This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 7c4c2c2f42 [Disco][Fix] Remove Dependency to PyTest (#15886)
7c4c2c2f42 is described below
commit 7c4c2c2f427c458095dac82becd090c99faeb73a
Author: Junru Shao <[email protected]>
AuthorDate: Sat Oct 7 09:09:30 2023 -0700
[Disco][Fix] Remove Dependency to PyTest (#15886)
Disco worker originally automatically import `tvm.testing.disco` for
convenient unittesting. However, `tvm.testing` is a special subpackage
that introduces many unnecessary dependencies, for example, pytest. This
PR removes such dependencies by directly moving the testing function
registration logic to the entry file.
---
python/tvm/exec/disco_worker.py | 37 ++++++++++++++++++++++++++--
python/tvm/testing/__init__.py | 2 +-
python/tvm/testing/disco.py | 53 -----------------------------------------
3 files changed, 36 insertions(+), 56 deletions(-)
diff --git a/python/tvm/exec/disco_worker.py b/python/tvm/exec/disco_worker.py
index 9faa5742ae..b5eea6328d 100644
--- a/python/tvm/exec/disco_worker.py
+++ b/python/tvm/exec/disco_worker.py
@@ -20,8 +20,41 @@ import os
import sys
from tvm import runtime as _ # pylint: disable=unused-import
-from tvm._ffi import get_global_func
-from tvm.testing import disco as _ # pylint: disable=unused-import
+from tvm._ffi import get_global_func, register_func
+from tvm.runtime import NDArray, ShapeTuple, String
+from tvm.runtime.ndarray import array
+
+
+@register_func("tests.disco.add_one")
+def _add_one(x: int) -> int: # pylint: disable=invalid-name
+ return x + 1
+
+
+@register_func("tests.disco.add_one_float", override=True)
+def _add_one_float(x: float): # pylint: disable=invalid-name
+ return x + 0.5
+
+
+@register_func("tests.disco.add_one_ndarray", override=True)
+def _add_one_ndarray(x: NDArray) -> NDArray: # pylint: disable=invalid-name
+ return array(x.numpy() + 1)
+
+
+@register_func("tests.disco.str", override=True)
+def _str_func(x: str): # pylint: disable=invalid-name
+ return x + "_suffix"
+
+
+@register_func("tests.disco.str_obj", override=True)
+def _str_obj_func(x: String): # pylint: disable=invalid-name
+ assert isinstance(x, String)
+ return String(x + "_suffix")
+
+
+@register_func("tests.disco.shape_tuple", override=True)
+def _shape_tuple_func(x: ShapeTuple): # pylint: disable=invalid-name
+ assert isinstance(x, ShapeTuple)
+ return ShapeTuple(list(x) + [4, 5])
def main():
diff --git a/python/tvm/testing/__init__.py b/python/tvm/testing/__init__.py
index 9aa1a31933..041207b66f 100644
--- a/python/tvm/testing/__init__.py
+++ b/python/tvm/testing/__init__.py
@@ -17,7 +17,7 @@
# pylint: disable=redefined-builtin, wildcard-import
"""Utility Python functions for TVM testing"""
-from . import auto_scheduler, autotvm, disco
+from . import auto_scheduler, autotvm
from ._ffi_api import (
ErrorTest,
FrontendTestModule,
diff --git a/python/tvm/testing/disco.py b/python/tvm/testing/disco.py
deleted file mode 100644
index c13e83b7c4..0000000000
--- a/python/tvm/testing/disco.py
+++ /dev/null
@@ -1,53 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-# pylint: disable=invalid-name, missing-function-docstring,
missing-class-docstring
-"""Common utilities for testing disco"""
-from tvm._ffi import register_func
-from tvm.runtime import NDArray, ShapeTuple, String
-from tvm.runtime.ndarray import array
-
-
-@register_func("tests.disco.add_one")
-def add_one(x: int) -> int: # pylint: disable=invalid-name
- return x + 1
-
-
-@register_func("tests.disco.add_one_float", override=True)
-def add_one_float(x: float): # pylint: disable=invalid-name
- return x + 0.5
-
-
-@register_func("tests.disco.add_one_ndarray", override=True)
-def add_one_ndarray(x: NDArray) -> NDArray: # pylint: disable=invalid-name
- return array(x.numpy() + 1)
-
-
-@register_func("tests.disco.str", override=True)
-def str_func(x: str): # pylint: disable=invalid-name
- return x + "_suffix"
-
-
-@register_func("tests.disco.str_obj", override=True)
-def str_obj_func(x: String): # pylint: disable=invalid-name
- assert isinstance(x, String)
- return String(x + "_suffix")
-
-
-@register_func("tests.disco.shape_tuple", override=True)
-def shape_tuple_func(x: ShapeTuple): # pylint: disable=invalid-name
- assert isinstance(x, ShapeTuple)
- return ShapeTuple(list(x) + [4, 5])