This is an automated email from the ASF dual-hosted git repository. xyz pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/pulsar-client-go.git
The following commit(s) were added to refs/heads/master by this push: new 42ded0d Optimize performance by passing MessageID implementations by pointers (#968) 42ded0d is described below commit 42ded0d59c46fd3fdaad45f045f7e8bf091131a5 Author: Yunze Xu <xyzinfern...@163.com> AuthorDate: Thu Mar 2 00:04:14 2023 +0800 Optimize performance by passing MessageID implementations by pointers (#968) ### Motivation Currently there are three implementations of the `MessageID` interface: - `messageID`: 24 bytes - `trackingMessageID`: 64 bytes - `chunkMessageID`: 80 bytes However, for all methods of them, the receiver is a value rather than a pointer. It's inefficient because each time a method is called, the copy would happen. Reference: https://go.dev/tour/methods/8 ### Modifications - Change the receiver from value to pointer for all `MessageID` implementations. - Use pointers as the returned values and function parameters for these implementations everywhere. The `trackingMessageID.Undefined` method is removed because it's never used now. Though it's a public method, the struct and its factory function are not exposed, so I think it's reasonable. Remove the benchmark added in https://github.com/apache/pulsar-client-go/pull/324. The result is obvious and this test is meaningless. I tried passing the `trackingMessageID` by pointer and the result reduced from 8.548 ns/op to 1.628 ns/op. It's obvious because a pointer is only 8 bytes while a `trackingMessageID` is 64 bytes. The overhead of accessing by pointers is far less than copying the extra bytes. --- pulsar/ack_grouping_tracker_test.go | 16 ++-- pulsar/consumer_impl.go | 22 ++--- pulsar/consumer_multitopic.go | 16 ++-- pulsar/consumer_partition.go | 171 +++++++++++++++++------------------ pulsar/consumer_partition_test.go | 17 ++-- pulsar/consumer_regex.go | 16 ++-- pulsar/consumer_test.go | 4 +- pulsar/impl_message.go | 94 +++++++++---------- pulsar/impl_message_bench_test.go | 49 ---------- pulsar/impl_message_test.go | 14 +-- pulsar/message_chunking_test.go | 4 +- pulsar/negative_acks_tracker.go | 2 +- pulsar/negative_acks_tracker_test.go | 16 ++-- pulsar/producer_partition.go | 10 +- pulsar/producer_test.go | 8 +- pulsar/reader_impl.go | 32 +++---- 16 files changed, 216 insertions(+), 275 deletions(-) diff --git a/pulsar/ack_grouping_tracker_test.go b/pulsar/ack_grouping_tracker_test.go index d7903e8..e7a6725 100644 --- a/pulsar/ack_grouping_tracker_test.go +++ b/pulsar/ack_grouping_tracker_test.go @@ -184,14 +184,14 @@ func TestTimedTrackerCumulativeAck(t *testing.T) { func TestTimedTrackerIsDuplicate(t *testing.T) { tracker := newAckGroupingTracker(nil, func(id MessageID) {}, func(id MessageID) {}) - tracker.add(messageID{batchIdx: 0, batchSize: 3}) - tracker.add(messageID{batchIdx: 2, batchSize: 3}) - assert.True(t, tracker.isDuplicate(messageID{batchIdx: 0, batchSize: 3})) - assert.False(t, tracker.isDuplicate(messageID{batchIdx: 1, batchSize: 3})) - assert.True(t, tracker.isDuplicate(messageID{batchIdx: 2, batchSize: 3})) + tracker.add(&messageID{batchIdx: 0, batchSize: 3}) + tracker.add(&messageID{batchIdx: 2, batchSize: 3}) + assert.True(t, tracker.isDuplicate(&messageID{batchIdx: 0, batchSize: 3})) + assert.False(t, tracker.isDuplicate(&messageID{batchIdx: 1, batchSize: 3})) + assert.True(t, tracker.isDuplicate(&messageID{batchIdx: 2, batchSize: 3})) tracker.flush() - assert.False(t, tracker.isDuplicate(messageID{batchIdx: 0, batchSize: 3})) - assert.False(t, tracker.isDuplicate(messageID{batchIdx: 1, batchSize: 3})) - assert.False(t, tracker.isDuplicate(messageID{batchIdx: 2, batchSize: 3})) + assert.False(t, tracker.isDuplicate(&messageID{batchIdx: 0, batchSize: 3})) + assert.False(t, tracker.isDuplicate(&messageID{batchIdx: 1, batchSize: 3})) + assert.False(t, tracker.isDuplicate(&messageID{batchIdx: 2, batchSize: 3})) } diff --git a/pulsar/consumer_impl.go b/pulsar/consumer_impl.go index 8ee1822..d16f719 100644 --- a/pulsar/consumer_impl.go +++ b/pulsar/consumer_impl.go @@ -383,7 +383,7 @@ func (c *consumer) internalTopicSubscribeToPartitions() error { metadata: metadata, subProperties: subProperties, replicateSubscriptionState: c.options.ReplicateSubscriptionState, - startMessageID: trackingMessageID{}, + startMessageID: nil, subscriptionMode: durable, readCompacted: c.options.ReadCompacted, interceptors: c.options.Interceptors, @@ -531,8 +531,8 @@ func (c *consumer) ReconsumeLaterWithCustomProperties(msg Message, customPropert if delay < 0 { delay = 0 } - msgID, ok := c.messageID(msg.ID()) - if !ok { + msgID := c.messageID(msg.ID()) + if msgID == nil { return } props := make(map[string]string) @@ -581,8 +581,8 @@ func (c *consumer) ReconsumeLaterWithCustomProperties(msg Message, customPropert func (c *consumer) Nack(msg Message) { if c.options.EnableDefaultNackBackoffPolicy || c.options.NackBackoffPolicy != nil { - mid, ok := c.messageID(msg.ID()) - if !ok { + mid := c.messageID(msg.ID()) + if mid == nil { return } @@ -743,11 +743,11 @@ func toProtoInitialPosition(p SubscriptionInitialPosition) pb.CommandSubscribe_I return pb.CommandSubscribe_Latest } -func (c *consumer) messageID(msgID MessageID) (trackingMessageID, bool) { - mid, ok := toTrackingMessageID(msgID) - if !ok { +func (c *consumer) messageID(msgID MessageID) *trackingMessageID { + mid := toTrackingMessageID(msgID) + if mid == nil { c.log.Warnf("invalid message id type %T", msgID) - return trackingMessageID{}, false + return nil } partition := int(mid.partitionIdx) @@ -755,10 +755,10 @@ func (c *consumer) messageID(msgID MessageID) (trackingMessageID, bool) { if partition < 0 || partition >= len(c.consumers) { c.log.Warnf("invalid partition index %d expected a partition between [0-%d]", partition, len(c.consumers)) - return trackingMessageID{}, false + return nil } - return mid, true + return mid } func addMessageCryptoIfMissing(client *client, options *ConsumerOptions, topics interface{}) error { diff --git a/pulsar/consumer_multitopic.go b/pulsar/consumer_multitopic.go index 452915a..c0fcaef 100644 --- a/pulsar/consumer_multitopic.go +++ b/pulsar/consumer_multitopic.go @@ -125,8 +125,8 @@ func (c *multiTopicConsumer) Ack(msg Message) error { // AckID the consumption of a single message, identified by its MessageID func (c *multiTopicConsumer) AckID(msgID MessageID) error { - mid, ok := toTrackingMessageID(msgID) - if !ok { + mid := toTrackingMessageID(msgID) + if mid == nil { c.log.Warnf("invalid message id type %T", msgID) return errors.New("invalid message id type in multi_consumer") } @@ -152,8 +152,8 @@ func (c *multiTopicConsumer) AckCumulative(msg Message) error { // AckIDCumulative the reception of all the messages in the stream up to (and including) // the provided message, identified by its MessageID func (c *multiTopicConsumer) AckIDCumulative(msgID MessageID) error { - mid, ok := toTrackingMessageID(msgID) - if !ok { + mid := toTrackingMessageID(msgID) + if mid == nil { c.log.Warnf("invalid message id type %T", msgID) return errors.New("invalid message id type in multi_consumer") } @@ -203,8 +203,8 @@ func (c *multiTopicConsumer) ReconsumeLaterWithCustomProperties(msg Message, cus func (c *multiTopicConsumer) Nack(msg Message) { if c.options.EnableDefaultNackBackoffPolicy || c.options.NackBackoffPolicy != nil { msgID := msg.ID() - mid, ok := toTrackingMessageID(msgID) - if !ok { + mid := toTrackingMessageID(msgID) + if mid == nil { c.log.Warnf("invalid message id type %T", msgID) return } @@ -221,8 +221,8 @@ func (c *multiTopicConsumer) Nack(msg Message) { } func (c *multiTopicConsumer) NackID(msgID MessageID) { - mid, ok := toTrackingMessageID(msgID) - if !ok { + mid := toTrackingMessageID(msgID) + if mid == nil { c.log.Warnf("invalid message id type %T", msgID) return } diff --git a/pulsar/consumer_partition.go b/pulsar/consumer_partition.go index 95a8d32..0f7af3d 100644 --- a/pulsar/consumer_partition.go +++ b/pulsar/consumer_partition.go @@ -41,10 +41,6 @@ import ( uAtomic "go.uber.org/atomic" ) -var ( - lastestMessageID = LatestMessageID() -) - type consumerState int const ( @@ -98,7 +94,7 @@ type partitionConsumerOpts struct { metadata map[string]string subProperties map[string]string replicateSubscriptionState bool - startMessageID trackingMessageID + startMessageID *trackingMessageID startMessageIDInclusive bool subscriptionMode subscriptionMode readCompacted bool @@ -149,13 +145,13 @@ type partitionConsumer struct { queueSize int32 queueCh chan []*message startMessageID atomicMessageID - lastDequeuedMsg trackingMessageID + lastDequeuedMsg *trackingMessageID eventsCh chan interface{} connectedCh chan struct{} connectClosedCh chan connectionClosed closeCh chan struct{} - clearQueueCh chan func(id trackingMessageID) + clearQueueCh chan func(id *trackingMessageID) nackTracker *negativeAcksTracker dlq *dlqRouter @@ -217,17 +213,17 @@ func (p *availablePermits) reset() { // atomicMessageID is a wrapper for trackingMessageID to make get and set atomic type atomicMessageID struct { - msgID trackingMessageID + msgID *trackingMessageID sync.RWMutex } -func (a *atomicMessageID) get() trackingMessageID { +func (a *atomicMessageID) get() *trackingMessageID { a.RLock() defer a.RUnlock() return a.msgID } -func (a *atomicMessageID) set(msgID trackingMessageID) { +func (a *atomicMessageID) set(msgID *trackingMessageID) { a.Lock() defer a.Unlock() a.msgID = msgID @@ -303,7 +299,7 @@ func newPartitionConsumer(parent Consumer, client *client, options *partitionCon messageCh: messageCh, connectClosedCh: make(chan connectionClosed, 10), closeCh: make(chan struct{}), - clearQueueCh: make(chan func(id trackingMessageID)), + clearQueueCh: make(chan func(id *trackingMessageID)), compressionProviders: sync.Map{}, dlq: dlq, metrics: metrics, @@ -347,7 +343,8 @@ func newPartitionConsumer(parent Consumer, client *client, options *partitionCon pc.log.Info("Created consumer") pc.setConsumerState(consumerReady) - if pc.options.startMessageIDInclusive && pc.startMessageID.get().equal(lastestMessageID.(messageID)) { + startingMessageID := pc.startMessageID.get() + if pc.options.startMessageIDInclusive && startingMessageID != nil && startingMessageID.equal(latestMessageID) { msgID, err := pc.requestGetLastMessageID() if err != nil { pc.nackTracker.Close() @@ -418,10 +415,10 @@ func (pc *partitionConsumer) internalUnsubscribe(unsub *unsubscribeRequest) { pc.setConsumerState(consumerClosed) } -func (pc *partitionConsumer) getLastMessageID() (trackingMessageID, error) { +func (pc *partitionConsumer) getLastMessageID() (*trackingMessageID, error) { if state := pc.getConsumerState(); state == consumerClosed || state == consumerClosing { pc.log.WithField("state", state).Error("Failed to redeliver closing or closed consumer") - return trackingMessageID{}, errors.New("failed to redeliver closing or closed consumer") + return nil, errors.New("failed to redeliver closing or closed consumer") } req := &getLastMsgIDRequest{doneCh: make(chan struct{})} pc.eventsCh <- req @@ -436,10 +433,10 @@ func (pc *partitionConsumer) internalGetLastMessageID(req *getLastMsgIDRequest) req.msgID, req.err = pc.requestGetLastMessageID() } -func (pc *partitionConsumer) requestGetLastMessageID() (trackingMessageID, error) { +func (pc *partitionConsumer) requestGetLastMessageID() (*trackingMessageID, error) { if state := pc.getConsumerState(); state == consumerClosed || state == consumerClosing { pc.log.WithField("state", state).Error("Failed to getLastMessageID closing or closed consumer") - return trackingMessageID{}, errors.New("failed to getLastMessageID closing or closed consumer") + return nil, errors.New("failed to getLastMessageID closing or closed consumer") } requestID := pc.client.rpcClient.NewRequestID() @@ -451,7 +448,7 @@ func (pc *partitionConsumer) requestGetLastMessageID() (trackingMessageID, error pb.BaseCommand_GET_LAST_MESSAGE_ID, cmdGetLastMessageID) if err != nil { pc.log.WithError(err).Error("Failed to get last message id") - return trackingMessageID{}, err + return nil, err } id := res.Response.GetLastMessageIdResponse.GetLastMessageId() return convertToMessageID(id), nil @@ -463,16 +460,16 @@ func (pc *partitionConsumer) ackID(msgID MessageID, withResponse bool) error { return errors.New("consumer state is closed") } - if cmid, ok := toChunkedMessageID(msgID); ok { + if cmid, ok := msgID.(*chunkMessageID); ok { return pc.unAckChunksTracker.ack(cmid) } - trackingID, ok := toTrackingMessageID(msgID) - if !ok { + trackingID := toTrackingMessageID(msgID) + if trackingID == nil { return errors.New("failed to convert trackingMessageID") } - if !trackingID.Undefined() && trackingID.ack() { + if trackingID != nil && trackingID.ack() { pc.metrics.AcksCounter.Inc() pc.metrics.ProcessingTime.Observe(float64(time.Now().UnixNano()-trackingID.receivedTime.UnixNano()) / 1.0e9) } else if !pc.options.enableBatchIndexAck { @@ -481,10 +478,10 @@ func (pc *partitionConsumer) ackID(msgID MessageID, withResponse bool) error { var ackReq *ackRequest if withResponse { - ackReq := pc.sendIndividualAck(&trackingID) + ackReq := pc.sendIndividualAck(trackingID) <-ackReq.doneCh } else { - pc.ackGroupingTracker.add(&trackingID) + pc.ackGroupingTracker.add(trackingID) } pc.options.interceptors.OnAcknowledge(pc.parentConsumer, msgID) if ackReq == nil { @@ -526,15 +523,12 @@ func (pc *partitionConsumer) internalAckIDCumulative(msgID MessageID, withRespon } // chunk message id will be converted to tracking message id - trackingID, ok := toTrackingMessageID(msgID) - if !ok { + trackingID := toTrackingMessageID(msgID) + if trackingID == nil { return errors.New("failed to convert trackingMessageID") } - if trackingID.Undefined() { - return nil - } - var msgIDToAck trackingMessageID + var msgIDToAck *trackingMessageID if trackingID.ackCumulative() || pc.options.enableBatchIndexAck { msgIDToAck = trackingID } else if !trackingID.tracker.hasPrevBatchAcked() { @@ -551,15 +545,15 @@ func (pc *partitionConsumer) internalAckIDCumulative(msgID MessageID, withRespon var ackReq *ackRequest if withResponse { - ackReq := pc.sendCumulativeAck(&msgIDToAck) + ackReq := pc.sendCumulativeAck(msgIDToAck) <-ackReq.doneCh } else { - pc.ackGroupingTracker.addCumulative(&msgIDToAck) + pc.ackGroupingTracker.addCumulative(msgIDToAck) } pc.options.interceptors.OnAcknowledge(pc.parentConsumer, msgID) - if cmid, ok := toChunkedMessageID(msgID); ok { + if cmid, ok := msgID.(*chunkMessageID); ok { pc.unAckChunksTracker.remove(cmid) } @@ -580,13 +574,13 @@ func (pc *partitionConsumer) sendCumulativeAck(msgID MessageID) *ackRequest { } func (pc *partitionConsumer) NackID(msgID MessageID) { - if cmid, ok := toChunkedMessageID(msgID); ok { + if cmid, ok := msgID.(*chunkMessageID); ok { pc.unAckChunksTracker.nack(cmid) return } - trackingID, ok := toTrackingMessageID(msgID) - if !ok { + trackingID := toTrackingMessageID(msgID) + if trackingID == nil { return } @@ -674,9 +668,9 @@ func (pc *partitionConsumer) Seek(msgID MessageID) error { req := &seekRequest{ doneCh: make(chan struct{}), } - if cmid, ok := toChunkedMessageID(msgID); ok { + if cmid, ok := msgID.(*chunkMessageID); ok { req.msgID = cmid.firstChunkID - } else if tmid, ok := toTrackingMessageID(msgID); ok { + } else if tmid := toTrackingMessageID(msgID); tmid != nil { req.msgID = tmid.messageID } else { // will never reach @@ -695,7 +689,7 @@ func (pc *partitionConsumer) internalSeek(seek *seekRequest) { defer close(seek.doneCh) seek.err = pc.requestSeek(seek.msgID) } -func (pc *partitionConsumer) requestSeek(msgID messageID) error { +func (pc *partitionConsumer) requestSeek(msgID *messageID) error { if err := pc.requestSeekWithoutClear(msgID); err != nil { return err } @@ -703,7 +697,7 @@ func (pc *partitionConsumer) requestSeek(msgID messageID) error { return nil } -func (pc *partitionConsumer) requestSeekWithoutClear(msgID messageID) error { +func (pc *partitionConsumer) requestSeekWithoutClear(msgID *messageID) error { state := pc.getConsumerState() if state == consumerClosing || state == consumerClosed { pc.log.WithField("state", state).Error("failed seek by consumer is closing or has closed") @@ -1063,7 +1057,7 @@ func (pc *partitionConsumer) processMessageChunk(compressedPayload internal.Buff numChunks := msgMeta.GetNumChunksFromMsg() totalChunksSize := int(msgMeta.GetTotalChunkMsgSize()) chunkID := msgMeta.GetChunkId() - msgID := messageID{ + msgID := &messageID{ ledgerID: int64(pbMsgID.GetLedgerId()), entryID: int64(pbMsgID.GetEntryId()), batchIdx: -1, @@ -1105,12 +1099,12 @@ func (pc *partitionConsumer) processMessageChunk(compressedPayload internal.Buff return ctx.chunkedMsgBuffer } -func (pc *partitionConsumer) messageShouldBeDiscarded(msgID trackingMessageID) bool { - if pc.startMessageID.get().Undefined() { +func (pc *partitionConsumer) messageShouldBeDiscarded(msgID *trackingMessageID) bool { + if pc.startMessageID.get() == nil { return false } // if we start at latest message, we should never discard - if pc.options.startMessageID.equal(latestMessageID) { + if pc.options.startMessageID != nil && pc.options.startMessageID.equal(latestMessageID) { return false } @@ -1263,7 +1257,7 @@ func (pc *partitionConsumer) dispatcher() { case clearQueueCb := <-pc.clearQueueCh: // drain the message queue on any new connection by sending a // special nil message to the channel so we know when to stop dropping messages - var nextMessageInQueue trackingMessageID + var nextMessageInQueue *trackingMessageID go func() { pc.queueCh <- nil }() @@ -1272,8 +1266,8 @@ func (pc *partitionConsumer) dispatcher() { // the queue has been drained if m == nil { break - } else if nextMessageInQueue.Undefined() { - nextMessageInQueue, _ = toTrackingMessageID(m[0].msgID) + } else if nextMessageInQueue == nil { + nextMessageInQueue = toTrackingMessageID(m[0].msgID) } } @@ -1311,13 +1305,13 @@ type redeliveryRequest struct { type getLastMsgIDRequest struct { doneCh chan struct{} - msgID trackingMessageID + msgID *trackingMessageID err error } type seekRequest struct { doneCh chan struct{} - msgID messageID + msgID *messageID err error } @@ -1578,15 +1572,15 @@ func (pc *partitionConsumer) grabConn() error { } } -func (pc *partitionConsumer) clearQueueAndGetNextMessage() trackingMessageID { +func (pc *partitionConsumer) clearQueueAndGetNextMessage() *trackingMessageID { if pc.getConsumerState() != consumerReady { - return trackingMessageID{} + return nil } wg := &sync.WaitGroup{} wg.Add(1) - var msgID trackingMessageID + var msgID *trackingMessageID - pc.clearQueueCh <- func(id trackingMessageID) { + pc.clearQueueCh <- func(id *trackingMessageID) { msgID = id wg.Done() } @@ -1599,16 +1593,16 @@ func (pc *partitionConsumer) clearQueueAndGetNextMessage() trackingMessageID { * Clear the internal receiver queue and returns the message id of what was the 1st message in the queue that was * not seen by the application */ -func (pc *partitionConsumer) clearReceiverQueue() trackingMessageID { +func (pc *partitionConsumer) clearReceiverQueue() *trackingMessageID { nextMessageInQueue := pc.clearQueueAndGetNextMessage() - if pc.startMessageID.get().Undefined() { + if pc.startMessageID.get() == nil { return pc.startMessageID.get() } - if !nextMessageInQueue.Undefined() { + if nextMessageInQueue != nil { return getPreviousMessage(nextMessageInQueue) - } else if !pc.lastDequeuedMsg.Undefined() { + } else if pc.lastDequeuedMsg != nil { // If the queue was empty we need to restart from the message just after the last one that has been dequeued // in the past return pc.lastDequeuedMsg @@ -1618,10 +1612,10 @@ func (pc *partitionConsumer) clearReceiverQueue() trackingMessageID { } } -func getPreviousMessage(mid trackingMessageID) trackingMessageID { +func getPreviousMessage(mid *trackingMessageID) *trackingMessageID { if mid.batchIdx >= 0 { - return trackingMessageID{ - messageID: messageID{ + return &trackingMessageID{ + messageID: &messageID{ ledgerID: mid.ledgerID, entryID: mid.entryID, batchIdx: mid.batchIdx - 1, @@ -1634,8 +1628,8 @@ func getPreviousMessage(mid trackingMessageID) trackingMessageID { } // Get on previous message in previous entry - return trackingMessageID{ - messageID: messageID{ + return &trackingMessageID{ + messageID: &messageID{ ledgerID: mid.ledgerID, entryID: mid.entryID - 1, batchIdx: mid.batchIdx, @@ -1734,8 +1728,8 @@ func (pc *partitionConsumer) _getConn() internal.Connection { return pc.conn.Load().(internal.Connection) } -func convertToMessageIDData(msgID trackingMessageID) *pb.MessageIdData { - if msgID.Undefined() { +func convertToMessageIDData(msgID *trackingMessageID) *pb.MessageIdData { + if msgID == nil { return nil } @@ -1745,13 +1739,13 @@ func convertToMessageIDData(msgID trackingMessageID) *pb.MessageIdData { } } -func convertToMessageID(id *pb.MessageIdData) trackingMessageID { +func convertToMessageID(id *pb.MessageIdData) *trackingMessageID { if id == nil { - return trackingMessageID{} + return nil } - msgID := trackingMessageID{ - messageID: messageID{ + msgID := &trackingMessageID{ + messageID: &messageID{ ledgerID: int64(*id.LedgerId), entryID: int64(*id.EntryId), }, @@ -1767,7 +1761,7 @@ type chunkedMsgCtx struct { totalChunks int32 chunkedMsgBuffer internal.Buffer lastChunkedMsgID int32 - chunkedMsgIDs []messageID + chunkedMsgIDs []*messageID receivedTime int64 mu sync.Mutex @@ -1778,12 +1772,12 @@ func newChunkedMsgCtx(numChunksFromMsg int32, totalChunkMsgSize int) *chunkedMsg totalChunks: numChunksFromMsg, chunkedMsgBuffer: internal.NewBuffer(totalChunkMsgSize), lastChunkedMsgID: -1, - chunkedMsgIDs: make([]messageID, numChunksFromMsg), + chunkedMsgIDs: make([]*messageID, numChunksFromMsg), receivedTime: time.Now().Unix(), } } -func (c *chunkedMsgCtx) append(chunkID int32, msgID messageID, partPayload internal.Buffer) { +func (c *chunkedMsgCtx) append(chunkID int32, msgID *messageID, partPayload internal.Buffer) { c.mu.Lock() defer c.mu.Unlock() c.chunkedMsgIDs[chunkID] = msgID @@ -1791,20 +1785,20 @@ func (c *chunkedMsgCtx) append(chunkID int32, msgID messageID, partPayload inter c.lastChunkedMsgID = chunkID } -func (c *chunkedMsgCtx) firstChunkID() messageID { +func (c *chunkedMsgCtx) firstChunkID() *messageID { c.mu.Lock() defer c.mu.Unlock() if len(c.chunkedMsgIDs) == 0 { - return messageID{} + return nil } return c.chunkedMsgIDs[0] } -func (c *chunkedMsgCtx) lastChunkID() messageID { +func (c *chunkedMsgCtx) lastChunkID() *messageID { c.mu.Lock() defer c.mu.Unlock() if len(c.chunkedMsgIDs) == 0 { - return messageID{} + return nil } return c.chunkedMsgIDs[len(c.chunkedMsgIDs)-1] } @@ -1814,9 +1808,13 @@ func (c *chunkedMsgCtx) discard(pc *partitionConsumer) { defer c.mu.Unlock() for _, mid := range c.chunkedMsgIDs { + if mid == nil { + continue + } pc.log.Info("Removing chunk message-id", mid.String()) - tmid, _ := toTrackingMessageID(mid) - pc.AckID(tmid) + if tmid := toTrackingMessageID(mid); tmid != nil { + pc.AckID(tmid) + } } } @@ -1935,40 +1933,41 @@ func (c *chunkedMsgCtxMap) Close() { } type unAckChunksTracker struct { - chunkIDs map[chunkMessageID][]messageID + // TODO: use hash code of chunkMessageID as the key + chunkIDs map[chunkMessageID][]*messageID pc *partitionConsumer mu sync.Mutex } func newUnAckChunksTracker(pc *partitionConsumer) *unAckChunksTracker { return &unAckChunksTracker{ - chunkIDs: make(map[chunkMessageID][]messageID), + chunkIDs: make(map[chunkMessageID][]*messageID), pc: pc, } } -func (u *unAckChunksTracker) add(cmid chunkMessageID, ids []messageID) { +func (u *unAckChunksTracker) add(cmid *chunkMessageID, ids []*messageID) { u.mu.Lock() defer u.mu.Unlock() - u.chunkIDs[cmid] = ids + u.chunkIDs[*cmid] = ids } -func (u *unAckChunksTracker) get(cmid chunkMessageID) []messageID { +func (u *unAckChunksTracker) get(cmid *chunkMessageID) []*messageID { u.mu.Lock() defer u.mu.Unlock() - return u.chunkIDs[cmid] + return u.chunkIDs[*cmid] } -func (u *unAckChunksTracker) remove(cmid chunkMessageID) { +func (u *unAckChunksTracker) remove(cmid *chunkMessageID) { u.mu.Lock() defer u.mu.Unlock() - delete(u.chunkIDs, cmid) + delete(u.chunkIDs, *cmid) } -func (u *unAckChunksTracker) ack(cmid chunkMessageID) error { +func (u *unAckChunksTracker) ack(cmid *chunkMessageID) error { ids := u.get(cmid) for _, id := range ids { if err := u.pc.AckID(id); err != nil { @@ -1979,7 +1978,7 @@ func (u *unAckChunksTracker) ack(cmid chunkMessageID) error { return nil } -func (u *unAckChunksTracker) nack(cmid chunkMessageID) { +func (u *unAckChunksTracker) nack(cmid *chunkMessageID) { ids := u.get(cmid) for _, id := range ids { u.pc.NackID(id) diff --git a/pulsar/consumer_partition_test.go b/pulsar/consumer_partition_test.go index b9a9a02..16c4399 100644 --- a/pulsar/consumer_partition_test.go +++ b/pulsar/consumer_partition_test.go @@ -48,11 +48,11 @@ func TestSingleMessageIDNoAckTracker(t *testing.T) { // ensure the tracker was set on the message id messages := <-pc.queueCh for _, m := range messages { - assert.Nil(t, m.ID().(trackingMessageID).tracker) + assert.Nil(t, m.ID().(*trackingMessageID).tracker) } // ack the message id - pc.AckID(messages[0].msgID.(trackingMessageID)) + pc.AckID(messages[0].msgID.(*trackingMessageID)) select { case <-eventsCh: @@ -86,11 +86,12 @@ func TestBatchMessageIDNoAckTracker(t *testing.T) { // ensure the tracker was set on the message id messages := <-pc.queueCh for _, m := range messages { - assert.Nil(t, m.ID().(trackingMessageID).tracker) + assert.Nil(t, m.ID().(*trackingMessageID).tracker) } // ack the message id - pc.AckID(messages[0].msgID.(trackingMessageID)) + err := pc.AckID(messages[0].msgID.(*trackingMessageID)) + assert.Nil(t, err) select { case <-eventsCh: @@ -120,12 +121,13 @@ func TestBatchMessageIDWithAckTracker(t *testing.T) { // ensure the tracker was set on the message id messages := <-pc.queueCh for _, m := range messages { - assert.NotNil(t, m.ID().(trackingMessageID).tracker) + assert.NotNil(t, m.ID().(*trackingMessageID).tracker) } // ack all message ids except the last one for i := 0; i < 9; i++ { - pc.AckID(messages[i].msgID.(trackingMessageID)) + err := pc.AckID(messages[i].msgID.(*trackingMessageID)) + assert.Nil(t, err) } select { @@ -135,7 +137,8 @@ func TestBatchMessageIDWithAckTracker(t *testing.T) { } // ack last message - pc.AckID(messages[9].msgID.(trackingMessageID)) + err := pc.AckID(messages[9].msgID.(*trackingMessageID)) + assert.Nil(t, err) select { case <-eventsCh: diff --git a/pulsar/consumer_regex.go b/pulsar/consumer_regex.go index d890c67..fdfecec 100644 --- a/pulsar/consumer_regex.go +++ b/pulsar/consumer_regex.go @@ -174,8 +174,8 @@ func (c *regexConsumer) ReconsumeLaterWithCustomProperties(msg Message, customPr // AckID the consumption of a single message, identified by its MessageID func (c *regexConsumer) AckID(msgID MessageID) error { - mid, ok := toTrackingMessageID(msgID) - if !ok { + mid := toTrackingMessageID(msgID) + if mid == nil { c.log.Warnf("invalid message id type %T", msgID) return errors.New("invalid message id type") } @@ -201,8 +201,8 @@ func (c *regexConsumer) AckCumulative(msg Message) error { // AckIDCumulative the reception of all the messages in the stream up to (and including) // the provided message, identified by its MessageID func (c *regexConsumer) AckIDCumulative(msgID MessageID) error { - mid, ok := toTrackingMessageID(msgID) - if !ok { + mid := toTrackingMessageID(msgID) + if mid == nil { c.log.Warnf("invalid message id type %T", msgID) return errors.New("invalid message id type") } @@ -222,8 +222,8 @@ func (c *regexConsumer) AckIDCumulative(msgID MessageID) error { func (c *regexConsumer) Nack(msg Message) { if c.options.EnableDefaultNackBackoffPolicy || c.options.NackBackoffPolicy != nil { msgID := msg.ID() - mid, ok := toTrackingMessageID(msgID) - if !ok { + mid := toTrackingMessageID(msgID) + if mid == nil { c.log.Warnf("invalid message id type %T", msgID) return } @@ -240,8 +240,8 @@ func (c *regexConsumer) Nack(msg Message) { } func (c *regexConsumer) NackID(msgID MessageID) { - mid, ok := toTrackingMessageID(msgID) - if !ok { + mid := toTrackingMessageID(msgID) + if mid == nil { c.log.Warnf("invalid message id type %T", msgID) return } diff --git a/pulsar/consumer_test.go b/pulsar/consumer_test.go index de90c0e..21fa7d0 100644 --- a/pulsar/consumer_test.go +++ b/pulsar/consumer_test.go @@ -910,7 +910,7 @@ func TestConsumerNoBatchCumulativeAck(t *testing.T) { if i == N/2-1 { // cumulative acks the first half of messages - consumer.AckCumulative(msg) + assert.Nil(t, consumer.AckCumulative(msg)) } } @@ -931,7 +931,7 @@ func TestConsumerNoBatchCumulativeAck(t *testing.T) { assert.Nil(t, err) assert.Equal(t, fmt.Sprintf("msg-content-%d", i), string(msg.Payload())) - consumer.Ack(msg) + assert.Nil(t, consumer.Ack(msg)) } } diff --git a/pulsar/impl_message.go b/pulsar/impl_message.go index 39db8e1..68ddecd 100644 --- a/pulsar/impl_message.go +++ b/pulsar/impl_message.go @@ -39,7 +39,7 @@ type messageID struct { batchSize int32 } -var latestMessageID = messageID{ +var latestMessageID = &messageID{ ledgerID: math.MaxInt64, entryID: math.MaxInt64, batchIdx: -1, @@ -47,7 +47,7 @@ var latestMessageID = messageID{ batchSize: 0, } -var earliestMessageID = messageID{ +var earliestMessageID = &messageID{ ledgerID: -1, entryID: -1, batchIdx: -1, @@ -56,18 +56,14 @@ var earliestMessageID = messageID{ } type trackingMessageID struct { - messageID + *messageID tracker *ackTracker consumer acker receivedTime time.Time } -func (id trackingMessageID) Undefined() bool { - return id == trackingMessageID{} -} - -func (id trackingMessageID) Ack() error { +func (id *trackingMessageID) Ack() error { if id.consumer == nil { return errors.New("consumer is nil in trackingMessageID") } @@ -78,7 +74,7 @@ func (id trackingMessageID) Ack() error { return nil } -func (id trackingMessageID) AckWithResponse() error { +func (id *trackingMessageID) AckWithResponse() error { if id.consumer == nil { return errors.New("consumer is nil in trackingMessageID") } @@ -89,37 +85,37 @@ func (id trackingMessageID) AckWithResponse() error { return nil } -func (id trackingMessageID) Nack() { +func (id *trackingMessageID) Nack() { if id.consumer == nil { return } id.consumer.NackID(id) } -func (id trackingMessageID) NackByMsg(msg Message) { +func (id *trackingMessageID) NackByMsg(msg Message) { if id.consumer == nil { return } id.consumer.NackMsg(msg) } -func (id trackingMessageID) ack() bool { +func (id *trackingMessageID) ack() bool { if id.tracker != nil && id.batchIdx > -1 { return id.tracker.ack(int(id.batchIdx)) } return true } -func (id trackingMessageID) ackCumulative() bool { +func (id *trackingMessageID) ackCumulative() bool { if id.tracker != nil && id.batchIdx > -1 { return id.tracker.ackCumulative(int(id.batchIdx)) } return true } -func (id trackingMessageID) prev() trackingMessageID { - return trackingMessageID{ - messageID: messageID{ +func (id *trackingMessageID) prev() *trackingMessageID { + return &trackingMessageID{ + messageID: &messageID{ ledgerID: id.ledgerID, entryID: id.entryID - 1, partitionIdx: id.partitionIdx, @@ -129,11 +125,11 @@ func (id trackingMessageID) prev() trackingMessageID { } } -func (id messageID) isEntryIDValid() bool { +func (id *messageID) isEntryIDValid() bool { return id.entryID >= 0 } -func (id messageID) greater(other messageID) bool { +func (id *messageID) greater(other *messageID) bool { if id.ledgerID != other.ledgerID { return id.ledgerID > other.ledgerID } @@ -145,17 +141,17 @@ func (id messageID) greater(other messageID) bool { return id.batchIdx > other.batchIdx } -func (id messageID) equal(other messageID) bool { +func (id *messageID) equal(other *messageID) bool { return id.ledgerID == other.ledgerID && id.entryID == other.entryID && id.batchIdx == other.batchIdx } -func (id messageID) greaterEqual(other messageID) bool { +func (id *messageID) greaterEqual(other *messageID) bool { return id.equal(other) || id.greater(other) } -func (id messageID) Serialize() []byte { +func (id *messageID) Serialize() []byte { msgID := &pb.MessageIdData{ LedgerId: proto.Uint64(uint64(id.ledgerID)), EntryId: proto.Uint64(uint64(id.entryID)), @@ -167,27 +163,27 @@ func (id messageID) Serialize() []byte { return data } -func (id messageID) LedgerID() int64 { +func (id *messageID) LedgerID() int64 { return id.ledgerID } -func (id messageID) EntryID() int64 { +func (id *messageID) EntryID() int64 { return id.entryID } -func (id messageID) BatchIdx() int32 { +func (id *messageID) BatchIdx() int32 { return id.batchIdx } -func (id messageID) PartitionIdx() int32 { +func (id *messageID) PartitionIdx() int32 { return id.partitionIdx } -func (id messageID) BatchSize() int32 { +func (id *messageID) BatchSize() int32 { return id.batchSize } -func (id messageID) String() string { +func (id *messageID) String() string { return fmt.Sprintf("%d:%d:%d", id.ledgerID, id.entryID, id.partitionIdx) } @@ -208,7 +204,7 @@ func deserializeMessageID(data []byte) (MessageID, error) { } func newMessageID(ledgerID int64, entryID int64, batchIdx int32, partitionIdx int32, batchSize int32) MessageID { - return messageID{ + return &messageID{ ledgerID: ledgerID, entryID: entryID, batchIdx: batchIdx, @@ -218,9 +214,9 @@ func newMessageID(ledgerID int64, entryID int64, batchIdx int32, partitionIdx in } func newTrackingMessageID(ledgerID int64, entryID int64, batchIdx int32, partitionIdx int32, batchSize int32, - tracker *ackTracker) trackingMessageID { - return trackingMessageID{ - messageID: messageID{ + tracker *ackTracker) *trackingMessageID { + return &trackingMessageID{ + messageID: &messageID{ ledgerID: ledgerID, entryID: entryID, batchIdx: batchIdx, @@ -232,31 +228,23 @@ func newTrackingMessageID(ledgerID int64, entryID int64, batchIdx int32, partiti } } -func toTrackingMessageID(msgID MessageID) (trackingMessageID, bool) { - if mid, ok := msgID.(messageID); ok { - return trackingMessageID{ +func toTrackingMessageID(msgID MessageID) *trackingMessageID { + if mid, ok := msgID.(*messageID); ok { + return &trackingMessageID{ messageID: mid, receivedTime: time.Now(), - }, true - } else if mid, ok := msgID.(trackingMessageID); ok { - return mid, true - } else if cmid, ok := msgID.(chunkMessageID); ok { - return trackingMessageID{ + } + } else if mid, ok := msgID.(*trackingMessageID); ok { + return mid + } else if cmid, ok := msgID.(*chunkMessageID); ok { + return &trackingMessageID{ messageID: cmid.messageID, receivedTime: cmid.receivedTime, consumer: cmid.consumer, - }, true + } } else { - return trackingMessageID{}, false - } -} - -func toChunkedMessageID(msgID MessageID) (chunkMessageID, bool) { - cid, ok := msgID.(chunkMessageID) - if ok { - return cid, true + return nil } - return chunkMessageID{}, false } func timeFromUnixTimestampMillis(timestamp uint64) time.Time { @@ -449,16 +437,16 @@ func (t *ackTracker) toAckSet() []int64 { } type chunkMessageID struct { - messageID + *messageID - firstChunkID messageID + firstChunkID *messageID receivedTime time.Time consumer acker } -func newChunkMessageID(firstChunkID messageID, lastChunkID messageID) chunkMessageID { - return chunkMessageID{ +func newChunkMessageID(firstChunkID *messageID, lastChunkID *messageID) *chunkMessageID { + return &chunkMessageID{ messageID: lastChunkID, firstChunkID: firstChunkID, receivedTime: time.Now(), diff --git a/pulsar/impl_message_bench_test.go b/pulsar/impl_message_bench_test.go deleted file mode 100644 index 4b6ca10..0000000 --- a/pulsar/impl_message_bench_test.go +++ /dev/null @@ -1,49 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -package pulsar - -import ( - "testing" -) - -var ( - usedByProducer messageID - usedByConsumer trackingMessageID -) - -func producerCall(id messageID) messageID { - id.entryID++ - return id -} - -func consumerCall(id trackingMessageID) trackingMessageID { - id.entryID++ - return id -} - -func BenchmarkProducerCall(b *testing.B) { - for i := 0; i < b.N; i++ { - usedByProducer = producerCall(usedByProducer) - } -} - -func BenchmarkConsumerCall(b *testing.B) { - for i := 0; i < b.N; i++ { - usedByConsumer = consumerCall(usedByConsumer) - } -} diff --git a/pulsar/impl_message_test.go b/pulsar/impl_message_test.go index 413a39f..6a21171 100644 --- a/pulsar/impl_message_test.go +++ b/pulsar/impl_message_test.go @@ -31,11 +31,11 @@ func TestMessageId(t *testing.T) { assert.NoError(t, err) assert.NotNil(t, id2) - assert.Equal(t, int64(1), id2.(messageID).ledgerID) - assert.Equal(t, int64(2), id2.(messageID).entryID) - assert.Equal(t, int32(3), id2.(messageID).batchIdx) - assert.Equal(t, int32(4), id2.(messageID).partitionIdx) - assert.Equal(t, int32(5), id2.(messageID).batchSize) + assert.Equal(t, int64(1), id2.(*messageID).ledgerID) + assert.Equal(t, int64(2), id2.(*messageID).entryID) + assert.Equal(t, int32(3), id2.(*messageID).batchIdx) + assert.Equal(t, int32(4), id2.(*messageID).partitionIdx) + assert.Equal(t, int32(5), id2.(*messageID).batchSize) id, err = DeserializeMessageID(nil) assert.Error(t, err) @@ -110,7 +110,7 @@ func TestAckingMessageIDBatchOne(t *testing.T) { func TestAckingMessageIDBatchTwo(t *testing.T) { tracker := newAckTracker(2) - ids := []trackingMessageID{ + ids := []*trackingMessageID{ newTrackingMessageID(1, 1, 0, 0, 0, tracker), newTrackingMessageID(1, 1, 1, 0, 0, tracker), } @@ -121,7 +121,7 @@ func TestAckingMessageIDBatchTwo(t *testing.T) { // try reverse order tracker = newAckTracker(2) - ids = []trackingMessageID{ + ids = []*trackingMessageID{ newTrackingMessageID(1, 1, 0, 0, 0, tracker), newTrackingMessageID(1, 1, 1, 0, 0, tracker), } diff --git a/pulsar/message_chunking_test.go b/pulsar/message_chunking_test.go index b3d64af..aac87c7 100644 --- a/pulsar/message_chunking_test.go +++ b/pulsar/message_chunking_test.go @@ -454,10 +454,10 @@ func TestChunkSize(t *testing.T) { }) assert.NoError(t, err) if size <= payloadChunkSize { - _, ok := msgID.(messageID) + _, ok := msgID.(*messageID) assert.Equal(t, true, ok) } else { - _, ok := msgID.(chunkMessageID) + _, ok := msgID.(*chunkMessageID) assert.Equal(t, true, ok) } } diff --git a/pulsar/negative_acks_tracker.go b/pulsar/negative_acks_tracker.go index 79ed694..58f5676 100644 --- a/pulsar/negative_acks_tracker.go +++ b/pulsar/negative_acks_tracker.go @@ -65,7 +65,7 @@ func newNegativeAcksTracker(rc redeliveryConsumer, delay time.Duration, return t } -func (t *negativeAcksTracker) Add(msgID messageID) { +func (t *negativeAcksTracker) Add(msgID *messageID) { // Always clear up the batch index since we want to track the nack // for the entire batch batchMsgID := messageID{ diff --git a/pulsar/negative_acks_tracker_test.go b/pulsar/negative_acks_tracker_test.go index 5faa947..12d33de 100644 --- a/pulsar/negative_acks_tracker_test.go +++ b/pulsar/negative_acks_tracker_test.go @@ -81,13 +81,13 @@ func TestNacksTracker(t *testing.T) { nmc := newNackMockedConsumer(nil) nacks := newNegativeAcksTracker(nmc, testNackDelay, nil, log.DefaultNopLogger()) - nacks.Add(messageID{ + nacks.Add(&messageID{ ledgerID: 1, entryID: 1, batchIdx: 1, }) - nacks.Add(messageID{ + nacks.Add(&messageID{ ledgerID: 2, entryID: 2, batchIdx: 1, @@ -114,25 +114,25 @@ func TestNacksWithBatchesTracker(t *testing.T) { nmc := newNackMockedConsumer(nil) nacks := newNegativeAcksTracker(nmc, testNackDelay, nil, log.DefaultNopLogger()) - nacks.Add(messageID{ + nacks.Add(&messageID{ ledgerID: 1, entryID: 1, batchIdx: 1, }) - nacks.Add(messageID{ + nacks.Add(&messageID{ ledgerID: 1, entryID: 1, batchIdx: 2, }) - nacks.Add(messageID{ + nacks.Add(&messageID{ ledgerID: 1, entryID: 1, batchIdx: 3, }) - nacks.Add(messageID{ + nacks.Add(&messageID{ ledgerID: 2, entryID: 2, batchIdx: 1, @@ -194,7 +194,7 @@ func (msg *mockMessage1) Payload() []byte { } func (msg *mockMessage1) ID() MessageID { - return messageID{ + return &messageID{ ledgerID: 1, entryID: 1, batchIdx: 1, @@ -270,7 +270,7 @@ func (msg *mockMessage2) Payload() []byte { } func (msg *mockMessage2) ID() MessageID { - return messageID{ + return &messageID{ ledgerID: 2, entryID: 2, batchIdx: 1, diff --git a/pulsar/producer_partition.go b/pulsar/producer_partition.go index c3a0aa9..744df79 100644 --- a/pulsar/producer_partition.go +++ b/pulsar/producer_partition.go @@ -1163,7 +1163,7 @@ func (p *partitionProducer) ReceivedSendReceipt(response *pb.CommandSendReceipt) if sr.totalChunks > 1 { if sr.chunkID == 0 { sr.chunkRecorder.setFirstChunkID( - messageID{ + &messageID{ int64(response.MessageId.GetLedgerId()), int64(response.MessageId.GetEntryId()), -1, @@ -1172,7 +1172,7 @@ func (p *partitionProducer) ReceivedSendReceipt(response *pb.CommandSendReceipt) }) } else if sr.chunkID == sr.totalChunks-1 { sr.chunkRecorder.setLastChunkID( - messageID{ + &messageID{ int64(response.MessageId.GetLedgerId()), int64(response.MessageId.GetEntryId()), -1, @@ -1180,7 +1180,7 @@ func (p *partitionProducer) ReceivedSendReceipt(response *pb.CommandSendReceipt) 0, }) // use chunkMsgID to set msgID - msgID = sr.chunkRecorder.chunkedMsgID + msgID = &sr.chunkRecorder.chunkedMsgID } } @@ -1375,11 +1375,11 @@ func newChunkRecorder() *chunkRecorder { } } -func (c *chunkRecorder) setFirstChunkID(msgID messageID) { +func (c *chunkRecorder) setFirstChunkID(msgID *messageID) { c.chunkedMsgID.firstChunkID = msgID } -func (c *chunkRecorder) setLastChunkID(msgID messageID) { +func (c *chunkRecorder) setLastChunkID(msgID *messageID) { c.chunkedMsgID.messageID = msgID } diff --git a/pulsar/producer_test.go b/pulsar/producer_test.go index f86d01a..a6f5e39 100644 --- a/pulsar/producer_test.go +++ b/pulsar/producer_test.go @@ -325,7 +325,7 @@ func TestFlushInProducer(t *testing.T) { assert.Nil(t, err) msgCount++ - msgID := msg.ID().(trackingMessageID) + msgID := msg.ID().(*trackingMessageID) // Since messages are batched, they will be sharing the same ledgerId/entryId if ledgerID == -1 { ledgerID = msgID.ledgerID @@ -742,7 +742,7 @@ func TestBatchDelayMessage(t *testing.T) { var delayMsgID int64 ch := make(chan struct{}, 2) producer.SendAsync(ctx, delayMsg, func(id MessageID, producerMessage *ProducerMessage, err error) { - atomic.StoreInt64(&delayMsgID, id.(messageID).entryID) + atomic.StoreInt64(&delayMsgID, id.(*messageID).entryID) ch <- struct{}{} }) delayMsgPublished := false @@ -758,13 +758,13 @@ func TestBatchDelayMessage(t *testing.T) { } var noDelayMsgID int64 producer.SendAsync(ctx, noDelayMsg, func(id MessageID, producerMessage *ProducerMessage, err error) { - atomic.StoreInt64(&noDelayMsgID, id.(messageID).entryID) + atomic.StoreInt64(&noDelayMsgID, id.(*messageID).entryID) }) for i := 0; i < 2; i++ { msg, err := consumer.Receive(context.Background()) assert.Nil(t, err, "unexpected error occurred when recving message from topic") - switch msg.ID().(trackingMessageID).entryID { + switch msg.ID().(*trackingMessageID).entryID { case atomic.LoadInt64(&noDelayMsgID): assert.LessOrEqual(t, time.Since(msg.PublishTime()).Nanoseconds(), int64(batchingDelay*2)) case atomic.LoadInt64(&delayMsgID): diff --git a/pulsar/reader_impl.go b/pulsar/reader_impl.go index 079754b..68dd084 100644 --- a/pulsar/reader_impl.go +++ b/pulsar/reader_impl.go @@ -37,7 +37,7 @@ type reader struct { client *client pc *partitionConsumer messageCh chan ConsumerMessage - lastMessageInBroker trackingMessageID + lastMessageInBroker *trackingMessageID log log.Logger metrics *internal.LeveledMetrics } @@ -51,8 +51,8 @@ func newReader(client *client, options ReaderOptions) (Reader, error) { return nil, newError(InvalidConfiguration, "StartMessageID is required") } - startMessageID, ok := toTrackingMessageID(options.StartMessageID) - if !ok { + startMessageID := toTrackingMessageID(options.StartMessageID) + if startMessageID == nil { // a custom type satisfying MessageID may not be a messageID or trackingMessageID // so re-create messageID using its data deserMsgID, err := deserializeMessageID(options.StartMessageID.Serialize()) @@ -60,8 +60,8 @@ func newReader(client *client, options ReaderOptions) (Reader, error) { return nil, err } // de-serialized MessageID is a messageID - startMessageID = trackingMessageID{ - messageID: deserMsgID.(messageID), + startMessageID = &trackingMessageID{ + messageID: deserMsgID.(*messageID), receivedTime: time.Now(), } } @@ -148,7 +148,7 @@ func (r *reader) Next(ctx context.Context) (Message, error) { // Acknowledge message immediately because the reader is based on non-durable subscription. When it reconnects, // it will specify the subscription position anyway msgID := cm.Message.ID() - if mid, ok := toTrackingMessageID(msgID); ok { + if mid := toTrackingMessageID(msgID); mid != nil { r.pc.lastDequeuedMsg = mid r.pc.AckID(mid) return cm.Message, nil @@ -161,7 +161,7 @@ func (r *reader) Next(ctx context.Context) (Message, error) { } func (r *reader) HasNext() bool { - if !r.lastMessageInBroker.Undefined() && r.hasMoreMessages() { + if r.lastMessageInBroker != nil && r.hasMoreMessages() { return true } @@ -180,7 +180,7 @@ func (r *reader) HasNext() bool { } func (r *reader) hasMoreMessages() bool { - if !r.pc.lastDequeuedMsg.Undefined() { + if r.pc.lastDequeuedMsg != nil { return r.lastMessageInBroker.isEntryIDValid() && r.lastMessageInBroker.greater(r.pc.lastDequeuedMsg.messageID) } @@ -200,29 +200,29 @@ func (r *reader) Close() { r.metrics.ReadersClosed.Inc() } -func (r *reader) messageID(msgID MessageID) (trackingMessageID, bool) { - mid, ok := toTrackingMessageID(msgID) - if !ok { +func (r *reader) messageID(msgID MessageID) *trackingMessageID { + mid := toTrackingMessageID(msgID) + if mid == nil { r.log.Warnf("invalid message id type %T", msgID) - return trackingMessageID{}, false + return nil } partition := int(mid.partitionIdx) // did we receive a valid partition index? if partition < 0 { r.log.Warnf("invalid partition index %d expected", partition) - return trackingMessageID{}, false + return nil } - return mid, true + return mid } func (r *reader) Seek(msgID MessageID) error { r.Lock() defer r.Unlock() - mid, ok := r.messageID(msgID) - if !ok { + mid := r.messageID(msgID) + if mid == nil { return nil }