This is an automated email from the ASF dual-hosted git repository.
ahmedabualsaud pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 7ea8cd2608e Simplify Managed API to avoid dealing with
PCollectionRowTuple (#31470)
7ea8cd2608e is described below
commit 7ea8cd2608e1b7550ffca44aff5596dd43fd4aa6
Author: Ahmed Abualsaud <[email protected]>
AuthorDate: Tue Jun 4 19:02:00 2024 -0400
Simplify Managed API to avoid dealing with PCollectionRowTuple (#31470)
* Managed accepts PInput type
* add unit test
* spotless
* spotless
* rename to getSinglePCollection
---
.../beam/sdk/values/PCollectionRowTuple.java | 17 ++++++++
.../apache/beam/sdk/io/iceberg/IcebergIOIT.java | 10 ++---
.../IcebergReadSchemaTransformProviderTest.java | 4 +-
.../IcebergWriteSchemaTransformProviderTest.java | 14 +++----
.../java/org/apache/beam/sdk/io/kafka/KafkaIO.java | 2 +-
.../KafkaReadSchemaTransformProviderTest.java | 4 +-
.../KafkaWriteSchemaTransformProviderTest.java | 7 +---
.../java/org/apache/beam/sdk/managed/Managed.java | 46 +++++++++++++++++-----
.../sdk/managed/ManagedTransformConstants.java | 3 ++
.../ManagedSchemaTransformTranslationTest.java | 3 +-
.../org/apache/beam/sdk/managed/ManagedTest.java | 34 ++++++++++++++--
11 files changed, 104 insertions(+), 40 deletions(-)
diff --git
a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PCollectionRowTuple.java
b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PCollectionRowTuple.java
index 0e7c52c4ae7..a2a3aa74e53 100644
---
a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PCollectionRowTuple.java
+++
b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PCollectionRowTuple.java
@@ -23,6 +23,7 @@ import java.util.Map;
import java.util.Objects;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.transforms.PTransform;
+import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions;
import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;
import org.checkerframework.checker.nullness.qual.Nullable;
@@ -180,6 +181,22 @@ public class PCollectionRowTuple implements PInput,
POutput {
return pcollection;
}
+ /**
+ * Like {@link #get(String)}, but is a convenience method to get a single
PCollection without
+ * providing a tag for that output. Use only when there is a single
collection in this tuple.
+ *
+ * <p>Throws {@link IllegalStateException} if more than one output exists in
the {@link
+ * PCollectionRowTuple}.
+ */
+ public PCollection<Row> getSinglePCollection() {
+ Preconditions.checkState(
+ pcollectionMap.size() == 1,
+ "Expected exactly one output PCollection<Row>, but found %s. "
+ + "Please try retrieving a specified output using get(<tag>)
instead.",
+ pcollectionMap.size());
+ return get(pcollectionMap.entrySet().iterator().next().getKey());
+ }
+
/**
* Returns an immutable Map from tag to corresponding {@link PCollection},
for all the members of
* this {@link PCollectionRowTuple}.
diff --git
a/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/IcebergIOIT.java
b/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/IcebergIOIT.java
index 06a63909c12..467a2cbaf24 100644
---
a/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/IcebergIOIT.java
+++
b/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/IcebergIOIT.java
@@ -38,7 +38,6 @@ import org.apache.beam.sdk.testing.PAssert;
import org.apache.beam.sdk.testing.TestPipeline;
import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.values.PCollection;
-import org.apache.beam.sdk.values.PCollectionRowTuple;
import org.apache.beam.sdk.values.Row;
import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;
import org.apache.hadoop.conf.Configuration;
@@ -216,11 +215,10 @@ public class IcebergIOIT implements Serializable {
.build())
.build();
- PCollectionRowTuple output =
- PCollectionRowTuple.empty(readPipeline)
- .apply(Managed.read(Managed.ICEBERG).withConfig(config));
+ PCollection<Row> rows =
+
readPipeline.apply(Managed.read(Managed.ICEBERG).withConfig(config)).getSinglePCollection();
- PAssert.that(output.get("output")).containsInAnyOrder(expectedRows);
+ PAssert.that(rows).containsInAnyOrder(expectedRows);
readPipeline.run().waitUntilFinish();
}
@@ -258,7 +256,7 @@ public class IcebergIOIT implements Serializable {
.build();
PCollection<Row> input =
writePipeline.apply(Create.of(inputRows)).setRowSchema(BEAM_SCHEMA);
- PCollectionRowTuple.of("input",
input).apply(Managed.write(Managed.ICEBERG).withConfig(config));
+ input.apply(Managed.write(Managed.ICEBERG).withConfig(config));
writePipeline.run().waitUntilFinish();
diff --git
a/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/IcebergReadSchemaTransformProviderTest.java
b/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/IcebergReadSchemaTransformProviderTest.java
index 27a31f31830..46168a487dd 100644
---
a/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/IcebergReadSchemaTransformProviderTest.java
+++
b/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/IcebergReadSchemaTransformProviderTest.java
@@ -166,9 +166,9 @@ public class IcebergReadSchemaTransformProviderTest {
Map<String, Object> configMap = new Yaml().load(yamlConfig);
PCollection<Row> output =
- PCollectionRowTuple.empty(testPipeline)
+ testPipeline
.apply(Managed.read(Managed.ICEBERG).withConfig(configMap))
- .get(OUTPUT_TAG);
+ .getSinglePCollection();
PAssert.that(output)
.satisfies(
diff --git
a/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/IcebergWriteSchemaTransformProviderTest.java
b/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/IcebergWriteSchemaTransformProviderTest.java
index 97aebd5c41f..9ef3e9945ec 100644
---
a/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/IcebergWriteSchemaTransformProviderTest.java
+++
b/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/IcebergWriteSchemaTransformProviderTest.java
@@ -134,16 +134,12 @@ public class IcebergWriteSchemaTransformProviderTest {
identifier, CatalogUtil.ICEBERG_CATALOG_TYPE_HADOOP,
warehouse.location);
Map<String, Object> configMap = new Yaml().load(yamlConfig);
- PCollectionRowTuple input =
- PCollectionRowTuple.of(
- INPUT_TAG,
- testPipeline
- .apply(
- "Records To Add",
Create.of(TestFixtures.asRows(TestFixtures.FILE1SNAPSHOT1)))
- .setRowSchema(
-
SchemaAndRowConversions.icebergSchemaToBeamSchema(TestFixtures.SCHEMA)));
+ PCollection<Row> inputRows =
+ testPipeline
+ .apply("Records To Add",
Create.of(TestFixtures.asRows(TestFixtures.FILE1SNAPSHOT1)))
+
.setRowSchema(SchemaAndRowConversions.icebergSchemaToBeamSchema(TestFixtures.SCHEMA));
PCollection<Row> result =
-
input.apply(Managed.write(Managed.ICEBERG).withConfig(configMap)).get(OUTPUT_TAG);
+
inputRows.apply(Managed.write(Managed.ICEBERG).withConfig(configMap)).get(OUTPUT_TAG);
PAssert.that(result).satisfies(new VerifyOutputs(identifier, "append"));
diff --git
a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java
b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java
index e897ed439cd..8f995a63a10 100644
--- a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java
+++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java
@@ -2663,7 +2663,7 @@ public class KafkaIO {
abstract Builder<K, V> setProducerConfig(Map<String, Object>
producerConfig);
abstract Builder<K, V> setProducerFactoryFn(
- SerializableFunction<Map<String, Object>, Producer<K, V>> fn);
+ @Nullable SerializableFunction<Map<String, Object>, Producer<K, V>>
fn);
abstract Builder<K, V> setKeySerializer(Class<? extends Serializer<K>>
serializer);
diff --git
a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaReadSchemaTransformProviderTest.java
b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaReadSchemaTransformProviderTest.java
index f5ac5bb54ad..dfe062e1eef 100644
---
a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaReadSchemaTransformProviderTest.java
+++
b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaReadSchemaTransformProviderTest.java
@@ -36,7 +36,7 @@ import org.apache.beam.sdk.managed.Managed;
import org.apache.beam.sdk.managed.ManagedTransformConstants;
import org.apache.beam.sdk.schemas.transforms.SchemaTransformProvider;
import org.apache.beam.sdk.schemas.utils.YamlUtils;
-import org.apache.beam.sdk.values.PCollectionRowTuple;
+import org.apache.beam.sdk.values.PBegin;
import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Sets;
import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.io.ByteStreams;
@@ -319,7 +319,7 @@ public class KafkaReadSchemaTransformProviderTest {
// Kafka Read SchemaTransform gets built in
ManagedSchemaTransformProvider's expand
Managed.read(Managed.KAFKA)
.withConfig(YamlUtils.yamlStringToMap(config))
- .expand(PCollectionRowTuple.empty(Pipeline.create()));
+ .expand(PBegin.in(Pipeline.create()));
}
}
diff --git
a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaWriteSchemaTransformProviderTest.java
b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaWriteSchemaTransformProviderTest.java
index 60bff89b355..f19e91d8926 100644
---
a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaWriteSchemaTransformProviderTest.java
+++
b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaWriteSchemaTransformProviderTest.java
@@ -43,7 +43,6 @@ import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollection;
-import org.apache.beam.sdk.values.PCollectionRowTuple;
import org.apache.beam.sdk.values.PCollectionTuple;
import org.apache.beam.sdk.values.Row;
import org.apache.beam.sdk.values.TupleTag;
@@ -225,10 +224,8 @@ public class KafkaWriteSchemaTransformProviderTest {
Managed.write(Managed.KAFKA)
.withConfig(YamlUtils.yamlStringToMap(config))
.expand(
- PCollectionRowTuple.of(
- "input",
- Pipeline.create()
-
.apply(Create.empty(Schema.builder().addByteArrayField("bytes").build()))));
+ Pipeline.create()
+
.apply(Create.empty(Schema.builder().addByteArrayField("bytes").build())));
}
}
diff --git
a/sdks/java/managed/src/main/java/org/apache/beam/sdk/managed/Managed.java
b/sdks/java/managed/src/main/java/org/apache/beam/sdk/managed/Managed.java
index da4a0853fb3..6f95290e6ee 100644
--- a/sdks/java/managed/src/main/java/org/apache/beam/sdk/managed/Managed.java
+++ b/sdks/java/managed/src/main/java/org/apache/beam/sdk/managed/Managed.java
@@ -22,11 +22,16 @@ import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import javax.annotation.Nullable;
+import org.apache.beam.sdk.coders.RowCoder;
import org.apache.beam.sdk.schemas.transforms.SchemaTransform;
import org.apache.beam.sdk.schemas.transforms.SchemaTransformProvider;
import org.apache.beam.sdk.schemas.utils.YamlUtils;
import org.apache.beam.sdk.transforms.PTransform;
+import org.apache.beam.sdk.values.PBegin;
+import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PCollectionRowTuple;
+import org.apache.beam.sdk.values.PInput;
+import org.apache.beam.sdk.values.Row;
import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;
import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions;
import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;
@@ -47,12 +52,13 @@ import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Immuta
* specifies arguments using like so:
*
* <pre>{@code
- * PCollectionRowTuple output = PCollectionRowTuple.empty(pipeline).apply(
+ * PCollection<Row> rows = pipeline.apply(
* Managed.read(ICEBERG)
* .withConfig(ImmutableMap.<String, Object>.builder()
* .put("foo", "abc")
* .put("bar", 123)
- * .build()));
+ * .build()))
+ * .getOutput();
* }</pre>
*
* <p>Instead of specifying configuration arguments directly in the code, one
can provide the
@@ -66,11 +72,9 @@ import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Immuta
* <p>The file's path can be passed in to the Managed API like so:
*
* <pre>{@code
- * PCollectionRowTuple input = PCollectionRowTuple.of("input",
pipeline.apply(Create.of(...)))
+ * PCollection<Row> inputRows = pipeline.apply(Create.of(...));
*
- * PCollectionRowTuple output = input.apply(
- * Managed.write(ICEBERG)
- * .withConfigUrl(<config path>));
+ * input.apply(Managed.write(ICEBERG).withConfigUrl(<config path>));
* }</pre>
*/
public class Managed {
@@ -132,8 +136,7 @@ public class Managed {
}
@AutoValue
- public abstract static class ManagedTransform
- extends PTransform<PCollectionRowTuple, PCollectionRowTuple> {
+ public abstract static class ManagedTransform extends PTransform<PInput,
PCollectionRowTuple> {
abstract String getIdentifier();
abstract @Nullable Map<String, Object> getConfig();
@@ -183,7 +186,9 @@ public class Managed {
}
@Override
- public PCollectionRowTuple expand(PCollectionRowTuple input) {
+ public PCollectionRowTuple expand(PInput input) {
+ PCollectionRowTuple inputTuple = resolveInput(input);
+
ManagedSchemaTransformProvider.ManagedConfig managedConfig =
ManagedSchemaTransformProvider.ManagedConfig.builder()
.setTransformIdentifier(getIdentifier())
@@ -194,7 +199,28 @@ public class Managed {
SchemaTransform underlyingTransform =
new
ManagedSchemaTransformProvider(getSupportedIdentifiers()).from(managedConfig);
- return input.apply(underlyingTransform);
+ return inputTuple.apply(underlyingTransform);
+ }
+
+ @VisibleForTesting
+ static PCollectionRowTuple resolveInput(PInput input) {
+ if (input instanceof PBegin) {
+ return PCollectionRowTuple.empty(input.getPipeline());
+ } else if (input instanceof PCollection) {
+ PCollection<?> inputCollection = (PCollection<?>) input;
+ Preconditions.checkArgument(
+ inputCollection.getCoder() instanceof RowCoder,
+ "Input PCollection must contain Row elements with a set Schema "
+ + "(using .setRowSchema()). Instead, found collection %s with
coder: %s.",
+ inputCollection.getName(),
+ inputCollection.getCoder());
+ return PCollectionRowTuple.of(
+ ManagedTransformConstants.INPUT, (PCollection<Row>)
inputCollection);
+ } else if (input instanceof PCollectionRowTuple) {
+ return (PCollectionRowTuple) input;
+ }
+
+ throw new IllegalArgumentException("Unsupported input type: " +
input.getClass());
}
}
}
diff --git
a/sdks/java/managed/src/main/java/org/apache/beam/sdk/managed/ManagedTransformConstants.java
b/sdks/java/managed/src/main/java/org/apache/beam/sdk/managed/ManagedTransformConstants.java
index 141544305a3..51d0b67b4b8 100644
---
a/sdks/java/managed/src/main/java/org/apache/beam/sdk/managed/ManagedTransformConstants.java
+++
b/sdks/java/managed/src/main/java/org/apache/beam/sdk/managed/ManagedTransformConstants.java
@@ -38,6 +38,9 @@ import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Immuta
* every single parameter through the Managed interface.
*/
public class ManagedTransformConstants {
+ // Standard input PCollection tag
+ public static final String INPUT = "input";
+
public static final String ICEBERG_READ =
"beam:schematransform:org.apache.beam:iceberg_read:v1";
public static final String ICEBERG_WRITE =
"beam:schematransform:org.apache.beam:iceberg_write:v1";
diff --git
a/sdks/java/managed/src/test/java/org/apache/beam/sdk/managed/ManagedSchemaTransformTranslationTest.java
b/sdks/java/managed/src/test/java/org/apache/beam/sdk/managed/ManagedSchemaTransformTranslationTest.java
index f7769a9e1d1..0d122646d89 100644
---
a/sdks/java/managed/src/test/java/org/apache/beam/sdk/managed/ManagedSchemaTransformTranslationTest.java
+++
b/sdks/java/managed/src/test/java/org/apache/beam/sdk/managed/ManagedSchemaTransformTranslationTest.java
@@ -50,7 +50,6 @@ import org.apache.beam.sdk.util.CoderUtils;
import org.apache.beam.sdk.util.construction.BeamUrns;
import org.apache.beam.sdk.util.construction.PipelineTranslation;
import org.apache.beam.sdk.values.PCollection;
-import org.apache.beam.sdk.values.PCollectionRowTuple;
import org.apache.beam.sdk.values.Row;
import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString;
import
org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.InvalidProtocolBufferException;
@@ -141,7 +140,7 @@ public class ManagedSchemaTransformTranslationTest {
.setIdentifier(TestSchemaTransformProvider.IDENTIFIER)
.build()
.withConfig(underlyingConfig);
- PCollectionRowTuple.of("input", input).apply(transform).get("output");
+ input.apply(transform);
// Then translate the pipeline to a proto and extract the
ManagedSchemaTransform's proto
RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto(p);
diff --git
a/sdks/java/managed/src/test/java/org/apache/beam/sdk/managed/ManagedTest.java
b/sdks/java/managed/src/test/java/org/apache/beam/sdk/managed/ManagedTest.java
index 7ed364d0e17..249faffec56 100644
---
a/sdks/java/managed/src/test/java/org/apache/beam/sdk/managed/ManagedTest.java
+++
b/sdks/java/managed/src/test/java/org/apache/beam/sdk/managed/ManagedTest.java
@@ -17,17 +17,23 @@
*/
package org.apache.beam.sdk.managed;
+import static org.junit.Assert.assertThrows;
+
import java.nio.file.Paths;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
+import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.managed.testing.TestSchemaTransformProvider;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.testing.PAssert;
import org.apache.beam.sdk.testing.TestPipeline;
import org.apache.beam.sdk.transforms.Create;
+import org.apache.beam.sdk.values.PBegin;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PCollectionRowTuple;
+import org.apache.beam.sdk.values.PCollectionTuple;
+import org.apache.beam.sdk.values.PInput;
import org.apache.beam.sdk.values.Row;
import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;
import org.junit.Rule;
@@ -61,11 +67,33 @@ public class ManagedTest {
Row.withSchema(SCHEMA).withFieldValue("str",
"b").withFieldValue("int", 2).build(),
Row.withSchema(SCHEMA).withFieldValue("str",
"c").withFieldValue("int", 3).build());
+ @Test
+ public void testResolveInputToPCollectionRowTuple() {
+ Pipeline p = Pipeline.create();
+ List<PInput> inputTypes =
+ Arrays.asList(
+ PBegin.in(p),
+ p.apply(Create.of(ROWS).withRowSchema(SCHEMA)),
+ PCollectionRowTuple.of("pcoll",
p.apply(Create.of(ROWS).withRowSchema(SCHEMA))));
+
+ List<PInput> badInputTypes =
+ Arrays.asList(
+ p.apply(Create.of(1, 2, 3)),
+ p.apply(Create.of(ROWS)),
+ PCollectionTuple.of("pcoll", p.apply(Create.of(ROWS))));
+
+ for (PInput input : inputTypes) {
+ Managed.ManagedTransform.resolveInput(input);
+ }
+ for (PInput badInput : badInputTypes) {
+ assertThrows(
+ IllegalArgumentException.class, () ->
Managed.ManagedTransform.resolveInput(badInput));
+ }
+ }
+
public void runTestProviderTest(Managed.ManagedTransform writeOp) {
PCollection<Row> rows =
- PCollectionRowTuple.of("input",
pipeline.apply(Create.of(ROWS)).setRowSchema(SCHEMA))
- .apply(writeOp)
- .get("output");
+
pipeline.apply(Create.of(ROWS)).setRowSchema(SCHEMA).apply(writeOp).getSinglePCollection();
Schema outputSchema = rows.getSchema();
PAssert.that(rows)