This is an automated email from the ASF dual-hosted git repository.

eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new f382a79  Add support for templated fields in PapermillOperator (#18357)
f382a79 is described below

commit f382a79adabb2372a1ca5d9e43ed34afd9dec33d
Author: eladkal <[email protected]>
AuthorDate: Mon Sep 20 08:06:47 2021 +0300

    Add support for templated fields in PapermillOperator (#18357)
---
 airflow/providers/papermill/operators/papermill.py |  7 ++++++-
 .../papermill/operators/test_papermill.py          | 24 ++++++++++++++++++++++
 2 files changed, 30 insertions(+), 1 deletion(-)

diff --git a/airflow/providers/papermill/operators/papermill.py 
b/airflow/providers/papermill/operators/papermill.py
index 987bc01..ebbda66 100644
--- a/airflow/providers/papermill/operators/papermill.py
+++ b/airflow/providers/papermill/operators/papermill.py
@@ -48,6 +48,8 @@ class PapermillOperator(BaseOperator):
 
     supports_lineage = True
 
+    template_fields = ('input_nb', 'output_nb', 'parameters')
+
     def __init__(
         self,
         *,
@@ -58,8 +60,11 @@ class PapermillOperator(BaseOperator):
     ) -> None:
         super().__init__(**kwargs)
 
+        self.input_nb = input_nb
+        self.output_nb = output_nb
+        self.parameters = parameters
         if input_nb:
-            self.inlets.append(NoteBook(url=input_nb, parameters=parameters))
+            self.inlets.append(NoteBook(url=input_nb, 
parameters=self.parameters))
         if output_nb:
             self.outlets.append(NoteBook(url=output_nb))
 
diff --git a/tests/providers/papermill/operators/test_papermill.py 
b/tests/providers/papermill/operators/test_papermill.py
index c2ebf2c..553c1fc 100644
--- a/tests/providers/papermill/operators/test_papermill.py
+++ b/tests/providers/papermill/operators/test_papermill.py
@@ -16,10 +16,14 @@
 # specific language governing permissions and limitations
 # under the License.
 import unittest
+from datetime import datetime
 from unittest.mock import patch
 
+from airflow.models import DAG, DagRun, TaskInstance
 from airflow.providers.papermill.operators.papermill import PapermillOperator
 
+DEFAULT_DATE = datetime(2021, 1, 1)
+
 
 class TestPapermillOperator(unittest.TestCase):
     @patch('airflow.providers.papermill.operators.papermill.pm')
@@ -42,3 +46,23 @@ class TestPapermillOperator(unittest.TestCase):
         mock_papermill.execute_notebook.assert_called_once_with(
             in_nb, out_nb, parameters=parameters, progress_bar=False, 
report_mode=True
         )
+
+    def test_render_template(self):
+        args = {'owner': 'airflow', 'start_date': DEFAULT_DATE}
+        dag = DAG('test_render_template', default_args=args)
+
+        operator = PapermillOperator(
+            task_id="render_dag_test",
+            input_nb="/tmp/{{ dag.dag_id }}.ipynb",
+            output_nb="/tmp/out-{{ dag.dag_id }}.ipynb",
+            parameters={"msgs": "dag id is {{ dag.dag_id }}!"},
+            dag=dag,
+        )
+
+        ti = TaskInstance(operator, run_id="papermill_test")
+        ti.dag_run = DagRun(execution_date=DEFAULT_DATE)
+        ti.render_templates()
+
+        assert "/tmp/test_render_template.ipynb" == getattr(operator, 
'input_nb')
+        assert '/tmp/out-test_render_template.ipynb' == getattr(operator, 
'output_nb')
+        assert {"msgs": "dag id is test_render_template!"} == 
getattr(operator, 'parameters')

Reply via email to