You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2022/07/06 00:12:03 UTC

[tvm] branch main updated: add aten::randn (#11994)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 5bc6684c9d add aten::randn (#11994)
5bc6684c9d is described below

commit 5bc6684c9d5664b8f3aeac8f8503f894ab2bfee5
Author: Yuanjing Shi <yu...@octoml.ai>
AuthorDate: Tue Jul 5 17:11:57 2022 -0700

    add aten::randn (#11994)
---
 python/tvm/relay/frontend/pytorch.py          |  9 +++++++++
 tests/python/frontend/pytorch/test_forward.py | 12 ++++++++++++
 2 files changed, 21 insertions(+)

diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py
index cb5392fa16..b1a7608860 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -2504,6 +2504,14 @@ class PyTorchOpConverter:
             dtype = input_types[0]
         return _op.zeros(shape, dtype)
 
+    def randn(self, inputs, input_types):
+        import time  # use current time as seed
+
+        shape = inputs[0]
+        output = _op.random.normal(_op.random.threefry_key(int(time.time())), shape)
+        _, values = _expr.TupleWrapper(output, 2)
+        return values
+
     def bincount(self, inputs, input_types):
         data = inputs[0]
         weights = inputs[1]
@@ -3415,6 +3423,7 @@ class PyTorchOpConverter:
             "aten::numel": self.numel,
             "aten::empty": self.empty,
             "aten::empty_like": self.empty_like,
+            "aten::randn": self.randn,
             "aten::bincount": self.bincount,
             "aten::scatter_add": self.scatter_add,
             "aten::__not__": self.logical_not,
diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py
index 80a5cd07f7..30ba713396 100644
--- a/tests/python/frontend/pytorch/test_forward.py
+++ b/tests/python/frontend/pytorch/test_forward.py
@@ -3895,6 +3895,18 @@ def test_empty_like():
     verify_model_with_input(test_func, [torch.rand([1, 3, 10, 10]).float()], assert_shape_only=True)
 
 
+def test_randn():
+    def test_func():
+        return torch.randn([1, 3, 10, 10])
+
+    verify_model_with_input(test_func, [], assert_shape_only=True)
+
+    def test_func1():
+        return torch.randn(1, 3, 10, 10)
+
+    verify_model_with_input(test_func1, [], assert_shape_only=True)
+
+
 def test_forward_pretrained_bert_base_uncased():
     ######################################################################
     # This is an example how to run BERT models using TVM