You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2021/12/24 03:04:16 UTC
[tvm] branch main updated: [TE] Support varargs in te.compute (#9796)
This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 066b417 [TE] Support varargs in te.compute (#9796)
066b417 is described below
commit 066b4170d5193e40e105fd43c7df2a6ec00ad951
Author: Tristan Konolige <tk...@octoml.ai>
AuthorDate: Thu Dec 23 19:03:32 2021 -0800
[TE] Support varargs in te.compute (#9796)
* [TE] Support varargs in te.compute
Support varargs (`lambda x, *args: ...`) in te.compute. The varargs take
indices into the remaining dimensions of the outputs shape. This
requires using inspect.getfullargspec instead of `fcompute.__code__`.
Also add checks that there are no keyword arguments.
* implicitly broadcast to remaining dimensions
---
python/tvm/te/operation.py | 29 ++++++++++++++++++++---------
python/tvm/te/tensor.py | 4 +++-
2 files changed, 23 insertions(+), 10 deletions(-)
diff --git a/python/tvm/te/operation.py b/python/tvm/te/operation.py
index 5cb58a8..dafbbfd 100644
--- a/python/tvm/te/operation.py
+++ b/python/tvm/te/operation.py
@@ -18,6 +18,7 @@
# pylint: disable=invalid-name
from numbers import Integral as _Integral
from typing import List, Union
+import inspect
import tvm._ffi
from tvm._ffi.base import string_types
@@ -89,18 +90,28 @@ def compute(shape, fcompute, name="compute", tag="", attrs=None):
shape = (shape,) if isinstance(shape, tvm.tir.PrimExpr) else shape
# for python3
shape = tuple([int(s) if isinstance(s, float) else s for s in shape])
- ndim = len(shape)
- code = fcompute.__code__
-
- out_ndim = ndim
- if code.co_argcount == 0:
- arg_names = ["i%d" % i for i in range(ndim)]
+ out_ndim = len(shape)
+
+ argspec = inspect.getfullargspec(fcompute)
+ if len(argspec.args) == 0 and argspec.varargs is None:
+ arg_names = ["i%d" % i for i in range(out_ndim)]
+ elif argspec.varargs is not None:
+ # if there is a varargs, it takes the remaining dimensions of out_ndim
+ arg_names = argspec.args + [f"i{i}" for i in range(out_ndim - len(argspec.args))]
else:
- arg_names = code.co_varnames[: code.co_argcount]
- out_ndim = code.co_argcount
+ arg_names = argspec.args
+ # if there are fewer args than out dimensions, the remaining dimensions
+ # are implicitly broadcast
+ out_ndim = len(arg_names)
+ assert argspec.varkw is None, "Variable keyword arguments not supported in fcompute"
+ assert argspec.defaults is None, "Default arguments not supported in fcompute"
+ assert len(argspec.kwonlyargs) == 0, "Keyword arguments are not supported in fcompute"
if out_ndim != len(arg_names):
- raise ValueError("fcompute do not match dimension, ndim=%d" % ndim)
+ raise ValueError(
+ "Number of args to fcompute does not match dimension, "
+ "args=%d, dimension=%d" % (len(arg_names), out_ndim)
+ )
dim_var = [tvm.tir.IterVar((0, s), x, 0) for x, s in zip(arg_names, shape[:out_ndim])]
body = fcompute(*[v.var for v in dim_var])
diff --git a/python/tvm/te/tensor.py b/python/tvm/te/tensor.py
index bdf3954..fc85d83 100644
--- a/python/tvm/te/tensor.py
+++ b/python/tvm/te/tensor.py
@@ -60,7 +60,9 @@ class Tensor(DataProducer, _expr.ExprOp):
def __call__(self, *indices):
ndim = self.ndim
if len(indices) != ndim:
- raise ValueError("Need to provide %d index in tensor slice" % ndim)
+ raise ValueError(
+ "Need to provide %d index in tensor but %d was provided" % (ndim, len(indices))
+ )
indices = convert_to_object(indices)
args = []
for x in indices: