187 lines
		
	
	
		
			4.9 KiB
		
	
	
	
		
			Plaintext
		
	
	
	
	
	
		
		
			
		
	
	
			187 lines
		
	
	
		
			4.9 KiB
		
	
	
	
		
			Plaintext
		
	
	
	
	
	
| 
								 | 
							
								// Copyright Jim Bosch 2010-2012.
							 | 
						||
| 
								 | 
							
								// Copyright Stefan Seefeld 2016.
							 | 
						||
| 
								 | 
							
								// Distributed under the Boost Software License, Version 1.0.
							 | 
						||
| 
								 | 
							
								// (See accompanying file LICENSE_1_0.txt or copy at
							 | 
						||
| 
								 | 
							
								// http://www.boost.org/LICENSE_1_0.txt)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								#ifndef boost_python_numpy_invoke_matching_hpp_
							 | 
						||
| 
								 | 
							
								#define boost_python_numpy_invoke_matching_hpp_
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								/**
							 | 
						||
| 
								 | 
							
								 *  @brief Template invocation based on dtype matching.
							 | 
						||
| 
								 | 
							
								 */
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								#include <boost/python/numpy/dtype.hpp>
							 | 
						||
| 
								 | 
							
								#include <boost/python/numpy/ndarray.hpp>
							 | 
						||
