330 lines
		
	
	
		
			10 KiB
		
	
	
	
		
			Plaintext
		
	
	
	
	
	
			
		
		
	
	
			330 lines
		
	
	
		
			10 KiB
		
	
	
	
		
			Plaintext
		
	
	
	
	
	
//---------------------------------------------------------------------------//
 | 
						|
// Copyright (c) 2013 Kyle Lutz <kyle.r.lutz@gmail.com>
 | 
						|
//
 | 
						|
// Distributed under the Boost Software License, Version 1.0
 | 
						|
// See accompanying file LICENSE_1_0.txt or copy at
 | 
						|
// http://www.boost.org/LICENSE_1_0.txt
 | 
						|
//
 | 
						|
// See http://boostorg.github.com/compute for more information.
 | 
						|
//---------------------------------------------------------------------------//
 | 
						|
 | 
						|
#ifndef BOOST_COMPUTE_LAMBDA_CONTEXT_HPP
 | 
						|
#define BOOST_COMPUTE_LAMBDA_CONTEXT_HPP
 | 
						|
 | 
						|
#include <boost/proto/core.hpp>
 | 
						|
#include <boost/proto/context.hpp>
 | 
						|
#include <boost/type_traits.hpp>
 | 
						|
#include <boost/preprocessor/repetition.hpp>
 | 
						|
 | 
						|
#include <boost/compute/config.hpp>
 | 
						|
#include <boost/compute/function.hpp>
 | 
						|
#include <boost/compute/lambda/result_of.hpp>
 | 
						|
#include <boost/compute/lambda/functional.hpp>
 | 
						|
#include <boost/compute/type_traits/result_of.hpp>
 | 
						|
#include <boost/compute/type_traits/type_name.hpp>
 | 
						|
#include <boost/compute/detail/meta_kernel.hpp>
 | 
						|
 | 
						|
