ilovetvm opened a new issue #7420:
URL: https://github.com/apache/tvm/issues/7420
Using the primitive `parallel` on a reduction axis in CPU (`target="llvm`)
can lead to silent errors. In the following `GEMM` example, TVM compiles
successfully, and computes as if the computation result is correct. However,
the comparison with `Pytorch` shows that the final result is wrong.
It would be better if a more rigorous checking for this could be added to
TVM. At least some warning messages for parallel reduction should be properly
produced to give an alarm to the user who is learning how to write correct
schedules.
```
import tvm
import torch
import numpy as np
M = 2
N = 2
K = 4
def gemm(inputs, weight):
assert inputs.shape[1].value == weight.shape[0].value
M, K = inputs.shape[0], inputs.shape[1]
N = weight.shape[1].value
k = tvm.te.reduce_axis((0, K))
return tvm.te.compute((M, N), lambda i, j: tvm.te.sum(inputs[i, k] *
weight[k, j], axis=k))
def test(parallel):
A_np = np.random.random([M, K]).astype(np.float32) * 100
B_np = np.random.random([K, N]).astype(np.float32) * 100
A_torch = torch.tensor(A_np)
B_torch = torch.tensor(B_np)
C_torch = A_torch @ B_torch
tvm_ctx = tvm.context("llvm", 0)
A_tvm = tvm.nd.array(A_np, tvm_ctx)
B_tvm = tvm.nd.array(B_np, tvm_ctx)
C_tvm = tvm.nd.array(np.zeros(C_torch.shape).astype(np.float32), tvm_ctx)
A_t = tvm.te.placeholder(A_np.shape, dtype="float32")
B_t = tvm.te.placeholder(B_np.shape, dtype="float32")
C = gemm(A_t, B_t)
s = tvm.te.create_schedule(C.op)
if parallel == True:
k_axis, = s[C].op.reduce_axis
s[C].parallel(k_axis)
print(tvm.lower(s, [A_t, B_t, C], simple_mode=True))
func = tvm.build(s, [A_t, B_t, C], "llvm")
func(A_tvm, B_tvm, C_tvm)
np.testing.assert_allclose(C_tvm.asnumpy(), C_torch.numpy(), rtol=1e-5)
test(False)
print("===========================================================")
test(True)
```
The result of running the above code is
```
primfn(placeholder_2: handle, placeholder_3: handle, compute_1: handle) -> ()
attr = {"global_symbol": "main", "tir.noalias": True}
buffers = {compute: Buffer(compute_2: Pointer(float32), float32, [2, 2],
[]),
placeholder: Buffer(placeholder_4: Pointer(float32), float32,
[2, 4], []),
placeholder_1: Buffer(placeholder_5: Pointer(float32), float32,
[4, 2], [])}
buffer_map = {placeholder_2: placeholder, placeholder_3: placeholder_1,
compute_1: compute} {
for (i: int32, 0, 2) {
for (j: int32, 0, 2) {
compute_2[((i*2) + j)] = 0f32
for (rv: int32, 0, 4) {
compute_2[((i*2) + j)] = ((float32*)compute_2[((i*2) + j)] +
((float32*)placeholder_4[((i*4) + rv)]*(float32*)placeholder_5[((rv*2) + j)]))
}
}
}
}
===========================================================
primfn(placeholder_2: handle, placeholder_3: handle, compute_1: handle) -> ()
attr = {"global_symbol": "main", "tir.noalias": True}
buffers = {compute: Buffer(compute_2: Pointer(float32), float32, [2, 2],
[]),
placeholder: Buffer(placeholder_4: Pointer(float32), float32,
[2, 4], []),
placeholder_1: Buffer(placeholder_5: Pointer(float32), float32,
[4, 2], [])}
buffer_map = {placeholder_2: placeholder, placeholder_3: placeholder_1,
compute_1: compute} {
for (i: int32, 0, 2) {
for (j: int32, 0, 2) {
compute_2[((i*2) + j)] = 0f32
for (rv: int32, 0, 4) "parallel" {
compute_2[((i*2) + j)] = ((float32*)compute_2[((i*2) + j)] +
((float32*)placeholder_4[((i*4) + rv)]*(float32*)placeholder_5[((rv*2) + j)]))
}
}
}
}
Traceback (most recent call last):
...
np.testing.assert_allclose(C_tvm.asnumpy(), C_torch.numpy(), rtol=1e-5)
...
AssertionError:
Not equal to tolerance rtol=1e-05, atol=0
Mismatched elements: 3 / 4 (75%)
Max absolute difference: 1769.9526
Max relative difference: 0.21025768
x: array([[ 6648.064, 9747.168],
[ 5256.175, 11761.918]], dtype=float32)
y: array([[ 8418.017 , 10098.07 ],
[ 5256.1753, 13276.437 ]], dtype=float32)
```
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]