Repository: incubator-airflow Updated Branches: refs/heads/master b3489b99e -> 3c5b73579
[AIRFLOW-1954] Add DataFlowTemplateOperator Closes #2909 from dsdinter/dataflow_template Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/3c5b7357 Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/3c5b7357 Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/3c5b7357 Branch: refs/heads/master Commit: 3c5b73579a3e6c8a1e47c2fddf201b99d690bafe Parents: b3489b9 Author: David Sabater <[email protected]> Authored: Thu Jan 4 13:44:07 2018 -0800 Committer: Chris Riccomini <[email protected]> Committed: Thu Jan 4 13:44:10 2018 -0800 ---------------------------------------------------------------------- airflow/contrib/hooks/gcp_dataflow_hook.py | 28 ++++++ airflow/contrib/operators/dataflow_operator.py | 98 +++++++++++++++++++- tests/contrib/hooks/test_gcp_dataflow_hook.py | 29 +++++- .../contrib/operators/test_dataflow_operator.py | 61 +++++++++++- 4 files changed, 208 insertions(+), 8 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/3c5b7357/airflow/contrib/hooks/gcp_dataflow_hook.py ---------------------------------------------------------------------- diff --git a/airflow/contrib/hooks/gcp_dataflow_hook.py b/airflow/contrib/hooks/gcp_dataflow_hook.py index 7cb7c79..f9970d9 100644 --- a/airflow/contrib/hooks/gcp_dataflow_hook.py +++ b/airflow/contrib/hooks/gcp_dataflow_hook.py @@ -166,6 +166,11 @@ class DataFlowHook(GoogleCloudBaseHook): self._start_dataflow(task_id, variables, dataflow, name, ["java", "-jar"], label_formatter) + def start_template_dataflow(self, task_id, variables, parameters, dataflow_template): + name = task_id + "-" + str(uuid.uuid1())[:8] + self._start_template_dataflow( + name, variables, parameters, dataflow_template) + def start_python_dataflow(self, task_id, variables, dataflow, py_options): name = task_id + "-" + str(uuid.uuid1())[:8] variables["job_name"] = name @@ -185,3 +190,26 @@ class DataFlowHook(GoogleCloudBaseHook): else: command.append("--" + attr + "=" + value) return command + + def _start_template_dataflow(self, name, variables, parameters, dataflow_template): + # Builds RuntimeEnvironment from variables dictionary + # https://cloud.google.com/dataflow/docs/reference/rest/v1b3/RuntimeEnvironment + environment = {} + for key in ['maxWorkers', 'zone', 'serviceAccountEmail', 'tempLocation', + 'bypassTempDirValidation', 'machineType']: + if key in variables: + environment.update({key: variables[key]}) + body = {"jobName": name, + "parameters": parameters, + "environment": environment} + service = self.get_conn() + if variables['project'] is None: + raise Exception( + 'Project not specified') + request = service.projects().templates().launch(projectId=variables['project'], + gcsPath=dataflow_template, + body=body) + response = request.execute() + _DataflowJob( + self.get_conn(), variables['project'], name, self.poll_sleep).wait_for_done() + return response http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/3c5b7357/airflow/contrib/operators/dataflow_operator.py ---------------------------------------------------------------------- diff --git a/airflow/contrib/operators/dataflow_operator.py b/airflow/contrib/operators/dataflow_operator.py index 01fbd35..915e26c 100644 --- a/airflow/contrib/operators/dataflow_operator.py +++ b/airflow/contrib/operators/dataflow_operator.py @@ -12,9 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy import re import uuid +import copy from airflow.contrib.hooks.gcs_hook import GoogleCloudStorageHook from airflow.contrib.hooks.gcp_dataflow_hook import DataFlowHook @@ -131,6 +131,102 @@ class DataFlowJavaOperator(BaseOperator): hook.start_java_dataflow(self.task_id, dataflow_options, self.jar) +class DataflowTemplateOperator(BaseOperator): + """ + Start a Templated Cloud DataFlow batch job. The parameters of the operation + will be passed to the job. + It's a good practice to define dataflow_* parameters in the default_args of the dag + like the project, zone and staging location. + https://cloud.google.com/dataflow/docs/reference/rest/v1b3/LaunchTemplateParameters + https://cloud.google.com/dataflow/docs/reference/rest/v1b3/RuntimeEnvironment + ``` + default_args = { + 'dataflow_default_options': { + 'project': 'my-gcp-project' + 'zone': 'europe-west1-d', + 'tempLocation': 'gs://my-staging-bucket/staging/' + } + } + } + ``` + You need to pass the path to your dataflow template as a file reference with the + ``template`` parameter. Use ``parameters`` to pass on parameters to your job. + Use ``environment`` to pass on runtime environment variables to your job. + ``` + t1 = DataflowTemplateOperator( + task_id='datapflow_example', + template='{{var.value.gcp_dataflow_base}}', + parameters={ + 'inputFile': "gs://bucket/input/my_input.txt", + 'outputFile': "gs://bucket/output/my_output.txt" + }, + dag=my-dag) + ``` + ``template`` ``dataflow_default_options`` and ``parameters`` are templated so you can + use variables in them. + """ + template_fields = ['parameters', 'dataflow_default_options', 'template'] + ui_color = '#0273d4' + + @apply_defaults + def __init__( + self, + template, + dataflow_default_options=None, + parameters=None, + gcp_conn_id='google_cloud_default', + delegate_to=None, + poll_sleep=10, + *args, + **kwargs): + """ + Create a new DataflowTemplateOperator. Note that + dataflow_default_options is expected to save high-level options + for project information, which apply to all dataflow operators in the DAG. + https://cloud.google.com/dataflow/docs/reference/rest/v1b3 + /LaunchTemplateParameters + https://cloud.google.com/dataflow/docs/reference/rest/v1b3/RuntimeEnvironment + For more detail on job template execution have a look at the reference: + https://cloud.google.com/dataflow/docs/templates/executing-templates + :param template: The reference to the DataFlow template. + :type template: string + :param dataflow_default_options: Map of default job environment options. + :type dataflow_default_options: dict + :param parameters: Map of job specific parameters for the template. + :type parameters: dict + :param gcp_conn_id: The connection ID to use connecting to Google Cloud + Platform. + :type gcp_conn_id: string + :param delegate_to: The account to impersonate, if any. + For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: string + :param poll_sleep: The time in seconds to sleep between polling Google + Cloud Platform for the dataflow job status while the job is in the + JOB_STATE_RUNNING state. + :type poll_sleep: int + """ + super(DataflowTemplateOperator, self).__init__(*args, **kwargs) + + dataflow_default_options = dataflow_default_options or {} + parameters = parameters or {} + + self.gcp_conn_id = gcp_conn_id + self.delegate_to = delegate_to + self.dataflow_default_options = dataflow_default_options + self.poll_sleep = poll_sleep + self.template = template + self.parameters = parameters + + def execute(self, context): + hook = DataFlowHook(gcp_conn_id=self.gcp_conn_id, + delegate_to=self.delegate_to, + poll_sleep=self.poll_sleep) + + hook.start_template_dataflow(self.task_id, self.dataflow_default_options, + self.parameters, self.template) + + class DataFlowPythonOperator(BaseOperator): template_fields = ['options', 'dataflow_default_options'] http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/3c5b7357/tests/contrib/hooks/test_gcp_dataflow_hook.py ---------------------------------------------------------------------- diff --git a/tests/contrib/hooks/test_gcp_dataflow_hook.py b/tests/contrib/hooks/test_gcp_dataflow_hook.py index a37b153..bf513c8 100644 --- a/tests/contrib/hooks/test_gcp_dataflow_hook.py +++ b/tests/contrib/hooks/test_gcp_dataflow_hook.py @@ -29,7 +29,12 @@ except ImportError: mock = None -TASK_ID = 'test-python-dataflow' +TASK_ID = 'test-dataflow-operator' +TEMPLATE = 'gs://dataflow-templates/wordcount/template_file' +PARAMETERS = { + 'inputFile': 'gs://dataflow-samples/shakespeare/kinglear.txt', + 'output': 'gs://test/output/my_output' +} PY_FILE = 'apache_beam.examples.wordcount' JAR_FILE = 'unitest.jar' PY_OPTIONS = ['-m'] @@ -43,6 +48,11 @@ DATAFLOW_OPTIONS_JAVA = { 'stagingLocation': 'gs://test/staging', 'labels': {'foo': 'bar'} } +DATAFLOW_OPTIONS_TEMPLATE = { + 'project': 'test', + 'tempLocation': 'gs://test/temp', + 'zone': 'us-central1-f' +} BASE_STRING = 'airflow.contrib.hooks.gcp_api_base_hook.{}' DATAFLOW_STRING = 'airflow.contrib.hooks.gcp_dataflow_hook.{}' MOCK_UUID = '12345678' @@ -52,7 +62,7 @@ def mock_init(self, gcp_conn_id, delegate_to=None): pass -class DataFlowHookTest(unittest.TestCase): +class DataFlowPythonHookTest(unittest.TestCase): def setUp(self): with mock.patch(BASE_STRING.format('GoogleCloudBaseHook.__init__'), @@ -129,3 +139,18 @@ class DataFlowHookTest(unittest.TestCase): self.assertRaises(Exception, dataflow.wait_for_done) mock_logging.warning.assert_has_calls([call('test'), call('error')]) + +class DataFlowTemplateHookTest(unittest.TestCase): + + def setUp(self): + with mock.patch(BASE_STRING.format('GoogleCloudBaseHook.__init__'), + new=mock_init): + self.dataflow_hook = DataFlowHook(gcp_conn_id='test') + + @mock.patch(DATAFLOW_STRING.format('DataFlowHook._start_template_dataflow')) + def test_start_template_dataflow(self, internal_dataflow_mock): + self.dataflow_hook.start_template_dataflow( + task_id=TASK_ID, variables=DATAFLOW_OPTIONS_TEMPLATE, parameters=PARAMETERS, + dataflow_template=TEMPLATE) + internal_dataflow_mock.assert_called_once_with( + mock.ANY, DATAFLOW_OPTIONS_TEMPLATE, PARAMETERS, TEMPLATE) http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/3c5b7357/tests/contrib/operators/test_dataflow_operator.py ---------------------------------------------------------------------- diff --git a/tests/contrib/operators/test_dataflow_operator.py b/tests/contrib/operators/test_dataflow_operator.py index 5b07051..da95d18 100644 --- a/tests/contrib/operators/test_dataflow_operator.py +++ b/tests/contrib/operators/test_dataflow_operator.py @@ -15,6 +15,8 @@ import unittest +from airflow.contrib.operators.dataflow_operator import DataFlowPythonOperator, \ + DataflowTemplateOperator from airflow.contrib.operators.dataflow_operator import DataFlowPythonOperator from airflow.version import version @@ -27,12 +29,23 @@ except ImportError: mock = None -TASK_ID = 'test-python-dataflow' +TASK_ID = 'test-dataflow-operator' +TEMPLATE = 'gs://dataflow-templates/wordcount/template_file' +PARAMETERS = { + 'inputFile': 'gs://dataflow-samples/shakespeare/kinglear.txt', + 'output': 'gs://test/output/my_output' +} PY_FILE = 'gs://my-bucket/my-object.py' PY_OPTIONS = ['-m'] -DEFAULT_OPTIONS = { +DEFAULT_OPTIONS_PYTHON = { + 'project': 'test', + 'stagingLocation': 'gs://test/staging', +} +DEFAULT_OPTIONS_TEMPLATE = { 'project': 'test', - 'stagingLocation': 'gs://test/staging' + 'stagingLocation': 'gs://test/staging', + 'tempLocation': 'gs://test/temp', + 'zone': 'us-central1-f' } ADDITIONAL_OPTIONS = { 'output': 'gs://test/output', @@ -54,7 +67,7 @@ class DataFlowPythonOperatorTest(unittest.TestCase): task_id=TASK_ID, py_file=PY_FILE, py_options=PY_OPTIONS, - dataflow_default_options=DEFAULT_OPTIONS, + dataflow_default_options=DEFAULT_OPTIONS_PYTHON, options=ADDITIONAL_OPTIONS, poll_sleep=POLL_SLEEP) @@ -65,7 +78,7 @@ class DataFlowPythonOperatorTest(unittest.TestCase): self.assertEqual(self.dataflow.py_options, PY_OPTIONS) self.assertEqual(self.dataflow.poll_sleep, POLL_SLEEP) self.assertEqual(self.dataflow.dataflow_default_options, - DEFAULT_OPTIONS) + DEFAULT_OPTIONS_PYTHON) self.assertEqual(self.dataflow.options, EXPECTED_ADDITIONAL_OPTIONS) @@ -90,3 +103,41 @@ class DataFlowPythonOperatorTest(unittest.TestCase): start_python_hook.assert_called_once_with(TASK_ID, expected_options, mock.ANY, PY_OPTIONS) self.assertTrue(self.dataflow.py_file.startswith('/tmp/dataflow')) + + +class DataFlowTemplateOperatorTest(unittest.TestCase): + + def setUp(self): + self.dataflow = DataflowTemplateOperator( + task_id=TASK_ID, + template=TEMPLATE, + parameters=PARAMETERS, + dataflow_default_options=DEFAULT_OPTIONS_TEMPLATE, + poll_sleep=POLL_SLEEP) + + def test_init(self): + """Test DataflowTemplateOperator instance is properly initialized.""" + self.assertEqual(self.dataflow.task_id, TASK_ID) + self.assertEqual(self.dataflow.template, TEMPLATE) + self.assertEqual(self.dataflow.parameters, PARAMETERS) + self.assertEqual(self.dataflow.poll_sleep, POLL_SLEEP) + self.assertEqual(self.dataflow.dataflow_default_options, + DEFAULT_OPTIONS_TEMPLATE) + + @mock.patch('airflow.contrib.operators.dataflow_operator.DataFlowHook') + def test_exec(self, dataflow_mock): + """Test DataFlowHook is created and the right args are passed to + start_template_workflow. + + """ + start_template_hook = dataflow_mock.return_value.start_template_dataflow + self.dataflow.execute(None) + self.assertTrue(dataflow_mock.called) + expected_options = { + 'project': 'test', + 'stagingLocation': 'gs://test/staging', + 'tempLocation': 'gs://test/temp', + 'zone': 'us-central1-f' + } + start_template_hook.assert_called_once_with(TASK_ID, expected_options, + PARAMETERS, TEMPLATE)
