Repository: incubator-airflow
Updated Branches:
  refs/heads/master 9ebb04acb -> e9babff4e


[AIRFLOW-2463] Make task instance context available for hive queries

[AIRFLOW-2463] Make task instance context
available for hive queries

update UPDATING.md, please squash

Closes #3405 from yrqls21/kevin_yang_add_context


Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/e9babff4
Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/e9babff4
Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/e9babff4

Branch: refs/heads/master
Commit: e9babff4eb3334b0d71cda31c2f6cbfe7b741389
Parents: 9ebb04a
Author: Kevin Yang <kevin.y...@airbnb.com>
Authored: Wed Jul 11 10:28:06 2018 +0200
Committer: Fokko Driesprong <fokkodriespr...@godatadriven.com>
Committed: Wed Jul 11 10:28:06 2018 +0200

----------------------------------------------------------------------
 UPDATING.md                                 |   3 +-
 airflow/hooks/hive_hooks.py                 |  88 +++++++++++++----
 airflow/operators/bash_operator.py          |  14 ++-
 airflow/operators/hive_to_mysql.py          |   8 +-
 airflow/operators/hive_to_samba_operator.py |   4 +-
 airflow/operators/python_operator.py        |  16 +++-
 airflow/utils/operator_helpers.py           |  59 ++++++++----
 tests/hooks/test_hive_hook.py               |  83 ++++++++++++++--
 tests/operators/bash_operator.py            |  12 ++-
 tests/operators/python_operator.py          | 116 +++++++++++++++++++++--
 tests/utils/test_operator_helpers.py        |  32 +++++--
 11 files changed, 358 insertions(+), 77 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/e9babff4/UPDATING.md
----------------------------------------------------------------------
diff --git a/UPDATING.md b/UPDATING.md
index 82983b2..74a0b1b 100644
--- a/UPDATING.md
+++ b/UPDATING.md
@@ -42,6 +42,7 @@ There are five roles created for Airflow by default: Admin, 
User, Op, Viewer, an
 - Airflow dag home page is now `/home` (instead of `/admin`).
 - All ModelViews in Flask-AppBuilder follow a different pattern from 
Flask-Admin. The `/admin` part of the url path will no longer exist. For 
example: `/admin/connection` becomes `/connection/list`, 
`/admin/connection/new` becomes `/connection/add`, `/admin/connection/edit` 
becomes `/connection/edit`, etc.
 - Due to security concerns, the new webserver will no longer support the 
features in the `Data Profiling` menu of old UI, including `Ad Hoc Query`, 
`Charts`, and `Known Events`.
+- HiveServer2Hook.get_results() always returns a list of tuples, even when a 
single column is queried, as per Python API 2.
 
 ### airflow.contrib.sensors.hdfs_sensors renamed to 
airflow.contrib.sensors.hdfs_sensor
 
@@ -77,7 +78,7 @@ supported and will be removed entirely in Airflow 2.0
 With Airflow 1.9 or lower, Unload operation always included header row. In 
order to include header row,
 we need to turn off parallel unload. It is preferred to perform unload 
operation using all nodes so that it is
 faster for larger tables. So, parameter called `include_header` is added and 
default is set to False.
-Header row will be added only if this parameter is set True and also in that 
case parallel will be automatically turned off (`PARALLEL OFF`)  
+Header row will be added only if this parameter is set True and also in that 
case parallel will be automatically turned off (`PARALLEL OFF`)
 
 ### Google cloud connection string
 

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/e9babff4/airflow/hooks/hive_hooks.py
----------------------------------------------------------------------
diff --git a/airflow/hooks/hive_hooks.py b/airflow/hooks/hive_hooks.py
index 93e1f45..5ac99ee 100644
--- a/airflow/hooks/hive_hooks.py
+++ b/airflow/hooks/hive_hooks.py
@@ -21,30 +21,40 @@ from __future__ import print_function, unicode_literals
 
 import contextlib
 import os
-
-from six.moves import zip
-from past.builtins import basestring, unicode
-
-import unicodecsv as csv
 import re
-import six
 import subprocess
 import time
 from collections import OrderedDict
 from tempfile import NamedTemporaryFile
+
 import hmsclient
+import six
+import unicodecsv as csv
+from past.builtins import basestring
+from past.builtins import unicode
+from six.moves import zip
 
-from airflow import configuration as conf
+import airflow.security.utils as utils
+from airflow import configuration
 from airflow.exceptions import AirflowException
 from airflow.hooks.base_hook import BaseHook
-from airflow.utils.helpers import as_flattened_list
 from airflow.utils.file import TemporaryDirectory
