Implement multiple result sets, row type information and support for all 
phoenix data types


Project: http://git-wip-us.apache.org/repos/asf/calcite-avatica-go/repo
Commit: 
http://git-wip-us.apache.org/repos/asf/calcite-avatica-go/commit/ae26325d
Tree: http://git-wip-us.apache.org/repos/asf/calcite-avatica-go/tree/ae26325d
Diff: http://git-wip-us.apache.org/repos/asf/calcite-avatica-go/diff/ae26325d

Branch: refs/heads/master
Commit: ae26325d3a2dd737a0fdf9d172b354122301f271
Parents: 261f94f
Author: Francis Chuang <francis.chu...@boostport.com>
Authored: Wed Mar 8 15:37:59 2017 +1100
Committer: Julian Hyde <jh...@apache.org>
Committed: Thu Aug 10 18:47:10 2017 -0700

----------------------------------------------------------------------
 connection.go  |   7 +-
 driver_test.go |  27 ++++++-
 rows.go        | 215 ++++++++++++++++++++++++++++++++++++++--------------
 rows_go18.go   |  59 ++++++++++++++
 statement.go   |  18 +++--
 5 files changed, 255 insertions(+), 71 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/calcite-avatica-go/blob/ae26325d/connection.go
----------------------------------------------------------------------
diff --git a/connection.go b/connection.go
index 2149261..9e59701 100644
--- a/connection.go
+++ b/connection.go
@@ -134,7 +134,7 @@ func (c *conn) exec(ctx context.Context, query string, args 
[]namedValue) (drive
                return nil, err
        }
 
-       // Currently there is only 1 ResultSet per response
+       // Currently there is only 1 ResultSet per response for exec
        changed := int64(res.(*message.ExecuteResponse).Results[0].UpdateCount)
 
        return &result{
@@ -178,8 +178,7 @@ func (c *conn) query(ctx context.Context, query string, 
args []namedValue) (driv
                return nil, err
        }
 
-       // Currently there is only 1 ResultSet per response
-       resultSet := res.(*message.ExecuteResponse).Results[0]
+       resultSets := res.(*message.ExecuteResponse).Results
 
-       return newRows(c, st.(*message.CreateStatementResponse).StatementId, 
resultSet), nil
+       return newRows(c, st.(*message.CreateStatementResponse).StatementId, 
resultSets), nil
 }

