This is an automated email from the ASF dual-hosted git repository.
jroesch pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 372d737 [RELAY] Refactor FoldConstant to skip TNonComputationalOps
(#6720)
372d737 is described below
commit 372d7374d221fb98f7e7fe5d9d5c937059a35515
Author: Lily Orth-Smith <[email protected]>
AuthorDate: Sat Oct 24 00:23:50 2020 -0700
[RELAY] Refactor FoldConstant to skip TNonComputationalOps (#6720)
* add TNonComputational to qnn ops and change FoldConstant
* remove comments
* check if op in nonComputational map
* forgot to mark device_copy op as TNonComputational
* hacky fix to fuseops pass
* fix typo
* manually skip device_copy in fold_constant
* Update src/relay/transforms/fold_constant.cc
Co-authored-by: Junru Shao <[email protected]>
Co-authored-by: Junru Shao <[email protected]>
---
src/relay/qnn/op/concatenate.cc | 1 +
src/relay/qnn/op/convolution.cc | 1 +
src/relay/qnn/op/dense.cc | 1 +
src/relay/qnn/op/dequantize.cc | 1 +
src/relay/qnn/op/op_common.h | 1 +
src/relay/qnn/op/quantize.cc | 1 +
src/relay/qnn/op/requantize.cc | 1 +
src/relay/transforms/fold_constant.cc | 9 ++++++---
8 files changed, 13 insertions(+), 3 deletions(-)
diff --git a/src/relay/qnn/op/concatenate.cc b/src/relay/qnn/op/concatenate.cc
index 29ecf45..88d2ecc 100644
--- a/src/relay/qnn/op/concatenate.cc
+++ b/src/relay/qnn/op/concatenate.cc
@@ -207,6 +207,7 @@ RELAY_REGISTER_OP("qnn.concatenate")
"The quantization zero_point of the output tensor.")
.set_support_level(11)
.add_type_rel("QnnConcatenate", QnnConcatenateRel)
+ .set_attr<TNonComputational>("TNonComputational", true)
.set_attr<FTVMLegalize>("FTVMQnnCanonicalize", ConcatenateQnnCanonicalize)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
QnnConcatenateLayout);
diff --git a/src/relay/qnn/op/convolution.cc b/src/relay/qnn/op/convolution.cc
index b2b6b09..73ee456 100644
--- a/src/relay/qnn/op/convolution.cc
+++ b/src/relay/qnn/op/convolution.cc
@@ -733,6 +733,7 @@ operator to understand how to scale back the int32 output
to (u)int8.
"The quantization zero_point of the weight tensor.")
.set_support_level(11)
.add_type_rel("QnnConv2D", QnnConv2DRel)
+ .set_attr<TNonComputational>("TNonComputational", true)
.set_attr<FTVMLegalize>("FTVMQnnCanonicalize", QnnConv2DCanonicalize)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
QnnConvInferCorrectLayout);
diff --git a/src/relay/qnn/op/dense.cc b/src/relay/qnn/op/dense.cc
index 3cfc418..e1cbfaf 100644
--- a/src/relay/qnn/op/dense.cc
+++ b/src/relay/qnn/op/dense.cc
@@ -189,6 +189,7 @@ RELAY_REGISTER_OP("qnn.dense")
"The quantization zero_point of the weight tensor.")
.set_support_level(11)
.add_type_rel("QDense", QnnDenseRel)
+ .set_attr<TNonComputational>("TNonComputational", true)
.set_attr<FTVMLegalize>("FTVMQnnCanonicalize", QnnDenseCanonicalize);
TVM_REGISTER_GLOBAL("relay.qnn.op._make.dense").set_body_typed(MakeQuantizedDense);
diff --git a/src/relay/qnn/op/dequantize.cc b/src/relay/qnn/op/dequantize.cc
index f0c139c..0a81f3f 100644
--- a/src/relay/qnn/op/dequantize.cc
+++ b/src/relay/qnn/op/dequantize.cc
@@ -136,6 +136,7 @@ The input is always quantized (int8, uint8) and will be
converted to float32 giv
.add_argument("input_zero_point", "Tensor", "The quantization zero_point
of the input tensor.")
.set_support_level(11)
.add_type_rel("Dequantize", DequantizeRel)
+ .set_attr<TNonComputational>("TNonComputational", true)
.set_attr<FTVMLegalize>("FTVMQnnCanonicalize", DequantizeQnnCanonicalize);
TVM_REGISTER_GLOBAL("relay.qnn.op._make.dequantize").set_body_typed(MakeDequantize);
diff --git a/src/relay/qnn/op/op_common.h b/src/relay/qnn/op/op_common.h
index e99c11b..3ca8f64 100644
--- a/src/relay/qnn/op/op_common.h
+++ b/src/relay/qnn/op/op_common.h
@@ -215,6 +215,7 @@ static inline bool QnnBroadcastRel(const Array<Type>&
types, int num_inputs, con
.add_argument("output_scale", "Tensor", "The scale of the output
tensor.") \
.add_argument("output_zero_point", "Tensor", "The zero_point of the
output tensor.") \
.add_type_rel("QnnBroadcast", QnnBroadcastRel)
\
+ .set_attr<TNonComputational>("TNonComputational", true)
\
.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
QnnBinaryBroadcastLayout)
} // namespace qnn
diff --git a/src/relay/qnn/op/quantize.cc b/src/relay/qnn/op/quantize.cc
index 1b5cb5e..0784791 100644
--- a/src/relay/qnn/op/quantize.cc
+++ b/src/relay/qnn/op/quantize.cc
@@ -150,6 +150,7 @@ scale and zero point.
"The quantization zero_point of the output tensor.")
.set_support_level(11)
.add_type_rel("Quantize", QuantizeRel)
+ .set_attr<TNonComputational>("TNonComputational", true)
.set_attr<FTVMLegalize>("FTVMQnnCanonicalize", QuantizeQnnCanonicalize);
TVM_REGISTER_GLOBAL("relay.qnn.op._make.quantize").set_body_typed(MakeQuantize);
diff --git a/src/relay/qnn/op/requantize.cc b/src/relay/qnn/op/requantize.cc
index ea87855..3572a39 100644
--- a/src/relay/qnn/op/requantize.cc
+++ b/src/relay/qnn/op/requantize.cc
@@ -324,6 +324,7 @@ Q_output = zp_output + (scale_input)/(scale_output) *
(Q_input - zp_input)
"The quantization zero_point of the output tensor.")
.set_support_level(11)
.add_type_rel("Requantize", RequantizeRel)
+ .set_attr<TNonComputational>("TNonComputational", true)
.set_attr<FTVMLegalize>("FTVMQnnCanonicalize", RequantizeQnnCanonicalize)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
RequantizeInferCorrectLayout);
diff --git a/src/relay/transforms/fold_constant.cc
b/src/relay/transforms/fold_constant.cc
index 1de690d..4a739dd 100644
--- a/src/relay/transforms/fold_constant.cc
+++ b/src/relay/transforms/fold_constant.cc
@@ -151,9 +151,12 @@ class ConstantFolder : public MixedModeMutator {
}
// We should think about potentially constant evaluation over these ops
too.
- if (call->op == invoke_tvm_op_ || call->op == shape_func_op_ || call->op
== alloc_tensor_op_ ||
- call->op == alloc_storage_op_ || call->op == device_copy_op_) {
- return GetRef<Call>(call);
+ static auto fnoncomputational =
Op::GetAttrMap<TNonComputational>("TNonComputational");
+ if (const auto* call_node = call->op.as<OpNode>()) {
+ Op op = GetRef<Op>(call_node);
+ if ((fnoncomputational.count(op) && fnoncomputational[op]) || (call->op
== device_copy_op_)) {
+ return GetRef<Call>(call);
+ }
}
bool all_const_args = true;