Add tests for go 1.8 interfaces
Project: http://git-wip-us.apache.org/repos/asf/calcite-avatica-go/repo Commit: http://git-wip-us.apache.org/repos/asf/calcite-avatica-go/commit/b4003a72 Tree: http://git-wip-us.apache.org/repos/asf/calcite-avatica-go/tree/b4003a72 Diff: http://git-wip-us.apache.org/repos/asf/calcite-avatica-go/diff/b4003a72 Branch: refs/heads/master Commit: b4003a72aa1edb90174a0601fdc13845b85dda8e Parents: ae26325 Author: Francis Chuang <francis.chu...@boostport.com> Authored: Wed Mar 8 21:04:23 2017 +1100 Committer: Julian Hyde <jh...@apache.org> Committed: Thu Aug 10 18:47:10 2017 -0700 ---------------------------------------------------------------------- compat_go18.go | 2 + connection.go | 2 - connection_go18.go | 9 +- driver_go18_test.go | 793 +++++++++++++++++++++++++++++++++++++++++++++++ rows_go18.go | 1 + statement_go18.go | 5 +- 6 files changed, 802 insertions(+), 10 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/calcite-avatica-go/blob/b4003a72/compat_go18.go ---------------------------------------------------------------------- diff --git a/compat_go18.go b/compat_go18.go index 8efd1b5..f57cc11 100644 --- a/compat_go18.go +++ b/compat_go18.go @@ -1,3 +1,5 @@ +// +build go1.8 + package avatica import ( http://git-wip-us.apache.org/repos/asf/calcite-avatica-go/blob/b4003a72/connection.go ---------------------------------------------------------------------- diff --git a/connection.go b/connection.go index 9e59701..dfb5b5a 100644 --- a/connection.go +++ b/connection.go @@ -68,7 +68,6 @@ func (c *conn) Close() error { // Begin starts and returns a new transaction. func (c *conn) Begin() (driver.Tx, error) { - return c.begin(context.Background(), isolationUseCurrent) } @@ -143,7 +142,6 @@ func (c *conn) exec(ctx context.Context, query string, args []namedValue) (drive } // Query prepares and executes a query and returns the result directly. -// Query's optimizations are currently disabled due to CALCITE-1181. func (c *conn) Query(query string, args []driver.Value) (driver.Rows, error) { list := driverValueToNamedValue(args) return c.query(context.Background(), query, list) http://git-wip-us.apache.org/repos/asf/calcite-avatica-go/blob/b4003a72/connection_go18.go ---------------------------------------------------------------------- diff --git a/connection_go18.go b/connection_go18.go index a54a115..1403679 100644 --- a/connection_go18.go +++ b/connection_go18.go @@ -1,15 +1,14 @@ // +build go1.8 + package avatica import ( + "database/sql" "database/sql/driver" - "errors" - - "database/sql" "fmt" - "golang.org/x/net/context" + "context" ) func (c *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { @@ -60,7 +59,7 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name func (c *conn) Ping(ctx context.Context) error { - _, err := c.ExecContext(ctx, "SELECT 1;", []driver.NamedValue{}) + _, err := c.ExecContext(ctx, "SELECT 1", []driver.NamedValue{}) if err != nil { return fmt.Errorf("Error pinging database: %s", err) http://git-wip-us.apache.org/repos/asf/calcite-avatica-go/blob/b4003a72/driver_go18_test.go ---------------------------------------------------------------------- diff --git a/driver_go18_test.go b/driver_go18_test.go new file mode 100644 index 0000000..d5b7617 --- /dev/null +++ b/driver_go18_test.go @@ -0,0 +1,793 @@ +// +build go1.8 + +package avatica + +import ( + "database/sql" + "testing" + "time" + + "context" + "math" + "reflect" +) + +func (dbt *DBTest) mustExecContext(ctx context.Context, query string, args ...interface{}) (res sql.Result) { + res, err := dbt.db.ExecContext(ctx, query, args...) + + if err != nil { + dbt.fail("exec", query, err) + } + + return res +} + +func (dbt *DBTest) mustQueryContext(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows) { + rows, err := dbt.db.QueryContext(ctx, query, args...) + + if err != nil { + dbt.fail("query", query, err) + } + + return rows +} + +func getContext() context.Context { + ctx, _ := context.WithTimeout(context.Background(), 2*time.Minute) + + return ctx +} + +func TestContext(t *testing.T) { + + runTests(t, dsn, func(dbt *DBTest) { + + // Create and seed table + dbt.mustExecContext(getContext(), "CREATE TABLE "+dbt.tableName+" (id BIGINT PRIMARY KEY, val VARCHAR) TRANSACTIONAL=false") + + dbt.mustExecContext(getContext(), "UPSERT INTO "+dbt.tableName+" VALUES (1,'A')") + + dbt.mustExecContext(getContext(), "UPSERT INTO "+dbt.tableName+" VALUES (2,'B')") + + rows := dbt.mustQueryContext(getContext(), "SELECT COUNT(*) FROM "+dbt.tableName) + defer rows.Close() + + for rows.Next() { + + var count int + + err := rows.Scan(&count) + + if err != nil { + dbt.Fatal(err) + } + + if count != 2 { + dbt.Fatalf("There should be 2 rows, got %d", count) + } + } + + // Test transactions and prepared statements + _, err := dbt.db.BeginTx(getContext(), &sql.TxOptions{Isolation: sql.LevelReadUncommitted, ReadOnly: true}) + + if err == nil { + t.Error("Expected an error while creating a read only transaction, but no error was returned") + } + + tx, err := dbt.db.BeginTx(getContext(), &sql.TxOptions{Isolation: sql.LevelReadCommitted}) + + if err != nil { + t.Errorf("Unexpected error while creating transaction: %s", err) + } + + stmt, err := tx.PrepareContext(getContext(), "UPSERT INTO "+dbt.tableName+" VALUES(?,?)") + + if err != nil { + t.Errorf("Unexpected error while preparing statement: %s", err) + } + + res, err := stmt.ExecContext(getContext(), 3, "C") + + if err != nil { + t.Errorf("Unexpected error while executing statement: %s", err) + } + + affected, err := res.RowsAffected() + + if err != nil { + t.Errorf("Error getting affected rows: %s", err) + } + + if affected != 1 { + t.Errorf("Expected 1 affected row, got %d", affected) + } + + err = tx.Commit() + + if err != nil { + t.Errorf("Error committing transaction: %s", err) + } + + stmt2, err := dbt.db.PrepareContext(getContext(), "SELECT * FROM "+dbt.tableName+" WHERE id = ?") + + if err != nil { + t.Errorf("Error preparing statement: %s", err) + } + + row := stmt2.QueryRowContext(getContext(), 3) + + if err != nil { + t.Errorf("Error querying for row: %s", err) + } + + var ( + queryID int64 + queryVal string + ) + + err = row.Scan(&queryID, &queryVal) + + if err != nil { + t.Errorf("Error scanning results into variable: %s", err) + } + + if queryID != 3 { + t.Errorf("Expected scanned id to be %d, got %d", 3, queryID) + } + + if queryVal != "C" { + t.Errorf("Expected scanned string to be %s, got %s", "C", queryVal) + } + }) +} + +func TestPing(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + err := dbt.db.Ping() + + if err != nil { + t.Errorf("Expected ping to succeed, got error: %s", err) + } + }) +} + +func TestInvalidPing(t *testing.T) { + runTests(t, "http://invalid-server:8765", func(dbt *DBTest) { + err := dbt.db.Ping() + + if err == nil { + t.Error("Expected ping to fail, but did not get any error") + } + }) +} + +func TestMultipleResultSets(t *testing.T) { + + runTests(t, dsn, func(dbt *DBTest) { + // Create and seed table + dbt.mustExecContext(getContext(), "CREATE TABLE "+dbt.tableName+" (id BIGINT PRIMARY KEY, val VARCHAR) TRANSACTIONAL=false") + + dbt.mustExecContext(getContext(), "UPSERT INTO "+dbt.tableName+" VALUES (1,'A')") + + dbt.mustExecContext(getContext(), "UPSERT INTO "+dbt.tableName+" VALUES (2,'B')") + + rows, err := dbt.db.QueryContext(getContext(), "SELECT * FROM "+dbt.tableName+" WHERE id = 1") + + if err != nil { + t.Errorf("Unexpected error while executing query: %s", err) + } + + defer rows.Close() + + for rows.Next() { + var ( + id int64 + val string + ) + + if err := rows.Scan(&id, &val); err != nil { + t.Errorf("Error while scanning row into variables: %s", err) + } + + if id != 1 { + t.Errorf("Expected id to be %d, got %d", 1, id) + } + + if val != "A" { + t.Errorf("Expected value to be %s, got %s", "A", val) + } + } + + if rows.NextResultSet() { + t.Error("There should be no more result sets, but got another result set") + } + }) +} + +func TestColumnTypes(t *testing.T) { + + runTests(t, dsn, func(dbt *DBTest) { + + // Create and seed table + dbt.mustExec(`CREATE TABLE ` + dbt.tableName + ` ( + int INTEGER PRIMARY KEY, + uint UNSIGNED_INT, + bint BIGINT, + ulong UNSIGNED_LONG, + tint TINYINT, + utint UNSIGNED_TINYINT, + sint SMALLINT, + usint UNSIGNED_SMALLINT, + flt FLOAT, + uflt UNSIGNED_FLOAT, + dbl DOUBLE, + udbl UNSIGNED_DOUBLE, + dec DECIMAL(10, 5), + dec2 DECIMAL, + bool BOOLEAN, + tm TIME, + dt DATE, + tmstmp TIMESTAMP, + utm UNSIGNED_TIME, + udt UNSIGNED_DATE, + utmstmp UNSIGNED_TIMESTAMP, + var VARCHAR(10), + ch CHAR(3), + bin BINARY(20), + varbin VARBINARY + ) TRANSACTIONAL=false`) + + // Select + rows, err := dbt.db.QueryContext(getContext(), "SELECT * FROM "+dbt.tableName) + + if err != nil { + t.Errorf("Unexpected error while selecting from table: %s", err) + } + + columnNames, err := rows.Columns() + + if err != nil { + t.Errorf("Error getting column names: %s", err) + } + + expectedColumnNames := []string{"INT", "UINT", "BINT", "ULONG", "TINT", "UTINT", "SINT", "USINT", "FLT", "UFLT", "DBL", "UDBL", "DEC", "DEC2", "BOOL", "TM", "DT", "TMSTMP", "UTM", "UDT", "UTMSTMP", "VAR", "CH", "BIN", "VARBIN"} + + if !reflect.DeepEqual(columnNames, expectedColumnNames) { + t.Error("Column names does not match expected column names") + } + + type decimalSize struct { + precision int64 + scale int64 + ok bool + } + + type length struct { + length int64 + ok bool + } + + type nullable struct { + nullable bool + ok bool + } + + expectedColumnTypes := []struct { + databaseTypeName string + decimalSize decimalSize + length length + name string + nullable nullable + scanType reflect.Type + }{ + { + databaseTypeName: "INTEGER", + decimalSize: decimalSize{ + precision: 0, + scale: 0, + ok: false, + }, + length: length{ + length: 0, + ok: false, + }, + name: "INT", + nullable: nullable{ + nullable: false, + ok: true, + }, + scanType: reflect.TypeOf(int64(0)), + }, + { + databaseTypeName: "UNSIGNED_INT", + decimalSize: decimalSize{ + precision: 0, + scale: 0, + ok: false, + }, + length: length{ + length: 0, + ok: false, + }, + name: "UINT", + nullable: nullable{ + nullable: true, + ok: true, + }, + scanType: reflect.TypeOf(int64(0)), + }, + { + databaseTypeName: "BIGINT", + decimalSize: decimalSize{ + precision: 0, + scale: 0, + ok: false, + }, + length: length{ + length: 0, + ok: false, + }, + name: "BINT", + nullable: nullable{ + nullable: true, + ok: true, + }, + scanType: reflect.TypeOf(int64(0)), + }, + { + databaseTypeName: "UNSIGNED_LONG", + decimalSize: decimalSize{ + precision: 0, + scale: 0, + ok: false, + }, + length: length{ + length: 0, + ok: false, + }, + name: "ULONG", + nullable: nullable{ + nullable: true, + ok: true, + }, + scanType: reflect.TypeOf(int64(0)), + }, + { + databaseTypeName: "TINYINT", + decimalSize: decimalSize{ + precision: 0, + scale: 0, + ok: false, + }, + length: length{ + length: 0, + ok: false, + }, + name: "TINT", + nullable: nullable{ + nullable: true, + ok: true, + }, + scanType: reflect.TypeOf(int64(0)), + }, + { + databaseTypeName: "UNSIGNED_TINYINT", + decimalSize: decimalSize{ + precision: 0, + scale: 0, + ok: false, + }, + length: length{ + length: 0, + ok: false, + }, + name: "UTINT", + nullable: nullable{ + nullable: true, + ok: true, + }, + scanType: reflect.TypeOf(int64(0)), + }, + { + databaseTypeName: "SMALLINT", + decimalSize: decimalSize{ + precision: 0, + scale: 0, + ok: false, + }, + length: length{ + length: 0, + ok: false, + }, + name: "SINT", + nullable: nullable{ + nullable: true, + ok: true, + }, + scanType: reflect.TypeOf(int64(0)), + }, + { + databaseTypeName: "UNSIGNED_SMALLINT", + decimalSize: decimalSize{ + precision: 0, + scale: 0, + ok: false, + }, + length: length{ + length: 0, + ok: false, + }, + name: "USINT", + nullable: nullable{ + nullable: true, + ok: true, + }, + scanType: reflect.TypeOf(int64(0)), + }, + { + databaseTypeName: "FLOAT", + decimalSize: decimalSize{ + precision: 0, + scale: 0, + ok: false, + }, + length: length{ + length: 0, + ok: false, + }, + name: "FLT", + nullable: nullable{ + nullable: true, + ok: true, + }, + scanType: reflect.TypeOf(float64(0)), + }, + { + databaseTypeName: "UNSIGNED_FLOAT", + decimalSize: decimalSize{ + precision: 0, + scale: 0, + ok: false, + }, + length: length{ + length: 0, + ok: false, + }, + name: "UFLT", + nullable: nullable{ + nullable: true, + ok: true, + }, + scanType: reflect.TypeOf(float64(0)), + }, + { + databaseTypeName: "DOUBLE", + decimalSize: decimalSize{ + precision: 0, + scale: 0, + ok: false, + }, + length: length{ + length: 0, + ok: false, + }, + name: "DBL", + nullable: nullable{ + nullable: true, + ok: true, + }, + scanType: reflect.TypeOf(float64(0)), + }, + { + databaseTypeName: "UNSIGNED_DOUBLE", + decimalSize: decimalSize{ + precision: 0, + scale: 0, + ok: false, + }, + length: length{ + length: 0, + ok: false, + }, + name: "UDBL", + nullable: nullable{ + nullable: true, + ok: true, + }, + scanType: reflect.TypeOf(float64(0)), + }, + { + databaseTypeName: "DECIMAL", + decimalSize: decimalSize{ + precision: 10, + scale: 5, + ok: true, + }, + length: length{ + length: 0, + ok: false, + }, + name: "DEC", + nullable: nullable{ + nullable: true, + ok: true, + }, + scanType: reflect.TypeOf(""), + }, + { + databaseTypeName: "DECIMAL", + decimalSize: decimalSize{ + precision: math.MaxInt64, + scale: math.MaxInt64, + ok: true, + }, + length: length{ + length: 0, + ok: false, + }, + name: "DEC2", + nullable: nullable{ + nullable: true, + ok: true, + }, + scanType: reflect.TypeOf(""), + }, + { + databaseTypeName: "BOOLEAN", + decimalSize: decimalSize{ + precision: 0, + scale: 0, + ok: false, + }, + length: length{ + length: 0, + ok: false, + }, + name: "BOOL", + nullable: nullable{ + nullable: true, + ok: true, + }, + scanType: reflect.TypeOf(bool(false)), + }, + { + databaseTypeName: "TIME", + decimalSize: decimalSize{ + precision: 0, + scale: 0, + ok: false, + }, + length: length{ + length: 0, + ok: false, + }, + name: "TM", + nullable: nullable{ + nullable: true, + ok: true, + }, + scanType: reflect.TypeOf(time.Time{}), + }, + { + databaseTypeName: "DATE", + decimalSize: decimalSize{ + precision: 0, + scale: 0, + ok: false, + }, + length: length{ + length: 0, + ok: false, + }, + name: "DT", + nullable: nullable{ + nullable: true, + ok: true, + }, + scanType: reflect.TypeOf(time.Time{}), + }, + { + databaseTypeName: "TIMESTAMP", + decimalSize: decimalSize{ + precision: 0, + scale: 0, + ok: false, + }, + length: length{ + length: 0, + ok: false, + }, + name: "TMSTMP", + nullable: nullable{ + nullable: true, + ok: true, + }, + scanType: reflect.TypeOf(time.Time{}), + }, + { + databaseTypeName: "UNSIGNED_TIME", + decimalSize: decimalSize{ + precision: 0, + scale: 0, + ok: false, + }, + length: length{ + length: 0, + ok: false, + }, + name: "UTM", + nullable: nullable{ + nullable: true, + ok: true, + }, + scanType: reflect.TypeOf(time.Time{}), + }, + { + databaseTypeName: "UNSIGNED_DATE", + decimalSize: decimalSize{ + precision: 0, + scale: 0, + ok: false, + }, + length: length{ + length: 0, + ok: false, + }, + name: "UDT", + nullable: nullable{ + nullable: true, + ok: true, + }, + scanType: reflect.TypeOf(time.Time{}), + }, + { + databaseTypeName: "UNSIGNED_TIMESTAMP", + decimalSize: decimalSize{ + precision: 0, + scale: 0, + ok: false, + }, + length: length{ + length: 0, + ok: false, + }, + name: "UTMSTMP", + nullable: nullable{ + nullable: true, + ok: true, + }, + scanType: reflect.TypeOf(time.Time{}), + }, + { + databaseTypeName: "VARCHAR", + decimalSize: decimalSize{ + precision: 0, + scale: 0, + ok: false, + }, + length: length{ + length: 10, + ok: true, + }, + name: "VAR", + nullable: nullable{ + nullable: true, + ok: true, + }, + scanType: reflect.TypeOf(""), + }, + { + databaseTypeName: "CHAR", + decimalSize: decimalSize{ + precision: 0, + scale: 0, + ok: false, + }, + length: length{ + length: 3, + ok: true, + }, + name: "CH", + nullable: nullable{ + nullable: true, + ok: true, + }, + scanType: reflect.TypeOf(""), + }, + { + databaseTypeName: "BINARY", + decimalSize: decimalSize{ + precision: 0, + scale: 0, + ok: false, + }, + length: length{ + length: 20, + ok: true, + }, + name: "BIN", + nullable: nullable{ + nullable: true, + ok: true, + }, + scanType: reflect.TypeOf([]byte{}), + }, + { + databaseTypeName: "VARBINARY", + decimalSize: decimalSize{ + precision: 0, + scale: 0, + ok: false, + }, + length: length{ + length: math.MaxInt64, + ok: true, + }, + name: "VARBIN", + nullable: nullable{ + nullable: true, + ok: true, + }, + scanType: reflect.TypeOf([]byte{}), + }, + } + + columnTypes, err := rows.ColumnTypes() + + if err != nil { + t.Errorf("Error getting column types: %s", err) + } + + for index, columnType := range columnTypes { + + expected := expectedColumnTypes[index] + + if columnType.DatabaseTypeName() != expected.databaseTypeName { + t.Errorf("Expected database type name for index %d to be %s, got %s", index, expected.databaseTypeName, columnType.DatabaseTypeName()) + } + + precision, scale, ok := columnType.DecimalSize() + + if precision != expected.decimalSize.precision { + t.Errorf("Expected decimal precision for index %d to be %d, got %d", index, expected.decimalSize.precision, precision) + } + + if scale != expected.decimalSize.scale { + t.Errorf("Expected decimal scale for index %d to be %d, got %d", index, expected.decimalSize.scale, scale) + } + + if ok != expected.decimalSize.ok { + t.Errorf("Expected decimal ok for index %d to be %t, got %t", index, expected.decimalSize.ok, ok) + } + + length, ok := columnType.Length() + + if length != expected.length.length { + t.Errorf("Expected length for index %d to be %d, got %d", index, expected.length.length, length) + } + + if ok != expected.length.ok { + t.Errorf("Expected length ok for index %d to be %t, got %t", index, expected.length.ok, ok) + } + + if columnType.Name() != expected.name { + t.Errorf("Expected column name for index %d to be %s, got %s", index, expected.name, columnType.Name()) + } + + nullable, ok := columnType.Nullable() + + if nullable != expected.nullable.nullable { + t.Errorf("Expected nullable for index %d to be %t, got %t", index, expected.nullable.nullable, nullable) + } + + if ok != expected.nullable.ok { + t.Errorf("Expected nullable ok for index %d to be %t, got %t", index, expected.nullable.ok, ok) + } + + if columnType.ScanType() != expected.scanType { + t.Errorf("Expected scan type for index %d to be %s, got %s", index, expected.scanType, columnType.ScanType()) + } + } + + }) +} http://git-wip-us.apache.org/repos/asf/calcite-avatica-go/blob/b4003a72/rows_go18.go ---------------------------------------------------------------------- diff --git a/rows_go18.go b/rows_go18.go index 9a783c2..20a3648 100644 --- a/rows_go18.go +++ b/rows_go18.go @@ -1,4 +1,5 @@ // +build go1.8 + package avatica import ( http://git-wip-us.apache.org/repos/asf/calcite-avatica-go/blob/b4003a72/statement_go18.go ---------------------------------------------------------------------- diff --git a/statement_go18.go b/statement_go18.go index e7ee65e..856f19c 100644 --- a/statement_go18.go +++ b/statement_go18.go @@ -1,12 +1,11 @@ // +build go1.8 + package avatica import ( + "context" "database/sql/driver" - "fmt" - - "golang.org/x/net/context" ) func (s *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {