This is an automated email from the ASF dual-hosted git repository.
echuraev pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new c93f0bae9b [Meta-Schedule][OpenCL] Enable MS tuning for Android OpenCL
(#16846)
c93f0bae9b is described below
commit c93f0bae9bf9aa3bd42f3239d4e4a0f2da37ee84
Author: Egor Churaev <[email protected]>
AuthorDate: Fri Apr 5 09:52:41 2024 +0300
[Meta-Schedule][OpenCL] Enable MS tuning for Android OpenCL (#16846)
Added OpenCL as a GPU target for Meta-Scheduler. Implemented export
function for Android which can be used when MS builder is configured.
Added an integration test which checks that MS tuning on Android GPU
works fine.
---
python/tvm/contrib/ndk.py | 12 ++++
src/meta_schedule/utils.h | 3 +-
tests/python/contrib/test_android/__init__.py | 18 ++++++
.../python/contrib/test_android/infrastructure.py | 57 +++++++++++++++++
.../contrib/test_android/test_meta_schedule.py | 71 ++++++++++++++++++++++
5 files changed, 160 insertions(+), 1 deletion(-)
diff --git a/python/tvm/contrib/ndk.py b/python/tvm/contrib/ndk.py
index 2a1105ed2b..14820c0ca8 100644
--- a/python/tvm/contrib/ndk.py
+++ b/python/tvm/contrib/ndk.py
@@ -22,7 +22,10 @@ import subprocess
import os
import shutil
from typing import Dict
+import tempfile
+from pathlib import Path
+from .._ffi import register_func
from .._ffi.base import py_str
from . import utils as _utils, tar as _tar, cc as _cc
from .cc import get_target_by_dump_machine
@@ -152,3 +155,12 @@ def get_global_symbol_section_map(path, *, nm=None) ->
Dict[str, str]:
base_path = os.path.dirname(compiler)
nm = os.path.join(base_path, "llvm-nm")
return _cc.get_global_symbol_section_map(path, nm=nm)
+
+
+@register_func("meta_schedule.builder.export_ndk")
+def _ndk_export(mod):
+ tmp_dir = tempfile.mkdtemp()
+ binary_name = "tmp_binary.so"
+ binary_path = Path(tmp_dir) / binary_name
+ mod.export_library(binary_path, fcompile=create_shared)
+ return str(binary_path)
diff --git a/src/meta_schedule/utils.h b/src/meta_schedule/utils.h
index 60840ca163..ceb0356cbc 100644
--- a/src/meta_schedule/utils.h
+++ b/src/meta_schedule/utils.h
@@ -513,7 +513,8 @@ inline void CloneRules(const SpaceGeneratorNode* src,
SpaceGeneratorNode* dst) {
/*! \brief Returns true if the given target is one of the supported gpu
targets. */
inline bool IsGPUTarget(const std::string& target_name) {
- static const std::unordered_set<std::string> gpu_targets{"cuda", "rocm",
"vulkan", "metal"};
+ static const std::unordered_set<std::string> gpu_targets{"cuda", "rocm",
"vulkan", "metal",
+ "opencl"};
return gpu_targets.count(target_name);
}
diff --git a/tests/python/contrib/test_android/__init__.py
b/tests/python/contrib/test_android/__init__.py
new file mode 100644
index 0000000000..9669578bb7
--- /dev/null
+++ b/tests/python/contrib/test_android/__init__.py
@@ -0,0 +1,18 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+""" Testing infrastructure for Android """
diff --git a/tests/python/contrib/test_android/infrastructure.py
b/tests/python/contrib/test_android/infrastructure.py
new file mode 100644
index 0000000000..b78d0bb40e
--- /dev/null
+++ b/tests/python/contrib/test_android/infrastructure.py
@@ -0,0 +1,57 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name
+
+""" Android testing infrastructure """
+
+import os
+import tvm
+from tvm.meta_schedule.runner import RPCRunner, RPCConfig, EvaluatorConfig
+
+
+def get_rpc_runner() -> tvm.meta_schedule.runner.RPCRunner:
+ if (
+ "TVM_TRACKER_HOST" in os.environ
+ and "TVM_TRACKER_PORT" in os.environ
+ and "RPC_DEVICE_KEY" in os.environ
+ ):
+ rpc_host = os.environ["TVM_TRACKER_HOST"]
+ rpc_port = int(os.environ["TVM_TRACKER_PORT"])
+ rpc_key = os.environ["RPC_DEVICE_KEY"]
+ else:
+ raise Exception("Please initialize environment variables for using RPC
tracker")
+
+ rpc_config = RPCConfig(
+ tracker_host=rpc_host,
+ tracker_port=rpc_port,
+ tracker_key=rpc_key,
+ session_priority=1,
+ session_timeout_sec=100,
+ )
+ evaluator_config = EvaluatorConfig(
+ number=1,
+ repeat=1,
+ min_repeat_ms=0,
+ )
+ return RPCRunner(rpc_config, evaluator_config)
+
+
+def get_android_gpu_target() -> tvm.target.Target:
+ """Creates a Android GPU target"""
+ target_c = "opencl"
+ target_h = "llvm -mtriple=arm64-linux-android"
+ return tvm.target.Target(target_c, host=target_h)
diff --git a/tests/python/contrib/test_android/test_meta_schedule.py
b/tests/python/contrib/test_android/test_meta_schedule.py
new file mode 100644
index 0000000000..eac5fab303
--- /dev/null
+++ b/tests/python/contrib/test_android/test_meta_schedule.py
@@ -0,0 +1,71 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+""" Test rpc based launcher for Android """
+import tempfile
+
+import numpy as np
+import pytest
+import tvm.testing
+import tvm.topi.testing
+from tvm import meta_schedule as ms
+from tvm.meta_schedule.builder import LocalBuilder
+from tvm.script import tir as T
+
+from .infrastructure import get_android_gpu_target, get_rpc_runner
+
+
[email protected]_func
+def matmul(a: T.handle, b: T.handle, c: T.handle) -> None:
+ A = T.match_buffer(a, [128, 128])
+ B = T.match_buffer(b, [128, 128])
+ C = T.match_buffer(c, [128, 128])
+ for i, j, k in T.grid(128, 128, 128):
+ with T.block("update"):
+ vi, vj, vk = T.axis.remap("SSR", [i, j, k])
+ with T.init():
+ C[vi, vj] = 0.0
+ C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
+
+
[email protected]("Integration test")
+def test_tune_tir_on_android():
+ """Test tune_tir on Android through RPC."""
+ max_workers = 4
+ builder = LocalBuilder(f_export="meta_schedule.builder.export_ndk",
max_workers=max_workers)
+ runner = get_rpc_runner()
+ target = get_android_gpu_target()
+ with tempfile.TemporaryDirectory() as work_dir:
+ database = ms.tir_integration.tune_tir(
+ mod=matmul,
+ target=target,
+ work_dir=work_dir,
+ max_trials_global=32,
+ num_trials_per_iter=16,
+ builder=builder,
+ runner=runner,
+ )
+ sch = ms.tir_integration.compile_tir(database, matmul, target)
+ if sch is None:
+ print("No valid schedule found!")
+ else:
+ sch.mod.show()
+ sch.trace.show()
+
+
+if __name__ == "__main__":
+ tvm.testing.main()