AndrewZhaoLuo commented on a change in pull request #8069: URL: https://github.com/apache/tvm/pull/8069#discussion_r652035337
########## File path: src/relay/transforms/fp32_to_fp16.h ########## @@ -0,0 +1,232 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file fp32_to_fp16.h + * \brief Utilities and common types used for FP32->FP16 pass. + */ +#ifndef TVM_RELAY_TRANSFORMS_FP32_TO_FP16_H_ +#define TVM_RELAY_TRANSFORMS_FP32_TO_FP16_H_ + +#include <tvm/ir/op.h> +#include <tvm/relay/expr.h> +#include <tvm/relay/function.h> + +#include <string> +#include <unordered_map> +#include <unordered_set> +#include <utility> +#include <vector> + +namespace tvm { +namespace relay { + +struct FP16OpDType { + DataType accumulation_dtype; + DataType output_dtype; +}; + +// GREEN colored ops should always be done in FP16 due to the speed and memory savings +// GRAY colored ops can be done in FP16 but don't have speedups to justify a dedicated cast. +// RED colored ops should not be done in FP16 due to numerical reasons. +enum FP16ConversionCategory { RED, GRAY, GREEN }; + +using OpStringSet = std::unordered_set<std::string>; + +// Default lists inspired from TF's classifications: Review comment: This is now done. ########## File path: src/relay/transforms/to_mixed_precision.cc ########## @@ -0,0 +1,356 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * + * \file to_mixed_precision.cc + * \brief Automatic mixed precision for relay graphs. i.e. turn a graph into fp16 form. + */ +#include "to_mixed_precision.h" + +#include <tvm/ir/attrs.h> +#include <tvm/relay/expr_functor.h> +#include <tvm/relay/transform.h> +#include <tvm/runtime/object.h> + +#include <utility> + +#include "pattern_utils.h" + +namespace tvm { +namespace relay { + +// A callable which hashes std::pair +struct pair_hash { + template <class T1, class T2> + std::size_t operator()(const std::pair<T1, T2>& pair) const { + auto h1 = std::hash<T1>()(pair.first); + auto h2 = std::hash<T2>()(pair.second); + + // Use boost's combine_hash strategy + return h1 ^ (h1 + 0x9e3779b9 + (h2 << 6) + (h2 >> 2)); + } +}; + +// A map of a parent node and a wanted dtype to existing nodes casted to the wanted dtype +using CachedCastNodes = std::unordered_map<std::pair<const ExprNode*, DataType>, Expr, pair_hash>; + +// A function which maps CallNodes to their initial conversion color +using ColorFunc = std::function<MixedTypeConversionCategory(const CallNode*)>; + +// A function which maps MIXED_PRECISION_ALWAYS CallNodes to wanted accumulation and output dtypes +using OutputDtypeFunc = std::function<MixedPrecisionOpOutDType(const CallNode*)>; + +class MixedPrecisionPass : public MixedModeMutator { + private: + CachedCastNodes cast_nodes_cache; + const ColorFunc colorer; + const OutputDtypeFunc output_dtype_func; + const DataType mixed_precision_type; + + Attrs GetNewAttrs(const CallNode* call, const DataType& accumulation_dtype) const { + /* If the accumulation dtype is in the attributes make a copy and mutate the field. */ + Attrs new_attrs = Attrs(call->attrs); + if (new_attrs.get() != nullptr) { + // TODO(AndrewZhaoLuo): Figure out a better way to do this + // modify output_dtype attributes (accumulation dtypes for ops) + if (auto attrs = new_attrs.as<Conv1DAttrs>()) { + ModifyAttrsOutputDType(attrs, accumulation_dtype); + } else if (auto attrs = new_attrs.as<Conv1DTransposeAttrs>()) { + ModifyAttrsOutputDType(attrs, accumulation_dtype); + } else if (auto attrs = new_attrs.as<Conv2DAttrs>()) { + ModifyAttrsOutputDType(attrs, accumulation_dtype); + } else if (auto attrs = new_attrs.as<Conv2DTransposeAttrs>()) { + ModifyAttrsOutputDType(attrs, accumulation_dtype); + } else if (auto attrs = new_attrs.as<Conv2DWinogradAttrs>()) { + ModifyAttrsOutputDType(attrs, accumulation_dtype); + } else if (auto attrs = new_attrs.as<Conv2DWinogradNNPACKWeightTransformAttrs>()) { + ModifyAttrsOutputDType(attrs, accumulation_dtype); + } else if (auto attrs = new_attrs.as<DeformableConv2DAttrs>()) { + ModifyAttrsOutputDType(attrs, accumulation_dtype); + } else if (auto attrs = new_attrs.as<Conv3DAttrs>()) { + ModifyAttrsOutputDType(attrs, accumulation_dtype); + } else if (auto attrs = new_attrs.as<Conv3DTransposeAttrs>()) { + ModifyAttrsOutputDType(attrs, accumulation_dtype); + } else if (auto attrs = new_attrs.as<Conv3DWinogradAttrs>()) { + ModifyAttrsOutputDType(attrs, accumulation_dtype); + } else if (auto attrs = new_attrs.as<DenseAttrs>()) { + ModifyAttrsOutputDType(attrs, accumulation_dtype); + } else if (auto attrs = new_attrs.as<BatchMatmulAttrs>()) { + ModifyAttrsOutputDType(attrs, accumulation_dtype); + } + + // modify dtype attributes (creating new tensors of type dtype) + if (auto attrs = new_attrs.as<InitOpAttrs>()) { + ModifyAttrsDType(attrs, accumulation_dtype); + } + } + + return new_attrs; + } + + template <typename T> + void ModifyAttrsOutputDType(const T* attrs, const DataType& accumulation_dtype) const { + /* + Helper template to modify relevant attributes with out_dtype type. + These represent accumulation dtypes for some operations e.g. + conv2d might take in fp16 and give a fp32 result. + Attrs is const because we get it as a const. + */ + T* mutable_attrs = const_cast<T*>(attrs); Review comment: Done -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: [email protected]
