You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2022/11/16 08:04:16 UTC

[GitHub] [tvm] echuraev commented on a diff in pull request #13402: [Relay][Frontend] Span filling common API

echuraev commented on code in PR #13402:
URL: https://github.com/apache/tvm/pull/13402#discussion_r1023597049


##########
python/tvm/relay/frontend/common.py:
##########
@@ -997,3 +1003,135 @@ def try_resolve_var_to_const(x, graph_params):
         return _op.const(value, dtype)
 
     return x
+
+
+class _SpanFiller(ExprMutator):
+    """SpanFiller"""
+
+    def __init__(self, span):
+        ExprMutator.__init__(self)
+        if isinstance(span, tvm.relay.Span):
+            self._span = span
+        elif isinstance(span, str):
+            self._span = tvm.relay.Span(tvm.relay.SourceName(span), 0, 0, 0, 0)
+        elif isinstance(span, bytes):
+            self._span = tvm.relay.Span(tvm.relay.SourceName(span.decode("utf-8")), 0, 0, 0, 0)
+        else:
+            assert False, f"unsupported span type: {type(span)}"
+
+    def visit(self, expr):
+        if hasattr(expr, "span") and expr.span:
+            return expr
+
+        return super().visit(expr)
+
+    def visit_function(self, fn):
+        new_params = [self.visit(x) for x in fn.params]
+        new_body = self.visit(fn.body)
+        return _function.Function(
+            list(new_params), new_body, fn.ret_type, fn.type_params, fn.attrs, self._span
+        )
+
+    def visit_let(self, let):
+        new_variable = self.visit(let.var)
+        new_value = self.visit(let.value)
+        new_body = self.visit(let.body)
+        return _expr.Let(new_variable, new_value, new_body, self._span)
+
+    def visit_call(self, call):
+        new_args = [self.visit(arg) for arg in call.args]
+        # call.op might be RelayExpr or Op type
+        # ExprMutator will return directly if subject belongs to Op type
+        new_op = self.visit(call.op)
+        return _expr.Call(new_op, new_args, call.attrs, call.type_args, self._span)
+
+    def visit_var(self, var):
+        return _expr.Var(var.name_hint, var.type_annotation, self._span)
+
+    def visit_if(self, ite):
+        return _expr.If(
+            self.visit(ite.cond),
+            self.visit(ite.true_branch),
+            self.visit(ite.false_branch),
+            self._span,
+        )
+
+    def visit_tuple(self, tup):
+        return _expr.Tuple([self.visit(field) for field in tup.fields], self._span)
+
+    def visit_tuple_getitem(self, op):
+        return _expr.TupleGetItem(self.visit(op.tuple_value), op.index, self._span)
+
+    def visit_constant(self, const):
+        return _expr.Constant(const.data, self._span)
+
+    # TODO: Frontend model translation could not use following relay expressions so far,
+    #       enable them when new models/impls leverage these kinds of relay expressions.
+    def visit_ref_create(self, _):
+        raise NotImplementedError()
+
+    def visit_ref_write(self, _):
+        raise NotImplementedError()
+
+    def visit_ref_read(self, _):
+        raise NotImplementedError()
+
+    def visit_match(self, _):
+        raise NotImplementedError()
+
+    def fill(self, sym):
+        """Fill span to sym when it is an expr, or return it without change
+
+        Parameters
+        ----------
+        sym :
+            A symbol which is generated from the conversion of a frontend operator.
+
+        Returns
+        -------
+        sym:
+            A expr with span-filled or the original sym.
+        """
+        if isinstance(sym, _expr.TupleWrapper):
+            return _expr.TupleWrapper(self.visit(sym.tuple_value), sym.size)
+        elif isinstance(sym, _expr.RelayExpr):
+            return self.visit(sym)
+        elif isinstance(sym, list):
+            assert all(
+                isinstance(expr, _expr.RelayExpr) for expr in sym
+            ), f"unexpected relay expressions in {sym}"
+            return [self.visit(expr) for expr in sym]
+        elif isinstance(sym, tuple):
+            # some op conversion may return dummy elements
+            # e.g. op in frontend/pytorch.py: min_max_common
+            assert all(
+                isinstance(expr, (_expr.RelayExpr, type(None))) for expr in sym
+            ), f"unexpected relay expressions in {sym}"
+            return tuple(self.visit(expr) if expr else None for expr in sym)
+        elif isinstance(sym, (float, int)):
+            return sym
+        elif isinstance(sym, np.ndarray):
+            return sym
+
+        raise RuntimeError(f"unsupported type {type(sym)}")
+
+
+def _should_fill_span():
+    should_fill_span = os.environ.get("TVM_SPANFILLING", "1")

Review Comment:
   Will we document somewhere how to use span filling and describe variable `TVM_SPANFILLING`?
   At least, probably we should add information about this variable to the `set_span` method?



##########
python/tvm/relay/frontend/common.py:
##########
@@ -304,13 +306,17 @@ def __init__(self):
         self.const_ctr = 1
         self.in_padding = False
 
-    def new_const(self, value, shape=None, dtype="float32"):
+    def new_const(self, value, shape=None, dtype="float32", source_name=None):
+        """Construct a new var expr and add to exprs dictionary"""
         name = "_param_%d" % (self.const_ctr)
         if hasattr(value, "shape"):
             shape = value.shape
         self.const_ctr += 1
         self.params[name] = value
-        self.exprs[name] = _expr.var(name_hint=name, shape=shape, dtype=dtype)
+        tmp_var = _expr.var(name_hint=name, shape=shape, dtype=dtype)
+        if source_name:
+            tmp_var = set_span(tmp_var, source_name)
+        self.exprs[name] = tmp_var

Review Comment:
   ```suggestion
           self.exprs[name] = _expr.var(name_hint=name, shape=shape, dtype=dtype)
           if source_name:
               self.exprs[name] = set_span(tmp_var, source_name)
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org