kaxil closed pull request #4367: [AIRFLOW-3551] Improve BashOperator Test Coverage URL: https://github.com/apache/incubator-airflow/pull/4367
This is a PR merged from a forked repository. As GitHub hides the original diff on merge, it is displayed below for the sake of provenance: As this is a foreign pull request (from a fork), the diff is supplied below (as it won't show otherwise due to GitHub magic): diff --git a/airflow/operators/bash_operator.py b/airflow/operators/bash_operator.py index 13aa44fc85..a2217adf40 100644 --- a/airflow/operators/bash_operator.py +++ b/airflow/operators/bash_operator.py @@ -20,11 +20,10 @@ import os import signal +from builtins import bytes from subprocess import Popen, STDOUT, PIPE from tempfile import gettempdir, NamedTemporaryFile -from builtins import bytes - from airflow.exceptions import AirflowException from airflow.models import BaseOperator from airflow.utils.decorators import apply_defaults @@ -66,13 +65,12 @@ class BashOperator(BaseOperator): ui_color = '#f0ede4' @apply_defaults - def __init__( - self, - bash_command, - xcom_push=False, - env=None, - output_encoding='utf-8', - *args, **kwargs): + def __init__(self, + bash_command, + xcom_push=False, + env=None, + output_encoding='utf-8', + *args, **kwargs): super(BashOperator, self).__init__(*args, **kwargs) self.bash_command = bash_command @@ -85,14 +83,14 @@ def execute(self, context): Execute the bash command in a temporary directory which will be cleaned afterwards """ - self.log.info("Tmp dir root location: \n %s", gettempdir()) + self.log.info('Tmp dir root location: \n %s', gettempdir()) # Prepare env for child process. if self.env is None: self.env = os.environ.copy() - airflow_context_vars = context_to_airflow_vars(context, - in_env_var_format=True) - self.log.info("Exporting the following env vars:\n" + + + airflow_context_vars = context_to_airflow_vars(context, in_env_var_format=True) + self.log.info('Exporting the following env vars:\n' + '\n'.join(["{}={}".format(k, v) for k, v in airflow_context_vars.items()])) @@ -101,16 +99,11 @@ def execute(self, context): self.lineage_data = self.bash_command with TemporaryDirectory(prefix='airflowtmp') as tmp_dir: - with NamedTemporaryFile(dir=tmp_dir, prefix=self.task_id) as f: - - f.write(bytes(self.bash_command, 'utf_8')) - f.flush() - fname = f.name - script_location = os.path.abspath(fname) - self.log.info( - "Temporary script location: %s", - script_location - ) + with NamedTemporaryFile(dir=tmp_dir, prefix=self.task_id) as tmp_file: + tmp_file.write(bytes(self.bash_command, 'utf_8')) + tmp_file.flush() + script_location = os.path.abspath(tmp_file.name) + self.log.info('Temporary script location: %s', script_location) def pre_exec(): # Restore default signal disposition and invoke setsid @@ -119,32 +112,33 @@ def pre_exec(): signal.signal(getattr(signal, sig), signal.SIG_DFL) os.setsid() - self.log.info("Running command: %s", self.bash_command) - sp = Popen( - ['bash', fname], - stdout=PIPE, stderr=STDOUT, - cwd=tmp_dir, env=self.env, + self.log.info('Running command: %s', self.bash_command) + sub_process = Popen( + ['bash', tmp_file.name], + stdout=PIPE, + stderr=STDOUT, + cwd=tmp_dir, + env=self.env, preexec_fn=pre_exec) - self.sp = sp + self.sub_process = sub_process - self.log.info("Output:") + self.log.info('Output:') line = '' - for line in iter(sp.stdout.readline, b''): - line = line.decode(self.output_encoding).rstrip() + for raw_line in iter(sub_process.stdout.readline, b''): + line = raw_line.decode(self.output_encoding).rstrip() self.log.info(line) - sp.wait() - self.log.info( - "Command exited with return code %s", - sp.returncode - ) - if sp.returncode: - raise AirflowException("Bash command failed") + sub_process.wait() + + self.log.info('Command exited with return code %s', sub_process.returncode) + + if sub_process.returncode: + raise AirflowException('Bash command failed') if self.xcom_push_flag: return line def on_kill(self): self.log.info('Sending SIGTERM signal to bash process group') - os.killpg(os.getpgid(self.sp.pid), signal.SIGTERM) + os.killpg(os.getpgid(self.sub_process.pid), signal.SIGTERM) diff --git a/tests/operators/test_bash_operator.py b/tests/operators/test_bash_operator.py index 8f55b9cda1..e298682734 100644 --- a/tests/operators/test_bash_operator.py +++ b/tests/operators/test_bash_operator.py @@ -15,6 +15,7 @@ import os import unittest from datetime import datetime, timedelta +from tempfile import NamedTemporaryFile from airflow import DAG from airflow.models import State @@ -26,7 +27,8 @@ INTERVAL = timedelta(hours=12) -class BashOperatorTestCase(unittest.TestCase): +class BashOperatorTest(unittest.TestCase): + def test_echo_env_variables(self): """ Test that env variables are exported correctly to the @@ -52,10 +54,8 @@ def test_echo_env_variables(self): external_trigger=False, ) - import tempfile - with tempfile.NamedTemporaryFile() as f: - fname = f.name - t = BashOperator( + with NamedTemporaryFile() as tmp_file: + task = BashOperator( task_id='echo_env_vars', dag=self.dag, bash_command='echo $AIRFLOW_HOME>> {0};' @@ -63,17 +63,17 @@ def test_echo_env_variables(self): 'echo $AIRFLOW_CTX_DAG_ID >> {0};' 'echo $AIRFLOW_CTX_TASK_ID>> {0};' 'echo $AIRFLOW_CTX_EXECUTION_DATE>> {0};' - 'echo $AIRFLOW_CTX_DAG_RUN_ID>> {0};'.format(fname) + 'echo $AIRFLOW_CTX_DAG_RUN_ID>> {0};'.format(tmp_file.name) ) original_AIRFLOW_HOME = os.environ['AIRFLOW_HOME'] os.environ['AIRFLOW_HOME'] = 'MY_PATH_TO_AIRFLOW_HOME' - t.run(DEFAULT_DATE, DEFAULT_DATE, - ignore_first_depends_on_past=True, ignore_ti_state=True) + task.run(DEFAULT_DATE, DEFAULT_DATE, + ignore_first_depends_on_past=True, ignore_ti_state=True) - with open(fname, 'r') as fr: - output = ''.join(fr.readlines()) + with open(tmp_file.name, 'r') as file: + output = ''.join(file.readlines()) self.assertIn('MY_PATH_TO_AIRFLOW_HOME', output) # exported in run_unit_tests.sh as part of PYTHONPATH self.assertIn('tests/test_utils', output) @@ -83,3 +83,14 @@ def test_echo_env_variables(self): self.assertIn('manual__' + DEFAULT_DATE.isoformat(), output) os.environ['AIRFLOW_HOME'] = original_AIRFLOW_HOME + + def test_return_value_to_xcom(self): + bash_operator = BashOperator( + bash_command='echo "stdout"', + xcom_push=True, + task_id='test_return_value_to_xcom', + dag=None + ) + xcom_return_value = bash_operator.execute(context={}) + + self.assertEqual(xcom_return_value, u'stdout') ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services