You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/05/14 04:17:11 UTC

[GitHub] [tvm] zhuzilin opened a new pull request #8041: [Relay][PRNG] Add uniform distribution generator wrt threefry PRNG

zhuzilin opened a new pull request #8041:
URL: https://github.com/apache/tvm/pull/8041


   This PR adds a uniform distribution generator using the  threefry PRNG introduced in #7083. We would need uniform to develop the training phase dropout as the following roadmap:
   
   ```
   uniform -> bernoulli -> dropout
   ```
   
   The algorithm used is basically the same as the one used in jax: using the random bits generated from `threefry_generate` as the fraction section of the float32 or float64. To be specific, I use the last 23 bits of the random bits for float32 and last 52 for float64. There is one different from the jax implementation. In jax, they used a bitcast to turn uint into float:
   
   ```python
   # jax implementation
   def _uniform(key, shape, dtype, minval, maxval) -> jnp.ndarray:
     ...
     bits = _random_bits(key, nbits, shape)
   
     # The strategy here is to randomize only the mantissa bits with an exponent of
     # 1 (after applying the bias), then shift and scale to the desired range. The
     # bit-level transformation we use relies on Numpy and XLA having bit-for-bit
     # equivalent float representations, which might not be true on all platforms.
     float_bits = lax.bitwise_or(
         lax.shift_right_logical(bits, np.array(nbits - nmant, lax.dtype(bits))),
         np.array(1., dtype).view(_UINT_DTYPES[nbits]))
     floats = lax.bitcast_convert_type(float_bits, dtype) - np.array(1., dtype)
     return lax.max(
         minval,
         lax.reshape(floats * (maxval - minval) + minval, shape.positional))
   ```
   
   However, as I haven't found the bitcast in te or topi, I use a divide to cast the type, which may be slower:
   
   ```python
       def uniform_scalar(bits):
           bits = bits >> (nbits - nfraction)
           standard_uniform = bits.astype(out_dtype) / float(1 << nfraction)
           return standard_uniform
   ```
   
   Thank you for your time on reviewing this PR. I may not be familiar enough with the tvm codebase at the moment, so I'm sorry for breaking any conventions in the community and I'd love to fix them :).
   
   Gently ping @tqchen @altanh @tkonolige 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] zhuzilin commented on a change in pull request #8041: [Relay][PRNG] Add uniform distribution generator wrt threefry PRNG

Posted by GitBox <gi...@apache.org>.
zhuzilin commented on a change in pull request #8041:
URL: https://github.com/apache/tvm/pull/8041#discussion_r634155870



##########
File path: python/tvm/relay/op/strategy/generic.py
##########
@@ -1495,6 +1495,28 @@ def threefry_split_strategy(attrs, inputs, out_type, target):
     return strategy
 
 
+# uniform
+def wrap_compute_uniform(topi_compute):
+    """Wrap uniform topi compute"""
+
+    def _compute_uniform(attrs, inputs, _):
+        return list(topi_compute(inputs[0], inputs[1], inputs[2], attrs.out_shape, attrs.out_dtype))
+
+    return _compute_uniform
+
+
+@override_native_generic_func("uniform_strategy")
+def uniform_strategy(attrs, inputs, out_type, target):
+    """uniform generic strategy"""
+    strategy = _op.OpStrategy()
+    strategy.add_implementation(
+        wrap_compute_uniform(topi.random.uniform),
+        wrap_topi_schedule(topi.generic.schedule_extern),

Review comment:
       As for me, the main goal for now is to add the missing bits to support the training, so I'm not so sure about the performance... I think we could gradually do the optimization after we have a complete workflow for training~




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] tkonolige commented on a change in pull request #8041: [Relay][PRNG] Add uniform distribution generator wrt threefry PRNG

Posted by GitBox <gi...@apache.org>.
tkonolige commented on a change in pull request #8041:
URL: https://github.com/apache/tvm/pull/8041#discussion_r632637754



##########
File path: python/tvm/topi/random/kernel.py
##########
@@ -466,3 +466,54 @@ def gen_ir(out_ptr):
     out_ary = tvm.nd.array(np.ones((1,), "uint64"), device)
     tvm.build(s, [f], target=target)(out_ary)
     return out_ary.asnumpy()[0] == 0
+
+
+def uniform(gen, low, high, out_shape, out_dtype):
+    """Draw samples from a uniform distribution.
+
+    Samples are uniformly distributed over the half-open interval [low, high)
+    (includes low, but excludes high). In other words, any value within the
+    given interval is equally likely to be drawn by uniform.
+
+    Parameters
+    ----------
+    gen : Tensor[10, uint64]

Review comment:
       You could probably say `ThreefryKey` instead of `Tensor[10, uint64]`




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] zhuzilin commented on a change in pull request #8041: [Relay][PRNG] Add uniform distribution generator wrt threefry PRNG

Posted by GitBox <gi...@apache.org>.
zhuzilin commented on a change in pull request #8041:
URL: https://github.com/apache/tvm/pull/8041#discussion_r632430646



##########
File path: python/tvm/topi/random/kernel.py
##########
@@ -466,3 +466,54 @@ def gen_ir(out_ptr):
     out_ary = tvm.nd.array(np.ones((1,), "uint64"), device)
     tvm.build(s, [f], target=target)(out_ary)
     return out_ary.asnumpy()[0] == 0
+
+
+def uniform(gen, low, high, out_shape, out_dtype):
+    """Draw samples from a uniform distribution.
+
+    Samples are uniformly distributed over the half-open interval [low, high)
+    (includes low, but excludes high). In other words, any value within the
+    given interval is equally likely to be drawn by uniform.
+
+    Parameters
+    ----------
+    gen : Tensor[10, uint64]
+        Generator state. Can be create with :py:func:`tvm.relay.threefry_key`. This should not be
+        reused in another function, otherwise random numbers will be repeated.
+
+    low : Tensor[(), out_dtype]
+        Lower boundary of the output interval. All values generated will be
+        greater than or equal to low.
+
+    high : Tensor[(), out_dtype]
+        Upper boundary of the output interval. All values generated will be
+        less than high.
+
+    out_shape : Sequence[int]
+        Output shape of the random numbers. Product of all dimensions must be a multiple of 4.
+
+    out_dtype : str
+        The output dtype.
+
+    Returns
+    -------
+    out : Tensor[out_shape, out_dtype]
+        Tensor of random numbers with shape `out_shape` and type `out_dtype`.
+    """
+    _, random_bits = threefry_generate(gen, out_shape)
+    nbits = 64
+    if out_dtype == "float32":
+        nfraction = 23
+    elif out_dtype == "float64":
+        nfraction = 52
+

Review comment:
       Thank you for the catch! Could you tell me how to restrict the `out_dtype` in `UniformAttrs`? or it would be great if there are any example ops that I can learn from. Thank you~




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] altanh commented on pull request #8041: [Relay][PRNG] Add uniform distribution generator wrt threefry PRNG

Posted by GitBox <gi...@apache.org>.
altanh commented on pull request #8041:
URL: https://github.com/apache/tvm/pull/8041#issuecomment-841437945


   Thanks for this PR! I will be reading it soon, and just wanted to point you to a branch I worked on a while ago where I hacked a uniform op + dropout support: https://github.com/altanh/tvm/commits/prng (just in case it might be useful for you to check and compare).
   
   > However, as I haven't found the bitcast in te or topi, I use a divide to cast the type, which may be slower:
   
   Perhaps this is the operation you're looking for? https://github.com/altanh/tvm/blob/2d9ac7710ab055d4f20e8b5a0a3580836723efac/python/tvm/topi/generic/algorithm.py#L465
   
   Thanks!
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] FrozenGene commented on a change in pull request #8041: [Relay][PRNG] Add uniform distribution generator wrt threefry PRNG

Posted by GitBox <gi...@apache.org>.
FrozenGene commented on a change in pull request #8041:
URL: https://github.com/apache/tvm/pull/8041#discussion_r632467451



##########
File path: python/tvm/topi/random/kernel.py
##########
@@ -466,3 +466,54 @@ def gen_ir(out_ptr):
     out_ary = tvm.nd.array(np.ones((1,), "uint64"), device)
     tvm.build(s, [f], target=target)(out_ary)
     return out_ary.asnumpy()[0] == 0
+
+
+def uniform(gen, low, high, out_shape, out_dtype):
+    """Draw samples from a uniform distribution.
+
+    Samples are uniformly distributed over the half-open interval [low, high)
+    (includes low, but excludes high). In other words, any value within the
+    given interval is equally likely to be drawn by uniform.
+
+    Parameters
+    ----------
+    gen : Tensor[10, uint64]

Review comment:
       If so, let us add comment describe what is the meaning of 10.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] FrozenGene commented on a change in pull request #8041: [Relay][PRNG] Add uniform distribution generator wrt threefry PRNG

Posted by GitBox <gi...@apache.org>.
FrozenGene commented on a change in pull request #8041:
URL: https://github.com/apache/tvm/pull/8041#discussion_r632280793



##########
File path: tests/python/topi/python/test_topi_prng.py
##########
@@ -43,6 +43,23 @@ def threefry_generate(target, dev, gen, size):
     return out_gen.asnumpy(), rands.asnumpy()
 
 
