elvin-n commented on a change in pull request #10239:
URL: https://github.com/apache/tvm/pull/10239#discussion_r807959262



##########
File path: src/relay/transforms/fake_quantization_to_integer.cc
##########
@@ -270,8 +293,233 @@ class FakeQuantizationRewriter : public MixedModeMutator {
   const bool hard_fail_;
 };
 
+bool is_op_enabled_for_optional_fq2i(const CallNode* call_node) {
+  const Op op = Downcast<Op>(call_node->op);
+  static auto fqfq = 
Op::GetAttrMap<FTVMFakeQuantizationToInteger>("FTVMFakeQuantizationToInteger");
+  static std::unordered_set<Op, tvm::ObjectHash, tvm::ObjectEqual> ops = {
+      Op::Get("reshape"),
+      Op::Get("squeeze"),
+      Op::Get("strided_slice"),
+      Op::Get("transpose"),
+      Op::Get("expand_dims"),
+      Op::Get("nn.max_pool2d"),
+      Op::Get("nn.batch_flatten"),
+      Op::Get("nn.depth_to_space"),
+      Op::Get("max"),
+      Op::Get("min"),
+      Op::Get("nn.avg_pool2d"),
+      Op::Get("nn.global_avg_pool2d"),
+      Op::Get("nn.bias_add"),
+      Op::Get("nn.conv2d"),
+      Op::Get("nn.conv2d_transpose"),
+      Op::Get("nn.dense"),
+      Op::Get("nn.batch_matmul"),
+      Op::Get("split"),
+      Op::Get("clip"),
+      Op::Get("nn.relu"),
+      Op::Get("nn.pad"),
+      Op::Get("broadcast_to"),
+      Op::Get("minimum"),
+      Op::Get("maximum")};
+
+  auto is_enabled = [&](const auto i) { return i == call_node->op; };
+  auto result = std::find_if(std::begin(ops), std::end(ops), is_enabled);
+  return result != ops.end() && fqfq.count(Downcast<Op>(op));
+}
+
+class OptionalSubgraphExtractor : public ExprVisitor {
+ public:
+  const ExprSet GetSubgraph(const Expr& expr) {
+    expr_call_node_ = expr.as<CallNode>();
+    ICHECK(expr_call_node_ != nullptr);
+    ICHECK(is_op_enabled_for_optional_fq2i(expr_call_node_));
+
+    VisitExpr(expr);
+
+    ExprSet subgraph;
+    if (is_fake_quantized_) {
+      for (auto kv : this->visit_counter_) {
+        if (auto call_node = GetRef<ObjectRef>(kv.first).as<CallNode>()) {
+          if (call_node != expr_call_node_) {
+            subgraph.insert(Downcast<Expr>(GetRef<ObjectRef>(kv.first)));
+          }
+        }
+      }
+    }
+    return subgraph;
+  }
+  const AffineTypeMap GetAffineTypes() { return affine_types_; }
+  void VisitExpr(const Expr& expr) override {
+    // When looking for fake quantized subgraphs, we only support data-flow 
regions of the graph,
+    // i.e. call nodes/tuples/constants/etc. If we see anything else (like 
control flow) we
+    // abort the rewrite.
+    if (expr.as<CallNode>() == nullptr && expr.as<OpNode>() == nullptr &&
+        expr.as<TupleNode>() == nullptr && expr.as<TupleGetItemNode>() == 
nullptr &&
+        expr.as<ConstantNode>() == nullptr) {
+      DLOG(INFO) << "FakeQuantizationToInteger found a non - dataflow op 
inside a fake quantize "
+                    "region, aborting this rewrite";
+      is_fake_quantized_ = false;
+    } else {
+      ExprVisitor::VisitExpr(expr);
+    }
+  }
+
+ protected:
+  void VisitExpr_(const CallNode* call_node) override {
+    if (call_node->op == dequantize_op_) {
+      const auto* attrs = call_node->attrs.as<qnn::DequantizeAttrs>();
+      ICHECK(attrs != nullptr);
+
+      affine_types_.Set(
+          GetRef<Expr>(call_node),
+          TensorAffineType(
+              call_node->args[1], call_node->args[2],
+              
tvm::relay::transform::InferTypeLocal(call_node->args[0]).as<TensorTypeNode>()->dtype,
+              attrs->axis));
+    } else if (call_node == expr_call_node_) {
+      for (auto arg : call_node->args) {
+        VisitExpr(arg);
+      }
+    } else {
+      // run normally on everything else.
+      ExprVisitor::VisitExpr_(call_node);
+    }
+  }
+
+  const Op dequantize_op_ = Op::Get("qnn.dequantize");
+  bool is_fake_quantized_ = true;
+  AffineTypeMap affine_types_;
+  const CallNode* expr_call_node_ = nullptr;
+};
+
+class OptionalSubgraphMutator : public ExprMutator {
+ public:
+  OptionalSubgraphMutator(ExprSet subgraph, AffineTypeMap affine_types, bool 
hard_fail)
+      : subgraph_(subgraph), affine_types_(affine_types), 
hard_fail_(hard_fail) {}
+
+  Expr MutateSubgraph(const Expr& expr) {
+    if (subgraph_.size() == 0) {
+      return expr;
+    }
+
+    quantize_node_ = expr.as<CallNode>();
+    ICHECK(quantize_node_);
+    ICHECK(is_op_enabled_for_optional_fq2i(quantize_node_));
+
+    for (auto node : subgraph_) {
+      const Op op = Downcast<Op>(node.as<CallNode>()->op);
+
+      if (node.as<CallNode>()->op != dequantize_op_) {
+        // Only modify the subgraph if we have translation
+        // rules for every op
+        if (hard_fail_) {
+          LOG(FATAL) << "Found no rewrite rule for " << AsText(op, false) << 
std::endl;
+        } else {
+          DLOG(INFO) << "Found no rewrite rule for " << AsText(op, false) << 
std::endl;
+          return expr;
+        }
+      }
+    }
+    try {
+      return Mutate(expr);
+    } catch (std::exception& e) {
+      if (hard_fail_) {
+        throw e;
+      } else {
+        DLOG(INFO) << "Ran into an error rewriting a subgraph, skipping" << 
expr << std::endl;
+        return expr;
+      }
+    }
+  }
+
+ protected:
+  Expr VisitExpr_(const CallNode* call_node) {
+    Expr out;
+    static auto fqfq =
+        
Op::GetAttrMap<FTVMFakeQuantizationToInteger>("FTVMFakeQuantizationToInteger");
+
+    Op op = Downcast<Op>(call_node->op);
+    if (fqfq.count(op)) {
+      Expr expr;
+      if (op == dequantize_op_) {
+        expr = GetRef<Expr>(call_node);
+      } else {
+        expr = ExprMutator::VisitExpr_(call_node);
+      }
+      // Call the rewrite
+      Array<ObjectRef> vals = fqfq[op](expr, affine_types_);
+      // Save the outputs of the rewrite
+      ICHECK(vals.size() == 2)
+          << "got the wrong number of returned arguments from 
FTVMFakeQuantizationToInteger for "
+          << AsText(op, false);
+      out = Downcast<Expr>(vals[0]);
+
+      affine_types_.Set(out, Downcast<AffineType>(vals[1]));
+
+      if (call_node == quantize_node_) {
+        out = qnn::MakeDequantize(out, 
vals[1].as<TensorAffineTypeNode>()->scale,
+                                  
vals[1].as<TensorAffineTypeNode>()->zero_point,
+                                  vals[1].as<TensorAffineTypeNode>()->axis);
+      }
+    } else {
+      ICHECK(false) << "When rewriting a fake quantized graph, found an 
invalid node "
+                    << AsText(GetRef<Expr>(call_node), false);
+    }
+    return out;
+  }
+
+  Expr VisitExpr_(const TupleNode* node) {
+    Expr expr = ExprMutator::VisitExpr_(node);
+    auto new_node = expr.as<TupleNode>();
+    Array<TensorAffineType> types;
+    for (Expr field : new_node->fields) {
+      ICHECK(affine_types_[field].as<TensorAffineTypeNode>());
+      types.push_back(Downcast<TensorAffineType>(affine_types_[field]));
+    }
+    affine_types_.Set(expr, TupleAffineType(types));
+    return expr;
+  }
+
+  Expr VisitExpr_(const TupleGetItemNode* node) {
+    Expr expr = ExprMutator::VisitExpr_(node);
+    auto tuple_type = 
affine_types_[expr.as<TupleGetItemNode>()->tuple].as<TupleAffineTypeNode>();
+    affine_types_.Set(expr, tuple_type->types[node->index]);
+    return expr;
+  }
+
+  ExprSet subgraph_;
+  AffineTypeMap affine_types_;
+  const bool hard_fail_;
+  const Op dequantize_op_ = Op::Get("qnn.dequantize");
+  const CallNode* quantize_node_ = nullptr;
+};
+
+class OptionalFakeQuantizationRewriter : public MixedModeMutator {
+ public:
+  explicit OptionalFakeQuantizationRewriter(bool hard_fail) : 
hard_fail_(hard_fail) {}
+
+ protected:
+  Expr Rewrite_(const CallNode* pre, const Expr& post) override {
+    if (const CallNode* call_node = post.as<CallNode>()) {
+      const Op op = Downcast<Op>(call_node->op);
+      if (is_op_enabled_for_optional_fq2i(call_node)) {
+        OptionalSubgraphExtractor extractor;
+        ExprSet subgraph = extractor.GetSubgraph(post);
+        AffineTypeMap affine_types = extractor.GetAffineTypes();
+        Expr out = OptionalSubgraphMutator(subgraph, affine_types, 
hard_fail_).MutateSubgraph(post);
+        return out;
+      }
+    }
+    return post;
+  }
+  const bool hard_fail_;
+};
+
 Expr FakeQuantizationToInteger(const Expr& expr, const IRModule& mod, bool 
hard_fail) {
-  return FakeQuantizationRewriter(hard_fail).Mutate(expr);
+  auto fq_expr = FakeQuantizationRewriter(hard_fail).Mutate(expr);
+  auto fq_inferred_expr = tvm::relay::InferType(fq_expr);
+  auto ofq_expr = 
OptionalFakeQuantizationRewriter(hard_fail).Mutate(fq_inferred_expr);
+  return ofq_expr;
 }

Review comment:
       let's add parameter `enableQAT` to the FakeQuantizationToInteger pass 
with default value not to call QAT transformation




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to