ashb commented on code in PR #44229:
URL: https://github.com/apache/airflow/pull/44229#discussion_r1851902463


##########
task_sdk/tests/execution_time/test_supervisor.py:
##########
@@ -225,3 +225,88 @@ def test_run_simple_dag(self, test_dags_dir, 
captured_logs, time_machine):
             "logger": "task",
             "timestamp": "2024-11-07T12:34:56.078901Z",
         } in captured_logs
+
+
+class TestHandleRequest:
+    @pytest.fixture
+    def watched_subprocess(self, mocker):
+        """Fixture to provide a WatchedSubprocess instance."""
+        return WatchedSubprocess(
+            ti_id=uuid7(),
+            pid=12345,
+            stdin=mocker.Mock(),  # Not used in these tests
+            stdout=mocker.Mock(),  # Not used in these tests
+            stderr=mocker.Mock(),  # Not used in these tests
+            client=mocker.Mock(),
+            process=mocker.Mock(),
+        )
+
+    @pytest.mark.parametrize(
+        ["message", "expected_buffer", "client_attr_path", "method_arg", 
"mock_response"],
+        [
+            pytest.param(
+                GetConnection(conn_id="test_conn"),
+                b'{"conn_id":"test_conn","conn_type":"mysql"}\n',
+                "connections.get",
+                "test_conn",
+                ConnectionResult(conn_id="test_conn", conn_type="mysql"),
+                id="get_connection",
+            ),
+            pytest.param(
+                GetVariable(key="test_key"),
+                b'{"key":"test_key","value":"test_value"}\n',
+                "variables.get",
+                "test_key",
+                VariableResult(key="test_key", value="test_value"),
+                id="get_variable",
+            ),
+        ],
+    )
+    def test_handle_requests(
+        self,
+        watched_subprocess,
+        mocker,
+        message,
+        expected_buffer,
+        client_attr_path,
+        method_arg,
+        mock_response,
+    ):
+        """
+        Test handling of different messages to the subprocess. For any new 
message type, add a
+        new parameter set to the `@pytest.mark.parametrize` decorator.
+
+        For each message type, this test:
+
+            1. Sends the message to the subprocess.
+            2. Verifies that the correct client method is called with the 
expected argument.
+            3. Checks that the buffer is updated with the expected response.
+        """
+
+        def _resolve_nested_attr(obj, attr_path):
+            """Helper to resolve nested attributes like 'variables.get'."""
+            attrs = attr_path.split(".")
+            for attr in attrs:
+                obj = getattr(obj, attr)
+            return obj
+
+        # Mock the client method. E.g. `client.variables.get` or 
`client.connections.get`
+        mock_client_method = _resolve_nested_attr(watched_subprocess.client, 
client_attr_path)
+        mock_client_method.return_value = mock_response
+
+        # Mock buffer directly as a real bytearray to avoid TypeError
+        buffer = bytearray()
+        mocker.patch("airflow.sdk.execution_time.supervisor.bytearray", 
return_value=buffer)

Review Comment:
   This seems.... odd. What was going on here?



##########
task_sdk/tests/execution_time/test_supervisor.py:
##########
@@ -225,3 +225,88 @@ def test_run_simple_dag(self, test_dags_dir, 
captured_logs, time_machine):
             "logger": "task",
             "timestamp": "2024-11-07T12:34:56.078901Z",
         } in captured_logs
+
+
+class TestHandleRequest:
+    @pytest.fixture
+    def watched_subprocess(self, mocker):
+        """Fixture to provide a WatchedSubprocess instance."""
+        return WatchedSubprocess(
+            ti_id=uuid7(),
+            pid=12345,
+            stdin=mocker.Mock(),  # Not used in these tests
+            stdout=mocker.Mock(),  # Not used in these tests
+            stderr=mocker.Mock(),  # Not used in these tests
+            client=mocker.Mock(),
+            process=mocker.Mock(),
+        )
+
+    @pytest.mark.parametrize(
+        ["message", "expected_buffer", "client_attr_path", "method_arg", 
"mock_response"],
+        [
+            pytest.param(
+                GetConnection(conn_id="test_conn"),
+                b'{"conn_id":"test_conn","conn_type":"mysql"}\n',
+                "connections.get",
+                "test_conn",
+                ConnectionResult(conn_id="test_conn", conn_type="mysql"),
+                id="get_connection",
+            ),
+            pytest.param(
+                GetVariable(key="test_key"),
+                b'{"key":"test_key","value":"test_value"}\n',
+                "variables.get",
+                "test_key",
+                VariableResult(key="test_key", value="test_value"),
+                id="get_variable",
+            ),
+        ],
+    )
+    def test_handle_requests(
+        self,
+        watched_subprocess,
+        mocker,
+        message,
+        expected_buffer,
+        client_attr_path,
+        method_arg,
+        mock_response,
+    ):
+        """
+        Test handling of different messages to the subprocess. For any new 
message type, add a
+        new parameter set to the `@pytest.mark.parametrize` decorator.
+
+        For each message type, this test:
+
+            1. Sends the message to the subprocess.
+            2. Verifies that the correct client method is called with the 
expected argument.
+            3. Checks that the buffer is updated with the expected response.
+        """
+
+        def _resolve_nested_attr(obj, attr_path):
+            """Helper to resolve nested attributes like 'variables.get'."""
+            attrs = attr_path.split(".")
+            for attr in attrs:
+                obj = getattr(obj, attr)
+            return obj
+
+        # Mock the client method. E.g. `client.variables.get` or 
`client.connections.get`
+        mock_client_method = _resolve_nested_attr(watched_subprocess.client, 
client_attr_path)

Review Comment:
   https://docs.python.org/3/library/operator.html#operator.attrgetter
   
   `from operator import attrgetter` and then:
   
   ```suggestion
           mock_client_method = 
attrgetter(client_attr_path)(watched_subprocess.client)
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to