This is an automated email from the ASF dual-hosted git repository.

jameshartig pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/cassandra-gocql-driver.git


The following commit(s) were added to refs/heads/trunk by this push:
     new 22ab88e7 CASSGO-92: add public method to retrieve StatementMetadata 
and LogField methods
22ab88e7 is described below

commit 22ab88e75597baf630dc553439039ca1f2ad3bfc
Author: James Hartig <[email protected]>
AuthorDate: Thu Oct 23 18:01:15 2025 +0000

    CASSGO-92: add public method to retrieve StatementMetadata and LogField 
methods
    
    StatementMetadata can be called on a session to get the bind, result,
    and pk information for a given query. Previously this wasn't publicly
    exposed but is necessary for some implementations of HostSelectionPolicy
    like token-aware. This might also be useful for CI tooling or runtime
    analysis of queries and the types of columns.
    
    NewLogField* are methods to to return a LogField with name and a specific
    type.
    
    Finally, session init was cleaned up to prevent a HostSelectionPolicy
    from causing a panic if it tried to make a query during init. The
    interface was documented that queries should not be attempted.
    
    Patch by James Hartig for CASSGO-92; reviewed by João Reis for CASSGO-92
---
 CHANGELOG.md      |  11 ++
 cassandra_test.go | 139 +++++++++++++------------
 cluster.go        |   4 +-
 conn.go           |  16 +--
 conn_test.go      |  14 +--
 connectionpool.go |  18 ++--
 control.go        |  62 ++++++------
 events.go         |  18 ++--
 frame.go          |  12 +--
 host_source.go    |  10 +-
 logger.go         |  20 ++--
 policies.go       |  12 ++-
 query_executor.go |  20 ++--
 session.go        | 296 ++++++++++++++++++++++++++----------------------------
 topology.go       |   6 +-
 15 files changed, 342 insertions(+), 316 deletions(-)

