You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2021/03/23 20:48:16 UTC

[tvm] branch main updated: [FIX] Fix temporary allocation size in threefry (#7709)

This is an automated email from the ASF dual-hosted git repository.

marisa pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 6f0a656  [FIX] Fix temporary allocation size in threefry (#7709)
6f0a656 is described below

commit 6f0a6561593898053cde051fbb4687eef3adec39
Author: Tristan Konolige <tr...@gmail.com>
AuthorDate: Tue Mar 23 13:47:53 2021 -0700

    [FIX] Fix temporary allocation size in threefry (#7709)
    
    * [FIX] Fix temporary allocation size in threefry
    
    * bump sizes
---
 python/tvm/topi/random/kernel.py           |  2 +-
 tests/python/topi/python/test_topi_prng.py | 10 +++++-----
 2 files changed, 6 insertions(+), 6 deletions(-)

diff --git a/python/tvm/topi/random/kernel.py b/python/tvm/topi/random/kernel.py
index 728cd68..a09a5f3 100644
--- a/python/tvm/topi/random/kernel.py
+++ b/python/tvm/topi/random/kernel.py
@@ -141,7 +141,7 @@ def _threefry(
         return [x, y]
 
     # temporary buffer for holding the results of _PERMUTATIONS
-    tmp = irb.allocate(out_buf.dtype, out_shape, name="tmp", scope="global")
+    tmp = irb.allocate(out_buf.dtype, out_shape * nwords, name="tmp", scope="global")
     tmp_offset = 0
 
     # Initialize entire key. It is composed of the original key with one
diff --git a/tests/python/topi/python/test_topi_prng.py b/tests/python/topi/python/test_topi_prng.py
index 649e541..102e93f 100644
--- a/tests/python/topi/python/test_topi_prng.py
+++ b/tests/python/topi/python/test_topi_prng.py
@@ -87,9 +87,9 @@ def test_threefry_generate(target, ctx):
     gen = tvm.relay.random.threefry_key(0).data.asnumpy()
 
     # check that we can generate some data
-    a, rands = threefry_generate(target, ctx, gen, (100,))
+    a, rands = threefry_generate(target, ctx, gen, (2048,))
     assert (
-        rands.shape[0] == 100 and len(rands.shape) == 1
+        rands.shape[0] == 2048 and len(rands.shape) == 1
     ), "Output shape should match requested shape"
 
     # check that gen out does not equal input
@@ -99,13 +99,13 @@ def test_threefry_generate(target, ctx):
     gen = np.array(
         [0, 0, 0, 0, 0, 0, 0, 2 ** 64 - 2, 1 << 63, 0], dtype="uint64"
     )  # make counter large
-    a, rands = threefry_generate(target, ctx, gen, (100,))
+    a, rands = threefry_generate(target, ctx, gen, (2048,))
     assert gen[4] != a[4], "Overflow of counter should trigger path change"
-    assert a[7] == 100, "Overflow of counter should still update counter"
+    assert a[7] == 2048, "Overflow of counter should still update counter"
 
     # check generate with path at length limit
     gen = np.array([0, 0, 0, 0, 0, 0, 0, 2 ** 64 - 2, 0, 0], dtype="uint64")  # make counter large
-    a, rands = threefry_generate(target, ctx, gen, (100,))
+    a, rands = threefry_generate(target, ctx, gen, (2048,))
     assert (
         gen[0:4] != a[0:4]
     ).any(), "Overflowing counter with no space left in path should change state"