You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2020/11/04 11:00:03 UTC

[GitHub] [incubator-tvm] leandron commented on a change in pull request #6797: [TVMSCRIPT] Using diagnostics for TVM Script

leandron commented on a change in pull request #6797:
URL: https://github.com/apache/incubator-tvm/pull/6797#discussion_r517248319



##########
File path: python/tvm/script/parser.py
##########
@@ -493,299 +529,281 @@ def visit_With(self, node):
                 with tir.let()/tir.Assert()/tir.attr()//tir.realize()
         """
 
-        if not len(node.items) == 1:
-            self.report_error("Only one with element is supported now")
-        if not isinstance(node.items[0].context_expr, ast.Call):
-            self.report_error("The context expression of with should be a Call")
+        if not isinstance(node.rhs, ast.Call):
+            self.report_error(
+                "The context expression of a `with` statement should be a function call.",
+                node.rhs.span,
+            )
 
-        func_call = node.items[0].context_expr
-        func_node = func_call.func
-        func = self.visit(func_node)
+        func = self.transform(node.rhs.func_name)
 
         if not isinstance(func, WithScopeHandler):
-            self.report_error("Function not allowed in with scope")
+            self.report_error(
+                f"Function {func} cannot be used in a `with` statement.", node.rhs.func_name.span
+            )
         # prepare for new block scope
         old_lineno, old_col_offset = self.current_lineno, self.current_col_offset
-        self.current_lineno, self.current_col_offset = (
-            self.base_lineno + func_call.lineno - 1,
-            func_call.col_offset,
-        )
-        self.context.new_scope(nodes=node.body)
+        self.current_lineno = node.body.span.start_line
+        self.current_col_offset = node.body.span.start_column
+        self.context.new_scope(nodes=node.body.stmts)
         # with scope handler process the scope
         func.enter_scope(node, self.context)
-        func.body = self.parse_body()
-        arg_list = self.parse_arg_list(func, func_call)
+        func.body = self.parse_body(node)
+        arg_list = self.parse_arg_list(func, node.rhs)
         res = func.exit_scope(node, self.context, arg_list)
         # exit the scope
         self.context.pop_scope()
         self.current_lineno, self.current_col_offset = old_lineno, old_col_offset
         return res
 
-    def visit_If(self, node):
+    def transform_If(self, node):
         """If visitor
         AST abstract grammar:
             If(expr test, stmt* body, stmt* orelse)
         """
 
-        condition = self.visit(node.test)
+        condition = self.transform(node.condition)
         # then body
-        self.context.new_scope(nodes=node.body)
-        then_body = self.parse_body()
+        self.context.new_scope(nodes=node.true.stmts)
+        then_body = self.parse_body(node)
         self.context.pop_scope()
 
         # else body
-        if len(node.orelse) > 0:
-            self.context.new_scope(nodes=node.orelse)
-            else_body = self.parse_body()
+        if len(node.false.stmts) > 0:
+            self.context.new_scope(nodes=node.false.stmts)
+            else_body = self.parse_body(node)
             self.context.pop_scope()
         else:
             else_body = None
 
         return tvm.tir.IfThenElse(condition, then_body, else_body)
 
-    def visit_Call(self, node):
+    def transform_Call(self, node):
         """Call visitor
-        AST abstract grammar:
-            Call(expr func, expr* args, keyword* keywords)
-            keyword = (identifier? arg, expr value)
 