-from airflow import configuration
-import airflow.security.utils as utils
+from airflow.utils.helpers import as_flattened_list
+from airflow.utils.operator_helpers import AIRFLOW_VAR_NAME_FORMAT_MAPPING
 
 HIVE_QUEUE_PRIORITIES = ['VERY_HIGH', 'HIGH', 'NORMAL', 'LOW', 'VERY_LOW']
 
 
+def get_context_from_env_var():
+    """
+    Extract context from env variable, e.g. dag_id, task_id and execution_date,
+    so that they can be used inside BashOperator and PythonOperator.
+    :return: The context of interest.
+    """
+    return {format_map['default']: 
os.environ.get(format_map['env_var_format'], '')
+            for format_map in AIRFLOW_VAR_NAME_FORMAT_MAPPING.values()}
+
+
 class HiveCliHook(BaseHook):
     """Simple wrapper around the hive CLI.
 
@@ -92,8 +102,8 @@ class HiveCliHook(BaseHook):
                     "Invalid Mapred Queue Priority.  Valid values are: "
                     "{}".format(', '.join(HIVE_QUEUE_PRIORITIES)))
 
-        self.mapred_queue = mapred_queue or conf.get('hive',
-                                                     
'default_hive_mapred_queue')
+        self.mapred_queue = mapred_queue or configuration.get('hive',
+                                                              
'default_hive_mapred_queue')
         self.mapred_queue_priority = mapred_queue_priority
         self.mapred_job_name = mapred_job_name
 
@@ -126,6 +136,7 @@ class HiveCliHook(BaseHook):
                 jdbc_url += ";auth=" + self.auth
 
             jdbc_url = jdbc_url.format(**locals())
+            jdbc_url = '"{}"'.format(jdbc_url)
 
             cmd_extra += ['-u', jdbc_url]
             if conn.login:
@@ -184,10 +195,15 @@ class HiveCliHook(BaseHook):
 
         with TemporaryDirectory(prefix='airflow_hiveop_') as tmp_dir:
             with NamedTemporaryFile(dir=tmp_dir) as f:
+                hql = hql + '\n'
                 f.write(hql.encode('UTF-8'))
                 f.flush()
                 hive_cmd = self._prepare_cli_cmd()
-                hive_conf_params = self._prepare_hiveconf(hive_conf)
+                env_context = get_context_from_env_var()
+                # Only extend the hive_conf if it is defined.
+                if hive_conf:
+                    env_context.update(hive_conf)
+                hive_conf_params = self._prepare_hiveconf(env_context)
                 if self.mapred_queue:
                     hive_conf_params.extend(
                         ['-hiveconf',
@@ -772,7 +788,7 @@ class HiveServer2Hook(BaseHook):
             username=db.login or username,
             database=schema or db.schema or 'default')
 
-    def _get_results(self, hql, schema='default', fetch_size=None):
+    def _get_results(self, hql, schema='default', fetch_size=None, 
hive_conf=None):
         from pyhive.exc import ProgrammingError
         if isinstance(hql, basestring):
             hql = [hql]
@@ -780,12 +796,21 @@ class HiveServer2Hook(BaseHook):
         with contextlib.closing(self.get_conn(schema)) as conn, \
                 contextlib.closing(conn.cursor()) as cur:
             cur.arraysize = fetch_size or 1000
+
+            env_context = get_context_from_env_var()
+            if hive_conf:
+                env_context.update(hive_conf)
+            for k, v in env_context.items():
+                cur.execute("set {}={}".format(k, v))
+
             for statement in hql:
                 cur.execute(statement)
                 # we only get results of statements that returns
                 lowered_statement = statement.lower().strip()
                 if (lowered_statement.startswith('select') or
-                   lowered_statement.startswith('with')):
+                    lowered_statement.startswith('with') or
+                    (lowered_statement.startswith('set') and
+                     '=' not in lowered_statement)):
                     description = [c for c in cur.description]
                     if previous_description and previous_description != 
description:
                         message = '''The statements are producing different 
descriptions:
@@ -805,8 +830,17 @@ class HiveServer2Hook(BaseHook):
                     except ProgrammingError:
                         self.log.debug("get_results returned no records")
 
-    def get_results(self, hql, schema='default', fetch_size=None):
-        results_iter = self._get_results(hql, schema, fetch_size=fetch_size)
+    def get_results(self, hql, schema='default', fetch_size=None, 
hive_conf=None):
+        """
+        Get results of the provided hql in target schema.
+        :param hql: hql to be executed.
+        :param schema: target schema, default to 'default'.
+        :param fetch_size max size of result to fetch.
+        :param hive_conf: hive_conf to execute alone with the hql.
+        :return: results of hql execution.
+        """
+        results_iter = self._get_results(hql, schema,
+                                         fetch_size=fetch_size, 
hive_conf=hive_conf)
         header = next(results_iter)
         results = {
             'data': list(results_iter),
@@ -822,9 +856,23 @@ class HiveServer2Hook(BaseHook):
             delimiter=',',
             lineterminator='\r\n',
             output_header=True,
-            fetch_size=1000):
-
-        results_iter = self._get_results(hql, schema, fetch_size=fetch_size)
+            fetch_size=1000,
+            hive_conf=None):
+        """
+        Execute hql in target schema and write results to a csv file.
+        :param hql: hql to be executed.
+        :param csv_filepath: filepath of csv to write results into.
+        :param schema: target schema, , default to 'default'.
+        :param delimiter: delimiter of the csv file.
+        :param lineterminator: lineterminator of the csv file.
+        :param output_header: header of the csv file.
+        :param fetch_size: number of result rows to write into the csv file.
+        :param hive_conf: hive_conf to execute alone with the hql.
+        :return:
+        """
+
+        results_iter = self._get_results(hql, schema,
+                                         fetch_size=fetch_size, 
hive_conf=hive_conf)
         header = next(results_iter)
         message = None
 

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/e9babff4/airflow/operators/bash_operator.py
----------------------------------------------------------------------
diff --git a/airflow/operators/bash_operator.py 
b/airflow/operators/bash_operator.py
index 37a19db..17de014 100644
--- a/airflow/operators/bash_operator.py
+++ b/airflow/operators/bash_operator.py
@@ -18,16 +18,18 @@
 # under the License.
 
 
-from builtins import bytes
 import os
 import signal
 from subprocess import Popen, STDOUT, PIPE
 from tempfile import gettempdir, NamedTemporaryFile
 
+from builtins import bytes
+
 from airflow.exceptions import AirflowException
 from airflow.models import BaseOperator
 from airflow.utils.decorators import apply_defaults
 from airflow.utils.file import TemporaryDirectory
+from airflow.utils.operator_helpers import context_to_airflow_vars
 
 
 class BashOperator(BaseOperator):
@@ -73,6 +75,16 @@ class BashOperator(BaseOperator):
         """
         self.log.info("Tmp dir root location: \n %s", gettempdir())
 
