haojin2 commented on a change in pull request #10550: [MXNET-320] Support elemwise_add/sub/max/min/hypot between dense and csr tensors URL: https://github.com/apache/incubator-mxnet/pull/10550#discussion_r181539348
########## File path: src/operator/tensor/elemwise_binary_op-inl.h ########## @@ -374,6 +374,72 @@ void ElemwiseBinaryOp::CsrCsrOp(mshadow::Stream<cpu> *s, } } +template<typename OP> +struct ElemwiseDnsMapKernel { + template<typename DType> + static void inline Map(int i, const OpReqType req, DType* out, const DType* dns_data, + const nnvm::dim_t num_rows, const nnvm::dim_t num_cols) { + if (i < num_rows*num_cols) { + KERNEL_ASSIGN(out[i], req, OP::Map(dns_data[i], DType(0.0f))); + } + } +}; + +template<typename OP> +struct ElemwiseDnsCsrDnsKernel { + template<typename DType, typename IType, typename CType> + static void inline Map(int i, const OpReqType req, DType* out, DType* dns_data, + const DType* csr_data, const IType* csr_indices, const CType* csr_indptr, + const nnvm::dim_t num_rows, const nnvm::dim_t num_cols) { + if (i < num_rows) { + for (int j = csr_indptr[i]; j < csr_indptr[i+1]; ++j) { + KERNEL_ASSIGN(out[i * num_cols + csr_indices[j]], req, + OP::Map(dns_data[i * num_cols + csr_indices[j]], csr_data[j])); + } + } + } +}; + +/*! \brief CSR -op- CSR binary operator for non-canonical NDArray */ Review comment: Will change the comment soon, forgot to change it when copying it over from above. ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services