You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/03/29 06:07:10 UTC

[GitHub] [tvm] hzfan commented on a change in pull request #7759: [ARITH] normalize iter affine map expr to PrimExpr

hzfan commented on a change in pull request #7759:
URL: https://github.com/apache/tvm/pull/7759#discussion_r603032508



##########
File path: src/arith/iter_affine_map.cc
##########
@@ -1028,5 +1028,61 @@ PrimExpr IterMapRewriter::VisitExpr_(const FloorModNode* op) {
   }
 }
 
+/*! * \brief Given an IterVarMapExpr, transform it to normal PrimExpr. */
+class IterMapToExprNormalizer {
+ public:
+  explicit IterMapToExprNormalizer(Analyzer* analyzer) : analyzer_(analyzer) {}
+
+  PrimExpr Convert(const IterMapExpr& expr) {
+    if (const auto* op = expr.as<IterSplitExprNode>()) {
+      return ConvertIterSplitExpr(GetRef<IterSplitExpr>(op));
+    } else if (const auto* op = expr.as<IterSumExprNode>()) {
+      return ConvertIterSumExpr(GetRef<IterSumExpr>(op));
+    } else {
+      ICHECK(expr.defined());
+      LOG(FATAL) << "Unknown IterMapExpr type " << expr->GetTypeKey();
+      return 0;
+    }
+  }
+
+  PrimExpr ConvertIterSumExpr(const IterSumExpr& expr) {
+    PrimExpr res = 0;
+    for (const IterSplitExpr& arg : expr->args) {
+      res += ConvertIterSplitExpr(arg);
+    }
+    res += expr->base;
+    return res;
+  }
+
+  PrimExpr ConvertIterSplitExpr(const IterSplitExpr& expr) {
+    PrimExpr source;
+    if (const auto* op = expr->source->source.as<VarNode>()) {
+      source = GetRef<Var>(op);
+    } else if (const auto& op = expr->source->source.as<IterSumExprNode>()) {

Review comment:
       ```suggestion
       } else if (const auto* op = expr->source->source.as<IterSumExprNode>()) {
   ```

##########
File path: src/arith/iter_affine_map.cc
##########
@@ -1028,5 +1028,61 @@ PrimExpr IterMapRewriter::VisitExpr_(const FloorModNode* op) {
   }
 }
 
+/*! * \brief Given an IterVarMapExpr, transform it to normal PrimExpr. */
+class IterMapToExprNormalizer {
+ public:
+  explicit IterMapToExprNormalizer(Analyzer* analyzer) : analyzer_(analyzer) {}
+
+  PrimExpr Convert(const IterMapExpr& expr) {
+    if (const auto* op = expr.as<IterSplitExprNode>()) {
+      return ConvertIterSplitExpr(GetRef<IterSplitExpr>(op));
+    } else if (const auto* op = expr.as<IterSumExprNode>()) {
+      return ConvertIterSumExpr(GetRef<IterSumExpr>(op));
+    } else {
+      ICHECK(expr.defined());
+      LOG(FATAL) << "Unknown IterMapExpr type " << expr->GetTypeKey();
+      return 0;
+    }
+  }
+
+  PrimExpr ConvertIterSumExpr(const IterSumExpr& expr) {
+    PrimExpr res = 0;
+    for (const IterSplitExpr& arg : expr->args) {
+      res += ConvertIterSplitExpr(arg);
+    }
+    res += expr->base;
+    return res;
+  }
+
+  PrimExpr ConvertIterSplitExpr(const IterSplitExpr& expr) {
+    PrimExpr source;
+    if (const auto* op = expr->source->source.as<VarNode>()) {
+      source = GetRef<Var>(op);
+    } else if (const auto& op = expr->source->source.as<IterSumExprNode>()) {
+      source = ConvertIterSumExpr(GetRef<IterSumExpr>(op));
+    }

Review comment:
       Do we need to LOG(FATAL) if it falls in neither of the if?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org