http://git-wip-us.apache.org/repos/asf/samza/blob/53d7f262/samza-core/src/test/java/org/apache/samza/operators/TestStreamGraphSpec.java ---------------------------------------------------------------------- diff --git a/samza-core/src/test/java/org/apache/samza/operators/TestStreamGraphSpec.java b/samza-core/src/test/java/org/apache/samza/operators/TestStreamGraphSpec.java new file mode 100644 index 0000000..e476abc --- /dev/null +++ b/samza-core/src/test/java/org/apache/samza/operators/TestStreamGraphSpec.java @@ -0,0 +1,601 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for THE + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.samza.operators; + +import com.google.common.collect.ImmutableList; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; + +import org.apache.samza.SamzaException; +import org.apache.samza.config.Config; +import org.apache.samza.config.JobConfig; +import org.apache.samza.operators.data.TestMessageEnvelope; +import org.apache.samza.operators.spec.InputOperatorSpec; +import org.apache.samza.operators.spec.OperatorSpec.OpCode; +import org.apache.samza.operators.spec.OutputStreamImpl; +import org.apache.samza.operators.stream.IntermediateMessageStreamImpl; +import org.apache.samza.runtime.ApplicationRunner; +import org.apache.samza.serializers.KVSerde; +import org.apache.samza.serializers.NoOpSerde; +import org.apache.samza.serializers.Serde; +import org.apache.samza.system.StreamSpec; +import org.apache.samza.table.TableSpec; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; +import static org.mockito.Matchers.anyString; +import static org.mockito.Matchers.eq; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class TestStreamGraphSpec { + + @Test + public void testGetInputStreamWithValueSerde() { + ApplicationRunner mockRunner = mock(ApplicationRunner.class); + StreamSpec mockStreamSpec = mock(StreamSpec.class); + when(mockRunner.getStreamSpec("test-stream-1")).thenReturn(mockStreamSpec); + StreamGraphSpec graphSpec = new StreamGraphSpec(mockRunner, mock(Config.class)); + + Serde mockValueSerde = mock(Serde.class); + MessageStream<TestMessageEnvelope> inputStream = graphSpec.getInputStream("test-stream-1", mockValueSerde); + + InputOperatorSpec<String, TestMessageEnvelope> inputOpSpec = + (InputOperatorSpec) ((MessageStreamImpl<TestMessageEnvelope>) inputStream).getOperatorSpec(); + assertEquals(OpCode.INPUT, inputOpSpec.getOpCode()); + assertEquals(graphSpec.getInputOperators().get(mockStreamSpec), inputOpSpec); + assertEquals(mockStreamSpec, inputOpSpec.getStreamSpec()); + assertTrue(inputOpSpec.getKeySerde() instanceof NoOpSerde); + assertEquals(mockValueSerde, inputOpSpec.getValueSerde()); + } + + @Test + public void testGetInputStreamWithKeyValueSerde() { + ApplicationRunner mockRunner = mock(ApplicationRunner.class); + StreamSpec mockStreamSpec = mock(StreamSpec.class); + when(mockRunner.getStreamSpec("test-stream-1")).thenReturn(mockStreamSpec); + StreamGraphSpec graphSpec = new StreamGraphSpec(mockRunner, mock(Config.class)); + + KVSerde mockKVSerde = mock(KVSerde.class); + Serde mockKeySerde = mock(Serde.class); + Serde mockValueSerde = mock(Serde.class); + doReturn(mockKeySerde).when(mockKVSerde).getKeySerde(); + doReturn(mockValueSerde).when(mockKVSerde).getValueSerde(); + MessageStream<TestMessageEnvelope> inputStream = graphSpec.getInputStream("test-stream-1", mockKVSerde); + + InputOperatorSpec<String, TestMessageEnvelope> inputOpSpec = + (InputOperatorSpec) ((MessageStreamImpl<TestMessageEnvelope>) inputStream).getOperatorSpec(); + assertEquals(OpCode.INPUT, inputOpSpec.getOpCode()); + assertEquals(graphSpec.getInputOperators().get(mockStreamSpec), inputOpSpec); + assertEquals(mockStreamSpec, inputOpSpec.getStreamSpec()); + assertEquals(mockKeySerde, inputOpSpec.getKeySerde()); + assertEquals(mockValueSerde, inputOpSpec.getValueSerde()); + } + + @Test(expected = NullPointerException.class) + public void testGetInputStreamWithNullSerde() { + ApplicationRunner mockRunner = mock(ApplicationRunner.class); + StreamSpec mockStreamSpec = mock(StreamSpec.class); + when(mockRunner.getStreamSpec("test-stream-1")).thenReturn(mockStreamSpec); + StreamGraphSpec graphSpec = new StreamGraphSpec(mockRunner, mock(Config.class)); + + graphSpec.getInputStream("test-stream-1", null); + } + + @Test + public void testGetInputStreamWithDefaultValueSerde() { + ApplicationRunner mockRunner = mock(ApplicationRunner.class); + StreamSpec mockStreamSpec = mock(StreamSpec.class); + when(mockRunner.getStreamSpec("test-stream-1")).thenReturn(mockStreamSpec); + StreamGraphSpec graphSpec = new StreamGraphSpec(mockRunner, mock(Config.class)); + + Serde mockValueSerde = mock(Serde.class); + graphSpec.setDefaultSerde(mockValueSerde); + MessageStream<TestMessageEnvelope> inputStream = graphSpec.getInputStream("test-stream-1"); + + InputOperatorSpec<String, TestMessageEnvelope> inputOpSpec = + (InputOperatorSpec) ((MessageStreamImpl<TestMessageEnvelope>) inputStream).getOperatorSpec(); + assertEquals(OpCode.INPUT, inputOpSpec.getOpCode()); + assertEquals(graphSpec.getInputOperators().get(mockStreamSpec), inputOpSpec); + assertEquals(mockStreamSpec, inputOpSpec.getStreamSpec()); + assertTrue(inputOpSpec.getKeySerde() instanceof NoOpSerde); + assertEquals(mockValueSerde, inputOpSpec.getValueSerde()); + } + + @Test + public void testGetInputStreamWithDefaultKeyValueSerde() { + ApplicationRunner mockRunner = mock(ApplicationRunner.class); + StreamSpec mockStreamSpec = mock(StreamSpec.class); + when(mockRunner.getStreamSpec("test-stream-1")).thenReturn(mockStreamSpec); + StreamGraphSpec graphSpec = new StreamGraphSpec(mockRunner, mock(Config.class)); + + KVSerde mockKVSerde = mock(KVSerde.class); + Serde mockKeySerde = mock(Serde.class); + Serde mockValueSerde = mock(Serde.class); + doReturn(mockKeySerde).when(mockKVSerde).getKeySerde(); + doReturn(mockValueSerde).when(mockKVSerde).getValueSerde(); + graphSpec.setDefaultSerde(mockKVSerde); + MessageStream<TestMessageEnvelope> inputStream = graphSpec.getInputStream("test-stream-1"); + + InputOperatorSpec<String, TestMessageEnvelope> inputOpSpec = + (InputOperatorSpec) ((MessageStreamImpl<TestMessageEnvelope>) inputStream).getOperatorSpec(); + assertEquals(OpCode.INPUT, inputOpSpec.getOpCode()); + assertEquals(graphSpec.getInputOperators().get(mockStreamSpec), inputOpSpec); + assertEquals(mockStreamSpec, inputOpSpec.getStreamSpec()); + assertEquals(mockKeySerde, inputOpSpec.getKeySerde()); + assertEquals(mockValueSerde, inputOpSpec.getValueSerde()); + } + + @Test + public void testGetInputStreamWithDefaultDefaultSerde() { + // default default serde == user hasn't provided a default serde + ApplicationRunner mockRunner = mock(ApplicationRunner.class); + StreamSpec mockStreamSpec = mock(StreamSpec.class); + when(mockRunner.getStreamSpec("test-stream-1")).thenReturn(mockStreamSpec); + StreamGraphSpec graphSpec = new StreamGraphSpec(mockRunner, mock(Config.class)); + + MessageStream<TestMessageEnvelope> inputStream = graphSpec.getInputStream("test-stream-1"); + + InputOperatorSpec<String, TestMessageEnvelope> inputOpSpec = + (InputOperatorSpec) ((MessageStreamImpl<TestMessageEnvelope>) inputStream).getOperatorSpec(); + assertEquals(OpCode.INPUT, inputOpSpec.getOpCode()); + assertEquals(graphSpec.getInputOperators().get(mockStreamSpec), inputOpSpec); + assertEquals(mockStreamSpec, inputOpSpec.getStreamSpec()); + assertTrue(inputOpSpec.getKeySerde() instanceof NoOpSerde); + assertTrue(inputOpSpec.getValueSerde() instanceof NoOpSerde); + } + + @Test + public void testGetInputStreamWithRelaxedTypes() { + ApplicationRunner mockRunner = mock(ApplicationRunner.class); + StreamSpec mockStreamSpec = mock(StreamSpec.class); + when(mockRunner.getStreamSpec("test-stream-1")).thenReturn(mockStreamSpec); + StreamGraphSpec graphSpec = new StreamGraphSpec(mockRunner, mock(Config.class)); + + MessageStream<TestMessageEnvelope> inputStream = graphSpec.getInputStream("test-stream-1"); + + InputOperatorSpec<String, TestMessageEnvelope> inputOpSpec = + (InputOperatorSpec) ((MessageStreamImpl<TestMessageEnvelope>) inputStream).getOperatorSpec(); + assertEquals(OpCode.INPUT, inputOpSpec.getOpCode()); + assertEquals(graphSpec.getInputOperators().get(mockStreamSpec), inputOpSpec); + assertEquals(mockStreamSpec, inputOpSpec.getStreamSpec()); + } + + @Test + public void testMultipleGetInputStreams() { + ApplicationRunner mockRunner = mock(ApplicationRunner.class); + StreamSpec mockStreamSpec1 = mock(StreamSpec.class); + StreamSpec mockStreamSpec2 = mock(StreamSpec.class); + when(mockRunner.getStreamSpec("test-stream-1")).thenReturn(mockStreamSpec1); + when(mockRunner.getStreamSpec("test-stream-2")).thenReturn(mockStreamSpec2); + + StreamGraphSpec graphSpec = new StreamGraphSpec(mockRunner, mock(Config.class)); + MessageStream<Object> inputStream1 = graphSpec.getInputStream("test-stream-1"); + MessageStream<Object> inputStream2 = graphSpec.getInputStream("test-stream-2"); + + InputOperatorSpec<String, TestMessageEnvelope> inputOpSpec1 = + (InputOperatorSpec) ((MessageStreamImpl<Object>) inputStream1).getOperatorSpec(); + InputOperatorSpec<String, TestMessageEnvelope> inputOpSpec2 = + (InputOperatorSpec) ((MessageStreamImpl<Object>) inputStream2).getOperatorSpec(); + + assertEquals(graphSpec.getInputOperators().size(), 2); + assertEquals(graphSpec.getInputOperators().get(mockStreamSpec1), inputOpSpec1); + assertEquals(graphSpec.getInputOperators().get(mockStreamSpec2), inputOpSpec2); + } + + @Test(expected = IllegalStateException.class) + public void testGetSameInputStreamTwice() { + ApplicationRunner mockRunner = mock(ApplicationRunner.class); + when(mockRunner.getStreamSpec("test-stream-1")).thenReturn(mock(StreamSpec.class)); + + StreamGraphSpec graphSpec = new StreamGraphSpec(mockRunner, mock(Config.class)); + graphSpec.getInputStream("test-stream-1"); + // should throw exception + graphSpec.getInputStream("test-stream-1"); + } + + @Test + public void testGetOutputStreamWithValueSerde() { + ApplicationRunner mockRunner = mock(ApplicationRunner.class); + StreamSpec mockStreamSpec = mock(StreamSpec.class); + when(mockRunner.getStreamSpec("test-stream-1")).thenReturn(mockStreamSpec); + + StreamGraphSpec graphSpec = new StreamGraphSpec(mockRunner, mock(Config.class)); + + Serde mockValueSerde = mock(Serde.class); + OutputStream<TestMessageEnvelope> outputStream = + graphSpec.getOutputStream("test-stream-1", mockValueSerde); + + OutputStreamImpl<TestMessageEnvelope> outputStreamImpl = (OutputStreamImpl) outputStream; + assertEquals(graphSpec.getOutputStreams().get(mockStreamSpec), outputStreamImpl); + assertEquals(mockStreamSpec, outputStreamImpl.getStreamSpec()); + assertTrue(outputStreamImpl.getKeySerde() instanceof NoOpSerde); + assertEquals(mockValueSerde, outputStreamImpl.getValueSerde()); + } + + @Test + public void testGetOutputStreamWithKeyValueSerde() { + ApplicationRunner mockRunner = mock(ApplicationRunner.class); + StreamSpec mockStreamSpec = mock(StreamSpec.class); + when(mockRunner.getStreamSpec("test-stream-1")).thenReturn(mockStreamSpec); + + StreamGraphSpec graphSpec = new StreamGraphSpec(mockRunner, mock(Config.class)); + KVSerde mockKVSerde = mock(KVSerde.class); + Serde mockKeySerde = mock(Serde.class); + Serde mockValueSerde = mock(Serde.class); + doReturn(mockKeySerde).when(mockKVSerde).getKeySerde(); + doReturn(mockValueSerde).when(mockKVSerde).getValueSerde(); + graphSpec.setDefaultSerde(mockKVSerde); + OutputStream<TestMessageEnvelope> outputStream = graphSpec.getOutputStream("test-stream-1", mockKVSerde); + + OutputStreamImpl<TestMessageEnvelope> outputStreamImpl = (OutputStreamImpl) outputStream; + assertEquals(graphSpec.getOutputStreams().get(mockStreamSpec), outputStreamImpl); + assertEquals(mockStreamSpec, outputStreamImpl.getStreamSpec()); + assertEquals(mockKeySerde, outputStreamImpl.getKeySerde()); + assertEquals(mockValueSerde, outputStreamImpl.getValueSerde()); + } + + @Test(expected = NullPointerException.class) + public void testGetOutputStreamWithNullSerde() { + ApplicationRunner mockRunner = mock(ApplicationRunner.class); + StreamSpec mockStreamSpec = mock(StreamSpec.class); + when(mockRunner.getStreamSpec("test-stream-1")).thenReturn(mockStreamSpec); + + StreamGraphSpec graphSpec = new StreamGraphSpec(mockRunner, mock(Config.class)); + + graphSpec.getOutputStream("test-stream-1", null); + } + + @Test + public void testGetOutputStreamWithDefaultValueSerde() { + ApplicationRunner mockRunner = mock(ApplicationRunner.class); + StreamSpec mockStreamSpec = mock(StreamSpec.class); + when(mockRunner.getStreamSpec("test-stream-1")).thenReturn(mockStreamSpec); + + Serde mockValueSerde = mock(Serde.class); + StreamGraphSpec graphSpec = new StreamGraphSpec(mockRunner, mock(Config.class)); + graphSpec.setDefaultSerde(mockValueSerde); + OutputStream<TestMessageEnvelope> outputStream = graphSpec.getOutputStream("test-stream-1"); + + OutputStreamImpl<TestMessageEnvelope> outputStreamImpl = (OutputStreamImpl) outputStream; + assertEquals(graphSpec.getOutputStreams().get(mockStreamSpec), outputStreamImpl); + assertEquals(mockStreamSpec, outputStreamImpl.getStreamSpec()); + assertTrue(outputStreamImpl.getKeySerde() instanceof NoOpSerde); + assertEquals(mockValueSerde, outputStreamImpl.getValueSerde()); + } + + @Test + public void testGetOutputStreamWithDefaultKeyValueSerde() { + ApplicationRunner mockRunner = mock(ApplicationRunner.class); + StreamSpec mockStreamSpec = mock(StreamSpec.class); + when(mockRunner.getStreamSpec("test-stream-1")).thenReturn(mockStreamSpec); + + StreamGraphSpec graphSpec = new StreamGraphSpec(mockRunner, mock(Config.class)); + KVSerde mockKVSerde = mock(KVSerde.class); + Serde mockKeySerde = mock(Serde.class); + Serde mockValueSerde = mock(Serde.class); + doReturn(mockKeySerde).when(mockKVSerde).getKeySerde(); + doReturn(mockValueSerde).when(mockKVSerde).getValueSerde(); + graphSpec.setDefaultSerde(mockKVSerde); + + OutputStream<TestMessageEnvelope> outputStream = graphSpec.getOutputStream("test-stream-1"); + + OutputStreamImpl<TestMessageEnvelope> outputStreamImpl = (OutputStreamImpl) outputStream; + assertEquals(graphSpec.getOutputStreams().get(mockStreamSpec), outputStreamImpl); + assertEquals(mockStreamSpec, outputStreamImpl.getStreamSpec()); + assertEquals(mockKeySerde, outputStreamImpl.getKeySerde()); + assertEquals(mockValueSerde, outputStreamImpl.getValueSerde()); + } + + @Test + public void testGetOutputStreamWithDefaultDefaultSerde() { + ApplicationRunner mockRunner = mock(ApplicationRunner.class); + StreamSpec mockStreamSpec = mock(StreamSpec.class); + when(mockRunner.getStreamSpec("test-stream-1")).thenReturn(mockStreamSpec); + + StreamGraphSpec graphSpec = new StreamGraphSpec(mockRunner, mock(Config.class)); + + OutputStream<TestMessageEnvelope> outputStream = graphSpec.getOutputStream("test-stream-1"); + + OutputStreamImpl<TestMessageEnvelope> outputStreamImpl = (OutputStreamImpl) outputStream; + assertEquals(graphSpec.getOutputStreams().get(mockStreamSpec), outputStreamImpl); + assertEquals(mockStreamSpec, outputStreamImpl.getStreamSpec()); + assertTrue(outputStreamImpl.getKeySerde() instanceof NoOpSerde); + assertTrue(outputStreamImpl.getValueSerde() instanceof NoOpSerde); + } + + @Test(expected = IllegalStateException.class) + public void testSetDefaultSerdeAfterGettingStreams() { + ApplicationRunner mockRunner = mock(ApplicationRunner.class); + when(mockRunner.getStreamSpec("test-stream-1")).thenReturn(mock(StreamSpec.class)); + + StreamGraphSpec graphSpec = new StreamGraphSpec(mockRunner, mock(Config.class)); + graphSpec.getInputStream("test-stream-1"); + graphSpec.setDefaultSerde(mock(Serde.class)); // should throw exception + } + + @Test(expected = IllegalStateException.class) + public void testSetDefaultSerdeAfterGettingOutputStream() { + ApplicationRunner mockRunner = mock(ApplicationRunner.class); + when(mockRunner.getStreamSpec("test-stream-1")).thenReturn(mock(StreamSpec.class)); + + StreamGraphSpec graphSpec = new StreamGraphSpec(mockRunner, mock(Config.class)); + graphSpec.getOutputStream("test-stream-1"); + graphSpec.setDefaultSerde(mock(Serde.class)); // should throw exception + } + + @Test(expected = IllegalStateException.class) + public void testSetDefaultSerdeAfterGettingIntermediateStream() { + ApplicationRunner mockRunner = mock(ApplicationRunner.class); + when(mockRunner.getStreamSpec("test-stream-1")).thenReturn(mock(StreamSpec.class)); + + StreamGraphSpec graphSpec = new StreamGraphSpec(mockRunner, mock(Config.class)); + graphSpec.getIntermediateStream("test-stream-1", null); + graphSpec.setDefaultSerde(mock(Serde.class)); // should throw exception + } + + @Test(expected = IllegalStateException.class) + public void testGetSameOutputStreamTwice() { + ApplicationRunner mockRunner = mock(ApplicationRunner.class); + when(mockRunner.getStreamSpec("test-stream-1")).thenReturn(mock(StreamSpec.class)); + + StreamGraphSpec graphSpec = new StreamGraphSpec(mockRunner, mock(Config.class)); + graphSpec.getOutputStream("test-stream-1"); + graphSpec.getOutputStream("test-stream-1"); // should throw exception + } + + @Test + public void testGetIntermediateStreamWithValueSerde() { + ApplicationRunner mockRunner = mock(ApplicationRunner.class); + Config mockConfig = mock(Config.class); + StreamSpec mockStreamSpec = mock(StreamSpec.class); + String mockStreamName = "mockStreamName"; + when(mockRunner.getStreamSpec(mockStreamName)).thenReturn(mockStreamSpec); + + StreamGraphSpec graphSpec = new StreamGraphSpec(mockRunner, mockConfig); + + Serde mockValueSerde = mock(Serde.class); + IntermediateMessageStreamImpl<TestMessageEnvelope> intermediateStreamImpl = + graphSpec.getIntermediateStream(mockStreamName, mockValueSerde); + + assertEquals(graphSpec.getInputOperators().get(mockStreamSpec), intermediateStreamImpl.getOperatorSpec()); + assertEquals(graphSpec.getOutputStreams().get(mockStreamSpec), intermediateStreamImpl.getOutputStream()); + assertEquals(mockStreamSpec, intermediateStreamImpl.getStreamSpec()); + assertTrue(intermediateStreamImpl.getOutputStream().getKeySerde() instanceof NoOpSerde); + assertEquals(mockValueSerde, intermediateStreamImpl.getOutputStream().getValueSerde()); + assertTrue(((InputOperatorSpec) intermediateStreamImpl.getOperatorSpec()).getKeySerde() instanceof NoOpSerde); + assertEquals(mockValueSerde, ((InputOperatorSpec) intermediateStreamImpl.getOperatorSpec()).getValueSerde()); + } + + @Test + public void testGetIntermediateStreamWithKeyValueSerde() { + ApplicationRunner mockRunner = mock(ApplicationRunner.class); + Config mockConfig = mock(Config.class); + StreamSpec mockStreamSpec = mock(StreamSpec.class); + String mockStreamName = "mockStreamName"; + when(mockRunner.getStreamSpec(mockStreamName)).thenReturn(mockStreamSpec); + + StreamGraphSpec graphSpec = new StreamGraphSpec(mockRunner, mockConfig); + + KVSerde mockKVSerde = mock(KVSerde.class); + Serde mockKeySerde = mock(Serde.class); + Serde mockValueSerde = mock(Serde.class); + doReturn(mockKeySerde).when(mockKVSerde).getKeySerde(); + doReturn(mockValueSerde).when(mockKVSerde).getValueSerde(); + IntermediateMessageStreamImpl<TestMessageEnvelope> intermediateStreamImpl = + graphSpec.getIntermediateStream(mockStreamName, mockKVSerde); + + assertEquals(graphSpec.getInputOperators().get(mockStreamSpec), intermediateStreamImpl.getOperatorSpec()); + assertEquals(graphSpec.getOutputStreams().get(mockStreamSpec), intermediateStreamImpl.getOutputStream()); + assertEquals(mockStreamSpec, intermediateStreamImpl.getStreamSpec()); + assertEquals(mockKeySerde, intermediateStreamImpl.getOutputStream().getKeySerde()); + assertEquals(mockValueSerde, intermediateStreamImpl.getOutputStream().getValueSerde()); + assertEquals(mockKeySerde, ((InputOperatorSpec) intermediateStreamImpl.getOperatorSpec()).getKeySerde()); + assertEquals(mockValueSerde, ((InputOperatorSpec) intermediateStreamImpl.getOperatorSpec()).getValueSerde()); + } + + @Test + public void testGetIntermediateStreamWithDefaultValueSerde() { + ApplicationRunner mockRunner = mock(ApplicationRunner.class); + Config mockConfig = mock(Config.class); + StreamSpec mockStreamSpec = mock(StreamSpec.class); + String mockStreamName = "mockStreamName"; + when(mockRunner.getStreamSpec(mockStreamName)).thenReturn(mockStreamSpec); + + StreamGraphSpec graph = new StreamGraphSpec(mockRunner, mockConfig); + + Serde mockValueSerde = mock(Serde.class); + graph.setDefaultSerde(mockValueSerde); + IntermediateMessageStreamImpl<TestMessageEnvelope> intermediateStreamImpl = + graph.getIntermediateStream(mockStreamName, null); + + assertEquals(graph.getInputOperators().get(mockStreamSpec), intermediateStreamImpl.getOperatorSpec()); + assertEquals(graph.getOutputStreams().get(mockStreamSpec), intermediateStreamImpl.getOutputStream()); + assertEquals(mockStreamSpec, intermediateStreamImpl.getStreamSpec()); + assertTrue(intermediateStreamImpl.getOutputStream().getKeySerde() instanceof NoOpSerde); + assertEquals(mockValueSerde, intermediateStreamImpl.getOutputStream().getValueSerde()); + assertTrue(((InputOperatorSpec) intermediateStreamImpl.getOperatorSpec()).getKeySerde() instanceof NoOpSerde); + assertEquals(mockValueSerde, ((InputOperatorSpec) intermediateStreamImpl.getOperatorSpec()).getValueSerde()); + } + + @Test + public void testGetIntermediateStreamWithDefaultKeyValueSerde() { + ApplicationRunner mockRunner = mock(ApplicationRunner.class); + Config mockConfig = mock(Config.class); + StreamSpec mockStreamSpec = mock(StreamSpec.class); + String mockStreamName = "mockStreamName"; + when(mockRunner.getStreamSpec(mockStreamName)).thenReturn(mockStreamSpec); + + StreamGraphSpec graphSpec = new StreamGraphSpec(mockRunner, mockConfig); + + KVSerde mockKVSerde = mock(KVSerde.class); + Serde mockKeySerde = mock(Serde.class); + Serde mockValueSerde = mock(Serde.class); + doReturn(mockKeySerde).when(mockKVSerde).getKeySerde(); + doReturn(mockValueSerde).when(mockKVSerde).getValueSerde(); + graphSpec.setDefaultSerde(mockKVSerde); + IntermediateMessageStreamImpl<TestMessageEnvelope> intermediateStreamImpl = + graphSpec.getIntermediateStream(mockStreamName, null); + + assertEquals(graphSpec.getInputOperators().get(mockStreamSpec), intermediateStreamImpl.getOperatorSpec()); + assertEquals(graphSpec.getOutputStreams().get(mockStreamSpec), intermediateStreamImpl.getOutputStream()); + assertEquals(mockStreamSpec, intermediateStreamImpl.getStreamSpec()); + assertEquals(mockKeySerde, intermediateStreamImpl.getOutputStream().getKeySerde()); + assertEquals(mockValueSerde, intermediateStreamImpl.getOutputStream().getValueSerde()); + assertEquals(mockKeySerde, ((InputOperatorSpec) intermediateStreamImpl.getOperatorSpec()).getKeySerde()); + assertEquals(mockValueSerde, ((InputOperatorSpec) intermediateStreamImpl.getOperatorSpec()).getValueSerde()); + } + + @Test + public void testGetIntermediateStreamWithDefaultDefaultSerde() { + ApplicationRunner mockRunner = mock(ApplicationRunner.class); + Config mockConfig = mock(Config.class); + StreamSpec mockStreamSpec = mock(StreamSpec.class); + String mockStreamName = "mockStreamName"; + when(mockRunner.getStreamSpec(mockStreamName)).thenReturn(mockStreamSpec); + + StreamGraphSpec graphSpec = new StreamGraphSpec(mockRunner, mockConfig); + IntermediateMessageStreamImpl<TestMessageEnvelope> intermediateStreamImpl = + graphSpec.getIntermediateStream(mockStreamName, null); + + assertEquals(graphSpec.getInputOperators().get(mockStreamSpec), intermediateStreamImpl.getOperatorSpec()); + assertEquals(graphSpec.getOutputStreams().get(mockStreamSpec), intermediateStreamImpl.getOutputStream()); + assertEquals(mockStreamSpec, intermediateStreamImpl.getStreamSpec()); + assertTrue(intermediateStreamImpl.getOutputStream().getKeySerde() instanceof NoOpSerde); + assertTrue(intermediateStreamImpl.getOutputStream().getValueSerde() instanceof NoOpSerde); + assertTrue(((InputOperatorSpec) intermediateStreamImpl.getOperatorSpec()).getKeySerde() instanceof NoOpSerde); + assertTrue(((InputOperatorSpec) intermediateStreamImpl.getOperatorSpec()).getValueSerde() instanceof NoOpSerde); + } + + @Test(expected = IllegalStateException.class) + public void testGetSameIntermediateStreamTwice() { + ApplicationRunner mockRunner = mock(ApplicationRunner.class); + when(mockRunner.getStreamSpec("test-stream-1")).thenReturn(mock(StreamSpec.class)); + + StreamGraphSpec graphSpec = new StreamGraphSpec(mockRunner, mock(Config.class)); + graphSpec.getIntermediateStream("test-stream-1", mock(Serde.class)); + graphSpec.getIntermediateStream("test-stream-1", mock(Serde.class)); + } + + @Test + public void testGetNextOpIdIncrementsId() { + ApplicationRunner mockRunner = mock(ApplicationRunner.class); + Config mockConfig = mock(Config.class); + when(mockConfig.get(eq(JobConfig.JOB_NAME()))).thenReturn("jobName"); + when(mockConfig.get(eq(JobConfig.JOB_ID()), anyString())).thenReturn("1234"); + + StreamGraphSpec graphSpec = new StreamGraphSpec(mockRunner, mockConfig); + assertEquals("jobName-1234-merge-0", graphSpec.getNextOpId(OpCode.MERGE, null)); + assertEquals("jobName-1234-join-customName", graphSpec.getNextOpId(OpCode.JOIN, "customName")); + assertEquals("jobName-1234-map-2", graphSpec.getNextOpId(OpCode.MAP, null)); + } + + @Test(expected = SamzaException.class) + public void testGetNextOpIdRejectsDuplicates() { + ApplicationRunner mockRunner = mock(ApplicationRunner.class); + Config mockConfig = mock(Config.class); + when(mockConfig.get(eq(JobConfig.JOB_NAME()))).thenReturn("jobName"); + when(mockConfig.get(eq(JobConfig.JOB_ID()), anyString())).thenReturn("1234"); + + StreamGraphSpec graphSpec = new StreamGraphSpec(mockRunner, mockConfig); + assertEquals("jobName-1234-join-customName", graphSpec.getNextOpId(OpCode.JOIN, "customName")); + graphSpec.getNextOpId(OpCode.JOIN, "customName"); // should throw + } + + @Test + public void testUserDefinedIdValidation() { + ApplicationRunner mockRunner = mock(ApplicationRunner.class); + Config mockConfig = mock(Config.class); + when(mockConfig.get(eq(JobConfig.JOB_NAME()))).thenReturn("jobName"); + when(mockConfig.get(eq(JobConfig.JOB_ID()), anyString())).thenReturn("1234"); + + StreamGraphSpec graphSpec = new StreamGraphSpec(mockRunner, mockConfig); + + // null and empty userDefinedIDs should fall back to autogenerated IDs. + try { + graphSpec.getNextOpId(OpCode.FILTER, null); + graphSpec.getNextOpId(OpCode.FILTER, ""); + graphSpec.getNextOpId(OpCode.FILTER, " "); + graphSpec.getNextOpId(OpCode.FILTER, "\t"); + } catch (SamzaException e) { + fail("Received an error with a null or empty operator ID instead of defaulting to auto-generated ID."); + } + + List<String> validOpIds = ImmutableList.of("op.id", "op_id", "op-id", "1000", "op_1", "OP_ID"); + for (String validOpId: validOpIds) { + try { + graphSpec.getNextOpId(OpCode.FILTER, validOpId); + } catch (Exception e) { + fail("Received an exception with a valid operator ID: " + validOpId); + } + } + + List<String> invalidOpIds = ImmutableList.of("op id", "op#id"); + for (String invalidOpId: invalidOpIds) { + try { + graphSpec.getNextOpId(OpCode.FILTER, invalidOpId); + fail("Did not receive an exception with an invalid operator ID: " + invalidOpId); + } catch (SamzaException e) { } + } + } + + @Test + public void testGetInputStreamPreservesInsertionOrder() { + ApplicationRunner mockRunner = mock(ApplicationRunner.class); + Config mockConfig = mock(Config.class); + + StreamGraphSpec graphSpec = new StreamGraphSpec(mockRunner, mockConfig); + + StreamSpec testStreamSpec1 = new StreamSpec("test-stream-1", "physical-stream-1", "test-system"); + when(mockRunner.getStreamSpec("test-stream-1")).thenReturn(testStreamSpec1); + + StreamSpec testStreamSpec2 = new StreamSpec("test-stream-2", "physical-stream-2", "test-system"); + when(mockRunner.getStreamSpec("test-stream-2")).thenReturn(testStreamSpec2); + + StreamSpec testStreamSpec3 = new StreamSpec("test-stream-3", "physical-stream-3", "test-system"); + when(mockRunner.getStreamSpec("test-stream-3")).thenReturn(testStreamSpec3); + + graphSpec.getInputStream("test-stream-1"); + graphSpec.getInputStream("test-stream-2"); + graphSpec.getInputStream("test-stream-3"); + + List<InputOperatorSpec> inputSpecs = new ArrayList<>(graphSpec.getInputOperators().values()); + assertEquals(inputSpecs.size(), 3); + assertEquals(inputSpecs.get(0).getStreamSpec(), testStreamSpec1); + assertEquals(inputSpecs.get(1).getStreamSpec(), testStreamSpec2); + assertEquals(inputSpecs.get(2).getStreamSpec(), testStreamSpec3); + } + + @Test + public void testGetTable() { + ApplicationRunner mockRunner = mock(ApplicationRunner.class); + Config mockConfig = mock(Config.class); + StreamGraphSpec graphSpec = new StreamGraphSpec(mockRunner, mockConfig); + + BaseTableDescriptor mockTableDescriptor = mock(BaseTableDescriptor.class); + when(mockTableDescriptor.getTableSpec()).thenReturn( + new TableSpec("t1", KVSerde.of(new NoOpSerde(), new NoOpSerde()), "", new HashMap<>())); + assertNotNull(graphSpec.getTable(mockTableDescriptor)); + } +}
http://git-wip-us.apache.org/repos/asf/samza/blob/53d7f262/samza-core/src/test/java/org/apache/samza/operators/data/TestOutputMessageEnvelope.java ---------------------------------------------------------------------- diff --git a/samza-core/src/test/java/org/apache/samza/operators/data/TestOutputMessageEnvelope.java b/samza-core/src/test/java/org/apache/samza/operators/data/TestOutputMessageEnvelope.java index f9537a3..519e5df 100644 --- a/samza-core/src/test/java/org/apache/samza/operators/data/TestOutputMessageEnvelope.java +++ b/samza-core/src/test/java/org/apache/samza/operators/data/TestOutputMessageEnvelope.java @@ -35,5 +35,19 @@ public class TestOutputMessageEnvelope { public String getKey() { return this.key; } + + @Override + public boolean equals(Object other) { + if (!(other instanceof TestOutputMessageEnvelope)) { + return false; + } + TestOutputMessageEnvelope otherMsg = (TestOutputMessageEnvelope) other; + return this.key.equals(otherMsg.key) && this.value.equals(otherMsg.value); + } + + @Override + public int hashCode() { + return String.format("%s:%d", key, value).hashCode(); + } } http://git-wip-us.apache.org/repos/asf/samza/blob/53d7f262/samza-core/src/test/java/org/apache/samza/operators/impl/TestOperatorImplGraph.java ---------------------------------------------------------------------- diff --git a/samza-core/src/test/java/org/apache/samza/operators/impl/TestOperatorImplGraph.java b/samza-core/src/test/java/org/apache/samza/operators/impl/TestOperatorImplGraph.java index 2d8d1eb..b87e5ed 100644 --- a/samza-core/src/test/java/org/apache/samza/operators/impl/TestOperatorImplGraph.java +++ b/samza-core/src/test/java/org/apache/samza/operators/impl/TestOperatorImplGraph.java @@ -21,11 +21,17 @@ package org.apache.samza.operators.impl; import com.google.common.collect.HashMultimap; import com.google.common.collect.Multimap; +import java.io.Serializable; +import java.time.Duration; +import java.util.ArrayList; import java.util.Collection; +import java.util.Collections; import java.util.HashMap; import java.util.HashSet; import java.util.Map; import java.util.Set; +import java.util.function.BiFunction; +import java.util.function.Function; import org.apache.samza.Partition; import org.apache.samza.config.Config; import org.apache.samza.config.JobConfig; @@ -39,9 +45,11 @@ import org.apache.samza.job.model.TaskModel; import org.apache.samza.metrics.MetricsRegistryMap; import org.apache.samza.operators.KV; import org.apache.samza.operators.MessageStream; +import org.apache.samza.operators.StreamGraphSpec; import org.apache.samza.operators.OutputStream; -import org.apache.samza.operators.StreamGraphImpl; +import org.apache.samza.operators.functions.ClosableFunction; import org.apache.samza.operators.functions.FilterFunction; +import org.apache.samza.operators.functions.InitableFunction; import org.apache.samza.operators.functions.JoinFunction; import org.apache.samza.operators.functions.MapFunction; import org.apache.samza.operators.impl.store.TimestampedValue; @@ -58,34 +66,160 @@ import org.apache.samza.system.SystemStream; import org.apache.samza.system.SystemStreamPartition; import org.apache.samza.task.MessageCollector; import org.apache.samza.task.TaskContext; +import java.util.List; import org.apache.samza.task.TaskCoordinator; import org.apache.samza.util.Clock; import org.apache.samza.util.SystemClock; +import org.junit.After; import org.junit.Test; -import java.time.Duration; -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; - import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotSame; import static org.junit.Assert.assertTrue; -import static org.mockito.Matchers.any; import static org.mockito.Matchers.anyString; import static org.mockito.Matchers.eq; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; public class TestOperatorImplGraph { + private void addOperatorRecursively(HashSet<OperatorImpl> s, OperatorImpl op) { + List<OperatorImpl> operators = new ArrayList<>(); + operators.add(op); + while (!operators.isEmpty()) { + OperatorImpl opImpl = operators.remove(0); + s.add(opImpl); + if (!opImpl.registeredOperators.isEmpty()) { + operators.addAll(opImpl.registeredOperators); + } + } + } + + static class TestMapFunction<M, OM> extends BaseTestFunction implements MapFunction<M, OM> { + final Function<M, OM> mapFn; + + public TestMapFunction(String opId, Function<M, OM> mapFn) { + super(opId); + this.mapFn = mapFn; + } + + @Override + public OM apply(M message) { + return this.mapFn.apply(message); + } + } + + static class TestJoinFunction<K, M, JM, RM> extends BaseTestFunction implements JoinFunction<K, M, JM, RM> { + final BiFunction<M, JM, RM> joiner; + final Function<M, K> firstKeyFn; + final Function<JM, K> secondKeyFn; + final Collection<RM> joinResults = new HashSet<>(); + + public TestJoinFunction(String opId, BiFunction<M, JM, RM> joiner, Function<M, K> firstKeyFn, Function<JM, K> secondKeyFn) { + super(opId); + this.joiner = joiner; + this.firstKeyFn = firstKeyFn; + this.secondKeyFn = secondKeyFn; + } + + @Override + public RM apply(M message, JM otherMessage) { + RM result = this.joiner.apply(message, otherMessage); + this.joinResults.add(result); + return result; + } + + @Override + public K getFirstKey(M message) { + return this.firstKeyFn.apply(message); + } + + @Override + public K getSecondKey(JM message) { + return this.secondKeyFn.apply(message); + } + } + + static abstract class BaseTestFunction implements InitableFunction, ClosableFunction, Serializable { + + static Map<TaskName, Map<String, BaseTestFunction>> perTaskFunctionMap = new HashMap<>(); + static Map<TaskName, List<String>> perTaskInitList = new HashMap<>(); + static Map<TaskName, List<String>> perTaskCloseList = new HashMap<>(); + int numInitCalled = 0; + int numCloseCalled = 0; + TaskName taskName = null; + final String opId; + + public BaseTestFunction(String opId) { + this.opId = opId; + } + + static public void reset() { + perTaskFunctionMap.clear(); + perTaskCloseList.clear(); + perTaskInitList.clear(); + } + + static public BaseTestFunction getInstanceByTaskName(TaskName taskName, String opId) { + return perTaskFunctionMap.get(taskName).get(opId); + } + + static public List<String> getInitListByTaskName(TaskName taskName) { + return perTaskInitList.get(taskName); + } + + static public List<String> getCloseListByTaskName(TaskName taskName) { + return perTaskCloseList.get(taskName); + } + + @Override + public void close() { + if (this.taskName == null) { + throw new IllegalStateException("Close called before init"); + } + if (perTaskFunctionMap.get(this.taskName) == null || !perTaskFunctionMap.get(this.taskName).containsKey(opId)) { + throw new IllegalStateException("Close called before init"); + } + + if (perTaskCloseList.get(this.taskName) == null) { + perTaskCloseList.put(taskName, new ArrayList<String>() { { this.add(opId); } }); + } else { + perTaskCloseList.get(taskName).add(opId); + } + + this.numCloseCalled++; + } + + @Override + public void init(Config config, TaskContext context) { + if (perTaskFunctionMap.get(context.getTaskName()) == null) { + perTaskFunctionMap.put(context.getTaskName(), new HashMap<String, BaseTestFunction>() { { this.put(opId, BaseTestFunction.this); } }); + } else { + if (perTaskFunctionMap.get(context.getTaskName()).containsKey(opId)) { + throw new IllegalStateException(String.format("Multiple init called for op %s in the same task instance %s", opId, this.taskName.getTaskName())); + } + perTaskFunctionMap.get(context.getTaskName()).put(opId, this); + } + if (perTaskInitList.get(context.getTaskName()) == null) { + perTaskInitList.put(context.getTaskName(), new ArrayList<String>() { { this.add(opId); } }); + } else { + perTaskInitList.get(context.getTaskName()).add(opId); + } + this.taskName = context.getTaskName(); + this.numInitCalled++; + } + } + + @After + public void tearDown() { + BaseTestFunction.reset(); + } + @Test public void testEmptyChain() { - StreamGraphImpl streamGraph = new StreamGraphImpl(mock(ApplicationRunner.class), mock(Config.class)); + StreamGraphSpec graphSpec = new StreamGraphSpec(mock(ApplicationRunner.class), mock(Config.class)); OperatorImplGraph opGraph = - new OperatorImplGraph(streamGraph, mock(Config.class), mock(TaskContextImpl.class), mock(Clock.class)); + new OperatorImplGraph(graphSpec.getOperatorSpecGraph(), mock(Config.class), mock(TaskContextImpl.class), mock(Clock.class)); assertEquals(0, opGraph.getAllInputOperators().size()); } @@ -94,10 +228,10 @@ public class TestOperatorImplGraph { ApplicationRunner mockRunner = mock(ApplicationRunner.class); when(mockRunner.getStreamSpec(eq("input"))).thenReturn(new StreamSpec("input", "input-stream", "input-system")); when(mockRunner.getStreamSpec(eq("output"))).thenReturn(mock(StreamSpec.class)); - StreamGraphImpl streamGraph = new StreamGraphImpl(mockRunner, mock(Config.class)); + StreamGraphSpec graphSpec = new StreamGraphSpec(mockRunner, mock(Config.class)); - MessageStream<Object> inputStream = streamGraph.getInputStream("input"); - OutputStream<Object> outputStream = streamGraph.getOutputStream("output"); + MessageStream<Object> inputStream = graphSpec.getInputStream("input"); + OutputStream<Object> outputStream = graphSpec.getOutputStream("output"); inputStream .filter(mock(FilterFunction.class)) @@ -108,7 +242,7 @@ public class TestOperatorImplGraph { when(mockTaskContext.getMetricsRegistry()).thenReturn(new MetricsRegistryMap()); when(mockTaskContext.getTaskName()).thenReturn(new TaskName("task 0")); OperatorImplGraph opImplGraph = - new OperatorImplGraph(streamGraph, mock(Config.class), mockTaskContext, mock(Clock.class)); + new OperatorImplGraph(graphSpec.getOperatorSpecGraph(), mock(Config.class), mockTaskContext, mock(Clock.class)); InputOperatorImpl inputOpImpl = opImplGraph.getInputOperator(new SystemStream("input-system", "input-stream")); assertEquals(1, inputOpImpl.registeredOperators.size()); @@ -136,9 +270,9 @@ public class TestOperatorImplGraph { Config mockConfig = mock(Config.class); when(mockConfig.get(JobConfig.JOB_NAME())).thenReturn("jobName"); when(mockConfig.get(eq(JobConfig.JOB_ID()), anyString())).thenReturn("jobId"); - StreamGraphImpl streamGraph = new StreamGraphImpl(mockRunner, mockConfig); - MessageStream<Object> inputStream = streamGraph.getInputStream("input"); - OutputStream<KV<Integer, String>> outputStream = streamGraph + StreamGraphSpec graphSpec = new StreamGraphSpec(mockRunner, mockConfig); + MessageStream<Object> inputStream = graphSpec.getInputStream("input"); + OutputStream<KV<Integer, String>> outputStream = graphSpec .getOutputStream("output", KVSerde.of(mock(IntegerSerde.class), mock(StringSerde.class))); inputStream @@ -160,7 +294,7 @@ public class TestOperatorImplGraph { new SamzaContainerContext("0", mockConfig, Collections.singleton(new TaskName("task 0")), new MetricsRegistryMap()); when(mockTaskContext.getSamzaContainerContext()).thenReturn(containerContext); OperatorImplGraph opImplGraph = - new OperatorImplGraph(streamGraph, mockConfig, mockTaskContext, mock(Clock.class)); + new OperatorImplGraph(graphSpec.getOperatorSpecGraph(), mockConfig, mockTaskContext, mock(Clock.class)); InputOperatorImpl inputOpImpl = opImplGraph.getInputOperator(new SystemStream("input-system", "input-stream")); assertEquals(1, inputOpImpl.registeredOperators.size()); @@ -182,16 +316,16 @@ public class TestOperatorImplGraph { public void testBroadcastChain() { ApplicationRunner mockRunner = mock(ApplicationRunner.class); when(mockRunner.getStreamSpec(eq("input"))).thenReturn(new StreamSpec("input", "input-stream", "input-system")); - StreamGraphImpl streamGraph = new StreamGraphImpl(mockRunner, mock(Config.class)); + StreamGraphSpec graphSpec = new StreamGraphSpec(mockRunner, mock(Config.class)); - MessageStream<Object> inputStream = streamGraph.getInputStream("input"); + MessageStream<Object> inputStream = graphSpec.getInputStream("input"); inputStream.filter(mock(FilterFunction.class)); inputStream.map(mock(MapFunction.class)); TaskContextImpl mockTaskContext = mock(TaskContextImpl.class); when(mockTaskContext.getMetricsRegistry()).thenReturn(new MetricsRegistryMap()); OperatorImplGraph opImplGraph = - new OperatorImplGraph(streamGraph, mock(Config.class), mockTaskContext, mock(Clock.class)); + new OperatorImplGraph(graphSpec.getOperatorSpecGraph(), mock(Config.class), mockTaskContext, mock(Clock.class)); InputOperatorImpl inputOpImpl = opImplGraph.getInputOperator(new SystemStream("input-system", "input-stream")); assertEquals(2, inputOpImpl.registeredOperators.size()); @@ -204,23 +338,36 @@ public class TestOperatorImplGraph { @Test public void testMergeChain() { ApplicationRunner mockRunner = mock(ApplicationRunner.class); - when(mockRunner.getStreamSpec(eq("input"))).thenReturn(new StreamSpec("input", "input-stream", "input-system")); - StreamGraphImpl streamGraph = new StreamGraphImpl(mockRunner, mock(Config.class)); + when(mockRunner.getStreamSpec(eq("input"))) + .thenReturn(new StreamSpec("input", "input-stream", "input-system")); + StreamGraphSpec graphSpec = new StreamGraphSpec(mockRunner, mock(Config.class)); - MessageStream<Object> inputStream = streamGraph.getInputStream("input"); + MessageStream<Object> inputStream = graphSpec.getInputStream("input"); MessageStream<Object> stream1 = inputStream.filter(mock(FilterFunction.class)); MessageStream<Object> stream2 = inputStream.map(mock(MapFunction.class)); MessageStream<Object> mergedStream = stream1.merge(Collections.singleton(stream2)); - MapFunction mockMapFunction = mock(MapFunction.class); - mergedStream.map(mockMapFunction); TaskContextImpl mockTaskContext = mock(TaskContextImpl.class); + TaskName mockTaskName = mock(TaskName.class); + when(mockTaskContext.getTaskName()).thenReturn(mockTaskName); when(mockTaskContext.getMetricsRegistry()).thenReturn(new MetricsRegistryMap()); + + MapFunction testMapFunction = new TestMapFunction<Object, Object>("test-map-1", (Function & Serializable) m -> m); + mergedStream.map(testMapFunction); + OperatorImplGraph opImplGraph = - new OperatorImplGraph(streamGraph, mock(Config.class), mockTaskContext, mock(Clock.class)); + new OperatorImplGraph(graphSpec.getOperatorSpecGraph(), mock(Config.class), mockTaskContext, mock(Clock.class)); + + Set<OperatorImpl> opSet = opImplGraph.getAllInputOperators().stream().collect(HashSet::new, + (s, op) -> addOperatorRecursively(s, op), HashSet::addAll); + Object[] mergeOps = opSet.stream().filter(op -> op.getOperatorSpec().getOpCode() == OpCode.MERGE).toArray(); + assertEquals(mergeOps.length, 1); + assertEquals(((OperatorImpl) mergeOps[0]).registeredOperators.size(), 1); + OperatorImpl mapOp = (OperatorImpl) ((OperatorImpl) mergeOps[0]).registeredOperators.iterator().next(); + assertEquals(mapOp.getOperatorSpec().getOpCode(), OpCode.MAP); // verify that the DAG after merge is only traversed & initialized once - verify(mockMapFunction, times(1)).init(any(Config.class), any(TaskContext.class)); + assertEquals(TestMapFunction.getInstanceByTaskName(mockTaskName, "test-map-1").numInitCalled, 1); } @Test @@ -231,25 +378,30 @@ public class TestOperatorImplGraph { Config mockConfig = mock(Config.class); when(mockConfig.get(JobConfig.JOB_NAME())).thenReturn("jobName"); when(mockConfig.get(eq(JobConfig.JOB_ID()), anyString())).thenReturn("jobId"); - StreamGraphImpl streamGraph = new StreamGraphImpl(mockRunner, mockConfig); - - JoinFunction mockJoinFunction = mock(JoinFunction.class); - MessageStream<Object> inputStream1 = streamGraph.getInputStream("input1", new NoOpSerde<>()); - MessageStream<Object> inputStream2 = streamGraph.getInputStream("input2", new NoOpSerde<>()); - inputStream1.join(inputStream2, mockJoinFunction, + StreamGraphSpec graphSpec = new StreamGraphSpec(mockRunner, mockConfig); + + Integer joinKey = new Integer(1); + Function<Object, Integer> keyFn = (Function & Serializable) m -> joinKey; + JoinFunction testJoinFunction = new TestJoinFunction("jobName-jobId-join-j1", + (BiFunction & Serializable) (m1, m2) -> KV.of(m1, m2), keyFn, keyFn); + MessageStream<Object> inputStream1 = graphSpec.getInputStream("input1", new NoOpSerde<>()); + MessageStream<Object> inputStream2 = graphSpec.getInputStream("input2", new NoOpSerde<>()); + inputStream1.join(inputStream2, testJoinFunction, mock(Serde.class), mock(Serde.class), mock(Serde.class), Duration.ofHours(1), "j1"); + TaskName mockTaskName = mock(TaskName.class); TaskContextImpl mockTaskContext = mock(TaskContextImpl.class); + when(mockTaskContext.getTaskName()).thenReturn(mockTaskName); when(mockTaskContext.getMetricsRegistry()).thenReturn(new MetricsRegistryMap()); KeyValueStore mockLeftStore = mock(KeyValueStore.class); when(mockTaskContext.getStore(eq("jobName-jobId-join-j1-L"))).thenReturn(mockLeftStore); KeyValueStore mockRightStore = mock(KeyValueStore.class); when(mockTaskContext.getStore(eq("jobName-jobId-join-j1-R"))).thenReturn(mockRightStore); OperatorImplGraph opImplGraph = - new OperatorImplGraph(streamGraph, mockConfig, mockTaskContext, mock(Clock.class)); + new OperatorImplGraph(graphSpec.getOperatorSpecGraph(), mockConfig, mockTaskContext, mock(Clock.class)); // verify that join function is initialized once. - verify(mockJoinFunction, times(1)).init(any(Config.class), any(TaskContext.class)); + assertEquals(TestJoinFunction.getInstanceByTaskName(mockTaskName, "jobName-jobId-join-j1").numInitCalled, 1); InputOperatorImpl inputOpImpl1 = opImplGraph.getInputOperator(new SystemStream("input-system", "input-stream1")); InputOperatorImpl inputOpImpl2 = opImplGraph.getInputOperator(new SystemStream("input-system", "input-stream2")); @@ -261,24 +413,23 @@ public class TestOperatorImplGraph { assertEquals(leftPartialJoinOpImpl.getOperatorSpec(), rightPartialJoinOpImpl.getOperatorSpec()); assertNotSame(leftPartialJoinOpImpl, rightPartialJoinOpImpl); - Object joinKey = new Object(); // verify that left partial join operator calls getFirstKey Object mockLeftMessage = mock(Object.class); long currentTimeMillis = System.currentTimeMillis(); when(mockLeftStore.get(eq(joinKey))).thenReturn(new TimestampedValue<>(mockLeftMessage, currentTimeMillis)); - when(mockJoinFunction.getFirstKey(eq(mockLeftMessage))).thenReturn(joinKey); inputOpImpl1.onMessage(KV.of("", mockLeftMessage), mock(MessageCollector.class), mock(TaskCoordinator.class)); - verify(mockJoinFunction, times(1)).getFirstKey(mockLeftMessage); // verify that right partial join operator calls getSecondKey Object mockRightMessage = mock(Object.class); when(mockRightStore.get(eq(joinKey))).thenReturn(new TimestampedValue<>(mockRightMessage, currentTimeMillis)); - when(mockJoinFunction.getSecondKey(eq(mockRightMessage))).thenReturn(joinKey); inputOpImpl2.onMessage(KV.of("", mockRightMessage), mock(MessageCollector.class), mock(TaskCoordinator.class)); - verify(mockJoinFunction, times(1)).getSecondKey(mockRightMessage); + // verify that the join function apply is called with the correct messages on match - verify(mockJoinFunction, times(1)).apply(mockLeftMessage, mockRightMessage); + assertEquals(((TestJoinFunction) TestJoinFunction.getInstanceByTaskName(mockTaskName, "jobName-jobId-join-j1")).joinResults.size(), 1); + KV joinResult = (KV) ((TestJoinFunction) TestJoinFunction.getInstanceByTaskName(mockTaskName, "jobName-jobId-join-j1")).joinResults.iterator().next(); + assertEquals(joinResult.getKey(), mockLeftMessage); + assertEquals(joinResult.getValue(), mockRightMessage); } @Test @@ -287,23 +438,25 @@ public class TestOperatorImplGraph { when(mockRunner.getStreamSpec("input1")).thenReturn(new StreamSpec("input1", "input-stream1", "input-system")); when(mockRunner.getStreamSpec("input2")).thenReturn(new StreamSpec("input2", "input-stream2", "input-system")); Config mockConfig = mock(Config.class); + TaskName mockTaskName = mock(TaskName.class); TaskContextImpl mockContext = mock(TaskContextImpl.class); + when(mockContext.getTaskName()).thenReturn(mockTaskName); when(mockContext.getMetricsRegistry()).thenReturn(new MetricsRegistryMap()); - StreamGraphImpl streamGraph = new StreamGraphImpl(mockRunner, mockConfig); + StreamGraphSpec graphSpec = new StreamGraphSpec(mockRunner, mockConfig); - MessageStream<Object> inputStream1 = streamGraph.getInputStream("input1"); - MessageStream<Object> inputStream2 = streamGraph.getInputStream("input2"); + MessageStream<Object> inputStream1 = graphSpec.getInputStream("input1"); + MessageStream<Object> inputStream2 = graphSpec.getInputStream("input2"); - List<String> initializedOperators = new ArrayList<>(); - List<String> closedOperators = new ArrayList<>(); + Function mapFn = (Function & Serializable) m -> m; + inputStream1.map(new TestMapFunction<Object, Object>("1", mapFn)) + .map(new TestMapFunction<Object, Object>("2", mapFn)); - inputStream1.map(createMapFunction("1", initializedOperators, closedOperators)) - .map(createMapFunction("2", initializedOperators, closedOperators)); + inputStream2.map(new TestMapFunction<Object, Object>("3", mapFn)) + .map(new TestMapFunction<Object, Object>("4", mapFn)); - inputStream2.map(createMapFunction("3", initializedOperators, closedOperators)) - .map(createMapFunction("4", initializedOperators, closedOperators)); + OperatorImplGraph opImplGraph = new OperatorImplGraph(graphSpec.getOperatorSpecGraph(), mockConfig, mockContext, SystemClock.instance()); - OperatorImplGraph opImplGraph = new OperatorImplGraph(streamGraph, mockConfig, mockContext, SystemClock.instance()); + List<String> initializedOperators = BaseTestFunction.getInitListByTaskName(mockTaskName); // Assert that initialization occurs in topological order. assertEquals(initializedOperators.get(0), "1"); @@ -313,35 +466,13 @@ public class TestOperatorImplGraph { // Assert that finalization occurs in reverse topological order. opImplGraph.close(); + List<String> closedOperators = BaseTestFunction.getCloseListByTaskName(mockTaskName); assertEquals(closedOperators.get(0), "4"); assertEquals(closedOperators.get(1), "3"); assertEquals(closedOperators.get(2), "2"); assertEquals(closedOperators.get(3), "1"); } - /** - * Creates an identity map function that appends to the provided lists when init/close is invoked. - */ - private MapFunction<Object, Object> createMapFunction(String id, - List<String> initializedOperators, List<String> finalizedOperators) { - return new MapFunction<Object, Object>() { - @Override - public void init(Config config, TaskContext context) { - initializedOperators.add(id); - } - - @Override - public void close() { - finalizedOperators.add(id); - } - - @Override - public Object apply(Object message) { - return message; - } - }; - } - @Test public void testGetStreamToConsumerTasks() { String system = "test-system"; @@ -409,16 +540,16 @@ public class TestOperatorImplGraph { when(runner.getStreamSpec("test-app-1-partition_by-p2")).thenReturn(int1); when(runner.getStreamSpec("test-app-1-partition_by-p1")).thenReturn(int2); - StreamGraphImpl streamGraph = new StreamGraphImpl(runner, config); - MessageStream messageStream1 = streamGraph.getInputStream("input1").map(m -> m); - MessageStream messageStream2 = streamGraph.getInputStream("input2").filter(m -> true); + StreamGraphSpec graphSpec = new StreamGraphSpec(runner, config); + MessageStream messageStream1 = graphSpec.getInputStream("input1").map(m -> m); + MessageStream messageStream2 = graphSpec.getInputStream("input2").filter(m -> true); MessageStream messageStream3 = - streamGraph.getInputStream("input3") + graphSpec.getInputStream("input3") .filter(m -> true) .partitionBy(m -> "hehe", m -> m, "p1") .map(m -> m); - OutputStream<Object> outputStream1 = streamGraph.getOutputStream("output1"); - OutputStream<Object> outputStream2 = streamGraph.getOutputStream("output2"); + OutputStream<Object> outputStream1 = graphSpec.getOutputStream("output1"); + OutputStream<Object> outputStream2 = graphSpec.getOutputStream("output2"); messageStream1 .join(messageStream2, mock(JoinFunction.class), @@ -430,7 +561,8 @@ public class TestOperatorImplGraph { mock(Serde.class), mock(Serde.class), mock(Serde.class), Duration.ofHours(1), "j2") .sendTo(outputStream2); - Multimap<SystemStream, SystemStream> outputToInput = OperatorImplGraph.getIntermediateToInputStreamsMap(streamGraph); + Multimap<SystemStream, SystemStream> outputToInput = + OperatorImplGraph.getIntermediateToInputStreamsMap(graphSpec.getOperatorSpecGraph()); Collection<SystemStream> inputs = outputToInput.get(int1.toSystemStream()); assertEquals(inputs.size(), 2); assertTrue(inputs.contains(input1.toSystemStream())); http://git-wip-us.apache.org/repos/asf/samza/blob/53d7f262/samza-core/src/test/java/org/apache/samza/operators/impl/TestStreamOperatorImpl.java ---------------------------------------------------------------------- diff --git a/samza-core/src/test/java/org/apache/samza/operators/impl/TestStreamOperatorImpl.java b/samza-core/src/test/java/org/apache/samza/operators/impl/TestStreamOperatorImpl.java index a91c1af..873cd3c 100644 --- a/samza-core/src/test/java/org/apache/samza/operators/impl/TestStreamOperatorImpl.java +++ b/samza-core/src/test/java/org/apache/samza/operators/impl/TestStreamOperatorImpl.java @@ -48,7 +48,7 @@ public class TestStreamOperatorImpl { Config mockConfig = mock(Config.class); TaskContext mockContext = mock(TaskContext.class); StreamOperatorImpl<TestMessageEnvelope, TestOutputMessageEnvelope> opImpl = - new StreamOperatorImpl<>(mockOp, mockConfig, mockContext); + new StreamOperatorImpl<>(mockOp); TestMessageEnvelope inMsg = mock(TestMessageEnvelope.class); Collection<TestOutputMessageEnvelope> mockOutputs = mock(Collection.class); when(txfmFn.apply(inMsg)).thenReturn(mockOutputs); @@ -69,7 +69,7 @@ public class TestStreamOperatorImpl { TaskContext mockContext = mock(TaskContext.class); StreamOperatorImpl<TestMessageEnvelope, TestOutputMessageEnvelope> opImpl = - new StreamOperatorImpl<>(mockOp, mockConfig, mockContext); + new StreamOperatorImpl<>(mockOp); // ensure that close is not called yet verify(txfmFn, times(0)).close(); http://git-wip-us.apache.org/repos/asf/samza/blob/53d7f262/samza-core/src/test/java/org/apache/samza/operators/impl/TestWindowOperator.java ---------------------------------------------------------------------- diff --git a/samza-core/src/test/java/org/apache/samza/operators/impl/TestWindowOperator.java b/samza-core/src/test/java/org/apache/samza/operators/impl/TestWindowOperator.java index 7d0c623..9741fc4 100644 --- a/samza-core/src/test/java/org/apache/samza/operators/impl/TestWindowOperator.java +++ b/samza-core/src/test/java/org/apache/samza/operators/impl/TestWindowOperator.java @@ -22,19 +22,20 @@ package org.apache.samza.operators.impl; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; -import junit.framework.Assert; import org.apache.samza.Partition; -import org.apache.samza.application.StreamApplication; import org.apache.samza.config.Config; +import org.apache.samza.config.MapConfig; import org.apache.samza.config.JobConfig; import org.apache.samza.container.TaskContextImpl; import org.apache.samza.container.TaskName; import org.apache.samza.metrics.MetricsRegistryMap; import org.apache.samza.operators.KV; import org.apache.samza.operators.MessageStream; -import org.apache.samza.operators.StreamGraph; +import org.apache.samza.operators.StreamGraphSpec; +import org.apache.samza.operators.functions.MapFunction; import org.apache.samza.operators.impl.store.TestInMemoryStore; import org.apache.samza.operators.impl.store.TimeSeriesKeySerde; +import org.apache.samza.operators.OperatorSpecGraph; import org.apache.samza.operators.triggers.FiringType; import org.apache.samza.operators.triggers.Trigger; import org.apache.samza.operators.triggers.Triggers; @@ -54,19 +55,25 @@ import org.apache.samza.task.MessageCollector; import org.apache.samza.task.StreamOperatorTask; import org.apache.samza.task.TaskCoordinator; import org.apache.samza.testUtils.TestClock; +import org.junit.Assert; import org.junit.Before; import org.junit.Test; +import java.io.IOException; import java.time.Duration; import java.util.ArrayList; import java.util.Collection; -import java.util.Collections; +import java.util.HashMap; import java.util.List; -import java.util.function.Function; +import java.util.Map; +import java.util.Collections; import static org.mockito.Matchers.anyString; import static org.mockito.Matchers.eq; -import static org.mockito.Mockito.*; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; public class TestWindowOperator { private final TaskCoordinator taskCoordinator = mock(TaskCoordinator.class); @@ -83,26 +90,32 @@ public class TestWindowOperator { taskContext = mock(TaskContextImpl.class); runner = mock(ApplicationRunner.class); Serde storeKeySerde = new TimeSeriesKeySerde(new IntegerSerde()); - Serde storeValSerde = new IntegerEnvelopeSerde(); + Serde storeValSerde = KVSerde.of(new IntegerSerde(), new IntegerSerde()); when(taskContext.getSystemStreamPartitions()).thenReturn(ImmutableSet .of(new SystemStreamPartition("kafka", "integers", new Partition(0)))); when(taskContext.getMetricsRegistry()).thenReturn(new MetricsRegistryMap()); - when(taskContext.getStore("jobName-jobId-window-w1")) .thenReturn(new TestInMemoryStore<>(storeKeySerde, storeValSerde)); when(runner.getStreamSpec("integers")).thenReturn(new StreamSpec("integers", "integers", "kafka")); + + Map<String, String> mapConfig = new HashMap<>(); + mapConfig.put("app.runner.class", "org.apache.samza.runtime.LocalApplicationRunner"); + mapConfig.put("job.default.system", "kafka"); + mapConfig.put("job.name", "jobName"); + mapConfig.put("job.id", "jobId"); + config = new MapConfig(mapConfig); } @Test public void testTumblingWindowsDiscardingMode() throws Exception { - StreamApplication sgb = new KeyedTumblingWindowStreamApplication(AccumulationMode.DISCARDING, - Duration.ofSeconds(1), Triggers.repeat(Triggers.count(2))); + OperatorSpecGraph sgb = this.getKeyedTumblingWindowStreamGraph(AccumulationMode.DISCARDING, + Duration.ofSeconds(1), Triggers.repeat(Triggers.count(2))).getOperatorSpecGraph(); List<WindowPane<Integer, Collection<IntegerEnvelope>>> windowPanes = new ArrayList<>(); TestClock testClock = new TestClock(); - StreamOperatorTask task = new StreamOperatorTask(sgb, runner, testClock); + StreamOperatorTask task = new StreamOperatorTask(sgb, null, testClock); task.init(config, taskContext); MessageCollector messageCollector = envelope -> windowPanes.add((WindowPane<Integer, Collection<IntegerEnvelope>>) envelope.getMessage()); @@ -130,12 +143,12 @@ public class TestWindowOperator { @Test public void testNonKeyedTumblingWindowsDiscardingMode() throws Exception { - StreamApplication sgb = new TumblingWindowStreamApplication(AccumulationMode.DISCARDING, - Duration.ofSeconds(1), Triggers.repeat(Triggers.count(1000))); + OperatorSpecGraph sgb = this.getTumblingWindowStreamGraph(AccumulationMode.DISCARDING, + Duration.ofSeconds(1), Triggers.repeat(Triggers.count(1000))).getOperatorSpecGraph(); List<WindowPane<Integer, Collection<IntegerEnvelope>>> windowPanes = new ArrayList<>(); TestClock testClock = new TestClock(); - StreamOperatorTask task = new StreamOperatorTask(sgb, runner, testClock); + StreamOperatorTask task = new StreamOperatorTask(sgb, null, testClock); task.init(config, taskContext); MessageCollector messageCollector = @@ -159,12 +172,12 @@ public class TestWindowOperator { when(taskContext.getStore("jobName-jobId-window-w1")) .thenReturn(new TestInMemoryStore<>(new TimeSeriesKeySerde(new IntegerSerde()), new IntegerSerde())); - StreamApplication sgb = new AggregateTumblingWindowStreamApplication(AccumulationMode.DISCARDING, - Duration.ofSeconds(1), Triggers.repeat(Triggers.count(2))); + OperatorSpecGraph sgb = this.getAggregateTumblingWindowStreamGraph(AccumulationMode.DISCARDING, + Duration.ofSeconds(1), Triggers.repeat(Triggers.count(2))).getOperatorSpecGraph(); List<WindowPane<Integer, Integer>> windowPanes = new ArrayList<>(); TestClock testClock = new TestClock(); - StreamOperatorTask task = new StreamOperatorTask(sgb, runner, testClock); + StreamOperatorTask task = new StreamOperatorTask(sgb, null, testClock); task.init(config, taskContext); MessageCollector messageCollector = envelope -> windowPanes.add((WindowPane<Integer, Integer>) envelope.getMessage()); integers.forEach(n -> task.process(new IntegerEnvelope(n), messageCollector, taskCoordinator)); @@ -181,11 +194,11 @@ public class TestWindowOperator { @Test public void testTumblingWindowsAccumulatingMode() throws Exception { - StreamApplication sgb = new KeyedTumblingWindowStreamApplication(AccumulationMode.ACCUMULATING, - Duration.ofSeconds(1), Triggers.repeat(Triggers.count(2))); + OperatorSpecGraph sgb = this.getKeyedTumblingWindowStreamGraph(AccumulationMode.ACCUMULATING, + Duration.ofSeconds(1), Triggers.repeat(Triggers.count(2))).getOperatorSpecGraph(); List<WindowPane<Integer, Collection<IntegerEnvelope>>> windowPanes = new ArrayList<>(); TestClock testClock = new TestClock(); - StreamOperatorTask task = new StreamOperatorTask(sgb, runner, testClock); + StreamOperatorTask task = new StreamOperatorTask(sgb, null, testClock); task.init(config, taskContext); MessageCollector messageCollector = @@ -210,10 +223,11 @@ public class TestWindowOperator { @Test public void testSessionWindowsDiscardingMode() throws Exception { - StreamApplication sgb = new KeyedSessionWindowStreamApplication(AccumulationMode.DISCARDING, Duration.ofMillis(500)); + OperatorSpecGraph sgb = + this.getKeyedSessionWindowStreamGraph(AccumulationMode.DISCARDING, Duration.ofMillis(500)).getOperatorSpecGraph(); TestClock testClock = new TestClock(); List<WindowPane<Integer, Collection<IntegerEnvelope>>> windowPanes = new ArrayList<>(); - StreamOperatorTask task = new StreamOperatorTask(sgb, runner, testClock); + StreamOperatorTask task = new StreamOperatorTask(sgb, null, testClock); task.init(config, taskContext); MessageCollector messageCollector = envelope -> windowPanes.add((WindowPane<Integer, Collection<IntegerEnvelope>>) envelope.getMessage()); @@ -255,10 +269,10 @@ public class TestWindowOperator { @Test public void testSessionWindowsAccumulatingMode() throws Exception { - StreamApplication sgb = new KeyedSessionWindowStreamApplication(AccumulationMode.DISCARDING, - Duration.ofMillis(500)); + OperatorSpecGraph sgb = this.getKeyedSessionWindowStreamGraph(AccumulationMode.DISCARDING, + Duration.ofMillis(500)).getOperatorSpecGraph(); TestClock testClock = new TestClock(); - StreamOperatorTask task = new StreamOperatorTask(sgb, runner, testClock); + StreamOperatorTask task = new StreamOperatorTask(sgb, null, testClock); List<WindowPane<Integer, Collection<IntegerEnvelope>>> windowPanes = new ArrayList<>(); MessageCollector messageCollector = @@ -287,10 +301,10 @@ public class TestWindowOperator { @Test public void testCancellationOfOnceTrigger() throws Exception { - StreamApplication sgb = new KeyedTumblingWindowStreamApplication(AccumulationMode.ACCUMULATING, - Duration.ofSeconds(1), Triggers.count(2)); + OperatorSpecGraph sgb = this.getKeyedTumblingWindowStreamGraph(AccumulationMode.ACCUMULATING, + Duration.ofSeconds(1), Triggers.count(2)).getOperatorSpecGraph(); TestClock testClock = new TestClock(); - StreamOperatorTask task = new StreamOperatorTask(sgb, runner, testClock); + StreamOperatorTask task = new StreamOperatorTask(sgb, null, testClock); task.init(config, taskContext); List<WindowPane<Integer, Collection<IntegerEnvelope>>> windowPanes = new ArrayList<>(); @@ -331,10 +345,10 @@ public class TestWindowOperator { @Test public void testCancellationOfAnyTrigger() throws Exception { - StreamApplication sgb = new KeyedTumblingWindowStreamApplication(AccumulationMode.ACCUMULATING, Duration.ofSeconds(1), - Triggers.any(Triggers.count(2), Triggers.timeSinceFirstMessage(Duration.ofMillis(500)))); + OperatorSpecGraph sgb = this.getKeyedTumblingWindowStreamGraph(AccumulationMode.ACCUMULATING, Duration.ofSeconds(1), + Triggers.any(Triggers.count(2), Triggers.timeSinceFirstMessage(Duration.ofMillis(500)))).getOperatorSpecGraph(); TestClock testClock = new TestClock(); - StreamOperatorTask task = new StreamOperatorTask(sgb, runner, testClock); + StreamOperatorTask task = new StreamOperatorTask(sgb, null, testClock); task.init(config, taskContext); List<WindowPane<Integer, Collection<IntegerEnvelope>>> windowPanes = new ArrayList<>(); @@ -389,15 +403,15 @@ public class TestWindowOperator { @Test public void testCancelationOfRepeatingNestedTriggers() throws Exception { - StreamApplication sgb = new KeyedTumblingWindowStreamApplication(AccumulationMode.ACCUMULATING, Duration.ofSeconds(1), - Triggers.repeat(Triggers.any(Triggers.count(2), Triggers.timeSinceFirstMessage(Duration.ofMillis(500))))); + OperatorSpecGraph sgb = this.getKeyedTumblingWindowStreamGraph(AccumulationMode.ACCUMULATING, Duration.ofSeconds(1), + Triggers.repeat(Triggers.any(Triggers.count(2), Triggers.timeSinceFirstMessage(Duration.ofMillis(500))))).getOperatorSpecGraph(); List<WindowPane<Integer, Collection<IntegerEnvelope>>> windowPanes = new ArrayList<>(); MessageCollector messageCollector = envelope -> windowPanes.add((WindowPane<Integer, Collection<IntegerEnvelope>>) envelope.getMessage()); TestClock testClock = new TestClock(); - StreamOperatorTask task = new StreamOperatorTask(sgb, runner, testClock); + StreamOperatorTask task = new StreamOperatorTask(sgb, null, testClock); task.init(config, taskContext); task.process(new IntegerEnvelope(1), messageCollector, taskCoordinator); @@ -434,12 +448,12 @@ public class TestWindowOperator { when(taskContext.fetchObject(EndOfStreamStates.class.getName())).thenReturn(endOfStreamStates); when(taskContext.fetchObject(WatermarkStates.class.getName())).thenReturn(mock(WatermarkStates.class)); - StreamApplication sgb = new TumblingWindowStreamApplication(AccumulationMode.DISCARDING, - Duration.ofSeconds(1), Triggers.repeat(Triggers.count(2))); + OperatorSpecGraph sgb = this.getTumblingWindowStreamGraph(AccumulationMode.DISCARDING, + Duration.ofSeconds(1), Triggers.repeat(Triggers.count(2))).getOperatorSpecGraph(); List<WindowPane<Integer, Collection<IntegerEnvelope>>> windowPanes = new ArrayList<>(); TestClock testClock = new TestClock(); - StreamOperatorTask task = new StreamOperatorTask(sgb, runner, testClock); + StreamOperatorTask task = new StreamOperatorTask(sgb, null, testClock); task.init(config, taskContext); MessageCollector messageCollector = @@ -475,10 +489,11 @@ public class TestWindowOperator { when(taskContext.fetchObject(EndOfStreamStates.class.getName())).thenReturn(endOfStreamStates); when(taskContext.fetchObject(WatermarkStates.class.getName())).thenReturn(mock(WatermarkStates.class)); - StreamApplication sgb = new KeyedSessionWindowStreamApplication(AccumulationMode.DISCARDING, Duration.ofMillis(500)); + OperatorSpecGraph sgb = + this.getKeyedSessionWindowStreamGraph(AccumulationMode.DISCARDING, Duration.ofMillis(500)).getOperatorSpecGraph(); TestClock testClock = new TestClock(); List<WindowPane<Integer, Collection<IntegerEnvelope>>> windowPanes = new ArrayList<>(); - StreamOperatorTask task = new StreamOperatorTask(sgb, runner, testClock); + StreamOperatorTask task = new StreamOperatorTask(sgb, null, testClock); task.init(config, taskContext); MessageCollector messageCollector = @@ -511,10 +526,11 @@ public class TestWindowOperator { when(taskContext.fetchObject(EndOfStreamStates.class.getName())).thenReturn(endOfStreamStates); when(taskContext.fetchObject(WatermarkStates.class.getName())).thenReturn(mock(WatermarkStates.class)); - StreamApplication sgb = new KeyedSessionWindowStreamApplication(AccumulationMode.DISCARDING, Duration.ofMillis(500)); + OperatorSpecGraph sgb = + this.getKeyedSessionWindowStreamGraph(AccumulationMode.DISCARDING, Duration.ofMillis(500)).getOperatorSpecGraph(); TestClock testClock = new TestClock(); List<WindowPane<Integer, Collection<IntegerEnvelope>>> windowPanes = new ArrayList<>(); - StreamOperatorTask task = new StreamOperatorTask(sgb, runner, testClock); + StreamOperatorTask task = new StreamOperatorTask(sgb, null, testClock); task.init(config, taskContext); MessageCollector messageCollector = @@ -534,144 +550,83 @@ public class TestWindowOperator { verify(taskCoordinator, times(1)).shutdown(TaskCoordinator.RequestScope.CURRENT_TASK); } - private class KeyedTumblingWindowStreamApplication implements StreamApplication { - - private final AccumulationMode mode; - private final Duration duration; - private final Trigger<IntegerEnvelope> earlyTrigger; - private final SystemStream outputSystemStream = new SystemStream("outputSystem", "outputStream"); + private StreamGraphSpec getKeyedTumblingWindowStreamGraph(AccumulationMode mode, + Duration duration, Trigger<KV<Integer, Integer>> earlyTrigger) throws IOException { + StreamGraphSpec graph = new StreamGraphSpec(runner, config); - KeyedTumblingWindowStreamApplication(AccumulationMode mode, - Duration timeDuration, Trigger<IntegerEnvelope> earlyTrigger) { - this.mode = mode; - this.duration = timeDuration; - this.earlyTrigger = earlyTrigger; - } + KVSerde<Integer, Integer> kvSerde = KVSerde.of(new IntegerSerde(), new IntegerSerde()); + graph.getInputStream("integers", kvSerde) + .window(Windows.keyedTumblingWindow(KV::getKey, duration, new IntegerSerde(), kvSerde) + .setEarlyTrigger(earlyTrigger).setAccumulationMode(mode), "w1") + .sink((message, messageCollector, taskCoordinator) -> { + SystemStream outputSystemStream = new SystemStream("outputSystem", "outputStream"); + messageCollector.send(new OutgoingMessageEnvelope(outputSystemStream, message)); + }); - @Override - public void init(StreamGraph graph, Config config) { - MessageStream<IntegerEnvelope> inStream = - graph.getInputStream("integers", KVSerde.of(new IntegerSerde(), new IntegerSerde())) - .map(kv -> new IntegerEnvelope(kv.getKey())); - Function<IntegerEnvelope, Integer> keyFn = m -> (Integer) m.getKey(); - inStream - .map(m -> m) - .window(Windows.keyedTumblingWindow(keyFn, duration, new IntegerSerde(), new IntegerEnvelopeSerde()) - .setEarlyTrigger(earlyTrigger) - .setAccumulationMode(mode), "w1") - .sink((message, messageCollector, taskCoordinator) -> { - messageCollector.send(new OutgoingMessageEnvelope(outputSystemStream, message)); - }); - } + return graph; } - private class TumblingWindowStreamApplication implements StreamApplication { - - private final AccumulationMode mode; - private final Duration duration; - private final Trigger<IntegerEnvelope> earlyTrigger; - private final SystemStream outputSystemStream = new SystemStream("outputSystem", "outputStream"); + private StreamGraphSpec getTumblingWindowStreamGraph(AccumulationMode mode, + Duration duration, Trigger<KV<Integer, Integer>> earlyTrigger) throws IOException { + StreamGraphSpec graph = new StreamGraphSpec(runner, config); - TumblingWindowStreamApplication(AccumulationMode mode, - Duration timeDuration, Trigger<IntegerEnvelope> earlyTrigger) { - this.mode = mode; - this.duration = timeDuration; - this.earlyTrigger = earlyTrigger; - } - - @Override - public void init(StreamGraph graph, Config config) { - MessageStream<IntegerEnvelope> inStream = - graph.getInputStream("integers", KVSerde.of(new IntegerSerde(), new IntegerSerde())) - .map(kv -> new IntegerEnvelope(kv.getKey())); - Function<IntegerEnvelope, Integer> keyFn = m -> (Integer) m.getKey(); - inStream - .map(m -> m) - .window(Windows.tumblingWindow(duration, new IntegerEnvelopeSerde()) - .setEarlyTrigger(earlyTrigger) - .setAccumulationMode(mode), "w1") - .sink((message, messageCollector, taskCoordinator) -> { - messageCollector.send(new OutgoingMessageEnvelope(outputSystemStream, message)); - }); - } + KVSerde<Integer, Integer> kvSerde = KVSerde.of(new IntegerSerde(), new IntegerSerde()); + graph.getInputStream("integers", kvSerde) + .window(Windows.tumblingWindow(duration, kvSerde).setEarlyTrigger(earlyTrigger) + .setAccumulationMode(mode), "w1") + .sink((message, messageCollector, taskCoordinator) -> { + SystemStream outputSystemStream = new SystemStream("outputSystem", "outputStream"); + messageCollector.send(new OutgoingMessageEnvelope(outputSystemStream, message)); + }); + return graph; } - private class AggregateTumblingWindowStreamApplication implements StreamApplication { + private StreamGraphSpec getKeyedSessionWindowStreamGraph(AccumulationMode mode, Duration duration) throws IOException { + StreamGraphSpec graph = new StreamGraphSpec(runner, config); - private final AccumulationMode mode; - private final Duration duration; - private final Trigger<IntegerEnvelope> earlyTrigger; - private final SystemStream outputSystemStream = new SystemStream("outputSystem", "outputStream"); - - AggregateTumblingWindowStreamApplication(AccumulationMode mode, Duration timeDuration, - Trigger<IntegerEnvelope> earlyTrigger) { - this.mode = mode; - this.duration = timeDuration; - this.earlyTrigger = earlyTrigger; - } - - @Override - public void init(StreamGraph graph, Config config) { - MessageStream<KV<Integer, Integer>> integers = graph.getInputStream("integers", - KVSerde.of(new IntegerSerde(), new IntegerSerde())); - - integers - .map(kv -> new IntegerEnvelope(kv.getKey())) - .window(Windows.<IntegerEnvelope, Integer>tumblingWindow(this.duration, () -> 0, (m, c) -> c + 1, new IntegerSerde()) - .setEarlyTrigger(earlyTrigger) + KVSerde<Integer, Integer> kvSerde = KVSerde.of(new IntegerSerde(), new IntegerSerde()); + graph.getInputStream("integers", kvSerde) + .window(Windows.keyedSessionWindow(KV::getKey, duration, new IntegerSerde(), kvSerde) .setAccumulationMode(mode), "w1") .sink((message, messageCollector, taskCoordinator) -> { + SystemStream outputSystemStream = new SystemStream("outputSystem", "outputStream"); messageCollector.send(new OutgoingMessageEnvelope(outputSystemStream, message)); }); - } + return graph; } - private class KeyedSessionWindowStreamApplication implements StreamApplication { + private StreamGraphSpec getAggregateTumblingWindowStreamGraph(AccumulationMode mode, Duration timeDuration, + Trigger<IntegerEnvelope> earlyTrigger) throws IOException { + StreamGraphSpec graph = new StreamGraphSpec(runner, config); - private final AccumulationMode mode; - private final Duration duration; - private final SystemStream outputSystemStream = new SystemStream("outputSystem", "outputStream"); + MessageStream<KV<Integer, Integer>> integers = graph.getInputStream("integers", + KVSerde.of(new IntegerSerde(), new IntegerSerde())); - KeyedSessionWindowStreamApplication(AccumulationMode mode, Duration duration) { - this.mode = mode; - this.duration = duration; - } - - @Override - public void init(StreamGraph graph, Config config) { - MessageStream<IntegerEnvelope> inStream = - graph.getInputStream("integers", KVSerde.of(new IntegerSerde(), new IntegerSerde())) - .map(kv -> new IntegerEnvelope(kv.getKey())); - Function<IntegerEnvelope, Integer> keyFn = m -> (Integer) m.getKey(); - - inStream - .map(m -> m) - .window(Windows.keyedSessionWindow(keyFn, duration, new IntegerSerde(), new IntegerEnvelopeSerde()) - .setAccumulationMode(mode), "w1") - .sink((message, messageCollector, taskCoordinator) -> { - messageCollector.send(new OutgoingMessageEnvelope(outputSystemStream, message)); - }); - } + integers + .map(new KVMapFunction()) + .window(Windows.<IntegerEnvelope, Integer>tumblingWindow(timeDuration, () -> 0, (m, c) -> c + 1, new IntegerSerde()) + .setEarlyTrigger(earlyTrigger) + .setAccumulationMode(mode), "w1") + .sink((message, messageCollector, taskCoordinator) -> { + SystemStream outputSystemStream = new SystemStream("outputSystem", "outputStream"); + messageCollector.send(new OutgoingMessageEnvelope(outputSystemStream, message)); + }); + return graph; } - private class IntegerEnvelope extends IncomingMessageEnvelope { + private static class IntegerEnvelope extends IncomingMessageEnvelope { IntegerEnvelope(Integer key) { - super(new SystemStreamPartition("kafka", "integers", new Partition(0)), "1", key, key); + super(new SystemStreamPartition("kafka", "integers", new Partition(0)), null, key, key); } } - private class IntegerEnvelopeSerde implements Serde<IntegerEnvelope> { - private final IntegerSerde intSerde = new IntegerSerde(); + private static class KVMapFunction implements MapFunction<KV<Integer, Integer>, IntegerEnvelope> { @Override - public byte[] toBytes(IntegerEnvelope object) { - return intSerde.toBytes((Integer) object.getKey()); - } - - @Override - public IntegerEnvelope fromBytes(byte[] bytes) { - return new IntegerEnvelope(intSerde.fromBytes(bytes)); + public IntegerEnvelope apply(KV<Integer, Integer> message) { + return new IntegerEnvelope(message.getKey()); } } + }
