AndrewZhaoLuo commented on a change in pull request #10239:
URL: https://github.com/apache/tvm/pull/10239#discussion_r808683849



##########
File path: src/relay/transforms/fake_quantization_to_integer.cc
##########
@@ -270,16 +297,252 @@ class FakeQuantizationRewriter : public MixedModeMutator 
{
   const bool hard_fail_;
 };
 
-Expr FakeQuantizationToInteger(const Expr& expr, const IRModule& mod, bool 
hard_fail) {
-  return FakeQuantizationRewriter(hard_fail).Mutate(expr);
+/* Checks if the operation to convert QAT pass is enabled.
+ * The following conditions must be satisfied:
+ * 1. operations registered for FTVMFakeQuantizationToInteger;
+ * 2. Unary operators or operators with with the TensorAffineType calculated 
during
+ * FTVMFakeQuantizationToInteger conversion;
+ * 3. Not one of the "key" operations: requantize,quantize and dequantize(they 
are at the boundaries
+ * of regions defined to be quantized).
+ */
+bool is_op_enabled_for_optional_fq2i(const CallNode* call_node) {
+  const Op op = Downcast<Op>(call_node->op);
+  static auto fqfq = 
Op::GetAttrMap<FTVMFakeQuantizationToInteger>("FTVMFakeQuantizationToInteger");
+  static std::unordered_set<Op, tvm::ObjectHash, tvm::ObjectEqual> ops = {
+      Op::Get("broadcast_to"),
+      Op::Get("clip"),
+      Op::Get("expand_dims"),
+      Op::Get("max"),
+      Op::Get("maximum"),
+      Op::Get("min"),
+      Op::Get("minimum"),
+      Op::Get("nn.avg_pool2d"),
+      Op::Get("nn.batch_flatten"),
+      Op::Get("nn.batch_matmul"),
+      Op::Get("nn.bias_add"),
+      Op::Get("nn.conv2d"),
+      Op::Get("nn.conv2d_transpose"),
+      Op::Get("nn.dense"),
+      Op::Get("nn.depth_to_space"),
+      Op::Get("nn.global_avg_pool2d"),
+      Op::Get("nn.max_pool2d"),
+      Op::Get("nn.pad"),
+      Op::Get("nn.relu"),
+      Op::Get("reshape"),
+      Op::Get("split"),
+      Op::Get("squeeze"),
+      Op::Get("strided_slice"),
+      Op::Get("transpose")};
+
+  auto is_enabled = [&](const auto i) { return i == call_node->op; };
+  auto result = std::find_if(std::begin(ops), std::end(ops), is_enabled);

Review comment:
       can you just use `ops.find(...)` from unordered_set? find_if is O(n)




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to