This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 736f2e898a Fix rendering parameters in PapermillOperator (#28979)
736f2e898a is described below
commit 736f2e898a91d3f2cb3e9ca811c9833609f28c5a
Author: Andrey Anshin <[email protected]>
AuthorDate: Mon Jan 23 02:32:20 2023 +0400
Fix rendering parameters in PapermillOperator (#28979)
---
airflow/providers/papermill/operators/papermill.py | 70 ++++++++++------
.../papermill/operators/test_papermill.py | 96 +++++++++++++++++-----
2 files changed, 119 insertions(+), 47 deletions(-)
diff --git a/airflow/providers/papermill/operators/papermill.py
b/airflow/providers/papermill/operators/papermill.py
index 531304d441..9e5d5c2e1b 100644
--- a/airflow/providers/papermill/operators/papermill.py
+++ b/airflow/providers/papermill/operators/papermill.py
@@ -17,7 +17,7 @@
# under the License.
from __future__ import annotations
-from typing import TYPE_CHECKING, Sequence
+from typing import TYPE_CHECKING, ClassVar, Collection, Optional, Sequence
import attr
import papermill as pm
@@ -33,8 +33,17 @@ if TYPE_CHECKING:
class NoteBook(File):
"""Jupyter notebook"""
- type_hint: str | None = "jupyter_notebook"
- parameters: dict | None = {}
+ # For compatibility with Airflow 2.3:
+ # 1. Use predefined set because `File.template_fields` introduced in
Airflow 2.4
+ # 2. Use old styled annotations because `cattrs` doesn't work well with
PEP 604.
+
+ template_fields: ClassVar[Collection[str]] = {
+ "parameters",
+ *(File.template_fields if hasattr(File, "template_fields") else
{"url"}),
+ }
+
+ type_hint: Optional[str] = "jupyter_notebook" # noqa: UP007
+ parameters: Optional[dict] = {} # noqa: UP007
meta_schema: str = __name__ + ".NoteBook"
@@ -43,8 +52,8 @@ class PapermillOperator(BaseOperator):
"""
Executes a jupyter notebook through papermill that is annotated with
parameters
- :param input_nb: input notebook (can also be a NoteBook or a File inlet)
- :param output_nb: output notebook (can also be a NoteBook or File outlet)
+ :param input_nb: input notebook, either path or NoteBook inlet.
+ :param output_nb: output notebook, either path or NoteBook outlet.
:param parameters: the notebook parameters to set
:param kernel_name: (optional) name of kernel to execute the notebook
against
(ignores kernel name in the notebook document metadata)
@@ -57,36 +66,43 @@ class PapermillOperator(BaseOperator):
def __init__(
self,
*,
- input_nb: str | None = None,
- output_nb: str | None = None,
+ input_nb: str | NoteBook | None = None,
+ output_nb: str | NoteBook | None = None,
parameters: dict | None = None,
kernel_name: str | None = None,
language_name: str | None = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
-
- self.input_nb = input_nb
- self.output_nb = output_nb
self.parameters = parameters
+
+ if not input_nb:
+ raise ValueError("Input notebook is not specified")
+ elif not isinstance(input_nb, NoteBook):
+ self.input_nb = NoteBook(url=input_nb, parameters=self.parameters)
+ else:
+ self.input_nb = input_nb
+
+ if not output_nb:
+ raise ValueError("Output notebook is not specified")
+ elif not isinstance(output_nb, NoteBook):
+ self.output_nb = NoteBook(url=output_nb)
+ else:
+ self.output_nb = output_nb
+
self.kernel_name = kernel_name
self.language_name = language_name
- if input_nb:
- self.inlets.append(NoteBook(url=input_nb,
parameters=self.parameters))
- if output_nb:
- self.outlets.append(NoteBook(url=output_nb))
+
+ self.inlets.append(self.input_nb)
+ self.outlets.append(self.output_nb)
def execute(self, context: Context):
- if not self.inlets or not self.outlets:
- raise ValueError("Input notebook or output notebook is not
specified")
-
- for i, item in enumerate(self.inlets):
- pm.execute_notebook(
- item.url,
- self.outlets[i].url,
- parameters=item.parameters,
- progress_bar=False,
- report_mode=True,
- kernel_name=self.kernel_name,
- language=self.language_name,
- )
+ pm.execute_notebook(
+ self.input_nb.url,
+ self.output_nb.url,
+ parameters=self.input_nb.parameters,
+ progress_bar=False,
+ report_mode=True,
+ kernel_name=self.kernel_name,
+ language=self.language_name,
+ )
diff --git a/tests/providers/papermill/operators/test_papermill.py
b/tests/providers/papermill/operators/test_papermill.py
index 2ab23280a7..55569e68ba 100644
--- a/tests/providers/papermill/operators/test_papermill.py
+++ b/tests/providers/papermill/operators/test_papermill.py
@@ -19,14 +19,55 @@ from __future__ import annotations
from unittest.mock import patch
-from airflow.models import DAG, DagRun, TaskInstance
-from airflow.providers.papermill.operators.papermill import PapermillOperator
+import pytest
+
+from airflow.providers.papermill.operators.papermill import NoteBook,
PapermillOperator
from airflow.utils import timezone
DEFAULT_DATE = timezone.datetime(2021, 1, 1)
+TEST_INPUT_URL = "/foo/bar"
+TEST_OUTPUT_URL = "/spam/egg"
+
+
+class TestNoteBook:
+ """Test NoteBook object."""
+
+ def test_templated_fields(self):
+ assert hasattr(NoteBook, "template_fields")
+ assert "parameters" in NoteBook.template_fields
class TestPapermillOperator:
+ """Test PapermillOperator."""
+
+ def test_mandatory_attributes(self):
+ """Test missing Input or Output notebooks."""
+ with pytest.raises(ValueError, match="Input notebook is not
specified"):
+ PapermillOperator(task_id="missing_input_nb", output_nb="foo-bar")
+
+ with pytest.raises(ValueError, match="Output notebook is not
specified"):
+ PapermillOperator(task_id="missing_input_nb", input_nb="foo-bar")
+
+ @pytest.mark.parametrize(
+ "output_nb",
+ [
+ pytest.param(TEST_OUTPUT_URL, id="output-as-string"),
+ pytest.param(NoteBook(TEST_OUTPUT_URL),
id="output-as-notebook-object"),
+ ],
+ )
+ @pytest.mark.parametrize(
+ "input_nb",
+ [
+ pytest.param(TEST_INPUT_URL, id="input-as-string"),
+ pytest.param(NoteBook(TEST_INPUT_URL),
id="input-as-notebook-object"),
+ ],
+ )
+ def test_notebooks_objects(self, input_nb, output_nb):
+ """Test different type of Input/Output notebooks arguments."""
+ op = PapermillOperator(task_id="test_notebooks_objects",
input_nb=input_nb, output_nb=output_nb)
+ assert op.input_nb.url == TEST_INPUT_URL
+ assert op.output_nb.url == TEST_OUTPUT_URL
+
@patch("airflow.providers.papermill.operators.papermill.pm")
def test_execute(self, mock_papermill):
in_nb = "/tmp/does_not_exist"
@@ -57,26 +98,41 @@ class TestPapermillOperator:
report_mode=True,
)
- def test_render_template(self):
- args = {"owner": "airflow", "start_date": DEFAULT_DATE}
- dag = DAG("test_render_template", default_args=args)
-
- operator = PapermillOperator(
- task_id="render_dag_test",
+ def test_render_template(self, create_task_instance_of_operator):
+ """Test rendering fields."""
+ ti = create_task_instance_of_operator(
+ PapermillOperator,
input_nb="/tmp/{{ dag.dag_id }}.ipynb",
output_nb="/tmp/out-{{ dag.dag_id }}.ipynb",
- parameters={"msgs": "dag id is {{ dag.dag_id }}!"},
- kernel_name="python3",
- language_name="python",
- dag=dag,
+ parameters={"msgs": "dag id is {{ dag.dag_id }}!", "test_dt": "{{
ds }}"},
+ kernel_name="{{ params.kernel_name }}",
+ language_name="{{ params.language_name }}",
+ # Additional parameters for render fields
+ params={
+ "kernel_name": "python3",
+ "language_name": "python",
+ },
+ # TI Settings
+ dag_id="test_render_template",
+ task_id="render_dag_test",
+ execution_date=DEFAULT_DATE,
)
+ task = ti.render_templates()
+
+ # Test render Input/Output notebook attributes
+ assert task.input_nb.url == "/tmp/test_render_template.ipynb"
+ assert task.input_nb.parameters == {
+ "msgs": "dag id is test_render_template!",
+ "test_dt": DEFAULT_DATE.date().isoformat(),
+ }
+ assert task.output_nb.url == "/tmp/out-test_render_template.ipynb"
+ assert task.output_nb.parameters == {}
- ti = TaskInstance(operator, run_id="papermill_test")
- ti.dag_run = DagRun(execution_date=DEFAULT_DATE)
- ti.render_templates()
+ # Test render other templated attributes
+ assert task.parameters == task.input_nb.parameters
+ assert "python3" == task.kernel_name
+ assert "python" == task.language_name
- assert "/tmp/test_render_template.ipynb" == operator.input_nb
- assert "/tmp/out-test_render_template.ipynb" == operator.output_nb
- assert {"msgs": "dag id is test_render_template!"} ==
operator.parameters
- assert "python3" == operator.kernel_name
- assert "python" == operator.language_name
+ # Test render Lineage inlets/outlets
+ assert task.inlets[0] == task.input_nb
+ assert task.outlets[0] == task.output_nb