You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by mb...@apache.org on 2020/12/03 18:32:53 UTC

[tvm] branch main updated: [Diagnostics] Add environment variable for controlling top-level printing and fix issue with pretty printing/parsing roundtrip. (#6874)

This is an automated email from the ASF dual-hosted git repository.

mbrookhart pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 8daa97e  [Diagnostics] Add environment variable for controlling top-level printing and fix issue with pretty printing/parsing roundtrip. (#6874)
8daa97e is described below

commit 8daa97ec87118ecdf38453ca878655cb08fba329
Author: Jared Roesch <ro...@gmail.com>
AuthorDate: Thu Dec 3 10:32:37 2020 -0800

    [Diagnostics] Add environment variable for controlling top-level printing and fix issue with pretty printing/parsing roundtrip. (#6874)
    
    * Update Parser in order to handle the NMS code
    
    * Add support for displaying traces optionally
    
    * WIP
    
    * Fix
    
    * Fix error reporting in parser and clean up __init__.py due to CR
    
    * Format
    
    * Quick fix for If
    
    * Fix format
    
    * Fix lint
---
 python/tvm/__init__.py               | 21 +++++++--
 src/parser/parser.cc                 | 91 +++++++++++++++++++++++++-----------
 tests/python/relay/test_ir_parser.py | 14 ++++++
 3 files changed, 95 insertions(+), 31 deletions(-)

diff --git a/python/tvm/__init__.py b/python/tvm/__init__.py
index 569e8f0..c2b4fdb 100644
--- a/python/tvm/__init__.py
+++ b/python/tvm/__init__.py
@@ -68,15 +68,28 @@ from . import support
 from .contrib import rocm as _rocm, nvcc as _nvcc, sdaccel as _sdaccel
 
 
+def _should_print_backtrace():
+    in_pytest = "PYTEST_CURRENT_TEST" in os.environ
+    tvm_backtrace = os.environ.get("TVM_BACKTRACE", "0")
+
+    try:
+        tvm_backtrace = bool(int(tvm_backtrace))
+    except ValueError:
+        raise ValueError(
+            f"invalid value for TVM_BACKTRACE `{tvm_backtrace}`, please set to 0 or 1."
+        )
+
+    return in_pytest or tvm_backtrace
+
+
 def tvm_wrap_excepthook(exception_hook):
     """Wrap given excepthook with TVM additional work."""
 
     def wrapper(exctype, value, trbk):
         """Clean subprocesses when TVM is interrupted."""
-        in_pytest = "PYTEST_CURRENT_TEST" in os.environ
-
-        if exctype is error.DiagnosticError and not in_pytest:
-            pass
+        if exctype is error.DiagnosticError and not _should_print_backtrace():
+            # TODO(@jroesch): consider moving to C++?
+            print("note: run with `TVM_BACKTRACE=1` environment variable to display a backtrace.")
         else:
             exception_hook(exctype, value, trbk)
 
diff --git a/src/parser/parser.cc b/src/parser/parser.cc
index 987a6e2..afcf707 100644
--- a/src/parser/parser.cc
+++ b/src/parser/parser.cc
@@ -605,30 +605,43 @@ class Parser {
     return ast;
   }
 
+  struct MetaRef {
+    std::string type_key;
+    uint64_t node_index;
+    Span span;
+    MetaRef(std::string type_key, uint64_t node_index, Span span)
+        : type_key(type_key), node_index(node_index), span(span) {}
+  };
+
+  MetaRef MetaRefFromToken(const Token& tok) {
+    Call ref = Downcast<Call>(tok->data);
+    auto attrs = ref->attrs.as<MetaRefAttrs>();
+    auto type_key = attrs->node_type_key;
+    auto index = attrs->node_index;
+    return MetaRef(type_key, index, ref->span);
+  }
+
   /*! \brief Parse a meta reference of the form `meta[type_key][node_index]`.
    * For example `meta[relay.Constant][0]` references the first constant, `meta[relay.Constant][1]`
    * the second, and so on.
    */
   ObjectRef ParseMetaRef() {
-    auto meta_ref = Match(TokenType::kMetaReference);
-    Call ref = Downcast<Call>(meta_ref->data);
-    auto attrs = ref->attrs.as<MetaRefAttrs>();
-    auto type_key = attrs->node_type_key;
-    auto index = attrs->node_index;
-    auto it = this->meta_table.find(type_key);
+    auto meta_ref_tok = Match(TokenType::kMetaReference);
+    auto meta_ref = MetaRefFromToken(meta_ref_tok);
+    auto it = this->meta_table.find(meta_ref.type_key);
     if (it != this->meta_table.end()) {
       auto nodes = (*it).second;
-      if (index < nodes.size()) {
-        return nodes[index];
+      if (meta_ref.node_index < nodes.size()) {
+        return nodes[meta_ref.node_index];
       } else {
-        this->diag_ctx.Emit(Diagnostic::Error(meta_ref->span)
-                            << "the node index `" << index << "` is out of bounds for `" << type_key
-                            << "`");
+        this->diag_ctx.Emit(Diagnostic::Error(meta_ref.span)
+                            << "the node index `" << meta_ref.node_index
+                            << "` is out of bounds for `" << meta_ref.type_key << "`");
         return ObjectRef();
       }
     } else {
-      this->diag_ctx.Emit(Diagnostic::Error(meta_ref->span)
-                          << "no entry in the meta table for `" << type_key << "`");
+      this->diag_ctx.Emit(Diagnostic::Error(meta_ref.span)
+                          << "no entry in the meta table for `" << meta_ref.type_key << "`");
       return ObjectRef();
     }
   }
@@ -922,10 +935,7 @@ class Parser {
             exprs.push_back(ParseMatch(is_total));
             break;
           }
-          case TokenType::kIf: {
-            exprs.push_back(ParseIf());
-            break;
-          }
+
           // %x ...
           case TokenType::kGraph:
             if (Lookahead(2)->token_type == TokenType::kEqual) {
@@ -1344,6 +1354,10 @@ class Parser {
             Match(TokenType::kIdentifier);
             return ObjectRef();
           }
+          if (id == "None") {
+            Match(TokenType::kIdentifier);
+            return Optional<ObjectRef>();
+          }
         }
       }
       default:
@@ -1372,7 +1386,7 @@ class Parser {
     ICHECK(op.defined()) << "the operator must be defined";
 
     DLOG(INFO) << "Parser::ParseCallArgs";
-    Map<String, ObjectRef> raw_attrs;
+    Attrs attrs;
     std::string op_key;
     bool is_op = false;
 
@@ -1388,21 +1402,40 @@ class Parser {
           [&] {
             auto is_ident = Lookahead(1)->token_type == TokenType::kIdentifier;
             auto next_is_equal = Lookahead(2)->token_type == TokenType::kEqual;
-
-            if (is_op && is_ident && next_is_equal) {
-              raw_attrs = ParseAttrs();
+            auto is_pretty_attrs = is_ident && next_is_equal;
+            auto is_meta_next = Lookahead(1)->token_type == TokenType::kMetaReference;
+            // TODO(@jroesch): might not handle trailing comma
+            auto last_meta = Lookahead(2)->token_type == TokenType::kCloseParen;
+            auto is_meta_attrs = is_meta_next && last_meta;
+
+            if (is_op && (is_pretty_attrs || is_meta_attrs)) {
+              if (is_meta_attrs) {
+                auto meta_ref = ParseMetaRef();
+                if (meta_ref.as<BaseAttrsNode>()) {
+                  attrs = Downcast<Attrs>(meta_ref);
+                } else {
+                  // Not awesome parsing code here.
+                  this->pos--;
+                  return false;
+                }
+              } else {
+                auto raw_attrs = ParseAttrs();
+                auto attr_obj = tvm::ReflectionVTable::Global()->CreateObject(op_key, raw_attrs);
+                ICHECK(attr_obj.defined());
+                attrs = Downcast<Attrs>(attr_obj);
+              }
               return true;
             }
 
             return false;
           });
 
-      Attrs attrs;
-
-      if (is_op && op_key.size()) {
-        auto attr_obj = tvm::ReflectionVTable::Global()->CreateObject(op_key, raw_attrs);
-        ICHECK(attr_obj.defined());
-        attrs = Downcast<Attrs>(attr_obj);
+      if (!attrs.defined()) {
+        if (is_op && op_key.size()) {
+          auto attr_obj = tvm::ReflectionVTable::Global()->CreateObject(op_key, {});
+          ICHECK(attr_obj.defined());
+          attrs = Downcast<Attrs>(attr_obj);
+        }
       }
 
       // TODO(@jroesch): in a secondary pass adjust spans.
@@ -1527,6 +1560,10 @@ class Parser {
           ICHECK(e->span.defined()) << "function spans must be defined.\n" << e;
           return e;
         }
+        case TokenType::kIf: {
+          Expr e = ParseIf();
+          return e;
+        }
         case TokenType::kRef: {
           Consume(TokenType::kRef);
           Match(TokenType::kOpenParen);
diff --git a/tests/python/relay/test_ir_parser.py b/tests/python/relay/test_ir_parser.py
index c5217ba..1622717 100644
--- a/tests/python/relay/test_ir_parser.py
+++ b/tests/python/relay/test_ir_parser.py
@@ -875,6 +875,20 @@ def test_tuple_return_value():
     parse_module(program)
 
 
+def test_parse_if_in_binding():
+    program = """
+    def @example(%b: bool) {
+        %0 = if (%b) {
+            1
+        } else {
+            0
+        };
+        %0
+    }
+    """
+    parse_module(program)
+
+
 def test_op_string_attr():
     call = parse_text(
         """