jcf94 commented on a change in pull request #7161:
URL: https://github.com/apache/tvm/pull/7161#discussion_r548502307



##########
File path: python/tvm/relay/op/strategy/x86.py
##########
@@ -364,9 +373,9 @@ def dense_strategy_cpu(attrs, inputs, out_type, target):
 def batch_matmul_strategy_cpu(attrs, inputs, out_type, target):
     """batch_matmul x86 strategy"""
     strategy = _op.OpStrategy()
-    if is_dynamic(out_type):
+    if is_dynamic(out_type) or is_auto_scheduler_enabled():

Review comment:
       I feel confused about the condition here. Does `is_dynamic(out_type)` 
just means the output shape has not been computed rather than this is a dynamic 
shape op?

##########
File path: python/tvm/topi/nn/batch_matmul.py
##########
@@ -36,14 +36,25 @@ def batch_matmul(x, y, oshape=None):
         Explicit intended output shape of the computation. Can be useful in 
cases
         with dynamic input shapes.
 
+    auto_scheduler_rewritten_layout: str = ""
+        The layout after auto-scheduler's layout rewrite pass.
+
     Returns
     -------
     output : tvm.te.Tensor
         3-D with shape [batch, M, N]
     """
     assert len(x.shape) == 3 and len(y.shape) == 3, "only support 3-dim 
batch_matmul"
     x_shape = get_const_tuple(x.shape)
-    y_shape = get_const_tuple(y.shape)
+    if auto_scheduler_rewritten_layout:
+        # Infer shape for the rewritten layout
+        y_shape = auto_scheduler.get_shape_from_rewritten_layout(
+            auto_scheduler_rewritten_layout, ["b", "j", "k"]
+        )
+        auto_scheduler.remove_index_check(y)
+    else:
+        y_shape = get_const_tuple(y.shape)
+

Review comment:
       Forget to add the `auto_scheduler.rewrite_compute_body` in the end of 
this function?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to