This is an automated email from the ASF dual-hosted git repository. andrewzhaoluo pushed a commit to branch aluo/autotensorization-redux-09192022 in repository https://gitbox.apache.org/repos/asf/tvm.git
commit 72d4ca5218e1a1ee0378bc6681eff756598e3598 Author: Andrew Zhao Luo <[email protected]> AuthorDate: Mon Sep 19 12:58:43 2022 -0700 changes change --- python/tvm/meta_schedule/default_config.py | 108 ++++++++++++++++++- python/tvm/relay/op/contrib/dnnl.py | 64 ++++++++--- python/tvm/relay/qnn/op/qnn.py | 68 ++++++++++++ python/tvm/relay/qnn/transform.py | 78 ++++++++++++++ .../transform/fake_quantization_to_integer.py | 86 +++++++++++++++ src/relay/qnn/op/div.cc | 117 +++++++++++++++++++++ 6 files changed, 501 insertions(+), 20 deletions(-) diff --git a/python/tvm/meta_schedule/default_config.py b/python/tvm/meta_schedule/default_config.py index ac4028ec50..eaa026e3b4 100644 --- a/python/tvm/meta_schedule/default_config.py +++ b/python/tvm/meta_schedule/default_config.py @@ -20,6 +20,8 @@ import logging from os import path as osp from typing import Callable, Dict, List, Optional, Union +from tvm._ffi.registry import register_func +from tvm.contrib import nvcc from tvm.ir import IRModule from tvm.target import Target from tvm.tir import PrimFunc @@ -43,6 +45,20 @@ FnPostproc = Callable[[], List[Postproc]] FnMutatorProb = Callable[[], Dict[Mutator, float]] +def target_has_vnni(target): + return target in { + "cascadelake", + "icelake-client", + "icelake-server", + "rocketlake", + "tigerlake", + "cooperlake", + "sapphirerapids", + "alderlake", + } + + +@register_func("tvm.meta_schedule.tune.parse_mod") # for use in ApplyHistoryBest def mod(mod: Union[PrimFunc, IRModule]) -> IRModule: # pylint: disable=redefined-outer-name """Normalize the input to an IRModule""" if isinstance(mod, PrimFunc): @@ -174,9 +190,13 @@ def schedule_rules( # pylint: disable=redefined-outer-name return sch_rules() if sch_rules is not None: raise TypeError(f"Expected `sch_rules` to be None or callable, but gets: {sch_rules}") - if target.kind.name in ["llvm", "hexagon"]: + if target.kind.name == "llvm": + if target_has_vnni(target.mcpu): + return _DefaultLLVMVNNI.schedule_rules() return _DefaultLLVM.schedule_rules() if target.kind.name in ["cuda", "rocm", "vulkan"]: + if target.kind.name == "cuda" and nvcc.have_tensorcore(target=target): + return _DefaultCUDATensorCore.schedule_rules() return _DefaultCUDA.schedule_rules() raise ValueError(f"Unsupported target: {target}") @@ -190,9 +210,13 @@ def postproc( # pylint: disable=redefined-outer-name return postproc() if postproc is not None: raise TypeError(f"Expected `postproc` to be None or callable, but gets: {postproc}") - if target.kind.name in ["llvm", "hexagon"]: + if target.kind.name == "llvm": + if target_has_vnni(target.mcpu): + return _DefaultLLVMVNNI.postprocs() return _DefaultLLVM.postprocs() if target.kind.name in ["cuda", "rocm", "vulkan"]: + if target.kind.name == "cuda" and nvcc.have_tensorcore(target=target): + return _DefaultCUDATensorCore.postprocs() return _DefaultCUDA.postprocs() raise ValueError(f"Unsupported target: {target}") @@ -208,9 +232,13 @@ def mutator_probs( # pylint: disable=redefined-outer-name raise TypeError( f"Expected `mutator_probs` to be None or callable, but gets: {mutator_probs}" ) - if target.kind.name in ["llvm", "hexagon"]: + if target.kind.name == "llvm": + if target_has_vnni(target.mcpu): + return _DefaultLLVMVNNI.mutator_probs() return _DefaultLLVM.mutator_probs() if target.kind.name in ["cuda", "rocm", "vulkan"]: + if target.kind.name == "cuda" and nvcc.have_tensorcore(target=target): + return _DefaultCUDATensorCore.mutator_probs() return _DefaultCUDA.mutator_probs() raise ValueError(f"Unsupported target: {target}") @@ -277,6 +305,78 @@ class _DefaultLLVM: } +class _DefaultLLVMVNNI: + """Default tuning configuration for LLVM with VNNI.""" + + @staticmethod + def schedule_rules() -> List[ScheduleRule]: + from tvm.meta_schedule import schedule_rule as M + from tvm.tir.tensor_intrin.x86 import VNNI_DOT_16x4_INTRIN + + logger.info("Using schedule rule: LLVM VNNI") + + return [ + M.AutoInline( + into_producer=False, + into_consumer=True, + inline_const_tensor=True, + disallow_if_then_else=True, + require_injective=True, + require_ordered=True, + disallow_op=["tir.exp"], + ), + M.AddRFactor(max_jobs_per_core=16, max_innermost_factor=64), + M.MultiLevelTilingWithIntrin( + VNNI_DOT_16x4_INTRIN, + structure="SSRSRS", + tile_binds=None, + max_innermost_factor=64, + vector_load_lens=None, + reuse_read=None, + reuse_write=M.ReuseType( + req="may", + levels=[1, 2], + scope="global", + ), + ), + M.MultiLevelTiling( + structure="SSRSRS", + tile_binds=None, + max_innermost_factor=64, + vector_load_lens=None, + reuse_read=None, + reuse_write=M.ReuseType( + req="may", + levels=[1, 2], + scope="global", + ), + ), + M.ParallelizeVectorizeUnroll( + max_jobs_per_core=16, + max_vectorize_extent=64, + unroll_max_steps=[0, 16, 64, 512], + unroll_explicit=True, + ), + M.RandomComputeLocation(), + ] + + @staticmethod + def postprocs() -> List[Postproc]: + from tvm.meta_schedule import postproc as M + + return [ + M.DisallowDynamicLoop(), + M.RewriteParallelVectorizeUnroll(), + M.RewriteReductionBlock(), + M.RewriteTensorize(vectorize_init_loop=True), + M.RewriteLayout(), + ] + + @staticmethod + def mutator_probs() -> Dict[Mutator, float]: + return _DefaultLLVM.mutator_probs() + + class _DefaultCUDA: """Default tuning configuration for CUDA.""" @@ -355,6 +455,8 @@ class _DefaultCUDATensorCore: from tvm.meta_schedule import schedule_rule as M from tvm.tir.tensor_intrin.cuda import get_wmma_intrin_group + logger.info("Using schedule rule: CUDA tensorcore") + return [ M.MultiLevelTilingTensorCore( intrin_groups=[ diff --git a/python/tvm/relay/op/contrib/dnnl.py b/python/tvm/relay/op/contrib/dnnl.py index f7752e41b0..67909b04b8 100644 --- a/python/tvm/relay/op/contrib/dnnl.py +++ b/python/tvm/relay/op/contrib/dnnl.py @@ -36,22 +36,18 @@ import logging from functools import reduce import tvm.ir -from tvm.ir import Op from tvm import relay +from tvm.ir import Op +from tvm.relay import expr as _expr from tvm.relay import transform -from tvm.relay.expr import GlobalVar -from tvm.relay.expr_functor import ExprMutator, ExprVisitor -from tvm.relay.expr import const - from tvm.relay.analysis import analysis as _analysis -from tvm.relay import expr as _expr +from tvm.relay.expr import Call, GlobalVar, TupleGetItem, const +from tvm.relay.expr_functor import ExprMutator, ExprVisitor -from tvm.relay.expr import Call, TupleGetItem from ... import _ffi_api -from ...dataflow_pattern import wildcard, is_op, is_constant, is_expr, rewrite, DFPatternCallback +from ...dataflow_pattern import DFPatternCallback, is_constant, is_expr, is_op, rewrite, wildcard from .register import register_pattern_table - logger = logging.getLogger("DNNL") supported_post_elts = ["nn.relu", "tanh", "sigmoid", "clip", "gelu", "swish", "mish", None] @@ -809,7 +805,7 @@ def prune_dnnl_subgraphs(mod): return new_mod -class LayerNormRewrite(DFPatternCallback): +class LayerNormRewritePattern1(DFPatternCallback): """ A callback to rewrite the following operators into a single layer normalization operator. @@ -826,7 +822,42 @@ class LayerNormRewrite(DFPatternCallback): /* ty=Tensor[(1, 3136, 64), float32] */; 10 %13 = add(%12, meta[relay.Constant][3] /* ty=Tensor[(64), float32] */) /* ty=Tensor[(1, 3136, 64), float32] */; + """ + + def __init__(self): + super(LayerNormRewritePattern1, self).__init__() + self.data = wildcard() + self.gamma = wildcard() + self.beta = wildcard() + mu = is_op("mean")(self.data) + diff = is_op("subtract")(self.data, mu) + cdiff = is_op("cast")(diff) | diff # cast does not need to be here usually + const_two = ( + is_expr(relay.const(2)) + | is_expr(relay.const(2.0)) + | is_expr(relay.const(2.0, dtype="float16")) + ) + p1 = is_op("power")(cdiff, const_two) + mp1 = is_op("mean")(p1) + eps = is_constant() # TODO: check epsilon is something reasonable + added_eps = is_op("add")(mp1, eps) + deno = is_op("sqrt")(added_eps) + div_out = is_op("divide")(diff, deno) + div_out2 = diff * is_op("rsqrt")(added_eps) + weighted = is_op("multiply")(div_out | div_out2, self.gamma) + added_bias = is_op("add")(weighted, self.beta) + self.pattern = added_bias + def callback(self, pre, post, node_map): + data = node_map[self.data][0] + gamma = node_map[self.gamma][0] + beta = node_map[self.beta][0] + return relay.op.nn.layer_norm(data=data, gamma=gamma, beta=beta) + + +class LayerNormRewritePattern2(DFPatternCallback): + """ + A callback to rewrite the following operators into a single layer normalization operator. Pattern #2: 1 %0 = mean(%input, axis=[-1], keepdims=True); 2 %1 = variance(%input, %0, axis=[-1], keepdims=True); @@ -842,19 +873,16 @@ class LayerNormRewrite(DFPatternCallback): """ def __init__(self): - super(LayerNormRewrite, self).__init__() + super(LayerNormRewritePattern2, self).__init__() self.data = wildcard() self.gamma = wildcard() self.beta = wildcard() mu = is_op("mean")(self.data) - diff = is_op("subtract")(self.data, mu) - cdiff = diff | is_op("cast")(diff) - const_two = is_expr(relay.const(2)) | is_expr(relay.const(2.0)) - p1 = is_op("power")(cdiff, const_two) - mp1 = is_op("mean")(p1) | is_op("variance")(self.data, mu) + mp1 = is_op("variance")(self.data, mu) eps = is_expr(relay.const(1e-5)) | is_expr(relay.const(1e-6)) added_eps = is_op("add")(mp1, eps) deno = is_op("sqrt")(added_eps) + diff = is_op("subtract")(self.data, mu) div_out = is_op("divide")(diff, deno) div_out2 = diff * is_op("rsqrt")(added_eps) weighted = is_op("multiply")(div_out | div_out2, self.gamma) @@ -872,7 +900,9 @@ def rewrite_layer_norm(mod): """Rewrite the input graph to replace multiple operators with a TVM native layer normalization operator so that we can offload them to dnnl layer normalization byoc part. """ - mod["main"] = rewrite(LayerNormRewrite(), mod["main"]) + mod["main"] = rewrite(LayerNormRewritePattern1(), mod["main"]) + mod["main"] = rewrite(LayerNormRewritePattern2(), mod["main"]) + return mod diff --git a/python/tvm/relay/qnn/op/qnn.py b/python/tvm/relay/qnn/op/qnn.py index 1f38385107..6d1cabeb8d 100644 --- a/python/tvm/relay/qnn/op/qnn.py +++ b/python/tvm/relay/qnn/op/qnn.py @@ -788,6 +788,74 @@ def mul( ) +def div( + lhs, + rhs, + lhs_scale, + lhs_zero_point, + rhs_scale, + rhs_zero_point, + output_scale, + output_zero_point, + lhs_axis=-1, + rhs_axis=-1, +): + """Quantized division with numpy-style broadcasting. + + Parameters + ---------- + lhs : relay.Expr + The left hand side quantized input data. + + rhs : relay.Expr + The right hand side quantized input data. + + lhs_scale: relay.Expr + The scale of the lhs quantized expr. + + lhs_zero_point: relay.Expr + The zero point of lhs quantized expr. + + rhs_scale: relay.Expr + The scale of the rhs quantized expr. + + rhs_zero_point: relay.Expr + The zero point of rhs quantized expr. + + output_scale: relay.Expr + The scale of the output quantized expr. + + output_zero_point: relay.Expr + The zero point of output quantized expr. + + lhs_axis: int + The channel axis for lhs quantization. Default value is -1 which corresponds + to the last axis. + + rhs_axis: int + The channel axis for rhs quantization. Default value is -1 which corresponds + to the last axis. + + Returns + ------- + result : relay.Expr + The computed result. + + """ + return _make.div( + lhs, + rhs, + lhs_scale, + lhs_zero_point, + rhs_scale, + rhs_zero_point, + output_scale, + output_zero_point, + lhs_axis, + rhs_axis, + ) + + def tanh(x, scale, zero_point, output_scale, output_zero_point): """Quantized tanh. diff --git a/python/tvm/relay/qnn/transform.py b/python/tvm/relay/qnn/transform.py index 0485cecb99..7b42942c8b 100644 --- a/python/tvm/relay/qnn/transform.py +++ b/python/tvm/relay/qnn/transform.py @@ -114,3 +114,81 @@ def Legalize(): """ return relay.transform.Legalize("FTVMQnnLegalize") + + +from tvm.relay.dataflow_pattern import ( + DFPatternCallback, + is_constant, + is_expr, + is_op, + rewrite, + wildcard, +) + + +class RSqrtPattern(DFPatternCallback): + """ + Rewrites QNN RSQRT Pattern + """ + + def __init__(self): + super(RSqrtPattern, self).__init__() + + self.sqrt_data = wildcard() + self.sqrt_data_input_scale = wildcard() + self.sqrt_data_input_zp = wildcard() + + self.numerator = wildcard() + self.numerator_scale = wildcard() + self.numerator_zp = wildcard() + + self.output_scale = wildcard() + self.output_zp = wildcard() + + self.sqrt = is_op("qnn.sqrt")( + self.sqrt_data, + self.sqrt_data_input_scale, + self.sqrt_data_input_zp, + wildcard(), + wildcard(), + ) + + # TODO: match axis properly + self.rsqrt = is_op("qnn.div")( + self.numerator, + self.sqrt, + self.numerator_scale, + self.numerator_zp, + wildcard(), + wildcard(), + self.output_scale, + self.output_zp, + ) + + self.pattern = self.rsqrt + + def callback(self, pre, post, node_map): + sqrt_data = node_map[self.sqrt_data][0] + sqrt_data_scale = node_map[self.sqrt_data_input_scale][0] + sqrt_data_zp = node_map[self.sqrt_data_input_zp][0] + + numerator = node_map[self.numerator][0] + numerator_scale = node_map[self.numerator_scale][0] + numerator_zp = node_map[self.numerator_zp][0] + + output_scale = node_map[self.output_scale][0] + output_zp = node_map[self.output_zp][0] + + rsqrt = relay.qnn.op.rsqrt( + sqrt_data, sqrt_data_scale, sqrt_data_zp, numerator_scale, numerator_zp + ) + return relay.qnn.op.mul( + numerator, + rsqrt, + numerator_scale, + numerator_zp, + numerator_scale, + numerator_zp, + output_scale, + output_zp, + ) diff --git a/python/tvm/relay/transform/fake_quantization_to_integer.py b/python/tvm/relay/transform/fake_quantization_to_integer.py index 242740399f..3dd2474170 100644 --- a/python/tvm/relay/transform/fake_quantization_to_integer.py +++ b/python/tvm/relay/transform/fake_quantization_to_integer.py @@ -19,6 +19,7 @@ import numpy as np import tvm from tvm import relay from tvm.ir import TensorAffineType, TupleAffineType +from tvm.relay.op.tensor import ones_like # import to register canonicalization funcs for fq2i # pylint: disable=unused-import @@ -198,6 +199,60 @@ def broadcast_to(expr, type_map): return [out, t] +@register_fake_quantization_to_integer("take") +def take(expr, type_map): + """Rewrite a take op""" + arg1 = expr.args[0] + t = type_map[arg1] + arg2 = expr.args[1] + out = relay.op.take( + arg1, + arg2, + axis=expr.attrs.axis, + batch_dims=expr.attrs.batch_dims, + mode=expr.attrs.mode, + ) + return [out, t] + + +@register_fake_quantization_to_integer("power") +def power(expr, type_map): + base = expr.args[0] + exponent = expr.args[1] + + base_type = type_map[base] + + if not isinstance(exponent, relay.Constant): + return [expr, type_map[expr]] + + data = exponent.data.numpy() + if not len(data.shape) == 0: + return [expr, type_map[expr]] + + data = data.item() + if data != 2: + return [expr, type_map[expr]] + + out = relay.qnn.op.mul( + base, + base, + base_type.scale, + base_type.zero_point, + base_type.scale, + base_type.zero_point, + output_scale=base_type.scale * base_type.scale, + output_zero_point=base_type.zero_point, + lhs_axis=base_type.axis, + rhs_axis=base_type.axis, + ) + return [ + out, + TensorAffineType( + base_type.scale * base_type.scale, base_type.zero_point, base_type.dtype, base_type.axis + ), + ] + + @register_fake_quantization_to_integer("nn.bias_add") def bias_add(expr, type_map): """Rewrite a bias_add op""" @@ -520,6 +575,37 @@ def register_binary_qnn(op_name, op): register_binary_qnn("add", lambda *args: relay.qnn.op.add(*args)) register_binary_qnn("multiply", lambda *args: relay.qnn.op.mul(*args)) register_binary_qnn("subtract", lambda *args: relay.qnn.op.subtract(*args)) +# register_binary_qnn("divide", lambda *args: relay.qnn.op.div(*args)) + + +''' +@register_fake_quantization_to_integer("divide") +def divide(expr, type_map): + """Rewrite an adaptive avgpool op""" + numerator = expr.args[0] + denominator = expr.args[1] + numerator_t = type_map[numerator] + denominator_t = type_map[denominator] + new_scale = numerator_t.scale / (denominator_t.scale * (denominator - denominator_t.zero_point)) + out = relay.divide(numerator, ones_like(denominator)) + assert numerator_t.axis == denominator_t.axis, "Only support identical axis for now." + # print(out) + + print("new out:") + str_new_out = str(relay.transform.InferType()(tvm.IRModule.from_expr(out))) + print("\n".join(str_new_out.split("\n")[-10:])) + print("old_out:") + str_old_out = str(relay.transform.InferType()(tvm.IRModule.from_expr(expr))) + print("\n".join(str_old_out.split("\n")[-10:])) + print() + breakpoint() + # print("yay!") + # This is to get broadcasting working to get same shape + return [ + out, + TensorAffineType(new_scale, numerator_t.zero_point, numerator_t.dtype, numerator_t.axis), + ] +''' def register_binary_identity(op_name, op): diff --git a/src/relay/qnn/op/div.cc b/src/relay/qnn/op/div.cc new file mode 100644 index 0000000000..3c37ed41c4 --- /dev/null +++ b/src/relay/qnn/op/div.cc @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/qnn/op/mul.cc + * \brief QNN mul operator. + */ +#include <tvm/relay/analysis.h> +#include <tvm/relay/op_attr_types.h> +#include <tvm/relay/qnn/attrs.h> + +#include "../../transforms/pattern_utils.h" +#include "../utils.h" +#include "op_common.h" + +namespace tvm { +namespace relay { +namespace qnn { + +/* + * \brief Canonicalizes the QNN div op. + * \param attrs The QNN div attrs. + * \param new_args The new mutated args to the call node. + * \param arg_types The types of input and output. + * \return The sequence of Relay ops for mul op. + */ +Expr QnnDivCanonicalize(const Attrs& attrs, const Array<Expr>& new_args, + const Array<tvm::relay::Type>& arg_types) { + Expr output; + + // Get the attrs. + QnnBinaryOpArguments args(new_args); + + // Get the input dtype and shape. + QnnBinaryOpTensorType input_type(arg_types, 0); + + // data types + const auto int32_dtype = DataType::Int(32); + const auto float32_dtype = DataType::Float(32); + + const auto* broadcast_attrs = attrs.as<BroadcastAttrs>(); + ICHECK(broadcast_attrs != nullptr); + + if (IsConstScalar(args.lhs_scale) && IsConstScalar(args.rhs_scale)) { + /* If both are constant: + + n1/n2 = [s1(q1-z1)] / [s2(q2-z2)] + n1/n2 = [s1/s2][(q1-z1)/(q2-z2)] + + As [(q1-z1)/(q2-z2)] is integer division, we lose perhaps significant precision. + To get around this we scale the numerator by C to ensure that + + |C(q1-z1)| >> (q2 - z2) and the precision loss from the division is minimal: + + n1/n2 = [s1/(s2 * C)][C(q1-z1)/(q2-z2)] + */ + + auto lhs_shifted = Cast(args.lhs, int32_dtype); + auto rhs_shifted = Cast(args.rhs, int32_dtype); + + auto zero_scalar = MakeConstantScalar(int32_dtype, 0); + if (!IsEqualScalar(args.lhs_zero_point, zero_scalar)) { + lhs_shifted = Subtract(lhs_shifted, args.lhs_zero_point); + } + + if (!IsEqualScalar(args.rhs_zero_point, zero_scalar)) { + rhs_shifted = Subtract(rhs_shifted, args.rhs_zero_point); + } + + // multiply numerator to avoid precision loss, as accumulate in INT32 and + // may deal with UINT16, multiply by 2^15 + int divide_scale_factor = 32768; + auto divide_scale_factor_constant = MakeConstantScalar(int32_dtype, divide_scale_factor); + output = Divide(Multiply(lhs_shifted, divide_scale_factor_constant), rhs_shifted); + + // Get the adjusted new scale and zero points. + float lhs_scale_float = GetScalarFromConstant<float>(args.lhs_scale); + float rhs_scale_float = GetScalarFromConstant<float>(args.rhs_scale); + float new_scale_float = lhs_scale_float / (rhs_scale_float * divide_scale_factor); + auto new_input_scale = MakeConstantScalar(float32_dtype, new_scale_float); + auto new_input_zero_point = zero_scalar; + + // Requantize to get Q_c + output = Requantize(output, input_type.shape, new_input_scale, new_input_zero_point, + args.output_scale, args.output_zero_point, input_type.dtype); + } else { + LOG(FATAL) << "Non-constant scale_factor not supported yet."; + } + + return output; +} + +// QNN Multiplication operator. +QNN_REGISTER_BINARY_OP("div") + .describe("Elementwise div with broadcasting for quantized tensors.") + .set_support_level(11) + .set_attr<FTVMLegalize>("FTVMQnnCanonicalize", QnnDivCanonicalize); + +} // namespace qnn +} // namespace relay +} // namespace tvm
