F21 closed pull request #34: [CALCITE-2763] Fix handling of nils (nulls) when 
executing queries and scanning query results
URL: https://github.com/apache/calcite-avatica-go/pull/34
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/driver_hsqldb_test.go b/driver_hsqldb_test.go
index 539a80f..39143ed 100644
--- a/driver_hsqldb_test.go
+++ b/driver_hsqldb_test.go
@@ -286,6 +286,270 @@ func TestHSQLDBDataTypes(t *testing.T) {
        })
 }
 
+func TestHSQLDBSQLNullTypes(t *testing.T) {
+
+       skipTestIfNotHSQLDB(t)
+
+       runTests(t, dsn, func(dbt *DBTest) {
+
+               // Create and seed table
+               dbt.mustExec(`CREATE TABLE ` + dbt.tableName + ` (
+                               id INTEGER PRIMARY KEY,
+                               int INTEGER,
+                               tint TINYINT,
+                               sint SMALLINT,
+                               bint BIGINT,
+                               num NUMERIC(10,3),
+                               dec DECIMAL(10,3),
+                               re REAL,
+                               flt FLOAT,
+                               dbl DOUBLE,
+                               bool BOOLEAN,
+                               ch CHAR(3),
+                               var VARCHAR(128),
+                               bin BINARY(20),
+                               varbin VARBINARY(128),
+                               dt DATE,
+                               tmstmp TIMESTAMP,
+                           )`)
+
+               var (
+                       idValue                 = time.Now().Unix()
+                       integerValue            = sql.NullInt64{}
+                       tintValue               = sql.NullInt64{}
+                       sintValue               = sql.NullInt64{}
+                       bintValue               = sql.NullInt64{}
+                       numValue                = sql.NullString{}
+                       decValue                = sql.NullString{}
+                       reValue                 = sql.NullFloat64{}
+                       fltValue                = sql.NullFloat64{}
+                       dblValue                = sql.NullFloat64{}
+                       booleanValue            = sql.NullBool{}
+                       chValue                 = sql.NullString{}
+                       varcharValue            = sql.NullString{}
+                       binValue     *[]byte    = nil
+                       varbinValue  *[]byte    = nil
+                       dtValue      *time.Time = nil
+                       tmstmpValue  *time.Time = nil
+               )
+
+               dbt.mustExec(`INSERT INTO `+dbt.tableName+` (id, int, tint, 
sint, bint, num, dec, re, flt, dbl, bool, ch, var, bin, varbin, dt, tmstmp) 
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
+                       idValue,
+                       integerValue,
+                       tintValue,
+                       sintValue,
+                       bintValue,
+                       numValue,
+                       decValue,
+                       reValue,
+                       fltValue,
+                       dblValue,
+                       booleanValue,
+                       chValue,
+                       varcharValue,
+                       binValue,
+                       varbinValue,
+                       dtValue,
+                       tmstmpValue,
+               )
+
+               rows := dbt.mustQuery("SELECT * FROM "+dbt.tableName+" WHERE id 
= ?", idValue)
+               defer rows.Close()
+
+               var (
+                       id      int64
+                       integer sql.NullInt64
+                       tint    sql.NullInt64
+                       sint    sql.NullInt64
+                       bint    sql.NullInt64
+                       num     sql.NullString
+                       dec     sql.NullString
+                       re      sql.NullFloat64
+                       flt     sql.NullFloat64
+                       dbl     sql.NullFloat64
+                       boolean sql.NullBool
+                       ch      sql.NullString
+                       varchar sql.NullString
+                       bin     *[]byte
+                       varbin  *[]byte
+                       dt      *time.Time
+                       tmstmp  *time.Time
+               )
+
+               for rows.Next() {
+                       err := rows.Scan(&id, &integer, &tint, &sint, &bint, 
&num, &dec, &re, &flt, &dbl, &boolean, &ch, &varchar, &bin, &varbin, &dt, 
&tmstmp)
+
+                       if err != nil {
+                               dbt.Fatal(err)
+                       }
+               }
+
+               comparisons := []struct {
+                       result   interface{}
+                       expected interface{}
+               }{
+                       {integer, integerValue},
+                       {tint, tintValue},
+                       {sint, sintValue},
+                       {bint, bintValue},
+                       {num, numValue},
+                       {dec, decValue},
+                       {re, reValue},
+                       {flt, fltValue},
+                       {dbl, dblValue},
+                       {boolean, booleanValue},
+                       {ch, chValue},
+                       {varchar, varcharValue},
+                       {bin, binValue},
+                       {varbin, varbinValue},
+                       {dt, dtValue},
+                       {tmstmp, tmstmpValue},
+               }
+
+               for i, tt := range comparisons {
+
+                       if v, ok := tt.expected.(time.Time); ok {
+
+                               if !v.Equal(tt.result.(time.Time)) {
+                                       dbt.Fatalf("Expected %v for case %d, 
got %v.", tt.expected, i, tt.result)
+                               }
+
+                       } else if v, ok := tt.expected.([]byte); ok {
+
+                               if !bytes.Equal(v, tt.result.([]byte)) {
+                                       dbt.Fatalf("Expected %v for case %d, 
got %v.", tt.expected, i, tt.result)
+                               }
+
+                       } else if tt.expected != tt.result {
+                               dbt.Errorf("Expected %v for case %d, got %v.", 
tt.expected, i, tt.result)
+                       }
+               }
+       })
+}
+
+func TestHSQLDBNulls(t *testing.T) {
+
+       skipTestIfNotHSQLDB(t)
+
+       runTests(t, dsn, func(dbt *DBTest) {
+
+               // Create and seed table
+               dbt.mustExec(`CREATE TABLE ` + dbt.tableName + ` (
+                               id INTEGER PRIMARY KEY,
+                               int INTEGER,
+                               tint TINYINT,
+                               sint SMALLINT,
+                               bint BIGINT,
+                               num NUMERIC(10,3),
+                               dec DECIMAL(10,3),
+                               re REAL,
+                               flt FLOAT,
+                               dbl DOUBLE,
+                               bool BOOLEAN,
+                               ch CHAR(3),
+                               var VARCHAR(128),
+                               bin BINARY(20),
+                               varbin VARBINARY(128),
+                               dt DATE,
+                               tmstmp TIMESTAMP,
+                           )`)
+
+               idValue := time.Now().Unix()
+
+               dbt.mustExec(`INSERT INTO `+dbt.tableName+` (id, int, tint, 
sint, bint, num, dec, re, flt, dbl, bool, ch, var, bin, varbin, dt, tmstmp) 
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
+                       idValue,
+                       nil,
+                       nil,
+                       nil,
+                       nil,
+                       nil,
+                       nil,
+                       nil,
+                       nil,
+                       nil,
+                       nil,
+                       nil,
+                       nil,
+                       nil,
+                       nil,
+                       nil,
+                       nil,
+               )
+
+               rows := dbt.mustQuery("SELECT * FROM "+dbt.tableName+" WHERE id 
= ?", idValue)
+               defer rows.Close()
+
+               var (
+                       id      int64
+                       integer sql.NullInt64
+                       tint    sql.NullInt64
+                       sint    sql.NullInt64
+                       bint    sql.NullInt64
+                       num     sql.NullString
+                       dec     sql.NullString
+                       re      sql.NullFloat64
+                       flt     sql.NullFloat64
+                       dbl     sql.NullFloat64
+                       boolean sql.NullBool
+                       ch      sql.NullString
+                       varchar sql.NullString
+                       bin     *[]byte
+                       varbin  *[]byte
+                       dt      *time.Time
+                       tmstmp  *time.Time
+               )
+
+               for rows.Next() {
+                       err := rows.Scan(&id, &integer, &tint, &sint, &bint, 
&num, &dec, &re, &flt, &dbl, &boolean, &ch, &varchar, &bin, &varbin, &dt, 
&tmstmp)
+
+                       if err != nil {
+                               dbt.Fatal(err)
+                       }
+               }
+
+               comparisons := []struct {
+                       result   interface{}
+                       expected interface{}
+               }{
+                       {integer, sql.NullInt64{}},
+                       {tint, sql.NullInt64{}},
+                       {sint, sql.NullInt64{}},
+                       {bint, sql.NullInt64{}},
+                       {num, sql.NullString{}},
+                       {dec, sql.NullString{}},
+                       {re, sql.NullFloat64{}},
+                       {flt, sql.NullFloat64{}},
+                       {dbl, sql.NullFloat64{}},
+                       {boolean, sql.NullBool{}},
+                       {ch, sql.NullString{}},
+                       {varchar, sql.NullString{}},
+                       {bin, (*[]byte)(nil)},
+                       {varbin, (*[]byte)(nil)},
+                       {dt, (*time.Time)(nil)},
+                       {tmstmp, (*time.Time)(nil)},
+               }
+
+               for i, tt := range comparisons {
+
+                       if v, ok := tt.expected.(time.Time); ok {
+
+                               if !v.Equal(tt.result.(time.Time)) {
+                                       dbt.Fatalf("Expected %v for case %d, 
got %v.", tt.expected, i, tt.result)
+                               }
+
+                       } else if v, ok := tt.expected.([]byte); ok {
+
+                               if !bytes.Equal(v, tt.result.([]byte)) {
+                                       dbt.Fatalf("Expected %v for case %d, 
got %v.", tt.expected, i, tt.result)
+                               }
+
+                       } else if tt.expected != tt.result {
+                               dbt.Errorf("Expected %v for case %d, got %v.", 
tt.expected, i, tt.result)
+                       }
+               }
+       })
+}
+
 // TODO: Test case commented out due to CALCITE-1951
 /*func TestHSQLDBLocations(t *testing.T) {
 
diff --git a/driver_phoenix_test.go b/driver_phoenix_test.go
index f004413..90892e2 100644
--- a/driver_phoenix_test.go
+++ b/driver_phoenix_test.go
@@ -93,7 +93,7 @@ func TestPhoenixZeroValues(t *testing.T) {
                        var i int
                        var flt float64
                        var b bool
-                       var s string
+                       var s sql.NullString
 
                        err := rows.Scan(&i, &flt, &b, &s)
 
@@ -113,8 +113,8 @@ func TestPhoenixZeroValues(t *testing.T) {
                                dbt.Fatalf("Boolean should be false, got %v", b)
                        }
 
-                       if s != "" {
-                               dbt.Fatalf("String should be \"\", got %v", s)
+                       if val, _ := s.Value(); val != nil {
+                               dbt.Fatalf("String should be nil, got %v", s)
                        }
                }
 
@@ -301,6 +301,344 @@ func TestPhoenixDataTypes(t *testing.T) {
        })
 }
 
+func TestPhoenixSQLNullTypes(t *testing.T) {
+
+       skipTestIfNotPhoenix(t)
+
+       runTests(t, dsn, func(dbt *DBTest) {
+
+               // Create and seed table
+               dbt.mustExec(`CREATE TABLE ` + dbt.tableName + ` (
+                               id INTEGER PRIMARY KEY,
+                               int INTEGER,
+                               uint UNSIGNED_INT,
+                               bint BIGINT,
+                               ulong UNSIGNED_LONG,
+                               tint TINYINT,
+                               utint UNSIGNED_TINYINT,
+                               sint SMALLINT,
+                               usint UNSIGNED_SMALLINT,
+                               flt FLOAT,
+                               uflt UNSIGNED_FLOAT,
+                               dbl DOUBLE,
+                               udbl UNSIGNED_DOUBLE,
+                               dec DECIMAL,
+                               bool BOOLEAN,
+                               tm TIME,
+                               dt DATE,
+                               tmstmp TIMESTAMP,
+                               utm UNSIGNED_TIME,
+                               udt UNSIGNED_DATE,
+                               utmstmp UNSIGNED_TIMESTAMP,
+                               var VARCHAR,
+                               ch CHAR(3),
+                               bin BINARY(20),
+                               varbin VARBINARY
+                           ) TRANSACTIONAL=false`)
+
+               var (
+                       idValue                  = time.Now().Unix()
+                       integerValue             = sql.NullInt64{}
+                       uintegerValue            = sql.NullInt64{}
+                       bintValue                = sql.NullInt64{}
+                       ulongValue               = sql.NullInt64{}
+                       tintValue                = sql.NullInt64{}
+                       utintValue               = sql.NullInt64{}
+                       sintValue                = sql.NullInt64{}
+                       usintValue               = sql.NullInt64{}
+                       fltValue                 = sql.NullFloat64{}
+                       ufltValue                = sql.NullFloat64{}
+                       dblValue                 = sql.NullFloat64{}
+                       udblValue                = sql.NullFloat64{}
+                       decValue                 = sql.NullString{}
+                       booleanValue             = sql.NullBool{}
+                       tmValue       *time.Time = nil
+                       dtValue       *time.Time = nil
+                       tmstmpValue   *time.Time = nil
+                       utmValue      *time.Time = nil
+                       udtValue      *time.Time = nil
+                       utmstmpValue  *time.Time = nil
+                       varcharValue             = sql.NullString{}
+                       chValue                  = sql.NullString{}
+                       binValue      *[]byte    = nil
+                       varbinValue   *[]byte    = nil
+               )
+
+               dbt.mustExec(`UPSERT INTO `+dbt.tableName+` VALUES (?, ?, ?, ?, 
?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
+                       idValue,
+                       integerValue,
+                       uintegerValue,
+                       bintValue,
+                       ulongValue,
+                       tintValue,
+                       utintValue,
+                       sintValue,
+                       usintValue,
+                       fltValue,
+                       ufltValue,
+                       dblValue,
+                       udblValue,
+                       decValue,
+                       booleanValue,
+                       tmValue,
+                       dtValue,
+                       tmstmpValue,
+                       utmValue,
+                       udtValue,
+                       utmstmpValue,
+                       varcharValue,
+                       chValue,
+                       binValue,
+                       varbinValue,
+               )
+
+               rows := dbt.mustQuery("SELECT * FROM "+dbt.tableName+" WHERE id 
= ?", idValue)
+               defer rows.Close()
+
+               var (
+                       id       int64
+                       integer  sql.NullInt64
+                       uinteger sql.NullInt64
+                       bint     sql.NullInt64
+                       ulong    sql.NullInt64
+                       tint     sql.NullInt64
+                       utint    sql.NullInt64
+                       sint     sql.NullInt64
+                       usint    sql.NullInt64
+                       flt      sql.NullFloat64
+                       uflt     sql.NullFloat64
+                       dbl      sql.NullFloat64
+                       udbl     sql.NullFloat64
+                       dec      sql.NullString
+                       boolean  sql.NullBool
+                       tm       *time.Time
+                       dt       *time.Time
+                       tmstmp   *time.Time
+                       utm      *time.Time
+                       udt      *time.Time
+                       utmstmp  *time.Time
+                       varchar  sql.NullString
+                       ch       sql.NullString
+                       bin      *[]byte
+                       varbin   *[]byte
+               )
+
+               for rows.Next() {
+
+                       err := rows.Scan(&id, &integer, &uinteger, &bint, 
&ulong, &tint, &utint, &sint, &usint, &flt, &uflt, &dbl, &udbl, &dec, &boolean, 
&tm, &dt, &tmstmp, &utm, &udt, &utmstmp, &varchar, &ch, &bin, &varbin)
+
+                       if err != nil {
+                               dbt.Fatal(err)
+                       }
+               }
+
+               comparisons := []struct {
+                       result   interface{}
+                       expected interface{}
+               }{
+                       {integer, integerValue},
+                       {uinteger, uintegerValue},
+                       {bint, bintValue},
+                       {ulong, ulongValue},
+                       {tint, tintValue},
+                       {utint, utintValue},
+                       {sint, sintValue},
+                       {usint, usintValue},
+                       {flt, fltValue},
+                       {uflt, ufltValue},
+                       {dbl, dblValue},
+                       {udbl, udblValue},
+                       {dec, decValue},
+                       {boolean, booleanValue},
+                       {tm, tmValue},
+                       {dt, dtValue},
+                       {tmstmp, tmstmpValue},
+                       {utm, utmValue},
+                       {udt, udtValue},
+                       {utmstmp, utmstmpValue},
+                       {varchar, varcharValue},
+                       {ch, chValue},
+                       {*bin, []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 
0, 0, 0, 0, 0, 0}},
+                       {varbin, varbinValue},
+               }
+
+               for i, tt := range comparisons {
+
+                       if v, ok := tt.expected.(time.Time); ok {
+
+                               if !v.Equal(tt.result.(time.Time)) {
+                                       dbt.Fatalf("Expected %v for case %d, 
got %v.", tt.expected, i, tt.result)
+                               }
+
+                       } else if v, ok := tt.expected.([]byte); ok {
+
+                               if !bytes.Equal(v, tt.result.([]byte)) {
+                                       dbt.Fatalf("Expected %v for case %d, 
got %v.", tt.expected, i, tt.result)
+                               }
+
+                       } else if tt.expected != tt.result {
+                               dbt.Errorf("Expected %v for case %d, got %v.", 
tt.expected, i, tt.result)
+                       }
+               }
+       })
+}
+
+func TestPhoenixNulls(t *testing.T) {
+
+       skipTestIfNotPhoenix(t)
+
+       runTests(t, dsn, func(dbt *DBTest) {
+
+               // Create and seed table
+               dbt.mustExec(`CREATE TABLE ` + dbt.tableName + ` (
+                               id INTEGER PRIMARY KEY,
+                               int INTEGER,
+                               uint UNSIGNED_INT,
+                               bint BIGINT,
+                               ulong UNSIGNED_LONG,
+                               tint TINYINT,
+                               utint UNSIGNED_TINYINT,
+                               sint SMALLINT,
+                               usint UNSIGNED_SMALLINT,
+                               flt FLOAT,
+                               uflt UNSIGNED_FLOAT,
+                               dbl DOUBLE,
+                               udbl UNSIGNED_DOUBLE,
+                               dec DECIMAL,
+                               bool BOOLEAN,
+                               tm TIME,
+                               dt DATE,
+                               tmstmp TIMESTAMP,
+                               utm UNSIGNED_TIME,
+                               udt UNSIGNED_DATE,
+                               utmstmp UNSIGNED_TIMESTAMP,
+                               var VARCHAR,
+                               ch CHAR(3),
+                               bin BINARY(20),
+                               varbin VARBINARY
+                           ) TRANSACTIONAL=false`)
+
+               idValue := time.Now().Unix()
+
+               dbt.mustExec(`UPSERT INTO `+dbt.tableName+` VALUES (?, ?, ?, ?, 
?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
+                       idValue,
+                       nil,
+                       nil,
+                       nil,
+                       nil,
+                       nil,
+                       nil,
+                       nil,
+                       nil,
+                       nil,
+                       nil,
+                       nil,
+                       nil,
+                       nil,
+                       nil,
+                       nil,
+                       nil,
+                       nil,
+                       nil,
+                       nil,
+                       nil,
+                       nil,
+                       nil,
+                       nil,
+                       nil,
+               )
+
+               rows := dbt.mustQuery("SELECT * FROM "+dbt.tableName+" WHERE id 
= ?", idValue)
+               defer rows.Close()
+
+               var (
+                       id       int64
+                       integer  sql.NullInt64
+                       uinteger sql.NullInt64
+                       bint     sql.NullInt64
+                       ulong    sql.NullInt64
+                       tint     sql.NullInt64
+                       utint    sql.NullInt64
+                       sint     sql.NullInt64
+                       usint    sql.NullInt64
+                       flt      sql.NullFloat64
+                       uflt     sql.NullFloat64
+                       dbl      sql.NullFloat64
+                       udbl     sql.NullFloat64
+                       dec      sql.NullString
+                       boolean  sql.NullBool
+                       tm       *time.Time
+                       dt       *time.Time
+                       tmstmp   *time.Time
+                       utm      *time.Time
+                       udt      *time.Time
+                       utmstmp  *time.Time
+                       varchar  sql.NullString
+                       ch       sql.NullString
+                       bin      *[]byte
+                       varbin   *[]byte
+               )
+
+               for rows.Next() {
+
+                       err := rows.Scan(&id, &integer, &uinteger, &bint, 
&ulong, &tint, &utint, &sint, &usint, &flt, &uflt, &dbl, &udbl, &dec, &boolean, 
&tm, &dt, &tmstmp, &utm, &udt, &utmstmp, &varchar, &ch, &bin, &varbin)
+
+                       if err != nil {
+                               dbt.Fatal(err)
+                       }
+               }
+
+               comparisons := []struct {
+                       result   interface{}
+                       expected interface{}
+               }{
+                       {integer, sql.NullInt64{}},
+                       {uinteger, sql.NullInt64{}},
+                       {bint, sql.NullInt64{}},
+                       {ulong, sql.NullInt64{}},
+                       {tint, sql.NullInt64{}},
+                       {utint, sql.NullInt64{}},
+                       {sint, sql.NullInt64{}},
+                       {usint, sql.NullInt64{}},
+                       {flt, sql.NullFloat64{}},
+                       {uflt, sql.NullFloat64{}},
+                       {dbl, sql.NullFloat64{}},
+                       {udbl, sql.NullFloat64{}},
+                       {dec, sql.NullString{}},
+                       {boolean, sql.NullBool{}},
+                       {tm, (*time.Time)(nil)},
+                       {dt, (*time.Time)(nil)},
+                       {tmstmp, (*time.Time)(nil)},
+                       {utm, (*time.Time)(nil)},
+                       {udt, (*time.Time)(nil)},
+                       {utmstmp, (*time.Time)(nil)},
+                       {varchar, sql.NullString{}},
+                       {ch, sql.NullString{}},
+                       {*bin, []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 
0, 0, 0, 0, 0, 0}},
+                       {varbin, (*[]byte)(nil)},
+               }
+
+               for i, tt := range comparisons {
+
+                       if v, ok := tt.expected.(time.Time); ok {
+
+                               if !v.Equal(tt.result.(time.Time)) {
+                                       dbt.Fatalf("Expected %v for case %d, 
got %v.", tt.expected, i, tt.result)
+                               }
+
+                       } else if v, ok := tt.expected.([]byte); ok {
+
+                               if !bytes.Equal(v, tt.result.([]byte)) {
+                                       dbt.Fatalf("Expected %v for case %d, 
got %v.", tt.expected, i, tt.result)
+                               }
+
+                       } else if tt.expected != tt.result {
+                               dbt.Errorf("Expected %v for case %d, got %v.", 
tt.expected, i, tt.result)
+                       }
+               }
+       })
+}
+
 func TestPhoenixLocations(t *testing.T) {
 
        skipTestIfNotPhoenix(t)
diff --git a/rows.go b/rows.go
index fa3eddb..8fe7f1e 100644
--- a/rows.go
+++ b/rows.go
@@ -180,7 +180,13 @@ func newRows(conn *conn, statementID uint32, resultSets 
[]*message.ResultSetResp
 
 // typedValueToNative converts values from avatica's types to Go's native types
 func typedValueToNative(rep message.Rep, v *message.TypedValue, config 
*Config) interface{} {
+
+       if v.Type == message.Rep_NULL {
+               return nil
+       }
+
        switch rep {
+
        case message.Rep_BOOLEAN, message.Rep_PRIMITIVE_BOOLEAN:
                return v.BoolValue
 
diff --git a/statement.go b/statement.go
index 3e3cab2..af2e2ce 100644
--- a/statement.go
+++ b/statement.go
@@ -156,6 +156,7 @@ func (s *stmt) parametersToTypedValues(vals []namedValue) 
[]*message.TypedValue
                typed := message.TypedValue{}
                if val.Value == nil {
                        typed.Null = true
+                       typed.Type = message.Rep_NULL
                } else {
 
                        switch v := val.Value.(type) {


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to