tqchen commented on a change in pull request #7045:
URL: https://github.com/apache/tvm/pull/7045#discussion_r551976566
##########
File path: src/arith/canonical_simplify.cc
##########
@@ -1071,6 +1208,33 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const
ReduceNode* op) {
return ret;
}
+PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const CastNode* op) {
+ if (!IsIndexType(op->dtype)) {
+ return Rewriter::VisitExpr_(op);
+ }
+ // normalize
+ PrimExpr value = this->CanonicalMutate(op->value);
+ PrimExpr ret;
+ // PushCastToChildren
+ if (value.as<SumExprNode>()) {
+ SumExpr se = Downcast<SumExpr>(value);
+ if (se->CanPushCastToChildren(op->dtype, analyzer_)) {
+ se.CopyOnWrite()->PushCastToChildren(op->dtype);
+ ret = se;
Review comment:
consider directly return here.
##########
File path: src/arith/canonical_simplify.cc
##########
@@ -1071,6 +1208,33 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const
ReduceNode* op) {
return ret;
}
+PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const CastNode* op) {
+ if (!IsIndexType(op->dtype)) {
+ return Rewriter::VisitExpr_(op);
+ }
+ // normalize
+ PrimExpr value = this->CanonicalMutate(op->value);
+ PrimExpr ret;
+ // PushCastToChildren
+ if (value.as<SumExprNode>()) {
+ SumExpr se = Downcast<SumExpr>(value);
+ if (se->CanPushCastToChildren(op->dtype, analyzer_)) {
+ se.CopyOnWrite()->PushCastToChildren(op->dtype);
+ ret = se;
+ }
+ } else if (value.as<SplitExprNode>()) {
+ SplitExpr se = Downcast<SplitExpr>(value);
+ if (se->CanPushCastToChildren(op->dtype, analyzer_)) {
+ se.CopyOnWrite()->PushCastToChildren(op->dtype);
+ ret = se;
Review comment:
consider directly return here.
##########
File path: src/arith/canonical_simplify.cc
##########
@@ -1071,6 +1208,33 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const
ReduceNode* op) {
return ret;
}
+PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const CastNode* op) {
+ if (!IsIndexType(op->dtype)) {
+ return Rewriter::VisitExpr_(op);
+ }
+ // normalize
+ PrimExpr value = this->CanonicalMutate(op->value);
+ PrimExpr ret;
+ // PushCastToChildren
+ if (value.as<SumExprNode>()) {
+ SumExpr se = Downcast<SumExpr>(value);
+ if (se->CanPushCastToChildren(op->dtype, analyzer_)) {
+ se.CopyOnWrite()->PushCastToChildren(op->dtype);
+ ret = se;
+ }
+ } else if (value.as<SplitExprNode>()) {
+ SplitExpr se = Downcast<SplitExpr>(value);
+ if (se->CanPushCastToChildren(op->dtype, analyzer_)) {
+ se.CopyOnWrite()->PushCastToChildren(op->dtype);
+ ret = se;
+ }
+ }
+ if (!ret.defined()) {
+ ret = Rewriter::VisitExpr_(op);
Review comment:
return Rewriter::VisitExpr_(op);
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]