masahi commented on code in PR #12895:
URL: https://github.com/apache/tvm/pull/12895#discussion_r988620495
##########
python/tvm/meta_schedule/relay_integration.py:
##########
@@ -15,28 +15,101 @@
# specific language governing permissions and limitations
# under the License.
"""MetaSchedule-Relay integration"""
-from typing import Any, Dict, List, Optional
+from contextlib import contextmanager
+from types import MappingProxyType
+from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple,
Union
+# isort: off
+from typing_extensions import Literal
+
+# isort: on
import numpy as np # type: ignore
from tvm import nd
from tvm._ffi import get_global_func
from tvm.ir import IRModule, transform
from tvm.runtime import NDArray
from tvm.target import Target
+from .builder import Builder
+from .cost_model import CostModel
+from .database import Database
from .extracted_task import ExtractedTask
-from .utils import autotvm_silencer
+from .logging import get_loggers_from_work_dir
+from .measure_callback import MeasureCallback
+from .profiler import Profiler
+from .runner import Runner
+from .search_strategy import SearchStrategy
+from .space_generator import SpaceGenerator
+from .task_scheduler import TaskScheduler
+from .tune import tune_tasks
+from .tune_context import TuneContext
+from .utils import fork_seed
+
+if TYPE_CHECKING:
+ from tvm import relay
+
+_extract_task = get_global_func( # pylint: disable=invalid-name
+ "relay.backend.MetaScheduleExtractTask",
+ allow_missing=True,
+)
+
+
+@contextmanager
+def _autotvm_silencer():
+ """A context manager that silences autotvm warnings."""
+ from tvm import autotvm # pylint: disable=import-outside-toplevel
+
+ silent = autotvm.GLOBAL_SCOPE.silent
+ autotvm.GLOBAL_SCOPE.silent = True
+ try:
+ yield
+ finally:
+ autotvm.GLOBAL_SCOPE.silent = silent
+
+
+def _normalize_params(
+ mod: IRModule,
+ target: Union[Target, str],
+ params: Optional[Dict[str, NDArray]],
+ pass_config: Mapping[str, Any],
+ executor: Optional["relay.backend.Executor"],
+) -> Tuple[IRModule, Target, Dict[str, NDArray], Dict[str, Any],
"relay.backend.Executor"]:
+ from tvm import relay # pylint: disable=import-outside-toplevel
+
+ if isinstance(mod, relay.Function):
+ mod = IRModule.from_expr(mod)
+ if not isinstance(target, Target):
+ target = Target(target)
+ if params is None:
+ params = {}
+ relay_params = {}
+ for name, param in params.items():
+ if isinstance(param, np.ndarray):
+ param = nd.array(param)
+ relay_params[name] = param
+ if executor is None:
+ executor = relay.backend.Executor("graph")
+ pass_config = dict(pass_config)
+ pass_config.setdefault(
+ "relay.FuseOps.link_params",
+ executor.attrs.get("link_params", False),
+ )
Review Comment:
Instead of modifying `pass_config`, we can use `mod.with_attr("executor",
executor)` similarly to
https://github.com/apache/tvm/blob/d4bf9ecf5524d265916ac7b860b0027f5eee5c49/src/relay/backend/build_module.cc#L414-L415.
This lets us align `link-params` attribute lookup in task extraction and
`relay.build` (instead of relying on the ad hoc config
`relay.FuseOps.link_params` for the former.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]