+        # Prepare env for child process.
+        if self.env is None:
+            self.env = os.environ.copy()
+        airflow_context_vars = context_to_airflow_vars(context, 
in_env_var_format=True)
+        self.log.info("Exporting the following env vars:\n" +
+                      '\n'.join(["{}={}".format(k, v)
+                                 for k, v in
+                                 airflow_context_vars.items()]))
+        self.env.update(airflow_context_vars)
+
         self.lineage_data = self.bash_command
 
         with TemporaryDirectory(prefix='airflowtmp') as tmp_dir:

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/e9babff4/airflow/operators/hive_to_mysql.py
----------------------------------------------------------------------
diff --git a/airflow/operators/hive_to_mysql.py 
b/airflow/operators/hive_to_mysql.py
index 4dc25a6..882a9d8 100644
--- a/airflow/operators/hive_to_mysql.py
+++ b/airflow/operators/hive_to_mysql.py
@@ -17,12 +17,13 @@
 # specific language governing permissions and limitations
 # under the License.
 
+from tempfile import NamedTemporaryFile
+
 from airflow.hooks.hive_hooks import HiveServer2Hook
 from airflow.hooks.mysql_hook import MySqlHook
 from airflow.models import BaseOperator
 from airflow.utils.decorators import apply_defaults
-
-from tempfile import NamedTemporaryFile
+from airflow.utils.operator_helpers import context_to_airflow_vars
 
 
 class HiveToMySqlTransfer(BaseOperator):
@@ -88,7 +89,8 @@ class HiveToMySqlTransfer(BaseOperator):
         if self.bulk_load:
             tmpfile = NamedTemporaryFile()
             hive.to_csv(self.sql, tmpfile.name, delimiter='\t',
-                        lineterminator='\n', output_header=False)
+                        lineterminator='\n', output_header=False,
+                        hive_conf=context_to_airflow_vars(context))
         else:
             results = hive.get_records(self.sql)
 

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/e9babff4/airflow/operators/hive_to_samba_operator.py
----------------------------------------------------------------------
diff --git a/airflow/operators/hive_to_samba_operator.py 
b/airflow/operators/hive_to_samba_operator.py
index f6978ac..fa2a961 100644
--- a/airflow/operators/hive_to_samba_operator.py
+++ b/airflow/operators/hive_to_samba_operator.py
@@ -23,6 +23,7 @@ from airflow.hooks.hive_hooks import HiveServer2Hook
 from airflow.hooks.samba_hook import SambaHook
 from airflow.models import BaseOperator
 from airflow.utils.decorators import apply_defaults
