This is an automated email from the ASF dual-hosted git repository.

wlo pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/gobblin.git


The following commit(s) were added to refs/heads/master by this push:
     new 16389dc75 [GOBBLIN-1904] Refactoring SalesforceSource class for better 
code organization and testability (#3768)
16389dc75 is described below

commit 16389dc751cebb985c75faf97b35dd6f51c13653
Author: Gautam Kumar <[email protected]>
AuthorDate: Wed Sep 13 09:29:24 2023 +0530

    [GOBBLIN-1904] Refactoring SalesforceSource class for better code 
organization and testability (#3768)
    
    * Optimizing high watermark metadata query for SFDC
    
    * Enabling support for EarlyStop in single-partition mode
    
    * Refactoring SalesforceSource class for better code organization and 
testability
    
    - In the current form, it's very difficult to add tests for changes as we 
need to handle all sorts of calls.
    - Moved the code related to histogram calculation to a new class - 
SalesforceHistogramService. This allows us to mock all that happens in this 
class, making writing tests easy.
    - Added a test for the new change introduced as part of 
https://github.com/apache/gobblin/pull/3766
    
    * Refactoring SalesforceSource class for better code organization and 
testability
    
    - In the current form, it's very difficult to add tests for changes as we 
need to handle all sorts of calls.
    - Moved the code related to histogram calculation to a new class - 
SalesforceHistogramService. This allows us to mock all that happens in this 
class, making writing tests easy.
    - Added a test for the new change introduced as part of 
https://github.com/apache/gobblin/pull/3766
    
    * Addressed comments
    
    * Addressed comments to pass along the connector instance instead of 
creating a new one in the SalesforceHistogramService class. Also 
(auto)reorganized imports.
    
    ---------
    
    Co-authored-by: Gautam Kumar <[email protected]>
---
 gobblin-salesforce/build.gradle                    |   1 +
 .../org/apache/gobblin/salesforce/Histogram.java   |  52 +++
 .../apache/gobblin/salesforce/HistogramGroup.java  |  33 ++
 .../salesforce/SalesforceHistogramService.java     | 373 ++++++++++++++++
 .../gobblin/salesforce/SalesforceSource.java       | 470 +++------------------
 .../gobblin/salesforce/SalesforceSourceTest.java   |  86 +++-
 6 files changed, 600 insertions(+), 415 deletions(-)

diff --git a/gobblin-salesforce/build.gradle b/gobblin-salesforce/build.gradle
index 0e663c4a6..b42b751ba 100644
--- a/gobblin-salesforce/build.gradle
+++ b/gobblin-salesforce/build.gradle
@@ -37,6 +37,7 @@ dependencies {
     compile externalDependency.salesforcePartner
 
     testCompile externalDependency.testng
+    testCompile externalDependency.mockito
 }
 
 configurations {
diff --git 
a/gobblin-salesforce/src/main/java/org/apache/gobblin/salesforce/Histogram.java 
b/gobblin-salesforce/src/main/java/org/apache/gobblin/salesforce/Histogram.java
new file mode 100644
index 000000000..bb0350392
--- /dev/null
+++ 
b/gobblin-salesforce/src/main/java/org/apache/gobblin/salesforce/Histogram.java
@@ -0,0 +1,52 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.gobblin.salesforce;
+
+import java.util.ArrayList;
+import java.util.List;
+import lombok.Getter;
+
+
+@Getter
+public class Histogram {
+  private long totalRecordCount;
+  private final List<HistogramGroup> groups;
+
+  Histogram() {
+    totalRecordCount = 0;
+    groups = new ArrayList<>();
+  }
+
+  void add(HistogramGroup group) {
+    groups.add(group);
+    totalRecordCount += group.getCount();
+  }
+
+  void add(Histogram histogram) {
+    groups.addAll(histogram.getGroups());
+    totalRecordCount += histogram.totalRecordCount;
+  }
+
+  HistogramGroup get(int idx) {
+    return this.groups.get(idx);
+  }
+
+  @Override
+  public String toString() {
+    return groups.toString();
+  }
+}
diff --git 
a/gobblin-salesforce/src/main/java/org/apache/gobblin/salesforce/HistogramGroup.java
 
b/gobblin-salesforce/src/main/java/org/apache/gobblin/salesforce/HistogramGroup.java
new file mode 100644
index 000000000..f43d61060
--- /dev/null
+++ 
b/gobblin-salesforce/src/main/java/org/apache/gobblin/salesforce/HistogramGroup.java
@@ -0,0 +1,33 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.gobblin.salesforce;
+
+import lombok.AllArgsConstructor;
+import lombok.Getter;
+
+
+@Getter
+@AllArgsConstructor
+class HistogramGroup {
+  private final String key;
+  private final int count;
+
+  @Override
+  public String toString() {
+    return key + ":" + count;
+  }
+}
\ No newline at end of file
diff --git 
a/gobblin-salesforce/src/main/java/org/apache/gobblin/salesforce/SalesforceHistogramService.java
 
b/gobblin-salesforce/src/main/java/org/apache/gobblin/salesforce/SalesforceHistogramService.java
new file mode 100644
index 000000000..ec09da57a
--- /dev/null
+++ 
b/gobblin-salesforce/src/main/java/org/apache/gobblin/salesforce/SalesforceHistogramService.java
@@ -0,0 +1,373 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.gobblin.salesforce;
+
+import com.google.common.math.DoubleMath;
+import com.google.gson.Gson;
+import com.google.gson.JsonArray;
+import com.google.gson.JsonElement;
+import com.google.gson.JsonObject;
+import java.math.RoundingMode;
+import java.util.ArrayList;
+import java.util.Calendar;
+import java.util.Date;
+import java.util.GregorianCalendar;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+import lombok.RequiredArgsConstructor;
+import lombok.SneakyThrows;
+import lombok.extern.slf4j.Slf4j;
+import org.apache.commons.lang3.text.StrSubstitutor;
+import org.apache.gobblin.configuration.SourceState;
+import org.apache.gobblin.source.extractor.DataRecordException;
+import org.apache.gobblin.source.extractor.exception.RestApiClientException;
+import 
org.apache.gobblin.source.extractor.exception.RestApiConnectionException;
+import 
org.apache.gobblin.source.extractor.exception.RestApiProcessingException;
+import org.apache.gobblin.source.extractor.extract.Command;
+import org.apache.gobblin.source.extractor.extract.CommandOutput;
+import org.apache.gobblin.source.extractor.extract.restapi.RestApiConnector;
+import org.apache.gobblin.source.extractor.partition.Partition;
+import org.apache.gobblin.source.extractor.partition.Partitioner;
+import org.apache.gobblin.source.extractor.utils.Utils;
+
+import static org.apache.gobblin.configuration.ConfigurationKeys.*;
+
+
+/**
+ * This class encapsulates everything related to histogram calculation for 
Salesforce. A histogram here refers to a
+ * mapping of number of records to be fetched by time intervals.
+ */
+@Slf4j
+public class SalesforceHistogramService {
+  private static final int MIN_SPLIT_TIME_MILLIS = 1000;
+  private static final String ZERO_TIME_SUFFIX = "-00:00:00";
+  private static final Gson GSON = new Gson();
+  // this is used to generate histogram buckets smaller than the target 
partition size to allow for more even
+  // packing of the generated partitions
+  private static final String PROBE_TARGET_RATIO = 
"salesforce.probeTargetRatio";
+  private static final double DEFAULT_PROBE_TARGET_RATIO = 0.60;
+  private static final String DYNAMIC_PROBING_LIMIT = 
"salesforce.dynamicProbingLimit";
+  private static final int DEFAULT_DYNAMIC_PROBING_LIMIT = 1000;
+
+  private static final String DAY_PARTITION_QUERY_TEMPLATE =
+      "SELECT count(${column}) cnt, DAY_ONLY(${column}) time FROM ${table} " + 
"WHERE ${column} ${greater} ${start}"
+          + " AND ${column} ${less} ${end} GROUP BY DAY_ONLY(${column}) ORDER 
BY DAY_ONLY(${column})";
+  private static final String PROBE_PARTITION_QUERY_TEMPLATE = "SELECT 
count(${column}) cnt FROM ${table} "
+      + "WHERE ${column} ${greater} ${start} AND ${column} ${less} ${end}";
+
+  protected SalesforceConnector salesforceConnector;
+  private final SfConfig sfConfig;
+
+  SalesforceHistogramService(SfConfig sfConfig, SalesforceConnector connector) 
{
+    this.sfConfig = sfConfig;
+    salesforceConnector = connector;
+  }
+
+  /**
+   * Generate the histogram
+   */
+  Histogram getHistogram(String entity, String watermarkColumn, SourceState 
state,
+      Partition partition) {
+
+    try {
+      if (!salesforceConnector.connect()) {
+        throw new RuntimeException("Failed to connect.");
+      }
+    } catch (RestApiConnectionException e) {
+      throw new RuntimeException("Failed to connect.", e);
+    }
+
+    Histogram histogram = getHistogramByDayBucketing(salesforceConnector, 
entity, watermarkColumn, partition);
+
+    // exchange the first histogram group key with the global low watermark to 
ensure that the low watermark is captured
+    // in the range of generated partitions
+    HistogramGroup firstGroup = histogram.get(0);
+    Date lwmDate = Utils.toDate(partition.getLowWatermark(), 
Partitioner.WATERMARKTIMEFORMAT);
+    histogram.getGroups().set(0, new 
HistogramGroup(Utils.epochToDate(lwmDate.getTime(), 
SalesforceSource.SECONDS_FORMAT),
+        firstGroup.getCount()));
+
+    // refine the histogram
+    if (state.getPropAsBoolean(SalesforceSource.ENABLE_DYNAMIC_PROBING)) {
+      histogram = getRefinedHistogram(salesforceConnector, entity, 
watermarkColumn, state, partition, histogram);
+    }
+
+    return histogram;
+  }
+
+  /**
+   * Get a histogram with day granularity buckets.
+   */
+  private Histogram getHistogramByDayBucketing(SalesforceConnector connector, 
String entity, String watermarkColumn,
+      Partition partition) {
+    Histogram histogram = new Histogram();
+
+    Calendar calendar = new GregorianCalendar();
+    Date startDate = Utils.toDate(partition.getLowWatermark(), 
Partitioner.WATERMARKTIMEFORMAT);
+    calendar.setTime(startDate);
+    int startYear = calendar.get(Calendar.YEAR);
+    String lowWatermarkDate = Utils.dateToString(startDate, 
SalesforceExtractor.SALESFORCE_TIMESTAMP_FORMAT);
+
+    Date endDate = Utils.toDate(partition.getHighWatermark(), 
Partitioner.WATERMARKTIMEFORMAT);
+    calendar.setTime(endDate);
+    int endYear = calendar.get(Calendar.YEAR);
+    String highWatermarkDate = Utils.dateToString(endDate, 
SalesforceExtractor.SALESFORCE_TIMESTAMP_FORMAT);
+
+    Map<String, String> values = new HashMap<>();
+    values.put("table", entity);
+    values.put("column", watermarkColumn);
+    StrSubstitutor sub = new StrSubstitutor(values);
+
+    for (int year = startYear; year <= endYear; year++) {
+      if (year == startYear) {
+        values.put("start", lowWatermarkDate);
+        values.put("greater", partition.isLowWatermarkInclusive() ? ">=" : 
">");
+      } else {
+        values.put("start", getDateString(year));
+        values.put("greater", ">=");
+      }
+
+      if (year == endYear) {
+        values.put("end", highWatermarkDate);
+        values.put("less", partition.isHighWatermarkInclusive() ? "<=" : "<");
+      } else {
+        values.put("end", getDateString(year + 1));
+        values.put("less", "<");
+      }
+
+      String query = sub.replace(DAY_PARTITION_QUERY_TEMPLATE);
+      log.info("Histogram query: " + query);
+
+      histogram.add(parseDayBucketingHistogram(getRecordsForQuery(connector, 
query)));
+    }
+
+    return histogram;
+  }
+
+  /**
+   * Refine the histogram by probing to split large buckets
+   * @return the refined histogram
+   */
+  private Histogram getRefinedHistogram(SalesforceConnector connector, String 
entity, String watermarkColumn,
+      SourceState state, Partition partition, Histogram histogram) {
+    final int maxPartitions = 
state.getPropAsInt(SOURCE_MAX_NUMBER_OF_PARTITIONS, 
DEFAULT_MAX_NUMBER_OF_PARTITIONS);
+    final int probeLimit = state.getPropAsInt(
+        DYNAMIC_PROBING_LIMIT, DEFAULT_DYNAMIC_PROBING_LIMIT);
+    final int minTargetPartitionSize = state.getPropAsInt(
+        SalesforceSource.MIN_TARGET_PARTITION_SIZE, 
SalesforceSource.DEFAULT_MIN_TARGET_PARTITION_SIZE);
+    final Histogram outputHistogram = new Histogram();
+    final double probeTargetRatio = state.getPropAsDouble(
+        PROBE_TARGET_RATIO, DEFAULT_PROBE_TARGET_RATIO);
+    final int bucketSizeLimit =
+        (int) (probeTargetRatio * computeTargetPartitionSize(histogram, 
minTargetPartitionSize, maxPartitions));
+
+    log.info("Refining histogram with bucket size limit {}.", bucketSizeLimit);
+
+    HistogramGroup currentGroup;
+    HistogramGroup nextGroup;
+    final TableCountProbingContext probingContext =
+        new TableCountProbingContext(connector, entity, watermarkColumn, 
bucketSizeLimit, probeLimit);
+
+    if (histogram.getGroups().isEmpty()) {
+      return outputHistogram;
+    }
+
+    // make a copy of the histogram list and add a dummy entry at the end to 
avoid special processing of the last group
+    List<HistogramGroup> list = new ArrayList(histogram.getGroups());
+    Date hwmDate = Utils.toDate(partition.getHighWatermark(), 
Partitioner.WATERMARKTIMEFORMAT);
+    list.add(new HistogramGroup(Utils.epochToDate(hwmDate.getTime(), 
SalesforceSource.SECONDS_FORMAT), 0));
+
+    for (int i = 0; i < list.size() - 1; i++) {
+      currentGroup = list.get(i);
+      nextGroup = list.get(i + 1);
+
+      // split the group if it is larger than the bucket size limit
+      if (currentGroup.getCount() > bucketSizeLimit) {
+        long startEpoch = Utils.toDate(currentGroup.getKey(), 
SalesforceSource.SECONDS_FORMAT).getTime();
+        long endEpoch = Utils.toDate(nextGroup.getKey(), 
SalesforceSource.SECONDS_FORMAT).getTime();
+
+        outputHistogram.add(getHistogramByProbing(probingContext, 
currentGroup.getCount(), startEpoch, endEpoch));
+      } else {
+        outputHistogram.add(currentGroup);
+      }
+    }
+
+    log.info("Executed {} probes for refining the histogram.", 
probingContext.probeCount);
+
+    // if the probe limit has been reached then print a warning
+    if (probingContext.probeCount >= probingContext.probeLimit) {
+      log.warn("Reached the probe limit");
+    }
+
+    return outputHistogram;
+  }
+
+
+  /**
+   * Get a histogram for the time range by probing to break down large 
buckets. Use count instead of
+   * querying if it is non-negative.
+   */
+  private Histogram getHistogramByProbing(TableCountProbingContext 
probingContext, int count, long startEpoch,
+      long endEpoch) {
+    Histogram histogram = new Histogram();
+
+    Map<String, String> values = new HashMap<>();
+    values.put("table", probingContext.entity);
+    values.put("column", probingContext.watermarkColumn);
+    values.put("greater", ">=");
+    values.put("less", "<");
+    StrSubstitutor sub = new StrSubstitutor(values);
+
+    getHistogramRecursively(probingContext, histogram, sub, values, count, 
startEpoch, endEpoch);
+
+    return histogram;
+  }
+
+  private String getDateString(int year) {
+    Calendar calendar = new GregorianCalendar();
+    calendar.clear();
+    calendar.set(Calendar.YEAR, year);
+    return Utils.dateToString(calendar.getTime(), 
SalesforceExtractor.SALESFORCE_TIMESTAMP_FORMAT);
+  }
+
+  /**
+   * Parse the query results into a {@link Histogram}
+   */
+  private Histogram parseDayBucketingHistogram(JsonArray records) {
+    log.info("Parse day-based histogram");
+
+    Histogram histogram = new Histogram();
+
+    Iterator<JsonElement> elements = records.iterator();
+    JsonObject element;
+
+    while (elements.hasNext()) {
+      element = elements.next().getAsJsonObject();
+      String time = element.get("time").getAsString() + ZERO_TIME_SUFFIX;
+      int count = element.get("cnt").getAsInt();
+
+      histogram.add(new HistogramGroup(time, count));
+    }
+
+    return histogram;
+  }
+
+  /**
+   * Split a histogram bucket along the midpoint if it is larger than the 
bucket size limit.
+   */
+  private void getHistogramRecursively(TableCountProbingContext 
probingContext, Histogram histogram, StrSubstitutor sub,
+      Map<String, String> values, int count, long startEpoch, long endEpoch) {
+    long midpointEpoch = startEpoch + (endEpoch - startEpoch) / 2;
+
+    // don't split further if small, above the probe limit, or less than 1 
second difference between the midpoint and start
+    if (count <= probingContext.bucketSizeLimit
+        || probingContext.probeCount > probingContext.probeLimit
+        || (midpointEpoch - startEpoch < MIN_SPLIT_TIME_MILLIS)) {
+      histogram.add(new HistogramGroup(Utils.epochToDate(startEpoch, 
SalesforceSource.SECONDS_FORMAT), count));
+      return;
+    }
+
+    int countLeft = getCountForRange(probingContext, sub, values, startEpoch, 
midpointEpoch);
+
+    getHistogramRecursively(probingContext, histogram, sub, values, countLeft, 
startEpoch, midpointEpoch);
+    log.debug("Count {} for left partition {} to {}", countLeft, startEpoch, 
midpointEpoch);
+
+    int countRight = count - countLeft;
+
+    getHistogramRecursively(probingContext, histogram, sub, values, 
countRight, midpointEpoch, endEpoch);
+    log.debug("Count {} for right partition {} to {}", countRight, 
midpointEpoch, endEpoch);
+  }
+
+
+  /**
+   * Get a {@link JsonArray} containing the query results
+   */
+  @SneakyThrows
+  private JsonArray getRecordsForQuery(SalesforceConnector connector, String 
query) {
+    RestApiProcessingException exception = null;
+    for (int i = 0; i < sfConfig.restApiRetryLimit + 1; i++) {
+      try {
+        String soqlQuery = SalesforceExtractor.getSoqlUrl(query);
+        List<Command> commands = 
RestApiConnector.constructGetCommand(connector.getFullUri(soqlQuery));
+        CommandOutput<?, ?> response = connector.getResponse(commands);
+
+        String output;
+        Iterator<String> itr = (Iterator<String>) 
response.getResults().values().iterator();
+        if (itr.hasNext()) {
+          output = itr.next();
+        } else {
+          throw new DataRecordException("Failed to get data from salesforce; 
REST response has no output");
+        }
+
+        return GSON.fromJson(output, 
JsonObject.class).getAsJsonArray("records");
+      } catch (RestApiClientException | DataRecordException e) {
+        throw new RuntimeException("Fail to get data from salesforce", e);
+      } catch (RestApiProcessingException e) {
+        exception = e;
+        log.info("Caught RestApiProcessingException, retrying({}) rest query: 
{}", i+1, query);
+        Thread.sleep(sfConfig.restApiRetryInterval);
+      }
+    }
+    throw new RuntimeException("Fail to get data from salesforce", exception);
+  }
+
+  /**
+   * Get the row count for a time range
+   */
+  private int getCountForRange(TableCountProbingContext probingContext, 
StrSubstitutor sub,
+      Map<String, String> subValues, long startTime, long endTime) {
+    String startTimeStr = Utils.dateToString(new Date(startTime), 
SalesforceExtractor.SALESFORCE_TIMESTAMP_FORMAT);
+    String endTimeStr = Utils.dateToString(new Date(endTime), 
SalesforceExtractor.SALESFORCE_TIMESTAMP_FORMAT);
+
+    subValues.put("start", startTimeStr);
+    subValues.put("end", endTimeStr);
+
+    String query = sub.replace(PROBE_PARTITION_QUERY_TEMPLATE);
+
+    log.debug("Count query: " + query);
+    probingContext.probeCount++;
+
+    JsonArray records = getRecordsForQuery(probingContext.connector, query);
+    Iterator<JsonElement> elements = records.iterator();
+    JsonObject element = elements.next().getAsJsonObject();
+
+    return element.get("cnt").getAsInt();
+  }
+
+  /**
+   * Compute the target partition size.
+   */
+  private int computeTargetPartitionSize(Histogram histogram, int 
minTargetPartitionSize, int maxPartitions) {
+    return Math.max(minTargetPartitionSize,
+        DoubleMath.roundToInt((double) histogram.getTotalRecordCount() / 
maxPartitions, RoundingMode.CEILING));
+  }
+
+  /**
+   * Context for probing the table for row counts of a time range
+   */
+  @RequiredArgsConstructor
+  public static class TableCountProbingContext {
+    private final SalesforceConnector connector;
+    private final String entity;
+    private final String watermarkColumn;
+    private final int bucketSizeLimit;
+    private final int probeLimit;
+
+    private int probeCount = 0;
+  }
+}
diff --git 
a/gobblin-salesforce/src/main/java/org/apache/gobblin/salesforce/SalesforceSource.java
 
b/gobblin-salesforce/src/main/java/org/apache/gobblin/salesforce/SalesforceSource.java
index 9d06402bf..aae7681e1 100644
--- 
a/gobblin-salesforce/src/main/java/org/apache/gobblin/salesforce/SalesforceSource.java
+++ 
b/gobblin-salesforce/src/main/java/org/apache/gobblin/salesforce/SalesforceSource.java
@@ -17,37 +17,29 @@
 
 package org.apache.gobblin.salesforce;
 
-import com.google.common.collect.Lists;
-import java.io.IOException;
-import java.math.RoundingMode;
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.Calendar;
-import java.util.Date;
-import java.util.GregorianCalendar;
-import java.util.HashMap;
-import java.util.Iterator;
-import java.util.List;
-import java.util.Map;
-import java.util.Set;
-
-import java.util.stream.Collectors;
-import lombok.SneakyThrows;
-import org.apache.commons.lang3.text.StrSubstitutor;
-import org.apache.commons.math3.stat.descriptive.DescriptiveStatistics;
-
 import com.google.common.annotations.VisibleForTesting;
 import com.google.common.base.Joiner;
 import com.google.common.base.Optional;
 import com.google.common.base.Strings;
 import com.google.common.base.Throwables;
+import com.google.common.collect.Lists;
 import com.google.common.collect.Sets;
 import com.google.common.math.DoubleMath;
 import com.google.gson.Gson;
 import com.google.gson.JsonArray;
 import com.google.gson.JsonElement;
 import com.google.gson.JsonObject;
-
+import java.io.IOException;
+import java.math.RoundingMode;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Date;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Set;
+import java.util.stream.Collectors;
+import lombok.extern.slf4j.Slf4j;
+import org.apache.commons.math3.stat.descriptive.DescriptiveStatistics;
 import org.apache.gobblin.configuration.ConfigurationKeys;
 import org.apache.gobblin.configuration.SourceState;
 import org.apache.gobblin.configuration.State;
@@ -55,10 +47,9 @@ import org.apache.gobblin.configuration.WorkUnitState;
 import org.apache.gobblin.dataset.DatasetConstants;
 import org.apache.gobblin.dataset.DatasetDescriptor;
 import org.apache.gobblin.metrics.event.lineage.LineageInfo;
-import org.apache.gobblin.source.extractor.DataRecordException;
+import org.apache.gobblin.salesforce.SalesforceExtractor.BatchIdAndResultId;
 import org.apache.gobblin.source.extractor.Extractor;
 import org.apache.gobblin.source.extractor.exception.ExtractPrepareException;
-import org.apache.gobblin.source.extractor.exception.RestApiClientException;
 import 
org.apache.gobblin.source.extractor.exception.RestApiConnectionException;
 import 
org.apache.gobblin.source.extractor.exception.RestApiProcessingException;
 import org.apache.gobblin.source.extractor.extract.Command;
@@ -73,16 +64,9 @@ import 
org.apache.gobblin.source.extractor.watermark.WatermarkType;
 import org.apache.gobblin.source.workunit.Extract;
 import org.apache.gobblin.source.workunit.WorkUnit;
 
-import lombok.AllArgsConstructor;
-import lombok.Getter;
-import lombok.RequiredArgsConstructor;
-import lombok.extern.slf4j.Slf4j;
-
 import static org.apache.gobblin.configuration.ConfigurationKeys.*;
 import static org.apache.gobblin.salesforce.SalesforceConfigurationKeys.*;
 
-import org.apache.gobblin.salesforce.SalesforceExtractor.BatchIdAndResultId;
-
 /**
  * An implementation of {@link QueryBasedSource} for salesforce data sources.
  */
@@ -91,35 +75,23 @@ public class SalesforceSource extends 
QueryBasedSource<JsonArray, JsonElement> {
   public static final String USE_ALL_OBJECTS = "use.all.objects";
   public static final boolean DEFAULT_USE_ALL_OBJECTS = false;
 
-  private static final String ENABLE_DYNAMIC_PROBING = 
"salesforce.enableDynamicProbing";
-  private static final String DYNAMIC_PROBING_LIMIT = 
"salesforce.dynamicProbingLimit";
-  private static final int DEFAULT_DYNAMIC_PROBING_LIMIT = 1000;
-  private static final String MIN_TARGET_PARTITION_SIZE = 
"salesforce.minTargetPartitionSize";
-  private static final int DEFAULT_MIN_TARGET_PARTITION_SIZE = 250000;
-  // this is used to generate histogram buckets smaller than the target 
partition size to allow for more even
-  // packing of the generated partitions
-  private static final String PROBE_TARGET_RATIO = 
"salesforce.probeTargetRatio";
-  private static final double DEFAULT_PROBE_TARGET_RATIO = 0.60;
-  private static final int MIN_SPLIT_TIME_MILLIS = 1000;
-
-  private static final String DAY_PARTITION_QUERY_TEMPLATE =
-      "SELECT count(${column}) cnt, DAY_ONLY(${column}) time FROM ${table} " + 
"WHERE ${column} ${greater} ${start}"
-          + " AND ${column} ${less} ${end} GROUP BY DAY_ONLY(${column}) ORDER 
BY DAY_ONLY(${column})";
-  private static final String PROBE_PARTITION_QUERY_TEMPLATE = "SELECT 
count(${column}) cnt FROM ${table} "
-      + "WHERE ${column} ${greater} ${start} AND ${column} ${less} ${end}";
-
-  private static final String ENABLE_DYNAMIC_PARTITIONING = 
"salesforce.enableDynamicPartitioning";
-  private static final String EARLY_STOP_TOTAL_RECORDS_LIMIT = 
"salesforce.earlyStopTotalRecordsLimit";
+  @VisibleForTesting
+  static final String ENABLE_DYNAMIC_PROBING = 
"salesforce.enableDynamicProbing";
+  static final String MIN_TARGET_PARTITION_SIZE = 
"salesforce.minTargetPartitionSize";
+  static final int DEFAULT_MIN_TARGET_PARTITION_SIZE = 250000;
+
+  @VisibleForTesting
+  static final String ENABLE_DYNAMIC_PARTITIONING = 
"salesforce.enableDynamicPartitioning";
+  @VisibleForTesting
+  static final String EARLY_STOP_TOTAL_RECORDS_LIMIT = 
"salesforce.earlyStopTotalRecordsLimit";
   private static final long DEFAULT_EARLY_STOP_TOTAL_RECORDS_LIMIT = 
DEFAULT_MIN_TARGET_PARTITION_SIZE * 4;
 
-  private static final String SECONDS_FORMAT = "yyyy-MM-dd-HH:mm:ss";
-  private static final String ZERO_TIME_SUFFIX = "-00:00:00";
+  static final String SECONDS_FORMAT = "yyyy-MM-dd-HH:mm:ss";
 
-  private static final Gson GSON = new Gson();
   private boolean isEarlyStopped = false;
   protected SalesforceConnector salesforceConnector = null;
 
-  private SfConfig workUnitConf;
+  private SalesforceHistogramService salesforceHistogramService;
 
   public SalesforceSource() {
     this.lineageInfo = Optional.absent();
@@ -130,6 +102,12 @@ public class SalesforceSource extends 
QueryBasedSource<JsonArray, JsonElement> {
     this.lineageInfo = Optional.fromNullable(lineageInfo);
   }
 
+  @VisibleForTesting
+  SalesforceSource(SalesforceHistogramService salesforceHistogramService) {
+    this.lineageInfo = Optional.absent();
+    this.salesforceHistogramService = salesforceHistogramService;
+  }
+
   @Override
   public Extractor<JsonArray, JsonElement> getExtractor(WorkUnitState state) 
throws IOException {
     try {
@@ -155,8 +133,14 @@ public class SalesforceSource extends 
QueryBasedSource<JsonArray, JsonElement> {
   }
   @Override
   protected List<WorkUnit> generateWorkUnits(SourceEntity sourceEntity, 
SourceState state, long previousWatermark) {
-    List<WorkUnit> workUnits = null;
-    workUnitConf = new SfConfig(state.getProperties());
+    SalesforceConnector connector = getConnector(state);
+
+    SfConfig sfConfig = new SfConfig(state.getProperties());
+    if (salesforceHistogramService == null) {
+      salesforceHistogramService = new SalesforceHistogramService(sfConfig, 
connector);
+    }
+
+    List<WorkUnit> workUnits;
     String partitionType = state.getProp(SALESFORCE_PARTITION_TYPE, "");
     if (partitionType.equals("PK_CHUNKING")) {
       // pk-chunking only supports start-time by 
source.querybased.start.value, and does not support end-time.
@@ -164,10 +148,10 @@ public class SalesforceSource extends 
QueryBasedSource<JsonArray, JsonElement> {
       // we should only pk chunking based work units only in case of 
snapshot/full ingestion
       workUnits = generateWorkUnitsPkChunking(sourceEntity, state, 
previousWatermark);
     } else {
-      workUnits = generateWorkUnitsStrategy(sourceEntity, state, 
previousWatermark);
+      workUnits = generateWorkUnitsHelper(sourceEntity, state, 
previousWatermark);
     }
     log.info("====Generated {} workUnit(s)====", workUnits.size());
-    if (workUnitConf.partitionOnly) {
+    if (sfConfig.partitionOnly) {
       log.info("It is partitionOnly mode, return blank workUnit list");
       return new ArrayList<>();
     } else {
@@ -277,11 +261,13 @@ public class SalesforceSource extends 
QueryBasedSource<JsonArray, JsonElement> {
   }
 
   /**
-   *
+   * Generates {@link WorkUnit}s based on a bunch of config values like max 
number of partitions, early stop,
+   * dynamic partitioning, dynamic probing, etc.
    */
-  private List<WorkUnit> generateWorkUnitsStrategy(SourceEntity sourceEntity, 
SourceState state, long previousWatermark) {
-    Boolean disableSoft = 
state.getPropAsBoolean(SOURCE_QUERYBASED_SALESFORCE_IS_SOFT_DELETES_PULL_DISABLED,
 false);
-    log.info("disable soft delete pull: " + disableSoft);
+  @VisibleForTesting
+  List<WorkUnit> generateWorkUnitsHelper(SourceEntity sourceEntity, 
SourceState state, long previousWatermark) {
+    boolean isSoftDeletePullDisabled = 
state.getPropAsBoolean(SOURCE_QUERYBASED_SALESFORCE_IS_SOFT_DELETES_PULL_DISABLED,
 false);
+    log.info("disable soft delete pull: " + isSoftDeletePullDisabled);
     WatermarkType watermarkType = WatermarkType.valueOf(
         state.getProp(ConfigurationKeys.SOURCE_QUERYBASED_WATERMARK_TYPE, 
ConfigurationKeys.DEFAULT_WATERMARK_TYPE)
             .toUpperCase());
@@ -292,10 +278,12 @@ public class SalesforceSource extends 
QueryBasedSource<JsonArray, JsonElement> {
     int minTargetPartitionSize = state.getPropAsInt(MIN_TARGET_PARTITION_SIZE, 
DEFAULT_MIN_TARGET_PARTITION_SIZE);
 
     // Only support time related watermark
-    if (watermarkType == WatermarkType.SIMPLE || 
Strings.isNullOrEmpty(watermarkColumn) || !state.getPropAsBoolean(
-        ENABLE_DYNAMIC_PARTITIONING)) {
+    if (watermarkType == WatermarkType.SIMPLE
+        || Strings.isNullOrEmpty(watermarkColumn)
+        || !state.getPropAsBoolean(ENABLE_DYNAMIC_PARTITIONING)) {
       List<WorkUnit> workUnits = super.generateWorkUnits(sourceEntity, state, 
previousWatermark);
-      workUnits.forEach(x -> 
x.setProp(SOURCE_QUERYBASED_SALESFORCE_IS_SOFT_DELETES_PULL_DISABLED, 
disableSoft));
+      workUnits.forEach(workUnit ->
+          
workUnit.setProp(SOURCE_QUERYBASED_SALESFORCE_IS_SOFT_DELETES_PULL_DISABLED, 
isSoftDeletePullDisabled));
       return workUnits;
     }
 
@@ -305,7 +293,8 @@ public class SalesforceSource extends 
QueryBasedSource<JsonArray, JsonElement> {
     }
 
     Partition partition = partitioner.getGlobalPartition(previousWatermark);
-    Histogram histogram = getHistogram(sourceEntity.getSourceEntityName(), 
watermarkColumn, state, partition);
+    Histogram histogram =
+        
salesforceHistogramService.getHistogram(sourceEntity.getSourceEntityName(), 
watermarkColumn, state, partition);
 
     // we should look if the count is too big, cut off early if count exceeds 
the limit, or bucket size is too large
 
@@ -316,8 +305,8 @@ public class SalesforceSource extends 
QueryBasedSource<JsonArray, JsonElement> {
       histogramAdjust = new Histogram();
       for (HistogramGroup group : histogram.getGroups()) {
         histogramAdjust.add(group);
-        if (histogramAdjust.getTotalRecordCount() > state
-            .getPropAsLong(EARLY_STOP_TOTAL_RECORDS_LIMIT, 
DEFAULT_EARLY_STOP_TOTAL_RECORDS_LIMIT)) {
+        long earlyStopRecordLimit = 
state.getPropAsLong(EARLY_STOP_TOTAL_RECORDS_LIMIT, 
DEFAULT_EARLY_STOP_TOTAL_RECORDS_LIMIT);
+        if (histogramAdjust.getTotalRecordCount() > earlyStopRecordLimit) {
           break;
         }
       }
@@ -345,7 +334,8 @@ public class SalesforceSource extends 
QueryBasedSource<JsonArray, JsonElement> {
     state.setProp(Partitioner.IS_EARLY_STOPPED, isEarlyStopped);
 
     List<WorkUnit> workUnits = super.generateWorkUnits(sourceEntity, state, 
previousWatermark);
-    workUnits.stream().forEach(x -> 
x.setProp(SOURCE_QUERYBASED_SALESFORCE_IS_SOFT_DELETES_PULL_DISABLED, 
disableSoft));
+    workUnits.forEach(workUnit ->
+        
workUnit.setProp(SOURCE_QUERYBASED_SALESFORCE_IS_SOFT_DELETES_PULL_DISABLED, 
isSoftDeletePullDisabled));
     return workUnits;
   }
 
@@ -353,12 +343,13 @@ public class SalesforceSource extends 
QueryBasedSource<JsonArray, JsonElement> {
     return state.getPropAsBoolean(ConfigurationKeys.SOURCE_EARLY_STOP_ENABLED, 
ConfigurationKeys.DEFAULT_SOURCE_EARLY_STOP_ENABLED);
   }
 
+  @VisibleForTesting
   String generateSpecifiedPartitions(Histogram histogram, int 
minTargetPartitionSize, int maxPartitions, long lowWatermark,
       long expectedHighWatermark) {
     int interval = computeTargetPartitionSize(histogram, 
minTargetPartitionSize, maxPartitions);
     int totalGroups = histogram.getGroups().size();
 
-    log.info("Histogram total record count: " + histogram.totalRecordCount);
+    log.info("Histogram total record count: " + 
histogram.getTotalRecordCount());
     log.info("Histogram total groups: " + totalGroups);
     log.info("maxPartitions: " + maxPartitions);
     log.info("interval: " + interval);
@@ -384,15 +375,15 @@ public class SalesforceSource extends 
QueryBasedSource<JsonArray, JsonElement> {
        * will have total size larger or equal to interval x 2. Hence, we are 
saturating all intervals (with original size)
        * without leaving any unused space in between. We could choose x3,x4... 
but it is not space efficient.
        */
-      if (count != 0 && count + group.count >= 2 * interval) {
+      if (count != 0 && count + group.getCount() >= 2 * interval) {
         // Summarize current group
         statistics.addValue(count);
         // A step-in start
         partitionPoints.add(Utils.toDateTimeFormat(group.getKey(), 
SECONDS_FORMAT, Partitioner.WATERMARKTIMEFORMAT));
-        count = group.count;
+        count = group.getCount();
       } else {
         // Add group into current partition
-        count += group.count;
+        count += group.getCount();
       }
 
       if (count >= interval) {
@@ -428,325 +419,7 @@ public class SalesforceSource extends 
QueryBasedSource<JsonArray, JsonElement> {
    */
   private int computeTargetPartitionSize(Histogram histogram, int 
minTargetPartitionSize, int maxPartitions) {
     return Math.max(minTargetPartitionSize,
-        DoubleMath.roundToInt((double) histogram.totalRecordCount / 
maxPartitions, RoundingMode.CEILING));
-  }
-
-  /**
-   * Get a {@link JsonArray} containing the query results
-   */
-  @SneakyThrows
-  private JsonArray getRecordsForQuery(SalesforceConnector connector, String 
query) {
-    RestApiProcessingException exception = null;
-    for (int i = 0; i < workUnitConf.restApiRetryLimit + 1; i++) {
-      try {
-        String soqlQuery = SalesforceExtractor.getSoqlUrl(query);
-        List<Command> commands = 
RestApiConnector.constructGetCommand(connector.getFullUri(soqlQuery));
-        CommandOutput<?, ?> response = connector.getResponse(commands);
-
-        String output;
-        Iterator<String> itr = (Iterator<String>) 
response.getResults().values().iterator();
-        if (itr.hasNext()) {
-          output = itr.next();
-        } else {
-          throw new DataRecordException("Failed to get data from salesforce; 
REST response has no output");
-        }
-
-        return GSON.fromJson(output, 
JsonObject.class).getAsJsonArray("records");
-      } catch (RestApiClientException | DataRecordException e) {
-        throw new RuntimeException("Fail to get data from salesforce", e);
-      } catch (RestApiProcessingException e) {
-        exception = e;
-        log.info("Caught RestApiProcessingException, retrying({}) rest query: 
{}", i+1, query);
-        Thread.sleep(workUnitConf.restApiRetryInterval);
-        continue;
-      }
-    }
-    throw new RuntimeException("Fail to get data from salesforce", exception);
-  }
-
-  /**
-   * Get the row count for a time range
-   */
-  private int getCountForRange(TableCountProbingContext probingContext, 
StrSubstitutor sub,
-      Map<String, String> subValues, long startTime, long endTime) {
-    String startTimeStr = Utils.dateToString(new Date(startTime), 
SalesforceExtractor.SALESFORCE_TIMESTAMP_FORMAT);
-    String endTimeStr = Utils.dateToString(new Date(endTime), 
SalesforceExtractor.SALESFORCE_TIMESTAMP_FORMAT);
-
-    subValues.put("start", startTimeStr);
-    subValues.put("end", endTimeStr);
-
-    String query = sub.replace(PROBE_PARTITION_QUERY_TEMPLATE);
-
-    log.debug("Count query: " + query);
-    probingContext.probeCount++;
-
-    JsonArray records = getRecordsForQuery(probingContext.connector, query);
-    Iterator<JsonElement> elements = records.iterator();
-    JsonObject element = elements.next().getAsJsonObject();
-
-    return element.get("cnt").getAsInt();
-  }
-
-  /**
-   * Split a histogram bucket along the midpoint if it is larger than the 
bucket size limit.
-   */
-  private void getHistogramRecursively(TableCountProbingContext 
probingContext, Histogram histogram, StrSubstitutor sub,
-      Map<String, String> values, int count, long startEpoch, long endEpoch) {
-    long midpointEpoch = startEpoch + (endEpoch - startEpoch) / 2;
-
-    // don't split further if small, above the probe limit, or less than 1 
second difference between the midpoint and start
-    if (count <= probingContext.bucketSizeLimit
-        || probingContext.probeCount > probingContext.probeLimit
-        || (midpointEpoch - startEpoch < MIN_SPLIT_TIME_MILLIS)) {
-      histogram.add(new HistogramGroup(Utils.epochToDate(startEpoch, 
SECONDS_FORMAT), count));
-      return;
-    }
-
-    int countLeft = getCountForRange(probingContext, sub, values, startEpoch, 
midpointEpoch);
-
-    getHistogramRecursively(probingContext, histogram, sub, values, countLeft, 
startEpoch, midpointEpoch);
-    log.debug("Count {} for left partition {} to {}", countLeft, startEpoch, 
midpointEpoch);
-
-    int countRight = count - countLeft;
-
-    getHistogramRecursively(probingContext, histogram, sub, values, 
countRight, midpointEpoch, endEpoch);
-    log.debug("Count {} for right partition {} to {}", countRight, 
midpointEpoch, endEpoch);
-  }
-
-  /**
-   * Get a histogram for the time range by probing to break down large 
buckets. Use count instead of
-   * querying if it is non-negative.
-   */
-  private Histogram getHistogramByProbing(TableCountProbingContext 
probingContext, int count, long startEpoch,
-      long endEpoch) {
-    Histogram histogram = new Histogram();
-
-    Map<String, String> values = new HashMap<>();
-    values.put("table", probingContext.entity);
-    values.put("column", probingContext.watermarkColumn);
-    values.put("greater", ">=");
-    values.put("less", "<");
-    StrSubstitutor sub = new StrSubstitutor(values);
-
-    getHistogramRecursively(probingContext, histogram, sub, values, count, 
startEpoch, endEpoch);
-
-    return histogram;
-  }
-
-  /**
-   * Refine the histogram by probing to split large buckets
-   * @return the refined histogram
-   */
-  private Histogram getRefinedHistogram(SalesforceConnector connector, String 
entity, String watermarkColumn,
-      SourceState state, Partition partition, Histogram histogram) {
-    final int maxPartitions = 
state.getPropAsInt(SOURCE_MAX_NUMBER_OF_PARTITIONS, 
DEFAULT_MAX_NUMBER_OF_PARTITIONS);
-    final int probeLimit = state.getPropAsInt(DYNAMIC_PROBING_LIMIT, 
DEFAULT_DYNAMIC_PROBING_LIMIT);
-    final int minTargetPartitionSize = 
state.getPropAsInt(MIN_TARGET_PARTITION_SIZE, 
DEFAULT_MIN_TARGET_PARTITION_SIZE);
-    final Histogram outputHistogram = new Histogram();
-    final double probeTargetRatio = state.getPropAsDouble(PROBE_TARGET_RATIO, 
DEFAULT_PROBE_TARGET_RATIO);
-    final int bucketSizeLimit =
-        (int) (probeTargetRatio * computeTargetPartitionSize(histogram, 
minTargetPartitionSize, maxPartitions));
-
-    log.info("Refining histogram with bucket size limit {}.", bucketSizeLimit);
-
-    HistogramGroup currentGroup;
-    HistogramGroup nextGroup;
-    final TableCountProbingContext probingContext =
-        new TableCountProbingContext(connector, entity, watermarkColumn, 
bucketSizeLimit, probeLimit);
-
-    if (histogram.getGroups().isEmpty()) {
-      return outputHistogram;
-    }
-
-    // make a copy of the histogram list and add a dummy entry at the end to 
avoid special processing of the last group
-    List<HistogramGroup> list = new ArrayList(histogram.getGroups());
-    Date hwmDate = Utils.toDate(partition.getHighWatermark(), 
Partitioner.WATERMARKTIMEFORMAT);
-    list.add(new HistogramGroup(Utils.epochToDate(hwmDate.getTime(), 
SECONDS_FORMAT), 0));
-
-    for (int i = 0; i < list.size() - 1; i++) {
-      currentGroup = list.get(i);
-      nextGroup = list.get(i + 1);
-
-      // split the group if it is larger than the bucket size limit
-      if (currentGroup.count > bucketSizeLimit) {
-        long startEpoch = Utils.toDate(currentGroup.getKey(), 
SECONDS_FORMAT).getTime();
-        long endEpoch = Utils.toDate(nextGroup.getKey(), 
SECONDS_FORMAT).getTime();
-
-        outputHistogram.add(getHistogramByProbing(probingContext, 
currentGroup.count, startEpoch, endEpoch));
-      } else {
-        outputHistogram.add(currentGroup);
-      }
-    }
-
-    log.info("Executed {} probes for refining the histogram.", 
probingContext.probeCount);
-
-    // if the probe limit has been reached then print a warning
-    if (probingContext.probeCount >= probingContext.probeLimit) {
-      log.warn("Reached the probe limit");
-    }
-
-    return outputHistogram;
-  }
-
-  /**
-   * Get a histogram with day granularity buckets.
-   */
-  private Histogram getHistogramByDayBucketing(SalesforceConnector connector, 
String entity, String watermarkColumn,
-      Partition partition) {
-    Histogram histogram = new Histogram();
-
-    Calendar calendar = new GregorianCalendar();
-    Date startDate = Utils.toDate(partition.getLowWatermark(), 
Partitioner.WATERMARKTIMEFORMAT);
-    calendar.setTime(startDate);
-    int startYear = calendar.get(Calendar.YEAR);
-    String lowWatermarkDate = Utils.dateToString(startDate, 
SalesforceExtractor.SALESFORCE_TIMESTAMP_FORMAT);
-
-    Date endDate = Utils.toDate(partition.getHighWatermark(), 
Partitioner.WATERMARKTIMEFORMAT);
-    calendar.setTime(endDate);
-    int endYear = calendar.get(Calendar.YEAR);
-    String highWatermarkDate = Utils.dateToString(endDate, 
SalesforceExtractor.SALESFORCE_TIMESTAMP_FORMAT);
-
-    Map<String, String> values = new HashMap<>();
-    values.put("table", entity);
-    values.put("column", watermarkColumn);
-    StrSubstitutor sub = new StrSubstitutor(values);
-
-    for (int year = startYear; year <= endYear; year++) {
-      if (year == startYear) {
-        values.put("start", lowWatermarkDate);
-        values.put("greater", partition.isLowWatermarkInclusive() ? ">=" : 
">");
-      } else {
-        values.put("start", getDateString(year));
-        values.put("greater", ">=");
-      }
-
-      if (year == endYear) {
-        values.put("end", highWatermarkDate);
-        values.put("less", partition.isHighWatermarkInclusive() ? "<=" : "<");
-      } else {
-        values.put("end", getDateString(year + 1));
-        values.put("less", "<");
-      }
-
-      String query = sub.replace(DAY_PARTITION_QUERY_TEMPLATE);
-      log.info("Histogram query: " + query);
-
-      histogram.add(parseDayBucketingHistogram(getRecordsForQuery(connector, 
query)));
-    }
-
-    return histogram;
-  }
-
-  protected SalesforceConnector getConnector(State state) {
-    if (this.salesforceConnector == null) {
-      this.salesforceConnector = new SalesforceConnector(state);
-    }
-    return this.salesforceConnector;
-  }
-
-  /**
-   * Generate the histogram
-   */
-  private Histogram getHistogram(String entity, String watermarkColumn, 
SourceState state,
-      Partition partition) {
-    SalesforceConnector connector = getConnector(state);
-
-    try {
-      if (!connector.connect()) {
-        throw new RuntimeException("Failed to connect.");
-      }
-    } catch (RestApiConnectionException e) {
-      throw new RuntimeException("Failed to connect.", e);
-    }
-
-    Histogram histogram = getHistogramByDayBucketing(connector, entity, 
watermarkColumn, partition);
-
-    // exchange the first histogram group key with the global low watermark to 
ensure that the low watermark is captured
-    // in the range of generated partitions
-    HistogramGroup firstGroup = histogram.get(0);
-    Date lwmDate = Utils.toDate(partition.getLowWatermark(), 
Partitioner.WATERMARKTIMEFORMAT);
-    histogram.getGroups().set(0, new 
HistogramGroup(Utils.epochToDate(lwmDate.getTime(), SECONDS_FORMAT),
-        firstGroup.getCount()));
-
-    // refine the histogram
-    if (state.getPropAsBoolean(ENABLE_DYNAMIC_PROBING)) {
-      histogram = getRefinedHistogram(connector, entity, watermarkColumn, 
state, partition, histogram);
-    }
-
-    return histogram;
-  }
-
-  private String getDateString(int year) {
-    Calendar calendar = new GregorianCalendar();
-    calendar.clear();
-    calendar.set(Calendar.YEAR, year);
-    return Utils.dateToString(calendar.getTime(), 
SalesforceExtractor.SALESFORCE_TIMESTAMP_FORMAT);
-  }
-
-  /**
-   * Parse the query results into a {@link Histogram}
-   */
-  private Histogram parseDayBucketingHistogram(JsonArray records) {
-    log.info("Parse day-based histogram");
-
-    Histogram histogram = new Histogram();
-
-    Iterator<JsonElement> elements = records.iterator();
-    JsonObject element;
-
-    while (elements.hasNext()) {
-      element = elements.next().getAsJsonObject();
-      String time = element.get("time").getAsString() + ZERO_TIME_SUFFIX;
-      int count = element.get("cnt").getAsInt();
-
-      histogram.add(new HistogramGroup(time, count));
-    }
-
-    return histogram;
-  }
-
-  @AllArgsConstructor
-  static class HistogramGroup {
-    @Getter
-    private final String key;
-    @Getter
-    private final int count;
-
-    @Override
-    public String toString() {
-      return key + ":" + count;
-    }
-  }
-
-  static class Histogram {
-    @Getter
-    private long totalRecordCount;
-    @Getter
-    private List<HistogramGroup> groups;
-
-    Histogram() {
-      totalRecordCount = 0;
-      groups = new ArrayList<>();
-    }
-
-    void add(HistogramGroup group) {
-      groups.add(group);
-      totalRecordCount += group.count;
-    }
-
-    void add(Histogram histogram) {
-      groups.addAll(histogram.getGroups());
-      totalRecordCount += histogram.totalRecordCount;
-    }
-
-    HistogramGroup get(int idx) {
-      return this.groups.get(idx);
-    }
-
-    @Override
-    public String toString() {
-      return groups.toString();
-    }
+        DoubleMath.roundToInt((double) histogram.getTotalRecordCount() / 
maxPartitions, RoundingMode.CEILING));
   }
 
   protected Set<SourceEntity> getSourceEntities(State state) {
@@ -788,17 +461,10 @@ public class SalesforceSource extends 
QueryBasedSource<JsonArray, JsonElement> {
     return result;
   }
 
-  /**
-   * Context for probing the table for row counts of a time range
-   */
-  @RequiredArgsConstructor
-  private static class TableCountProbingContext {
-    private final SalesforceConnector connector;
-    private final String entity;
-    private final String watermarkColumn;
-    private final int bucketSizeLimit;
-    private final int probeLimit;
-
-    private int probeCount = 0;
+  protected SalesforceConnector getConnector(State state) {
+    if (this.salesforceConnector == null) {
+      this.salesforceConnector = new SalesforceConnector(state);
+    }
+    return this.salesforceConnector;
   }
 }
diff --git 
a/gobblin-salesforce/src/test/java/org/apache/gobblin/salesforce/SalesforceSourceTest.java
 
b/gobblin-salesforce/src/test/java/org/apache/gobblin/salesforce/SalesforceSourceTest.java
index 422a93363..9e01b9c73 100644
--- 
a/gobblin-salesforce/src/test/java/org/apache/gobblin/salesforce/SalesforceSourceTest.java
+++ 
b/gobblin-salesforce/src/test/java/org/apache/gobblin/salesforce/SalesforceSourceTest.java
@@ -16,20 +16,22 @@
  */
 package org.apache.gobblin.salesforce;
 
-import java.util.List;
-
-import org.testng.Assert;
-import org.testng.annotations.Test;
-
+import com.google.gson.Gson;
 import com.typesafe.config.ConfigFactory;
-
+import java.util.HashMap;
+import java.util.List;
 import org.apache.gobblin.configuration.ConfigurationKeys;
 import org.apache.gobblin.configuration.SourceState;
 import org.apache.gobblin.metrics.event.lineage.LineageInfo;
 import org.apache.gobblin.source.extractor.extract.QueryBasedSource;
+import org.apache.gobblin.source.extractor.partition.Partition;
 import org.apache.gobblin.source.extractor.partition.Partitioner;
 import org.apache.gobblin.source.workunit.WorkUnit;
+import org.testng.Assert;
+import org.testng.annotations.DataProvider;
+import org.testng.annotations.Test;
 
+import static org.mockito.Mockito.*;
 
 public class SalesforceSourceTest {
   @Test
@@ -55,8 +57,8 @@ public class SalesforceSourceTest {
 
   @Test
   void testGenerateSpecifiedPartitionFromSinglePointHistogram() {
-    SalesforceSource.Histogram histogram = new SalesforceSource.Histogram();
-    histogram.add(new SalesforceSource.HistogramGroup("2014-02-13-00:00:00", 
10));
+    Histogram histogram = new Histogram();
+    histogram.add(new HistogramGroup("2014-02-13-00:00:00", 10));
     SalesforceSource source = new SalesforceSource();
 
     long expectedHighWatermark = 20170407152123L;
@@ -69,11 +71,7 @@ public class SalesforceSourceTest {
 
   @Test
   void testGenerateSpecifiedPartition() {
-    SalesforceSource.Histogram histogram = new SalesforceSource.Histogram();
-    for (String group: HISTOGRAM.split(", ")) {
-      String[] groupInfo = group.split("::");
-      histogram.add(new SalesforceSource.HistogramGroup(groupInfo[0], 
Integer.parseInt(groupInfo[1])));
-    }
+    Histogram histogram = getHistogram();
     SalesforceSource source = new SalesforceSource();
 
     long expectedHighWatermark = 20170407152123L;
@@ -84,5 +82,67 @@ public class SalesforceSourceTest {
     Assert.assertEquals(actualPartitions, expectedPartitions);
   }
 
+  @DataProvider
+  private Object[][] 
provideGenerateWorkUnitsHelperForSinglePartitionAndEarlyStopTestData() {
+    return new Object[][] {
+        {
+            1000L,  // earlyStopRecordCount
+            20140508000000L  // expectedHighWtm
+        },
+        {
+            10000L,
+            20150119000000L
+        },
+        {
+            100000L,
+            20170214000000L
+        },
+        {
+            1000000L,
+            20170301000000L
+        }
+    };
+  }
+  @Test(dataProvider = 
"provideGenerateWorkUnitsHelperForSinglePartitionAndEarlyStopTestData")
+  void testGenerateWorkUnitsHelperForSinglePartitionAndEarlyStop(long 
earlyStopRecordCount, long expectedHighWtm) {
+    QueryBasedSource.SourceEntity sourceEntity = 
QueryBasedSource.SourceEntity.fromSourceEntityName("contacts");
+    SourceState state = getDefaultSourceState();
+    state.setProp(ConfigurationKeys.SOURCE_MAX_NUMBER_OF_PARTITIONS, 1);
+    state.setProp(SalesforceSource.ENABLE_DYNAMIC_PARTITIONING, true);
+    state.setProp(ConfigurationKeys.SOURCE_EARLY_STOP_ENABLED, true);
+    state.setProp(SalesforceSource.ENABLE_DYNAMIC_PROBING, true);
+    state.setProp(SalesforceSource.EARLY_STOP_TOTAL_RECORDS_LIMIT, 
earlyStopRecordCount);
+    long previousWtm = 20140213000000L;
+
+    SalesforceHistogramService salesforceHistogramService = 
mock(SalesforceHistogramService.class);
+    String deltaFieldKey = 
state.getProp(ConfigurationKeys.EXTRACT_DELTA_FIELDS_KEY);
+    Partition partition = new 
Partitioner(state).getGlobalPartition(previousWtm);
+    
when(salesforceHistogramService.getHistogram(sourceEntity.getSourceEntityName(),
 deltaFieldKey, state, partition))
+        .thenReturn(getHistogram());
+
+    List<WorkUnit> actualWorkUnits =  new 
SalesforceSource(salesforceHistogramService).generateWorkUnitsHelper(sourceEntity,
 state, previousWtm);
+    Assert.assertEquals(actualWorkUnits.size(), 1);
+    double actualHighWtm = (double) new 
Gson().fromJson(actualWorkUnits.get(0).getExpectedHighWatermark(), 
HashMap.class).get("value");
+    Assert.assertEquals(actualHighWtm, 
Double.parseDouble(String.valueOf(expectedHighWtm)));
+  }
+
+  private SourceState getDefaultSourceState() {
+    SourceState sourceState = new SourceState();
+    sourceState.setProp(ConfigurationKeys.EXTRACT_NAMESPACE_NAME_KEY, 
"salesforce");
+    sourceState.setProp(ConfigurationKeys.EXTRACT_TABLE_TYPE_KEY, 
"snapshot_append");
+    sourceState.setProp(ConfigurationKeys.SOURCE_QUERYBASED_EXTRACT_TYPE, 
"SNAPSHOT");
+    sourceState.setProp(ConfigurationKeys.EXTRACT_DELTA_FIELDS_KEY, 
"LastModifiedDate");
+    return sourceState;
+  }
+
+  private Histogram getHistogram() {
+    Histogram histogram = new Histogram();
+    for (String group: HISTOGRAM.split(", ")) {
+      String[] groupInfo = group.split("::");
+      histogram.add(new HistogramGroup(groupInfo[0], 
Integer.parseInt(groupInfo[1])));
+    }
+    return histogram;
+  }
+
   static final String HISTOGRAM = "2014-02-13-00:00:00::3, 
2014-04-15-00:00:00::1, 2014-05-06-00:00:00::624, 2014-05-07-00:00:00::1497, 
2014-05-08-00:00:00::10, 2014-05-18-00:00:00::3, 2014-05-19-00:00:00::2, 
2014-05-20-00:00:00::1, 2014-05-21-00:00:00::8, 2014-05-26-00:00:00::2, 
2014-05-28-00:00:00::1, 2014-05-31-00:00:00::1, 2014-06-02-00:00:00::1, 
2014-06-03-00:00:00::1, 2014-06-04-00:00:00::1, 2014-06-10-00:00:00::2, 
2014-06-12-00:00:00::1, 2014-06-23-00:00:00::1, 2014-06-24-00:00:00 [...]
 }

Reply via email to