This is an automated email from the ASF dual-hosted git repository.

zike pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pulsar-client-go.git


The following commit(s) were added to refs/heads/master by this push:
     new d98c4f1  [refactor]: Refactor the toTrackingMessageID() (#972)
d98c4f1 is described below

commit d98c4f17c6f8927072d146f4a10c8df73e21be6e
Author: Jiaqi Shen <[email protected]>
AuthorDate: Tue Mar 7 18:14:11 2023 +0800

    [refactor]: Refactor the toTrackingMessageID() (#972)
    
    ### Motivation
    
    `toTrackingMessageID()` is a function that is widely used in `consumer` 
implementation. It can convert an interface type  `MessageID` into an struct 
type `trackingMessageID`. In addition, the second return value also plays a 
role in checking the `MessageID` type. In other words, it indicates that 
`MessageID` **cannot** be a user-defined type. From the perspective of code 
readability, `toTrackingMessageID()` should not do both.
    
    **Note**: After #968 , `toTrackingMessageID()` returns only a pointer now. 
The role of original `ok` is replaced by nil pointer  now. However, the main 
content discussed in this PR has not changed.
    
    For example.
    
    
https://github.com/apache/pulsar-client-go/blob/e2ea255052e8a527091791ef368851d885ee2d45/pulsar/consumer_regex.go#L176-L181
    
    This example is the correct usage. The `ok` returned by 
`toTrackingMessageID()` is used to reject user-defined `MessageID`.
    
    
https://github.com/apache/pulsar-client-go/blob/e2ea255052e8a527091791ef368851d885ee2d45/pulsar/consumer_partition.go#L470-L473
    
    This example is a bit vague. The actual effect here is the same as the 
previous example. But it return an error `failed to convert trackingMessageID` 
which is confusing.
    
    
https://github.com/apache/pulsar-client-go/blob/e2ea255052e8a527091791ef368851d885ee2d45/pulsar/consumer_partition.go#L1816-L1820
    
    In this case. We just want to convert `MessageID` into `trackingMessageID`. 
We do not care what it really is because it's not possible an invalid 
`MessageID` implementation.
    
    So, original `toTrackingMessageID()` needs to require a careful look to use 
it correctly. I think it would be better to split it into two different method. 
`toTrackingMessageID()` just do the struct conversion, which it's more clearly. 
And when the new messageID type is added, we can just modify the 
`checkMessageIDType`.
    
    ### Modifications
    
    - Refactor the `toTrackingMessageID()`
    - Add the `checkMessageIDType()` to check whether `MessageID` is 
user-defined.
---
 pulsar/consumer_impl.go       | 15 +++++++++++----
 pulsar/consumer_multitopic.go | 16 ++++++++--------
 pulsar/consumer_partition.go  | 44 ++++++++++++++++++++++++++++++-------------
 pulsar/consumer_regex.go      | 23 ++++++++++++----------
 pulsar/impl_message.go        | 44 ++++++++++++++++++++++++++++---------------
 pulsar/reader_impl.go         | 30 ++++++++++++++---------------
 6 files changed, 106 insertions(+), 66 deletions(-)

diff --git a/pulsar/consumer_impl.go b/pulsar/consumer_impl.go
index 7a86574..d19e522 100644
--- a/pulsar/consumer_impl.go
+++ b/pulsar/consumer_impl.go
@@ -531,10 +531,17 @@ func (c *consumer) ReconsumeLaterWithCustomProperties(msg 
Message, customPropert
        if delay < 0 {
                delay = 0
        }
+
+       if !checkMessageIDType(msg.ID()) {
+               c.log.Warnf("invalid message id type %T", msg.ID())
+               return
+       }
+
        msgID := c.messageID(msg.ID())
        if msgID == nil {
                return
        }
+
        props := make(map[string]string)
        for k, v := range msg.Properties() {
                props[k] = v
@@ -580,6 +587,10 @@ func (c *consumer) ReconsumeLaterWithCustomProperties(msg 
Message, customPropert
 }
 
 func (c *consumer) Nack(msg Message) {
+       if !checkMessageIDType(msg.ID()) {
+               c.log.Warnf("invalid message id type %T", msg.ID())
+               return
+       }
        if c.options.EnableDefaultNackBackoffPolicy || 
c.options.NackBackoffPolicy != nil {
                mid := c.messageID(msg.ID())
                if mid == nil {
@@ -745,10 +756,6 @@ func toProtoInitialPosition(p SubscriptionInitialPosition) 
pb.CommandSubscribe_I
 
 func (c *consumer) messageID(msgID MessageID) *trackingMessageID {
        mid := toTrackingMessageID(msgID)
-       if mid == nil {
-               c.log.Warnf("invalid message id type %T", msgID)
-               return nil
-       }
 
        partition := int(mid.partitionIdx)
        // did we receive a valid partition index?
diff --git a/pulsar/consumer_multitopic.go b/pulsar/consumer_multitopic.go
index c0fcaef..8108c29 100644
--- a/pulsar/consumer_multitopic.go
+++ b/pulsar/consumer_multitopic.go
@@ -125,11 +125,11 @@ func (c *multiTopicConsumer) Ack(msg Message) error {
 
 // AckID the consumption of a single message, identified by its MessageID
 func (c *multiTopicConsumer) AckID(msgID MessageID) error {
-       mid := toTrackingMessageID(msgID)
-       if mid == nil {
+       if !checkMessageIDType(msgID) {
                c.log.Warnf("invalid message id type %T", msgID)
                return errors.New("invalid message id type in multi_consumer")
        }
+       mid := toTrackingMessageID(msgID)
 
        if mid.consumer == nil {
                c.log.Warnf("unable to ack messageID=%+v can not determine 
topic", msgID)
@@ -152,11 +152,11 @@ func (c *multiTopicConsumer) AckCumulative(msg Message) 
error {
 // AckIDCumulative the reception of all the messages in the stream up to (and 
including)
 // the provided message, identified by its MessageID
 func (c *multiTopicConsumer) AckIDCumulative(msgID MessageID) error {
-       mid := toTrackingMessageID(msgID)
-       if mid == nil {
+       if !checkMessageIDType(msgID) {
                c.log.Warnf("invalid message id type %T", msgID)
                return errors.New("invalid message id type in multi_consumer")
        }
+       mid := toTrackingMessageID(msgID)
 
        if mid.consumer == nil {
                c.log.Warnf("unable to ack messageID=%+v can not determine 
topic", msgID)
@@ -203,11 +203,11 @@ func (c *multiTopicConsumer) 
ReconsumeLaterWithCustomProperties(msg Message, cus
 func (c *multiTopicConsumer) Nack(msg Message) {
        if c.options.EnableDefaultNackBackoffPolicy || 
c.options.NackBackoffPolicy != nil {
                msgID := msg.ID()
-               mid := toTrackingMessageID(msgID)
-               if mid == nil {
+               if !checkMessageIDType(msgID) {
                        c.log.Warnf("invalid message id type %T", msgID)
                        return
                }
+               mid := toTrackingMessageID(msgID)
 
                if mid.consumer == nil {
                        c.log.Warnf("unable to nack messageID=%+v can not 
determine topic", msgID)
@@ -221,11 +221,11 @@ func (c *multiTopicConsumer) Nack(msg Message) {
 }
 
 func (c *multiTopicConsumer) NackID(msgID MessageID) {
-       mid := toTrackingMessageID(msgID)
-       if mid == nil {
+       if !checkMessageIDType(msgID) {
                c.log.Warnf("invalid message id type %T", msgID)
                return
        }
+       mid := toTrackingMessageID(msgID)
 
        if mid.consumer == nil {
                c.log.Warnf("unable to nack messageID=%+v can not determine 
topic", msgID)
diff --git a/pulsar/consumer_partition.go b/pulsar/consumer_partition.go
index 0f7af3d..a3dac19 100644
--- a/pulsar/consumer_partition.go
+++ b/pulsar/consumer_partition.go
@@ -465,9 +465,6 @@ func (pc *partitionConsumer) ackID(msgID MessageID, 
withResponse bool) error {
        }
 
        trackingID := toTrackingMessageID(msgID)
-       if trackingID == nil {
-               return errors.New("failed to convert trackingMessageID")
-       }
 
        if trackingID != nil && trackingID.ack() {
                pc.metrics.AcksCounter.Inc()
@@ -501,18 +498,34 @@ func (pc *partitionConsumer) sendIndividualAck(msgID 
MessageID) *ackRequest {
 }
 
 func (pc *partitionConsumer) AckIDWithResponse(msgID MessageID) error {
+       if !checkMessageIDType(msgID) {
+               pc.log.Errorf("invalid message id type %T", msgID)
+               return fmt.Errorf("invalid message id type %T", msgID)
+       }
        return pc.ackID(msgID, true)
 }
 
 func (pc *partitionConsumer) AckID(msgID MessageID) error {
+       if !checkMessageIDType(msgID) {
+               pc.log.Errorf("invalid message id type %T", msgID)
+               return fmt.Errorf("invalid message id type %T", msgID)
+       }
        return pc.ackID(msgID, false)
 }
 
 func (pc *partitionConsumer) AckIDCumulative(msgID MessageID) error {
+       if !checkMessageIDType(msgID) {
+               pc.log.Errorf("invalid message id type %T", msgID)
+               return fmt.Errorf("invalid message id type %T", msgID)
+       }
        return pc.internalAckIDCumulative(msgID, false)
 }
 
 func (pc *partitionConsumer) AckIDWithResponseCumulative(msgID MessageID) 
error {
+       if !checkMessageIDType(msgID) {
+               pc.log.Errorf("invalid message id type %T", msgID)
+               return fmt.Errorf("invalid message id type %T", msgID)
+       }
        return pc.internalAckIDCumulative(msgID, true)
 }
 
@@ -574,15 +587,17 @@ func (pc *partitionConsumer) sendCumulativeAck(msgID 
MessageID) *ackRequest {
 }
 
 func (pc *partitionConsumer) NackID(msgID MessageID) {
+       if !checkMessageIDType(msgID) {
+               pc.log.Warnf("invalid message id type %T", msgID)
+               return
+       }
+
        if cmid, ok := msgID.(*chunkMessageID); ok {
                pc.unAckChunksTracker.nack(cmid)
                return
        }
 
        trackingID := toTrackingMessageID(msgID)
-       if trackingID == nil {
-               return
-       }
 
        pc.nackTracker.Add(trackingID.messageID)
        pc.metrics.NacksCounter.Inc()
@@ -665,16 +680,20 @@ func (pc *partitionConsumer) Seek(msgID MessageID) error {
                pc.log.WithField("state", state).Error("Failed to seek by 
closing or closed consumer")
                return errors.New("failed to seek by closing or closed 
consumer")
        }
+
+       if !checkMessageIDType(msgID) {
+               pc.log.Errorf("invalid message id type %T", msgID)
+               return fmt.Errorf("invalid message id type %T", msgID)
+       }
+
        req := &seekRequest{
                doneCh: make(chan struct{}),
        }
        if cmid, ok := msgID.(*chunkMessageID); ok {
                req.msgID = cmid.firstChunkID
-       } else if tmid := toTrackingMessageID(msgID); tmid != nil {
-               req.msgID = tmid.messageID
        } else {
-               // will never reach
-               return errors.New("unhandled messageID type")
+               tmid := toTrackingMessageID(msgID)
+               req.msgID = tmid.messageID
        }
 
        pc.ackGroupingTracker.flushAndClean()
@@ -1812,9 +1831,8 @@ func (c *chunkedMsgCtx) discard(pc *partitionConsumer) {
                        continue
                }
                pc.log.Info("Removing chunk message-id", mid.String())
-               if tmid := toTrackingMessageID(mid); tmid != nil {
-                       pc.AckID(tmid)
-               }
+               tmid := toTrackingMessageID(mid)
+               pc.AckID(tmid)
        }
 }
 
diff --git a/pulsar/consumer_regex.go b/pulsar/consumer_regex.go
index fdfecec..2520af5 100644
--- a/pulsar/consumer_regex.go
+++ b/pulsar/consumer_regex.go
@@ -174,12 +174,13 @@ func (c *regexConsumer) 
ReconsumeLaterWithCustomProperties(msg Message, customPr
 
 // AckID the consumption of a single message, identified by its MessageID
 func (c *regexConsumer) AckID(msgID MessageID) error {
-       mid := toTrackingMessageID(msgID)
-       if mid == nil {
+       if !checkMessageIDType(msgID) {
                c.log.Warnf("invalid message id type %T", msgID)
-               return errors.New("invalid message id type")
+               return fmt.Errorf("invalid message id type %T", msgID)
        }
 
+       mid := toTrackingMessageID(msgID)
+
        if mid.consumer == nil {
                c.log.Warnf("unable to ack messageID=%+v can not determine 
topic", msgID)
                return errors.New("consumer is nil in consumer_regex")
@@ -201,12 +202,13 @@ func (c *regexConsumer) AckCumulative(msg Message) error {
 // AckIDCumulative the reception of all the messages in the stream up to (and 
including)
 // the provided message, identified by its MessageID
 func (c *regexConsumer) AckIDCumulative(msgID MessageID) error {
-       mid := toTrackingMessageID(msgID)
-       if mid == nil {
+       if !checkMessageIDType(msgID) {
                c.log.Warnf("invalid message id type %T", msgID)
-               return errors.New("invalid message id type")
+               return fmt.Errorf("invalid message id type %T", msgID)
        }
 
+       mid := toTrackingMessageID(msgID)
+
        if mid.consumer == nil {
                c.log.Warnf("unable to ack messageID=%+v can not determine 
topic", msgID)
                return errors.New("unable to ack message because consumer is 
nil")
@@ -222,11 +224,11 @@ func (c *regexConsumer) AckIDCumulative(msgID MessageID) 
error {
 func (c *regexConsumer) Nack(msg Message) {
        if c.options.EnableDefaultNackBackoffPolicy || 
c.options.NackBackoffPolicy != nil {
                msgID := msg.ID()
-               mid := toTrackingMessageID(msgID)
-               if mid == nil {
+               if !checkMessageIDType(msgID) {
                        c.log.Warnf("invalid message id type %T", msgID)
                        return
                }
+               mid := toTrackingMessageID(msgID)
 
                if mid.consumer == nil {
                        c.log.Warnf("unable to nack messageID=%+v can not 
determine topic", msgID)
@@ -240,12 +242,13 @@ func (c *regexConsumer) Nack(msg Message) {
 }
 
 func (c *regexConsumer) NackID(msgID MessageID) {
-       mid := toTrackingMessageID(msgID)
-       if mid == nil {
+       if !checkMessageIDType(msgID) {
                c.log.Warnf("invalid message id type %T", msgID)
                return
        }
 
+       mid := toTrackingMessageID(msgID)
+
        if mid.consumer == nil {
                c.log.Warnf("unable to nack messageID=%+v can not determine 
topic", msgID)
                return
diff --git a/pulsar/impl_message.go b/pulsar/impl_message.go
index 68ddecd..9c56070 100644
--- a/pulsar/impl_message.go
+++ b/pulsar/impl_message.go
@@ -213,6 +213,16 @@ func newMessageID(ledgerID int64, entryID int64, batchIdx 
int32, partitionIdx in
        }
 }
 
+func fromMessageID(msgID MessageID) *messageID {
+       return &messageID{
+               ledgerID:     msgID.LedgerID(),
+               entryID:      msgID.EntryID(),
+               batchIdx:     msgID.BatchIdx(),
+               partitionIdx: msgID.PartitionIdx(),
+               batchSize:    msgID.BatchSize(),
+       }
+}
+
 func newTrackingMessageID(ledgerID int64, entryID int64, batchIdx int32, 
partitionIdx int32, batchSize int32,
        tracker *ackTracker) *trackingMessageID {
        return &trackingMessageID{
@@ -228,22 +238,26 @@ func newTrackingMessageID(ledgerID int64, entryID int64, 
batchIdx int32, partiti
        }
 }
 
-func toTrackingMessageID(msgID MessageID) *trackingMessageID {
-       if mid, ok := msgID.(*messageID); ok {
-               return &trackingMessageID{
-                       messageID:    mid,
-                       receivedTime: time.Now(),
-               }
-       } else if mid, ok := msgID.(*trackingMessageID); ok {
+// checkMessageIDType checks if the MessageID is user-defined
+func checkMessageIDType(msgID MessageID) (valid bool) {
+       switch msgID.(type) {
+       case *trackingMessageID:
+               return true
+       case *chunkMessageID:
+               return true
+       case *messageID:
+               return true
+       default:
+               return false
+       }
+}
+
+func toTrackingMessageID(msgID MessageID) (trackingMsgID *trackingMessageID) {
+       if mid, ok := msgID.(*trackingMessageID); ok {
                return mid
-       } else if cmid, ok := msgID.(*chunkMessageID); ok {
-               return &trackingMessageID{
-                       messageID:    cmid.messageID,
-                       receivedTime: cmid.receivedTime,
-                       consumer:     cmid.consumer,
-               }
-       } else {
-               return nil
+       }
+       return &trackingMessageID{
+               messageID: fromMessageID(msgID),
        }
 }
 
diff --git a/pulsar/reader_impl.go b/pulsar/reader_impl.go
index 68dd084..c7620ad 100644
--- a/pulsar/reader_impl.go
+++ b/pulsar/reader_impl.go
@@ -51,8 +51,8 @@ func newReader(client *client, options ReaderOptions) 
(Reader, error) {
                return nil, newError(InvalidConfiguration, "StartMessageID is 
required")
        }
 
-       startMessageID := toTrackingMessageID(options.StartMessageID)
-       if startMessageID == nil {
+       var startMessageID *trackingMessageID
+       if !checkMessageIDType(options.StartMessageID) {
                // a custom type satisfying MessageID may not be a messageID or 
trackingMessageID
                // so re-create messageID using its data
                deserMsgID, err := 
deserializeMessageID(options.StartMessageID.Serialize())
@@ -60,10 +60,9 @@ func newReader(client *client, options ReaderOptions) 
(Reader, error) {
                        return nil, err
                }
                // de-serialized MessageID is a messageID
-               startMessageID = &trackingMessageID{
-                       messageID:    deserMsgID.(*messageID),
-                       receivedTime: time.Now(),
-               }
+               startMessageID = toTrackingMessageID(deserMsgID)
+       } else {
+               startMessageID = toTrackingMessageID(options.StartMessageID)
        }
 
        subscriptionName := options.SubscriptionName
@@ -148,12 +147,10 @@ func (r *reader) Next(ctx context.Context) (Message, 
error) {
                        // Acknowledge message immediately because the reader 
is based on non-durable subscription. When it reconnects,
                        // it will specify the subscription position anyway
                        msgID := cm.Message.ID()
-                       if mid := toTrackingMessageID(msgID); mid != nil {
-                               r.pc.lastDequeuedMsg = mid
-                               r.pc.AckID(mid)
-                               return cm.Message, nil
-                       }
-                       return nil, newError(InvalidMessage, 
fmt.Sprintf("invalid message id type %T", msgID))
+                       mid := toTrackingMessageID(msgID)
+                       r.pc.lastDequeuedMsg = mid
+                       r.pc.AckID(mid)
+                       return cm.Message, nil
                case <-ctx.Done():
                        return nil, ctx.Err()
                }
@@ -202,10 +199,6 @@ func (r *reader) Close() {
 
 func (r *reader) messageID(msgID MessageID) *trackingMessageID {
        mid := toTrackingMessageID(msgID)
-       if mid == nil {
-               r.log.Warnf("invalid message id type %T", msgID)
-               return nil
-       }
 
        partition := int(mid.partitionIdx)
        // did we receive a valid partition index?
@@ -221,6 +214,11 @@ func (r *reader) Seek(msgID MessageID) error {
        r.Lock()
        defer r.Unlock()
 
+       if !checkMessageIDType(msgID) {
+               r.log.Warnf("invalid message id type %T", msgID)
+               return fmt.Errorf("invalid message id type %T", msgID)
+       }
+
        mid := r.messageID(msgID)
        if mid == nil {
                return nil

Reply via email to