You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/02/08 03:28:28 UTC
[GitHub] [tvm] ilovetvm opened a new issue #7420: Parallel on Reduction Axis in CPU needs rigorous checking
ilovetvm opened a new issue #7420:
URL: https://github.com/apache/tvm/issues/7420
Using the primitive `parallel` on a reduction axis in CPU (`target="llvm`) can lead to silent errors. In the following `GEMM` example, TVM compiles successfully, and computes as if the computation result is correct. However, the comparison with `Pytorch` shows that the final result is wrong.
It would be better if a more rigorous checking for this could be added to TVM. At least some warning messages for parallel reduction should be properly produced to give an alarm to the user who is learning how to write correct schedules.
```
import tvm
import torch
import numpy as np
M = 2
N = 2
K = 4
def gemm(inputs, weight):
assert inputs.shape[1].value == weight.shape[0].value
M, K = inputs.shape[0], inputs.shape[1]
N = weight.shape[1].value
k = tvm.te.reduce_axis((0, K))
return tvm.te.compute((M, N), lambda i, j: tvm.te.sum(inputs[i, k] * weight[k, j], axis=k))
def test(parallel):
A_np = np.random.random([M, K]).astype(np.float32) * 100
B_np = np.random.random([K, N]).astype(np.float32) * 100
A_torch = torch.tensor(A_np)
B_torch = torch.tensor(B_np)
C_torch = A_torch @ B_torch
tvm_ctx = tvm.context("llvm", 0)
A_tvm = tvm.nd.array(A_np, tvm_ctx)
B_tvm = tvm.nd.array(B_np, tvm_ctx)
C_tvm = tvm.nd.array(np.zeros(C_torch.shape).astype(np.float32), tvm_ctx)
A_t = tvm.te.placeholder(A_np.shape, dtype="float32")
B_t = tvm.te.placeholder(B_np.shape, dtype="float32")
C = gemm(A_t, B_t)
s = tvm.te.create_schedule(C.op)
if parallel == True:
k_axis, = s[C].op.reduce_axis
s[C].parallel(k_axis)
print(tvm.lower(s, [A_t, B_t, C], simple_mode=True))
func = tvm.build(s, [A_t, B_t, C], "llvm")
func(A_tvm, B_tvm, C_tvm)
np.testing.assert_allclose(C_tvm.asnumpy(), C_torch.numpy(), rtol=1e-5)
test(False)
print("===========================================================")
test(True)
```
The result of running the above code is
```
primfn(placeholder_2: handle, placeholder_3: handle, compute_1: handle) -> ()
attr = {"global_symbol": "main", "tir.noalias": True}
buffers = {compute: Buffer(compute_2: Pointer(float32), float32, [2, 2], []),
placeholder: Buffer(placeholder_4: Pointer(float32), float32, [2, 4], []),
placeholder_1: Buffer(placeholder_5: Pointer(float32), float32, [4, 2], [])}
buffer_map = {placeholder_2: placeholder, placeholder_3: placeholder_1, compute_1: compute} {
for (i: int32, 0, 2) {
for (j: int32, 0, 2) {
compute_2[((i*2) + j)] = 0f32
for (rv: int32, 0, 4) {
compute_2[((i*2) + j)] = ((float32*)compute_2[((i*2) + j)] + ((float32*)placeholder_4[((i*4) + rv)]*(float32*)placeholder_5[((rv*2) + j)]))
}
}
}
}
===========================================================
primfn(placeholder_2: handle, placeholder_3: handle, compute_1: handle) -> ()
attr = {"global_symbol": "main", "tir.noalias": True}
buffers = {compute: Buffer(compute_2: Pointer(float32), float32, [2, 2], []),
placeholder: Buffer(placeholder_4: Pointer(float32), float32, [2, 4], []),
placeholder_1: Buffer(placeholder_5: Pointer(float32), float32, [4, 2], [])}
buffer_map = {placeholder_2: placeholder, placeholder_3: placeholder_1, compute_1: compute} {
for (i: int32, 0, 2) {
for (j: int32, 0, 2) {
compute_2[((i*2) + j)] = 0f32
for (rv: int32, 0, 4) "parallel" {
compute_2[((i*2) + j)] = ((float32*)compute_2[((i*2) + j)] + ((float32*)placeholder_4[((i*4) + rv)]*(float32*)placeholder_5[((rv*2) + j)]))
}
}
}
}
Traceback (most recent call last):
...
np.testing.assert_allclose(C_tvm.asnumpy(), C_torch.numpy(), rtol=1e-5)
...
AssertionError:
Not equal to tolerance rtol=1e-05, atol=0
Mismatched elements: 3 / 4 (75%)
Max absolute difference: 1769.9526
Max relative difference: 0.21025768
x: array([[ 6648.064, 9747.168],
[ 5256.175, 11761.918]], dtype=float32)
y: array([[ 8418.017 , 10098.07 ],
[ 5256.1753, 13276.437 ]], dtype=float32)
```
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
users@infra.apache.org