+def uniform(target, dev, gen, low, high, size, dtype):
+    gen_placeholder = tvm.te.placeholder(gen.shape, name="gen", dtype="uint64")
+    low_placeholder = tvm.te.placeholder(low.shape, name="low", dtype=dtype)
+    high_placeholder = tvm.te.placeholder(high.shape, name="high", dtype=dtype)
+    print(low_placeholder)
+    print(high_placeholder)
+    out_placeholder = tvm.topi.random.uniform(
+        gen_placeholder, low_placeholder, high_placeholder, size, dtype
+    )
+    print(out_placeholder)

Review comment:
       ditto

##########
File path: tests/python/topi/python/test_topi_prng.py
##########
@@ -43,6 +43,23 @@ def threefry_generate(target, dev, gen, size):
     return out_gen.asnumpy(), rands.asnumpy()
 
 
+def uniform(target, dev, gen, low, high, size, dtype):
+    gen_placeholder = tvm.te.placeholder(gen.shape, name="gen", dtype="uint64")
+    low_placeholder = tvm.te.placeholder(low.shape, name="low", dtype=dtype)
+    high_placeholder = tvm.te.placeholder(high.shape, name="high", dtype=dtype)
+    print(low_placeholder)
+    print(high_placeholder)

Review comment:
       remove it




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] zhuzilin commented on a change in pull request #8041: [Relay][PRNG] Add uniform distribution generator wrt threefry PRNG

Posted by GitBox <gi...@apache.org>.
zhuzilin commented on a change in pull request #8041:
URL: https://github.com/apache/tvm/pull/8041#discussion_r634892173



##########
File path: python/tvm/topi/random/kernel.py
##########
@@ -466,3 +466,67 @@ def gen_ir(out_ptr):
     out_ary = tvm.nd.array(np.ones((1,), "uint64"), device)
     tvm.build(s, [f], target=target)(out_ary)
     return out_ary.asnumpy()[0] == 0
+
+
+def uniform(gen, low, high, out_shape, out_dtype):
+    """Draw samples from a uniform distribution.
+
+    Samples are uniformly distributed over the half-open interval [low, high)
+    (includes low, but excludes high). In other words, any value within the
+    given interval is equally likely to be drawn by uniform.
+
+    Parameters
+    ----------
+    gen : ThreefryKey
+        Generator state. Can be create with :py:func:`tvm.relay.threefry_key`. This should not be
+        reused in another function, otherwise random numbers will be repeated.
+
+    low : Tensor[(), out_dtype]
+        Lower boundary of the output interval. All values generated will be
+        greater than or equal to low.
+
+    high : Tensor[(), out_dtype]
+        Upper boundary of the output interval. All values generated will be
+        less than high.
+
+    out_shape : Sequence[int]
+        Output shape of the random numbers. Product of all dimensions must be a multiple of 4.
+
+    out_dtype : str
+        The output dtype.
+
+    Returns
+    -------
+    new_gen : ThreefryKey
+        New generator state that is distinct from `gen`.
+
+    out : Tensor[out_shape, out_dtype]
+        Tensor of random numbers with shape `out_shape` and type `out_dtype`.
+    """
+    new_gen, random_bits = threefry_generate(gen, out_shape)
+    assert out_dtype in ("float32", "float64")

Review comment:
       Added.

##########
File path: src/relay/op/random/kernel.cc
##########
@@ -85,5 +85,52 @@ RELAY_REGISTER_OP("random.threefry_split")
     .add_argument("key", "Tensor", "Input Threefry key")
     .add_type_rel("ThreefrySplit", ThreefrySplitRel);
 
+TVM_REGISTER_NODE_TYPE(UniformAttrs);
+
+bool UniformRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
+                const TypeReporter& reporter) {
+  const UniformAttrs* param = attrs.as<UniformAttrs>();
+  ICHECK_EQ(types.size(), 4) << "Uniform should have three input and one output";

Review comment:
       Fixed.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] zhuzilin commented on a change in pull request #8041: [Relay][PRNG] Add uniform distribution generator wrt threefry PRNG

Posted by GitBox <gi...@apache.org>.
zhuzilin commented on a change in pull request #8041:
URL: https://github.com/apache/tvm/pull/8041#discussion_r632983593



##########
File path: python/tvm/topi/random/kernel.py
##########
@@ -466,3 +466,54 @@ def gen_ir(out_ptr):
     out_ary = tvm.nd.array(np.ones((1,), "uint64"), device)
     tvm.build(s, [f], target=target)(out_ary)
     return out_ary.asnumpy()[0] == 0
+
+
+def uniform(gen, low, high, out_shape, out_dtype):
+    """Draw samples from a uniform distribution.
+
+    Samples are uniformly distributed over the half-open interval [low, high)
+    (includes low, but excludes high). In other words, any value within the
+    given interval is equally likely to be drawn by uniform.
+
+    Parameters
+    ----------
+    gen : Tensor[10, uint64]
+        Generator state. Can be create with :py:func:`tvm.relay.threefry_key`. This should not be
+        reused in another function, otherwise random numbers will be repeated.
+
+    low : Tensor[(), out_dtype]
+        Lower boundary of the output interval. All values generated will be
+        greater than or equal to low.
+
+    high : Tensor[(), out_dtype]
+        Upper boundary of the output interval. All values generated will be
+        less than high.
+
+    out_shape : Sequence[int]
+        Output shape of the random numbers. Product of all dimensions must be a multiple of 4.

Review comment:
       Sorry, I just rethink about this problem. There should not be any restriction to the output shape... We could change the input restriction of the `threefry_generate` in other PR.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] altanh commented on a change in pull request #8041: [Relay][PRNG] Add uniform distribution generator wrt threefry PRNG

Posted by GitBox <gi...@apache.org>.
altanh commented on a change in pull request #8041:
URL: https://github.com/apache/tvm/pull/8041#discussion_r633781691



##########
File path: src/relay/op/random/kernel.cc
##########
@@ -85,5 +85,47 @@ RELAY_REGISTER_OP("random.threefry_split")
     .add_argument("key", "Tensor", "Input Threefry key")
     .add_type_rel("ThreefrySplit", ThreefrySplitRel);
 
+TVM_REGISTER_NODE_TYPE(UniformAttrs);
+
+bool UniformRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
+                const TypeReporter& reporter) {
+  const UniformAttrs* param = attrs.as<UniformAttrs>();
+  ICHECK_EQ(types.size(), 4) << "Uniform should have three input and one output";
+
+  std::vector<IndexExpr> oshape;
+  for (auto& x : param->out_shape) {
+    oshape.push_back(x);
+  }
+  DataType out_dtype = param->out_dtype;
+  // we are supporting float32 and float64 at the moment.
+  ICHECK(out_dtype.is_float() && (out_dtype.bits() == 32 || out_dtype.bits() == 64));

Review comment:
       Could you use the Diagnostics API for erroring out this check?
   
   See for example: https://github.com/apache/tvm/blob/main/src/relay/op/nn/nn.cc#L65

##########
File path: python/tvm/relay/op/strategy/generic.py
##########
@@ -1495,6 +1495,28 @@ def threefry_split_strategy(attrs, inputs, out_type, target):
     return strategy
 
 
+# uniform
+def wrap_compute_uniform(topi_compute):
+    """Wrap uniform topi compute"""
+
+    def _compute_uniform(attrs, inputs, _):
+        return list(topi_compute(inputs[0], inputs[1], inputs[2], attrs.out_shape, attrs.out_dtype))
+
+    return _compute_uniform
+
+
+@override_native_generic_func("uniform_strategy")
+def uniform_strategy(attrs, inputs, out_type, target):
+    """uniform generic strategy"""
+    strategy = _op.OpStrategy()
+    strategy.add_implementation(
+        wrap_compute_uniform(topi.random.uniform),
+        wrap_topi_schedule(topi.generic.schedule_extern),

Review comment:
       do we need to specialize some schedules for this op? could be a follow up PR just wondering how the perf works with just generic schedule




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] FrozenGene commented on pull request #8041: [Relay][PRNG] Add uniform distribution generator wrt threefry PRNG

Posted by GitBox <gi...@apache.org>.
FrozenGene commented on pull request #8041:
URL: https://github.com/apache/tvm/pull/8041#issuecomment-845621917


   Thanks @zhuzilin @altanh @tkonolige merged now


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] FrozenGene edited a comment on pull request #8041: [Relay][PRNG] Add uniform distribution generator wrt threefry PRNG

Posted by GitBox <gi...@apache.org>.
FrozenGene edited a comment on pull request #8041:
URL: https://github.com/apache/tvm/pull/8041#issuecomment-841919711


   > @FrozenGene Thank you for the clue! However, I haven't find how to restrict the dtype attributes in `Conv2DRel` or `Conv2DAttrs`... Should I add the type restriction to the `UniformAttrs`, or raise error in the `MakeUniform` and `UniformRel`?
   
   Suggest `UniformRel`
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] zhuzilin commented on a change in pull request #8041: [Relay][PRNG] Add uniform distribution generator wrt threefry PRNG

Posted by GitBox <gi...@apache.org>.
zhuzilin commented on a change in pull request #8041:
URL: https://github.com/apache/tvm/pull/8041#discussion_r634864945



##########
File path: python/tvm/topi/random/kernel.py
##########
@@ -466,3 +466,54 @@ def gen_ir(out_ptr):
     out_ary = tvm.nd.array(np.ones((1,), "uint64"), device)
     tvm.build(s, [f], target=target)(out_ary)
     return out_ary.asnumpy()[0] == 0
+
+
+def uniform(gen, low, high, out_shape, out_dtype):
+    """Draw samples from a uniform distribution.
+
+    Samples are uniformly distributed over the half-open interval [low, high)
+    (includes low, but excludes high). In other words, any value within the
+    given interval is equally likely to be drawn by uniform.
+
+    Parameters
+    ----------
+    gen : Tensor[10, uint64]
+        Generator state. Can be create with :py:func:`tvm.relay.threefry_key`. This should not be
+        reused in another function, otherwise random numbers will be repeated.
+
+    low : Tensor[(), out_dtype]
+        Lower boundary of the output interval. All values generated will be
+        greater than or equal to low.
+
+    high : Tensor[(), out_dtype]
+        Upper boundary of the output interval. All values generated will be
+        less than high.
+
+    out_shape : Sequence[int]
+        Output shape of the random numbers. Product of all dimensions must be a multiple of 4.

