This is an automated email from the ASF dual-hosted git repository.

jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new b23e0a9  New example of custom operator using RTC (#9870)
b23e0a9 is described below

commit b23e0a9f9f28f886ab20e48d0fcabcf0f8db91c4
Author: Deokjae Lee <36436141+asitsta...@users.noreply.github.com>
AuthorDate: Sat Feb 24 03:15:14 2018 +0900

    New example of custom operator using RTC (#9870)
    
    * Example implementation of the softmax output layer using RTC
    
    * Remove broken NDArrayOp example
    
    * Update README.md of the python custom operator examples
---
 example/numpy-ops/README.md             |   9 +-
 example/numpy-ops/custom_softmax_rtc.py | 162 ++++++++++++++++++++++++++++++++
 example/numpy-ops/ndarray_softmax.py    | 106 ---------------------
 3 files changed, 165 insertions(+), 112 deletions(-)

diff --git a/example/numpy-ops/README.md b/example/numpy-ops/README.md
index 1ec8a40..aa4911f 100644
--- a/example/numpy-ops/README.md
+++ b/example/numpy-ops/README.md
@@ -1,7 +1,4 @@
-# Training MNIST With NumpyOp
+# Training with Custom Operators in Python
 
-Uses the same setup as example/mnist/mlp.py. Except the loss symbol is
-custom defined with NumpyOp. mxnet.operator.NumpyOp help move computation
-in a symbol's forward/backward operation to python frontend. This is for
-fast implementation/experimentation of non-performance-critical symbols.
-If it is becoming a bottleneck, please consider write a C++/CUDA version.
\ No newline at end of file
+These examples demonstrate custom operator implementations in python. 
+You can implement the computation entirely in python or write custom CUDA 
kernels in C/C++ inside your python source code with a help of Run-Time 
Compilation (RTC).
diff --git a/example/numpy-ops/custom_softmax_rtc.py 
b/example/numpy-ops/custom_softmax_rtc.py
new file mode 100644
index 0000000..906cbbe
--- /dev/null
+++ b/example/numpy-ops/custom_softmax_rtc.py
@@ -0,0 +1,162 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# pylint: skip-file
+
+import logging
+import numpy as np
+import mxnet as mx
+
+class Softmax(mx.operator.CustomOp):
+    def __init__(self):
+        self.fwd_kernel_mod = None
+        self.bwd_kernel_mod = None
+        super().__init__()
+
+    def forward(self, is_train, req, in_data, out_data, aux):
+        if req[0] == "null":
+            return
+        x = in_data[0]  # input
+        y = out_data[0] # output
+        if self.fwd_kernel_mod is None:
+            # Each thread processes a row (a sample in the batch).
+            src = r"""
+                template<class DType>
+                __global__ void fwd(const DType* x, DType* y, const int 
row_size, const int req) {
+                    const int offset = row_size * threadIdx.x;
+                    DType max = x[offset];
+                    for(int i = 1; i < row_size; ++i) {
+                        if(max < x[offset + i]) {
+                            max = x[offset + i];
+                        }
+                    }
+                    DType sum = 0;
+                    for(int i = 0; i < row_size; ++i) {
+                        sum += exp(x[offset + i] - max);
+                    }
+                    switch(req) {
+                        case 1:
+                            for(int i = 0; i < row_size; ++i) {
+                                y[offset + i] = exp(x[offset + i] - max) / sum;
+                            }
+                            break;
+                        case 2:
+                            for(int i = 0; i < row_size; ++i) {
+                                y[offset + i] += exp(x[offset + i] - max) / 
sum;
+                            }
+                            break;
+                    }
+                }
+                """
+            self.fwd_kernel_mod = mx.rtc.CudaModule(src, 
exports=["fwd<float>", "fwd<double>"])
+        dtype = "double" if y.dtype == np.float64 else "float"
+        kernel_signature = "const {0}*, const {0}*, const int, const 
int".format(dtype)
+        kernel = self.fwd_kernel_mod.get_kernel("fwd<{}>".format(dtype), 
kernel_signature)
+        # args, ctx, grid_shape, block_shape, shared_mem = 0
+        kernel.launch((x, y, x.shape[1], self._reqCode(req[0])), mx.gpu(0), 
(1, 1, 1), (x.shape[0], 1, 1))
+
+    def backward(self, req, out_grad, in_data, out_data, in_grad, aux):
+        if req[0] == "null":
+            return
+        l = in_data[1]  # label
+        y = out_data[0] # output from the forward pass
+        dx = in_grad[0] # the storage for the gradient
+        if self.bwd_kernel_mod is None:
+            # Each block processes a row and each thread in a block calculate 
an element of `dx`.
+            src = r"""
+                template<class DType>
+                __global__ void bwd(const DType* l, const DType* y, DType* dx, 
const int req) {
+                    const int z = static_cast<int>(l[blockIdx.x]);
+                    const int i = threadIdx.x + blockDim.x * blockIdx.x;
+                    if(req == 1) {
+                        dx[i]  = threadIdx.x == z ? y[i] - 1 : y[i];
+                    } else {
+                        dx[i] += threadIdx.x == z ? y[i] - 1 : y[i];
+                    }
+                }
+                """
+            self.bwd_kernel_mod = mx.rtc.CudaModule(src, 
exports=["bwd<float>", "bwd<double>"])
+        dtype = "double" if dx.dtype == np.float64 else "float"
+        kernel_signature = "const {0}*, const {0}*, {0}*, const 
int".format(dtype)
+        kernel = self.bwd_kernel_mod.get_kernel("bwd<{}>".format(dtype), 
kernel_signature)
+        # args, ctx, grid_shape, block_shape, shared_mem = 0
+        kernel.launch((l, y, dx, self._reqCode(req[0])), mx.gpu(0), 
(y.shape[0], 1, 1), (y.shape[1], 1, 1))
+
+    def _reqCode(self, req):
+        if(req == "write"):
+            return 1
+        elif(req == "add"):
+            return 2
+        elif(req == "null"):
+            return 0
+        else:
+            raise ValueError("Invalid value of `req`: {}".format(req))
+
+
+@mx.operator.register("softmax")
+class SoftmaxProp(mx.operator.CustomOpProp):
+    def __init__(self):
+        super(SoftmaxProp, self).__init__(need_top_grad=False)
+
+    def list_arguments(self):
+        return ['data', 'label']
+
+    def list_outputs(self):
+        return ['output']
+
+    def infer_shape(self, in_shape):
+        data_shape = in_shape[0]
+        label_shape = (in_shape[0][0],)
+        output_shape = in_shape[0]
+        return [data_shape, label_shape], [output_shape], []
+
+    def infer_type(self, in_type):
+        return in_type, [in_type[0]], []
+
+    def create_operator(self, ctx, in_shapes, in_dtypes):
+        return Softmax()
+
+# define mlp
+
+data = mx.symbol.Variable('data')
+fc1 = mx.symbol.FullyConnected(data=data, name='fc1', num_hidden=128)
+act1 = mx.symbol.Activation(data=fc1, name='relu1', act_type="relu")
+fc2 = mx.symbol.FullyConnected(data=act1, name='fc2', num_hidden=64)
+act2 = mx.symbol.Activation(data=fc2, name='relu2', act_type="relu")
+fc3 = mx.symbol.FullyConnected(data=act2, name='fc3', num_hidden=10)
+#mlp = mx.symbol.SoftmaxOutput(data = fc3, name = 'softmax')
+mlp = mx.symbol.Custom(data=fc3, name='softmax', op_type='softmax')
+
+# data
+
+train, val = mx.test_utils.get_mnist_iterator(batch_size=100, 
input_shape=(784,))
+
+# train
+
+logging.basicConfig(level=logging.DEBUG)
+
+context = mx.gpu(0)
+mod = mx.mod.Module(mlp, context=context)
+mod.fit(
+    train_data=train,
+    eval_data=val,
+    optimizer='sgd',
+    optimizer_params={'learning_rate':0.1, 'momentum': 0.9, 'wd': 0.00001},
+    num_epoch=10,
+    batch_end_callback=mx.callback.Speedometer(100, 100)
+)
+
diff --git a/example/numpy-ops/ndarray_softmax.py 
b/example/numpy-ops/ndarray_softmax.py
deleted file mode 100644
index 58eab3d..0000000
--- a/example/numpy-ops/ndarray_softmax.py
+++ /dev/null
@@ -1,106 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-# pylint: skip-file
-import mxnet as mx
-from mxnet.test_utils import get_mnist_iterator
-import numpy as np
-import logging
-
-class NDArraySoftmax(mx.operator.NDArrayOp):
-    def __init__(self):
-        super(NDArraySoftmax, self).__init__(False)
-        self.fwd_kernel = None
-        self.bwd_kernel = None
-
-    def list_arguments(self):
-        return ['data', 'label']
-
-    def list_outputs(self):
-        return ['output']
-
-    def infer_shape(self, in_shape):
-        data_shape = in_shape[0]
-        label_shape = (in_shape[0][0],)
-        output_shape = in_shape[0]
-        return [data_shape, label_shape], [output_shape]
-
-    def forward(self, in_data, out_data):
-        x = in_data[0]
-        y = out_data[0]
-        if self.fwd_kernel is None:
-            self.fwd_kernel = mx.rtc('softmax', [('x', x)], [('y', y)], """
-int i = threadIdx.x + blockIdx.x*blockDim.x;
-float max_x = x[i*x_dims[1]];
-for (int j = 1; j < x_dims[1]; ++j) {
-    if (max_x < x[i*x_dims[1]+j]) {
-        max_x = x[i*x_dims[1]+j];
-    }
-}
-float sum = 0.0f;
-for (int j = 0; j < x_dims[1]; ++j) {
-    sum += expf(x[i*x_dims[1]+j]-max_x);
-}
-for (int j = 0; j < x_dims[1]; ++j) {
-    y[i*x_dims[1]+j] = expf(x[i*x_dims[1]+j]-max_x)/sum;
-}
-""")
-        self.fwd_kernel.push([x], [y], (1, 1, 1), (x.shape[0], 1, 1))
-
-    def backward(self, out_grad, in_data, out_data, in_grad):
-        l = in_data[1]
-        y = out_data[0]
-        dx = in_grad[0]
-        if self.bwd_kernel is None:
-            self.bwd_kernel = mx.rtc('softmax_grad', [('y', y), ('l', l)], 
[('dx', dx)], """
-int i = blockIdx.x;
-int j = threadIdx.x;
-int k = static_cast<int>(l[i]);
-if (j == k) {
-    dx[i*dx_dims[1]+j] = y[i*dx_dims[1]+j] - 1.0f;
-} else {
-    dx[i*dx_dims[1]+j] = y[i*dx_dims[1]+j];
-}
-""")
-        self.bwd_kernel.push([y,l], [dx], (y.shape[0],1,1), (y.shape[1], 1, 1))
-
-# define mlp
-
-data = mx.symbol.Variable('data')
-fc1 = mx.symbol.FullyConnected(data = data, name='fc1', num_hidden=128)
-act1 = mx.symbol.Activation(data = fc1, name='relu1', act_type="relu")
-fc2 = mx.symbol.FullyConnected(data = act1, name = 'fc2', num_hidden = 64)
-act2 = mx.symbol.Activation(data = fc2, name='relu2', act_type="relu")
-fc3 = mx.symbol.FullyConnected(data = act2, name='fc3', num_hidden=10)
-#mlp = mx.symbol.Softmax(data = fc3, name = 'mlp')
-mysoftmax = NDArraySoftmax()
-mlp = mysoftmax(data=fc3, name = 'softmax')
-
-# data
-
-train, val = get_mnist_iterator(batch_size=100, input_shape = (784,))
-
-# train
-
-logging.basicConfig(level=logging.DEBUG)
-
-model = mx.model.FeedForward(
-    ctx = mx.gpu(0), symbol = mlp, num_epoch = 20,
-    learning_rate = 0.1, momentum = 0.9, wd = 0.00001)
-
-model.fit(X=train, eval_data=val)
-

-- 
To stop receiving notification emails like this one, please contact
j...@apache.org.

Reply via email to