This is an automated email from the ASF dual-hosted git repository.
eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new f382a79 Add support for templated fields in PapermillOperator (#18357)
f382a79 is described below
commit f382a79adabb2372a1ca5d9e43ed34afd9dec33d
Author: eladkal <[email protected]>
AuthorDate: Mon Sep 20 08:06:47 2021 +0300
Add support for templated fields in PapermillOperator (#18357)
---
airflow/providers/papermill/operators/papermill.py | 7 ++++++-
.../papermill/operators/test_papermill.py | 24 ++++++++++++++++++++++
2 files changed, 30 insertions(+), 1 deletion(-)
diff --git a/airflow/providers/papermill/operators/papermill.py
b/airflow/providers/papermill/operators/papermill.py
index 987bc01..ebbda66 100644
--- a/airflow/providers/papermill/operators/papermill.py
+++ b/airflow/providers/papermill/operators/papermill.py
@@ -48,6 +48,8 @@ class PapermillOperator(BaseOperator):
supports_lineage = True
+ template_fields = ('input_nb', 'output_nb', 'parameters')
+
def __init__(
self,
*,
@@ -58,8 +60,11 @@ class PapermillOperator(BaseOperator):
) -> None:
super().__init__(**kwargs)
+ self.input_nb = input_nb
+ self.output_nb = output_nb
+ self.parameters = parameters
if input_nb:
- self.inlets.append(NoteBook(url=input_nb, parameters=parameters))
+ self.inlets.append(NoteBook(url=input_nb,
parameters=self.parameters))
if output_nb:
self.outlets.append(NoteBook(url=output_nb))
diff --git a/tests/providers/papermill/operators/test_papermill.py
b/tests/providers/papermill/operators/test_papermill.py
index c2ebf2c..553c1fc 100644
--- a/tests/providers/papermill/operators/test_papermill.py
+++ b/tests/providers/papermill/operators/test_papermill.py
@@ -16,10 +16,14 @@
# specific language governing permissions and limitations
# under the License.
import unittest
+from datetime import datetime
from unittest.mock import patch
+from airflow.models import DAG, DagRun, TaskInstance
from airflow.providers.papermill.operators.papermill import PapermillOperator
+DEFAULT_DATE = datetime(2021, 1, 1)
+
class TestPapermillOperator(unittest.TestCase):
@patch('airflow.providers.papermill.operators.papermill.pm')
@@ -42,3 +46,23 @@ class TestPapermillOperator(unittest.TestCase):
mock_papermill.execute_notebook.assert_called_once_with(
in_nb, out_nb, parameters=parameters, progress_bar=False,
report_mode=True
)
+
+ def test_render_template(self):
+ args = {'owner': 'airflow', 'start_date': DEFAULT_DATE}
+ dag = DAG('test_render_template', default_args=args)
+
+ operator = PapermillOperator(
+ task_id="render_dag_test",
+ input_nb="/tmp/{{ dag.dag_id }}.ipynb",
+ output_nb="/tmp/out-{{ dag.dag_id }}.ipynb",
+ parameters={"msgs": "dag id is {{ dag.dag_id }}!"},
+ dag=dag,
+ )
+
+ ti = TaskInstance(operator, run_id="papermill_test")
+ ti.dag_run = DagRun(execution_date=DEFAULT_DATE)
+ ti.render_templates()
+
+ assert "/tmp/test_render_template.ipynb" == getattr(operator,
'input_nb')
+ assert '/tmp/out-test_render_template.ipynb' == getattr(operator,
'output_nb')
+ assert {"msgs": "dag id is test_render_template!"} ==
getattr(operator, 'parameters')