This is an automated email from the ASF dual-hosted git repository.

bolke pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 0a4ed7d557 Add pyspark decorator (#35247)
0a4ed7d557 is described below

commit 0a4ed7d557809ad81ecc50d197c33c8d178c42ce
Author: Bolke de Bruin <[email protected]>
AuthorDate: Wed Nov 1 11:42:58 2023 +0100

    Add pyspark decorator (#35247)
    
    This add the pyspark decorator so that spark can be
    run inline so that results, like dataframes, can be
    shared.
---
 .pre-commit-config.yaml                            |   2 +
 airflow/decorators/__init__.pyi                    |  22 ++++
 .../providers/apache/spark/decorators/__init__.py  |  17 +++
 .../providers/apache/spark/decorators/pyspark.py   | 119 +++++++++++++++++++++
 airflow/providers/apache/spark/provider.yaml       |   4 +
 .../decorators/pyspark.rst                         |  51 +++++++++
 .../index.rst                                      |   1 +
 .../providers/apache/spark/decorators/__init__.py  |  16 +++
 .../apache/spark/decorators/test_pyspark.py        |  80 ++++++++++++++
 .../providers/apache/spark/example_pyspark.py      |  75 +++++++++++++
 10 files changed, 387 insertions(+)

diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index b7c80a8106..e6b59100bf 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -514,6 +514,7 @@ repos:
           ^airflow/providers/apache/cassandra/hooks/cassandra.py$|
           ^airflow/providers/apache/hive/operators/hive_stats.py$|
           ^airflow/providers/apache/hive/transfers/vertica_to_hive.py$|
+          ^airflow/providers/apache/spark/decorators/|
           ^airflow/providers/apache/spark/hooks/|
           ^airflow/providers/apache/spark/operators/|
           ^airflow/providers/exasol/hooks/exasol.py$|
@@ -542,6 +543,7 @@ repos:
           
^docs/apache-airflow-providers-amazon/secrets-backends/aws-ssm-parameter-store.rst$|
           ^docs/apache-airflow-providers-apache-hdfs/connections.rst$|
           ^docs/apache-airflow-providers-apache-kafka/connections/kafka.rst$|
+          ^docs/apache-airflow-providers-apache-spark/decorators/pyspark.rst$|
           
^docs/apache-airflow-providers-google/operators/cloud/kubernetes_engine.rst$|
           
^docs/apache-airflow-providers-microsoft-azure/connections/azure_cosmos.rst$|
           ^docs/conf.py$|
diff --git a/airflow/decorators/__init__.pyi b/airflow/decorators/__init__.pyi
index 0c3e94bf5c..f718e35777 100644
--- a/airflow/decorators/__init__.pyi
+++ b/airflow/decorators/__init__.pyi
@@ -566,6 +566,28 @@ class TaskDecoratorCollection:
         """
     @overload
     def sensor(self, python_callable: Callable[FParams, FReturn] | None = 
None) -> Task[FParams, FReturn]: ...
+    @overload
+    def pyspark(
+        self,
+        *,
+        multiple_outputs: bool | None = None,
+        conn_id: str | None = None,
+        config_kwargs: dict[str, str] | None = None,
+        **kwargs,
+    ) -> TaskDecorator:
+        """
+        Wraps a Python function that is to be injected with a SparkSession.
+
+        :param multiple_outputs: If set, function return value will be 
unrolled to multiple XCom values.
+            Dict will unroll to XCom values with keys as XCom keys. Defaults 
to False.
+        :param conn_id: The connection ID to use for the SparkSession.
+        :param config_kwargs: Additional kwargs to pass to the SparkSession 
builder. This overrides
+            the config from the connection.
+        """
+    @overload
+    def pyspark(
+        self, python_callable: Callable[FParams, FReturn] | None = None
+    ) -> Task[FParams, FReturn]: ...
 
 task: TaskDecoratorCollection
 setup: Callable
diff --git a/airflow/providers/apache/spark/decorators/__init__.py 
b/airflow/providers/apache/spark/decorators/__init__.py
new file mode 100644
index 0000000000..217e5db960
--- /dev/null
+++ b/airflow/providers/apache/spark/decorators/__init__.py
@@ -0,0 +1,17 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/airflow/providers/apache/spark/decorators/pyspark.py 
b/airflow/providers/apache/spark/decorators/pyspark.py
new file mode 100644
index 0000000000..6f576b03a2
--- /dev/null
+++ b/airflow/providers/apache/spark/decorators/pyspark.py
@@ -0,0 +1,119 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import inspect
+from typing import TYPE_CHECKING, Any, Callable, Sequence
+
+from airflow.decorators.base import DecoratedOperator, TaskDecorator, 
task_decorator_factory
+from airflow.hooks.base import BaseHook
+from airflow.operators.python import PythonOperator
+
+if TYPE_CHECKING:
+    from airflow.utils.context import Context
+
+SPARK_CONTEXT_KEYS = ["spark", "sc"]
+
+
+class _PySparkDecoratedOperator(DecoratedOperator, PythonOperator):
+    custom_operator_name = "@task.pyspark"
+
+    template_fields: Sequence[str] = ("op_args", "op_kwargs")
+
+    def __init__(
+        self,
+        python_callable: Callable,
+        op_args: Sequence | None = None,
+        op_kwargs: dict | None = None,
+        conn_id: str | None = None,
+        config_kwargs: dict | None = None,
+        **kwargs,
+    ):
+        self.conn_id = conn_id
+        self.config_kwargs = config_kwargs or {}
+
+        signature = inspect.signature(python_callable)
+        parameters = [
+            param.replace(default=None) if param.name in SPARK_CONTEXT_KEYS 
else param
+            for param in signature.parameters.values()
+        ]
+        # mypy does not understand __signature__ attribute
+        # see https://github.com/python/mypy/issues/12472
+        python_callable.__signature__ = 
signature.replace(parameters=parameters)  # type: ignore[attr-defined]
+
+        kwargs_to_upstream = {
+            "python_callable": python_callable,
+            "op_args": op_args,
+            "op_kwargs": op_kwargs,
+        }
+        super().__init__(
+            kwargs_to_upstream=kwargs_to_upstream,
+            python_callable=python_callable,
+            op_args=op_args,
+            op_kwargs=op_kwargs,
+            **kwargs,
+        )
+
+    def execute(self, context: Context):
+        from pyspark import SparkConf
+        from pyspark.sql import SparkSession
+
+        conf = SparkConf().setAppName(f"{self.dag_id}-{self.task_id}")
+
+        master = "local[*]"
+        if self.conn_id:
+            conn = BaseHook.get_connection(self.conn_id)
+            if conn.port:
+                master = f"{conn.host}:{conn.port}"
+            elif conn.host:
+                master = conn.host
+
+            for key, value in conn.extra_dejson.items():
+                conf.set(key, value)
+
+        conf.setMaster(master)
+
+        # task can override connection config
+        for key, value in self.config_kwargs.items():
+            conf.set(key, value)
+
+        spark = SparkSession.builder.config(conf=conf).getOrCreate()
+        sc = spark.sparkContext
+
+        if not self.op_kwargs:
+            self.op_kwargs = {}
+
+        op_kwargs: dict[str, Any] = dict(self.op_kwargs)
+        op_kwargs["spark"] = spark
+        op_kwargs["sc"] = sc
+
+        self.op_kwargs = op_kwargs
+        return super().execute(context)
+
+
+def pyspark_task(
+    python_callable: Callable | None = None,
+    multiple_outputs: bool | None = None,
+    **kwargs,
+) -> TaskDecorator:
+    return task_decorator_factory(
+        python_callable=python_callable,
+        multiple_outputs=multiple_outputs,
+        decorated_operator_class=_PySparkDecoratedOperator,
+        **kwargs,
+    )
diff --git a/airflow/providers/apache/spark/provider.yaml 
b/airflow/providers/apache/spark/provider.yaml
index be3e791a92..9316f80fa0 100644
--- a/airflow/providers/apache/spark/provider.yaml
+++ b/airflow/providers/apache/spark/provider.yaml
@@ -83,6 +83,10 @@ connection-types:
   - hook-class-name: 
airflow.providers.apache.spark.hooks.spark_submit.SparkSubmitHook
     connection-type: spark
 
+task-decorators:
+  - class-name: airflow.providers.apache.spark.decorators.pyspark.pyspark_task
+    name: pyspark
+
 additional-extras:
   - name: cncf.kubernetes
     dependencies:
diff --git a/docs/apache-airflow-providers-apache-spark/decorators/pyspark.rst 
b/docs/apache-airflow-providers-apache-spark/decorators/pyspark.rst
new file mode 100644
index 0000000000..28b51ec848
--- /dev/null
+++ b/docs/apache-airflow-providers-apache-spark/decorators/pyspark.rst
@@ -0,0 +1,51 @@
+.. Licensed to the Apache Software Foundation (ASF) under one
+   or more contributor license agreements.  See the NOTICE file
+   distributed with this work for additional information
+   regarding copyright ownership.  The ASF licenses this file
+   to you under the Apache License, Version 2.0 (the
+   "License"); you may not use this file except in compliance
+   with the License.  You may obtain a copy of the License at
+
+ ..   http://www.apache.org/licenses/LICENSE-2.0
+
+ .. Unless required by applicable law or agreed to in writing,
+    software distributed under the License is distributed on an
+    "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+    KIND, either express or implied.  See the License for the
+    specific language governing permissions and limitations
+    under the License.
+
+
+
+.. _howto/decorator:pyspark:
+
+PySpark Decorator
+=================
+
+Python callable wrapped within the ``@task.pyspark`` decorator
+is injected with a SparkContext object.
+
+Parameters
+----------
+
+The following parameters can be passed to the decorator:
+
+conn_id: str
+    The connection ID to use for connecting to the Spark cluster. If not
+    specified, the spark master is set to ``local[*]``.
+config_kwargs: dict
+    The kwargs used for initializing the SparkConf object. This overrides
+    the spark configuration options set in the connection.
+
+
+Example
+-------
+
+The following example shows how to use the ``@task.pyspark`` decorator. Note
+that the ``spark`` and ``sc`` objects are injected into the function.
+
+.. exampleinclude:: 
/../../tests/system/providers/apache/spark/example_pyspark.py
+    :language: python
+    :dedent: 4
+    :start-after: [START task_pyspark]
+    :end-before: [END task_pyspark]
diff --git a/docs/apache-airflow-providers-apache-spark/index.rst 
b/docs/apache-airflow-providers-apache-spark/index.rst
index fc3438f0d9..af2d24ba48 100644
--- a/docs/apache-airflow-providers-apache-spark/index.rst
+++ b/docs/apache-airflow-providers-apache-spark/index.rst
@@ -34,6 +34,7 @@
     :caption: Guides
 
     Connection types <connections/spark>
+    Decorators <decorators/pyspark>
     Operators <operators>
 
 .. toctree::
diff --git a/tests/providers/apache/spark/decorators/__init__.py 
b/tests/providers/apache/spark/decorators/__init__.py
new file mode 100644
index 0000000000..13a83393a9
--- /dev/null
+++ b/tests/providers/apache/spark/decorators/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/tests/providers/apache/spark/decorators/test_pyspark.py 
b/tests/providers/apache/spark/decorators/test_pyspark.py
new file mode 100644
index 0000000000..ea307c9c3c
--- /dev/null
+++ b/tests/providers/apache/spark/decorators/test_pyspark.py
@@ -0,0 +1,80 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest import mock
+
+import pytest
+
+from airflow.decorators import task
+from airflow.models import Connection
+from airflow.utils import db, timezone
+
+DEFAULT_DATE = timezone.datetime(2021, 9, 1)
+
+
+class TestPysparkDecorator:
+    def setup_method(self):
+        db.merge_conn(
+            Connection(
+                conn_id="pyspark_local",
+                conn_type="spark",
+                host="spark://none",
+                extra="",
+            )
+        )
+
+    @pytest.mark.db_test
+    @mock.patch("pyspark.SparkConf.setAppName")
+    @mock.patch("pyspark.sql.SparkSession")
+    def test_pyspark_decorator_with_connection(self, spark_mock, conf_mock, 
dag_maker):
+        @task.pyspark(conn_id="pyspark_local", 
config_kwargs={"spark.executor.memory": "2g"})
+        def f(spark, sc):
+            import random
+
+            return [random.random() for _ in range(100)]
+
+        with dag_maker():
+            ret = f()
+
+        dr = dag_maker.create_dagrun()
+        ret.operator.run(start_date=dr.execution_date, 
end_date=dr.execution_date)
+        ti = dr.get_task_instances()[0]
+        assert len(ti.xcom_pull()) == 100
+        conf_mock().set.assert_called_with("spark.executor.memory", "2g")
+        conf_mock().setMaster.assert_called_once_with("spark://none")
+        spark_mock.builder.config.assert_called_once_with(conf=conf_mock())
+
+    @pytest.mark.db_test
+    @mock.patch("pyspark.SparkConf.setAppName")
+    @mock.patch("pyspark.sql.SparkSession")
+    def test_simple_pyspark_decorator(self, spark_mock, conf_mock, dag_maker):
+        e = 2
+
+        @task.pyspark
+        def f():
+            return e
+
+        with dag_maker():
+            ret = f()
+
+        dr = dag_maker.create_dagrun()
+        ret.operator.run(start_date=dr.execution_date, 
end_date=dr.execution_date)
+        ti = dr.get_task_instances()[0]
+        assert ti.xcom_pull() == e
+        conf_mock().setMaster.assert_called_once_with("local[*]")
+        spark_mock.builder.config.assert_called_once_with(conf=conf_mock())
diff --git a/tests/system/providers/apache/spark/example_pyspark.py 
b/tests/system/providers/apache/spark/example_pyspark.py
new file mode 100644
index 0000000000..c671cb40e3
--- /dev/null
+++ b/tests/system/providers/apache/spark/example_pyspark.py
@@ -0,0 +1,75 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import typing
+
+import pendulum
+
+if typing.TYPE_CHECKING:
+    import pandas as pd
+    from pyspark import SparkContext
+    from pyspark.sql import SparkSession
+
+from airflow.decorators import dag, task
+
+
+@dag(
+    schedule=None,
+    start_date=pendulum.datetime(2021, 1, 1, tz="UTC"),
+    catchup=False,
+    tags=["example"],
+)
+def example_pyspark():
+    """
+    ### Example Pyspark DAG
+    This is an example DAG which uses pyspark
+    """
+
+    # [START task_pyspark]
+    @task.pyspark(conn_id="spark-local")
+    def spark_task(spark: SparkSession, sc: SparkContext) -> pd.DataFrame:
+        df = spark.createDataFrame(
+            [
+                (1, "John Doe", 21),
+                (2, "Jane Doe", 22),
+                (3, "Joe Bloggs", 23),
+            ],
+            ["id", "name", "age"],
+        )
+        df.show()
+
+        return df.toPandas()
+
+    # [END task_pyspark]
+
+    @task
+    def print_df(df: pd.DataFrame):
+        print(df)
+
+    df = spark_task()
+    print_df(df)
+
+
+# work around pre-commit
+dag = example_pyspark()
+
+from tests.system.utils import get_test_run  # noqa: E402
+
+# Needed to run the example DAG with pytest (see: 
tests/system/README.md#run_via_pytest)
+test_run = get_test_run(dag)

Reply via email to