You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2021/03/03 05:46:10 UTC

[tvm] branch main updated: [torch] Add linear operator support (#7569)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 67bba90  [torch] Add linear operator support (#7569)
67bba90 is described below

commit 67bba9032577025419dc0e110fdf4b08c5f66895
Author: Alexander Pivovarov <pi...@amazon.com>
AuthorDate: Tue Mar 2 21:45:51 2021 -0800

    [torch] Add linear operator support (#7569)
---
 python/tvm/relay/frontend/pytorch.py          | 15 ++++++++++++
 tests/python/frontend/pytorch/test_forward.py | 34 +++++++++++++++++++++++++++
 2 files changed, 49 insertions(+)

diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py
index 3c61749..dcf2f08 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -1374,6 +1374,20 @@ class PyTorchOpConverter:
             count_include_pad=count_include_pad,
         )
 
+    def linear(self, inputs, input_types):
+        # https://pytorch.org/docs/stable/nn.functional.html#linear
+        # 0 - input
+        # 1 - weight
+        bias = inputs[2]
+        mm_out = self.matmul(inputs[:2], input_types[:2])
+        if isinstance(bias, _expr.Expr):
+            bias_ndims = len(self.infer_shape_with_prelude(bias))
+            if bias_ndims == 1:
+                return _op.nn.bias_add(mm_out, bias)
+            mm_dtype = self.infer_type_with_prelude(mm_out).dtype
+            return self.add([mm_out, bias], [mm_dtype, input_types[2]])
+        return mm_out
+
     def dropout(self, inputs, input_types):
         data = inputs[0]
         rate = float(inputs[1])
@@ -2289,6 +2303,7 @@ class PyTorchOpConverter:
             "aten::softplus": self.softplus,
             "aten::avg_pool2d": self.avg_pool2d,
             "aten::avg_pool3d": self.avg_pool3d,
+            "aten::linear": self.linear,
             "aten::dropout": self.dropout,
             "aten::dropout_": self.dropout,
             "aten::feature_dropout": self.dropout,
diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py
index 9f035ad..54bf2fd 100644
--- a/tests/python/frontend/pytorch/test_forward.py
+++ b/tests/python/frontend/pytorch/test_forward.py
@@ -24,6 +24,7 @@ import numpy as np
 import torch
 import torchvision
 from torch.nn import Module
+from torch.nn import functional as F
 import tvm
 from tvm import relay
 from tvm.contrib import graph_runtime
@@ -1460,6 +1461,39 @@ def test_forward_dense():
 
 
 @tvm.testing.uses_gpu
+def test_forward_linear():
+    torch.set_grad_enabled(False)
+
+    class Linear(Module):
+        def forward(self, input, weight, bias):
+            return F.linear(input, weight, bias)
+
+    class LinearNoBias(Module):
+        def forward(self, input, weight):
+            return F.linear(input, weight)
+
+    input2d = torch.rand([2, 2]).float()
+    weight1d = torch.rand([2]).float()
+    weight2d = torch.rand([2, 2]).float()
+    bias1d = torch.rand([2]).float()
+    bias2d = torch.rand([2, 2]).float()
+    # 2D input, 2D weight, 1D bias
+    verify_model(Linear(), input_data=[input2d, weight2d, bias1d])
+    # 2D input, 2D weight, 2D bias
+    verify_model(Linear(), input_data=[input2d, weight2d, bias2d])
+    # 2D input, 2D weight, no bias
+    verify_model(LinearNoBias(), input_data=[input2d, weight2d])
+    # 2D input, 1D weight, 1D bias is not supported by torch.linear()
+    # 2D input, 1D weight, no bias
+    verify_model(LinearNoBias(), input_data=[input2d, weight1d])
+    # TODO: Add the following cases when matmul(1D, _) is supported by TVM
+    # 1D input, 2D weight, 1D bias
+    # 1D input, 2D weight, no bias
+    # 1D input, 1D weight, scalar bias
+    # 1D input, 1D weight, no bias
+
+
+@tvm.testing.uses_gpu
 def test_forward_dropout():
     torch.set_grad_enabled(False)
     input_shape = [1, 3, 10, 10]