kaxil commented on code in PR #44229:
URL: https://github.com/apache/airflow/pull/44229#discussion_r1852434168


##########
task_sdk/tests/execution_time/test_supervisor.py:
##########
@@ -225,3 +225,88 @@ def test_run_simple_dag(self, test_dags_dir, 
captured_logs, time_machine):
             "logger": "task",
             "timestamp": "2024-11-07T12:34:56.078901Z",
         } in captured_logs
+
+
+class TestHandleRequest:
+    @pytest.fixture
+    def watched_subprocess(self, mocker):
+        """Fixture to provide a WatchedSubprocess instance."""
+        return WatchedSubprocess(
+            ti_id=uuid7(),
+            pid=12345,
+            stdin=mocker.Mock(),  # Not used in these tests
+            stdout=mocker.Mock(),  # Not used in these tests
+            stderr=mocker.Mock(),  # Not used in these tests
+            client=mocker.Mock(),
+            process=mocker.Mock(),
+        )
+
+    @pytest.mark.parametrize(
+        ["message", "expected_buffer", "client_attr_path", "method_arg", 
"mock_response"],
+        [
+            pytest.param(
+                GetConnection(conn_id="test_conn"),
+                b'{"conn_id":"test_conn","conn_type":"mysql"}\n',
+                "connections.get",
+                "test_conn",
+                ConnectionResult(conn_id="test_conn", conn_type="mysql"),
+                id="get_connection",
+            ),
+            pytest.param(
+                GetVariable(key="test_key"),
+                b'{"key":"test_key","value":"test_value"}\n',
+                "variables.get",
+                "test_key",
+                VariableResult(key="test_key", value="test_value"),
+                id="get_variable",
+            ),
+        ],
+    )
+    def test_handle_requests(
+        self,
+        watched_subprocess,
+        mocker,
+        message,
+        expected_buffer,
+        client_attr_path,
+        method_arg,
+        mock_response,
+    ):
+        """
+        Test handling of different messages to the subprocess. For any new 
message type, add a
+        new parameter set to the `@pytest.mark.parametrize` decorator.
+
+        For each message type, this test:
+
+            1. Sends the message to the subprocess.
+            2. Verifies that the correct client method is called with the 
expected argument.
+            3. Checks that the buffer is updated with the expected response.
+        """
+
+        def _resolve_nested_attr(obj, attr_path):
+            """Helper to resolve nested attributes like 'variables.get'."""
+            attrs = attr_path.split(".")
+            for attr in attrs:
+                obj = getattr(obj, attr)
+            return obj
+
+        # Mock the client method. E.g. `client.variables.get` or 
`client.connections.get`
+        mock_client_method = _resolve_nested_attr(watched_subprocess.client, 
client_attr_path)
+        mock_client_method.return_value = mock_response
+
+        # Mock buffer directly as a real bytearray to avoid TypeError
+        buffer = bytearray()
+        mocker.patch("airflow.sdk.execution_time.supervisor.bytearray", 
return_value=buffer)

Review Comment:
   Indeed, it was hiding a bug.
   
   ```
   FAILED 
task_sdk/tests/execution_time/test_supervisor.py::TestHandleRequest::test_handle_requests[get_variable]
 - assert equals failed
     
b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00
  b'{"key":"test_key","value":"test_value"}\n'
     
\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x
     
00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00{"key":"test_key","value":"test_valu
     e"}\n'
   ```
   
   That was because msgspec truncated the buffer with to the end of serialized 
msg for us with 
[`encoder.encode_into`](https://jcristharif.com/msgspec/api.html#msgspec.json.Encoder.encode_into).
   
   So we could at `buffer[:] = b""` to clear buffer but preserving the 
pre-allocated memory. But @ashb and I discussed to simplify and remove buffer 
handling since Pydantic doesn't support buffers



##########
task_sdk/tests/execution_time/test_supervisor.py:
##########
@@ -225,3 +225,88 @@ def test_run_simple_dag(self, test_dags_dir, 
captured_logs, time_machine):
             "logger": "task",
             "timestamp": "2024-11-07T12:34:56.078901Z",
         } in captured_logs
+
+
+class TestHandleRequest:
+    @pytest.fixture
+    def watched_subprocess(self, mocker):
+        """Fixture to provide a WatchedSubprocess instance."""
+        return WatchedSubprocess(
+            ti_id=uuid7(),
+            pid=12345,
+            stdin=mocker.Mock(),  # Not used in these tests
+            stdout=mocker.Mock(),  # Not used in these tests
+            stderr=mocker.Mock(),  # Not used in these tests
+            client=mocker.Mock(),
+            process=mocker.Mock(),
+        )
+
+    @pytest.mark.parametrize(
+        ["message", "expected_buffer", "client_attr_path", "method_arg", 
"mock_response"],
+        [
+            pytest.param(
+                GetConnection(conn_id="test_conn"),
+                b'{"conn_id":"test_conn","conn_type":"mysql"}\n',
+                "connections.get",
+                "test_conn",
+                ConnectionResult(conn_id="test_conn", conn_type="mysql"),
+                id="get_connection",
+            ),
+            pytest.param(
+                GetVariable(key="test_key"),
+                b'{"key":"test_key","value":"test_value"}\n',
+                "variables.get",
+                "test_key",
+                VariableResult(key="test_key", value="test_value"),
+                id="get_variable",
+            ),
+        ],
+    )
+    def test_handle_requests(
+        self,
+        watched_subprocess,
+        mocker,
+        message,
+        expected_buffer,
+        client_attr_path,
+        method_arg,
+        mock_response,
+    ):
+        """
+        Test handling of different messages to the subprocess. For any new 
message type, add a
+        new parameter set to the `@pytest.mark.parametrize` decorator.
+
+        For each message type, this test:
+
+            1. Sends the message to the subprocess.
+            2. Verifies that the correct client method is called with the 
expected argument.
+            3. Checks that the buffer is updated with the expected response.
+        """
+
+        def _resolve_nested_attr(obj, attr_path):
+            """Helper to resolve nested attributes like 'variables.get'."""
+            attrs = attr_path.split(".")
+            for attr in attrs:
+                obj = getattr(obj, attr)
+            return obj
+
+        # Mock the client method. E.g. `client.variables.get` or 
`client.connections.get`
+        mock_client_method = _resolve_nested_attr(watched_subprocess.client, 
client_attr_path)

Review Comment:
   Good shout



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to