itscaro closed pull request #4116: WIP [AIRFLOW-3207] option to stop task 
pushing result to xcom
URL: https://github.com/apache/incubator-airflow/pull/4116
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/airflow/contrib/operators/bigquery_get_data.py 
b/airflow/contrib/operators/bigquery_get_data.py
index f5e6e50f06..e9c7787dd3 100644
--- a/airflow/contrib/operators/bigquery_get_data.py
+++ b/airflow/contrib/operators/bigquery_get_data.py
@@ -66,6 +66,8 @@ class BigQueryGetDataOperator(BaseOperator):
         For this to work, the service account making the request must have 
domain-wide
         delegation enabled.
     :type delegate_to: str
+    :param do_xcom_push: return the result which also get set in XCOM
+    :type do_xcom_push: bool
     """
     template_fields = ('dataset_id', 'table_id', 'max_results')
     ui_color = '#e4f0e8'
@@ -78,6 +80,7 @@ def __init__(self,
                  selected_fields=None,
                  bigquery_conn_id='bigquery_default',
                  delegate_to=None,
+                 do_xcom_push=True,
                  *args,
                  **kwargs):
         super(BigQueryGetDataOperator, self).__init__(*args, **kwargs)
@@ -87,6 +90,7 @@ def __init__(self,
         self.selected_fields = selected_fields
         self.bigquery_conn_id = bigquery_conn_id
         self.delegate_to = delegate_to
+        self.do_xcom_push = do_xcom_push
 
     def execute(self, context):
         self.log.info('Fetching Data from:')
@@ -113,4 +117,5 @@ def execute(self, context):
                 single_row.append(fields['v'])
             table_data.append(single_row)
 
-        return table_data
+        if self.do_xcom_push:
+            return table_data
diff --git a/airflow/contrib/operators/dataflow_operator.py 
b/airflow/contrib/operators/dataflow_operator.py
index 5378735f94..dfb65761a1 100644
--- a/airflow/contrib/operators/dataflow_operator.py
+++ b/airflow/contrib/operators/dataflow_operator.py
@@ -311,7 +311,6 @@ def __init__(
             poll_sleep=10,
             *args,
             **kwargs):
-
         super(DataFlowPythonOperator, self).__init__(*args, **kwargs)
 
         self.py_file = py_file
@@ -335,9 +334,11 @@ def execute(self, context):
                             poll_sleep=self.poll_sleep)
         dataflow_options = self.dataflow_default_options.copy()
         dataflow_options.update(self.options)
+
         # Convert argument names from lowerCamelCase to snake case.
-        camel_to_snake = lambda name: re.sub(
-            r'[A-Z]', lambda x: '_' + x.group(0).lower(), name)
+        def camel_to_snake(name):
+            return re.sub(r'[A-Z]', lambda x: '_' + x.group(0).lower(), name)
+
         formatted_options = {camel_to_snake(key): dataflow_options[key]
                              for key in dataflow_options}
         hook.start_python_dataflow(
diff --git a/airflow/contrib/operators/dataproc_operator.py 
b/airflow/contrib/operators/dataproc_operator.py
index 60fc2bcf15..8823c56c30 100644
--- a/airflow/contrib/operators/dataproc_operator.py
+++ b/airflow/contrib/operators/dataproc_operator.py
@@ -229,14 +229,14 @@ def _get_cluster(self, service):
         cluster = [c for c in cluster_list if c['clusterName'] == 
self.cluster_name]
         if cluster:
             return cluster[0]
-        return None
+        return
 
     def _get_cluster_state(self, service):
         cluster = self._get_cluster(service)
         if 'status' in cluster:
             return cluster['status']['state']
         else:
-            return None
+            return
 
     def _cluster_ready(self, state, service):
         if state == 'RUNNING':
@@ -407,7 +407,7 @@ def execute(self, context):
                 self.cluster_name
             )
             self._wait_for_done(service)
-            return True
+            return
 
         cluster_data = self._build_cluster_data()
         try:
@@ -425,7 +425,7 @@ def execute(self, context):
                     self.cluster_name
                 )
                 self._wait_for_done(service)
-                return True
+                return
             else:
                 raise e
 
diff --git a/airflow/contrib/operators/emr_add_steps_operator.py 
b/airflow/contrib/operators/emr_add_steps_operator.py
index 959543e617..c0048681cf 100644
--- a/airflow/contrib/operators/emr_add_steps_operator.py
+++ b/airflow/contrib/operators/emr_add_steps_operator.py
@@ -32,6 +32,8 @@ class EmrAddStepsOperator(BaseOperator):
     :type aws_conn_id: str
     :param steps: boto3 style steps to be added to the jobflow. (templated)
     :type steps: list
+    :param do_xcom_push: return the Step IDs which also get set in XCOM
+    :type do_xcom_push: bool
     """
     template_fields = ['job_flow_id', 'steps']
     template_ext = ()
@@ -43,12 +45,14 @@ def __init__(
             job_flow_id,
             aws_conn_id='s3_default',
             steps=None,
+            do_xcom_push=True,
             *args, **kwargs):
         super(EmrAddStepsOperator, self).__init__(*args, **kwargs)
         steps = steps or []
         self.job_flow_id = job_flow_id
         self.aws_conn_id = aws_conn_id
         self.steps = steps
+        self.do_xcom_push = do_xcom_push
 
     def execute(self, context):
         emr = EmrHook(aws_conn_id=self.aws_conn_id).get_conn()
