tqchen commented on code in PR #16052:
URL: https://github.com/apache/tvm/pull/16052#discussion_r1436460198


##########
python/tvm/topi/nn/dense.py:
##########
@@ -65,86 +65,103 @@ def matmul(
     output : tvm.te.Tensor
         2-D with shape [batch, out_dim]
     """
-    # TODO(jcf94): Add multi-dim support for tensor_a
-    assert len(tensor_a.shape) == 2, "only support 2-dim matmul"
+    # TODO(yixin): support cases for 1-dim input
+    # TODO(yixin): adding support and further check for >2-dim input in 
autotvm template
+    assert (
+        len(tensor_a.shape) >= 2 and len(tensor_b.shape) >= 2
+    ), "1-dim matmul is not supported yet."
     if bias is not None:
         assert len(bias.shape) == 1
     if out_dtype is None:
         out_dtype = tensor_a.dtype
     if transpose_a:
-        in_dim, batch = tensor_a.shape
+        reduce_dim_a, in_dim = tensor_a.shape[-2:]
     else:
-        batch, in_dim = tensor_a.shape
+        in_dim, reduce_dim_a = tensor_a.shape[-2:]
+    batch_dims_a = tensor_a.shape[:-2]
 
     if auto_scheduler_rewritten_layout:
         # Infer shape for the rewritten layout
-        out_dim, red_dim = auto_scheduler.get_shape_from_rewritten_layout(
+        assert len(tensor_b).shape == 2, "only support 2-dim matmul when using 
auto-scheduler"

Review Comment:
   cc @Ubospica please followup
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to