joellubi commented on code in PR #1722: URL: https://github.com/apache/arrow-adbc/pull/1722#discussion_r1663202041
########## go/adbc/driver/bigquery/driver_test.go: ########## @@ -0,0 +1,1354 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package bigquery_test + +import ( + "bytes" + "context" + "fmt" + "os" + "strings" + "testing" + "time" + + "github.com/apache/arrow-adbc/go/adbc" + driver "github.com/apache/arrow-adbc/go/adbc/driver/bigquery" + "github.com/apache/arrow/go/v17/arrow" + "github.com/apache/arrow/go/v17/arrow/array" + "github.com/apache/arrow/go/v17/arrow/decimal128" + "github.com/apache/arrow/go/v17/arrow/decimal256" + "github.com/apache/arrow/go/v17/arrow/memory" + "github.com/google/uuid" + "github.com/stretchr/testify/suite" +) + +type BigQueryQuirks struct { + mem *memory.CheckedAllocator + authType string + authValue string + // catalogName is the same as projectID + catalogName string + // schemaName is the same as datasetID + schemaName string +} + +func (q *BigQueryQuirks) SetupDriver(t *testing.T) adbc.Driver { + q.mem = memory.NewCheckedAllocator(memory.DefaultAllocator) + return driver.NewDriver(q.mem) +} + +func (q *BigQueryQuirks) TearDownDriver(t *testing.T, _ adbc.Driver) { + q.mem.AssertSize(t, 0) +} + +func (q *BigQueryQuirks) DatabaseOptions() map[string]string { + return map[string]string{ + driver.OptionStringAuthType: q.authType, + driver.OptionStringAuthCredentials: q.authValue, + driver.OptionStringProjectID: q.catalogName, + driver.OptionStringDatasetID: q.schemaName, + } +} + +func getSqlTypeFromArrowField(f arrow.Field) string { + switch f.Type.ID() { + case arrow.BOOL: + return "BOOLEAN" + case arrow.UINT8, arrow.INT8, arrow.UINT16, arrow.INT16, arrow.UINT32, arrow.INT32, arrow.UINT64, arrow.INT64: + return "INTEGER" + case arrow.FLOAT32, arrow.FLOAT64: + return "FLOAT64" + case arrow.STRING: + return "STRING" + case arrow.BINARY, arrow.FIXED_SIZE_BINARY: + return "BYTES" + case arrow.DATE32, arrow.DATE64: + return "DATE" + case arrow.TIMESTAMP: + return "TIMESTAMP" + case arrow.TIME32, arrow.TIME64: + return "TIME" + case arrow.INTERVAL_MONTHS: + return "INTERVAL_MONTHS" + case arrow.DECIMAL128: + return "NUMERIC" + case arrow.DECIMAL256: + return "BIGNUMERIC" + case arrow.LIST: + elem := getSqlTypeFromArrowField(f.Type.(*arrow.ListType).ElemField()) + return "ARRAY<" + elem + ">" + case arrow.STRUCT: + fields := f.Type.(*arrow.StructType).Fields() + childTypes := make([]string, len(fields)) + for i, field := range fields { + childTypes[i] = fmt.Sprintf("%s %s", field.Name, getSqlTypeFromArrowField(field)) + } + return fmt.Sprintf("STRUCT<%s>", strings.Join(childTypes, ",")) + default: + return "" + } +} + +func (q *BigQueryQuirks) quoteTblName(name string) string { + return fmt.Sprintf("`%s.%s`", q.schemaName, strings.ReplaceAll(name, "\"", "\"\"")) +} + +func (q *BigQueryQuirks) CreateSampleTableWithRecords(tableName string, r arrow.Record) error { + var b strings.Builder + b.WriteString("CREATE OR REPLACE TABLE ") + b.WriteString(q.quoteTblName(tableName)) + b.WriteString(" (") + + for i := 0; i < int(r.NumCols()); i++ { + if i != 0 { + b.WriteString(", ") + } + f := r.Schema().Field(i) + b.WriteString(f.Name) + b.WriteByte(' ') + b.WriteString(getSqlTypeFromArrowField(f)) + } + b.WriteString(")") + + ctx := context.Background() + mem := memory.NewCheckedAllocator(memory.DefaultAllocator) + tmpDriver := driver.NewDriver(mem) + db, err := tmpDriver.NewDatabase(q.DatabaseOptions()) + if err != nil { + panic(err) + } + defer db.Close() + + cnxn, err := db.Open(ctx) + if err != nil { + panic(err) + } + defer cnxn.Close() + + stmt, err := cnxn.NewStatement() + if err != nil { + panic(err) + } + defer stmt.Close() + + err = stmt.SetOption(driver.OptionBoolQueryUseLegacySQL, "false") + if err != nil { + panic(err) + } + + creationQuery := b.String() + err = stmt.SetSqlQuery(creationQuery) + if err != nil { + panic(err) + } + _, err = stmt.ExecuteUpdate(ctx) + if err != nil { + panic(err) + } + + // wait for some time before accessing it + // BigQuery needs some time to make the table available + // otherwise the query will fail with error saying the table cannot be found + time.Sleep(5 * time.Second) + + insertQuery := "INSERT INTO " + q.quoteTblName(tableName) + " VALUES (" + bindings := strings.Repeat("?,", int(r.NumCols())) + insertQuery += bindings[:len(bindings)-1] + ")" + err = stmt.Bind(ctx, r) + if err != nil { + return err + } + err = stmt.SetSqlQuery(insertQuery) + if err != nil { + return err + } + rdr, _, err := stmt.ExecuteQuery(ctx) + if err != nil { + return err + } + + rdr.Release() + return nil +} + +func (q *BigQueryQuirks) CreateSampleTableWithStreams(tableName string, rdr array.RecordReader) error { + var b strings.Builder + b.WriteString("CREATE OR REPLACE TABLE ") + b.WriteString(q.quoteTblName(tableName)) + b.WriteString(" (") + + for i := 0; i < rdr.Schema().NumFields(); i++ { + if i != 0 { + b.WriteString(", ") + } + f := rdr.Schema().Field(i) + b.WriteString(f.Name) + b.WriteByte(' ') + b.WriteString(getSqlTypeFromArrowField(f)) + } + b.WriteString(")") + + ctx := context.Background() + mem := memory.NewCheckedAllocator(memory.DefaultAllocator) + tmpDriver := driver.NewDriver(mem) + db, err := tmpDriver.NewDatabase(q.DatabaseOptions()) + if err != nil { + panic(err) + } + defer db.Close() + + cnxn, err := db.Open(ctx) + if err != nil { + panic(err) + } + defer cnxn.Close() + + stmt, err := cnxn.NewStatement() + if err != nil { + panic(err) + } + defer stmt.Close() + + err = stmt.SetOption(driver.OptionBoolQueryUseLegacySQL, "false") + if err != nil { + panic(err) + } + + creationQuery := b.String() + err = stmt.SetSqlQuery(creationQuery) + if err != nil { + panic(err) + } + _, err = stmt.ExecuteUpdate(ctx) + if err != nil { + panic(err) + } + + // wait for some time before accessing it + // BigQuery needs some time to make the table available + // otherwise the query will fail with error saying the table cannot be found + time.Sleep(5 * time.Second) + + insertQuery := "INSERT INTO " + q.quoteTblName(tableName) + " VALUES (" + bindings := strings.Repeat("?,", rdr.Schema().NumFields()) + insertQuery += bindings[:len(bindings)-1] + ")" + err = stmt.BindStream(ctx, rdr) + if err != nil { + return err + } + err = stmt.SetSqlQuery(insertQuery) + if err != nil { + return err + } + res, _, err := stmt.ExecuteQuery(ctx) + if err != nil { + return err + } + + res.Release() + return nil +} + +func (q *BigQueryQuirks) DropTable(cnxn adbc.Connection, tblname string) error { + stmt, err := cnxn.NewStatement() + if err != nil { + return err + } + defer stmt.Close() + + if err = stmt.SetSqlQuery(`DROP TABLE IF EXISTS ` + q.quoteTblName(tblname)); err != nil { + return err + } + + _, err = stmt.ExecuteUpdate(context.Background()) + return err +} + +func (q *BigQueryQuirks) Alloc() memory.Allocator { return q.mem } +func (q *BigQueryQuirks) BindParameter(_ int) string { return "?" } +func (q *BigQueryQuirks) SupportsBulkIngest(string) bool { return true } Review Comment: I think this should be set to `false` (which is fine for now). Bulk ingestion is a specific [feature of ADBC](https://arrow.apache.org/adbc/current/format/specification.html#bulk-ingestion) where the `OptionKeyIngestTargetTable` key can simply be set to a table name and then you don't need to write an insert or load query. For example you can see how the snowflake driver handles [checking for it](https://github.com/apache/arrow-adbc/blob/d9a92b81f0ccd723aa6a3e8b2fcbea793f2aa37e/go/adbc/driver/snowflake/statement.go#L451-L455). The existing ingestion test look good anyway and we'll likely be able to just replace the `Quirks.CreateSampleTableWithRecords` helper with a call to native ingestion. ########## go/adbc/driver/bigquery/driver_test.go: ########## @@ -0,0 +1,1354 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package bigquery_test + +import ( + "bytes" + "context" + "fmt" + "os" + "strings" + "testing" + "time" + + "github.com/apache/arrow-adbc/go/adbc" + driver "github.com/apache/arrow-adbc/go/adbc/driver/bigquery" + "github.com/apache/arrow/go/v17/arrow" + "github.com/apache/arrow/go/v17/arrow/array" + "github.com/apache/arrow/go/v17/arrow/decimal128" + "github.com/apache/arrow/go/v17/arrow/decimal256" + "github.com/apache/arrow/go/v17/arrow/memory" + "github.com/google/uuid" + "github.com/stretchr/testify/suite" +) + +type BigQueryQuirks struct { + mem *memory.CheckedAllocator + authType string + authValue string + // catalogName is the same as projectID + catalogName string + // schemaName is the same as datasetID + schemaName string +} + +func (q *BigQueryQuirks) SetupDriver(t *testing.T) adbc.Driver { + q.mem = memory.NewCheckedAllocator(memory.DefaultAllocator) + return driver.NewDriver(q.mem) +} + +func (q *BigQueryQuirks) TearDownDriver(t *testing.T, _ adbc.Driver) { + q.mem.AssertSize(t, 0) +} + +func (q *BigQueryQuirks) DatabaseOptions() map[string]string { + return map[string]string{ + driver.OptionStringAuthType: q.authType, + driver.OptionStringAuthCredentials: q.authValue, + driver.OptionStringProjectID: q.catalogName, + driver.OptionStringDatasetID: q.schemaName, + } +} + +func getSqlTypeFromArrowField(f arrow.Field) string { + switch f.Type.ID() { + case arrow.BOOL: + return "BOOLEAN" + case arrow.UINT8, arrow.INT8, arrow.UINT16, arrow.INT16, arrow.UINT32, arrow.INT32, arrow.UINT64, arrow.INT64: + return "INTEGER" + case arrow.FLOAT32, arrow.FLOAT64: + return "FLOAT64" + case arrow.STRING: + return "STRING" + case arrow.BINARY, arrow.FIXED_SIZE_BINARY: + return "BYTES" + case arrow.DATE32, arrow.DATE64: + return "DATE" + case arrow.TIMESTAMP: + return "TIMESTAMP" + case arrow.TIME32, arrow.TIME64: + return "TIME" + case arrow.INTERVAL_MONTHS: + return "INTERVAL_MONTHS" + case arrow.DECIMAL128: + return "NUMERIC" + case arrow.DECIMAL256: + return "BIGNUMERIC" + case arrow.LIST: + elem := getSqlTypeFromArrowField(f.Type.(*arrow.ListType).ElemField()) + return "ARRAY<" + elem + ">" + case arrow.STRUCT: + fields := f.Type.(*arrow.StructType).Fields() + childTypes := make([]string, len(fields)) + for i, field := range fields { + childTypes[i] = fmt.Sprintf("%s %s", field.Name, getSqlTypeFromArrowField(field)) + } + return fmt.Sprintf("STRUCT<%s>", strings.Join(childTypes, ",")) + default: + return "" + } +} + +func (q *BigQueryQuirks) quoteTblName(name string) string { + return fmt.Sprintf("`%s.%s`", q.schemaName, strings.ReplaceAll(name, "\"", "\"\"")) +} + +func (q *BigQueryQuirks) CreateSampleTableWithRecords(tableName string, r arrow.Record) error { + var b strings.Builder + b.WriteString("CREATE OR REPLACE TABLE ") + b.WriteString(q.quoteTblName(tableName)) + b.WriteString(" (") + + for i := 0; i < int(r.NumCols()); i++ { + if i != 0 { + b.WriteString(", ") + } + f := r.Schema().Field(i) + b.WriteString(f.Name) + b.WriteByte(' ') + b.WriteString(getSqlTypeFromArrowField(f)) + } + b.WriteString(")") + + ctx := context.Background() + mem := memory.NewCheckedAllocator(memory.DefaultAllocator) + tmpDriver := driver.NewDriver(mem) + db, err := tmpDriver.NewDatabase(q.DatabaseOptions()) + if err != nil { + panic(err) + } + defer db.Close() + + cnxn, err := db.Open(ctx) + if err != nil { + panic(err) + } + defer cnxn.Close() + + stmt, err := cnxn.NewStatement() + if err != nil { + panic(err) + } + defer stmt.Close() + + err = stmt.SetOption(driver.OptionBoolQueryUseLegacySQL, "false") + if err != nil { + panic(err) + } + + creationQuery := b.String() + err = stmt.SetSqlQuery(creationQuery) + if err != nil { + panic(err) + } + _, err = stmt.ExecuteUpdate(ctx) + if err != nil { + panic(err) + } + + // wait for some time before accessing it + // BigQuery needs some time to make the table available + // otherwise the query will fail with error saying the table cannot be found + time.Sleep(5 * time.Second) + + insertQuery := "INSERT INTO " + q.quoteTblName(tableName) + " VALUES (" + bindings := strings.Repeat("?,", int(r.NumCols())) + insertQuery += bindings[:len(bindings)-1] + ")" + err = stmt.Bind(ctx, r) + if err != nil { + return err + } + err = stmt.SetSqlQuery(insertQuery) + if err != nil { + return err + } + rdr, _, err := stmt.ExecuteQuery(ctx) + if err != nil { + return err + } + + rdr.Release() + return nil +} + +func (q *BigQueryQuirks) CreateSampleTableWithStreams(tableName string, rdr array.RecordReader) error { + var b strings.Builder + b.WriteString("CREATE OR REPLACE TABLE ") + b.WriteString(q.quoteTblName(tableName)) + b.WriteString(" (") + + for i := 0; i < rdr.Schema().NumFields(); i++ { + if i != 0 { + b.WriteString(", ") + } + f := rdr.Schema().Field(i) + b.WriteString(f.Name) + b.WriteByte(' ') + b.WriteString(getSqlTypeFromArrowField(f)) + } + b.WriteString(")") + + ctx := context.Background() + mem := memory.NewCheckedAllocator(memory.DefaultAllocator) + tmpDriver := driver.NewDriver(mem) + db, err := tmpDriver.NewDatabase(q.DatabaseOptions()) + if err != nil { + panic(err) + } + defer db.Close() + + cnxn, err := db.Open(ctx) + if err != nil { + panic(err) + } + defer cnxn.Close() + + stmt, err := cnxn.NewStatement() + if err != nil { + panic(err) + } + defer stmt.Close() + + err = stmt.SetOption(driver.OptionBoolQueryUseLegacySQL, "false") + if err != nil { + panic(err) + } + + creationQuery := b.String() + err = stmt.SetSqlQuery(creationQuery) + if err != nil { + panic(err) + } + _, err = stmt.ExecuteUpdate(ctx) + if err != nil { + panic(err) + } + + // wait for some time before accessing it + // BigQuery needs some time to make the table available + // otherwise the query will fail with error saying the table cannot be found + time.Sleep(5 * time.Second) + + insertQuery := "INSERT INTO " + q.quoteTblName(tableName) + " VALUES (" + bindings := strings.Repeat("?,", rdr.Schema().NumFields()) + insertQuery += bindings[:len(bindings)-1] + ")" + err = stmt.BindStream(ctx, rdr) + if err != nil { + return err + } + err = stmt.SetSqlQuery(insertQuery) + if err != nil { + return err + } + res, _, err := stmt.ExecuteQuery(ctx) + if err != nil { + return err + } + + res.Release() + return nil +} + +func (q *BigQueryQuirks) DropTable(cnxn adbc.Connection, tblname string) error { + stmt, err := cnxn.NewStatement() + if err != nil { + return err + } + defer stmt.Close() + + if err = stmt.SetSqlQuery(`DROP TABLE IF EXISTS ` + q.quoteTblName(tblname)); err != nil { + return err + } + + _, err = stmt.ExecuteUpdate(context.Background()) + return err +} + +func (q *BigQueryQuirks) Alloc() memory.Allocator { return q.mem } +func (q *BigQueryQuirks) BindParameter(_ int) string { return "?" } +func (q *BigQueryQuirks) SupportsBulkIngest(string) bool { return true } +func (q *BigQueryQuirks) SupportsConcurrentStatements() bool { return true } +func (q *BigQueryQuirks) SupportsCurrentCatalogSchema() bool { return true } +func (q *BigQueryQuirks) SupportsExecuteSchema() bool { return true } Review Comment: I don't think this should be `true` either. `ExecuteSchema` currently returns an "unimplemented" error, which should fail this test. I haven't run the tests myself though. ########## go/adbc/driver/bigquery/record_reader.go: ########## @@ -0,0 +1,315 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package bigquery + +import ( + "context" + "errors" + "sync/atomic" + + "cloud.google.com/go/bigquery" + "github.com/apache/arrow-adbc/go/adbc" + "github.com/apache/arrow/go/v17/arrow" + "github.com/apache/arrow/go/v17/arrow/array" + "github.com/apache/arrow/go/v17/arrow/ipc" + "github.com/apache/arrow/go/v17/arrow/memory" + "golang.org/x/sync/errgroup" +) + +type reader struct { + refCount int64 + schema *arrow.Schema + chs []chan arrow.Record + curChIndex int + rec arrow.Record + err error + + cancelFn context.CancelFunc +} + +func checkContext(ctx context.Context, maybeErr error) error { + if maybeErr != nil { + return maybeErr + } else if errors.Is(ctx.Err(), context.Canceled) { + return adbc.Error{Msg: ctx.Err().Error(), Code: adbc.StatusCancelled} + } else if errors.Is(ctx.Err(), context.DeadlineExceeded) { + return adbc.Error{Msg: ctx.Err().Error(), Code: adbc.StatusTimeout} + } + return ctx.Err() +} + +func runQuery(ctx context.Context, query *bigquery.Query, executeUpdate bool) (bigquery.ArrowIterator, int64, error) { + job, err := query.Run(ctx) + if err != nil { + return nil, -1, err + } + if executeUpdate { + return nil, 0, nil + } + + iter, err := job.Read(ctx) + if err != nil { + return nil, -1, err + } + arrowIterator, err := iter.ArrowIterator() + if err != nil { + return nil, -1, err + } + totalRows := int64(iter.TotalRows) + return arrowIterator, totalRows, nil +} + +func ipcReaderFromArrowIterator(arrowIterator bigquery.ArrowIterator, alloc memory.Allocator) (*ipc.Reader, error) { + arrowItReader := bigquery.NewArrowIteratorReader(arrowIterator) + return ipc.NewReader(arrowItReader, ipc.WithAllocator(alloc)) +} + +func getQueryParameter(values arrow.Record, row int, parameterMode string) ([]bigquery.QueryParameter, error) { + parameters := make([]bigquery.QueryParameter, values.NumCols()) + includeName := parameterMode == OptionValueQueryParameterModeNamed + schema := values.Schema() + for i, v := range values.Columns() { + pi, err := arrowValueToQueryParameterValue(schema.Field(i), v, row) + if err != nil { + return nil, err + } + parameters[i] = pi + if includeName { + parameters[i].Name = values.ColumnName(i) + } + } + return parameters, nil +} + +func runPlainQuery(ctx context.Context, query *bigquery.Query, alloc memory.Allocator, resultRecordBufferSize int) (bigqueryRdr *reader, totalRows int64, err error) { + arrowIterator, totalRows, err := runQuery(ctx, query, false) + if err != nil { + return nil, -1, err + } + rdr, err := ipcReaderFromArrowIterator(arrowIterator, alloc) + if err != nil { + return nil, -1, err + } + + chs := make([]chan arrow.Record, 1) + ctx, cancelFn := context.WithCancel(ctx) + ch := make(chan arrow.Record, resultRecordBufferSize) + chs[0] = ch + + defer func() { + if err != nil { + close(ch) + cancelFn() + } + }() + + bigqueryRdr = &reader{ + refCount: 1, + chs: chs, + curChIndex: 0, + err: nil, + cancelFn: cancelFn, + schema: nil, + } + + go func() { + defer rdr.Release() + for rdr.Next() && ctx.Err() == nil { + rec := rdr.Record() + rec.Retain() + ch <- rec + } + + err = checkContext(ctx, rdr.Err()) + defer close(ch) + }() + return bigqueryRdr, totalRows, nil +} + +// kicks off a goroutine for each endpoint and returns a reader which +// gathers all of the records as they come in. +func newRecordReader(ctx context.Context, query *bigquery.Query, boundParameters array.RecordReader, parameterMode string, alloc memory.Allocator, resultRecordBufferSize, prefetchConcurrency int) (bigqueryRdr *reader, totalRows int64, err error) { + if boundParameters == nil { + return runPlainQuery(ctx, query, alloc, resultRecordBufferSize) + } + + recs := make([]arrow.Record, 0) + for boundParameters.Next() { + rec := boundParameters.Record() + recs = append(recs, rec) + } Review Comment: RecordReaders use reference counting and should be released once they've been consumed. ```suggestion for boundParameters.Next() { rec := boundParameters.Record() recs = append(recs, rec) } defer boundParameters.Release() ``` -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
