lhutton1 commented on a change in pull request #9442:
URL: https://github.com/apache/tvm/pull/9442#discussion_r742958078



##########
File path: tests/python/contrib/test_ethosu/test_legalize.py
##########
@@ -558,5 +558,193 @@ def verify(ext_func):
     verify(mod["tvmgen_default_ethosu_main_0"])
 
 
[email protected]("operator_type", ["ADD", "SUB", "MUL", "MIN", "MAX"])
[email protected](
+    "ifm_shape, ifm2_shape, reversed_operands",
+    [
+        ([1, 2, 3, 4], [1, 2, 3, 4], False),
+        ([1, 2, 3, 4], [1, 1, 3, 1], False),
+        ([1, 1, 3, 1], [1, 2, 3, 4], True),
+    ],
+)
[email protected]("activation_function", ["NONE", "RELU"])
+def test_tflite_binary_elemwise_legalize(
+    operator_type,
+    ifm_shape,
+    ifm2_shape,
+    reversed_operands,
+    activation_function,
+):
+    dtype = "int8"
+
+    def create_tflite_graph():
+        class Model(tf.Module):
+            @tf.function
+            def tf_function(self, x, y):
+                if operator_type == "ADD":
+                    op = tf.math.add(x, y)
+                elif operator_type == "SUB":
+                    op = tf.math.subtract(x, y)
+                elif operator_type == "MUL":
+                    op = tf.math.multiply(x, y)
+                elif operator_type == "MIN":
+                    op = tf.math.minimum(x, y)
+                elif operator_type == "MAX":
+                    op = tf.math.maximum(x, y)
+                if activation_function == "RELU":
+                    op = tf.nn.relu(op)
+                return op
+
+        model = Model()
+        concrete_func = model.tf_function.get_concrete_function(
+            tf.TensorSpec(ifm_shape, dtype=tf.float32), 
tf.TensorSpec(ifm2_shape, dtype=tf.float32)
+        )
+
+        # Convert the model
+        def representative_dataset():
+            for _ in range(100):
+                data = np.random.rand(*tuple(ifm_shape))
+                data2 = np.random.rand(*tuple(ifm2_shape)) * 2
+                yield [data.astype(np.float32), data2.astype(np.float32)]
+
+        converter = 
tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
+        converter.optimizations = [tf.lite.Optimize.DEFAULT]
+        converter.representative_dataset = representative_dataset
+        converter.target_spec.supported_ops = 
[tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
+        converter.inference_input_type = tf.int8
+        converter.inference_output_type = tf.int8
+        tflite_model = converter.convert()
+        return tflite_model
+
+    def verify(ext_func):
+        out_shape = ifm2_shape if reversed_operands else ifm_shape
+        shapes = [ifm_shape, ifm2_shape]
+        ifm_index, ifm2_index = (1, 0) if reversed_operands else (0, 1)
+        op = ext_func.body
+        assert list(op.args[0].checked_type.shape) == shapes[ifm_index]
+        assert list(op.args[1].checked_type.shape) == shapes[ifm2_index]
+        assert op.args[0].checked_type.dtype == dtype
+        assert list(op.checked_type.shape) == out_shape
+        assert op.checked_type.dtype == dtype
+        assert op.attrs.operator_type == operator_type
+        assert op.attrs.reversed_operands == reversed_operands
+        if activation_function == "RELU":
+            assert str(op.attrs.activation) == "CLIP"
+
+    if operator_type == "ADD":
+        rewriter = legalize.AddRewriter()
+        pattern_table = [
+            (
+                ethosu.AddParams.composite_name,
+                ethosu.qnn_add_pattern(),
+                lambda pat: ethosu.AddParams(pat).is_valid(),
+            ),
+        ]
+    elif operator_type == "SUB":
+        rewriter = legalize.SubRewriter()
+        pattern_table = [
+            (
+                ethosu.SubParams.composite_name,
+                ethosu.qnn_subtract_pattern(),
+                lambda pat: ethosu.SubParams(pat).is_valid(),
+            ),
+        ]
+    elif operator_type == "MUL":
+        rewriter = legalize.MulRewriter()
+        pattern_table = [
+            (
+                ethosu.MulParams.composite_name,
+                ethosu.qnn_mul_pattern(),
+                lambda pat: ethosu.MulParams(pat).is_valid(),
+            ),
+        ]
+    elif operator_type == "MIN":
+        rewriter = legalize.MinRewriter()
+        pattern_table = [
+            (
+                ethosu.MinParams.composite_name,
+                ethosu.minimum_pattern(),
+                lambda pat: ethosu.MinParams(pat).is_valid(),
+            ),
+        ]
+    elif operator_type == "MAX":
+        rewriter = legalize.MaxRewriter()
+        pattern_table = [
+            (
+                ethosu.MaxParams.composite_name,
+                ethosu.maximum_pattern(),
+                lambda pat: ethosu.MaxParams(pat).is_valid(),
+            ),
+        ]
+
+    tflite_graph = create_tflite_graph()
+    tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0)
+
+    mod, _ = relay.frontend.from_tflite(
+        tflite_model,
+        shape_dict={"x": ifm_shape, "y": ifm2_shape},
+        dtype_dict={"x": dtype, "y": dtype},
+    )
+    mod = partition_ethosu_by_table(mod, pattern_table)
+
+    mod["tvmgen_default_ethosu_main_0"] = dataflow_pattern.rewrite(
+        rewriter, mod["tvmgen_default_ethosu_main_0"]
+    )
+    verify(mod["tvmgen_default_ethosu_main_0"])
+
+
[email protected](
+    "ifm_shape, ifm2_shape, reversed_operands",
+    [
+        ([1, 2, 3, 4], [1, 2, 3, 4], False),
+        ([1, 2, 3, 4], [1, 1, 3, 1], False),
+        ([1, 1, 3, 1], [1, 2, 3, 4], True),
+    ],
+)
+def test_ethosu_left_shift_binary_elemwise_legalize(ifm_shape, ifm2_shape, 
reversed_operands):
+    dtype = "int32"
+    operator_type = "SHL"
+
+    def create_graph():
+        input1 = relay.var("x1", shape=ifm_shape, dtype=dtype)
+        input2 = relay.var("x2", shape=ifm2_shape, dtype=dtype)
+        c1 = relay.left_shift(input1, input2)

Review comment:
       IIRC the difference is due to Relay using a `floor` while the NPU does 
something semantically equivalent to a `ceil`. If that's the case can we not 
add some trickery when legalizing a right shift operator directly? e.g. 
legalize right shift to: right shift followed by element-wise subtract of 1?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to