You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/06/27 13:09:45 UTC

[GitHub] [tvm] Hzfengsy opened a new pull request #8354: [TIR][TVMScript] specialize

Hzfengsy opened a new pull request #8354:
URL: https://github.com/apache/tvm/pull/8354


   This PR enables meta_programming for TVM Functions (usually used for TVMScript), which is useful to define only one function(TVMScript) but used in different shapes.
   
   Example:
   Mate fucntion
   ```Python
   @tvm.script.tir
   def mem_copy(a: ty.handle, b: ty.handle, m: ty.int32, n: ty.int32) -> None:
       A = tir.match_buffer(a, (m, n), "float32")
       B = tir.match_buffer(b, (m, n), "float32")
   
       with tir.block([m, n], "") as [vi, vj]:
           B[vi, vj] = A[vi, vj]
   ```
   
   Instruction:
   ```Python
   a, _, m, n = mem_copy.params
   func = mem_copy.specialize({a: tir.decl_buffer((16, 16))})
   # or
   func = mem_copy.specialize({n: 16, m: 16})
   ```
   
   Specialized function:
   ```Python
   @tvm.script.tir
   def mem_copy_16_16(a: ty.handle, b: ty.handle) -> None:
       A = tir.match_buffer(a, (16, 16), "float32")
       B = tir.match_buffer(b, (16, 16), "float32")
   
       with tir.block([16, 16], "") as [vi, vj]:
           B[vi, vj] = A[vi, vj]
   ```
   
   cc @tqchen @junrushao1994 @MasterJH5574 @comaniac 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] tqchen commented on pull request #8354: [TIR][TVMScript] specialize

Posted by GitBox <gi...@apache.org>.
tqchen commented on pull request #8354:
URL: https://github.com/apache/tvm/pull/8354#issuecomment-869648242


   cc @vinx13 @junrushao1994 please help to take a look


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] tqchen commented on a change in pull request #8354: [TIR][TVMScript] specialize

Posted by GitBox <gi...@apache.org>.
tqchen commented on a change in pull request #8354:
URL: https://github.com/apache/tvm/pull/8354#discussion_r659730908



