You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2022/02/08 17:44:50 UTC

[GitHub] [tvm] areusch commented on a change in pull request #8509: [TIR] Tir constants integration into compilation pipeline

areusch commented on a change in pull request #8509:
URL: https://github.com/apache/tvm/pull/8509#discussion_r801844326



##########
File path: python/tvm/script/tir/scope_handler.py
##########
@@ -157,6 +158,53 @@ def setup_buffer_var(
         context.update_symbol(name, self.buffer_var, node)
 
 
+@register
+class AllocateConst(WithScopeHandler):
+    """With scope handler tir.allocate_const(data, extents, dtype, condition)"""

Review comment:
       could you guys expand the comment? what does this class do and where should it be used?

##########
File path: tests/python/relay/test_pass_fuse_ops.py
##########
@@ -635,27 +635,31 @@ def before():
         out = relay.op.take(concat, indices=relay.const([0], dtype="int64"))
         return relay.Function(relay.analysis.free_vars(out), out)
 
-    def expected():
+    def expected(link_params):
         shape1 = (tvm.tir.const(10, "int64"), tvm.tir.const(1, "int64"))
         shape2 = (tvm.tir.const(1, "int64"),)
         x = relay.var("x", shape=shape1)
         p0 = relay.var("p0", shape=shape1)
         p1 = relay.var("p1", shape=shape2, dtype="int64")
         c = relay.const([0], dtype="int64")
         concat = relay.concatenate([p0, p0], axis=-1)
-        out = relay.op.take(concat, indices=p1)
+        out = relay.op.take(concat, indices=c if link_params else p1)
 
-        f0 = relay.Function([p0, p1], out)
+        f0 = relay.Function([p0] if link_params else [p0, p1], out)
         f0 = f0.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
 
-        y = relay.Call(f0, [x, c])
+        y = relay.Call(f0, [x] if link_params else [x, c])
         return relay.Function([x], y)
 
-    orig = before()
-    m = fuse2(tvm.IRModule.from_expr(orig))
-    relay.build(m, "llvm")
-    after = run_opt_pass(expected(), transform.InferType())
-    assert tvm.ir.structural_equal(m["main"], after)
+    for link_params in [False, True]:

Review comment:
       want to make this a tvm.testing.parameter?

##########
File path: src/tir/ir/stmt.cc
##########
@@ -366,19 +366,19 @@ Allocate::Allocate(Var buffer_var, DataType dtype, Array<PrimExpr> extents, Prim
   data_ = std::move(node);
 }
 
-int32_t AllocateNode::constant_allocation_size(const Array<PrimExpr>& extents) {
-  int64_t result = 1;
+size_t AllocateNode::ConstantAllocationSize(const Array<PrimExpr>& extents) {
+  size_t result = 1;

Review comment:
       i think we should use the same data type as would be used to compute extents from a DLTensor, which i think is `int64_t`. thoughts?

##########
File path: python/tvm/te/operation.py
##########
@@ -362,6 +362,28 @@ def var(name="tindex", dtype="int32", span=None):
     return tvm.tir.Var(name, dtype, span)
 
 
+def const(name="tindex", dtype="int32", span=None):

Review comment:
       should default name be something other than tindex? or should it even be possible to omit the name?

##########
File path: src/ir/module.cc
##########
@@ -219,6 +224,64 @@ void IRModuleNode::AddUnchecked(const GlobalVar& var, const BaseFunc& func) {
   global_var_map_.Set(var->name_hint, var);
 }
 
+// Replaces constant data to index into mod's "Constants" attrs array.
+// Only processes tir::PrimFunc and ignores everything else
+void IRModuleNode::ExtractConstants(BaseFunc f) {
+  using ConstArrayType = Array<runtime::NDArray>;
+  class Applicator : public tir::StmtExprVisitor {
+   protected:
+    // returns index of the a in constant_array_, if not found - appends
+    // TODO(@d-smirnov): make real content comparision with already existing NDArrays

Review comment:
       does this TODO impact functionality? could you explain why (e.g. any hidden assumptions here)?

##########
File path: src/printer/tvmscript_printer.cc
##########
@@ -349,6 +350,26 @@ class TVMScriptPrinter : public StmtFunctor<Doc(const Stmt&)>,
   }
 };
 
+/*!
+ * \brief special method to print NDArray in TIR
+ * \param arr the NDArray to be printed
+ * \param os the output stream where the NDArray will be printed to
+ */
+template <typename T>
+void NDArrayToTIR(::tvm::runtime::NDArray arr, std::ostream& os) {
+  int ndim = arr->ndim;
+  int tot_dim = 1;
+  for (int i = 0; i < ndim; i++) {
+    tot_dim *= arr->shape[i];
+  }
+  T* data_ptr = reinterpret_cast<T*>(arr->data);
+  os << "[";
+  for (int i = 0; i < tot_dim; i++) {
+    os << data_ptr[i] << ", ";

Review comment:
       should we omit the final comma?

##########
File path: tests/python/relay/test_pass_fuse_ops.py
##########
@@ -668,27 +672,31 @@ def before():
         out = relay.gather_nd(concat, indices=relay.expr.const([[0, 1], [1, 0]], dtype="int64"))
         return relay.Function(relay.analysis.free_vars(out), out)
 
-    def expected():
+    def expected(link_params):
         shape1 = (tvm.tir.const(10, "int64"), tvm.tir.const(1, "int64"))
         shape2 = (tvm.tir.const(2, "int64"), tvm.tir.const(2, "int64"))
         x = relay.var("x", shape=shape1)
         p0 = relay.var("p0", shape=shape1)
         p1 = relay.var("p1", shape=shape2, dtype="int64")
         c = relay.const([[0, 1], [1, 0]], dtype="int64")
         concat = relay.concatenate([p0, p0], axis=-1)
-        out = relay.gather_nd(concat, indices=p1)
+        out = relay.gather_nd(concat, indices=c if link_params else p1)
 
-        f0 = relay.Function([p0, p1], out)
+        f0 = relay.Function([p0] if link_params else [p0, p1], out)
         f0 = f0.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
 
-        y = relay.Call(f0, [x, c])
+        y = relay.Call(f0, [x] if link_params else [x, c])
         return relay.Function([x], y)
 
-    orig = before()
-    m = fuse2(tvm.IRModule.from_expr(orig))
-    relay.build(m, "llvm")
-    after = run_opt_pass(expected(), transform.InferType())
-    assert tvm.ir.structural_equal(m["main"], after)
+    for link_params in [False, True]:

Review comment:
       same question here




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org