worryg0d commented on code in PR #1868:
URL: 
https://github.com/apache/cassandra-gocql-driver/pull/1868#discussion_r2000954850


##########
conn.go:
##########
@@ -1523,32 +1534,35 @@ func (c *Conn) UseKeyspace(keyspace string) error {
        return nil
 }
 
-func (c *Conn) executeBatch(ctx context.Context, batch *Batch) *Iter {
+func (c *Conn) executeBatch(ctx context.Context, b *internalBatch) *Iter {
        if c.version == protoVersion1 {
-               return &Iter{err: ErrUnsupported}
+               return newErrIter(ErrUnsupported, b.metrics)
        }
 
-       n := len(batch.Entries)
+       iter := newIter(b.metrics)
+
+       n := len(b.batchOpts.Entries)
        req := &writeBatchFrame{
-               typ:                   batch.Type,
+               typ:                   b.batchOpts.Type,
                statements:            make([]batchStatment, n),
-               consistency:           batch.Cons,
-               serialConsistency:     batch.serialCons,
-               defaultTimestamp:      batch.defaultTimestamp,
-               defaultTimestampValue: batch.defaultTimestampValue,
-               customPayload:         batch.CustomPayload,
+               consistency:           b.GetConsistency(),
+               serialConsistency:     b.batchOpts.serialCons,
+               defaultTimestamp:      b.batchOpts.defaultTimestamp,
+               defaultTimestampValue: b.batchOpts.defaultTimestampValue,
+               customPayload:         b.batchOpts.CustomPayload,
        }
 
-       stmts := make(map[string]string, len(batch.Entries))
+       stmts := make(map[string]string, len(b.batchOpts.Entries))
 
        for i := 0; i < n; i++ {
-               entry := &batch.Entries[i]
-               b := &req.statements[i]
+               entry := &b.batchOpts.Entries[i]
+               batchStmt := &req.statements[i]
 
                if len(entry.Args) > 0 || entry.binding != nil {
-                       info, err := c.prepareStatement(batch.Context(), 
entry.Stmt, batch.trace)
+                       info, err := c.prepareStatement(b.batchOpts.context, 
entry.Stmt, b.batchOpts.trace)

Review Comment:
   The only usage of `ctx` is at line 1634 and that is a recursive call to 
`executeBatch` on an unprepared Cassandra response so changing 
`b.batchOpts.context` to `ctx` will make the code less confusing and clearer



##########
query_executor.go:
##########
@@ -201,16 +210,385 @@ func (q *queryExecutor) do(ctx context.Context, qry 
ExecutableQuery, hostIter Ne
        }
 
        if lastErr != nil {
-               return &Iter{err: lastErr}
+               return newErrIter(lastErr, qry.getQueryMetrics())
        }
 
-       return &Iter{err: ErrNoConnections}
+       return newErrIter(ErrNoConnections, qry.getQueryMetrics())
 }
 
-func (q *queryExecutor) run(ctx context.Context, qry ExecutableQuery, hostIter 
NextHost, results chan<- *Iter) {
+func (q *queryExecutor) run(ctx context.Context, qry internalRequest, hostIter 
NextHost, results chan<- *Iter) {
        select {
        case results <- q.do(ctx, qry, hostIter):
        case <-ctx.Done():
        }
-       qry.releaseAfterExecution()
+}
+
+type queryOptions struct {
+       stmt                  string
+       values                []interface{}
+       initialConsistency    Consistency
+       pageSize              int
+       initialPageState      []byte
+       prefetch              float64
+       trace                 Tracer
+       observer              QueryObserver
+       rt                    RetryPolicy
+       spec                  SpeculativeExecutionPolicy
+       binding               func(q *QueryInfo) ([]interface{}, error)
+       serialCons            SerialConsistency
+       defaultTimestamp      bool
+       defaultTimestampValue int64
+       disableSkipMetadata   bool
+       context               context.Context
+       idempotent            bool
+       customPayload         map[string][]byte
+       keyspace              string
+       disableAutoPage       bool
+       skipPrepare           bool
+       routingKey            []byte
+
+       // getKeyspace is field so that it can be overriden in tests
+       getKeyspace func() string
+}
+
+func newQueryOptions(q *Query) *queryOptions {
+       return &queryOptions{
+               stmt:                  q.stmt,
+               values:                q.values,
+               initialConsistency:    q.initialConsistency,
+               pageSize:              q.pageSize,
+               initialPageState:      q.initialPageState,
+               prefetch:              q.prefetch,
+               trace:                 q.trace,
+               observer:              q.observer,
+               rt:                    q.rt,
+               spec:                  q.spec,
+               binding:               q.binding,
+               serialCons:            q.serialCons,
+               defaultTimestamp:      q.defaultTimestamp,
+               defaultTimestampValue: q.defaultTimestampValue,
+               disableSkipMetadata:   q.disableSkipMetadata,
+               context:               q.Context(),
+               idempotent:            q.idempotent,
+               customPayload:         q.customPayload,
+               disableAutoPage:       q.disableAutoPage,
+               skipPrepare:           q.skipPrepare,
+               routingKey:            q.routingKey,
+               getKeyspace:           q.getKeyspace,
+       }
+}
+
+type internalQuery struct {
+       originalQuery *Query
+       qryOpts       *queryOptions
+       pageState     []byte
+       metrics       *queryMetrics
+       refCount      uint32
+       conn          *Conn
+       consistency   uint32
+       session       *Session
+       routingInfo   *queryRoutingInfo
+}
+
+func newInternalQuery(q *Query) *internalQuery {
+       return &internalQuery{
+               originalQuery: q,
+               qryOpts:       newQueryOptions(q),
+               metrics:       &queryMetrics{m: make(map[string]*hostMetrics)},
+               consistency:   uint32(q.initialConsistency),
+               pageState:     nil,
+               refCount:      0,
+               conn:          nil,
+               session:       q.session,
+               routingInfo:   &queryRoutingInfo{},
+       }
+}
+
+// Attempts returns the number of times the query was executed.
+func (q *internalQuery) Attempts() int {
+       return q.metrics.attempts()
+}
+
+func (q *internalQuery) attempt(keyspace string, end, start time.Time, iter 
*Iter, host *HostInfo) {
+       latency := end.Sub(start)
+       attempt, metricsForHost := q.metrics.attempt(1, latency, host, 
q.qryOpts.observer != nil)
+
+       if q.qryOpts.observer != nil {
+               q.qryOpts.observer.ObserveQuery(q.qryOpts.context, 
ObservedQuery{
+                       Keyspace:  keyspace,
+                       Statement: q.qryOpts.stmt,
+                       Values:    q.qryOpts.values,
+                       Start:     start,
+                       End:       end,
+                       Rows:      iter.numRows,
+                       Host:      host,
+                       Metrics:   metricsForHost,
+                       Err:       iter.err,
+                       Attempt:   attempt,
+               })
+       }
+}
+
+func (q *internalQuery) execute(ctx context.Context, conn *Conn) *Iter {
+       return conn.executeQuery(ctx, q)
+}
+
+func (q *internalQuery) retryPolicy() RetryPolicy {
+       return q.qryOpts.rt
+}
+
+func (q *internalQuery) speculativeExecutionPolicy() 
SpeculativeExecutionPolicy {
+       return q.qryOpts.spec
+}
+
+func (q *internalQuery) GetRoutingKey() ([]byte, error) {
+       if q.qryOpts.routingKey != nil {
+               return q.qryOpts.routingKey, nil
+       }
+
+       if q.qryOpts.binding != nil && len(q.qryOpts.values) == 0 {
+               // If this query was created using session.Bind we wont have 
the query
+               // values yet, so we have to pass down to the next policy.
+               // TODO: Remove this and handle this case
+               return nil, nil
+       }
+
+       // try to determine the routing key
+       routingKeyInfo, err := q.session.routingKeyInfo(q.qryOpts.context, 
q.qryOpts.stmt)
+       if err != nil {
+               return nil, err
+       }
+
+       if routingKeyInfo != nil {
+               q.routingInfo.mu.Lock()
+               q.routingInfo.keyspace = routingKeyInfo.keyspace
+               q.routingInfo.table = routingKeyInfo.table
+               q.routingInfo.mu.Unlock()
+       }
+       return createRoutingKey(routingKeyInfo, q.qryOpts.values)
+}
+
+func (q *internalQuery) Keyspace() string {
+       if q.qryOpts.getKeyspace != nil {
+               return q.qryOpts.getKeyspace()
+       }
+
+       qrKs := q.routingInfo.getKeyspace()
+       if qrKs != "" {
+               return qrKs
+       }
+
+       if q.session == nil {
+               return ""
+       }
+       // TODO(chbannis): this should be parsed from the query or we should let
+       // this be set by users.
+       return q.session.cfg.Keyspace
+}
+
+func (q *internalQuery) Table() string {
+       return q.routingInfo.getTable()
+}
+
+func (q *internalQuery) IsIdempotent() bool {
+       return q.qryOpts.idempotent
+}
+
+func (q *internalQuery) getQueryMetrics() *queryMetrics {
+       return q.metrics
+}
+
+func (q *internalQuery) SetConsistency(c Consistency) {
+       atomic.StoreUint32(&q.consistency, uint32(c))
+}
+
+func (q *internalQuery) GetConsistency() Consistency {
+       return Consistency(atomic.LoadUint32(&q.consistency))
+}
+
+func (q *internalQuery) Context() context.Context {
+       return q.qryOpts.context
+}
+
+func (q *internalQuery) Statement() Statement {
+       return q.originalQuery
+}
+
+type batchOptions struct {
+       Type                  BatchType
+       Entries               []BatchEntry
+       CustomPayload         map[string][]byte
+       rt                    RetryPolicy
+       spec                  SpeculativeExecutionPolicy
+       trace                 Tracer
+       observer              BatchObserver
+       serialCons            SerialConsistency
+       defaultTimestamp      bool
+       defaultTimestampValue int64
+       context               context.Context
+       keyspace              string
+       idempotent            bool
+       routingKey            []byte
+}
+
+func newBatchOptions(b *Batch) *batchOptions {
+       // make we get a new array so if user keeps appending entries on the 
Batch object it doesn't affect this execution
+       newEntries := make([]BatchEntry, len(b.Entries))
+       for i, e := range b.Entries {
+               newEntries[i] = e
+       }
+       return &batchOptions{
+               Type:                  b.Type,
+               Entries:               b.Entries,

Review Comment:
   Here you should use created `newEntries` instead of `b.Entries`



##########
conn.go:
##########
@@ -1523,32 +1534,35 @@ func (c *Conn) UseKeyspace(keyspace string) error {
        return nil
 }
 
-func (c *Conn) executeBatch(ctx context.Context, batch *Batch) *Iter {
+func (c *Conn) executeBatch(ctx context.Context, b *internalBatch) *Iter {
        if c.version == protoVersion1 {
-               return &Iter{err: ErrUnsupported}
+               return newErrIter(ErrUnsupported, b.metrics)
        }
 
-       n := len(batch.Entries)
+       iter := newIter(b.metrics)
+
+       n := len(b.batchOpts.Entries)
        req := &writeBatchFrame{
-               typ:                   batch.Type,
+               typ:                   b.batchOpts.Type,
                statements:            make([]batchStatment, n),
-               consistency:           batch.Cons,
-               serialConsistency:     batch.serialCons,
-               defaultTimestamp:      batch.defaultTimestamp,
-               defaultTimestampValue: batch.defaultTimestampValue,
-               customPayload:         batch.CustomPayload,
+               consistency:           b.GetConsistency(),
+               serialConsistency:     b.batchOpts.serialCons,
+               defaultTimestamp:      b.batchOpts.defaultTimestamp,
+               defaultTimestampValue: b.batchOpts.defaultTimestampValue,
+               customPayload:         b.batchOpts.CustomPayload,
        }
 
-       stmts := make(map[string]string, len(batch.Entries))
+       stmts := make(map[string]string, len(b.batchOpts.Entries))
 
        for i := 0; i < n; i++ {
-               entry := &batch.Entries[i]
-               b := &req.statements[i]
+               entry := &b.batchOpts.Entries[i]
+               batchStmt := &req.statements[i]
 
                if len(entry.Args) > 0 || entry.binding != nil {
-                       info, err := c.prepareStatement(batch.Context(), 
entry.Stmt, batch.trace)
+                       info, err := c.prepareStatement(b.batchOpts.context, 
entry.Stmt, b.batchOpts.trace)

Review Comment:
   If I'm not missing anything here both `ctx` arg of the method and 
`b.batchOpts.context` should be the the same context value.
   The `ctx` arg is being passed by `queryExecutor.executeQuery()` at line 102  
by calling `qry.Context()` which should return the underlying 
`internalBatch.batchOpts.context`.



##########
query_executor.go:
##########
@@ -201,16 +210,385 @@ func (q *queryExecutor) do(ctx context.Context, qry 
ExecutableQuery, hostIter Ne
        }
 
        if lastErr != nil {
-               return &Iter{err: lastErr}
+               return newErrIter(lastErr, qry.getQueryMetrics())
        }
 
-       return &Iter{err: ErrNoConnections}
+       return newErrIter(ErrNoConnections, qry.getQueryMetrics())
 }
 
-func (q *queryExecutor) run(ctx context.Context, qry ExecutableQuery, hostIter 
NextHost, results chan<- *Iter) {
+func (q *queryExecutor) run(ctx context.Context, qry internalRequest, hostIter 
NextHost, results chan<- *Iter) {
        select {
        case results <- q.do(ctx, qry, hostIter):
        case <-ctx.Done():
        }
-       qry.releaseAfterExecution()
+}
+
+type queryOptions struct {
+       stmt                  string
+       values                []interface{}
+       initialConsistency    Consistency
+       pageSize              int
+       initialPageState      []byte
+       prefetch              float64
+       trace                 Tracer
+       observer              QueryObserver
+       rt                    RetryPolicy
+       spec                  SpeculativeExecutionPolicy
+       binding               func(q *QueryInfo) ([]interface{}, error)
+       serialCons            SerialConsistency
+       defaultTimestamp      bool
+       defaultTimestampValue int64
+       disableSkipMetadata   bool
+       context               context.Context
+       idempotent            bool
+       customPayload         map[string][]byte
+       keyspace              string
+       disableAutoPage       bool
+       skipPrepare           bool
+       routingKey            []byte
+
+       // getKeyspace is field so that it can be overriden in tests
+       getKeyspace func() string
+}
+
+func newQueryOptions(q *Query) *queryOptions {
+       return &queryOptions{
+               stmt:                  q.stmt,
+               values:                q.values,
+               initialConsistency:    q.initialConsistency,
+               pageSize:              q.pageSize,
+               initialPageState:      q.initialPageState,
+               prefetch:              q.prefetch,
+               trace:                 q.trace,
+               observer:              q.observer,
+               rt:                    q.rt,
+               spec:                  q.spec,
+               binding:               q.binding,
+               serialCons:            q.serialCons,
+               defaultTimestamp:      q.defaultTimestamp,
+               defaultTimestampValue: q.defaultTimestampValue,
+               disableSkipMetadata:   q.disableSkipMetadata,
+               context:               q.Context(),
+               idempotent:            q.idempotent,
+               customPayload:         q.customPayload,
+               disableAutoPage:       q.disableAutoPage,
+               skipPrepare:           q.skipPrepare,
+               routingKey:            q.routingKey,
+               getKeyspace:           q.getKeyspace,
+       }
+}
+
+type internalQuery struct {
+       originalQuery *Query
+       qryOpts       *queryOptions
+       pageState     []byte
+       metrics       *queryMetrics
+       refCount      uint32
+       conn          *Conn
+       consistency   uint32
+       session       *Session
+       routingInfo   *queryRoutingInfo
+}
+
+func newInternalQuery(q *Query) *internalQuery {
+       return &internalQuery{
+               originalQuery: q,
+               qryOpts:       newQueryOptions(q),
+               metrics:       &queryMetrics{m: make(map[string]*hostMetrics)},
+               consistency:   uint32(q.initialConsistency),
+               pageState:     nil,
+               refCount:      0,
+               conn:          nil,
+               session:       q.session,
+               routingInfo:   &queryRoutingInfo{},
+       }
+}
+
+// Attempts returns the number of times the query was executed.
+func (q *internalQuery) Attempts() int {
+       return q.metrics.attempts()
+}
+
+func (q *internalQuery) attempt(keyspace string, end, start time.Time, iter 
*Iter, host *HostInfo) {
+       latency := end.Sub(start)
+       attempt, metricsForHost := q.metrics.attempt(1, latency, host, 
q.qryOpts.observer != nil)
+
+       if q.qryOpts.observer != nil {
+               q.qryOpts.observer.ObserveQuery(q.qryOpts.context, 
ObservedQuery{
+                       Keyspace:  keyspace,
+                       Statement: q.qryOpts.stmt,
+                       Values:    q.qryOpts.values,
+                       Start:     start,
+                       End:       end,
+                       Rows:      iter.numRows,
+                       Host:      host,
+                       Metrics:   metricsForHost,
+                       Err:       iter.err,
+                       Attempt:   attempt,
+               })
+       }
+}
+
+func (q *internalQuery) execute(ctx context.Context, conn *Conn) *Iter {
+       return conn.executeQuery(ctx, q)
+}
+
+func (q *internalQuery) retryPolicy() RetryPolicy {
+       return q.qryOpts.rt
+}
+
+func (q *internalQuery) speculativeExecutionPolicy() 
SpeculativeExecutionPolicy {
+       return q.qryOpts.spec
+}
+
+func (q *internalQuery) GetRoutingKey() ([]byte, error) {
+       if q.qryOpts.routingKey != nil {
+               return q.qryOpts.routingKey, nil
+       }
+
+       if q.qryOpts.binding != nil && len(q.qryOpts.values) == 0 {
+               // If this query was created using session.Bind we wont have 
the query
+               // values yet, so we have to pass down to the next policy.
+               // TODO: Remove this and handle this case
+               return nil, nil
+       }
+
+       // try to determine the routing key
+       routingKeyInfo, err := q.session.routingKeyInfo(q.qryOpts.context, 
q.qryOpts.stmt)
+       if err != nil {
+               return nil, err
+       }
+
+       if routingKeyInfo != nil {
+               q.routingInfo.mu.Lock()
+               q.routingInfo.keyspace = routingKeyInfo.keyspace
+               q.routingInfo.table = routingKeyInfo.table
+               q.routingInfo.mu.Unlock()
+       }
+       return createRoutingKey(routingKeyInfo, q.qryOpts.values)
+}
+
+func (q *internalQuery) Keyspace() string {
+       if q.qryOpts.getKeyspace != nil {
+               return q.qryOpts.getKeyspace()
+       }
+
+       qrKs := q.routingInfo.getKeyspace()
+       if qrKs != "" {
+               return qrKs
+       }
+
+       if q.session == nil {
+               return ""
+       }
+       // TODO(chbannis): this should be parsed from the query or we should let
+       // this be set by users.
+       return q.session.cfg.Keyspace
+}
+
+func (q *internalQuery) Table() string {
+       return q.routingInfo.getTable()
+}
+
+func (q *internalQuery) IsIdempotent() bool {
+       return q.qryOpts.idempotent
+}
+
+func (q *internalQuery) getQueryMetrics() *queryMetrics {
+       return q.metrics
+}
+
+func (q *internalQuery) SetConsistency(c Consistency) {
+       atomic.StoreUint32(&q.consistency, uint32(c))
+}
+
+func (q *internalQuery) GetConsistency() Consistency {
+       return Consistency(atomic.LoadUint32(&q.consistency))
+}
+
+func (q *internalQuery) Context() context.Context {
+       return q.qryOpts.context
+}
+
+func (q *internalQuery) Statement() Statement {
+       return q.originalQuery
+}
+
+type batchOptions struct {
+       Type                  BatchType
+       Entries               []BatchEntry
+       CustomPayload         map[string][]byte
+       rt                    RetryPolicy
+       spec                  SpeculativeExecutionPolicy
+       trace                 Tracer
+       observer              BatchObserver
+       serialCons            SerialConsistency
+       defaultTimestamp      bool
+       defaultTimestampValue int64
+       context               context.Context
+       keyspace              string
+       idempotent            bool
+       routingKey            []byte
+}
+
+func newBatchOptions(b *Batch) *batchOptions {
+       // make we get a new array so if user keeps appending entries on the 
Batch object it doesn't affect this execution
+       newEntries := make([]BatchEntry, len(b.Entries))
+       for i, e := range b.Entries {
+               newEntries[i] = e
+       }
+       return &batchOptions{
+               Type:                  b.Type,
+               Entries:               b.Entries,
+               CustomPayload:         b.CustomPayload,
+               rt:                    b.rt,
+               spec:                  b.spec,
+               trace:                 b.trace,
+               observer:              b.observer,
+               serialCons:            b.serialCons,
+               defaultTimestamp:      b.defaultTimestamp,
+               defaultTimestampValue: b.defaultTimestampValue,
+               context:               b.Context(),
+               keyspace:              b.Keyspace(),
+               idempotent:            b.IsIdempotent(),
+               routingKey:            b.routingKey,
+       }
+}
+
+type internalBatch struct {
+       originalBatch *Batch
+       batchOpts     *batchOptions
+       metrics       *queryMetrics
+       consistency   uint32
+       routingInfo   *queryRoutingInfo
+       session       *Session
+}
+
+func newInternalBatch(batch *Batch) *internalBatch {
+       return &internalBatch{
+               originalBatch: batch,
+               batchOpts:     newBatchOptions(batch),
+               metrics:       &queryMetrics{m: make(map[string]*hostMetrics)},
+               routingInfo:   &queryRoutingInfo{},
+               session:       batch.session,
+       }
+}
+
+// Attempts returns the number of attempts made to execute the batch.
+func (b *internalBatch) Attempts() int {
+       return b.metrics.attempts()
+}
+
+func (b *internalBatch) attempt(keyspace string, end, start time.Time, iter 
*Iter, host *HostInfo) {
+       latency := end.Sub(start)
+       attempt, metricsForHost := b.metrics.attempt(1, latency, host, 
b.batchOpts.observer != nil)
+
+       if b.batchOpts.observer == nil {
+               return
+       }
+
+       statements := make([]string, len(b.batchOpts.Entries))
+       values := make([][]interface{}, len(b.batchOpts.Entries))
+
+       for i, entry := range b.batchOpts.Entries {
+               statements[i] = entry.Stmt
+               values[i] = entry.Args
+       }
+
+       b.batchOpts.observer.ObserveBatch(b.batchOpts.context, ObservedBatch{
+               Keyspace:   keyspace,
+               Statements: statements,
+               Values:     values,
+               Start:      start,
+               End:        end,
+               // Rows not used in batch observations // TODO - might be able 
to support it when using BatchCAS
+               Host:    host,
+               Metrics: metricsForHost,
+               Err:     iter.err,
+               Attempt: attempt,
+       })
+}
+
+func (b *internalBatch) borrowForExecution() {
+       // empty, because Batch has no equivalent of Query.Release()
+       // that would race with speculative executions.
+}
+
+func (b *internalBatch) releaseAfterExecution() {
+       // empty, because Batch has no equivalent of Query.Release()
+       // that would race with speculative executions.
+}

Review Comment:
   Since both these methods are removed from the interface you can get rid of 
them here too



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: pr-unsubscr...@cassandra.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: pr-unsubscr...@cassandra.apache.org
For additional commands, e-mail: pr-h...@cassandra.apache.org

Reply via email to