This is an automated email from the ASF dual-hosted git repository.
csullivan pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 32ed4f00de [Unity] Fix CUTLASS tests following LiftTransformParams
signature change (#15707)
32ed4f00de is described below
commit 32ed4f00de6001bcaa2181ae55e4ae7f30474d6b
Author: masahi <[email protected]>
AuthorDate: Sat Sep 9 00:15:12 2023 +0900
[Unity] Fix CUTLASS tests following LiftTransformParams signature change
(#15707)
#15657 modified the signature of a mod generated by LiftTransformParams
pass to take unpacked params as input rather than tuple params. Some cutlass
tests needs updatin
---
tests/python/relax/test_codegen_cutlass.py | 12 +++++-------
1 file changed, 5 insertions(+), 7 deletions(-)
diff --git a/tests/python/relax/test_codegen_cutlass.py
b/tests/python/relax/test_codegen_cutlass.py
index 0c5b3ea9e0..e8d4e83521 100644
--- a/tests/python/relax/test_codegen_cutlass.py
+++ b/tests/python/relax/test_codegen_cutlass.py
@@ -1513,15 +1513,15 @@ def test_fp16A_int4B_gemm():
x_nd = tvm.nd.array(x, dev)
residual_nd = tvm.nd.array(residual, dev)
- params = (packed_weight.copyto(dev), scales.copyto(dev),
bias_trans.copyto(dev))
+ params = [packed_weight.copyto(dev), scales.copyto(dev),
bias_trans.copyto(dev)]
for f_name in ["main_bias", "main_cast_bias", "main_residual"]:
with_residual = "residual" in f_name
if with_residual:
- inp = [x_nd, residual_nd, params]
+ inp = [x_nd, residual_nd] + params
else:
- inp = [x_nd, params]
+ inp = [x_nd] + params
out = vm[f_name](*inp).numpy()
@@ -1665,8 +1665,7 @@ def test_fp16A_int8B_gemm():
vm = relax.vm.VirtualMachine(ex, dev)
x_nd = tvm.nd.array(x, dev)
- params = (packed_weight.copyto(dev), scales.copyto(dev),
bias_trans.copyto(dev))
- inp = [x_nd, params]
+ inp = [x_nd, packed_weight.copyto(dev), scales.copyto(dev),
bias_trans.copyto(dev)]
out = vm["main"](*inp).numpy()
def gelu_fp16(x):
@@ -1940,8 +1939,7 @@ def test_fp16A_int8B_gemm_batched():
vm = relax.vm.VirtualMachine(ex, dev)
x_nd = tvm.nd.array(x, dev)
- params = (packed_weight.copyto(dev), scales.copyto(dev))
- inp = [x_nd, params]
+ inp = [x_nd, packed_weight.copyto(dev), scales.copyto(dev)]
out = vm["main"](*inp).numpy()
ref = np.dot(x, y.transpose())
tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)