vinx13 commented on a change in pull request #8325:
URL: https://github.com/apache/tvm/pull/8325#discussion_r665788576



##########
File path: src/relay/transforms/fold_scale_axis.cc
##########
@@ -735,25 +733,41 @@ class BackwardTransformer : public ObjectRef {
   }
   using ContainerType = BackwardTransformerNode;
 };
-
-Expr BackwardTransformerNode::Transform(const CallNode* call_node, Message 
message, Expr scale) {
-  static const auto& ftransform = 
Op::GetAttrMap<FBackwardTransform>("FScaleAxisBackwardTransform");
-  auto f = ftransform.get(call_node->op, nullptr);
-  if (f != nullptr) {
+/*!

Review comment:
       nit: add a blank line

##########
File path: src/relay/transforms/fold_scale_axis.cc
##########
@@ -735,25 +733,41 @@ class BackwardTransformer : public ObjectRef {
   }
   using ContainerType = BackwardTransformerNode;
 };
-
-Expr BackwardTransformerNode::Transform(const CallNode* call_node, Message 
message, Expr scale) {
-  static const auto& ftransform = 
Op::GetAttrMap<FBackwardTransform>("FScaleAxisBackwardTransform");
-  auto f = ftransform.get(call_node->op, nullptr);
-  if (f != nullptr) {
+/*!
+ * \brief Transform the expr to consider the scaling.
+ *
+ * \param expr The input expression.
+ * \param message The axes to scale.
+ * \param scale The scale applied to the axes.
+ * \return The result of transformation.
+ */
+Expr BackwardTransformerNode::Transform(const Expr& expr, Message message, 
Expr scale) {
+  if (const CallNode* call_node = expr.as<CallNode>()) {
+    static const auto& ftransform =
+        Op::GetAttrMap<FBackwardTransform>("FScaleAxisBackwardTransform");
+    auto f = ftransform.get(call_node->op, nullptr);
     const Call call = GetRef<Call>(call_node);
-    const auto it = memo_.find(call);
-    if (it != memo_.end()) {
-      return it->second;
+    // ignore if there is a message
+    if (!message.defined()) {
+      const auto it = memo_.find(call);
+      if (it != memo_.end()) {
+        return it->second;
+      }
+    }
+    Expr new_expr = NullValue<Expr>();
+    if (f != nullptr) {
+      new_expr = f(call, message, scale, GetRef<BackwardTransformer>(this));
+    } else {
+      ICHECK(!message.defined()) << "outstanding scale";
+      new_expr = NormalCallTransform(call.operator->());
     }
-    Expr new_expr = f(GetRef<Call>(call_node), message, scale, 
GetRef<BackwardTransformer>(this));
     memo_[call] = new_expr;
     return new_expr;
   } else {
     ICHECK(!message.defined()) << "outstanding scale";
-    return NormalCallTransform(call_node);
+    return this->Mutate(expr);
   }
 }
-

Review comment:
       nit: keep the blank line here




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to