gussmith23 commented on a change in pull request #5812:
URL: https://github.com/apache/incubator-tvm/pull/5812#discussion_r470754362



##########
File path: tests/python/unittest/test_custom_datatypes.py
##########
@@ -0,0 +1,407 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Utilities for changing datatypes of models."""
+import tvm
+import topi.testing
+import numpy as np
+from numpy.random import MT19937, RandomState, SeedSequence
+from tvm import relay
+from tvm.relay.testing.inception_v3 import get_workload as get_inception
+from tvm.relay.testing.resnet import get_workload as get_resnet
+from tvm.relay.testing.mobilenet import get_workload as get_mobilenet
+from tvm.target.datatype import register, register_min_func, register_op, 
create_lower_func, lower_ite
+from nose.tools import nottest
+
+tgt = "llvm"
+# we use a random seed to generate input_data
+# to guarantee stable tests
+rs = RandomState(MT19937(SeedSequence(123456789)))
+
+def convert_ndarray(dst_dtype, *arrays):
+    """Converts NDArray(s) into the specified datatype"""
+    def convert(array):
+        x = relay.var('x', shape=array.shape, dtype=str(array.dtype))
+        cast = relay.Function([x], x.astype(dst_dtype))
+        with tvm.transform.PassContext(config={"tir.disable_vectorize": True}):
+            return relay.create_executor('graph').evaluate(cast)(array)
+
+    return tuple([convert(x) for x in arrays])
+
+
+def change_dtype(src, dst, module, params):
+    module = relay.frontend.ChangeDatatype(src, dst)(module)
+    module = relay.transform.InferType()(module)
+    params = dict((p, convert_ndarray(dst, params[p])) for p in params)
+    return module, params
+
+def compare(module, input, src_dtype, dst_dtype, rtol, atol, params = {}):
+    ex = relay.create_executor("graph", mod=module)
+
+    correct = ex.evaluate()(*input, **params)
+
+    module, _ = change_dtype(src_dtype, dst_dtype, module, [])
+    ex = relay.create_executor("graph", mod=module)
+    # converts all inputs to dst_dtype
+    x_converted = convert_ndarray(dst_dtype, *input)
+
+    # Vectorization is not implemented with custom datatypes
+    with tvm.transform.PassContext(config={"tir.disable_vectorize": True}):
+        maybe_correct = ex.evaluate()(*x_converted, **params)
+        # TODO(andrew) this only works on single output
+        maybe_correct_converted = convert_ndarray(src_dtype, maybe_correct)[0]
+    np.testing.assert_allclose(maybe_correct_converted.asnumpy(),
+                                correct.asnumpy(),
+                                rtol=rtol,
+                                atol=atol)
+
+def setup():
+    """Set up tests
+
+    Currently, this registers some custom datatypes using the Bring Your
+    Own Datatypes framework.
+    """
+
+    # To use datatype operations in an external library, you should first load
+    # the library containing the datatype implementation:
+    # CDLL("libposit.so", RTLD_GLOBAL)
+    # In this case, the datatype library we are using is built right into TVM,
+    # so we do not need to explicitly load any library.
+
+    # You can pick a code for your datatype arbitrarily, as long as it is
+    # greater than 128 and has not already been chosen.
+
+    register("posites2", 131)
+
+    register_op(create_lower_func(
+        {
+            (32, 32): "FloatToPosit32es2",
+            (32, 16): "FloatToPosit16es2",
+            (32, 8): 'FloatToPosit8es2',
+        }), 
+        "Cast", "llvm", "float", "posites2")
+    register_op(create_lower_func(
+        {
+            (32, 32): "Posit32es2ToFloat",
+            (16, 32): 'Posit16es2ToFloat',
+            (8, 32): 'Posit8es2ToFloat',
+        }), 
+        "Cast", "llvm", "posites2", "float")
+    register_op(create_lower_func(
+        {
+            (4, 32): 'IntToPosit32es2',
+            (4, 16): 'IntToPosit16es2',
+            (4, 8): 'IntToPosit8es2'
+        }), 
+        "Cast", "llvm", "int", "posites2")
+    register_op(create_lower_func({
+        32: 'Posit32es2Add',
+        16: 'Posit16es2Add',
+        8: 'Posit8es2Add'
+    }), "Add", "llvm", "posites2")
+    register_op(create_lower_func({
+        32: 'Posit32es2Sub',
+        16: 'Posit16es2Sub',
+        8: 'Posit8es2Sub'
+    }), "Sub", "llvm", "posites2")
+    register_op(create_lower_func({
+        32: 'FloatToPosit32es2',
+        16: 'FloatToPosit16es2',
+        8: 'FloatToPosit8es2'
+    }), "FloatImm", "llvm", "posites2")
+    register_op(create_lower_func({
+        32: 'Posit32es2Mul',
+        16: 'Posit16es2Mul',
+        8: 'Posit8es2Mul'
+    }), "Mul", "llvm", "posites2")
+    register_op(create_lower_func({
+        32: 'Posit32es2Div',
+        16: 'Posit16es2Div',
+        8: 'Posit8es2Div'
+    }), "Div", "llvm", "posites2")
+    register_op(create_lower_func({
+        32: 'Posit32es2Max',
+        16: 'Posit16es2Max',
+        8: 'Posit8es2Max'
+    }), "Max", "llvm", "posites2")
+    register_op(create_lower_func({
+        32: 'Posit32es2Sqrt',
+        16: 'Posit16es2Sqrt',
+        8: 'Posit8es2Sqrt'
+    }), "Call", "llvm", "posites2", intrinsic_name="sqrt")
+    register_op(lower_ite,
+                "Call",
+                "llvm",
+                "posites2",
+                intrinsic_name="tvm_if_then_else")
+    register_op(create_lower_func({
+        32: 'Posit32es2Exp',
+        16: 'Posit16es2Exp',
+        8: 'Posit8es2Exp'
+    }), "Call", "llvm", "posites2", intrinsic_name="exp")
+    register_op(create_lower_func({
+        32: 'Posit32es2Log',
+        16: 'Posit16es2Log',
+        8: 'Posit8es2Log'
+    }), "Call", "llvm", "posites2", intrinsic_name="log")
+    register_op(create_lower_func({
+        32: 'Posit32es2Sigmoid',
+        16: 'Posit16es2Sigmoid',
+        8: 'Posit8es2Sigmoid'
+    }), "Call", "llvm", "posites2", intrinsic_name="sigmoid")
+    register_op(create_lower_func({
+        32: 'Posit32es2Tanh',
+        16: 'Posit16es2Tanh',
+        8: 'Posit8es2Tanh'
+    }), "Call", "llvm", "posites2", intrinsic_name="tanh")
+    register_min_func(lambda num_bits: - (2 ** 2 ** 2) ** (num_bits - 2), 
"posites2")
+
+def run_ops(src_dtype, dst_dtype, rtol=1e-7, atol=1e-7):
+    """Run the same op, but with two different datatypes"""
+    # used for unary ops, first shape in binary ops
+    shape1 = (5, 10, 5)
+    # second shape for binary ops
+    shape2 = (5, )
+
+    def check_unary_op(op, src_dtype, dst_dtype):
+        t1 = relay.TensorType(shape1, src_dtype)
+        x = relay.var("x", t1)
+        z = op(x)
+        x_data = rs.rand(*shape1).astype(t1.dtype)
+
+        module = tvm.IRModule.from_expr(relay.Function([x], z))
+
+        compare(module, (x_data, ), src_dtype, dst_dtype, rtol, atol)
+        # print(maybe_correct_converted)
+        # print(correct)
+
+    for op in [
+            relay.nn.softmax,
+            tvm.relay.log,
+            tvm.relay.exp,
+            tvm.relay.sqrt,
+            tvm.relay.rsqrt,
+            tvm.relay.sigmoid,
+            tvm.relay.tanh,
+            relay.nn.relu,
+    ]:
+        check_unary_op(op, src_dtype, dst_dtype)
+
+    def check_binary_op(opfunc, src_dtype, dst_dtype):
+        t1 = relay.TensorType(shape1, src_dtype)
+        t2 = relay.TensorType(shape2, src_dtype)
+        x = relay.var("x", t1)
+        y = relay.var("y", t2)
+        z = opfunc(x, y)
+        x_data = rs.rand(*shape1).astype(t1.dtype)
+        y_data = rs.rand(*shape2).astype(t2.dtype)
+        module = tvm.IRModule.from_expr(relay.Function([x, y], z))
+
+        compare(module, (x_data, y_data), src_dtype, dst_dtype, rtol, atol)
+
+    for op in [
+            relay.add,
+            relay.subtract,
+            relay.divide,
+            relay.multiply,
+    ]:
+        check_binary_op(op, src_dtype, dst_dtype)
+
+    # we would like to test tvm_if_then_else
+    # but Relay.IfNode is not lowered to this intrinsic,
+    # so to keep our tests consistent with relay, we decide to not unit test
+    # Note: tvm_if_then_else is tested as part of the mobile_net model
+
+
+def run_model(get_workload,
+              input_shape,
+              src_dtype,
+              dst_dtype,
+              num_classes,
+              rtol=0.0001,
+              atol=0.0001):
+    module, params = get_workload(image_shape=input_shape,
+                                  num_classes=num_classes)
+
+    # generate random input with appropriate shape/type
+    input = tvm.nd.array(rs.rand(*input_shape).astype(src_dtype))
+
+    compare(module, (input, ), src_dtype, dst_dtype, rtol, atol, params)
+
+def run_conv2d(src_dtype, dst_dtype):
+    def run_test_conv2d(src_dtype,
+                        dst_dtype,
+                        scale,
+                        dshape,
+                        kshape,
+                        padding=(1, 1),
+                        fref=None,
+                        groups=1,
+                        dilation=(1, 1),
+                        except_targets=None,
+                        **attrs):
+        if except_targets is None:
+            except_targets = []
+
+        x = relay.var("x", shape=dshape, dtype=src_dtype)
+        w = relay.var("w", shape=kshape, dtype=src_dtype)
+        y = relay.nn.conv2d(x,
+                            w,
+                            padding=padding,
+                            dilation=dilation,
+                            groups=groups,
+                            **attrs)
+        module = tvm.IRModule.from_expr(relay.Function([x, w], y))
+        data = rs.uniform(-scale, scale, size=dshape).astype(src_dtype)
+        kernel = rs.uniform(-scale, scale,
+                                   size=kshape).astype(src_dtype)
+        dkernel = topi.testing.dilate_python(kernel, (1, 1) + dilation)
+        if fref is None:
+            ref_res = topi.testing.conv2d_nchw_python(
+                data.astype(src_dtype),
+                dkernel.astype(src_dtype),
+                1,
+                padding,
+                groups=groups)
+        else:
+            ref_res = fref(data.astype(src_dtype), dkernel.astype(src_dtype))
+
+        for target, ctx in [("llvm", tvm.cpu(0))]:
+            if target in except_targets:
+                continue
+            intrp1 = relay.create_executor("graph",
+                                           ctx=ctx,
+                                           target=target,
+                                           mod=module)
+            module, _ = change_dtype(src_dtype, dst_dtype, module, [])
+            data_converted = convert_ndarray(dst_dtype, data)
+            kernel_converted = convert_ndarray(dst_dtype, kernel)
+            with tvm.transform.PassContext(
+                    config={"tir.disable_vectorize": True}):
+                op_res1 = intrp1.evaluate()(data_converted, kernel_converted)
+            op_res1_converted = convert_ndarray(src_dtype, op_res1)
+            tvm.testing.assert_allclose(op_res1_converted.asnumpy(), ref_res)
+
+    # depthwise conv2d
+    dshape = (1, 32, 18, 18)
+    kshape = (32, 1, 3, 3)
+    run_test_conv2d(src_dtype,
+                    dst_dtype,
+                    1,
+                    dshape,
+                    kshape,
+                    padding=(1, 1),
+                    channels=32,
+                    groups=32,
+                    kernel_size=(3, 3),
+                    fref=lambda x, w: topi.testing.
+                    depthwise_conv2d_python_nchw(x, w, (1, 1), "SAME"))
+
+    # CUDA is disabled for 'direct' schedule:
+    # https://github.com/dmlc/tvm/pull/3070#issuecomment-486597553
+    # group conv2d
+    dshape = (1, 32, 18, 18)
+    kshape = (32, 4, 3, 3)
+    run_test_conv2d(src_dtype,
+                    dst_dtype,
+                    1,
+                    dshape,
+                    kshape,
+                    padding=(1, 1),
+                    channels=32,
+                    groups=8,
+                    kernel_size=(3, 3),
+                    except_targets=['cuda'])
+    # also group conv2d
+    dshape = (1, 32, 18, 18)
+    kshape = (64, 1, 3, 3)
+    run_test_conv2d(src_dtype,
+                    dst_dtype,
+                    1,
+                    dshape,
+                    kshape,
+                    padding=(1, 1),
+                    channels=64,
+                    groups=32,
+                    kernel_size=(3, 3),
+                    except_targets=['cuda'])
+
+    # normal conv2d
+    dshape = (1, 3, 224, 224)
+    kshape = (10, 3, 3, 3)
+    run_test_conv2d(src_dtype,
+                    dst_dtype,
+                    1,
+                    dshape,
+                    kshape,
+                    padding=(1, 1),
+                    channels=10,
+                    kernel_size=(3, 3))
+
+    # dilated conv2d
+    dshape = (1, 3, 18, 18)
+    kshape = (10, 3, 3, 3)
+    run_test_conv2d(src_dtype,
+                    dst_dtype,
+                    1,
+                    dshape,
+                    kshape,
+                    padding=(1, 1),
+                    channels=10,
+                    kernel_size=(3, 3),
+                    dilation=(3, 3))
+
+
+def test_ops():
+    run_ops('float32', 'custom[posites2]8', rtol=1, atol=1)
+    run_ops('float32', 'custom[posites2]16', rtol=0.01, atol=1)
+    run_ops('float32', 'custom[posites2]32')
+
+
+def test_conv2d():
+    # TODO(@gussmith23) slow and broken, needing refactor!
+    # run_conv2d('float32', 'custom[posit32]32')
+    pass
+
+
+def test_models():
+    # Expected posit8 might be faster, but it's not.
+    # run_model(get_mobilenet, (3, 224, 224), 'float32', 'custom[posit8]8')
+    # run_model(get_mobilenet, (3, 224, 224), 'float32', 'custom[posit32]32')
+    # run_model(get_inception, (3, 299, 299), 'float32', 'custom[posit32]32')
+    # run_model(get_resnet, (3, 224, 224), 'float32', 'custom[posit32]32')
+
+    # Run cifar-10 sizes to be a little faster...
+    run_model(get_mobilenet, (3, 32, 32),
+              'float32',
+              'custom[posites2]32',
+              num_classes=10)
+    # run_model(get_inception, (3, 32, 32),
+    #           'float32',
+    #           'custom[posites2]32',
+    #           num_classes=10)
+    # run_model(get_resnet, (3, 32, 32),

Review comment:
       We should debug small Resnet. Something is going wrong. The easiest way 
to debug is to return early from the Resnet constructor:
   
   ``` diff
   modified   python/tvm/relay/testing/resnet.py
   @@ -162,6 +162,7 @@ def resnet(units,
        data = relay.var("data", shape=data_shape, dtype=dtype)
        data = layers.batch_norm_infer(data=data, epsilon=2e-5, scale=False, 
name='bn_data')
        (_, _, height, _) = data_shape
   +    return relay.Function(relay.analysis.free_vars(data), data)
        if height <= 32:            # such as cifar10
            body = layers.conv2d(
                data=data, channels=filter_list[0], kernel_size=(3, 3),
   @@ -194,7 +195,6 @@ def resnet(units,
        flat = relay.nn.batch_flatten(data=pool1)
        fc1 = layers.dense_add_bias(data=flat, units=num_classes, name='fc1')
        net = relay.nn.softmax(data=fc1)
   -    return relay.Function(relay.analysis.free_vars(net), net)
    
    
    def get_net(batch_size,
   ```
   This should work well with the current test file.
   You can move the early return up and down to see where things diverge.
   In my limited testing, it looks like posits are going to NaN after this 
first batch norm! Running the above version of resnet produces NaNs, but if we 
move the early return one layer up, then things are fine:
   
   ``` diff
   modified   python/tvm/relay/testing/resnet.py
   @@ -160,6 +160,7 @@ def resnet(units,
        num_unit = len(units)
        assert num_unit == num_stages
        data = relay.var("data", shape=data_shape, dtype=dtype)
   +    return relay.Function(relay.analysis.free_vars(data), data)
        data = layers.batch_norm_infer(data=data, epsilon=2e-5, scale=False, 
name='bn_data')
        (_, _, height, _) = data_shape
        if height <= 32:            # such as cifar10
   @@ -194,7 +195,6 @@ def resnet(units,
        flat = relay.nn.batch_flatten(data=pool1)
        fc1 = layers.dense_add_bias(data=flat, units=num_classes, name='fc1')
        net = relay.nn.softmax(data=fc1)
   -    return relay.Function(relay.analysis.free_vars(net), net)
    
    
    def get_net(batch_size,
   ```
   
   This version of Resnet literally doesn't do anything, and just returns the 
input. But we know it works, and we can see that things go wrong after that 
first batch norm.
   
   My first thought was "oh, we need to run `SimplifyInference`!" So I did:
   ``` diff
   modified   tests/python/unittest/test_custom_datatypes.py
   @@ -49,6 +49,8 @@ def change_dtype(src, dst, module, params):
        return module, params
    
    def compare(module, input, src_dtype, dst_dtype, rtol, atol, params = {}):
   +    module = relay.transform.SimplifyInference()(module)
   +
        ex = relay.create_executor("graph", mod=module)
    
        correct = ex.evaluate()(*input, **params)
   ```
   
   This fixes the batch norm...somewhat. We still get numerical errors between 
the outputs of the batch norms:
   
   ```
   AssertionError: 
   Not equal to tolerance rtol=0.0001, atol=0.0001
   
   Mismatch: 99.2%
   Max absolute difference: 0.01599717
   Max relative difference: 0.01600081
    x: array([[[[0.692229, 0.071629, 0.014329, ..., 0.984048, 0.871523,
             0.445013],
            [0.728206, 0.627117, 0.563761, ..., 0.378826, 0.326279,...
    y: array([[[[0.681328, 0.070501, 0.014103, ..., 0.968551, 0.857798,
             0.438005],
            [0.716737, 0.61724 , 0.554882, ..., 0.37286 , 0.321141,...
   ```
   
   There shouldn't be this amount of numerical error between posit32es2 and 
float32 -- posit32es2 should be far more accurate than float32 at values around 
[0,1]. 
   
   Here's a debugging list:
   - I think we do still need to run `SimplifyInference` explicitly. This is 
confusing, because it's run in `Optimize()`...maybe it's not getting run?
     + [ ] Add `SimplifyInference` somewhere (maybe in `compare()` like I did? 
I think that might be the best place)
     + [ ] Figure out whether `SimplifyInference` is seemingly not getting run 
later on
   - [ ] I think we should add a test for batch norm itself, because it's been 
such a thorn in the side of this project. Have a test that runs just a batch 
norm through your `compare()`. Once we fix batch norm, hopefully everything 
else will be fixed.
   - To debug the numerical issues with batch norm, here are a few dead-simple 
ideas off the top of my head:
     + [ ] Make sure that the workload is using the posit32 C functions. Maybe 
your `create_lower_func()` call is lowering them to the wrong bitwidth?
     + [ ] Make sure the params going into the batch norm are the same between 
float32 and posit32. (Batch norms take four parameters: mean, var, beta, gamma)
   
   Let me know where you get with this! Good luck!




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to