yongwww commented on code in PR #14242: URL: https://github.com/apache/tvm/pull/14242#discussion_r1129891535
########## src/relax/transform/to_mixed_precision.cc: ########## @@ -0,0 +1,538 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file src/relax/transform/to_mixed_precision.cc + * \brief Automatic mixed precision pass. + */ + +#include <tvm/relax/expr_functor.h> +#include <tvm/relax/op_attr_types.h> +#include <tvm/relax/transform.h> + +#include <array> + +#include "../op/nn/convolution.h" +#include "../op/tensor/datatype.h" +#include "../op/tensor/linear_algebra.h" +#include "infer_amp_utils.h" +#include "utils.h" + +namespace tvm { +namespace relax { + +using runtime::String; + +int GetMixedPrecisionInfo(const CallNode* call_node) { + const OpNode* op_node = call_node->op.as<OpNode>(); + if (op_node == nullptr) { + return -1; + } + Op op = GetRef<Op>(op_node); + auto attr_map = Op::GetAttrMap<TMixedPrecisionPolicy>("TMixedPrecisionPolicy"); + return attr_map.count(op) ? attr_map[op] : MixedPrecisionPolicyKind::kNever; +} + +/*! + * \brief Main logic to automatically cast fp32 input modules to fp16 for certain ops. + * + * Structurally speaking, a Relax function is composed of a series of VarBinding and + * MatchCast. And a specific class of VarBindings is the basic unit we want to rewrite. + * Formally, they are of the form: + * + * var = Call(Op, [args], attrs) + * + * where Op is a specific op we want to rewrite, and attrs is the attributes of the op. + * var and args are all exprs with type Tensor or Tuple of Tensors. They might + * be vars, constants, or Tuple of vars and constants. + * Depending on the properties of the op, we may have 3 different ways to rewrite it: + * + * 1. kAlways: Always cast the args to fp16 Review Comment: Probably we will lose some accuracy with `kAlways`, I am wondering if it is possible to allow user to control the `MixedPrecisionPolicyKind` like `ToMixedPrecision(MixedPrecisionPolicyKind=kAlways/kFollow/AMP/Never)` -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