namespace boost {
 | 
						|
namespace compute {
 | 
						|
namespace lambda {
 | 
						|
 | 
						|
namespace mpl = boost::mpl;
 | 
						|
namespace proto = boost::proto;
 | 
						|
 | 
						|
#define BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(tag, op) \
 | 
						|
    template<class LHS, class RHS> \
 | 
						|
    void operator()(tag, const LHS &lhs, const RHS &rhs) \
 | 
						|
    { \
 | 
						|
        if(proto::arity_of<LHS>::value > 0){ \
 | 
						|
            stream << '('; \
 | 
						|
            proto::eval(lhs, *this); \
 | 
						|
            stream << ')'; \
 | 
						|
        } \
 | 
						|
        else { \
 | 
						|
            proto::eval(lhs, *this); \
 | 
						|
        } \
 | 
						|
        \
 | 
						|
        stream << op; \
 | 
						|
        \
 | 
						|
        if(proto::arity_of<RHS>::value > 0){ \
 | 
						|
            stream << '('; \
 | 
						|
            proto::eval(rhs, *this); \
 | 
						|
            stream << ')'; \
 | 
						|
        } \
 | 
						|
        else { \
 | 
						|
            proto::eval(rhs, *this); \
 | 
						|
        } \
 | 
						|
    }
 | 
						|
 | 
						|
// lambda expression context
 | 
						|
template<class Args>
 | 
						|
struct context : proto::callable_context<context<Args> >
 | 
						|
{
 | 
						|
    typedef void result_type;
 | 
						|
    typedef Args args_tuple;
 | 
						|
 | 
						|
    // create a lambda context for kernel with args
 | 
						|
    context(boost::compute::detail::meta_kernel &kernel, const Args &args_)
 | 
						|
        : stream(kernel),
 | 
						|
          args(args_)
 | 
						|
    {
 | 
						|
    }
 | 
						|
 | 
						|
    // handle terminals
 | 
						|
    template<class T>
 | 
						|
    void operator()(proto::tag::terminal, const T &x)
 | 
						|
    {
 | 
						|
        // terminal values in lambda expressions are always literals
 | 
						|
        stream << stream.lit(x);
 | 
						|
    }
 | 
						|
 | 
						|
    // handle placeholders
 | 
						|
    template<int I>
 | 
						|
    void operator()(proto::tag::terminal, placeholder<I>)
 | 
						|
    {
 | 
						|
        stream << boost::get<I>(args);
 | 
						|
    }
 | 
						|
 | 
						|
    // handle functions
 | 
						|
    #define BOOST_COMPUTE_LAMBDA_CONTEXT_FUNCTION_ARG(z, n, unused) \
 | 
						|
        BOOST_PP_COMMA_IF(n) BOOST_PP_CAT(const Arg, n) BOOST_PP_CAT(&arg, n)
 | 
						|
 | 
						|
    #define BOOST_COMPUTE_LAMBDA_CONTEXT_FUNCTION(z, n, unused) \
 | 
						|
    template<class F, BOOST_PP_ENUM_PARAMS(n, class Arg)> \
 | 
						|
    void operator()( \
 | 
						|
        proto::tag::function, \
 | 
						|
        const F &function, \
 | 
						|
        BOOST_PP_REPEAT(n, BOOST_COMPUTE_LAMBDA_CONTEXT_FUNCTION_ARG, ~) \
 | 
						|
    ) \
 | 
						|
    { \
 | 
						|
        proto::value(function).apply(*this, BOOST_PP_ENUM_PARAMS(n, arg)); \
 | 
						|
    }
 | 
						|
 | 
						|
    BOOST_PP_REPEAT_FROM_TO(1, BOOST_COMPUTE_MAX_ARITY, BOOST_COMPUTE_LAMBDA_CONTEXT_FUNCTION, ~)
 | 
						|
 | 
						|
    #undef BOOST_COMPUTE_LAMBDA_CONTEXT_FUNCTION
 | 
						|
 | 
						|
    // operators
 | 
						|
    BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::plus, '+')
 | 
						|
    BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::minus, '-')
 | 
						|
    BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::multiplies, '*')
 | 
						|
    BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::divides, '/')
 | 
						|
    BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::modulus, '%')
 | 
						|
    BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::less, '<')
 | 
						|
    BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::greater, '>')
 | 
						|
    BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::less_equal, "<=")
 | 
						|
    BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::greater_equal, ">=")
 | 
						|
    BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::equal_to, "==")
 | 
						|
    BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::not_equal_to, "!=")
 | 
						|
    BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::logical_and, "&&")
 | 
						|
    BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::logical_or, "||")
 | 
						|
    BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::bitwise_and, '&')
 | 
						|
    BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::bitwise_or, '|')
 | 
						|
    BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::bitwise_xor, '^')
 | 
						|
    BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::assign, '=')
 | 
						|
 | 
						|
    // subscript operator
 | 
						|
    template<class LHS, class RHS>
 | 
						|
    void operator()(proto::tag::subscript, const LHS &lhs, const RHS &rhs)
 | 
						|
    {
 | 
						|
        proto::eval(lhs, *this);
 | 
						|
        stream << '[';
 | 
						|
        proto::eval(rhs, *this);
 | 
						|
        stream << ']';
 | 
						|
    }
 | 
						|
 | 
						|
    // ternary conditional operator
 | 
						|
    template<class Pred, class Arg1, class Arg2>
 | 
						|
    void operator()(proto::tag::if_else_, const Pred &p, const Arg1 &x, const Arg2 &y)
 | 
						|
    {
 | 
						|
        proto::eval(p, *this);
 | 
						|
        stream << '?';
 | 
						|
        proto::eval(x, *this);
 | 
						|
        stream << ':';
 | 
						|
        proto::eval(y, *this);
 | 
						|
    }
 | 
						|
 | 
						|
    boost::compute::detail::meta_kernel &stream;
 | 
						|
    Args args;
 | 
						|
};
 | 
						|
 | 
						|
