This is an automated email from the ASF dual-hosted git repository.

wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new ff9c480  making quantization tweaks (#6731)
ff9c480 is described below

commit ff9c4803913b82085f281c98afbd54feedefeb7c
Author: Thierry Moreau <tmor...@octoml.ai>
AuthorDate: Fri Nov 6 18:20:56 2020 -0800

    making quantization tweaks (#6731)
---
 python/tvm/relay/quantize/_annotate.py | 43 ++++++++++++++++++++++++++++++++++
 src/relay/quantize/realize.cc          | 36 ++++++++++++++++++++++++++++
 2 files changed, 79 insertions(+)

diff --git a/python/tvm/relay/quantize/_annotate.py 
b/python/tvm/relay/quantize/_annotate.py
index b187387..6c395e2 100644
--- a/python/tvm/relay/quantize/_annotate.py
+++ b/python/tvm/relay/quantize/_annotate.py
@@ -175,6 +175,28 @@ def conv2d_rewrite(ref_call, new_args, ctx):
     return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)
 
 
+@register_annotate_function("nn.conv1d")
+def conv1d_rewrite(ref_call, new_args, ctx):
+    """Rewrite function for conv1d. Lhs of conv will be quantized to
+    input field, and rhs of conv will be quantized to weight field.
+    Output would be in activation field"""
+    if quantize_context().check_to_skip(ref_call):
+        return None
+
+    lhs_expr, lhs_kind = _get_expr_kind(new_args[0])
+    rhs_expr, rhs_kind = _get_expr_kind(new_args[1])
+
+    if lhs_kind is None or lhs_kind == QAnnotateKind.ACTIVATION:
+        lhs_expr = attach_simulated_quantize(lhs_expr, QAnnotateKind.INPUT)
+
+    assert rhs_kind is None
+    rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.WEIGHT)
+
+    expr = _forward_op(ref_call, [lhs_expr, rhs_expr])
+
+    return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)
+
+
 @register_annotate_function("nn.dense")
 def dense_rewrite(ref_call, new_args, ctx):
     """Rewrite function for dense. Lhs of dense will be quantized to input 
field, and rhs of
@@ -289,6 +311,8 @@ register_annotate_function("clip", identity_rewrite)
 register_annotate_function("nn.relu", identity_rewrite)
 register_annotate_function("strided_slice", identity_rewrite)
 register_annotate_function("nn.avg_pool2d", identity_rewrite)
+register_annotate_function("nn.batch_flatten", identity_rewrite)
+register_annotate_function("transpose", identity_rewrite)
 register_annotate_function("annotation.stop_fusion", identity_rewrite)
 
 
@@ -311,6 +335,25 @@ def pool2d_rewrite(ref_call, new_args, ctx):
 register_annotate_function("nn.max_pool2d", pool2d_rewrite)
 
 
+def pool1d_rewrite(ref_call, new_args, ctx):
+    """Rewrite function for max pool1d"""
+    if quantize_context().check_to_skip(ref_call):
+        return None
+
+    expr, x_kind = _get_expr_kind(new_args[0])
+
+    if x_kind is None:
+        return None
+    if x_kind == QAnnotateKind.ACTIVATION:
+        expr = attach_simulated_quantize(expr, QAnnotateKind.INPUT)
+
+    expr = _forward_op(ref_call, [expr])
+    return QAnnotateExpr(expr, QAnnotateKind.INPUT)
+
+
+register_annotate_function("nn.max_pool1d", pool1d_rewrite)
+
+
 @register_annotate_function("annotation.cast_hint")
 def cast_hint_rewrite(ref_call, new_args, ctx):
     """Rewrite function to force cast"""
diff --git a/src/relay/quantize/realize.cc b/src/relay/quantize/realize.cc
index 8db72a3..2716c6e 100644
--- a/src/relay/quantize/realize.cc
+++ b/src/relay/quantize/realize.cc
@@ -234,6 +234,37 @@ Expr Conv2dRealize(const Call& ref_call, const 
Array<Expr>& new_args, const Obje
 
 RELAY_REGISTER_OP("nn.conv2d").set_attr<FForwardRewrite>("FQRealizeRewrite", 
Conv2dRealize);
 
+Expr Conv1dRealize(const Call& ref_call, const Array<Expr>& new_args, const 
ObjectRef& ctx) {
+  const QConfig& cfg = QConfig::Current();
+  CHECK_EQ(new_args.size(), 2);
+  if (!new_args[0]->IsInstance<TempExprNode>() && 
!new_args[1]->IsInstance<TempExprNode>()) {
+    return Expr(nullptr);
+  }
+  const auto* lhs = new_args[0].as<QRealizeIntExprNode>();
+  CHECK(lhs);
+  const auto* rhs = new_args[1].as<QRealizeIntExprNode>();
+  CHECK(rhs);
+
+  Expr ldata = lhs->data;
+  if (lhs->dtype != cfg->dtype_input) {
+    ldata = Cast(ldata, cfg->dtype_input);
+  }
+  Expr rdata = Cast(rhs->data, cfg->dtype_weight);
+
+  const auto ref_attrs = ref_call->attrs.as<Conv1DAttrs>();
+  auto attrs = make_object<Conv1DAttrs>();
+  *attrs = *ref_attrs;
+  DataType out_dtype = cfg->dtype_activation;
+  attrs->out_dtype = out_dtype;
+
+  Expr ret = Call(ref_call->op, {ldata, rdata}, Attrs(attrs), 
ref_call->type_args);
+  Expr mul = Multiply(lhs->dom_scale, rhs->dom_scale);
+  Expr dom_scale = FoldConstantOpt(mul);
+  return QRealizeIntExpr(ret, dom_scale, out_dtype);
+}
+
+RELAY_REGISTER_OP("nn.conv1d").set_attr<FForwardRewrite>("FQRealizeRewrite", 
Conv1dRealize);
+
 Expr DenseRealize(const Call& ref_call, const Array<Expr>& new_args, const 
ObjectRef& ctx) {
   const QConfig& cfg = QConfig::Current();
   ICHECK_EQ(new_args.size(), 2);
@@ -449,6 +480,8 @@ 
RELAY_REGISTER_OP("strided_slice").set_attr<FForwardRewrite>("FQRealizeRewrite",
 RELAY_REGISTER_OP("nn.batch_flatten")
     .set_attr<FForwardRewrite>("FQRealizeRewrite", IdentityRealize);
 
+RELAY_REGISTER_OP("transpose").set_attr<FForwardRewrite>("FQRealizeRewrite", 
IdentityRealize);
+
 RELAY_REGISTER_OP("annotation.stop_fusion")
     .set_attr<FForwardRewrite>("FQRealizeRewrite", IdentityRealize);
 
@@ -469,6 +502,9 @@ Expr CastDtypeInputRealize(const Call& ref_call, const 
Array<Expr>& new_args,
 RELAY_REGISTER_OP("nn.max_pool2d")
     .set_attr<FForwardRewrite>("FQRealizeRewrite", CastDtypeInputRealize);
 
+RELAY_REGISTER_OP("nn.max_pool1d")
+    .set_attr<FForwardRewrite>("FQRealizeRewrite", CastDtypeInputRealize);
+
 Expr AvgPoolRealize(const Call& ref_call, const Array<Expr>& new_args, const 
ObjectRef& ctx) {
   const QConfig& cfg = QConfig::Current();
   ICHECK_EQ(new_args.size(), 1);

Reply via email to