This is an automated email from the ASF dual-hosted git repository.

joaoreis pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/cassandra-gocql-driver.git


The following commit(s) were added to refs/heads/trunk by this push:
     new c65c762b Iter.MapScan is unable to scan data of user-defined types fix
c65c762b is described below

commit c65c762b83eccdb6a6a083c123a5935c4302745c
Author: Bohdan Siryk <[email protected]>
AuthorDate: Tue Mar 31 12:51:22 2026 +0300

    Iter.MapScan is unable to scan data of user-defined types fix
    
    Previously, Iter.MapScan was unable to scan columns of user-defined types 
due to a bug which leaded
    to never updated destination variable. With this patch it is handled 
correctly.
    
    Patch by Bohdan Siryk; reviewed by João Reis, James Hartig for CASSGO-115
---
 CHANGELOG.md      |  6 +++++
 cassandra_test.go | 59 +++++++++++++++++++++++++++++++++++++++++++++
 marshal.go        | 71 +++++++++++++++++++++++++++++++------------------------
 marshal_test.go   | 22 +++++++++++++++++
 4 files changed, 127 insertions(+), 31 deletions(-)

diff --git a/CHANGELOG.md b/CHANGELOG.md
index c064c600..3b07b9bf 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -5,6 +5,12 @@ All notable changes to this project will be documented in this 
file.
 The format is based on [Keep a 
Changelog](https://keepachangelog.com/en/1.0.0/),
 and this project adheres to [Semantic 
Versioning](https://semver.org/spec/v2.0.0.html).
 
+## [2.1.1]
+
+### Fixed
+
+- Iter.MapScan is unable to scan data of user-defined types fix (CASSGO-115)
+
 ## [2.1.0]
 
 ### Added
diff --git a/cassandra_test.go b/cassandra_test.go
index d33de21c..2f386c80 100644
--- a/cassandra_test.go
+++ b/cassandra_test.go
@@ -4617,3 +4617,62 @@ func TestNewSession_SchemaListenersValidation(t 
*testing.T) {
                })
        }
 }