+from airflow.utils.operator_helpers import context_to_airflow_vars
 
 
 class Hive2SambaOperator(BaseOperator):
@@ -60,6 +61,7 @@ class Hive2SambaOperator(BaseOperator):
         hive = HiveServer2Hook(hiveserver2_conn_id=self.hiveserver2_conn_id)
         tmpfile = tempfile.NamedTemporaryFile()
         self.log.info("Fetching file from Hive")
-        hive.to_csv(hql=self.hql, csv_filepath=tmpfile.name)
+        hive.to_csv(hql=self.hql, csv_filepath=tmpfile.name,
+                    hive_conf=context_to_airflow_vars(context))
         self.log.info("Pushing to samba")
         samba.push_from_local(self.destination_filepath, tmpfile.name)

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/e9babff4/airflow/operators/python_operator.py
----------------------------------------------------------------------
diff --git a/airflow/operators/python_operator.py 
b/airflow/operators/python_operator.py
index a564897..678a3de 100644
--- a/airflow/operators/python_operator.py
+++ b/airflow/operators/python_operator.py
@@ -17,21 +17,22 @@
 # specific language governing permissions and limitations
 # under the License.
 
-from builtins import str
-import dill
 import inspect
 import os
 import pickle
 import subprocess
 import sys
 import types
+from textwrap import dedent
+
+import dill
+from builtins import str
 
 from airflow.exceptions import AirflowException
 from airflow.models import BaseOperator, SkipMixin
 from airflow.utils.decorators import apply_defaults
 from airflow.utils.file import TemporaryDirectory
-
-from textwrap import dedent
+from airflow.utils.operator_helpers import context_to_airflow_vars
 
 
 class PythonOperator(BaseOperator):
@@ -91,6 +92,13 @@ class PythonOperator(BaseOperator):
             self.template_ext = templates_exts
 
     def execute(self, context):
+        # Export context to make it available for callables to use.
+        airflow_context_vars = context_to_airflow_vars(context, 
in_env_var_format=True)
+        self.log.info("Exporting the following env vars:\n" +
+                      '\n'.join(["{}={}".format(k, v)
+                                 for k, v in airflow_context_vars.items()]))
+        os.environ.update(airflow_context_vars)
+
         if self.provide_context:
             context.update(self.op_kwargs)
             context['templates_dict'] = self.templates_dict

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/e9babff4/airflow/utils/operator_helpers.py
----------------------------------------------------------------------
diff --git a/airflow/utils/operator_helpers.py 
b/airflow/utils/operator_helpers.py
index 356aa65..e981941 100644
--- a/airflow/utils/operator_helpers.py
+++ b/airflow/utils/operator_helpers.py
@@ -7,9 +7,9 @@
 # to you under the Apache License, Version 2.0 (the
 # "License"); you may not use this file except in compliance
 # with the License.  You may obtain a copy of the License at
-# 
+#
 #   http://www.apache.org/licenses/LICENSE-2.0
-# 
+#
 # Unless required by applicable law or agreed to in writing,
 # software distributed under the License is distributed on an
 # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -18,32 +18,49 @@
 # under the License.
 #
 
+AIRFLOW_VAR_NAME_FORMAT_MAPPING = {
+    'AIRFLOW_CONTEXT_DAG_ID': {'default': 'airflow.ctx.dag_id',
+                               'env_var_format': 'AIRFLOW_CTX_DAG_ID'},
+    'AIRFLOW_CONTEXT_TASK_ID': {'default': 'airflow.ctx.task_id',
+                                'env_var_format': 'AIRFLOW_CTX_TASK_ID'},
+    'AIRFLOW_CONTEXT_EXECUTION_DATE': {'default': 'airflow.ctx.execution_date',
+                                       'env_var_format': 
'AIRFLOW_CTX_EXECUTION_DATE'},
+    'AIRFLOW_CONTEXT_DAG_RUN_ID': {'default': 'airflow.ctx.dag_run_id',
+                                   'env_var_format': 'AIRFLOW_CTX_DAG_RUN_ID'}
+}
+
 
-def context_to_airflow_vars(context):
+def context_to_airflow_vars(context, in_env_var_format=False):
     """
     Given a context, this function provides a dictionary of values that can be 
used to
     externally reconstruct relations between dags, dag_runs, tasks and 
task_instances.
+    Default to abc.def.ghi format and can be made to ABC_DEF_GHI format if
+    in_env_var_format is set to True.
 
-    :param context: The context for the task_instance of interest
+    :param context: The context for the task_instance of interest.
     :type context: dict
+    :param in_env_var_format: If returned vars should be in ABC_DEF_GHI format.
+    :type in_env_var_format: bool
+    :return task_instance context as dict.
     """
