You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2022/07/06 00:12:03 UTC
[tvm] branch main updated: add aten::randn (#11994)
This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 5bc6684c9d add aten::randn (#11994)
5bc6684c9d is described below
commit 5bc6684c9d5664b8f3aeac8f8503f894ab2bfee5
Author: Yuanjing Shi <yu...@octoml.ai>
AuthorDate: Tue Jul 5 17:11:57 2022 -0700
add aten::randn (#11994)
---
python/tvm/relay/frontend/pytorch.py | 9 +++++++++
tests/python/frontend/pytorch/test_forward.py | 12 ++++++++++++
2 files changed, 21 insertions(+)
diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py
index cb5392fa16..b1a7608860 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -2504,6 +2504,14 @@ class PyTorchOpConverter:
dtype = input_types[0]
return _op.zeros(shape, dtype)
+ def randn(self, inputs, input_types):
+ import time # use current time as seed
+
+ shape = inputs[0]
+ output = _op.random.normal(_op.random.threefry_key(int(time.time())), shape)
+ _, values = _expr.TupleWrapper(output, 2)
+ return values
+
def bincount(self, inputs, input_types):
data = inputs[0]
weights = inputs[1]
@@ -3415,6 +3423,7 @@ class PyTorchOpConverter:
"aten::numel": self.numel,
"aten::empty": self.empty,
"aten::empty_like": self.empty_like,
+ "aten::randn": self.randn,
"aten::bincount": self.bincount,
"aten::scatter_add": self.scatter_add,
"aten::__not__": self.logical_not,
diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py
index 80a5cd07f7..30ba713396 100644
--- a/tests/python/frontend/pytorch/test_forward.py
+++ b/tests/python/frontend/pytorch/test_forward.py
@@ -3895,6 +3895,18 @@ def test_empty_like():
verify_model_with_input(test_func, [torch.rand([1, 3, 10, 10]).float()], assert_shape_only=True)
+def test_randn():
+ def test_func():
+ return torch.randn([1, 3, 10, 10])
+
+ verify_model_with_input(test_func, [], assert_shape_only=True)
+
+ def test_func1():
+ return torch.randn(1, 3, 10, 10)
+
+ verify_model_with_input(test_func1, [], assert_shape_only=True)
+
+
def test_forward_pretrained_bert_base_uncased():
######################################################################
# This is an example how to run BERT models using TVM