This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 09a0ea0cd4 [FIX][TVMC] Fix the mixed precision conversion pipeline
09a0ea0cd4 is described below

commit 09a0ea0cd4f6f99342966c06fd371d9b8af84dbe
Author: Siva <[email protected]>
AuthorDate: Sun Jan 26 12:48:24 2025 +0530

    [FIX][TVMC] Fix the mixed precision conversion pipeline
    
    Fixed the mixed precision conversion pipeline issue.
---
 python/tvm/driver/tvmc/autotuner.py               | 4 ++--
 python/tvm/driver/tvmc/compiler.py                | 2 +-
 python/tvm/driver/tvmc/transform.py               | 5 ++++-
 tests/python/driver/tvmc/test_transform.py        | 2 ++
 tests/python/relay/opencl_texture/test_network.py | 1 +
 5 files changed, 10 insertions(+), 4 deletions(-)

diff --git a/python/tvm/driver/tvmc/autotuner.py 
b/python/tvm/driver/tvmc/autotuner.py
index 82b5cc1598..ad4f8ae616 100644
--- a/python/tvm/driver/tvmc/autotuner.py
+++ b/python/tvm/driver/tvmc/autotuner.py
@@ -672,7 +672,7 @@ def autotvm_get_tuning_tasks(
     """
     target, target_host = Target.canon_target_and_host(target, target_host)
 
-    mod = apply_graph_transforms(mod, transform_args)
+    mod = apply_graph_transforms(mod, transform_args, params)
 
     tasks = autotvm.task.extract_from_program(
         mod["main"],
@@ -718,7 +718,7 @@ def autoscheduler_get_tuning_tasks(
     """
     target, target_host = Target.canon_target_and_host(target, target_host)
 
-    mod = apply_graph_transforms(mod, transform_args)
+    mod = apply_graph_transforms(mod, transform_args, params)
 
     # Extract the tasks
     tasks, task_weights = auto_scheduler.extract_tasks(
diff --git a/python/tvm/driver/tvmc/compiler.py 
b/python/tvm/driver/tvmc/compiler.py
index 43c53e8859..058ae62d18 100644
--- a/python/tvm/driver/tvmc/compiler.py
+++ b/python/tvm/driver/tvmc/compiler.py
@@ -401,7 +401,7 @@ def compile_model(
         instruments=instruments,
     ):
         transform_args = parse_graph_transform_args(locals())
-        mod = apply_graph_transforms(mod, transform_args)
+        mod = apply_graph_transforms(mod, transform_args, params)
 
         for partition_function, opts in zip(partition_functions, 
partition_opts):
             mod = partition_function(mod, params, mod_name=mod_name, **opts)
diff --git a/python/tvm/driver/tvmc/transform.py 
b/python/tvm/driver/tvmc/transform.py
index 30d9bfa639..253c624e6e 100644
--- a/python/tvm/driver/tvmc/transform.py
+++ b/python/tvm/driver/tvmc/transform.py
@@ -162,7 +162,7 @@ def convert_graph_layout(mod, desired_layouts, ops=None):
         raise TVMCException("Error converting layouts: {}".format(str(err)))
 
 
-def apply_graph_transforms(mod, args):
+def apply_graph_transforms(mod, args, params=None):
     """Alter the layout of the input graph.
 
     Parameters
@@ -171,6 +171,8 @@ def apply_graph_transforms(mod, args):
         The relay module to convert.
     args : dict
         The transform arguments.
+    params: dict
+        Module params
 
     Returns
     -------
@@ -188,6 +190,7 @@ def apply_graph_transforms(mod, args):
 
     # ToMixedPrecision
     if args.get("mixed_precision", False):
+        mod = relay.quantize.prerequisite_optimize(mod, params)
         mod = convert_to_mixed_precision(
             mod,
             args.get("mixed_precision_ops"),
diff --git a/tests/python/driver/tvmc/test_transform.py 
b/tests/python/driver/tvmc/test_transform.py
index 06af3cb156..ebf067990d 100644
--- a/tests/python/driver/tvmc/test_transform.py
+++ b/tests/python/driver/tvmc/test_transform.py
@@ -226,6 +226,7 @@ def 
test_layout_transform_to_mixed_precision_pass_args_graph():
             "mixed_precision_calculation_type": "float16",
             "mixed_precision_acc_type": "float16",
         },
+        params,
     )
     ret = CheckOpMutator("float16", "float16", "nn.conv2d").check(mod["main"])
     assert ret
@@ -240,6 +241,7 @@ def 
test_layout_transform_to_mixed_precision_pass_args_graph():
             "mixed_precision_calculation_type": "float16",
             "mixed_precision_acc_type": "float32",
         },
+        params,
     )
     ret = CheckOpMutator("float16", "float32", "nn.conv2d").check(mod["main"])
     assert ret
diff --git a/tests/python/relay/opencl_texture/test_network.py 
b/tests/python/relay/opencl_texture/test_network.py
index 2b2f3741cb..66c88ebbe2 100644
--- a/tests/python/relay/opencl_texture/test_network.py
+++ b/tests/python/relay/opencl_texture/test_network.py
@@ -47,6 +47,7 @@ def _test_mobilenet_v1(remote, target, calc_dtype, 
executor_type, acc_dtype):
                 "mixed_precision_calculation_type": calc_dtype,
                 "mixed_precision_acc_type": acc_dtype,
             },
+            params,
         )
 
     if executor_type == "ge":

Reply via email to