Review comment:
       @tkonolige Sure, I will submit one. Could you tell me what kind of update on key `tmp` we need before the second `_threefry`? I can only think of updating increment counter (`tmp[7]`).




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] tkonolige commented on a change in pull request #8041: [Relay][PRNG] Add uniform distribution generator wrt threefry PRNG

Posted by GitBox <gi...@apache.org>.
tkonolige commented on a change in pull request #8041:
URL: https://github.com/apache/tvm/pull/8041#discussion_r635397639



##########
File path: python/tvm/topi/random/kernel.py
##########
@@ -466,3 +466,54 @@ def gen_ir(out_ptr):
     out_ary = tvm.nd.array(np.ones((1,), "uint64"), device)
     tvm.build(s, [f], target=target)(out_ary)
     return out_ary.asnumpy()[0] == 0
+
+
+def uniform(gen, low, high, out_shape, out_dtype):
+    """Draw samples from a uniform distribution.
+
+    Samples are uniformly distributed over the half-open interval [low, high)
+    (includes low, but excludes high). In other words, any value within the
+    given interval is equally likely to be drawn by uniform.
+
+    Parameters
+    ----------
+    gen : Tensor[10, uint64]
+        Generator state. Can be create with :py:func:`tvm.relay.threefry_key`. This should not be
+        reused in another function, otherwise random numbers will be repeated.
+
+    low : Tensor[(), out_dtype]
+        Lower boundary of the output interval. All values generated will be
+        greater than or equal to low.
+
+    high : Tensor[(), out_dtype]
+        Upper boundary of the output interval. All values generated will be
+        less than high.
+
+    out_shape : Sequence[int]
+        Output shape of the random numbers. Product of all dimensions must be a multiple of 4.

Review comment:
       You'll need to update the counter buffer to be equal to out_len




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] tkonolige commented on a change in pull request #8041: [Relay][PRNG] Add uniform distribution generator wrt threefry PRNG

Posted by GitBox <gi...@apache.org>.
tkonolige commented on a change in pull request #8041:
URL: https://github.com/apache/tvm/pull/8041#discussion_r633744201



##########
File path: python/tvm/relay/op/random/kernel.py
##########
@@ -132,3 +132,55 @@ def foo(key):
         :py:func:`threefry_generate`.
     """
     return _make.threefry_split(key)
+
+
+def uniform(key, shape, dtype="float32", low=0.0, high=1.0):
+    """Draw samples from a uniform distribution.
+
+    Samples are uniformly distributed over the half-open interval [low, high)
+    (includes low, but excludes high). In other words, any value within the
+    given interval is equally likely to be drawn by uniform.
+
+    Example
+    -------
+
+    .. code-block:: python
+
+        key = threefry_key(0)
+        random_values = uniform(key, (100,), low=0, high=10)

Review comment:
       ```suggestion
           key = threefry_key(0)
           key, random_values = uniform(key, (100,), low=0, high=10)
   ```

##########
File path: python/tvm/topi/random/kernel.py
##########
@@ -466,3 +466,54 @@ def gen_ir(out_ptr):
     out_ary = tvm.nd.array(np.ones((1,), "uint64"), device)
     tvm.build(s, [f], target=target)(out_ary)
     return out_ary.asnumpy()[0] == 0
+
+
+def uniform(gen, low, high, out_shape, out_dtype):
+    """Draw samples from a uniform distribution.
+
+    Samples are uniformly distributed over the half-open interval [low, high)
+    (includes low, but excludes high). In other words, any value within the
+    given interval is equally likely to be drawn by uniform.
+
+    Parameters
+    ----------
+    gen : Tensor[10, uint64]
+        Generator state. Can be create with :py:func:`tvm.relay.threefry_key`. This should not be
+        reused in another function, otherwise random numbers will be repeated.
+
+    low : Tensor[(), out_dtype]
+        Lower boundary of the output interval. All values generated will be
+        greater than or equal to low.
+
+    high : Tensor[(), out_dtype]
+        Upper boundary of the output interval. All values generated will be
+        less than high.
+
+    out_shape : Sequence[int]
+        Output shape of the random numbers. Product of all dimensions must be a multiple of 4.
+
+    out_dtype : str
+        The output dtype.
+
+    Returns
+    -------
+    out : Tensor[out_shape, out_dtype]
+        Tensor of random numbers with shape `out_shape` and type `out_dtype`.
+    """
+    _, random_bits = threefry_generate(gen, out_shape)
+    nbits = 64
+    if out_dtype == "float32":
+        nfraction = 23
+    elif out_dtype == "float64":
+        nfraction = 52
+

Review comment:
       I would add a check here anyways just to be safe.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] zhuzilin edited a comment on pull request #8041: [Relay][PRNG] Add uniform distribution generator wrt threefry PRNG

Posted by GitBox <gi...@apache.org>.
zhuzilin edited a comment on pull request #8041:
URL: https://github.com/apache/tvm/pull/8041#issuecomment-842966025


   @FrozenGene  @altanh @tkonolige I've updated the PR upon the reviews. Could you take another look? Thank you~


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] FrozenGene commented on a change in pull request #8041: [Relay][PRNG] Add uniform distribution generator wrt threefry PRNG

Posted by GitBox <gi...@apache.org>.
FrozenGene commented on a change in pull request #8041:
URL: https://github.com/apache/tvm/pull/8041#discussion_r632401575



##########
File path: tests/python/topi/python/test_topi_prng.py
##########
@@ -118,7 +132,22 @@ def test_threefry_wrapping(target, dev):
     ), f"{target} does not suppport wrapping unsigned integer arithmetic"
 
 
+@tvm.testing.parametrize_targets
+def test_uniform(target, dev):
+    gen = tvm.relay.random.threefry_key(0).data.asnumpy()
+    m = 1024
+    n = 1024
+    dtype = "float32"

Review comment:
       ditto

##########
File path: python/tvm/topi/random/kernel.py
##########
@@ -466,3 +466,54 @@ def gen_ir(out_ptr):
     out_ary = tvm.nd.array(np.ones((1,), "uint64"), device)
     tvm.build(s, [f], target=target)(out_ary)
     return out_ary.asnumpy()[0] == 0
+
+
+def uniform(gen, low, high, out_shape, out_dtype):
+    """Draw samples from a uniform distribution.
+
+    Samples are uniformly distributed over the half-open interval [low, high)
+    (includes low, but excludes high). In other words, any value within the
+    given interval is equally likely to be drawn by uniform.
+
+    Parameters
+    ----------
+    gen : Tensor[10, uint64]

Review comment:
       What is the meaning of `10`?

##########
File path: tests/python/relay/test_prng.py
##########
@@ -103,6 +103,19 @@ def test_threefry_split_infer():
     assert tvm.ir.structural_equal(f.ret_type, expected_type)
 
 
+def test_uniform_infer():
+    oshape = (12,)
+    odtype = "float32"

Review comment:
       Should cover more types. For example `float64` you have implemented

##########
File path: python/tvm/topi/random/kernel.py
##########
@@ -466,3 +466,54 @@ def gen_ir(out_ptr):
     out_ary = tvm.nd.array(np.ones((1,), "uint64"), device)
     tvm.build(s, [f], target=target)(out_ary)
     return out_ary.asnumpy()[0] == 0
