Fokko closed pull request #3981: [AIRFLOW-3133] Implement xcom_push flag for
contrib's operators
URL: https://github.com/apache/incubator-airflow/pull/3981
This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:
As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):
diff --git a/airflow/contrib/operators/bigquery_get_data.py
b/airflow/contrib/operators/bigquery_get_data.py
index f5e6e50f06..f96392945d 100644
--- a/airflow/contrib/operators/bigquery_get_data.py
+++ b/airflow/contrib/operators/bigquery_get_data.py
@@ -66,6 +66,8 @@ class BigQueryGetDataOperator(BaseOperator):
For this to work, the service account making the request must have
domain-wide
delegation enabled.
:type delegate_to: str
+ :param xcom_push: return the result which also get set in XCOM
+ :type xcom_push: bool
"""
template_fields = ('dataset_id', 'table_id', 'max_results')
ui_color = '#e4f0e8'
@@ -78,6 +80,7 @@ def __init__(self,
selected_fields=None,
bigquery_conn_id='bigquery_default',
delegate_to=None,
+ xcom_push=True,
*args,
**kwargs):
super(BigQueryGetDataOperator, self).__init__(*args, **kwargs)
@@ -87,6 +90,7 @@ def __init__(self,
self.selected_fields = selected_fields
self.bigquery_conn_id = bigquery_conn_id
self.delegate_to = delegate_to
+ self.xcom_push_flag = xcom_push
def execute(self, context):
self.log.info('Fetching Data from:')
@@ -113,4 +117,5 @@ def execute(self, context):
single_row.append(fields['v'])
table_data.append(single_row)
- return table_data
+ if self.xcom_push_flag:
+ return table_data
diff --git a/airflow/contrib/operators/databricks_operator.py
b/airflow/contrib/operators/databricks_operator.py
index 3ebc729f78..49b55b4bc8 100644
--- a/airflow/contrib/operators/databricks_operator.py
+++ b/airflow/contrib/operators/databricks_operator.py
@@ -63,11 +63,11 @@ def _handle_databricks_operator_execution(operator, hook,
log, context):
:param operator: Databricks operator being handled
:param context: Airflow context
"""
- if operator.do_xcom_push:
+ if operator.xcom_push:
context['ti'].xcom_push(key=XCOM_RUN_ID_KEY, value=operator.run_id)
log.info('Run submitted with run_id: %s', operator.run_id)
run_page_url = hook.get_run_page_url(operator.run_id)
- if operator.do_xcom_push:
+ if operator.xcom_push:
context['ti'].xcom_push(key=XCOM_RUN_PAGE_URL_KEY, value=run_page_url)
log.info('View run status, Spark UI, and logs at %s', run_page_url)
@@ -209,8 +209,8 @@ class DatabricksSubmitRunOperator(BaseOperator):
:param databricks_retry_delay: Number of seconds to wait between retries
(it
might be a floating point number).
:type databricks_retry_delay: float
- :param do_xcom_push: Whether we should push run_id and run_page_url to
xcom.
- :type do_xcom_push: bool
+ :param xcom_push: Whether we should push run_id and run_page_url to xcom.
+ :type xcom_push: bool
"""
# Used in airflow.models.BaseOperator
template_fields = ('json',)
@@ -232,7 +232,7 @@ def __init__(
polling_period_seconds=30,
databricks_retry_limit=3,
databricks_retry_delay=1,
- do_xcom_push=False,
+ xcom_push=False,
**kwargs):
"""
Creates a new ``DatabricksSubmitRunOperator``.
@@ -263,7 +263,7 @@ def __init__(
self.json = _deep_string_coerce(self.json)
# This variable will be used in case our task gets killed.
self.run_id = None
- self.do_xcom_push = do_xcom_push
+ self.xcom_push_flag = xcom_push
def get_hook(self):
return DatabricksHook(
@@ -410,8 +410,8 @@ class DatabricksRunNowOperator(BaseOperator):
:param databricks_retry_limit: Amount of times retry if the Databricks
backend is
unreachable. Its value must be greater than or equal to 1.
:type databricks_retry_limit: int
- :param do_xcom_push: Whether we should push run_id and run_page_url to
xcom.
- :type do_xcom_push: bool
+ :param xcom_push: Whether we should push run_id and run_page_url to xcom.
+ :type xcom_push: bool
"""
# Used in airflow.models.BaseOperator
template_fields = ('json',)
@@ -430,7 +430,7 @@ def __init__(
polling_period_seconds=30,
databricks_retry_limit=3,
databricks_retry_delay=1,
- do_xcom_push=False,
+ xcom_push=False,
**kwargs):
"""
@@ -455,7 +455,7 @@ def __init__(
self.json = _deep_string_coerce(self.json)
# This variable will be used in case our task gets killed.
self.run_id = None
- self.do_xcom_push = do_xcom_push
+ self.xcom_push_flag = xcom_push
def get_hook(self):
return DatabricksHook(
diff --git a/airflow/contrib/operators/dataflow_operator.py
b/airflow/contrib/operators/dataflow_operator.py
index 7f7a18495d..c10baec081 100644
--- a/airflow/contrib/operators/dataflow_operator.py
+++ b/airflow/contrib/operators/dataflow_operator.py
@@ -311,7 +311,6 @@ def __init__(
poll_sleep=10,
*args,
**kwargs):
-
super(DataFlowPythonOperator, self).__init__(*args, **kwargs)
self.py_file = py_file
@@ -335,9 +334,11 @@ def execute(self, context):
poll_sleep=self.poll_sleep)
dataflow_options = self.dataflow_default_options.copy()
dataflow_options.update(self.options)
+
# Convert argument names from lowerCamelCase to snake case.
- camel_to_snake = lambda name: re.sub(
- r'[A-Z]', lambda x: '_' + x.group(0).lower(), name)
+ def camel_to_snake(name):
+ return re.sub(r'[A-Z]', lambda x: '_' + x.group(0).lower(), name)
+
formatted_options = {camel_to_snake(key): dataflow_options[key]
for key in dataflow_options}
hook.start_python_dataflow(
diff --git a/airflow/contrib/operators/dataproc_operator.py
b/airflow/contrib/operators/dataproc_operator.py
index 49e24a3df2..65181799bc 100644
--- a/airflow/contrib/operators/dataproc_operator.py
+++ b/airflow/contrib/operators/dataproc_operator.py
@@ -221,14 +221,14 @@ def _get_cluster(self, service):
cluster = [c for c in cluster_list if c['clusterName'] ==
self.cluster_name]
if cluster:
return cluster[0]
- return None
+ return
def _get_cluster_state(self, service):
cluster = self._get_cluster(service)
if 'status' in cluster:
return cluster['status']['state']
else:
- return None
+ return
def _cluster_ready(self, state, service):
if state == 'RUNNING':
@@ -394,7 +394,7 @@ def execute(self, context):
self.cluster_name
)
self._wait_for_done(service)
- return True
+ return
cluster_data = self._build_cluster_data()
try:
@@ -412,7 +412,7 @@ def execute(self, context):
self.cluster_name
)
self._wait_for_done(service)
- return True
+ return
else:
raise e
diff --git a/airflow/contrib/operators/datastore_export_operator.py
b/airflow/contrib/operators/datastore_export_operator.py
index 9d95eadc74..0fa5790922 100644
--- a/airflow/contrib/operators/datastore_export_operator.py
+++ b/airflow/contrib/operators/datastore_export_operator.py
@@ -82,7 +82,7 @@ def __init__(self,
self.labels = labels
self.polling_interval_in_seconds = polling_interval_in_seconds
self.overwrite_existing = overwrite_existing
- self.xcom_push = xcom_push
+ self.xcom_push_flag = xcom_push
def execute(self, context):
self.log.info('Exporting data to Cloud Storage bucket ' + self.bucket)
@@ -106,5 +106,5 @@ def execute(self, context):
if state != 'SUCCESSFUL':
raise AirflowException('Operation failed:
result={}'.format(result))
- if self.xcom_push:
+ if self.xcom_push_flag:
return result
diff --git a/airflow/contrib/operators/datastore_import_operator.py
b/airflow/contrib/operators/datastore_import_operator.py
index c79767f35e..24e61da044 100644
--- a/airflow/contrib/operators/datastore_import_operator.py
+++ b/airflow/contrib/operators/datastore_import_operator.py
@@ -76,7 +76,7 @@ def __init__(self,
self.entity_filter = entity_filter
self.labels = labels
self.polling_interval_in_seconds = polling_interval_in_seconds
- self.xcom_push = xcom_push
+ self.xcom_push_flag = xcom_push
def execute(self, context):
self.log.info('Importing data from Cloud Storage bucket %s',
self.bucket)
@@ -94,5 +94,5 @@ def execute(self, context):
if state != 'SUCCESSFUL':
raise AirflowException('Operation failed:
result={}'.format(result))
- if self.xcom_push:
+ if self.xcom_push_flag:
return result
diff --git a/airflow/contrib/operators/emr_add_steps_operator.py
b/airflow/contrib/operators/emr_add_steps_operator.py
index 959543e617..74ad6742ba 100644
--- a/airflow/contrib/operators/emr_add_steps_operator.py
+++ b/airflow/contrib/operators/emr_add_steps_operator.py
@@ -32,6 +32,8 @@ class EmrAddStepsOperator(BaseOperator):
:type aws_conn_id: str
:param steps: boto3 style steps to be added to the jobflow. (templated)
:type steps: list
+ :param xcom_push: return the Step IDs which also get set in XCOM
+ :type xcom_push: bool
"""
template_fields = ['job_flow_id', 'steps']
template_ext = ()
@@ -43,12 +45,14 @@ def __init__(
job_flow_id,
aws_conn_id='s3_default',
steps=None,
+ xcom_push=True,
*args, **kwargs):
super(EmrAddStepsOperator, self).__init__(*args, **kwargs)
steps = steps or []
self.job_flow_id = job_flow_id
self.aws_conn_id = aws_conn_id
self.steps = steps
+ self.xcom_push_flag = xcom_push
def execute(self, context):
emr = EmrHook(aws_conn_id=self.aws_conn_id).get_conn()
@@ -60,4 +64,5 @@ def execute(self, context):
raise AirflowException('Adding steps failed: %s' % response)
else:
self.log.info('Steps %s added to JobFlow', response['StepIds'])
- return response['StepIds']
+ if self.xcom_push_flag:
+ return response['StepIds']
diff --git a/airflow/contrib/operators/emr_create_job_flow_operator.py
b/airflow/contrib/operators/emr_create_job_flow_operator.py
index 42886d6006..e6445f7684 100644
--- a/airflow/contrib/operators/emr_create_job_flow_operator.py
+++ b/airflow/contrib/operators/emr_create_job_flow_operator.py
@@ -35,6 +35,8 @@ class EmrCreateJobFlowOperator(BaseOperator):
:param job_flow_overrides: boto3 style arguments to override
emr_connection extra. (templated)
:type steps: dict
+ :param xcom_push: return the status code which also get set in XCOM
+ :type xcom_push: bool
"""
template_fields = ['job_flow_overrides']
template_ext = ()
@@ -46,6 +48,7 @@ def __init__(
aws_conn_id='s3_default',
emr_conn_id='emr_default',
job_flow_overrides=None,
+ xcom_push=True,
*args, **kwargs):
super(EmrCreateJobFlowOperator, self).__init__(*args, **kwargs)
self.aws_conn_id = aws_conn_id
@@ -53,6 +56,7 @@ def __init__(
if job_flow_overrides is None:
job_flow_overrides = {}
self.job_flow_overrides = job_flow_overrides
+ self.xcom_push_flag = xcom_push
def execute(self, context):
emr = EmrHook(aws_conn_id=self.aws_conn_id,
emr_conn_id=self.emr_conn_id)
@@ -67,4 +71,5 @@ def execute(self, context):
raise AirflowException('JobFlow creation failed: %s' % response)
else:
self.log.info('JobFlow with id %s created', response['JobFlowId'])
- return response['JobFlowId']
+ if self.xcom_push_flag:
+ return response['JobFlowId']
diff --git a/airflow/contrib/operators/gcp_container_operator.py
b/airflow/contrib/operators/gcp_container_operator.py
index fda4d44b9d..9a755ff293 100644
--- a/airflow/contrib/operators/gcp_container_operator.py
+++ b/airflow/contrib/operators/gcp_container_operator.py
@@ -133,6 +133,9 @@ class GKEClusterCreateOperator(BaseOperator):
:type gcp_conn_id: str
:param api_version: The api version to use
:type api_version: str
+ :param xcom_push: return the result of cluster creation operation which
also get
+ set in XCOM
+ :type xcom_push: bool
"""
template_fields = ['project_id', 'gcp_conn_id', 'location', 'api_version',
'body']
@@ -143,6 +146,7 @@ def __init__(self,
body=None,
gcp_conn_id='google_cloud_default',
api_version='v2',
+ xcom_push=True,
*args,
**kwargs):
super(GKEClusterCreateOperator, self).__init__(*args, **kwargs)
@@ -154,6 +158,7 @@ def __init__(self,
self.location = location
self.api_version = api_version
self.body = body
+ self.xcom_push_flag = xcom_push
def _check_input(self):
if all([self.project_id, self.location, self.body]):
@@ -175,7 +180,9 @@ def execute(self, context):
self._check_input()
hook = GKEClusterHook(self.project_id, self.location)
create_op = hook.create_cluster(cluster=self.body)
- return create_op
+
+ if self.xcom_push_flag:
+ return create_op
KUBE_CONFIG_ENV_VAR = "KUBECONFIG"
diff --git a/airflow/contrib/operators/gcp_function_operator.py
b/airflow/contrib/operators/gcp_function_operator.py
index 4455307c93..4bd1a94ad2 100644
--- a/airflow/contrib/operators/gcp_function_operator.py
+++ b/airflow/contrib/operators/gcp_function_operator.py
@@ -527,6 +527,8 @@ class GcfFunctionDeleteOperator(BaseOperator):
:type gcp_conn_id: str
:param api_version: Version of the API used (for example v1).
:type api_version: str
+ :param xcom_push: return the file list which also get set in XCOM
+ :type xcom_push: bool
"""
@apply_defaults
@@ -534,10 +536,13 @@ def __init__(self,
name,
gcp_conn_id='google_cloud_default',
api_version='v1',
- *args, **kwargs):
+ xcom_push=True,
+ *args,
+ **kwargs):
self.name = name
self.gcp_conn_id = gcp_conn_id
self.api_version = api_version
+ self.xcom_push_flag = xcom_push
self._validate_inputs()
self.hook = GcfHook(gcp_conn_id=self.gcp_conn_id,
api_version=self.api_version)
super(GcfFunctionDeleteOperator, self).__init__(*args, **kwargs)
@@ -553,7 +558,10 @@ def _validate_inputs(self):
def execute(self, context):
try:
- return self.hook.delete_function(self.name)
+ result = self.hook.delete_function(self.name)
+
+ if self.xcom_push_flag:
+ return result
except HttpError as e:
status = e.resp.status
if status == 404:
diff --git a/airflow/contrib/operators/gcs_list_operator.py
b/airflow/contrib/operators/gcs_list_operator.py
index 7b37b269a6..ed9311f5d2 100644
--- a/airflow/contrib/operators/gcs_list_operator.py
+++ b/airflow/contrib/operators/gcs_list_operator.py
@@ -45,6 +45,9 @@ class GoogleCloudStorageListOperator(BaseOperator):
For this to work, the service account making the request must have
domain-wide delegation enabled.
:type delegate_to: str
+ :param xcom_push: return the list of Google Cloud Storage Hook which also
get set
+ in XCOM
+ :type xcom_push: bool
**Example**:
The following Operator would list all the Avro files from
``sales/sales-2017``
@@ -68,6 +71,7 @@ def __init__(self,
delimiter=None,
google_cloud_storage_conn_id='google_cloud_default',
delegate_to=None,
+ xcom_push=True,
*args,
**kwargs):
super(GoogleCloudStorageListOperator, self).__init__(*args, **kwargs)
@@ -76,6 +80,7 @@ def __init__(self,
self.delimiter = delimiter
self.google_cloud_storage_conn_id = google_cloud_storage_conn_id
self.delegate_to = delegate_to
+ self.xcom_push_flag = xcom_push
def execute(self, context):
@@ -87,6 +92,9 @@ def execute(self, context):
self.log.info('Getting list of the files. Bucket: %s; Delimiter: %s;
Prefix: %s',
self.bucket, self.delimiter, self.prefix)
- return hook.list(bucket=self.bucket,
- prefix=self.prefix,
- delimiter=self.delimiter)
+ files = hook.list(bucket=self.bucket,
+ prefix=self.prefix,
+ delimiter=self.delimiter)
+
+ if self.xcom_push_flag or __name__ != 'GoogleCloudStorageListOperator':
+ return files
diff --git a/airflow/contrib/operators/gcs_to_bq.py
b/airflow/contrib/operators/gcs_to_bq.py
index 39dff21606..1542171ac4 100644
--- a/airflow/contrib/operators/gcs_to_bq.py
+++ b/airflow/contrib/operators/gcs_to_bq.py
@@ -87,10 +87,7 @@ class GoogleCloudStorageToBigQueryOperator(BaseOperator):
:type allow_jagged_rows: bool
:param max_id_key: If set, the name of a column in the BigQuery table
that's to be loaded. This will be used to select the MAX value from
- BigQuery after the load occurs. The results will be returned by the
- execute() command, which in turn gets stored in XCom for future
- operators to use. This can be helpful with incremental loads--during
- future executions, you can pick up from the max ID.
+ BigQuery after the load occurs. (used in combination with xcom_push)
:type max_id_key: str
:param bigquery_conn_id: Reference to a specific BigQuery hook.
:type bigquery_conn_id: str
@@ -119,6 +116,11 @@ class GoogleCloudStorageToBigQueryOperator(BaseOperator):
time_partitioning. The order of columns given determines the sort
order.
Not applicable for external tables.
:type cluster_fields: list of str
+ :param xcom_push: return the max id which also get set in XCOM.
+ (max_id_key must be set)
+ This can be helpful with incremental loads--during future executions,
you can pick
+ up from the max ID.
+ :type xcom_push: bool
"""
template_fields = ('bucket', 'source_objects',
'schema_object', 'destination_project_dataset_table')
@@ -152,6 +154,7 @@ def __init__(self,
external_table=False,
time_partitioning=None,
cluster_fields=None,
+ xcom_push=True,
*args, **kwargs):
super(GoogleCloudStorageToBigQueryOperator, self).__init__(*args,
**kwargs)
@@ -190,6 +193,7 @@ def __init__(self,
self.src_fmt_configs = src_fmt_configs
self.time_partitioning = time_partitioning
self.cluster_fields = cluster_fields
+ self.xcom_push_flag = xcom_push
def execute(self, context):
bq_hook = BigQueryHook(bigquery_conn_id=self.bigquery_conn_id,
@@ -248,7 +252,7 @@ def execute(self, context):
time_partitioning=self.time_partitioning,
cluster_fields=self.cluster_fields)
- if self.max_id_key:
+ if self.xcom_push_flag and self.max_id_key:
cursor.execute('SELECT MAX({}) FROM {}'.format(
self.max_id_key,
self.destination_project_dataset_table))
diff --git a/airflow/contrib/operators/gcs_to_s3.py
b/airflow/contrib/operators/gcs_to_s3.py
index d8b180c81a..9c23be771b 100644
--- a/airflow/contrib/operators/gcs_to_s3.py
+++ b/airflow/contrib/operators/gcs_to_s3.py
@@ -57,6 +57,8 @@ class
GoogleCloudStorageToS3Operator(GoogleCloudStorageListOperator):
You can specify this argument if you want to use a different
CA cert bundle than the one used by botocore.
:type dest_verify: bool or str
+ :param xcom_push: return the file list which also get set in XCOM
+ :type xcom_push: bool
"""
template_fields = ('bucket', 'prefix', 'delimiter', 'dest_s3_key')
ui_color = '#f0eee4'
@@ -72,6 +74,7 @@ def __init__(self,
dest_s3_key=None,
dest_verify=None,
replace=False,
+ xcom_push=True,
*args,
**kwargs):
@@ -88,6 +91,7 @@ def __init__(self,
self.dest_s3_key = dest_s3_key
self.dest_verify = dest_verify
self.replace = replace
+ self.xcom_push_flag = xcom_push
def execute(self, context):
# use the super to list all files in an Google Cloud Storage bucket
@@ -122,4 +126,5 @@ def execute(self, context):
else:
self.log.info("In sync, no files needed to be uploaded to S3")
- return files
+ if self.xcom_push_flag:
+ return files
diff --git a/airflow/contrib/operators/hive_to_dynamodb.py
b/airflow/contrib/operators/hive_to_dynamodb.py
index 4a39e40741..bd0b937e7e 100644
--- a/airflow/contrib/operators/hive_to_dynamodb.py
+++ b/airflow/contrib/operators/hive_to_dynamodb.py
@@ -70,7 +70,8 @@ def __init__(
schema='default',
hiveserver2_conn_id='hiveserver2_default',
aws_conn_id='aws_default',
- *args, **kwargs):
+ *args,
+ **kwargs):
super(HiveToDynamoDBTransferOperator, self).__init__(*args, **kwargs)
self.sql = sql
self.table_name = table_name
diff --git a/airflow/contrib/operators/jenkins_job_trigger_operator.py
b/airflow/contrib/operators/jenkins_job_trigger_operator.py
index 3e1aba0c1b..0a6e71fa7b 100644
--- a/airflow/contrib/operators/jenkins_job_trigger_operator.py
+++ b/airflow/contrib/operators/jenkins_job_trigger_operator.py
@@ -106,6 +106,8 @@ class JenkinsJobTriggerOperator(BaseOperator):
:param max_try_before_job_appears: The maximum number of requests to make
while waiting for the job to appears on jenkins server (default 10)
:type max_try_before_job_appears: int
+ :param xcom_push: return the build URL which also get set in XCOM
+ :type xcom_push: bool
"""
template_fields = ('parameters',)
template_ext = ('.json',)
@@ -118,6 +120,7 @@ def __init__(self,
parameters="",
sleep_time=10,
max_try_before_job_appears=10,
+ xcom_push=True,
*args,
**kwargs):
super(JenkinsJobTriggerOperator, self).__init__(*args, **kwargs)
@@ -128,6 +131,7 @@ def __init__(self,
self.sleep_time = sleep_time
self.jenkins_connection_id = jenkins_connection_id
self.max_try_before_job_appears = max_try_before_job_appears
+ self.xcom_push_flag = xcom_push
def build_job(self, jenkins_server):
"""
@@ -243,7 +247,8 @@ def execute(self, context):
'this exception for unknown parameters'
'You can also check logs for more details on this
exception '
'(jenkins_url/log/rss)', str(err))
- if build_info:
+
+ if self.xcom_push_flag and build_info:
# If we can we return the url of the job
# for later use (like retrieving an artifact)
return build_info['url']
diff --git a/airflow/contrib/operators/jira_operator.py
b/airflow/contrib/operators/jira_operator.py
index 01d78b8645..5bf81769d6 100644
--- a/airflow/contrib/operators/jira_operator.py
+++ b/airflow/contrib/operators/jira_operator.py
@@ -41,6 +41,8 @@ class JiraOperator(BaseOperator):
:param get_jira_resource_method: function or operator to get jira resource
on which the provided jira_method will be
executed
:type get_jira_resource_method: function
+ :param xcom_push: return the result which also get set in XCOM
+ :type xcom_push: bool
"""
template_fields = ("jira_method_args",)
@@ -52,6 +54,7 @@ def __init__(self,
jira_method_args=None,
result_processor=None,
get_jira_resource_method=None,
+ xcom_push=True,
*args,
**kwargs):
super(JiraOperator, self).__init__(*args, **kwargs)
@@ -60,6 +63,7 @@ def __init__(self,
self.jira_method_args = jira_method_args
self.result_processor = result_processor
self.get_jira_resource_method = get_jira_resource_method
+ self.xcom_push_flag = xcom_push
def execute(self, context):
try:
@@ -82,10 +86,12 @@ def execute(self, context):
# ex: self.xcom_push(context, key='operator_response',
value=jira_response)
# This could potentially throw error if jira_result is not
picklable
jira_result = getattr(resource,
self.method_name)(**self.jira_method_args)
- if self.result_processor:
- return self.result_processor(context, jira_result)
- return jira_result
+ if self.xcom_push_flag:
+ if self.result_processor:
+ return self.result_processor(context, jira_result)
+
+ return jira_result
except JIRAError as jira_error:
raise AirflowException("Failed to execute jiraOperator, error: %s"
diff --git a/airflow/contrib/operators/kubernetes_pod_operator.py
b/airflow/contrib/operators/kubernetes_pod_operator.py
index d3c396e668..2c0eed9095 100644
--- a/airflow/contrib/operators/kubernetes_pod_operator.py
+++ b/airflow/contrib/operators/kubernetes_pod_operator.py
@@ -126,7 +126,7 @@ def execute(self, context):
raise AirflowException(
'Pod returned a failure: {state}'.format(state=final_state)
)
- if self.xcom_push:
+ if self.xcom_push_flag:
return result
except AirflowException as ex:
raise AirflowException('Pod Launching failed:
{error}'.format(error=ex))
@@ -179,7 +179,7 @@ def __init__(self,
self.node_selectors = node_selectors or {}
self.annotations = annotations or {}
self.affinity = affinity or {}
- self.xcom_push = xcom_push
+ self.xcom_push_flag = xcom_push
self.resources = resources or Resources()
self.config_file = config_file
self.image_pull_secrets = image_pull_secrets
diff --git a/airflow/contrib/operators/mlengine_operator.py
b/airflow/contrib/operators/mlengine_operator.py
index 8091ceefff..21f18255ba 100644
--- a/airflow/contrib/operators/mlengine_operator.py
+++ b/airflow/contrib/operators/mlengine_operator.py
@@ -150,6 +150,8 @@ class MLEngineBatchPredictionOperator(BaseOperator):
For this to work, the service account making the request must
have doamin-wide delegation enabled.
:type delegate_to: str
+ :param xcom_push: return the result which also get set in XCOM
+ :type xcom_push: bool
Raises:
``ValueError``: if a unique model/version origin cannot be determined.
@@ -181,6 +183,7 @@ def __init__(self,
runtime_version=None,
gcp_conn_id='google_cloud_default',
delegate_to=None,
+ xcom_push=True,
*args,
**kwargs):
super(MLEngineBatchPredictionOperator, self).__init__(*args, **kwargs)
@@ -198,6 +201,7 @@ def __init__(self,
self._runtime_version = runtime_version
self._gcp_conn_id = gcp_conn_id
self._delegate_to = delegate_to
+ self.xcom_push_flag = xcom_push
if not self._project_id:
raise AirflowException('Google Cloud project id is required.')
@@ -272,7 +276,8 @@ def check_existing_job(existing_job):
str(finished_prediction_job)))
raise RuntimeError(finished_prediction_job['errorMessage'])
- return finished_prediction_job['predictionOutput']
+ if self.xcom_push_flag:
+ return finished_prediction_job['predictionOutput']
class MLEngineModelOperator(BaseOperator):
@@ -300,6 +305,8 @@ class MLEngineModelOperator(BaseOperator):
For this to work, the service account making the request must have
domain-wide delegation enabled.
:type delegate_to: str
+ :param xcom_push: return the result which also get set in XCOM
+ :type xcom_push: bool
"""
template_fields = [
@@ -313,6 +320,7 @@ def __init__(self,
operation='create',
gcp_conn_id='google_cloud_default',
delegate_to=None,
+ xcom_push=True,
*args,
**kwargs):
super(MLEngineModelOperator, self).__init__(*args, **kwargs)
@@ -321,17 +329,21 @@ def __init__(self,
self._operation = operation
self._gcp_conn_id = gcp_conn_id
self._delegate_to = delegate_to
+ self.xcom_push_flag = xcom_push
def execute(self, context):
hook = MLEngineHook(
gcp_conn_id=self._gcp_conn_id, delegate_to=self._delegate_to)
if self._operation == 'create':
- return hook.create_model(self._project_id, self._model)
+ result = hook.create_model(self._project_id, self._model)
elif self._operation == 'get':
- return hook.get_model(self._project_id, self._model['name'])
+ result = hook.get_model(self._project_id, self._model['name'])
else:
raise ValueError('Unknown operation: {}'.format(self._operation))
+ if self.xcom_push_flag:
+ return result
+
class MLEngineVersionOperator(BaseOperator):
"""
@@ -387,6 +399,8 @@ class MLEngineVersionOperator(BaseOperator):
For this to work, the service account making the request must have
domain-wide delegation enabled.
:type delegate_to: str
+ :param xcom_push: return the result which also get set in XCOM
+ :type xcom_push: bool
"""
template_fields = [
@@ -404,6 +418,7 @@ def __init__(self,
operation='create',
gcp_conn_id='google_cloud_default',
delegate_to=None,
+ xcom_push=True,
*args,
**kwargs):
@@ -415,6 +430,7 @@ def __init__(self,
self._operation = operation
self._gcp_conn_id = gcp_conn_id
self._delegate_to = delegate_to
+ self.xcom_push_flag = xcom_push
def execute(self, context):
if 'name' not in self._version:
@@ -427,19 +443,22 @@ def execute(self, context):
if not self._version:
raise ValueError("version attribute of {} could not "
"be empty".format(self.__class__.__name__))
- return hook.create_version(self._project_id, self._model_name,
- self._version)
+ result = hook.create_version(self._project_id, self._model_name,
+ self._version)
elif self._operation == 'set_default':
- return hook.set_default_version(self._project_id, self._model_name,
- self._version['name'])
+ result = hook.set_default_version(self._project_id,
self._model_name,
+ self._version['name'])
elif self._operation == 'list':
- return hook.list_versions(self._project_id, self._model_name)
+ result = hook.list_versions(self._project_id, self._model_name)
elif self._operation == 'delete':
- return hook.delete_version(self._project_id, self._model_name,
- self._version['name'])
+ result = hook.delete_version(self._project_id, self._model_name,
+ self._version['name'])
else:
raise ValueError('Unknown operation: {}'.format(self._operation))
+ if self.xcom_push_flag:
+ return result
+
class MLEngineTrainingOperator(BaseOperator):
"""
@@ -585,8 +604,7 @@ def execute(self, context):
if self._mode == 'DRY_RUN':
self.log.info('In dry_run mode.')
- self.log.info('MLEngine Training job request is: {}'.format(
- training_request))
+ self.log.info('MLEngine Training job request is:
{}'.format(training_request))
return
hook = MLEngineHook(
diff --git a/airflow/contrib/operators/mlengine_prediction_summary.py
b/airflow/contrib/operators/mlengine_prediction_summary.py
index def793c1be..227cf74488 100644
--- a/airflow/contrib/operators/mlengine_prediction_summary.py
+++ b/airflow/contrib/operators/mlengine_prediction_summary.py
@@ -113,7 +113,7 @@ def decode(x):
@beam.ptransform_fn
-def MakeSummary(pcoll, metric_fn, metric_keys): # pylint: disable=invalid-name
+def make_summary(pcoll, metric_fn, metric_keys): # pylint:
disable=invalid-name
return (
pcoll |
"ApplyMetricFnPerInstance" >> beam.Map(metric_fn) |
@@ -156,8 +156,7 @@ def run(argv=None):
raise ValueError("--metric_fn_encoded must be an encoded callable.")
metric_keys = known_args.metric_keys.split(",")
- with beam.Pipeline(
- options=beam.pipeline.PipelineOptions(pipeline_args)) as p:
+ with beam.Pipeline(options=beam.pipeline.PipelineOptions(pipeline_args))
as p:
# This is apache-beam ptransform's convention
# pylint: disable=no-value-for-parameter
_ = (p
@@ -165,7 +164,7 @@ def run(argv=None):
os.path.join(known_args.prediction_path,
"prediction.results-*-of-*"),
coder=JsonCoder())
- | "Summary" >> MakeSummary(metric_fn, metric_keys)
+ | "Summary" >> make_summary(metric_fn, metric_keys)
| "Write" >> beam.io.WriteToText(
os.path.join(known_args.prediction_path,
"prediction.summary.json"),
diff --git a/airflow/contrib/operators/mongo_to_s3.py
b/airflow/contrib/operators/mongo_to_s3.py
index 8bfa7a52f8..82e8040132 100644
--- a/airflow/contrib/operators/mongo_to_s3.py
+++ b/airflow/contrib/operators/mongo_to_s3.py
@@ -94,8 +94,6 @@ def execute(self, context):
replace=self.replace
)
- return True
-
@staticmethod
def _stringify(iterable, joinable='\n'):
"""
diff --git a/airflow/contrib/operators/pubsub_operator.py
b/airflow/contrib/operators/pubsub_operator.py
index e40828bf92..9fabf473b6 100644
--- a/airflow/contrib/operators/pubsub_operator.py
+++ b/airflow/contrib/operators/pubsub_operator.py
@@ -159,6 +159,7 @@ def __init__(
fail_if_exists=False,
gcp_conn_id='google_cloud_default',
delegate_to=None,
+ xcom_push=True,
*args,
**kwargs):
"""
@@ -185,6 +186,8 @@ def __init__(
For this to work, the service account making the request
must have domain-wide delegation enabled.
:type delegate_to: str
+ :param xcom_push: return the result which also get set in XCOM
+ :type xcom_push: bool
"""
super(PubSubSubscriptionCreateOperator, self).__init__(*args, **kwargs)
@@ -196,16 +199,20 @@ def __init__(
self.fail_if_exists = fail_if_exists
self.gcp_conn_id = gcp_conn_id
self.delegate_to = delegate_to
+ self.xcom_push_flag = xcom_push
def execute(self, context):
hook = PubSubHook(gcp_conn_id=self.gcp_conn_id,
delegate_to=self.delegate_to)
- return hook.create_subscription(
+ result = hook.create_subscription(
self.topic_project, self.topic, self.subscription,
self.subscription_project, self.ack_deadline_secs,
self.fail_if_exists)
+ if self.xcom_push_flag:
+ return result
+
class PubSubTopicDeleteOperator(BaseOperator):
"""Delete a PubSub topic.
diff --git a/airflow/contrib/operators/qubole_operator.py
b/airflow/contrib/operators/qubole_operator.py
index 82ee293b93..d86c1b35ea 100755
--- a/airflow/contrib/operators/qubole_operator.py
+++ b/airflow/contrib/operators/qubole_operator.py
@@ -28,6 +28,8 @@ class QuboleOperator(BaseOperator):
:param qubole_conn_id: Connection id which consists of qds auth_token
:type qubole_conn_id: str
+ :param xcom_push: return the result which also get set in XCOM
+ :type xcom_push: bool
kwargs:
:command_type: type of command to be executed, e.g. hivecmd, shellcmd,
hadoopcmd
@@ -129,11 +131,16 @@ class QuboleOperator(BaseOperator):
ui_fgcolor = '#fff'
@apply_defaults
- def __init__(self, qubole_conn_id="qubole_default", *args, **kwargs):
+ def __init__(self,
+ qubole_conn_id="qubole_default",
+ xcom_push=True,
+ *args,
+ **kwargs):
self.args = args
self.kwargs = kwargs
self.kwargs['qubole_conn_id'] = qubole_conn_id
super(QuboleOperator, self).__init__(*args, **kwargs)
+ self.xcom_push_flag = xcom_push
if self.on_failure_callback is None:
self.on_failure_callback = QuboleHook.handle_failure_retry
@@ -142,7 +149,10 @@ def __init__(self, qubole_conn_id="qubole_default", *args,
**kwargs):
self.on_retry_callback = QuboleHook.handle_failure_retry
def execute(self, context):
- return self.get_hook().execute(context)
+ result = self.get_hook().execute(context)
+
+ if self.xcom_push_flag:
+ return result
def on_kill(self, ti=None):
self.get_hook().kill(ti)
diff --git a/airflow/contrib/operators/s3_list_operator.py
b/airflow/contrib/operators/s3_list_operator.py
index 3ca22d5932..c66e816c21 100644
--- a/airflow/contrib/operators/s3_list_operator.py
+++ b/airflow/contrib/operators/s3_list_operator.py
@@ -48,6 +48,8 @@ class S3ListOperator(BaseOperator):
You can specify this argument if you want to use a different
CA cert bundle than the one used by botocore.
:type verify: bool or str
+ :param xcom_push: return the key list which also get set in XCOM
+ :type xcom_push: bool
**Example**:
The following operator would list all the files
@@ -72,6 +74,7 @@ def __init__(self,
delimiter='',
aws_conn_id='aws_default',
verify=None,
+ xcom_push=True,
*args,
**kwargs):
super(S3ListOperator, self).__init__(*args, **kwargs)
@@ -80,6 +83,7 @@ def __init__(self,
self.delimiter = delimiter
self.aws_conn_id = aws_conn_id
self.verify = verify
+ self.xcom_push_flag = xcom_push
def execute(self, context):
hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify)
@@ -88,7 +92,10 @@ def execute(self, context):
'Getting the list of files from bucket: {0} in prefix: {1}
(Delimiter {2})'.
format(self.bucket, self.prefix, self.delimiter))
- return hook.list_keys(
+ list_keys = hook.list_keys(
bucket_name=self.bucket,
prefix=self.prefix,
delimiter=self.delimiter)
+
+ if self.xcom_push_flag:
+ return list_keys
diff --git a/airflow/contrib/operators/s3_to_gcs_operator.py
b/airflow/contrib/operators/s3_to_gcs_operator.py
index 5dd355a6fd..f7d14ef2a5 100644
--- a/airflow/contrib/operators/s3_to_gcs_operator.py
+++ b/airflow/contrib/operators/s3_to_gcs_operator.py
@@ -64,6 +64,8 @@ class S3ToGoogleCloudStorageOperator(S3ListOperator):
:param replace: Whether you want to replace existing destination files
or not.
:type replace: bool
+ :param xcom_push: return the file list which also get set in XCOM
+ :type xcom_push: bool
**Example**:
@@ -97,6 +99,7 @@ def __init__(self,
dest_gcs=None,
delegate_to=None,
replace=False,
+ xcom_push=True,
*args,
**kwargs):
@@ -112,6 +115,7 @@ def __init__(self,
self.delegate_to = delegate_to
self.replace = replace
self.verify = verify
+ self.xcom_push_flag = xcom_push
if dest_gcs and not self._gcs_object_is_directory(self.dest_gcs):
self.log.info(
@@ -194,7 +198,8 @@ def execute(self, context):
'In sync, no files needed to be uploaded to Google Cloud'
'Storage')
- return files
+ if self.xcom_push_flag:
+ return files
# Following functionality may be better suited in
# airflow/contrib/hooks/gcs_hook.py
diff --git a/airflow/contrib/operators/sftp_operator.py
b/airflow/contrib/operators/sftp_operator.py
index a3b5c1f244..a2180fa679 100644
--- a/airflow/contrib/operators/sftp_operator.py
+++ b/airflow/contrib/operators/sftp_operator.py
@@ -117,4 +117,4 @@ def execute(self, context):
raise AirflowException("Error while transferring {0}, error: {1}"
.format(file_msg, str(e)))
- return None
+ return
diff --git a/airflow/contrib/operators/ssh_operator.py
b/airflow/contrib/operators/ssh_operator.py
index 2bf342935d..0b031ae46f 100644
--- a/airflow/contrib/operators/ssh_operator.py
+++ b/airflow/contrib/operators/ssh_operator.py
@@ -45,8 +45,8 @@ class SSHOperator(BaseOperator):
:type command: str
:param timeout: timeout (in seconds) for executing the command.
:type timeout: int
- :param do_xcom_push: return the stdout which also get set in xcom by
airflow platform
- :type do_xcom_push: bool
+ :param xcom_push: return the stdout which also get set in XCOM
+ :type xcom_push: bool
"""
template_fields = ('command', 'remote_host')
@@ -59,7 +59,7 @@ def __init__(self,
remote_host=None,
command=None,
timeout=10,
- do_xcom_push=False,
+ xcom_push=False,
*args,
**kwargs):
super(SSHOperator, self).__init__(*args, **kwargs)
@@ -68,7 +68,7 @@ def __init__(self,
self.remote_host = remote_host
self.command = command
self.timeout = timeout
- self.do_xcom_push = do_xcom_push
+ self.xcom_push_flag = xcom_push
def execute(self, context):
try:
@@ -148,8 +148,8 @@ def execute(self, context):
exit_status = stdout.channel.recv_exit_status()
if exit_status is 0:
- # returning output if do_xcom_push is set
- if self.do_xcom_push:
+ # returning output if xcom_push is set
+ if self.xcom_push_flag:
enable_pickling = configuration.conf.getboolean(
'core', 'enable_xcom_pickling'
)
@@ -166,8 +166,6 @@ def execute(self, context):
except Exception as e:
raise AirflowException("SSH operator error: {0}".format(str(e)))
- return True
-
def tunnel(self):
ssh_client = self.ssh_hook.get_conn()
ssh_client.get_transport()
diff --git a/airflow/contrib/operators/winrm_operator.py
b/airflow/contrib/operators/winrm_operator.py
index c81acac44f..fb7f229049 100644
--- a/airflow/contrib/operators/winrm_operator.py
+++ b/airflow/contrib/operators/winrm_operator.py
@@ -48,8 +48,8 @@ class WinRMOperator(BaseOperator):
:type command: str
:param timeout: timeout for executing the command.
:type timeout: int
- :param do_xcom_push: return the stdout which also get set in xcom by
airflow platform
- :type do_xcom_push: bool
+ :param xcom_push: return the stdout which also get set in XCOM
+ :type xcom_push: bool
"""
template_fields = ('command',)
@@ -60,7 +60,7 @@ def __init__(self,
remote_host=None,
command=None,
timeout=10,
- do_xcom_push=False,
+ xcom_push=False,
*args,
**kwargs):
super(WinRMOperator, self).__init__(*args, **kwargs)
@@ -69,7 +69,7 @@ def __init__(self,
self.remote_host = remote_host
self.command = command
self.timeout = timeout
- self.do_xcom_push = do_xcom_push
+ self.xcom_push_flag = xcom_push
def execute(self, context):
if self.ssh_conn_id and not self.winrm_hook:
@@ -107,7 +107,7 @@ def execute(self, context):
)
# Only buffer stdout if we need to so that we minimize
memory usage.
- if self.do_xcom_push:
+ if self.xcom_push_flag:
stdout_buffer.append(stdout)
stderr_buffer.append(stderr)
@@ -127,8 +127,8 @@ def execute(self, context):
raise AirflowException("WinRM operator error: {0}".format(str(e)))
if return_code is 0:
- # returning output if do_xcom_push is set
- if self.do_xcom_push:
+ # returning output if xcom_push is set
+ if self.xcom_push_flag:
enable_pickling = configuration.conf.getboolean(
'core', 'enable_xcom_pickling'
)
@@ -145,5 +145,3 @@ def execute(self, context):
raise AirflowException(error_msg)
self.log.info("Finished!")
-
- return True
diff --git a/tests/contrib/operators/test_gcs_list_operator.py
b/tests/contrib/operators/test_gcs_list_operator.py
index 8f0281f76f..cb28ff59fc 100644
--- a/tests/contrib/operators/test_gcs_list_operator.py
+++ b/tests/contrib/operators/test_gcs_list_operator.py
@@ -52,3 +52,19 @@ def test_execute(self, mock_hook):
bucket=TEST_BUCKET, prefix=PREFIX, delimiter=DELIMITER
)
self.assertEqual(sorted(files), sorted(MOCK_FILES))
+
+
@mock.patch('airflow.contrib.operators.gcs_list_operator.GoogleCloudStorageHook')
+ def test_execute_without_xcom_push(self, mock_hook):
+ mock_hook.return_value.list.return_value = MOCK_FILES
+
+ operator = GoogleCloudStorageListOperator(task_id=TASK_ID,
+ bucket=TEST_BUCKET,
+ prefix=PREFIX,
+ delimiter=DELIMITER,
+ xcom_push=False)
+
+ files = operator.execute(None)
+ mock_hook.return_value.list.assert_called_once_with(
+ bucket=TEST_BUCKET, prefix=PREFIX, delimiter=DELIMITER
+ )
+ self.assertEqual(files, None)
diff --git a/tests/contrib/operators/test_gcs_to_s3_operator.py
b/tests/contrib/operators/test_gcs_to_s3_operator.py
index 1a35a223ab..01cf12e6f1 100644
--- a/tests/contrib/operators/test_gcs_to_s3_operator.py
+++ b/tests/contrib/operators/test_gcs_to_s3_operator.py
@@ -72,3 +72,29 @@ def test_execute(self, mock_hook, mock_hook2):
sorted(uploaded_files))
self.assertEqual(sorted(MOCK_FILES),
sorted(hook.list_keys('bucket', delimiter='/')))
+
+ @mock_s3
+
@mock.patch('airflow.contrib.operators.gcs_list_operator.GoogleCloudStorageHook')
+ @mock.patch('airflow.contrib.operators.gcs_to_s3.GoogleCloudStorageHook')
+ def test_execute_without_xcom_push(self, mock_hook, mock_hook2):
+ mock_hook.return_value.list.return_value = MOCK_FILES
+ mock_hook.return_value.download.return_value = b"testing"
+ mock_hook2.return_value.list.return_value = MOCK_FILES
+
+ operator = GoogleCloudStorageToS3Operator(task_id=TASK_ID,
+ bucket=GCS_BUCKET,
+ prefix=PREFIX,
+ delimiter=DELIMITER,
+ dest_aws_conn_id=None,
+ dest_s3_key=S3_BUCKET,
+ xcom_push=False)
+ # create dest bucket
+ hook = S3Hook(aws_conn_id=None)
+ b = hook.get_bucket('bucket')
+ b.create()
+ b.put_object(Key=MOCK_FILES[0], Body=b'testing')
+
+ uploaded_files = operator.execute(None)
+ self.assertEqual(uploaded_files, None)
+ self.assertEqual(sorted(MOCK_FILES),
+ sorted(hook.list_keys('bucket', delimiter='/')))
diff --git a/tests/contrib/operators/test_sftp_operator.py
b/tests/contrib/operators/test_sftp_operator.py
index bf4525e311..6788e84f7a 100644
--- a/tests/contrib/operators/test_sftp_operator.py
+++ b/tests/contrib/operators/test_sftp_operator.py
@@ -95,7 +95,7 @@ def test_pickle_file_transfer_put(self):
task_id="test_check_file",
ssh_hook=self.hook,
command="cat {0}".format(self.test_remote_filepath),
- do_xcom_push=True,
+ xcom_push=True,
dag=self.dag
)
self.assertIsNotNone(check_file_task)
@@ -132,7 +132,7 @@ def test_json_file_transfer_put(self):
task_id="test_check_file",
ssh_hook=self.hook,
command="cat {0}".format(self.test_remote_filepath),
- do_xcom_push=True,
+ xcom_push=True,
dag=self.dag
)
self.assertIsNotNone(check_file_task)
@@ -155,7 +155,7 @@ def test_pickle_file_transfer_get(self):
ssh_hook=self.hook,
command="echo '{0}' > {1}".format(test_remote_file_content,
self.test_remote_filepath),
- do_xcom_push=True,
+ xcom_push=True,
dag=self.dag
)
self.assertIsNotNone(create_file_task)
@@ -193,7 +193,7 @@ def test_json_file_transfer_get(self):
ssh_hook=self.hook,
command="echo '{0}' > {1}".format(test_remote_file_content,
self.test_remote_filepath),
- do_xcom_push=True,
+ xcom_push=True,
dag=self.dag
)
self.assertIsNotNone(create_file_task)
@@ -295,7 +295,7 @@ def delete_remote_resource(self):
task_id="test_check_file",
ssh_hook=self.hook,
command="rm {0}".format(self.test_remote_filepath),
- do_xcom_push=True,
+ xcom_push=True,
dag=self.dag
)
self.assertIsNotNone(remove_file_task)
diff --git a/tests/contrib/operators/test_ssh_operator.py
b/tests/contrib/operators/test_ssh_operator.py
index 1a2c788596..b757389e3b 100644
--- a/tests/contrib/operators/test_ssh_operator.py
+++ b/tests/contrib/operators/test_ssh_operator.py
@@ -82,7 +82,7 @@ def test_json_command_execution(self):
task_id="test",
ssh_hook=self.hook,
command="echo -n airflow",
- do_xcom_push=True,
+ xcom_push=True,
dag=self.dag,
)
@@ -101,7 +101,7 @@ def test_pickle_command_execution(self):
task_id="test",
ssh_hook=self.hook,
command="echo -n airflow",
- do_xcom_push=True,
+ xcom_push=True,
dag=self.dag,
)
@@ -119,7 +119,7 @@ def test_command_execution_with_env(self):
task_id="test",
ssh_hook=self.hook,
command="echo -n airflow",
- do_xcom_push=True,
+ xcom_push=True,
dag=self.dag,
)
@@ -137,7 +137,7 @@ def test_no_output_command(self):
task_id="test",
ssh_hook=self.hook,
command="sleep 1",
- do_xcom_push=True,
+ xcom_push=True,
dag=self.dag,
)
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services