piiswrong closed pull request #9870: New example of custom operator using RTC
URL: https://github.com/apache/incubator-mxnet/pull/9870
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/example/numpy-ops/README.md b/example/numpy-ops/README.md
index 1ec8a404c3..aa4911f674 100644
--- a/example/numpy-ops/README.md
+++ b/example/numpy-ops/README.md
@@ -1,7 +1,4 @@
-# Training MNIST With NumpyOp
+# Training with Custom Operators in Python
 
-Uses the same setup as example/mnist/mlp.py. Except the loss symbol is
-custom defined with NumpyOp. mxnet.operator.NumpyOp help move computation
-in a symbol's forward/backward operation to python frontend. This is for
-fast implementation/experimentation of non-performance-critical symbols.
-If it is becoming a bottleneck, please consider write a C++/CUDA version.
\ No newline at end of file
+These examples demonstrate custom operator implementations in python. 
+You can implement the computation entirely in python or write custom CUDA 
kernels in C/C++ inside your python source code with a help of Run-Time 
Compilation (RTC).
diff --git a/example/numpy-ops/custom_softmax_rtc.py 
b/example/numpy-ops/custom_softmax_rtc.py
new file mode 100644
index 0000000000..906cbbeac0
--- /dev/null
+++ b/example/numpy-ops/custom_softmax_rtc.py
@@ -0,0 +1,162 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# pylint: skip-file
+
+import logging
+import numpy as np
+import mxnet as mx
+
+class Softmax(mx.operator.CustomOp):
+    def __init__(self):
+        self.fwd_kernel_mod = None
+        self.bwd_kernel_mod = None
+        super().__init__()
+
+    def forward(self, is_train, req, in_data, out_data, aux):
+        if req[0] == "null":
+            return
+        x = in_data[0]  # input
+        y = out_data[0] # output
+        if self.fwd_kernel_mod is None:
+            # Each thread processes a row (a sample in the batch).
+            src = r"""
+                template<class DType>
+                __global__ void fwd(const DType* x, DType* y, const int 
row_size, const int req) {
+                    const int offset = row_size * threadIdx.x;
+                    DType max = x[offset];
+                    for(int i = 1; i < row_size; ++i) {
+                        if(max < x[offset + i]) {
+                            max = x[offset + i];
+                        }
+                    }
+                    DType sum = 0;
+                    for(int i = 0; i < row_size; ++i) {
+                        sum += exp(x[offset + i] - max);
+                    }
+                    switch(req) {
+                        case 1:
+                            for(int i = 0; i < row_size; ++i) {
+                                y[offset + i] = exp(x[offset + i] - max) / sum;
+                            }
+                            break;
+                        case 2:
+                            for(int i = 0; i < row_size; ++i) {
+                                y[offset + i] += exp(x[offset + i] - max) / 
sum;
+                            }
+                            break;
+                    }
+                }
+                """
+            self.fwd_kernel_mod = mx.rtc.CudaModule(src, 
exports=["fwd<float>", "fwd<double>"])
+        dtype = "double" if y.dtype == np.float64 else "float"
+        kernel_signature = "const {0}*, const {0}*, const int, const 
int".format(dtype)
+        kernel = self.fwd_kernel_mod.get_kernel("fwd<{}>".format(dtype), 
kernel_signature)
+        # args, ctx, grid_shape, block_shape, shared_mem = 0
+        kernel.launch((x, y, x.shape[1], self._reqCode(req[0])), mx.gpu(0), 
(1, 1, 1), (x.shape[0], 1, 1))
+
+    def backward(self, req, out_grad, in_data, out_data, in_grad, aux):
+        if req[0] == "null":
+            return
+        l = in_data[1]  # label
+        y = out_data[0] # output from the forward pass
+        dx = in_grad[0] # the storage for the gradient
+        if self.bwd_kernel_mod is None:
+            # Each block processes a row and each thread in a block calculate 
an element of `dx`.
+            src = r"""
+                template<class DType>
+                __global__ void bwd(const DType* l, const DType* y, DType* dx, 
const int req) {
+                    const int z = static_cast<int>(l[blockIdx.x]);
+                    const int i = threadIdx.x + blockDim.x * blockIdx.x;
+                    if(req == 1) {
+                        dx[i]  = threadIdx.x == z ? y[i] - 1 : y[i];
+                    } else {
+                        dx[i] += threadIdx.x == z ? y[i] - 1 : y[i];
+                    }
+                }
+                """
+            self.bwd_kernel_mod = mx.rtc.CudaModule(src, 
exports=["bwd<float>", "bwd<double>"])
+        dtype = "double" if dx.dtype == np.float64 else "float"
+        kernel_signature = "const {0}*, const {0}*, {0}*, const 
int".format(dtype)
+        kernel = self.bwd_kernel_mod.get_kernel("bwd<{}>".format(dtype), 
kernel_signature)
+        # args, ctx, grid_shape, block_shape, shared_mem = 0
+        kernel.launch((l, y, dx, self._reqCode(req[0])), mx.gpu(0), 
(y.shape[0], 1, 1), (y.shape[1], 1, 1))
+
+    def _reqCode(self, req):
+        if(req == "write"):
+            return 1
+        elif(req == "add"):
+            return 2
+        elif(req == "null"):
+            return 0
+        else:
+            raise ValueError("Invalid value of `req`: {}".format(req))
+
+
+@mx.operator.register("softmax")
+class SoftmaxProp(mx.operator.CustomOpProp):
+    def __init__(self):
+        super(SoftmaxProp, self).__init__(need_top_grad=False)
+
+    def list_arguments(self):
+        return ['data', 'label']
+
+    def list_outputs(self):
+        return ['output']
+
+    def infer_shape(self, in_shape):
+        data_shape = in_shape[0]
+        label_shape = (in_shape[0][0],)
+        output_shape = in_shape[0]
+        return [data_shape, label_shape], [output_shape], []
+
+    def infer_type(self, in_type):
+        return in_type, [in_type[0]], []
+
+    def create_operator(self, ctx, in_shapes, in_dtypes):
+        return Softmax()
+
+# define mlp
+
+data = mx.symbol.Variable('data')
+fc1 = mx.symbol.FullyConnected(data=data, name='fc1', num_hidden=128)
+act1 = mx.symbol.Activation(data=fc1, name='relu1', act_type="relu")
+fc2 = mx.symbol.FullyConnected(data=act1, name='fc2', num_hidden=64)
+act2 = mx.symbol.Activation(data=fc2, name='relu2', act_type="relu")
+fc3 = mx.symbol.FullyConnected(data=act2, name='fc3', num_hidden=10)
+#mlp = mx.symbol.SoftmaxOutput(data = fc3, name = 'softmax')
+mlp = mx.symbol.Custom(data=fc3, name='softmax', op_type='softmax')
+
+# data
+
+train, val = mx.test_utils.get_mnist_iterator(batch_size=100, 
input_shape=(784,))
+
+# train
+
+logging.basicConfig(level=logging.DEBUG)
+
+context = mx.gpu(0)
+mod = mx.mod.Module(mlp, context=context)
+mod.fit(
+    train_data=train,
+    eval_data=val,
+    optimizer='sgd',
+    optimizer_params={'learning_rate':0.1, 'momentum': 0.9, 'wd': 0.00001},
+    num_epoch=10,
+    batch_end_callback=mx.callback.Speedometer(100, 100)
+)
+
diff --git a/example/numpy-ops/ndarray_softmax.py 
b/example/numpy-ops/ndarray_softmax.py
deleted file mode 100644
index 58eab3d538..0000000000
--- a/example/numpy-ops/ndarray_softmax.py
+++ /dev/null
@@ -1,106 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-# pylint: skip-file
-import mxnet as mx
-from mxnet.test_utils import get_mnist_iterator
-import numpy as np
-import logging
-
-class NDArraySoftmax(mx.operator.NDArrayOp):
-    def __init__(self):
-        super(NDArraySoftmax, self).__init__(False)
-        self.fwd_kernel = None
-        self.bwd_kernel = None
-
-    def list_arguments(self):
-        return ['data', 'label']
-
-    def list_outputs(self):
-        return ['output']
-
-    def infer_shape(self, in_shape):
-        data_shape = in_shape[0]
-        label_shape = (in_shape[0][0],)
-        output_shape = in_shape[0]
-        return [data_shape, label_shape], [output_shape]
-
-    def forward(self, in_data, out_data):
-        x = in_data[0]
-        y = out_data[0]
-        if self.fwd_kernel is None:
-            self.fwd_kernel = mx.rtc('softmax', [('x', x)], [('y', y)], """
-int i = threadIdx.x + blockIdx.x*blockDim.x;
-float max_x = x[i*x_dims[1]];
-for (int j = 1; j < x_dims[1]; ++j) {
-    if (max_x < x[i*x_dims[1]+j]) {
-        max_x = x[i*x_dims[1]+j];
-    }
-}
-float sum = 0.0f;
-for (int j = 0; j < x_dims[1]; ++j) {
-    sum += expf(x[i*x_dims[1]+j]-max_x);
-}
-for (int j = 0; j < x_dims[1]; ++j) {
-    y[i*x_dims[1]+j] = expf(x[i*x_dims[1]+j]-max_x)/sum;
-}
-""")
-        self.fwd_kernel.push([x], [y], (1, 1, 1), (x.shape[0], 1, 1))
-
-    def backward(self, out_grad, in_data, out_data, in_grad):
-        l = in_data[1]
-        y = out_data[0]
-        dx = in_grad[0]
-        if self.bwd_kernel is None:
-            self.bwd_kernel = mx.rtc('softmax_grad', [('y', y), ('l', l)], 
[('dx', dx)], """
-int i = blockIdx.x;
-int j = threadIdx.x;
-int k = static_cast<int>(l[i]);
-if (j == k) {
-    dx[i*dx_dims[1]+j] = y[i*dx_dims[1]+j] - 1.0f;
-} else {
-    dx[i*dx_dims[1]+j] = y[i*dx_dims[1]+j];
-}
-""")
-        self.bwd_kernel.push([y,l], [dx], (y.shape[0],1,1), (y.shape[1], 1, 1))
-
-# define mlp
-
-data = mx.symbol.Variable('data')
-fc1 = mx.symbol.FullyConnected(data = data, name='fc1', num_hidden=128)
-act1 = mx.symbol.Activation(data = fc1, name='relu1', act_type="relu")
-fc2 = mx.symbol.FullyConnected(data = act1, name = 'fc2', num_hidden = 64)
-act2 = mx.symbol.Activation(data = fc2, name='relu2', act_type="relu")
-fc3 = mx.symbol.FullyConnected(data = act2, name='fc3', num_hidden=10)
-#mlp = mx.symbol.Softmax(data = fc3, name = 'mlp')
-mysoftmax = NDArraySoftmax()
-mlp = mysoftmax(data=fc3, name = 'softmax')
-
-# data
-
-train, val = get_mnist_iterator(batch_size=100, input_shape = (784,))
-
-# train
-
-logging.basicConfig(level=logging.DEBUG)
-
-model = mx.model.FeedForward(
-    ctx = mx.gpu(0), symbol = mlp, num_epoch = 20,
-    learning_rate = 0.1, momentum = 0.9, wd = 0.00001)
-
-model.fit(X=train, eval_data=val)
-


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to