This is an automated email from the ASF dual-hosted git repository.
chamikara pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 1455c54 [BEAM-12769] Adds support for expanding a Java cross-language
transform using the class name and builder methods (#15343)
1455c54 is described below
commit 1455c545c2e7a4d89b949ebb75712c30fa996925
Author: Chamikara Jayalath <[email protected]>
AuthorDate: Sat Sep 4 14:20:26 2021 -0700
[BEAM-12769] Adds support for expanding a Java cross-language transform
using the class name and builder methods (#15343)
* Adds support for expanding a Java cross-language transform using the
class name and builder methods
* Adds an allowlist and adds support for annotations
* Fix tests
* Address CheckerFramework errors
* Adds license
* Addresses reviewer comments.
* Apply suggestions from code review
Co-authored-by: Lukasz Cwik <[email protected]>
* Addresses reviewer comments.
* Updated the proto to include a single schema/payload for constructor and
each builder method.
Updated the implementation accordingly and added additional tests.
* Some doc updates and few other minor updates.
* Addressing reviewer comments
Co-authored-by: Lukasz Cwik <[email protected]>
---
.../src/main/proto/external_transforms.proto | 63 ++
sdks/java/expansion-service/build.gradle | 3 +
.../sdk/expansion/service/ExpansionService.java | 36 +-
.../expansion/service/ExpansionServiceOptions.java | 75 ++
.../service/JavaClassLookupTransformProvider.java | 526 +++++++++
.../service/MultiLanguageBuilderMethod.java | 33 +-
.../service/MultiLanguageConstructorMethod.java | 33 +-
.../expansion/service/ExpansionServiceTest.java | 16 +-
.../JavaCLassLookupTransformProviderTest.java | 1111 ++++++++++++++++++++
.../src/test/resources/test_allowlist.yaml | 67 ++
10 files changed, 1901 insertions(+), 62 deletions(-)
diff --git a/model/pipeline/src/main/proto/external_transforms.proto
b/model/pipeline/src/main/proto/external_transforms.proto
index f2d47a1..a528e56 100644
--- a/model/pipeline/src/main/proto/external_transforms.proto
+++ b/model/pipeline/src/main/proto/external_transforms.proto
@@ -29,6 +29,7 @@ option java_package = "org.apache.beam.model.pipeline.v1";
option java_outer_classname = "ExternalTransforms";
import "schema.proto";
+import "beam_runner_api.proto";
// A configuration payload for an external transform.
// Used as the payload of ExternalTransform as part of an ExpansionRequest.
@@ -40,3 +41,65 @@ message ExternalConfigurationPayload {
// schema.
bytes payload = 2;
}
+
+// Defines specific expansion methods that may be used to expand cross-language
+// transforms.
+// Has to be set as the URN of the transform of the expansion request.
+message ExpansionMethods {
+ enum Enum {
+ // Expand a Java transform using specified constructor and builder methods.
+ // Transform payload will be of type JavaClassLookupPayload.
+ JAVA_CLASS_LOOKUP = 0 [(org.apache.beam.model.pipeline.v1.beam_urn) =
+ "beam:expansion:payload:java_class_lookup:v1"];
+ }
+}
+
+// A configuration payload for an external transform.
+// Used to define a Java transform that can be directly instantiated by a Java
+// expansion service.
+message JavaClassLookupPayload {
+ // Name of the Java transform class.
+ string class_name = 1;
+
+ // A static method to construct the initial instance of the transform.
+ // If not provided, the transform should be instantiated using a class
+ // constructor.
+ string constructor_method = 2;
+
+ // The top level fields of the schema represent the method parameters in
+ // order.
+ // If able, top level field names are also verified against the method
+ // parameters for a match.
+ Schema constructor_schema = 3;
+
+ // A payload which can be decoded using beam:coder:row:v1 and the provided
+ // constructor schema.
+ bytes constructor_payload = 4;
+
+ // Set of builder methods and corresponding parameters to apply after the
+ // transform object is constructed.
+ // When constructing the transform object, given builder methods will be
+ // applied in order.
+ repeated BuilderMethod builder_methods = 5;
+}
+
+// This represents a builder method of the transform class that should be
+// applied in-order after instantiating the initial transform object.
+// Each builder method may take one or more parameters and has to return an
+// instance of the transform object.
+message BuilderMethod {
+ // Name of the builder method
+ string name = 1;
+
+ // The top level fields of the schema represent the method parameters in
+ // order.
+ // If able, top level field names are also verified against the method
+ // parameters for a match.
+ Schema schema = 2;
+
+ // A payload which can be decoded using beam:coder:row:v1 and the builder
+ // method schema.
+ bytes payload = 3;
+}
+
+
diff --git a/sdks/java/expansion-service/build.gradle
b/sdks/java/expansion-service/build.gradle
index a626303..2a0ffd0 100644
--- a/sdks/java/expansion-service/build.gradle
+++ b/sdks/java/expansion-service/build.gradle
@@ -38,6 +38,9 @@ dependencies {
compile project(path: ":sdks:java:core", configuration: "shadow")
compile project(path: ":runners:core-construction-java")
compile project(path: ":runners:java-fn-execution")
+ compile library.java.jackson_annotations
+ compile library.java.jackson_databind
+ compile library.java.jackson_dataformat_yaml
compile library.java.vendored_grpc_1_36_0
compile library.java.vendored_guava_26_0_jre
compile library.java.slf4j_api
diff --git
a/sdks/java/expansion-service/src/main/java/org/apache/beam/sdk/expansion/service/ExpansionService.java
b/sdks/java/expansion-service/src/main/java/org/apache/beam/sdk/expansion/service/ExpansionService.java
index eaa1cbe..6e1f3d3 100644
---
a/sdks/java/expansion-service/src/main/java/org/apache/beam/sdk/expansion/service/ExpansionService.java
+++
b/sdks/java/expansion-service/src/main/java/org/apache/beam/sdk/expansion/service/ExpansionService.java
@@ -17,6 +17,7 @@
*/
package org.apache.beam.sdk.expansion.service;
+import static org.apache.beam.runners.core.construction.BeamUrns.getUrn;
import static
org.apache.beam.runners.core.construction.resources.PipelineResources.detectClassPathResourcesToStage;
import static org.apache.beam.sdk.util.Preconditions.checkArgumentNotNull;
@@ -35,8 +36,10 @@ import java.util.Set;
import java.util.stream.Collectors;
import org.apache.beam.model.expansion.v1.ExpansionApi;
import org.apache.beam.model.expansion.v1.ExpansionServiceGrpc;
+import org.apache.beam.model.pipeline.v1.ExternalTransforms.ExpansionMethods;
import
org.apache.beam.model.pipeline.v1.ExternalTransforms.ExternalConfigurationPayload;
import org.apache.beam.model.pipeline.v1.RunnerApi;
+import org.apache.beam.model.pipeline.v1.SchemaApi;
import org.apache.beam.runners.core.construction.Environments;
import org.apache.beam.runners.core.construction.PipelineTranslation;
import org.apache.beam.runners.core.construction.RehydratedComponents;
@@ -49,6 +52,7 @@ import org.apache.beam.sdk.PipelineRunner;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.RowCoder;
import org.apache.beam.sdk.expansion.ExternalTransformRegistrar;
+import
org.apache.beam.sdk.expansion.service.JavaClassLookupTransformProvider.AllowList;
import org.apache.beam.sdk.options.ExperimentalOptions;
import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.beam.sdk.options.PipelineOptionsFactory;
@@ -70,6 +74,7 @@ import org.apache.beam.sdk.values.PInput;
import org.apache.beam.sdk.values.POutput;
import org.apache.beam.sdk.values.Row;
import org.apache.beam.sdk.values.TupleTag;
+import org.apache.beam.vendor.grpc.v1p36p0.com.google.protobuf.ByteString;
import org.apache.beam.vendor.grpc.v1p36p0.io.grpc.Server;
import org.apache.beam.vendor.grpc.v1p36p0.io.grpc.ServerBuilder;
import org.apache.beam.vendor.grpc.v1p36p0.io.grpc.stub.StreamObserver;
@@ -172,8 +177,8 @@ public class ExpansionService extends
ExpansionServiceGrpc.ExpansionServiceImplB
return configurationClass;
}
- private static <ConfigT> Row decodeRow(ExternalConfigurationPayload
payload) {
- Schema payloadSchema =
SchemaTranslation.schemaFromProto(payload.getSchema());
+ static <ConfigT> Row decodeConfigObjectRow(SchemaApi.Schema schema,
ByteString payload) {
+ Schema payloadSchema = SchemaTranslation.schemaFromProto(schema);
if (payloadSchema.getFieldCount() == 0) {
return Row.withSchema(Schema.of()).build();
@@ -200,7 +205,7 @@ public class ExpansionService extends
ExpansionServiceGrpc.ExpansionServiceImplB
Row configRow;
try {
- configRow =
RowCoder.of(payloadSchema).decode(payload.getPayload().newInput());
+ configRow = RowCoder.of(payloadSchema).decode(payload.newInput());
} catch (IOException e) {
throw new RuntimeException("Error decoding payload", e);
}
@@ -247,7 +252,7 @@ public class ExpansionService extends
ExpansionServiceGrpc.ExpansionServiceImplB
SerializableFunction<Row, ConfigT> fromRowFunc =
SCHEMA_REGISTRY.getFromRowFunction(configurationClass);
- Row payloadRow = decodeRow(payload);
+ Row payloadRow = decodeConfigObjectRow(payload.getSchema(),
payload.getPayload());
if (!payloadRow.getSchema().assignableTo(configSchema)) {
throw new IllegalArgumentException(
@@ -263,7 +268,7 @@ public class ExpansionService extends
ExpansionServiceGrpc.ExpansionServiceImplB
private static <ConfigT> ConfigT payloadToConfigSetters(
ExternalConfigurationPayload payload, Class<ConfigT>
configurationClass)
throws ReflectiveOperationException {
- Row configRow = decodeRow(payload);
+ Row configRow = decodeConfigObjectRow(payload.getSchema(),
payload.getPayload());
Constructor<ConfigT> constructor =
configurationClass.getDeclaredConstructor();
constructor.setAccessible(true);
@@ -459,13 +464,22 @@ public class ExpansionService extends
ExpansionServiceGrpc.ExpansionServiceImplB
}
}));
- @Nullable
- TransformProvider transformProvider =
-
getRegisteredTransforms().get(request.getTransform().getSpec().getUrn());
- if (transformProvider == null) {
- throw new UnsupportedOperationException(
- "Unknown urn: " + request.getTransform().getSpec().getUrn());
+ String urn = request.getTransform().getSpec().getUrn();
+
+ TransformProvider transformProvider = null;
+ if (getUrn(ExpansionMethods.Enum.JAVA_CLASS_LOOKUP).equals(urn)) {
+ AllowList allowList =
+
pipelineOptions.as(ExpansionServiceOptions.class).getJavaClassLookupAllowlist();
+ assert allowList != null;
+ transformProvider = new JavaClassLookupTransformProvider(allowList);
+ } else {
+ transformProvider = getRegisteredTransforms().get(urn);
+ if (transformProvider == null) {
+ throw new UnsupportedOperationException(
+ "Unknown urn: " + request.getTransform().getSpec().getUrn());
+ }
}
+
Map<String, PCollection<?>> outputs =
transformProvider.apply(
pipeline,
diff --git
a/sdks/java/expansion-service/src/main/java/org/apache/beam/sdk/expansion/service/ExpansionServiceOptions.java
b/sdks/java/expansion-service/src/main/java/org/apache/beam/sdk/expansion/service/ExpansionServiceOptions.java
new file mode 100644
index 0000000..79e870c
--- /dev/null
+++
b/sdks/java/expansion-service/src/main/java/org/apache/beam/sdk/expansion/service/ExpansionServiceOptions.java
@@ -0,0 +1,75 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.expansion.service;
+
+import com.fasterxml.jackson.databind.ObjectMapper;
+import com.fasterxml.jackson.dataformat.yaml.YAMLFactory;
+import java.io.File;
+import java.io.IOException;
+import java.util.ArrayList;
+import
org.apache.beam.sdk.expansion.service.JavaClassLookupTransformProvider.AllowList;
+import org.apache.beam.sdk.options.Default;
+import org.apache.beam.sdk.options.DefaultValueFactory;
+import org.apache.beam.sdk.options.Description;
+import org.apache.beam.sdk.options.PipelineOptions;
+
+/** Options used to configure the {@link ExpansionService}. */
+public interface ExpansionServiceOptions extends PipelineOptions {
+
+ @Description("Allow list for Java class based transform expansion")
+ @Default.InstanceFactory(JavaClassLookupAllowListFactory.class)
+ AllowList getJavaClassLookupAllowlist();
+
+ void setJavaClassLookupAllowlist(AllowList file);
+
+ @Description("Allow list file for Java class based transform expansion")
+ String getJavaClassLookupAllowlistFile();
+
+ void setJavaClassLookupAllowlistFile(String file);
+
+ /**
+ * Loads the allow list from {@link #getJavaClassLookupAllowlistFile},
defaulting to an empty
+ * {@link JavaClassLookupTransformProvider.AllowList}.
+ */
+ class JavaClassLookupAllowListFactory implements
DefaultValueFactory<AllowList> {
+
+ @Override
+ public AllowList create(PipelineOptions options) {
+ String allowListFile =
+
options.as(ExpansionServiceOptions.class).getJavaClassLookupAllowlistFile();
+ if (allowListFile != null) {
+ ObjectMapper mapper = new ObjectMapper(new YAMLFactory());
+ File allowListFileObj = new File(allowListFile);
+ if (!allowListFileObj.exists()) {
+ throw new IllegalArgumentException(
+ "Allow list file " + allowListFile + " does not exist");
+ }
+ try {
+ return mapper.readValue(allowListFileObj, AllowList.class);
+ } catch (IOException e) {
+ throw new IllegalArgumentException(
+ "Could not load the provided allowlist file " + allowListFile,
e);
+ }
+ }
+
+ // By default produces an empty allow-list.
+ return new AutoValue_JavaClassLookupTransformProvider_AllowList(
+ JavaClassLookupTransformProvider.ALLOW_LIST_VERSION, new
ArrayList<>());
+ }
+ }
+}
diff --git
a/sdks/java/expansion-service/src/main/java/org/apache/beam/sdk/expansion/service/JavaClassLookupTransformProvider.java
b/sdks/java/expansion-service/src/main/java/org/apache/beam/sdk/expansion/service/JavaClassLookupTransformProvider.java
new file mode 100644
index 0000000..d32c7e4
--- /dev/null
+++
b/sdks/java/expansion-service/src/main/java/org/apache/beam/sdk/expansion/service/JavaClassLookupTransformProvider.java
@@ -0,0 +1,526 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.expansion.service;
+
+import static org.apache.beam.runners.core.construction.BeamUrns.getUrn;
+
+import com.fasterxml.jackson.annotation.JsonCreator;
+import com.fasterxml.jackson.annotation.JsonProperty;
+import com.google.auto.value.AutoValue;
+import edu.umd.cs.findbugs.annotations.SuppressFBWarnings;
+import java.io.IOException;
+import java.lang.annotation.Annotation;
+import java.lang.reflect.Array;
+import java.lang.reflect.Constructor;
+import java.lang.reflect.InvocationTargetException;
+import java.lang.reflect.Method;
+import java.lang.reflect.ParameterizedType;
+import java.lang.reflect.Type;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.List;
+import java.util.stream.Collectors;
+import org.apache.beam.model.pipeline.v1.ExternalTransforms.BuilderMethod;
+import org.apache.beam.model.pipeline.v1.ExternalTransforms.ExpansionMethods;
+import
org.apache.beam.model.pipeline.v1.ExternalTransforms.JavaClassLookupPayload;
+import org.apache.beam.model.pipeline.v1.RunnerApi.FunctionSpec;
+import org.apache.beam.model.pipeline.v1.SchemaApi;
+import org.apache.beam.repackaged.core.org.apache.commons.lang3.ClassUtils;
+import org.apache.beam.sdk.coders.RowCoder;
+import
org.apache.beam.sdk.expansion.service.ExpansionService.TransformProvider;
+import org.apache.beam.sdk.schemas.JavaFieldSchema;
+import org.apache.beam.sdk.schemas.NoSuchSchemaException;
+import org.apache.beam.sdk.schemas.Schema;
+import org.apache.beam.sdk.schemas.Schema.Field;
+import org.apache.beam.sdk.schemas.Schema.TypeName;
+import org.apache.beam.sdk.schemas.SchemaRegistry;
+import org.apache.beam.sdk.schemas.SchemaTranslation;
+import org.apache.beam.sdk.transforms.PTransform;
+import org.apache.beam.sdk.transforms.SerializableFunction;
+import org.apache.beam.sdk.util.common.ReflectHelpers;
+import org.apache.beam.sdk.values.PInput;
+import org.apache.beam.sdk.values.POutput;
+import org.apache.beam.sdk.values.Row;
+import org.apache.beam.vendor.grpc.v1p36p0.com.google.protobuf.ByteString;
+import
org.apache.beam.vendor.grpc.v1p36p0.com.google.protobuf.InvalidProtocolBufferException;
+import org.checkerframework.checker.nullness.qual.Nullable;
+
+/**
+ * A transform provider that can be used to directly instantiate a transform
using Java class name
+ * and builder methods.
+ *
+ * @param <InputT> input {@link PInput} type of the transform
+ * @param <OutputT> output {@link POutput} type of the transform
+ */
+@SuppressWarnings({"argument.type.incompatible",
"assignment.type.incompatible"})
+@SuppressFBWarnings("UWF_UNWRITTEN_PUBLIC_OR_PROTECTED_FIELD")
+class JavaClassLookupTransformProvider<InputT extends PInput, OutputT extends
POutput>
+ implements TransformProvider<PInput, POutput> {
+
+ public static final String ALLOW_LIST_VERSION = "v1";
+ private static final SchemaRegistry SCHEMA_REGISTRY =
SchemaRegistry.createDefault();
+ private final AllowList allowList;
+
+ public JavaClassLookupTransformProvider(AllowList allowList) {
+ if (!allowList.getVersion().equals(ALLOW_LIST_VERSION)) {
+ throw new IllegalArgumentException("Unknown allow-list version");
+ }
+ this.allowList = allowList;
+ }
+
+ @Override
+ public PTransform<PInput, POutput> getTransform(FunctionSpec spec) {
+ JavaClassLookupPayload payload;
+ try {
+ payload = JavaClassLookupPayload.parseFrom(spec.getPayload());
+ } catch (InvalidProtocolBufferException e) {
+ throw new IllegalArgumentException(
+ "Invalid payload type for URN " +
getUrn(ExpansionMethods.Enum.JAVA_CLASS_LOOKUP), e);
+ }
+
+ String className = payload.getClassName();
+ try {
+ AllowedClass allowlistClass = null;
+ if (this.allowList != null) {
+ for (AllowedClass cls : this.allowList.getAllowedClasses()) {
+ if (cls.getClassName().equals(className)) {
+ if (allowlistClass != null) {
+ throw new IllegalArgumentException(
+ "Found two matching allowlist classes " + allowlistClass + "
and " + cls);
+ }
+ allowlistClass = cls;
+ }
+ }
+ }
+ if (allowlistClass == null) {
+ throw new UnsupportedOperationException(
+ "The provided allow list does not enable expanding a transform
class by the name "
+ + className
+ + ".");
+ }
+ Class<PTransform<InputT, OutputT>> transformClass =
+ (Class<PTransform<InputT, OutputT>>)
+ ReflectHelpers.findClassLoader().loadClass(className);
+ PTransform<PInput, POutput> transform;
+ Row constructorRow =
+ decodeRow(payload.getConstructorSchema(),
payload.getConstructorPayload());
+ if (payload.getConstructorMethod().isEmpty()) {
+ Constructor<?>[] constructors = transformClass.getConstructors();
+ Constructor<PTransform<InputT, OutputT>> constructor =
+ findMappingConstructor(constructors, payload);
+ Object[] parameterValues =
+ getParameterValues(
+ constructor.getParameters(),
+ constructorRow,
+ constructor.getGenericParameterTypes());
+ transform = (PTransform<PInput, POutput>)
constructor.newInstance(parameterValues);
+ } else {
+ Method[] methods = transformClass.getMethods();
+ Method method = findMappingConstructorMethod(methods, payload,
allowlistClass);
+ Object[] parameterValues =
+ getParameterValues(
+ method.getParameters(), constructorRow,
method.getGenericParameterTypes());
+ transform = (PTransform<PInput, POutput>) method.invoke(null /* static
*/, parameterValues);
+ }
+ return applyBuilderMethods(transform, payload, allowlistClass);
+ } catch (ClassNotFoundException e) {
+ throw new IllegalArgumentException("Could not find class " + className,
e);
+ } catch (InstantiationException
+ | IllegalArgumentException
+ | IllegalAccessException
+ | InvocationTargetException e) {
+ throw new IllegalArgumentException("Could not instantiate class " +
className, e);
+ }
+ }
+
+ private PTransform<PInput, POutput> applyBuilderMethods(
+ PTransform<PInput, POutput> transform,
+ JavaClassLookupPayload payload,
+ AllowedClass allowListClass) {
+ for (BuilderMethod builderMethod : payload.getBuilderMethodsList()) {
+ Method method = getMethod(transform, builderMethod, allowListClass);
+ try {
+ Row builderMethodRow = decodeRow(builderMethod.getSchema(),
builderMethod.getPayload());
+ transform =
+ (PTransform<PInput, POutput>)
+ method.invoke(
+ transform,
+ getParameterValues(
+ method.getParameters(),
+ builderMethodRow,
+ method.getGenericParameterTypes()));
+ } catch (IllegalAccessException | InvocationTargetException e) {
+ throw new IllegalArgumentException(
+ "Could not invoke the builder method "
+ + builderMethod
+ + " on transform "
+ + transform
+ + " with parameter schema "
+ + builderMethod.getSchema(),
+ e);
+ }
+ }
+
+ return transform;
+ }
+
+ private boolean isBuilderMethodForName(
+ Method method, String nameFromPayload, AllowedClass allowListClass) {
+ // Lookup based on method annotations
+ for (Annotation annotation : method.getAnnotations()) {
+ if (annotation instanceof MultiLanguageBuilderMethod) {
+ if (nameFromPayload.equals(((MultiLanguageBuilderMethod)
annotation).name())) {
+ if
(allowListClass.getAllowedBuilderMethods().contains(nameFromPayload)) {
+ return true;
+ } else {
+ throw new RuntimeException(
+ "Builder method " + nameFromPayload + " has to be explicitly
allowed");
+ }
+ }
+ }
+ }
+
+ // Lookup based on the method name.
+ boolean match = method.getName().equals(nameFromPayload);
+ String consideredMethodName = method.getName();
+
+ // We provide a simplification for common Java builder pattern naming
convention where builder
+ // methods start with "with". In this case, for a builder method name in
the form "withXyz",
+ // users may just use "xyz". If additional updates to the method name are
needed the transform
+ // has to be updated by adding annotations.
+ if (!match && consideredMethodName.length() > 4 &&
consideredMethodName.startsWith("with")) {
+ consideredMethodName =
+ consideredMethodName.substring(4, 5).toLowerCase() +
consideredMethodName.substring(5);
+ match = consideredMethodName.equals(nameFromPayload);
+ }
+ if (match &&
!allowListClass.getAllowedBuilderMethods().contains(consideredMethodName)) {
+ throw new RuntimeException(
+ "Builder method name " + consideredMethodName + " has to be
explicitly allowed");
+ }
+ return match;
+ }
+
+ private Method getMethod(
+ PTransform<PInput, POutput> transform,
+ BuilderMethod builderMethod,
+ AllowedClass allowListClass) {
+
+ Row builderMethodRow = decodeRow(builderMethod.getSchema(),
builderMethod.getPayload());
+
+ List<Method> matchingMethods =
+ Arrays.stream(transform.getClass().getMethods())
+ .filter(m -> isBuilderMethodForName(m, builderMethod.getName(),
allowListClass))
+ .filter(m -> parametersCompatible(m.getParameters(),
builderMethodRow))
+ .filter(m -> PTransform.class.isAssignableFrom(m.getReturnType()))
+ .collect(Collectors.toList());
+
+ if (matchingMethods.size() != 1) {
+ throw new RuntimeException(
+ "Expected to find exactly one matching method in transform "
+ + transform
+ + " for BuilderMethod"
+ + builderMethod
+ + " but found "
+ + matchingMethods.size());
+ }
+ return matchingMethods.get(0);
+ }
+
+ private static boolean isPrimitiveOrWrapperOrString(java.lang.Class<?> type)
{
+ return ClassUtils.isPrimitiveOrWrapper(type) || type == String.class;
+ }
+
+ private Schema getParameterSchema(Class<?> parameterClass) {
+ Schema parameterSchema;
+ try {
+ parameterSchema = SCHEMA_REGISTRY.getSchema(parameterClass);
+ } catch (NoSuchSchemaException e) {
+
+ SCHEMA_REGISTRY.registerSchemaProvider(parameterClass, new
JavaFieldSchema());
+ try {
+ parameterSchema = SCHEMA_REGISTRY.getSchema(parameterClass);
+ } catch (NoSuchSchemaException e1) {
+ throw new RuntimeException(e1);
+ }
+ if (parameterSchema != null && parameterSchema.getFieldCount() == 0) {
+ throw new RuntimeException(
+ "Could not determine a valid schema for parameter class " +
parameterClass);
+ }
+ }
+ return parameterSchema;
+ }
+
+ private boolean parametersCompatible(
+ java.lang.reflect.Parameter[] methodParameters, Row constructorRow) {
+ Schema constructorSchema = constructorRow.getSchema();
+ if (methodParameters.length != constructorSchema.getFieldCount()) {
+ return false;
+ }
+
+ for (int i = 0; i < methodParameters.length; i++) {
+ java.lang.reflect.Parameter parameterFromReflection =
methodParameters[i];
+ Field parameterFromPayload = constructorSchema.getField(i);
+
+ String paramNameFromReflection = parameterFromReflection.getName();
+ if (!paramNameFromReflection.startsWith("arg")
+ && !paramNameFromReflection.equals(parameterFromPayload.getName())) {
+ // Parameter name through reflection is from the class file (not
through synthesizing,
+ // hence we can validate names)
+ return false;
+ }
+
+ Class<?> parameterClass = parameterFromReflection.getType();
+ if (isPrimitiveOrWrapperOrString(parameterClass)) {
+ continue;
+ }
+
+ // We perform additional validation for arrays and non-primitive types.
+ if (parameterClass.isArray()) {
+ Class<?> arrayFieldClass = parameterClass.getComponentType();
+ if (parameterFromPayload.getType().getTypeName() != TypeName.ARRAY) {
+ throw new RuntimeException(
+ "Expected a schema with a single array field but received "
+ + parameterFromPayload.getType().getTypeName());
+ }
+
+ // Following is a best-effort validation that may not cover all cases.
Idea is to resolve
+ // ambiguities as much as possible to determine an exact match for the
given set of
+ // parameters. If there are ambiguities, the expansion will fail.
+ if (!isPrimitiveOrWrapperOrString(arrayFieldClass)) {
+ @Nullable Collection<Row> values = constructorRow.getArray(i);
+ Schema arrayFieldSchema = getParameterSchema(arrayFieldClass);
+ if (arrayFieldSchema == null) {
+ throw new RuntimeException("Could not determine a schema for type
" + arrayFieldClass);
+ }
+ if (values != null) {
+ @Nullable Row firstItem = values.iterator().next();
+ if (firstItem != null &&
!(firstItem.getSchema().assignableTo(arrayFieldSchema))) {
+ return false;
+ }
+ }
+ }
+ } else if (constructorRow.getValue(i) instanceof Row) {
+ @Nullable Row parameterRow = constructorRow.getRow(i);
+ Schema schema = getParameterSchema(parameterClass);
+ if (schema == null) {
+ throw new RuntimeException("Could not determine a schema for type "
+ parameterClass);
+ }
+ if (parameterRow != null &&
!parameterRow.getSchema().assignableTo(schema)) {
+ return false;
+ }
+ }
+ }
+ return true;
+ }
+
+ private @Nullable Object getDecodedValueFromRow(
+ Class<?> type, Object valueFromRow, @Nullable Type genericType) {
+ if (isPrimitiveOrWrapperOrString(type)) {
+ if (!isPrimitiveOrWrapperOrString(valueFromRow.getClass())) {
+ throw new IllegalArgumentException(
+ "Expected a Java primitive value but received " + valueFromRow);
+ }
+ return valueFromRow;
+ } else if (type.isArray()) {
+ Class<?> arrayComponentClass = type.getComponentType();
+ return getDecodedArrayValueFromRow(arrayComponentClass, valueFromRow);
+ } else if (Collection.class.isAssignableFrom(type)) {
+ List<Object> originalList = (List) valueFromRow;
+ List<Object> decodedList = new ArrayList<>();
+ for (Object obj : originalList) {
+ if (genericType instanceof ParameterizedType) {
+ Class<?> elementType =
+ (Class<?>) ((ParameterizedType)
genericType).getActualTypeArguments()[0];
+ decodedList.add(getDecodedValueFromRow(elementType, obj, null));
+ } else {
+ throw new RuntimeException("Could not determine the generic type of
the list");
+ }
+ }
+ return decodedList;
+ } else if (valueFromRow instanceof Row) {
+ Row row = (Row) valueFromRow;
+ SerializableFunction<Row, ?> fromRowFunc;
+ try {
+ fromRowFunc = SCHEMA_REGISTRY.getFromRowFunction(type);
+ } catch (NoSuchSchemaException e) {
+ throw new IllegalArgumentException(
+ "Could not determine the row function for class " + type, e);
+ }
+ return fromRowFunc.apply(row);
+ }
+ throw new RuntimeException("Could not decode the value from Row " +
valueFromRow);
+ }
+
+ private Object[] getParameterValues(
+ java.lang.reflect.Parameter[] parameters, Row constrtuctorRow, Type[]
genericTypes) {
+ ArrayList<Object> parameterValues = new ArrayList<>();
+ for (int i = 0; i < parameters.length; ++i) {
+ java.lang.reflect.Parameter parameter = parameters[i];
+ Class<?> parameterClass = parameter.getType();
+ Object parameterValue =
+ getDecodedValueFromRow(parameterClass, constrtuctorRow.getValue(i),
genericTypes[i]);
+ parameterValues.add(parameterValue);
+ }
+
+ return parameterValues.toArray();
+ }
+
+ private Object[] getDecodedArrayValueFromRow(Class<?> arrayComponentType,
Object valueFromRow) {
+ List<Object> originalValues = (List<Object>) valueFromRow;
+ List<Object> decodedValues = new ArrayList<>();
+ for (Object obj : originalValues) {
+ decodedValues.add(getDecodedValueFromRow(arrayComponentType, obj, null));
+ }
+
+ // We have to construct and return an array of the correct type. Otherwise
Java reflection
+ // constructor/method invocations that use the returned value may consider
the array as varargs
+ // (different parameters).
+ Object valueTypeArray = Array.newInstance(arrayComponentType,
decodedValues.size());
+ for (int i = 0; i < decodedValues.size(); i++) {
+ Array.set(valueTypeArray, i,
arrayComponentType.cast(decodedValues.get(i)));
+ }
+ return (Object[]) valueTypeArray;
+ }
+
+ private Constructor<PTransform<InputT, OutputT>> findMappingConstructor(
+ Constructor<?>[] constructors, JavaClassLookupPayload payload) {
+ Row constructorRow = decodeRow(payload.getConstructorSchema(),
payload.getConstructorPayload());
+
+ List<Constructor<?>> mappingConstructors =
+ Arrays.stream(constructors)
+ .filter(c -> c.getParameterCount() ==
payload.getConstructorSchema().getFieldsCount())
+ .filter(c -> parametersCompatible(c.getParameters(),
constructorRow))
+ .collect(Collectors.toList());
+ if (mappingConstructors.size() != 1) {
+ throw new RuntimeException(
+ "Expected to find a single mapping constructor but found " +
mappingConstructors.size());
+ }
+ return (Constructor<PTransform<InputT, OutputT>>)
mappingConstructors.get(0);
+ }
+
+ private boolean isConstructorMethodForName(
+ Method method, String nameFromPayload, AllowedClass allowListClass) {
+ for (Annotation annotation : method.getAnnotations()) {
+ if (annotation instanceof MultiLanguageConstructorMethod) {
+ if (nameFromPayload.equals(((MultiLanguageConstructorMethod)
annotation).name())) {
+ if
(allowListClass.getAllowedConstructorMethods().contains(nameFromPayload)) {
+ return true;
+ } else {
+ throw new RuntimeException(
+ "Constructor method " + nameFromPayload + " needs to be
explicitly allowed");
+ }
+ }
+ }
+ }
+ if (method.getName().equals(nameFromPayload)) {
+ if
(allowListClass.getAllowedConstructorMethods().contains(nameFromPayload)) {
+ return true;
+ } else {
+ throw new RuntimeException(
+ "Constructor method " + nameFromPayload + " needs to be explicitly
allowed");
+ }
+ }
+ return false;
+ }
+
+ private Method findMappingConstructorMethod(
+ Method[] methods, JavaClassLookupPayload payload, AllowedClass
allowListClass) {
+
+ Row constructorRow = decodeRow(payload.getConstructorSchema(),
payload.getConstructorPayload());
+
+ List<Method> mappingConstructorMethods =
+ Arrays.stream(methods)
+ .filter(
+ m -> isConstructorMethodForName(m,
payload.getConstructorMethod(), allowListClass))
+ .filter(m -> m.getParameterCount() ==
payload.getConstructorSchema().getFieldsCount())
+ .filter(m -> parametersCompatible(m.getParameters(),
constructorRow))
+ .collect(Collectors.toList());
+
+ if (mappingConstructorMethods.size() != 1) {
+ throw new RuntimeException(
+ "Expected to find a single mapping constructor method but found "
+ + mappingConstructorMethods.size()
+ + " Payload was "
+ + payload);
+ }
+ return mappingConstructorMethods.get(0);
+ }
+
+ @AutoValue
+ public abstract static class AllowList {
+
+ public abstract String getVersion();
+
+ public abstract List<AllowedClass> getAllowedClasses();
+
+ @JsonCreator
+ static AllowList create(
+ @JsonProperty("version") String version,
+ @JsonProperty("allowedClasses") @javax.annotation.Nullable
+ List<AllowedClass> allowedClasses) {
+ if (allowedClasses == null) {
+ allowedClasses = new ArrayList<>();
+ }
+ return new AutoValue_JavaClassLookupTransformProvider_AllowList(version,
allowedClasses);
+ }
+ }
+
+ @AutoValue
+ public abstract static class AllowedClass {
+
+ public abstract String getClassName();
+
+ public abstract List<String> getAllowedBuilderMethods();
+
+ public abstract List<String> getAllowedConstructorMethods();
+
+ @JsonCreator
+ static AllowedClass create(
+ @JsonProperty("className") String className,
+ @JsonProperty("allowedBuilderMethods") @javax.annotation.Nullable
+ List<String> allowedBuilderMethods,
+ @JsonProperty("allowedConstructorMethods") @javax.annotation.Nullable
+ List<String> allowedConstructorMethods) {
+ if (allowedBuilderMethods == null) {
+ allowedBuilderMethods = new ArrayList<>();
+ }
+ if (allowedConstructorMethods == null) {
+ allowedConstructorMethods = new ArrayList<>();
+ }
+ return new AutoValue_JavaClassLookupTransformProvider_AllowedClass(
+ className, allowedBuilderMethods, allowedConstructorMethods);
+ }
+ }
+
+ static Row decodeRow(SchemaApi.Schema schema, ByteString payload) {
+ Schema payloadSchema = SchemaTranslation.schemaFromProto(schema);
+
+ if (payloadSchema.getFieldCount() == 0) {
+ return Row.withSchema(Schema.of()).build();
+ }
+
+ Row row;
+ try {
+ row = RowCoder.of(payloadSchema).decode(payload.newInput());
+ } catch (IOException e) {
+ throw new RuntimeException("Error decoding payload", e);
+ }
+ return row;
+ }
+}
diff --git a/model/pipeline/src/main/proto/external_transforms.proto
b/sdks/java/expansion-service/src/main/java/org/apache/beam/sdk/expansion/service/MultiLanguageBuilderMethod.java
similarity index 53%
copy from model/pipeline/src/main/proto/external_transforms.proto
copy to
sdks/java/expansion-service/src/main/java/org/apache/beam/sdk/expansion/service/MultiLanguageBuilderMethod.java
index f2d47a1..3ee9ef5 100644
--- a/model/pipeline/src/main/proto/external_transforms.proto
+++
b/sdks/java/expansion-service/src/main/java/org/apache/beam/sdk/expansion/service/MultiLanguageBuilderMethod.java
@@ -15,28 +15,17 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
+package org.apache.beam.sdk.expansion.service;
-/*
- * Protocol Buffers describing the external transforms available.
- */
-
-syntax = "proto3";
-
-package org.apache.beam.model.pipeline.v1;
-
-option go_package =
"github.com/apache/beam/sdks/v2/go/pkg/beam/model/pipeline_v1;pipeline_v1";
-option java_package = "org.apache.beam.model.pipeline.v1";
-option java_outer_classname = "ExternalTransforms";
-
-import "schema.proto";
-
-// A configuration payload for an external transform.
-// Used as the payload of ExternalTransform as part of an ExpansionRequest.
-message ExternalConfigurationPayload {
- // A schema for use in beam:coder:row:v1
- Schema schema = 1;
+import java.lang.annotation.Documented;
+import java.lang.annotation.ElementType;
+import java.lang.annotation.Retention;
+import java.lang.annotation.RetentionPolicy;
+import java.lang.annotation.Target;
- // A payload which can be decoded using beam:coder:row:v1 and the given
- // schema.
- bytes payload = 2;
+@Documented
+@Target({ElementType.METHOD})
+@Retention(RetentionPolicy.RUNTIME)
+public @interface MultiLanguageBuilderMethod {
+ String name();
}
diff --git a/model/pipeline/src/main/proto/external_transforms.proto
b/sdks/java/expansion-service/src/main/java/org/apache/beam/sdk/expansion/service/MultiLanguageConstructorMethod.java
similarity index 53%
copy from model/pipeline/src/main/proto/external_transforms.proto
copy to
sdks/java/expansion-service/src/main/java/org/apache/beam/sdk/expansion/service/MultiLanguageConstructorMethod.java
index f2d47a1..e89f460 100644
--- a/model/pipeline/src/main/proto/external_transforms.proto
+++
b/sdks/java/expansion-service/src/main/java/org/apache/beam/sdk/expansion/service/MultiLanguageConstructorMethod.java
@@ -15,28 +15,17 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
+package org.apache.beam.sdk.expansion.service;
-/*
- * Protocol Buffers describing the external transforms available.
- */
-
-syntax = "proto3";
-
-package org.apache.beam.model.pipeline.v1;
-
-option go_package =
"github.com/apache/beam/sdks/v2/go/pkg/beam/model/pipeline_v1;pipeline_v1";
-option java_package = "org.apache.beam.model.pipeline.v1";
-option java_outer_classname = "ExternalTransforms";
-
-import "schema.proto";
-
-// A configuration payload for an external transform.
-// Used as the payload of ExternalTransform as part of an ExpansionRequest.
-message ExternalConfigurationPayload {
- // A schema for use in beam:coder:row:v1
- Schema schema = 1;
+import java.lang.annotation.Documented;
+import java.lang.annotation.ElementType;
+import java.lang.annotation.Retention;
+import java.lang.annotation.RetentionPolicy;
+import java.lang.annotation.Target;
- // A payload which can be decoded using beam:coder:row:v1 and the given
- // schema.
- bytes payload = 2;
+@Documented
+@Target({ElementType.METHOD})
+@Retention(RetentionPolicy.RUNTIME)
+public @interface MultiLanguageConstructorMethod {
+ String name();
}
diff --git
a/sdks/java/expansion-service/src/test/java/org/apache/beam/sdk/expansion/service/ExpansionServiceTest.java
b/sdks/java/expansion-service/src/test/java/org/apache/beam/sdk/expansion/service/ExpansionServiceTest.java
index 5e2a243..e8ecf46 100644
---
a/sdks/java/expansion-service/src/test/java/org/apache/beam/sdk/expansion/service/ExpansionServiceTest.java
+++
b/sdks/java/expansion-service/src/test/java/org/apache/beam/sdk/expansion/service/ExpansionServiceTest.java
@@ -90,7 +90,8 @@ public class ExpansionServiceTest {
/** Registers a single test transformation. */
@AutoService(ExpansionService.ExpansionServiceRegistrar.class)
- public static class TestTransforms implements
ExpansionService.ExpansionServiceRegistrar {
+ public static class TestTransformRegistrar implements
ExpansionService.ExpansionServiceRegistrar {
+
@Override
public Map<String, ExpansionService.TransformProvider> knownTransforms() {
return ImmutableMap.of(TEST_URN, spec -> Count.perElement());
@@ -140,9 +141,9 @@ public class ExpansionServiceTest {
}
@Test
- public void testConstructGenerateSequence() {
+ public void testConstructGenerateSequenceWithRegistration() {
ExternalTransforms.ExternalConfigurationPayload payload =
- encodeRow(
+ encodeRowIntoExternalConfigurationPayload(
Row.withSchema(
Schema.of(
Field.of("start", FieldType.INT64),
@@ -176,7 +177,7 @@ public class ExpansionServiceTest {
@Test
public void testCompoundCodersForExternalConfiguration_setters() throws
Exception {
ExternalTransforms.ExternalConfigurationPayload externalConfig =
- encodeRow(
+ encodeRowIntoExternalConfigurationPayload(
Row.withSchema(
Schema.of(
Field.nullable("config_key1", FieldType.INT64),
@@ -253,7 +254,7 @@ public class ExpansionServiceTest {
@Test
public void testCompoundCodersForExternalConfiguration_schemas() throws
Exception {
ExternalTransforms.ExternalConfigurationPayload externalConfig =
- encodeRow(
+ encodeRowIntoExternalConfigurationPayload(
Row.withSchema(
Schema.of(
Field.nullable("configKey1", FieldType.INT64),
@@ -320,7 +321,7 @@ public class ExpansionServiceTest {
@Test
public void testExternalConfiguration_simpleSchema() throws Exception {
ExternalTransforms.ExternalConfigurationPayload externalConfig =
- encodeRow(
+ encodeRowIntoExternalConfigurationPayload(
Row.withSchema(
Schema.of(
Field.of("bar", FieldType.STRING),
@@ -350,7 +351,8 @@ public class ExpansionServiceTest {
abstract List<String> getList();
}
- private static ExternalTransforms.ExternalConfigurationPayload encodeRow(Row
row) {
+ private static ExternalTransforms.ExternalConfigurationPayload
+ encodeRowIntoExternalConfigurationPayload(Row row) {
ByteString.Output outputStream = ByteString.newOutput();
try {
SchemaCoder.of(row.getSchema()).encode(row, outputStream);
diff --git
a/sdks/java/expansion-service/src/test/java/org/apache/beam/sdk/expansion/service/JavaCLassLookupTransformProviderTest.java
b/sdks/java/expansion-service/src/test/java/org/apache/beam/sdk/expansion/service/JavaCLassLookupTransformProviderTest.java
new file mode 100644
index 0000000..5244108
--- /dev/null
+++
b/sdks/java/expansion-service/src/test/java/org/apache/beam/sdk/expansion/service/JavaCLassLookupTransformProviderTest.java
@@ -0,0 +1,1111 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.expansion.service;
+
+import static org.apache.beam.runners.core.construction.BeamUrns.getUrn;
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.hamcrest.Matchers.anyOf;
+import static org.hamcrest.Matchers.containsString;
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertThrows;
+import static org.junit.Assert.assertTrue;
+
+import java.io.File;
+import java.io.IOException;
+import java.io.Serializable;
+import java.net.URL;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+import org.apache.beam.model.expansion.v1.ExpansionApi;
+import org.apache.beam.model.pipeline.v1.ExternalTransforms;
+import org.apache.beam.model.pipeline.v1.ExternalTransforms.BuilderMethod;
+import org.apache.beam.model.pipeline.v1.ExternalTransforms.ExpansionMethods;
+import org.apache.beam.model.pipeline.v1.RunnerApi;
+import org.apache.beam.model.pipeline.v1.RunnerApi.ParDoPayload;
+import org.apache.beam.model.pipeline.v1.SchemaApi;
+import org.apache.beam.runners.core.construction.ParDoTranslation;
+import org.apache.beam.runners.core.construction.PipelineTranslation;
+import org.apache.beam.sdk.Pipeline;
+import org.apache.beam.sdk.options.PipelineOptionsFactory;
+import org.apache.beam.sdk.schemas.Schema;
+import org.apache.beam.sdk.schemas.Schema.Field;
+import org.apache.beam.sdk.schemas.Schema.FieldType;
+import org.apache.beam.sdk.schemas.SchemaCoder;
+import org.apache.beam.sdk.schemas.SchemaTranslation;
+import org.apache.beam.sdk.transforms.Create;
+import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.PTransform;
+import org.apache.beam.sdk.transforms.ParDo;
+import org.apache.beam.sdk.values.PBegin;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.Row;
+import org.apache.beam.vendor.grpc.v1p36p0.com.google.protobuf.ByteString;
+import
org.apache.beam.vendor.grpc.v1p36p0.com.google.protobuf.InvalidProtocolBufferException;
+import
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
+import
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.io.Resources;
+import org.hamcrest.Matchers;
+import org.junit.BeforeClass;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+/** Tests for {@link JavaCLassLookupTransformProvider}. */
+@RunWith(JUnit4.class)
+public class JavaCLassLookupTransformProviderTest {
+
+ private static final String TEST_URN = "test:beam:transforms:count";
+
+ private static final String TEST_NAME = "TestName";
+
+ private static final String TEST_NAMESPACE = "namespace";
+
+ private static ExpansionService expansionService;
+
+ @BeforeClass
+ public static void setupExpansionService() {
+ PipelineOptionsFactory.register(ExpansionServiceOptions.class);
+ URL allowListFile = Resources.getResource("./test_allowlist.yaml");
+ System.out.println("Exists: " + new
File(allowListFile.getPath()).exists());
+ expansionService =
+ new ExpansionService(
+ new String[] {"--javaClassLookupAllowlistFile=" +
allowListFile.getPath()});
+ }
+
+ static class DummyDoFn extends DoFn<String, String> {
+ String strField1;
+ String strField2;
+ int intField1;
+ Double doubleWrapperField;
+ String[] strArrayField;
+ DummyComplexType complexTypeField;
+ DummyComplexType[] complexTypeArrayField;
+ List<String> strListField;
+ List<DummyComplexType> complexTypeListField;
+
+ private DummyDoFn(
+ String strField1,
+ String strField2,
+ int intField1,
+ Double doubleWrapperField,
+ String[] strArrayField,
+ DummyComplexType complexTypeField,
+ DummyComplexType[] complexTypeArrayField,
+ List<String> strListField,
+ List<DummyComplexType> complexTypeListField) {
+ this.intField1 = intField1;
+ this.strField1 = strField1;
+ this.strField2 = strField2;
+ this.doubleWrapperField = doubleWrapperField;
+ this.strArrayField = strArrayField;
+ this.complexTypeField = complexTypeField;
+ this.complexTypeArrayField = complexTypeArrayField;
+ this.strListField = strListField;
+ this.complexTypeListField = complexTypeListField;
+ }
+
+ @ProcessElement
+ public void processElement(ProcessContext c) {
+ c.output(c.element());
+ }
+ }
+
+ public static class DummyComplexType implements Serializable {
+ String complexTypeStrField;
+ int complexTypeIntField;
+
+ public DummyComplexType() {}
+
+ public DummyComplexType(String complexTypeStrField, int
complexTypeIntField) {
+ this.complexTypeStrField = complexTypeStrField;
+ this.complexTypeIntField = complexTypeIntField;
+ }
+
+ @Override
+ public int hashCode() {
+ return this.complexTypeStrField.hashCode() + this.complexTypeIntField *
31;
+ }
+
+ @Override
+ public boolean equals(Object obj) {
+ if (!(obj instanceof DummyComplexType)) {
+ return false;
+ }
+ DummyComplexType toCompare = (DummyComplexType) obj;
+ return (this.complexTypeIntField == toCompare.complexTypeIntField)
+ && (this.complexTypeStrField.equals(toCompare.complexTypeStrField));
+ }
+ }
+
+ public static class DummyTransform extends PTransform<PBegin,
PCollection<String>> {
+ String strField1;
+ String strField2;
+ int intField1;
+ Double doubleWrapperField;
+ String[] strArrayField;
+ DummyComplexType complexTypeField;
+ DummyComplexType[] complexTypeArrayField;
+ List<String> strListField;
+ List<DummyComplexType> complexTypeListField;
+
+ @Override
+ public PCollection<String> expand(PBegin input) {
+ return input
+ .apply("MyCreateTransform", Create.of("aaa", "bbb", "ccc"))
+ .apply(
+ "MyParDoTransform",
+ ParDo.of(
+ new DummyDoFn(
+ this.strField1,
+ this.strField2,
+ this.intField1,
+ this.doubleWrapperField,
+ this.strArrayField,
+ this.complexTypeField,
+ this.complexTypeArrayField,
+ this.strListField,
+ this.complexTypeListField)));
+ }
+ }
+
+ public static class DummyTransformWithConstructor extends DummyTransform {
+
+ public DummyTransformWithConstructor(String strField1) {
+ this.strField1 = strField1;
+ }
+ }
+
+ public static class DummyTransformWithConstructorAndBuilderMethods extends
DummyTransform {
+
+ public DummyTransformWithConstructorAndBuilderMethods(String strField1) {
+ this.strField1 = strField1;
+ }
+
+ public DummyTransformWithConstructorAndBuilderMethods withStrField2(String
strField2) {
+ this.strField2 = strField2;
+ return this;
+ }
+
+ public DummyTransformWithConstructorAndBuilderMethods withIntField1(int
intField1) {
+ this.intField1 = intField1;
+ return this;
+ }
+ }
+
+ public static class DummyTransformWithMultiArgumentBuilderMethod extends
DummyTransform {
+
+ public DummyTransformWithMultiArgumentBuilderMethod(String strField1) {
+ this.strField1 = strField1;
+ }
+
+ public DummyTransformWithMultiArgumentBuilderMethod withFields(
+ String strField2, int intField1) {
+ this.strField2 = strField2;
+ this.intField1 = intField1;
+ return this;
+ }
+ }
+
+ public static class DummyTransformWithMultiArgumentConstructor extends
DummyTransform {
+
+ public DummyTransformWithMultiArgumentConstructor(String strField1, String
strField2) {
+ this.strField1 = strField1;
+ this.strField2 = strField2;
+ }
+ }
+
+ public static class DummyTransformWithConstructorMethod extends
DummyTransform {
+
+ public static DummyTransformWithConstructorMethod from(String strField1) {
+ DummyTransformWithConstructorMethod transform = new
DummyTransformWithConstructorMethod();
+ transform.strField1 = strField1;
+ return transform;
+ }
+ }
+
+ public static class DummyTransformWithConstructorMethodAndBuilderMethods
extends DummyTransform {
+
+ public static DummyTransformWithConstructorMethodAndBuilderMethods
from(String strField1) {
+ DummyTransformWithConstructorMethodAndBuilderMethods transform =
+ new DummyTransformWithConstructorMethodAndBuilderMethods();
+ transform.strField1 = strField1;
+ return transform;
+ }
+
+ public DummyTransformWithConstructorMethodAndBuilderMethods
withStrField2(String strField2) {
+ this.strField2 = strField2;
+ return this;
+ }
+
+ public DummyTransformWithConstructorMethodAndBuilderMethods
withIntField1(int intField1) {
+ this.intField1 = intField1;
+ return this;
+ }
+ }
+
+ public static class DummyTransformWithMultiLanguageAnnotations extends
DummyTransform {
+
+ @MultiLanguageConstructorMethod(name = "create_transform")
+ public static DummyTransformWithMultiLanguageAnnotations from(String
strField1) {
+ DummyTransformWithMultiLanguageAnnotations transform =
+ new DummyTransformWithMultiLanguageAnnotations();
+ transform.strField1 = strField1;
+ return transform;
+ }
+
+ @MultiLanguageBuilderMethod(name = "abc")
+ public DummyTransformWithMultiLanguageAnnotations withStrField2(String
strField2) {
+ this.strField2 = strField2;
+ return this;
+ }
+
+ @MultiLanguageBuilderMethod(name = "xyz")
+ public DummyTransformWithMultiLanguageAnnotations withIntField1(int
intField1) {
+ this.intField1 = intField1;
+ return this;
+ }
+ }
+
+ public static class DummyTransformWithWrapperTypes extends DummyTransform {
+ public DummyTransformWithWrapperTypes(String strField1) {
+ this.strField1 = strField1;
+ }
+
+ public DummyTransformWithWrapperTypes withDoubleWrapperField(Double
doubleWrapperField) {
+ this.doubleWrapperField = doubleWrapperField;
+ return this;
+ }
+ }
+
+ public static class DummyTransformWithComplexTypes extends DummyTransform {
+ public DummyTransformWithComplexTypes(String strField1) {
+ this.strField1 = strField1;
+ }
+
+ public DummyTransformWithComplexTypes
withComplexTypeField(DummyComplexType complexTypeField) {
+ this.complexTypeField = complexTypeField;
+ return this;
+ }
+ }
+
+ public static class DummyTransformWithArray extends DummyTransform {
+ public DummyTransformWithArray(String strField1) {
+ this.strField1 = strField1;
+ }
+
+ public DummyTransformWithArray withStrArrayField(String[] strArrayField) {
+ this.strArrayField = strArrayField;
+ return this;
+ }
+ }
+
+ public static class DummyTransformWithList extends DummyTransform {
+ public DummyTransformWithList(String strField1) {
+ this.strField1 = strField1;
+ }
+
+ public DummyTransformWithList withStrListField(List<String> strListField) {
+ this.strListField = strListField;
+ return this;
+ }
+ }
+
+ public static class DummyTransformWithComplexTypeArray extends
DummyTransform {
+ public DummyTransformWithComplexTypeArray(String strField1) {
+ this.strField1 = strField1;
+ }
+
+ public DummyTransformWithComplexTypeArray withComplexTypeArrayField(
+ DummyComplexType[] complexTypeArrayField) {
+ this.complexTypeArrayField = complexTypeArrayField;
+ return this;
+ }
+ }
+
+ public static class DummyTransformWithComplexTypeList extends DummyTransform
{
+ public DummyTransformWithComplexTypeList(String strField1) {
+ this.strField1 = strField1;
+ }
+
+ public DummyTransformWithComplexTypeList withComplexTypeListField(
+ List<DummyComplexType> complexTypeListField) {
+ this.complexTypeListField = complexTypeListField;
+ return this;
+ }
+ }
+
+ void testClassLookupExpansionRequestConstruction(
+ ExternalTransforms.JavaClassLookupPayload payload, Map<String, Object>
fieldsToVerify) {
+ Pipeline p = Pipeline.create();
+
+ RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto(p);
+
+ ExpansionApi.ExpansionRequest request =
+ ExpansionApi.ExpansionRequest.newBuilder()
+ .setComponents(pipelineProto.getComponents())
+ .setTransform(
+ RunnerApi.PTransform.newBuilder()
+ .setUniqueName(TEST_NAME)
+ .setSpec(
+ RunnerApi.FunctionSpec.newBuilder()
+
.setUrn(getUrn(ExpansionMethods.Enum.JAVA_CLASS_LOOKUP))
+ .setPayload(payload.toByteString())))
+ .setNamespace(TEST_NAMESPACE)
+ .build();
+ ExpansionApi.ExpansionResponse response = expansionService.expand(request);
+ RunnerApi.PTransform expandedTransform = response.getTransform();
+ assertEquals(TEST_NAMESPACE + TEST_NAME,
expandedTransform.getUniqueName());
+ assertThat(expandedTransform.getInputsCount(), Matchers.is(0));
+ assertThat(expandedTransform.getOutputsCount(), Matchers.is(1));
+ assertEquals(2, expandedTransform.getSubtransformsCount());
+ assertEquals(2, expandedTransform.getSubtransformsCount());
+ assertThat(
+ expandedTransform.getSubtransforms(0),
+ anyOf(containsString("MyCreateTransform"),
containsString("MyParDoTransform")));
+ assertThat(
+ expandedTransform.getSubtransforms(1),
+ anyOf(containsString("MyCreateTransform"),
containsString("MyParDoTransform")));
+
+ org.apache.beam.model.pipeline.v1.RunnerApi.PTransform userParDoTransform
= null;
+ for (String transformId :
response.getComponents().getTransformsMap().keySet()) {
+ if (transformId.contains("ParMultiDo-Dummy-")) {
+ userParDoTransform =
response.getComponents().getTransformsMap().get(transformId);
+ }
+ }
+ assertNotNull(userParDoTransform);
+ ParDoPayload parDoPayload = null;
+ try {
+ parDoPayload =
ParDoPayload.parseFrom(userParDoTransform.getSpec().getPayload());
+ } catch (InvalidProtocolBufferException e) {
+ throw new RuntimeException(e);
+ }
+ assertNotNull(parDoPayload);
+ DummyDoFn doFn =
+ (DummyDoFn)
+
ParDoTranslation.doFnWithExecutionInformationFromProto(parDoPayload.getDoFn())
+ .getDoFn();
+ System.out.println("DoFn" + doFn);
+
+ List<String> verifiedFields = new ArrayList<>();
+ if (fieldsToVerify.keySet().contains("strField1")) {
+ assertEquals(doFn.strField1, fieldsToVerify.get("strField1"));
+ verifiedFields.add("strField1");
+ }
+ if (fieldsToVerify.keySet().contains("strField2")) {
+ assertEquals(doFn.strField2, fieldsToVerify.get("strField2"));
+ verifiedFields.add("strField2");
+ }
+ if (fieldsToVerify.keySet().contains("intField1")) {
+ assertEquals(doFn.intField1, fieldsToVerify.get("intField1"));
+ verifiedFields.add("intField1");
+ }
+ if (fieldsToVerify.keySet().contains("doubleWrapperField")) {
+ assertEquals(doFn.doubleWrapperField,
fieldsToVerify.get("doubleWrapperField"));
+ verifiedFields.add("doubleWrapperField");
+ }
+ if (fieldsToVerify.containsKey("complexTypeStrField")) {
+ assertEquals(
+ doFn.complexTypeField.complexTypeStrField,
fieldsToVerify.get("complexTypeStrField"));
+ verifiedFields.add("complexTypeStrField");
+ }
+ if (fieldsToVerify.containsKey("complexTypeIntField")) {
+ assertEquals(
+ doFn.complexTypeField.complexTypeIntField,
fieldsToVerify.get("complexTypeIntField"));
+ verifiedFields.add("complexTypeIntField");
+ }
+
+ if (fieldsToVerify.keySet().contains("strArrayField")) {
+ assertArrayEquals(doFn.strArrayField, (String[])
fieldsToVerify.get("strArrayField"));
+ verifiedFields.add("strArrayField");
+ }
+
+ if (fieldsToVerify.keySet().contains("strListField")) {
+ assertEquals(doFn.strListField, (List)
fieldsToVerify.get("strListField"));
+ verifiedFields.add("strListField");
+ }
+
+ if (fieldsToVerify.keySet().contains("complexTypeArrayField")) {
+ assertArrayEquals(
+ doFn.complexTypeArrayField,
+ (DummyComplexType[]) fieldsToVerify.get("complexTypeArrayField"));
+ verifiedFields.add("complexTypeArrayField");
+ }
+
+ if (fieldsToVerify.keySet().contains("complexTypeListField")) {
+ assertEquals(doFn.complexTypeListField, (List)
fieldsToVerify.get("complexTypeListField"));
+ verifiedFields.add("complexTypeListField");
+ }
+
+ List<String> unverifiedFields = new ArrayList<>(fieldsToVerify.keySet());
+ unverifiedFields.removeAll(verifiedFields);
+ if (!unverifiedFields.isEmpty()) {
+ throw new RuntimeException("Failed to verify some fields: " +
unverifiedFields);
+ }
+ }
+
+ @Test
+ public void testJavaClassLookupWithConstructor() {
+ ExternalTransforms.JavaClassLookupPayload.Builder payloadBuilder =
+ ExternalTransforms.JavaClassLookupPayload.newBuilder();
+ payloadBuilder.setClassName(
+
"org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithConstructor");
+
+ Row constructorRow =
+ Row.withSchema(Schema.of(Field.of("strField1", FieldType.STRING)))
+ .withFieldValue("strField1", "test_str_1")
+ .build();
+
+ payloadBuilder.setConstructorSchema(getProtoSchemaFromRow(constructorRow));
+
payloadBuilder.setConstructorPayload(getProtoPayloadFromRow(constructorRow));
+
+ testClassLookupExpansionRequestConstruction(
+ payloadBuilder.build(), ImmutableMap.of("strField1", "test_str_1"));
+ }
+
+ @Test
+ public void testJavaClassLookupWithConstructorMethod() {
+ ExternalTransforms.JavaClassLookupPayload.Builder payloadBuilder =
+ ExternalTransforms.JavaClassLookupPayload.newBuilder();
+ payloadBuilder.setClassName(
+
"org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithConstructorMethod");
+
+ payloadBuilder.setConstructorMethod("from");
+ Row constructorRow =
+ Row.withSchema(Schema.of(Field.of("strField1", FieldType.STRING)))
+ .withFieldValue("strField1", "test_str_1")
+ .build();
+
+ payloadBuilder.setConstructorSchema(getProtoSchemaFromRow(constructorRow));
+
payloadBuilder.setConstructorPayload(getProtoPayloadFromRow(constructorRow));
+
+ testClassLookupExpansionRequestConstruction(
+ payloadBuilder.build(), ImmutableMap.of("strField1", "test_str_1"));
+ }
+
+ @Test
+ public void testJavaClassLookupWithConstructorAndBuilderMethods() {
+ ExternalTransforms.JavaClassLookupPayload.Builder payloadBuilder =
+ ExternalTransforms.JavaClassLookupPayload.newBuilder();
+ payloadBuilder.setClassName(
+
"org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithConstructorAndBuilderMethods");
+
+ Row constructorRow =
+ Row.withSchema(Schema.of(Field.of("strField1", FieldType.STRING)))
+ .withFieldValue("strField1", "test_str_1")
+ .build();
+
+ payloadBuilder.setConstructorSchema(getProtoSchemaFromRow(constructorRow));
+
payloadBuilder.setConstructorPayload(getProtoPayloadFromRow(constructorRow));
+
+ BuilderMethod.Builder builderMethodBuilder = BuilderMethod.newBuilder();
+ builderMethodBuilder.setName("withStrField2");
+ Row builderMethodRow =
+ Row.withSchema(Schema.of(Field.of("strField2", FieldType.STRING)))
+ .withFieldValue("strField2", "test_str_2")
+ .build();
+ builderMethodBuilder.setSchema(getProtoSchemaFromRow(builderMethodRow));
+ builderMethodBuilder.setPayload(getProtoPayloadFromRow(builderMethodRow));
+
+ payloadBuilder.addBuilderMethods(builderMethodBuilder);
+
+ builderMethodBuilder = BuilderMethod.newBuilder();
+ builderMethodBuilder.setName("withIntField1");
+ builderMethodRow =
+ Row.withSchema(Schema.of(Field.of("intField1", FieldType.INT32)))
+ .withFieldValue("intField1", 10)
+ .build();
+ builderMethodBuilder.setSchema(getProtoSchemaFromRow(builderMethodRow));
+ builderMethodBuilder.setPayload(getProtoPayloadFromRow(builderMethodRow));
+
+ payloadBuilder.addBuilderMethods(builderMethodBuilder);
+
+ testClassLookupExpansionRequestConstruction(
+ payloadBuilder.build(),
+ ImmutableMap.of("strField1", "test_str_1", "strField2", "test_str_2",
"intField1", 10));
+ }
+
+ @Test
+ public void testJavaClassLookupWithMultiArgumentConstructor() {
+ ExternalTransforms.JavaClassLookupPayload.Builder payloadBuilder =
+ ExternalTransforms.JavaClassLookupPayload.newBuilder();
+ payloadBuilder.setClassName(
+
"org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithMultiArgumentConstructor");
+
+ Row constructorRow =
+ Row.withSchema(
+ Schema.of(
+ Field.of("strField1", FieldType.STRING),
+ Field.of("strField2", FieldType.STRING)))
+ .withFieldValue("strField1", "test_str_1")
+ .withFieldValue("strField2", "test_str_2")
+ .build();
+
+ payloadBuilder.setConstructorSchema(getProtoSchemaFromRow(constructorRow));
+
payloadBuilder.setConstructorPayload(getProtoPayloadFromRow(constructorRow));
+
+ testClassLookupExpansionRequestConstruction(
+ payloadBuilder.build(),
+ ImmutableMap.of("strField1", "test_str_1", "strField2", "test_str_2"));
+ }
+
+ @Test
+ public void testJavaClassLookupWithMultiArgumentBuilderMethod() {
+ ExternalTransforms.JavaClassLookupPayload.Builder payloadBuilder =
+ ExternalTransforms.JavaClassLookupPayload.newBuilder();
+ payloadBuilder.setClassName(
+
"org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithMultiArgumentBuilderMethod");
+
+ Row constructorRow =
+ Row.withSchema(Schema.of(Field.of("strField1", FieldType.STRING)))
+ .withFieldValue("strField1", "test_str_1")
+ .build();
+
+ payloadBuilder.setConstructorSchema(getProtoSchemaFromRow(constructorRow));
+
payloadBuilder.setConstructorPayload(getProtoPayloadFromRow(constructorRow));
+
+ BuilderMethod.Builder builderMethodBuilder = BuilderMethod.newBuilder();
+ builderMethodBuilder.setName("withFields");
+ Row builderMethodRow =
+ Row.withSchema(
+ Schema.of(
+ Field.of("strField2", FieldType.STRING),
+ Field.of("intField1", FieldType.INT32)))
+ .withFieldValue("strField2", "test_str_2")
+ .withFieldValue("intField1", 10)
+ .build();
+ builderMethodBuilder.setSchema(getProtoSchemaFromRow(builderMethodRow));
+ builderMethodBuilder.setPayload(getProtoPayloadFromRow(builderMethodRow));
+
+ payloadBuilder.addBuilderMethods(builderMethodBuilder);
+
+ testClassLookupExpansionRequestConstruction(
+ payloadBuilder.build(),
+ ImmutableMap.of("strField1", "test_str_1", "strField2", "test_str_2",
"intField1", 10));
+ }
+
+ @Test
+ public void testJavaClassLookupWithWrapperTypes() {
+ ExternalTransforms.JavaClassLookupPayload.Builder payloadBuilder =
+ ExternalTransforms.JavaClassLookupPayload.newBuilder();
+ payloadBuilder.setClassName(
+
"org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithWrapperTypes");
+
+ Row constructorRow =
+ Row.withSchema(Schema.of(Field.of("strField1", FieldType.STRING)))
+ .withFieldValue("strField1", "test_str_1")
+ .build();
+
+ payloadBuilder.setConstructorSchema(getProtoSchemaFromRow(constructorRow));
+
payloadBuilder.setConstructorPayload(getProtoPayloadFromRow(constructorRow));
+
+ BuilderMethod.Builder builderMethodBuilder = BuilderMethod.newBuilder();
+ builderMethodBuilder.setName("withDoubleWrapperField");
+ Row builderMethodRow =
+ Row.withSchema(Schema.of(Field.of("doubleWrapperField",
FieldType.DOUBLE)))
+ .withFieldValue("doubleWrapperField", 123.56)
+ .build();
+ builderMethodBuilder.setSchema(getProtoSchemaFromRow(builderMethodRow));
+ builderMethodBuilder.setPayload(getProtoPayloadFromRow(builderMethodRow));
+
+ payloadBuilder.addBuilderMethods(builderMethodBuilder);
+
+ testClassLookupExpansionRequestConstruction(
+ payloadBuilder.build(), ImmutableMap.of("doubleWrapperField", 123.56));
+ }
+
+ @Test
+ public void testJavaClassLookupWithComplexTypes() {
+ ExternalTransforms.JavaClassLookupPayload.Builder payloadBuilder =
+ ExternalTransforms.JavaClassLookupPayload.newBuilder();
+ payloadBuilder.setClassName(
+
"org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithComplexTypes");
+
+ Row constructorRow =
+ Row.withSchema(Schema.of(Field.of("strField1", FieldType.STRING)))
+ .withFieldValue("strField1", "test_str_1")
+ .build();
+ payloadBuilder.setConstructorSchema(getProtoSchemaFromRow(constructorRow));
+
payloadBuilder.setConstructorPayload(getProtoPayloadFromRow(constructorRow));
+
+ Schema complexTypeSchema =
+ Schema.builder()
+ .addStringField("complexTypeStrField")
+ .addInt32Field("complexTypeIntField")
+ .build();
+
+ BuilderMethod.Builder builderMethodBuilder = BuilderMethod.newBuilder();
+ builderMethodBuilder.setName("withComplexTypeField");
+
+ Row builderMethodParamRow =
+ Row.withSchema(complexTypeSchema)
+ .withFieldValue("complexTypeStrField", "complex_type_str_1")
+ .withFieldValue("complexTypeIntField", 123)
+ .build();
+
+ Schema builderMethodSchema =
+ Schema.builder().addRowField("complexTypeField",
complexTypeSchema).build();
+ Row builderMethodRow =
+ Row.withSchema(builderMethodSchema)
+ .withFieldValue("complexTypeField", builderMethodParamRow)
+ .build();
+ builderMethodBuilder.setSchema(getProtoSchemaFromRow(builderMethodRow));
+ builderMethodBuilder.setPayload(getProtoPayloadFromRow(builderMethodRow));
+
+ payloadBuilder.addBuilderMethods(builderMethodBuilder);
+
+ testClassLookupExpansionRequestConstruction(
+ payloadBuilder.build(),
+ ImmutableMap.of("complexTypeStrField", "complex_type_str_1",
"complexTypeIntField", 123));
+ }
+
+ @Test
+ public void testJavaClassLookupWithSimpleArrayType() {
+ ExternalTransforms.JavaClassLookupPayload.Builder payloadBuilder =
+ ExternalTransforms.JavaClassLookupPayload.newBuilder();
+ payloadBuilder.setClassName(
+
"org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithArray");
+
+ Row constructorRow =
+ Row.withSchema(Schema.of(Field.of("strField1", FieldType.STRING)))
+ .withFieldValue("strField1", "test_str_1")
+ .build();
+
+ payloadBuilder.setConstructorSchema(getProtoSchemaFromRow(constructorRow));
+
payloadBuilder.setConstructorPayload(getProtoPayloadFromRow(constructorRow));
+
+ BuilderMethod.Builder builderMethodBuilder = BuilderMethod.newBuilder();
+ builderMethodBuilder.setName("withStrArrayField");
+
+ Schema builderMethodSchema =
+ Schema.builder().addArrayField("strArrayField",
FieldType.STRING).build();
+
+ Row builderMethodRow =
+ Row.withSchema(builderMethodSchema)
+ .withFieldValue(
+ "strArrayField", ImmutableList.of("test_str_1", "test_str_2",
"test_str_3"))
+ .build();
+
+ builderMethodBuilder.setSchema(getProtoSchemaFromRow(builderMethodRow));
+ builderMethodBuilder.setPayload(getProtoPayloadFromRow(builderMethodRow));
+
+ payloadBuilder.addBuilderMethods(builderMethodBuilder);
+
+ String[] resultArray = {"test_str_1", "test_str_2", "test_str_3"};
+ testClassLookupExpansionRequestConstruction(
+ payloadBuilder.build(), ImmutableMap.of("strArrayField", resultArray));
+ }
+
+ @Test
+ public void testJavaClassLookupWithSimpleListType() {
+ ExternalTransforms.JavaClassLookupPayload.Builder payloadBuilder =
+ ExternalTransforms.JavaClassLookupPayload.newBuilder();
+ payloadBuilder.setClassName(
+
"org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithList");
+
+ Row constructorRow =
+ Row.withSchema(Schema.of(Field.of("strField1", FieldType.STRING)))
+ .withFieldValue("strField1", "test_str_1")
+ .build();
+
+ payloadBuilder.setConstructorSchema(getProtoSchemaFromRow(constructorRow));
+
payloadBuilder.setConstructorPayload(getProtoPayloadFromRow(constructorRow));
+
+ BuilderMethod.Builder builderMethodBuilder = BuilderMethod.newBuilder();
+ builderMethodBuilder.setName("withStrListField");
+
+ Schema builderMethodSchema =
+ Schema.builder().addIterableField("strListField",
FieldType.STRING).build();
+
+ Row builderMethodRow =
+ Row.withSchema(builderMethodSchema)
+ .withFieldValue(
+ "strListField", ImmutableList.of("test_str_1", "test_str_2",
"test_str_3"))
+ .build();
+
+ builderMethodBuilder.setSchema(getProtoSchemaFromRow(builderMethodRow));
+ builderMethodBuilder.setPayload(getProtoPayloadFromRow(builderMethodRow));
+
+ payloadBuilder.addBuilderMethods(builderMethodBuilder);
+
+ List<String> resultList = new ArrayList<>();
+ resultList.add("test_str_1");
+ resultList.add("test_str_2");
+ resultList.add("test_str_3");
+ testClassLookupExpansionRequestConstruction(
+ payloadBuilder.build(), ImmutableMap.of("strListField", resultList));
+ }
+
+ @Test
+ public void testJavaClassLookupWithComplexArrayType() {
+ ExternalTransforms.JavaClassLookupPayload.Builder payloadBuilder =
+ ExternalTransforms.JavaClassLookupPayload.newBuilder();
+ payloadBuilder.setClassName(
+
"org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithComplexTypeArray");
+
+ Schema complexTypeSchema =
+ Schema.builder()
+ .addStringField("complexTypeStrField")
+ .addInt32Field("complexTypeIntField")
+ .build();
+
+ Schema builderMethodSchema =
+ Schema.builder()
+ .addArrayField("complexTypeArrayField",
FieldType.row(complexTypeSchema))
+ .build();
+
+ Row constructorRow =
+ Row.withSchema(Schema.of(Field.of("strField1", FieldType.STRING)))
+ .withFieldValue("strField1", "test_str_1")
+ .build();
+
+ payloadBuilder.setConstructorSchema(getProtoSchemaFromRow(constructorRow));
+
payloadBuilder.setConstructorPayload(getProtoPayloadFromRow(constructorRow));
+
+ List<Row> complexTypeList = new ArrayList<>();
+ complexTypeList.add(
+ Row.withSchema(complexTypeSchema)
+ .withFieldValue("complexTypeStrField", "complex_type_str_1")
+ .withFieldValue("complexTypeIntField", 123)
+ .build());
+ complexTypeList.add(
+ Row.withSchema(complexTypeSchema)
+ .withFieldValue("complexTypeStrField", "complex_type_str_2")
+ .withFieldValue("complexTypeIntField", 456)
+ .build());
+ complexTypeList.add(
+ Row.withSchema(complexTypeSchema)
+ .withFieldValue("complexTypeStrField", "complex_type_str_3")
+ .withFieldValue("complexTypeIntField", 789)
+ .build());
+
+ BuilderMethod.Builder builderMethodBuilder = BuilderMethod.newBuilder();
+ builderMethodBuilder.setName("withComplexTypeArrayField");
+
+ Row builderMethodRow =
+ Row.withSchema(builderMethodSchema)
+ .withFieldValue("complexTypeArrayField", complexTypeList)
+ .build();
+
+ builderMethodBuilder.setSchema(getProtoSchemaFromRow(builderMethodRow));
+ builderMethodBuilder.setPayload(getProtoPayloadFromRow(builderMethodRow));
+
+ payloadBuilder.addBuilderMethods(builderMethodBuilder);
+
+ ArrayList<DummyComplexType> resultList = new ArrayList<>();
+ resultList.add(new DummyComplexType("complex_type_str_1", 123));
+ resultList.add(new DummyComplexType("complex_type_str_2", 456));
+ resultList.add(new DummyComplexType("complex_type_str_3", 789));
+
+ testClassLookupExpansionRequestConstruction(
+ payloadBuilder.build(),
+ ImmutableMap.of("complexTypeArrayField", resultList.toArray(new
DummyComplexType[0])));
+ }
+
+ @Test
+ public void testJavaClassLookupWithComplexListType() {
+ ExternalTransforms.JavaClassLookupPayload.Builder payloadBuilder =
+ ExternalTransforms.JavaClassLookupPayload.newBuilder();
+ payloadBuilder.setClassName(
+
"org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithComplexTypeList");
+
+ Schema complexTypeSchema =
+ Schema.builder()
+ .addStringField("complexTypeStrField")
+ .addInt32Field("complexTypeIntField")
+ .build();
+
+ Schema builderMethodSchema =
+ Schema.builder()
+ .addIterableField("complexTypeListField",
FieldType.row(complexTypeSchema))
+ .build();
+
+ Row constructorRow =
+ Row.withSchema(Schema.of(Field.of("strField1", FieldType.STRING)))
+ .withFieldValue("strField1", "test_str_1")
+ .build();
+
+ payloadBuilder.setConstructorSchema(getProtoSchemaFromRow(constructorRow));
+
payloadBuilder.setConstructorPayload(getProtoPayloadFromRow(constructorRow));
+
+ List<Row> complexTypeList = new ArrayList<>();
+ complexTypeList.add(
+ Row.withSchema(complexTypeSchema)
+ .withFieldValue("complexTypeStrField", "complex_type_str_1")
+ .withFieldValue("complexTypeIntField", 123)
+ .build());
+ complexTypeList.add(
+ Row.withSchema(complexTypeSchema)
+ .withFieldValue("complexTypeStrField", "complex_type_str_2")
+ .withFieldValue("complexTypeIntField", 456)
+ .build());
+ complexTypeList.add(
+ Row.withSchema(complexTypeSchema)
+ .withFieldValue("complexTypeStrField", "complex_type_str_3")
+ .withFieldValue("complexTypeIntField", 789)
+ .build());
+
+ BuilderMethod.Builder builderMethodBuilder = BuilderMethod.newBuilder();
+ builderMethodBuilder.setName("withComplexTypeListField");
+
+ Row builderMethodRow =
+ Row.withSchema(builderMethodSchema)
+ .withFieldValue("complexTypeListField", complexTypeList)
+ .build();
+
+ builderMethodBuilder.setSchema(getProtoSchemaFromRow(builderMethodRow));
+ builderMethodBuilder.setPayload(getProtoPayloadFromRow(builderMethodRow));
+
+ payloadBuilder.addBuilderMethods(builderMethodBuilder);
+
+ ArrayList<DummyComplexType> resultList = new ArrayList<>();
+ resultList.add(new DummyComplexType("complex_type_str_1", 123));
+ resultList.add(new DummyComplexType("complex_type_str_2", 456));
+ resultList.add(new DummyComplexType("complex_type_str_3", 789));
+
+ testClassLookupExpansionRequestConstruction(
+ payloadBuilder.build(), ImmutableMap.of("complexTypeListField",
resultList));
+ }
+
+ @Test
+ public void testJavaClassLookupWithConstructorMethodAndBuilderMethods() {
+ ExternalTransforms.JavaClassLookupPayload.Builder payloadBuilder =
+ ExternalTransforms.JavaClassLookupPayload.newBuilder();
+ payloadBuilder.setClassName(
+
"org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithConstructorMethodAndBuilderMethods");
+ payloadBuilder.setConstructorMethod("from");
+
+ Row constructorRow =
+ Row.withSchema(Schema.of(Field.of("strField1", FieldType.STRING)))
+ .withFieldValue("strField1", "test_str_1")
+ .build();
+
+ payloadBuilder.setConstructorSchema(getProtoSchemaFromRow(constructorRow));
+
payloadBuilder.setConstructorPayload(getProtoPayloadFromRow(constructorRow));
+
+ BuilderMethod.Builder builderMethodBuilder = BuilderMethod.newBuilder();
+ builderMethodBuilder.setName("withStrField2");
+
+ Row builderMethodRow =
+ Row.withSchema(Schema.of(Field.of("strField2", FieldType.STRING)))
+ .withFieldValue("strField2", "test_str_2")
+ .build();
+ builderMethodBuilder.setSchema(getProtoSchemaFromRow(builderMethodRow));
+ builderMethodBuilder.setPayload(getProtoPayloadFromRow(builderMethodRow));
+
+ payloadBuilder.addBuilderMethods(builderMethodBuilder);
+
+ builderMethodBuilder = BuilderMethod.newBuilder();
+ builderMethodBuilder.setName("withIntField1");
+
+ builderMethodRow =
+ Row.withSchema(Schema.of(Field.of("intField1", FieldType.INT32)))
+ .withFieldValue("intField1", 10)
+ .build();
+ builderMethodBuilder.setSchema(getProtoSchemaFromRow(builderMethodRow));
+ builderMethodBuilder.setPayload(getProtoPayloadFromRow(builderMethodRow));
+
+ payloadBuilder.addBuilderMethods(builderMethodBuilder);
+
+ testClassLookupExpansionRequestConstruction(
+ payloadBuilder.build(),
+ ImmutableMap.of("strField1", "test_str_1", "strField2", "test_str_2",
"intField1", 10));
+ }
+
+ @Test
+ public void testJavaClassLookupWithSimplifiedBuilderMethodNames() {
+ ExternalTransforms.JavaClassLookupPayload.Builder payloadBuilder =
+ ExternalTransforms.JavaClassLookupPayload.newBuilder();
+ payloadBuilder.setClassName(
+
"org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithConstructorMethodAndBuilderMethods");
+ payloadBuilder.setConstructorMethod("from");
+
+ Row constructorRow =
+ Row.withSchema(Schema.of(Field.of("strField1", FieldType.STRING)))
+ .withFieldValue("strField1", "test_str_1")
+ .build();
+
+ payloadBuilder.setConstructorSchema(getProtoSchemaFromRow(constructorRow));
+
payloadBuilder.setConstructorPayload(getProtoPayloadFromRow(constructorRow));
+
+ BuilderMethod.Builder builderMethodBuilder = BuilderMethod.newBuilder();
+ builderMethodBuilder.setName("strField2");
+ Row builderMethodRow =
+ Row.withSchema(Schema.of(Field.of("strField2", FieldType.STRING)))
+ .withFieldValue("strField2", "test_str_2")
+ .build();
+ builderMethodBuilder.setSchema(getProtoSchemaFromRow(builderMethodRow));
+ builderMethodBuilder.setPayload(getProtoPayloadFromRow(builderMethodRow));
+
+ payloadBuilder.addBuilderMethods(builderMethodBuilder);
+
+ builderMethodBuilder = BuilderMethod.newBuilder();
+ builderMethodBuilder.setName("intField1");
+ builderMethodRow =
+ Row.withSchema(Schema.of(Field.of("intField1", FieldType.INT32)))
+ .withFieldValue("intField1", 10)
+ .build();
+ builderMethodBuilder.setSchema(getProtoSchemaFromRow(builderMethodRow));
+ builderMethodBuilder.setPayload(getProtoPayloadFromRow(builderMethodRow));
+
+ payloadBuilder.addBuilderMethods(builderMethodBuilder);
+
+ testClassLookupExpansionRequestConstruction(
+ payloadBuilder.build(),
+ ImmutableMap.of("strField1", "test_str_1", "strField2", "test_str_2",
"intField1", 10));
+ }
+
+ @Test
+ public void testJavaClassLookupWithAnnotations() {
+ ExternalTransforms.JavaClassLookupPayload.Builder payloadBuilder =
+ ExternalTransforms.JavaClassLookupPayload.newBuilder();
+ payloadBuilder.setClassName(
+
"org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithMultiLanguageAnnotations");
+ payloadBuilder.setConstructorMethod("create_transform");
+ Row constructorRow =
+ Row.withSchema(Schema.of(Field.of("strField1", FieldType.STRING)))
+ .withFieldValue("strField1", "test_str_1")
+ .build();
+
+ payloadBuilder.setConstructorSchema(getProtoSchemaFromRow(constructorRow));
+
payloadBuilder.setConstructorPayload(getProtoPayloadFromRow(constructorRow));
+
+ BuilderMethod.Builder builderMethodBuilder = BuilderMethod.newBuilder();
+ builderMethodBuilder.setName("abc");
+ Row builderMethodRow =
+ Row.withSchema(Schema.of(Field.of("strField2", FieldType.STRING)))
+ .withFieldValue("strField2", "test_str_2")
+ .build();
+ builderMethodBuilder.setSchema(getProtoSchemaFromRow(builderMethodRow));
+ builderMethodBuilder.setPayload(getProtoPayloadFromRow(builderMethodRow));
+
+ payloadBuilder.addBuilderMethods(builderMethodBuilder);
+
+ builderMethodBuilder = BuilderMethod.newBuilder();
+ builderMethodBuilder.setName("xyz");
+ builderMethodRow =
+ Row.withSchema(Schema.of(Field.of("intField1", FieldType.INT32)))
+ .withFieldValue("intField1", 10)
+ .build();
+ builderMethodBuilder.setSchema(getProtoSchemaFromRow(builderMethodRow));
+ builderMethodBuilder.setPayload(getProtoPayloadFromRow(builderMethodRow));
+
+ payloadBuilder.addBuilderMethods(builderMethodBuilder);
+
+ testClassLookupExpansionRequestConstruction(
+ payloadBuilder.build(),
+ ImmutableMap.of("strField1", "test_str_1", "strField2", "test_str_2",
"intField1", 10));
+ }
+
+ @Test
+ public void testJavaClassLookupClassNotAvailable() {
+ ExternalTransforms.JavaClassLookupPayload.Builder payloadBuilder =
+ ExternalTransforms.JavaClassLookupPayload.newBuilder();
+ payloadBuilder.setClassName(
+
"org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$UnavailableClass");
+ Row constructorRow =
+ Row.withSchema(Schema.of(Field.of("strField1", FieldType.STRING)))
+ .withFieldValue("strField1", "test_str_1")
+ .build();
+
+ payloadBuilder.setConstructorSchema(getProtoSchemaFromRow(constructorRow));
+
payloadBuilder.setConstructorPayload(getProtoPayloadFromRow(constructorRow));
+
+ RuntimeException thrown =
+ assertThrows(
+ RuntimeException.class,
+ () ->
+ testClassLookupExpansionRequestConstruction(
+ payloadBuilder.build(), ImmutableMap.of()));
+ assertTrue(thrown.getMessage().contains("does not enable"));
+ }
+
+ @Test
+ public void testJavaClassLookupIncorrectConstructionParameter() {
+ ExternalTransforms.JavaClassLookupPayload.Builder payloadBuilder =
+ ExternalTransforms.JavaClassLookupPayload.newBuilder();
+ payloadBuilder.setClassName(
+
"org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithConstructor");
+ Row constructorRow =
+ Row.withSchema(Schema.of(Field.of("incorrectField", FieldType.STRING)))
+ .withFieldValue("incorrectField", "test_str_1")
+ .build();
+
+ payloadBuilder.setConstructorSchema(getProtoSchemaFromRow(constructorRow));
+
payloadBuilder.setConstructorPayload(getProtoPayloadFromRow(constructorRow));
+
+ RuntimeException thrown =
+ assertThrows(
+ RuntimeException.class,
+ () ->
+ testClassLookupExpansionRequestConstruction(
+ payloadBuilder.build(), ImmutableMap.of()));
+ assertTrue(thrown.getMessage().contains("Expected to find a single mapping
constructor"));
+ }
+
+ @Test
+ public void testJavaClassLookupIncorrectBuilderMethodParameter() {
+ ExternalTransforms.JavaClassLookupPayload.Builder payloadBuilder =
+ ExternalTransforms.JavaClassLookupPayload.newBuilder();
+ payloadBuilder.setClassName(
+
"org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithConstructorAndBuilderMethods");
+ Row constructorRow =
+ Row.withSchema(Schema.of(Field.of("strField1", FieldType.STRING)))
+ .withFieldValue("strField1", "test_str_1")
+ .build();
+
+ payloadBuilder.setConstructorSchema(getProtoSchemaFromRow(constructorRow));
+
payloadBuilder.setConstructorPayload(getProtoPayloadFromRow(constructorRow));
+
+ BuilderMethod.Builder builderMethodBuilder = BuilderMethod.newBuilder();
+ builderMethodBuilder.setName("withStrField2");
+ Row builderMethodRow =
+ Row.withSchema(Schema.of(Field.of("incorrectParam", FieldType.STRING)))
+ .withFieldValue("incorrectParam", "test_str_2")
+ .build();
+ builderMethodBuilder.setSchema(getProtoSchemaFromRow(builderMethodRow));
+ builderMethodBuilder.setPayload(getProtoPayloadFromRow(builderMethodRow));
+
+ payloadBuilder.addBuilderMethods(builderMethodBuilder);
+
+ RuntimeException thrown =
+ assertThrows(
+ RuntimeException.class,
+ () ->
+ testClassLookupExpansionRequestConstruction(
+ payloadBuilder.build(), ImmutableMap.of()));
+ assertTrue(thrown.getMessage().contains("Expected to find exactly one
matching method"));
+ }
+
+ private SchemaApi.Schema getProtoSchemaFromRow(Row row) {
+ return SchemaTranslation.schemaToProto(row.getSchema(), true);
+ }
+
+ private ByteString getProtoPayloadFromRow(Row row) {
+ ByteString.Output outputStream = ByteString.newOutput();
+ try {
+ SchemaCoder.of(row.getSchema()).encode(row, outputStream);
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+
+ return outputStream.toByteString();
+ }
+}
diff --git a/sdks/java/expansion-service/src/test/resources/test_allowlist.yaml
b/sdks/java/expansion-service/src/test/resources/test_allowlist.yaml
new file mode 100644
index 0000000..ad11523
--- /dev/null
+++ b/sdks/java/expansion-service/src/test/resources/test_allowlist.yaml
@@ -0,0 +1,67 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+version: v1
+allowedClasses:
+- className:
org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithConstructor
+- className:
org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithConstructorMethod
+ allowedConstructorMethods:
+ - from
+- className:
org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithConstructorAndBuilderMethods
+ allowedBuilderMethods:
+ - withStrField2
+ - withIntField1
+- className:
org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithMultiArgumentConstructor
+- className:
org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithMultiArgumentBuilderMethod
+ allowedBuilderMethods:
+ - withFields
+- className:
org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithConstructorMethodAndBuilderMethods
+ allowedConstructorMethods:
+ - from
+ allowedBuilderMethods:
+ - withStrField2
+ - withIntField1
+ - strField2
+ - intField1
+- className:
org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithMultiLanguageAnnotations
+ allowedConstructorMethods:
+ - create_transform
+ allowedBuilderMethods:
+ - abc
+ - xyz
+- className:
org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithWrapperTypes
+ allowedBuilderMethods:
+ - withDoubleWrapperField
+- className:
org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithComplexTypes
+ allowedBuilderMethods:
+ - withComplexTypeField
+- className:
org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithArray
+ allowedBuilderMethods:
+ - withStrArrayField
+- className:
org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithList
+ allowedBuilderMethods:
+ - withStrListField
+- className:
org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithComplexTypeArray
+ allowedBuilderMethods:
+ - withComplexTypeArrayField
+- className:
org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithComplexTypeList
+ allowedBuilderMethods:
+ - withComplexTypeListField
+
+