This is an automated email from the ASF dual-hosted git repository.

lukhut pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 28aead905b [microNPU][ETHOSU] Fix SoftMax legalization parameters 
(#15069)
28aead905b is described below

commit 28aead905b2b8ea3984f75d16819a4079b11b23e
Author: Aleksei-grovety <[email protected]>
AuthorDate: Mon Jun 26 20:57:54 2023 +0400

    [microNPU][ETHOSU] Fix SoftMax legalization parameters (#15069)
    
    * [microNPU][ETHOSU] Fix Softmax activation parameters
    
    Fix activation parameters for operations according to the values in Vela.
    
    * fix legalization parameters
    
    * Update test_legalize.py
    
    * Update test_legalize.py
---
 .../backend/contrib/ethosu/softmax_rewriter.py     |  75 +++++++++-----
 tests/python/contrib/test_ethosu/test_legalize.py  | 114 +++++++++++++--------
 2 files changed, 122 insertions(+), 67 deletions(-)

diff --git a/python/tvm/relay/backend/contrib/ethosu/softmax_rewriter.py 
b/python/tvm/relay/backend/contrib/ethosu/softmax_rewriter.py
index 6c0a1dffc3..23d4f4b45b 100644
--- a/python/tvm/relay/backend/contrib/ethosu/softmax_rewriter.py
+++ b/python/tvm/relay/backend/contrib/ethosu/softmax_rewriter.py
@@ -82,6 +82,8 @@ class SoftmaxRewriter(DFPatternCallback):
         self, pre: tvm.relay.Expr, post: tvm.relay.Expr, node_map: 
tvm.ir.container.Map
     ) -> tvm.relay.Expr:
         params = self.params_class(post.op.body)
+        quant_min = -128
+        quant_max = 127
 
         ifm = post.args[0]
         ifm_dtype = ifm.checked_type.dtype
@@ -121,12 +123,14 @@ class SoftmaxRewriter(DFPatternCallback):
             ifm2_scale=0.0,
             ifm2_zero_point=int(params.ifm.q_params.zero_point),
             ofm_scale=1.0,
-            ofm_zero_point=127,
+            ofm_zero_point=quant_max,
             ifm_channels=depth,
             ifm2_channels=1,
             reversed_operands=False,
             ofm_dtype="int32",
             activation="LUT",
+            clip_min=-255,
+            clip_max=0,
         )
 
         # PASS 2 - SHR
@@ -147,8 +151,8 @@ class SoftmaxRewriter(DFPatternCallback):
             reversed_operands=False,
             ofm_dtype="int32",
             activation="CLIP",
-            clip_min=-128,
-            clip_max=127,
+            clip_min=quant_min,
+            clip_max=quant_max,
             rounding_mode="NATURAL",
         )
 
@@ -165,6 +169,9 @@ class SoftmaxRewriter(DFPatternCallback):
             ofm_channels=1,
             upscale="NONE",
             ofm_dtype="int32",
+            activation="CLIP",
+            clip_min=quant_min,
+            clip_max=quant_max,
         )
 
         # PASS 4 - CLZ
@@ -177,6 +184,9 @@ class SoftmaxRewriter(DFPatternCallback):
             ofm_scale=0.0,
             ofm_zero_point=int(params.ifm.q_params.zero_point),
             ofm_channels=1,
+            activation="CLIP",
+            clip_min=quant_min,
+            clip_max=quant_max,
         )
 
         # PASS 5 - Sub
@@ -196,6 +206,9 @@ class SoftmaxRewriter(DFPatternCallback):
             ifm2_channels=1,
             reversed_operands=False,
             ofm_dtype="int32",
+            activation="CLIP",
+            clip_min=quant_min,
+            clip_max=quant_max,
         )
 
         # PASS 6 - Sub
@@ -215,6 +228,9 @@ class SoftmaxRewriter(DFPatternCallback):
             ifm2_channels=1,
             reversed_operands=False,
             ofm_dtype="int32",
+            activation="CLIP",
+            clip_min=quant_min,
+            clip_max=quant_max,
         )
 
         # PASS 7 - SHL
@@ -229,13 +245,13 @@ class SoftmaxRewriter(DFPatternCallback):
             ifm2_zero_point=0,
             ofm_scale=0.0,
             ofm_zero_point=int(params.ifm.q_params.zero_point),
-            ifm_channels=depth,
+            ifm_channels=1,
             ifm2_channels=1,
             reversed_operands=False,
             ofm_dtype="int32",
             activation="CLIP",
-            clip_min=-128,
-            clip_max=127,
+            clip_min=quant_min,
+            clip_max=quant_max,
         )
 
         # PASS 8 - Sub
@@ -255,6 +271,9 @@ class SoftmaxRewriter(DFPatternCallback):
             ifm2_channels=1,
             reversed_operands=False,
             ofm_dtype="int32",
