shahar1 commented on code in PR #64900:
URL: https://github.com/apache/airflow/pull/64900#discussion_r3287206802


##########
providers/google/src/airflow/providers/google/cloud/operators/dataflow.py:
##########
@@ -1176,3 +1178,145 @@ def execute(self, context: Context):
             raise AirflowException(self.response)
 
         return None
+
+
+class DataflowJobMetricsOperator(GoogleCloudBaseOperator):
+    """
+    Fetches metrics for a single Dataflow job and executes a callback function 
with the result.
+
+    .. seealso::
+        For more information on how to use this operator, take a look at the 
guide:
+        :ref:`howto/operator:DataflowJobMetricsOperator`
+
+    :param job_id: Dataflow job ID. Jinja-templated.
+    :param callback: Callback function that accepts the metrics list.
+        If provided, the function is called with the metrics and its result is 
returned.
+        If not provided, metrics are pushed to XCom and returned directly.
+        See: 
https://cloud.google.com/dataflow/docs/reference/rest/v1b3/MetricUpdate
+    :param fail_on_terminal_state: If set to True, raises an exception when 
the job
+        is in a terminal state. Default is False.

Review Comment:
   Default is `True` below



##########
providers/google/tests/system/google/cloud/dataflow/example_dataflow_get_metrics.py:
##########
@@ -0,0 +1,133 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""
+Example Airflow DAG for Google Cloud Dataflow Get Metrics Operator.
+
+This DAG demonstrates how to use DataflowJobMetricsOperator to:
+1. Collect metrics from a Dataflow job
+2. Pass metrics to a callback function for processing
+3. Return metrics directly for XCom consumption when no callback is provided
+4. Use deferrable mode for async execution
+5. Consume metrics from XCom in downstream tasks
+"""
+
+from __future__ import annotations
+
+import os
+from datetime import datetime
+
+from airflow import DAG
+from airflow.operators.empty import EmptyOperator
+from airflow.operators.python import PythonOperator

Review Comment:
   1. Importing `DAG` should be from `airflow.sdk` 
   2. Importing `PythonOperator` should be from 
`airflow.providers.standard.operators.python`



##########
providers/google/tests/unit/google/cloud/operators/test_dataflow.py:
##########
@@ -794,3 +817,168 @@ def test_invalid_response(self, sdk_connection_not_found):
             
DataflowDeletePipelineOperator(**init_kwargs).execute(mock.MagicMock()).return_value
 = {
                 "error": {"message": "example error"}
             }
