This is an automated email from the ASF dual-hosted git repository.
wkcn pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new f01dc80 Adding sparse support to MXTensor for custom operators
(#17569)
f01dc80 is described below
commit f01dc80f030d2d1912c8e134c95f373e9f1f8e7b
Author: guanxinq <[email protected]>
AuthorDate: Sun Mar 22 03:50:55 2020 -0700
Adding sparse support to MXTensor for custom operators (#17569)
* Added enum for sparse storage
* Add structure for Dense and Sparse
* redesign the data structure for MXSparse
* pull out aux data from sparse NDArray
* Added more sparse arguments to API interface
* Passed sparse from c_api to lib_api.h and set in MXTensor
* Fix indent
* fix segfault
* Fix NDArray to MXTensor errors
* Add a sample of sparse(CSR) transpose
* Make CSR transpose temporarily work by hardcoding
* Fixed sparse output size(Refined)
* Add tests for symbolic and stateful ops
* Added a sample for row sparse transpose
* Added real row sparse transpose
* Fix output size issue by adding lambda for CheckAndAlloc()
* Fix mixed storage formats error
* Added infer storage type function
* resolve comments
* Set inferSType as optional function
* Resolve comments
* Add error messages
* Resolve comments
* verify transpose ops results
* fix sanity check
* update MX_LIBRARY_VERSION to 5
---
example/extensions/lib_custom_op/Makefile | 10 +-
.../extensions/lib_custom_op/test_transposecsr.py | 78 ++++++
.../lib_custom_op/test_transposerowsp.py | 73 ++++++
.../extensions/lib_custom_op/transposecsr_lib.cc | 197 ++++++++++++++
.../extensions/lib_custom_op/transposerowsp_lib.cc | 199 ++++++++++++++
example/extensions/lib_subgraph/subgraph_lib.cc | 4 +-
include/mxnet/lib_api.h | 286 ++++++++++++++++++---
src/c_api/c_api.cc | 119 ++++++++-
8 files changed, 919 insertions(+), 47 deletions(-)
diff --git a/example/extensions/lib_custom_op/Makefile
b/example/extensions/lib_custom_op/Makefile
index edd753b..feded29 100644
--- a/example/extensions/lib_custom_op/Makefile
+++ b/example/extensions/lib_custom_op/Makefile
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
-all: gemm_lib relu_lib
+all: gemm_lib relu_lib transposecsr_lib transposerowsp_lib
gemm_lib:
g++ -shared -fPIC -std=c++11 gemm_lib.cc -o libgemm_lib.so -I
../../../include/mxnet
@@ -23,5 +23,11 @@ gemm_lib:
relu_lib:
nvcc -shared -std=c++11 -Xcompiler -fPIC relu_lib.cu -o librelu_lib.so
-I ../../../include/mxnet
+transposecsr_lib:
+ g++ -shared -fPIC -std=c++11 transposecsr_lib.cc -o
libtransposecsr_lib.so -I ../../../include/mxnet
+
+transposerowsp_lib:
+ g++ -shared -fPIC -std=c++11 transposerowsp_lib.cc -o
libtransposerowsp_lib.so -I ../../../include/mxnet
+
clean:
- rm -rf libgemm_lib.so librelu_lib.so
+ rm -rf libgemm_lib.so librelu_lib.so libtransposecsr_lib.so
libtransposerowsp_lib.so
diff --git a/example/extensions/lib_custom_op/test_transposecsr.py
b/example/extensions/lib_custom_op/test_transposecsr.py
new file mode 100644
index 0000000..37d066a
--- /dev/null
+++ b/example/extensions/lib_custom_op/test_transposecsr.py
@@ -0,0 +1,78 @@
+#!/usr/bin/env python3
+
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# coding: utf-8
+# pylint: disable=arguments-differ
+
+# This test checks dynamic loading of custom library into MXNet
+# and checks end to end compute of a simple 2D gemm custom op
+
+import mxnet as mx
+import os
+
+#load library
+if (os.name=='posix'):
+ path = os.path.abspath('libtransposecsr_lib.so')
+ mx.library.load(path)
+elif (os.name=='nt'):
+ path = os.path.abspath('libtransposecsr_lib.dll')
+ mx.library.load(path)
+
+a = mx.nd.array([[1,3,0,2,1],[0,1,0,0,0],[0,2,4,5,3]])
+a = a.tostype('csr')
+print("--------Input CSR Array---------")
+print("data:", a.data.asnumpy())
+print("indices:", a.indices.asnumpy())
+print("indptr:", a.indptr.asnumpy())
+
+print("--------Start NDArray Compute---------")
+b = mx.nd.my_transposecsr(a)
+print("Compute Results:")
+print("data:", b.data.asnumpy())
+print("indices:", b.indices.asnumpy())
+print("indptr:", b.indptr.asnumpy())
+
+print("Stateful Compute Result:")
+c = mx.nd.my_state_transposecsr(a, test_kw=100)
+print("data:", c.data.asnumpy())
+print("indices:", c.indices.asnumpy())
+print("indptr:", c.indptr.asnumpy())
+
+print("--------start symbolic compute--------")
+d = mx.sym.Variable('d')
+e = mx.sym.my_transposecsr(d)
+f = mx.sym.my_state_transposecsr(d, test_kw=200)
+
+exe = e.bind(ctx=mx.cpu(),args={'d':a})
+exe2 = f.bind(ctx=mx.cpu(),args={'d':a})
+out = exe.forward()
+print("Compute Results:")
+print("data:", out[0].data.asnumpy())
+print("indices:", out[0].indices.asnumpy())
+print("indptr:", out[0].indptr.asnumpy())
+
+out2 = exe2.forward()
+out2 = exe2.forward()
+print("Stateful Compute Result:")
+print("data:", out2[0].data.asnumpy())
+print("indices:", out2[0].indices.asnumpy())
+print("indptr:", out2[0].indptr.asnumpy())
+
+print("--------Baseline(dense)--------")
+print(mx.nd.transpose(a.tostype('default')))
diff --git a/example/extensions/lib_custom_op/test_transposerowsp.py
b/example/extensions/lib_custom_op/test_transposerowsp.py
new file mode 100644
index 0000000..cea62ec
--- /dev/null
+++ b/example/extensions/lib_custom_op/test_transposerowsp.py
@@ -0,0 +1,73 @@
+#!/usr/bin/env python3
+
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# coding: utf-8
+# pylint: disable=arguments-differ
+
+# This test checks dynamic loading of custom library into MXNet
+# and checks end to end compute of a simple 2D gemm custom op
+
+import mxnet as mx
+import os
+
+#load library
+if (os.name=='posix'):
+ path = os.path.abspath('libtransposerowsp_lib.so')
+ mx.library.load(path)
+elif (os.name=='nt'):
+ path = os.path.abspath('libtransposerowsp_lib.dll')
+ mx.library.load(path)
+
+a = mx.nd.array([[1,2,3],[0,0,0],[4,0,5],[0,0,0],[0,0,0]])
+a = a.tostype('row_sparse')
+print("--------Input CSR Array---------")
+print("data:", a.data.asnumpy())
+print("indices:", a.indices.asnumpy())
+
+print("--------Start NDArray Compute---------")
+b = mx.nd.my_transposerowsp(a)
+print("Compute Results:")
+print("data:", b.data.asnumpy())
+print("indices:", b.indices.asnumpy())
+
+print("Stateful Compute Result:")
+c = mx.nd.my_state_transposerowsp(a, test_kw=100)
+print("data:", c.data.asnumpy())
+print("indices:", c.indices.asnumpy())
+
+print("--------start symbolic compute--------")
+d = mx.sym.Variable('d')
+e = mx.sym.my_transposerowsp(d)
+f = mx.sym.my_state_transposerowsp(d, test_kw=200)
+
+exe = e.bind(ctx=mx.cpu(),args={'d':a})
+exe2 = f.bind(ctx=mx.cpu(),args={'d':a})
+out = exe.forward()
+print("Compute Results:")
+print("data:", out[0].data.asnumpy())
+print("indices:", out[0].indices.asnumpy())
+
+out2 = exe2.forward()
+out2 = exe2.forward()
+print("Stateful Compute Result:")
+print("data:", out2[0].data.asnumpy())
+print("indices:", out2[0].indices.asnumpy())
+
+print("--------Baseline(dense)--------")
+print(mx.nd.transpose(a.tostype('default')))
diff --git a/example/extensions/lib_custom_op/transposecsr_lib.cc
b/example/extensions/lib_custom_op/transposecsr_lib.cc
new file mode 100644
index 0000000..0daeb3e
--- /dev/null
+++ b/example/extensions/lib_custom_op/transposecsr_lib.cc
@@ -0,0 +1,197 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2020 by Contributors
+ * \file transsparse_lib.cc
+ * \brief Sample 2D transpose custom operator.
+ */
+
+#include <iostream>
+#include "lib_api.h"
+
+void transpose(MXTensor src, MXTensor dst, OpResource res) {
+ MXSparse* A = src.data<MXSparse>();
+ MXSparse* B = dst.data<MXSparse>();
+ std::vector<int64_t> shape = src.shape;
+ int64_t h = shape[0];
+ int64_t w = shape[1];
+ if(src.stype == kCSRStorage) {
+ float *Aval = (float*) (A->data);
+ // Here we need one more element to help calculate index(line 57).
+ std::vector<int64_t> rowPtr(w + 2, 0);
+ // count column
+ for(int i = 0; i < A->data_len; i++) {
+ rowPtr[A->indices[i] + 2]++;
+ }
+ // Accumulated sum. After this for loop, rowPtr[1:w+2) stores the correct
+ // result of transposed rowPtr.
+ for(int i = 2; i < rowPtr.size(); i++) {
+ rowPtr[i] += rowPtr[i - 1];
+ }
+
+ // Alloc memory for sparse data, where 0 is the index
+ // of B in output vector.
+ res.alloc_sparse(B, 0, A->data_len, w + 1);
+ float *Bval = (float*) (B->data);
+ for(int i = 0; i < h; i++) {
+ for(int j = A->indptr[i]; j < A->indptr[i + 1]; j++) {
+ // Helps calculate index and after that rowPtr[0:w+1) stores the
+ // correct result of transposed rowPtr.
+ int index = rowPtr[A->indices[j] + 1]++;
+ Bval[index] = Aval[j];
+ B->indices[index] = i;
+ }
+ }
+ memcpy(B->indptr, rowPtr.data(), sizeof(int64_t) * (w + 1));
+ }
+}
+
+MXReturnValue forward(std::map<std::string, std::string> attrs,
+ std::vector<MXTensor> inputs,
+ std::vector<MXTensor> outputs,
+ OpResource res) {
+ // The data types and storage types of inputs and outputs should be the
same.
+ if(inputs[0].dtype != outputs[0].dtype || inputs[0].stype !=
outputs[0].stype) {
+ std::cout << "Error! Expected all inputs and outputs to be the same type."
+ << "Found input storage type:" << inputs[0].stype
+ << " Found output storage type:" << outputs[0].stype
+ << " Found input data type:" << inputs[0].dtype
+ << " Found output data type:" << outputs[0].dtype << std::endl;
+ return MX_FAIL;
+ }
+
+ transpose(inputs[0], outputs[0], res);
+ return MX_SUCCESS;
+}
+
+MXReturnValue backward(std::map<std::string, std::string> attrs,
+ std::vector<MXTensor> inputs,
+ std::vector<MXTensor> outputs,
+ OpResource res) {
+ return MX_SUCCESS;
+}
+
+MXReturnValue parseAttrs(std::map<std::string, std::string> attrs, int*
num_in, int* num_out) {
+ *num_in = 1;
+ *num_out = 1;
+ return MX_SUCCESS;
+}
+
+MXReturnValue inferType(std::map<std::string, std::string> attrs,
+ std::vector<int> &intypes,
+ std::vector<int> &outtypes) {
+ // validate inputs
+ if (intypes.size() != 1) {
+ std::cout << "Expected 1 inputs to inferType" << std::endl;
+ return MX_FAIL;
+ }
+ if (intypes[0] != kFloat32) {
+ std::cout << "Expected input to have float32 type" << std::endl;
+ return MX_FAIL;
+ }
+
+ outtypes[0] = intypes[0];
+ return MX_SUCCESS;
+}
+
+MXReturnValue inferSType(std::map<std::string, std::string> attrs,
+ std::vector<int> &instypes,
+ std::vector<int> &outstypes) {
+ if (instypes[0] != kCSRStorage) {
+ std::cout << "Expected storage type is kCSRStorage" << std::endl;
+ return MX_FAIL;
+ }
+ outstypes[0] = instypes[0];
+ return MX_SUCCESS;
+}
+
+MXReturnValue inferShape(std::map<std::string, std::string> attrs,
+ std::vector<std::vector<unsigned int>> &inshapes,
+ std::vector<std::vector<unsigned int>> &outshapes) {
+ // validate inputs
+ if (inshapes.size() != 1) {
+ std::cout << "Expected 1 inputs to inferShape" << std::endl;
+ return MX_FAIL;
+ }
+
+ outshapes[0].push_back(inshapes[0][1]);
+ outshapes[0].push_back(inshapes[0][0]);
+ return MX_SUCCESS;
+}
+
+REGISTER_OP(my_transposecsr)
+.setForward(forward, "cpu")
+.setBackward(backward, "cpu")
+.setParseAttrs(parseAttrs)
+.setInferType(inferType)
+.setInferSType(inferSType)
+.setInferShape(inferShape);
+
+/* ------------------------------------------------------------------------- */
+
+class MyStatefulTransposeCSR : public CustomStatefulOp {
+ public:
+ explicit MyStatefulTransposeCSR(int count) : count(count) {}
+
+ MXReturnValue Forward(std::vector<MXTensor> inputs,
+ std::vector<MXTensor> outputs,
+ OpResource op_res) {
+ std::cout << "Info: keyword + number of forward: " << ++count << std::endl;
+ std::map<std::string, std::string> attrs;
+ return forward(attrs, inputs, outputs, op_res);
+ }
+
+ MXReturnValue Backward(std::vector<MXTensor> inputs,
+ std::vector<MXTensor> outputs,
+ OpResource op_res) {
+ std::map<std::string, std::string> attrs;
+ return backward(attrs, inputs, outputs, op_res);
+ }
+
+ private:
+ int count;
+};
+
+MXReturnValue createOpState(std::map<std::string, std::string> attrs,
+ CustomStatefulOp** op_inst) {
+ // testing passing of keyword arguments
+ int count = attrs.count("test_kw") > 0 ? std::stoi(attrs["test_kw"]) : 0;
+ // creating stateful operator instance
+ *op_inst = new MyStatefulTransposeCSR(count);
+ std::cout << "Info: stateful operator created" << std::endl;
+ return MX_SUCCESS;
+}
+
+REGISTER_OP(my_state_transposecsr)
+.setParseAttrs(parseAttrs)
+.setInferType(inferType)
+.setInferSType(inferSType)
+.setInferShape(inferShape)
+.setCreateOpState(createOpState, "cpu");
+
+MXReturnValue initialize(int version) {
+ if (version >= 10400) {
+ std::cout << "MXNet version " << version << " supported" << std::endl;
+ return MX_SUCCESS;
+ } else {
+ std::cout << "MXNet version " << version << " not supported" << std::endl;
+ return MX_FAIL;
+ }
+}
diff --git a/example/extensions/lib_custom_op/transposerowsp_lib.cc
b/example/extensions/lib_custom_op/transposerowsp_lib.cc
new file mode 100644
index 0000000..883d816
--- /dev/null
+++ b/example/extensions/lib_custom_op/transposerowsp_lib.cc
@@ -0,0 +1,199 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2020 by Contributors
+ * \file transsparse_lib.cc
+ * \brief Sample 2D transpose custom operator.
+ */
+
+#include <iostream>
+#include "lib_api.h"
+
+void transpose(MXTensor src, MXTensor dst, OpResource res) {
+ MXSparse* A = src.data<MXSparse>();
+ MXSparse* B = dst.data<MXSparse>();
+
+ std::vector<int64_t> shape = src.shape;
+ int64_t h = shape[0];
+ int64_t w = shape[1];
+ if(src.stype == kRowSparseStorage) {
+ // Keys of the map is the row index of transposed tensors.
+ // Values of the map is the rows which have non-zero elements.
+ std::map<int, std::vector<float>> mp;
+ float *Aval = (float*) (A->data);
+ for(int i = 0; i < A->data_len; i++) {
+ int row = i / w;
+ int col = i % w;
+ row = A->indices[row];
+ if(Aval[i] != 0) {
+ if(mp.find(col) == mp.end()) {
+ mp[col] = std::vector<float>(h, 0);
+ mp[col][row] = Aval[i];
+ }
+ else {
+ mp[col][row] = Aval[i];
+ }
+ }
+ }
+
+ // Alloc memory for output tensors.
+ res.alloc_sparse(B, 0, mp.size());
+ float *Bval = (float*) (B->data);
+ int didx = 0, iidx = 0;
+ for(auto i : mp) {
+ B->indices[iidx++] = i.first;
+ for(auto j : i.second) {
+ Bval[didx++] = j;
+ }
+ }
+ }
+}
+
+MXReturnValue forward(std::map<std::string, std::string> attrs,
+ std::vector<MXTensor> inputs,
+ std::vector<MXTensor> outputs,
+ OpResource res) {
+ // The data types and storage types of inputs and outputs should be the same.
+ if(inputs[0].dtype != outputs[0].dtype || inputs[0].stype !=
outputs[0].stype) {
+ std::cout << "Error! Expected all inputs and outputs to be the same type."
+ << "Found input storage type:" << inputs[0].stype
+ << " Found output storage type:" << outputs[0].stype
+ << " Found input data type:" << inputs[0].dtype
+ << " Found output data type:" << outputs[0].dtype << std::endl;
+ return MX_FAIL;
+ }
+ transpose(inputs[0], outputs[0], res);
+ return MX_SUCCESS;
+}
+
+MXReturnValue backward(std::map<std::string, std::string> attrs,
+ std::vector<MXTensor> inputs,
+ std::vector<MXTensor> outputs,
+ OpResource res) {
+ return MX_SUCCESS;
+}
+
+MXReturnValue parseAttrs(std::map<std::string, std::string> attrs, int*
num_in, int* num_out) {
+ *num_in = 1;
+ *num_out = 1;
+ return MX_SUCCESS;
+}
+
+MXReturnValue inferType(std::map<std::string, std::string> attrs,
+ std::vector<int> &intypes,
+ std::vector<int> &outtypes) {
+ // validate inputs
+ if (intypes.size() != 1) {
+ std::cout << "Expected 1 inputs to inferType" << std::endl;
+ return MX_FAIL;
+ }
+ if (intypes[0] != kFloat32) {
+ std::cout << "Expected input to have float32 type" << std::endl;
+ return MX_FAIL;
+ }
+
+ outtypes[0] = intypes[0];
+ return MX_SUCCESS;
+}
+
+MXReturnValue inferSType(std::map<std::string, std::string> attrs,
+ std::vector<int> &instypes,
+ std::vector<int> &outstypes) {
+ if (instypes[0] != kRowSparseStorage) {
+ std::cout << "Expected storage type is kRowSparseStorage" << std::endl;
+ return MX_FAIL;
+ }
+ outstypes[0] = instypes[0];
+ return MX_SUCCESS;
+}
+
+MXReturnValue inferShape(std::map<std::string, std::string> attrs,
+ std::vector<std::vector<unsigned int>> &inshapes,
+ std::vector<std::vector<unsigned int>> &outshapes) {
+ // validate inputs
+ if (inshapes.size() != 1) {
+ std::cout << "Expected 1 inputs to inferShape" << std::endl;
+ return MX_FAIL;
+ }
+
+ outshapes[0].push_back(inshapes[0][1]);
+ outshapes[0].push_back(inshapes[0][0]);
+ return MX_SUCCESS;
+}
+
+REGISTER_OP(my_transposerowsp)
+.setForward(forward, "cpu")
+.setBackward(backward, "cpu")
+.setParseAttrs(parseAttrs)
+.setInferType(inferType)
+.setInferSType(inferSType)
+.setInferShape(inferShape);
+
+/* ------------------------------------------------------------------------- */
+
+class MyStatefulTransposeRowSP : public CustomStatefulOp {
+ public:
+ explicit MyStatefulTransposeRowSP(int count) : count(count) {}
+
+ MXReturnValue Forward(std::vector<MXTensor> inputs,
+ std::vector<MXTensor> outputs,
+ OpResource op_res) {
+ std::cout << "Info: keyword + number of forward: " << ++count << std::endl;
+ std::map<std::string, std::string> attrs;
+ return forward(attrs, inputs, outputs, op_res);
+ }
+
+ MXReturnValue Backward(std::vector<MXTensor> inputs,
+ std::vector<MXTensor> outputs,
+ OpResource op_res) {
+ std::map<std::string, std::string> attrs;
+ return backward(attrs, inputs, outputs, op_res);
+ }
+
+ private:
+ int count;
+};
+
+MXReturnValue createOpState(std::map<std::string, std::string> attrs,
+ CustomStatefulOp** op_inst) {
+ // testing passing of keyword arguments
+ int count = attrs.count("test_kw") > 0 ? std::stoi(attrs["test_kw"]) : 0;
+ // creating stateful operator instance
+ *op_inst = new MyStatefulTransposeRowSP(count);
+ std::cout << "Info: stateful operator created" << std::endl;
+ return MX_SUCCESS;
+}
+
+REGISTER_OP(my_state_transposerowsp)
+.setParseAttrs(parseAttrs)
+.setInferType(inferType)
+.setInferSType(inferSType)
+.setInferShape(inferShape)
+.setCreateOpState(createOpState, "cpu");
+
+MXReturnValue initialize(int version) {
+ if (version >= 10400) {
+ std::cout << "MXNet version " << version << " supported" << std::endl;
+ return MX_SUCCESS;
+ } else {
+ std::cout << "MXNet version " << version << " not supported" << std::endl;
+ return MX_FAIL;
+ }
+}
diff --git a/example/extensions/lib_subgraph/subgraph_lib.cc
b/example/extensions/lib_subgraph/subgraph_lib.cc
index 8c24dd8..d821bdb 100644
--- a/example/extensions/lib_subgraph/subgraph_lib.cc
+++ b/example/extensions/lib_subgraph/subgraph_lib.cc
@@ -84,7 +84,7 @@ MXReturnValue myExecutor(std::vector<MXTensor> inputs,
// get input tensor based on node ID inputs from data storage
MXTensor &input = data[node_inputs.list[0].list[0].num];
// create temporary storage
- MXTensor tmp(malloc(input.size()*4), input.shape, input.dtype, 0,
{"cpu", 0});
+ MXTensor tmp(malloc(input.size()*4), input.shape, input.dtype, 0,
{"cpu", 0}, kDefaultStorage);
// save allocated ptr to free later
to_free.push_back(tmp.data_ptr);
// execute log operator
@@ -95,7 +95,7 @@ MXReturnValue myExecutor(std::vector<MXTensor> inputs,
// get input tensor based on node ID inputs from data storage
MXTensor &input = data[node_inputs.list[0].list[0].num];
// create temporary storage
- MXTensor tmp(malloc(input.size()*4), input.shape, input.dtype, 0,
{"cpu", 0});
+ MXTensor tmp(malloc(input.size()*4), input.shape, input.dtype, 0,
{"cpu", 0}, kDefaultStorage);
// save allocated ptr to free later
to_free.push_back(tmp.data_ptr);
// execute exp operator
diff --git a/include/mxnet/lib_api.h b/include/mxnet/lib_api.h
index 9b32122..fd526ee 100644
--- a/include/mxnet/lib_api.h
+++ b/include/mxnet/lib_api.h
@@ -39,7 +39,7 @@
#include <utility>
#include <stdexcept>
-#define MX_LIBRARY_VERSION 4
+#define MX_LIBRARY_VERSION 5
/*!
* \brief For loading multiple custom op libraries in Linux, exporting same
symbol multiple
@@ -214,6 +214,18 @@ enum MXDType {
kUNSET = 100,
};
+/*
+ * MXTensor storage type.
+ */
+enum MXStorageType {
+ // dense
+ kDefaultStorage = 0,
+ // row sparse
+ kRowSparseStorage = 1,
+ // csr
+ kCSRStorage = 2,
+};
+
/*!
* \brief Context info passing from MXNet OpContext
* dev_type is string repr of supported context, currently only "cpu" and "gpu"
@@ -229,25 +241,64 @@ enum MXReturnValue {
MX_SUCCESS = 1,
};
+// For sparse tensors, read/write the data from NDarray via pointers.
+struct MXSparse {
+ // Pointer to data.
+ void *data{nullptr};
+ // length of (non-zero) data.
+ int64_t data_len;
+
+ // To store aux data for sparse.
+ // For CSR, indices stores the col index of non-zero elements.
+ // For row sparse, indices store row index of rows which have non-zero
elements.
+ int64_t* indices;
+ int64_t indices_len;
+
+ // For CSR, indptr gives the start and end index of data for each row.
+ // For row sparse, indptr is not used.
+ int64_t* indptr = nullptr;
+ int64_t indptr_len;
+
+ void set(void *data_ptr, const int64_t* dims, int ndims, void *idx,
+ int64_t num_idx, void *idx_ptr = nullptr, int64_t num_idx_ptr = 0) {
+ data = data_ptr;
+ // If CSR, num of non-zero elemets is num_idx,
+ // If row sparse, num of elements is num_idx * width.
+ data_len = num_idx;
+ if (!idx_ptr) {
+ for (int i = 1; i < ndims; ++i)
+ data_len *= dims[i];
+ }
+
+ indices = reinterpret_cast<int64_t*>(idx);
+ indices_len = num_idx;
+
+ if (idx_ptr) {
+ indptr = reinterpret_cast<int64_t*>(idx_ptr);
+ indptr_len = num_idx_ptr;
+ }
+ }
+};
+
/*!
* \brief Tensor data structure used by custom operator
*/
struct MXTensor {
- MXTensor() : data_ptr(nullptr), dtype(kUNSET), verID(0) {}
+ MXTensor() : data_ptr(nullptr), dtype(kUNSET), verID(0),
stype(kDefaultStorage) {}
MXTensor(const MXTensor& oth) : data_ptr(oth.data_ptr), shape(oth.shape),
- dtype(oth.dtype), verID(oth.verID), ctx(oth.ctx) {
+ dtype(oth.dtype), verID(oth.verID), ctx(oth.ctx), stype(oth.stype) {
setDLTensor();
}
MXTensor(void *data_ptr, const std::vector<int64_t> &shape, MXDType dtype,
- size_t vID, MXContext mx_ctx)
- : data_ptr(data_ptr), shape(shape), dtype(dtype), verID(vID), ctx(mx_ctx) {
+ size_t vID, MXContext mx_ctx, MXStorageType stype = kDefaultStorage)
+ : data_ptr(data_ptr), shape(shape), dtype(dtype), verID(vID), ctx(mx_ctx),
stype(stype) {
setDLTensor();
}
/*! \brief populate internal tensor fields */
void setTensor(void *dptr, MXDType type, const int64_t* dims, int ndims,
- size_t vID, MXContext mx_ctx) {
- data_ptr = dptr; dtype = type; verID = vID; ctx = mx_ctx;
+ size_t vID, MXContext mx_ctx, MXStorageType storage_type) {
+ data_ptr = dptr; dtype = type; verID = vID; ctx = mx_ctx; stype =
storage_type;
shape.clear();
for (int j = 0; j < ndims; j++) {
shape.push_back(dims[j]);
@@ -340,11 +391,12 @@ struct MXTensor {
verID == oth.verID &&
ctx.dev_type == oth.ctx.dev_type &&
ctx.dev_id == oth.ctx.dev_id &&
- shape == oth.shape;
+ shape == oth.shape &&
+ stype == oth.stype;
}
- // data is flatten 1D repr of tensor, elements are in continuous memory
- // user can access each element using the shape of tensor
+ // For dense, data_ptr points to data.
+ // For sparse, data_ptr points to MXSparse.
void *data_ptr;
// shape is in [2,3,4] format to represent high-dim tensor
@@ -362,11 +414,16 @@ struct MXTensor {
// corresponding DLTensor repr of MXTensor
// easy way to reuse functions taking DLTensor
DLTensor dltensor;
+
+ // storage type
+ MXStorageType stype;
};
/*! \brief resource malloc function to allocate memory inside Forward/Backward
functions */
typedef void* (*xpu_malloc_t)(void*, int);
+typedef void (*sparse_malloc_t)(void*, int, int, int, void**, int64_t**,
int64_t**);
+
#if defined(__NVCC__)
typedef cudaStream_t mx_stream_t;
#else
@@ -379,9 +436,11 @@ typedef void* (*xpu_malloc_t)(void*, int);
class OpResource {
public:
OpResource(xpu_malloc_t cpu_malloc_fp, void* cpu_alloc_fp,
- xpu_malloc_t gpu_malloc_fp, void* gpu_alloc_fp, void* stream)
+ xpu_malloc_t gpu_malloc_fp, void* gpu_alloc_fp, void* stream,
+ sparse_malloc_t sparse_malloc_fp, void* sparse_alloc_fp)
: cpu_malloc(cpu_malloc_fp), gpu_malloc(gpu_malloc_fp),
- cpu_alloc(cpu_alloc_fp), gpu_alloc(gpu_alloc_fp), cuda_stream(stream) {}
+ cpu_alloc(cpu_alloc_fp), gpu_alloc(gpu_alloc_fp), cuda_stream(stream),
+ sparse_malloc(sparse_malloc_fp), sparse_alloc(sparse_alloc_fp) {}
/*! \brief allocate cpu memory controlled by MXNet */
void* alloc_cpu(int size) {
@@ -398,6 +457,12 @@ class OpResource {
return static_cast<mx_stream_t>(cuda_stream);
}
+ /*! \brief allocate sparse memory controlled by MXNet */
+ void alloc_sparse(MXSparse* sparse, int index, int indices_len, int
indptr_len = 0) {
+ sparse_malloc(sparse_alloc, index, indices_len, indptr_len,
+ &(sparse->data), &(sparse->indices), &(sparse->indptr));
+ }
+
private:
/*! \brief allocation lambda function */
xpu_malloc_t cpu_malloc, gpu_malloc;
@@ -405,6 +470,10 @@ class OpResource {
void *cpu_alloc, *gpu_alloc;
/*! \brief cuda stream passed from MXNet */
void *cuda_stream;
+ /*! \brief sparse allocation lambda function */
+ sparse_malloc_t sparse_malloc;
+ /*! \brief lambda function to return allocated sparse memory handle */
+ void *sparse_alloc;
};
/*!
@@ -647,6 +716,8 @@ typedef MXReturnValue (*parseAttrs_t)(std::map<std::string,
std::string>,
int*, int*);
typedef MXReturnValue (*inferType_t)(std::map<std::string, std::string>,
std::vector<int>&, std::vector<int>&);
+typedef MXReturnValue (*inferSType_t)(std::map<std::string, std::string>,
+ std::vector<int>&, std::vector<int>&);
typedef MXReturnValue (*inferShape_t)(std::map<std::string, std::string>,
std::vector<std::vector<unsigned int> >&,
std::vector<std::vector<unsigned int>
>&);
@@ -660,9 +731,9 @@ typedef MXReturnValue
(*createOpState_t)(std::map<std::string, std::string>,
*/
class CustomOp {
public:
- explicit CustomOp(const char* op_name) :
- name(op_name), parse_attrs(nullptr), infer_type(nullptr),
- infer_shape(nullptr), mutate_inputs(nullptr), isSGop(false) {}
+ explicit CustomOp(const char* op_name) : name(op_name),
+ parse_attrs(NULL), infer_type(NULL), infer_storage_type(NULL),
infer_shape(NULL),
+ mutate_inputs(NULL), isSGop(false) {}
CustomOp& setForward(fcomp_t fcomp, const char* ctx) {
if (forward_ctx_map.count(ctx) > 0)
raiseDuplicateContextError();
@@ -683,6 +754,10 @@ class CustomOp {
infer_type = func;
return *this;
}
+ CustomOp& setInferSType(inferSType_t func) {
+ infer_storage_type = func;
+ return *this;
+ }
CustomOp& setInferShape(inferShape_t func) {
infer_shape = func;
return *this;
@@ -723,6 +798,7 @@ class CustomOp {
/*! \brief operator functions */
parseAttrs_t parse_attrs;
inferType_t infer_type;
+ inferSType_t infer_storage_type;
inferShape_t infer_shape;
mutateInputs_t mutate_inputs;
bool isSGop;
@@ -876,7 +952,7 @@ typedef int (*opRegGet_t)(int idx, const char** name, int
*isSGop,
const char*** backward_ctx, fcomp_t** backward_fp,
int* backward_count,
const char*** create_op_ctx, createOpState_t**
create_op_fp,
int* create_op_count,
- parseAttrs_t* parse, inferType_t* type,
+ parseAttrs_t* parse, inferType_t* type,
inferSType_t* stype,
inferShape_t* shape, mutateInputs_t* mutate);
#define MXLIB_OPCALLFREE_STR "_opCallFree"
@@ -898,6 +974,11 @@ typedef int (*opCallInferType_t)(inferType_t inferType,
const char* const* keys,
const char* const* vals, int num,
int* intypes, int num_in, int* outtypes, int
num_out);
+#define MXLIB_OPCALLINFERSTYPE_STR "_opCallInferSType"
+typedef int (*opCallInferSType_t)(inferSType_t inferSType, const char* const*
keys,
+ const char* const* vals, int num,
+ int* intypes, int num_in, int* outtypes, int
num_out);
+
#define MXLIB_OPCALLFCOMP_STR "_opCallFCompute"
typedef int (*opCallFComp_t)(fcomp_t fcomp, const char* const* keys,
const char* const* vals, int num,
@@ -910,7 +991,13 @@ typedef int (*opCallFComp_t)(fcomp_t fcomp, const char*
const* keys,
size_t* outIDs, const char** outdev_type,
int* outdev_id, int num_out,
xpu_malloc_t cpu_malloc, void* cpu_alloc,
- xpu_malloc_t gpu_malloc, void* gpu_alloc, void*
cuda_stream);
+ xpu_malloc_t gpu_malloc, void* gpu_alloc, void*
cuda_stream,
+ sparse_malloc_t sparse_malloc, void* sparse_alloc,
+ int* instypes, int* outstypes,
+ void** in_indices, void** out_indices,
+ void** in_indptr, void** out_indptr,
+ int64_t* in_indices_shapes, int64_t*
out_indices_shapes,
+ int64_t* in_indptr_shapes, int64_t*
out_indptr_shapes);
#define MXLIB_OPCALLMUTATEINPUTS_STR "_opCallMutateInputs"
typedef int (*opCallMutateInputs_t)(mutateInputs_t mutate, const char* const*
keys,
@@ -933,7 +1020,13 @@ typedef int (*opCallFStatefulComp_t)(int is_forward,
void* state_op,
size_t* outIDs, const char** outdev_type,
int* outdev_id, int num_out,
xpu_malloc_t cpu_malloc, void* cpu_alloc,
- xpu_malloc_t gpu_malloc, void* gpu_alloc,
void* stream);
+ xpu_malloc_t gpu_malloc, void* gpu_alloc,
void* stream,
+ sparse_malloc_t sparse_malloc, void*
sparse_alloc,
+ int* instypes, int* outstypes,
+ void** in_indices, void** out_indices,
+ void** in_indptr, void** out_indptr,
+ int64_t* in_indices_shapes, int64_t*
out_indices_shapes,
+ int64_t* in_indptr_shapes, int64_t*
out_indptr_shapes);
#define MXLIB_PARTREGSIZE_STR "_partRegSize"
typedef int (*partRegSize_t)(void);
@@ -1004,12 +1097,13 @@ extern "C" {
const char*** forward_ctx, fcomp_t** forward_fp, int*
forward_count,
const char*** backward_ctx, fcomp_t** backward_fp, int*
backward_count,
const char*** create_op_ctx, createOpState_t** create_op_fp, int*
create_op_count,
- parseAttrs_t* parse, inferType_t* type,
+ parseAttrs_t* parse, inferType_t* type, inferSType_t* stype,
inferShape_t* shape, mutateInputs_t* mutate) {
CustomOp &op = Registry<CustomOp>::get()->get(idx);
*name = op.name;
*parse = op.parse_attrs;
*type = op.infer_type;
+ *stype = op.infer_storage_type;
*shape = op.infer_shape;
*mutate = op.mutate_inputs;
*isSGop = op.isSGop;
@@ -1136,6 +1230,43 @@ extern "C" {
return retval;
}
+ /*! \brief returns status of calling inferSType function for operator from
library */
+#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__)
+ __declspec(dllexport) int __cdecl
+#else
+ int
+#endif
+ _opCallInferSType(inferSType_t inferSType, const char* const* keys,
+ const char* const* vals, int num,
+ int* instypes, int num_in, int* outstypes, int num_out) {
+ // create map of attributes from list
+ std::map<std::string, std::string> attrs;
+ for (int i = 0; i < num; i++) {
+ attrs[std::string(keys[i])] = std::string(vals[i]);
+ }
+
+ // create a vector of types for inputs
+ std::vector<int> in_stypes(num_in);
+ for (int i = 0; i < num_in; i++) {
+ in_stypes[i] = instypes[i];
+ }
+
+ // create a vector of types for outputs
+ std::vector<int> out_stypes(num_out, -1);
+
+ int retval = inferSType(attrs, in_stypes, out_stypes);
+
+ if (!retval)
+ return retval;
+
+ // copy output storage types
+ for (int i = 0; i < num_out; i++) {
+ outstypes[i] = out_stypes[i];
+ }
+
+ return retval;
+ }
+
/*! \brief returns status of calling Forward/Backward function for operator
from library */
#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__)
__declspec(dllexport) int __cdecl
@@ -1148,7 +1279,12 @@ extern "C" {
const int64_t** outshapes, int* outdims, void** outdata,
int* outtypes,
size_t* outIDs, const char** outdev_type, int* outdev_id,
int num_out,
xpu_malloc_t cpu_malloc, void* cpu_alloc,
- xpu_malloc_t gpu_malloc, void* gpu_alloc, void* cuda_stream)
{
+ xpu_malloc_t gpu_malloc, void* gpu_alloc, void* cuda_stream,
+ sparse_malloc_t sparse_malloc, void* sparse_alloc,
+ int* instypes, int* outstypes, void** in_indices, void**
out_indices,
+ void** in_indptr, void** out_indptr,
+ int64_t* in_indices_shapes, int64_t* out_indices_shapes,
+ int64_t* in_indptr_shapes, int64_t* out_indptr_shapes) {
// create map of attributes from list
std::map<std::string, std::string> attrs;
for (int i = 0; i < num; i++) {
@@ -1157,20 +1293,59 @@ extern "C" {
// create a vector of tensors for inputs
std::vector<MXTensor> inputs(num_in);
+ // create a vector for sparse inputs
+ std::vector<MXSparse> in_sparse(num_in);
+
for (int i = 0; i < num_in; i++) {
- inputs[i].setTensor(indata[i], (MXDType)intypes[i], inshapes[i],
indims[i],
- inIDs[i], {indev_type[i], indev_id[i]});
+ // Dense representation.
+ if (instypes[i] == 0) {
+ inputs[i].setTensor(indata[i], (MXDType)intypes[i], inshapes[i],
indims[i],
+ inIDs[i], {indev_type[i], indev_id[i]},
kDefaultStorage);
+ } else {
+ // Sparse representation.
+ MXStorageType type;
+ if (instypes[i] == 1) {
+ type = kRowSparseStorage;
+ in_sparse[i].set(indata[i], inshapes[i], indims[i], in_indices[i],
in_indices_shapes[i]);
+ } else {
+ type = kCSRStorage;
+ in_sparse[i].set(indata[i], inshapes[i], indims[i], in_indices[i],
+ in_indices_shapes[i], in_indptr[i],
in_indptr_shapes[i]);
+ }
+ inputs[i].setTensor(reinterpret_cast<void*>(&in_sparse[i]),
(MXDType)intypes[i],
+ inshapes[i], indims[i], inIDs[i], {indev_type[i],
indev_id[i]}, type);
+ }
}
// create a vector of tensors for outputs
std::vector<MXTensor> outputs(num_out);
+ std::vector<MXSparse> out_sparse(num_out);
+
for (int i = 0; i < num_out; i++) {
- outputs[i].setTensor(outdata[i], (MXDType)outtypes[i], outshapes[i],
outdims[i],
- outIDs[i], {outdev_type[i], outdev_id[i]});
+ // Dense representation.
+ if (outstypes[i] == 0) {
+ outputs[i].setTensor(outdata[i], (MXDType)outtypes[i], outshapes[i],
outdims[i],
+ outIDs[i], {outdev_type[i], outdev_id[i]},
kDefaultStorage);
+ } else {
+ // Sparse representation.
+ MXStorageType type;
+ if (outstypes[i] == 1) {
+ type = kRowSparseStorage;
+ out_sparse[i].set(outdata[i], outshapes[i], outdims[i],
+ out_indices[i], out_indices_shapes[i]);
+ } else {
+ type = kCSRStorage;
+ out_sparse[i].set(outdata[i], outshapes[i], outdims[i],
out_indices[i],
+ out_indices_shapes[i], out_indptr[i],
out_indptr_shapes[i]);
+ }
+ outputs[i].setTensor(reinterpret_cast<void*>(&out_sparse[i]),
(MXDType)outtypes[i],
+ outshapes[i], outdims[i], outIDs[i],
{outdev_type[i],
+ outdev_id[i]}, type);
+ }
}
- OpResource res(cpu_malloc, cpu_alloc, gpu_malloc, gpu_alloc, cuda_stream);
-
+ OpResource res(cpu_malloc, cpu_alloc, gpu_malloc, gpu_alloc,
+ cuda_stream, sparse_malloc, sparse_alloc);
return fcomp(attrs, inputs, outputs, res);
}
@@ -1239,22 +1414,69 @@ extern "C" {
const int64_t** outshapes, int* outdims, void**
outdata, int* outtypes,
size_t* outIDs, const char** outdev_type, int*
outdev_id, int num_out,
xpu_malloc_t cpu_malloc, void* cpu_alloc,
- xpu_malloc_t gpu_malloc, void* gpu_alloc, void*
stream) {
+ xpu_malloc_t gpu_malloc, void* gpu_alloc, void*
stream,
+ sparse_malloc_t sparse_malloc, void* sparse_alloc,
+ int* instypes, int* outstypes, void** in_indices,
void** out_indices,
+ void** in_indptr, void** out_indptr,
+ int64_t* in_indices_shapes, int64_t*
out_indices_shapes,
+ int64_t* in_indptr_shapes, int64_t*
out_indptr_shapes) {
// create a vector of tensors for inputs
std::vector<MXTensor> inputs(num_in);
+ // create a vector for sparse inputs
+ std::vector<MXSparse> in_sparse(num_in);
+
for (int i = 0; i < num_in; i++) {
- inputs[i].setTensor(indata[i], (MXDType)intypes[i], inshapes[i],
indims[i],
- inIDs[i], {indev_type[i], indev_id[i]});
+ if (instypes[i] == 0) {
+ // Dense representation.
+ inputs[i].setTensor(indata[i], (MXDType)intypes[i], inshapes[i],
indims[i],
+ inIDs[i], {indev_type[i], indev_id[i]},
kDefaultStorage);
+ } else {
+ // Sparse representation.
+ MXStorageType type;
+ if (instypes[i] == 1) {
+ type = kRowSparseStorage;
+ in_sparse[i].set(indata[i], inshapes[i], indims[i], in_indices[i],
in_indices_shapes[i]);
+ } else {
+ type = kCSRStorage;
+ in_sparse[i].set(indata[i], inshapes[i], indims[i], in_indices[i],
+ in_indices_shapes[i], in_indptr[i],
in_indptr_shapes[i]);
+ }
+ inputs[i].setTensor(reinterpret_cast<void*>(&in_sparse[i]),
(MXDType)intypes[i],
+ inshapes[i], indims[i], inIDs[i], {indev_type[i],
+ indev_id[i]}, type);
+ }
}
// create a vector of tensors for outputs
std::vector<MXTensor> outputs(num_out);
+ // create a vector for sparse outputs
+ std::vector<MXSparse> out_sparse(num_out);
+
for (int i = 0; i < num_out; i++) {
- outputs[i].setTensor(outdata[i], (MXDType)outtypes[i], outshapes[i],
outdims[i],
- outIDs[i], {outdev_type[i], outdev_id[i]});
+ if (outstypes[i] == 0) {
+ // Dense representation.
+ outputs[i].setTensor(outdata[i], (MXDType)outtypes[i], outshapes[i],
outdims[i],
+ outIDs[i], {outdev_type[i], outdev_id[i]},
kDefaultStorage);
+ } else {
+ // Sparse representation.
+ MXStorageType type;
+ if (outstypes[i] == 1) {
+ type = kRowSparseStorage;
+ out_sparse[i].set(outdata[i], outshapes[i], outdims[i],
out_indices[i],
+ out_indices_shapes[i]);
+ } else {
+ type = kCSRStorage;
+ out_sparse[i].set(outdata[i], outshapes[i], outdims[i],
out_indices[i],
+ out_indices_shapes[i], out_indptr[i],
out_indptr_shapes[i]);
+ }
+ outputs[i].setTensor(reinterpret_cast<void*>(&out_sparse[i]),
(MXDType)outtypes[i],
+ outshapes[i], outdims[i], outIDs[i],
{outdev_type[i],
+ outdev_id[i]}, type);
+ }
}
- OpResource res(cpu_malloc, cpu_alloc, gpu_malloc, gpu_alloc, stream);
+ OpResource res(cpu_malloc, cpu_alloc, gpu_malloc, gpu_alloc,
+ stream, sparse_malloc, sparse_alloc);
CustomStatefulOp* op_ptr = reinterpret_cast<CustomStatefulOp*>(state_op);
if (is_forward) {
diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc
index db0e262..fe00a9a 100644
--- a/src/c_api/c_api.cc
+++ b/src/c_api/c_api.cc
@@ -114,7 +114,7 @@ void CustomFComputeDispatcher(const std::string op_name,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
std::vector<void*> in_data, out_data;
- std::vector<const int64_t *> in_shapes, out_shapes;
+ std::vector<const int64_t*> in_shapes, out_shapes;
std::vector<int> in_dims, out_dims;
std::vector<int> in_types, out_types;
std::vector<size_t> in_verIDs, out_verIDs;
@@ -122,6 +122,13 @@ void CustomFComputeDispatcher(const std::string op_name,
std::vector<int> in_dev_id, out_dev_id;
std::vector<NDArray> conv_mkl; // converted NDArrays from MKLDNN format
+ // Extra data for sparse inputs and outputs.
+ std::vector<int> in_stypes(inputs.size(), 0), out_stypes(outputs.size(), 0);
+ std::vector<void*> in_indices(inputs.size(), nullptr),
out_indices(outputs.size(), nullptr);
+ std::vector<void*> in_indptr(inputs.size(), nullptr),
out_indptr(outputs.size(), nullptr);
+ std::vector<int64_t> in_indices_shapes(inputs.size(), 0),
out_indices_shapes(outputs.size(), 0);
+ std::vector<int64_t> in_indptr_shapes(inputs.size(), 0),
out_indptr_shapes(outputs.size(), 0);
+
// convert inputs/outpus NDArray to C types to be passed to lib_api.h
for (size_t i = 0; i < inputs.size(); i++) {
NDArray const* in_nd = &(inputs[i]);
@@ -141,7 +148,19 @@ void CustomFComputeDispatcher(const std::string op_name,
in_verIDs.push_back(in_nd->version());
const char* ctx_str = in_nd->ctx().dev_mask() == Context::kCPU ? "cpu" :
"gpu";
in_dev_type.push_back(ctx_str);
+
in_dev_id.push_back(in_nd->ctx().real_dev_id());
+ if (inputs[i].storage_type() == mxnet::kRowSparseStorage) {
+ in_stypes[i] = 1;
+ in_indices[i] = inputs[i].aux_data(rowsparse::kIdx).dptr_;
+ in_indices_shapes[i] = inputs[i].aux_shape(rowsparse::kIdx).Size();
+ } else if (inputs[i].storage_type() == mxnet::kCSRStorage) {
+ in_stypes[i] = 2;
+ in_indices[i] = inputs[i].aux_data(csr::kIdx).dptr_;
+ in_indptr[i] = inputs[i].aux_data(csr::kIndPtr).dptr_;
+ in_indices_shapes[i] = inputs[i].aux_shape(csr::kIdx).Size();
+ in_indptr_shapes[i] = inputs[i].aux_shape(csr::kIndPtr).Size();
+ }
}
for (size_t i = 0; i < outputs.size(); i++) {
@@ -153,6 +172,18 @@ void CustomFComputeDispatcher(const std::string op_name,
const char* ctx_str = outputs[i].ctx().dev_mask() == Context::kCPU ? "cpu"
: "gpu";
out_dev_type.push_back(ctx_str);
out_dev_id.push_back(outputs[i].ctx().real_dev_id());
+
+ if (outputs[i].storage_type() == mxnet::kRowSparseStorage) {
+ out_stypes[i] = 1;
+ out_indices[i] = outputs[i].aux_data(rowsparse::kIdx).dptr_;
+ out_indices_shapes[i] = outputs[i].aux_shape(rowsparse::kIdx).Size();
+ } else if (outputs[i].storage_type() == mxnet::kCSRStorage) {
+ out_stypes[i] = 2;
+ out_indices[i] = outputs[i].aux_data(csr::kIdx).dptr_;
+ out_indptr[i] = outputs[i].aux_data(csr::kIndPtr).dptr_;
+ out_indices_shapes[i] = outputs[i].aux_shape(csr::kIdx).Size();
+ out_indptr_shapes[i] = outputs[i].aux_shape(csr::kIndPtr).Size();
+ }
}
// get memory resource and mxnet backend streams
@@ -173,6 +204,24 @@ void CustomFComputeDispatcher(const std::string op_name,
return workspace.dptr_;
};
+ // create lambda that allocates memory for sparse and
+ // returns allocated arrays for data, indices and indptr.
+ auto sparse_alloc = [&](int index, int indices_len, int idxptr_len,
+ void** data, int64_t** indices, int64_t** indptr) {
+ if (idxptr_len == 0) {
+ // Row Sparse
+ outputs[index].CheckAndAlloc({mshadow::Shape1(indices_len)});
+ *data = outputs[index].data().dptr_;
+ *indices =
reinterpret_cast<int64_t*>(outputs[index].aux_data(rowsparse::kIdx).dptr_);
+ } else {
+ // CSR
+ outputs[index].CheckAndAlloc({mshadow::Shape1(idxptr_len),
mshadow::Shape1(indices_len)});
+ *data = outputs[index].data().dptr_;
+ *indices =
reinterpret_cast<int64_t*>(outputs[index].aux_data(csr::kIdx).dptr_);
+ *indptr =
reinterpret_cast<int64_t*>(outputs[index].aux_data(csr::kIndPtr).dptr_);
+ }
+ };
+
// create lambda without captures so that we can cast it to function pointer
// lambda with captures cannot be cast to function pointer and pass to
lib_api.h
// this needs to be a lambda function so that we can do the decltype cast
@@ -189,6 +238,13 @@ void CustomFComputeDispatcher(const std::string op_name,
return static_cast<void*>((*gpualloc)(size));
};
+ typedef decltype(sparse_alloc) alloc_type_sparse;
+ auto sparse_malloc = [](void* _sparse_alloc, int index, int indices_len, int
idxptr_len,
+ void** data, int64_t** indices, int64_t** indptr) {
+ alloc_type_sparse* sparsealloc =
static_cast<alloc_type_sparse*>(_sparse_alloc);
+ (*sparsealloc)(index, indices_len, idxptr_len, data, indices, indptr);
+ };
+
// get actual cudaStream_t out of mxnet gpu stream and pass to lib_api.h
void *cuda_stream = nullptr;
#if MXNET_USE_CUDA
@@ -208,13 +264,18 @@ void CustomFComputeDispatcher(const std::string op_name,
attr_keys.push_back(kv.first.c_str());
attr_vals.push_back(kv.second.c_str());
}
+
// call fcompute function
CHECK(callFComp(fcomp_fp, attr_keys.data(), attr_vals.data(),
attr_keys.size(),
in_shapes.data(), in_dims.data(), in_data.data(),
in_types.data(),
in_verIDs.data(), in_dev_type.data(), in_dev_id.data(),
in_data.size(),
out_shapes.data(), out_dims.data(), out_data.data(),
out_types.data(),
out_verIDs.data(), out_dev_type.data(), out_dev_id.data(),
out_data.size(),
- cpu_malloc, &cpu_alloc, gpu_malloc, &gpu_alloc,
cuda_stream))
+ cpu_malloc, &cpu_alloc, gpu_malloc, &gpu_alloc,
cuda_stream,
+ sparse_malloc, &sparse_alloc, in_stypes.data(),
out_stypes.data(),
+ in_indices.data(), out_indices.data(), in_indptr.data(),
out_indptr.data(),
+ in_indices_shapes.data(), out_indices_shapes.data(),
+ in_indptr_shapes.data(), out_indptr_shapes.data()))
<< "Error calling FCompute for custom operator '" << op_name << "'";
}
@@ -233,7 +294,12 @@ void CustomFComputeDispatcher(const std::string op_name,
out_shapes.data(), out_dims.data(),
out_data.data(), out_types.data(),
out_verIDs.data(), out_dev_type.data(),
out_dev_id.data(),
out_data.size(),
- cpu_malloc, &cpu_alloc, gpu_malloc, &gpu_alloc,
cuda_stream))
+ cpu_malloc, &cpu_alloc, gpu_malloc, &gpu_alloc,
cuda_stream,
+ sparse_malloc, &sparse_alloc, in_stypes.data(),
out_stypes.data(),
+ in_indices.data(), out_indices.data(),
+ in_indptr.data(), out_indptr.data(),
+ in_indices_shapes.data(),
out_indices_shapes.data(),
+ in_indptr_shapes.data(), out_indptr_shapes.data()))
<< "Error calling FStatefulCompute for custom operator '" << op_name <<
"'";
}
}
@@ -272,6 +338,9 @@ int MXLoadLib(const char *path) {
opCallInferType_t callInferType =
get_func<opCallInferType_t>(lib,
const_cast<char*>(MXLIB_OPCALLINFERTYPE_STR));
+ opCallInferSType_t callInferSType =
+ get_func<opCallInferSType_t>(lib,
const_cast<char*>(MXLIB_OPCALLINFERSTYPE_STR));
+
opCallFComp_t callFComp =
get_func<opCallFComp_t>(lib, const_cast<char*>(MXLIB_OPCALLFCOMP_STR));
@@ -306,6 +375,7 @@ int MXLoadLib(const char *path) {
// function pointers holding implementation from custom library
parseAttrs_t parse_fp = nullptr;
inferType_t type_fp = nullptr;
+ inferSType_t stype_fp = nullptr;
inferShape_t shape_fp = nullptr;
// optional attributes
mutateInputs_t mutate_fp = nullptr;
@@ -322,7 +392,7 @@ int MXLoadLib(const char *path) {
&forward_ctx, &forward_fcomp, &forward_count,
&backward_ctx, &backward_fcomp, &backward_count,
&createop_ctx, &createop_fp, &createop_count,
- &parse_fp, &type_fp, &shape_fp, &mutate_fp);
+ &parse_fp, &type_fp, &stype_fp, &shape_fp, &mutate_fp);
// construct maps of context to forward/backward custom library function
std::unordered_map<std::string, fcomp_t> forward_ctx_map;
@@ -583,12 +653,39 @@ int MXLoadLib(const char *path) {
DispatchMode* dispatch_mode,
std::vector<int>* in_stypes,
std::vector<int>* out_stypes) {
- // TODO(ziyimu): remove this dense enforce check after supporting sparse
tensor
- CHECK(mxnet::common::ContainsOnlyStorage(*in_stypes,
mxnet::kDefaultStorage))
- << "Error input tensors are not dense for custom operator '" << name_str
<< "'";
- // set outputs as dense
- return op::storage_type_assign(out_stypes, mxnet::kDefaultStorage,
- dispatch_mode, DispatchMode::kFComputeEx);
+ if (stype_fp == nullptr) {
+ // InferSType is not defineid in customized lib.
+ CHECK(mxnet::common::ContainsOnlyStorage(*in_stypes,
mxnet::kDefaultStorage))
+ << "Error input tensors are not dense for custom operator '" <<
name_str << "'";
+ // set outputs as dense
+ return op::storage_type_assign(out_stypes, mxnet::kDefaultStorage,
+ dispatch_mode,
DispatchMode::kFComputeEx);
+ } else {
+ // InferSType is defined in customized lib.
+ // convert attributes to vector of char*
+ std::vector<const char*> attr_keys, attr_vals;
+ for (auto kv : attrs.dict) {
+ attr_keys.push_back(kv.first.c_str());
+ attr_vals.push_back(kv.second.c_str());
+ }
+ // copy input types from in_stype
+ std::vector<int> instypes(*in_stypes);
+
+ // output types will be populated by inferType function
+ std::vector<int> outstypes(out_stypes->size());
+ CHECK(callInferSType(stype_fp, attr_keys.data(), attr_vals.data(),
attr_keys.size(),
+ instypes.data(), in_stypes->size(),
+ outstypes.data(), out_stypes->size()))
+ << "Error calling InferSType for custom operator '" << name_str << "'";
+
+ // copy and assign output storage types from custom op to MXNet memory.
+ for (size_t i = 0; i < out_stypes->size(); i++) {
+ STORAGE_TYPE_ASSIGN_CHECK(*out_stypes, i, outstypes[i]);
+ }
+ // assign dispatch mode
+ DISPATCH_MODE_ASSIGN_CHECK(dispatch_mode, 0,
DispatchMode::kFComputeEx);
+ return true;
+ }
};
// FGradient register lambda
@@ -698,8 +795,8 @@ int MXLoadLib(const char *path) {
regOp.set_num_inputs(num_inputs);
regOp.set_num_outputs(num_outputs);
regOp.set_attr<nnvm::FInferType>("FInferType", infer_type, plevel);
- regOp.set_attr<mxnet::FInferShape>("FInferShape", infer_shape, plevel);
regOp.set_attr<FInferStorageType>("FInferStorageType",
infer_storage_type, plevel);
+ regOp.set_attr<mxnet::FInferShape>("FInferShape", infer_shape, plevel);
regOp.set_attr<FResourceRequest>("FResourceRequest", resc_req, plevel);
// optionally add fmutate inputs if user specified a function
if (mutate_fp != nullptr)