Repository: incubator-airflow Updated Branches: refs/heads/master b9f4a7437 -> cc9295fe3
[AIRFLOW-1953] Add labels to dataflow operators Closes #2913 from fenglu-g/master Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/cc9295fe Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/cc9295fe Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/cc9295fe Branch: refs/heads/master Commit: cc9295fe37ed6fb1ddfa077ee065ca6e0849a617 Parents: b9f4a74 Author: fenglu-g <[email protected]> Authored: Wed Jan 3 11:16:39 2018 -0800 Committer: Chris Riccomini <[email protected]> Committed: Wed Jan 3 11:16:39 2018 -0800 ---------------------------------------------------------------------- UPDATING.md | 5 ++ airflow/contrib/hooks/gcp_dataflow_hook.py | 36 ++++++++---- airflow/contrib/operators/dataflow_operator.py | 13 +++-- setup.py | 2 +- tests/contrib/hooks/test_gcp_dataflow_hook.py | 60 +++++++++++++++++--- .../contrib/operators/test_dataflow_operator.py | 14 ++++- .../operators/test_mlengine_operator_utils.py | 3 + 7 files changed, 107 insertions(+), 26 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/cc9295fe/UPDATING.md ---------------------------------------------------------------------- diff --git a/UPDATING.md b/UPDATING.md index 9c39634..7a801e5 100644 --- a/UPDATING.md +++ b/UPDATING.md @@ -14,6 +14,11 @@ celery_result_backend -> result_backend ``` This will result in the same config parameters as Celery 4 and will make it more transparent. +### GCP Dataflow Operators +Dataflow job labeling is now supported in Dataflow{Java,Python}Operator with a default +"airflow-version" label, please upgrade your google-cloud-dataflow or apache-beam version +to 2.2.0 or greater. + ## Airflow 1.9 ### SSH Hook updates, along with new SSH Operator & SFTP Operator http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/cc9295fe/airflow/contrib/hooks/gcp_dataflow_hook.py ---------------------------------------------------------------------- diff --git a/airflow/contrib/hooks/gcp_dataflow_hook.py b/airflow/contrib/hooks/gcp_dataflow_hook.py index 1928c3b..7cb7c79 100644 --- a/airflow/contrib/hooks/gcp_dataflow_hook.py +++ b/airflow/contrib/hooks/gcp_dataflow_hook.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import json import select import subprocess import time @@ -147,27 +148,40 @@ class DataFlowHook(GoogleCloudBaseHook): http_authorized = self._authorize() return build('dataflow', 'v1b3', http=http_authorized) - def _start_dataflow(self, task_id, variables, dataflow, name, command_prefix): - cmd = command_prefix + self._build_cmd(task_id, variables, dataflow) + def _start_dataflow(self, task_id, variables, dataflow, + name, command_prefix, label_formatter): + cmd = command_prefix + self._build_cmd(task_id, variables, + dataflow, label_formatter) _Dataflow(cmd).wait_for_done() - _DataflowJob( - self.get_conn(), variables['project'], name, self.poll_sleep).wait_for_done() + _DataflowJob(self.get_conn(), variables['project'], + name, self.poll_sleep).wait_for_done() def start_java_dataflow(self, task_id, variables, dataflow): name = task_id + "-" + str(uuid.uuid1())[:8] variables['jobName'] = name - self._start_dataflow( - task_id, variables, dataflow, name, ["java", "-jar"]) + + def label_formatter(labels_dict): + return ['--labels={}'.format( + json.dumps(labels_dict).replace(' ', ''))] + self._start_dataflow(task_id, variables, dataflow, name, + ["java", "-jar"], label_formatter) def start_python_dataflow(self, task_id, variables, dataflow, py_options): name = task_id + "-" + str(uuid.uuid1())[:8] variables["job_name"] = name - self._start_dataflow( - task_id, variables, dataflow, name, ["python"] + py_options) - def _build_cmd(self, task_id, variables, dataflow): + def label_formatter(labels_dict): + return ['--labels={}={}'.format(key, value) + for key, value in labels_dict.items()] + self._start_dataflow(task_id, variables, dataflow, name, + ["python"] + py_options, label_formatter) + + def _build_cmd(self, task_id, variables, dataflow, label_formatter): command = [dataflow, "--runner=DataflowRunner"] if variables is not None: - for attr, value in variables.iteritems(): - command.append("--" + attr + "=" + value) + for attr, value in variables.items(): + if attr == 'labels': + command += label_formatter(value) + else: + command.append("--" + attr + "=" + value) return command http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/cc9295fe/airflow/contrib/operators/dataflow_operator.py ---------------------------------------------------------------------- diff --git a/airflow/contrib/operators/dataflow_operator.py b/airflow/contrib/operators/dataflow_operator.py index 6fd23f1..01fbd35 100644 --- a/airflow/contrib/operators/dataflow_operator.py +++ b/airflow/contrib/operators/dataflow_operator.py @@ -19,6 +19,7 @@ import uuid from airflow.contrib.hooks.gcs_hook import GoogleCloudStorageHook from airflow.contrib.hooks.gcp_dataflow_hook import DataFlowHook from airflow.models import BaseOperator +from airflow.version import version from airflow.utils.decorators import apply_defaults @@ -52,7 +53,8 @@ class DataFlowJavaOperator(BaseOperator): 'autoscalingAlgorithm': 'BASIC', 'maxNumWorkers': '50', 'start': '{{ds}}', - 'partitionType': 'DAY' + 'partitionType': 'DAY', + 'labels': {'foo' : 'bar'} }, dag=my-dag) ``` @@ -97,7 +99,7 @@ class DataFlowJavaOperator(BaseOperator): For this to work, the service account making the request must have domain-wide delegation enabled. :type delegate_to: string - :param poll_sleep: The time in seconds to sleep between polling Google + :param poll_sleep: The time in seconds to sleep between polling Google Cloud Platform for the dataflow job status while the job is in the JOB_STATE_RUNNING state. :type poll_sleep: int @@ -106,7 +108,8 @@ class DataFlowJavaOperator(BaseOperator): dataflow_default_options = dataflow_default_options or {} options = options or {} - + options.setdefault('labels', {}).update( + {'airflow-version': 'v' + version.replace('.', '-').replace('+', '-')}) self.gcp_conn_id = gcp_conn_id self.delegate_to = delegate_to self.jar = jar @@ -171,7 +174,7 @@ class DataFlowPythonOperator(BaseOperator): For this to work, the service account making the request must have domain-wide delegation enabled. :type delegate_to: string - :param poll_sleep: The time in seconds to sleep between polling Google + :param poll_sleep: The time in seconds to sleep between polling Google Cloud Platform for the dataflow job status while the job is in the JOB_STATE_RUNNING state. :type poll_sleep: int @@ -182,6 +185,8 @@ class DataFlowPythonOperator(BaseOperator): self.py_options = py_options or [] self.dataflow_default_options = dataflow_default_options or {} self.options = options or {} + self.options.setdefault('labels', {}).update( + {'airflow-version': 'v' + version.replace('.', '-').replace('+', '-')}) self.gcp_conn_id = gcp_conn_id self.delegate_to = delegate_to self.poll_sleep = poll_sleep http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/cc9295fe/setup.py ---------------------------------------------------------------------- diff --git a/setup.py b/setup.py index 84da6f1..a63ce79 100644 --- a/setup.py +++ b/setup.py @@ -123,7 +123,7 @@ gcp_api = [ 'google-api-python-client>=1.5.0, <1.6.0', 'oauth2client>=2.0.2, <2.1.0', 'PyOpenSSL', - 'google-cloud-dataflow', + 'google-cloud-dataflow>=2.2.0', 'pandas-gbq' ] hdfs = ['snakebite>=2.7.8'] http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/cc9295fe/tests/contrib/hooks/test_gcp_dataflow_hook.py ---------------------------------------------------------------------- diff --git a/tests/contrib/hooks/test_gcp_dataflow_hook.py b/tests/contrib/hooks/test_gcp_dataflow_hook.py index 1ab5a99..a37b153 100644 --- a/tests/contrib/hooks/test_gcp_dataflow_hook.py +++ b/tests/contrib/hooks/test_gcp_dataflow_hook.py @@ -31,13 +31,21 @@ except ImportError: TASK_ID = 'test-python-dataflow' PY_FILE = 'apache_beam.examples.wordcount' +JAR_FILE = 'unitest.jar' PY_OPTIONS = ['-m'] -OPTIONS = { +DATAFLOW_OPTIONS_PY = { 'project': 'test', - 'staging_location': 'gs://test/staging' + 'staging_location': 'gs://test/staging', + 'labels': {'foo': 'bar'} +} +DATAFLOW_OPTIONS_JAVA = { + 'project': 'test', + 'stagingLocation': 'gs://test/staging', + 'labels': {'foo': 'bar'} } BASE_STRING = 'airflow.contrib.hooks.gcp_api_base_hook.{}' DATAFLOW_STRING = 'airflow.contrib.hooks.gcp_dataflow_hook.{}' +MOCK_UUID = '12345678' def mock_init(self, gcp_conn_id, delegate_to=None): @@ -51,13 +59,51 @@ class DataFlowHookTest(unittest.TestCase): new=mock_init): self.dataflow_hook = DataFlowHook(gcp_conn_id='test') - @mock.patch(DATAFLOW_STRING.format('DataFlowHook._start_dataflow')) - def test_start_python_dataflow(self, internal_dataflow_mock): + @mock.patch(DATAFLOW_STRING.format('uuid.uuid1')) + @mock.patch(DATAFLOW_STRING.format('_DataflowJob')) + @mock.patch(DATAFLOW_STRING.format('_Dataflow')) + @mock.patch(DATAFLOW_STRING.format('DataFlowHook.get_conn')) + def test_start_python_dataflow(self, mock_conn, + mock_dataflow, mock_dataflowjob, mock_uuid): + mock_uuid.return_value = MOCK_UUID + mock_conn.return_value = None + dataflow_instance = mock_dataflow.return_value + dataflow_instance.wait_for_done.return_value = None + dataflowjob_instance = mock_dataflowjob.return_value + dataflowjob_instance.wait_for_done.return_value = None self.dataflow_hook.start_python_dataflow( - task_id=TASK_ID, variables=OPTIONS, + task_id=TASK_ID, variables=DATAFLOW_OPTIONS_PY, dataflow=PY_FILE, py_options=PY_OPTIONS) - internal_dataflow_mock.assert_called_once_with( - TASK_ID, OPTIONS, PY_FILE, mock.ANY, ['python'] + PY_OPTIONS) + EXPECTED_CMD = ['python', '-m', PY_FILE, + '--runner=DataflowRunner', '--project=test', + '--labels=foo=bar', + '--staging_location=gs://test/staging', + '--job_name={}-{}'.format(TASK_ID, MOCK_UUID)] + self.assertListEqual(sorted(mock_dataflow.call_args[0][0]), + sorted(EXPECTED_CMD)) + + @mock.patch(DATAFLOW_STRING.format('uuid.uuid1')) + @mock.patch(DATAFLOW_STRING.format('_DataflowJob')) + @mock.patch(DATAFLOW_STRING.format('_Dataflow')) + @mock.patch(DATAFLOW_STRING.format('DataFlowHook.get_conn')) + def test_start_java_dataflow(self, mock_conn, + mock_dataflow, mock_dataflowjob, mock_uuid): + mock_uuid.return_value = MOCK_UUID + mock_conn.return_value = None + dataflow_instance = mock_dataflow.return_value + dataflow_instance.wait_for_done.return_value = None + dataflowjob_instance = mock_dataflowjob.return_value + dataflowjob_instance.wait_for_done.return_value = None + self.dataflow_hook.start_java_dataflow( + task_id=TASK_ID, variables=DATAFLOW_OPTIONS_JAVA, + dataflow=JAR_FILE) + EXPECTED_CMD = ['java', '-jar', JAR_FILE, + '--runner=DataflowRunner', '--project=test', + '--stagingLocation=gs://test/staging', + '--labels={"foo":"bar"}', + '--jobName={}-{}'.format(TASK_ID, MOCK_UUID)] + self.assertListEqual(sorted(mock_dataflow.call_args[0][0]), + sorted(EXPECTED_CMD)) @mock.patch('airflow.contrib.hooks.gcp_dataflow_hook._Dataflow.log') @mock.patch('subprocess.Popen') http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/cc9295fe/tests/contrib/operators/test_dataflow_operator.py ---------------------------------------------------------------------- diff --git a/tests/contrib/operators/test_dataflow_operator.py b/tests/contrib/operators/test_dataflow_operator.py index 77fc1f6..5b07051 100644 --- a/tests/contrib/operators/test_dataflow_operator.py +++ b/tests/contrib/operators/test_dataflow_operator.py @@ -16,6 +16,7 @@ import unittest from airflow.contrib.operators.dataflow_operator import DataFlowPythonOperator +from airflow.version import version try: from unittest import mock @@ -34,7 +35,13 @@ DEFAULT_OPTIONS = { 'stagingLocation': 'gs://test/staging' } ADDITIONAL_OPTIONS = { - 'output': 'gs://test/output' + 'output': 'gs://test/output', + 'labels': {'foo': 'bar'} +} +TEST_VERSION = 'v{}'.format(version.replace('.', '-').replace('+', '-')) +EXPECTED_ADDITIONAL_OPTIONS = { + 'output': 'gs://test/output', + 'labels': {'foo': 'bar', 'airflow-version': TEST_VERSION} } POLL_SLEEP = 30 GCS_HOOK_STRING = 'airflow.contrib.operators.dataflow_operator.{}' @@ -60,7 +67,7 @@ class DataFlowPythonOperatorTest(unittest.TestCase): self.assertEqual(self.dataflow.dataflow_default_options, DEFAULT_OPTIONS) self.assertEqual(self.dataflow.options, - ADDITIONAL_OPTIONS) + EXPECTED_ADDITIONAL_OPTIONS) @mock.patch('airflow.contrib.operators.dataflow_operator.DataFlowHook') @mock.patch(GCS_HOOK_STRING.format('GoogleCloudBucketHelper')) @@ -76,7 +83,8 @@ class DataFlowPythonOperatorTest(unittest.TestCase): expected_options = { 'project': 'test', 'staging_location': 'gs://test/staging', - 'output': 'gs://test/output' + 'output': 'gs://test/output', + 'labels': {'foo': 'bar', 'airflow-version': TEST_VERSION} } gcs_download_hook.assert_called_once_with(PY_FILE) start_python_hook.assert_called_once_with(TASK_ID, expected_options, http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/cc9295fe/tests/contrib/operators/test_mlengine_operator_utils.py ---------------------------------------------------------------------- diff --git a/tests/contrib/operators/test_mlengine_operator_utils.py b/tests/contrib/operators/test_mlengine_operator_utils.py index 80ab01a..c8f6fb5 100644 --- a/tests/contrib/operators/test_mlengine_operator_utils.py +++ b/tests/contrib/operators/test_mlengine_operator_utils.py @@ -26,11 +26,13 @@ from airflow import configuration, DAG from airflow.contrib.operators import mlengine_operator_utils from airflow.contrib.operators.mlengine_operator_utils import create_evaluate_ops from airflow.exceptions import AirflowException +from airflow.version import version from mock import ANY from mock import patch DEFAULT_DATE = datetime.datetime(2017, 6, 6) +TEST_VERSION = 'v{}'.format(version.replace('.', '-').replace('+', '-')) class CreateEvaluateOpsTest(unittest.TestCase): @@ -115,6 +117,7 @@ class CreateEvaluateOpsTest(unittest.TestCase): 'eval-test-summary', { 'prediction_path': 'gs://legal-bucket/fake-output-path', + 'labels': {'airflow-version': TEST_VERSION}, 'metric_keys': 'err', 'metric_fn_encoded': self.metric_fn_encoded, },
