This is an automated email from the ASF dual-hosted git repository.
bolke pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 18dac61bf6 Add remote kernel support for papermill operator (#34840)
18dac61bf6 is described below
commit 18dac61bf6a72f660295512291d5bcdc726d7d2d
Author: Akshay Chitneni <[email protected]>
AuthorDate: Mon Nov 13 03:21:35 2023 -0800
Add remote kernel support for papermill operator (#34840)
Co-authored-by: Akshay Chitneni <[email protected]>
---
airflow/providers/papermill/hooks/__init__.py | 17 ++
airflow/providers/papermill/hooks/kernel.py | 171 +++++++++++++++++++++
airflow/providers/papermill/operators/papermill.py | 40 ++++-
airflow/providers/papermill/provider.yaml | 10 ++
.../connections/index.rst | 28 ++++
.../connections/jupyter_kernel.rst | 78 ++++++++++
docs/apache-airflow-providers-papermill/index.rst | 1 +
.../operators.rst | 8 +
generated/provider_dependencies.json | 1 +
tests/providers/papermill/hooks/__init__.py | 16 ++
tests/providers/papermill/hooks/test_kernel.py | 41 +++++
.../papermill/operators/test_papermill.py | 56 ++++++-
tests/system/providers/papermill/conftest.py | 56 +++++++
.../papermill/example_papermill_remote_verify.py | 80 ++++++++++
14 files changed, 600 insertions(+), 3 deletions(-)
diff --git a/airflow/providers/papermill/hooks/__init__.py
b/airflow/providers/papermill/hooks/__init__.py
new file mode 100644
index 0000000000..217e5db960
--- /dev/null
+++ b/airflow/providers/papermill/hooks/__init__.py
@@ -0,0 +1,17 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/airflow/providers/papermill/hooks/kernel.py
b/airflow/providers/papermill/hooks/kernel.py
new file mode 100644
index 0000000000..0bac65dc98
--- /dev/null
+++ b/airflow/providers/papermill/hooks/kernel.py
@@ -0,0 +1,171 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+from jupyter_client import AsyncKernelManager
+from papermill.clientwrap import PapermillNotebookClient
+from papermill.engines import NBClientEngine
+from papermill.utils import merge_kwargs, remove_args
+from traitlets import Unicode
+
+if TYPE_CHECKING:
+ from pydantic import typing
+
+from airflow.hooks.base import BaseHook
+
+JUPYTER_KERNEL_SHELL_PORT = 60316
+JUPYTER_KERNEL_IOPUB_PORT = 60317
+JUPYTER_KERNEL_STDIN_PORT = 60318
+JUPYTER_KERNEL_CONTROL_PORT = 60319
+JUPYTER_KERNEL_HB_PORT = 60320
+REMOTE_KERNEL_ENGINE = "remote_kernel_engine"
+
+
+class KernelConnection:
+ """Class to represent kernel connection object."""
+
+ ip: str
+ shell_port: int
+ iopub_port: int
+ stdin_port: int
+ control_port: int
+ hb_port: int
+ session_key: str
+
+
+class KernelHook(BaseHook):
+ """
+ The KernelHook can be used to interact with remote jupyter kernel.
+
+ Takes kernel host/ip from connection and refers to jupyter kernel ports
and session_key
+ from ``extra`` field.
+
+ :param kernel_conn_id: connection that has kernel host/ip
+ """
+
+ conn_name_attr = "kernel_conn_id"
+ default_conn_name = "jupyter_kernel_default"
+ conn_type = "jupyter_kernel"
+ hook_name = "Jupyter Kernel"
+
+ def __init__(self, kernel_conn_id: str = default_conn_name, *args,
**kwargs) -> None:
+ super().__init__(*args, **kwargs)
+ self.kernel_conn = self.get_connection(kernel_conn_id)
+ register_remote_kernel_engine()
+
+ def get_conn(self) -> KernelConnection:
+ kernel_connection = KernelConnection()
+ kernel_connection.ip = self.kernel_conn.host
+ kernel_connection.shell_port = self.kernel_conn.extra_dejson.get(
+ "shell_port", JUPYTER_KERNEL_SHELL_PORT
+ )
+ kernel_connection.iopub_port = self.kernel_conn.extra_dejson.get(
+ "iopub_port", JUPYTER_KERNEL_IOPUB_PORT
+ )
+ kernel_connection.stdin_port = self.kernel_conn.extra_dejson.get(
+ "stdin_port", JUPYTER_KERNEL_STDIN_PORT
+ )
+ kernel_connection.control_port = self.kernel_conn.extra_dejson.get(
+ "control_port", JUPYTER_KERNEL_CONTROL_PORT
+ )
+ kernel_connection.hb_port =
self.kernel_conn.extra_dejson.get("hb_port", JUPYTER_KERNEL_HB_PORT)
+ kernel_connection.session_key =
self.kernel_conn.extra_dejson.get("session_key", "")
+ return kernel_connection
+
+
+def register_remote_kernel_engine():
+ """Registers ``RemoteKernelEngine`` papermill engine."""
+ from papermill.engines import papermill_engines
+
+ papermill_engines.register(REMOTE_KERNEL_ENGINE, RemoteKernelEngine)
+
+
+class RemoteKernelManager(AsyncKernelManager):
+ """Jupyter kernel manager that connects to a remote kernel."""
+
+ session_key = Unicode("", config=True, help="Session key to connect to
remote kernel")
+
+ @property
+ def has_kernel(self) -> bool:
+ return True
+
+ async def _async_is_alive(self) -> bool:
+ return True
+
+ def shutdown_kernel(self, now: bool = False, restart: bool = False):
+ pass
+
+ def client(self, **kwargs: typing.Any):
+ """Create a client configured to connect to our kernel."""
+ kernel_client = super().client(**kwargs)
+ # load connection info to set session_key
+ config: dict[str, int | str | bytes] = dict(
+ ip=self.ip,
+ shell_port=self.shell_port,
+ iopub_port=self.iopub_port,
+ stdin_port=self.stdin_port,
+ control_port=self.control_port,
+ hb_port=self.hb_port,
+ key=self.session_key,
+ transport="tcp",
+ signature_scheme="hmac-sha256",
+ )
+ kernel_client.load_connection_info(config)
+ return kernel_client
+
+
+class RemoteKernelEngine(NBClientEngine):
+ """Papermill engine to use ``RemoteKernelManager`` to connect to remote
kernel and execute notebook."""
+
+ @classmethod
+ def execute_managed_notebook(
+ cls,
+ nb_man,
+ kernel_name,
+ log_output=False,
+ stdout_file=None,
+ stderr_file=None,
+ start_timeout=60,
+ execution_timeout=None,
+ **kwargs,
+ ):
+ """Performs the actual execution of the parameterized notebook
locally."""
+ km = RemoteKernelManager()
+ km.ip = kwargs["kernel_ip"]
+ km.shell_port = kwargs["kernel_shell_port"]
+ km.iopub_port = kwargs["kernel_iopub_port"]
+ km.stdin_port = kwargs["kernel_stdin_port"]
+ km.control_port = kwargs["kernel_control_port"]
+ km.hb_port = kwargs["kernel_hb_port"]
+ km.ip = kwargs["kernel_ip"]
+ km.session_key = kwargs["kernel_session_key"]
+
+ # Exclude parameters that named differently downstream
+ safe_kwargs = remove_args(["timeout", "startup_timeout"], **kwargs)
+
+ final_kwargs = merge_kwargs(
+ safe_kwargs,
+ timeout=execution_timeout if execution_timeout else
kwargs.get("timeout"),
+ startup_timeout=start_timeout,
+ log_output=False,
+ stdout_file=stdout_file,
+ stderr_file=stderr_file,
+ )
+
+ return PapermillNotebookClient(nb_man, km=km, **final_kwargs).execute()
diff --git a/airflow/providers/papermill/operators/papermill.py
b/airflow/providers/papermill/operators/papermill.py
index 0a3091a3ef..3326b2e3c9 100644
--- a/airflow/providers/papermill/operators/papermill.py
+++ b/airflow/providers/papermill/operators/papermill.py
@@ -17,6 +17,7 @@
# under the License.
from __future__ import annotations
+from functools import cached_property
from typing import TYPE_CHECKING, ClassVar, Collection, Sequence
import attr
@@ -24,6 +25,7 @@ import papermill as pm
from airflow.lineage.entities import File
from airflow.models import BaseOperator
+from airflow.providers.papermill.hooks.kernel import REMOTE_KERNEL_ENGINE,
KernelHook
if TYPE_CHECKING:
from airflow.utils.context import Context
@@ -54,7 +56,14 @@ class PapermillOperator(BaseOperator):
supports_lineage = True
- template_fields: Sequence[str] = ("input_nb", "output_nb", "parameters",
"kernel_name", "language_name")
+ template_fields: Sequence[str] = (
+ "input_nb",
+ "output_nb",
+ "parameters",
+ "kernel_name",
+ "language_name",
+ "kernel_conn_id",
+ )
def __init__(
self,
@@ -64,6 +73,7 @@ class PapermillOperator(BaseOperator):
parameters: dict | None = None,
kernel_name: str | None = None,
language_name: str | None = None,
+ kernel_conn_id: str | None = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
@@ -85,11 +95,29 @@ class PapermillOperator(BaseOperator):
self.kernel_name = kernel_name
self.language_name = language_name
+ self.kernel_conn_id = kernel_conn_id
self.inlets.append(self.input_nb)
self.outlets.append(self.output_nb)
def execute(self, context: Context):
+ remote_kernel_kwargs = {}
+ kernel_hook = self.hook
+ if kernel_hook:
+ engine_name = REMOTE_KERNEL_ENGINE
+ kernel_connection = kernel_hook.get_conn()
+ remote_kernel_kwargs = {
+ "kernel_ip": kernel_connection.ip,
+ "kernel_shell_port": kernel_connection.shell_port,
+ "kernel_iopub_port": kernel_connection.iopub_port,
+ "kernel_stdin_port": kernel_connection.stdin_port,
+ "kernel_control_port": kernel_connection.control_port,
+ "kernel_hb_port": kernel_connection.hb_port,
+ "kernel_session_key": kernel_connection.session_key,
+ }
+ else:
+ engine_name = None
+
pm.execute_notebook(
self.input_nb.url,
self.output_nb.url,
@@ -98,4 +126,14 @@ class PapermillOperator(BaseOperator):
report_mode=True,
kernel_name=self.kernel_name,
language=self.language_name,
+ engine_name=engine_name,
+ **remote_kernel_kwargs,
)
+
+ @cached_property
+ def hook(self) -> KernelHook | None:
+ """Get valid hook."""
+ if self.kernel_conn_id:
+ return KernelHook(kernel_conn_id=self.kernel_conn_id)
+ else:
+ return None
diff --git a/airflow/providers/papermill/provider.yaml
b/airflow/providers/papermill/provider.yaml
index 70c4903858..c85fde1677 100644
--- a/airflow/providers/papermill/provider.yaml
+++ b/airflow/providers/papermill/provider.yaml
@@ -44,6 +44,7 @@ dependencies:
- apache-airflow>=2.5.0
- papermill[all]>=1.2.1
- scrapbook[all]
+ - ipykernel
integrations:
- integration-name: Papermill
@@ -57,3 +58,12 @@ operators:
- integration-name: Papermill
python-modules:
- airflow.providers.papermill.operators.papermill
+
+hooks:
+ - integration-name: Papermill
+ python-modules:
+ - airflow.providers.papermill.hooks.kernel
+
+connection-types:
+ - hook-class-name: airflow.providers.papermill.hooks.kernel.KernelHook
+ connection-type: jupyter_kernel
diff --git a/docs/apache-airflow-providers-papermill/connections/index.rst
b/docs/apache-airflow-providers-papermill/connections/index.rst
new file mode 100644
index 0000000000..c66accd066
--- /dev/null
+++ b/docs/apache-airflow-providers-papermill/connections/index.rst
@@ -0,0 +1,28 @@
+ .. Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ .. http://www.apache.org/licenses/LICENSE-2.0
+
+ .. Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+
+
+
+Jupyter Kernel Connections
+==========================
+
+
+.. toctree::
+ :maxdepth: 1
+ :glob:
+
+ *
diff --git
a/docs/apache-airflow-providers-papermill/connections/jupyter_kernel.rst
b/docs/apache-airflow-providers-papermill/connections/jupyter_kernel.rst
new file mode 100644
index 0000000000..2ae56cfa7c
--- /dev/null
+++ b/docs/apache-airflow-providers-papermill/connections/jupyter_kernel.rst
@@ -0,0 +1,78 @@
+ .. Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ .. http://www.apache.org/licenses/LICENSE-2.0
+
+ .. Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+
+
+
+.. _howto/connection:jupyter_kernel:
+
+Jupyter Kernel Connection
+=========================
+
+The Jupyter Kernel connection type enables remote kernel connections.
+
+
+Default Connection ID
+---------------------
+
+ The default Jupyter Kernel connection ID is ``jupyter_kernel_default``.
+
+Configuring the Connection
+--------------------------
+
+host
+ HOSTNAME/IP of the remote Jupyter Kernel
+
+Extra (optional)
+ Specify the extra parameters (as json dictionary) that can be used in
kernel connection.
+ All parameters are optional.
+
+ * ``session_key``: Session key to initiate a connection to remote kernel
[default: ''].
+ * ``shell_port``: SHELL port [default: 60316].
+ * ``iopub_port``: IOPUB port [default: 60317].
+ * ``stdin_port``: STDIN port [default: 60318].
+ * ``control_port``: CONTROL port [default: 60319].
+ * ``hb_port``: HEARTBEAT port [default: 60320].
+
+If you are configuring the connection via a URI, ensure that all components of
the URI are URL-encoded.
+
+Examples
+--------
+
+**Set Remote Kernel Connection as Environment Variable (URI)**
+ .. code-block:: bash
+
+ export AIRFLOW_CONN_JUPYTER_KERNEL_DEFAULT='{"host": "remote_host",
"extra": {"session_key": "notebooks"}}'
+
+**Snippet for create Connection as URI**:
+ .. code-block:: python
+
+ from airflow.models.connection import Connection
+
+ conn = Connection(
+ conn_id="jupyter_kernel_default",
+ conn_type="jupyter_kernel",
+ host="remote_host",
+ extra={
+ # Specify extra parameters here
+ "session_key": "notebooks",
+ },
+ )
+
+ # Generate Environment Variable Name
+ env_key = f"AIRFLOW_CONN_{conn.conn_id.upper()}"
+
+ print(f"{env_key}='{conn.get_uri()}'")
diff --git a/docs/apache-airflow-providers-papermill/index.rst
b/docs/apache-airflow-providers-papermill/index.rst
index cb68701cb9..9a3a42c260 100644
--- a/docs/apache-airflow-providers-papermill/index.rst
+++ b/docs/apache-airflow-providers-papermill/index.rst
@@ -35,6 +35,7 @@
:caption: Guides
Operators <operators>
+ Connection types <connections/index>
.. toctree::
:hidden:
diff --git a/docs/apache-airflow-providers-papermill/operators.rst
b/docs/apache-airflow-providers-papermill/operators.rst
index 69a8259a9c..ed1cf580c8 100644
--- a/docs/apache-airflow-providers-papermill/operators.rst
+++ b/docs/apache-airflow-providers-papermill/operators.rst
@@ -62,3 +62,11 @@ Example DAG to Verify the message in the notebook:
:language: python
:start-after: [START howto_verify_operator_papermill]
:end-before: [END howto_verify_operator_papermill]
+
+
+Example DAG to Verify the message in the notebook using a remote jupyter
kernel:
+
+.. exampleinclude::
/../../tests/system/providers/papermill/example_papermill_remote_verify.py
+ :language: python
+ :start-after: [START howto_verify_operator_papermill_remote_kernel]
+ :end-before: [END howto_verify_operator_papermill_remote_kernel]
diff --git a/generated/provider_dependencies.json
b/generated/provider_dependencies.json
index 4d4d6cb90a..c7c939670a 100644
--- a/generated/provider_dependencies.json
+++ b/generated/provider_dependencies.json
@@ -732,6 +732,7 @@
"papermill": {
"deps": [
"apache-airflow>=2.5.0",
+ "ipykernel",
"papermill[all]>=1.2.1",
"scrapbook[all]"
],
diff --git a/tests/providers/papermill/hooks/__init__.py
b/tests/providers/papermill/hooks/__init__.py
new file mode 100644
index 0000000000..13a83393a9
--- /dev/null
+++ b/tests/providers/papermill/hooks/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/tests/providers/papermill/hooks/test_kernel.py
b/tests/providers/papermill/hooks/test_kernel.py
new file mode 100644
index 0000000000..4266d5fcb7
--- /dev/null
+++ b/tests/providers/papermill/hooks/test_kernel.py
@@ -0,0 +1,41 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest.mock import patch
+
+from airflow.models import Connection
+from airflow.providers.papermill.hooks.kernel import KernelHook
+
+
+class TestKernelHook:
+ """
+ Tests for Kernel connection
+ """
+
+ def test_kernel_connection(self):
+ """
+ Test that fetches kernelConnection with configured host and ports
+ """
+ conn = Connection(
+ conn_type="jupyter_kernel", host="test_host",
extra='{"shell_port": 60000, "session_key": "key"}'
+ )
+ with patch.object(KernelHook, "get_connection", return_value=conn):
+ hook = KernelHook()
+ assert hook.get_conn().ip == "test_host"
+ assert hook.get_conn().shell_port == 60000
+ assert hook.get_conn().session_key == "key"
diff --git a/tests/providers/papermill/operators/test_papermill.py
b/tests/providers/papermill/operators/test_papermill.py
index d862a2eb87..03fd0dc74e 100644
--- a/tests/providers/papermill/operators/test_papermill.py
+++ b/tests/providers/papermill/operators/test_papermill.py
@@ -17,11 +17,18 @@
# under the License.
from __future__ import annotations
-from unittest.mock import patch
+from unittest.mock import MagicMock, patch
import pytest
-from airflow.providers.papermill.operators.papermill import NoteBook,
PapermillOperator
+from airflow.providers.papermill.hooks.kernel import (
+ JUPYTER_KERNEL_CONTROL_PORT,
+ JUPYTER_KERNEL_HB_PORT,
+ JUPYTER_KERNEL_IOPUB_PORT,
+ JUPYTER_KERNEL_SHELL_PORT,
+ JUPYTER_KERNEL_STDIN_PORT,
+)
+from airflow.providers.papermill.operators.papermill import
REMOTE_KERNEL_ENGINE, NoteBook, PapermillOperator
from airflow.utils import timezone
DEFAULT_DATE = timezone.datetime(2021, 1, 1)
@@ -96,6 +103,51 @@ class TestPapermillOperator:
language=language_name,
progress_bar=False,
report_mode=True,
+ engine_name=None,
+ )
+
+
@patch("airflow.providers.papermill.hooks.kernel.KernelHook.get_connection")
+ @patch("airflow.providers.papermill.operators.papermill.pm")
+ def test_execute_remote_kernel(self, mock_papermill, kernel_hook):
+ in_nb = "/tmp/does_not_exist"
+ out_nb = "/tmp/will_not_exist"
+ kernel_name = "python3"
+ language_name = "python"
+ parameters = {"msg": "hello_world", "train": 1}
+ conn = MagicMock()
+ conn.host = "127.0.0.1"
+ conn.extra_dejson = {"session_key": "notebooks"}
+ kernel_hook.return_value = conn
+
+ op = PapermillOperator(
+ input_nb=in_nb,
+ output_nb=out_nb,
+ parameters=parameters,
+ task_id="papermill_operator_test",
+ kernel_name=kernel_name,
+ language_name=language_name,
+ kernel_conn_id="jupyter_kernel_default",
+ dag=None,
+ )
+
+ op.execute(context={})
+
+ mock_papermill.execute_notebook.assert_called_once_with(
+ in_nb,
+ out_nb,
+ parameters=parameters,
+ kernel_name=kernel_name,
+ language=language_name,
+ progress_bar=False,
+ report_mode=True,
+ engine_name=REMOTE_KERNEL_ENGINE,
+ kernel_session_key="notebooks",
+ kernel_shell_port=JUPYTER_KERNEL_SHELL_PORT,
+ kernel_iopub_port=JUPYTER_KERNEL_IOPUB_PORT,
+ kernel_stdin_port=JUPYTER_KERNEL_STDIN_PORT,
+ kernel_control_port=JUPYTER_KERNEL_CONTROL_PORT,
+ kernel_hb_port=JUPYTER_KERNEL_HB_PORT,
+ kernel_ip="127.0.0.1",
)
@pytest.mark.db_test
diff --git a/tests/system/providers/papermill/conftest.py
b/tests/system/providers/papermill/conftest.py
new file mode 100644
index 0000000000..4594c13e84
--- /dev/null
+++ b/tests/system/providers/papermill/conftest.py
@@ -0,0 +1,56 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import os
+import subprocess
+
+import pytest
+
+from airflow.providers.papermill.hooks.kernel import (
+ JUPYTER_KERNEL_CONTROL_PORT,
+ JUPYTER_KERNEL_HB_PORT,
+ JUPYTER_KERNEL_IOPUB_PORT,
+ JUPYTER_KERNEL_SHELL_PORT,
+ JUPYTER_KERNEL_STDIN_PORT,
+)
+
+
[email protected](scope="session", autouse=True)
+def remote_kernel(request):
+ proc = subprocess.Popen(
+ [
+ "python3",
+ "-m",
+ "ipykernel",
+ '--Session.key=b""',
+ f"--hb={JUPYTER_KERNEL_HB_PORT}",
+ f"--shell={JUPYTER_KERNEL_SHELL_PORT}",
+ f"--iopub={JUPYTER_KERNEL_IOPUB_PORT}",
+ f"--stdin={JUPYTER_KERNEL_STDIN_PORT}",
+ f"--control={JUPYTER_KERNEL_CONTROL_PORT}",
+ "--ip=0.0.0.0",
+ ]
+ )
+ request.addfinalizer(proc.kill)
+
+
[email protected](scope="session", autouse=True)
+def airflow_conn(remote_kernel):
+ os.environ[
+ "AIRFLOW_CONN_JUPYTER_KERNEL_DEFAULT"
+ ] = '{"host": "localhost", "extra": {"shell_port": 60316} }'
diff --git
a/tests/system/providers/papermill/example_papermill_remote_verify.py
b/tests/system/providers/papermill/example_papermill_remote_verify.py
new file mode 100644
index 0000000000..ba57e15454
--- /dev/null
+++ b/tests/system/providers/papermill/example_papermill_remote_verify.py
@@ -0,0 +1,80 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+This DAG will use Papermill to run the notebook "hello_world", based on the
execution date
+it will create an output notebook "out-<date>". All fields, including the keys
in the parameters, are
+templated.
+"""
+from __future__ import annotations
+
+import os
+from datetime import datetime, timedelta
+
+import scrapbook as sb
+
+from airflow import DAG
+from airflow.decorators import task
+from airflow.providers.papermill.operators.papermill import PapermillOperator
+
+START_DATE = datetime(2021, 1, 1)
+SCHEDULE_INTERVAL = "@once"
+DAGRUN_TIMEOUT = timedelta(minutes=60)
+ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID")
+DAG_ID = "example_papermill_operator_remote_verify"
+
+
+# [START howto_verify_operator_papermill_remote_kernel]
+@task
+def check_notebook(output_notebook, execution_date):
+ """
+ Verify the message in the notebook
+ """
+ notebook = sb.read_notebook(output_notebook)
+ message = notebook.scraps["message"]
+ print(f"Message in notebook {message} for {execution_date}")
+
+ if message.data != f"Ran from Airflow at {execution_date}!":
+ return False
+
+ return True
+
+
+with DAG(
+ dag_id="example_papermill_operator_remote_verify",
+ schedule="@once",
+ start_date=START_DATE,
+ dagrun_timeout=DAGRUN_TIMEOUT,
+ catchup=False,
+) as dag:
+ run_this = PapermillOperator(
+ task_id="run_example_notebook",
+ input_nb=os.path.join(os.path.dirname(os.path.realpath(__file__)),
"input_notebook.ipynb"),
+ output_nb="/tmp/out-{{ execution_date }}.ipynb",
+ parameters={"msgs": "Ran from Airflow at {{ execution_date }}!"},
+ kernel_conn_id="jupyter_kernel_default",
+ )
+
+ run_this >> check_notebook(
+ output_notebook="/tmp/out-{{ execution_date }}.ipynb",
execution_date="{{ execution_date }}"
+ )
+# [END howto_verify_operator_papermill_remote_kernel]
+
+from tests.system.utils import get_test_run # noqa: E402
+
+# Needed to run the example DAG with pytest (see:
tests/system/README.md#run_via_pytest)
+test_run = get_test_run(dag)