ashutosh-arm commented on a change in pull request #9331:
URL: https://github.com/apache/tvm/pull/9331#discussion_r734653298



##########
File path: python/tvm/relay/op/contrib/cmsisnn.py
##########
@@ -47,42 +47,93 @@ def partition_for_cmsisnn(mod, params=None, **opts):
     if params:
         mod["main"] = bind_params_by_name(mod["main"], params)
 
+    tvm._ffi._init_api("relay.ext.cmsisnn.transform", __name__)
+
     seq = tvm.transform.Sequential(
         [
             transform.InferType(),
             transform.MergeComposite(pattern_table()),
             transform.AnnotateTarget("cmsisnn"),
-            transform.MergeCompilerRegions(),
             transform.PartitionGraph(),
+            GenerateCMSISNNConstants(),
+            ExtractConstantsFromPartitionedFunction(),
+            transform.InferType(),
         ]
     )
-
     return seq(mod)
 
 
 @register_pattern_table("cmsisnn")
 def pattern_table():
     """Get the cmsisnn compiler pattern table."""
 
-    def softmax_pattern():
+    def qnn_softmax_pattern():
+        """Create pattern for quantized softmax"""
         pattern = is_op("qnn.dequantize")(wildcard(), is_constant(), 
is_constant())
         pattern = is_op("nn.softmax")(pattern)
         pattern = is_op("qnn.quantize")(pattern, is_constant(), is_constant())
         return pattern
 
-    def check_quantized_softmax(extract):
+    def check_qnn_softmax(pattern):
         """Check if softmax is supported by CMSIS-NN."""
-        dequantize_call = extract.args[0].args[0]
-        scale = extract.args[1].data.numpy().item(0)
-        zero_point = extract.args[2].data.numpy().item(0)
+        dequantize_call = pattern.args[0].args[0]
+        scale = pattern.args[1].data.numpy().item(0)
+        zero_point = pattern.args[2].data.numpy().item(0)
 
         # check for dtypes of quantize and dequantize
         return (
             (scale == 1.0 / 256 and zero_point == -128)
-            and extract.attrs.out_dtype == "int8"
+            and pattern.attrs.out_dtype == "int8"
             and dequantize_call.args[0].checked_type.dtype == "int8"
         )
 
+    def qnn_conv2d_pattern():
+        """Create pattern for qnn.conv2D with optional fused relu."""
+        qnn_conv2d = is_op("qnn.conv2d")(
+            wildcard(), is_constant(), is_constant(), is_constant(), 
is_constant(), is_constant()
+        ).has_attr({"kernel_layout": "HWIO"})
+        bias_add = is_op("nn.bias_add")(qnn_conv2d, is_constant())
+        req = is_op("qnn.requantize")(
+            qnn_conv2d | bias_add, is_constant(), is_constant(), 
is_constant(), is_constant()
+        )
+        clip_or_req = req.optional(is_op("clip"))
+        return clip_or_req
+
+    def check_qnn_conv2d(pattern):
+        """Check if the Conv2D is supported by CMSIS-NN."""
+        if str(pattern.op.name) == "clip":
+            relu = pattern
+            requantize = relu.args[0]
+        else:
+            requantize = pattern
+        requantize_input = requantize.args[0]
+        bias_add = None
+        bias_dtype = "int32"
+        if str(requantize_input.op.name) == "nn.bias_add":
+            bias_add = requantize_input
+            conv2d = bias_add.args[0]
+            bias_dtype = bias_add.args[1].checked_type.dtype
+        else:
+            conv2d = requantize_input
+        conv2d_input = conv2d.args[0]
+        conv2d_weight = conv2d.args[1]
+
+        # kernel zero_point should be 0
+        kernel_zp = conv2d.args[3].data.numpy()
+        kernel_zp = [kernel_zp] if kernel_zp.ndim == 0 else kernel_zp
+
+        return (
+            conv2d.attrs.kernel_layout == "HWIO"

Review comment:
       yes, in the end we convert it into OHWI but atm only from HWIO. I 
assumed that frontends always convert Conv2D kernel layouts to HWIO. I 
rechecked: ```Conv2DRel``` does not seem to be the case .




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to