rasagna-quic commented on code in PR #14210:
URL: https://github.com/apache/tvm/pull/14210#discussion_r1133026510


##########
python/tvm/contrib/hexagon/transform.py:
##########
@@ -148,3 +151,163 @@ def transform(func, mod, ctx):
 
 def ir_lower_vtcm_pass():
     return [(3, ir_lower_vtcm())]
+
+
+class qdistilbert_rewrite(DFPatternCallback):
+    """
+    A callback to replace the below pattern:
+    Pattern:
+    %35 = strided_slice(%34, begin=[0, 0, 0], end=[1, 128, 64], strides=[1, 1, 
1], axes=None);
+    %44 = reshape(%35, newshape=[-1, 64]);
+    <snip>
+    %42 = strided_slice(%41, begin=[0, 0, 0], end=[1, 64, 128], strides=[1, 1, 
1], axes=None);
+    %43 = reshape(%42, newshape=[64, 128]);
+    %45 = transpose(%43, axes=[1, 0]);
+    <snip>
+    %46 = qnn.dense(%44, %45, 13, 1, 0.0541715f, 0.0489368f, units=None, 
out_dtype="int32");
+    %47 = qnn.requantize(%46, 0.00265098f, 0, 0.728874f, -14, axis=1, 
out_dtype="int8");
+    <snip>
+    %125 = expand_dims(%47, axis=0) /* ty=Tensor[(1, 128, 128), int8] */;
+    < The above pattern repeats 12 times, which is the batch size >
+
+    %137 = (%125, %126, %127, %128, %129, %130, %131, %132, %133, %134, %135, 
%136);
+    %138 = concatenate(%137);
+
+    """
+
+    def __init__(self):
+        super(qdistilbert_rewrite, self).__init__()
+        self.A = wildcard()  # Tensor A
+        self.B = wildcard()  # Tensor B
+        self.batch = 12  # Number of time pattern repeats or Batch size
+
+        self.d = []  # List of dense quantization parameters
+        self.q = []  # List of requantize parameters
+        L = []  # List of patterns
+
+        z = tvm.tir.IntImm("int64", 0)
+        s1 = tvm.tir.IntImm("int64", 1)
+
+        for i in range(self.batch):
+            x = tvm.tir.IntImm("int64", i)
+
+            self.d.append([is_constant(), is_constant(), is_constant(), 
is_constant()])
+            self.q.append([is_constant(), is_constant(), is_constant(), 
is_constant()])
+
+            pat_a = is_op("strided_slice")(self.A).has_attr(
+                {"begin": [x, z, z], "strides": [s1, s1, s1]}
+            )
+            pat_a = is_op("reshape")(pat_a)
+
+            pat_b = is_op("strided_slice")(self.B).has_attr(
+                {"begin": [x, z, z], "strides": [s1, s1, s1]}
+            )
+            pat_b = is_op("reshape")(pat_b)
+            pat_b = is_op("transpose")(pat_b)
+
+            pat = is_op("qnn.dense")(
+                pat_a, pat_b, self.d[i][0], self.d[i][1], self.d[i][2], 
self.d[i][3]
+            )
+            pat = is_op("qnn.requantize")(
+                pat, self.q[i][0], self.q[i][1], self.q[i][2], self.q[i][3]
+            )
+            pat = is_op("expand_dims")(pat)
+            L.append(pat)
+
+        T = is_tuple(L)
+        self.pattern = is_op("concatenate")(T)
+
+    def check_quant_params(self, node_map):
+        """checking if dense and requant params are the same across patterns"""
+        r = self.batch
+        x1 = [node_map[self.d[0][i]][0].data.numpy().item() for i in range(4)]
+        x2 = [node_map[self.q[0][i]][0].data.numpy().item() for i in range(4)]
+        for i in range(1, r):
+            for j in range(4):
+                y1 = node_map[self.d[i][j]][0].data.numpy().item()
+                y2 = node_map[self.q[i][j]][0].data.numpy().item()
+                if x1[j] != y1 or x2[j] != y2:
+                    return False
+        return True
+
+    def callback(self, pre, post, node_map):
+        A = node_map[self.A][0]
+        B = node_map[self.B][0]
+
+        if not self.check_quant_params(node_map):
+            return post
+
+        [a0, a1, a2] = [0, 0, 0]  # Tensor A shape
+        [b0, b1, b2] = [0, 0, 0]  # Tensor B shape
+
+        if isinstance(A, relay.expr.Call) and isinstance(B, relay.expr.Call):
+            if A.checked_type is None or B.checked_type is None:
+                # Need infer pass to be run before this pass
+                return post
+            if len(A.checked_type.shape) == 3 and len(B.checked_type.shape) == 
3:
+                [a0, a1, a2] = A.checked_type.shape
+                [b0, b1, b2] = B.checked_type.shape
+
+        if isinstance(A, relay.Var) and isinstance(B, relay.Var):
+            if len(A.type_annotation.shape) == 3 and 
len(B.type_annotation.shape) == 3:
+                [a0, a1, a2] = A.type_annotation.shape
+                [b0, b1, b2] = B.type_annotation.shape
+
+        # Check if the batch size is same as expected tensor size
+        if (a0 != self.batch) or (b0 != self.batch):
+            return post
+
+        for i in range(self.batch):
+            # end=(x, pa1, pa2) attribute of strided_slice for Tensor A
+            pa1 = 
pre.args[0][i].args[0].args[0].args[0].args[0].attrs.end[1].value
+            pa2 = 
pre.args[0][i].args[0].args[0].args[0].args[0].attrs.end[2].value
+
+            # end=(x, pb1, pb2) attribute of strided_slice for Tensor B
+            pb1 = 
pre.args[0][i].args[0].args[0].args[1].args[0].args[0].attrs.end[1].value
+            pb2 = 
pre.args[0][i].args[0].args[0].args[1].args[0].args[0].attrs.end[2].value
+
+            if a1 != pa1 or a2 != pa2 or b1 != pb1 or b2 != pb2:
+                return post
+
+        d = [node_map[self.d[0][i]][0] for i in range(4)]
+        q = [node_map[self.q[0][i]][0] for i in range(4)]
+
+        out = relay.op.transpose(B, axes=[0, 2, 1])
+        out = relay.qnn.op.batch_matmul(A, out, d[0], d[1], d[2], d[3], 
out_dtype="int32")
+        out = relay.qnn.op.requantize(out, q[0], q[1], q[2], q[3], 
out_dtype="int8")
+        return out
+
+
+def rewrite_qdistilbert(mod):
+    """Rewrite the Quantized Distilbert to reduce computational complexity."""
+    mod["main"] = rewrite(qdistilbert_rewrite(), mod["main"])
+    return mod
+
+
+class remove_empty_pad_callback(DFPatternCallback):

Review Comment:
   Thank you @masahi for your quick review and help. I will keep the above 
changes in mind, when I create a new PR.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to