This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new eacd234a887 [SPARK-44963][PYTHON][ML][TESTS] Make PySpark (pyspark-ml 
module) tests passing without any optional dependency
eacd234a887 is described below

commit eacd234a88768d984e3f71f496ea6f4f0930df5d
Author: Hyukjin Kwon <gurwls...@apache.org>
AuthorDate: Mon Aug 28 11:55:00 2023 +0900

    [SPARK-44963][PYTHON][ML][TESTS] Make PySpark (pyspark-ml module) tests 
passing without any optional dependency
    
    ### What changes were proposed in this pull request?
    
    This PR proposes to fix the tests to properly run or skip when there aren't 
optional dependencies installed.
    
    ### Why are the changes needed?
    
    Currently, it fails as below:
    
    ```
    ./python/run-tests --python-executables=python3 --modules=pyspark-ml
    ...
    Starting test(python3): pyspark.ml.tests.test_model_cache (temp output: 
/Users/hyukjin.kwon/workspace/forked/spark/python/target/f6f88c1e-0cb2-43e6-980e-47f1cdb9b463/python3__pyspark.ml.tests.test_model_cache__zij05l1u.log)
    Traceback (most recent call last):
      File 
"/Users/hyukjin.kwon/miniconda3/envs/vanilla-3.10/lib/python3.10/runpy.py", 
line 196, in _run_module_as_main
        return _run_code(code, main_globals, None,
      File 
"/Users/hyukjin.kwon/miniconda3/envs/vanilla-3.10/lib/python3.10/runpy.py", 
line 86, in _run_code
        exec(code, run_globals)
      File 
"/Users/hyukjin.kwon/workspace/forked/spark/python/pyspark/ml/tests/test_functions.py",
 line 18, in <module>
        import pandas as pd
    ModuleNotFoundError: No module named 'pandas'
    ```
    
    PySpark tests should pass without optional dependencies.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No, test-only.
    
    ### How was this patch tested?
    
    Manually ran as described above.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #42678 from HyukjinKwon/SPARK-44963.
    
    Authored-by: Hyukjin Kwon <gurwls...@apache.org>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 python/pyspark/ml/functions.py                     | 27 ++++++++++++++++---
 .../tests/connect/test_connect_classification.py   |  9 +++++--
 .../ml/tests/connect/test_connect_evaluation.py    | 10 ++++++--
 .../ml/tests/connect/test_connect_feature.py       |  6 ++++-
 .../ml/tests/connect/test_connect_pipeline.py      |  6 ++++-
 .../ml/tests/connect/test_connect_summarizer.py    |  6 ++++-
 .../ml/tests/connect/test_connect_tuning.py        |  7 ++++-
 .../connect/test_legacy_mode_classification.py     | 15 +++++++----
 .../tests/connect/test_legacy_mode_evaluation.py   | 19 +++++++++-----
 .../ml/tests/connect/test_legacy_mode_feature.py   | 16 +++++++-----
 .../ml/tests/connect/test_legacy_mode_pipeline.py  | 15 +++++------
 .../tests/connect/test_legacy_mode_summarizer.py   |  6 ++++-
 .../ml/tests/connect/test_legacy_mode_tuning.py    | 30 ++++++++++++----------
 .../tests/connect/test_parity_torch_data_loader.py |  4 +--
 python/pyspark/ml/tests/test_functions.py          | 17 +++++++++---
 python/pyspark/ml/torch/tests/test_distributor.py  |  2 +-
 .../ml/torch/tests/test_log_communication.py       |  2 +-
 17 files changed, 138 insertions(+), 59 deletions(-)

diff --git a/python/pyspark/ml/functions.py b/python/pyspark/ml/functions.py
index 89b05b692ea..55631a818bb 100644
--- a/python/pyspark/ml/functions.py
+++ b/python/pyspark/ml/functions.py
@@ -17,9 +17,16 @@
 from __future__ import annotations
 
 import inspect
-import numpy as np
-import pandas as pd
 import uuid
+from typing import Any, Callable, Iterator, List, Mapping, TYPE_CHECKING, 
Tuple, Union, Optional
+
+import numpy as np
+
+try:
+    import pandas as pd
+except ImportError:
+    pass  # Let it throw a better error message later when the API is invoked.
+
 from pyspark import SparkContext
 from pyspark.sql.functions import pandas_udf
 from pyspark.sql.column import Column, _to_java_column
