You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2022/09/14 04:10:35 UTC

[tvm] branch torchbench created (now bacf3946c7)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a change to branch torchbench
in repository https://gitbox.apache.org/repos/asf/tvm.git


      at bacf3946c7 support constant folding on ndarray_size

This branch includes the following new commits:

     new 292f55b59b add copy_ and embedding_bag
     new 5aa6d7c253 fix rebase
     new bacf3946c7 support constant folding on ndarray_size

The 3 revisions listed above as "new" are entirely new to this
repository and will be described in separate emails.  The revisions
listed as "add" were already present in the repository and have only
been added to this reference.



[tvm] 02/03: fix rebase

Posted by ma...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch torchbench
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit 5aa6d7c25360ee5b339c42f8c9f5c655943a4333
Author: YJ Shi <yu...@octoml.ai>
AuthorDate: Tue Sep 13 16:15:09 2022 -0700

    fix rebase
---
 python/tvm/relay/frontend/pytorch.py          | 12 ++++--------
 tests/python/frontend/pytorch/test_forward.py |  7 +++----
 2 files changed, 7 insertions(+), 12 deletions(-)

diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py
index 9255c42383..e2badaabf7 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -39,11 +39,7 @@ from ..loops import while_loop
 from ..prelude import Prelude, StaticTensorArrayOps
 from ..ty import Any, TensorType, TupleType
 from . import qnn_torch
-<<<<<<< HEAD
-from .common import AttrCvt, get_relay_op, gru_cell, logger, rnn_cell
-=======
-from .common import AttrCvt, fold_constant, get_relay_op, gru_cell, infer_shape, logger
->>>>>>> dfcf28b5d... add copy_ and embedding_bag
+from .common import AttrCvt, fold_constant, get_relay_op, gru_cell, logger
 from .common import infer_shape as _infer_shape
 from .common import infer_value as _infer_value
 from .common import infer_value_simulated as _infer_value_simulated
@@ -3415,7 +3411,7 @@ class PyTorchOpConverter:
         output = _op.random.multinomial(key, probs, num_samples)
         _, indices = _expr.TupleWrapper(output, 2)
         return indices
-    
+
     def embedding_bag(self, inputs, _):
         assert len(inputs) == 9, "embedding_bag needs 9 arguments"
         (
@@ -3433,10 +3429,10 @@ class PyTorchOpConverter:
         assert scale_grad_by_freq == 0, "scale_grad_by_freq not supported in embedding_bag."
         assert padding_idx == None, "padding_idx not supported in embedding_bag."
 
-        assert len(infer_shape(indices)) == 1, "Expects 1D indices for aten::embedding_bag."
+        assert len(_infer_shape(indices)) == 1, "Expects 1D indices for aten::embedding_bag."
 
         offsets_const_fold = fold_constant(offsets_1d)
-
+        print(offsets_const_fold)
         assert isinstance(
             offsets_const_fold, _expr.Constant
         ), "Only constant offsets are supported."
diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py
index f9ff4a212c..58a4dfbe94 100755
--- a/tests/python/frontend/pytorch/test_forward.py
+++ b/tests/python/frontend/pytorch/test_forward.py
@@ -4608,7 +4608,6 @@ def test_mod():
         verify_model(test_fn, [torch.tensor([1, 2, 3, 4, 5]), torch.tensor(-1.5)])
 
 
-<<<<<<< HEAD
 def test_softmax_fuse():
     # https://github.com/apache/tvm/issues/12001
     class Model(torch.nn.Module):
@@ -4686,15 +4685,15 @@ def test_multinomial():
         _test_multinomial(1),
         [torch.rand(size=[4, 5]).float()],
         cpu_only=True,
-        check_correctness=False,
-=======
+    )
+
+
 def test_embedding_bag():
     embedding_matrix = torch.rand(10, 3)
     inp = torch.tensor([[1, 2, 4, 5], [4, 3, 2, 9], [6, 7, 8, 9]])
     verify_model(
         F.embedding_bag,
         [inp, embedding_matrix],
->>>>>>> dfcf28b5d... add copy_ and embedding_bag
     )
 
 


[tvm] 01/03: add copy_ and embedding_bag

