n-oden commented on a change in pull request #15858:
URL: https://github.com/apache/beam/pull/15858#discussion_r766953848



##########
File path: 
sdks/java/io/redis/src/main/java/org/apache/beam/sdk/io/redis/RedisIO.java
##########
@@ -709,4 +720,146 @@ public void teardown() {
       }
     }
   }
+
+  /**
+   * A {@link PTransform} to write stream key pairs 
(https://redis.io/topics/streams-intro) to a
+   * Redis server.
+   */
+  @AutoValue
+  public abstract static class WriteStreams
+      extends PTransform<PCollection<KV<String, Map<String, String>>>, PDone> {
+
+    abstract @Nullable RedisConnectionConfiguration connectionConfiguration();
+
+    abstract @Nullable Long maxLen();
+
+    abstract boolean approximateTrim();
+
+    abstract Builder toBuilder();
+
+    @AutoValue.Builder
+    abstract static class Builder {
+
+      abstract Builder setConnectionConfiguration(
+          RedisConnectionConfiguration connectionConfiguration);
+
+      abstract Builder setMaxLen(Long maxLen);
+
+      abstract Builder setApproximateTrim(boolean approximateTrim);
+
+      abstract WriteStreams build();
+    }
+
+    public WriteStreams withEndpoint(String host, int port) {
+      checkArgument(host != null, "host can not be null");
+      checkArgument(port > 0, "port can not be negative or 0");
+      return toBuilder()
+          
.setConnectionConfiguration(connectionConfiguration().withHost(host).withPort(port))
+          .build();
+    }
+
+    public WriteStreams withAuth(String auth) {
+      checkArgument(auth != null, "auth can not be null");
+      return toBuilder()
+          .setConnectionConfiguration(connectionConfiguration().withAuth(auth))
+          .build();
+    }
+
+    public WriteStreams withTimeout(int timeout) {
+      checkArgument(timeout >= 0, "timeout can not be negative");
+      return toBuilder()
+          
.setConnectionConfiguration(connectionConfiguration().withTimeout(timeout))
+          .build();
+    }
+
+    public WriteStreams 
withConnectionConfiguration(RedisConnectionConfiguration connection) {
+      checkArgument(connection != null, "connection can not be null");
+      return toBuilder().setConnectionConfiguration(connection).build();
+    }
+
+    public WriteStreams withMaxLen(Long maxLen) {
+      checkArgument(maxLen >= 0L, "maxLen must be positive if set");
+      return toBuilder().setMaxLen(maxLen).build();
+    }
+
+    public WriteStreams withApproximateTrim(boolean approximateTrim) {
+      return toBuilder().setApproximateTrim(approximateTrim).build();
+    }
+
+    @Override
+    public PDone expand(PCollection<KV<String, Map<String, String>>> input) {
+      checkArgument(connectionConfiguration() != null, 
"withConnectionConfiguration() is required");
+
+      input.apply(ParDo.of(new WriteStreamFn(this)));
+      return PDone.in(input.getPipeline());
+    }
+
+    private static class WriteStreamFn extends DoFn<KV<String, Map<String, 
String>>, Void> {
+
+      private static final int DEFAULT_BATCH_SIZE = 1000;
+
+      private final WriteStreams spec;
+
+      private transient Jedis jedis;
+      private transient Pipeline pipeline;
+
+      private int batchCount;
+
+      public WriteStreamFn(WriteStreams spec) {
+        this.spec = spec;
+      }
+
+      @Setup
+      public void setup() {
+        jedis = spec.connectionConfiguration().connect();
+      }
+
+      @StartBundle
+      public void startBundle() {
+        pipeline = jedis.pipelined();
+        pipeline.multi();
+        batchCount = 0;
+      }
+
+      @ProcessElement
+      public void processElement(ProcessContext c) {
+        KV<String, Map<String, String>> record = c.element();
+
+        writeRecord(record);
+
+        batchCount++;
+
+        if (batchCount >= DEFAULT_BATCH_SIZE) {
+          pipeline.exec();
+          pipeline.sync();
+          pipeline.multi();
+          batchCount = 0;
+        }
+      }
+
+      private void writeRecord(KV<String, Map<String, String>> record) {
+        String key = record.getKey();
+        Map<String, String> value = record.getValue();
+        if (spec.maxLen() > 0L) {
+          pipeline.xadd(key, StreamEntryID.NEW_ENTRY, value, spec.maxLen(), 
spec.approximateTrim());
+        } else {
+          pipeline.xadd(key, StreamEntryID.NEW_ENTRY, value);
+        }
+      }
+
+      @FinishBundle
+      public void finishBundle() {
+        if (pipeline.isInMulti()) {
+          pipeline.exec();

Review comment:
       Small followup: if we're inside a MULTI transaction, we definitely don't 
need to check each nested response.  Per the redis docs, ```Either all of the 
commands or none are processed, so a Redis transaction is also atomic.``` -- if 
EXEC does not throw an error, the transaction has applied.

##########
File path: 
sdks/java/io/redis/src/test/java/org/apache/beam/sdk/io/redis/RedisIOTest.java
##########
@@ -205,6 +212,62 @@ public void testWriteUsingDECRBY() {
     assertEquals(-1, count);
   }
 
+  @Test
+  public void testWriteStreams() {
+    List<String> keys = Arrays.asList("a", "b", "c", "d", "e", "f", "g", "h", 
"i", "j");
+    List<KV<String, Map<String, String>>> data = new ArrayList<>();
+    for (String key : keys) {
+      Map<String, String> values =
+          Stream.of(
+                  new AbstractMap.SimpleEntry<String, String>("foo", "bar"),
+                  new AbstractMap.SimpleEntry<String, String>("baz", "qux"))
+              .collect(Collectors.toMap(Map.Entry::getKey, 
Map.Entry::getValue));
+      data.add(KV.of(key, values));
+    }
+    PCollection<KV<String, Map<String, String>>> write =
+        p.apply(
+            Create.of(data)
+                .withCoder(
+                    KvCoder.of(
+                        StringUtf8Coder.of(),
+                        MapCoder.of(StringUtf8Coder.of(), 
StringUtf8Coder.of()))));
+    write.apply(RedisIO.writeStreams().withEndpoint(REDIS_HOST, port));
+    p.run();
+
+    for (String key : keys) {
+      long count = client.xlen(key);
+      assertEquals(2, count);

Review comment:
       👍 

##########
File path: 
sdks/java/io/redis/src/test/java/org/apache/beam/sdk/io/redis/RedisIOTest.java
##########
@@ -205,6 +212,62 @@ public void testWriteUsingDECRBY() {
     assertEquals(-1, count);
   }
 
+  @Test
+  public void testWriteStreams() {
+    List<String> keys = Arrays.asList("a", "b", "c", "d", "e", "f", "g", "h", 
"i", "j");
+    List<KV<String, Map<String, String>>> data = new ArrayList<>();
+    for (String key : keys) {
+      Map<String, String> values =
+          Stream.of(
+                  new AbstractMap.SimpleEntry<String, String>("foo", "bar"),
+                  new AbstractMap.SimpleEntry<String, String>("baz", "qux"))
+              .collect(Collectors.toMap(Map.Entry::getKey, 
Map.Entry::getValue));
+      data.add(KV.of(key, values));
+    }
+    PCollection<KV<String, Map<String, String>>> write =
+        p.apply(
+            Create.of(data)
+                .withCoder(
+                    KvCoder.of(
+                        StringUtf8Coder.of(),
+                        MapCoder.of(StringUtf8Coder.of(), 
StringUtf8Coder.of()))));
+    write.apply(RedisIO.writeStreams().withEndpoint(REDIS_HOST, port));
+    p.run();
+
+    for (String key : keys) {
+      long count = client.xlen(key);
+      assertEquals(2, count);
+    }
+  }
+
+  @Test
+  public void testWriteStreamsWithTruncation() {
+    List<String> keys = Arrays.asList("a", "b", "c", "d", "e", "f", "g", "h", 
"i", "j");
+    List<KV<String, Map<String, String>>> data = new ArrayList<>();
+    for (String key : keys) {
+      Map<String, String> values =
+          Stream.of(
+                  new AbstractMap.SimpleEntry<String, String>("foo", "bar"),
+                  new AbstractMap.SimpleEntry<String, String>("baz", "qux"))
+              .collect(Collectors.toMap(Map.Entry::getKey, 
Map.Entry::getValue));
+      data.add(KV.of(key, values));
+    }
+    PCollection<KV<String, Map<String, String>>> write =
+        p.apply(
+            Create.of(data)
+                .withCoder(
+                    KvCoder.of(
+                        StringUtf8Coder.of(),
+                        MapCoder.of(StringUtf8Coder.of(), 
StringUtf8Coder.of()))));
+    write.apply(RedisIO.writeStreams().withEndpoint(REDIS_HOST, 
port).withMaxLen(1L));

Review comment:
       Yeah, I think I confused myself with the overly complicated test data 
creation pattern.  This should actually test what it claims to now.

##########
File path: 
sdks/java/io/redis/src/test/java/org/apache/beam/sdk/io/redis/RedisIOTest.java
##########
@@ -205,6 +212,62 @@ public void testWriteUsingDECRBY() {
     assertEquals(-1, count);
   }
 
+  @Test
+  public void testWriteStreams() {
+    List<String> keys = Arrays.asList("a", "b", "c", "d", "e", "f", "g", "h", 
"i", "j");
+    List<KV<String, Map<String, String>>> data = new ArrayList<>();
+    for (String key : keys) {
+      Map<String, String> values =

Review comment:
       Oh nice, that's a much more idiomatic pattern, thank you.  (Java is to 
put it mildly not my primary language.)




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to