This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 976064dc6c Add Snowpark operator and decorator (#42457)
976064dc6c is described below
commit 976064dc6ce95d3b5cead1a7d2fcad4971c61b9a
Author: Jianzhun Du <[email protected]>
AuthorDate: Wed Oct 2 18:55:57 2024 -0700
Add Snowpark operator and decorator (#42457)
---
airflow/providers/snowflake/decorators/__init__.py | 16 ++
airflow/providers/snowflake/decorators/snowpark.py | 124 ++++++++++++
airflow/providers/snowflake/hooks/snowflake.py | 22 ++
airflow/providers/snowflake/operators/snowpark.py | 133 +++++++++++++
airflow/providers/snowflake/provider.yaml | 7 +
airflow/providers/snowflake/utils/snowpark.py | 44 ++++
.../decorators/index.rst | 25 +++
.../decorators/snowpark.rst | 70 +++++++
docs/apache-airflow-providers-snowflake/index.rst | 1 +
.../operators/snowpark.rst | 74 +++++++
docs/spelling_wordlist.txt | 3 +
generated/provider_dependencies.json | 1 +
tests/providers/snowflake/decorators/__init__.py | 16 ++
.../snowflake/decorators/test_snowpark.py | 221 +++++++++++++++++++++
tests/providers/snowflake/hooks/test_snowflake.py | 27 +++
.../providers/snowflake/operators/test_snowpark.py | 181 +++++++++++++++++
tests/providers/snowflake/utils/test_snowpark.py | 36 ++++
.../snowflake/example_snowpark_decorator.py | 85 ++++++++
.../snowflake/example_snowpark_operator.py | 94 +++++++++
19 files changed, 1180 insertions(+)
diff --git a/airflow/providers/snowflake/decorators/__init__.py
b/airflow/providers/snowflake/decorators/__init__.py
new file mode 100644
index 0000000000..13a83393a9
--- /dev/null
+++ b/airflow/providers/snowflake/decorators/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/airflow/providers/snowflake/decorators/snowpark.py
b/airflow/providers/snowflake/decorators/snowpark.py
new file mode 100644
index 0000000000..406d817e9d
--- /dev/null
+++ b/airflow/providers/snowflake/decorators/snowpark.py
@@ -0,0 +1,124 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Callable, Sequence
+
+from airflow.decorators.base import DecoratedOperator, task_decorator_factory
+from airflow.providers.snowflake.operators.snowpark import SnowparkOperator
+from airflow.providers.snowflake.utils.snowpark import
inject_session_into_op_kwargs
+
+if TYPE_CHECKING:
+ from airflow.decorators.base import TaskDecorator
+
+
+class _SnowparkDecoratedOperator(DecoratedOperator, SnowparkOperator):
+ """
+ Wraps a Python callable that contains Snowpark code and captures
args/kwargs when called for execution.
+
+ :param snowflake_conn_id: Reference to
+ :ref:`Snowflake connection id<howto/connection:snowflake>`
+ :param python_callable: A reference to an object that is callable
+ :param op_args: a list of positional arguments that will get unpacked when
+ calling your callable
+ :param op_kwargs: a dictionary of keyword arguments that will get unpacked
+ in your function
+ :param warehouse: name of warehouse (will overwrite any warehouse
+ defined in the connection's extra JSON)
+ :param database: name of database (will overwrite database defined
+ in connection)
+ :param schema: name of schema (will overwrite schema defined in
+ connection)
+ :param role: name of role (will overwrite any role defined in
+ connection's extra JSON)
+ :param authenticator: authenticator for Snowflake.
+ 'snowflake' (default) to use the internal Snowflake authenticator
+ 'externalbrowser' to authenticate using your web browser and
+ Okta, ADFS or any other SAML 2.0-compliant identify provider
+ (IdP) that has been defined for your account
+ 'https://<your_okta_account_name>.okta.com' to authenticate
+ through native Okta.
+ :param session_parameters: You can set session-level parameters at
+ the time you connect to Snowflake
+ :param multiple_outputs: If set to True, the decorated function's return
value will be unrolled to
+ multiple XCom values. Dict will unroll to XCom values with its keys as
XCom keys. Defaults to False.
+ """
+
+ custom_operator_name = "@task.snowpark"
+
+ def __init__(
+ self,
+ *,
+ snowflake_conn_id: str = "snowflake_default",
+ python_callable: Callable,
+ op_args: Sequence | None = None,
+ op_kwargs: dict | None = None,
+ warehouse: str | None = None,
+ database: str | None = None,
+ role: str | None = None,
+ schema: str | None = None,
+ authenticator: str | None = None,
+ session_parameters: dict | None = None,
+ **kwargs,
+ ) -> None:
+ kwargs_to_upstream = {
+ "python_callable": python_callable,
+ "op_args": op_args,
+ "op_kwargs": op_kwargs,
+ }
+ super().__init__(
+ kwargs_to_upstream=kwargs_to_upstream,
+ snowflake_conn_id=snowflake_conn_id,
+ python_callable=python_callable,
+ op_args=op_args,
+ # airflow.decorators.base.DecoratedOperator checks if the
functions are bindable, so we have to
+ # add an artificial value to pass the validation if there is a
keyword argument named `session`
+ # in the signature of the python callable. The real value is
determined at runtime.
+ op_kwargs=inject_session_into_op_kwargs(python_callable,
op_kwargs, None)
+ if op_kwargs is not None
+ else op_kwargs,
+ warehouse=warehouse,
+ database=database,
+ role=role,
+ schema=schema,
+ authenticator=authenticator,
+ session_parameters=session_parameters,
+ **kwargs,
+ )
+
+
+def snowpark_task(
+ python_callable: Callable | None = None,
+ multiple_outputs: bool | None = None,
+ **kwargs,
+) -> TaskDecorator:
+ """
+ Wrap a function that contains Snowpark code into an Airflow operator.
+
+ Accepts kwargs for operator kwarg. Can be reused in a single DAG.
+
+ :param python_callable: Function to decorate
+ :param multiple_outputs: If set to True, the decorated function's return
value will be unrolled to
+ multiple XCom values. Dict will unroll to XCom values with its keys as
XCom keys. Defaults to False.
+ """
+ return task_decorator_factory(
+ python_callable=python_callable,
+ multiple_outputs=multiple_outputs,
+ decorated_operator_class=_SnowparkDecoratedOperator,
+ **kwargs,
+ )
diff --git a/airflow/providers/snowflake/hooks/snowflake.py
b/airflow/providers/snowflake/hooks/snowflake.py
index 4b4143fdd1..0f81f2e384 100644
--- a/airflow/providers/snowflake/hooks/snowflake.py
+++ b/airflow/providers/snowflake/hooks/snowflake.py
@@ -321,6 +321,28 @@ class SnowflakeHook(DbApiHook):
engine_kwargs["connect_args"][key] = conn_params[key]
return create_engine(self._conn_params_to_sqlalchemy_uri(conn_params),
**engine_kwargs)
+ def get_snowpark_session(self):
+ """
+ Get a Snowpark session object.
+
+ :return: the created session.
+ """
+ from snowflake.snowpark import Session
+
+ from airflow import __version__ as airflow_version
+ from airflow.providers.snowflake import __version__ as provider_version
+
+ conn_config = self._get_conn_params
+ session = Session.builder.configs(conn_config).create()
+ # add query tag for observability
+ session.update_query_tag(
+ {
+ "airflow_version": airflow_version,
+ "airflow_provider_version": provider_version,
+ }
+ )
+ return session
+
def set_autocommit(self, conn, autocommit: Any) -> None:
conn.autocommit(autocommit)
conn.autocommit_mode = autocommit
diff --git a/airflow/providers/snowflake/operators/snowpark.py
b/airflow/providers/snowflake/operators/snowpark.py
new file mode 100644
index 0000000000..1635eebaa3
--- /dev/null
+++ b/airflow/providers/snowflake/operators/snowpark.py
@@ -0,0 +1,133 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+from typing import Any, Callable, Collection, Mapping, Sequence
+
+from airflow.operators.python import PythonOperator, get_current_context
+from airflow.providers.snowflake.hooks.snowflake import SnowflakeHook
+from airflow.providers.snowflake.utils.snowpark import
inject_session_into_op_kwargs
+
+
+class SnowparkOperator(PythonOperator):
+ """
+ Executes a Python function with Snowpark Python code.
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the
guide:
+ :ref:`howto/operator:SnowparkOperator`
+
+ :param snowflake_conn_id: Reference to
+ :ref:`Snowflake connection id<howto/connection:snowflake>`
+ :param python_callable: A reference to an object that is callable
+ :param op_args: a list of positional arguments that will get unpacked when
+ calling your callable
+ :param op_kwargs: a dictionary of keyword arguments that will get unpacked
+ in your function
+ :param templates_dict: a dictionary where the values are templates that
+ will get templated by the Airflow engine sometime between
+ ``__init__`` and ``execute`` takes place and are made available
+ in your callable's context after the template has been applied.
(templated)
+ :param templates_exts: a list of file extensions to resolve while
+ processing templated fields, for examples ``['.sql', '.hql']``
+ :param show_return_value_in_logs: a bool value whether to show return_value
+ logs. Defaults to True, which allows return value log output.
+ It can be set to False to prevent log output of return value when you
return huge data
+ such as transmission a large amount of XCom to TaskAPI.
+ :param warehouse: name of warehouse (will overwrite any warehouse
+ defined in the connection's extra JSON)
+ :param database: name of database (will overwrite database defined
+ in connection)
+ :param schema: name of schema (will overwrite schema defined in
+ connection)
+ :param role: name of role (will overwrite any role defined in
+ connection's extra JSON)
+ :param authenticator: authenticator for Snowflake.
+ 'snowflake' (default) to use the internal Snowflake authenticator
+ 'externalbrowser' to authenticate using your web browser and
+ Okta, ADFS or any other SAML 2.0-compliant identify provider
+ (IdP) that has been defined for your account
+ 'https://<your_okta_account_name>.okta.com' to authenticate
+ through native Okta.
+ :param session_parameters: You can set session-level parameters at
+ the time you connect to Snowflake
+ """
+
+ def __init__(
+ self,
+ *,
+ snowflake_conn_id: str = "snowflake_default",
+ python_callable: Callable,
+ op_args: Collection[Any] | None = None,
+ op_kwargs: Mapping[str, Any] | None = None,
+ templates_dict: dict[str, Any] | None = None,
+ templates_exts: Sequence[str] | None = None,
+ show_return_value_in_logs: bool = True,
+ warehouse: str | None = None,
+ database: str | None = None,
+ schema: str | None = None,
+ role: str | None = None,
+ authenticator: str | None = None,
+ session_parameters: dict | None = None,
+ **kwargs,
+ ):
+ super().__init__(
+ python_callable=python_callable,
+ op_args=op_args,
+ op_kwargs=op_kwargs,
+ templates_dict=templates_dict,
+ templates_exts=templates_exts,
+ show_return_value_in_logs=show_return_value_in_logs,
+ **kwargs,
+ )
+ self.snowflake_conn_id = snowflake_conn_id
+ self.warehouse = warehouse
+ self.database = database
+ self.schema = schema
+ self.role = role
+ self.authenticator = authenticator
+ self.session_parameters = session_parameters
+
+ def execute_callable(self):
+ hook = SnowflakeHook(
+ snowflake_conn_id=self.snowflake_conn_id,
+ warehouse=self.warehouse,
+ database=self.database,
+ role=self.role,
+ schema=self.schema,
+ authenticator=self.authenticator,
+ session_parameters=self.session_parameters,
+ )
+ session = hook.get_snowpark_session()
+ context = get_current_context()
+ session.update_query_tag(
+ {
+ "dag_id": context["dag_run"].dag_id,
+ "dag_run_id": context["dag_run"].run_id,
+ "task_id": context["task_instance"].task_id,
+ "operator": self.__class__.__name__,
+ }
+ )
+ try:
+ # inject session object if the function has "session" keyword as
an argument
+ self.op_kwargs = inject_session_into_op_kwargs(
+ self.python_callable, dict(self.op_kwargs), session
+ )
+ return super().execute_callable()
+ finally:
+ session.close()
diff --git a/airflow/providers/snowflake/provider.yaml
b/airflow/providers/snowflake/provider.yaml
index 067f673d70..47de902ff6 100644
--- a/airflow/providers/snowflake/provider.yaml
+++ b/airflow/providers/snowflake/provider.yaml
@@ -90,12 +90,14 @@ dependencies:
- pyarrow>=14.0.1
- snowflake-connector-python>=3.7.1
- snowflake-sqlalchemy>=1.4.0
+ - snowflake-snowpark-python>=1.17.0;python_version<"3.12"
integrations:
- integration-name: Snowflake
external-doc-url: https://snowflake.com/
how-to-guide:
- /docs/apache-airflow-providers-snowflake/operators/snowflake.rst
+ - /docs/apache-airflow-providers-snowflake/operators/snowpark.rst
logo: /integration-logos/snowflake/Snowflake.png
tags: [service]
@@ -103,6 +105,11 @@ operators:
- integration-name: Snowflake
python-modules:
- airflow.providers.snowflake.operators.snowflake
+ - airflow.providers.snowflake.operators.snowpark
+
+task-decorators:
+ - class-name: airflow.providers.snowflake.decorators.snowpark.snowpark_task
+ name: snowpark
hooks:
- integration-name: Snowflake
diff --git a/airflow/providers/snowflake/utils/snowpark.py
b/airflow/providers/snowflake/utils/snowpark.py
new file mode 100644
index 0000000000..a6617bb920
--- /dev/null
+++ b/airflow/providers/snowflake/utils/snowpark.py
@@ -0,0 +1,44 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+import inspect
+from typing import TYPE_CHECKING, Callable
+
+if TYPE_CHECKING:
+ from snowflake.snowpark import Session
+
+
+def inject_session_into_op_kwargs(
+ python_callable: Callable, op_kwargs: dict, session: Session | None
+) -> dict:
+ """
+ Inject Snowpark session into operator kwargs based on signature of python
callable.
+
+ If there is a keyword argument named `session` in the signature of the
python callable,
+ a Snowpark session object will be injected into kwargs.
+
+ :param python_callable: Python callable
+ :param op_kwargs: Operator kwargs
+ :param session: Snowpark session
+ """
+ signature = inspect.signature(python_callable)
+ if "session" in signature.parameters:
+ return {**op_kwargs, "session": session}
+ else:
+ return op_kwargs
diff --git a/docs/apache-airflow-providers-snowflake/decorators/index.rst
b/docs/apache-airflow-providers-snowflake/decorators/index.rst
new file mode 100644
index 0000000000..7871e3553b
--- /dev/null
+++ b/docs/apache-airflow-providers-snowflake/decorators/index.rst
@@ -0,0 +1,25 @@
+ .. Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ .. http://www.apache.org/licenses/LICENSE-2.0
+
+ .. Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+
+Snowflake decorators
+====================
+
+.. toctree::
+ :maxdepth: 1
+ :glob:
+
+ *
diff --git a/docs/apache-airflow-providers-snowflake/decorators/snowpark.rst
b/docs/apache-airflow-providers-snowflake/decorators/snowpark.rst
new file mode 100644
index 0000000000..09be01e3ef
--- /dev/null
+++ b/docs/apache-airflow-providers-snowflake/decorators/snowpark.rst
@@ -0,0 +1,70 @@
+ .. Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ .. http://www.apache.org/licenses/LICENSE-2.0
+
+ .. Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+
+.. _howto/decorators:snowpark:
+
+``@task.snowpark``
+==================
+
+Use the :func:`@task.snowpark
<airflow.providers.snowflake.decorators.snowpark.snowpark_task>` to run
+`Snowpark Python
<https://docs.snowflake.com/en/developer-guide/snowpark/python/index.html>`__
code in a `Snowflake <https://docs.snowflake.com/en/>`__ database.
+
+.. warning::
+
+ - Snowpark does not support Python 3.12 yet.
+ - Currently, this decorator does not support `Snowpark pandas API
<https://docs.snowflake.com/en/developer-guide/snowpark/python/pandas-on-snowflake>`__
because conflicting pandas version is used in Airflow.
+ Consider using Snowpark pandas API with other Snowpark decorators or
operators.
+
+Prerequisite Tasks
+^^^^^^^^^^^^^^^^^^
+
+To use this decorator, you must do a few things:
+
+ * Install provider package via **pip**.
+
+ .. code-block:: bash
+
+ pip install 'apache-airflow-providers-snowflake'
+
+ Detailed information is available for :doc:`Installation
<apache-airflow:installation/index>`.
+
+ * :doc:`Setup a Snowflake Connection </connections/snowflake>`.
+
+Using the Operator
+^^^^^^^^^^^^^^^^^^
+
+Use the ``snowflake_conn_id`` argument to specify connection used. If not
specified, ``snowflake_default`` will be used.
+
+An example usage of the ``@task.snowpark`` is as follows:
+
+.. exampleinclude::
/../../tests/system/providers/snowflake/example_snowpark_decorator.py
+ :language: python
+ :start-after: [START howto_decorator_snowpark]
+ :end-before: [END howto_decorator_snowpark]
+
+
+As the example demonstrates, there are two ways to use the Snowpark session
object in your Python function:
+
+ * Pass the Snowpark session object to the function as a keyword argument
named ``session``. The Snowpark session will be automatically injected into the
function, allowing you to use it as you normally would.
+
+ * Use `get_active_session
<https://docs.snowflake.com/en/developer-guide/snowpark/reference/python/1.3.0/api/snowflake.snowpark.context.get_active_session>`__
+ function from Snowpark to retrieve the Snowpark session object inside the
function.
+
+.. note::
+
+ Parameters that can be passed onto the decorators will be given priority
over the parameters already given
+ in the Airflow connection metadata (such as ``schema``, ``role``,
``database`` and so forth).
diff --git a/docs/apache-airflow-providers-snowflake/index.rst
b/docs/apache-airflow-providers-snowflake/index.rst
index 5b9a8a5133..b00ea39c52 100644
--- a/docs/apache-airflow-providers-snowflake/index.rst
+++ b/docs/apache-airflow-providers-snowflake/index.rst
@@ -36,6 +36,7 @@
Connection Types <connections/snowflake>
Operators <operators/index>
+ Decorators <decorators/index>
.. toctree::
:hidden:
diff --git a/docs/apache-airflow-providers-snowflake/operators/snowpark.rst
b/docs/apache-airflow-providers-snowflake/operators/snowpark.rst
new file mode 100644
index 0000000000..755fa6c529
--- /dev/null
+++ b/docs/apache-airflow-providers-snowflake/operators/snowpark.rst
@@ -0,0 +1,74 @@
+ .. Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ .. http://www.apache.org/licenses/LICENSE-2.0
+
+ .. Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+
+.. _howto/operator:SnowparkOperator:
+
+SnowparkOperator
+================
+
+Use the :class:`SnowparkOperator
<airflow.providers.snowflake.operators.snowpark.SnowparkOperator>` to run
+`Snowpark Python
<https://docs.snowflake.com/en/developer-guide/snowpark/python/index.html>`__
code in a `Snowflake <https://docs.snowflake.com/en/>`__ database.
+
+.. warning::
+
+ - Snowpark does not support Python 3.12 yet.
+ - Currently, this operator does not support `Snowpark pandas API
<https://docs.snowflake.com/en/developer-guide/snowpark/python/pandas-on-snowflake>`__
because conflicting pandas version is used in Airflow.
+ Consider using Snowpark pandas API with other Snowpark decorators or
operators.
+
+.. tip::
+
+ The :doc:`@task.snowpark </decorators/snowpark>` decorator is recommended
over the ``SnowparkOperator`` to run Snowpark Python code.
+
+Prerequisite Tasks
+^^^^^^^^^^^^^^^^^^
+
+To use this operator, you must do a few things:
+
+ * Install provider package via **pip**.
+
+ .. code-block:: bash
+
+ pip install 'apache-airflow-providers-snowflake'
+
+ Detailed information is available for :doc:`Installation
<apache-airflow:installation/index>`.
+
+ * :doc:`Setup a Snowflake Connection </connections/snowflake>`.
+
+Using the Operator
+^^^^^^^^^^^^^^^^^^
+
+Use the ``snowflake_conn_id`` argument to specify connection used. If not
specified, ``snowflake_default`` will be used.
+
+An example usage of the ``@task.snowpark`` is as follows:
+
+.. exampleinclude::
/../../tests/system/providers/snowflake/example_snowpark_operator.py
+ :language: python
+ :start-after: [START howto_operator_snowpark]
+ :end-before: [END howto_operator_snowpark]
+
+
+As the example demonstrates, there are two ways to use the Snowpark session
object in your Python function:
+
+ * Pass the Snowpark session object to the function as a keyword argument
named ``session``. The Snowpark session will be automatically injected into the
function, allowing you to use it as you normally would.
+
+ * Use `get_active_session
<https://docs.snowflake.com/en/developer-guide/snowpark/reference/python/1.3.0/api/snowflake.snowpark.context.get_active_session>`__
+ function from Snowpark to retrieve the Snowpark session object inside the
function.
+
+.. note::
+
+ Parameters that can be passed onto the operators will be given priority over
the parameters already given
+ in the Airflow connection metadata (such as ``schema``, ``role``,
``database`` and so forth).
diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt
index eb6e612e09..cf6856838d 100644
--- a/docs/spelling_wordlist.txt
+++ b/docs/spelling_wordlist.txt
@@ -1485,6 +1485,9 @@ SlackResponse
slas
smtp
SnowflakeHook
+Snowpark
+snowpark
+SnowparkOperator
somecollection
somedatabase
sortable
diff --git a/generated/provider_dependencies.json
b/generated/provider_dependencies.json
index 2a81933c6c..10631afb9b 100644
--- a/generated/provider_dependencies.json
+++ b/generated/provider_dependencies.json
@@ -1230,6 +1230,7 @@
"pandas>=2.1.2,<2.2;python_version>=\"3.9\"",
"pyarrow>=14.0.1",
"snowflake-connector-python>=3.7.1",
+ "snowflake-snowpark-python>=1.17.0;python_version<\"3.12\"",
"snowflake-sqlalchemy>=1.4.0"
],
"devel-deps": [],
diff --git a/tests/providers/snowflake/decorators/__init__.py
b/tests/providers/snowflake/decorators/__init__.py
new file mode 100644
index 0000000000..13a83393a9
--- /dev/null
+++ b/tests/providers/snowflake/decorators/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/tests/providers/snowflake/decorators/test_snowpark.py
b/tests/providers/snowflake/decorators/test_snowpark.py
new file mode 100644
index 0000000000..b14b6bd5c0
--- /dev/null
+++ b/tests/providers/snowflake/decorators/test_snowpark.py
@@ -0,0 +1,221 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+import sys
+from typing import TYPE_CHECKING
+from unittest import mock
+
+import pytest
+
+from airflow.decorators import task
+from airflow.utils import timezone
+
+if TYPE_CHECKING:
+ from snowflake.snowpark import Session
+
+DEFAULT_DATE = timezone.datetime(2024, 9, 1)
+TEST_DAG_ID = "test_snowpark_decorator"
+TASK_ID = "snowpark_task"
+CONN_ID = "snowflake_default"
+
+
[email protected]_test
[email protected](sys.version_info >= (3, 12), reason="Snowpark Python
doesn't support Python 3.12 yet")
+class TestSnowparkDecorator:
+ @mock.patch("airflow.providers.snowflake.operators.snowpark.SnowflakeHook")
+ def test_snowpark_decorator_no_param(self, mock_snowflake_hook, dag_maker):
+ number = 11
+
+ @task.snowpark(
+ task_id=f"{TASK_ID}_1",
+ snowflake_conn_id=CONN_ID,
+ warehouse="test_warehouse",
+ database="test_database",
+ schema="test_schema",
+ role="test_role",
+ authenticator="externalbrowser",
+ )
+ def func1(session: Session):
+ assert session ==
mock_snowflake_hook.return_value.get_snowpark_session.return_value
+ return number
+
+ @task.snowpark(
+ task_id=f"{TASK_ID}_2",
+ snowflake_conn_id=CONN_ID,
+ warehouse="test_warehouse",
+ database="test_database",
+ schema="test_schema",
+ role="test_role",
+ authenticator="externalbrowser",
+ )
+ def func2():
+ return number
+
+ with dag_maker(dag_id=TEST_DAG_ID):
+ rets = [func1(), func2()]
+
+ dr = dag_maker.create_dagrun()
+ for ret in rets:
+ ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+ for ti in dr.get_task_instances():
+ assert ti.xcom_pull() == number
+ assert mock_snowflake_hook.call_count == 2
+ assert
mock_snowflake_hook.return_value.get_snowpark_session.call_count == 2
+
+ @mock.patch("airflow.providers.snowflake.operators.snowpark.SnowflakeHook")
+ def test_snowpark_decorator_with_param(self, mock_snowflake_hook,
dag_maker):
+ number = 11
+
+ @task.snowpark(
+ task_id=f"{TASK_ID}_1",
+ snowflake_conn_id=CONN_ID,
+ warehouse="test_warehouse",
+ database="test_database",
+ schema="test_schema",
+ role="test_role",
+ authenticator="externalbrowser",
+ )
+ def func1(session: Session, number: int):
+ assert session ==
mock_snowflake_hook.return_value.get_snowpark_session.return_value
+ return number
+
+ @task.snowpark(
+ task_id=f"{TASK_ID}_2",
+ snowflake_conn_id=CONN_ID,
+ warehouse="test_warehouse",
+ database="test_database",
+ schema="test_schema",
+ role="test_role",
+ authenticator="externalbrowser",
+ )
+ def func2(number: int, session: Session):
+ assert session ==
mock_snowflake_hook.return_value.get_snowpark_session.return_value
+ return number
+
+ @task.snowpark(
+ task_id=f"{TASK_ID}_3",
+ snowflake_conn_id=CONN_ID,
+ warehouse="test_warehouse",
+ database="test_database",
+ schema="test_schema",
+ role="test_role",
+ authenticator="externalbrowser",
+ )
+ def func3(number: int):
+ return number
+
+ with dag_maker(dag_id=TEST_DAG_ID):
+ rets = [func1(number=number), func2(number=number),
func3(number=number)]
+
+ dr = dag_maker.create_dagrun()
+ for ret in rets:
+ ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+ for ti in dr.get_task_instances():
+ assert ti.xcom_pull() == number
+ assert mock_snowflake_hook.call_count == 3
+ assert
mock_snowflake_hook.return_value.get_snowpark_session.call_count == 3
+
+ @mock.patch("airflow.providers.snowflake.operators.snowpark.SnowflakeHook")
+ def test_snowpark_decorator_no_return(self, mock_snowflake_hook,
dag_maker):
+ @task.snowpark(
+ task_id=TASK_ID,
+ snowflake_conn_id=CONN_ID,
+ warehouse="test_warehouse",
+ database="test_database",
+ schema="test_schema",
+ role="test_role",
+ authenticator="externalbrowser",
+ )
+ def func(session: Session):
+ assert session ==
mock_snowflake_hook.return_value.get_snowpark_session.return_value
+
+ with dag_maker(dag_id=TEST_DAG_ID):
+ ret = func()
+
+ dr = dag_maker.create_dagrun()
+ ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+ for ti in dr.get_task_instances():
+ assert ti.xcom_pull() is None
+ mock_snowflake_hook.assert_called_once()
+
mock_snowflake_hook.return_value.get_snowpark_session.assert_called_once()
+
+ @mock.patch("airflow.providers.snowflake.operators.snowpark.SnowflakeHook")
+ def test_snowpark_decorator_multiple_output(self, mock_snowflake_hook,
dag_maker):
+ @task.snowpark(
+ task_id=TASK_ID,
+ snowflake_conn_id=CONN_ID,
+ warehouse="test_warehouse",
+ database="test_database",
+ schema="test_schema",
+ role="test_role",
+ authenticator="externalbrowser",
+ multiple_outputs=True,
+ )
+ def func(session: Session):
+ assert session ==
mock_snowflake_hook.return_value.get_snowpark_session.return_value
+ return {"a": 1, "b": "2"}
+
+ with dag_maker(dag_id=TEST_DAG_ID):
+ ret = func()
+
+ dr = dag_maker.create_dagrun()
+ ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+ ti = dr.get_task_instances()[0]
+ assert ti.xcom_pull(key="a") == 1
+ assert ti.xcom_pull(key="b") == "2"
+ assert ti.xcom_pull() == {"a": 1, "b": "2"}
+ mock_snowflake_hook.assert_called_once()
+
mock_snowflake_hook.return_value.get_snowpark_session.assert_called_once()
+
+ @mock.patch("airflow.providers.snowflake.operators.snowpark.SnowflakeHook")
+ def test_snowpark_decorator_session_tag(self, mock_snowflake_hook,
dag_maker):
+ mock_session =
mock_snowflake_hook.return_value.get_snowpark_session.return_value
+ mock_session.query_tag = {}
+
+ # Mock the update_query_tag function to combine with another dict
+ def update_query_tag(new_tags):
+ mock_session.query_tag.update(new_tags)
+
+ mock_session.update_query_tag = mock.Mock(side_effect=update_query_tag)
+
+ @task.snowpark(
+ task_id=TASK_ID,
+ snowflake_conn_id=CONN_ID,
+ warehouse="test_warehouse",
+ database="test_database",
+ schema="test_schema",
+ role="test_role",
+ authenticator="externalbrowser",
+ )
+ def func(session: Session):
+ return session.query_tag
+
+ with dag_maker(dag_id=TEST_DAG_ID):
+ ret = func()
+
+ dr = dag_maker.create_dagrun()
+ ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+ ti = dr.get_task_instances()[0]
+ query_tag = ti.xcom_pull()
+ assert query_tag == {
+ "dag_id": TEST_DAG_ID,
+ "dag_run_id": dr.run_id,
+ "task_id": TASK_ID,
+ "operator": "_SnowparkDecoratedOperator",
+ }
diff --git a/tests/providers/snowflake/hooks/test_snowflake.py
b/tests/providers/snowflake/hooks/test_snowflake.py
index 16e10db048..9ef0c4d2a5 100644
--- a/tests/providers/snowflake/hooks/test_snowflake.py
+++ b/tests/providers/snowflake/hooks/test_snowflake.py
@@ -18,6 +18,7 @@
from __future__ import annotations
import json
+import sys
from copy import deepcopy
from typing import TYPE_CHECKING, Any
from unittest import mock
@@ -611,3 +612,29 @@ class TestPytestSnowflakeHook:
hook_with_schema_param =
SnowflakeHook(snowflake_conn_id="test_conn", schema="my_schema")
assert hook_with_schema_param.get_openlineage_default_schema() ==
"my_schema"
mock_get_first.assert_not_called()
+
+ @pytest.mark.skipif(sys.version_info >= (3, 12), reason="Snowpark Python
doesn't support Python 3.12 yet")
+ @mock.patch("snowflake.snowpark.Session.builder")
+ def test_get_snowpark_session(self, mock_session_builder):
+ from airflow import __version__ as airflow_version
+ from airflow.providers.snowflake import __version__ as provider_version
+
+ mock_session = mock.MagicMock()
+ mock_session_builder.configs.return_value.create.return_value =
mock_session
+
+ with mock.patch.dict(
+ "os.environ",
AIRFLOW_CONN_TEST_CONN=Connection(**BASE_CONNECTION_KWARGS).get_uri()
+ ):
+ hook = SnowflakeHook(snowflake_conn_id="test_conn")
+ session = hook.get_snowpark_session()
+ assert session == mock_session
+
+
mock_session_builder.configs.assert_called_once_with(hook._get_conn_params)
+
+ # Verify that update_query_tag was called with the expected tag
dictionary
+ mock_session.update_query_tag.assert_called_once_with(
+ {
+ "airflow_version": airflow_version,
+ "airflow_provider_version": provider_version,
+ }
+ )
diff --git a/tests/providers/snowflake/operators/test_snowpark.py
b/tests/providers/snowflake/operators/test_snowpark.py
new file mode 100644
index 0000000000..b39bf3c105
--- /dev/null
+++ b/tests/providers/snowflake/operators/test_snowpark.py
@@ -0,0 +1,181 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+import sys
+from typing import TYPE_CHECKING
+from unittest import mock
+
+import pytest
+
+from airflow.providers.snowflake.operators.snowpark import SnowparkOperator
+from airflow.utils import timezone
+
+if TYPE_CHECKING:
+ from snowflake.snowpark import Session
+
+DEFAULT_DATE = timezone.datetime(2024, 9, 1)
+TEST_DAG_ID = "test_snowpark_operator"
+TASK_ID = "snowpark_task"
+CONN_ID = "snowflake_default"
+
+
[email protected]_test
[email protected](sys.version_info >= (3, 12), reason="Snowpark Python
doesn't support Python 3.12 yet")
+class TestSnowparkOperator:
+ @mock.patch("airflow.providers.snowflake.operators.snowpark.SnowflakeHook")
+ def test_snowpark_operator_no_param(self, mock_snowflake_hook, dag_maker):
+ number = 11
+
+ with dag_maker(dag_id=TEST_DAG_ID) as dag:
+
+ def func1(session: Session):
+ assert session ==
mock_snowflake_hook.return_value.get_snowpark_session.return_value
+ return number
+
+ def func2():
+ return number
+
+ operators = [
+ SnowparkOperator(
+ task_id=f"{TASK_ID}_{i}",
+ snowflake_conn_id=CONN_ID,
+ python_callable=func,
+ warehouse="test_warehouse",
+ database="test_database",
+ schema="test_schema",
+ role="test_role",
+ authenticator="externalbrowser",
+ dag=dag,
+ )
+ for i, func in enumerate([func1, func2])
+ ]
+
+ dr = dag_maker.create_dagrun()
+ for operator in operators:
+ operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+ for ti in dr.get_task_instances():
+ assert ti.xcom_pull() == number
+ assert mock_snowflake_hook.call_count == 2
+ assert
mock_snowflake_hook.return_value.get_snowpark_session.call_count == 2
+
+ @mock.patch("airflow.providers.snowflake.operators.snowpark.SnowflakeHook")
+ def test_snowpark_operator_with_param(self, mock_snowflake_hook,
dag_maker):
+ number = 11
+
+ with dag_maker(dag_id=TEST_DAG_ID) as dag:
+
+ def func1(session: Session, number: int):
+ assert session ==
mock_snowflake_hook.return_value.get_snowpark_session.return_value
+ return number
+
+ def func2(number: int, session: Session):
+ assert session ==
mock_snowflake_hook.return_value.get_snowpark_session.return_value
+ return number
+
+ def func3(number: int):
+ return number
+
+ operators = [
+ SnowparkOperator(
+ task_id=f"{TASK_ID}_{i}",
+ snowflake_conn_id=CONN_ID,
+ python_callable=func,
+ op_kwargs={"number": number},
+ warehouse="test_warehouse",
+ database="test_database",
+ schema="test_schema",
+ role="test_role",
+ authenticator="externalbrowser",
+ dag=dag,
+ )
+ for i, func in enumerate([func1, func2, func3])
+ ]
+
+ dr = dag_maker.create_dagrun()
+ for operator in operators:
+ operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+ for ti in dr.get_task_instances():
+ assert ti.xcom_pull() == number
+ assert mock_snowflake_hook.call_count == 3
+ assert
mock_snowflake_hook.return_value.get_snowpark_session.call_count == 3
+
+ @mock.patch("airflow.providers.snowflake.operators.snowpark.SnowflakeHook")
+ def test_snowpark_operator_no_return(self, mock_snowflake_hook, dag_maker):
+ with dag_maker(dag_id=TEST_DAG_ID) as dag:
+
+ def func(session: Session):
+ assert session ==
mock_snowflake_hook.return_value.get_snowpark_session.return_value
+
+ operator = SnowparkOperator(
+ task_id=TASK_ID,
+ snowflake_conn_id=CONN_ID,
+ python_callable=func,
+ warehouse="test_warehouse",
+ database="test_database",
+ schema="test_schema",
+ role="test_role",
+ authenticator="externalbrowser",
+ dag=dag,
+ )
+
+ dr = dag_maker.create_dagrun()
+ operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+ for ti in dr.get_task_instances():
+ assert ti.xcom_pull() is None
+ mock_snowflake_hook.assert_called_once()
+
mock_snowflake_hook.return_value.get_snowpark_session.assert_called_once()
+
+ @mock.patch("airflow.providers.snowflake.operators.snowpark.SnowflakeHook")
+ def test_snowpark_operator_session_tag(self, mock_snowflake_hook,
dag_maker):
+ mock_session =
mock_snowflake_hook.return_value.get_snowpark_session.return_value
+ mock_session.query_tag = {}
+
+ # Mock the update_query_tag function to combine with another dict
+ def update_query_tag(new_tags):
+ mock_session.query_tag.update(new_tags)
+
+ mock_session.update_query_tag = mock.Mock(side_effect=update_query_tag)
+
+ with dag_maker(dag_id=TEST_DAG_ID) as dag:
+
+ def func(session: Session):
+ return session.query_tag
+
+ operator = SnowparkOperator(
+ task_id=TASK_ID,
+ snowflake_conn_id=CONN_ID,
+ python_callable=func,
+ warehouse="test_warehouse",
+ database="test_database",
+ schema="test_schema",
+ role="test_role",
+ authenticator="externalbrowser",
+ dag=dag,
+ )
+
+ dr = dag_maker.create_dagrun()
+ operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+ ti = dr.get_task_instances()[0]
+ query_tag = ti.xcom_pull()
+ assert query_tag == {
+ "dag_id": TEST_DAG_ID,
+ "dag_run_id": dr.run_id,
+ "task_id": TASK_ID,
+ "operator": "SnowparkOperator",
+ }
diff --git a/tests/providers/snowflake/utils/test_snowpark.py
b/tests/providers/snowflake/utils/test_snowpark.py
new file mode 100644
index 0000000000..c0c8b507ef
--- /dev/null
+++ b/tests/providers/snowflake/utils/test_snowpark.py
@@ -0,0 +1,36 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import pytest
+
+from airflow.providers.snowflake.utils.snowpark import
inject_session_into_op_kwargs
+
+
[email protected](
+ "func,expected_injected",
+ [
+ (lambda x: x, False),
+ (lambda: 1, False),
+ (lambda session: 1, True),
+ (lambda session, x: x, True),
+ (lambda x, session: 2 * x, True),
+ ],
+)
+def test_inject_session_into_op_kwargs(func, expected_injected):
+ result = inject_session_into_op_kwargs(func, {}, None)
+ assert ("session" in result) == expected_injected
diff --git a/tests/system/providers/snowflake/example_snowpark_decorator.py
b/tests/system/providers/snowflake/example_snowpark_decorator.py
new file mode 100644
index 0000000000..1a303b1fdf
--- /dev/null
+++ b/tests/system/providers/snowflake/example_snowpark_decorator.py
@@ -0,0 +1,85 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Example use of Snowflake Snowpark Python related decorators.
+"""
+
+from __future__ import annotations
+
+from datetime import datetime
+from typing import TYPE_CHECKING
+
+if TYPE_CHECKING:
+ from snowflake.snowpark import Session
+
+from airflow import DAG
+from airflow.decorators import task
+
+SNOWFLAKE_CONN_ID = "snowflake_default"
+DAG_ID = "example_snowpark"
+
+with DAG(
+ DAG_ID,
+ start_date=datetime(2024, 1, 1),
+ schedule="@once",
+ default_args={"snowflake_conn_id": SNOWFLAKE_CONN_ID},
+ tags=["example"],
+ catchup=False,
+) as dag:
+ # [START howto_decorator_snowpark]
+ @task.snowpark
+ def setup_data(session: Session):
+ # The Snowpark session object is injected as an argument
+ data = [
+ (1, 0, 5, "Product 1", "prod-1", 1, 10),
+ (2, 1, 5, "Product 1A", "prod-1-A", 1, 20),
+ (3, 1, 5, "Product 1B", "prod-1-B", 1, 30),
+ (4, 0, 10, "Product 2", "prod-2", 2, 40),
+ (5, 4, 10, "Product 2A", "prod-2-A", 2, 50),
+ (6, 4, 10, "Product 2B", "prod-2-B", 2, 60),
+ (7, 0, 20, "Product 3", "prod-3", 3, 70),
+ (8, 7, 20, "Product 3A", "prod-3-A", 3, 80),
+ (9, 7, 20, "Product 3B", "prod-3-B", 3, 90),
+ (10, 0, 50, "Product 4", "prod-4", 4, 100),
+ (11, 10, 50, "Product 4A", "prod-4-A", 4, 100),
+ (12, 10, 50, "Product 4B", "prod-4-B", 4, 100),
+ ]
+ columns = ["id", "parent_id", "category_id", "name", "serial_number",
"key", "3rd"]
+ df = session.create_dataframe(data, schema=columns)
+ table_name = "sample_product_data"
+ df.write.save_as_table(table_name, mode="overwrite")
+ return table_name
+
+ table_name = setup_data() # type: ignore[call-arg]
+
+ @task.snowpark
+ def check_num_rows(table_name: str):
+ # Alternatively, retrieve the Snowpark session object using
`get_active_session`
+ from snowflake.snowpark.context import get_active_session
+
+ session = get_active_session()
+ df = session.table(table_name)
+ assert df.count() == 12
+
+ check_num_rows(table_name)
+ # [END howto_decorator_snowpark]
+
+from tests.system.utils import get_test_run # noqa: E402
+
+# Needed to run the example DAG with pytest (see:
tests/system/README.md#run_via_pytest)
+test_run = get_test_run(dag)
diff --git a/tests/system/providers/snowflake/example_snowpark_operator.py
b/tests/system/providers/snowflake/example_snowpark_operator.py
new file mode 100644
index 0000000000..090a0f53a4
--- /dev/null
+++ b/tests/system/providers/snowflake/example_snowpark_operator.py
@@ -0,0 +1,94 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Example use of Snowflake Snowpark Python related operators.
+"""
+
+from __future__ import annotations
+
+from datetime import datetime
+from typing import TYPE_CHECKING
+
+if TYPE_CHECKING:
+ from snowflake.snowpark import Session
+
+from airflow import DAG
+from airflow.providers.snowflake.operators.snowpark import SnowparkOperator
+
+SNOWFLAKE_CONN_ID = "snowflake_default"
+DAG_ID = "example_snowpark"
+
+with DAG(
+ DAG_ID,
+ start_date=datetime(2024, 1, 1),
+ schedule="@once",
+ default_args={"snowflake_conn_id": SNOWFLAKE_CONN_ID},
+ tags=["example"],
+ catchup=False,
+) as dag:
+ # [START howto_operator_snowpark]
+ def setup_data(session: Session):
+ # The Snowpark session object is injected as an argument
+ data = [
+ (1, 0, 5, "Product 1", "prod-1", 1, 10),
+ (2, 1, 5, "Product 1A", "prod-1-A", 1, 20),
+ (3, 1, 5, "Product 1B", "prod-1-B", 1, 30),
+ (4, 0, 10, "Product 2", "prod-2", 2, 40),
+ (5, 4, 10, "Product 2A", "prod-2-A", 2, 50),
+ (6, 4, 10, "Product 2B", "prod-2-B", 2, 60),
+ (7, 0, 20, "Product 3", "prod-3", 3, 70),
+ (8, 7, 20, "Product 3A", "prod-3-A", 3, 80),
+ (9, 7, 20, "Product 3B", "prod-3-B", 3, 90),
+ (10, 0, 50, "Product 4", "prod-4", 4, 100),
+ (11, 10, 50, "Product 4A", "prod-4-A", 4, 100),
+ (12, 10, 50, "Product 4B", "prod-4-B", 4, 100),
+ ]
+ columns = ["id", "parent_id", "category_id", "name", "serial_number",
"key", "3rd"]
+ df = session.create_dataframe(data, schema=columns)
+ table_name = "sample_product_data"
+ df.write.save_as_table(table_name, mode="overwrite")
+ return table_name
+
+ setup_data_operator = SnowparkOperator(
+ task_id="setup_data",
+ python_callable=setup_data,
+ dag=dag,
+ )
+
+ def check_num_rows(table_name: str):
+ # Alternatively, retrieve the Snowpark session object using
`get_active_session`
+ from snowflake.snowpark.context import get_active_session
+
+ session = get_active_session()
+ df = session.table(table_name)
+ assert df.count() == 12
+
+ check_num_rows_operator = SnowparkOperator(
+ task_id="check_num_rows",
+ python_callable=check_num_rows,
+ op_kwargs={"table_name": "{{
task_instance.xcom_pull(task_ids='setup_data') }}"},
+ dag=dag,
+ )
+
+ setup_data_operator >> check_num_rows_operator
+ # [END howto_operator_snowpark]
+
+from tests.system.utils import get_test_run # noqa: E402
+
+# Needed to run the example DAG with pytest (see:
tests/system/README.md#run_via_pytest)
+test_run = get_test_run(dag)