eric-haibin-lin commented on a change in pull request #10150: [WIP] [DO NOT
MERGE] Sparse operator broadcast_mul/div(csr, dense) = csr
URL: https://github.com/apache/incubator-mxnet/pull/10150#discussion_r175539266
##########
File path: src/operator/tensor/elemwise_binary_broadcast_op.h
##########
@@ -185,6 +227,75 @@ void BinaryBroadcastCompute(const nnvm::NodeAttrs& attrs,
}
}
+template<typename xpu, typename OP>
+void BinaryBroadCastCsrDnsCsrImpl(const OpContext& ctx,
+ const NDArray& csr,
+ const NDArray& dns,
+ const OpReqType req,
+ const NDArray& output) {
+ using namespace mshadow;
+ using namespace mxnet_op;
+ using namespace csr;
+ CHECK_EQ(dns.shape().ndim(), 1) << "input dense should be a vector";
+ mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
+ bool col_vec = (dns.shape()[0] == csr.shape()[0])? true : false;
+ if (!csr.storage_initialized()) {
+ FillZerosCsrImpl(s, output);
+ return;
+ }
+ const nnvm::dim_t nnz = csr.storage_shape()[0];
+ const nnvm::dim_t num_rows = output.shape()[0];
+ output.CheckAndAlloc({Shape1(num_rows + 1), Shape1(nnz)});
+
+ MSHADOW_TYPE_SWITCH(output.dtype(), DType, {
+ MSHADOW_IDX_TYPE_SWITCH(output.aux_type(kIdx), CType, {
+ MSHADOW_IDX_TYPE_SWITCH(output.aux_type(kIndPtr), RType, {
+ MXNET_ASSIGN_REQ_SWITCH(req, req_type, {
+ Kernel<csr_dns_csr_broadcast_kernel<DType, CType, RType, req_type,
OP>, xpu>::Launch(
+ s, num_rows, csr.data().dptr<DType>(),
csr.aux_data(kIdx).dptr<CType>(),
+ csr.aux_data(kIndPtr).dptr<RType>(), dns.data().dptr<DType>(),
+ output.data().dptr<DType>(), csr.shape()[1], col_vec);
+ Copy(output.aux_data(kIdx).FlatTo1D<xpu, CType>(),
+ csr.aux_data(kIdx).FlatTo1D<xpu, CType>());
+ Copy(output.aux_data(kIndPtr).FlatTo1D<xpu, RType>(),
+ csr.aux_data(kIndPtr).FlatTo1D<xpu, RType>());
+ });
+ });
+ });
+ });
+}
+
+template<typename xpu, typename OP>
+void BinaryBroadcastComputeCsrEx(const nnvm::NodeAttrs& attrs,
+ const OpContext& ctx,
+ const std::vector<NDArray>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<NDArray>& outputs) {
+ CHECK_EQ(inputs.size(), 2U);
+ CHECK_EQ(outputs.size(), 1U);
+ CHECK_EQ(req.size(), 1U);
+ CHECK_LE(inputs[1].shape().ndim(), 2U) << "input dense matrix should have
less than 2 dimensions";
+ const auto in1_stype = inputs[0].storage_type();
+ const auto in2_stype = inputs[1].storage_type();
+ const auto out_stype = outputs[0].storage_type();
+ if (!(inputs[1].shape().ndim() == 1U)) {
+ ElemwiseBinaryOp::ComputeEx<xpu, OP>(attrs, ctx, inputs, req, outputs);
+ } else {
+ if (req[0] != kNullOp) {
+ // broadcast(CSR, Dense(1D)) = CSR
+ if (in1_stype == kCSRStorage && in2_stype == kDefaultStorage &&
out_stype == kCSRStorage) {
+ BinaryBroadCastCsrDnsCsrImpl<xpu, OP>(ctx, inputs[0], inputs[1],
req[0], outputs[0]);
+ // broadcast(CSR, Dense(1D)) = Dense
+ //} else if (in1_stype == kCSRStorage && in2_stype == kDefaultStorage &&
+ // out_stype == kDefaultStorage) {
+ // BinaryBroadCastCsrDnsDnsImpl(ctx, inputs[0], input[1], req[0],
outputs[0]);
+ } else {
+ LogUnimplementedOp(attrs, ctx, inputs, req, outputs);
Review comment:
I agree that throwing an error is not desirable and blocks users from what
they want to do. The problem is that finferstorage is not aware of shape and
dtype, and dispatch only based on dev_mask and storage types. And for this
sparse broadcast operator it's a lot of work to implement cases for 2-D and
3-D.
Maybe a temporary walk-around is to fallback inside the operator..
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services