This is an automated email from the ASF dual-hosted git repository.

cbalint13 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new dcecb86291 [Docs] Improve static shape tuning parameter configuration 
(follow-up to commit c71aefc) (#18545)
dcecb86291 is described below

commit dcecb862916314fffb1eb974d61da8f485a259a3
Author: ConvolutedDog <[email protected]>
AuthorDate: Sun Dec 7 06:28:46 2025 +0800

    [Docs] Improve static shape tuning parameter configuration (follow-up to 
commit c71aefc) (#18545)
    
    - Expose max_trials_per_task parameter to static_shape_tuning_pipeline
    - Adjust default TOTAL_TRIALS from 8000 to 80 for tutorial demonstration
    purposes
    - Add documentation for tuning parameters in tutorial, clarifying
    relationship between MAX_TRIALS_PER_TASK and TOTAL_TRIALS
---
 docs/how_to/tutorials/e2e_opt_model.py | 29 +++++++++++++++++++++++++++--
 python/tvm/relax/pipeline.py           | 21 +++++++++++++++++++--
 2 files changed, 46 insertions(+), 4 deletions(-)

diff --git a/docs/how_to/tutorials/e2e_opt_model.py 
b/docs/how_to/tutorials/e2e_opt_model.py
index 8307ddc4f2..507864160d 100644
--- a/docs/how_to/tutorials/e2e_opt_model.py
+++ b/docs/how_to/tutorials/e2e_opt_model.py
@@ -95,13 +95,38 @@ if not IS_IN_CI:
 # leverage MetaSchedule to tune the model and store the tuning logs to the 
database. We also
 # apply the database to the model to get the best performance.
 #
+# The ResNet18 model will be divided into 20 independent tuning tasks during 
compilation.
+# To ensure each task receives adequate tuning resources in one iteration 
while providing
+# early feedback:
+#
+# - To quickly observe tuning progress, each task is allocated a maximum of 16 
trials per
+#   iteration (controlled by ``MAX_TRIALS_PER_TASK=16``). We should set 
``TOTAL_TRIALS``
+#   to at least ``320 (20 tasks * 16 trials)`` ensures every task receives one 
full iteration
+#   of tuning. We set it to 512 in our configuration to allow for several more 
iterations,
+#   aiming to explore a wider parameter space and potentially achieve better 
performance.
+# - If ``MAX_TRIALS_PER_TASK == None``, the system defaults to 
``TOTAL_TRIALS`` trials per
+#   task per iteration. An insufficient ``TOTAL_TRIALS`` setting may lead to 
undersubscribed
+#   tuning, potentially skipping some tasks entirely. Explicitly setting both 
parameters
+#   avoids this issue and provides deterministic resource allocation across 
all tasks.
+#
+# Note: These parameter settings are optimized for quick tutorial 
demonstration. For production
+# deployments requiring higher performance, we recommend adjusting both 
``MAX_TRIALS_PER_TASK``
+# and ``TOTAL_TRIALS`` to larger values. This allows more extensive search 
space exploration
+# and typically yields better performance outcomes.
 
-TOTAL_TRIALS = 8000  # Change to 20000 for better performance if needed
+TOTAL_TRIALS = 512  # Change to 20000 for better performance if needed
+MAX_TRIALS_PER_TASK = 16  # Change to more trials per task for better 
performance if needed
 target = tvm.target.Target("nvidia/geforce-rtx-3090-ti")  # Change to your 
target device
 work_dir = "tuning_logs"
 
 if not IS_IN_CI:
-    mod = relax.get_pipeline("static_shape_tuning", target=target, 
total_trials=TOTAL_TRIALS)(mod)
+    mod = relax.get_pipeline(
+        "static_shape_tuning",
+        target=target,
+        work_dir=work_dir,
+        total_trials=TOTAL_TRIALS,
+        max_trials_per_task=MAX_TRIALS_PER_TASK,
+    )(mod)
 
     # Only show the main function
     mod["main"].show()
diff --git a/python/tvm/relax/pipeline.py b/python/tvm/relax/pipeline.py
index a5850267a8..388f9dbb43 100644
--- a/python/tvm/relax/pipeline.py
+++ b/python/tvm/relax/pipeline.py
@@ -21,7 +21,7 @@ This namespace offers a pre-defined collection that can be 
used
 as it is or serves as a basis to do further composition.
 """
 # pylint: disable=unused-argument
-from typing import Union
+from typing import Union, Optional
 
 import tvm
 from tvm import meta_schedule as ms
@@ -111,6 +111,7 @@ def static_shape_tuning_pipeline(
     target: Union[str, tvm.target.Target],
     work_dir: str = "tuning_logs",
     cpu_weight_prepack: bool = False,
+    max_trials_per_task: Optional[int] = None,
 ):
     """Tune the static shape model and store the log to database.
 
@@ -128,6 +129,16 @@ def static_shape_tuning_pipeline(
     cpu_weight_prepack : bool
         Whether to enable the cpu weight prepack feature.
 
+    max_trials_per_task : Optional[int]
+        The maximum number of trials to run per task.
+        If not specified, it defaults to the value of `total_trials`, and this
+        may lead to undersubscribed tuning, potentially skipping some tasks
+        entirely. Explicitly setting both parameters avoids this issue and
+        provides deterministic resource allocation across all tasks.
+        For optimal tuning, set `total_trials` to at least
+        `max_trials_per_task * number_of_tuning_tasks` to ensure
+        each task receives adequate tuning resources in one iteration.
+
     Note
     ----
     `cpu_weight_prepack` is expected to be `True` when running on CPU for
@@ -142,6 +153,7 @@ def static_shape_tuning_pipeline(
             target="llvm -num-cores 16",
             work_dir="tuning_logs",
             cpu_weight_prepack=True,
+            max_trials_per_task=64,
         )(mod)
 
         ex = tvm.compile(mod, target=target)
@@ -177,7 +189,12 @@ def static_shape_tuning_pipeline(
                     *pre_tuning_layout_rewrite,
                     # Skip tuning if total_trials is 0
                     (
-                        transform.MetaScheduleTuneIRMod({}, work_dir, 
total_trials)
+                        transform.MetaScheduleTuneIRMod(
+                            params={},
+                            work_dir=work_dir,
+                            max_trials_global=total_trials,
+                            max_trials_per_task=max_trials_per_task,
+                        )
                         if total_trials > 0
                         else tvm.transform.Sequential([])
                     ),

Reply via email to