eric-haibin-lin closed pull request #10894: [MXNET-399] Elemwise_mul between
dense and csr on CPU & GPU
URL: https://github.com/apache/incubator-mxnet/pull/10894
This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:
As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):
diff --git a/src/operator/tensor/elemwise_binary_op-inl.h
b/src/operator/tensor/elemwise_binary_op-inl.h
index c74f1f93603..911c369b3e6 100644
--- a/src/operator/tensor/elemwise_binary_op-inl.h
+++ b/src/operator/tensor/elemwise_binary_op-inl.h
@@ -495,6 +495,91 @@ void ElemwiseBinaryOp::DnsCsrDnsOp(mshadow::Stream<xpu> *s,
});
}
+/*!
+ * \brief Kernel for performing elemwise op between dense and csr matrix
+ * \param i global thread id
+ * \param req type of request
+ * \param out output array
+ * \param dns_data data array of dense input
+ * \param csr_data data array of csr input
+ * \param csr_indices indices array of csr input
+ * \param csr_indptr indptr array of csr input
+ * \param num_rows number of rows of both inputs
+ * \param num_cols number of columns of both inputs
+ */
+template<int req, typename OP, bool reverse>
+struct ElemwiseDnsCsrCsrKernel {
+ template<typename DType, typename IType, typename CType>
+ MSHADOW_XINLINE static void Map(int i, DType* out, DType* dns_data,
+ const DType* csr_data, const IType*
csr_indices,
+ const CType* csr_indptr, const nnvm::dim_t
num_rows,
+ const nnvm::dim_t num_cols) {
+ if (i < num_rows) {
+ for (int j = csr_indptr[i]; j < csr_indptr[i+1]; ++j) {
+ KERNEL_ASSIGN(out[j], req, reverse ?
+ OP::Map(dns_data[i * num_cols +
csr_indices[j]], csr_data[j]) :
+ OP::Map(csr_data[j], dns_data[i * num_cols
+ csr_indices[j]]));
+ }
+ }
+ }
+};
+
+/*! \brief DNS -op- CSR binary operator for non-canonical NDArray */
+template<typename xpu, typename OP>
+void ElemwiseBinaryOp::DnsCsrCsrOp(const nnvm::NodeAttrs &attrs,
+ const OpContext &ctx,
+ const NDArray &dns,
+ const NDArray &csr,
+ const OpReqType req,
+ const NDArray &output,
+ const bool reverse) {
+ using namespace mshadow;
+ using namespace mxnet_op;
+ using namespace csr;
+ CHECK_EQ(dns.storage_type(), kDefaultStorage);
+ CHECK_EQ(csr.storage_type(), kCSRStorage);
+ CHECK_EQ(req, kWriteTo) << "elemwise(dns, csr) = csr only supports kWriteTo";
+ if (req == kNullOp) return;
+ const bool supported_op = std::is_same<OP, mshadow_op::mul>::value;
+ CHECK(supported_op == true) << "elemwise(dns, csr) = csr only supports mul";
+ const nnvm::dim_t num_csr_rows = csr.shape()[0];
+ const nnvm::dim_t num_csr_cols = csr.shape()[1];
+ const nnvm::dim_t nnz = csr.storage_shape()[0];
+ Stream<xpu> *s = ctx.get_stream<xpu>();
+
+ output.CheckAndAlloc({Shape1(num_csr_rows + 1), Shape1(nnz)});
+ if (csr.storage_initialized()) {
+ TBlob csr_data = csr.data();
+ TBlob csr_indices = csr.aux_data(kIdx);
+ TBlob csr_indptr = csr.aux_data(kIndPtr);
+ MSHADOW_SGL_DBL_TYPE_SWITCH(csr_data.type_flag_, DType, {
+ MSHADOW_IDX_TYPE_SWITCH(csr_indices.type_flag_, IType, {
+ MSHADOW_IDX_TYPE_SWITCH(csr_indptr.type_flag_, CType, {
+ MXNET_ASSIGN_REQ_SWITCH(req, Req, {
+ if (reverse) {
+ Kernel<ElemwiseDnsCsrCsrKernel<Req, OP, true>, xpu>::Launch(
+ s, num_csr_rows, output.data().dptr<DType>(),
dns.data().dptr<DType>(),
+ csr_data.dptr<DType>(), csr_indices.dptr<IType>(),
csr_indptr.dptr<CType>(),
+ num_csr_rows, num_csr_cols);
+ } else {
+ Kernel<ElemwiseDnsCsrCsrKernel<Req, OP, false>, xpu>::Launch(
+ s, num_csr_rows, output.data().dptr<DType>(),
dns.data().dptr<DType>(),
+ csr_data.dptr<DType>(), csr_indices.dptr<IType>(),
csr_indptr.dptr<CType>(),
+ num_csr_rows, num_csr_cols);
+ }
+ Copy(output.aux_data(kIdx).FlatTo1D<xpu, IType>(),
+ csr.aux_data(kIdx).FlatTo1D<xpu, IType>(), s);
+ Copy(output.aux_data(kIndPtr).FlatTo1D<xpu, CType>(),
+ csr.aux_data(kIndPtr).FlatTo1D<xpu, CType>(), s);
+ });
+ });
+ });
+ });
+ } else {
+ FillZerosCsrImpl(s, output);
+ }
+}
+
/*!
* \brief Kernel for performing elemwise op between dense and rsp tensor
* \param i global thread id
diff --git a/src/operator/tensor/elemwise_binary_op.cc
b/src/operator/tensor/elemwise_binary_op.cc
index e8ba2fa7234..9ccbacc2f65 100644
--- a/src/operator/tensor/elemwise_binary_op.cc
+++ b/src/operator/tensor/elemwise_binary_op.cc
@@ -63,6 +63,11 @@ bool ElemwiseBinaryOp::BackwardUseInStorageType(const
nnvm::NodeAttrs& attrs,
const bool invalid_ctx = dev_mask != mshadow::cpu::kDevMask;
const auto dispatch_ex = invalid_ctx ? DispatchMode::kFComputeFallback :
DispatchMode::kFComputeEx;
+ const int ograd_stype = in_attrs->at(0);
+ const int lhs_stype = in_attrs->at(1);
+ const int rhs_stype = in_attrs->at(2);
+ int& lhs_grad_stype = out_attrs->at(0);
+ int& rhs_grad_stype = out_attrs->at(1);
if (!dispatched && common::ContainsOnlyStorage(*in_attrs, kDefaultStorage)) {
dispatched = storage_type_assign(out_attrs, kDefaultStorage,
dispatch_mode, DispatchMode::kFCompute);
@@ -74,6 +79,22 @@ bool ElemwiseBinaryOp::BackwardUseInStorageType(const
nnvm::NodeAttrs& attrs,
dispatch_mode, dispatch_ex);
}
}
+ if (!dispatched && ograd_stype == kDefaultStorage &&
+ ((lhs_stype == kCSRStorage && rhs_stype == kDefaultStorage) ||
+ (lhs_stype == kDefaultStorage && rhs_stype == kCSRStorage))) {
+ const bool reverse = (lhs_stype == kCSRStorage);
+ if (reverse &&
+ type_assign(&lhs_grad_stype, kDefaultStorage) &&
+ type_assign(&rhs_grad_stype, kCSRStorage)) {
+ DISPATCH_MODE_ASSIGN_CHECK(dispatch_mode, 0, DispatchMode::kFComputeEx);
+ dispatched = true;
+ } else if (!reverse &&
+ type_assign(&lhs_grad_stype, kCSRStorage) &&
+ type_assign(&rhs_grad_stype, kDefaultStorage)) {
+ DISPATCH_MODE_ASSIGN_CHECK(dispatch_mode, 0, DispatchMode::kFComputeEx);
+ dispatched = true;
+ }
+ }
if (!dispatched) {
dispatched = dispatch_fallback(out_attrs, dispatch_mode);
}
diff --git a/src/operator/tensor/elemwise_binary_op.h
b/src/operator/tensor/elemwise_binary_op.h
index a5b73dadd3a..ad4b3e7cc4a 100644
--- a/src/operator/tensor/elemwise_binary_op.h
+++ b/src/operator/tensor/elemwise_binary_op.h
@@ -165,12 +165,11 @@ class ElemwiseBinaryOp : public OpBase {
typename xpu,
typename LOP,
typename ROP,
- typename DType,
bool in0_ok_dense = false,
bool in1_ok_dense = false,
bool in2_ok_dense = false,
typename BackupCompute>
- static inline void BackwardUseInEx_(const nnvm::NodeAttrs &attrs,
+ static inline void RspRspOpBackward(const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const std::vector<NDArray> &inputs,
const std::vector<OpReqType> &req,
@@ -200,6 +199,33 @@ class ElemwiseBinaryOp : public OpBase {
}
}
+ template<typename xpu, typename LOP, typename ROP>
+ static inline void DnsCsrCsrOpBackward(const nnvm::NodeAttrs &attrs,
+ const OpContext &ctx,
+ const std::vector<NDArray> &inputs,
+ const std::vector<OpReqType> &req,
+ const std::vector<NDArray> &outputs) {
+ const bool supported_ops = std::is_same<mshadow_op::right, LOP>::value &&
+ std::is_same<mshadow_op::left, ROP>::value;
+ CHECK(supported_ops)
+ << "Only backward for mul is supported (LOP should be right, ROP should
be left)";
+ const NDArray& out_grad = inputs[0];
+ const NDArray& lhs_in = inputs[1];
+ const NDArray& rhs_in = inputs[2];
+ const NDArray& lhs_grad = outputs[0];
+ const NDArray& rhs_grad = outputs[1];
+ const bool reverse = (outputs[0].storage_type() == kCSRStorage);
+ if (reverse) {
+ DnsCsrCsrOp<xpu, mshadow_op::mul>(attrs, ctx, out_grad, rhs_in, req[0],
lhs_grad, false);
+ Compute<xpu, mshadow_op::mul>(attrs, ctx, {out_grad.data(),
lhs_in.data()}, {req[1]},
+ {rhs_grad.data()});
+ } else {
+ DnsCsrCsrOp<xpu, mshadow_op::mul>(attrs, ctx, out_grad, lhs_in, req[1],
rhs_grad, false);
+ Compute<xpu, mshadow_op::mul>(attrs, ctx, {out_grad.data(),
rhs_in.data()}, {req[0]},
+ {lhs_grad.data()});
+ }
+ }
+
public:
/*! \brief Binary op handling for lhr/rhs: RspDns, RspRsp, DnsRsp, or
RspRsp->Dns result */
template<typename OP>
@@ -232,44 +258,54 @@ class ElemwiseBinaryOp : public OpBase {
/*! \brief CSR -op- CSR binary operator for non-canonical NDArray */
template<typename OP>
static void CsrCsrOp(mshadow::Stream<cpu> *s,
- const nnvm::NodeAttrs &attrs,
- const OpContext &ctx,
- const NDArray &lhs,
- const NDArray &rhs,
- OpReqType req,
- const NDArray &output);
+ const nnvm::NodeAttrs &attrs,
+ const OpContext &ctx,
+ const NDArray &lhs,
+ const NDArray &rhs,
+ OpReqType req,
+ const NDArray &output);
/*! \brief CSR -op- CSR binary operator for non-canonical NDArray */
template<typename OP>
static void CsrCsrOp(mshadow::Stream<gpu> *s,
- const nnvm::NodeAttrs &attrs,
- const OpContext &ctx,
- const NDArray &lhs,
- const NDArray &rhs,
- OpReqType req,
- const NDArray &output);
+ const nnvm::NodeAttrs &attrs,
+ const OpContext &ctx,
+ const NDArray &lhs,
+ const NDArray &rhs,
+ OpReqType req,
+ const NDArray &output);
/*! \brief DNS -op- CSR binary operator for non-canonical NDArray */
template<typename xpu, typename OP>
static void DnsCsrDnsOp(mshadow::Stream<xpu> *s,
- const nnvm::NodeAttrs &attrs,
- const OpContext &ctx,
- const NDArray &lhs,
- const NDArray &rhs,
- OpReqType req,
- const NDArray &output,
- const bool reverse);
+ const nnvm::NodeAttrs &attrs,
+ const OpContext &ctx,
+ const NDArray &lhs,
+ const NDArray &rhs,
+ OpReqType req,
+ const NDArray &output,
+ const bool reverse);
+
+ /*! \brief DNS -op- CSR binary operator for non-canonical NDArray */
+ template<typename xpu, typename OP>
+ static void DnsCsrCsrOp(const nnvm::NodeAttrs &attrs,
+ const OpContext &ctx,
+ const NDArray &lhs,
+ const NDArray &rhs,
+ OpReqType req,
+ const NDArray &output,
+ const bool reverse);
/*! \brief DNS -op- RSP binary operator for non-canonical NDArray */
template<typename xpu, typename OP>
static void DnsRspDnsOp(mshadow::Stream<xpu> *s,
- const nnvm::NodeAttrs &attrs,
- const OpContext &ctx,
- const NDArray &lhs,
- const NDArray &rhs,
- OpReqType req,
- const NDArray &output,
- const bool reverse);
+ const nnvm::NodeAttrs &attrs,
+ const OpContext &ctx,
+ const NDArray &lhs,
+ const NDArray &rhs,
+ OpReqType req,
+ const NDArray &output,
+ const bool reverse);
public:
/*!
@@ -336,6 +372,14 @@ class ElemwiseBinaryOp : public OpBase {
dispatched = storage_type_assign(&out_stype, kRowSparseStorage,
dispatch_mode, dispatch_ex);
}
+ if (!dispatched &&
+ ((lhs_stype == kCSRStorage && rhs_stype == kDefaultStorage) ||
+ (lhs_stype == kDefaultStorage && rhs_stype == kCSRStorage))) {
+ // csr, dns -> csr
+ // dns, csr -> csr
+ dispatched = storage_type_assign(&out_stype, kCSRStorage,
+ dispatch_mode,
DispatchMode::kFComputeEx);
+ }
if (!dispatched) {
dispatched = dispatch_fallback(out_attrs, dispatch_mode);
}
@@ -540,6 +584,14 @@ class ElemwiseBinaryOp : public OpBase {
req[0], outputs[0], lhs_may_be_dense, rhs_may_be_dense, false,
false);
} else if (lhs_stype == kCSRStorage && rhs_stype == kCSRStorage) {
ComputeEx<xpu, OP>(attrs, ctx, inputs, req, outputs);
+ } else if (((lhs_stype == kCSRStorage && rhs_stype == kDefaultStorage) ||
+ (lhs_stype == kDefaultStorage && rhs_stype == kCSRStorage)) &&
+ out_stype == kCSRStorage) {
+ const NDArray& dns = (lhs_stype == kDefaultStorage)? inputs[0] :
inputs[1];
+ const NDArray& csr = (lhs_stype == kCSRStorage)? inputs[0] : inputs[1];
+ const bool reverse = (lhs_stype == kCSRStorage);
+
+ DnsCsrCsrOp<xpu, OP>(attrs, ctx, dns, csr, req[0], outputs[0], reverse);
} else {
LogUnimplementedOp(attrs, ctx, inputs, req, outputs);
}
@@ -635,16 +687,21 @@ class ElemwiseBinaryOp : public OpBase {
using namespace common;
CHECK_EQ(inputs.size(), 3U);
CHECK_EQ(outputs.size(), 2U); // lhs input grad, rhs input grad
+ const auto out_grad_stype = inputs[0].storage_type();
const auto lhs_grad_stype = outputs[0].storage_type();
const auto rhs_grad_stype = outputs[1].storage_type();
if (ContainsOnlyStorage(inputs, kRowSparseStorage) &&
(lhs_grad_stype == kDefaultStorage || lhs_grad_stype ==
kRowSparseStorage) &&
(rhs_grad_stype == kDefaultStorage || rhs_grad_stype ==
kRowSparseStorage)) {
// rsp, rsp, rsp -> [dns, rsp], [dns, rsp]
- MSHADOW_TYPE_SWITCH(outputs[0].dtype(), DType, {
- BackwardUseInEx_<xpu, LOP, ROP, DType, in0_ok_dense, in1_ok_dense,
in2_ok_dense>(
- attrs, ctx, inputs, req, outputs, BackwardUseIn<xpu, LOP, ROP>);
- });
+ RspRspOpBackward<xpu, LOP, ROP, in0_ok_dense, in1_ok_dense,
in2_ok_dense>(
+ attrs, ctx, inputs, req, outputs, BackwardUseIn<xpu, LOP, ROP>);
+ }
+ if (((lhs_grad_stype == kDefaultStorage && rhs_grad_stype == kCSRStorage)
||
+ (lhs_grad_stype == kCSRStorage && rhs_grad_stype == kDefaultStorage))
&&
+ out_grad_stype == kDefaultStorage) {
+ // dns, csr, dns -> [csr, dns] / csr, dns, dns -> [dns, csr]
+ DnsCsrCsrOpBackward<xpu, LOP, ROP>(attrs, ctx, inputs, req, outputs);
}
}
}; // class ElemwiseBinaryOp
diff --git a/src/operator/tensor/elemwise_binary_op_basic.cu
b/src/operator/tensor/elemwise_binary_op_basic.cu
index 9c1fd0e14f3..5cdd8947dd4 100644
--- a/src/operator/tensor/elemwise_binary_op_basic.cu
+++ b/src/operator/tensor/elemwise_binary_op_basic.cu
@@ -51,7 +51,9 @@ NNVM_REGISTER_OP(_backward_sub)
mshadow_op::negation>);
NNVM_REGISTER_OP(elemwise_mul)
-.set_attr<FCompute>("FCompute<gpu>", ElemwiseBinaryOp::ComputeWithHalf2<gpu,
op::mshadow_op::mul>);
+.set_attr<FCompute>("FCompute<gpu>", ElemwiseBinaryOp::ComputeWithHalf2<gpu,
op::mshadow_op::mul>)
+.set_attr<FComputeEx>("FComputeEx<gpu>",
+ ElemwiseBinaryOp::ComputeDnsLRValueEx<gpu, op::mshadow_op::mul, true, true>);
NNVM_REGISTER_OP(_backward_mul)
.set_attr<FCompute>("FCompute<gpu>",
diff --git a/tests/python/unittest/test_sparse_operator.py
b/tests/python/unittest/test_sparse_operator.py
index 226db70a2ac..b2ff0fecb5a 100644
--- a/tests/python/unittest/test_sparse_operator.py
+++ b/tests/python/unittest/test_sparse_operator.py
@@ -329,9 +329,19 @@ def elemwise_mul_stype(lstype, rstype):
return 'row_sparse'
elif lstype == 'row_sparse' and rstype == 'default':
return 'row_sparse'
+ elif lstype == 'default' and rstype == 'csr':
+ return 'csr'
+ elif lstype == 'csr' and rstype == 'default':
+ return 'csr'
else:
return 'default'
+ def elemwise_mul_lhs_grad_stype(lstype, rstype):
+ return elemwise_mul_stype(elemwise_mul_stype(lstype, rstype), rstype)
+
+ def elemwise_mul_rhs_grad_stype(lstype, rstype):
+ return elemwise_mul_stype(elemwise_mul_stype(lstype, rstype), lstype)
+
def check_elemwise_binary_ops(lhs_stype, rhs_stype, shape,
lhs_grad_stype=None, rhs_grad_stype=None,
lhs_density=.5, rhs_density=.5,
@@ -378,8 +388,8 @@ def check_elemwise_binary_ops(lhs_stype, rhs_stype, shape,
lambda l, r: mx.sym.sparse.elemwise_mul(l, r),
lambda l, r: l * r,
lambda outg, l, r: (outg * r, outg * l),
- elemwise_mul_stype(lhs_stype, rhs_stype),
- elemwise_mul_stype(lhs_stype, rhs_stype),
+ elemwise_mul_lhs_grad_stype(lhs_stype,
rhs_stype),
+ elemwise_mul_rhs_grad_stype(lhs_stype,
rhs_stype),
expected_result_storage_type=elemwise_mul_stype(lhs_stype, rhs_stype),
ograd_density=ograd_density,
force_lr_overlap=force_lr_overlap,
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services