This is an automated email from the ASF dual-hosted git repository.

kenn pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new bcee5d081d0 Remove expensive shuffle of read data in KafkaIO when 
using sdf and commit offsets (#31682)
bcee5d081d0 is described below

commit bcee5d081d05841bb52f0851f94a04f9c8968b88
Author: Sam Whittle <[email protected]>
AuthorDate: Thu Aug 29 16:03:47 2024 +0200

    Remove expensive shuffle of read data in KafkaIO when using sdf and commit 
offsets (#31682)
---
 .../beam/sdk/io/kafka/KafkaCommitOffset.java       |  83 +++++++++-
 .../java/org/apache/beam/sdk/io/kafka/KafkaIO.java |  80 +++++++---
 .../beam/sdk/io/kafka/KafkaCommitOffsetTest.java   | 169 ++++++++++++++++++---
 3 files changed, 278 insertions(+), 54 deletions(-)

diff --git 
a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaCommitOffset.java
 
b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaCommitOffset.java
index 3816ee0bb85..fa692d3aaf4 100644
--- 
a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaCommitOffset.java
+++ 
b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaCommitOffset.java
@@ -33,6 +33,7 @@ import org.apache.beam.sdk.transforms.PTransform;
 import org.apache.beam.sdk.transforms.ParDo;
 import org.apache.beam.sdk.transforms.SerializableFunction;
 import org.apache.beam.sdk.transforms.windowing.FixedWindows;
+import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
 import org.apache.beam.sdk.transforms.windowing.Window;
 import org.apache.beam.sdk.values.KV;
 import org.apache.beam.sdk.values.PCollection;
@@ -40,7 +41,9 @@ import org.apache.beam.sdk.values.TypeDescriptor;
 import org.apache.kafka.clients.consumer.Consumer;
 import org.apache.kafka.clients.consumer.ConsumerConfig;
 import org.apache.kafka.clients.consumer.OffsetAndMetadata;
+import org.checkerframework.checker.nullness.qual.MonotonicNonNull;
 import org.joda.time.Duration;
+import org.joda.time.Instant;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -49,9 +52,12 @@ public class KafkaCommitOffset<K, V>
     extends PTransform<
         PCollection<KV<KafkaSourceDescriptor, KafkaRecord<K, V>>>, 
PCollection<Void>> {
   private final KafkaIO.ReadSourceDescriptors<K, V> readSourceDescriptors;
+  private final boolean use259implementation;
 
-  KafkaCommitOffset(KafkaIO.ReadSourceDescriptors<K, V> readSourceDescriptors) 
{
+  KafkaCommitOffset(
+      KafkaIO.ReadSourceDescriptors<K, V> readSourceDescriptors, boolean 
use259implementation) {
     this.readSourceDescriptors = readSourceDescriptors;
+    this.use259implementation = use259implementation;
   }
 
   static class CommitOffsetDoFn extends DoFn<KV<KafkaSourceDescriptor, Long>, 
Void> {
@@ -90,7 +96,7 @@ public class KafkaCommitOffset<K, V>
               || description.getBootStrapServers() != null);
       Map<String, Object> config = new HashMap<>(currentConfig);
       if (description.getBootStrapServers() != null
-          && description.getBootStrapServers().size() > 0) {
+          && !description.getBootStrapServers().isEmpty()) {
         config.put(
             ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG,
             String.join(",", description.getBootStrapServers()));
@@ -99,13 +105,78 @@ public class KafkaCommitOffset<K, V>
     }
   }
 
+  private static final class MaxOffsetFn<K, V>
+      extends DoFn<KV<KafkaSourceDescriptor, KafkaRecord<K, V>>, 
KV<KafkaSourceDescriptor, Long>> {
+    private static class OffsetAndTimestamp {
+      OffsetAndTimestamp(long offset, Instant timestamp) {
+        this.offset = offset;
+        this.timestamp = timestamp;
+      }
+
+      void merge(long offset, Instant timestamp) {
+        if (this.offset < offset) {
+          this.offset = offset;
+          this.timestamp = timestamp;
+        }
+      }
+
+      long offset;
+      Instant timestamp;
+    }
+
+    private transient @MonotonicNonNull Map<KafkaSourceDescriptor, 
OffsetAndTimestamp> maxObserved;
+
+    @StartBundle
+    public void startBundle() {
+      if (maxObserved == null) {
+        maxObserved = new HashMap<>();
+      } else {
+        maxObserved.clear();
+      }
+    }
+
+    @RequiresStableInput
+    @ProcessElement
+    @SuppressWarnings("nullness") // startBundle guaranteed to initialize
+    public void processElement(
+        @Element KV<KafkaSourceDescriptor, KafkaRecord<K, V>> element,
+        @Timestamp Instant timestamp) {
+      maxObserved.compute(
+          element.getKey(),
+          (k, v) -> {
+            long offset = element.getValue().getOffset();
+            if (v == null) {
+              return new OffsetAndTimestamp(offset, timestamp);
+            }
+            v.merge(offset, timestamp);
+            return v;
+          });
+    }
+
+    @FinishBundle
+    @SuppressWarnings("nullness") // startBundle guaranteed to initialize
+    public void finishBundle(FinishBundleContext context) {
+      maxObserved.forEach(
+          (k, v) -> context.output(KV.of(k, v.offset), v.timestamp, 
GlobalWindow.INSTANCE));
+    }
+  }
+
   @Override
   public PCollection<Void> expand(PCollection<KV<KafkaSourceDescriptor, 
KafkaRecord<K, V>>> input) {
     try {
-      return input
-          .apply(
-              MapElements.into(new TypeDescriptor<KV<KafkaSourceDescriptor, 
Long>>() {})
-                  .via(element -> KV.of(element.getKey(), 
element.getValue().getOffset())))
+      PCollection<KV<KafkaSourceDescriptor, Long>> offsets;
+      if (use259implementation) {
+        offsets =
+            input.apply(
+                MapElements.into(new TypeDescriptor<KV<KafkaSourceDescriptor, 
Long>>() {})
+                    .via(element -> KV.of(element.getKey(), 
element.getValue().getOffset())));
+      } else {
+        // Reduce the amount of data to combine by calculating a max within 
the generally dense
+        // bundles of reading
+        // from a Kafka partition.
+        offsets = input.apply(ParDo.of(new MaxOffsetFn<>()));
+      }
+      return offsets
           .setCoder(
               KvCoder.of(
                   input
diff --git 
a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java 
b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java
index c1526d5382b..1fd3e3e044e 100644
--- a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java
+++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java
@@ -31,6 +31,7 @@ import java.lang.reflect.Method;
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collections;
+import java.util.Comparator;
 import java.util.HashMap;
 import java.util.HashSet;
 import java.util.List;
@@ -60,12 +61,14 @@ import 
org.apache.beam.sdk.io.kafka.KafkaIOReadImplementationCompatibility.Kafka
 import org.apache.beam.sdk.options.Default;
 import org.apache.beam.sdk.options.ExperimentalOptions;
 import org.apache.beam.sdk.options.PipelineOptions;
+import org.apache.beam.sdk.options.StreamingOptions;
 import org.apache.beam.sdk.options.ValueProvider;
 import org.apache.beam.sdk.runners.AppliedPTransform;
 import org.apache.beam.sdk.runners.PTransformOverride;
 import org.apache.beam.sdk.runners.PTransformOverrideFactory;
 import org.apache.beam.sdk.schemas.JavaFieldSchema;
 import org.apache.beam.sdk.schemas.NoSuchSchemaException;
+import org.apache.beam.sdk.schemas.SchemaRegistry;
 import org.apache.beam.sdk.schemas.annotations.DefaultSchema;
 import org.apache.beam.sdk.schemas.annotations.SchemaCreate;
 import org.apache.beam.sdk.schemas.transforms.Convert;
@@ -103,6 +106,7 @@ import org.apache.beam.sdk.values.TupleTagList;
 import org.apache.beam.sdk.values.TypeDescriptor;
 import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;
 import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Joiner;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Comparators;
 import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList;
 import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;
 import org.apache.kafka.clients.consumer.Consumer;
@@ -2136,9 +2140,9 @@ public class KafkaIO {
    * the transform will expand to:
    *
    * <pre>{@code
-   * PCollection<KafkaSourceDescriptor> --> 
ParDo(ReadFromKafkaDoFn<KafkaSourceDescriptor, KV<KafkaSourceDescriptor, 
KafkaRecord>>) --> Reshuffle() --> Map(output KafkaRecord)
-   *                                                                           
                                                              |
-   *                                                                           
                                                              --> 
KafkaCommitOffset
+   * PCollection<KafkaSourceDescriptor> --> 
ParDo(ReadFromKafkaDoFn<KafkaSourceDescriptor, KV<KafkaSourceDescriptor, 
KafkaRecord>>) --> Map(output KafkaRecord)
+   *                                                                           
                               |
+   *                                                                           
                               --> KafkaCommitOffset
    * }</pre>
    *
    * . Note that this expansion is not supported when running with x-lang on 
Dataflow.
@@ -2682,33 +2686,61 @@ public class KafkaIO {
                             .getSchemaRegistry()
                             .getSchemaCoder(KafkaSourceDescriptor.class),
                         recordCoder));
-        if (isCommitOffsetEnabled() && !configuredKafkaCommit() && 
!isRedistribute()) {
-          outputWithDescriptor =
-              outputWithDescriptor
-                  .apply(Reshuffle.viaRandomKey())
-                  .setCoder(
-                      KvCoder.of(
-                          input
-                              .getPipeline()
-                              .getSchemaRegistry()
-                              .getSchemaCoder(KafkaSourceDescriptor.class),
-                          recordCoder));
-
-          PCollection<Void> unused = outputWithDescriptor.apply(new 
KafkaCommitOffset<K, V>(this));
-          unused.setCoder(VoidCoder.of());
+
+        boolean applyCommitOffsets =
+            isCommitOffsetEnabled() && !configuredKafkaCommit() && 
!isRedistribute();
+        if (!applyCommitOffsets) {
+          return outputWithDescriptor
+              .apply(MapElements.into(new TypeDescriptor<KafkaRecord<K, V>>() 
{}).via(KV::getValue))
+              .setCoder(recordCoder);
+        }
+
+        // Add transform for committing offsets to Kafka with consistency with 
beam pipeline data
+        // processing.
+        String requestedVersionString =
+            input
+                .getPipeline()
+                .getOptions()
+                .as(StreamingOptions.class)
+                .getUpdateCompatibilityVersion();
+        if (requestedVersionString != null) {
+          List<String> requestedVersion = 
Arrays.asList(requestedVersionString.split("\\."));
+          List<String> targetVersion = Arrays.asList("2", "60", "0");
+
+          if (Comparators.lexicographical(Comparator.<String>naturalOrder())
+                  .compare(requestedVersion, targetVersion)
+              < 0) {
+            return expand259Commits(
+                outputWithDescriptor, recordCoder, 
input.getPipeline().getSchemaRegistry());
+          }
         }
-        PCollection<KafkaRecord<K, V>> output =
-            outputWithDescriptor
-                .apply(
-                    MapElements.into(new TypeDescriptor<KafkaRecord<K, V>>() 
{})
-                        .via(element -> element.getValue()))
-                .setCoder(recordCoder);
-        return output;
+        outputWithDescriptor.apply(new KafkaCommitOffset<>(this, 
false)).setCoder(VoidCoder.of());
+        return outputWithDescriptor
+            .apply(MapElements.into(new TypeDescriptor<KafkaRecord<K, V>>() 
{}).via(KV::getValue))
+            .setCoder(recordCoder);
       } catch (NoSuchSchemaException e) {
         throw new RuntimeException(e.getMessage());
       }
     }
 
+    private PCollection<KafkaRecord<K, V>> expand259Commits(
+        PCollection<KV<KafkaSourceDescriptor, KafkaRecord<K, V>>> 
outputWithDescriptor,
+        Coder<KafkaRecord<K, V>> recordCoder,
+        SchemaRegistry schemaRegistry)
+        throws NoSuchSchemaException {
+      // Reshuffles the data and then branches off applying commit offsets.
+      outputWithDescriptor =
+          outputWithDescriptor
+              .apply(Reshuffle.viaRandomKey())
+              .setCoder(
+                  KvCoder.of(
+                      
schemaRegistry.getSchemaCoder(KafkaSourceDescriptor.class), recordCoder));
+      outputWithDescriptor.apply(new KafkaCommitOffset<>(this, 
true)).setCoder(VoidCoder.of());
+      return outputWithDescriptor
+          .apply(MapElements.into(new TypeDescriptor<KafkaRecord<K, V>>() 
{}).via(KV::getValue))
+          .setCoder(recordCoder);
+    }
+
     private Coder<K> getKeyCoder(CoderRegistry coderRegistry) {
       return (getKeyCoder() != null)
           ? getKeyCoder()
diff --git 
a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaCommitOffsetTest.java
 
b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaCommitOffsetTest.java
index f258328c109..c16e25510ab 100644
--- 
a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaCommitOffsetTest.java
+++ 
b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaCommitOffsetTest.java
@@ -17,13 +17,25 @@
  */
 package org.apache.beam.sdk.io.kafka;
 
+import java.util.ArrayList;
 import java.util.HashMap;
+import java.util.List;
 import java.util.Map;
+import java.util.concurrent.TimeUnit;
+import org.apache.beam.sdk.coders.CannotProvideCoderException;
+import org.apache.beam.sdk.coders.KvCoder;
+import org.apache.beam.sdk.coders.StringUtf8Coder;
 import org.apache.beam.sdk.io.kafka.KafkaCommitOffset.CommitOffsetDoFn;
 import org.apache.beam.sdk.io.kafka.KafkaIO.ReadSourceDescriptors;
 import org.apache.beam.sdk.testing.ExpectedLogs;
+import org.apache.beam.sdk.testing.NeedsRunner;
+import org.apache.beam.sdk.testing.TestPipeline;
+import org.apache.beam.sdk.transforms.Create;
 import org.apache.beam.sdk.transforms.SerializableFunction;
 import org.apache.beam.sdk.values.KV;
+import org.apache.beam.sdk.values.PCollection;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;
 import org.apache.kafka.clients.consumer.Consumer;
 import org.apache.kafka.clients.consumer.ConsumerConfig;
 import org.apache.kafka.clients.consumer.MockConsumer;
@@ -33,14 +45,14 @@ import org.apache.kafka.common.TopicPartition;
 import org.junit.Assert;
 import org.junit.Rule;
 import org.junit.Test;
+import org.junit.experimental.categories.Category;
 import org.junit.runner.RunWith;
 import org.junit.runners.JUnit4;
 
 /** Unit tests for {@link KafkaCommitOffset}. */
 @RunWith(JUnit4.class)
 public class KafkaCommitOffsetTest {
-
-  private final TopicPartition partition = new TopicPartition("topic", 0);
+  @Rule public final transient TestPipeline pipeline = TestPipeline.create();
   @Rule public ExpectedLogs expectedLogs = 
ExpectedLogs.none(CommitOffsetDoFn.class);
 
   private final KafkaCommitOffsetMockConsumer consumer =
@@ -48,29 +60,132 @@ public class KafkaCommitOffsetTest {
   private final KafkaCommitOffsetMockConsumer errorConsumer =
       new KafkaCommitOffsetMockConsumer(null, true);
 
+  private static final KafkaCommitOffsetMockConsumer COMPOSITE_CONSUMER =
+      new KafkaCommitOffsetMockConsumer(null, false);
+  private static final KafkaCommitOffsetMockConsumer 
COMPOSITE_CONSUMER_BOOTSTRAP =
+      new KafkaCommitOffsetMockConsumer(null, false);
+
+  private static final Map<String, Object> configMap =
+      ImmutableMap.of(ConsumerConfig.GROUP_ID_CONFIG, "group1");
+
   @Test
   public void testCommitOffsetDoFn() {
-    Map<String, Object> configMap = new HashMap<>();
-    configMap.put(ConsumerConfig.GROUP_ID_CONFIG, "group1");
-
     ReadSourceDescriptors<Object, Object> descriptors =
         ReadSourceDescriptors.read()
             .withBootstrapServers("bootstrap_server")
             .withConsumerConfigUpdates(configMap)
             .withConsumerFactoryFn(
-                new SerializableFunction<Map<String, Object>, Consumer<byte[], 
byte[]>>() {
-                  @Override
-                  public Consumer<byte[], byte[]> apply(Map<String, Object> 
input) {
-                    Assert.assertEquals("group1", 
input.get(ConsumerConfig.GROUP_ID_CONFIG));
-                    return consumer;
-                  }
-                });
+                (SerializableFunction<Map<String, Object>, Consumer<byte[], 
byte[]>>)
+                    input -> {
+                      Assert.assertEquals("group1", 
input.get(ConsumerConfig.GROUP_ID_CONFIG));
+                      return consumer;
+                    });
     CommitOffsetDoFn doFn = new CommitOffsetDoFn(descriptors);
 
+    final TopicPartition topicPartition1 = new TopicPartition("topic", 0);
+    final TopicPartition topicPartition2 = new TopicPartition("other_topic", 
1);
     doFn.processElement(
-        KV.of(KafkaSourceDescriptor.of(partition, null, null, null, null, 
null), 1L));
+        KV.of(KafkaSourceDescriptor.of(topicPartition1, null, null, null, 
null, null), 2L));
+    doFn.processElement(
+        KV.of(KafkaSourceDescriptor.of(topicPartition2, null, null, null, 
null, null), 200L));
 
-    Assert.assertEquals(2L, consumer.commit.get(partition).offset());
+    Assert.assertEquals(3L, (long) 
consumer.commitOffsets.get(topicPartition1));
+    Assert.assertEquals(201L, (long) 
consumer.commitOffsets.get(topicPartition2));
+
+    doFn.processElement(
+        KV.of(KafkaSourceDescriptor.of(topicPartition1, null, null, null, 
null, null), 3L));
+    Assert.assertEquals(4L, (long) 
consumer.commitOffsets.get(topicPartition1));
+  }
+
+  KafkaRecord<String, String> makeTestRecord(int i) {
+    return new KafkaRecord<>(
+        "", 0, i, 0, KafkaTimestampType.NO_TIMESTAMP_TYPE, null, KV.of("key" + 
i, "value" + i));
+  }
+
+  @Test
+  @Category(NeedsRunner.class)
+  public void testKafkaOffsetComposite() throws CannotProvideCoderException {
+    testKafkaOffsetHelper(false);
+  }
+
+  @Test
+  @Category(NeedsRunner.class)
+  public void testKafkaOffsetCompositeLegacy() throws 
CannotProvideCoderException {
+    testKafkaOffsetHelper(true);
+  }
+
+  private void testKafkaOffsetHelper(boolean use259Implementation)
+      throws CannotProvideCoderException {
+    COMPOSITE_CONSUMER.commitOffsets.clear();
+    COMPOSITE_CONSUMER_BOOTSTRAP.commitOffsets.clear();
+
+    ReadSourceDescriptors<String, String> descriptors =
+        ReadSourceDescriptors.<String, String>read()
+            .withBootstrapServers("bootstrap_server")
+            .withConsumerConfigUpdates(configMap)
+            .withConsumerFactoryFn(
+                (SerializableFunction<Map<String, Object>, Consumer<byte[], 
byte[]>>)
+                    input -> {
+                      Assert.assertEquals("group1", 
input.get(ConsumerConfig.GROUP_ID_CONFIG));
+                      if (input
+                          .get(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG)
+                          .equals("bootstrap_server")) {
+                        return COMPOSITE_CONSUMER;
+                      }
+                      Assert.assertEquals(
+                          "bootstrap_overridden",
+                          input.get(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG));
+                      return COMPOSITE_CONSUMER_BOOTSTRAP;
+                    });
+
+    String topic0 = "topic0_" + (use259Implementation ? "259" : "260");
+    String topic1 = "topic1_" + (use259Implementation ? "259" : "260");
+    KafkaSourceDescriptor d1 =
+        KafkaSourceDescriptor.of(new TopicPartition(topic0, 0), null, null, 
null, null, null);
+    KafkaSourceDescriptor d2 =
+        KafkaSourceDescriptor.of(new TopicPartition(topic0, 1), null, null, 
null, null, null);
+    KafkaSourceDescriptor d3 =
+        KafkaSourceDescriptor.of(
+            new TopicPartition(topic1, 0),
+            null,
+            null,
+            null,
+            null,
+            ImmutableList.of("bootstrap_overridden"));
+    KafkaSourceDescriptor d4 =
+        KafkaSourceDescriptor.of(
+            new TopicPartition(topic1, 1),
+            null,
+            null,
+            null,
+            null,
+            ImmutableList.of("bootstrap_overridden"));
+    List<KV<KafkaSourceDescriptor, KafkaRecord<String, String>>> elements = 
new ArrayList<>();
+    elements.add(KV.of(d1, makeTestRecord(10)));
+
+    elements.add(KV.of(d2, makeTestRecord(20)));
+    elements.add(KV.of(d3, makeTestRecord(30)));
+    elements.add(KV.of(d4, makeTestRecord(40)));
+    elements.add(KV.of(d2, makeTestRecord(10)));
+    elements.add(KV.of(d1, makeTestRecord(100)));
+    PCollection<KV<KafkaSourceDescriptor, KafkaRecord<String, String>>> input =
+        pipeline.apply(
+            Create.of(elements)
+                .withCoder(
+                    KvCoder.of(
+                        
pipeline.getCoderRegistry().getCoder(KafkaSourceDescriptor.class),
+                        KafkaRecordCoder.of(StringUtf8Coder.of(), 
StringUtf8Coder.of()))));
+    input.apply(new KafkaCommitOffset<>(descriptors, use259Implementation));
+    pipeline.run();
+
+    HashMap<TopicPartition, Long> expectedOffsets = new HashMap<>();
+    expectedOffsets.put(d1.getTopicPartition(), 101L);
+    expectedOffsets.put(d2.getTopicPartition(), 21L);
+    Assert.assertEquals(expectedOffsets, COMPOSITE_CONSUMER.commitOffsets);
+    expectedOffsets.clear();
+    expectedOffsets.put(d3.getTopicPartition(), 31L);
+    expectedOffsets.put(d4.getTopicPartition(), 41L);
+    Assert.assertEquals(expectedOffsets, 
COMPOSITE_CONSUMER_BOOTSTRAP.commitOffsets);
   }
 
   @Test
@@ -83,25 +198,25 @@ public class KafkaCommitOffsetTest {
             .withBootstrapServers("bootstrap_server")
             .withConsumerConfigUpdates(configMap)
             .withConsumerFactoryFn(
-                new SerializableFunction<Map<String, Object>, Consumer<byte[], 
byte[]>>() {
-                  @Override
-                  public Consumer<byte[], byte[]> apply(Map<String, Object> 
input) {
-                    Assert.assertEquals("group1", 
input.get(ConsumerConfig.GROUP_ID_CONFIG));
-                    return errorConsumer;
-                  }
-                });
+                (SerializableFunction<Map<String, Object>, Consumer<byte[], 
byte[]>>)
+                    input -> {
+                      Assert.assertEquals("group1", 
input.get(ConsumerConfig.GROUP_ID_CONFIG));
+                      return errorConsumer;
+                    });
     CommitOffsetDoFn doFn = new CommitOffsetDoFn(descriptors);
 
+    final TopicPartition partition = new TopicPartition("topic", 0);
     doFn.processElement(
         KV.of(KafkaSourceDescriptor.of(partition, null, null, null, null, 
null), 1L));
 
     expectedLogs.verifyWarn("Getting exception when committing offset: Test 
Exception");
+    Assert.assertTrue(errorConsumer.commitOffsets.isEmpty());
   }
 
   private static class KafkaCommitOffsetMockConsumer extends 
MockConsumer<byte[], byte[]> {
 
-    public Map<TopicPartition, OffsetAndMetadata> commit;
-    private boolean throwException;
+    public final HashMap<TopicPartition, Long> commitOffsets = new HashMap<>();
+    private final boolean throwException;
 
     public KafkaCommitOffsetMockConsumer(
         OffsetResetStrategy offsetResetStrategy, boolean throwException) {
@@ -115,8 +230,14 @@ public class KafkaCommitOffsetTest {
         throw new RuntimeException("Test Exception");
       } else {
         commitAsync(offsets, null);
-        commit = offsets;
+        offsets.forEach(
+            (topic, offsetMetadata) -> commitOffsets.put(topic, 
offsetMetadata.offset()));
       }
     }
+
+    @Override
+    public synchronized void close(long timeout, TimeUnit unit) {
+      // Ignore closing since we're using a single consumer.
+    }
   }
 }

Reply via email to