yelite commented on code in PR #14166:
URL: https://github.com/apache/tvm/pull/14166#discussion_r1123535730
##########
python/tvm/relax/backend/contrib/cutlass.py:
##########
@@ -17,84 +17,141 @@
"""Pattern table for CUTLASS backend"""
-from tvm.relax import transform
+from typing import Mapping, Optional, Tuple
+
+import tvm
+from tvm.contrib.cutlass.build import is_valid_for_cutlass_matmul
+from tvm.relax import Call, Expr, ShapeExpr, transform
+from tvm.relax.dpl import DFPattern
from ..pattern_registry import get_patterns_with_prefix, register_patterns
from ..patterns import make_fused_bias_activation_pattern, make_matmul_pattern
+
+def _get_static_shape(shape: ShapeExpr) -> Optional[Tuple[int]]:
+ result = []
+ for dim in shape.values:
+ if isinstance(dim, tvm.tir.expr.IntImm):
+ result.append(int(dim))
+ else:
+ return None
+ return result
+
+
+def _is_supported_dtype(lhs_dtype, rhs_dtype):
+ """Check if dtypes in the given workload are supported by CUTLASS."""
+ return (
+ (lhs_dtype == "float16" and rhs_dtype == "float16")
+ or (lhs_dtype == "float32" and rhs_dtype == "float32")
+ or (lhs_dtype in ("int8", "uint8") and rhs_dtype in ("int8", "uint8"))
+ )
+
+
+def _check_matmul(
+ match_result: Mapping[DFPattern, Expr],
+ _: Expr,
+) -> bool:
+ matmul_call: Call = None
+ for _, expr in match_result.items():
+ if isinstance(expr, Call) and expr.op.name == "relax.matmul":
Review Comment:
I will fix this in my follow up PR about dynamic shape support.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]