+            activation="CLIP",
+            clip_min=quant_min,
+            clip_max=quant_max,
         )
 
         # PASS 9 - SHL
@@ -274,8 +293,8 @@ class SoftmaxRewriter(DFPatternCallback):
             reversed_operands=False,
             ofm_dtype="int32",
             activation="CLIP",
-            clip_min=-128,
-            clip_max=127,
+            clip_min=quant_min,
+            clip_max=quant_max,
         )
 
         # PASS 10 - Add
@@ -296,8 +315,8 @@ class SoftmaxRewriter(DFPatternCallback):
             reversed_operands=False,
             ofm_dtype="int32",
             activation="CLIP",
-            clip_min=-128,
-            clip_max=127,
+            clip_min=quant_min,
+            clip_max=quant_max,
             use_rescale=True,
             rescale_scale=1,
             rescale_shift=1,
@@ -316,13 +335,13 @@ class SoftmaxRewriter(DFPatternCallback):
             ifm2_zero_point=0,
             ofm_scale=2.0,
             ofm_zero_point=0,
-            ifm_channels=depth,
+            ifm_channels=1,
             ifm2_channels=1,
             reversed_operands=False,
             ofm_dtype="int32",
             activation="CLIP",
-            clip_min=-128 * 2,
-            clip_max=127 * 2,
+            clip_min=quant_min,
+            clip_max=quant_max,
         )
 
         # PASS 12 - Add
@@ -343,8 +362,8 @@ class SoftmaxRewriter(DFPatternCallback):
             reversed_operands=False,
             ofm_dtype="int32",
             activation="CLIP",
-            clip_min=-128,
-            clip_max=127,
+            clip_min=quant_min,
+            clip_max=quant_max,
         )
 
         nr_x = rescale_w_offset
@@ -368,8 +387,8 @@ class SoftmaxRewriter(DFPatternCallback):
                 reversed_operands=False,
                 ofm_dtype="int32",
                 activation="CLIP",
-                clip_min=-128 * 2,
-                clip_max=127 * 2,
+                clip_min=quant_min,
+                clip_max=quant_max,
             )
 
             # PASS 14, 19, 24 - Sub
@@ -388,6 +407,9 @@ class SoftmaxRewriter(DFPatternCallback):
                 ifm2_channels=1,
                 reversed_operands=False,
                 ofm_dtype="int32",
+                activation="CLIP",
+                clip_min=quant_min,
+                clip_max=quant_max,
             )
 
             # PASS 15, 20, 25 - Mul
@@ -407,8 +429,8 @@ class SoftmaxRewriter(DFPatternCallback):
                 reversed_operands=False,
                 ofm_dtype="int32",
                 activation="CLIP",
-                clip_min=-128 * 2,
-                clip_max=127 * 2,
+                clip_min=quant_min,
+                clip_max=quant_max,
             )
 
             # PASS 16, 21, 26 - Mul
@@ -428,8 +450,8 @@ class SoftmaxRewriter(DFPatternCallback):
                 reversed_operands=False,
                 ofm_dtype="int32",
                 activation="CLIP",
-                clip_min=-128,
-                clip_max=127,
+                clip_min=quant_min,
+                clip_max=quant_max,
             )
 
             # PASS 17, 22, 27 - Add
@@ -448,6 +470,9 @@ class SoftmaxRewriter(DFPatternCallback):
                 ifm2_channels=1,
                 reversed_operands=False,
                 ofm_dtype="int32",
+                activation="CLIP",
+                clip_min=quant_min,
+                clip_max=quant_max,
             )
 
         # PASS 28 - Mul
@@ -468,8 +493,8 @@ class SoftmaxRewriter(DFPatternCallback):
             reversed_operands=False,
             ofm_dtype="int32",
             activation="CLIP",
-            clip_min=-128,
-            clip_max=127,
+            clip_min=quant_min,
+            clip_max=quant_max,
         )
 
         # PASS 29 - Mul
@@ -489,8 +514,8 @@ class SoftmaxRewriter(DFPatternCallback):
             reversed_operands=False,
             ofm_dtype="int32",
             activation="CLIP",
-            clip_min=-128 * 2,
-            clip_max=127 * 2,
+            clip_min=quant_min,
+            clip_max=quant_max,
         )
 
         # PASS 30 - SHR
diff --git a/tests/python/contrib/test_ethosu/test_legalize.py 
b/tests/python/contrib/test_ethosu/test_legalize.py
index 1b643f8157..c952a13c52 100644
--- a/tests/python/contrib/test_ethosu/test_legalize.py
+++ b/tests/python/contrib/test_ethosu/test_legalize.py
@@ -3526,9 +3526,10 @@ def test_tflite_hard_swish(ifm_shape):
     assert tuple(func_body.args[1].checked_type.shape) == (256,)
 
 
