This is an automated email from the ASF dual-hosted git repository.
haibin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 9feecce [MXNET-399] Elemwise_mul between dense and csr on CPU & GPU
(#10894)
9feecce is described below
commit 9feeccecb4ab64461cfae0bd4e75dd4bcbd7c9d5
Author: Hao Jin <[email protected]>
AuthorDate: Wed May 30 17:33:42 2018 -0700
[MXNET-399] Elemwise_mul between dense and csr on CPU & GPU (#10894)
* support elemwise_mul between dns and csr
* address reviews and support for backward when ograd is dns
---
src/operator/tensor/elemwise_binary_op-inl.h | 85 +++++++++++++++++
src/operator/tensor/elemwise_binary_op.cc | 21 ++++
src/operator/tensor/elemwise_binary_op.h | 121 +++++++++++++++++-------
src/operator/tensor/elemwise_binary_op_basic.cu | 4 +-
tests/python/unittest/test_sparse_operator.py | 14 ++-
5 files changed, 210 insertions(+), 35 deletions(-)
diff --git a/src/operator/tensor/elemwise_binary_op-inl.h
b/src/operator/tensor/elemwise_binary_op-inl.h
index c74f1f9..911c369 100644
--- a/src/operator/tensor/elemwise_binary_op-inl.h
+++ b/src/operator/tensor/elemwise_binary_op-inl.h
@@ -496,6 +496,91 @@ void ElemwiseBinaryOp::DnsCsrDnsOp(mshadow::Stream<xpu> *s,
}
/*!
+ * \brief Kernel for performing elemwise op between dense and csr matrix
+ * \param i global thread id
+ * \param req type of request
+ * \param out output array
+ * \param dns_data data array of dense input
+ * \param csr_data data array of csr input
+ * \param csr_indices indices array of csr input
+ * \param csr_indptr indptr array of csr input
+ * \param num_rows number of rows of both inputs
+ * \param num_cols number of columns of both inputs
+ */
+template<int req, typename OP, bool reverse>
+struct ElemwiseDnsCsrCsrKernel {
+ template<typename DType, typename IType, typename CType>
+ MSHADOW_XINLINE static void Map(int i, DType* out, DType* dns_data,
+ const DType* csr_data, const IType*
csr_indices,
+ const CType* csr_indptr, const nnvm::dim_t
num_rows,
+ const nnvm::dim_t num_cols) {
+ if (i < num_rows) {
+ for (int j = csr_indptr[i]; j < csr_indptr[i+1]; ++j) {
+ KERNEL_ASSIGN(out[j], req, reverse ?
+ OP::Map(dns_data[i * num_cols +
csr_indices[j]], csr_data[j]) :
+ OP::Map(csr_data[j], dns_data[i * num_cols
+ csr_indices[j]]));
+ }
+ }
+ }
+};
+
+/*! \brief DNS -op- CSR binary operator for non-canonical NDArray */
+template<typename xpu, typename OP>
+void ElemwiseBinaryOp::DnsCsrCsrOp(const nnvm::NodeAttrs &attrs,
+ const OpContext &ctx,
+ const NDArray &dns,
+ const NDArray &csr,
+ const OpReqType req,
+ const NDArray &output,
+ const bool reverse) {
+ using namespace mshadow;
+ using namespace mxnet_op;
+ using namespace csr;
+ CHECK_EQ(dns.storage_type(), kDefaultStorage);
+ CHECK_EQ(csr.storage_type(), kCSRStorage);
+ CHECK_EQ(req, kWriteTo) << "elemwise(dns, csr) = csr only supports kWriteTo";
+ if (req == kNullOp) return;
+ const bool supported_op = std::is_same<OP, mshadow_op::mul>::value;
+ CHECK(supported_op == true) << "elemwise(dns, csr) = csr only supports mul";
+ const nnvm::dim_t num_csr_rows = csr.shape()[0];
+ const nnvm::dim_t num_csr_cols = csr.shape()[1];
+ const nnvm::dim_t nnz = csr.storage_shape()[0];
+ Stream<xpu> *s = ctx.get_stream<xpu>();
+
+ output.CheckAndAlloc({Shape1(num_csr_rows + 1), Shape1(nnz)});
+ if (csr.storage_initialized()) {
+ TBlob csr_data = csr.data();
+ TBlob csr_indices = csr.aux_data(kIdx);
+ TBlob csr_indptr = csr.aux_data(kIndPtr);
+ MSHADOW_SGL_DBL_TYPE_SWITCH(csr_data.type_flag_, DType, {
+ MSHADOW_IDX_TYPE_SWITCH(csr_indices.type_flag_, IType, {
+ MSHADOW_IDX_TYPE_SWITCH(csr_indptr.type_flag_, CType, {
+ MXNET_ASSIGN_REQ_SWITCH(req, Req, {
+ if (reverse) {
+ Kernel<ElemwiseDnsCsrCsrKernel<Req, OP, true>, xpu>::Launch(
+ s, num_csr_rows, output.data().dptr<DType>(),
dns.data().dptr<DType>(),
+ csr_data.dptr<DType>(), csr_indices.dptr<IType>(),
csr_indptr.dptr<CType>(),
+ num_csr_rows, num_csr_cols);
+ } else {
+ Kernel<ElemwiseDnsCsrCsrKernel<Req, OP, false>, xpu>::Launch(
+ s, num_csr_rows, output.data().dptr<DType>(),
dns.data().dptr<DType>(),
+ csr_data.dptr<DType>(), csr_indices.dptr<IType>(),
csr_indptr.dptr<CType>(),
+ num_csr_rows, num_csr_cols);
+ }
+ Copy(output.aux_data(kIdx).FlatTo1D<xpu, IType>(),
+ csr.aux_data(kIdx).FlatTo1D<xpu, IType>(), s);
+ Copy(output.aux_data(kIndPtr).FlatTo1D<xpu, CType>(),
+ csr.aux_data(kIndPtr).FlatTo1D<xpu, CType>(), s);
+ });
+ });
+ });
+ });
+ } else {
+ FillZerosCsrImpl(s, output);
+ }
+}
+
+/*!
* \brief Kernel for performing elemwise op between dense and rsp tensor
* \param i global thread id
* \param req type of request
diff --git a/src/operator/tensor/elemwise_binary_op.cc
b/src/operator/tensor/elemwise_binary_op.cc
index e8ba2fa..9ccbacc 100644
--- a/src/operator/tensor/elemwise_binary_op.cc
+++ b/src/operator/tensor/elemwise_binary_op.cc
@@ -63,6 +63,11 @@ bool ElemwiseBinaryOp::BackwardUseInStorageType(const
nnvm::NodeAttrs& attrs,
const bool invalid_ctx = dev_mask != mshadow::cpu::kDevMask;
const auto dispatch_ex = invalid_ctx ? DispatchMode::kFComputeFallback :
DispatchMode::kFComputeEx;
+ const int ograd_stype = in_attrs->at(0);
+ const int lhs_stype = in_attrs->at(1);
+ const int rhs_stype = in_attrs->at(2);
+ int& lhs_grad_stype = out_attrs->at(0);
+ int& rhs_grad_stype = out_attrs->at(1);
if (!dispatched && common::ContainsOnlyStorage(*in_attrs, kDefaultStorage)) {
dispatched = storage_type_assign(out_attrs, kDefaultStorage,
dispatch_mode, DispatchMode::kFCompute);
@@ -74,6 +79,22 @@ bool ElemwiseBinaryOp::BackwardUseInStorageType(const
nnvm::NodeAttrs& attrs,
dispatch_mode, dispatch_ex);
}
}
+ if (!dispatched && ograd_stype == kDefaultStorage &&
+ ((lhs_stype == kCSRStorage && rhs_stype == kDefaultStorage) ||
+ (lhs_stype == kDefaultStorage && rhs_stype == kCSRStorage))) {
+ const bool reverse = (lhs_stype == kCSRStorage);
+ if (reverse &&
+ type_assign(&lhs_grad_stype, kDefaultStorage) &&
+ type_assign(&rhs_grad_stype, kCSRStorage)) {
+ DISPATCH_MODE_ASSIGN_CHECK(dispatch_mode, 0, DispatchMode::kFComputeEx);
+ dispatched = true;
+ } else if (!reverse &&
+ type_assign(&lhs_grad_stype, kCSRStorage) &&
+ type_assign(&rhs_grad_stype, kDefaultStorage)) {
+ DISPATCH_MODE_ASSIGN_CHECK(dispatch_mode, 0, DispatchMode::kFComputeEx);
+ dispatched = true;
+ }
+ }
if (!dispatched) {
dispatched = dispatch_fallback(out_attrs, dispatch_mode);
}
diff --git a/src/operator/tensor/elemwise_binary_op.h
b/src/operator/tensor/elemwise_binary_op.h
index a5b73da..ad4b3e7 100644
--- a/src/operator/tensor/elemwise_binary_op.h
+++ b/src/operator/tensor/elemwise_binary_op.h
@@ -165,12 +165,11 @@ class ElemwiseBinaryOp : public OpBase {
typename xpu,
typename LOP,
typename ROP,
- typename DType,
bool in0_ok_dense = false,
bool in1_ok_dense = false,
bool in2_ok_dense = false,
typename BackupCompute>
- static inline void BackwardUseInEx_(const nnvm::NodeAttrs &attrs,
+ static inline void RspRspOpBackward(const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const std::vector<NDArray> &inputs,
const std::vector<OpReqType> &req,
@@ -200,6 +199,33 @@ class ElemwiseBinaryOp : public OpBase {
}
}
+ template<typename xpu, typename LOP, typename ROP>
+ static inline void DnsCsrCsrOpBackward(const nnvm::NodeAttrs &attrs,
+ const OpContext &ctx,
+ const std::vector<NDArray> &inputs,
+ const std::vector<OpReqType> &req,
+ const std::vector<NDArray> &outputs) {
+ const bool supported_ops = std::is_same<mshadow_op::right, LOP>::value &&
+ std::is_same<mshadow_op::left, ROP>::value;
+ CHECK(supported_ops)
+ << "Only backward for mul is supported (LOP should be right, ROP should
be left)";
+ const NDArray& out_grad = inputs[0];
+ const NDArray& lhs_in = inputs[1];
+ const NDArray& rhs_in = inputs[2];
+ const NDArray& lhs_grad = outputs[0];
+ const NDArray& rhs_grad = outputs[1];
+ const bool reverse = (outputs[0].storage_type() == kCSRStorage);
+ if (reverse) {
+ DnsCsrCsrOp<xpu, mshadow_op::mul>(attrs, ctx, out_grad, rhs_in, req[0],
lhs_grad, false);
+ Compute<xpu, mshadow_op::mul>(attrs, ctx, {out_grad.data(),
lhs_in.data()}, {req[1]},
+ {rhs_grad.data()});
+ } else {
+ DnsCsrCsrOp<xpu, mshadow_op::mul>(attrs, ctx, out_grad, lhs_in, req[1],
rhs_grad, false);
+ Compute<xpu, mshadow_op::mul>(attrs, ctx, {out_grad.data(),
rhs_in.data()}, {req[0]},
+ {lhs_grad.data()});
+ }
+ }
+
public:
/*! \brief Binary op handling for lhr/rhs: RspDns, RspRsp, DnsRsp, or
RspRsp->Dns result */
template<typename OP>
@@ -232,44 +258,54 @@ class ElemwiseBinaryOp : public OpBase {
/*! \brief CSR -op- CSR binary operator for non-canonical NDArray */
template<typename OP>
static void CsrCsrOp(mshadow::Stream<cpu> *s,
- const nnvm::NodeAttrs &attrs,
- const OpContext &ctx,
- const NDArray &lhs,
- const NDArray &rhs,
- OpReqType req,
- const NDArray &output);
+ const nnvm::NodeAttrs &attrs,
+ const OpContext &ctx,
+ const NDArray &lhs,
+ const NDArray &rhs,
+ OpReqType req,
+ const NDArray &output);
/*! \brief CSR -op- CSR binary operator for non-canonical NDArray */
template<typename OP>
static void CsrCsrOp(mshadow::Stream<gpu> *s,
- const nnvm::NodeAttrs &attrs,
- const OpContext &ctx,
- const NDArray &lhs,
- const NDArray &rhs,
- OpReqType req,
- const NDArray &output);
+ const nnvm::NodeAttrs &attrs,
+ const OpContext &ctx,
+ const NDArray &lhs,
+ const NDArray &rhs,
+ OpReqType req,
+ const NDArray &output);
/*! \brief DNS -op- CSR binary operator for non-canonical NDArray */
template<typename xpu, typename OP>
static void DnsCsrDnsOp(mshadow::Stream<xpu> *s,
- const nnvm::NodeAttrs &attrs,
- const OpContext &ctx,
- const NDArray &lhs,
- const NDArray &rhs,
- OpReqType req,
- const NDArray &output,
- const bool reverse);
+ const nnvm::NodeAttrs &attrs,
+ const OpContext &ctx,
+ const NDArray &lhs,
+ const NDArray &rhs,
+ OpReqType req,
+ const NDArray &output,
+ const bool reverse);
+
+ /*! \brief DNS -op- CSR binary operator for non-canonical NDArray */
+ template<typename xpu, typename OP>
+ static void DnsCsrCsrOp(const nnvm::NodeAttrs &attrs,
+ const OpContext &ctx,
+ const NDArray &lhs,
+ const NDArray &rhs,
+ OpReqType req,
+ const NDArray &output,
+ const bool reverse);
/*! \brief DNS -op- RSP binary operator for non-canonical NDArray */
template<typename xpu, typename OP>
static void DnsRspDnsOp(mshadow::Stream<xpu> *s,
- const nnvm::NodeAttrs &attrs,
- const OpContext &ctx,
- const NDArray &lhs,
- const NDArray &rhs,
- OpReqType req,
- const NDArray &output,
- const bool reverse);
+ const nnvm::NodeAttrs &attrs,
+ const OpContext &ctx,
+ const NDArray &lhs,
+ const NDArray &rhs,
+ OpReqType req,
+ const NDArray &output,
+ const bool reverse);
public:
/*!
@@ -336,6 +372,14 @@ class ElemwiseBinaryOp : public OpBase {
dispatched = storage_type_assign(&out_stype, kRowSparseStorage,
dispatch_mode, dispatch_ex);
}
+ if (!dispatched &&
+ ((lhs_stype == kCSRStorage && rhs_stype == kDefaultStorage) ||
+ (lhs_stype == kDefaultStorage && rhs_stype == kCSRStorage))) {
+ // csr, dns -> csr
+ // dns, csr -> csr
+ dispatched = storage_type_assign(&out_stype, kCSRStorage,
+ dispatch_mode,
DispatchMode::kFComputeEx);
+ }
if (!dispatched) {
dispatched = dispatch_fallback(out_attrs, dispatch_mode);
}
@@ -540,6 +584,14 @@ class ElemwiseBinaryOp : public OpBase {
req[0], outputs[0], lhs_may_be_dense, rhs_may_be_dense, false,
false);
} else if (lhs_stype == kCSRStorage && rhs_stype == kCSRStorage) {
ComputeEx<xpu, OP>(attrs, ctx, inputs, req, outputs);
+ } else if (((lhs_stype == kCSRStorage && rhs_stype == kDefaultStorage) ||
+ (lhs_stype == kDefaultStorage && rhs_stype == kCSRStorage)) &&
+ out_stype == kCSRStorage) {
+ const NDArray& dns = (lhs_stype == kDefaultStorage)? inputs[0] :
inputs[1];
+ const NDArray& csr = (lhs_stype == kCSRStorage)? inputs[0] : inputs[1];
+ const bool reverse = (lhs_stype == kCSRStorage);
+
+ DnsCsrCsrOp<xpu, OP>(attrs, ctx, dns, csr, req[0], outputs[0], reverse);
} else {
LogUnimplementedOp(attrs, ctx, inputs, req, outputs);
}
@@ -635,16 +687,21 @@ class ElemwiseBinaryOp : public OpBase {
using namespace common;
CHECK_EQ(inputs.size(), 3U);
CHECK_EQ(outputs.size(), 2U); // lhs input grad, rhs input grad
+ const auto out_grad_stype = inputs[0].storage_type();
const auto lhs_grad_stype = outputs[0].storage_type();
const auto rhs_grad_stype = outputs[1].storage_type();
if (ContainsOnlyStorage(inputs, kRowSparseStorage) &&
(lhs_grad_stype == kDefaultStorage || lhs_grad_stype ==
kRowSparseStorage) &&
(rhs_grad_stype == kDefaultStorage || rhs_grad_stype ==
kRowSparseStorage)) {
// rsp, rsp, rsp -> [dns, rsp], [dns, rsp]
- MSHADOW_TYPE_SWITCH(outputs[0].dtype(), DType, {
- BackwardUseInEx_<xpu, LOP, ROP, DType, in0_ok_dense, in1_ok_dense,
in2_ok_dense>(
- attrs, ctx, inputs, req, outputs, BackwardUseIn<xpu, LOP, ROP>);
- });
+ RspRspOpBackward<xpu, LOP, ROP, in0_ok_dense, in1_ok_dense,
in2_ok_dense>(
+ attrs, ctx, inputs, req, outputs, BackwardUseIn<xpu, LOP, ROP>);
+ }
+ if (((lhs_grad_stype == kDefaultStorage && rhs_grad_stype == kCSRStorage)
||
+ (lhs_grad_stype == kCSRStorage && rhs_grad_stype == kDefaultStorage))
&&
+ out_grad_stype == kDefaultStorage) {
+ // dns, csr, dns -> [csr, dns] / csr, dns, dns -> [dns, csr]
+ DnsCsrCsrOpBackward<xpu, LOP, ROP>(attrs, ctx, inputs, req, outputs);
}
}
}; // class ElemwiseBinaryOp
diff --git a/src/operator/tensor/elemwise_binary_op_basic.cu
b/src/operator/tensor/elemwise_binary_op_basic.cu
index 9c1fd0e..5cdd894 100644
--- a/src/operator/tensor/elemwise_binary_op_basic.cu
+++ b/src/operator/tensor/elemwise_binary_op_basic.cu
@@ -51,7 +51,9 @@ NNVM_REGISTER_OP(_backward_sub)
mshadow_op::negation>);
NNVM_REGISTER_OP(elemwise_mul)
-.set_attr<FCompute>("FCompute<gpu>", ElemwiseBinaryOp::ComputeWithHalf2<gpu,
op::mshadow_op::mul>);
+.set_attr<FCompute>("FCompute<gpu>", ElemwiseBinaryOp::ComputeWithHalf2<gpu,
op::mshadow_op::mul>)
+.set_attr<FComputeEx>("FComputeEx<gpu>",
+ ElemwiseBinaryOp::ComputeDnsLRValueEx<gpu, op::mshadow_op::mul, true, true>);
NNVM_REGISTER_OP(_backward_mul)
.set_attr<FCompute>("FCompute<gpu>",
diff --git a/tests/python/unittest/test_sparse_operator.py
b/tests/python/unittest/test_sparse_operator.py
index 226db70..b2ff0fe 100644
--- a/tests/python/unittest/test_sparse_operator.py
+++ b/tests/python/unittest/test_sparse_operator.py
@@ -329,9 +329,19 @@ def test_elemwise_binary_ops():
return 'row_sparse'
elif lstype == 'row_sparse' and rstype == 'default':
return 'row_sparse'
+ elif lstype == 'default' and rstype == 'csr':
+ return 'csr'
+ elif lstype == 'csr' and rstype == 'default':
+ return 'csr'
else:
return 'default'
+ def elemwise_mul_lhs_grad_stype(lstype, rstype):
+ return elemwise_mul_stype(elemwise_mul_stype(lstype, rstype), rstype)
+
+ def elemwise_mul_rhs_grad_stype(lstype, rstype):
+ return elemwise_mul_stype(elemwise_mul_stype(lstype, rstype), lstype)
+
def check_elemwise_binary_ops(lhs_stype, rhs_stype, shape,
lhs_grad_stype=None, rhs_grad_stype=None,
lhs_density=.5, rhs_density=.5,
@@ -378,8 +388,8 @@ def test_elemwise_binary_ops():
lambda l, r: mx.sym.sparse.elemwise_mul(l, r),
lambda l, r: l * r,
lambda outg, l, r: (outg * r, outg * l),
- elemwise_mul_stype(lhs_stype, rhs_stype),
- elemwise_mul_stype(lhs_stype, rhs_stype),
+ elemwise_mul_lhs_grad_stype(lhs_stype,
rhs_stype),
+ elemwise_mul_rhs_grad_stype(lhs_stype,
rhs_stype),
expected_result_storage_type=elemwise_mul_stype(lhs_stype, rhs_stype),
ograd_density=ograd_density,
force_lr_overlap=force_lr_overlap,
--
To stop receiving notification emails like this one, please contact
[email protected].