You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2021/03/23 20:48:16 UTC
[tvm] branch main updated: [FIX] Fix temporary allocation size in
threefry (#7709)
This is an automated email from the ASF dual-hosted git repository.
marisa pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 6f0a656 [FIX] Fix temporary allocation size in threefry (#7709)
6f0a656 is described below
commit 6f0a6561593898053cde051fbb4687eef3adec39
Author: Tristan Konolige <tr...@gmail.com>
AuthorDate: Tue Mar 23 13:47:53 2021 -0700
[FIX] Fix temporary allocation size in threefry (#7709)
* [FIX] Fix temporary allocation size in threefry
* bump sizes
---
python/tvm/topi/random/kernel.py | 2 +-
tests/python/topi/python/test_topi_prng.py | 10 +++++-----
2 files changed, 6 insertions(+), 6 deletions(-)
diff --git a/python/tvm/topi/random/kernel.py b/python/tvm/topi/random/kernel.py
index 728cd68..a09a5f3 100644
--- a/python/tvm/topi/random/kernel.py
+++ b/python/tvm/topi/random/kernel.py
@@ -141,7 +141,7 @@ def _threefry(
return [x, y]
# temporary buffer for holding the results of _PERMUTATIONS
- tmp = irb.allocate(out_buf.dtype, out_shape, name="tmp", scope="global")
+ tmp = irb.allocate(out_buf.dtype, out_shape * nwords, name="tmp", scope="global")
tmp_offset = 0
# Initialize entire key. It is composed of the original key with one
diff --git a/tests/python/topi/python/test_topi_prng.py b/tests/python/topi/python/test_topi_prng.py
index 649e541..102e93f 100644
--- a/tests/python/topi/python/test_topi_prng.py
+++ b/tests/python/topi/python/test_topi_prng.py
@@ -87,9 +87,9 @@ def test_threefry_generate(target, ctx):
gen = tvm.relay.random.threefry_key(0).data.asnumpy()
# check that we can generate some data
- a, rands = threefry_generate(target, ctx, gen, (100,))
+ a, rands = threefry_generate(target, ctx, gen, (2048,))
assert (
- rands.shape[0] == 100 and len(rands.shape) == 1
+ rands.shape[0] == 2048 and len(rands.shape) == 1
), "Output shape should match requested shape"
# check that gen out does not equal input
@@ -99,13 +99,13 @@ def test_threefry_generate(target, ctx):
gen = np.array(
[0, 0, 0, 0, 0, 0, 0, 2 ** 64 - 2, 1 << 63, 0], dtype="uint64"
) # make counter large
- a, rands = threefry_generate(target, ctx, gen, (100,))
+ a, rands = threefry_generate(target, ctx, gen, (2048,))
assert gen[4] != a[4], "Overflow of counter should trigger path change"
- assert a[7] == 100, "Overflow of counter should still update counter"
+ assert a[7] == 2048, "Overflow of counter should still update counter"
# check generate with path at length limit
gen = np.array([0, 0, 0, 0, 0, 0, 0, 2 ** 64 - 2, 0, 0], dtype="uint64") # make counter large
- a, rands = threefry_generate(target, ctx, gen, (100,))
+ a, rands = threefry_generate(target, ctx, gen, (2048,))
assert (
gen[0:4] != a[0:4]
).any(), "Overflowing counter with no space left in path should change state"