spectrometerHBH commented on code in PR #14242:
URL: https://github.com/apache/tvm/pull/14242#discussion_r1129930876


##########
src/relax/transform/to_mixed_precision.cc:
##########
@@ -0,0 +1,538 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file src/relax/transform/to_mixed_precision.cc
+ * \brief Automatic mixed precision pass.
+ */
+
+#include <tvm/relax/expr_functor.h>
+#include <tvm/relax/op_attr_types.h>
+#include <tvm/relax/transform.h>
+
+#include <array>
+
+#include "../op/nn/convolution.h"
+#include "../op/tensor/datatype.h"
+#include "../op/tensor/linear_algebra.h"
+#include "infer_amp_utils.h"
+#include "utils.h"
+
+namespace tvm {
+namespace relax {
+
+using runtime::String;
+
+int GetMixedPrecisionInfo(const CallNode* call_node) {
+  const OpNode* op_node = call_node->op.as<OpNode>();
+  if (op_node == nullptr) {
+    return -1;
+  }
+  Op op = GetRef<Op>(op_node);
+  auto attr_map = 
Op::GetAttrMap<TMixedPrecisionPolicy>("TMixedPrecisionPolicy");
+  return attr_map.count(op) ? attr_map[op] : MixedPrecisionPolicyKind::kNever;
+}
+
+/*!
+ * \brief Main logic to automatically cast fp32 input modules to fp16 for 
certain ops.
+ *
+ * Structurally speaking, a Relax function is composed of a series of 
VarBinding and
+ * MatchCast. And a specific class of VarBindings is the basic unit we want to 
rewrite.
+ * Formally, they are of the form:
+ *
+ * var = Call(Op, [args], attrs)
+ *
+ * where Op is a specific op we want to rewrite, and attrs is the attributes 
of the op.
+ * var and args are all exprs with type Tensor or Tuple of Tensors. They might
+ * be vars, constants, or Tuple of vars and constants.
+ * Depending on the properties of the op, we may have 3 different ways to 
rewrite it:
+ *
+ * 1. kAlways: Always cast the args to fp16

Review Comment:
   That's in principle possible. Here the pass is meant to serve as similar 
things like `torch.autocast("cuda")`, where it will always cast the inputs of 
matmul/conv2d to fp16 from fp32.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to