This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 736f2e898a Fix rendering parameters in PapermillOperator (#28979)
736f2e898a is described below

commit 736f2e898a91d3f2cb3e9ca811c9833609f28c5a
Author: Andrey Anshin <[email protected]>
AuthorDate: Mon Jan 23 02:32:20 2023 +0400

    Fix rendering parameters in PapermillOperator (#28979)
---
 airflow/providers/papermill/operators/papermill.py | 70 ++++++++++------
 .../papermill/operators/test_papermill.py          | 96 +++++++++++++++++-----
 2 files changed, 119 insertions(+), 47 deletions(-)

diff --git a/airflow/providers/papermill/operators/papermill.py 
b/airflow/providers/papermill/operators/papermill.py
index 531304d441..9e5d5c2e1b 100644
--- a/airflow/providers/papermill/operators/papermill.py
+++ b/airflow/providers/papermill/operators/papermill.py
@@ -17,7 +17,7 @@
 # under the License.
 from __future__ import annotations
 
-from typing import TYPE_CHECKING, Sequence
+from typing import TYPE_CHECKING, ClassVar, Collection, Optional, Sequence
 
 import attr
 import papermill as pm
@@ -33,8 +33,17 @@ if TYPE_CHECKING:
 class NoteBook(File):
     """Jupyter notebook"""
 
-    type_hint: str | None = "jupyter_notebook"
-    parameters: dict | None = {}
+    # For compatibility with Airflow 2.3:
+    # 1. Use predefined set because `File.template_fields` introduced in 
Airflow 2.4
+    # 2. Use old styled annotations because `cattrs` doesn't work well with 
PEP 604.
+
+    template_fields: ClassVar[Collection[str]] = {
+        "parameters",
+        *(File.template_fields if hasattr(File, "template_fields") else 
{"url"}),
+    }
+
+    type_hint: Optional[str] = "jupyter_notebook"  # noqa: UP007
+    parameters: Optional[dict] = {}  # noqa: UP007
 
     meta_schema: str = __name__ + ".NoteBook"
 
@@ -43,8 +52,8 @@ class PapermillOperator(BaseOperator):
     """
     Executes a jupyter notebook through papermill that is annotated with 
parameters
 
-    :param input_nb: input notebook (can also be a NoteBook or a File inlet)
-    :param output_nb: output notebook (can also be a NoteBook or File outlet)
+    :param input_nb: input notebook, either path or NoteBook inlet.
+    :param output_nb: output notebook, either path or NoteBook outlet.
     :param parameters: the notebook parameters to set
     :param kernel_name: (optional) name of kernel to execute the notebook 
against
         (ignores kernel name in the notebook document metadata)
@@ -57,36 +66,43 @@ class PapermillOperator(BaseOperator):
     def __init__(
         self,
         *,
-        input_nb: str | None = None,
-        output_nb: str | None = None,
+        input_nb: str | NoteBook | None = None,
+        output_nb: str | NoteBook | None = None,
         parameters: dict | None = None,
         kernel_name: str | None = None,
         language_name: str | None = None,
         **kwargs,
     ) -> None:
         super().__init__(**kwargs)
-
-        self.input_nb = input_nb
-        self.output_nb = output_nb
         self.parameters = parameters
+
+        if not input_nb:
+            raise ValueError("Input notebook is not specified")
+        elif not isinstance(input_nb, NoteBook):
+            self.input_nb = NoteBook(url=input_nb, parameters=self.parameters)
+        else:
+            self.input_nb = input_nb
+
+        if not output_nb:
+            raise ValueError("Output notebook is not specified")
+        elif not isinstance(output_nb, NoteBook):
+            self.output_nb = NoteBook(url=output_nb)
+        else:
+            self.output_nb = output_nb
+
         self.kernel_name = kernel_name
         self.language_name = language_name
-        if input_nb:
-            self.inlets.append(NoteBook(url=input_nb, 
parameters=self.parameters))
-        if output_nb:
-            self.outlets.append(NoteBook(url=output_nb))
+
+        self.inlets.append(self.input_nb)
+        self.outlets.append(self.output_nb)
 
     def execute(self, context: Context):
-        if not self.inlets or not self.outlets:
-            raise ValueError("Input notebook or output notebook is not 
specified")
-
-        for i, item in enumerate(self.inlets):
-            pm.execute_notebook(
-                item.url,
-                self.outlets[i].url,
-                parameters=item.parameters,
-                progress_bar=False,
-                report_mode=True,
-                kernel_name=self.kernel_name,
-                language=self.language_name,
-            )
+        pm.execute_notebook(
+            self.input_nb.url,
+            self.output_nb.url,
+            parameters=self.input_nb.parameters,
+            progress_bar=False,
+            report_mode=True,
+            kernel_name=self.kernel_name,
+            language=self.language_name,
+        )
diff --git a/tests/providers/papermill/operators/test_papermill.py 
b/tests/providers/papermill/operators/test_papermill.py
index 2ab23280a7..55569e68ba 100644
--- a/tests/providers/papermill/operators/test_papermill.py
+++ b/tests/providers/papermill/operators/test_papermill.py
@@ -19,14 +19,55 @@ from __future__ import annotations
 
 from unittest.mock import patch
 
-from airflow.models import DAG, DagRun, TaskInstance
-from airflow.providers.papermill.operators.papermill import PapermillOperator
+import pytest
+
+from airflow.providers.papermill.operators.papermill import NoteBook, 
PapermillOperator
 from airflow.utils import timezone
 
 DEFAULT_DATE = timezone.datetime(2021, 1, 1)
+TEST_INPUT_URL = "/foo/bar"
+TEST_OUTPUT_URL = "/spam/egg"
+
+
+class TestNoteBook:
+    """Test NoteBook object."""
+
+    def test_templated_fields(self):
+        assert hasattr(NoteBook, "template_fields")
+        assert "parameters" in NoteBook.template_fields
 
 
 class TestPapermillOperator:
+    """Test PapermillOperator."""
+
+    def test_mandatory_attributes(self):
+        """Test missing Input or Output notebooks."""
+        with pytest.raises(ValueError, match="Input notebook is not 
specified"):
+            PapermillOperator(task_id="missing_input_nb", output_nb="foo-bar")
+
+        with pytest.raises(ValueError, match="Output notebook is not 
specified"):
+            PapermillOperator(task_id="missing_input_nb", input_nb="foo-bar")
+
+    @pytest.mark.parametrize(
+        "output_nb",
+        [
+            pytest.param(TEST_OUTPUT_URL, id="output-as-string"),
+            pytest.param(NoteBook(TEST_OUTPUT_URL), 
id="output-as-notebook-object"),
+        ],
+    )
+    @pytest.mark.parametrize(
+        "input_nb",
+        [
+            pytest.param(TEST_INPUT_URL, id="input-as-string"),
+            pytest.param(NoteBook(TEST_INPUT_URL), 
id="input-as-notebook-object"),
+        ],
+    )
+    def test_notebooks_objects(self, input_nb, output_nb):
+        """Test different type of Input/Output notebooks arguments."""
+        op = PapermillOperator(task_id="test_notebooks_objects", 
input_nb=input_nb, output_nb=output_nb)
+        assert op.input_nb.url == TEST_INPUT_URL
+        assert op.output_nb.url == TEST_OUTPUT_URL
+
     @patch("airflow.providers.papermill.operators.papermill.pm")
     def test_execute(self, mock_papermill):
         in_nb = "/tmp/does_not_exist"
@@ -57,26 +98,41 @@ class TestPapermillOperator:
             report_mode=True,
         )
 
-    def test_render_template(self):
-        args = {"owner": "airflow", "start_date": DEFAULT_DATE}
-        dag = DAG("test_render_template", default_args=args)
-
-        operator = PapermillOperator(
-            task_id="render_dag_test",
+    def test_render_template(self, create_task_instance_of_operator):
+        """Test rendering fields."""
+        ti = create_task_instance_of_operator(
+            PapermillOperator,
             input_nb="/tmp/{{ dag.dag_id }}.ipynb",
             output_nb="/tmp/out-{{ dag.dag_id }}.ipynb",
-            parameters={"msgs": "dag id is {{ dag.dag_id }}!"},
-            kernel_name="python3",
-            language_name="python",
-            dag=dag,
+            parameters={"msgs": "dag id is {{ dag.dag_id }}!", "test_dt": "{{ 
ds }}"},
+            kernel_name="{{ params.kernel_name }}",
+            language_name="{{ params.language_name }}",
+            # Additional parameters for render fields
+            params={
+                "kernel_name": "python3",
+                "language_name": "python",
+            },
+            # TI Settings
+            dag_id="test_render_template",
+            task_id="render_dag_test",
+            execution_date=DEFAULT_DATE,
         )
+        task = ti.render_templates()
+
+        # Test render Input/Output notebook attributes
+        assert task.input_nb.url == "/tmp/test_render_template.ipynb"
+        assert task.input_nb.parameters == {
+            "msgs": "dag id is test_render_template!",
+            "test_dt": DEFAULT_DATE.date().isoformat(),
+        }
+        assert task.output_nb.url == "/tmp/out-test_render_template.ipynb"
+        assert task.output_nb.parameters == {}
 
-        ti = TaskInstance(operator, run_id="papermill_test")
-        ti.dag_run = DagRun(execution_date=DEFAULT_DATE)
-        ti.render_templates()
+        # Test render other templated attributes
+        assert task.parameters == task.input_nb.parameters
+        assert "python3" == task.kernel_name
+        assert "python" == task.language_name
 
-        assert "/tmp/test_render_template.ipynb" == operator.input_nb
-        assert "/tmp/out-test_render_template.ipynb" == operator.output_nb
-        assert {"msgs": "dag id is test_render_template!"} == 
operator.parameters
-        assert "python3" == operator.kernel_name
-        assert "python" == operator.language_name
+        # Test render Lineage inlets/outlets
+        assert task.inlets[0] == task.input_nb
+        assert task.outlets[0] == task.output_nb

Reply via email to