@@ -60,4 +64,5 @@ def execute(self, context):
             raise AirflowException('Adding steps failed: %s' % response)
         else:
             self.log.info('Steps %s added to JobFlow', response['StepIds'])
-            return response['StepIds']
+            if self.do_xcom_push:
+                return response['StepIds']
diff --git a/airflow/contrib/operators/emr_create_job_flow_operator.py 
b/airflow/contrib/operators/emr_create_job_flow_operator.py
index 42886d6006..6d936b8df8 100644
--- a/airflow/contrib/operators/emr_create_job_flow_operator.py
+++ b/airflow/contrib/operators/emr_create_job_flow_operator.py
@@ -35,6 +35,8 @@ class EmrCreateJobFlowOperator(BaseOperator):
     :param job_flow_overrides: boto3 style arguments to override
        emr_connection extra. (templated)
     :type steps: dict
+    :param do_xcom_push: return the status code which also get set in XCOM
+    :type do_xcom_push: bool
     """
     template_fields = ['job_flow_overrides']
     template_ext = ()
@@ -46,6 +48,7 @@ def __init__(
             aws_conn_id='s3_default',
             emr_conn_id='emr_default',
             job_flow_overrides=None,
+            do_xcom_push=True,
             *args, **kwargs):
         super(EmrCreateJobFlowOperator, self).__init__(*args, **kwargs)
         self.aws_conn_id = aws_conn_id
@@ -53,6 +56,7 @@ def __init__(
         if job_flow_overrides is None:
             job_flow_overrides = {}
         self.job_flow_overrides = job_flow_overrides
+        self.do_xcom_push = do_xcom_push
 
     def execute(self, context):
         emr = EmrHook(aws_conn_id=self.aws_conn_id, 
emr_conn_id=self.emr_conn_id)
@@ -67,4 +71,5 @@ def execute(self, context):
             raise AirflowException('JobFlow creation failed: %s' % response)
         else:
             self.log.info('JobFlow with id %s created', response['JobFlowId'])
-            return response['JobFlowId']
+            if self.do_xcom_push:
+                return response['JobFlowId']
diff --git a/airflow/contrib/operators/gcp_container_operator.py 
b/airflow/contrib/operators/gcp_container_operator.py
index fda4d44b9d..814d8d6f6c 100644
--- a/airflow/contrib/operators/gcp_container_operator.py
+++ b/airflow/contrib/operators/gcp_container_operator.py
@@ -133,6 +133,9 @@ class GKEClusterCreateOperator(BaseOperator):
     :type gcp_conn_id: str
     :param api_version: The api version to use
     :type api_version: str
+    :param do_xcom_push: return the result of cluster creation operation which 
also get
+        set in XCOM
+    :type do_xcom_push: bool
     """
     template_fields = ['project_id', 'gcp_conn_id', 'location', 'api_version', 
'body']
 
@@ -143,6 +146,7 @@ def __init__(self,
                  body=None,
                  gcp_conn_id='google_cloud_default',
                  api_version='v2',
+                 do_xcom_push=True,
                  *args,
                  **kwargs):
         super(GKEClusterCreateOperator, self).__init__(*args, **kwargs)
