This is an automated email from the ASF dual-hosted git repository. wuwei pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/main by this push: new ff9c480 making quantization tweaks (#6731) ff9c480 is described below commit ff9c4803913b82085f281c98afbd54feedefeb7c Author: Thierry Moreau <tmor...@octoml.ai> AuthorDate: Fri Nov 6 18:20:56 2020 -0800 making quantization tweaks (#6731) --- python/tvm/relay/quantize/_annotate.py | 43 ++++++++++++++++++++++++++++++++++ src/relay/quantize/realize.cc | 36 ++++++++++++++++++++++++++++ 2 files changed, 79 insertions(+) diff --git a/python/tvm/relay/quantize/_annotate.py b/python/tvm/relay/quantize/_annotate.py index b187387..6c395e2 100644 --- a/python/tvm/relay/quantize/_annotate.py +++ b/python/tvm/relay/quantize/_annotate.py @@ -175,6 +175,28 @@ def conv2d_rewrite(ref_call, new_args, ctx): return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION) +@register_annotate_function("nn.conv1d") +def conv1d_rewrite(ref_call, new_args, ctx): + """Rewrite function for conv1d. Lhs of conv will be quantized to + input field, and rhs of conv will be quantized to weight field. + Output would be in activation field""" + if quantize_context().check_to_skip(ref_call): + return None + + lhs_expr, lhs_kind = _get_expr_kind(new_args[0]) + rhs_expr, rhs_kind = _get_expr_kind(new_args[1]) + + if lhs_kind is None or lhs_kind == QAnnotateKind.ACTIVATION: + lhs_expr = attach_simulated_quantize(lhs_expr, QAnnotateKind.INPUT) + + assert rhs_kind is None + rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.WEIGHT) + + expr = _forward_op(ref_call, [lhs_expr, rhs_expr]) + + return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION) + + @register_annotate_function("nn.dense") def dense_rewrite(ref_call, new_args, ctx): """Rewrite function for dense. Lhs of dense will be quantized to input field, and rhs of @@ -289,6 +311,8 @@ register_annotate_function("clip", identity_rewrite) register_annotate_function("nn.relu", identity_rewrite) register_annotate_function("strided_slice", identity_rewrite) register_annotate_function("nn.avg_pool2d", identity_rewrite) +register_annotate_function("nn.batch_flatten", identity_rewrite) +register_annotate_function("transpose", identity_rewrite) register_annotate_function("annotation.stop_fusion", identity_rewrite) @@ -311,6 +335,25 @@ def pool2d_rewrite(ref_call, new_args, ctx): register_annotate_function("nn.max_pool2d", pool2d_rewrite) +def pool1d_rewrite(ref_call, new_args, ctx): + """Rewrite function for max pool1d""" + if quantize_context().check_to_skip(ref_call): + return None + + expr, x_kind = _get_expr_kind(new_args[0]) + + if x_kind is None: + return None + if x_kind == QAnnotateKind.ACTIVATION: + expr = attach_simulated_quantize(expr, QAnnotateKind.INPUT) + + expr = _forward_op(ref_call, [expr]) + return QAnnotateExpr(expr, QAnnotateKind.INPUT) + + +register_annotate_function("nn.max_pool1d", pool1d_rewrite) + + @register_annotate_function("annotation.cast_hint") def cast_hint_rewrite(ref_call, new_args, ctx): """Rewrite function to force cast""" diff --git a/src/relay/quantize/realize.cc b/src/relay/quantize/realize.cc index 8db72a3..2716c6e 100644 --- a/src/relay/quantize/realize.cc +++ b/src/relay/quantize/realize.cc @@ -234,6 +234,37 @@ Expr Conv2dRealize(const Call& ref_call, const Array<Expr>& new_args, const Obje RELAY_REGISTER_OP("nn.conv2d").set_attr<FForwardRewrite>("FQRealizeRewrite", Conv2dRealize); +Expr Conv1dRealize(const Call& ref_call, const Array<Expr>& new_args, const ObjectRef& ctx) { + const QConfig& cfg = QConfig::Current(); + CHECK_EQ(new_args.size(), 2); + if (!new_args[0]->IsInstance<TempExprNode>() && !new_args[1]->IsInstance<TempExprNode>()) { + return Expr(nullptr); + } + const auto* lhs = new_args[0].as<QRealizeIntExprNode>(); + CHECK(lhs); + const auto* rhs = new_args[1].as<QRealizeIntExprNode>(); + CHECK(rhs); + + Expr ldata = lhs->data; + if (lhs->dtype != cfg->dtype_input) { + ldata = Cast(ldata, cfg->dtype_input); + } + Expr rdata = Cast(rhs->data, cfg->dtype_weight); + + const auto ref_attrs = ref_call->attrs.as<Conv1DAttrs>(); + auto attrs = make_object<Conv1DAttrs>(); + *attrs = *ref_attrs; + DataType out_dtype = cfg->dtype_activation; + attrs->out_dtype = out_dtype; + + Expr ret = Call(ref_call->op, {ldata, rdata}, Attrs(attrs), ref_call->type_args); + Expr mul = Multiply(lhs->dom_scale, rhs->dom_scale); + Expr dom_scale = FoldConstantOpt(mul); + return QRealizeIntExpr(ret, dom_scale, out_dtype); +} + +RELAY_REGISTER_OP("nn.conv1d").set_attr<FForwardRewrite>("FQRealizeRewrite", Conv1dRealize); + Expr DenseRealize(const Call& ref_call, const Array<Expr>& new_args, const ObjectRef& ctx) { const QConfig& cfg = QConfig::Current(); ICHECK_EQ(new_args.size(), 2); @@ -449,6 +480,8 @@ RELAY_REGISTER_OP("strided_slice").set_attr<FForwardRewrite>("FQRealizeRewrite", RELAY_REGISTER_OP("nn.batch_flatten") .set_attr<FForwardRewrite>("FQRealizeRewrite", IdentityRealize); +RELAY_REGISTER_OP("transpose").set_attr<FForwardRewrite>("FQRealizeRewrite", IdentityRealize); + RELAY_REGISTER_OP("annotation.stop_fusion") .set_attr<FForwardRewrite>("FQRealizeRewrite", IdentityRealize); @@ -469,6 +502,9 @@ Expr CastDtypeInputRealize(const Call& ref_call, const Array<Expr>& new_args, RELAY_REGISTER_OP("nn.max_pool2d") .set_attr<FForwardRewrite>("FQRealizeRewrite", CastDtypeInputRealize); +RELAY_REGISTER_OP("nn.max_pool1d") + .set_attr<FForwardRewrite>("FQRealizeRewrite", CastDtypeInputRealize); + Expr AvgPoolRealize(const Call& ref_call, const Array<Expr>& new_args, const ObjectRef& ctx) { const QConfig& cfg = QConfig::Current(); ICHECK_EQ(new_args.size(), 1);