uros-b commented on code in PR #56343:
URL: https://github.com/apache/spark/pull/56343#discussion_r3473403527


##########
python/pyspark/sql/tests/connect/client/test_client.py:
##########
@@ -277,6 +280,64 @@ def test_user_agent_default(self):
             mock.req.client_type, r"^_SPARK_CONNECT_PYTHON spark/[^ ]+ os/[^ 
]+ python/[^ ]+$"
         )
 
+    class MultiBatchMockService:
+        """Mock service returning a single arrow batch message whose IPC stream
+        carries multiple RecordBatches. ``declared_row_count`` is the row_count
+        the server claims, allowing tests to exercise both the matching and the
+        mismatching validation paths."""
+
+        def __init__(self, session_id: str, values, declared_row_count: int):
+            self._session_id = session_id
+            self._values = values
+            self._declared_row_count = declared_row_count
+            self.req: Optional[proto.ExecutePlanRequest] = None
+
+        def ExecutePlan(self, req: proto.ExecutePlanRequest, metadata, 
timeout=None):
+            self.req = req
+            resp = proto.ExecutePlanResponse()
+            resp.session_id = self._session_id
+            resp.operation_id = req.operation_id
+
+            pdf = pd.DataFrame(data={"col1": self._values})
+            schema = pa.Schema.from_pandas(pdf)
+            table = pa.Table.from_pandas(pdf)
+            sink = pa.BufferOutputStream()
+            writer = pa.ipc.new_stream(sink, schema=schema)
+            # Split the data into multiple RecordBatches within one IPC stream.
+            for batch in table.to_batches(max_chunksize=2):
+                writer.write_batch(batch)
+            writer.close()
+
+            resp.arrow_batch.data = sink.getvalue().to_pybytes()
+            resp.arrow_batch.row_count = self._declared_row_count
+            return [resp]
+
+    def test_multiple_record_batches_in_single_arrow_batch(self):
+        # An Arrow IPC stream may carry multiple RecordBatches; row_count is 
the total
+        # across them and must be validated only after the stream is fully 
consumed.
+        client = SparkConnectClient("sc://foo/", 
use_reattachable_execute=False)
+        values = [1, 2, 3, 4]
+        client._stub = self.MultiBatchMockService(
+            client._session_id, values, declared_row_count=len(values)
+        )
+
+        table, _, _ = client.to_table(proto.Plan(), {})
+        self.assertEqual(table.num_rows, len(values))
+        self.assertEqual(table.column("col1").to_pylist(), values)
+
+    def test_multiple_record_batches_row_count_mismatch_raises(self):

Review Comment:
   Nit: The two tests share identical client/stub setup differing only in 
declared_row_count; a small helper would reduce duplication.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to