namespace detail {
 | 
						|
 | 
						|
template<class Expr, class Arg>
 | 
						|
struct invoked_unary_expression
 | 
						|
{
 | 
						|
    typedef typename ::boost::compute::result_of<Expr(Arg)>::type result_type;
 | 
						|
 | 
						|
    invoked_unary_expression(const Expr &expr, const Arg &arg)
 | 
						|
        : m_expr(expr),
 | 
						|
          m_arg(arg)
 | 
						|
    {
 | 
						|
    }
 | 
						|
 | 
						|
    Expr m_expr;
 | 
						|
    Arg m_arg;
 | 
						|
};
 | 
						|
 | 
						|
template<class Expr, class Arg>
 | 
						|
boost::compute::detail::meta_kernel&
 | 
						|
operator<<(boost::compute::detail::meta_kernel &kernel,
 | 
						|
           const invoked_unary_expression<Expr, Arg> &expr)
 | 
						|
{
 | 
						|
    context<boost::tuple<Arg> > ctx(kernel, boost::make_tuple(expr.m_arg));
 | 
						|
    proto::eval(expr.m_expr, ctx);
 | 
						|
 | 
						|
    return kernel;
 | 
						|
}
 | 
						|
 | 
						|
template<class Expr, class Arg1, class Arg2>
 | 
						|
struct invoked_binary_expression
 | 
						|
{
 | 
						|
    typedef typename ::boost::compute::result_of<Expr(Arg1, Arg2)>::type result_type;
 | 
						|
 | 
						|
    invoked_binary_expression(const Expr &expr,
 | 
						|
                              const Arg1 &arg1,
 | 
						|
                              const Arg2 &arg2)
 | 
						|
        : m_expr(expr),
 | 
						|
          m_arg1(arg1),
 | 
						|
          m_arg2(arg2)
 | 
						|
    {
 | 
						|
    }
 | 
						|
 | 
						|
    Expr m_expr;
 | 
						|
    Arg1 m_arg1;
 | 
						|
    Arg2 m_arg2;
 | 
						|
};
 | 
						|
 | 
						|
template<class Expr, class Arg1, class Arg2>
 | 
						|
boost::compute::detail::meta_kernel&
 | 
						|
operator<<(boost::compute::detail::meta_kernel &kernel,
 | 
						|
           const invoked_binary_expression<Expr, Arg1, Arg2> &expr)
 | 
						|
{
 | 
						|
    context<boost::tuple<Arg1, Arg2> > ctx(
 | 
						|
        kernel,
 | 
						|
        boost::make_tuple(expr.m_arg1, expr.m_arg2)
 | 
						|
    );
 | 
						|
    proto::eval(expr.m_expr, ctx);
 | 
						|
 | 
						|
    return kernel;
 | 
						|
}
 | 
						|
 | 
						|
} // end detail namespace
 | 
						|
 | 
						|
// forward declare domain
 | 
						|
struct domain;
 | 
						|
 | 
						|
// lambda expression wrapper
 | 
						|
template<class Expr>
 | 
						|