-    params = {}
-    dag = context.get('dag')
-    if dag and dag.dag_id:
-        params['airflow.ctx.dag.dag_id'] = dag.dag_id
-
-    dag_run = context.get('dag_run')
-    if dag_run and dag_run.execution_date:
-        params['airflow.ctx.dag_run.execution_date'] = 
dag_run.execution_date.isoformat()
-
-    task = context.get('task')
-    if task and task.task_id:
-        params['airflow.ctx.task.task_id'] = task.task_id
-
+    params = dict()
+    if in_env_var_format:
+        name_format = 'env_var_format'
+    else:
+        name_format = 'default'
     task_instance = context.get('task_instance')
+    if task_instance and task_instance.dag_id:
+        params[AIRFLOW_VAR_NAME_FORMAT_MAPPING['AIRFLOW_CONTEXT_DAG_ID'][
+            name_format]] = task_instance.dag_id
+    if task_instance and task_instance.task_id:
+        params[AIRFLOW_VAR_NAME_FORMAT_MAPPING['AIRFLOW_CONTEXT_TASK_ID'][
+            name_format]] = task_instance.task_id
     if task_instance and task_instance.execution_date:
-        params['airflow.ctx.task_instance.execution_date'] = (
-            task_instance.execution_date.isoformat()
-        )
-
+        params[
+            AIRFLOW_VAR_NAME_FORMAT_MAPPING['AIRFLOW_CONTEXT_EXECUTION_DATE'][
+                name_format]] = task_instance.execution_date.isoformat()
+    dag_run = context.get('dag_run')
+    if dag_run and dag_run.run_id:
+        params[AIRFLOW_VAR_NAME_FORMAT_MAPPING['AIRFLOW_CONTEXT_DAG_RUN_ID'][
+            name_format]] = dag_run.run_id
     return params

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/e9babff4/tests/hooks/test_hive_hook.py
----------------------------------------------------------------------
diff --git a/tests/hooks/test_hive_hook.py b/tests/hooks/test_hive_hook.py
index 81270c6..1cac74c 100644
--- a/tests/hooks/test_hive_hook.py
+++ b/tests/hooks/test_hive_hook.py
@@ -21,24 +21,22 @@
 import datetime
 import itertools
 import os
-
-import pandas as pd
 import random
-
-import mock
 import unittest
-
 from collections import OrderedDict
+
+import mock
+import pandas as pd
 from hmsclient import HMSClient
 
+from airflow import DAG, configuration
 from airflow.exceptions import AirflowException
 from airflow.hooks.hive_hooks import HiveCliHook, HiveMetastoreHook, 
HiveServer2Hook
-from airflow import DAG, configuration
 from airflow.operators.hive_operator import HiveOperator
 from airflow.utils import timezone
+from airflow.utils.operator_helpers import AIRFLOW_VAR_NAME_FORMAT_MAPPING
 from airflow.utils.tests import assertEqualIgnoreMultipleSpaces
 
-
 configuration.load_test_config()
 
 
@@ -97,6 +95,39 @@ class TestHiveCliHook(unittest.TestCase):
         hook = HiveCliHook()
         hook.run_cli("SHOW DATABASES")
 
+    def test_run_cli_with_hive_conf(self):
+        hql = "set key;\n" \
+              "set airflow.ctx.dag_id;\nset airflow.ctx.dag_run_id;\n" \
+              "set airflow.ctx.task_id;\nset airflow.ctx.execution_date;\n"
+
+        dag_id_ctx_var_name = \
+            
AIRFLOW_VAR_NAME_FORMAT_MAPPING['AIRFLOW_CONTEXT_DAG_ID']['env_var_format']
+        task_id_ctx_var_name = \
+            
AIRFLOW_VAR_NAME_FORMAT_MAPPING['AIRFLOW_CONTEXT_TASK_ID']['env_var_format']
+        execution_date_ctx_var_name = \
+            AIRFLOW_VAR_NAME_FORMAT_MAPPING['AIRFLOW_CONTEXT_EXECUTION_DATE'][
+                'env_var_format']
+        dag_run_id_ctx_var_name = \
+            AIRFLOW_VAR_NAME_FORMAT_MAPPING['AIRFLOW_CONTEXT_DAG_RUN_ID'][
+                'env_var_format']
+        os.environ[dag_id_ctx_var_name] = 'test_dag_id'
+        os.environ[task_id_ctx_var_name] = 'test_task_id'
+        os.environ[execution_date_ctx_var_name] = 'test_execution_date'
+        os.environ[dag_run_id_ctx_var_name] = 'test_dag_run_id'
+
+        hook = HiveCliHook()
+        output = hook.run_cli(hql=hql, hive_conf={'key': 'value'})
+        self.assertIn('value', output)
+        self.assertIn('test_dag_id', output)
+        self.assertIn('test_task_id', output)
+        self.assertIn('test_execution_date', output)
+        self.assertIn('test_dag_run_id', output)
+
+        del os.environ[dag_id_ctx_var_name]
+        del os.environ[task_id_ctx_var_name]
+        del os.environ[execution_date_ctx_var_name]
+        del os.environ[dag_run_id_ctx_var_name]
+
     @mock.patch('airflow.hooks.hive_hooks.HiveCliHook.run_cli')
     def test_load_file(self, mock_run_cli):
         filepath = "/path/to/input/file"
