caldempsey commented on code in PR #152:
URL: https://github.com/apache/spark-connect-go/pull/152#discussion_r2319805765
##########
spark/client/client.go:
##########
@@ -434,6 +443,151 @@ func (c *ExecutePlanClient) ToTable() (*types.StructType,
arrow.Table, error) {
}
}
+func (c *ExecutePlanClient) ToRecordBatches(ctx context.Context) (<-chan
arrow.Record, <-chan error, *types.StructType) {
+ recordChan := make(chan arrow.Record, 10)
+ errorChan := make(chan error, 1)
+
+ go func() {
+ defer func() {
+ // Ensure channels are always closed to prevent
goroutine leaks
+ close(recordChan)
+ close(errorChan)
+ }()
+
+ // Explicitly needed when tracking re-attachable execution.
+ c.done = false
+
+ for {
+ // Check for context cancellation before each iteration
+ select {
+ case <-ctx.Done():
+ // Context cancelled - send the error and
return immediately
+ select {
+ case errorChan <- ctx.Err():
+ default:
+ // Channel might be full, but we're
exiting anyway
+ }
+ return
+ default:
+ // Continue with normal processing
+ }
+
+ resp, err := c.responseStream.Recv()
+
+ // Check for context cancellation after potentially
blocking operations
+ select {
+ case <-ctx.Done():
+ select {
+ case errorChan <- ctx.Err():
+ default:
+ }
+ return
+ default:
+ }
+
+ // EOF is received when the last message has been
processed and the stream
+ // finished normally. Handle this FIRST, before any
other processing.
+ if errors.Is(err, io.EOF) {
+ return
+ }
+
+ // If there's any other error, handle it
+ if err != nil {
+ if se := sparkerrors.FromRPCError(err); se !=
nil {
+ select {
+ case errorChan <-
sparkerrors.WithType(se, sparkerrors.ExecutionError):
+ case <-ctx.Done():
+ return
+ }
+ } else {
+ // Unknown error - still send it
+ select {
+ case errorChan <- err:
+ case <-ctx.Done():
+ return
+ }
+ }
+ return
+ }
+
+ // Only proceed if we have a valid response (no error)
+ if resp == nil {
+ continue
+ }
+
+ // Check that the server returned the session ID that
we were expecting
+ // and that it has not changed.
+ if resp.GetSessionId() != c.sessionId {
+ select {
+ case errorChan <-
sparkerrors.WithType(&sparkerrors.InvalidServerSideSessionDetailsError{
+ OwnSessionId: c.sessionId,
+ ReceivedSessionId: resp.GetSessionId(),
+ }, sparkerrors.InvalidServerSideSessionError):
+ case <-ctx.Done():
+ return
+ }
+ return
+ }
+
+ // Check if the response has already the schema set and
if yes, convert
+ // the proto DataType to a StructType.
+ if resp.Schema != nil {
+ c.schema, err =
types.ConvertProtoDataTypeToStructType(resp.Schema)
+ if err != nil {
+ select {
+ case errorChan <-
sparkerrors.WithType(err, sparkerrors.ExecutionError):
+ case <-ctx.Done():
+ return
+ }
+ return
+ }
+ }
+
+ switch x := resp.ResponseType.(type) {
+ case *proto.ExecutePlanResponse_SqlCommandResult_:
+ if val := x.SqlCommandResult.GetRelation(); val
!= nil {
+ c.properties["sql_command_result"] = val
+ }
+
+ case *proto.ExecutePlanResponse_ArrowBatch_:
+ // This is what we want - stream the record
batch
+ record, err :=
types.ReadArrowBatchToRecord(x.ArrowBatch.Data, c.schema)
+ if err != nil {
+ select {
+ case errorChan <- err:
+ case <-ctx.Done():
+ return
+ }
+ return
+ }
+
+ // Try to send the record, but respect context
cancellation
+ select {
+ case recordChan <- record:
+ // Successfully sent
Review Comment:
left this for now, I think we can punt this down the road, mainly because I
don't think it needs to be the focus of this PR but maybe the focus of a
following one
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]