-        By now 3 patterns of Call is allowed
-            1. Intrin representing PrimExpr/IterVar
+        3 different Call patterns are allowed:
+            1. Intrin representing a PrimExpr/IterVar
                 1.1 tir.int/uint/float8/16/32/64/floormod/floordiv/load/cast/ramp/broadcast/max
                 1.2 tir.range/reduce_axis/scan_axis/opaque_axis
             2. tir.Op(dtype, ...)
             3. other callable functions
         """
 
-        func = self.visit(node.func)
-        if isinstance(func, Intrin) and not func.stmt:
-            # pattern 1
-            arg_list = self.parse_arg_list(func, node)
-            return func.handle(arg_list)
+        if isinstance(node.func_name, ast.Op):
+            if node.func_name.name == ast.BuiltinOp.Subscript:
+                return self.transform_Subscript(node)
+            if node.func_name.name in self._binop_maker:
+                lhs = self.transform(node.params[0])
+                rhs = self.transform(node.params[1])
+                return self._binop_maker[node.func_name.name](lhs, rhs)
+            if node.func_name.name in self._unaryop_maker:
+                rhs = self.transform(node.params[0])
+                return self._unaryop_maker[node.func_name.name](rhs)
+            self.report_error(f"Unsupported operator {node.func_name.name}.", node.func_name.span)
         else:
-            args = [self.visit(arg) for arg in node.args]
-            kw_args = [self.visit(keyword) for keyword in node.keywords]
-            kw_args = {kw_arg[0]: kw_arg[1] for kw_arg in kw_args}
-            if isinstance(func, tvm.tir.op.Op):
-                # pattern 2
-                return tvm.tir.Call(kw_args["dtype"], func, args)
-            elif callable(func):
-                # pattern 3
-                return func(*args, **kw_args)
-
-        self.report_error("Unsupported function call")
-
-    def visit_Expr(self, node):
-        """Expr visitor
-        AST abstract grammar:
-            Expr(expr value)
-
-        Now only 3 types of Expr stmt is allowed:
-            1. Intrin representing Stmt without body
-                tir.store()/tir.evaluate()
-            2. with scope handlers with concise scoping without var def
-                tir.attr()/tir.assert()/tir.allocate()/tir.realize()
-            3. special stmt without var def
-                tir.func_attr()
+            func = self.transform(node.func_name)
+            if isinstance(func, Intrin) and not func.stmt:
+                # pattern 1
+                arg_list = self.parse_arg_list(func, node)
+                return func.handle(arg_list)
+            else:
+                args = [self.transform(arg) for arg in node.params]
+                kw_args = {
+                    self.transform(k): self.transform(v) for k, v in node.keyword_params.items()
+                }
+                if isinstance(func, tvm.tir.op.Op):
+                    # pattern 2
+                    return tvm.tir.Call(kw_args["dtype"], func, args)
+                elif callable(func):
+                    # pattern 3
+                    return func(*args, **kw_args)
+
+        self.report_error("Unsupported function call.", node.func_name.span)
+
+    def transform_UnassignedCall(self, node):
+        """Visitor for statements that are function calls.
+
+        This handles function calls that appear on thier own line like `tir.realize`.
+
+        Examples
+        --------
+        .. code-block:: python
+
+            @tvm.script.tir
+            def f():
+                A = tir.buffer_decl([10, 10])
+                tir.realize(A[1:2, 1:2], "")  # This is an UnassignedCall
+                A[1, 1] = 2  # This is also an UnassignedCall
         """
+        # Only allowed builtin operator that can be a statement is x[1] = 3 i.e. subscript assign.
+        if isinstance(node.call.func_name, ast.Op):
+            if node.call.func_name.name != ast.BuiltinOp.SubscriptAssign:
+                self.report_error(
+                    "Binary and unary operators are not allowed as a statement", node.span
+                )
+            else:
+                return self.transform_SubscriptAssign(node.call)
 
-        if not isinstance(node.value, ast.Call):
-            self.report_error("Unsupported Expr stmt")
+        # handle a regular funciton call

Review comment:
       ```suggestion
           # handle a regular function call
   ```

