This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 3c7adfb1f7 Use `packaging.version.parse` instead of
`distutils.version.LooseVersion` (#17173)
3c7adfb1f7 is described below
commit 3c7adfb1f7015078903ba53cc5317ead1b4f5f32
Author: Masahiro Hiramori <[email protected]>
AuthorDate: Sat Jul 20 04:00:01 2024 +0900
Use `packaging.version.parse` instead of `distutils.version.LooseVersion`
(#17173)
use `packaging.version.parse` instead of `distutils.version.LooseVersion`
---
python/tvm/contrib/msc/core/utils/info.py | 6 +++---
python/tvm/relay/frontend/pytorch_utils.py | 4 ++--
python/tvm/relay/op/contrib/ethosn.py | 6 +++---
python/tvm/relay/testing/tflite.py | 4 ++--
.../contrib/test_arm_compute_lib/test_network.py | 4 ++--
tests/python/frontend/tensorflow/test_forward.py | 9 ++++-----
tests/python/frontend/tflite/test_forward.py | 19 +++++++++----------
7 files changed, 25 insertions(+), 27 deletions(-)
diff --git a/python/tvm/contrib/msc/core/utils/info.py
b/python/tvm/contrib/msc/core/utils/info.py
index 4fea45f8fa..58b0811279 100644
--- a/python/tvm/contrib/msc/core/utils/info.py
+++ b/python/tvm/contrib/msc/core/utils/info.py
@@ -17,7 +17,7 @@
"""tvm.contrib.msc.core.utils.info"""
from typing import List, Tuple, Dict, Any, Union
-from distutils.version import LooseVersion
+from packaging.version import parse
import numpy as np
import tvm
@@ -409,8 +409,8 @@ def get_version(framework: str) -> List[int]:
raw_version = "1.0.0"
except: # pylint: disable=bare-except
raw_version = "1.0.0"
- raw_version = raw_version or "1.0.0"
- return LooseVersion(raw_version).version
+ version = parse(raw_version or "1.0.0")
+ return [version.major, version.minor, version.micro]
def compare_version(given_version: List[int], target_version: List[int]) ->
int:
diff --git a/python/tvm/relay/frontend/pytorch_utils.py
b/python/tvm/relay/frontend/pytorch_utils.py
index 7de1248bda..8686be4b1e 100644
--- a/python/tvm/relay/frontend/pytorch_utils.py
+++ b/python/tvm/relay/frontend/pytorch_utils.py
@@ -36,7 +36,7 @@ def is_version_greater_than(ver):
than the one given as an argument.
"""
import torch
- from distutils.version import LooseVersion
+ from packaging.version import parse
torch_ver = torch.__version__
# PT version numbers can include +cu[cuda version code]
@@ -44,7 +44,7 @@ def is_version_greater_than(ver):
if "+cu" in torch_ver:
torch_ver = torch_ver.split("+cu")[0]
- return LooseVersion(torch_ver) > ver
+ return parse(torch_ver) > parse(ver)
def getattr_attr_name(node):
diff --git a/python/tvm/relay/op/contrib/ethosn.py
b/python/tvm/relay/op/contrib/ethosn.py
index 81534d48a2..c1e87ad5d9 100644
--- a/python/tvm/relay/op/contrib/ethosn.py
+++ b/python/tvm/relay/op/contrib/ethosn.py
@@ -17,7 +17,7 @@
# pylint: disable=invalid-name, unused-argument
"""Arm(R) Ethos(TM)-N NPU supported operators."""
from enum import Enum
-from distutils.version import LooseVersion
+from packaging.version import parse
import tvm.ir
from tvm.relay import transform
@@ -118,7 +118,7 @@ def partition_for_ethosn(mod, params=None, **opts):
"""
api_version = ethosn_api_version()
supported_api_versions = ["3.2.0"]
- if all(api_version != LooseVersion(exp_ver) for exp_ver in
supported_api_versions):
+ if all(parse(api_version) != parse(exp_ver) for exp_ver in
supported_api_versions):
raise ValueError(
f"Driver stack version {api_version} is unsupported. "
f"Please use version in {supported_api_versions}."
@@ -433,7 +433,7 @@ def split(expr):
"""Check if a split is supported by Ethos-N."""
if not ethosn_available():
return False
- if ethosn_api_version() == LooseVersion("3.0.1"):
+ if parse(ethosn_api_version()) == parse("3.0.1"):
return False
if not _ethosn.split(expr):
return False
diff --git a/python/tvm/relay/testing/tflite.py
b/python/tvm/relay/testing/tflite.py
index df9c0bcadf..29f6bc62ca 100644
--- a/python/tvm/relay/testing/tflite.py
+++ b/python/tvm/relay/testing/tflite.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
"""Common utilities for creating TFLite models"""
-from distutils.version import LooseVersion
+from packaging.version import parse
import numpy as np
import pytest
import tflite.Model # pylint: disable=wrong-import-position
@@ -134,7 +134,7 @@ class TFLiteModel:
assert self.serial_model is not None, "TFLite model was not created."
output_tolerance = None
- if tf.__version__ < LooseVersion("2.5.0"):
+ if parse(tf.__version__) < parse("2.5.0"):
output_tolerance = 1
interpreter = tf.lite.Interpreter(model_content=self.serial_model)
else:
diff --git a/tests/python/contrib/test_arm_compute_lib/test_network.py
b/tests/python/contrib/test_arm_compute_lib/test_network.py
index 3cf81e971f..8c6302abf8 100644
--- a/tests/python/contrib/test_arm_compute_lib/test_network.py
+++ b/tests/python/contrib/test_arm_compute_lib/test_network.py
@@ -16,7 +16,7 @@
# under the License.
"""Arm Compute Library network tests."""
-from distutils.version import LooseVersion
+from packaging.version import parse
import numpy as np
import pytest
@@ -137,7 +137,7 @@ def test_mobilenet():
mod, params = _get_keras_model(mobilenet, inputs)
return mod, params, inputs
- if keras.__version__ < LooseVersion("2.9"):
+ if parse(keras.__version__) < parse("2.9"):
# This can be removed after we migrate to TF/Keras >= 2.9
expected_tvm_ops = 56
expected_acl_partitions = 31
diff --git a/tests/python/frontend/tensorflow/test_forward.py
b/tests/python/frontend/tensorflow/test_forward.py
index db270ccb2e..354ed38a62 100644
--- a/tests/python/frontend/tensorflow/test_forward.py
+++ b/tests/python/frontend/tensorflow/test_forward.py
@@ -21,7 +21,6 @@ Tensorflow testcases
This article is a test script to test tensorflow operator with Relay.
"""
from __future__ import print_function
-from distutils.version import LooseVersion
import threading
import platform
@@ -1755,7 +1754,7 @@ def _test_concat_v2(shape1, shape2, dim):
def test_forward_concat_v2():
- if tf.__version__ < LooseVersion("1.4.1"):
+ if package_version.parse(tf.__version__) < package_version.parse("1.4.1"):
return
_test_concat_v2([2, 3], [2, 3], 0)
@@ -3128,7 +3127,7 @@ def _test_forward_clip_by_value(ip_shape, clip_value_min,
clip_value_max, dtype)
def test_forward_clip_by_value():
"""test ClipByValue op"""
- if tf.__version__ < LooseVersion("1.9"):
+ if package_version.parse(tf.__version__) < package_version.parse("1.9"):
_test_forward_clip_by_value((4,), 0.1, 5.0, "float32")
_test_forward_clip_by_value((4, 4), 1, 5, "int32")
@@ -4482,7 +4481,7 @@ def _test_forward_zeros_like(in_shape, dtype):
def test_forward_zeros_like():
- if tf.__version__ < LooseVersion("1.2"):
+ if package_version.parse(tf.__version__) < package_version.parse("1.2"):
_test_forward_zeros_like((2, 3), "int32")
_test_forward_zeros_like((2, 3, 5), "int8")
_test_forward_zeros_like((2, 3, 5, 7), "uint16")
@@ -5566,7 +5565,7 @@ def test_forward_spop():
# This test is expected to fail in TF version >= 2.6
# as the generated graph will be considered frozen, hence
# not passing the criteria for the test below.
- if tf.__version__ < LooseVersion("2.6.1"):
+ if package_version.parse(tf.__version__) < package_version.parse("2.6.1"):
_test_spop_resource_variables()
# Placeholder test cases
diff --git a/tests/python/frontend/tflite/test_forward.py
b/tests/python/frontend/tflite/test_forward.py
index 75a2a37c63..cb0b17ea3f 100644
--- a/tests/python/frontend/tflite/test_forward.py
+++ b/tests/python/frontend/tflite/test_forward.py
@@ -22,7 +22,6 @@ This article is a test script to test TFLite operator with
Relay.
"""
from __future__ import print_function
from functools import partial
-from distutils.version import LooseVersion
import platform
import os
import tempfile
@@ -1054,7 +1053,7 @@ def _test_tflite2_quantized_convolution(
input_node = subgraph.Tensors(model_input).Name().decode("utf-8")
tflite_output = run_tflite_graph(tflite_model_quant, data)
- if tf.__version__ < LooseVersion("2.9"):
+ if package_version.parse(tf.__version__) < package_version.parse("2.9"):
input_node = data_in.name.replace(":0", "")
else:
input_node = "serving_default_" + data_in.name + ":0"
@@ -1775,7 +1774,7 @@ def _test_tflite2_quantized_transpose_conv(
tflite_output = run_tflite_graph(tflite_model_quant, data)
- if tf.__version__ < LooseVersion("2.9"):
+ if package_version.parse(tf.__version__) < package_version.parse("2.9"):
input_node = data_in.name.replace(":0", "")
else:
input_node = "serving_default_" + data_in.name + ":0"
@@ -2219,9 +2218,9 @@ def _test_abs(data, quantized, int_quant_dtype=tf.int8):
tflite_output = run_tflite_graph(tflite_model_quant, data)
# TFLite 2.6.x upgrade support
- if tf.__version__ < LooseVersion("2.6.1"):
+ if package_version.parse(tf.__version__) <
package_version.parse("2.6.1"):
in_node = ["serving_default_input_int8"]
- elif tf.__version__ < LooseVersion("2.9"):
+ elif package_version.parse(tf.__version__) <
package_version.parse("2.9"):
in_node = (
["serving_default_input_int16"] if int_quant_dtype == tf.int16
else ["tfl.quantize"]
)
@@ -2245,7 +2244,7 @@ def _test_rsqrt(data, quantized, int_quant_dtype=tf.int8):
"""One iteration of rsqrt"""
# tensorflow version upgrade support
- if tf.__version__ < LooseVersion("2.6.1") or not quantized:
+ if package_version.parse(tf.__version__) < package_version.parse("2.6.1")
or not quantized:
return _test_unary_elemwise(
math_ops.rsqrt, data, quantized, quant_range=[1, 6],
int_quant_dtype=int_quant_dtype
)
@@ -2254,7 +2253,7 @@ def _test_rsqrt(data, quantized, int_quant_dtype=tf.int8):
tf.math.rsqrt, data, int_quant_dtype=int_quant_dtype
)
tflite_output = run_tflite_graph(tflite_model_quant, data)
- if tf.__version__ < LooseVersion("2.9"):
+ if package_version.parse(tf.__version__) <
package_version.parse("2.9"):
in_node = ["tfl.quantize"]
else:
in_node = "serving_default_input"
@@ -2338,7 +2337,7 @@ def _test_cos(data, quantized, int_quant_dtype=tf.int8):
tf.math.cos, data, int_quant_dtype=int_quant_dtype
)
tflite_output = run_tflite_graph(tflite_model_quant, data)
- if tf.__version__ < LooseVersion("2.9"):
+ if package_version.parse(tf.__version__) <
package_version.parse("2.9"):
in_node = ["tfl.quantize"]
else:
in_node = "serving_default_input"
@@ -3396,7 +3395,7 @@ def _test_quantize_dequantize(data):
tflite_model_quant = _quantize_keras_model(keras_model,
representative_data_gen, True, True)
tflite_output = run_tflite_graph(tflite_model_quant, data)
- if tf.__version__ < LooseVersion("2.9"):
+ if package_version.parse(tf.__version__) < package_version.parse("2.9"):
in_node = data_in.name.split(":")[0]
else:
in_node = "serving_default_" + data_in.name + ":0"
@@ -3426,7 +3425,7 @@ def _test_quantize_dequantize_const(data):
tflite_model_quant = _quantize_keras_model(keras_model,
representative_data_gen, True, True)
tflite_output = run_tflite_graph(tflite_model_quant, data)
- if tf.__version__ < LooseVersion("2.9"):
+ if package_version.parse(tf.__version__) < package_version.parse("2.9"):
in_node = data_in.name.split(":")[0]
else:
in_node = "serving_default_" + data_in.name + ":0"