Cookiee235 opened a new issue, #17270:
URL: https://github.com/apache/tvm/issues/17270
### Actual behavior
```
Traceback (most recent call last):
File "/share_container/optfuzz/res/bugs/inconsis222.py", line 258, in
<module>
np.testing.assert_allclose(before_outputs, after_outputs, 1e-3, 1e-3)
File
"/root/miniconda3/lib/python3.12/site-packages/numpy/testing/_private/utils.py",
line 1504, in assert_allclose
assert_array_compare(compare, actual, desired, err_msg=str(err_msg),
File "/root/miniconda3/lib/python3.12/contextlib.py", line 81, in inner
return func(*args, **kwds)
^^^^^^^^^^^^^^^^^^^
File
"/root/miniconda3/lib/python3.12/site-packages/numpy/testing/_private/utils.py",
line 718, in assert_array_compare
flagged = func_assert_same_pos(x, y, func=isnan, hasval='nan')
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File
"/root/miniconda3/lib/python3.12/site-packages/numpy/testing/_private/utils.py",
line 688, in func_assert_same_pos
raise AssertionError(msg)
AssertionError:
Not equal to tolerance rtol=0.001, atol=0.001
x and y nan location mismatch:
x: array([[ 7.936000e+04, 8.032000e+04, 8.128000e+04, 8.224000e+04,
8.320000e+04, 8.416000e+04, 8.512000e+04, 8.608000e+04,
7.168000e+04, 7.252000e+04, 1.898367e+16, 7.420000e+04,...
y: array([[ 7.936000e+04, 8.032000e+04, 8.128000e+04, 8.224000e+04,
8.320000e+04, 8.416000e+04, 8.512000e+04, 8.608000e+04,
nan, 7.252000e+04, 7.336000e+04, 7.420000e+04,...
```
### Steps to reproduce
<details>
<summary>This is a complex test case, I cannot further reduce this case due
to unknown root case</summary>
```python
import tvm
from tvm import relax
import numpy as np
import tvm
metadata = tvm.ir.load_json("""{
\"root\": 1,
\"nodes\": [
{
\"type_key\": \"\"
},
{
\"type_key\": \"Map\",
\"keys\": [
\"relax.expr.Constant\"
],
\"data\": [2]
},
{
\"type_key\": \"Array\",
\"data\": [3]
},
{
\"type_key\": \"relax.expr.Constant\",
\"attrs\": {
\"_checked_type_\": \"11\",
\"data\": \"0\",
\"span\": \"0\",
\"struct_info_\": \"4\"
}
},
{
\"type_key\": \"relax.TensorStructInfo\",
\"attrs\": {
\"dtype\": \"float32\",
\"ndim\": \"2\",
\"shape\": \"5\",
\"span\": \"0\",
\"vdevice\": \"0\"
}
},
{
\"type_key\": \"relax.expr.ShapeExpr\",
\"attrs\": {
\"_checked_type_\": \"10\",
\"span\": \"0\",
\"struct_info_\": \"9\",
\"values\": \"6\"
}
},
{
\"type_key\": \"Array\",
\"data\": [7, 8]
},
{
\"type_key\": \"IntImm\",
\"attrs\": {
\"dtype\": \"int64\",
\"span\": \"0\",
\"value\": \"16\"
}
},
{
\"type_key\": \"IntImm\",
\"attrs\": {
\"dtype\": \"int64\",
\"span\": \"0\",
\"value\": \"16\"
}
},
{
\"type_key\": \"relax.ShapeStructInfo\",
\"attrs\": {
\"ndim\": \"2\",
\"span\": \"0\",
\"values\": \"6\"
}
},
{
\"type_key\": \"relax.ShapeType\",
\"attrs\": {
\"ndim\": \"2\",
\"span\": \"0\"
}
},
{
\"type_key\": \"relax.DynTensorType\",
\"attrs\": {
\"dtype\": \"float32\",
\"ndim\": \"2\",
\"span\": \"0\"
}
}
],
\"b64ndarrays\": [
\"P6G0lvBAXt0AAAAAAAAAAAEAAAAAAAAAAgAAAAIgAQAQAAAAAAAAABAAAAAAAAAAAAQAAAAAAAAAAAAAAACAPwAAAEAAAEBAAACAQAAAoEAAAMBAAADgQAAAAEEAABBBAAAgQQAAMEEAAEBBAABQQQAAYEEAAHBBAACAQQAAiEEAAJBBAACYQQAAoEEAAKhBAACwQQAAuEEAAMBBAADIQQAA0EEAANhBAADgQQAA6EEAAPBBAAD4QQAAAEIAAARCAAAIQgAADEIAABBCAAAUQgAAGEIAABxCAAAgQgAAJEIAAChCAAAsQgAAMEIAADRCAAA4QgAAPEIAAEBCAABEQgAASEIAAExCAABQQgAAVEIAAFhCAABcQgAAYEIAAGRCAABoQgAAbEIAAHBCAAB0QgAAeEIAAHxCAACAQgAAgkIAAIRCAACGQgAAiEIAAIpCAACMQgAAjkIAAJBCAACSQgAAlEIAAJZCAACYQgAAmkIAAJxCAACeQgAAoEIAAKJCAACkQgAApkIAAKhCAACqQgAArEIAAK5CAACwQgAAskIAALRCAAC2QgAAuEIAALpCAAC8QgAAvkIAAMBCAADCQgAAxEIAAMZCAADIQgAAykIAAMxCAADOQgAA0EIAANJCAADUQgAA1kIAANhCAADaQgAA3EIAAN5CAADgQgAA4kIAAORCAADmQgAA6EIAAOpCAADsQgAA7kIAAPBCAADyQgAA9EIAAPZCAAD4QgAA+kIAAPxCAAD+QgAAAEMAAAFDAAACQwAAA0MAAARDAAAFQwAABkMAAAdDAAAIQwAACUMAAApDAAALQwAADEMAAA1DAAAOQwAAD0MAABBDAAARQwAAEkMAABNDAAAUQwAAFUMAABZDAAAXQwAAGEMAABlDAAAaQwAAG0MAABxDAAAdQwAAHkMAAB9DAAAgQwAAIUMAACJDAAAjQwAAJEMAACVDAAAmQwAAJ0MAAChDAAApQwAAKkMAA
CtDAAAsQwAALUMAAC5DAAAvQwAAMEMAADFDAAAyQwAAM0MAADRDAAA1QwAANkMAADdDAAA4QwAAOUMAADpDAAA7QwAAPEMAAD1DAAA+QwAAP0MAAEBDAABBQwAAQkMAAENDAABEQwAARUMAAEZDAABHQwAASEMAAElDAABKQwAAS0MAAExDAABNQwAATkMAAE9DAABQQwAAUUMAAFJDAABTQwAAVEMAAFVDAABWQwAAV0MAAFhDAABZQwAAWkMAAFtDAABcQwAAXUMAAF5DAABfQwAAYEMAAGFDAABiQwAAY0MAAGRDAABlQwAAZkMAAGdDAABoQwAAaUMAAGpDAABrQwAAbEMAAG1DAABuQwAAb0MAAHBDAABxQwAAckMAAHNDAAB0QwAAdUMAAHZDAAB3QwAAeEMAAHlDAAB6QwAAe0MAAHxDAAB9QwAAfkMAAH9D\"
],
\"attrs\": {\"tvm_version\": \"0.17.dev0\"}
}""")
from tvm.script import ir as I
from tvm.script import tir as T
from tvm.script import relax as R
@I.ir_module
class Module:
@T.prim_func(private=True)
def add(A: T.Buffer((T.int64(16), T.int64(16)), "float32"), B:
T.Buffer((T.int64(16), T.int64(16)), "float32"), T_add: T.Buffer((T.int64(16),
T.int64(16)), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1 in T.grid(T.int64(16), T.int64(16)):
with T.block("T_add"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(A[v_ax0, v_ax1], B[v_ax0, v_ax1])
T.writes(T_add[v_ax0, v_ax1])
T_add[v_ax0, v_ax1] = A[v_ax0, v_ax1] + B[v_ax0, v_ax1]
@T.prim_func(private=True)
def cast(gv: T.Buffer((T.int64(16), T.int64(16)), "float32"), compute:
T.Buffer((T.int64(16), T.int64(16)), "int64")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for i0, i1 in T.grid(T.int64(16), T.int64(16)):
with T.block("compute"):
v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
T.reads(gv[v_i0, v_i1])
T.writes(compute[v_i0, v_i1])
compute[v_i0, v_i1] = T.Cast("int64", gv[v_i0, v_i1])
@T.prim_func(private=True)
def matmul(x: T.Buffer((T.int64(1), T.int64(16)), "float32"), weight:
T.Buffer((T.int64(16), T.int64(32)), "float32"), matmul: T.Buffer((T.int64(1),
T.int64(32)), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for i0, i1, k in T.grid(T.int64(1), T.int64(32), T.int64(16)):
with T.block("matmul"):
v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
T.reads(x[v_i0, v_k], weight[v_k, v_i1])
T.writes(matmul[v_i0, v_i1])
with T.init():
matmul[v_i0, v_i1] = T.float32(0)
matmul[v_i0, v_i1] = matmul[v_i0, v_i1] + x[v_i0, v_k] *
weight[v_k, v_i1]
@T.prim_func(private=True)
def reshape(gv: T.Buffer((T.int64(16), T.int64(16)), "float32"),
T_reshape: T.Buffer((T.int64(256),), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0 in range(T.int64(256)):
with T.block("T_reshape"):
v_ax0 = T.axis.spatial(T.int64(256), ax0)
T.reads(gv[v_ax0 % T.int64(256) // T.int64(16), v_ax0 %
T.int64(16)])
T.writes(T_reshape[v_ax0])
T_reshape[v_ax0] = gv[v_ax0 % T.int64(256) // T.int64(16),
v_ax0 % T.int64(16)]
@T.prim_func(private=True)
def reshape1(temp: T.Buffer((T.int64(16),), "float32"), T_reshape:
T.Buffer((T.int64(1), T.int64(16)), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1 in T.grid(T.int64(1), T.int64(16)):
with T.block("T_reshape"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(temp[v_ax1 % T.int64(16)])
T.writes(T_reshape[v_ax0, v_ax1])
T_reshape[v_ax0, v_ax1] = temp[v_ax1 % T.int64(16)]
@T.prim_func(private=True)
def reshape2(gv: T.Buffer((T.int64(16), T.int64(16)), "int64"),
T_reshape: T.Buffer((T.int64(256),), "int64")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0 in range(T.int64(256)):
with T.block("T_reshape"):
v_ax0 = T.axis.spatial(T.int64(256), ax0)
T.reads(gv[v_ax0 % T.int64(256) // T.int64(16), v_ax0 %
T.int64(16)])
T.writes(T_reshape[v_ax0])
T_reshape[v_ax0] = gv[v_ax0 % T.int64(256) // T.int64(16),
v_ax0 % T.int64(16)]
@T.prim_func(private=True)
def reshape3(temp: T.Buffer((T.int64(32),), "int64"), T_reshape:
T.Buffer((T.int64(32),), "int64")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0 in range(T.int64(32)):
with T.block("T_reshape"):
v_ax0 = T.axis.spatial(T.int64(32), ax0)
T.reads(temp[v_ax0 % T.int64(32)])
T.writes(T_reshape[v_ax0])
T_reshape[v_ax0] = temp[v_ax0 % T.int64(32)]
@T.prim_func(private=True)
def strided_slice(tensor_1dim: T.Buffer((T.int64(256),), "float32"),
T_strided_slice_with_axes: T.Buffer((T.int64(16),), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0 in range(T.int64(16)):
with T.block("T_strided_slice_with_axes"):
v_ax0 = T.axis.spatial(T.int64(16), ax0)
T.reads(tensor_1dim[v_ax0])
T.writes(T_strided_slice_with_axes[v_ax0])
T_strided_slice_with_axes[v_ax0] = tensor_1dim[v_ax0]
@T.prim_func(private=True)
def strided_slice1(tensor_1dim: T.Buffer((T.int64(256),), "int64"),
T_strided_slice_with_axes: T.Buffer((T.int64(32),), "int64")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0 in range(T.int64(32)):
with T.block("T_strided_slice_with_axes"):
v_ax0 = T.axis.spatial(T.int64(32), ax0)
T.reads(tensor_1dim[v_ax0])
T.writes(T_strided_slice_with_axes[v_ax0])
T_strided_slice_with_axes[v_ax0] = tensor_1dim[v_ax0]
@T.prim_func(private=True)
def take(var_weight_table: T.handle, routing_table:
T.Buffer((T.int64(32),), "int64"), T_take: T.Buffer((T.int64(16), T.int64(32)),
"float32")):
T.func_attr({"tir.noalias": T.bool(True)})
weight_table_size = T.int64()
weight_table = T.match_buffer(var_weight_table, (T.int64(16),
weight_table_size))
# with T.block("root"):
for ax0, ax1 in T.grid(T.int64(16), T.int64(32)):
with T.block("T_take"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(weight_table[v_ax0, routing_table[v_ax1]],
routing_table[v_ax1])
T.writes(T_take[v_ax0, v_ax1])
T_take[v_ax0, v_ax1] = weight_table[v_ax0,
routing_table[v_ax1]]
@R.function
def main_7(x: R.Tensor((1, 16), dtype="float32"), weight_table:
R.Tensor((16, "weight_table_size"), dtype="float32"), routing_table:
R.Tensor((32,), dtype="int64")) -> R.Tensor((1, 32), dtype="float32"):
weight_table_size = T.int64()
cls = Module
with R.dataflow():
weight = R.call_tir(cls.take, (weight_table, routing_table),
out_sinfo=R.Tensor((16, 32), dtype="float32"))
out = R.call_tir(cls.matmul, (x, weight), out_sinfo=R.Tensor((1,
32), dtype="float32"))
R.output(out)
return out
@R.function
def main() -> R.Tensor((1, 32), dtype="float32"):
cls = Module
gv = R.call_tir(cls.add, (metadata["relax.expr.Constant"][0],
metadata["relax.expr.Constant"][0]), out_sinfo=R.Tensor((16, 16),
dtype="float32"))
tensor_1dim = R.call_tir(cls.reshape, (gv,),
out_sinfo=R.Tensor((256,), dtype="float32"))
temp = R.call_tir(cls.strided_slice, (tensor_1dim,),
out_sinfo=R.Tensor((16,), dtype="float32"))
para0 = R.call_tir(cls.reshape1, (temp,), out_sinfo=R.Tensor((1,
16), dtype="float32"))
para1: R.Tensor((16, 16), dtype="float32") = gv
gv_1 = R.call_tir(cls.cast, (gv,), out_sinfo=R.Tensor((16, 16),
dtype="int64"))
tensor_1dim_1 = R.call_tir(cls.reshape2, (gv_1,),
out_sinfo=R.Tensor((256,), dtype="int64"))
temp_1 = R.call_tir(cls.strided_slice1, (tensor_1dim_1,),
out_sinfo=R.Tensor((32,), dtype="int64"))
para2 = R.call_tir(cls.reshape3, (temp_1,),
out_sinfo=R.Tensor((32,), dtype="int64"))
res: R.Tensor((1, 32), dtype="float32") = cls.main_7(para0, para1,
para2)
return res
def compile_mod(mod, func_name, target, *inputs):
ex = relax.build(mod, target='llvm')
vm = relax.VirtualMachine(ex, tvm.cpu())
mod_outputs = vm[f'{func_name}'](*inputs)
mod_outputs = mod_outputs.numpy()
return mod_outputs
mod = Module
before_outputs = compile_mod(mod, 'main', 'llvm')
mod = relax.transform.FoldConstant()(mod)
mod = relax.transform.ReorderTakeAfterMatmul()(mod)
after_outputs = compile_mod(mod, 'main', 'llvm')
np.testing.assert_allclose(before_outputs, after_outputs, 1e-3, 1e-3)
```
</details>
CC @Lunderberg @junrushao
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]