This is an automated email from the ASF dual-hosted git repository.
eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new fd03dc2933 Fix reraise outside of try block in
`AthenaHook.get_output_location` (#36008)
fd03dc2933 is described below
commit fd03dc29336e1331d20de0113993dd5a35353ee0
Author: Andrey Anshin <[email protected]>
AuthorDate: Fri Dec 1 20:53:04 2023 +0400
Fix reraise outside of try block in `AthenaHook.get_output_location`
(#36008)
---
airflow/providers/amazon/aws/hooks/athena.py | 25 +++++++++++--------------
tests/providers/amazon/aws/hooks/test_athena.py | 25 +++++++++++++++++++++++++
2 files changed, 36 insertions(+), 14 deletions(-)
diff --git a/airflow/providers/amazon/aws/hooks/athena.py
b/airflow/providers/amazon/aws/hooks/athena.py
index 3715b4a1bc..04853621f1 100644
--- a/airflow/providers/amazon/aws/hooks/athena.py
+++ b/airflow/providers/amazon/aws/hooks/athena.py
@@ -292,20 +292,17 @@ class AthenaHook(AwsBaseHook):
:param query_execution_id: Id of submitted athena query
"""
- if query_execution_id:
- response =
self.get_query_info(query_execution_id=query_execution_id, use_cache=True)
-
- if response:
- try:
- return
response["QueryExecution"]["ResultConfiguration"]["OutputLocation"]
- except KeyError:
- self.log.error(
- "Error retrieving OutputLocation. Query execution id:
%s", query_execution_id
- )
- raise
- else:
- raise
- raise ValueError("Invalid Query execution id. Query execution id: %s",
query_execution_id)
+ if not query_execution_id:
+ raise ValueError(f"Invalid Query execution id. Query execution id:
{query_execution_id}")
+
+ if not (response :=
self.get_query_info(query_execution_id=query_execution_id, use_cache=True)):
+ raise ValueError(f"Unable to get query information for execution
id: {query_execution_id}")
+
+ try:
+ return
response["QueryExecution"]["ResultConfiguration"]["OutputLocation"]
+ except KeyError:
+ self.log.error("Error retrieving OutputLocation. Query execution
id: %s", query_execution_id)
+ raise
def stop_query(self, query_execution_id: str) -> dict:
"""Cancel the submitted query.
diff --git a/tests/providers/amazon/aws/hooks/test_athena.py
b/tests/providers/amazon/aws/hooks/test_athena.py
index 8f224f0b2d..a61663a8fb 100644
--- a/tests/providers/amazon/aws/hooks/test_athena.py
+++ b/tests/providers/amazon/aws/hooks/test_athena.py
@@ -18,6 +18,8 @@ from __future__ import annotations
from unittest import mock
+import pytest
+
from airflow.providers.amazon.aws.hooks.athena import AthenaHook
MOCK_DATA = {
@@ -197,6 +199,29 @@ class TestAthenaHook:
result =
self.athena.get_output_location(query_execution_id=MOCK_DATA["query_execution_id"])
assert result == "s3://test_bucket/test.csv"
+ @pytest.mark.parametrize(
+ "query_execution_id", [pytest.param("", id="empty-string"),
pytest.param(None, id="none")]
+ )
+ def test_hook_get_output_location_empty_execution_id(self,
query_execution_id):
+ with pytest.raises(ValueError, match="Invalid Query execution id"):
+
self.athena.get_output_location(query_execution_id=query_execution_id)
+
+ @pytest.mark.parametrize("response", [pytest.param({}, id="empty-dict"),
pytest.param(None, id="none")])
+ def test_hook_get_output_location_no_response(self, response):
+ with mock.patch.object(AthenaHook, "get_query_info",
return_value=response) as m:
+ with pytest.raises(ValueError, match="Unable to get query
information"):
+
self.athena.get_output_location(query_execution_id="PLACEHOLDER")
+ m.assert_called_once_with(query_execution_id="PLACEHOLDER",
use_cache=True)
+
+ def test_hook_get_output_location_invalid_response(self, caplog):
+ with mock.patch.object(AthenaHook, "get_query_info") as m:
+ m.return_value = {"foo": "bar"}
+ caplog.clear()
+ caplog.set_level("ERROR")
+ with pytest.raises(KeyError):
+
self.athena.get_output_location(query_execution_id="PLACEHOLDER")
+ assert "Error retrieving OutputLocation" in caplog.text
+
@mock.patch.object(AthenaHook, "get_conn")
def test_hook_get_query_info_caching(self, mock_conn):
mock_conn.return_value.get_query_execution.return_value =
MOCK_QUERY_EXECUTION_OUTPUT