cocoa-xu commented on code in PR #1722: URL: https://github.com/apache/arrow-adbc/pull/1722#discussion_r1662948519
########## go/adbc/driver/bigquery/connection.go: ########## @@ -0,0 +1,730 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package bigquery + +import ( + "bytes" + "context" + "crypto/tls" + "encoding/json" + "fmt" + "io" + "log" + "net/http" + "net/url" + "regexp" + "strconv" + "strings" + "time" + + "cloud.google.com/go/bigquery" + "github.com/apache/arrow-adbc/go/adbc" + "github.com/apache/arrow-adbc/go/adbc/driver/internal/driverbase" + "github.com/apache/arrow/go/v17/arrow" + "golang.org/x/oauth2" + "google.golang.org/api/option" +) + +type connectionImpl struct { + driverbase.ConnectionImplBase + + authType string + credentials string + clientID string + clientSecret string + refreshToken string + + // catalog is the same as the project id in BigQuery + catalog string + // dbSchema is the same as the dataset id in BigQuery + dbSchema string + // tableID is the default table for statement + tableID string + + resultRecordBufferSize int + prefetchConcurrency int + + client *bigquery.Client +} + +type bigQueryTokenResponse struct { + AccessToken string `json:"access_token"` + ExpiresIn int `json:"expires_in"` + Scope string `json:"scope"` + TokenType string `json:"token_type"` +} + +// GetCurrentCatalog implements driverbase.CurrentNamespacer. +func (c *connectionImpl) GetCurrentCatalog() (string, error) { + return c.catalog, nil +} + +// GetCurrentDbSchema implements driverbase.CurrentNamespacer. +func (c *connectionImpl) GetCurrentDbSchema() (string, error) { + return c.dbSchema, nil +} + +// SetCurrentCatalog implements driverbase.CurrentNamespacer. +func (c *connectionImpl) SetCurrentCatalog(value string) error { + c.catalog = value + return nil +} + +// SetCurrentDbSchema implements driverbase.CurrentNamespacer. +func (c *connectionImpl) SetCurrentDbSchema(value string) error { + sanitized, err := sanitizeDataset(value) + if err != nil { + return err + } + c.dbSchema = sanitized + return nil +} + +// ListTableTypes implements driverbase.TableTypeLister. +func (c *connectionImpl) ListTableTypes(ctx context.Context) ([]string, error) { + return []string{ + string(bigquery.RegularTable), + string(bigquery.ViewTable), + string(bigquery.ExternalTable), + string(bigquery.MaterializedView), + string(bigquery.Snapshot), + }, nil +} + +// SetAutocommit implements driverbase.AutocommitSetter. +func (c *connectionImpl) SetAutocommit(enabled bool) error { + if enabled { + return nil + } + return adbc.Error{ + Code: adbc.StatusNotImplemented, + Msg: "SetAutocommit to `false` is not yet implemented", + } +} + +// Commit commits any pending transactions on this connection, it should +// only be used if autocommit is disabled. +// +// Behavior is undefined if this is mixed with SQL transaction statements. +func (c *connectionImpl) Commit(_ context.Context) error { + return adbc.Error{ + Code: adbc.StatusNotImplemented, + Msg: "Commit not yet implemented for BigQuery driver", + } +} + +// Rollback rolls back any pending transactions. Only used if autocommit +// is disabled. +// +// Behavior is undefined if this is mixed with SQL transaction statements. +func (c *connectionImpl) Rollback(_ context.Context) error { + return adbc.Error{ + Code: adbc.StatusNotImplemented, + Msg: "Rollback not yet implemented for BigQuery driver", + } +} + +// Close closes this connection and releases any associated resources. +func (c *connectionImpl) Close() error { + return c.client.Close() +} + +// Metadata methods +// Generally these methods return an array.RecordReader that +// can be consumed to retrieve metadata about the database as Arrow +// data. The returned metadata has an expected schema given in the +// doc strings of the specific methods. Schema fields are nullable +// unless otherwise marked. While no Statement is used in these +// methods, the result set may count as an active statement to the +// driver for the purposes of concurrency management (e.g. if the +// driver has a limit on concurrent active statements and it must +// execute a SQL query internally in order to implement the metadata +// method). +// +// Some methods accept "search pattern" arguments, which are strings +// that can contain the special character "%" to match zero or more +// characters, or "_" to match exactly one character. (See the +// documentation of DatabaseMetaData in JDBC or "Pattern Value Arguments" +// in the ODBC documentation.) Escaping is not currently supported. +// GetObjects gets a hierarchical view of all catalogs, database schemas, +// tables, and columns. +// +// The result is an Arrow Dataset with the following schema: +// +// Field Name | Field Type +// ----------------------------|---------------------------- +// catalog_name | utf8 +// catalog_db_schemas | list<DB_SCHEMA_SCHEMA> +// +// DB_SCHEMA_SCHEMA is a Struct with the fields: +// +// Field Name | Field Type +// ----------------------------|---------------------------- +// db_schema_name | utf8 +// db_schema_tables | list<TABLE_SCHEMA> +// +// TABLE_SCHEMA is a Struct with the fields: +// +// Field Name | Field Type +// ----------------------------|---------------------------- +// table_name | utf8 not null +// table_type | utf8 not null +// table_columns | list<COLUMN_SCHEMA> +// table_constraints | list<CONSTRAINT_SCHEMA> +// +// COLUMN_SCHEMA is a Struct with the fields: +// +// Field Name | Field Type | Comments +// ----------------------------|---------------------|--------- +// column_name | utf8 not null | +// ordinal_position | int32 | (1) +// remarks | utf8 | (2) +// xdbc_data_type | int16 | (3) +// xdbc_type_name | utf8 | (3) +// xdbc_column_size | int32 | (3) +// xdbc_decimal_digits | int16 | (3) +// xdbc_num_prec_radix | int16 | (3) +// xdbc_nullable | int16 | (3) +// xdbc_column_def | utf8 | (3) +// xdbc_sql_data_type | int16 | (3) +// xdbc_datetime_sub | int16 | (3) +// xdbc_char_octet_length | int32 | (3) +// xdbc_is_nullable | utf8 | (3) +// xdbc_scope_catalog | utf8 | (3) +// xdbc_scope_schema | utf8 | (3) +// xdbc_scope_table | utf8 | (3) +// xdbc_is_autoincrement | bool | (3) +// xdbc_is_generatedcolumn | utf8 | (3) +// +// 1. The column's ordinal position in the table (starting from 1). +// 2. Database-specific description of the column. +// 3. Optional Value. Should be null if not supported by the driver. +// xdbc_values are meant to provide JDBC/ODBC-compatible metadata +// in an agnostic manner. +// +// CONSTRAINT_SCHEMA is a Struct with the fields: +// +// Field Name | Field Type | Comments +// ----------------------------|---------------------|--------- +// constraint_name | utf8 | +// constraint_type | utf8 not null | (1) +// constraint_column_names | list<utf8> not null | (2) +// constraint_column_usage | list<USAGE_SCHEMA> | (3) +// +// 1. One of 'CHECK', 'FOREIGN KEY', 'PRIMARY KEY', or 'UNIQUE'. +// 2. The columns on the current table that are constrained, in order. +// 3. For FOREIGN KEY only, the referenced table and columns. +// +// USAGE_SCHEMA is a Struct with fields: +// +// Field Name | Field Type +// ----------------------------|---------------------------- +// fk_catalog | utf8 +// fk_db_schema | utf8 +// fk_table | utf8 not null +// fk_column_name | utf8 not null +// +// For the parameters: If nil is passed, then that parameter will not +// be filtered by at all. If an empty string, then only objects without +// that property (ie: catalog or db schema) will be returned. +// +// tableName and columnName must be either nil (do not filter by +// table name or column name) or non-empty. +// +// All non-empty, non-nil strings should be a search pattern (as described +// earlier). + +func (c *connectionImpl) GetTableSchema(ctx context.Context, catalog *string, dbSchema *string, tableName string) (*arrow.Schema, error) { + return c.getTableSchemaWithFilter(ctx, catalog, dbSchema, tableName, nil) +} + +// NewStatement initializes a new statement object tied to this connection +func (c *connectionImpl) NewStatement() (adbc.Statement, error) { + return &statement{ + connectionImpl: c, + query: c.client.Query(""), + parameterMode: OptionValueQueryParameterModePositional, + resultRecordBufferSize: c.resultRecordBufferSize, + prefetchConcurrency: c.prefetchConcurrency, + }, nil +} + +func (c *connectionImpl) GetOption(key string) (string, error) { + switch key { + case OptionStringAuthType: + return c.authType, nil + case OptionStringAuthCredentials: + return c.credentials, nil + case OptionStringAuthClientID: + return c.clientID, nil + case OptionStringAuthClientSecret: + return c.clientSecret, nil + case OptionStringAuthRefreshToken: + return c.refreshToken, nil + case OptionStringProjectID: + return c.catalog, nil + case OptionStringDatasetID: + return c.dbSchema, nil + case OptionStringTableID: + return c.tableID, nil + default: + return c.ConnectionImplBase.GetOption(key) + } +} + +func (c *connectionImpl) GetOptionInt(key string) (int64, error) { + switch key { + case OptionIntQueryResultBufferSize: + return int64(c.resultRecordBufferSize), nil + case OptionIntQueryPrefetchConcurrency: + return int64(c.prefetchConcurrency), nil + default: + return c.ConnectionImplBase.GetOptionInt(key) + } +} + +func (c *connectionImpl) SetOptionInt(key string, value int64) error { + switch key { + case OptionIntQueryResultBufferSize: + c.resultRecordBufferSize = int(value) + return nil + case OptionIntQueryPrefetchConcurrency: + c.prefetchConcurrency = int(value) + return nil + default: + return c.ConnectionImplBase.SetOptionInt(key, value) + } +} + +func (c *connectionImpl) newClient(ctx context.Context) error { + if c.catalog == "" { + return adbc.Error{ + Code: adbc.StatusInvalidArgument, + Msg: "ProjectID is empty", + } + } + switch c.authType { + case OptionValueAuthTypeJSONCredentialFile, OptionValueAuthTypeJSONCredentialString, OptionValueAuthTypeUserAuthentication: + var credentials option.ClientOption + if c.authType == OptionValueAuthTypeJSONCredentialFile { + credentials = option.WithCredentialsFile(c.credentials) + } else if c.authType == OptionValueAuthTypeJSONCredentialString { + credentials = option.WithCredentialsJSON([]byte(c.credentials)) + } else { + if c.clientID == "" { + return adbc.Error{ + Code: adbc.StatusInvalidArgument, + Msg: fmt.Sprintf("The `%s` parameter is empty", OptionStringAuthClientID), + } + } + if c.clientSecret == "" { + return adbc.Error{ + Code: adbc.StatusInvalidArgument, + Msg: fmt.Sprintf("The `%s` parameter is empty", OptionStringAuthClientSecret), + } + } + if c.refreshToken == "" { + return adbc.Error{ + Code: adbc.StatusInvalidArgument, + Msg: fmt.Sprintf("The `%s` parameter is empty", OptionStringAuthRefreshToken), + } + } + credentials = option.WithTokenSource(c) + } + + client, err := bigquery.NewClient(ctx, c.catalog, credentials) + if err != nil { + return err + } + + err = client.EnableStorageReadClient(ctx, credentials) + if err != nil { + return err + } + + c.client = client + default: + client, err := bigquery.NewClient(ctx, c.catalog) + if err != nil { + return err + } + + err = client.EnableStorageReadClient(ctx) + if err != nil { + return err + } + + c.client = client + } + return nil +} + +var ( + // Dataset: + // + // https://cloud.google.com/bigquery/docs/datasets#dataset-naming + // + // When you create a dataset in BigQuery, the dataset name must be unique for each project. + // The dataset name can contain the following: + // - Up to 1,024 characters. + // - Letters (uppercase or lowercase), numbers, and underscores. + // Dataset names are case-sensitive by default. mydataset and MyDataset can coexist in the same project, + // unless one of them has case-sensitivity turned off. + // Dataset names cannot contain spaces or special characters such as -, &, @, or %. + datasetRegex = regexp.MustCompile("^[a-zA-Z0-9_-]") + + precisionScaleRegex = regexp.MustCompile(`^(?:BIG)?NUMERIC\((?P<precision>\d+)(?:,(?P<scale>\d+))?\)$`) + simpleDataType = map[string]arrow.DataType{ + "BOOL": arrow.FixedWidthTypes.Boolean, + "BOOLEAN": arrow.FixedWidthTypes.Boolean, + "FLOAT": arrow.PrimitiveTypes.Float64, + "FLOAT64": arrow.PrimitiveTypes.Float64, + "BYTES": arrow.BinaryTypes.Binary, + "STRING": arrow.BinaryTypes.String, + // TODO: potentially we should consider using GeoArrow for this + "GEOGRAPHY": arrow.BinaryTypes.String, + "JSON": arrow.BinaryTypes.String, + "DATE": arrow.FixedWidthTypes.Date32, + "DATETIME": &arrow.TimestampType{Unit: arrow.Microsecond}, + "TIMESTAMP": &arrow.TimestampType{Unit: arrow.Microsecond}, + "TIME": arrow.FixedWidthTypes.Time64us, + } +) + +func sanitizeDataset(value string) (string, error) { + if value == "" { + return value, nil + } + + if datasetRegex.MatchString(value) { + if len(value) > 1024 { + return "", adbc.Error{ + Code: adbc.StatusInvalidArgument, + Msg: "Dataset name exceeds 1024 characters", + } + } + return value, nil + } + + return "", adbc.Error{ + Code: adbc.StatusInvalidArgument, + Msg: fmt.Sprintf("invalid characters in value `%s`", value), + } +} + +func buildSchemaField(name string, typeString string) (arrow.Field, error) { + index := strings.Index(name, "(") + if index == -1 { + index = strings.Index(name, "<") + } else { + lIndex := strings.Index(name, "<") + if index < lIndex { + index = lIndex + } + } + + dataType := typeString + if index != -1 { + dataType = dataType[:index] + } + return buildField(name, typeString, dataType, index) +} + +func buildField(name, typeString, dataType string, index int) (arrow.Field, error) { + // https://cloud.google.com/bigquery/docs/reference/storage#arrow_schema_details + field := arrow.Field{ + Name: strings.Clone(name), + } + val, ok := simpleDataType[dataType] + if ok { + field.Type = val + return field, nil + } + + switch dataType { + case "NUMERIC", "DECIMAL": + precision, scale, err := parsePrecisionAndScale(name, typeString) + if err != nil { + return arrow.Field{}, err + } + field.Type = &arrow.Decimal128Type{ + Precision: precision, + Scale: scale, + } + case "BIGNUMERIC", "BIGDECIMAL": + precision, scale, err := parsePrecisionAndScale(name, typeString) + if err != nil { + return arrow.Field{}, err + } + field.Type = &arrow.Decimal256Type{ + Precision: precision, + Scale: scale, + } + case "ARRAY": + arrayType := strings.Replace(typeString[:len(dataType)], "<", "", 1) + rIndex := strings.LastIndex(arrayType, ">") + if rIndex == -1 { + return arrow.Field{}, adbc.Error{ + Code: adbc.StatusInvalidData, + Msg: fmt.Sprintf("Cannot parse array type `%s` for field `%s`: cannot find `>`", typeString, name), + } + } + arrayType = arrayType[:rIndex] + arrayType[rIndex+1:] + arrayFieldType, err := buildField(name, typeString, arrayType, rIndex) + if err != nil { + return arrow.Field{}, err + } + field.Type = arrow.ListOf(arrayFieldType.Type) + field.Metadata = arrayFieldType.Metadata + field.Nullable = arrayFieldType.Nullable + case "RECORD", "STRUCT": + fieldRecords := typeString[index+1:] + fieldRecords = fieldRecords[:len(fieldRecords)-1] + nestedFields := make([]arrow.Field, 0) + for _, record := range strings.Split(fieldRecords, ",") { + fieldRecord := strings.TrimSpace(record) + recordParts := strings.SplitN(fieldRecord, " ", 2) + if len(recordParts) != 2 { + return arrow.Field{}, adbc.Error{ + Code: adbc.StatusInvalidData, + Msg: fmt.Sprintf("invalid field record `%s` for type `%s`", fieldRecord, dataType), + } + } + fieldName := recordParts[0] + fieldType := recordParts[1] + nestedField, err := buildSchemaField(fieldName, fieldType) + if err != nil { + return arrow.Field{}, err + } + nestedFields = append(nestedFields, nestedField) + } + structType := arrow.StructOf(nestedFields...) + if structType == nil { + return arrow.Field{}, adbc.Error{ + Code: adbc.StatusInvalidArgument, + Msg: fmt.Sprintf("Cannot create a struct schema for record `%s`", fieldRecords), + } + } + field.Type = structType + default: + return arrow.Field{}, adbc.Error{ + Code: adbc.StatusInvalidArgument, + Msg: fmt.Sprintf("Cannot handle data type `%s`, type=`%s`", dataType, typeString), + } + } + return field, nil +} + +func parsePrecisionAndScale(name, typeString string) (int32, int32, error) { + typeString = strings.TrimSpace(typeString) + if len(typeString) == 0 { + return 0, 0, adbc.Error{ + Code: adbc.StatusInvalidData, + Msg: fmt.Sprintf("Cannot parse data type `%s` for field `%s`", typeString, name), + } + } + + if typeString == "NUMERIC" { + return 38, 9, nil + } else if typeString == "BIGNUMERIC" { + return 76, 38, nil + } + + r := precisionScaleRegex.FindStringSubmatch(typeString) + if len(r) != 3 { + return 0, 0, adbc.Error{ + Code: adbc.StatusInvalidData, + Msg: fmt.Sprintf("Cannot parse data type `%s` for field `%s`", typeString, name), + } + } + + precisionString := r[precisionScaleRegex.SubexpIndex("precision")] + precision, err := strconv.ParseInt(precisionString, 10, 32) + if err != nil { + return 0, 0, adbc.Error{ + Code: adbc.StatusInvalidData, + Msg: fmt.Sprintf("Cannot parse precision `%s` for field `%s`: %s", precisionString, name, err.Error()), + } + } + + scaleString := r[precisionScaleRegex.SubexpIndex("scale")] + scale, err := strconv.ParseInt(scaleString, 10, 32) + if err != nil { + return 0, 0, adbc.Error{ + Code: adbc.StatusInvalidData, + Msg: fmt.Sprintf("Cannot parse scale `%s` for field `%s`: %s", scaleString, name, err.Error()), + } + } + return int32(precision), int32(scale), nil +} + +func (c *connectionImpl) getTableSchemaWithFilter(ctx context.Context, catalog *string, dbSchema *string, tableName string, columnName *string) (*arrow.Schema, error) { Review Comment: Good catch! Sorry I didn't noticed that before. I've pushed a commit that uses this information! -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
