http://git-wip-us.apache.org/repos/asf/mahout/blob/f7c1f802/native-viennaCL/src/main/cpp/viennacl/device_specific/mapped_objects.hpp ---------------------------------------------------------------------- diff --git a/native-viennaCL/src/main/cpp/viennacl/device_specific/mapped_objects.hpp b/native-viennaCL/src/main/cpp/viennacl/device_specific/mapped_objects.hpp new file mode 100644 index 0000000..19f7993 --- /dev/null +++ b/native-viennaCL/src/main/cpp/viennacl/device_specific/mapped_objects.hpp @@ -0,0 +1,512 @@ +#ifndef VIENNACL_DEVICE_SPECIFIC_MAPPED_TYPE_HPP +#define VIENNACL_DEVICE_SPECIFIC_MAPPED_TYPE_HPP + +/* ========================================================================= + Copyright (c) 2010-2016, Institute for Microelectronics, + Institute for Analysis and Scientific Computing, + TU Wien. + Portions of this software are copyright by UChicago Argonne, LLC. + + ----------------- + ViennaCL - The Vienna Computing Library + ----------------- + + Project Head: Karl Rupp [email protected] + + (A list of authors and contributors can be found in the manual) + + License: MIT (X11), see file LICENSE in the base directory +============================================================================= */ + + +/** @file viennacl/device_specific/mapped_objects.hpp + @brief Map ViennaCL objects to generator wrappers +*/ + +#include <string> + +#include "viennacl/scheduler/forwards.h" +#include "viennacl/device_specific/forwards.h" +#include "viennacl/device_specific/utils.hpp" + +namespace viennacl +{ + +namespace device_specific +{ + +/** @brief Mapped Object +* +* This object populates the symbolic mapping associated with a statement. (root_id, LHS|RHS|PARENT) => mapped_object +* The tree can then be reconstructed in its symbolic form +*/ +class mapped_object +{ +private: + virtual void postprocess(std::string &) const { } + +protected: + struct MorphBase { virtual ~MorphBase(){} }; + struct MorphBase1D : public MorphBase { public: virtual std::string operator()(std::string const & i) const = 0; }; + struct MorphBase2D : public MorphBase { public: virtual std::string operator()(std::string const & i, std::string const & j) const = 0; }; + + static void replace_offset(std::string & str, MorphBase const & morph) + { + vcl_size_t pos = 0; + while ((pos=str.find("$OFFSET", pos))!=std::string::npos) + { + std::string postprocessed; + vcl_size_t pos_po = str.find('{', pos); + vcl_size_t pos_pe = str.find('}', pos_po); + + if (MorphBase2D const * p2d = dynamic_cast<MorphBase2D const *>(&morph)) + { + vcl_size_t pos_comma = str.find(',', pos_po); + std::string i = str.substr(pos_po + 1, pos_comma - pos_po - 1); + std::string j = str.substr(pos_comma + 1, pos_pe - pos_comma - 1); + postprocessed = (*p2d)(i, j); + } + else if (MorphBase1D const * p1d = dynamic_cast<MorphBase1D const *>(&morph)) + { + std::string i = str.substr(pos_po + 1, pos_pe - pos_po - 1); + postprocessed = (*p1d)(i); + } + + str.replace(pos, pos_pe + 1 - pos, postprocessed); + pos = pos_pe; + } + } + + void register_attribute(std::string & attribute, std::string const & key, std::string const & value) + { + attribute = value; + keywords_[key] = attribute; + } + +public: + struct node_info + { + node_info(mapping_type const * _mapping, scheduler::statement const * _statement, vcl_size_t _root_idx) : + mapping(_mapping), statement(_statement), root_idx(_root_idx) { } + mapping_type const * mapping; + scheduler::statement const * statement; + vcl_size_t root_idx; + }; + +public: + mapped_object(std::string const & scalartype, unsigned int id, std::string const & type_key) : type_key_(type_key) + { + register_attribute(scalartype_, "#scalartype", scalartype); + register_attribute(name_, "#name", "obj" + tools::to_string(id)); + } + + virtual ~mapped_object(){ } + + virtual std::string & append_kernel_arguments(std::set<std::string> &, std::string & str, unsigned int) const { return str; } + + std::string type_key() const { return type_key_; } + + std::string const & name() const { return name_; } + + std::string process(std::string const & in) const + { + std::string res(in); + for (std::map<std::string,std::string>::const_iterator it = keywords_.begin(); it != keywords_.end(); ++it) + tools::find_and_replace(res, it->first, it->second); + postprocess(res); + return res; + } + + std::string evaluate(std::map<std::string, std::string> const & accessors) const + { + if (accessors.find(type_key_)==accessors.end()) + return name_; + return process(at(accessors, type_key_)); + } + + +protected: + std::string name_; + std::string scalartype_; + std::string type_key_; + std::map<std::string, std::string> keywords_; +}; + + +/** @brief Binary leaf interface +* +* Some subtrees have to be interpret at leaves when reconstructing the final expression. It is the case of trans(), diag(), prod(), etc... +* This interface stores basic infos about the sub-trees +*/ +class binary_leaf +{ +public: + binary_leaf(mapped_object::node_info info) : info_(info){ } + + void process_recursive(utils::kernel_generation_stream & stream, leaf_t leaf, std::string const & key, std::string const & process_str, std::set<std::string> & already_fetched) + { + tree_parsing::process(stream, leaf, key, process_str, *info_.statement, info_.root_idx, *info_.mapping, already_fetched); + } + + std::string evaluate_recursive(leaf_t leaf, std::map<std::string, std::string> const & accessors) + { + return tree_parsing::evaluate(leaf, accessors, *info_.statement, info_.root_idx, *info_.mapping); + } + +protected: + mapped_object::node_info info_; +}; + +/** @brief Matrix product + * + * Maps prod(matrix_expression, matrix_expression) + */ +class mapped_matrix_product : public mapped_object, public binary_leaf +{ +public: + mapped_matrix_product(std::string const & scalartype, unsigned int id, node_info info) : mapped_object(scalartype, id, "matrix_product"), binary_leaf(info) { } +}; + +/** @brief Reduction +* +* Base class for mapping a reduction +*/ +class mapped_reduction : public mapped_object, public binary_leaf +{ +public: + mapped_reduction(std::string const & scalartype, unsigned int id, node_info info, std::string const & type_key) : mapped_object(scalartype, id, type_key), binary_leaf(info){ } + + vcl_size_t root_idx() const { return info_.root_idx; } + scheduler::statement const & statement() const { return *info_.statement; } + scheduler::statement_node root_node() const { return statement().array()[root_idx()]; } + bool is_index_reduction() const { return utils::is_index_reduction(info_.statement->array()[info_.root_idx].op); } + + scheduler::op_element root_op() const + { + scheduler::op_element res = info_.statement->array()[info_.root_idx].op; + if (res.type==scheduler::OPERATION_BINARY_MAT_VEC_PROD_TYPE + ||res.type==scheduler::OPERATION_BINARY_INNER_PROD_TYPE) + res.type = scheduler::OPERATION_BINARY_ADD_TYPE; + return res; + } +}; + +/** @brief Scalar reduction +* +* Maps a scalar reduction (max, min, argmax, inner_prod, etc..) +*/ +class mapped_scalar_reduction : public mapped_reduction +{ +public: + mapped_scalar_reduction(std::string const & scalartype, unsigned int id, node_info info) : mapped_reduction(scalartype, id, info, "scalar_reduction"){ } +}; + +/** @brief Vector reduction +* +* Maps a row-wise reduction (max, min, argmax, matrix-vector product, etc..) +*/ +class mapped_row_wise_reduction : public mapped_reduction +{ +public: + mapped_row_wise_reduction(std::string const & scalartype, unsigned int id, node_info info) : mapped_reduction(scalartype, id, info, "row_wise_reduction") { } +}; + +/** @brief Host scalar + * + * Maps a host scalar (passed by value) + */ +class mapped_host_scalar : public mapped_object +{ +public: + mapped_host_scalar(std::string const & scalartype, unsigned int id) : mapped_object(scalartype, id, "host_scalar"){ } + + std::string & append_kernel_arguments(std::set<std::string> & already_generated, std::string & str, unsigned int width) const + { + if (already_generated.insert(name_).second) + str += generate_value_kernel_argument(utils::append_width(scalartype_, width), name_); + return str; + } +}; + +/** @brief Handle +* +* Maps an object passed by pointer +*/ +class mapped_handle : public mapped_object +{ +private: + virtual void append_optional_arguments(std::string &) const = 0; + +public: + mapped_handle(std::string const & scalartype, unsigned int id, std::string const & type_key) : mapped_object(scalartype, id, type_key) + { + register_attribute(pointer_, "#pointer", name_ + "_pointer"); + } + + std::string & append_kernel_arguments(std::set<std::string> & already_generated, std::string & str, unsigned int width) const + { + if (already_generated.insert(name_).second) + { + str += generate_pointer_kernel_argument("__global", utils::append_width(scalartype_, width), pointer_); + append_optional_arguments(str); + } + return str; + } + +private: + std::string pointer_; +}; + + +/** @brief Scalar + * + * Maps a scalar passed by pointer + */ +class mapped_scalar : public mapped_handle +{ +private: + void append_optional_arguments(std::string &) const{ } + +public: + mapped_scalar(std::string const & scalartype, unsigned int id) : mapped_handle(scalartype, id, "scalar") { } +}; + +/** @brief Buffered + * + * Maps a buffered object (vector, matrix) + */ +class mapped_buffer : public mapped_handle +{ +public: + mapped_buffer(std::string const & scalartype, unsigned int id, std::string const & type_key) : mapped_handle(scalartype, id, type_key){ } +}; + +/** @brief Vector + * + * Maps a vector + */ +class mapped_vector : public mapped_buffer +{ + void append_optional_arguments(std::string & str) const + { + str += generate_value_kernel_argument("unsigned int", start_); + str += generate_value_kernel_argument("unsigned int", stride_); + } + +public: + mapped_vector(std::string const & scalartype, unsigned int id) : mapped_buffer(scalartype, id, "vector") + { + register_attribute(start_, "#start", name_ + "_start"); + register_attribute(stride_, "#stride", name_ + "_stride"); + } + +private: + std::string start_; + std::string stride_; +}; + +/** @brief Matrix + * + * Maps a matrix + */ +class mapped_matrix : public mapped_buffer +{ +private: + void append_optional_arguments(std::string & str) const + { + str += generate_value_kernel_argument("unsigned int", ld_); + str += generate_value_kernel_argument("unsigned int", start1_); + str += generate_value_kernel_argument("unsigned int", start2_); + str += generate_value_kernel_argument("unsigned int", stride1_); + str += generate_value_kernel_argument("unsigned int", stride2_); + } + + void postprocess(std::string & str) const + { + struct Morph : public MorphBase2D + { + Morph(bool _is_row_major, std::string const & _ld) : is_row_major(_is_row_major), ld(_ld){ } + std::string operator()(std::string const & i, std::string const & j) const + { + if (is_row_major) + return "(" + i + ") * " + ld + " + (" + j + ")"; + return "(" + i + ") + (" + j + ") * " + ld; + } + private: + bool is_row_major; + std::string const & ld; + }; + replace_offset(str, Morph(row_major_, ld_)); + } + +public: + mapped_matrix(std::string const & scalartype, unsigned int id, bool row_major) : mapped_buffer(scalartype, id, "matrix"), row_major_(row_major) + { + register_attribute(ld_, "#ld", name_ + "_ld"); + register_attribute(start1_, "#start1", name_ + "_start1"); + register_attribute(start2_, "#start2", name_ + "_start2"); + register_attribute(stride1_, "#stride1", name_ + "_stride1"); + register_attribute(stride2_, "#stride2", name_ + "_stride2"); + if (row_major_) + keywords_["#nldstride"] = "#stride1"; + else + keywords_["#nldstride"] = "#stride2"; + + if (row_major_) + { + std::swap(start1_, start2_); + std::swap(stride1_, stride2_); + } + } + + bool row_major() const + { + return row_major_; + } + +private: + std::string ld_; + std::string start1_; + std::string start2_; + std::string stride1_; + std::string stride2_; + bool row_major_; +}; + +/** @brief Vector diag +* +* Maps a diag(vector_expression) node into a diagonal matrix +*/ +class mapped_vector_diag : public mapped_object, public binary_leaf +{ +private: + void postprocess(std::string &res) const + { + std::map<std::string, std::string> accessors; + tools::find_and_replace(res, "#diag_offset", tree_parsing::evaluate(RHS_NODE_TYPE, accessors, *info_.statement, info_.root_idx, *info_.mapping)); + accessors["vector"] = res; + res = tree_parsing::evaluate(LHS_NODE_TYPE, accessors, *info_.statement, info_.root_idx, *info_.mapping); + } + +public: + mapped_vector_diag(std::string const & scalartype, unsigned int id, node_info info) : mapped_object(scalartype, id, "vector_diag"), binary_leaf(info){ } +}; + + +/** @brief Trans +* +* Maps trans(matrix_expression) into the transposed of matrix_expression +*/ +class mapped_trans: public mapped_object, public binary_leaf +{ +private: + void postprocess(std::string &res) const + { + std::map<std::string, std::string> accessors; + accessors["matrix"] = res; + res = tree_parsing::evaluate(LHS_NODE_TYPE, accessors, *info_.statement, info_.root_idx, *info_.mapping); + } + +public: + mapped_trans(std::string const & scalartype, unsigned int id, node_info info) : mapped_object(scalartype, id, "matrix_trans"), binary_leaf(info){ } +}; + +/** @brief Matrix row +* +* Maps row(matrix_expression, scalar_expression) into the scalar_expression's row of matrix_expression +*/ +class mapped_matrix_row : public mapped_object, binary_leaf +{ +private: + void postprocess(std::string &res) const + { + std::map<std::string, std::string> accessors; + tools::find_and_replace(res, "#row", tree_parsing::evaluate(RHS_NODE_TYPE, accessors, *info_.statement, info_.root_idx, *info_.mapping)); + accessors["matrix"] = res; + res = tree_parsing::evaluate(LHS_NODE_TYPE, accessors, *info_.statement, info_.root_idx, *info_.mapping); + } + +public: + mapped_matrix_row(std::string const & scalartype, unsigned int id, node_info info) : mapped_object(scalartype, id, "matrix_row"), binary_leaf(info) + { } +}; + + +/** @brief Matrix column +* +* Maps column(matrix_expression, scalar_expression) into the scalar_expression's column of matrix_expression +*/ +class mapped_matrix_column : public mapped_object, binary_leaf +{ +private: + void postprocess(std::string &res) const + { + std::map<std::string, std::string> accessors; + tools::find_and_replace(res, "#column", tree_parsing::evaluate(RHS_NODE_TYPE, accessors, *info_.statement, info_.root_idx, *info_.mapping)); + accessors["matrix"] = res; + res = tree_parsing::evaluate(LHS_NODE_TYPE, accessors, *info_.statement, info_.root_idx, *info_.mapping); + } + +public: + mapped_matrix_column(std::string const & scalartype, unsigned int id, node_info info) : mapped_object(scalartype, id, "matrix_column"), binary_leaf(info) + { } +}; + +/** @brief Matrix diag +* +* Maps a diag(matrix_expression) node into the vector of its diagonal elements +*/ +class mapped_matrix_diag : public mapped_object, binary_leaf +{ +private: + void postprocess(std::string &res) const + { + std::map<std::string, std::string> accessors; + tools::find_and_replace(res, "#diag_offset", tree_parsing::evaluate(RHS_NODE_TYPE, accessors, *info_.statement, info_.root_idx, *info_.mapping)); + accessors["matrix"] = res; + res = tree_parsing::evaluate(LHS_NODE_TYPE, accessors, *info_.statement, info_.root_idx, *info_.mapping); + } + +public: + mapped_matrix_diag(std::string const & scalartype, unsigned int id, node_info info) : mapped_object(scalartype, id, "matrix_diag"), binary_leaf(info) + { } +}; + +/** @brief Implicit vector + * + * Maps an implicit vector + */ +class mapped_implicit_vector : public mapped_object +{ +public: + mapped_implicit_vector(std::string const & scalartype, unsigned int id) : mapped_object(scalartype, id, "implicit_vector") + { } + + std::string & append_kernel_arguments(std::set<std::string> & /*already_generated*/, std::string & str, unsigned int width) const + { + str += generate_value_kernel_argument(utils::append_width(scalartype_, width), name_); + return str; + } +}; + +/** @brief Implicit matrix + * + * Maps an implicit matrix + */ +class mapped_implicit_matrix : public mapped_object +{ +public: + mapped_implicit_matrix(std::string const & scalartype, unsigned int id) : mapped_object(scalartype, id, "implicit_matrix") + { } + + std::string & append_kernel_arguments(std::set<std::string> & /*already_generated*/, std::string & str, unsigned int width) const + { + str += generate_value_kernel_argument(utils::append_width(scalartype_, width), name_); + return str; + } +}; + +} + +} +#endif
http://git-wip-us.apache.org/repos/asf/mahout/blob/f7c1f802/native-viennaCL/src/main/cpp/viennacl/device_specific/templates/matrix_product_template.hpp ---------------------------------------------------------------------- diff --git a/native-viennaCL/src/main/cpp/viennacl/device_specific/templates/matrix_product_template.hpp b/native-viennaCL/src/main/cpp/viennacl/device_specific/templates/matrix_product_template.hpp new file mode 100644 index 0000000..1f082ac --- /dev/null +++ b/native-viennaCL/src/main/cpp/viennacl/device_specific/templates/matrix_product_template.hpp @@ -0,0 +1,859 @@ +#ifndef VIENNACL_DEVICE_SPECIFIC_TEMPLATES_MATRIX_PRODUCT_HPP +#define VIENNACL_DEVICE_SPECIFIC_TEMPLATES_MATRIX_PRODUCT_HPP + +/* ========================================================================= +Copyright (c) 2010-2016, Institute for Microelectronics, + Institute for Analysis and Scientific Computing, + TU Wien. +Portions of this software are copyright by UChicago Argonne, LLC. + + ----------------- + ViennaCL - The Vienna Computing Library + ----------------- + +Project Head: Karl Rupp [email protected] + +(A list of authors and contributors can be found in the manual) + +License: MIT (X11), see file LICENSE in the base directory +============================================================================= */ + + +/** @file viennacl/device_specific/templates/matrix_product_template.hpp +* +* Kernel template for the matrix product operation +*/ + +#include <vector> + +#include "viennacl/scheduler/forwards.h" + +#include "viennacl/detail/matrix_def.hpp" +#include "viennacl/matrix_proxy.hpp" + +#include "viennacl/device_specific/templates/template_base.hpp" +#include "viennacl/device_specific/mapped_objects.hpp" +#include "viennacl/device_specific/utils.hpp" +#include "viennacl/device_specific/tree_parsing.hpp" +#include "viennacl/forwards.h" + +#include "viennacl/tools/tools.hpp" + +namespace viennacl +{ +namespace device_specific +{ + +struct matrix_product_parameters : public template_base::parameters_type +{ + matrix_product_parameters(unsigned int simd_width + , unsigned int local_size_0, unsigned int KL, unsigned int local_size_1 + , unsigned int ms, unsigned int ks, unsigned int ns + , fetching_policy_type A_fetching_policy_param, fetching_policy_type B_fetching_policy_param + , unsigned int local_fetch_0_param, unsigned int local_fetch_1_param): template_base::parameters_type(simd_width, local_size_0, local_size_1, 1), + kL(KL), mS(ms), kS(ks), nS(ns), A_fetching_policy(A_fetching_policy_param), B_fetching_policy(B_fetching_policy_param), + local_fetch_0(local_fetch_0_param), local_fetch_1(local_fetch_1_param), + mL(ms*local_size_0), nL(ns*local_size_1){} + + unsigned int kL; + + unsigned int mS; + unsigned int kS; + unsigned int nS; + + fetching_policy_type A_fetching_policy; + fetching_policy_type B_fetching_policy; + + unsigned int local_fetch_0; + unsigned int local_fetch_1; + + unsigned int mL; + unsigned int nL; +}; + +class matrix_product_template : public template_base_impl<matrix_product_template, matrix_product_parameters> +{ + +private: + unsigned int n_lmem_elements() const + { + unsigned int N = 0; + if (p_.A_fetching_policy==FETCH_FROM_LOCAL) + N += p_.kL * (p_.mL+1); + if (p_.B_fetching_policy==FETCH_FROM_LOCAL) + N += p_.nL * (p_.kL+1); + return N; + } + + int check_invalid_impl(viennacl::ocl::device const & /*device*/) const + { + if (p_.A_fetching_policy!=FETCH_FROM_LOCAL && p_.B_fetching_policy!=FETCH_FROM_LOCAL&& (p_.local_fetch_0!=0 || p_.local_fetch_1!=0)) + return TEMPLATE_GLOBAL_MEMORY_REQUIRES_ZERO_LOCAL_FETCH; + + if ((p_.mS % p_.simd_width) > 0 || (p_.nS % p_.simd_width) > 0) + return TEMPLATE_MS_NS_MUST_BE_SIMD_WIDTH_MULTIPLE; + + if (p_.kS > p_.kL) + return TEMPLATE_KS_MUST_BE_SMALLER_THAN_KL; + + if (!(A_trans_=='N' && B_trans_=='T') && p_.simd_width>1) + return TEMPLATE_SIMD_WIDTH_MUST_BE_ONE; + + if (p_.A_fetching_policy==FETCH_FROM_LOCAL || p_.B_fetching_policy==FETCH_FROM_LOCAL) + { + if ((p_.local_fetch_0*p_.local_fetch_1) !=(p_.local_size_0*p_.local_size_1)) + return TEMPLATE_LOCAL_FETCH_PRODUCT_MUST_MATCH_LOCAL_SIZE_PRODUCT; + } + + if (p_.A_fetching_policy==FETCH_FROM_LOCAL) + { + unsigned int bound1 = (A_trans_=='N')?p_.kL:p_.mL; + unsigned int bound0 = (A_trans_=='N')?p_.mL:p_.kL; + + if (p_.local_fetch_1>0 && (bound1 % p_.local_fetch_1)> 0) + return A_trans_=='N'?TEMPLATE_LOCAL_FETCH_1_MUST_BE_KL_MULTIPLE:TEMPLATE_LOCAL_FETCH_1_MUST_BE_ML_MULTIPLE; + + if (p_.local_fetch_0>0 && (bound0 % (p_.local_fetch_0*p_.simd_width)) > 0) + return A_trans_=='N'?TEMPLATE_LOCAL_FETCH_0_MUST_BE_NL_MULTIPLE:TEMPLATE_LOCAL_FETCH_0_MUST_BE_KL_MULTIPLE; + + } + if (p_.B_fetching_policy==FETCH_FROM_LOCAL) + { + unsigned int bound1 = (B_trans_=='T')?p_.kL:p_.nL; + unsigned int bound0 = (B_trans_=='T')?p_.nL:p_.kL; + + if (p_.local_fetch_1>0 && (bound1 % p_.local_fetch_1)> 0) + return B_trans_=='T'?TEMPLATE_LOCAL_FETCH_1_MUST_BE_KL_MULTIPLE:TEMPLATE_LOCAL_FETCH_1_MUST_BE_ML_MULTIPLE; + + if (p_.local_fetch_0>0 && (bound0 % (p_.local_fetch_0*p_.simd_width)) > 0) + return B_trans_=='T'?TEMPLATE_LOCAL_FETCH_1_MUST_BE_KL_MULTIPLE:TEMPLATE_LOCAL_FETCH_1_MUST_BE_ML_MULTIPLE; + + } + + return TEMPLATE_VALID; + } + + static void parse(scheduler::statement const & s, + vcl_size_t & C_idx, leaf_t & C_leaf, vcl_size_t & alpha_idx, leaf_t & alpha_leaf, + vcl_size_t & A_idx, leaf_t & A_leaf, bool& A_trans, vcl_size_t & B_idx, leaf_t & B_leaf, bool& B_trans, + vcl_size_t & beta_idx, leaf_t & beta_leaf) + { + using namespace tree_parsing; + using namespace scheduler; + + scheduler::statement::container_type const & array = s.array(); + vcl_size_t root_idx = s.root(); + + C_idx = root_idx; + C_leaf = LHS_NODE_TYPE; + + vcl_size_t node_add_idx = array[root_idx].rhs.node_index; + + vcl_size_t node_1_idx = array[node_add_idx].lhs.node_index; + alpha_idx = node_1_idx; + alpha_leaf = RHS_NODE_TYPE; + + vcl_size_t mat_prod_idx = array[node_1_idx].lhs.node_index; + if (array[mat_prod_idx].lhs.type_family==MATRIX_TYPE_FAMILY) + { + A_trans = false; + A_idx = mat_prod_idx; + } + else + { + A_trans = true; + A_idx = array[mat_prod_idx].lhs.node_index; + } + A_leaf = LHS_NODE_TYPE; + + if (array[mat_prod_idx].rhs.type_family==MATRIX_TYPE_FAMILY) + { + B_trans = false; + B_idx = mat_prod_idx; + B_leaf = RHS_NODE_TYPE; + } + else + { + B_trans = true; + B_idx = array[mat_prod_idx].rhs.node_index; + B_leaf = LHS_NODE_TYPE; + } + + vcl_size_t node_2_idx = array[node_add_idx].rhs.node_index; + beta_idx = node_2_idx; + beta_leaf = RHS_NODE_TYPE; + } + + void VIENNACL_HANDLE_BOUNDS(bool fallback, utils::kernel_generation_stream & stream, std::string const & inbounds, std::string const & do_if, std::string do_else) const + { + if (fallback) + { + stream << "if (" << inbounds << ")" << std::endl; + stream.inc_tab(); + stream << do_if << ";" << std::endl; + stream.dec_tab(); + stream << "else" << std::endl; + stream.inc_tab(); + stream << do_else << ";" << std::endl; + stream.dec_tab(); + } + else + stream << do_if << ";" << std::endl; + } + + + std::string generate_impl(const std::string &kernel_prefix, const statements_container &statements, const std::vector<mapping_type> &mappings, bool fallback) const + { + using std::string; + using tools::to_string; + + parameters_type pfallback(1, p_.local_size_0, p_.kL, p_.local_size_1, p_.mS, 1, p_.nS, p_.A_fetching_policy, p_.B_fetching_policy, p_.local_fetch_0, p_.local_fetch_1); + parameters_type const & p = fallback?pfallback:p_; + +#define VIENNACL_MUL_STRIDE1 string(fallback?"*#stride1":"") +#define VIENNACL_HANDLE_BOUNDS(in_bounds, to_load) (!fallback?string(to_load):string( string(in_bounds) + "?" + string(to_load) + ":0")) +#define VIENNACL_VSTORE(value, offset, ptr) vstore(p.simd_width, value, offset, ptr) + + string widthstr = tools::to_string(p.simd_width); + + ////////////////// + /// INIT + /// ////////////// + utils::kernel_generation_stream stream; + scheduler::statement const & st = statements.data().front(); + mapping_type const & mapping = mappings.front(); + + bool A_trans = false, B_trans = false; + vcl_size_t C_idx=0, alpha_idx=0, A_idx=0, B_idx=0, beta_idx=0; + leaf_t C_leaf=LHS_NODE_TYPE, alpha_leaf=LHS_NODE_TYPE, A_leaf=LHS_NODE_TYPE, B_leaf=LHS_NODE_TYPE, beta_leaf=LHS_NODE_TYPE; + parse(st, C_idx, C_leaf, alpha_idx, alpha_leaf, A_idx, A_leaf, A_trans, B_idx, B_leaf, B_trans, beta_idx, beta_leaf); + + mapped_matrix * C = (mapped_matrix* )at(mapping, mapping_key( C_idx, C_leaf)).get(); + mapped_host_scalar * alpha = (mapped_host_scalar*)at(mapping, mapping_key(alpha_idx, alpha_leaf)).get(); + mapped_matrix * A = (mapped_matrix* )at(mapping, mapping_key( A_idx, A_leaf)).get(); + mapped_matrix * B = (mapped_matrix* )at(mapping, mapping_key( B_idx, B_leaf)).get(); + mapped_host_scalar * beta = (mapped_host_scalar*)at(mapping, mapping_key( beta_idx, beta_leaf)).get(); + + ////////////////// + /// DECLARATIONS + /// ////////////// + + stream << " __attribute__((reqd_work_group_size(" << p.local_size_0 << "," << p.local_size_1 << ",1)))" << std::endl; + std::map<std::string, unsigned int> widths; + widths[A->name()] = p.simd_width; + widths[B->name()] = p.simd_width; + generate_prototype(stream, kernel_prefix, "unsigned int M, unsigned int N, unsigned int K, ", mappings, statements, widths); + stream << "{" << std::endl; + stream.inc_tab(); + if(!fallback) + { + stream << A->process("#start1 /= " + to_string(p.simd_width) + ";") << std::endl; + stream << A->process("#ld /= " + to_string(p.simd_width) + ";") << std::endl; + stream << B->process("#start1/= " + to_string(p.simd_width) + ";") << std::endl; + stream << B->process("#ld /= " + to_string(p.simd_width) + ";") << std::endl; + } + tree_parsing::process(stream, PARENT_NODE_TYPE, "matrix", "#pointer += $OFFSET{#start1, #start2};", statements, mappings); + tree_parsing::process(stream, PARENT_NODE_TYPE, "matrix", "#ld *= #nldstride;", statements, mappings); + + ///Result Values + stream << C->process("#scalartype rC[" + to_string(p.mS) + "][" + to_string(p.nS) + "] = {{(#scalartype)0}};") << std::endl; + if (p.A_fetching_policy==FETCH_FROM_LOCAL) + stream << A->process("#scalartype rA[" + to_string(p.kS) + "][" + to_string(p.mS) + "];") << std::endl; + else + stream << A->process(utils::append_width("#scalartype",p.simd_width) + " rA[" + to_string(p.kS) + "][" + to_string(p.mS/p.simd_width) + "];") << std::endl; + if (p.B_fetching_policy==FETCH_FROM_LOCAL) + stream << B->process("#scalartype rB[" + to_string(p.kS) + "][" + to_string(p.nS) + "];"); + else + stream << B->process(utils::append_width("#scalartype",p.simd_width) + " rB[" + to_string(p.kS) + "][" + to_string(p.nS/p.simd_width) + "];") << std::endl; + + + if (p.A_fetching_policy==FETCH_FROM_LOCAL) + stream << A->process("__local #scalartype lA[" + to_string(p.kL*(p.mL+1)) + "];"); + if (p.B_fetching_policy==FETCH_FROM_LOCAL) + stream << B->process("__local #scalartype lB[" + to_string(p.kL*(p.nL+1)) + "];"); + stream << std::endl; + + stream << "size_t gidx = get_group_id(0);" << std::endl; + stream << "size_t gidy = get_group_id(1);" << std::endl; + stream << "size_t idx = get_local_id(0);" << std::endl; + stream << "size_t idy = get_local_id(1);" << std::endl; + + if (p.A_fetching_policy==FETCH_FROM_LOCAL || p.B_fetching_policy==FETCH_FROM_LOCAL) + { + stream << std::endl; + stream << "size_t idt = " << p.local_size_0 << "*idy + idx;" << std::endl; + stream << "size_t idxT = idt % " << p.local_fetch_0 << ";" << std::endl; + stream << "size_t idyT = idt / " << p.local_fetch_0 << ";" << std::endl; + } + stream << std::endl; + + if (fallback) + { + //Bounds checking for M (in A, C) + stream << "bool in_bounds_m[" << p.mS << "];" << std::endl; + stream << "for(size_t m = 0; m < " << p.mS << "; m++)" << std::endl; + stream.inc_tab(); + switch (p.A_fetching_policy) + { + case FETCH_FROM_GLOBAL_CONTIGUOUS: + stream << "in_bounds_m[m] = gidx*" << p.mL << " + idx*" << p.mS << " + m < M;" << std::endl; + break; + default: + stream << "in_bounds_m[m] = gidx*" << p.mL << " + idx + m*" << p.local_size_0 << " < M;" << std::endl; + break; + } + stream.dec_tab(); + + //Bounds checking for A if Local + if (p.A_fetching_policy==FETCH_FROM_LOCAL) + { + unsigned int fetch_size = (A_trans_=='N'?p.local_fetch_0*p.simd_width:p.local_fetch_1); + stream << "bool in_bounds_m_local[" << p.mL/fetch_size << "];" << std::endl; + stream << "for(size_t m = 0; m < " << p.mL/fetch_size << "; m++)" << std::endl; + stream.inc_tab(); + stream << "in_bounds_m_local[m] = gidx*" << p.mL << " + " << (A_trans_=='N'?"idxT":"idyT") << " + m*" << fetch_size << " < M;" << std::endl; + stream.dec_tab(); + } + + //Bounds checking for N (in B, C) + stream << "bool in_bounds_n[" << p.nS << "];" << std::endl; + stream << "for(size_t n = 0; n < " << p.nS << "; n++)" << std::endl; + stream.inc_tab(); + switch (p.B_fetching_policy) + { + case FETCH_FROM_GLOBAL_CONTIGUOUS: + stream << "in_bounds_n[n] = gidy*" << p.nL << " + idy*" << p.nS << " + n < N;" << std::endl; + break; + default: + stream << "in_bounds_n[n] = gidy*" << p.nL << " + idy + n*" << p.local_size_1 << " < N;" << std::endl; + break; + } + stream.dec_tab(); + + //Bounds checking for B if Local + if (p.B_fetching_policy==FETCH_FROM_LOCAL) + { + unsigned int fetch_size = (B_trans_=='T'?p.local_fetch_0*p.simd_width:p.local_fetch_1); + stream << "bool in_bounds_n_local[" << p.nL/fetch_size << "];" << std::endl; + stream << "for(size_t n = 0; n < " << p.nL/fetch_size << "; n++)" << std::endl; + stream.inc_tab(); + stream << "in_bounds_n_local[n] = gidy*" << p.nL << " + " << (B_trans_=='T'?"idxT":"idyT") << " + n*" << fetch_size << " < N;" << std::endl; + stream.dec_tab(); + } + } + + switch (p.A_fetching_policy) + { + case FETCH_FROM_LOCAL: + if (A_trans_=='N') + stream << A->process("#pointer += (gidx*" + to_string(p.mL/p.simd_width) + " + idxT)" + VIENNACL_MUL_STRIDE1 + " + idyT*#ld;") << std::endl; + else + stream << A->process("#pointer += idxT" + VIENNACL_MUL_STRIDE1 + " + gidx*" + to_string(p.mL/p.simd_width) + "*#ld + idyT*#ld;") << std::endl; + break; + + case FETCH_FROM_GLOBAL_CONTIGUOUS: + if (A_trans_=='N') + stream << A->process("#pointer += (gidx*" + to_string(p.mL/p.simd_width) + "+ idx*" + to_string(p.mS/p.simd_width) + ")" + VIENNACL_MUL_STRIDE1 + ";") << std::endl; + else + stream << A->process("#pointer += (gidx*" + to_string(p.mL/p.simd_width) + "+ idx*" + to_string(p.mS/p.simd_width) + ")*#ld;") << std::endl; + break; + + case FETCH_FROM_GLOBAL_STRIDED: + if (A_trans_=='N') + stream << A->process("#pointer += (gidx*" + to_string(p.mL/p.simd_width) + "+ idx" + ")" + VIENNACL_MUL_STRIDE1 + ";") << std::endl; + else + stream << A->process("#pointer += (gidx*" + to_string(p.mL/p.simd_width) + "+ idx)*#ld;") << std::endl; + break; + + //default: break; + } + + switch (p.B_fetching_policy) + { + case FETCH_FROM_LOCAL: + if (B_trans_=='T') + stream << B->process("#pointer += (gidy*" + to_string(p.nL/p.simd_width) + " + idxT" + ")" + VIENNACL_MUL_STRIDE1 + " + idyT*#ld;") << std::endl; + else + stream << B->process("#pointer += idxT" + VIENNACL_MUL_STRIDE1 + " + gidy*" + to_string(p.nL/p.simd_width) + "*#ld + idyT*#ld;") << std::endl; + break; + + case FETCH_FROM_GLOBAL_CONTIGUOUS: + if (B_trans_=='T') + stream << B->process("#pointer += (gidy*" + to_string(p.nL/p.simd_width) + "+ idy*" + to_string(p.nS/p.simd_width) + ")" + VIENNACL_MUL_STRIDE1 + ";") << std::endl; + else + stream << B->process("#pointer += (gidy*" + to_string(p.nL/p.simd_width) + "+ idy*" + to_string(p.nS/p.simd_width) + ")*#ld;") << std::endl; + break; + + case FETCH_FROM_GLOBAL_STRIDED: + if (B_trans_=='T') + stream << B->process("#pointer += (gidy*" + to_string(p.nL/p.simd_width) + "+ idy" + ")" + VIENNACL_MUL_STRIDE1 + ";") << std::endl; + else + stream << B->process("#pointer += (gidy*" + to_string(p.nL/p.simd_width) + "+ idy)*#ld;") << std::endl; + break; + + //default: break; + } + + stream << std::endl; + stream << "size_t K_size_t = K;" << std::endl; + stream << "for(size_t block_k=0; block_k < K_size_t; block_k+=" << p.kL << "){" << std::endl; + stream.inc_tab(); + + if (p.A_fetching_policy==FETCH_FROM_LOCAL) + { + if (A_trans_=='N') + stream << A->process("__local #scalartype* plA = lA + idyT*" + to_string(p.mL + 1) + " + " + to_string(p.simd_width) + "*idxT;") << std::endl; + else + stream << A->process("__local #scalartype* plA = lA + idxT*" + to_string(p.mL + 1) + " + idyT;") << std::endl; + } + + + if (p.B_fetching_policy==FETCH_FROM_LOCAL) + { + if (B_trans_=='T') + stream << B->process("__local #scalartype* plB = lB + idyT*" + to_string(p.nL+1) + " + " + to_string(p.simd_width) + "*idxT;") << std::endl; + else + stream << B->process("__local #scalartype* plB = lB + idxT*" + to_string(p.nL+1) + "+ idyT;") <<std::endl; + } + + + if (p.A_fetching_policy==FETCH_FROM_LOCAL || p.B_fetching_policy==FETCH_FROM_LOCAL) + stream << "barrier(CLK_LOCAL_MEM_FENCE);" << std::endl; + + ///Fetch LHS to Local Memory + if (p.A_fetching_policy==FETCH_FROM_LOCAL && A_trans_=='N') + for (unsigned int k = 0; k < p.kL; k += p.local_fetch_1) + for (unsigned int m = 0; m < p.mL; m += p.local_fetch_0*p.simd_width) + { + string in_bounds = "in_bounds_m_local[" + to_string(m/(p.local_fetch_0*p.simd_width)) + "]"; + string to_load = "#pointer[" + to_string(k) + "*#ld + " + to_string(m/p.simd_width) + VIENNACL_MUL_STRIDE1 + "]"; + stream << A->process(VIENNACL_VSTORE(VIENNACL_HANDLE_BOUNDS(in_bounds, to_load), "0", "plA + " + to_string(k*(p.mL+1)+m))) << ";" << std::endl; + } + else if (p.A_fetching_policy==FETCH_FROM_LOCAL && A_trans_=='T') + for (unsigned int k = 0; k < p.mL; k += p.local_fetch_1) + for (unsigned int m = 0; m < p.kL; m += p.local_fetch_0*p.simd_width) + { + string in_bounds = "in_bounds_m_local[" + to_string(k/p.local_fetch_1) + "]"; + string to_load = "#pointer[" + to_string(k) + "*#ld + " + to_string(m/p.simd_width) + VIENNACL_MUL_STRIDE1 + "]"; + stream << A->process(VIENNACL_VSTORE(VIENNACL_HANDLE_BOUNDS(in_bounds, to_load), "0", "plA + " + to_string(m*(p.mL+1)+k))) << ";" << std::endl; + } + + if (p.B_fetching_policy==FETCH_FROM_LOCAL && B_trans_=='T') + for (unsigned int k = 0; k < p.kL; k += p.local_fetch_1) + for (unsigned int n = 0; n < p.nL; n += p.local_fetch_0*p.simd_width) + { + string in_bounds = "in_bounds_n_local[" + to_string(n/(p.local_fetch_0*p.simd_width)) + "]"; + string to_load = "#pointer[" + to_string(k) + "*#ld + " + to_string(n/p.simd_width) + VIENNACL_MUL_STRIDE1 + "]"; + stream << B->process(VIENNACL_VSTORE(VIENNACL_HANDLE_BOUNDS(in_bounds, to_load), "0", "plB + " + to_string(k*(p.nL+1)+n))) << ";" << std::endl; + } + else if (p.B_fetching_policy==FETCH_FROM_LOCAL && B_trans_=='N') + for (unsigned int k = 0; k < p.nL; k += p.local_fetch_1) + for (unsigned int n = 0; n < p.kL; n += p.local_fetch_0*p.simd_width) + { + string in_bounds = "in_bounds_n_local[" + to_string(k/p.local_fetch_1) + "]"; + string to_load = "#pointer[" + to_string(k) + "*#ld + " + to_string(n/p.simd_width) + VIENNACL_MUL_STRIDE1 + "]"; + stream << B->process(VIENNACL_VSTORE(VIENNACL_HANDLE_BOUNDS(in_bounds, to_load), "0", "plB + " + to_string(n*(p.nL+1)+k))) << ";" << std::endl; + } + + if (p.A_fetching_policy==FETCH_FROM_LOCAL || p.B_fetching_policy == FETCH_FROM_LOCAL) + { + stream << "barrier(CLK_LOCAL_MEM_FENCE);" << std::endl; + stream << "size_t offA = " << p.simd_width << "*idx;" << std::endl; + stream << "size_t offB = " << p.simd_width << "*idy;" << std::endl; + } + + if (fallback) + stream << "for(size_t k = 0; k < " << p.kL << " && (block_k + k < K_size_t); k+=" << p.kS << "){" << std::endl; + else + stream << "for(size_t k = 0; k < " << p.kL << "; k+=" << p.kS << "){" << std::endl; + stream.inc_tab(); + + ///Fetch LHS to registers + stream << "#pragma unroll " << p.kS << std::endl; + stream << "for(size_t kk = 0; kk < " << p.kS << "; kk++)" << std::endl; + stream << "#pragma unroll " << p.mS/p.simd_width << std::endl; + stream << "for(size_t mm = 0; mm < " << p.mS/p.simd_width << "; mm++)" << std::endl; + stream << "{" << std::endl; + stream.inc_tab(); + switch (p.A_fetching_policy) + { + case FETCH_FROM_LOCAL: + for (unsigned int ss = 0; ss < p.simd_width; ++ss) + stream << "rA[kk][mm*" << p.simd_width << "+" << ss << "] = lA[offA + mm*" << p.local_size_0*p.simd_width << "+" << ss << "+ kk*" << (p.mL+1) << "];" << std::endl; + break; + + case FETCH_FROM_GLOBAL_CONTIGUOUS: + { + if (A_trans_=='N') + stream << "rA[kk][mm] = " << A->process(VIENNACL_HANDLE_BOUNDS("in_bounds_m[mm]", "#pointer[kk*#ld + mm" + VIENNACL_MUL_STRIDE1 + "]")) << ";" << std::endl; + else + stream << "rA[kk][mm] = " << A->process(VIENNACL_HANDLE_BOUNDS("in_bounds_m[mm]", "#pointer[mm*#ld + kk" + VIENNACL_MUL_STRIDE1 + "]")) << ";" << std::endl; + break; + } + + case FETCH_FROM_GLOBAL_STRIDED: + { + if (A_trans_=='N') + stream << "rA[kk][mm] = " << A->process(VIENNACL_HANDLE_BOUNDS("in_bounds_m[mm]", "#pointer[kk*#ld + mm*" + to_string(p.local_size_0) + VIENNACL_MUL_STRIDE1 + "]")) << ";" << std::endl; + else + stream << "rA[kk][mm] = " << A->process(VIENNACL_HANDLE_BOUNDS("in_bounds_m[mm]", "#pointer[mm*#ld*" + to_string(p.local_size_0) + " + kk" + VIENNACL_MUL_STRIDE1 + "]")) << ";" << std::endl; + break; + } + + //default: break; + } + stream.dec_tab(); + stream << "}" << std::endl; + + stream << "#pragma unroll " << p.kS << std::endl; + stream << "for(size_t kk = 0; kk < " << p.kS << "; kk++)" << std::endl; + stream << "#pragma unroll " << p.nS/p.simd_width << std::endl; + stream << "for(size_t nn = 0; nn < " << p.nS/p.simd_width << "; nn++)" << std::endl; + stream << "{" << std::endl; + stream.inc_tab(); + switch (p.B_fetching_policy) + { + case FETCH_FROM_LOCAL: + for (unsigned int ss = 0; ss < p.simd_width; ++ss) + stream << "rB[kk][nn*" << p.simd_width << "+" << ss << "] = lB[offB + nn*" << p.local_size_1*p.simd_width << "+" << ss << "+ kk*" << (p.nL+1) << "];" << std::endl; + break; + + case FETCH_FROM_GLOBAL_CONTIGUOUS: + { + if (B_trans_=='T') + stream << "rB[kk][nn] = " << B->process(VIENNACL_HANDLE_BOUNDS("in_bounds_n[nn]", "#pointer[kk*#ld + nn" + VIENNACL_MUL_STRIDE1 + "]")) << ";" << std::endl; + else + stream << "rB[kk][nn] = " << B->process(VIENNACL_HANDLE_BOUNDS("in_bounds_n[nn]", "#pointer[nn*#ld + kk" + VIENNACL_MUL_STRIDE1 + "]")) << ";" << std::endl; + break; + } + + case FETCH_FROM_GLOBAL_STRIDED: + { + if (B_trans_=='T') + stream << "rB[kk][nn] = " << B->process(VIENNACL_HANDLE_BOUNDS("in_bounds_n[nn]", "#pointer[kk*#ld + nn*" + to_string(p.local_size_1) + VIENNACL_MUL_STRIDE1 + "]")) << ";" << std::endl; + else + stream << "rB[kk][nn] = " << B->process(VIENNACL_HANDLE_BOUNDS("in_bounds_n[nn]", "#pointer[nn*#ld*" + to_string(p.local_size_1) + " + kk" + VIENNACL_MUL_STRIDE1 + "]")) << ";" << std::endl; + break; + } + + //default: break; + } + stream.dec_tab(); + stream << "}" << std::endl; + + + ///Increment pointers + switch (p.A_fetching_policy) + { + case FETCH_FROM_LOCAL: + stream << "offA += " << p.kS*(p.mL+1) << ";" << std::endl; + break; + + default: + if (A_trans_=='N') + stream << A->process("#pointer += " + to_string(p.kS) + "*#ld;") << std::endl; + else + stream << A->process("#pointer += " + to_string(p.kS) + "" + VIENNACL_MUL_STRIDE1 + ";") << std::endl; + break; + } + + + switch (p.B_fetching_policy) + { + case FETCH_FROM_LOCAL: + stream << "offB += " << p.kS*(p.nL+1) << ";" << std::endl; + break; + + default: + if (B_trans_=='T') + stream << B->process("#pointer += " + to_string(p.kS) + "*#ld;") << std::endl; + else + stream << B->process("#pointer += " + to_string(p.kS) + "" + VIENNACL_MUL_STRIDE1 + ";") << std::endl; + break; + } + + + stream << "#pragma unroll " << p.kS << std::endl; + stream << "for(size_t kk = 0; kk <" << p.kS << "; ++kk)" << std::endl; + stream << "{" << std::endl; + stream.inc_tab(); + for (unsigned int nn=0; nn < p.nS; ++nn) + for (unsigned int mm=0; mm < p.mS; ++mm) + { + string res_str, lhs_str, rhs_str; + res_str = "rC[" + tools::to_string(mm) + "][" + tools::to_string(nn) + "]"; + if (p.A_fetching_policy==FETCH_FROM_LOCAL || p.simd_width==1) + lhs_str = "rA[kk][" + tools::to_string(mm) + "]"; + else + lhs_str = "rA[kk][" + tools::to_string(mm/p.simd_width) + "].s" + tools::to_string(mm%p.simd_width); + if (p.B_fetching_policy==FETCH_FROM_LOCAL || p.simd_width==1) + rhs_str = "rB[kk]["+tools::to_string(nn)+"]"; + else + rhs_str = "rB[kk]["+tools::to_string(nn/p.simd_width)+"].s"+tools::to_string(nn%p.simd_width); + stream << res_str << "=" << "fma(" << lhs_str << "," << rhs_str << "," << res_str << ");" << std::endl; + } + stream.dec_tab(); + stream << "}" << std::endl; + + + + + stream.dec_tab(); + stream << "}" << std::endl; + + //Increment global pointer if local memory is used + //Else, it's incremented directly when fetching + if (p.A_fetching_policy==FETCH_FROM_LOCAL) + { + if (A_trans_=='N') + stream << A->process("#pointer += " + to_string(p.kL) + "*#ld;") << std::endl; + else + stream << A->process("#pointer += " + to_string(p.kL) + "" + VIENNACL_MUL_STRIDE1 + ";") << std::endl; + } + + if (p.B_fetching_policy==FETCH_FROM_LOCAL) + { + if (B_trans_=='T') + stream << B->process("#pointer += " + to_string(p.kL) + "*#ld;") << std::endl; + else + stream << B->process("#pointer += " + to_string(p.kL) + "" + VIENNACL_MUL_STRIDE1 + ";") << std::endl; + } + + stream.dec_tab(); + stream << "}" << std::endl; + + + if (C->row_major()) + { + unsigned int ministartstride0 = p.A_fetching_policy==FETCH_FROM_GLOBAL_CONTIGUOUS?p.mS:p.simd_width; + unsigned int ministartstride1 = p.B_fetching_policy==FETCH_FROM_GLOBAL_CONTIGUOUS?p.nS:p.simd_width; + + stream << C->process("#pointer += gidx*" + to_string(p.mL) + "*#ld;") << std::endl; + stream << C->process("#pointer += idx*" + to_string(ministartstride0) + "*#ld;") << std::endl; + stream << C->process("#pointer += gidy*" + to_string(p.nL) + "*#stride2;") << std::endl; + stream << C->process("#pointer += idy*" + to_string(ministartstride1) + "*#stride2;") << std::endl; + + for (unsigned int n=0; n < p.nS; ++n) + { + for (unsigned int m=0; m < p.mS; ++m) + { + unsigned int ministride1 = p.A_fetching_policy==FETCH_FROM_GLOBAL_CONTIGUOUS?1:p.local_size_0; + string Cj = to_string((m/p.simd_width)*(ministride1*p.simd_width) + m%p.simd_width); + if (fallback) + { + stream << "if (in_bounds_m[" + to_string(m) + "] && in_bounds_n[" + to_string(n) + "])" << std::endl; + stream.inc_tab(); + } + stream << C->process("#pointer[" + Cj + "*#ld] = rC[" + to_string(m) + "][" + to_string(n) + "]*" + alpha->name() + "+ #pointer[" + Cj + "*#ld]*" + beta->name() + ";") << std::endl; + if (fallback) + stream.dec_tab(); + } + if ((n+1)%p.simd_width>0 || p.B_fetching_policy==FETCH_FROM_GLOBAL_CONTIGUOUS) + stream << C->process("#pointer += #stride2;") << std::endl; + else + stream << C->process("#pointer += " + to_string((p.local_size_1*p.simd_width) - (p.simd_width-1)) + "*#stride2;") << std::endl; + } + + } + else + { + unsigned int ministartstride0 = p.A_fetching_policy==FETCH_FROM_GLOBAL_CONTIGUOUS?p.mS:p.simd_width; + unsigned int ministartstride1 = p.B_fetching_policy==FETCH_FROM_GLOBAL_CONTIGUOUS?p.nS:p.simd_width; + + stream << C->process("#pointer += gidx*" + to_string(p.mL) + "*#stride1;") << std::endl; + stream << C->process("#pointer += idx*" + to_string(ministartstride0) + "*#stride1;") << std::endl; + stream << C->process("#pointer += gidy*" + to_string(p.nL) + "*#ld;") << std::endl; + stream << C->process("#pointer += idy*" + to_string(ministartstride1) + "*#ld;") << std::endl; + + for (unsigned int m=0; m < p.mS; ++m) + { + for (unsigned int n=0; n < p.nS; ++n) + { + unsigned int ministride1 = p.B_fetching_policy==FETCH_FROM_GLOBAL_CONTIGUOUS?1:p.local_size_1; + string Cj = to_string((n/p.simd_width)*(ministride1*p.simd_width) + n%p.simd_width); + if (fallback) + { + stream << "if (in_bounds_m[" + to_string(m) + "] && in_bounds_n[" + to_string(n) + "])" << std::endl; + stream.inc_tab(); + } + stream << C->process("#pointer[" + Cj + "*#ld] = rC[" + to_string(m) + "][" + to_string(n) + "]*" + alpha->name() + " + #pointer[" + Cj + "*#ld]*" + beta->name() + ";") << std::endl; + if (fallback) + stream.dec_tab(); + } + + if ((m+1)%p.simd_width>0 || p.A_fetching_policy==FETCH_FROM_GLOBAL_CONTIGUOUS) + stream << C->process("#pointer += #stride1;") << std::endl; + else + stream << C->process("#pointer += " + to_string((p.local_size_0*p.simd_width) - (p.simd_width-1)) + "*#stride1;") << std::endl; + } + } + + stream.dec_tab(); + stream << "}" << std::endl; + + return stream.str(); + +#undef VIENNACL_MUL_STRIDE1 +#undef VIENNACL_HANDLE_BOUNDS +#undef VIENNACL_VSTORE + } + + std::vector<std::string> generate_impl(std::string const & kernel_prefix, statements_container const & statements, std::vector<mapping_type> const & mappings) const + { + std::vector<std::string> res; + res.push_back(generate_impl(kernel_prefix, statements, mappings, false)); + res.push_back(generate_impl(kernel_prefix, statements, mappings, true)); + return res; + } + + template<class NumericT> + void enqueue_block(scheduler::statement & statement, + scheduler::lhs_rhs_element& eA, scheduler::lhs_rhs_element& eB, scheduler::lhs_rhs_element& eC, scheduler::lhs_rhs_element& ebeta, + matrix_base<NumericT> const & A, matrix_base<NumericT> const & B, matrix_base<NumericT> const & C, NumericT beta, + std::vector<lazy_program_compiler> & programs, std::string const & kernel_prefix, vcl_size_t id) + { + if (A.size1()==0 || A.size2()==0 || B.size1()==0 || B.size2()==0 || C.size1()==0 || C.size2()==0) + return; + + viennacl::ocl::kernel& kernel = programs[id].program().get_kernel(kernel_prefix); + + kernel.local_work_size(0, p_.local_size_0); + kernel.local_work_size(1, p_.local_size_1); + + scheduler::statement::assign_element(eA, A); + scheduler::statement::assign_element(eB, B); + scheduler::statement::assign_element(eC, C); + scheduler::statement::assign_element(ebeta, beta); + + if (id==1) + { + kernel.global_work_size(0, tools::align_to_multiple(tools::align_to_multiple((unsigned int)C.size1(),p_.mS)/p_.mS, p_.local_size_0)); + kernel.global_work_size(1, tools::align_to_multiple(tools::align_to_multiple((unsigned int)C.size2(),p_.nS)/p_.nS, p_.local_size_1)); + } + else + { + kernel.global_work_size(0, C.size1()/p_.mS); + kernel.global_work_size(1, C.size2()/p_.nS); + } + unsigned int current_arg = 0; + kernel.arg(current_arg++, cl_uint(C.size1())); + kernel.arg(current_arg++, cl_uint(C.size2())); + if (A.row_major()) + kernel.arg(current_arg++, cl_uint(A_trans_=='T'?A.size2():A.size1())); + else + kernel.arg(current_arg++, cl_uint(A_trans_=='N'?A.size2():A.size1())); + set_arguments(statement, kernel, current_arg); + viennacl::ocl::enqueue(kernel); + + } + + template<class NumericT> + matrix_slice< viennacl::matrix_base<NumericT> > create_slice(viennacl::matrix_base<NumericT>* scheduler::lhs_rhs_element::*ptr, scheduler::lhs_rhs_element const & element, + vcl_size_t s0_0, vcl_size_t s0_1, vcl_size_t s1_0, vcl_size_t s1_1, bool swap) + { + matrix_base<NumericT> & M = *(element.*ptr); + slice s0(s0_0, 1, s0_1 - s0_0); + slice s1(s1_0, 1, s1_1 - s1_0); + if (swap) + std::swap(s0, s1); + return matrix_slice<viennacl::matrix_base<NumericT> >(M, s0, s1); + } + + template<class NumericT> + void enqueue_impl(viennacl::matrix_base<NumericT>* scheduler::lhs_rhs_element::*ptr_matrix, + scheduler::statement & statement, scheduler::lhs_rhs_element & A, scheduler::lhs_rhs_element & B, scheduler::lhs_rhs_element & C, scheduler::lhs_rhs_element & beta, + NumericT beta_value, std::vector<lazy_program_compiler> & programs, std::string const & kernel_prefix) + { + using namespace device_specific::utils; + vcl_size_t ldstrideA = call_on_matrix(A, leading_stride()); + vcl_size_t ldstrideB = call_on_matrix(B, leading_stride()); + vcl_size_t ldstrideC = call_on_matrix(C, leading_stride()); + vcl_size_t ldstartA = call_on_matrix(A, leading_start()); + vcl_size_t ldstartB = call_on_matrix(B, leading_start()); + bool swap_A = ((A_trans_=='T') ^ utils::call_on_matrix(A, row_major_fun())); + bool swap_B = ((B_trans_=='T') ^ utils::call_on_matrix(B, row_major_fun())); + + vcl_size_t M = call_on_matrix(C, size1_fun()); + vcl_size_t N = call_on_matrix(C, size2_fun()); + vcl_size_t K; + if (utils::call_on_matrix(A, row_major_fun())) + K = A_trans_=='T'?call_on_matrix(A, size2_fun()):call_on_matrix(A, size1_fun()); + else + K = A_trans_=='N'?call_on_matrix(A, size2_fun()):call_on_matrix(A, size1_fun()); + + if (M < p_.mL || N < p_.nL || K < p_.kL || ldstrideA> 1 || ldstrideB > 1 || ldstrideC > 1 || + (p_.simd_width>1 && (ldstartA % p_.simd_width > 0 || ldstartB % p_.simd_width > 0))) + { + enqueue_block(statement, A, B, C, beta, create_slice(ptr_matrix, A, 0, M, 0, K, swap_A), + create_slice(ptr_matrix, B, 0, K, 0, N, swap_B), + create_slice(ptr_matrix, C, 0, M, 0, N, false), beta_value, programs, kernel_prefix, 1); + return; + } + + + scheduler::lhs_rhs_element Acopy = A; + scheduler::lhs_rhs_element Bcopy = B; + scheduler::lhs_rhs_element Ccopy = C; + + vcl_size_t lM = M / p_.mL * p_.mL; + vcl_size_t lN = N / p_.nL * p_.nL; + vcl_size_t lK = K / p_.kL * p_.kL; + + + enqueue_block(statement, A, B, C, beta, create_slice<NumericT>(ptr_matrix, Acopy, 0, lM, 0, lK, swap_A), create_slice<NumericT>(ptr_matrix, Bcopy, 0, lK, 0, lN, swap_B), create_slice<NumericT>(ptr_matrix, Ccopy, 0, lM, 0, lN, false), beta_value, programs, kernel_prefix, 0); + enqueue_block(statement, A, B, C, beta, create_slice<NumericT>(ptr_matrix, Acopy, 0, lM, lK, K, swap_A), create_slice<NumericT>(ptr_matrix, Bcopy, lK, K, 0, lN, swap_B), create_slice<NumericT>(ptr_matrix, Ccopy, 0, lM, 0, lN, false), (NumericT)1, programs, kernel_prefix, 1); + + enqueue_block(statement, A, B, C, beta, create_slice<NumericT>(ptr_matrix, Acopy, 0, lM, 0, lK, swap_A), create_slice<NumericT>(ptr_matrix, Bcopy, 0, lK, lN, N, swap_B), create_slice<NumericT>(ptr_matrix, Ccopy, 0, lM, lN, N, false), beta_value, programs, kernel_prefix, 1); + enqueue_block(statement, A, B, C, beta, create_slice<NumericT>(ptr_matrix, Acopy, 0, lM, lK, K, swap_A), create_slice<NumericT>(ptr_matrix, Bcopy, lK, K, lN, N, swap_B), create_slice<NumericT>(ptr_matrix, Ccopy, 0, lM, lN, N, false), (NumericT)1, programs, kernel_prefix, 1); + + enqueue_block(statement, A, B, C, beta, create_slice<NumericT>(ptr_matrix, Acopy, lM, M, 0, lK, swap_A), create_slice<NumericT>(ptr_matrix, Bcopy, 0, lK, 0, lN, swap_B), create_slice<NumericT>(ptr_matrix, Ccopy, lM, M, 0, lN, false), beta_value, programs, kernel_prefix, 1); + enqueue_block(statement, A, B, C, beta, create_slice<NumericT>(ptr_matrix, Acopy, lM, M, lK, K, swap_A), create_slice<NumericT>(ptr_matrix, Bcopy, lK, K, 0, lN, swap_B), create_slice<NumericT>(ptr_matrix, Ccopy, lM, M, 0, lN, false), (NumericT)1, programs, kernel_prefix, 1); + + enqueue_block(statement, A, B, C, beta, create_slice<NumericT>(ptr_matrix, Acopy, lM, M, 0, lK, swap_A), create_slice<NumericT>(ptr_matrix, Bcopy, 0, lK, lN, N, swap_B), create_slice<NumericT>(ptr_matrix, Ccopy, lM, M, lN, N, false), beta_value, programs, kernel_prefix, 1); + enqueue_block(statement, A, B, C, beta, create_slice<NumericT>(ptr_matrix, Acopy, lM, M, lK, K, swap_A), create_slice<NumericT>(ptr_matrix, Bcopy, lK, K, lN, N, swap_B), create_slice<NumericT>(ptr_matrix, Ccopy, lM, M, lN, N, false), (NumericT)1, programs, kernel_prefix, 1); + } + +public: + matrix_product_template(matrix_product_template::parameters_type const & parameters, char A_trans, char B_trans) : template_base_impl<matrix_product_template, matrix_product_parameters>(parameters, BIND_ALL_UNIQUE), A_trans_(A_trans), B_trans_(B_trans){ } + + virtual void enqueue(std::string const & kernel_prefix, std::vector<lazy_program_compiler> & programs, statements_container const & statements) + { + using namespace device_specific::utils; + using namespace tree_parsing; + + scheduler::statement const & st = statements.data().front(); + bool A_trans, B_trans; + vcl_size_t C_idx=0, A_idx=0, B_idx=0, alpha_idx=0, beta_idx = 0; + leaf_t C_leaf=LHS_NODE_TYPE, A_leaf=LHS_NODE_TYPE, B_leaf=LHS_NODE_TYPE, alpha_leaf=LHS_NODE_TYPE, beta_leaf=LHS_NODE_TYPE; + parse(st, C_idx, C_leaf, alpha_idx, alpha_leaf, A_idx, A_leaf, A_trans, B_idx, B_leaf, B_trans, beta_idx, beta_leaf); + + scheduler::statement stcopy = st; + scheduler::lhs_rhs_element& A = utils::lhs_rhs_element(stcopy, A_idx, A_leaf); + scheduler::lhs_rhs_element& B = utils::lhs_rhs_element(stcopy, B_idx, B_leaf); + scheduler::lhs_rhs_element& C = utils::lhs_rhs_element(stcopy, C_idx, C_leaf); + scheduler::lhs_rhs_element& beta = utils::lhs_rhs_element(stcopy, beta_idx, beta_leaf); + + + + + + + if (C.numeric_type==scheduler::FLOAT_TYPE) + enqueue_impl<float>(&scheduler::lhs_rhs_element::matrix_float, stcopy, A, B, C, beta, beta.host_float, programs, kernel_prefix); + else if (C.numeric_type==scheduler::DOUBLE_TYPE) + enqueue_impl<double>(&scheduler::lhs_rhs_element::matrix_double, stcopy, A, B, C, beta, beta.host_double, programs, kernel_prefix); + else + throw generator_not_supported_exception("GEMM only supported for float/double"); + + } + +private: + const char A_trans_; + const char B_trans_; +}; + +} + +} + +#endif http://git-wip-us.apache.org/repos/asf/mahout/blob/f7c1f802/native-viennaCL/src/main/cpp/viennacl/device_specific/templates/template_base.hpp ---------------------------------------------------------------------- diff --git a/native-viennaCL/src/main/cpp/viennacl/device_specific/templates/template_base.hpp b/native-viennaCL/src/main/cpp/viennacl/device_specific/templates/template_base.hpp new file mode 100644 index 0000000..40e3168 --- /dev/null +++ b/native-viennaCL/src/main/cpp/viennacl/device_specific/templates/template_base.hpp @@ -0,0 +1,596 @@ +#ifndef VIENNACL_DEVICE_SPECIFIC_TEMPLATES_TEMPLATE_BASE_ +#define VIENNACL_DEVICE_SPECIFIC_TEMPLATES_TEMPLATE_BASE_ + +/* ========================================================================= + Copyright (c) 2010-2016, Institute for Microelectronics, + Institute for Analysis and Scientific Computing, + TU Wien. + Portions of this software are copyright by UChicago Argonne, LLC. + + ----------------- + ViennaCL - The Vienna Computing Library + ----------------- + + Project Head: Karl Rupp [email protected] + + (A list of authors and contributors can be found in the manual) + + License: MIT (X11), see file LICENSE in the base directory +============================================================================= */ + + +/** @file viennacl/device_specific/templates/template_base.hpp + * + * Base classes for the profiles +*/ + +#include <list> +#include <set> + +#include "viennacl/ocl/kernel.hpp" +#include "viennacl/ocl/device.hpp" +#include "viennacl/ocl/device_utils.hpp" + +#include "viennacl/scheduler/forwards.h" +#include "viennacl/scheduler/io.hpp" + +#include "viennacl/device_specific/lazy_program_compiler.hpp" +#include "viennacl/device_specific/mapped_objects.hpp" +#include "viennacl/device_specific/tree_parsing.hpp" +#include "viennacl/device_specific/utils.hpp" + +namespace viennacl +{ +namespace device_specific +{ + +enum fetching_policy_type +{ + FETCH_FROM_LOCAL, + FETCH_FROM_GLOBAL_STRIDED, + FETCH_FROM_GLOBAL_CONTIGUOUS +}; + +class template_base +{ +public: + struct parameters_type + { + parameters_type(unsigned int _simd_width, unsigned int _local_size_1, unsigned int _local_size_2, unsigned int _num_kernels) : simd_width(_simd_width), local_size_0(_local_size_1), local_size_1(_local_size_2), num_kernels(_num_kernels){ } + + unsigned int simd_width; + unsigned int local_size_0; + unsigned int local_size_1; + unsigned int num_kernels; + }; + +private: + /** @brief Functor to map the statements to the types defined in mapped_objects.hpp */ + class map_functor : public tree_parsing::traversal_functor + { + + scheduler::statement_node_numeric_type numeric_type(scheduler::statement const * statement, vcl_size_t root_idx) const + { + scheduler::statement_node const * root_node = &statement->array()[root_idx]; + while (root_node->lhs.numeric_type==scheduler::INVALID_NUMERIC_TYPE) + root_node = &statement->array()[root_node->lhs.node_index]; + return root_node->lhs.numeric_type; + } + + public: + typedef tools::shared_ptr<mapped_object> result_type; + + map_functor(symbolic_binder & binder, mapping_type & mapping) : binder_(binder), mapping_(mapping){ } + + /** @brief Binary leaf */ + template<class T> + result_type binary_leaf(scheduler::statement const * statement, vcl_size_t root_idx, mapping_type const * mapping) const + { + return result_type(new T(utils::numeric_type_to_string(numeric_type(statement,root_idx)), binder_.get(NULL), mapped_object::node_info(mapping, statement, root_idx))); + } + + template<class NumericT> + result_type operator()(NumericT const & /*scalar*/) const + { + return result_type(new mapped_host_scalar(utils::type_to_string<NumericT>::value(), binder_.get(NULL))); + } + + /** @brief Scalar mapping */ + template<class NumericT> + result_type operator()(scalar<NumericT> const & scal) const + { + return result_type(new mapped_scalar(utils::type_to_string<NumericT>::value(), binder_.get(&viennacl::traits::handle(scal)))); + } + + /** @brief Vector mapping */ + template<class NumericT> + result_type operator()(vector_base<NumericT> const & vec) const + { + return result_type(new mapped_vector(utils::type_to_string<NumericT>::value(), binder_.get(&viennacl::traits::handle(vec)))); + } + + /** @brief Implicit vector mapping */ + template<class NumericT> + result_type operator()(implicit_vector_base<NumericT> const & /*vec*/) const + { + return result_type(new mapped_implicit_vector(utils::type_to_string<NumericT>::value(), binder_.get(NULL))); + } + + /** @brief Matrix mapping */ + template<class NumericT> + result_type operator()(matrix_base<NumericT> const & mat) const + { + return result_type(new mapped_matrix(utils::type_to_string<NumericT>::value(), binder_.get(&viennacl::traits::handle(mat)), + viennacl::traits::row_major(mat))); + } + + /** @brief Implicit matrix mapping */ + template<class NumericT> + result_type operator()(implicit_matrix_base<NumericT> const & /*mat*/) const + { + return result_type(new mapped_implicit_matrix(utils::type_to_string<NumericT>::value(), binder_.get(NULL))); + } + + /** @brief Traversal functor */ + void operator()(scheduler::statement const & statement, vcl_size_t root_idx, leaf_t leaf_t) const { + mapping_type::key_type key(root_idx, leaf_t); + scheduler::statement_node const & root_node = statement.array()[root_idx]; + + if (leaf_t == LHS_NODE_TYPE && root_node.lhs.type_family != scheduler::COMPOSITE_OPERATION_FAMILY) + mapping_.insert(mapping_type::value_type(key, utils::call_on_element(root_node.lhs, *this))); + else if (leaf_t == RHS_NODE_TYPE && root_node.rhs.type_family != scheduler::COMPOSITE_OPERATION_FAMILY) + mapping_.insert(mapping_type::value_type(key, utils::call_on_element(root_node.rhs, *this))); + else if ( leaf_t== PARENT_NODE_TYPE) + { + if (root_node.op.type==scheduler::OPERATION_BINARY_VECTOR_DIAG_TYPE) + mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_vector_diag>(&statement, root_idx, &mapping_))); + else if (root_node.op.type==scheduler::OPERATION_BINARY_MATRIX_DIAG_TYPE) + mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_matrix_diag>(&statement, root_idx, &mapping_))); + else if (root_node.op.type==scheduler::OPERATION_BINARY_MATRIX_ROW_TYPE) + mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_matrix_row>(&statement, root_idx, &mapping_))); + else if (root_node.op.type==scheduler::OPERATION_BINARY_MATRIX_COLUMN_TYPE) + mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_matrix_column>(&statement, root_idx, &mapping_))); + else if (is_scalar_reduction(root_node)) + mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_scalar_reduction>(&statement, root_idx, &mapping_))); + else if (is_vector_reduction(root_node)) + mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_row_wise_reduction>(&statement, root_idx, &mapping_))); + else if (root_node.op.type == scheduler::OPERATION_BINARY_MAT_MAT_PROD_TYPE) + mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_matrix_product>(&statement, root_idx, &mapping_))); + else if (root_node.op.type == scheduler::OPERATION_UNARY_TRANS_TYPE) + mapping_.insert(mapping_type::value_type(key, binary_leaf<mapped_trans>(&statement, root_idx, &mapping_))); + } + } + + private: + symbolic_binder & binder_; + mapping_type & mapping_; + }; + + /** @brief functor for generating the prototype of a statement */ + class prototype_generation_traversal : public tree_parsing::traversal_functor + { + private: + std::set<std::string> & already_generated_; + std::string & str_; + mapping_type const & mapping_; + std::map<std::string, unsigned int> const & widths_; + public: + prototype_generation_traversal(std::set<std::string> & already_generated, std::string & str, mapping_type const & mapping, std::map<std::string, unsigned int> const & widths) : + already_generated_(already_generated), str_(str), mapping_(mapping), widths_(widths){ } + + void operator()(scheduler::statement const & statement, vcl_size_t root_idx, leaf_t leaf) const + { + scheduler::statement_node const & root_node = statement.array()[root_idx]; + if ( (leaf==LHS_NODE_TYPE && root_node.lhs.type_family!=scheduler::COMPOSITE_OPERATION_FAMILY) + ||(leaf==RHS_NODE_TYPE && root_node.rhs.type_family!=scheduler::COMPOSITE_OPERATION_FAMILY) ) + { + mapped_object * obj = at(mapping_, std::make_pair(root_idx,leaf)).get(); + if(widths_.find(obj->name())!=widths_.end()) + obj->append_kernel_arguments(already_generated_, str_, at(widths_, obj->name())); + else + obj->append_kernel_arguments(already_generated_, str_, 1); + } + } + }; + + + + /** @brief functor for setting the arguments of a kernel */ + class set_arguments_functor : public tree_parsing::traversal_functor + { + public: + typedef void result_type; + + set_arguments_functor(symbolic_binder & binder, unsigned int & current_arg, viennacl::ocl::kernel & kernel) : binder_(binder), current_arg_(current_arg), kernel_(kernel){ } + + template<class NumericT> + result_type operator()(NumericT const & scal) const { + typedef typename viennacl::result_of::cl_type<NumericT>::type cl_scalartype; + kernel_.arg(current_arg_++, cl_scalartype(scal)); + } + + /** @brief Scalar mapping */ + template<class NumericT> + result_type operator()(scalar<NumericT> const & scal) const { + if (binder_.bind(&viennacl::traits::handle(scal))) + kernel_.arg(current_arg_++, scal.handle().opencl_handle()); + } + + /** @brief Vector mapping */ + template<class NumericT> + result_type operator()(vector_base<NumericT> const & vec) const { + if (binder_.bind(&viennacl::traits::handle(vec))) + { + kernel_.arg(current_arg_++, vec.handle().opencl_handle()); + kernel_.arg(current_arg_++, cl_uint(viennacl::traits::start(vec))); + kernel_.arg(current_arg_++, cl_uint(viennacl::traits::stride(vec))); + } + } + + /** @brief Implicit vector mapping */ + template<class NumericT> + result_type operator()(implicit_vector_base<NumericT> const & vec) const + { + typedef typename viennacl::result_of::cl_type<NumericT>::type cl_scalartype; + kernel_.arg(current_arg_++, cl_scalartype(vec.value())); + if (vec.has_index()) + kernel_.arg(current_arg_++, cl_uint(vec.index())); + } + + /** @brief Matrix mapping */ + template<class NumericT> + result_type operator()(matrix_base<NumericT> const & mat) const + { + if (binder_.bind(&viennacl::traits::handle(mat))) + { + kernel_.arg(current_arg_++, mat.handle().opencl_handle()); + kernel_.arg(current_arg_++, cl_uint(viennacl::traits::ld(mat))); + if (mat.row_major()) + { + kernel_.arg(current_arg_++, cl_uint(viennacl::traits::start2(mat))); + kernel_.arg(current_arg_++, cl_uint(viennacl::traits::start1(mat))); + kernel_.arg(current_arg_++, cl_uint(viennacl::traits::stride2(mat))); + kernel_.arg(current_arg_++, cl_uint(viennacl::traits::stride1(mat))); + } + else + { + kernel_.arg(current_arg_++, cl_uint(viennacl::traits::start1(mat))); + kernel_.arg(current_arg_++, cl_uint(viennacl::traits::start2(mat))); + kernel_.arg(current_arg_++, cl_uint(viennacl::traits::stride1(mat))); + kernel_.arg(current_arg_++, cl_uint(viennacl::traits::stride2(mat))); + } + } + } + + /** @brief Implicit matrix mapping */ + template<class NumericT> + result_type operator()(implicit_matrix_base<NumericT> const & mat) const + { + kernel_.arg(current_arg_++, typename viennacl::result_of::cl_type<NumericT>::type(mat.value())); + } + + /** @brief Traversal functor: */ + void operator()(scheduler::statement const & statement, vcl_size_t root_idx, leaf_t leaf_t) const + { + scheduler::statement_node const & root_node = statement.array()[root_idx]; + if (leaf_t==LHS_NODE_TYPE && root_node.lhs.type_family != scheduler::COMPOSITE_OPERATION_FAMILY) + utils::call_on_element(root_node.lhs, *this); + else if (leaf_t==RHS_NODE_TYPE && root_node.rhs.type_family != scheduler::COMPOSITE_OPERATION_FAMILY) + utils::call_on_element(root_node.rhs, *this); + } + + private: + symbolic_binder & binder_; + unsigned int & current_arg_; + viennacl::ocl::kernel & kernel_; + }; + +protected: + + static void generate_prototype(utils::kernel_generation_stream & stream, std::string const & name, std::string const & first_arguments, std::vector<mapping_type> const & mappings, statements_container const &statements, + std::map<std::string, unsigned int> const & widths) + { + statements_container::data_type::const_iterator sit; + std::vector<mapping_type>::const_iterator mit; + std::set<std::string> already_generated; + + std::string arguments = first_arguments; + for (mit = mappings.begin(), sit = statements.data().begin(); sit != statements.data().end(); ++sit, ++mit) + tree_parsing::traverse(*sit, sit->root(), prototype_generation_traversal(already_generated, arguments, *mit, widths), true); + arguments.erase(arguments.size()-1); //Last comma pruned + stream << "__kernel " << "void " << name << "(" << arguments << ")" << std::endl; + } + + static void generate_prototype(utils::kernel_generation_stream & stream, std::string const & name, std::string const & first_arguments, std::vector<mapping_type> const & mappings, statements_container const & statements) + { + generate_prototype(stream, name, first_arguments, mappings, statements, std::map<std::string, unsigned int>()); + } + + void set_arguments(statements_container const & statements, viennacl::ocl::kernel & kernel, unsigned int & current_arg) + { + tools::shared_ptr<symbolic_binder> binder = make_binder(binding_policy_); + for (statements_container::data_type::const_iterator itt = statements.data().begin(); itt != statements.data().end(); ++itt) + tree_parsing::traverse(*itt, itt->root(), set_arguments_functor(*binder,current_arg,kernel), true); + } + + class invalid_template_exception : public std::exception + { + public: + invalid_template_exception() : message_() {} + invalid_template_exception(std::string message) : + message_("ViennaCL: Internal error: The generator cannot apply the given template to the given statement: " + message + "\n" + "If you are using a builtin template, please report on [email protected]! We will provide a fix as soon as possible\n" + "If you are using your own template, please try using other parameters") {} + virtual const char* what() const throw() { return message_.c_str(); } + virtual ~invalid_template_exception() throw() {} + private: + std::string message_; + }; + + static void fetching_loop_info(fetching_policy_type policy, std::string const & bound, utils::kernel_generation_stream & stream, std::string & init, std::string & upper_bound, std::string & inc, std::string const & domain_id, std::string const & domain_size) + { + if (policy==FETCH_FROM_GLOBAL_STRIDED) + { + init = domain_id; + upper_bound = bound; + inc = domain_size; + } + else if (policy==FETCH_FROM_GLOBAL_CONTIGUOUS) + { + std::string chunk_size = "chunk_size"; + std::string chunk_start = "chunk_start"; + std::string chunk_end = "chunk_end"; + + stream << "unsigned int " << chunk_size << " = (" << bound << "+" << domain_size << "-1)/" << domain_size << ";" << std::endl; + stream << "unsigned int " << chunk_start << " =" << domain_id << "*" << chunk_size << ";" << std::endl; + stream << "unsigned int " << chunk_end << " = min(" << chunk_start << "+" << chunk_size << ", " << bound << ");" << std::endl; + init = chunk_start; + upper_bound = chunk_end; + inc = "1"; + } + } + + static bool is_node_trans(scheduler::statement::container_type const & array, vcl_size_t root_idx, leaf_t leaf_type) + { + bool res = false; + scheduler::lhs_rhs_element scheduler::statement_node::*ptr; + if (leaf_type==LHS_NODE_TYPE) + ptr = &scheduler::statement_node::lhs; + else + ptr = &scheduler::statement_node::rhs; + scheduler::statement_node const * node = &array[root_idx]; + while ((node->*ptr).type_family==scheduler::COMPOSITE_OPERATION_FAMILY) + { + if (array[(node->*ptr).node_index].op.type==scheduler::OPERATION_UNARY_TRANS_TYPE) + res = !res; + node = &array[(node->*ptr).node_index]; + } + return res; + } + +protected: + + static std::string append_simd_suffix(std::string const & str, unsigned int i) + { + assert(i < 16); + static char suffixes[] = {'0','1','2','3','4','5','6','7','8','9', + 'a','b','c','d','e','f'}; + return str + tools::to_string(suffixes[i]); + } + + static bool is_striding_operator(scheduler::statement_node const & node) + { + return node.op.type==scheduler::OPERATION_BINARY_MATRIX_COLUMN_TYPE + || node.op.type==scheduler::OPERATION_BINARY_MATRIX_ROW_TYPE + || node.op.type==scheduler::OPERATION_BINARY_MATRIX_DIAG_TYPE; + } + + static bool has_strided_access(statements_container const & statements) + { + for (statements_container::data_type::const_iterator it = statements.data().begin(); it != statements.data().end(); ++it) + { + //checks for vectors + std::vector<scheduler::lhs_rhs_element> vectors; + tree_parsing::traverse(*it, it->root(), tree_parsing::filter_elements(scheduler::DENSE_VECTOR_TYPE, vectors), true); + for (std::vector<scheduler::lhs_rhs_element>::iterator itt = vectors.begin(); itt != vectors.end(); ++itt) + if (utils::call_on_vector(*itt, utils::stride_fun())>1) + return true; + + //checks for matrix + std::vector<scheduler::lhs_rhs_element> matrices; + tree_parsing::traverse(*it, it->root(), tree_parsing::filter_elements(scheduler::DENSE_MATRIX_TYPE, matrices), true); + for (std::vector<scheduler::lhs_rhs_element>::iterator itt = matrices.begin(); itt != matrices.end(); ++itt) + if (utils::call_on_matrix(*itt, utils::stride1_fun())>1 || utils::call_on_matrix(*itt, utils::stride2_fun())>2) + return true; + + std::vector<vcl_size_t> striding_operators; + tree_parsing::traverse(*it, it->root(), tree_parsing::filter(&is_striding_operator, striding_operators), false); + if(striding_operators.size() > 0) + return true; + } + return false; + } + + static vcl_size_t vector_size(scheduler::statement_node const & node, bool up_to_internal_size) + { + using namespace scheduler; + using namespace utils; + if (node.op.type==OPERATION_BINARY_MATRIX_DIAG_TYPE) + { + vcl_size_t size1 = up_to_internal_size?call_on_matrix(node.lhs, internal_size1_fun()):call_on_matrix(node.lhs, size1_fun()); + vcl_size_t size2 = up_to_internal_size?call_on_matrix(node.lhs, internal_size2_fun()):call_on_matrix(node.lhs, size2_fun()); + return std::min<vcl_size_t>(size1, size2); + } + else if (node.op.type==OPERATION_BINARY_MATRIX_ROW_TYPE) + return up_to_internal_size?call_on_matrix(node.lhs, internal_size2_fun()):call_on_matrix(node.lhs, size2_fun()); + else if (node.op.type==OPERATION_BINARY_MATRIX_COLUMN_TYPE) + return up_to_internal_size?call_on_matrix(node.lhs, internal_size1_fun()):call_on_matrix(node.lhs, size1_fun()); + else + return up_to_internal_size?call_on_vector(node.lhs, internal_size_fun()):call_on_vector(node.lhs, size_fun()); + } + + //NB : templates are not used here because declaring a functor out of the generate() functions would be harder to read + struct loop_body_base + { + virtual void operator()(utils::kernel_generation_stream & stream, unsigned int simd_width) const = 0; + virtual ~loop_body_base() {} + }; + + static void element_wise_loop_1D(utils::kernel_generation_stream & stream, loop_body_base const & loop_body, + fetching_policy_type fetch, unsigned int simd_width, std::string const & i, std::string const & bound, std::string const & domain_id, std::string const & domain_size) + { + std::string strwidth = tools::to_string(simd_width); + std::string boundround = bound + "/" + strwidth; + + std::string init, upper_bound, inc; + fetching_loop_info(fetch, boundround, stream, init, upper_bound, inc, domain_id, domain_size); + stream << "for(unsigned int " << i << " = " << init << "; " << i << " < " << upper_bound << "; " << i << " += " << inc << ")" << std::endl; + stream << "{" << std::endl; + stream.inc_tab(); + loop_body(stream, simd_width); + stream.dec_tab(); + stream << "}" << std::endl; + + if (simd_width>1) + { + stream << "for(unsigned int " << i << " = " << boundround << "*" << strwidth << " + " << domain_id << "; " << i << " < " << bound << "; " << i << " += " + domain_size + ")" << std::endl; + stream << "{" << std::endl; + stream.inc_tab(); + loop_body(stream, 1); + stream.dec_tab(); + stream << "}" << std::endl; + } + } + + static std::string vstore(unsigned int simd_width, std::string const & value, std::string const & offset, std::string const & ptr) + { + if (simd_width==1) + return "(" + ptr + ")[" + offset + "] = " + value; + else + return utils::append_width("vstore", simd_width) + "(" + value + ", " + offset + ", " + ptr + ")"; + } + + static std::string vload(unsigned int simd_width, std::string const & offset, std::string const & ptr) + { + if (simd_width==1) + return "(" + ptr + ")[" + offset + "]"; + else + return utils::append_width("vload", simd_width) + "(" + offset + ", " + ptr + ")"; + } + +private: + /** @brief Generates the body of the associated kernel function */ + virtual std::vector<std::string> generate_impl(std::string const & kernel_prefix, statements_container const & statements, std::vector<mapping_type> const & mapping) const = 0; + +public: + template_base(binding_policy_t binding_policy) : binding_policy_(binding_policy) {} + + virtual ~template_base(){ } + + std::vector<std::string> generate(std::string const & kernel_prefix, statements_container const & statements, viennacl::ocl::device const & device) + { + statements_container::data_type::const_iterator sit; + std::vector<mapping_type>::iterator mit; + + if(int err = check_invalid(statements, device)) + throw generator_not_supported_exception("The supplied parameters for this template are invalid : err " + tools::to_string(err)); + + //Create mapping + std::vector<mapping_type> mappings(statements.data().size()); + tools::shared_ptr<symbolic_binder> binder = make_binder(binding_policy_); + for (mit = mappings.begin(), sit = statements.data().begin(); sit != statements.data().end(); ++sit, ++mit) + tree_parsing::traverse(*sit, sit->root(), map_functor(*binder,*mit), true); + + return generate_impl(kernel_prefix, statements, mappings); + } + + /** @brief returns whether or not the profile has undefined behavior on particular device */ + virtual int check_invalid(statements_container const & statements, viennacl::ocl::device const & device) const = 0; + + virtual void enqueue(std::string const & kernel_prefix, std::vector<lazy_program_compiler> & programs, statements_container const & statements) = 0; + + virtual tools::shared_ptr<template_base> clone() const = 0; +private: + binding_policy_t binding_policy_; +}; + + +template<class TemplateType, class ParametersType> +class template_base_impl : public template_base +{ +private: + virtual int check_invalid_impl(viennacl::ocl::device const & /*dev*/) const { return TEMPLATE_VALID; } + + virtual unsigned int n_lmem_elements() const { return 0; } + +public: + typedef ParametersType parameters_type; + + /** @brief The constructor */ + template_base_impl(parameters_type const & parameters, binding_policy_t binding_policy) : template_base(binding_policy), p_(parameters){ } + + parameters_type const & parameters() const + { + return p_; + } + + tools::shared_ptr<template_base> clone() const + { + return tools::shared_ptr<template_base>(new TemplateType(*dynamic_cast<TemplateType const *>(this))); + } + + /** @brief returns whether or not the profile has undefined behavior on particular device */ + int check_invalid(statements_container const & statements, viennacl::ocl::device const & device) const + { + using namespace viennacl::tools; + + scheduler::statement const & statement = statements.data().front(); + unsigned int scalartype_size = utils::size_of(lhs_most(statement.array(), statement.root()).lhs.numeric_type); + + //Query device informations + vcl_size_t lmem_available = static_cast<vcl_size_t>(device.local_mem_size()); + vcl_size_t lmem_usage = scalartype_size*n_lmem_elements(); + if (lmem_usage>lmem_available) + return TEMPLATE_LOCAL_MEMORY_OVERFLOW; + + //Invalid work group size + vcl_size_t max_workgroup_size = device.max_work_group_size(); + std::vector<vcl_size_t> max_work_item_sizes = device.max_work_item_sizes(); + if (p_.local_size_0*p_.local_size_1 > max_workgroup_size) + return TEMPLATE_WORK_GROUP_SIZE_OVERFLOW; + if (p_.local_size_0 > max_work_item_sizes[0]) + return TEMPLATE_LOCAL_SIZE_0_OVERFLOW; + + if (p_.local_size_1 > max_work_item_sizes[1]) + return TEMPLATE_LOCAL_SIZE_1_OVERFLOW; + + //Advice from the Intel guide + unsigned int warp_size = 8; + if (device.type()==CL_DEVICE_TYPE_GPU) + { + //Advice from the nvidia guide + warp_size = 32; + //Advice from the AMD guide + if (device.vendor_id()==4098) + warp_size = 64; + } + if (((p_.local_size_0*p_.local_size_1)%warp_size)>0) + return TEMPLATE_LOCAL_SIZE_NOT_WARP_MULTIPLE; + + //Invalid SIMD Width + if (p_.simd_width!=1 && p_.simd_width!=2 && + p_.simd_width!=4 && p_.simd_width!=8 && + p_.simd_width!=16) + return TEMPLATE_INVALID_SIMD_WIDTH; + + return check_invalid_impl(device); + } + +protected: + parameters_type p_; +}; + +} +} + +#endif