@@ -419,3 +450,41 @@ class TestHiveServer2Hook(unittest.TestCase):
         hook = HiveServer2Hook()
         results = hook.get_records(sqls, schema=self.database)
         self.assertListEqual(results, [(1, 1), (2, 2)])
+
+    def test_get_results_with_hive_conf(self):
+        hql = ["set key",
+               "set airflow.ctx.dag_id",
+               "set airflow.ctx.dag_run_id",
+               "set airflow.ctx.task_id",
+               "set airflow.ctx.execution_date"]
+
+        dag_id_ctx_var_name = \
+            
AIRFLOW_VAR_NAME_FORMAT_MAPPING['AIRFLOW_CONTEXT_DAG_ID']['env_var_format']
+        task_id_ctx_var_name = \
+            
AIRFLOW_VAR_NAME_FORMAT_MAPPING['AIRFLOW_CONTEXT_TASK_ID']['env_var_format']
+        execution_date_ctx_var_name = \
+            AIRFLOW_VAR_NAME_FORMAT_MAPPING['AIRFLOW_CONTEXT_EXECUTION_DATE'][
+                'env_var_format']
+        dag_run_id_ctx_var_name = \
+            AIRFLOW_VAR_NAME_FORMAT_MAPPING['AIRFLOW_CONTEXT_DAG_RUN_ID'][
+                'env_var_format']
+        os.environ[dag_id_ctx_var_name] = 'test_dag_id'
+        os.environ[task_id_ctx_var_name] = 'test_task_id'
+        os.environ[execution_date_ctx_var_name] = 'test_execution_date'
+        os.environ[dag_run_id_ctx_var_name] = 'test_dag_run_id'
+
+        hook = HiveServer2Hook()
+        output = '\n'.join(res_tuple[0]
+                           for res_tuple
+                           in hook.get_results(hql=hql,
+                                               hive_conf={'key': 
'value'})['data'])
+        self.assertIn('value', output)
+        self.assertIn('test_dag_id', output)
+        self.assertIn('test_task_id', output)
+        self.assertIn('test_execution_date', output)
+        self.assertIn('test_dag_run_id', output)
+
+        del os.environ[dag_id_ctx_var_name]
+        del os.environ[task_id_ctx_var_name]
+        del os.environ[execution_date_ctx_var_name]
+        del os.environ[dag_run_id_ctx_var_name]

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/e9babff4/tests/operators/bash_operator.py
----------------------------------------------------------------------
diff --git a/tests/operators/bash_operator.py b/tests/operators/bash_operator.py
index 1ce77e9..e0a0ff3 100644
--- a/tests/operators/bash_operator.py
+++ b/tests/operators/bash_operator.py
@@ -12,8 +12,8 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import unittest
 import os
+import unittest
 from datetime import datetime, timedelta
 
 from airflow import DAG
@@ -59,7 +59,11 @@ class BashOperatorTestCase(unittest.TestCase):
                 task_id='echo_env_vars',
                 dag=self.dag,
                 bash_command='echo $AIRFLOW_HOME>> {0};'