+
+
[email protected]
+def sync_operator():
+    """Create a synchronous DataflowJobMetricsOperator instance."""
+    return DataflowJobMetricsOperator(
+        task_id=TASK_ID,
+        job_id=JOB_ID,
+        project_id=PROJECT_ID,
+        location=LOCATION,
+        gcp_conn_id=GCP_CONN_ID,
+        deferrable=False,
+    )
+
+
[email protected]
+def deferrable_operator():
+    """Create a deferrable DataflowJobMetricsOperator instance."""
+    return DataflowJobMetricsOperator(
+        task_id=TASK_ID,
+        job_id=JOB_ID,
+        project_id=PROJECT_ID,
+        location=LOCATION,
+        gcp_conn_id=GCP_CONN_ID,
+        deferrable=True,
+    )
+
+
+class TestDataflowJobMetricsOperatorLocationValidation:
+    """Test location validation during execution."""
+
+    def test_execute_raises_when_location_is_none(self):
+        """Test that execute raises ValueError when location is None."""
+        op = DataflowJobMetricsOperator(
+            task_id=TASK_ID,
+            job_id=JOB_ID,
+            project_id=PROJECT_ID,
+            location=None,
+            deferrable=False,
+        )
+        with pytest.raises(ValueError, match="DataflowJobMetricsOperator 
requires 'location' to be set"):
+            op.execute(mock.MagicMock())
+
+    def test_execute_raises_when_location_is_empty_string(self):
+        """Test that execute raises ValueError when location is empty 
string."""
+        op = DataflowJobMetricsOperator(
+            task_id=TASK_ID,
+            job_id=JOB_ID,
+            project_id=PROJECT_ID,
+            location="",
+            deferrable=False,
+        )
+        with pytest.raises(ValueError, match="DataflowJobMetricsOperator 
requires 'location' to be set"):
+            op.execute(mock.MagicMock())
+
+
+class TestDataflowJobMetricsOperatorExecuteSync:
+    """Test synchronous execution of DataflowJobMetricsOperator."""
+
+    @mock.patch(f"{OPERATOR_PATH}.DataflowHook")
+    def test_execute_sync_without_callback(self, mock_hook, sync_operator):
+        """Test sync execute without callback."""
+        mock_hook.return_value.fetch_job_metrics_by_id.return_value = 
SAMPLE_METRICS
+        mock_context = {"task_instance": mock.MagicMock()}
+
+        result = sync_operator.execute(mock_context)
+
+        mock_hook.assert_called_once_with(
+            gcp_conn_id=GCP_CONN_ID,
+            impersonation_chain=None,
+        )
+        mock_hook.return_value.fetch_job_metrics_by_id.assert_called_once_with(
+            job_id=JOB_ID,
+            project_id=PROJECT_ID,
+            location=LOCATION,
+        )
+        assert result == SAMPLE_METRICS["metrics"]
+
+    @mock.patch(f"{OPERATOR_PATH}.DataflowHook")
+    def test_execute_sync_raise_exception_on_terminal_state(self, mock_hook):
+        """Test that execute raises exception when job is in terminal state 
with fail_on_terminal_state=True."""
+        operator = DataflowJobMetricsOperator(
+            task_id=TASK_ID,
+            job_id=JOB_ID,
+            project_id=PROJECT_ID,
+            location=LOCATION,
+            fail_on_terminal_state=True,
+            deferrable=False,
+            gcp_conn_id=GCP_CONN_ID,
+        )
+
+        mock_hook.return_value.get_job.return_value = {
+            "id": JOB_ID,
+            "currentState": DataflowJobStatus.JOB_STATE_DONE,
+        }
+
+        with pytest.raises(
+            AirflowException,

Review Comment:
   Operator now raises `RuntimeError`



##########
providers/google/tests/system/google/cloud/dataflow/example_dataflow_get_metrics.py:
##########
@@ -0,0 +1,133 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""
+Example Airflow DAG for Google Cloud Dataflow Get Metrics Operator.
+
+This DAG demonstrates how to use DataflowJobMetricsOperator to:
+1. Collect metrics from a Dataflow job
+2. Pass metrics to a callback function for processing
+3. Return metrics directly for XCom consumption when no callback is provided
+4. Use deferrable mode for async execution
+5. Consume metrics from XCom in downstream tasks
+"""
+
+from __future__ import annotations
+
+import os
+from datetime import datetime
+
+from airflow import DAG
+from airflow.operators.empty import EmptyOperator
+from airflow.operators.python import PythonOperator
+from airflow.providers.google.cloud.operators.dataflow import 
DataflowJobMetricsOperator
+
+ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID", "default")
+PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT") or "test-project"
+DAG_ID = "dataflow_get_metrics"
+LOCATION = "us-central1"
+
+
+def process_metrics_callback(metrics):
+    """Callback function that processes metrics returned by the operator."""
+    metric_list = metrics if isinstance(metrics, list) else []
+    print(f"Metrics count from callback: {len(metric_list)}")
+    return {"processed_metrics_count": len(metric_list)}
+
+
+def consume_metrics_from_xcom(**context):
+    """Consume and display metrics count from XCom."""
+    task_instance = context["task_instance"]
+    metrics = task_instance.xcom_pull(task_ids="collect_metrics_no_callback", 
key="metrics")

Review Comment:
   It wil return `None`, because you don't push anything to the key `metrics`.
   You shuold retrieve the results from `return_value`.



##########
providers/google/tests/unit/google/cloud/operators/test_dataflow.py:
##########
@@ -794,3 +817,168 @@ def test_invalid_response(self, sdk_connection_not_found):
             
DataflowDeletePipelineOperator(**init_kwargs).execute(mock.MagicMock()).return_value
 = {
                 "error": {"message": "example error"}
             }
+
+
[email protected]
+def sync_operator():
+    """Create a synchronous DataflowJobMetricsOperator instance."""
+    return DataflowJobMetricsOperator(
+        task_id=TASK_ID,
+        job_id=JOB_ID,
+        project_id=PROJECT_ID,
+        location=LOCATION,
+        gcp_conn_id=GCP_CONN_ID,
+        deferrable=False,
+    )
+
+
[email protected]
+def deferrable_operator():
+    """Create a deferrable DataflowJobMetricsOperator instance."""
+    return DataflowJobMetricsOperator(
+        task_id=TASK_ID,
+        job_id=JOB_ID,
+        project_id=PROJECT_ID,
+        location=LOCATION,
+        gcp_conn_id=GCP_CONN_ID,
+        deferrable=True,
+    )
+
+
+class TestDataflowJobMetricsOperatorLocationValidation:
+    """Test location validation during execution."""
+
+    def test_execute_raises_when_location_is_none(self):
+        """Test that execute raises ValueError when location is None."""
+        op = DataflowJobMetricsOperator(
+            task_id=TASK_ID,
+            job_id=JOB_ID,
+            project_id=PROJECT_ID,
+            location=None,
+            deferrable=False,
+        )
+        with pytest.raises(ValueError, match="DataflowJobMetricsOperator 
requires 'location' to be set"):
+            op.execute(mock.MagicMock())
+
+    def test_execute_raises_when_location_is_empty_string(self):
+        """Test that execute raises ValueError when location is empty 
string."""
+        op = DataflowJobMetricsOperator(
+            task_id=TASK_ID,
+            job_id=JOB_ID,
+            project_id=PROJECT_ID,
+            location="",
+            deferrable=False,
+        )
+        with pytest.raises(ValueError, match="DataflowJobMetricsOperator 
requires 'location' to be set"):
+            op.execute(mock.MagicMock())
+
+
+class TestDataflowJobMetricsOperatorExecuteSync:
+    """Test synchronous execution of DataflowJobMetricsOperator."""
+
+    @mock.patch(f"{OPERATOR_PATH}.DataflowHook")
+    def test_execute_sync_without_callback(self, mock_hook, sync_operator):
+        """Test sync execute without callback."""
+        mock_hook.return_value.fetch_job_metrics_by_id.return_value = 
SAMPLE_METRICS
+        mock_context = {"task_instance": mock.MagicMock()}
+
+        result = sync_operator.execute(mock_context)
+
+        mock_hook.assert_called_once_with(
+            gcp_conn_id=GCP_CONN_ID,
+            impersonation_chain=None,
+        )
+        mock_hook.return_value.fetch_job_metrics_by_id.assert_called_once_with(
+            job_id=JOB_ID,
+            project_id=PROJECT_ID,
+            location=LOCATION,
+        )
+        assert result == SAMPLE_METRICS["metrics"]
+
+    @mock.patch(f"{OPERATOR_PATH}.DataflowHook")
+    def test_execute_sync_raise_exception_on_terminal_state(self, mock_hook):
+        """Test that execute raises exception when job is in terminal state 
with fail_on_terminal_state=True."""
+        operator = DataflowJobMetricsOperator(
+            task_id=TASK_ID,
+            job_id=JOB_ID,
+            project_id=PROJECT_ID,
+            location=LOCATION,
+            fail_on_terminal_state=True,
+            deferrable=False,
+            gcp_conn_id=GCP_CONN_ID,
+        )
+
+        mock_hook.return_value.get_job.return_value = {
+            "id": JOB_ID,
+            "currentState": DataflowJobStatus.JOB_STATE_DONE,
+        }
+
+        with pytest.raises(
+            AirflowException,
+            match=f"Job with id '{JOB_ID}' is already in terminal state: 
{DataflowJobStatus.JOB_STATE_DONE}",
+        ):
+            operator.execute(mock.MagicMock())
+
+        mock_hook.assert_called_once_with(
+            gcp_conn_id=GCP_CONN_ID,
+            impersonation_chain=None,
+        )
+        mock_hook.return_value.get_job.assert_called_once_with(
+            job_id=JOB_ID,
+            project_id=PROJECT_ID,
+            location=LOCATION,
+        )
+
+
+class TestDataflowJobMetricsOperatorExecuteDeferred:
+    """Test deferrable execution of DataflowJobMetricsOperator."""
+
+    @mock.patch(f"{OPERATOR_PATH}.DataflowJobMetricsOperator.defer")
+    def test_execute_deferred_calls_defer(self, mock_defer, 
deferrable_operator):
+        """Test that deferrable operator calls defer method."""
+        deferrable_operator.execute(mock.MagicMock())
+        mock_defer.assert_called_once()
+
+    @mock.patch(f"{OPERATOR_PATH}.DataflowJobMetricsOperator.defer")
+    def test_execute_deferred_with_correct_trigger(self, mock_defer, 
deferrable_operator):
+        """Test that deferrable operator creates trigger with correct 
parameters."""
+        deferrable_operator.execute(mock.MagicMock())
+
+        _, kwargs = mock_defer.call_args
+        trigger = kwargs["trigger"]
+
+        assert isinstance(trigger, DataflowJobMetricsTrigger)
+        assert trigger.job_id == JOB_ID
+        assert trigger.project_id == PROJECT_ID
+        assert trigger.location == LOCATION
+        assert trigger.gcp_conn_id == GCP_CONN_ID
+        assert trigger.fail_on_terminal_state is False
+        assert kwargs["method_name"] == GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME
+
+
+class TestDataflowJobMetricsOperatorExecuteComplete:
+    """Test execute_complete callback for DataflowJobMetricsOperator."""
+
+    def test_execute_complete_raises_when_event_is_none(self, sync_operator):
+        """Test that execute_complete raises RuntimeError when event is 
None."""
+        with pytest.raises((RuntimeError, TypeError)):
+            sync_operator.execute_complete(context=mock.MagicMock(), 
event=None)
+
+    def test_execute_complete_raises_on_error_status(self, sync_operator):
+        """Test that execute_complete raises RuntimeError on error status."""
+        with pytest.raises(AirflowException):

Review Comment:
   Operator now raises `RuntimeError`



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to