diff --git a/CHANGELOG.md b/CHANGELOG.md
index 2fa50aa0..66a6067e 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -5,6 +5,17 @@ All notable changes to this project will be documented in this 
file.
 The format is based on [Keep a 
Changelog](https://keepachangelog.com/en/1.0.0/),
 and this project adheres to [Semantic 
Versioning](https://semver.org/spec/v2.0.0.html).
 
+## [2.1.0]
+
+### Added
+
+- Session.StatementMetadata (CASSGO-92)
+- NewLogFieldIP, NewLogFieldError, NewLogFieldStringer, NewLogFieldString, 
NewLogFieldInt, NewLogFieldBool (CASSGO-92)
+
+### Fixed
+
+- Prevent panic with queries during session init (CASSGO-92)
+
 ## [2.0.0]
 
 ### Removed
diff --git a/cassandra_test.go b/cassandra_test.go
index bc260e44..937a7c05 100644
--- a/cassandra_test.go
+++ b/cassandra_test.go
@@ -2794,66 +2794,62 @@ func TestKeyspaceMetadata(t *testing.T) {
 }
 
 // Integration test of the routing key calculation
-func TestRoutingKey(t *testing.T) {
+func TestRoutingStatementMetadata(t *testing.T) {
        session := createSession(t)
        defer session.Close()
 
-       if err := createTable(session, "CREATE TABLE 
gocql_test.test_single_routing_key (first_id int, second_id int, PRIMARY KEY 
(first_id, second_id))"); err != nil {
+       if err := createTable(session, "CREATE TABLE 
gocql_test.test_single_routing_key (first_id int, second_id varchar, PRIMARY 
KEY (first_id, second_id))"); err != nil {
                t.Fatalf("failed to create table with error '%v'", err)
        }
-       if err := createTable(session, "CREATE TABLE 
gocql_test.test_composite_routing_key (first_id int, second_id int, PRIMARY KEY 
((first_id, second_id)))"); err != nil {
+       if err := createTable(session, "CREATE TABLE 
gocql_test.test_composite_routing_key (first_id int, second_id varchar, PRIMARY 
KEY ((first_id, second_id)))"); err != nil {
                t.Fatalf("failed to create table with error '%v'", err)
        }
 
-       routingKeyInfo, err := session.routingKeyInfo(context.Background(), 
"SELECT * FROM test_single_routing_key WHERE second_id=? AND first_id=?", "")
+       meta, err := session.routingStatementMetadata(context.Background(), 
"SELECT * FROM test_single_routing_key WHERE second_id=? AND first_id=?", "")
        if err != nil {
-               t.Fatalf("failed to get routing key info due to error: %v", err)
+               t.Fatalf("failed to get routing statement metadata due to 
error: %v", err)
        }
-       if routingKeyInfo == nil {
-               t.Fatal("Expected routing key info, but was nil")
+       if meta == nil {
+               t.Fatal("Expected routing statement metadata, but was nil")
        }
-       if len(routingKeyInfo.indexes) != 1 {
-               t.Fatalf("Expected routing key indexes length to be 1 but was 
%d", len(routingKeyInfo.indexes))
+       if len(meta.PKBindColumnIndexes) != 1 {
+               t.Fatalf("Expected routing statement metadata 
PKBindColumnIndexes length to be 1 but was %d", len(meta.PKBindColumnIndexes))
        }
-       if routingKeyInfo.indexes[0] != 1 {
-               t.Errorf("Expected routing key index[0] to be 1 but was %d", 
routingKeyInfo.indexes[0])
+       if meta.PKBindColumnIndexes[0] != 1 {
+               t.Errorf("Expected routing statement metadata 
PKBindColumnIndexes[0] to be 1 but was %d", meta.PKBindColumnIndexes[0])
        }
-       if len(routingKeyInfo.types) != 1 {
-               t.Fatalf("Expected routing key types length to be 1 but was 
%d", len(routingKeyInfo.types))
+       if len(meta.BindColumns) != 2 {
+               t.Fatalf("Expected routing statement metadata BindColumns 
length to be 2 but was %d", len(meta.BindColumns))
        }
-       if routingKeyInfo.types[0] == nil {
-               t.Fatal("Expected routing key types[0] to be non-nil")
+       if meta.BindColumns[0].TypeInfo.Type() != TypeVarchar {
+               t.Fatalf("Expected routing statement metadata 
BindColumns[0].TypeInfo.Type to be %v but was %v", TypeVarchar, 
meta.BindColumns[0].TypeInfo.Type())
        }
-       if routingKeyInfo.types[0].Type() != TypeInt {
-               t.Fatalf("Expected routing key types[0].Type to be %v but was 
%v", TypeInt, routingKeyInfo.types[0].Type())
+       if meta.BindColumns[1].TypeInfo.Type() != TypeInt {
+               t.Fatalf("Expected routing statement metadata 
BindColumns[1].TypeInfo.Type to be %v but was %v", TypeInt, 
meta.BindColumns[1].TypeInfo.Type())
        }
-
-       // verify the cache is working
-       routingKeyInfo, err = session.routingKeyInfo(context.Background(), 
"SELECT * FROM test_single_routing_key WHERE second_id=? AND first_id=?", "")
-       if err != nil {
-               t.Fatalf("failed to get routing key info due to error: %v", err)
+       if len(meta.ResultColumns) != 2 {
+               t.Fatalf("Expected routing statement metadata ResultColumns 
length to be 2 but was %d", len(meta.ResultColumns))
        }
-       if len(routingKeyInfo.indexes) != 1 {
-               t.Fatalf("Expected routing key indexes length to be 1 but was 
%d", len(routingKeyInfo.indexes))
+       if meta.ResultColumns[0].Name != "first_id" {
+               t.Fatalf("Expected routing statement metadata 
ResultColumns[0].Name to be %v but was %v", "first_id", 
meta.ResultColumns[0].Name)
        }
-       if routingKeyInfo.indexes[0] != 1 {
-               t.Errorf("Expected routing key index[0] to be 1 but was %d", 
routingKeyInfo.indexes[0])
+       if meta.ResultColumns[0].TypeInfo.Type() != TypeInt {
+               t.Fatalf("Expected routing statement metadata 
ResultColumns[0].TypeInfo.Type to be %v but was %v", TypeInt, 
meta.ResultColumns[0].TypeInfo.Type())
        }
-       if len(routingKeyInfo.types) != 1 {
-               t.Fatalf("Expected routing key types length to be 1 but was 
%d", len(routingKeyInfo.types))
+       if meta.ResultColumns[1].Name != "second_id" {
+               t.Fatalf("Expected routing statement metadata 
ResultColumns[1].Name to be %v but was %v", "second_id", 
meta.ResultColumns[1].Name)
        }
-       if routingKeyInfo.types[0] == nil {
-               t.Fatal("Expected routing key types[0] to be non-nil")
+       if meta.ResultColumns[1].TypeInfo.Type() != TypeVarchar {
+               t.Fatalf("Expected routing statement metadata 
ResultColumns[1].TypeInfo.Type to be %v but was %v", TypeVarchar, 
meta.ResultColumns[1].TypeInfo.Type())
        }
-       if routingKeyInfo.types[0].Type() != TypeInt {
-               t.Fatalf("Expected routing key types[0] to be %v but was %v", 
TypeInt, routingKeyInfo.types[0].Type())
-       }
-       cacheSize := session.routingKeyInfoCache.lru.Len()
+
+       // verify the cache is working
+       cacheSize := session.routingMetadataCache.lru.Len()
        if cacheSize != 1 {
                t.Errorf("Expected cache size to be 1 but was %d", cacheSize)
        }
 
-       query := newInternalQuery(session.Query("SELECT * FROM 
test_single_routing_key WHERE second_id=? AND first_id=?", 1, 2), nil)
+       query := newInternalQuery(session.Query("SELECT * FROM 
test_single_routing_key WHERE second_id=? AND first_id=?", "1", 2), nil)
        routingKey, err := query.GetRoutingKey()
        if err != nil {
                t.Fatalf("Failed to get routing key due to error: %v", err)
@@ -2863,50 +2859,59 @@ func TestRoutingKey(t *testing.T) {
                t.Errorf("Expected routing key %v but was %v", 
expectedRoutingKey, routingKey)
        }
 
-       routingKeyInfo, err = session.routingKeyInfo(context.Background(), 
"SELECT * FROM test_composite_routing_key WHERE second_id=? AND first_id=?", "")
+       meta, err = session.routingStatementMetadata(context.Background(), 
"SELECT * FROM test_composite_routing_key WHERE second_id=? AND first_id=?", "")
        if err != nil {
-               t.Fatalf("failed to get routing key info due to error: %v", err)
+               t.Fatalf("failed to get routing statement metadata due to 
error: %v", err)
+       }
+       if meta == nil {
+               t.Fatal("Expected routing statement metadata, but was nil")
+       }
+       if len(meta.PKBindColumnIndexes) != 2 {
+               t.Fatalf("Expected routing statement metadata 
PKBindColumnIndexes length to be 2 but was %d", len(meta.PKBindColumnIndexes))
+       }
+       if meta.PKBindColumnIndexes[0] != 1 {
+               t.Errorf("Expected routing statement metadata 
PKBindColumnIndexes[0] to be 1 but was %d", meta.PKBindColumnIndexes[0])
        }
-       if routingKeyInfo == nil {
-               t.Fatal("Expected routing key info, but was nil")
+       if meta.PKBindColumnIndexes[1] != 0 {
+               t.Errorf("Expected routing statement metadata 
PKBindColumnIndexes[1] to be 0 but was %d", meta.PKBindColumnIndexes[1])
        }
-       if len(routingKeyInfo.indexes) != 2 {
-               t.Fatalf("Expected routing key indexes length to be 2 but was 
%d", len(routingKeyInfo.indexes))
+       if len(meta.BindColumns) != 2 {
+               t.Fatalf("Expected routing statement metadata BindColumns 
length to be 2 but was %d", len(meta.BindColumns))
        }
-       if routingKeyInfo.indexes[0] != 1 {
-               t.Errorf("Expected routing key index[0] to be 1 but was %d", 
routingKeyInfo.indexes[0])
+       if meta.BindColumns[0].TypeInfo.Type() != TypeVarchar {
+               t.Fatalf("Expected routing statement metadata 
BindColumns[0].TypeInfo.Type to be %v but was %v", TypeVarchar, 
meta.BindColumns[0].TypeInfo.Type())
        }
-       if routingKeyInfo.indexes[1] != 0 {
-               t.Errorf("Expected routing key index[1] to be 0 but was %d", 
routingKeyInfo.indexes[1])
+       if meta.BindColumns[1].TypeInfo.Type() != TypeInt {
+               t.Fatalf("Expected routing statement metadata 
BindColumns[1].TypeInfo.Type to be %v but was %v", TypeInt, 
meta.BindColumns[1].TypeInfo.Type())
        }
-       if len(routingKeyInfo.types) != 2 {
-               t.Fatalf("Expected routing key types length to be 1 but was 
%d", len(routingKeyInfo.types))
+       if len(meta.ResultColumns) != 2 {
+               t.Fatalf("Expected routing statement metadata ResultColumns 
length to be 2 but was %d", len(meta.ResultColumns))
        }
-       if routingKeyInfo.types[0] == nil {
-               t.Fatal("Expected routing key types[0] to be non-nil")
+       if meta.ResultColumns[0].Name != "first_id" {
+               t.Fatalf("Expected routing statement metadata 
ResultColumns[0].Name to be %v but was %v", "first_id", 
meta.ResultColumns[0].Name)
        }
-       if routingKeyInfo.types[0].Type() != TypeInt {
-               t.Fatalf("Expected routing key types[0] to be %v but was %v", 
TypeInt, routingKeyInfo.types[0].Type())
+       if meta.ResultColumns[0].TypeInfo.Type() != TypeInt {
+               t.Fatalf("Expected routing statement metadata 
ResultColumns[0].TypeInfo.Type to be %v but was %v", TypeInt, 
meta.ResultColumns[0].TypeInfo.Type())
        }
-       if routingKeyInfo.types[1] == nil {
-               t.Fatal("Expected routing key types[1] to be non-nil")
+       if meta.ResultColumns[1].Name != "second_id" {
+               t.Fatalf("Expected routing statement metadata 
ResultColumns[1].Name to be %v but was %v", "second_id", 
meta.ResultColumns[1].Name)
        }
-       if routingKeyInfo.types[1].Type() != TypeInt {
-               t.Fatalf("Expected routing key types[0] to be %v but was %v", 
TypeInt, routingKeyInfo.types[1].Type())
+       if meta.ResultColumns[1].TypeInfo.Type() != TypeVarchar {
+               t.Fatalf("Expected routing statement metadata 
ResultColumns[1].TypeInfo.Type to be %v but was %v", TypeVarchar, 
meta.ResultColumns[1].TypeInfo.Type())
        }
 
-       query = newInternalQuery(session.Query("SELECT * FROM 
test_composite_routing_key WHERE second_id=? AND first_id=?", 1, 2), nil)
+       query = newInternalQuery(session.Query("SELECT * FROM 
test_composite_routing_key WHERE second_id=? AND first_id=?", "1", 2), nil)
        routingKey, err = query.GetRoutingKey()
        if err != nil {
                t.Fatalf("Failed to get routing key due to error: %v", err)
        }
-       expectedRoutingKey = []byte{0, 4, 0, 0, 0, 2, 0, 0, 4, 0, 0, 0, 1, 0}
+       expectedRoutingKey = []byte{0, 4, 0, 0, 0, 2, 0, 0, 1, 49, 0}
        if !reflect.DeepEqual(expectedRoutingKey, routingKey) {
                t.Errorf("Expected routing key %v but was %v", 
expectedRoutingKey, routingKey)
        }
 
        // verify the cache is working
-       cacheSize = session.routingKeyInfoCache.lru.Len()
+       cacheSize = session.routingMetadataCache.lru.Len()
        if cacheSize != 2 {
                t.Errorf("Expected cache size to be 2 but was %d", cacheSize)
        }
@@ -3956,17 +3961,17 @@ func TestRoutingKeyCacheUsesOverriddenKeyspace(t 
*testing.T) {
                t.Fatal(err)
        }
 
-       getRoutingKeyInfo := func(key string) *routingKeyInfo {
+       getStatementMetadata := func(key string) *StatementMetadata {
                t.Helper()
-               session.routingKeyInfoCache.mu.Lock()
-               value, ok := session.routingKeyInfoCache.lru.Get(key)
+               session.routingMetadataCache.mu.Lock()
+               value, ok := session.routingMetadataCache.lru.Get(key)
                if !ok {
                        t.Fatalf("routing key not found in cache for key %v", 
key)
                }
-               session.routingKeyInfoCache.mu.Unlock()
+               session.routingMetadataCache.mu.Unlock()
 
                inflight := value.(*inflightCachedEntry)
-               return inflight.value.(*routingKeyInfo)
+               return inflight.value.(*StatementMetadata)
        }
 
        const insertQuery = "INSERT INTO routing_key_cache_uses_overridden_ks 
(id) VALUES (?)"
@@ -3979,8 +3984,8 @@ func TestRoutingKeyCacheUsesOverriddenKeyspace(t 
*testing.T) {
        require.NoError(t, err)
 
        // Ensuring that the cache contains the query with default ks
-       routingKeyInfo1 := getRoutingKeyInfo("gocql_test" + b1.Entries[0].Stmt)
-       require.Equal(t, "gocql_test", routingKeyInfo1.keyspace)
+       meta1 := getStatementMetadata("gocql_test" + b1.Entries[0].Stmt)
+       require.Equal(t, "gocql_test", meta1.Keyspace)
 
        // Running batch in gocql_test_routing_key_cache ks
        b2 := session.Batch(LoggedBatch)
@@ -3991,8 +3996,8 @@ func TestRoutingKeyCacheUsesOverriddenKeyspace(t 
*testing.T) {
        require.NoError(t, err)
 
        // Ensuring that the cache contains the query with 
gocql_test_routing_key_cache ks
-       routingKeyInfo2 := getRoutingKeyInfo("gocql_test_routing_key_cache" + 
b2.Entries[0].Stmt)
-       require.Equal(t, "gocql_test_routing_key_cache", 
routingKeyInfo2.keyspace)
+       meta2 := getStatementMetadata("gocql_test_routing_key_cache" + 
b2.Entries[0].Stmt)
+       require.Equal(t, "gocql_test_routing_key_cache", meta2.Keyspace)
 
        const selectStmt = "SELECT * FROM routing_key_cache_uses_overridden_ks 
WHERE id=?"
 
diff --git a/cluster.go b/cluster.go
index aec8be64..ebd4a186 100644
--- a/cluster.go
+++ b/cluster.go
@@ -352,8 +352,8 @@ func (cfg *ClusterConfig) translateAddressPort(addr net.IP, 
port int, logger Str
        }
        newAddr, newPort := cfg.AddressTranslator.Translate(addr, port)
        logger.Debug("Translating address.",
-               newLogFieldIp("old_addr", addr), newLogFieldInt("old_port", 
port),
-               newLogFieldIp("new_addr", newAddr), newLogFieldInt("new_port", 
newPort))
+               NewLogFieldIP("old_addr", addr), NewLogFieldInt("old_port", 
port),
+               NewLogFieldIP("new_addr", newAddr), NewLogFieldInt("new_port", 
newPort))
        return newAddr, newPort
 }
 
diff --git a/conn.go b/conn.go
index ddf13081..40044565 100644
--- a/conn.go
+++ b/conn.go
@@ -709,7 +709,7 @@ func (c *Conn) processFrame(ctx context.Context, r 
io.Reader) error {
        delete(c.calls, head.stream)
        c.mu.Unlock()
        if call == nil || !ok {
-               c.logger.Warning("Received response for stream which has no 
handler.", newLogFieldString("header", head.String()))
+               c.logger.Warning("Received response for stream which has no 
handler.", NewLogFieldString("header", head.String()))
                return c.discardFrame(r, head)
        } else if head.stream != call.streamID {
                panic(fmt.Sprintf("call has incorrect streamID: got %d expected 
%d", call.streamID, head.stream))
@@ -1316,7 +1316,7 @@ func (c *Conn) execInternal(ctx context.Context, req 
frameBuilder, tracer Tracer
                        responseFrame, err := resp.framer.parseFrame()
                        if err != nil {
                                c.logger.Warning("Framer error while attempting 
to parse potential protocol error.",
-                                       newLogFieldError("err", err))
+                                       NewLogFieldError("err", err))
                                return nil, errProtocol
                        }
                        //goland:noinspection GoTypeAssertionOnErrors
@@ -1333,17 +1333,17 @@ func (c *Conn) execInternal(ctx context.Context, req 
frameBuilder, tracer Tracer
        case <-timeoutCh:
                close(call.timeout)
                c.logger.Debug("Request timed out on connection.",
-                       newLogFieldString("host_id", c.host.HostID()), 
newLogFieldIp("addr", c.host.ConnectAddress()))
+                       NewLogFieldString("host_id", c.host.HostID()), 
NewLogFieldIP("addr", c.host.ConnectAddress()))
                return nil, ErrTimeoutNoResponse
        case <-ctxDone:
                c.logger.Debug("Request failed because context elapsed out on 
connection.",
-                       newLogFieldString("host_id", c.host.HostID()), 
newLogFieldIp("addr", c.host.ConnectAddress()),
-                       newLogFieldError("ctx_err", ctx.Err()))
+                       NewLogFieldString("host_id", c.host.HostID()), 
NewLogFieldIP("addr", c.host.ConnectAddress()),
+                       NewLogFieldError("ctx_err", ctx.Err()))
                close(call.timeout)
                return nil, ctx.Err()
        case <-c.ctx.Done():
                c.logger.Debug("Request failed because connection closed.",
-                       newLogFieldString("host_id", c.host.HostID()), 
newLogFieldIp("addr", c.host.ConnectAddress()))
+                       NewLogFieldString("host_id", c.host.HostID()), 
NewLogFieldIP("addr", c.host.ConnectAddress()))
                close(call.timeout)
                return nil, ErrConnectionClosed
        }
@@ -1698,7 +1698,7 @@ func (c *Conn) executeQuery(ctx context.Context, q 
*internalQuery) *Iter {
                iter.framer = framer
                if err := c.awaitSchemaAgreement(ctx); err != nil {
                        // TODO: should have this behind a flag
-                       c.logger.Warning("Error while awaiting for schema 
agreement after a schema change event.", newLogFieldError("err", err))
+                       c.logger.Warning("Error while awaiting for schema 
agreement after a schema change event.", NewLogFieldError("err", err))
                }
                // dont return an error from this, might be a good idea to give 
a warning
                // though. The impact of this returning an error would be that 
the cluster
@@ -1956,7 +1956,7 @@ func (c *Conn) awaitSchemaAgreement(ctx context.Context) 
(err error) {
                                goto cont
                        }
                        if !isValidPeer(host) || host.schemaVersion == "" {
-                               c.logger.Warning("Invalid peer or peer with 
empty schema_version.", newLogFieldIp("peer", host.ConnectAddress()))
+                               c.logger.Warning("Invalid peer or peer with 
empty schema_version.", NewLogFieldIP("peer", host.ConnectAddress()))
                                continue
                        }
 
diff --git a/conn_test.go b/conn_test.go
index 3646dcc8..60e4a2a8 100644
--- a/conn_test.go
+++ b/conn_test.go
@@ -356,13 +356,13 @@ func (o *testQueryObserver) ObserveQuery(ctx 
context.Context, q ObservedQuery) {
        host := q.Host.ConnectAddress().String()
        o.metrics[host] = q.Metrics
        o.logger.Debug("Observed query.",
-               newLogFieldString("stmt", q.Statement),
-               newLogFieldInt("rows", q.Rows),
-               newLogFieldString("duration", q.End.Sub(q.Start).String()),
-               newLogFieldString("host", host),
-               newLogFieldInt("attempts", q.Metrics.Attempts),
-               newLogFieldString("latency", 
strconv.FormatInt(q.Metrics.TotalLatency, 10)),
-               newLogFieldError("err", q.Err))
+               NewLogFieldString("stmt", q.Statement),
+               NewLogFieldInt("rows", q.Rows),
+               NewLogFieldString("duration", q.End.Sub(q.Start).String()),
+               NewLogFieldString("host", host),
+               NewLogFieldInt("attempts", q.Metrics.Attempts),
+               NewLogFieldString("latency", 
strconv.FormatInt(q.Metrics.TotalLatency, 10)),
+               NewLogFieldError("err", q.Err))
 }
 
 func (o *testQueryObserver) GetMetrics(host *HostInfo) *hostMetrics {
diff --git a/connectionpool.go b/connectionpool.go
index f316b569..56ca5370 100644
--- a/connectionpool.go
+++ b/connectionpool.go
@@ -494,11 +494,11 @@ func (pool *hostConnPool) logConnectErr(err error) {
                // connection refused
                // these are typical during a node outage so avoid log spam.
                pool.logger.Debug("Pool unable to establish a connection to 
host.",
-                       newLogFieldIp("host_addr", pool.host.ConnectAddress()), 
newLogFieldString("host_id", pool.host.HostID()), newLogFieldError("err", err))
+                       NewLogFieldIP("host_addr", pool.host.ConnectAddress()), 
NewLogFieldString("host_id", pool.host.HostID()), NewLogFieldError("err", err))
        } else if err != nil {
                // unexpected error
                pool.logger.Debug("Pool failed to connect to host due to 
error.",
-                       newLogFieldIp("host_addr", pool.host.ConnectAddress()), 
newLogFieldString("host_id", pool.host.HostID()), newLogFieldError("err", err))
+                       NewLogFieldIP("host_addr", pool.host.ConnectAddress()), 
NewLogFieldString("host_id", pool.host.HostID()), NewLogFieldError("err", err))
        }
 }
 
@@ -506,7 +506,7 @@ func (pool *hostConnPool) logConnectErr(err error) {
 func (pool *hostConnPool) fillingStopped(err error) {
        if err != nil {
                pool.logger.Warning("Connection pool filling failed.",
-                       newLogFieldIp("host_addr", pool.host.ConnectAddress()), 
newLogFieldString("host_id", pool.host.HostID()), newLogFieldError("err", err))
+                       NewLogFieldIP("host_addr", pool.host.ConnectAddress()), 
NewLogFieldString("host_id", pool.host.HostID()), NewLogFieldError("err", err))
                // wait for some time to avoid back-to-back filling
                // this provides some time between failed attempts
                // to fill the pool for the host to recover
@@ -523,7 +523,7 @@ func (pool *hostConnPool) fillingStopped(err error) {
        // if we errored and the size is now zero, make sure the host is marked 
as down
        // see https://github.com/apache/cassandra-gocql-driver/issues/1614
        pool.logger.Debug("Logging number of connections of pool after filling 
stopped.",
-               newLogFieldIp("host_addr", host.ConnectAddress()), 
newLogFieldString("host_id", host.HostID()), newLogFieldInt("count", count))
+               NewLogFieldIP("host_addr", host.ConnectAddress()), 
NewLogFieldString("host_id", host.HostID()), NewLogFieldInt("count", count))
        if err != nil && count == 0 {
                if pool.session.cfg.ConvictionPolicy.AddFailure(err, host) {
                        pool.session.handleNodeDown(host.ConnectAddress(), port)
@@ -580,10 +580,10 @@ func (pool *hostConnPool) connect() (err error) {
                        }
                }
                pool.logger.Warning("Pool failed to connect to host. 
Reconnecting according to the reconnection policy.",
-                       newLogFieldIp("host", pool.host.ConnectAddress()),
-                       newLogFieldString("host_id", pool.host.HostID()),
-                       newLogFieldError("err", err),
-                       newLogFieldString("reconnectionPolicy", 
fmt.Sprintf("%T", reconnectionPolicy)))
+                       NewLogFieldIP("host", pool.host.ConnectAddress()),
+                       NewLogFieldString("host_id", pool.host.HostID()),
+                       NewLogFieldError("err", err),
+                       NewLogFieldString("reconnectionPolicy", 
fmt.Sprintf("%T", reconnectionPolicy)))
                time.Sleep(reconnectionPolicy.GetInterval(i))
        }
 
@@ -631,7 +631,7 @@ func (pool *hostConnPool) HandleError(conn *Conn, err 
error, closed bool) {
        }
 
        pool.logger.Info("Pool connection error.",
-               newLogFieldString("addr", conn.addr), newLogFieldError("err", 
err))
+               NewLogFieldString("addr", conn.addr), NewLogFieldError("err", 
err))
 
        // find the connection index
        for i, candidate := range pool.conns {
diff --git a/control.go b/control.go
index de374c58..e59acb40 100644
--- a/control.go
+++ b/control.go
@@ -105,7 +105,7 @@ func (c *controlConn) heartBeat() {
 
                resp, err := c.writeFrame(&writeOptionsFrame{})
                if err != nil {
-                       c.session.logger.Debug("Control connection failed to 
send heartbeat.", newLogFieldError("err", err))
+                       c.session.logger.Debug("Control connection failed to 
send heartbeat.", NewLogFieldError("err", err))
                        goto reconn
                }
 
@@ -115,10 +115,10 @@ func (c *controlConn) heartBeat() {
                        sleepTime = 5 * time.Second
                        continue
                case error:
-                       c.session.logger.Debug("Control connection heartbeat 
failed.", newLogFieldError("err", actualResp))
+                       c.session.logger.Debug("Control connection heartbeat 
failed.", NewLogFieldError("err", actualResp))
                        goto reconn
                default:
-                       c.session.logger.Error("Unknown frame in response to 
options.", newLogFieldString("frame_type", fmt.Sprintf("%T", resp)))
+                       c.session.logger.Error("Unknown frame in response to 
options.", NewLogFieldString("frame_type", fmt.Sprintf("%T", resp)))
                }
 
        reconn:
@@ -270,18 +270,18 @@ func (c *controlConn) discoverProtocol(hosts []*HostInfo) 
(int, error) {
 
                if err == nil {
                        c.session.logger.Debug("Discovered protocol version 
using host.",
-                               newLogFieldInt("protocol_version", 
connCfg.ProtoVersion), newLogFieldIp("host_addr", host.ConnectAddress()), 
newLogFieldString("host_id", host.HostID()))
+                               NewLogFieldInt("protocol_version", 
connCfg.ProtoVersion), NewLogFieldIP("host_addr", host.ConnectAddress()), 
NewLogFieldString("host_id", host.HostID()))
                        return connCfg.ProtoVersion, nil
                }
 
                if proto := parseProtocolFromError(err); proto > 0 {
                        c.session.logger.Debug("Discovered protocol version 
using host after parsing protocol error.",
-                               newLogFieldInt("protocol_version", proto), 
newLogFieldIp("host_addr", host.ConnectAddress()), newLogFieldString("host_id", 
host.HostID()))
+                               NewLogFieldInt("protocol_version", proto), 
NewLogFieldIP("host_addr", host.ConnectAddress()), NewLogFieldString("host_id", 
host.HostID()))
                        return proto, nil
                }
 
                c.session.logger.Debug("Failed to discover protocol version 
using host.",
-                       newLogFieldIp("host_addr", host.ConnectAddress()), 
newLogFieldString("host_id", host.HostID()), newLogFieldError("err", err))
+                       NewLogFieldIP("host_addr", host.ConnectAddress()), 
NewLogFieldString("host_id", host.HostID()), NewLogFieldError("err", err))
        }
 
        return 0, err
@@ -305,10 +305,10 @@ func (c *controlConn) connect(hosts []*HostInfo, 
sessionInit bool) error {
                conn, err = c.session.dial(c.session.ctx, host, &cfg, c)
                if err != nil {
                        c.session.logger.Info("Control connection failed to 
establish a connection to host.",
-                               newLogFieldIp("host_addr", 
host.ConnectAddress()),
-                               newLogFieldInt("port", host.Port()),
-                               newLogFieldString("host_id", host.HostID()),
-                               newLogFieldError("err", err))
+                               NewLogFieldIP("host_addr", 
host.ConnectAddress()),
+                               NewLogFieldInt("port", host.Port()),
+                               NewLogFieldString("host_id", host.HostID()),
+                               NewLogFieldError("err", err))
                        continue
                }
                err = c.setupConn(conn, sessionInit)
@@ -316,10 +316,10 @@ func (c *controlConn) connect(hosts []*HostInfo, 
sessionInit bool) error {
                        break
                }
                c.session.logger.Info("Control connection setup failed after 
connecting to host.",
-                       newLogFieldIp("host_addr", host.ConnectAddress()),
-                       newLogFieldInt("port", host.Port()),
-                       newLogFieldString("host_id", host.HostID()),
-                       newLogFieldError("err", err))
+                       NewLogFieldIP("host_addr", host.ConnectAddress()),
+                       NewLogFieldInt("port", host.Port()),
+                       NewLogFieldString("host_id", host.HostID()),
+                       NewLogFieldError("err", err))
                conn.Close()
                conn = nil
        }
@@ -368,7 +368,7 @@ func (c *controlConn) setupConn(conn *Conn, sessionInit 
bool) error {
                        msg = "Added control host (session initialization)."
                }
                logHelper(c.session.logger, logLevel, msg,
-                       newLogFieldIp("host_addr", host.ConnectAddress()), 
newLogFieldString("host_id", host.HostID()))
+                       NewLogFieldIP("host_addr", host.ConnectAddress()), 
NewLogFieldString("host_id", host.HostID()))
        }
 
        if err := c.registerEvents(conn); err != nil {
@@ -383,7 +383,7 @@ func (c *controlConn) setupConn(conn *Conn, sessionInit 
bool) error {
        c.conn.Store(ch)
 
        c.session.logger.Info("Control connection connected to host.",
-               newLogFieldIp("host_addr", host.ConnectAddress()), 
newLogFieldString("host_id", host.HostID()))
+               NewLogFieldIP("host_addr", host.ConnectAddress()), 
NewLogFieldString("host_id", host.HostID()))
 
        if c.session.initialized() {
                // We connected to control conn, so add the connect the host in 
pool as well.
@@ -445,14 +445,14 @@ func (c *controlConn) reconnect() {
 
        if err != nil {
                c.session.logger.Error("Unable to reconnect control 
connection.",
-                       newLogFieldError("err", err))
+                       NewLogFieldError("err", err))
                return
        }
 
        err = c.session.refreshRing()
        if err != nil {
                c.session.logger.Warning("Unable to refresh ring.",
-                       newLogFieldError("err", err))
+                       NewLogFieldError("err", err))
        }
 }
 
@@ -482,7 +482,7 @@ func (c *controlConn) attemptReconnect() (*Conn, error) {
                return conn, err
        }
 
-       c.session.logger.Error("Unable to connect to any ring node, control 
connection falling back to initial contact points.", newLogFieldError("err", 
err))
+       c.session.logger.Error("Unable to connect to any ring node, control 
connection falling back to initial contact points.", NewLogFieldError("err", 
err))
        // Fallback to initial contact points, as it may be the case that all 
known initialHosts
        // changed their IPs while keeping the same hostname(s).
        initialHosts, resolvErr := addrsToHosts(c.session.cfg.Hosts, 
c.session.cfg.Port, c.session.logger)
@@ -500,10 +500,10 @@ func (c *controlConn) attemptReconnectToAnyOfHosts(hosts 
[]*HostInfo) (*Conn, er
                conn, err = c.session.connect(c.session.ctx, host, c)
                if err != nil {
                        c.session.logger.Info("During reconnection, control 
connection failed to establish a connection to host.",
-                               newLogFieldIp("host_addr", 
host.ConnectAddress()),
-                               newLogFieldInt("port", host.Port()),
-                               newLogFieldString("host_id", host.HostID()),
-                               newLogFieldError("err", err))
+                               NewLogFieldIP("host_addr", 
host.ConnectAddress()),
+                               NewLogFieldInt("port", host.Port()),
+                               NewLogFieldString("host_id", host.HostID()),
+                               NewLogFieldError("err", err))
                        continue
                }
                err = c.setupConn(conn, false)
@@ -511,10 +511,10 @@ func (c *controlConn) attemptReconnectToAnyOfHosts(hosts 
[]*HostInfo) (*Conn, er
                        break
                }
                c.session.logger.Info("During reconnection, control connection 
setup failed after connecting to host.",
-                       newLogFieldIp("host_addr", host.ConnectAddress()),
-                       newLogFieldInt("port", host.Port()),
-                       newLogFieldString("host_id", host.HostID()),
-                       newLogFieldError("err", err))
+                       NewLogFieldIP("host_addr", host.ConnectAddress()),
+                       NewLogFieldInt("port", host.Port()),
+                       NewLogFieldString("host_id", host.HostID()),
+                       NewLogFieldError("err", err))
                conn.Close()
                conn = nil
        }
@@ -535,9 +535,9 @@ func (c *controlConn) HandleError(conn *Conn, err error, 
closed bool) {
        }
 
        c.session.logger.Warning("Control connection error.",
-               newLogFieldIp("host_addr", conn.host.ConnectAddress()),
-               newLogFieldString("host_id", conn.host.HostID()),
-               newLogFieldError("err", err))
+               NewLogFieldIP("host_addr", conn.host.ConnectAddress()),
+               NewLogFieldString("host_id", conn.host.HostID()),
+               NewLogFieldError("err", err))
 
        c.reconnect()
 }
@@ -602,7 +602,7 @@ func (c *controlConn) query(statement string, values 
...interface{}) (iter *Iter
 
                if iter.err != nil {
                        c.session.logger.Warning("Error executing control 
connection statement.",
-                               newLogFieldString("statement", statement), 
newLogFieldError("err", iter.err))
+                               NewLogFieldString("statement", statement), 
NewLogFieldError("err", iter.err))
                }
 
                qry.metrics.attempt(0)
diff --git a/events.go b/events.go
index 8f4bd1db..d511d9ae 100644
--- a/events.go
+++ b/events.go
@@ -104,7 +104,7 @@ func (e *eventDebouncer) debounce(frame frame) {
                e.events = append(e.events, frame)
        } else {
                e.logger.Warning("Event buffer full, dropping event frame.",
-                       newLogFieldString("event_name", e.name), 
newLogFieldStringer("frame", frame))
+                       NewLogFieldString("event_name", e.name), 
NewLogFieldStringer("frame", frame))
        }
 
        e.mu.Unlock()
@@ -113,11 +113,11 @@ func (e *eventDebouncer) debounce(frame frame) {
 func (s *Session) handleEvent(framer *framer) {
        frame, err := framer.parseFrame()
        if err != nil {
-               s.logger.Error("Unable to parse event frame.", 
newLogFieldError("err", err))
+               s.logger.Error("Unable to parse event frame.", 
NewLogFieldError("err", err))
                return
        }
 
-       s.logger.Debug("Handling event frame.", newLogFieldStringer("frame", 
frame))
+       s.logger.Debug("Handling event frame.", NewLogFieldStringer("frame", 
frame))
 
        switch f := frame.(type) {
        case *schemaChangeKeyspace, *schemaChangeFunction,
@@ -128,7 +128,7 @@ func (s *Session) handleEvent(framer *framer) {
                s.nodeEvents.debounce(frame)
        default:
                s.logger.Error("Invalid event frame.",
-                       newLogFieldString("frame_type", fmt.Sprintf("%T", f)), 
newLogFieldStringer("frame", f))
+                       NewLogFieldString("frame_type", fmt.Sprintf("%T", f)), 
NewLogFieldStringer("frame", f))
        }
 }
 
@@ -181,7 +181,7 @@ func (s *Session) handleNodeEvent(frames []frame) {
                switch f := frame.(type) {
                case *topologyChangeEventFrame:
                        s.logger.Info("Received topology change event.",
-                               newLogFieldString("frame", 
strings.Join([]string{f.change, "->", f.host.String(), ":", 
strconv.Itoa(f.port)}, "")))
+                               NewLogFieldString("frame", 
strings.Join([]string{f.change, "->", f.host.String(), ":", 
strconv.Itoa(f.port)}, "")))
                        topologyEventReceived = true
                case *statusChangeEventFrame:
                        event, ok := sEvents[f.host.String()]
@@ -199,7 +199,7 @@ func (s *Session) handleNodeEvent(frames []frame) {
 
        for _, f := range sEvents {
                s.logger.Info("Dispatching status change event.",
-                       newLogFieldString("frame", 
strings.Join([]string{f.change, "->", f.host.String(), ":", 
strconv.Itoa(f.port)}, "")))
+                       NewLogFieldString("frame", 
strings.Join([]string{f.change, "->", f.host.String(), ":", 
strconv.Itoa(f.port)}, "")))
 
                // ignore events we received if they were disabled
                // see 
https://github.com/apache/cassandra-gocql-driver/issues/1591
@@ -218,7 +218,7 @@ func (s *Session) handleNodeEvent(frames []frame) {
 
 func (s *Session) handleNodeUp(eventIp net.IP, eventPort int) {
        s.logger.Info("Node is UP.",
-               newLogFieldStringer("event_ip", eventIp), 
newLogFieldInt("event_port", eventPort))
+               NewLogFieldStringer("event_ip", eventIp), 
NewLogFieldInt("event_port", eventPort))
 
        host, ok := s.ring.getHostByIP(eventIp.String())
        if !ok {
@@ -244,7 +244,7 @@ func (s *Session) startPoolFill(host *HostInfo) {
 
 func (s *Session) handleNodeConnected(host *HostInfo) {
        s.logger.Debug("Pool connected to node.",
-               newLogFieldIp("host_addr", host.ConnectAddress()), 
newLogFieldInt("port", host.Port()), newLogFieldString("host_id", 
host.HostID()))
+               NewLogFieldIP("host_addr", host.ConnectAddress()), 
NewLogFieldInt("port", host.Port()), NewLogFieldString("host_id", 
host.HostID()))
 
        host.setState(NodeUp)
 
@@ -255,7 +255,7 @@ func (s *Session) handleNodeConnected(host *HostInfo) {
 
 func (s *Session) handleNodeDown(ip net.IP, port int) {
        s.logger.Warning("Node is DOWN.",
-               newLogFieldIp("host_addr", ip), newLogFieldInt("port", port))
+               NewLogFieldIP("host_addr", ip), NewLogFieldInt("port", port))
 
        host, ok := s.ring.getHostByIP(ip.String())
        if ok {
diff --git a/frame.go b/frame.go
index 0ee6a2e1..403da877 100644
--- a/frame.go
+++ b/frame.go
@@ -1046,12 +1046,11 @@ func (f *framer) readTypeInfo() (TypeInfo, error) {
 type preparedMetadata struct {
        resultMetadata
 
-       // proto v4+
-       pkeyColumns []int
-
-       keyspace string
-
-       table string
+       // pkeyColumns is only present in protocol v4+
+       pkeyColumns         []int
+       supportsPKeyColumns bool
+       keyspace            string
+       table               string
 }
 
 func (r preparedMetadata) String() string {
@@ -1090,6 +1089,7 @@ func (f *framer) parsePreparedMetadata() 
(preparedMetadata, error) {
                        pkeys[i] = int(c)
                }
                meta.pkeyColumns = pkeys
+               meta.supportsPKeyColumns = true
        }
 
        if meta.flags&flagHasMorePages == flagHasMorePages {
diff --git a/host_source.go b/host_source.go
index ab653740..622d2294 100644
--- a/host_source.go
+++ b/host_source.go
@@ -784,7 +784,7 @@ func (r *ringDescriber) getClusterPeerInfo(localHost 
*HostInfo) ([]*HostInfo, er
                                return nil, fmt.Errorf("unable to fetch peer 
host info: %s", iterErr)
                        }
                        // skip over peers that we couldn't parse
-                       r.session.logger.Warning("Failed to parse peer this 
host will be ignored.", newLogFieldError("err", err))
+                       r.session.logger.Warning("Failed to parse peer this 
host will be ignored.", NewLogFieldError("err", err))
                        continue
                }
                // if nil then none left
@@ -794,7 +794,7 @@ func (r *ringDescriber) getClusterPeerInfo(localHost 
*HostInfo) ([]*HostInfo, er
                if !isValidPeer(host) {
                        // If it's not a valid peer
                        r.session.logger.Warning("Found invalid peer "+
-                               "likely due to a gossip or snitch issue, this 
host will be ignored.", newLogFieldStringer("host", host))
+                               "likely due to a gossip or snitch issue, this 
host will be ignored.", NewLogFieldStringer("host", host))
                        continue
                }
 
@@ -866,7 +866,7 @@ func refreshRing(r *ringDescriber) error {
                }
 
                if host, ok := r.session.ring.addHostIfMissing(h); !ok {
-                       r.session.logger.Info("Adding host.", 
newLogFieldIp("host_addr", h.ConnectAddress()), newLogFieldString("host_id", 
h.HostID()))
+                       r.session.logger.Info("Adding host.", 
NewLogFieldIP("host_addr", h.ConnectAddress()), NewLogFieldString("host_id", 
h.HostID()))
                        r.session.startPoolFill(h)
                } else {
                        // host (by hostID) already exists; determine if IP has 
changed
@@ -885,7 +885,7 @@ func refreshRing(r *ringDescriber) error {
                                if _, alreadyExists := 
r.session.ring.addHostIfMissing(h); alreadyExists {
                                        return fmt.Errorf("add new host=%s 
after removal: %w", h, ErrHostAlreadyExists)
                                }
-                               r.session.logger.Info("Adding host with new IP 
after removing old host.", newLogFieldIp("host_addr", h.ConnectAddress()), 
newLogFieldString("host_id", h.HostID()))
+                               r.session.logger.Info("Adding host with new IP 
after removing old host.", NewLogFieldIP("host_addr", h.ConnectAddress()), 
NewLogFieldString("host_id", h.HostID()))
                                // add new HostInfo (same hostID, new IP)
                                r.session.startPoolFill(h)
                        }
@@ -899,7 +899,7 @@ func refreshRing(r *ringDescriber) error {
 
        r.session.metadata.setPartitioner(partitioner)
        r.session.policy.SetPartitioner(partitioner)
-       r.session.logger.Info("Refreshed ring.", newLogFieldString("ring", 
ringString(r.session.ring.allHosts())))
+       r.session.logger.Info("Refreshed ring.", NewLogFieldString("ring", 
ringString(r.session.ring.allHosts())))
        return nil
 }
 
diff --git a/logger.go b/logger.go
index df865c84..ad8c2e5f 100644
--- a/logger.go
+++ b/logger.go
@@ -48,7 +48,7 @@ func logHelper(logger StructuredLogger, level LogLevel, msg 
string, fields ...Lo
        case LogLevelError:
                logger.Error(msg, fields...)
        default:
-               logger.Error("Unknown log level", newLogFieldInt("level", 
int(level)), newLogFieldString("msg", msg))
+               logger.Error("Unknown log level", NewLogFieldInt("level", 
int(level)), NewLogFieldString("msg", msg))
        }
 }
 
@@ -229,7 +229,8 @@ func newLogField(name string, value LogFieldValue) LogField 
{
        }
 }
 
-func newLogFieldIp(name string, value net.IP) LogField {
+// NewLogFieldIP creates a new LogField with the given name and net.IP.
+func NewLogFieldIP(name string, value net.IP) LogField {
        var str string
        if value == nil {
                str = "<nil>"
@@ -239,7 +240,8 @@ func newLogFieldIp(name string, value net.IP) LogField {
        return newLogField(name, logFieldValueString(str))
 }
 
-func newLogFieldError(name string, value error) LogField {
+// NewLogFieldError creates a new LogField with the given name and error.
+func NewLogFieldError(name string, value error) LogField {
        var str string
        if value != nil {
                str = value.Error()
@@ -247,7 +249,8 @@ func newLogFieldError(name string, value error) LogField {
        return newLogField(name, logFieldValueString(str))
 }
 
-func newLogFieldStringer(name string, value fmt.Stringer) LogField {
+// NewLogFieldStringer creates a new LogField with the given name and 
fmt.Stringer.
+func NewLogFieldStringer(name string, value fmt.Stringer) LogField {
        var str string
        if value != nil {
                str = value.String()
@@ -255,15 +258,18 @@ func newLogFieldStringer(name string, value fmt.Stringer) 
LogField {
        return newLogField(name, logFieldValueString(str))
 }
 
-func newLogFieldString(name string, value string) LogField {
+// NewLogFieldString creates a new LogField with the given name and string.
+func NewLogFieldString(name string, value string) LogField {
        return newLogField(name, logFieldValueString(value))
 }
 
-func newLogFieldInt(name string, value int) LogField {
+// NewLogFieldInt creates a new LogField with the given name and int.
+func NewLogFieldInt(name string, value int) LogField {
        return newLogField(name, logFieldValueInt64(int64(value)))
 }
 
-func newLogFieldBool(name string, value bool) LogField {
+// NewLogFieldBool creates a new LogField with the given name and bool.
+func NewLogFieldBool(name string, value bool) LogField {
        return newLogField(name, logFieldValueBool(value))
 }
 
diff --git a/policies.go b/policies.go
index 8f968087..5fc61d78 100644
--- a/policies.go
+++ b/policies.go
@@ -304,9 +304,19 @@ type HostTierer interface {
 type HostSelectionPolicy interface {
        HostStateNotifier
        SetPartitioner
+
+       // KeyspaceChanged is called when the driver receives a keyspace change 
event.
        KeyspaceChanged(KeyspaceUpdateEvent)
+
+       // Init is called automatically during session creation so the policy 
can store
+       // a reference to the attached session. Notably the session is not 
usable yet
+       // when it's passed to this method.
        Init(*Session)
+
+       // IsLocal should return true if the given Host is considered "local" 
by some
+       // criteria. "Local" hosts are preferred over non-local hosts.
        IsLocal(host *HostInfo) bool
+
        // Pick returns an iteration function over selected hosts.
        // Multiple attempts of a single query execution won't call the 
returned NextHost function concurrently,
        // so it's safe to have internal state without additional 
synchronization as long as every call to Pick returns
@@ -576,7 +586,7 @@ func (m *clusterMeta) resetTokenRing(partitioner string, 
hosts []*HostInfo, logg
        // create a new token ring
        tokenRing, err := newTokenRing(partitioner, hosts)
        if err != nil {
-               logger.Warning("Unable to update the token ring due to error.", 
newLogFieldError("err", err))
+               logger.Warning("Unable to update the token ring due to error.", 
NewLogFieldError("err", err))
                return
        }
 
diff --git a/query_executor.go b/query_executor.go
index 35422ffe..2d7a6233 100644
--- a/query_executor.go
+++ b/query_executor.go
@@ -423,18 +423,18 @@ func (q *internalQuery) GetRoutingKey() ([]byte, error) {
        }
 
        // try to determine the routing key
-       routingKeyInfo, err := q.session.routingKeyInfo(q.Context(), 
q.qryOpts.stmt, q.qryOpts.keyspace)
+       meta, err := q.session.routingStatementMetadata(q.Context(), 
q.qryOpts.stmt, q.qryOpts.keyspace)
        if err != nil {
                return nil, err
        }
 
-       if routingKeyInfo != nil {
+       if meta != nil {
                q.routingInfo.mu.Lock()
-               q.routingInfo.keyspace = routingKeyInfo.keyspace
-               q.routingInfo.table = routingKeyInfo.table
+               q.routingInfo.keyspace = meta.Keyspace
+               q.routingInfo.table = meta.Table
                q.routingInfo.mu.Unlock()
        }
-       return createRoutingKey(routingKeyInfo, q.qryOpts.values)
+       return createRoutingKey(meta, q.qryOpts.values)
 }
 
 func (q *internalQuery) Keyspace() string {
@@ -643,19 +643,19 @@ func (b *internalBatch) GetRoutingKey() ([]byte, error) {
                return nil, nil
        }
        // try to determine the routing key
-       routingKeyInfo, err := b.session.routingKeyInfo(b.Context(), 
entry.Stmt, b.batchOpts.keyspace)
+       meta, err := b.session.routingStatementMetadata(b.Context(), 
entry.Stmt, b.batchOpts.keyspace)
        if err != nil {
                return nil, err
        }
 
-       if routingKeyInfo != nil {
+       if meta != nil {
                b.routingInfo.mu.Lock()
-               b.routingInfo.keyspace = routingKeyInfo.keyspace
-               b.routingInfo.table = routingKeyInfo.table
+               b.routingInfo.keyspace = meta.Keyspace
+               b.routingInfo.table = meta.Table
                b.routingInfo.mu.Unlock()
        }
 
-       return createRoutingKey(routingKeyInfo, entry.Args)
+       return createRoutingKey(meta, entry.Args)
 }
 
 func (b *internalBatch) Keyspace() string {
diff --git a/session.go b/session.go
index 111c7d01..264ed892 100644
--- a/session.go
+++ b/session.go
@@ -51,21 +51,21 @@ import (
 // and automatically sets a default consistency level on all operations
 // that do not have a consistency level set.
 type Session struct {
-       cons                Consistency
-       pageSize            int
-       prefetch            float64
-       routingKeyInfoCache routingKeyInfoLRU
-       schemaDescriber     *schemaDescriber
-       trace               Tracer
-       queryObserver       QueryObserver
-       batchObserver       BatchObserver
-       connectObserver     ConnectObserver
-       frameObserver       FrameHeaderObserver
-       streamObserver      StreamObserver
-       hostSource          *ringDescriber
-       ringRefresher       *refreshDebouncer
-       stmtsLRU            *preparedLRU
-       types               *RegisteredTypes
+       cons                 Consistency
+       pageSize             int
+       prefetch             float64
+       routingMetadataCache routingKeyInfoLRU
+       schemaDescriber      *schemaDescriber
+       trace                Tracer
+       queryObserver        QueryObserver
+       batchObserver        BatchObserver
+       connectObserver      ConnectObserver
+       frameObserver        FrameHeaderObserver
+       streamObserver       StreamObserver
+       hostSource           *ringDescriber
+       ringRefresher        *refreshDebouncer
+       stmtsLRU             *preparedLRU
+       types                *RegisteredTypes
 
        connCfg *ConnConfig
 
@@ -111,7 +111,7 @@ func addrsToHosts(addrs []string, defaultPort int, logger 
StructuredLogger) ([]*
                if err != nil {
                        // Try other hosts if unable to resolve DNS name
                        if _, ok := err.(*net.DNSError); ok {
-                               logger.Error("DNS error.", 
newLogFieldError("err", err))
+                               logger.Error("DNS error.", 
NewLogFieldError("err", err))
                                continue
                        }
                        return nil, err
@@ -167,24 +167,11 @@ func NewSession(cfg ClusterConfig) (*Session, error) {
        s.nodeEvents = newEventDebouncer("NodeEvents", s.handleNodeEvent, 
s.logger)
        s.schemaEvents = newEventDebouncer("SchemaEvents", s.handleSchemaEvent, 
s.logger)
 
-       s.routingKeyInfoCache.lru = lru.New(cfg.MaxRoutingKeyInfo)
+       s.routingMetadataCache.lru = lru.New(cfg.MaxRoutingKeyInfo)
 
        s.hostSource = &ringDescriber{session: s}
        s.ringRefresher = newRefreshDebouncer(ringRefreshDebounceTime, func() 
error { return refreshRing(s.hostSource) })
 
-       if cfg.PoolConfig.HostSelectionPolicy == nil {
-               cfg.PoolConfig.HostSelectionPolicy = RoundRobinHostPolicy()
-       }
-       s.pool = cfg.PoolConfig.buildPool(s)
-
-       s.policy = cfg.PoolConfig.HostSelectionPolicy
-       s.policy.Init(s)
-
-       s.executor = &queryExecutor{
-               pool:   s.pool,
-               policy: cfg.PoolConfig.HostSelectionPolicy,
-       }
-
        s.queryObserver = cfg.QueryObserver
        s.batchObserver = cfg.BatchObserver
        s.connectObserver = cfg.ConnectObserver
@@ -199,6 +186,20 @@ func NewSession(cfg ClusterConfig) (*Session, error) {
        }
        s.connCfg = connCfg
 
+       if cfg.PoolConfig.HostSelectionPolicy == nil {
+               cfg.PoolConfig.HostSelectionPolicy = RoundRobinHostPolicy()
+       }
+       s.pool = cfg.PoolConfig.buildPool(s)
+       s.policy = cfg.PoolConfig.HostSelectionPolicy
+
+       // set the executor here in case the policy needs to execute queries in 
Init
+       s.executor = &queryExecutor{
+               pool:   s.pool,
+               policy: cfg.PoolConfig.HostSelectionPolicy,
+       }
+
+       s.policy.Init(s)
+
        if err := s.init(); err != nil {
                s.Close()
                if err == ErrNoConnectionsStarted {
@@ -234,7 +235,7 @@ func (s *Session) init() error {
                        // TODO(zariel): we really only need this in 1 place
                        s.cfg.ProtoVersion = proto
                        s.connCfg.ProtoVersion = proto
-                       s.logger.Info("Discovered protocol version.", 
newLogFieldInt("protocol_version", proto))
+                       s.logger.Info("Discovered protocol version.", 
NewLogFieldInt("protocol_version", proto))
                }
 
                if err := s.control.connect(hosts, true); err != nil {
@@ -256,7 +257,7 @@ func (s *Session) init() error {
                        }
 
                        hosts = filteredHosts
-                       s.logger.Info("Refreshed ring.", 
newLogFieldString("ring", ringString(hosts)))
+                       s.logger.Info("Refreshed ring.", 
NewLogFieldString("ring", ringString(hosts)))
                } else {
                        s.logger.Info("Not performing a ring refresh because 
DisableInitialHostLookup is true.")
                }
@@ -295,7 +296,7 @@ func (s *Session) init() error {
                }
                if !exists {
                        s.logger.Info("Adding host (session initialization).",
-                               newLogFieldIp("host_addr", 
host.ConnectAddress()), newLogFieldString("host_id", host.HostID()))
+                               NewLogFieldIP("host_addr", 
host.ConnectAddress()), NewLogFieldString("host_id", host.HostID()))
                }
 
                atomic.AddInt64(&left, 1)
@@ -404,16 +405,16 @@ func (s *Session) reconnectDownedHosts(intv 
time.Duration) {
                        hosts := s.ring.allHosts()
 
                        // Print session.ring for debug.
-                       s.logger.Debug("Logging current ring state.", 
newLogFieldString("ring", ringString(hosts)))
+                       s.logger.Debug("Logging current ring state.", 
NewLogFieldString("ring", ringString(hosts)))
 
                        for _, h := range hosts {
                                if h.IsUp() {
                                        continue
                                }
                                s.logger.Debug("Reconnecting to downed host.",
-                                       newLogFieldIp("host_addr", 
h.ConnectAddress()),
-                                       newLogFieldInt("host_port", h.Port()),
-                                       newLogFieldString("host_id", 
h.HostID()))
+                                       NewLogFieldIP("host_addr", 
h.ConnectAddress()),
+                                       NewLogFieldInt("host_port", h.Port()),
+                                       NewLogFieldString("host_id", 
h.HostID()))
                                // we let the pool call handleNodeConnected to 
change the host state
                                s.pool.addHost(h)
                        }
@@ -578,7 +579,7 @@ func (s *Session) executeQuery(qry *internalQuery) (it 
*Iter) {
 }
 
 func (s *Session) removeHost(h *HostInfo) {
-       s.logger.Warning("Removing host.", newLogFieldIp("host_addr", 
h.ConnectAddress()), newLogFieldString("host_id", h.HostID()))
+       s.logger.Warning("Removing host.", NewLogFieldIP("host_addr", 
h.ConnectAddress()), NewLogFieldString("host_id", h.HostID()))
        s.policy.RemoveHost(h)
        hostID := h.HostID()
        s.pool.removeHost(hostID)
@@ -599,6 +600,7 @@ func (s *Session) KeyspaceMetadata(keyspace string) 
(*KeyspaceMetadata, error) {
 
 func (s *Session) getConn() *Conn {
        hosts := s.ring.allHosts()
+
        for _, host := range hosts {
                if !host.IsUp() {
                        continue
@@ -615,23 +617,22 @@ func (s *Session) getConn() *Conn {
        return nil
 }
 
-// Returns routing key indexes and type info.
+// Returns statement metadata for the purposes of generating a routing key.
 // If keyspace == "" it uses the keyspace which is specified in 
Cluster.Keyspace
-func (s *Session) routingKeyInfo(ctx context.Context, stmt string, keyspace 
string) (*routingKeyInfo, error) {
+func (s *Session) routingStatementMetadata(ctx context.Context, stmt string, 
keyspace string) (*StatementMetadata, error) {
        if keyspace == "" {
                keyspace = s.cfg.Keyspace
        }
 
-       routingKeyInfoCacheKey := keyspace + stmt
-
-       s.routingKeyInfoCache.mu.Lock()
+       key := keyspace + stmt
+       s.routingMetadataCache.mu.Lock()
 
        // Using here keyspace + stmt as a cache key because
        // the query keyspace could be overridden via SetKeyspace
-       entry, cached := s.routingKeyInfoCache.lru.Get(routingKeyInfoCacheKey)
+       entry, cached := s.routingMetadataCache.lru.Get(key)
        if cached {
                // done accessing the cache
-               s.routingKeyInfoCache.mu.Unlock()
+               s.routingMetadataCache.mu.Unlock()
                // the entry is an inflight struct similar to that used by
                // Conn to prepare statements
                inflight := entry.(*inflightCachedEntry)
@@ -643,7 +644,7 @@ func (s *Session) routingKeyInfo(ctx context.Context, stmt 
string, keyspace stri
                        return nil, inflight.err
                }
 
-               key, _ := inflight.value.(*routingKeyInfo)
+               key, _ := inflight.value.(*StatementMetadata)
 
                return key, nil
        }
@@ -652,114 +653,113 @@ func (s *Session) routingKeyInfo(ctx context.Context, 
stmt string, keyspace stri
        inflight := new(inflightCachedEntry)
        inflight.wg.Add(1)
        defer inflight.wg.Done()
-       s.routingKeyInfoCache.lru.Add(routingKeyInfoCacheKey, inflight)
-       s.routingKeyInfoCache.mu.Unlock()
-
-       var (
-               info         *preparedStatment
-               partitionKey []*ColumnMetadata
-       )
+       s.routingMetadataCache.lru.Add(key, inflight)
+       s.routingMetadataCache.mu.Unlock()
 
-       conn := s.getConn()
-       if conn == nil {
-               // TODO: better error?
-               inflight.err = errors.New("gocql: unable to fetch prepared 
info: no connection available")
-               return nil, inflight.err
-       }
-
-       // get the query info for the statement
-       info, inflight.err = conn.prepareStatement(ctx, stmt, nil, keyspace)
+       var meta StatementMetadata
+       meta, inflight.err = s.StatementMetadata(ctx, stmt, keyspace)
        if inflight.err != nil {
                // don't cache this error
-               s.routingKeyInfoCache.Remove(stmt)
+               s.routingMetadataCache.Remove(key)
                return nil, inflight.err
        }
 
-       // TODO: it would be nice to mark hosts here but as we are not using 
the policies
-       // to fetch hosts we cant
+       inflight.value = &meta
 
-       if info.request.colCount == 0 {
-               // no arguments, no routing key, and no error
-               return nil, nil
-       }
+       return &meta, nil
+}
 
-       table := info.request.table
-       if info.request.keyspace != "" {
-               keyspace = info.request.keyspace
-       }
+// StatementMetadata represents various metadata about a statement.
+type StatementMetadata struct {
+       // Keyspace is the keyspace of the table for the statement.
+       Keyspace string
 
-       if len(info.request.pkeyColumns) > 0 {
-               // proto v4 dont need to calculate primary key columns
-               types := make([]TypeInfo, len(info.request.pkeyColumns))
-               for i, col := range info.request.pkeyColumns {
-                       types[i] = info.request.columns[col].TypeInfo
-               }
+       // Table is the table of the statement.
+       Table string
 
-               routingKeyInfo := &routingKeyInfo{
-                       indexes:  info.request.pkeyColumns,
-                       types:    types,
-                       keyspace: keyspace,
-                       table:    table,
-               }
+       // BindColumns are columns bound to the statement.
+       BindColumns []ColumnInfo
 
-               inflight.value = routingKeyInfo
-               return routingKeyInfo, nil
-       }
+       // PKBindColumnIndexes are the indexes of the BindColumns that 
correspond to
+       // partition key columns. If this is empty then one or more columns in 
the
+       // partition key were not bound to the statement.
+       PKBindColumnIndexes []int
 
-       var keyspaceMetadata *KeyspaceMetadata
-       keyspaceMetadata, inflight.err = 
s.KeyspaceMetadata(info.request.columns[0].Keyspace)
-       if inflight.err != nil {
-               // don't cache this error
-               s.routingKeyInfoCache.Remove(stmt)
-               return nil, inflight.err
+       // ResultColumns are the columns that are returned by the statement.
+       ResultColumns []ColumnInfo
+}
+
+// StatementMetadata returns metadata for a statement. If keyspace is empty,
+// the session's keyspace is used.
+func (s *Session) StatementMetadata(ctx context.Context, stmt, keyspace 
string) (StatementMetadata, error) {
+       if keyspace == "" {
+               keyspace = s.cfg.Keyspace
        }
 
-       tableMetadata, found := keyspaceMetadata.Tables[table]
-       if !found {
-               // unlikely that the statement could be prepared and the 
metadata for
-               // the table couldn't be found, but this may indicate either a 
bug
-               // in the metadata code, or that the table was just dropped.
-               inflight.err = ErrNoMetadata
-               // don't cache this error
-               s.routingKeyInfoCache.Remove(stmt)
-               return nil, inflight.err
+       conn := s.getConn()
+       if conn == nil {
+               return StatementMetadata{}, ErrNoConnections
        }
 
-       partitionKey = tableMetadata.PartitionKey
+       // get the query info for the statement
+       info, err := conn.prepareStatement(ctx, stmt, nil, keyspace)
+       if err != nil {
+               // TODO: it would be nice to mark hosts here but as we are not 
using the policies
+               // to fetch hosts we cant and we can't use the policies because 
they might
+               // require token awareness which requires this method
+               return StatementMetadata{}, err
+       }
 
-       size := len(partitionKey)
-       routingKeyInfo := &routingKeyInfo{
-               indexes:  make([]int, size),
-               types:    make([]TypeInfo, size),
-               keyspace: keyspace,
-               table:    table,
+       if info.request.keyspace != "" {
+               keyspace = info.request.keyspace
        }
 
-       for keyIndex, keyColumn := range partitionKey {
-               // set an indicator for checking if the mapping is missing
-               routingKeyInfo.indexes[keyIndex] = -1
+       meta := StatementMetadata{
+               Keyspace:            keyspace,
+               Table:               info.request.table,
+               BindColumns:         info.request.columns,
+               PKBindColumnIndexes: info.request.pkeyColumns,
+               ResultColumns:       info.response.columns,
+       }
 
-               // find the column in the query info
-               for argIndex, boundColumn := range info.request.columns {
-                       if keyColumn.Name == boundColumn.Name {
-                               // there may be many such bound columns, pick 
the first
-                               routingKeyInfo.indexes[keyIndex] = argIndex
-                               routingKeyInfo.types[keyIndex] = 
boundColumn.TypeInfo
-                               break
-                       }
+       // if it is protocol < v4 then we need to calculate the routing key info
+       if !info.request.supportsPKeyColumns && len(info.request.columns) > 0 {
+               keyspaceMetadata, err := s.KeyspaceMetadata(meta.Keyspace)
+               if err != nil {
+                       // don't cache this error
+                       return StatementMetadata{}, err
                }
 
-               if routingKeyInfo.indexes[keyIndex] == -1 {
-                       // missing a routing key column mapping
-                       // no routing key, and no error
-                       return nil, nil
+               tableMetadata, found := keyspaceMetadata.Tables[meta.Table]
+               if !found {
+                       // unlikely that the statement could be prepared and 
the metadata for
+                       // the table couldn't be found, but this may indicate 
either a bug
+                       // in the metadata code, or that the table was just 
dropped.
+                       return StatementMetadata{}, ErrNoMetadata
                }
-       }
 
-       // cache this result
-       inflight.value = routingKeyInfo
+               meta.PKBindColumnIndexes = make([]int, 
len(tableMetadata.PartitionKey))
+               for keyIndex, keyColumn := range tableMetadata.PartitionKey {
+                       // set an indicator for checking if the mapping is 
missing
+                       meta.PKBindColumnIndexes[keyIndex] = -1
+
+                       // find the column in the query info
+                       for colIndex, boundColumn := range info.request.columns 
{
+                               if keyColumn.Name == boundColumn.Name {
+                                       // there may be many such bound 
columns, pick the first
+                                       meta.PKBindColumnIndexes[keyIndex] = 
colIndex
+                                       break
+                               }
+                       }
 
-       return routingKeyInfo, nil
+                       if meta.PKBindColumnIndexes[keyIndex] == -1 {
+                               // the partition key column is not bound to the 
statement
+                               meta.PKBindColumnIndexes = nil
+                               break
+                       }
+               }
+       }
+       return meta, nil
 }
 
 // Exec executes a batch operation and returns nil if successful
@@ -2102,16 +2102,20 @@ func (b *Batch) WithTimestamp(timestamp int64) *Batch {
        return b
 }
 
-func createRoutingKey(routingKeyInfo *routingKeyInfo, values []interface{}) 
([]byte, error) {
-       if routingKeyInfo == nil {
+func createRoutingKey(meta *StatementMetadata, values []interface{}) ([]byte, 
error) {
+       if meta == nil || len(meta.PKBindColumnIndexes) == 0 {
                return nil, nil
        }
 
-       if len(routingKeyInfo.indexes) == 1 {
+       if len(values) != len(meta.BindColumns) {
+               return nil, errors.New("gocql: number of values does not match 
the number of bind columns")
+       }
+
+       if len(meta.PKBindColumnIndexes) == 1 {
                // single column routing key
                routingKey, err := Marshal(
-                       routingKeyInfo.types[0],
-                       values[routingKeyInfo.indexes[0]],
+                       meta.BindColumns[meta.PKBindColumnIndexes[0]].TypeInfo,
+                       values[meta.PKBindColumnIndexes[0]],
                )
                if err != nil {
                        return nil, err
@@ -2121,22 +2125,23 @@ func createRoutingKey(routingKeyInfo *routingKeyInfo, 
values []interface{}) ([]b
 
        // composite routing key
        buf := bytes.NewBuffer(make([]byte, 0, 256))
-       for i := range routingKeyInfo.indexes {
+       lenBuf := make([]byte, 2)
+       for i := range meta.PKBindColumnIndexes {
                encoded, err := Marshal(
-                       routingKeyInfo.types[i],
-                       values[routingKeyInfo.indexes[i]],
+                       meta.BindColumns[meta.PKBindColumnIndexes[i]].TypeInfo,
+                       values[meta.PKBindColumnIndexes[i]],
                )
                if err != nil {
                        return nil, err
                }
-               lenBuf := []byte{0x00, 0x00}
+               // first write the length of the encoded value as a 16-bit big 
endian integer
                binary.BigEndian.PutUint16(lenBuf, uint16(len(encoded)))
                buf.Write(lenBuf)
+               // then write the encoded value and a null byte to separate the 
values
                buf.Write(encoded)
                buf.WriteByte(0x00)
        }
-       routingKey := buf.Bytes()
-       return routingKey, nil
+       return buf.Bytes(), nil
 }
 
 // SetKeyspace will enable keyspace flag on the query.
@@ -2195,17 +2200,6 @@ type routingKeyInfoLRU struct {
        mu  sync.Mutex
 }
 
-type routingKeyInfo struct {
-       indexes  []int
-       types    []TypeInfo
-       keyspace string
-       table    string
-}
-
-func (r *routingKeyInfo) String() string {
-       return fmt.Sprintf("routing key index=%v types=%v", r.indexes, r.types)
-}
-
 func (r *routingKeyInfoLRU) Remove(key string) {
        r.mu.Lock()
        r.lru.Remove(key)
diff --git a/topology.go b/topology.go
index 5a9b50fa..e5741ff2 100644
--- a/topology.go
+++ b/topology.go
@@ -96,7 +96,7 @@ func getStrategy(ks *KeyspaceMetadata, logger 
StructuredLogger) placementStrateg
                rf, err := 
getReplicationFactorFromOpts(ks.StrategyOptions["replication_factor"])
                if err != nil {
                        logger.Warning("Failed to parse replication factor of 
keyspace configured with SimpleStrategy.",
-                               newLogFieldString("keyspace", ks.Name), 
newLogFieldError("err", err))
+                               NewLogFieldString("keyspace", ks.Name), 
NewLogFieldError("err", err))
                        return nil
                }
                return &simpleStrategy{rf: rf}
@@ -110,7 +110,7 @@ func getStrategy(ks *KeyspaceMetadata, logger 
StructuredLogger) placementStrateg
                        rf, err := getReplicationFactorFromOpts(rf)
                        if err != nil {
                                logger.Warning("Failed to parse replication 
factors of keyspace configured with NetworkTopologyStrategy.",
-                                       newLogFieldString("keyspace", ks.Name), 
newLogFieldString("dc", dc), newLogFieldError("err", err))
+                                       NewLogFieldString("keyspace", ks.Name), 
NewLogFieldString("dc", dc), NewLogFieldError("err", err))
                                // skip DC if the rf is invalid/unsupported, so 
that we can at least work with other working DCs.
                                continue
                        }
@@ -122,7 +122,7 @@ func getStrategy(ks *KeyspaceMetadata, logger 
StructuredLogger) placementStrateg
                return nil
        default:
                logger.Warning("Failed to parse replication factor of keyspace 
due to unknown strategy class.",
-                       newLogFieldString("keyspace", ks.Name), 
newLogFieldString("strategy_class", ks.StrategyClass))
+                       NewLogFieldString("keyspace", ks.Name), 
NewLogFieldString("strategy_class", ks.StrategyClass))
                return nil
        }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to