-                             'echo $PYTHONPATH>> {0};'.format(fname)
+                             'echo $PYTHONPATH>> {0};'
+                             'echo $AIRFLOW_CTX_DAG_ID >> {0};'
+                             'echo $AIRFLOW_CTX_TASK_ID>> {0};'
+                             'echo $AIRFLOW_CTX_EXECUTION_DATE>> {0};'
+                             'echo $AIRFLOW_CTX_DAG_RUN_ID>> 
{0};'.format(fname)
             )
             os.environ['AIRFLOW_HOME'] = 'MY_PATH_TO_AIRFLOW_HOME'
             t.run(DEFAULT_DATE, DEFAULT_DATE,
@@ -70,3 +74,7 @@ class BashOperatorTestCase(unittest.TestCase):
                 self.assertIn('MY_PATH_TO_AIRFLOW_HOME', output)
                 # exported in run_unit_tests.sh as part of PYTHONPATH
                 self.assertIn('tests/test_utils', output)
+                self.assertIn('bash_op_test', output)
+                self.assertIn('echo_env_vars', output)
+                self.assertIn(DEFAULT_DATE.isoformat(), output)
+                self.assertIn('manual__' + DEFAULT_DATE.isoformat(), output)

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/e9babff4/tests/operators/python_operator.py
----------------------------------------------------------------------
diff --git a/tests/operators/python_operator.py 
b/tests/operators/python_operator.py
index 43aa8a6..735a4d7 100644
--- a/tests/operators/python_operator.py
+++ b/tests/operators/python_operator.py
@@ -20,28 +20,43 @@
 from __future__ import print_function, unicode_literals
 
 import copy
-import datetime
+import logging
+import os
 import unittest
+from datetime import timedelta
 
-from airflow import configuration, DAG
-from airflow.models import TaskInstance as TI
+from airflow import configuration
+from airflow.exceptions import AirflowException
+from airflow.models import TaskInstance as TI, DAG, DagRun
+from airflow.operators.dummy_operator import DummyOperator
 from airflow.operators.python_operator import PythonOperator, 
BranchPythonOperator
 from airflow.operators.python_operator import ShortCircuitOperator
-from airflow.operators.dummy_operator import DummyOperator
 from airflow.settings import Session
 from airflow.utils import timezone
 from airflow.utils.state import State
 
-from airflow.exceptions import AirflowException
-import logging
-
 DEFAULT_DATE = timezone.datetime(2016, 1, 1)
 END_DATE = timezone.datetime(2016, 1, 2)
-INTERVAL = datetime.timedelta(hours=12)
+INTERVAL = timedelta(hours=12)
 FROZEN_NOW = timezone.datetime(2016, 1, 2, 12, 1, 1)
 
+TI_CONTEXT_ENV_VARS = ['AIRFLOW_CTX_DAG_ID',
+                       'AIRFLOW_CTX_TASK_ID',
+                       'AIRFLOW_CTX_EXECUTION_DATE',
+                       'AIRFLOW_CTX_DAG_RUN_ID']
+
 
 class PythonOperatorTest(unittest.TestCase):
+    @classmethod
+    def setUpClass(cls):
+        super(PythonOperatorTest, cls).setUpClass()
+
+        session = Session()
+
+        session.query(DagRun).delete()
+        session.query(TI).delete()
+        session.commit()
+        session.close()
 
     def setUp(self):
         super(PythonOperatorTest, self).setUp()
@@ -56,6 +71,21 @@ class PythonOperatorTest(unittest.TestCase):
         self.clear_run()
         self.addCleanup(self.clear_run)
 
+    def tearDown(self):
+        super(PythonOperatorTest, self).tearDown()
+
+        session = Session()
+
+        session.query(DagRun).delete()
+        session.query(TI).delete()
+        print(len(session.query(DagRun).all()))
+        session.commit()
+        session.close()
+
+        for var in TI_CONTEXT_ENV_VARS:
+            if var in os.environ:
+                del os.environ[var]
+
     def do_run(self):
         self.run = True
 
@@ -107,8 +137,46 @@ class PythonOperatorTest(unittest.TestCase):
         self.assertEquals(id(original_task.python_callable),
                           id(new_task.python_callable))
 
+    def _env_var_check_callback(self):
+        self.assertEqual('test_dag', os.environ['AIRFLOW_CTX_DAG_ID'])
+        self.assertEqual('hive_in_python_op', 
os.environ['AIRFLOW_CTX_TASK_ID'])
+        self.assertEqual(DEFAULT_DATE.isoformat(),
+                         os.environ['AIRFLOW_CTX_EXECUTION_DATE'])
+        self.assertEqual('manual__' + DEFAULT_DATE.isoformat(),
+                         os.environ['AIRFLOW_CTX_DAG_RUN_ID'])
+
+    def test_echo_env_variables(self):
+        """
+        Test that env variables are exported correctly to the
+        python callback in the task.
+        """
+        self.dag.create_dagrun(
+            run_id='manual__' + DEFAULT_DATE.isoformat(),
+            execution_date=DEFAULT_DATE,
+            start_date=DEFAULT_DATE,
+            state=State.RUNNING,
+            external_trigger=False,
+        )
+
+        t = PythonOperator(task_id='hive_in_python_op',
+                           dag=self.dag,
+                           python_callable=self._env_var_check_callback
+                           )
+        t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+
 
 class BranchOperatorTest(unittest.TestCase):
+    @classmethod
+    def setUpClass(cls):
+        super(BranchOperatorTest, cls).setUpClass()
+
+        session = Session()
+
+        session.query(DagRun).delete()
+        session.query(TI).delete()
+        session.commit()
+        session.close()
+
     def setUp(self):
         self.dag = DAG('branch_operator_test',
                        default_args={
@@ -125,6 +193,17 @@ class BranchOperatorTest(unittest.TestCase):
         self.branch_2.set_upstream(self.branch_op)
         self.dag.clear()
 
+    def tearDown(self):
+        super(BranchOperatorTest, self).tearDown()
+
+        session = Session()
+
+        session.query(DagRun).delete()
+        session.query(TI).delete()
+        print(len(session.query(DagRun).all()))
+        session.commit()
+        session.close()
+
     def test_without_dag_run(self):
         """This checks the defensive against non existent tasks in a dag run"""
         self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
@@ -170,6 +249,27 @@ class BranchOperatorTest(unittest.TestCase):
 
 
 class ShortCircuitOperatorTest(unittest.TestCase):
+    @classmethod
+    def setUpClass(cls):
+        super(ShortCircuitOperatorTest, cls).setUpClass()
+
+        session = Session()
+
+        session.query(DagRun).delete()
+        session.query(TI).delete()
+        session.commit()
+        session.close()
+
+    def tearDown(self):
+        super(ShortCircuitOperatorTest, self).tearDown()
+
+        session = Session()
+
+        session.query(DagRun).delete()
+        session.query(TI).delete()
+        session.commit()
+        session.close()
+
     def test_without_dag_run(self):
         """This checks the defensive against non existent tasks in a dag run"""
         value = False

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/e9babff4/tests/utils/test_operator_helpers.py
----------------------------------------------------------------------
diff --git a/tests/utils/test_operator_helpers.py 
b/tests/utils/test_operator_helpers.py
index 592b456..a358601 100644
--- a/tests/utils/test_operator_helpers.py
+++ b/tests/utils/test_operator_helpers.py
@@ -7,9 +7,9 @@
 # to you under the Apache License, Version 2.0 (the
 # "License"); you may not use this file except in compliance
 # with the License.  You may obtain a copy of the License at
-# 
+#
 #   http://www.apache.org/licenses/LICENSE-2.0
-# 
+#
 # Unless required by applicable law or agreed to in writing,
 # software distributed under the License is distributed on an
 # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -17,9 +17,10 @@
 # specific language governing permissions and limitations
 # under the License.
 
+import unittest
 from datetime import datetime
+
 import mock
-import unittest
 
 from airflow.utils import operator_helpers
 
@@ -31,16 +32,18 @@ class TestOperatorHelpers(unittest.TestCase):
         self.dag_id = 'dag_id'
         self.task_id = 'task_id'
         self.execution_date = '2017-05-21T00:00:00'
+        self.dag_run_id = 'dag_run_id'
         self.context = {
-            'dag': mock.MagicMock(name='dag', dag_id=self.dag_id),
             'dag_run': mock.MagicMock(
                 name='dag_run',
+                run_id=self.dag_run_id,
                 execution_date=datetime.strptime(self.execution_date,
                                                  '%Y-%m-%dT%H:%M:%S'),
             ),
-            'task': mock.MagicMock(name='task', task_id=self.task_id),
             'task_instance': mock.MagicMock(
                 name='task_instance',
+                task_id=self.task_id,
+                dag_id=self.dag_id,
                 execution_date=datetime.strptime(self.execution_date,
                                                  '%Y-%m-%dT%H:%M:%S'),
             ),
@@ -53,10 +56,21 @@ class TestOperatorHelpers(unittest.TestCase):
         self.assertDictEqual(
             operator_helpers.context_to_airflow_vars(self.context),
             {
-                'airflow.ctx.dag.dag_id': self.dag_id,
-                'airflow.ctx.dag_run.execution_date': self.execution_date,
-                'airflow.ctx.task.task_id': self.task_id,
-                'airflow.ctx.task_instance.execution_date': 
self.execution_date,
+                'airflow.ctx.dag_id': self.dag_id,
+                'airflow.ctx.execution_date': self.execution_date,
+                'airflow.ctx.task_id': self.task_id,
+                'airflow.ctx.dag_run_id': self.dag_run_id,
+            }
+        )
+
+        self.assertDictEqual(
+            operator_helpers.context_to_airflow_vars(self.context,
+                                                     in_env_var_format=True),
+            {
+                'AIRFLOW_CTX_DAG_ID': self.dag_id,
+                'AIRFLOW_CTX_EXECUTION_DATE': self.execution_date,
+                'AIRFLOW_CTX_TASK_ID': self.task_id,
+                'AIRFLOW_CTX_DAG_RUN_ID': self.dag_run_id,
             }
         )
 

Reply via email to