Posted by ma...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch torchbench
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit 292f55b59b7e82381cb339bfb6f0885b866f097d
Author: YJ Shi <yu...@octoml.ai>
AuthorDate: Wed Jul 6 00:52:34 2022 -0700

    add copy_ and embedding_bag
---
 python/tvm/relay/frontend/pytorch.py          | 65 +++++++++++++++++++++++++++
 tests/python/frontend/pytorch/test_forward.py |  9 ++++
 2 files changed, 74 insertions(+)

diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py
index 0e6d4caae0..9255c42383 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -39,7 +39,11 @@ from ..loops import while_loop
 from ..prelude import Prelude, StaticTensorArrayOps
 from ..ty import Any, TensorType, TupleType
 from . import qnn_torch
+<<<<<<< HEAD
 from .common import AttrCvt, get_relay_op, gru_cell, logger, rnn_cell
+=======
+from .common import AttrCvt, fold_constant, get_relay_op, gru_cell, infer_shape, logger
+>>>>>>> dfcf28b5d... add copy_ and embedding_bag
 from .common import infer_shape as _infer_shape
 from .common import infer_value as _infer_value
 from .common import infer_value_simulated as _infer_value_simulated
@@ -811,6 +815,10 @@ class PyTorchOpConverter:
         fill_value = inputs[1]
         return self.full_impl(self.infer_shape(data), fill_value, input_types[0])
 
+    def copy_(self, inputs, input_types):
+        src = inputs[1]
+        return _op.tensor.copy(src)
+
     def linspace(self, inputs, input_types):
         start = inputs[0]
         stop = inputs[1]
@@ -3407,6 +3415,61 @@ class PyTorchOpConverter:
         output = _op.random.multinomial(key, probs, num_samples)
         _, indices = _expr.TupleWrapper(output, 2)
         return indices
+    
+    def embedding_bag(self, inputs, _):
+        assert len(inputs) == 9, "embedding_bag needs 9 arguments"
+        (
+            weights,
+            indices,
+            offsets_1d,
+            scale_grad_by_freq,
+            mode,
+            sparse,
+            per_sample_weights,
+            include_last_offset,
+            padding_idx,
+        ) = inputs
+
+        assert scale_grad_by_freq == 0, "scale_grad_by_freq not supported in embedding_bag."
+        assert padding_idx == None, "padding_idx not supported in embedding_bag."
+
+        assert len(infer_shape(indices)) == 1, "Expects 1D indices for aten::embedding_bag."
+
+        offsets_const_fold = fold_constant(offsets_1d)
+
+        assert isinstance(
+            offsets_const_fold, _expr.Constant
+        ), "Only constant offsets are supported."
+
+        offsets_np = offsets_const_fold.data.numpy()
+        if include_last_offset == 1:
+            offsets_np = offsets_np[..., 0]  # exclude last dimension
+        offsets_diff = np.diff(offsets_np)
+
+        assert np.all(offsets_diff[1:] == offsets_diff[0]), "Only 2D cases supported for now."
+
+        indices_2d = _op.reshape(indices, (-1, offsets_diff[0]))
+
+        mode_map = {0: _op.sum, 1: _op.mean, 2: _op.max}
+        assert mode in mode_map, "unsupported reduction op mode %d." % mode
+
+        reduce_op = mode_map[mode]
+
+        # TOOD(masahi): Implementing embedding_bag in terms of gather and reduce defeats the
+        # purpose of using this op. Implement Relay / topi op for fused gather and reduce.
+        gather = _op.take(weights, indices_2d, axis=0)
+        if per_sample_weights is not None:
+            if mode != 0:
+                raise NotImplementedError(
+                    "Only mode 'sum' is supported when per_sample_weights is passed."
+                )
+            gather = gather * per_sample_weights
+        reduced = reduce_op(gather, 1)
+        # pytorch/aten/src/ATen/native/EmbeddingBag.cpp shows that aten::embedding_bag returns
+        # 4 outputs: output, offset2bag, bag_size, max_indices
+        # The Python version of the op only returns the first output, so we also support only the
+        # first output. If the model uses other outputs, the conversion would fail.
+        return reduced, None, None, None
 
     # Operator mappings
     def create_convert_map(self):
