This is an automated email from the ASF dual-hosted git repository.
bolke pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 0a4ed7d557 Add pyspark decorator (#35247)
0a4ed7d557 is described below
commit 0a4ed7d557809ad81ecc50d197c33c8d178c42ce
Author: Bolke de Bruin <[email protected]>
AuthorDate: Wed Nov 1 11:42:58 2023 +0100
Add pyspark decorator (#35247)
This add the pyspark decorator so that spark can be
run inline so that results, like dataframes, can be
shared.
---
.pre-commit-config.yaml | 2 +
airflow/decorators/__init__.pyi | 22 ++++
.../providers/apache/spark/decorators/__init__.py | 17 +++
.../providers/apache/spark/decorators/pyspark.py | 119 +++++++++++++++++++++
airflow/providers/apache/spark/provider.yaml | 4 +
.../decorators/pyspark.rst | 51 +++++++++
.../index.rst | 1 +
.../providers/apache/spark/decorators/__init__.py | 16 +++
.../apache/spark/decorators/test_pyspark.py | 80 ++++++++++++++
.../providers/apache/spark/example_pyspark.py | 75 +++++++++++++
10 files changed, 387 insertions(+)
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index b7c80a8106..e6b59100bf 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -514,6 +514,7 @@ repos:
^airflow/providers/apache/cassandra/hooks/cassandra.py$|
^airflow/providers/apache/hive/operators/hive_stats.py$|
^airflow/providers/apache/hive/transfers/vertica_to_hive.py$|
+ ^airflow/providers/apache/spark/decorators/|
^airflow/providers/apache/spark/hooks/|
^airflow/providers/apache/spark/operators/|
^airflow/providers/exasol/hooks/exasol.py$|
@@ -542,6 +543,7 @@ repos:
^docs/apache-airflow-providers-amazon/secrets-backends/aws-ssm-parameter-store.rst$|
^docs/apache-airflow-providers-apache-hdfs/connections.rst$|
^docs/apache-airflow-providers-apache-kafka/connections/kafka.rst$|
+ ^docs/apache-airflow-providers-apache-spark/decorators/pyspark.rst$|
^docs/apache-airflow-providers-google/operators/cloud/kubernetes_engine.rst$|
^docs/apache-airflow-providers-microsoft-azure/connections/azure_cosmos.rst$|
^docs/conf.py$|
diff --git a/airflow/decorators/__init__.pyi b/airflow/decorators/__init__.pyi
index 0c3e94bf5c..f718e35777 100644
--- a/airflow/decorators/__init__.pyi
+++ b/airflow/decorators/__init__.pyi
@@ -566,6 +566,28 @@ class TaskDecoratorCollection:
"""
@overload
def sensor(self, python_callable: Callable[FParams, FReturn] | None =
None) -> Task[FParams, FReturn]: ...
+ @overload
+ def pyspark(
+ self,
+ *,
+ multiple_outputs: bool | None = None,
+ conn_id: str | None = None,
+ config_kwargs: dict[str, str] | None = None,
+ **kwargs,
+ ) -> TaskDecorator:
+ """
+ Wraps a Python function that is to be injected with a SparkSession.
+
+ :param multiple_outputs: If set, function return value will be
unrolled to multiple XCom values.
+ Dict will unroll to XCom values with keys as XCom keys. Defaults
to False.
+ :param conn_id: The connection ID to use for the SparkSession.
+ :param config_kwargs: Additional kwargs to pass to the SparkSession
builder. This overrides
+ the config from the connection.
+ """
+ @overload
+ def pyspark(
+ self, python_callable: Callable[FParams, FReturn] | None = None
+ ) -> Task[FParams, FReturn]: ...
task: TaskDecoratorCollection
setup: Callable
diff --git a/airflow/providers/apache/spark/decorators/__init__.py
b/airflow/providers/apache/spark/decorators/__init__.py
new file mode 100644
index 0000000000..217e5db960
--- /dev/null
+++ b/airflow/providers/apache/spark/decorators/__init__.py
@@ -0,0 +1,17 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/airflow/providers/apache/spark/decorators/pyspark.py
b/airflow/providers/apache/spark/decorators/pyspark.py
new file mode 100644
index 0000000000..6f576b03a2
--- /dev/null
+++ b/airflow/providers/apache/spark/decorators/pyspark.py
@@ -0,0 +1,119 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import inspect
+from typing import TYPE_CHECKING, Any, Callable, Sequence
+
+from airflow.decorators.base import DecoratedOperator, TaskDecorator,
task_decorator_factory
+from airflow.hooks.base import BaseHook
+from airflow.operators.python import PythonOperator
+
+if TYPE_CHECKING:
+ from airflow.utils.context import Context
+
+SPARK_CONTEXT_KEYS = ["spark", "sc"]
+
+
+class _PySparkDecoratedOperator(DecoratedOperator, PythonOperator):
+ custom_operator_name = "@task.pyspark"
+
+ template_fields: Sequence[str] = ("op_args", "op_kwargs")
+
+ def __init__(
+ self,
+ python_callable: Callable,
+ op_args: Sequence | None = None,
+ op_kwargs: dict | None = None,
+ conn_id: str | None = None,
+ config_kwargs: dict | None = None,
+ **kwargs,
+ ):
+ self.conn_id = conn_id
+ self.config_kwargs = config_kwargs or {}
+
+ signature = inspect.signature(python_callable)
+ parameters = [
+ param.replace(default=None) if param.name in SPARK_CONTEXT_KEYS
else param
+ for param in signature.parameters.values()
+ ]
+ # mypy does not understand __signature__ attribute
+ # see https://github.com/python/mypy/issues/12472
+ python_callable.__signature__ =
signature.replace(parameters=parameters) # type: ignore[attr-defined]
+
+ kwargs_to_upstream = {
+ "python_callable": python_callable,
+ "op_args": op_args,
+ "op_kwargs": op_kwargs,
+ }
+ super().__init__(
+ kwargs_to_upstream=kwargs_to_upstream,
+ python_callable=python_callable,
+ op_args=op_args,
+ op_kwargs=op_kwargs,
+ **kwargs,
+ )
+
+ def execute(self, context: Context):
+ from pyspark import SparkConf
+ from pyspark.sql import SparkSession
+
+ conf = SparkConf().setAppName(f"{self.dag_id}-{self.task_id}")
+
+ master = "local[*]"
+ if self.conn_id:
+ conn = BaseHook.get_connection(self.conn_id)
+ if conn.port:
+ master = f"{conn.host}:{conn.port}"
+ elif conn.host:
+ master = conn.host
+
+ for key, value in conn.extra_dejson.items():
+ conf.set(key, value)
+
+ conf.setMaster(master)
+
+ # task can override connection config
+ for key, value in self.config_kwargs.items():
+ conf.set(key, value)
+
+ spark = SparkSession.builder.config(conf=conf).getOrCreate()
+ sc = spark.sparkContext
+
+ if not self.op_kwargs:
+ self.op_kwargs = {}
+
+ op_kwargs: dict[str, Any] = dict(self.op_kwargs)
+ op_kwargs["spark"] = spark
+ op_kwargs["sc"] = sc
+
+ self.op_kwargs = op_kwargs
+ return super().execute(context)
+
+
+def pyspark_task(
+ python_callable: Callable | None = None,
+ multiple_outputs: bool | None = None,
+ **kwargs,
+) -> TaskDecorator:
+ return task_decorator_factory(
+ python_callable=python_callable,
+ multiple_outputs=multiple_outputs,
+ decorated_operator_class=_PySparkDecoratedOperator,
+ **kwargs,
+ )
diff --git a/airflow/providers/apache/spark/provider.yaml
b/airflow/providers/apache/spark/provider.yaml
index be3e791a92..9316f80fa0 100644
--- a/airflow/providers/apache/spark/provider.yaml
+++ b/airflow/providers/apache/spark/provider.yaml
@@ -83,6 +83,10 @@ connection-types:
- hook-class-name:
airflow.providers.apache.spark.hooks.spark_submit.SparkSubmitHook
connection-type: spark
+task-decorators:
+ - class-name: airflow.providers.apache.spark.decorators.pyspark.pyspark_task
+ name: pyspark
+
additional-extras:
- name: cncf.kubernetes
dependencies:
diff --git a/docs/apache-airflow-providers-apache-spark/decorators/pyspark.rst
b/docs/apache-airflow-providers-apache-spark/decorators/pyspark.rst
new file mode 100644
index 0000000000..28b51ec848
--- /dev/null
+++ b/docs/apache-airflow-providers-apache-spark/decorators/pyspark.rst
@@ -0,0 +1,51 @@
+.. Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ .. http://www.apache.org/licenses/LICENSE-2.0
+
+ .. Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+
+
+
+.. _howto/decorator:pyspark:
+
+PySpark Decorator
+=================
+
+Python callable wrapped within the ``@task.pyspark`` decorator
+is injected with a SparkContext object.
+
+Parameters
+----------
+
+The following parameters can be passed to the decorator:
+
+conn_id: str
+ The connection ID to use for connecting to the Spark cluster. If not
+ specified, the spark master is set to ``local[*]``.
+config_kwargs: dict
+ The kwargs used for initializing the SparkConf object. This overrides
+ the spark configuration options set in the connection.
+
+
+Example
+-------
+
+The following example shows how to use the ``@task.pyspark`` decorator. Note
+that the ``spark`` and ``sc`` objects are injected into the function.
+
+.. exampleinclude::
/../../tests/system/providers/apache/spark/example_pyspark.py
+ :language: python
+ :dedent: 4
+ :start-after: [START task_pyspark]
+ :end-before: [END task_pyspark]
diff --git a/docs/apache-airflow-providers-apache-spark/index.rst
b/docs/apache-airflow-providers-apache-spark/index.rst
index fc3438f0d9..af2d24ba48 100644
--- a/docs/apache-airflow-providers-apache-spark/index.rst
+++ b/docs/apache-airflow-providers-apache-spark/index.rst
@@ -34,6 +34,7 @@
:caption: Guides
Connection types <connections/spark>
+ Decorators <decorators/pyspark>
Operators <operators>
.. toctree::
diff --git a/tests/providers/apache/spark/decorators/__init__.py
b/tests/providers/apache/spark/decorators/__init__.py
new file mode 100644
index 0000000000..13a83393a9
--- /dev/null
+++ b/tests/providers/apache/spark/decorators/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/tests/providers/apache/spark/decorators/test_pyspark.py
b/tests/providers/apache/spark/decorators/test_pyspark.py
new file mode 100644
index 0000000000..ea307c9c3c
--- /dev/null
+++ b/tests/providers/apache/spark/decorators/test_pyspark.py
@@ -0,0 +1,80 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest import mock
+
+import pytest
+
+from airflow.decorators import task
+from airflow.models import Connection
+from airflow.utils import db, timezone
+
+DEFAULT_DATE = timezone.datetime(2021, 9, 1)
+
+
+class TestPysparkDecorator:
+ def setup_method(self):
+ db.merge_conn(
+ Connection(
+ conn_id="pyspark_local",
+ conn_type="spark",
+ host="spark://none",
+ extra="",
+ )
+ )
+
+ @pytest.mark.db_test
+ @mock.patch("pyspark.SparkConf.setAppName")
+ @mock.patch("pyspark.sql.SparkSession")
+ def test_pyspark_decorator_with_connection(self, spark_mock, conf_mock,
dag_maker):
+ @task.pyspark(conn_id="pyspark_local",
config_kwargs={"spark.executor.memory": "2g"})
+ def f(spark, sc):
+ import random
+
+ return [random.random() for _ in range(100)]
+
+ with dag_maker():
+ ret = f()
+
+ dr = dag_maker.create_dagrun()
+ ret.operator.run(start_date=dr.execution_date,
end_date=dr.execution_date)
+ ti = dr.get_task_instances()[0]
+ assert len(ti.xcom_pull()) == 100
+ conf_mock().set.assert_called_with("spark.executor.memory", "2g")
+ conf_mock().setMaster.assert_called_once_with("spark://none")
+ spark_mock.builder.config.assert_called_once_with(conf=conf_mock())
+
+ @pytest.mark.db_test
+ @mock.patch("pyspark.SparkConf.setAppName")
+ @mock.patch("pyspark.sql.SparkSession")
+ def test_simple_pyspark_decorator(self, spark_mock, conf_mock, dag_maker):
+ e = 2
+
+ @task.pyspark
+ def f():
+ return e
+
+ with dag_maker():
+ ret = f()
+
+ dr = dag_maker.create_dagrun()
+ ret.operator.run(start_date=dr.execution_date,
end_date=dr.execution_date)
+ ti = dr.get_task_instances()[0]
+ assert ti.xcom_pull() == e
+ conf_mock().setMaster.assert_called_once_with("local[*]")
+ spark_mock.builder.config.assert_called_once_with(conf=conf_mock())
diff --git a/tests/system/providers/apache/spark/example_pyspark.py
b/tests/system/providers/apache/spark/example_pyspark.py
new file mode 100644
index 0000000000..c671cb40e3
--- /dev/null
+++ b/tests/system/providers/apache/spark/example_pyspark.py
@@ -0,0 +1,75 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import typing
+
+import pendulum
+
+if typing.TYPE_CHECKING:
+ import pandas as pd
+ from pyspark import SparkContext
+ from pyspark.sql import SparkSession
+
+from airflow.decorators import dag, task
+
+
+@dag(
+ schedule=None,
+ start_date=pendulum.datetime(2021, 1, 1, tz="UTC"),
+ catchup=False,
+ tags=["example"],
+)
+def example_pyspark():
+ """
+ ### Example Pyspark DAG
+ This is an example DAG which uses pyspark
+ """
+
+ # [START task_pyspark]
+ @task.pyspark(conn_id="spark-local")
+ def spark_task(spark: SparkSession, sc: SparkContext) -> pd.DataFrame:
+ df = spark.createDataFrame(
+ [
+ (1, "John Doe", 21),
+ (2, "Jane Doe", 22),
+ (3, "Joe Bloggs", 23),
+ ],
+ ["id", "name", "age"],
+ )
+ df.show()
+
+ return df.toPandas()
+
+ # [END task_pyspark]
+
+ @task
+ def print_df(df: pd.DataFrame):
+ print(df)
+
+ df = spark_task()
+ print_df(df)
+
+
+# work around pre-commit
+dag = example_pyspark()
+
+from tests.system.utils import get_test_run # noqa: E402
+
+# Needed to run the example DAG with pytest (see:
tests/system/README.md#run_via_pytest)
+test_run = get_test_run(dag)