You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/06/23 03:37:00 UTC
[GitHub] [tvm] wrongtest opened a new issue #8308: [BUG] Incorrect buffer offset for vectorized computation
wrongtest opened a new issue #8308:
URL: https://github.com/apache/tvm/issues/8308
Hi there, we are working on vectorized computation in some tagged storage scope. We encountered the problem on vectorize:
The example is as below:
```python
# before vectorize
for (i: int32, 0, 1024) {
C[i] = ((float32*)A[i] + (float32*)B[(i])
}
# after vectorize (correct now)
C[ramp(0, 1, 1024)] = ((float32x1024*)A[ramp(0, 1, 1024)] + (float32x1024*)B[ramp(0, 1, 1024)])
# after storage rewrite (problematic)
Merged[ramp(0, 1, 1024)] = ((float32x1024*)Merged[ramp(0, 1, 1024)] + (float32x1024*)Merged[ramp(1, 1, 1024)])
# after storage rewrite (correct version if we do not vectorize)
for (i: int32, 0, 1024) {
Merged[i] = ((float32*)Merged[i] + (float32*)Merge[(i + 1024)])
}
```
Here we use tagged storage scope, thus `A`,`B`,`C` can be merge into single buffer `Merged` and `A` and `C` share same region:
- `A` -> `Merged[0: 1024]`
- `B` -> `Merged[1024; 2048]`
- `C` -> `Merge[0: 1024]`
It seems that after buffer merging and buffer index remap, the ramp node is incorrectly rewrite to
`ramp(origin + 1, 1, 1024)` instead of `ramp(origin + 1024, 1, 1024)`
The related implementation seems to be here:
https://github.com/apache/tvm/blob/main/src/tir/transforms/storage_rewrite.cc#L507-L512
where the offset is divided by a factor of datatype lanes (in our case, 1024)
The code to reproduce the problem is as below:
```python
import tvm
from tvm import te
from tvm import testing
import numpy as np
@tvm.register_func("tvm.info.mem.global_mycategory")
def my_mem_info():
return tvm.ir.make_node(
"MemoryInfo",
unit_bits=8,
max_simd_bits=512,
max_num_bits=99999,
head_address=None,
)
A = te.placeholder([1024], name="A")
B = te.placeholder([1024], name="B")
C = te.compute([1024], lambda i: A[i] + B[i], name="C")
s = te.create_schedule(C.op)
AA = s.cache_read(A, "global_mycategory", readers=[C])
BB = s.cache_read(B, "global_mycategory", readers=[C])
CC = s.cache_write(C, "global_mycategory")
s[CC].vectorize(s[CC].op.axis[0])
print(tvm.lower(s, [A, B, C]))
f = tvm.build(s, [A, B, C], "llvm")
arr_A = tvm.nd.array(np.linspace(0, 1024, 1024).astype("float32"))
arr_B = tvm.nd.array(np.linspace(1024, 2048, 1024).astype("float32"))
arr_C = tvm.nd.array(np.zeros([1024], dtype="float32"))
f(arr_A, arr_B, arr_C)
print(arr_A.asnumpy())
print(arr_B.asnumpy())
print(arr_C.asnumpy())
tvm.testing.assert_allclose(
arr_C.asnumpy(), arr_A.asnumpy() + arr_B.asnumpy(), rtol=1e-5)
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
users@infra.apache.org
[GitHub] [tvm] vinx13 closed issue #8308: [BUG] Incorrect buffer offset for vectorized computation
Posted by GitBox <gi...@apache.org>.
vinx13 closed issue #8308:
URL: https://github.com/apache/tvm/issues/8308
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org
For queries about this service, please contact Infrastructure at:
users@infra.apache.org
[GitHub] [tvm] wrongtest commented on issue #8308: [BUG] Incorrect buffer offset for vectorized computation
Posted by GitBox <gi...@apache.org>.
wrongtest commented on issue #8308:
URL: https://github.com/apache/tvm/issues/8308#issuecomment-868305492
Hi, there is a fix at https://github.com/apache/tvm/pull/8338. @tqchen @vinx13 @masahi
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
users@infra.apache.org
[GitHub] [tvm] vinx13 closed issue #8308: [BUG] Incorrect buffer offset for vectorized computation
Posted by GitBox <gi...@apache.org>.
vinx13 closed issue #8308:
URL: https://github.com/apache/tvm/issues/8308
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org
For queries about this service, please contact Infrastructure at:
users@infra.apache.org
[GitHub] [tvm] tqchen commented on issue #8308: [BUG] Incorrect buffer offset for vectorized computation
Posted by GitBox <gi...@apache.org>.
tqchen commented on issue #8308:
URL: https://github.com/apache/tvm/issues/8308#issuecomment-867616887
Thanks @wrongtest for reporting the issue, can you help to suggest a fix and send a PR?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
users@infra.apache.org
[GitHub] [tvm] vinx13 closed issue #8308: [BUG] Incorrect buffer offset for vectorized computation
Posted by GitBox <gi...@apache.org>.
vinx13 closed issue #8308:
URL: https://github.com/apache/tvm/issues/8308
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org
For queries about this service, please contact Infrastructure at:
users@infra.apache.org
[GitHub] [tvm] tqchen commented on issue #8308: [BUG] Incorrect buffer offset for vectorized computation
Posted by GitBox <gi...@apache.org>.
tqchen commented on issue #8308:
URL: https://github.com/apache/tvm/issues/8308#issuecomment-867617327
also cc @vinx13 @masahi who can help manage the related PRs
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
users@infra.apache.org