vinx13 commented on code in PR #16895:
URL: https://github.com/apache/tvm/pull/16895#discussion_r1569209004
##########
python/tvm/relax/backend/contrib/cublas.py:
##########
@@ -68,11 +69,30 @@ def _check_matmul(context: PatternCheckContext) -> bool:
# Rows number must be multiples of 4 for IGEMM
return False
elif lhs_dtype == "e4m3_float8" and rhs_dtype == "e4m3_float8":
- # Matrix dimensions must be multiples of 16. This requirement is
missing from the cuBLAS
- # docs, but it was observed during testing.
- if not isinstance(rhs_shape[-1], (tvm.tir.expr.IntImm, int)) or
rhs_shape[-1] % 16 != 0:
+ matmul_rhs_var = matmul_call.args[1]
+ rhs_transposed = False
+ if matmul_rhs_var in context.matched_bindings:
+ matmul_rhs_call = context.matched_bindings[matmul_rhs_var]
+ assert (
+ isinstance(matmul_rhs_call, tvm.relax.Call)
+ and matmul_rhs_call.op.name == "relax.permute_dims"
+ )
Review Comment:
`if matmul_rhs_var in context.matched_bindings:` this condition implies rhs
is transposed (it's the only pattern that rhs is another binding being
matched), so I added an assertion here, it won't crash if we have
non-transposed rhs
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]