+
+
+def uniform(gen, low, high, out_shape, out_dtype):
+    """Draw samples from a uniform distribution.
+
+    Samples are uniformly distributed over the half-open interval [low, high)
+    (includes low, but excludes high). In other words, any value within the
+    given interval is equally likely to be drawn by uniform.
+
+    Parameters
+    ----------
+    gen : Tensor[10, uint64]
+        Generator state. Can be create with :py:func:`tvm.relay.threefry_key`. This should not be
+        reused in another function, otherwise random numbers will be repeated.
+
+    low : Tensor[(), out_dtype]
+        Lower boundary of the output interval. All values generated will be
+        greater than or equal to low.
+
+    high : Tensor[(), out_dtype]
+        Upper boundary of the output interval. All values generated will be
+        less than high.
+
+    out_shape : Sequence[int]
+        Output shape of the random numbers. Product of all dimensions must be a multiple of 4.

Review comment:
       What is the reason of `product must be a multiple of 4`?

##########
File path: python/tvm/topi/random/kernel.py
##########
@@ -466,3 +466,54 @@ def gen_ir(out_ptr):
     out_ary = tvm.nd.array(np.ones((1,), "uint64"), device)
     tvm.build(s, [f], target=target)(out_ary)
     return out_ary.asnumpy()[0] == 0
+
+
+def uniform(gen, low, high, out_shape, out_dtype):
+    """Draw samples from a uniform distribution.
+
+    Samples are uniformly distributed over the half-open interval [low, high)
+    (includes low, but excludes high). In other words, any value within the
+    given interval is equally likely to be drawn by uniform.
+
+    Parameters
+    ----------
+    gen : Tensor[10, uint64]
+        Generator state. Can be create with :py:func:`tvm.relay.threefry_key`. This should not be
+        reused in another function, otherwise random numbers will be repeated.
+
+    low : Tensor[(), out_dtype]
+        Lower boundary of the output interval. All values generated will be
+        greater than or equal to low.
+
+    high : Tensor[(), out_dtype]
+        Upper boundary of the output interval. All values generated will be
+        less than high.
+
+    out_shape : Sequence[int]
+        Output shape of the random numbers. Product of all dimensions must be a multiple of 4.
+
+    out_dtype : str
+        The output dtype.
+
+    Returns
+    -------
+    out : Tensor[out_shape, out_dtype]
+        Tensor of random numbers with shape `out_shape` and type `out_dtype`.
+    """
+    _, random_bits = threefry_generate(gen, out_shape)
+    nbits = 64
+    if out_dtype == "float32":
+        nfraction = 23
+    elif out_dtype == "float64":
+        nfraction = 52
+

Review comment:
       What will be happened if our `out_dtype` are other types? For example `float16`? Even `int8` / `int32`?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] FrozenGene merged pull request #8041: [Relay][PRNG] Add uniform distribution generator wrt threefry PRNG

Posted by GitBox <gi...@apache.org>.
FrozenGene merged pull request #8041:
URL: https://github.com/apache/tvm/pull/8041


   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] zhuzilin commented on a change in pull request #8041: [Relay][PRNG] Add uniform distribution generator wrt threefry PRNG

Posted by GitBox <gi...@apache.org>.
zhuzilin commented on a change in pull request #8041:
URL: https://github.com/apache/tvm/pull/8041#discussion_r632282508



##########
File path: tests/python/topi/python/test_topi_prng.py
##########
@@ -43,6 +43,23 @@ def threefry_generate(target, dev, gen, size):
     return out_gen.asnumpy(), rands.asnumpy()
 
 
+def uniform(target, dev, gen, low, high, size, dtype):
+    gen_placeholder = tvm.te.placeholder(gen.shape, name="gen", dtype="uint64")
+    low_placeholder = tvm.te.placeholder(low.shape, name="low", dtype=dtype)
+    high_placeholder = tvm.te.placeholder(high.shape, name="high", dtype=dtype)
+    print(low_placeholder)
+    print(high_placeholder)
+    out_placeholder = tvm.topi.random.uniform(
+        gen_placeholder, low_placeholder, high_placeholder, size, dtype
+    )
+    print(out_placeholder)

Review comment:
       Sorry for that :disappointed_relieved: Removed.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] zhuzilin commented on a change in pull request #8041: [Relay][PRNG] Add uniform distribution generator wrt threefry PRNG

Posted by GitBox <gi...@apache.org>.
zhuzilin commented on a change in pull request #8041:
URL: https://github.com/apache/tvm/pull/8041#discussion_r632978098



##########
File path: src/relay/op/random/kernel.cc
##########
@@ -85,5 +85,46 @@ RELAY_REGISTER_OP("random.threefry_split")
     .add_argument("key", "Tensor", "Input Threefry key")
     .add_type_rel("ThreefrySplit", ThreefrySplitRel);
 
+TVM_REGISTER_NODE_TYPE(UniformAttrs);
+
+bool UniformRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
+                const TypeReporter& reporter) {
+  const UniformAttrs* param = attrs.as<UniformAttrs>();
+  ICHECK_EQ(types.size(), 4) << "ThreefryGenerate should have one input and one output";

Review comment:
       Fixed.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] altanh commented on a change in pull request #8041: [Relay][PRNG] Add uniform distribution generator wrt threefry PRNG

Posted by GitBox <gi...@apache.org>.
altanh commented on a change in pull request #8041:
URL: https://github.com/apache/tvm/pull/8041#discussion_r634803831



##########
File path: tests/python/topi/python/test_topi_prng.py
##########
@@ -118,7 +135,24 @@ def test_threefry_wrapping(target, dev):
     ), f"{target} does not suppport wrapping unsigned integer arithmetic"
 
 
+@tvm.testing.parametrize_targets
+def test_uniform(target, dev):
+    gen = tvm.relay.random.threefry_key(0).data.asnumpy()
+    m = 1024
+    n = 1024
+    dtypes = ["float32", "float64"]
+    for dtype in dtypes:
+        low = np.array(5.0, dtype=dtype)
+        high = np.array(10.0, dtype=dtype)
+        new_gen, rands = uniform(target, dev, gen, low, high, (m, n), dtype)
+        assert (gen != new_gen).any()
+        assert abs(np.mean(rands) - 7.5) < 1e-1
+        assert abs(np.min(rands) - 5.0) < 1e-3
+        assert abs(np.max(rands) - 10.0) < 1e-3

Review comment:
       `<=` check?

##########
File path: tests/python/topi/python/test_topi_prng.py
##########
@@ -118,7 +135,24 @@ def test_threefry_wrapping(target, dev):
     ), f"{target} does not suppport wrapping unsigned integer arithmetic"
 
 
+@tvm.testing.parametrize_targets
+def test_uniform(target, dev):
+    gen = tvm.relay.random.threefry_key(0).data.asnumpy()
+    m = 1024
+    n = 1024
+    dtypes = ["float32", "float64"]
+    for dtype in dtypes:
+        low = np.array(5.0, dtype=dtype)
+        high = np.array(10.0, dtype=dtype)
+        new_gen, rands = uniform(target, dev, gen, low, high, (m, n), dtype)
+        assert (gen != new_gen).any()
+        assert abs(np.mean(rands) - 7.5) < 1e-1
+        assert abs(np.min(rands) - 5.0) < 1e-3

Review comment:
       could we change this to a `>=` check? 




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] zhuzilin commented on a change in pull request #8041: [Relay][PRNG] Add uniform distribution generator wrt threefry PRNG

Posted by GitBox <gi...@apache.org>.
zhuzilin commented on a change in pull request #8041:
URL: https://github.com/apache/tvm/pull/8041#discussion_r632978037



##########
File path: python/tvm/relay/op/random/kernel.py
##########
@@ -132,3 +132,52 @@ def foo(key):
         :py:func:`threefry_generate`.
     """
     return _make.threefry_split(key)
+
+
+def uniform(key, shape, dtype="float32", low=0.0, high=1.0):
+    """Draw samples from a uniform distribution.
+
+    Samples are uniformly distributed over the half-open interval [low, high)
+    (includes low, but excludes high). In other words, any value within the
+    given interval is equally likely to be drawn by uniform.
+
+    Example
+    -------
+
+    .. code-block:: python
+
+        key = threefry_key(0)
+        random_values = uniform(key, (100,), low=0, high=10)
+
+    Parameters
+    ----------
+    key : relay.Expr
+        key that uniquely determines the random values. Multiple uses with the
+        same generator will generate the same random values. This generator should be
+        treated as an opaque pointer. You can create one from calling
+        :py:func:`threefry_key`, :py:func:`threefry_split`, or
+        :py:func:`threefry_generate`. **Do not use this generator again after calling
+        this function.**
+
+    shape : Sequence[int]
+        Desired outputs shape of random numbers.
+
+    dtype : str
+        Desired outputs type of random numbers.
+
+    low : float or relay.Expr, optional
+        Lower bound of the uniform distribution.
+
+    high : float or relay.Expr, optional
+        Upper bound of the uniform distribution.
+
+    Returns

Review comment:
       Agree~ I've updated the output.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] altanh commented on a change in pull request #8041: [Relay][PRNG] Add uniform distribution generator wrt threefry PRNG

Posted by GitBox <gi...@apache.org>.
altanh commented on a change in pull request #8041:
URL: https://github.com/apache/tvm/pull/8041#discussion_r633750359



##########
File path: python/tvm/topi/random/kernel.py
##########
@@ -466,3 +466,54 @@ def gen_ir(out_ptr):
     out_ary = tvm.nd.array(np.ones((1,), "uint64"), device)
     tvm.build(s, [f], target=target)(out_ary)
     return out_ary.asnumpy()[0] == 0
+
+
+def uniform(gen, low, high, out_shape, out_dtype):
+    """Draw samples from a uniform distribution.
+
+    Samples are uniformly distributed over the half-open interval [low, high)
+    (includes low, but excludes high). In other words, any value within the
+    given interval is equally likely to be drawn by uniform.
+
+    Parameters
+    ----------
+    gen : Tensor[10, uint64]
+        Generator state. Can be create with :py:func:`tvm.relay.threefry_key`. This should not be
+        reused in another function, otherwise random numbers will be repeated.
+
+    low : Tensor[(), out_dtype]
+        Lower boundary of the output interval. All values generated will be
+        greater than or equal to low.
+
+    high : Tensor[(), out_dtype]
+        Upper boundary of the output interval. All values generated will be
+        less than high.
+
+    out_shape : Sequence[int]
+        Output shape of the random numbers. Product of all dimensions must be a multiple of 4.

