jpf888 commented on issue #17176:
URL: https://github.com/apache/tvm/issues/17176#issuecomment-2261794051
@Lunderberg
1、When I apply dispatch in mlcllm, this issue occurs,and When the
pattern-matching is applied, is the call to the fused kernel generated as
“R.call_dps_packed("fused_relax_nn_conv2d_cudnn", args) of the class Module“ ,
2、When in the TVM test case, it works fine, log :
**before pattern** :
from tvm.script import ir as I
from tvm.script import relax as R
@I.ir_module
class Module:
@R.function
def main(data: R.Tensor((16, 32, 32, 16), dtype="float16"), weight:
R.Tensor((32, 3, 3, 16), dtype="float16")) -> R.Tensor((16, 32, 32, 32),
dtype="float16"):
with R.dataflow():
lv: R.Tensor((16, 32, 32, 32), dtype="float16") =
R.nn.conv2d(data, weight, strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1,
1], groups=1, data_layout="NHWC", kernel_layout="OHWI", out_layout="NHWC",
out_dtype="float16")
R.output(lv)
return lv
**after pattern:**
from tvm.script import ir as I
from tvm.script import relax as R
@I.ir_module
class Module:
@R.function
def fused_relax_nn_conv2d_cudnn(data: R.Tensor((16, 32, 32, 16),
dtype="float16"), weight: R.Tensor((32, 3, 3, 16), dtype="float16")) ->
R.Tensor((16, 32, 32, 32), dtype="float16"):
R.func_attr({"Codegen": "cudnn"})
# from tvm.script import relax as R
@R.function
def local_func(data_1: R.Tensor((16, 32, 32, 16),
dtype="float16"), weight_1: R.Tensor((32, 3, 3, 16), dtype="float16")) ->
R.Tensor((16, 32, 32, 32), dtype="float16"):
R.func_attr({"Composite": "cudnn.conv2d.nhwc_ohwi"})
with R.dataflow():
gv: R.Tensor((16, 32, 32, 32), dtype="float16") =
R.nn.conv2d(data_1, weight_1, strides=[1, 1], padding=[1, 1, 1, 1],
dilation=[1, 1], groups=1, data_layout="NHWC", kernel_layout="OHWI",
out_layout="NHWC", out_dtype="float16")
R.output(gv)
return gv
output: R.Tensor((16, 32, 32, 32), dtype="float16") =
local_func(data, weight)
return output
@R.function
def main(data: R.Tensor((16, 32, 32, 16), dtype="float16"),
weight: R.Tensor((32, 3, 3, 16), dtype="float16")) -> R.Tensor((16, 32, 32,
32), dtype="float16"):
cls = Module
with R.dataflow():
gv: R.Tensor((16, 32, 32, 32), dtype="float16") =
cls.fused_relax_nn_conv2d_cudnn(data, weight)
R.output(gv)
return gv
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]