This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 3c7adfb1f7 Use `packaging.version.parse` instead of 
`distutils.version.LooseVersion` (#17173)
3c7adfb1f7 is described below

commit 3c7adfb1f7015078903ba53cc5317ead1b4f5f32
Author: Masahiro Hiramori <[email protected]>
AuthorDate: Sat Jul 20 04:00:01 2024 +0900

    Use `packaging.version.parse` instead of `distutils.version.LooseVersion` 
(#17173)
    
    use `packaging.version.parse` instead of `distutils.version.LooseVersion`
---
 python/tvm/contrib/msc/core/utils/info.py             |  6 +++---
 python/tvm/relay/frontend/pytorch_utils.py            |  4 ++--
 python/tvm/relay/op/contrib/ethosn.py                 |  6 +++---
 python/tvm/relay/testing/tflite.py                    |  4 ++--
 .../contrib/test_arm_compute_lib/test_network.py      |  4 ++--
 tests/python/frontend/tensorflow/test_forward.py      |  9 ++++-----
 tests/python/frontend/tflite/test_forward.py          | 19 +++++++++----------
 7 files changed, 25 insertions(+), 27 deletions(-)

diff --git a/python/tvm/contrib/msc/core/utils/info.py 
b/python/tvm/contrib/msc/core/utils/info.py
index 4fea45f8fa..58b0811279 100644
--- a/python/tvm/contrib/msc/core/utils/info.py
+++ b/python/tvm/contrib/msc/core/utils/info.py
@@ -17,7 +17,7 @@
 """tvm.contrib.msc.core.utils.info"""
 
 from typing import List, Tuple, Dict, Any, Union
-from distutils.version import LooseVersion
+from packaging.version import parse
 import numpy as np
 
 import tvm
@@ -409,8 +409,8 @@ def get_version(framework: str) -> List[int]:
             raw_version = "1.0.0"
     except:  # pylint: disable=bare-except
         raw_version = "1.0.0"
-    raw_version = raw_version or "1.0.0"
-    return LooseVersion(raw_version).version
+    version = parse(raw_version or "1.0.0")
+    return [version.major, version.minor, version.micro]
 
 
 def compare_version(given_version: List[int], target_version: List[int]) -> 
int:
diff --git a/python/tvm/relay/frontend/pytorch_utils.py 
b/python/tvm/relay/frontend/pytorch_utils.py
index 7de1248bda..8686be4b1e 100644
--- a/python/tvm/relay/frontend/pytorch_utils.py
+++ b/python/tvm/relay/frontend/pytorch_utils.py
@@ -36,7 +36,7 @@ def is_version_greater_than(ver):
     than the one given as an argument.
     """
     import torch
-    from distutils.version import LooseVersion
+    from packaging.version import parse
 
     torch_ver = torch.__version__
     # PT version numbers can include +cu[cuda version code]
@@ -44,7 +44,7 @@ def is_version_greater_than(ver):
     if "+cu" in torch_ver:
         torch_ver = torch_ver.split("+cu")[0]
 
-    return LooseVersion(torch_ver) > ver
+    return parse(torch_ver) > parse(ver)
 
 
 def getattr_attr_name(node):
diff --git a/python/tvm/relay/op/contrib/ethosn.py 
b/python/tvm/relay/op/contrib/ethosn.py
index 81534d48a2..c1e87ad5d9 100644
--- a/python/tvm/relay/op/contrib/ethosn.py
+++ b/python/tvm/relay/op/contrib/ethosn.py
@@ -17,7 +17,7 @@
 # pylint: disable=invalid-name, unused-argument
 """Arm(R) Ethos(TM)-N NPU supported operators."""
 from enum import Enum
-from distutils.version import LooseVersion
+from packaging.version import parse
 
 import tvm.ir
 from tvm.relay import transform
@@ -118,7 +118,7 @@ def partition_for_ethosn(mod, params=None, **opts):
     """
     api_version = ethosn_api_version()
     supported_api_versions = ["3.2.0"]
-    if all(api_version != LooseVersion(exp_ver) for exp_ver in 
supported_api_versions):
+    if all(parse(api_version) != parse(exp_ver) for exp_ver in 
supported_api_versions):
         raise ValueError(
             f"Driver stack version {api_version} is unsupported. "
             f"Please use version in {supported_api_versions}."
@@ -433,7 +433,7 @@ def split(expr):
     """Check if a split is supported by Ethos-N."""
     if not ethosn_available():
         return False
-    if ethosn_api_version() == LooseVersion("3.0.1"):
+    if parse(ethosn_api_version()) == parse("3.0.1"):
         return False
     if not _ethosn.split(expr):
         return False
diff --git a/python/tvm/relay/testing/tflite.py 
b/python/tvm/relay/testing/tflite.py
index df9c0bcadf..29f6bc62ca 100644
--- a/python/tvm/relay/testing/tflite.py
+++ b/python/tvm/relay/testing/tflite.py
@@ -15,7 +15,7 @@
 # specific language governing permissions and limitations
 # under the License.
 """Common utilities for creating TFLite models"""
-from distutils.version import LooseVersion
+from packaging.version import parse
 import numpy as np
 import pytest
 import tflite.Model  # pylint: disable=wrong-import-position
@@ -134,7 +134,7 @@ class TFLiteModel:
         assert self.serial_model is not None, "TFLite model was not created."
 
         output_tolerance = None
-        if tf.__version__ < LooseVersion("2.5.0"):
+        if parse(tf.__version__) < parse("2.5.0"):
             output_tolerance = 1
             interpreter = tf.lite.Interpreter(model_content=self.serial_model)
         else:
diff --git a/tests/python/contrib/test_arm_compute_lib/test_network.py 
b/tests/python/contrib/test_arm_compute_lib/test_network.py
index 3cf81e971f..8c6302abf8 100644
--- a/tests/python/contrib/test_arm_compute_lib/test_network.py
+++ b/tests/python/contrib/test_arm_compute_lib/test_network.py
@@ -16,7 +16,7 @@
 # under the License.
 """Arm Compute Library network tests."""
 
-from distutils.version import LooseVersion
+from packaging.version import parse
 
 import numpy as np
 import pytest
@@ -137,7 +137,7 @@ def test_mobilenet():
         mod, params = _get_keras_model(mobilenet, inputs)
         return mod, params, inputs
 
-    if keras.__version__ < LooseVersion("2.9"):
+    if parse(keras.__version__) < parse("2.9"):
         # This can be removed after we migrate to TF/Keras >= 2.9
         expected_tvm_ops = 56
         expected_acl_partitions = 31
diff --git a/tests/python/frontend/tensorflow/test_forward.py 
b/tests/python/frontend/tensorflow/test_forward.py
index db270ccb2e..354ed38a62 100644
--- a/tests/python/frontend/tensorflow/test_forward.py
+++ b/tests/python/frontend/tensorflow/test_forward.py
@@ -21,7 +21,6 @@ Tensorflow testcases
 This article is a test script to test tensorflow operator with Relay.
 """
 from __future__ import print_function
-from distutils.version import LooseVersion
 
 import threading
 import platform
@@ -1755,7 +1754,7 @@ def _test_concat_v2(shape1, shape2, dim):
 
 
 def test_forward_concat_v2():
-    if tf.__version__ < LooseVersion("1.4.1"):
+    if package_version.parse(tf.__version__) < package_version.parse("1.4.1"):
         return
 
     _test_concat_v2([2, 3], [2, 3], 0)
@@ -3128,7 +3127,7 @@ def _test_forward_clip_by_value(ip_shape, clip_value_min, 
clip_value_max, dtype)
 
 def test_forward_clip_by_value():
     """test ClipByValue op"""
-    if tf.__version__ < LooseVersion("1.9"):
+    if package_version.parse(tf.__version__) < package_version.parse("1.9"):
         _test_forward_clip_by_value((4,), 0.1, 5.0, "float32")
         _test_forward_clip_by_value((4, 4), 1, 5, "int32")
 
@@ -4482,7 +4481,7 @@ def _test_forward_zeros_like(in_shape, dtype):
 
 
 def test_forward_zeros_like():
-    if tf.__version__ < LooseVersion("1.2"):
+    if package_version.parse(tf.__version__) < package_version.parse("1.2"):
         _test_forward_zeros_like((2, 3), "int32")
         _test_forward_zeros_like((2, 3, 5), "int8")
         _test_forward_zeros_like((2, 3, 5, 7), "uint16")
@@ -5566,7 +5565,7 @@ def test_forward_spop():
     # This test is expected to fail in TF version >= 2.6
     # as the generated graph will be considered frozen, hence
     # not passing the criteria for the test below.
-    if tf.__version__ < LooseVersion("2.6.1"):
+    if package_version.parse(tf.__version__) < package_version.parse("2.6.1"):
         _test_spop_resource_variables()
 
     # Placeholder test cases
diff --git a/tests/python/frontend/tflite/test_forward.py 
b/tests/python/frontend/tflite/test_forward.py
index 75a2a37c63..cb0b17ea3f 100644
--- a/tests/python/frontend/tflite/test_forward.py
+++ b/tests/python/frontend/tflite/test_forward.py
@@ -22,7 +22,6 @@ This article is a test script to test TFLite operator with 
Relay.
 """
 from __future__ import print_function
 from functools import partial
-from distutils.version import LooseVersion
 import platform
 import os
 import tempfile
@@ -1054,7 +1053,7 @@ def _test_tflite2_quantized_convolution(
     input_node = subgraph.Tensors(model_input).Name().decode("utf-8")
 
     tflite_output = run_tflite_graph(tflite_model_quant, data)
-    if tf.__version__ < LooseVersion("2.9"):
+    if package_version.parse(tf.__version__) < package_version.parse("2.9"):
         input_node = data_in.name.replace(":0", "")
     else:
         input_node = "serving_default_" + data_in.name + ":0"
@@ -1775,7 +1774,7 @@ def _test_tflite2_quantized_transpose_conv(
 
     tflite_output = run_tflite_graph(tflite_model_quant, data)
 
-    if tf.__version__ < LooseVersion("2.9"):
+    if package_version.parse(tf.__version__) < package_version.parse("2.9"):
         input_node = data_in.name.replace(":0", "")
     else:
         input_node = "serving_default_" + data_in.name + ":0"
@@ -2219,9 +2218,9 @@ def _test_abs(data, quantized, int_quant_dtype=tf.int8):
         tflite_output = run_tflite_graph(tflite_model_quant, data)
 
         # TFLite 2.6.x upgrade support
-        if tf.__version__ < LooseVersion("2.6.1"):
+        if package_version.parse(tf.__version__) < 
package_version.parse("2.6.1"):
             in_node = ["serving_default_input_int8"]
-        elif tf.__version__ < LooseVersion("2.9"):
+        elif package_version.parse(tf.__version__) < 
package_version.parse("2.9"):
             in_node = (
                 ["serving_default_input_int16"] if int_quant_dtype == tf.int16 
else ["tfl.quantize"]
             )
@@ -2245,7 +2244,7 @@ def _test_rsqrt(data, quantized, int_quant_dtype=tf.int8):
     """One iteration of rsqrt"""
 
     # tensorflow version upgrade support
-    if tf.__version__ < LooseVersion("2.6.1") or not quantized:
+    if package_version.parse(tf.__version__) < package_version.parse("2.6.1") 
or not quantized:
         return _test_unary_elemwise(
             math_ops.rsqrt, data, quantized, quant_range=[1, 6], 
int_quant_dtype=int_quant_dtype
         )
@@ -2254,7 +2253,7 @@ def _test_rsqrt(data, quantized, int_quant_dtype=tf.int8):
             tf.math.rsqrt, data, int_quant_dtype=int_quant_dtype
         )
         tflite_output = run_tflite_graph(tflite_model_quant, data)
-        if tf.__version__ < LooseVersion("2.9"):
+        if package_version.parse(tf.__version__) < 
package_version.parse("2.9"):
             in_node = ["tfl.quantize"]
         else:
             in_node = "serving_default_input"
@@ -2338,7 +2337,7 @@ def _test_cos(data, quantized, int_quant_dtype=tf.int8):
             tf.math.cos, data, int_quant_dtype=int_quant_dtype
         )
         tflite_output = run_tflite_graph(tflite_model_quant, data)
-        if tf.__version__ < LooseVersion("2.9"):
+        if package_version.parse(tf.__version__) < 
package_version.parse("2.9"):
             in_node = ["tfl.quantize"]
         else:
             in_node = "serving_default_input"
@@ -3396,7 +3395,7 @@ def _test_quantize_dequantize(data):
     tflite_model_quant = _quantize_keras_model(keras_model, 
representative_data_gen, True, True)
 
     tflite_output = run_tflite_graph(tflite_model_quant, data)
-    if tf.__version__ < LooseVersion("2.9"):
+    if package_version.parse(tf.__version__) < package_version.parse("2.9"):
         in_node = data_in.name.split(":")[0]
     else:
         in_node = "serving_default_" + data_in.name + ":0"
@@ -3426,7 +3425,7 @@ def _test_quantize_dequantize_const(data):
     tflite_model_quant = _quantize_keras_model(keras_model, 
representative_data_gen, True, True)
 
     tflite_output = run_tflite_graph(tflite_model_quant, data)
-    if tf.__version__ < LooseVersion("2.9"):
+    if package_version.parse(tf.__version__) < package_version.parse("2.9"):
         in_node = data_in.name.split(":")[0]
     else:
         in_node = "serving_default_" + data_in.name + ":0"

Reply via email to