Add Create.TimestampedValues.withType

This brings parity between Create.Values and Create.TimestampedValues.

Update CreateTest to ensure that create coder inference would fail if it
ran.


Project: http://git-wip-us.apache.org/repos/asf/beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/43825093
Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/43825093
Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/43825093

Branch: refs/heads/master
Commit: 43825093e233c25e5920eb2bacca98673e25de75
Parents: 7c71036
Author: Aviem Zur <[email protected]>
Authored: Mon Mar 13 21:45:21 2017 +0200
Committer: Thomas Groh <[email protected]>
Committed: Wed Mar 15 10:32:58 2017 -0700

----------------------------------------------------------------------
 .../org/apache/beam/sdk/transforms/Create.java  | 71 +++++++++++++++-----
 .../apache/beam/sdk/transforms/CreateTest.java  | 32 ++++++++-
 2 files changed, 82 insertions(+), 21 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/beam/blob/43825093/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Create.java
----------------------------------------------------------------------
diff --git 
a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Create.java 
b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Create.java
index 4f746d0..ffc2d8d 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Create.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Create.java
@@ -207,7 +207,10 @@ public class Create<T> {
    * Otherwise, use {@link Create.TimestampedValues#withCoder} to set the 
coder explicitly.
    */
   public static <T> TimestampedValues<T> 
timestamped(Iterable<TimestampedValue<T>> elems) {
-    return new TimestampedValues<>(elems, Optional.<Coder<T>>absent());
+    return new TimestampedValues<>(
+        elems,
+        Optional.<Coder<T>>absent(),
+        Optional.<TypeDescriptor<T>>absent());
   }
 
   /**
@@ -495,27 +498,32 @@ public class Create<T> {
      * is used.
      */
     public TimestampedValues<T> withCoder(Coder<T> coder) {
-      return new TimestampedValues<>(timestampedElements, 
Optional.<Coder<T>>of(coder));
+      return new TimestampedValues<>(timestampedElements, Optional.of(coder), 
typeDescriptor);
+    }
+
+    /**
+     * Returns a {@link Create.TimestampedValues} PTransform like this one 
that uses the given
+     * {@code TypeDescriptor<T>} to determine the {@code Coder} to use to 
decode each of the
+     * objects into a value of type {@code T}. Note that a default coder must 
be registered for the
+     * class described in the {@code TypeDescriptor<T>}.
+     *
+     * <p>By default, {@code Create.TimestampedValues} can automatically 
determine the {@code Coder}
+     * to use if all elements have the same non-parameterized run-time class, 
and a default coder is
+     * registered for that class. See {@link CoderRegistry} for details on how 
defaults are
+     * determined.
+     *
+     * <p>Note that for {@link Create.TimestampedValues} with no elements, the 
{@link VoidCoder} is
+     * used.
+     */
+    public TimestampedValues<T> withType(TypeDescriptor<T> type) {
+      return new TimestampedValues<>(timestampedElements, elementCoder, 
Optional.of(type));
     }
 
     @Override
     public PCollection<T> expand(PBegin input) {
       try {
-        Iterable<T> rawElements =
-            Iterables.transform(
-                timestampedElements,
-                new Function<TimestampedValue<T>, T>() {
-                  @Override
-                  public T apply(TimestampedValue<T> input) {
-                    return input.getValue();
-                  }
-                });
-        Coder<T> coder;
-        if (elementCoder.isPresent()) {
-          coder = elementCoder.get();
-        } else {
-          coder = 
getDefaultCreateCoder(input.getPipeline().getCoderRegistry(), rawElements);
-        }
+        Coder<T> coder = getDefaultOutputCoder(input);
+
         PCollection<TimestampedValue<T>> intermediate = 
Pipeline.applyTransform(input,
             
Create.of(timestampedElements).withCoder(TimestampedValueCoder.of(coder)));
 
@@ -533,12 +541,19 @@ public class Create<T> {
     /** The timestamped elements of the resulting PCollection. */
     private final transient Iterable<TimestampedValue<T>> timestampedElements;
 
+    /** The coder used to encode the values to and from a binary 
representation. */
     private final transient Optional<Coder<T>> elementCoder;
 
+    /** The value type. */
+    private final transient Optional<TypeDescriptor<T>> typeDescriptor;
+
     private TimestampedValues(
-        Iterable<TimestampedValue<T>> timestampedElements, Optional<Coder<T>> 
elementCoder) {
+        Iterable<TimestampedValue<T>> timestampedElements,
+        Optional<Coder<T>> elementCoder,
+        Optional<TypeDescriptor<T>> typeDescriptor) {
       this.timestampedElements = timestampedElements;
       this.elementCoder = elementCoder;
+      this.typeDescriptor = typeDescriptor;
     }
 
     private static class ConvertTimestamps<T> extends 
DoFn<TimestampedValue<T>, T> {
@@ -547,6 +562,26 @@ public class Create<T> {
         c.outputWithTimestamp(c.element().getValue(), 
c.element().getTimestamp());
       }
     }
+
+    @Override
+    public Coder<T> getDefaultOutputCoder(PBegin input) throws 
CannotProvideCoderException {
+      if (elementCoder.isPresent()) {
+        return elementCoder.get();
+      } else if (typeDescriptor.isPresent()) {
+        return 
input.getPipeline().getCoderRegistry().getDefaultCoder(typeDescriptor.get());
+      } else {
+        Iterable<T> rawElements =
+            Iterables.transform(
+                timestampedElements,
+                new Function<TimestampedValue<T>, T>() {
+                  @Override
+                  public T apply(TimestampedValue<T> input) {
+                    return input.getValue();
+                  }
+                });
+        return getDefaultCreateCoder(input.getPipeline().getCoderRegistry(), 
rawElements);
+      }
+    }
   }
 
   private static <T> Coder<T> getDefaultCreateCoder(CoderRegistry registry, 
Iterable<T> elems)

http://git-wip-us.apache.org/repos/asf/beam/blob/43825093/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/CreateTest.java
----------------------------------------------------------------------
diff --git 
a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/CreateTest.java 
b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/CreateTest.java
index af917cf..d21e502 100644
--- 
a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/CreateTest.java
+++ 
b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/CreateTest.java
@@ -303,6 +303,33 @@ public class CreateTest {
   }
 
   @Test
+  public void testCreateTimestampedDefaultOutputCoderUsingCoder() throws 
Exception {
+    Coder<Record> coder = new RecordCoder();
+    PBegin pBegin = PBegin.in(p);
+    Create.TimestampedValues<Record> values =
+        Create.timestamped(
+            TimestampedValue.of(new Record(), new Instant(0)),
+            TimestampedValue.<Record>of(new Record2(), new Instant(0)))
+            .withCoder(coder);
+    Coder<Record> defaultCoder = values.getDefaultOutputCoder(pBegin);
+    assertThat(defaultCoder, equalTo(coder));
+  }
+
+  @Test
+  public void testCreateTimestampedDefaultOutputCoderUsingTypeDescriptor() 
throws Exception {
+    Coder<Record> coder = new RecordCoder();
+    p.getCoderRegistry().registerCoder(Record.class, coder);
+    PBegin pBegin = PBegin.in(p);
+    Create.TimestampedValues<Record> values =
+        Create.timestamped(
+            TimestampedValue.of(new Record(), new Instant(0)),
+            TimestampedValue.<Record>of(new Record2(), new Instant(0)))
+            .withType(new TypeDescriptor<Record>() {});
+    Coder<Record> defaultCoder = values.getDefaultOutputCoder(pBegin);
+    assertThat(defaultCoder, equalTo(coder));
+  }
+
+  @Test
   @Category(RunnableOnService.class)
   public void testCreateWithVoidType() throws Exception {
     PCollection<Void> output = p.apply(Create.of((Void) null, (Void) null));
@@ -346,7 +373,7 @@ public class CreateTest {
     Coder<Record> coder = new RecordCoder();
     PBegin pBegin = PBegin.in(p);
     Create.Values<Record> values =
-        Create.of(new Record(), new Record(), new Record()).withCoder(coder);
+        Create.of(new Record(), new Record2()).withCoder(coder);
     Coder<Record> defaultCoder = values.getDefaultOutputCoder(pBegin);
     assertThat(defaultCoder, equalTo(coder));
   }
@@ -357,8 +384,7 @@ public class CreateTest {
     p.getCoderRegistry().registerCoder(Record.class, coder);
     PBegin pBegin = PBegin.in(p);
     Create.Values<Record> values =
-        Create.of(new Record(), new Record(), new Record())
-            .withType(new TypeDescriptor<Record>() {});
+        Create.of(new Record(), new Record2()).withType(new 
TypeDescriptor<Record>() {});
     Coder<Record> defaultCoder = values.getDefaultOutputCoder(pBegin);
     assertThat(defaultCoder, equalTo(coder));
   }

Reply via email to