@@ -154,6 +158,7 @@ def __init__(self,
         self.location = location
         self.api_version = api_version
         self.body = body
+        self.do_xcom_push = do_xcom_push
 
     def _check_input(self):
         if all([self.project_id, self.location, self.body]):
@@ -175,7 +180,9 @@ def execute(self, context):
         self._check_input()
         hook = GKEClusterHook(self.project_id, self.location)
         create_op = hook.create_cluster(cluster=self.body)
-        return create_op
+
+        if self.do_xcom_push:
+            return create_op
 
 
 KUBE_CONFIG_ENV_VAR = "KUBECONFIG"
diff --git a/airflow/contrib/operators/gcp_function_operator.py 
b/airflow/contrib/operators/gcp_function_operator.py
index 8207b9d084..677acb8d87 100644
--- a/airflow/contrib/operators/gcp_function_operator.py
+++ b/airflow/contrib/operators/gcp_function_operator.py
@@ -274,6 +274,8 @@ class GcfFunctionDeleteOperator(BaseOperator):
     :type gcp_conn_id: str
     :param api_version: Version of the API used (for example v1).
     :type api_version: str
+    :param do_xcom_push: return the file list which also get set in XCOM
+    :type do_xcom_push: bool
     """
 
     @apply_defaults
@@ -281,10 +283,13 @@ def __init__(self,
                  name,
                  gcp_conn_id='google_cloud_default',
                  api_version='v1',
-                 *args, **kwargs):
+                 do_xcom_push=True,
+                 *args,
+                 **kwargs):
         self.name = name
         self.gcp_conn_id = gcp_conn_id
         self.api_version = api_version
+        self.do_xcom_push = do_xcom_push
         self._validate_inputs()
         self.hook = GcfHook(gcp_conn_id=self.gcp_conn_id, 
api_version=self.api_version)
         super(GcfFunctionDeleteOperator, self).__init__(*args, **kwargs)
@@ -300,7 +305,10 @@ def _validate_inputs(self):
 
     def execute(self, context):
         try:
-            return self.hook.delete_function(self.name)
+            result = self.hook.delete_function(self.name)
+
+            if self.do_xcom_push:
+                return result
         except HttpError as e:
             status = e.resp.status
             if status == 404:
diff --git a/airflow/contrib/operators/gcs_list_operator.py 
b/airflow/contrib/operators/gcs_list_operator.py
index 7b37b269a6..1e1587dc1d 100644
--- a/airflow/contrib/operators/gcs_list_operator.py
+++ b/airflow/contrib/operators/gcs_list_operator.py
@@ -45,6 +45,9 @@ class GoogleCloudStorageListOperator(BaseOperator):
         For this to work, the service account making the request must have
         domain-wide delegation enabled.
     :type delegate_to: str
+    :param do_xcom_push: return the list of Google Cloud Storage Hook which 
also get set
+        in XCOM
+    :type do_xcom_push: bool
 
     **Example**:
         The following Operator would list all the Avro files from 
``sales/sales-2017``
@@ -68,6 +71,7 @@ def __init__(self,
                  delimiter=None,
                  google_cloud_storage_conn_id='google_cloud_default',
                  delegate_to=None,
+                 do_xcom_push=True,
                  *args,
                  **kwargs):
         super(GoogleCloudStorageListOperator, self).__init__(*args, **kwargs)
@@ -76,6 +80,7 @@ def __init__(self,
         self.delimiter = delimiter
         self.google_cloud_storage_conn_id = google_cloud_storage_conn_id
         self.delegate_to = delegate_to
+        self.do_xcom_push = do_xcom_push
 
     def execute(self, context):
 
@@ -87,6 +92,9 @@ def execute(self, context):
         self.log.info('Getting list of the files. Bucket: %s; Delimiter: %s; 
Prefix: %s',
                       self.bucket, self.delimiter, self.prefix)
 
-        return hook.list(bucket=self.bucket,
-                         prefix=self.prefix,
-                         delimiter=self.delimiter)
+        files = hook.list(bucket=self.bucket,
+                          prefix=self.prefix,
+                          delimiter=self.delimiter)
+
+        if self.do_xcom_push or __name__ != 'GoogleCloudStorageListOperator':
+            return files
diff --git a/airflow/contrib/operators/gcs_to_bq.py 
b/airflow/contrib/operators/gcs_to_bq.py
index 39dff21606..d41f4b6acb 100644
--- a/airflow/contrib/operators/gcs_to_bq.py
+++ b/airflow/contrib/operators/gcs_to_bq.py
@@ -87,10 +87,7 @@ class GoogleCloudStorageToBigQueryOperator(BaseOperator):
     :type allow_jagged_rows: bool
     :param max_id_key: If set, the name of a column in the BigQuery table
         that's to be loaded. This will be used to select the MAX value from
-        BigQuery after the load occurs. The results will be returned by the
-        execute() command, which in turn gets stored in XCom for future
-        operators to use. This can be helpful with incremental loads--during
-        future executions, you can pick up from the max ID.
+        BigQuery after the load occurs. (used in combination with do_xcom_push)
     :type max_id_key: str
     :param bigquery_conn_id: Reference to a specific BigQuery hook.
     :type bigquery_conn_id: str
@@ -119,6 +116,11 @@ class GoogleCloudStorageToBigQueryOperator(BaseOperator):
         time_partitioning. The order of columns given determines the sort 
order.
         Not applicable for external tables.
     :type cluster_fields: list of str
+    :param do_xcom_push: return the max id which also get set in XCOM.
+        (max_id_key must be set)
+        This can be helpful with incremental loads--during future executions, 
you can pick
+        up from the max ID.
+    :type do_xcom_push: bool
     """
     template_fields = ('bucket', 'source_objects',
                        'schema_object', 'destination_project_dataset_table')
@@ -152,6 +154,7 @@ def __init__(self,
                  external_table=False,
                  time_partitioning=None,
                  cluster_fields=None,
+                 do_xcom_push=True,
                  *args, **kwargs):
 
         super(GoogleCloudStorageToBigQueryOperator, self).__init__(*args, 
**kwargs)
@@ -190,6 +193,7 @@ def __init__(self,
         self.src_fmt_configs = src_fmt_configs
         self.time_partitioning = time_partitioning
         self.cluster_fields = cluster_fields
+        self.do_xcom_push = do_xcom_push
 
     def execute(self, context):
         bq_hook = BigQueryHook(bigquery_conn_id=self.bigquery_conn_id,
@@ -248,7 +252,7 @@ def execute(self, context):
                 time_partitioning=self.time_partitioning,
                 cluster_fields=self.cluster_fields)
 
-        if self.max_id_key:
+        if self.do_xcom_push and self.max_id_key:
             cursor.execute('SELECT MAX({}) FROM {}'.format(
                 self.max_id_key,
                 self.destination_project_dataset_table))
diff --git a/airflow/contrib/operators/gcs_to_s3.py 
b/airflow/contrib/operators/gcs_to_s3.py
index d8b180c81a..9db7a7ffb5 100644
--- a/airflow/contrib/operators/gcs_to_s3.py
+++ b/airflow/contrib/operators/gcs_to_s3.py
@@ -57,6 +57,8 @@ class 
GoogleCloudStorageToS3Operator(GoogleCloudStorageListOperator):
                  You can specify this argument if you want to use a different
                  CA cert bundle than the one used by botocore.
     :type dest_verify: bool or str
+    :param do_xcom_push: return the file list which also get set in XCOM
+    :type do_xcom_push: bool
     """
     template_fields = ('bucket', 'prefix', 'delimiter', 'dest_s3_key')
     ui_color = '#f0eee4'
@@ -72,6 +74,7 @@ def __init__(self,
                  dest_s3_key=None,
                  dest_verify=None,
                  replace=False,
+                 do_xcom_push=True,
                  *args,
                  **kwargs):
 
@@ -88,6 +91,7 @@ def __init__(self,
         self.dest_s3_key = dest_s3_key
         self.dest_verify = dest_verify
         self.replace = replace
+        self.do_xcom_push = do_xcom_push
 
     def execute(self, context):
         # use the super to list all files in an Google Cloud Storage bucket
@@ -122,4 +126,5 @@ def execute(self, context):
         else:
             self.log.info("In sync, no files needed to be uploaded to S3")
 
-        return files
+        if self.do_xcom_push:
+            return files
diff --git a/airflow/contrib/operators/hive_to_dynamodb.py 
b/airflow/contrib/operators/hive_to_dynamodb.py
index 4a39e40741..bd0b937e7e 100644
--- a/airflow/contrib/operators/hive_to_dynamodb.py
+++ b/airflow/contrib/operators/hive_to_dynamodb.py
@@ -70,7 +70,8 @@ def __init__(
             schema='default',
             hiveserver2_conn_id='hiveserver2_default',
             aws_conn_id='aws_default',
-            *args, **kwargs):
+            *args,
+            **kwargs):
         super(HiveToDynamoDBTransferOperator, self).__init__(*args, **kwargs)
         self.sql = sql
         self.table_name = table_name
diff --git a/airflow/contrib/operators/jenkins_job_trigger_operator.py 
b/airflow/contrib/operators/jenkins_job_trigger_operator.py
index 3e1aba0c1b..027343405a 100644
--- a/airflow/contrib/operators/jenkins_job_trigger_operator.py
+++ b/airflow/contrib/operators/jenkins_job_trigger_operator.py
@@ -106,6 +106,8 @@ class JenkinsJobTriggerOperator(BaseOperator):
     :param max_try_before_job_appears: The maximum number of requests to make
         while waiting for the job to appears on jenkins server (default 10)
     :type max_try_before_job_appears: int
+    :param do_xcom_push: return the build URL which also get set in XCOM
+    :type do_xcom_push: bool
     """
     template_fields = ('parameters',)
     template_ext = ('.json',)
@@ -118,6 +120,7 @@ def __init__(self,
                  parameters="",
                  sleep_time=10,
                  max_try_before_job_appears=10,
+                 do_xcom_push=True,
                  *args,
                  **kwargs):
         super(JenkinsJobTriggerOperator, self).__init__(*args, **kwargs)
@@ -128,6 +131,7 @@ def __init__(self,
         self.sleep_time = sleep_time
         self.jenkins_connection_id = jenkins_connection_id
         self.max_try_before_job_appears = max_try_before_job_appears
+        self.do_xcom_push = do_xcom_push
 
     def build_job(self, jenkins_server):
         """
@@ -243,7 +247,8 @@ def execute(self, context):
                     'this exception for unknown parameters'
                     'You can also check logs for more details on this 
exception '
                     '(jenkins_url/log/rss)', str(err))
-        if build_info:
+
+        if self.do_xcom_push and build_info:
             # If we can we return the url of the job
             # for later use (like retrieving an artifact)
             return build_info['url']
diff --git a/airflow/contrib/operators/jira_operator.py 
b/airflow/contrib/operators/jira_operator.py
index 01d78b8645..b165c1cbc3 100644
--- a/airflow/contrib/operators/jira_operator.py
+++ b/airflow/contrib/operators/jira_operator.py
@@ -41,6 +41,8 @@ class JiraOperator(BaseOperator):
     :param get_jira_resource_method: function or operator to get jira resource
                                     on which the provided jira_method will be 
executed
     :type get_jira_resource_method: function
+    :param do_xcom_push: return the result which also get set in XCOM
+    :type do_xcom_push: bool
     """
 
     template_fields = ("jira_method_args",)
@@ -52,6 +54,7 @@ def __init__(self,
                  jira_method_args=None,
                  result_processor=None,
                  get_jira_resource_method=None,
+                 do_xcom_push=True,
                  *args,
                  **kwargs):
         super(JiraOperator, self).__init__(*args, **kwargs)
@@ -60,6 +63,7 @@ def __init__(self,
         self.jira_method_args = jira_method_args
         self.result_processor = result_processor
         self.get_jira_resource_method = get_jira_resource_method
+        self.do_xcom_push = do_xcom_push
 
     def execute(self, context):
         try:
@@ -82,10 +86,12 @@ def execute(self, context):
             # ex: self.xcom_push(context, key='operator_response', 
value=jira_response)
             # This could potentially throw error if jira_result is not 
picklable
             jira_result = getattr(resource, 
self.method_name)(**self.jira_method_args)
-            if self.result_processor:
-                return self.result_processor(context, jira_result)
 
-            return jira_result
+            if self.do_xcom_push:
+                if self.result_processor:
+                    return self.result_processor(context, jira_result)
+
+                return jira_result
 
         except JIRAError as jira_error:
             raise AirflowException("Failed to execute jiraOperator, error: %s"
diff --git a/airflow/contrib/operators/mlengine_operator.py 
b/airflow/contrib/operators/mlengine_operator.py
index 8091ceefff..d9a123df56 100644
--- a/airflow/contrib/operators/mlengine_operator.py
+++ b/airflow/contrib/operators/mlengine_operator.py
@@ -150,6 +150,8 @@ class MLEngineBatchPredictionOperator(BaseOperator):
         For this to work, the service account making the request must
         have doamin-wide delegation enabled.
     :type delegate_to: str
+    :param do_xcom_push: return the result which also get set in XCOM
+    :type do_xcom_push: bool
 
     Raises:
         ``ValueError``: if a unique model/version origin cannot be determined.
@@ -181,6 +183,7 @@ def __init__(self,
                  runtime_version=None,
                  gcp_conn_id='google_cloud_default',
                  delegate_to=None,
+                 do_xcom_push=True,
                  *args,
                  **kwargs):
         super(MLEngineBatchPredictionOperator, self).__init__(*args, **kwargs)
@@ -198,6 +201,7 @@ def __init__(self,
         self._runtime_version = runtime_version
         self._gcp_conn_id = gcp_conn_id
         self._delegate_to = delegate_to
+        self.do_xcom_push = do_xcom_push
 
         if not self._project_id:
             raise AirflowException('Google Cloud project id is required.')
@@ -272,7 +276,8 @@ def check_existing_job(existing_job):
                 str(finished_prediction_job)))
             raise RuntimeError(finished_prediction_job['errorMessage'])
 
-        return finished_prediction_job['predictionOutput']
+        if self.do_xcom_push:
+            return finished_prediction_job['predictionOutput']
 
 
 class MLEngineModelOperator(BaseOperator):
@@ -300,6 +305,8 @@ class MLEngineModelOperator(BaseOperator):
         For this to work, the service account making the request must have
         domain-wide delegation enabled.
     :type delegate_to: str
+    :param do_xcom_push: return the result which also get set in XCOM
+    :type do_xcom_push: bool
     """
 
     template_fields = [
@@ -313,6 +320,7 @@ def __init__(self,
                  operation='create',
                  gcp_conn_id='google_cloud_default',
                  delegate_to=None,
+                 do_xcom_push=True,
                  *args,
                  **kwargs):
         super(MLEngineModelOperator, self).__init__(*args, **kwargs)
@@ -321,17 +329,21 @@ def __init__(self,
         self._operation = operation
         self._gcp_conn_id = gcp_conn_id
         self._delegate_to = delegate_to
+        self.do_xcom_push = do_xcom_push
 
     def execute(self, context):
         hook = MLEngineHook(
             gcp_conn_id=self._gcp_conn_id, delegate_to=self._delegate_to)
         if self._operation == 'create':
-            return hook.create_model(self._project_id, self._model)
+            result = hook.create_model(self._project_id, self._model)
         elif self._operation == 'get':
-            return hook.get_model(self._project_id, self._model['name'])
+            result = hook.get_model(self._project_id, self._model['name'])
         else:
             raise ValueError('Unknown operation: {}'.format(self._operation))
 
+        if self.do_xcom_push:
+            return result
+
 
 class MLEngineVersionOperator(BaseOperator):
     """
@@ -387,6 +399,8 @@ class MLEngineVersionOperator(BaseOperator):
         For this to work, the service account making the request must have
         domain-wide delegation enabled.
     :type delegate_to: str
+    :param do_xcom_push: return the result which also get set in XCOM
+    :type do_xcom_push: bool
     """
 
     template_fields = [
@@ -404,6 +418,7 @@ def __init__(self,
                  operation='create',
                  gcp_conn_id='google_cloud_default',
                  delegate_to=None,
+                 do_xcom_push=True,
                  *args,
                  **kwargs):
 
@@ -415,6 +430,7 @@ def __init__(self,
         self._operation = operation
         self._gcp_conn_id = gcp_conn_id
         self._delegate_to = delegate_to
+        self.do_xcom_push = do_xcom_push
 
     def execute(self, context):
         if 'name' not in self._version:
@@ -427,19 +443,22 @@ def execute(self, context):
             if not self._version:
                 raise ValueError("version attribute of {} could not "
                                  "be empty".format(self.__class__.__name__))
-            return hook.create_version(self._project_id, self._model_name,
-                                       self._version)
+            result = hook.create_version(self._project_id, self._model_name,
+                                         self._version)
         elif self._operation == 'set_default':
-            return hook.set_default_version(self._project_id, self._model_name,
-                                            self._version['name'])
+            result = hook.set_default_version(self._project_id, 
self._model_name,
+                                              self._version['name'])
         elif self._operation == 'list':
-            return hook.list_versions(self._project_id, self._model_name)
+            result = hook.list_versions(self._project_id, self._model_name)
         elif self._operation == 'delete':
-            return hook.delete_version(self._project_id, self._model_name,
-                                       self._version['name'])
+            result = hook.delete_version(self._project_id, self._model_name,
+                                         self._version['name'])
         else:
             raise ValueError('Unknown operation: {}'.format(self._operation))
 
+        if self.do_xcom_push:
+            return result
+
 
 class MLEngineTrainingOperator(BaseOperator):
     """
@@ -585,8 +604,7 @@ def execute(self, context):
 
         if self._mode == 'DRY_RUN':
             self.log.info('In dry_run mode.')
-            self.log.info('MLEngine Training job request is: {}'.format(
-                training_request))
+            self.log.info('MLEngine Training job request is: 
{}'.format(training_request))
             return
 
         hook = MLEngineHook(
diff --git a/airflow/contrib/operators/mlengine_prediction_summary.py 
b/airflow/contrib/operators/mlengine_prediction_summary.py
index def793c1be..227cf74488 100644
--- a/airflow/contrib/operators/mlengine_prediction_summary.py
+++ b/airflow/contrib/operators/mlengine_prediction_summary.py
@@ -113,7 +113,7 @@ def decode(x):
 
 
 @beam.ptransform_fn
-def MakeSummary(pcoll, metric_fn, metric_keys):  # pylint: disable=invalid-name
+def make_summary(pcoll, metric_fn, metric_keys):  # pylint: 
disable=invalid-name
     return (
         pcoll |
         "ApplyMetricFnPerInstance" >> beam.Map(metric_fn) |
@@ -156,8 +156,7 @@ def run(argv=None):
         raise ValueError("--metric_fn_encoded must be an encoded callable.")
     metric_keys = known_args.metric_keys.split(",")
 
-    with beam.Pipeline(
-        options=beam.pipeline.PipelineOptions(pipeline_args)) as p:
+    with beam.Pipeline(options=beam.pipeline.PipelineOptions(pipeline_args)) 
as p:
         # This is apache-beam ptransform's convention
         # pylint: disable=no-value-for-parameter
         _ = (p
@@ -165,7 +164,7 @@ def run(argv=None):
                  os.path.join(known_args.prediction_path,
                               "prediction.results-*-of-*"),
                  coder=JsonCoder())
-             | "Summary" >> MakeSummary(metric_fn, metric_keys)
+             | "Summary" >> make_summary(metric_fn, metric_keys)
              | "Write" >> beam.io.WriteToText(
                  os.path.join(known_args.prediction_path,
                               "prediction.summary.json"),
diff --git a/airflow/contrib/operators/mongo_to_s3.py 
b/airflow/contrib/operators/mongo_to_s3.py
index 8bfa7a52f8..82e8040132 100644
--- a/airflow/contrib/operators/mongo_to_s3.py
+++ b/airflow/contrib/operators/mongo_to_s3.py
@@ -94,8 +94,6 @@ def execute(self, context):
             replace=self.replace
         )
 
-        return True
-
     @staticmethod
     def _stringify(iterable, joinable='\n'):
         """
diff --git a/airflow/contrib/operators/pubsub_operator.py 
b/airflow/contrib/operators/pubsub_operator.py
index e40828bf92..38f14f2eb7 100644
--- a/airflow/contrib/operators/pubsub_operator.py
+++ b/airflow/contrib/operators/pubsub_operator.py
@@ -159,6 +159,7 @@ def __init__(
             fail_if_exists=False,
             gcp_conn_id='google_cloud_default',
             delegate_to=None,
+            do_xcom_push=True,
             *args,
             **kwargs):
         """
@@ -185,6 +186,8 @@ def __init__(
             For this to work, the service account making the request
             must have domain-wide delegation enabled.
         :type delegate_to: str
+        :param do_xcom_push: return the result which also get set in XCOM
+        :type do_xcom_push: bool
         """
         super(PubSubSubscriptionCreateOperator, self).__init__(*args, **kwargs)
 
@@ -196,16 +199,20 @@ def __init__(
         self.fail_if_exists = fail_if_exists
         self.gcp_conn_id = gcp_conn_id
         self.delegate_to = delegate_to
+        self.do_xcom_push = do_xcom_push
 
     def execute(self, context):
         hook = PubSubHook(gcp_conn_id=self.gcp_conn_id,
                           delegate_to=self.delegate_to)
 
-        return hook.create_subscription(
+        result = hook.create_subscription(
             self.topic_project, self.topic, self.subscription,
             self.subscription_project, self.ack_deadline_secs,
             self.fail_if_exists)
 
+        if self.do_xcom_push:
+            return result
+
 
 class PubSubTopicDeleteOperator(BaseOperator):
     """Delete a PubSub topic.
diff --git a/airflow/contrib/operators/qubole_operator.py 
b/airflow/contrib/operators/qubole_operator.py
index 82ee293b93..79d12dde75 100755
--- a/airflow/contrib/operators/qubole_operator.py
+++ b/airflow/contrib/operators/qubole_operator.py
@@ -28,6 +28,8 @@ class QuboleOperator(BaseOperator):
 
     :param qubole_conn_id: Connection id which consists of qds auth_token
     :type qubole_conn_id: str
+    :param do_xcom_push: return the result which also get set in XCOM
+    :type do_xcom_push: bool
 
     kwargs:
         :command_type: type of command to be executed, e.g. hivecmd, shellcmd, 
hadoopcmd
@@ -129,11 +131,16 @@ class QuboleOperator(BaseOperator):
     ui_fgcolor = '#fff'
 
     @apply_defaults
-    def __init__(self, qubole_conn_id="qubole_default", *args, **kwargs):
+    def __init__(self,
+                 qubole_conn_id="qubole_default",
+                 do_xcom_push=True,
+                 *args,
+                 **kwargs):
         self.args = args
         self.kwargs = kwargs
         self.kwargs['qubole_conn_id'] = qubole_conn_id
         super(QuboleOperator, self).__init__(*args, **kwargs)
+        self.do_xcom_push = do_xcom_push
 
         if self.on_failure_callback is None:
             self.on_failure_callback = QuboleHook.handle_failure_retry
@@ -142,7 +149,10 @@ def __init__(self, qubole_conn_id="qubole_default", *args, 
**kwargs):
             self.on_retry_callback = QuboleHook.handle_failure_retry
 
     def execute(self, context):
-        return self.get_hook().execute(context)
+        result = self.get_hook().execute(context)
+
+        if self.do_xcom_push:
+            return result
 
     def on_kill(self, ti=None):
         self.get_hook().kill(ti)
diff --git a/airflow/contrib/operators/s3_list_operator.py 
b/airflow/contrib/operators/s3_list_operator.py
index 3ca22d5932..88a5346a36 100644
--- a/airflow/contrib/operators/s3_list_operator.py
+++ b/airflow/contrib/operators/s3_list_operator.py
@@ -48,6 +48,8 @@ class S3ListOperator(BaseOperator):
                  You can specify this argument if you want to use a different
                  CA cert bundle than the one used by botocore.
     :type verify: bool or str
+    :param do_xcom_push: return the key list which also get set in XCOM
+    :type do_xcom_push: bool
 
     **Example**:
         The following operator would list all the files
@@ -72,6 +74,7 @@ def __init__(self,
                  delimiter='',
                  aws_conn_id='aws_default',
                  verify=None,
+                 do_xcom_push=True,
                  *args,
                  **kwargs):
         super(S3ListOperator, self).__init__(*args, **kwargs)
@@ -80,6 +83,7 @@ def __init__(self,
         self.delimiter = delimiter
         self.aws_conn_id = aws_conn_id
         self.verify = verify
+        self.do_xcom_push = do_xcom_push
 
     def execute(self, context):
         hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify)
@@ -88,7 +92,10 @@ def execute(self, context):
             'Getting the list of files from bucket: {0} in prefix: {1} 
(Delimiter {2})'.
             format(self.bucket, self.prefix, self.delimiter))
 
-        return hook.list_keys(
+        list_keys = hook.list_keys(
             bucket_name=self.bucket,
             prefix=self.prefix,
             delimiter=self.delimiter)
+
+        if self.do_xcom_push:
+            return list_keys
diff --git a/airflow/contrib/operators/s3_to_gcs_operator.py 
b/airflow/contrib/operators/s3_to_gcs_operator.py
index 5dd355a6fd..e590e1a252 100644
--- a/airflow/contrib/operators/s3_to_gcs_operator.py
+++ b/airflow/contrib/operators/s3_to_gcs_operator.py
@@ -64,6 +64,8 @@ class S3ToGoogleCloudStorageOperator(S3ListOperator):
     :param replace: Whether you want to replace existing destination files
         or not.
     :type replace: bool
+    :param do_xcom_push: return the file list which also get set in XCOM
+    :type do_xcom_push: bool
 
 
     **Example**:
@@ -97,6 +99,7 @@ def __init__(self,
                  dest_gcs=None,
                  delegate_to=None,
                  replace=False,
+                 do_xcom_push=True,
                  *args,
                  **kwargs):
 
@@ -112,6 +115,7 @@ def __init__(self,
         self.delegate_to = delegate_to
         self.replace = replace
         self.verify = verify
+        self.do_xcom_push = do_xcom_push
 
         if dest_gcs and not self._gcs_object_is_directory(self.dest_gcs):
             self.log.info(
@@ -194,7 +198,8 @@ def execute(self, context):
                 'In sync, no files needed to be uploaded to Google Cloud'
                 'Storage')
 
-        return files
+        if self.do_xcom_push:
+            return files
 
     # Following functionality may be better suited in
     # airflow/contrib/hooks/gcs_hook.py
diff --git a/airflow/contrib/operators/sftp_operator.py 
b/airflow/contrib/operators/sftp_operator.py
index 620d875f89..fb0ebbd5db 100644
--- a/airflow/contrib/operators/sftp_operator.py
+++ b/airflow/contrib/operators/sftp_operator.py
@@ -117,4 +117,4 @@ def execute(self, context):
             raise AirflowException("Error while transferring {0}, error: {1}"
                                    .format(file_msg, str(e)))
 
-        return None
+        return
diff --git a/airflow/contrib/operators/ssh_operator.py 
b/airflow/contrib/operators/ssh_operator.py
index 2bf342935d..cede566562 100644
--- a/airflow/contrib/operators/ssh_operator.py
+++ b/airflow/contrib/operators/ssh_operator.py
@@ -45,7 +45,7 @@ class SSHOperator(BaseOperator):
     :type command: str
     :param timeout: timeout (in seconds) for executing the command.
     :type timeout: int
-    :param do_xcom_push: return the stdout which also get set in xcom by 
airflow platform
+    :param do_xcom_push: return the stdout which also get set in XCOM
     :type do_xcom_push: bool
     """
 
@@ -166,8 +166,6 @@ def execute(self, context):
         except Exception as e:
             raise AirflowException("SSH operator error: {0}".format(str(e)))
 
-        return True
-
     def tunnel(self):
         ssh_client = self.ssh_hook.get_conn()
         ssh_client.get_transport()
diff --git a/airflow/contrib/operators/winrm_operator.py 
b/airflow/contrib/operators/winrm_operator.py
index c81acac44f..5404c3a8b4 100644
--- a/airflow/contrib/operators/winrm_operator.py
+++ b/airflow/contrib/operators/winrm_operator.py
@@ -48,7 +48,7 @@ class WinRMOperator(BaseOperator):
     :type command: str
     :param timeout: timeout for executing the command.
     :type timeout: int
-    :param do_xcom_push: return the stdout which also get set in xcom by 
airflow platform
+    :param do_xcom_push: return the stdout which also get set in XCOM
     :type do_xcom_push: bool
     """
     template_fields = ('command',)
@@ -145,5 +145,3 @@ def execute(self, context):
             raise AirflowException(error_msg)
 
         self.log.info("Finished!")
-
-        return True
diff --git a/airflow/models.py b/airflow/models.py
index 6cdfc0fd81..3984f56ec4 100755
--- a/airflow/models.py
+++ b/airflow/models.py
@@ -1641,7 +1641,7 @@ def signal_handler(signum, frame):
                     result = task_copy.execute(context=context)
 
                 # If the task returns a result, push an XCom containing it
-                if result is not None:
+                if task_copy.do_xcom_push and result is not None:
                     self.xcom_push(key=XCOM_RETURN_KEY, value=result)
 
                 task_copy.post_execute(context=context, result=result)
@@ -2419,6 +2419,9 @@ class derived from this one results in the creation of a 
task object,
             )
 
     :type executor_config: dict
+    :param do_xcom_push: if True, an XCom is pushed containing the Operator's
+        result
+    :type do_xcom_push: bool
     """
 
     # For derived classes to define which fields will get jinjaified
@@ -2473,6 +2476,7 @@ def __init__(
             run_as_user=None,
             task_concurrency=None,
             executor_config=None,
+            do_xcom_push=False,
             inlets=None,
             outlets=None,
             *args,
@@ -2557,6 +2561,7 @@ def __init__(
         self.run_as_user = run_as_user
         self.task_concurrency = task_concurrency
         self.executor_config = executor_config or {}
+        self.do_xcom_push = do_xcom_push
 
         # Private attributes
         self._upstream_task_ids = set()
@@ -2610,6 +2615,7 @@ def __init__(
             'on_failure_callback',
             'on_success_callback',
             'on_retry_callback',
+            'do_xcom_push',
         }
 
     def __eq__(self, other):
diff --git a/tests/contrib/operators/test_gcs_list_operator.py 
b/tests/contrib/operators/test_gcs_list_operator.py
index 8f0281f76f..10f7d9ddab 100644
--- a/tests/contrib/operators/test_gcs_list_operator.py
+++ b/tests/contrib/operators/test_gcs_list_operator.py
@@ -52,3 +52,19 @@ def test_execute(self, mock_hook):
             bucket=TEST_BUCKET, prefix=PREFIX, delimiter=DELIMITER
         )
         self.assertEqual(sorted(files), sorted(MOCK_FILES))
+
+    
@mock.patch('airflow.contrib.operators.gcs_list_operator.GoogleCloudStorageHook')
+    def test_execute_without_xcom_push(self, mock_hook):
+        mock_hook.return_value.list.return_value = MOCK_FILES
+
+        operator = GoogleCloudStorageListOperator(task_id=TASK_ID,
+                                                  bucket=TEST_BUCKET,
+                                                  prefix=PREFIX,
+                                                  delimiter=DELIMITER,
+                                                  do_xcom_push=False)
+
+        files = operator.execute(None)
+        mock_hook.return_value.list.assert_called_once_with(
+            bucket=TEST_BUCKET, prefix=PREFIX, delimiter=DELIMITER
+        )
+        self.assertEqual(files, None)
diff --git a/tests/contrib/operators/test_gcs_to_s3_operator.py 
b/tests/contrib/operators/test_gcs_to_s3_operator.py
index 1a35a223ab..15ef9886b7 100644
--- a/tests/contrib/operators/test_gcs_to_s3_operator.py
+++ b/tests/contrib/operators/test_gcs_to_s3_operator.py
@@ -72,3 +72,29 @@ def test_execute(self, mock_hook, mock_hook2):
                          sorted(uploaded_files))
         self.assertEqual(sorted(MOCK_FILES),
                          sorted(hook.list_keys('bucket', delimiter='/')))
+
+    @mock_s3
+    
@mock.patch('airflow.contrib.operators.gcs_list_operator.GoogleCloudStorageHook')
+    @mock.patch('airflow.contrib.operators.gcs_to_s3.GoogleCloudStorageHook')
+    def test_execute_without_xcom_push(self, mock_hook, mock_hook2):
+        mock_hook.return_value.list.return_value = MOCK_FILES
+        mock_hook.return_value.download.return_value = b"testing"
+        mock_hook2.return_value.list.return_value = MOCK_FILES
+
+        operator = GoogleCloudStorageToS3Operator(task_id=TASK_ID,
+                                                  bucket=GCS_BUCKET,
+                                                  prefix=PREFIX,
+                                                  delimiter=DELIMITER,
+                                                  dest_aws_conn_id=None,
+                                                  dest_s3_key=S3_BUCKET,
+                                                  do_xcom_push=False)
+        # create dest bucket
+        hook = S3Hook(aws_conn_id=None)
+        b = hook.get_bucket('bucket')
+        b.create()
+        b.put_object(Key=MOCK_FILES[0], Body=b'testing')
+
+        uploaded_files = operator.execute(None)
+        self.assertEqual(uploaded_files, None)
+        self.assertEqual(sorted(MOCK_FILES),
+                         sorted(hook.list_keys('bucket', delimiter='/')))
diff --git a/tests/models.py b/tests/models.py
index f2d36a263b..5723cea421 100644
--- a/tests/models.py
+++ b/tests/models.py
@@ -2264,6 +2264,32 @@ def test_xcom_pull_different_execution_date(self):
                                       include_prior_dates=True),
                          value)
 
+    def test_xcom_push_flag(self):
+        """
+        Tests the option for Operators to push XComs
+        """
+        value = 'hello'
+        task_id = 'test_no_xcom_push'
+        dag = models.DAG(dag_id='test_xcom')
+
+        # nothing saved to XCom
+        task = PythonOperator(
+            task_id=task_id,
+            dag=dag,
+            python_callable=lambda: value,
+            do_xcom_push=False,
+            owner='airflow',
+            start_date=datetime.datetime(2017, 1, 1)
+        )
+        ti = TI(task=task, execution_date=datetime.datetime(2017, 1, 1))
+        ti.run()
+        self.assertEqual(
+            ti.xcom_pull(
+                task_ids=task_id, key=models.XCOM_RETURN_KEY
+            ),
+            None
+        )
+
     def test_post_execute_hook(self):
         """
         Test that post_execute hook is called with the Operator's result.


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to