##########
File path: python/tvm/script/parser.py
##########
@@ -14,21 +14,22 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""TVM Script Parser For TIR"""
-# pylint: disable=invalid-name, missing-docstring, inconsistent-return-statements, no-else-return
-# pylint: disable=unnecessary-comprehension, unused-argument
-# pylint: disable=relative-beyond-top-level
+"""TVM Script Parser For TIR
+
+We use [synr](https://synr.readthedocs.io) to get an AST that is stable over
+different python versions. Synr also provides an error handling context that we
+use for error reporting.
+"""
+# pylint: disable=invalid-name, inconsistent-return-statements, no-else-return
 import json
 import operator
 import inspect
-from typed_ast import ast3 as ast
+from synr import ast, Transformer, to_ast

Review comment:
       Out of curiosity, are we deprecating all usage of `typed_ast` in favour of `synr`? If so, I think we also want to replace and make this dependency official, here:
   
   https://github.com/apache/incubator-tvm/blob/c8064b3ca6787961c90fe7c09c3ddc3beba3ece5/python/setup.py#L186
   
   Plus remove `typed_ast` from the python packages script, maybe in a follow-up patch to support the transition period:
   
   https://github.com/apache/incubator-tvm/blob/c8064b3ca6787961c90fe7c09c3ddc3beba3ece5/docker/install/ubuntu_install_python_package.sh#L24
   
   In case you're doing it in a follow-up patch, can you add a TODO comment?
   
   

##########
File path: python/tvm/script/parser.py
##########
@@ -238,25 +252,26 @@ def parse_arg_list(self, func, node_call):
             internal_args.extend(reader.get_varargs(len(pos_only) + len(kwargs) + 1))
         return internal_args
 
-    def parse_type(self, type_node):
-        """ Parse type """
+    def parse_type(self, type_node, parent):
+        """Parse a type annotation.
+
+        We require the parent object to the type so that we have a place to
+        report the error message if the type does not exist.
+        """
         if type_node is None:
-            self.report_error("missing type annotation")
-        res_type = self.visit(type_node)
+            self.report_error("A type annotation is required", parent.span)
+        res_type = self.transform(type_node)
         return tvm.ir.TupleType([]) if res_type is None else res_type.evaluate()
 
     def generic_visit(self, node):
-        """Override method in ast.NodeVisitor.
-        To directly filter out invalidate type of stmt.
-        """
+        """Fallback visitor if node type is not handled. Reports an error."""
 
-        self.report_error(type(node).__name__ + " AST node is not supported now")
+        self.report_error(type(node).__name__ + " AST node is not supported now", node.span)

Review comment:
       ```suggestion
           self.report_error(type(node).__name__ + " AST node is not supported", node.span)
   ```

##########
File path: python/tvm/script/parser.py
##########
@@ -794,28 +812,54 @@ def visit_Name(self, node):
         symbol = self.context.lookup_symbol(name)
         if symbol is not None:
             return symbol
-        self.report_error("Unknown identifier %s" % name)
+        self.report_error(f"Unknown identifier {name}.", node.span)
+
+    def transform_TypeVar(self, node):
+        """Type variable visitor.
+
+        Equivalent to `transform_Var` but for types.
+        """
+        name = node.id.name
+        symbol = Registry.lookup(name)
+        if symbol is not None:
+            return symbol
+        symbol = self.context.lookup_symbol(name)
+        if symbol is not None:
+            return symbol

Review comment:
       An alternative way, that I think it is equivalent, with less branches and return statements in the same function. Feel free to keep your version if you want :)
   
   ```suggestion
           symbol = Registry.lookup(name) or self.context.lookup_symbol(name)
           if symbol is not None:
               return symbol
   ```

##########
File path: python/tvm/script/parser.py
##########
@@ -459,29 +495,29 @@ def visit_For(self, node):
                 for name in tir.serial()/tir.parallel()/tir.vectorized()/tir.unroll()
         """
 
-        if not isinstance(node.iter, ast.Call):
-            self.report_error("The loop iter should be a Call")
-        func = self.visit(node.iter.func)
+        if not isinstance(node.rhs, ast.Call):
+            self.report_error("The loop iterator should be a function call.", node.rhs.span)
+        func = self.transform(node.rhs.func_name)
         if not isinstance(func, ForScopeHandler):
-            self.report_error("Only for scope handlers can be used in for stmt")
+            self.report_error(
+                "Only for scope handlers can be used in a for statement.", node.rhs.func_name.span

Review comment:
       ```suggestion
                   "Only For scope handlers can be used in a for statement.", node.rhs.func_name.span
   ```

##########
File path: python/tvm/script/parser.py
##########
@@ -493,299 +529,281 @@ def visit_With(self, node):
                 with tir.let()/tir.Assert()/tir.attr()//tir.realize()
         """
 
-        if not len(node.items) == 1:
-            self.report_error("Only one with element is supported now")
-        if not isinstance(node.items[0].context_expr, ast.Call):
-            self.report_error("The context expression of with should be a Call")
+        if not isinstance(node.rhs, ast.Call):
+            self.report_error(
+                "The context expression of a `with` statement should be a function call.",
+                node.rhs.span,
+            )
 
-        func_call = node.items[0].context_expr
-        func_node = func_call.func
-        func = self.visit(func_node)
+        func = self.transform(node.rhs.func_name)
 
         if not isinstance(func, WithScopeHandler):
-            self.report_error("Function not allowed in with scope")
+            self.report_error(
+                f"Function {func} cannot be used in a `with` statement.", node.rhs.func_name.span
+            )
         # prepare for new block scope
         old_lineno, old_col_offset = self.current_lineno, self.current_col_offset
-        self.current_lineno, self.current_col_offset = (
-            self.base_lineno + func_call.lineno - 1,
-            func_call.col_offset,
-        )
-        self.context.new_scope(nodes=node.body)
+        self.current_lineno = node.body.span.start_line
+        self.current_col_offset = node.body.span.start_column
+        self.context.new_scope(nodes=node.body.stmts)
         # with scope handler process the scope
         func.enter_scope(node, self.context)
-        func.body = self.parse_body()
-        arg_list = self.parse_arg_list(func, func_call)
+        func.body = self.parse_body(node)
+        arg_list = self.parse_arg_list(func, node.rhs)
         res = func.exit_scope(node, self.context, arg_list)
         # exit the scope
         self.context.pop_scope()
         self.current_lineno, self.current_col_offset = old_lineno, old_col_offset
         return res
 
-    def visit_If(self, node):
+    def transform_If(self, node):
         """If visitor
         AST abstract grammar:
             If(expr test, stmt* body, stmt* orelse)
         """
 
-        condition = self.visit(node.test)
+        condition = self.transform(node.condition)
         # then body
-        self.context.new_scope(nodes=node.body)
-        then_body = self.parse_body()
+        self.context.new_scope(nodes=node.true.stmts)
+        then_body = self.parse_body(node)
         self.context.pop_scope()
 
         # else body
-        if len(node.orelse) > 0:
-            self.context.new_scope(nodes=node.orelse)
-            else_body = self.parse_body()
+        if len(node.false.stmts) > 0:
+            self.context.new_scope(nodes=node.false.stmts)
+            else_body = self.parse_body(node)
             self.context.pop_scope()
         else:
             else_body = None
 
         return tvm.tir.IfThenElse(condition, then_body, else_body)
 
-    def visit_Call(self, node):
+    def transform_Call(self, node):
         """Call visitor
-        AST abstract grammar:
-            Call(expr func, expr* args, keyword* keywords)
-            keyword = (identifier? arg, expr value)
 
-        By now 3 patterns of Call is allowed
-            1. Intrin representing PrimExpr/IterVar
+        3 different Call patterns are allowed:
+            1. Intrin representing a PrimExpr/IterVar
                 1.1 tir.int/uint/float8/16/32/64/floormod/floordiv/load/cast/ramp/broadcast/max
                 1.2 tir.range/reduce_axis/scan_axis/opaque_axis
             2. tir.Op(dtype, ...)
             3. other callable functions
         """
 
-        func = self.visit(node.func)
-        if isinstance(func, Intrin) and not func.stmt:
-            # pattern 1
-            arg_list = self.parse_arg_list(func, node)
-            return func.handle(arg_list)
+        if isinstance(node.func_name, ast.Op):
+            if node.func_name.name == ast.BuiltinOp.Subscript:
+                return self.transform_Subscript(node)
+            if node.func_name.name in self._binop_maker:
+                lhs = self.transform(node.params[0])
+                rhs = self.transform(node.params[1])
+                return self._binop_maker[node.func_name.name](lhs, rhs)
+            if node.func_name.name in self._unaryop_maker:
+                rhs = self.transform(node.params[0])
+                return self._unaryop_maker[node.func_name.name](rhs)
+            self.report_error(f"Unsupported operator {node.func_name.name}.", node.func_name.span)
         else:
-            args = [self.visit(arg) for arg in node.args]
-            kw_args = [self.visit(keyword) for keyword in node.keywords]
-            kw_args = {kw_arg[0]: kw_arg[1] for kw_arg in kw_args}
-            if isinstance(func, tvm.tir.op.Op):
-                # pattern 2
-                return tvm.tir.Call(kw_args["dtype"], func, args)
-            elif callable(func):
-                # pattern 3
-                return func(*args, **kw_args)
-
-        self.report_error("Unsupported function call")
-
-    def visit_Expr(self, node):
-        """Expr visitor
-        AST abstract grammar:
-            Expr(expr value)
-
-        Now only 3 types of Expr stmt is allowed:
-            1. Intrin representing Stmt without body
-                tir.store()/tir.evaluate()
-            2. with scope handlers with concise scoping without var def
-                tir.attr()/tir.assert()/tir.allocate()/tir.realize()
-            3. special stmt without var def
-                tir.func_attr()
+            func = self.transform(node.func_name)
+            if isinstance(func, Intrin) and not func.stmt:
+                # pattern 1
+                arg_list = self.parse_arg_list(func, node)
+                return func.handle(arg_list)
+            else:
+                args = [self.transform(arg) for arg in node.params]
+                kw_args = {
+                    self.transform(k): self.transform(v) for k, v in node.keyword_params.items()
+                }
+                if isinstance(func, tvm.tir.op.Op):
+                    # pattern 2
+                    return tvm.tir.Call(kw_args["dtype"], func, args)
+                elif callable(func):
+                    # pattern 3
+                    return func(*args, **kw_args)
+
+        self.report_error("Unsupported function call.", node.func_name.span)
+
+    def transform_UnassignedCall(self, node):
+        """Visitor for statements that are function calls.
+
+        This handles function calls that appear on thier own line like `tir.realize`.
+
+        Examples
+        --------
+        .. code-block:: python
+
+            @tvm.script.tir
+            def f():
+                A = tir.buffer_decl([10, 10])
+                tir.realize(A[1:2, 1:2], "")  # This is an UnassignedCall
+                A[1, 1] = 2  # This is also an UnassignedCall
         """
+        # Only allowed builtin operator that can be a statement is x[1] = 3 i.e. subscript assign.
+        if isinstance(node.call.func_name, ast.Op):
+            if node.call.func_name.name != ast.BuiltinOp.SubscriptAssign:
+                self.report_error(
+                    "Binary and unary operators are not allowed as a statement", node.span
+                )
+            else:
+                return self.transform_SubscriptAssign(node.call)
 
-        if not isinstance(node.value, ast.Call):
-            self.report_error("Unsupported Expr stmt")
+        # handle a regular funciton call
+        func = self.transform(node.call.func_name)
+        arg_list = self.parse_arg_list(func, node.call)
 
-        func = self.visit(node.value.func)
-        arg_list = self.parse_arg_list(func, node.value)
+        if isinstance(func, tvm.script.scope_handler.AssertHandler):
+            self.report_error(
+                "A standalone `tir.Assert` is not allowed. Use `assert condition, message` "
+                "instead.",
+                node.call.func_name.span,
+            )
 
         if isinstance(func, Intrin) and func.stmt:
-            # pattern 1
             return func.handle(arg_list)
         elif isinstance(func, WithScopeHandler) and func.concise_scope and not func.def_symbol:
-            # pattern 2
             func.enter_scope(node, self.context)
-            func.body = self.parse_body()
+            func.body = self.parse_body(node)
             return func.exit_scope(node, self.context, arg_list)
         elif isinstance(func, SpecialStmt) and not func.def_symbol:
-            # pattern 3
             func.handle(node, self.context, arg_list)
             return
 
-        self.report_error("Invalid Expr stmt")
-
-    def visit_BinOp(self, node):
-        """BinOp visitor
-        AST abstract grammar:
-            BinOp(expr left, operator op, expr right)
-        """
-
-        lhs = self.visit(node.left)
-        rhs = self.visit(node.right)
-        if not isinstance(node.op, tuple(TVMScriptParser._binop_maker.keys())):
-            self.report_error("BinOp " + str(type(node.op)) + " is not supported now")
-        return TVMScriptParser._binop_maker[type(node.op)](lhs, rhs)
+        self.report_error(f"Invalid Expr stmt {type(func).__name__}.", node.call.func_name.span)
 
-    def visit_Compare(self, node):
-        """Compare visitor
-        AST abstract grammar:
-            Compare(expr left, expr right, ops=)
-        """
-
-        ops = [self.visit(node.left)]
-        ops += [self.visit(comparator) for comparator in node.comparators]
-        res = []
-        for i in range(len(node.ops)):
-            lhs = ops[i]
-            rhs = ops[i + 1]
-            res.append(TVMScriptParser._binop_maker[type(node.ops[i])](lhs, rhs))
-        return _all(*res)
-
-    def visit_BoolOp(self, node):
-        """BoolOp visitor
-        AST abstract grammar:
-            BoolOp(boolop op, expr* values)
-        """
-
-        values = [self.visit(value) for value in node.values]
-        return TVMScriptParser._binop_maker[type(node.op)](*values)
-
-    def visit_UnaryOp(self, node):
-        """UnaryOp visitor
-        AST abstract grammar:
-            UnaryOp(unaryop op, expr operand)
-        """
+    def transform_Slice(self, node):
+        start = self.transform(node.start)
+        end = self.transform(node.end)
+        if not (isinstance(node.step, ast.Constant) and node.step.value == 1):
+            self.report_error("Only step size 1 is supported for slices.", node.step.span)
+        extent = end - start
+        if isinstance(extent, tvm.tir.PrimExpr):
+            ana = tvm.arith.Analyzer()
+            extent = ana.simplify(extent)
+        return tvm.ir.Range.from_min_extent(start, extent)
 
-        operand = self.visit(node.operand)
-        if not isinstance(node.op, tuple(TVMScriptParser._unaryop_maker.keys())):
-            self.report_error("UnaryOp " + str(type(node.op)) + " is not supported now")
-        return TVMScriptParser._unaryop_maker[type(node.op)](operand)
+    def transform_Subscript(self, node):
+        """Array access visitor.
 
-    def visit_Subscript(self, node):
-        """Subscript visitor
-        AST abstract grammar:
-            Subscript(expr value, slice slice, expr_context ctx)
-            slice = Slice(expr? lower, expr? upper, expr? step)
-                    | ExtSlice(slice* dims)
-                    | Index(expr value)
-        By now 2 patterns of Subscript are supported:
+        By now only 2 types of Subscript are supported:
             1. Buffer[index, index, ...], Buffer element access(BufferLoad & BufferStore)
                Var[index] Buffer element access()
             2. meta[type_key][index], Meta info access
         """
 
-        symbol = self.visit(node.value)
+        symbol = self.transform(node.params[0])
         if symbol is None:
-            self.report_error(node.value.id + " is not defined")
-        if isinstance(symbol, (tvm.tir.expr.Var, tvm.tir.Buffer)):
-            if isinstance(node.slice, ast.Index):
-                # BufferLoad & BufferStore, Buffer/Var[index, index, ...]
-                indexes = self.visit(node.slice.value)
-                indexes = list(indexes) if isinstance(indexes, tuple) else [indexes]
-                if isinstance(node.ctx, ast.Load):
-                    if isinstance(symbol, tvm.tir.expr.Var):
-                        return tvm.tir.Load("float32", symbol, indexes, True)
-                    else:
-                        return tvm.tir.BufferLoad(symbol, indexes)
-                else:
-                    return symbol, indexes
-            else:
-                # Buffer Region, now used in tir.realize(buffer[bounds])
-                doms = []
-                slice_nodes = []
-                if isinstance(node.slice, ast.Slice):
-                    # Buffer[begin:end]
-                    slice_nodes.append(node.slice)
-                elif isinstance(node.slice, ast.ExtSlice):
-                    # Buffer[begin:end, begin:end]
-                    slice_nodes.extend(node.slice.dims)
-
-                for dim in slice_nodes:
-                    if not hasattr(dim, "step"):
-                        self.report_error("slice of Buffer Region ought to be begin:end")
-                    if dim.step is not None:
-                        self.report_error("step is not allowed in Buffer Region")
-                    upper = self.visit(dim.upper)
-                    lower = self.visit(dim.lower)
-                    extent = upper - lower
-                    if isinstance(extent, _expr.PrimExpr):
-                        ana = tvm.arith.Analyzer()
-                        extent = ana.simplify(extent)
-                    doms.append(tvm.ir.Range.from_min_extent(lower, extent))
-                return symbol, doms
-        else:
-            res = symbol[self.visit(slice)]
-            if res is None:
-                self.report_error("Only buffer variable and meta can be subscriptable")
-            return res
+            self.report_error(f"Variable {node.value.id} is not defined.", node.params[0].span)
 
-    def visit_Attribute(self, node):
-        """Attribute visitor
-        AST abstract grammar:
-            Attribute(expr value, identifier attr, expr_context ctx)
+        indexes = [self.transform(x) for x in node.params[1].values]
+        if isinstance(indexes[0], tvm.ir.Range):
+            return symbol, indexes
+
+        if isinstance(symbol, tvm.tir.expr.Var):
+            return tvm.tir.Load("float32", symbol, indexes, True)
+        if isinstance(symbol, tvm.tir.Buffer):
+            return tvm.tir.BufferLoad(symbol, indexes)
+
+        self.report_error(
+            f"Cannot subscript from a {type(symbol).__name__}. Only variables and "
+            "buffers are supported.",
+            node.params[0].span,
+        )
+
+    def transform_Attr(self, node):
+        """Visitor for field access of the form `x.y`.
+
+        This visitor is used to lookup function and symbol names. We have two
+        cases to handle here:
+        1. If we have a statement of the form `tir.something`, then we lookup
+           `tir.somthing` in the `Registry`. If the function is not in the
+           registry, then we try to find a `tvm.ir.op.Op` with the same name.
+        2. All other names `tvm.something` are lookup up in this current python
+           namespace.
         """
 
-        if isinstance(node.value, ast.Name):
-            if node.value.id == "tir":
-                func_name = "tir." + node.attr
+        if isinstance(node.object, ast.Var):
+            if node.object.id.name == "tir":
+                func_name = "tir." + node.field.name
                 res = Registry.lookup(func_name)
                 if res is not None:
                     return res
                 try:
                     return tvm.ir.op.Op.get(func_name)
-                except AttributeError:
-                    self.report_error("Unregistered function tir." + node.attr)
-            elif node.value.id == "ty":
-                if not hasattr(ty, node.attr):
-                    self.report_error("invalid type annotation ty." + node.attr)
-                return getattr(ty, node.attr)
-
-        symbol = self.visit(node.value)
+                except TVMError as e:
+                    # Check if we got an attribute error
+                    if e.args[0].find("AttributeError"):
+                        self.report_error(
+                            f"Unregistered function `tir.{node.field.name}`.", node.field.span
+                        )
+                    else:
+                        raise e
+
+        symbol = self.transform(node.object)
         if symbol is None:
-            self.report_error("Unsupported Attribute expression")
-        if not hasattr(symbol, node.attr):
-            self.report_error("Type " + type(symbol) + " has not attr " + node.attr)
-        res = getattr(symbol, node.attr)
+            self.report_error("Unsupported Attribute expression.", node.object.span)
+        if not hasattr(symbol, node.field.name):
+            self.report_error(
+                f"Type {type(symbol)} does not have a field called `{node.field}`.", node.span
+            )
+        res = getattr(symbol, node.field.name)
         return res
 
-    def visit_Dict(self, node):
-        """Dict visitor
-        AST abstract grammar:
-            Dict(expr* keys, expr* values)
+    def transform_TypeAttr(self, node):
+        """Visitor for field access of the form `x.y` for types.
+
+        We have two cases here:
+        1. If the type is of the form `ty.something`, we look up the type in
+           the `ty` namespace in this module.
+        2. If the type is of the form `tvm.x.somthing` then we look up
+           `tvm.x.somthing` in this modules namespace.

Review comment:
       ```suggestion
           2. If the type is of the form `tvm.x.something` then we look up
              `tvm.x.something` in this modules namespace.
   ```




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org