+
+func TestIterMapScanUDT(t *testing.T) {
+       session := createSession(t)
+       defer session.Close()
+
+       err := createTable(session, `CREATE TYPE IF NOT EXISTS 
gocql_test.top_level_mapscan_udt (
+         field_a text,
+         field_b int
+       );`)
+       require.NoError(t, err)
+
+       err = createTable(session, `CREATE TABLE IF NOT EXISTS 
gocql_test.top_level_mapscan_udt_table (
+         id int PRIMARY KEY,
+         value frozen<top_level_mapscan_udt>
+       );`)
+       require.NoError(t, err)
+
+       value := map[string]interface{}{
+               "field_a": "test_text",
+               "field_b": 42,
+       }
+
+       err = session.Query("INSERT INTO top_level_mapscan_udt_table (id, 
value) VALUES (?, ?)", 1, value).Exec()
+       require.NoError(t, err)
+
+       var scanned map[string]interface{}
+       err = session.Query("SELECT value FROM top_level_mapscan_udt_table 
WHERE id = ?", 1).Scan(&scanned)
+       require.NoError(t, err)
+
+       require.Equal(t, value["field_a"], scanned["field_a"])
+       require.Equal(t, value["field_b"], scanned["field_b"])
+
+       rawResult := map[string]interface{}{}
+       rawResultIter := session.Query("SELECT value FROM 
top_level_mapscan_udt_table WHERE id = ?", 1).Iter()
+       rawResultIter.MapScan(rawResult)
+       err = rawResultIter.Close()
+       require.NoError(t, err)
+
+       rawValue, ok := rawResult["value"].(map[string]interface{})
+       require.True(t, ok, "expected MapScan() value column to be 
map[string]interface{} got %T", rawResult["value"])
+       require.Equal(t, value["field_a"], rawValue["field_a"])
+       require.Equal(t, value["field_b"], rawValue["field_b"])
+
+       // Test for null udt value
+       err = session.Query("INSERT INTO top_level_mapscan_udt_table (id) 
VALUES (?)", 2).Exec()
+       require.NoError(t, err)
+
+       scanned = nil
+       err = session.Query("SELECT value FROM top_level_mapscan_udt_table 
WHERE id = ?", 2).Scan(&scanned)
+       require.NoError(t, err)
+       require.Nil(t, scanned)
+
+       rawResult = map[string]interface{}{}
+       rawResultIter = session.Query("SELECT value FROM 
top_level_mapscan_udt_table WHERE id = ?", 2).Iter()
+       rawResultIter.MapScan(rawResult)
+       err = rawResultIter.Close()
+       require.NoError(t, err)
+       require.Nil(t, rawResult["value"])
+}
diff --git a/marshal.go b/marshal.go
index 9252c8c0..972ed2f0 100644
--- a/marshal.go
+++ b/marshal.go
@@ -2921,12 +2921,6 @@ func (udt UDTTypeInfo) Marshal(value interface{}) 
([]byte, error) {
 
 // Unmarshal unmarshals the byte slice into the value.
 func (udt UDTTypeInfo) Unmarshal(data []byte, value interface{}) error {
-       // do this up here so we don't need to duplicate all of the map logic 
below
-       if iptr, ok := value.(*interface{}); ok && iptr != nil {
-               v := map[string]interface{}{}
-               *iptr = v
-               value = &v
-       }
        switch v := value.(type) {
        case UDTUnmarshaler:
                for id, e := range udt.Elements {
@@ -2945,34 +2939,18 @@ func (udt UDTTypeInfo) Unmarshal(data []byte, value 
interface{}) error {
                }
 
                return nil
-       case *map[string]interface{}:
-               if data == nil {
-                       *v = nil
-                       return nil
-               }
-
-               m := map[string]interface{}{}
-               *v = m
-
-               for id, e := range udt.Elements {
-                       if len(data) == 0 {
-                               return nil
-                       }
-                       if len(data) < 4 {
-                               return unmarshalErrorf("can not unmarshal UDT: 
field [%d]%s: unexpected eof", id, e.Name)
-                       }
-
-                       var p []byte
-                       p, data = readBytes(data)
-
-                       v := reflect.New(reflect.TypeOf(e.Type.Zero()))
-                       if err := Unmarshal(e.Type, p, v.Interface()); err != 
nil {
+       case *interface{}:
+               if v != nil {
+                       // m will be initialized by the unmarshalIntoMap 
function
+                       var m map[string]interface{}
+                       if err := udt.unmarshalIntoMap(data, &m); err != nil {
                                return err
                        }
-                       m[e.Name] = v.Elem().Interface()
+                       *v = m
+                       return nil
                }
-
-               return nil
+       case *map[string]interface{}:
+               return udt.unmarshalIntoMap(data, v)
        }
 
        rv := reflect.ValueOf(value)
@@ -3037,6 +3015,37 @@ func (udt UDTTypeInfo) Unmarshal(data []byte, value 
interface{}) error {
        return nil
 }
 
+// Unmarshals data into map and store its pointer in the dstMap.
+func (udt UDTTypeInfo) unmarshalIntoMap(data []byte, dstMap 
*map[string]interface{}) error {
+       if data == nil {
+               *dstMap = nil
+               return nil
+       }
+
+       m := map[string]interface{}{}
+       *dstMap = m
+
+       for id, e := range udt.Elements {
+               if len(data) == 0 {
+                       return nil
+               }
+               if len(data) < 4 {
+                       return unmarshalErrorf("can not unmarshal UDT: field 
[%d]%s: unexpected eof", id, e.Name)
+               }
+
+               var p []byte
+               p, data = readBytes(data)
+
+               v := reflect.New(reflect.TypeOf(e.Type.Zero()))
+               if err := Unmarshal(e.Type, p, v.Interface()); err != nil {
+                       return err
+               }
+               m[e.Name] = v.Elem().Interface()
+       }
+
+       return nil
+}
+
 // MarshalError represents an error that occurred during marshaling.
 type MarshalError string
 
diff --git a/marshal_test.go b/marshal_test.go
index 8a7b989e..55b70dd1 100644
--- a/marshal_test.go
+++ b/marshal_test.go
@@ -2575,6 +2575,28 @@ func TestUnmarshalUDT(t *testing.T) {
                        t.Errorf(`Expected 42 for second but received: %T(%v)`, 
value["second"], value["second"])
                }
        }
+
+       interfaceValue := interface{}(map[string]interface{}{})
+       err = Unmarshal(info, data, &interfaceValue)
+       if err != nil {
+               t.Error(err)
+       }
+
+       result, ok := interfaceValue.(map[string]interface{})
+       if !ok {
+               t.Error("expected result to be map[string]interface{}")
+       }
+       if result == nil {
+               t.Error("expected result to be not nil")
+       }
+
+       if result["first"] != "Hello" {
+               t.Error("expected result[first] to be Hello")
+       }
+
+       if result["second"] != int16(42) {
+               t.Error("expected result[second] to be 42")
+       }
 }
 
 // bytesWithLength concatenates all data slices and prepends the total length 
as uint32.


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to