This is an automated email from the ASF dual-hosted git repository.
reuvenlax pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 77b295b Merge pull request #8311: Allow Schema field selections in
DoFn using NewDoFn injection
77b295b is described below
commit 77b295b1c2b0a206099b8f50c4d3180c248e252c
Author: reuvenlax <[email protected]>
AuthorDate: Sat Apr 27 03:01:26 2019 -0700
Merge pull request #8311: Allow Schema field selections in DoFn using
NewDoFn injection
---
.../construction/SplittableParDoNaiveBounded.java | 2 +-
...TimeBoundedSplittableProcessElementInvoker.java | 2 +-
.../apache/beam/runners/core/SimpleDoFnRunner.java | 13 +-
.../apache/beam/sdk/schemas/SchemaRegistry.java | 42 +++-
.../sdk/schemas/annotations/DefaultSchema.java | 89 +++++---
.../beam/sdk/schemas/transforms/Convert.java | 77 ++-----
.../apache/beam/sdk/schemas/transforms/Select.java | 2 +-
.../beam/sdk/schemas/utils/ConvertHelpers.java | 204 ++++++++++++++++++
.../sdk/schemas/utils/StaticSchemaInference.java | 4 +-
.../beam/sdk/transforms/DoFnSchemaInformation.java | 199 +++++++++++++++++-
.../org/apache/beam/sdk/transforms/DoFnTester.java | 2 +-
.../java/org/apache/beam/sdk/transforms/ParDo.java | 111 ++++------
.../reflect/ByteBuddyDoFnInvokerFactory.java | 34 ++-
.../beam/sdk/transforms/reflect/DoFnInvoker.java | 15 +-
.../beam/sdk/transforms/reflect/DoFnSignature.java | 36 +++-
.../sdk/transforms/reflect/DoFnSignatures.java | 29 ++-
.../beam/sdk/transforms/ParDoSchemaTest.java | 233 +++++++++++++++++----
.../sdk/transforms/reflect/DoFnSignaturesTest.java | 35 +++-
.../apache/beam/fn/harness/FnApiDoFnRunner.java | 11 +-
19 files changed, 883 insertions(+), 257 deletions(-)
diff --git
a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/SplittableParDoNaiveBounded.java
b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/SplittableParDoNaiveBounded.java
index 35ba2ff..ef55460 100644
---
a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/SplittableParDoNaiveBounded.java
+++
b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/SplittableParDoNaiveBounded.java
@@ -238,7 +238,7 @@ public class SplittableParDoNaiveBounded {
}
@Override
- public Object schemaElement(DoFn<InputT, OutputT> doFn) {
+ public Object schemaElement(int index) {
throw new UnsupportedOperationException();
}
diff --git
a/runners/core-java/src/main/java/org/apache/beam/runners/core/OutputAndTimeBoundedSplittableProcessElementInvoker.java
b/runners/core-java/src/main/java/org/apache/beam/runners/core/OutputAndTimeBoundedSplittableProcessElementInvoker.java
index 2009f90..7a7bc60 100644
---
a/runners/core-java/src/main/java/org/apache/beam/runners/core/OutputAndTimeBoundedSplittableProcessElementInvoker.java
+++
b/runners/core-java/src/main/java/org/apache/beam/runners/core/OutputAndTimeBoundedSplittableProcessElementInvoker.java
@@ -121,7 +121,7 @@ public class
OutputAndTimeBoundedSplittableProcessElementInvoker<
}
@Override
- public Object schemaElement(DoFn<InputT, OutputT> doFn) {
+ public Object schemaElement(int index) {
throw new UnsupportedOperationException("Not supported in
SplittableDoFn");
}
diff --git
a/runners/core-java/src/main/java/org/apache/beam/runners/core/SimpleDoFnRunner.java
b/runners/core-java/src/main/java/org/apache/beam/runners/core/SimpleDoFnRunner.java
index bbe2730..8fd2ad3 100644
---
a/runners/core-java/src/main/java/org/apache/beam/runners/core/SimpleDoFnRunner.java
+++
b/runners/core-java/src/main/java/org/apache/beam/runners/core/SimpleDoFnRunner.java
@@ -39,6 +39,7 @@ import
org.apache.beam.sdk.transforms.DoFn.MultiOutputReceiver;
import org.apache.beam.sdk.transforms.DoFn.OutputReceiver;
import org.apache.beam.sdk.transforms.DoFnOutputReceivers;
import org.apache.beam.sdk.transforms.DoFnSchemaInformation;
+import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.beam.sdk.transforms.reflect.DoFnInvoker;
import org.apache.beam.sdk.transforms.reflect.DoFnInvokers;
import org.apache.beam.sdk.transforms.reflect.DoFnSignature;
@@ -301,7 +302,7 @@ public class SimpleDoFnRunner<InputT, OutputT> implements
DoFnRunner<InputT, Out
}
@Override
- public Object schemaElement(DoFn<InputT, OutputT> doFn) {
+ public Object schemaElement(int index) {
throw new UnsupportedOperationException(
"Element parameters are not supported outside of @ProcessElement
method.");
}
@@ -415,7 +416,7 @@ public class SimpleDoFnRunner<InputT, OutputT> implements
DoFnRunner<InputT, Out
}
@Override
- public Object schemaElement(DoFn<InputT, OutputT> doFn) {
+ public Object schemaElement(int index) {
throw new UnsupportedOperationException(
"Cannot access element outside of @ProcessElement method.");
}
@@ -631,9 +632,9 @@ public class SimpleDoFnRunner<InputT, OutputT> implements
DoFnRunner<InputT, Out
}
@Override
- public Object schemaElement(DoFn<InputT, OutputT> doFn) {
- Row row = schemaCoder.getToRowFunction().apply(element());
- return
doFnSchemaInformation.getElementParameterSchema().getFromRowFunction().apply(row);
+ public Object schemaElement(int index) {
+ SerializableFunction converter =
doFnSchemaInformation.getElementConverters().get(index);
+ return converter.apply(element());
}
@Override
@@ -781,7 +782,7 @@ public class SimpleDoFnRunner<InputT, OutputT> implements
DoFnRunner<InputT, Out
}
@Override
- public Object schemaElement(DoFn<InputT, OutputT> doFn) {
+ public Object schemaElement(int index) {
throw new UnsupportedOperationException("Element parameters are not
supported.");
}
diff --git
a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/SchemaRegistry.java
b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/SchemaRegistry.java
index 6f565ac..8074500 100644
---
a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/SchemaRegistry.java
+++
b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/SchemaRegistry.java
@@ -78,22 +78,52 @@ public class SchemaRegistry {
@Nullable
@Override
public <T> Schema schemaFor(TypeDescriptor<T> typeDescriptor) {
- SchemaProvider schemaProvider = providers.get(typeDescriptor);
- return (schemaProvider != null) ?
schemaProvider.schemaFor(typeDescriptor) : null;
+ TypeDescriptor<?> type = typeDescriptor;
+ do {
+ SchemaProvider schemaProvider = providers.get(type);
+ if (schemaProvider != null) {
+ return schemaProvider.schemaFor(type);
+ }
+ Class<?> superClass = type.getRawType().getSuperclass();
+ if (superClass == null || superClass.equals(Object.class)) {
+ return null;
+ }
+ type = TypeDescriptor.of(superClass);
+ } while (true);
}
@Nullable
@Override
public <T> SerializableFunction<T, Row> toRowFunction(TypeDescriptor<T>
typeDescriptor) {
- SchemaProvider schemaProvider = providers.get(typeDescriptor);
- return (schemaProvider != null) ?
schemaProvider.toRowFunction(typeDescriptor) : null;
+ TypeDescriptor<?> type = typeDescriptor;
+ do {
+ SchemaProvider schemaProvider = providers.get(type);
+ if (schemaProvider != null) {
+ return (SerializableFunction<T, Row>)
schemaProvider.toRowFunction(type);
+ }
+ Class<?> superClass = type.getRawType().getSuperclass();
+ if (superClass == null || superClass.equals(Object.class)) {
+ return null;
+ }
+ type = TypeDescriptor.of(superClass);
+ } while (true);
}
@Nullable
@Override
public <T> SerializableFunction<Row, T> fromRowFunction(TypeDescriptor<T>
typeDescriptor) {
- SchemaProvider schemaProvider = providers.get(typeDescriptor);
- return (schemaProvider != null) ?
schemaProvider.fromRowFunction(typeDescriptor) : null;
+ TypeDescriptor<?> type = typeDescriptor;
+ do {
+ SchemaProvider schemaProvider = providers.get(type);
+ if (schemaProvider != null) {
+ return (SerializableFunction<Row, T>)
schemaProvider.fromRowFunction(type);
+ }
+ Class<?> superClass = type.getRawType().getSuperclass();
+ if (superClass == null || superClass.equals(Object.class)) {
+ return null;
+ }
+ type = TypeDescriptor.of(superClass);
+ } while (true);
}
}
diff --git
a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/annotations/DefaultSchema.java
b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/annotations/DefaultSchema.java
index 00c9376..d3b7d10 100644
---
a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/annotations/DefaultSchema.java
+++
b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/annotations/DefaultSchema.java
@@ -19,6 +19,7 @@ package org.apache.beam.sdk.schemas.annotations;
import static
org.apache.beam.vendor.guava.v20_0.com.google.common.base.Preconditions.checkArgument;
+import java.io.Serializable;
import java.lang.annotation.Documented;
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
@@ -71,46 +72,64 @@ public @interface DefaultSchema {
* delegates to that provider.
*/
class DefaultSchemaProvider implements SchemaProvider {
- final Map<TypeDescriptor, SchemaProvider> cachedProviders =
Maps.newConcurrentMap();
+ final Map<TypeDescriptor, ProviderAndDescriptor> cachedProviders =
Maps.newConcurrentMap();
+
+ private static final class ProviderAndDescriptor implements Serializable {
+ final SchemaProvider schemaProvider;
+ final TypeDescriptor<?> typeDescriptor;
+
+ public ProviderAndDescriptor(
+ SchemaProvider schemaProvider, TypeDescriptor<?> typeDescriptor) {
+ this.schemaProvider = schemaProvider;
+ this.typeDescriptor = typeDescriptor;
+ }
+ }
@Nullable
- private SchemaProvider getSchemaProvider(TypeDescriptor<?> typeDescriptor)
{
+ private ProviderAndDescriptor getSchemaProvider(TypeDescriptor<?>
typeDescriptor) {
return cachedProviders.computeIfAbsent(
typeDescriptor,
type -> {
Class<?> clazz = type.getRawType();
- DefaultSchema annotation =
clazz.getAnnotation(DefaultSchema.class);
- if (annotation == null) {
- return null;
- }
- Class<? extends SchemaProvider> providerClass = annotation.value();
- checkArgument(
- providerClass != null,
- "Type " + type + " has a @DefaultSchema annotation with a null
argument.");
+ do {
+ DefaultSchema annotation =
clazz.getAnnotation(DefaultSchema.class);
+ if (annotation != null) {
+ Class<? extends SchemaProvider> providerClass =
annotation.value();
+ checkArgument(
+ providerClass != null,
+ "Type " + type + " has a @DefaultSchema annotation with a
null argument.");
- try {
- return providerClass.getDeclaredConstructor().newInstance();
- } catch (NoSuchMethodException
- | InstantiationException
- | IllegalAccessException
- | InvocationTargetException e) {
- throw new IllegalStateException(
- "Failed to create SchemaProvider "
- + providerClass.getSimpleName()
- + " which was"
- + " specified as the default SchemaProvider for type "
- + type
- + ". Make "
- + " sure that this class has a public default
constructor.",
- e);
- }
+ try {
+ return new ProviderAndDescriptor(
+ providerClass.getDeclaredConstructor().newInstance(),
+ TypeDescriptor.of(clazz));
+ } catch (NoSuchMethodException
+ | InstantiationException
+ | IllegalAccessException
+ | InvocationTargetException e) {
+ throw new IllegalStateException(
+ "Failed to create SchemaProvider "
+ + providerClass.getSimpleName()
+ + " which was"
+ + " specified as the default SchemaProvider for type
"
+ + type
+ + ". Make "
+ + " sure that this class has a public default
constructor.",
+ e);
+ }
+ }
+ clazz = clazz.getSuperclass();
+ } while (clazz != null && !clazz.equals(Object.class));
+ return null;
});
}
@Override
public <T> Schema schemaFor(TypeDescriptor<T> typeDescriptor) {
- SchemaProvider schemaProvider = getSchemaProvider(typeDescriptor);
- return (schemaProvider != null) ?
schemaProvider.schemaFor(typeDescriptor) : null;
+ ProviderAndDescriptor providerAndDescriptor =
getSchemaProvider(typeDescriptor);
+ return (providerAndDescriptor != null)
+ ?
providerAndDescriptor.schemaProvider.schemaFor(providerAndDescriptor.typeDescriptor)
+ : null;
}
/**
@@ -119,8 +138,11 @@ public @interface DefaultSchema {
*/
@Override
public <T> SerializableFunction<T, Row> toRowFunction(TypeDescriptor<T>
typeDescriptor) {
- SchemaProvider schemaProvider = getSchemaProvider(typeDescriptor);
- return (schemaProvider != null) ?
schemaProvider.toRowFunction(typeDescriptor) : null;
+ ProviderAndDescriptor providerAndDescriptor =
getSchemaProvider(typeDescriptor);
+ return (providerAndDescriptor != null)
+ ? providerAndDescriptor.schemaProvider.toRowFunction(
+ (TypeDescriptor<T>) providerAndDescriptor.typeDescriptor)
+ : null;
}
/**
@@ -129,8 +151,11 @@ public @interface DefaultSchema {
*/
@Override
public <T> SerializableFunction<Row, T> fromRowFunction(TypeDescriptor<T>
typeDescriptor) {
- SchemaProvider schemaProvider = getSchemaProvider(typeDescriptor);
- return (schemaProvider != null) ?
schemaProvider.fromRowFunction(typeDescriptor) : null;
+ ProviderAndDescriptor providerAndDescriptor =
getSchemaProvider(typeDescriptor);
+ return (providerAndDescriptor != null)
+ ? providerAndDescriptor.schemaProvider.fromRowFunction(
+ (TypeDescriptor<T>) providerAndDescriptor.typeDescriptor)
+ : null;
}
}
diff --git
a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/Convert.java
b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/Convert.java
index 9b01b35..b137c6a 100644
---
a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/Convert.java
+++
b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/Convert.java
@@ -20,15 +20,13 @@ package org.apache.beam.sdk.schemas.transforms;
import javax.annotation.Nullable;
import org.apache.beam.sdk.annotations.Experimental;
import org.apache.beam.sdk.annotations.Experimental.Kind;
-import org.apache.beam.sdk.schemas.NoSuchSchemaException;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.schemas.Schema.FieldType;
-import org.apache.beam.sdk.schemas.SchemaCoder;
import org.apache.beam.sdk.schemas.SchemaRegistry;
+import org.apache.beam.sdk.schemas.utils.ConvertHelpers;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.ParDo;
-import org.apache.beam.sdk.transforms.SerializableFunctions;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.Row;
import org.apache.beam.sdk.values.TypeDescriptor;
@@ -99,7 +97,6 @@ public class Convert {
private static class ConvertTransform<InputT, OutputT>
extends PTransform<PCollection<InputT>, PCollection<OutputT>> {
TypeDescriptor<OutputT> outputTypeDescriptor;
- Schema unboxedSchema = null;
ConvertTransform(TypeDescriptor<OutputT> outputTypeDescriptor) {
this.outputTypeDescriptor = outputTypeDescriptor;
@@ -124,62 +121,34 @@ public class Convert {
throw new RuntimeException("Convert requires a schema on the input.");
}
- final SchemaCoder<OutputT> outputSchemaCoder;
- boolean toRow =
outputTypeDescriptor.equals(TypeDescriptor.of(Row.class));
- if (toRow) {
- // If the output is of type Row, then just forward the schema of the
input type to the
- // output.
- outputSchemaCoder =
- (SchemaCoder<OutputT>)
- SchemaCoder.of(
- input.getSchema(),
- SerializableFunctions.identity(),
- SerializableFunctions.identity());
- } else {
- // Otherwise, try to find a schema for the output type in the schema
registry.
- SchemaRegistry registry = input.getPipeline().getSchemaRegistry();
- try {
- outputSchemaCoder =
- SchemaCoder.of(
- registry.getSchema(outputTypeDescriptor),
- registry.getToRowFunction(outputTypeDescriptor),
- registry.getFromRowFunction(outputTypeDescriptor));
-
- Schema outputSchema = outputSchemaCoder.getSchema();
- if (!outputSchema.assignableToIgnoreNullable(input.getSchema())) {
- // We also support unboxing nested Row schemas, so attempt that.
- // TODO: Support unboxing to primitive types as well.
- unboxedSchema = getBoxedNestedSchema(input.getSchema());
- if (unboxedSchema == null ||
!outputSchema.assignableToIgnoreNullable(unboxedSchema)) {
- Schema checked = (unboxedSchema == null) ? input.getSchema() :
unboxedSchema;
- throw new RuntimeException(
- "Cannot convert between types that don't have equivalent
schemas."
- + " input schema: "
- + checked
- + " output schema: "
- + outputSchemaCoder.getSchema());
- }
- }
- } catch (NoSuchSchemaException e) {
- throw new RuntimeException("No schema registered for " +
outputTypeDescriptor);
- }
- }
-
- return input
- .apply(
+ SchemaRegistry registry = input.getPipeline().getSchemaRegistry();
+ ConvertHelpers.ConvertedSchemaInformation<OutputT> converted =
+ ConvertHelpers.getConvertedSchemaInformation(
+ input.getSchema(), outputTypeDescriptor, registry);
+ boolean unbox = converted.unboxedType != null;
+ PCollection<OutputT> output =
+ input.apply(
ParDo.of(
new DoFn<InputT, OutputT>() {
@ProcessElement
public void processElement(@Element Row row,
OutputReceiver<OutputT> o) {
// Read the row, potentially unboxing if necessary.
- Row input = (unboxedSchema == null) ? row :
row.getValue(0);
-
o.output(outputSchemaCoder.getFromRowFunction().apply(input));
+ Object input = unbox ? row.getValue(0) : row;
+ // The output has a schema, so we need to convert to the
appropriate type.
+
o.output(converted.outputSchemaCoder.getFromRowFunction().apply((Row) input));
}
- }))
- .setSchema(
- outputSchemaCoder.getSchema(),
- outputSchemaCoder.getToRowFunction(),
- outputSchemaCoder.getFromRowFunction());
+ }));
+ if (converted.outputSchemaCoder != null) {
+ output =
+ output.setSchema(
+ converted.outputSchemaCoder.getSchema(),
+ converted.outputSchemaCoder.getToRowFunction(),
+ converted.outputSchemaCoder.getFromRowFunction());
+ } else {
+ // TODO: Support full unboxing and boxing in Create.
+ throw new RuntimeException("Unboxing is not yet supported in the
Create transform");
+ }
+ return output;
}
}
}
diff --git
a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/Select.java
b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/Select.java
index 077cc33..7626869 100644
---
a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/Select.java
+++
b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/Select.java
@@ -64,7 +64,7 @@ import org.apache.beam.sdk.values.Row;
*
* <pre>{@code
* PCollection<UserEvent> events = readUserEvents();
- * PCollection<Row> rows = event.apply(Select.fieldNames("location")
+ * PCollection<Location> rows = event.apply(Select.fieldNames("location")
* .apply(Convert.to(Location.class));
* }</pre>
*/
diff --git
a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/ConvertHelpers.java
b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/ConvertHelpers.java
new file mode 100644
index 0000000..e74f85f
--- /dev/null
+++
b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/ConvertHelpers.java
@@ -0,0 +1,204 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.schemas.utils;
+
+import java.io.Serializable;
+import java.lang.reflect.InvocationTargetException;
+import java.lang.reflect.Type;
+import javax.annotation.Nullable;
+import net.bytebuddy.ByteBuddy;
+import net.bytebuddy.description.type.TypeDescription;
+import net.bytebuddy.dynamic.DynamicType;
+import net.bytebuddy.dynamic.loading.ClassLoadingStrategy;
+import net.bytebuddy.dynamic.scaffold.InstrumentedType;
+import net.bytebuddy.implementation.Implementation;
+import net.bytebuddy.implementation.bytecode.ByteCodeAppender;
+import net.bytebuddy.implementation.bytecode.ByteCodeAppender.Size;
+import net.bytebuddy.implementation.bytecode.StackManipulation;
+import net.bytebuddy.implementation.bytecode.member.MethodReturn;
+import net.bytebuddy.implementation.bytecode.member.MethodVariableAccess;
+import net.bytebuddy.matcher.ElementMatchers;
+import org.apache.beam.sdk.schemas.JavaFieldSchema.JavaFieldTypeSupplier;
+import org.apache.beam.sdk.schemas.NoSuchSchemaException;
+import org.apache.beam.sdk.schemas.Schema;
+import org.apache.beam.sdk.schemas.Schema.FieldType;
+import org.apache.beam.sdk.schemas.SchemaCoder;
+import org.apache.beam.sdk.schemas.SchemaRegistry;
+import org.apache.beam.sdk.schemas.utils.ByteBuddyUtils.ConvertType;
+import org.apache.beam.sdk.schemas.utils.ByteBuddyUtils.ConvertValueForSetter;
+import org.apache.beam.sdk.transforms.SerializableFunction;
+import org.apache.beam.sdk.transforms.SerializableFunctions;
+import org.apache.beam.sdk.util.common.ReflectHelpers;
+import org.apache.beam.sdk.values.Row;
+import org.apache.beam.sdk.values.TypeDescriptor;
+import
org.apache.beam.vendor.guava.v20_0.com.google.common.primitives.Primitives;
+
+/** Helper functions for converting between equivalent schema types. */
+public class ConvertHelpers {
+ /** Return value after converting a schema. */
+ public static class ConvertedSchemaInformation<T> implements Serializable {
+ // If the output type is a composite type, this is the schema coder.
+ @Nullable public final SchemaCoder<T> outputSchemaCoder;
+ @Nullable public final FieldType unboxedType;
+
+ public ConvertedSchemaInformation(
+ @Nullable SchemaCoder<T> outputSchemaCoder, @Nullable FieldType
unboxedType) {
+ this.outputSchemaCoder = outputSchemaCoder;
+ this.unboxedType = unboxedType;
+ }
+ }
+
+ /** Get the coder used for converting from an inputSchema to a given type. */
+ public static <T> ConvertedSchemaInformation<T>
getConvertedSchemaInformation(
+ Schema inputSchema, TypeDescriptor<T> outputType, SchemaRegistry
schemaRegistry) {
+ ConvertedSchemaInformation<T> convertedSchema = null;
+ boolean toRow = outputType.equals(TypeDescriptor.of(Row.class));
+ if (toRow) {
+ // If the output is of type Row, then just forward the schema of the
input type to the
+ // output.
+ convertedSchema =
+ new ConvertedSchemaInformation<>(
+ (SchemaCoder<T>)
+ SchemaCoder.of(
+ inputSchema,
+ SerializableFunctions.identity(),
+ SerializableFunctions.identity()),
+ null);
+ } else {
+ // Otherwise, try to find a schema for the output type in the schema
registry.
+ Schema outputSchema = null;
+ SchemaCoder<T> outputSchemaCoder = null;
+ try {
+ outputSchema = schemaRegistry.getSchema(outputType);
+ outputSchemaCoder =
+ SchemaCoder.of(
+ outputSchema,
+ schemaRegistry.getToRowFunction(outputType),
+ schemaRegistry.getFromRowFunction(outputType));
+ } catch (NoSuchSchemaException e) {
+
+ }
+ FieldType unboxedType = null;
+ // TODO: Properly handle nullable.
+ if (outputSchema == null ||
!outputSchema.assignableToIgnoreNullable(inputSchema)) {
+ // The schema is not convertible directly. Attempt to unbox it and see
if the schema matches
+ // then.
+ Schema checkedSchema = inputSchema;
+ if (inputSchema.getFieldCount() == 1) {
+ unboxedType = inputSchema.getField(0).getType();
+ if (unboxedType.getTypeName().isCompositeType()
+ &&
!outputSchema.assignableToIgnoreNullable(unboxedType.getRowSchema())) {
+ checkedSchema = unboxedType.getRowSchema();
+ } else {
+ checkedSchema = null;
+ }
+ }
+ if (checkedSchema != null) {
+ throw new RuntimeException(
+ "Cannot convert between types that don't have equivalent
schemas."
+ + " input schema: "
+ + checkedSchema
+ + " output schema: "
+ + outputSchema);
+ }
+ }
+ convertedSchema = new ConvertedSchemaInformation<T>(outputSchemaCoder,
unboxedType);
+ }
+ return convertedSchema;
+ }
+
+ /**
+ * Returns a function to convert a Row into a primitive type. This only
works when the row schema
+ * contains a single field, and that field is convertible to the primitive
type.
+ */
+ @SuppressWarnings("unchecked")
+ public static <OutputT> SerializableFunction<?, OutputT> getConvertPrimitive(
+ FieldType fieldType, TypeDescriptor<?> outputTypeDescriptor) {
+ FieldType expectedFieldType =
+ StaticSchemaInference.fieldFromType(outputTypeDescriptor,
JavaFieldTypeSupplier.INSTANCE);
+ if (!expectedFieldType.equals(fieldType)) {
+ throw new IllegalArgumentException(
+ "Element argument type "
+ + outputTypeDescriptor
+ + " does not work with expected schema field type "
+ + fieldType);
+ }
+
+ Type expectedInputType = new
ConvertType(true).convert(outputTypeDescriptor);
+
+ TypeDescriptor<?> outputType = outputTypeDescriptor;
+ if (outputType.getRawType().isPrimitive()) {
+ // A SerializableFunction can only return an Object type, so if the DoFn
parameter is a
+ // primitive type, then box it for the return. The return type will be
unboxed before being
+ // forwarded to the DoFn parameter.
+ outputType = TypeDescriptor.of(Primitives.wrap(outputType.getRawType()));
+ }
+
+ TypeDescription.Generic genericType =
+ TypeDescription.Generic.Builder.parameterizedType(
+ SerializableFunction.class, expectedInputType,
outputType.getType())
+ .build();
+ DynamicType.Builder<SerializableFunction> builder =
+ (DynamicType.Builder<SerializableFunction>) new
ByteBuddy().subclass(genericType);
+ try {
+ return builder
+ .method(ElementMatchers.named("apply"))
+ .intercept(new ConvertPrimitiveInstruction(outputType))
+ .make()
+ .load(ReflectHelpers.findClassLoader(),
ClassLoadingStrategy.Default.INJECTION)
+ .getLoaded()
+ .getDeclaredConstructor()
+ .newInstance();
+ } catch (InstantiationException
+ | IllegalAccessException
+ | NoSuchMethodException
+ | InvocationTargetException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ static class ConvertPrimitiveInstruction implements Implementation {
+ private final TypeDescriptor<?> outputFieldType;
+
+ public ConvertPrimitiveInstruction(TypeDescriptor<?> outputFieldType) {
+ this.outputFieldType = outputFieldType;
+ }
+
+ @Override
+ public InstrumentedType prepare(InstrumentedType instrumentedType) {
+ return instrumentedType;
+ }
+
+ @Override
+ public ByteCodeAppender appender(final Target implementationTarget) {
+ return (methodVisitor, implementationContext, instrumentedMethod) -> {
+ int numLocals = 1 + instrumentedMethod.getParameters().size();
+
+ // Method param is offset 1 (offset 0 is the this parameter).
+ StackManipulation readValue =
MethodVariableAccess.REFERENCE.loadFrom(1);
+ StackManipulation stackManipulation =
+ new StackManipulation.Compound(
+ new ConvertValueForSetter(readValue).convert(outputFieldType),
+ MethodReturn.REFERENCE);
+
+ StackManipulation.Size size = stackManipulation.apply(methodVisitor,
implementationContext);
+ return new Size(size.getMaximalSize(), numLocals);
+ };
+ }
+ }
+}
diff --git
a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/StaticSchemaInference.java
b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/StaticSchemaInference.java
index 073ead1..7de5fae 100644
---
a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/StaticSchemaInference.java
+++
b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/StaticSchemaInference.java
@@ -91,8 +91,8 @@ public class StaticSchemaInference {
return builder.build();
}
- // Map a Java field type to a Beam Schema FieldType.
- private static Schema.FieldType fieldFromType(
+ /** Map a Java field type to a Beam Schema FieldType. */
+ public static Schema.FieldType fieldFromType(
TypeDescriptor type, FieldValueTypeSupplier fieldValueTypeSupplier) {
FieldType primitiveType = PRIMITIVE_TYPES.get(type.getRawType());
if (primitiveType != null) {
diff --git
a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnSchemaInformation.java
b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnSchemaInformation.java
index 5c0347f..ab54fb3 100644
---
a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnSchemaInformation.java
+++
b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnSchemaInformation.java
@@ -19,8 +19,17 @@ package org.apache.beam.sdk.transforms;
import com.google.auto.value.AutoValue;
import java.io.Serializable;
-import javax.annotation.Nullable;
+import java.util.Collections;
+import java.util.List;
+import org.apache.beam.sdk.schemas.FieldAccessDescriptor;
+import org.apache.beam.sdk.schemas.Schema;
+import org.apache.beam.sdk.schemas.Schema.FieldType;
import org.apache.beam.sdk.schemas.SchemaCoder;
+import org.apache.beam.sdk.schemas.utils.ConvertHelpers;
+import org.apache.beam.sdk.schemas.utils.SelectHelpers;
+import org.apache.beam.sdk.values.Row;
+import org.apache.beam.sdk.values.TypeDescriptor;
+import
org.apache.beam.vendor.guava.v20_0.com.google.common.collect.ImmutableList;
/** Represents information about how a DoFn extracts schemas. */
@AutoValue
@@ -29,25 +38,201 @@ public abstract class DoFnSchemaInformation implements
Serializable {
* The schema of the @Element parameter. If the Java type does not match the
input PCollection but
* the schemas are compatible, Beam will automatically convert between the
Java types.
*/
- @Nullable
- public abstract SchemaCoder<?> getElementParameterSchema();
+ public abstract List<SerializableFunction<?, ?>> getElementConverters();
/** Create an instance. */
public static DoFnSchemaInformation create() {
- return new AutoValue_DoFnSchemaInformation.Builder().build();
+ return new AutoValue_DoFnSchemaInformation.Builder()
+ .setElementConverters(Collections.emptyList())
+ .build();
}
/** The builder object. */
@AutoValue.Builder
public abstract static class Builder {
- abstract Builder setElementParameterSchema(@Nullable SchemaCoder<?>
schemaCoder);
+ abstract Builder setElementConverters(List<SerializableFunction<?, ?>>
converters);
abstract DoFnSchemaInformation build();
}
public abstract Builder toBuilder();
- public <T> DoFnSchemaInformation withElementParameterSchema(SchemaCoder<T>
schemaCoder) {
- return toBuilder().setElementParameterSchema(schemaCoder).build();
+ /**
+ * Specified a parameter that is a selection from an input schema (specified
using FieldAccess).
+ * This method is called when the input parameter itself has a schema. The
input parameter does
+ * not need to be a Row. If it is a type with a compatible registered
schema, then the conversion
+ * will be done automatically.
+ *
+ * @param inputCoder The coder for the ParDo's input elements.
+ * @param selectDescriptor The descriptor describing which field to select.
+ * @param selectOutputSchema The schema of the selected parameter.
+ * @param parameterCoder The coder for the input parameter to the method.
+ * @param unbox If unbox is true, then the select result is a 1-field schema
that needs to be
+ * unboxed.
+ * @return
+ */
+ DoFnSchemaInformation withSelectFromSchemaParameter(
+ SchemaCoder<?> inputCoder,
+ FieldAccessDescriptor selectDescriptor,
+ Schema selectOutputSchema,
+ SchemaCoder<?> parameterCoder,
+ boolean unbox) {
+ List<SerializableFunction<?, ?>> converters =
+ ImmutableList.<SerializableFunction<?, ?>>builder()
+ .addAll(getElementConverters())
+ .add(
+ ConversionFunction.of(
+ inputCoder.getSchema(),
+ inputCoder.getToRowFunction(),
+ parameterCoder.getFromRowFunction(),
+ selectDescriptor,
+ selectOutputSchema,
+ unbox))
+ .build();
+
+ return toBuilder().setElementConverters(converters).build();
+ }
+
+ /**
+ * Specified a parameter that is a selection from an input schema (specified
using FieldAccess).
+ * This method is called when the input parameter is a Java type that does
not itself have a
+ * schema, e.g. long, or String. In this case we expect the selection
predicate to return a
+ * single-field row with a field of the output type.
+ *
+ * @param inputCoder The coder for the ParDo's input elements.
+ * @param selectDescriptor The descriptor describing which field to select.
+ * @param selectOutputSchema The schema of the selected parameter.
+ * @param elementT The type of the method's input parameter.
+ * @return
+ */
+ DoFnSchemaInformation withUnboxPrimitiveParameter(
+ SchemaCoder inputCoder,
+ FieldAccessDescriptor selectDescriptor,
+ Schema selectOutputSchema,
+ TypeDescriptor<?> elementT) {
+ if (selectOutputSchema.getFieldCount() != 1) {
+ throw new RuntimeException("Parameter has no schema and the input is not
a simple type.");
+ }
+ FieldType fieldType = selectOutputSchema.getField(0).getType();
+ if (fieldType.getTypeName().isCompositeType()) {
+ throw new RuntimeException("Parameter has no schema and the input is not
a primitive type.");
+ }
+
+ List<SerializableFunction<?, ?>> converters =
+ ImmutableList.<SerializableFunction<?, ?>>builder()
+ .addAll(getElementConverters())
+ .add(
+ UnboxingConversionFunction.of(
+ inputCoder.getSchema(),
+ inputCoder.getToRowFunction(),
+ selectDescriptor,
+ selectOutputSchema,
+ elementT))
+ .build();
+
+ return toBuilder().setElementConverters(converters).build();
+ }
+
+ private static class ConversionFunction<InputT, OutputT>
+ implements SerializableFunction<InputT, OutputT> {
+ private final Schema inputSchema;
+ private final SerializableFunction<InputT, Row> toRowFunction;
+ private final SerializableFunction<Row, OutputT> fromRowFunction;
+ private final FieldAccessDescriptor selectDescriptor;
+ private final Schema selectOutputSchema;
+ private final boolean unbox;
+
+ private ConversionFunction(
+ Schema inputSchema,
+ SerializableFunction<InputT, Row> toRowFunction,
+ SerializableFunction<Row, OutputT> fromRowFunction,
+ FieldAccessDescriptor selectDescriptor,
+ Schema selectOutputSchema,
+ boolean unbox) {
+ this.inputSchema = inputSchema;
+ this.toRowFunction = toRowFunction;
+ this.fromRowFunction = fromRowFunction;
+ this.selectDescriptor = selectDescriptor;
+ this.selectOutputSchema = selectOutputSchema;
+ this.unbox = unbox;
+ }
+
+ public static <InputT, OutputT> ConversionFunction of(
+ Schema inputSchema,
+ SerializableFunction<InputT, Row> toRowFunction,
+ SerializableFunction<Row, OutputT> fromRowFunction,
+ FieldAccessDescriptor selectDescriptor,
+ Schema selectOutputSchema,
+ boolean unbox) {
+ return new ConversionFunction<>(
+ inputSchema, toRowFunction, fromRowFunction, selectDescriptor,
selectOutputSchema, unbox);
+ }
+
+ @Override
+ public OutputT apply(InputT input) {
+ Row row = toRowFunction.apply(input);
+ Row selected =
+ SelectHelpers.selectRow(row, selectDescriptor, inputSchema,
selectOutputSchema);
+ if (unbox) {
+ selected = selected.getRow(0);
+ }
+ return fromRowFunction.apply(selected);
+ }
+ }
+
+ /**
+ * This function is used when the schema is a singleton schema containing a
single primitive field
+ * and the Java type we are converting to is that of the primitive field.
+ */
+ private static class UnboxingConversionFunction<InputT, OutputT>
+ implements SerializableFunction<InputT, OutputT> {
+ private final Schema inputSchema;
+ private final SerializableFunction<InputT, Row> toRowFunction;
+ private final FieldAccessDescriptor selectDescriptor;
+ private final Schema selectOutputSchema;
+ private final FieldType primitiveType;
+ private final TypeDescriptor<?> primitiveOutputType;
+ private transient SerializableFunction<InputT, OutputT> conversionFunction;
+
+ private UnboxingConversionFunction(
+ Schema inputSchema,
+ SerializableFunction<InputT, Row> toRowFunction,
+ FieldAccessDescriptor selectDescriptor,
+ Schema selectOutputSchema,
+ TypeDescriptor<?> primitiveOutputType) {
+ this.inputSchema = inputSchema;
+ this.toRowFunction = toRowFunction;
+ this.selectDescriptor = selectDescriptor;
+ this.selectOutputSchema = selectOutputSchema;
+ this.primitiveType = selectOutputSchema.getField(0).getType();
+ this.primitiveOutputType = primitiveOutputType;
+ }
+
+ public static <InputT, OutputT> UnboxingConversionFunction of(
+ Schema inputSchema,
+ SerializableFunction<InputT, Row> toRowFunction,
+ FieldAccessDescriptor selectDescriptor,
+ Schema selectOutputSchema,
+ TypeDescriptor<?> primitiveOutputType) {
+ return new UnboxingConversionFunction<>(
+ inputSchema, toRowFunction, selectDescriptor, selectOutputSchema,
primitiveOutputType);
+ }
+
+ @Override
+ public OutputT apply(InputT input) {
+ Row row = toRowFunction.apply(input);
+ Row selected =
+ SelectHelpers.selectRow(row, selectDescriptor, inputSchema,
selectOutputSchema);
+ return getConversionFunction().apply(selected.getValue(0));
+ }
+
+ private SerializableFunction<InputT, OutputT> getConversionFunction() {
+ if (conversionFunction == null) {
+ conversionFunction =
+ (SerializableFunction<InputT, OutputT>)
+ ConvertHelpers.getConvertPrimitive(primitiveType,
primitiveOutputType);
+ }
+ return conversionFunction;
+ }
}
}
diff --git
a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnTester.java
b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnTester.java
index c3b1f82..1c5f4b6 100644
---
a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnTester.java
+++
b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnTester.java
@@ -251,7 +251,7 @@ public class DoFnTester<InputT, OutputT> implements
AutoCloseable {
}
@Override
- public InputT schemaElement(DoFn<InputT, OutputT> doFn) {
+ public InputT schemaElement(int index) {
throw new UnsupportedOperationException("Schemas are not
supported by DoFnTester");
}
diff --git
a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java
b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java
index 9febcba..08bae25 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java
@@ -39,6 +39,8 @@ import org.apache.beam.sdk.schemas.NoSuchSchemaException;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.schemas.SchemaCoder;
import org.apache.beam.sdk.schemas.SchemaRegistry;
+import org.apache.beam.sdk.schemas.utils.ConvertHelpers;
+import org.apache.beam.sdk.schemas.utils.SelectHelpers;
import org.apache.beam.sdk.state.StateSpec;
import org.apache.beam.sdk.transforms.DoFn.WindowedContext;
import org.apache.beam.sdk.transforms.display.DisplayData;
@@ -59,7 +61,6 @@ import org.apache.beam.sdk.values.PCollectionTuple;
import org.apache.beam.sdk.values.PCollectionView;
import org.apache.beam.sdk.values.PCollectionViews;
import org.apache.beam.sdk.values.PValue;
-import org.apache.beam.sdk.values.Row;
import org.apache.beam.sdk.values.TupleTag;
import org.apache.beam.sdk.values.TupleTagList;
import org.apache.beam.sdk.values.TypeDescriptor;
@@ -435,7 +436,7 @@ public class ParDo {
}
}
- private static void validateFieldAccessParameter(
+ private static FieldAccessDescriptor getFieldAccessDescriptorFromParameter(
@Nullable String fieldAccessString,
Schema inputSchema,
Map<String, FieldAccessDeclaration> fieldAccessDeclarations,
@@ -448,25 +449,25 @@ public class ParDo {
// here as well to catch these errors.
FieldAccessDescriptor fieldAccessDescriptor = null;
if (fieldAccessString == null) {
- // This is the case where no FieldId is defined, just an @Element Row
row. Default to all
- // fields accessed.
+ // This is the case where no FieldId is defined. Default to all fields
accessed.
fieldAccessDescriptor = FieldAccessDescriptor.withAllFields();
} else {
- // In this case, we expect to have a FieldAccessDescriptor defined in
the class.
+ // If there is a FieldAccessDescriptor in the class with this id, use
that.
FieldAccessDeclaration fieldAccessDeclaration =
fieldAccessDeclarations.get(fieldAccessString);
- checkArgument(
- fieldAccessDeclaration != null,
- "No FieldAccessDeclaration defined with id",
- fieldAccessString);
-
checkArgument(fieldAccessDeclaration.field().getType().equals(FieldAccessDescriptor.class));
- try {
- fieldAccessDescriptor = (FieldAccessDescriptor)
fieldAccessDeclaration.field().get(fn);
- } catch (IllegalAccessException e) {
- throw new RuntimeException(e);
+ if (fieldAccessDeclaration != null) {
+
checkArgument(fieldAccessDeclaration.field().getType().equals(FieldAccessDescriptor.class));
+ try {
+ fieldAccessDescriptor = (FieldAccessDescriptor)
fieldAccessDeclaration.field().get(fn);
+ } catch (IllegalAccessException e) {
+ throw new RuntimeException(e);
+ }
+ } else {
+ // Otherwise, interpret the string as a field-name expression.
+ fieldAccessDescriptor =
FieldAccessDescriptor.withFieldNames(fieldAccessString);
}
}
- fieldAccessDescriptor.resolve(inputSchema);
+ return fieldAccessDescriptor.resolve(inputSchema);
}
/**
@@ -571,64 +572,44 @@ public class ParDo {
DoFn<?, ?> fn, PCollection<?> input) {
DoFnSignature signature = DoFnSignatures.getSignature(fn.getClass());
DoFnSignature.ProcessElementMethod processElementMethod =
signature.processElement();
- SchemaElementParameter elementParameter =
processElementMethod.getSchemaElementParameter();
- boolean validateInputSchema = elementParameter != null;
- TypeDescriptor<?> elementT = null;
- if (validateInputSchema) {
- elementT = (TypeDescriptor<?>) elementParameter.elementT();
- }
-
- DoFnSchemaInformation doFnSchemaInformation =
DoFnSchemaInformation.create();
- if (validateInputSchema) {
- // Element type doesn't match input type, so we need to covnert.
+ if (!processElementMethod.getSchemaElementParameters().isEmpty()) {
if (!input.hasSchema()) {
throw new IllegalArgumentException("Type of @Element must match the
DoFn type" + input);
}
+ }
- validateFieldAccessParameter(
- elementParameter.fieldAccessString(),
- input.getSchema(),
- signature.fieldAccessDeclarations(),
- fn);
-
- boolean toRow = elementT.equals(TypeDescriptor.of(Row.class));
- if (toRow) {
+ SchemaRegistry schemaRegistry = input.getPipeline().getSchemaRegistry();
+ DoFnSchemaInformation doFnSchemaInformation =
DoFnSchemaInformation.create();
+ for (SchemaElementParameter parameter :
processElementMethod.getSchemaElementParameters()) {
+ TypeDescriptor<?> elementT = parameter.elementT();
+ FieldAccessDescriptor accessDescriptor =
+ getFieldAccessDescriptorFromParameter(
+ parameter.fieldAccessString(),
+ input.getSchema(),
+ signature.fieldAccessDeclarations(),
+ fn);
+ Schema selectedSchema = SelectHelpers.getOutputSchema(input.getSchema(),
accessDescriptor);
+ ConvertHelpers.ConvertedSchemaInformation converted =
+ ConvertHelpers.getConvertedSchemaInformation(selectedSchema,
elementT, schemaRegistry);
+ if (converted.outputSchemaCoder != null) {
doFnSchemaInformation =
- doFnSchemaInformation.withElementParameterSchema(
- SchemaCoder.of(
- input.getSchema(),
- SerializableFunctions.identity(),
- SerializableFunctions.identity()));
+ doFnSchemaInformation.withSelectFromSchemaParameter(
+ (SchemaCoder<?>) input.getCoder(),
+ accessDescriptor,
+ selectedSchema,
+ converted.outputSchemaCoder,
+ converted.unboxedType != null);
} else {
- // For now we assume the parameter is not of type Row (TODO: change
this)
- SchemaRegistry schemaRegistry =
input.getPipeline().getSchemaRegistry();
- try {
- Schema schema = schemaRegistry.getSchema(elementT);
- SerializableFunction toRowFunction =
schemaRegistry.getToRowFunction(elementT);
- SerializableFunction fromRowFunction =
schemaRegistry.getFromRowFunction(elementT);
- doFnSchemaInformation =
- doFnSchemaInformation.withElementParameterSchema(
- SchemaCoder.of(schema, toRowFunction, fromRowFunction));
-
- // assert matches input schema.
- // TODO: Properly handle nullable.
- if (!doFnSchemaInformation
- .getElementParameterSchema()
- .getSchema()
- .assignableToIgnoreNullable(input.getSchema())) {
- throw new IllegalArgumentException(
- "Input to DoFn has schema: "
- + input.getSchema()
- + " However @ElementParameter of type "
- + elementT
- + " has incompatible schema "
- +
doFnSchemaInformation.getElementParameterSchema().getSchema());
- }
- } catch (NoSuchSchemaException e) {
- throw new RuntimeException("No schema registered for " + elementT);
- }
+ // If the selected schema is a Row containing a single primitive type
(which is the output
+ // of Select when selecting a primitive), attempt to unbox it and
match against the
+ // parameter.
+ checkArgument(converted.unboxedType != null);
+ doFnSchemaInformation =
+ doFnSchemaInformation.withUnboxPrimitiveParameter(
+ (SchemaCoder<?>) input.getCoder(), accessDescriptor,
selectedSchema, elementT);
}
}
+
return doFnSchemaInformation;
}
diff --git
a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/ByteBuddyDoFnInvokerFactory.java
b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/ByteBuddyDoFnInvokerFactory.java
index 47f09f4..457ee55 100644
---
a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/ByteBuddyDoFnInvokerFactory.java
+++
b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/ByteBuddyDoFnInvokerFactory.java
@@ -34,6 +34,7 @@ import net.bytebuddy.description.field.FieldDescription;
import net.bytebuddy.description.method.MethodDescription;
import net.bytebuddy.description.modifier.Visibility;
import net.bytebuddy.description.type.TypeDescription;
+import net.bytebuddy.description.type.TypeDescription.ForLoadedType;
import net.bytebuddy.description.type.TypeList;
import net.bytebuddy.dynamic.DynamicType;
import net.bytebuddy.dynamic.loading.ClassLoadingStrategy;
@@ -46,10 +47,12 @@ import net.bytebuddy.implementation.Implementation.Context;
import net.bytebuddy.implementation.MethodDelegation;
import net.bytebuddy.implementation.bytecode.ByteCodeAppender;
import net.bytebuddy.implementation.bytecode.StackManipulation;
+import net.bytebuddy.implementation.bytecode.StackManipulation.Compound;
import net.bytebuddy.implementation.bytecode.Throw;
import net.bytebuddy.implementation.bytecode.assign.Assigner;
import net.bytebuddy.implementation.bytecode.assign.Assigner.Typing;
import net.bytebuddy.implementation.bytecode.assign.TypeCasting;
+import net.bytebuddy.implementation.bytecode.constant.IntegerConstant;
import net.bytebuddy.implementation.bytecode.constant.TextConstant;
import net.bytebuddy.implementation.bytecode.member.FieldAccess;
import net.bytebuddy.implementation.bytecode.member.MethodInvocation;
@@ -87,6 +90,7 @@ import
org.apache.beam.sdk.transforms.splittabledofn.HasDefaultTracker;
import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker;
import org.apache.beam.sdk.util.UserCodeException;
import org.apache.beam.sdk.values.TypeDescriptor;
+import
org.apache.beam.vendor.guava.v20_0.com.google.common.primitives.Primitives;
/** Dynamically generates a {@link DoFnInvoker} instances for invoking a
{@link DoFn}. */
public class ByteBuddyDoFnInvokerFactory implements DoFnInvokerFactory {
@@ -663,13 +667,29 @@ public class ByteBuddyDoFnInvokerFactory implements
DoFnInvokerFactory {
@Override
public StackManipulation dispatch(SchemaElementParameter p) {
- // Ignore FieldAccess id for now.
- return new StackManipulation.Compound(
- pushDelegate,
- MethodInvocation.invoke(
- getExtraContextFactoryMethodDescription(
- SCHEMA_ELEMENT_PARAMETER_METHOD, DoFn.class)),
- TypeCasting.to(new
TypeDescription.ForLoadedType(p.elementT().getRawType())));
+ ForLoadedType elementType = new
ForLoadedType(p.elementT().getRawType());
+ ForLoadedType castType =
+ elementType.isPrimitive()
+ ? new
ForLoadedType(Primitives.wrap(p.elementT().getRawType()))
+ : elementType;
+
+ StackManipulation stackManipulation =
+ new StackManipulation.Compound(
+ IntegerConstant.forValue(p.index()),
+ MethodInvocation.invoke(
+ getExtraContextFactoryMethodDescription(
+ SCHEMA_ELEMENT_PARAMETER_METHOD, int.class)),
+ TypeCasting.to(castType));
+ if (elementType.isPrimitive()) {
+ stackManipulation =
+ new Compound(
+ stackManipulation,
+ Assigner.DEFAULT.assign(
+ elementType.asBoxed().asGenericType(),
+ elementType.asUnboxed().asGenericType(),
+ Typing.STATIC));
+ }
+ return stackManipulation;
}
@Override
diff --git
a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnInvoker.java
b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnInvoker.java
index 438a918..d8504ee 100644
---
a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnInvoker.java
+++
b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnInvoker.java
@@ -132,16 +132,19 @@ public interface DoFnInvoker<InputT, OutputT> {
/** Provide a {@link DoFn.OnTimerContext} to use with the given {@link
DoFn}. */
DoFn<InputT, OutputT>.OnTimerContext onTimerContext(DoFn<InputT, OutputT>
doFn);
- /** Provide a link to the input element. */
+ /** Provide a reference to the input element. */
InputT element(DoFn<InputT, OutputT> doFn);
- /** Provide a link to the input element. */
- Object schemaElement(DoFn<InputT, OutputT> doFn);
+ /**
+ * Provide a reference to the selected schema field corresponding to the
input argument
+ * specified by index.
+ */
+ Object schemaElement(int index);
- /** Provide a link to the input element timestamp. */
+ /** Provide a reference to the input element timestamp. */
Instant timestamp(DoFn<InputT, OutputT> doFn);
- /** Provide a link to the time domain for a timer firing. */
+ /** Provide a reference to the time domain for a timer firing. */
TimeDomain timeDomain(DoFn<InputT, OutputT> doFn);
/** Provide a {@link OutputReceiver} for outputting to the default output.
*/
@@ -188,7 +191,7 @@ public interface DoFnInvoker<InputT, OutputT> {
}
@Override
- public InputT schemaElement(DoFn<InputT, OutputT> doFn) {
+ public InputT schemaElement(int index) {
throw new UnsupportedOperationException(
String.format(
"Should never call non-overridden methods of %s",
diff --git
a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignature.java
b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignature.java
index 6151cab..727d4f7 100644
---
a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignature.java
+++
b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignature.java
@@ -24,6 +24,7 @@ import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
+import java.util.stream.Collectors;
import javax.annotation.Nullable;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.options.PipelineOptions;
@@ -398,8 +399,12 @@ public abstract class DoFnSignature {
}
public static SchemaElementParameter schemaElementParameter(
- TypeDescriptor<?> elementT, @Nullable String fieldAccessId) {
- return new
AutoValue_DoFnSignature_Parameter_SchemaElementParameter(elementT,
fieldAccessId);
+ TypeDescriptor<?> elementT, @Nullable String fieldAccessString, int
index) {
+ return new
AutoValue_DoFnSignature_Parameter_SchemaElementParameter.Builder()
+ .setElementT(elementT)
+ .setFieldAccessString(fieldAccessString)
+ .setIndex(index)
+ .build();
}
public static TimestampParameter timestampParameter() {
@@ -511,6 +516,22 @@ public abstract class DoFnSignature {
@Nullable
public abstract String fieldAccessString();
+
+ public abstract int index();
+
+ /** Builder class. */
+ @AutoValue.Builder
+ public abstract static class Builder {
+ public abstract Builder setElementT(TypeDescriptor<?> elementT);
+
+ public abstract Builder setFieldAccessString(@Nullable String
fieldAccess);
+
+ public abstract Builder setIndex(int index);
+
+ public abstract SchemaElementParameter build();
+ }
+
+ public abstract Builder toBuilder();
}
/**
@@ -691,12 +712,11 @@ public abstract class DoFnSignature {
}
@Nullable
- public SchemaElementParameter getSchemaElementParameter() {
- Optional<Parameter> parameter =
- extraParameters().stream()
-
.filter(Predicates.instanceOf(SchemaElementParameter.class)::apply)
- .findFirst();
- return parameter.isPresent() ? ((SchemaElementParameter)
parameter.get()) : null;
+ public List<SchemaElementParameter> getSchemaElementParameters() {
+ return extraParameters().stream()
+ .filter(Predicates.instanceOf(SchemaElementParameter.class)::apply)
+ .map(SchemaElementParameter.class::cast)
+ .collect(Collectors.toList());
}
/** The {@link OutputReceiverParameter} for a main output, or null if
there is none. */
diff --git
a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignatures.java
b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignatures.java
index 9889adc..69b03a4 100644
---
a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignatures.java
+++
b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignatures.java
@@ -54,6 +54,7 @@ import org.apache.beam.sdk.transforms.DoFn.TimerId;
import
org.apache.beam.sdk.transforms.reflect.DoFnSignature.FieldAccessDeclaration;
import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter;
import
org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.RestrictionTrackerParameter;
+import
org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.SchemaElementParameter;
import
org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.StateParameter;
import
org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.TimerParameter;
import
org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.WindowParameter;
@@ -260,6 +261,10 @@ public class DoFnSignatures {
return Collections.unmodifiableList(extraParameters);
}
+ public void setParameter(int index, Parameter parameter) {
+ extraParameters.set(index, parameter);
+ }
+
/**
* Returns an {@link MethodAnalysisContext} like this one but including
the provided {@link
* StateParameter}.
@@ -814,6 +819,16 @@ public class DoFnSignatures {
methodContext.addParameter(extraParam);
}
+ int schemaElementIndex = 0;
+ for (int i = 0; i < methodContext.getExtraParameters().size(); ++i) {
+ Parameter parameter = methodContext.getExtraParameters().get(i);
+ if (parameter instanceof SchemaElementParameter) {
+ SchemaElementParameter schemaParameter = (SchemaElementParameter)
parameter;
+ schemaParameter =
schemaParameter.toBuilder().setIndex(schemaElementIndex).build();
+ methodContext.setParameter(i, schemaParameter);
+ ++schemaElementIndex;
+ }
+ }
// The allowed parameters depend on whether this DoFn is splittable
if (methodContext.hasRestrictionTrackerParameter()) {
@@ -867,13 +882,13 @@ public class DoFnSignatures {
ErrorReporter paramErrors = methodErrors.forParameter(param);
- if (hasElementAnnotation(param.getAnnotations())) {
- if (paramT.equals(inputT)) {
- return Parameter.elementParameter(paramT);
- } else {
- String fieldAccessString = getFieldAccessId(param.getAnnotations());
- return Parameter.schemaElementParameter(paramT, fieldAccessString);
- }
+ String fieldAccessString = getFieldAccessId(param.getAnnotations());
+ if (fieldAccessString != null) {
+ return Parameter.schemaElementParameter(paramT, fieldAccessString,
param.getIndex());
+ } else if (hasElementAnnotation(param.getAnnotations())) {
+ return (paramT.equals(inputT))
+ ? Parameter.elementParameter(paramT)
+ : Parameter.schemaElementParameter(paramT, null, param.getIndex());
} else if (hasTimestampAnnotation(param.getAnnotations())) {
methodErrors.checkArgument(
rawType.equals(Instant.class),
diff --git
a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoSchemaTest.java
b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoSchemaTest.java
index 96ef3b63..9408fc3 100644
---
a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoSchemaTest.java
+++
b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoSchemaTest.java
@@ -17,15 +17,18 @@
*/
package org.apache.beam.sdk.transforms;
+import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
+import com.google.auto.value.AutoValue;
import java.io.Serializable;
+import java.util.Arrays;
import java.util.List;
+import java.util.stream.Collectors;
+import org.apache.beam.sdk.schemas.AutoValueSchema;
import org.apache.beam.sdk.schemas.FieldAccessDescriptor;
-import org.apache.beam.sdk.schemas.JavaFieldSchema;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.schemas.annotations.DefaultSchema;
-import org.apache.beam.sdk.schemas.annotations.SchemaCreate;
import org.apache.beam.sdk.testing.NeedsRunner;
import org.apache.beam.sdk.testing.PAssert;
import org.apache.beam.sdk.testing.TestPipeline;
@@ -307,8 +310,7 @@ public class ParDoSchemaTest implements Serializable {
FieldAccessDescriptor.withAllFields();
@ProcessElement
- public void process(
- @FieldAccess("foo") @Element Row row,
OutputReceiver<String> r) {
+ public void process(@FieldAccess("foo") Row row,
OutputReceiver<String> r) {
r.output(row.getString(0) + ":" + row.getInt32(1));
}
}));
@@ -361,31 +363,29 @@ public class ParDoSchemaTest implements Serializable {
}
/** POJO used for testing. */
- @DefaultSchema(JavaFieldSchema.class)
- static class InferredPojo {
- final String stringField;
- final Integer integerField;
+ @DefaultSchema(AutoValueSchema.class)
+ @AutoValue
+ abstract static class Inferred {
+ abstract String getStringField();
- @SchemaCreate
- InferredPojo(String stringField, Integer integerField) {
- this.stringField = stringField;
- this.integerField = integerField;
- }
+ abstract Integer getIntegerField();
}
@Test
@Category({ValidatesRunner.class, UsesSchema.class})
public void testInferredSchemaPipeline() {
- List<InferredPojo> pojoList =
+ List<Inferred> pojoList =
Lists.newArrayList(
- new InferredPojo("a", 1), new InferredPojo("b", 2), new
InferredPojo("c", 3));
+ new AutoValue_ParDoSchemaTest_Inferred("a", 1),
+ new AutoValue_ParDoSchemaTest_Inferred("b", 2),
+ new AutoValue_ParDoSchemaTest_Inferred("c", 3));
PCollection<String> output =
pipeline
.apply(Create.of(pojoList))
.apply(
ParDo.of(
- new DoFn<InferredPojo, String>() {
+ new DoFn<Inferred, String>() {
@ProcessElement
public void process(@Element Row row,
OutputReceiver<String> r) {
r.output(row.getString(0) + ":" + row.getInt32(1));
@@ -398,61 +398,57 @@ public class ParDoSchemaTest implements Serializable {
@Test
@Category({ValidatesRunner.class, UsesSchema.class})
public void testSchemasPassedThrough() {
- List<InferredPojo> pojoList =
+ List<Inferred> pojoList =
Lists.newArrayList(
- new InferredPojo("a", 1), new InferredPojo("b", 2), new
InferredPojo("c", 3));
+ new AutoValue_ParDoSchemaTest_Inferred("a", 1),
+ new AutoValue_ParDoSchemaTest_Inferred("b", 2),
+ new AutoValue_ParDoSchemaTest_Inferred("c", 3));
- PCollection<InferredPojo> out =
pipeline.apply(Create.of(pojoList)).apply(Filter.by(e -> true));
+ PCollection<Inferred> out =
pipeline.apply(Create.of(pojoList)).apply(Filter.by(e -> true));
assertTrue(out.hasSchema());
pipeline.run();
}
/** Pojo used for testing. */
- @DefaultSchema(JavaFieldSchema.class)
- static class InferredPojo2 {
- final Integer integerField;
- final String stringField;
+ @DefaultSchema(AutoValueSchema.class)
+ @AutoValue
+ abstract static class Inferred2 {
+ abstract Integer getIntegerField();
- @SchemaCreate
- InferredPojo2(String stringField, Integer integerField) {
- this.stringField = stringField;
- this.integerField = integerField;
- }
+ abstract String getStringField();
}
@Test
@Category({ValidatesRunner.class, UsesSchema.class})
public void testSchemaConversionPipeline() {
- List<InferredPojo> pojoList =
+ List<Inferred> pojoList =
Lists.newArrayList(
- new InferredPojo("a", 1), new InferredPojo("b", 2), new
InferredPojo("c", 3));
+ new AutoValue_ParDoSchemaTest_Inferred("a", 1),
+ new AutoValue_ParDoSchemaTest_Inferred("b", 2),
+ new AutoValue_ParDoSchemaTest_Inferred("c", 3));
PCollection<String> output =
pipeline
.apply(Create.of(pojoList))
.apply(
ParDo.of(
- new DoFn<InferredPojo, String>() {
+ new DoFn<Inferred, String>() {
@ProcessElement
- public void process(@Element InferredPojo2 pojo,
OutputReceiver<String> r) {
- r.output(pojo.stringField + ":" + pojo.integerField);
+ public void process(@Element Inferred2 pojo,
OutputReceiver<String> r) {
+ r.output(pojo.getStringField() + ":" +
pojo.getIntegerField());
}
}));
PAssert.that(output).containsInAnyOrder("a:1", "b:2", "c:3");
pipeline.run();
}
- @DefaultSchema(JavaFieldSchema.class)
- static class Nested {
- final int field1;
- final InferredPojo inner;
+ @DefaultSchema(AutoValueSchema.class)
+ @AutoValue
+ abstract static class Nested {
+ abstract int getField1();
- @SchemaCreate
- public Nested(int field1, InferredPojo inner) {
- this.field1 = field1;
- this.inner = inner;
- }
+ abstract Inferred getInner();
}
@Test
@@ -460,9 +456,10 @@ public class ParDoSchemaTest implements Serializable {
public void testNestedSchema() {
List<Nested> pojoList =
Lists.newArrayList(
- new Nested(1, new InferredPojo("a", 1)),
- new Nested(2, new InferredPojo("b", 2)),
- new Nested(3, new InferredPojo("c", 3)));
+ new AutoValue_ParDoSchemaTest_Nested(1, new
AutoValue_ParDoSchemaTest_Inferred("a", 1)),
+ new AutoValue_ParDoSchemaTest_Nested(2, new
AutoValue_ParDoSchemaTest_Inferred("b", 2)),
+ new AutoValue_ParDoSchemaTest_Nested(
+ 3, new AutoValue_ParDoSchemaTest_Inferred("c", 3)));
PCollection<String> output =
pipeline
@@ -475,10 +472,154 @@ public class ParDoSchemaTest implements Serializable {
new DoFn<Nested, String>() {
@ProcessElement
public void process(@Element Nested nested,
OutputReceiver<String> r) {
- r.output(nested.inner.stringField + ":" +
nested.inner.integerField);
+ r.output(
+ nested.getInner().getStringField()
+ + ":"
+ + nested.getInner().getIntegerField());
}
}));
PAssert.that(output).containsInAnyOrder("a:1", "b:2", "c:3");
pipeline.run();
}
+
+ @DefaultSchema(AutoValueSchema.class)
+ @AutoValue
+ abstract static class ForExtraction {
+ abstract Integer getIntegerField();
+
+ abstract String getStringField();
+
+ abstract List<Integer> getInts();
+ }
+
+ @Test
+ @Category({ValidatesRunner.class, UsesSchema.class})
+ public void testSchemaFieldSelectionUnboxing() {
+ List<ForExtraction> pojoList =
+ Lists.newArrayList(
+ new AutoValue_ParDoSchemaTest_ForExtraction(1, "a",
Lists.newArrayList(1, 2)),
+ new AutoValue_ParDoSchemaTest_ForExtraction(2, "b",
Lists.newArrayList(2, 3)),
+ new AutoValue_ParDoSchemaTest_ForExtraction(3, "c",
Lists.newArrayList(3, 4)));
+
+ PCollection<String> output =
+ pipeline
+ .apply(Create.of(pojoList))
+ .apply(
+ ParDo.of(
+ new DoFn<ForExtraction, String>() {
+ // Read the list twice as two equivalent types to ensure
that Beam properly
+ // converts.
+ @ProcessElement
+ public void process(
+ @FieldAccess("stringField") String stringField,
+ @FieldAccess("integerField") Integer integerField,
+ @FieldAccess("ints") Integer[] intArray,
+ @FieldAccess("ints") List<Integer> intList,
+ OutputReceiver<String> r) {
+
+ r.output(
+ stringField
+ + ":"
+ + integerField
+ + ":"
+ + Arrays.toString(intArray)
+ + ":"
+ + intList.toString());
+ }
+ }));
+ PAssert.that(output)
+ .containsInAnyOrder("a:1:[1, 2]:[1, 2]", "b:2:[2, 3]:[2, 3]", "c:3:[3,
4]:[3, 4]");
+ pipeline.run();
+ }
+
+ @Test
+ @Category({ValidatesRunner.class, UsesSchema.class})
+ public void testSchemaFieldDescriptorSelectionUnboxing() {
+ List<ForExtraction> pojoList =
+ Lists.newArrayList(
+ new AutoValue_ParDoSchemaTest_ForExtraction(1, "a",
Lists.newArrayList(1, 2)),
+ new AutoValue_ParDoSchemaTest_ForExtraction(2, "b",
Lists.newArrayList(2, 3)),
+ new AutoValue_ParDoSchemaTest_ForExtraction(3, "c",
Lists.newArrayList(3, 4)));
+
+ PCollection<String> output =
+ pipeline
+ .apply(Create.of(pojoList))
+ .apply(
+ ParDo.of(
+ new DoFn<ForExtraction, String>() {
+ @FieldAccess("stringSelector")
+ final FieldAccessDescriptor stringSelector =
+ FieldAccessDescriptor.withFieldNames("stringField");
+
+ @FieldAccess("intSelector")
+ final FieldAccessDescriptor intSelector =
+ FieldAccessDescriptor.withFieldNames("integerField");
+
+ @FieldAccess("intsSelector")
+ final FieldAccessDescriptor intsSelector =
+ FieldAccessDescriptor.withFieldNames("ints");
+
+ @ProcessElement
+ public void process(
+ @FieldAccess("stringSelector") String stringField,
+ @FieldAccess("intSelector") int integerField,
+ @FieldAccess("intsSelector") int[] intArray,
+ OutputReceiver<String> r) {
+ r.output(
+ stringField + ":" + integerField + ":" +
Arrays.toString(intArray));
+ }
+ }));
+ PAssert.that(output).containsInAnyOrder("a:1:[1, 2]", "b:2:[2, 3]",
"c:3:[3, 4]");
+ pipeline.run();
+ }
+
+ @DefaultSchema(AutoValueSchema.class)
+ @AutoValue
+ abstract static class NestedForExtraction {
+ abstract ForExtraction getInner();
+ }
+
+ @Test
+ @Category({ValidatesRunner.class, UsesSchema.class})
+ public void testSchemaFieldSelectionNested() {
+ List<ForExtraction> pojoList =
+ Lists.newArrayList(
+ new AutoValue_ParDoSchemaTest_ForExtraction(1, "a",
Lists.newArrayList(1, 2)),
+ new AutoValue_ParDoSchemaTest_ForExtraction(2, "b",
Lists.newArrayList(2, 3)),
+ new AutoValue_ParDoSchemaTest_ForExtraction(3, "c",
Lists.newArrayList(3, 4)));
+ List<NestedForExtraction> outerList =
+ pojoList.stream()
+ .map(AutoValue_ParDoSchemaTest_NestedForExtraction::new)
+ .collect(Collectors.toList());
+
+ PCollection<String> output =
+ pipeline
+ .apply(Create.of(outerList))
+ .apply(
+ ParDo.of(
+ new DoFn<NestedForExtraction, String>() {
+
+ @ProcessElement
+ public void process(
+ @FieldAccess("inner.*") ForExtraction extracted,
+ @FieldAccess("inner") ForExtraction extracted1,
+ @FieldAccess("inner.stringField") String stringField,
+ @FieldAccess("inner.integerField") int integerField,
+ @FieldAccess("inner.ints") List<Integer> intArray,
+ OutputReceiver<String> r) {
+ assertEquals(extracted, extracted1);
+ assertEquals(stringField, extracted.getStringField());
+ assertEquals(integerField, (int)
extracted.getIntegerField());
+ assertEquals(intArray, extracted.getInts());
+ r.output(
+ extracted.getStringField()
+ + ":"
+ + extracted.getIntegerField()
+ + ":"
+ + extracted.getInts().toString());
+ }
+ }));
+ PAssert.that(output).containsInAnyOrder("a:1:[1, 2]", "b:2:[2, 3]",
"c:3:[3, 4]");
+ pipeline.run();
+ }
}
diff --git
a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesTest.java
b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesTest.java
index 063af02..8c80cfc 100644
---
a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesTest.java
+++
b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesTest.java
@@ -25,6 +25,8 @@ import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.not;
import static org.hamcrest.Matchers.notNullValue;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.fail;
@@ -66,6 +68,7 @@ import org.apache.beam.sdk.transforms.windowing.PaneInfo;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.Row;
import org.apache.beam.sdk.values.TypeDescriptor;
+import org.apache.beam.sdk.values.TypeDescriptors;
import org.hamcrest.Matcher;
import org.hamcrest.Matchers;
import org.joda.time.Instant;
@@ -181,7 +184,35 @@ public class DoFnSignaturesTest {
@ProcessElement
public void process(@Element Row row) {}
}.getClass());
- assertThat(sig.processElement().getSchemaElementParameter(),
notNullValue());
+ assertFalse(sig.processElement().getSchemaElementParameters().isEmpty());
+ }
+
+ @Test
+ public void testMultipleSchemaParameters() {
+ DoFnSignature sig =
+ DoFnSignatures.getSignature(
+ new DoFn<String, String>() {
+ @ProcessElement
+ public void process(
+ @Element Row row1,
+ @Timestamp Instant ts,
+ @Element Row row2,
+ OutputReceiver<String> o,
+ @Element Integer intParameter) {}
+ }.getClass());
+ assertEquals(3, sig.processElement().getSchemaElementParameters().size());
+ assertEquals(0,
sig.processElement().getSchemaElementParameters().get(0).index());
+ assertEquals(
+ TypeDescriptors.rows(),
+ sig.processElement().getSchemaElementParameters().get(0).elementT());
+ assertEquals(1,
sig.processElement().getSchemaElementParameters().get(1).index());
+ assertEquals(
+ TypeDescriptors.rows(),
+ sig.processElement().getSchemaElementParameters().get(1).elementT());
+ assertEquals(2,
sig.processElement().getSchemaElementParameters().get(2).index());
+ assertEquals(
+ TypeDescriptors.integers(),
+ sig.processElement().getSchemaElementParameters().get(2).elementT());
}
@Test
@@ -202,7 +233,7 @@ public class DoFnSignaturesTest {
assertThat(field.getName(), equalTo("fieldAccess"));
assertThat(field.get(doFn), equalTo(descriptor));
- assertThat(sig.processElement().getSchemaElementParameter(),
notNullValue());
+ assertFalse(sig.processElement().getSchemaElementParameters().isEmpty());
}
@Test
diff --git
a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java
b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java
index b8a2dfb..78f035b 100644
---
a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java
+++
b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java
@@ -44,6 +44,7 @@ import
org.apache.beam.sdk.transforms.DoFn.MultiOutputReceiver;
import org.apache.beam.sdk.transforms.DoFn.OutputReceiver;
import org.apache.beam.sdk.transforms.DoFnOutputReceivers;
import org.apache.beam.sdk.transforms.DoFnSchemaInformation;
+import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.beam.sdk.transforms.reflect.DoFnInvoker;
import org.apache.beam.sdk.transforms.reflect.DoFnInvokers;
import org.apache.beam.sdk.transforms.reflect.DoFnSignature.StateDeclaration;
@@ -120,6 +121,7 @@ public class FnApiDoFnRunner<InputT, OutputT>
this.mainOutputConsumers =
(Collection<FnDataReceiver<WindowedValue<OutputT>>>)
(Collection)
context.localNameToConsumer.get(context.mainOutputTag.getId());
+ this.doFnSchemaInformation =
ParDoTranslation.getSchemaInformation(context.parDoPayload);
this.doFnInvoker = DoFnInvokers.invokerFor(context.doFn);
this.doFnInvoker.invokeSetup();
@@ -157,7 +159,6 @@ public class FnApiDoFnRunner<InputT, OutputT>
outputTo(consumers, WindowedValue.of(output, timestamp, window,
PaneInfo.NO_FIRING));
}
};
- this.doFnSchemaInformation =
ParDoTranslation.getSchemaInformation(context.parDoPayload);
}
@Override
@@ -396,9 +397,9 @@ public class FnApiDoFnRunner<InputT, OutputT>
}
@Override
- public Object schemaElement(DoFn<InputT, OutputT> doFn) {
- Row row = context.schemaCoder.getToRowFunction().apply(element());
- return
doFnSchemaInformation.getElementParameterSchema().getFromRowFunction().apply(row);
+ public Object schemaElement(int index) {
+ SerializableFunction converter =
doFnSchemaInformation.getElementConverters().get(index);
+ return converter.apply(element());
}
@Override
@@ -580,7 +581,7 @@ public class FnApiDoFnRunner<InputT, OutputT>
}
@Override
- public Object schemaElement(DoFn<InputT, OutputT> doFn) {
+ public Object schemaElement(int index) {
throw new UnsupportedOperationException("Element parameters are not
supported.");
}