This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new b3d01c2295 [Relax][Bugfix] Preserve dtype in ToMixedPrecision for
kNever ops (#17263)
b3d01c2295 is described below
commit b3d01c2295cde9dcd02980bad49fcd9cd3049231
Author: Eric Lunderberg <[email protected]>
AuthorDate: Sun Aug 11 13:43:09 2024 -0500
[Relax][Bugfix] Preserve dtype in ToMixedPrecision for kNever ops (#17263)
Prior to this commit, while an operator with the
`MixedPrecisionPolicyKind::kNever` attribute would not be updated from
`float32` to `float16`, it would be erroneously updated from `float16`
to `float32`.
This commit updates `ToMixedPrecision` to preserve the datatype of any
arguments used in a `kNever` operation, rather than forcing them to a
`float32` datatype.
---
src/relax/transform/to_mixed_precision.cc | 69 ++++++++++++++--------
.../relax/test_transform_to_mixed_precision.py | 34 ++++++++++-
2 files changed, 75 insertions(+), 28 deletions(-)
diff --git a/src/relax/transform/to_mixed_precision.cc
b/src/relax/transform/to_mixed_precision.cc
index c844d59356..1b660b8fec 100644
--- a/src/relax/transform/to_mixed_precision.cc
+++ b/src/relax/transform/to_mixed_precision.cc
@@ -303,11 +303,7 @@ class ToMixedPrecisionRewriter : public ExprMutator {
}
Array<Expr> RemapArgs(const Array<Expr>& args) {
- Array<Expr> new_args;
- for (const auto& arg : args) {
- new_args.push_back(VarReplacer::Replace(arg, var_remap_));
- }
- return new_args;
+ return args.Map([this](Expr arg) { return VarReplacer::Replace(arg,
var_remap_); });
}
// Util function to rewrite the expr to the given dtype
@@ -475,37 +471,60 @@ class ToMixedPrecisionRewriter : public ExprMutator {
ReEmitBinding(binding, call_node->args[0]);
return;
}
- DataType to;
- ObjectPtr<CallNode> new_call = make_object<CallNode>(*call_node);
+
+ Call new_call = GetRef<Call>(call_node);
+
// We first to remap the args to the current vars according to the
var_remap_
- new_call->args = std::move(RemapArgs(call_node->args));
+ new_call.CopyOnWrite()->args = RemapArgs(new_call->args);
+
// Then we rewrite the args according to the policy
+ std::optional<DataType> opt_new_dtype = std::nullopt;
+
if (policy == kAlways) {
- to = fp16_;
+ opt_new_dtype = fp16_;
auto attr_map =
Op::GetAttrMap<FInferMixedPrecision>("FInferMixedPrecision");
ICHECK(attr_map.count(op));
- auto f = attr_map[op];
- new_call = make_object<CallNode>(*(f(Call(new_call),
output_dtype_).get()));
+ new_call = attr_map[op](new_call, output_dtype_);
} else if (policy == kFollow) {
- to = AllFP16Castable(new_call->args) ? fp16_ : fp32_;
+ opt_new_dtype = AllFP16Castable(new_call->args) ? fp16_ : fp32_;
} else if (policy == kNever) {
- to = fp32_;
+ // An upstream operation may have changed the datatype of the
+ // arguments. Because this operation must be provided with
+ // exactly the same dtype as it previously had, it may require a
+ // cast back to the original datatype.
+
+ if (!new_call->args.same_as(call_node->args)) {
+ Array<Expr> new_typed_args;
+ for (size_t i = 0; i < call_node->args.size(); i++) {
+ auto arg = new_call->args[i];
+ auto old_ntype = NTypeFrom(call_node->args[i]);
+ new_typed_args.push_back(RewriteExpr(arg, old_ntype));
+ }
+ new_call.CopyOnWrite()->args = new_typed_args;
+ }
+
} else {
LOG(FATAL) << "Unsupported TMixedPrecisionPolicy: " << policy;
}
- new_call->args = std::move(RewriteArgs(new_call->args, to));
- new_call->struct_info_ = NullOpt;
- Expr new_value = builder_->Normalize(Call(new_call));
- if (policy == kAlways && binding->var->IsInstance<DataflowVarNode>()) {
- // kAlways: store the tensors to fp16
- // But global vars will be stored to the original dtype anyway (see
below)
- new_value = RewriteExpr(new_value, NTypeFrom(new_value, fp16_));
- }
- if (!binding->var->IsInstance<DataflowVarNode>()) {
- // Global var: store the tensors to the original dtype
- NType to = NTypeFrom(binding->var);
- new_value = RewriteExpr(new_value, to);
+
+ Expr new_value = new_call;
+ if (opt_new_dtype) {
+ auto new_dtype = opt_new_dtype.value();
+ new_call.CopyOnWrite()->args = RewriteArgs(new_call->args, new_dtype);
+ new_call.CopyOnWrite()->struct_info_ = NullOpt;
+
+ new_value = builder_->Normalize(Call(new_call));
+
+ if (!binding->var->IsInstance<DataflowVarNode>()) {
+ // Non-Dataflow var: store the tensors to the original dtype
+ new_value = RewriteExpr(new_value, NTypeFrom(binding->var));
+ } else if (policy == kAlways &&
binding->var->IsInstance<DataflowVarNode>()) {
+ // kAlways: store the tensors to fp16
+ // But non-dataflow vars will be stored to the original dtype anyway
(see above)
+ new_value = RewriteExpr(new_value, NTypeFrom(new_value, new_dtype));
+ }
}
+
ReEmitBinding(binding, builder_->Normalize(new_value));
}
diff --git a/tests/python/relax/test_transform_to_mixed_precision.py
b/tests/python/relax/test_transform_to_mixed_precision.py
index 4ddf47b462..ed10fc95c7 100644
--- a/tests/python/relax/test_transform_to_mixed_precision.py
+++ b/tests/python/relax/test_transform_to_mixed_precision.py
@@ -20,7 +20,7 @@ import tvm
from tvm import relax
import tvm.testing
from tvm.relax.transform import ToMixedPrecision
-from tvm.script.parser import ir as I, relax as R
+from tvm.script.parser import ir as I, relax as R, tir as T
def _assert_test(input, expected=None, expected2=None):
@@ -614,8 +614,8 @@ def test_conv2d_softmax():
x: R.Tensor((2, 3, 28, 28), "float32"), w: R.Tensor((3, 3, 3, 3),
"float32")
) -> R.Tensor(None, "float32", ndim=4):
with R.dataflow():
- gv: R.Tensor((2, 3, 26, 26), "float32") = R.nn.conv2d(x, w,
padding=(1, 1))
- gv1: R.Tensor((2, 3, 26, 26), "float32") = R.nn.softmax(x,
axis=1)
+ gv: R.Tensor((2, 3, 28, 28), "float32") = R.nn.conv2d(x, w,
padding=(1, 1))
+ gv1: R.Tensor((2, 3, 28, 28), "float32") = R.nn.softmax(x,
axis=1)
gv2 = R.add(gv, gv1)
R.output(gv2)
return gv2
@@ -1036,5 +1036,33 @@ def test_convert_sig():
tvm.ir.assert_structural_equal(mod, Expected)
+def test_call_tir_with_float16_args():
+ @I.ir_module
+ class Before:
+ @R.function
+ def main(A: R.Tensor([64], "float16")):
+ cls = Before
+ with R.dataflow():
+ B = R.call_tir(cls.tir_identity, [A], out_sinfo=R.Tensor([64],
"float16"))
+ C = R.call_tir(cls.tir_identity, [B], out_sinfo=R.Tensor([64],
"float16"))
+ R.output(C)
+ return C
+
+ @T.prim_func
+ def tir_identity(
+ Input: T.Buffer(64, "float16"),
+ Output: T.Buffer(64, "float16"),
+ ):
+ for i in range(64):
+ with T.block("copy"):
+ vi = T.axis.remap("S", [i])
+ Output[vi] = Input[vi]
+
+ Expected = Before
+
+ After = ToMixedPrecision()(Before)
+ tvm.ir.assert_structural_equal(Expected, After)
+
+
if __name__ == "__main__":
tvm.testing.main()