hzfan commented on a change in pull request #7759:
URL: https://github.com/apache/tvm/pull/7759#discussion_r603032508
##########
File path: src/arith/iter_affine_map.cc
##########
@@ -1028,5 +1028,61 @@ PrimExpr IterMapRewriter::VisitExpr_(const FloorModNode*
op) {
}
}
+/*! * \brief Given an IterVarMapExpr, transform it to normal PrimExpr. */
+class IterMapToExprNormalizer {
+ public:
+ explicit IterMapToExprNormalizer(Analyzer* analyzer) : analyzer_(analyzer) {}
+
+ PrimExpr Convert(const IterMapExpr& expr) {
+ if (const auto* op = expr.as<IterSplitExprNode>()) {
+ return ConvertIterSplitExpr(GetRef<IterSplitExpr>(op));
+ } else if (const auto* op = expr.as<IterSumExprNode>()) {
+ return ConvertIterSumExpr(GetRef<IterSumExpr>(op));
+ } else {
+ ICHECK(expr.defined());
+ LOG(FATAL) << "Unknown IterMapExpr type " << expr->GetTypeKey();
+ return 0;
+ }
+ }
+
+ PrimExpr ConvertIterSumExpr(const IterSumExpr& expr) {
+ PrimExpr res = 0;
+ for (const IterSplitExpr& arg : expr->args) {
+ res += ConvertIterSplitExpr(arg);
+ }
+ res += expr->base;
+ return res;
+ }
+
+ PrimExpr ConvertIterSplitExpr(const IterSplitExpr& expr) {
+ PrimExpr source;
+ if (const auto* op = expr->source->source.as<VarNode>()) {
+ source = GetRef<Var>(op);
+ } else if (const auto& op = expr->source->source.as<IterSumExprNode>()) {
Review comment:
```suggestion
} else if (const auto* op = expr->source->source.as<IterSumExprNode>()) {
```
##########
File path: src/arith/iter_affine_map.cc
##########
@@ -1028,5 +1028,61 @@ PrimExpr IterMapRewriter::VisitExpr_(const FloorModNode*
op) {
}
}
+/*! * \brief Given an IterVarMapExpr, transform it to normal PrimExpr. */
+class IterMapToExprNormalizer {
+ public:
+ explicit IterMapToExprNormalizer(Analyzer* analyzer) : analyzer_(analyzer) {}
+
+ PrimExpr Convert(const IterMapExpr& expr) {
+ if (const auto* op = expr.as<IterSplitExprNode>()) {
+ return ConvertIterSplitExpr(GetRef<IterSplitExpr>(op));
+ } else if (const auto* op = expr.as<IterSumExprNode>()) {
+ return ConvertIterSumExpr(GetRef<IterSumExpr>(op));
+ } else {
+ ICHECK(expr.defined());
+ LOG(FATAL) << "Unknown IterMapExpr type " << expr->GetTypeKey();
+ return 0;
+ }
+ }
+
+ PrimExpr ConvertIterSumExpr(const IterSumExpr& expr) {
+ PrimExpr res = 0;
+ for (const IterSplitExpr& arg : expr->args) {
+ res += ConvertIterSplitExpr(arg);
+ }
+ res += expr->base;
+ return res;
+ }
+
+ PrimExpr ConvertIterSplitExpr(const IterSplitExpr& expr) {
+ PrimExpr source;
+ if (const auto* op = expr->source->source.as<VarNode>()) {
+ source = GetRef<Var>(op);
+ } else if (const auto& op = expr->source->source.as<IterSumExprNode>()) {
+ source = ConvertIterSumExpr(GetRef<IterSumExpr>(op));
+ }
Review comment:
Do we need to LOG(FATAL) if it falls in neither of the if?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]