masahi commented on a change in pull request #9737:
URL: https://github.com/apache/tvm/pull/9737#discussion_r769075265
##########
File path: python/tvm/contrib/cutlass/gen_conv2d.py
##########
@@ -121,27 +131,67 @@ def get_default(self, out_dtype):
data_type = gemm_profile_result["data_type"]
return create_conv2d_operator([tile_description], data_type,
[alignment])[0]
+ def check_align(self, op_name, C, K):
+ """Filter out kernels that cannot be supported."""
+ aligns = re.findall(r"align[1|2|4|8]", op_name)
+ assert len(aligns) == 1
+ align = int(aligns[0][-1])
+ return all([dim % align == 0 for dim in [C, K]])
+
def profile(
- self, d_shape, w_shape, out_shape, out_dtype, profile_all=True,
use_multiprocessing=False
+ self,
+ d_shape,
+ w_shape,
+ padding,
+ stride,
+ dilation,
+ out_dtype,
+ profile_all=True,
+ use_multiprocessing=False,
):
"""Profile and select the best kernel from candidate kernels.
If profile_all is False, return immediately after the first applicable
kernel is found.
If use_multiprocessing is True, compile all profiler executables in
parallel.
"""
- B, _, _, IC = d_shape
+ N, H, W, IC = d_shape
OC, R, S, _ = w_shape
- _, P, Q, _ = out_shape
+ workload = (
+ N,
+ H,
+ W,
+ IC,
+ OC,
+ R,
+ S,
+ padding[0],
+ padding[1],
+ stride[0],
+ stride[1],
+ dilation[0],
+ dilation[1],
+ )
- M = B * P * Q
- N = OC
- K = R * S * IC
+ if workload in self.cache:
+ return self.cache[workload]
- gemm_profile_result = self.gemm_profiler.profile(
- M, N, K, out_dtype, profile_all=profile_all,
use_multiprocessing=use_multiprocessing
- )
+ ops = GENERATOR_FUNC_TABLE[self.sm](out_dtype,
op_creator=create_conv2d_operator)
+ ops = list(filter(lambda op: self.check_align(op["name"], IC, OC),
ops))
- tile_description = gemm_profile_result["tile_description"]
- alignment = gemm_profile_result["alignment"]
- data_type = gemm_profile_result["data_type"]
+ if profile_all:
+ self.engine.compile_all(ops, use_multiprocessing)
- return create_conv2d_operator([tile_description], data_type,
[alignment])[0]
+ args = (
+ "--n=%d --h=%d --w=%d --c=%d --k=%d --r=%d --s=%d --pad_h=%d
--pad_w=%d "
+ "--stride_h=%d --stride_w=%d --dilation_h=%d --dilation_w=%d"
+ ) % workload
+
+ for op in ops:
+ out = self.engine.evaluate(op, args.split(" "))
+ op["runtime"] = out
+ if out > 0 and not profile_all:
Review comment:
oops you are right, changed to `out < float("inf")`
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]