Review comment:
       do you mind sending a PR for updating the threefry_generate output, or rather what approach do you have in mind? I tried to avoid this problem by truncating output buffer but this required an extra copy, wonder if you have something else




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] zhuzilin commented on a change in pull request #8041: [Relay][PRNG] Add uniform distribution generator wrt threefry PRNG

Posted by GitBox <gi...@apache.org>.
zhuzilin commented on a change in pull request #8041:
URL: https://github.com/apache/tvm/pull/8041#discussion_r632282652



##########
File path: tests/python/topi/python/test_topi_prng.py
##########
@@ -43,6 +43,23 @@ def threefry_generate(target, dev, gen, size):
     return out_gen.asnumpy(), rands.asnumpy()
 
 
+def uniform(target, dev, gen, low, high, size, dtype):
+    gen_placeholder = tvm.te.placeholder(gen.shape, name="gen", dtype="uint64")
+    low_placeholder = tvm.te.placeholder(low.shape, name="low", dtype=dtype)
+    high_placeholder = tvm.te.placeholder(high.shape, name="high", dtype=dtype)
+    print(low_placeholder)
+    print(high_placeholder)

Review comment:
       Removed.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] FrozenGene commented on pull request #8041: [Relay][PRNG] Add uniform distribution generator wrt threefry PRNG

Posted by GitBox <gi...@apache.org>.
FrozenGene commented on pull request #8041:
URL: https://github.com/apache/tvm/pull/8041#issuecomment-841919711


   > @FrozenGene Thank you for the clue! However, I haven't find how to restrict the dtype attributes in `Conv2DRel` or `Conv2DAttrs`... Should I add the type restriction to the `UniformAttrs`, or raise error in the `MakeUniform` and `UniformRel`?
   Suggest `UniformRel`
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] zhuzilin commented on a change in pull request #8041: [Relay][PRNG] Add uniform distribution generator wrt threefry PRNG

Posted by GitBox <gi...@apache.org>.
zhuzilin commented on a change in pull request #8041:
URL: https://github.com/apache/tvm/pull/8041#discussion_r634892900



##########
File path: tests/python/topi/python/test_topi_prng.py
##########
@@ -118,7 +135,24 @@ def test_threefry_wrapping(target, dev):
     ), f"{target} does not suppport wrapping unsigned integer arithmetic"
 
 
+@tvm.testing.parametrize_targets
+def test_uniform(target, dev):
+    gen = tvm.relay.random.threefry_key(0).data.asnumpy()
+    m = 1024
+    n = 1024
+    dtypes = ["float32", "float64"]
+    for dtype in dtypes:
+        low = np.array(5.0, dtype=dtype)
+        high = np.array(10.0, dtype=dtype)
+        new_gen, rands = uniform(target, dev, gen, low, high, (m, n), dtype)
+        assert (gen != new_gen).any()
+        assert abs(np.mean(rands) - 7.5) < 1e-1
+        assert abs(np.min(rands) - 5.0) < 1e-3

Review comment:
       Changed to `>=` and `<=`. They should pass all the time.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] zhuzilin commented on a change in pull request #8041: [Relay][PRNG] Add uniform distribution generator wrt threefry PRNG

Posted by GitBox <gi...@apache.org>.
zhuzilin commented on a change in pull request #8041:
URL: https://github.com/apache/tvm/pull/8041#discussion_r634156686



##########
File path: python/tvm/relay/op/random/kernel.py
##########
@@ -132,3 +132,55 @@ def foo(key):
         :py:func:`threefry_generate`.
     """
     return _make.threefry_split(key)
+
+
+def uniform(key, shape, dtype="float32", low=0.0, high=1.0):
+    """Draw samples from a uniform distribution.
+
+    Samples are uniformly distributed over the half-open interval [low, high)
+    (includes low, but excludes high). In other words, any value within the
+    given interval is equally likely to be drawn by uniform.
+
+    Example
+    -------
+
+    .. code-block:: python
+
+        key = threefry_key(0)
+        random_values = uniform(key, (100,), low=0, high=10)

Review comment:
       Fixed.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] zhuzilin commented on pull request #8041: [Relay][PRNG] Add uniform distribution generator wrt threefry PRNG

Posted by GitBox <gi...@apache.org>.
zhuzilin commented on pull request #8041:
URL: https://github.com/apache/tvm/pull/8041#issuecomment-843724437


   @FrozenGene Could you take another look of this PR? Thank you~


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] zhuzilin commented on pull request #8041: [Relay][PRNG] Add uniform distribution generator wrt threefry PRNG

Posted by GitBox <gi...@apache.org>.
zhuzilin commented on pull request #8041:
URL: https://github.com/apache/tvm/pull/8041#issuecomment-845608860


   @FrozenGene Could you have another look at this PR? Thank you!


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] zhuzilin commented on a change in pull request #8041: [Relay][PRNG] Add uniform distribution generator wrt threefry PRNG

Posted by GitBox <gi...@apache.org>.
zhuzilin commented on a change in pull request #8041:
URL: https://github.com/apache/tvm/pull/8041#discussion_r632983593



##########
File path: python/tvm/topi/random/kernel.py
##########
@@ -466,3 +466,54 @@ def gen_ir(out_ptr):
     out_ary = tvm.nd.array(np.ones((1,), "uint64"), device)
     tvm.build(s, [f], target=target)(out_ary)
     return out_ary.asnumpy()[0] == 0
+
+
+def uniform(gen, low, high, out_shape, out_dtype):
+    """Draw samples from a uniform distribution.
+
+    Samples are uniformly distributed over the half-open interval [low, high)
+    (includes low, but excludes high). In other words, any value within the
+    given interval is equally likely to be drawn by uniform.
+
+    Parameters
+    ----------
+    gen : Tensor[10, uint64]
+        Generator state. Can be create with :py:func:`tvm.relay.threefry_key`. This should not be
+        reused in another function, otherwise random numbers will be repeated.
+
+    low : Tensor[(), out_dtype]
+        Lower boundary of the output interval. All values generated will be
+        greater than or equal to low.
+
+    high : Tensor[(), out_dtype]
+        Upper boundary of the output interval. All values generated will be
+        less than high.
+
+    out_shape : Sequence[int]
+        Output shape of the random numbers. Product of all dimensions must be a multiple of 4.

Review comment:
       Sorry, I just rethink about this problem. There should not be any restriction to the output shape... I will fix this soon.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] zhuzilin commented on a change in pull request #8041: [Relay][PRNG] Add uniform distribution generator wrt threefry PRNG

Posted by GitBox <gi...@apache.org>.
zhuzilin commented on a change in pull request #8041:
URL: https://github.com/apache/tvm/pull/8041#discussion_r634153432



##########
File path: src/relay/op/random/kernel.cc
##########
@@ -85,5 +85,47 @@ RELAY_REGISTER_OP("random.threefry_split")
     .add_argument("key", "Tensor", "Input Threefry key")
     .add_type_rel("ThreefrySplit", ThreefrySplitRel);
 
+TVM_REGISTER_NODE_TYPE(UniformAttrs);
+
+bool UniformRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
+                const TypeReporter& reporter) {
+  const UniformAttrs* param = attrs.as<UniformAttrs>();
+  ICHECK_EQ(types.size(), 4) << "Uniform should have three input and one output";
+
+  std::vector<IndexExpr> oshape;
+  for (auto& x : param->out_shape) {
+    oshape.push_back(x);
+  }
+  DataType out_dtype = param->out_dtype;
+  // we are supporting float32 and float64 at the moment.
+  ICHECK(out_dtype.is_float() && (out_dtype.bits() == 32 || out_dtype.bits() == 64));

Review comment:
       Updated.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] zhuzilin commented on a change in pull request #8041: [Relay][PRNG] Add uniform distribution generator wrt threefry PRNG

Posted by GitBox <gi...@apache.org>.
zhuzilin commented on a change in pull request #8041:
URL: https://github.com/apache/tvm/pull/8041#discussion_r632983593



##########
File path: python/tvm/topi/random/kernel.py
##########
@@ -466,3 +466,54 @@ def gen_ir(out_ptr):
     out_ary = tvm.nd.array(np.ones((1,), "uint64"), device)
     tvm.build(s, [f], target=target)(out_ary)
     return out_ary.asnumpy()[0] == 0
+
+
+def uniform(gen, low, high, out_shape, out_dtype):
+    """Draw samples from a uniform distribution.
+
+    Samples are uniformly distributed over the half-open interval [low, high)
+    (includes low, but excludes high). In other words, any value within the
+    given interval is equally likely to be drawn by uniform.
+
+    Parameters
+    ----------
+    gen : Tensor[10, uint64]
+        Generator state. Can be create with :py:func:`tvm.relay.threefry_key`. This should not be
+        reused in another function, otherwise random numbers will be repeated.
+
+    low : Tensor[(), out_dtype]
+        Lower boundary of the output interval. All values generated will be
+        greater than or equal to low.
+
+    high : Tensor[(), out_dtype]
+        Upper boundary of the output interval. All values generated will be
+        less than high.
+
+    out_shape : Sequence[int]
+        Output shape of the random numbers. Product of all dimensions must be a multiple of 4.

Review comment:
       Sorry, I just rethink about this problem. There should not be any restriction to the output shape... We could change the input restriction of the `threefry_generate` in other PR.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] FrozenGene commented on a change in pull request #8041: [Relay][PRNG] Add uniform distribution generator wrt threefry PRNG

