AndrewZhaoLuo commented on a change in pull request #8069: URL: https://github.com/apache/tvm/pull/8069#discussion_r653826430
########## File path: python/tvm/relay/transform/mixed_precision.py ########## @@ -0,0 +1,177 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=line-too-long,unused-argument +"""Default behavior for ops in mixed_precision pass. Import this file to use.""" +from typing import List + +from tvm import relay +from tvm.relay.op import register_mixed_precision_conversion + +# MIXED_PRECISION_ALWAYS ops should always be done in lower precision due to the speed and memory +# savings. MIXED_PRECISION_FOLLOW ops can be done in lower precision but don't have speedups to +# justify a cast. MIXED_PRECISION_NEVER colored ops should not be done in lower precision due to +# numerical reasons. +MIXED_PRECISION_ALWAYS = 0 +MIXED_PRECISION_FOLLOW = 1 +MIXED_PRECISION_NEVER = 2 + +# Default lists inspired from TF's classifications: +# github.com/tensorflow/tensorflow/blob/v2.5.0/tensorflow/core/grappler/optimizers/auto_mixed_precision_lists.h +# They have a bias toward Nvidia Tensor Cores so modify lists per your hardware choice. +DEFAULT_ALWAYS_LIST = [ + "nn.conv1d", + "nn.conv2d", + "nn.conv3d", + "nn.conv1d_transpose", + "nn.conv2d_transpose", + "nn.conv3d_transpose", + "nn.dense", + # "nn.batch_matmul", # Handled by a special case +] +DEFAULT_FOLLOW_LIST = [ + # These ops add new data or change shape + "nn.pad", + "nn.batch_flatten", + "concatenate", + "zeros", + "split", + "squeeze", + "transpose", + "expand_dims", + "reshape", + "dyn.reshape", + "broadcast_to_like", + "dyn.broadcast_to", + "strided_slice", + "dyn.strided_slice", + "take", + "argwhere", + "where", + "tile", + "dyn.tile", + "scatter", + "full", + "dyn.full", + # Comparison + "less", + "greater", + "less_equal", + "greater_equal", + # By definition copy and cast will depend on inputs for output. + "copy", + "cast", + "cast_like", + # Simple arithmetic + "add", + "subtract", + "multiply", + "divide", + "nn.bias_add", + "nn.batch_norm", + "sum", + "mean", + "sqrt", + "shape_of", + # Simple activations + "max", + "min", + "maximum", + "minimum", + "nn.relu", + "nn.leaky_relu", + "nn.prelu", + "nn.dropout", + # Complicated activations which saturate in a narrow range + "sigmoid", + "tanh", + # Pooling operations + "nn.max_pool1d", + "nn.max_pool2d", + "nn.max_pool3d", + "nn.avg_pool1d", + "nn.avg_pool2d", + "nn.avg_pool3d", + # "nn.global_max_pool1d", # does not exist yet + "nn.global_max_pool2d", + # "nn.global_max_pool3d", # does not exist yet + # "nn.global_avg_pool1d", # does not exist yet + "nn.global_avg_pool2d", + # "nn.global_avg_pool3d", # does not exist yet + "nn.adaptive_max_pool1d", + "nn.adaptive_max_pool2d", + "nn.adaptive_max_pool3d", + "nn.adaptive_avg_pool1d", + "nn.adaptive_avg_pool2d", + "nn.adaptive_avg_pool3d", +] +DEFAULT_NEVER_LIST = [ + # In general if |f(x)| >> |x| for expected inputs then put the op here. + "exp", + "power", + "nn.cross_entropy", + "nn.cross_entropy_with_logits", + "nn.softmax", + "nn.l2_normalize", + # Error function doesn't seem to be able to be lowered into fp16 version in llvm. + # Move to follow list when it does. + "erf", +] + + +# Returns a decorator which registers for every given op, the function under FTVMMixedPrecisionConversionType +def register_func_to_op_list(list_ops): + def decorator(func): + for op_name in list_ops: + register_mixed_precision_conversion(op_name, func=func) + + return decorator + + +def get_generic_out_dtypes(call_node: relay.Call, mixed_precision_type: str) -> List[str]: + # Assume support accumulation dtypes <---> has out_dtype attr Review comment: Right now there isn't a good way to tell which ops have "accumulation" dtypes. This is the simplest method I could think of. There is discussion on the discuss thread about making it easier to tell which ops support heterogenous accumulators. ########## File path: python/tvm/relay/transform/mixed_precision.py ########## @@ -0,0 +1,177 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=line-too-long,unused-argument +"""Default behavior for ops in mixed_precision pass. Import this file to use.""" +from typing import List + +from tvm import relay +from tvm.relay.op import register_mixed_precision_conversion + +# MIXED_PRECISION_ALWAYS ops should always be done in lower precision due to the speed and memory +# savings. MIXED_PRECISION_FOLLOW ops can be done in lower precision but don't have speedups to +# justify a cast. MIXED_PRECISION_NEVER colored ops should not be done in lower precision due to +# numerical reasons. +MIXED_PRECISION_ALWAYS = 0 +MIXED_PRECISION_FOLLOW = 1 +MIXED_PRECISION_NEVER = 2 + +# Default lists inspired from TF's classifications: +# github.com/tensorflow/tensorflow/blob/v2.5.0/tensorflow/core/grappler/optimizers/auto_mixed_precision_lists.h +# They have a bias toward Nvidia Tensor Cores so modify lists per your hardware choice. +DEFAULT_ALWAYS_LIST = [ + "nn.conv1d", + "nn.conv2d", + "nn.conv3d", + "nn.conv1d_transpose", + "nn.conv2d_transpose", + "nn.conv3d_transpose", + "nn.dense", + # "nn.batch_matmul", # Handled by a special case +] +DEFAULT_FOLLOW_LIST = [ + # These ops add new data or change shape + "nn.pad", + "nn.batch_flatten", + "concatenate", + "zeros", + "split", + "squeeze", + "transpose", + "expand_dims", + "reshape", + "dyn.reshape", + "broadcast_to_like", + "dyn.broadcast_to", + "strided_slice", + "dyn.strided_slice", + "take", + "argwhere", + "where", + "tile", + "dyn.tile", + "scatter", + "full", + "dyn.full", + # Comparison + "less", + "greater", + "less_equal", + "greater_equal", + # By definition copy and cast will depend on inputs for output. + "copy", + "cast", + "cast_like", + # Simple arithmetic + "add", + "subtract", + "multiply", + "divide", + "nn.bias_add", + "nn.batch_norm", + "sum", + "mean", + "sqrt", + "shape_of", + # Simple activations + "max", + "min", + "maximum", + "minimum", + "nn.relu", + "nn.leaky_relu", + "nn.prelu", + "nn.dropout", + # Complicated activations which saturate in a narrow range + "sigmoid", + "tanh", + # Pooling operations + "nn.max_pool1d", + "nn.max_pool2d", + "nn.max_pool3d", + "nn.avg_pool1d", + "nn.avg_pool2d", + "nn.avg_pool3d", + # "nn.global_max_pool1d", # does not exist yet + "nn.global_max_pool2d", + # "nn.global_max_pool3d", # does not exist yet + # "nn.global_avg_pool1d", # does not exist yet + "nn.global_avg_pool2d", + # "nn.global_avg_pool3d", # does not exist yet + "nn.adaptive_max_pool1d", + "nn.adaptive_max_pool2d", + "nn.adaptive_max_pool3d", + "nn.adaptive_avg_pool1d", + "nn.adaptive_avg_pool2d", + "nn.adaptive_avg_pool3d", +] +DEFAULT_NEVER_LIST = [ + # In general if |f(x)| >> |x| for expected inputs then put the op here. + "exp", + "power", + "nn.cross_entropy", + "nn.cross_entropy_with_logits", + "nn.softmax", + "nn.l2_normalize", + # Error function doesn't seem to be able to be lowered into fp16 version in llvm. + # Move to follow list when it does. + "erf", +] + + +# Returns a decorator which registers for every given op, the function under FTVMMixedPrecisionConversionType +def register_func_to_op_list(list_ops): + def decorator(func): + for op_name in list_ops: + register_mixed_precision_conversion(op_name, func=func) + + return decorator + + +def get_generic_out_dtypes(call_node: relay.Call, mixed_precision_type: str) -> List[str]: + # Assume support accumulation dtypes <---> has out_dtype attr Review comment: Done ########## File path: src/relay/transforms/to_mixed_precision.cc ########## @@ -0,0 +1,409 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * + * \file to_mixed_precision.cc + * \brief Automatic mixed floating point precision for relay graphs. i.e. turn a graph into fp16. + * + */ + +#include <tvm/ir/attrs.h> +#include <tvm/relay/expr_functor.h> +#include <tvm/relay/transform.h> +#include <tvm/runtime/object.h> + +#include <utility> + +#include "pattern_utils.h" + +namespace tvm { +namespace relay { + +// A callable which hashes std::pair +struct pair_hash { + template <class T1, class T2> + std::size_t operator()(const std::pair<T1, T2>& pair) const { + auto h1 = std::hash<T1>()(pair.first); + auto h2 = std::hash<T2>()(pair.second); + + // Use boost's combine_hash strategy + return h1 ^ (h1 + 0x9e3779b9 + (h2 << 6) + (h2 >> 2)); + } +}; + +// MIXED_PRECISION_ALWAYS ops should always be done in lower precision due to the speed and memory +// savings. MIXED_PRECISION_FOLLOW ops can be done in lower precision but don't have speedups to +// justify a cast. MIXED_PRECISION_NEVER colored ops should not be done in lower precision due to +// numerical reasons. +enum MixedTypeConversionCategory : int { + MIXED_PRECISION_ALWAYS = 0, + MIXED_PRECISION_FOLLOW = 1, + MIXED_PRECISION_NEVER = 2 +}; + +// A map of a parent node and a wanted dtype to existing nodes casted to the wanted dtype +using CachedCastNodes = std::unordered_map<std::pair<const ExprNode*, DataType>, Expr, pair_hash>; + +// Return array is of type : [MixedTypeConversionCategory (int), String, String] +// The fields are : [ConversionCategory, accumulation_datatype, output_datatype] +// Call is a call node, DataType is the mixed precision type +using FTVMMixedPrecisionConversionType = runtime::TypedPackedFunc<Array<ObjectRef>( + const Call& call_node, const std::string& target_dtype_str)>; + +class MixedPrecisionPass : public MixedModeMutator { + private: + CachedCastNodes cast_nodes_cache; + + // The target datatype we want to convert to e.g. FP16 + const DataType mixed_precision_type; + + // If false, throws a fatal error if an op which is not registered with a + // FTVMMixedPrecisionConversionType is encountered. + bool ignore_missing_ops; + + // If true, emits a warning if an op which is not registered with a + // FTVMMixedPrecisionConversionType is encountered. + bool warn_missing_ops; + + Attrs GetNewAttrs(const CallNode* call, const DataType& accumulation_dtype) const { + /* If the accumulation dtype is in the attributes make a copy and mutate the field. */ + Attrs cur_attrs = call->attrs; + if (cur_attrs.get() != nullptr) { + // TODO(AndrewZhaoLuo): Figure out a better way to do this + // modify output_dtype attributes (accumulation dtypes for ops) + if (auto attrs = cur_attrs.as<Conv1DAttrs>()) { + return ModifyAttrsOutputDType(attrs, accumulation_dtype); + } else if (auto attrs = cur_attrs.as<Conv1DTransposeAttrs>()) { + return ModifyAttrsOutputDType(attrs, accumulation_dtype); + } else if (auto attrs = cur_attrs.as<Conv2DAttrs>()) { + return ModifyAttrsOutputDType(attrs, accumulation_dtype); + } else if (auto attrs = cur_attrs.as<Conv2DTransposeAttrs>()) { + return ModifyAttrsOutputDType(attrs, accumulation_dtype); + } else if (auto attrs = cur_attrs.as<Conv2DWinogradAttrs>()) { + return ModifyAttrsOutputDType(attrs, accumulation_dtype); + } else if (auto attrs = cur_attrs.as<Conv2DWinogradNNPACKWeightTransformAttrs>()) { + return ModifyAttrsOutputDType(attrs, accumulation_dtype); + } else if (auto attrs = cur_attrs.as<DeformableConv2DAttrs>()) { + return ModifyAttrsOutputDType(attrs, accumulation_dtype); + } else if (auto attrs = cur_attrs.as<Conv3DAttrs>()) { + return ModifyAttrsOutputDType(attrs, accumulation_dtype); + } else if (auto attrs = cur_attrs.as<Conv3DTransposeAttrs>()) { + return ModifyAttrsOutputDType(attrs, accumulation_dtype); + } else if (auto attrs = cur_attrs.as<Conv3DWinogradAttrs>()) { + return ModifyAttrsOutputDType(attrs, accumulation_dtype); + } else if (auto attrs = cur_attrs.as<DenseAttrs>()) { + return ModifyAttrsOutputDType(attrs, accumulation_dtype); + } else if (auto attrs = cur_attrs.as<BatchMatmulAttrs>()) { + return ModifyAttrsOutputDType(attrs, accumulation_dtype); + } + + // modify dtype attributes (creating new tensors of type dtype) + if (auto attrs = cur_attrs.as<InitOpAttrs>()) { + return ModifyAttrsDType(attrs, accumulation_dtype); + } + } + + return cur_attrs; + } + + template <typename T> + Attrs ModifyAttrsOutputDType(const T* attrs, const DataType& accumulation_dtype) const { + /* + Helper template to modify relevant attributes with out_dtype type. + These represent accumulation dtypes for some operations e.g. + conv2d might take in fp16 and give a fp32 result. + Attrs is const because we get it as a const. + */ + DataType cur_type = (attrs->out_dtype); + ObjectPtr<T> new_attrs = make_object<T>(*attrs); + if (cur_type.is_float() || cur_type.is_void()) new_attrs->out_dtype = accumulation_dtype; + return Attrs(new_attrs); + } + + template <typename T> + Attrs ModifyAttrsDType(const T* attrs, const DataType& accumulation_dtype) const { + /* + Helper template to modify relevant attributes with dtype type. + This determines the output dtype for some ops. For example + zeros creates a tensor of zeros of the specified dtype. + Attrs is const because we get it as a const. + */ + DataType cur_type = (attrs->dtype); + ObjectPtr<T> new_attrs = make_object<T>(*attrs); + if (cur_type.is_float() || cur_type.is_void()) new_attrs->dtype = accumulation_dtype; + return Attrs(new_attrs); + } + + Type GetType(const Expr& expr) const { + auto mod = IRModule::FromExpr(expr); + mod = transform::InferType()(mod); + + if (expr.as<FunctionNode>()) { + return mod->Lookup("main")->checked_type(); + } else { + return mod->Lookup("main").as<FunctionNode>()->body->checked_type(); + } + } Review comment: Yeah so the problems fundamentally have to do with me not thinking about and handling algebraic data types. I've added a warning and the pass should fail with an appropriate error message> I'm going to push support for this down the line for a future PR. Most models don't use these features I believe. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: [email protected]
