This is an automated email from the ASF dual-hosted git repository.
taragolis pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 169ce92249 Consolidate hook management in HiveOperator (#34430)
169ce92249 is described below
commit 169ce92249d700c5ad1a4fdac35ba4feb8feee04
Author: Hussein Awala <[email protected]>
AuthorDate: Mon Sep 18 21:06:44 2023 +0200
Consolidate hook management in HiveOperator (#34430)
* Consolidate hook management in HiveOperator
* use AirflowProviderDeprecationWarning
---
airflow/providers/apache/hive/operators/hive.py | 20 +++++++++++---------
tests/providers/apache/hive/operators/test_hive.py | 10 ++++------
2 files changed, 15 insertions(+), 15 deletions(-)
diff --git a/airflow/providers/apache/hive/operators/hive.py
b/airflow/providers/apache/hive/operators/hive.py
index 71bd8ac49b..640943467f 100644
--- a/airflow/providers/apache/hive/operators/hive.py
+++ b/airflow/providers/apache/hive/operators/hive.py
@@ -19,9 +19,13 @@ from __future__ import annotations
import os
import re
+from functools import cached_property
from typing import TYPE_CHECKING, Any, Sequence
+from deprecated.classic import deprecated
+
from airflow.configuration import conf
+from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.models import BaseOperator
from airflow.providers.apache.hive.hooks.hive import HiveCliHook
from airflow.utils import operator_helpers
@@ -116,13 +120,8 @@ class HiveOperator(BaseOperator):
)
self.mapred_job_name_template: str = job_name_template
- # assigned lazily - just for consistency we can create the attribute
with a
- # `None` initial value, later it will be populated by the execute
method.
- # This also makes `on_kill` implementation consistent since it assumes
`self.hook`
- # is defined.
- self.hook: HiveCliHook | None = None
-
- def get_hook(self) -> HiveCliHook:
+ @cached_property
+ def hook(self) -> HiveCliHook:
"""Get Hive cli hook."""
return HiveCliHook(
hive_cli_conn_id=self.hive_cli_conn_id,
@@ -134,6 +133,11 @@ class HiveOperator(BaseOperator):
auth=self.auth,
)
+ @deprecated(reason="use `hook` property instead.",
category=AirflowProviderDeprecationWarning)
+ def get_hook(self) -> HiveCliHook:
+ """Get Hive cli hook."""
+ return self.hook
+
def prepare_template(self) -> None:
if self.hiveconf_jinja_translate:
self.hql = re.sub(r"(\$\{(hiveconf:)?([ a-zA-Z0-9_]*)\})", r"{{
\g<3> }}", self.hql)
@@ -142,7 +146,6 @@ class HiveOperator(BaseOperator):
def execute(self, context: Context) -> None:
self.log.info("Executing: %s", self.hql)
- self.hook = self.get_hook()
# set the mapred_job_name if it's not set with dag, task, execution
time info
if not self.mapred_job_name:
@@ -169,7 +172,6 @@ class HiveOperator(BaseOperator):
# existing env vars from impacting behavior.
self.clear_airflow_vars()
- self.hook = self.get_hook()
self.hook.test_hql(hql=self.hql)
def on_kill(self) -> None:
diff --git a/tests/providers/apache/hive/operators/test_hive.py
b/tests/providers/apache/hive/operators/test_hive.py
index d64a2f7bc3..f02f69c2a4 100644
--- a/tests/providers/apache/hive/operators/test_hive.py
+++ b/tests/providers/apache/hive/operators/test_hive.py
@@ -41,7 +41,7 @@ class HiveOperatorConfigTest(TestHiveEnvironment):
# just check that the correct default value in test_default.cfg is used
test_config_hive_mapred_queue = conf.get("hive",
"default_hive_mapred_queue")
- assert op.get_hook().mapred_queue == test_config_hive_mapred_queue
+ assert op.hook.mapred_queue == test_config_hive_mapred_queue
def test_hive_airflow_default_config_queue_override(self):
specific_mapred_queue = "default"
@@ -54,7 +54,7 @@ class HiveOperatorConfigTest(TestHiveEnvironment):
dag=self.dag,
)
- assert op.get_hook().mapred_queue == specific_mapred_queue
+ assert op.hook.mapred_queue == specific_mapred_queue
class HiveOperatorTest(TestHiveEnvironment):
@@ -75,10 +75,8 @@ class HiveOperatorTest(TestHiveEnvironment):
op.prepare_template()
assert op.hql == "SELECT * FROM ${hiveconf:table} PARTITION
(${hiveconf:day});"
-
@mock.patch("airflow.providers.apache.hive.operators.hive.HiveOperator.get_hook")
- def test_mapred_job_name(self, mock_get_hook):
- mock_hook = mock.MagicMock()
- mock_get_hook.return_value = mock_hook
+
@mock.patch("airflow.providers.apache.hive.operators.hive.HiveOperator.hook",
mock.MagicMock())
+ def test_mapred_job_name(self, mock_hook):
op = HiveOperator(task_id="test_mapred_job_name", hql=self.hql,
dag=self.dag)
fake_run_id = "test_mapred_job_name"