This is an automated email from the ASF dual-hosted git repository.
jwfromm pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 9c36056a12 [Unity] Allow modifying function signature by AMP to accept
fp16 inputs (#14719)
9c36056a12 is described below
commit 9c36056a12c0588540671b8771ad681d4bfb6618
Author: masahi <[email protected]>
AuthorDate: Fri Apr 28 00:43:18 2023 +0900
[Unity] Allow modifying function signature by AMP to accept fp16 inputs
(#14719)
* Modify func sig to accept fp16 inputs in AMP
* add test
* add doc
* fix
* cpplint
---
include/tvm/relax/transform.h | 5 ++-
python/tvm/relax/transform/transform.py | 10 ++++-
src/relax/transform/to_mixed_precision.cc | 40 +++++++++++++++----
.../relax/test_transform_to_mixed_precision.py | 46 ++++++++++++++++++++++
4 files changed, 91 insertions(+), 10 deletions(-)
diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h
index 27bd1bd702..9c3a763d69 100644
--- a/include/tvm/relax/transform.h
+++ b/include/tvm/relax/transform.h
@@ -479,9 +479,12 @@ TVM_DLL Pass DeadCodeElimination(Array<runtime::String>
entry_functions);
* \brief Automatic mixed precision pass. Currently the pass assumes the input
module to be fp32
* only, and will automatically cast fp32 to fp16 for certain ops.
* \param out_dtype The output data type of gemm/conv, which is the data type
of the accumulator.
+ * \param fp16_input_names The names of function parameters whose dtype should
become fp16. The
+ * function signature would change accordingly.
* \return The Pass.
*/
-TVM_DLL Pass ToMixedPrecision(const DataType& out_dtype);
+TVM_DLL Pass ToMixedPrecision(const DataType& out_dtype,
+ Optional<Array<String>> fp16_input_names =
NullOpt);
/*!
* \brief Rewrite a Relax module for executing with CUDA graph. This pass
identifies
diff --git a/python/tvm/relax/transform/transform.py
b/python/tvm/relax/transform/transform.py
index 870b731883..46f908c448 100644
--- a/python/tvm/relax/transform/transform.py
+++ b/python/tvm/relax/transform/transform.py
@@ -914,19 +914,25 @@ def DeadCodeElimination(entry_functions:
Optional[List[str]] = None) -> tvm.ir.t
return _ffi_api.DeadCodeElimination(entry_functions) # type: ignore
-def ToMixedPrecision(out_dtype="float32") -> tvm.ir.transform.Pass:
+def ToMixedPrecision(
+ out_dtype="float32", fp16_input_names: Optional[List[str]] = None
+) -> tvm.ir.transform.Pass:
"""Automatic mixed precision pass. Currently the pass assumes the input
module to be fp32
only, and will automatically cast fp32 to fp16 for certain ops.
Parameters
----------
out_dtype : str
The output data type of gemm/conv, which is the data type of the
accumulator.
+ fp16_input_names : List[str]
+ The names of function parameters whose dtype should become fp16. The
function signature
+ would change accordingly.
+
Returns
-------
ret : tvm.transform.Pass
The registered pass for mixed precision.
"""
- return _ffi_api.ToMixedPrecision(out_dtype) # type: ignore
+ return _ffi_api.ToMixedPrecision(out_dtype, fp16_input_names) # type:
ignore
def SplitCallTIRByPattern(patterns, fcodegen) -> tvm.ir.transform.Pass:
diff --git a/src/relax/transform/to_mixed_precision.cc
b/src/relax/transform/to_mixed_precision.cc
index a04d5dbd3a..64763276d0 100644
--- a/src/relax/transform/to_mixed_precision.cc
+++ b/src/relax/transform/to_mixed_precision.cc
@@ -27,6 +27,7 @@
#include <array>
#include <cstdint>
+#include <unordered_set>
#include "../op/nn/convolution.h"
#include "../op/tensor/datatype.h"
@@ -273,13 +274,30 @@ class DTypeDecisionCollector : public ExprVisitor {
class ToMixedPrecisionRewriter : public ExprMutator {
public:
- explicit ToMixedPrecisionRewriter(const VarDTypeMap* only_fp16_map, DataType
output_dtype)
- : only_fp16_map_(only_fp16_map), output_dtype_(output_dtype) {}
+ explicit ToMixedPrecisionRewriter(const VarDTypeMap* only_fp16_map, DataType
output_dtype,
+ const std::unordered_set<std::string>&
fp16_input_names)
+ : only_fp16_map_(only_fp16_map),
+ output_dtype_(output_dtype),
+ fp16_input_names_(fp16_input_names) {}
private:
Var GetRemapped(const Var& var) {
auto it = var_remap_.find(var->vid);
- return it == var_remap_.end() ? var : it->second;
+ if (it != var_remap_.end()) {
+ return it->second;
+ } else {
+ if (fp16_input_names_.count(var->name_hint())) {
+ auto sinfo = GetStructInfo(var);
+ if (auto tensor_sinfo = sinfo.as<TensorStructInfoNode>()) {
+ TensorStructInfo fp16_sinfo(tensor_sinfo->shape.value(),
DataType::Float(16),
+ tensor_sinfo->span);
+ Var fp16_var(var->vid, fp16_sinfo, var->span);
+ var_remap_[var->vid] = fp16_var;
+ return fp16_var;
+ }
+ }
+ return var;
+ }
}
Array<Expr> RemapArgs(const Array<Expr>& args) {
@@ -427,6 +445,8 @@ class ToMixedPrecisionRewriter : public ExprMutator {
return VisitVar_(GetRef<Var>(op));
}
+ Var VisitVarDef(const Var& var) { return GetRemapped(var); }
+
Expr VisitExpr_(const DataflowVarNode* op) final {
if (!builder_->CurrentBlockIsDataFlow()) {
return ExprMutator::VisitExpr_(op);
@@ -561,22 +581,28 @@ class ToMixedPrecisionRewriter : public ExprMutator {
DataType fp32_ = DataType(DataType::TypeCode::kFloat, 32, 1);
DataType output_dtype_;
Array<Var> params_;
+ std::unordered_set<std::string> fp16_input_names_;
const Op& wrap_param_op = Op::Get("relax.wrap_param");
};
-Expr ToMixedPrecision(const Function& f, const DataType& out_dtype) {
+Expr ToMixedPrecision(const Function& f, const DataType& out_dtype,
+ Optional<Array<String>> fp16_input_names) {
VarDTypeMap only_fp16_map = std::move(DTypeDecisionCollector::Collect(f,
out_dtype));
- ToMixedPrecisionRewriter mutator(&only_fp16_map, out_dtype);
+ std::unordered_set<std::string> fp16_input_names_set;
+ if (fp16_input_names) {
+ fp16_input_names_set.insert(fp16_input_names.value().begin(),
fp16_input_names.value().end());
+ }
+ ToMixedPrecisionRewriter mutator(&only_fp16_map, out_dtype,
fp16_input_names_set);
return mutator(f);
}
namespace transform {
-Pass ToMixedPrecision(const DataType& out_dtype) {
+Pass ToMixedPrecision(const DataType& out_dtype, Optional<Array<String>>
fp16_input_names) {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)>
pass_func =
[=](Function f, IRModule m, PassContext pc) {
- return Downcast<Function>(ToMixedPrecision(f, out_dtype));
+ return Downcast<Function>(ToMixedPrecision(f, out_dtype,
fp16_input_names));
};
return CreateFunctionPass(pass_func, 0, "ToMixedPrecision", {});
}
diff --git a/tests/python/relax/test_transform_to_mixed_precision.py
b/tests/python/relax/test_transform_to_mixed_precision.py
index 6bae732927..cb179a8c25 100644
--- a/tests/python/relax/test_transform_to_mixed_precision.py
+++ b/tests/python/relax/test_transform_to_mixed_precision.py
@@ -990,5 +990,51 @@ def test_conv2d_bias_fp32():
_assert_test(Input_bound, expected2=Expected_no_bias_cast)
+def test_convert_sig():
+ @tvm.script.ir_module
+ class Input:
+ @R.function
+ def main(
+ x: R.Tensor((1, 4, 64, 64), dtype="float32"),
+ w: R.Tensor((512, 4, 3, 3), dtype="float32"),
+ bias: R.Tensor((512,), dtype="float32"),
+ ) -> R.Tensor((1, 512, 64, 64), dtype="float32"):
+ # block 0
+ with R.dataflow():
+ lv142: R.Tensor((1, 4, 64, 64), dtype="float32") = R.nn.conv2d(
+ x,
+ w,
+ strides=[1, 1],
+ padding=[0, 0, 0, 0],
+ out_dtype="float32",
+ )
+ lv143: R.Tensor((1, 4, 1, 1), dtype="float32") =
R.reshape(bias, (1, 512, 1, 1))
+ lv144: R.Tensor((1, 4, 64, 64), dtype="float32") =
R.add(lv142, lv143)
+ R.output(lv144)
+ return lv144
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def main(
+ x: R.Tensor((1, 4, 64, 64), dtype="float32"),
+ w: R.Tensor((512, 4, 3, 3), dtype="float16"),
+ bias: R.Tensor((512,), dtype="float16"),
+ ) -> R.Tensor((1, 512, 64, 64), dtype="float32"):
+ with R.dataflow():
+ lv = R.astype(x, dtype="float16")
+ lv142 = R.nn.conv2d(
+ lv, w, strides=[1, 1], padding=[0, 0, 0, 0],
out_dtype="float16"
+ )
+ lv143 = R.reshape(bias, R.shape([1, 512, 1, 1]))
+ lv1 = R.add(lv142, lv143)
+ lv144 = R.astype(lv1, dtype="float32")
+ R.output(lv144)
+ return lv144
+
+ mod = ToMixedPrecision(out_dtype="float16", fp16_input_names=["w",
"bias"])(Input)
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
if __name__ == "__main__":
tvm.testing.main()