lukecwik commented on a change in pull request #15343: URL: https://github.com/apache/beam/pull/15343#discussion_r698701576
########## File path: sdks/java/expansion-service/src/main/java/org/apache/beam/sdk/expansion/service/JavaClassLookupTransformProvider.java ########## @@ -0,0 +1,449 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.expansion.service; + +import static org.apache.beam.runners.core.construction.BeamUrns.getUrn; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.auto.value.AutoValue; +import edu.umd.cs.findbugs.annotations.SuppressFBWarnings; +import java.lang.annotation.Annotation; +import java.lang.reflect.Constructor; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.stream.Collectors; +import org.apache.beam.model.pipeline.v1.ExternalTransforms.BuilderMethod; +import org.apache.beam.model.pipeline.v1.ExternalTransforms.ExpansionMethods; +import org.apache.beam.model.pipeline.v1.ExternalTransforms.JavaClassLookupPayload; +import org.apache.beam.model.pipeline.v1.ExternalTransforms.Parameter; +import org.apache.beam.model.pipeline.v1.RunnerApi.FunctionSpec; +import org.apache.beam.repackaged.core.org.apache.commons.lang3.ClassUtils; +import org.apache.beam.sdk.expansion.service.ExpansionService.ExternalTransformRegistrarLoader; +import org.apache.beam.sdk.expansion.service.ExpansionService.TransformProvider; +import org.apache.beam.sdk.schemas.JavaFieldSchema; +import org.apache.beam.sdk.schemas.NoSuchSchemaException; +import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.schemas.SchemaRegistry; +import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.transforms.SerializableFunction; +import org.apache.beam.sdk.util.common.ReflectHelpers; +import org.apache.beam.sdk.values.PInput; +import org.apache.beam.sdk.values.POutput; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.vendor.grpc.v1p36p0.com.google.protobuf.InvalidProtocolBufferException; +import org.checkerframework.checker.nullness.qual.NonNull; + +/** + * A transform provider that can be used to directly instantiate a transform using Java class name + * and builder methods. + * + * @param <InputT> input {@link PInput} type of the transform + * @param <OutputT> output {@link POutput} type of the transform + */ +@SuppressWarnings({"argument.type.incompatible", "assignment.type.incompatible"}) +@SuppressFBWarnings("UWF_UNWRITTEN_PUBLIC_OR_PROTECTED_FIELD") +class JavaClassLookupTransformProvider<InputT extends PInput, OutputT extends POutput> + implements TransformProvider<PInput, POutput> { + + private static final SchemaRegistry SCHEMA_REGISTRY = SchemaRegistry.createDefault(); + AllowList allowList; + public static final String ALLOW_LIST_VERSION = "v1"; Review comment: ```suggestion public static final String ALLOW_LIST_VERSION = "v1"; private static final SchemaRegistry SCHEMA_REGISTRY = SchemaRegistry.createDefault(); private final AllowList allowList; ``` ########## File path: sdks/java/expansion-service/src/main/java/org/apache/beam/sdk/expansion/service/JavaClassLookupTransformProvider.java ########## @@ -0,0 +1,449 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.expansion.service; + +import static org.apache.beam.runners.core.construction.BeamUrns.getUrn; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.auto.value.AutoValue; +import edu.umd.cs.findbugs.annotations.SuppressFBWarnings; +import java.lang.annotation.Annotation; +import java.lang.reflect.Constructor; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.stream.Collectors; +import org.apache.beam.model.pipeline.v1.ExternalTransforms.BuilderMethod; +import org.apache.beam.model.pipeline.v1.ExternalTransforms.ExpansionMethods; +import org.apache.beam.model.pipeline.v1.ExternalTransforms.JavaClassLookupPayload; +import org.apache.beam.model.pipeline.v1.ExternalTransforms.Parameter; +import org.apache.beam.model.pipeline.v1.RunnerApi.FunctionSpec; +import org.apache.beam.repackaged.core.org.apache.commons.lang3.ClassUtils; +import org.apache.beam.sdk.expansion.service.ExpansionService.ExternalTransformRegistrarLoader; +import org.apache.beam.sdk.expansion.service.ExpansionService.TransformProvider; +import org.apache.beam.sdk.schemas.JavaFieldSchema; +import org.apache.beam.sdk.schemas.NoSuchSchemaException; +import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.schemas.SchemaRegistry; +import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.transforms.SerializableFunction; +import org.apache.beam.sdk.util.common.ReflectHelpers; +import org.apache.beam.sdk.values.PInput; +import org.apache.beam.sdk.values.POutput; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.vendor.grpc.v1p36p0.com.google.protobuf.InvalidProtocolBufferException; +import org.checkerframework.checker.nullness.qual.NonNull; + +/** + * A transform provider that can be used to directly instantiate a transform using Java class name + * and builder methods. + * + * @param <InputT> input {@link PInput} type of the transform + * @param <OutputT> output {@link POutput} type of the transform + */ +@SuppressWarnings({"argument.type.incompatible", "assignment.type.incompatible"}) +@SuppressFBWarnings("UWF_UNWRITTEN_PUBLIC_OR_PROTECTED_FIELD") +class JavaClassLookupTransformProvider<InputT extends PInput, OutputT extends POutput> + implements TransformProvider<PInput, POutput> { + + private static final SchemaRegistry SCHEMA_REGISTRY = SchemaRegistry.createDefault(); + AllowList allowList; + public static final String ALLOW_LIST_VERSION = "v1"; + + public JavaClassLookupTransformProvider(AllowList allowList) { + if (!allowList.getVersion().equals(ALLOW_LIST_VERSION)) { + throw new IllegalArgumentException("Unknown allow-list version"); + } + this.allowList = allowList; + } + + @Override + public PTransform<PInput, POutput> getTransform(FunctionSpec spec) { + JavaClassLookupPayload payload = null; + try { + payload = JavaClassLookupPayload.parseFrom(spec.getPayload()); + } catch (InvalidProtocolBufferException e) { + throw new IllegalArgumentException( + "Invalid payload type for URN " + getUrn(ExpansionMethods.Enum.JAVA_CLASS_LOOKUP), e); + } + + String className = payload.getClassName(); + try { + AllowedClass allowlistClass = null; + if (this.allowList != null) { + for (AllowedClass cls : this.allowList.getAllowedClasses()) { + if (cls.getClassName().equals(className)) { + if (allowlistClass != null) { + throw new IllegalArgumentException( + "Found two matching allowlist classes " + allowlistClass + " and " + cls); + } + allowlistClass = cls; + } + } + } + if (allowlistClass == null) { + throw new UnsupportedOperationException( + "Expanding a transform class by the name " + className + " is not allowed."); Review comment: ```suggestion "The provided allow list does not enable expanding a transform class by the name " + className + "."); ``` ########## File path: sdks/java/expansion-service/src/main/java/org/apache/beam/sdk/expansion/service/ExpansionServiceOptions.java ########## @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.expansion.service; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.dataformat.yaml.YAMLFactory; +import java.io.File; +import java.io.IOException; +import java.util.ArrayList; +import org.apache.beam.sdk.expansion.service.JavaClassLookupTransformProvider.AllowList; +import org.apache.beam.sdk.options.Default; +import org.apache.beam.sdk.options.DefaultValueFactory; +import org.apache.beam.sdk.options.Description; +import org.apache.beam.sdk.options.PipelineOptions; + +public interface ExpansionServiceOptions extends PipelineOptions { + + @Description("Allow list for Java class based transform expansion") + @Default.InstanceFactory(JavaClassLookupAllowListFactory.class) + AllowList getJavaClassLookupAllowlist(); + + void setJavaClassLookupAllowlist(AllowList file); + + @Description("Allow list file for Java class based transform expansion") + @Default.String("") + String getJavaClassLookupAllowlistFile(); + + void setJavaClassLookupAllowlistFile(String file); + + class JavaClassLookupAllowListFactory implements DefaultValueFactory<AllowList> { Review comment: Class comment, e.g.: `Loads the allow list from {@link #getJavaClassLookupAllowlistFile}, defaulting to an empty AllowList.` ########## File path: sdks/java/expansion-service/src/main/java/org/apache/beam/sdk/expansion/service/JavaClassLookupTransformProvider.java ########## @@ -0,0 +1,449 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.expansion.service; + +import static org.apache.beam.runners.core.construction.BeamUrns.getUrn; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.auto.value.AutoValue; +import edu.umd.cs.findbugs.annotations.SuppressFBWarnings; +import java.lang.annotation.Annotation; +import java.lang.reflect.Constructor; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.stream.Collectors; +import org.apache.beam.model.pipeline.v1.ExternalTransforms.BuilderMethod; +import org.apache.beam.model.pipeline.v1.ExternalTransforms.ExpansionMethods; +import org.apache.beam.model.pipeline.v1.ExternalTransforms.JavaClassLookupPayload; +import org.apache.beam.model.pipeline.v1.ExternalTransforms.Parameter; +import org.apache.beam.model.pipeline.v1.RunnerApi.FunctionSpec; +import org.apache.beam.repackaged.core.org.apache.commons.lang3.ClassUtils; +import org.apache.beam.sdk.expansion.service.ExpansionService.ExternalTransformRegistrarLoader; +import org.apache.beam.sdk.expansion.service.ExpansionService.TransformProvider; +import org.apache.beam.sdk.schemas.JavaFieldSchema; +import org.apache.beam.sdk.schemas.NoSuchSchemaException; +import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.schemas.SchemaRegistry; +import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.transforms.SerializableFunction; +import org.apache.beam.sdk.util.common.ReflectHelpers; +import org.apache.beam.sdk.values.PInput; +import org.apache.beam.sdk.values.POutput; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.vendor.grpc.v1p36p0.com.google.protobuf.InvalidProtocolBufferException; +import org.checkerframework.checker.nullness.qual.NonNull; + +/** + * A transform provider that can be used to directly instantiate a transform using Java class name + * and builder methods. + * + * @param <InputT> input {@link PInput} type of the transform + * @param <OutputT> output {@link POutput} type of the transform + */ +@SuppressWarnings({"argument.type.incompatible", "assignment.type.incompatible"}) +@SuppressFBWarnings("UWF_UNWRITTEN_PUBLIC_OR_PROTECTED_FIELD") +class JavaClassLookupTransformProvider<InputT extends PInput, OutputT extends POutput> + implements TransformProvider<PInput, POutput> { + + private static final SchemaRegistry SCHEMA_REGISTRY = SchemaRegistry.createDefault(); + AllowList allowList; + public static final String ALLOW_LIST_VERSION = "v1"; + + public JavaClassLookupTransformProvider(AllowList allowList) { + if (!allowList.getVersion().equals(ALLOW_LIST_VERSION)) { Review comment: What is the purpose of having a version? ########## File path: model/pipeline/src/main/proto/external_transforms.proto ########## @@ -40,3 +41,64 @@ message ExternalConfigurationPayload { // schema. bytes payload = 2; } + +// This defines a single parameter that should be provided to a method (or a +// constructor) of the transform class. +message Parameter { + // Name of the parameter. + // Optional. If available, may be used to validate the parameter’s name at + // runtime. + string name = 1; + + // A schema that maps to the parameter’s type. + Schema schema = 2; + + // A payload which can be decoded using ‘beam:coder:row:v1’ and the given + // schema. + bytes payload = 3; +} + +// This represents a builder method of the transform class. This may take one +// or more parameters. This has to return an instance of the transform. +message BuilderMethod { + // Name of the builder method + string name = 1; + + // Builder method parameters (in order) + repeated Parameter parameter = 2; +} + +// Defines specific expansion methods that may be used to expand cross-language +// transforms. +// Has to be set as the URN of the transform of the expansion request. +message ExpansionMethods { + enum Enum { + // Expand a Java transform using specified constructor and builder methods. + // Transform payload will be of type JavaClassLookupPayload. + JAVA_CLASS_LOOKUP = 0 [(org.apache.beam.model.pipeline.v1.beam_urn) = + "beam:expansion:payload:java_class_lookup:v1"]; + } +} + +// A configuration payload for an external transform. +// Used to define a Java transform that can be directly instantiated by a Java // expansion service. +message JavaClassLookupPayload { + // Name of the Java transform class. + string class_name = 1; + + // A method to construct the initial instance of the transform. + // In not provided, a constructor of the class will be used. Review comment: ```suggestion // A static method to construct the initial instance of the transform. // If not provided, the class's constructor will be used. ``` ########## File path: sdks/java/expansion-service/src/main/java/org/apache/beam/sdk/expansion/service/ExpansionServiceOptions.java ########## @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.expansion.service; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.dataformat.yaml.YAMLFactory; +import java.io.File; +import java.io.IOException; +import java.util.ArrayList; +import org.apache.beam.sdk.expansion.service.JavaClassLookupTransformProvider.AllowList; +import org.apache.beam.sdk.options.Default; +import org.apache.beam.sdk.options.DefaultValueFactory; +import org.apache.beam.sdk.options.Description; +import org.apache.beam.sdk.options.PipelineOptions; + +public interface ExpansionServiceOptions extends PipelineOptions { + + @Description("Allow list for Java class based transform expansion") + @Default.InstanceFactory(JavaClassLookupAllowListFactory.class) + AllowList getJavaClassLookupAllowlist(); + + void setJavaClassLookupAllowlist(AllowList file); + + @Description("Allow list file for Java class based transform expansion") + @Default.String("") Review comment: ```suggestion ``` ########## File path: sdks/java/expansion-service/src/main/java/org/apache/beam/sdk/expansion/service/ExpansionServiceOptions.java ########## @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.expansion.service; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.dataformat.yaml.YAMLFactory; +import java.io.File; +import java.io.IOException; +import java.util.ArrayList; +import org.apache.beam.sdk.expansion.service.JavaClassLookupTransformProvider.AllowList; +import org.apache.beam.sdk.options.Default; +import org.apache.beam.sdk.options.DefaultValueFactory; +import org.apache.beam.sdk.options.Description; +import org.apache.beam.sdk.options.PipelineOptions; + +public interface ExpansionServiceOptions extends PipelineOptions { + + @Description("Allow list for Java class based transform expansion") + @Default.InstanceFactory(JavaClassLookupAllowListFactory.class) + AllowList getJavaClassLookupAllowlist(); + + void setJavaClassLookupAllowlist(AllowList file); + + @Description("Allow list file for Java class based transform expansion") + @Default.String("") + String getJavaClassLookupAllowlistFile(); + + void setJavaClassLookupAllowlistFile(String file); + + class JavaClassLookupAllowListFactory implements DefaultValueFactory<AllowList> { + + @Override + public AllowList create(PipelineOptions options) { + String allowListFile = + options.as(ExpansionServiceOptions.class).getJavaClassLookupAllowlistFile(); + if (!allowListFile.isEmpty()) { Review comment: ```suggestion if (allowListFile != null) { ``` ########## File path: sdks/java/expansion-service/src/main/java/org/apache/beam/sdk/expansion/service/ExpansionServiceOptions.java ########## @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.expansion.service; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.dataformat.yaml.YAMLFactory; +import java.io.File; +import java.io.IOException; +import java.util.ArrayList; +import org.apache.beam.sdk.expansion.service.JavaClassLookupTransformProvider.AllowList; +import org.apache.beam.sdk.options.Default; +import org.apache.beam.sdk.options.DefaultValueFactory; +import org.apache.beam.sdk.options.Description; +import org.apache.beam.sdk.options.PipelineOptions; + +public interface ExpansionServiceOptions extends PipelineOptions { Review comment: interface comment ########## File path: sdks/java/expansion-service/src/test/java/org/apache/beam/sdk/expansion/service/JavaCLassLookupTransformProviderTest.java ########## @@ -0,0 +1,466 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.expansion.service; + +import static org.apache.beam.runners.core.construction.BeamUrns.getUrn; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.anyOf; +import static org.hamcrest.Matchers.containsString; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; + +import java.io.File; +import java.net.URL; +import org.apache.beam.model.expansion.v1.ExpansionApi; +import org.apache.beam.model.pipeline.v1.ExternalTransforms; +import org.apache.beam.model.pipeline.v1.ExternalTransforms.BuilderMethod; +import org.apache.beam.model.pipeline.v1.ExternalTransforms.ExpansionMethods; +import org.apache.beam.model.pipeline.v1.RunnerApi; +import org.apache.beam.runners.core.construction.PipelineTranslation; +import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.options.PipelineOptionsFactory; +import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.schemas.Schema.Field; +import org.apache.beam.sdk.schemas.Schema.FieldType; +import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.values.PBegin; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.io.Resources; +import org.hamcrest.Matchers; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for {@link JavaCLassLookupTransformProvider}. */ +@RunWith(JUnit4.class) +public class JavaCLassLookupTransformProviderTest { + + private static final String TEST_URN = "test:beam:transforms:count"; + + private static final String TEST_NAME = "TestName"; + + private static final String TEST_NAMESPACE = "namespace"; + + private static ExpansionService expansionService; + + @BeforeClass + public static void setupExpansionService() { + PipelineOptionsFactory.register(ExpansionServiceOptions.class); + URL allowListFile = Resources.getResource("./test_allowlist.yaml"); + System.out.println("Exists: " + new File(allowListFile.getPath()).exists()); + expansionService = + new ExpansionService( + new String[] {"--javaClassLookupAllowlistFile=" + allowListFile.getPath()}); + } + + public static class DummyTransform extends PTransform<PBegin, PCollection<String>> { + + String strField1; + String strField2; + int intField1; + + @Override + public PCollection<String> expand(PBegin input) { + return input + .apply("MyCreateTransform", Create.of("aaa", "bbb", "ccc")) + .apply( + "MyParDoTransform", + ParDo.of( + new DoFn<String, String>() { + @ProcessElement + public void processElement(ProcessContext c) { + c.output(c.element() + strField1); + } + })); + } + } + + public static class DummyTransformWithConstructor extends DummyTransform { + + public DummyTransformWithConstructor(String strField1) { + this.strField1 = strField1; + } + } + + public static class DummyTransformWithConstructorAndBuilderMethods extends DummyTransform { + + public DummyTransformWithConstructorAndBuilderMethods(String strField1) { + this.strField1 = strField1; + } + + public DummyTransformWithConstructorAndBuilderMethods withStrField2(String strField2) { + this.strField2 = strField2; + return this; + } + + public DummyTransformWithConstructorAndBuilderMethods withIntField1(int intField1) { + this.intField1 = intField1; + return this; + } + } + + public static class DummyTransformWithConstructorMethod extends DummyTransform { + + public static DummyTransformWithConstructorMethod from(String strField1) { + DummyTransformWithConstructorMethod transform = new DummyTransformWithConstructorMethod(); + transform.strField1 = strField1; + return transform; + } + } + + public static class DummyTransformWithConstructorMethodAndBuilderMethods extends DummyTransform { + + public static DummyTransformWithConstructorMethodAndBuilderMethods from(String strField1) { + DummyTransformWithConstructorMethodAndBuilderMethods transform = + new DummyTransformWithConstructorMethodAndBuilderMethods(); + transform.strField1 = strField1; + return transform; + } + + public DummyTransformWithConstructorMethodAndBuilderMethods withStrField2(String strField2) { + this.strField2 = strField2; + return this; + } + + public DummyTransformWithConstructorMethodAndBuilderMethods withIntField1(int intField1) { + this.intField1 = intField1; + return this; + } + } + + public static class DummyTransformWithMultiLanguageAnnotations extends DummyTransform { + + @MultiLanguageConstructorMethod(name = "create_transform") + public static DummyTransformWithMultiLanguageAnnotations from(String strField1) { + DummyTransformWithMultiLanguageAnnotations transform = + new DummyTransformWithMultiLanguageAnnotations(); + transform.strField1 = strField1; + return transform; + } + + @MultiLanguageBuilderMethod(name = "abc") + public DummyTransformWithMultiLanguageAnnotations withStrField2(String strField2) { + this.strField2 = strField2; + return this; + } + + @MultiLanguageBuilderMethod(name = "xyz") + public DummyTransformWithMultiLanguageAnnotations withIntField1(int intField1) { + this.intField1 = intField1; + return this; + } + } + + void testClassLookupExpansionRequestConstruction( + ExternalTransforms.JavaClassLookupPayload payloaad) { + Pipeline p = Pipeline.create(); + + RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto(p); + + ExpansionApi.ExpansionRequest request = + ExpansionApi.ExpansionRequest.newBuilder() + .setComponents(pipelineProto.getComponents()) + .setTransform( + RunnerApi.PTransform.newBuilder() + .setUniqueName(TEST_NAME) + .setSpec( + RunnerApi.FunctionSpec.newBuilder() + .setUrn(getUrn(ExpansionMethods.Enum.JAVA_CLASS_LOOKUP)) + .setPayload(payloaad.toByteString()))) + .setNamespace(TEST_NAMESPACE) + .build(); + ExpansionApi.ExpansionResponse response = expansionService.expand(request); + RunnerApi.PTransform expandedTransform = response.getTransform(); + assertEquals(TEST_NAMESPACE + TEST_NAME, expandedTransform.getUniqueName()); + assertThat(expandedTransform.getInputsCount(), Matchers.is(0)); + assertThat(expandedTransform.getOutputsCount(), Matchers.is(1)); + assertEquals(2, expandedTransform.getSubtransformsCount()); + assertEquals(2, expandedTransform.getSubtransformsCount()); + assertThat( + expandedTransform.getSubtransforms(0), + anyOf(containsString("MyCreateTransform"), containsString("MyParDoTransform"))); + assertThat( + expandedTransform.getSubtransforms(1), + anyOf(containsString("MyCreateTransform"), containsString("MyParDoTransform"))); + } + + @Test + public void testJavaClassLookupWithConstructor() { + ExternalTransforms.JavaClassLookupPayload.Builder payloadBuilder = + ExternalTransforms.JavaClassLookupPayload.newBuilder(); + payloadBuilder.setClassName( + "org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithConstructor"); + + payloadBuilder.addConstructorParameters( + ExpansionServiceTest.encodeRowIntoParameter( + Row.withSchema(Schema.of(Field.of("strField1", FieldType.STRING))) + .withFieldValue("strField1", "test_str_1") + .build(), + "strField1")); + + testClassLookupExpansionRequestConstruction(payloadBuilder.build()); + } + + @Test + public void testJavaClassLookupWithConstructorMethod() { + ExternalTransforms.JavaClassLookupPayload.Builder payloadBuilder = + ExternalTransforms.JavaClassLookupPayload.newBuilder(); + payloadBuilder.setClassName( + "org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithConstructorMethod"); + + payloadBuilder.setConstructorMethod("from"); + payloadBuilder.addConstructorParameters( + ExpansionServiceTest.encodeRowIntoParameter( + Row.withSchema(Schema.of(Field.of("strField1", FieldType.STRING))) + .withFieldValue("strField1", "test_str_1") + .build(), + "strField1")); + + testClassLookupExpansionRequestConstruction(payloadBuilder.build()); Review comment: Check that strField1 was set? ########## File path: sdks/java/expansion-service/src/test/java/org/apache/beam/sdk/expansion/service/JavaCLassLookupTransformProviderTest.java ########## @@ -0,0 +1,466 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.expansion.service; + +import static org.apache.beam.runners.core.construction.BeamUrns.getUrn; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.anyOf; +import static org.hamcrest.Matchers.containsString; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; + +import java.io.File; +import java.net.URL; +import org.apache.beam.model.expansion.v1.ExpansionApi; +import org.apache.beam.model.pipeline.v1.ExternalTransforms; +import org.apache.beam.model.pipeline.v1.ExternalTransforms.BuilderMethod; +import org.apache.beam.model.pipeline.v1.ExternalTransforms.ExpansionMethods; +import org.apache.beam.model.pipeline.v1.RunnerApi; +import org.apache.beam.runners.core.construction.PipelineTranslation; +import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.options.PipelineOptionsFactory; +import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.schemas.Schema.Field; +import org.apache.beam.sdk.schemas.Schema.FieldType; +import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.values.PBegin; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.io.Resources; +import org.hamcrest.Matchers; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for {@link JavaCLassLookupTransformProvider}. */ +@RunWith(JUnit4.class) +public class JavaCLassLookupTransformProviderTest { + + private static final String TEST_URN = "test:beam:transforms:count"; + + private static final String TEST_NAME = "TestName"; + + private static final String TEST_NAMESPACE = "namespace"; + + private static ExpansionService expansionService; + + @BeforeClass + public static void setupExpansionService() { + PipelineOptionsFactory.register(ExpansionServiceOptions.class); + URL allowListFile = Resources.getResource("./test_allowlist.yaml"); + System.out.println("Exists: " + new File(allowListFile.getPath()).exists()); + expansionService = + new ExpansionService( + new String[] {"--javaClassLookupAllowlistFile=" + allowListFile.getPath()}); + } + + public static class DummyTransform extends PTransform<PBegin, PCollection<String>> { + + String strField1; Review comment: Add tests for a wrapper type (e.g. Double field), and a complex type (non string, primitive, wrapper), and a list of simple type and list of complex type ########## File path: sdks/java/expansion-service/src/main/java/org/apache/beam/sdk/expansion/service/JavaClassLookupTransformProvider.java ########## @@ -0,0 +1,449 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.expansion.service; + +import static org.apache.beam.runners.core.construction.BeamUrns.getUrn; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.auto.value.AutoValue; +import edu.umd.cs.findbugs.annotations.SuppressFBWarnings; +import java.lang.annotation.Annotation; +import java.lang.reflect.Constructor; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.stream.Collectors; +import org.apache.beam.model.pipeline.v1.ExternalTransforms.BuilderMethod; +import org.apache.beam.model.pipeline.v1.ExternalTransforms.ExpansionMethods; +import org.apache.beam.model.pipeline.v1.ExternalTransforms.JavaClassLookupPayload; +import org.apache.beam.model.pipeline.v1.ExternalTransforms.Parameter; +import org.apache.beam.model.pipeline.v1.RunnerApi.FunctionSpec; +import org.apache.beam.repackaged.core.org.apache.commons.lang3.ClassUtils; +import org.apache.beam.sdk.expansion.service.ExpansionService.ExternalTransformRegistrarLoader; +import org.apache.beam.sdk.expansion.service.ExpansionService.TransformProvider; +import org.apache.beam.sdk.schemas.JavaFieldSchema; +import org.apache.beam.sdk.schemas.NoSuchSchemaException; +import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.schemas.SchemaRegistry; +import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.transforms.SerializableFunction; +import org.apache.beam.sdk.util.common.ReflectHelpers; +import org.apache.beam.sdk.values.PInput; +import org.apache.beam.sdk.values.POutput; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.vendor.grpc.v1p36p0.com.google.protobuf.InvalidProtocolBufferException; +import org.checkerframework.checker.nullness.qual.NonNull; + +/** + * A transform provider that can be used to directly instantiate a transform using Java class name + * and builder methods. + * + * @param <InputT> input {@link PInput} type of the transform + * @param <OutputT> output {@link POutput} type of the transform + */ +@SuppressWarnings({"argument.type.incompatible", "assignment.type.incompatible"}) +@SuppressFBWarnings("UWF_UNWRITTEN_PUBLIC_OR_PROTECTED_FIELD") +class JavaClassLookupTransformProvider<InputT extends PInput, OutputT extends POutput> + implements TransformProvider<PInput, POutput> { + + private static final SchemaRegistry SCHEMA_REGISTRY = SchemaRegistry.createDefault(); + AllowList allowList; + public static final String ALLOW_LIST_VERSION = "v1"; + + public JavaClassLookupTransformProvider(AllowList allowList) { + if (!allowList.getVersion().equals(ALLOW_LIST_VERSION)) { + throw new IllegalArgumentException("Unknown allow-list version"); + } + this.allowList = allowList; + } + + @Override + public PTransform<PInput, POutput> getTransform(FunctionSpec spec) { + JavaClassLookupPayload payload = null; + try { + payload = JavaClassLookupPayload.parseFrom(spec.getPayload()); + } catch (InvalidProtocolBufferException e) { + throw new IllegalArgumentException( + "Invalid payload type for URN " + getUrn(ExpansionMethods.Enum.JAVA_CLASS_LOOKUP), e); + } + + String className = payload.getClassName(); + try { + AllowedClass allowlistClass = null; + if (this.allowList != null) { + for (AllowedClass cls : this.allowList.getAllowedClasses()) { + if (cls.getClassName().equals(className)) { + if (allowlistClass != null) { + throw new IllegalArgumentException( + "Found two matching allowlist classes " + allowlistClass + " and " + cls); + } + allowlistClass = cls; + } + } + } + if (allowlistClass == null) { + throw new UnsupportedOperationException( + "Expanding a transform class by the name " + className + " is not allowed."); + } + Class<PTransform<InputT, OutputT>> transformClass = + (Class<PTransform<InputT, OutputT>>) + ReflectHelpers.findClassLoader().loadClass(className); + PTransform<PInput, POutput> transform; + if (payload.getConstructorMethod().isEmpty()) { + Constructor<?>[] constructors = transformClass.getConstructors(); + Constructor<PTransform<InputT, OutputT>> constructor = + findMappingConstructor(constructors, payload); + Object[] parameterValues = + getParameterValues( + constructor.getParameters(), + payload.getConstructorParametersList().toArray(new Parameter[0])); + transform = (PTransform<PInput, POutput>) constructor.newInstance(parameterValues); + } else { + Method[] methods = transformClass.getMethods(); + Method method = findMappingConstructorMethod(methods, payload, allowlistClass); + Object[] parameterValues = + getParameterValues( + method.getParameters(), + payload.getConstructorParametersList().toArray(new Parameter[0])); + transform = (PTransform<PInput, POutput>) method.invoke(null /* static */, parameterValues); + } + return applyBuilderMethods(transform, payload, allowlistClass); + } catch (ClassNotFoundException e) { + throw new IllegalArgumentException("Could not find class " + className, e); + } catch (InstantiationException + | IllegalArgumentException + | IllegalAccessException + | InvocationTargetException e) { + throw new IllegalArgumentException("Could not instantiate class " + className, e); + } + } + + private PTransform<PInput, POutput> applyBuilderMethods( + PTransform<PInput, POutput> transform, + JavaClassLookupPayload payload, + AllowedClass allowListClass) { + for (BuilderMethod builderMethod : payload.getBuilderMethodsList()) { + Method method = getMethod(transform, builderMethod, allowListClass); + try { + transform = + (PTransform<PInput, POutput>) + method.invoke( + transform, + getParameterValues( + method.getParameters(), + builderMethod.getParameterList().toArray(new Parameter[0]))); + } catch (IllegalAccessException | InvocationTargetException e) { + throw new IllegalArgumentException( + "Could not invoke the builder method " + + builderMethod + + " on transform " + + transform + + " with parameters " + + builderMethod.getParameterList(), + e); + } + } + + return transform; + } + + private boolean isBuilderMethodForName( + Method method, String nameFromPayload, AllowedClass allowListClass) { + // Lookup based on method annotations + for (Annotation annotation : method.getAnnotations()) { + if (annotation instanceof MultiLanguageBuilderMethod) { + if (nameFromPayload.equals(((MultiLanguageBuilderMethod) annotation).name())) { + if (allowListClass.getAllowedBuilderMethods().contains(nameFromPayload)) { + return true; + } else { + throw new RuntimeException( + "Builder method " + nameFromPayload + " has to be explicitly allowed"); + } + } + } + } + + // Lookup based on the method name. + boolean match = method.getName().equals(nameFromPayload); + String consideredMethodName = method.getName(); + + // We provide a simplification for common Java builder pattern naming convention where builder + // methods start with "with". In this case, for a builder method name in the form "withXyz", + // users may just use "xyz". If additional updates to the method name are needed the transform + // has to be updated by adding annotations. + if (!match && consideredMethodName.length() > 4 && consideredMethodName.startsWith("with")) { + consideredMethodName = + consideredMethodName.substring(4, 5).toLowerCase() + consideredMethodName.substring(5); + match = consideredMethodName.equals(nameFromPayload); + } + if (match && !allowListClass.getAllowedBuilderMethods().contains(consideredMethodName)) { + throw new RuntimeException( + "Builder method name " + consideredMethodName + " has to be explicitly allowed"); + } + return match; + } + + private Method getMethod( + PTransform<PInput, POutput> transform, + BuilderMethod builderMethod, + AllowedClass allowListClass) { + List<Method> matchingMethods = + Arrays.stream(transform.getClass().getMethods()) + .filter(m -> isBuilderMethodForName(m, builderMethod.getName(), allowListClass)) + .filter( + m -> + parametersCompatible( + m.getParameters(), + builderMethod.getParameterList().toArray(new Parameter[0]))) + .filter(m -> PTransform.class.isAssignableFrom(m.getReturnType())) + .collect(Collectors.toList()); + + if (matchingMethods.size() != 1) { + throw new RuntimeException( + "Expected to find exact one matching method in transform " Review comment: ```suggestion "Expected to find exact one matching method in transform " ``` ```suggestion "Expected to find exactly one matching method in transform " ``` ########## File path: sdks/java/expansion-service/src/test/java/org/apache/beam/sdk/expansion/service/JavaCLassLookupTransformProviderTest.java ########## @@ -0,0 +1,466 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.expansion.service; + +import static org.apache.beam.runners.core.construction.BeamUrns.getUrn; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.anyOf; +import static org.hamcrest.Matchers.containsString; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; + +import java.io.File; +import java.net.URL; +import org.apache.beam.model.expansion.v1.ExpansionApi; +import org.apache.beam.model.pipeline.v1.ExternalTransforms; +import org.apache.beam.model.pipeline.v1.ExternalTransforms.BuilderMethod; +import org.apache.beam.model.pipeline.v1.ExternalTransforms.ExpansionMethods; +import org.apache.beam.model.pipeline.v1.RunnerApi; +import org.apache.beam.runners.core.construction.PipelineTranslation; +import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.options.PipelineOptionsFactory; +import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.schemas.Schema.Field; +import org.apache.beam.sdk.schemas.Schema.FieldType; +import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.values.PBegin; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.io.Resources; +import org.hamcrest.Matchers; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for {@link JavaCLassLookupTransformProvider}. */ +@RunWith(JUnit4.class) +public class JavaCLassLookupTransformProviderTest { + + private static final String TEST_URN = "test:beam:transforms:count"; + + private static final String TEST_NAME = "TestName"; + + private static final String TEST_NAMESPACE = "namespace"; + + private static ExpansionService expansionService; + + @BeforeClass + public static void setupExpansionService() { + PipelineOptionsFactory.register(ExpansionServiceOptions.class); + URL allowListFile = Resources.getResource("./test_allowlist.yaml"); + System.out.println("Exists: " + new File(allowListFile.getPath()).exists()); + expansionService = + new ExpansionService( + new String[] {"--javaClassLookupAllowlistFile=" + allowListFile.getPath()}); + } + + public static class DummyTransform extends PTransform<PBegin, PCollection<String>> { + + String strField1; + String strField2; + int intField1; + + @Override + public PCollection<String> expand(PBegin input) { + return input + .apply("MyCreateTransform", Create.of("aaa", "bbb", "ccc")) + .apply( + "MyParDoTransform", + ParDo.of( + new DoFn<String, String>() { + @ProcessElement + public void processElement(ProcessContext c) { + c.output(c.element() + strField1); + } + })); + } + } + + public static class DummyTransformWithConstructor extends DummyTransform { + + public DummyTransformWithConstructor(String strField1) { + this.strField1 = strField1; + } + } + + public static class DummyTransformWithConstructorAndBuilderMethods extends DummyTransform { + + public DummyTransformWithConstructorAndBuilderMethods(String strField1) { + this.strField1 = strField1; + } + + public DummyTransformWithConstructorAndBuilderMethods withStrField2(String strField2) { + this.strField2 = strField2; + return this; + } + + public DummyTransformWithConstructorAndBuilderMethods withIntField1(int intField1) { + this.intField1 = intField1; + return this; + } + } + + public static class DummyTransformWithConstructorMethod extends DummyTransform { + + public static DummyTransformWithConstructorMethod from(String strField1) { + DummyTransformWithConstructorMethod transform = new DummyTransformWithConstructorMethod(); + transform.strField1 = strField1; + return transform; + } + } + + public static class DummyTransformWithConstructorMethodAndBuilderMethods extends DummyTransform { + + public static DummyTransformWithConstructorMethodAndBuilderMethods from(String strField1) { + DummyTransformWithConstructorMethodAndBuilderMethods transform = + new DummyTransformWithConstructorMethodAndBuilderMethods(); + transform.strField1 = strField1; + return transform; + } + + public DummyTransformWithConstructorMethodAndBuilderMethods withStrField2(String strField2) { + this.strField2 = strField2; + return this; + } + + public DummyTransformWithConstructorMethodAndBuilderMethods withIntField1(int intField1) { + this.intField1 = intField1; + return this; + } + } + + public static class DummyTransformWithMultiLanguageAnnotations extends DummyTransform { + + @MultiLanguageConstructorMethod(name = "create_transform") + public static DummyTransformWithMultiLanguageAnnotations from(String strField1) { + DummyTransformWithMultiLanguageAnnotations transform = + new DummyTransformWithMultiLanguageAnnotations(); + transform.strField1 = strField1; + return transform; + } + + @MultiLanguageBuilderMethod(name = "abc") + public DummyTransformWithMultiLanguageAnnotations withStrField2(String strField2) { + this.strField2 = strField2; + return this; + } + + @MultiLanguageBuilderMethod(name = "xyz") + public DummyTransformWithMultiLanguageAnnotations withIntField1(int intField1) { + this.intField1 = intField1; + return this; + } + } + + void testClassLookupExpansionRequestConstruction( + ExternalTransforms.JavaClassLookupPayload payloaad) { + Pipeline p = Pipeline.create(); + + RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto(p); + + ExpansionApi.ExpansionRequest request = + ExpansionApi.ExpansionRequest.newBuilder() + .setComponents(pipelineProto.getComponents()) + .setTransform( + RunnerApi.PTransform.newBuilder() + .setUniqueName(TEST_NAME) + .setSpec( + RunnerApi.FunctionSpec.newBuilder() + .setUrn(getUrn(ExpansionMethods.Enum.JAVA_CLASS_LOOKUP)) + .setPayload(payloaad.toByteString()))) + .setNamespace(TEST_NAMESPACE) + .build(); + ExpansionApi.ExpansionResponse response = expansionService.expand(request); + RunnerApi.PTransform expandedTransform = response.getTransform(); + assertEquals(TEST_NAMESPACE + TEST_NAME, expandedTransform.getUniqueName()); + assertThat(expandedTransform.getInputsCount(), Matchers.is(0)); + assertThat(expandedTransform.getOutputsCount(), Matchers.is(1)); + assertEquals(2, expandedTransform.getSubtransformsCount()); + assertEquals(2, expandedTransform.getSubtransformsCount()); + assertThat( + expandedTransform.getSubtransforms(0), + anyOf(containsString("MyCreateTransform"), containsString("MyParDoTransform"))); + assertThat( + expandedTransform.getSubtransforms(1), + anyOf(containsString("MyCreateTransform"), containsString("MyParDoTransform"))); + } + + @Test + public void testJavaClassLookupWithConstructor() { + ExternalTransforms.JavaClassLookupPayload.Builder payloadBuilder = + ExternalTransforms.JavaClassLookupPayload.newBuilder(); + payloadBuilder.setClassName( + "org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithConstructor"); + + payloadBuilder.addConstructorParameters( + ExpansionServiceTest.encodeRowIntoParameter( + Row.withSchema(Schema.of(Field.of("strField1", FieldType.STRING))) + .withFieldValue("strField1", "test_str_1") + .build(), + "strField1")); + + testClassLookupExpansionRequestConstruction(payloadBuilder.build()); + } + + @Test + public void testJavaClassLookupWithConstructorMethod() { + ExternalTransforms.JavaClassLookupPayload.Builder payloadBuilder = + ExternalTransforms.JavaClassLookupPayload.newBuilder(); + payloadBuilder.setClassName( + "org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithConstructorMethod"); + + payloadBuilder.setConstructorMethod("from"); + payloadBuilder.addConstructorParameters( + ExpansionServiceTest.encodeRowIntoParameter( + Row.withSchema(Schema.of(Field.of("strField1", FieldType.STRING))) + .withFieldValue("strField1", "test_str_1") + .build(), + "strField1")); + + testClassLookupExpansionRequestConstruction(payloadBuilder.build()); + } + + @Test + public void testJavaClassLookupWithConstructorAndBuilderMethods() { + ExternalTransforms.JavaClassLookupPayload.Builder payloadBuilder = + ExternalTransforms.JavaClassLookupPayload.newBuilder(); + payloadBuilder.setClassName( + "org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithConstructorAndBuilderMethods"); + + payloadBuilder.addConstructorParameters( + ExpansionServiceTest.encodeRowIntoParameter( + Row.withSchema(Schema.of(Field.of("strField1", FieldType.STRING))) + .withFieldValue("strField1", "test_str_1") + .build(), + "strField1")); + + BuilderMethod.Builder builderMethodBuilder = BuilderMethod.newBuilder(); + builderMethodBuilder.setName("withStrField2"); + builderMethodBuilder.addParameter( + ExpansionServiceTest.encodeRowIntoParameter( + Row.withSchema(Schema.of(Field.of("strField2", FieldType.STRING))) + .withFieldValue("strField2", "test_str_2") + .build(), + "strField2")); + payloadBuilder.addBuilderMethods(builderMethodBuilder); + + builderMethodBuilder = BuilderMethod.newBuilder(); + builderMethodBuilder.setName("withIntField1"); + builderMethodBuilder.addParameter( + ExpansionServiceTest.encodeRowIntoParameter( + Row.withSchema(Schema.of(Field.of("intField1", FieldType.INT32))) + .withFieldValue("intField1", 10) + .build(), + "intField1")); + payloadBuilder.addBuilderMethods(builderMethodBuilder); + + testClassLookupExpansionRequestConstruction(payloadBuilder.build()); + } + + @Test + public void testJavaClassLookupWithConstructorMethodAndBuilderMethods() { + ExternalTransforms.JavaClassLookupPayload.Builder payloadBuilder = + ExternalTransforms.JavaClassLookupPayload.newBuilder(); + payloadBuilder.setClassName( + "org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithConstructorMethodAndBuilderMethods"); + payloadBuilder.setConstructorMethod("from"); + + payloadBuilder.addConstructorParameters( + ExpansionServiceTest.encodeRowIntoParameter( + Row.withSchema(Schema.of(Field.of("strField1", FieldType.STRING))) + .withFieldValue("strField1", "test_str_1") + .build(), + "strField1")); + + BuilderMethod.Builder builderMethodBuilder = BuilderMethod.newBuilder(); + builderMethodBuilder.setName("withStrField2"); + builderMethodBuilder.addParameter( + ExpansionServiceTest.encodeRowIntoParameter( + Row.withSchema(Schema.of(Field.of("strField2", FieldType.STRING))) + .withFieldValue("strField2", "test_str_2") + .build(), + "strField2")); + payloadBuilder.addBuilderMethods(builderMethodBuilder); + + builderMethodBuilder = BuilderMethod.newBuilder(); + builderMethodBuilder.setName("withIntField1"); + builderMethodBuilder.addParameter( + ExpansionServiceTest.encodeRowIntoParameter( + Row.withSchema(Schema.of(Field.of("intField1", FieldType.INT32))) + .withFieldValue("intField1", 10) + .build(), + "intField1")); + payloadBuilder.addBuilderMethods(builderMethodBuilder); + + testClassLookupExpansionRequestConstruction(payloadBuilder.build()); + } + + @Test + public void testJavaClassLookupWithSimplifiedBuilderMethodNames() { + ExternalTransforms.JavaClassLookupPayload.Builder payloadBuilder = + ExternalTransforms.JavaClassLookupPayload.newBuilder(); + payloadBuilder.setClassName( + "org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithConstructorMethodAndBuilderMethods"); + payloadBuilder.setConstructorMethod("from"); + + payloadBuilder.addConstructorParameters( + ExpansionServiceTest.encodeRowIntoParameter( + Row.withSchema(Schema.of(Field.of("strField1", FieldType.STRING))) + .withFieldValue("strField1", "test_str_1") + .build(), + "strField1")); + + BuilderMethod.Builder builderMethodBuilder = BuilderMethod.newBuilder(); + builderMethodBuilder.setName("strField2"); + builderMethodBuilder.addParameter( + ExpansionServiceTest.encodeRowIntoParameter( + Row.withSchema(Schema.of(Field.of("strField2", FieldType.STRING))) + .withFieldValue("strField2", "test_str_2") + .build(), + "strField2")); + payloadBuilder.addBuilderMethods(builderMethodBuilder); + + builderMethodBuilder = BuilderMethod.newBuilder(); + builderMethodBuilder.setName("intField1"); + builderMethodBuilder.addParameter( + ExpansionServiceTest.encodeRowIntoParameter( + Row.withSchema(Schema.of(Field.of("intField1", FieldType.INT32))) + .withFieldValue("intField1", 10) + .build(), + "intField1")); + payloadBuilder.addBuilderMethods(builderMethodBuilder); + + testClassLookupExpansionRequestConstruction(payloadBuilder.build()); Review comment: Check that strField1, strField2 and intField1 was set. ########## File path: sdks/java/expansion-service/src/test/java/org/apache/beam/sdk/expansion/service/JavaCLassLookupTransformProviderTest.java ########## @@ -0,0 +1,466 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.expansion.service; + +import static org.apache.beam.runners.core.construction.BeamUrns.getUrn; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.anyOf; +import static org.hamcrest.Matchers.containsString; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; + +import java.io.File; +import java.net.URL; +import org.apache.beam.model.expansion.v1.ExpansionApi; +import org.apache.beam.model.pipeline.v1.ExternalTransforms; +import org.apache.beam.model.pipeline.v1.ExternalTransforms.BuilderMethod; +import org.apache.beam.model.pipeline.v1.ExternalTransforms.ExpansionMethods; +import org.apache.beam.model.pipeline.v1.RunnerApi; +import org.apache.beam.runners.core.construction.PipelineTranslation; +import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.options.PipelineOptionsFactory; +import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.schemas.Schema.Field; +import org.apache.beam.sdk.schemas.Schema.FieldType; +import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.values.PBegin; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.io.Resources; +import org.hamcrest.Matchers; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for {@link JavaCLassLookupTransformProvider}. */ +@RunWith(JUnit4.class) +public class JavaCLassLookupTransformProviderTest { + + private static final String TEST_URN = "test:beam:transforms:count"; + + private static final String TEST_NAME = "TestName"; + + private static final String TEST_NAMESPACE = "namespace"; + + private static ExpansionService expansionService; + + @BeforeClass + public static void setupExpansionService() { + PipelineOptionsFactory.register(ExpansionServiceOptions.class); + URL allowListFile = Resources.getResource("./test_allowlist.yaml"); + System.out.println("Exists: " + new File(allowListFile.getPath()).exists()); + expansionService = + new ExpansionService( + new String[] {"--javaClassLookupAllowlistFile=" + allowListFile.getPath()}); + } + + public static class DummyTransform extends PTransform<PBegin, PCollection<String>> { + + String strField1; + String strField2; + int intField1; + + @Override + public PCollection<String> expand(PBegin input) { + return input + .apply("MyCreateTransform", Create.of("aaa", "bbb", "ccc")) + .apply( + "MyParDoTransform", + ParDo.of( + new DoFn<String, String>() { + @ProcessElement + public void processElement(ProcessContext c) { + c.output(c.element() + strField1); + } + })); + } + } + + public static class DummyTransformWithConstructor extends DummyTransform { + + public DummyTransformWithConstructor(String strField1) { + this.strField1 = strField1; + } + } + + public static class DummyTransformWithConstructorAndBuilderMethods extends DummyTransform { + + public DummyTransformWithConstructorAndBuilderMethods(String strField1) { + this.strField1 = strField1; + } + + public DummyTransformWithConstructorAndBuilderMethods withStrField2(String strField2) { + this.strField2 = strField2; + return this; + } + + public DummyTransformWithConstructorAndBuilderMethods withIntField1(int intField1) { + this.intField1 = intField1; + return this; + } + } + + public static class DummyTransformWithConstructorMethod extends DummyTransform { + + public static DummyTransformWithConstructorMethod from(String strField1) { + DummyTransformWithConstructorMethod transform = new DummyTransformWithConstructorMethod(); + transform.strField1 = strField1; + return transform; + } + } + + public static class DummyTransformWithConstructorMethodAndBuilderMethods extends DummyTransform { + + public static DummyTransformWithConstructorMethodAndBuilderMethods from(String strField1) { + DummyTransformWithConstructorMethodAndBuilderMethods transform = + new DummyTransformWithConstructorMethodAndBuilderMethods(); + transform.strField1 = strField1; + return transform; + } + + public DummyTransformWithConstructorMethodAndBuilderMethods withStrField2(String strField2) { + this.strField2 = strField2; + return this; + } + + public DummyTransformWithConstructorMethodAndBuilderMethods withIntField1(int intField1) { + this.intField1 = intField1; + return this; + } + } + + public static class DummyTransformWithMultiLanguageAnnotations extends DummyTransform { + + @MultiLanguageConstructorMethod(name = "create_transform") + public static DummyTransformWithMultiLanguageAnnotations from(String strField1) { + DummyTransformWithMultiLanguageAnnotations transform = + new DummyTransformWithMultiLanguageAnnotations(); + transform.strField1 = strField1; + return transform; + } + + @MultiLanguageBuilderMethod(name = "abc") + public DummyTransformWithMultiLanguageAnnotations withStrField2(String strField2) { + this.strField2 = strField2; + return this; + } + + @MultiLanguageBuilderMethod(name = "xyz") + public DummyTransformWithMultiLanguageAnnotations withIntField1(int intField1) { + this.intField1 = intField1; + return this; + } + } + + void testClassLookupExpansionRequestConstruction( + ExternalTransforms.JavaClassLookupPayload payloaad) { + Pipeline p = Pipeline.create(); + + RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto(p); + + ExpansionApi.ExpansionRequest request = + ExpansionApi.ExpansionRequest.newBuilder() + .setComponents(pipelineProto.getComponents()) + .setTransform( + RunnerApi.PTransform.newBuilder() + .setUniqueName(TEST_NAME) + .setSpec( + RunnerApi.FunctionSpec.newBuilder() + .setUrn(getUrn(ExpansionMethods.Enum.JAVA_CLASS_LOOKUP)) + .setPayload(payloaad.toByteString()))) + .setNamespace(TEST_NAMESPACE) + .build(); + ExpansionApi.ExpansionResponse response = expansionService.expand(request); + RunnerApi.PTransform expandedTransform = response.getTransform(); + assertEquals(TEST_NAMESPACE + TEST_NAME, expandedTransform.getUniqueName()); + assertThat(expandedTransform.getInputsCount(), Matchers.is(0)); + assertThat(expandedTransform.getOutputsCount(), Matchers.is(1)); + assertEquals(2, expandedTransform.getSubtransformsCount()); + assertEquals(2, expandedTransform.getSubtransformsCount()); + assertThat( + expandedTransform.getSubtransforms(0), + anyOf(containsString("MyCreateTransform"), containsString("MyParDoTransform"))); + assertThat( + expandedTransform.getSubtransforms(1), + anyOf(containsString("MyCreateTransform"), containsString("MyParDoTransform"))); + } + + @Test + public void testJavaClassLookupWithConstructor() { + ExternalTransforms.JavaClassLookupPayload.Builder payloadBuilder = + ExternalTransforms.JavaClassLookupPayload.newBuilder(); + payloadBuilder.setClassName( + "org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithConstructor"); + + payloadBuilder.addConstructorParameters( + ExpansionServiceTest.encodeRowIntoParameter( + Row.withSchema(Schema.of(Field.of("strField1", FieldType.STRING))) + .withFieldValue("strField1", "test_str_1") + .build(), + "strField1")); + + testClassLookupExpansionRequestConstruction(payloadBuilder.build()); + } + + @Test + public void testJavaClassLookupWithConstructorMethod() { + ExternalTransforms.JavaClassLookupPayload.Builder payloadBuilder = + ExternalTransforms.JavaClassLookupPayload.newBuilder(); + payloadBuilder.setClassName( + "org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithConstructorMethod"); + + payloadBuilder.setConstructorMethod("from"); + payloadBuilder.addConstructorParameters( + ExpansionServiceTest.encodeRowIntoParameter( + Row.withSchema(Schema.of(Field.of("strField1", FieldType.STRING))) + .withFieldValue("strField1", "test_str_1") + .build(), + "strField1")); + + testClassLookupExpansionRequestConstruction(payloadBuilder.build()); + } + + @Test + public void testJavaClassLookupWithConstructorAndBuilderMethods() { + ExternalTransforms.JavaClassLookupPayload.Builder payloadBuilder = + ExternalTransforms.JavaClassLookupPayload.newBuilder(); + payloadBuilder.setClassName( + "org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithConstructorAndBuilderMethods"); + + payloadBuilder.addConstructorParameters( + ExpansionServiceTest.encodeRowIntoParameter( + Row.withSchema(Schema.of(Field.of("strField1", FieldType.STRING))) + .withFieldValue("strField1", "test_str_1") + .build(), + "strField1")); + + BuilderMethod.Builder builderMethodBuilder = BuilderMethod.newBuilder(); + builderMethodBuilder.setName("withStrField2"); + builderMethodBuilder.addParameter( + ExpansionServiceTest.encodeRowIntoParameter( + Row.withSchema(Schema.of(Field.of("strField2", FieldType.STRING))) + .withFieldValue("strField2", "test_str_2") + .build(), + "strField2")); + payloadBuilder.addBuilderMethods(builderMethodBuilder); + + builderMethodBuilder = BuilderMethod.newBuilder(); + builderMethodBuilder.setName("withIntField1"); + builderMethodBuilder.addParameter( + ExpansionServiceTest.encodeRowIntoParameter( + Row.withSchema(Schema.of(Field.of("intField1", FieldType.INT32))) + .withFieldValue("intField1", 10) + .build(), + "intField1")); + payloadBuilder.addBuilderMethods(builderMethodBuilder); + + testClassLookupExpansionRequestConstruction(payloadBuilder.build()); + } + + @Test + public void testJavaClassLookupWithConstructorMethodAndBuilderMethods() { + ExternalTransforms.JavaClassLookupPayload.Builder payloadBuilder = + ExternalTransforms.JavaClassLookupPayload.newBuilder(); + payloadBuilder.setClassName( + "org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithConstructorMethodAndBuilderMethods"); + payloadBuilder.setConstructorMethod("from"); + + payloadBuilder.addConstructorParameters( + ExpansionServiceTest.encodeRowIntoParameter( + Row.withSchema(Schema.of(Field.of("strField1", FieldType.STRING))) + .withFieldValue("strField1", "test_str_1") + .build(), + "strField1")); + + BuilderMethod.Builder builderMethodBuilder = BuilderMethod.newBuilder(); + builderMethodBuilder.setName("withStrField2"); + builderMethodBuilder.addParameter( + ExpansionServiceTest.encodeRowIntoParameter( + Row.withSchema(Schema.of(Field.of("strField2", FieldType.STRING))) + .withFieldValue("strField2", "test_str_2") + .build(), + "strField2")); + payloadBuilder.addBuilderMethods(builderMethodBuilder); + + builderMethodBuilder = BuilderMethod.newBuilder(); + builderMethodBuilder.setName("withIntField1"); + builderMethodBuilder.addParameter( + ExpansionServiceTest.encodeRowIntoParameter( + Row.withSchema(Schema.of(Field.of("intField1", FieldType.INT32))) + .withFieldValue("intField1", 10) + .build(), + "intField1")); + payloadBuilder.addBuilderMethods(builderMethodBuilder); + + testClassLookupExpansionRequestConstruction(payloadBuilder.build()); Review comment: Check that strField1, strField2 and intField1 was set. ########## File path: sdks/java/expansion-service/src/main/java/org/apache/beam/sdk/expansion/service/JavaClassLookupTransformProvider.java ########## @@ -0,0 +1,449 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.expansion.service; + +import static org.apache.beam.runners.core.construction.BeamUrns.getUrn; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.auto.value.AutoValue; +import edu.umd.cs.findbugs.annotations.SuppressFBWarnings; +import java.lang.annotation.Annotation; +import java.lang.reflect.Constructor; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.stream.Collectors; +import org.apache.beam.model.pipeline.v1.ExternalTransforms.BuilderMethod; +import org.apache.beam.model.pipeline.v1.ExternalTransforms.ExpansionMethods; +import org.apache.beam.model.pipeline.v1.ExternalTransforms.JavaClassLookupPayload; +import org.apache.beam.model.pipeline.v1.ExternalTransforms.Parameter; +import org.apache.beam.model.pipeline.v1.RunnerApi.FunctionSpec; +import org.apache.beam.repackaged.core.org.apache.commons.lang3.ClassUtils; +import org.apache.beam.sdk.expansion.service.ExpansionService.ExternalTransformRegistrarLoader; +import org.apache.beam.sdk.expansion.service.ExpansionService.TransformProvider; +import org.apache.beam.sdk.schemas.JavaFieldSchema; +import org.apache.beam.sdk.schemas.NoSuchSchemaException; +import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.schemas.SchemaRegistry; +import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.transforms.SerializableFunction; +import org.apache.beam.sdk.util.common.ReflectHelpers; +import org.apache.beam.sdk.values.PInput; +import org.apache.beam.sdk.values.POutput; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.vendor.grpc.v1p36p0.com.google.protobuf.InvalidProtocolBufferException; +import org.checkerframework.checker.nullness.qual.NonNull; + +/** + * A transform provider that can be used to directly instantiate a transform using Java class name + * and builder methods. + * + * @param <InputT> input {@link PInput} type of the transform + * @param <OutputT> output {@link POutput} type of the transform + */ +@SuppressWarnings({"argument.type.incompatible", "assignment.type.incompatible"}) +@SuppressFBWarnings("UWF_UNWRITTEN_PUBLIC_OR_PROTECTED_FIELD") +class JavaClassLookupTransformProvider<InputT extends PInput, OutputT extends POutput> + implements TransformProvider<PInput, POutput> { + + private static final SchemaRegistry SCHEMA_REGISTRY = SchemaRegistry.createDefault(); + AllowList allowList; + public static final String ALLOW_LIST_VERSION = "v1"; + + public JavaClassLookupTransformProvider(AllowList allowList) { + if (!allowList.getVersion().equals(ALLOW_LIST_VERSION)) { + throw new IllegalArgumentException("Unknown allow-list version"); + } + this.allowList = allowList; + } + + @Override + public PTransform<PInput, POutput> getTransform(FunctionSpec spec) { + JavaClassLookupPayload payload = null; + try { + payload = JavaClassLookupPayload.parseFrom(spec.getPayload()); + } catch (InvalidProtocolBufferException e) { + throw new IllegalArgumentException( + "Invalid payload type for URN " + getUrn(ExpansionMethods.Enum.JAVA_CLASS_LOOKUP), e); + } + + String className = payload.getClassName(); + try { + AllowedClass allowlistClass = null; + if (this.allowList != null) { + for (AllowedClass cls : this.allowList.getAllowedClasses()) { + if (cls.getClassName().equals(className)) { + if (allowlistClass != null) { + throw new IllegalArgumentException( + "Found two matching allowlist classes " + allowlistClass + " and " + cls); + } + allowlistClass = cls; + } + } + } + if (allowlistClass == null) { + throw new UnsupportedOperationException( + "Expanding a transform class by the name " + className + " is not allowed."); + } + Class<PTransform<InputT, OutputT>> transformClass = + (Class<PTransform<InputT, OutputT>>) + ReflectHelpers.findClassLoader().loadClass(className); + PTransform<PInput, POutput> transform; + if (payload.getConstructorMethod().isEmpty()) { + Constructor<?>[] constructors = transformClass.getConstructors(); + Constructor<PTransform<InputT, OutputT>> constructor = + findMappingConstructor(constructors, payload); + Object[] parameterValues = + getParameterValues( + constructor.getParameters(), + payload.getConstructorParametersList().toArray(new Parameter[0])); + transform = (PTransform<PInput, POutput>) constructor.newInstance(parameterValues); + } else { + Method[] methods = transformClass.getMethods(); + Method method = findMappingConstructorMethod(methods, payload, allowlistClass); + Object[] parameterValues = + getParameterValues( + method.getParameters(), + payload.getConstructorParametersList().toArray(new Parameter[0])); + transform = (PTransform<PInput, POutput>) method.invoke(null /* static */, parameterValues); + } + return applyBuilderMethods(transform, payload, allowlistClass); + } catch (ClassNotFoundException e) { + throw new IllegalArgumentException("Could not find class " + className, e); + } catch (InstantiationException + | IllegalArgumentException + | IllegalAccessException + | InvocationTargetException e) { + throw new IllegalArgumentException("Could not instantiate class " + className, e); + } + } + + private PTransform<PInput, POutput> applyBuilderMethods( + PTransform<PInput, POutput> transform, + JavaClassLookupPayload payload, + AllowedClass allowListClass) { + for (BuilderMethod builderMethod : payload.getBuilderMethodsList()) { + Method method = getMethod(transform, builderMethod, allowListClass); + try { + transform = + (PTransform<PInput, POutput>) + method.invoke( + transform, + getParameterValues( + method.getParameters(), + builderMethod.getParameterList().toArray(new Parameter[0]))); + } catch (IllegalAccessException | InvocationTargetException e) { + throw new IllegalArgumentException( + "Could not invoke the builder method " + + builderMethod + + " on transform " + + transform + + " with parameters " + + builderMethod.getParameterList(), + e); + } + } + + return transform; + } + + private boolean isBuilderMethodForName( + Method method, String nameFromPayload, AllowedClass allowListClass) { + // Lookup based on method annotations + for (Annotation annotation : method.getAnnotations()) { + if (annotation instanceof MultiLanguageBuilderMethod) { + if (nameFromPayload.equals(((MultiLanguageBuilderMethod) annotation).name())) { + if (allowListClass.getAllowedBuilderMethods().contains(nameFromPayload)) { + return true; + } else { + throw new RuntimeException( + "Builder method " + nameFromPayload + " has to be explicitly allowed"); + } + } + } + } + + // Lookup based on the method name. + boolean match = method.getName().equals(nameFromPayload); + String consideredMethodName = method.getName(); + + // We provide a simplification for common Java builder pattern naming convention where builder + // methods start with "with". In this case, for a builder method name in the form "withXyz", + // users may just use "xyz". If additional updates to the method name are needed the transform + // has to be updated by adding annotations. + if (!match && consideredMethodName.length() > 4 && consideredMethodName.startsWith("with")) { + consideredMethodName = + consideredMethodName.substring(4, 5).toLowerCase() + consideredMethodName.substring(5); + match = consideredMethodName.equals(nameFromPayload); + } + if (match && !allowListClass.getAllowedBuilderMethods().contains(consideredMethodName)) { + throw new RuntimeException( + "Builder method name " + consideredMethodName + " has to be explicitly allowed"); + } + return match; + } + + private Method getMethod( + PTransform<PInput, POutput> transform, + BuilderMethod builderMethod, + AllowedClass allowListClass) { + List<Method> matchingMethods = + Arrays.stream(transform.getClass().getMethods()) + .filter(m -> isBuilderMethodForName(m, builderMethod.getName(), allowListClass)) + .filter( + m -> + parametersCompatible( + m.getParameters(), + builderMethod.getParameterList().toArray(new Parameter[0]))) + .filter(m -> PTransform.class.isAssignableFrom(m.getReturnType())) + .collect(Collectors.toList()); + + if (matchingMethods.size() != 1) { + throw new RuntimeException( + "Expected to find exact one matching method in transform " + + transform + + " for BuilderMethod" + + builderMethod + + " but found " + + matchingMethods.size()); + } + return matchingMethods.get(0); + } + + private static boolean isPrimitiveOrWrapperOrString(java.lang.Class<?> type) { + return ClassUtils.isPrimitiveOrWrapper(type) || type == String.class; + } + + private boolean parametersCompatible( + java.lang.reflect.Parameter[] methodParameters, Parameter[] payloadParameters) { + if (methodParameters.length != payloadParameters.length) { + return false; + } + + for (int i = 0; i < methodParameters.length; i++) { + java.lang.reflect.Parameter parameterFromReflection = methodParameters[i]; + Parameter parameterFromPayload = payloadParameters[i]; + + String paramNameFromReflection = parameterFromReflection.getName(); + if (!paramNameFromReflection.startsWith("arg") + && !paramNameFromReflection.equals(parameterFromPayload.getName())) { + // Parameter name through reflection is from the class file (not through synthesizing, + // hence we can validate names) + return false; + } + + Class<PTransform<InputT, OutputT>> parameterClass = + (Class<PTransform<InputT, OutputT>>) parameterFromReflection.getType(); + Row parameterRow = + ExternalTransformRegistrarLoader.decodeRow( + parameterFromPayload.getSchema(), parameterFromPayload.getPayload()); + + Schema parameterSchema = null; + if (isPrimitiveOrWrapperOrString(parameterClass)) { + if (parameterRow.getFieldCount() != 1) { + throw new RuntimeException( + "Expected a row for a single primitive field but received " + parameterRow); + } + // We get the value just for validation here. + getPrimitiveValueFromRow(parameterRow); + } else { + try { + parameterSchema = SCHEMA_REGISTRY.getSchema(parameterClass); + } catch (NoSuchSchemaException e) { + + SCHEMA_REGISTRY.registerSchemaProvider(parameterClass, new JavaFieldSchema()); + try { + parameterSchema = SCHEMA_REGISTRY.getSchema(parameterClass); + } catch (NoSuchSchemaException e1) { + throw new RuntimeException(e1); + } + if (parameterSchema != null && parameterSchema.getFieldCount() == 0) { + throw new RuntimeException( + "Could not determine a valid schema for parameter class " + parameterClass); + } + } + } + + if (parameterSchema != null && !parameterRow.getSchema().assignableTo(parameterSchema)) { + return false; + } + } + return true; + } + + private Object[] getParameterValues( + java.lang.reflect.Parameter[] parameters, Parameter[] payloadParameters) { + ArrayList<Object> parameterValues = new ArrayList<>(); + int i = 0; + for (java.lang.reflect.Parameter parameter : parameters) { + Parameter parameterConfig = payloadParameters[i]; + Class<?> parameterClass = parameter.getType(); + + Row parameterRow = + ExternalTransformRegistrarLoader.decodeRow( + parameterConfig.getSchema(), parameterConfig.getPayload()); + + Object parameterValue = null; + if (isPrimitiveOrWrapperOrString(parameterClass)) { + parameterValue = getPrimitiveValueFromRow(parameterRow); + } else { + SerializableFunction<Row, ?> fromRowFunc = null; + // SCHEMA_REGISTRY. + try { + fromRowFunc = SCHEMA_REGISTRY.getFromRowFunction(parameterClass); + } catch (NoSuchSchemaException e) { + throw new IllegalArgumentException( + "Could not determine the row function for class " + parameterClass, e); + } + parameterValue = fromRowFunc.apply(parameterRow); + } + parameterValues.add(parameterValue); + i++; Review comment: nit: Use a classic for loop, its simpler to understand for what your trying to do ```suggestion for (int i = 0; i < parameters.length; ++i) { Parameter parameter = parameters[i]; Parameter parameterConfig = payloadParameters[i]; Class<?> parameterClass = parameter.getType(); Row parameterRow = ExternalTransformRegistrarLoader.decodeRow( parameterConfig.getSchema(), parameterConfig.getPayload()); Object parameterValue = null; if (isPrimitiveOrWrapperOrString(parameterClass)) { parameterValue = getPrimitiveValueFromRow(parameterRow); } else { SerializableFunction<Row, ?> fromRowFunc = null; // SCHEMA_REGISTRY. try { fromRowFunc = SCHEMA_REGISTRY.getFromRowFunction(parameterClass); } catch (NoSuchSchemaException e) { throw new IllegalArgumentException( "Could not determine the row function for class " + parameterClass, e); } parameterValue = fromRowFunc.apply(parameterRow); } parameterValues.add(parameterValue); ``` ########## File path: sdks/java/expansion-service/src/test/java/org/apache/beam/sdk/expansion/service/JavaCLassLookupTransformProviderTest.java ########## @@ -0,0 +1,466 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.expansion.service; + +import static org.apache.beam.runners.core.construction.BeamUrns.getUrn; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.anyOf; +import static org.hamcrest.Matchers.containsString; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; + +import java.io.File; +import java.net.URL; +import org.apache.beam.model.expansion.v1.ExpansionApi; +import org.apache.beam.model.pipeline.v1.ExternalTransforms; +import org.apache.beam.model.pipeline.v1.ExternalTransforms.BuilderMethod; +import org.apache.beam.model.pipeline.v1.ExternalTransforms.ExpansionMethods; +import org.apache.beam.model.pipeline.v1.RunnerApi; +import org.apache.beam.runners.core.construction.PipelineTranslation; +import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.options.PipelineOptionsFactory; +import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.schemas.Schema.Field; +import org.apache.beam.sdk.schemas.Schema.FieldType; +import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.values.PBegin; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.io.Resources; +import org.hamcrest.Matchers; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for {@link JavaCLassLookupTransformProvider}. */ +@RunWith(JUnit4.class) +public class JavaCLassLookupTransformProviderTest { + + private static final String TEST_URN = "test:beam:transforms:count"; + + private static final String TEST_NAME = "TestName"; + + private static final String TEST_NAMESPACE = "namespace"; + + private static ExpansionService expansionService; + + @BeforeClass + public static void setupExpansionService() { + PipelineOptionsFactory.register(ExpansionServiceOptions.class); + URL allowListFile = Resources.getResource("./test_allowlist.yaml"); + System.out.println("Exists: " + new File(allowListFile.getPath()).exists()); + expansionService = + new ExpansionService( + new String[] {"--javaClassLookupAllowlistFile=" + allowListFile.getPath()}); + } + + public static class DummyTransform extends PTransform<PBegin, PCollection<String>> { + + String strField1; + String strField2; + int intField1; + + @Override + public PCollection<String> expand(PBegin input) { + return input + .apply("MyCreateTransform", Create.of("aaa", "bbb", "ccc")) + .apply( + "MyParDoTransform", + ParDo.of( + new DoFn<String, String>() { + @ProcessElement + public void processElement(ProcessContext c) { + c.output(c.element() + strField1); + } + })); + } + } + + public static class DummyTransformWithConstructor extends DummyTransform { + + public DummyTransformWithConstructor(String strField1) { + this.strField1 = strField1; + } + } + + public static class DummyTransformWithConstructorAndBuilderMethods extends DummyTransform { + + public DummyTransformWithConstructorAndBuilderMethods(String strField1) { + this.strField1 = strField1; + } + + public DummyTransformWithConstructorAndBuilderMethods withStrField2(String strField2) { + this.strField2 = strField2; + return this; + } + + public DummyTransformWithConstructorAndBuilderMethods withIntField1(int intField1) { + this.intField1 = intField1; + return this; + } + } + + public static class DummyTransformWithConstructorMethod extends DummyTransform { + + public static DummyTransformWithConstructorMethod from(String strField1) { + DummyTransformWithConstructorMethod transform = new DummyTransformWithConstructorMethod(); + transform.strField1 = strField1; + return transform; + } + } + + public static class DummyTransformWithConstructorMethodAndBuilderMethods extends DummyTransform { + + public static DummyTransformWithConstructorMethodAndBuilderMethods from(String strField1) { + DummyTransformWithConstructorMethodAndBuilderMethods transform = + new DummyTransformWithConstructorMethodAndBuilderMethods(); + transform.strField1 = strField1; + return transform; + } + + public DummyTransformWithConstructorMethodAndBuilderMethods withStrField2(String strField2) { + this.strField2 = strField2; + return this; + } + + public DummyTransformWithConstructorMethodAndBuilderMethods withIntField1(int intField1) { + this.intField1 = intField1; + return this; + } + } + + public static class DummyTransformWithMultiLanguageAnnotations extends DummyTransform { + + @MultiLanguageConstructorMethod(name = "create_transform") + public static DummyTransformWithMultiLanguageAnnotations from(String strField1) { + DummyTransformWithMultiLanguageAnnotations transform = + new DummyTransformWithMultiLanguageAnnotations(); + transform.strField1 = strField1; + return transform; + } + + @MultiLanguageBuilderMethod(name = "abc") + public DummyTransformWithMultiLanguageAnnotations withStrField2(String strField2) { + this.strField2 = strField2; + return this; + } + + @MultiLanguageBuilderMethod(name = "xyz") + public DummyTransformWithMultiLanguageAnnotations withIntField1(int intField1) { + this.intField1 = intField1; + return this; + } + } + + void testClassLookupExpansionRequestConstruction( + ExternalTransforms.JavaClassLookupPayload payloaad) { + Pipeline p = Pipeline.create(); + + RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto(p); + + ExpansionApi.ExpansionRequest request = + ExpansionApi.ExpansionRequest.newBuilder() + .setComponents(pipelineProto.getComponents()) + .setTransform( + RunnerApi.PTransform.newBuilder() + .setUniqueName(TEST_NAME) + .setSpec( + RunnerApi.FunctionSpec.newBuilder() + .setUrn(getUrn(ExpansionMethods.Enum.JAVA_CLASS_LOOKUP)) + .setPayload(payloaad.toByteString()))) + .setNamespace(TEST_NAMESPACE) + .build(); + ExpansionApi.ExpansionResponse response = expansionService.expand(request); + RunnerApi.PTransform expandedTransform = response.getTransform(); + assertEquals(TEST_NAMESPACE + TEST_NAME, expandedTransform.getUniqueName()); + assertThat(expandedTransform.getInputsCount(), Matchers.is(0)); + assertThat(expandedTransform.getOutputsCount(), Matchers.is(1)); + assertEquals(2, expandedTransform.getSubtransformsCount()); + assertEquals(2, expandedTransform.getSubtransformsCount()); + assertThat( + expandedTransform.getSubtransforms(0), + anyOf(containsString("MyCreateTransform"), containsString("MyParDoTransform"))); + assertThat( + expandedTransform.getSubtransforms(1), + anyOf(containsString("MyCreateTransform"), containsString("MyParDoTransform"))); + } + + @Test + public void testJavaClassLookupWithConstructor() { + ExternalTransforms.JavaClassLookupPayload.Builder payloadBuilder = + ExternalTransforms.JavaClassLookupPayload.newBuilder(); + payloadBuilder.setClassName( + "org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithConstructor"); + + payloadBuilder.addConstructorParameters( + ExpansionServiceTest.encodeRowIntoParameter( + Row.withSchema(Schema.of(Field.of("strField1", FieldType.STRING))) + .withFieldValue("strField1", "test_str_1") + .build(), + "strField1")); + + testClassLookupExpansionRequestConstruction(payloadBuilder.build()); + } + + @Test + public void testJavaClassLookupWithConstructorMethod() { + ExternalTransforms.JavaClassLookupPayload.Builder payloadBuilder = + ExternalTransforms.JavaClassLookupPayload.newBuilder(); + payloadBuilder.setClassName( + "org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithConstructorMethod"); + + payloadBuilder.setConstructorMethod("from"); + payloadBuilder.addConstructorParameters( + ExpansionServiceTest.encodeRowIntoParameter( + Row.withSchema(Schema.of(Field.of("strField1", FieldType.STRING))) + .withFieldValue("strField1", "test_str_1") + .build(), + "strField1")); + + testClassLookupExpansionRequestConstruction(payloadBuilder.build()); + } + + @Test + public void testJavaClassLookupWithConstructorAndBuilderMethods() { + ExternalTransforms.JavaClassLookupPayload.Builder payloadBuilder = + ExternalTransforms.JavaClassLookupPayload.newBuilder(); + payloadBuilder.setClassName( + "org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithConstructorAndBuilderMethods"); + + payloadBuilder.addConstructorParameters( + ExpansionServiceTest.encodeRowIntoParameter( + Row.withSchema(Schema.of(Field.of("strField1", FieldType.STRING))) + .withFieldValue("strField1", "test_str_1") + .build(), + "strField1")); + + BuilderMethod.Builder builderMethodBuilder = BuilderMethod.newBuilder(); + builderMethodBuilder.setName("withStrField2"); + builderMethodBuilder.addParameter( + ExpansionServiceTest.encodeRowIntoParameter( + Row.withSchema(Schema.of(Field.of("strField2", FieldType.STRING))) + .withFieldValue("strField2", "test_str_2") + .build(), + "strField2")); + payloadBuilder.addBuilderMethods(builderMethodBuilder); + + builderMethodBuilder = BuilderMethod.newBuilder(); + builderMethodBuilder.setName("withIntField1"); + builderMethodBuilder.addParameter( + ExpansionServiceTest.encodeRowIntoParameter( + Row.withSchema(Schema.of(Field.of("intField1", FieldType.INT32))) + .withFieldValue("intField1", 10) + .build(), + "intField1")); + payloadBuilder.addBuilderMethods(builderMethodBuilder); + + testClassLookupExpansionRequestConstruction(payloadBuilder.build()); Review comment: Check that strField1, strField2 and intField1 was set. ########## File path: sdks/java/expansion-service/src/test/java/org/apache/beam/sdk/expansion/service/JavaCLassLookupTransformProviderTest.java ########## @@ -0,0 +1,466 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.expansion.service; + +import static org.apache.beam.runners.core.construction.BeamUrns.getUrn; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.anyOf; +import static org.hamcrest.Matchers.containsString; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; + +import java.io.File; +import java.net.URL; +import org.apache.beam.model.expansion.v1.ExpansionApi; +import org.apache.beam.model.pipeline.v1.ExternalTransforms; +import org.apache.beam.model.pipeline.v1.ExternalTransforms.BuilderMethod; +import org.apache.beam.model.pipeline.v1.ExternalTransforms.ExpansionMethods; +import org.apache.beam.model.pipeline.v1.RunnerApi; +import org.apache.beam.runners.core.construction.PipelineTranslation; +import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.options.PipelineOptionsFactory; +import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.schemas.Schema.Field; +import org.apache.beam.sdk.schemas.Schema.FieldType; +import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.values.PBegin; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.io.Resources; +import org.hamcrest.Matchers; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for {@link JavaCLassLookupTransformProvider}. */ +@RunWith(JUnit4.class) +public class JavaCLassLookupTransformProviderTest { + + private static final String TEST_URN = "test:beam:transforms:count"; + + private static final String TEST_NAME = "TestName"; + + private static final String TEST_NAMESPACE = "namespace"; + + private static ExpansionService expansionService; + + @BeforeClass + public static void setupExpansionService() { + PipelineOptionsFactory.register(ExpansionServiceOptions.class); + URL allowListFile = Resources.getResource("./test_allowlist.yaml"); + System.out.println("Exists: " + new File(allowListFile.getPath()).exists()); + expansionService = + new ExpansionService( + new String[] {"--javaClassLookupAllowlistFile=" + allowListFile.getPath()}); + } + + public static class DummyTransform extends PTransform<PBegin, PCollection<String>> { + + String strField1; + String strField2; + int intField1; + + @Override + public PCollection<String> expand(PBegin input) { + return input + .apply("MyCreateTransform", Create.of("aaa", "bbb", "ccc")) + .apply( + "MyParDoTransform", + ParDo.of( + new DoFn<String, String>() { + @ProcessElement + public void processElement(ProcessContext c) { + c.output(c.element() + strField1); + } + })); + } + } + + public static class DummyTransformWithConstructor extends DummyTransform { + + public DummyTransformWithConstructor(String strField1) { + this.strField1 = strField1; + } + } + + public static class DummyTransformWithConstructorAndBuilderMethods extends DummyTransform { + + public DummyTransformWithConstructorAndBuilderMethods(String strField1) { + this.strField1 = strField1; + } + + public DummyTransformWithConstructorAndBuilderMethods withStrField2(String strField2) { + this.strField2 = strField2; + return this; + } + + public DummyTransformWithConstructorAndBuilderMethods withIntField1(int intField1) { + this.intField1 = intField1; + return this; + } + } + + public static class DummyTransformWithConstructorMethod extends DummyTransform { + + public static DummyTransformWithConstructorMethod from(String strField1) { + DummyTransformWithConstructorMethod transform = new DummyTransformWithConstructorMethod(); + transform.strField1 = strField1; + return transform; + } + } + + public static class DummyTransformWithConstructorMethodAndBuilderMethods extends DummyTransform { + + public static DummyTransformWithConstructorMethodAndBuilderMethods from(String strField1) { + DummyTransformWithConstructorMethodAndBuilderMethods transform = + new DummyTransformWithConstructorMethodAndBuilderMethods(); + transform.strField1 = strField1; + return transform; + } + + public DummyTransformWithConstructorMethodAndBuilderMethods withStrField2(String strField2) { + this.strField2 = strField2; + return this; + } + + public DummyTransformWithConstructorMethodAndBuilderMethods withIntField1(int intField1) { + this.intField1 = intField1; + return this; + } + } + + public static class DummyTransformWithMultiLanguageAnnotations extends DummyTransform { + + @MultiLanguageConstructorMethod(name = "create_transform") + public static DummyTransformWithMultiLanguageAnnotations from(String strField1) { + DummyTransformWithMultiLanguageAnnotations transform = + new DummyTransformWithMultiLanguageAnnotations(); + transform.strField1 = strField1; + return transform; + } + + @MultiLanguageBuilderMethod(name = "abc") + public DummyTransformWithMultiLanguageAnnotations withStrField2(String strField2) { + this.strField2 = strField2; + return this; + } + + @MultiLanguageBuilderMethod(name = "xyz") + public DummyTransformWithMultiLanguageAnnotations withIntField1(int intField1) { + this.intField1 = intField1; + return this; + } + } + + void testClassLookupExpansionRequestConstruction( + ExternalTransforms.JavaClassLookupPayload payloaad) { + Pipeline p = Pipeline.create(); + + RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto(p); + + ExpansionApi.ExpansionRequest request = + ExpansionApi.ExpansionRequest.newBuilder() + .setComponents(pipelineProto.getComponents()) + .setTransform( + RunnerApi.PTransform.newBuilder() + .setUniqueName(TEST_NAME) + .setSpec( + RunnerApi.FunctionSpec.newBuilder() + .setUrn(getUrn(ExpansionMethods.Enum.JAVA_CLASS_LOOKUP)) + .setPayload(payloaad.toByteString()))) + .setNamespace(TEST_NAMESPACE) + .build(); + ExpansionApi.ExpansionResponse response = expansionService.expand(request); + RunnerApi.PTransform expandedTransform = response.getTransform(); + assertEquals(TEST_NAMESPACE + TEST_NAME, expandedTransform.getUniqueName()); + assertThat(expandedTransform.getInputsCount(), Matchers.is(0)); + assertThat(expandedTransform.getOutputsCount(), Matchers.is(1)); + assertEquals(2, expandedTransform.getSubtransformsCount()); + assertEquals(2, expandedTransform.getSubtransformsCount()); + assertThat( + expandedTransform.getSubtransforms(0), + anyOf(containsString("MyCreateTransform"), containsString("MyParDoTransform"))); + assertThat( + expandedTransform.getSubtransforms(1), + anyOf(containsString("MyCreateTransform"), containsString("MyParDoTransform"))); + } + + @Test + public void testJavaClassLookupWithConstructor() { + ExternalTransforms.JavaClassLookupPayload.Builder payloadBuilder = + ExternalTransforms.JavaClassLookupPayload.newBuilder(); + payloadBuilder.setClassName( + "org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithConstructor"); + + payloadBuilder.addConstructorParameters( + ExpansionServiceTest.encodeRowIntoParameter( + Row.withSchema(Schema.of(Field.of("strField1", FieldType.STRING))) + .withFieldValue("strField1", "test_str_1") + .build(), + "strField1")); + + testClassLookupExpansionRequestConstruction(payloadBuilder.build()); + } + + @Test + public void testJavaClassLookupWithConstructorMethod() { + ExternalTransforms.JavaClassLookupPayload.Builder payloadBuilder = + ExternalTransforms.JavaClassLookupPayload.newBuilder(); + payloadBuilder.setClassName( + "org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithConstructorMethod"); + + payloadBuilder.setConstructorMethod("from"); + payloadBuilder.addConstructorParameters( + ExpansionServiceTest.encodeRowIntoParameter( + Row.withSchema(Schema.of(Field.of("strField1", FieldType.STRING))) + .withFieldValue("strField1", "test_str_1") + .build(), + "strField1")); + + testClassLookupExpansionRequestConstruction(payloadBuilder.build()); + } + + @Test + public void testJavaClassLookupWithConstructorAndBuilderMethods() { + ExternalTransforms.JavaClassLookupPayload.Builder payloadBuilder = + ExternalTransforms.JavaClassLookupPayload.newBuilder(); + payloadBuilder.setClassName( + "org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithConstructorAndBuilderMethods"); + + payloadBuilder.addConstructorParameters( + ExpansionServiceTest.encodeRowIntoParameter( + Row.withSchema(Schema.of(Field.of("strField1", FieldType.STRING))) + .withFieldValue("strField1", "test_str_1") + .build(), + "strField1")); + + BuilderMethod.Builder builderMethodBuilder = BuilderMethod.newBuilder(); + builderMethodBuilder.setName("withStrField2"); + builderMethodBuilder.addParameter( + ExpansionServiceTest.encodeRowIntoParameter( + Row.withSchema(Schema.of(Field.of("strField2", FieldType.STRING))) + .withFieldValue("strField2", "test_str_2") + .build(), + "strField2")); + payloadBuilder.addBuilderMethods(builderMethodBuilder); + + builderMethodBuilder = BuilderMethod.newBuilder(); + builderMethodBuilder.setName("withIntField1"); + builderMethodBuilder.addParameter( + ExpansionServiceTest.encodeRowIntoParameter( + Row.withSchema(Schema.of(Field.of("intField1", FieldType.INT32))) + .withFieldValue("intField1", 10) + .build(), + "intField1")); + payloadBuilder.addBuilderMethods(builderMethodBuilder); + + testClassLookupExpansionRequestConstruction(payloadBuilder.build()); + } + + @Test + public void testJavaClassLookupWithConstructorMethodAndBuilderMethods() { + ExternalTransforms.JavaClassLookupPayload.Builder payloadBuilder = + ExternalTransforms.JavaClassLookupPayload.newBuilder(); + payloadBuilder.setClassName( + "org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithConstructorMethodAndBuilderMethods"); + payloadBuilder.setConstructorMethod("from"); + + payloadBuilder.addConstructorParameters( + ExpansionServiceTest.encodeRowIntoParameter( + Row.withSchema(Schema.of(Field.of("strField1", FieldType.STRING))) + .withFieldValue("strField1", "test_str_1") + .build(), + "strField1")); + + BuilderMethod.Builder builderMethodBuilder = BuilderMethod.newBuilder(); + builderMethodBuilder.setName("withStrField2"); + builderMethodBuilder.addParameter( + ExpansionServiceTest.encodeRowIntoParameter( + Row.withSchema(Schema.of(Field.of("strField2", FieldType.STRING))) + .withFieldValue("strField2", "test_str_2") + .build(), + "strField2")); + payloadBuilder.addBuilderMethods(builderMethodBuilder); + + builderMethodBuilder = BuilderMethod.newBuilder(); + builderMethodBuilder.setName("withIntField1"); + builderMethodBuilder.addParameter( + ExpansionServiceTest.encodeRowIntoParameter( + Row.withSchema(Schema.of(Field.of("intField1", FieldType.INT32))) + .withFieldValue("intField1", 10) + .build(), + "intField1")); + payloadBuilder.addBuilderMethods(builderMethodBuilder); + + testClassLookupExpansionRequestConstruction(payloadBuilder.build()); + } + + @Test + public void testJavaClassLookupWithSimplifiedBuilderMethodNames() { + ExternalTransforms.JavaClassLookupPayload.Builder payloadBuilder = + ExternalTransforms.JavaClassLookupPayload.newBuilder(); + payloadBuilder.setClassName( + "org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithConstructorMethodAndBuilderMethods"); + payloadBuilder.setConstructorMethod("from"); + + payloadBuilder.addConstructorParameters( + ExpansionServiceTest.encodeRowIntoParameter( + Row.withSchema(Schema.of(Field.of("strField1", FieldType.STRING))) + .withFieldValue("strField1", "test_str_1") + .build(), + "strField1")); + + BuilderMethod.Builder builderMethodBuilder = BuilderMethod.newBuilder(); + builderMethodBuilder.setName("strField2"); + builderMethodBuilder.addParameter( + ExpansionServiceTest.encodeRowIntoParameter( + Row.withSchema(Schema.of(Field.of("strField2", FieldType.STRING))) + .withFieldValue("strField2", "test_str_2") + .build(), + "strField2")); + payloadBuilder.addBuilderMethods(builderMethodBuilder); + + builderMethodBuilder = BuilderMethod.newBuilder(); + builderMethodBuilder.setName("intField1"); + builderMethodBuilder.addParameter( + ExpansionServiceTest.encodeRowIntoParameter( + Row.withSchema(Schema.of(Field.of("intField1", FieldType.INT32))) + .withFieldValue("intField1", 10) + .build(), + "intField1")); + payloadBuilder.addBuilderMethods(builderMethodBuilder); + + testClassLookupExpansionRequestConstruction(payloadBuilder.build()); + } + + @Test + public void testJavaClassLookupWithAnnotations() { + ExternalTransforms.JavaClassLookupPayload.Builder payloadBuilder = + ExternalTransforms.JavaClassLookupPayload.newBuilder(); + payloadBuilder.setClassName( + "org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithMultiLanguageAnnotations"); + payloadBuilder.setConstructorMethod("create_transform"); + + payloadBuilder.addConstructorParameters( + ExpansionServiceTest.encodeRowIntoParameter( + Row.withSchema(Schema.of(Field.of("strField1", FieldType.STRING))) + .withFieldValue("strField1", "test_str_1") + .build(), + "strField1")); + + BuilderMethod.Builder builderMethodBuilder = BuilderMethod.newBuilder(); + builderMethodBuilder.setName("abc"); + builderMethodBuilder.addParameter( + ExpansionServiceTest.encodeRowIntoParameter( + Row.withSchema(Schema.of(Field.of("strField2", FieldType.STRING))) + .withFieldValue("strField2", "test_str_2") + .build(), + "strField2")); + payloadBuilder.addBuilderMethods(builderMethodBuilder); + + builderMethodBuilder = BuilderMethod.newBuilder(); + builderMethodBuilder.setName("xyz"); + builderMethodBuilder.addParameter( + ExpansionServiceTest.encodeRowIntoParameter( + Row.withSchema(Schema.of(Field.of("intField1", FieldType.INT32))) + .withFieldValue("intField1", 10) + .build(), + "intField1")); + payloadBuilder.addBuilderMethods(builderMethodBuilder); + + testClassLookupExpansionRequestConstruction(payloadBuilder.build()); Review comment: Check that strField1, abc and xyz was set. ########## File path: sdks/java/expansion-service/src/test/java/org/apache/beam/sdk/expansion/service/JavaCLassLookupTransformProviderTest.java ########## @@ -0,0 +1,466 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.expansion.service; + +import static org.apache.beam.runners.core.construction.BeamUrns.getUrn; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.anyOf; +import static org.hamcrest.Matchers.containsString; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; + +import java.io.File; +import java.net.URL; +import org.apache.beam.model.expansion.v1.ExpansionApi; +import org.apache.beam.model.pipeline.v1.ExternalTransforms; +import org.apache.beam.model.pipeline.v1.ExternalTransforms.BuilderMethod; +import org.apache.beam.model.pipeline.v1.ExternalTransforms.ExpansionMethods; +import org.apache.beam.model.pipeline.v1.RunnerApi; +import org.apache.beam.runners.core.construction.PipelineTranslation; +import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.options.PipelineOptionsFactory; +import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.schemas.Schema.Field; +import org.apache.beam.sdk.schemas.Schema.FieldType; +import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.values.PBegin; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.io.Resources; +import org.hamcrest.Matchers; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for {@link JavaCLassLookupTransformProvider}. */ +@RunWith(JUnit4.class) +public class JavaCLassLookupTransformProviderTest { + + private static final String TEST_URN = "test:beam:transforms:count"; + + private static final String TEST_NAME = "TestName"; + + private static final String TEST_NAMESPACE = "namespace"; + + private static ExpansionService expansionService; + + @BeforeClass + public static void setupExpansionService() { + PipelineOptionsFactory.register(ExpansionServiceOptions.class); + URL allowListFile = Resources.getResource("./test_allowlist.yaml"); + System.out.println("Exists: " + new File(allowListFile.getPath()).exists()); + expansionService = + new ExpansionService( + new String[] {"--javaClassLookupAllowlistFile=" + allowListFile.getPath()}); + } + + public static class DummyTransform extends PTransform<PBegin, PCollection<String>> { + + String strField1; + String strField2; + int intField1; + + @Override + public PCollection<String> expand(PBegin input) { + return input + .apply("MyCreateTransform", Create.of("aaa", "bbb", "ccc")) + .apply( + "MyParDoTransform", + ParDo.of( + new DoFn<String, String>() { + @ProcessElement + public void processElement(ProcessContext c) { + c.output(c.element() + strField1); + } + })); + } + } + + public static class DummyTransformWithConstructor extends DummyTransform { + + public DummyTransformWithConstructor(String strField1) { + this.strField1 = strField1; + } + } + + public static class DummyTransformWithConstructorAndBuilderMethods extends DummyTransform { + + public DummyTransformWithConstructorAndBuilderMethods(String strField1) { + this.strField1 = strField1; + } + + public DummyTransformWithConstructorAndBuilderMethods withStrField2(String strField2) { + this.strField2 = strField2; + return this; + } + + public DummyTransformWithConstructorAndBuilderMethods withIntField1(int intField1) { + this.intField1 = intField1; + return this; + } + } + + public static class DummyTransformWithConstructorMethod extends DummyTransform { + + public static DummyTransformWithConstructorMethod from(String strField1) { + DummyTransformWithConstructorMethod transform = new DummyTransformWithConstructorMethod(); + transform.strField1 = strField1; + return transform; + } + } + + public static class DummyTransformWithConstructorMethodAndBuilderMethods extends DummyTransform { + + public static DummyTransformWithConstructorMethodAndBuilderMethods from(String strField1) { + DummyTransformWithConstructorMethodAndBuilderMethods transform = + new DummyTransformWithConstructorMethodAndBuilderMethods(); + transform.strField1 = strField1; + return transform; + } + + public DummyTransformWithConstructorMethodAndBuilderMethods withStrField2(String strField2) { + this.strField2 = strField2; + return this; + } + + public DummyTransformWithConstructorMethodAndBuilderMethods withIntField1(int intField1) { + this.intField1 = intField1; + return this; + } + } + + public static class DummyTransformWithMultiLanguageAnnotations extends DummyTransform { + + @MultiLanguageConstructorMethod(name = "create_transform") + public static DummyTransformWithMultiLanguageAnnotations from(String strField1) { + DummyTransformWithMultiLanguageAnnotations transform = + new DummyTransformWithMultiLanguageAnnotations(); + transform.strField1 = strField1; + return transform; + } + + @MultiLanguageBuilderMethod(name = "abc") + public DummyTransformWithMultiLanguageAnnotations withStrField2(String strField2) { + this.strField2 = strField2; + return this; + } + + @MultiLanguageBuilderMethod(name = "xyz") + public DummyTransformWithMultiLanguageAnnotations withIntField1(int intField1) { + this.intField1 = intField1; + return this; + } + } + + void testClassLookupExpansionRequestConstruction( + ExternalTransforms.JavaClassLookupPayload payloaad) { + Pipeline p = Pipeline.create(); + + RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto(p); + + ExpansionApi.ExpansionRequest request = + ExpansionApi.ExpansionRequest.newBuilder() + .setComponents(pipelineProto.getComponents()) + .setTransform( + RunnerApi.PTransform.newBuilder() + .setUniqueName(TEST_NAME) + .setSpec( + RunnerApi.FunctionSpec.newBuilder() + .setUrn(getUrn(ExpansionMethods.Enum.JAVA_CLASS_LOOKUP)) + .setPayload(payloaad.toByteString()))) + .setNamespace(TEST_NAMESPACE) + .build(); + ExpansionApi.ExpansionResponse response = expansionService.expand(request); + RunnerApi.PTransform expandedTransform = response.getTransform(); + assertEquals(TEST_NAMESPACE + TEST_NAME, expandedTransform.getUniqueName()); + assertThat(expandedTransform.getInputsCount(), Matchers.is(0)); + assertThat(expandedTransform.getOutputsCount(), Matchers.is(1)); + assertEquals(2, expandedTransform.getSubtransformsCount()); + assertEquals(2, expandedTransform.getSubtransformsCount()); + assertThat( + expandedTransform.getSubtransforms(0), + anyOf(containsString("MyCreateTransform"), containsString("MyParDoTransform"))); + assertThat( + expandedTransform.getSubtransforms(1), + anyOf(containsString("MyCreateTransform"), containsString("MyParDoTransform"))); + } + + @Test + public void testJavaClassLookupWithConstructor() { + ExternalTransforms.JavaClassLookupPayload.Builder payloadBuilder = + ExternalTransforms.JavaClassLookupPayload.newBuilder(); + payloadBuilder.setClassName( + "org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithConstructor"); + + payloadBuilder.addConstructorParameters( + ExpansionServiceTest.encodeRowIntoParameter( + Row.withSchema(Schema.of(Field.of("strField1", FieldType.STRING))) + .withFieldValue("strField1", "test_str_1") + .build(), + "strField1")); + + testClassLookupExpansionRequestConstruction(payloadBuilder.build()); Review comment: Check that strField1 was set? -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