Posted by GitBox <gi...@apache.org>.
FrozenGene commented on a change in pull request #8041:
URL: https://github.com/apache/tvm/pull/8041#discussion_r632479707



##########
File path: python/tvm/topi/random/kernel.py
##########
@@ -466,3 +466,54 @@ def gen_ir(out_ptr):
     out_ary = tvm.nd.array(np.ones((1,), "uint64"), device)
     tvm.build(s, [f], target=target)(out_ary)
     return out_ary.asnumpy()[0] == 0
+
+
+def uniform(gen, low, high, out_shape, out_dtype):
+    """Draw samples from a uniform distribution.
+
+    Samples are uniformly distributed over the half-open interval [low, high)
+    (includes low, but excludes high). In other words, any value within the
+    given interval is equally likely to be drawn by uniform.
+
+    Parameters
+    ----------
+    gen : Tensor[10, uint64]
+        Generator state. Can be create with :py:func:`tvm.relay.threefry_key`. This should not be
+        reused in another function, otherwise random numbers will be repeated.
+
+    low : Tensor[(), out_dtype]
+        Lower boundary of the output interval. All values generated will be
+        greater than or equal to low.
+
+    high : Tensor[(), out_dtype]
+        Upper boundary of the output interval. All values generated will be
+        less than high.
+
+    out_shape : Sequence[int]
+        Output shape of the random numbers. Product of all dimensions must be a multiple of 4.
+
+    out_dtype : str
+        The output dtype.
+
+    Returns
+    -------
+    out : Tensor[out_shape, out_dtype]
+        Tensor of random numbers with shape `out_shape` and type `out_dtype`.
+    """
+    _, random_bits = threefry_generate(gen, out_shape)
+    nbits = 64
+    if out_dtype == "float32":
+        nfraction = 23
+    elif out_dtype == "float64":
+        nfraction = 52
+

Review comment:
       I think there are 2 options: 1. Ref `Conv2DRel`. 2. You could restrict the type here and raise exception. I prefer option 1.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] FrozenGene commented on pull request #8041: [Relay][PRNG] Add uniform distribution generator wrt threefry PRNG

Posted by GitBox <gi...@apache.org>.
FrozenGene commented on pull request #8041:
URL: https://github.com/apache/tvm/pull/8041#issuecomment-841919711


   > @FrozenGene Thank you for the clue! However, I haven't find how to restrict the dtype attributes in `Conv2DRel` or `Conv2DAttrs`... Should I add the type restriction to the `UniformAttrs`, or raise error in the `MakeUniform` and `UniformRel`?
   Suggest `UniformRel`
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] tkonolige commented on a change in pull request #8041: [Relay][PRNG] Add uniform distribution generator wrt threefry PRNG

Posted by GitBox <gi...@apache.org>.
tkonolige commented on a change in pull request #8041:
URL: https://github.com/apache/tvm/pull/8041#discussion_r634782901



##########
File path: python/tvm/topi/random/kernel.py
##########
@@ -466,3 +466,67 @@ def gen_ir(out_ptr):
     out_ary = tvm.nd.array(np.ones((1,), "uint64"), device)
     tvm.build(s, [f], target=target)(out_ary)
     return out_ary.asnumpy()[0] == 0
+
+
+def uniform(gen, low, high, out_shape, out_dtype):
+    """Draw samples from a uniform distribution.
+
+    Samples are uniformly distributed over the half-open interval [low, high)
+    (includes low, but excludes high). In other words, any value within the
+    given interval is equally likely to be drawn by uniform.
+
+    Parameters
+    ----------
+    gen : ThreefryKey
+        Generator state. Can be create with :py:func:`tvm.relay.threefry_key`. This should not be
+        reused in another function, otherwise random numbers will be repeated.
+
+    low : Tensor[(), out_dtype]
+        Lower boundary of the output interval. All values generated will be
+        greater than or equal to low.
+
+    high : Tensor[(), out_dtype]
+        Upper boundary of the output interval. All values generated will be
+        less than high.
+
+    out_shape : Sequence[int]
+        Output shape of the random numbers. Product of all dimensions must be a multiple of 4.
+
+    out_dtype : str
+        The output dtype.
+
+    Returns
+    -------
+    new_gen : ThreefryKey
+        New generator state that is distinct from `gen`.
+
+    out : Tensor[out_shape, out_dtype]
+        Tensor of random numbers with shape `out_shape` and type `out_dtype`.
+    """
+    new_gen, random_bits = threefry_generate(gen, out_shape)
+    assert out_dtype in ("float32", "float64")

Review comment:
       Can you add an error message

##########
File path: src/relay/op/random/kernel.cc
##########
@@ -85,5 +85,52 @@ RELAY_REGISTER_OP("random.threefry_split")
     .add_argument("key", "Tensor", "Input Threefry key")
     .add_type_rel("ThreefrySplit", ThreefrySplitRel);
 
+TVM_REGISTER_NODE_TYPE(UniformAttrs);
+
+bool UniformRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
+                const TypeReporter& reporter) {
+  const UniformAttrs* param = attrs.as<UniformAttrs>();
+  ICHECK_EQ(types.size(), 4) << "Uniform should have three input and one output";

Review comment:
       ```suggestion
     ICHECK_EQ(types.size(), 4) << "Uniform should have three inputs and one output";
   ```




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] zhuzilin edited a comment on pull request #8041: [Relay][PRNG] Add uniform distribution generator wrt threefry PRNG

Posted by GitBox <gi...@apache.org>.
zhuzilin edited a comment on pull request #8041:
URL: https://github.com/apache/tvm/pull/8041#issuecomment-841694188


   @FrozenGene @tkonolige @altanh Thank you for your reviews. I've updated this PR based on them.
   
   > I think there are 2 options: 1. Ref Conv2DRel. 2. You could restrict the type here and raise exception. I prefer option 1.
   
   @FrozenGene Thank you for the clue! However, I haven't find how to restrict the dtype attributes in `Conv2DRel` or `Conv2DAttrs`... Should I add the type restriction to the `UniformAttrs`, or raise error in the `MakeUniform` and `UniformRel`?
   
   > How well does this approach work when we have a large range (high - low)? It seems like we would be loosing a lot of potential randomness.
   
   @tkonolige As this approach is only using the fraction bits to represent float, there will be loss of randomness for all floats, at least `(2^nexp-1) / 2^nexp` of the float (nexp stands for the number of exponential digits). However, it's a little tricky to use all 64 digits of the random bit to represent a uniform distributed float number... Do you have any idea on that?
   
   > Perhaps this is the operation you're looking for? https://github.com/altanh/tvm/blob/2d9ac7710ab055d4f20e8b5a0a3580836723efac/python/tvm/topi/generic/algorithm.py#L465
   
   @altanh Thank you for your references! The `reinterpret` is exactly what I was looking for. I've updated the algorithm and right now it is the same as the one used in jax.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] tkonolige commented on a change in pull request #8041: [Relay][PRNG] Add uniform distribution generator wrt threefry PRNG

Posted by GitBox <gi...@apache.org>.
tkonolige commented on a change in pull request #8041:
URL: https://github.com/apache/tvm/pull/8041#discussion_r632639897



##########
File path: python/tvm/relay/op/random/kernel.py
##########
@@ -132,3 +132,52 @@ def foo(key):
         :py:func:`threefry_generate`.
     """
     return _make.threefry_split(key)
+
+
+def uniform(key, shape, dtype="float32", low=0.0, high=1.0):
+    """Draw samples from a uniform distribution.
+
+    Samples are uniformly distributed over the half-open interval [low, high)
+    (includes low, but excludes high). In other words, any value within the
+    given interval is equally likely to be drawn by uniform.
+
+    Example
+    -------
+
+    .. code-block:: python
+
+        key = threefry_key(0)
+        random_values = uniform(key, (100,), low=0, high=10)
+
+    Parameters
+    ----------
+    key : relay.Expr
+        key that uniquely determines the random values. Multiple uses with the
+        same generator will generate the same random values. This generator should be
+        treated as an opaque pointer. You can create one from calling
+        :py:func:`threefry_key`, :py:func:`threefry_split`, or
+        :py:func:`threefry_generate`. **Do not use this generator again after calling
+        this function.**
+
+    shape : Sequence[int]
+        Desired outputs shape of random numbers.
+
+    dtype : str
+        Desired outputs type of random numbers.
+
+    low : float or relay.Expr, optional
+        Lower bound of the uniform distribution.
+
+    high : float or relay.Expr, optional
+        Upper bound of the uniform distribution.
+
+    Returns

Review comment:
       One thing we want to think about here is the API for generating random values. The main question is around handling the threefry key. Do we want the user to split the key before every invocation, or do we want anything that uses the key to return the next one.
   
   i.e.
   
   ```
   key_1, key_2 = split(key)
   data = uniform(key_2, ...)
   my_other_random_funtion(key_1)
   ```
   
   vs
   
   ```
   data, key = uniform(key, ...)
   my_other_random_function(key)
   ```
   
   Personally I prefer the second. It also makes better use of the key space and avoid recomputing the key too often.

