Repository: incubator-airflow Updated Branches: refs/heads/master 9ebb04acb -> e9babff4e
[AIRFLOW-2463] Make task instance context available for hive queries [AIRFLOW-2463] Make task instance context available for hive queries update UPDATING.md, please squash Closes #3405 from yrqls21/kevin_yang_add_context Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/e9babff4 Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/e9babff4 Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/e9babff4 Branch: refs/heads/master Commit: e9babff4eb3334b0d71cda31c2f6cbfe7b741389 Parents: 9ebb04a Author: Kevin Yang <kevin.y...@airbnb.com> Authored: Wed Jul 11 10:28:06 2018 +0200 Committer: Fokko Driesprong <fokkodriespr...@godatadriven.com> Committed: Wed Jul 11 10:28:06 2018 +0200 ---------------------------------------------------------------------- UPDATING.md | 3 +- airflow/hooks/hive_hooks.py | 88 +++++++++++++---- airflow/operators/bash_operator.py | 14 ++- airflow/operators/hive_to_mysql.py | 8 +- airflow/operators/hive_to_samba_operator.py | 4 +- airflow/operators/python_operator.py | 16 +++- airflow/utils/operator_helpers.py | 59 ++++++++---- tests/hooks/test_hive_hook.py | 83 ++++++++++++++-- tests/operators/bash_operator.py | 12 ++- tests/operators/python_operator.py | 116 +++++++++++++++++++++-- tests/utils/test_operator_helpers.py | 32 +++++-- 11 files changed, 358 insertions(+), 77 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/e9babff4/UPDATING.md ---------------------------------------------------------------------- diff --git a/UPDATING.md b/UPDATING.md index 82983b2..74a0b1b 100644 --- a/UPDATING.md +++ b/UPDATING.md @@ -42,6 +42,7 @@ There are five roles created for Airflow by default: Admin, User, Op, Viewer, an - Airflow dag home page is now `/home` (instead of `/admin`). - All ModelViews in Flask-AppBuilder follow a different pattern from Flask-Admin. The `/admin` part of the url path will no longer exist. For example: `/admin/connection` becomes `/connection/list`, `/admin/connection/new` becomes `/connection/add`, `/admin/connection/edit` becomes `/connection/edit`, etc. - Due to security concerns, the new webserver will no longer support the features in the `Data Profiling` menu of old UI, including `Ad Hoc Query`, `Charts`, and `Known Events`. +- HiveServer2Hook.get_results() always returns a list of tuples, even when a single column is queried, as per Python API 2. ### airflow.contrib.sensors.hdfs_sensors renamed to airflow.contrib.sensors.hdfs_sensor @@ -77,7 +78,7 @@ supported and will be removed entirely in Airflow 2.0 With Airflow 1.9 or lower, Unload operation always included header row. In order to include header row, we need to turn off parallel unload. It is preferred to perform unload operation using all nodes so that it is faster for larger tables. So, parameter called `include_header` is added and default is set to False. -Header row will be added only if this parameter is set True and also in that case parallel will be automatically turned off (`PARALLEL OFF`) +Header row will be added only if this parameter is set True and also in that case parallel will be automatically turned off (`PARALLEL OFF`) ### Google cloud connection string http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/e9babff4/airflow/hooks/hive_hooks.py ---------------------------------------------------------------------- diff --git a/airflow/hooks/hive_hooks.py b/airflow/hooks/hive_hooks.py index 93e1f45..5ac99ee 100644 --- a/airflow/hooks/hive_hooks.py +++ b/airflow/hooks/hive_hooks.py @@ -21,30 +21,40 @@ from __future__ import print_function, unicode_literals import contextlib import os - -from six.moves import zip -from past.builtins import basestring, unicode - -import unicodecsv as csv import re -import six import subprocess import time from collections import OrderedDict from tempfile import NamedTemporaryFile + import hmsclient +import six +import unicodecsv as csv +from past.builtins import basestring +from past.builtins import unicode +from six.moves import zip -from airflow import configuration as conf +import airflow.security.utils as utils +from airflow import configuration from airflow.exceptions import AirflowException from airflow.hooks.base_hook import BaseHook -from airflow.utils.helpers import as_flattened_list from airflow.utils.file import TemporaryDirectory -from airflow import configuration -import airflow.security.utils as utils +from airflow.utils.helpers import as_flattened_list +from airflow.utils.operator_helpers import AIRFLOW_VAR_NAME_FORMAT_MAPPING HIVE_QUEUE_PRIORITIES = ['VERY_HIGH', 'HIGH', 'NORMAL', 'LOW', 'VERY_LOW'] +def get_context_from_env_var(): + """ + Extract context from env variable, e.g. dag_id, task_id and execution_date, + so that they can be used inside BashOperator and PythonOperator. + :return: The context of interest. + """ + return {format_map['default']: os.environ.get(format_map['env_var_format'], '') + for format_map in AIRFLOW_VAR_NAME_FORMAT_MAPPING.values()} + + class HiveCliHook(BaseHook): """Simple wrapper around the hive CLI. @@ -92,8 +102,8 @@ class HiveCliHook(BaseHook): "Invalid Mapred Queue Priority. Valid values are: " "{}".format(', '.join(HIVE_QUEUE_PRIORITIES))) - self.mapred_queue = mapred_queue or conf.get('hive', - 'default_hive_mapred_queue') + self.mapred_queue = mapred_queue or configuration.get('hive', + 'default_hive_mapred_queue') self.mapred_queue_priority = mapred_queue_priority self.mapred_job_name = mapred_job_name @@ -126,6 +136,7 @@ class HiveCliHook(BaseHook): jdbc_url += ";auth=" + self.auth jdbc_url = jdbc_url.format(**locals()) + jdbc_url = '"{}"'.format(jdbc_url) cmd_extra += ['-u', jdbc_url] if conn.login: @@ -184,10 +195,15 @@ class HiveCliHook(BaseHook): with TemporaryDirectory(prefix='airflow_hiveop_') as tmp_dir: with NamedTemporaryFile(dir=tmp_dir) as f: + hql = hql + '\n' f.write(hql.encode('UTF-8')) f.flush() hive_cmd = self._prepare_cli_cmd() - hive_conf_params = self._prepare_hiveconf(hive_conf) + env_context = get_context_from_env_var() + # Only extend the hive_conf if it is defined. + if hive_conf: + env_context.update(hive_conf) + hive_conf_params = self._prepare_hiveconf(env_context) if self.mapred_queue: hive_conf_params.extend( ['-hiveconf', @@ -772,7 +788,7 @@ class HiveServer2Hook(BaseHook): username=db.login or username, database=schema or db.schema or 'default') - def _get_results(self, hql, schema='default', fetch_size=None): + def _get_results(self, hql, schema='default', fetch_size=None, hive_conf=None): from pyhive.exc import ProgrammingError if isinstance(hql, basestring): hql = [hql] @@ -780,12 +796,21 @@ class HiveServer2Hook(BaseHook): with contextlib.closing(self.get_conn(schema)) as conn, \ contextlib.closing(conn.cursor()) as cur: cur.arraysize = fetch_size or 1000 + + env_context = get_context_from_env_var() + if hive_conf: + env_context.update(hive_conf) + for k, v in env_context.items(): + cur.execute("set {}={}".format(k, v)) + for statement in hql: cur.execute(statement) # we only get results of statements that returns lowered_statement = statement.lower().strip() if (lowered_statement.startswith('select') or - lowered_statement.startswith('with')): + lowered_statement.startswith('with') or + (lowered_statement.startswith('set') and + '=' not in lowered_statement)): description = [c for c in cur.description] if previous_description and previous_description != description: message = '''The statements are producing different descriptions: @@ -805,8 +830,17 @@ class HiveServer2Hook(BaseHook): except ProgrammingError: self.log.debug("get_results returned no records") - def get_results(self, hql, schema='default', fetch_size=None): - results_iter = self._get_results(hql, schema, fetch_size=fetch_size) + def get_results(self, hql, schema='default', fetch_size=None, hive_conf=None): + """ + Get results of the provided hql in target schema. + :param hql: hql to be executed. + :param schema: target schema, default to 'default'. + :param fetch_size max size of result to fetch. + :param hive_conf: hive_conf to execute alone with the hql. + :return: results of hql execution. + """ + results_iter = self._get_results(hql, schema, + fetch_size=fetch_size, hive_conf=hive_conf) header = next(results_iter) results = { 'data': list(results_iter), @@ -822,9 +856,23 @@ class HiveServer2Hook(BaseHook): delimiter=',', lineterminator='\r\n', output_header=True, - fetch_size=1000): - - results_iter = self._get_results(hql, schema, fetch_size=fetch_size) + fetch_size=1000, + hive_conf=None): + """ + Execute hql in target schema and write results to a csv file. + :param hql: hql to be executed. + :param csv_filepath: filepath of csv to write results into. + :param schema: target schema, , default to 'default'. + :param delimiter: delimiter of the csv file. + :param lineterminator: lineterminator of the csv file. + :param output_header: header of the csv file. + :param fetch_size: number of result rows to write into the csv file. + :param hive_conf: hive_conf to execute alone with the hql. + :return: + """ + + results_iter = self._get_results(hql, schema, + fetch_size=fetch_size, hive_conf=hive_conf) header = next(results_iter) message = None http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/e9babff4/airflow/operators/bash_operator.py ---------------------------------------------------------------------- diff --git a/airflow/operators/bash_operator.py b/airflow/operators/bash_operator.py index 37a19db..17de014 100644 --- a/airflow/operators/bash_operator.py +++ b/airflow/operators/bash_operator.py @@ -18,16 +18,18 @@ # under the License. -from builtins import bytes import os import signal from subprocess import Popen, STDOUT, PIPE from tempfile import gettempdir, NamedTemporaryFile +from builtins import bytes + from airflow.exceptions import AirflowException from airflow.models import BaseOperator from airflow.utils.decorators import apply_defaults from airflow.utils.file import TemporaryDirectory +from airflow.utils.operator_helpers import context_to_airflow_vars class BashOperator(BaseOperator): @@ -73,6 +75,16 @@ class BashOperator(BaseOperator): """ self.log.info("Tmp dir root location: \n %s", gettempdir()) + # Prepare env for child process. + if self.env is None: + self.env = os.environ.copy() + airflow_context_vars = context_to_airflow_vars(context, in_env_var_format=True) + self.log.info("Exporting the following env vars:\n" + + '\n'.join(["{}={}".format(k, v) + for k, v in + airflow_context_vars.items()])) + self.env.update(airflow_context_vars) + self.lineage_data = self.bash_command with TemporaryDirectory(prefix='airflowtmp') as tmp_dir: http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/e9babff4/airflow/operators/hive_to_mysql.py ---------------------------------------------------------------------- diff --git a/airflow/operators/hive_to_mysql.py b/airflow/operators/hive_to_mysql.py index 4dc25a6..882a9d8 100644 --- a/airflow/operators/hive_to_mysql.py +++ b/airflow/operators/hive_to_mysql.py @@ -17,12 +17,13 @@ # specific language governing permissions and limitations # under the License. +from tempfile import NamedTemporaryFile + from airflow.hooks.hive_hooks import HiveServer2Hook from airflow.hooks.mysql_hook import MySqlHook from airflow.models import BaseOperator from airflow.utils.decorators import apply_defaults - -from tempfile import NamedTemporaryFile +from airflow.utils.operator_helpers import context_to_airflow_vars class HiveToMySqlTransfer(BaseOperator): @@ -88,7 +89,8 @@ class HiveToMySqlTransfer(BaseOperator): if self.bulk_load: tmpfile = NamedTemporaryFile() hive.to_csv(self.sql, tmpfile.name, delimiter='\t', - lineterminator='\n', output_header=False) + lineterminator='\n', output_header=False, + hive_conf=context_to_airflow_vars(context)) else: results = hive.get_records(self.sql) http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/e9babff4/airflow/operators/hive_to_samba_operator.py ---------------------------------------------------------------------- diff --git a/airflow/operators/hive_to_samba_operator.py b/airflow/operators/hive_to_samba_operator.py index f6978ac..fa2a961 100644 --- a/airflow/operators/hive_to_samba_operator.py +++ b/airflow/operators/hive_to_samba_operator.py @@ -23,6 +23,7 @@ from airflow.hooks.hive_hooks import HiveServer2Hook from airflow.hooks.samba_hook import SambaHook from airflow.models import BaseOperator from airflow.utils.decorators import apply_defaults +from airflow.utils.operator_helpers import context_to_airflow_vars class Hive2SambaOperator(BaseOperator): @@ -60,6 +61,7 @@ class Hive2SambaOperator(BaseOperator): hive = HiveServer2Hook(hiveserver2_conn_id=self.hiveserver2_conn_id) tmpfile = tempfile.NamedTemporaryFile() self.log.info("Fetching file from Hive") - hive.to_csv(hql=self.hql, csv_filepath=tmpfile.name) + hive.to_csv(hql=self.hql, csv_filepath=tmpfile.name, + hive_conf=context_to_airflow_vars(context)) self.log.info("Pushing to samba") samba.push_from_local(self.destination_filepath, tmpfile.name) http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/e9babff4/airflow/operators/python_operator.py ---------------------------------------------------------------------- diff --git a/airflow/operators/python_operator.py b/airflow/operators/python_operator.py index a564897..678a3de 100644 --- a/airflow/operators/python_operator.py +++ b/airflow/operators/python_operator.py @@ -17,21 +17,22 @@ # specific language governing permissions and limitations # under the License. -from builtins import str -import dill import inspect import os import pickle import subprocess import sys import types +from textwrap import dedent + +import dill +from builtins import str from airflow.exceptions import AirflowException from airflow.models import BaseOperator, SkipMixin from airflow.utils.decorators import apply_defaults from airflow.utils.file import TemporaryDirectory - -from textwrap import dedent +from airflow.utils.operator_helpers import context_to_airflow_vars class PythonOperator(BaseOperator): @@ -91,6 +92,13 @@ class PythonOperator(BaseOperator): self.template_ext = templates_exts def execute(self, context): + # Export context to make it available for callables to use. + airflow_context_vars = context_to_airflow_vars(context, in_env_var_format=True) + self.log.info("Exporting the following env vars:\n" + + '\n'.join(["{}={}".format(k, v) + for k, v in airflow_context_vars.items()])) + os.environ.update(airflow_context_vars) + if self.provide_context: context.update(self.op_kwargs) context['templates_dict'] = self.templates_dict http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/e9babff4/airflow/utils/operator_helpers.py ---------------------------------------------------------------------- diff --git a/airflow/utils/operator_helpers.py b/airflow/utils/operator_helpers.py index 356aa65..e981941 100644 --- a/airflow/utils/operator_helpers.py +++ b/airflow/utils/operator_helpers.py @@ -7,9 +7,9 @@ # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# +# # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -18,32 +18,49 @@ # under the License. # +AIRFLOW_VAR_NAME_FORMAT_MAPPING = { + 'AIRFLOW_CONTEXT_DAG_ID': {'default': 'airflow.ctx.dag_id', + 'env_var_format': 'AIRFLOW_CTX_DAG_ID'}, + 'AIRFLOW_CONTEXT_TASK_ID': {'default': 'airflow.ctx.task_id', + 'env_var_format': 'AIRFLOW_CTX_TASK_ID'}, + 'AIRFLOW_CONTEXT_EXECUTION_DATE': {'default': 'airflow.ctx.execution_date', + 'env_var_format': 'AIRFLOW_CTX_EXECUTION_DATE'}, + 'AIRFLOW_CONTEXT_DAG_RUN_ID': {'default': 'airflow.ctx.dag_run_id', + 'env_var_format': 'AIRFLOW_CTX_DAG_RUN_ID'} +} + -def context_to_airflow_vars(context): +def context_to_airflow_vars(context, in_env_var_format=False): """ Given a context, this function provides a dictionary of values that can be used to externally reconstruct relations between dags, dag_runs, tasks and task_instances. + Default to abc.def.ghi format and can be made to ABC_DEF_GHI format if + in_env_var_format is set to True. - :param context: The context for the task_instance of interest + :param context: The context for the task_instance of interest. :type context: dict + :param in_env_var_format: If returned vars should be in ABC_DEF_GHI format. + :type in_env_var_format: bool + :return task_instance context as dict. """ - params = {} - dag = context.get('dag') - if dag and dag.dag_id: - params['airflow.ctx.dag.dag_id'] = dag.dag_id - - dag_run = context.get('dag_run') - if dag_run and dag_run.execution_date: - params['airflow.ctx.dag_run.execution_date'] = dag_run.execution_date.isoformat() - - task = context.get('task') - if task and task.task_id: - params['airflow.ctx.task.task_id'] = task.task_id - + params = dict() + if in_env_var_format: + name_format = 'env_var_format' + else: + name_format = 'default' task_instance = context.get('task_instance') + if task_instance and task_instance.dag_id: + params[AIRFLOW_VAR_NAME_FORMAT_MAPPING['AIRFLOW_CONTEXT_DAG_ID'][ + name_format]] = task_instance.dag_id + if task_instance and task_instance.task_id: + params[AIRFLOW_VAR_NAME_FORMAT_MAPPING['AIRFLOW_CONTEXT_TASK_ID'][ + name_format]] = task_instance.task_id if task_instance and task_instance.execution_date: - params['airflow.ctx.task_instance.execution_date'] = ( - task_instance.execution_date.isoformat() - ) - + params[ + AIRFLOW_VAR_NAME_FORMAT_MAPPING['AIRFLOW_CONTEXT_EXECUTION_DATE'][ + name_format]] = task_instance.execution_date.isoformat() + dag_run = context.get('dag_run') + if dag_run and dag_run.run_id: + params[AIRFLOW_VAR_NAME_FORMAT_MAPPING['AIRFLOW_CONTEXT_DAG_RUN_ID'][ + name_format]] = dag_run.run_id return params http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/e9babff4/tests/hooks/test_hive_hook.py ---------------------------------------------------------------------- diff --git a/tests/hooks/test_hive_hook.py b/tests/hooks/test_hive_hook.py index 81270c6..1cac74c 100644 --- a/tests/hooks/test_hive_hook.py +++ b/tests/hooks/test_hive_hook.py @@ -21,24 +21,22 @@ import datetime import itertools import os - -import pandas as pd import random - -import mock import unittest - from collections import OrderedDict + +import mock +import pandas as pd from hmsclient import HMSClient +from airflow import DAG, configuration from airflow.exceptions import AirflowException from airflow.hooks.hive_hooks import HiveCliHook, HiveMetastoreHook, HiveServer2Hook -from airflow import DAG, configuration from airflow.operators.hive_operator import HiveOperator from airflow.utils import timezone +from airflow.utils.operator_helpers import AIRFLOW_VAR_NAME_FORMAT_MAPPING from airflow.utils.tests import assertEqualIgnoreMultipleSpaces - configuration.load_test_config() @@ -97,6 +95,39 @@ class TestHiveCliHook(unittest.TestCase): hook = HiveCliHook() hook.run_cli("SHOW DATABASES") + def test_run_cli_with_hive_conf(self): + hql = "set key;\n" \ + "set airflow.ctx.dag_id;\nset airflow.ctx.dag_run_id;\n" \ + "set airflow.ctx.task_id;\nset airflow.ctx.execution_date;\n" + + dag_id_ctx_var_name = \ + AIRFLOW_VAR_NAME_FORMAT_MAPPING['AIRFLOW_CONTEXT_DAG_ID']['env_var_format'] + task_id_ctx_var_name = \ + AIRFLOW_VAR_NAME_FORMAT_MAPPING['AIRFLOW_CONTEXT_TASK_ID']['env_var_format'] + execution_date_ctx_var_name = \ + AIRFLOW_VAR_NAME_FORMAT_MAPPING['AIRFLOW_CONTEXT_EXECUTION_DATE'][ + 'env_var_format'] + dag_run_id_ctx_var_name = \ + AIRFLOW_VAR_NAME_FORMAT_MAPPING['AIRFLOW_CONTEXT_DAG_RUN_ID'][ + 'env_var_format'] + os.environ[dag_id_ctx_var_name] = 'test_dag_id' + os.environ[task_id_ctx_var_name] = 'test_task_id' + os.environ[execution_date_ctx_var_name] = 'test_execution_date' + os.environ[dag_run_id_ctx_var_name] = 'test_dag_run_id' + + hook = HiveCliHook() + output = hook.run_cli(hql=hql, hive_conf={'key': 'value'}) + self.assertIn('value', output) + self.assertIn('test_dag_id', output) + self.assertIn('test_task_id', output) + self.assertIn('test_execution_date', output) + self.assertIn('test_dag_run_id', output) + + del os.environ[dag_id_ctx_var_name] + del os.environ[task_id_ctx_var_name] + del os.environ[execution_date_ctx_var_name] + del os.environ[dag_run_id_ctx_var_name] + @mock.patch('airflow.hooks.hive_hooks.HiveCliHook.run_cli') def test_load_file(self, mock_run_cli): filepath = "/path/to/input/file" @@ -419,3 +450,41 @@ class TestHiveServer2Hook(unittest.TestCase): hook = HiveServer2Hook() results = hook.get_records(sqls, schema=self.database) self.assertListEqual(results, [(1, 1), (2, 2)]) + + def test_get_results_with_hive_conf(self): + hql = ["set key", + "set airflow.ctx.dag_id", + "set airflow.ctx.dag_run_id", + "set airflow.ctx.task_id", + "set airflow.ctx.execution_date"] + + dag_id_ctx_var_name = \ + AIRFLOW_VAR_NAME_FORMAT_MAPPING['AIRFLOW_CONTEXT_DAG_ID']['env_var_format'] + task_id_ctx_var_name = \ + AIRFLOW_VAR_NAME_FORMAT_MAPPING['AIRFLOW_CONTEXT_TASK_ID']['env_var_format'] + execution_date_ctx_var_name = \ + AIRFLOW_VAR_NAME_FORMAT_MAPPING['AIRFLOW_CONTEXT_EXECUTION_DATE'][ + 'env_var_format'] + dag_run_id_ctx_var_name = \ + AIRFLOW_VAR_NAME_FORMAT_MAPPING['AIRFLOW_CONTEXT_DAG_RUN_ID'][ + 'env_var_format'] + os.environ[dag_id_ctx_var_name] = 'test_dag_id' + os.environ[task_id_ctx_var_name] = 'test_task_id' + os.environ[execution_date_ctx_var_name] = 'test_execution_date' + os.environ[dag_run_id_ctx_var_name] = 'test_dag_run_id' + + hook = HiveServer2Hook() + output = '\n'.join(res_tuple[0] + for res_tuple + in hook.get_results(hql=hql, + hive_conf={'key': 'value'})['data']) + self.assertIn('value', output) + self.assertIn('test_dag_id', output) + self.assertIn('test_task_id', output) + self.assertIn('test_execution_date', output) + self.assertIn('test_dag_run_id', output) + + del os.environ[dag_id_ctx_var_name] + del os.environ[task_id_ctx_var_name] + del os.environ[execution_date_ctx_var_name] + del os.environ[dag_run_id_ctx_var_name] http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/e9babff4/tests/operators/bash_operator.py ---------------------------------------------------------------------- diff --git a/tests/operators/bash_operator.py b/tests/operators/bash_operator.py index 1ce77e9..e0a0ff3 100644 --- a/tests/operators/bash_operator.py +++ b/tests/operators/bash_operator.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -import unittest import os +import unittest from datetime import datetime, timedelta from airflow import DAG @@ -59,7 +59,11 @@ class BashOperatorTestCase(unittest.TestCase): task_id='echo_env_vars', dag=self.dag, bash_command='echo $AIRFLOW_HOME>> {0};' - 'echo $PYTHONPATH>> {0};'.format(fname) + 'echo $PYTHONPATH>> {0};' + 'echo $AIRFLOW_CTX_DAG_ID >> {0};' + 'echo $AIRFLOW_CTX_TASK_ID>> {0};' + 'echo $AIRFLOW_CTX_EXECUTION_DATE>> {0};' + 'echo $AIRFLOW_CTX_DAG_RUN_ID>> {0};'.format(fname) ) os.environ['AIRFLOW_HOME'] = 'MY_PATH_TO_AIRFLOW_HOME' t.run(DEFAULT_DATE, DEFAULT_DATE, @@ -70,3 +74,7 @@ class BashOperatorTestCase(unittest.TestCase): self.assertIn('MY_PATH_TO_AIRFLOW_HOME', output) # exported in run_unit_tests.sh as part of PYTHONPATH self.assertIn('tests/test_utils', output) + self.assertIn('bash_op_test', output) + self.assertIn('echo_env_vars', output) + self.assertIn(DEFAULT_DATE.isoformat(), output) + self.assertIn('manual__' + DEFAULT_DATE.isoformat(), output) http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/e9babff4/tests/operators/python_operator.py ---------------------------------------------------------------------- diff --git a/tests/operators/python_operator.py b/tests/operators/python_operator.py index 43aa8a6..735a4d7 100644 --- a/tests/operators/python_operator.py +++ b/tests/operators/python_operator.py @@ -20,28 +20,43 @@ from __future__ import print_function, unicode_literals import copy -import datetime +import logging +import os import unittest +from datetime import timedelta -from airflow import configuration, DAG -from airflow.models import TaskInstance as TI +from airflow import configuration +from airflow.exceptions import AirflowException +from airflow.models import TaskInstance as TI, DAG, DagRun +from airflow.operators.dummy_operator import DummyOperator from airflow.operators.python_operator import PythonOperator, BranchPythonOperator from airflow.operators.python_operator import ShortCircuitOperator -from airflow.operators.dummy_operator import DummyOperator from airflow.settings import Session from airflow.utils import timezone from airflow.utils.state import State -from airflow.exceptions import AirflowException -import logging - DEFAULT_DATE = timezone.datetime(2016, 1, 1) END_DATE = timezone.datetime(2016, 1, 2) -INTERVAL = datetime.timedelta(hours=12) +INTERVAL = timedelta(hours=12) FROZEN_NOW = timezone.datetime(2016, 1, 2, 12, 1, 1) +TI_CONTEXT_ENV_VARS = ['AIRFLOW_CTX_DAG_ID', + 'AIRFLOW_CTX_TASK_ID', + 'AIRFLOW_CTX_EXECUTION_DATE', + 'AIRFLOW_CTX_DAG_RUN_ID'] + class PythonOperatorTest(unittest.TestCase): + @classmethod + def setUpClass(cls): + super(PythonOperatorTest, cls).setUpClass() + + session = Session() + + session.query(DagRun).delete() + session.query(TI).delete() + session.commit() + session.close() def setUp(self): super(PythonOperatorTest, self).setUp() @@ -56,6 +71,21 @@ class PythonOperatorTest(unittest.TestCase): self.clear_run() self.addCleanup(self.clear_run) + def tearDown(self): + super(PythonOperatorTest, self).tearDown() + + session = Session() + + session.query(DagRun).delete() + session.query(TI).delete() + print(len(session.query(DagRun).all())) + session.commit() + session.close() + + for var in TI_CONTEXT_ENV_VARS: + if var in os.environ: + del os.environ[var] + def do_run(self): self.run = True @@ -107,8 +137,46 @@ class PythonOperatorTest(unittest.TestCase): self.assertEquals(id(original_task.python_callable), id(new_task.python_callable)) + def _env_var_check_callback(self): + self.assertEqual('test_dag', os.environ['AIRFLOW_CTX_DAG_ID']) + self.assertEqual('hive_in_python_op', os.environ['AIRFLOW_CTX_TASK_ID']) + self.assertEqual(DEFAULT_DATE.isoformat(), + os.environ['AIRFLOW_CTX_EXECUTION_DATE']) + self.assertEqual('manual__' + DEFAULT_DATE.isoformat(), + os.environ['AIRFLOW_CTX_DAG_RUN_ID']) + + def test_echo_env_variables(self): + """ + Test that env variables are exported correctly to the + python callback in the task. + """ + self.dag.create_dagrun( + run_id='manual__' + DEFAULT_DATE.isoformat(), + execution_date=DEFAULT_DATE, + start_date=DEFAULT_DATE, + state=State.RUNNING, + external_trigger=False, + ) + + t = PythonOperator(task_id='hive_in_python_op', + dag=self.dag, + python_callable=self._env_var_check_callback + ) + t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + class BranchOperatorTest(unittest.TestCase): + @classmethod + def setUpClass(cls): + super(BranchOperatorTest, cls).setUpClass() + + session = Session() + + session.query(DagRun).delete() + session.query(TI).delete() + session.commit() + session.close() + def setUp(self): self.dag = DAG('branch_operator_test', default_args={ @@ -125,6 +193,17 @@ class BranchOperatorTest(unittest.TestCase): self.branch_2.set_upstream(self.branch_op) self.dag.clear() + def tearDown(self): + super(BranchOperatorTest, self).tearDown() + + session = Session() + + session.query(DagRun).delete() + session.query(TI).delete() + print(len(session.query(DagRun).all())) + session.commit() + session.close() + def test_without_dag_run(self): """This checks the defensive against non existent tasks in a dag run""" self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) @@ -170,6 +249,27 @@ class BranchOperatorTest(unittest.TestCase): class ShortCircuitOperatorTest(unittest.TestCase): + @classmethod + def setUpClass(cls): + super(ShortCircuitOperatorTest, cls).setUpClass() + + session = Session() + + session.query(DagRun).delete() + session.query(TI).delete() + session.commit() + session.close() + + def tearDown(self): + super(ShortCircuitOperatorTest, self).tearDown() + + session = Session() + + session.query(DagRun).delete() + session.query(TI).delete() + session.commit() + session.close() + def test_without_dag_run(self): """This checks the defensive against non existent tasks in a dag run""" value = False http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/e9babff4/tests/utils/test_operator_helpers.py ---------------------------------------------------------------------- diff --git a/tests/utils/test_operator_helpers.py b/tests/utils/test_operator_helpers.py index 592b456..a358601 100644 --- a/tests/utils/test_operator_helpers.py +++ b/tests/utils/test_operator_helpers.py @@ -7,9 +7,9 @@ # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# +# # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -17,9 +17,10 @@ # specific language governing permissions and limitations # under the License. +import unittest from datetime import datetime + import mock -import unittest from airflow.utils import operator_helpers @@ -31,16 +32,18 @@ class TestOperatorHelpers(unittest.TestCase): self.dag_id = 'dag_id' self.task_id = 'task_id' self.execution_date = '2017-05-21T00:00:00' + self.dag_run_id = 'dag_run_id' self.context = { - 'dag': mock.MagicMock(name='dag', dag_id=self.dag_id), 'dag_run': mock.MagicMock( name='dag_run', + run_id=self.dag_run_id, execution_date=datetime.strptime(self.execution_date, '%Y-%m-%dT%H:%M:%S'), ), - 'task': mock.MagicMock(name='task', task_id=self.task_id), 'task_instance': mock.MagicMock( name='task_instance', + task_id=self.task_id, + dag_id=self.dag_id, execution_date=datetime.strptime(self.execution_date, '%Y-%m-%dT%H:%M:%S'), ), @@ -53,10 +56,21 @@ class TestOperatorHelpers(unittest.TestCase): self.assertDictEqual( operator_helpers.context_to_airflow_vars(self.context), { - 'airflow.ctx.dag.dag_id': self.dag_id, - 'airflow.ctx.dag_run.execution_date': self.execution_date, - 'airflow.ctx.task.task_id': self.task_id, - 'airflow.ctx.task_instance.execution_date': self.execution_date, + 'airflow.ctx.dag_id': self.dag_id, + 'airflow.ctx.execution_date': self.execution_date, + 'airflow.ctx.task_id': self.task_id, + 'airflow.ctx.dag_run_id': self.dag_run_id, + } + ) + + self.assertDictEqual( + operator_helpers.context_to_airflow_vars(self.context, + in_env_var_format=True), + { + 'AIRFLOW_CTX_DAG_ID': self.dag_id, + 'AIRFLOW_CTX_EXECUTION_DATE': self.execution_date, + 'AIRFLOW_CTX_TASK_ID': self.task_id, + 'AIRFLOW_CTX_DAG_RUN_ID': self.dag_run_id, } )