| 
								 | 
							
								#include <boost/mpl/integral_c.hpp>
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								namespace boost { namespace python { namespace numpy {
							 | 
						||
| 
								 | 
							
								namespace detail 
							 | 
						||
| 
								 | 
							
								{
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								struct add_pointer_meta 
							 | 
						||
| 
								 | 
							
								{
							 | 
						||
| 
								 | 
							
								  template <typename T>
							 | 
						||
| 
								 | 
							
								  struct apply 
							 | 
						||
| 
								 | 
							
								  {
							 | 
						||
| 
								 | 
							
								    typedef typename boost::add_pointer<T>::type type;
							 | 
						||
| 
								 | 
							
								  };
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								};
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								struct dtype_template_match_found {};
							 | 
						||
| 
								 | 
							
								struct nd_template_match_found {};
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								template <typename Function>
							 | 
						||
| 
								 | 
							
								struct dtype_template_invoker 
							 | 
						||
| 
								 | 
							
								{
							 | 
						||
| 
								 | 
							
								    
							 | 
						||
| 
								 | 
							
								  template <typename T>
							 | 
						||
| 
								 | 
							
								  void operator()(T *) const 
							 | 
						||
| 
								 | 
							
								  {
							 | 
						||
| 
								 | 
							
								    if (dtype::get_builtin<T>() == m_dtype) 
							 | 
						||
| 
								 | 
							
								    {
							 | 
						||
| 
								 | 
							
								      m_func.Function::template apply<T>();
							 | 
						||
| 
								 | 
							
								      throw dtype_template_match_found();
							 | 
						||
| 
								 | 
							
								    }
							 | 
						||
| 
								 | 
							
								  }
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								  dtype_template_invoker(dtype const & dtype_, Function func) 
							 | 
						||
| 
								 | 
							
								    : m_dtype(dtype_), m_func(func) {}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								private:
							 | 
						||
| 
								 | 
							
								  dtype const & m_dtype;
							 | 
						||
| 
								 | 
							
								  Function m_func;
							 | 
						||
| 
								 | 
							
								};
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								template <typename Function>
							 | 
						||
| 
								 | 
							
								struct dtype_template_invoker< boost::reference_wrapper<Function> > 
							 | 
						||
| 
								 | 
							
								{
							 | 
						||
| 
								 | 
							
								    
							 | 
						||
| 
								 | 
							
								  template <typename T>
							 | 
						||
| 
								 | 
							
								  void operator()(T *) const 
							 | 
						||
| 
								 | 
							
								  {
							 | 
						||
| 
								 | 
							
								    if (dtype::get_builtin<T>() == m_dtype) 
							 | 
						||
| 
								 | 
							
								    {
							 | 
						||
| 
								 | 
							
								      m_func.Function::template apply<T>();
							 | 
						||
| 
								 | 
							
								      throw dtype_template_match_found();
							 | 
						||
| 
								 | 
							
								    }
							 | 
						||
| 
								 | 
							
								  }
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								  dtype_template_invoker(dtype const & dtype_, Function & func)
							 | 
						||
| 
								 | 
							
								    : m_dtype(dtype_), m_func(func) {}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								private:
							 | 
						||
| 
								 | 
							
								  dtype const & m_dtype;
							 | 
						||
| 
								 | 
							
								  Function & m_func;
							 | 
						||
| 
								 | 
							
								};
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								template <typename Function>
							 | 
						||
| 
								 | 
							
								struct nd_template_invoker 
							 | 
						||
| 
								 | 
							
								{    
							 | 
						||
| 
								 | 
							
								  template <int N>
							 | 
						||
| 
								 | 
							
								  void operator()(boost::mpl::integral_c<int,N> *) const 
							 | 
						||
| 
								 | 
							
								  {
							 | 
						||
| 
								 | 
							
								    if (m_nd == N) 
							 | 
						||
| 
								 | 
							
								    {
							 | 
						||
| 
								 | 
							
								      m_func.Function::template apply<N>();
							 | 
						||
| 
								 | 
							
								      throw nd_template_match_found();
							 | 
						||
| 
								 | 
							
								    }
							 | 
						||
| 
								 | 
							
								  }
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								  nd_template_invoker(int nd, Function func) : m_nd(nd), m_func(func) {}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								private:
							 | 
						||
| 
								 | 
							
								  int m_nd;
							 | 
						||
| 
								 | 
							
								  Function m_func;
							 | 
						||
| 
								 | 
							
								};
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								template <typename Function>
							 | 
						||
| 
								 | 
							
								struct nd_template_invoker< boost::reference_wrapper<Function> > 
							 | 
						||
| 
								 | 
							
								{    
							 | 
						||
| 
								 | 
							
								  template <int N>
							 | 
						||
| 
								 | 
							
								  void operator()(boost::mpl::integral_c<int,N> *) const 
							 | 
						||
| 
								 | 
							
								  {
							 | 
						||
| 
								 | 
							
								    if (m_nd == N) 
							 | 
						||
| 
								 | 
							
								    {
							 | 
						||
| 
								 | 
							
								      m_func.Function::template apply<N>();
							 | 
						||
| 
								 | 
							
								      throw nd_template_match_found();
							 | 
						||
| 
								 | 
							
								    }
							 | 
						||
| 
								 | 
							
								  }
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								  nd_template_invoker(int nd, Function & func) : m_nd(nd), m_func(func) {}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								private:
							 | 
						||
| 
								 | 
							
								  int m_nd;
							 | 
						||
| 
								 | 
							
								  Function & m_func;
							 | 
						||
| 
								 | 
							
								};
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								} // namespace boost::python::numpy::detail
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								template <typename Sequence, typename Function>
							 | 
						||
| 
								 | 
							
								void invoke_matching_nd(int nd, Function f) 
							 | 
						||
| 
								 | 
							
								{
							 | 
						||
| 
								 | 
							
								  detail::nd_template_invoker<Function> invoker(nd, f);
							 | 
						||
| 
								 | 
							
								  try { boost::mpl::for_each< Sequence, detail::add_pointer_meta >(invoker);}
							 | 
						||
| 
								 | 
							
								  catch (detail::nd_template_match_found &) { return;}
							 | 
						||
| 
								 | 
							
								  PyErr_SetString(PyExc_TypeError, "number of dimensions not found in template list.");
							 | 
						||
| 
								 | 
							
								  python::throw_error_already_set();
							 | 
						||
| 
								 | 
							
								}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								template <typename Sequence, typename Function>
							 | 
						||
| 
								 | 
							
								void invoke_matching_dtype(dtype const & dtype_, Function f) 
							 | 
						||
| 
								 | 
							
								{
							 | 
						||
| 
								 | 
							
								  detail::dtype_template_invoker<Function> invoker(dtype_, f);
							 | 
						||
| 
								 | 
							
								  try { boost::mpl::for_each< Sequence, detail::add_pointer_meta >(invoker);}
							 | 
						||
| 
								 | 
							
								  catch (detail::dtype_template_match_found &) { return;}
							 | 
						||
| 
								 | 
							
								  PyErr_SetString(PyExc_TypeError, "dtype not found in template list.");
							 | 
						||
| 
								 | 
							
								  python::throw_error_already_set();
							 | 
						||
| 
								 | 
							
								}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								namespace detail 
							 | 
						||
| 
								 | 
							
								{
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								template <typename T, typename Function>
							 | 
						||
| 
								 | 
							
								struct array_template_invoker_wrapper_2 
							 | 
						||
| 
								 | 
							
								{
							 | 
						||
| 
								 | 
							
								  template <int N>
							 | 
						||
| 
								 | 
							
								  void apply() const { m_func.Function::template apply<T,N>();}
							 | 
						||
| 
								 | 
							
								  array_template_invoker_wrapper_2(Function & func) : m_func(func) {}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								private:
							 | 
						||
| 
								 | 
							
								  Function & m_func;
							 | 
						||
| 
								 | 
							
								};
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								template <typename DimSequence, typename Function>
							 | 
						||
| 
								 | 
							
								struct array_template_invoker_wrapper_1 
							 | 
						||
| 
								 | 
							
								{
							 | 
						||
| 
								 | 
							
								  template <typename T>
							 | 
						||
| 
								 | 
							
								  void apply() const { invoke_matching_nd<DimSequence>(m_nd, array_template_invoker_wrapper_2<T,Function>(m_func));}
							 | 
						||
| 
								 | 
							
								  array_template_invoker_wrapper_1(int nd, Function & func) : m_nd(nd), m_func(func) {}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								private:
							 | 
						||
| 
								 | 
							
								  int m_nd;
							 | 
						||
| 
								 | 
							
								  Function & m_func;
							 | 
						||
| 
								 | 
							
								};
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								template <typename DimSequence, typename Function>
							 | 
						||
| 
								 | 
							
								struct array_template_invoker_wrapper_1< DimSequence, boost::reference_wrapper<Function> >
							 | 
						||
| 
								 | 
							
								  : public array_template_invoker_wrapper_1< DimSequence, Function >
							 | 
						||
| 
								 | 
							
								{
							 | 
						||
| 
								 | 
							
								  array_template_invoker_wrapper_1(int nd, Function & func)
							 | 
						||
| 
								 | 
							
								    : array_template_invoker_wrapper_1< DimSequence, Function >(nd, func) {}
							 | 
						||
| 
								 | 
							
								};
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								} // namespace boost::python::numpy::detail
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								template <typename TypeSequence, typename DimSequence, typename Function>
							 | 
						||
| 
								 | 
							
								void invoke_matching_array(ndarray const & array_, Function f) 
							 | 
						||
| 
								 | 
							
								{
							 | 
						||
| 
								 | 
							
								  detail::array_template_invoker_wrapper_1<DimSequence,Function> wrapper(array_.get_nd(), f);
							 | 
						||
| 
								 | 
							
								  invoke_matching_dtype<TypeSequence>(array_.get_dtype(), wrapper);
							 | 
						||
| 
								 | 
							
								}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								}}} // namespace boost::python::numpy
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								#endif
							 |