arina-grovety commented on code in PR #14765:
URL: https://github.com/apache/tvm/pull/14765#discussion_r1196374217


##########
tests/python/contrib/test_ethosu/test_legalize.py:
##########
@@ -462,6 +463,118 @@ def verify(ext_func):
     verify(mod["tvmgen_default_ethos_u_main_0"])
 
 
+def test_tflite_conv2d_with_separate_channel_padding_legalize():
+    dtype = "int8"
+    ifm_shape = (1, 55, 34, 3)
+    kernel_shape = (3, 2)
+    strides = (1, 1)
+    dilation = (2, 1)
+    padding_ch = (1, 1)
+
+    class ArePadOnGraph(ExprVisitor):
+        """
+        Visits the Graph recursively and checks if it contains 'nn.pad' op
+        """
+
+        def __init__(self):
+            ExprVisitor.__init__(self)
+            self.on_graph = False
+
+        def visit_call(self, call):
+            if isinstance(call.op, tvm.ir.Op):
+                if str(call.op.name) == "nn.pad":
+                    self.on_graph = True
+
+            return super().visit_call(call)
+
+        def are_pad_on_graph(self, subgraph) -> bool:
+            """
+            This function recursively visits the graph and checks if 'nn.pad' 
op is ongraph
+            """
+            self.visit(subgraph)
+            return self.on_graph
+
+    def create_tflite_graph_single():
+        class Model(tf.Module):
+            @tf.function
+            def tf_function(self, x):
+                tf_strides = [1, strides[0], strides[1], 1]
+                op = tf.pad(
+                    x,
+                    [[0, 0], [0, 0], [0, 0], [padding_ch[0], padding_ch[1]]],
+                    "CONSTANT",
+                )
+                # HWIO
+                weight_shape = [
+                    kernel_shape[0],
+                    kernel_shape[1],
+                    ifm_shape[3] + padding_ch[0] + padding_ch[1],
+                    3,
+                ]
+                weight = tf.constant(np.random.uniform(size=weight_shape), 
dtype=tf.float32)
+                return tf.nn.conv2d(
+                    op,
+                    weight,
+                    strides=tf_strides,
+                    padding="VALID",
+                    dilations=dilation,
+                )
+
+        model = Model()
+        concrete_func = model.tf_function.get_concrete_function(
+            tf.TensorSpec(ifm_shape, dtype=tf.float32)
+        )
+        # Convert the model
+        def representative_dataset():
+            for _ in range(100):
+                data = np.random.rand(*tuple(ifm_shape))
+                yield [data.astype(np.float32)]
+
+        converter = 
tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
+        converter.optimizations = [tf.lite.Optimize.DEFAULT]
+        converter.representative_dataset = representative_dataset
+        converter.target_spec.supported_ops = 
[tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
+        converter.inference_input_type = tf.int8
+        converter.inference_output_type = tf.int8
+        tflite_model = converter.convert()
+        return tflite_model
+
+    def verify(ext_func):
+
+        assert ArePadOnGraph().are_pad_on_graph(ext_func.body) == True
+
+    conv2d_pattern_table = [
+        (
+            ethosu.ChannelPadParams.composite_name,
+            ethosu.pad_pattern(),
+            lambda pat: ethosu.ChannelPadParams(pat).is_valid(),
+        ),
+        (
+            ethosu.QnnConv2DParams.composite_name,
+            ethosu.qnn_conv2d_pattern(),
+            lambda pat: ethosu.QnnConv2DParams(pat).is_valid(),
+        ),
+    ]
+
+    tflite_graph = create_tflite_graph_single()
+    tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0)
+
+    mod, conv_params = relay.frontend.from_tflite(
+        tflite_model,
+        shape_dict={"input": ifm_shape},
+        dtype_dict={"input": dtype},
+    )
+
+    mod["main"] = bind_params_by_name(mod["main"], conv_params)
+    mod = partition_ethosu_by_table(mod, conv2d_pattern_table)
+
+    mod["tvmgen_default_ethos_u_main_0"] = dataflow_pattern.rewrite(
+        legalize.Conv2DRewriter(), mod["tvmgen_default_ethos_u_main_0"]
+    )
+
+    verify(mod["tvmgen_default_ethos_u_main_0"])

Review Comment:
   Hi @ekalda, thank you for the review! 
   Yes, it is. Here we check that the pad by channel does not merge with 
conv2d, as it happens with the spatial pad.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to