Add HSQLDB support and move phoenix support into an adapter
Project: http://git-wip-us.apache.org/repos/asf/calcite-avatica-go/repo Commit: http://git-wip-us.apache.org/repos/asf/calcite-avatica-go/commit/2968def4 Tree: http://git-wip-us.apache.org/repos/asf/calcite-avatica-go/tree/2968def4 Diff: http://git-wip-us.apache.org/repos/asf/calcite-avatica-go/diff/2968def4 Branch: refs/heads/master Commit: 2968def4e8dfefdbec731b710769a07fed5d6ff4 Parents: 7891426 Author: Francis Chuang <francis.chu...@boostport.com> Authored: Sun Apr 15 14:27:20 2018 +1000 Committer: Francis Chuang <francis.chu...@boostport.com> Committed: Mon Apr 16 14:23:14 2018 +1000 ---------------------------------------------------------------------- adapter.go | 30 ++ connection.go | 44 +- connection_go18.go | 2 +- driver.go | 44 +- driver_go18_hsqldb_test.go | 578 +++++++++++++++++++++ driver_go18_phoenix_test.go | 768 ++++++++++++++++++++++++++++ driver_go18_test.go | 736 --------------------------- driver_hsqldb_test.go | 1028 +++++++++++++++++++++++++++++++++++++ driver_phoenix_test.go | 1044 ++++++++++++++++++++++++++++++++++++++ driver_test.go | 977 +---------------------------------- errors.go | 324 ------------ errors/errors.go | 63 +++ gen-protobuf.bat | 2 +- gen-protobuf.sh | 2 +- generic/generic.go | 125 +++++ hsqldb/hsqldb.go | 125 +++++ http_client.go | 16 +- internal/column.go | 39 ++ moby.yml | 9 +- phoenix/phoenix.go | 365 +++++++++++++ rows.go | 106 +--- rows_go18.go | 12 +- statement.go | 10 +- transaction.go | 18 +- 24 files changed, 4311 insertions(+), 2156 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/calcite-avatica-go/blob/2968def4/adapter.go ---------------------------------------------------------------------- diff --git a/adapter.go b/adapter.go new file mode 100644 index 0000000..ea0ba93 --- /dev/null +++ b/adapter.go @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to you under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package avatica + +import ( + "github.com/apache/calcite-avatica-go/errors" + "github.com/apache/calcite-avatica-go/internal" + "github.com/apache/calcite-avatica-go/message" +) + +type Adapter interface { + GetPingStatement() string + GetColumnTypeDefinition(*message.ColumnMetaData) *internal.Column + ErrorResponseToResponseError(*message.ErrorResponse) errors.ResponseError +} http://git-wip-us.apache.org/repos/asf/calcite-avatica-go/blob/2968def4/connection.go ---------------------------------------------------------------------- diff --git a/connection.go b/connection.go index ef6a599..cf8ad9b 100644 --- a/connection.go +++ b/connection.go @@ -20,6 +20,7 @@ package avatica import ( "database/sql/driver" + "github.com/apache/calcite-avatica-go/errors" "github.com/apache/calcite-avatica-go/message" "golang.org/x/net/context" ) @@ -28,6 +29,7 @@ type conn struct { connectionId string config *Config httpClient *httpClient + adapter Adapter } // Prepare returns a prepared statement, bound to this connection. @@ -47,7 +49,7 @@ func (c *conn) prepare(ctx context.Context, query string) (driver.Stmt, error) { }) if err != nil { - return nil, err + return nil, c.avaticaErrorToResponseErrorOrError(err) } prepareResponse := response.(*message.PrepareResponse) @@ -80,7 +82,11 @@ func (c *conn) Close() error { c.connectionId = "" - return err + if err != nil { + return c.avaticaErrorToResponseErrorOrError(err) + } + + return nil } // Begin starts and returns a new transaction. @@ -107,7 +113,7 @@ func (c *conn) begin(ctx context.Context, isolationLevel isoLevel) (driver.Tx, e }) if err != nil { - return nil, err + return nil, c.avaticaErrorToResponseErrorOrError(err) } return &tx{ @@ -135,7 +141,7 @@ func (c *conn) exec(ctx context.Context, query string, args []namedValue) (drive }) if err != nil { - return nil, err + return nil, c.avaticaErrorToResponseErrorOrError(err) } res, err := c.httpClient.post(ctx, &message.PrepareAndExecuteRequest{ @@ -147,7 +153,7 @@ func (c *conn) exec(ctx context.Context, query string, args []namedValue) (drive }) if err != nil { - return nil, err + return nil, c.avaticaErrorToResponseErrorOrError(err) } // Currently there is only 1 ResultSet per response for exec @@ -178,7 +184,7 @@ func (c *conn) query(ctx context.Context, query string, args []namedValue) (driv }) if err != nil { - return nil, err + return nil, c.avaticaErrorToResponseErrorOrError(err) } res, err := c.httpClient.post(ctx, &message.PrepareAndExecuteRequest{ @@ -190,10 +196,34 @@ func (c *conn) query(ctx context.Context, query string, args []namedValue) (driv }) if err != nil { - return nil, err + return nil, c.avaticaErrorToResponseErrorOrError(err) } resultSets := res.(*message.ExecuteResponse).Results return newRows(c, st.(*message.CreateStatementResponse).StatementId, resultSets), nil } + +func (c *conn) avaticaErrorToResponseErrorOrError(err error) error { + + avaticaErr, ok := err.(avaticaError) + + if !ok { + return err + } + + if c.adapter != nil { + return c.adapter.ErrorResponseToResponseError(avaticaErr.message) + } + + return errors.ResponseError{ + Exceptions: avaticaErr.message.Exceptions, + ErrorMessage: avaticaErr.message.ErrorMessage, + Severity: int8(avaticaErr.message.Severity), + ErrorCode: errors.ErrorCode(avaticaErr.message.ErrorCode), + SqlState: errors.SQLState(avaticaErr.message.SqlState), + Metadata: &errors.RPCMetadata{ + ServerAddress: avaticaErr.message.GetMetadata().ServerAddress, + }, + } +} http://git-wip-us.apache.org/repos/asf/calcite-avatica-go/blob/2968def4/connection_go18.go ---------------------------------------------------------------------- diff --git a/connection_go18.go b/connection_go18.go index 332e04f..1c75b0d 100644 --- a/connection_go18.go +++ b/connection_go18.go @@ -75,7 +75,7 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name func (c *conn) Ping(ctx context.Context) error { - _, err := c.ExecContext(ctx, "SELECT 1", []driver.NamedValue{}) + _, err := c.ExecContext(ctx, c.adapter.GetPingStatement(), []driver.NamedValue{}) if err != nil { return fmt.Errorf("Error pinging database: %s", err) http://git-wip-us.apache.org/repos/asf/calcite-avatica-go/blob/2968def4/driver.go ---------------------------------------------------------------------- diff --git a/driver.go b/driver.go index 3137314..6457f40 100644 --- a/driver.go +++ b/driver.go @@ -36,7 +36,10 @@ import ( "database/sql/driver" "fmt" + "github.com/apache/calcite-avatica-go/generic" + "github.com/apache/calcite-avatica-go/hsqldb" "github.com/apache/calcite-avatica-go/message" + "github.com/apache/calcite-avatica-go/phoenix" "github.com/satori/go.uuid" "golang.org/x/net/context" ) @@ -88,6 +91,12 @@ func (a *Driver) Open(dsn string) (driver.Conn, error) { info["password"] = config.password } + conn := &conn{ + connectionId: connectionId.String(), + httpClient: httpClient, + config: config, + } + // Open a connection to the server req := &message.OpenConnectionRequest{ ConnectionId: connectionId.String(), @@ -101,18 +110,43 @@ func (a *Driver) Open(dsn string) (driver.Conn, error) { _, err = httpClient.post(context.Background(), req) if err != nil { - return nil, err + return nil, conn.avaticaErrorToResponseErrorOrError(err) } - conn := &conn{ - connectionId: connectionId.String(), - httpClient: httpClient, - config: config, + response, err := httpClient.post(context.Background(), &message.DatabasePropertyRequest{ + ConnectionId: connectionId.String(), + }) + + if err != nil { + return nil, conn.avaticaErrorToResponseErrorOrError(err) } + databasePropertyResponse := response.(*message.DatabasePropertyResponse) + + adapter := "" + + for _, property := range databasePropertyResponse.Props { + if property.Key.Name == "GET_DRIVER_NAME" { + adapter = property.Value.StringValue + } + } + + conn.adapter = getAdapter(adapter) + return conn, nil } +func getAdapter(e string) Adapter { + switch e { + case "HSQL Database Engine Driver": + return hsqldb.Adapter{} + case "PhoenixEmbeddedDriver": + return phoenix.Adapter{} + default: + return generic.Adapter{} + } +} + func init() { sql.Register("avatica", &Driver{}) } http://git-wip-us.apache.org/repos/asf/calcite-avatica-go/blob/2968def4/driver_go18_hsqldb_test.go ---------------------------------------------------------------------- diff --git a/driver_go18_hsqldb_test.go b/driver_go18_hsqldb_test.go new file mode 100644 index 0000000..21c59be --- /dev/null +++ b/driver_go18_hsqldb_test.go @@ -0,0 +1,578 @@ +// +build go1.8 + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to you under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package avatica + +import ( + "database/sql" + "math" + "reflect" + "testing" + "time" +) + +func TestHSQLDBContext(t *testing.T) { + + skipTestIfNotHSQLDB(t) + + runTests(t, dsn, func(dbt *DBTest) { + + // Create and seed table + dbt.mustExecContext(getContext(), "CREATE TABLE "+dbt.tableName+" (id BIGINT PRIMARY KEY, val VARCHAR(1))") + + dbt.mustExecContext(getContext(), "INSERT INTO "+dbt.tableName+" VALUES (1,'A')") + + dbt.mustExecContext(getContext(), "INSERT INTO "+dbt.tableName+" VALUES (2,'B')") + + rows := dbt.mustQueryContext(getContext(), "SELECT COUNT(*) FROM "+dbt.tableName) + defer rows.Close() + + for rows.Next() { + + var count int + + err := rows.Scan(&count) + + if err != nil { + dbt.Fatal(err) + } + + if count != 2 { + dbt.Fatalf("There should be 2 rows, got %d", count) + } + } + + // Test transactions and prepared statements + _, err := dbt.db.BeginTx(getContext(), &sql.TxOptions{Isolation: sql.LevelReadUncommitted, ReadOnly: true}) + + if err == nil { + t.Error("Expected an error while creating a read only transaction, but no error was returned") + } + + tx, err := dbt.db.BeginTx(getContext(), &sql.TxOptions{Isolation: sql.LevelReadCommitted}) + + if err != nil { + t.Errorf("Unexpected error while creating transaction: %s", err) + } + + stmt, err := tx.PrepareContext(getContext(), "INSERT INTO "+dbt.tableName+" VALUES(?,?)") + + if err != nil { + t.Errorf("Unexpected error while preparing statement: %s", err) + } + + res, err := stmt.ExecContext(getContext(), 3, "C") + + if err != nil { + t.Errorf("Unexpected error while executing statement: %s", err) + } + + affected, err := res.RowsAffected() + + if err != nil { + t.Errorf("Error getting affected rows: %s", err) + } + + if affected != 1 { + t.Errorf("Expected 1 affected row, got %d", affected) + } + + err = tx.Commit() + + if err != nil { + t.Errorf("Error committing transaction: %s", err) + } + + stmt2, err := dbt.db.PrepareContext(getContext(), "SELECT * FROM "+dbt.tableName+" WHERE id = ?") + + if err != nil { + t.Errorf("Error preparing statement: %s", err) + } + + row := stmt2.QueryRowContext(getContext(), 3) + + if err != nil { + t.Errorf("Error querying for row: %s", err) + } + + var ( + queryID int64 + queryVal string + ) + + err = row.Scan(&queryID, &queryVal) + + if err != nil { + t.Errorf("Error scanning results into variable: %s", err) + } + + if queryID != 3 { + t.Errorf("Expected scanned id to be %d, got %d", 3, queryID) + } + + if queryVal != "C" { + t.Errorf("Expected scanned string to be %s, got %s", "C", queryVal) + } + }) +} + +func TestHSQLDBMultipleResultSets(t *testing.T) { + + skipTestIfNotHSQLDB(t) + + runTests(t, dsn, func(dbt *DBTest) { + // Create and seed table + dbt.mustExecContext(getContext(), "CREATE TABLE "+dbt.tableName+" (id BIGINT PRIMARY KEY, val VARCHAR(1))") + + dbt.mustExecContext(getContext(), "INSERT INTO "+dbt.tableName+" VALUES (1,'A')") + + dbt.mustExecContext(getContext(), "INSERT INTO "+dbt.tableName+" VALUES (2,'B')") + + rows, err := dbt.db.QueryContext(getContext(), "SELECT * FROM "+dbt.tableName+" WHERE id = 1") + + if err != nil { + t.Errorf("Unexpected error while executing query: %s", err) + } + + defer rows.Close() + + for rows.Next() { + var ( + id int64 + val string + ) + + if err := rows.Scan(&id, &val); err != nil { + t.Errorf("Error while scanning row into variables: %s", err) + } + + if id != 1 { + t.Errorf("Expected id to be %d, got %d", 1, id) + } + + if val != "A" { + t.Errorf("Expected value to be %s, got %s", "A", val) + } + } + + if rows.NextResultSet() { + t.Error("There should be no more result sets, but got another result set") + } + }) +} + +func TestHSQLDBColumnTypes(t *testing.T) { + + skipTestIfNotHSQLDB(t) + + runTests(t, dsn, func(dbt *DBTest) { + + // Create and seed table + dbt.mustExec(`CREATE TABLE ` + dbt.tableName + ` ( + int INTEGER PRIMARY KEY, + bint BIGINT, + tint TINYINT, + sint SMALLINT, + dbl DOUBLE, + dec DECIMAL(10, 5), + dec2 DECIMAL, + bool BOOLEAN, + tm TIME, + dt DATE, + tmstmp TIMESTAMP, + var VARCHAR(10), + ch CHAR(3), + bin BINARY(20), + varbin VARBINARY(20) + )`) + + // Select + rows, err := dbt.db.QueryContext(getContext(), "SELECT * FROM "+dbt.tableName) + + if err != nil { + t.Errorf("Unexpected error while selecting from table: %s", err) + } + + columnNames, err := rows.Columns() + + if err != nil { + t.Errorf("Error getting column names: %s", err) + } + + expectedColumnNames := []string{"INT", "BINT", "TINT", "SINT", "DBL", "DEC", "DEC2", "BOOL", "TM", "DT", "TMSTMP", "VAR", "CH", "BIN", "VARBIN"} + + if !reflect.DeepEqual(columnNames, expectedColumnNames) { + t.Error("Column names does not match expected column names") + } + + type decimalSize struct { + precision int64 + scale int64 + ok bool + } + + type length struct { + length int64 + ok bool + } + + type nullable struct { + nullable bool + ok bool + } + + expectedColumnTypes := []struct { + databaseTypeName string + decimalSize decimalSize + length length + name string + nullable nullable + scanType reflect.Type + }{ + { + databaseTypeName: "INTEGER", + decimalSize: decimalSize{ + precision: 0, + scale: 0, + ok: false, + }, + length: length{ + length: 0, + ok: false, + }, + name: "INT", + nullable: nullable{ + nullable: false, + ok: true, + }, + scanType: reflect.TypeOf(int64(0)), + }, + { + databaseTypeName: "BIGINT", + decimalSize: decimalSize{ + precision: 0, + scale: 0, + ok: false, + }, + length: length{ + length: 0, + ok: false, + }, + name: "BINT", + nullable: nullable{ + nullable: true, + ok: true, + }, + scanType: reflect.TypeOf(int64(0)), + }, + { + databaseTypeName: "TINYINT", + decimalSize: decimalSize{ + precision: 0, + scale: 0, + ok: false, + }, + length: length{ + length: 0, + ok: false, + }, + name: "TINT", + nullable: nullable{ + nullable: true, + ok: true, + }, + scanType: reflect.TypeOf(int64(0)), + }, + { + databaseTypeName: "SMALLINT", + decimalSize: decimalSize{ + precision: 0, + scale: 0, + ok: false, + }, + length: length{ + length: 0, + ok: false, + }, + name: "SINT", + nullable: nullable{ + nullable: true, + ok: true, + }, + scanType: reflect.TypeOf(int64(0)), + }, + { + databaseTypeName: "DOUBLE", + decimalSize: decimalSize{ + precision: 0, + scale: 0, + ok: false, + }, + length: length{ + length: 0, + ok: false, + }, + name: "DBL", + nullable: nullable{ + nullable: true, + ok: true, + }, + scanType: reflect.TypeOf(float64(0)), + }, + { + databaseTypeName: "DECIMAL", + decimalSize: decimalSize{ + precision: 10, + scale: 5, + ok: true, + }, + length: length{ + length: 0, + ok: false, + }, + name: "DEC", + nullable: nullable{ + nullable: true, + ok: true, + }, + scanType: reflect.TypeOf(""), + }, + { + databaseTypeName: "DECIMAL", + decimalSize: decimalSize{ + precision: 128, + scale: math.MaxInt64, + ok: true, + }, + length: length{ + length: 0, + ok: false, + }, + name: "DEC2", + nullable: nullable{ + nullable: true, + ok: true, + }, + scanType: reflect.TypeOf(""), + }, + { + databaseTypeName: "BOOLEAN", + decimalSize: decimalSize{ + precision: 0, + scale: 0, + ok: false, + }, + length: length{ + length: 0, + ok: false, + }, + name: "BOOL", + nullable: nullable{ + nullable: true, + ok: true, + }, + scanType: reflect.TypeOf(bool(false)), + }, + { + databaseTypeName: "TIME", + decimalSize: decimalSize{ + precision: 0, + scale: 0, + ok: false, + }, + length: length{ + length: 0, + ok: false, + }, + name: "TM", + nullable: nullable{ + nullable: true, + ok: true, + }, + scanType: reflect.TypeOf(time.Time{}), + }, + { + databaseTypeName: "DATE", + decimalSize: decimalSize{ + precision: 0, + scale: 0, + ok: false, + }, + length: length{ + length: 0, + ok: false, + }, + name: "DT", + nullable: nullable{ + nullable: true, + ok: true, + }, + scanType: reflect.TypeOf(time.Time{}), + }, + { + databaseTypeName: "TIMESTAMP", + decimalSize: decimalSize{ + precision: 0, + scale: 0, + ok: false, + }, + length: length{ + length: 0, + ok: false, + }, + name: "TMSTMP", + nullable: nullable{ + nullable: true, + ok: true, + }, + scanType: reflect.TypeOf(time.Time{}), + }, + { + databaseTypeName: "VARCHAR", + decimalSize: decimalSize{ + precision: 0, + scale: 0, + ok: false, + }, + length: length{ + length: 10, + ok: true, + }, + name: "VAR", + nullable: nullable{ + nullable: true, + ok: true, + }, + scanType: reflect.TypeOf(""), + }, + { + databaseTypeName: "CHARACTER", + decimalSize: decimalSize{ + precision: 0, + scale: 0, + ok: false, + }, + length: length{ + length: 3, + ok: true, + }, + name: "CH", + nullable: nullable{ + nullable: true, + ok: true, + }, + scanType: reflect.TypeOf(""), + }, + { + databaseTypeName: "BINARY", + decimalSize: decimalSize{ + precision: 0, + scale: 0, + ok: false, + }, + length: length{ + length: 20, + ok: true, + }, + name: "BIN", + nullable: nullable{ + nullable: true, + ok: true, + }, + scanType: reflect.TypeOf([]byte{}), + }, + { + databaseTypeName: "VARBINARY", + decimalSize: decimalSize{ + precision: 0, + scale: 0, + ok: false, + }, + length: length{ + length: 20, + ok: true, + }, + name: "VARBIN", + nullable: nullable{ + nullable: true, + ok: true, + }, + scanType: reflect.TypeOf([]byte{}), + }, + } + + columnTypes, err := rows.ColumnTypes() + + if err != nil { + t.Errorf("Error getting column types: %s", err) + } + + for index, columnType := range columnTypes { + + expected := expectedColumnTypes[index] + + if columnType.DatabaseTypeName() != expected.databaseTypeName { + t.Errorf("Expected database type name for index %d to be %s, got %s", index, expected.databaseTypeName, columnType.DatabaseTypeName()) + } + + precision, scale, ok := columnType.DecimalSize() + + if precision != expected.decimalSize.precision { + t.Errorf("Expected decimal precision for index %d to be %d, got %d", index, expected.decimalSize.precision, precision) + } + + if scale != expected.decimalSize.scale { + t.Errorf("Expected decimal scale for index %d to be %d, got %d", index, expected.decimalSize.scale, scale) + } + + if ok != expected.decimalSize.ok { + t.Errorf("Expected decimal ok for index %d to be %t, got %t", index, expected.decimalSize.ok, ok) + } + + length, ok := columnType.Length() + + if length != expected.length.length { + t.Errorf("Expected length for index %d to be %d, got %d", index, expected.length.length, length) + } + + if ok != expected.length.ok { + t.Errorf("Expected length ok for index %d to be %t, got %t", index, expected.length.ok, ok) + } + + if columnType.Name() != expected.name { + t.Errorf("Expected column name for index %d to be %s, got %s", index, expected.name, columnType.Name()) + } + + nullable, ok := columnType.Nullable() + + if nullable != expected.nullable.nullable { + t.Errorf("Expected nullable for index %d to be %t, got %t", index, expected.nullable.nullable, nullable) + } + + if ok != expected.nullable.ok { + t.Errorf("Expected nullable ok for index %d to be %t, got %t", index, expected.nullable.ok, ok) + } + + if columnType.ScanType() != expected.scanType { + t.Errorf("Expected scan type for index %d to be %s, got %s", index, expected.scanType, columnType.ScanType()) + } + } + + }) +} http://git-wip-us.apache.org/repos/asf/calcite-avatica-go/blob/2968def4/driver_go18_phoenix_test.go ---------------------------------------------------------------------- diff --git a/driver_go18_phoenix_test.go b/driver_go18_phoenix_test.go new file mode 100644 index 0000000..bb7b322 --- /dev/null +++ b/driver_go18_phoenix_test.go @@ -0,0 +1,768 @@ +// +build go1.8 + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to you under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package avatica + +import ( + "database/sql" + "math" + "reflect" + "testing" + "time" +) + +func TestPhoenixContext(t *testing.T) { + + skipTestIfNotPhoenix(t) + + runTests(t, dsn, func(dbt *DBTest) { + + // Create and seed table + dbt.mustExecContext(getContext(), "CREATE TABLE "+dbt.tableName+" (id BIGINT PRIMARY KEY, val VARCHAR) TRANSACTIONAL=false") + + dbt.mustExecContext(getContext(), "UPSERT INTO "+dbt.tableName+" VALUES (1,'A')") + + dbt.mustExecContext(getContext(), "UPSERT INTO "+dbt.tableName+" VALUES (2,'B')") + + rows := dbt.mustQueryContext(getContext(), "SELECT COUNT(*) FROM "+dbt.tableName) + defer rows.Close() + + for rows.Next() { + + var count int + + err := rows.Scan(&count) + + if err != nil { + dbt.Fatal(err) + } + + if count != 2 { + dbt.Fatalf("There should be 2 rows, got %d", count) + } + } + + // Test transactions and prepared statements + _, err := dbt.db.BeginTx(getContext(), &sql.TxOptions{Isolation: sql.LevelReadUncommitted, ReadOnly: true}) + + if err == nil { + t.Error("Expected an error while creating a read only transaction, but no error was returned") + } + + tx, err := dbt.db.BeginTx(getContext(), &sql.TxOptions{Isolation: sql.LevelReadCommitted}) + + if err != nil { + t.Errorf("Unexpected error while creating transaction: %s", err) + } + + stmt, err := tx.PrepareContext(getContext(), "UPSERT INTO "+dbt.tableName+" VALUES(?,?)") + + if err != nil { + t.Errorf("Unexpected error while preparing statement: %s", err) + } + + res, err := stmt.ExecContext(getContext(), 3, "C") + + if err != nil { + t.Errorf("Unexpected error while executing statement: %s", err) + } + + affected, err := res.RowsAffected() + + if err != nil { + t.Errorf("Error getting affected rows: %s", err) + } + + if affected != 1 { + t.Errorf("Expected 1 affected row, got %d", affected) + } + + err = tx.Commit() + + if err != nil { + t.Errorf("Error committing transaction: %s", err) + } + + stmt2, err := dbt.db.PrepareContext(getContext(), "SELECT * FROM "+dbt.tableName+" WHERE id = ?") + + if err != nil { + t.Errorf("Error preparing statement: %s", err) + } + + row := stmt2.QueryRowContext(getContext(), 3) + + if err != nil { + t.Errorf("Error querying for row: %s", err) + } + + var ( + queryID int64 + queryVal string + ) + + err = row.Scan(&queryID, &queryVal) + + if err != nil { + t.Errorf("Error scanning results into variable: %s", err) + } + + if queryID != 3 { + t.Errorf("Expected scanned id to be %d, got %d", 3, queryID) + } + + if queryVal != "C" { + t.Errorf("Expected scanned string to be %s, got %s", "C", queryVal) + } + }) +} + +func TestPhoenixMultipleResultSets(t *testing.T) { + + skipTestIfNotPhoenix(t) + + runTests(t, dsn, func(dbt *DBTest) { + // Create and seed table + dbt.mustExecContext(getContext(), "CREATE TABLE "+dbt.tableName+" (id BIGINT PRIMARY KEY, val VARCHAR) TRANSACTIONAL=false") + + dbt.mustExecContext(getContext(), "UPSERT INTO "+dbt.tableName+" VALUES (1,'A')") + + dbt.mustExecContext(getContext(), "UPSERT INTO "+dbt.tableName+" VALUES (2,'B')") + + rows, err := dbt.db.QueryContext(getContext(), "SELECT * FROM "+dbt.tableName+" WHERE id = 1") + + if err != nil { + t.Errorf("Unexpected error while executing query: %s", err) + } + + defer rows.Close() + + for rows.Next() { + var ( + id int64 + val string + ) + + if err := rows.Scan(&id, &val); err != nil { + t.Errorf("Error while scanning row into variables: %s", err) + } + + if id != 1 { + t.Errorf("Expected id to be %d, got %d", 1, id) + } + + if val != "A" { + t.Errorf("Expected value to be %s, got %s", "A", val) + } + } + + if rows.NextResultSet() { + t.Error("There should be no more result sets, but got another result set") + } + }) +} + +func TestPhoenixColumnTypes(t *testing.T) { + + skipTestIfNotPhoenix(t) + + runTests(t, dsn, func(dbt *DBTest) { + + // Create and seed table + dbt.mustExec(`CREATE TABLE ` + dbt.tableName + ` ( + int INTEGER PRIMARY KEY, + uint UNSIGNED_INT, + bint BIGINT, + ulong UNSIGNED_LONG, + tint TINYINT, + utint UNSIGNED_TINYINT, + sint SMALLINT, + usint UNSIGNED_SMALLINT, + flt FLOAT, + uflt UNSIGNED_FLOAT, + dbl DOUBLE, + udbl UNSIGNED_DOUBLE, + dec DECIMAL(10, 5), + dec2 DECIMAL, + bool BOOLEAN, + tm TIME, + dt DATE, + tmstmp TIMESTAMP, + utm UNSIGNED_TIME, + udt UNSIGNED_DATE, + utmstmp UNSIGNED_TIMESTAMP, + var VARCHAR(10), + ch CHAR(3), + bin BINARY(20), + varbin VARBINARY + ) TRANSACTIONAL=false`) + + // Select + rows, err := dbt.db.QueryContext(getContext(), "SELECT * FROM "+dbt.tableName) + + if err != nil { + t.Errorf("Unexpected error while selecting from table: %s", err) + } + + columnNames, err := rows.Columns() + + if err != nil { + t.Errorf("Error getting column names: %s", err) + } + + expectedColumnNames := []string{"INT", "UINT", "BINT", "ULONG", "TINT", "UTINT", "SINT", "USINT", "FLT", "UFLT", "DBL", "UDBL", "DEC", "DEC2", "BOOL", "TM", "DT", "TMSTMP", "UTM", "UDT", "UTMSTMP", "VAR", "CH", "BIN", "VARBIN"} + + if !reflect.DeepEqual(columnNames, expectedColumnNames) { + t.Error("Column names does not match expected column names") + } + + type decimalSize struct { + precision int64 + scale int64 + ok bool + } + + type length struct { + length int64 + ok bool + } + + type nullable struct { + nullable bool + ok bool + } + + expectedColumnTypes := []struct { + databaseTypeName string + decimalSize decimalSize + length length + name string + nullable nullable + scanType reflect.Type + }{ + { + databaseTypeName: "INTEGER", + decimalSize: decimalSize{ + precision: 0, + scale: 0, + ok: false, + }, + length: length{ + length: 0, + ok: false, + }, + name: "INT", + nullable: nullable{ + nullable: false, + ok: true, + }, + scanType: reflect.TypeOf(int64(0)), + }, + { + databaseTypeName: "UNSIGNED_INT", + decimalSize: decimalSize{ + precision: 0, + scale: 0, + ok: false, + }, + length: length{ + length: 0, + ok: false, + }, + name: "UINT", + nullable: nullable{ + nullable: true, + ok: true, + }, + scanType: reflect.TypeOf(int64(0)), + }, + { + databaseTypeName: "BIGINT", + decimalSize: decimalSize{ + precision: 0, + scale: 0, + ok: false, + }, + length: length{ + length: 0, + ok: false, + }, + name: "BINT", + nullable: nullable{ + nullable: true, + ok: true, + }, + scanType: reflect.TypeOf(int64(0)), + }, + { + databaseTypeName: "UNSIGNED_LONG", + decimalSize: decimalSize{ + precision: 0, + scale: 0, + ok: false, + }, + length: length{ + length: 0, + ok: false, + }, + name: "ULONG", + nullable: nullable{ + nullable: true, + ok: true, + }, + scanType: reflect.TypeOf(int64(0)), + }, + { + databaseTypeName: "TINYINT", + decimalSize: decimalSize{ + precision: 0, + scale: 0, + ok: false, + }, + length: length{ + length: 0, + ok: false, + }, + name: "TINT", + nullable: nullable{ + nullable: true, + ok: true, + }, + scanType: reflect.TypeOf(int64(0)), + }, + { + databaseTypeName: "UNSIGNED_TINYINT", + decimalSize: decimalSize{ + precision: 0, + scale: 0, + ok: false, + }, + length: length{ + length: 0, + ok: false, + }, + name: "UTINT", + nullable: nullable{ + nullable: true, + ok: true, + }, + scanType: reflect.TypeOf(int64(0)), + }, + { + databaseTypeName: "SMALLINT", + decimalSize: decimalSize{ + precision: 0, + scale: 0, + ok: false, + }, + length: length{ + length: 0, + ok: false, + }, + name: "SINT", + nullable: nullable{ + nullable: true, + ok: true, + }, + scanType: reflect.TypeOf(int64(0)), + }, + { + databaseTypeName: "UNSIGNED_SMALLINT", + decimalSize: decimalSize{ + precision: 0, + scale: 0, + ok: false, + }, + length: length{ + length: 0, + ok: false, + }, + name: "USINT", + nullable: nullable{ + nullable: true, + ok: true, + }, + scanType: reflect.TypeOf(int64(0)), + }, + { + databaseTypeName: "FLOAT", + decimalSize: decimalSize{ + precision: 0, + scale: 0, + ok: false, + }, + length: length{ + length: 0, + ok: false, + }, + name: "FLT", + nullable: nullable{ + nullable: true, + ok: true, + }, + scanType: reflect.TypeOf(float64(0)), + }, + { + databaseTypeName: "UNSIGNED_FLOAT", + decimalSize: decimalSize{ + precision: 0, + scale: 0, + ok: false, + }, + length: length{ + length: 0, + ok: false, + }, + name: "UFLT", + nullable: nullable{ + nullable: true, + ok: true, + }, + scanType: reflect.TypeOf(float64(0)), + }, + { + databaseTypeName: "DOUBLE", + decimalSize: decimalSize{ + precision: 0, + scale: 0, + ok: false, + }, + length: length{ + length: 0, + ok: false, + }, + name: "DBL", + nullable: nullable{ + nullable: true, + ok: true, + }, + scanType: reflect.TypeOf(float64(0)), + }, + { + databaseTypeName: "UNSIGNED_DOUBLE", + decimalSize: decimalSize{ + precision: 0, + scale: 0, + ok: false, + }, + length: length{ + length: 0, + ok: false, + }, + name: "UDBL", + nullable: nullable{ + nullable: true, + ok: true, + }, + scanType: reflect.TypeOf(float64(0)), + }, + { + databaseTypeName: "DECIMAL", + decimalSize: decimalSize{ + precision: 10, + scale: 5, + ok: true, + }, + length: length{ + length: 0, + ok: false, + }, + name: "DEC", + nullable: nullable{ + nullable: true, + ok: true, + }, + scanType: reflect.TypeOf(""), + }, + { + databaseTypeName: "DECIMAL", + decimalSize: decimalSize{ + precision: math.MaxInt64, + scale: math.MaxInt64, + ok: true, + }, + length: length{ + length: 0, + ok: false, + }, + name: "DEC2", + nullable: nullable{ + nullable: true, + ok: true, + }, + scanType: reflect.TypeOf(""), + }, + { + databaseTypeName: "BOOLEAN", + decimalSize: decimalSize{ + precision: 0, + scale: 0, + ok: false, + }, + length: length{ + length: 0, + ok: false, + }, + name: "BOOL", + nullable: nullable{ + nullable: true, + ok: true, + }, + scanType: reflect.TypeOf(bool(false)), + }, + { + databaseTypeName: "TIME", + decimalSize: decimalSize{ + precision: 0, + scale: 0, + ok: false, + }, + length: length{ + length: 0, + ok: false, + }, + name: "TM", + nullable: nullable{ + nullable: true, + ok: true, + }, + scanType: reflect.TypeOf(time.Time{}), + }, + { + databaseTypeName: "DATE", + decimalSize: decimalSize{ + precision: 0, + scale: 0, + ok: false, + }, + length: length{ + length: 0, + ok: false, + }, + name: "DT", + nullable: nullable{ + nullable: true, + ok: true, + }, + scanType: reflect.TypeOf(time.Time{}), + }, + { + databaseTypeName: "TIMESTAMP", + decimalSize: decimalSize{ + precision: 0, + scale: 0, + ok: false, + }, + length: length{ + length: 0, + ok: false, + }, + name: "TMSTMP", + nullable: nullable{ + nullable: true, + ok: true, + }, + scanType: reflect.TypeOf(time.Time{}), + }, + { + databaseTypeName: "UNSIGNED_TIME", + decimalSize: decimalSize{ + precision: 0, + scale: 0, + ok: false, + }, + length: length{ + length: 0, + ok: false, + }, + name: "UTM", + nullable: nullable{ + nullable: true, + ok: true, + }, + scanType: reflect.TypeOf(time.Time{}), + }, + { + databaseTypeName: "UNSIGNED_DATE", + decimalSize: decimalSize{ + precision: 0, + scale: 0, + ok: false, + }, + length: length{ + length: 0, + ok: false, + }, + name: "UDT", + nullable: nullable{ + nullable: true, + ok: true, + }, + scanType: reflect.TypeOf(time.Time{}), + }, + { + databaseTypeName: "UNSIGNED_TIMESTAMP", + decimalSize: decimalSize{ + precision: 0, + scale: 0, + ok: false, + }, + length: length{ + length: 0, + ok: false, + }, + name: "UTMSTMP", + nullable: nullable{ + nullable: true, + ok: true, + }, + scanType: reflect.TypeOf(time.Time{}), + }, + { + databaseTypeName: "VARCHAR", + decimalSize: decimalSize{ + precision: 0, + scale: 0, + ok: false, + }, + length: length{ + length: 10, + ok: true, + }, + name: "VAR", + nullable: nullable{ + nullable: true, + ok: true, + }, + scanType: reflect.TypeOf(""), + }, + { + databaseTypeName: "CHAR", + decimalSize: decimalSize{ + precision: 0, + scale: 0, + ok: false, + }, + length: length{ + length: 3, + ok: true, + }, + name: "CH", + nullable: nullable{ + nullable: true, + ok: true, + }, + scanType: reflect.TypeOf(""), + }, + { + databaseTypeName: "BINARY", + decimalSize: decimalSize{ + precision: 0, + scale: 0, + ok: false, + }, + length: length{ + length: 20, + ok: true, + }, + name: "BIN", + nullable: nullable{ + nullable: true, + ok: true, + }, + scanType: reflect.TypeOf([]byte{}), + }, + { + databaseTypeName: "VARBINARY", + decimalSize: decimalSize{ + precision: 0, + scale: 0, + ok: false, + }, + length: length{ + length: math.MaxInt64, + ok: true, + }, + name: "VARBIN", + nullable: nullable{ + nullable: true, + ok: true, + }, + scanType: reflect.TypeOf([]byte{}), + }, + } + + columnTypes, err := rows.ColumnTypes() + + if err != nil { + t.Errorf("Error getting column types: %s", err) + } + + for index, columnType := range columnTypes { + + expected := expectedColumnTypes[index] + + if columnType.DatabaseTypeName() != expected.databaseTypeName { + t.Errorf("Expected database type name for index %d to be %s, got %s", index, expected.databaseTypeName, columnType.DatabaseTypeName()) + } + + precision, scale, ok := columnType.DecimalSize() + + if precision != expected.decimalSize.precision { + t.Errorf("Expected decimal precision for index %d to be %d, got %d", index, expected.decimalSize.precision, precision) + } + + if scale != expected.decimalSize.scale { + t.Errorf("Expected decimal scale for index %d to be %d, got %d", index, expected.decimalSize.scale, scale) + } + + if ok != expected.decimalSize.ok { + t.Errorf("Expected decimal ok for index %d to be %t, got %t", index, expected.decimalSize.ok, ok) + } + + length, ok := columnType.Length() + + if length != expected.length.length { + t.Errorf("Expected length for index %d to be %d, got %d", index, expected.length.length, length) + } + + if ok != expected.length.ok { + t.Errorf("Expected length ok for index %d to be %t, got %t", index, expected.length.ok, ok) + } + + if columnType.Name() != expected.name { + t.Errorf("Expected column name for index %d to be %s, got %s", index, expected.name, columnType.Name()) + } + + nullable, ok := columnType.Nullable() + + if nullable != expected.nullable.nullable { + t.Errorf("Expected nullable for index %d to be %t, got %t", index, expected.nullable.nullable, nullable) + } + + if ok != expected.nullable.ok { + t.Errorf("Expected nullable ok for index %d to be %t, got %t", index, expected.nullable.ok, ok) + } + + if columnType.ScanType() != expected.scanType { + t.Errorf("Expected scan type for index %d to be %s, got %s", index, expected.scanType, columnType.ScanType()) + } + } + + }) +} http://git-wip-us.apache.org/repos/asf/calcite-avatica-go/blob/2968def4/driver_go18_test.go ---------------------------------------------------------------------- diff --git a/driver_go18_test.go b/driver_go18_test.go index be760f6..1be71e0 100644 --- a/driver_go18_test.go +++ b/driver_go18_test.go @@ -22,8 +22,6 @@ package avatica import ( "context" "database/sql" - "math" - "reflect" "testing" "time" ) @@ -54,109 +52,6 @@ func getContext() context.Context { return ctx } -func TestContext(t *testing.T) { - - runTests(t, dsn, func(dbt *DBTest) { - - // Create and seed table - dbt.mustExecContext(getContext(), "CREATE TABLE "+dbt.tableName+" (id BIGINT PRIMARY KEY, val VARCHAR) TRANSACTIONAL=false") - - dbt.mustExecContext(getContext(), "UPSERT INTO "+dbt.tableName+" VALUES (1,'A')") - - dbt.mustExecContext(getContext(), "UPSERT INTO "+dbt.tableName+" VALUES (2,'B')") - - rows := dbt.mustQueryContext(getContext(), "SELECT COUNT(*) FROM "+dbt.tableName) - defer rows.Close() - - for rows.Next() { - - var count int - - err := rows.Scan(&count) - - if err != nil { - dbt.Fatal(err) - } - - if count != 2 { - dbt.Fatalf("There should be 2 rows, got %d", count) - } - } - - // Test transactions and prepared statements - _, err := dbt.db.BeginTx(getContext(), &sql.TxOptions{Isolation: sql.LevelReadUncommitted, ReadOnly: true}) - - if err == nil { - t.Error("Expected an error while creating a read only transaction, but no error was returned") - } - - tx, err := dbt.db.BeginTx(getContext(), &sql.TxOptions{Isolation: sql.LevelReadCommitted}) - - if err != nil { - t.Errorf("Unexpected error while creating transaction: %s", err) - } - - stmt, err := tx.PrepareContext(getContext(), "UPSERT INTO "+dbt.tableName+" VALUES(?,?)") - - if err != nil { - t.Errorf("Unexpected error while preparing statement: %s", err) - } - - res, err := stmt.ExecContext(getContext(), 3, "C") - - if err != nil { - t.Errorf("Unexpected error while executing statement: %s", err) - } - - affected, err := res.RowsAffected() - - if err != nil { - t.Errorf("Error getting affected rows: %s", err) - } - - if affected != 1 { - t.Errorf("Expected 1 affected row, got %d", affected) - } - - err = tx.Commit() - - if err != nil { - t.Errorf("Error committing transaction: %s", err) - } - - stmt2, err := dbt.db.PrepareContext(getContext(), "SELECT * FROM "+dbt.tableName+" WHERE id = ?") - - if err != nil { - t.Errorf("Error preparing statement: %s", err) - } - - row := stmt2.QueryRowContext(getContext(), 3) - - if err != nil { - t.Errorf("Error querying for row: %s", err) - } - - var ( - queryID int64 - queryVal string - ) - - err = row.Scan(&queryID, &queryVal) - - if err != nil { - t.Errorf("Error scanning results into variable: %s", err) - } - - if queryID != 3 { - t.Errorf("Expected scanned id to be %d, got %d", 3, queryID) - } - - if queryVal != "C" { - t.Errorf("Expected scanned string to be %s, got %s", "C", queryVal) - } - }) -} - func TestPing(t *testing.T) { runTests(t, dsn, func(dbt *DBTest) { err := dbt.db.Ping() @@ -176,634 +71,3 @@ func TestInvalidPing(t *testing.T) { } }) } - -func TestMultipleResultSets(t *testing.T) { - - runTests(t, dsn, func(dbt *DBTest) { - // Create and seed table - dbt.mustExecContext(getContext(), "CREATE TABLE "+dbt.tableName+" (id BIGINT PRIMARY KEY, val VARCHAR) TRANSACTIONAL=false") - - dbt.mustExecContext(getContext(), "UPSERT INTO "+dbt.tableName+" VALUES (1,'A')") - - dbt.mustExecContext(getContext(), "UPSERT INTO "+dbt.tableName+" VALUES (2,'B')") - - rows, err := dbt.db.QueryContext(getContext(), "SELECT * FROM "+dbt.tableName+" WHERE id = 1") - - if err != nil { - t.Errorf("Unexpected error while executing query: %s", err) - } - - defer rows.Close() - - for rows.Next() { - var ( - id int64 - val string - ) - - if err := rows.Scan(&id, &val); err != nil { - t.Errorf("Error while scanning row into variables: %s", err) - } - - if id != 1 { - t.Errorf("Expected id to be %d, got %d", 1, id) - } - - if val != "A" { - t.Errorf("Expected value to be %s, got %s", "A", val) - } - } - - if rows.NextResultSet() { - t.Error("There should be no more result sets, but got another result set") - } - }) -} - -func TestColumnTypes(t *testing.T) { - - runTests(t, dsn, func(dbt *DBTest) { - - // Create and seed table - dbt.mustExec(`CREATE TABLE ` + dbt.tableName + ` ( - int INTEGER PRIMARY KEY, - uint UNSIGNED_INT, - bint BIGINT, - ulong UNSIGNED_LONG, - tint TINYINT, - utint UNSIGNED_TINYINT, - sint SMALLINT, - usint UNSIGNED_SMALLINT, - flt FLOAT, - uflt UNSIGNED_FLOAT, - dbl DOUBLE, - udbl UNSIGNED_DOUBLE, - dec DECIMAL(10, 5), - dec2 DECIMAL, - bool BOOLEAN, - tm TIME, - dt DATE, - tmstmp TIMESTAMP, - utm UNSIGNED_TIME, - udt UNSIGNED_DATE, - utmstmp UNSIGNED_TIMESTAMP, - var VARCHAR(10), - ch CHAR(3), - bin BINARY(20), - varbin VARBINARY - ) TRANSACTIONAL=false`) - - // Select - rows, err := dbt.db.QueryContext(getContext(), "SELECT * FROM "+dbt.tableName) - - if err != nil { - t.Errorf("Unexpected error while selecting from table: %s", err) - } - - columnNames, err := rows.Columns() - - if err != nil { - t.Errorf("Error getting column names: %s", err) - } - - expectedColumnNames := []string{"INT", "UINT", "BINT", "ULONG", "TINT", "UTINT", "SINT", "USINT", "FLT", "UFLT", "DBL", "UDBL", "DEC", "DEC2", "BOOL", "TM", "DT", "TMSTMP", "UTM", "UDT", "UTMSTMP", "VAR", "CH", "BIN", "VARBIN"} - - if !reflect.DeepEqual(columnNames, expectedColumnNames) { - t.Error("Column names does not match expected column names") - } - - type decimalSize struct { - precision int64 - scale int64 - ok bool - } - - type length struct { - length int64 - ok bool - } - - type nullable struct { - nullable bool - ok bool - } - - expectedColumnTypes := []struct { - databaseTypeName string - decimalSize decimalSize - length length - name string - nullable nullable - scanType reflect.Type - }{ - { - databaseTypeName: "INTEGER", - decimalSize: decimalSize{ - precision: 0, - scale: 0, - ok: false, - }, - length: length{ - length: 0, - ok: false, - }, - name: "INT", - nullable: nullable{ - nullable: false, - ok: true, - }, - scanType: reflect.TypeOf(int64(0)), - }, - { - databaseTypeName: "UNSIGNED_INT", - decimalSize: decimalSize{ - precision: 0, - scale: 0, - ok: false, - }, - length: length{ - length: 0, - ok: false, - }, - name: "UINT", - nullable: nullable{ - nullable: true, - ok: true, - }, - scanType: reflect.TypeOf(int64(0)), - }, - { - databaseTypeName: "BIGINT", - decimalSize: decimalSize{ - precision: 0, - scale: 0, - ok: false, - }, - length: length{ - length: 0, - ok: false, - }, - name: "BINT", - nullable: nullable{ - nullable: true, - ok: true, - }, - scanType: reflect.TypeOf(int64(0)), - }, - { - databaseTypeName: "UNSIGNED_LONG", - decimalSize: decimalSize{ - precision: 0, - scale: 0, - ok: false, - }, - length: length{ - length: 0, - ok: false, - }, - name: "ULONG", - nullable: nullable{ - nullable: true, - ok: true, - }, - scanType: reflect.TypeOf(int64(0)), - }, - { - databaseTypeName: "TINYINT", - decimalSize: decimalSize{ - precision: 0, - scale: 0, - ok: false, - }, - length: length{ - length: 0, - ok: false, - }, - name: "TINT", - nullable: nullable{ - nullable: true, - ok: true, - }, - scanType: reflect.TypeOf(int64(0)), - }, - { - databaseTypeName: "UNSIGNED_TINYINT", - decimalSize: decimalSize{ - precision: 0, - scale: 0, - ok: false, - }, - length: length{ - length: 0, - ok: false, - }, - name: "UTINT", - nullable: nullable{ - nullable: true, - ok: true, - }, - scanType: reflect.TypeOf(int64(0)), - }, - { - databaseTypeName: "SMALLINT", - decimalSize: decimalSize{ - precision: 0, - scale: 0, - ok: false, - }, - length: length{ - length: 0, - ok: false, - }, - name: "SINT", - nullable: nullable{ - nullable: true, - ok: true, - }, - scanType: reflect.TypeOf(int64(0)), - }, - { - databaseTypeName: "UNSIGNED_SMALLINT", - decimalSize: decimalSize{ - precision: 0, - scale: 0, - ok: false, - }, - length: length{ - length: 0, - ok: false, - }, - name: "USINT", - nullable: nullable{ - nullable: true, - ok: true, - }, - scanType: reflect.TypeOf(int64(0)), - }, - { - databaseTypeName: "FLOAT", - decimalSize: decimalSize{ - precision: 0, - scale: 0, - ok: false, - }, - length: length{ - length: 0, - ok: false, - }, - name: "FLT", - nullable: nullable{ - nullable: true, - ok: true, - }, - scanType: reflect.TypeOf(float64(0)), - }, - { - databaseTypeName: "UNSIGNED_FLOAT", - decimalSize: decimalSize{ - precision: 0, - scale: 0, - ok: false, - }, - length: length{ - length: 0, - ok: false, - }, - name: "UFLT", - nullable: nullable{ - nullable: true, - ok: true, - }, - scanType: reflect.TypeOf(float64(0)), - }, - { - databaseTypeName: "DOUBLE", - decimalSize: decimalSize{ - precision: 0, - scale: 0, - ok: false, - }, - length: length{ - length: 0, - ok: false, - }, - name: "DBL", - nullable: nullable{ - nullable: true, - ok: true, - }, - scanType: reflect.TypeOf(float64(0)), - }, - { - databaseTypeName: "UNSIGNED_DOUBLE", - decimalSize: decimalSize{ - precision: 0, - scale: 0, - ok: false, - }, - length: length{ - length: 0, - ok: false, - }, - name: "UDBL", - nullable: nullable{ - nullable: true, - ok: true, - }, - scanType: reflect.TypeOf(float64(0)), - }, - { - databaseTypeName: "DECIMAL", - decimalSize: decimalSize{ - precision: 10, - scale: 5, - ok: true, - }, - length: length{ - length: 0, - ok: false, - }, - name: "DEC", - nullable: nullable{ - nullable: true, - ok: true, - }, - scanType: reflect.TypeOf(""), - }, - { - databaseTypeName: "DECIMAL", - decimalSize: decimalSize{ - precision: math.MaxInt64, - scale: math.MaxInt64, - ok: true, - }, - length: length{ - length: 0, - ok: false, - }, - name: "DEC2", - nullable: nullable{ - nullable: true, - ok: true, - }, - scanType: reflect.TypeOf(""), - }, - { - databaseTypeName: "BOOLEAN", - decimalSize: decimalSize{ - precision: 0, - scale: 0, - ok: false, - }, - length: length{ - length: 0, - ok: false, - }, - name: "BOOL", - nullable: nullable{ - nullable: true, - ok: true, - }, - scanType: reflect.TypeOf(false), - }, - { - databaseTypeName: "TIME", - decimalSize: decimalSize{ - precision: 0, - scale: 0, - ok: false, - }, - length: length{ - length: 0, - ok: false, - }, - name: "TM", - nullable: nullable{ - nullable: true, - ok: true, - }, - scanType: reflect.TypeOf(time.Time{}), - }, - { - databaseTypeName: "DATE", - decimalSize: decimalSize{ - precision: 0, - scale: 0, - ok: false, - }, - length: length{ - length: 0, - ok: false, - }, - name: "DT", - nullable: nullable{ - nullable: true, - ok: true, - }, - scanType: reflect.TypeOf(time.Time{}), - }, - { - databaseTypeName: "TIMESTAMP", - decimalSize: decimalSize{ - precision: 0, - scale: 0, - ok: false, - }, - length: length{ - length: 0, - ok: false, - }, - name: "TMSTMP", - nullable: nullable{ - nullable: true, - ok: true, - }, - scanType: reflect.TypeOf(time.Time{}), - }, - { - databaseTypeName: "UNSIGNED_TIME", - decimalSize: decimalSize{ - precision: 0, - scale: 0, - ok: false, - }, - length: length{ - length: 0, - ok: false, - }, - name: "UTM", - nullable: nullable{ - nullable: true, - ok: true, - }, - scanType: reflect.TypeOf(time.Time{}), - }, - { - databaseTypeName: "UNSIGNED_DATE", - decimalSize: decimalSize{ - precision: 0, - scale: 0, - ok: false, - }, - length: length{ - length: 0, - ok: false, - }, - name: "UDT", - nullable: nullable{ - nullable: true, - ok: true, - }, - scanType: reflect.TypeOf(time.Time{}), - }, - { - databaseTypeName: "UNSIGNED_TIMESTAMP", - decimalSize: decimalSize{ - precision: 0, - scale: 0, - ok: false, - }, - length: length{ - length: 0, - ok: false, - }, - name: "UTMSTMP", - nullable: nullable{ - nullable: true, - ok: true, - }, - scanType: reflect.TypeOf(time.Time{}), - }, - { - databaseTypeName: "VARCHAR", - decimalSize: decimalSize{ - precision: 0, - scale: 0, - ok: false, - }, - length: length{ - length: 10, - ok: true, - }, - name: "VAR", - nullable: nullable{ - nullable: true, - ok: true, - }, - scanType: reflect.TypeOf(""), - }, - { - databaseTypeName: "CHAR", - decimalSize: decimalSize{ - precision: 0, - scale: 0, - ok: false, - }, - length: length{ - length: 3, - ok: true, - }, - name: "CH", - nullable: nullable{ - nullable: true, - ok: true, - }, - scanType: reflect.TypeOf(""), - }, - { - databaseTypeName: "BINARY", - decimalSize: decimalSize{ - precision: 0, - scale: 0, - ok: false, - }, - length: length{ - length: 20, - ok: true, - }, - name: "BIN", - nullable: nullable{ - nullable: true, - ok: true, - }, - scanType: reflect.TypeOf([]byte{}), - }, - { - databaseTypeName: "VARBINARY", - decimalSize: decimalSize{ - precision: 0, - scale: 0, - ok: false, - }, - length: length{ - length: math.MaxInt64, - ok: true, - }, - name: "VARBIN", - nullable: nullable{ - nullable: true, - ok: true, - }, - scanType: reflect.TypeOf([]byte{}), - }, - } - - columnTypes, err := rows.ColumnTypes() - - if err != nil { - t.Errorf("Error getting column types: %s", err) - } - - for index, columnType := range columnTypes { - - expected := expectedColumnTypes[index] - - if columnType.DatabaseTypeName() != expected.databaseTypeName { - t.Errorf("Expected database type name for index %d to be %s, got %s", index, expected.databaseTypeName, columnType.DatabaseTypeName()) - } - - precision, scale, ok := columnType.DecimalSize() - - if precision != expected.decimalSize.precision { - t.Errorf("Expected decimal precision for index %d to be %d, got %d", index, expected.decimalSize.precision, precision) - } - - if scale != expected.decimalSize.scale { - t.Errorf("Expected decimal scale for index %d to be %d, got %d", index, expected.decimalSize.scale, scale) - } - - if ok != expected.decimalSize.ok { - t.Errorf("Expected decimal ok for index %d to be %t, got %t", index, expected.decimalSize.ok, ok) - } - - length, ok := columnType.Length() - - if length != expected.length.length { - t.Errorf("Expected length for index %d to be %d, got %d", index, expected.length.length, length) - } - - if ok != expected.length.ok { - t.Errorf("Expected length ok for index %d to be %t, got %t", index, expected.length.ok, ok) - } - - if columnType.Name() != expected.name { - t.Errorf("Expected column name for index %d to be %s, got %s", index, expected.name, columnType.Name()) - } - - nullable, ok := columnType.Nullable() - - if nullable != expected.nullable.nullable { - t.Errorf("Expected nullable for index %d to be %t, got %t", index, expected.nullable.nullable, nullable) - } - - if ok != expected.nullable.ok { - t.Errorf("Expected nullable ok for index %d to be %t, got %t", index, expected.nullable.ok, ok) - } - - if columnType.ScanType() != expected.scanType { - t.Errorf("Expected scan type for index %d to be %s, got %s", index, expected.scanType, columnType.ScanType()) - } - } - - }) -} http://git-wip-us.apache.org/repos/asf/calcite-avatica-go/blob/2968def4/driver_hsqldb_test.go ---------------------------------------------------------------------- diff --git a/driver_hsqldb_test.go b/driver_hsqldb_test.go new file mode 100644 index 0000000..7597cd2 --- /dev/null +++ b/driver_hsqldb_test.go @@ -0,0 +1,1028 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to you under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package avatica + +import ( + "bytes" + "crypto/sha256" + "database/sql" + "io/ioutil" + "os" + "path/filepath" + "testing" + "time" +) + +func skipTestIfNotHSQLDB(t *testing.T) { + + val := os.Getenv("AVATICA_FLAVOR") + + if val != "HSQLDB" { + t.Skip("Skipping Apache Avatica HSQLDB test") + } +} + +func TestHSQLDBConnectionMustBeOpenedWithAutoCommitTrue(t *testing.T) { + + skipTestIfNotHSQLDB(t) + + runTests(t, dsn, func(dbt *DBTest) { + + // Create and seed table + dbt.mustExec("CREATE TABLE " + dbt.tableName + " (id BIGINT PRIMARY KEY, val VARCHAR(1))") + + dbt.mustExec("INSERT INTO " + dbt.tableName + " VALUES (1,'A')") + + dbt.mustExec("INSERT INTO " + dbt.tableName + " VALUES (2,'B')") + + rows := dbt.mustQuery("SELECT COUNT(*) FROM " + dbt.tableName) + defer rows.Close() + + for rows.Next() { + + var count int + + err := rows.Scan(&count) + + if err != nil { + dbt.Fatal(err) + } + + if count != 2 { + dbt.Fatalf("There should be 2 rows, got %d", count) + } + } + + }) +} + +func TestHSQLDBZeroValues(t *testing.T) { + + skipTestIfNotHSQLDB(t) + + runTests(t, dsn, func(dbt *DBTest) { + + // Create and seed table + dbt.mustExec("CREATE TABLE " + dbt.tableName + " (int INTEGER PRIMARY KEY, flt FLOAT, bool BOOLEAN, str VARCHAR(1))") + + dbt.mustExec("INSERT INTO " + dbt.tableName + " VALUES (0, 0.0, false, '')") + + rows := dbt.mustQuery("SELECT * FROM " + dbt.tableName) + defer rows.Close() + + for rows.Next() { + + var i int + var flt float64 + var b bool + var s string + + err := rows.Scan(&i, &flt, &b, &s) + + if err != nil { + dbt.Fatal(err) + } + + if i != 0 { + dbt.Fatalf("Integer should be 0, got %v", i) + } + + if flt != 0.0 { + dbt.Fatalf("Float should be 0.0, got %v", flt) + } + + if b != false { + dbt.Fatalf("Boolean should be false, got %v", b) + } + + if s != "" { + dbt.Fatalf("String should be \"\", got %v", s) + } + } + + }) +} + +func TestHSQLDBDataTypes(t *testing.T) { + + // TODO; Test case for Time type is currently commented out due to CALCITE-1951 + + skipTestIfNotHSQLDB(t) + + runTests(t, dsn, func(dbt *DBTest) { + + // Create and seed table + /*dbt.mustExec(`CREATE TABLE ` + dbt.tableName + ` ( + int INTEGER PRIMARY KEY, + tint TINYINT, + sint SMALLINT, + bint BIGINT, + num NUMERIC(10,3), + dec DECIMAL(10,3), + re REAL, + flt FLOAT, + dbl DOUBLE, + bool BOOLEAN, + ch CHAR(3), + var VARCHAR(128), + bin BINARY(20), + varbin VARBINARY(128), + dt DATE, + tm TIME, + tmstmp TIMESTAMP, + )`)*/ + + dbt.mustExec(`CREATE TABLE ` + dbt.tableName + ` ( + int INTEGER PRIMARY KEY, + tint TINYINT, + sint SMALLINT, + bint BIGINT, + num NUMERIC(10,3), + dec DECIMAL(10,3), + re REAL, + flt FLOAT, + dbl DOUBLE, + bool BOOLEAN, + ch CHAR(3), + var VARCHAR(128), + bin BINARY(20), + varbin VARBINARY(128), + dt DATE, + tmstmp TIMESTAMP, + )`) + + var ( + integerValue int = -20 + tintValue int = -128 + sintValue int = -32768 + bintValue int = -9223372036854775807 + numValue string = "1.333" + decValue string = "1.333" + reValue float64 = 3.555 + fltValue float64 = -3.555 + dblValue float64 = -9.555 + booleanValue bool = true + chValue string = "a" + varcharValue string = "test string" + binValue []byte = make([]byte, 20, 20) + varbinValue []byte = []byte("testtesttest") + dtValue time.Time = time.Date(2100, 2, 1, 0, 0, 0, 0, time.UTC) + // tmValue time.Time = time.Date(0, 1, 1, 21, 21, 21, 222000000, time.UTC) + tmstmpValue time.Time = time.Date(2100, 2, 1, 21, 21, 21, 222000000, time.UTC) + ) + + copy(binValue[:], []byte("test")) + + // dbt.mustExec(`INSERT INTO `+dbt.tableName+` (int, tint, sint, bint, num, dec, re, flt, dbl, bool, ch, var, bin, varbin, dt, tm, tmstmp) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, + dbt.mustExec(`INSERT INTO `+dbt.tableName+` (int, tint, sint, bint, num, dec, re, flt, dbl, bool, ch, var, bin, varbin, dt, tmstmp) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, + integerValue, + tintValue, + sintValue, + bintValue, + numValue, + decValue, + reValue, + fltValue, + dblValue, + booleanValue, + chValue, + varcharValue, + binValue, + varbinValue, + dtValue, + // tmValue, + tmstmpValue, + ) + + rows := dbt.mustQuery("SELECT * FROM " + dbt.tableName) + defer rows.Close() + + var ( + integer int + tint int + sint int + bint int + num string + dec string + re float64 + flt float64 + dbl float64 + boolean bool + ch string + varchar string + bin []byte + varbin []byte + dt time.Time + // tm time.Time + tmstmp time.Time + ) + + for rows.Next() { + + // err := rows.Scan(&integer, &tint, &sint, &bint, &num, &dec, &re, &flt, &dbl, &boolean, &ch, &varchar, &bin, &varbin, &dt, &tm, &tmstmp) + err := rows.Scan(&integer, &tint, &sint, &bint, &num, &dec, &re, &flt, &dbl, &boolean, &ch, &varchar, &bin, &varbin, &dt, &tmstmp) + + if err != nil { + dbt.Fatal(err) + } + } + + comparisons := []struct { + result interface{} + expected interface{} + }{ + {integer, integerValue}, + {tint, tintValue}, + {sint, sintValue}, + {bint, bintValue}, + {num, numValue}, + {dec, decValue}, + {re, reValue}, + {flt, fltValue}, + {dbl, dblValue}, + {boolean, booleanValue}, + {ch, chValue + " "}, // HSQLDB pads CHAR columns if a length is specified + {varchar, varcharValue}, + {bin, binValue}, + {varbin, varbinValue}, + {dt, dtValue}, + // {tm, tmValue}, + {tmstmp, tmstmpValue}, + } + + for _, tt := range comparisons { + + if v, ok := tt.expected.(time.Time); ok { + + if !v.Equal(tt.result.(time.Time)) { + dbt.Fatalf("Expected %v, got %v.", tt.expected, tt.result) + } + + } else if v, ok := tt.expected.([]byte); ok { + + if !bytes.Equal(v, tt.result.([]byte)) { + dbt.Fatalf("Expected %v, got %v.", tt.expected, tt.result) + } + + } else if tt.expected != tt.result { + dbt.Errorf("Expected %v, got %v.", tt.expected, tt.result) + } + } + }) +} + +// TODO: Test case commented out due to CALCITE-1951 +/*func TestHSQLDBLocations(t *testing.T) { + + skipTestIfNotHSQLDB(t) + + query := "?location=Australia/Melbourne" + + runTests(t, dsn+query, func(dbt *DBTest) { + + // Create and seed table + dbt.mustExec(`CREATE TABLE ` + dbt.tableName + ` ( + tm TIME(6) PRIMARY KEY, + dt DATE, + tmstmp TIMESTAMP + )`) + + loc, err := time.LoadLocation("Australia/Melbourne") + + if err != nil { + dbt.Fatalf("Unexpected error: %s", err) + } + + var ( + tmValue time.Time = time.Date(0, 1, 1, 21, 21, 21, 222000000, loc) + dtValue time.Time = time.Date(2100, 2, 1, 0, 0, 0, 0, loc) + tmstmpValue time.Time = time.Date(2100, 2, 1, 21, 21, 21, 222000000, loc) + ) + + dbt.mustExec(`INSERT INTO `+dbt.tableName+`(tm, dt, tmstmp) VALUES (?, ?, ?)`, + tmValue, + dtValue, + tmstmpValue, + ) + + rows := dbt.mustQuery("SELECT * FROM " + dbt.tableName) + defer rows.Close() + + var ( + tm time.Time + dt time.Time + tmstmp time.Time + ) + + for rows.Next() { + + err := rows.Scan(&tm, &dt, &tmstmp) + + if err != nil { + dbt.Fatal(err) + } + } + + comparisons := []struct { + result time.Time + expected time.Time + }{ + {tm, tmValue}, + {dt, dtValue}, + {tmstmp, tmstmpValue}, + } + + for _, tt := range comparisons { + if !tt.result.Equal(tt.expected) { + dbt.Errorf("Expected %v, got %v.", tt.expected, tt.result) + } + } + }) +}*/ + +func TestHSQLDBDateAndTimestampsBefore1970(t *testing.T) { + + skipTestIfNotHSQLDB(t) + + runTests(t, dsn, func(dbt *DBTest) { + + // Create and seed table + dbt.mustExec(`CREATE TABLE ` + dbt.tableName + ` ( + int INTEGER PRIMARY KEY, + dt DATE, + tmstmp TIMESTAMP + )`) + + var ( + integerValue int = 1 + dtValue time.Time = time.Date(1945, 5, 20, 0, 0, 0, 0, time.UTC) + tmstmpValue time.Time = time.Date(1911, 5, 20, 21, 21, 21, 222000000, time.UTC) + ) + + dbt.mustExec(`INSERT INTO `+dbt.tableName+`(int, dt, tmstmp) VALUES (?, ?, ?)`, + integerValue, + dtValue, + tmstmpValue, + ) + + rows := dbt.mustQuery("SELECT dt, tmstmp FROM " + dbt.tableName) + defer rows.Close() + + var ( + dt time.Time + tmstmp time.Time + ) + + for rows.Next() { + err := rows.Scan(&dt, &tmstmp) + + if err != nil { + dbt.Fatal(err) + } + } + + comparisons := []struct { + result time.Time + expected time.Time + }{ + {dt, dtValue}, + {tmstmp, tmstmpValue}, + } + + for _, tt := range comparisons { + if !tt.expected.Equal(tt.result) { + dbt.Fatalf("Expected %v, got %v.", tt.expected, tt.result) + } + } + }) +} + +func TestHSQLDBStoreAndRetrieveBinaryData(t *testing.T) { + + skipTestIfNotHSQLDB(t) + + runTests(t, dsn, func(dbt *DBTest) { + + // Create and seed table + // TODO: Switch VARBINARY to BLOB once avatica supports BLOBs and CBLOBs. CALCITE-1957 + dbt.mustExec(`CREATE TABLE ` + dbt.tableName + ` ( + int INTEGER PRIMARY KEY, + bin VARBINARY(999999) + )`) + + filePath := filepath.Join("test-fixtures", "gopher.png") + + file, err := ioutil.ReadFile(filePath) + + if err != nil { + t.Fatalf("Unable to read text-fixture: %s", filePath) + } + + hash := sha256.Sum256(file) + + dbt.mustExec(`INSERT INTO `+dbt.tableName+` (int, bin) VALUES (?, ?)`, + 1, + file, + ) + + rows := dbt.mustQuery("SELECT bin FROM " + dbt.tableName) + defer rows.Close() + + var receivedFile []byte + + for rows.Next() { + + err := rows.Scan(&receivedFile) + + if err != nil { + dbt.Fatal(err) + } + } + + ioutil.WriteFile("test-fixtures/gopher.png", receivedFile, os.ModePerm) + + receivedHash := sha256.Sum256(receivedFile) + + if !bytes.Equal(hash[:], receivedHash[:]) { + t.Fatalf("Hash of stored file (%x) does not equal hash of retrieved file (%x).", hash[:], receivedHash[:]) + } + }) +} + +/*func TestHSQLDBCommittingTransactions(t *testing.T) { + + skipTestIfNotHSQLDB(t) + + query := "?transactionIsolation=4" + + runTests(t, dsn+query, func(dbt *DBTest) { + + // Create and seed table + dbt.mustExec(`CREATE TABLE ` + dbt.tableName + ` ( + int INTEGER PRIMARY KEY + )`) + + tx, err := dbt.db.Begin() + + if err != nil { + t.Fatalf("Unable to create transaction: %s", err) + } + + stmt, err := tx.Prepare(`INSERT INTO ` + dbt.tableName + `(int) VALUES(?)`) + + if err != nil { + t.Fatalf("Could not prepare statement: %s", err) + } + + totalRows := 6 + + for i := 1; i <= totalRows; i++ { + _, err := stmt.Exec(i) + + if err != nil { + dbt.Fatal(err) + } + } + + r := tx.QueryRow("SELECT COUNT(*) FROM " + dbt.tableName) + + var count int + + err = r.Scan(&count) + + if err != nil { + t.Fatalf("Unable to scan row result: %s", err) + } + + if count != totalRows { + t.Fatalf("Expected %d rows, got %d", totalRows, count) + } + + // Commit the transaction + tx.Commit() + + rows := dbt.mustQuery("SELECT COUNT(*) FROM " + dbt.tableName) + + var countAfterRollback int + + for rows.Next() { + err := rows.Scan(&countAfterRollback) + + if err != nil { + dbt.Fatal(err) + } + } + + if countAfterRollback != totalRows { + t.Fatalf("Expected %d rows, got %d", totalRows, countAfterRollback) + } + }) +} + +func TestHSQLDBRollingBackTransactions(t *testing.T) { + + skipTestIfNotHSQLDB(t) + + query := "?transactionIsolation=4" + + runTests(t, dsn+query, func(dbt *DBTest) { + + // Create and seed table + dbt.mustExec(`CREATE TABLE ` + dbt.tableName + ` ( + int INTEGER PRIMARY KEY + )`) + + tx, err := dbt.db.Begin() + + if err != nil { + t.Fatalf("Unable to create transaction: %s", err) + } + + stmt, err := tx.Prepare(`INSERT INTO ` + dbt.tableName + `(int) VALUES(?)`) + + if err != nil { + t.Fatalf("Could not prepare statement: %s", err) + } + + totalRows := 6 + + for i := 1; i <= totalRows; i++ { + _, err := stmt.Exec(i) + + if err != nil { + dbt.Fatal(err) + } + } + + r := tx.QueryRow(`SELECT COUNT(*) FROM ` + dbt.tableName) + + var count int + + err = r.Scan(&count) + + if err != nil { + t.Fatalf("Unable to scan row result: %s", err) + } + + if count != totalRows { + t.Fatalf("Expected %d rows, got %d", totalRows, count) + } + + // Rollback the transaction + tx.Rollback() + + rows := dbt.mustQuery(`SELECT COUNT(*) FROM ` + dbt.tableName) + + var countAfterRollback int + + for rows.Next() { + err := rows.Scan(&countAfterRollback) + + if err != nil { + dbt.Fatal(err) + } + } + + if countAfterRollback != 0 { + t.Fatalf("Expected %d rows, got %d", 0, countAfterRollback) + } + }) +}*/ + +func TestHSQLDBPreparedStatements(t *testing.T) { + + skipTestIfNotHSQLDB(t) + + runTests(t, dsn, func(dbt *DBTest) { + + // Create and seed table + dbt.mustExec(`CREATE TABLE ` + dbt.tableName + ` ( + int INTEGER PRIMARY KEY + )`) + + stmt, err := dbt.db.Prepare(`INSERT INTO ` + dbt.tableName + `(int) VALUES(?)`) + + if err != nil { + dbt.Fatal(err) + } + + totalRows := 6 + + for i := 1; i <= totalRows; i++ { + _, err := stmt.Exec(i) + + if err != nil { + dbt.Fatal(err) + } + } + + queryStmt, err := dbt.db.Prepare(`SELECT * FROM ` + dbt.tableName + ` WHERE int = ?`) + + if err != nil { + dbt.Fatal(err) + } + + var res int + + for i := 1; i <= totalRows; i++ { + + err := queryStmt.QueryRow(i).Scan(&res) + + if err != nil { + dbt.Fatal(err) + } + + if res != i { + dbt.Fatalf("Unexpected query result. Expected %d, got %d.", i, res) + } + } + }) +} + +func TestHSQLDBFetchingMoreRows(t *testing.T) { + + skipTestIfNotHSQLDB(t) + + query := "?maxRowsTotal=-1&frameMaxSize=1" + + runTests(t, dsn+query, func(dbt *DBTest) { + + // Create and seed table + dbt.mustExec(`CREATE TABLE ` + dbt.tableName + ` ( + int INTEGER PRIMARY KEY + )`) + + stmt, err := dbt.db.Prepare(`INSERT INTO ` + dbt.tableName + `(int) VALUES(?)`) + + if err != nil { + dbt.Fatal(err) + } + + totalRows := 6 + + for i := 1; i <= totalRows; i++ { + _, err := stmt.Exec(i) + + if err != nil { + dbt.Fatal(err) + } + } + + rows := dbt.mustQuery(`SELECT * FROM ` + dbt.tableName) + defer rows.Close() + + count := 0 + + for rows.Next() { + count++ + } + + if count != totalRows { + dbt.Fatalf("Expected %d rows to be retrieved, retrieved %d", totalRows, count) + } + }) +} + +func TestHSQLDBExecuteShortcut(t *testing.T) { + + skipTestIfNotHSQLDB(t) + + runTests(t, dsn, func(dbt *DBTest) { + + // Create and seed table + dbt.mustExec(`CREATE TABLE ` + dbt.tableName + ` ( + int INTEGER PRIMARY KEY + )`) + + res, err := dbt.db.Exec(`INSERT INTO ` + dbt.tableName + `(int) VALUES(1)`) + + if err != nil { + dbt.Fatal(err) + } + + affected, err := res.RowsAffected() + + if err != nil { + dbt.Fatal(err) + } + + if affected != 1 { + dbt.Fatalf("Expected 1 row to be affected, %d affected", affected) + } + }) +} + +func TestHSQLDBQueryShortcut(t *testing.T) { + + skipTestIfNotHSQLDB(t) + + query := "?maxRowsTotal=-1&frameMaxSize=1" + + runTests(t, dsn+query, func(dbt *DBTest) { + + // Create and seed table + dbt.mustExec(`CREATE TABLE ` + dbt.tableName + ` ( + int INTEGER PRIMARY KEY + )`) + + stmt, err := dbt.db.Prepare(`INSERT INTO ` + dbt.tableName + `(int) VALUES(?)`) + + if err != nil { + dbt.Fatal(err) + } + + totalRows := 6 + + for i := 1; i <= totalRows; i++ { + _, err := stmt.Exec(i) + + if err != nil { + dbt.Fatal(err) + } + } + + rows := dbt.mustQuery(`SELECT * FROM ` + dbt.tableName) + defer rows.Close() + + count := 0 + + for rows.Next() { + count++ + } + + if count != totalRows { + dbt.Fatalf("Expected %d rows to be retrieved, retrieved %d", totalRows, count) + } + }) +} + +// TODO: Test disabled due to CALCITE-2250 +/*func TestHSQLDBOptimisticConcurrency(t *testing.T) { + + skipTestIfNotHSQLDB(t) + + query := "?transactionIsolation=4" + + runTests(t, dsn+query, func(dbt *DBTest) { + + // Create and seed table + dbt.mustExec(`CREATE TABLE ` + dbt.tableName + ` ( + id INTEGER PRIMARY KEY, + msg VARCHAR(64), + version INTEGER + )`) + + stmt, err := dbt.db.Prepare(`INSERT INTO ` + dbt.tableName + `(id, msg, version) VALUES(?, ?, ?)`) + + if err != nil { + dbt.Fatal(err) + } + + totalRows := 6 + + for i := 1; i <= totalRows; i++ { + _, err := stmt.Exec(i, fmt.Sprintf("message version %d", i), i) + + if err != nil { + dbt.Fatal(err) + } + } + + // Start the transactions + tx1, err := dbt.db.Begin() + + if err != nil { + dbt.Fatal(err) + } + + tx2, err := dbt.db.Begin() + + if err != nil { + dbt.Fatal(err) + } + + // Select from first transaction + _ = tx1.QueryRow(`SELECT MAX(version) FROM ` + dbt.tableName) + + // Modify using second transaction + _, err = tx2.Exec(`INSERT INTO `+dbt.tableName+`(id, msg, version) VALUES(?, ?, ?)`, 7, "message value 7", 7) + + if err != nil { + dbt.Fatal(err) + } + + err = tx2.Commit() + + if err != nil { + dbt.Fatal(err) + } + + // Modify using tx1 + _, err = tx1.Exec(`INSERT INTO `+dbt.tableName+`(id, msg, version) VALUES(?, ?, ?)`, 7, "message value 7", 7) + + if err != nil { + dbt.Fatal(err) + } + + err = tx1.Commit() + + if err == nil { + dbt.Fatal("Expected an error, but did not receive any.") + } + + errName := err.(ResponseError).Name() + + if errName != "transaction_conflict_exception" { + dbt.Fatal("Expected transaction_conflict") + } + }) +}*/ + +func TestHSQLDBLastInsertIDShouldReturnError(t *testing.T) { + + skipTestIfNotHSQLDB(t) + + runTests(t, dsn, func(dbt *DBTest) { + + dbt.mustExec(`CREATE TABLE ` + dbt.tableName + ` ( + id INTEGER IDENTITY PRIMARY KEY, + msg VARCHAR(3), + version INTEGER + )`) + + res, err := dbt.db.Exec(`INSERT INTO ` + dbt.tableName + `(msg, version) VALUES('abc', 1)`) + + if err != nil { + dbt.Fatal(err) + } + + _, err = res.LastInsertId() + + if err == nil { + dbt.Fatal("Expected an error as Avatica does not support LastInsertId(), but there was no error.") + } + }) +} + +func TestHSQLDBSchemaSupport(t *testing.T) { + + skipTestIfNotHSQLDB(t) + + db, err := sql.Open("avatica", dsn) + + if err != nil { + t.Fatalf("error connecting: %s", err.Error()) + } + + defer db.Close() + + _, err = db.Exec("CREATE SCHEMA IF NOT EXISTS avaticatest") + + if err != nil { + t.Fatalf("error creating schema: %s", err) + } + + defer db.Exec("DROP SCHEMA IF EXISTS avaticatest") + + path := "/avaticatest" + + runTests(t, dsn+path, func(dbt *DBTest) { + + // Create and seed table + dbt.mustExec(`CREATE TABLE avaticatest.` + dbt.tableName + ` ( + int INTEGER PRIMARY KEY + );`) + + defer dbt.mustExec(`DROP TABLE IF EXISTS avaticatest.` + dbt.tableName) + + _, err := dbt.db.Exec(`INSERT INTO avaticatest.` + dbt.tableName + `(int) VALUES(1)`) + + if err != nil { + dbt.Fatal(err) + } + + rows := dbt.mustQuery(`SELECT * FROM avaticatest.` + dbt.tableName) + defer rows.Close() + + count := 0 + + for rows.Next() { + count++ + } + + if count != 1 { + dbt.Errorf("Expected 1 row, got %d rows back,", count) + } + }) +} + +func TestHSQLDBMultipleSchemaSupport(t *testing.T) { + + skipTestIfNotHSQLDB(t) + + db, err := sql.Open("avatica", dsn) + + if err != nil { + t.Fatalf("error connecting: %s", err.Error()) + } + + defer db.Close() + + _, err = db.Exec("CREATE SCHEMA IF NOT EXISTS avaticatest1") + + if err != nil { + t.Fatalf("error creating schema: %s", err) + } + + defer db.Exec("DROP SCHEMA IF EXISTS avaticatest1") + + _, err = db.Exec("CREATE SCHEMA IF NOT EXISTS avaticatest2") + + if err != nil { + t.Fatalf("error creating schema: %s", err) + } + + defer db.Exec("DROP SCHEMA IF EXISTS avaticatest2") + + path := "/avaticatest1" + + runTests(t, dsn+path, func(dbt *DBTest) { + + dbt.mustExec(`SET INITIAL SCHEMA avaticatest2`) + + // Create and seed table + dbt.mustExec(`CREATE TABLE avaticatest2.` + dbt.tableName + ` ( + int INTEGER PRIMARY KEY + )`) + + defer dbt.mustExec(`DROP TABLE IF EXISTS avaticatest2.` + dbt.tableName) + + _, err := dbt.db.Exec(`INSERT INTO avaticatest2.` + dbt.tableName + `(int) VALUES(1)`) + + if err != nil { + dbt.Fatal(err) + } + + rows := dbt.mustQuery(`SELECT * FROM avaticatest2.` + dbt.tableName) + defer rows.Close() + + count := 0 + + for rows.Next() { + count++ + } + + if count != 1 { + dbt.Errorf("Expected 1 row, got %d rows back,", count) + } + }) +} + +// TODO: Test disabled due to CALCITE-1049 +/*func TestHSQLDBErrorCodeParsing(t *testing.T) { + + skipTestIfNotHSQLDB(t) + + db, err := sql.Open("avatica", dsn) + + if err != nil { + t.Fatalf("error connecting: %s", err.Error()) + } + + defer db.Close() + + _, err = db.Query("SELECT * FROM table_that_does_not_exist") + + if err == nil { + t.Error("Expected error due to selecting from non-existent table, but there was no error.") + } + + resErr, ok := err.(ResponseError) + + if !ok { + t.Fatalf("Error type was not ResponseError") + } + + if resErr.ErrorCode != 1012 { + t.Errorf("Expected error code to be %d, got %d.", 1012, resErr.ErrorCode) + } + + if resErr.SqlState != "42M03" { + t.Errorf("Expected SQL state to be %s, got %s.", "42M03", resErr.SqlState) + } +}*/