@@ -3444,6 +3507,7 @@ class PyTorchOpConverter:
             "aten::full_like": self.full_like,
             "aten::new_full": self.new_full,
             "aten::fill_": self.fill_,
+            "aten::copy_": self.copy_,
             "aten::linspace": self.linspace,
             "aten::reciprocal": self.reciprocal,
             "aten::repeat": self.repeat,
@@ -3670,6 +3734,7 @@ class PyTorchOpConverter:
             "aten::__lshift__": self.make_elemwise("left_shift"),
             "aten::__rshift__": self.make_elemwise("right_shift"),
             "aten::multinomial": self.multinomial,
+            "aten::embedding_bag": self.embedding_bag,
         }
 
     def update_convert_map(self, custom_map):
diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py
index 4c78ba4b85..f9ff4a212c 100755
--- a/tests/python/frontend/pytorch/test_forward.py
+++ b/tests/python/frontend/pytorch/test_forward.py
@@ -4608,6 +4608,7 @@ def test_mod():
         verify_model(test_fn, [torch.tensor([1, 2, 3, 4, 5]), torch.tensor(-1.5)])
 
 
+<<<<<<< HEAD
 def test_softmax_fuse():
     # https://github.com/apache/tvm/issues/12001
     class Model(torch.nn.Module):
@@ -4686,6 +4687,14 @@ def test_multinomial():
         [torch.rand(size=[4, 5]).float()],
         cpu_only=True,
         check_correctness=False,
+=======
+def test_embedding_bag():
+    embedding_matrix = torch.rand(10, 3)
+    inp = torch.tensor([[1, 2, 4, 5], [4, 3, 2, 9], [6, 7, 8, 9]])
+    verify_model(
+        F.embedding_bag,
+        [inp, embedding_matrix],
+>>>>>>> dfcf28b5d... add copy_ and embedding_bag
     )
 
 


[tvm] 03/03: support constant folding on ndarray_size

Posted by ma...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch torchbench
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit bacf3946c727682e7aad82f03e34abbbd9f120a2
Author: Masahiro Masuda <ma...@gmail.com>
AuthorDate: Wed Sep 14 13:09:45 2022 +0900

    support constant folding on ndarray_size
---
 python/tvm/relay/frontend/pytorch.py  |  2 +-
 src/relay/transforms/fold_constant.cc | 10 ++++++++--
 2 files changed, 9 insertions(+), 3 deletions(-)

diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py
index e2badaabf7..722b2889d3 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -2489,7 +2489,7 @@ class PyTorchOpConverter:
         )
 
     def numel(self, inputs, input_types):
-        return _op.ndarray_size(inputs[0])
+        return fold_constant(_op.ndarray_size(inputs[0]))
 
     def empty(self, inputs, input_types):
         shape = inputs[0]
diff --git a/src/relay/transforms/fold_constant.cc b/src/relay/transforms/fold_constant.cc
index 9dec840be0..f484dfc700 100644
--- a/src/relay/transforms/fold_constant.cc
+++ b/src/relay/transforms/fold_constant.cc
@@ -188,8 +188,7 @@ class ConstantFolder : public MixedModeMutator {
     if (is_no_computational && (is_no_qnn_canonicalized || !fold_qnn_)) {
       return std::move(post_call);
     }
-    if (op == device_copy_op_ || op == shape_of_op_ || op == vm_shape_of_op_ ||
-        op == ndarray_size_op_) {
+    if (op == device_copy_op_ || op == shape_of_op_ || op == vm_shape_of_op_) {
       // We should think about potentially constant evaluation over these ops too.
       return std::move(post_call);
     }
@@ -383,6 +382,13 @@ class ConstantFolder : public MixedModeMutator {
       // TODO(mbs): This is not necessary since we only ever ask for the shapes for
       // pre-rewritten expressions which will always have a checked_type.
       return const_node->tensor_type()->shape;
+      //    } else if (auto ttype = input->type_as<TensorTypeNode>()) {
+    }  else if (const auto* var = input.as<VarNode>()) {
+      auto ty = var->type_annotation;
+      if (ty->IsInstance<TensorTypeNode>()) {
+        return Downcast<TensorType>(ty)->shape;
+      }
+      return {};
     } else if (input->checked_type_.defined()) {
       return input->checked_type().as<TensorTypeNode>()->shape;
     } else {