##########
File path: src/relay/op/random/kernel.cc
##########
@@ -85,5 +85,46 @@ RELAY_REGISTER_OP("random.threefry_split")
     .add_argument("key", "Tensor", "Input Threefry key")
     .add_type_rel("ThreefrySplit", ThreefrySplitRel);
 
+TVM_REGISTER_NODE_TYPE(UniformAttrs);
+
+bool UniformRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
+                const TypeReporter& reporter) {
+  const UniformAttrs* param = attrs.as<UniformAttrs>();
+  ICHECK_EQ(types.size(), 4) << "ThreefryGenerate should have one input and one output";

Review comment:
       ```suggestion
     ICHECK_EQ(types.size(), 4) << "Uniform should have one input and one output";
   ```

##########
File path: python/tvm/topi/random/kernel.py
##########
@@ -466,3 +466,54 @@ def gen_ir(out_ptr):
     out_ary = tvm.nd.array(np.ones((1,), "uint64"), device)
     tvm.build(s, [f], target=target)(out_ary)
     return out_ary.asnumpy()[0] == 0
+
+
+def uniform(gen, low, high, out_shape, out_dtype):
+    """Draw samples from a uniform distribution.
+
+    Samples are uniformly distributed over the half-open interval [low, high)
+    (includes low, but excludes high). In other words, any value within the
+    given interval is equally likely to be drawn by uniform.
+
+    Parameters
+    ----------
+    gen : Tensor[10, uint64]
+        Generator state. Can be create with :py:func:`tvm.relay.threefry_key`. This should not be
+        reused in another function, otherwise random numbers will be repeated.
+
+    low : Tensor[(), out_dtype]
+        Lower boundary of the output interval. All values generated will be
+        greater than or equal to low.
+
+    high : Tensor[(), out_dtype]
+        Upper boundary of the output interval. All values generated will be
+        less than high.
+
+    out_shape : Sequence[int]
+        Output shape of the random numbers. Product of all dimensions must be a multiple of 4.
+
+    out_dtype : str
+        The output dtype.
+
+    Returns
+    -------
+    out : Tensor[out_shape, out_dtype]
+        Tensor of random numbers with shape `out_shape` and type `out_dtype`.
+    """
+    _, random_bits = threefry_generate(gen, out_shape)
+    nbits = 64
+    if out_dtype == "float32":
+        nfraction = 23
+    elif out_dtype == "float64":
+        nfraction = 52
+
+    def uniform_scalar(bits):
+        bits = bits >> (nbits - nfraction)
+        standard_uniform = bits.astype(out_dtype) / float(1 << nfraction)
+        return standard_uniform
+
+    standard_uniform_values = tvm.te.compute(out_shape, lambda *i: uniform_scalar(random_bits(*i)))
+
+    uniform_values = tvm.topi.add(tvm.topi.multiply(standard_uniform_values, high - low), low)

Review comment:
       How well does this approach work when we have a large range (high - low)? It seems like we would be loosing a lot of potential randomness.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] FrozenGene edited a comment on pull request #8041: [Relay][PRNG] Add uniform distribution generator wrt threefry PRNG

Posted by GitBox <gi...@apache.org>.
FrozenGene edited a comment on pull request #8041:
URL: https://github.com/apache/tvm/pull/8041#issuecomment-841919711


   > @FrozenGene Thank you for the clue! However, I haven't find how to restrict the dtype attributes in `Conv2DRel` or `Conv2DAttrs`... Should I add the type restriction to the `UniformAttrs`, or raise error in the `MakeUniform` and `UniformRel`?
   
   Suggest `UniformRel`
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] zhuzilin commented on a change in pull request #8041: [Relay][PRNG] Add uniform distribution generator wrt threefry PRNG

Posted by GitBox <gi...@apache.org>.
zhuzilin commented on a change in pull request #8041:
URL: https://github.com/apache/tvm/pull/8041#discussion_r632413056



##########
File path: python/tvm/topi/random/kernel.py
##########
@@ -466,3 +466,54 @@ def gen_ir(out_ptr):
     out_ary = tvm.nd.array(np.ones((1,), "uint64"), device)
     tvm.build(s, [f], target=target)(out_ary)
     return out_ary.asnumpy()[0] == 0
+
+
+def uniform(gen, low, high, out_shape, out_dtype):
+    """Draw samples from a uniform distribution.
+
+    Samples are uniformly distributed over the half-open interval [low, high)
+    (includes low, but excludes high). In other words, any value within the
+    given interval is equally likely to be drawn by uniform.
+
+    Parameters
+    ----------
+    gen : Tensor[10, uint64]
+        Generator state. Can be create with :py:func:`tvm.relay.threefry_key`. This should not be
+        reused in another function, otherwise random numbers will be repeated.
+
+    low : Tensor[(), out_dtype]
+        Lower boundary of the output interval. All values generated will be
+        greater than or equal to low.
+
+    high : Tensor[(), out_dtype]
+        Upper boundary of the output interval. All values generated will be
+        less than high.
+
+    out_shape : Sequence[int]
+        Output shape of the random numbers. Product of all dimensions must be a multiple of 4.

Review comment:
       It's the property of the threefry key. Please refer to this comment: https://github.com/apache/tvm/pull/7083#discussion_r554177250




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] zhuzilin edited a comment on pull request #8041: [Relay][PRNG] Add uniform distribution generator wrt threefry PRNG

Posted by GitBox <gi...@apache.org>.
zhuzilin edited a comment on pull request #8041:
URL: https://github.com/apache/tvm/pull/8041#issuecomment-841694188


   @FrozenGene @tkonolige @altanh Thank you for your reviews. I've updated this PR based on them.
   
   > I think there are 2 options: 1. Ref Conv2DRel. 2. You could restrict the type here and raise exception. I prefer option 1.
   
   @FrozenGene Thank you for the clue! However, I haven't find how to restrict the dtype attributes in `Conv2DRel` or `Conv2DAttrs`... Should I add the type restriction to the `UniformAttrs`, or raise error in the `MakeUniform` and `UniformRel`?
   
   > How well does this approach work when we have a large range (high - low)? It seems like we would be loosing a lot of potential randomness.
   
   @tkonolige As this approach is only using the fraction bits to represent float, there will be loss of randomness for all floats, at least `(2^nexp-1) / 2^nexp` of the float (nexp stands for the number of exponential digits). However, it's a little tricky to use all 64 digits of the random bit to represent a uniform distributed float number... Do you have any idea on that?
   
   > Perhaps this is the operation you're looking for? https://github.com/altanh/tvm/blob/2d9ac7710ab055d4f20e8b5a0a3580836723efac/python/tvm/topi/generic/algorithm.py#L465
   
   @altanh Thank you for your references! The `reinterpret` is exactly what I was looking for. I've updated the algorithm and right now it is the same as the one used in jax.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] zhuzilin commented on a change in pull request #8041: [Relay][PRNG] Add uniform distribution generator wrt threefry PRNG

Posted by GitBox <gi...@apache.org>.
zhuzilin commented on a change in pull request #8041:
URL: https://github.com/apache/tvm/pull/8041#discussion_r634130654



##########
File path: python/tvm/topi/random/kernel.py
##########
@@ -466,3 +466,54 @@ def gen_ir(out_ptr):
     out_ary = tvm.nd.array(np.ones((1,), "uint64"), device)
     tvm.build(s, [f], target=target)(out_ary)
     return out_ary.asnumpy()[0] == 0
+
+
+def uniform(gen, low, high, out_shape, out_dtype):
+    """Draw samples from a uniform distribution.
+
+    Samples are uniformly distributed over the half-open interval [low, high)
+    (includes low, but excludes high). In other words, any value within the
+    given interval is equally likely to be drawn by uniform.
+
+    Parameters
+    ----------
+    gen : Tensor[10, uint64]
+        Generator state. Can be create with :py:func:`tvm.relay.threefry_key`. This should not be
+        reused in another function, otherwise random numbers will be repeated.
+
+    low : Tensor[(), out_dtype]
+        Lower boundary of the output interval. All values generated will be
+        greater than or equal to low.
+
+    high : Tensor[(), out_dtype]
+        Upper boundary of the output interval. All values generated will be
+        less than high.
+
+    out_shape : Sequence[int]
+        Output shape of the random numbers. Product of all dimensions must be a multiple of 4.

Review comment:
       @altanh Sorry that I'm not familiar with the threefry algorithm. Is it possible to call `_threefry` twice in `threefry_generate` in the following form? something like:
   
   ```python
   out_array = irb.buffer_ptr(out_array_ptr)
   # deal with most of the array
   _threefry(irb, tmp, 0, tmp, 4, out_array, 0, out_len // 4)
   if out_len % 4 != 0:
       # generate remainders in a small tmp buffer
       tmp_array = irb.allocate(gen.dtype, 4, name="tmp", scope="global")
       # may need to update the tmp key in between
       # ...
       _threefry(irb, tmp, 0, tmp, 4, tmp_array, 0, out_len // 4)
       # only copy the tmp buffer
       for i in range(out_len // 4 * 4, out_len):
           out_array[i] = tmp_array[i%4]
   ```
   
   In this way, we coud avoid copying the whole generated tensor.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] zhuzilin commented on pull request #8041: [Relay][PRNG] Add uniform distribution generator wrt threefry PRNG

Posted by GitBox <gi...@apache.org>.
zhuzilin commented on pull request #8041:
URL: https://github.com/apache/tvm/pull/8041#issuecomment-842404407


   > Suggest `UniformRel`
   
   @FrozenGene Thank you. I've added the type restriction.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] zhuzilin commented on pull request #8041: [Relay][PRNG] Add uniform distribution generator wrt threefry PRNG

