You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2021/12/24 03:04:16 UTC

[tvm] branch main updated: [TE] Support varargs in te.compute (#9796)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 066b417  [TE] Support varargs in te.compute (#9796)
066b417 is described below

commit 066b4170d5193e40e105fd43c7df2a6ec00ad951
Author: Tristan Konolige <tk...@octoml.ai>
AuthorDate: Thu Dec 23 19:03:32 2021 -0800

    [TE] Support varargs in te.compute (#9796)
    
    * [TE] Support varargs in te.compute
    
    Support varargs (`lambda x, *args: ...`) in te.compute. The varargs take
    indices into the remaining dimensions of the outputs shape. This
    requires using inspect.getfullargspec instead of `fcompute.__code__`.
    
    Also add checks that there are no keyword arguments.
    
    * implicitly broadcast to remaining dimensions
---
 python/tvm/te/operation.py | 29 ++++++++++++++++++++---------
 python/tvm/te/tensor.py    |  4 +++-
 2 files changed, 23 insertions(+), 10 deletions(-)

diff --git a/python/tvm/te/operation.py b/python/tvm/te/operation.py
index 5cb58a8..dafbbfd 100644
--- a/python/tvm/te/operation.py
+++ b/python/tvm/te/operation.py
@@ -18,6 +18,7 @@
 # pylint: disable=invalid-name
 from numbers import Integral as _Integral
 from typing import List, Union
+import inspect
 
 import tvm._ffi
 from tvm._ffi.base import string_types
@@ -89,18 +90,28 @@ def compute(shape, fcompute, name="compute", tag="", attrs=None):
     shape = (shape,) if isinstance(shape, tvm.tir.PrimExpr) else shape
     # for python3
     shape = tuple([int(s) if isinstance(s, float) else s for s in shape])
-    ndim = len(shape)
-    code = fcompute.__code__
-
-    out_ndim = ndim
-    if code.co_argcount == 0:
-        arg_names = ["i%d" % i for i in range(ndim)]
+    out_ndim = len(shape)
+
+    argspec = inspect.getfullargspec(fcompute)
+    if len(argspec.args) == 0 and argspec.varargs is None:
+        arg_names = ["i%d" % i for i in range(out_ndim)]
+    elif argspec.varargs is not None:
+        # if there is a varargs, it takes the remaining dimensions of out_ndim
+        arg_names = argspec.args + [f"i{i}" for i in range(out_ndim - len(argspec.args))]
     else:
-        arg_names = code.co_varnames[: code.co_argcount]
-        out_ndim = code.co_argcount
+        arg_names = argspec.args
+        # if there are fewer args than out dimensions, the remaining dimensions
+        # are implicitly broadcast
+        out_ndim = len(arg_names)
+    assert argspec.varkw is None, "Variable keyword arguments not supported in fcompute"
+    assert argspec.defaults is None, "Default arguments not supported in fcompute"
+    assert len(argspec.kwonlyargs) == 0, "Keyword arguments are not supported in fcompute"
 
     if out_ndim != len(arg_names):
-        raise ValueError("fcompute do not match dimension, ndim=%d" % ndim)
+        raise ValueError(
+            "Number of args to fcompute does not match dimension, "
+            "args=%d, dimension=%d" % (len(arg_names), out_ndim)
+        )
 
     dim_var = [tvm.tir.IterVar((0, s), x, 0) for x, s in zip(arg_names, shape[:out_ndim])]
     body = fcompute(*[v.var for v in dim_var])
diff --git a/python/tvm/te/tensor.py b/python/tvm/te/tensor.py
index bdf3954..fc85d83 100644
--- a/python/tvm/te/tensor.py
+++ b/python/tvm/te/tensor.py
@@ -60,7 +60,9 @@ class Tensor(DataProducer, _expr.ExprOp):
     def __call__(self, *indices):
         ndim = self.ndim
         if len(indices) != ndim:
-            raise ValueError("Need to provide %d index in tensor slice" % ndim)
+            raise ValueError(
+                "Need to provide %d index in tensor but %d was provided" % (ndim, len(indices))
+            )
         indices = convert_to_object(indices)
         args = []
         for x in indices: