chamikaramj commented on a change in pull request #15549:
URL: https://github.com/apache/beam/pull/15549#discussion_r718685572



##########
File path: 
sdks/java/io/redis/src/main/java/org/apache/beam/sdk/io/redis/RedisIO.java
##########
@@ -334,94 +336,88 @@ public void setup() {
     public void teardown() {
       jedis.close();
     }
-  }
-
-  private static class ReadKeysWithPattern extends BaseReadFn<String> {
 
-    ReadKeysWithPattern(RedisConnectionConfiguration connectionConfiguration) {
-      super(connectionConfiguration);
+    @GetInitialRestriction
+    public OffsetRange getInitialRestriction(@Element String pattern) {
+      return new OffsetRange(0, getKeyPatters(pattern).size());
     }
 
     @ProcessElement
-    public void processElement(ProcessContext c) {
-      ScanParams scanParams = new ScanParams();
-      scanParams.match(c.element());
-
-      String cursor = ScanParams.SCAN_POINTER_START;
-      boolean finished = false;
-      while (!finished) {
-        ScanResult<String> scanResult = jedis.scan(cursor, scanParams);
-        List<String> keys = scanResult.getResult();
-        for (String k : keys) {
-          c.output(k);
+    public void processElement(ProcessContext c, 
RestrictionTracker<OffsetRange, Long> tracker) {
+      String pattern = c.element();
+      List<String> keys = getKeyPatters(pattern);
+      List<String> bundle = new ArrayList<>();
+      for (long i = tracker.currentRestriction().getFrom();
+          i < tracker.currentRestriction().getTo();
+          i++) {
+        if (tracker.tryClaim(i)) {
+          bundle.add(keys.get((int) i));
         }
-        cursor = scanResult.getCursor();
-        if (cursor.equals(ScanParams.SCAN_POINTER_START)) {
-          finished = true;
+      }
+      if (bundle.size() > 0) {
+        List<KV<String, String>> kvs = fetchAndFlush(bundle);

Review comment:
       Probably loading the whole bundle to memory here could result in OOMs.

##########
File path: 
sdks/java/io/redis/src/main/java/org/apache/beam/sdk/io/redis/RedisIO.java
##########
@@ -334,94 +336,88 @@ public void setup() {
     public void teardown() {
       jedis.close();
     }
-  }
-
-  private static class ReadKeysWithPattern extends BaseReadFn<String> {
 
-    ReadKeysWithPattern(RedisConnectionConfiguration connectionConfiguration) {
-      super(connectionConfiguration);
+    @GetInitialRestriction
+    public OffsetRange getInitialRestriction(@Element String pattern) {
+      return new OffsetRange(0, getKeyPatters(pattern).size());
     }
 
     @ProcessElement
-    public void processElement(ProcessContext c) {
-      ScanParams scanParams = new ScanParams();
-      scanParams.match(c.element());
-
-      String cursor = ScanParams.SCAN_POINTER_START;
-      boolean finished = false;
-      while (!finished) {
-        ScanResult<String> scanResult = jedis.scan(cursor, scanParams);
-        List<String> keys = scanResult.getResult();
-        for (String k : keys) {
-          c.output(k);
+    public void processElement(ProcessContext c, 
RestrictionTracker<OffsetRange, Long> tracker) {
+      String pattern = c.element();
+      List<String> keys = getKeyPatters(pattern);
+      List<String> bundle = new ArrayList<>();
+      for (long i = tracker.currentRestriction().getFrom();
+          i < tracker.currentRestriction().getTo();
+          i++) {
+        if (tracker.tryClaim(i)) {
+          bundle.add(keys.get((int) i));
         }
-        cursor = scanResult.getCursor();
-        if (cursor.equals(ScanParams.SCAN_POINTER_START)) {
-          finished = true;
+      }
+      if (bundle.size() > 0) {
+        List<KV<String, String>> kvs = fetchAndFlush(bundle);
+        for (KV<String, String> kv : kvs) {
+          c.output(kv);
         }
       }
     }
-  }
-
-  /** A {@link DoFn} requesting Redis server to get key/value pairs. */
-  private static class ReadFn extends BaseReadFn<KV<String, String>> {
-    transient @Nullable Multimap<BoundedWindow, String> bundles = null;
-    @Nullable AtomicInteger batchCount = null;
-    private final int batchSize;
-
-    ReadFn(RedisConnectionConfiguration connectionConfiguration, int 
batchSize) {
-      super(connectionConfiguration);
-      this.batchSize = batchSize;
-    }
 
-    @StartBundle
-    public void startBundle() {
-      bundles = ArrayListMultimap.create();
-      batchCount = new AtomicInteger();
+    @SplitRestriction
+    public void split(@Restriction OffsetRange restriction, 
OutputReceiver<OffsetRange> out) {
+      for (OffsetRange offsetRange :
+          splitBlockWithLimit(restriction.getFrom(), restriction.getTo())) {
+        out.output(offsetRange);
+      }
     }
 
-    @ProcessElement
-    public void processElement(ProcessContext c, BoundedWindow window) {
-      String key = c.element();
-      bundles.put(window, key);
-      if (batchCount.incrementAndGet() > getBatchSize()) {
-        Multimap<BoundedWindow, KV<String, String>> kvs = fetchAndFlush();
-        for (BoundedWindow w : kvs.keySet()) {
-          for (KV<String, String> kv : kvs.get(w)) {
-            c.output(kv);
-          }
+    public ArrayList<OffsetRange> splitBlockWithLimit(long start, long end) {
+      ArrayList<OffsetRange> offsetList = new ArrayList<>();
+      long newStart = start;
+      long newEnd = start;
+      while (true) {
+        if (newStart + batchSize <= end) {
+          offsetList.add(new OffsetRange(newStart, newStart + batchSize));
+          newEnd = newStart + batchSize;
+          newStart = newStart + batchSize + 1;

Review comment:
       Seems like we are loosing a record position here. If the previous split 
ends at (newStart + batchSize) next split should start at that. Also please add 
unit tests that would cover this.

##########
File path: 
sdks/java/io/redis/src/main/java/org/apache/beam/sdk/io/redis/RedisIO.java
##########
@@ -334,94 +336,88 @@ public void setup() {
     public void teardown() {
       jedis.close();
     }
-  }
-
-  private static class ReadKeysWithPattern extends BaseReadFn<String> {
 
-    ReadKeysWithPattern(RedisConnectionConfiguration connectionConfiguration) {
-      super(connectionConfiguration);
+    @GetInitialRestriction
+    public OffsetRange getInitialRestriction(@Element String pattern) {
+      return new OffsetRange(0, getKeyPatters(pattern).size());
     }
 
     @ProcessElement
-    public void processElement(ProcessContext c) {
-      ScanParams scanParams = new ScanParams();
-      scanParams.match(c.element());
-
-      String cursor = ScanParams.SCAN_POINTER_START;
-      boolean finished = false;
-      while (!finished) {
-        ScanResult<String> scanResult = jedis.scan(cursor, scanParams);
-        List<String> keys = scanResult.getResult();
-        for (String k : keys) {
-          c.output(k);
+    public void processElement(ProcessContext c, 
RestrictionTracker<OffsetRange, Long> tracker) {

Review comment:
       Please check other source implementations and add similar tests related 
to splitting (and reading if anything is missing).
   Here are the tests for TextSource for example (even thought it's not an SDF).
   
https://github.com/apache/beam/blob/master/sdks/java/core/src/test/java/org/apache/beam/sdk/io/TextIOReadTest.java
   
   

##########
File path: 
sdks/java/io/redis/src/main/java/org/apache/beam/sdk/io/redis/RedisIO.java
##########
@@ -334,94 +336,88 @@ public void setup() {
     public void teardown() {
       jedis.close();
     }
-  }
-
-  private static class ReadKeysWithPattern extends BaseReadFn<String> {
 
-    ReadKeysWithPattern(RedisConnectionConfiguration connectionConfiguration) {
-      super(connectionConfiguration);
+    @GetInitialRestriction
+    public OffsetRange getInitialRestriction(@Element String pattern) {
+      return new OffsetRange(0, getKeyPatters(pattern).size());
     }
 
     @ProcessElement
-    public void processElement(ProcessContext c) {
-      ScanParams scanParams = new ScanParams();
-      scanParams.match(c.element());
-
-      String cursor = ScanParams.SCAN_POINTER_START;
-      boolean finished = false;
-      while (!finished) {
-        ScanResult<String> scanResult = jedis.scan(cursor, scanParams);
-        List<String> keys = scanResult.getResult();
-        for (String k : keys) {
-          c.output(k);
+    public void processElement(ProcessContext c, 
RestrictionTracker<OffsetRange, Long> tracker) {
+      String pattern = c.element();
+      List<String> keys = getKeyPatters(pattern);
+      List<String> bundle = new ArrayList<>();
+      for (long i = tracker.currentRestriction().getFrom();
+          i < tracker.currentRestriction().getTo();
+          i++) {
+        if (tracker.tryClaim(i)) {
+          bundle.add(keys.get((int) i));
         }
-        cursor = scanResult.getCursor();
-        if (cursor.equals(ScanParams.SCAN_POINTER_START)) {
-          finished = true;
+      }
+      if (bundle.size() > 0) {
+        List<KV<String, String>> kvs = fetchAndFlush(bundle);
+        for (KV<String, String> kv : kvs) {
+          c.output(kv);
         }
       }
     }
-  }
-
-  /** A {@link DoFn} requesting Redis server to get key/value pairs. */
-  private static class ReadFn extends BaseReadFn<KV<String, String>> {
-    transient @Nullable Multimap<BoundedWindow, String> bundles = null;
-    @Nullable AtomicInteger batchCount = null;
-    private final int batchSize;
-
-    ReadFn(RedisConnectionConfiguration connectionConfiguration, int 
batchSize) {
-      super(connectionConfiguration);
-      this.batchSize = batchSize;
-    }
 
-    @StartBundle
-    public void startBundle() {
-      bundles = ArrayListMultimap.create();
-      batchCount = new AtomicInteger();
+    @SplitRestriction
+    public void split(@Restriction OffsetRange restriction, 
OutputReceiver<OffsetRange> out) {
+      for (OffsetRange offsetRange :
+          splitBlockWithLimit(restriction.getFrom(), restriction.getTo())) {
+        out.output(offsetRange);
+      }
     }
 
-    @ProcessElement
-    public void processElement(ProcessContext c, BoundedWindow window) {
-      String key = c.element();
-      bundles.put(window, key);
-      if (batchCount.incrementAndGet() > getBatchSize()) {
-        Multimap<BoundedWindow, KV<String, String>> kvs = fetchAndFlush();
-        for (BoundedWindow w : kvs.keySet()) {
-          for (KV<String, String> kv : kvs.get(w)) {
-            c.output(kv);
-          }
+    public ArrayList<OffsetRange> splitBlockWithLimit(long start, long end) {
+      ArrayList<OffsetRange> offsetList = new ArrayList<>();
+      long newStart = start;
+      long newEnd = start;
+      while (true) {
+        if (newStart + batchSize <= end) {
+          offsetList.add(new OffsetRange(newStart, newStart + batchSize));
+          newEnd = newStart + batchSize;

Review comment:
       Probably simpler to just set correct newStart and newEnd values for the 
next batch and use that here.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to