cocoa-xu commented on code in PR #1722:
URL: https://github.com/apache/arrow-adbc/pull/1722#discussion_r1664487224


##########
go/adbc/driver/bigquery/record_reader.go:
##########
@@ -0,0 +1,315 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package bigquery
+
+import (
+       "context"
+       "errors"
+       "sync/atomic"
+
+       "cloud.google.com/go/bigquery"
+       "github.com/apache/arrow-adbc/go/adbc"
+       "github.com/apache/arrow/go/v17/arrow"
+       "github.com/apache/arrow/go/v17/arrow/array"
+       "github.com/apache/arrow/go/v17/arrow/ipc"
+       "github.com/apache/arrow/go/v17/arrow/memory"
+       "golang.org/x/sync/errgroup"
+)
+
+type reader struct {
+       refCount   int64
+       schema     *arrow.Schema
+       chs        []chan arrow.Record
+       curChIndex int
+       rec        arrow.Record
+       err        error
+
+       cancelFn context.CancelFunc
+}
+
+func checkContext(ctx context.Context, maybeErr error) error {
+       if maybeErr != nil {
+               return maybeErr
+       } else if errors.Is(ctx.Err(), context.Canceled) {
+               return adbc.Error{Msg: ctx.Err().Error(), Code: 
adbc.StatusCancelled}
+       } else if errors.Is(ctx.Err(), context.DeadlineExceeded) {
+               return adbc.Error{Msg: ctx.Err().Error(), Code: 
adbc.StatusTimeout}
+       }
+       return ctx.Err()
+}
+
+func runQuery(ctx context.Context, query *bigquery.Query, executeUpdate bool) 
(bigquery.ArrowIterator, int64, error) {
+       job, err := query.Run(ctx)
+       if err != nil {
+               return nil, -1, err
+       }
+       if executeUpdate {
+               return nil, 0, nil
+       }
+
+       iter, err := job.Read(ctx)
+       if err != nil {
+               return nil, -1, err
+       }
+       arrowIterator, err := iter.ArrowIterator()
+       if err != nil {
+               return nil, -1, err
+       }
+       totalRows := int64(iter.TotalRows)
+       return arrowIterator, totalRows, nil
+}
+
+func ipcReaderFromArrowIterator(arrowIterator bigquery.ArrowIterator, alloc 
memory.Allocator) (*ipc.Reader, error) {
+       arrowItReader := bigquery.NewArrowIteratorReader(arrowIterator)
+       return ipc.NewReader(arrowItReader, ipc.WithAllocator(alloc))
+}
+
+func getQueryParameter(values arrow.Record, row int, parameterMode string) 
([]bigquery.QueryParameter, error) {
+       parameters := make([]bigquery.QueryParameter, values.NumCols())
+       includeName := parameterMode == OptionValueQueryParameterModeNamed
+       schema := values.Schema()
+       for i, v := range values.Columns() {
+               pi, err := arrowValueToQueryParameterValue(schema.Field(i), v, 
row)
+               if err != nil {
+                       return nil, err
+               }
+               parameters[i] = pi
+               if includeName {
+                       parameters[i].Name = values.ColumnName(i)
+               }
+       }
+       return parameters, nil
+}
+
+func runPlainQuery(ctx context.Context, query *bigquery.Query, alloc 
memory.Allocator, resultRecordBufferSize int) (bigqueryRdr *reader, totalRows 
int64, err error) {
+       arrowIterator, totalRows, err := runQuery(ctx, query, false)
+       if err != nil {
+               return nil, -1, err
+       }
+       rdr, err := ipcReaderFromArrowIterator(arrowIterator, alloc)
+       if err != nil {
+               return nil, -1, err
+       }
+
+       chs := make([]chan arrow.Record, 1)
+       ctx, cancelFn := context.WithCancel(ctx)
+       ch := make(chan arrow.Record, resultRecordBufferSize)
+       chs[0] = ch
+
+       defer func() {
+               if err != nil {
+                       close(ch)
+                       cancelFn()
+               }
+       }()
+
+       bigqueryRdr = &reader{
+               refCount:   1,
+               chs:        chs,
+               curChIndex: 0,
+               err:        nil,
+               cancelFn:   cancelFn,
+               schema:     nil,
+       }
+
+       go func() {
+               defer rdr.Release()
+               for rdr.Next() && ctx.Err() == nil {
+                       rec := rdr.Record()
+                       rec.Retain()
+                       ch <- rec
+               }
+
+               err = checkContext(ctx, rdr.Err())
+               defer close(ch)
+       }()
+       return bigqueryRdr, totalRows, nil
+}
+
+// kicks off a goroutine for each endpoint and returns a reader which
+// gathers all of the records as they come in.
+func newRecordReader(ctx context.Context, query *bigquery.Query, 
boundParameters array.RecordReader, parameterMode string, alloc 
memory.Allocator, resultRecordBufferSize, prefetchConcurrency int) (bigqueryRdr 
*reader, totalRows int64, err error) {
+       if boundParameters == nil {
+               return runPlainQuery(ctx, query, alloc, resultRecordBufferSize)
+       }
+
+       recs := make([]arrow.Record, 0)
+       for boundParameters.Next() {
+               rec := boundParameters.Record()
+               recs = append(recs, rec)
+       }

Review Comment:
   Hi @zeroshade, thanks for these advices! I was looking at snowflake's 
implementation as reference, and here if I'm not wrong, the idea is to count 
how many channels do we need so that each Record batch and their results can be 
put in a separate channel (and the requests would be in their own goroutine).
   
   ```go
   batches := int64(len(recs))
   chs := make([]chan arrow.Record, batches)
   ```



##########
go/adbc/driver/bigquery/record_reader.go:
##########
@@ -0,0 +1,315 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package bigquery
+
+import (
+       "context"
+       "errors"
+       "sync/atomic"
+
+       "cloud.google.com/go/bigquery"
+       "github.com/apache/arrow-adbc/go/adbc"
+       "github.com/apache/arrow/go/v17/arrow"
+       "github.com/apache/arrow/go/v17/arrow/array"
+       "github.com/apache/arrow/go/v17/arrow/ipc"
+       "github.com/apache/arrow/go/v17/arrow/memory"
+       "golang.org/x/sync/errgroup"
+)
+
+type reader struct {
+       refCount   int64
+       schema     *arrow.Schema
+       chs        []chan arrow.Record
+       curChIndex int
+       rec        arrow.Record
+       err        error
+
+       cancelFn context.CancelFunc
+}
+
+func checkContext(ctx context.Context, maybeErr error) error {
+       if maybeErr != nil {
+               return maybeErr
+       } else if errors.Is(ctx.Err(), context.Canceled) {
+               return adbc.Error{Msg: ctx.Err().Error(), Code: 
adbc.StatusCancelled}
+       } else if errors.Is(ctx.Err(), context.DeadlineExceeded) {
+               return adbc.Error{Msg: ctx.Err().Error(), Code: 
adbc.StatusTimeout}
+       }
+       return ctx.Err()
+}
+
+func runQuery(ctx context.Context, query *bigquery.Query, executeUpdate bool) 
(bigquery.ArrowIterator, int64, error) {
+       job, err := query.Run(ctx)
+       if err != nil {
+               return nil, -1, err
+       }
+       if executeUpdate {
+               return nil, 0, nil
+       }
+
+       iter, err := job.Read(ctx)
+       if err != nil {
+               return nil, -1, err
+       }
+       arrowIterator, err := iter.ArrowIterator()
+       if err != nil {
+               return nil, -1, err
+       }
+       totalRows := int64(iter.TotalRows)
+       return arrowIterator, totalRows, nil
+}
+
+func ipcReaderFromArrowIterator(arrowIterator bigquery.ArrowIterator, alloc 
memory.Allocator) (*ipc.Reader, error) {
+       arrowItReader := bigquery.NewArrowIteratorReader(arrowIterator)
+       return ipc.NewReader(arrowItReader, ipc.WithAllocator(alloc))
+}
+
+func getQueryParameter(values arrow.Record, row int, parameterMode string) 
([]bigquery.QueryParameter, error) {
+       parameters := make([]bigquery.QueryParameter, values.NumCols())
+       includeName := parameterMode == OptionValueQueryParameterModeNamed
+       schema := values.Schema()
+       for i, v := range values.Columns() {
+               pi, err := arrowValueToQueryParameterValue(schema.Field(i), v, 
row)
+               if err != nil {
+                       return nil, err
+               }
+               parameters[i] = pi
+               if includeName {
+                       parameters[i].Name = values.ColumnName(i)
+               }
+       }
+       return parameters, nil
+}
+
+func runPlainQuery(ctx context.Context, query *bigquery.Query, alloc 
memory.Allocator, resultRecordBufferSize int) (bigqueryRdr *reader, totalRows 
int64, err error) {
+       arrowIterator, totalRows, err := runQuery(ctx, query, false)
+       if err != nil {
+               return nil, -1, err
+       }
+       rdr, err := ipcReaderFromArrowIterator(arrowIterator, alloc)
+       if err != nil {
+               return nil, -1, err
+       }
+
+       chs := make([]chan arrow.Record, 1)
+       ctx, cancelFn := context.WithCancel(ctx)
+       ch := make(chan arrow.Record, resultRecordBufferSize)
+       chs[0] = ch
+
+       defer func() {
+               if err != nil {
+                       close(ch)
+                       cancelFn()
+               }
+       }()
+
+       bigqueryRdr = &reader{
+               refCount:   1,
+               chs:        chs,
+               curChIndex: 0,
+               err:        nil,
+               cancelFn:   cancelFn,
+               schema:     nil,
+       }
+
+       go func() {
+               defer rdr.Release()
+               for rdr.Next() && ctx.Err() == nil {
+                       rec := rdr.Record()
+                       rec.Retain()
+                       ch <- rec
+               }
+
+               err = checkContext(ctx, rdr.Err())
+               defer close(ch)
+       }()
+       return bigqueryRdr, totalRows, nil
+}
+
+// kicks off a goroutine for each endpoint and returns a reader which
+// gathers all of the records as they come in.
+func newRecordReader(ctx context.Context, query *bigquery.Query, 
boundParameters array.RecordReader, parameterMode string, alloc 
memory.Allocator, resultRecordBufferSize, prefetchConcurrency int) (bigqueryRdr 
*reader, totalRows int64, err error) {
+       if boundParameters == nil {
+               return runPlainQuery(ctx, query, alloc, resultRecordBufferSize)
+       }
+
+       recs := make([]arrow.Record, 0)
+       for boundParameters.Next() {
+               rec := boundParameters.Record()
+               recs = append(recs, rec)
+       }

Review Comment:
   Hi @zeroshade, thanks for these advices! I was looking at snowflake's 
implementation for reference, and here if I'm not wrong, the idea is to count 
how many channels do we need so that each Record batch and their results can be 
put in a separate channel (and the requests would be in their own goroutine).
   
   ```go
   batches := int64(len(recs))
   chs := make([]chan arrow.Record, batches)
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to