This is an automated email from the ASF dual-hosted git repository.
jscheffl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 18fb7055ef0 Add unit tests for Gen AI operator exception handling.
(#61790)
18fb7055ef0 is described below
commit 18fb7055ef0c2d425989784060c6b0cfcc9a4604
Author: Yorgos Toprakchioglu <[email protected]>
AuthorDate: Sat Feb 21 19:50:27 2026 +0000
Add unit tests for Gen AI operator exception handling. (#61790)
* Add unit tests for Gen AI operator exception handling.
* Remove strict match assertions from Gen AI operator unit tests.
* Fix ClientError instantiation.
* Fix ClientError exception for lowest-direct-dependencies test.
---
.../unit/google/cloud/operators/test_gen_ai.py | 470 +++++++++++++++++++++
1 file changed, 470 insertions(+)
diff --git a/providers/google/tests/unit/google/cloud/operators/test_gen_ai.py
b/providers/google/tests/unit/google/cloud/operators/test_gen_ai.py
index c532f8b3e5c..761a835678c 100644
--- a/providers/google/tests/unit/google/cloud/operators/test_gen_ai.py
+++ b/providers/google/tests/unit/google/cloud/operators/test_gen_ai.py
@@ -18,6 +18,8 @@ from __future__ import annotations
from unittest import mock
+import pytest
+from google.genai.errors import ClientError
from google.genai.types import (
Content,
CreateCachedContentConfig,
@@ -28,6 +30,7 @@ from google.genai.types import (
TuningDataset,
)
+from airflow.exceptions import AirflowException
from airflow.providers.google.cloud.operators.gen_ai import (
GenAICountTokensOperator,
GenAICreateCachedContentOperator,
@@ -368,6 +371,116 @@ class TestGenAIGeminiCreateBatchJobOperator:
mock_hook.return_value.get_batch_job.assert_called_once_with("test-name")
mock_job.model_dump.assert_called_once_with(mode="json")
+ def
test_init_retrieve_result_and_not_wait_until_complete_raises_airflow_exception(self):
+ with pytest.raises(AirflowException):
+ GenAIGeminiCreateBatchJobOperator(
+ task_id=TASK_ID,
+ project_id=GCP_PROJECT,
+ location=GCP_LOCATION,
+ model=TEST_GEMINI_MODEL,
+ gcp_conn_id=GCP_CONN_ID,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ input_source=TEST_BATCH_JOB_INLINED_REQUESTS,
+ gemini_api_key=TEST_GEMINI_API_KEY,
+ wait_until_complete=False,
+ retrieve_result=True,
+ )
+
+ def test_init_input_source_not_string_raises_airflow_exception(self):
+ with pytest.raises(AirflowException):
+ GenAIGeminiCreateBatchJobOperator(
+ task_id=TASK_ID,
+ project_id=GCP_PROJECT,
+ location=GCP_LOCATION,
+ model=TEST_GEMINI_MODEL,
+ gcp_conn_id=GCP_CONN_ID,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ input_source=TEST_BATCH_JOB_INLINED_REQUESTS,
+ gemini_api_key=TEST_GEMINI_API_KEY,
+ wait_until_complete=False,
+ results_folder=TEST_FILE_PATH,
+ )
+
+ def test_init_results_folder_not_exists_raises_airflow_exception(self):
+ with pytest.raises(AirflowException):
+ GenAIGeminiCreateBatchJobOperator(
+ task_id=TASK_ID,
+ project_id=GCP_PROJECT,
+ location=GCP_LOCATION,
+ model=TEST_GEMINI_MODEL,
+ gcp_conn_id=GCP_CONN_ID,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ input_source=TEST_FILE_NAME,
+ gemini_api_key=TEST_GEMINI_API_KEY,
+ wait_until_complete=False,
+ results_folder=TEST_FILE_PATH,
+ )
+
+ @mock.patch(GEN_AI_PATH.format("GenAIGeminiAPIHook"))
+ def test__wait_until_complete_exception_raises_airflow_exception(self,
mock_hook):
+ op = GenAIGeminiCreateBatchJobOperator(
+ task_id=TASK_ID,
+ project_id=GCP_PROJECT,
+ location=GCP_LOCATION,
+ model=TEST_GEMINI_MODEL,
+ gcp_conn_id=GCP_CONN_ID,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ input_source=TEST_BATCH_JOB_INLINED_REQUESTS,
+ gemini_api_key=TEST_GEMINI_API_KEY,
+ )
+
+ mock_hook.return_value.get_batch_job.side_effect = Exception()
+
+ with pytest.raises(AirflowException):
+ op._wait_until_complete(job=mock.MagicMock())
+
+ @mock.patch(GEN_AI_PATH.format("GenAIGeminiAPIHook"))
+ def test_execute_exception_error_raises_airflow_exception(self, mock_hook):
+ op = GenAIGeminiCreateBatchJobOperator(
+ task_id=TASK_ID,
+ project_id=GCP_PROJECT,
+ location=GCP_LOCATION,
+ model=TEST_GEMINI_MODEL,
+ gcp_conn_id=GCP_CONN_ID,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ input_source=TEST_BATCH_JOB_INLINED_REQUESTS,
+ gemini_api_key=TEST_GEMINI_API_KEY,
+ )
+
+ mock_hook.return_value.create_batch_job.side_effect = Exception()
+
+ with pytest.raises(AirflowException):
+ op.execute(context={"ti": mock.MagicMock()})
+
+ mock_hook.assert_called_once_with(
+ gcp_conn_id=GCP_CONN_ID,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ gemini_api_key=TEST_GEMINI_API_KEY,
+ )
+
+ mock_hook.return_value.create_batch_job.assert_called_once_with(
+ source=TEST_BATCH_JOB_INLINED_REQUESTS,
+ model=TEST_GEMINI_MODEL,
+ create_batch_job_config=None,
+ )
+
+ def test_execute_complete_error_status_raises_airflow_exception(self):
+ op = GenAIGeminiCreateBatchJobOperator(
+ task_id=TASK_ID,
+ project_id=GCP_PROJECT,
+ location=GCP_LOCATION,
+ model=TEST_GEMINI_MODEL,
+ gcp_conn_id=GCP_CONN_ID,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ input_source=TEST_BATCH_JOB_INLINED_REQUESTS,
+ gemini_api_key=TEST_GEMINI_API_KEY,
+ )
+
+ event = {"status": "error", "message": "test-message"}
+
+ with pytest.raises(AirflowException):
+ op.execute_complete(context={"ti": mock.MagicMock()}, event=event)
+
class TestGenAIGeminiGetBatchJobOperator:
@mock.patch(GEN_AI_PATH.format("GenAIGeminiAPIHook"))
@@ -414,6 +527,33 @@ class TestGenAIGeminiGetBatchJobOperator:
assert result == expected_return
mock_job.model_dump.assert_called_once_with(mode="json")
+ @mock.patch(GEN_AI_PATH.format("GenAIGeminiAPIHook"))
+ def test_execute_value_error_raises_airflow_exception(self, mock_hook):
+ op = GenAIGeminiGetBatchJobOperator(
+ task_id=TASK_ID,
+ project_id=GCP_PROJECT,
+ location=GCP_LOCATION,
+ job_name=TEST_BATCH_JOB_NAME,
+ gcp_conn_id=GCP_CONN_ID,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ gemini_api_key=TEST_GEMINI_API_KEY,
+ )
+
+ mock_hook.return_value.get_batch_job.side_effect = ValueError()
+
+ with pytest.raises(AirflowException):
+ op.execute(context={"ti": mock.MagicMock()})
+
+ mock_hook.assert_called_once_with(
+ gcp_conn_id=GCP_CONN_ID,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ gemini_api_key=TEST_GEMINI_API_KEY,
+ )
+
+ mock_hook.return_value.get_batch_job.assert_called_once_with(
+ job_name=TEST_BATCH_JOB_NAME,
+ )
+
class TestGenAIGeminiListBatchJobsOperator:
@mock.patch(GEN_AI_PATH.format("GenAIGeminiAPIHook"))
@@ -458,6 +598,60 @@ class TestGenAIGeminiDeleteBatchJobOperator:
job_name=TEST_BATCH_JOB_NAME,
)
+ @mock.patch(GEN_AI_PATH.format("GenAIGeminiAPIHook"))
+ def test_execute_value_error_raises_airflow_exception(self, mock_hook):
+ op = GenAIGeminiDeleteBatchJobOperator(
+ task_id=TASK_ID,
+ project_id=GCP_PROJECT,
+ location=GCP_LOCATION,
+ job_name=TEST_BATCH_JOB_NAME,
+ gemini_api_key=TEST_GEMINI_API_KEY,
+ gcp_conn_id=GCP_CONN_ID,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ )
+
+ mock_hook.return_value.delete_batch_job.side_effect = ValueError()
+
+ with pytest.raises(AirflowException):
+ op.execute(context={"ti": mock.MagicMock()})
+
+ mock_hook.assert_called_once_with(
+ gcp_conn_id=GCP_CONN_ID,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ gemini_api_key=TEST_GEMINI_API_KEY,
+ )
+
+ mock_hook.return_value.delete_batch_job.assert_called_once_with(
+ job_name=TEST_BATCH_JOB_NAME,
+ )
+
+ @mock.patch(GEN_AI_PATH.format("GenAIGeminiAPIHook"))
+ def test_execute_job_error_raises_airflow_exception(self, mock_hook):
+ op = GenAIGeminiDeleteBatchJobOperator(
+ task_id=TASK_ID,
+ project_id=GCP_PROJECT,
+ location=GCP_LOCATION,
+ job_name=TEST_BATCH_JOB_NAME,
+ gemini_api_key=TEST_GEMINI_API_KEY,
+ gcp_conn_id=GCP_CONN_ID,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ )
+
+ mock_hook.return_value.delete_batch_job.return_value =
mock.MagicMock(error="Test error")
+
+ with pytest.raises(AirflowException):
+ op.execute(context={"ti": mock.MagicMock()})
+
+ mock_hook.assert_called_once_with(
+ gcp_conn_id=GCP_CONN_ID,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ gemini_api_key=TEST_GEMINI_API_KEY,
+ )
+
+ mock_hook.return_value.delete_batch_job.assert_called_once_with(
+ job_name=TEST_BATCH_JOB_NAME,
+ )
+
class TestGenAIGeminiCancelBatchJobOperator:
@mock.patch(GEN_AI_PATH.format("GenAIGeminiAPIHook"))
@@ -481,6 +675,33 @@ class TestGenAIGeminiCancelBatchJobOperator:
job_name=TEST_BATCH_JOB_NAME,
)
+ @mock.patch(GEN_AI_PATH.format("GenAIGeminiAPIHook"))
+ def test_execute_value_error_raises_airflow_exception(self, mock_hook):
+ op = GenAIGeminiCancelBatchJobOperator(
+ task_id=TASK_ID,
+ project_id=GCP_PROJECT,
+ location=GCP_LOCATION,
+ job_name=TEST_BATCH_JOB_NAME,
+ gemini_api_key=TEST_GEMINI_API_KEY,
+ gcp_conn_id=GCP_CONN_ID,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ )
+
+ mock_hook.return_value.cancel_batch_job.side_effect = ValueError()
+
+ with pytest.raises(AirflowException):
+ op.execute(context={"ti": mock.MagicMock()})
+
+ mock_hook.assert_called_once_with(
+ gcp_conn_id=GCP_CONN_ID,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ gemini_api_key=TEST_GEMINI_API_KEY,
+ )
+
+ mock_hook.return_value.cancel_batch_job.assert_called_once_with(
+ job_name=TEST_BATCH_JOB_NAME,
+ )
+
class TestGenAIGeminiCreateEmbeddingsBatchJobOperator:
@mock.patch(GEN_AI_PATH.format("GenAIGeminiAPIHook"))
@@ -508,6 +729,117 @@ class TestGenAIGeminiCreateEmbeddingsBatchJobOperator:
create_embeddings_config=None,
)
+ def
test_init_retrieve_result_and_not_wait_until_complete_raises_airflow_exception(self):
+ with pytest.raises(AirflowException):
+ GenAIGeminiCreateEmbeddingsBatchJobOperator(
+ task_id=TASK_ID,
+ project_id=GCP_PROJECT,
+ location=GCP_LOCATION,
+ input_source=TEST_EMBEDDINGS_JOB_INLINED_REQUESTS,
+ model=EMBEDDING_MODEL,
+ gemini_api_key=TEST_GEMINI_API_KEY,
+ gcp_conn_id=GCP_CONN_ID,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ wait_until_complete=False,
+ retrieve_result=True,
+ )
+
+ def test_init_input_source_not_string_raises_airflow_exception(self):
+ with pytest.raises(AirflowException):
+ GenAIGeminiCreateEmbeddingsBatchJobOperator(
+ task_id=TASK_ID,
+ project_id=GCP_PROJECT,
+ location=GCP_LOCATION,
+ input_source=TEST_EMBEDDINGS_JOB_INLINED_REQUESTS,
+ model=EMBEDDING_MODEL,
+ gemini_api_key=TEST_GEMINI_API_KEY,
+ gcp_conn_id=GCP_CONN_ID,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ wait_until_complete=False,
+ results_folder=TEST_FILE_PATH,
+ )
+
+ def test_init_results_folder_not_exists_raises_airflow_exception(self):
+ with pytest.raises(AirflowException):
+ GenAIGeminiCreateEmbeddingsBatchJobOperator(
+ task_id=TASK_ID,
+ project_id=GCP_PROJECT,
+ location=GCP_LOCATION,
+ input_source=TEST_FILE_NAME,
+ model=EMBEDDING_MODEL,
+ gemini_api_key=TEST_GEMINI_API_KEY,
+ gcp_conn_id=GCP_CONN_ID,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ wait_until_complete=False,
+ results_folder=TEST_FILE_PATH,
+ )
+
+ @mock.patch(GEN_AI_PATH.format("GenAIGeminiAPIHook"))
+ def test__wait_until_complete_exception_raises_airflow_exception(self,
mock_hook):
+ op = GenAIGeminiCreateEmbeddingsBatchJobOperator(
+ task_id=TASK_ID,
+ project_id=GCP_PROJECT,
+ location=GCP_LOCATION,
+ input_source=TEST_EMBEDDINGS_JOB_INLINED_REQUESTS,
+ model=EMBEDDING_MODEL,
+ gemini_api_key=TEST_GEMINI_API_KEY,
+ gcp_conn_id=GCP_CONN_ID,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ )
+
+ mock_hook.return_value.get_batch_job.side_effect = Exception()
+
+ with pytest.raises(AirflowException):
+ op._wait_until_complete(job=mock.MagicMock())
+
+ @mock.patch(GEN_AI_PATH.format("GenAIGeminiAPIHook"))
+ def test_execute_exception_error_raises_airflow_exception(self, mock_hook):
+ op = GenAIGeminiCreateEmbeddingsBatchJobOperator(
+ task_id=TASK_ID,
+ project_id=GCP_PROJECT,
+ location=GCP_LOCATION,
+ input_source=TEST_EMBEDDINGS_JOB_INLINED_REQUESTS,
+ model=EMBEDDING_MODEL,
+ gemini_api_key=TEST_GEMINI_API_KEY,
+ gcp_conn_id=GCP_CONN_ID,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ wait_until_complete=False,
+ )
+
+ mock_hook.return_value.create_embeddings.side_effect = Exception()
+
+ with pytest.raises(AirflowException):
+ op.execute(context={"ti": mock.MagicMock()})
+
+ mock_hook.assert_called_once_with(
+ gcp_conn_id=GCP_CONN_ID,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ gemini_api_key=TEST_GEMINI_API_KEY,
+ )
+
+ mock_hook.return_value.create_embeddings.assert_called_once_with(
+ source=TEST_EMBEDDINGS_JOB_INLINED_REQUESTS,
+ model=EMBEDDING_MODEL,
+ create_embeddings_config=None,
+ )
+
+ def test_execute_complete_error_status_raises_airflow_exception(self):
+ op = GenAIGeminiCreateEmbeddingsBatchJobOperator(
+ task_id=TASK_ID,
+ project_id=GCP_PROJECT,
+ location=GCP_LOCATION,
+ input_source=TEST_EMBEDDINGS_JOB_INLINED_REQUESTS,
+ model=EMBEDDING_MODEL,
+ gemini_api_key=TEST_GEMINI_API_KEY,
+ gcp_conn_id=GCP_CONN_ID,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ )
+
+ event = {"status": "error", "message": "test-message"}
+
+ with pytest.raises(AirflowException):
+ op.execute_complete(context={"ti": mock.MagicMock()}, event=event)
+
class TestGenAIGeminiGetFileOperator:
@mock.patch(GEN_AI_PATH.format("GenAIGeminiAPIHook"))
@@ -531,6 +863,33 @@ class TestGenAIGeminiGetFileOperator:
file_name=TEST_FILE_NAME,
)
+ @mock.patch(GEN_AI_PATH.format("GenAIGeminiAPIHook"))
+ def test_execute_client_error_raises_airflow_exception(self, mock_hook):
+ op = GenAIGeminiGetFileOperator(
+ task_id=TASK_ID,
+ project_id=GCP_PROJECT,
+ location=GCP_LOCATION,
+ gemini_api_key=TEST_GEMINI_API_KEY,
+ file_name=TEST_FILE_NAME,
+ gcp_conn_id=GCP_CONN_ID,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ )
+
+ mock_hook.return_value.get_file.side_effect =
ClientError.__new__(ClientError)
+
+ with pytest.raises(AirflowException):
+ op.execute(context={"ti": mock.MagicMock()})
+
+ mock_hook.assert_called_once_with(
+ gcp_conn_id=GCP_CONN_ID,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ gemini_api_key=TEST_GEMINI_API_KEY,
+ )
+
+ mock_hook.return_value.get_file.assert_called_once_with(
+ file_name=TEST_FILE_NAME,
+ )
+
class TestGenAIGeminiUploadFileOperator:
@mock.patch(GEN_AI_PATH.format("GenAIGeminiAPIHook"))
@@ -555,6 +914,90 @@ class TestGenAIGeminiUploadFileOperator:
upload_file_config=None,
)
+ @mock.patch(GEN_AI_PATH.format("GenAIGeminiAPIHook"))
+ def test_execute_runtime_error_raises_runtime_error(self, mock_hook):
+ op = GenAIGeminiUploadFileOperator(
+ task_id=TASK_ID,
+ project_id=GCP_PROJECT,
+ location=GCP_LOCATION,
+ file_path=TEST_FILE_PATH,
+ gemini_api_key=TEST_GEMINI_API_KEY,
+ gcp_conn_id=GCP_CONN_ID,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ )
+
+ mock_hook.return_value.upload_file.side_effect = RuntimeError()
+
+ with pytest.raises(RuntimeError):
+ op.execute(context={"ti": mock.MagicMock()})
+
+ mock_hook.assert_called_once_with(
+ gcp_conn_id=GCP_CONN_ID,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ gemini_api_key=TEST_GEMINI_API_KEY,
+ )
+
+ mock_hook.return_value.upload_file.assert_called_once_with(
+ path_to_file=TEST_FILE_PATH,
+ upload_file_config=None,
+ )
+
+ @mock.patch(GEN_AI_PATH.format("GenAIGeminiAPIHook"))
+ def test_execute_value_error_raises_airflow_exception(self, mock_hook):
+ op = GenAIGeminiUploadFileOperator(
+ task_id=TASK_ID,
+ project_id=GCP_PROJECT,
+ location=GCP_LOCATION,
+ file_path=TEST_FILE_PATH,
+ gemini_api_key=TEST_GEMINI_API_KEY,
+ gcp_conn_id=GCP_CONN_ID,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ )
+
+ mock_hook.return_value.upload_file.side_effect = ValueError()
+
+ with pytest.raises(AirflowException):
+ op.execute(context={"ti": mock.MagicMock()})
+
+ mock_hook.assert_called_once_with(
+ gcp_conn_id=GCP_CONN_ID,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ gemini_api_key=TEST_GEMINI_API_KEY,
+ )
+
+ mock_hook.return_value.upload_file.assert_called_once_with(
+ path_to_file=TEST_FILE_PATH,
+ upload_file_config=None,
+ )
+
+ @mock.patch(GEN_AI_PATH.format("GenAIGeminiAPIHook"))
+ def test_execute_file_not_found_error_raises_airflow_exception(self,
mock_hook):
+ op = GenAIGeminiUploadFileOperator(
+ task_id=TASK_ID,
+ project_id=GCP_PROJECT,
+ location=GCP_LOCATION,
+ file_path=TEST_FILE_PATH,
+ gemini_api_key=TEST_GEMINI_API_KEY,
+ gcp_conn_id=GCP_CONN_ID,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ )
+
+ mock_hook.return_value.upload_file.side_effect = FileNotFoundError()
+
+ with pytest.raises(AirflowException):
+ op.execute(context={"ti": mock.MagicMock()})
+
+ mock_hook.assert_called_once_with(
+ gcp_conn_id=GCP_CONN_ID,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ gemini_api_key=TEST_GEMINI_API_KEY,
+ )
+
+ mock_hook.return_value.upload_file.assert_called_once_with(
+ path_to_file=TEST_FILE_PATH,
+ upload_file_config=None,
+ )
+
class TestGenAIGeminiListFilesOperator:
@mock.patch(GEN_AI_PATH.format("GenAIGeminiAPIHook"))
@@ -597,3 +1040,30 @@ class TestGenAIGeminiDeleteFileOperator:
mock_hook.return_value.delete_file.assert_called_once_with(
file_name=TEST_FILE_NAME,
)
+
+ @mock.patch(GEN_AI_PATH.format("GenAIGeminiAPIHook"))
+ def test_execute_client_error_raises_airflow_exception(self, mock_hook):
+ op = GenAIGeminiDeleteFileOperator(
+ task_id=TASK_ID,
+ project_id=GCP_PROJECT,
+ location=GCP_LOCATION,
+ file_name=TEST_FILE_NAME,
+ gemini_api_key=TEST_GEMINI_API_KEY,
+ gcp_conn_id=GCP_CONN_ID,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ )
+
+ mock_hook.return_value.delete_file.side_effect =
ClientError.__new__(ClientError)
+
+ with pytest.raises(AirflowException):
+ op.execute(context={"ti": mock.MagicMock()})
+
+ mock_hook.assert_called_once_with(
+ gcp_conn_id=GCP_CONN_ID,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ gemini_api_key=TEST_GEMINI_API_KEY,
+ )
+
+ mock_hook.return_value.delete_file.assert_called_once_with(
+ file_name=TEST_FILE_NAME,
+ )