##########
File path: src/tir/ir/specialize.cc
##########
@@ -0,0 +1,326 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/tir/ir/specialize.cc
+ * \brief Specialize parameters of PrimFunc.
+ */
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/function.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include <functional>
+
+#include "functor_common.h"
+
+namespace tvm {
+namespace tir {
+
+using VarMap = std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual>;
+
+/**************** Helper functions ****************/
+
+/*! \brief Helper function to check whether the given var is in function parameter list. */
+inline bool IsParam(const PrimFunc& func, const Var& param) {
+  return std::any_of(func->params.begin(), func->params.end(),
+                     [&](const Var& var) { return var.same_as(param); });
+}
+
+/*! \brief Mutator to specialize function and remove const parameters */
+class PrimFuncSpecializer : public StmtExprMutator {
+ public:
+  explicit PrimFuncSpecializer(VarMap var_map) : var_map_(var_map) {}
+
+  static PrimFunc Specialize(PrimFunc f, const VarMap& var_map) {
+    PrimFuncSpecializer specializer(var_map);
+    // Updating Buffer map
+    Map<Var, Buffer> buffer_map;
+    for (const auto& it : f->buffer_map) {
+      const Var& var = it.first;
+      const Buffer& buffer = it.second;
+      Buffer new_buffer = specializer.MutateBuffer(buffer);
+      buffer_map.Set(var, new_buffer);
+      if (!new_buffer.same_as(buffer)) {
+        specializer.buffer_map_[buffer] = new_buffer;
+      }
+    }
+
+    // Updating parmeters
+    Array<Var> params;
+    for (const auto& var : f->params) {
+      // Remove parmeters which has been specialized.
+      if (var_map.find(var) == var_map.end()) {
+        params.push_back(var);
+      }
+    }
+
+    PrimFuncNode* f_ptr = f.CopyOnWrite();
+    f_ptr->params = std::move(params);
+    f_ptr->buffer_map = std::move(buffer_map);
+    f_ptr->body = specializer(std::move(f_ptr->body));
+
+    // Updating attrs
+    if (f->attrs.defined()) {
+      auto& attr_dict = f_ptr->attrs.CopyOnWrite()->dict;
+      for (const auto& kv : attr_dict) {
+        const String& key = kv.first;
+        const ObjectRef& value = kv.second;
+        if (value->IsInstance<PrimExprNode>()) {
+          attr_dict.Set(key, Substitute(Downcast<PrimExpr>(value), var_map));
+        }
+      }
+    }
+    return f;
+  }
+
+ private:
+  Stmt VisitStmt_(const BlockNode* op) final {
+    Array<Buffer> alloc_buffers = MutateArray(
+        op->alloc_buffers,
+        std::bind(&PrimFuncSpecializer::MutateAllocBuffer, this, std::placeholders::_1));
+    Array<BufferRegion> reads = MutateArray(
+        op->reads,
+        std::bind(&PrimFuncSpecializer::MutateBufferRegion, this, std::placeholders::_1));
+    Array<BufferRegion> writes = MutateArray(
+        op->writes,
+        std::bind(&PrimFuncSpecializer::MutateBufferRegion, this, std::placeholders::_1));
+    Array<IterVar> block_vars = MutateArray(
+        op->iter_vars, std::bind(&PrimFuncSpecializer::MutateIterVar, this, std::placeholders::_1));
+    Optional<Stmt> init = NullOpt;
+    if (op->init.defined()) {
+      init = VisitStmt(op->init.value());
+    }
+    Stmt body = VisitStmt(op->body);
+
+    if (alloc_buffers.same_as(op->alloc_buffers) && reads.same_as(op->reads) &&
+        writes.same_as(op->writes) && block_vars.same_as(op->iter_vars) && body.same_as(op->body) &&
+        init.same_as(op->init)) {
+      return GetRef<Block>(op);
+    } else {
+      ObjectPtr<BlockNode> n = CopyOnWrite(op);
+      n->alloc_buffers = std::move(alloc_buffers);
+      n->reads = std::move(reads);
+      n->writes = std::move(writes);
+      n->iter_vars = std::move(block_vars);
+      n->body = std::move(body);
+      n->init = std::move(init);
+      return Stmt(n);
+    }
+  }
+
+  Stmt VisitStmt_(const BufferStoreNode* op) final {
+    auto it = buffer_map_.find(op->buffer);
+    if (it == buffer_map_.end()) {

Review comment:
       need to recursive visit the expressions(e.g. what if the src contains a buffer load, please add a regression test).
   
   Always consider recursive visit first

##########
File path: src/tir/ir/specialize.cc
##########
@@ -0,0 +1,326 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/tir/ir/specialize.cc
+ * \brief Specialize parameters of PrimFunc.
+ */
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/function.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include <functional>
+
+#include "functor_common.h"
+
+namespace tvm {
+namespace tir {
+
+using VarMap = std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual>;
+
+/**************** Helper functions ****************/
+
+/*! \brief Helper function to check whether the given var is in function parameter list. */
+inline bool IsParam(const PrimFunc& func, const Var& param) {
+  return std::any_of(func->params.begin(), func->params.end(),
+                     [&](const Var& var) { return var.same_as(param); });
+}
+
+/*! \brief Mutator to specialize function and remove const parameters */
+class PrimFuncSpecializer : public StmtExprMutator {
+ public:
+  explicit PrimFuncSpecializer(VarMap var_map) : var_map_(var_map) {}
+
+  static PrimFunc Specialize(PrimFunc f, const VarMap& var_map) {
+    PrimFuncSpecializer specializer(var_map);
+    // Updating Buffer map
+    Map<Var, Buffer> buffer_map;
+    for (const auto& it : f->buffer_map) {
+      const Var& var = it.first;
+      const Buffer& buffer = it.second;
+      Buffer new_buffer = specializer.MutateBuffer(buffer);
+      buffer_map.Set(var, new_buffer);
+      if (!new_buffer.same_as(buffer)) {
+        specializer.buffer_map_[buffer] = new_buffer;
+      }
+    }
+
+    // Updating parmeters
+    Array<Var> params;
+    for (const auto& var : f->params) {
+      // Remove parmeters which has been specialized.
+      if (var_map.find(var) == var_map.end()) {
+        params.push_back(var);
+      }
+    }
+
+    PrimFuncNode* f_ptr = f.CopyOnWrite();
+    f_ptr->params = std::move(params);
+    f_ptr->buffer_map = std::move(buffer_map);
+    f_ptr->body = specializer(std::move(f_ptr->body));
+
+    // Updating attrs
+    if (f->attrs.defined()) {
+      auto& attr_dict = f_ptr->attrs.CopyOnWrite()->dict;
+      for (const auto& kv : attr_dict) {
+        const String& key = kv.first;
+        const ObjectRef& value = kv.second;
+        if (value->IsInstance<PrimExprNode>()) {
+          attr_dict.Set(key, Substitute(Downcast<PrimExpr>(value), var_map));
+        }
+      }
+    }
+    return f;
+  }
+
+ private:
+  Stmt VisitStmt_(const BlockNode* op) final {
+    Array<Buffer> alloc_buffers = MutateArray(
+        op->alloc_buffers,
+        std::bind(&PrimFuncSpecializer::MutateAllocBuffer, this, std::placeholders::_1));
+    Array<BufferRegion> reads = MutateArray(
+        op->reads,
+        std::bind(&PrimFuncSpecializer::MutateBufferRegion, this, std::placeholders::_1));
+    Array<BufferRegion> writes = MutateArray(
+        op->writes,
+        std::bind(&PrimFuncSpecializer::MutateBufferRegion, this, std::placeholders::_1));
+    Array<IterVar> block_vars = MutateArray(
+        op->iter_vars, std::bind(&PrimFuncSpecializer::MutateIterVar, this, std::placeholders::_1));
+    Optional<Stmt> init = NullOpt;
+    if (op->init.defined()) {
+      init = VisitStmt(op->init.value());
+    }
+    Stmt body = VisitStmt(op->body);
+
+    if (alloc_buffers.same_as(op->alloc_buffers) && reads.same_as(op->reads) &&
+        writes.same_as(op->writes) && block_vars.same_as(op->iter_vars) && body.same_as(op->body) &&
+        init.same_as(op->init)) {
+      return GetRef<Block>(op);
+    } else {
+      ObjectPtr<BlockNode> n = CopyOnWrite(op);
+      n->alloc_buffers = std::move(alloc_buffers);
+      n->reads = std::move(reads);
+      n->writes = std::move(writes);
+      n->iter_vars = std::move(block_vars);
+      n->body = std::move(body);
+      n->init = std::move(init);
+      return Stmt(n);
+    }
+  }
+
+  Stmt VisitStmt_(const BufferStoreNode* op) final {
+    auto it = buffer_map_.find(op->buffer);
+    if (it == buffer_map_.end()) {
+      return GetRef<BufferStore>(op);
+    }
+
+    PrimExpr value = VisitExpr(op->value);
+    Array<PrimExpr> indices =
+        MutateArray(op->indices, [this](const PrimExpr& e) { return this->VisitExpr(e); });
+
+    auto n = CopyOnWrite(op);
+    n->buffer = it->second;
+    n->value = std::move(value);
+    n->indices = std::move(indices);
+    return Stmt(n);
+  }
+
+  PrimExpr VisitExpr_(const BufferLoadNode* op) final {
+    auto it = buffer_map_.find(op->buffer);
+    if (it == buffer_map_.end()) {
+      return GetRef<BufferLoad>(op);
+    }
+
+    Array<PrimExpr> indices =
+        MutateArray(op->indices, [this](const PrimExpr& e) { return this->VisitExpr(e); });

Review comment:
       need to recursive visit the indices, always use the post order recursive pattern when possible.
   
   e.g. first call StmtExprMutator::VisitExpr_(op);
   
   This will also allows you to skip the lines of indices mutation. Please add a regression test

##########
File path: src/tir/ir/specialize.cc
##########
@@ -0,0 +1,326 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/tir/ir/specialize.cc
+ * \brief Specialize parameters of PrimFunc.
+ */
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/function.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include <functional>
+
+#include "functor_common.h"
+
+namespace tvm {
+namespace tir {
+
+using VarMap = std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual>;
+
+/**************** Helper functions ****************/
+
+/*! \brief Helper function to check whether the given var is in function parameter list. */
+inline bool IsParam(const PrimFunc& func, const Var& param) {
+  return std::any_of(func->params.begin(), func->params.end(),
+                     [&](const Var& var) { return var.same_as(param); });
+}
+
+/*! \brief Mutator to specialize function and remove const parameters */
+class PrimFuncSpecializer : public StmtExprMutator {
+ public:
+  explicit PrimFuncSpecializer(VarMap var_map) : var_map_(var_map) {}
+
+  static PrimFunc Specialize(PrimFunc f, const VarMap& var_map) {
+    PrimFuncSpecializer specializer(var_map);
+    // Updating Buffer map
+    Map<Var, Buffer> buffer_map;
+    for (const auto& it : f->buffer_map) {
+      const Var& var = it.first;
+      const Buffer& buffer = it.second;
+      Buffer new_buffer = specializer.MutateBuffer(buffer);
+      buffer_map.Set(var, new_buffer);
+      if (!new_buffer.same_as(buffer)) {
+        specializer.buffer_map_[buffer] = new_buffer;
+      }
+    }
+
+    // Updating parmeters
+    Array<Var> params;
+    for (const auto& var : f->params) {
+      // Remove parmeters which has been specialized.
+      if (var_map.find(var) == var_map.end()) {
+        params.push_back(var);
+      }
+    }
+
+    PrimFuncNode* f_ptr = f.CopyOnWrite();
+    f_ptr->params = std::move(params);
+    f_ptr->buffer_map = std::move(buffer_map);
+    f_ptr->body = specializer(std::move(f_ptr->body));
+
+    // Updating attrs
+    if (f->attrs.defined()) {
+      auto& attr_dict = f_ptr->attrs.CopyOnWrite()->dict;
+      for (const auto& kv : attr_dict) {
+        const String& key = kv.first;
+        const ObjectRef& value = kv.second;
+        if (value->IsInstance<PrimExprNode>()) {
+          attr_dict.Set(key, Substitute(Downcast<PrimExpr>(value), var_map));
+        }
+      }
+    }
+    return f;
+  }
+
+ private:
+  Stmt VisitStmt_(const BlockNode* op) final {
+    Array<Buffer> alloc_buffers = MutateArray(
+        op->alloc_buffers,
+        std::bind(&PrimFuncSpecializer::MutateAllocBuffer, this, std::placeholders::_1));
+    Array<BufferRegion> reads = MutateArray(
+        op->reads,
+        std::bind(&PrimFuncSpecializer::MutateBufferRegion, this, std::placeholders::_1));
+    Array<BufferRegion> writes = MutateArray(
+        op->writes,
+        std::bind(&PrimFuncSpecializer::MutateBufferRegion, this, std::placeholders::_1));
+    Array<IterVar> block_vars = MutateArray(
+        op->iter_vars, std::bind(&PrimFuncSpecializer::MutateIterVar, this, std::placeholders::_1));
+    Optional<Stmt> init = NullOpt;
+    if (op->init.defined()) {
+      init = VisitStmt(op->init.value());
+    }
+    Stmt body = VisitStmt(op->body);
+
+    if (alloc_buffers.same_as(op->alloc_buffers) && reads.same_as(op->reads) &&
+        writes.same_as(op->writes) && block_vars.same_as(op->iter_vars) && body.same_as(op->body) &&
+        init.same_as(op->init)) {
+      return GetRef<Block>(op);
+    } else {
+      ObjectPtr<BlockNode> n = CopyOnWrite(op);
+      n->alloc_buffers = std::move(alloc_buffers);
+      n->reads = std::move(reads);
+      n->writes = std::move(writes);
+      n->iter_vars = std::move(block_vars);
+      n->body = std::move(body);
+      n->init = std::move(init);
+      return Stmt(n);
+    }
+  }
+
+  Stmt VisitStmt_(const BufferStoreNode* op) final {
+    auto it = buffer_map_.find(op->buffer);
+    if (it == buffer_map_.end()) {
+      return GetRef<BufferStore>(op);
+    }
+
+    PrimExpr value = VisitExpr(op->value);
+    Array<PrimExpr> indices =
+        MutateArray(op->indices, [this](const PrimExpr& e) { return this->VisitExpr(e); });
+
+    auto n = CopyOnWrite(op);

Review comment:
       always try to return the original ones when possible.

##########
File path: tests/python/unittest/test_tvmscript_meta_programming.py
##########
@@ -0,0 +1,185 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-function-docstring, missing-module-docstring
+
+import tvm
+from tvm import tir
+from tvm.script import ty
+
+
+@tvm.script.tir
+def matmul(a: ty.handle, b: ty.handle, c: ty.handle, n: ty.int32) -> None:
+    m = tir.var("int32")
+    A = tir.match_buffer(a, [m, n])
+    B = tir.match_buffer(b, [m, n])
+    C = tir.match_buffer(c, [m, m])
+
+    with tir.block([m, m, tir.reduce_axis(0, n)], "update") as [vi, vj, vk]:
+        with tir.init():
+            C[vi, vj] = 0.0
+        C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
+
+
+@tvm.script.tir
+def matmul_128(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
+    A = tir.match_buffer(a, [128, 128])
+    B = tir.match_buffer(b, [128, 128])
+    C = tir.match_buffer(c, [128, 128])
+
+    with tir.block([128, 128, tir.reduce_axis(0, 128)], "update") as [vi, vj, vk]:
+        with tir.init():
+            C[vi, vj] = 0.0
+        C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
+
+
+@tvm.script.tir
+def matmul_m_128(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
+    m = tir.var("int32")
+    A = tir.match_buffer(a, [m, 128])
+    B = tir.match_buffer(b, [m, 128])
+    C = tir.match_buffer(c, [m, m])
+
+    with tir.block([m, m, tir.reduce_axis(0, 128)], "update") as [vi, vj, vk]:
+        with tir.init():
+            C[vi, vj] = 0.0
+        C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
+
+
+@tvm.script.tir
+def matmul_m_8x(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
+    x = tir.var("int32")
+    m = tir.var("int32")
+    A = tir.match_buffer(a, [m, x * 8])
+    B = tir.match_buffer(b, [m, x * 8])
+    C = tir.match_buffer(c, [m, m])
+
+    with tir.block([m, m, tir.reduce_axis(0, x * 8)], "update") as [vi, vj, vk]:
+        with tir.init():
+            C[vi, vj] = 0.0
+        C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
+
+
+@tvm.script.tir
+def element_wise(a: ty.handle, c: ty.handle) -> None:
+    m = tir.var("int32")
+    n = tir.var("int32")
+    A = tir.match_buffer(a, (m, n), "float32")
+    C = tir.match_buffer(c, (m, n), "float32")
+
+    B = tir.alloc_buffer((m, n), "float32")
+
+    with tir.block([m, n], "B") as [vi, vj]:
+        B[vi, vj] = A[vi, vj] * 2.0
+
+    with tir.block([m, n], "C") as [vi, vj]:
+        C[vi, vj] = B[vi, vj] + 1.0
+
+
+@tvm.script.tir
+def element_wise_128_64(a: ty.handle, c: ty.handle) -> None:
+    A = tir.match_buffer(a, (128, 64), "float32")
+    C = tir.match_buffer(c, (128, 64), "float32")
+    B = tir.alloc_buffer((128, 64), "float32")
+
+    with tir.block([128, 64], "B") as [vi, vj]:
+        B[vi, vj] = A[vi, vj] * 2.0
+
+    with tir.block([128, 64], "C") as [vi, vj]:
+        C[vi, vj] = B[vi, vj] + 1.0
+
+
+@tvm.script.tir
+def element_wise_128_n(a: ty.handle, c: ty.handle) -> None:
+    n = tir.var("int32")
+    A = tir.match_buffer(a, (128, n), "float32")
+    C = tir.match_buffer(c, (128, n), "float32")
+    B = tir.alloc_buffer((128, n), "float32")
+
+    with tir.block([128, n], "B") as [vi, vj]:
+        B[vi, vj] = A[vi, vj] * 2.0
+
+    with tir.block([128, n], "C") as [vi, vj]:
+        C[vi, vj] = B[vi, vj] + 1.0
+
+
+@tvm.script.tir
+def mem_copy(a: ty.handle, b: ty.handle, m: ty.int32, n: ty.int32) -> None:
+    A = tir.match_buffer(a, (m, n), "float32")
+    B = tir.match_buffer(b, (m, n), "float32")
+
+    with tir.block([m, n], "") as [vi, vj]:
+        B[vi, vj] = A[vi, vj]
+
+
+@tvm.script.tir
+def mem_copy_16_16(a: ty.handle, b: ty.handle) -> None:
+    A = tir.match_buffer(a, (16, 16), "float32")
+    B = tir.match_buffer(b, (16, 16), "float32")
+
+    with tir.block([16, 16], "") as [vi, vj]:
+        B[vi, vj] = A[vi, vj]
+
+
+@tvm.script.tir
+def mem_copy_m_16(a: ty.handle, b: ty.handle, m: ty.int32) -> None:
+    A = tir.match_buffer(a, (m, 16), "float32")
+    B = tir.match_buffer(b, (m, 16), "float32")
+
+    with tir.block([m, 16], "") as [vi, vj]:
+        B[vi, vj] = A[vi, vj]
+
+
+def test_tensor_dimension_invariant_code_matmul():
+    a, _, _, n = matmul.params
+    # fully specialized
+    func = matmul.specialize({a: tir.decl_buffer((128, 128))})
+    tvm.ir.assert_structural_equal(func, matmul_128)
+    # partially specialized
+    func = matmul.specialize({n: 128})
+    tvm.ir.assert_structural_equal(func, matmul_m_128)
+    # symbolic specialized
+    func = matmul.specialize({n: tir.Var("x", "int32") * 8})
+    tvm.ir.assert_structural_equal(func, matmul_m_8x)
+
+
+def test_tensor_dimension_invariant_code_elemwise():
+    a, c = element_wise.params
+    C = element_wise.buffer_map[c]
+    # fully specialized
+    func = element_wise.specialize({a: tir.decl_buffer((128, 64))})
+    tvm.ir.assert_structural_equal(func, element_wise_128_64)
+    # partially specialized
+    func = element_wise.specialize({c: tir.decl_buffer((128, C.shape[1]))})
+    tvm.ir.assert_structural_equal(func, element_wise_128_n)

Review comment:
       add a testcase of nothing being specialized, require that the function itself remains the same(pointer equality)




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] Hzfengsy commented on pull request #8354: [TIR][TVMScript] specialize

Posted by GitBox <gi...@apache.org>.
Hzfengsy commented on pull request #8354:
URL: https://github.com/apache/tvm/pull/8354#issuecomment-870552987


   Thanks, @comaniac. Just have fixed it


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] tqchen commented on a change in pull request #8354: [TIR][TVMScript] specialize

Posted by GitBox <gi...@apache.org>.
tqchen commented on a change in pull request #8354:
URL: https://github.com/apache/tvm/pull/8354#discussion_r659322308



##########
File path: python/tvm/tir/function.py
##########
@@ -85,3 +87,19 @@ def with_body(self, new_body, span=None):
             The created new function.
         """
         return PrimFunc(self.params, new_body, self.ret_type, self.buffer_map, self.attrs, span)
+
+    def specialize(self, param_map: Mapping[Var, Union[PrimExpr, Buffer]]):
+        """Metaprogramming usage: specialize parameters of PrimFunc
+

Review comment:
       Add an code example block to show before and after




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] Hzfengsy commented on a change in pull request #8354: [TIR][TVMScript] specialize

Posted by GitBox <gi...@apache.org>.
Hzfengsy commented on a change in pull request #8354:
URL: https://github.com/apache/tvm/pull/8354#discussion_r660399589



##########
File path: src/tir/ir/specialize.cc
##########
@@ -0,0 +1,326 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/tir/ir/specialize.cc
+ * \brief Specialize parameters of PrimFunc.
+ */
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/function.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include <functional>
+
+#include "functor_common.h"
+
+namespace tvm {
+namespace tir {
+
+using VarMap = std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual>;
+
+/**************** Helper functions ****************/
+
+/*! \brief Helper function to check whether the given var is in function parameter list. */
+inline bool IsParam(const PrimFunc& func, const Var& param) {
+  return std::any_of(func->params.begin(), func->params.end(),
+                     [&](const Var& var) { return var.same_as(param); });
+}
+
+/*! \brief Mutator to specialize function and remove const parameters */
+class PrimFuncSpecializer : public StmtExprMutator {
+ public:
+  explicit PrimFuncSpecializer(VarMap var_map) : var_map_(var_map) {}
+
+  static PrimFunc Specialize(PrimFunc f, const VarMap& var_map) {
+    PrimFuncSpecializer specializer(var_map);
+    // Updating Buffer map
+    Map<Var, Buffer> buffer_map;
+    for (const auto& it : f->buffer_map) {
+      const Var& var = it.first;
+      const Buffer& buffer = it.second;
+      Buffer new_buffer = specializer.MutateBuffer(buffer);
+      buffer_map.Set(var, new_buffer);
+      if (!new_buffer.same_as(buffer)) {
+        specializer.buffer_map_[buffer] = new_buffer;
+      }
+    }
+
+    // Updating parmeters
+    Array<Var> params;
+    for (const auto& var : f->params) {
+      // Remove parmeters which has been specialized.
+      if (var_map.find(var) == var_map.end()) {
+        params.push_back(var);
+      }
+    }
+
+    PrimFuncNode* f_ptr = f.CopyOnWrite();
+    f_ptr->params = std::move(params);
+    f_ptr->buffer_map = std::move(buffer_map);
+    f_ptr->body = specializer(std::move(f_ptr->body));
+
+    // Updating attrs
+    if (f->attrs.defined()) {
+      auto& attr_dict = f_ptr->attrs.CopyOnWrite()->dict;
+      for (const auto& kv : attr_dict) {
+        const String& key = kv.first;
+        const ObjectRef& value = kv.second;
+        if (value->IsInstance<PrimExprNode>()) {
+          attr_dict.Set(key, Substitute(Downcast<PrimExpr>(value), var_map));
+        }
+      }
+    }
+    return f;
+  }
+
+ private:
+  Stmt VisitStmt_(const BlockNode* op) final {
+    Array<Buffer> alloc_buffers = MutateArray(
+        op->alloc_buffers,
+        std::bind(&PrimFuncSpecializer::MutateAllocBuffer, this, std::placeholders::_1));
+    Array<BufferRegion> reads = MutateArray(
+        op->reads,
+        std::bind(&PrimFuncSpecializer::MutateBufferRegion, this, std::placeholders::_1));
+    Array<BufferRegion> writes = MutateArray(
+        op->writes,
+        std::bind(&PrimFuncSpecializer::MutateBufferRegion, this, std::placeholders::_1));
+    Array<IterVar> block_vars = MutateArray(
+        op->iter_vars, std::bind(&PrimFuncSpecializer::MutateIterVar, this, std::placeholders::_1));
+    Optional<Stmt> init = NullOpt;
+    if (op->init.defined()) {
+      init = VisitStmt(op->init.value());
+    }
+    Stmt body = VisitStmt(op->body);
+
+    if (alloc_buffers.same_as(op->alloc_buffers) && reads.same_as(op->reads) &&
+        writes.same_as(op->writes) && block_vars.same_as(op->iter_vars) && body.same_as(op->body) &&
+        init.same_as(op->init)) {
+      return GetRef<Block>(op);
+    } else {
+      ObjectPtr<BlockNode> n = CopyOnWrite(op);
+      n->alloc_buffers = std::move(alloc_buffers);
+      n->reads = std::move(reads);
+      n->writes = std::move(writes);
+      n->iter_vars = std::move(block_vars);
+      n->body = std::move(body);
+      n->init = std::move(init);
+      return Stmt(n);
+    }
+  }
+
+  Stmt VisitStmt_(const BufferStoreNode* op) final {
+    auto it = buffer_map_.find(op->buffer);
+    if (it == buffer_map_.end()) {
+      return GetRef<BufferStore>(op);
+    }
+
+    PrimExpr value = VisitExpr(op->value);
+    Array<PrimExpr> indices =
+        MutateArray(op->indices, [this](const PrimExpr& e) { return this->VisitExpr(e); });
+
+    auto n = CopyOnWrite(op);

Review comment:
       In this case, I return the original one before at line 131. If the `buffer` is in `buffer_map_` it will always change




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] tqchen commented on a change in pull request #8354: [TIR][TVMScript] specialize

Posted by GitBox <gi...@apache.org>.
tqchen commented on a change in pull request #8354:
URL: https://github.com/apache/tvm/pull/8354#discussion_r661431202



##########
File path: src/tir/ir/specialize.cc
##########
@@ -0,0 +1,311 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/tir/ir/specialize.cc
+ * \brief Specialize parameters of PrimFunc.
+ */
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/function.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include <functional>
+
+#include "functor_common.h"
+
+namespace tvm {
+namespace tir {
+
+using VarMap = std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual>;
+
+/**************** Helper functions ****************/
+
+/*! \brief Helper function to check whether the given var is in function parameter list. */
+inline bool IsParam(const PrimFunc& func, const Var& param) {
+  return std::any_of(func->params.begin(), func->params.end(),
+                     [&](const Var& var) { return var.same_as(param); });
+}
+
+/**************** Specializer ****************/
+
+/*! \brief Mutator to specialize function and remove const parameters */
+class PrimFuncSpecializer : public StmtExprMutator {
+ public:
+  explicit PrimFuncSpecializer(const VarMap& var_map) : var_map_(var_map) {}
+
+  static PrimFunc Specialize(PrimFunc f, const VarMap& var_map) {
+    PrimFuncSpecializer specializer(var_map);
+    // Updating Buffer map
+    Map<Var, Buffer> buffer_map;
+    bool buffer_map_updated = false;
+    for (const auto& it : f->buffer_map) {
+      const Var& var = it.first;
+      const Buffer& buffer = it.second;
+      Buffer new_buffer = specializer.MutateBuffer(buffer);
+      buffer_map.Set(var, new_buffer);
+      if (!new_buffer.same_as(buffer)) {
+        buffer_map_updated = true;
+        specializer.buffer_map_[buffer] = new_buffer;
+      }
+    }
+
+    // Updating parmeters
+    Array<Var> params;
+    bool param_updated = false;
+    for (const auto& var : f->params) {
+      // Remove parmeters which has been specialized.
+      if (var_map.find(var) == var_map.end()) {
+        params.push_back(var);
+      } else {
+        param_updated = true;
+      }
+    }
+
+    // Updating function body
+    Stmt body = specializer(f->body);
+
+    if (param_updated || buffer_map_updated || !f->body.same_as(body)) {
+      PrimFuncNode* f_ptr = f.CopyOnWrite();
+      f_ptr->params = std::move(params);
+      f_ptr->buffer_map = std::move(buffer_map);
+      f_ptr->body = std::move(body);
+    }
+    return f;
+  }
+
+ private:
+  Stmt VisitStmt_(const BlockNode* op) final {
+    // Step.0. Define buffer mappings which is allocated inside the block
+    Array<Buffer> alloc_buffers = MutateArray(
+        op->alloc_buffers,
+        std::bind(&PrimFuncSpecializer::MutateAllocBuffer, this, std::placeholders::_1));
+
+    // Step.1. Recursively visit block body
+    Stmt stmt = StmtExprMutator::VisitStmt_(op);
+    op = stmt.as<BlockNode>();
+    ICHECK(op != nullptr);
+
+    Array<BufferRegion> reads = MutateArray(
+        op->reads,
+        std::bind(&PrimFuncSpecializer::MutateBufferRegion, this, std::placeholders::_1));
+    Array<BufferRegion> writes = MutateArray(
+        op->writes,
+        std::bind(&PrimFuncSpecializer::MutateBufferRegion, this, std::placeholders::_1));
+
+    if (alloc_buffers.same_as(op->alloc_buffers) && reads.same_as(op->reads)) {
+      return GetRef<Block>(op);
+    } else {
+      ObjectPtr<BlockNode> n = CopyOnWrite(op);
+      n->alloc_buffers = std::move(alloc_buffers);
+      n->reads = std::move(reads);
+      n->writes = std::move(writes);
+      return Stmt(n);
+    }
+  }
+
+  Stmt VisitStmt_(const BufferStoreNode* op) final {
+    Stmt stmt = StmtExprMutator::VisitStmt_(op);
+    op = stmt.as<BufferStoreNode>();
+    ICHECK(op != nullptr);
+    auto it = buffer_map_.find(op->buffer);
+    if (it == buffer_map_.end()) {
+      return GetRef<BufferStore>(op);
+    } else {
+      auto n = CopyOnWrite(op);
+      n->buffer = it->second;
+      return Stmt(n);
+    }
+  }
+
+  PrimExpr VisitExpr_(const BufferLoadNode* op) final {
+    PrimExpr expr = StmtExprMutator::VisitExpr_(op);
+    op = expr.as<BufferLoadNode>();
+    ICHECK(op != nullptr);
+    auto it = buffer_map_.find(op->buffer);
+    if (it == buffer_map_.end()) {
+      return GetRef<BufferLoad>(op);
+    } else {
+      auto n = make_object<BufferLoadNode>(*op);
+      n->buffer = it->second;
+      return PrimExpr(n);
+    }
+  }
+
+  PrimExpr VisitExpr_(const VarNode* op) final {
+    auto it = var_map_.find(GetRef<Var>(op));
+    if (it == var_map_.end()) {
+      return GetRef<PrimExpr>(op);
+    } else {
+      return it->second;
+    }
+  }
+
+ private:
+  Buffer MutateBuffer(const Buffer& buffer) const {
+    Array<PrimExpr> shape =
+        MutateArray(buffer->shape, [this](const PrimExpr& e) { return Substitute(e, var_map_); });
+    Array<PrimExpr> strides =
+        MutateArray(buffer->strides, [this](const PrimExpr& e) { return Substitute(e, var_map_); });
+
+    PrimExpr elem_offset = Substitute(buffer->elem_offset, var_map_);
+
+    if (buffer->elem_offset.same_as(elem_offset) && buffer->shape.same_as(shape) &&
+        buffer->strides.same_as(strides)) {
+      return buffer;
+    } else {
+      auto n = make_object<BufferNode>(*buffer.get());
+      n->elem_offset = std::move(elem_offset);
+      n->shape = std::move(shape);
+      n->strides = std::move(strides);
+      return Buffer(n);
+    }
+  }
+
+  Range MutateRange(const Range& range) {
+    PrimExpr min = this->VisitExpr(range->min);
+    PrimExpr extent = this->VisitExpr(range->extent);
+    if (min.same_as(range->min) && extent.same_as(range->extent)) {
+      return range;
+    } else {
+      return Range::FromMinExtent(std::move(min), std::move(extent));
+    }
+  }
+
+  Buffer MutateAllocBuffer(const Buffer& alloc_buf) {
+    Buffer buf = MutateBuffer(alloc_buf);
+    if (buf.same_as(alloc_buf)) {
+      return alloc_buf;
+    } else {
+      ICHECK(buffer_map_.find(alloc_buf) == buffer_map_.end());
+      buffer_map_[alloc_buf] = buf;
+      return buf;
+    }
+  }
+
+  BufferRegion MutateBufferRegion(const BufferRegion& buffer_region) {
+    auto it = buffer_map_.find(buffer_region->buffer);
+    Array<Range> region =
+        MutateArray(buffer_region->region,
+                    std::bind(&PrimFuncSpecializer::MutateRange, this, std::placeholders::_1));
+    if (it == buffer_map_.end() && region.same_as(buffer_region->region)) {
+      return buffer_region;
+    } else {
+      return BufferRegion(it->second, std::move(region));
+    }
+  }
+
+ private:
+  /*! \brief The vars to be substitute and their values */
+  const VarMap& var_map_;
+  /*! \brief map from old buffer to mutated buffer */
+  std::unordered_map<Buffer, Buffer, ObjectPtrHash, ObjectPtrEqual> buffer_map_;
+};
+
+void UpdateSpecializeVarMap(const PrimFunc& func, const Var& param, const Buffer& specific_buf,

Review comment:
       document the function a bit

##########
File path: src/tir/ir/specialize.cc
##########
@@ -0,0 +1,311 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/tir/ir/specialize.cc
+ * \brief Specialize parameters of PrimFunc.
+ */
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/function.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include <functional>
+
+#include "functor_common.h"
+
+namespace tvm {
+namespace tir {
+
+using VarMap = std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual>;
+
+/**************** Helper functions ****************/
+
+/*! \brief Helper function to check whether the given var is in function parameter list. */
+inline bool IsParam(const PrimFunc& func, const Var& param) {
+  return std::any_of(func->params.begin(), func->params.end(),
+                     [&](const Var& var) { return var.same_as(param); });
+}
+
+/**************** Specializer ****************/
+
+/*! \brief Mutator to specialize function and remove const parameters */
+class PrimFuncSpecializer : public StmtExprMutator {
+ public:
+  explicit PrimFuncSpecializer(const VarMap& var_map) : var_map_(var_map) {}
+
+  static PrimFunc Specialize(PrimFunc f, const VarMap& var_map) {
+    PrimFuncSpecializer specializer(var_map);
+    // Updating Buffer map
+    Map<Var, Buffer> buffer_map;
+    bool buffer_map_updated = false;
+    for (const auto& it : f->buffer_map) {
+      const Var& var = it.first;
+      const Buffer& buffer = it.second;
+      Buffer new_buffer = specializer.MutateBuffer(buffer);
+      buffer_map.Set(var, new_buffer);
+      if (!new_buffer.same_as(buffer)) {
+        buffer_map_updated = true;
+        specializer.buffer_map_[buffer] = new_buffer;
+      }
+    }
+
+    // Updating parmeters
+    Array<Var> params;
+    bool param_updated = false;
+    for (const auto& var : f->params) {
+      // Remove parmeters which has been specialized.
+      if (var_map.find(var) == var_map.end()) {
+        params.push_back(var);
+      } else {
+        param_updated = true;
+      }
+    }
+
+    // Updating function body
+    Stmt body = specializer(f->body);
+
+    if (param_updated || buffer_map_updated || !f->body.same_as(body)) {
+      PrimFuncNode* f_ptr = f.CopyOnWrite();
+      f_ptr->params = std::move(params);
+      f_ptr->buffer_map = std::move(buffer_map);
+      f_ptr->body = std::move(body);
+    }
+    return f;
+  }
+
+ private:
+  Stmt VisitStmt_(const BlockNode* op) final {
+    // Step.0. Define buffer mappings which is allocated inside the block
+    Array<Buffer> alloc_buffers = MutateArray(
+        op->alloc_buffers,
+        std::bind(&PrimFuncSpecializer::MutateAllocBuffer, this, std::placeholders::_1));
+
+    // Step.1. Recursively visit block body
+    Stmt stmt = StmtExprMutator::VisitStmt_(op);
+    op = stmt.as<BlockNode>();
+    ICHECK(op != nullptr);
+
+    Array<BufferRegion> reads = MutateArray(
+        op->reads,
+        std::bind(&PrimFuncSpecializer::MutateBufferRegion, this, std::placeholders::_1));
+    Array<BufferRegion> writes = MutateArray(
+        op->writes,
+        std::bind(&PrimFuncSpecializer::MutateBufferRegion, this, std::placeholders::_1));
+
+    if (alloc_buffers.same_as(op->alloc_buffers) && reads.same_as(op->reads)) {
+      return GetRef<Block>(op);
+    } else {
+      ObjectPtr<BlockNode> n = CopyOnWrite(op);
+      n->alloc_buffers = std::move(alloc_buffers);
+      n->reads = std::move(reads);
+      n->writes = std::move(writes);
+      return Stmt(n);
+    }
+  }
+
+  Stmt VisitStmt_(const BufferStoreNode* op) final {
+    Stmt stmt = StmtExprMutator::VisitStmt_(op);
+    op = stmt.as<BufferStoreNode>();
+    ICHECK(op != nullptr);
+    auto it = buffer_map_.find(op->buffer);
+    if (it == buffer_map_.end()) {
+      return GetRef<BufferStore>(op);
+    } else {
+      auto n = CopyOnWrite(op);
+      n->buffer = it->second;
+      return Stmt(n);
+    }
+  }
+
+  PrimExpr VisitExpr_(const BufferLoadNode* op) final {
+    PrimExpr expr = StmtExprMutator::VisitExpr_(op);
+    op = expr.as<BufferLoadNode>();
+    ICHECK(op != nullptr);
+    auto it = buffer_map_.find(op->buffer);
+    if (it == buffer_map_.end()) {
+      return GetRef<BufferLoad>(op);
+    } else {
+      auto n = make_object<BufferLoadNode>(*op);
+      n->buffer = it->second;
+      return PrimExpr(n);
+    }
+  }
+
+  PrimExpr VisitExpr_(const VarNode* op) final {
+    auto it = var_map_.find(GetRef<Var>(op));
+    if (it == var_map_.end()) {
+      return GetRef<PrimExpr>(op);
+    } else {
+      return it->second;
+    }
+  }
+
+ private:
+  Buffer MutateBuffer(const Buffer& buffer) const {
+    Array<PrimExpr> shape =
+        MutateArray(buffer->shape, [this](const PrimExpr& e) { return Substitute(e, var_map_); });
+    Array<PrimExpr> strides =
+        MutateArray(buffer->strides, [this](const PrimExpr& e) { return Substitute(e, var_map_); });
+
+    PrimExpr elem_offset = Substitute(buffer->elem_offset, var_map_);
+
+    if (buffer->elem_offset.same_as(elem_offset) && buffer->shape.same_as(shape) &&
+        buffer->strides.same_as(strides)) {
+      return buffer;
+    } else {
+      auto n = make_object<BufferNode>(*buffer.get());
+      n->elem_offset = std::move(elem_offset);
+      n->shape = std::move(shape);
+      n->strides = std::move(strides);
+      return Buffer(n);
+    }
+  }
+
+  Range MutateRange(const Range& range) {
+    PrimExpr min = this->VisitExpr(range->min);
+    PrimExpr extent = this->VisitExpr(range->extent);
+    if (min.same_as(range->min) && extent.same_as(range->extent)) {
+      return range;
+    } else {
+      return Range::FromMinExtent(std::move(min), std::move(extent));
+    }
+  }
+
+  Buffer MutateAllocBuffer(const Buffer& alloc_buf) {
+    Buffer buf = MutateBuffer(alloc_buf);
+    if (buf.same_as(alloc_buf)) {
+      return alloc_buf;
+    } else {
+      ICHECK(buffer_map_.find(alloc_buf) == buffer_map_.end());
+      buffer_map_[alloc_buf] = buf;
+      return buf;
+    }
+  }
+
+  BufferRegion MutateBufferRegion(const BufferRegion& buffer_region) {
+    auto it = buffer_map_.find(buffer_region->buffer);
+    Array<Range> region =
+        MutateArray(buffer_region->region,
+                    std::bind(&PrimFuncSpecializer::MutateRange, this, std::placeholders::_1));
+    if (it == buffer_map_.end() && region.same_as(buffer_region->region)) {
+      return buffer_region;
+    } else {
+      return BufferRegion(it->second, std::move(region));
+    }
+  }
+
+ private:
+  /*! \brief The vars to be substitute and their values */
+  const VarMap& var_map_;
+  /*! \brief map from old buffer to mutated buffer */
+  std::unordered_map<Buffer, Buffer, ObjectPtrHash, ObjectPtrEqual> buffer_map_;
+};
+
+void UpdateSpecializeVarMap(const PrimFunc& func, const Var& param, const Buffer& specific_buf,
+                            VarMap* var_map) {
+  // preliminaries
+  tir::ExprDeepEqual equal;
+
+  auto it = func->buffer_map.find(param);
+  CHECK(it != func->buffer_map.end())
+      << "ValueError: specialize expects param to be in PrimFunc's buffer_map";
+  const Buffer& buf_to_specialize = (*it).second;
+
+  // build var mapping using specific_buf's parameters
+  auto build_var_mapping = [&](const PrimExpr& new_expr, const PrimExpr& old_expr) {
+    if (!equal(new_expr, old_expr)) {
+      CHECK(old_expr->IsInstance<VarNode>())
+          << "TypeError: The signature of target buffer exprected an independent Var, but got "
+          << old_expr << ".";
+      const Var& var = Downcast<Var>(old_expr);
+      auto it = var_map->find(var);
+      if (it != var_map->end()) {
+        CHECK(equal(it->second, new_expr))
+            << "ValueError: The assigned value of var " << var << " mismatched. " << it->second
+            << " vs. " << new_expr << ".";
+      } else {
+        (*var_map)[var] = new_expr;
+      }
+    }
+  };
+
+  // Check buffer dimensions
+  CHECK(specific_buf->shape.size() == buf_to_specialize->shape.size())
+      << "ValueError: The buffer dimensions mismatched" << buf_to_specialize->shape.size()
+      << " vs. " << specific_buf->shape.size() << ".";
+
+  CHECK(specific_buf->strides.size() == buf_to_specialize->strides.size())
+      << "ValueError: The buffer strides dimensions mismatched" << buf_to_specialize->strides.size()
+      << " vs. " << specific_buf->strides.size() << ".";
+
+  // Updating var mapping using specific_expr
+  for (size_t i = 0; i < specific_buf->shape.size(); ++i) {
+    build_var_mapping(specific_buf->shape[i], buf_to_specialize->shape[i]);
+  }
+  for (size_t i = 0; i < specific_buf->strides.size(); ++i) {
+    build_var_mapping(specific_buf->strides[i], buf_to_specialize->strides[i]);
+  }
+  build_var_mapping(specific_buf->elem_offset, buf_to_specialize->elem_offset);
+
+  // Check data_alignment and offset_factor.
+  // These two signatures are int, so we do not need map them.
+  CHECK_EQ(specific_buf->data_alignment, buf_to_specialize->data_alignment)
+      << "ValueError: The buffer data_alignment mismatched" << buf_to_specialize->data_alignment
+      << " vs. " << specific_buf->data_alignment << ".";
+
+  CHECK_EQ(specific_buf->offset_factor, buf_to_specialize->offset_factor)
+      << "ValueError: The buffer offset_factor mismatched" << buf_to_specialize->offset_factor
+      << " vs. " << specific_buf->offset_factor << ".";
+}
+
+void UpdateSpecializeVarMap(const PrimFunc& func, const Var& param, const PrimExpr& specific_expr,
+                            VarMap* var_map) {
+  // check param is in PrimFunc's parameters
+  CHECK(IsParam(func, param)) << "ValueError: Specialize expects param to be in PrimFunc's params";
+  // specialize a param not in buffer_map
+  CHECK_EQ(func->buffer_map.count(param), 0)
+      << "ValueError: Specialize expects param to not be in PrimFunc's buffer_map";
+  // build var mapping using specific_expr
+  (*var_map)[param] = specific_expr;
+}
+
+/**************** FFI ****************/
+
+TVM_REGISTER_GLOBAL("tir.Specialize")
+    .set_body_typed<PrimFunc(PrimFunc, Map<Var, ObjectRef>)>([](PrimFunc func,
+                                                                Map<Var, ObjectRef> param_map) {

Review comment:
       Let us move this function to a separate global Specialize function so it can be called from C++ side as well

##########
File path: tests/python/unittest/test_tvmscript_meta_programming.py
##########
@@ -0,0 +1,193 @@
+# Licensed to the Apache Software Foundation (ASF) under one

Review comment:
       test_tir_specialize




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] tqchen commented on a change in pull request #8354: [TIR][TVMScript] specialize

Posted by GitBox <gi...@apache.org>.
tqchen commented on a change in pull request #8354:
URL: https://github.com/apache/tvm/pull/8354#discussion_r659744486



##########
File path: src/tir/ir/specialize.cc
##########
@@ -0,0 +1,326 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/tir/ir/specialize.cc
+ * \brief Specialize parameters of PrimFunc.
+ */
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/function.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include <functional>
+
+#include "functor_common.h"
+
+namespace tvm {
+namespace tir {
+
+using VarMap = std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual>;
+
+/**************** Helper functions ****************/
+
+/*! \brief Helper function to check whether the given var is in function parameter list. */
+inline bool IsParam(const PrimFunc& func, const Var& param) {
+  return std::any_of(func->params.begin(), func->params.end(),
+                     [&](const Var& var) { return var.same_as(param); });
+}
+
+/*! \brief Mutator to specialize function and remove const parameters */
+class PrimFuncSpecializer : public StmtExprMutator {
+ public:
+  explicit PrimFuncSpecializer(VarMap var_map) : var_map_(var_map) {}
+
+  static PrimFunc Specialize(PrimFunc f, const VarMap& var_map) {
+    PrimFuncSpecializer specializer(var_map);
+    // Updating Buffer map
+    Map<Var, Buffer> buffer_map;
+    for (const auto& it : f->buffer_map) {
+      const Var& var = it.first;
+      const Buffer& buffer = it.second;
+      Buffer new_buffer = specializer.MutateBuffer(buffer);
+      buffer_map.Set(var, new_buffer);
+      if (!new_buffer.same_as(buffer)) {
+        specializer.buffer_map_[buffer] = new_buffer;
+      }
+    }
+
+    // Updating parmeters
+    Array<Var> params;
+    for (const auto& var : f->params) {
+      // Remove parmeters which has been specialized.
+      if (var_map.find(var) == var_map.end()) {
+        params.push_back(var);
+      }
+    }
+
+    PrimFuncNode* f_ptr = f.CopyOnWrite();
+    f_ptr->params = std::move(params);
+    f_ptr->buffer_map = std::move(buffer_map);
+    f_ptr->body = specializer(std::move(f_ptr->body));
+
+    // Updating attrs
+    if (f->attrs.defined()) {
+      auto& attr_dict = f_ptr->attrs.CopyOnWrite()->dict;
+      for (const auto& kv : attr_dict) {
+        const String& key = kv.first;
+        const ObjectRef& value = kv.second;
+        if (value->IsInstance<PrimExprNode>()) {
+          attr_dict.Set(key, Substitute(Downcast<PrimExpr>(value), var_map));
+        }
+      }
+    }
+    return f;
+  }
+
+ private:
+  Stmt VisitStmt_(const BlockNode* op) final {
+    Array<Buffer> alloc_buffers = MutateArray(
+        op->alloc_buffers,
+        std::bind(&PrimFuncSpecializer::MutateAllocBuffer, this, std::placeholders::_1));
+    Array<BufferRegion> reads = MutateArray(
+        op->reads,
+        std::bind(&PrimFuncSpecializer::MutateBufferRegion, this, std::placeholders::_1));
+    Array<BufferRegion> writes = MutateArray(
+        op->writes,
+        std::bind(&PrimFuncSpecializer::MutateBufferRegion, this, std::placeholders::_1));
+    Array<IterVar> block_vars = MutateArray(
+        op->iter_vars, std::bind(&PrimFuncSpecializer::MutateIterVar, this, std::placeholders::_1));
+    Optional<Stmt> init = NullOpt;
+    if (op->init.defined()) {
+      init = VisitStmt(op->init.value());
+    }
+    Stmt body = VisitStmt(op->body);
+
+    if (alloc_buffers.same_as(op->alloc_buffers) && reads.same_as(op->reads) &&
+        writes.same_as(op->writes) && block_vars.same_as(op->iter_vars) && body.same_as(op->body) &&
+        init.same_as(op->init)) {
+      return GetRef<Block>(op);
+    } else {
+      ObjectPtr<BlockNode> n = CopyOnWrite(op);
+      n->alloc_buffers = std::move(alloc_buffers);
+      n->reads = std::move(reads);
+      n->writes = std::move(writes);
+      n->iter_vars = std::move(block_vars);
+      n->body = std::move(body);
+      n->init = std::move(init);
+      return Stmt(n);
+    }
+  }
+
+  Stmt VisitStmt_(const BufferStoreNode* op) final {
+    auto it = buffer_map_.find(op->buffer);
+    if (it == buffer_map_.end()) {
+      return GetRef<BufferStore>(op);
+    }
+
+    PrimExpr value = VisitExpr(op->value);
+    Array<PrimExpr> indices =
+        MutateArray(op->indices, [this](const PrimExpr& e) { return this->VisitExpr(e); });
+
+    auto n = CopyOnWrite(op);
+    n->buffer = it->second;
+    n->value = std::move(value);
+    n->indices = std::move(indices);
+    return Stmt(n);
+  }
+
+  PrimExpr VisitExpr_(const BufferLoadNode* op) final {
+    auto it = buffer_map_.find(op->buffer);
+    if (it == buffer_map_.end()) {
+      return GetRef<BufferLoad>(op);
+    }
+
+    Array<PrimExpr> indices =
+        MutateArray(op->indices, [this](const PrimExpr& e) { return this->VisitExpr(e); });
+
+    auto n = CopyOnWrite(op);
+    n->buffer = it->second;
+    n->indices = std::move(indices);
+    return PrimExpr(n);
+  }
+
+  PrimExpr VisitExpr_(const VarNode* op) final {
+    auto it = var_map_.find(GetRef<Var>(op));
+    if (it == var_map_.end()) {
+      return GetRef<PrimExpr>(op);
+    } else {
+      return it->second;
+    }
+  }
+
+ private:
+  Buffer MutateBuffer(Buffer buffer) const {
+    BufferNode* buffer_ptr = buffer.CopyOnWrite();
+    Array<PrimExpr> new_shape, new_stride;
+    new_shape.reserve(buffer_ptr->shape.size());
+    new_shape.reserve(buffer_ptr->strides.size());
+    for (const auto& dim : buffer_ptr->shape) {
+      new_shape.push_back(Substitute(dim, var_map_));
+    }
+    for (const auto& stride : buffer_ptr->strides) {
+      new_shape.push_back(Substitute(stride, var_map_));
+    }
+    buffer_ptr->elem_offset = Substitute(buffer_ptr->elem_offset, var_map_);
+    buffer_ptr->shape = std::move(new_shape);
+    buffer_ptr->strides = std::move(new_stride);
+    return buffer;
+  }
+
+  Range MutateRange(const Range& range) {
+    PrimExpr min = this->VisitExpr(range->min);
+    PrimExpr extent = this->VisitExpr(range->extent);
+    if (min.same_as(range->min) && extent.same_as(range->extent)) {
+      return range;
+    } else {
+      ObjectPtr<RangeNode> n = CopyOnWrite(range.get());
+      n->min = std::move(min);
+      n->extent = std::move(extent);
+      return Range(n);
+    }
+  }
+
+  IterVar MutateIterVar(const IterVar& iter_var) {
+    Range range = MutateRange(iter_var->dom);
+    if (range.same_as(iter_var->dom)) {
+      return iter_var;
+    } else {
+      auto n = CopyOnWrite(iter_var.get());
+      n->dom = std::move(range);
+      return IterVar(n);
+    }
+  }
+
+  Buffer MutateAllocBuffer(const Buffer& alloc_buf) {
+    Buffer buf = MutateBuffer(alloc_buf);
+    if (buf.same_as(alloc_buf)) {
+      return alloc_buf;
+    } else {
+      buffer_map_[alloc_buf] = buf;

Review comment:
       ICheck that buffer map should not contain alloc_buff




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] tqchen commented on a change in pull request #8354: [TIR][TVMScript] specialize

Posted by GitBox <gi...@apache.org>.
tqchen commented on a change in pull request #8354:
URL: https://github.com/apache/tvm/pull/8354#discussion_r659739313



##########
File path: src/tir/ir/specialize.cc
##########
@@ -0,0 +1,326 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/tir/ir/specialize.cc
+ * \brief Specialize parameters of PrimFunc.
+ */
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/function.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include <functional>
+
+#include "functor_common.h"
+
+namespace tvm {
+namespace tir {
+
+using VarMap = std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual>;
+
+/**************** Helper functions ****************/
+
+/*! \brief Helper function to check whether the given var is in function parameter list. */
+inline bool IsParam(const PrimFunc& func, const Var& param) {
+  return std::any_of(func->params.begin(), func->params.end(),
+                     [&](const Var& var) { return var.same_as(param); });
+}
+
+/*! \brief Mutator to specialize function and remove const parameters */
+class PrimFuncSpecializer : public StmtExprMutator {
+ public:
+  explicit PrimFuncSpecializer(VarMap var_map) : var_map_(var_map) {}
+
+  static PrimFunc Specialize(PrimFunc f, const VarMap& var_map) {
+    PrimFuncSpecializer specializer(var_map);
+    // Updating Buffer map
+    Map<Var, Buffer> buffer_map;
+    for (const auto& it : f->buffer_map) {
+      const Var& var = it.first;
+      const Buffer& buffer = it.second;
+      Buffer new_buffer = specializer.MutateBuffer(buffer);
+      buffer_map.Set(var, new_buffer);
+      if (!new_buffer.same_as(buffer)) {
+        specializer.buffer_map_[buffer] = new_buffer;
+      }
+    }
+
+    // Updating parmeters
+    Array<Var> params;
+    for (const auto& var : f->params) {
+      // Remove parmeters which has been specialized.
+      if (var_map.find(var) == var_map.end()) {
+        params.push_back(var);
+      }
+    }
+
+    PrimFuncNode* f_ptr = f.CopyOnWrite();
+    f_ptr->params = std::move(params);
+    f_ptr->buffer_map = std::move(buffer_map);
+    f_ptr->body = specializer(std::move(f_ptr->body));
+
+    // Updating attrs
+    if (f->attrs.defined()) {
+      auto& attr_dict = f_ptr->attrs.CopyOnWrite()->dict;
+      for (const auto& kv : attr_dict) {
+        const String& key = kv.first;
+        const ObjectRef& value = kv.second;
+        if (value->IsInstance<PrimExprNode>()) {
+          attr_dict.Set(key, Substitute(Downcast<PrimExpr>(value), var_map));
+        }
+      }
+    }
+    return f;
+  }
+
+ private:
+  Stmt VisitStmt_(const BlockNode* op) final {
+    Array<Buffer> alloc_buffers = MutateArray(
+        op->alloc_buffers,
+        std::bind(&PrimFuncSpecializer::MutateAllocBuffer, this, std::placeholders::_1));
+    Array<BufferRegion> reads = MutateArray(
+        op->reads,
+        std::bind(&PrimFuncSpecializer::MutateBufferRegion, this, std::placeholders::_1));
+    Array<BufferRegion> writes = MutateArray(
+        op->writes,
+        std::bind(&PrimFuncSpecializer::MutateBufferRegion, this, std::placeholders::_1));
+    Array<IterVar> block_vars = MutateArray(
+        op->iter_vars, std::bind(&PrimFuncSpecializer::MutateIterVar, this, std::placeholders::_1));
+    Optional<Stmt> init = NullOpt;
+    if (op->init.defined()) {
+      init = VisitStmt(op->init.value());
+    }
+    Stmt body = VisitStmt(op->body);
+
+    if (alloc_buffers.same_as(op->alloc_buffers) && reads.same_as(op->reads) &&
+        writes.same_as(op->writes) && block_vars.same_as(op->iter_vars) && body.same_as(op->body) &&
+        init.same_as(op->init)) {
+      return GetRef<Block>(op);
+    } else {
+      ObjectPtr<BlockNode> n = CopyOnWrite(op);
+      n->alloc_buffers = std::move(alloc_buffers);
+      n->reads = std::move(reads);
+      n->writes = std::move(writes);
+      n->iter_vars = std::move(block_vars);
+      n->body = std::move(body);
+      n->init = std::move(init);
+      return Stmt(n);
+    }
+  }
+
+  Stmt VisitStmt_(const BufferStoreNode* op) final {
+    auto it = buffer_map_.find(op->buffer);
+    if (it == buffer_map_.end()) {
+      return GetRef<BufferStore>(op);
+    }
+
+    PrimExpr value = VisitExpr(op->value);
+    Array<PrimExpr> indices =
+        MutateArray(op->indices, [this](const PrimExpr& e) { return this->VisitExpr(e); });
+
+    auto n = CopyOnWrite(op);
+    n->buffer = it->second;
+    n->value = std::move(value);
+    n->indices = std::move(indices);
+    return Stmt(n);
+  }
+
+  PrimExpr VisitExpr_(const BufferLoadNode* op) final {
+    auto it = buffer_map_.find(op->buffer);
+    if (it == buffer_map_.end()) {
+      return GetRef<BufferLoad>(op);
+    }
+
+    Array<PrimExpr> indices =
+        MutateArray(op->indices, [this](const PrimExpr& e) { return this->VisitExpr(e); });
+
+    auto n = CopyOnWrite(op);
+    n->buffer = it->second;
+    n->indices = std::move(indices);
+    return PrimExpr(n);
+  }
+
+  PrimExpr VisitExpr_(const VarNode* op) final {
+    auto it = var_map_.find(GetRef<Var>(op));
+    if (it == var_map_.end()) {
+      return GetRef<PrimExpr>(op);
+    } else {
+      return it->second;
+    }
+  }
+
+ private:
+  Buffer MutateBuffer(Buffer buffer) const {
+    BufferNode* buffer_ptr = buffer.CopyOnWrite();

Review comment:
       The following invariance must hold: if nothing is remapped, we return the same buffer. Right now this does not hold.
   
   - We need to check the new shape, strides




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] tqchen commented on a change in pull request #8354: [TIR][TVMScript] specialize

Posted by GitBox <gi...@apache.org>.
tqchen commented on a change in pull request #8354:
URL: https://github.com/apache/tvm/pull/8354#discussion_r659739654



##########
File path: src/tir/ir/specialize.cc
##########
@@ -0,0 +1,326 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/tir/ir/specialize.cc
+ * \brief Specialize parameters of PrimFunc.
+ */
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/function.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include <functional>
+
+#include "functor_common.h"
+
+namespace tvm {
+namespace tir {
+
+using VarMap = std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual>;
+
+/**************** Helper functions ****************/
+
+/*! \brief Helper function to check whether the given var is in function parameter list. */
+inline bool IsParam(const PrimFunc& func, const Var& param) {
+  return std::any_of(func->params.begin(), func->params.end(),
+                     [&](const Var& var) { return var.same_as(param); });
+}
+
+/*! \brief Mutator to specialize function and remove const parameters */
+class PrimFuncSpecializer : public StmtExprMutator {
+ public:
+  explicit PrimFuncSpecializer(VarMap var_map) : var_map_(var_map) {}
+
+  static PrimFunc Specialize(PrimFunc f, const VarMap& var_map) {
+    PrimFuncSpecializer specializer(var_map);
+    // Updating Buffer map
+    Map<Var, Buffer> buffer_map;
+    for (const auto& it : f->buffer_map) {
+      const Var& var = it.first;
+      const Buffer& buffer = it.second;
+      Buffer new_buffer = specializer.MutateBuffer(buffer);
+      buffer_map.Set(var, new_buffer);
+      if (!new_buffer.same_as(buffer)) {
+        specializer.buffer_map_[buffer] = new_buffer;
+      }
+    }
+
+    // Updating parmeters
+    Array<Var> params;
+    for (const auto& var : f->params) {
+      // Remove parmeters which has been specialized.
+      if (var_map.find(var) == var_map.end()) {
+        params.push_back(var);
+      }
+    }
+
+    PrimFuncNode* f_ptr = f.CopyOnWrite();
+    f_ptr->params = std::move(params);
+    f_ptr->buffer_map = std::move(buffer_map);
+    f_ptr->body = specializer(std::move(f_ptr->body));
+
+    // Updating attrs
+    if (f->attrs.defined()) {

Review comment:
       Let us skip attrs updates, attrs are supposed to be invariant to the code itself.

##########
File path: src/tir/ir/specialize.cc
##########
@@ -0,0 +1,326 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/tir/ir/specialize.cc
+ * \brief Specialize parameters of PrimFunc.
+ */
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/function.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include <functional>
+
+#include "functor_common.h"
+
+namespace tvm {
+namespace tir {
+
+using VarMap = std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual>;
+
+/**************** Helper functions ****************/
+
+/*! \brief Helper function to check whether the given var is in function parameter list. */
+inline bool IsParam(const PrimFunc& func, const Var& param) {
+  return std::any_of(func->params.begin(), func->params.end(),
+                     [&](const Var& var) { return var.same_as(param); });
+}
+
+/*! \brief Mutator to specialize function and remove const parameters */
+class PrimFuncSpecializer : public StmtExprMutator {
+ public:
+  explicit PrimFuncSpecializer(VarMap var_map) : var_map_(var_map) {}
+
+  static PrimFunc Specialize(PrimFunc f, const VarMap& var_map) {
+    PrimFuncSpecializer specializer(var_map);
+    // Updating Buffer map
+    Map<Var, Buffer> buffer_map;
+    for (const auto& it : f->buffer_map) {
+      const Var& var = it.first;
+      const Buffer& buffer = it.second;
+      Buffer new_buffer = specializer.MutateBuffer(buffer);
+      buffer_map.Set(var, new_buffer);
+      if (!new_buffer.same_as(buffer)) {
+        specializer.buffer_map_[buffer] = new_buffer;
+      }
+    }
+
+    // Updating parmeters
+    Array<Var> params;
+    for (const auto& var : f->params) {
+      // Remove parmeters which has been specialized.
+      if (var_map.find(var) == var_map.end()) {
+        params.push_back(var);
+      }
+    }
+
+    PrimFuncNode* f_ptr = f.CopyOnWrite();

Review comment:
       The following invariant needs to hold, if nothing changes in the body, we should return the same function

##########
File path: src/tir/ir/specialize.cc
##########
@@ -0,0 +1,326 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/tir/ir/specialize.cc
+ * \brief Specialize parameters of PrimFunc.
+ */
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/function.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include <functional>
+
+#include "functor_common.h"
+
+namespace tvm {
+namespace tir {
+
+using VarMap = std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual>;
+
+/**************** Helper functions ****************/
+
+/*! \brief Helper function to check whether the given var is in function parameter list. */
+inline bool IsParam(const PrimFunc& func, const Var& param) {
+  return std::any_of(func->params.begin(), func->params.end(),
+                     [&](const Var& var) { return var.same_as(param); });
+}
+
+/*! \brief Mutator to specialize function and remove const parameters */
+class PrimFuncSpecializer : public StmtExprMutator {
+ public:
+  explicit PrimFuncSpecializer(VarMap var_map) : var_map_(var_map) {}
+
+  static PrimFunc Specialize(PrimFunc f, const VarMap& var_map) {
+    PrimFuncSpecializer specializer(var_map);
+    // Updating Buffer map
+    Map<Var, Buffer> buffer_map;
+    for (const auto& it : f->buffer_map) {
+      const Var& var = it.first;
+      const Buffer& buffer = it.second;
+      Buffer new_buffer = specializer.MutateBuffer(buffer);
+      buffer_map.Set(var, new_buffer);
+      if (!new_buffer.same_as(buffer)) {
+        specializer.buffer_map_[buffer] = new_buffer;
+      }
+    }
+
+    // Updating parmeters
+    Array<Var> params;
+    for (const auto& var : f->params) {
+      // Remove parmeters which has been specialized.
+      if (var_map.find(var) == var_map.end()) {
+        params.push_back(var);
+      }
+    }
+
+    PrimFuncNode* f_ptr = f.CopyOnWrite();
+    f_ptr->params = std::move(params);
+    f_ptr->buffer_map = std::move(buffer_map);
+    f_ptr->body = specializer(std::move(f_ptr->body));
+
+    // Updating attrs
+    if (f->attrs.defined()) {
+      auto& attr_dict = f_ptr->attrs.CopyOnWrite()->dict;
+      for (const auto& kv : attr_dict) {
+        const String& key = kv.first;
+        const ObjectRef& value = kv.second;
+        if (value->IsInstance<PrimExprNode>()) {
+          attr_dict.Set(key, Substitute(Downcast<PrimExpr>(value), var_map));
+        }
+      }
+    }
+    return f;
+  }
+
+ private:
+  Stmt VisitStmt_(const BlockNode* op) final {
+    Array<Buffer> alloc_buffers = MutateArray(
+        op->alloc_buffers,
+        std::bind(&PrimFuncSpecializer::MutateAllocBuffer, this, std::placeholders::_1));
+    Array<BufferRegion> reads = MutateArray(
+        op->reads,
+        std::bind(&PrimFuncSpecializer::MutateBufferRegion, this, std::placeholders::_1));
+    Array<BufferRegion> writes = MutateArray(
+        op->writes,
+        std::bind(&PrimFuncSpecializer::MutateBufferRegion, this, std::placeholders::_1));
+    Array<IterVar> block_vars = MutateArray(
+        op->iter_vars, std::bind(&PrimFuncSpecializer::MutateIterVar, this, std::placeholders::_1));
+    Optional<Stmt> init = NullOpt;
+    if (op->init.defined()) {
+      init = VisitStmt(op->init.value());
+    }
+    Stmt body = VisitStmt(op->body);
+
+    if (alloc_buffers.same_as(op->alloc_buffers) && reads.same_as(op->reads) &&
+        writes.same_as(op->writes) && block_vars.same_as(op->iter_vars) && body.same_as(op->body) &&
+        init.same_as(op->init)) {
+      return GetRef<Block>(op);
+    } else {
+      ObjectPtr<BlockNode> n = CopyOnWrite(op);
+      n->alloc_buffers = std::move(alloc_buffers);
+      n->reads = std::move(reads);
+      n->writes = std::move(writes);
+      n->iter_vars = std::move(block_vars);
+      n->body = std::move(body);
+      n->init = std::move(init);
+      return Stmt(n);
+    }
+  }
+
+  Stmt VisitStmt_(const BufferStoreNode* op) final {
+    auto it = buffer_map_.find(op->buffer);
+    if (it == buffer_map_.end()) {
+      return GetRef<BufferStore>(op);
+    }
+
+    PrimExpr value = VisitExpr(op->value);
+    Array<PrimExpr> indices =
+        MutateArray(op->indices, [this](const PrimExpr& e) { return this->VisitExpr(e); });
+
+    auto n = CopyOnWrite(op);
+    n->buffer = it->second;
+    n->value = std::move(value);
+    n->indices = std::move(indices);
+    return Stmt(n);
+  }
+
+  PrimExpr VisitExpr_(const BufferLoadNode* op) final {
+    auto it = buffer_map_.find(op->buffer);
+    if (it == buffer_map_.end()) {
+      return GetRef<BufferLoad>(op);
+    }
+
+    Array<PrimExpr> indices =
+        MutateArray(op->indices, [this](const PrimExpr& e) { return this->VisitExpr(e); });
+
+    auto n = CopyOnWrite(op);
+    n->buffer = it->second;
+    n->indices = std::move(indices);
+    return PrimExpr(n);
+  }
+
+  PrimExpr VisitExpr_(const VarNode* op) final {
+    auto it = var_map_.find(GetRef<Var>(op));
+    if (it == var_map_.end()) {
+      return GetRef<PrimExpr>(op);
+    } else {
+      return it->second;
+    }
+  }
+
+ private:
+  Buffer MutateBuffer(Buffer buffer) const {
+    BufferNode* buffer_ptr = buffer.CopyOnWrite();
+    Array<PrimExpr> new_shape, new_stride;
+    new_shape.reserve(buffer_ptr->shape.size());
+    new_shape.reserve(buffer_ptr->strides.size());
+    for (const auto& dim : buffer_ptr->shape) {
+      new_shape.push_back(Substitute(dim, var_map_));
+    }
+    for (const auto& stride : buffer_ptr->strides) {
+      new_shape.push_back(Substitute(stride, var_map_));
+    }
+    buffer_ptr->elem_offset = Substitute(buffer_ptr->elem_offset, var_map_);
+    buffer_ptr->shape = std::move(new_shape);
+    buffer_ptr->strides = std::move(new_stride);
+    return buffer;
+  }
+
+  Range MutateRange(const Range& range) {
+    PrimExpr min = this->VisitExpr(range->min);
+    PrimExpr extent = this->VisitExpr(range->extent);
+    if (min.same_as(range->min) && extent.same_as(range->extent)) {
+      return range;
+    } else {
+      ObjectPtr<RangeNode> n = CopyOnWrite(range.get());
+      n->min = std::move(min);
+      n->extent = std::move(extent);
+      return Range(n);
+    }
+  }
+
+  IterVar MutateIterVar(const IterVar& iter_var) {
+    Range range = MutateRange(iter_var->dom);
+    if (range.same_as(iter_var->dom)) {
+      return iter_var;
+    } else {
+      auto n = CopyOnWrite(iter_var.get());
+      n->dom = std::move(range);
+      return IterVar(n);
+    }
+  }
+
+  Buffer MutateAllocBuffer(const Buffer& alloc_buf) {
+    Buffer buf = MutateBuffer(alloc_buf);
+    if (buf.same_as(alloc_buf)) {
+      return alloc_buf;
+    } else {
+      buffer_map_[alloc_buf] = buf;
+      return buf;
+    }
+  }
+
+  BufferRegion MutateBufferRegion(const BufferRegion& buffer_region) {
+    auto it = buffer_map_.find(buffer_region->buffer);
+    Array<Range> region =
+        MutateArray(buffer_region->region,
+                    std::bind(&PrimFuncSpecializer::MutateRange, this, std::placeholders::_1));
+    if (it == buffer_map_.end() && region.same_as(buffer_region->region)) {
+      return buffer_region;
+    } else {
+      auto n = CopyOnWrite(buffer_region.get());

Review comment:
       avoid CopyOnWrite except for stmt for now, since we are not tracking cow for other things




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] Hzfengsy commented on a change in pull request #8354: [TIR][TVMScript] specialize

Posted by GitBox <gi...@apache.org>.
Hzfengsy commented on a change in pull request #8354:
URL: https://github.com/apache/tvm/pull/8354#discussion_r660399589



##########
File path: src/tir/ir/specialize.cc
##########
@@ -0,0 +1,326 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/tir/ir/specialize.cc
+ * \brief Specialize parameters of PrimFunc.
+ */
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/function.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include <functional>
+
+#include "functor_common.h"
+
+namespace tvm {
+namespace tir {
+
+using VarMap = std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual>;
+
+/**************** Helper functions ****************/
+
+/*! \brief Helper function to check whether the given var is in function parameter list. */
+inline bool IsParam(const PrimFunc& func, const Var& param) {
+  return std::any_of(func->params.begin(), func->params.end(),
+                     [&](const Var& var) { return var.same_as(param); });
+}
+
+/*! \brief Mutator to specialize function and remove const parameters */
+class PrimFuncSpecializer : public StmtExprMutator {
+ public:
+  explicit PrimFuncSpecializer(VarMap var_map) : var_map_(var_map) {}
+
+  static PrimFunc Specialize(PrimFunc f, const VarMap& var_map) {
+    PrimFuncSpecializer specializer(var_map);
+    // Updating Buffer map
+    Map<Var, Buffer> buffer_map;
+    for (const auto& it : f->buffer_map) {
+      const Var& var = it.first;
+      const Buffer& buffer = it.second;
+      Buffer new_buffer = specializer.MutateBuffer(buffer);
+      buffer_map.Set(var, new_buffer);
+      if (!new_buffer.same_as(buffer)) {
+        specializer.buffer_map_[buffer] = new_buffer;
+      }
+    }
+
+    // Updating parmeters
+    Array<Var> params;
+    for (const auto& var : f->params) {
+      // Remove parmeters which has been specialized.
+      if (var_map.find(var) == var_map.end()) {
+        params.push_back(var);
+      }
+    }
+
+    PrimFuncNode* f_ptr = f.CopyOnWrite();
+    f_ptr->params = std::move(params);
+    f_ptr->buffer_map = std::move(buffer_map);
+    f_ptr->body = specializer(std::move(f_ptr->body));
+
+    // Updating attrs
+    if (f->attrs.defined()) {
+      auto& attr_dict = f_ptr->attrs.CopyOnWrite()->dict;
+      for (const auto& kv : attr_dict) {
+        const String& key = kv.first;
+        const ObjectRef& value = kv.second;
+        if (value->IsInstance<PrimExprNode>()) {
+          attr_dict.Set(key, Substitute(Downcast<PrimExpr>(value), var_map));
+        }
+      }
+    }
+    return f;
+  }
+
+ private:
+  Stmt VisitStmt_(const BlockNode* op) final {
+    Array<Buffer> alloc_buffers = MutateArray(
+        op->alloc_buffers,
+        std::bind(&PrimFuncSpecializer::MutateAllocBuffer, this, std::placeholders::_1));
+    Array<BufferRegion> reads = MutateArray(
+        op->reads,
+        std::bind(&PrimFuncSpecializer::MutateBufferRegion, this, std::placeholders::_1));
+    Array<BufferRegion> writes = MutateArray(
+        op->writes,
+        std::bind(&PrimFuncSpecializer::MutateBufferRegion, this, std::placeholders::_1));
+    Array<IterVar> block_vars = MutateArray(
+        op->iter_vars, std::bind(&PrimFuncSpecializer::MutateIterVar, this, std::placeholders::_1));
+    Optional<Stmt> init = NullOpt;
+    if (op->init.defined()) {
+      init = VisitStmt(op->init.value());
+    }
+    Stmt body = VisitStmt(op->body);
+
+    if (alloc_buffers.same_as(op->alloc_buffers) && reads.same_as(op->reads) &&
+        writes.same_as(op->writes) && block_vars.same_as(op->iter_vars) && body.same_as(op->body) &&
+        init.same_as(op->init)) {
+      return GetRef<Block>(op);
+    } else {
+      ObjectPtr<BlockNode> n = CopyOnWrite(op);
+      n->alloc_buffers = std::move(alloc_buffers);
+      n->reads = std::move(reads);
+      n->writes = std::move(writes);
+      n->iter_vars = std::move(block_vars);
+      n->body = std::move(body);
+      n->init = std::move(init);
+      return Stmt(n);
+    }
+  }
+
+  Stmt VisitStmt_(const BufferStoreNode* op) final {
+    auto it = buffer_map_.find(op->buffer);
+    if (it == buffer_map_.end()) {
+      return GetRef<BufferStore>(op);
+    }
+
+    PrimExpr value = VisitExpr(op->value);
+    Array<PrimExpr> indices =
+        MutateArray(op->indices, [this](const PrimExpr& e) { return this->VisitExpr(e); });
+
+    auto n = CopyOnWrite(op);

Review comment:
       In this case, I return the original one before at line 131. If the `buffer` is in `buffer_map_` it will always change

##########
File path: src/tir/ir/specialize.cc
##########
@@ -0,0 +1,326 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/tir/ir/specialize.cc
+ * \brief Specialize parameters of PrimFunc.
+ */
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/function.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include <functional>
+
+#include "functor_common.h"
+
+namespace tvm {
+namespace tir {
+
+using VarMap = std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual>;
+
+/**************** Helper functions ****************/
+
+/*! \brief Helper function to check whether the given var is in function parameter list. */
+inline bool IsParam(const PrimFunc& func, const Var& param) {
+  return std::any_of(func->params.begin(), func->params.end(),
+                     [&](const Var& var) { return var.same_as(param); });
+}
+
+/*! \brief Mutator to specialize function and remove const parameters */
+class PrimFuncSpecializer : public StmtExprMutator {
+ public:
+  explicit PrimFuncSpecializer(VarMap var_map) : var_map_(var_map) {}
+
+  static PrimFunc Specialize(PrimFunc f, const VarMap& var_map) {
+    PrimFuncSpecializer specializer(var_map);
+    // Updating Buffer map
+    Map<Var, Buffer> buffer_map;
+    for (const auto& it : f->buffer_map) {
+      const Var& var = it.first;
+      const Buffer& buffer = it.second;
+      Buffer new_buffer = specializer.MutateBuffer(buffer);
+      buffer_map.Set(var, new_buffer);
+      if (!new_buffer.same_as(buffer)) {
+        specializer.buffer_map_[buffer] = new_buffer;
+      }
+    }
+
+    // Updating parmeters
+    Array<Var> params;
+    for (const auto& var : f->params) {
+      // Remove parmeters which has been specialized.
+      if (var_map.find(var) == var_map.end()) {
+        params.push_back(var);
+      }
+    }
+
+    PrimFuncNode* f_ptr = f.CopyOnWrite();
+    f_ptr->params = std::move(params);
+    f_ptr->buffer_map = std::move(buffer_map);
+    f_ptr->body = specializer(std::move(f_ptr->body));
+
+    // Updating attrs
+    if (f->attrs.defined()) {
+      auto& attr_dict = f_ptr->attrs.CopyOnWrite()->dict;
+      for (const auto& kv : attr_dict) {
+        const String& key = kv.first;
+        const ObjectRef& value = kv.second;
+        if (value->IsInstance<PrimExprNode>()) {
+          attr_dict.Set(key, Substitute(Downcast<PrimExpr>(value), var_map));
+        }
+      }
+    }
+    return f;
+  }
+
+ private:
+  Stmt VisitStmt_(const BlockNode* op) final {
+    Array<Buffer> alloc_buffers = MutateArray(
+        op->alloc_buffers,
+        std::bind(&PrimFuncSpecializer::MutateAllocBuffer, this, std::placeholders::_1));
+    Array<BufferRegion> reads = MutateArray(
+        op->reads,
+        std::bind(&PrimFuncSpecializer::MutateBufferRegion, this, std::placeholders::_1));
+    Array<BufferRegion> writes = MutateArray(
+        op->writes,
+        std::bind(&PrimFuncSpecializer::MutateBufferRegion, this, std::placeholders::_1));
+    Array<IterVar> block_vars = MutateArray(
+        op->iter_vars, std::bind(&PrimFuncSpecializer::MutateIterVar, this, std::placeholders::_1));
+    Optional<Stmt> init = NullOpt;
+    if (op->init.defined()) {
+      init = VisitStmt(op->init.value());
+    }
+    Stmt body = VisitStmt(op->body);
+
+    if (alloc_buffers.same_as(op->alloc_buffers) && reads.same_as(op->reads) &&
+        writes.same_as(op->writes) && block_vars.same_as(op->iter_vars) && body.same_as(op->body) &&
+        init.same_as(op->init)) {
+      return GetRef<Block>(op);
+    } else {
+      ObjectPtr<BlockNode> n = CopyOnWrite(op);
+      n->alloc_buffers = std::move(alloc_buffers);
+      n->reads = std::move(reads);
+      n->writes = std::move(writes);
+      n->iter_vars = std::move(block_vars);
+      n->body = std::move(body);
+      n->init = std::move(init);
+      return Stmt(n);
+    }
+  }
+
+  Stmt VisitStmt_(const BufferStoreNode* op) final {
+    auto it = buffer_map_.find(op->buffer);
+    if (it == buffer_map_.end()) {
+      return GetRef<BufferStore>(op);
+    }
+
+    PrimExpr value = VisitExpr(op->value);
+    Array<PrimExpr> indices =
+        MutateArray(op->indices, [this](const PrimExpr& e) { return this->VisitExpr(e); });
+
+    auto n = CopyOnWrite(op);

Review comment:
       In this case, I return the original one before at line 131. If the `buffer` is in `buffer_map_` it will always change




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] tqchen commented on a change in pull request #8354: [TIR][TVMScript] specialize

Posted by GitBox <gi...@apache.org>.
tqchen commented on a change in pull request #8354:
URL: https://github.com/apache/tvm/pull/8354#discussion_r659322308



##########
File path: python/tvm/tir/function.py
##########
@@ -85,3 +87,19 @@ def with_body(self, new_body, span=None):
             The created new function.
         """
         return PrimFunc(self.params, new_body, self.ret_type, self.buffer_map, self.attrs, span)
+
+    def specialize(self, param_map: Mapping[Var, Union[PrimExpr, Buffer]]):
+        """Metaprogramming usage: specialize parameters of PrimFunc
+

Review comment:
       Add an code example block




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] Hzfengsy commented on pull request #8354: [TIR][TVMScript] specialize

Posted by GitBox <gi...@apache.org>.
Hzfengsy commented on pull request #8354:
URL: https://github.com/apache/tvm/pull/8354#issuecomment-870552987


   Thanks, @comaniac. Just have fixed it


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] tqchen merged pull request #8354: [TIR][TVMScript] specialize

Posted by GitBox <gi...@apache.org>.
tqchen merged pull request #8354:
URL: https://github.com/apache/tvm/pull/8354


   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] tqchen commented on a change in pull request #8354: [TIR][TVMScript] specialize

Posted by GitBox <gi...@apache.org>.
tqchen commented on a change in pull request #8354:
URL: https://github.com/apache/tvm/pull/8354#discussion_r659743770



##########
File path: src/tir/ir/specialize.cc
##########
@@ -0,0 +1,326 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/tir/ir/specialize.cc
+ * \brief Specialize parameters of PrimFunc.
+ */
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/function.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include <functional>
+
+#include "functor_common.h"
+
+namespace tvm {
+namespace tir {
+
+using VarMap = std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual>;
+
+/**************** Helper functions ****************/
+
+/*! \brief Helper function to check whether the given var is in function parameter list. */
+inline bool IsParam(const PrimFunc& func, const Var& param) {
+  return std::any_of(func->params.begin(), func->params.end(),
+                     [&](const Var& var) { return var.same_as(param); });
+}
+
+/*! \brief Mutator to specialize function and remove const parameters */
+class PrimFuncSpecializer : public StmtExprMutator {
+ public:
+  explicit PrimFuncSpecializer(VarMap var_map) : var_map_(var_map) {}
+
+  static PrimFunc Specialize(PrimFunc f, const VarMap& var_map) {
+    PrimFuncSpecializer specializer(var_map);
+    // Updating Buffer map
+    Map<Var, Buffer> buffer_map;
+    for (const auto& it : f->buffer_map) {
+      const Var& var = it.first;
+      const Buffer& buffer = it.second;
+      Buffer new_buffer = specializer.MutateBuffer(buffer);
+      buffer_map.Set(var, new_buffer);
+      if (!new_buffer.same_as(buffer)) {
+        specializer.buffer_map_[buffer] = new_buffer;
+      }
+    }
+
+    // Updating parmeters
+    Array<Var> params;
+    for (const auto& var : f->params) {
+      // Remove parmeters which has been specialized.
+      if (var_map.find(var) == var_map.end()) {
+        params.push_back(var);
+      }
+    }
+
+    PrimFuncNode* f_ptr = f.CopyOnWrite();
+    f_ptr->params = std::move(params);
+    f_ptr->buffer_map = std::move(buffer_map);
+    f_ptr->body = specializer(std::move(f_ptr->body));
+
+    // Updating attrs
+    if (f->attrs.defined()) {
+      auto& attr_dict = f_ptr->attrs.CopyOnWrite()->dict;
+      for (const auto& kv : attr_dict) {
+        const String& key = kv.first;
+        const ObjectRef& value = kv.second;
+        if (value->IsInstance<PrimExprNode>()) {
+          attr_dict.Set(key, Substitute(Downcast<PrimExpr>(value), var_map));
+        }
+      }
+    }
+    return f;
+  }
+
+ private:
+  Stmt VisitStmt_(const BlockNode* op) final {
+    Array<Buffer> alloc_buffers = MutateArray(
+        op->alloc_buffers,
+        std::bind(&PrimFuncSpecializer::MutateAllocBuffer, this, std::placeholders::_1));
+    Array<BufferRegion> reads = MutateArray(
+        op->reads,
+        std::bind(&PrimFuncSpecializer::MutateBufferRegion, this, std::placeholders::_1));
+    Array<BufferRegion> writes = MutateArray(
+        op->writes,
+        std::bind(&PrimFuncSpecializer::MutateBufferRegion, this, std::placeholders::_1));
+    Array<IterVar> block_vars = MutateArray(
+        op->iter_vars, std::bind(&PrimFuncSpecializer::MutateIterVar, this, std::placeholders::_1));
+    Optional<Stmt> init = NullOpt;
+    if (op->init.defined()) {
+      init = VisitStmt(op->init.value());
+    }
+    Stmt body = VisitStmt(op->body);
+
+    if (alloc_buffers.same_as(op->alloc_buffers) && reads.same_as(op->reads) &&
+        writes.same_as(op->writes) && block_vars.same_as(op->iter_vars) && body.same_as(op->body) &&
+        init.same_as(op->init)) {
+      return GetRef<Block>(op);
+    } else {
+      ObjectPtr<BlockNode> n = CopyOnWrite(op);
+      n->alloc_buffers = std::move(alloc_buffers);
+      n->reads = std::move(reads);
+      n->writes = std::move(writes);
+      n->iter_vars = std::move(block_vars);
+      n->body = std::move(body);
+      n->init = std::move(init);
+      return Stmt(n);
+    }
+  }
+
+  Stmt VisitStmt_(const BufferStoreNode* op) final {
+    auto it = buffer_map_.find(op->buffer);
+    if (it == buffer_map_.end()) {
+      return GetRef<BufferStore>(op);
+    }
+
+    PrimExpr value = VisitExpr(op->value);
+    Array<PrimExpr> indices =
+        MutateArray(op->indices, [this](const PrimExpr& e) { return this->VisitExpr(e); });
+
+    auto n = CopyOnWrite(op);
+    n->buffer = it->second;
+    n->value = std::move(value);
+    n->indices = std::move(indices);
+    return Stmt(n);
+  }
+
+  PrimExpr VisitExpr_(const BufferLoadNode* op) final {
+    auto it = buffer_map_.find(op->buffer);
+    if (it == buffer_map_.end()) {
+      return GetRef<BufferLoad>(op);
+    }
+
+    Array<PrimExpr> indices =
+        MutateArray(op->indices, [this](const PrimExpr& e) { return this->VisitExpr(e); });
+
+    auto n = CopyOnWrite(op);
+    n->buffer = it->second;
+    n->indices = std::move(indices);
+    return PrimExpr(n);
+  }
+
+  PrimExpr VisitExpr_(const VarNode* op) final {
+    auto it = var_map_.find(GetRef<Var>(op));
+    if (it == var_map_.end()) {
+      return GetRef<PrimExpr>(op);
+    } else {
+      return it->second;
+    }
+  }
+
+ private:
+  Buffer MutateBuffer(Buffer buffer) const {
+    BufferNode* buffer_ptr = buffer.CopyOnWrite();
+    Array<PrimExpr> new_shape, new_stride;
+    new_shape.reserve(buffer_ptr->shape.size());
+    new_shape.reserve(buffer_ptr->strides.size());
+    for (const auto& dim : buffer_ptr->shape) {
+      new_shape.push_back(Substitute(dim, var_map_));
+    }
+    for (const auto& stride : buffer_ptr->strides) {
+      new_shape.push_back(Substitute(stride, var_map_));
+    }
+    buffer_ptr->elem_offset = Substitute(buffer_ptr->elem_offset, var_map_);
+    buffer_ptr->shape = std::move(new_shape);
+    buffer_ptr->strides = std::move(new_stride);
+    return buffer;
+  }
+
+  Range MutateRange(const Range& range) {
+    PrimExpr min = this->VisitExpr(range->min);
+    PrimExpr extent = this->VisitExpr(range->extent);
+    if (min.same_as(range->min) && extent.same_as(range->extent)) {
+      return range;
+    } else {
+      ObjectPtr<RangeNode> n = CopyOnWrite(range.get());

Review comment:
       create a new range instead of using copyonwrite

##########
File path: src/tir/ir/specialize.cc
##########
@@ -0,0 +1,326 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/tir/ir/specialize.cc
+ * \brief Specialize parameters of PrimFunc.
+ */
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/function.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include <functional>
+
+#include "functor_common.h"
+
+namespace tvm {
+namespace tir {
+
+using VarMap = std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual>;
+
+/**************** Helper functions ****************/
+
+/*! \brief Helper function to check whether the given var is in function parameter list. */
+inline bool IsParam(const PrimFunc& func, const Var& param) {
+  return std::any_of(func->params.begin(), func->params.end(),
+                     [&](const Var& var) { return var.same_as(param); });
+}
+
+/*! \brief Mutator to specialize function and remove const parameters */
+class PrimFuncSpecializer : public StmtExprMutator {
+ public:
+  explicit PrimFuncSpecializer(VarMap var_map) : var_map_(var_map) {}
+
+  static PrimFunc Specialize(PrimFunc f, const VarMap& var_map) {
+    PrimFuncSpecializer specializer(var_map);
+    // Updating Buffer map
+    Map<Var, Buffer> buffer_map;
+    for (const auto& it : f->buffer_map) {
+      const Var& var = it.first;
+      const Buffer& buffer = it.second;
+      Buffer new_buffer = specializer.MutateBuffer(buffer);
+      buffer_map.Set(var, new_buffer);
+      if (!new_buffer.same_as(buffer)) {
+        specializer.buffer_map_[buffer] = new_buffer;
+      }
+    }
+
+    // Updating parmeters
+    Array<Var> params;
+    for (const auto& var : f->params) {
+      // Remove parmeters which has been specialized.
+      if (var_map.find(var) == var_map.end()) {
+        params.push_back(var);
+      }
+    }
+
+    PrimFuncNode* f_ptr = f.CopyOnWrite();
+    f_ptr->params = std::move(params);
+    f_ptr->buffer_map = std::move(buffer_map);
+    f_ptr->body = specializer(std::move(f_ptr->body));
+
+    // Updating attrs
+    if (f->attrs.defined()) {
+      auto& attr_dict = f_ptr->attrs.CopyOnWrite()->dict;
+      for (const auto& kv : attr_dict) {
+        const String& key = kv.first;
+        const ObjectRef& value = kv.second;
+        if (value->IsInstance<PrimExprNode>()) {
+          attr_dict.Set(key, Substitute(Downcast<PrimExpr>(value), var_map));
+        }
+      }
+    }
+    return f;
+  }
+
+ private:
+  Stmt VisitStmt_(const BlockNode* op) final {
+    Array<Buffer> alloc_buffers = MutateArray(
+        op->alloc_buffers,
+        std::bind(&PrimFuncSpecializer::MutateAllocBuffer, this, std::placeholders::_1));
+    Array<BufferRegion> reads = MutateArray(
+        op->reads,
+        std::bind(&PrimFuncSpecializer::MutateBufferRegion, this, std::placeholders::_1));
+    Array<BufferRegion> writes = MutateArray(
+        op->writes,
+        std::bind(&PrimFuncSpecializer::MutateBufferRegion, this, std::placeholders::_1));
+    Array<IterVar> block_vars = MutateArray(
+        op->iter_vars, std::bind(&PrimFuncSpecializer::MutateIterVar, this, std::placeholders::_1));
+    Optional<Stmt> init = NullOpt;
+    if (op->init.defined()) {
+      init = VisitStmt(op->init.value());
+    }
+    Stmt body = VisitStmt(op->body);
+
+    if (alloc_buffers.same_as(op->alloc_buffers) && reads.same_as(op->reads) &&
+        writes.same_as(op->writes) && block_vars.same_as(op->iter_vars) && body.same_as(op->body) &&
+        init.same_as(op->init)) {
+      return GetRef<Block>(op);
+    } else {
+      ObjectPtr<BlockNode> n = CopyOnWrite(op);
+      n->alloc_buffers = std::move(alloc_buffers);
+      n->reads = std::move(reads);
+      n->writes = std::move(writes);
+      n->iter_vars = std::move(block_vars);
+      n->body = std::move(body);
+      n->init = std::move(init);
+      return Stmt(n);
+    }
+  }
+
+  Stmt VisitStmt_(const BufferStoreNode* op) final {
+    auto it = buffer_map_.find(op->buffer);
+    if (it == buffer_map_.end()) {
+      return GetRef<BufferStore>(op);
+    }
+
+    PrimExpr value = VisitExpr(op->value);
+    Array<PrimExpr> indices =
+        MutateArray(op->indices, [this](const PrimExpr& e) { return this->VisitExpr(e); });
+
+    auto n = CopyOnWrite(op);
+    n->buffer = it->second;
+    n->value = std::move(value);
+    n->indices = std::move(indices);
+    return Stmt(n);
+  }
+
+  PrimExpr VisitExpr_(const BufferLoadNode* op) final {
+    auto it = buffer_map_.find(op->buffer);
+    if (it == buffer_map_.end()) {
+      return GetRef<BufferLoad>(op);
+    }
+
+    Array<PrimExpr> indices =
+        MutateArray(op->indices, [this](const PrimExpr& e) { return this->VisitExpr(e); });
+
+    auto n = CopyOnWrite(op);
+    n->buffer = it->second;
+    n->indices = std::move(indices);
+    return PrimExpr(n);
+  }
+
+  PrimExpr VisitExpr_(const VarNode* op) final {
+    auto it = var_map_.find(GetRef<Var>(op));
+    if (it == var_map_.end()) {
+      return GetRef<PrimExpr>(op);
+    } else {
+      return it->second;
+    }
+  }
+
+ private:
+  Buffer MutateBuffer(Buffer buffer) const {
+    BufferNode* buffer_ptr = buffer.CopyOnWrite();
+    Array<PrimExpr> new_shape, new_stride;
+    new_shape.reserve(buffer_ptr->shape.size());
+    new_shape.reserve(buffer_ptr->strides.size());
+    for (const auto& dim : buffer_ptr->shape) {
+      new_shape.push_back(Substitute(dim, var_map_));
+    }
+    for (const auto& stride : buffer_ptr->strides) {
+      new_shape.push_back(Substitute(stride, var_map_));
+    }
+    buffer_ptr->elem_offset = Substitute(buffer_ptr->elem_offset, var_map_);
+    buffer_ptr->shape = std::move(new_shape);
+    buffer_ptr->strides = std::move(new_stride);
+    return buffer;
+  }
+
+  Range MutateRange(const Range& range) {
+    PrimExpr min = this->VisitExpr(range->min);
+    PrimExpr extent = this->VisitExpr(range->extent);
+    if (min.same_as(range->min) && extent.same_as(range->extent)) {
+      return range;
+    } else {
+      ObjectPtr<RangeNode> n = CopyOnWrite(range.get());
+      n->min = std::move(min);
+      n->extent = std::move(extent);
+      return Range(n);
+    }
+  }
+
+  IterVar MutateIterVar(const IterVar& iter_var) {
+    Range range = MutateRange(iter_var->dom);
+    if (range.same_as(iter_var->dom)) {
+      return iter_var;
+    } else {
+      auto n = CopyOnWrite(iter_var.get());

Review comment:
       create a new IterVar instead of copyonWrite here




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] tqchen commented on a change in pull request #8354: [TIR][TVMScript] specialize

Posted by GitBox <gi...@apache.org>.
tqchen commented on a change in pull request #8354:
URL: https://github.com/apache/tvm/pull/8354#discussion_r659753085



##########
File path: src/tir/ir/specialize.cc
##########
@@ -0,0 +1,326 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/tir/ir/specialize.cc
+ * \brief Specialize parameters of PrimFunc.
+ */
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/function.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include <functional>
+
+#include "functor_common.h"
+
+namespace tvm {
+namespace tir {
+
+using VarMap = std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual>;
+
+/**************** Helper functions ****************/
+
+/*! \brief Helper function to check whether the given var is in function parameter list. */
+inline bool IsParam(const PrimFunc& func, const Var& param) {
+  return std::any_of(func->params.begin(), func->params.end(),
+                     [&](const Var& var) { return var.same_as(param); });
+}
+
+/*! \brief Mutator to specialize function and remove const parameters */
+class PrimFuncSpecializer : public StmtExprMutator {
+ public:
+  explicit PrimFuncSpecializer(VarMap var_map) : var_map_(var_map) {}
+
+  static PrimFunc Specialize(PrimFunc f, const VarMap& var_map) {
+    PrimFuncSpecializer specializer(var_map);
+    // Updating Buffer map
+    Map<Var, Buffer> buffer_map;
+    for (const auto& it : f->buffer_map) {
+      const Var& var = it.first;
+      const Buffer& buffer = it.second;
+      Buffer new_buffer = specializer.MutateBuffer(buffer);
+      buffer_map.Set(var, new_buffer);
+      if (!new_buffer.same_as(buffer)) {
+        specializer.buffer_map_[buffer] = new_buffer;
+      }
+    }
+
+    // Updating parmeters
+    Array<Var> params;
+    for (const auto& var : f->params) {
+      // Remove parmeters which has been specialized.
+      if (var_map.find(var) == var_map.end()) {
+        params.push_back(var);
+      }
+    }
+
+    PrimFuncNode* f_ptr = f.CopyOnWrite();
+    f_ptr->params = std::move(params);
+    f_ptr->buffer_map = std::move(buffer_map);
+    f_ptr->body = specializer(std::move(f_ptr->body));
+
+    // Updating attrs
+    if (f->attrs.defined()) {
+      auto& attr_dict = f_ptr->attrs.CopyOnWrite()->dict;
+      for (const auto& kv : attr_dict) {
+        const String& key = kv.first;
+        const ObjectRef& value = kv.second;
+        if (value->IsInstance<PrimExprNode>()) {
+          attr_dict.Set(key, Substitute(Downcast<PrimExpr>(value), var_map));
+        }
+      }
+    }
+    return f;
+  }
+
+ private:
+  Stmt VisitStmt_(const BlockNode* op) final {
+    Array<Buffer> alloc_buffers = MutateArray(
+        op->alloc_buffers,
+        std::bind(&PrimFuncSpecializer::MutateAllocBuffer, this, std::placeholders::_1));
+    Array<BufferRegion> reads = MutateArray(
+        op->reads,
+        std::bind(&PrimFuncSpecializer::MutateBufferRegion, this, std::placeholders::_1));
+    Array<BufferRegion> writes = MutateArray(
+        op->writes,
+        std::bind(&PrimFuncSpecializer::MutateBufferRegion, this, std::placeholders::_1));
+    Array<IterVar> block_vars = MutateArray(
+        op->iter_vars, std::bind(&PrimFuncSpecializer::MutateIterVar, this, std::placeholders::_1));
+    Optional<Stmt> init = NullOpt;
+    if (op->init.defined()) {
+      init = VisitStmt(op->init.value());
+    }
+    Stmt body = VisitStmt(op->body);
+
+    if (alloc_buffers.same_as(op->alloc_buffers) && reads.same_as(op->reads) &&
+        writes.same_as(op->writes) && block_vars.same_as(op->iter_vars) && body.same_as(op->body) &&
+        init.same_as(op->init)) {
+      return GetRef<Block>(op);
+    } else {
+      ObjectPtr<BlockNode> n = CopyOnWrite(op);
+      n->alloc_buffers = std::move(alloc_buffers);
+      n->reads = std::move(reads);
+      n->writes = std::move(writes);
+      n->iter_vars = std::move(block_vars);
+      n->body = std::move(body);
+      n->init = std::move(init);
+      return Stmt(n);
+    }
+  }
+
+  Stmt VisitStmt_(const BufferStoreNode* op) final {
+    auto it = buffer_map_.find(op->buffer);
+    if (it == buffer_map_.end()) {
+      return GetRef<BufferStore>(op);
+    }
+
+    PrimExpr value = VisitExpr(op->value);
+    Array<PrimExpr> indices =
+        MutateArray(op->indices, [this](const PrimExpr& e) { return this->VisitExpr(e); });
+
+    auto n = CopyOnWrite(op);
+    n->buffer = it->second;
+    n->value = std::move(value);
+    n->indices = std::move(indices);
+    return Stmt(n);
+  }
+
+  PrimExpr VisitExpr_(const BufferLoadNode* op) final {
+    auto it = buffer_map_.find(op->buffer);
+    if (it == buffer_map_.end()) {
+      return GetRef<BufferLoad>(op);
+    }
+
+    Array<PrimExpr> indices =
+        MutateArray(op->indices, [this](const PrimExpr& e) { return this->VisitExpr(e); });
+
+    auto n = CopyOnWrite(op);
+    n->buffer = it->second;
+    n->indices = std::move(indices);
+    return PrimExpr(n);
+  }
+
+  PrimExpr VisitExpr_(const VarNode* op) final {
+    auto it = var_map_.find(GetRef<Var>(op));
+    if (it == var_map_.end()) {
+      return GetRef<PrimExpr>(op);
+    } else {
+      return it->second;
+    }
+  }
+
+ private:
+  Buffer MutateBuffer(Buffer buffer) const {
+    BufferNode* buffer_ptr = buffer.CopyOnWrite();
+    Array<PrimExpr> new_shape, new_stride;
+    new_shape.reserve(buffer_ptr->shape.size());
+    new_shape.reserve(buffer_ptr->strides.size());
+    for (const auto& dim : buffer_ptr->shape) {
+      new_shape.push_back(Substitute(dim, var_map_));

Review comment:
       can we reuse MutateArray for this?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] Hzfengsy commented on a change in pull request #8354: [TIR][TVMScript] specialize

Posted by GitBox <gi...@apache.org>.
Hzfengsy commented on a change in pull request #8354:
URL: https://github.com/apache/tvm/pull/8354#discussion_r660399589



##########
File path: src/tir/ir/specialize.cc
##########
@@ -0,0 +1,326 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/tir/ir/specialize.cc
+ * \brief Specialize parameters of PrimFunc.
+ */
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/function.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include <functional>
+
+#include "functor_common.h"
+
+namespace tvm {
+namespace tir {
+
+using VarMap = std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual>;
+
+/**************** Helper functions ****************/
+
+/*! \brief Helper function to check whether the given var is in function parameter list. */
+inline bool IsParam(const PrimFunc& func, const Var& param) {
+  return std::any_of(func->params.begin(), func->params.end(),
+                     [&](const Var& var) { return var.same_as(param); });
+}
+
+/*! \brief Mutator to specialize function and remove const parameters */
+class PrimFuncSpecializer : public StmtExprMutator {
+ public:
+  explicit PrimFuncSpecializer(VarMap var_map) : var_map_(var_map) {}
+
+  static PrimFunc Specialize(PrimFunc f, const VarMap& var_map) {
+    PrimFuncSpecializer specializer(var_map);
+    // Updating Buffer map
+    Map<Var, Buffer> buffer_map;
+    for (const auto& it : f->buffer_map) {
+      const Var& var = it.first;
+      const Buffer& buffer = it.second;
+      Buffer new_buffer = specializer.MutateBuffer(buffer);
+      buffer_map.Set(var, new_buffer);
+      if (!new_buffer.same_as(buffer)) {
+        specializer.buffer_map_[buffer] = new_buffer;
+      }
+    }
+
+    // Updating parmeters
+    Array<Var> params;
+    for (const auto& var : f->params) {
+      // Remove parmeters which has been specialized.
+      if (var_map.find(var) == var_map.end()) {
+        params.push_back(var);
+      }
+    }
+
+    PrimFuncNode* f_ptr = f.CopyOnWrite();
+    f_ptr->params = std::move(params);
+    f_ptr->buffer_map = std::move(buffer_map);
+    f_ptr->body = specializer(std::move(f_ptr->body));
+
+    // Updating attrs
+    if (f->attrs.defined()) {
+      auto& attr_dict = f_ptr->attrs.CopyOnWrite()->dict;
+      for (const auto& kv : attr_dict) {
+        const String& key = kv.first;
+        const ObjectRef& value = kv.second;
+        if (value->IsInstance<PrimExprNode>()) {
+          attr_dict.Set(key, Substitute(Downcast<PrimExpr>(value), var_map));
+        }
+      }
+    }
+    return f;
+  }
+
+ private:
+  Stmt VisitStmt_(const BlockNode* op) final {
+    Array<Buffer> alloc_buffers = MutateArray(
+        op->alloc_buffers,
+        std::bind(&PrimFuncSpecializer::MutateAllocBuffer, this, std::placeholders::_1));
+    Array<BufferRegion> reads = MutateArray(
+        op->reads,
+        std::bind(&PrimFuncSpecializer::MutateBufferRegion, this, std::placeholders::_1));
+    Array<BufferRegion> writes = MutateArray(
+        op->writes,
+        std::bind(&PrimFuncSpecializer::MutateBufferRegion, this, std::placeholders::_1));
+    Array<IterVar> block_vars = MutateArray(
+        op->iter_vars, std::bind(&PrimFuncSpecializer::MutateIterVar, this, std::placeholders::_1));
+    Optional<Stmt> init = NullOpt;
+    if (op->init.defined()) {
+      init = VisitStmt(op->init.value());
+    }
+    Stmt body = VisitStmt(op->body);
+
+    if (alloc_buffers.same_as(op->alloc_buffers) && reads.same_as(op->reads) &&
+        writes.same_as(op->writes) && block_vars.same_as(op->iter_vars) && body.same_as(op->body) &&
+        init.same_as(op->init)) {
+      return GetRef<Block>(op);
+    } else {
+      ObjectPtr<BlockNode> n = CopyOnWrite(op);
+      n->alloc_buffers = std::move(alloc_buffers);
+      n->reads = std::move(reads);
+      n->writes = std::move(writes);
+      n->iter_vars = std::move(block_vars);
+      n->body = std::move(body);
+      n->init = std::move(init);
+      return Stmt(n);
+    }
+  }
+
+  Stmt VisitStmt_(const BufferStoreNode* op) final {
+    auto it = buffer_map_.find(op->buffer);
+    if (it == buffer_map_.end()) {
+      return GetRef<BufferStore>(op);
+    }
+
+    PrimExpr value = VisitExpr(op->value);
+    Array<PrimExpr> indices =
+        MutateArray(op->indices, [this](const PrimExpr& e) { return this->VisitExpr(e); });
+
+    auto n = CopyOnWrite(op);

Review comment:
       In this case, I return the original one before at line 131. If the `buffer` is in `buffer_map_` it will always change




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] tqchen commented on a change in pull request #8354: [TIR][TVMScript] specialize

Posted by GitBox <gi...@apache.org>.
tqchen commented on a change in pull request #8354:
URL: https://github.com/apache/tvm/pull/8354#discussion_r659738283



##########
File path: src/tir/ir/specialize.cc
##########
@@ -0,0 +1,326 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/tir/ir/specialize.cc
+ * \brief Specialize parameters of PrimFunc.
+ */
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/function.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include <functional>
+
+#include "functor_common.h"
+
+namespace tvm {
+namespace tir {
+
+using VarMap = std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual>;
+
+/**************** Helper functions ****************/
+
+/*! \brief Helper function to check whether the given var is in function parameter list. */
+inline bool IsParam(const PrimFunc& func, const Var& param) {
+  return std::any_of(func->params.begin(), func->params.end(),
+                     [&](const Var& var) { return var.same_as(param); });
+}
+
+/*! \brief Mutator to specialize function and remove const parameters */
+class PrimFuncSpecializer : public StmtExprMutator {
+ public:
+  explicit PrimFuncSpecializer(VarMap var_map) : var_map_(var_map) {}
+
+  static PrimFunc Specialize(PrimFunc f, const VarMap& var_map) {
+    PrimFuncSpecializer specializer(var_map);
+    // Updating Buffer map
+    Map<Var, Buffer> buffer_map;
+    for (const auto& it : f->buffer_map) {
+      const Var& var = it.first;
+      const Buffer& buffer = it.second;
+      Buffer new_buffer = specializer.MutateBuffer(buffer);
+      buffer_map.Set(var, new_buffer);
+      if (!new_buffer.same_as(buffer)) {
+        specializer.buffer_map_[buffer] = new_buffer;
+      }
+    }
+
+    // Updating parmeters
+    Array<Var> params;
+    for (const auto& var : f->params) {
+      // Remove parmeters which has been specialized.
+      if (var_map.find(var) == var_map.end()) {
+        params.push_back(var);
+      }
+    }
+
+    PrimFuncNode* f_ptr = f.CopyOnWrite();
+    f_ptr->params = std::move(params);
+    f_ptr->buffer_map = std::move(buffer_map);
+    f_ptr->body = specializer(std::move(f_ptr->body));
+
+    // Updating attrs
+    if (f->attrs.defined()) {
+      auto& attr_dict = f_ptr->attrs.CopyOnWrite()->dict;
+      for (const auto& kv : attr_dict) {
+        const String& key = kv.first;
+        const ObjectRef& value = kv.second;
+        if (value->IsInstance<PrimExprNode>()) {
+          attr_dict.Set(key, Substitute(Downcast<PrimExpr>(value), var_map));
+        }
+      }
+    }
+    return f;
+  }
+
+ private:
+  Stmt VisitStmt_(const BlockNode* op) final {
+    Array<Buffer> alloc_buffers = MutateArray(
+        op->alloc_buffers,
+        std::bind(&PrimFuncSpecializer::MutateAllocBuffer, this, std::placeholders::_1));
+    Array<BufferRegion> reads = MutateArray(
+        op->reads,
+        std::bind(&PrimFuncSpecializer::MutateBufferRegion, this, std::placeholders::_1));
+    Array<BufferRegion> writes = MutateArray(
+        op->writes,
+        std::bind(&PrimFuncSpecializer::MutateBufferRegion, this, std::placeholders::_1));
+    Array<IterVar> block_vars = MutateArray(
+        op->iter_vars, std::bind(&PrimFuncSpecializer::MutateIterVar, this, std::placeholders::_1));
+    Optional<Stmt> init = NullOpt;
+    if (op->init.defined()) {
+      init = VisitStmt(op->init.value());
+    }
+    Stmt body = VisitStmt(op->body);
+
+    if (alloc_buffers.same_as(op->alloc_buffers) && reads.same_as(op->reads) &&
+        writes.same_as(op->writes) && block_vars.same_as(op->iter_vars) && body.same_as(op->body) &&
+        init.same_as(op->init)) {
+      return GetRef<Block>(op);
+    } else {
+      ObjectPtr<BlockNode> n = CopyOnWrite(op);
+      n->alloc_buffers = std::move(alloc_buffers);
+      n->reads = std::move(reads);
+      n->writes = std::move(writes);
+      n->iter_vars = std::move(block_vars);
+      n->body = std::move(body);
+      n->init = std::move(init);
+      return Stmt(n);
+    }
+  }
+
+  Stmt VisitStmt_(const BufferStoreNode* op) final {
+    auto it = buffer_map_.find(op->buffer);
+    if (it == buffer_map_.end()) {
+      return GetRef<BufferStore>(op);
+    }
+
+    PrimExpr value = VisitExpr(op->value);
+    Array<PrimExpr> indices =
+        MutateArray(op->indices, [this](const PrimExpr& e) { return this->VisitExpr(e); });
+
+    auto n = CopyOnWrite(op);
+    n->buffer = it->second;
+    n->value = std::move(value);
+    n->indices = std::move(indices);
+    return Stmt(n);
+  }
+
+  PrimExpr VisitExpr_(const BufferLoadNode* op) final {
+    auto it = buffer_map_.find(op->buffer);
+    if (it == buffer_map_.end()) {
+      return GetRef<BufferLoad>(op);
+    }
+
+    Array<PrimExpr> indices =
+        MutateArray(op->indices, [this](const PrimExpr& e) { return this->VisitExpr(e); });
+
+    auto n = CopyOnWrite(op);
+    n->buffer = it->second;
+    n->indices = std::move(indices);
+    return PrimExpr(n);
+  }
+
+  PrimExpr VisitExpr_(const VarNode* op) final {
+    auto it = var_map_.find(GetRef<Var>(op));
+    if (it == var_map_.end()) {
+      return GetRef<PrimExpr>(op);
+    } else {
+      return it->second;
+    }
+  }
+
+ private:
+  Buffer MutateBuffer(Buffer buffer) const {
+    BufferNode* buffer_ptr = buffer.CopyOnWrite();
+    Array<PrimExpr> new_shape, new_stride;
+    new_shape.reserve(buffer_ptr->shape.size());
+    new_shape.reserve(buffer_ptr->strides.size());
+    for (const auto& dim : buffer_ptr->shape) {
+      new_shape.push_back(Substitute(dim, var_map_));
+    }
+    for (const auto& stride : buffer_ptr->strides) {
+      new_shape.push_back(Substitute(stride, var_map_));

Review comment:
       new strides, please add a regression test




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] tqchen commented on a change in pull request #8354: [TIR][TVMScript] specialize

Posted by GitBox <gi...@apache.org>.
tqchen commented on a change in pull request #8354:
URL: https://github.com/apache/tvm/pull/8354#discussion_r659738159



##########
File path: src/tir/ir/specialize.cc
##########
@@ -0,0 +1,326 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/tir/ir/specialize.cc
+ * \brief Specialize parameters of PrimFunc.
+ */
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/function.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include <functional>
+
+#include "functor_common.h"
+
+namespace tvm {
+namespace tir {
+
+using VarMap = std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual>;
+
+/**************** Helper functions ****************/
+
+/*! \brief Helper function to check whether the given var is in function parameter list. */
+inline bool IsParam(const PrimFunc& func, const Var& param) {
+  return std::any_of(func->params.begin(), func->params.end(),
+                     [&](const Var& var) { return var.same_as(param); });
+}
+
+/*! \brief Mutator to specialize function and remove const parameters */
+class PrimFuncSpecializer : public StmtExprMutator {
+ public:
+  explicit PrimFuncSpecializer(VarMap var_map) : var_map_(var_map) {}
+
+  static PrimFunc Specialize(PrimFunc f, const VarMap& var_map) {
+    PrimFuncSpecializer specializer(var_map);
+    // Updating Buffer map
+    Map<Var, Buffer> buffer_map;
+    for (const auto& it : f->buffer_map) {
+      const Var& var = it.first;
+      const Buffer& buffer = it.second;
+      Buffer new_buffer = specializer.MutateBuffer(buffer);
+      buffer_map.Set(var, new_buffer);
+      if (!new_buffer.same_as(buffer)) {
+        specializer.buffer_map_[buffer] = new_buffer;
+      }
+    }
+
+    // Updating parmeters
+    Array<Var> params;
+    for (const auto& var : f->params) {
+      // Remove parmeters which has been specialized.
+      if (var_map.find(var) == var_map.end()) {
+        params.push_back(var);
+      }
+    }
+
+    PrimFuncNode* f_ptr = f.CopyOnWrite();
+    f_ptr->params = std::move(params);
+    f_ptr->buffer_map = std::move(buffer_map);
+    f_ptr->body = specializer(std::move(f_ptr->body));
+
+    // Updating attrs
+    if (f->attrs.defined()) {
+      auto& attr_dict = f_ptr->attrs.CopyOnWrite()->dict;
+      for (const auto& kv : attr_dict) {
+        const String& key = kv.first;
+        const ObjectRef& value = kv.second;
+        if (value->IsInstance<PrimExprNode>()) {
+          attr_dict.Set(key, Substitute(Downcast<PrimExpr>(value), var_map));
+        }
+      }
+    }
+    return f;
+  }
+
+ private:
+  Stmt VisitStmt_(const BlockNode* op) final {
+    Array<Buffer> alloc_buffers = MutateArray(
+        op->alloc_buffers,
+        std::bind(&PrimFuncSpecializer::MutateAllocBuffer, this, std::placeholders::_1));
+    Array<BufferRegion> reads = MutateArray(
+        op->reads,
+        std::bind(&PrimFuncSpecializer::MutateBufferRegion, this, std::placeholders::_1));
+    Array<BufferRegion> writes = MutateArray(
+        op->writes,
+        std::bind(&PrimFuncSpecializer::MutateBufferRegion, this, std::placeholders::_1));
+    Array<IterVar> block_vars = MutateArray(
+        op->iter_vars, std::bind(&PrimFuncSpecializer::MutateIterVar, this, std::placeholders::_1));
+    Optional<Stmt> init = NullOpt;
+    if (op->init.defined()) {
+      init = VisitStmt(op->init.value());
+    }
+    Stmt body = VisitStmt(op->body);
+
+    if (alloc_buffers.same_as(op->alloc_buffers) && reads.same_as(op->reads) &&
+        writes.same_as(op->writes) && block_vars.same_as(op->iter_vars) && body.same_as(op->body) &&
+        init.same_as(op->init)) {
+      return GetRef<Block>(op);
+    } else {
+      ObjectPtr<BlockNode> n = CopyOnWrite(op);
+      n->alloc_buffers = std::move(alloc_buffers);
+      n->reads = std::move(reads);
+      n->writes = std::move(writes);
+      n->iter_vars = std::move(block_vars);
+      n->body = std::move(body);
+      n->init = std::move(init);
+      return Stmt(n);
+    }
+  }
+
+  Stmt VisitStmt_(const BufferStoreNode* op) final {
+    auto it = buffer_map_.find(op->buffer);
+    if (it == buffer_map_.end()) {
+      return GetRef<BufferStore>(op);
+    }
+
+    PrimExpr value = VisitExpr(op->value);
+    Array<PrimExpr> indices =
+        MutateArray(op->indices, [this](const PrimExpr& e) { return this->VisitExpr(e); });
+
+    auto n = CopyOnWrite(op);
+    n->buffer = it->second;
+    n->value = std::move(value);
+    n->indices = std::move(indices);
+    return Stmt(n);
+  }
+
+  PrimExpr VisitExpr_(const BufferLoadNode* op) final {
+    auto it = buffer_map_.find(op->buffer);
+    if (it == buffer_map_.end()) {
+      return GetRef<BufferLoad>(op);
+    }
+
+    Array<PrimExpr> indices =
+        MutateArray(op->indices, [this](const PrimExpr& e) { return this->VisitExpr(e); });
+
+    auto n = CopyOnWrite(op);
+    n->buffer = it->second;
+    n->indices = std::move(indices);
+    return PrimExpr(n);
+  }
+
+  PrimExpr VisitExpr_(const VarNode* op) final {
+    auto it = var_map_.find(GetRef<Var>(op));
+    if (it == var_map_.end()) {
+      return GetRef<PrimExpr>(op);
+    } else {
+      return it->second;
+    }
+  }
+
+ private:
+  Buffer MutateBuffer(Buffer buffer) const {
+    BufferNode* buffer_ptr = buffer.CopyOnWrite();
+    Array<PrimExpr> new_shape, new_stride;
+    new_shape.reserve(buffer_ptr->shape.size());
+    new_shape.reserve(buffer_ptr->strides.size());

Review comment:
       new_strides

##########
File path: src/tir/ir/specialize.cc
##########
@@ -0,0 +1,326 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/tir/ir/specialize.cc
+ * \brief Specialize parameters of PrimFunc.
+ */
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/function.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include <functional>
+
+#include "functor_common.h"
+
+namespace tvm {
+namespace tir {
+
+using VarMap = std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual>;
+
+/**************** Helper functions ****************/
+
+/*! \brief Helper function to check whether the given var is in function parameter list. */
+inline bool IsParam(const PrimFunc& func, const Var& param) {
+  return std::any_of(func->params.begin(), func->params.end(),
+                     [&](const Var& var) { return var.same_as(param); });
+}
+
+/*! \brief Mutator to specialize function and remove const parameters */
+class PrimFuncSpecializer : public StmtExprMutator {
+ public:
+  explicit PrimFuncSpecializer(VarMap var_map) : var_map_(var_map) {}
+
+  static PrimFunc Specialize(PrimFunc f, const VarMap& var_map) {
+    PrimFuncSpecializer specializer(var_map);
+    // Updating Buffer map
+    Map<Var, Buffer> buffer_map;
+    for (const auto& it : f->buffer_map) {
+      const Var& var = it.first;
+      const Buffer& buffer = it.second;
+      Buffer new_buffer = specializer.MutateBuffer(buffer);
+      buffer_map.Set(var, new_buffer);
+      if (!new_buffer.same_as(buffer)) {
+        specializer.buffer_map_[buffer] = new_buffer;
+      }
+    }
+
+    // Updating parmeters
+    Array<Var> params;
+    for (const auto& var : f->params) {
+      // Remove parmeters which has been specialized.
+      if (var_map.find(var) == var_map.end()) {
+        params.push_back(var);
+      }
+    }
+
+    PrimFuncNode* f_ptr = f.CopyOnWrite();
+    f_ptr->params = std::move(params);
+    f_ptr->buffer_map = std::move(buffer_map);
+    f_ptr->body = specializer(std::move(f_ptr->body));
+
+    // Updating attrs
+    if (f->attrs.defined()) {
+      auto& attr_dict = f_ptr->attrs.CopyOnWrite()->dict;
+      for (const auto& kv : attr_dict) {
+        const String& key = kv.first;
+        const ObjectRef& value = kv.second;
+        if (value->IsInstance<PrimExprNode>()) {
+          attr_dict.Set(key, Substitute(Downcast<PrimExpr>(value), var_map));
+        }
+      }
+    }
+    return f;
+  }
+
+ private:
+  Stmt VisitStmt_(const BlockNode* op) final {
+    Array<Buffer> alloc_buffers = MutateArray(
+        op->alloc_buffers,
+        std::bind(&PrimFuncSpecializer::MutateAllocBuffer, this, std::placeholders::_1));
+    Array<BufferRegion> reads = MutateArray(
+        op->reads,
+        std::bind(&PrimFuncSpecializer::MutateBufferRegion, this, std::placeholders::_1));
+    Array<BufferRegion> writes = MutateArray(
+        op->writes,
+        std::bind(&PrimFuncSpecializer::MutateBufferRegion, this, std::placeholders::_1));
+    Array<IterVar> block_vars = MutateArray(
+        op->iter_vars, std::bind(&PrimFuncSpecializer::MutateIterVar, this, std::placeholders::_1));
+    Optional<Stmt> init = NullOpt;
+    if (op->init.defined()) {
+      init = VisitStmt(op->init.value());
+    }
+    Stmt body = VisitStmt(op->body);
+
+    if (alloc_buffers.same_as(op->alloc_buffers) && reads.same_as(op->reads) &&
+        writes.same_as(op->writes) && block_vars.same_as(op->iter_vars) && body.same_as(op->body) &&
+        init.same_as(op->init)) {
+      return GetRef<Block>(op);
+    } else {
+      ObjectPtr<BlockNode> n = CopyOnWrite(op);
+      n->alloc_buffers = std::move(alloc_buffers);
+      n->reads = std::move(reads);
+      n->writes = std::move(writes);
+      n->iter_vars = std::move(block_vars);
+      n->body = std::move(body);
+      n->init = std::move(init);
+      return Stmt(n);
+    }
+  }
+
+  Stmt VisitStmt_(const BufferStoreNode* op) final {
+    auto it = buffer_map_.find(op->buffer);
+    if (it == buffer_map_.end()) {
+      return GetRef<BufferStore>(op);
+    }
+
+    PrimExpr value = VisitExpr(op->value);
+    Array<PrimExpr> indices =
+        MutateArray(op->indices, [this](const PrimExpr& e) { return this->VisitExpr(e); });
+
+    auto n = CopyOnWrite(op);
+    n->buffer = it->second;
+    n->value = std::move(value);
+    n->indices = std::move(indices);
+    return Stmt(n);
+  }
+
+  PrimExpr VisitExpr_(const BufferLoadNode* op) final {
+    auto it = buffer_map_.find(op->buffer);
+    if (it == buffer_map_.end()) {
+      return GetRef<BufferLoad>(op);
+    }
+
+    Array<PrimExpr> indices =
+        MutateArray(op->indices, [this](const PrimExpr& e) { return this->VisitExpr(e); });
+
+    auto n = CopyOnWrite(op);
+    n->buffer = it->second;
+    n->indices = std::move(indices);
+    return PrimExpr(n);
+  }
+
+  PrimExpr VisitExpr_(const VarNode* op) final {
+    auto it = var_map_.find(GetRef<Var>(op));
+    if (it == var_map_.end()) {
+      return GetRef<PrimExpr>(op);
+    } else {
+      return it->second;
+    }
+  }
+
+ private:
+  Buffer MutateBuffer(Buffer buffer) const {
+    BufferNode* buffer_ptr = buffer.CopyOnWrite();
+    Array<PrimExpr> new_shape, new_stride;
+    new_shape.reserve(buffer_ptr->shape.size());
+    new_shape.reserve(buffer_ptr->strides.size());
+    for (const auto& dim : buffer_ptr->shape) {
+      new_shape.push_back(Substitute(dim, var_map_));
+    }
+    for (const auto& stride : buffer_ptr->strides) {
+      new_shape.push_back(Substitute(stride, var_map_));

Review comment:
       new strides, please add a regression test

##########
File path: src/tir/ir/specialize.cc
##########
@@ -0,0 +1,326 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/tir/ir/specialize.cc
+ * \brief Specialize parameters of PrimFunc.
+ */
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/function.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include <functional>
+
+#include "functor_common.h"
+
+namespace tvm {
+namespace tir {
+
+using VarMap = std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual>;
+
+/**************** Helper functions ****************/
+
+/*! \brief Helper function to check whether the given var is in function parameter list. */
+inline bool IsParam(const PrimFunc& func, const Var& param) {
+  return std::any_of(func->params.begin(), func->params.end(),
+                     [&](const Var& var) { return var.same_as(param); });
+}
+
+/*! \brief Mutator to specialize function and remove const parameters */
+class PrimFuncSpecializer : public StmtExprMutator {
+ public:
+  explicit PrimFuncSpecializer(VarMap var_map) : var_map_(var_map) {}
+
+  static PrimFunc Specialize(PrimFunc f, const VarMap& var_map) {
+    PrimFuncSpecializer specializer(var_map);
+    // Updating Buffer map
+    Map<Var, Buffer> buffer_map;
+    for (const auto& it : f->buffer_map) {
+      const Var& var = it.first;
+      const Buffer& buffer = it.second;
+      Buffer new_buffer = specializer.MutateBuffer(buffer);
+      buffer_map.Set(var, new_buffer);
+      if (!new_buffer.same_as(buffer)) {
+        specializer.buffer_map_[buffer] = new_buffer;
+      }
+    }
+
+    // Updating parmeters
+    Array<Var> params;
+    for (const auto& var : f->params) {
+      // Remove parmeters which has been specialized.
+      if (var_map.find(var) == var_map.end()) {
+        params.push_back(var);
+      }
+    }
+
+    PrimFuncNode* f_ptr = f.CopyOnWrite();
+    f_ptr->params = std::move(params);
+    f_ptr->buffer_map = std::move(buffer_map);
+    f_ptr->body = specializer(std::move(f_ptr->body));
+
+    // Updating attrs
+    if (f->attrs.defined()) {
+      auto& attr_dict = f_ptr->attrs.CopyOnWrite()->dict;
+      for (const auto& kv : attr_dict) {
+        const String& key = kv.first;
+        const ObjectRef& value = kv.second;
+        if (value->IsInstance<PrimExprNode>()) {
+          attr_dict.Set(key, Substitute(Downcast<PrimExpr>(value), var_map));
+        }
+      }
+    }
+    return f;
+  }
+
+ private:
+  Stmt VisitStmt_(const BlockNode* op) final {
+    Array<Buffer> alloc_buffers = MutateArray(
+        op->alloc_buffers,
+        std::bind(&PrimFuncSpecializer::MutateAllocBuffer, this, std::placeholders::_1));
+    Array<BufferRegion> reads = MutateArray(
+        op->reads,
+        std::bind(&PrimFuncSpecializer::MutateBufferRegion, this, std::placeholders::_1));
+    Array<BufferRegion> writes = MutateArray(
+        op->writes,
+        std::bind(&PrimFuncSpecializer::MutateBufferRegion, this, std::placeholders::_1));
+    Array<IterVar> block_vars = MutateArray(
+        op->iter_vars, std::bind(&PrimFuncSpecializer::MutateIterVar, this, std::placeholders::_1));
+    Optional<Stmt> init = NullOpt;
+    if (op->init.defined()) {
+      init = VisitStmt(op->init.value());
+    }
+    Stmt body = VisitStmt(op->body);
+
+    if (alloc_buffers.same_as(op->alloc_buffers) && reads.same_as(op->reads) &&
+        writes.same_as(op->writes) && block_vars.same_as(op->iter_vars) && body.same_as(op->body) &&
+        init.same_as(op->init)) {
+      return GetRef<Block>(op);
+    } else {
+      ObjectPtr<BlockNode> n = CopyOnWrite(op);
+      n->alloc_buffers = std::move(alloc_buffers);
+      n->reads = std::move(reads);
+      n->writes = std::move(writes);
+      n->iter_vars = std::move(block_vars);
+      n->body = std::move(body);
+      n->init = std::move(init);
+      return Stmt(n);
+    }
+  }
+
+  Stmt VisitStmt_(const BufferStoreNode* op) final {
+    auto it = buffer_map_.find(op->buffer);
+    if (it == buffer_map_.end()) {
+      return GetRef<BufferStore>(op);
+    }
+
+    PrimExpr value = VisitExpr(op->value);
+    Array<PrimExpr> indices =
+        MutateArray(op->indices, [this](const PrimExpr& e) { return this->VisitExpr(e); });
+
+    auto n = CopyOnWrite(op);
+    n->buffer = it->second;
+    n->value = std::move(value);
+    n->indices = std::move(indices);
+    return Stmt(n);
+  }
+
+  PrimExpr VisitExpr_(const BufferLoadNode* op) final {
+    auto it = buffer_map_.find(op->buffer);
+    if (it == buffer_map_.end()) {
+      return GetRef<BufferLoad>(op);
+    }
+
+    Array<PrimExpr> indices =
+        MutateArray(op->indices, [this](const PrimExpr& e) { return this->VisitExpr(e); });
+
+    auto n = CopyOnWrite(op);
+    n->buffer = it->second;
+    n->indices = std::move(indices);
+    return PrimExpr(n);
+  }
+
+  PrimExpr VisitExpr_(const VarNode* op) final {
+    auto it = var_map_.find(GetRef<Var>(op));
+    if (it == var_map_.end()) {
+      return GetRef<PrimExpr>(op);
+    } else {
+      return it->second;
+    }
+  }
+
+ private:
+  Buffer MutateBuffer(Buffer buffer) const {
+    BufferNode* buffer_ptr = buffer.CopyOnWrite();

Review comment:
       The following invariance must hold: if nothing is remapped, we return the same buffer. Right now this does not hold.
   
   - We need to check the new shape, strides

##########
File path: src/tir/ir/specialize.cc
##########
@@ -0,0 +1,326 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/tir/ir/specialize.cc
+ * \brief Specialize parameters of PrimFunc.
+ */
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/function.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include <functional>
+
+#include "functor_common.h"
+
+namespace tvm {
+namespace tir {
+
+using VarMap = std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual>;
+
+/**************** Helper functions ****************/
+
+/*! \brief Helper function to check whether the given var is in function parameter list. */
+inline bool IsParam(const PrimFunc& func, const Var& param) {
+  return std::any_of(func->params.begin(), func->params.end(),
+                     [&](const Var& var) { return var.same_as(param); });
+}
+
+/*! \brief Mutator to specialize function and remove const parameters */
+class PrimFuncSpecializer : public StmtExprMutator {
+ public:
+  explicit PrimFuncSpecializer(VarMap var_map) : var_map_(var_map) {}
+
+  static PrimFunc Specialize(PrimFunc f, const VarMap& var_map) {
+    PrimFuncSpecializer specializer(var_map);
+    // Updating Buffer map
+    Map<Var, Buffer> buffer_map;
+    for (const auto& it : f->buffer_map) {
+      const Var& var = it.first;
+      const Buffer& buffer = it.second;
+      Buffer new_buffer = specializer.MutateBuffer(buffer);
+      buffer_map.Set(var, new_buffer);
+      if (!new_buffer.same_as(buffer)) {
+        specializer.buffer_map_[buffer] = new_buffer;
+      }
+    }
+
+    // Updating parmeters
+    Array<Var> params;
+    for (const auto& var : f->params) {
+      // Remove parmeters which has been specialized.
+      if (var_map.find(var) == var_map.end()) {
+        params.push_back(var);
+      }
+    }
+
+    PrimFuncNode* f_ptr = f.CopyOnWrite();
+    f_ptr->params = std::move(params);
+    f_ptr->buffer_map = std::move(buffer_map);
+    f_ptr->body = specializer(std::move(f_ptr->body));
+
+    // Updating attrs
+    if (f->attrs.defined()) {

Review comment:
       Let us skip attrs updates, attrs are supposed to be invariant to the code itself.

##########
File path: src/tir/ir/specialize.cc
##########
@@ -0,0 +1,326 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/tir/ir/specialize.cc
+ * \brief Specialize parameters of PrimFunc.
+ */
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/function.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include <functional>
+
+#include "functor_common.h"
+
+namespace tvm {
+namespace tir {
+
+using VarMap = std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual>;
+
+/**************** Helper functions ****************/
+
+/*! \brief Helper function to check whether the given var is in function parameter list. */
+inline bool IsParam(const PrimFunc& func, const Var& param) {
+  return std::any_of(func->params.begin(), func->params.end(),
+                     [&](const Var& var) { return var.same_as(param); });
+}
+
+/*! \brief Mutator to specialize function and remove const parameters */
+class PrimFuncSpecializer : public StmtExprMutator {
+ public:
+  explicit PrimFuncSpecializer(VarMap var_map) : var_map_(var_map) {}
+
+  static PrimFunc Specialize(PrimFunc f, const VarMap& var_map) {
+    PrimFuncSpecializer specializer(var_map);
+    // Updating Buffer map
+    Map<Var, Buffer> buffer_map;
+    for (const auto& it : f->buffer_map) {
+      const Var& var = it.first;
+      const Buffer& buffer = it.second;
+      Buffer new_buffer = specializer.MutateBuffer(buffer);
+      buffer_map.Set(var, new_buffer);
+      if (!new_buffer.same_as(buffer)) {
+        specializer.buffer_map_[buffer] = new_buffer;
+      }
+    }
+
+    // Updating parmeters
+    Array<Var> params;
+    for (const auto& var : f->params) {
+      // Remove parmeters which has been specialized.
+      if (var_map.find(var) == var_map.end()) {
+        params.push_back(var);
+      }
+    }
+
+    PrimFuncNode* f_ptr = f.CopyOnWrite();

Review comment:
       The following invariant needs to hold, if nothing changes in the body, we should return the same function

##########
File path: src/tir/ir/specialize.cc
##########
@@ -0,0 +1,326 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/tir/ir/specialize.cc
+ * \brief Specialize parameters of PrimFunc.
+ */
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/function.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include <functional>
+
+#include "functor_common.h"
+
+namespace tvm {
+namespace tir {
+
+using VarMap = std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual>;
+
+/**************** Helper functions ****************/
+
+/*! \brief Helper function to check whether the given var is in function parameter list. */
+inline bool IsParam(const PrimFunc& func, const Var& param) {
+  return std::any_of(func->params.begin(), func->params.end(),
+                     [&](const Var& var) { return var.same_as(param); });
+}
+
+/*! \brief Mutator to specialize function and remove const parameters */
+class PrimFuncSpecializer : public StmtExprMutator {
+ public:
+  explicit PrimFuncSpecializer(VarMap var_map) : var_map_(var_map) {}
+
+  static PrimFunc Specialize(PrimFunc f, const VarMap& var_map) {
+    PrimFuncSpecializer specializer(var_map);
+    // Updating Buffer map
+    Map<Var, Buffer> buffer_map;
+    for (const auto& it : f->buffer_map) {
+      const Var& var = it.first;
+      const Buffer& buffer = it.second;
+      Buffer new_buffer = specializer.MutateBuffer(buffer);
+      buffer_map.Set(var, new_buffer);
+      if (!new_buffer.same_as(buffer)) {
+        specializer.buffer_map_[buffer] = new_buffer;
+      }
+    }
+
+    // Updating parmeters
+    Array<Var> params;
+    for (const auto& var : f->params) {
+      // Remove parmeters which has been specialized.
+      if (var_map.find(var) == var_map.end()) {
+        params.push_back(var);
+      }
+    }
+
+    PrimFuncNode* f_ptr = f.CopyOnWrite();
+    f_ptr->params = std::move(params);
+    f_ptr->buffer_map = std::move(buffer_map);
+    f_ptr->body = specializer(std::move(f_ptr->body));
+
+    // Updating attrs
+    if (f->attrs.defined()) {
+      auto& attr_dict = f_ptr->attrs.CopyOnWrite()->dict;
+      for (const auto& kv : attr_dict) {
+        const String& key = kv.first;
+        const ObjectRef& value = kv.second;
+        if (value->IsInstance<PrimExprNode>()) {
+          attr_dict.Set(key, Substitute(Downcast<PrimExpr>(value), var_map));
+        }
+      }
+    }
+    return f;
+  }
+
+ private:
+  Stmt VisitStmt_(const BlockNode* op) final {
+    Array<Buffer> alloc_buffers = MutateArray(
+        op->alloc_buffers,
+        std::bind(&PrimFuncSpecializer::MutateAllocBuffer, this, std::placeholders::_1));
+    Array<BufferRegion> reads = MutateArray(
+        op->reads,
+        std::bind(&PrimFuncSpecializer::MutateBufferRegion, this, std::placeholders::_1));
+    Array<BufferRegion> writes = MutateArray(
+        op->writes,
+        std::bind(&PrimFuncSpecializer::MutateBufferRegion, this, std::placeholders::_1));
+    Array<IterVar> block_vars = MutateArray(
+        op->iter_vars, std::bind(&PrimFuncSpecializer::MutateIterVar, this, std::placeholders::_1));
+    Optional<Stmt> init = NullOpt;
+    if (op->init.defined()) {
+      init = VisitStmt(op->init.value());
+    }
+    Stmt body = VisitStmt(op->body);
+
+    if (alloc_buffers.same_as(op->alloc_buffers) && reads.same_as(op->reads) &&
+        writes.same_as(op->writes) && block_vars.same_as(op->iter_vars) && body.same_as(op->body) &&
+        init.same_as(op->init)) {
+      return GetRef<Block>(op);
+    } else {
+      ObjectPtr<BlockNode> n = CopyOnWrite(op);
+      n->alloc_buffers = std::move(alloc_buffers);
+      n->reads = std::move(reads);
+      n->writes = std::move(writes);
+      n->iter_vars = std::move(block_vars);
+      n->body = std::move(body);
+      n->init = std::move(init);
+      return Stmt(n);
+    }
+  }
+
+  Stmt VisitStmt_(const BufferStoreNode* op) final {
+    auto it = buffer_map_.find(op->buffer);
+    if (it == buffer_map_.end()) {
+      return GetRef<BufferStore>(op);
+    }
+
+    PrimExpr value = VisitExpr(op->value);
+    Array<PrimExpr> indices =
+        MutateArray(op->indices, [this](const PrimExpr& e) { return this->VisitExpr(e); });
+
+    auto n = CopyOnWrite(op);
+    n->buffer = it->second;
+    n->value = std::move(value);
+    n->indices = std::move(indices);
+    return Stmt(n);
+  }
+
+  PrimExpr VisitExpr_(const BufferLoadNode* op) final {
+    auto it = buffer_map_.find(op->buffer);
+    if (it == buffer_map_.end()) {
+      return GetRef<BufferLoad>(op);
+    }
+
+    Array<PrimExpr> indices =
+        MutateArray(op->indices, [this](const PrimExpr& e) { return this->VisitExpr(e); });
+
+    auto n = CopyOnWrite(op);
+    n->buffer = it->second;
+    n->indices = std::move(indices);
+    return PrimExpr(n);
+  }
+
+  PrimExpr VisitExpr_(const VarNode* op) final {
+    auto it = var_map_.find(GetRef<Var>(op));
+    if (it == var_map_.end()) {
+      return GetRef<PrimExpr>(op);
+    } else {
+      return it->second;
+    }
+  }
+
+ private:
+  Buffer MutateBuffer(Buffer buffer) const {
+    BufferNode* buffer_ptr = buffer.CopyOnWrite();
+    Array<PrimExpr> new_shape, new_stride;
+    new_shape.reserve(buffer_ptr->shape.size());
+    new_shape.reserve(buffer_ptr->strides.size());
+    for (const auto& dim : buffer_ptr->shape) {
+      new_shape.push_back(Substitute(dim, var_map_));
+    }
+    for (const auto& stride : buffer_ptr->strides) {
+      new_shape.push_back(Substitute(stride, var_map_));
+    }
+    buffer_ptr->elem_offset = Substitute(buffer_ptr->elem_offset, var_map_);
+    buffer_ptr->shape = std::move(new_shape);
+    buffer_ptr->strides = std::move(new_stride);
+    return buffer;
+  }
+
+  Range MutateRange(const Range& range) {
+    PrimExpr min = this->VisitExpr(range->min);
+    PrimExpr extent = this->VisitExpr(range->extent);
+    if (min.same_as(range->min) && extent.same_as(range->extent)) {
+      return range;
+    } else {
+      ObjectPtr<RangeNode> n = CopyOnWrite(range.get());
+      n->min = std::move(min);
+      n->extent = std::move(extent);
+      return Range(n);
+    }
+  }
+
+  IterVar MutateIterVar(const IterVar& iter_var) {
+    Range range = MutateRange(iter_var->dom);
+    if (range.same_as(iter_var->dom)) {
+      return iter_var;
+    } else {
+      auto n = CopyOnWrite(iter_var.get());
+      n->dom = std::move(range);
+      return IterVar(n);
+    }
+  }
+
+  Buffer MutateAllocBuffer(const Buffer& alloc_buf) {
+    Buffer buf = MutateBuffer(alloc_buf);
+    if (buf.same_as(alloc_buf)) {
+      return alloc_buf;
+    } else {
+      buffer_map_[alloc_buf] = buf;
+      return buf;
+    }
+  }
+
+  BufferRegion MutateBufferRegion(const BufferRegion& buffer_region) {
+    auto it = buffer_map_.find(buffer_region->buffer);
+    Array<Range> region =
+        MutateArray(buffer_region->region,
+                    std::bind(&PrimFuncSpecializer::MutateRange, this, std::placeholders::_1));
+    if (it == buffer_map_.end() && region.same_as(buffer_region->region)) {
+      return buffer_region;
+    } else {
+      auto n = CopyOnWrite(buffer_region.get());

Review comment:
       avoid CopyOnWrite except for stmt for now, since we are not tracking cow for other things

##########
File path: src/tir/ir/specialize.cc
##########
@@ -0,0 +1,326 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/tir/ir/specialize.cc
+ * \brief Specialize parameters of PrimFunc.
+ */
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/function.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include <functional>
+
+#include "functor_common.h"
+
+namespace tvm {
+namespace tir {
+
+using VarMap = std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual>;
+
+/**************** Helper functions ****************/
+
+/*! \brief Helper function to check whether the given var is in function parameter list. */
+inline bool IsParam(const PrimFunc& func, const Var& param) {
+  return std::any_of(func->params.begin(), func->params.end(),
+                     [&](const Var& var) { return var.same_as(param); });
+}
+
+/*! \brief Mutator to specialize function and remove const parameters */
+class PrimFuncSpecializer : public StmtExprMutator {
+ public:
+  explicit PrimFuncSpecializer(VarMap var_map) : var_map_(var_map) {}
+
+  static PrimFunc Specialize(PrimFunc f, const VarMap& var_map) {
+    PrimFuncSpecializer specializer(var_map);
+    // Updating Buffer map
+    Map<Var, Buffer> buffer_map;
+    for (const auto& it : f->buffer_map) {
+      const Var& var = it.first;
+      const Buffer& buffer = it.second;
+      Buffer new_buffer = specializer.MutateBuffer(buffer);
+      buffer_map.Set(var, new_buffer);
+      if (!new_buffer.same_as(buffer)) {
+        specializer.buffer_map_[buffer] = new_buffer;
+      }
+    }
+
+    // Updating parmeters
+    Array<Var> params;
+    for (const auto& var : f->params) {
+      // Remove parmeters which has been specialized.
+      if (var_map.find(var) == var_map.end()) {
+        params.push_back(var);
+      }
+    }
+
+    PrimFuncNode* f_ptr = f.CopyOnWrite();
+    f_ptr->params = std::move(params);
+    f_ptr->buffer_map = std::move(buffer_map);
+    f_ptr->body = specializer(std::move(f_ptr->body));
+
+    // Updating attrs
+    if (f->attrs.defined()) {
+      auto& attr_dict = f_ptr->attrs.CopyOnWrite()->dict;
+      for (const auto& kv : attr_dict) {
+        const String& key = kv.first;
+        const ObjectRef& value = kv.second;
+        if (value->IsInstance<PrimExprNode>()) {
+          attr_dict.Set(key, Substitute(Downcast<PrimExpr>(value), var_map));
+        }
+      }
+    }
+    return f;
+  }
+
+ private:
+  Stmt VisitStmt_(const BlockNode* op) final {
+    Array<Buffer> alloc_buffers = MutateArray(
+        op->alloc_buffers,
+        std::bind(&PrimFuncSpecializer::MutateAllocBuffer, this, std::placeholders::_1));
+    Array<BufferRegion> reads = MutateArray(
+        op->reads,
+        std::bind(&PrimFuncSpecializer::MutateBufferRegion, this, std::placeholders::_1));
+    Array<BufferRegion> writes = MutateArray(
+        op->writes,
+        std::bind(&PrimFuncSpecializer::MutateBufferRegion, this, std::placeholders::_1));
+    Array<IterVar> block_vars = MutateArray(
+        op->iter_vars, std::bind(&PrimFuncSpecializer::MutateIterVar, this, std::placeholders::_1));
+    Optional<Stmt> init = NullOpt;
+    if (op->init.defined()) {
+      init = VisitStmt(op->init.value());
+    }
+    Stmt body = VisitStmt(op->body);
+
+    if (alloc_buffers.same_as(op->alloc_buffers) && reads.same_as(op->reads) &&
+        writes.same_as(op->writes) && block_vars.same_as(op->iter_vars) && body.same_as(op->body) &&
+        init.same_as(op->init)) {
+      return GetRef<Block>(op);
+    } else {
+      ObjectPtr<BlockNode> n = CopyOnWrite(op);
+      n->alloc_buffers = std::move(alloc_buffers);
+      n->reads = std::move(reads);
+      n->writes = std::move(writes);
+      n->iter_vars = std::move(block_vars);
+      n->body = std::move(body);
+      n->init = std::move(init);
+      return Stmt(n);
+    }
+  }
+
+  Stmt VisitStmt_(const BufferStoreNode* op) final {
+    auto it = buffer_map_.find(op->buffer);
+    if (it == buffer_map_.end()) {
+      return GetRef<BufferStore>(op);
+    }
+
+    PrimExpr value = VisitExpr(op->value);
+    Array<PrimExpr> indices =
+        MutateArray(op->indices, [this](const PrimExpr& e) { return this->VisitExpr(e); });
+
+    auto n = CopyOnWrite(op);
+    n->buffer = it->second;
+    n->value = std::move(value);
+    n->indices = std::move(indices);
+    return Stmt(n);
+  }
+
+  PrimExpr VisitExpr_(const BufferLoadNode* op) final {
+    auto it = buffer_map_.find(op->buffer);
+    if (it == buffer_map_.end()) {
+      return GetRef<BufferLoad>(op);
+    }
+
+    Array<PrimExpr> indices =
+        MutateArray(op->indices, [this](const PrimExpr& e) { return this->VisitExpr(e); });
+
+    auto n = CopyOnWrite(op);
+    n->buffer = it->second;
+    n->indices = std::move(indices);
+    return PrimExpr(n);
+  }
+
+  PrimExpr VisitExpr_(const VarNode* op) final {
+    auto it = var_map_.find(GetRef<Var>(op));
+    if (it == var_map_.end()) {
+      return GetRef<PrimExpr>(op);
+    } else {
+      return it->second;
+    }
+  }
+
+ private:
+  Buffer MutateBuffer(Buffer buffer) const {
+    BufferNode* buffer_ptr = buffer.CopyOnWrite();
+    Array<PrimExpr> new_shape, new_stride;
+    new_shape.reserve(buffer_ptr->shape.size());
+    new_shape.reserve(buffer_ptr->strides.size());
+    for (const auto& dim : buffer_ptr->shape) {
+      new_shape.push_back(Substitute(dim, var_map_));
+    }
+    for (const auto& stride : buffer_ptr->strides) {
+      new_shape.push_back(Substitute(stride, var_map_));
+    }
+    buffer_ptr->elem_offset = Substitute(buffer_ptr->elem_offset, var_map_);
+    buffer_ptr->shape = std::move(new_shape);
+    buffer_ptr->strides = std::move(new_stride);
+    return buffer;
+  }
+
+  Range MutateRange(const Range& range) {
+    PrimExpr min = this->VisitExpr(range->min);
+    PrimExpr extent = this->VisitExpr(range->extent);
+    if (min.same_as(range->min) && extent.same_as(range->extent)) {
+      return range;
+    } else {
+      ObjectPtr<RangeNode> n = CopyOnWrite(range.get());

Review comment:
       create a new range instead of using copyonwrite

##########
File path: src/tir/ir/specialize.cc
##########
@@ -0,0 +1,326 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/tir/ir/specialize.cc
+ * \brief Specialize parameters of PrimFunc.
+ */
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/function.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include <functional>
+
+#include "functor_common.h"
+
+namespace tvm {
+namespace tir {
+
+using VarMap = std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual>;
+
+/**************** Helper functions ****************/
+
+/*! \brief Helper function to check whether the given var is in function parameter list. */
+inline bool IsParam(const PrimFunc& func, const Var& param) {
+  return std::any_of(func->params.begin(), func->params.end(),
+                     [&](const Var& var) { return var.same_as(param); });
+}
+
+/*! \brief Mutator to specialize function and remove const parameters */
+class PrimFuncSpecializer : public StmtExprMutator {
+ public:
+  explicit PrimFuncSpecializer(VarMap var_map) : var_map_(var_map) {}
+
+  static PrimFunc Specialize(PrimFunc f, const VarMap& var_map) {
+    PrimFuncSpecializer specializer(var_map);
+    // Updating Buffer map
+    Map<Var, Buffer> buffer_map;
+    for (const auto& it : f->buffer_map) {
+      const Var& var = it.first;
+      const Buffer& buffer = it.second;
+      Buffer new_buffer = specializer.MutateBuffer(buffer);
+      buffer_map.Set(var, new_buffer);
+      if (!new_buffer.same_as(buffer)) {
+        specializer.buffer_map_[buffer] = new_buffer;
+      }
+    }
+
+    // Updating parmeters
+    Array<Var> params;
+    for (const auto& var : f->params) {
+      // Remove parmeters which has been specialized.
+      if (var_map.find(var) == var_map.end()) {
+        params.push_back(var);
+      }
+    }
+
+    PrimFuncNode* f_ptr = f.CopyOnWrite();
+    f_ptr->params = std::move(params);
+    f_ptr->buffer_map = std::move(buffer_map);
+    f_ptr->body = specializer(std::move(f_ptr->body));
+
+    // Updating attrs
+    if (f->attrs.defined()) {
+      auto& attr_dict = f_ptr->attrs.CopyOnWrite()->dict;
+      for (const auto& kv : attr_dict) {
+        const String& key = kv.first;
+        const ObjectRef& value = kv.second;
+        if (value->IsInstance<PrimExprNode>()) {
+          attr_dict.Set(key, Substitute(Downcast<PrimExpr>(value), var_map));
+        }
+      }
+    }
+    return f;
+  }
+
+ private:
+  Stmt VisitStmt_(const BlockNode* op) final {
+    Array<Buffer> alloc_buffers = MutateArray(
+        op->alloc_buffers,
+        std::bind(&PrimFuncSpecializer::MutateAllocBuffer, this, std::placeholders::_1));
+    Array<BufferRegion> reads = MutateArray(
+        op->reads,
+        std::bind(&PrimFuncSpecializer::MutateBufferRegion, this, std::placeholders::_1));
+    Array<BufferRegion> writes = MutateArray(
+        op->writes,
+        std::bind(&PrimFuncSpecializer::MutateBufferRegion, this, std::placeholders::_1));
+    Array<IterVar> block_vars = MutateArray(
+        op->iter_vars, std::bind(&PrimFuncSpecializer::MutateIterVar, this, std::placeholders::_1));
+    Optional<Stmt> init = NullOpt;
+    if (op->init.defined()) {
+      init = VisitStmt(op->init.value());
+    }
+    Stmt body = VisitStmt(op->body);
+
+    if (alloc_buffers.same_as(op->alloc_buffers) && reads.same_as(op->reads) &&
+        writes.same_as(op->writes) && block_vars.same_as(op->iter_vars) && body.same_as(op->body) &&
+        init.same_as(op->init)) {
+      return GetRef<Block>(op);
+    } else {
+      ObjectPtr<BlockNode> n = CopyOnWrite(op);
+      n->alloc_buffers = std::move(alloc_buffers);
+      n->reads = std::move(reads);
+      n->writes = std::move(writes);
+      n->iter_vars = std::move(block_vars);
+      n->body = std::move(body);
+      n->init = std::move(init);
+      return Stmt(n);
+    }
+  }
+
+  Stmt VisitStmt_(const BufferStoreNode* op) final {
+    auto it = buffer_map_.find(op->buffer);
+    if (it == buffer_map_.end()) {
+      return GetRef<BufferStore>(op);
+    }
+
+    PrimExpr value = VisitExpr(op->value);
+    Array<PrimExpr> indices =
+        MutateArray(op->indices, [this](const PrimExpr& e) { return this->VisitExpr(e); });
+
+    auto n = CopyOnWrite(op);
+    n->buffer = it->second;
+    n->value = std::move(value);
+    n->indices = std::move(indices);
+    return Stmt(n);
+  }
+
+  PrimExpr VisitExpr_(const BufferLoadNode* op) final {
+    auto it = buffer_map_.find(op->buffer);
+    if (it == buffer_map_.end()) {
+      return GetRef<BufferLoad>(op);
+    }
+
+    Array<PrimExpr> indices =
+        MutateArray(op->indices, [this](const PrimExpr& e) { return this->VisitExpr(e); });
+
+    auto n = CopyOnWrite(op);
+    n->buffer = it->second;
+    n->indices = std::move(indices);
+    return PrimExpr(n);
+  }
+
+  PrimExpr VisitExpr_(const VarNode* op) final {
+    auto it = var_map_.find(GetRef<Var>(op));
+    if (it == var_map_.end()) {
+      return GetRef<PrimExpr>(op);
+    } else {
+      return it->second;
+    }
+  }
+
+ private:
+  Buffer MutateBuffer(Buffer buffer) const {
+    BufferNode* buffer_ptr = buffer.CopyOnWrite();
+    Array<PrimExpr> new_shape, new_stride;
+    new_shape.reserve(buffer_ptr->shape.size());
+    new_shape.reserve(buffer_ptr->strides.size());
+    for (const auto& dim : buffer_ptr->shape) {
+      new_shape.push_back(Substitute(dim, var_map_));
+    }
+    for (const auto& stride : buffer_ptr->strides) {
+      new_shape.push_back(Substitute(stride, var_map_));
+    }
+    buffer_ptr->elem_offset = Substitute(buffer_ptr->elem_offset, var_map_);
+    buffer_ptr->shape = std::move(new_shape);
+    buffer_ptr->strides = std::move(new_stride);
+    return buffer;
+  }
+
+  Range MutateRange(const Range& range) {
+    PrimExpr min = this->VisitExpr(range->min);
+    PrimExpr extent = this->VisitExpr(range->extent);
+    if (min.same_as(range->min) && extent.same_as(range->extent)) {
+      return range;
+    } else {
+      ObjectPtr<RangeNode> n = CopyOnWrite(range.get());
+      n->min = std::move(min);
+      n->extent = std::move(extent);
+      return Range(n);
+    }
+  }
+
+  IterVar MutateIterVar(const IterVar& iter_var) {
+    Range range = MutateRange(iter_var->dom);
+    if (range.same_as(iter_var->dom)) {
+      return iter_var;
+    } else {
+      auto n = CopyOnWrite(iter_var.get());

Review comment:
       create a new IterVar instead of copyonWrite here

##########
File path: src/tir/ir/specialize.cc
##########
@@ -0,0 +1,326 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/tir/ir/specialize.cc
+ * \brief Specialize parameters of PrimFunc.
+ */
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/function.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include <functional>
+
+#include "functor_common.h"
+
+namespace tvm {
+namespace tir {
+
+using VarMap = std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual>;
+
+/**************** Helper functions ****************/
+
+/*! \brief Helper function to check whether the given var is in function parameter list. */
+inline bool IsParam(const PrimFunc& func, const Var& param) {
+  return std::any_of(func->params.begin(), func->params.end(),
+                     [&](const Var& var) { return var.same_as(param); });
+}
+
+/*! \brief Mutator to specialize function and remove const parameters */
+class PrimFuncSpecializer : public StmtExprMutator {
+ public:
+  explicit PrimFuncSpecializer(VarMap var_map) : var_map_(var_map) {}
+
+  static PrimFunc Specialize(PrimFunc f, const VarMap& var_map) {
+    PrimFuncSpecializer specializer(var_map);
+    // Updating Buffer map
+    Map<Var, Buffer> buffer_map;
+    for (const auto& it : f->buffer_map) {
+      const Var& var = it.first;
+      const Buffer& buffer = it.second;
+      Buffer new_buffer = specializer.MutateBuffer(buffer);
+      buffer_map.Set(var, new_buffer);
+      if (!new_buffer.same_as(buffer)) {
+        specializer.buffer_map_[buffer] = new_buffer;
+      }
+    }
+
+    // Updating parmeters
+    Array<Var> params;
+    for (const auto& var : f->params) {
+      // Remove parmeters which has been specialized.
+      if (var_map.find(var) == var_map.end()) {
+        params.push_back(var);
+      }
+    }
+
+    PrimFuncNode* f_ptr = f.CopyOnWrite();
+    f_ptr->params = std::move(params);
+    f_ptr->buffer_map = std::move(buffer_map);
+    f_ptr->body = specializer(std::move(f_ptr->body));
+
+    // Updating attrs
+    if (f->attrs.defined()) {
+      auto& attr_dict = f_ptr->attrs.CopyOnWrite()->dict;
+      for (const auto& kv : attr_dict) {
+        const String& key = kv.first;
+        const ObjectRef& value = kv.second;
+        if (value->IsInstance<PrimExprNode>()) {
+          attr_dict.Set(key, Substitute(Downcast<PrimExpr>(value), var_map));
+        }
+      }
+    }
+    return f;
+  }
+
+ private:
+  Stmt VisitStmt_(const BlockNode* op) final {
+    Array<Buffer> alloc_buffers = MutateArray(
+        op->alloc_buffers,
+        std::bind(&PrimFuncSpecializer::MutateAllocBuffer, this, std::placeholders::_1));
+    Array<BufferRegion> reads = MutateArray(
+        op->reads,
+        std::bind(&PrimFuncSpecializer::MutateBufferRegion, this, std::placeholders::_1));
+    Array<BufferRegion> writes = MutateArray(
+        op->writes,
+        std::bind(&PrimFuncSpecializer::MutateBufferRegion, this, std::placeholders::_1));
+    Array<IterVar> block_vars = MutateArray(
+        op->iter_vars, std::bind(&PrimFuncSpecializer::MutateIterVar, this, std::placeholders::_1));
+    Optional<Stmt> init = NullOpt;
+    if (op->init.defined()) {
+      init = VisitStmt(op->init.value());
+    }
+    Stmt body = VisitStmt(op->body);
+
+    if (alloc_buffers.same_as(op->alloc_buffers) && reads.same_as(op->reads) &&
+        writes.same_as(op->writes) && block_vars.same_as(op->iter_vars) && body.same_as(op->body) &&
+        init.same_as(op->init)) {
+      return GetRef<Block>(op);
+    } else {
+      ObjectPtr<BlockNode> n = CopyOnWrite(op);
+      n->alloc_buffers = std::move(alloc_buffers);
+      n->reads = std::move(reads);
+      n->writes = std::move(writes);
+      n->iter_vars = std::move(block_vars);
+      n->body = std::move(body);
+      n->init = std::move(init);
+      return Stmt(n);
+    }
+  }
+
+  Stmt VisitStmt_(const BufferStoreNode* op) final {
+    auto it = buffer_map_.find(op->buffer);
+    if (it == buffer_map_.end()) {
+      return GetRef<BufferStore>(op);
+    }
+
+    PrimExpr value = VisitExpr(op->value);
+    Array<PrimExpr> indices =
+        MutateArray(op->indices, [this](const PrimExpr& e) { return this->VisitExpr(e); });
+
+    auto n = CopyOnWrite(op);
+    n->buffer = it->second;
+    n->value = std::move(value);
+    n->indices = std::move(indices);
+    return Stmt(n);
+  }
+
+  PrimExpr VisitExpr_(const BufferLoadNode* op) final {
+    auto it = buffer_map_.find(op->buffer);
+    if (it == buffer_map_.end()) {
+      return GetRef<BufferLoad>(op);
+    }
+
+    Array<PrimExpr> indices =
+        MutateArray(op->indices, [this](const PrimExpr& e) { return this->VisitExpr(e); });
+
+    auto n = CopyOnWrite(op);
+    n->buffer = it->second;
+    n->indices = std::move(indices);
+    return PrimExpr(n);
+  }
+
+  PrimExpr VisitExpr_(const VarNode* op) final {
+    auto it = var_map_.find(GetRef<Var>(op));
+    if (it == var_map_.end()) {
+      return GetRef<PrimExpr>(op);
+    } else {
+      return it->second;
+    }
+  }
+
+ private:
+  Buffer MutateBuffer(Buffer buffer) const {
+    BufferNode* buffer_ptr = buffer.CopyOnWrite();
+    Array<PrimExpr> new_shape, new_stride;
+    new_shape.reserve(buffer_ptr->shape.size());
+    new_shape.reserve(buffer_ptr->strides.size());
+    for (const auto& dim : buffer_ptr->shape) {
+      new_shape.push_back(Substitute(dim, var_map_));
+    }
+    for (const auto& stride : buffer_ptr->strides) {
+      new_shape.push_back(Substitute(stride, var_map_));
+    }
+    buffer_ptr->elem_offset = Substitute(buffer_ptr->elem_offset, var_map_);
+    buffer_ptr->shape = std::move(new_shape);
+    buffer_ptr->strides = std::move(new_stride);
+    return buffer;
+  }
+
+  Range MutateRange(const Range& range) {
+    PrimExpr min = this->VisitExpr(range->min);
+    PrimExpr extent = this->VisitExpr(range->extent);
+    if (min.same_as(range->min) && extent.same_as(range->extent)) {
+      return range;
+    } else {
+      ObjectPtr<RangeNode> n = CopyOnWrite(range.get());
+      n->min = std::move(min);
+      n->extent = std::move(extent);
+      return Range(n);
+    }
+  }
+
+  IterVar MutateIterVar(const IterVar& iter_var) {
+    Range range = MutateRange(iter_var->dom);
+    if (range.same_as(iter_var->dom)) {
+      return iter_var;
+    } else {
+      auto n = CopyOnWrite(iter_var.get());
+      n->dom = std::move(range);
+      return IterVar(n);
+    }
+  }
+
+  Buffer MutateAllocBuffer(const Buffer& alloc_buf) {
+    Buffer buf = MutateBuffer(alloc_buf);
+    if (buf.same_as(alloc_buf)) {
+      return alloc_buf;
+    } else {
+      buffer_map_[alloc_buf] = buf;

Review comment:
       ICheck that buffer map should not contain alloc_buff

##########
File path: src/tir/ir/specialize.cc
##########
@@ -0,0 +1,326 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/tir/ir/specialize.cc
+ * \brief Specialize parameters of PrimFunc.
+ */
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/function.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include <functional>
+
+#include "functor_common.h"
+
+namespace tvm {
+namespace tir {
+
+using VarMap = std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual>;
+
+/**************** Helper functions ****************/
+
+/*! \brief Helper function to check whether the given var is in function parameter list. */
+inline bool IsParam(const PrimFunc& func, const Var& param) {
+  return std::any_of(func->params.begin(), func->params.end(),
+                     [&](const Var& var) { return var.same_as(param); });
+}
+
+/*! \brief Mutator to specialize function and remove const parameters */
+class PrimFuncSpecializer : public StmtExprMutator {
+ public:
+  explicit PrimFuncSpecializer(VarMap var_map) : var_map_(var_map) {}
+
+  static PrimFunc Specialize(PrimFunc f, const VarMap& var_map) {
+    PrimFuncSpecializer specializer(var_map);
+    // Updating Buffer map
+    Map<Var, Buffer> buffer_map;
+    for (const auto& it : f->buffer_map) {
+      const Var& var = it.first;
+      const Buffer& buffer = it.second;
+      Buffer new_buffer = specializer.MutateBuffer(buffer);
+      buffer_map.Set(var, new_buffer);
+      if (!new_buffer.same_as(buffer)) {
+        specializer.buffer_map_[buffer] = new_buffer;
+      }
+    }
+
+    // Updating parmeters
+    Array<Var> params;
+    for (const auto& var : f->params) {
+      // Remove parmeters which has been specialized.
+      if (var_map.find(var) == var_map.end()) {
+        params.push_back(var);
+      }
+    }
+
+    PrimFuncNode* f_ptr = f.CopyOnWrite();
+    f_ptr->params = std::move(params);
+    f_ptr->buffer_map = std::move(buffer_map);
+    f_ptr->body = specializer(std::move(f_ptr->body));
+
+    // Updating attrs
+    if (f->attrs.defined()) {
+      auto& attr_dict = f_ptr->attrs.CopyOnWrite()->dict;
+      for (const auto& kv : attr_dict) {
+        const String& key = kv.first;
+        const ObjectRef& value = kv.second;
+        if (value->IsInstance<PrimExprNode>()) {
+          attr_dict.Set(key, Substitute(Downcast<PrimExpr>(value), var_map));
+        }
+      }
+    }
+    return f;
+  }
+
+ private:
+  Stmt VisitStmt_(const BlockNode* op) final {
+    Array<Buffer> alloc_buffers = MutateArray(
+        op->alloc_buffers,
+        std::bind(&PrimFuncSpecializer::MutateAllocBuffer, this, std::placeholders::_1));
+    Array<BufferRegion> reads = MutateArray(
+        op->reads,
+        std::bind(&PrimFuncSpecializer::MutateBufferRegion, this, std::placeholders::_1));
+    Array<BufferRegion> writes = MutateArray(
+        op->writes,
+        std::bind(&PrimFuncSpecializer::MutateBufferRegion, this, std::placeholders::_1));
+    Array<IterVar> block_vars = MutateArray(
+        op->iter_vars, std::bind(&PrimFuncSpecializer::MutateIterVar, this, std::placeholders::_1));
+    Optional<Stmt> init = NullOpt;
+    if (op->init.defined()) {
+      init = VisitStmt(op->init.value());
+    }
+    Stmt body = VisitStmt(op->body);
+
+    if (alloc_buffers.same_as(op->alloc_buffers) && reads.same_as(op->reads) &&
+        writes.same_as(op->writes) && block_vars.same_as(op->iter_vars) && body.same_as(op->body) &&
+        init.same_as(op->init)) {
+      return GetRef<Block>(op);
+    } else {
+      ObjectPtr<BlockNode> n = CopyOnWrite(op);
+      n->alloc_buffers = std::move(alloc_buffers);
+      n->reads = std::move(reads);
+      n->writes = std::move(writes);
+      n->iter_vars = std::move(block_vars);
+      n->body = std::move(body);
+      n->init = std::move(init);
+      return Stmt(n);
+    }
+  }
+
+  Stmt VisitStmt_(const BufferStoreNode* op) final {
+    auto it = buffer_map_.find(op->buffer);
+    if (it == buffer_map_.end()) {
+      return GetRef<BufferStore>(op);
+    }
+
+    PrimExpr value = VisitExpr(op->value);
+    Array<PrimExpr> indices =
+        MutateArray(op->indices, [this](const PrimExpr& e) { return this->VisitExpr(e); });
+
+    auto n = CopyOnWrite(op);
+    n->buffer = it->second;
+    n->value = std::move(value);
+    n->indices = std::move(indices);
+    return Stmt(n);
+  }
+
+  PrimExpr VisitExpr_(const BufferLoadNode* op) final {
+    auto it = buffer_map_.find(op->buffer);
+    if (it == buffer_map_.end()) {
+      return GetRef<BufferLoad>(op);
+    }
+
+    Array<PrimExpr> indices =
+        MutateArray(op->indices, [this](const PrimExpr& e) { return this->VisitExpr(e); });
+
+    auto n = CopyOnWrite(op);
+    n->buffer = it->second;
+    n->indices = std::move(indices);
+    return PrimExpr(n);
+  }
+
+  PrimExpr VisitExpr_(const VarNode* op) final {
+    auto it = var_map_.find(GetRef<Var>(op));
+    if (it == var_map_.end()) {
+      return GetRef<PrimExpr>(op);
+    } else {
+      return it->second;
+    }
+  }
+
+ private:
+  Buffer MutateBuffer(Buffer buffer) const {
+    BufferNode* buffer_ptr = buffer.CopyOnWrite();
+    Array<PrimExpr> new_shape, new_stride;
+    new_shape.reserve(buffer_ptr->shape.size());
+    new_shape.reserve(buffer_ptr->strides.size());
+    for (const auto& dim : buffer_ptr->shape) {
+      new_shape.push_back(Substitute(dim, var_map_));

Review comment:
       can we reuse MutateArray for this?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] tqchen commented on pull request #8354: [TIR][TVMScript] specialize

Posted by GitBox <gi...@apache.org>.
tqchen commented on pull request #8354:
URL: https://github.com/apache/tvm/pull/8354#issuecomment-869648242


   cc @vinx13 @junrushao1994 please help to take a look


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] MasterJH5574 commented on a change in pull request #8354: [TIR][TVMScript] specialize

Posted by GitBox <gi...@apache.org>.
MasterJH5574 commented on a change in pull request #8354:
URL: https://github.com/apache/tvm/pull/8354#discussion_r680087655



##########
File path: tests/python/unittest/test_tir_specialize.py
##########
@@ -0,0 +1,199 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-function-docstring, missing-module-docstring
+
+import tvm
+from tvm import tir
+from tvm.script import ty
+
+
+@tvm.script.tir
+def matmul(a: ty.handle, b: ty.handle, c: ty.handle, n: ty.int32) -> None:
+    m = tir.var("int32")
+    A = tir.match_buffer(a, [m, n])
+    B = tir.match_buffer(b, [m, n])
+    C = tir.match_buffer(c, [m, m])
+
+    with tir.block([m, m, tir.reduce_axis(0, n)], "update") as [vi, vj, vk]:
+        with tir.init():
+            C[vi, vj] = 0.0
+        C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
+
+
+@tvm.script.tir
+def matmul_128(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
+    A = tir.match_buffer(a, [128, 128])
+    B = tir.match_buffer(b, [128, 128])
+    C = tir.match_buffer(c, [128, 128])
+
+    with tir.block([128, 128, tir.reduce_axis(0, 128)], "update") as [vi, vj, vk]:
+        with tir.init():
+            C[vi, vj] = 0.0
+        C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
+
+
+@tvm.script.tir
+def matmul_m_128(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
+    m = tir.var("int32")
+    A = tir.match_buffer(a, [m, 128])
+    B = tir.match_buffer(b, [m, 128])
+    C = tir.match_buffer(c, [m, m])
+
+    with tir.block([m, m, tir.reduce_axis(0, 128)], "update") as [vi, vj, vk]:
+        with tir.init():
+            C[vi, vj] = 0.0
+        C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
+
+
+@tvm.script.tir
+def matmul_m_8x(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
+    x = tir.var("int32")
+    m = tir.var("int32")
+    A = tir.match_buffer(a, [m, x * 8])
+    B = tir.match_buffer(b, [m, x * 8])
+    C = tir.match_buffer(c, [m, m])
+
+    with tir.block([m, m, tir.reduce_axis(0, x * 8)], "update") as [vi, vj, vk]:
+        with tir.init():
+            C[vi, vj] = 0.0
+        C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
+
+
+@tvm.script.tir
+def element_wise(a: ty.handle, c: ty.handle) -> None:
+    m = tir.var("int32")
+    n = tir.var("int32")
+    A = tir.match_buffer(a, (m, n), "float32")
+    C = tir.match_buffer(c, (m, n), "float32")
+
+    B = tir.alloc_buffer((m, n), "float32")
+
+    with tir.block([m, n], "B") as [vi, vj]:
+        B[vi, vj] = A[vi, vj] * 2.0
+
+    with tir.block([m, n], "C") as [vi, vj]:
+        C[vi, vj] = B[vi, vj] + 1.0
+
+
+@tvm.script.tir
+def element_wise_128_64(a: ty.handle, c: ty.handle) -> None:
+    A = tir.match_buffer(a, (128, 64), "float32")
+    C = tir.match_buffer(c, (128, 64), "float32")
+    B = tir.alloc_buffer((128, 64), "float32")
+
+    with tir.block([128, 64], "B") as [vi, vj]:
+        B[vi, vj] = A[vi, vj] * 2.0
+
+    with tir.block([128, 64], "C") as [vi, vj]:
+        C[vi, vj] = B[vi, vj] + 1.0
+
+
+@tvm.script.tir
+def element_wise_128_n(a: ty.handle, c: ty.handle) -> None:
+    n = tir.var("int32")
+    A = tir.match_buffer(a, (128, n), "float32")
+    C = tir.match_buffer(c, (128, n), "float32")
+    B = tir.alloc_buffer((128, n), "float32")
+
+    with tir.block([128, n], "B") as [vi, vj]:
+        B[vi, vj] = A[vi, vj] * 2.0
+
+    with tir.block([128, n], "C") as [vi, vj]:
+        C[vi, vj] = B[vi, vj] + 1.0
+
+
+@tvm.script.tir
+def mem_copy(
+    a: ty.handle, b: ty.handle, m: ty.int32, n: ty.int32, p: ty.int32, q: ty.int32
+) -> None:
+    A = tir.match_buffer(a, (m, n), "float32", strides=[p, 1], elem_offset=q)
+    B = tir.match_buffer(b, (m, n), "float32", strides=[p, 1], elem_offset=q)
+
+    with tir.block([m, n], "") as [vi, vj]:
+        B[vi, vj] = A[vi, vj]
+
+
+@tvm.script.tir
+def mem_copy_16_16_8_4(a: ty.handle, b: ty.handle) -> None:
+    A = tir.match_buffer(a, (16, 16), "float32", strides=[8, 1], elem_offset=4)
+    B = tir.match_buffer(b, (16, 16), "float32", strides=[8, 1], elem_offset=4)
+
+    with tir.block([16, 16], "") as [vi, vj]:
+        B[vi, vj] = A[vi, vj]
+
+
+@tvm.script.tir
+def mem_copy_m_n_p_n(a: ty.handle, b: ty.handle, m: ty.int32, n: ty.int32, p: ty.int32) -> None:
+    A = tir.match_buffer(a, (m, n), "float32", strides=[p, 1], elem_offset=n)
+    B = tir.match_buffer(b, (m, n), "float32", strides=[p, 1], elem_offset=n)
+
+    with tir.block([m, n], "") as [vi, vj]:
+        B[vi, vj] = A[vi, vj]
+
+
+def test_specialize_nothing():
+    func = matmul.specialize({})
+    assert func.same_as(matmul)  # Pointer the same
+
+
+def test_specialize_matmul():
+    a, _, _, n = matmul.params
+    # fully specialized
+    func = matmul.specialize({a: tir.decl_buffer((128, 128))})
+    tvm.ir.assert_structural_equal(func, matmul_128)
+    # partially specialized
+    func = matmul.specialize({n: 128})
+    tvm.ir.assert_structural_equal(func, matmul_m_128)
+    # symbolic specialized
+    func = matmul.specialize({n: tir.Var("x", "int32") * 8})
+    tvm.ir.assert_structural_equal(func, matmul_m_8x)
+
+
+def test_specialize_elemwise():
+    a, c = element_wise.params
+    C = element_wise.buffer_map[c]
+    # fully specialized
+    func = element_wise.specialize({a: tir.decl_buffer((128, 64))})
+    tvm.ir.assert_structural_equal(func, element_wise_128_64)
+    # partially specialized
+    func = element_wise.specialize({c: tir.decl_buffer((128, C.shape[1]))})
+    tvm.ir.assert_structural_equal(func, element_wise_128_n)
+
+
+def test_specialize_mem_copy():
+    a, _, m, n, p, q = mem_copy.params
+    # fully specialized
+    func = mem_copy.specialize({a: tir.decl_buffer((16, 16), strides=[8, 1], elem_offset=4)})
+    tvm.ir.assert_structural_equal(func, mem_copy_16_16_8_4)
+    func = mem_copy.specialize({n: 16, m: 16, p: 8, q: 4})
+    tvm.ir.assert_structural_equal(func, mem_copy_16_16_8_4)
+    # partially specialized
+    func = mem_copy.specialize({q: n})
+    tvm.ir.assert_structural_equal(func, mem_copy_m_n_p_n)
+
+
+def test_specialize_recursive_load():
+    # TODO(Siyuan): add recursive Load testcase, e.g. A[C[i]]
+    pass

Review comment:
       I saw this TODO item today. Is it intended to leave the TODO here? Or it was because we forgot to write the unittest?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] tqchen commented on a change in pull request #8354: [TIR][TVMScript] specialize

Posted by GitBox <gi...@apache.org>.
tqchen commented on a change in pull request #8354:
URL: https://github.com/apache/tvm/pull/8354#discussion_r659738159



##########
File path: src/tir/ir/specialize.cc
##########
@@ -0,0 +1,326 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/tir/ir/specialize.cc
+ * \brief Specialize parameters of PrimFunc.
+ */
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/function.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include <functional>
+
+#include "functor_common.h"
+
+namespace tvm {
+namespace tir {
+
+using VarMap = std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual>;
+
+/**************** Helper functions ****************/
+
+/*! \brief Helper function to check whether the given var is in function parameter list. */
+inline bool IsParam(const PrimFunc& func, const Var& param) {
+  return std::any_of(func->params.begin(), func->params.end(),
+                     [&](const Var& var) { return var.same_as(param); });
+}
+
+/*! \brief Mutator to specialize function and remove const parameters */
+class PrimFuncSpecializer : public StmtExprMutator {
+ public:
+  explicit PrimFuncSpecializer(VarMap var_map) : var_map_(var_map) {}
+
+  static PrimFunc Specialize(PrimFunc f, const VarMap& var_map) {
+    PrimFuncSpecializer specializer(var_map);
+    // Updating Buffer map
+    Map<Var, Buffer> buffer_map;
+    for (const auto& it : f->buffer_map) {
+      const Var& var = it.first;
+      const Buffer& buffer = it.second;
+      Buffer new_buffer = specializer.MutateBuffer(buffer);
+      buffer_map.Set(var, new_buffer);
+      if (!new_buffer.same_as(buffer)) {
+        specializer.buffer_map_[buffer] = new_buffer;
+      }
+    }
+
+    // Updating parmeters
+    Array<Var> params;
+    for (const auto& var : f->params) {
+      // Remove parmeters which has been specialized.
+      if (var_map.find(var) == var_map.end()) {
+        params.push_back(var);
+      }
+    }
+
+    PrimFuncNode* f_ptr = f.CopyOnWrite();
+    f_ptr->params = std::move(params);
+    f_ptr->buffer_map = std::move(buffer_map);
+    f_ptr->body = specializer(std::move(f_ptr->body));
+
+    // Updating attrs
+    if (f->attrs.defined()) {
+      auto& attr_dict = f_ptr->attrs.CopyOnWrite()->dict;
+      for (const auto& kv : attr_dict) {
+        const String& key = kv.first;
+        const ObjectRef& value = kv.second;
+        if (value->IsInstance<PrimExprNode>()) {
+          attr_dict.Set(key, Substitute(Downcast<PrimExpr>(value), var_map));
+        }
+      }
+    }
+    return f;
+  }
+
+ private:
+  Stmt VisitStmt_(const BlockNode* op) final {
+    Array<Buffer> alloc_buffers = MutateArray(
+        op->alloc_buffers,
+        std::bind(&PrimFuncSpecializer::MutateAllocBuffer, this, std::placeholders::_1));
+    Array<BufferRegion> reads = MutateArray(
+        op->reads,
+        std::bind(&PrimFuncSpecializer::MutateBufferRegion, this, std::placeholders::_1));
+    Array<BufferRegion> writes = MutateArray(
+        op->writes,
+        std::bind(&PrimFuncSpecializer::MutateBufferRegion, this, std::placeholders::_1));
+    Array<IterVar> block_vars = MutateArray(
+        op->iter_vars, std::bind(&PrimFuncSpecializer::MutateIterVar, this, std::placeholders::_1));
+    Optional<Stmt> init = NullOpt;
+    if (op->init.defined()) {
+      init = VisitStmt(op->init.value());
+    }
+    Stmt body = VisitStmt(op->body);
+
+    if (alloc_buffers.same_as(op->alloc_buffers) && reads.same_as(op->reads) &&
+        writes.same_as(op->writes) && block_vars.same_as(op->iter_vars) && body.same_as(op->body) &&
+        init.same_as(op->init)) {
+      return GetRef<Block>(op);
+    } else {
+      ObjectPtr<BlockNode> n = CopyOnWrite(op);
+      n->alloc_buffers = std::move(alloc_buffers);
+      n->reads = std::move(reads);
+      n->writes = std::move(writes);
+      n->iter_vars = std::move(block_vars);
+      n->body = std::move(body);
+      n->init = std::move(init);
+      return Stmt(n);
+    }
+  }
+
+  Stmt VisitStmt_(const BufferStoreNode* op) final {
+    auto it = buffer_map_.find(op->buffer);
+    if (it == buffer_map_.end()) {
+      return GetRef<BufferStore>(op);
+    }
+
+    PrimExpr value = VisitExpr(op->value);
+    Array<PrimExpr> indices =
+        MutateArray(op->indices, [this](const PrimExpr& e) { return this->VisitExpr(e); });
+
+    auto n = CopyOnWrite(op);
+    n->buffer = it->second;
+    n->value = std::move(value);
+    n->indices = std::move(indices);
+    return Stmt(n);
+  }
+
+  PrimExpr VisitExpr_(const BufferLoadNode* op) final {
+    auto it = buffer_map_.find(op->buffer);
+    if (it == buffer_map_.end()) {
+      return GetRef<BufferLoad>(op);
+    }
+
+    Array<PrimExpr> indices =
+        MutateArray(op->indices, [this](const PrimExpr& e) { return this->VisitExpr(e); });
+
+    auto n = CopyOnWrite(op);
+    n->buffer = it->second;
+    n->indices = std::move(indices);
+    return PrimExpr(n);
+  }
+
+  PrimExpr VisitExpr_(const VarNode* op) final {
+    auto it = var_map_.find(GetRef<Var>(op));
+    if (it == var_map_.end()) {
+      return GetRef<PrimExpr>(op);
+    } else {
+      return it->second;
+    }
+  }
+
+ private:
+  Buffer MutateBuffer(Buffer buffer) const {
+    BufferNode* buffer_ptr = buffer.CopyOnWrite();
+    Array<PrimExpr> new_shape, new_stride;
+    new_shape.reserve(buffer_ptr->shape.size());
+    new_shape.reserve(buffer_ptr->strides.size());

Review comment:
       new_strides




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org