Add custom rehydration for ParDo
Project: http://git-wip-us.apache.org/repos/asf/beam/repo Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/7fb3e793 Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/7fb3e793 Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/7fb3e793 Branch: refs/heads/master Commit: 7fb3e79328e1a9ef8340170aecd44c89e596eec5 Parents: 92209c3 Author: Kenneth Knowles <[email protected]> Authored: Tue Oct 3 19:17:48 2017 -0700 Committer: Kenneth Knowles <[email protected]> Committed: Tue Oct 17 12:45:11 2017 -0700 ---------------------------------------------------------------------- .../core/construction/ParDoTranslation.java | 226 ++++++++++++++++--- .../core/construction/PipelineTranslation.java | 22 -- 2 files changed, 194 insertions(+), 54 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/beam/blob/7fb3e793/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ParDoTranslation.java ---------------------------------------------------------------------- diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ParDoTranslation.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ParDoTranslation.java index 5092448..f88cbe5 100644 --- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ParDoTranslation.java +++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ParDoTranslation.java @@ -26,6 +26,7 @@ import static org.apache.beam.runners.core.construction.PTransformTranslation.PA import com.google.auto.service.AutoService; import com.google.auto.value.AutoValue; import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.MoreObjects; import com.google.common.base.Optional; import com.google.common.collect.Iterables; import com.google.common.collect.Sets; @@ -35,6 +36,7 @@ import java.io.IOException; import java.io.Serializable; import java.util.ArrayList; import java.util.Collections; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Set; @@ -75,6 +77,7 @@ import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.sdk.util.WindowedValue.FullWindowedValueCoder; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionView; +import org.apache.beam.sdk.values.PValue; import org.apache.beam.sdk.values.TupleTag; import org.apache.beam.sdk.values.TupleTagList; import org.apache.beam.sdk.values.WindowingStrategy; @@ -91,8 +94,7 @@ public class ParDoTranslation { /** A {@link TransformPayloadTranslator} for {@link ParDo}. */ public static class ParDoPayloadTranslator - extends PTransformTranslation.TransformPayloadTranslator.WithDefaultRehydration< - ParDo.MultiOutput<?, ?>> { + implements TransformPayloadTranslator<MultiOutput<?, ?>> { public static TransformPayloadTranslator create() { return new ParDoPayloadTranslator(); } @@ -115,6 +117,13 @@ public class ParDoTranslation { .build(); } + @Override + public PTransformTranslation.RawPTransform<?, ?> rehydrate( + RunnerApi.PTransform protoTransform, RehydratedComponents rehydratedComponents) + throws IOException { + return new RawParDo<>(protoTransform, rehydratedComponents); + } + /** Registers {@link ParDoPayloadTranslator}. */ @AutoService(TransformPayloadTranslatorRegistrar.class) public static class Registrar implements TransformPayloadTranslatorRegistrar { @@ -125,41 +134,76 @@ public class ParDoTranslation { } @Override - public Map<String, TransformPayloadTranslator> getTransformRehydrators() { - return Collections.emptyMap(); + public Map<String, ? extends TransformPayloadTranslator> getTransformRehydrators() { + return Collections.singletonMap(PAR_DO_TRANSFORM_URN, new ParDoPayloadTranslator()); } } } - public static ParDoPayload toProto(ParDo.MultiOutput<?, ?> parDo, SdkComponents components) + public static ParDoPayload toProto(final ParDo.MultiOutput<?, ?> parDo, SdkComponents components) throws IOException { - DoFn<?, ?> doFn = parDo.getFn(); - DoFnSignature signature = DoFnSignatures.getSignature(doFn.getClass()); - Map<String, StateDeclaration> states = signature.stateDeclarations(); - Map<String, TimerDeclaration> timers = signature.timerDeclarations(); - List<Parameter> parameters = signature.processElement().extraParameters(); - - ParDoPayload.Builder builder = ParDoPayload.newBuilder(); - builder.setDoFn(toProto(parDo.getFn(), parDo.getMainOutputTag())); - builder.setSplittable(signature.processElement().isSplittable()); - for (PCollectionView<?> sideInput : parDo.getSideInputs()) { - builder.putSideInputs(sideInput.getTagInternal().getId(), toProto(sideInput)); - } - for (Parameter parameter : parameters) { - Optional<RunnerApi.Parameter> protoParameter = toProto(parameter); - if (protoParameter.isPresent()) { - builder.addParameters(protoParameter.get()); - } - } - for (Map.Entry<String, StateDeclaration> state : states.entrySet()) { - RunnerApi.StateSpec spec = toProto(getStateSpecOrCrash(state.getValue(), doFn), components); - builder.putStateSpecs(state.getKey(), spec); - } - for (Map.Entry<String, TimerDeclaration> timer : timers.entrySet()) { - RunnerApi.TimerSpec spec = toProto(getTimerSpecOrCrash(timer.getValue(), doFn)); - builder.putTimerSpecs(timer.getKey(), spec); - } - return builder.build(); + + final DoFn<?, ?> doFn = parDo.getFn(); + final DoFnSignature signature = DoFnSignatures.getSignature(doFn.getClass()); + + return payloadForParDoLike( + new ParDoLike() { + @Override + public SdkFunctionSpec translateDoFn(SdkComponents newComponents) { + return toProto(parDo.getFn(), parDo.getMainOutputTag()); + } + + @Override + public List<RunnerApi.Parameter> translateParameters() { + List<RunnerApi.Parameter> parameters = new ArrayList<>(); + for (Parameter parameter : signature.processElement().extraParameters()) { + Optional<RunnerApi.Parameter> protoParameter = toProto(parameter); + if (protoParameter.isPresent()) { + parameters.add(protoParameter.get()); + } + } + return parameters; + } + + @Override + public Map<String, SideInput> translateSideInputs(SdkComponents components) { + Map<String, SideInput> sideInputs = new HashMap<>(); + for (PCollectionView<?> sideInput : parDo.getSideInputs()) { + sideInputs.put(sideInput.getTagInternal().getId(), toProto(sideInput)); + } + return sideInputs; + } + + @Override + public Map<String, RunnerApi.StateSpec> translateStateSpecs(SdkComponents components) + throws IOException { + Map<String, RunnerApi.StateSpec> stateSpecs = new HashMap<>(); + for (Map.Entry<String, StateDeclaration> state : + signature.stateDeclarations().entrySet()) { + RunnerApi.StateSpec spec = + toProto(getStateSpecOrCrash(state.getValue(), doFn), components); + stateSpecs.put(state.getKey(), spec); + } + return stateSpecs; + } + + @Override + public Map<String, RunnerApi.TimerSpec> translateTimerSpecs(SdkComponents newComponents) { + Map<String, RunnerApi.TimerSpec> timerSpecs = new HashMap<>(); + for (Map.Entry<String, TimerDeclaration> timer : + signature.timerDeclarations().entrySet()) { + RunnerApi.TimerSpec spec = toProto(getTimerSpecOrCrash(timer.getValue(), doFn)); + timerSpecs.put(timer.getKey(), spec); + } + return timerSpecs; + } + + @Override + public boolean isSplittable() { + return signature.processElement().isSplittable(); + } + }, + components); } private static StateSpec<?> getStateSpecOrCrash( @@ -603,4 +647,122 @@ public class ParDoTranslation { SerializableUtils.deserializeFromByteArray( spec.getPayload().toByteArray(), "Custom WinodwMappingFn"); } + + static class RawParDo<InputT, OutputT> + extends PTransformTranslation.RawPTransform<PCollection<InputT>, PCollection<OutputT>> + implements ParDoLike { + + private final RunnerApi.PTransform protoTransform; + private final transient RehydratedComponents rehydratedComponents; + + // Parsed from protoTransform and cached + private final FunctionSpec spec; + private final ParDoPayload payload; + + public RawParDo(RunnerApi.PTransform protoTransform, RehydratedComponents rehydratedComponents) + throws IOException { + this.rehydratedComponents = rehydratedComponents; + this.protoTransform = protoTransform; + this.spec = protoTransform.getSpec(); + this.payload = ParDoPayload.parseFrom(spec.getPayload()); + } + + @Override + public FunctionSpec getSpec() { + return spec; + } + + @Override + public FunctionSpec migrate(SdkComponents components) throws IOException { + return FunctionSpec.newBuilder() + .setUrn(PAR_DO_TRANSFORM_URN) + .setPayload(payloadForParDoLike(this, components).toByteString()) + .build(); + } + + @Override + public Map<TupleTag<?>, PValue> getAdditionalInputs() { + Map<TupleTag<?>, PValue> additionalInputs = new HashMap<>(); + for (Map.Entry<String, SideInput> sideInputEntry : payload.getSideInputsMap().entrySet()) { + try { + additionalInputs.put( + new TupleTag<>(sideInputEntry.getKey()), + rehydratedComponents.getPCollection( + protoTransform.getInputsOrThrow(sideInputEntry.getKey()))); + } catch (IOException exc) { + throw new IllegalStateException( + String.format( + "Could not find input with name %s for %s transform", + sideInputEntry.getKey(), ParDo.class.getSimpleName())); + } + } + return additionalInputs; + } + + @Override + public SdkFunctionSpec translateDoFn(SdkComponents newComponents) { + // TODO: re-register the environment with the new components + return payload.getDoFn(); + } + + @Override + public List<RunnerApi.Parameter> translateParameters() { + return MoreObjects.firstNonNull( + payload.getParametersList(), Collections.<RunnerApi.Parameter>emptyList()); + } + + @Override + public Map<String, SideInput> translateSideInputs(SdkComponents components) { + // TODO: re-register the PCollections and UDF environments + return MoreObjects.firstNonNull( + payload.getSideInputsMap(), Collections.<String, SideInput>emptyMap()); + } + + @Override + public Map<String, RunnerApi.StateSpec> translateStateSpecs(SdkComponents components) { + // TODO: re-register the coders + return MoreObjects.firstNonNull( + payload.getStateSpecsMap(), Collections.<String, RunnerApi.StateSpec>emptyMap()); + } + + @Override + public Map<String, RunnerApi.TimerSpec> translateTimerSpecs(SdkComponents newComponents) { + return MoreObjects.firstNonNull( + payload.getTimerSpecsMap(), Collections.<String, RunnerApi.TimerSpec>emptyMap()); + } + + @Override + public boolean isSplittable() { + return payload.getSplittable(); + } + } + + /** These methods drive to-proto translation from Java and from rehydrated ParDos. */ + private interface ParDoLike { + SdkFunctionSpec translateDoFn(SdkComponents newComponents); + + List<RunnerApi.Parameter> translateParameters(); + + Map<String, RunnerApi.SideInput> translateSideInputs(SdkComponents components); + + Map<String, RunnerApi.StateSpec> translateStateSpecs(SdkComponents components) + throws IOException; + + Map<String, RunnerApi.TimerSpec> translateTimerSpecs(SdkComponents newComponents); + + boolean isSplittable(); + } + + public static ParDoPayload payloadForParDoLike(ParDoLike parDo, SdkComponents components) + throws IOException { + + return ParDoPayload.newBuilder() + .setDoFn(parDo.translateDoFn(components)) + .addAllParameters(parDo.translateParameters()) + .putAllStateSpecs(parDo.translateStateSpecs(components)) + .putAllTimerSpecs(parDo.translateTimerSpecs(components)) + .putAllSideInputs(parDo.translateSideInputs(components)) + .setSplittable(parDo.isSplittable()) + .build(); + } } http://git-wip-us.apache.org/repos/asf/beam/blob/7fb3e793/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PipelineTranslation.java ---------------------------------------------------------------------- diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PipelineTranslation.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PipelineTranslation.java index 85033e5..c8d38eb 100644 --- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PipelineTranslation.java +++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PipelineTranslation.java @@ -141,31 +141,9 @@ public class PipelineTranslation { rehydratedComponents.getPCollection(outputEntry.getValue())); } - RunnerApi.FunctionSpec transformSpec = transformProto.getSpec(); RawPTransform<?, ?> transform = PTransformTranslation.rehydrate(transformProto, rehydratedComponents); - // By default, no "additional" inputs, since that is an SDK-specific thing. - // Only ParDo and WriteFiles really separate main from side inputs - Map<TupleTag<?>, PValue> additionalInputs = Collections.emptyMap(); - - // TODO: ParDoTranslation should own it - https://issues.apache.org/jira/browse/BEAM-2674 - if (transformSpec.getUrn().equals(PTransformTranslation.PAR_DO_TRANSFORM_URN)) { - RunnerApi.ParDoPayload payload = RunnerApi.ParDoPayload.parseFrom(transformSpec.getPayload()); - additionalInputs = - sideInputMapToAdditionalInputs( - transformProto, rehydratedComponents, rehydratedInputs, payload.getSideInputsMap()); - } - - // TODO: WriteFilesTranslation should own it - https://issues.apache.org/jira/browse/BEAM-2674 - if (transformSpec.getUrn().equals(PTransformTranslation.WRITE_FILES_TRANSFORM_URN)) { - RunnerApi.WriteFilesPayload payload = - RunnerApi.WriteFilesPayload.parseFrom(transformSpec.getPayload()); - additionalInputs = - sideInputMapToAdditionalInputs( - transformProto, rehydratedComponents, rehydratedInputs, payload.getSideInputsMap()); - } - if (isPrimitive(transformProto)) { transforms.addFinalizedPrimitiveNode( transformProto.getUniqueName(), rehydratedInputs, transform, rehydratedOutputs);
