This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new b3d01c2295 [Relax][Bugfix] Preserve dtype in ToMixedPrecision for 
kNever ops (#17263)
b3d01c2295 is described below

commit b3d01c2295cde9dcd02980bad49fcd9cd3049231
Author: Eric Lunderberg <[email protected]>
AuthorDate: Sun Aug 11 13:43:09 2024 -0500

    [Relax][Bugfix] Preserve dtype in ToMixedPrecision for kNever ops (#17263)
    
    Prior to this commit, while an operator with the
    `MixedPrecisionPolicyKind::kNever` attribute would not be updated from
    `float32` to `float16`, it would be erroneously updated from `float16`
    to `float32`.
    
    This commit updates `ToMixedPrecision` to preserve the datatype of any
    arguments used in a `kNever` operation, rather than forcing them to a
    `float32` datatype.
---
 src/relax/transform/to_mixed_precision.cc          | 69 ++++++++++++++--------
 .../relax/test_transform_to_mixed_precision.py     | 34 ++++++++++-
 2 files changed, 75 insertions(+), 28 deletions(-)

diff --git a/src/relax/transform/to_mixed_precision.cc 
b/src/relax/transform/to_mixed_precision.cc
index c844d59356..1b660b8fec 100644
--- a/src/relax/transform/to_mixed_precision.cc
+++ b/src/relax/transform/to_mixed_precision.cc
@@ -303,11 +303,7 @@ class ToMixedPrecisionRewriter : public ExprMutator {
   }
 
   Array<Expr> RemapArgs(const Array<Expr>& args) {
-    Array<Expr> new_args;
-    for (const auto& arg : args) {
-      new_args.push_back(VarReplacer::Replace(arg, var_remap_));
-    }
-    return new_args;
+    return args.Map([this](Expr arg) { return VarReplacer::Replace(arg, 
var_remap_); });
   }
 
   // Util function to rewrite the expr to the given dtype
@@ -475,37 +471,60 @@ class ToMixedPrecisionRewriter : public ExprMutator {
       ReEmitBinding(binding, call_node->args[0]);
       return;
     }
-    DataType to;
-    ObjectPtr<CallNode> new_call = make_object<CallNode>(*call_node);
+
+    Call new_call = GetRef<Call>(call_node);
+
     // We first to remap the args to the current vars according to the 
var_remap_
-    new_call->args = std::move(RemapArgs(call_node->args));
+    new_call.CopyOnWrite()->args = RemapArgs(new_call->args);
+
     // Then we rewrite the args according to the policy
+    std::optional<DataType> opt_new_dtype = std::nullopt;
+
     if (policy == kAlways) {
-      to = fp16_;
+      opt_new_dtype = fp16_;
       auto attr_map = 
Op::GetAttrMap<FInferMixedPrecision>("FInferMixedPrecision");
       ICHECK(attr_map.count(op));
-      auto f = attr_map[op];
-      new_call = make_object<CallNode>(*(f(Call(new_call), 
output_dtype_).get()));
+      new_call = attr_map[op](new_call, output_dtype_);
     } else if (policy == kFollow) {
-      to = AllFP16Castable(new_call->args) ? fp16_ : fp32_;
+      opt_new_dtype = AllFP16Castable(new_call->args) ? fp16_ : fp32_;
     } else if (policy == kNever) {
-      to = fp32_;
+      // An upstream operation may have changed the datatype of the
+      // arguments.  Because this operation must be provided with
+      // exactly the same dtype as it previously had, it may require a
+      // cast back to the original datatype.
+
+      if (!new_call->args.same_as(call_node->args)) {
+        Array<Expr> new_typed_args;
+        for (size_t i = 0; i < call_node->args.size(); i++) {
+          auto arg = new_call->args[i];
+          auto old_ntype = NTypeFrom(call_node->args[i]);
+          new_typed_args.push_back(RewriteExpr(arg, old_ntype));
+        }
+        new_call.CopyOnWrite()->args = new_typed_args;
+      }
+
     } else {
       LOG(FATAL) << "Unsupported TMixedPrecisionPolicy: " << policy;
     }
-    new_call->args = std::move(RewriteArgs(new_call->args, to));
-    new_call->struct_info_ = NullOpt;
-    Expr new_value = builder_->Normalize(Call(new_call));
-    if (policy == kAlways && binding->var->IsInstance<DataflowVarNode>()) {
-      // kAlways: store the tensors to fp16
-      // But global vars will be stored to the original dtype anyway (see 
below)
-      new_value = RewriteExpr(new_value, NTypeFrom(new_value, fp16_));
-    }
-    if (!binding->var->IsInstance<DataflowVarNode>()) {
-      // Global var: store the tensors to the original dtype
-      NType to = NTypeFrom(binding->var);
-      new_value = RewriteExpr(new_value, to);
+
+    Expr new_value = new_call;
+    if (opt_new_dtype) {
+      auto new_dtype = opt_new_dtype.value();
+      new_call.CopyOnWrite()->args = RewriteArgs(new_call->args, new_dtype);
+      new_call.CopyOnWrite()->struct_info_ = NullOpt;
+
+      new_value = builder_->Normalize(Call(new_call));
+
+      if (!binding->var->IsInstance<DataflowVarNode>()) {
+        // Non-Dataflow var: store the tensors to the original dtype
+        new_value = RewriteExpr(new_value, NTypeFrom(binding->var));
+      } else if (policy == kAlways && 
binding->var->IsInstance<DataflowVarNode>()) {
+        // kAlways: store the tensors to fp16
+        // But non-dataflow vars will be stored to the original dtype anyway 
(see above)
+        new_value = RewriteExpr(new_value, NTypeFrom(new_value, new_dtype));
+      }
     }
+
     ReEmitBinding(binding, builder_->Normalize(new_value));
   }
 
diff --git a/tests/python/relax/test_transform_to_mixed_precision.py 
b/tests/python/relax/test_transform_to_mixed_precision.py
index 4ddf47b462..ed10fc95c7 100644
--- a/tests/python/relax/test_transform_to_mixed_precision.py
+++ b/tests/python/relax/test_transform_to_mixed_precision.py
@@ -20,7 +20,7 @@ import tvm
 from tvm import relax
 import tvm.testing
 from tvm.relax.transform import ToMixedPrecision
-from tvm.script.parser import ir as I, relax as R
+from tvm.script.parser import ir as I, relax as R, tir as T
 
 
 def _assert_test(input, expected=None, expected2=None):
@@ -614,8 +614,8 @@ def test_conv2d_softmax():
             x: R.Tensor((2, 3, 28, 28), "float32"), w: R.Tensor((3, 3, 3, 3), 
"float32")
         ) -> R.Tensor(None, "float32", ndim=4):
             with R.dataflow():
