You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by wu...@apache.org on 2022/08/26 17:15:59 UTC

[tvm] branch main updated: [TIR] More hygenic TVM_SREF macros (#12607)

This is an automated email from the ASF dual-hosted git repository.

wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 49b3c72935 [TIR] More hygenic TVM_SREF macros (#12607)
49b3c72935 is described below

commit 49b3c72935b290afa9eee1f1c57a4b4c2f10a445
Author: Eric Lunderberg <Lu...@users.noreply.github.com>
AuthorDate: Fri Aug 26 10:15:54 2022 -0700

    [TIR] More hygenic TVM_SREF macros (#12607)
    
    Previously, the `TVM_SREF_TO_BLOCK`, `TVM_SREF_TO_FOR`, and
    `TVM_TYPE_AS` macros required both the input and output variables.
    The input variable name is useful for improving the error message
    returned, but the output variable name isn't necessary for this
    functionality, and prevents the macro from being used as part of an
    expression.
    
    * Generate an immediately-invoked lambda expression to allow for an
      independently-scoped `result` variable.
    
    * Use parentheses around the input argument, in case the sref is
      the result of an expression.
    
    * Update all call sites to remove the macro argument providing the
      first argument.
---
 src/meta_schedule/mutator/mutate_parallel.cc       |  4 +-
 src/meta_schedule/mutator/mutate_thread_binding.cc |  8 ++--
 src/meta_schedule/mutator/mutate_tile_size.cc      |  4 +-
 src/meta_schedule/mutator/mutate_unroll.cc         |  4 +-
 .../postproc/rewrite_parallel_vectorize_unroll.cc  |  4 +-
 src/meta_schedule/schedule_rule/auto_bind.cc       |  2 +-
 src/meta_schedule/schedule_rule/auto_inline.cc     |  2 +-
 .../schedule_rule/multi_level_tiling.cc            |  2 +-
 .../multi_level_tiling_tensor_core.cc              |  4 +-
 .../schedule_rule/random_compute_location.cc       |  2 +-
 src/meta_schedule/utils.h                          |  2 +-
 src/tir/schedule/analysis/analysis.cc              | 48 ++++++++++----------
 src/tir/schedule/block_scope.cc                    |  2 +-
 src/tir/schedule/concrete_schedule.cc              |  4 +-
 src/tir/schedule/concrete_schedule.h               |  6 +--
 src/tir/schedule/primitive/block_annotate.cc       |  6 +--
 src/tir/schedule/primitive/blockize_tensorize.cc   |  2 +-
 src/tir/schedule/primitive/cache_read_write.cc     | 14 +++---
 src/tir/schedule/primitive/compute_at.cc           | 12 ++---
 src/tir/schedule/primitive/compute_inline.cc       |  8 ++--
 src/tir/schedule/primitive/decompose_padding.cc    |  2 +-
 src/tir/schedule/primitive/for_kind.cc             |  4 +-
 src/tir/schedule/primitive/get_block_loop.cc       |  2 +-
 .../schedule/primitive/layout_transformation.cc    | 10 ++---
 src/tir/schedule/primitive/loop_transformation.cc  | 10 ++---
 src/tir/schedule/primitive/reduction.cc            | 12 ++---
 src/tir/schedule/primitive/sampling.cc             |  2 +-
 src/tir/schedule/state.cc                          | 14 +++---
 src/tir/schedule/transform.cc                      |  6 +--
 src/tir/schedule/utils.h                           | 51 ++++++++++++++--------
 30 files changed, 133 insertions(+), 120 deletions(-)

diff --git a/src/meta_schedule/mutator/mutate_parallel.cc b/src/meta_schedule/mutator/mutate_parallel.cc
index 5b7fe7f514..82b91da682 100644
--- a/src/meta_schedule/mutator/mutate_parallel.cc
+++ b/src/meta_schedule/mutator/mutate_parallel.cc
@@ -64,7 +64,7 @@ const BlockRVNode* GetInstGetBlockOutput(const Instruction& inst) {
     return nullptr;
   }
   ICHECK_EQ(inst->outputs.size(), 1);
-  const BlockRVNode* block = TVM_TYPE_AS(block, inst->outputs[0], BlockRVNode);
+  const BlockRVNode* block = TVM_TYPE_AS(inst->outputs[0], BlockRVNode);
   return block;
 }
 
@@ -82,7 +82,7 @@ std::vector<std::vector<int64_t>> AnalyzeParallel(const ScheduleState& self,
   Array<StmtSRef> block_srefs =
       tir::GetBlocks(self, block_name, self->mod->GetGlobalVar(func_name));
   ICHECK_EQ(block_srefs.size(), 1);
-  const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_srefs[0]);
+  const BlockNode* block = TVM_SREF_TO_BLOCK(block_srefs[0]);
   ScopeBlockLoopInfo info = GetScopeBlockLoopInfo(GetRef<Block>(block));
   std::vector<std::vector<int64_t>> results;
   results.reserve(info.realizes.size());
diff --git a/src/meta_schedule/mutator/mutate_thread_binding.cc b/src/meta_schedule/mutator/mutate_thread_binding.cc
index 41207162ee..de780b53e2 100644
--- a/src/meta_schedule/mutator/mutate_thread_binding.cc
+++ b/src/meta_schedule/mutator/mutate_thread_binding.cc
@@ -109,12 +109,12 @@ std::vector<MutateThreadBindingNode::Candidate> MutateThreadBindingNode::FindCan
   for (const Instruction& inst : trace->insts) {
     if (inst->kind.same_as(inst_sample_categorical)) {
       ICHECK_EQ(inst->outputs.size(), 1);
-      const PrimExprNode* var_rv = TVM_TYPE_AS(var_rv, inst->outputs[0], PrimExprNode);
+      const PrimExprNode* var_rv = TVM_TYPE_AS(inst->outputs[0], PrimExprNode);
       sample_insts[var_rv] = inst.get();
     } else if (is_split_by_sample(inst)) {
       CHECK_EQ(inst->outputs.size(), 2);
       // Only consider the inner loop, which can be bound to threadIdx.x
-      const tir::LoopRVNode* var_rv = TVM_TYPE_AS(var_rv, inst->outputs[1], tir::LoopRVNode);
+      const tir::LoopRVNode* var_rv = TVM_TYPE_AS(inst->outputs[1], tir::LoopRVNode);
       sampled_split_insts[var_rv] = inst.get();
     } else if (is_thread_binding_by_sample(inst)) {
       bind_insts.push_back(inst.get());
@@ -122,12 +122,12 @@ std::vector<MutateThreadBindingNode::Candidate> MutateThreadBindingNode::FindCan
   }
 
   for (const InstructionNode* bind_inst : bind_insts) {
-    const auto* loop_rv = TVM_TYPE_AS(loop_rv, bind_inst->inputs[0], tir::LoopRVNode);
+    const auto* loop_rv = TVM_TYPE_AS(bind_inst->inputs[0], tir::LoopRVNode);
     auto split_it = sampled_split_insts.find(loop_rv);
     ICHECK(split_it != sampled_split_insts.end());
     const InstructionNode* split_inst = split_it->second;
 
-    const auto* expr_rv = TVM_TYPE_AS(expr_rv, split_inst->inputs[2], PrimExprNode);
+    const auto* expr_rv = TVM_TYPE_AS(split_inst->inputs[2], PrimExprNode);
     auto sample_it = sample_insts.find(expr_rv);
     ICHECK(sample_it != sample_insts.end());
     const InstructionNode* sample_inst = sample_it->second;
diff --git a/src/meta_schedule/mutator/mutate_tile_size.cc b/src/meta_schedule/mutator/mutate_tile_size.cc
index 00967aef7a..4a3bfda8a4 100644
--- a/src/meta_schedule/mutator/mutate_tile_size.cc
+++ b/src/meta_schedule/mutator/mutate_tile_size.cc
@@ -34,7 +34,7 @@ using tir::Trace;
  * \return The result of downcast
  */
 std::vector<int64_t> DowncastTilingDecision(const ObjectRef& decision) {
-  const auto* arr = TVM_TYPE_AS(arr, decision, runtime::ArrayNode);
+  const auto* arr = TVM_TYPE_AS(decision, runtime::ArrayNode);
   return support::AsVector<ObjectRef, int64_t>(GetRef<Array<ObjectRef>>(arr));
 }
 
@@ -123,7 +123,7 @@ void FindSampleVectorize(const Trace& trace, std::vector<Instruction>* inst,
     if (inst->kind.same_as(inst_sample_categorical)) {
       ICHECK_EQ(inst->outputs.size(), 1);
       if (annotated.count(inst->outputs[0].get())) {
-        const auto* d = TVM_TYPE_AS(d, decision, IntImmNode);
+        const auto* d = TVM_TYPE_AS(decision, IntImmNode);
         instructions.push_back(inst);
         decisions.push_back(d->value);
       }
diff --git a/src/meta_schedule/mutator/mutate_unroll.cc b/src/meta_schedule/mutator/mutate_unroll.cc
index 94e8348858..c282a171c3 100644
--- a/src/meta_schedule/mutator/mutate_unroll.cc
+++ b/src/meta_schedule/mutator/mutate_unroll.cc
@@ -91,7 +91,7 @@ bool FindUnrollDecision(const Trace& trace, TRandState* rand_state,
   for (const Instruction& inst : trace->insts) {
     if (inst->kind.same_as(inst_sample_categorical)) {
       ICHECK_EQ(inst->outputs.size(), 1);
-      const PrimExprNode* var_rv = TVM_TYPE_AS(var_rv, inst->outputs[0], PrimExprNode);
+      const PrimExprNode* var_rv = TVM_TYPE_AS(inst->outputs[0], PrimExprNode);
       sample_insts[var_rv] = inst.get();
     } else if (IsAnnotateWithUnroll(inst)) {
       ann_insts.push_back(inst.get());
@@ -103,7 +103,7 @@ bool FindUnrollDecision(const Trace& trace, TRandState* rand_state,
   }
   const InstructionNode* ann_inst = ann_insts[tir::SampleInt(rand_state, 0, n_ann_insts)];
   ICHECK_EQ(ann_inst->inputs.size(), 2);
-  const auto* var_rv = TVM_TYPE_AS(var_rv, ann_inst->inputs[1], PrimExprNode);
+  const auto* var_rv = TVM_TYPE_AS(ann_inst->inputs[1], PrimExprNode);
   ICHECK(sample_insts.count(var_rv));
   const InstructionNode* sample_inst = sample_insts.at(var_rv);
   ICHECK_EQ(sample_inst->attrs.size(), 2);
diff --git a/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc b/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc
index f3c2b1328b..08d25d0178 100644
--- a/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc
+++ b/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc
@@ -233,7 +233,7 @@ void AdjustParallelVectorize(const Schedule& sch, const BlockRV& block_rv,
     int64_t prod_extent = 1;
     for (int i = 0; i < n_loops && loop_types[i] == IterVarType::kDataPar; ++i) {
       const StmtSRef& loop_sref = loop_srefs[i];
-      const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+      const ForNode* loop = TVM_SREF_TO_FOR(loop_sref);
       if (HasAnnOrBinding(loop)) {
         break;
       }
@@ -262,7 +262,7 @@ void AdjustParallelVectorize(const Schedule& sch, const BlockRV& block_rv,
     for (int i = n_loops - 1;
          i >= 0 && loop_types[i] == IterVarType::kDataPar && num_fusible < max_fusible; --i) {
       const StmtSRef& loop_sref = loop_srefs[i];
-      const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+      const ForNode* loop = TVM_SREF_TO_FOR(loop_sref);
       if (HasAnnOrBinding(loop)) {
         break;
       }
diff --git a/src/meta_schedule/schedule_rule/auto_bind.cc b/src/meta_schedule/schedule_rule/auto_bind.cc
index ff4d26084e..d8f52fa8e1 100644
--- a/src/meta_schedule/schedule_rule/auto_bind.cc
+++ b/src/meta_schedule/schedule_rule/auto_bind.cc
@@ -45,7 +45,7 @@ void BindBlockThreadIdx(const tir::Schedule& sch, const tir::BlockRV& block_rv,
   int i_spatial_loop = -1;
   for (int i = 0; i < n; ++i) {
     const StmtSRef& loop_sref = loops[i];
-    const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+    const ForNode* loop = TVM_SREF_TO_FOR(loop_sref);
     runtime::ThreadScope thread_scope = GetThreadScope(loop);
     if (IsBlockIdx(thread_scope)) {
       if (i_block_idx == -1) {
diff --git a/src/meta_schedule/schedule_rule/auto_inline.cc b/src/meta_schedule/schedule_rule/auto_inline.cc
index df4d3ac859..76313f46d1 100644
--- a/src/meta_schedule/schedule_rule/auto_inline.cc
+++ b/src/meta_schedule/schedule_rule/auto_inline.cc
@@ -96,7 +96,7 @@ inline InlineType AutoInlineNode::CheckInline(const tir::Schedule& sch,
   StmtSRef block_sref = sch->GetSRef(block_rv);
   bool is_pure_sptial = IsInSpatialPrimFunc(sch, block_sref);
   ScheduleState state = sch->state();
-  const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
+  const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref);
   BlockRealize realize = GetBlockRealize(state, block_sref);
   // Cond 1. The block has only one write buffer
   if (block->writes.size() != 1) {
diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling.cc b/src/meta_schedule/schedule_rule/multi_level_tiling.cc
index eefc2eea41..c126c85446 100644
--- a/src/meta_schedule/schedule_rule/multi_level_tiling.cc
+++ b/src/meta_schedule/schedule_rule/multi_level_tiling.cc
@@ -37,7 +37,7 @@ namespace tir {
  * of multi-level tiling, so it's intentionally kept inside this file not in the analysis header
  */
 std::vector<int> GetReadBufferNDims(const StmtSRef& block_sref) {
-  const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
+  const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref);
   const BufferNode* write_buffer = block->writes[0]->buffer.get();
   int n = block->reads.size();
   std::vector<int> results(n, -1);
diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc b/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc
index 49704fb66b..7ddda9b263 100644
--- a/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc
+++ b/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc
@@ -411,7 +411,7 @@ Optional<LoopRV> MultiLevelTilingTensorCoreNode::TransformWithTensorIntrin(
   tir::StmtSRef block_sref = state->sch->GetSRef(state->block_rv);
 
   // Add reindex stages
-  const tir::BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
+  const tir::BlockNode* block = TVM_SREF_TO_BLOCK(block_sref);
   // Hold the reference of the block before reindex
   const tir::Block block_before_reindex = GetRef<tir::Block>(block);
   if (block->reads.size() != 2 || block->writes.size() != 1) {
@@ -488,7 +488,7 @@ Optional<LoopRV> MultiLevelTilingTensorCoreNode::TransformWithTensorIntrin(
     }
     visited_buffers.insert(lhs_buffer);
     // Refresh block pointer (block sref is not invalidated)
-    block = TVM_SREF_TO_BLOCK(block, block_sref);
+    block = TVM_SREF_TO_BLOCK(block_sref);
     const tir::BufferRegion& reindexed_buffer_region = tir::GetNthAccessBufferRegion(
         state->sch->state(), GetRef<tir::Block>(block), buffer_index, index_type);
     auto sub_index_map = f_get_sub_index_map(lhs_buffer, reindexed_buffer_region->region);
diff --git a/src/meta_schedule/schedule_rule/random_compute_location.cc b/src/meta_schedule/schedule_rule/random_compute_location.cc
index e4b5d5bde2..65988dfd56 100644
--- a/src/meta_schedule/schedule_rule/random_compute_location.cc
+++ b/src/meta_schedule/schedule_rule/random_compute_location.cc
@@ -60,7 +60,7 @@ class RandomComputeLocationNode : public ScheduleRuleNode {
  private:
   bool CheckConditions(const tir::Schedule sch, const tir::BlockRV& block_rv) const {
     tir::StmtSRef block_sref = sch->GetSRef(block_rv);
-    const tir::BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
+    TVM_SREF_TO_BLOCK(block_sref);
 
     // Cond 1. The block is not the root block.
     if (block_sref->parent == nullptr) {
diff --git a/src/meta_schedule/utils.h b/src/meta_schedule/utils.h
index cb84596eed..664a6a609e 100644
--- a/src/meta_schedule/utils.h
+++ b/src/meta_schedule/utils.h
@@ -238,7 +238,7 @@ inline std::string Concat(const Array<String>& strs, const std::string& delim) {
  */
 inline tir::BlockRV GetRVFromSRef(const tir::Schedule& sch, const tir::StmtSRef& block_sref,
                                   const String& global_var_name) {
-  const tir::BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
+  const tir::BlockNode* block = TVM_SREF_TO_BLOCK(block_sref);
   return sch->GetBlock(block->name_hint, global_var_name);
 }
 
diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc
index 62ec0b468f..b9e99257f3 100644
--- a/src/tir/schedule/analysis/analysis.cc
+++ b/src/tir/schedule/analysis/analysis.cc
@@ -150,7 +150,7 @@ Definition of a scope that is a stage pipeline:
   if (require_stage_pipeline) {
     bool stage_pipeline = self->GetBlockInfo(scope_root_sref).scope->stage_pipeline;
     if (stage_pipeline == false) {
-      const BlockNode* block = TVM_SREF_TO_BLOCK(block, scope_root_sref);
+      const BlockNode* block = TVM_SREF_TO_BLOCK(scope_root_sref);
       throw NotStagePipelineError(self->mod, GetRef<Block>(block));
     }
   }
@@ -229,7 +229,7 @@ bool IsDominantBlock(const ScheduleState& self, const StmtSRef& scope_root_sref,
     }
   }
   // Check whether the input block is the only writer of its outputs
-  const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
+  const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref);
   for (const BufferRegion& write_region : block->writes) {
     if (buffer_writers.count(write_region->buffer)) {
       if (buffer_writers.at(write_region->buffer).size() != 1) {
@@ -252,7 +252,7 @@ bool IsDominantBlock(const ScheduleState& self, const StmtSRef& scope_root_sref,
 int CheckCompleteBlockErrorCode(const ScheduleState& self, const StmtSRef& block_sref,
                                 const StmtSRef& scope_root_sref) {
   // Cond 1. All block vars are data parallel
-  const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
+  const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref);
   for (const IterVar& iter_var : block->iter_vars) {
     if (iter_var->iter_type != kDataPar) {
       return 1;
@@ -328,7 +328,7 @@ void CheckCompleteBlock(const ScheduleState& self, const StmtSRef& block_sref,
 
   int error_code = CheckCompleteBlockErrorCode(self, block_sref, scope_root_sref);
   if (error_code != 0) {
-    const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
+    const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref);
     throw IncompleteBlockError(self->mod, GetRef<Block>(block), error_code);
   }
 }
@@ -344,7 +344,7 @@ void CheckCompleteBlock(const ScheduleState& self, const StmtSRef& block_sref,
  */
 int CheckReductionBlockErrorCode(const ScheduleState& self, const StmtSRef& block_sref,
                                  const StmtSRef& scope_root_sref) {
-  const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
+  const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref);
   // Cond 1. The block has the `init` statement.
   if (!block->init.defined()) {
     return 1;
@@ -394,7 +394,7 @@ void CheckReductionBlock(const ScheduleState& self, const StmtSRef& block_sref,
 
   int error_code = CheckReductionBlockErrorCode(self, block_sref, scope_root_sref);
   if (error_code != 0) {
-    const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
+    const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref);
     throw NotReductionBlockError(self->mod, GetRef<Block>(block), error_code);
   }
 }
@@ -441,7 +441,7 @@ void CheckCompleteOrReductionBlock(const ScheduleState& self, const StmtSRef& bl
   if (reduction_block_error_code == 0) {
     return;
   }
-  const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
+  const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref);
   throw NotCompleteOrReductionBlockError(self->mod, GetRef<Block>(block), complete_block_error_code,
                                          reduction_block_error_code);
 }
@@ -491,7 +491,7 @@ void CheckSubtreeCompactDataflow(const ScheduleState& self, const StmtSRef& subt
     int local_complete_block_code = CheckCompleteBlockErrorCode(self, block_sref, subtree_root),
         local_reduction_block_code = CheckReductionBlockErrorCode(self, block_sref, subtree_root);
     if (local_complete_block_code != 0 && local_reduction_block_code != 0) {
-      const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
+      const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref);
       throw NotCompactDataFlowError(self->mod, GetRef<Stmt>(subtree_root->stmt),
                                     GetRef<Block>(block), local_complete_block_code,
                                     local_reduction_block_code);
@@ -501,8 +501,8 @@ void CheckSubtreeCompactDataflow(const ScheduleState& self, const StmtSRef& subt
 
 bool IsOutputBlock(const ScheduleState& self, const StmtSRef& block_sref,
                    const StmtSRef& scope_root_sref) {
-  const BlockNode* scope_root = TVM_SREF_TO_BLOCK(scope_root, scope_root_sref);
-  const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
+  const BlockNode* scope_root = TVM_SREF_TO_BLOCK(scope_root_sref);
+  const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref);
   std::unordered_set<const BufferNode*> scope_allocated;
   scope_allocated.reserve(scope_root->alloc_buffers.size());
   for (const Buffer& buffer : scope_root->alloc_buffers) {
@@ -532,7 +532,7 @@ void CheckNotOutputBlock(const ScheduleState& self, const StmtSRef& block_sref,
     Block block_;
   };
   if (IsOutputBlock(self, block_sref, scope_root_sref)) {
-    const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
+    const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref);
     throw OutputBlockError(self->mod, GetRef<Block>(block));
   }
 }
@@ -547,12 +547,12 @@ std::vector<IterVarType> GetBlockVarTypes(const BlockNode* block) {
 }
 
 std::vector<IterVarType> GetBlockVarTypes(const StmtSRef& block_sref) {
-  const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
+  const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref);
   return GetBlockVarTypes(block);
 }
 
 bool IsWriteCache(const StmtSRef& block_sref) {
-  const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
+  const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref);
   if (block->writes.size() != 1) {
     return false;
   }
@@ -751,7 +751,7 @@ void CheckLoopStartsWithZero(const ScheduleState& self, const StmtSRef& loop_sre
     IRModule mod_;
     For loop_;
   };
-  const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+  const ForNode* loop = TVM_SREF_TO_FOR(loop_sref);
   if (!analyzer->CanProve(loop->min == 0)) {
     throw LoopNotStartWithZeroError(self->mod, GetRef<For>(loop));
   }
@@ -856,7 +856,7 @@ BlockRealize GetBlockRealize(const ScheduleState& self, const StmtSRef& block_sr
     const BlockRealizeNode* result;
   };
 
-  const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
+  const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref);
   if (block_sref->parent == nullptr) {
     const PrimFuncNode* func = GetRootPrimFunc(self->mod, block, nullptr);
     return Downcast<BlockRealize>(func->body);
@@ -870,7 +870,7 @@ BlockRealize GetBlockRealize(const ScheduleState& self, const StmtSRef& block_sr
 }
 
 IterVarType GetLoopIterType(const StmtSRef& loop_sref) {
-  const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+  const ForNode* loop = TVM_SREF_TO_FOR(loop_sref);
   const Var& loop_var = loop->loop_var;
   int n_spatial = 0;
   int n_reduce = 0;
@@ -1924,7 +1924,7 @@ void CheckStorageScope(const ScheduleState& self, String storage_scope) {
 }
 
 bool IsSpatial(const StmtSRef& block_sref) {
-  const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
+  const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref);
   for (const IterVar& iter_var : block->iter_vars) {
     if (iter_var->iter_type != IterVarType::kDataPar) {
       return false;
@@ -1934,14 +1934,14 @@ bool IsSpatial(const StmtSRef& block_sref) {
 }
 
 bool IsTrivialBinding(const ScheduleState& self, const StmtSRef& block_sref) {
-  const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
+  TVM_SREF_TO_BLOCK(block_sref);
   Array<StmtSRef> loops = GetLoops(block_sref);
   Array<PrimExpr> binds = GetBlockRealize(self, block_sref)->iter_values;
   if (loops.size() != binds.size()) {
     return false;
   }
   for (int i = 0, n = loops.size(); i < n; ++i) {
-    const ForNode* loop = TVM_SREF_TO_FOR(loop, loops[i]);
+    const ForNode* loop = TVM_SREF_TO_FOR(loops[i]);
     if (binds[i].get() != loop->loop_var.get()) {
       return false;
     }
@@ -1953,7 +1953,7 @@ bool NeedsMultiLevelTiling(const ScheduleState& self, const StmtSRef& block_sref
   if (HasBeenMultiLevelTiled(block_sref)) {
     return false;
   }
-  const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
+  const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref);
   if (block->writes.size() != 1 || block->reads.empty() || IsSpatial(block_sref) ||
       !IsTrivialBinding(self, block_sref)) {
     return false;
@@ -2065,7 +2065,7 @@ bool NeedsRFactorOrCrossThreadReduction(const tir::ScheduleState& self,   //
                                         const tir::StmtSRef& block_sref,  //
                                         int64_t max_parallel_extent,      //
                                         int64_t max_parallel_basic) {
-  const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
+  const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref);
   Array<tir::StmtSRef> loops = tir::GetLoops(block_sref);
 
   // Cond 1. The block has only one write buffer
@@ -2100,9 +2100,9 @@ bool NeedsRFactorOrCrossThreadReduction(const tir::ScheduleState& self,   //
     }
 
     // Cond 5.
-    const ForNode* loop_i = TVM_SREF_TO_FOR(loop_i, loops[i]);
+    const ForNode* loop_i = TVM_SREF_TO_FOR(loops[i]);
     if (i < loops.size() - 1) {
-      const ForNode* loop_i1 = TVM_SREF_TO_FOR(loop_i1, loops[i + 1]);
+      const ForNode* loop_i1 = TVM_SREF_TO_FOR(loops[i + 1]);
       if (loop_i->body.get() != loop_i1) {
         return false;
       }
@@ -2194,7 +2194,7 @@ Optional<TensorizeInfo> GetTensorizeLoopMapping(const tir::ScheduleState& self,
   TensorIntrinDescInfo desc_info = ExtractTensorIntrinDescInfo(&analyzer, desc_func);
   // Step 2. Collect loops from block_sref
   const tir::StmtSRef& scope_sref = GetScopeRoot(self, block_sref, false);
-  const tir::BlockNode* scope_block = TVM_SREF_TO_BLOCK(scope_block, scope_sref);
+  TVM_SREF_TO_BLOCK(scope_sref);
   std::vector<const tir::ForNode*> block_loops;
   std::unordered_set<const tir::VarNode*> block_loop_vars;
   {
diff --git a/src/tir/schedule/block_scope.cc b/src/tir/schedule/block_scope.cc
index f1ce65e48e..31452f4a8f 100644
--- a/src/tir/schedule/block_scope.cc
+++ b/src/tir/schedule/block_scope.cc
@@ -76,7 +76,7 @@ BlockScope::BlockScope(const Array<StmtSRef>& child_block_srefs) {
   SMap<Buffer, Array<StmtSRef>> buffer_readers;
   SMap<Buffer, Array<StmtSRef>>& buffer_writers = n->buffer_writers;
   for (const StmtSRef& child_block_sref : child_block_srefs) {
-    const BlockNode* child_block = TVM_SREF_TO_BLOCK(child_block, child_block_sref);
+    const BlockNode* child_block = TVM_SREF_TO_BLOCK(child_block_sref);
     // Step 1. Update `buffer_readers` and `buffer_writers` for each buffer
     for (const BufferRegion& region : child_block->reads) {
       buffer_readers[region->buffer].push_back(child_block_sref);
diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc
index 5f773a02d6..afc6757997 100644
--- a/src/tir/schedule/concrete_schedule.cc
+++ b/src/tir/schedule/concrete_schedule.cc
@@ -269,7 +269,7 @@ BlockRV ConcreteScheduleNode::GetBlock(const String& name, const Optional<String
         : name_(name), mod_(mod), blocks_{} {
       blocks_.reserve(blocks.size());
       for (const StmtSRef& block_sref : blocks) {
-        const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
+        const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref);
         blocks_.push_back(GetRef<Block>(block));
       }
     }
@@ -432,7 +432,7 @@ Array<LoopRV> ConcreteScheduleNode::Split(const LoopRV& loop_rv,
 
   // Prepare for the splitting
   StmtSRef loop_sref = this->GetSRef(loop_rv);
-  const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+  const ForNode* loop = TVM_SREF_TO_FOR(loop_sref);
   Array<PrimExpr> factors;
   factors.reserve(factor_rvs.size());
   int infer_index = -1;
diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h
index 92b9de4088..e79d1d5288 100644
--- a/src/tir/schedule/concrete_schedule.h
+++ b/src/tir/schedule/concrete_schedule.h
@@ -206,13 +206,13 @@ class ConcreteScheduleNode : public ScheduleNode {
 
 inline Block ConcreteScheduleNode::Get(const BlockRV& block_rv) const {
   StmtSRef sref = this->GetSRef(block_rv);
-  const BlockNode* block = TVM_SREF_TO_BLOCK(block, sref);
+  const BlockNode* block = TVM_SREF_TO_BLOCK(sref);
   return GetRef<Block>(block);
 }
 
 inline For ConcreteScheduleNode::Get(const LoopRV& loop_rv) const {
   StmtSRef sref = this->GetSRef(loop_rv);
-  const ForNode* loop = TVM_SREF_TO_FOR(loop, sref);
+  const ForNode* loop = TVM_SREF_TO_FOR(sref);
   return GetRef<For>(loop);
 }
 
@@ -223,7 +223,7 @@ inline PrimExpr ConcreteScheduleNode::Get(const ExprRV& expr_rv) const {
       LOG(FATAL) << "IndexError: Cannot find corresponding ExprRV: " << var;
     }
     const ObjectRef& obj = (*it).second;
-    const auto* int_imm = TVM_TYPE_AS(int_imm, obj, IntImmNode);
+    const auto* int_imm = TVM_TYPE_AS(obj, IntImmNode);
     return Integer(int_imm->value);
   });
   return this->analyzer_->Simplify(transformed);
diff --git a/src/tir/schedule/primitive/block_annotate.cc b/src/tir/schedule/primitive/block_annotate.cc
index 2d876d9bf7..31c938313f 100644
--- a/src/tir/schedule/primitive/block_annotate.cc
+++ b/src/tir/schedule/primitive/block_annotate.cc
@@ -238,7 +238,7 @@ class StorageScopeMutator : private ReplaceBufferMutator {
 
 void StorageAlign(ScheduleState self, const StmtSRef& block_sref, int buffer_index, int axis,
                   int factor, int offset) {
-  const BlockNode* block_ptr = TVM_SREF_TO_BLOCK(block_ptr, block_sref);
+  const BlockNode* block_ptr = TVM_SREF_TO_BLOCK(block_sref);
   Buffer buffer =
       GetNthAccessBuffer(self, GetRef<Block>(block_ptr), buffer_index, BufferIndexType::kWrite);
   StorageAlignInvalidFactorError::Check(self->mod, factor);
@@ -274,7 +274,7 @@ void StorageAlign(ScheduleState self, const StmtSRef& block_sref, int buffer_ind
 
 void SetScope(ScheduleState self, const StmtSRef& block_sref, int buffer_index,
               const String& storage_scope) {
-  const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
+  const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref);
   Buffer buffer =
       GetNthAccessBuffer(self, GetRef<Block>(block), buffer_index, BufferIndexType::kWrite);
 
@@ -289,7 +289,7 @@ void SetScope(ScheduleState self, const StmtSRef& block_sref, int buffer_index,
   // Step 3. Get the allocation site of the target buffer.
   StmtSRef alloc_site_sref =
       NonAllocatedBufferError::CheckAndGetBufferAllocationSite(self->mod, block_sref, buffer);
-  const BlockNode* alloc_site = TVM_SREF_TO_BLOCK(alloc_site, alloc_site_sref);
+  const BlockNode* alloc_site = TVM_SREF_TO_BLOCK(alloc_site_sref);
 
   // Step 4. Recursively replace the old buffer to a new buffer, where the new buffer has the given
   // storage scope. In the meanwhile, collect the block sref reuse information.
diff --git a/src/tir/schedule/primitive/blockize_tensorize.cc b/src/tir/schedule/primitive/blockize_tensorize.cc
index cf6532e82d..7481a7c924 100644
--- a/src/tir/schedule/primitive/blockize_tensorize.cc
+++ b/src/tir/schedule/primitive/blockize_tensorize.cc
@@ -426,7 +426,7 @@ Stmt MakeLoopNest(Stmt stmt, const std::vector<const ForNode*>& loops) {
 
 BlockRealize BlockizeImpl(const ScheduleState& self, const StmtSRef& loop_sref,
                           Map<Block, Block>* block_sref_reuse, arith::Analyzer* analyzer) {
-  const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+  TVM_SREF_TO_FOR(loop_sref);
   // Step 1: Check and get the only block under `loop`.
   BlockRealize block_realize = CheckGetSingleChildBlockRealizeOnSRefTree(self, loop_sref);
   Block block = block_realize->block;
diff --git a/src/tir/schedule/primitive/cache_read_write.cc b/src/tir/schedule/primitive/cache_read_write.cc
index 529d3333cd..a221733eb3 100644
--- a/src/tir/schedule/primitive/cache_read_write.cc
+++ b/src/tir/schedule/primitive/cache_read_write.cc
@@ -31,7 +31,7 @@ class NotSingleWriteBlock : public ScheduleError {
     ICHECK_GT(write_blocks.size(), 1);
     write_blocks_.reserve(write_blocks.size());
     for (const StmtSRef& block_sref : write_blocks) {
-      const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
+      const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref);
       write_blocks_.push_back(GetRef<Block>(block));
     }
   }
@@ -532,7 +532,7 @@ class CacheReadRewriter : public StmtExprMutator {
     bool is_consumer = info_->consumer_blocks.empty();
     // Otherwise check if this is one of the specified blocks.
     for (StmtSRef consumer_sref : info_->consumer_blocks) {
-      const BlockNode* consumer_node = TVM_SREF_TO_BLOCK(consumer_node, consumer_sref);
+      const BlockNode* consumer_node = TVM_SREF_TO_BLOCK(consumer_sref);
       Block consumer_block = GetRef<Block>(consumer_node);
       if (old_stmt.same_as(consumer_block)) {
         is_consumer = true;
@@ -999,11 +999,11 @@ StmtSRef CacheRead(ScheduleState self, const StmtSRef& block_sref, int read_buff
   CheckStorageScope(self, storage_scope);
 
   // Step 1. Check index, getting the target buffer and the parent scope
-  const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
+  const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref);
   Buffer read_buffer =
       GetNthAccessBuffer(self, GetRef<Block>(block), read_buffer_index, BufferIndexType::kRead);
   StmtSRef scope_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/true);
-  const BlockNode* scope_block = TVM_SREF_TO_BLOCK(scope_block, scope_sref);
+  const BlockNode* scope_block = TVM_SREF_TO_BLOCK(scope_sref);
 
   // Step 2. Create CacheStageInfo
   CacheStageInfo info;
@@ -1020,7 +1020,7 @@ StmtSRef CacheRead(ScheduleState self, const StmtSRef& block_sref, int read_buff
   if (Optional<StmtSRef> _write_block_sref = GetOnlyWriteBlock(self, scope_sref, read_buffer)) {
     // Case 1. The buffer is written inside the block.
     StmtSRef write_block_sref = _write_block_sref.value();
-    const BlockNode* write_block = TVM_SREF_TO_BLOCK(write_block, write_block_sref);
+    const BlockNode* write_block = TVM_SREF_TO_BLOCK(write_block_sref);
     // Find the producing region
     BufferRegion region = GetBufferRegionFromBuffer(write_block->writes, read_buffer).value();
     StmtSRef parent_sref = GetRef<StmtSRef>(write_block_sref->parent);
@@ -1072,7 +1072,7 @@ StmtSRef CacheWrite(ScheduleState self, const StmtSRef& block_sref, int write_bu
   CheckStorageScope(self, storage_scope);
 
   // Step 1. Checking index, getting the target buffer and the parent scope
-  const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
+  const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref);
   Buffer write_buffer =
       GetNthAccessBuffer(self, GetRef<Block>(block), write_buffer_index, BufferIndexType::kWrite);
   StmtSRef scope_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/true);
@@ -1114,7 +1114,7 @@ StmtSRef CacheWrite(ScheduleState self, const StmtSRef& block_sref, int write_bu
 
 StmtSRef ReIndex(ScheduleState self, const StmtSRef& block_sref, int buffer_index,
                  BufferIndexType buffer_index_type) {
-  const BlockNode* block_ptr = TVM_SREF_TO_BLOCK(block_ptr, block_sref);
+  const BlockNode* block_ptr = TVM_SREF_TO_BLOCK(block_sref);
   Block block = GetRef<Block>(block_ptr);
   Buffer buffer = GetNthAccessBuffer(self, block, buffer_index, buffer_index_type);
   StmtSRef scope_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/true);
diff --git a/src/tir/schedule/primitive/compute_at.cc b/src/tir/schedule/primitive/compute_at.cc
index 8baedfd70d..83342e351b 100644
--- a/src/tir/schedule/primitive/compute_at.cc
+++ b/src/tir/schedule/primitive/compute_at.cc
@@ -37,7 +37,7 @@ class NotAllRequiredBlocksAreVisitedError : public ScheduleError {
       : mod_(mod), num_not_visited_(num_not_visited) {
     required_.reserve(required.size());
     for (const StmtSRef& block_sref : required) {
-      const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
+      const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref);
       required_.push_back(GetRef<Block>(block));
     }
   }
@@ -306,14 +306,14 @@ class ScopeReconstructor : private StmtMutator {
       return GetRef<Block>(block);
     }
     if (block == rm_src_stmt_.get()) {
-      block = TVM_TYPE_AS(block, rm_tgt_stmt_, BlockNode);
+      block = TVM_TYPE_AS(rm_tgt_stmt_, BlockNode);
     }
     return StmtMutator::VisitStmt_(block);
   }
 
   Stmt VisitStmt_(const ForNode* loop) final {
     if (loop == rm_src_stmt_.get()) {
-      loop = TVM_TYPE_AS(loop, rm_tgt_stmt_, ForNode);
+      loop = TVM_TYPE_AS(rm_tgt_stmt_, ForNode);
     }
     if (loop == loop_.get()) {
       return new_loop_;
@@ -559,7 +559,7 @@ void CalculateProvidedRequiredRegions(
   }
   // Step 2. Calculate the region required by dependent blocks under `loop`
   for (const StmtSRef& required_block_sref : is_compute_at ? consumer_srefs : producer_srefs) {
-    const BlockNode* required_block = TVM_SREF_TO_BLOCK(required_block, required_block_sref);
+    const BlockNode* required_block = TVM_SREF_TO_BLOCK(required_block_sref);
     ICHECK(block2realize.count(required_block));
     RelaxBufferRegions</*relax_storage_scope=*/is_compute_at>(
         /*binding=*/GetBindings(GetRef<BlockRealize>(block2realize.at(required_block))),
@@ -576,8 +576,8 @@ void ComputeAtOrReverseComputeAtImpl(ScheduleState self, const StmtSRef& block_s
                                      const StmtSRef& loop_sref, bool preserve_unit_loops,
                                      arith::Analyzer* analyzer, bool check_only = false,
                                      int index = -1) {
-  const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
-  const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+  const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref);
+  const ForNode* loop = TVM_SREF_TO_FOR(loop_sref);
   // Step 1. Bunch of checks
   // Check condition 1) : scope stage pipeline
   StmtSRef scope_root_sref = GetScopeRoot(self, block_sref,
diff --git a/src/tir/schedule/primitive/compute_inline.cc b/src/tir/schedule/primitive/compute_inline.cc
index ad15e06e28..bfda66036f 100644
--- a/src/tir/schedule/primitive/compute_inline.cc
+++ b/src/tir/schedule/primitive/compute_inline.cc
@@ -174,7 +174,7 @@ class NonSingleProducerError : public ScheduleError {
         }
       }
     }
-    const BlockNode* block = TVM_SREF_TO_BLOCK(block, consumer_block_sref);
+    const BlockNode* block = TVM_SREF_TO_BLOCK(consumer_block_sref);
     throw NonSingleProducerError(self->mod, GetRef<Block>(block));
   }
 };
@@ -183,7 +183,7 @@ class OpaqueAccessError : public ScheduleError {
  public:
   explicit OpaqueAccessError(IRModule mod, StmtSRef scope_root_sref)
       : mod_(mod), scope_root_(nullptr) {
-    const BlockNode* scope_root = TVM_SREF_TO_BLOCK(scope_root, scope_root_sref);
+    const BlockNode* scope_root = TVM_SREF_TO_BLOCK(scope_root_sref);
     this->scope_root_ = GetRef<Block>(scope_root);
   }
 
@@ -653,7 +653,7 @@ class ReverseComputeInliner : public BaseInliner {
 
 void ComputeInlineImpl(ScheduleState self, const StmtSRef& producer_block_sref,
                        bool check_only = false) {
-  const BlockNode* _producer_block = TVM_SREF_TO_BLOCK(_producer_block, producer_block_sref);
+  const BlockNode* _producer_block = TVM_SREF_TO_BLOCK(producer_block_sref);
   Block producer_block = GetRef<Block>(_producer_block);
   HasInitBlock::Check(self->mod, producer_block);
   Buffer inlined_buffer = NotSingleReadWriteBuffer::GetSingleWrite(self, producer_block);
@@ -698,7 +698,7 @@ bool CanComputeInline(const ScheduleState& self, const StmtSRef& producer_block_
 
 void ReverseComputeInlineImpl(ScheduleState self, const StmtSRef& consumer_block_sref,
                               bool check_only = false) {
-  const BlockNode* _consumer_block = TVM_SREF_TO_BLOCK(_consumer_block, consumer_block_sref);
+  const BlockNode* _consumer_block = TVM_SREF_TO_BLOCK(consumer_block_sref);
   Block consumer_block = GetRef<Block>(_consumer_block);
   HasInitBlock::Check(self->mod, consumer_block);
   // Step 1. Get the scope block
diff --git a/src/tir/schedule/primitive/decompose_padding.cc b/src/tir/schedule/primitive/decompose_padding.cc
index 365c6d43f1..93fb88e666 100644
--- a/src/tir/schedule/primitive/decompose_padding.cc
+++ b/src/tir/schedule/primitive/decompose_padding.cc
@@ -415,7 +415,7 @@ StmtSRef DecomposePaddingImpl(ScheduleState self, const StmtSRef& block_sref,
    *    - trim original block to write non-padding part only
    */
   // Condition Checks and Information Collection
-  const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
+  const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref);
   const BlockRealizeNode* realize = GetBlockRealize(self, block_sref).get();
   Map<Var, Range> dom_map;
   arith::Analyzer analyzer;
diff --git a/src/tir/schedule/primitive/for_kind.cc b/src/tir/schedule/primitive/for_kind.cc
index ec337224e5..cc8cb55fd3 100644
--- a/src/tir/schedule/primitive/for_kind.cc
+++ b/src/tir/schedule/primitive/for_kind.cc
@@ -145,7 +145,7 @@ void CheckParallelizability(const ScheduleState& self, const For& loop, ForKind
  */
 void ParallelizeComputation(const ScheduleState& self, const StmtSRef& loop_sref, ForKind for_kind,
                             Optional<IterVar> thread_axis) {
-  const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+  const ForNode* loop = TVM_SREF_TO_FOR(loop_sref);
 
   /*
    * Check:
@@ -186,7 +186,7 @@ void Bind(ScheduleState self, const StmtSRef& loop_sref, const IterVar& thread_a
 }
 
 void Unroll(ScheduleState self, const StmtSRef& loop_sref) {
-  const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+  const ForNode* loop = TVM_SREF_TO_FOR(loop_sref);
   ObjectPtr<ForNode> new_loop = make_object<ForNode>(*loop);
   new_loop->kind = ForKind::kUnrolled;
   new_loop->thread_binding = NullOpt;
diff --git a/src/tir/schedule/primitive/get_block_loop.cc b/src/tir/schedule/primitive/get_block_loop.cc
index 746918ac4e..cbdb99c644 100644
--- a/src/tir/schedule/primitive/get_block_loop.cc
+++ b/src/tir/schedule/primitive/get_block_loop.cc
@@ -40,7 +40,7 @@ Array<StmtSRef> GetBlocks(const ScheduleState& self, const String& name, const G
   };
 
   BaseFunc func = self->mod->Lookup(gv);
-  const auto* prim_func = TVM_TYPE_AS(prim_func, func, PrimFuncNode);
+  const auto* prim_func = TVM_TYPE_AS(func, PrimFuncNode);
   Finder finder(self, name);
   finder(prim_func->body);
   return std::move(finder.results_);
diff --git a/src/tir/schedule/primitive/layout_transformation.cc b/src/tir/schedule/primitive/layout_transformation.cc
index 148b3ee033..b4e40fa120 100644
--- a/src/tir/schedule/primitive/layout_transformation.cc
+++ b/src/tir/schedule/primitive/layout_transformation.cc
@@ -134,7 +134,7 @@ class BufferIsSubregionError : public ScheduleError {
 
 void TransformLayout(ScheduleState self, const StmtSRef& block_sref, int buffer_index,
                      BufferIndexType buffer_index_type, const IndexMap& index_map) {
-  const BlockNode* block_ptr = TVM_SREF_TO_BLOCK(block_ptr, block_sref);
+  const BlockNode* block_ptr = TVM_SREF_TO_BLOCK(block_sref);
   Buffer old_buffer =
       GetNthAccessBuffer(self, GetRef<Block>(block_ptr), buffer_index, buffer_index_type);
   Optional<StmtSRef> defining_site_sref;
@@ -147,7 +147,7 @@ void TransformLayout(ScheduleState self, const StmtSRef& block_sref, int buffer_
   StmtSRef scope_sref = defining_site_sref.defined()
                             ? defining_site_sref.value()
                             : GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false);
-  const BlockNode* scope_block = TVM_SREF_TO_BLOCK(scope_block, scope_sref);
+  const BlockNode* scope_block = TVM_SREF_TO_BLOCK(scope_sref);
 
   // Step 1: Infer the shape of the new buffer
   ObjectPtr<BufferNode> new_buffer_node = make_object<BufferNode>(*(old_buffer.get()));
@@ -344,7 +344,7 @@ class OpaqueNewIterTypeError : public ScheduleError {
 
 void TransformBlockLayout(ScheduleState self, const StmtSRef& block_sref,
                           const IndexMap& index_map) {
-  const BlockNode* block_ptr = TVM_SREF_TO_BLOCK(block_ptr, block_sref);
+  const BlockNode* block_ptr = TVM_SREF_TO_BLOCK(block_sref);
   const Block& block = GetRef<Block>(block_ptr);
   arith::Analyzer analyzer;
 
@@ -489,7 +489,7 @@ class BufferAxisSeparatorMutator : private ReplaceBufferMutator {
 
 void SetAxisSeparator(ScheduleState self, const StmtSRef& block_sref, int buffer_index,
                       BufferIndexType buffer_index_type, const Array<IntImm>& axis_separators) {
-  const BlockNode* block_ptr = TVM_SREF_TO_BLOCK(block_ptr, block_sref);
+  const BlockNode* block_ptr = TVM_SREF_TO_BLOCK(block_sref);
   Buffer old_buffer =
       GetNthAccessBuffer(self, GetRef<Block>(block_ptr), buffer_index, buffer_index_type);
   Optional<StmtSRef> defining_site_sref;
@@ -502,7 +502,7 @@ void SetAxisSeparator(ScheduleState self, const StmtSRef& block_sref, int buffer
   StmtSRef scope_sref = defining_site_sref.defined()
                             ? defining_site_sref.value()
                             : GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false);
-  const BlockNode* scope_block = TVM_SREF_TO_BLOCK(scope_block, scope_sref);
+  const BlockNode* scope_block = TVM_SREF_TO_BLOCK(scope_sref);
 
   // Step 1: Check and update axis_separators of the buffer.
   Buffer new_buffer = old_buffer;
diff --git a/src/tir/schedule/primitive/loop_transformation.cc b/src/tir/schedule/primitive/loop_transformation.cc
index f1b6f46e1b..2db3eb902a 100644
--- a/src/tir/schedule/primitive/loop_transformation.cc
+++ b/src/tir/schedule/primitive/loop_transformation.cc
@@ -87,7 +87,7 @@ class IterMapSimplifyBlockBinding : public StmtExprMutator {
                               bool preserve_unit_iters) {
     Map<Var, Range> loop_var2extent;
     for (const StmtSRef& sref : loop_srefs) {
-      const ForNode* loop = TVM_SREF_TO_FOR(loop, sref);
+      const ForNode* loop = TVM_SREF_TO_FOR(sref);
       loop_var2extent.Set(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent));
     }
     return Downcast<For>(IterMapSimplifyBlockBinding(opaque_blocks, std::move(loop_var2extent),
@@ -389,7 +389,7 @@ Array<StmtSRef> Split(ScheduleState self, const StmtSRef& loop_sref, const Array
   // - The execution order has not changed. (The block executes with the same args and the same
   // order with before.
   // Step 1. Check correctness
-  const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+  const ForNode* loop = TVM_SREF_TO_FOR(loop_sref);
   if (!loop->annotations.empty() || loop->thread_binding.defined()) {
     throw HasAnnotationOrThreadBindingError(self->mod, GetRef<For>(loop));
   }
@@ -445,7 +445,7 @@ Array<StmtSRef> Split(ScheduleState self, const StmtSRef& loop_sref, const Array
   result_srefs.reserve(n);
   for (int i = 0; i < n; i++) {
     result_srefs.push_back(self->stmt2ref.at(new_stmt.get()));
-    const ForNode* outer_loop = TVM_TYPE_AS(outer_loop, new_stmt, ForNode);
+    const ForNode* outer_loop = TVM_TYPE_AS(new_stmt, ForNode);
     new_stmt = outer_loop->body;
   }
   return result_srefs;
@@ -464,7 +464,7 @@ StmtSRef Fuse(ScheduleState self, const Array<StmtSRef>& loop_srefs, bool preser
   std::unordered_set<const VarNode*> outer_loop_vars;
   // Step 1. check correctness
   for (const StmtSRef& sref : loop_srefs) {
-    const ForNode* loop = TVM_SREF_TO_FOR(loop, sref);
+    const ForNode* loop = TVM_SREF_TO_FOR(sref);
     if (!loop->annotations.empty() || loop->thread_binding.defined()) {
       throw HasAnnotationOrThreadBindingError(self->mod, GetRef<For>(loop));
     }
@@ -554,7 +554,7 @@ std::unordered_set<const StmtSRefNode*> CollectLoopsIntoSet(
   for (const StmtSRef& loop_sref : ordered_loop_srefs) {
     auto inserted = loop_srefs.insert(loop_sref.get());
     if (!inserted.second) {
-      const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+      const ForNode* loop = TVM_SREF_TO_FOR(loop_sref);
       throw LoopMultiAppearanceError(self->mod, GetRef<For>(loop));
     }
   }
diff --git a/src/tir/schedule/primitive/reduction.cc b/src/tir/schedule/primitive/reduction.cc
index ad9043e4f2..7a4ace736e 100644
--- a/src/tir/schedule/primitive/reduction.cc
+++ b/src/tir/schedule/primitive/reduction.cc
@@ -123,7 +123,7 @@ class LoopHeightError : public ScheduleError {
         // loop_var of a higher loop shouldn't contain loop var
         const Var& loop_var = higher_loop->StmtAs<ForNode>()->loop_var;
         if (UsesVar(binding, [v = loop_var.get()](const VarNode* var) { return var == v; })) {
-          const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+          const ForNode* loop = TVM_SREF_TO_FOR(loop_sref);
           throw LoopHeightError(mod, GetRef<For>(loop), GetRef<Block>(block));
         }
       }
@@ -183,8 +183,8 @@ StmtSRef DecomposeReduction(ScheduleState self, const StmtSRef& block_sref,
    *    - generate corresponding init block and update block
    */
   // Condition Checks and Information Collection
-  const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
-  const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+  const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref);
+  const ForNode* loop = TVM_SREF_TO_FOR(loop_sref);
   // Get the outer loops from high to low
   Array<StmtSRef> loops = GetLoops(block_sref);
   const BlockRealizeNode* realize = GetBlockRealize(self, block_sref).get();
@@ -264,7 +264,7 @@ StmtSRef DecomposeReduction(ScheduleState self, const StmtSRef& block_sref,
   std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> loop_var_map;
   Stmt body = BlockRealize(init_realize);
   for (int i : chosen_loops) {
-    const ForNode* old_loop = TVM_SREF_TO_FOR(old_loop, loops[i]);
+    const ForNode* old_loop = TVM_SREF_TO_FOR(loops[i]);
     // Create a new equivalent to the chosen loop
     Var old_loop_var = old_loop->loop_var;
     Var new_loop_var = old_loop_var.copy_with_suffix("_init");
@@ -277,7 +277,7 @@ StmtSRef DecomposeReduction(ScheduleState self, const StmtSRef& block_sref,
   }
   body = Substitute(body, loop_var_map);
   // Step 6. Mutate IR
-  const BlockNode* old_scope_root = TVM_SREF_TO_BLOCK(old_scope_root, scope_root_sref);
+  const BlockNode* old_scope_root = TVM_SREF_TO_BLOCK(scope_root_sref);
   Block new_scope_root{nullptr};
   Block new_reduction_block{nullptr};
   std::tie(new_scope_root, new_reduction_block) = DecomposeReductionBlockReplacer::Replace(
@@ -1013,7 +1013,7 @@ StmtSRef RFactor(ScheduleState self, const StmtSRef& rf_loop_sref, int factor_ax
   StmtSRef scope_root = GetScopeRoot(self, block_sref,  //
                                      /*require_stage_pipeline=*/true);
   CheckReductionBlock(self, block_sref, scope_root);
-  const ForNode* rf_loop = TVM_SREF_TO_FOR(rf_loop, rf_loop_sref);
+  const ForNode* rf_loop = TVM_SREF_TO_FOR(rf_loop_sref);
   if (rf_loop->kind != ForKind::kSerial) {
     throw NotSerialLoopKindError(self->mod, GetRef<For>(rf_loop));
   }
diff --git a/src/tir/schedule/primitive/sampling.cc b/src/tir/schedule/primitive/sampling.cc
index 1961565aac..52b5add2bc 100644
--- a/src/tir/schedule/primitive/sampling.cc
+++ b/src/tir/schedule/primitive/sampling.cc
@@ -311,7 +311,7 @@ std::vector<int64_t> SamplePerfectTile(
     support::LinearCongruentialEngine::TRandState* rand_state,  //
     const tir::StmtSRef& loop_sref, int32_t n_splits, int32_t max_innermost_factor,
     Optional<Array<Integer>>* decision) {
-  const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+  const ForNode* loop = TVM_SREF_TO_FOR(loop_sref);
   const int64_t* extent = GetLoopIntExtent(loop);
   std::vector<int64_t> result;
   if (extent == nullptr) {
diff --git a/src/tir/schedule/state.cc b/src/tir/schedule/state.cc
index 07481ddb19..15d0e08ddc 100644
--- a/src/tir/schedule/state.cc
+++ b/src/tir/schedule/state.cc
@@ -208,7 +208,7 @@ class BlockInfoCollector : private StmtVisitor {
     if (is_root_block) {
       // If the block doesn't have outer loops and BlockRealize,
       // then we set the affine binding flag as true only if the block has no block vars
-      const BlockNode* block = TVM_SREF_TO_BLOCK(block, scope_root);
+      const BlockNode* block = TVM_SREF_TO_BLOCK(scope_root);
       if (block->iter_vars.empty()) info.affine_binding = true;
     } else {
       info.affine_binding =
@@ -233,7 +233,7 @@ class BlockInfoCollector : private StmtVisitor {
     block_reads_unbound.reserve(child_block_srefs.size());
     block_writes_unbound.reserve(child_block_srefs.size());
     for (const StmtSRef& block_sref : child_block_srefs) {
-      const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
+      const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref);
       Map<Var, PrimExpr> binding = GetBindings(block2realize_.at(block));
       // Step 1.1. Unbind read regions
       Array<BufferRegion> reads;
@@ -254,7 +254,7 @@ class BlockInfoCollector : private StmtVisitor {
     for (const auto& kv : info.scope->dst2deps) {
       const StmtSRef& consumer_block_sref = kv.first;
       const Array<Dependency>& deps = kv.second;
-      const BlockNode* consumer_block = TVM_SREF_TO_BLOCK(consumer_block, consumer_block_sref);
+      const BlockNode* consumer_block = TVM_SREF_TO_BLOCK(consumer_block_sref);
       const BlockRealize& consumer_realize = block2realize_.at(consumer_block);
       bool& region_cover = self_->block_info.at(consumer_block_sref).region_cover = true;
       // Step 2.1. Extract the path to the scope root
@@ -851,7 +851,7 @@ class ChildReplacer : private StmtMutator {
       } else if (const auto* realize = stmt.as<BlockRealizeNode>()) {
         // Case 2. stmt is BlockRealize, src_stmt is Block
         if (realize->block.get() == src_stmt) {
-          const auto* tgt_block = TVM_TYPE_AS(tgt_block, tgt_stmt_, BlockNode);
+          const auto* tgt_block = TVM_TYPE_AS(tgt_stmt_, BlockNode);
           ObjectPtr<BlockRealizeNode> new_realize = make_object<BlockRealizeNode>(*realize);
           new_realize->block = GetRef<Block>(tgt_block);
           new_stmt = BlockRealize(std::move(new_realize));
@@ -1044,9 +1044,9 @@ void ScheduleStateNode::Replace(const tir::StmtSRef& _src_sref, const Stmt& tgt_
     // If `g_func` was unique, after the 3 lines above:
     //   `ref_new_func` points to the same unique function that `g_func` points to
     // Update the body of the function the sref belongs to Assign
-    const auto* realize = TVM_TYPE_AS(realize, g_func->body, BlockRealizeNode);
+    const auto* realize = TVM_TYPE_AS(g_func->body, BlockRealizeNode);
     // Make `child_tgt_stmt` the root block
-    const auto* child_block = TVM_TYPE_AS(child_block, child_tgt_stmt, BlockNode);
+    const auto* child_block = TVM_TYPE_AS(child_tgt_stmt, BlockNode);
     ObjectPtr<BlockRealizeNode> new_realize = make_object<BlockRealizeNode>(*realize);
     new_realize->block = GetRef<Block>(child_block);
     new_func->body = BlockRealize(std::move(new_realize));
@@ -1078,7 +1078,7 @@ void ScheduleStateNode::DebugVerify() const {
 /**************** BlockInfo-related ****************/
 
 BlockInfo ScheduleStateNode::GetBlockInfo(const StmtSRef& block_sref) const {
-  const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
+  TVM_SREF_TO_BLOCK(block_sref);
   auto it = this->block_info.find(block_sref);
   CHECK(it != this->block_info.end())
       << "IndexError: Cannot find the corresponding BlockScope to the block sref:\n"
diff --git a/src/tir/schedule/transform.cc b/src/tir/schedule/transform.cc
index 1c21d770db..1ebaf202d4 100644
--- a/src/tir/schedule/transform.cc
+++ b/src/tir/schedule/transform.cc
@@ -36,7 +36,7 @@ Block WithAnnotation(const BlockNode* block, const String& attr_key, const Objec
 Buffer WithScope(const Buffer& buffer, const String& scope) {
   ObjectPtr<BufferNode> new_buffer = make_object<BufferNode>(*buffer.get());
   ObjectPtr<VarNode> new_var = make_object<VarNode>(*buffer->data.get());
-  const auto* ptr_type = TVM_TYPE_AS(ptr_type, buffer->data->type_annotation, PointerTypeNode);
+  const auto* ptr_type = TVM_TYPE_AS(buffer->data->type_annotation, PointerTypeNode);
   new_var->type_annotation = PointerType(ptr_type->element_type, scope);
   new_buffer->data = Var(new_var->name_hint + "_" + scope, new_var->type_annotation);
   new_buffer->name = buffer->name + "_" + scope;
@@ -253,8 +253,8 @@ void LeafBlockRemovalPlan(const ScheduleState& self, const StmtSRef& leaf_block_
     }
   }
   ICHECK(sref != nullptr && sref->stmt != nullptr);
-  const auto* leaf_block = TVM_SREF_TO_BLOCK(leaf_block, leaf_block_sref);
-  const auto* scope_block = TVM_SREF_TO_BLOCK(scope_block, sref);
+  const auto* leaf_block = TVM_SREF_TO_BLOCK(leaf_block_sref);
+  const auto* scope_block = TVM_SREF_TO_BLOCK(sref);
   throw OnlyLeafError(self->mod, GetRef<Block>(leaf_block), GetRef<Block>(scope_block));
 }
 
diff --git a/src/tir/schedule/utils.h b/src/tir/schedule/utils.h
index 3db80989ae..c289309acc 100644
--- a/src/tir/schedule/utils.h
+++ b/src/tir/schedule/utils.h
@@ -62,25 +62,35 @@ namespace tir {
 
 /*!
  * \brief A helper macro to convert an sref to the block it points to,
- * throwing an internal error if downcasting fails
- * \param Result The result variable, used for checking
+ *
+ * Throws an internal error if downcasting fails.  The variable name
+ * in the parent scope is used for the error message.
+ *
  * \param SRef The SRef to be cast
  */
-#define TVM_SREF_TO_BLOCK(Result, SRef)                   \
-  TVM_SREF_AS_OR_ERR(Result, SRef, ::tvm::tir::BlockNode) \
-      << "TypeError: Expects StmtSRef `" << #SRef         \
-      << "` points to `Block`, but gets: " << (SRef->stmt ? SRef->stmt->GetTypeKey() : "None")
+#define TVM_SREF_TO_BLOCK(SRef)                                                                    \
+  [&]() {                                                                                          \
+    auto result = TVM_SREF_AS_OR_ERR(result, (SRef), ::tvm::tir::BlockNode)                        \
+                  << "TypeError: Expects StmtSRef `" << #SRef << "` points to `Block`, but gets: " \
+                  << ((SRef)->stmt ? (SRef)->stmt->GetTypeKey() : "None");                         \
+    return result;                                                                                 \
+  }()
 
 /*!
- * \brief A helper macro to convert an sref to the for-loop it points to,
- * throwing an internal error if downcasting fails
- * \param Result The name of the result variable, used for checking
+ * \brief A helper macro to convert an sref to the for-loop it points to
+ *
+ * Throws an internal error if downcasting fails.  The variable name
+ * in the parent scope is used for the error message.
+ *
  * \param SRef The SRef to be cast
  */
-#define TVM_SREF_TO_FOR(Result, SRef)                   \
-  TVM_SREF_AS_OR_ERR(Result, SRef, ::tvm::tir::ForNode) \
-      << "TypeError: Expects StmtSRef `" << #SRef       \
-      << "` points to `Loop`, but gets: " << (SRef->stmt ? SRef->stmt->GetTypeKey() : "None")
+#define TVM_SREF_TO_FOR(SRef)                                                                     \
+  [&]() {                                                                                         \
+    auto result = TVM_SREF_AS_OR_ERR(result, (SRef), ::tvm::tir::ForNode)                         \
+                  << "TypeError: Expects StmtSRef `" << #SRef << "` points to `Loop`, but gets: " \
+                  << ((SRef)->stmt ? (SRef)->stmt->GetTypeKey() : "None");                        \
+    return result;                                                                                \
+  }()
 
 /*!
  * \brief Downcast a TVM ObjectRef to its corresponding container using `ObjectRef::as<Type>`,
@@ -100,10 +110,13 @@ namespace tir {
  * \param From The ObjectRef to be downcast
  * \param Type The type to be downcast to
  */
-#define TVM_TYPE_AS(Result, From, Type)                                           \
-  TVM_TYPE_AS_OR_ERR(Result, From, Type)                                          \
-      << "TypeError: Expects `" << #From << "` to have type `" << Type::_type_key \
-      << "`, but gets: " << (From.defined() ? From->GetTypeKey() : "None")
+#define TVM_TYPE_AS(From, Type)                                                               \
+  [&]() {                                                                                     \
+    auto result = TVM_TYPE_AS_OR_ERR(result, (From), Type)                                    \
+                  << "TypeError: Expects `" << #From << "` to have type `" << Type::_type_key \
+                  << "`, but gets: " << ((From).defined() ? (From)->GetTypeKey() : "None");   \
+    return result;                                                                            \
+  }()
 
 /*!
  * \brief Convert an array of loop StmtSRefs to an array of loops
@@ -114,7 +127,7 @@ inline Array<For> LoopSRefs2Loops(const Array<StmtSRef>& loop_srefs) {
   Array<For> loops;
   loops.reserve(loop_srefs.size());
   for (StmtSRef loop_sref : loop_srefs) {
-    const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+    const ForNode* loop = TVM_SREF_TO_FOR(loop_sref);
     loops.push_back(GetRef<For>(loop));
   }
   return loops;
@@ -264,7 +277,7 @@ inline const int64_t* GetLoopIntExtent(const ForNode* loop) { return as_const_in
  * \return The extent of the loop, nullptr if the extent is not constant
  */
 inline const int64_t* GetLoopIntExtent(const StmtSRef& loop_sref) {
-  const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+  const ForNode* loop = TVM_SREF_TO_FOR(loop_sref);
   return as_const_int(loop->extent);
 }