You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2021/01/27 19:30:27 UTC
[tvm] branch main updated: [Torch] Various updates for PyTorch
frontend (#7348)
This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 59e0a4a [Torch] Various updates for PyTorch frontend (#7348)
59e0a4a is described below
commit 59e0a4a46461b1a90bc24660cf25e08cfcfb7a1f
Author: masahi <ma...@gmail.com>
AuthorDate: Thu Jan 28 04:30:08 2021 +0900
[Torch] Various updates for PyTorch frontend (#7348)
* add conversion for detr
* remove explicit broadcast_to before batched matmul
* use take with wrap mode
* add test for transformer and negative indices
* add sort and argsort
* add logical_and
* support masked_select
* add gpu targets to masked_select test
* improve sort conversion
---
python/tvm/relay/frontend/pytorch.py | 63 ++++++++++++----
tests/python/frontend/pytorch/test_forward.py | 101 +++++++++++++++++++++++++-
2 files changed, 150 insertions(+), 14 deletions(-)
diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py
index 991e3a8..68e68fd 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -399,10 +399,7 @@ class PyTorchOpConverter:
begin = [0] * ndim
dim = int(inputs[1])
stride = int(inputs[4])
- if isinstance(inputs[2], _expr.Call):
- begin[dim], _ = try_infer_value(inputs[2], lambda ret: np.asscalar(ret.astype(np.int)))
- else:
- begin[dim] = int(inputs[2])
+ begin[dim], _ = try_infer_value(inputs[2], lambda ret: np.asscalar(ret.astype(np.int)))
# Process begin
if not isinstance(begin[dim], int):
@@ -518,13 +515,13 @@ class PyTorchOpConverter:
data = inputs[0]
dim = int(inputs[1])
index = _wrap_const(inputs[2])
- return _op.transform.take(data, index, axis=dim)
+ return _op.transform.take(data, index, axis=dim, mode="wrap")
def take(self, inputs, input_types):
data = inputs[0]
indices = _op.cast(inputs[1], "int32")
- return _op.transform.take(data, indices=indices)
+ return _op.transform.take(data, indices=indices, mode="wrap")
def topk(self, inputs, input_types):
data = inputs[0]
@@ -551,7 +548,13 @@ class PyTorchOpConverter:
def repeat(self, inputs, input_types):
data = inputs[0]
- reps = inputs[1]
+ reps = []
+ for r in inputs[1]:
+ if isinstance(r, int):
+ reps.append(r)
+ else:
+ reps.append(int(_infer_value(r, {}).asnumpy()))
+
return _op.transform.tile(data, reps=reps)
def repeat_interleave(self, inputs, input_types):
@@ -1520,12 +1523,6 @@ class PyTorchOpConverter:
# Convert a and b into 3 dimensional tensors.
a = _op.reshape(inputs_0, [-1, a_shape[-2], a_shape[-1]])
b = _op.reshape(inputs_1, [-1, b_shape[-2], b_shape[-1]])
- # Broadcast b to match batch size of a
- new_b_shape = list(self.infer_shape_with_prelude(b))
- new_a_shape = self.infer_shape_with_prelude(a)
- if new_a_shape[0] > new_b_shape[0]:
- new_b_shape[0] = new_a_shape[0]
- b = _op.broadcast_to(b, new_b_shape)
# Transpose matrix dimensions of b.
b = _op.transpose(b, [0, 2, 1])
# Perform a batch matmul.
@@ -2070,6 +2067,40 @@ class PyTorchOpConverter:
src = inputs[3]
return _op.scatter_add(data, index, src, axis=axis)
+ def cumsum(self, inputs, input_types):
+ data = inputs[0]
+ dim = inputs[1]
+ dtype = inputs[2]
+
+ if inputs[2] is not None:
+ dtype = _convert_dtype_value(inputs[2])
+
+ return _op.cumsum(data, axis=dim, dtype=dtype)
+
+ def masked_fill(self, inputs, input_types):
+ mask = inputs[1]
+ value = _op.cast(_wrap_const(inputs[2]), input_types[0])
+ return _op.where(mask, value, inputs[0])
+
+ def masked_select(self, inputs, input_types):
+ mask = inputs[1]
+ indices = self.nonzero([mask], input_types, is_numpy_style=True)
+ return _op.adv_index([inputs[0]] + [indices[i] for i in range(indices.size)])
+
+ def sort(self, inputs, input_types):
+ data = inputs[0]
+ dim = inputs[1]
+ is_descending = inputs[2]
+ # pytorch sort returns both sorted indices and values
+ indices = _op.argsort(data, dim, not is_descending)
+ return _op.gather(data, dim, indices), indices
+
+ def argsort(self, inputs, input_types):
+ data = inputs[0]
+ dim = inputs[1]
+ is_descending = inputs[2]
+ return _op.argsort(data, dim, not is_descending)
+
def is_floating_point(self, inputs, input_types):
assert len(inputs) == 1
@@ -2263,6 +2294,7 @@ class PyTorchOpConverter:
"torchvision::roi_align": self.roi_align,
"aten::unbind": self.unbind,
"aten::__and__": self.logical_and,
+ "aten::logical_and": self.logical_and,
"aten::_shape_as_tensor": self.shape_as_tensor,
"aten::nonzero": self.nonzero,
"aten::nonzero_numpy": self.nonzero_numpy,
@@ -2278,6 +2310,11 @@ class PyTorchOpConverter:
"aten::__not__": self.logical_not,
"aten::hardswish_": self.hard_swish,
"aten::hardswish": self.hard_swish,
+ "aten::cumsum": self.cumsum,
+ "aten::masked_fill": self.masked_fill,
+ "aten::masked_select": self.masked_select,
+ "aten::argsort": self.argsort,
+ "aten::sort": self.sort,
}
def update_convert_map(self, custom_map):
diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py
index 7cdd450..6d9b559 100644
--- a/tests/python/frontend/pytorch/test_forward.py
+++ b/tests/python/frontend/pytorch/test_forward.py
@@ -1147,7 +1147,7 @@ def test_forward_view():
@tvm.testing.uses_gpu
def test_forward_select():
torch.set_grad_enabled(False)
- input_shape = [1, 3, 10, 10]
+ input_shape = [5, 3, 10, 10]
class Select1(Module):
def forward(self, *args):
@@ -1167,6 +1167,9 @@ def test_forward_select():
input_data = torch.rand(input_shape).float()
verify_model(Select1().float().eval(), input_data=input_data)
+ # test negative indexing
+ verify_model(lambda x: x[-1], input_data=input_data)
+
x = torch.randn(3, 4)
indices = torch.tensor([0, 2])
verify_model(IndexedSelect(x, 0).eval(), input_data=indices)
@@ -2653,6 +2656,8 @@ def test_forward_take():
verify_model(Take1().float().eval(), input_data=input_data)
indices = torch.tensor([[0, 0], [1, 0]])
verify_model(Take2().float().eval(), input_data=[input_data, indices])
+ indices = torch.tensor([0, -1])
+ verify_model(Take2().float().eval(), input_data=[input_data, indices])
@tvm.testing.uses_gpu
@@ -3452,6 +3457,93 @@ def test_hard_swish():
verify_model(torch.nn.Hardswish(inplace=True).eval(), input_data=input)
+def test_cumsum():
+ def test_fn(dim, dtype=None):
+ return lambda x: torch.cumsum(x, dim=dim, dtype=dtype)
+
+ inp = torch.randint(0, 100, (10000,), dtype=torch.int32)
+ verify_model(test_fn(0), [inp])
+ verify_model(test_fn(0), [inp.to(torch.int64)])
+ verify_model(test_fn(0, dtype=torch.int64), [inp.to(torch.int64)])
+
+ inp = torch.randn((100, 100), dtype=torch.float32)
+ verify_model(test_fn(dim=0, dtype=torch.float64), [inp])
+ verify_model(test_fn(dim=1), [inp])
+
+ inp = torch.randn((100, 100), dtype=torch.float32) > 0.5
+ verify_model(test_fn(dim=0, dtype=torch.int32), [inp])
+
+
+def test_masked_fill():
+ def test_fn(x, mask):
+ return torch.masked_fill(x, mask, 0.0)
+
+ inp = torch.randn(100, 100)
+ verify_model(test_fn, [inp, inp > 0.5])
+ verify_model(test_fn, [inp.to(torch.float64), inp > 0.5])
+
+
+def test_transformer():
+ model = torch.nn.Transformer(d_model=256, nhead=8, num_encoder_layers=6, num_decoder_layers=6)
+ model = model.eval()
+ src = torch.rand((10, 32, 256))
+ tgt = torch.rand((20, 32, 256))
+ verify_model(model.eval(), input_data=[src, tgt])
+
+
+def test_argsort():
+ def test_fn(dim, descending):
+ return lambda x: torch.argsort(x, dim=dim, descending=descending)
+
+ inp = torch.randn(100)
+ verify_model(test_fn(0, True), [inp])
+ verify_model(test_fn(0, False), [inp])
+
+ inp = torch.randn(100, 100)
+ verify_model(test_fn(0, True), [inp])
+ verify_model(test_fn(0, False), [inp])
+ verify_model(test_fn(1, True), [inp])
+ verify_model(test_fn(1, False), [inp])
+
+
+def test_sort():
+ def test_fn(dim, descending):
+ return lambda x: torch.sort(x, dim=dim, descending=descending)
+
+ inp = torch.randn(100)
+ verify_model(test_fn(0, True), [inp])
+ verify_model(test_fn(0, False), [inp])
+
+ inp = torch.randn(100, 100)
+ verify_model(test_fn(0, True), [inp])
+ verify_model(test_fn(0, False), [inp])
+ verify_model(test_fn(1, True), [inp])
+ verify_model(test_fn(1, False), [inp])
+
+
+def test_logical_and():
+ def test_fn(x, y):
+ return torch.logical_and(x, y)
+
+ a = torch.tensor([0, 1, 10, 0], dtype=torch.int8)
+ b = torch.tensor([4, 0, 1, 0], dtype=torch.int8)
+ verify_model(test_fn, [a, b])
+
+ a = torch.tensor([True, False, True])
+ b = torch.tensor([True, False, False])
+ verify_model(test_fn, [a, b])
+
+
+def test_masked_select():
+ def test_fn(x, mask):
+ return torch.masked_select(x, mask)
+
+ for shape in [(10,), (3, 4), (16, 32, 64)]:
+ x = torch.randn(*shape)
+ mask = x.ge(0.5)
+ verify_trace_model(test_fn, [x, mask], ["llvm", "cuda", "nvptx"])
+
+
if __name__ == "__main__":
# some structural tests
test_forward_traced_function()
@@ -3580,6 +3672,13 @@ if __name__ == "__main__":
test_forward_scatter()
test_numel()
test_bincount()
+ test_cumsum()
+ test_masked_fill()
+ test_transformer()
+ test_sort()
+ test_argsort()
+ test_logical_and()
+ test_masked_select()
# Model tests
test_resnet18()