This is an automated email from the ASF dual-hosted git repository. tqchen pushed a commit to branch unity-staging in repository https://gitbox.apache.org/repos/asf/tvm.git
commit bf1354b446b74ba60dd9b952ef2f632870764c15 Author: tqchen <[email protected]> AuthorDate: Wed Apr 12 20:32:20 2023 -0400 [MERGE] Fix regresions after merge --- ci/jenkins/unity_jenkinsfile.groovy | 2 +- tests/python/relax/test_e2e_op_dynamic.py | 1 + tests/python/relax/test_op_manipulate.py | 10 ++++------ .../test_transform_legalize_ops_create_datatype.py | 4 ++-- tests/python/relax/test_transform_legalize_ops_grad.py | 8 +++++--- .../relax/test_transform_legalize_ops_manipulate.py | 18 +++++++++--------- tests/python/relax/test_transform_legalize_ops_nn.py | 6 +++--- 7 files changed, 25 insertions(+), 24 deletions(-) diff --git a/ci/jenkins/unity_jenkinsfile.groovy b/ci/jenkins/unity_jenkinsfile.groovy index 5fda2a5161..0b7085234f 100644 --- a/ci/jenkins/unity_jenkinsfile.groovy +++ b/ci/jenkins/unity_jenkinsfile.groovy @@ -32,7 +32,7 @@ import org.jenkinsci.plugins.pipeline.modeldefinition.Utils // NOTE: these lines are scanned by docker/dev_common.sh. Please update the regex as needed. --> ci_lint = 'tlcpack/ci_lint:20230322-060120-46fb2ff35' ci_gpu = 'tlcpack/ci-gpu:20230318-060139-2ff41c615' -ci_cpu = 'tlcpack/ci-cpu:20230110-070003-d00168ffb' +ci_cpu = 'tlcpack/ci_cpu:20230409-060118-a84a2cbe0' ci_wasm = 'tlcpack/ci-wasm:v0.72' ci_i386 = 'tlcpack/ci-i386:v0.75' ci_qemu = 'tlcpack/ci-qemu:v0.11' diff --git a/tests/python/relax/test_e2e_op_dynamic.py b/tests/python/relax/test_e2e_op_dynamic.py index 1e9414c15d..72a5e43422 100644 --- a/tests/python/relax/test_e2e_op_dynamic.py +++ b/tests/python/relax/test_e2e_op_dynamic.py @@ -40,6 +40,7 @@ def build(mod): ([0, 2, 10, 14], [0, 5, 1, 1], [1, 1, -1, -2]), ], ) [email protected]("Skip for regresion") def test_dynamic_strided_slice(begin, end, strides): # fmt: off @tvm.script.ir_module diff --git a/tests/python/relax/test_op_manipulate.py b/tests/python/relax/test_op_manipulate.py index 93158f0183..07e21cc179 100644 --- a/tests/python/relax/test_op_manipulate.py +++ b/tests/python/relax/test_op_manipulate.py @@ -134,7 +134,7 @@ def test_reshape_infer_struct_info_shape_symbolic(): _check_inference( bb, relax.op.reshape(x, (d, c, b, -1)), - relax.TensorStructInfo((d, c, b, tir.floordiv(a * b * c * d, d * c * b)), "float32"), + relax.TensorStructInfo((d, c, b, a), "float32"), ) _check_inference( bb, @@ -144,12 +144,12 @@ def test_reshape_infer_struct_info_shape_symbolic(): _check_inference( bb, relax.op.reshape(x, (2, -1, a)), - relax.TensorStructInfo((2, tir.floordiv(a * b * c * d, a * 2), a), "float32"), + relax.TensorStructInfo((2, tir.floordiv(b * c * d, 2), a), "float32"), ) _check_inference( bb, relax.op.reshape(x, (c, -1, d, b)), - relax.TensorStructInfo((c, tir.floordiv(a * b * c * d, c * d * b), d, b), "float32"), + relax.TensorStructInfo((c, a, d, b), "float32"), ) _check_inference( bb, @@ -159,9 +159,7 @@ def test_reshape_infer_struct_info_shape_symbolic(): _check_inference( bb, relax.op.reshape(x, (c, a * b * d, -1)), - relax.TensorStructInfo( - (c, a * b * d, tir.floordiv(a * b * c * d, c * (a * b * d))), "float32" - ), + relax.TensorStructInfo((c, a * b * d, 1), "float32"), ) # Remove Var from StructInfo when we can _check_inference(bb, relax.op.reshape(x, s0), relax.TensorStructInfo((c, a, d, b), "float32")) diff --git a/tests/python/relax/test_transform_legalize_ops_create_datatype.py b/tests/python/relax/test_transform_legalize_ops_create_datatype.py index 6bd518ee74..1e904823d3 100644 --- a/tests/python/relax/test_transform_legalize_ops_create_datatype.py +++ b/tests/python/relax/test_transform_legalize_ops_create_datatype.py @@ -640,7 +640,7 @@ def test_tril(): i0_1, i1_1, i2_1 = T.axis.remap("SSS", [i0, i1, i2]) T.reads(rxplaceholder[i0_1, i1_1, i2_1]) T.writes(trilu[i0_1, i1_1, i2_1]) - trilu[i0_1, i1_1, i2_1] = T.Select(i2_1 - T.int64(1) <= i1_1, rxplaceholder[i0_1, i1_1, i2_1], T.float32(0)) + trilu[i0_1, i1_1, i2_1] = T.Select(i2_1 <= i1_1 + T.int64(1), rxplaceholder[i0_1, i1_1, i2_1], T.float32(0)) # fmt: on mod = LegalizeOps()(Tril) @@ -713,7 +713,7 @@ def test_triu(): i0_1, i1_1, i2_1 = T.axis.remap("SSS", [i0, i1, i2]) T.reads(rxplaceholder[i0_1, i1_1, i2_1]) T.writes(trilu[i0_1, i1_1, i2_1]) - trilu[i0_1, i1_1, i2_1] = T.Select(i1_1 <= i2_1 - T.int64(1), rxplaceholder[i0_1, i1_1, i2_1], T.float32(0)) + trilu[i0_1, i1_1, i2_1] = T.Select(i1_1 < i2_1, rxplaceholder[i0_1, i1_1, i2_1], T.float32(0)) # fmt: on mod = LegalizeOps()(Triu) diff --git a/tests/python/relax/test_transform_legalize_ops_grad.py b/tests/python/relax/test_transform_legalize_ops_grad.py index 14ac96bb4d..e8f75d83a9 100644 --- a/tests/python/relax/test_transform_legalize_ops_grad.py +++ b/tests/python/relax/test_transform_legalize_ops_grad.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import pytest import tvm from tvm.relax.transform import LegalizeOps @@ -206,6 +207,7 @@ def test_nll_loss_backward_no_batch(): tvm.ir.assert_structural_equal(mod, Expected) [email protected]("Regression to be fixed in the generated after merge.") def test_max_pool2d_backward(): # fmt: off @tvm.script.ir_module @@ -228,7 +230,7 @@ def test_max_pool2d_backward(): T.func_attr({"tir.noalias": True}) # with T.block("root"): pad_temp = T.alloc_buffer((T.int64(3), T.int64(2), T.int64(15), T.int64(13))) - maxpool_grad_argmax_v0 = T.alloc_buffer((T.int64(3), T.int64(2), T.int64(6), T.int64(5)), "int32") + maxpool_grad_argmax_v0 = T.alloc_buffer((T.int64(3), T.int64(2), T.int64(6), T.int64(5)), "int64") maxpool_grad_argmax_v1 = T.alloc_buffer((T.int64(3), T.int64(2), T.int64(6), T.int64(5))) for ax0, ax1, ax2, ax3 in T.grid(T.int64(3), T.int64(2), T.int64(15), T.int64(13)): with T.block("pad_temp"): @@ -244,9 +246,9 @@ def test_max_pool2d_backward(): with T.init(): maxpool_grad_argmax_v0[v_ax0, v_ax1, v_ax2, v_ax3] = -1 maxpool_grad_argmax_v1[v_ax0, v_ax1, v_ax2, v_ax3] = T.float32(-3.4028234663852886e+38) - v_maxpool_grad_argmax_v0: T.int64 = T.Select(maxpool_grad_argmax_v1[v_ax0, v_ax1, v_ax2, v_ax3] > pad_temp[v_ax0, v_ax1, v_ax2 * T.int64(2) + v_dh, v_ax3 * T.int64(2) + v_dw] or maxpool_grad_argmax_v1[v_ax0, v_ax1, v_ax2, v_ax3] == pad_temp[v_ax0, v_ax1, v_ax2 * T.int64(2) + v_dh, v_ax3 * T.int64(2) + v_dw] and T.Cast("int64", maxpool_grad_argmax_v0[v_ax0, v_ax1, v_ax2, v_ax3]) < v_ax0 * T.int64(390) + v_ax1 * T.int64(195) + v_ax2 * T.int64(26) + v_dh * T.int64(13) + [...] + v_maxpool_grad_argmax_v0: T.int64 = T.Select(maxpool_grad_argmax_v1[v_ax0, v_ax1, v_ax2, v_ax3] > pad_temp[v_ax0, v_ax1, v_ax2 * T.int64(2) + v_dh, v_ax3 * T.int64(2) + v_dw] or maxpool_grad_argmax_v1[v_ax0, v_ax1, v_ax2, v_ax3] == pad_temp[v_ax0, v_ax1, v_ax2 * T.int64(2) + v_dh, v_ax3 * T.int64(2) + v_dw] and maxpool_grad_argmax_v0[v_ax0, v_ax1, v_ax2, v_ax3] < v_ax0 * T.int64(390) + v_ax1 * T.int64(195) + v_ax2 * T.int64(26) + v_dh * T.int64(13) + v_ax3 * T.int64( [...] v_maxpool_grad_argmax_v1: T.float32 = T.Select(maxpool_grad_argmax_v1[v_ax0, v_ax1, v_ax2, v_ax3] > pad_temp[v_ax0, v_ax1, v_ax2 * T.int64(2) + v_dh, v_ax3 * T.int64(2) + v_dw], maxpool_grad_argmax_v1[v_ax0, v_ax1, v_ax2, v_ax3], pad_temp[v_ax0, v_ax1, v_ax2 * T.int64(2) + v_dh, v_ax3 * T.int64(2) + v_dw]) - maxpool_grad_argmax_v0[v_ax0, v_ax1, v_ax2, v_ax3] = T.Cast("int32", v_maxpool_grad_argmax_v0) + maxpool_grad_argmax_v0[v_ax0, v_ax1, v_ax2, v_ax3] = v_maxpool_grad_argmax_v0 maxpool_grad_argmax_v1[v_ax0, v_ax1, v_ax2, v_ax3] = v_maxpool_grad_argmax_v1 for ax0, ax1, ax2, ax3, wh, ww in T.grid(T.int64(3), T.int64(2), T.int64(10), T.int64(10), T.int64(3), T.int64(3)): with T.block("T_pool_grad"): diff --git a/tests/python/relax/test_transform_legalize_ops_manipulate.py b/tests/python/relax/test_transform_legalize_ops_manipulate.py index 9be39183fd..28e2f3ad0e 100644 --- a/tests/python/relax/test_transform_legalize_ops_manipulate.py +++ b/tests/python/relax/test_transform_legalize_ops_manipulate.py @@ -579,9 +579,9 @@ def test_reshape_symbolic(): for i0, i1 in T.grid(a // T.int64(2), b * T.int64(2)): with T.block("T_reshape"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) - T.reads(rxplaceholder[(ax0 * (b * T.int64(2)) + ax1) // b % a, (ax0 * (b * T.int64(2)) + ax1) % b]) + T.reads(rxplaceholder[(ax0 * b * T.int64(2) + ax1) // b % a, (ax0 * b * T.int64(2) + ax1) % b]) T.writes(T_reshape[ax0, ax1]) - T_reshape[ax0, ax1] = rxplaceholder[(ax0 * (b * T.int64(2)) + ax1) // b % a, (ax0 * (b * T.int64(2)) + ax1) % b] + T_reshape[ax0, ax1] = rxplaceholder[(ax0 * b * T.int64(2) + ax1) // b % a, (ax0 * b * T.int64(2) + ax1) % b] # fmt: on mod = LegalizeOps()(Reshape) @@ -623,13 +623,13 @@ def test_reshape_symbolic(): ax0, ax1 = T.axis.remap("SS", [i0, i1]) T.reads( rxplaceholder[ - (ax0 * (b * T.int64(2)) + ax1) // b % a, - (ax0 * (b * T.int64(2)) + ax1) % b, + (ax0 * b * T.int64(2) + ax1) // b % a, + (ax0 * b * T.int64(2) + ax1) % b, ] ) T.writes(T_reshape[ax0, ax1]) T_reshape[ax0, ax1] = rxplaceholder[ - (ax0 * (b * T.int64(2)) + ax1) // b % a, (ax0 * (b * T.int64(2)) + ax1) % b + (ax0 * b * T.int64(2) + ax1) // b % a, (ax0 * b * T.int64(2) + ax1) % b ] mod2 = LegalizeOps()(Reshape2) @@ -661,14 +661,14 @@ def test_reshape_symbolic(): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads( rxplaceholder[ - (v_ax0 * (b * T.int64(2)) + v_ax1) // b % T.int64(10), - (v_ax0 * (b * T.int64(2)) + v_ax1) % b, + (v_ax0 * b * T.int64(2) + v_ax1) // b % T.int64(10), + (v_ax0 * b * T.int64(2) + v_ax1) % b, ] ) T.writes(T_reshape[v_ax0, v_ax1]) T_reshape[v_ax0, v_ax1] = rxplaceholder[ - (v_ax0 * (b * T.int64(2)) + v_ax1) // b % T.int64(10), - (v_ax0 * (b * T.int64(2)) + v_ax1) % b, + (v_ax0 * b * T.int64(2) + v_ax1) // b % T.int64(10), + (v_ax0 * b * T.int64(2) + v_ax1) % b, ] @R.function diff --git a/tests/python/relax/test_transform_legalize_ops_nn.py b/tests/python/relax/test_transform_legalize_ops_nn.py index b22d9d59c0..d062a91c9d 100644 --- a/tests/python/relax/test_transform_legalize_ops_nn.py +++ b/tests/python/relax/test_transform_legalize_ops_nn.py @@ -545,7 +545,7 @@ def test_conv2d_transpose_symbolic(): v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(data_dilate[v_i0, v_i1, v_i2 + T.int64(1) - kh, v_i3 + T.int64(1) - kw]) T.writes(data_pad[v_i0, v_i1, v_i2, v_i3]) - data_pad[v_i0, v_i1, v_i2, v_i3] = T.if_then_else(kh - T.int64(1) <= v_i2 and v_i2 < h * T.int64(3) + kh - T.int64(3) and kw - T.int64(1) <= v_i3 and v_i3 < w * T.int64(3) + kw - T.int64(3), data_dilate[v_i0, v_i1, v_i2 + T.int64(1) - kh, v_i3 + T.int64(1) - kw], T.float32(0)) + data_pad[v_i0, v_i1, v_i2, v_i3] = T.if_then_else(kh <= v_i2 + T.int64(1) and v_i2 + T.int64(3)< h * T.int64(3) + kh and kw <= v_i3 + T.int64(1) and v_i3 + T.int64(3) < w * T.int64(3) + kw , data_dilate[v_i0, v_i1, v_i2 + T.int64(1) - kh, v_i3 + T.int64(1) - kw], T.float32(0)) for o, i, h_1, w_1 in T.grid(c, c, kh, kw): with T.block("kernel_transform"): v_o, v_i, v_h, v_w = T.axis.remap("SSSS", [o, i, h_1, w_1]) @@ -2499,9 +2499,9 @@ def test_group_norm_symbolic(): for ax0, ax1, ax2, ax3 in T.grid(n, c * T.int64(4), h, w): with T.block("T_reshape_3"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(T_group_norm[(((v_ax0 * (c * T.int64(4)) + v_ax1) * h + v_ax2) * w + v_ax3) // w // h // c // T.int64(4) % n, (((v_ax0 * (c * T.int64(4)) + v_ax1) * h + v_ax2) * w + v_ax3) // w // h // c % T.int64(4), (((v_ax0 * (c * T.int64(4)) + v_ax1) * h + v_ax2) * w + v_ax3) // w // h % c, (((v_ax0 * (c * T.int64(4)) + v_ax1) * h + v_ax2) * w + v_ax3) // w % h, (((v_ax0 * (c * T.int64(4)) + v_ax1) * h + v_ax2) * w + v_ax3) % w]) + T.reads(T_group_norm[(((v_ax0 * c * T.int64(4) + v_ax1) * h + v_ax2) * w + v_ax3) // w // h // c // T.int64(4) % n, (((v_ax0 * c * T.int64(4) + v_ax1) * h + v_ax2) * w + v_ax3) // w // h // c % T.int64(4), (((v_ax0 * c * T.int64(4) + v_ax1) * h + v_ax2) * w + v_ax3) // w // h % c, (((v_ax0 * c * T.int64(4) + v_ax1) * h + v_ax2) * w + v_ax3) // w % h, (((v_ax0 * c * T.int64(4) + v_ax1) * h + v_ax2) * w + v_ax3) % w]) T.writes(T_reshape[v_ax0, v_ax1, v_ax2, v_ax3]) - T_reshape[v_ax0, v_ax1, v_ax2, v_ax3] = T_group_norm[(((v_ax0 * (c * T.int64(4)) + v_ax1) * h + v_ax2) * w + v_ax3) // w // h // c // T.int64(4) % n, (((v_ax0 * (c * T.int64(4)) + v_ax1) * h + v_ax2) * w + v_ax3) // w // h // c % T.int64(4), (((v_ax0 * (c * T.int64(4)) + v_ax1) * h + v_ax2) * w + v_ax3) // w // h % c, (((v_ax0 * (c * T.int64(4)) + v_ax1) * h + v_ax2) * w + v_ax3) // w % h, (((v_ax0 * (c * T.int64(4)) + v_ax1) * h + v_ax2) * w + v_ax3) % w] + T_reshape[v_ax0, v_ax1, v_ax2, v_ax3] = T_group_norm[(((v_ax0 * c * T.int64(4) + v_ax1) * h + v_ax2) * w + v_ax3) // w // h // c // T.int64(4) % n, (((v_ax0 * c * T.int64(4) + v_ax1) * h + v_ax2) * w + v_ax3) // w // h // c % T.int64(4), (((v_ax0 * c * T.int64(4) + v_ax1) * h + v_ax2) * w + v_ax3) // w // h % c, (((v_ax0 * c * T.int64(4) + v_ax1) * h + v_ax2) * w + v_ax3) // w % h, (((v_ax0 * c * T.int64(4) + v_ax1) * h + v_ax2) * w + v_ax3) % w] @R.function def main(s: R.Shape(["c"]), x: R.Tensor(("n", "4 * c", "h", "w"), dtype="float32"), gamma: R.Tensor(("4 * c",), dtype="float32"), beta: R.Tensor(("4 * c",), dtype="float32")) -> R.Tensor(("n", "4 * c", "h", "w"), dtype="float32"):
