vinx13 commented on code in PR #14166:
URL: https://github.com/apache/tvm/pull/14166#discussion_r1122437329


##########
python/tvm/relax/backend/contrib/cutlass.py:
##########
@@ -17,84 +17,141 @@
 
 """Pattern table for CUTLASS backend"""
 
-from tvm.relax import transform
+from typing import Mapping, Optional, Tuple
+
+import tvm
+from tvm.contrib.cutlass.build import is_valid_for_cutlass_matmul
+from tvm.relax import Call, Expr, ShapeExpr, transform
+from tvm.relax.dpl import DFPattern
 
 from ..pattern_registry import get_patterns_with_prefix, register_patterns
 from ..patterns import make_fused_bias_activation_pattern, make_matmul_pattern
 
+
+def _get_static_shape(shape: ShapeExpr) -> Optional[Tuple[int]]:
+    result = []
+    for dim in shape.values:
+        if isinstance(dim, tvm.tir.expr.IntImm):
+            result.append(int(dim))
+        else:
+            return None
+    return result
+
+
+def _is_supported_dtype(lhs_dtype, rhs_dtype):
+    """Check if dtypes in the given workload are supported by CUTLASS."""
+    return (
+        (lhs_dtype == "float16" and rhs_dtype == "float16")
+        or (lhs_dtype == "float32" and rhs_dtype == "float32")
+        or (lhs_dtype in ("int8", "uint8") and rhs_dtype in ("int8", "uint8"))
+    )
+
+
+def _check_matmul(
+    match_result: Mapping[DFPattern, Expr],
+    _: Expr,
+) -> bool:
+    matmul_call: Call = None
+    for _, expr in match_result.items():
+        if isinstance(expr, Call) and expr.op.name == "relax.matmul":

Review Comment:
   ```suggestion
           if isinstance(expr, Call) and isinstance(expr.op, tvm.ir.Op) and 
expr.op.name == "relax.matmul":
   ```
   it's possible it's calling a PrimFunc, it will be a global var



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to