Posted by GitBox <gi...@apache.org>.
zhuzilin commented on pull request #8041:
URL: https://github.com/apache/tvm/pull/8041#issuecomment-842404407


   > Suggest `UniformRel`
   
   @FrozenGene Thank you. I've added the type restriction.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] tkonolige commented on a change in pull request #8041: [Relay][PRNG] Add uniform distribution generator wrt threefry PRNG

Posted by GitBox <gi...@apache.org>.
tkonolige commented on a change in pull request #8041:
URL: https://github.com/apache/tvm/pull/8041#discussion_r634781789



##########
File path: python/tvm/topi/random/kernel.py
##########
@@ -466,3 +466,54 @@ def gen_ir(out_ptr):
     out_ary = tvm.nd.array(np.ones((1,), "uint64"), device)
     tvm.build(s, [f], target=target)(out_ary)
     return out_ary.asnumpy()[0] == 0
+
+
+def uniform(gen, low, high, out_shape, out_dtype):
+    """Draw samples from a uniform distribution.
+
+    Samples are uniformly distributed over the half-open interval [low, high)
+    (includes low, but excludes high). In other words, any value within the
+    given interval is equally likely to be drawn by uniform.
+
+    Parameters
+    ----------
+    gen : Tensor[10, uint64]
+        Generator state. Can be create with :py:func:`tvm.relay.threefry_key`. This should not be
+        reused in another function, otherwise random numbers will be repeated.
+
+    low : Tensor[(), out_dtype]
+        Lower boundary of the output interval. All values generated will be
+        greater than or equal to low.
+
+    high : Tensor[(), out_dtype]
+        Upper boundary of the output interval. All values generated will be
+        less than high.
+
+    out_shape : Sequence[int]
+        Output shape of the random numbers. Product of all dimensions must be a multiple of 4.

Review comment:
       Yeah, you could do that. Maybe submit it in a new PR?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] zhuzilin commented on a change in pull request #8041: [Relay][PRNG] Add uniform distribution generator wrt threefry PRNG

Posted by GitBox <gi...@apache.org>.
zhuzilin commented on a change in pull request #8041:
URL: https://github.com/apache/tvm/pull/8041#discussion_r632416258



##########
File path: python/tvm/topi/random/kernel.py
##########
@@ -466,3 +466,54 @@ def gen_ir(out_ptr):
     out_ary = tvm.nd.array(np.ones((1,), "uint64"), device)
     tvm.build(s, [f], target=target)(out_ary)
     return out_ary.asnumpy()[0] == 0
+
+
+def uniform(gen, low, high, out_shape, out_dtype):
+    """Draw samples from a uniform distribution.
+
+    Samples are uniformly distributed over the half-open interval [low, high)
+    (includes low, but excludes high). In other words, any value within the
+    given interval is equally likely to be drawn by uniform.
+
+    Parameters
+    ----------
+    gen : Tensor[10, uint64]

Review comment:
       This is the `ThreefryKeyType` introduced in #7083. Please refer to: https://github.com/apache/tvm/blob/c999a840cb5579c493f5b5e7f20bc619260dad08/src/relay/op/random/kernel.cc#L28




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] zhuzilin commented on pull request #8041: [Relay][PRNG] Add uniform distribution generator wrt threefry PRNG

Posted by GitBox <gi...@apache.org>.
zhuzilin commented on pull request #8041:
URL: https://github.com/apache/tvm/pull/8041#issuecomment-841694188


   @FrozenGene @tkonolige @altanh Thank you for your reviews. I've updated this PR based on them.
   
   > I think there are 2 options: 1. Ref Conv2DRel. 2. You could restrict the type here and raise exception. I prefer option 1.
   
   @FrozenGene Thank you for the clue! However, I haven't find how to restrict the dtype attributes in `Conv2DRel` or `Conv2DAttrs`... Should I add the type restriction to the `UniformAttrs`, or raise error in the `MakeUniform` and `UniformRel`?
   
   > How well does this approach work when we have a large range (high - low)? It seems like we would be loosing a lot of potential randomness.
   
   @tkonolige As this approach is only using the fraction bits to represent float, there will be loss of randomness for all floats, at least `(2^nexp-1) / 2^nexp` of the float (nexp stands for the number of exponential digits). However, it's a little tricky to use all 64 digits of the random bit to represent a uniform distributed float number... Do you have any idea on that?
   
   > Perhaps this is the operation you're looking for? https://github.com/altanh/tvm/blob/2d9ac7710ab055d4f20e8b5a0a3580836723efac/python/tvm/topi/generic/algorithm.py#L465
   
   @altanh Thank you for your references! The `reinterpret` is exactly what I was looking for. I've tried to turn the algorithm in to the following form:
   
   ```python
       fraction = tvm.topi.right_shift(random_bits, tvm.tir.const(nbits - nfraction, dtype="uint64"))
       # 127 is the magic number for float32. it will make the exponenial part 1.
       exponent = tvm.topi.left_shift(tvm.topi.full(out_shape, "uint64", 127), tvm.tir.const(nfraction, dtype="uint64"))
       mantissa = tvm.topi.bitwise_or(fraction, exponent)
       standard_uniform_values = tvm.topi.reinterpret(mantissa, out_dtype) - 1
       uniform_values = tvm.topi.add(tvm.topi.multiply(standard_uniform_values, high - low), low)
   ```
   
   But there will be error relevant to llvm bitcast on the second line from bottom (the `-1`):
   ```
   TVMError: LLVM module verification failed with the following errors: 
   Invalid bitcast
     %39 = bitcast <2 x i64> %37 to <2 x float>
   Invalid bitcast
     %40 = bitcast <2 x i64> %38 to <2 x float>
   Invalid bitcast
     %52 = bitcast i64 %51 to float
   ```
    Do you have any clue why this kind of error happens?


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] zhuzilin commented on a change in pull request #8041: [Relay][PRNG] Add uniform distribution generator wrt threefry PRNG

Posted by GitBox <gi...@apache.org>.
zhuzilin commented on a change in pull request #8041:
URL: https://github.com/apache/tvm/pull/8041#discussion_r632978952



##########
File path: python/tvm/topi/random/kernel.py
##########
@@ -466,3 +466,54 @@ def gen_ir(out_ptr):
     out_ary = tvm.nd.array(np.ones((1,), "uint64"), device)
     tvm.build(s, [f], target=target)(out_ary)
     return out_ary.asnumpy()[0] == 0
+
+
+def uniform(gen, low, high, out_shape, out_dtype):
+    """Draw samples from a uniform distribution.
+
+    Samples are uniformly distributed over the half-open interval [low, high)
+    (includes low, but excludes high). In other words, any value within the
+    given interval is equally likely to be drawn by uniform.
+
+    Parameters
+    ----------
+    gen : Tensor[10, uint64]

Review comment:
       Fixed.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] zhuzilin commented on a change in pull request #8041: [Relay][PRNG] Add uniform distribution generator wrt threefry PRNG

Posted by GitBox <gi...@apache.org>.
zhuzilin commented on a change in pull request #8041:
URL: https://github.com/apache/tvm/pull/8041#discussion_r632901257



##########
File path: python/tvm/topi/random/kernel.py
##########
@@ -466,3 +466,54 @@ def gen_ir(out_ptr):
     out_ary = tvm.nd.array(np.ones((1,), "uint64"), device)
     tvm.build(s, [f], target=target)(out_ary)
     return out_ary.asnumpy()[0] == 0
+
+
+def uniform(gen, low, high, out_shape, out_dtype):
+    """Draw samples from a uniform distribution.
+
+    Samples are uniformly distributed over the half-open interval [low, high)
+    (includes low, but excludes high). In other words, any value within the
+    given interval is equally likely to be drawn by uniform.
+
+    Parameters
+    ----------
+    gen : Tensor[10, uint64]
+        Generator state. Can be create with :py:func:`tvm.relay.threefry_key`. This should not be
+        reused in another function, otherwise random numbers will be repeated.
+
+    low : Tensor[(), out_dtype]
+        Lower boundary of the output interval. All values generated will be
+        greater than or equal to low.
+
+    high : Tensor[(), out_dtype]
+        Upper boundary of the output interval. All values generated will be
+        less than high.
+
+    out_shape : Sequence[int]
+        Output shape of the random numbers. Product of all dimensions must be a multiple of 4.
+
+    out_dtype : str
+        The output dtype.
+
+    Returns
+    -------
+    out : Tensor[out_shape, out_dtype]
+        Tensor of random numbers with shape `out_shape` and type `out_dtype`.
+    """
+    _, random_bits = threefry_generate(gen, out_shape)
+    nbits = 64
+    if out_dtype == "float32":
+        nfraction = 23
+    elif out_dtype == "float64":
+        nfraction = 52
+

Review comment:
       Thanks a lot!




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] zhuzilin commented on pull request #8041: [Relay][PRNG] Add uniform distribution generator wrt threefry PRNG

Posted by GitBox <gi...@apache.org>.
zhuzilin commented on pull request #8041:
URL: https://github.com/apache/tvm/pull/8041#issuecomment-842966025


   @altanh @tkonolige I've updated the PR upon the reviews. Could you take another look? Thank you~


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org