You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by zh...@apache.org on 2020/03/10 17:59:22 UTC
[incubator-tvm] branch master updated: [Torch] Add initial control
flow support (#4964)
This is an automated email from the ASF dual-hosted git repository.
zhic pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push:
new 06e9542 [Torch] Add initial control flow support (#4964)
06e9542 is described below
commit 06e9542ee0bfd014bd06a4dd4fdb3af9d2d29eb0
Author: masahi <ma...@gmail.com>
AuthorDate: Wed Mar 11 02:59:09 2020 +0900
[Torch] Add initial control flow support (#4964)
* Add support for prim::If and prim::Loop with test cases
* rebase and fix tests
* add some comments
* simplifying, fix float cast
* parse -> convert
* recursivly retrive ops in get_all_op_names
* use multiple return values from block correctly, simplify loop convert
* choose dtype properly for zeros and ones
* simplifying, replace convert_inputs with _get_relay_input_vars
* fix for while loop with non input dependent init cond
* add assert on loop var update
* move the condition around
* better testing for seg models
* rebase fix, disable inception v3 in quant test as it is too slow to
load with torch-1.4 + torchvision 0.5
* simplify and add more comparison op converter
---
python/tvm/relay/frontend/pytorch.py | 223 ++++++++++++++++++++++----
tests/python/frontend/pytorch/qnn_test.py | 3 +-
tests/python/frontend/pytorch/test_forward.py | 197 ++++++++++++++++++++++-
3 files changed, 385 insertions(+), 38 deletions(-)
diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py
index ff37f82..6da91c1 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -20,6 +20,7 @@
"""PT: PyTorch frontend."""
import itertools
import logging
+import sys
import numpy as np
@@ -29,6 +30,7 @@ from tvm.ir import module as _module
from .. import analysis as _analysis
from .. import expr as _expr
from .. import op as _op
+from ..loops import while_loop
from .common import get_relay_op
from .common import infer_shape as _infer_shape
from .common import infer_value as _infer_value
@@ -107,9 +109,8 @@ def _select():
def _impl(inputs, input_types):
data = inputs[0]
dim = int(inputs[1])
- index = int(inputs[2])
-
- return _op.transform.take(data, _expr.const(index, dtype="int32"), axis=dim)
+ index = _wrap_const(inputs[2])
+ return _op.transform.take(data, index, axis=dim)
return _impl
def _ones():
@@ -126,7 +127,10 @@ def _ones():
else:
assert "data type {} could not be parsed in ones op" % (type(data))
- return _op.full(_expr.const(1), shape, dtype=_convert_data_type(input_types[0]))
+ dtype_map = {6: "float32", 3: "int32"}
+ dtype_id = inputs[1]
+ assert dtype_id in dtype_map, "Unsupported dtype %d" % dtype_id
+ return _op.full(_expr.const(1), shape, dtype=dtype_map[dtype_id])
return _impl
def _zeros():
@@ -143,7 +147,10 @@ def _zeros():
else:
assert "data type {} could not be parsed in zeros op" % (type(data))
- return _op.full(_expr.const(0), shape, dtype=_convert_data_type(input_types[0]))
+ dtype_map = {6: "float32", 3: "int32"}
+ dtype_id = inputs[1]
+ assert dtype_id in dtype_map, "Unsupported dtype %d" % dtype_id
+ return _op.full(_expr.const(0), shape, dtype=dtype_map[dtype_id])
return _impl
def _relu():
@@ -222,12 +229,10 @@ def _convolution():
else:
assert "data type {} could not be parsed in conv op" % (type(weight))
- # TODO: Add reshape when channel multiplier > 1. Pending PR #4644
channels = weight_shape[0]
groups = int(inputs[8])
if groups > 1:
- # in torch, groups == in_channels for depth wise conv
channel_multiplier = channels // groups
new_weight_shape = (groups, channel_multiplier, weight_shape[2], weight_shape[3])
weight = _op.transform.reshape(weight, new_weight_shape)
@@ -496,7 +501,7 @@ def _dropout():
return _impl
def _reduce(name):
- def _impl(inputs, attrs, params):
+ def _impl(inputs, input_types):
data = inputs[0]
return get_relay_op(name)(data)
return _impl
@@ -714,7 +719,6 @@ def _upsample(method):
return _impl
-
def _expand_as():
def _impl(inputs, input_types):
# TODO: maybe fix this
@@ -724,6 +728,29 @@ def _expand_as():
return inputs[0]
return _impl
+def _neg():
+ def _impl(inputs, input_types):
+ data = inputs[0]
+ return _op.tensor.negative(data)
+ return _impl
+
+def _tanh():
+ def _impl(inputs, input_types):
+ data = inputs[0]
+ return _op.tensor.tanh(data)
+ return _impl
+
+def _Bool():
+ def _impl(inputs, input_types):
+ assert len(inputs) == 1
+ return inputs[0]
+ return _impl
+
+def _Float():
+ def _impl(inputs, input_types):
+ assert len(inputs) == 1
+ return _op.cast(inputs[0], "float32")
+ return _impl
# Helper functions for operator implementation
@@ -780,6 +807,11 @@ def _convert_elemwise_input(data, input_type):
else:
return data
+def _wrap_const(c):
+ if not isinstance(c, _expr.Expr) and not isinstance(c, list):
+ return _expr.const(c)
+ return c
+
# Operator mappings
_convert_map = {
@@ -845,7 +877,16 @@ _convert_map = {
"aten::detach" : _identity(),
"aten::upsample_bilinear2d" : _upsample("bilinear"),
"aten::upsample_nearest2d" : _upsample("nearest_neighbor"),
- "aten::expand_as" : _expand_as()
+ "aten::expand_as" : _expand_as(),
+ "aten::lt" : _elemwise("less"),
+ "aten::gt" : _elemwise("greater"),
+ "aten::le" : _elemwise("less_equal"),
+ "aten::ge" : _elemwise("greater_equal"),
+ "aten::ne" : _elemwise("not_equal"),
+ "aten::Bool" : _Bool(),
+ "aten::Float" : _Float(),
+ "aten::neg" : _neg(),
+ "aten::tanh" : _tanh(),
}
@@ -894,7 +935,8 @@ def _report_missing_conversion(op_names):
""" Check if all ops in an input graph are supported by TVM """
known_ops = ["prim::Constant", "prim::GetAttr",
"prim::ListConstruct", "prim::ListUnpack",
- "prim::TupleConstruct", "prim::TupleUnpack"]
+ "prim::TupleConstruct", "prim::TupleUnpack",
+ "prim::If", "prim::Loop"]
known_ops += list(_convert_map.keys())
known_ops += list(qnn_torch.convert_map.keys())
@@ -939,9 +981,13 @@ def _get_input_types(op_node):
input_node_kind = in_ty.kind()
if input_node_kind == 'TensorType':
if in_ty.scalarType() is None:
- input_list_types.append(None)
+ # Tensor's type can be unknown if we use torch.jit.script(...)
+ # Defaults to float for now
+ logging.warning("Untyped Tensor found, assume it is float")
+ input_list_types.append("float")
else:
input_list_types.append(in_ty.scalarType().lower())
+
elif input_node_kind == 'ListType':
input_list_types.append(str(in_ty.getElementType()).lower())
elif input_node_kind in ['IntType', 'FloatType', 'BoolType',
@@ -1004,15 +1050,10 @@ def _get_operator_nodes(nodes):
return ops
-def parse_inputs(graph_inputs, input_shapes):
- """ Return Relay vars from torch input vars """
- ir_inputs = list(graph_inputs)
- input_vars = {}
-
- for input_name, ir_input in zip(input_shapes, ir_inputs[1:]):
- input_vars[input_name] = _expr.var(input_name,
- shape=input_shapes[input_name])
- return input_vars
+def _get_relay_input_vars(input_shapes):
+ """ Return Relay vars from input shapes """
+ return {iname: _expr.var(iname, shape=ishape)
+ for iname, ishape in input_shapes.items()}
def get_use_chains(root_node, terminate=lambda _: False):
@@ -1055,7 +1096,7 @@ def get_attr_chains(root_getattr_node):
return get_use_chains(root_getattr_node, terminate)
-def parse_params(graph, state_dict):
+def convert_params(graph, state_dict):
"""
Return Relay vars and TVM NDArrays for input parameters
A chain of prim::GetAttr nodes is processed one at a time
@@ -1090,7 +1131,109 @@ def parse_params(graph, state_dict):
return params, param_tensors, packed_param_map
-def parse_operators(operators, outputs, output_index_map, ret_name):
+def convert_block(block, outputs, output_index_map):
+ """ Translate Torch "Block", used for prim::If and prim::Loop """
+ ops = _get_operator_nodes(block.nodes())
+ ret_names = _get_input_names(block.returnNode())
+ return convert_operators(ops, outputs, output_index_map, ret_names)
+
+
+def convert_if(if_node, outputs, output_index_map):
+ """ Translate Torch prim::If to Relay If """
+ cond = outputs[output_index_map[if_node.inputsAt(0).debugName()]]
+ blocks = list(if_node.blocks())
+ true_branch = convert_block(blocks[0], outputs, output_index_map)
+ false_branch = convert_block(blocks[1], outputs, output_index_map)
+ assert len(true_branch) == 1 and len(false_branch) == 1
+ return _expr.If(cond, true_branch[0], false_branch[0])
+
+
+def convert_loop(loop_node, outputs, output_index_map):
+ """ Translate Torch prim::Loop to Relay while_loop """
+ def get_input(index):
+ ivalue = loop_node.inputsAt(index)
+ inode = ivalue.node()
+ if inode.kind() == "prim::Constant":
+ return _expr.const(_get_constant(inode))
+ var_name = ivalue.debugName()
+ assert var_name in output_index_map
+ return _wrap_const(outputs[output_index_map[var_name]])
+
+ # Refer to the spec for prim::Loop below
+ # https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/OVERVIEW.md#loops
+ # The first input: %max_trip_count
+ # The second input: %initial_condition
+ # The rest of input: loop variables
+ max_loop_count = get_input(0)
+ init_cond = get_input(1)
+ num_loop_var = len(list(loop_node.inputs())) - 2
+ init_vals = [get_input(i + 2) for i in range(num_loop_var)]
+
+ # while loop has always max_loop_count being int64 max
+ # max_loop_count.data (tvm.runtime.NDArray) is -1, so _get_constant again
+ is_while_loop = (isinstance(max_loop_count, _expr.Constant) and
+ _get_constant(loop_node.inputsAt(0).node()) == sys.maxsize)
+
+ body_block = list(loop_node.blocks())[0]
+ block_input_names = _get_input_names(body_block)
+
+ def cond(*current_vals):
+ i = current_vals[0]
+
+ if is_while_loop:
+ return _op.equal(i, _expr.const(True, 'bool'))
+
+ return _op.less(i, max_loop_count)
+
+ def body(*current_vals):
+ # Update loop variables using the prev iteration outputs
+ assert len(current_vals) == len(block_input_names)
+ for (i, iname) in enumerate(block_input_names):
+ outputs[output_index_map[iname]] = current_vals[i]
+
+ block_outputs = convert_block(body_block, outputs, output_index_map)
+
+ if not is_while_loop:
+ # iter var increment implicit in torch, so do it manually
+ # for while loop, block_outputs[0] is already a boolean,
+ # the result of termination check
+ incr = _expr.const(1, dtype="int32")
+ block_outputs[0] = current_vals[0] + incr
+
+ return block_outputs
+
+ def get_var(name, val):
+ if isinstance(val, _expr.Constant):
+ return _expr.var(name, shape=val.data.shape, dtype=val.data.dtype)
+ return _expr.var(name)
+
+ if is_while_loop:
+ loop_iter_dtype = "bool"
+ # while loop with non input dependent condition such as while i < 10:
+ # init_cond is int, need to cast to bool to type check
+ if isinstance(init_cond, _expr.Constant):
+ init_cond = _op.cast(init_cond, "bool")
+ init_loop_iter_val = init_cond
+ else:
+ loop_iter_dtype = "int32"
+ # always count from 0
+ init_loop_iter_val = _expr.const(0, dtype="int32")
+
+ name_val_pairs = list(zip(block_input_names,
+ [init_loop_iter_val] + init_vals))
+ _update_outputs_from_pairs(name_val_pairs, outputs, output_index_map)
+
+ loop_iter_var = _expr.var(block_input_names[0], shape=(),
+ dtype=loop_iter_dtype)
+ loop_vars = [get_var(name, val) for name, val in name_val_pairs[1:]]
+ loop = while_loop(cond, [loop_iter_var] + loop_vars, body)
+ loop_val = loop(init_loop_iter_val, *init_vals)
+
+ # The first element is a loop counter or boolean condition, ignore it
+ return [_expr.TupleGetItem(loop_val, i+1) for i in range(num_loop_var)]
+
+
+def convert_operators(operators, outputs, output_index_map, ret_names):
""" Convert each Torch IR operators to Relay equivalent """
for node_name, op_node in operators:
operator = op_node.kind()
@@ -1110,17 +1253,35 @@ def parse_operators(operators, outputs, output_index_map, ret_name):
unpacked_names = _get_output_names(op_node)
_update_outputs_from_pairs(zip(unpacked_names, inputs[0]),
outputs, output_index_map)
+ elif operator == "prim::If":
+ if_out = convert_if(op_node, outputs, output_index_map)
+ output_index_map[node_name] = len(outputs)
+ outputs.append(if_out)
+ elif operator == "prim::Loop":
+ loop_out = convert_loop(op_node, outputs, output_index_map)
+ unpacked_names = _get_output_names(op_node)
+ assert len(loop_out) == len(unpacked_names)
+ _update_outputs_from_pairs(zip(unpacked_names, loop_out),
+ outputs, output_index_map)
else:
output_index_map[node_name] = len(outputs)
relay_op = _convert_map[operator]
outputs.append(relay_op(inputs, _get_input_types(op_node)))
- return outputs[output_index_map[ret_name]]
+ return [_wrap_const(outputs[output_index_map[ret_name]])
+ for ret_name in ret_names]
def get_all_op_names(graph):
""" Return all operator names in the input graph """
- return set(node.kind() for node in graph.nodes())
+ nodes = list(graph.nodes())
+ prim_with_blocks = ["prim::If", "prim::Loop"]
+ for prim in prim_with_blocks:
+ prim_nodes = graph.findAllNodes(prim, recurse=True)
+ for prim_node in prim_nodes:
+ for block in prim_node.blocks():
+ nodes += block.nodes()
+ return set(node.kind() for node in nodes)
def get_graph_input_names(script_module):
@@ -1167,14 +1328,14 @@ def from_pytorch(script_module, input_shapes, custom_convert_map=None):
_check_input_names(script_module, input_shapes)
params = script_module.state_dict()
- input_vars = parse_inputs(graph.inputs(), input_shapes)
- param_vars, tensors, packed_param_map = parse_params(graph, params)
+ input_vars = _get_relay_input_vars(input_shapes)
+ param_vars, tensors, packed_param_map = convert_params(graph, params)
tvm_params = {k: tvm.nd.array(v) for k, v in tensors.items()}
input_vars.update(param_vars)
outputs = list(input_vars.values())
output_index_map = dict(zip(input_vars.keys(), range(len(outputs))))
- ret_name = _get_input_names(graph.return_node())[0]
+ ret_name = _get_input_names(graph.return_node())
# For quantized models
if "aten::quantize_per_tensor" in op_names:
@@ -1186,8 +1347,8 @@ def from_pytorch(script_module, input_shapes, custom_convert_map=None):
qnn_torch.add_quant_params(tvm_params, weight_quant_params)
_convert_map.update(qnn_torch.convert_map)
- body = parse_operators(_get_operator_nodes(graph.nodes()), outputs,
- output_index_map, ret_name)
- func = tvm.relay.Function(_analysis.free_vars(body), body)
+ ret = convert_operators(_get_operator_nodes(graph.nodes()), outputs,
+ output_index_map, ret_name)
+ func = tvm.relay.Function(_analysis.free_vars(ret[0]), ret[0])
return _module.IRModule.from_expr(func), tvm_params
diff --git a/tests/python/frontend/pytorch/qnn_test.py b/tests/python/frontend/pytorch/qnn_test.py
index e3a876c..23fcb7c 100644
--- a/tests/python/frontend/pytorch/qnn_test.py
+++ b/tests/python/frontend/pytorch/qnn_test.py
@@ -347,7 +347,8 @@ def test_quantized_imagenet():
qmodels += [
("resnet18", qresnet.resnet18(pretrained=True), per_channel),
("mobilenet_v2", qmobilenet.mobilenet_v2(pretrained=True), per_channel),
- ("inception_v3", qinception.inception_v3(pretrained=True), per_channel),
+ # disable inception test for now, since loading it takes ~5min on torchvision-0.5
+ #("inception_v3", qinception.inception_v3(pretrained=True), per_channel),
("googlenet", qgooglenet(pretrained=True), per_channel),
]
diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py
index eed47ea..59f93b4 100644
--- a/tests/python/frontend/pytorch/test_forward.py
+++ b/tests/python/frontend/pytorch/test_forward.py
@@ -756,7 +756,6 @@ def test_vgg11_bn():
verify_model("vgg11_bn")
"""
-
def test_custom_conversion_map():
def get_roi_align():
pool_size = 5
@@ -801,11 +800,193 @@ def test_segmentaton_models():
inp = [torch.rand((1, 3, 300, 300), dtype=torch.float)]
- for model in [fcn, deeplab]:
- # depthwise + dilated covolution not supported on x86
- # see https://github.com/apache/incubator-tvm/issues/4962
- verify_model(SegmentationModelWrapper(model.eval()), inp,
- ctx_list=[("cuda", tvm.gpu(0))])
+ verify_model(SegmentationModelWrapper(fcn.eval()), inp)
+
+ # depthwise + dilated covolution not supported on x86
+ # see https://github.com/apache/incubator-tvm/issues/4962
+ cuda_ctx = ("cuda", tvm.gpu(0))
+ if cuda_ctx[1].exist:
+ verify_model(SegmentationModelWrapper(deeplab.eval()), inp, [cuda_ctx])
+
+
+def verify_script_model(pt_model, ishapes):
+ script_module = torch.jit.script(pt_model)
+ input_names = get_graph_input_names(script_module)
+ input_shapes = dict(zip(input_names, ishapes))
+
+ inputs = [torch.randn(input_shapes[input_name], dtype=torch.float)
+ for input_name in input_names]
+
+ mod, params = relay.frontend.from_pytorch(script_module, input_shapes)
+
+ executor = relay.create_executor("vm", mod=mod, ctx=tvm.cpu(0),
+ target="llvm")
+ evaluator = executor.evaluate()
+
+ for name, inp in zip(input_names, inputs):
+ params[name] = inp.numpy()
+
+ op_res = evaluator(**params)
+
+ with torch.no_grad():
+ pt_result = pt_model(*inputs)
+
+ if not isinstance(pt_result, torch.Tensor):
+ tvm_res = op_res.asnumpy().item()
+ assert pt_result == tvm_res
+ else:
+ tvm.testing.assert_allclose(op_res.asnumpy(), pt_result.numpy(),
+ rtol=1e-5, atol=1e-5)
+
+
+def test_control_flow():
+ class SimpleIf(torch.nn.Module):
+ def __init__(self, N, M):
+ super().__init__()
+ self.weight = torch.nn.Parameter(torch.rand(N, M))
+
+ def forward(self, inp):
+ if inp.sum() > 0.:
+ output = self.weight + inp
+ else:
+ output = self.weight - inp
+ return output
+
+ class NestedIf(torch.nn.Module):
+ def __init__(self, N, M):
+ super().__init__()
+ self.weight = torch.nn.Parameter(torch.rand(N, M))
+
+ def forward(self, inp):
+ if inp.sum() > 0.:
+ if inp.mean() > 0.:
+ output = self.weight + inp
+ else:
+ output = self.weight - inp
+ else:
+ if inp.mean() >= 0.:
+ output = self.weight * inp
+ else:
+ output = self.weight / inp
+
+ return output
+
+ class ScalarLoop(torch.nn.Module):
+ def forward(self, inp):
+ a = 0
+ for i in range(inp.size(0)):
+ b = i * i
+ b = b + 1
+ a += b
+ if a != 0:
+ a += 1
+ else:
+ a += 2
+ return a
+
+ class SimpleLoop(torch.nn.Module):
+ def forward(self, inp):
+ a = inp
+ for i in range(inp.size(0)):
+ b = a * 2.
+ c = a + b
+ a += c
+ return a
+
+ class LoopWithIf(torch.nn.Module):
+ def forward(self, inp):
+ a = inp
+ for i in range(inp.size(0)):
+ b = a * 2.
+ b = a + b
+ if b.sum() > 0.0:
+ a += b
+ else:
+ a -= b
+ return a
+
+ class NestedLoop(torch.nn.Module):
+ def forward(self, inp):
+ a = inp
+ for i in range(inp.size(0)):
+ b = a * float(i)
+ for j in range(inp.size(1)):
+ a += b * float(j)
+ return a
+
+ class SimpleScalarWhileLoop(torch.nn.Module):
+ def forward(self, inp):
+ a = 1
+ i = 0
+ while i <= inp.size(0):
+ a += i
+ i += 2
+ i = 0
+ # also test constant init cond
+ while i < 10:
+ a += i
+ i += 3
+ return a
+
+ class SimpleWhileLoop(torch.nn.Module):
+ def forward(self, inp):
+ a = inp
+ i = 0
+ while i < inp.size(0):
+ a += a * float(i) * 2.0
+ i += 1
+ return a
+
+ models = [
+ SimpleIf(10, 20),
+ NestedIf(10, 20),
+ ScalarLoop(),
+ SimpleLoop(),
+ LoopWithIf(),
+ SimpleScalarWhileLoop(),
+ SimpleWhileLoop(),
+ NestedLoop(),
+ ]
+
+ for pt_model in models:
+ verify_script_model(pt_model.eval(), [(10, 20)])
+
+
+def test_simple_rnn():
+ # The mixed tracing and scripting example from
+ # https://pytorch.org/tutorials/beginner/Intro_to_TorchScript_tutorial.html#mixing-scripting-and-tracing
+ class DecisionGate(torch.nn.Module):
+ def forward(self, x):
+ if x.sum() > 0:
+ return x
+ else:
+ return -x
+
+ class Cell(torch.nn.Module):
+ def __init__(self, dg):
+ super(Cell, self).__init__()
+ self.dg = dg
+ self.linear = torch.nn.Linear(4, 4)
+
+ def forward(self, x, h):
+ new_h = torch.tanh(self.dg(self.linear(x)) + h)
+ return new_h, new_h
+
+ class RNNLoop(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ x = torch.rand(10, 4, dtype=torch.float)
+ h = torch.rand(10, 4, dtype=torch.float)
+ self.cell = torch.jit.trace(Cell(DecisionGate()), (x, h))
+
+ def forward(self, xs):
+ h = torch.zeros(10, 4, dtype=torch.float)
+ y = torch.zeros(10, 4, dtype=torch.float)
+ for i in range(xs.size(0)):
+ y, h = self.cell(xs[i], h)
+ return y
+
+ verify_script_model(RNNLoop().eval(), [(10, 10, 4)])
if __name__ == "__main__":
@@ -860,3 +1041,7 @@ if __name__ == "__main__":
test_quantized_modules()
test_quantized_imagenet()
+
+ # Test simple conditionals and loop
+ test_control_flow()
+ test_simple_rnn()