You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2020/06/28 23:02:16 UTC

[incubator-tvm] branch master updated: [REFACTOR][TIR][API-Change] Range/IntSet API style consistency. (#5953)

This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new 0465ffd  [REFACTOR][TIR][API-Change] Range/IntSet API style consistency. (#5953)
0465ffd is described below

commit 0465ffda26a4e9d64944f01b33b566b6663d3afe
Author: Tianqi Chen <tq...@users.noreply.github.com>
AuthorDate: Sun Jun 28 16:02:06 2020 -0700

    [REFACTOR][TIR][API-Change] Range/IntSet API style consistency. (#5953)
    
    - Range::make_by_min_extent -> Range::FromMinExtent
    - Update the APIs in IntSet to use CamelCase
---
 docs/dev/inferbound.rst                            |  6 +--
 include/tvm/arith/int_set.h                        | 51 ++++++++----------
 include/tvm/ir/expr.h                              |  2 +-
 python/tvm/ir/expr.py                              |  4 +-
 python/tvm/te/hybrid/parser.py                     |  2 +-
 python/tvm/te/tensor_intrin.py                     |  2 +-
 src/arith/bound_deducer.cc                         |  8 +--
 src/arith/domain_touched.cc                        |  6 +--
 src/arith/int_set.cc                               | 62 +++++++++++-----------
 src/arith/ir_mutator_with_analyzer.cc              |  4 +-
 src/arith/ir_visitor_with_analyzer.h               |  4 +-
 src/arith/solve_linear_equation.cc                 |  4 +-
 src/ir/expr.cc                                     |  4 +-
 src/target/llvm/codegen_llvm.cc                    |  4 +-
 src/target/spirv/codegen_spirv.cc                  |  4 +-
 src/te/operation/compute_op.cc                     |  4 +-
 src/te/operation/extern_op.cc                      |  6 +--
 src/te/operation/hybrid_op.cc                      |  8 +--
 src/te/operation/scan_op.cc                        | 14 ++---
 src/te/operation/tensorize.cc                      | 10 ++--
 src/te/schedule/bound.cc                           | 22 ++++----
 src/te/schedule/message_passing.cc                 | 46 ++++++++--------
 src/te/schedule/schedule_dataflow_rewrite.cc       |  2 +-
 src/te/schedule/schedule_lang.cc                   |  2 +-
 .../schedule_postproc_rewrite_for_tensor_core.cc   |  4 +-
 src/tir/ir/expr_functor.cc                         |  2 +-
 src/tir/ir/stmt_functor.cc                         |  2 +-
 src/tir/transforms/coproc_sync.cc                  |  2 +-
 src/tir/transforms/inject_prefetch.cc              |  8 +--
 src/tir/transforms/loop_partition.cc               | 24 ++++-----
 src/tir/transforms/lower_warp_memory.cc            |  4 +-
 src/tir/transforms/narrow_datatype.cc              |  4 +-
 src/tir/transforms/simplify.cc                     |  2 +-
 src/tir/transforms/storage_access.cc               |  8 +--
 src/tir/transforms/thread_storage_sync.cc          |  4 +-
 tests/python/unittest/test_arith_intset.py         |  4 +-
 .../unittest/test_arith_solve_linear_system.py     |  4 +-
 tests/python/unittest/test_tir_ir_builder.py       |  4 +-
 .../unittest/test_tir_transform_storage_flatten.py |  2 +-
 39 files changed, 175 insertions(+), 184 deletions(-)

diff --git a/docs/dev/inferbound.rst b/docs/dev/inferbound.rst
index 6520732..63954ac 100644
--- a/docs/dev/inferbound.rst
+++ b/docs/dev/inferbound.rst
@@ -181,8 +181,8 @@ The Ranges of the inner and outer IterVars of the split are set based on the par
 
 .. code:: cpp
 
-   rmap[split->inner] = Range::make_by_min_extent(0, split->factor)
-   rmap[split->outer] = Range::make_by_min_extent(0, DivCeil(rmap[split->parent]->extent, split->factor))
+   rmap[split->inner] = Range::FromMinExtent(0, split->factor)
+   rmap[split->outer] = Range::FromMinExtent(0, DivCeil(rmap[split->parent]->extent, split->factor))
 
 There is an opportunity here to tighten the bounds produced by InferBound, when ``split->factor`` does not evenly divide the parent's extent. Suppose the parent's extent is 20, and the split factor is 16. Then on the second iteration of the outer loop, the inner loop only needs to perform 4 iterations, not 16. If PassDownDomain could set the extent of ``split->inner`` to ``min(split->factor, rmap[split->parent]->extent - (split->outer * split->factor))``, then the extent of the inner var [...]
 
@@ -190,7 +190,7 @@ For Fuse relations, the Range of the fused IterVar is set based on the known Ran
 
 .. code:: cpp
 
-   rmap[fuse->fused] = Range::make_by_min_extent(0, rmap[fuse->outer]->extent * rmap[fuse->inner]->extent)
+   rmap[fuse->fused] = Range::FromMinExtent(0, rmap[fuse->outer]->extent * rmap[fuse->inner]->extent)
 
 
 InferRootBound
diff --git a/include/tvm/arith/int_set.h b/include/tvm/arith/int_set.h
index ae90bde..515392d 100644
--- a/include/tvm/arith/int_set.h
+++ b/include/tvm/arith/int_set.h
@@ -65,82 +65,75 @@ class IntSetNode : public Object {
  */
 class IntSet : public ObjectRef {
  public:
-  /*! \brief constructor */
-  IntSet() {}
-  // constructor from not container.
-  explicit IntSet(ObjectPtr<Object> n) : ObjectRef(n) {}
-  /*!
-   * \brief access the internal node container
-   * \return the pointer to the internal node container
-   */
-  const IntSetNode* operator->() const { return static_cast<const IntSetNode*>(get()); }
   /*!
    * \brief Find a range that covers the region.
    * \param max_range The range to be covered.
    * \return The covering range.
    */
-  Range cover_range(Range max_range) const;
+  Range CoverRange(Range max_range) const;
   /*! \return Lower bound of the set */
   PrimExpr min() const;
   /*! \return upper bound of the set */
   PrimExpr max() const;
+  /*! \return The sign of the elements in the integer set */
+  SignType GetSignType() const;
   /*! \return Whether the set represent nothing  */
-  bool is_nothing() const;
+  bool IsNothing() const;
   /*! \return Whether the set represent everything  */
-  bool is_everything() const;
+  bool IsEverything() const;
   /*! \return Whether the set is a single point */
-  bool is_single_point() const;
+  bool IsSinglePoint() const;
   /*! \return Whether the set is proved to be bigger than 0 */
-  bool can_prove_positive() const;
+  bool CanProvePositive() const;
   /*! \return Whether the set is proved to be smaller than 0 */
-  bool can_prove_negative() const;
+  bool CanProveNegative() const;
   /*! \return Whether the set is proved to be smaller than or equal to 0 */
-  bool can_prove_non_positive() const;
+  bool CanProveNonPositive() const;
   /*! \return Whether the set is proved to be larger than or equal to 0 */
-  bool can_prove_non_negative() const;
-  /*! \return The sign of the elements in the integer set */
-  SignType sign_type() const;
+  bool CanProveNonNegative() const;
   /*!
-   * \brief The single point value, call only if is_single_point is true
+   * \brief The single point value, call only if IsSinglePoint is true
    * \return The point value.
    */
-  PrimExpr point_value() const;
+  PrimExpr PointValue() const;
   /*!
    * \brief Try to match IntSet with range r.
    *
-   * \note It is guanrateed that IntSet::range(r).match_range(r) == true
+   * \note It is guanrateed that IntSet::FromRange(r).MatchRange(r) == true
    * \return true if we can prove they are the same.
    */
-  bool match_range(const Range& r) const;
+  bool MatchRange(const tvm::Range& r) const;
   /*! \return The set contains nothing */
-  static IntSet nothing();
+  static IntSet Nothing();
   /*! \return The set contains everything */
-  static IntSet everything();
+  static IntSet Everything();
   /*!
    * \brief construct a point set.
    * \param point The point in the set.
    * \return construct a single point set
    */
-  static IntSet single_point(PrimExpr point);
+  static IntSet SinglePoint(PrimExpr point);
   /*!
    * \brief construct a integer set from vector expression.
    * \param vec The vector expression, can also be single point.
    * \return The result set containing the indices in the vector.
    */
-  static IntSet vector(PrimExpr vec);
+  static IntSet Vector(PrimExpr vec);
   /*!
    * \brief Construct a set representing a range.
    * \param r The range
    * \return constructed set.
    */
-  static IntSet range(Range r);
+  static IntSet FromRange(tvm::Range r);
   /*!
    * \brief Construct a set representing a interval.
    * \param min The minimum value of the interval.
    * \param max The maximum value of the interval.
    * \return constructed set.
    */
-  static IntSet interval(PrimExpr min, PrimExpr max);
+  static IntSet Interval(PrimExpr min, PrimExpr max);
+
+  TVM_DEFINE_OBJECT_REF_METHODS(IntSet, ObjectRef, IntSetNode);
 };
 
 //-----------------------------------------------
diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h
index b6083c8..d6cfc5a 100644
--- a/include/tvm/ir/expr.h
+++ b/include/tvm/ir/expr.h
@@ -454,7 +454,7 @@ class Range : public ObjectRef {
    * \param min The minimum range.
    * \param extent The extent of the range.
    */
-  static Range make_by_min_extent(PrimExpr min, PrimExpr extent);
+  static Range FromMinExtent(PrimExpr min, PrimExpr extent);
   // declare range.
   TVM_DEFINE_OBJECT_REF_METHODS(Range, ObjectRef, RangeNode);
 };
diff --git a/python/tvm/ir/expr.py b/python/tvm/ir/expr.py
index eedfff8..0a3f205 100644
--- a/python/tvm/ir/expr.py
+++ b/python/tvm/ir/expr.py
@@ -118,7 +118,7 @@ class Range(Node):
                 _ffi_api.Range, begin, end)
 
     @staticmethod
-    def make_by_min_extent(min_value, extent):
+    def from_min_extent(min_value, extent):
         """Construct a Range by min and extent.
 
         This constructs a range in [min_value, min_value + extent)
@@ -136,4 +136,4 @@ class Range(Node):
         rng : Range
             The constructed range.
         """
-        return _ffi_api.range_by_min_extent(min_value, extent)
+        return _ffi_api.Range_from_min_extent(min_value, extent)
diff --git a/python/tvm/te/hybrid/parser.py b/python/tvm/te/hybrid/parser.py
index 913b453..b6f6e51 100644
--- a/python/tvm/te/hybrid/parser.py
+++ b/python/tvm/te/hybrid/parser.py
@@ -209,7 +209,7 @@ class HybridParser(ast.NodeVisitor):
             if _scope == 'global':
                 body = self.wrap_up_binds(body)
 
-            _domain = [Range.make_by_min_extent(0, i) for i in _buf.shape]
+            _domain = [Range.from_min_extent(0, i) for i in _buf.shape]
             _dtype = _buf.dtype
             _true = tvm.runtime.convert(True)
             body = tvm.tir.ProducerRealize(_buf, _domain, _true, body)
diff --git a/python/tvm/te/tensor_intrin.py b/python/tvm/te/tensor_intrin.py
index cd488a7..7d396ee 100644
--- a/python/tvm/te/tensor_intrin.py
+++ b/python/tvm/te/tensor_intrin.py
@@ -37,7 +37,7 @@ def _get_region(tslice):
                 begin = idx.var
             else:
                 begin = idx
-            region.append(Range.make_by_min_extent(begin, 1))
+            region.append(Range.from_min_extent(begin, 1))
     return region
 
 
diff --git a/src/arith/bound_deducer.cc b/src/arith/bound_deducer.cc
index 496eb20..f685170 100644
--- a/src/arith/bound_deducer.cc
+++ b/src/arith/bound_deducer.cc
@@ -135,7 +135,7 @@ class BoundDeducer : public ExprVisitor {
     if (operand.dtype().is_uint()) {
       sign_operand = kPositive;
     } else {
-      sign_operand = expr_map_[operand].sign_type();
+      sign_operand = expr_map_[operand].GetSignType();
     }
 
     if (sign_operand == SignType::kNegative) {
@@ -315,7 +315,7 @@ void BoundDeducer::Deduce() {
 void BoundDeducer::Relax() {
   IntSet a = EvalSet(expr_, relax_map_);
   IntSet b = EvalSet(result_, relax_map_);
-  if (a.is_everything() || b.is_everything()) {
+  if (a.IsEverything() || b.IsEverything()) {
     success_ = false;
     return;
   }
@@ -336,7 +336,7 @@ IntSet DeduceBound(PrimExpr v, PrimExpr e,
                    const std::unordered_map<const VarNode*, IntSet>& relax_map) {
   BoundDeducer d(v, e, hint_map, relax_map);
   d.Deduce();
-  if (!d.success_) return IntSet::nothing();
+  if (!d.success_) return IntSet::Nothing();
   PrimExpr min = neg_inf(), max = pos_inf();
   if (d.comp_op == kEqual) {
     min = d.result_;
@@ -346,7 +346,7 @@ IntSet DeduceBound(PrimExpr v, PrimExpr e,
   } else {
     max = d.result_;
   }
-  return IntSet::interval(min, max);
+  return IntSet::Interval(min, max);
 }
 
 // assuming e >= 0, deduce the bound of variable from it.
diff --git a/src/arith/domain_touched.cc b/src/arith/domain_touched.cc
index b44d9f7..d59486c 100644
--- a/src/arith/domain_touched.cc
+++ b/src/arith/domain_touched.cc
@@ -45,14 +45,14 @@ class BufferTouchedDomain final : public StmtExprVisitor {
     Region ret;
     Range none;
     for (size_t i = 0; i < bounds_.size(); ++i) {
-      ret.push_back(arith::Union(bounds_[i]).cover_range(none));
+      ret.push_back(arith::Union(bounds_[i]).CoverRange(none));
     }
     return ret;
   }
 
   void VisitStmt_(const ForNode* op) final {
     const VarNode* var = op->loop_var.get();
-    dom_map_[var] = IntSet::range(Range::make_by_min_extent(op->min, op->extent));
+    dom_map_[var] = IntSet::FromRange(Range::FromMinExtent(op->min, op->extent));
     StmtExprVisitor::VisitStmt_(op);
     dom_map_.erase(var);
   }
@@ -69,7 +69,7 @@ class BufferTouchedDomain final : public StmtExprVisitor {
       const IterVarNode* thread_axis = op->node.as<IterVarNode>();
       CHECK(thread_axis);
       const VarNode* var = thread_axis->var.get();
-      dom_map_[var] = IntSet::range(Range(make_zero(op->value.dtype()), op->value));
+      dom_map_[var] = IntSet::FromRange(Range(make_zero(op->value.dtype()), op->value));
       StmtExprVisitor::VisitStmt_(op);
       dom_map_.erase(var);
     } else {
diff --git a/src/arith/int_set.cc b/src/arith/int_set.cc
index b043b35..03645b4 100644
--- a/src/arith/int_set.cc
+++ b/src/arith/int_set.cc
@@ -493,14 +493,14 @@ IntSet IntSetAnalyzer::operator()(const PrimExpr& expr, const Map<Var, IntSet>&
 
 // Quickly adapt to IntSet interface
 // TODO(tqchen): revisit IntSet interface as well.
-Range IntSet::cover_range(Range max_range) const {
+Range IntSet::CoverRange(Range max_range) const {
   IntSet temp;
   Analyzer analyzer;
   const IntervalSetNode* s_int = (*this).as<IntervalSetNode>();
   CHECK(s_int != nullptr);
   if (s_int->HasUpperBound() && s_int->HasLowerBound()) {
-    return Range::make_by_min_extent(s_int->min_value,
-                                     analyzer.Simplify(s_int->max_value + 1 - s_int->min_value));
+    return Range::FromMinExtent(s_int->min_value,
+                                analyzer.Simplify(s_int->max_value + 1 - s_int->min_value));
   }
   return max_range;
 }
@@ -517,34 +517,34 @@ PrimExpr IntSet::max() const {
   return s_int->max_value;
 }
 
-bool IntSet::is_nothing() const {
+bool IntSet::IsNothing() const {
   const IntervalSetNode* s_int = (*this).as<IntervalSetNode>();
   return (s_int && s_int->IsEmpty());
 }
 
-bool IntSet::is_everything() const {
+bool IntSet::IsEverything() const {
   const IntervalSetNode* s_int = (*this).as<IntervalSetNode>();
   return (s_int && s_int->IsEverything());
 }
 
-bool IntSet::is_single_point() const {
+bool IntSet::IsSinglePoint() const {
   const IntervalSetNode* s_int = (*this).as<IntervalSetNode>();
   return (s_int && s_int->IsSinglePoint());
 }
 
-bool IntSet::can_prove_positive() const {
+bool IntSet::CanProvePositive() const {
   Analyzer analyzer;
   const IntervalSetNode* s_int = (*this).as<IntervalSetNode>();
   return (s_int && is_positive_const(analyzer.Simplify(s_int->min_value)));
 }
 
-bool IntSet::can_prove_negative() const {
+bool IntSet::CanProveNegative() const {
   Analyzer analyzer;
   const IntervalSetNode* s_int = (*this).as<IntervalSetNode>();
   return (s_int && is_negative_const(analyzer.Simplify(s_int->max_value)));
 }
 
-bool IntSet::can_prove_non_positive() const {
+bool IntSet::CanProveNonPositive() const {
   Analyzer analyzer;
   if (const auto* s_int = (*this).as<IntervalSetNode>()) {
     auto max = analyzer.Simplify(s_int->max_value);
@@ -553,7 +553,7 @@ bool IntSet::can_prove_non_positive() const {
   return false;
 }
 
-bool IntSet::can_prove_non_negative() const {
+bool IntSet::CanProveNonNegative() const {
   Analyzer analyzer;
   if (const IntervalSetNode* s_int = (*this).as<IntervalSetNode>()) {
     auto min = analyzer.Simplify(s_int->min_value);
@@ -562,32 +562,32 @@ bool IntSet::can_prove_non_negative() const {
   return false;
 }
 
-SignType IntSet::sign_type() const {
-  if (can_prove_positive()) {
+SignType IntSet::GetSignType() const {
+  if (CanProvePositive()) {
     return kPositive;
-  } else if (can_prove_negative()) {
+  } else if (CanProveNegative()) {
     return kNegative;
-  } else if (is_single_point() && is_zero(point_value())) {
+  } else if (IsSinglePoint() && is_zero(PointValue())) {
     return kZero;
   } else {
     return kUnknown;
   }
 }
-PrimExpr IntSet::point_value() const {
+PrimExpr IntSet::PointValue() const {
   const IntervalSetNode* s_int = (*this).as<IntervalSetNode>();
   CHECK(s_int && s_int->IsSinglePoint());
   return s_int->min_value;
 }
 
-IntSet IntSet::nothing() { return IntervalSet::Empty(); }
+IntSet IntSet::Nothing() { return IntervalSet::Empty(); }
 
-IntSet IntSet::everything() { return IntervalSet::Everything(); }
+IntSet IntSet::Everything() { return IntervalSet::Everything(); }
 
-IntSet IntSet::single_point(PrimExpr x) { return IntervalSet::SinglePoint(x); }
+IntSet IntSet::SinglePoint(PrimExpr x) { return IntervalSet::SinglePoint(x); }
 
-IntSet IntSet::interval(PrimExpr min, PrimExpr max) {
+IntSet IntSet::Interval(PrimExpr min, PrimExpr max) {
   if (min.same_as(max)) {
-    return IntSet::single_point(min);
+    return IntSet::SinglePoint(min);
   }
   return IntervalSet(min, max);
 }
@@ -597,15 +597,15 @@ inline bool ProveEqual(Analyzer* analyzer, PrimExpr lhs, PrimExpr rhs) {
   return is_zero(analyzer->Simplify(lhs - rhs));
 }
 
-IntSet IntSet::range(Range r) {
+IntSet IntSet::FromRange(Range r) {
   // must make sure it can be matched back by MatchRange.
   if (is_one(r->extent)) {
-    return IntSet::single_point(r->min);
+    return IntSet::SinglePoint(r->min);
   }
   return IntervalSet(r->min, r->extent + r->min - 1);
 }
 
-bool IntSet::match_range(const Range& b) const {
+bool IntSet::MatchRange(const Range& b) const {
   const IntSet& a = *this;
   const IntervalSetNode* a_int = a.as<IntervalSetNode>();
   if (!a_int) return false;
@@ -615,7 +615,7 @@ bool IntSet::match_range(const Range& b) const {
 }
 
 IntSet Union(const Array<IntSet>& sets) {
-  if (sets.size() == 0) return IntSet::nothing();
+  if (sets.size() == 0) return IntSet::Nothing();
   if (sets.size() == 1) return sets[0];
   Analyzer ana;
   IntervalSet x = ToIntervalSet(sets[0]);
@@ -626,7 +626,7 @@ IntSet Union(const Array<IntSet>& sets) {
 }
 
 IntSet Intersect(const Array<IntSet>& sets) {
-  if (sets.size() == 0) return IntSet::nothing();
+  if (sets.size() == 0) return IntSet::Nothing();
   if (sets.size() == 1) return sets[0];
   Analyzer ana;
   IntervalSet x = ToIntervalSet(sets[0]);
@@ -657,7 +657,7 @@ IntSet EvalSet(PrimExpr e, const Map<Var, IntSet>& dom_map) {
   return IntervalSetEvaluator(&ana, dom_map, false).Eval(e);
 }
 
-IntSet IntSet::vector(PrimExpr x) {
+IntSet IntSet::Vector(PrimExpr x) {
   Analyzer ana;
   Map<Var, IntSet> dmap;
   return IntervalSetEvaluator(&ana, dmap, true).Eval(x);
@@ -730,19 +730,19 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
                 << "[" << op->min_value << ", " << op->max_value << ']';
     });
 
-TVM_REGISTER_GLOBAL("arith.intset_single_point").set_body_typed(IntSet::single_point);
+TVM_REGISTER_GLOBAL("arith.intset_single_point").set_body_typed(IntSet::SinglePoint);
 
-TVM_REGISTER_GLOBAL("arith.intset_vector").set_body_typed(IntSet::vector);
+TVM_REGISTER_GLOBAL("arith.intset_vector").set_body_typed(IntSet::Vector);
 
-TVM_REGISTER_GLOBAL("arith.intset_interval").set_body_typed(IntSet::interval);
+TVM_REGISTER_GLOBAL("arith.intset_interval").set_body_typed(IntSet::Interval);
 
 TVM_REGISTER_GLOBAL("arith.IntervalSetGetMin").set_body_method(&IntSet::min);
 
 TVM_REGISTER_GLOBAL("arith.IntervalSetGetMax").set_body_method(&IntSet::max);
 
-TVM_REGISTER_GLOBAL("arith.IntSetIsNothing").set_body_method(&IntSet::is_nothing);
+TVM_REGISTER_GLOBAL("arith.IntSetIsNothing").set_body_method(&IntSet::IsNothing);
 
-TVM_REGISTER_GLOBAL("arith.IntSetIsEverything").set_body_method(&IntSet::is_everything);
+TVM_REGISTER_GLOBAL("arith.IntSetIsEverything").set_body_method(&IntSet::IsEverything);
 
 }  // namespace arith
 }  // namespace tvm
diff --git a/src/arith/ir_mutator_with_analyzer.cc b/src/arith/ir_mutator_with_analyzer.cc
index 259fcd9..2a02661 100644
--- a/src/arith/ir_mutator_with_analyzer.cc
+++ b/src/arith/ir_mutator_with_analyzer.cc
@@ -31,7 +31,7 @@ namespace arith {
 using namespace tir;
 
 Stmt IRMutatorWithAnalyzer::VisitStmt_(const ForNode* op) {
-  analyzer_->Bind(op->loop_var, Range::make_by_min_extent(op->min, op->extent));
+  analyzer_->Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent));
   return StmtExprMutator::VisitStmt_(op);
 }
 
@@ -97,7 +97,7 @@ Stmt IRMutatorWithAnalyzer::VisitStmt_(const AttrStmtNode* op) {
   if (op->attr_key == tir::attr::thread_extent || op->attr_key == tir::attr::virtual_thread) {
     IterVar iv = Downcast<IterVar>(op->node);
     CHECK_NE(iv->thread_tag.length(), 0U);
-    analyzer_->Bind(iv->var, Range::make_by_min_extent(0, op->value));
+    analyzer_->Bind(iv->var, Range::FromMinExtent(0, op->value));
     Stmt stmt = StmtExprMutator::VisitStmt_(op);
     return stmt;
   } else {
diff --git a/src/arith/ir_visitor_with_analyzer.h b/src/arith/ir_visitor_with_analyzer.h
index 810949b..388720a 100644
--- a/src/arith/ir_visitor_with_analyzer.h
+++ b/src/arith/ir_visitor_with_analyzer.h
@@ -37,7 +37,7 @@ class IRVisitorWithAnalyzer final : public StmtExprVisitor {
   PrimExpr Simplify(const PrimExpr& expr) { return analyzer_.Simplify(expr); }
 
   void VisitStmt_(const ForNode* op) {
-    analyzer_.Bind(op->loop_var, Range::make_by_min_extent(op->min, op->extent));
+    analyzer_.Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent));
     return StmtExprVisitor::VisitStmt_(op);
   }
 
@@ -45,7 +45,7 @@ class IRVisitorWithAnalyzer final : public StmtExprVisitor {
     if (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread) {
       IterVar iv = Downcast<IterVar>(op->node);
       CHECK_NE(iv->thread_tag.length(), 0U);
-      analyzer_.Bind(iv->var, Range::make_by_min_extent(0, op->value));
+      analyzer_.Bind(iv->var, Range::FromMinExtent(0, op->value));
       StmtExprVisitor::VisitStmt_(op);
     } else {
       StmtExprVisitor::VisitStmt_(op);
diff --git a/src/arith/solve_linear_equation.cc b/src/arith/solve_linear_equation.cc
index 5bf0e0e..cda1ec2 100644
--- a/src/arith/solve_linear_equation.cc
+++ b/src/arith/solve_linear_equation.cc
@@ -225,14 +225,14 @@ Map<Var, Range> InferRange(const Map<Var, PrimExpr>& vars_to_infer, const Array<
       new_ranges.Set(p.first, p.second);
     }
     // Convert original ranges to IntSets
-    var_intsets[p.first.get()] = IntSet::range(p.second);
+    var_intsets[p.first.get()] = IntSet::FromRange(p.second);
   }
 
   // Infer ranges for the new variables and add them to the resulting ranges
   for (const auto& p : vars_to_infer) {
     const auto& var = p.first;
     const auto& expr = p.second;
-    Range range = EvalSet(expr, var_intsets).cover_range(Range());
+    Range range = EvalSet(expr, var_intsets).CoverRange(Range());
     if (range.defined()) {
       new_ranges.Set(var, range);
     }
diff --git a/src/ir/expr.cc b/src/ir/expr.cc
index fd380aa..05d41cf 100644
--- a/src/ir/expr.cc
+++ b/src/ir/expr.cc
@@ -119,11 +119,11 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
 Range::Range(PrimExpr begin, PrimExpr end)
     : Range(make_object<RangeNode>(begin, tir::is_zero(begin) ? end : (end - begin))) {}
 
-Range Range::make_by_min_extent(PrimExpr min, PrimExpr extent) {
+Range Range::FromMinExtent(PrimExpr min, PrimExpr extent) {
   return Range(make_object<RangeNode>(min, extent));
 }
 
-TVM_REGISTER_GLOBAL("ir.range_by_min_extent").set_body_typed(Range::make_by_min_extent);
+TVM_REGISTER_GLOBAL("ir.Range_from_min_extent").set_body_typed(Range::FromMinExtent);
 
 TVM_REGISTER_GLOBAL("ir.Range").set_body([](TVMArgs args, TVMRetValue* ret) {
   *ret = Range(args[0], args[1]);
diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc
index 99a23c6..5085c1e 100644
--- a/src/target/llvm/codegen_llvm.cc
+++ b/src/target/llvm/codegen_llvm.cc
@@ -1190,7 +1190,7 @@ void CodeGenLLVM::VisitStmt_(const StoreNode* op) {
 
 void CodeGenLLVM::VisitStmt_(const ForNode* op) {
   CHECK(is_zero(op->min));
-  analyzer_->Bind(op->loop_var, Range::make_by_min_extent(op->min, op->extent));
+  analyzer_->Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent));
   if (op->for_type == ForType::Unrolled) {
     LOG(WARNING) << "Unroll hint get ignore at CodeGenLLVM backend, "
                  << " consider set unroll_explicit=True";
@@ -1264,7 +1264,7 @@ void CodeGenLLVM::VisitStmt_(const AttrStmtNode* op) {
     if (iv->thread_tag.length() != 0) {
       if (!var_map_.count(iv->var.get())) {
         var_map_[iv->var.get()] = GetThreadIndex(iv);
-        analyzer_->Bind(iv->var, Range::make_by_min_extent(0, op->value));
+        analyzer_->Bind(iv->var, Range::FromMinExtent(0, op->value));
       }
     }
   } else if (op->attr_key == tir::attr::storage_scope) {
diff --git a/src/target/spirv/codegen_spirv.cc b/src/target/spirv/codegen_spirv.cc
index ff3bc7d..7ff0c55 100644
--- a/src/target/spirv/codegen_spirv.cc
+++ b/src/target/spirv/codegen_spirv.cc
@@ -469,7 +469,7 @@ void CodeGenSPIRV::VisitStmt_(const StoreNode* op) {
 
 void CodeGenSPIRV::VisitStmt_(const ForNode* op) {
   CHECK(is_zero(op->min));
-  analyzer_->Bind(op->loop_var, Range::make_by_min_extent(op->min, op->extent));
+  analyzer_->Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent));
   spirv::Value init_value = MakeValue(op->min);
   spirv::Value extent_value = MakeValue(op->extent);
   // Must get init label after making value(to make sure they are correct)
@@ -569,7 +569,7 @@ void CodeGenSPIRV::VisitStmt_(const AttrStmtNode* op) {
     if (iv->thread_tag.length() != 0) {
       if (!var_map_.count(iv->var.get())) {
         var_map_[iv->var.get()] = GetThreadIndex(iv, op->value);
-        analyzer_->Bind(iv->var, Range::make_by_min_extent(0, op->value));
+        analyzer_->Bind(iv->var, Range::FromMinExtent(0, op->value));
       }
     }
   } else if (op->attr_key == tir::attr::storage_scope) {
diff --git a/src/te/operation/compute_op.cc b/src/te/operation/compute_op.cc
index 21343ec..36d2d33 100644
--- a/src/te/operation/compute_op.cc
+++ b/src/te/operation/compute_op.cc
@@ -230,7 +230,7 @@ void ComputeOpNode::PropBoundToInputs(const Operation& self, arith::Analyzer* an
               min_value = shape_i_min_value;
               max_value = shape_i_max_value;
             }
-            dom.data[i].push_back(IntSet::interval(min_value, max_value));
+            dom.data[i].push_back(IntSet::Interval(min_value, max_value));
           } else {
             dom.data[i].push_back(arg_intset);
           }
@@ -247,7 +247,7 @@ void BaseComputeOpNode::GatherBound(const Operation& self,
   CHECK_EQ(self.operator->(), this);
   const TensorDom& tdom = tensor_dom.at(self.output(0));
   for (size_t i = 0; i < this->axis.size(); ++i) {
-    Range r = arith::Union(tdom.data.at(i)).cover_range(this->axis[i]->dom);
+    Range r = arith::Union(tdom.data.at(i)).CoverRange(this->axis[i]->dom);
     CHECK(!out_dom_map->count(this->axis[i]));
     (*out_dom_map)[this->axis[i]] = r;
   }
diff --git a/src/te/operation/extern_op.cc b/src/te/operation/extern_op.cc
index d789938a..e61fe51 100644
--- a/src/te/operation/extern_op.cc
+++ b/src/te/operation/extern_op.cc
@@ -112,8 +112,8 @@ void ExternOpNode::PropBoundToInputs(const Operation& self, arith::Analyzer* ana
     if (it == out_dom_map->end()) continue;
     TensorDom& dom = it->second;
     for (size_t i = 0; i < t->shape.size(); ++i) {
-      dom.data[i].emplace_back(IntSet::range(
-          Range::make_by_min_extent(make_const(t->shape[i].dtype(), 0), t->shape[i])));
+      dom.data[i].emplace_back(
+          IntSet::FromRange(Range::FromMinExtent(make_const(t->shape[i].dtype(), 0), t->shape[i])));
     }
   }
 }
@@ -131,7 +131,7 @@ Stmt ExternOpNode::BuildRealize(const Stage& stage,
     Tensor t = stage->op.output(k);
     Region bounds;
     for (size_t i = 0; i < t->shape.size(); ++i) {
-      bounds.push_back(Range::make_by_min_extent(make_const(t->shape[i].dtype(), 0), t->shape[i]));
+      bounds.push_back(Range::FromMinExtent(make_const(t->shape[i].dtype(), 0), t->shape[i]));
     }
     realize_body = tir::ProducerRealize(t, bounds, const_true(), realize_body);
   }
diff --git a/src/te/operation/hybrid_op.cc b/src/te/operation/hybrid_op.cc
index 9be474d..01162cb 100644
--- a/src/te/operation/hybrid_op.cc
+++ b/src/te/operation/hybrid_op.cc
@@ -127,8 +127,8 @@ void HybridOpNode::PropBoundToInputs(const Operation& self, arith::Analyzer* ana
     if (it == out_dom_map->end()) continue;
     TensorDom& dom = it->second;
     for (size_t i = 0; i < t->shape.size(); ++i) {
-      dom.data[i].emplace_back(IntSet::range(
-          Range::make_by_min_extent(make_const(t->shape[i].dtype(), 0), t->shape[i])));
+      dom.data[i].emplace_back(
+          IntSet::FromRange(Range::FromMinExtent(make_const(t->shape[i].dtype(), 0), t->shape[i])));
     }
   }
 }
@@ -152,7 +152,7 @@ Stmt HybridOpNode::BuildRealize(const Stage& stage,
     Tensor t = stage->op.output(k);
     Region bounds;
     for (size_t i = 0; i < t->shape.size(); ++i) {
-      bounds.push_back(Range::make_by_min_extent(make_const(t->shape[i].dtype(), 0), t->shape[i]));
+      bounds.push_back(Range::FromMinExtent(make_const(t->shape[i].dtype(), 0), t->shape[i]));
     }
     realize_body = tir::ProducerRealize(t, bounds, const_true(), realize_body);
   }
@@ -447,7 +447,7 @@ std::vector<IterVar> GatherLoopVars(Stmt stmt) {
   PostOrderVisit(stmt, [&res_](const ObjectRef& node) {
     if (const ForNode* op = node.as<ForNode>()) {
       Var loop_var(op->loop_var);
-      Range dom = Range::make_by_min_extent(op->min, op->extent);
+      Range dom = Range::FromMinExtent(op->min, op->extent);
       res_.push_back(IterVar(dom, loop_var, ForTypeToIterVarType(op->for_type)));
     }
   });
diff --git a/src/te/operation/scan_op.cc b/src/te/operation/scan_op.cc
index cc86d0f..99b0edf 100644
--- a/src/te/operation/scan_op.cc
+++ b/src/te/operation/scan_op.cc
@@ -87,7 +87,7 @@ ScanOp::ScanOp(std::string name, std::string tag, Map<String, ObjectRef> attrs,
         // setup spatial axis
         std::ostringstream spatial_name;
         spatial_name << name << ".out" << i << ".i" << k;
-        n->spatial_axis_.push_back(IterVar(Range::make_by_min_extent(0, update[i]->shape[k]),
+        n->spatial_axis_.push_back(IterVar(Range::FromMinExtent(0, update[i]->shape[k]),
                                            Var(spatial_name.str()), kOpaque));
       }
     }
@@ -118,7 +118,7 @@ Array<Tensor> scan(Array<Tensor> init, Array<Tensor> update, Array<Tensor> state
                    Array<Tensor> inputs, std::string name, std::string tag,
                    Map<String, ObjectRef> attrs) {
   IterVar scan_axis =
-      IterVar(Range::make_by_min_extent(init[0]->shape[0], update[0]->shape[0] - init[0]->shape[0]),
+      IterVar(Range::FromMinExtent(init[0]->shape[0], update[0]->shape[0] - init[0]->shape[0]),
               Var(name + ".idx"), kOrdered);
   Operation op = ScanOp(name, tag, attrs, scan_axis, init, update, state_placeholder, inputs);
   Array<Tensor> res;
@@ -174,7 +174,7 @@ void ScanOpNode::PropBoundToInputs(const Operation& self, arith::Analyzer* analy
     // first dimension, always needed.
     if (init_dom) {
       init_dom->data[0].push_back(
-          IntSet::range(Range::make_by_min_extent(0, this->init[i]->shape[0])));
+          IntSet::FromRange(Range::FromMinExtent(0, this->init[i]->shape[0])));
     }
     if (update_dom) {
       update_dom->data[0].push_back(dom_map.at(this->scan_axis->var.get()));
@@ -210,9 +210,9 @@ void ScanOpNode::GatherBound(const Operation& self,
   CHECK(!out_dom_map->count(this->scan_axis));
   arith::Analyzer analyzer;
   Range sdom = this->scan_axis->dom;
-  Range r = arith::Union(time_dom).cover_range(sdom);
+  Range r = arith::Union(time_dom).CoverRange(sdom);
   (*out_dom_map)[this->scan_axis] =
-      Range::make_by_min_extent(sdom->min, analyzer.Simplify(r->extent + r->min - sdom->min));
+      Range::FromMinExtent(sdom->min, analyzer.Simplify(r->extent + r->min - sdom->min));
   Map<IterVar, PrimExpr> fix_pt = ScanFixPointAnalysis(self);
   // Update for spatial axis.
   size_t sp_idx = 0;
@@ -224,7 +224,7 @@ void ScanOpNode::GatherBound(const Operation& self,
       CHECK(fix_pt.count(sp_ax));
       if (fix_pt[sp_ax].as<tir::IntImmNode>()->value) {
         // fix point, we can slice it.
-        (*out_dom_map)[sp_ax] = arith::Union(d.data[k]).cover_range(sp_ax->dom);
+        (*out_dom_map)[sp_ax] = arith::Union(d.data[k]).CoverRange(sp_ax->dom);
       } else {
         // not a fix point, need to include everything.
         (*out_dom_map)[sp_ax] = sp_ax->dom;
@@ -238,7 +238,7 @@ Stmt ScanOpNode::BuildRealize(const Stage& stage, const std::unordered_map<IterV
   arith::Analyzer analyzer;
   CHECK_EQ(stage->op.get(), this);
   Range sdom = dom_map.at(this->scan_axis);
-  Range tdom = Range::make_by_min_extent(0, analyzer.Simplify(sdom->extent + sdom->min));
+  Range tdom = Range::FromMinExtent(0, analyzer.Simplify(sdom->extent + sdom->min));
   Stmt ret = body;
   size_t sp_idx = 0;
   for (size_t i = 0; i < update.size(); ++i) {
diff --git a/src/te/operation/tensorize.cc b/src/te/operation/tensorize.cc
index d48bf78..1d72345 100644
--- a/src/te/operation/tensorize.cc
+++ b/src/te/operation/tensorize.cc
@@ -55,12 +55,12 @@ size_t InferTensorizeRegion(const ComputeOpNode* self, const Stage& stage,
     CHECK(vit != dom_map.end());
     const Range& vrange = vit->second;
     if (is_one(vrange->extent)) {
-      up_state[iv] = IntSet::single_point(vrange->min);
+      up_state[iv] = IntSet::SinglePoint(vrange->min);
     } else if (found_point) {
       CHECK(is_zero(vrange->min));
-      up_state[iv] = IntSet::single_point(iv->var);
+      up_state[iv] = IntSet::SinglePoint(iv->var);
     } else {
-      up_state[iv] = IntSet::range(vrange);
+      up_state[iv] = IntSet::FromRange(vrange);
     }
     auto iit = stage->iter_var_attrs.find(iv);
     if (iit != stage->iter_var_attrs.end()) {
@@ -88,7 +88,7 @@ size_t InferTensorizeRegion(const ComputeOpNode* self, const Stage& stage,
   }
   for (IterVar iv : self->root_iter_vars()) {
     IntSet iset = up_state.at(iv);
-    Range iv_range = iset.cover_range(dom_map.at(iv));
+    Range iv_range = iset.CoverRange(dom_map.at(iv));
     (*out_dom)[iv] = iv_range;
     analyzer.Bind(iv->var, iv_range);
     temp_dmap[iv->var.get()] = iset;
@@ -100,7 +100,7 @@ size_t InferTensorizeRegion(const ComputeOpNode* self, const Stage& stage,
     Array<Range> vec;
     const Tensor& t = kv.first;
     for (size_t i = 0; i < t.ndim(); ++i) {
-      Range r = arith::Union(kv.second.data.at(i)).cover_range(none);
+      Range r = arith::Union(kv.second.data.at(i)).CoverRange(none);
       CHECK(r.defined()) << "cannot deduce region of tensorized scope for input " << t;
       vec.push_back(std::move(r));
     }
diff --git a/src/te/schedule/bound.cc b/src/te/schedule/bound.cc
index 099f488..83a1caf 100644
--- a/src/te/schedule/bound.cc
+++ b/src/te/schedule/bound.cc
@@ -144,17 +144,17 @@ void InferRootBound(const Stage& stage, const GraphContext& ctx,
       CHECK(it != rmap->end());
       const Range& vrange = it->second;
       if (is_one(vrange->extent)) {
-        up_state[iv] = IntSet::single_point(vrange->min);
+        up_state[iv] = IntSet::SinglePoint(vrange->min);
       } else if (!NeedRelax(iv, found_attach, ctx.bind_map, scope)) {
         CHECK(is_zero(vrange->min)) << "InferBound requires every leaf iter var's min equals 0, "
                                     << " call schedule.normalize to achieve this. ";
         if (ctx.bind_map.count(iv)) {
-          up_state[iv] = IntSet::single_point(ctx.bind_map.at(iv)->var);
+          up_state[iv] = IntSet::SinglePoint(ctx.bind_map.at(iv)->var);
         } else {
-          up_state[iv] = IntSet::single_point(iv->var);
+          up_state[iv] = IntSet::SinglePoint(iv->var);
         }
       } else {
-        up_state[iv] = IntSet::range(vrange);
+        up_state[iv] = IntSet::FromRange(vrange);
       }
     }
     // Consumer's attach nest
@@ -166,9 +166,9 @@ void InferRootBound(const Stage& stage, const GraphContext& ctx,
       CHECK(is_zero(vrange->min)) << "InferBound requires every leaf iter var's min equals 0, "
                                   << "call schedule.normalize to achieve this.";
       if (NeedRelax(iv, found_attach, ctx.bind_map, scope)) {
-        relax_set.Set(iv->var, IntSet::range(vrange));
+        relax_set.Set(iv->var, IntSet::FromRange(vrange));
         if (ctx.bind_map.count(iv)) {
-          relax_set.Set(ctx.bind_map.at(iv)->var, IntSet::range(vrange));
+          relax_set.Set(ctx.bind_map.at(iv)->var, IntSet::FromRange(vrange));
         }
       }
     }
@@ -186,16 +186,16 @@ void InferRootBound(const Stage& stage, const GraphContext& ctx,
     for (auto iv : op->root_iter_vars()) {
       Range r;
       if (up_state.count(iv)) {
-        r = up_state.at(iv).cover_range(iv->dom);
+        r = up_state.at(iv).CoverRange(iv->dom);
       } else {
         r = iv->dom;
       }
       if (relax_set.size() != 0) {
         dom_map[iv->var.get()] =
-            IntSet::interval(analyzer.int_set(r->min, relax_set).min(),
+            IntSet::Interval(analyzer.int_set(r->min, relax_set).min(),
                              analyzer.int_set(r->min + r->extent - 1, relax_set).max());
       } else {
-        dom_map[iv->var.get()] = IntSet::range(r);
+        dom_map[iv->var.get()] = IntSet::FromRange(r);
       }
       analyzer.Bind(iv->var, r, true);
     }
@@ -247,8 +247,8 @@ Map<IterVar, Range> InferBound(const Schedule& sch) {
     }
   }
   for (auto& p : ret) {
-    ret[p.first] = Range::make_by_min_extent(analyzer.Simplify(p.second->min),
-                                             analyzer.Simplify(p.second->extent));
+    ret[p.first] =
+        Range::FromMinExtent(analyzer.Simplify(p.second->min), analyzer.Simplify(p.second->extent));
   }
   return Map<IterVar, Range>(ret.begin(), ret.end());
 }
diff --git a/src/te/schedule/message_passing.cc b/src/te/schedule/message_passing.cc
index 55593be..4313be8 100644
--- a/src/te/schedule/message_passing.cc
+++ b/src/te/schedule/message_passing.cc
@@ -132,16 +132,14 @@ void PassDownDomain(const Stage& stage, std::unordered_map<IterVar, Range>* p_st
       };
       if (r->factor.defined()) {
         Update(p_state, r->inner,
-               Range::make_by_min_extent(0, resolve_min_extent_for_split(r->inner, r->factor)),
-               actx);
+               Range::FromMinExtent(0, resolve_min_extent_for_split(r->inner, r->factor)), actx);
         Update(p_state, r->outer,
-               Range::make_by_min_extent(0, ceil_div(range_parent->extent, r->factor)), actx);
+               Range::FromMinExtent(0, ceil_div(range_parent->extent, r->factor)), actx);
       } else {
         Update(p_state, r->outer,
-               Range::make_by_min_extent(0, resolve_min_extent_for_split(r->outer, r->nparts)),
-               actx);
+               Range::FromMinExtent(0, resolve_min_extent_for_split(r->outer, r->nparts)), actx);
         Update(p_state, r->inner,
-               Range::make_by_min_extent(0, ceil_div(range_parent->extent, r->nparts)), actx);
+               Range::FromMinExtent(0, ceil_div(range_parent->extent, r->nparts)), actx);
       }
     } else if (const FuseNode* r = rel.as<FuseNode>()) {
       if (!state.count(r->outer) || !state.count(r->inner)) {
@@ -150,15 +148,15 @@ void PassDownDomain(const Stage& stage, std::unordered_map<IterVar, Range>* p_st
       }
       const Range& range_outer = state.at(r->outer);
       const Range& range_inner = state.at(r->inner);
-      state[r->fused] = Range::make_by_min_extent(0, range_outer->extent * range_inner->extent);
+      state[r->fused] = Range::FromMinExtent(0, range_outer->extent * range_inner->extent);
     } else if (const RebaseNode* r = rel.as<RebaseNode>()) {
       if (!state.count(r->parent)) {
         CHECK(allow_missing);
         continue;
       }
-      Update(p_state, r->rebased, Range::make_by_min_extent(0, state.at(r->parent)->extent), actx);
+      Update(p_state, r->rebased, Range::FromMinExtent(0, state.at(r->parent)->extent), actx);
     } else if (const SingletonNode* s = rel.as<SingletonNode>()) {
-      Update(p_state, s->iter, Range::make_by_min_extent(0, 1), actx);
+      Update(p_state, s->iter, Range::FromMinExtent(0, 1), actx);
     } else {
       LOG(FATAL) << "unknown relation type";
     }
@@ -278,8 +276,8 @@ void PassDownIndex(const Stage& stage, const Map<IterVar, Range>& dom_map,
 void PassUpDomain(const SplitNode* s, const std::unordered_map<IterVar, Range>& dom_map,
                   const IntSet& outer, const IntSet& inner, IntSet* parent) {
   if (dom_map.count(s->outer) && dom_map.count(s->inner) && dom_map.count(s->parent) &&
-      outer.match_range(dom_map.at(s->outer)) && inner.match_range(dom_map.at(s->inner))) {
-    *parent = IntSet::range(dom_map.at(s->parent));
+      outer.MatchRange(dom_map.at(s->outer)) && inner.MatchRange(dom_map.at(s->inner))) {
+    *parent = IntSet::FromRange(dom_map.at(s->parent));
     return;
   }
   PrimExpr factor = dom_map.at(s->inner)->extent;
@@ -298,33 +296,33 @@ void PassUpDomain(const FuseNode* s, const std::unordered_map<IterVar, Range>& d
   CHECK(dom_map.count(s->fused));
   arith::Analyzer ana;
 
-  if (fused.match_range(dom_map.at(s->fused))) {
-    *outer = IntSet::range(dom_map.at(s->outer));
-    *inner = IntSet::range(dom_map.at(s->inner));
+  if (fused.MatchRange(dom_map.at(s->fused))) {
+    *outer = IntSet::FromRange(dom_map.at(s->outer));
+    *inner = IntSet::FromRange(dom_map.at(s->inner));
     return;
   }
   PrimExpr outer_min = dom_map.at(s->outer)->min;
   PrimExpr inner_min = dom_map.at(s->inner)->min;
 
-  if (fused.is_single_point()) {
-    PrimExpr value = fused.point_value();
+  if (fused.IsSinglePoint()) {
+    PrimExpr value = fused.PointValue();
     PrimExpr factor = dom_map.at(s->inner)->extent;
     PrimExpr v_outer = indexdiv(value, factor);
     PrimExpr v_inner = indexmod(value, factor);
     if (!is_zero(outer_min)) v_outer = v_outer + outer_min;
     if (!is_zero(inner_min)) v_inner = v_inner + inner_min;
-    *outer = IntSet::single_point(v_outer);
-    *inner = IntSet::single_point(v_inner);
+    *outer = IntSet::SinglePoint(v_outer);
+    *inner = IntSet::SinglePoint(v_inner);
   } else {
     PrimExpr fused_extent = (fused.max() - fused.min() + 1);
     PrimExpr inner_extent = dom_map.at(s->inner)->extent;
-    *outer = IntSet::interval(outer_min + indexdiv(fused.min(), inner_extent),
+    *outer = IntSet::Interval(outer_min + indexdiv(fused.min(), inner_extent),
                               outer_min + indexdiv(fused.max(), inner_extent));
     if (is_zero(ana.Simplify(indexmod(inner_extent, fused_extent))) &&
         is_zero(ana.Simplify(indexmod(fused.min(), fused_extent)))) {
       // fused never spans multiple rows, make a tight bounding box
       // there may be other cases when bounding box could be tightened
-      *inner = IntSet::interval(inner_min + indexmod(fused.min(), inner_extent),
+      *inner = IntSet::Interval(inner_min + indexmod(fused.min(), inner_extent),
                                 inner_min + indexmod(fused.max(), inner_extent));
     } else {  // fused may span multiple rows, use full row widths
       if (!is_zero(ana.Simplify(indexmod(fused_extent, inner_extent))) ||
@@ -332,7 +330,7 @@ void PassUpDomain(const FuseNode* s, const std::unordered_map<IterVar, Range>& d
         LOG(WARNING)
             << "fused and original axes are not aligned, this may cause redundant computations";
       }
-      *inner = IntSet::range(dom_map.at(s->inner));
+      *inner = IntSet::FromRange(dom_map.at(s->inner));
     }
     return;
   }
@@ -341,8 +339,8 @@ void PassUpDomain(const FuseNode* s, const std::unordered_map<IterVar, Range>& d
 void PassUpDomain(const RebaseNode* s, const std::unordered_map<IterVar, Range>& dom_map,
                   const IntSet& rebased, IntSet* parent) {
   CHECK(dom_map.count(s->parent));
-  if (rebased.match_range(dom_map.at(s->rebased))) {
-    *parent = IntSet::range(dom_map.at(s->parent));
+  if (rebased.MatchRange(dom_map.at(s->rebased))) {
+    *parent = IntSet::FromRange(dom_map.at(s->parent));
     return;
   }
   PrimExpr parent_min = dom_map.at(s->parent)->min;
@@ -538,7 +536,7 @@ std::vector<PrimExpr> MakeBoundCheck(const Stage& stage, const Map<IterVar, Rang
 
   // setup domain map for set analysis
   for (const auto& kv : dom_map) {
-    iset_dmap.Set(kv.first->var, IntSet::range(kv.second));
+    iset_dmap.Set(kv.first->var, IntSet::FromRange(kv.second));
   }
 
   for (auto entry : dom_map) {
diff --git a/src/te/schedule/schedule_dataflow_rewrite.cc b/src/te/schedule/schedule_dataflow_rewrite.cc
index f130cb4..52c6757 100644
--- a/src/te/schedule/schedule_dataflow_rewrite.cc
+++ b/src/te/schedule/schedule_dataflow_rewrite.cc
@@ -370,7 +370,7 @@ Array<Tensor> CacheWriteWithReLayoutTensor(Schedule sch, const Array<Tensor>& te
     for (Range r : old_region) {
       PrimExpr min = VarReplacer(vsub2newvar)(r->min);
       PrimExpr extent = VarReplacer(vsub2newvar)(r->extent);
-      region.push_back(Range::make_by_min_extent(min, extent));
+      region.push_back(Range::FromMinExtent(min, extent));
     }
     new_regions.push_back(region);
   }
diff --git a/src/te/schedule/schedule_lang.cc b/src/te/schedule/schedule_lang.cc
index 707d52f..6473c7e 100644
--- a/src/te/schedule/schedule_lang.cc
+++ b/src/te/schedule/schedule_lang.cc
@@ -262,7 +262,7 @@ Stage& Stage::fuse(const Array<IterVar>& axes, IterVar* p_target) {  // NOLINT(*
     // special handle fuse empty array.
     // insert at the outer most loop
     IterVar singleton =
-        IterVar(Range::make_by_min_extent(0, 1), Var("singleton", DataType::Int(32)), kDataPar);
+        IterVar(Range::FromMinExtent(0, 1), Var("singleton", DataType::Int(32)), kDataPar);
     self->relations.push_back(Singleton(singleton));
     Array<IterVar>& all_vars = self->all_iter_vars;
     Array<IterVar>& leaf_vars = self->leaf_iter_vars;
diff --git a/src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc b/src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc
index be1bdd9..75605ad 100644
--- a/src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc
+++ b/src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc
@@ -799,9 +799,9 @@ class TensorCoreIRMutator : public StmtExprMutator {
       }
       CHECK_GE(op->bounds.size(), 2) << "Less than 2 dimensions for matrix " << key->GetNameHint();
       new_bounds.push_back(
-          Range::make_by_min_extent(op->bounds[op->bounds.size() - 2]->min, new_extents[0]));
+          Range::FromMinExtent(op->bounds[op->bounds.size() - 2]->min, new_extents[0]));
       new_bounds.push_back(
-          Range::make_by_min_extent(op->bounds[op->bounds.size() - 1]->min, new_extents[1]));
+          Range::FromMinExtent(op->bounds[op->bounds.size() - 1]->min, new_extents[1]));
 
       return ProducerRealize(op->producer, new_bounds, op->condition, op->body);
     }
diff --git a/src/tir/ir/expr_functor.cc b/src/tir/ir/expr_functor.cc
index afc128b..0118228 100644
--- a/src/tir/ir/expr_functor.cc
+++ b/src/tir/ir/expr_functor.cc
@@ -214,7 +214,7 @@ PrimExpr ExprMutator::VisitExpr_(const ReduceNode* op) {
     if (min.same_as(r->min) && extent.same_as(r->extent)) {
       return v;
     } else {
-      return IterVar(Range::make_by_min_extent(min, extent), v->var, v->iter_type, v->thread_tag);
+      return IterVar(Range::FromMinExtent(min, extent), v->var, v->iter_type, v->thread_tag);
     }
   };
   Array<IterVar> axis = MutateArray(op->axis, fitervar);
diff --git a/src/tir/ir/stmt_functor.cc b/src/tir/ir/stmt_functor.cc
index abf6438..529380b 100644
--- a/src/tir/ir/stmt_functor.cc
+++ b/src/tir/ir/stmt_functor.cc
@@ -131,7 +131,7 @@ class StmtMutator::Internal {
       if (min.same_as(r->min) && extent.same_as(r->extent)) {
         return r;
       } else {
-        return Range::make_by_min_extent(min, extent);
+        return Range::FromMinExtent(min, extent);
       }
     };
     return MutateArray(arr, fmutate, self->allow_copy_on_write_);
diff --git a/src/tir/transforms/coproc_sync.cc b/src/tir/transforms/coproc_sync.cc
index eb9ef32..716ec62 100644
--- a/src/tir/transforms/coproc_sync.cc
+++ b/src/tir/transforms/coproc_sync.cc
@@ -328,7 +328,7 @@ class CoProcBarrierDetector : public StorageAccessVisitor {
       wset.push_back(acc.touched);
     }
     Range none;
-    Range r = arith::Union(wset).cover_range(none);
+    Range r = arith::Union(wset).CoverRange(none);
     CHECK(r.defined()) << "Cannot deduce write range of " << wvec[0].buffer;
     PrimExpr min = r->min;
     PrimExpr extent = r->extent;
diff --git a/src/tir/transforms/inject_prefetch.cc b/src/tir/transforms/inject_prefetch.cc
index 9c27a71..4e4f33b 100644
--- a/src/tir/transforms/inject_prefetch.cc
+++ b/src/tir/transforms/inject_prefetch.cc
@@ -49,15 +49,15 @@ class PrefetchInjector : public StmtMutator {
       Region region;
 
       auto iter_var = loop_nest_.back().get();
-      vectorized_[iter_var] = IntSet::single_point(loop_nest_.back() + op->value);
+      vectorized_[iter_var] = IntSet::SinglePoint(loop_nest_.back() + op->value);
 
       for (Range r : domain) {
         if (!r.defined()) {
           LOG(WARNING) << "Cannot decide prefetch region for " << buffer;
           return op->body;
         }
-        Range res(EvalSet(r, vectorized_).cover_range(none));
-        region.push_back(Range::make_by_min_extent(res->min, res->extent));
+        Range res(EvalSet(r, vectorized_).CoverRange(none));
+        region.push_back(Range::FromMinExtent(res->min, res->extent));
       }
 
       vectorized_.erase(iter_var);
@@ -72,7 +72,7 @@ class PrefetchInjector : public StmtMutator {
     auto& var = op->loop_var;
     loop_nest_.push_back(var);
     if (op->for_type == ForType::Vectorized) {
-      vectorized_[var.get()] = IntSet::interval(op->min, (op->min + op->extent) - 1);
+      vectorized_[var.get()] = IntSet::Interval(op->min, (op->min + op->extent) - 1);
     }
     Stmt ret = StmtMutator::VisitStmt_(op);
     if (op->for_type == ForType::Vectorized) {
diff --git a/src/tir/transforms/loop_partition.cc b/src/tir/transforms/loop_partition.cc
index 1876dfe..d8d784b 100644
--- a/src/tir/transforms/loop_partition.cc
+++ b/src/tir/transforms/loop_partition.cc
@@ -190,8 +190,8 @@ class PartitionFinder : public StmtExprVisitor {
     if (ExprUseVars(op->min, out_vars_) || ExprUseVars(op->extent, out_vars_)) return;
 
     const VarNode* var = op->loop_var.get();
-    hint_map_.insert({var, IntSet::interval(op->min, op->min + op->extent - 1)});
-    relax_map_.insert({var, IntSet::interval(op->min, op->min + op->extent - 1)});
+    hint_map_.insert({var, IntSet::Interval(op->min, op->min + op->extent - 1)});
+    relax_map_.insert({var, IntSet::Interval(op->min, op->min + op->extent - 1)});
     StmtExprVisitor::VisitStmt_(op);
     relax_map_.erase(var);
     hint_map_.erase(var);
@@ -203,7 +203,7 @@ class PartitionFinder : public StmtExprVisitor {
       const IterVarNode* thread_axis = op->node.as<IterVarNode>();
       CHECK(thread_axis);
       const VarNode* var = thread_axis->var.get();
-      IntSet dom = IntSet::range(Range(make_zero(op->value.dtype()), op->value));
+      IntSet dom = IntSet::FromRange(Range(make_zero(op->value.dtype()), op->value));
       hint_map_.insert({var, dom});
       relax_map_.insert({var, dom});
       StmtExprVisitor::VisitStmt_(op);
@@ -222,14 +222,14 @@ class PartitionFinder : public StmtExprVisitor {
         // true. Also find the interval, if exists, in which we can prove that cond is
         // false.
         IntSet interval = DeduceBound(current_var_, cond, hint_map_, relax_map_);
-        if (!interval.is_nothing()) {
+        if (!interval.IsNothing()) {
           // cond is true within interval
           partitions[{cond.get(), true}] = interval;
         }
         PrimExpr inverse_cond = InverseCond(cond);
         if (inverse_cond.defined()) {
           IntSet interval = DeduceBound(current_var_, inverse_cond, hint_map_, relax_map_);
-          if (!interval.is_nothing()) {
+          if (!interval.IsNothing()) {
             // cond is false within interval
             partitions[{cond.get(), false}] = interval;
           }
@@ -342,7 +342,7 @@ class LoopPartitioner : public StmtMutator {
 
     // normal path when loop partition fails
     // normal loop variable can be put into hint map.
-    hint_map_.insert({op->loop_var.get(), IntSet::interval(op->min, op->min + op->extent - 1)});
+    hint_map_.insert({op->loop_var.get(), IntSet::Interval(op->min, op->min + op->extent - 1)});
     Stmt res = StmtMutator::VisitStmt_(op);
     hint_map_.erase(op->loop_var.get());
     return res;
@@ -366,11 +366,11 @@ class LoopPartitioner : public StmtMutator {
     Stmt res;
     if (scope.rank == 1) {
       // threadIdx should be put into relax map, in case of divergence.
-      relax_map_.insert({var.get(), IntSet::interval(make_zero(var.dtype()), op->value - 1)});
+      relax_map_.insert({var.get(), IntSet::Interval(make_zero(var.dtype()), op->value - 1)});
       res = StmtMutator::VisitStmt_(op);
       relax_map_.erase(var.get());
     } else {
-      hint_map_.insert({var.get(), IntSet::interval(make_zero(var.dtype()), op->value - 1)});
+      hint_map_.insert({var.get(), IntSet::Interval(make_zero(var.dtype()), op->value - 1)});
       res = StmtMutator::VisitStmt_(op);
       hint_map_.erase(var.get());
     }
@@ -410,7 +410,7 @@ std::pair<IntSet, std::unordered_set<const Object*>> LoopPartitioner::GetInterva
       }
     }
   }
-  IntSet interval = sets.empty() ? IntSet::nothing() : Intersect(sets);
+  IntSet interval = sets.empty() ? IntSet::Nothing() : Intersect(sets);
   return std::make_pair(interval, cond_set);
 }
 
@@ -464,7 +464,7 @@ Stmt LoopPartitioner::TryPartition(const Object* node, const Stmt& stmt, Var var
                                    PrimExpr max, Stmt body, bool partition_thread_scope) {
   using namespace arith;
   // include hint of var.
-  hint_map_.insert({var.get(), IntSet::interval(min, max)});
+  hint_map_.insert({var.get(), IntSet::Interval(min, max)});
 
   PartitionFinder finder(var, hint_map_, relax_map_);
   finder(body);
@@ -479,12 +479,12 @@ Stmt LoopPartitioner::TryPartition(const Object* node, const Stmt& stmt, Var var
   // find an interval in which all conditions on var are true
   std::tie(middle_interval, cond_set) =
       GetIntervalAndCondset(finder.partitions, for_interval, true);
-  if (middle_interval.is_nothing()) {
+  if (middle_interval.IsNothing()) {
     // if such interval doesn't exist, find an interval in which all
     // conditions on var are false
     std::tie(middle_interval, cond_set) =
         GetIntervalAndCondset(finder.partitions, for_interval, false);
-    if (middle_interval.is_nothing())
+    if (middle_interval.IsNothing())
       // we couldn't find an interval in which the conditions are provably true or false
       // Therefore, we can't partition the loop based on those conds
       return Stmt();
diff --git a/src/tir/transforms/lower_warp_memory.cc b/src/tir/transforms/lower_warp_memory.cc
index 72423e0..480c62c 100644
--- a/src/tir/transforms/lower_warp_memory.cc
+++ b/src/tir/transforms/lower_warp_memory.cc
@@ -315,7 +315,7 @@ class BindVarBoundInfo : public StmtVisitor {
 
   void VisitStmt_(const ForNode* op) final {
     const Var& loop_var = op->loop_var;
-    analyzer_->Bind(loop_var, Range::make_by_min_extent(op->min, op->extent));
+    analyzer_->Bind(loop_var, Range::FromMinExtent(op->min, op->extent));
     StmtVisitor::VisitStmt_(op);
   }
 
@@ -324,7 +324,7 @@ class BindVarBoundInfo : public StmtVisitor {
       IterVar iv = Downcast<IterVar>(op->node);
       CHECK_NE(iv->thread_tag.length(), 0U);
       if (!var_dom_.count(iv->var.get())) {
-        Range dom = Range::make_by_min_extent(0, op->value);
+        Range dom = Range::FromMinExtent(0, op->value);
         var_dom_[iv->var.get()] = dom;
         analyzer_->Bind(iv->var, dom);
       }
diff --git a/src/tir/transforms/narrow_datatype.cc b/src/tir/transforms/narrow_datatype.cc
index a14fd02..4d6aa88 100644
--- a/src/tir/transforms/narrow_datatype.cc
+++ b/src/tir/transforms/narrow_datatype.cc
@@ -97,7 +97,7 @@ class DataTypeVisitor final : public StmtExprVisitor {
   }
 
   void VisitStmt_(const ForNode* op) {
-    analyzer_.Bind(op->loop_var, Range::make_by_min_extent(op->min, op->extent));
+    analyzer_.Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent));
     vextent_[op->loop_var.as<VarNode>()] = op->extent.dtype();
     return StmtExprVisitor::VisitStmt_(op);
   }
@@ -106,7 +106,7 @@ class DataTypeVisitor final : public StmtExprVisitor {
     if (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread) {
       IterVar iv = Downcast<IterVar>(op->node);
       CHECK_NE(iv->thread_tag.length(), 0U);
-      analyzer_.Bind(iv->var, Range::make_by_min_extent(0, op->value));
+      analyzer_.Bind(iv->var, Range::FromMinExtent(0, op->value));
       vextent_[iv->var.as<VarNode>()] = op->value.dtype();
       StmtExprVisitor::VisitStmt_(op);
     } else {
diff --git a/src/tir/transforms/simplify.cc b/src/tir/transforms/simplify.cc
index 3c8a934..3088b6b 100644
--- a/src/tir/transforms/simplify.cc
+++ b/src/tir/transforms/simplify.cc
@@ -48,7 +48,7 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {
   Stmt Simplify(Stmt stmt) { return operator()(std::move(stmt)); }
 
   Stmt VisitStmt_(const ForNode* op) final {
-    analyzer_->Bind(op->loop_var, Range::make_by_min_extent(op->min, op->extent));
+    analyzer_->Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent));
     With<ConstraintContext> ctx1(analyzer_, op->loop_var >= op->min);
     With<ConstraintContext> ctx2(analyzer_, op->loop_var < op->min + op->extent);
     return Parent::VisitStmt_(op);
diff --git a/src/tir/transforms/storage_access.cc b/src/tir/transforms/storage_access.cc
index 24f8b75..1914609 100644
--- a/src/tir/transforms/storage_access.cc
+++ b/src/tir/transforms/storage_access.cc
@@ -42,7 +42,7 @@ void StorageAccessVisitor::VisitExpr_(const LoadNode* op) {
     e.threads = env_threads();
     e.buffer = op->buffer_var;
     e.dtype = op->dtype.element_of();
-    e.touched = arith::IntSet::vector(op->index);
+    e.touched = arith::IntSet::Vector(op->index);
     e.type = kRead;
     e.scope = scope;
     curr_stmt_.access.emplace_back(std::move(e));
@@ -62,7 +62,7 @@ void StorageAccessVisitor::VisitStmt_(const StoreNode* op) {
     e.threads = env_threads();
     e.buffer = op->buffer_var;
     e.dtype = op->value.dtype().element_of();
-    e.touched = arith::IntSet::vector(op->index);
+    e.touched = arith::IntSet::Vector(op->index);
     e.type = kWrite;
     e.scope = scope;
     curr_stmt_.access.emplace_back(std::move(e));
@@ -148,7 +148,7 @@ void StorageAccessVisitor::VisitStmt_(const ForNode* op) {
     // relax the touched set to contain all ranges in the loop.
     std::unordered_map<const VarNode*, arith::IntSet> relax_map;
     relax_map[op->loop_var.get()] =
-        arith::IntSet::range(Range::make_by_min_extent(op->min, op->extent));
+        arith::IntSet::FromRange(Range::FromMinExtent(op->min, op->extent));
     for (AccessEntry& e : s.access) {
       if (e.buffer.defined()) {
         CHECK(e.touched.defined());
@@ -199,7 +199,7 @@ void StorageAccessVisitor::VisitExpr_(const CallNode* op) {
       e.threads = env_threads();
       e.dtype = dtype;
       e.buffer = Downcast<Var>(op->args[1]);
-      e.touched = arith::IntSet::range(Range::make_by_min_extent(offset, extent));
+      e.touched = arith::IntSet::FromRange(Range::FromMinExtent(offset, extent));
       e.scope = scope;
       if (flag->value & 1) {
         e.type = kRead;
diff --git a/src/tir/transforms/thread_storage_sync.cc b/src/tir/transforms/thread_storage_sync.cc
index a38be3c..4893748 100644
--- a/src/tir/transforms/thread_storage_sync.cc
+++ b/src/tir/transforms/thread_storage_sync.cc
@@ -183,8 +183,8 @@ class ThreadSyncPlanner : public StorageAccessVisitor {
         // Assumes no race between threads
         // Same index value means no conflicts
         // TODO(tqchen) more standard set based testing.
-        if (e.touched.is_single_point() && x.touched.is_single_point()) {
-          if (ExprDeepEqual()(e.touched.point_value(), x.touched.point_value())) continue;
+        if (e.touched.IsSinglePoint() && x.touched.IsSinglePoint()) {
+          if (ExprDeepEqual()(e.touched.PointValue(), x.touched.PointValue())) continue;
         }
         if (x.double_buffer_write && e.type == kRead && !loop_carry) continue;
         return true;
diff --git a/tests/python/unittest/test_arith_intset.py b/tests/python/unittest/test_arith_intset.py
index 9919c7b..5e8c947 100644
--- a/tests/python/unittest/test_arith_intset.py
+++ b/tests/python/unittest/test_arith_intset.py
@@ -98,11 +98,11 @@ def test_mod():
 
     floordiv = tvm.te.floordiv
     z = te.var("z")
-    ck.analyzer.bind(x, tvm.ir.Range.make_by_min_extent(0, 3))
+    ck.analyzer.bind(x, tvm.ir.Range.from_min_extent(0, 3))
     ck.verify(flm(y, 8), {y : tvm.arith.IntervalSet(z*8+x*4, z*8+x*4+3)},
               (0, 7))
     ck1 = IntSetChecker()
-    ck1.analyzer.bind(x, tvm.ir.Range.make_by_min_extent(0, 2))
+    ck1.analyzer.bind(x, tvm.ir.Range.from_min_extent(0, 2))
     ck1.verify(flm(y, 8), {y : tvm.arith.IntervalSet(z*8+x*4, z*8+x*4+3)}, (x*4, x*4+3))
 
 
diff --git a/tests/python/unittest/test_arith_solve_linear_system.py b/tests/python/unittest/test_arith_solve_linear_system.py
index 4f4c5ee..550dfef 100644
--- a/tests/python/unittest/test_arith_solve_linear_system.py
+++ b/tests/python/unittest/test_arith_solve_linear_system.py
@@ -199,8 +199,8 @@ def test_low_rank():
 def test_infer_range():
     x, y = te.var("x"), te.var("y")
     ranges = {
-        x: tvm.ir.Range.make_by_min_extent(-5, 10),
-        y: tvm.ir.Range.make_by_min_extent(0, 10),
+        x: tvm.ir.Range.from_min_extent(-5, 10),
+        y: tvm.ir.Range.from_min_extent(0, 10),
     }
 
     solution = arith.solve_linear_equations([
diff --git a/tests/python/unittest/test_tir_ir_builder.py b/tests/python/unittest/test_tir_ir_builder.py
index 090acda..95047f5 100644
--- a/tests/python/unittest/test_tir_ir_builder.py
+++ b/tests/python/unittest/test_tir_ir_builder.py
@@ -65,8 +65,8 @@ def test_prefetch():
     with ib.for_range(0, n, name="i") as i:
         ib.emit(
             tvm.tir.Prefetch(A,
-                [tvm.ir.Range.make_by_min_extent(i+1, 2),
-                 tvm.ir.Range.make_by_min_extent(0, 20)]))
+                [tvm.ir.Range.from_min_extent(i+1, 2),
+                 tvm.ir.Range.from_min_extent(0, 20)]))
     body = ib.get()
     assert body.body.bounds[0].extent.value == 2
 
diff --git a/tests/python/unittest/test_tir_transform_storage_flatten.py b/tests/python/unittest/test_tir_transform_storage_flatten.py
index 468867a..b0acc6c 100644
--- a/tests/python/unittest/test_tir_transform_storage_flatten.py
+++ b/tests/python/unittest/test_tir_transform_storage_flatten.py
@@ -44,7 +44,7 @@ def test_flatten_prefetch():
     _A= tvm.tir.decl_buffer(A.shape, A.dtype, name = 'A');
     i = te.size_var('i')
     j = te.size_var('j')
-    region = [tvm.ir.Range.make_by_min_extent(i[0], i[1]) for i in [(i, 2), (j, 8), (0, 4)]]
+    region = [tvm.ir.Range.from_min_extent(i[0], i[1]) for i in [(i, 2), (j, 8), (0, 4)]]
     stmt = tvm.tir.Prefetch(_A, region)
 
     func = tvm.te.schedule.SchedulePostProcToPrimFunc(