This is an automated email from the ASF dual-hosted git repository. joaoreis pushed a commit to branch trunk in repository https://gitbox.apache.org/repos/asf/cassandra-gocql-driver.git
The following commit(s) were added to refs/heads/trunk by this push: new a7f7f31d Fix tinyint silent unmarshal error and vector SliceMap bug a7f7f31d is described below commit a7f7f31d6f4b33d01c3aa2eb4d03a7d7ea3e89f5 Author: João Reis <joaor...@apache.org> AuthorDate: Thu Jul 3 13:53:32 2025 +0100 Fix tinyint silent unmarshal error and vector SliceMap bug Tinyint unmarshal silently errors and unmarshals tinyint values as 0 in some situations. This patch also removes other cases of silent errors in decode functions. Fix issue where reading vectors with SliceMap is not possible. Patch by João Reis; reviewed by James Hartig for CASSGO-82, CASSGO-83 --- CHANGELOG.md | 2 + integration_test.go | 705 ++++++++++++++++++++++++++++++++++++++++++++++++++++ marshal.go | 159 ++++++++---- marshal_test.go | 115 ++++++++- vector.go | 12 +- 5 files changed, 929 insertions(+), 64 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1e550f83..5c047fd0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -64,6 +64,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fix deadlock in refresh debouncer stop (CASSGO-41) - Endless query execution fix (CASSGO-50) - Accept peers with empty rack (CASSGO-6) +- Fix tinyint unmarshal regression (CASSGO-82) +- Vector columns can't be used with SliceMap() (CASSGO-83) ## [1.7.0] - 2024-09-23 diff --git a/integration_test.go b/integration_test.go index 18acaf4a..bd6ccb5c 100644 --- a/integration_test.go +++ b/integration_test.go @@ -30,9 +30,15 @@ package gocql // This file groups integration tests where Cassandra has to be set up with some special integration variables import ( "context" + "fmt" + "math/big" + "net" "reflect" + "strings" "testing" "time" + + inf "gopkg.in/inf.v0" ) // TestAuthentication verifies that gocql will work with a host configured to only accept authenticated connections @@ -272,3 +278,702 @@ func TestUDF(t *testing.T) { t.Fatal(err) } } + +// SliceMapTypesTestCase defines a test case for validating SliceMap/MapScan behavior +type SliceMapTypesTestCase struct { + CQLType string + CQLValue string // Non-NULL value to insert + ExpectedValue interface{} // Expected value for non-NULL case + ExpectedNullValue interface{} // Expected value for NULL +} + +// compareCollectionValues compares collection values (lists, sets, maps) with special handling +func compareCollectionValues(t *testing.T, cqlType string, expected, actual interface{}) bool { + switch { + case strings.HasPrefix(cqlType, "set<"): + // Sets are returned as slices, but order is not guaranteed + expectedSlice := reflect.ValueOf(expected) + actualSlice := reflect.ValueOf(actual) + if expectedSlice.Kind() != reflect.Slice || actualSlice.Kind() != reflect.Slice { + return false + } + if expectedSlice.Len() != actualSlice.Len() { + return false + } + + // Convert to maps for unordered comparison + expectedSet := make(map[interface{}]bool) + for i := 0; i < expectedSlice.Len(); i++ { + expectedSet[expectedSlice.Index(i).Interface()] = true + } + + actualSet := make(map[interface{}]bool) + for i := 0; i < actualSlice.Len(); i++ { + actualSet[actualSlice.Index(i).Interface()] = true + } + + return reflect.DeepEqual(expectedSet, actualSet) + + default: + // For lists, maps, and other collections, reflect.DeepEqual works fine + return reflect.DeepEqual(expected, actual) + } +} + +// compareValues compares expected and actual values with type-specific logic +func compareValues(t *testing.T, cqlType string, expected, actual interface{}) bool { + switch cqlType { + case "varint": + // big.Int needs Cmp() for proper comparison, but handle nil pointers safely + if expectedBig, ok := expected.(*big.Int); ok { + if actualBig, ok := actual.(*big.Int); ok { + // Handle nil cases + if expectedBig == nil && actualBig == nil { + return true + } + if expectedBig == nil || actualBig == nil { + return false + } + return expectedBig.Cmp(actualBig) == 0 + } + } + return reflect.DeepEqual(expected, actual) + + case "decimal": + // inf.Dec needs Cmp() for proper comparison, but handle nil pointers safely + if expectedDec, ok := expected.(*inf.Dec); ok { + if actualDec, ok := actual.(*inf.Dec); ok { + // Handle nil cases + if expectedDec == nil && actualDec == nil { + return true + } + if expectedDec == nil || actualDec == nil { + return false + } + return expectedDec.Cmp(actualDec) == 0 + } + } + return reflect.DeepEqual(expected, actual) + + default: + // reflect.DeepEqual handles nil vs empty slice/map distinction correctly for all types + // including inet (net.IP), blob ([]byte), collections ([]T, map[K]V), etc. + // This is critical for catching zero value behavior changes in the driver + return reflect.DeepEqual(expected, actual) + } +} + +// TestSliceMapMapScanTypes tests SliceMap and MapScan with various CQL types +func TestSliceMapMapScanTypes(t *testing.T) { + session := createSession(t) + defer session.Close() + + // Create test table + tableCQL := ` + CREATE TABLE IF NOT EXISTS gocql_test.slicemap_test ( + id int PRIMARY KEY, + tinyint_col tinyint, + smallint_col smallint, + int_col int, + bigint_col bigint, + float_col float, + double_col double, + boolean_col boolean, + text_col text, + ascii_col ascii, + varchar_col varchar, + timestamp_col timestamp, + uuid_col uuid, + timeuuid_col timeuuid, + inet_col inet, + blob_col blob, + varint_col varint, + decimal_col decimal, + date_col date, + time_col time, + duration_col duration + )` + + if err := createTable(session, tableCQL); err != nil { + t.Fatal("Failed to create test table:", err) + } + + // Clear existing data + if err := session.Query("TRUNCATE gocql_test.slicemap_test").Exec(); err != nil { + t.Fatal("Failed to truncate test table:", err) + } + + testCases := []SliceMapTypesTestCase{ + {"tinyint", "42", int8(42), int8(0)}, + {"smallint", "1234", int16(1234), int16(0)}, + {"int", "123456", int(123456), int(0)}, + {"bigint", "1234567890", int64(1234567890), int64(0)}, + {"float", "3.14", float32(3.14), float32(0)}, + {"double", "2.718281828", float64(2.718281828), float64(0)}, + {"boolean", "true", true, false}, + {"text", "'hello world'", "hello world", ""}, + {"ascii", "'hello ascii'", "hello ascii", ""}, + {"varchar", "'hello varchar'", "hello varchar", ""}, + {"timestamp", "1388534400000", time.Unix(1388534400, 0).UTC(), time.Time{}}, + {"uuid", "550e8400-e29b-41d4-a716-446655440000", mustParseUUID("550e8400-e29b-41d4-a716-446655440000"), UUID{}}, + {"timeuuid", "60d79c23-5793-11f0-8afe-bcfce78b517a", mustParseUUID("60d79c23-5793-11f0-8afe-bcfce78b517a"), UUID{}}, + {"inet", "'127.0.0.1'", net.ParseIP("127.0.0.1").To4(), net.IP(nil)}, + {"blob", "0x48656c6c6f", []byte("Hello"), []byte(nil)}, + {"varint", "123456789012345678901234567890", mustParseBigInt("123456789012345678901234567890"), (*big.Int)(nil)}, + {"decimal", "123.45", mustParseDecimal("123.45"), (*inf.Dec)(nil)}, + {"date", "'2015-05-03'", time.Date(2015, 5, 3, 0, 0, 0, 0, time.UTC), time.Time{}}, + {"time", "'13:30:54.234'", 13*time.Hour + 30*time.Minute + 54*time.Second + 234*time.Millisecond, time.Duration(0)}, + {"duration", "1y2mo3d4h5m6s789ms", mustCreateDuration(14, 3, 4*time.Hour+5*time.Minute+6*time.Second+789*time.Millisecond), Duration{}}, + } + + for i, tc := range testCases { + t.Run(tc.CQLType, func(t *testing.T) { + testSliceMapMapScanSimple(t, session, tc, i) + }) + } +} + +// Simplified test function that tests both SliceMap and MapScan with both NULL and non-NULL values +func testSliceMapMapScanSimple(t *testing.T, session *Session, tc SliceMapTypesTestCase, id int) { + colName := tc.CQLType + "_col" + + // Test non-NULL value + t.Run("NonNull", func(t *testing.T) { + // Insert non-NULL value + insertQuery := fmt.Sprintf("INSERT INTO gocql_test.slicemap_test (id, %s) VALUES (?, %s)", colName, tc.CQLValue) + if err := session.Query(insertQuery, id*2).Exec(); err != nil { + t.Fatalf("Failed to insert non-NULL value: %v", err) + } + + // Test both SliceMap and MapScan + for _, method := range []string{"SliceMap", "MapScan"} { + t.Run(method, func(t *testing.T) { + result := queryAndExtractValue(t, session, colName, id*2, method) + validateResult(t, tc.CQLType, tc.ExpectedValue, result, method, "non-NULL") + }) + } + }) + + // Test NULL value + t.Run("Null", func(t *testing.T) { + // Insert NULL value + insertQuery := fmt.Sprintf("INSERT INTO gocql_test.slicemap_test (id, %s) VALUES (?, NULL)", colName) + if err := session.Query(insertQuery, id*2+1).Exec(); err != nil { + t.Fatalf("Failed to insert NULL value: %v", err) + } + + // Test both SliceMap and MapScan + for _, method := range []string{"SliceMap", "MapScan"} { + t.Run(method, func(t *testing.T) { + result := queryAndExtractValue(t, session, colName, id*2+1, method) + validateResult(t, tc.CQLType, tc.ExpectedNullValue, result, method, "NULL") + }) + } + }) +} + +// Helper function to query and extract value using either SliceMap or MapScan +func queryAndExtractValue(t *testing.T, session *Session, colName string, id int, method string) interface{} { + selectQuery := fmt.Sprintf("SELECT %s FROM gocql_test.slicemap_test WHERE id = ?", colName) + + switch method { + case "SliceMap": + iter := session.Query(selectQuery, id).Iter() + sliceResults, err := iter.SliceMap() + iter.Close() + if err != nil { + t.Fatalf("SliceMap failed: %v", err) + } + if len(sliceResults) != 1 { + t.Fatalf("Expected 1 result, got %d", len(sliceResults)) + } + return sliceResults[0][colName] + + case "MapScan": + mapResult := make(map[string]interface{}) + if err := session.Query(selectQuery, id).MapScan(mapResult); err != nil { + t.Fatalf("MapScan failed: %v", err) + } + return mapResult[colName] + + default: + t.Fatalf("Unknown method: %s", method) + return nil + } +} + +// Helper function to validate results +func validateResult(t *testing.T, cqlType string, expected, actual interface{}, method, valueType string) { + // Check type + if expected != nil && actual != nil { + expectedType := reflect.TypeOf(expected) + actualType := reflect.TypeOf(actual) + if expectedType != actualType { + t.Errorf("%s %s %s: expected type %v, got %v", method, valueType, cqlType, expectedType, actualType) + } + } + + // Check value + if !compareValues(t, cqlType, expected, actual) { + t.Errorf("%s %s %s: expected value %v (type %T), got %v (type %T)", + method, valueType, cqlType, expected, expected, actual, actual) + } +} + +// Helper function to parse UUID (for test data) +func mustParseUUID(s string) UUID { + u, err := ParseUUID(s) + if err != nil { + panic(err) + } + return u +} + +// Helper function to parse big.Int (for test data) +func mustParseBigInt(s string) *big.Int { + i := new(big.Int) + if _, ok := i.SetString(s, 10); !ok { + panic("failed to parse big.Int: " + s) + } + return i +} + +// Helper function to parse inf.Dec (for test data) +func mustParseDecimal(s string) *inf.Dec { + dec := new(inf.Dec) + if _, ok := dec.SetString(s); !ok { + panic("failed to parse inf.Dec: " + s) + } + return dec +} + +// Helper function to create Duration (for test data) +func mustCreateDuration(months int32, days int32, timeDuration time.Duration) Duration { + return Duration{ + Months: months, + Days: days, + Nanoseconds: timeDuration.Nanoseconds(), + } +} + +// TestSliceMapMapScanCounterTypes tests counter types separately since they have special restrictions +// (counter columns can't be mixed with other column types in the same table) +func TestSliceMapMapScanCounterTypes(t *testing.T) { + session := createSession(t) + defer session.Close() + + // Create separate table for counter types + if err := createTable(session, ` + CREATE TABLE IF NOT EXISTS gocql_test.slicemap_counter_test ( + id int PRIMARY KEY, + counter_col counter + ) + `); err != nil { + t.Fatal("Failed to create counter test table:", err) + } + + // Clear existing data + if err := session.Query("TRUNCATE gocql_test.slicemap_counter_test").Exec(); err != nil { + t.Fatal("Failed to truncate counter test table:", err) + } + + testID := 1 + expectedValue := int64(42) + + // Increment counter (can't INSERT into counter, must UPDATE) + err := session.Query("UPDATE gocql_test.slicemap_counter_test SET counter_col = counter_col + 42 WHERE id = ?", testID).Exec() + if err != nil { + t.Fatalf("Failed to increment counter: %v", err) + } + + // Test both SliceMap and MapScan + for _, method := range []string{"SliceMap", "MapScan"} { + t.Run(method, func(t *testing.T) { + var result interface{} + + selectQuery := "SELECT counter_col FROM gocql_test.slicemap_counter_test WHERE id = ?" + if method == "SliceMap" { + iter := session.Query(selectQuery, testID).Iter() + sliceResults, err := iter.SliceMap() + iter.Close() + if err != nil { + t.Fatalf("SliceMap failed: %v", err) + } + if len(sliceResults) != 1 { + t.Fatalf("Expected 1 result, got %d", len(sliceResults)) + } + result = sliceResults[0]["counter_col"] + } else { + mapResult := make(map[string]interface{}) + if err := session.Query(selectQuery, testID).MapScan(mapResult); err != nil { + t.Fatalf("MapScan failed: %v", err) + } + result = mapResult["counter_col"] + } + + validateResult(t, "counter", expectedValue, result, method, "incremented") + }) + } +} + +// TestSliceMapMapScanTupleTypes tests tuple types separately since they have special handling +// (tuple elements get split into individual columns) +func TestSliceMapMapScanTupleTypes(t *testing.T) { + session := createSession(t) + defer session.Close() + + // Create test table with tuple column + if err := createTable(session, ` + CREATE TABLE IF NOT EXISTS gocql_test.slicemap_tuple_test ( + id int PRIMARY KEY, + tuple_col tuple<int, text> + ) + `); err != nil { + t.Fatal("Failed to create tuple test table:", err) + } + + // Clear existing data + if err := session.Query("TRUNCATE gocql_test.slicemap_tuple_test").Exec(); err != nil { + t.Fatal("Failed to truncate tuple test table:", err) + } + + // Test non-NULL tuple + t.Run("NonNull", func(t *testing.T) { + testID := 1 + // Insert tuple value + err := session.Query("INSERT INTO gocql_test.slicemap_tuple_test (id, tuple_col) VALUES (?, (42, 'hello'))", testID).Exec() + if err != nil { + t.Fatalf("Failed to insert tuple value: %v", err) + } + + // Test both SliceMap and MapScan + for _, method := range []string{"SliceMap", "MapScan"} { + t.Run(method, func(t *testing.T) { + var result map[string]interface{} + + selectQuery := "SELECT tuple_col FROM gocql_test.slicemap_tuple_test WHERE id = ?" + if method == "SliceMap" { + iter := session.Query(selectQuery, testID).Iter() + sliceResults, err := iter.SliceMap() + iter.Close() + if err != nil { + t.Fatalf("SliceMap failed: %v", err) + } + if len(sliceResults) != 1 { + t.Fatalf("Expected 1 result, got %d", len(sliceResults)) + } + result = sliceResults[0] + } else { + result = make(map[string]interface{}) + if err := session.Query(selectQuery, testID).MapScan(result); err != nil { + t.Fatalf("MapScan failed: %v", err) + } + } + + // Check tuple elements (tuples get split into individual columns) + elem0Key := TupleColumnName("tuple_col", 0) + elem1Key := TupleColumnName("tuple_col", 1) + + if result[elem0Key] != 42 { + t.Errorf("%s tuple[0]: expected 42, got %v", method, result[elem0Key]) + } + if result[elem1Key] != "hello" { + t.Errorf("%s tuple[1]: expected 'hello', got %v", method, result[elem1Key]) + } + }) + } + }) + + // Test NULL tuple + t.Run("Null", func(t *testing.T) { + testID := 2 + // Insert NULL tuple + err := session.Query("INSERT INTO gocql_test.slicemap_tuple_test (id, tuple_col) VALUES (?, NULL)", testID).Exec() + if err != nil { + t.Fatalf("Failed to insert NULL tuple: %v", err) + } + + // Test both SliceMap and MapScan + for _, method := range []string{"SliceMap", "MapScan"} { + t.Run(method, func(t *testing.T) { + var result map[string]interface{} + + selectQuery := "SELECT tuple_col FROM gocql_test.slicemap_tuple_test WHERE id = ?" + if method == "SliceMap" { + iter := session.Query(selectQuery, testID).Iter() + sliceResults, err := iter.SliceMap() + iter.Close() + if err != nil { + t.Fatalf("SliceMap failed: %v", err) + } + if len(sliceResults) != 1 { + t.Fatalf("Expected 1 result, got %d", len(sliceResults)) + } + result = sliceResults[0] + } else { + result = make(map[string]interface{}) + if err := session.Query(selectQuery, testID).MapScan(result); err != nil { + t.Fatalf("MapScan failed: %v", err) + } + } + + // Check tuple elements (NULL tuple gives zero values) + elem0Key := TupleColumnName("tuple_col", 0) + elem1Key := TupleColumnName("tuple_col", 1) + + if result[elem0Key] != 0 { + t.Errorf("%s NULL tuple[0]: expected 0, got %v", method, result[elem0Key]) + } + if result[elem1Key] != "" { + t.Errorf("%s NULL tuple[1]: expected '', got %v", method, result[elem1Key]) + } + }) + } + }) +} + +// TestSliceMapMapScanVectorTypes tests vector types separately since they need Cassandra 5.0+ and special table setup +// (vectors need separate tables and version checks) +func TestSliceMapMapScanVectorTypes(t *testing.T) { + session := createSession(t) + defer session.Close() + + // Vector types require Cassandra 5.0+ + if session.control.getConn().host.Version().Before(5, 0, 0) { + t.Skip("Vector types require Cassandra 5.0+") + } + + // Create test table with vector columns + if err := createTable(session, ` + CREATE TABLE IF NOT EXISTS gocql_test.slicemap_vector_test ( + id int PRIMARY KEY, + vector_float_col vector<float, 3>, + vector_text_col vector<text, 2> + ) + `); err != nil { + t.Fatal("Failed to create vector test table:", err) + } + + // Clear existing data + if err := session.Query("TRUNCATE gocql_test.slicemap_vector_test").Exec(); err != nil { + t.Fatal("Failed to truncate vector test table:", err) + } + + testCases := []struct { + colName string + cqlValue string + expectedValue interface{} + expectedNull interface{} + }{ + {"vector_float_col", "[1.0, 2.5, -3.0]", []float32{1.0, 2.5, -3.0}, []float32(nil)}, + {"vector_text_col", "['hello', 'world']", []string{"hello", "world"}, []string(nil)}, + } + + for _, tc := range testCases { + t.Run(tc.colName, func(t *testing.T) { + // Test non-NULL value + t.Run("NonNull", func(t *testing.T) { + testID := 1 + // Insert non-NULL value + insertQuery := fmt.Sprintf("INSERT INTO gocql_test.slicemap_vector_test (id, %s) VALUES (?, %s)", tc.colName, tc.cqlValue) + if err := session.Query(insertQuery, testID).Exec(); err != nil { + t.Fatalf("Failed to insert non-NULL value: %v", err) + } + + // Test both SliceMap and MapScan + for _, method := range []string{"SliceMap", "MapScan"} { + t.Run(method, func(t *testing.T) { + var result interface{} + + selectQuery := fmt.Sprintf("SELECT %s FROM gocql_test.slicemap_vector_test WHERE id = ?", tc.colName) + if method == "SliceMap" { + iter := session.Query(selectQuery, testID).Iter() + sliceResults, err := iter.SliceMap() + iter.Close() + if err != nil { + t.Fatalf("SliceMap failed: %v", err) + } + if len(sliceResults) != 1 { + t.Fatalf("Expected 1 result, got %d", len(sliceResults)) + } + result = sliceResults[0][tc.colName] + } else { + mapResult := make(map[string]interface{}) + if err := session.Query(selectQuery, testID).MapScan(mapResult); err != nil { + t.Fatalf("MapScan failed: %v", err) + } + result = mapResult[tc.colName] + } + + validateResult(t, tc.colName, tc.expectedValue, result, method, "non-NULL") + }) + } + }) + + // Test NULL value + t.Run("Null", func(t *testing.T) { + testID := 2 + // Insert NULL value + insertQuery := fmt.Sprintf("INSERT INTO gocql_test.slicemap_vector_test (id, %s) VALUES (?, NULL)", tc.colName) + if err := session.Query(insertQuery, testID).Exec(); err != nil { + t.Fatalf("Failed to insert NULL value: %v", err) + } + + // Test both SliceMap and MapScan + for _, method := range []string{"SliceMap", "MapScan"} { + t.Run(method, func(t *testing.T) { + var result interface{} + + selectQuery := fmt.Sprintf("SELECT %s FROM gocql_test.slicemap_vector_test WHERE id = ?", tc.colName) + if method == "SliceMap" { + iter := session.Query(selectQuery, testID).Iter() + sliceResults, err := iter.SliceMap() + iter.Close() + if err != nil { + t.Fatalf("SliceMap failed: %v", err) + } + if len(sliceResults) != 1 { + t.Fatalf("Expected 1 result, got %d", len(sliceResults)) + } + result = sliceResults[0][tc.colName] + } else { + mapResult := make(map[string]interface{}) + if err := session.Query(selectQuery, testID).MapScan(mapResult); err != nil { + t.Fatalf("MapScan failed: %v", err) + } + result = mapResult[tc.colName] + } + + // Vectors should return nil slices for NULL values for consistency + validateResult(t, tc.colName, tc.expectedNull, result, method, "NULL") + }) + } + }) + }) + } +} + +// TestSliceMapMapScanCollectionTypes tests collection types separately since they have special handling +// (collections should return nil slices/maps for NULL values for consistency with other slice-based types) +func TestSliceMapMapScanCollectionTypes(t *testing.T) { + session := createSession(t) + defer session.Close() + + // Create test table with collection columns + if err := createTable(session, ` + CREATE TABLE IF NOT EXISTS gocql_test.slicemap_collection_test ( + id int PRIMARY KEY, + list_col list<text>, + set_col set<int>, + map_col map<text, int> + ) + `); err != nil { + t.Fatal("Failed to create collection test table:", err) + } + + // Clear existing data + if err := session.Query("TRUNCATE gocql_test.slicemap_collection_test").Exec(); err != nil { + t.Fatal("Failed to truncate collection test table:", err) + } + + testCases := []struct { + colName string + cqlValue string + expectedValue interface{} + expectedNull interface{} + }{ + {"list_col", "['a', 'b', 'c']", []string{"a", "b", "c"}, []string(nil)}, + {"set_col", "{1, 2, 3}", []int{1, 2, 3}, []int(nil)}, + {"map_col", "{'key1': 1, 'key2': 2}", map[string]int{"key1": 1, "key2": 2}, map[string]int(nil)}, + } + + for _, tc := range testCases { + t.Run(tc.colName, func(t *testing.T) { + // Test non-NULL value + t.Run("NonNull", func(t *testing.T) { + testID := 1 + // Insert non-NULL value + insertQuery := fmt.Sprintf("INSERT INTO gocql_test.slicemap_collection_test (id, %s) VALUES (?, %s)", tc.colName, tc.cqlValue) + if err := session.Query(insertQuery, testID).Exec(); err != nil { + t.Fatalf("Failed to insert non-NULL value: %v", err) + } + + // Test both SliceMap and MapScan + for _, method := range []string{"SliceMap", "MapScan"} { + t.Run(method, func(t *testing.T) { + var result interface{} + + selectQuery := fmt.Sprintf("SELECT %s FROM gocql_test.slicemap_collection_test WHERE id = ?", tc.colName) + if method == "SliceMap" { + iter := session.Query(selectQuery, testID).Iter() + sliceResults, err := iter.SliceMap() + iter.Close() + if err != nil { + t.Fatalf("SliceMap failed: %v", err) + } + if len(sliceResults) != 1 { + t.Fatalf("Expected 1 result, got %d", len(sliceResults)) + } + result = sliceResults[0][tc.colName] + } else { + mapResult := make(map[string]interface{}) + if err := session.Query(selectQuery, testID).MapScan(mapResult); err != nil { + t.Fatalf("MapScan failed: %v", err) + } + result = mapResult[tc.colName] + } + + // For sets, we need special comparison since order is not guaranteed + if strings.HasPrefix(tc.colName, "set_") { + if !compareCollectionValues(t, tc.colName, tc.expectedValue, result) { + t.Errorf("%s non-NULL %s: expected %v, got %v", method, tc.colName, tc.expectedValue, result) + } + } else { + validateResult(t, tc.colName, tc.expectedValue, result, method, "non-NULL") + } + }) + } + }) + + // Test NULL value + t.Run("Null", func(t *testing.T) { + testID := 2 + // Insert NULL value + insertQuery := fmt.Sprintf("INSERT INTO gocql_test.slicemap_collection_test (id, %s) VALUES (?, NULL)", tc.colName) + if err := session.Query(insertQuery, testID).Exec(); err != nil { + t.Fatalf("Failed to insert NULL value: %v", err) + } + + // Test both SliceMap and MapScan + for _, method := range []string{"SliceMap", "MapScan"} { + t.Run(method, func(t *testing.T) { + var result interface{} + + selectQuery := fmt.Sprintf("SELECT %s FROM gocql_test.slicemap_collection_test WHERE id = ?", tc.colName) + if method == "SliceMap" { + iter := session.Query(selectQuery, testID).Iter() + sliceResults, err := iter.SliceMap() + iter.Close() + if err != nil { + t.Fatalf("SliceMap failed: %v", err) + } + if len(sliceResults) != 1 { + t.Fatalf("Expected 1 result, got %d", len(sliceResults)) + } + result = sliceResults[0][tc.colName] + } else { + mapResult := make(map[string]interface{}) + if err := session.Query(selectQuery, testID).MapScan(mapResult); err != nil { + t.Fatalf("MapScan failed: %v", err) + } + result = mapResult[tc.colName] + } + + // Collections should return nil slices/maps for NULL values for consistency + validateResult(t, tc.colName, tc.expectedNull, result, method, "NULL") + }) + } + }) + }) + } +} diff --git a/marshal.go b/marshal.go index 5133a96d..5562d9fd 100644 --- a/marshal.go +++ b/marshal.go @@ -432,15 +432,19 @@ func (smallIntTypeInfo) Marshal(value interface{}) ([]byte, error) { // Unmarshal unmarshals the byte slice into the value. func (s smallIntTypeInfo) Unmarshal(data []byte, value interface{}) error { + decodedData, err := decShort(data) + if err != nil { + return unmarshalErrorf("%s", err.Error()) + } if iptr, ok := value.(*interface{}); ok && iptr != nil { var v int16 - if err := unmarshalIntlike(TypeSmallInt, int64(decShort(data)), data, &v); err != nil { + if err := unmarshalIntlike(TypeSmallInt, int64(decodedData), data, &v); err != nil { return err } *iptr = v return nil } - return unmarshalIntlike(TypeSmallInt, int64(decShort(data)), data, value) + return unmarshalIntlike(TypeSmallInt, int64(decodedData), data, value) } type tinyIntTypeInfo struct{} @@ -540,15 +544,19 @@ func (tinyIntTypeInfo) Marshal(value interface{}) ([]byte, error) { // Unmarshal unmarshals the byte slice into the value. func (t tinyIntTypeInfo) Unmarshal(data []byte, value interface{}) error { + decodedData, err := decTiny(data) + if err != nil { + return unmarshalErrorf("%s", err.Error()) + } if iptr, ok := value.(*interface{}); ok && iptr != nil { var v int8 - if err := unmarshalIntlike(TypeSmallInt, int64(decShort(data)), data, &v); err != nil { + if err := unmarshalIntlike(TypeTinyInt, int64(decodedData), data, &v); err != nil { return err } *iptr = v return nil } - return unmarshalIntlike(TypeTinyInt, int64(decTiny(data)), data, value) + return unmarshalIntlike(TypeTinyInt, int64(decodedData), data, value) } type intTypeInfo struct{} @@ -636,26 +644,34 @@ func (intTypeInfo) Marshal(value interface{}) ([]byte, error) { // Unmarshal unmarshals the byte slice into the value. func (i intTypeInfo) Unmarshal(data []byte, value interface{}) error { + decodedData, err := decInt(data) + if err != nil { + return unmarshalErrorf("%s", err.Error()) + } if iptr, ok := value.(*interface{}); ok && iptr != nil { var v int - if err := unmarshalIntlike(TypeInt, int64(decInt(data)), data, &v); err != nil { + if err := unmarshalIntlike(TypeInt, int64(decodedData), data, &v); err != nil { return err } *iptr = v return nil } - return unmarshalIntlike(TypeInt, int64(decInt(data)), data, value) + return unmarshalIntlike(TypeInt, int64(decodedData), data, value) } func encInt(x int32) []byte { return []byte{byte(x >> 24), byte(x >> 16), byte(x >> 8), byte(x)} } -func decInt(x []byte) int32 { +func decInt(x []byte) (int32, error) { + if x == nil || len(x) == 0 { + // len(x)==0 is to keep old behavior from 1.x (empty values can be in the DB and are different from NULL) + return 0, nil + } if len(x) != 4 { - return 0 + return 0, fmt.Errorf("expected 4 bytes decoding int but got %v", len(x)) } - return int32(x[0])<<24 | int32(x[1])<<16 | int32(x[2])<<8 | int32(x[3]) + return int32(x[0])<<24 | int32(x[1])<<16 | int32(x[2])<<8 | int32(x[3]), nil } func encShort(x int16) []byte { @@ -665,18 +681,26 @@ func encShort(x int16) []byte { return p } -func decShort(p []byte) int16 { +func decShort(p []byte) (int16, error) { + if p == nil || len(p) == 0 { + // len(p)==0 is to keep old behavior from 1.x (empty values can be in the DB and are different from NULL) + return 0, nil + } if len(p) != 2 { - return 0 + return 0, fmt.Errorf("expected 2 bytes decoding short but got %v", len(p)) } - return int16(p[0])<<8 | int16(p[1]) + return int16(p[0])<<8 | int16(p[1]), nil } -func decTiny(p []byte) int8 { +func decTiny(p []byte) (int8, error) { + if p == nil || len(p) == 0 { + // len(p)==0 is to keep old behavior from 1.x (empty values can be in the DB and are different from NULL) + return 0, nil + } if len(p) != 1 { - return 0 + return 0, fmt.Errorf("expected 1 byte decoding tinyint but got %v", len(p)) } - return int8(p[0]) + return int8(p[0]), nil } type bigIntLikeTypeInfo struct { @@ -774,15 +798,19 @@ func bytesToUint64(data []byte) (ret uint64) { // Unmarshal unmarshals the byte slice into the value. func (b bigIntLikeTypeInfo) Unmarshal(data []byte, value interface{}) error { + decodedData, err := decBigInt(data) + if err != nil { + return unmarshalErrorf("can not unmarshal bigint: %s", err.Error()) + } if iptr, ok := value.(*interface{}); ok && iptr != nil { var v int64 - if err := unmarshalIntlike(b.typ, decBigInt(data), data, &v); err != nil { + if err := unmarshalIntlike(b.typ, decodedData, data, &v); err != nil { return err } *iptr = v return nil } - return unmarshalIntlike(b.typ, decBigInt(data), data, value) + return unmarshalIntlike(b.typ, decodedData, data, value) } type varintTypeInfo struct{} @@ -1083,14 +1111,18 @@ func unmarshalIntlike(typ Type, int64Val int64, data []byte, value interface{}) return unmarshalErrorf("can not unmarshal int-like into %T. Accepted types: big.Int, int8, uint8, int16, uint16, int, uint, int32, uint32, int64, uint64, string, *interface{}.", value) } -func decBigInt(data []byte) int64 { +func decBigInt(data []byte) (int64, error) { + if data == nil || len(data) == 0 { + // len(data)==0 is to keep old behavior from 1.x (empty values can be in the DB and are different from NULL) + return 0, nil + } if len(data) != 8 { - return 0 + return 0, fmt.Errorf("expected 8 bytes, got %d", len(data)) } return int64(data[0])<<56 | int64(data[1])<<48 | int64(data[2])<<40 | int64(data[3])<<32 | int64(data[4])<<24 | int64(data[5])<<16 | - int64(data[6])<<8 | int64(data[7]) + int64(data[6])<<8 | int64(data[7]), nil } type booleanTypeInfo struct{} @@ -1130,12 +1162,16 @@ func (b booleanTypeInfo) Marshal(value interface{}) ([]byte, error) { // Unmarshal unmarshals the byte slice into the value. func (b booleanTypeInfo) Unmarshal(data []byte, value interface{}) error { + decodedData, err := decBool(data) + if err != nil { + return unmarshalErrorf("can not unmarshal boolean: %s", err.Error()) + } switch v := value.(type) { case *bool: - *v = decBool(data) + *v = decodedData return nil case *interface{}: - *v = decBool(data) + *v = decodedData return nil } rv := reflect.ValueOf(value) @@ -1145,7 +1181,7 @@ func (b booleanTypeInfo) Unmarshal(data []byte, value interface{}) error { rv = rv.Elem() switch rv.Type().Kind() { case reflect.Bool: - rv.SetBool(decBool(data)) + rv.SetBool(decodedData) return nil } return unmarshalErrorf("can not unmarshal boolean into %T. Accepted types: *bool, *interface{}.", value) @@ -1158,11 +1194,15 @@ func encBool(v bool) []byte { return []byte{0} } -func decBool(v []byte) bool { - if len(v) == 0 { - return false +func decBool(v []byte) (bool, error) { + if v == nil || len(v) == 0 { + // len(v)==0 is to keep old behavior from 1.x (empty values can be in the DB and are different from NULL) + return false, nil + } + if len(v) != 1 { + return false, fmt.Errorf("expected 1 byte, got %d", len(v)) } - return v[0] != 0 + return v[0] != 0, nil } type floatTypeInfo struct{} @@ -1200,12 +1240,16 @@ func (floatTypeInfo) Marshal(value interface{}) ([]byte, error) { // Unmarshal unmarshals the byte slice into the value. func (floatTypeInfo) Unmarshal(data []byte, value interface{}) error { + decodedData, err := decInt(data) + if err != nil { + return err + } switch v := value.(type) { case *float32: - *v = math.Float32frombits(uint32(decInt(data))) + *v = math.Float32frombits(uint32(decodedData)) return nil case *interface{}: - *v = math.Float32frombits(uint32(decInt(data))) + *v = math.Float32frombits(uint32(decodedData)) return nil } rv := reflect.ValueOf(value) @@ -1215,7 +1259,7 @@ func (floatTypeInfo) Unmarshal(data []byte, value interface{}) error { rv = rv.Elem() switch rv.Type().Kind() { case reflect.Float32: - rv.SetFloat(float64(math.Float32frombits(uint32(decInt(data))))) + rv.SetFloat(float64(math.Float32frombits(uint32(decodedData)))) return nil } return unmarshalErrorf("can not unmarshal float into %T. Accepted types: *float32, *interface{}, UnsetValue.", value) @@ -1254,12 +1298,17 @@ func (doubleTypeInfo) Marshal(value interface{}) ([]byte, error) { // Unmarshal unmarshals the byte slice into the value. func (doubleTypeInfo) Unmarshal(data []byte, value interface{}) error { + decodedData, err := decBigInt(data) + if err != nil { + return unmarshalErrorf("can not unmarshal double: %s", err.Error()) + } + decodedUint64 := uint64(decodedData) switch v := value.(type) { case *float64: - *v = math.Float64frombits(uint64(decBigInt(data))) + *v = math.Float64frombits(decodedUint64) return nil case *interface{}: - *v = math.Float64frombits(uint64(decBigInt(data))) + *v = math.Float64frombits(decodedUint64) return nil } rv := reflect.ValueOf(value) @@ -1269,7 +1318,7 @@ func (doubleTypeInfo) Unmarshal(data []byte, value interface{}) error { rv = rv.Elem() switch rv.Type().Kind() { case reflect.Float64: - rv.SetFloat(math.Float64frombits(uint64(decBigInt(data)))) + rv.SetFloat(math.Float64frombits(decodedUint64)) return nil } return unmarshalErrorf("can not unmarshal double into %T. Accepted types: *float64, *interface{}.", value) @@ -1312,20 +1361,22 @@ func (decimalTypeInfo) Marshal(value interface{}) ([]byte, error) { // Unmarshal unmarshals the byte slice into the value. func (decimalTypeInfo) Unmarshal(data []byte, value interface{}) error { + if len(data) < 4 { + return unmarshalErrorf("inf.Dec needs at least 4 bytes, while value has only %d", len(data)) + } + + decodedData, err := decInt(data[0:4]) + if err != nil { + return err + } switch v := value.(type) { case *inf.Dec: - if len(data) < 4 { - return unmarshalErrorf("inf.Dec needs at least 4 bytes, while value has only %d", len(data)) - } - scale := decInt(data[0:4]) + scale := decodedData unscaled := decBigInt2C(data[4:], nil) *v = *inf.NewDecBig(unscaled, inf.Scale(scale)) return nil case *interface{}: - if len(data) < 4 { - return unmarshalErrorf("inf.Dec needs at least 4 bytes, while value has only %d", len(data)) - } - scale := decInt(data[0:4]) + scale := decodedData unscaled := decBigInt2C(data[4:], nil) *v = inf.NewDecBig(unscaled, inf.Scale(scale)) return nil @@ -1414,16 +1465,20 @@ func (timestampTypeInfo) Marshal(value interface{}) ([]byte, error) { // Unmarshal unmarshals the byte slice into the value. func (timestampTypeInfo) Unmarshal(data []byte, value interface{}) error { + decodedData, err := decBigInt(data) + if err != nil { + return unmarshalErrorf("can not unmarshal timestamp: %s", err.Error()) + } switch v := value.(type) { case *int64: - *v = decBigInt(data) + *v = decodedData return nil case *time.Time: if len(data) == 0 { *v = time.Time{} return nil } - x := decBigInt(data) + x := decodedData sec := x / 1000 nsec := (x - sec*1000) * 1000000 *v = time.Unix(sec, nsec).In(time.UTC) @@ -1433,7 +1488,7 @@ func (timestampTypeInfo) Unmarshal(data []byte, value interface{}) error { *v = time.Time{} return nil } - x := decBigInt(data) + x := decodedData sec := x / 1000 nsec := (x - sec*1000) * 1000000 *v = time.Unix(sec, nsec).In(time.UTC) @@ -1447,7 +1502,7 @@ func (timestampTypeInfo) Unmarshal(data []byte, value interface{}) error { rv = rv.Elem() switch rv.Type().Kind() { case reflect.Int64: - rv.SetInt(decBigInt(data)) + rv.SetInt(decodedData) return nil } return unmarshalErrorf("can not unmarshal timestamp into %T. Accepted types: *int64, *time.Time, *interface{}.", value) @@ -1490,15 +1545,19 @@ func (timeTypeInfo) Marshal(value interface{}) ([]byte, error) { // Unmarshal unmarshals the byte slice into the value. func (timeTypeInfo) Unmarshal(data []byte, value interface{}) error { + decodedData, err := decBigInt(data) + if err != nil { + return unmarshalErrorf("can not unmarshal time: %s", err.Error()) + } switch v := value.(type) { case *int64: - *v = decBigInt(data) + *v = decodedData return nil case *time.Duration: - *v = time.Duration(decBigInt(data)) + *v = time.Duration(decodedData) return nil case *interface{}: - *v = time.Duration(decBigInt(data)) + *v = time.Duration(decodedData) return nil } @@ -1509,7 +1568,7 @@ func (timeTypeInfo) Unmarshal(data []byte, value interface{}) error { rv = rv.Elem() switch rv.Type().Kind() { case reflect.Int64: - rv.SetInt(decBigInt(data)) + rv.SetInt(decodedData) return nil } return unmarshalErrorf("can not unmarshal time into %T. Accepted types: *int64, *time.Duration, *interface{}.", value) @@ -2010,7 +2069,7 @@ func (c CollectionType) unmarshalListSet(data []byte, value interface{}) error { } return nil } - return unmarshalErrorf("can not unmarshal collection into %T. Accepted types: *slice, *array.", value) + return unmarshalErrorf("can not unmarshal collection into %T. Accepted types: *slice, *array, *interface{}.", value) } type mapCQLType struct { diff --git a/marshal_test.go b/marshal_test.go index 5b518b13..8a7b989e 100644 --- a/marshal_test.go +++ b/marshal_test.go @@ -1704,9 +1704,85 @@ func TestMarshalTime(t *testing.T) { t.Errorf("marshalTest[%d]: %v", i, err) continue } + decoded, err := decBigInt(test.Data) + if err != nil { + t.Error(err) + } if !bytes.Equal(data, test.Data) { t.Errorf("marshalTest[%d]: expected %x (%v), got %x (%v) for time %s", i, - test.Data, decInt(test.Data), data, decInt(data), test.Value) + test.Data, decoded, data, decoded, test.Value) + } + } +} + +func TestUnmarshalTimestamp(t *testing.T) { + var marshalTimestampTests = []struct { + Info TypeInfo + Data []byte + Value interface{} + }{ + { + timestampTypeInfo{}, + []byte("\x00\x00\x01\x40\x77\x16\xe1\xb8"), + time.Date(2013, time.August, 13, 9, 52, 3, 0, time.UTC), + }, + { + timestampTypeInfo{}, + []byte("\x00\x00\x01\x40\x77\x16\xe1\xb8"), + int64(1376387523000), + }, + { + // 9223372036854 is the maximum time representable in ms since the epoch + // with int64 if using UnixNano to convert + timestampTypeInfo{}, + []byte("\x00\x00\x08\x63\x7b\xd0\x5a\xf6"), + time.Date(2262, time.April, 11, 23, 47, 16, 854775807, time.UTC), + }, + { + // One nanosecond after causes overflow when using UnixNano + // Instead it should resolve to the same time in ms + timestampTypeInfo{}, + []byte("\x00\x00\x08\x63\x7b\xd0\x5a\xf6"), + time.Date(2262, time.April, 11, 23, 47, 16, 854775808, time.UTC), + }, + { + // -9223372036855 is the minimum time representable in ms since the epoch + // with int64 if using UnixNano to convert + timestampTypeInfo{}, + []byte("\xff\xff\xf7\x9c\x84\x2f\xa5\x09"), + time.Date(1677, time.September, 21, 00, 12, 43, 145224192, time.UTC), + }, + { + // One nanosecond earlier causes overflow when using UnixNano + // it should resolve to the same time in ms + timestampTypeInfo{}, + []byte("\xff\xff\xf7\x9c\x84\x2f\xa5\x09"), + time.Date(1677, time.September, 21, 00, 12, 43, 145224191, time.UTC), + }, + { + // Store the zero time as a blank slice + timestampTypeInfo{}, + []byte{}, + time.Time{}, + }, + { + // Store the zero time as a nil slice + timestampTypeInfo{}, + []byte(nil), + time.Time{}, + }, + } + + for i, test := range marshalTimestampTests { + v := reflect.New(reflect.TypeOf(test.Value)).Interface() + err := Unmarshal(test.Info, test.Data, &v) + if err != nil { + t.Errorf("marshalTest[%d]: %v", i, err) + continue + } + if reflect.DeepEqual(v, test.Value) { + t.Errorf("marshalTest[%d]: expected %v, got %v", i, + test.Value, v) } } } @@ -1761,6 +1837,12 @@ func TestMarshalTimestamp(t *testing.T) { []byte{}, time.Time{}, }, + { + // Store the zero time as a nil slice + timestampTypeInfo{}, + []byte(nil), + time.Time{}, + }, } for i, test := range marshalTimestampTests { @@ -1770,8 +1852,8 @@ func TestMarshalTimestamp(t *testing.T) { continue } if !bytes.Equal(data, test.Data) { - t.Errorf("marshalTest[%d]: expected %x (%v), got %x (%v) for time %s", i, - test.Data, decBigInt(test.Data), data, decBigInt(data), test.Value) + t.Errorf("marshalTest[%d]: expected %x, got %x for time %s", i, + test.Data, data, test.Value) } } } @@ -1897,23 +1979,22 @@ func TestMarshalTuple(t *testing.T) { }, } - for _, tc := range testCases { + for i, tc := range testCases { t.Run(tc.name, func(t *testing.T) { data, err := Marshal(info, tc.value) if err != nil { - t.Errorf("marshalTest: %v", err) + t.Errorf("marshalTest[%d]: %v", i, err) return } - if !bytes.Equal(data, tc.expected) { - t.Errorf("marshalTest: expected %x (%v), got %x (%v)", - tc.expected, decBigInt(tc.expected), data, decBigInt(data)) + t.Errorf("marshalTest[%d]: expected %x, got %x", + i, tc.expected, data) return } err = Unmarshal(info, data, tc.checkValue) if err != nil { - t.Errorf("unmarshalTest: %v", err) + t.Errorf("marshalTest[%d]: %v", i, err) return } @@ -2284,9 +2365,13 @@ func TestMarshalDate(t *testing.T) { t.Errorf("marshalTest[%d]: %v", i, err) continue } + decoded, err := decInt(test.Data) + if err != nil { + t.Error(err) + } if !bytes.Equal(data, test.Data) { t.Errorf("marshalTest[%d]: expected %x (%v), got %x (%v) for time %s", i, - test.Data, decInt(test.Data), data, decInt(data), test.Value) + test.Data, decoded, data, decoded, test.Value) } } } @@ -2323,9 +2408,13 @@ func TestLargeDate(t *testing.T) { t.Errorf("largeDateTest[%d]: %v", i, err) continue } + decoded, err := decInt(test.Data) + if err != nil { + t.Error(err) + } if !bytes.Equal(data, test.Data) { t.Errorf("largeDateTest[%d]: expected %x (%v), got %x (%v) for time %s", i, - test.Data, decInt(test.Data), data, decInt(data), test.Value) + test.Data, decoded, data, decoded, test.Value) } var date time.Time @@ -2379,8 +2468,8 @@ func TestMarshalDuration(t *testing.T) { continue } if !bytes.Equal(data, test.Data) { - t.Errorf("marshalTest[%d]: expected %x (%v), got %x (%v) for time %s", i, - test.Data, decInt(test.Data), data, decInt(data), test.Value) + t.Errorf("marshalTest[%d]: expected %x, got %x for time %s", i, + test.Data, data, test.Value) } } } diff --git a/vector.go b/vector.go index 266e341e..72a11c2c 100644 --- a/vector.go +++ b/vector.go @@ -124,6 +124,13 @@ func (v VectorType) Unmarshal(data []byte, value interface{}) error { } rv = rv.Elem() t := rv.Type() + if t.Kind() == reflect.Interface { + if t.NumMethod() != 0 { + return unmarshalErrorf("can not unmarshal into non-empty interface %T", value) + } + t = reflect.TypeOf(v.Zero()) + } + k := t.Kind() switch k { case reflect.Slice, reflect.Array: @@ -143,6 +150,9 @@ func (v VectorType) Unmarshal(data []byte, value interface{}) error { } } else { rv.Set(reflect.MakeSlice(t, v.Dimensions, v.Dimensions)) + if rv.Kind() == reflect.Interface { + rv = rv.Elem() + } } elemSize := len(data) / v.Dimensions for i := 0; i < v.Dimensions; i++ { @@ -173,7 +183,7 @@ func (v VectorType) Unmarshal(data []byte, value interface{}) error { } return nil } - return unmarshalErrorf("can not unmarshal %s into %T. Accepted types: slice, array.", v, value) + return unmarshalErrorf("can not unmarshal %s into %T. Accepted types: *slice, *array, *interface{}.", v, value) } // isVectorVariableLengthType determines if a type requires explicit length serialization within a vector. --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@cassandra.apache.org For additional commands, e-mail: commits-h...@cassandra.apache.org