@@ -36,7 +43,6 @@ from pyspark.sql.types import (
     StructType,
 )
 from pyspark.ml.util import try_remote_functions
-from typing import Any, Callable, Iterator, List, Mapping, TYPE_CHECKING, 
Tuple, Union, Optional
 
 if TYPE_CHECKING:
     from pyspark.sql._typing import UserDefinedFunctionLike
@@ -822,6 +828,21 @@ def _test() -> None:
     import pyspark.ml.functions
     import sys
 
+    from pyspark.sql.pandas.utils import (
+        require_minimum_pandas_version,
+        require_minimum_pyarrow_version,
+    )
+
+    try:
+        require_minimum_pandas_version()
+        require_minimum_pyarrow_version()
+    except Exception as e:
+        print(
+            f"Skipping pyspark.ml.functions doctests: {e}",
+            file=sys.stderr,
+        )
+        sys.exit(0)
+
     globs = pyspark.ml.functions.__dict__.copy()
     spark = SparkSession.builder.master("local[2]").appName("ml.functions 
tests").getOrCreate()
     sc = spark.sparkContext
diff --git a/python/pyspark/ml/tests/connect/test_connect_classification.py 
b/python/pyspark/ml/tests/connect/test_connect_classification.py
index f3e621c19f0..1c777fc3d40 100644
--- a/python/pyspark/ml/tests/connect/test_connect_classification.py
+++ b/python/pyspark/ml/tests/connect/test_connect_classification.py
@@ -18,7 +18,7 @@
 
 import unittest
 from pyspark.sql import SparkSession
-from pyspark.ml.tests.connect.test_legacy_mode_classification import 
ClassificationTestsMixin
+from pyspark.testing.connectutils import should_test_connect, 
connect_requirement_message
 
 have_torch = True
 try:
@@ -26,8 +26,13 @@ try:
 except ImportError:
     have_torch = False
 
+if should_test_connect:
+    from pyspark.ml.tests.connect.test_legacy_mode_classification import 
ClassificationTestsMixin
 
-@unittest.skipIf(not have_torch, "torch is required")
+
+@unittest.skipIf(
+    not should_test_connect or not have_torch, connect_requirement_message or 
"torch is required"
+)
 class ClassificationTestsOnConnect(ClassificationTestsMixin, 
unittest.TestCase):
     def setUp(self) -> None:
         self.spark = (
diff --git a/python/pyspark/ml/tests/connect/test_connect_evaluation.py 
b/python/pyspark/ml/tests/connect/test_connect_evaluation.py
index ce7cf03049d..0512619d2bf 100644
--- a/python/pyspark/ml/tests/connect/test_connect_evaluation.py
+++ b/python/pyspark/ml/tests/connect/test_connect_evaluation.py
@@ -17,7 +17,7 @@
 
 import unittest
 from pyspark.sql import SparkSession
-from pyspark.ml.tests.connect.test_legacy_mode_evaluation import 
EvaluationTestsMixin
+from pyspark.testing.connectutils import should_test_connect, 
connect_requirement_message
 
 have_torcheval = True
 try:
@@ -25,8 +25,14 @@ try:
 except ImportError:
     have_torcheval = False
 
+if should_test_connect:
+    from pyspark.ml.tests.connect.test_legacy_mode_evaluation import 
EvaluationTestsMixin
 
-@unittest.skipIf(not have_torcheval, "torcheval is required")
+
+@unittest.skipIf(
+    not should_test_connect or not have_torcheval,
+    connect_requirement_message or "torcheval is required",
+)
 class EvaluationTestsOnConnect(EvaluationTestsMixin, unittest.TestCase):
     def setUp(self) -> None:
         self.spark = SparkSession.builder.remote("local[2]").getOrCreate()
diff --git a/python/pyspark/ml/tests/connect/test_connect_feature.py 
b/python/pyspark/ml/tests/connect/test_connect_feature.py
index d7698c37722..bd5eebe6e42 100644
--- a/python/pyspark/ml/tests/connect/test_connect_feature.py
+++ b/python/pyspark/ml/tests/connect/test_connect_feature.py
@@ -17,9 +17,13 @@
 
 import unittest
 from pyspark.sql import SparkSession
-from pyspark.ml.tests.connect.test_legacy_mode_feature import FeatureTestsMixin
+from pyspark.testing.connectutils import should_test_connect, 
connect_requirement_message
 
+if should_test_connect:
+    from pyspark.ml.tests.connect.test_legacy_mode_feature import 
FeatureTestsMixin
 
+
+@unittest.skipIf(not should_test_connect, connect_requirement_message)
 class FeatureTestsOnConnect(FeatureTestsMixin, unittest.TestCase):
     def setUp(self) -> None:
         self.spark = SparkSession.builder.remote("local[2]").getOrCreate()
diff --git a/python/pyspark/ml/tests/connect/test_connect_pipeline.py 
b/python/pyspark/ml/tests/connect/test_connect_pipeline.py
index e676c8bfee9..d2d960d6b74 100644
--- a/python/pyspark/ml/tests/connect/test_connect_pipeline.py
+++ b/python/pyspark/ml/tests/connect/test_connect_pipeline.py
@@ -18,9 +18,13 @@
 
 import unittest
 from pyspark.sql import SparkSession
-from pyspark.ml.tests.connect.test_legacy_mode_pipeline import 
PipelineTestsMixin
+from pyspark.testing.connectutils import should_test_connect, 
connect_requirement_message
 
+if should_test_connect:
+    from pyspark.ml.tests.connect.test_legacy_mode_pipeline import 
PipelineTestsMixin
 
+
+@unittest.skipIf(not should_test_connect, connect_requirement_message)
 class PipelineTestsOnConnect(PipelineTestsMixin, unittest.TestCase):
     def setUp(self) -> None:
         self.spark = (
diff --git a/python/pyspark/ml/tests/connect/test_connect_summarizer.py 
b/python/pyspark/ml/tests/connect/test_connect_summarizer.py
index 0b0537dfee3..107f8348d7e 100644
--- a/python/pyspark/ml/tests/connect/test_connect_summarizer.py
+++ b/python/pyspark/ml/tests/connect/test_connect_summarizer.py
@@ -17,9 +17,13 @@
 
 import unittest
 from pyspark.sql import SparkSession
-from pyspark.ml.tests.connect.test_legacy_mode_summarizer import 
SummarizerTestsMixin
+from pyspark.testing.connectutils import should_test_connect, 
connect_requirement_message
 
+if should_test_connect:
+    from pyspark.ml.tests.connect.test_legacy_mode_summarizer import 
SummarizerTestsMixin
 
+
+@unittest.skipIf(not should_test_connect, connect_requirement_message)
 class SummarizerTestsOnConnect(SummarizerTestsMixin, unittest.TestCase):
     def setUp(self) -> None:
         self.spark = SparkSession.builder.remote("local[2]").getOrCreate()
diff --git a/python/pyspark/ml/tests/connect/test_connect_tuning.py 
b/python/pyspark/ml/tests/connect/test_connect_tuning.py
index 18673d4b26b..a38b081636a 100644
--- a/python/pyspark/ml/tests/connect/test_connect_tuning.py
+++ b/python/pyspark/ml/tests/connect/test_connect_tuning.py
@@ -18,9 +18,14 @@
 
 import unittest
 from pyspark.sql import SparkSession
-from pyspark.ml.tests.connect.test_legacy_mode_tuning import 
CrossValidatorTestsMixin
 
+from pyspark.testing.connectutils import should_test_connect, 
connect_requirement_message
 
+if should_test_connect:
+    from pyspark.ml.tests.connect.test_legacy_mode_tuning import 
CrossValidatorTestsMixin
+
+
+@unittest.skipIf(not should_test_connect, connect_requirement_message)
 class CrossValidatorTestsOnConnect(CrossValidatorTestsMixin, 
unittest.TestCase):
     def setUp(self) -> None:
         self.spark = (
diff --git a/python/pyspark/ml/tests/connect/test_legacy_mode_classification.py 
b/python/pyspark/ml/tests/connect/test_legacy_mode_classification.py
index 84d5829122a..5e5f1b64a33 100644
--- a/python/pyspark/ml/tests/connect/test_legacy_mode_classification.py
+++ b/python/pyspark/ml/tests/connect/test_legacy_mode_classification.py
@@ -19,12 +19,8 @@ import os
 import tempfile
 import unittest
 import numpy as np
-from pyspark.ml.connect.classification import (
-    LogisticRegression as LORV2,
-    LogisticRegressionModel as LORV2Model,
-)
 from pyspark.sql import SparkSession
-
+from pyspark.testing.connectutils import should_test_connect, 
connect_requirement_message
 
 have_torch = True
 try:
@@ -32,6 +28,12 @@ try:
 except ImportError:
     have_torch = False
 
+if should_test_connect:
+    from pyspark.ml.connect.classification import (
+        LogisticRegression as LORV2,
+        LogisticRegressionModel as LORV2Model,
+    )
+
 
 class ClassificationTestsMixin:
     @staticmethod
@@ -218,6 +220,9 @@ class ClassificationTestsMixin:
             loaded_model.transform(eval_df1.toPandas())
 
 
+@unittest.skipIf(
+    not should_test_connect or not have_torch, connect_requirement_message or 
"No torch found"
+)
 class ClassificationTests(ClassificationTestsMixin, unittest.TestCase):
     def setUp(self) -> None:
         self.spark = SparkSession.builder.master("local[2]").getOrCreate()
diff --git a/python/pyspark/ml/tests/connect/test_legacy_mode_evaluation.py 
b/python/pyspark/ml/tests/connect/test_legacy_mode_evaluation.py
index 9ff26c1f450..19442667b2b 100644
--- a/python/pyspark/ml/tests/connect/test_legacy_mode_evaluation.py
+++ b/python/pyspark/ml/tests/connect/test_legacy_mode_evaluation.py
@@ -19,13 +19,8 @@ import unittest
 import numpy as np
 import tempfile
 
-from pyspark.ml.connect.evaluation import (
-    RegressionEvaluator,
-    BinaryClassificationEvaluator,
-    MulticlassClassificationEvaluator,
-)
 from pyspark.sql import SparkSession
-
+from pyspark.testing.connectutils import should_test_connect, 
connect_requirement_message
 
 have_torcheval = True
 try:
@@ -33,6 +28,13 @@ try:
 except ImportError:
     have_torcheval = False
 
+if should_test_connect:
+    from pyspark.ml.connect.evaluation import (
+        RegressionEvaluator,
+        BinaryClassificationEvaluator,
+        MulticlassClassificationEvaluator,
+    )
+
 
 class EvaluationTestsMixin:
     def test_regressor_evaluator(self):
@@ -173,7 +175,10 @@ class EvaluationTestsMixin:
             assert loaded_evaluator.getMetricName() == "accuracy"
 
 
-@unittest.skipIf(not have_torcheval, "torcheval is required")
+@unittest.skipIf(
+    not should_test_connect or not have_torcheval,
+    connect_requirement_message or "torcheval is required",
+)
 class EvaluationTests(EvaluationTestsMixin, unittest.TestCase):
     def setUp(self) -> None:
         self.spark = SparkSession.builder.master("local[2]").getOrCreate()
diff --git a/python/pyspark/ml/tests/connect/test_legacy_mode_feature.py 
b/python/pyspark/ml/tests/connect/test_legacy_mode_feature.py
index 3aac4a0e097..4f8b74e1f70 100644
--- a/python/pyspark/ml/tests/connect/test_legacy_mode_feature.py
+++ b/python/pyspark/ml/tests/connect/test_legacy_mode_feature.py
@@ -22,13 +22,16 @@ import numpy as np
 import tempfile
 import unittest
 
-from pyspark.ml.connect.feature import (
-    MaxAbsScaler,
-    MaxAbsScalerModel,
-    StandardScaler,
-    StandardScalerModel,
-)
 from pyspark.sql import SparkSession
+from pyspark.testing.connectutils import should_test_connect, 
connect_requirement_message
+
+if should_test_connect:
+    from pyspark.ml.connect.feature import (
+        MaxAbsScaler,
+        MaxAbsScalerModel,
+        StandardScaler,
+        StandardScalerModel,
+    )
 
 
 class FeatureTestsMixin:
@@ -136,6 +139,7 @@ class FeatureTestsMixin:
                 np.testing.assert_allclose(sk_result, expected_result)
 
 
+@unittest.skipIf(not should_test_connect, connect_requirement_message)
 class FeatureTests(FeatureTestsMixin, unittest.TestCase):
     def setUp(self) -> None:
         self.spark = SparkSession.builder.master("local[2]").getOrCreate()
diff --git a/python/pyspark/ml/tests/connect/test_legacy_mode_pipeline.py 
b/python/pyspark/ml/tests/connect/test_legacy_mode_pipeline.py
index 5fd4f6f16cf..009c17e5b05 100644
--- a/python/pyspark/ml/tests/connect/test_legacy_mode_pipeline.py
+++ b/python/pyspark/ml/tests/connect/test_legacy_mode_pipeline.py
@@ -19,17 +19,13 @@ import os
 import tempfile
 import unittest
 import numpy as np
-from pyspark.ml.connect.feature import StandardScaler
-from pyspark.ml.connect.classification import LogisticRegression as LORV2
-from pyspark.ml.connect.pipeline import Pipeline
 from pyspark.sql import SparkSession
+from pyspark.testing.connectutils import should_test_connect, 
connect_requirement_message
 
-
-have_torch = True
-try:
-    import torch  # noqa: F401
-except ImportError:
-    have_torch = False
+if should_test_connect:
+    from pyspark.ml.connect.feature import StandardScaler
+    from pyspark.ml.connect.classification import LogisticRegression as LORV2
+    from pyspark.ml.connect.pipeline import Pipeline
 
 
 class PipelineTestsMixin:
@@ -164,6 +160,7 @@ class PipelineTestsMixin:
         assert lorv2.getOrDefault(lorv2.maxIter) == 200
 
 
+@unittest.skipIf(not should_test_connect, connect_requirement_message)
 class PipelineTests(PipelineTestsMixin, unittest.TestCase):
     def setUp(self) -> None:
         self.spark = SparkSession.builder.master("local[2]").getOrCreate()
diff --git a/python/pyspark/ml/tests/connect/test_legacy_mode_summarizer.py 
b/python/pyspark/ml/tests/connect/test_legacy_mode_summarizer.py
index 49c092b5023..2e6299dabdf 100644
--- a/python/pyspark/ml/tests/connect/test_legacy_mode_summarizer.py
+++ b/python/pyspark/ml/tests/connect/test_legacy_mode_summarizer.py
@@ -19,8 +19,11 @@
 import unittest
 import numpy as np
 
-from pyspark.ml.connect.summarizer import summarize_dataframe
 from pyspark.sql import SparkSession
+from pyspark.testing.connectutils import should_test_connect, 
connect_requirement_message
+
+if should_test_connect:
+    from pyspark.ml.connect.summarizer import summarize_dataframe
 
 
 class SummarizerTestsMixin:
@@ -58,6 +61,7 @@ class SummarizerTestsMixin:
         assert_dict_allclose(result_local, expected_result)
 
 
+@unittest.skipIf(not should_test_connect, connect_requirement_message)
 class SummarizerTests(SummarizerTestsMixin, unittest.TestCase):
     def setUp(self) -> None:
         self.spark = SparkSession.builder.master("local[2]").getOrCreate()
diff --git a/python/pyspark/ml/tests/connect/test_legacy_mode_tuning.py 
b/python/pyspark/ml/tests/connect/test_legacy_mode_tuning.py
index 0ade227540c..5f714eeb169 100644
--- a/python/pyspark/ml/tests/connect/test_legacy_mode_tuning.py
+++ b/python/pyspark/ml/tests/connect/test_legacy_mode_tuning.py
@@ -19,26 +19,27 @@
 import tempfile
 import unittest
 import numpy as np
-import pandas as pd
+
 from pyspark.ml.param import Param, Params
-from pyspark.ml.connect import Model, Estimator
-from pyspark.ml.connect.feature import StandardScaler
-from pyspark.ml.connect.classification import LogisticRegression as LORV2
-from pyspark.ml.connect.pipeline import Pipeline
-from pyspark.ml.connect.tuning import CrossValidator, CrossValidatorModel
-from pyspark.ml.connect.evaluation import BinaryClassificationEvaluator, 
RegressionEvaluator
 from pyspark.ml.tuning import ParamGridBuilder
 from pyspark.sql import SparkSession
 from pyspark.sql.functions import rand
+from pyspark.testing.connectutils import should_test_connect, 
connect_requirement_message
 
-from sklearn.datasets import load_breast_cancer
-
-
-have_torch = True
+have_sklearn = True
 try:
-    import torch  # noqa: F401
+    from sklearn.datasets import load_breast_cancer  # noqa: F401
 except ImportError:
-    have_torch = False
+    have_sklearn = False
+
+if should_test_connect:
+    import pandas as pd
+    from pyspark.ml.connect import Model, Estimator
+    from pyspark.ml.connect.feature import StandardScaler
+    from pyspark.ml.connect.classification import LogisticRegression as LORV2
+    from pyspark.ml.connect.pipeline import Pipeline
+    from pyspark.ml.connect.tuning import CrossValidator, CrossValidatorModel
+    from pyspark.ml.connect.evaluation import BinaryClassificationEvaluator, 
RegressionEvaluator
 
 
 class HasInducedError(Params):
@@ -272,6 +273,9 @@ class CrossValidatorTestsMixin:
         cv.fit(train_dataset)
 
 
+@unittest.skipIf(
+    not should_test_connect or not have_sklearn, connect_requirement_message 
or "No sklearn found"
+)
 class CrossValidatorTests(CrossValidatorTestsMixin, unittest.TestCase):
     def setUp(self) -> None:
         self.spark = SparkSession.builder.master("local[2]").getOrCreate()
diff --git a/python/pyspark/ml/tests/connect/test_parity_torch_data_loader.py 
b/python/pyspark/ml/tests/connect/test_parity_torch_data_loader.py
index 18556633d89..68a281dbefa 100644
--- a/python/pyspark/ml/tests/connect/test_parity_torch_data_loader.py
+++ b/python/pyspark/ml/tests/connect/test_parity_torch_data_loader.py
@@ -18,14 +18,14 @@
 import unittest
 from pyspark.sql import SparkSession
 
+from pyspark.ml.torch.tests.test_data_loader import 
TorchDistributorDataLoaderUnitTests
+
 have_torch = True
 try:
     import torch  # noqa: F401
 except ImportError:
     have_torch = False
 
-from pyspark.ml.torch.tests.test_data_loader import 
TorchDistributorDataLoaderUnitTests
-
 
 @unittest.skipIf(not have_torch, "torch is required")
 class 
TorchDistributorBaselineUnitTestsOnConnect(TorchDistributorDataLoaderUnitTests):
diff --git a/python/pyspark/ml/tests/test_functions.py 
b/python/pyspark/ml/tests/test_functions.py
index 894db2f8a7d..e3c7982f92b 100644
--- a/python/pyspark/ml/tests/test_functions.py
+++ b/python/pyspark/ml/tests/test_functions.py
@@ -15,17 +15,28 @@
 # limitations under the License.
 #
 import numpy as np
-import pandas as pd
 import unittest
 
 from pyspark.ml.functions import predict_batch_udf
 from pyspark.sql.functions import array, struct, col
 from pyspark.sql.types import ArrayType, DoubleType, IntegerType, StructType, 
StructField, FloatType
 from pyspark.testing.mlutils import SparkSessionTestCase
-
-
+from pyspark.testing.sqlutils import (
+    have_pandas,
+    have_pyarrow,
+    pandas_requirement_message,
+    pyarrow_requirement_message,
+)
+
+
+@unittest.skipIf(
+    not have_pandas or not have_pyarrow,
+    pandas_requirement_message or pyarrow_requirement_message,
+)
 class PredictBatchUDFTests(SparkSessionTestCase):
     def setUp(self):
+        import pandas as pd
+
         super(PredictBatchUDFTests, self).setUp()
         self.data = np.arange(0, 1000, dtype=np.float64).reshape(-1, 4)
 
diff --git a/python/pyspark/ml/torch/tests/test_distributor.py 
b/python/pyspark/ml/torch/tests/test_distributor.py
index 364ed83f98d..29ee20d1dfb 100644
--- a/python/pyspark/ml/torch/tests/test_distributor.py
+++ b/python/pyspark/ml/torch/tests/test_distributor.py
@@ -18,7 +18,7 @@
 import contextlib
 import os
 import shutil
-from six import StringIO
+from io import StringIO
 import stat
 import subprocess
 import sys
diff --git a/python/pyspark/ml/torch/tests/test_log_communication.py 
b/python/pyspark/ml/torch/tests/test_log_communication.py
index 164c7556d12..5a77b37c64c 100644
--- a/python/pyspark/ml/torch/tests/test_log_communication.py
+++ b/python/pyspark/ml/torch/tests/test_log_communication.py
@@ -18,7 +18,7 @@
 from __future__ import absolute_import, division, print_function
 
 import contextlib
-from six import StringIO
+from io import StringIO
 import sys
 import time
 from typing import Any, Callable


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to