http://git-wip-us.apache.org/repos/asf/calcite-avatica-go/blob/ae26325d/driver_test.go
----------------------------------------------------------------------
diff --git a/driver_test.go b/driver_test.go
index 017920b..f2035f7 100644
--- a/driver_test.go
+++ b/driver_test.go
@@ -5,13 +5,14 @@ import (
        "crypto/sha256"
        "database/sql"
        "fmt"
-       "github.com/satori/go.uuid"
        "io/ioutil"
        "os"
        "path/filepath"
        "strings"
        "testing"
        "time"
+
+       "github.com/satori/go.uuid"
 )
 
 var (
@@ -187,10 +188,14 @@ func TestDataTypes(t *testing.T) {
                                uflt UNSIGNED_FLOAT,
                                dbl DOUBLE,
                                udbl UNSIGNED_DOUBLE,
+                               dec DECIMAL,
                                bool BOOLEAN,
                                tm TIME,
                                dt DATE,
                                tmstmp TIMESTAMP,
+                               utm UNSIGNED_TIME,
+                               udt UNSIGNED_DATE,
+                               utmstmp UNSIGNED_TIMESTAMP,
                                var VARCHAR,
                                ch CHAR(3),
                                bin BINARY(20),
@@ -210,10 +215,14 @@ func TestDataTypes(t *testing.T) {
                        ufltValue     float64   = 3.555
                        dblValue      float64   = -9.555
                        udblValue     float64   = 9.555
+                       decValue      string    = "1.333"
                        booleanValue  bool      = true
                        tmValue       time.Time = time.Date(0, 1, 1, 21, 21, 
21, 222000000, time.UTC)
                        dtValue       time.Time = time.Date(2100, 2, 1, 0, 0, 
0, 0, time.UTC)
                        tmstmpValue   time.Time = time.Date(2100, 2, 1, 21, 21, 
21, 222000000, time.UTC)
+                       utmValue      time.Time = time.Date(0, 1, 1, 21, 21, 
21, 222000000, time.UTC)
+                       udtValue      time.Time = time.Date(2100, 2, 1, 0, 0, 
0, 0, time.UTC)
+                       utmstmpValue  time.Time = time.Date(2100, 2, 1, 21, 21, 
21, 222000000, time.UTC)
                        varcharValue  string    = "test string"
                        chValue       string    = "a"
                        binValue      []byte    = make([]byte, 20, 20)
@@ -222,7 +231,7 @@ func TestDataTypes(t *testing.T) {
 
                copy(binValue[:], "test")
 
-               dbt.mustExec(`UPSERT INTO `+dbt.tableName+` VALUES (?, ?, ?, ?, 
?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
+               dbt.mustExec(`UPSERT INTO `+dbt.tableName+` VALUES (?, ?, ?, ?, 
?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
                        integerValue,
                        uintegerValue,
                        bintValue,
@@ -235,10 +244,14 @@ func TestDataTypes(t *testing.T) {
                        ufltValue,
                        dblValue,
                        udblValue,
+                       decValue,
                        booleanValue,
                        tmValue,
                        dtValue,
                        tmstmpValue,
+                       utmValue,
+                       udtValue,
+                       utmstmpValue,
                        varcharValue,
                        chValue,
                        binValue,
@@ -261,10 +274,14 @@ func TestDataTypes(t *testing.T) {
                        uflt     float64
                        dbl      float64
                        udbl     float64
+                       dec      string
                        boolean  bool
                        tm       time.Time
                        dt       time.Time
                        tmstmp   time.Time
+                       utm      time.Time
+                       udt      time.Time
+                       utmstmp  time.Time
                        varchar  string
                        ch       string
                        bin      []byte
@@ -273,7 +290,7 @@ func TestDataTypes(t *testing.T) {
 
                for rows.Next() {
 
-                       err := rows.Scan(&integer, &uinteger, &bint, &ulong, 
&tint, &utint, &sint, &usint, &flt, &uflt, &dbl, &udbl, &boolean, &tm, &dt, 
&tmstmp, &varchar, &ch, &bin, &varbin)
+                       err := rows.Scan(&integer, &uinteger, &bint, &ulong, 
&tint, &utint, &sint, &usint, &flt, &uflt, &dbl, &udbl, &dec, &boolean, &tm, 
&dt, &tmstmp, &utm, &udt, &utmstmp, &varchar, &ch, &bin, &varbin)
 
                        if err != nil {
                                dbt.Fatal(err)
@@ -296,10 +313,14 @@ func TestDataTypes(t *testing.T) {
                        {uflt, ufltValue},
                        {dbl, dblValue},
                        {udbl, udblValue},
+                       {dec, decValue},
                        {boolean, booleanValue},
                        {tm, tmValue},
                        {dt, dtValue},
                        {tmstmp, tmstmpValue},
+                       {utm, utmValue},
+                       {udt, udtValue},
+                       {utmstmp, utmstmpValue},
                        {varchar, varcharValue},
                        {ch, chValue},
                        {bin, binValue},

http://git-wip-us.apache.org/repos/asf/calcite-avatica-go/blob/ae26325d/rows.go
----------------------------------------------------------------------
diff --git a/rows.go b/rows.go
index 82b67ff..4d69ff7 100644
--- a/rows.go
+++ b/rows.go
@@ -2,21 +2,47 @@ package avatica
 
 import (
        "database/sql/driver"
-       "github.com/Boostport/avatica/message"
-       "golang.org/x/net/context"
        "io"
        "time"
+
+       "reflect"
+
+       "math"
+
+       "fmt"
+
+       "github.com/Boostport/avatica/message"
+       "golang.org/x/net/context"
 )
 
+type precisionScale struct {
+       precision int64
+       scale     int64
+}
+
+type column struct {
+       name           string
+       typeName       string
+       rep            message.Rep
+       length         int64
+       nullable       bool
+       precisionScale *precisionScale
+       scanType       reflect.Type
+}
+
+type resultSet struct {
+       columns    []*column
+       done       bool
+       offset     uint64
+       data       [][]*message.TypedValue
+       currentRow int
+}
+
 type rows struct {
-       conn        *conn
-       statementID uint32
-       columnNames []string
-       columnTypes []message.Rep
-       done        bool
-       offset      uint64
-       data        [][]*message.TypedValue
-       currentRow  int
+       conn             *conn
+       statementID      uint32
+       resultSets       []*resultSet
+       currentResultSet int
 }
 
 // Columns returns the names of the columns. The number of
@@ -24,7 +50,14 @@ type rows struct {
 // slice.  If a particular column name isn't known, an empty
 // string should be returned for that entry.
 func (r *rows) Columns() []string {
-       return r.columnNames
+
+       cols := []string{}
+
+       for _, column := range r.resultSets[r.currentResultSet].columns {
+               cols = append(cols, column.name)
+       }
+
+       return cols
 }
 
 // Close closes the rows iterator.
@@ -45,9 +78,11 @@ func (r *rows) Close() error {
 // Next should return io.EOF when there are no more rows.
 func (r *rows) Next(dest []driver.Value) error {
 
-       if r.currentRow >= len(r.data) {
+       resultSet := r.resultSets[r.currentResultSet]
+
+       if resultSet.currentRow >= len(resultSet.data) {
 
-               if r.done {
+               if resultSet.done {
                        // Finished iterating through all results
                        return io.EOF
                }
@@ -56,7 +91,7 @@ func (r *rows) Next(dest []driver.Value) error {
                res, err := r.conn.httpClient.post(context.Background(), 
&message.FetchRequest{
                        ConnectionId: r.conn.connectionId,
                        StatementId:  r.statementID,
-                       Offset:       r.offset,
+                       Offset:       resultSet.offset,
                        FrameMaxSize: r.conn.config.frameMaxSize,
                })
 
@@ -84,69 +119,135 @@ func (r *rows) Next(dest []driver.Value) error {
                        data = append(data, rowData)
                }
 
-               r.done = frame.Done
-               r.data = data
-               r.currentRow = 0
+               resultSet.done = frame.Done
+               resultSet.data = data
+               resultSet.currentRow = 0
 
        }
 
-       for i, val := range r.data[r.currentRow] {
-               dest[i] = typedValueToNative(r.columnTypes[i], val, 
r.conn.config)
+       for i, val := range resultSet.data[resultSet.currentRow] {
+               dest[i] = typedValueToNative(resultSet.columns[i].rep, val, 
r.conn.config)
        }
 
-       r.currentRow++
+       resultSet.currentRow++
 
        return nil
 }
 
 // newRows create a new set of rows from a result set.
-func newRows(conn *conn, statementID uint32, resultSet 
*message.ResultSetResponse) *rows {
-
-       columnNames := []string{}
-       columnTypes := []message.Rep{}
-
-       for _, col := range resultSet.Signature.Columns {
-               columnNames = append(columnNames, col.ColumnName)
-
-               // Special case for floats, date, time and timestamp
-               switch col.Type.Name {
-               case "FLOAT":
-                       columnTypes = append(columnTypes, message.Rep_FLOAT)
-               case "UNSIGNED_FLOAT":
-                       columnTypes = append(columnTypes, message.Rep_FLOAT)
-               case "TIME":
-                       columnTypes = append(columnTypes, 
message.Rep_JAVA_SQL_TIME)
-               case "DATE":
-                       columnTypes = append(columnTypes, 
message.Rep_JAVA_SQL_DATE)
-               case "TIMESTAMP":
-                       columnTypes = append(columnTypes, 
message.Rep_JAVA_SQL_TIMESTAMP)
-               default:
-                       columnTypes = append(columnTypes, col.Type.Rep)
+func newRows(conn *conn, statementID uint32, resultSets 
[]*message.ResultSetResponse) *rows {
+
+       rsets := []*resultSet{}
+
+       for _, result := range resultSets {
+               columns := []*column{}
+
+               for _, col := range result.Signature.Columns {
+
+                       column := &column{
+                               name:     col.ColumnName,
+                               typeName: col.Type.Name,
+                               nullable: col.Nullable != 0,
+                       }
+
+                       // Handle precision and length
+                       switch col.Type.Name {
+                       case "DECIMAL":
+
+                               precision := int64(col.Precision)
+
+                               if precision == 0 {
+                                       precision = math.MaxInt64
+                               }
+
+                               scale := int64(col.Scale)
+
+                               if scale == 0 {
+                                       scale = math.MaxInt64
+                               }
+
+                               column.precisionScale = &precisionScale{
+                                       precision: precision,
+                                       scale:     scale,
+                               }
+                       case "VARCHAR", "CHAR", "BINARY":
+                               column.length = int64(col.Precision)
+                       case "VARBINARY":
+                               column.length = math.MaxInt64
+                       }
+
+                       // Handle scan types
+                       switch col.Type.Name {
+                       case "INTEGER", "UNSIGNED_INT", "BIGINT", 
"UNSIGNED_LONG", "TINYINT", "UNSIGNED_TINYINT", "SMALLINT", "UNSIGNED_SMALLINT":
+                               column.scanType = reflect.TypeOf(int64(0))
+
+                       case "FLOAT", "UNSIGNED_FLOAT", "DOUBLE", 
"UNSIGNED_DOUBLE":
+                               column.scanType = reflect.TypeOf(float64(0))
+
+                       case "DECIMAL", "VARCHAR", "CHAR":
+                               column.scanType = reflect.TypeOf("")
+
+                       case "BOOLEAN":
+                               column.scanType = reflect.TypeOf(bool(false))
+
+                       case "TIME", "DATE", "TIMESTAMP", "UNSIGNED_TIME", 
"UNSIGNED_DATE", "UNSIGNED_TIMESTAMP":
+                               column.scanType = reflect.TypeOf(time.Time{})
+
+                       case "BINARY", "VARBINARY":
+                               column.scanType = reflect.TypeOf([]byte{})
+
+                       default:
+                               panic(fmt.Sprintf("scantype for %s is not 
implemented", col.Type.Name))
+                       }
+
+                       // Handle rep type special cases for decimals, floats, 
date, time and timestamp
+                       switch col.Type.Name {
+                       case "DECIMAL":
+                               column.rep = message.Rep_BIG_DECIMAL
+                       case "FLOAT":
+                               column.rep = message.Rep_FLOAT
+                       case "UNSIGNED_FLOAT":
+                               column.rep = message.Rep_FLOAT
+                       case "TIME", "UNSIGNED_TIME":
+                               column.rep = message.Rep_JAVA_SQL_TIME
+                       case "DATE", "UNSIGNED_DATE":
+                               column.rep = message.Rep_JAVA_SQL_DATE
+                       case "TIMESTAMP", "UNSIGNED_TIMESTAMP":
+                               column.rep = message.Rep_JAVA_SQL_TIMESTAMP
+                       default:
+                               column.rep = col.Type.Rep
+                       }
+
+                       columns = append(columns, column)
                }
-       }
 
-       frame := resultSet.FirstFrame
+               frame := result.FirstFrame
 
-       data := [][]*message.TypedValue{}
+               data := [][]*message.TypedValue{}
 
-       for _, row := range frame.Rows {
-               rowData := []*message.TypedValue{}
+               for _, row := range frame.Rows {
+                       rowData := []*message.TypedValue{}
 
-               for _, col := range row.Value {
-                       rowData = append(rowData, col.ScalarValue)
+                       for _, col := range row.Value {
+                               rowData = append(rowData, col.ScalarValue)
+                       }
+
+                       data = append(data, rowData)
                }
 
-               data = append(data, rowData)
+               rsets = append(rsets, &resultSet{
+                       columns: columns,
+                       done:    frame.Done,
+                       offset:  frame.Offset,
+                       data:    data,
+               })
        }
 
        return &rows{
-               conn:        conn,
-               statementID: statementID,
-               columnNames: columnNames,
-               columnTypes: columnTypes,
-               done:        frame.Done,
-               offset:      frame.Offset,
-               data:        data,
+               conn:             conn,
+               statementID:      statementID,
+               resultSets:       rsets,
+               currentResultSet: 0,
        }
 }
 

http://git-wip-us.apache.org/repos/asf/calcite-avatica-go/blob/ae26325d/rows_go18.go
----------------------------------------------------------------------
diff --git a/rows_go18.go b/rows_go18.go
new file mode 100644
index 0000000..9a783c2
--- /dev/null
+++ b/rows_go18.go
@@ -0,0 +1,59 @@
+// +build go1.8
+package avatica
+
+import (
+       "io"
+       "reflect"
+)
+
+func (r *rows) HasNextResultSet() bool {
+       lastResultSetID := len(r.resultSets) - 1
+       return lastResultSetID > r.currentResultSet
+}
+
+func (r *rows) NextResultSet() error {
+
+       lastResultSetID := len(r.resultSets) - 1
+
+       if r.currentResultSet+1 > lastResultSetID {
+               return io.EOF
+       }
+
+       r.currentResultSet++
+
+       return nil
+}
+
+func (r *rows) ColumnTypeDatabaseTypeName(index int) string {
+
+       return r.resultSets[r.currentResultSet].columns[index].typeName
+}
+
+func (r *rows) ColumnTypeLength(index int) (length int64, ok bool) {
+       l := r.resultSets[r.currentResultSet].columns[index].length
+
+       if l == 0 {
+               return 0, false
+       }
+
+       return l, true
+}
+
+func (r *rows) ColumnTypeNullable(index int) (nullable, ok bool) {
+       return r.resultSets[r.currentResultSet].columns[index].nullable, true
+}
+
+func (r *rows) ColumnTypePrecisionScale(index int) (precision, scale int64, ok 
bool) {
+
+       ps := r.resultSets[r.currentResultSet].columns[index].precisionScale
+
+       if ps != nil {
+               return ps.precision, ps.scale, true
+       }
+
+       return 0, 0, false
+}
+
+func (r *rows) ColumnTypeScanType(index int) reflect.Type {
+       return r.resultSets[r.currentResultSet].columns[index].scanType
+}

http://git-wip-us.apache.org/repos/asf/calcite-avatica-go/blob/ae26325d/statement.go
----------------------------------------------------------------------
diff --git a/statement.go b/statement.go
index 1861a93..3e6b64a 100644
--- a/statement.go
+++ b/statement.go
@@ -98,8 +98,7 @@ func (s *stmt) query(ctx context.Context, args []namedValue) 
(driver.Rows, error
                return nil, err
        }
 
-       // Currently there is only 1 ResultSet per response
-       resultSet := res.(*message.ExecuteResponse).Results[0]
+       resultSet := res.(*message.ExecuteResponse).Results
 
        return newRows(s.conn, s.statementID, resultSet), nil
 }
@@ -110,7 +109,6 @@ func (s *stmt) parametersToTypedValues(vals []namedValue) 
[]*message.TypedValue
 
        for i, val := range vals {
                typed := message.TypedValue{}
-
                if val.Value == nil {
                        typed.Null = true
                } else {
@@ -129,13 +127,19 @@ func (s *stmt) parametersToTypedValues(vals []namedValue) 
[]*message.TypedValue
                                typed.Type = message.Rep_BYTE_STRING
                                typed.BytesValue = v
                        case string:
-                               typed.Type = message.Rep_STRING
+
+                               if s.parameters[i].TypeName == "DECIMAL" {
+                                       typed.Type = message.Rep_BIG_DECIMAL
+                               } else {
+                                       typed.Type = message.Rep_STRING
+                               }
                                typed.StringValue = v
+
                        case time.Time:
                                avaticaParameter := s.parameters[i]
 
                                switch avaticaParameter.TypeName {
-                               case "TIME":
+                               case "TIME", "UNSIGNED_TIME":
                                        typed.Type = message.Rep_JAVA_SQL_TIME
 
                                        // Because a location can have multiple 
time zones due to daylight savings,
@@ -146,7 +150,7 @@ func (s *stmt) parametersToTypedValues(vals []namedValue) 
[]*message.TypedValue
                                        base := time.Date(v.Year(), v.Month(), 
v.Day(), 0, 0, 0, 0, time.FixedZone(zone, offset))
                                        typed.NumberValue = 
int64(v.Sub(base).Nanoseconds() / int64(time.Millisecond))
 
-                               case "DATE":
+                               case "DATE", "UNSIGNED_DATE":
                                        typed.Type = message.Rep_JAVA_SQL_DATE
 
                                        // Because a location can have multiple 
time zones due to daylight savings,
@@ -157,7 +161,7 @@ func (s *stmt) parametersToTypedValues(vals []namedValue) 
[]*message.TypedValue
                                        base := time.Date(1970, 1, 1, 0, 0, 0, 
0, time.FixedZone(zone, offset))
                                        typed.NumberValue = int64(v.Sub(base) / 
(24 * time.Hour))
 
-                               case "TIMESTAMP":
+                               case "TIMESTAMP", "UNSIGNED_TIMESTAMP":
                                        typed.Type = 
message.Rep_JAVA_SQL_TIMESTAMP
 
                                        // Because a location can have multiple 
time zones due to daylight savings,

Reply via email to