This is an automated email from the ASF dual-hosted git repository. jagadish pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/samza.git
The following commit(s) were added to refs/heads/master by this push: new 1b267f7 SAMZA-2055: Async high level api 1b267f7 is described below commit 1b267f70b0db9222e2002c18f831ec23a6f08c42 Author: mynameborat <bharath.kumarasubraman...@gmail.com> AuthorDate: Wed Apr 3 12:21:00 2019 -0700 SAMZA-2055: Async high level api https://cwiki.apache.org/confluence/display/SAMZA/SEP-21%3A+Samza+Async+API+for+High+Level Author: mynameborat <bharath.kumarasubraman...@gmail.com> Reviewers: Jagadish <jagad...@apache.org> Closes #905 from mynameborat/async-high-level-api --- .../org/apache/samza/operators/MessageStream.java | 27 +++ .../operators/functions/AsyncFlatMapFunction.java | 73 ++++++++ .../apache/samza/operators/MessageStreamImpl.java | 10 ++ ...atorImpl.java => AsyncFlatmapOperatorImpl.java} | 38 ++-- .../operators/impl/BroadcastOperatorImpl.java | 7 +- ...mOperatorImpl.java => FlatmapOperatorImpl.java} | 10 +- .../samza/operators/impl/InputOperatorImpl.java | 22 ++- .../apache/samza/operators/impl/OperatorImpl.java | 198 ++++++++++++++------- .../samza/operators/impl/OperatorImplGraph.java | 5 +- .../samza/operators/impl/OutputOperatorImpl.java | 6 +- .../operators/impl/PartialJoinOperatorImpl.java | 12 +- .../operators/impl/PartitionByOperatorImpl.java | 6 +- .../operators/impl/SendToTableOperatorImpl.java | 7 +- .../samza/operators/impl/SinkOperatorImpl.java | 6 +- .../impl/StreamTableJoinOperatorImpl.java | 27 +-- .../samza/operators/impl/TriggerScheduler.java | 7 +- .../samza/operators/impl/WindowOperatorImpl.java | 12 +- .../operators/spec/AsyncFlatMapOperatorSpec.java | 59 ++++++ .../apache/samza/operators/spec/OperatorSpec.java | 3 +- .../apache/samza/operators/spec/OperatorSpecs.java | 16 ++ .../org/apache/samza/task/StreamOperatorTask.java | 50 +++++- .../org/apache/samza/task/TaskFactoryUtil.java | 2 +- .../apache/samza/operators/TestJoinOperator.java | 52 +++--- ...ratorImpl.java => TestFlatmapOperatorImpl.java} | 10 +- .../samza/operators/impl/TestOperatorImpl.java | 28 +-- .../operators/impl/TestOperatorImplGraph.java | 4 +- .../impl/TestStreamTableJoinOperatorImpl.java | 5 +- .../samza/operators/impl/TestWindowOperator.java | 98 +++++----- .../org/apache/samza/task/TestTaskFactoryUtil.java | 4 +- .../samza/example/AppWithGlobalConfigExample.java | 32 +--- .../samza/example/AsyncApplicationExample.java | 135 ++++++++++++++ .../org/apache/samza/example/BroadcastExample.java | 13 +- .../apache/samza/example/KeyValueStoreExample.java | 23 +-- .../org/apache/samza/example/MergeExample.java | 8 +- .../samza/example/PageViewCounterExample.java | 32 +--- .../apache/samza/example/RepartitionExample.java | 17 +- .../org/apache/samza/example/WindowExample.java | 11 +- .../apache/samza/example/models/AdClickEvent.java | 37 ++++ .../samza/example/models/EnrichedAdClickEvent.java | 43 +++++ .../org/apache/samza/example/models/Member.java | 43 +++++ .../apache/samza/example/models/PageViewCount.java | 44 +++++ .../apache/samza/example/models/PageViewEvent.java | 43 +++++ .../controlmessages/WatermarkIntegrationTest.java | 7 +- .../samza/test/operator/TestAsyncFlatMap.java | 180 +++++++++++++++++++ .../test/samzasql/TestSamzaSqlRemoteTable.java | 2 + 45 files changed, 1125 insertions(+), 349 deletions(-) diff --git a/samza-api/src/main/java/org/apache/samza/operators/MessageStream.java b/samza-api/src/main/java/org/apache/samza/operators/MessageStream.java index f951a84..141a4d2 100644 --- a/samza-api/src/main/java/org/apache/samza/operators/MessageStream.java +++ b/samza-api/src/main/java/org/apache/samza/operators/MessageStream.java @@ -22,7 +22,9 @@ import java.time.Duration; import java.util.ArrayList; import java.util.Collection; +import java.util.concurrent.CompletionStage; import org.apache.samza.annotation.InterfaceStability; +import org.apache.samza.operators.functions.AsyncFlatMapFunction; import org.apache.samza.operators.functions.FilterFunction; import org.apache.samza.operators.functions.FlatMapFunction; import org.apache.samza.operators.functions.JoinFunction; @@ -68,6 +70,31 @@ public interface MessageStream<M> { <OM> MessageStream<OM> flatMap(FlatMapFunction<? super M, ? extends OM> flatMapFn); /** + * Applies the provided 1:n transformation asynchronously to this {@link MessageStream}. The asynchronous transformation + * is specified through {@link AsyncFlatMapFunction}. The results are emitted to the downstream operators upon the + * completion of the {@link CompletionStage} returned from the {@link AsyncFlatMapFunction}. + * <p> + * The operator can operate in two modes depending on <i>task.max.concurrency.</i>. + * <ul> + * <li> + * Serialized (task.max.concurrency=1) - In this mode, each invocation of the {@link AsyncFlatMapFunction} is guaranteed + * to happen-before next invocation. + * </li> + * <li> + * Parallel (task.max.concurrency>1) - In this mode, multiple invocations can happen in parallel without happens-before guarantee + * and the {@link AsyncFlatMapFunction} is required synchronize any shared state. The operator doesn't provide any ordering guarantees. + * i.e The results corresponding to each invocation of this operator might not be emitted in the same order as invocations. + * By extension, the operator chain that follows it also doesn't have any ordering guarantees. + * </li> + * </ul> + * + * @param asyncFlatMapFn the async function to transform a message to zero or more messages + * @param <OM> the type of messages in the transformed {@link MessageStream} + * @return the transformed {@link MessageStream} + */ + <OM> MessageStream<OM> flatMapAsync(AsyncFlatMapFunction<? super M, ? extends OM> asyncFlatMapFn); + + /** * Applies the provided function to messages in this {@link MessageStream} and returns the * filtered {@link MessageStream}. * <p> diff --git a/samza-api/src/main/java/org/apache/samza/operators/functions/AsyncFlatMapFunction.java b/samza-api/src/main/java/org/apache/samza/operators/functions/AsyncFlatMapFunction.java new file mode 100644 index 0000000..a24e0fe --- /dev/null +++ b/samza-api/src/main/java/org/apache/samza/operators/functions/AsyncFlatMapFunction.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.samza.operators.functions; + +import java.io.Serializable; +import java.util.Collection; +import java.util.concurrent.CompletionStage; +import org.apache.samza.SamzaException; +import org.apache.samza.annotation.InterfaceStability; + + +/** + * Asynchronous variant of the {@link FlatMapFunction} used in tandem with {@link org.apache.samza.operators.MessageStream#flatMapAsync(AsyncFlatMapFunction)} + * to transform a collection of 0 or more messages. + * <p> + * Typically, {@link AsyncFlatMapFunction} is used for describing complex transformations that involve IO operations or remote calls. + * The following pseudo code demonstrates a sample implementation of {@link AsyncFlatMapFunction} that sends out an email + * and returns the status asynchronously. + * <pre> {@code + * AsyncFlatMapFunction<Email, Status> asyncEmailSender = (Email message) -> { + * ... + * + * Request<Email> emailRequest = buildEmailRequest(message); + * Future<EmailResponse> emailResponseFuture = emailClient.sendRequest(emailRequest); // send email asynchronously + * ... + * + * return new CompletableFuture<>(emailResponseFuture) + * .thenApply(response -> fetchStatus(response); + * } + * } + * </pre> + * + * <p> + * The function needs to be thread safe in case of task.max.concurrency>1. It also needs to coordinate any + * shared state since happens-before is not guaranteed between the messages delivered to the function. Refer to + * {@link org.apache.samza.operators.MessageStream#flatMapAsync(AsyncFlatMapFunction)} docs for more details on the modes + * and guarantees. + * + * <p> + * For each invocation, the {@link CompletionStage} returned by the function should be completed successfully/exceptionally + * within task.callback.timeout.ms; failure to do so will result in {@link SamzaException} bringing down the application. + * + * @param <M> type of the input message + * @param <OM> type of the transformed messages + */ +@InterfaceStability.Unstable +@FunctionalInterface +public interface AsyncFlatMapFunction<M, OM> extends InitableFunction, ClosableFunction, Serializable { + + /** + * Transforms the provided message into a collection of 0 or more messages. + * + * @param message the input message to be transformed + * @return a {@link CompletionStage} of a {@link Collection} of transformed messages + */ + CompletionStage<Collection<OM>> apply(M message); +} diff --git a/samza-core/src/main/java/org/apache/samza/operators/MessageStreamImpl.java b/samza-core/src/main/java/org/apache/samza/operators/MessageStreamImpl.java index b652d68..99cc81a 100644 --- a/samza-core/src/main/java/org/apache/samza/operators/MessageStreamImpl.java +++ b/samza-core/src/main/java/org/apache/samza/operators/MessageStreamImpl.java @@ -25,12 +25,14 @@ import java.util.Collection; import org.apache.samza.SamzaException; import org.apache.samza.application.descriptors.StreamApplicationDescriptorImpl; +import org.apache.samza.operators.functions.AsyncFlatMapFunction; import org.apache.samza.operators.functions.FilterFunction; import org.apache.samza.operators.functions.FlatMapFunction; import org.apache.samza.operators.functions.JoinFunction; import org.apache.samza.operators.functions.MapFunction; import org.apache.samza.operators.functions.SinkFunction; import org.apache.samza.operators.functions.StreamTableJoinFunction; +import org.apache.samza.operators.spec.AsyncFlatMapOperatorSpec; import org.apache.samza.operators.spec.BroadcastOperatorSpec; import org.apache.samza.operators.spec.JoinOperatorSpec; import org.apache.samza.operators.spec.OperatorSpec; @@ -103,6 +105,14 @@ public class MessageStreamImpl<M> implements MessageStream<M> { } @Override + public <OM> MessageStream<OM> flatMapAsync(AsyncFlatMapFunction<? super M, ? extends OM> flatMapFn) { + String opId = this.streamAppDesc.getNextOpId(OpCode.ASYNC_FLAT_MAP); + AsyncFlatMapOperatorSpec<M, OM> op = OperatorSpecs.createAsyncOperatorSpec(flatMapFn, opId); + this.operatorSpec.registerNextOperatorSpec(op); + return new MessageStreamImpl<>(this.streamAppDesc, op); + } + + @Override public void sink(SinkFunction<? super M> sinkFn) { String opId = this.streamAppDesc.getNextOpId(OpCode.SINK); SinkOperatorSpec<M> op = OperatorSpecs.createSinkOperatorSpec(sinkFn, opId); diff --git a/samza-core/src/main/java/org/apache/samza/operators/impl/StreamOperatorImpl.java b/samza-core/src/main/java/org/apache/samza/operators/impl/AsyncFlatmapOperatorImpl.java similarity index 60% copy from samza-core/src/main/java/org/apache/samza/operators/impl/StreamOperatorImpl.java copy to samza-core/src/main/java/org/apache/samza/operators/impl/AsyncFlatmapOperatorImpl.java index 1a615bd..fa5e56a 100644 --- a/samza-core/src/main/java/org/apache/samza/operators/impl/StreamOperatorImpl.java +++ b/samza-core/src/main/java/org/apache/samza/operators/impl/AsyncFlatmapOperatorImpl.java @@ -18,49 +18,41 @@ */ package org.apache.samza.operators.impl; +import java.util.Collection; +import java.util.concurrent.CompletionStage; import org.apache.samza.context.Context; -import org.apache.samza.operators.functions.FlatMapFunction; +import org.apache.samza.operators.functions.AsyncFlatMapFunction; +import org.apache.samza.operators.spec.AsyncFlatMapOperatorSpec; import org.apache.samza.operators.spec.OperatorSpec; -import org.apache.samza.operators.spec.StreamOperatorSpec; import org.apache.samza.task.MessageCollector; import org.apache.samza.task.TaskCoordinator; -import java.util.Collection; - - -/** - * A simple operator that accepts a 1:n transform function and applies it to each incoming message. - * - * @param <M> the type of input message - * @param <RM> the type of result - */ -class StreamOperatorImpl<M, RM> extends OperatorImpl<M, RM> { - private final StreamOperatorSpec<M, RM> streamOpSpec; - private final FlatMapFunction<M, RM> transformFn; +public class AsyncFlatmapOperatorImpl<M, RM> extends OperatorImpl<M, RM> { + private final AsyncFlatMapOperatorSpec<M, RM> opSpec; + private final AsyncFlatMapFunction<M, RM> transformFn; - StreamOperatorImpl(StreamOperatorSpec<M, RM> streamOpSpec) { - this.streamOpSpec = streamOpSpec; - this.transformFn = streamOpSpec.getTransformFn(); + AsyncFlatmapOperatorImpl(AsyncFlatMapOperatorSpec<M, RM> opSpec) { + this.opSpec = opSpec; + this.transformFn = opSpec.getTransformFn(); } - @Override protected void handleInit(Context context) { - transformFn.init(context); + this.transformFn.init(context); } @Override - public Collection<RM> handleMessage(M message, MessageCollector collector, + protected CompletionStage<Collection<RM>> handleMessageAsync(M message, MessageCollector collector, TaskCoordinator coordinator) { - return this.transformFn.apply(message); + return transformFn.apply(message); } @Override protected void handleClose() { - this.transformFn.close(); } + @Override protected OperatorSpec<M, RM> getOperatorSpec() { - return streamOpSpec; + return opSpec; } } diff --git a/samza-core/src/main/java/org/apache/samza/operators/impl/BroadcastOperatorImpl.java b/samza-core/src/main/java/org/apache/samza/operators/impl/BroadcastOperatorImpl.java index 4965f7b..4f93f2c 100644 --- a/samza-core/src/main/java/org/apache/samza/operators/impl/BroadcastOperatorImpl.java +++ b/samza-core/src/main/java/org/apache/samza/operators/impl/BroadcastOperatorImpl.java @@ -19,6 +19,8 @@ package org.apache.samza.operators.impl; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionStage; import org.apache.samza.context.Context; import org.apache.samza.operators.spec.BroadcastOperatorSpec; import org.apache.samza.operators.spec.OperatorSpec; @@ -50,9 +52,10 @@ class BroadcastOperatorImpl<M> extends OperatorImpl<M, Void> { } @Override - protected Collection<Void> handleMessage(M message, MessageCollector collector, TaskCoordinator coordinator) { + protected CompletionStage<Collection<Void>> handleMessageAsync(M message, MessageCollector collector, + TaskCoordinator coordinator) { collector.send(new OutgoingMessageEnvelope(systemStream, 0, null, message)); - return Collections.emptyList(); + return CompletableFuture.completedFuture(Collections.emptyList()); } @Override diff --git a/samza-core/src/main/java/org/apache/samza/operators/impl/StreamOperatorImpl.java b/samza-core/src/main/java/org/apache/samza/operators/impl/FlatmapOperatorImpl.java similarity index 82% rename from samza-core/src/main/java/org/apache/samza/operators/impl/StreamOperatorImpl.java rename to samza-core/src/main/java/org/apache/samza/operators/impl/FlatmapOperatorImpl.java index 1a615bd..3191d33 100644 --- a/samza-core/src/main/java/org/apache/samza/operators/impl/StreamOperatorImpl.java +++ b/samza-core/src/main/java/org/apache/samza/operators/impl/FlatmapOperatorImpl.java @@ -18,6 +18,8 @@ */ package org.apache.samza.operators.impl; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionStage; import org.apache.samza.context.Context; import org.apache.samza.operators.functions.FlatMapFunction; import org.apache.samza.operators.spec.OperatorSpec; @@ -34,12 +36,12 @@ import java.util.Collection; * @param <M> the type of input message * @param <RM> the type of result */ -class StreamOperatorImpl<M, RM> extends OperatorImpl<M, RM> { +class FlatmapOperatorImpl<M, RM> extends OperatorImpl<M, RM> { private final StreamOperatorSpec<M, RM> streamOpSpec; private final FlatMapFunction<M, RM> transformFn; - StreamOperatorImpl(StreamOperatorSpec<M, RM> streamOpSpec) { + FlatmapOperatorImpl(StreamOperatorSpec<M, RM> streamOpSpec) { this.streamOpSpec = streamOpSpec; this.transformFn = streamOpSpec.getTransformFn(); } @@ -50,9 +52,9 @@ class StreamOperatorImpl<M, RM> extends OperatorImpl<M, RM> { } @Override - public Collection<RM> handleMessage(M message, MessageCollector collector, + protected CompletionStage<Collection<RM>> handleMessageAsync(M message, MessageCollector collector, TaskCoordinator coordinator) { - return this.transformFn.apply(message); + return CompletableFuture.completedFuture(this.transformFn.apply(message)); } @Override diff --git a/samza-core/src/main/java/org/apache/samza/operators/impl/InputOperatorImpl.java b/samza-core/src/main/java/org/apache/samza/operators/impl/InputOperatorImpl.java index 8cf528c..d9dff9e 100644 --- a/samza-core/src/main/java/org/apache/samza/operators/impl/InputOperatorImpl.java +++ b/samza-core/src/main/java/org/apache/samza/operators/impl/InputOperatorImpl.java @@ -18,6 +18,9 @@ */ package org.apache.samza.operators.impl; +import java.util.Optional; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionStage; import org.apache.samza.context.Context; import org.apache.samza.operators.KV; import org.apache.samza.system.descriptors.InputTransformer; @@ -47,18 +50,21 @@ public final class InputOperatorImpl extends OperatorImpl<IncomingMessageEnvelop } @Override - public Collection<Object> handleMessage(IncomingMessageEnvelope ime, MessageCollector collector, TaskCoordinator coordinator) { - Object message; + protected CompletionStage<Collection<Object>> handleMessageAsync(IncomingMessageEnvelope message, + MessageCollector collector, TaskCoordinator coordinator) { + Object result; InputTransformer transformer = inputOpSpec.getTransformer(); if (transformer != null) { - message = transformer.apply(ime); + result = transformer.apply(message); } else { - message = this.inputOpSpec.isKeyed() ? KV.of(ime.getKey(), ime.getMessage()) : ime.getMessage(); + result = this.inputOpSpec.isKeyed() ? KV.of(message.getKey(), message.getMessage()) : message.getMessage(); } - if (message != null) { - return Collections.singletonList(message); - } - return Collections.emptyList(); + + Collection<Object> output = Optional.ofNullable(result) + .map(Collections::singletonList) + .orElse(Collections.emptyList()); + + return CompletableFuture.completedFuture(output); } @Override diff --git a/samza-core/src/main/java/org/apache/samza/operators/impl/OperatorImpl.java b/samza-core/src/main/java/org/apache/samza/operators/impl/OperatorImpl.java index 276e8c3..8d4ae21 100644 --- a/samza-core/src/main/java/org/apache/samza/operators/impl/OperatorImpl.java +++ b/samza-core/src/main/java/org/apache/samza/operators/impl/OperatorImpl.java @@ -18,6 +18,9 @@ */ package org.apache.samza.operators.impl; +import com.google.common.annotations.VisibleForTesting; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionStage; import org.apache.samza.SamzaException; import org.apache.samza.config.Config; import org.apache.samza.config.MetricsConfig; @@ -161,21 +164,13 @@ public abstract class OperatorImpl<M, RM> { || taskModel.getSystemStreamPartitions().stream().anyMatch(ssp -> ssp.getSystemStream().equals(input)); } - /** - * Handle the incoming {@code message} for this {@link OperatorImpl} and propagate results to registered operators. - * <p> - * Delegates to {@link #handleMessage(Object, MessageCollector, TaskCoordinator)} for handling the message. - * - * @param message the input message - * @param collector the {@link MessageCollector} for this message - * @param coordinator the {@link TaskCoordinator} for this message - */ - public final void onMessage(M message, MessageCollector collector, TaskCoordinator coordinator) { + public final CompletionStage<Void> onMessageAsync(M message, MessageCollector collector, + TaskCoordinator coordinator) { this.numMessage.inc(); long startNs = this.highResClock.nanoTime(); - Collection<RM> results; + CompletionStage<Collection<RM>> completableResultsFuture; try { - results = handleMessage(message, collector, coordinator); + completableResultsFuture = handleMessageAsync(message, collector, coordinator); } catch (ClassCastException e) { String actualType = e.getMessage().replaceFirst(" cannot be cast to .*", ""); String expectedType = e.getMessage().replaceFirst(".* cannot be cast to ", ""); @@ -186,30 +181,39 @@ public abstract class OperatorImpl<M, RM> { getOpImplId(), getOperatorSpec().getSourceLocation(), expectedType, actualType), e); } - long endNs = this.highResClock.nanoTime(); - this.handleMessageNs.update(endNs - startNs); - - results.forEach(rm -> - this.registeredOperators.forEach(op -> - op.onMessage(rm, collector, coordinator))); - - WatermarkFunction watermarkFn = getOperatorSpec().getWatermarkFn(); - if (watermarkFn != null) { - // check whether there is new watermark emitted from the user function - Long outputWm = watermarkFn.getOutputWatermark(); - propagateWatermark(outputWm, collector, coordinator); - } + CompletionStage<Void> result = completableResultsFuture.thenCompose(results -> { + long endNs = this.highResClock.nanoTime(); + this.handleMessageNs.update(endNs - startNs); + + return CompletableFuture.allOf(results.stream() + .flatMap(r -> this.registeredOperators.stream() + .map(op -> op.onMessageAsync(r, collector, coordinator))) + .toArray(CompletableFuture[]::new)); + }); + + result.thenAccept(x -> { + WatermarkFunction watermarkFn = getOperatorSpec().getWatermarkFn(); + if (watermarkFn != null) { + // check whether there is new watermark emitted from the user function + Long outputWm = watermarkFn.getOutputWatermark(); + propagateWatermark(outputWm, collector, coordinator); + } + }); + + return result; } /** - * Handle the incoming {@code message} and return the results to be propagated to registered operators. + * Handle the incoming {@code message} asynchronously and return a {@link CompletionStage} of the results to be propagated + * to the registered operators. * - * @param message the input message - * @param collector the {@link MessageCollector} in the context - * @param coordinator the {@link TaskCoordinator} in the context - * @return results of the transformation + * @param message the input message + * @param collector the {@link MessageCollector} in the context + * @param coordinator the {@link TaskCoordinator} in the context + * + * @return a {@code CompletionStage} of the results of the transformation */ - protected abstract Collection<RM> handleMessage(M message, MessageCollector collector, + protected abstract CompletionStage<Collection<RM>> handleMessageAsync(M message, MessageCollector collector, TaskCoordinator coordinator); /** @@ -220,17 +224,23 @@ public abstract class OperatorImpl<M, RM> { * @param collector the {@link MessageCollector} in the context * @param coordinator the {@link TaskCoordinator} in the context */ - public final void onTimer(MessageCollector collector, TaskCoordinator coordinator) { + public final CompletionStage<Void> onTimer(MessageCollector collector, TaskCoordinator coordinator) { long startNs = this.highResClock.nanoTime(); Collection<RM> results = handleTimer(collector, coordinator); long endNs = this.highResClock.nanoTime(); this.handleTimerNs.update(endNs - startNs); - results.forEach(rm -> - this.registeredOperators.forEach(op -> - op.onMessage(rm, collector, coordinator))); - this.registeredOperators.forEach(op -> - op.onTimer(collector, coordinator)); + CompletionStage<Void> resultFuture = CompletableFuture.allOf( + results.stream() + .flatMap(r -> this.registeredOperators.stream() + .map(op -> op.onMessageAsync(r, collector, coordinator))) + .toArray(CompletableFuture[]::new)); + + return resultFuture.thenCompose(x -> + CompletableFuture.allOf(this.registeredOperators + .stream() + .map(op -> op.onTimer(collector, coordinator)) + .toArray(CompletableFuture[]::new))); } /** @@ -254,12 +264,14 @@ public abstract class OperatorImpl<M, RM> { * @param collector message collector * @param coordinator task coordinator */ - public final void aggregateEndOfStream(EndOfStreamMessage eos, SystemStreamPartition ssp, MessageCollector collector, + public final CompletionStage<Void> aggregateEndOfStream(EndOfStreamMessage eos, SystemStreamPartition ssp, MessageCollector collector, TaskCoordinator coordinator) { LOG.info("Received end-of-stream message from task {} in {}", eos.getTaskName(), ssp); eosStates.update(eos, ssp); SystemStream stream = ssp.getSystemStream(); + CompletionStage<Void> endOfStreamFuture = CompletableFuture.completedFuture(null); + if (eosStates.isEndOfStream(stream)) { LOG.info("Input {} reaches the end for task {}", stream.toString(), taskName.getTaskName()); if (eos.getTaskName() != null) { @@ -267,16 +279,20 @@ public abstract class OperatorImpl<M, RM> { // broadcast the end-of-stream to all the peer partitions controlMessageSender.broadcastToOtherPartitions(new EndOfStreamMessage(), ssp, collector); } - // populate the end-of-stream through the dag - onEndOfStream(collector, coordinator); - if (eosStates.allEndOfStream()) { - // all inputs have been end-of-stream, shut down the task - LOG.info("All input streams have reached the end for task {}", taskName.getTaskName()); - coordinator.commit(TaskCoordinator.RequestScope.CURRENT_TASK); - coordinator.shutdown(TaskCoordinator.RequestScope.CURRENT_TASK); - } + // populate the end-of-stream through the dag + endOfStreamFuture = onEndOfStream(collector, coordinator) + .thenAccept(result -> { + if (eosStates.allEndOfStream()) { + // all inputs have been end-of-stream, shut down the task + LOG.info("All input streams have reached the end for task {}", taskName.getTaskName()); + coordinator.commit(TaskCoordinator.RequestScope.CURRENT_TASK); + coordinator.shutdown(TaskCoordinator.RequestScope.CURRENT_TASK); + } + }); } + + return endOfStreamFuture; } /** @@ -285,16 +301,25 @@ public abstract class OperatorImpl<M, RM> { * @param collector message collector * @param coordinator task coordinator */ - private final void onEndOfStream(MessageCollector collector, TaskCoordinator coordinator) { + private CompletionStage<Void> onEndOfStream(MessageCollector collector, TaskCoordinator coordinator) { + CompletionStage<Void> endOfStreamFuture = CompletableFuture.completedFuture(null); + if (inputStreams.stream().allMatch(input -> eosStates.isEndOfStream(input))) { Collection<RM> results = handleEndOfStream(collector, coordinator); - results.forEach(rm -> - this.registeredOperators.forEach(op -> - op.onMessage(rm, collector, coordinator))); + CompletionStage<Void> resultFuture = CompletableFuture.allOf( + results.stream() + .flatMap(r -> this.registeredOperators.stream() + .map(op -> op.onMessageAsync(r, collector, coordinator))) + .toArray(CompletableFuture[]::new)); - this.registeredOperators.forEach(op -> op.onEndOfStream(collector, coordinator)); + endOfStreamFuture = resultFuture.thenCompose(x -> + CompletableFuture.allOf(this.registeredOperators.stream() + .map(op -> op.onEndOfStream(collector, coordinator)) + .toArray(CompletableFuture[]::new))); } + + return endOfStreamFuture; } /** @@ -318,11 +343,13 @@ public abstract class OperatorImpl<M, RM> { * @param collector message collector * @param coordinator task coordinator */ - public final void aggregateWatermark(WatermarkMessage watermarkMessage, SystemStreamPartition ssp, + public final CompletionStage<Void> aggregateWatermark(WatermarkMessage watermarkMessage, SystemStreamPartition ssp, MessageCollector collector, TaskCoordinator coordinator) { LOG.debug("Received watermark {} from {}", watermarkMessage.getTimestamp(), ssp); watermarkStates.update(watermarkMessage, ssp); long watermark = watermarkStates.getWatermark(ssp.getSystemStream()); + CompletionStage<Void> watermarkFuture = CompletableFuture.completedFuture(null); + if (currentWatermark < watermark) { LOG.debug("Got watermark {} from stream {}", watermark, ssp.getSystemStream()); @@ -332,11 +359,11 @@ public abstract class OperatorImpl<M, RM> { controlMessageSender.broadcastToOtherPartitions(new WatermarkMessage(watermark), ssp, collector); } // populate the watermark through the dag - onWatermark(watermark, collector, coordinator); - - // update metrics - watermarkStates.updateAggregateMetric(ssp, watermark); + watermarkFuture = onWatermark(watermark, collector, coordinator) + .thenAccept(ignored -> watermarkStates.updateAggregateMetric(ssp, watermark)); } + + return watermarkFuture; } /** @@ -347,16 +374,20 @@ public abstract class OperatorImpl<M, RM> { * @param collector message collector * @param coordinator task coordinator */ - private final void onWatermark(long watermark, MessageCollector collector, TaskCoordinator coordinator) { + private CompletionStage<Void> onWatermark(long watermark, MessageCollector collector, TaskCoordinator coordinator) { final long inputWatermarkMin; if (prevOperators.isEmpty()) { // for input operator, use the watermark time coming from the source input inputWatermarkMin = watermark; } else { // InputWatermark(op) = min { OutputWatermark(op') | op' is upstream of op} - inputWatermarkMin = prevOperators.stream().map(op -> op.getOutputWatermark()).min(Long::compare).get(); + inputWatermarkMin = prevOperators.stream() + .map(op -> op.getOutputWatermark()) + .min(Long::compare) + .get(); } + CompletionStage<Void> watermarkFuture = CompletableFuture.completedFuture(null); if (currentWatermark < inputWatermarkMin) { // advance the watermark time of this operator currentWatermark = inputWatermarkMin; @@ -377,26 +408,38 @@ public abstract class OperatorImpl<M, RM> { } if (!output.isEmpty()) { - output.forEach(rm -> - this.registeredOperators.forEach(op -> - op.onMessage(rm, collector, coordinator))); + watermarkFuture = CompletableFuture.allOf( + output.stream() + .flatMap(rm -> this.registeredOperators.stream() + .map(op -> op.onMessageAsync(rm, collector, coordinator))) + .toArray(CompletableFuture[]::new)); } - propagateWatermark(outputWm, collector, coordinator); + watermarkFuture.thenCompose(res -> propagateWatermark(outputWm, collector, coordinator)); } + + return watermarkFuture; } - private void propagateWatermark(Long outputWm, MessageCollector collector, TaskCoordinator coordinator) { + private CompletionStage<Void> propagateWatermark(Long outputWm, MessageCollector collector, TaskCoordinator coordinator) { + CompletionStage<Void> watermarkFuture = CompletableFuture.completedFuture(null); + if (outputWm != null) { if (outputWatermark < outputWm) { // advance the watermark outputWatermark = outputWm; LOG.debug("Advance output watermark to {} in operator {}", outputWatermark, getOpImplId()); - this.registeredOperators.forEach(op -> op.onWatermark(outputWatermark, collector, coordinator)); + watermarkFuture = CompletableFuture.allOf( + this.registeredOperators + .stream() + .map(op -> op.onWatermark(outputWatermark, collector, coordinator)) + .toArray(CompletableFuture[]::new)); } else if (outputWatermark > outputWm) { LOG.warn("Ignore watermark {} that is smaller than the previous watermark {}.", outputWm, outputWatermark); } } + + return watermarkFuture; } /** @@ -449,9 +492,12 @@ public abstract class OperatorImpl<M, RM> { final Collection<RM> output = scheduledFn.onCallback(key, time); if (!output.isEmpty()) { - output.forEach(rm -> - registeredOperators.forEach(op -> - op.onMessage(rm, collector, coordinator))); + CompletableFuture<Void> timerFuture = CompletableFuture.allOf(output.stream() + .flatMap(r -> registeredOperators.stream() + .map(op -> op.onMessageAsync(r, collector, coordinator))) + .toArray(CompletableFuture[]::new)); + + timerFuture.join(); } } else { throw new SamzaException( @@ -499,6 +545,24 @@ public abstract class OperatorImpl<M, RM> { return getOperatorSpec().getOpId(); } + /* Package Private helper method for tests to perform onMessage synchronously + * Note: It is only intended for test use + */ + @VisibleForTesting + final void onMessage(M message, MessageCollector collector, TaskCoordinator coordinator) { + onMessageAsync(message, collector, coordinator) + .toCompletableFuture().join(); + } + + /* Package Private helper method for tests to perform handleMessage synchronously + * Note: It is only intended for test use + */ + @VisibleForTesting + final Collection<RM> handleMessage(M message, MessageCollector collector, TaskCoordinator coordinator) { + return handleMessageAsync(message, collector, coordinator) + .toCompletableFuture().join(); + } + private HighResolutionClock createHighResClock(Config config) { MetricsConfig metricsConfig = new MetricsConfig(config); // The timer metrics calculation here is only enabled for debugging diff --git a/samza-core/src/main/java/org/apache/samza/operators/impl/OperatorImplGraph.java b/samza-core/src/main/java/org/apache/samza/operators/impl/OperatorImplGraph.java index 9f33356..2b95321 100644 --- a/samza-core/src/main/java/org/apache/samza/operators/impl/OperatorImplGraph.java +++ b/samza-core/src/main/java/org/apache/samza/operators/impl/OperatorImplGraph.java @@ -31,6 +31,7 @@ import org.apache.samza.operators.OperatorSpecGraph; import org.apache.samza.operators.Scheduler; import org.apache.samza.operators.functions.JoinFunction; import org.apache.samza.operators.functions.PartialJoinFunction; +import org.apache.samza.operators.spec.AsyncFlatMapOperatorSpec; import org.apache.samza.operators.spec.BroadcastOperatorSpec; import org.apache.samza.operators.spec.InputOperatorSpec; import org.apache.samza.operators.spec.JoinOperatorSpec; @@ -218,7 +219,7 @@ public class OperatorImplGraph { if (operatorSpec instanceof InputOperatorSpec) { return new InputOperatorImpl((InputOperatorSpec) operatorSpec); } else if (operatorSpec instanceof StreamOperatorSpec) { - return new StreamOperatorImpl((StreamOperatorSpec) operatorSpec); + return new FlatmapOperatorImpl((StreamOperatorSpec) operatorSpec); } else if (operatorSpec instanceof SinkOperatorSpec) { return new SinkOperatorImpl((SinkOperatorSpec) operatorSpec); } else if (operatorSpec instanceof OutputOperatorSpec) { @@ -243,6 +244,8 @@ public class OperatorImplGraph { String streamId = ((BroadcastOperatorSpec) operatorSpec).getOutputStream().getStreamId(); SystemStream systemStream = streamConfig.streamIdToSystemStream(streamId); return new BroadcastOperatorImpl((BroadcastOperatorSpec) operatorSpec, systemStream, context); + } else if (operatorSpec instanceof AsyncFlatMapOperatorSpec) { + return new AsyncFlatmapOperatorImpl((AsyncFlatMapOperatorSpec) operatorSpec); } throw new IllegalArgumentException( String.format("Unsupported OperatorSpec: %s", operatorSpec.getClass().getName())); diff --git a/samza-core/src/main/java/org/apache/samza/operators/impl/OutputOperatorImpl.java b/samza-core/src/main/java/org/apache/samza/operators/impl/OutputOperatorImpl.java index 407cdd9..566485a 100644 --- a/samza-core/src/main/java/org/apache/samza/operators/impl/OutputOperatorImpl.java +++ b/samza-core/src/main/java/org/apache/samza/operators/impl/OutputOperatorImpl.java @@ -18,6 +18,8 @@ */ package org.apache.samza.operators.impl; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionStage; import org.apache.samza.context.Context; import org.apache.samza.operators.KV; import org.apache.samza.operators.spec.OperatorSpec; @@ -52,7 +54,7 @@ class OutputOperatorImpl<M> extends OperatorImpl<M, Void> { } @Override - public Collection<Void> handleMessage(M message, MessageCollector collector, + protected CompletionStage<Collection<Void>> handleMessageAsync(M message, MessageCollector collector, TaskCoordinator coordinator) { Object key, value; if (outputStream.isKeyed()) { @@ -64,7 +66,7 @@ class OutputOperatorImpl<M> extends OperatorImpl<M, Void> { } collector.send(new OutgoingMessageEnvelope(systemStream, null, key, value)); - return Collections.emptyList(); + return CompletableFuture.completedFuture(Collections.emptyList()); } @Override diff --git a/samza-core/src/main/java/org/apache/samza/operators/impl/PartialJoinOperatorImpl.java b/samza-core/src/main/java/org/apache/samza/operators/impl/PartialJoinOperatorImpl.java index 55658eb..6497f67 100644 --- a/samza-core/src/main/java/org/apache/samza/operators/impl/PartialJoinOperatorImpl.java +++ b/samza-core/src/main/java/org/apache/samza/operators/impl/PartialJoinOperatorImpl.java @@ -18,6 +18,8 @@ */ package org.apache.samza.operators.impl; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionStage; import org.apache.samza.SamzaException; import org.apache.samza.context.Context; import org.apache.samza.operators.functions.PartialJoinFunction; @@ -68,7 +70,10 @@ class PartialJoinOperatorImpl<K, M, OM, JM> extends OperatorImpl<M, JM> { } @Override - public Collection<JM> handleMessage(M message, MessageCollector collector, TaskCoordinator coordinator) { + protected CompletionStage<Collection<JM>> handleMessageAsync(M message, MessageCollector collector, + TaskCoordinator coordinator) { + Collection<JM> output = Collections.emptyList(); + try { KeyValueStore<K, TimestampedValue<M>> thisState = thisPartialJoinFn.getState(); KeyValueStore<K, TimestampedValue<OM>> otherState = otherPartialJoinFn.getState(); @@ -80,12 +85,13 @@ class PartialJoinOperatorImpl<K, M, OM, JM> extends OperatorImpl<M, JM> { long now = clock.currentTimeMillis(); if (otherMessage != null && otherMessage.getTimestamp() > now - ttlMs) { JM joinResult = thisPartialJoinFn.apply(message, otherMessage.getValue()); - return Collections.singletonList(joinResult); + output = Collections.singletonList(joinResult); } } catch (Exception e) { throw new SamzaException("Error handling message in PartialJoinOperatorImpl " + getOpImplId(), e); } - return Collections.emptyList(); + + return CompletableFuture.completedFuture(output); } @Override diff --git a/samza-core/src/main/java/org/apache/samza/operators/impl/PartitionByOperatorImpl.java b/samza-core/src/main/java/org/apache/samza/operators/impl/PartitionByOperatorImpl.java index 134a517..47ad4f6 100644 --- a/samza-core/src/main/java/org/apache/samza/operators/impl/PartitionByOperatorImpl.java +++ b/samza-core/src/main/java/org/apache/samza/operators/impl/PartitionByOperatorImpl.java @@ -18,6 +18,8 @@ */ package org.apache.samza.operators.impl; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionStage; import org.apache.samza.context.Context; import org.apache.samza.context.InternalTaskContext; import org.apache.samza.operators.functions.MapFunction; @@ -66,13 +68,13 @@ class PartitionByOperatorImpl<M, K, V> extends OperatorImpl<M, Void> { } @Override - public Collection<Void> handleMessage(M message, MessageCollector collector, + protected CompletionStage<Collection<Void>> handleMessageAsync(M message, MessageCollector collector, TaskCoordinator coordinator) { K key = keyFunction.apply(message); V value = valueFunction.apply(message); Long partitionKey = key == null ? 0L : null; collector.send(new OutgoingMessageEnvelope(systemStream, partitionKey, key, value)); - return Collections.emptyList(); + return CompletableFuture.completedFuture(Collections.emptyList()); } @Override diff --git a/samza-core/src/main/java/org/apache/samza/operators/impl/SendToTableOperatorImpl.java b/samza-core/src/main/java/org/apache/samza/operators/impl/SendToTableOperatorImpl.java index 6d84b17..1197b37 100644 --- a/samza-core/src/main/java/org/apache/samza/operators/impl/SendToTableOperatorImpl.java +++ b/samza-core/src/main/java/org/apache/samza/operators/impl/SendToTableOperatorImpl.java @@ -18,6 +18,8 @@ */ package org.apache.samza.operators.impl; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionStage; import org.apache.samza.context.Context; import org.apache.samza.operators.KV; import org.apache.samza.operators.spec.OperatorSpec; @@ -52,10 +54,11 @@ public class SendToTableOperatorImpl<K, V> extends OperatorImpl<KV<K, V>, Void> } @Override - protected Collection<Void> handleMessage(KV<K, V> message, MessageCollector collector, TaskCoordinator coordinator) { + protected CompletionStage<Collection<Void>> handleMessageAsync(KV<K, V> message, MessageCollector collector, + TaskCoordinator coordinator) { table.put(message.getKey(), message.getValue()); // there should be no further chained operators since this is a terminal operator. - return Collections.emptyList(); + return CompletableFuture.completedFuture(Collections.emptyList()); } @Override diff --git a/samza-core/src/main/java/org/apache/samza/operators/impl/SinkOperatorImpl.java b/samza-core/src/main/java/org/apache/samza/operators/impl/SinkOperatorImpl.java index 6fe9006..56b2d33 100644 --- a/samza-core/src/main/java/org/apache/samza/operators/impl/SinkOperatorImpl.java +++ b/samza-core/src/main/java/org/apache/samza/operators/impl/SinkOperatorImpl.java @@ -18,6 +18,8 @@ */ package org.apache.samza.operators.impl; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionStage; import org.apache.samza.context.Context; import org.apache.samza.operators.functions.SinkFunction; import org.apache.samza.operators.spec.OperatorSpec; @@ -48,11 +50,11 @@ class SinkOperatorImpl<M> extends OperatorImpl<M, Void> { } @Override - public Collection<Void> handleMessage(M message, MessageCollector collector, + protected CompletionStage<Collection<Void>> handleMessageAsync(M message, MessageCollector collector, TaskCoordinator coordinator) { this.sinkFn.apply(message, collector, coordinator); // there should be no further chained operators since this is a terminal operator. - return Collections.emptyList(); + return CompletableFuture.completedFuture(Collections.emptyList()); } @Override diff --git a/samza-core/src/main/java/org/apache/samza/operators/impl/StreamTableJoinOperatorImpl.java b/samza-core/src/main/java/org/apache/samza/operators/impl/StreamTableJoinOperatorImpl.java index e3fc266..ec4d45d 100644 --- a/samza-core/src/main/java/org/apache/samza/operators/impl/StreamTableJoinOperatorImpl.java +++ b/samza-core/src/main/java/org/apache/samza/operators/impl/StreamTableJoinOperatorImpl.java @@ -18,6 +18,9 @@ */ package org.apache.samza.operators.impl; +import java.util.Optional; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionStage; import org.apache.samza.context.Context; import org.apache.samza.operators.KV; import org.apache.samza.operators.spec.OperatorSpec; @@ -55,26 +58,30 @@ class StreamTableJoinOperatorImpl<K, M, R extends KV, JM> extends OperatorImpl<M } @Override - public Collection<JM> handleMessage(M message, MessageCollector collector, TaskCoordinator coordinator) { + protected CompletionStage<Collection<JM>> handleMessageAsync(M message, MessageCollector collector, + TaskCoordinator coordinator) { if (message == null) { - return Collections.emptyList(); + return CompletableFuture.completedFuture(Collections.emptyList()); } K key = joinOpSpec.getJoinFn().getMessageKey(message); - Object recordValue = null; - if (key != null) { - recordValue = table.get(key); - } + return Optional.ofNullable(key) + .map(joinKey -> table.getAsync(joinKey) + .thenApply(val -> getJoinOutput(joinKey, val, message))) + .orElseGet(() -> CompletableFuture.completedFuture(getJoinOutput(key, null, message))); + } - R record = recordValue != null ? (R) KV.of(key, recordValue) : null; - JM output = joinOpSpec.getJoinFn().apply(message, record); + private Collection<JM> getJoinOutput(K key, Object value, M message) { + JM output = Optional.ofNullable(value) + .map(val -> (R) KV.of(key, val)) + .map(record -> joinOpSpec.getJoinFn().apply(message, record)) + .orElseGet(() -> joinOpSpec.getJoinFn().apply(message, null)); // The support for inner and outer join will be provided in the jonFn. For inner join, the joinFn might // return null, when the corresponding record is absent in the table. return output != null ? - Collections.singletonList(output) - : Collections.emptyList(); + Collections.singletonList(output) : Collections.emptyList(); } @Override diff --git a/samza-core/src/main/java/org/apache/samza/operators/impl/TriggerScheduler.java b/samza-core/src/main/java/org/apache/samza/operators/impl/TriggerScheduler.java index 952d9f1..7628d33 100644 --- a/samza-core/src/main/java/org/apache/samza/operators/impl/TriggerScheduler.java +++ b/samza-core/src/main/java/org/apache/samza/operators/impl/TriggerScheduler.java @@ -19,6 +19,8 @@ package org.apache.samza.operators.impl; +import java.util.Queue; +import java.util.concurrent.PriorityBlockingQueue; import org.apache.samza.operators.triggers.Cancellable; import org.apache.samza.util.Clock; import org.slf4j.Logger; @@ -26,7 +28,6 @@ import org.slf4j.LoggerFactory; import java.util.ArrayList; import java.util.List; -import java.util.PriorityQueue; /** * Allows to schedule and cancel callbacks for triggers. @@ -35,11 +36,11 @@ public class TriggerScheduler<WK> { private static final Logger LOG = LoggerFactory.getLogger(TriggerScheduler.class); - private final PriorityQueue<TriggerCallbackState<WK>> pendingCallbacks; + private final Queue<TriggerCallbackState<WK>> pendingCallbacks; private final Clock clock; public TriggerScheduler(Clock clock) { - this.pendingCallbacks = new PriorityQueue<>(); + this.pendingCallbacks = new PriorityBlockingQueue<>(); this.clock = clock; } diff --git a/samza-core/src/main/java/org/apache/samza/operators/impl/WindowOperatorImpl.java b/samza-core/src/main/java/org/apache/samza/operators/impl/WindowOperatorImpl.java index 0241d9e..f2b5914 100644 --- a/samza-core/src/main/java/org/apache/samza/operators/impl/WindowOperatorImpl.java +++ b/samza-core/src/main/java/org/apache/samza/operators/impl/WindowOperatorImpl.java @@ -21,6 +21,9 @@ package org.apache.samza.operators.impl; import com.google.common.base.Preconditions; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionStage; +import java.util.concurrent.ConcurrentHashMap; import org.apache.samza.context.Context; import org.apache.samza.operators.functions.FoldLeftFunction; import org.apache.samza.operators.functions.MapFunction; @@ -53,7 +56,6 @@ import org.slf4j.LoggerFactory; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; -import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; @@ -96,7 +98,7 @@ public class WindowOperatorImpl<M, K> extends OperatorImpl<M, WindowPane<K, Obje private final MapFunction<M, K> keyFn; private final TriggerScheduler<K> triggerScheduler; - private final Map<TriggerKey<K>, TriggerImplHandler> triggers = new HashMap<>(); + private final Map<TriggerKey<K>, TriggerImplHandler> triggers = new ConcurrentHashMap<>(); private TimeSeriesStore<K, Object> timeSeriesStore; public WindowOperatorImpl(WindowOperatorSpec<M, K, Object> windowOpSpec, Clock clock) { @@ -134,8 +136,8 @@ public class WindowOperatorImpl<M, K> extends OperatorImpl<M, WindowPane<K, Obje } @Override - public Collection<WindowPane<K, Object>> handleMessage( - M message, MessageCollector collector, TaskCoordinator coordinator) { + protected CompletionStage<Collection<WindowPane<K, Object>>> handleMessageAsync(M message, MessageCollector collector, + TaskCoordinator coordinator) { LOG.trace("Processing message envelope: {}", message); List<WindowPane<K, Object>> results = new ArrayList<>(); @@ -177,7 +179,7 @@ public class WindowOperatorImpl<M, K> extends OperatorImpl<M, WindowPane<K, Obje maybeTriggeredPane.ifPresent(results::add); } - return results; + return CompletableFuture.completedFuture(results); } @Override diff --git a/samza-core/src/main/java/org/apache/samza/operators/spec/AsyncFlatMapOperatorSpec.java b/samza-core/src/main/java/org/apache/samza/operators/spec/AsyncFlatMapOperatorSpec.java new file mode 100644 index 0000000..968859b --- /dev/null +++ b/samza-core/src/main/java/org/apache/samza/operators/spec/AsyncFlatMapOperatorSpec.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.samza.operators.spec; + +import org.apache.samza.operators.functions.AsyncFlatMapFunction; +import org.apache.samza.operators.functions.ScheduledFunction; +import org.apache.samza.operators.functions.WatermarkFunction; + + +/** + * The spec for an operator that transforms each input message to a collection of output messages. + * + * @param <M> type of input message + * @param <OM> type of output messages + */ +public class AsyncFlatMapOperatorSpec<M, OM> extends OperatorSpec<M, OM> { + protected final AsyncFlatMapFunction<M, OM> transformFn; + + /** + * Constructor for a {@link AsyncFlatMapOperatorSpec}. + * + * @param transformFn the transformation function + * @param opId the unique ID for this {@link OperatorSpec} + */ + AsyncFlatMapOperatorSpec(AsyncFlatMapFunction<M, OM> transformFn, String opId) { + super(OpCode.FLAT_MAP, opId); + this.transformFn = transformFn; + } + + @Override + public WatermarkFunction getWatermarkFn() { + return this.transformFn instanceof WatermarkFunction ? (WatermarkFunction) this.transformFn : null; + } + + @Override + public ScheduledFunction getScheduledFn() { + return this.transformFn instanceof ScheduledFunction ? (ScheduledFunction) this.transformFn : null; + } + + public AsyncFlatMapFunction<M, OM> getTransformFn() { + return this.transformFn; + } +} diff --git a/samza-core/src/main/java/org/apache/samza/operators/spec/OperatorSpec.java b/samza-core/src/main/java/org/apache/samza/operators/spec/OperatorSpec.java index 1886d1b..a047889 100644 --- a/samza-core/src/main/java/org/apache/samza/operators/spec/OperatorSpec.java +++ b/samza-core/src/main/java/org/apache/samza/operators/spec/OperatorSpec.java @@ -51,7 +51,8 @@ public abstract class OperatorSpec<M, OM> implements Serializable { MERGE, PARTITION_BY, OUTPUT, - BROADCAST + BROADCAST, + ASYNC_FLAT_MAP } private final String opId; diff --git a/samza-core/src/main/java/org/apache/samza/operators/spec/OperatorSpecs.java b/samza-core/src/main/java/org/apache/samza/operators/spec/OperatorSpecs.java index ff3fe67..89d6b38 100644 --- a/samza-core/src/main/java/org/apache/samza/operators/spec/OperatorSpecs.java +++ b/samza-core/src/main/java/org/apache/samza/operators/spec/OperatorSpecs.java @@ -20,6 +20,7 @@ package org.apache.samza.operators.spec; import org.apache.samza.operators.KV; +import org.apache.samza.operators.functions.AsyncFlatMapFunction; import org.apache.samza.operators.functions.FilterFunction; import org.apache.samza.operators.functions.FlatMapFunction; import org.apache.samza.system.descriptors.InputTransformer; @@ -96,6 +97,21 @@ public class OperatorSpecs { } /** + * Creates a {@link AsyncFlatMapOperatorSpec} for {@link AsyncFlatMapFunction}. + * + * @param asyncFlatMapFn the transformation function + * @param opId the unique ID of the operator + * @param <M> type of input message + * @param <OM> type of output message + * @return the {@link AsyncFlatMapOperatorSpec} + */ + public static <M, OM> AsyncFlatMapOperatorSpec<M, OM> createAsyncOperatorSpec( + AsyncFlatMapFunction<? super M, ? extends OM> asyncFlatMapFn, String opId) { + return new AsyncFlatMapOperatorSpec<>((AsyncFlatMapFunction<M, OM>) asyncFlatMapFn, opId); + } + + + /** * Creates a {@link SinkOperatorSpec} for the sink operator. * * @param sinkFn the sink function provided by the user diff --git a/samza-core/src/main/java/org/apache/samza/task/StreamOperatorTask.java b/samza-core/src/main/java/org/apache/samza/task/StreamOperatorTask.java index 87131d7..387ad2e 100644 --- a/samza-core/src/main/java/org/apache/samza/task/StreamOperatorTask.java +++ b/samza-core/src/main/java/org/apache/samza/task/StreamOperatorTask.java @@ -18,6 +18,10 @@ */ package org.apache.samza.task; +import com.google.common.base.Preconditions; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionStage; +import org.apache.samza.SamzaException; import org.apache.samza.context.Context; import org.apache.samza.operators.OperatorSpecGraph; import org.apache.samza.operators.impl.InputOperatorImpl; @@ -37,7 +41,7 @@ import org.slf4j.LoggerFactory; * A {@link StreamTask} implementation that brings all the operator API implementation components together and * feeds the input messages into the user-defined transformation chains in {@link OperatorSpecGraph}. */ -public class StreamOperatorTask implements StreamTask, InitableTask, WindowableTask, ClosableTask { +public class StreamOperatorTask implements AsyncStreamTask, InitableTask, WindowableTask, ClosableTask { private static final Logger LOG = LoggerFactory.getLogger(StreamOperatorTask.class); private final OperatorSpecGraph specGraph; @@ -90,34 +94,56 @@ public class StreamOperatorTask implements StreamTask, InitableTask, WindowableT * @param ime incoming message envelope to process * @param collector the collector to send messages with * @param coordinator the coordinator to request commits or shutdown + * @param callback the task callback handle */ @Override - public final void process(IncomingMessageEnvelope ime, MessageCollector collector, TaskCoordinator coordinator) { + public final void processAsync(IncomingMessageEnvelope ime, MessageCollector collector, TaskCoordinator coordinator, + TaskCallback callback) { SystemStream systemStream = ime.getSystemStreamPartition().getSystemStream(); InputOperatorImpl inputOpImpl = operatorImplGraph.getInputOperator(systemStream); if (inputOpImpl != null) { - switch (MessageType.of(ime.getMessage())) { + CompletionStage<Void> processFuture; + MessageType messageType = MessageType.of(ime.getMessage()); + switch (messageType) { case USER_MESSAGE: - inputOpImpl.onMessage(ime, collector, coordinator); + processFuture = inputOpImpl.onMessageAsync(ime, collector, coordinator); break; case END_OF_STREAM: EndOfStreamMessage eosMessage = (EndOfStreamMessage) ime.getMessage(); - inputOpImpl.aggregateEndOfStream(eosMessage, ime.getSystemStreamPartition(), collector, coordinator); + processFuture = + inputOpImpl.aggregateEndOfStream(eosMessage, ime.getSystemStreamPartition(), collector, coordinator); break; case WATERMARK: WatermarkMessage watermarkMessage = (WatermarkMessage) ime.getMessage(); - inputOpImpl.aggregateWatermark(watermarkMessage, ime.getSystemStreamPartition(), collector, coordinator); + processFuture = + inputOpImpl.aggregateWatermark(watermarkMessage, ime.getSystemStreamPartition(), collector, coordinator); + break; + + default: + processFuture = failedFuture(new SamzaException("Unknown message type " + messageType + " encountered.")); break; } + + processFuture.whenComplete((val, ex) -> { + if (ex != null) { + callback.failure(ex); + } else { + callback.complete(); + } + }); } } @Override public final void window(MessageCollector collector, TaskCoordinator coordinator) { - operatorImplGraph.getAllInputOperators() - .forEach(inputOperator -> inputOperator.onTimer(collector, coordinator)); + CompletableFuture<Void> windowFuture = CompletableFuture.allOf(operatorImplGraph.getAllInputOperators() + .stream() + .map(inputOperator -> inputOperator.onTimer(collector, coordinator)) + .toArray(CompletableFuture[]::new)); + + windowFuture.join(); } @Override @@ -131,4 +157,12 @@ public class StreamOperatorTask implements StreamTask, InitableTask, WindowableT OperatorImplGraph getOperatorImplGraph() { return this.operatorImplGraph; } + + private static CompletableFuture<Void> failedFuture(Throwable ex) { + Preconditions.checkNotNull(ex); + CompletableFuture<Void> failedFuture = new CompletableFuture<>(); + failedFuture.completeExceptionally(ex); + + return failedFuture; + } } diff --git a/samza-core/src/main/java/org/apache/samza/task/TaskFactoryUtil.java b/samza-core/src/main/java/org/apache/samza/task/TaskFactoryUtil.java index b2297e1..7ad1a0a 100644 --- a/samza-core/src/main/java/org/apache/samza/task/TaskFactoryUtil.java +++ b/samza-core/src/main/java/org/apache/samza/task/TaskFactoryUtil.java @@ -48,7 +48,7 @@ public class TaskFactoryUtil { if (appDesc instanceof TaskApplicationDescriptorImpl) { return ((TaskApplicationDescriptorImpl) appDesc).getTaskFactory(); } else if (appDesc instanceof StreamApplicationDescriptorImpl) { - return (StreamTaskFactory) () -> new StreamOperatorTask( + return (AsyncStreamTaskFactory) () -> new StreamOperatorTask( ((StreamApplicationDescriptorImpl) appDesc).getOperatorSpecGraph()); } throw new IllegalArgumentException(String.format("ApplicationDescriptorImpl has to be either TaskApplicationDescriptorImpl or " diff --git a/samza-core/src/test/java/org/apache/samza/operators/TestJoinOperator.java b/samza-core/src/test/java/org/apache/samza/operators/TestJoinOperator.java index 90e59ae..b310e6f 100644 --- a/samza-core/src/test/java/org/apache/samza/operators/TestJoinOperator.java +++ b/samza-core/src/test/java/org/apache/samza/operators/TestJoinOperator.java @@ -41,6 +41,7 @@ import org.apache.samza.system.SystemStream; import org.apache.samza.system.SystemStreamPartition; import org.apache.samza.task.MessageCollector; import org.apache.samza.task.StreamOperatorTask; +import org.apache.samza.task.TaskCallback; import org.apache.samza.task.TaskCoordinator; import org.apache.samza.testUtils.StreamTestUtils; import org.apache.samza.testUtils.TestClock; @@ -69,6 +70,7 @@ public class TestJoinOperator { private static final Duration JOIN_TTL = Duration.ofMinutes(10); private final TaskCoordinator taskCoordinator = mock(TaskCoordinator.class); + private final TaskCallback taskCallback = mock(TaskCallback.class); private final Set<Integer> numbers = ImmutableSet.of(1, 2, 3, 4, 5, 6, 7, 8, 9, 10); @Test @@ -79,9 +81,9 @@ public class TestJoinOperator { MessageCollector messageCollector = envelope -> output.add((Integer) envelope.getMessage()); // push messages to first stream - numbers.forEach(n -> sot.process(new FirstStreamIME(n, n), messageCollector, taskCoordinator)); + numbers.forEach(n -> sot.processAsync(new FirstStreamIME(n, n), messageCollector, taskCoordinator, taskCallback)); // push messages to second stream with same keys - numbers.forEach(n -> sot.process(new SecondStreamIME(n, n), messageCollector, taskCoordinator)); + numbers.forEach(n -> sot.processAsync(new SecondStreamIME(n, n), messageCollector, taskCoordinator, taskCallback)); int outputSum = output.stream().reduce(0, (s, m) -> s + m); assertEquals(110, outputSum); @@ -118,7 +120,7 @@ public class TestJoinOperator { MessageCollector messageCollector = mock(MessageCollector.class); // push messages to first stream - numbers.forEach(n -> sot.process(new FirstStreamIME(n, n), messageCollector, taskCoordinator)); + numbers.forEach(n -> sot.processAsync(new FirstStreamIME(n, n), messageCollector, taskCoordinator, taskCallback)); // close should not be called till now sot.close(); @@ -137,9 +139,9 @@ public class TestJoinOperator { MessageCollector messageCollector = envelope -> output.add((Integer) envelope.getMessage()); // push messages to second stream - numbers.forEach(n -> sot.process(new SecondStreamIME(n, n), messageCollector, taskCoordinator)); + numbers.forEach(n -> sot.processAsync(new SecondStreamIME(n, n), messageCollector, taskCoordinator, taskCallback)); // push messages to first stream with same keys - numbers.forEach(n -> sot.process(new FirstStreamIME(n, n), messageCollector, taskCoordinator)); + numbers.forEach(n -> sot.processAsync(new FirstStreamIME(n, n), messageCollector, taskCoordinator, taskCallback)); int outputSum = output.stream().reduce(0, (s, m) -> s + m); assertEquals(110, outputSum); @@ -153,9 +155,9 @@ public class TestJoinOperator { MessageCollector messageCollector = envelope -> output.add((Integer) envelope.getMessage()); // push messages to first stream - numbers.forEach(n -> sot.process(new FirstStreamIME(n, n), messageCollector, taskCoordinator)); + numbers.forEach(n -> sot.processAsync(new FirstStreamIME(n, n), messageCollector, taskCoordinator, taskCallback)); // push messages to second stream with different keys - numbers.forEach(n -> sot.process(new SecondStreamIME(n + 100, n), messageCollector, taskCoordinator)); + numbers.forEach(n -> sot.processAsync(new SecondStreamIME(n + 100, n), messageCollector, taskCoordinator, taskCallback)); assertTrue(output.isEmpty()); } @@ -168,9 +170,9 @@ public class TestJoinOperator { MessageCollector messageCollector = envelope -> output.add((Integer) envelope.getMessage()); // push messages to second stream - numbers.forEach(n -> sot.process(new SecondStreamIME(n, n), messageCollector, taskCoordinator)); + numbers.forEach(n -> sot.processAsync(new SecondStreamIME(n, n), messageCollector, taskCoordinator, taskCallback)); // push messages to first stream with different keys - numbers.forEach(n -> sot.process(new FirstStreamIME(n + 100, n), messageCollector, taskCoordinator)); + numbers.forEach(n -> sot.processAsync(new FirstStreamIME(n + 100, n), messageCollector, taskCoordinator, taskCallback)); assertTrue(output.isEmpty()); } @@ -183,11 +185,11 @@ public class TestJoinOperator { MessageCollector messageCollector = envelope -> output.add((Integer) envelope.getMessage()); // push messages to first stream - numbers.forEach(n -> sot.process(new FirstStreamIME(n, n), messageCollector, taskCoordinator)); + numbers.forEach(n -> sot.processAsync(new FirstStreamIME(n, n), messageCollector, taskCoordinator, taskCallback)); // push messages to first stream again with same keys but different values - numbers.forEach(n -> sot.process(new FirstStreamIME(n, 2 * n), messageCollector, taskCoordinator)); + numbers.forEach(n -> sot.processAsync(new FirstStreamIME(n, 2 * n), messageCollector, taskCoordinator, taskCallback)); // push messages to second stream with same key - numbers.forEach(n -> sot.process(new SecondStreamIME(n, n), messageCollector, taskCoordinator)); + numbers.forEach(n -> sot.processAsync(new SecondStreamIME(n, n), messageCollector, taskCoordinator, taskCallback)); int outputSum = output.stream().reduce(0, (s, m) -> s + m); assertEquals(165, outputSum); // should use latest messages in the first stream @@ -201,11 +203,11 @@ public class TestJoinOperator { MessageCollector messageCollector = envelope -> output.add((Integer) envelope.getMessage()); // push messages to second stream - numbers.forEach(n -> sot.process(new SecondStreamIME(n, n), messageCollector, taskCoordinator)); + numbers.forEach(n -> sot.processAsync(new SecondStreamIME(n, n), messageCollector, taskCoordinator, taskCallback)); // push messages to second stream again with same keys but different values - numbers.forEach(n -> sot.process(new SecondStreamIME(n, 2 * n), messageCollector, taskCoordinator)); + numbers.forEach(n -> sot.processAsync(new SecondStreamIME(n, 2 * n), messageCollector, taskCoordinator, taskCallback)); // push messages to first stream with same key - numbers.forEach(n -> sot.process(new FirstStreamIME(n, n), messageCollector, taskCoordinator)); + numbers.forEach(n -> sot.processAsync(new FirstStreamIME(n, n), messageCollector, taskCoordinator, taskCallback)); int outputSum = output.stream().reduce(0, (s, m) -> s + m); assertEquals(165, outputSum); // should use latest messages in the second stream @@ -219,9 +221,9 @@ public class TestJoinOperator { MessageCollector messageCollector = envelope -> output.add((Integer) envelope.getMessage()); // push messages to first stream - numbers.forEach(n -> sot.process(new FirstStreamIME(n, n), messageCollector, taskCoordinator)); + numbers.forEach(n -> sot.processAsync(new FirstStreamIME(n, n), messageCollector, taskCoordinator, taskCallback)); // push messages to second stream with same key - numbers.forEach(n -> sot.process(new SecondStreamIME(n, n), messageCollector, taskCoordinator)); + numbers.forEach(n -> sot.processAsync(new SecondStreamIME(n, n), messageCollector, taskCoordinator, taskCallback)); int outputSum = output.stream().reduce(0, (s, m) -> s + m); assertEquals(110, outputSum); @@ -229,7 +231,7 @@ public class TestJoinOperator { output.clear(); // push messages to first stream with same keys once again. - numbers.forEach(n -> sot.process(new FirstStreamIME(n, n), messageCollector, taskCoordinator)); + numbers.forEach(n -> sot.processAsync(new FirstStreamIME(n, n), messageCollector, taskCoordinator, taskCallback)); int newOutputSum = output.stream().reduce(0, (s, m) -> s + m); assertEquals(110, newOutputSum); // should produce the same output as before } @@ -242,9 +244,9 @@ public class TestJoinOperator { MessageCollector messageCollector = envelope -> output.add((Integer) envelope.getMessage()); // push messages to first stream - numbers.forEach(n -> sot.process(new FirstStreamIME(n, n), messageCollector, taskCoordinator)); + numbers.forEach(n -> sot.processAsync(new FirstStreamIME(n, n), messageCollector, taskCoordinator, taskCallback)); // push messages to second stream with same key - numbers.forEach(n -> sot.process(new SecondStreamIME(n, n), messageCollector, taskCoordinator)); + numbers.forEach(n -> sot.processAsync(new SecondStreamIME(n, n), messageCollector, taskCoordinator, taskCallback)); int outputSum = output.stream().reduce(0, (s, m) -> s + m); assertEquals(110, outputSum); @@ -252,7 +254,7 @@ public class TestJoinOperator { output.clear(); // push messages to second stream with same keys once again. - numbers.forEach(n -> sot.process(new SecondStreamIME(n, n), messageCollector, taskCoordinator)); + numbers.forEach(n -> sot.processAsync(new SecondStreamIME(n, n), messageCollector, taskCoordinator, taskCallback)); int newOutputSum = output.stream().reduce(0, (s, m) -> s + m); assertEquals(110, newOutputSum); // should produce the same output as before } @@ -266,13 +268,13 @@ public class TestJoinOperator { MessageCollector messageCollector = envelope -> output.add((Integer) envelope.getMessage()); // push messages to first stream - numbers.forEach(n -> sot.process(new FirstStreamIME(n, n), messageCollector, taskCoordinator)); + numbers.forEach(n -> sot.processAsync(new FirstStreamIME(n, n), messageCollector, taskCoordinator, taskCallback)); testClock.advanceTime(JOIN_TTL.plus(Duration.ofMinutes(1))); // 1 minute after ttl sot.window(messageCollector, taskCoordinator); // should expire first stream messages // push messages to second stream with same key - numbers.forEach(n -> sot.process(new SecondStreamIME(n, n), messageCollector, taskCoordinator)); + numbers.forEach(n -> sot.processAsync(new SecondStreamIME(n, n), messageCollector, taskCoordinator, taskCallback)); assertTrue(output.isEmpty()); } @@ -286,13 +288,13 @@ public class TestJoinOperator { MessageCollector messageCollector = envelope -> output.add((Integer) envelope.getMessage()); // push messages to second stream - numbers.forEach(n -> sot.process(new SecondStreamIME(n, n), messageCollector, taskCoordinator)); + numbers.forEach(n -> sot.processAsync(new SecondStreamIME(n, n), messageCollector, taskCoordinator, taskCallback)); testClock.advanceTime(JOIN_TTL.plus(Duration.ofMinutes(1))); // 1 minute after ttl sot.window(messageCollector, taskCoordinator); // should expire second stream messages // push messages to first stream with same key - numbers.forEach(n -> sot.process(new FirstStreamIME(n, n), messageCollector, taskCoordinator)); + numbers.forEach(n -> sot.processAsync(new FirstStreamIME(n, n), messageCollector, taskCoordinator, taskCallback)); assertTrue(output.isEmpty()); } diff --git a/samza-core/src/test/java/org/apache/samza/operators/impl/TestStreamOperatorImpl.java b/samza-core/src/test/java/org/apache/samza/operators/impl/TestFlatmapOperatorImpl.java similarity index 90% rename from samza-core/src/test/java/org/apache/samza/operators/impl/TestStreamOperatorImpl.java rename to samza-core/src/test/java/org/apache/samza/operators/impl/TestFlatmapOperatorImpl.java index ae05305..378f574 100644 --- a/samza-core/src/test/java/org/apache/samza/operators/impl/TestStreamOperatorImpl.java +++ b/samza-core/src/test/java/org/apache/samza/operators/impl/TestFlatmapOperatorImpl.java @@ -35,7 +35,7 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; -public class TestStreamOperatorImpl { +public class TestFlatmapOperatorImpl { @Test @SuppressWarnings("unchecked") @@ -43,8 +43,8 @@ public class TestStreamOperatorImpl { StreamOperatorSpec<TestMessageEnvelope, TestOutputMessageEnvelope> mockOp = mock(StreamOperatorSpec.class); FlatMapFunction<TestMessageEnvelope, TestOutputMessageEnvelope> txfmFn = mock(FlatMapFunction.class); when(mockOp.getTransformFn()).thenReturn(txfmFn); - StreamOperatorImpl<TestMessageEnvelope, TestOutputMessageEnvelope> opImpl = - new StreamOperatorImpl<>(mockOp); + FlatmapOperatorImpl<TestMessageEnvelope, TestOutputMessageEnvelope> opImpl = + new FlatmapOperatorImpl<>(mockOp); TestMessageEnvelope inMsg = mock(TestMessageEnvelope.class); Collection<TestOutputMessageEnvelope> mockOutputs = mock(Collection.class); when(txfmFn.apply(inMsg)).thenReturn(mockOutputs); @@ -62,8 +62,8 @@ public class TestStreamOperatorImpl { FlatMapFunction<TestMessageEnvelope, TestOutputMessageEnvelope> txfmFn = mock(FlatMapFunction.class); when(mockOp.getTransformFn()).thenReturn(txfmFn); - StreamOperatorImpl<TestMessageEnvelope, TestOutputMessageEnvelope> opImpl = - new StreamOperatorImpl<>(mockOp); + FlatmapOperatorImpl<TestMessageEnvelope, TestOutputMessageEnvelope> opImpl = + new FlatmapOperatorImpl<>(mockOp); // ensure that close is not called yet verify(txfmFn, times(0)).close(); diff --git a/samza-core/src/test/java/org/apache/samza/operators/impl/TestOperatorImpl.java b/samza-core/src/test/java/org/apache/samza/operators/impl/TestOperatorImpl.java index 6279960..aea35aa 100644 --- a/samza-core/src/test/java/org/apache/samza/operators/impl/TestOperatorImpl.java +++ b/samza-core/src/test/java/org/apache/samza/operators/impl/TestOperatorImpl.java @@ -21,6 +21,8 @@ package org.apache.samza.operators.impl; import java.util.Collection; import java.util.Collections; import java.util.Set; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionStage; import org.apache.samza.context.Context; import org.apache.samza.context.InternalTaskContext; import org.apache.samza.context.MockContext; @@ -85,13 +87,15 @@ public class TestOperatorImpl { // register a couple of operators OperatorImpl mockNextOpImpl1 = mock(OperatorImpl.class); when(mockNextOpImpl1.getOperatorSpec()).thenReturn(new TestOpSpec()); - when(mockNextOpImpl1.handleMessage(anyObject(), anyObject(), anyObject())).thenReturn(Collections.emptyList()); + when(mockNextOpImpl1.handleMessageAsync(anyObject(), anyObject(), anyObject())) + .thenReturn(CompletableFuture.completedFuture(Collections.emptyList())); mockNextOpImpl1.init(this.internalTaskContext); opImpl.registerNextOperator(mockNextOpImpl1); OperatorImpl mockNextOpImpl2 = mock(OperatorImpl.class); when(mockNextOpImpl2.getOperatorSpec()).thenReturn(new TestOpSpec()); - when(mockNextOpImpl2.handleMessage(anyObject(), anyObject(), anyObject())).thenReturn(Collections.emptyList()); + when(mockNextOpImpl2.handleMessageAsync(anyObject(), anyObject(), anyObject())) + .thenReturn(CompletableFuture.completedFuture(Collections.emptyList())); mockNextOpImpl2.init(this.internalTaskContext); opImpl.registerNextOperator(mockNextOpImpl2); @@ -101,8 +105,8 @@ public class TestOperatorImpl { opImpl.onMessage(mock(Object.class), mockCollector, mockCoordinator); // verify that it propagates its handleMessage results to next operators - verify(mockNextOpImpl1, times(1)).handleMessage(mockTestOpImplOutput, mockCollector, mockCoordinator); - verify(mockNextOpImpl2, times(1)).handleMessage(mockTestOpImplOutput, mockCollector, mockCoordinator); + verify(mockNextOpImpl1, times(1)).handleMessageAsync(mockTestOpImplOutput, mockCollector, mockCoordinator); + verify(mockNextOpImpl2, times(1)).handleMessageAsync(mockTestOpImplOutput, mockCollector, mockCoordinator); } @Test @@ -137,13 +141,15 @@ public class TestOperatorImpl { // register a couple of operators OperatorImpl mockNextOpImpl1 = mock(OperatorImpl.class); when(mockNextOpImpl1.getOperatorSpec()).thenReturn(new TestOpSpec()); - when(mockNextOpImpl1.handleMessage(anyObject(), anyObject(), anyObject())).thenReturn(Collections.emptyList()); + when(mockNextOpImpl1.handleMessageAsync(anyObject(), anyObject(), anyObject())) + .thenReturn(CompletableFuture.completedFuture(Collections.emptyList())); mockNextOpImpl1.init(this.internalTaskContext); opImpl.registerNextOperator(mockNextOpImpl1); OperatorImpl mockNextOpImpl2 = mock(OperatorImpl.class); when(mockNextOpImpl2.getOperatorSpec()).thenReturn(new TestOpSpec()); - when(mockNextOpImpl2.handleMessage(anyObject(), anyObject(), anyObject())).thenReturn(Collections.emptyList()); + when(mockNextOpImpl2.handleMessageAsync(anyObject(), anyObject(), anyObject())) + .thenReturn(CompletableFuture.completedFuture(Collections.emptyList())); mockNextOpImpl2.init(this.internalTaskContext); opImpl.registerNextOperator(mockNextOpImpl2); @@ -153,8 +159,8 @@ public class TestOperatorImpl { opImpl.onTimer(mockCollector, mockCoordinator); // verify that it propagates its handleTimer results to next operators - verify(mockNextOpImpl1, times(1)).handleMessage(mockTestOpImplOutput, mockCollector, mockCoordinator); - verify(mockNextOpImpl2, times(1)).handleMessage(mockTestOpImplOutput, mockCollector, mockCoordinator); + verify(mockNextOpImpl1, times(1)).handleMessageAsync(mockTestOpImplOutput, mockCollector, mockCoordinator); + verify(mockNextOpImpl2, times(1)).handleMessageAsync(mockTestOpImplOutput, mockCollector, mockCoordinator); // verify that it propagates the timer tick to next operators verify(mockNextOpImpl1, times(1)).handleTimer(mockCollector, mockCoordinator); @@ -197,9 +203,9 @@ public class TestOperatorImpl { protected void handleInit(Context context) {} @Override - public Collection<Object> handleMessage(Object message, - MessageCollector collector, TaskCoordinator coordinator) { - return Collections.singletonList(mockOutput); + public CompletionStage<Collection<Object>> handleMessageAsync(Object message, MessageCollector collector, + TaskCoordinator coordinator) { + return CompletableFuture.completedFuture(Collections.singletonList(mockOutput)); } @Override diff --git a/samza-core/src/test/java/org/apache/samza/operators/impl/TestOperatorImplGraph.java b/samza-core/src/test/java/org/apache/samza/operators/impl/TestOperatorImplGraph.java index 7fbeb74..61ee7a2 100644 --- a/samza-core/src/test/java/org/apache/samza/operators/impl/TestOperatorImplGraph.java +++ b/samza-core/src/test/java/org/apache/samza/operators/impl/TestOperatorImplGraph.java @@ -150,11 +150,11 @@ public class TestOperatorImplGraph { InputOperatorImpl inputOpImpl = opImplGraph.getInputOperator(new SystemStream(inputSystem, inputPhysicalName)); assertEquals(1, inputOpImpl.registeredOperators.size()); - OperatorImpl filterOpImpl = (StreamOperatorImpl) inputOpImpl.registeredOperators.iterator().next(); + OperatorImpl filterOpImpl = (FlatmapOperatorImpl) inputOpImpl.registeredOperators.iterator().next(); assertEquals(1, filterOpImpl.registeredOperators.size()); assertEquals(OpCode.FILTER, filterOpImpl.getOperatorSpec().getOpCode()); - OperatorImpl mapOpImpl = (StreamOperatorImpl) filterOpImpl.registeredOperators.iterator().next(); + OperatorImpl mapOpImpl = (FlatmapOperatorImpl) filterOpImpl.registeredOperators.iterator().next(); assertEquals(1, mapOpImpl.registeredOperators.size()); assertEquals(OpCode.MAP, mapOpImpl.getOperatorSpec().getOpCode()); diff --git a/samza-core/src/test/java/org/apache/samza/operators/impl/TestStreamTableJoinOperatorImpl.java b/samza-core/src/test/java/org/apache/samza/operators/impl/TestStreamTableJoinOperatorImpl.java index 8fd161b..69950d9 100644 --- a/samza-core/src/test/java/org/apache/samza/operators/impl/TestStreamTableJoinOperatorImpl.java +++ b/samza-core/src/test/java/org/apache/samza/operators/impl/TestStreamTableJoinOperatorImpl.java @@ -18,6 +18,7 @@ */ package org.apache.samza.operators.impl; +import java.util.concurrent.CompletableFuture; import junit.framework.Assert; import org.apache.samza.SamzaException; import org.apache.samza.context.Context; @@ -72,8 +73,8 @@ public class TestStreamTableJoinOperatorImpl { } }); ReadWriteTable table = mock(ReadWriteTable.class); - when(table.get("1")).thenReturn("r1"); - when(table.get("2")).thenReturn(null); + when(table.getAsync("1")).thenReturn(CompletableFuture.completedFuture("r1")); + when(table.getAsync("2")).thenReturn(CompletableFuture.completedFuture(null)); Context context = new MockContext(); when(context.getTaskContext().getTable(tableId)).thenReturn(table); diff --git a/samza-core/src/test/java/org/apache/samza/operators/impl/TestWindowOperator.java b/samza-core/src/test/java/org/apache/samza/operators/impl/TestWindowOperator.java index c588b3c..594cd4a 100644 --- a/samza-core/src/test/java/org/apache/samza/operators/impl/TestWindowOperator.java +++ b/samza-core/src/test/java/org/apache/samza/operators/impl/TestWindowOperator.java @@ -61,6 +61,7 @@ import org.apache.samza.system.SystemStream; import org.apache.samza.system.SystemStreamPartition; import org.apache.samza.task.MessageCollector; import org.apache.samza.task.StreamOperatorTask; +import org.apache.samza.task.TaskCallback; import org.apache.samza.task.TaskCoordinator; import org.apache.samza.testUtils.TestClock; import org.junit.Assert; @@ -74,6 +75,7 @@ import static org.mockito.Mockito.when; public class TestWindowOperator { private final TaskCoordinator taskCoordinator = mock(TaskCoordinator.class); + private final TaskCallback taskCallback = mock(TaskCallback.class); private final List<Integer> integers = ImmutableList.of(1, 2, 1, 2, 1, 2, 1, 2, 3); private Context context; private Config config; @@ -113,7 +115,7 @@ public class TestWindowOperator { task.init(this.context); MessageCollector messageCollector = envelope -> windowPanes.add((WindowPane<Integer, Collection<IntegerEnvelope>>) envelope.getMessage()); - integers.forEach(n -> task.process(new IntegerEnvelope(n), messageCollector, taskCoordinator)); + integers.forEach(n -> task.processAsync(new IntegerEnvelope(n), messageCollector, taskCoordinator, taskCallback)); testClock.advanceTime(Duration.ofSeconds(1)); task.window(messageCollector, taskCoordinator); @@ -149,7 +151,7 @@ public class TestWindowOperator { envelope -> windowPanes.add((WindowPane<Integer, Collection<IntegerEnvelope>>) envelope.getMessage()); Assert.assertEquals(windowPanes.size(), 0); - integers.forEach(n -> task.process(new IntegerEnvelope(n), messageCollector, taskCoordinator)); + integers.forEach(n -> task.processAsync(new IntegerEnvelope(n), messageCollector, taskCoordinator, taskCallback)); Assert.assertEquals(windowPanes.size(), 0); testClock.advanceTime(Duration.ofSeconds(1)); @@ -173,7 +175,7 @@ public class TestWindowOperator { StreamOperatorTask task = new StreamOperatorTask(sgb, testClock); task.init(this.context); MessageCollector messageCollector = envelope -> windowPanes.add((WindowPane<Integer, Integer>) envelope.getMessage()); - integers.forEach(n -> task.process(new IntegerEnvelope(n), messageCollector, taskCoordinator)); + integers.forEach(n -> task.processAsync(new IntegerEnvelope(n), messageCollector, taskCoordinator, taskCallback)); testClock.advanceTime(Duration.ofSeconds(1)); task.window(messageCollector, taskCoordinator); @@ -196,7 +198,7 @@ public class TestWindowOperator { MessageCollector messageCollector = envelope -> windowPanes.add((WindowPane<Integer, Collection<IntegerEnvelope>>) envelope.getMessage()); - integers.forEach(n -> task.process(new IntegerEnvelope(n), messageCollector, taskCoordinator)); + integers.forEach(n -> task.processAsync(new IntegerEnvelope(n), messageCollector, taskCoordinator, taskCallback)); testClock.advanceTime(Duration.ofSeconds(1)); task.window(messageCollector, taskCoordinator); @@ -224,8 +226,8 @@ public class TestWindowOperator { task.init(this.context); MessageCollector messageCollector = envelope -> windowPanes.add((WindowPane<Integer, Collection<IntegerEnvelope>>) envelope.getMessage()); - task.process(new IntegerEnvelope(1), messageCollector, taskCoordinator); - task.process(new IntegerEnvelope(1), messageCollector, taskCoordinator); + task.processAsync(new IntegerEnvelope(1), messageCollector, taskCoordinator, taskCallback); + task.processAsync(new IntegerEnvelope(1), messageCollector, taskCoordinator, taskCallback); testClock.advanceTime(Duration.ofSeconds(1)); task.window(messageCollector, taskCoordinator); @@ -233,10 +235,10 @@ public class TestWindowOperator { Assert.assertEquals(windowPanes.get(0).getKey().getPaneId(), "1"); Assert.assertEquals(windowPanes.get(0).getKey().getKey(), new Integer(1)); - task.process(new IntegerEnvelope(2), messageCollector, taskCoordinator); - task.process(new IntegerEnvelope(2), messageCollector, taskCoordinator); - task.process(new IntegerEnvelope(3), messageCollector, taskCoordinator); - task.process(new IntegerEnvelope(3), messageCollector, taskCoordinator); + task.processAsync(new IntegerEnvelope(2), messageCollector, taskCoordinator, taskCallback); + task.processAsync(new IntegerEnvelope(2), messageCollector, taskCoordinator, taskCallback); + task.processAsync(new IntegerEnvelope(3), messageCollector, taskCoordinator, taskCallback); + task.processAsync(new IntegerEnvelope(3), messageCollector, taskCoordinator, taskCallback); testClock.advanceTime(Duration.ofSeconds(1)); task.window(messageCollector, taskCoordinator); @@ -249,8 +251,8 @@ public class TestWindowOperator { Assert.assertEquals((windowPanes.get(1).getMessage()).size(), 2); Assert.assertEquals((windowPanes.get(2).getMessage()).size(), 2); - task.process(new IntegerEnvelope(2), messageCollector, taskCoordinator); - task.process(new IntegerEnvelope(2), messageCollector, taskCoordinator); + task.processAsync(new IntegerEnvelope(2), messageCollector, taskCoordinator, taskCallback); + task.processAsync(new IntegerEnvelope(2), messageCollector, taskCoordinator, taskCallback); testClock.advanceTime(Duration.ofSeconds(1)); task.window(messageCollector, taskCoordinator); @@ -272,15 +274,15 @@ public class TestWindowOperator { envelope -> windowPanes.add((WindowPane<Integer, Collection<IntegerEnvelope>>) envelope.getMessage()); task.init(this.context); - task.process(new IntegerEnvelope(1), messageCollector, taskCoordinator); - task.process(new IntegerEnvelope(1), messageCollector, taskCoordinator); + task.processAsync(new IntegerEnvelope(1), messageCollector, taskCoordinator, taskCallback); + task.processAsync(new IntegerEnvelope(1), messageCollector, taskCoordinator, taskCallback); testClock.advanceTime(Duration.ofSeconds(1)); - task.process(new IntegerEnvelope(2), messageCollector, taskCoordinator); - task.process(new IntegerEnvelope(2), messageCollector, taskCoordinator); + task.processAsync(new IntegerEnvelope(2), messageCollector, taskCoordinator, taskCallback); + task.processAsync(new IntegerEnvelope(2), messageCollector, taskCoordinator, taskCallback); - task.process(new IntegerEnvelope(2), messageCollector, taskCoordinator); - task.process(new IntegerEnvelope(2), messageCollector, taskCoordinator); + task.processAsync(new IntegerEnvelope(2), messageCollector, taskCoordinator, taskCallback); + task.processAsync(new IntegerEnvelope(2), messageCollector, taskCoordinator, taskCallback); testClock.advanceTime(Duration.ofSeconds(1)); task.window(messageCollector, taskCoordinator); @@ -303,16 +305,16 @@ public class TestWindowOperator { List<WindowPane<Integer, Collection<IntegerEnvelope>>> windowPanes = new ArrayList<>(); MessageCollector messageCollector = envelope -> windowPanes.add((WindowPane<Integer, Collection<IntegerEnvelope>>) envelope.getMessage()); - task.process(new IntegerEnvelope(1), messageCollector, taskCoordinator); - task.process(new IntegerEnvelope(1), messageCollector, taskCoordinator); + task.processAsync(new IntegerEnvelope(1), messageCollector, taskCoordinator, taskCallback); + task.processAsync(new IntegerEnvelope(1), messageCollector, taskCoordinator, taskCallback); Assert.assertEquals(windowPanes.size(), 1); Assert.assertEquals(windowPanes.get(0).getKey().getPaneId(), "0"); Assert.assertEquals(windowPanes.get(0).getKey().getKey(), new Integer(1)); Assert.assertEquals(windowPanes.get(0).getFiringType(), FiringType.EARLY); - task.process(new IntegerEnvelope(1), messageCollector, taskCoordinator); - task.process(new IntegerEnvelope(1), messageCollector, taskCoordinator); - task.process(new IntegerEnvelope(1), messageCollector, taskCoordinator); + task.processAsync(new IntegerEnvelope(1), messageCollector, taskCoordinator, taskCallback); + task.processAsync(new IntegerEnvelope(1), messageCollector, taskCoordinator, taskCallback); + task.processAsync(new IntegerEnvelope(1), messageCollector, taskCoordinator, taskCallback); Assert.assertEquals(windowPanes.size(), 1); @@ -324,7 +326,7 @@ public class TestWindowOperator { Assert.assertEquals(windowPanes.get(1).getKey().getPaneId(), "0"); Assert.assertEquals(windowPanes.get(1).getFiringType(), FiringType.DEFAULT); - task.process(new IntegerEnvelope(3), messageCollector, taskCoordinator); + task.processAsync(new IntegerEnvelope(3), messageCollector, taskCoordinator, taskCallback); testClock.advanceTime(Duration.ofSeconds(1)); task.window(messageCollector, taskCoordinator); @@ -347,8 +349,8 @@ public class TestWindowOperator { List<WindowPane<Integer, Collection<IntegerEnvelope>>> windowPanes = new ArrayList<>(); MessageCollector messageCollector = envelope -> windowPanes.add((WindowPane<Integer, Collection<IntegerEnvelope>>) envelope.getMessage()); - task.process(new IntegerEnvelope(1), messageCollector, taskCoordinator); - task.process(new IntegerEnvelope(1), messageCollector, taskCoordinator); + task.processAsync(new IntegerEnvelope(1), messageCollector, taskCoordinator, taskCallback); + task.processAsync(new IntegerEnvelope(1), messageCollector, taskCoordinator, taskCallback); //assert that the count trigger fired Assert.assertEquals(windowPanes.size(), 1); @@ -358,9 +360,9 @@ public class TestWindowOperator { //assert that the triggering of the count trigger cancelled the inner timeSinceFirstMessage trigger Assert.assertEquals(windowPanes.size(), 1); - task.process(new IntegerEnvelope(1), messageCollector, taskCoordinator); - task.process(new IntegerEnvelope(1), messageCollector, taskCoordinator); - task.process(new IntegerEnvelope(1), messageCollector, taskCoordinator); + task.processAsync(new IntegerEnvelope(1), messageCollector, taskCoordinator, taskCallback); + task.processAsync(new IntegerEnvelope(1), messageCollector, taskCoordinator, taskCallback); + task.processAsync(new IntegerEnvelope(1), messageCollector, taskCoordinator, taskCallback); //advance timer by 500 more millis to enable the default trigger testClock.advanceTime(Duration.ofMillis(500)); @@ -373,7 +375,7 @@ public class TestWindowOperator { Assert.assertEquals(windowPanes.get(1).getKey().getPaneId(), "0"); Assert.assertEquals((windowPanes.get(1).getMessage()).size(), 5); - task.process(new IntegerEnvelope(1), messageCollector, taskCoordinator); + task.processAsync(new IntegerEnvelope(1), messageCollector, taskCoordinator, taskCallback); //advance timer by 500 millis to enable the inner timeSinceFirstMessage trigger testClock.advanceTime(Duration.ofMillis(500)); @@ -407,24 +409,24 @@ public class TestWindowOperator { StreamOperatorTask task = new StreamOperatorTask(sgb, testClock); task.init(this.context); - task.process(new IntegerEnvelope(1), messageCollector, taskCoordinator); + task.processAsync(new IntegerEnvelope(1), messageCollector, taskCoordinator, taskCallback); - task.process(new IntegerEnvelope(1), messageCollector, taskCoordinator); + task.processAsync(new IntegerEnvelope(1), messageCollector, taskCoordinator, taskCallback); //assert that the count trigger fired Assert.assertEquals(windowPanes.size(), 1); //advance the timer to enable the potential triggering of the inner timeSinceFirstMessage trigger - task.process(new IntegerEnvelope(1), messageCollector, taskCoordinator); + task.processAsync(new IntegerEnvelope(1), messageCollector, taskCoordinator, taskCallback); testClock.advanceTime(Duration.ofMillis(500)); //assert that the triggering of the count trigger cancelled the inner timeSinceFirstMessage trigger task.window(messageCollector, taskCoordinator); Assert.assertEquals(windowPanes.size(), 2); - task.process(new IntegerEnvelope(1), messageCollector, taskCoordinator); - task.process(new IntegerEnvelope(1), messageCollector, taskCoordinator); + task.processAsync(new IntegerEnvelope(1), messageCollector, taskCoordinator, taskCallback); + task.processAsync(new IntegerEnvelope(1), messageCollector, taskCoordinator, taskCallback); Assert.assertEquals(windowPanes.size(), 3); - task.process(new IntegerEnvelope(1), messageCollector, taskCoordinator); + task.processAsync(new IntegerEnvelope(1), messageCollector, taskCoordinator, taskCallback); //advance timer by 500 more millis to enable the default trigger testClock.advanceTime(Duration.ofMillis(500)); task.window(messageCollector, taskCoordinator); @@ -448,7 +450,7 @@ public class TestWindowOperator { Assert.assertEquals(windowPanes.size(), 0); List<Integer> integerList = ImmutableList.of(1, 2, 1, 2, 1); - integerList.forEach(n -> task.process(new IntegerEnvelope(n), messageCollector, taskCoordinator)); + integerList.forEach(n -> task.processAsync(new IntegerEnvelope(n), messageCollector, taskCoordinator, taskCallback)); // early triggers should emit (1,2) and (1,2) in the same window. Assert.assertEquals(windowPanes.size(), 2); @@ -458,7 +460,7 @@ public class TestWindowOperator { final IncomingMessageEnvelope endOfStream = IncomingMessageEnvelope.buildEndOfStreamEnvelope( new SystemStreamPartition("kafka", "integers", new Partition(0))); - task.process(endOfStream, messageCollector, taskCoordinator); + task.processAsync(endOfStream, messageCollector, taskCoordinator, taskCallback); // end of stream flushes the last entry (1) Assert.assertEquals(windowPanes.size(), 3); @@ -479,18 +481,18 @@ public class TestWindowOperator { MessageCollector messageCollector = envelope -> windowPanes.add((WindowPane<Integer, Collection<IntegerEnvelope>>) envelope.getMessage()); - task.process(new IntegerEnvelope(1), messageCollector, taskCoordinator); - task.process(new IntegerEnvelope(1), messageCollector, taskCoordinator); + task.processAsync(new IntegerEnvelope(1), messageCollector, taskCoordinator, taskCallback); + task.processAsync(new IntegerEnvelope(1), messageCollector, taskCoordinator, taskCallback); testClock.advanceTime(1000); task.window(messageCollector, taskCoordinator); Assert.assertEquals(windowPanes.size(), 1); - task.process(new IntegerEnvelope(1), messageCollector, taskCoordinator); - task.process(new IntegerEnvelope(1), messageCollector, taskCoordinator); + task.processAsync(new IntegerEnvelope(1), messageCollector, taskCoordinator, taskCallback); + task.processAsync(new IntegerEnvelope(1), messageCollector, taskCoordinator, taskCallback); final IncomingMessageEnvelope endOfStream = IncomingMessageEnvelope.buildEndOfStreamEnvelope( new SystemStreamPartition("kafka", "integers", new Partition(0))); - task.process(endOfStream, messageCollector, taskCoordinator); + task.processAsync(endOfStream, messageCollector, taskCoordinator, taskCallback); Assert.assertEquals(windowPanes.size(), 2); Assert.assertEquals(windowPanes.get(0).getMessage().size(), 2); verify(taskCoordinator, times(1)).commit(TaskCoordinator.RequestScope.CURRENT_TASK); @@ -510,14 +512,14 @@ public class TestWindowOperator { MessageCollector messageCollector = envelope -> windowPanes.add((WindowPane<Integer, Collection<IntegerEnvelope>>) envelope.getMessage()); - task.process(new IntegerEnvelope(1), messageCollector, taskCoordinator); - task.process(new IntegerEnvelope(1), messageCollector, taskCoordinator); - task.process(new IntegerEnvelope(1), messageCollector, taskCoordinator); - task.process(new IntegerEnvelope(1), messageCollector, taskCoordinator); + task.processAsync(new IntegerEnvelope(1), messageCollector, taskCoordinator, taskCallback); + task.processAsync(new IntegerEnvelope(1), messageCollector, taskCoordinator, taskCallback); + task.processAsync(new IntegerEnvelope(1), messageCollector, taskCoordinator, taskCallback); + task.processAsync(new IntegerEnvelope(1), messageCollector, taskCoordinator, taskCallback); final IncomingMessageEnvelope endOfStream = IncomingMessageEnvelope.buildEndOfStreamEnvelope( new SystemStreamPartition("kafka", "integers", new Partition(0))); - task.process(endOfStream, messageCollector, taskCoordinator); + task.processAsync(endOfStream, messageCollector, taskCoordinator, taskCallback); Assert.assertEquals(windowPanes.size(), 1); Assert.assertEquals(windowPanes.get(0).getMessage().size(), 4); verify(taskCoordinator, times(1)).commit(TaskCoordinator.RequestScope.CURRENT_TASK); diff --git a/samza-core/src/test/java/org/apache/samza/task/TestTaskFactoryUtil.java b/samza-core/src/test/java/org/apache/samza/task/TestTaskFactoryUtil.java index 82fb41c..3e2b89b 100644 --- a/samza-core/src/test/java/org/apache/samza/task/TestTaskFactoryUtil.java +++ b/samza-core/src/test/java/org/apache/samza/task/TestTaskFactoryUtil.java @@ -111,8 +111,8 @@ public class TestTaskFactoryUtil { OperatorSpecGraph mockSpecGraph = mock(OperatorSpecGraph.class); when(mockStreamApp.getOperatorSpecGraph()).thenReturn(mockSpecGraph); TaskFactory streamTaskFactory = TaskFactoryUtil.getTaskFactory(mockStreamApp); - assertTrue(streamTaskFactory instanceof StreamTaskFactory); - StreamTask streamTask = ((StreamTaskFactory) streamTaskFactory).createInstance(); + assertTrue(streamTaskFactory instanceof AsyncStreamTaskFactory); + AsyncStreamTask streamTask = ((AsyncStreamTaskFactory) streamTaskFactory).createInstance(); assertTrue(streamTask instanceof StreamOperatorTask); verify(mockSpecGraph).clone(); } diff --git a/samza-test/src/main/java/org/apache/samza/example/AppWithGlobalConfigExample.java b/samza-test/src/main/java/org/apache/samza/example/AppWithGlobalConfigExample.java index 766b529..ee299b7 100644 --- a/samza-test/src/main/java/org/apache/samza/example/AppWithGlobalConfigExample.java +++ b/samza-test/src/main/java/org/apache/samza/example/AppWithGlobalConfigExample.java @@ -23,6 +23,8 @@ import java.util.HashMap; import org.apache.samza.application.StreamApplication; import org.apache.samza.application.descriptors.StreamApplicationDescriptor; import org.apache.samza.config.Config; +import org.apache.samza.example.models.PageViewCount; +import org.apache.samza.example.models.PageViewEvent; import org.apache.samza.operators.KV; import org.apache.samza.operators.triggers.Triggers; import org.apache.samza.operators.windows.AccumulationMode; @@ -66,37 +68,21 @@ public class AppWithGlobalConfigExample implements StreamApplication { KVSerde.of(new StringSerde(), new JsonSerdeV2<>(PageViewCount.class))); appDescriptor.getInputStream(inputStreamDescriptor) - .window(Windows.<PageViewEvent, String, Integer>keyedTumblingWindow(m -> m.memberId, Duration.ofSeconds(10), () -> 0, (m, c) -> c + 1, + .window(Windows.<PageViewEvent, String, Integer>keyedTumblingWindow(PageViewEvent::getMemberId, Duration.ofSeconds(10), () -> 0, (m, c) -> c + 1, null, null) .setEarlyTrigger(Triggers.repeat(Triggers.count(5))) .setAccumulationMode(AccumulationMode.DISCARDING), "window1") - .map(m -> KV.of(m.getKey().getKey(), new PageViewCount(m))) + .map(m -> KV.of(m.getKey().getKey(), buildPageViewCount(m))) .sendTo(appDescriptor.getOutputStream(outputStreamDescriptor)); appDescriptor.withMetricsReporterFactories(new HashMap<>()); } - class PageViewEvent { - String pageId; - String memberId; - long timestamp; + static PageViewCount buildPageViewCount(WindowPane<String, Integer> windowPane) { + String memberId = windowPane.getKey().getKey(); + long timestamp = Long.valueOf(windowPane.getKey().getPaneId()); + int count = windowPane.getMessage(); - PageViewEvent(String pageId, String memberId, long timestamp) { - this.pageId = pageId; - this.memberId = memberId; - this.timestamp = timestamp; - } - } - - static class PageViewCount { - String memberId; - long timestamp; - int count; - - PageViewCount(WindowPane<String, Integer> m) { - this.memberId = m.getKey().getKey(); - this.timestamp = Long.valueOf(m.getKey().getPaneId()); - this.count = m.getMessage(); - } + return new PageViewCount(memberId, timestamp, count); } } diff --git a/samza-test/src/main/java/org/apache/samza/example/AsyncApplicationExample.java b/samza-test/src/main/java/org/apache/samza/example/AsyncApplicationExample.java new file mode 100644 index 0000000..9ec1dca --- /dev/null +++ b/samza-test/src/main/java/org/apache/samza/example/AsyncApplicationExample.java @@ -0,0 +1,135 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.samza.example; + +import com.google.common.collect.ImmutableList; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.Random; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionStage; +import org.apache.samza.application.StreamApplication; +import org.apache.samza.application.descriptors.StreamApplicationDescriptor; +import org.apache.samza.config.Config; +import org.apache.samza.example.models.AdClickEvent; +import org.apache.samza.example.models.EnrichedAdClickEvent; +import org.apache.samza.example.models.Member; +import org.apache.samza.operators.KV; +import org.apache.samza.operators.MessageStream; +import org.apache.samza.operators.OutputStream; +import org.apache.samza.runtime.ApplicationRunner; +import org.apache.samza.runtime.ApplicationRunners; +import org.apache.samza.serializers.JsonSerdeV2; +import org.apache.samza.serializers.KVSerde; +import org.apache.samza.serializers.StringSerde; +import org.apache.samza.system.kafka.descriptors.KafkaInputDescriptor; +import org.apache.samza.system.kafka.descriptors.KafkaOutputDescriptor; +import org.apache.samza.system.kafka.descriptors.KafkaSystemDescriptor; +import org.apache.samza.util.CommandLine; + + +/** + * An illustration of use of async APIs in high level application. + * The following example demonstrates the use of {@link MessageStream#flatMapAsync(org.apache.samza.operators.functions.AsyncFlatMapFunction)}. We use a mock + * member decorator service which returns a future in response to decorate request. Typically, in real world scenarios, + * this mock member service will be replaced with rest call to a remote service. + */ +public class AsyncApplicationExample implements StreamApplication { + + @Override + public void describe(StreamApplicationDescriptor appDescriptor) { + KafkaSystemDescriptor trackingSystem = new KafkaSystemDescriptor("tracking"); + + KafkaInputDescriptor<AdClickEvent> inputStreamDescriptor = + trackingSystem.getInputDescriptor("adClickEvent", new JsonSerdeV2<>(AdClickEvent.class)); + + KafkaOutputDescriptor<KV<String, EnrichedAdClickEvent>> outputStreamDescriptor = + trackingSystem.getOutputDescriptor("enrichedAdClickEvent", + KVSerde.of(new StringSerde(), new JsonSerdeV2<>(EnrichedAdClickEvent.class))); + + MessageStream<AdClickEvent> adClickEventStream = appDescriptor.getInputStream(inputStreamDescriptor); + OutputStream<KV<String, EnrichedAdClickEvent>> enrichedAdClickStream = + appDescriptor.getOutputStream(outputStreamDescriptor); + + adClickEventStream + .flatMapAsync(AsyncApplicationExample::enrichAdClickEvent) + .map(enrichedAdClickEvent -> KV.of(enrichedAdClickEvent.getCountry(), enrichedAdClickEvent)) + .sendTo(enrichedAdClickStream); + } + + public static void main(String[] args) { + CommandLine cmdLine = new CommandLine(); + Config config = cmdLine.loadConfig(cmdLine.parser().parse(args)); + ApplicationRunner runner = ApplicationRunners.getApplicationRunner(new AsyncApplicationExample(), config); + + runner.run(); + runner.waitForFinish(); + } + + private static CompletionStage<Collection<EnrichedAdClickEvent>> enrichAdClickEvent(AdClickEvent adClickEvent) { + CompletionStage<Member> decoratedMemberFuture = MemberDecoratorService.decorateMember(adClickEvent.getMemberId()); + return decoratedMemberFuture + .thenApply(member -> Collections.singleton( + new EnrichedAdClickEvent(adClickEvent.getId(), member.getGender(), member.getCountry()))); + } + + /** + * A mock member decorator service that introduces delay to the member decorate call for illustrating async APIs + * use in high level application. In real world, this component would correspond to a component that makes remote + * calls. + */ + private static class MemberDecoratorService { + private static final String[] GENDER = {"F", "M", "U"}; + private static final List<String> COUNTRY = ImmutableList.of( + "KENYA", + "NEW ZEALAND", + "INDONESIA", + "PERU", + "FRANCE", + "MEXICO"); + private static final Random RANDOM = new Random(); + + static CompletionStage<Member> decorateMember(int memberId) { + return CompletableFuture.supplyAsync(() -> { + /* + * Introduce some lag to mimic remote call. In real use cases, this typically translates to over the wire + * network call to some rest service. + */ + try { + Thread.sleep((long) (Math.random() * 10000)); + } catch (InterruptedException ec) { + System.out.println("Interrupted during sleep"); + } + + return new Member(memberId, getRandomGender(), getRandomCountry()); + }); + } + + static String getRandomGender() { + int index = RANDOM.nextInt(GENDER.length); + return GENDER[index]; + } + + static String getRandomCountry() { + int index = RANDOM.nextInt(COUNTRY.size()); + return COUNTRY.get(index); + } + } +} diff --git a/samza-test/src/main/java/org/apache/samza/example/BroadcastExample.java b/samza-test/src/main/java/org/apache/samza/example/BroadcastExample.java index bf641ce..598e332 100644 --- a/samza-test/src/main/java/org/apache/samza/example/BroadcastExample.java +++ b/samza-test/src/main/java/org/apache/samza/example/BroadcastExample.java @@ -22,6 +22,7 @@ package org.apache.samza.example; import org.apache.samza.application.StreamApplication; import org.apache.samza.application.descriptors.StreamApplicationDescriptor; import org.apache.samza.config.Config; +import org.apache.samza.example.models.PageViewEvent; import org.apache.samza.operators.KV; import org.apache.samza.operators.MessageStream; import org.apache.samza.runtime.ApplicationRunner; @@ -41,7 +42,7 @@ import org.apache.samza.util.CommandLine; public class BroadcastExample implements StreamApplication { // local execution mode - public static void main(String[] args) throws Exception { + public static void main(String[] args) { CommandLine cmdLine = new CommandLine(); Config config = cmdLine.loadConfig(cmdLine.parser().parse(args)); ApplicationRunner runner = ApplicationRunners.getApplicationRunner(new BroadcastExample(), config); @@ -67,14 +68,4 @@ public class BroadcastExample implements StreamApplication { inputStream.filter(m -> m.key.equals("key2")).sendTo(appDescriptor.getOutputStream(outStream2)); inputStream.filter(m -> m.key.equals("key3")).sendTo(appDescriptor.getOutputStream(outStream3)); } - - class PageViewEvent { - String key; - long timestamp; - - public PageViewEvent(String key, long timestamp) { - this.key = key; - this.timestamp = timestamp; - } - } } diff --git a/samza-test/src/main/java/org/apache/samza/example/KeyValueStoreExample.java b/samza-test/src/main/java/org/apache/samza/example/KeyValueStoreExample.java index 444039a..4b3ee38 100644 --- a/samza-test/src/main/java/org/apache/samza/example/KeyValueStoreExample.java +++ b/samza-test/src/main/java/org/apache/samza/example/KeyValueStoreExample.java @@ -26,6 +26,7 @@ import org.apache.samza.application.StreamApplication; import org.apache.samza.application.descriptors.StreamApplicationDescriptor; import org.apache.samza.config.Config; import org.apache.samza.context.Context; +import org.apache.samza.example.models.PageViewEvent; import org.apache.samza.operators.KV; import org.apache.samza.operators.MessageStream; import org.apache.samza.operators.OutputStream; @@ -48,7 +49,7 @@ import org.apache.samza.util.CommandLine; public class KeyValueStoreExample implements StreamApplication { // local execution mode - public static void main(String[] args) throws Exception { + public static void main(String[] args) { CommandLine cmdLine = new CommandLine(); Config config = cmdLine.loadConfig(cmdLine.parser().parse(args)); ApplicationRunner runner = ApplicationRunners.getApplicationRunner(new KeyValueStoreExample(), config); @@ -73,7 +74,7 @@ public class KeyValueStoreExample implements StreamApplication { OutputStream<KV<String, StatsOutput>> pageViewEventPerMember = appDescriptor.getOutputStream(outputStreamDescriptor); pageViewEvents - .partitionBy(pve -> pve.memberId, pve -> pve, + .partitionBy(pve -> pve.getMemberId(), pve -> pve, KVSerde.of(new StringSerde(), new JsonSerdeV2<>(PageViewEvent.class)), "partitionBy") .map(KV::getValue) .flatMap(new MyStatsCounter()) @@ -95,8 +96,8 @@ public class KeyValueStoreExample implements StreamApplication { @Override public Collection<StatsOutput> apply(PageViewEvent message) { List<StatsOutput> outputStats = new ArrayList<>(); - long wndTimestamp = (long) Math.floor(TimeUnit.MILLISECONDS.toMinutes(message.timestamp) / 5) * 5; - String wndKey = String.format("%s-%d", message.memberId, wndTimestamp); + long wndTimestamp = (long) Math.floor(TimeUnit.MILLISECONDS.toMinutes(message.getTimestamp()) / 5) * 5; + String wndKey = String.format("%s-%d", message.getMemberId(), wndTimestamp); StatsWindowState curState = this.statsStore.get(wndKey); if (curState == null) { curState = new StatsWindowState(); @@ -107,7 +108,7 @@ public class KeyValueStoreExample implements StreamApplication { curState.timeAtLastOutput = curTimeMs; curState.lastCount += curState.newCount; curState.newCount = 0; - outputStats.add(new StatsOutput(message.memberId, wndTimestamp, curState.lastCount)); + outputStats.add(new StatsOutput(message.getMemberId(), wndTimestamp, curState.lastCount)); } // update counter w/o generating output this.statsStore.put(wndKey, curState); @@ -121,18 +122,6 @@ public class KeyValueStoreExample implements StreamApplication { } } - class PageViewEvent { - String pageId; - String memberId; - long timestamp; - - PageViewEvent(String pageId, String memberId, long timestamp) { - this.pageId = pageId; - this.memberId = memberId; - this.timestamp = timestamp; - } - } - static class StatsOutput { private String memberId; private long timestamp; diff --git a/samza-test/src/main/java/org/apache/samza/example/MergeExample.java b/samza-test/src/main/java/org/apache/samza/example/MergeExample.java index e3eee23..d4da7a5 100644 --- a/samza-test/src/main/java/org/apache/samza/example/MergeExample.java +++ b/samza-test/src/main/java/org/apache/samza/example/MergeExample.java @@ -24,6 +24,7 @@ import com.google.common.collect.ImmutableList; import org.apache.samza.application.StreamApplication; import org.apache.samza.application.descriptors.StreamApplicationDescriptor; import org.apache.samza.config.Config; +import org.apache.samza.example.models.PageViewEvent; import org.apache.samza.operators.MessageStream; import org.apache.samza.runtime.ApplicationRunner; import org.apache.samza.runtime.ApplicationRunners; @@ -39,7 +40,7 @@ import org.apache.samza.util.CommandLine; public class MergeExample implements StreamApplication { // local execution mode - public static void main(String[] args) throws Exception { + public static void main(String[] args) { CommandLine cmdLine = new CommandLine(); Config config = cmdLine.loadConfig(cmdLine.parser().parse(args)); ApplicationRunner runner = ApplicationRunners.getApplicationRunner(new MergeExample(), config); @@ -67,9 +68,4 @@ public class MergeExample implements StreamApplication { .mergeAll(ImmutableList.of(appDescriptor.getInputStream(isd1), appDescriptor.getInputStream(isd2), appDescriptor.getInputStream(isd3))) .sendTo(appDescriptor.getOutputStream(osd)); } - - class PageViewEvent { - String pageId; - long viewTimestamp; - } } \ No newline at end of file diff --git a/samza-test/src/main/java/org/apache/samza/example/PageViewCounterExample.java b/samza-test/src/main/java/org/apache/samza/example/PageViewCounterExample.java index 5fe7b9c..60dc2af 100644 --- a/samza-test/src/main/java/org/apache/samza/example/PageViewCounterExample.java +++ b/samza-test/src/main/java/org/apache/samza/example/PageViewCounterExample.java @@ -22,6 +22,8 @@ import java.time.Duration; import org.apache.samza.application.descriptors.StreamApplicationDescriptor; import org.apache.samza.application.StreamApplication; import org.apache.samza.config.Config; +import org.apache.samza.example.models.PageViewCount; +import org.apache.samza.example.models.PageViewEvent; import org.apache.samza.operators.KV; import org.apache.samza.operators.MessageStream; import org.apache.samza.operators.OutputStream; @@ -74,34 +76,18 @@ public class PageViewCounterExample implements StreamApplication { SupplierFunction<Integer> initialValue = () -> 0; FoldLeftFunction<PageViewEvent, Integer> foldLeftFn = (m, c) -> c + 1; pageViewEvents - .window(Windows.keyedTumblingWindow(m -> m.memberId, Duration.ofSeconds(10), initialValue, foldLeftFn, null, null) + .window(Windows.keyedTumblingWindow(PageViewEvent::getMemberId, Duration.ofSeconds(10), initialValue, foldLeftFn, null, null) .setEarlyTrigger(Triggers.repeat(Triggers.count(5))) .setAccumulationMode(AccumulationMode.DISCARDING), "tumblingWindow") - .map(windowPane -> KV.of(windowPane.getKey().getKey(), new PageViewCount(windowPane))) + .map(windowPane -> KV.of(windowPane.getKey().getKey(), buildPageViewCount(windowPane))) .sendTo(pageViewEventPerMemberStream); } - class PageViewEvent { - String pageId; - String memberId; - long timestamp; + static PageViewCount buildPageViewCount(WindowPane<String, Integer> windowPane) { + String memberId = windowPane.getKey().getKey(); + long timestamp = Long.valueOf(windowPane.getKey().getPaneId()); + int count = windowPane.getMessage(); - PageViewEvent(String pageId, String memberId, long timestamp) { - this.pageId = pageId; - this.memberId = memberId; - this.timestamp = timestamp; - } - } - - static class PageViewCount { - String memberId; - long timestamp; - int count; - - PageViewCount(WindowPane<String, Integer> m) { - this.memberId = m.getKey().getKey(); - this.timestamp = Long.valueOf(m.getKey().getPaneId()); - this.count = m.getMessage(); - } + return new PageViewCount(memberId, timestamp, count); } } diff --git a/samza-test/src/main/java/org/apache/samza/example/RepartitionExample.java b/samza-test/src/main/java/org/apache/samza/example/RepartitionExample.java index 19403b0..1f5c91b 100644 --- a/samza-test/src/main/java/org/apache/samza/example/RepartitionExample.java +++ b/samza-test/src/main/java/org/apache/samza/example/RepartitionExample.java @@ -22,6 +22,7 @@ import java.time.Duration; import org.apache.samza.application.StreamApplication; import org.apache.samza.application.descriptors.StreamApplicationDescriptor; import org.apache.samza.config.Config; +import org.apache.samza.example.models.PageViewEvent; import org.apache.samza.operators.KV; import org.apache.samza.operators.MessageStream; import org.apache.samza.operators.OutputStream; @@ -44,7 +45,7 @@ import org.apache.samza.util.CommandLine; public class RepartitionExample implements StreamApplication { // local execution mode - public static void main(String[] args) throws Exception { + public static void main(String[] args) { CommandLine cmdLine = new CommandLine(); Config config = cmdLine.loadConfig(cmdLine.parser().parse(args)); ApplicationRunner runner = ApplicationRunners.getApplicationRunner(new RepartitionExample(), config); @@ -69,7 +70,7 @@ public class RepartitionExample implements StreamApplication { OutputStream<KV<String, MyStreamOutput>> pageViewEventPerMember = appDescriptor.getOutputStream(outputStreamDescriptor); pageViewEvents - .partitionBy(pve -> pve.memberId, pve -> pve, + .partitionBy(pve -> pve.getMemberId(), pve -> pve, KVSerde.of(new StringSerde(), new JsonSerdeV2<>(PageViewEvent.class)), "partitionBy") .window(Windows.keyedTumblingWindow( KV::getKey, Duration.ofMinutes(5), () -> 0, (m, c) -> c + 1, null, null), "window") @@ -77,18 +78,6 @@ public class RepartitionExample implements StreamApplication { .sendTo(pageViewEventPerMember); } - class PageViewEvent { - String pageId; - String memberId; - long timestamp; - - PageViewEvent(String pageId, String memberId, long timestamp) { - this.pageId = pageId; - this.memberId = memberId; - this.timestamp = timestamp; - } - } - static class MyStreamOutput { String memberId; long timestamp; diff --git a/samza-test/src/main/java/org/apache/samza/example/WindowExample.java b/samza-test/src/main/java/org/apache/samza/example/WindowExample.java index 426fd8d..d73b30f 100644 --- a/samza-test/src/main/java/org/apache/samza/example/WindowExample.java +++ b/samza-test/src/main/java/org/apache/samza/example/WindowExample.java @@ -23,6 +23,7 @@ import java.time.Duration; import org.apache.samza.application.StreamApplication; import org.apache.samza.application.descriptors.StreamApplicationDescriptor; import org.apache.samza.config.Config; +import org.apache.samza.example.models.PageViewEvent; import org.apache.samza.operators.MessageStream; import org.apache.samza.operators.OutputStream; import org.apache.samza.operators.functions.FoldLeftFunction; @@ -80,14 +81,4 @@ public class WindowExample implements StreamApplication { .map(WindowPane::getMessage) .sendTo(outputStream); } - - class PageViewEvent { - String key; - long timestamp; - - public PageViewEvent(String key, long timestamp) { - this.key = key; - this.timestamp = timestamp; - } - } } diff --git a/samza-test/src/main/java/org/apache/samza/example/models/AdClickEvent.java b/samza-test/src/main/java/org/apache/samza/example/models/AdClickEvent.java new file mode 100644 index 0000000..46cf94a --- /dev/null +++ b/samza-test/src/main/java/org/apache/samza/example/models/AdClickEvent.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.samza.example.models; + +public class AdClickEvent { + private int id; + private int memberId; + + public AdClickEvent(int id, int memberId) { + this.id = id; + this.memberId = memberId; + } + + public int getId() { + return id; + } + + public int getMemberId() { + return memberId; + } +} diff --git a/samza-test/src/main/java/org/apache/samza/example/models/EnrichedAdClickEvent.java b/samza-test/src/main/java/org/apache/samza/example/models/EnrichedAdClickEvent.java new file mode 100644 index 0000000..4a3b874 --- /dev/null +++ b/samza-test/src/main/java/org/apache/samza/example/models/EnrichedAdClickEvent.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.samza.example.models; + +public class EnrichedAdClickEvent { + private int id; + private String gender; + private String country; + + public EnrichedAdClickEvent(int id, String gender, String country) { + this.id = id; + this.gender = gender; + this.country = country; + } + + public int getId() { + return id; + } + + public String getGender() { + return gender; + } + + public String getCountry() { + return country; + } +} diff --git a/samza-test/src/main/java/org/apache/samza/example/models/Member.java b/samza-test/src/main/java/org/apache/samza/example/models/Member.java new file mode 100644 index 0000000..e321bbc --- /dev/null +++ b/samza-test/src/main/java/org/apache/samza/example/models/Member.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.samza.example.models; + +public class Member { + private int memberId; + private String gender; + private String country; + + public Member(int memberId, String gender, String country) { + this.memberId = memberId; + this.gender = gender; + this.country = country; + } + + public int getMemberId() { + return memberId; + } + + public String getGender() { + return gender; + } + + public String getCountry() { + return country; + } +} diff --git a/samza-test/src/main/java/org/apache/samza/example/models/PageViewCount.java b/samza-test/src/main/java/org/apache/samza/example/models/PageViewCount.java new file mode 100644 index 0000000..eeb5d24 --- /dev/null +++ b/samza-test/src/main/java/org/apache/samza/example/models/PageViewCount.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.samza.example.models; + + +public class PageViewCount { + private String memberId; + private long timestamp; + private int count; + + public PageViewCount(String memberId, long timestamp, int count) { + this.memberId = memberId; + this.timestamp = timestamp; + this.count = count; + } + + public String getMemberId() { + return memberId; + } + + public long getTimestamp() { + return timestamp; + } + + public int getCount() { + return count; + } +} diff --git a/samza-test/src/main/java/org/apache/samza/example/models/PageViewEvent.java b/samza-test/src/main/java/org/apache/samza/example/models/PageViewEvent.java new file mode 100644 index 0000000..fba1bfb --- /dev/null +++ b/samza-test/src/main/java/org/apache/samza/example/models/PageViewEvent.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.samza.example.models; + +public class PageViewEvent { + private String pageId; + private String memberId; + private long timestamp; + + public PageViewEvent(String pageId, String memberId, long timestamp) { + this.pageId = pageId; + this.memberId = memberId; + this.timestamp = timestamp; + } + + public String getPageId() { + return pageId; + } + + public String getMemberId() { + return memberId; + } + + public long getTimestamp() { + return timestamp; + } +} diff --git a/samza-test/src/test/java/org/apache/samza/test/controlmessages/WatermarkIntegrationTest.java b/samza-test/src/test/java/org/apache/samza/test/controlmessages/WatermarkIntegrationTest.java index 9436fb2..63dc254 100644 --- a/samza-test/src/test/java/org/apache/samza/test/controlmessages/WatermarkIntegrationTest.java +++ b/samza-test/src/test/java/org/apache/samza/test/controlmessages/WatermarkIntegrationTest.java @@ -21,7 +21,6 @@ package org.apache.samza.test.controlmessages; import scala.collection.JavaConverters; -import java.lang.reflect.Field; import java.util.ArrayList; import java.util.HashMap; import java.util.List; @@ -65,7 +64,6 @@ import org.apache.samza.system.SystemConsumer; import org.apache.samza.system.SystemFactory; import org.apache.samza.system.SystemProducer; import org.apache.samza.system.SystemStreamPartition; -import org.apache.samza.task.AsyncStreamTaskAdapter; import org.apache.samza.task.StreamOperatorTask; import org.apache.samza.task.TestStreamOperatorTask; import org.apache.samza.test.controlmessages.TestData.PageView; @@ -198,10 +196,7 @@ public class WatermarkIntegrationTest extends IntegrationTestHarness { Map<TaskName, TaskInstance> taskInstances = JavaConverters.mapAsJavaMapConverter(container.getTaskInstances()).asJava(); Map<String, StreamOperatorTask> tasks = new HashMap<>(); for (Map.Entry<TaskName, TaskInstance> entry : taskInstances.entrySet()) { - AsyncStreamTaskAdapter adapter = (AsyncStreamTaskAdapter) entry.getValue().task(); - Field field = AsyncStreamTaskAdapter.class.getDeclaredField("wrappedTask"); - field.setAccessible(true); - StreamOperatorTask task = (StreamOperatorTask) field.get(adapter); + StreamOperatorTask task = (StreamOperatorTask) entry.getValue().task(); tasks.put(entry.getKey().getTaskName(), task); } return tasks; diff --git a/samza-test/src/test/java/org/apache/samza/test/operator/TestAsyncFlatMap.java b/samza-test/src/test/java/org/apache/samza/test/operator/TestAsyncFlatMap.java new file mode 100644 index 0000000..4afff92 --- /dev/null +++ b/samza-test/src/test/java/org/apache/samza/test/operator/TestAsyncFlatMap.java @@ -0,0 +1,180 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.samza.test.operator; + +import com.google.common.collect.ImmutableList; +import java.io.Serializable; +import java.time.Duration; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionStage; +import java.util.function.Predicate; +import java.util.function.Supplier; +import java.util.stream.Collectors; +import org.apache.samza.SamzaException; +import org.apache.samza.application.StreamApplication; +import org.apache.samza.application.descriptors.StreamApplicationDescriptor; +import org.apache.samza.config.Config; +import org.apache.samza.config.MapConfig; +import org.apache.samza.config.StreamConfig; +import org.apache.samza.config.TaskConfig; +import org.apache.samza.operators.OutputStream; +import org.apache.samza.serializers.NoOpSerde; +import org.apache.samza.system.kafka.descriptors.KafkaOutputDescriptor; +import org.apache.samza.system.kafka.descriptors.KafkaSystemDescriptor; +import org.apache.samza.test.framework.TestRunner; +import org.apache.samza.test.framework.system.descriptors.InMemoryInputDescriptor; +import org.apache.samza.test.framework.system.descriptors.InMemoryOutputDescriptor; +import org.apache.samza.test.framework.system.descriptors.InMemorySystemDescriptor; +import org.apache.samza.test.harness.IntegrationTestHarness; +import org.apache.samza.test.operator.data.PageView; +import org.junit.Test; + +import static org.junit.Assert.*; + + +public class TestAsyncFlatMap extends IntegrationTestHarness { + private static final String TEST_SYSTEM = "test"; + private static final String PAGE_VIEW_STREAM = "test-async-page-view-stream"; + private static final String NON_GUEST_PAGE_VIEW_STREAM = "test-async-non-guest-page-view-stream"; + private static final String FAIL_PROCESS = "process.fail"; + private static final String FAIL_DOWNSTREAM_OPERATOR = "downstream.operator.fail"; + private static final String LOGIN_PAGE = "login-page"; + private static final String PROCESS_JITTER = "process.jitter"; + + private static final List<PageView> PAGE_VIEWS = ImmutableList.of( + new PageView("1", LOGIN_PAGE, "1"), + new PageView("2", "home-page", "2"), + new PageView("3", "profile-page", "0"), + new PageView("4", LOGIN_PAGE, "0")); + + + @Test + public void testProcessingFutureCompletesSuccessfully() { + List<PageView> expectedPageViews = PAGE_VIEWS.stream() + .filter(pageView -> !pageView.getPageId().equals(LOGIN_PAGE) && Long.valueOf(pageView.getUserId()) > 0) + .collect(Collectors.toList()); + + List<PageView> actualPageViews = runTest(PAGE_VIEWS, new HashMap<>()); + assertEquals("Mismatch between expected vs actual page views", expectedPageViews, actualPageViews); + } + + @Test(expected = SamzaException.class) + public void testProcessingFutureCompletesAfterTaskTimeout() { + Map<String, String> configs = new HashMap<>(); + configs.put(TaskConfig.CALLBACK_TIMEOUT_MS(), "100"); + configs.put(PROCESS_JITTER, "200"); + + runTest(PAGE_VIEWS, configs); + } + + @Test(expected = RuntimeException.class) + public void testProcessingExceptionIsBubbledUp() { + Map<String, String> configs = new HashMap<>(); + configs.put(FAIL_PROCESS, "true"); + + runTest(PAGE_VIEWS, configs); + } + + @Test(expected = RuntimeException.class) + public void testDownstreamOperatorExceptionIsBubbledUp() { + Map<String, String> configs = new HashMap<>(); + configs.put(FAIL_DOWNSTREAM_OPERATOR, "true"); + + runTest(PAGE_VIEWS, configs); + } + + private List<PageView> runTest(List<PageView> pageViews, Map<String, String> configs) { + configs.put(String.format(StreamConfig.SYSTEM_FOR_STREAM_ID(), PAGE_VIEW_STREAM), TEST_SYSTEM); + + InMemorySystemDescriptor isd = new InMemorySystemDescriptor(TEST_SYSTEM); + InMemoryInputDescriptor<PageView> pageViewStreamDesc = isd + .getInputDescriptor(PAGE_VIEW_STREAM, new NoOpSerde<>()); + + + InMemoryOutputDescriptor<PageView> outputStreamDesc = isd + .getOutputDescriptor(NON_GUEST_PAGE_VIEW_STREAM, new NoOpSerde<>()); + + TestRunner + .of(new AsyncFlatMapExample()) + .addInputStream(pageViewStreamDesc, pageViews) + .addOutputStream(outputStreamDesc, 1) + .addConfig(new MapConfig(configs)) + .run(Duration.ofMillis(50000)); + + Map<Integer, List<PageView>> result = TestRunner.consumeStream(outputStreamDesc, Duration.ofMillis(1000)); + List<PageView> results = result.values().stream() + .flatMap(List::stream) + .collect(Collectors.toList()); + + return results; + } + + static class AsyncFlatMapExample implements StreamApplication { + @Override + public void describe(StreamApplicationDescriptor appDescriptor) { + Config config = appDescriptor.getConfig(); + KafkaSystemDescriptor kafkaSystemDescriptor = new KafkaSystemDescriptor(TEST_SYSTEM); + KafkaOutputDescriptor<PageView> + outputDescriptor = kafkaSystemDescriptor.getOutputDescriptor(NON_GUEST_PAGE_VIEW_STREAM, new NoOpSerde<>()); + OutputStream<PageView> nonGuestPageViewStream = appDescriptor.getOutputStream(outputDescriptor); + + Predicate<PageView> failProcess = (Predicate<PageView> & Serializable) (ignored) -> config.getBoolean(FAIL_PROCESS, false); + Predicate<PageView> failDownstreamOperator = (Predicate<PageView> & Serializable) (ignored) -> config.getBoolean(FAIL_DOWNSTREAM_OPERATOR, false); + Supplier<Long> processJitter = (Supplier<Long> & Serializable) () -> config.getLong(PROCESS_JITTER, 100); + + appDescriptor.getInputStream(kafkaSystemDescriptor.getInputDescriptor(PAGE_VIEW_STREAM, new NoOpSerde<PageView>())) + .flatMapAsync(pageView -> filterGuestPageViews(pageView, failProcess, processJitter)) + .filter(pageView -> filterLoginPageViews(pageView, failDownstreamOperator)) + .sendTo(nonGuestPageViewStream); + } + + private static CompletionStage<Collection<PageView>> filterGuestPageViews(PageView pageView, + Predicate<PageView> shouldFailProcess, Supplier<Long> processJitter) { + CompletableFuture<Collection<PageView>> filteredPageViews = CompletableFuture.supplyAsync(() -> { + try { + Thread.sleep(processJitter.get()); + } catch (InterruptedException ex) { + System.out.println("Interrupted during sleep."); + } + + return Long.valueOf(pageView.getUserId()) < 1 ? Collections.emptyList() : Collections.singleton(pageView); + }); + + if (shouldFailProcess.test(pageView)) { + filteredPageViews.completeExceptionally(new RuntimeException("Remote service threw an exception")); + } + + return filteredPageViews; + } + + private static boolean filterLoginPageViews(PageView pageView, Predicate<PageView> shouldFailProcess) { + if (shouldFailProcess.test(pageView)) { + throw new RuntimeException("Filtering login page views ran into an exception"); + } + + return !LOGIN_PAGE.equals(pageView.getPageId()); + } + + } +} diff --git a/samza-test/src/test/java/org/apache/samza/test/samzasql/TestSamzaSqlRemoteTable.java b/samza-test/src/test/java/org/apache/samza/test/samzasql/TestSamzaSqlRemoteTable.java index d962b14..94efa79 100644 --- a/samza-test/src/test/java/org/apache/samza/test/samzasql/TestSamzaSqlRemoteTable.java +++ b/samza-test/src/test/java/org/apache/samza/test/samzasql/TestSamzaSqlRemoteTable.java @@ -32,6 +32,7 @@ import org.apache.samza.sql.util.JsonUtil; import org.apache.samza.sql.util.SamzaSqlTestConfig; import org.apache.samza.sql.util.RemoteStoreIOResolverTestFactory; import org.junit.Assert; +import org.junit.Ignore; import org.junit.Test; @@ -53,6 +54,7 @@ public class TestSamzaSqlRemoteTable extends SamzaSqlIntegrationTestHarness { } @Test + @Ignore("Disabled due to flakiness related to data generation; Refer Pull Request #905 for details") public void testSinkEndToEndWithKeyWithNullRecords() { int numMessages = 20;