This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 09a0ea0cd4 [FIX][TVMC] Fix the mixed precision conversion pipeline
09a0ea0cd4 is described below
commit 09a0ea0cd4f6f99342966c06fd371d9b8af84dbe
Author: Siva <[email protected]>
AuthorDate: Sun Jan 26 12:48:24 2025 +0530
[FIX][TVMC] Fix the mixed precision conversion pipeline
Fixed the mixed precision conversion pipeline issue.
---
python/tvm/driver/tvmc/autotuner.py | 4 ++--
python/tvm/driver/tvmc/compiler.py | 2 +-
python/tvm/driver/tvmc/transform.py | 5 ++++-
tests/python/driver/tvmc/test_transform.py | 2 ++
tests/python/relay/opencl_texture/test_network.py | 1 +
5 files changed, 10 insertions(+), 4 deletions(-)
diff --git a/python/tvm/driver/tvmc/autotuner.py
b/python/tvm/driver/tvmc/autotuner.py
index 82b5cc1598..ad4f8ae616 100644
--- a/python/tvm/driver/tvmc/autotuner.py
+++ b/python/tvm/driver/tvmc/autotuner.py
@@ -672,7 +672,7 @@ def autotvm_get_tuning_tasks(
"""
target, target_host = Target.canon_target_and_host(target, target_host)
- mod = apply_graph_transforms(mod, transform_args)
+ mod = apply_graph_transforms(mod, transform_args, params)
tasks = autotvm.task.extract_from_program(
mod["main"],
@@ -718,7 +718,7 @@ def autoscheduler_get_tuning_tasks(
"""
target, target_host = Target.canon_target_and_host(target, target_host)
- mod = apply_graph_transforms(mod, transform_args)
+ mod = apply_graph_transforms(mod, transform_args, params)
# Extract the tasks
tasks, task_weights = auto_scheduler.extract_tasks(
diff --git a/python/tvm/driver/tvmc/compiler.py
b/python/tvm/driver/tvmc/compiler.py
index 43c53e8859..058ae62d18 100644
--- a/python/tvm/driver/tvmc/compiler.py
+++ b/python/tvm/driver/tvmc/compiler.py
@@ -401,7 +401,7 @@ def compile_model(
instruments=instruments,
):
transform_args = parse_graph_transform_args(locals())
- mod = apply_graph_transforms(mod, transform_args)
+ mod = apply_graph_transforms(mod, transform_args, params)
for partition_function, opts in zip(partition_functions,
partition_opts):
mod = partition_function(mod, params, mod_name=mod_name, **opts)
diff --git a/python/tvm/driver/tvmc/transform.py
b/python/tvm/driver/tvmc/transform.py
index 30d9bfa639..253c624e6e 100644
--- a/python/tvm/driver/tvmc/transform.py
+++ b/python/tvm/driver/tvmc/transform.py
@@ -162,7 +162,7 @@ def convert_graph_layout(mod, desired_layouts, ops=None):
raise TVMCException("Error converting layouts: {}".format(str(err)))
-def apply_graph_transforms(mod, args):
+def apply_graph_transforms(mod, args, params=None):
"""Alter the layout of the input graph.
Parameters
@@ -171,6 +171,8 @@ def apply_graph_transforms(mod, args):
The relay module to convert.
args : dict
The transform arguments.
+ params: dict
+ Module params
Returns
-------
@@ -188,6 +190,7 @@ def apply_graph_transforms(mod, args):
# ToMixedPrecision
if args.get("mixed_precision", False):
+ mod = relay.quantize.prerequisite_optimize(mod, params)
mod = convert_to_mixed_precision(
mod,
args.get("mixed_precision_ops"),
diff --git a/tests/python/driver/tvmc/test_transform.py
b/tests/python/driver/tvmc/test_transform.py
index 06af3cb156..ebf067990d 100644
--- a/tests/python/driver/tvmc/test_transform.py
+++ b/tests/python/driver/tvmc/test_transform.py
@@ -226,6 +226,7 @@ def
test_layout_transform_to_mixed_precision_pass_args_graph():
"mixed_precision_calculation_type": "float16",
"mixed_precision_acc_type": "float16",
},
+ params,
)
ret = CheckOpMutator("float16", "float16", "nn.conv2d").check(mod["main"])
assert ret
@@ -240,6 +241,7 @@ def
test_layout_transform_to_mixed_precision_pass_args_graph():
"mixed_precision_calculation_type": "float16",
"mixed_precision_acc_type": "float32",
},
+ params,
)
ret = CheckOpMutator("float16", "float32", "nn.conv2d").check(mod["main"])
assert ret
diff --git a/tests/python/relay/opencl_texture/test_network.py
b/tests/python/relay/opencl_texture/test_network.py
index 2b2f3741cb..66c88ebbe2 100644
--- a/tests/python/relay/opencl_texture/test_network.py
+++ b/tests/python/relay/opencl_texture/test_network.py
@@ -47,6 +47,7 @@ def _test_mobilenet_v1(remote, target, calc_dtype,
executor_type, acc_dtype):
"mixed_precision_calculation_type": calc_dtype,
"mixed_precision_acc_type": acc_dtype,
},
+ params,
)
if executor_type == "ge":