Generalize extraction of DoFn parameters from context
Project: http://git-wip-us.apache.org/repos/asf/incubator-beam/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-beam/commit/40ff9d40 Tree: http://git-wip-us.apache.org/repos/asf/incubator-beam/tree/40ff9d40 Diff: http://git-wip-us.apache.org/repos/asf/incubator-beam/diff/40ff9d40 Branch: refs/heads/master Commit: 40ff9d401f0ba3f85d1bab848d1c6a662b03bc99 Parents: ac252a7 Author: Kenneth Knowles <[email protected]> Authored: Thu Nov 3 18:42:25 2016 -0700 Committer: Kenneth Knowles <[email protected]> Committed: Mon Nov 7 15:25:03 2016 -0800 ---------------------------------------------------------------------- .../sdk/transforms/reflect/DoFnInvokers.java | 115 ++++++++++++------- .../sdk/transforms/reflect/DoFnSignature.java | 11 +- .../sdk/transforms/reflect/DoFnSignatures.java | 15 ++- .../DoFnSignaturesSplittableDoFnTest.java | 5 +- 4 files changed, 94 insertions(+), 52 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/40ff9d40/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnInvokers.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnInvokers.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnInvokers.java index c5a23dc..ad2b766 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnInvokers.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnInvokers.java @@ -69,6 +69,13 @@ import org.apache.beam.sdk.transforms.DoFn.ExtraContextFactory; import org.apache.beam.sdk.transforms.DoFn.ProcessElement; import org.apache.beam.sdk.transforms.DoFnAdapters; import org.apache.beam.sdk.transforms.OldDoFn; +import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.BoundedWindowParameter; +import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.Cases; +import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.InputProviderParameter; +import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.OutputReceiverParameter; +import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.RestrictionTrackerParameter; +import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.StateParameter; +import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.TimerParameter; import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker; import org.apache.beam.sdk.util.UserCodeException; import org.apache.beam.sdk.values.TypeDescriptor; @@ -496,40 +503,76 @@ public class DoFnInvokers { } } + private static StackManipulation simpleExtraContextParameter( + String methodName, + StackManipulation pushExtraContextFactory) { + try { + return new StackManipulation.Compound( + pushExtraContextFactory, + MethodInvocation.invoke( + new MethodDescription.ForLoadedMethod( + DoFn.ExtraContextFactory.class.getMethod(methodName)))); + } catch (Exception e) { + throw new IllegalStateException( + String.format( + "Failed to locate required method %s.%s", + ExtraContextFactory.class.getSimpleName(), methodName), + e); + } + } + + private static StackManipulation getExtraContextParameter( + DoFnSignature.Parameter parameter, + final StackManipulation pushExtraContextFactory) { + + return parameter.match(new Cases<StackManipulation>() { + + @Override + public StackManipulation dispatch(BoundedWindowParameter p) { + return simpleExtraContextParameter("window", pushExtraContextFactory); + } + + @Override + public StackManipulation dispatch(InputProviderParameter p) { + return simpleExtraContextParameter("inputProvider", pushExtraContextFactory); + } + + @Override + public StackManipulation dispatch(OutputReceiverParameter p) { + return simpleExtraContextParameter("outputReceiver", pushExtraContextFactory); + } + + @Override + public StackManipulation dispatch(RestrictionTrackerParameter p) { + // ExtraContextFactory.restrictionTracker() returns a RestrictionTracker, + // but the @ProcessElement method expects a concrete subtype of it. + // Insert a downcast. + return new StackManipulation.Compound( + simpleExtraContextParameter("restrictionTracker", pushExtraContextFactory), + TypeCasting.to(new TypeDescription.ForLoadedType(p.trackerT().getRawType()))); + } + + @Override + public StackManipulation dispatch(StateParameter p) { + throw new UnsupportedOperationException("State parameters are not yet supported."); + } + + @Override + public StackManipulation dispatch(TimerParameter p) { + throw new UnsupportedOperationException("Timer parameters are not yet supported."); + } + }); + } + /** * Implements the invoker's {@link DoFnInvoker#invokeProcessElement} method by delegating to the * {@link DoFn.ProcessElement} method. */ private static final class ProcessElementDelegation extends DoFnMethodDelegation { - private static final Map<DoFnSignature.Parameter, MethodDescription> - EXTRA_CONTEXT_FACTORY_METHODS; private static final MethodDescription PROCESS_CONTINUATION_STOP_METHOD; static { try { - Map<DoFnSignature.Parameter, MethodDescription> methods = new HashMap<>(); - methods.put( - DoFnSignature.Parameter.boundedWindow(), - new MethodDescription.ForLoadedMethod( - DoFn.ExtraContextFactory.class.getMethod("window"))); - methods.put( - DoFnSignature.Parameter.inputProvider(), - new MethodDescription.ForLoadedMethod( - DoFn.ExtraContextFactory.class.getMethod("inputProvider"))); - methods.put( - DoFnSignature.Parameter.outputReceiver(), - new MethodDescription.ForLoadedMethod( - DoFn.ExtraContextFactory.class.getMethod("outputReceiver"))); - methods.put( - DoFnSignature.Parameter.restrictionTracker(), - new MethodDescription.ForLoadedMethod( - DoFn.ExtraContextFactory.class.getMethod("restrictionTracker"))); - EXTRA_CONTEXT_FACTORY_METHODS = Collections.unmodifiableMap(methods); - } catch (Exception e) { - throw new RuntimeException( - "Failed to locate an ExtraContextFactory method that was expected to exist", e); - } - try { PROCESS_CONTINUATION_STOP_METHOD = new MethodDescription.ForLoadedMethod(DoFn.ProcessContinuation.class.getMethod("stop")); } catch (NoSuchMethodException e) { @@ -537,10 +580,6 @@ public class DoFnInvokers { } } - private static MethodDescription getExtraContextFactoryMethod(DoFnSignature.Parameter param) { - return EXTRA_CONTEXT_FACTORY_METHODS.get(param); - } - private final DoFnSignature.ProcessElementMethod signature; /** Implementation of {@link MethodDelegation} for the {@link ProcessElement} method. */ @@ -555,25 +594,15 @@ public class DoFnInvokers { // DoFn.ProcessContext, ExtraContextFactory. // Parameters of the wrapped DoFn method: // DoFn.ProcessContext, [BoundedWindow, InputProvider, OutputReceiver] in any order - ArrayList<StackManipulation> parameters = new ArrayList<>(); + ArrayList<StackManipulation> pushParameters = new ArrayList<>(); // Push the ProcessContext argument. - parameters.add(MethodVariableAccess.REFERENCE.loadOffset(1)); + pushParameters.add(MethodVariableAccess.REFERENCE.loadOffset(1)); // Push the extra arguments in their actual order. StackManipulation pushExtraContextFactory = MethodVariableAccess.REFERENCE.loadOffset(2); for (DoFnSignature.Parameter param : signature.extraParameters()) { - parameters.add( - new StackManipulation.Compound( - pushExtraContextFactory, - MethodInvocation.invoke(getExtraContextFactoryMethod(param)), - // ExtraContextFactory.restrictionTracker() returns a RestrictionTracker, - // but the @ProcessElement method expects a concrete subtype of it. - // Insert a downcast. - DoFnSignature.Parameter.restrictionTracker().equals(param) - ? TypeCasting.to( - new TypeDescription.ForLoadedType(signature.trackerT().getRawType())) - : StackManipulation.Trivial.INSTANCE)); + pushParameters.add(getExtraContextParameter(param, pushExtraContextFactory)); } - return new StackManipulation.Compound(parameters); + return new StackManipulation.Compound(pushParameters); } @Override http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/40ff9d40/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignature.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignature.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignature.java index 11f6aa7..a189bd5 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignature.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignature.java @@ -34,6 +34,7 @@ import org.apache.beam.sdk.transforms.DoFn.ProcessContinuation; import org.apache.beam.sdk.transforms.DoFn.StateId; import org.apache.beam.sdk.transforms.DoFn.TimerId; import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.BoundedWindowParameter; +import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.RestrictionTrackerParameter; import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.StateParameter; import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; @@ -230,8 +231,6 @@ public abstract class DoFnSignature { // These parameter descriptors are constant private static final BoundedWindowParameter BOUNDED_WINDOW_PARAMETER = new AutoValue_DoFnSignature_Parameter_BoundedWindowParameter(); - private static final RestrictionTrackerParameter RESTRICTION_TRACKER_PARAMETER = - new AutoValue_DoFnSignature_Parameter_RestrictionTrackerParameter(); private static final InputProviderParameter INPUT_PROVIDER_PARAMETER = new AutoValue_DoFnSignature_Parameter_InputProviderParameter(); private static final OutputReceiverParameter OUTPUT_RECEIVER_PARAMETER = @@ -261,8 +260,8 @@ public abstract class DoFnSignature { /** * Returns a {@link RestrictionTrackerParameter}. */ - public static RestrictionTrackerParameter restrictionTracker() { - return RESTRICTION_TRACKER_PARAMETER; + public static RestrictionTrackerParameter restrictionTracker(TypeDescriptor<?> trackerT) { + return new AutoValue_DoFnSignature_Parameter_RestrictionTrackerParameter(trackerT); } /** @@ -315,6 +314,7 @@ public abstract class DoFnSignature { public abstract static class RestrictionTrackerParameter extends Parameter { // Package visible for AutoValue RestrictionTrackerParameter() {} + public abstract TypeDescriptor<?> trackerT(); } /** @@ -388,7 +388,8 @@ public abstract class DoFnSignature { * Whether this {@link DoFn} is <a href="https://s.apache.org/splittable-do-fn">splittable</a>. */ public boolean isSplittable() { - return extraParameters().contains(Parameter.restrictionTracker()); + return Iterables.any( + extraParameters(), Predicates.instanceOf(RestrictionTrackerParameter.class)); } } http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/40ff9d40/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignatures.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignatures.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignatures.java index 0475404..09c5f3d 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignatures.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignatures.java @@ -21,7 +21,9 @@ import static com.google.common.base.Preconditions.checkState; import com.google.auto.value.AutoValue; import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Predicates; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Iterables; import com.google.common.collect.Maps; import java.lang.annotation.Annotation; import java.lang.reflect.AnnotatedElement; @@ -45,6 +47,7 @@ import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.DoFn.StateId; import org.apache.beam.sdk.transforms.DoFn.TimerId; import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter; +import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.RestrictionTrackerParameter; import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.StateParameter; import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.TimerParameter; import org.apache.beam.sdk.transforms.reflect.DoFnSignature.StateDeclaration; @@ -148,6 +151,12 @@ public class DoFnSignatures { private MethodAnalysisContext() {} + /** Indicates whether a {@link RestrictionTrackerParameter} is known in this context. */ + public boolean hasRestrictionTrackerParameter() { + return Iterables.any( + extraParameters, Predicates.instanceOf(RestrictionTrackerParameter.class)); + } + /** State parameters declared in this context, keyed by {@link StateId}. */ public Map<String, StateParameter> getStateParameters() { return Collections.unmodifiableMap(stateParameters); @@ -663,7 +672,7 @@ public class DoFnSignatures { } // A splittable DoFn can not have any other extra context parameters. - if (methodContext.getExtraParameters().contains(DoFnSignature.Parameter.restrictionTracker())) { + if (methodContext.hasRestrictionTrackerParameter()) { errors.checkArgument( methodContext.getExtraParameters().size() == 1, "Splittable DoFn must not have any extra arguments, but has: %s", @@ -724,10 +733,10 @@ public class DoFnSignatures { } else if (RestrictionTracker.class.isAssignableFrom(rawType)) { methodErrors.checkArgument( - !methodContext.getExtraParameters().contains(Parameter.restrictionTracker()), + !methodContext.hasRestrictionTrackerParameter(), "Multiple %s parameters", RestrictionTracker.class.getSimpleName()); - return Parameter.restrictionTracker(); + return Parameter.restrictionTracker(paramT); } else if (rawType.equals(Timer.class)) { // m.getParameters() is not available until Java 8 http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/40ff9d40/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesSplittableDoFnTest.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesSplittableDoFnTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesSplittableDoFnTest.java index 0751b59..7b93eb9 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesSplittableDoFnTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesSplittableDoFnTest.java @@ -22,6 +22,8 @@ import static org.apache.beam.sdk.transforms.reflect.DoFnSignaturesTestUtils.err import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; +import com.google.common.base.Predicates; +import com.google.common.collect.Iterables; import java.util.List; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.KvCoder; @@ -80,7 +82,8 @@ public class DoFnSignaturesSplittableDoFnTest { }); assertTrue(signature.isSplittable()); - assertTrue(signature.extraParameters().contains(DoFnSignature.Parameter.restrictionTracker())); + assertTrue(Iterables.any(signature.extraParameters(), + Predicates.instanceOf(DoFnSignature.Parameter.RestrictionTrackerParameter.class))); assertEquals(SomeRestrictionTracker.class, signature.trackerT().getRawType()); }
