comaniac commented on code in PR #11142:
URL: https://github.com/apache/tvm/pull/11142#discussion_r862270681


##########
src/relay/transforms/to_mixed_precision.cc:
##########
@@ -36,6 +36,7 @@
 namespace tvm {
 namespace relay {
 
+TVM_REGISTER_PASS_CONFIG_OPTION("relay.ToMixedPrecision.enable_original_type", 
Bool);

Review Comment:
   This name is a bit confusing. Maybe something like "keep_orig_output_dtype".



##########
src/relay/transforms/to_mixed_precision.cc:
##########
@@ -381,6 +400,11 @@ class MixedPrecisionPass : public MixedModeMutator {
       if (accumulation_dtype != output_dtype) {
         output = CastArg(output, GetType(output), output_dtype);
       }
+      if (pre_call_node == static_cast<const CallNode*>(root_) && 
enable_original_type_) {

Review Comment:
   Could we just use `.same_as`?



##########
tests/python/relay/test_to_mixed_precision.py:
##########
@@ -41,17 +41,31 @@ def verify_mixed_precision_output_close(
     mixed_precision_dtype="float16",
     rtol: float = 1e-3,
     atol: float = 0,
+    enable_original_type=False,
 ) -> tvm.runtime.Module:
 
     mod = InferType()(mod)
     result_fp32 = run_module(mod, mod_params)
-    fp16_mod = ToMixedPrecision(mixed_precision_dtype)(mod)
-    result_fp16 = run_module(fp16_mod, mod_params)
+
+    if enable_original_type == False:

Review Comment:
   ```suggestion
       if not enable_original_type:
   ```
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to