This is an automated email from the ASF dual-hosted git repository.

zihanli58 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/gobblin.git


The following commit(s) were added to refs/heads/master by this push:
     new 918716fbd [GOBBLIN-1922]Add function in Kafka Source to recompute 
workUnits for filtered partitions (#3798)
918716fbd is described below

commit 918716fbd7dcc4dfaef7071c058f7209d64e694b
Author: Hanghang Nate Liu <[email protected]>
AuthorDate: Thu Oct 26 10:12:57 2023 -0700

    [GOBBLIN-1922]Add function in Kafka Source to recompute workUnits for 
filtered partitions (#3798)
    
    * add function in Kafka Source to recompute workUnits for filtered 
partitions
    
    * address comments
    
    * set default min container value to 1
    
    * add condition when create empty wu
    
    * update the condition
---
 .../client/AbstractBaseKafkaConsumerClient.java    |   1 -
 .../extractor/extract/kafka/KafkaSource.java       | 130 +++++++++++++++------
 .../extractor/extract/kafka/KafkaSourceTest.java   | 114 +++++++++++++++++-
 3 files changed, 204 insertions(+), 41 deletions(-)

diff --git 
a/gobblin-modules/gobblin-kafka-common/src/main/java/org/apache/gobblin/kafka/client/AbstractBaseKafkaConsumerClient.java
 
b/gobblin-modules/gobblin-kafka-common/src/main/java/org/apache/gobblin/kafka/client/AbstractBaseKafkaConsumerClient.java
index 3933adb3c..18cc75c30 100644
--- 
a/gobblin-modules/gobblin-kafka-common/src/main/java/org/apache/gobblin/kafka/client/AbstractBaseKafkaConsumerClient.java
+++ 
b/gobblin-modules/gobblin-kafka-common/src/main/java/org/apache/gobblin/kafka/client/AbstractBaseKafkaConsumerClient.java
@@ -157,7 +157,6 @@ public abstract class AbstractBaseKafkaConsumerClient 
implements GobblinKafkaCon
     return processedName;
   }
 
-
   /**
    * Get a list of all kafka topics
    */
diff --git 
a/gobblin-modules/gobblin-kafka-common/src/main/java/org/apache/gobblin/source/extractor/extract/kafka/KafkaSource.java
 
b/gobblin-modules/gobblin-kafka-common/src/main/java/org/apache/gobblin/source/extractor/extract/kafka/KafkaSource.java
index 0bca916a7..80ef8c09d 100644
--- 
a/gobblin-modules/gobblin-kafka-common/src/main/java/org/apache/gobblin/source/extractor/extract/kafka/KafkaSource.java
+++ 
b/gobblin-modules/gobblin-kafka-common/src/main/java/org/apache/gobblin/source/extractor/extract/kafka/KafkaSource.java
@@ -18,6 +18,10 @@
 package org.apache.gobblin.source.extractor.extract.kafka;
 
 import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
@@ -29,6 +33,7 @@ import java.util.concurrent.atomic.AtomicBoolean;
 import java.util.concurrent.atomic.AtomicInteger;
 import java.util.regex.Pattern;
 
+import java.util.stream.Collectors;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -60,6 +65,7 @@ import 
org.apache.gobblin.kafka.client.GobblinKafkaConsumerClient.GobblinKafkaCo
 import org.apache.gobblin.metrics.MetricContext;
 import org.apache.gobblin.metrics.event.lineage.LineageInfo;
 import org.apache.gobblin.source.extractor.extract.EventBasedSource;
+import 
org.apache.gobblin.source.extractor.extract.kafka.workunit.packer.KafkaTopicGroupingWorkUnitPacker;
 import 
org.apache.gobblin.source.extractor.extract.kafka.workunit.packer.KafkaWorkUnitPacker;
 import 
org.apache.gobblin.source.extractor.extract.kafka.validator.TopicValidators;
 import org.apache.gobblin.source.extractor.limiter.LimiterConfigurationKeys;
@@ -190,10 +196,27 @@ public abstract class KafkaSource<S, D> extends 
EventBasedSource<S, D> {
 
   @Override
   public List<WorkUnit> getWorkunits(SourceState state) {
+    return this.getWorkunitsForFilteredPartitions(state, Optional.absent(), 
Optional.absent());
+  }
+
+  /**
+   * Compute Workunits for Kafka Topics. If filteredTopicPartition present, 
respect this map and only compute the provided
+   * topics and filtered partitions. If not, use state to discover Kafka 
topics and all available partitions.
+   *
+   * @param filteredTopicPartition optional parameter to determine if only 
filtered topic-partitions are needed.
+   * @param minContainer give an option to specify a minimum container count. 
Please be advised that how it being used is
+   *                     determined by the implementation of concrete {@link 
KafkaWorkUnitPacker} class.
+   *
+   * TODO: Utilize the minContainer in {@link 
KafkaTopicGroupingWorkUnitPacker#pack(Map, int)}, as the numContainers variable
+   *                     is not used currently.
+   */
+  public List<WorkUnit> getWorkunitsForFilteredPartitions(SourceState state,
+      Optional<Map<String, List<Integer>>> filteredTopicPartition, 
Optional<Integer> minContainer) {
     this.metricContext = Instrumented.getMetricContext(state, 
KafkaSource.class);
     this.lineageInfo = LineageInfo.getLineageInfo(state.getBroker());
 
-    Map<String, List<WorkUnit>> workUnits = Maps.newConcurrentMap();
+    Map<String, List<Integer>> filteredTopicPartitionMap = 
filteredTopicPartition.or(new HashMap<>());
+    Map<String, List<WorkUnit>> kafkaTopicWorkunitMap = 
Maps.newConcurrentMap();
     if 
(state.getPropAsBoolean(KafkaSource.GOBBLIN_KAFKA_EXTRACT_ALLOW_TABLE_TYPE_NAMESPACE_CUSTOMIZATION))
 {
       String tableTypeStr =
           state.getProp(ConfigurationKeys.EXTRACT_TABLE_TYPE_KEY, 
KafkaSource.DEFAULT_TABLE_TYPE.toString());
@@ -213,18 +236,22 @@ public abstract class KafkaSource<S, D> extends 
EventBasedSource<S, D> {
     try {
       Config config = ConfigUtils.propertiesToConfig(state.getProperties());
       GobblinKafkaConsumerClientFactory kafkaConsumerClientFactory = 
kafkaConsumerClientResolver
-              .resolveClass(
-                      
state.getProp(GOBBLIN_KAFKA_CONSUMER_CLIENT_FACTORY_CLASS,
-                              
DEFAULT_GOBBLIN_KAFKA_CONSUMER_CLIENT_FACTORY_CLASS)).newInstance();
+          .resolveClass(
+              state.getProp(GOBBLIN_KAFKA_CONSUMER_CLIENT_FACTORY_CLASS,
+                  
DEFAULT_GOBBLIN_KAFKA_CONSUMER_CLIENT_FACTORY_CLASS)).newInstance();
 
       this.kafkaConsumerClient.set(kafkaConsumerClientFactory.create(config));
 
-      List<KafkaTopic> topics = getValidTopics(getFilteredTopics(state), 
state);
+      Collection<KafkaTopic> topics;
+      if(filteredTopicPartition.isPresent()) {
+        // If filteredTopicPartition present, use it to construct the 
whitelist pattern while leave blacklist empty
+        topics = 
this.kafkaConsumerClient.get().getFilteredTopics(Collections.emptyList(),
+            
filteredTopicPartitionMap.keySet().stream().map(Pattern::compile).collect(Collectors.toList()));
+      } else {
+        topics = getValidTopics(getFilteredTopics(state), state);
+      }
       this.topicsToProcess = 
topics.stream().map(KafkaTopic::getName).collect(toSet());
 
-      for (String topic : this.topicsToProcess) {
-        LOG.info("Discovered topic " + topic);
-      }
       Map<String, State> topicSpecificStateMap =
           DatasetUtils.getDatasetSpecificProps(Iterables.transform(topics, new 
Function<KafkaTopic, String>() {
 
@@ -234,20 +261,13 @@ public abstract class KafkaSource<S, D> extends 
EventBasedSource<S, D> {
             }
           }), state);
 
-      for (KafkaTopic topic : topics) {
-        if (topic.getTopicSpecificState().isPresent()) {
-          topicSpecificStateMap.computeIfAbsent(topic.getName(), k -> new 
State())
-              .addAllIfNotExist(topic.getTopicSpecificState().get());
-        }
-      }
-
       int numOfThreads = 
state.getPropAsInt(ConfigurationKeys.KAFKA_SOURCE_WORK_UNITS_CREATION_THREADS,
           
ConfigurationKeys.KAFKA_SOURCE_WORK_UNITS_CREATION_DEFAULT_THREAD_COUNT);
       ExecutorService threadPool =
           Executors.newFixedThreadPool(numOfThreads, 
ExecutorsUtils.newThreadFactory(Optional.of(LOG)));
 
       if 
(state.getPropAsBoolean(ConfigurationKeys.KAFKA_SOURCE_SHARE_CONSUMER_CLIENT,
-              ConfigurationKeys.DEFAULT_KAFKA_SOURCE_SHARE_CONSUMER_CLIENT)) {
+          ConfigurationKeys.DEFAULT_KAFKA_SOURCE_SHARE_CONSUMER_CLIENT)) {
         this.sharedKafkaConsumerClient = this.kafkaConsumerClient.get();
       } else {
         // preallocate one client per thread
@@ -257,32 +277,44 @@ public abstract class KafkaSource<S, D> extends 
EventBasedSource<S, D> {
       Stopwatch createWorkUnitStopwatch = Stopwatch.createStarted();
 
       for (KafkaTopic topic : topics) {
+        LOG.info("Discovered topic " + topic);
+        if (topic.getTopicSpecificState().isPresent()) {
+          topicSpecificStateMap.computeIfAbsent(topic.getName(), k -> new 
State())
+              .addAllIfNotExist(topic.getTopicSpecificState().get());
+        }
+        Optional<Set<Integer>> partitionIDSet = Optional.absent();
+        if(filteredTopicPartition.isPresent()) {
+          List<Integer> list = 
java.util.Optional.ofNullable(filteredTopicPartitionMap.get(topic.getName()))
+              .orElse(new ArrayList<>());
+          partitionIDSet = Optional.of(new HashSet<>(list));
+          LOG.info("Compute the workunit for topic {} with num of filtered 
partitions: {}",
+              topic.getName(), list.size());
+        }
+
         threadPool.submit(
             new WorkUnitCreator(topic, state, 
Optional.fromNullable(topicSpecificStateMap.get(topic.getName())),
-                workUnits));
+                kafkaTopicWorkunitMap, partitionIDSet));
       }
 
       ExecutorsUtils.shutdownExecutorService(threadPool, Optional.of(LOG), 1L, 
TimeUnit.HOURS);
-      LOG.info(String.format("Created workunits for %d topics in %d seconds", 
workUnits.size(),
+      LOG.info(String.format("Created workunits for %d topics in %d seconds", 
kafkaTopicWorkunitMap.size(),
           createWorkUnitStopwatch.elapsed(TimeUnit.SECONDS)));
 
       // Create empty WorkUnits for skipped partitions (i.e., partitions that 
have previous offsets,
-      // but aren't processed).
-      createEmptyWorkUnitsForSkippedPartitions(workUnits, 
topicSpecificStateMap, state);
-      //determine the number of mappers
-      int maxMapperNum =
-          state.getPropAsInt(ConfigurationKeys.MR_JOB_MAX_MAPPERS_KEY, 
ConfigurationKeys.DEFAULT_MR_JOB_MAX_MAPPERS);
+      // but aren't processed). When filteredTopicPartition present, only 
filtered topic-partitions are needed so skip this call
+      if(!filteredTopicPartition.isPresent()) {
+        createEmptyWorkUnitsForSkippedPartitions(kafkaTopicWorkunitMap, 
topicSpecificStateMap, state);
+      }
+
       KafkaWorkUnitPacker kafkaWorkUnitPacker = 
KafkaWorkUnitPacker.getInstance(this, state, Optional.of(this.metricContext));
-      int numOfMultiWorkunits = maxMapperNum;
-      if(state.contains(ConfigurationKeys.MR_TARGET_MAPPER_SIZE)) {
-        double totalEstDataSize = 
kafkaWorkUnitPacker.setWorkUnitEstSizes(workUnits);
-        LOG.info(String.format("The total estimated data size is %.2f", 
totalEstDataSize));
-        double targetMapperSize = 
state.getPropAsDouble(ConfigurationKeys.MR_TARGET_MAPPER_SIZE);
-        numOfMultiWorkunits = (int) (totalEstDataSize / targetMapperSize) + 1;
-        numOfMultiWorkunits = Math.min(numOfMultiWorkunits, maxMapperNum);
+      int numOfMultiWorkunits = minContainer.or(1);
+      if(state.contains(ConfigurationKeys.MR_JOB_MAX_MAPPERS_KEY)) {
+        numOfMultiWorkunits = Math.max(numOfMultiWorkunits,
+            calculateNumMappersForPacker(state, kafkaWorkUnitPacker, 
kafkaTopicWorkunitMap));
       }
-      addTopicSpecificPropsToWorkUnits(workUnits, topicSpecificStateMap);
-      List<WorkUnit> workUnitList = kafkaWorkUnitPacker.pack(workUnits, 
numOfMultiWorkunits);
+
+      addTopicSpecificPropsToWorkUnits(kafkaTopicWorkunitMap, 
topicSpecificStateMap);
+      List<WorkUnit> workUnitList = 
kafkaWorkUnitPacker.pack(kafkaTopicWorkunitMap, numOfMultiWorkunits);
       setLimiterReportKeyListToWorkUnits(workUnitList, 
getLimiterExtractorReportKeys());
       return workUnitList;
     } catch (InstantiationException | IllegalAccessException | 
ClassNotFoundException e) {
@@ -305,6 +337,7 @@ public abstract class KafkaSource<S, D> extends 
EventBasedSource<S, D> {
         LOG.error("Exception {} encountered closing GobblinKafkaConsumerClient 
", t);
       }
     }
+
   }
 
   protected void populateClientPool(int count,
@@ -377,10 +410,27 @@ public abstract class KafkaSource<S, D> extends 
EventBasedSource<S, D> {
     }
   }
 
+  //determine the number of mappers/containers for workunit packer
+  private int calculateNumMappersForPacker(SourceState state,
+      KafkaWorkUnitPacker kafkaWorkUnitPacker, Map<String, List<WorkUnit>> 
workUnits) {
+    int maxMapperNum =
+        state.getPropAsInt(ConfigurationKeys.MR_JOB_MAX_MAPPERS_KEY, 
ConfigurationKeys.DEFAULT_MR_JOB_MAX_MAPPERS);
+    int numContainers = maxMapperNum;
+    if(state.contains(ConfigurationKeys.MR_TARGET_MAPPER_SIZE)) {
+      double totalEstDataSize = 
kafkaWorkUnitPacker.setWorkUnitEstSizes(workUnits);
+      LOG.info(String.format("The total estimated data size is %.2f", 
totalEstDataSize));
+      double targetMapperSize = 
state.getPropAsDouble(ConfigurationKeys.MR_TARGET_MAPPER_SIZE);
+      numContainers = (int) (totalEstDataSize / targetMapperSize) + 1;
+      numContainers = Math.min(numContainers, maxMapperNum);
+    }
+    return numContainers;
+  }
+
   /*
    * This function need to be thread safe since it is called in the Runnable
    */
-  private List<WorkUnit> getWorkUnitsForTopic(KafkaTopic topic, SourceState 
state, Optional<State> topicSpecificState) {
+  private List<WorkUnit> getWorkUnitsForTopic(KafkaTopic topic, SourceState 
state,
+      Optional<State> topicSpecificState, Optional<Set<Integer>> 
filteredPartitions) {
     Timer.Context context = 
this.metricContext.timer("isTopicQualifiedTimer").time();
     boolean topicQualified = isTopicQualified(topic);
     context.close();
@@ -388,6 +438,9 @@ public abstract class KafkaSource<S, D> extends 
EventBasedSource<S, D> {
     List<WorkUnit> workUnits = Lists.newArrayList();
     List<KafkaPartition> topicPartitions = topic.getPartitions();
     for (KafkaPartition partition : topicPartitions) {
+      if(filteredPartitions.isPresent() && 
!filteredPartitions.get().contains(partition.getId())) {
+        continue;
+      }
       WorkUnit workUnit = getWorkUnitForTopicPartition(partition, state, 
topicSpecificState);
       if (workUnit != null) {
         // For disqualified topics, for each of its workunits set the high 
watermark to be the same
@@ -895,13 +948,20 @@ public abstract class KafkaSource<S, D> extends 
EventBasedSource<S, D> {
     private final SourceState state;
     private final Optional<State> topicSpecificState;
     private final Map<String, List<WorkUnit>> allTopicWorkUnits;
+    private final Optional<Set<Integer>> filteredPartitionsId;
 
     WorkUnitCreator(KafkaTopic topic, SourceState state, Optional<State> 
topicSpecificState,
         Map<String, List<WorkUnit>> workUnits) {
+      this(topic, state, topicSpecificState, workUnits, Optional.absent());
+    }
+
+    WorkUnitCreator(KafkaTopic topic, SourceState state, Optional<State> 
topicSpecificState,
+        Map<String, List<WorkUnit>> workUnits, Optional<Set<Integer>> 
filteredPartitionsId) {
       this.topic = topic;
       this.state = state;
       this.topicSpecificState = topicSpecificState;
       this.allTopicWorkUnits = workUnits;
+      this.filteredPartitionsId = filteredPartitionsId;
     }
 
     @Override
@@ -917,7 +977,7 @@ public abstract class KafkaSource<S, D> extends 
EventBasedSource<S, D> {
         }
 
         this.allTopicWorkUnits.put(this.topic.getName(),
-            KafkaSource.this.getWorkUnitsForTopic(this.topic, this.state, 
this.topicSpecificState));
+            KafkaSource.this.getWorkUnitsForTopic(this.topic, this.state, 
this.topicSpecificState, this.filteredPartitionsId));
       } catch (Throwable t) {
         LOG.error("Caught error in creating work unit for " + 
this.topic.getName(), t);
         throw new RuntimeException(t);
@@ -930,4 +990,4 @@ public abstract class KafkaSource<S, D> extends 
EventBasedSource<S, D> {
       }
     }
   }
-}
+}
\ No newline at end of file
diff --git 
a/gobblin-modules/gobblin-kafka-common/src/test/java/org/apache/gobblin/source/extractor/extract/kafka/KafkaSourceTest.java
 
b/gobblin-modules/gobblin-kafka-common/src/test/java/org/apache/gobblin/source/extractor/extract/kafka/KafkaSourceTest.java
index c26872e1c..fff8581fc 100644
--- 
a/gobblin-modules/gobblin-kafka-common/src/test/java/org/apache/gobblin/source/extractor/extract/kafka/KafkaSourceTest.java
+++ 
b/gobblin-modules/gobblin-kafka-common/src/test/java/org/apache/gobblin/source/extractor/extract/kafka/KafkaSourceTest.java
@@ -17,17 +17,27 @@
 
 package org.apache.gobblin.source.extractor.extract.kafka;
 
+import com.google.common.base.Optional;
+import com.typesafe.config.Config;
 import java.io.IOException;
+import java.util.ArrayList;
 import java.util.Arrays;
-import java.util.Collections;
+import java.util.HashMap;
 import java.util.Iterator;
+import java.util.LinkedList;
 import java.util.List;
+import java.util.Map;
 import java.util.regex.Pattern;
 import java.util.stream.Collectors;
 
 import org.apache.commons.collections.CollectionUtils;
+import org.apache.gobblin.annotation.Alias;
+import org.apache.gobblin.configuration.ConfigurationKeys;
 import 
org.apache.gobblin.source.extractor.extract.kafka.validator.TopicNameValidator;
 import 
org.apache.gobblin.source.extractor.extract.kafka.validator.TopicValidators;
+import 
org.apache.gobblin.source.extractor.extract.kafka.workunit.packer.KafkaWorkUnitPacker;
+import org.apache.gobblin.source.workunit.MultiWorkUnit;
+import org.apache.gobblin.source.workunit.WorkUnit;
 import org.testng.Assert;
 import org.testng.annotations.Test;
 
@@ -38,8 +48,54 @@ import org.apache.gobblin.kafka.client.KafkaConsumerRecord;
 import org.apache.gobblin.source.extractor.Extractor;
 import org.apache.gobblin.util.DatasetFilterUtils;
 
+import static org.apache.gobblin.source.extractor.extract.kafka.KafkaSource.*;
+
 
 public class KafkaSourceTest {
+  private static List<String> testTopics =Arrays.asList(
+      "topic1", "topic2", "topic3");
+
+  @Test
+  public void testGetWorkunits() {
+    TestKafkaClient testKafkaClient = new TestKafkaClient();
+    testKafkaClient.testTopics = testTopics;
+    SourceState state = new SourceState();
+    state.setProp(ConfigurationKeys.WRITER_OUTPUT_DIR, "TestPath");
+    state.setProp(KafkaWorkUnitPacker.KAFKA_WORKUNIT_PACKER_TYPE, 
KafkaWorkUnitPacker.PackerType.CUSTOM);
+    state.setProp(KafkaWorkUnitPacker.KAFKA_WORKUNIT_PACKER_CUSTOMIZED_TYPE, 
"org.apache.gobblin.source.extractor.extract.kafka.workunit.packer.KafkaTopicGroupingWorkUnitPacker");
+    state.setProp(GOBBLIN_KAFKA_CONSUMER_CLIENT_FACTORY_CLASS, 
"MockTestKafkaConsumerClientFactory");
+    TestKafkaSource testKafkaSource = new TestKafkaSource(testKafkaClient);
+    List<WorkUnit> workUnits = testKafkaSource.getWorkunits(state);
+
+    validatePartitionNumWithinWorkUnits(workUnits, 48);
+
+  }
+
+  @Test
+  public void testGetWorkunitsForFilteredPartitions() {
+    TestKafkaClient testKafkaClient = new TestKafkaClient();
+    List<String> allTopics = testTopics;
+    Map<String, List<Integer>> filteredTopicPartitionMap = new HashMap<>();
+    filteredTopicPartitionMap.put(allTopics.get(0), new LinkedList<>());
+    filteredTopicPartitionMap.put(allTopics.get(1), new LinkedList<>());
+    filteredTopicPartitionMap.put(allTopics.get(2), new LinkedList<>());
+    filteredTopicPartitionMap.get(allTopics.get(0)).addAll(Arrays.asList(0, 
11));
+    filteredTopicPartitionMap.get(allTopics.get(1)).addAll(Arrays.asList(2, 8, 
10));
+    filteredTopicPartitionMap.get(allTopics.get(2)).addAll(Arrays.asList(1, 3, 
5, 7));
+
+    testKafkaClient.testTopics = allTopics;
+    SourceState state = new SourceState();
+    state.setProp(ConfigurationKeys.WRITER_OUTPUT_DIR, "TestPath");
+    state.setProp(GOBBLIN_KAFKA_CONSUMER_CLIENT_FACTORY_CLASS, 
"MockTestKafkaConsumerClientFactory");
+    TestKafkaSource testKafkaSource = new TestKafkaSource(testKafkaClient);
+    List<WorkUnit> workUnits = 
testKafkaSource.getWorkunitsForFilteredPartitions(state, 
Optional.of(filteredTopicPartitionMap), Optional.of(3));
+    validatePartitionNumWithinWorkUnits(workUnits, 9);
+
+    state.setProp(KafkaWorkUnitPacker.KAFKA_WORKUNIT_PACKER_TYPE, 
KafkaWorkUnitPacker.PackerType.CUSTOM);
+    state.setProp(KafkaWorkUnitPacker.KAFKA_WORKUNIT_PACKER_CUSTOMIZED_TYPE, 
"org.apache.gobblin.source.extractor.extract.kafka.workunit.packer.KafkaTopicGroupingWorkUnitPacker");
+    workUnits = testKafkaSource.getWorkunitsForFilteredPartitions(state, 
Optional.of(filteredTopicPartitionMap), Optional.of(1));
+    validatePartitionNumWithinWorkUnits(workUnits, 9);
+  }
 
   @Test
   public void testGetFilteredTopics() {
@@ -89,12 +145,60 @@ public class KafkaSourceTest {
         toKafkaTopicList(allTopics.subList(0, 3))));
   }
 
-  public List<KafkaTopic> toKafkaTopicList(List<String> topicNames) {
-    return topicNames.stream().map(topicName -> new KafkaTopic(topicName, 
Collections.emptyList())).collect(Collectors.toList());
+  public static List<KafkaPartition> creatPartitions(String topicName, int 
partitionNum) {
+    List<KafkaPartition> partitions = new ArrayList<>(partitionNum);
+    for(int i = 0; i < partitionNum; i++ ) {
+      partitions.add(new 
KafkaPartition.Builder().withTopicName(topicName).withId(i).withLeaderHostAndPort("test").withLeaderId(1).build());
+    }
+    return partitions;
+  }
+
+  public static List<KafkaPartition> getPartitionFromWorkUnit(WorkUnit 
workUnit) {
+    List<KafkaPartition> topicPartitions = new ArrayList<>();
+    if(workUnit instanceof MultiWorkUnit) {
+      for(WorkUnit wu : ((MultiWorkUnit) workUnit).getWorkUnits()) {
+        topicPartitions.addAll(getPartitionFromWorkUnit(wu));
+      }
+    }else {
+      int i = 0;
+      String partitionIdProp = KafkaSource.PARTITION_ID + "." + i;
+      while (workUnit.getProp(partitionIdProp) != null) {
+        int partitionId = workUnit.getPropAsInt(partitionIdProp);
+        KafkaPartition topicPartition =
+            new 
KafkaPartition.Builder().withTopicName(workUnit.getProp(KafkaSource.TOPIC_NAME)).withId(partitionId).build();
+        topicPartitions.add(topicPartition);
+        i++;
+        partitionIdProp = KafkaSource.PARTITION_ID + "." + i;
+      }
+    }
+    return topicPartitions;
+  }
+
+
+  public static List<KafkaTopic> toKafkaTopicList(List<String> topicNames) {
+    return topicNames.stream().map(topicName -> new KafkaTopic(topicName, 
creatPartitions(topicName, 16))).collect(Collectors.toList());
+  }
+
+  private void validatePartitionNumWithinWorkUnits(List<WorkUnit> workUnits, 
int expectPartitionNum) {
+    List<KafkaPartition> partitionList = new ArrayList<>();
+    for(WorkUnit workUnit : workUnits) {
+      partitionList.addAll(getPartitionFromWorkUnit(workUnit));
+    }
+    Assert.assertEquals(partitionList.size(), expectPartitionNum);
+  }
+
+  @Alias("MockTestKafkaConsumerClientFactory")
+  public static class MockTestKafkaConsumerClientFactory
+      implements GobblinKafkaConsumerClient.GobblinKafkaConsumerClientFactory {
+
+    @Override
+    public GobblinKafkaConsumerClient create(Config config) {
+      return new TestKafkaClient();
+    }
   }
 
-  private class TestKafkaClient implements GobblinKafkaConsumerClient {
-    List<String> testTopics;
+  public static class TestKafkaClient implements GobblinKafkaConsumerClient {
+    List<String> testTopics = KafkaSourceTest.testTopics;
 
     @Override
     public List<KafkaTopic> getFilteredTopics(List<Pattern> blacklist, 
List<Pattern> whitelist) {

Reply via email to