vinx13 commented on code in PR #14166:
URL: https://github.com/apache/tvm/pull/14166#discussion_r1122276298


##########
python/tvm/relax/backend/contrib/cutlass.py:
##########
@@ -17,84 +17,132 @@
 
 """Pattern table for CUTLASS backend"""
 
-from tvm.relax import transform
+from typing import Mapping, Optional, Tuple
+
+import tvm
+from tvm.contrib.cutlass.build import is_valid_for_cutlass_matmul
+from tvm.relax import Call, Expr, ShapeExpr, transform
+from tvm.relax.dpl import DFPattern
 
 from ..pattern_registry import get_patterns_with_prefix, register_patterns
 from ..patterns import make_fused_bias_activation_pattern, make_matmul_pattern
 
+
+def _get_static_shape(shape: ShapeExpr) -> Optional[Tuple[int]]:
+    result = []
+    for dim in shape.values:
+        if isinstance(dim, tvm.tir.expr.IntImm):
+            result.append(int(dim))
+        else:
+            return None
+    return result
+
+
+def _check_matmul(
+    match_result: Mapping[DFPattern, Expr],
+    _: Expr,
+) -> bool:
+    matmul_call: Call = None
+    for _, expr in match_result.items():
+        if isinstance(expr, Call) and expr.op.name == "relax.matmul":
+            matmul_call = expr
+    if matmul_call is None:
+        raise ValueError("Cannot find call to matmul from match_result.")
+
+    lhs_shape = _get_static_shape(matmul_call.args[0].struct_info.shape)
+    rhs_shape = _get_static_shape(matmul_call.args[1].struct_info.shape)
+    if len(lhs_shape) < 2 or len(rhs_shape) < 2:
+        return False
+
+    lhs_dtype = matmul_call.args[0].struct_info.dtype
+    rhs_dtype = matmul_call.args[1].struct_info.dtype
+    if lhs_dtype != "float16" or rhs_dtype != "float16":

Review Comment:
   fp32 is also supported (if not, let's revisit them later)



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to