This is an automated email from the ASF dual-hosted git repository.
eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 0bb56315e6 Added `overrides` parameter to CloudRunExecuteJobOperator
(#34874)
0bb56315e6 is described below
commit 0bb56315e664875cd764486bb2090e0a2ef747d8
Author: Chloe Sheasby <[email protected]>
AuthorDate: Wed Oct 25 14:21:41 2023 -0500
Added `overrides` parameter to CloudRunExecuteJobOperator (#34874)
---
airflow/providers/google/cloud/hooks/cloud_run.py | 12 +++-
.../providers/google/cloud/operators/cloud_run.py | 7 ++-
.../operators/cloud/cloud_run.rst | 9 +++
.../providers/google/cloud/hooks/test_cloud_run.py | 15 ++++-
.../google/cloud/operators/test_cloud_run.py | 68 +++++++++++++++++++++-
.../google/cloud/cloud_run/example_cloud_run.py | 32 +++++++++-
6 files changed, 133 insertions(+), 10 deletions(-)
diff --git a/airflow/providers/google/cloud/hooks/cloud_run.py
b/airflow/providers/google/cloud/hooks/cloud_run.py
index 6cc3b304dc..8741aa1d4d 100644
--- a/airflow/providers/google/cloud/hooks/cloud_run.py
+++ b/airflow/providers/google/cloud/hooks/cloud_run.py
@@ -18,7 +18,7 @@
from __future__ import annotations
import itertools
-from typing import TYPE_CHECKING, Iterable, Sequence
+from typing import TYPE_CHECKING, Any, Iterable, Sequence
from google.cloud.run_v2 import (
CreateJobRequest,
@@ -113,9 +113,15 @@ class CloudRunHook(GoogleBaseHook):
@GoogleBaseHook.fallback_to_default_project_id
def execute_job(
- self, job_name: str, region: str, project_id: str = PROVIDE_PROJECT_ID
+ self,
+ job_name: str,
+ region: str,
+ project_id: str = PROVIDE_PROJECT_ID,
+ overrides: dict[str, Any] | None = None,
) -> operation.Operation:
- run_job_request =
RunJobRequest(name=f"projects/{project_id}/locations/{region}/jobs/{job_name}")
+ run_job_request = RunJobRequest(
+ name=f"projects/{project_id}/locations/{region}/jobs/{job_name}",
overrides=overrides
+ )
operation = self.get_conn().run_job(request=run_job_request)
return operation
diff --git a/airflow/providers/google/cloud/operators/cloud_run.py
b/airflow/providers/google/cloud/operators/cloud_run.py
index ba50ea111d..14d27810da 100644
--- a/airflow/providers/google/cloud/operators/cloud_run.py
+++ b/airflow/providers/google/cloud/operators/cloud_run.py
@@ -17,7 +17,7 @@
# under the License.
from __future__ import annotations
-from typing import TYPE_CHECKING, Sequence
+from typing import TYPE_CHECKING, Any, Sequence
from google.cloud.run_v2 import Job
@@ -248,6 +248,7 @@ class CloudRunExecuteJobOperator(GoogleCloudBaseOperator):
:param job_name: Required. The name of the job to update.
:param job: Required. The job descriptor containing the new configuration
of the job to update.
The name field will be replaced by job_name
+ :param overrides: Optional map of override values.
:param gcp_conn_id: The connection ID used to connect to Google Cloud.
:param polling_period_seconds: Optional: Control the rate of the poll for
the result of deferrable run.
By default, the trigger will poll every 10 seconds.
@@ -270,6 +271,7 @@ class CloudRunExecuteJobOperator(GoogleCloudBaseOperator):
project_id: str,
region: str,
job_name: str,
+ overrides: dict[str, Any] | None = None,
polling_period_seconds: float = 10,
timeout_seconds: float | None = None,
gcp_conn_id: str = "google_cloud_default",
@@ -281,6 +283,7 @@ class CloudRunExecuteJobOperator(GoogleCloudBaseOperator):
self.project_id = project_id
self.region = region
self.job_name = job_name
+ self.overrides = overrides
self.gcp_conn_id = gcp_conn_id
self.impersonation_chain = impersonation_chain
self.polling_period_seconds = polling_period_seconds
@@ -293,7 +296,7 @@ class CloudRunExecuteJobOperator(GoogleCloudBaseOperator):
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain
)
self.operation = hook.execute_job(
- region=self.region, project_id=self.project_id,
job_name=self.job_name
+ region=self.region, project_id=self.project_id,
job_name=self.job_name, overrides=self.overrides
)
if not self.deferrable:
diff --git a/docs/apache-airflow-providers-google/operators/cloud/cloud_run.rst
b/docs/apache-airflow-providers-google/operators/cloud/cloud_run.rst
index 7c80f86d15..cf90afde68 100644
--- a/docs/apache-airflow-providers-google/operators/cloud/cloud_run.rst
+++ b/docs/apache-airflow-providers-google/operators/cloud/cloud_run.rst
@@ -77,6 +77,15 @@ or you can define the same operator in the deferrable mode:
:start-after: [START howto_operator_cloud_run_execute_job_deferrable_mode]
:end-before: [END howto_operator_cloud_run_execute_job_deferrable_mode]
+You can also specify overrides that allow you to give a new entrypoint command
to the job and more:
+
+:class:`~airflow.providers.google.cloud.operators.cloud_run.CloudRunExecuteJobOperator`
+
+.. exampleinclude::
/../../tests/system/providers/google/cloud/cloud_run/example_cloud_run.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_operator_cloud_run_execute_job_with_overrides]
+ :end-before: [END howto_operator_cloud_run_execute_job_with_overrides]
Update a job
diff --git a/tests/providers/google/cloud/hooks/test_cloud_run.py
b/tests/providers/google/cloud/hooks/test_cloud_run.py
index c91bc490f3..6a9a4fa898 100644
--- a/tests/providers/google/cloud/hooks/test_cloud_run.py
+++ b/tests/providers/google/cloud/hooks/test_cloud_run.py
@@ -34,7 +34,7 @@ from airflow.providers.google.cloud.hooks.cloud_run import
CloudRunAsyncHook, Cl
from tests.providers.google.cloud.utils.base_gcp_mock import
mock_base_gcp_hook_default_project_id
-class TestCloudBathHook:
+class TestCloudRunHook:
def dummy_get_credentials(self):
pass
@@ -111,9 +111,18 @@ class TestCloudBathHook:
job_name = "job1"
region = "region1"
project_id = "projectid"
- run_job_request =
RunJobRequest(name=f"projects/{project_id}/locations/{region}/jobs/{job_name}")
+ overrides = {
+ "container_overrides": [{"args": ["python", "main.py"]}],
+ "task_count": 1,
+ "timeout": "60s",
+ }
+ run_job_request = RunJobRequest(
+ name=f"projects/{project_id}/locations/{region}/jobs/{job_name}",
overrides=overrides
+ )
- cloud_run_hook.execute_job(job_name=job_name, region=region,
project_id=project_id)
+ cloud_run_hook.execute_job(
+ job_name=job_name, region=region, project_id=project_id,
overrides=overrides
+ )
cloud_run_hook._client.run_job.assert_called_once_with(request=run_job_request)
@mock.patch(
diff --git a/tests/providers/google/cloud/operators/test_cloud_run.py
b/tests/providers/google/cloud/operators/test_cloud_run.py
index 0fe7779158..152e625a23 100644
--- a/tests/providers/google/cloud/operators/test_cloud_run.py
+++ b/tests/providers/google/cloud/operators/test_cloud_run.py
@@ -96,7 +96,7 @@ class TestCloudRunExecuteJobOperator:
operator.execute(context=mock.MagicMock())
hook_mock.return_value.execute_job.assert_called_once_with(
- job_name=JOB_NAME, region=REGION, project_id=PROJECT_ID
+ job_name=JOB_NAME, region=REGION, project_id=PROJECT_ID,
overrides=None
)
@mock.patch(CLOUD_RUN_HOOK_PATH)
@@ -209,6 +209,72 @@ class TestCloudRunExecuteJobOperator:
result = operator.execute_complete(mock.MagicMock(), event)
assert result["name"] == JOB_NAME
+ @mock.patch(CLOUD_RUN_HOOK_PATH)
+ def test_execute_overrides(self, hook_mock):
+ hook_mock.return_value.get_job.return_value = JOB
+ hook_mock.return_value.execute_job.return_value =
self._mock_operation(3, 3, 0)
+
+ overrides = {
+ "container_overrides": [{"args": ["python", "main.py"]}],
+ "task_count": 1,
+ "timeout": "60s",
+ }
+
+ operator = CloudRunExecuteJobOperator(
+ task_id=TASK_ID, project_id=PROJECT_ID, region=REGION,
job_name=JOB_NAME, overrides=overrides
+ )
+
+ operator.execute(context=mock.MagicMock())
+
+ hook_mock.return_value.execute_job.assert_called_once_with(
+ job_name=JOB_NAME, region=REGION, project_id=PROJECT_ID,
overrides=overrides
+ )
+
+ @mock.patch(CLOUD_RUN_HOOK_PATH)
+ def test_execute_overrides_with_invalid_task_count(self, hook_mock):
+ overrides = {
+ "container_overrides": [{"args": ["python", "main.py"]}],
+ "task_count": -1,
+ "timeout": "60s",
+ }
+
+ operator = CloudRunExecuteJobOperator(
+ task_id=TASK_ID, project_id=PROJECT_ID, region=REGION,
job_name=JOB_NAME, overrides=overrides
+ )
+
+ with pytest.raises(AirflowException):
+ operator.execute(context=mock.MagicMock())
+
+ @mock.patch(CLOUD_RUN_HOOK_PATH)
+ def test_execute_overrides_with_invalid_timeout(self, hook_mock):
+ overrides = {
+ "container_overrides": [{"args": ["python", "main.py"]}],
+ "task_count": 1,
+ "timeout": "60",
+ }
+
+ operator = CloudRunExecuteJobOperator(
+ task_id=TASK_ID, project_id=PROJECT_ID, region=REGION,
job_name=JOB_NAME, overrides=overrides
+ )
+
+ with pytest.raises(AirflowException):
+ operator.execute(context=mock.MagicMock())
+
+ @mock.patch(CLOUD_RUN_HOOK_PATH)
+ def test_execute_overrides_with_invalid_container_args(self, hook_mock):
+ overrides = {
+ "container_overrides": [{"name": "job", "args": "python main.py"}],
+ "task_count": 1,
+ "timeout": "60s",
+ }
+
+ operator = CloudRunExecuteJobOperator(
+ task_id=TASK_ID, project_id=PROJECT_ID, region=REGION,
job_name=JOB_NAME, overrides=overrides
+ )
+
+ with pytest.raises(AirflowException):
+ operator.execute(context=mock.MagicMock())
+
def _mock_operation(self, task_count, succeeded_count, failed_count):
operation = mock.MagicMock()
operation.result.return_value = self._mock_execution(task_count,
succeeded_count, failed_count)
diff --git a/tests/system/providers/google/cloud/cloud_run/example_cloud_run.py
b/tests/system/providers/google/cloud/cloud_run/example_cloud_run.py
index 330789d82d..08c82d6eb0 100644
--- a/tests/system/providers/google/cloud/cloud_run/example_cloud_run.py
+++ b/tests/system/providers/google/cloud/cloud_run/example_cloud_run.py
@@ -44,12 +44,14 @@ region = "us-central1"
job_name_prefix = "cloudrun-system-test-job"
job1_name = f"{job_name_prefix}1"
job2_name = f"{job_name_prefix}2"
+job3_name = f"{job_name_prefix}3"
create1_task_name = "create-job1"
create2_task_name = "create-job2"
execute1_task_name = "execute-job1"
execute2_task_name = "execute-job2"
+execute3_task_name = "execute-job3"
update_job1_task_name = "update-job1"
@@ -70,6 +72,9 @@ def _assert_executed_jobs_xcom(ti):
job2_dicts = ti.xcom_pull(task_ids=[execute2_task_name],
key="return_value")
assert job2_name in job2_dicts[0]["name"]
+ job3_dicts = ti.xcom_pull(task_ids=[execute3_task_name],
key="return_value")
+ assert job3_name in job3_dicts[0]["name"]
+
def _assert_created_jobs_xcom(ti):
job1_dicts = ti.xcom_pull(task_ids=[create1_task_name], key="return_value")
@@ -181,6 +186,31 @@ with DAG(
)
# [END howto_operator_cloud_run_execute_job_deferrable_mode]
+ # [START howto_operator_cloud_run_execute_job_with_overrides]
+ overrides = {
+ "container_overrides": [
+ {
+ "name": "job",
+ "args": ["python", "main.py"],
+ "env": [{"name": "ENV_VAR", "value": "value"}],
+ "clearArgs": False,
+ }
+ ],
+ "task_count": 1,
+ "timeout": "60s",
+ }
+
+ execute3 = CloudRunExecuteJobOperator(
+ task_id=execute3_task_name,
+ project_id=PROJECT_ID,
+ region=region,
+ overrides=overrides,
+ job_name=job3_name,
+ dag=dag,
+ deferrable=False,
+ )
+ # [END howto_operator_cloud_run_execute_job_with_overrides]
+
assert_executed_jobs = PythonOperator(
task_id="assert-executed-jobs",
python_callable=_assert_executed_jobs_xcom, dag=dag
)
@@ -237,7 +267,7 @@ with DAG(
(
(create1, create2)
>> assert_created_jobs
- >> (execute1, execute2)
+ >> (execute1, execute2, execute3)
>> assert_executed_jobs
>> list_jobs_limit
>> assert_jobs_limit