This is an automated email from the ASF dual-hosted git repository.

zero323 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 6d3cfed  [SPARK-36989][TESTS][PYTHON] Add type hints data tests
6d3cfed is described below

commit 6d3cfed0e8ab2812bc44bfcd5b82d8d227e77a29
Author: zero323 <mszymkiew...@gmail.com>
AuthorDate: Tue Oct 26 10:32:34 2021 +0200

    [SPARK-36989][TESTS][PYTHON] Add type hints data tests
    
    ### What changes were proposed in this pull request?
    
    This PR:
    
    - Adds basic data test runner to `dev/lint-python`, using 
[`typeddjango/pytest-mypy-plugins`](https://github.com/typeddjango/pytest-mypy-plugins)
    - Migrates data test cases from `pyspark-stubs`
    
    In case of failure, a message similar to the following one
    
    ```
    starting mypy annotations test...
    annotations passed mypy checks.
    
    starting mypy data test...
    annotations failed data checks:
    ============================= test session starts 
==============================
    platform linux -- Python 3.9.7, pytest-6.2.5, py-1.10.0, pluggy-1.0.0
    rootdir: /path/to/spark/python, configfile: pyproject.toml
    plugins: mypy-plugins-1.9.2
    collected 37 items
    
    python/pyspark/ml/tests/typing/test_classification.yml ..                [  
5%]
    python/pyspark/ml/tests/typing/test_evaluation.yml .                     [  
8%]
    python/pyspark/ml/tests/typing/test_feature.yml .                        [ 
10%]
    python/pyspark/ml/tests/typing/test_param.yml .                          [ 
13%]
    python/pyspark/ml/tests/typing/test_readable.yml .                       [ 
16%]
    python/pyspark/ml/tests/typing/test_regression.yml ..                    [ 
21%]
    python/pyspark/sql/tests/typing/test_column.yml F                        [ 
24%]
    python/pyspark/sql/tests/typing/test_dataframe.yml .......               [ 
43%]
    python/pyspark/sql/tests/typing/test_functions.yml .                     [ 
45%]
    python/pyspark/sql/tests/typing/test_pandas_compatibility.yml ..         [ 
51%]
    python/pyspark/sql/tests/typing/test_readwriter.yml ..                   [ 
56%]
    python/pyspark/sql/tests/typing/test_session.yml .....                   [ 
70%]
    python/pyspark/sql/tests/typing/test_udf.yml .......                     [ 
89%]
    python/pyspark/tests/typing/test_context.yml .                           [ 
91%]
    python/pyspark/tests/typing/test_core.yml .                              [ 
94%]
    python/pyspark/tests/typing/test_rdd.yml .                               [ 
97%]
    python/pyspark/tests/typing/test_resultiterable.yml .                    
[100%]
    
    =================================== FAILURES 
===================================
    ______________________________ colDateTimeCompare 
______________________________
    /path/to/spark/python/pyspark/sql/tests/typing/test_column.yml:39:
    E   pytest_mypy_plugins.utils.TypecheckAssertionError: Invalid output:
    E   Actual:
    E     main:20: note: Revealed type is "pyspark.sql.column.Column" (diff)
    E   Expected:
    E     main:20: note: Revealed type is "datetime.date*" (diff)
    E   Alignment of first line difference:
    E     E: ...ote: Revealed type is "datetime.date*"
    E     A: ...ote: Revealed type is "pyspark.sql.column.Column"
    E                                  ^
    =========================== short test summary info 
============================
    FAILED python/pyspark/sql/tests/typing/test_column.yml::colDateTimeCompare -
    ======================== 1 failed, 36 passed in 56.13s 
=========================
    ```
    
    will be displayed.
    
    ### Why are the changes needed?
    
    Currently, type annotations are tested primarily for integrity and, to 
lesser extent, against actual API. Testing against examples is work in progress 
(SPARK-36997).  Data tests allow us to improve coverage and test negative cases 
(code, that should fail type checker validation).
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Running linter tests with additions proposed in this PR
    
    Closes #34296 from zero323/SPARK-36989.
    
    Authored-by: zero323 <mszymkiew...@gmail.com>
    Signed-off-by: zero323 <mszymkiew...@gmail.com>
---
 .github/workflows/build_and_test.yml               |   2 +
 dev/lint-python                                    |  75 +++++++--
 python/pyproject.toml                              |  25 +++
 .../ml/tests/typing/test_classification.yml        |  38 +++++
 python/pyspark/ml/tests/typing/test_evaluation.yml |  26 ++++
 python/pyspark/ml/tests/typing/test_feature.yml    |  44 ++++++
 python/pyspark/ml/tests/typing/test_param.yml      |  30 ++++
 python/pyspark/ml/tests/typing/test_readable.yml   |  28 ++++
 python/pyspark/ml/tests/typing/test_regression.yml |  38 +++++
 python/pyspark/sql/tests/typing/test_column.yml    |  37 +++++
 python/pyspark/sql/tests/typing/test_dataframe.yml | 140 +++++++++++++++++
 python/pyspark/sql/tests/typing/test_functions.yml |  90 +++++++++++
 .../sql/tests/typing/test_pandas_compatibility.yml |  35 +++++
 .../pyspark/sql/tests/typing/test_readwriter.yml   |  45 ++++++
 python/pyspark/sql/tests/typing/test_session.yml   | 113 ++++++++++++++
 python/pyspark/sql/tests/typing/test_udf.yml       | 170 +++++++++++++++++++++
 python/pyspark/tests/typing/test_context.yml       |  21 +++
 python/pyspark/tests/typing/test_core.yml          |  20 +++
 python/pyspark/tests/typing/test_rdd.yml           |  62 ++++++++
 .../pyspark/tests/typing/test_resultiterable.yml   |  22 +++
 20 files changed, 1047 insertions(+), 14 deletions(-)

diff --git a/.github/workflows/build_and_test.yml 
b/.github/workflows/build_and_test.yml
index f586d55..3f2d500 100644
--- a/.github/workflows/build_and_test.yml
+++ b/.github/workflows/build_and_test.yml
@@ -442,6 +442,8 @@ jobs:
         # Jinja2 3.0.0+ causes error when building with Sphinx.
         #   See also https://issues.apache.org/jira/browse/SPARK-35375.
         python3.9 -m pip install flake8 pydata_sphinx_theme 'mypy==0.910' 
numpydoc 'jinja2<3.0.0' 'black==21.5b2'
+        # TODO Update to PyPI
+        python3.9 -m pip install 
git+https://github.com/typeddjango/pytest-mypy-plugins.git@b0020061f48e85743ee3335bd62a3a608d17c6bd
     - name: Install R linter dependencies and SparkR
       run: |
         apt-get install -y libcurl4-openssl-dev libgit2-dev libssl-dev 
libxml2-dev
diff --git a/dev/lint-python b/dev/lint-python
index 0639ff6..f04cd1a 100755
--- a/dev/lint-python
+++ b/dev/lint-python
@@ -22,6 +22,7 @@ MINIMUM_MYPY="0.910"
 MYPY_BUILD="mypy"
 PYCODESTYLE_BUILD="pycodestyle"
 MINIMUM_PYCODESTYLE="2.7.0"
+PYTEST_BUILD="pytest"
 
 PYTHON_EXECUTABLE="${PYTHON_EXECUTABLE:-python3}"
 
@@ -123,10 +124,66 @@ function pycodestyle_test {
     fi
 }
 
-function mypy_test {
+
+function mypy_annotation_test {
     local MYPY_REPORT=
     local MYPY_STATUS=
 
+    echo "starting mypy annotations test..."
+    MYPY_REPORT=$( ($MYPY_BUILD \
+      --config-file python/mypy.ini \
+      --cache-dir /tmp/.mypy_cache/ \
+      python/pyspark) 2>&1)
+    MYPY_STATUS=$?
+
+    if [ "$MYPY_STATUS" -ne 0 ]; then
+        echo "annotations failed mypy checks:"
+        echo "$MYPY_REPORT"
+        echo "$MYPY_STATUS"
+        exit "$MYPY_STATUS"
+    else
+        echo "annotations passed mypy checks."
+        echo
+    fi
+}
+
+
+function mypy_data_test {
+    local PYTEST_REPORT=
+    local PYTEST_STATUS=
+
+    echo "starting mypy data test..."
+
+    $PYTHON_EXECUTABLE -c "import importlib.util; import sys; \
+               sys.exit(0 if importlib.util.find_spec('pytest_mypy_plugins') 
else 1)"
+
+    if [ $? -ne 0 ]; then
+      echo "pytest-mypy-plugins missing. Skipping for now."
+      return
+    fi
+
+    PYTEST_REPORT=$( (MYPYPATH=python $PYTEST_BUILD \
+      -c python/pyproject.toml \
+      --rootdir python \
+      --mypy-only-local-stub \
+      --mypy-ini-file python/mypy.ini \
+      python/pyspark ) 2>&1)
+
+    PYTEST_STATUS=$?
+
+    if [ "$PYTEST_STATUS" -ne 0 ]; then
+        echo "annotations failed data checks:"
+        echo "$PYTEST_REPORT"
+        echo "$PYTEST_STATUS"
+        exit "$PYTEST_STATUS"
+    else
+      echo "annotations passed data checks."
+      echo
+    fi
+}
+
+
+function mypy_test {
     if ! hash "$MYPY_BUILD" 2> /dev/null; then
         echo "The $MYPY_BUILD command was not found. Skipping for now."
         return
@@ -141,21 +198,11 @@ function mypy_test {
         return
     fi
 
-    echo "starting mypy test..."
-    MYPY_REPORT=$( ($MYPY_BUILD --config-file python/mypy.ini python/pyspark) 
2>&1)
-    MYPY_STATUS=$?
-
-    if [ "$MYPY_STATUS" -ne 0 ]; then
-        echo "mypy checks failed:"
-        echo "$MYPY_REPORT"
-        echo "$MYPY_STATUS"
-        exit "$MYPY_STATUS"
-    else
-        echo "mypy checks passed."
-        echo
-    fi
+    mypy_annotation_test
+    mypy_data_test
 }
 
+
 function flake8_test {
     local FLAKE8_VERSION=
     local EXPECTED_FLAKE8=
diff --git a/python/pyproject.toml b/python/pyproject.toml
new file mode 100644
index 0000000..286b728
--- /dev/null
+++ b/python/pyproject.toml
@@ -0,0 +1,25 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+[tool.pytest.ini_options]
+# Pytest it used only to run mypy data tests
+python_files = "test_*.yml"
+testpaths = [
+  "pyspark/tests/typing",
+  "pyspark/sql/tests/typing",
+  "pyspark/ml/typing",
+]
diff --git a/python/pyspark/ml/tests/typing/test_classification.yml 
b/python/pyspark/ml/tests/typing/test_classification.yml
new file mode 100644
index 0000000..a6efc76
--- /dev/null
+++ b/python/pyspark/ml/tests/typing/test_classification.yml
@@ -0,0 +1,38 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+- case: oneVsRest
+  main: |
+    from pyspark.ml.classification import (
+        OneVsRest, OneVsRestModel, LogisticRegression, LogisticRegressionModel
+    )
+
+    # Should support
+    OneVsRest(classifier=LogisticRegression())
+    OneVsRest(classifier=LogisticRegressionModel.load("/foo"))  # E: Argument 
"classifier" to "OneVsRest" has incompatible type "LogisticRegressionModel"; 
expected "Optional[Estimator[<nothing>]]"  [arg-type]
+    OneVsRest(classifier="foo")  # E: Argument "classifier" to "OneVsRest" has 
incompatible type "str"; expected "Optional[Estimator[<nothing>]]"  [arg-type]
+
+
+- case: fitFMClassifier
+  main: |
+    from pyspark.sql import SparkSession
+    from pyspark.ml.classification import FMClassifier, FMClassificationModel
+
+    spark = SparkSession.builder.getOrCreate()
+    fm_model: FMClassificationModel = 
FMClassifier().fit(spark.read.parquet("/foo"))
+    fm_model.linear.toArray()
+    fm_model.factors.numRows
diff --git a/python/pyspark/ml/tests/typing/test_evaluation.yml 
b/python/pyspark/ml/tests/typing/test_evaluation.yml
new file mode 100644
index 0000000..e9e8f20
--- /dev/null
+++ b/python/pyspark/ml/tests/typing/test_evaluation.yml
@@ -0,0 +1,26 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+- case: BinaryClassificationEvaluator
+  main: |
+    from pyspark.ml.evaluation import BinaryClassificationEvaluator
+
+    BinaryClassificationEvaluator().setMetricName("areaUnderROC")
+    BinaryClassificationEvaluator(metricName="areaUnderPR")
+
+    BinaryClassificationEvaluator().setMetricName("foo")  # E: Argument 1 to 
"setMetricName" of "BinaryClassificationEvaluator" has incompatible type 
"Literal['foo']"; expected "Union[Literal['areaUnderROC'], 
Literal['areaUnderPR']]"  [arg-type]
+    BinaryClassificationEvaluator(metricName="bar")  # E: Argument 
"metricName" to "BinaryClassificationEvaluator" has incompatible type 
"Literal['bar']"; expected "Union[Literal['areaUnderROC'], 
Literal['areaUnderPR']]"  [arg-type]
diff --git a/python/pyspark/ml/tests/typing/test_feature.yml 
b/python/pyspark/ml/tests/typing/test_feature.yml
new file mode 100644
index 0000000..3d6b090
--- /dev/null
+++ b/python/pyspark/ml/tests/typing/test_feature.yml
@@ -0,0 +1,44 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+- case: stringIndexerOverloads
+  main: |
+    from pyspark.ml.feature import StringIndexer
+
+    # No arguments is OK
+    StringIndexer()
+
+    StringIndexer(inputCol="foo")
+    StringIndexer(outputCol="bar")
+    StringIndexer(inputCol="foo", outputCol="bar")
+
+    StringIndexer(inputCols=["foo"])
+    StringIndexer(outputCols=["bar"])
+    StringIndexer(inputCols=["foo"], outputCols=["bar"])
+
+    StringIndexer(inputCol="foo", outputCols=["bar"])
+    StringIndexer(inputCols=["foo"], outputCol="bar")
+
+  out: |
+    main:14: error: No overload variant of "StringIndexer" matches argument 
types "str", "List[str]"  [call-overload]
+    main:14: note: Possible overload variants:
+    main:14: note:     def StringIndexer(self, *, inputCol: Optional[str] = 
..., outputCol: Optional[str] = ..., handleInvalid: str = ..., stringOrderType: 
str = ...) -> StringIndexer
+    main:14: note:     def StringIndexer(self, *, inputCols: 
Optional[List[str]] = ..., outputCols: Optional[List[str]] = ..., 
handleInvalid: str = ..., stringOrderType: str = ...) -> StringIndexer
+    main:15: error: No overload variant of "StringIndexer" matches argument 
types "List[str]", "str"  [call-overload]
+    main:15: note: Possible overload variants:
+    main:15: note:     def StringIndexer(self, *, inputCol: Optional[str] = 
..., outputCol: Optional[str] = ..., handleInvalid: str = ..., stringOrderType: 
str = ...) -> StringIndexer
+    main:15: note:     def StringIndexer(self, *, inputCols: 
Optional[List[str]] = ..., outputCols: Optional[List[str]] = ..., 
handleInvalid: str = ..., stringOrderType: str = ...) -> StringIndexer
\ No newline at end of file
diff --git a/python/pyspark/ml/tests/typing/test_param.yml 
b/python/pyspark/ml/tests/typing/test_param.yml
new file mode 100644
index 0000000..0b423f0
--- /dev/null
+++ b/python/pyspark/ml/tests/typing/test_param.yml
@@ -0,0 +1,30 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+- case: paramGenric
+  main: |
+    from pyspark.ml.param import Param, Params, TypeConverters
+
+    class Foo(Params):
+        foo = Param(Params(), "foo", "foo", TypeConverters.toInt)
+        def getFoo(self) -> int:
+            return self.getOrDefault(self.foo)
+
+    class Bar(Params):
+        bar = Param(Params(), "bar", "bar", TypeConverters.toInt)
+        def getFoo(self) -> str:
+            return self.getOrDefault(self.bar)  # E: Incompatible return value 
type (got "int", expected "str")  [return-value]
diff --git a/python/pyspark/ml/tests/typing/test_readable.yml 
b/python/pyspark/ml/tests/typing/test_readable.yml
new file mode 100644
index 0000000..772133a
--- /dev/null
+++ b/python/pyspark/ml/tests/typing/test_readable.yml
@@ -0,0 +1,28 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+- case: readLinearSVCModel
+  main: |
+    from pyspark.ml.classification import LinearSVCModel
+
+    model1 = LinearSVCModel.load("dummy")
+    model1.coefficients.toArray()
+    model1.foo()  # E: "LinearSVCModel" has no attribute "foo"  [attr-defined]
+
+    model2 = LinearSVCModel.read().load("dummy")
+    model2.coefficients.toArray()
+    model2.foo()  # E: "LinearSVCModel" has no attribute "foo"  [attr-defined]
diff --git a/python/pyspark/ml/tests/typing/test_regression.yml 
b/python/pyspark/ml/tests/typing/test_regression.yml
new file mode 100644
index 0000000..b045bec
--- /dev/null
+++ b/python/pyspark/ml/tests/typing/test_regression.yml
@@ -0,0 +1,38 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+- case: loadFMRegressor
+  main: |
+    from pyspark.ml.regression import FMRegressor, FMRegressionModel
+
+    fm = FMRegressor.load("/foo")
+    fm.setMiniBatchFraction(0.1)
+
+    fm_model = FMRegressionModel.load("/bar")
+    fm_model.factors.numCols
+
+    fm_model.foo()  # E: "FMRegressionModel" has no attribute "foo"  
[attr-defined]
+
+
+- case: loadLinearRegressor
+  main: |
+    from pyspark.ml.regression import LinearRegressionModel
+
+    lr_model = LinearRegressionModel.load("/foo")
+    lr_model.getLabelCol().upper()
+
+    lr_model.foo  # E: "LinearRegressionModel" has no attribute "foo"  
[attr-defined]
diff --git a/python/pyspark/sql/tests/typing/test_column.yml 
b/python/pyspark/sql/tests/typing/test_column.yml
new file mode 100644
index 0000000..26eeb61
--- /dev/null
+++ b/python/pyspark/sql/tests/typing/test_column.yml
@@ -0,0 +1,37 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+- case: colDateTimeCompare
+  main: |
+    import datetime
+    from pyspark.sql.functions import col
+
+    today = datetime.date.today()
+    now = datetime.datetime.now()
+    a_col = col("")
+
+    a_col < today
+    a_col <= today
+    a_col == today
+    a_col >= today
+    a_col > today
+
+    a_col < now
+    a_col <= now
+    a_col == now
+    a_col >= now
+    a_col > now
diff --git a/python/pyspark/sql/tests/typing/test_dataframe.yml 
b/python/pyspark/sql/tests/typing/test_dataframe.yml
new file mode 100644
index 0000000..79a3bcd
--- /dev/null
+++ b/python/pyspark/sql/tests/typing/test_dataframe.yml
@@ -0,0 +1,140 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+- case: sampling
+  main: |
+    from pyspark.sql import SparkSession
+
+    spark = SparkSession.builder.getOrCreate()
+    df = spark.range(1)
+    df.sample(1.0)
+    df.sample(0.5, 3)
+    df.sample(fraction=0.5, seed=3)
+    df.sample(withReplacement=True, fraction=0.5, seed=3)
+    df.sample(fraction=1.0)
+    df.sample(False, fraction=1.0)
+
+    # Will raise a runtime error, though not a typecheck error, as bool is
+    # duck-type compatible with float.
+    df.sample(True)
+
+    df.sample(withReplacement=False)
+
+  out: |
+    main:16: error: No overload variant of "sample" of "DataFrame" matches 
argument type "bool"  [call-overload]
+    main:16: note: Possible overload variants:
+    main:16: note:     def sample(self, fraction: float, seed: Optional[int] = 
...) -> DataFrame
+    main:16: note:     def sample(self, withReplacement: Optional[bool], 
fraction: float, seed: Optional[int] = ...) -> DataFrame
+
+
+- case: selectColumns
+  main: |
+    from pyspark.sql import SparkSession
+    from pyspark.sql.functions import col
+
+    spark = SparkSession.builder.getOrCreate()
+
+    data = [('Alice', 1)]
+    df = spark.createDataFrame(data, schema="name str, age int")
+
+    df.select(["name", "age"])
+    df.select([col("name"), col("age")])
+
+    df.select(["name", col("age")])  # E: Argument 1 to "select" of 
"DataFrame" has incompatible type "List[object]"; expected "Union[List[Column], 
List[str]]"  [arg-type]
+
+
+- case: groupBy
+  main: |
+    from pyspark.sql import SparkSession
+    from pyspark.sql.functions import col
+
+    spark = SparkSession.builder.getOrCreate()
+
+    data = [('Alice', 1)]
+    df = spark.createDataFrame(data, schema="name str, age int")
+
+    df.groupBy(["name", "age"])
+    df.groupby(["name", "age"])
+    df.groupBy([col("name"), col("age")])
+    df.groupby([col("name"), col("age")])
+    df.groupBy(["name", col("age")])  # E: Argument 1 to "groupBy" of 
"DataFrame" has incompatible type "List[object]"; expected "Union[List[Column], 
List[str]]"  [arg-type]
+
+
+- case: rollup
+  main: |
+    from pyspark.sql import SparkSession
+    from pyspark.sql.functions import col
+
+    spark = SparkSession.builder.getOrCreate()
+
+    data = [('Alice', 1)]
+    df = spark.createDataFrame(data, schema="name str, age int")
+
+    df.rollup(["name", "age"])
+    df.rollup([col("name"), col("age")])
+
+
+    df.rollup(["name", col("age")])  # E: Argument 1 to "rollup" of 
"DataFrame" has incompatible type "List[object]"; expected "Union[List[Column], 
List[str]]"  [arg-type]
+
+
+- case: cube
+  main: |
+    from pyspark.sql import SparkSession
+    from pyspark.sql.functions import col
+
+    spark = SparkSession.builder.getOrCreate()
+
+    data = [('Alice', 1)]
+    df = spark.createDataFrame(data, schema="name str, age int")
+
+    df.cube(["name", "age"])
+    df.cube([col("name"), col("age")])
+
+
+    df.cube(["name", col("age")])  # E: Argument 1 to "cube" of "DataFrame" 
has incompatible type "List[object]"; expected "Union[List[Column], List[str]]" 
 [arg-type]
