This is an automated email from the ASF dual-hosted git repository.

kdoran pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/nifi.git


The following commit(s) were added to refs/heads/main by this push:
     new bbb2b152816 NIFI-15614 - ConsumeKafka - Duplicate messages during 
consumer group rebalance (#10908)
bbb2b152816 is described below

commit bbb2b152816f1e9310b83d405ad62bff9865e344
Author: Pierre Villard <[email protected]>
AuthorDate: Tue Apr 7 02:15:40 2026 +0200

    NIFI-15614 - ConsumeKafka - Duplicate messages during consumer group 
rebalance (#10908)
    
    Signed-off-by: Kevin Doran <[email protected]>
---
 .../kafka/processors/ConsumeKafkaRebalanceIT.java  | 330 +++++++++++++++++++++
 .../apache/nifi/kafka/processors/ConsumeKafka.java | 177 +++++++----
 .../kafka/processors/consumer/OffsetTracker.java   |   6 +
 .../service/api/consumer/KafkaConsumerService.java |  40 +++
 .../service/api/consumer/RebalanceCallback.java    |  58 ++++
 .../kafka/service/api/consumer/SessionContext.java |  29 ++
 .../service/consumer/Kafka3ConsumerService.java    |  79 ++++-
 7 files changed, 653 insertions(+), 66 deletions(-)

diff --git 
a/nifi-extension-bundles/nifi-kafka-bundle/nifi-kafka-3-integration/src/test/java/org/apache/nifi/kafka/processors/ConsumeKafkaRebalanceIT.java
 
b/nifi-extension-bundles/nifi-kafka-bundle/nifi-kafka-3-integration/src/test/java/org/apache/nifi/kafka/processors/ConsumeKafkaRebalanceIT.java
index f37c52523cf..e81b3a8c179 100644
--- 
a/nifi-extension-bundles/nifi-kafka-bundle/nifi-kafka-3-integration/src/test/java/org/apache/nifi/kafka/processors/ConsumeKafkaRebalanceIT.java
+++ 
b/nifi-extension-bundles/nifi-kafka-bundle/nifi-kafka-3-integration/src/test/java/org/apache/nifi/kafka/processors/ConsumeKafkaRebalanceIT.java
@@ -30,12 +30,17 @@ import org.apache.kafka.common.TopicPartition;
 import org.apache.kafka.common.serialization.ByteArrayDeserializer;
 import org.apache.kafka.common.serialization.StringSerializer;
 import org.apache.nifi.kafka.service.api.consumer.AutoOffsetReset;
+import org.apache.nifi.kafka.service.api.consumer.RebalanceCallback;
+import org.apache.nifi.kafka.service.api.consumer.SessionContext;
 import org.apache.nifi.kafka.service.api.record.ByteRecord;
 import org.apache.nifi.kafka.service.consumer.Kafka3ConsumerService;
 import org.apache.nifi.kafka.service.consumer.Subscription;
 import org.apache.nifi.logging.ComponentLog;
+import org.junit.jupiter.api.Assumptions;
 import org.junit.jupiter.api.Test;
 import org.junit.jupiter.api.Timeout;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
 
 import java.time.Duration;
 import java.util.Collections;
@@ -45,9 +50,14 @@ import java.util.Map;
 import java.util.Properties;
 import java.util.Set;
 import java.util.UUID;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.CountDownLatch;
 import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicInteger;
+import java.util.concurrent.atomic.AtomicReference;
 
 import static org.junit.jupiter.api.Assertions.assertEquals;
 import static org.junit.jupiter.api.Assertions.assertFalse;
@@ -60,6 +70,8 @@ import static org.mockito.Mockito.mock;
  */
 class ConsumeKafkaRebalanceIT extends AbstractConsumeKafkaIT {
 
+    private static final Logger logger = 
LoggerFactory.getLogger(ConsumeKafkaRebalanceIT.class);
+
     private static final int NUM_PARTITIONS = 3;
     private static final int MESSAGES_PER_PARTITION = 20;
 
@@ -298,6 +310,313 @@ class ConsumeKafkaRebalanceIT extends 
AbstractConsumeKafkaIT {
                 " records since processing was never completed.");
     }
 
+    /**
+     * Tests that a REAL Kafka rebalance (triggered by a second consumer 
joining) does not cause
+     * duplicate message processing.
+     *
+     * This test reproduces the real-world scenario where:
+     * 1. Consumer 1 is actively polling and processing messages (with slow 
processing)
+     * 2. Consumer 2 joins the same group, triggering a Kafka rebalance
+     * 3. During Consumer 1's poll(), onPartitionsRevoked() is called 
internally by Kafka
+     * 4. The RebalanceCallback is invoked, allowing the processor to commit 
its session
+     * 5. Kafka offsets are committed synchronously while still in 
onPartitionsRevoked()
+     * 6. Rebalance completes successfully with no duplicates
+     *
+     * The fix commits offsets INSIDE the onPartitionsRevoked() callback, 
which is the
+     * only time when the consumer is still in a valid state to commit. This 
is similar to how
+     * NiFi 1.x handled rebalances in ConsumerLease.
+     */
+    @Test
+    @Timeout(value = 120, unit = TimeUnit.SECONDS)
+    void testRealRebalanceDoesNotCauseDuplicates() throws Exception {
+        final String topic = "real-rebalance-test-" + UUID.randomUUID();
+        final String groupId = "real-rebalance-group-" + UUID.randomUUID();
+        final int numPartitions = 6;
+        final int messagesPerPartition = 500; // More messages to ensure 
overlap
+        final int totalMessages = numPartitions * messagesPerPartition;
+
+        createTopic(topic, numPartitions);
+        produceMessagesToTopic(topic, numPartitions, messagesPerPartition);
+
+        // Track all consumed message IDs across both consumers
+        final Set<String> allConsumedMessages = ConcurrentHashMap.newKeySet();
+        final AtomicInteger duplicateCount = new AtomicInteger(0);
+        final AtomicInteger rebalanceCount = new AtomicInteger(0);
+        final CountDownLatch consumer1Started = new CountDownLatch(1);
+        final CountDownLatch consumer2Started = new CountDownLatch(1);
+        final CountDownLatch testComplete = new CountDownLatch(2);
+        final AtomicInteger consumer1Count = new AtomicInteger(0);
+        final AtomicInteger consumer2Count = new AtomicInteger(0);
+
+        final ComponentLog mockLog = mock(ComponentLog.class);
+        final ExecutorService executor = Executors.newFixedThreadPool(2);
+
+        // Rebalance callback that simulates processor committing its session
+        // In a real processor, this would commit FlowFiles; here we just log 
and allow the commit
+        final RebalanceCallback callback = (revokedPartitions, context) -> {
+            rebalanceCount.incrementAndGet();
+            logger.info("Rebalance callback invoked for partitions: {}", 
revokedPartitions);
+        };
+
+        try {
+            // Consumer 1: Start consuming with simulated slow processing
+            executor.submit(() -> {
+                final Properties props1 = getConsumerProperties(groupId);
+                // Fetch fewer records per poll to slow down consumption
+                props1.put(ConsumerConfig.MAX_POLL_RECORDS_CONFIG, "10");
+                try (KafkaConsumer<byte[], byte[]> kafkaConsumer1 = new 
KafkaConsumer<>(props1)) {
+                    final Subscription subscription = new 
Subscription(groupId, Collections.singletonList(topic), 
AutoOffsetReset.EARLIEST);
+                    // Use the constructor with callback to enable synchronous 
commit during rebalance
+                    final Kafka3ConsumerService service1 = new 
Kafka3ConsumerService(mockLog, kafkaConsumer1, subscription, callback);
+                    consumer1Started.countDown();
+
+                    int emptyPolls = 0;
+                    while (emptyPolls < 15 && allConsumedMessages.size() < 
totalMessages) {
+                        boolean hasRecords = false;
+                        for (ByteRecord record : 
service1.poll(Duration.ofSeconds(1))) {
+                            hasRecords = true;
+                            final String messageId = record.getTopic() + "-" + 
record.getPartition() + "-" + record.getOffset();
+                            if (!allConsumedMessages.add(messageId)) {
+                                duplicateCount.incrementAndGet();
+                            }
+                            consumer1Count.incrementAndGet();
+                        }
+
+                        if (hasRecords) {
+                            emptyPolls = 0;
+                            // Simulate slow processing
+                            Thread.sleep(50);
+                        } else {
+                            emptyPolls++;
+                        }
+                    }
+                    service1.close();
+                } catch (Exception e) {
+                    logger.error("Consumer 1 error", e);
+                } finally {
+                    testComplete.countDown();
+                }
+            });
+
+            // Wait for consumer 1 to start
+            assertTrue(consumer1Started.await(30, TimeUnit.SECONDS), "Consumer 
1 did not start");
+
+            // Wait a bit then start consumer 2 to trigger rebalance while 
consumer 1 is actively consuming
+            Thread.sleep(200);
+
+            // Consumer 2: Join the group to trigger rebalance
+            executor.submit(() -> {
+                final Properties props2 = getConsumerProperties(groupId);
+                props2.put(ConsumerConfig.MAX_POLL_RECORDS_CONFIG, "10");
+                try (KafkaConsumer<byte[], byte[]> kafkaConsumer2 = new 
KafkaConsumer<>(props2)) {
+                    final Subscription subscription = new 
Subscription(groupId, Collections.singletonList(topic), 
AutoOffsetReset.EARLIEST);
+                    // Use the constructor with callback to enable synchronous 
commit during rebalance
+                    final Kafka3ConsumerService service2 = new 
Kafka3ConsumerService(mockLog, kafkaConsumer2, subscription, callback);
+                    consumer2Started.countDown();
+
+                    int emptyPolls = 0;
+                    while (emptyPolls < 15 && allConsumedMessages.size() < 
totalMessages) {
+                        boolean hasRecords = false;
+                        for (ByteRecord record : 
service2.poll(Duration.ofSeconds(1))) {
+                            hasRecords = true;
+                            final String messageId = record.getTopic() + "-" + 
record.getPartition() + "-" + record.getOffset();
+                            if (!allConsumedMessages.add(messageId)) {
+                                duplicateCount.incrementAndGet();
+                            }
+                            consumer2Count.incrementAndGet();
+                        }
+
+                        if (hasRecords) {
+                            emptyPolls = 0;
+                            Thread.sleep(50);
+                        } else {
+                            emptyPolls++;
+                        }
+                    }
+                    service2.close();
+                } catch (Exception e) {
+                    logger.error("Consumer 2 error", e);
+                } finally {
+                    testComplete.countDown();
+                }
+            });
+
+            // Wait for consumer 2 to start (confirms rebalance was triggered)
+            assertTrue(consumer2Started.await(30, TimeUnit.SECONDS), "Consumer 
2 did not start");
+
+            // Wait for both consumers to finish
+            assertTrue(testComplete.await(90, TimeUnit.SECONDS), "Test did not 
complete in time");
+
+        } finally {
+            executor.shutdownNow();
+        }
+
+        // Log results for debugging
+        logger.info("Consumer 1 polled: {} records", consumer1Count.get());
+        logger.info("Consumer 2 polled: {} records", consumer2Count.get());
+        logger.info("Total unique messages: {}", allConsumedMessages.size());
+        logger.info("Duplicate count: {}", duplicateCount.get());
+        logger.info("Rebalance count: {}", rebalanceCount.get());
+
+        // Verify both consumers participated (rebalance occurred)
+        assertTrue(consumer2Count.get() > 0,
+                "Consumer 2 should have consumed some records after rebalance, 
but got " + consumer2Count.get());
+
+        // Verify no duplicates occurred
+        assertEquals(0, duplicateCount.get(),
+                "Duplicate messages detected during rebalance! " + 
duplicateCount.get() + " duplicates found. " +
+                "Consumer 1 polled " + consumer1Count.get() + " records, " +
+                "Consumer 2 polled " + consumer2Count.get() + " records, " +
+                "but only " + allConsumedMessages.size() + " unique 
messages.");
+
+        // Verify all messages were consumed
+        assertEquals(totalMessages, allConsumedMessages.size(),
+                "Expected to consume " + totalMessages + " unique messages but 
got " + allConsumedMessages.size());
+    }
+
+    /**
+     * Tests that the per-service session context ensures thread safety during 
rebalances.
+     *
+     * With multiple concurrent tasks, each consumer service has its own 
session context.
+     * This test verifies that when a rebalance callback fires, it receives 
the correct
+     * session context from its own service, not from another concurrent task.
+     *
+     * This test:
+     * 1. Creates two consumer services in the same group
+     * 2. Each service sets its own session context ("session-1" and 
"session-2")
+     * 3. When a rebalance occurs, each callback receives its own service's 
session context
+     * 4. Verifies that consumer 1's callback sees "session-1" (not 
"session-2")
+     */
+    @Test
+    @Timeout(value = 60, unit = TimeUnit.SECONDS)
+    void testPerServiceSessionContextEnsuresThreadSafety() throws Exception {
+        final String topic = "session-context-test-" + UUID.randomUUID();
+        final String groupId = "session-context-group-" + UUID.randomUUID();
+
+        createTopic(topic, 2);
+        produceMessagesToTopic(topic, 2, 10);
+
+        final ComponentLog mockLog = mock(ComponentLog.class);
+
+        // Track what session ID was seen in each callback
+        final AtomicReference<String> sessionSeenInCallback1 = new 
AtomicReference<>();
+        final AtomicReference<String> sessionSeenInCallback2 = new 
AtomicReference<>();
+
+        // Latches to coordinate thread timing
+        final CountDownLatch thread1SetContext = new CountDownLatch(1);
+        final CountDownLatch thread2SetContext = new CountDownLatch(1);
+        final CountDownLatch thread1RebalanceTriggered = new CountDownLatch(1);
+        final CountDownLatch testComplete = new CountDownLatch(2);
+
+        // Callback for consumer 1 - reads from sessionContext parameter 
(per-service)
+        final RebalanceCallback callback1 = (revokedPartitions, context) -> {
+            final String sessionId = context != null ? ((TestSessionContext) 
context).sessionId : null;
+            sessionSeenInCallback1.set(sessionId);
+            logger.info("Consumer 1 callback fired, saw session context: {}", 
sessionId);
+            thread1RebalanceTriggered.countDown();
+        };
+
+        // Callback for consumer 2 - also reads from sessionContext parameter
+        final RebalanceCallback callback2 = (revokedPartitions, context) -> {
+            final String sessionId = context != null ? ((TestSessionContext) 
context).sessionId : null;
+            sessionSeenInCallback2.set(sessionId);
+            logger.info("Consumer 2 callback fired, saw session context: {}", 
sessionId);
+        };
+
+        final ExecutorService executor = Executors.newFixedThreadPool(2);
+
+        try {
+            // Thread 1: Consumer 1 sets its own session context to "session-1"
+            executor.submit(() -> {
+                try {
+                    final Properties consumerProps = 
getConsumerProperties(groupId);
+                    try (KafkaConsumer<byte[], byte[]> kafkaConsumer = new 
KafkaConsumer<>(consumerProps)) {
+                        final Subscription subscription = new 
Subscription(groupId, Collections.singletonList(topic), 
AutoOffsetReset.EARLIEST);
+                        final Kafka3ConsumerService service1 = new 
Kafka3ConsumerService(mockLog, kafkaConsumer, subscription, callback1);
+
+                        // Set session context on THIS service (not a shared 
holder)
+                        service1.setSessionContext(new 
TestSessionContext("session-1"));
+                        logger.info("Thread 1 set service1 context to: 
session-1");
+                        thread1SetContext.countDown();
+
+                        // Wait for Thread 2 to set its context
+                        assertTrue(thread2SetContext.await(30, 
TimeUnit.SECONDS), "Thread 2 did not set context");
+
+                        // Poll to consume some records - this might trigger a 
rebalance
+                        // when consumer 2 joins the group
+                        for (int i = 0; i < 10; i++) {
+                            final Iterator<ByteRecord> records = 
service1.poll(Duration.ofMillis(500)).iterator();
+                            while (records.hasNext()) {
+                                records.next();
+                            }
+                            // Check if rebalance callback was triggered
+                            if (sessionSeenInCallback1.get() != null) {
+                                break;
+                            }
+                        }
+
+                        service1.close();
+                    }
+                } catch (Exception e) {
+                    logger.error("Thread 1 error", e);
+                } finally {
+                    testComplete.countDown();
+                }
+            });
+
+            // Thread 2: Sets its own session context to "session-2"
+            executor.submit(() -> {
+                try {
+                    // Wait for Thread 1 to set its context
+                    assertTrue(thread1SetContext.await(30, TimeUnit.SECONDS), 
"Thread 1 did not set context");
+
+                    final Properties consumerProps = 
getConsumerProperties(groupId);
+                    try (KafkaConsumer<byte[], byte[]> kafkaConsumer = new 
KafkaConsumer<>(consumerProps)) {
+                        final Subscription subscription = new 
Subscription(groupId, Collections.singletonList(topic), 
AutoOffsetReset.EARLIEST);
+                        final Kafka3ConsumerService service2 = new 
Kafka3ConsumerService(mockLog, kafkaConsumer, subscription, callback2);
+
+                        // Set session context on THIS service (different from 
service1)
+                        service2.setSessionContext(new 
TestSessionContext("session-2"));
+                        logger.info("Thread 2 set service2 context to: 
session-2");
+                        thread2SetContext.countDown();
+
+                        // Poll to trigger rebalance (joining the group)
+                        for (int i = 0; i < 10; i++) {
+                            final Iterator<ByteRecord> records = 
service2.poll(Duration.ofMillis(500)).iterator();
+                            while (records.hasNext()) {
+                                records.next();
+                            }
+                        }
+
+                        service2.close();
+                    }
+                } catch (Exception e) {
+                    logger.error("Thread 2 error", e);
+                } finally {
+                    testComplete.countDown();
+                }
+            });
+
+            // Wait for test to complete
+            assertTrue(testComplete.await(45, TimeUnit.SECONDS), "Test did not 
complete in time");
+
+        } finally {
+            executor.shutdownNow();
+        }
+
+        // Skip the test if no rebalance occurred - we can't verify the 
thread-safety property without it
+        Assumptions.assumeTrue(sessionSeenInCallback1.get() != null,
+                "No rebalance occurred for consumer 1 - test is inconclusive");
+
+        logger.info("Consumer 1 callback saw session: {}, expected: 
session-1", sessionSeenInCallback1.get());
+
+        // With per-service session context, consumer 1's callback should see 
"session-1"
+        // (its own session context), not "session-2" (consumer 2's session 
context)
+        assertEquals("session-1", sessionSeenInCallback1.get(),
+                "Per-service session context failed! Consumer 1's rebalance 
callback saw the wrong session. " +
+                "Expected 'session-1' but got '" + 
sessionSeenInCallback1.get() + "'.");
+    }
+
     /**
      * Produces messages to a specific topic with a given number of partitions.
      */
@@ -364,4 +683,15 @@ class ConsumeKafkaRebalanceIT extends 
AbstractConsumeKafkaIT {
         props.put(ConsumerConfig.HEARTBEAT_INTERVAL_MS_CONFIG, "3000");
         return props;
     }
+
+    /**
+     * Simple test implementation of SessionContext for thread-safety testing.
+     */
+    private static class TestSessionContext implements SessionContext {
+        final String sessionId;
+
+        TestSessionContext(final String sessionId) {
+            this.sessionId = sessionId;
+        }
+    }
 }
diff --git 
a/nifi-extension-bundles/nifi-kafka-bundle/nifi-kafka-processors/src/main/java/org/apache/nifi/kafka/processors/ConsumeKafka.java
 
b/nifi-extension-bundles/nifi-kafka-bundle/nifi-kafka-processors/src/main/java/org/apache/nifi/kafka/processors/ConsumeKafka.java
index bd79089542c..9a50eb6060c 100644
--- 
a/nifi-extension-bundles/nifi-kafka-bundle/nifi-kafka-processors/src/main/java/org/apache/nifi/kafka/processors/ConsumeKafka.java
+++ 
b/nifi-extension-bundles/nifi-kafka-bundle/nifi-kafka-processors/src/main/java/org/apache/nifi/kafka/processors/ConsumeKafka.java
@@ -46,6 +46,8 @@ import 
org.apache.nifi.kafka.service.api.common.PartitionState;
 import org.apache.nifi.kafka.service.api.consumer.AutoOffsetReset;
 import org.apache.nifi.kafka.service.api.consumer.KafkaConsumerService;
 import org.apache.nifi.kafka.service.api.consumer.PollingContext;
+import org.apache.nifi.kafka.service.api.consumer.RebalanceCallback;
+import org.apache.nifi.kafka.service.api.consumer.SessionContext;
 import org.apache.nifi.kafka.service.api.record.ByteRecord;
 import org.apache.nifi.kafka.shared.attribute.KafkaFlowFileAttribute;
 import org.apache.nifi.kafka.shared.property.KeyEncoding;
@@ -430,77 +432,87 @@ public class ConsumeKafka extends AbstractProcessor 
implements VerifiableProcess
         final OffsetTracker offsetTracker = new OffsetTracker();
         boolean recordsReceived = false;
 
-        while (System.currentTimeMillis() < stopTime) {
-            try {
-                final Duration maxWaitDuration = Duration.ofMillis(stopTime - 
System.currentTimeMillis());
-                if (maxWaitDuration.toMillis() <= 0) {
-                    break;
-                }
+        final RebalanceSessionHolder sessionHolder = new 
RebalanceSessionHolder(session, offsetTracker);
+        consumerService.setSessionContext(sessionHolder);
 
-                final Iterator<ByteRecord> consumerRecords = 
consumerService.poll(maxWaitDuration).iterator();
-                if (!consumerRecords.hasNext()) {
-                    getLogger().trace("No Kafka Records consumed: {}", 
pollingContext);
-                    // Check if a rebalance occurred during poll - if so, 
break to commit what we have
-                    if (consumerService.hasRevokedPartitions()) {
-                        getLogger().debug("Rebalance detected with revoked 
partitions, breaking to commit session");
+        try {
+            while (System.currentTimeMillis() < stopTime) {
+                try {
+                    final Duration maxWaitDuration = 
Duration.ofMillis(stopTime - System.currentTimeMillis());
+                    if (maxWaitDuration.toMillis() <= 0) {
                         break;
                     }
-                    continue;
-                }
 
-                recordsReceived = true;
-                processConsumerRecords(context, session, offsetTracker, 
consumerRecords);
+                    final Iterator<ByteRecord> consumerRecords = 
consumerService.poll(maxWaitDuration).iterator();
+                    if (!consumerRecords.hasNext()) {
+                        getLogger().trace("No Kafka Records consumed: {}", 
pollingContext);
+                        // Check if a rebalance occurred during poll - if so, 
break to commit what we have
+                        if (consumerService.hasRevokedPartitions()) {
+                            getLogger().debug("Rebalance detected with revoked 
partitions, breaking to commit session");
+                            break;
+                        }
+                        continue;
+                    }
 
-                // Check if a rebalance occurred during poll - if so, break to 
commit what we have
-                if (consumerService.hasRevokedPartitions()) {
-                    getLogger().debug("Rebalance detected with revoked 
partitions, breaking to commit session");
-                    break;
-                }
+                    recordsReceived = true;
+                    processConsumerRecords(context, session, offsetTracker, 
consumerRecords);
 
-                if (maxUncommittedSizeConfigured) {
-                    // Stop consuming before reaching Max Uncommitted Time 
when exceeding Max Uncommitted Size
-                    final long totalRecordSize = 
offsetTracker.getTotalRecordSize();
-                    if (totalRecordSize > maxUncommittedSize) {
+                    // Check if a rebalance occurred during poll - if so, 
break to commit what we have
+                    if (consumerService.hasRevokedPartitions()) {
+                        getLogger().debug("Rebalance detected with revoked 
partitions, breaking to commit session");
                         break;
                     }
+
+                    if (maxUncommittedSizeConfigured) {
+                        // Stop consuming before reaching Max Uncommitted Time 
when exceeding Max Uncommitted Size
+                        final long totalRecordSize = 
offsetTracker.getTotalRecordSize();
+                        if (totalRecordSize > maxUncommittedSize) {
+                            break;
+                        }
+                    }
+                } catch (final Exception e) {
+                    getLogger().error("Failed to consume Kafka Records", e);
+                    consumerService.rollback();
+                    close(consumerService, "Encountered Exception while 
consuming or writing out Kafka Records");
+                    context.yield();
+                    // If there are any FlowFiles already created and 
transferred, roll them back because we're rolling back offsets and
+                    // because we will consume the data again, we don't want 
to transfer out the FlowFiles.
+                    session.rollback();
+                    return;
                 }
-            } catch (final Exception e) {
-                getLogger().error("Failed to consume Kafka Records", e);
-                consumerService.rollback();
-                close(consumerService, "Encountered Exception while consuming 
or writing out Kafka Records");
-                context.yield();
-                // If there are any FlowFiles already created and transferred, 
roll them back because we're rolling back offsets and
-                // because we will consume the data again, we don't want to 
transfer out the FlowFiles.
-                session.rollback();
-                return;
             }
-        }
 
-        if (!recordsReceived && !consumerService.hasRevokedPartitions()) {
-            getLogger().trace("No Kafka Records consumed, re-queuing 
consumer");
-            consumerServices.offer(consumerService);
-            return;
-        }
+            if (!recordsReceived && !consumerService.hasRevokedPartitions()) {
+                getLogger().trace("No Kafka Records consumed, re-queuing 
consumer");
+                consumerServices.offer(consumerService);
+                return;
+            }
 
-        // If no records received but we have revoked partitions, we still 
need to commit their offsets
-        if (!recordsReceived && consumerService.hasRevokedPartitions()) {
-            getLogger().debug("No records received but rebalance occurred, 
committing offsets for revoked partitions");
-            try {
-                consumerService.commitOffsetsForRevokedPartitions();
-            } catch (final Exception e) {
-                getLogger().warn("Failed to commit offsets for revoked 
partitions", e);
+            // If no records received but we have revoked partitions, we still 
need to commit their offsets.
+            // Note: When a rebalance callback is registered (which is the 
case in this processor), offsets for
+            // revoked partitions are committed synchronously during 
onPartitionsRevoked(), so hasRevokedPartitions()
+            // will return false. This code path exists for backward 
compatibility when no callback is registered.
+            if (!recordsReceived && consumerService.hasRevokedPartitions()) {
+                getLogger().debug("No records received but rebalance occurred, 
committing offsets for revoked partitions");
+                try {
+                    consumerService.commitOffsetsForRevokedPartitions();
+                } catch (final Exception e) {
+                    getLogger().warn("Failed to commit offsets for revoked 
partitions", e);
+                }
+                consumerServices.offer(consumerService);
+                return;
             }
-            consumerServices.offer(consumerService);
-            return;
-        }
 
-        session.commitAsync(
-            () -> commitOffsets(consumerService, offsetTracker, 
pollingContext, session),
-            throwable -> {
-                getLogger().error("Failed to commit session; will roll back 
any uncommitted records", throwable);
-                rollback(consumerService, offsetTracker, session);
-                context.yield();
-            });
+            session.commitAsync(
+                () -> commitOffsets(consumerService, offsetTracker, 
pollingContext, session),
+                throwable -> {
+                    getLogger().error("Failed to commit session; will roll 
back any uncommitted records", throwable);
+                    rollback(consumerService, offsetTracker, session);
+                    context.yield();
+                });
+        } finally {
+            consumerService.setSessionContext(null);
+        }
     }
 
     private void commitOffsets(final KafkaConsumerService consumerService, 
final OffsetTracker offsetTracker, final PollingContext pollingContext, final 
ProcessSession session) {
@@ -513,7 +525,9 @@ public class ConsumeKafka extends AbstractProcessor 
implements VerifiableProcess
                 });
             }
 
-            // After successful session commit, also commit offsets for any 
partitions that were revoked during rebalance
+            // After successful session commit, also commit offsets for any 
partitions that were revoked during rebalance.
+            // Note: When a rebalance callback is registered, this check will 
always be false since offsets are
+            // committed synchronously during onPartitionsRevoked(). This code 
path is for backward compatibility.
             if (consumerService.hasRevokedPartitions()) {
                 getLogger().debug("Committing offsets for partitions revoked 
during rebalance");
                 consumerService.commitOffsetsForRevokedPartitions();
@@ -688,7 +702,43 @@ public class ConsumeKafka extends AbstractProcessor 
implements VerifiableProcess
 
         getLogger().info("No Kafka Consumer Service available; creating a new 
one. Active count: {}", activeCount);
         final KafkaConnectionService connectionService = 
context.getProperty(CONNECTION_SERVICE).asControllerService(KafkaConnectionService.class);
-        return connectionService.getConsumerService(pollingContext);
+        final KafkaConsumerService newService = 
connectionService.getConsumerService(pollingContext);
+        newService.setRebalanceCallback(createRebalanceCallback());
+        return newService;
+    }
+
+    private RebalanceCallback createRebalanceCallback() {
+        return new RebalanceCallback() {
+            @Override
+            public void onPartitionsRevoked(final Collection<PartitionState> 
revokedPartitions, final SessionContext sessionContext) {
+                if (sessionContext == null) {
+                    getLogger().debug("No session context during rebalance 
callback, nothing to commit");
+                    return;
+                }
+
+                final RebalanceSessionHolder holder = (RebalanceSessionHolder) 
sessionContext;
+                final ProcessSession session = holder.session;
+                final OffsetTracker offsetTracker = holder.offsetTracker;
+
+                getLogger().info("Rebalance callback invoked for {} revoked 
partitions, committing session synchronously",
+                        revokedPartitions.size());
+
+                try {
+                    session.commit();
+                    getLogger().debug("Session committed successfully during 
rebalance callback");
+
+                    if (offsetTracker != null) {
+                        offsetTracker.getRecordCounts().forEach((topic, count) 
-> {
+                            session.adjustCounter("Records Acknowledged for " 
+ topic, count, true);
+                        });
+                        offsetTracker.clear();
+                    }
+                } catch (final Exception e) {
+                    getLogger().error("Failed to commit session during 
rebalance callback", e);
+                    throw new RuntimeException("Failed to commit session 
during rebalance", e);
+                }
+            }
+        };
     }
 
     private int getMaxConsumerCount() {
@@ -780,4 +830,13 @@ public class ConsumeKafka extends AbstractProcessor 
implements VerifiableProcess
         return pollingContext;
     }
 
+    private static class RebalanceSessionHolder implements SessionContext {
+        private final ProcessSession session;
+        private final OffsetTracker offsetTracker;
+
+        RebalanceSessionHolder(final ProcessSession session, final 
OffsetTracker offsetTracker) {
+            this.session = session;
+            this.offsetTracker = offsetTracker;
+        }
+    }
 }
diff --git 
a/nifi-extension-bundles/nifi-kafka-bundle/nifi-kafka-processors/src/main/java/org/apache/nifi/kafka/processors/consumer/OffsetTracker.java
 
b/nifi-extension-bundles/nifi-kafka-bundle/nifi-kafka-processors/src/main/java/org/apache/nifi/kafka/processors/consumer/OffsetTracker.java
index e421b9c5b58..b79725940c1 100644
--- 
a/nifi-extension-bundles/nifi-kafka-bundle/nifi-kafka-processors/src/main/java/org/apache/nifi/kafka/processors/consumer/OffsetTracker.java
+++ 
b/nifi-extension-bundles/nifi-kafka-bundle/nifi-kafka-processors/src/main/java/org/apache/nifi/kafka/processors/consumer/OffsetTracker.java
@@ -64,4 +64,10 @@ public class OffsetTracker {
         }
         return pollingSummary;
     }
+
+    public void clear() {
+        offsets.clear();
+        recordCounts.clear();
+        totalRecordSize.set(0);
+    }
 }
diff --git 
a/nifi-extension-bundles/nifi-kafka-bundle/nifi-kafka-service-api/src/main/java/org/apache/nifi/kafka/service/api/consumer/KafkaConsumerService.java
 
b/nifi-extension-bundles/nifi-kafka-bundle/nifi-kafka-service-api/src/main/java/org/apache/nifi/kafka/service/api/consumer/KafkaConsumerService.java
index 310f31ad267..5e46894a825 100644
--- 
a/nifi-extension-bundles/nifi-kafka-bundle/nifi-kafka-service-api/src/main/java/org/apache/nifi/kafka/service/api/consumer/KafkaConsumerService.java
+++ 
b/nifi-extension-bundles/nifi-kafka-bundle/nifi-kafka-service-api/src/main/java/org/apache/nifi/kafka/service/api/consumer/KafkaConsumerService.java
@@ -106,4 +106,44 @@ public interface KafkaConsumerService extends Closeable {
      */
     default void clearRevokedPartitions() {
     }
+
+    /**
+     * Set a callback to be invoked during consumer group rebalance when 
partitions are revoked.
+     * The callback is invoked inside {@code onPartitionsRevoked()} before 
Kafka offsets are committed,
+     * allowing the processor to commit its session (FlowFiles) first.
+     * <p>
+     * This is critical for preventing both data loss and duplicates during 
rebalance:
+     * <ul>
+     *   <li>The callback commits the NiFi session (FlowFiles) first</li>
+     *   <li>Then Kafka offsets are committed while consumer is still in valid 
state</li>
+     * </ul>
+     * </p>
+     *
+     * @param callback the callback to invoke during rebalance, or null to 
clear
+     */
+    default void setRebalanceCallback(RebalanceCallback callback) {
+    }
+
+    /**
+     * Set a session context object that will be associated with this consumer 
service instance.
+     * This is used to store processor-specific state (like session and offset 
tracker) that
+     * needs to be accessed during rebalance callbacks.
+     * <p>
+     * Each consumer service instance should have its own session context to 
ensure thread safety
+     * when multiple concurrent tasks are using different consumer services 
from a pool.
+     * </p>
+     *
+     * @param sessionContext the context object to associate with this 
service, or null to clear
+     */
+    default void setSessionContext(SessionContext sessionContext) {
+    }
+
+    /**
+     * Get the session context object associated with this consumer service 
instance.
+     *
+     * @return the session context, or null if not set
+     */
+    default SessionContext getSessionContext() {
+        return null;
+    }
 }
diff --git 
a/nifi-extension-bundles/nifi-kafka-bundle/nifi-kafka-service-api/src/main/java/org/apache/nifi/kafka/service/api/consumer/RebalanceCallback.java
 
b/nifi-extension-bundles/nifi-kafka-bundle/nifi-kafka-service-api/src/main/java/org/apache/nifi/kafka/service/api/consumer/RebalanceCallback.java
new file mode 100644
index 00000000000..838b54fafda
--- /dev/null
+++ 
b/nifi-extension-bundles/nifi-kafka-bundle/nifi-kafka-service-api/src/main/java/org/apache/nifi/kafka/service/api/consumer/RebalanceCallback.java
@@ -0,0 +1,58 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.nifi.kafka.service.api.consumer;
+
+import org.apache.nifi.kafka.service.api.common.PartitionState;
+
+import java.util.Collection;
+
+/**
+ * Callback interface for handling Kafka consumer group rebalance events.
+ * <p>
+ * When a rebalance occurs and partitions are being revoked from a consumer,
+ * Kafka calls {@code onPartitionsRevoked()} on the ConsumerRebalanceListener.
+ * This is the ONLY time when the consumer is still in a valid state to commit 
offsets.
+ * After this callback returns, the consumer is no longer part of an active 
group
+ * and any commit attempts will fail with RebalanceInProgressException.
+ * </p>
+ * <p>
+ * This callback allows processors to be notified during the rebalance so they 
can
+ * commit their session (FlowFiles) before Kafka offsets are committed. This 
ensures:
+ * <ul>
+ *   <li>No data loss: NiFi session is committed before Kafka offsets</li>
+ *   <li>No duplicates: Kafka offsets are committed while consumer is still 
valid</li>
+ * </ul>
+ * </p>
+ */
+@FunctionalInterface
+public interface RebalanceCallback {
+
+    /**
+     * Called during {@code onPartitionsRevoked()} when partitions with 
uncommitted offsets
+     * are being revoked from this consumer.
+     * <p>
+     * The implementation should commit any pending work (e.g., NiFi session 
with FlowFiles)
+     * for the specified partitions. After this method returns, Kafka offsets 
will be committed
+     * for these partitions.
+     * </p>
+     *
+     * @param revokedPartitions the partitions being revoked that have 
uncommitted offsets
+     * @param sessionContext the session context stored in the consumer 
service, containing
+     *                       processor-specific state (e.g., ProcessSession 
and OffsetTracker)
+     */
+    void onPartitionsRevoked(Collection<PartitionState> revokedPartitions, 
SessionContext sessionContext);
+}
diff --git 
a/nifi-extension-bundles/nifi-kafka-bundle/nifi-kafka-service-api/src/main/java/org/apache/nifi/kafka/service/api/consumer/SessionContext.java
 
b/nifi-extension-bundles/nifi-kafka-bundle/nifi-kafka-service-api/src/main/java/org/apache/nifi/kafka/service/api/consumer/SessionContext.java
new file mode 100644
index 00000000000..c9b213989eb
--- /dev/null
+++ 
b/nifi-extension-bundles/nifi-kafka-bundle/nifi-kafka-service-api/src/main/java/org/apache/nifi/kafka/service/api/consumer/SessionContext.java
@@ -0,0 +1,29 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.nifi.kafka.service.api.consumer;
+
+/**
+ * Marker interface for session context objects stored in KafkaConsumerService.
+ * <p>
+ * Implementations of this interface hold processor-specific state (such as 
ProcessSession
+ * and offset tracking information) that needs to be accessed during rebalance 
callbacks.
+ * Each consumer service instance should have its own session context to 
ensure thread safety
+ * when multiple concurrent tasks are using different consumer services from a 
pool.
+ * </p>
+ */
+public interface SessionContext {
+}
diff --git 
a/nifi-extension-bundles/nifi-kafka-bundle/nifi-kafka-service-shared/src/main/java/org/apache/nifi/kafka/service/consumer/Kafka3ConsumerService.java
 
b/nifi-extension-bundles/nifi-kafka-bundle/nifi-kafka-service-shared/src/main/java/org/apache/nifi/kafka/service/consumer/Kafka3ConsumerService.java
index 7dce728b5e9..37cb7e18254 100644
--- 
a/nifi-extension-bundles/nifi-kafka-bundle/nifi-kafka-service-shared/src/main/java/org/apache/nifi/kafka/service/consumer/Kafka3ConsumerService.java
+++ 
b/nifi-extension-bundles/nifi-kafka-bundle/nifi-kafka-service-shared/src/main/java/org/apache/nifi/kafka/service/consumer/Kafka3ConsumerService.java
@@ -27,6 +27,8 @@ import 
org.apache.nifi.kafka.service.api.common.PartitionState;
 import org.apache.nifi.kafka.service.api.common.TopicPartitionSummary;
 import org.apache.nifi.kafka.service.api.consumer.KafkaConsumerService;
 import org.apache.nifi.kafka.service.api.consumer.PollingSummary;
+import org.apache.nifi.kafka.service.api.consumer.RebalanceCallback;
+import org.apache.nifi.kafka.service.api.consumer.SessionContext;
 import org.apache.nifi.kafka.service.api.header.RecordHeader;
 import org.apache.nifi.kafka.service.api.record.ByteRecord;
 import org.apache.nifi.logging.ComponentLog;
@@ -57,14 +59,22 @@ public class Kafka3ConsumerService implements 
KafkaConsumerService, Closeable, C
     private final ComponentLog componentLog;
     private final Consumer<byte[], byte[]> consumer;
     private final Subscription subscription;
+    private volatile RebalanceCallback rebalanceCallback;
+    private volatile SessionContext sessionContext;
     private final Map<TopicPartition, Long> uncommittedOffsets = new 
ConcurrentHashMap<>();
     private final Set<TopicPartition> revokedPartitions = new 
CopyOnWriteArraySet<>();
     private volatile boolean closed = false;
 
     public Kafka3ConsumerService(final ComponentLog componentLog, final 
Consumer<byte[], byte[]> consumer, final Subscription subscription) {
+        this(componentLog, consumer, subscription, null);
+    }
+
+    public Kafka3ConsumerService(final ComponentLog componentLog, final 
Consumer<byte[], byte[]> consumer,
+            final Subscription subscription, final RebalanceCallback 
rebalanceCallback) {
         this.componentLog = Objects.requireNonNull(componentLog, "Component 
Log required");
         this.consumer = consumer;
         this.subscription = subscription;
+        this.rebalanceCallback = rebalanceCallback;
 
         final Optional<Pattern> topicPatternFound = 
subscription.getTopicPattern();
         if (topicPatternFound.isPresent()) {
@@ -85,17 +95,57 @@ public class Kafka3ConsumerService implements 
KafkaConsumerService, Closeable, C
     public void onPartitionsRevoked(final Collection<TopicPartition> 
partitions) {
         componentLog.info("Kafka revoked the following Partitions from this 
consumer: {}", partitions);
 
-        // Store revoked partitions for the processor to handle after 
committing its session.
-        // We do NOT commit offsets here to avoid data loss - the processor 
must commit its
-        // session first, then call commitOffsetsForRevokedPartitions().
+        // Identify partitions with uncommitted offsets
+        final Map<TopicPartition, Long> partitionsWithUncommittedOffsets = new 
HashMap<>();
         for (final TopicPartition partition : partitions) {
-            if (uncommittedOffsets.containsKey(partition)) {
-                revokedPartitions.add(partition);
+            final Long offset = uncommittedOffsets.get(partition);
+            if (offset != null) {
+                partitionsWithUncommittedOffsets.put(partition, offset);
             }
         }
 
-        if (!revokedPartitions.isEmpty()) {
-            componentLog.info("Partitions revoked with uncommitted offsets, 
pending processor commit: {}", revokedPartitions);
+        if (partitionsWithUncommittedOffsets.isEmpty()) {
+            return;
+        }
+
+        componentLog.info("Partitions revoked with uncommitted offsets: {}", 
partitionsWithUncommittedOffsets.keySet());
+
+        // If a callback is registered, we can safely commit offsets 
synchronously:
+        // 1. Call the callback so the processor can commit its session 
(FlowFiles) first
+        // 2. Then commit Kafka offsets immediately while consumer is still in 
valid state
+        // This prevents both data loss (session committed first) and 
duplicates (offsets committed during callback).
+        if (rebalanceCallback != null) {
+            final Collection<PartitionState> revokedStates = 
partitionsWithUncommittedOffsets.keySet().stream()
+                    .map(tp -> new PartitionState(tp.topic(), tp.partition()))
+                    .collect(Collectors.toList());
+
+            try {
+                componentLog.debug("Invoking rebalance callback for 
partitions: {}", revokedStates);
+                rebalanceCallback.onPartitionsRevoked(revokedStates, 
sessionContext);
+            } catch (final Exception e) {
+                componentLog.warn("Rebalance callback failed, offsets will not 
be committed for revoked partitions", e);
+                return;
+            }
+
+            // Commit offsets for revoked partitions immediately while still 
in valid state
+            final Map<TopicPartition, OffsetAndMetadata> offsetsToCommit = new 
HashMap<>();
+            for (final Map.Entry<TopicPartition, Long> entry : 
partitionsWithUncommittedOffsets.entrySet()) {
+                offsetsToCommit.put(entry.getKey(), new 
OffsetAndMetadata(entry.getValue()));
+                uncommittedOffsets.remove(entry.getKey());
+            }
+
+            try {
+                consumer.commitSync(offsetsToCommit);
+                componentLog.info("Committed offsets during rebalance for 
partitions: {}", offsetsToCommit);
+            } catch (final Exception e) {
+                componentLog.warn("Failed to commit offsets during rebalance 
for partitions: {}", offsetsToCommit.keySet(), e);
+            }
+        } else {
+            // No callback registered - defer commit to avoid data loss.
+            // Store revoked partitions so the processor can call 
commitOffsetsForRevokedPartitions()
+            // after successfully committing its session.
+            
revokedPartitions.addAll(partitionsWithUncommittedOffsets.keySet());
+            componentLog.info("No rebalance callback registered, deferring 
commit for partitions: {}", revokedPartitions);
         }
     }
 
@@ -248,6 +298,21 @@ public class Kafka3ConsumerService implements 
KafkaConsumerService, Closeable, C
         revokedPartitions.clear();
     }
 
+    @Override
+    public void setRebalanceCallback(final RebalanceCallback callback) {
+        this.rebalanceCallback = callback;
+    }
+
+    @Override
+    public void setSessionContext(final SessionContext sessionContext) {
+        this.sessionContext = sessionContext;
+    }
+
+    @Override
+    public SessionContext getSessionContext() {
+        return sessionContext;
+    }
+
     private Map<TopicPartition, OffsetAndMetadata> getOffsets(final 
PollingSummary pollingSummary) {
         final Map<TopicPartition, OffsetAndMetadata> offsets = new HashMap<>();
 

Reply via email to