-                gv: R.Tensor((2, 3, 26, 26), "float32") = R.nn.conv2d(x, w, 
padding=(1, 1))
-                gv1: R.Tensor((2, 3, 26, 26), "float32") = R.nn.softmax(x, 
axis=1)
+                gv: R.Tensor((2, 3, 28, 28), "float32") = R.nn.conv2d(x, w, 
padding=(1, 1))
+                gv1: R.Tensor((2, 3, 28, 28), "float32") = R.nn.softmax(x, 
axis=1)
                 gv2 = R.add(gv, gv1)
                 R.output(gv2)
             return gv2
@@ -1036,5 +1036,33 @@ def test_convert_sig():
     tvm.ir.assert_structural_equal(mod, Expected)
 
 
+def test_call_tir_with_float16_args():
+    @I.ir_module
+    class Before:
+        @R.function
+        def main(A: R.Tensor([64], "float16")):
+            cls = Before
+            with R.dataflow():
+                B = R.call_tir(cls.tir_identity, [A], out_sinfo=R.Tensor([64], 
"float16"))
+                C = R.call_tir(cls.tir_identity, [B], out_sinfo=R.Tensor([64], 
"float16"))
+                R.output(C)
+            return C
+
+        @T.prim_func
+        def tir_identity(
+            Input: T.Buffer(64, "float16"),
+            Output: T.Buffer(64, "float16"),
+        ):
+            for i in range(64):
+                with T.block("copy"):
+                    vi = T.axis.remap("S", [i])
+                    Output[vi] = Input[vi]
+
+    Expected = Before
+
+    After = ToMixedPrecision()(Before)
+    tvm.ir.assert_structural_equal(Expected, After)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to