+
+
+- case: dropColumns
+  main: |
+    from pyspark.sql import SparkSession
+    from pyspark.sql.functions import lit, col
+
+    spark = SparkSession.builder.getOrCreate()
+    df = spark.range(1)
+    df.drop("id")
+    df.drop("id", "foo")
+    df.drop(df.id)
+
+    df.drop(col("id"), col("foo"))
+
+  out: |
+    main:10: error: No overload variant of "drop" of "DataFrame" matches 
argument types "Column", "Column"  [call-overload]
+    main:10: note: Possible overload variant:
+    main:10: note:     def drop(self, *cols: str) -> DataFrame
+    main:10: note:     <1 more non-matching overload not shown>
+
+
+- case: fillNullValues
+  main: |
+    from pyspark.sql import SparkSession
+
+    spark = SparkSession.builder.getOrCreate()
+    df = spark.createDataFrame([(1,2)], schema=("id1", "id2"))
+
+    df.fillna(value=1, subset="id1")
+    df.fillna(value=1, subset=("id1", "id2"))
+    df.fillna(value=1, subset=["id1"])
diff --git a/python/pyspark/sql/tests/typing/test_functions.yml 
b/python/pyspark/sql/tests/typing/test_functions.yml
new file mode 100644
index 0000000..70f6fd9
--- /dev/null
+++ b/python/pyspark/sql/tests/typing/test_functions.yml
@@ -0,0 +1,90 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+- case: varargFunctionsOverloads
+  expect_fail: true  # TODO: Remove once SPARK-37085 is resolved
+  main: |
+    from pyspark.sql.functions import (
+      array,
+      col,
+      create_map,
+      map_concat,
+      struct,
+    )
+
+    array(col("foo"), col("bar"))
+    array([col("foo"), col("bar")])
+    array("foo", "bar")
+    array(["foo", "bar"])
+
+    create_map(col("foo"), col("bar"))
+    create_map([col("foo"), col("bar")])
+    create_map("foo", "bar")
+    create_map(["foo", "bar"])
+
+    map_concat(col("foo"), col("bar"))
+    map_concat([col("foo"), col("bar")])
+    map_concat("foo", "bar")
+    map_concat(["foo", "bar"])
+
+    struct(col("foo"), col("bar"))
+    struct([col("foo"), col("bar")])
+    struct("foo", "bar")
+    struct(["foo", "bar"])
+
+    array([col("foo")], [col("bar")])
+    create_map([col("foo")], [col("bar")])
+    map_concat([col("foo")], [col("bar")])
+    struct(["foo"], ["bar"])
+    array(["foo"], ["bar"])
+    create_map(["foo"], ["bar"])
+    map_concat(["foo"], ["bar"])
+    struct(["foo"], ["bar"])
+
+  out: |
+    main.py:29: error: No overload variant of "array" matches argument types 
"List[Column]", "List[Column]"
+    main.py:29: note: Possible overload variant:
+    main.py:29: note:     def array(*cols: Union[Column, str]) -> Column
+    main.py:29: note:     <1 more non-matching overload not shown>
+    main.py:30: error: No overload variant of "create_map" matches argument 
types "List[Column]", "List[Column]"
+    main.py:30: note: Possible overload variant:
+    main.py:30: note:     def create_map(*cols: Union[Column, str]) -> Column
+    main.py:30: note:     <1 more non-matching overload not shown>
+    main.py:31: error: No overload variant of "map_concat" matches argument 
types "List[Column]", "List[Column]"
+    main.py:31: note: Possible overload variant:
+    main.py:31: note:     def map_concat(*cols: Union[Column, str]) -> Column
+    main.py:31: note:     <1 more non-matching overload not shown>
+    main.py:32: error: No overload variant of "struct" matches argument types 
"List[str]", "List[str]"
+    main.py:32: note: Possible overload variant:
+    main.py:32: note:     def struct(*cols: Union[Column, str]) -> Column
+    main.py:32: note:     <1 more non-matching overload not shown>
+    main.py:33: error: No overload variant of "array" matches argument types 
"List[str]", "List[str]"
+    main.py:33: note: Possible overload variant:
+    main.py:33: note:     def array(*cols: Union[Column, str]) -> Column
+    main.py:33: note:     <1 more non-matching overload not shown>
+    main.py:34: error: No overload variant of "create_map" matches argument 
types "List[str]", "List[str]"
+    main.py:34: note: Possible overload variant:
+    main.py:34: note:     def create_map(*cols: Union[Column, str]) -> Column
+    main.py:34: note:     <1 more non-matching overload not shown>
+    main.py:35: error: No overload variant of "map_concat" matches argument 
types "List[str]", "List[str]"
+    main.py:35: note: Possible overload variant:
+    main.py:35: note:     def map_concat(*cols: Union[Column, str]) -> Column
+    main.py:35: note:     <1 more non-matching overload not shown>
+    main.py:36: error: No overload variant of "struct" matches argument types 
"List[str]", "List[str]"
+    main.py:36: note: Possible overload variant:
+    main.py:36: note:     def struct(*cols: Union[Column, str]) -> Column
+    main.py:36: note:     <1 more non-matching overload not shown>
diff --git a/python/pyspark/sql/tests/typing/test_pandas_compatibility.yml 
b/python/pyspark/sql/tests/typing/test_pandas_compatibility.yml
new file mode 100644
index 0000000..00d44a7
--- /dev/null
+++ b/python/pyspark/sql/tests/typing/test_pandas_compatibility.yml
@@ -0,0 +1,35 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+- case: pandasDataFrameProtocol
+  main: |
+    from pyspark.sql.pandas._typing import PandasDataFrame
+
+    df = PandasDataFrame({"a":  1})
+    df.groupby("a")
+
+    df.foo()  # E: "DataFrameLike" has no attribute "foo"  [attr-defined]
+
+
+- case: PandasSeriesProtocol
+  main: |
+    from pyspark.sql.pandas._typing import PandasSeries
+
+    series = PandasSeries([1, 2, 3])
+    series.nsmallest(3)
+
+    series.foo()  # E: "SeriesLike" has no attribute "foo"  [attr-defined]
diff --git a/python/pyspark/sql/tests/typing/test_readwriter.yml 
b/python/pyspark/sql/tests/typing/test_readwriter.yml
new file mode 100644
index 0000000..2ce3637
--- /dev/null
+++ b/python/pyspark/sql/tests/typing/test_readwriter.yml
@@ -0,0 +1,45 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+- case: readWriterOptions
+  main: |
+    from pyspark.sql import SparkSession
+
+    spark = SparkSession.builder.getOrCreate()
+
+    spark.read.option("foo", 1)
+    spark.createDataFrame([(1, 2)], ["foo", "bar"]).write.option("bar", True)
+
+    spark.read.load(foo=True)
+
+    spark.read.load(foo=["a"])  # E: Argument "foo" to "load" of 
"DataFrameReader" has incompatible type "List[str]"; expected "Union[bool, 
float, int, str, None]"  [arg-type]
+    spark.read.option("foo", (1, ))  # E: Argument 2 to "option" of 
"DataFrameReader" has incompatible type "Tuple[int]"; expected "Union[bool, 
float, int, str, None]"  [arg-type]
+
+
+- case: readStreamOptions
+  main: |
+    from pyspark.sql import SparkSession
+
+    spark = SparkSession.builder.getOrCreate()
+
+    spark.read.option("foo", True).option("foo", 1).option("foo", 
1.0).option("foo", "1").option("foo", None)
+    spark.readStream.option("foo", True).option("foo", 1).option("foo", 
1.0).option("foo", "1").option("foo", None)
+
+    spark.read.options(foo=True, bar=1).options(foo=1.0, bar="1", baz=None)
+    spark.readStream.options(foo=True, bar=1).options(foo=1.0, bar="1", 
baz=None)
+
+    spark.readStream.load(foo=True)
diff --git a/python/pyspark/sql/tests/typing/test_session.yml 
b/python/pyspark/sql/tests/typing/test_session.yml
new file mode 100644
index 0000000..f06e79e
--- /dev/null
+++ b/python/pyspark/sql/tests/typing/test_session.yml
@@ -0,0 +1,113 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+- case: createDataFrameStructsValid
+  main: |
+    from pyspark.sql import SparkSession
+    from pyspark.sql.types import StructType, StructField, StringType, 
IntegerType
+
+    spark = SparkSession.builder.getOrCreate()
+
+    data = [('Alice', 1)]
+    schema = StructType([
+        StructField("name", StringType(), True),
+        StructField("age", IntegerType(), True)
+    ])
+
+    # Valid structs
+    spark.createDataFrame(data)
+    spark.createDataFrame(data, samplingRatio=0.1)
+    spark.createDataFrame(data, ("name", "age"))
+    spark.createDataFrame(data, schema)
+    spark.createDataFrame(data, "name string, age integer")
+    spark.createDataFrame([(1, ("foo", "bar"))], ("_1", "_2"))
+    spark.createDataFrame(data, ("name", "age"), samplingRatio=0.1)  # type: 
ignore
+
+
+- case: createDataFrameScalarsValid
+  main: |
+
+    from pyspark.sql import SparkSession
+    from pyspark.sql.types import StructType, StructField, StringType, 
IntegerType
+
+    spark = SparkSession.builder.getOrCreate()
+
+    # Scalars
+    spark.createDataFrame([1, 2, 3], IntegerType())
+    spark.createDataFrame(["foo", "bar"], "string")
+
+
+- case: createDataFrameScalarsInvalid
+  main: |
+    from pyspark.sql import SparkSession
+    from pyspark.sql.types import StructType, StructField, StringType, 
IntegerType
+
+    spark = SparkSession.builder.getOrCreate()
+
+    schema = StructType([
+        StructField("name", StringType(), True),
+        StructField("age", IntegerType(), True)
+    ])
+
+    # Invalid - scalars require schema
+    spark.createDataFrame(["foo", "bar"]) # E: Value of type variable 
"RowLike" of "createDataFrame" of "SparkSession" cannot be "str"  [type-var]
+
+    # Invalid - data has to match schema (either product -> struct or scalar 
-> atomic)
+    spark.createDataFrame([1, 2, 3], schema) # E: Value of type variable 
"RowLike" of "createDataFrame" of "SparkSession" cannot be "int"  [type-var]
+
+
+- case: createDataFrameStructsInvalid
+  main: |
+    from pyspark.sql import SparkSession
+    from pyspark.sql.types import StructType, StructField, StringType, 
IntegerType
+
+    spark = SparkSession.builder.getOrCreate()
+
+    data = [('Alice', 1)]
+
+    schema = StructType([
+        StructField("name", StringType(), True),
+        StructField("age", IntegerType(), True)
+    ])
+
+    # Invalid product should have StructType schema
+    spark.createDataFrame(data, IntegerType())
+
+    # This shouldn't type check, though is technically speaking valid
+    # because  samplingRatio is ignored
+    spark.createDataFrame(data, schema, samplingRatio=0.1)
+
+  out: |
+    main:14: error: Argument 1 to "createDataFrame" of "SparkSession" has 
incompatible type "List[Tuple[str, int]]"; expected 
"Union[RDD[Union[Union[datetime, date], Union[bool, float, int, str], 
Decimal]], Iterable[Union[Union[datetime, date], Union[bool, float, int, str], 
Decimal]]]"  [arg-type]
+    main:18: error: No overload variant of "createDataFrame" of "SparkSession" 
matches argument types "List[Tuple[str, int]]", "StructType", "float"  
[call-overload]
+    main:18: note: Possible overload variants:
+    main:18: note:     def [RowLike in (List[Any], Tuple[Any, ...], Row)] 
createDataFrame(self, data: Union[RDD[RowLike], Iterable[RowLike]], 
samplingRatio: Optional[float] = ...) -> DataFrame
+    main:18: note:     def [RowLike in (List[Any], Tuple[Any, ...], Row)] 
createDataFrame(self, data: Union[RDD[RowLike], Iterable[RowLike]], schema: 
Union[List[str], Tuple[str, ...]] = ..., verifySchema: bool = ...) -> DataFrame
+    main:18: note:     <4 more similar overloads not shown, out of 6 total 
overloads>
+
+
+- case: createDataFrameFromEmptyRdd
+  main: |
+    from pyspark.sql import SparkSession
+    from pyspark.sql.types import StructType
+
+    spark = SparkSession.builder.getOrCreate()
+
+    spark.createDataFrame(
+        spark.sparkContext.emptyRDD(),
+        schema=StructType(),
+    )
diff --git a/python/pyspark/sql/tests/typing/test_udf.yml 
b/python/pyspark/sql/tests/typing/test_udf.yml
new file mode 100644
index 0000000..3860830
--- /dev/null
+++ b/python/pyspark/sql/tests/typing/test_udf.yml
@@ -0,0 +1,170 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+- case: scalarUDF
+  main: |
+    from pyspark.sql.pandas.functions import pandas_udf, PandasUDFType
+    import pandas.core.series  # type: ignore[import]
+    import pandas.core.frame  # type: ignore[import]
+
+    @pandas_udf("str", PandasUDFType.SCALAR)
+    def f(x: pandas.core.series.Series) -> pandas.core.series.Series:
+        return x
+
+    @pandas_udf("str", PandasUDFType.SCALAR)
+    def g(x: pandas.core.series.Series, y: pandas.core.series.Series) -> 
pandas.core.series.Series:
+        return x
+
+    @pandas_udf("str", PandasUDFType.SCALAR)
+    def h(*xs: pandas.core.series.Series) -> pandas.core.series.Series:
+        return xs[0]
+
+    @pandas_udf("str", PandasUDFType.SCALAR)
+    def k(x: pandas.core.frame.DataFrame, y: pandas.core.series.Series) -> 
pandas.core.series.Series:
+        return x
+
+    pandas_udf(lambda x: x, "str", PandasUDFType.SCALAR)
+    pandas_udf(lambda x, y: x, "str", PandasUDFType.SCALAR)
+    pandas_udf(lambda *xs: xs[0], "str", PandasUDFType.SCALAR)
+
+
+- case: scalarIterUDF
+  main: |
+    from pyspark.sql.pandas.functions import pandas_udf, PandasUDFType
+    from pyspark.sql.types import IntegerType
+    import pandas.core.series  # type: ignore[import]
+    from typing import Iterable
+
+    @pandas_udf(IntegerType(), PandasUDFType.SCALAR_ITER)
+    def f(xs: pandas.core.series.Series) -> 
Iterable[pandas.core.series.Series]:
+        for x in xs:
+            yield x + 1
+
+
+- case: groupedMapUdf
+  main: |
+    from typing import Any
+
+    from pyspark.sql.session import SparkSession
+    from pyspark.sql.types import StructField, StructType, LongType
+    from pyspark.sql.pandas.functions import pandas_udf, PandasUDFType
+    import pandas.core.frame  # type: ignore[import]
+
+    @pandas_udf("id long", PandasUDFType.GROUPED_MAP)
+    def f(pdf: pandas.core.frame.DataFrame) -> pandas.core.frame.DataFrame:
+       return pdf
+
+    spark = SparkSession.builder.getOrCreate()
+
+    dfg = spark.range(1).groupBy("id")
+    dfg.apply(f)
+
+    @pandas_udf("id long", PandasUDFType.GROUPED_MAP)
+    def g(key: Any, pdf: pandas.core.frame.DataFrame) -> 
pandas.core.frame.DataFrame:
+       return pdf
+
+    dfg.apply(g)
+
+
+    def h(pdf: pandas.core.frame.DataFrame) -> pandas.core.frame.DataFrame:
+       return pdf
+
+    dfg.applyInPandas(h, "id long")
+    dfg.applyInPandas(h, StructType([StructField("id", LongType())]))
+
+
+- case: groupedAggUDF
+  main: |
+    # Let's keep this one to make sure compatibility imports work
+    from pyspark.sql.functions import pandas_udf, PandasUDFType
+    from pyspark.sql.types import IntegerType
+    import pandas.core.series  # type: ignore[import]
+
+    @pandas_udf(IntegerType(), PandasUDFType.GROUPED_AGG)
+    def f(x: pandas.core.series.Series) -> int:
+        return 42
+
+    @pandas_udf("int", PandasUDFType.GROUPED_AGG)
+    def g(x: pandas.core.series.Series, y: pandas.core.series.Series) -> int:
+        return 42
+
+    @pandas_udf("int", PandasUDFType.GROUPED_AGG)
+    def h(*xs: pandas.core.series.Series) -> int:
+        return 42
+
+    pandas_udf(lambda x: 42, "str", PandasUDFType.GROUPED_AGG)
+    pandas_udf(lambda x, y: 42, "str", PandasUDFType.GROUPED_AGG)
+    pandas_udf(lambda *xs: 42, "str", PandasUDFType.GROUPED_AGG)
+
+
+- case: mapIterUdf
+  main: |
+    from pyspark.sql.session import SparkSession
+    from typing import Iterable
+    import pandas.core.frame  # type: ignore[import]
+
+    spark = SparkSession.builder.getOrCreate()
+
+    def f(batch_iter: Iterable[pandas.core.frame.DataFrame]) -> 
Iterable[pandas.core.frame.DataFrame]:
+        for pdf in batch_iter:
+            yield pdf[pdf.id == 1]
+
+    spark.range(1).mapInPandas(f, "id long").show()
+
+
+- case: legacyUDF
+  main: |
+    from pyspark.sql.functions import udf
+    from pyspark.sql.types import IntegerType
+
+    udf(lambda x: x, "string")
+
+    udf(lambda x: x)
+
+    @udf("string")
+    def f(x: str) -> str:
+        return x
+
+    @udf(returnType="string")
+    def g(x: str) -> str:
+        return x
+
+    @udf(returnType=IntegerType())
+    def h(x: int) -> int:
+        return x
+
+    @udf
+    def i(x: str) -> str:
+        return x
+
+
+- case: cogroupedAggUdf
+  main: |
+    from pyspark.sql.session import SparkSession
+    import pandas.core.frame  # type: ignore[import]
+    from  pyspark.sql.types import StructType, StructField, LongType
+
+    spark = SparkSession.builder.getOrCreate()
+
+    dfg1 = spark.range(1).groupBy("id")
+    dfg2 = spark.range(1).groupBy("id")
+
+    def f(x: pandas.core.frame.DataFrame, y: pandas.core.frame.DataFrame) -> 
pandas.core.frame.DataFrame:
+        return x
+
+    dfg1.cogroup(dfg2).applyInPandas(f, "id int")
+    dfg1.cogroup(dfg2).applyInPandas(f, StructType([StructField("id", 
LongType())]))
diff --git a/python/pyspark/tests/typing/test_context.yml 
b/python/pyspark/tests/typing/test_context.yml
new file mode 100644
index 0000000..1217651
--- /dev/null
+++ b/python/pyspark/tests/typing/test_context.yml
@@ -0,0 +1,21 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+- case: contextInitalization
+  main: |
+    from pyspark import SparkContext
+    sc: SparkContext = SparkContext()
diff --git a/python/pyspark/tests/typing/test_core.yml 
b/python/pyspark/tests/typing/test_core.yml
new file mode 100644
index 0000000..ff58613
--- /dev/null
+++ b/python/pyspark/tests/typing/test_core.yml
@@ -0,0 +1,20 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+- case: coreImports
+  main: |
+    from pyspark import keyword_only, Row, SQLContext
diff --git a/python/pyspark/tests/typing/test_rdd.yml 
b/python/pyspark/tests/typing/test_rdd.yml
new file mode 100644
index 0000000..749ad53
--- /dev/null
+++ b/python/pyspark/tests/typing/test_rdd.yml
@@ -0,0 +1,62 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+- case: toDF
+  main: |
+    from pyspark.sql.types import (
+      IntegerType,
+      Row,
+      StructType,
+      StringType,
+      StructField,
+    )
+    from collections import namedtuple
+    from pyspark.sql import SparkSession
+
+    spark = SparkSession.builder.getOrCreate()
+    sc = spark.sparkContext
+
+    struct = StructType([
+        StructField("a", IntegerType()),
+        StructField("b", StringType())
+    ])
+
+    AB = namedtuple("AB", ["a", "b"])
+
+    rdd_row = sc.parallelize([Row(a=1, b="foo")])
+    rdd_row.toDF()
+    rdd_row.toDF(sampleRatio=0.4)
+    rdd_row.toDF(["a", "b"], sampleRatio=0.4)
+    rdd_row.toDF(struct)
+
+    rdd_tuple = sc.parallelize([(1, "foo")])
+    rdd_tuple.toDF()
+    rdd_tuple.toDF(sampleRatio=0.4)
+    rdd_tuple.toDF(["a", "b"], sampleRatio=0.4)
+    rdd_tuple.toDF(struct)
+
+    rdd_list = sc.parallelize([[1, "foo"]])
+    rdd_list.toDF()
+    rdd_list.toDF(sampleRatio=0.4)
+    rdd_list.toDF(["a", "b"], sampleRatio=0.4)
+    rdd_list.toDF(struct)
+
+    rdd_named_tuple = sc.parallelize([AB(1, "foo")])
+    rdd_named_tuple.toDF()
+    rdd_named_tuple.toDF(sampleRatio=0.4)
+    rdd_named_tuple.toDF(["a", "b"], sampleRatio=0.4)
+    rdd_named_tuple.toDF(struct)
diff --git a/python/pyspark/tests/typing/test_resultiterable.yml 
b/python/pyspark/tests/typing/test_resultiterable.yml
new file mode 100644
index 0000000..4261a29
--- /dev/null
+++ b/python/pyspark/tests/typing/test_resultiterable.yml
@@ -0,0 +1,22 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+- case: resultIterable
+  main: |
+    from pyspark.resultiterable import ResultIterable
+
+    ResultIterable([])

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to