[email protected]("ifm_shape", [(1, 12), (1, 12, 32)])
-def test_tflite_softmax(ifm_shape):
+def test_tflite_softmax():
+    np.random.seed(0)
     dtype = "int8"
+    ifm_shape = (1, 12)
 
     def create_tflite_graph():
         @tf.function
@@ -3539,7 +3540,7 @@ def test_tflite_softmax(ifm_shape):
         # Convert the model
         def representative_dataset():
             for _ in range(100):
-                data = np.random.rand(*tuple(ifm_shape))
+                data = np.random.uniform(low=-1, high=2, size=tuple(ifm_shape))
                 yield [data.astype(np.float32)]
 
         converter = 
tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
@@ -3554,44 +3555,54 @@ def test_tflite_softmax(ifm_shape):
     def verify(ext_func):
         out_op = ext_func.body
         ops = []
-        # List of expected operations and their type if it exists
-        expected_ops = [
-            ("reshape", None),
-            ("reshape", None),
-            ("contrib.ethosu.pooling", "MAX"),
-            ("contrib.ethosu.binary_elementwise", "SUB"),
-            ("contrib.ethosu.binary_elementwise", "SHR"),
-            ("contrib.ethosu.pooling", "SUM"),
-            ("contrib.ethosu.unary_elementwise", "CLZ"),
-            ("contrib.ethosu.binary_elementwise", "SUB"),
-            ("contrib.ethosu.binary_elementwise", "SHL"),
-            ("contrib.ethosu.binary_elementwise", "SUB"),
-            ("contrib.ethosu.binary_elementwise", "SHL"),
-            ("contrib.ethosu.binary_elementwise", "ADD"),
-            ("contrib.ethosu.binary_elementwise", "MUL"),
-            ("contrib.ethosu.binary_elementwise", "ADD"),
-            ("contrib.ethosu.binary_elementwise", "MUL"),
-            ("contrib.ethosu.binary_elementwise", "SUB"),
-            ("contrib.ethosu.binary_elementwise", "MUL"),
-            ("contrib.ethosu.binary_elementwise", "MUL"),
-            ("contrib.ethosu.binary_elementwise", "ADD"),
-            ("contrib.ethosu.binary_elementwise", "MUL"),
-            ("contrib.ethosu.binary_elementwise", "SUB"),
-            ("contrib.ethosu.binary_elementwise", "MUL"),
-            ("contrib.ethosu.binary_elementwise", "MUL"),
-            ("contrib.ethosu.binary_elementwise", "ADD"),
-            ("contrib.ethosu.binary_elementwise", "MUL"),
-            ("contrib.ethosu.binary_elementwise", "SUB"),
-            ("contrib.ethosu.binary_elementwise", "MUL"),
-            ("contrib.ethosu.binary_elementwise", "MUL"),
-            ("contrib.ethosu.binary_elementwise", "ADD"),
-            ("contrib.ethosu.binary_elementwise", "MUL"),
-            ("contrib.ethosu.binary_elementwise", "MUL"),
-            ("contrib.ethosu.binary_elementwise", "SUB"),
-            ("contrib.ethosu.binary_elementwise", "SHR"),
-            ("reshape", None),
+        # List of expected operations, their type and activation parameters if 
it exists
+        expected_ops_params = [
+            ("reshape", None, [None, None, None, None, None, None]),
+            ("reshape", None, [None, None, None, None, None, None]),
+            ("contrib.ethosu.pooling", "MAX", [0.011756093241274357, -43, 
None, None, 0.0, -43]),
+            (
+                "contrib.ethosu.binary_elementwise",
+                "SUB",
+                [0.011756093241274357, -43, 0.0, -43, 1.0, 127],
+            ),
+            ("contrib.ethosu.binary_elementwise", "SHR", [1.0, 0, 0.0, 0, 0.0, 
-43]),
+            ("contrib.ethosu.pooling", "SUM", [0.0, 0, None, None, 0.0, -43]),
+            ("contrib.ethosu.unary_elementwise", "CLZ", [0.0, 0, None, None, 
0.0, -43]),
+            ("contrib.ethosu.binary_elementwise", "SUB", [0.0, 0, 0.0, 0, 0.0, 
-43]),
+            ("contrib.ethosu.binary_elementwise", "SHL", [0.0, 0, 0.0, 0, 0.0, 
-43]),
+            ("contrib.ethosu.binary_elementwise", "SUB", [0.0, 0, 0.0, 0, 0.0, 
-43]),
+            ("contrib.ethosu.binary_elementwise", "SHL", [0.0, 0, 0.0, 0, 0.0, 
-43]),
+            ("contrib.ethosu.binary_elementwise", "ADD", [0.0, 0, 0.0, 0, 1.0, 
0]),
+            ("contrib.ethosu.binary_elementwise", "MUL", [1.0, 0, 1.0, 0, 2.0, 
0]),
+            ("contrib.ethosu.binary_elementwise", "ADD", [2.0, 0, 0.0, 0, 1.0, 
0]),
+            ("contrib.ethosu.binary_elementwise", "MUL", [1.0, 0, 1.0, 0, 2.0, 
0]),
+            ("contrib.ethosu.binary_elementwise", "SUB", [2.0, 0, 0.0, 0, 1.0, 
0]),
+            ("contrib.ethosu.binary_elementwise", "MUL", [1.0, 0, 1.0, 0, 2.0, 
0]),
+            ("contrib.ethosu.binary_elementwise", "MUL", [2.0, 0, 0.0, 0, 0.0, 
-43]),
+            ("contrib.ethosu.binary_elementwise", "ADD", [1.0, 0, 0.0, 0, 1.0, 
0]),
+            ("contrib.ethosu.binary_elementwise", "MUL", [1.0, 0, 1.0, 0, 2.0, 
0]),
+            ("contrib.ethosu.binary_elementwise", "SUB", [2.0, 0, 0.0, 0, 1.0, 
0]),
+            ("contrib.ethosu.binary_elementwise", "MUL", [1.0, 0, 1.0, 0, 2.0, 
0]),
+            ("contrib.ethosu.binary_elementwise", "MUL", [2.0, 0, 0.0, 0, 0.0, 
-43]),
+            ("contrib.ethosu.binary_elementwise", "ADD", [1.0, 0, 0.0, 0, 1.0, 
0]),
+            ("contrib.ethosu.binary_elementwise", "MUL", [1.0, 0, 1.0, 0, 2.0, 
0]),
+            ("contrib.ethosu.binary_elementwise", "SUB", [2.0, 0, 0.0, 0, 1.0, 
0]),
+            ("contrib.ethosu.binary_elementwise", "MUL", [1.0, 0, 1.0, 0, 2.0, 
0]),
+            ("contrib.ethosu.binary_elementwise", "MUL", [2.0, 0, 0.0, 0, 0.0, 
-43]),
+            ("contrib.ethosu.binary_elementwise", "ADD", [1.0, 0, 0.0, 0, 1.0, 
0]),
+            ("contrib.ethosu.binary_elementwise", "MUL", [1.0, 0, 0.0, 0, 1.0, 
0]),
+            ("contrib.ethosu.binary_elementwise", "MUL", [1.0, 0, 1.0, 0, 2.0, 
0]),
+            ("contrib.ethosu.binary_elementwise", "SUB", [0.0, 0, 0.0, 0, 0.0, 
-43]),
+            ("contrib.ethosu.binary_elementwise", "SHR", [2.0, 0, 0.0, 0, 
0.00390625, -128]),
+            ("reshape", None, [None, None, None, None, None, None]),
         ]
 
+        def get_attr_value(op, attr_name):
+            if hasattr(op.attrs, attr_name):
+                return op.attrs[attr_name]
+            else:
+                return None
+
         def get_op_type(op):
             if hasattr(op.attrs, "pooling_type"):
                 return op.attrs.pooling_type
@@ -3599,6 +3610,16 @@ def test_tflite_softmax(ifm_shape):
                 return op.attrs.operator_type
             return None
 
+        def get_activation_params(op):
+            activation_params = []
+            activation_params.append(get_attr_value(op, "ifm_scale"))
+            activation_params.append(get_attr_value(op, "ifm_zero_point"))
+            activation_params.append(get_attr_value(op, "ifm2_scale"))
+            activation_params.append(get_attr_value(op, "ifm2_zero_point"))
+            activation_params.append(get_attr_value(op, "ofm_scale"))
+            activation_params.append(get_attr_value(op, "ofm_zero_point"))
+            return activation_params
+
         def _visit(stmt):
             if isinstance(stmt, relay.expr.Call):
                 ops.append(stmt)
@@ -3616,9 +3637,18 @@ def test_tflite_softmax(ifm_shape):
         assert ofm.dtype == dtype
 
         # check operations
-
-        ops = [(op.op.name, get_op_type(op)) for op in ops]
-        assert expected_ops == ops
+        for op, expected_op_params in zip(ops, expected_ops_params):
+            activation_params = get_activation_params(op)
+            expected_op_name, expected_op_type, expected_activation_params = 
expected_op_params
+            assert op.op.name == expected_op_name
+            assert expected_op_type == get_op_type(op)
+            for activation_param, expected_activation_param in zip(
+                activation_params, expected_activation_params
+            ):
+                if isinstance(activation_param, float):
+                    assert math.isclose(expected_activation_param, 
activation_param, abs_tol=1e-7)
+                else:
+                    assert expected_activation_param == activation_param
 
     softmax_pattern_table = [
         (

Reply via email to