struct expression : proto::extends<Expr, expression<Expr>, domain>
 | 
						|
{
 | 
						|
    typedef proto::extends<Expr, expression<Expr>, domain> base_type;
 | 
						|
 | 
						|
    BOOST_PROTO_EXTENDS_USING_ASSIGN(expression)
 | 
						|
 | 
						|
    expression(const Expr &expr = Expr())
 | 
						|
        : base_type(expr)
 | 
						|
    {
 | 
						|
    }
 | 
						|
 | 
						|
    // result_of protocol
 | 
						|
    template<class Signature>
 | 
						|
    struct result
 | 
						|
    {
 | 
						|
    };
 | 
						|
 | 
						|
    template<class This>
 | 
						|
    struct result<This()>
 | 
						|
    {
 | 
						|
        typedef
 | 
						|
            typename ::boost::compute::lambda::result_of<Expr>::type type;
 | 
						|
    };
 | 
						|
 | 
						|
    template<class This, class Arg>
 | 
						|
    struct result<This(Arg)>
 | 
						|
    {
 | 
						|
        typedef
 | 
						|
            typename ::boost::compute::lambda::result_of<
 | 
						|
                Expr,
 | 
						|
                typename boost::tuple<Arg>
 | 
						|
            >::type type;
 | 
						|
    };
 | 
						|
 | 
						|
    template<class This, class Arg1, class Arg2>
 | 
						|
    struct result<This(Arg1, Arg2)>
 | 
						|
    {
 | 
						|
        typedef typename
 | 
						|
            ::boost::compute::lambda::result_of<
 | 
						|
                Expr,
 | 
						|
                typename boost::tuple<Arg1, Arg2>
 | 
						|
            >::type type;
 | 
						|
    };
 | 
						|
 | 
						|
    template<class Arg>
 | 
						|
    detail::invoked_unary_expression<expression<Expr>, Arg>
 | 
						|
    operator()(const Arg &x) const
 | 
						|
    {
 | 
						|
        return detail::invoked_unary_expression<expression<Expr>, Arg>(*this, x);
 | 
						|
    }
 | 
						|
 | 
						|
    template<class Arg1, class Arg2>
 | 
						|
    detail::invoked_binary_expression<expression<Expr>, Arg1, Arg2>
 | 
						|
    operator()(const Arg1 &x, const Arg2 &y) const
 | 
						|
    {
 | 
						|
        return detail::invoked_binary_expression<
 | 
						|
                   expression<Expr>,
 | 
						|
                   Arg1,
 | 
						|
                   Arg2
 | 
						|
                >(*this, x, y);
 | 
						|
    }
 | 
						|
 | 
						|
    // function<> conversion operator
 | 
						|
    template<class R, class A1>
 | 
						|
    operator function<R(A1)>() const
 | 
						|
    {
 | 
						|
        using ::boost::compute::detail::meta_kernel;
 | 
						|
 | 
						|
        std::stringstream source;
 | 
						|
 | 
						|
        ::boost::compute::detail::meta_kernel_variable<A1> arg1("x");
 | 
						|
 | 
						|
        source << "inline " << type_name<R>() << " lambda"
 | 
						|
               << ::boost::compute::detail::generate_argument_list<R(A1)>('x')
 | 
						|
               << "{\n"
 | 
						|
               << "    return " << meta_kernel::expr_to_string((*this)(arg1)) << ";\n"
 | 
						|
               << "}\n";
 | 
						|
 | 
						|
        return make_function_from_source<R(A1)>("lambda", source.str());
 | 
						|
    }
 | 
						|
 | 
						|
    template<class R, class A1, class A2>
 | 
						|
    operator function<R(A1, A2)>() const
 | 
						|
    {
 | 
						|
        using ::boost::compute::detail::meta_kernel;
 | 
						|
 | 
						|
        std::stringstream source;
 | 
						|
 | 
						|
        ::boost::compute::detail::meta_kernel_variable<A1> arg1("x");
 | 
						|
        ::boost::compute::detail::meta_kernel_variable<A1> arg2("y");
 | 
						|
 | 
						|
        source << "inline " << type_name<R>() << " lambda"
 | 
						|
               << ::boost::compute::detail::generate_argument_list<R(A1, A2)>('x')
 | 
						|
               << "{\n"
 | 
						|
               << "    return " << meta_kernel::expr_to_string((*this)(arg1, arg2)) << ";\n"
 | 
						|
               << "}\n";
 | 
						|
 | 
						|
        return make_function_from_source<R(A1, A2)>("lambda", source.str());
 | 
						|
    }
 | 
						|
};
 | 
						|
 | 
						|
// lambda expression domain
 | 
						|
struct domain : proto::domain<proto::generator<expression> >
 | 
						|
{
 | 
						|
};
 | 
						|
 | 
						|
} // end lambda namespace
 | 
						|
} // end compute namespace
 | 
						|
} // end boost namespace
 | 
						|
 | 
						|
#endif // BOOST_COMPUTE_LAMBDA_CONTEXT_HPP
 |