itscaro closed pull request #4116: WIP [AIRFLOW-3207] option to stop task
pushing result to xcom
URL: https://github.com/apache/incubator-airflow/pull/4116
This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:
As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):
diff --git a/airflow/contrib/operators/bigquery_get_data.py
b/airflow/contrib/operators/bigquery_get_data.py
index f5e6e50f06..e9c7787dd3 100644
--- a/airflow/contrib/operators/bigquery_get_data.py
+++ b/airflow/contrib/operators/bigquery_get_data.py
@@ -66,6 +66,8 @@ class BigQueryGetDataOperator(BaseOperator):
For this to work, the service account making the request must have
domain-wide
delegation enabled.
:type delegate_to: str
+ :param do_xcom_push: return the result which also get set in XCOM
+ :type do_xcom_push: bool
"""
template_fields = ('dataset_id', 'table_id', 'max_results')
ui_color = '#e4f0e8'
@@ -78,6 +80,7 @@ def __init__(self,
selected_fields=None,
bigquery_conn_id='bigquery_default',
delegate_to=None,
+ do_xcom_push=True,
*args,
**kwargs):
super(BigQueryGetDataOperator, self).__init__(*args, **kwargs)
@@ -87,6 +90,7 @@ def __init__(self,
self.selected_fields = selected_fields
self.bigquery_conn_id = bigquery_conn_id
self.delegate_to = delegate_to
+ self.do_xcom_push = do_xcom_push
def execute(self, context):
self.log.info('Fetching Data from:')
@@ -113,4 +117,5 @@ def execute(self, context):
single_row.append(fields['v'])
table_data.append(single_row)
- return table_data
+ if self.do_xcom_push:
+ return table_data
diff --git a/airflow/contrib/operators/dataflow_operator.py
b/airflow/contrib/operators/dataflow_operator.py
index 5378735f94..dfb65761a1 100644
--- a/airflow/contrib/operators/dataflow_operator.py
+++ b/airflow/contrib/operators/dataflow_operator.py
@@ -311,7 +311,6 @@ def __init__(
poll_sleep=10,
*args,
**kwargs):
-
super(DataFlowPythonOperator, self).__init__(*args, **kwargs)
self.py_file = py_file
@@ -335,9 +334,11 @@ def execute(self, context):
poll_sleep=self.poll_sleep)
dataflow_options = self.dataflow_default_options.copy()
dataflow_options.update(self.options)
+
# Convert argument names from lowerCamelCase to snake case.
- camel_to_snake = lambda name: re.sub(
- r'[A-Z]', lambda x: '_' + x.group(0).lower(), name)
+ def camel_to_snake(name):
+ return re.sub(r'[A-Z]', lambda x: '_' + x.group(0).lower(), name)
+
formatted_options = {camel_to_snake(key): dataflow_options[key]
for key in dataflow_options}
hook.start_python_dataflow(
diff --git a/airflow/contrib/operators/dataproc_operator.py
b/airflow/contrib/operators/dataproc_operator.py
index 60fc2bcf15..8823c56c30 100644
--- a/airflow/contrib/operators/dataproc_operator.py
+++ b/airflow/contrib/operators/dataproc_operator.py
@@ -229,14 +229,14 @@ def _get_cluster(self, service):
cluster = [c for c in cluster_list if c['clusterName'] ==
self.cluster_name]
if cluster:
return cluster[0]
- return None
+ return
def _get_cluster_state(self, service):
cluster = self._get_cluster(service)
if 'status' in cluster:
return cluster['status']['state']
else:
- return None
+ return
def _cluster_ready(self, state, service):
if state == 'RUNNING':
@@ -407,7 +407,7 @@ def execute(self, context):
self.cluster_name
)
self._wait_for_done(service)
- return True
+ return
cluster_data = self._build_cluster_data()
try:
@@ -425,7 +425,7 @@ def execute(self, context):
self.cluster_name
)
self._wait_for_done(service)
- return True
+ return
else:
raise e
diff --git a/airflow/contrib/operators/emr_add_steps_operator.py
b/airflow/contrib/operators/emr_add_steps_operator.py
index 959543e617..c0048681cf 100644
--- a/airflow/contrib/operators/emr_add_steps_operator.py
+++ b/airflow/contrib/operators/emr_add_steps_operator.py
@@ -32,6 +32,8 @@ class EmrAddStepsOperator(BaseOperator):
:type aws_conn_id: str
:param steps: boto3 style steps to be added to the jobflow. (templated)
:type steps: list
+ :param do_xcom_push: return the Step IDs which also get set in XCOM
+ :type do_xcom_push: bool
"""
template_fields = ['job_flow_id', 'steps']
template_ext = ()
@@ -43,12 +45,14 @@ def __init__(
job_flow_id,
aws_conn_id='s3_default',
steps=None,
+ do_xcom_push=True,
*args, **kwargs):
super(EmrAddStepsOperator, self).__init__(*args, **kwargs)
steps = steps or []
self.job_flow_id = job_flow_id
self.aws_conn_id = aws_conn_id
self.steps = steps
+ self.do_xcom_push = do_xcom_push
def execute(self, context):
emr = EmrHook(aws_conn_id=self.aws_conn_id).get_conn()
@@ -60,4 +64,5 @@ def execute(self, context):
raise AirflowException('Adding steps failed: %s' % response)
else:
self.log.info('Steps %s added to JobFlow', response['StepIds'])
- return response['StepIds']
+ if self.do_xcom_push:
+ return response['StepIds']
diff --git a/airflow/contrib/operators/emr_create_job_flow_operator.py
b/airflow/contrib/operators/emr_create_job_flow_operator.py
index 42886d6006..6d936b8df8 100644
--- a/airflow/contrib/operators/emr_create_job_flow_operator.py
+++ b/airflow/contrib/operators/emr_create_job_flow_operator.py
@@ -35,6 +35,8 @@ class EmrCreateJobFlowOperator(BaseOperator):
:param job_flow_overrides: boto3 style arguments to override
emr_connection extra. (templated)
:type steps: dict
+ :param do_xcom_push: return the status code which also get set in XCOM
+ :type do_xcom_push: bool
"""
template_fields = ['job_flow_overrides']
template_ext = ()
@@ -46,6 +48,7 @@ def __init__(
aws_conn_id='s3_default',
emr_conn_id='emr_default',
job_flow_overrides=None,
+ do_xcom_push=True,
*args, **kwargs):
super(EmrCreateJobFlowOperator, self).__init__(*args, **kwargs)
self.aws_conn_id = aws_conn_id
@@ -53,6 +56,7 @@ def __init__(
if job_flow_overrides is None:
job_flow_overrides = {}
self.job_flow_overrides = job_flow_overrides
+ self.do_xcom_push = do_xcom_push
def execute(self, context):
emr = EmrHook(aws_conn_id=self.aws_conn_id,
emr_conn_id=self.emr_conn_id)
@@ -67,4 +71,5 @@ def execute(self, context):
raise AirflowException('JobFlow creation failed: %s' % response)
else:
self.log.info('JobFlow with id %s created', response['JobFlowId'])
- return response['JobFlowId']
+ if self.do_xcom_push:
+ return response['JobFlowId']
diff --git a/airflow/contrib/operators/gcp_container_operator.py
b/airflow/contrib/operators/gcp_container_operator.py
index fda4d44b9d..814d8d6f6c 100644
--- a/airflow/contrib/operators/gcp_container_operator.py
+++ b/airflow/contrib/operators/gcp_container_operator.py
@@ -133,6 +133,9 @@ class GKEClusterCreateOperator(BaseOperator):
:type gcp_conn_id: str
:param api_version: The api version to use
:type api_version: str
+ :param do_xcom_push: return the result of cluster creation operation which
also get
+ set in XCOM
+ :type do_xcom_push: bool
"""
template_fields = ['project_id', 'gcp_conn_id', 'location', 'api_version',
'body']
@@ -143,6 +146,7 @@ def __init__(self,
body=None,
gcp_conn_id='google_cloud_default',
api_version='v2',
+ do_xcom_push=True,
*args,
**kwargs):
super(GKEClusterCreateOperator, self).__init__(*args, **kwargs)
@@ -154,6 +158,7 @@ def __init__(self,
self.location = location
self.api_version = api_version
self.body = body
+ self.do_xcom_push = do_xcom_push
def _check_input(self):
if all([self.project_id, self.location, self.body]):
@@ -175,7 +180,9 @@ def execute(self, context):
self._check_input()
hook = GKEClusterHook(self.project_id, self.location)
create_op = hook.create_cluster(cluster=self.body)
- return create_op
+
+ if self.do_xcom_push:
+ return create_op
KUBE_CONFIG_ENV_VAR = "KUBECONFIG"
diff --git a/airflow/contrib/operators/gcp_function_operator.py
b/airflow/contrib/operators/gcp_function_operator.py
index 8207b9d084..677acb8d87 100644
--- a/airflow/contrib/operators/gcp_function_operator.py
+++ b/airflow/contrib/operators/gcp_function_operator.py
@@ -274,6 +274,8 @@ class GcfFunctionDeleteOperator(BaseOperator):
:type gcp_conn_id: str
:param api_version: Version of the API used (for example v1).
:type api_version: str
+ :param do_xcom_push: return the file list which also get set in XCOM
+ :type do_xcom_push: bool
"""
@apply_defaults
@@ -281,10 +283,13 @@ def __init__(self,
name,
gcp_conn_id='google_cloud_default',
api_version='v1',
- *args, **kwargs):
+ do_xcom_push=True,
+ *args,
+ **kwargs):
self.name = name
self.gcp_conn_id = gcp_conn_id
self.api_version = api_version
+ self.do_xcom_push = do_xcom_push
self._validate_inputs()
self.hook = GcfHook(gcp_conn_id=self.gcp_conn_id,
api_version=self.api_version)
super(GcfFunctionDeleteOperator, self).__init__(*args, **kwargs)
@@ -300,7 +305,10 @@ def _validate_inputs(self):
def execute(self, context):
try:
- return self.hook.delete_function(self.name)
+ result = self.hook.delete_function(self.name)
+
+ if self.do_xcom_push:
+ return result
except HttpError as e:
status = e.resp.status
if status == 404:
diff --git a/airflow/contrib/operators/gcs_list_operator.py
b/airflow/contrib/operators/gcs_list_operator.py
index 7b37b269a6..1e1587dc1d 100644
--- a/airflow/contrib/operators/gcs_list_operator.py
+++ b/airflow/contrib/operators/gcs_list_operator.py
@@ -45,6 +45,9 @@ class GoogleCloudStorageListOperator(BaseOperator):
For this to work, the service account making the request must have
domain-wide delegation enabled.
:type delegate_to: str
+ :param do_xcom_push: return the list of Google Cloud Storage Hook which
also get set
+ in XCOM
+ :type do_xcom_push: bool
**Example**:
The following Operator would list all the Avro files from
``sales/sales-2017``
@@ -68,6 +71,7 @@ def __init__(self,
delimiter=None,
google_cloud_storage_conn_id='google_cloud_default',
delegate_to=None,
+ do_xcom_push=True,
*args,
**kwargs):
super(GoogleCloudStorageListOperator, self).__init__(*args, **kwargs)
@@ -76,6 +80,7 @@ def __init__(self,
self.delimiter = delimiter
self.google_cloud_storage_conn_id = google_cloud_storage_conn_id
self.delegate_to = delegate_to
+ self.do_xcom_push = do_xcom_push
def execute(self, context):
@@ -87,6 +92,9 @@ def execute(self, context):
self.log.info('Getting list of the files. Bucket: %s; Delimiter: %s;
Prefix: %s',
self.bucket, self.delimiter, self.prefix)
- return hook.list(bucket=self.bucket,
- prefix=self.prefix,
- delimiter=self.delimiter)
+ files = hook.list(bucket=self.bucket,
+ prefix=self.prefix,
+ delimiter=self.delimiter)
+
+ if self.do_xcom_push or __name__ != 'GoogleCloudStorageListOperator':
+ return files
diff --git a/airflow/contrib/operators/gcs_to_bq.py
b/airflow/contrib/operators/gcs_to_bq.py
index 39dff21606..d41f4b6acb 100644
--- a/airflow/contrib/operators/gcs_to_bq.py
+++ b/airflow/contrib/operators/gcs_to_bq.py
@@ -87,10 +87,7 @@ class GoogleCloudStorageToBigQueryOperator(BaseOperator):
:type allow_jagged_rows: bool
:param max_id_key: If set, the name of a column in the BigQuery table
that's to be loaded. This will be used to select the MAX value from
- BigQuery after the load occurs. The results will be returned by the
- execute() command, which in turn gets stored in XCom for future
- operators to use. This can be helpful with incremental loads--during
- future executions, you can pick up from the max ID.
+ BigQuery after the load occurs. (used in combination with do_xcom_push)
:type max_id_key: str
:param bigquery_conn_id: Reference to a specific BigQuery hook.
:type bigquery_conn_id: str
@@ -119,6 +116,11 @@ class GoogleCloudStorageToBigQueryOperator(BaseOperator):
time_partitioning. The order of columns given determines the sort
order.
Not applicable for external tables.
:type cluster_fields: list of str
+ :param do_xcom_push: return the max id which also get set in XCOM.
+ (max_id_key must be set)
+ This can be helpful with incremental loads--during future executions,
you can pick
+ up from the max ID.
+ :type do_xcom_push: bool
"""
template_fields = ('bucket', 'source_objects',
'schema_object', 'destination_project_dataset_table')
@@ -152,6 +154,7 @@ def __init__(self,
external_table=False,
time_partitioning=None,
cluster_fields=None,
+ do_xcom_push=True,
*args, **kwargs):
super(GoogleCloudStorageToBigQueryOperator, self).__init__(*args,
**kwargs)
@@ -190,6 +193,7 @@ def __init__(self,
self.src_fmt_configs = src_fmt_configs
self.time_partitioning = time_partitioning
self.cluster_fields = cluster_fields
+ self.do_xcom_push = do_xcom_push
def execute(self, context):
bq_hook = BigQueryHook(bigquery_conn_id=self.bigquery_conn_id,
@@ -248,7 +252,7 @@ def execute(self, context):
time_partitioning=self.time_partitioning,
cluster_fields=self.cluster_fields)
- if self.max_id_key:
+ if self.do_xcom_push and self.max_id_key:
cursor.execute('SELECT MAX({}) FROM {}'.format(
self.max_id_key,
self.destination_project_dataset_table))
diff --git a/airflow/contrib/operators/gcs_to_s3.py
b/airflow/contrib/operators/gcs_to_s3.py
index d8b180c81a..9db7a7ffb5 100644
--- a/airflow/contrib/operators/gcs_to_s3.py
+++ b/airflow/contrib/operators/gcs_to_s3.py
@@ -57,6 +57,8 @@ class
GoogleCloudStorageToS3Operator(GoogleCloudStorageListOperator):
You can specify this argument if you want to use a different
CA cert bundle than the one used by botocore.
:type dest_verify: bool or str
+ :param do_xcom_push: return the file list which also get set in XCOM
+ :type do_xcom_push: bool
"""
template_fields = ('bucket', 'prefix', 'delimiter', 'dest_s3_key')
ui_color = '#f0eee4'
@@ -72,6 +74,7 @@ def __init__(self,
dest_s3_key=None,
dest_verify=None,
replace=False,
+ do_xcom_push=True,
*args,
**kwargs):
@@ -88,6 +91,7 @@ def __init__(self,
self.dest_s3_key = dest_s3_key
self.dest_verify = dest_verify
self.replace = replace
+ self.do_xcom_push = do_xcom_push
def execute(self, context):
# use the super to list all files in an Google Cloud Storage bucket
@@ -122,4 +126,5 @@ def execute(self, context):
else:
self.log.info("In sync, no files needed to be uploaded to S3")
- return files
+ if self.do_xcom_push:
+ return files
diff --git a/airflow/contrib/operators/hive_to_dynamodb.py
b/airflow/contrib/operators/hive_to_dynamodb.py
index 4a39e40741..bd0b937e7e 100644
--- a/airflow/contrib/operators/hive_to_dynamodb.py
+++ b/airflow/contrib/operators/hive_to_dynamodb.py
@@ -70,7 +70,8 @@ def __init__(
schema='default',
hiveserver2_conn_id='hiveserver2_default',
aws_conn_id='aws_default',
- *args, **kwargs):
+ *args,
+ **kwargs):
super(HiveToDynamoDBTransferOperator, self).__init__(*args, **kwargs)
self.sql = sql
self.table_name = table_name
diff --git a/airflow/contrib/operators/jenkins_job_trigger_operator.py
b/airflow/contrib/operators/jenkins_job_trigger_operator.py
index 3e1aba0c1b..027343405a 100644
--- a/airflow/contrib/operators/jenkins_job_trigger_operator.py
+++ b/airflow/contrib/operators/jenkins_job_trigger_operator.py
@@ -106,6 +106,8 @@ class JenkinsJobTriggerOperator(BaseOperator):
:param max_try_before_job_appears: The maximum number of requests to make
while waiting for the job to appears on jenkins server (default 10)
:type max_try_before_job_appears: int
+ :param do_xcom_push: return the build URL which also get set in XCOM
+ :type do_xcom_push: bool
"""
template_fields = ('parameters',)
template_ext = ('.json',)
@@ -118,6 +120,7 @@ def __init__(self,
parameters="",
sleep_time=10,
max_try_before_job_appears=10,
+ do_xcom_push=True,
*args,
**kwargs):
super(JenkinsJobTriggerOperator, self).__init__(*args, **kwargs)
@@ -128,6 +131,7 @@ def __init__(self,
self.sleep_time = sleep_time
self.jenkins_connection_id = jenkins_connection_id
self.max_try_before_job_appears = max_try_before_job_appears
+ self.do_xcom_push = do_xcom_push
def build_job(self, jenkins_server):
"""
@@ -243,7 +247,8 @@ def execute(self, context):
'this exception for unknown parameters'
'You can also check logs for more details on this
exception '
'(jenkins_url/log/rss)', str(err))
- if build_info:
+
+ if self.do_xcom_push and build_info:
# If we can we return the url of the job
# for later use (like retrieving an artifact)
return build_info['url']
diff --git a/airflow/contrib/operators/jira_operator.py
b/airflow/contrib/operators/jira_operator.py
index 01d78b8645..b165c1cbc3 100644
--- a/airflow/contrib/operators/jira_operator.py
+++ b/airflow/contrib/operators/jira_operator.py
@@ -41,6 +41,8 @@ class JiraOperator(BaseOperator):
:param get_jira_resource_method: function or operator to get jira resource
on which the provided jira_method will be
executed
:type get_jira_resource_method: function
+ :param do_xcom_push: return the result which also get set in XCOM
+ :type do_xcom_push: bool
"""
template_fields = ("jira_method_args",)
@@ -52,6 +54,7 @@ def __init__(self,
jira_method_args=None,
result_processor=None,
get_jira_resource_method=None,
+ do_xcom_push=True,
*args,
**kwargs):
super(JiraOperator, self).__init__(*args, **kwargs)
@@ -60,6 +63,7 @@ def __init__(self,
self.jira_method_args = jira_method_args
self.result_processor = result_processor
self.get_jira_resource_method = get_jira_resource_method
+ self.do_xcom_push = do_xcom_push
def execute(self, context):
try:
@@ -82,10 +86,12 @@ def execute(self, context):
# ex: self.xcom_push(context, key='operator_response',
value=jira_response)
# This could potentially throw error if jira_result is not
picklable
jira_result = getattr(resource,
self.method_name)(**self.jira_method_args)
- if self.result_processor:
- return self.result_processor(context, jira_result)
- return jira_result
+ if self.do_xcom_push:
+ if self.result_processor:
+ return self.result_processor(context, jira_result)
+
+ return jira_result
except JIRAError as jira_error:
raise AirflowException("Failed to execute jiraOperator, error: %s"
diff --git a/airflow/contrib/operators/mlengine_operator.py
b/airflow/contrib/operators/mlengine_operator.py
index 8091ceefff..d9a123df56 100644
--- a/airflow/contrib/operators/mlengine_operator.py
+++ b/airflow/contrib/operators/mlengine_operator.py
@@ -150,6 +150,8 @@ class MLEngineBatchPredictionOperator(BaseOperator):
For this to work, the service account making the request must
have doamin-wide delegation enabled.
:type delegate_to: str
+ :param do_xcom_push: return the result which also get set in XCOM
+ :type do_xcom_push: bool
Raises:
``ValueError``: if a unique model/version origin cannot be determined.
@@ -181,6 +183,7 @@ def __init__(self,
runtime_version=None,
gcp_conn_id='google_cloud_default',
delegate_to=None,
+ do_xcom_push=True,
*args,
**kwargs):
super(MLEngineBatchPredictionOperator, self).__init__(*args, **kwargs)
@@ -198,6 +201,7 @@ def __init__(self,
self._runtime_version = runtime_version
self._gcp_conn_id = gcp_conn_id
self._delegate_to = delegate_to
+ self.do_xcom_push = do_xcom_push
if not self._project_id:
raise AirflowException('Google Cloud project id is required.')
@@ -272,7 +276,8 @@ def check_existing_job(existing_job):
str(finished_prediction_job)))
raise RuntimeError(finished_prediction_job['errorMessage'])
- return finished_prediction_job['predictionOutput']
+ if self.do_xcom_push:
+ return finished_prediction_job['predictionOutput']
class MLEngineModelOperator(BaseOperator):
@@ -300,6 +305,8 @@ class MLEngineModelOperator(BaseOperator):
For this to work, the service account making the request must have
domain-wide delegation enabled.
:type delegate_to: str
+ :param do_xcom_push: return the result which also get set in XCOM
+ :type do_xcom_push: bool
"""
template_fields = [
@@ -313,6 +320,7 @@ def __init__(self,
operation='create',
gcp_conn_id='google_cloud_default',
delegate_to=None,
+ do_xcom_push=True,
*args,
**kwargs):
super(MLEngineModelOperator, self).__init__(*args, **kwargs)
@@ -321,17 +329,21 @@ def __init__(self,
self._operation = operation
self._gcp_conn_id = gcp_conn_id
self._delegate_to = delegate_to
+ self.do_xcom_push = do_xcom_push
def execute(self, context):
hook = MLEngineHook(
gcp_conn_id=self._gcp_conn_id, delegate_to=self._delegate_to)
if self._operation == 'create':
- return hook.create_model(self._project_id, self._model)
+ result = hook.create_model(self._project_id, self._model)
elif self._operation == 'get':
- return hook.get_model(self._project_id, self._model['name'])
+ result = hook.get_model(self._project_id, self._model['name'])
else:
raise ValueError('Unknown operation: {}'.format(self._operation))
+ if self.do_xcom_push:
+ return result
+
class MLEngineVersionOperator(BaseOperator):
"""
@@ -387,6 +399,8 @@ class MLEngineVersionOperator(BaseOperator):
For this to work, the service account making the request must have
domain-wide delegation enabled.
:type delegate_to: str
+ :param do_xcom_push: return the result which also get set in XCOM
+ :type do_xcom_push: bool
"""
template_fields = [
@@ -404,6 +418,7 @@ def __init__(self,
operation='create',
gcp_conn_id='google_cloud_default',
delegate_to=None,
+ do_xcom_push=True,
*args,
**kwargs):
@@ -415,6 +430,7 @@ def __init__(self,
self._operation = operation
self._gcp_conn_id = gcp_conn_id
self._delegate_to = delegate_to
+ self.do_xcom_push = do_xcom_push
def execute(self, context):
if 'name' not in self._version:
@@ -427,19 +443,22 @@ def execute(self, context):
if not self._version:
raise ValueError("version attribute of {} could not "
"be empty".format(self.__class__.__name__))
- return hook.create_version(self._project_id, self._model_name,
- self._version)
+ result = hook.create_version(self._project_id, self._model_name,
+ self._version)
elif self._operation == 'set_default':
- return hook.set_default_version(self._project_id, self._model_name,
- self._version['name'])
+ result = hook.set_default_version(self._project_id,
self._model_name,
+ self._version['name'])
elif self._operation == 'list':
- return hook.list_versions(self._project_id, self._model_name)
+ result = hook.list_versions(self._project_id, self._model_name)
elif self._operation == 'delete':
- return hook.delete_version(self._project_id, self._model_name,
- self._version['name'])
+ result = hook.delete_version(self._project_id, self._model_name,
+ self._version['name'])
else:
raise ValueError('Unknown operation: {}'.format(self._operation))
+ if self.do_xcom_push:
+ return result
+
class MLEngineTrainingOperator(BaseOperator):
"""
@@ -585,8 +604,7 @@ def execute(self, context):
if self._mode == 'DRY_RUN':
self.log.info('In dry_run mode.')
- self.log.info('MLEngine Training job request is: {}'.format(
- training_request))
+ self.log.info('MLEngine Training job request is:
{}'.format(training_request))
return
hook = MLEngineHook(
diff --git a/airflow/contrib/operators/mlengine_prediction_summary.py
b/airflow/contrib/operators/mlengine_prediction_summary.py
index def793c1be..227cf74488 100644
--- a/airflow/contrib/operators/mlengine_prediction_summary.py
+++ b/airflow/contrib/operators/mlengine_prediction_summary.py
@@ -113,7 +113,7 @@ def decode(x):
@beam.ptransform_fn
-def MakeSummary(pcoll, metric_fn, metric_keys): # pylint: disable=invalid-name
+def make_summary(pcoll, metric_fn, metric_keys): # pylint:
disable=invalid-name
return (
pcoll |
"ApplyMetricFnPerInstance" >> beam.Map(metric_fn) |
@@ -156,8 +156,7 @@ def run(argv=None):
raise ValueError("--metric_fn_encoded must be an encoded callable.")
metric_keys = known_args.metric_keys.split(",")
- with beam.Pipeline(
- options=beam.pipeline.PipelineOptions(pipeline_args)) as p:
+ with beam.Pipeline(options=beam.pipeline.PipelineOptions(pipeline_args))
as p:
# This is apache-beam ptransform's convention
# pylint: disable=no-value-for-parameter
_ = (p
@@ -165,7 +164,7 @@ def run(argv=None):
os.path.join(known_args.prediction_path,
"prediction.results-*-of-*"),
coder=JsonCoder())
- | "Summary" >> MakeSummary(metric_fn, metric_keys)
+ | "Summary" >> make_summary(metric_fn, metric_keys)
| "Write" >> beam.io.WriteToText(
os.path.join(known_args.prediction_path,
"prediction.summary.json"),
diff --git a/airflow/contrib/operators/mongo_to_s3.py
b/airflow/contrib/operators/mongo_to_s3.py
index 8bfa7a52f8..82e8040132 100644
--- a/airflow/contrib/operators/mongo_to_s3.py
+++ b/airflow/contrib/operators/mongo_to_s3.py
@@ -94,8 +94,6 @@ def execute(self, context):
replace=self.replace
)
- return True
-
@staticmethod
def _stringify(iterable, joinable='\n'):
"""
diff --git a/airflow/contrib/operators/pubsub_operator.py
b/airflow/contrib/operators/pubsub_operator.py
index e40828bf92..38f14f2eb7 100644
--- a/airflow/contrib/operators/pubsub_operator.py
+++ b/airflow/contrib/operators/pubsub_operator.py
@@ -159,6 +159,7 @@ def __init__(
fail_if_exists=False,
gcp_conn_id='google_cloud_default',
delegate_to=None,
+ do_xcom_push=True,
*args,
**kwargs):
"""
@@ -185,6 +186,8 @@ def __init__(
For this to work, the service account making the request
must have domain-wide delegation enabled.
:type delegate_to: str
+ :param do_xcom_push: return the result which also get set in XCOM
+ :type do_xcom_push: bool
"""
super(PubSubSubscriptionCreateOperator, self).__init__(*args, **kwargs)
@@ -196,16 +199,20 @@ def __init__(
self.fail_if_exists = fail_if_exists
self.gcp_conn_id = gcp_conn_id
self.delegate_to = delegate_to
+ self.do_xcom_push = do_xcom_push
def execute(self, context):
hook = PubSubHook(gcp_conn_id=self.gcp_conn_id,
delegate_to=self.delegate_to)
- return hook.create_subscription(
+ result = hook.create_subscription(
self.topic_project, self.topic, self.subscription,
self.subscription_project, self.ack_deadline_secs,
self.fail_if_exists)
+ if self.do_xcom_push:
+ return result
+
class PubSubTopicDeleteOperator(BaseOperator):
"""Delete a PubSub topic.
diff --git a/airflow/contrib/operators/qubole_operator.py
b/airflow/contrib/operators/qubole_operator.py
index 82ee293b93..79d12dde75 100755
--- a/airflow/contrib/operators/qubole_operator.py
+++ b/airflow/contrib/operators/qubole_operator.py
@@ -28,6 +28,8 @@ class QuboleOperator(BaseOperator):
:param qubole_conn_id: Connection id which consists of qds auth_token
:type qubole_conn_id: str
+ :param do_xcom_push: return the result which also get set in XCOM
+ :type do_xcom_push: bool
kwargs:
:command_type: type of command to be executed, e.g. hivecmd, shellcmd,
hadoopcmd
@@ -129,11 +131,16 @@ class QuboleOperator(BaseOperator):
ui_fgcolor = '#fff'
@apply_defaults
- def __init__(self, qubole_conn_id="qubole_default", *args, **kwargs):
+ def __init__(self,
+ qubole_conn_id="qubole_default",
+ do_xcom_push=True,
+ *args,
+ **kwargs):
self.args = args
self.kwargs = kwargs
self.kwargs['qubole_conn_id'] = qubole_conn_id
super(QuboleOperator, self).__init__(*args, **kwargs)
+ self.do_xcom_push = do_xcom_push
if self.on_failure_callback is None:
self.on_failure_callback = QuboleHook.handle_failure_retry
@@ -142,7 +149,10 @@ def __init__(self, qubole_conn_id="qubole_default", *args,
**kwargs):
self.on_retry_callback = QuboleHook.handle_failure_retry
def execute(self, context):
- return self.get_hook().execute(context)
+ result = self.get_hook().execute(context)
+
+ if self.do_xcom_push:
+ return result
def on_kill(self, ti=None):
self.get_hook().kill(ti)
diff --git a/airflow/contrib/operators/s3_list_operator.py
b/airflow/contrib/operators/s3_list_operator.py
index 3ca22d5932..88a5346a36 100644
--- a/airflow/contrib/operators/s3_list_operator.py
+++ b/airflow/contrib/operators/s3_list_operator.py
@@ -48,6 +48,8 @@ class S3ListOperator(BaseOperator):
You can specify this argument if you want to use a different
CA cert bundle than the one used by botocore.
:type verify: bool or str
+ :param do_xcom_push: return the key list which also get set in XCOM
+ :type do_xcom_push: bool
**Example**:
The following operator would list all the files
@@ -72,6 +74,7 @@ def __init__(self,
delimiter='',
aws_conn_id='aws_default',
verify=None,
+ do_xcom_push=True,
*args,
**kwargs):
super(S3ListOperator, self).__init__(*args, **kwargs)
@@ -80,6 +83,7 @@ def __init__(self,
self.delimiter = delimiter
self.aws_conn_id = aws_conn_id
self.verify = verify
+ self.do_xcom_push = do_xcom_push
def execute(self, context):
hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify)
@@ -88,7 +92,10 @@ def execute(self, context):
'Getting the list of files from bucket: {0} in prefix: {1}
(Delimiter {2})'.
format(self.bucket, self.prefix, self.delimiter))
- return hook.list_keys(
+ list_keys = hook.list_keys(
bucket_name=self.bucket,
prefix=self.prefix,
delimiter=self.delimiter)
+
+ if self.do_xcom_push:
+ return list_keys
diff --git a/airflow/contrib/operators/s3_to_gcs_operator.py
b/airflow/contrib/operators/s3_to_gcs_operator.py
index 5dd355a6fd..e590e1a252 100644
--- a/airflow/contrib/operators/s3_to_gcs_operator.py
+++ b/airflow/contrib/operators/s3_to_gcs_operator.py
@@ -64,6 +64,8 @@ class S3ToGoogleCloudStorageOperator(S3ListOperator):
:param replace: Whether you want to replace existing destination files
or not.
:type replace: bool
+ :param do_xcom_push: return the file list which also get set in XCOM
+ :type do_xcom_push: bool
**Example**:
@@ -97,6 +99,7 @@ def __init__(self,
dest_gcs=None,
delegate_to=None,
replace=False,
+ do_xcom_push=True,
*args,
**kwargs):
@@ -112,6 +115,7 @@ def __init__(self,
self.delegate_to = delegate_to
self.replace = replace
self.verify = verify
+ self.do_xcom_push = do_xcom_push
if dest_gcs and not self._gcs_object_is_directory(self.dest_gcs):
self.log.info(
@@ -194,7 +198,8 @@ def execute(self, context):
'In sync, no files needed to be uploaded to Google Cloud'
'Storage')
- return files
+ if self.do_xcom_push:
+ return files
# Following functionality may be better suited in
# airflow/contrib/hooks/gcs_hook.py
diff --git a/airflow/contrib/operators/sftp_operator.py
b/airflow/contrib/operators/sftp_operator.py
index 620d875f89..fb0ebbd5db 100644
--- a/airflow/contrib/operators/sftp_operator.py
+++ b/airflow/contrib/operators/sftp_operator.py
@@ -117,4 +117,4 @@ def execute(self, context):
raise AirflowException("Error while transferring {0}, error: {1}"
.format(file_msg, str(e)))
- return None
+ return
diff --git a/airflow/contrib/operators/ssh_operator.py
b/airflow/contrib/operators/ssh_operator.py
index 2bf342935d..cede566562 100644
--- a/airflow/contrib/operators/ssh_operator.py
+++ b/airflow/contrib/operators/ssh_operator.py
@@ -45,7 +45,7 @@ class SSHOperator(BaseOperator):
:type command: str
:param timeout: timeout (in seconds) for executing the command.
:type timeout: int
- :param do_xcom_push: return the stdout which also get set in xcom by
airflow platform
+ :param do_xcom_push: return the stdout which also get set in XCOM
:type do_xcom_push: bool
"""
@@ -166,8 +166,6 @@ def execute(self, context):
except Exception as e:
raise AirflowException("SSH operator error: {0}".format(str(e)))
- return True
-
def tunnel(self):
ssh_client = self.ssh_hook.get_conn()
ssh_client.get_transport()
diff --git a/airflow/contrib/operators/winrm_operator.py
b/airflow/contrib/operators/winrm_operator.py
index c81acac44f..5404c3a8b4 100644
--- a/airflow/contrib/operators/winrm_operator.py
+++ b/airflow/contrib/operators/winrm_operator.py
@@ -48,7 +48,7 @@ class WinRMOperator(BaseOperator):
:type command: str
:param timeout: timeout for executing the command.
:type timeout: int
- :param do_xcom_push: return the stdout which also get set in xcom by
airflow platform
+ :param do_xcom_push: return the stdout which also get set in XCOM
:type do_xcom_push: bool
"""
template_fields = ('command',)
@@ -145,5 +145,3 @@ def execute(self, context):
raise AirflowException(error_msg)
self.log.info("Finished!")
-
- return True
diff --git a/airflow/models.py b/airflow/models.py
index 6cdfc0fd81..3984f56ec4 100755
--- a/airflow/models.py
+++ b/airflow/models.py
@@ -1641,7 +1641,7 @@ def signal_handler(signum, frame):
result = task_copy.execute(context=context)
# If the task returns a result, push an XCom containing it
- if result is not None:
+ if task_copy.do_xcom_push and result is not None:
self.xcom_push(key=XCOM_RETURN_KEY, value=result)
task_copy.post_execute(context=context, result=result)
@@ -2419,6 +2419,9 @@ class derived from this one results in the creation of a
task object,
)
:type executor_config: dict
+ :param do_xcom_push: if True, an XCom is pushed containing the Operator's
+ result
+ :type do_xcom_push: bool
"""
# For derived classes to define which fields will get jinjaified
@@ -2473,6 +2476,7 @@ def __init__(
run_as_user=None,
task_concurrency=None,
executor_config=None,
+ do_xcom_push=False,
inlets=None,
outlets=None,
*args,
@@ -2557,6 +2561,7 @@ def __init__(
self.run_as_user = run_as_user
self.task_concurrency = task_concurrency
self.executor_config = executor_config or {}
+ self.do_xcom_push = do_xcom_push
# Private attributes
self._upstream_task_ids = set()
@@ -2610,6 +2615,7 @@ def __init__(
'on_failure_callback',
'on_success_callback',
'on_retry_callback',
+ 'do_xcom_push',
}
def __eq__(self, other):
diff --git a/tests/contrib/operators/test_gcs_list_operator.py
b/tests/contrib/operators/test_gcs_list_operator.py
index 8f0281f76f..10f7d9ddab 100644
--- a/tests/contrib/operators/test_gcs_list_operator.py
+++ b/tests/contrib/operators/test_gcs_list_operator.py
@@ -52,3 +52,19 @@ def test_execute(self, mock_hook):
bucket=TEST_BUCKET, prefix=PREFIX, delimiter=DELIMITER
)
self.assertEqual(sorted(files), sorted(MOCK_FILES))
+
+
@mock.patch('airflow.contrib.operators.gcs_list_operator.GoogleCloudStorageHook')
+ def test_execute_without_xcom_push(self, mock_hook):
+ mock_hook.return_value.list.return_value = MOCK_FILES
+
+ operator = GoogleCloudStorageListOperator(task_id=TASK_ID,
+ bucket=TEST_BUCKET,
+ prefix=PREFIX,
+ delimiter=DELIMITER,
+ do_xcom_push=False)
+
+ files = operator.execute(None)
+ mock_hook.return_value.list.assert_called_once_with(
+ bucket=TEST_BUCKET, prefix=PREFIX, delimiter=DELIMITER
+ )
+ self.assertEqual(files, None)
diff --git a/tests/contrib/operators/test_gcs_to_s3_operator.py
b/tests/contrib/operators/test_gcs_to_s3_operator.py
index 1a35a223ab..15ef9886b7 100644
--- a/tests/contrib/operators/test_gcs_to_s3_operator.py
+++ b/tests/contrib/operators/test_gcs_to_s3_operator.py
@@ -72,3 +72,29 @@ def test_execute(self, mock_hook, mock_hook2):
sorted(uploaded_files))
self.assertEqual(sorted(MOCK_FILES),
sorted(hook.list_keys('bucket', delimiter='/')))
+
+ @mock_s3
+
@mock.patch('airflow.contrib.operators.gcs_list_operator.GoogleCloudStorageHook')
+ @mock.patch('airflow.contrib.operators.gcs_to_s3.GoogleCloudStorageHook')
+ def test_execute_without_xcom_push(self, mock_hook, mock_hook2):
+ mock_hook.return_value.list.return_value = MOCK_FILES
+ mock_hook.return_value.download.return_value = b"testing"
+ mock_hook2.return_value.list.return_value = MOCK_FILES
+
+ operator = GoogleCloudStorageToS3Operator(task_id=TASK_ID,
+ bucket=GCS_BUCKET,
+ prefix=PREFIX,
+ delimiter=DELIMITER,
+ dest_aws_conn_id=None,
+ dest_s3_key=S3_BUCKET,
+ do_xcom_push=False)
+ # create dest bucket
+ hook = S3Hook(aws_conn_id=None)
+ b = hook.get_bucket('bucket')
+ b.create()
+ b.put_object(Key=MOCK_FILES[0], Body=b'testing')
+
+ uploaded_files = operator.execute(None)
+ self.assertEqual(uploaded_files, None)
+ self.assertEqual(sorted(MOCK_FILES),
+ sorted(hook.list_keys('bucket', delimiter='/')))
diff --git a/tests/models.py b/tests/models.py
index f2d36a263b..5723cea421 100644
--- a/tests/models.py
+++ b/tests/models.py
@@ -2264,6 +2264,32 @@ def test_xcom_pull_different_execution_date(self):
include_prior_dates=True),
value)
+ def test_xcom_push_flag(self):
+ """
+ Tests the option for Operators to push XComs
+ """
+ value = 'hello'
+ task_id = 'test_no_xcom_push'
+ dag = models.DAG(dag_id='test_xcom')
+
+ # nothing saved to XCom
+ task = PythonOperator(
+ task_id=task_id,
+ dag=dag,
+ python_callable=lambda: value,
+ do_xcom_push=False,
+ owner='airflow',
+ start_date=datetime.datetime(2017, 1, 1)
+ )
+ ti = TI(task=task, execution_date=datetime.datetime(2017, 1, 1))
+ ti.run()
+ self.assertEqual(
+ ti.xcom_pull(
+ task_ids=task_id, key=models.XCOM_RETURN_KEY
+ ),
+ None
+ )
+
def test_post_execute_hook(self):
"""
Test that post_execute hook is called with the Operator's result.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services