http://git-wip-us.apache.org/repos/asf/flink/blob/0c771505/flink-queryable-state/flink-queryable-state-runtime/src/test/java/org/apache/flink/queryablestate/network/ClientTest.java ---------------------------------------------------------------------- diff --git a/flink-queryable-state/flink-queryable-state-runtime/src/test/java/org/apache/flink/queryablestate/network/ClientTest.java b/flink-queryable-state/flink-queryable-state-runtime/src/test/java/org/apache/flink/queryablestate/network/ClientTest.java new file mode 100644 index 0000000..1fa4deb --- /dev/null +++ b/flink-queryable-state/flink-queryable-state-runtime/src/test/java/org/apache/flink/queryablestate/network/ClientTest.java @@ -0,0 +1,782 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.queryablestate.network; + +import org.apache.flink.api.common.JobID; +import org.apache.flink.api.common.state.ValueState; +import org.apache.flink.api.common.state.ValueStateDescriptor; +import org.apache.flink.api.common.typeutils.base.IntSerializer; +import org.apache.flink.queryablestate.KvStateID; +import org.apache.flink.queryablestate.client.VoidNamespace; +import org.apache.flink.queryablestate.client.VoidNamespaceSerializer; +import org.apache.flink.queryablestate.client.state.serialization.KvStateSerializer; +import org.apache.flink.queryablestate.messages.KvStateInternalRequest; +import org.apache.flink.queryablestate.messages.KvStateResponse; +import org.apache.flink.queryablestate.network.messages.MessageSerializer; +import org.apache.flink.queryablestate.network.messages.MessageType; +import org.apache.flink.queryablestate.network.stats.AtomicKvStateRequestStats; +import org.apache.flink.queryablestate.server.KvStateServerImpl; +import org.apache.flink.runtime.jobgraph.JobVertexID; +import org.apache.flink.runtime.operators.testutils.DummyEnvironment; +import org.apache.flink.runtime.query.KvStateRegistry; +import org.apache.flink.runtime.state.AbstractKeyedStateBackend; +import org.apache.flink.runtime.state.AbstractStateBackend; +import org.apache.flink.runtime.state.KeyGroupRange; +import org.apache.flink.runtime.state.internal.InternalKvState; +import org.apache.flink.runtime.state.memory.MemoryStateBackend; +import org.apache.flink.util.NetUtils; + +import org.apache.flink.shaded.netty4.io.netty.bootstrap.ServerBootstrap; +import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf; +import org.apache.flink.shaded.netty4.io.netty.channel.Channel; +import org.apache.flink.shaded.netty4.io.netty.channel.ChannelHandler; +import org.apache.flink.shaded.netty4.io.netty.channel.ChannelHandlerContext; +import org.apache.flink.shaded.netty4.io.netty.channel.ChannelInboundHandlerAdapter; +import org.apache.flink.shaded.netty4.io.netty.channel.ChannelInitializer; +import org.apache.flink.shaded.netty4.io.netty.channel.nio.NioEventLoopGroup; +import org.apache.flink.shaded.netty4.io.netty.channel.socket.SocketChannel; +import org.apache.flink.shaded.netty4.io.netty.channel.socket.nio.NioServerSocketChannel; +import org.apache.flink.shaded.netty4.io.netty.handler.codec.LengthFieldBasedFrameDecoder; + +import org.junit.AfterClass; +import org.junit.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.net.ConnectException; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.net.UnknownHostException; +import java.nio.channels.ClosedChannelException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.Callable; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.ThreadLocalRandom; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; + +import scala.concurrent.duration.Deadline; +import scala.concurrent.duration.FiniteDuration; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +/** + * Tests for {@link Client}. + */ +public class ClientTest { + + private static final Logger LOG = LoggerFactory.getLogger(ClientTest.class); + + // Thread pool for client bootstrap (shared between tests) + private static final NioEventLoopGroup NIO_GROUP = new NioEventLoopGroup(); + + private static final FiniteDuration TEST_TIMEOUT = new FiniteDuration(10L, TimeUnit.SECONDS); + + @AfterClass + public static void tearDown() throws Exception { + if (NIO_GROUP != null) { + NIO_GROUP.shutdownGracefully(); + } + } + + /** + * Tests simple queries, of which half succeed and half fail. + */ + @Test + public void testSimpleRequests() throws Exception { + Deadline deadline = TEST_TIMEOUT.fromNow(); + AtomicKvStateRequestStats stats = new AtomicKvStateRequestStats(); + + MessageSerializer<KvStateInternalRequest, KvStateResponse> serializer = + new MessageSerializer<>(new KvStateInternalRequest.KvStateInternalRequestDeserializer(), new KvStateResponse.KvStateResponseDeserializer()); + + Client<KvStateInternalRequest, KvStateResponse> client = null; + Channel serverChannel = null; + + try { + client = new Client<>("Test Client", 1, serializer, stats); + + // Random result + final byte[] expected = new byte[1024]; + ThreadLocalRandom.current().nextBytes(expected); + + final LinkedBlockingQueue<ByteBuf> received = new LinkedBlockingQueue<>(); + final AtomicReference<Channel> channel = new AtomicReference<>(); + + serverChannel = createServerChannel(new ChannelInboundHandlerAdapter() { + @Override + public void channelActive(ChannelHandlerContext ctx) throws Exception { + channel.set(ctx.channel()); + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + received.add((ByteBuf) msg); + } + }); + + InetSocketAddress serverAddress = getKvStateServerAddress(serverChannel); + + long numQueries = 1024L; + + List<CompletableFuture<KvStateResponse>> futures = new ArrayList<>(); + for (long i = 0L; i < numQueries; i++) { + KvStateInternalRequest request = new KvStateInternalRequest(new KvStateID(), new byte[0]); + futures.add(client.sendRequest(serverAddress, request)); + } + + // Respond to messages + Exception testException = new RuntimeException("Expected test Exception"); + + for (long i = 0L; i < numQueries; i++) { + ByteBuf buf = received.poll(deadline.timeLeft().toMillis(), TimeUnit.MILLISECONDS); + assertNotNull("Receive timed out", buf); + + Channel ch = channel.get(); + assertNotNull("Channel not active", ch); + + assertEquals(MessageType.REQUEST, MessageSerializer.deserializeHeader(buf)); + long requestId = MessageSerializer.getRequestId(buf); + KvStateInternalRequest deserRequest = serializer.deserializeRequest(buf); + + buf.release(); + + if (i % 2L == 0L) { + ByteBuf response = MessageSerializer.serializeResponse( + serverChannel.alloc(), + requestId, + new KvStateResponse(expected)); + + ch.writeAndFlush(response); + } else { + ByteBuf response = MessageSerializer.serializeRequestFailure( + serverChannel.alloc(), + requestId, + testException); + + ch.writeAndFlush(response); + } + } + + for (long i = 0L; i < numQueries; i++) { + + if (i % 2L == 0L) { + KvStateResponse serializedResult = futures.get((int) i).get(deadline.timeLeft().toMillis(), TimeUnit.MILLISECONDS); + assertArrayEquals(expected, serializedResult.getContent()); + } else { + try { + futures.get((int) i).get(deadline.timeLeft().toMillis(), TimeUnit.MILLISECONDS); + fail("Did not throw expected Exception"); + } catch (ExecutionException e) { + + if (!(e.getCause() instanceof RuntimeException)) { + fail("Did not throw expected Exception"); + } + // else expected + } + } + } + + assertEquals(numQueries, stats.getNumRequests()); + long expectedRequests = numQueries / 2L; + + // Counts can take some time to propagate + while (deadline.hasTimeLeft() && (stats.getNumSuccessful() != expectedRequests || + stats.getNumFailed() != expectedRequests)) { + Thread.sleep(100L); + } + + assertEquals(expectedRequests, stats.getNumSuccessful()); + assertEquals(expectedRequests, stats.getNumFailed()); + } finally { + if (client != null) { + client.shutdown(); + } + + if (serverChannel != null) { + serverChannel.close(); + } + + assertEquals("Channel leak", 0L, stats.getNumConnections()); + } + } + + /** + * Tests that a request to an unavailable host is failed with ConnectException. + */ + @Test + public void testRequestUnavailableHost() throws Exception { + Deadline deadline = TEST_TIMEOUT.fromNow(); + AtomicKvStateRequestStats stats = new AtomicKvStateRequestStats(); + + MessageSerializer<KvStateInternalRequest, KvStateResponse> serializer = + new MessageSerializer<>(new KvStateInternalRequest.KvStateInternalRequestDeserializer(), new KvStateResponse.KvStateResponseDeserializer()); + + Client<KvStateInternalRequest, KvStateResponse> client = null; + + try { + client = new Client<>("Test Client", 1, serializer, stats); + + int availablePort = NetUtils.getAvailablePort(); + + InetSocketAddress serverAddress = new InetSocketAddress( + InetAddress.getLocalHost(), + availablePort); + + KvStateInternalRequest request = new KvStateInternalRequest(new KvStateID(), new byte[0]); + CompletableFuture<KvStateResponse> future = client.sendRequest(serverAddress, request); + + try { + future.get(deadline.timeLeft().toMillis(), TimeUnit.MILLISECONDS); + fail("Did not throw expected ConnectException"); + } catch (ExecutionException e) { + if (!(e.getCause() instanceof ConnectException)) { + fail("Did not throw expected ConnectException"); + } + // else expected + } + } finally { + if (client != null) { + client.shutdown(); + } + + assertEquals("Channel leak", 0L, stats.getNumConnections()); + } + } + + /** + * Multiple threads concurrently fire queries. + */ + @Test + public void testConcurrentQueries() throws Exception { + Deadline deadline = TEST_TIMEOUT.fromNow(); + AtomicKvStateRequestStats stats = new AtomicKvStateRequestStats(); + + final MessageSerializer<KvStateInternalRequest, KvStateResponse> serializer = + new MessageSerializer<>(new KvStateInternalRequest.KvStateInternalRequestDeserializer(), new KvStateResponse.KvStateResponseDeserializer()); + + ExecutorService executor = null; + Client<KvStateInternalRequest, KvStateResponse> client = null; + Channel serverChannel = null; + + final byte[] serializedResult = new byte[1024]; + ThreadLocalRandom.current().nextBytes(serializedResult); + + try { + int numQueryTasks = 4; + final int numQueriesPerTask = 1024; + + executor = Executors.newFixedThreadPool(numQueryTasks); + + client = new Client<>("Test Client", 1, serializer, stats); + + serverChannel = createServerChannel(new ChannelInboundHandlerAdapter() { + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + ByteBuf buf = (ByteBuf) msg; + assertEquals(MessageType.REQUEST, MessageSerializer.deserializeHeader(buf)); + long requestId = MessageSerializer.getRequestId(buf); + KvStateInternalRequest request = serializer.deserializeRequest(buf); + + buf.release(); + + KvStateResponse response = new KvStateResponse(serializedResult); + ByteBuf serResponse = MessageSerializer.serializeResponse( + ctx.alloc(), + requestId, + response); + + ctx.channel().writeAndFlush(serResponse); + } + }); + + final InetSocketAddress serverAddress = getKvStateServerAddress(serverChannel); + + final Client<KvStateInternalRequest, KvStateResponse> finalClient = client; + Callable<List<CompletableFuture<KvStateResponse>>> queryTask = () -> { + List<CompletableFuture<KvStateResponse>> results = new ArrayList<>(numQueriesPerTask); + + for (int i = 0; i < numQueriesPerTask; i++) { + KvStateInternalRequest request = new KvStateInternalRequest(new KvStateID(), new byte[0]); + results.add(finalClient.sendRequest(serverAddress, request)); + } + + return results; + }; + + // Submit query tasks + List<Future<List<CompletableFuture<KvStateResponse>>>> futures = new ArrayList<>(); + for (int i = 0; i < numQueryTasks; i++) { + futures.add(executor.submit(queryTask)); + } + + // Verify results + for (Future<List<CompletableFuture<KvStateResponse>>> future : futures) { + List<CompletableFuture<KvStateResponse>> results = future.get(deadline.timeLeft().toMillis(), TimeUnit.MILLISECONDS); + for (CompletableFuture<KvStateResponse> result : results) { + KvStateResponse actual = result.get(deadline.timeLeft().toMillis(), TimeUnit.MILLISECONDS); + assertArrayEquals(serializedResult, actual.getContent()); + } + } + + int totalQueries = numQueryTasks * numQueriesPerTask; + + // Counts can take some time to propagate + while (deadline.hasTimeLeft() && stats.getNumSuccessful() != totalQueries) { + Thread.sleep(100L); + } + + assertEquals(totalQueries, stats.getNumRequests()); + assertEquals(totalQueries, stats.getNumSuccessful()); + } finally { + if (executor != null) { + executor.shutdown(); + } + + if (serverChannel != null) { + serverChannel.close(); + } + + if (client != null) { + client.shutdown(); + } + + assertEquals("Channel leak", 0L, stats.getNumConnections()); + } + } + + /** + * Tests that a server failure closes the connection and removes it from + * the established connections. + */ + @Test + public void testFailureClosesChannel() throws Exception { + Deadline deadline = TEST_TIMEOUT.fromNow(); + AtomicKvStateRequestStats stats = new AtomicKvStateRequestStats(); + + final MessageSerializer<KvStateInternalRequest, KvStateResponse> serializer = + new MessageSerializer<>(new KvStateInternalRequest.KvStateInternalRequestDeserializer(), new KvStateResponse.KvStateResponseDeserializer()); + + Client<KvStateInternalRequest, KvStateResponse> client = null; + Channel serverChannel = null; + + try { + client = new Client<>("Test Client", 1, serializer, stats); + + final LinkedBlockingQueue<ByteBuf> received = new LinkedBlockingQueue<>(); + final AtomicReference<Channel> channel = new AtomicReference<>(); + + serverChannel = createServerChannel(new ChannelInboundHandlerAdapter() { + @Override + public void channelActive(ChannelHandlerContext ctx) throws Exception { + channel.set(ctx.channel()); + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + received.add((ByteBuf) msg); + } + }); + + InetSocketAddress serverAddress = getKvStateServerAddress(serverChannel); + + // Requests + List<Future<KvStateResponse>> futures = new ArrayList<>(); + KvStateInternalRequest request = new KvStateInternalRequest(new KvStateID(), new byte[0]); + + futures.add(client.sendRequest(serverAddress, request)); + futures.add(client.sendRequest(serverAddress, request)); + + ByteBuf buf = received.poll(deadline.timeLeft().toMillis(), TimeUnit.MILLISECONDS); + assertNotNull("Receive timed out", buf); + buf.release(); + + buf = received.poll(deadline.timeLeft().toMillis(), TimeUnit.MILLISECONDS); + assertNotNull("Receive timed out", buf); + buf.release(); + + assertEquals(1L, stats.getNumConnections()); + + Channel ch = channel.get(); + assertNotNull("Channel not active", ch); + + // Respond with failure + ch.writeAndFlush(MessageSerializer.serializeServerFailure( + serverChannel.alloc(), + new RuntimeException("Expected test server failure"))); + + try { + futures.remove(0).get(deadline.timeLeft().toMillis(), TimeUnit.MILLISECONDS); + fail("Did not throw expected server failure"); + } catch (ExecutionException e) { + + if (!(e.getCause() instanceof RuntimeException)) { + fail("Did not throw expected Exception"); + } + // Expected + } + + try { + futures.remove(0).get(deadline.timeLeft().toMillis(), TimeUnit.MILLISECONDS); + fail("Did not throw expected server failure"); + } catch (ExecutionException e) { + + if (!(e.getCause() instanceof RuntimeException)) { + fail("Did not throw expected Exception"); + } + // Expected + } + + assertEquals(0L, stats.getNumConnections()); + + // Counts can take some time to propagate + while (deadline.hasTimeLeft() && (stats.getNumSuccessful() != 0L || stats.getNumFailed() != 2L)) { + Thread.sleep(100L); + } + + assertEquals(2L, stats.getNumRequests()); + assertEquals(0L, stats.getNumSuccessful()); + assertEquals(2L, stats.getNumFailed()); + } finally { + if (client != null) { + client.shutdown(); + } + + if (serverChannel != null) { + serverChannel.close(); + } + + assertEquals("Channel leak", 0L, stats.getNumConnections()); + } + } + + /** + * Tests that a server channel close, closes the connection and removes it + * from the established connections. + */ + @Test + public void testServerClosesChannel() throws Exception { + Deadline deadline = TEST_TIMEOUT.fromNow(); + AtomicKvStateRequestStats stats = new AtomicKvStateRequestStats(); + + final MessageSerializer<KvStateInternalRequest, KvStateResponse> serializer = + new MessageSerializer<>(new KvStateInternalRequest.KvStateInternalRequestDeserializer(), new KvStateResponse.KvStateResponseDeserializer()); + + Client<KvStateInternalRequest, KvStateResponse> client = null; + Channel serverChannel = null; + + try { + client = new Client<>("Test Client", 1, serializer, stats); + + final AtomicBoolean received = new AtomicBoolean(); + final AtomicReference<Channel> channel = new AtomicReference<>(); + + serverChannel = createServerChannel(new ChannelInboundHandlerAdapter() { + @Override + public void channelActive(ChannelHandlerContext ctx) throws Exception { + channel.set(ctx.channel()); + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + received.set(true); + } + }); + + InetSocketAddress serverAddress = getKvStateServerAddress(serverChannel); + + // Requests + KvStateInternalRequest request = new KvStateInternalRequest(new KvStateID(), new byte[0]); + Future<KvStateResponse> future = client.sendRequest(serverAddress, request); + + while (!received.get() && deadline.hasTimeLeft()) { + Thread.sleep(50L); + } + assertTrue("Receive timed out", received.get()); + + assertEquals(1, stats.getNumConnections()); + + channel.get().close().await(deadline.timeLeft().toMillis(), TimeUnit.MILLISECONDS); + + try { + future.get(deadline.timeLeft().toMillis(), TimeUnit.MILLISECONDS); + fail("Did not throw expected server failure"); + } catch (ExecutionException e) { + if (!(e.getCause() instanceof ClosedChannelException)) { + fail("Did not throw expected Exception"); + } + // Expected + } + + assertEquals(0L, stats.getNumConnections()); + + // Counts can take some time to propagate + while (deadline.hasTimeLeft() && (stats.getNumSuccessful() != 0L || stats.getNumFailed() != 1L)) { + Thread.sleep(100L); + } + + assertEquals(1L, stats.getNumRequests()); + assertEquals(0L, stats.getNumSuccessful()); + assertEquals(1L, stats.getNumFailed()); + } finally { + if (client != null) { + client.shutdown(); + } + + if (serverChannel != null) { + serverChannel.close(); + } + + assertEquals("Channel leak", 0L, stats.getNumConnections()); + } + } + + /** + * Tests multiple clients querying multiple servers until 100k queries have + * been processed. At this point, the client is shut down and its verified + * that all ongoing requests are failed. + */ + @Test + public void testClientServerIntegration() throws Throwable { + // Config + final int numServers = 2; + final int numServerEventLoopThreads = 2; + final int numServerQueryThreads = 2; + + final int numClientEventLoopThreads = 4; + final int numClientsTasks = 8; + + final int batchSize = 16; + + final int numKeyGroups = 1; + + AbstractStateBackend abstractBackend = new MemoryStateBackend(); + KvStateRegistry dummyRegistry = new KvStateRegistry(); + DummyEnvironment dummyEnv = new DummyEnvironment("test", 1, 0); + dummyEnv.setKvStateRegistry(dummyRegistry); + + AbstractKeyedStateBackend<Integer> backend = abstractBackend.createKeyedStateBackend( + dummyEnv, + new JobID(), + "test_op", + IntSerializer.INSTANCE, + numKeyGroups, + new KeyGroupRange(0, 0), + dummyRegistry.createTaskRegistry(new JobID(), new JobVertexID())); + + final FiniteDuration timeout = new FiniteDuration(10, TimeUnit.SECONDS); + + AtomicKvStateRequestStats clientStats = new AtomicKvStateRequestStats(); + + final MessageSerializer<KvStateInternalRequest, KvStateResponse> serializer = + new MessageSerializer<>(new KvStateInternalRequest.KvStateInternalRequestDeserializer(), new KvStateResponse.KvStateResponseDeserializer()); + + Client<KvStateInternalRequest, KvStateResponse> client = null; + ExecutorService clientTaskExecutor = null; + final KvStateServerImpl[] server = new KvStateServerImpl[numServers]; + + try { + client = new Client<>("Test Client", numClientEventLoopThreads, serializer, clientStats); + clientTaskExecutor = Executors.newFixedThreadPool(numClientsTasks); + + // Create state + ValueStateDescriptor<Integer> desc = new ValueStateDescriptor<>("any", IntSerializer.INSTANCE); + desc.setQueryable("any"); + + // Create servers + KvStateRegistry[] registry = new KvStateRegistry[numServers]; + AtomicKvStateRequestStats[] serverStats = new AtomicKvStateRequestStats[numServers]; + final KvStateID[] ids = new KvStateID[numServers]; + + for (int i = 0; i < numServers; i++) { + registry[i] = new KvStateRegistry(); + serverStats[i] = new AtomicKvStateRequestStats(); + server[i] = new KvStateServerImpl( + InetAddress.getLocalHost(), + Collections.singletonList(0).iterator(), + numServerEventLoopThreads, + numServerQueryThreads, + registry[i], + serverStats[i]); + + server[i].start(); + + backend.setCurrentKey(1010 + i); + + // Value per server + ValueState<Integer> state = backend.getPartitionedState( + VoidNamespace.INSTANCE, + VoidNamespaceSerializer.INSTANCE, + desc); + + state.update(201 + i); + + // we know it must be a KvStat but this is not exposed to the user via State + InternalKvState<?> kvState = (InternalKvState<?>) state; + + // Register KvState (one state instance for all server) + ids[i] = registry[i].registerKvState(new JobID(), new JobVertexID(), new KeyGroupRange(0, 0), "any", kvState); + } + + final Client<KvStateInternalRequest, KvStateResponse> finalClient = client; + Callable<Void> queryTask = () -> { + while (true) { + if (Thread.interrupted()) { + throw new InterruptedException(); + } + + // Random server permutation + List<Integer> random = new ArrayList<>(); + for (int j = 0; j < batchSize; j++) { + random.add(j); + } + Collections.shuffle(random); + + // Dispatch queries + List<Future<KvStateResponse>> futures = new ArrayList<>(batchSize); + + for (int j = 0; j < batchSize; j++) { + int targetServer = random.get(j) % numServers; + + byte[] serializedKeyAndNamespace = KvStateSerializer.serializeKeyAndNamespace( + 1010 + targetServer, + IntSerializer.INSTANCE, + VoidNamespace.INSTANCE, + VoidNamespaceSerializer.INSTANCE); + + KvStateInternalRequest request = new KvStateInternalRequest(ids[targetServer], serializedKeyAndNamespace); + futures.add(finalClient.sendRequest(server[targetServer].getServerAddress(), request)); + } + + // Verify results + for (int j = 0; j < batchSize; j++) { + int targetServer = random.get(j) % numServers; + + Future<KvStateResponse> future = futures.get(j); + byte[] buf = future.get(timeout.toMillis(), TimeUnit.MILLISECONDS).getContent(); + int value = KvStateSerializer.deserializeValue(buf, IntSerializer.INSTANCE); + assertEquals(201L + targetServer, value); + } + } + }; + + // Submit tasks + List<Future<Void>> taskFutures = new ArrayList<>(); + for (int i = 0; i < numClientsTasks; i++) { + taskFutures.add(clientTaskExecutor.submit(queryTask)); + } + + long numRequests; + while ((numRequests = clientStats.getNumRequests()) < 100_000L) { + Thread.sleep(100L); + LOG.info("Number of requests {}/100_000", numRequests); + } + + // Shut down + client.shutdown(); + + for (Future<Void> future : taskFutures) { + try { + future.get(); + fail("Did not throw expected Exception after shut down"); + } catch (ExecutionException t) { + if (t.getCause().getCause() instanceof ClosedChannelException || + t.getCause().getCause() instanceof IllegalStateException) { + // Expected + } else { + t.printStackTrace(); + fail("Failed with unexpected Exception type: " + t.getClass().getName()); + } + } + } + + assertEquals("Connection leak (client)", 0L, clientStats.getNumConnections()); + for (int i = 0; i < numServers; i++) { + boolean success = false; + int numRetries = 0; + while (!success) { + try { + assertEquals("Connection leak (server)", 0L, serverStats[i].getNumConnections()); + success = true; + } catch (Throwable t) { + if (numRetries < 10) { + LOG.info("Retrying connection leak check (server)"); + Thread.sleep((numRetries + 1) * 50L); + numRetries++; + } else { + throw t; + } + } + } + } + } finally { + if (client != null) { + client.shutdown(); + } + + for (int i = 0; i < numServers; i++) { + if (server[i] != null) { + server[i].shutdown(); + } + } + + if (clientTaskExecutor != null) { + clientTaskExecutor.shutdown(); + } + } + } + + // ------------------------------------------------------------------------ + + private Channel createServerChannel(final ChannelHandler... handlers) throws UnknownHostException, InterruptedException { + ServerBootstrap bootstrap = new ServerBootstrap() + // Bind address and port + .localAddress(InetAddress.getLocalHost(), 0) + // NIO server channels + .group(NIO_GROUP) + .channel(NioServerSocketChannel.class) + // See initializer for pipeline details + .childHandler(new ChannelInitializer<SocketChannel>() { + @Override + protected void initChannel(SocketChannel ch) throws Exception { + ch.pipeline() + .addLast(new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 4, 0, 4)) + .addLast(handlers); + } + }); + + return bootstrap.bind().sync().channel(); + } + + private InetSocketAddress getKvStateServerAddress(Channel serverChannel) { + return (InetSocketAddress) serverChannel.localAddress(); + } +}
http://git-wip-us.apache.org/repos/asf/flink/blob/0c771505/flink-queryable-state/flink-queryable-state-runtime/src/test/java/org/apache/flink/queryablestate/network/KvStateClientHandlerTest.java ---------------------------------------------------------------------- diff --git a/flink-queryable-state/flink-queryable-state-runtime/src/test/java/org/apache/flink/queryablestate/network/KvStateClientHandlerTest.java b/flink-queryable-state/flink-queryable-state-runtime/src/test/java/org/apache/flink/queryablestate/network/KvStateClientHandlerTest.java new file mode 100644 index 0000000..cb490aa --- /dev/null +++ b/flink-queryable-state/flink-queryable-state-runtime/src/test/java/org/apache/flink/queryablestate/network/KvStateClientHandlerTest.java @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.queryablestate.network; + +import org.apache.flink.queryablestate.messages.KvStateInternalRequest; +import org.apache.flink.queryablestate.messages.KvStateResponse; +import org.apache.flink.queryablestate.network.messages.MessageSerializer; + +import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf; +import org.apache.flink.shaded.netty4.io.netty.channel.embedded.EmbeddedChannel; + +import org.junit.Test; + +import java.nio.channels.ClosedChannelException; + +import static org.junit.Assert.assertEquals; +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +/** + * Tests for {@link ClientHandler}. + */ +public class KvStateClientHandlerTest { + + /** + * Tests that on reads the expected callback methods are called and read + * buffers are recycled. + */ + @Test + public void testReadCallbacksAndBufferRecycling() throws Exception { + final ClientHandlerCallback<KvStateResponse> callback = mock(ClientHandlerCallback.class); + + final MessageSerializer<KvStateInternalRequest, KvStateResponse> serializer = + new MessageSerializer<>(new KvStateInternalRequest.KvStateInternalRequestDeserializer(), new KvStateResponse.KvStateResponseDeserializer()); + final EmbeddedChannel channel = new EmbeddedChannel(new ClientHandler<>("Test Client", serializer, callback)); + + final byte[] content = new byte[0]; + final KvStateResponse response = new KvStateResponse(content); + + // + // Request success + // + ByteBuf buf = MessageSerializer.serializeResponse(channel.alloc(), 1222112277L, response); + buf.skipBytes(4); // skip frame length + + // Verify callback + channel.writeInbound(buf); + verify(callback, times(1)).onRequestResult(eq(1222112277L), any(KvStateResponse.class)); + assertEquals("Buffer not recycled", 0, buf.refCnt()); + + // + // Request failure + // + buf = MessageSerializer.serializeRequestFailure( + channel.alloc(), + 1222112278, + new RuntimeException("Expected test Exception")); + buf.skipBytes(4); // skip frame length + + // Verify callback + channel.writeInbound(buf); + verify(callback, times(1)).onRequestFailure(eq(1222112278L), any(RuntimeException.class)); + assertEquals("Buffer not recycled", 0, buf.refCnt()); + + // + // Server failure + // + buf = MessageSerializer.serializeServerFailure( + channel.alloc(), + new RuntimeException("Expected test Exception")); + buf.skipBytes(4); // skip frame length + + // Verify callback + channel.writeInbound(buf); + verify(callback, times(1)).onFailure(any(RuntimeException.class)); + + // + // Unexpected messages + // + buf = channel.alloc().buffer(4).writeInt(1223823); + + // Verify callback + channel.writeInbound(buf); + verify(callback, times(2)).onFailure(any(IllegalStateException.class)); + assertEquals("Buffer not recycled", 0, buf.refCnt()); + + // + // Exception caught + // + channel.pipeline().fireExceptionCaught(new RuntimeException("Expected test Exception")); + verify(callback, times(3)).onFailure(any(RuntimeException.class)); + + // + // Channel inactive + // + channel.pipeline().fireChannelInactive(); + verify(callback, times(4)).onFailure(any(ClosedChannelException.class)); + } + +} http://git-wip-us.apache.org/repos/asf/flink/blob/0c771505/flink-queryable-state/flink-queryable-state-runtime/src/test/java/org/apache/flink/queryablestate/network/KvStateRequestSerializerTest.java ---------------------------------------------------------------------- diff --git a/flink-queryable-state/flink-queryable-state-runtime/src/test/java/org/apache/flink/queryablestate/network/KvStateRequestSerializerTest.java b/flink-queryable-state/flink-queryable-state-runtime/src/test/java/org/apache/flink/queryablestate/network/KvStateRequestSerializerTest.java new file mode 100644 index 0000000..d3314ab --- /dev/null +++ b/flink-queryable-state/flink-queryable-state-runtime/src/test/java/org/apache/flink/queryablestate/network/KvStateRequestSerializerTest.java @@ -0,0 +1,416 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.queryablestate.network; + +import org.apache.flink.api.common.ExecutionConfig; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.state.MapStateDescriptor; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.api.common.typeutils.base.LongSerializer; +import org.apache.flink.api.common.typeutils.base.StringSerializer; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.queryablestate.client.VoidNamespace; +import org.apache.flink.queryablestate.client.VoidNamespaceSerializer; +import org.apache.flink.queryablestate.client.state.serialization.KvStateSerializer; +import org.apache.flink.runtime.query.TaskKvStateRegistry; +import org.apache.flink.runtime.state.KeyGroupRange; +import org.apache.flink.runtime.state.heap.HeapKeyedStateBackend; +import org.apache.flink.runtime.state.internal.InternalKvState; +import org.apache.flink.runtime.state.internal.InternalListState; +import org.apache.flink.runtime.state.internal.InternalMapState; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ThreadLocalRandom; + +import static org.junit.Assert.assertEquals; +import static org.mockito.Mockito.mock; + +/** + * Tests for {@link KvStateSerializer}. + */ +@RunWith(Parameterized.class) +public class KvStateRequestSerializerTest { + + @Parameterized.Parameters + public static Collection<Boolean> parameters() { + return Arrays.asList(false, true); + } + + @Parameterized.Parameter + public boolean async; + + /** + * Tests key and namespace serialization utils. + */ + @Test + public void testKeyAndNamespaceSerialization() throws Exception { + TypeSerializer<Long> keySerializer = LongSerializer.INSTANCE; + TypeSerializer<String> namespaceSerializer = StringSerializer.INSTANCE; + + long expectedKey = Integer.MAX_VALUE + 12323L; + String expectedNamespace = "knilf"; + + byte[] serializedKeyAndNamespace = KvStateSerializer.serializeKeyAndNamespace( + expectedKey, keySerializer, expectedNamespace, namespaceSerializer); + + Tuple2<Long, String> actual = KvStateSerializer.deserializeKeyAndNamespace( + serializedKeyAndNamespace, keySerializer, namespaceSerializer); + + assertEquals(expectedKey, actual.f0.longValue()); + assertEquals(expectedNamespace, actual.f1); + } + + /** + * Tests key and namespace deserialization utils with too few bytes. + */ + @Test(expected = IOException.class) + public void testKeyAndNamespaceDeserializationEmpty() throws Exception { + KvStateSerializer.deserializeKeyAndNamespace( + new byte[] {}, LongSerializer.INSTANCE, StringSerializer.INSTANCE); + } + + /** + * Tests key and namespace deserialization utils with too few bytes. + */ + @Test(expected = IOException.class) + public void testKeyAndNamespaceDeserializationTooShort() throws Exception { + KvStateSerializer.deserializeKeyAndNamespace( + new byte[] {1}, LongSerializer.INSTANCE, StringSerializer.INSTANCE); + } + + /** + * Tests key and namespace deserialization utils with too many bytes. + */ + @Test(expected = IOException.class) + public void testKeyAndNamespaceDeserializationTooMany1() throws Exception { + // Long + null String + 1 byte + KvStateSerializer.deserializeKeyAndNamespace( + new byte[] {1, 1, 1, 1, 1, 1, 1, 1, 42, 0, 2}, LongSerializer.INSTANCE, + StringSerializer.INSTANCE); + } + + /** + * Tests key and namespace deserialization utils with too many bytes. + */ + @Test(expected = IOException.class) + public void testKeyAndNamespaceDeserializationTooMany2() throws Exception { + // Long + null String + 2 bytes + KvStateSerializer.deserializeKeyAndNamespace( + new byte[] {1, 1, 1, 1, 1, 1, 1, 1, 42, 0, 2, 2}, LongSerializer.INSTANCE, + StringSerializer.INSTANCE); + } + + /** + * Tests value serialization utils. + */ + @Test + public void testValueSerialization() throws Exception { + TypeSerializer<Long> valueSerializer = LongSerializer.INSTANCE; + long expectedValue = Long.MAX_VALUE - 1292929292L; + + byte[] serializedValue = KvStateSerializer.serializeValue(expectedValue, valueSerializer); + long actualValue = KvStateSerializer.deserializeValue(serializedValue, valueSerializer); + + assertEquals(expectedValue, actualValue); + } + + /** + * Tests value deserialization with too few bytes. + */ + @Test(expected = IOException.class) + public void testDeserializeValueEmpty() throws Exception { + KvStateSerializer.deserializeValue(new byte[] {}, LongSerializer.INSTANCE); + } + + /** + * Tests value deserialization with too few bytes. + */ + @Test(expected = IOException.class) + public void testDeserializeValueTooShort() throws Exception { + // 1 byte (incomplete Long) + KvStateSerializer.deserializeValue(new byte[] {1}, LongSerializer.INSTANCE); + } + + /** + * Tests value deserialization with too many bytes. + */ + @Test(expected = IOException.class) + public void testDeserializeValueTooMany1() throws Exception { + // Long + 1 byte + KvStateSerializer.deserializeValue(new byte[] {1, 1, 1, 1, 1, 1, 1, 1, 2}, + LongSerializer.INSTANCE); + } + + /** + * Tests value deserialization with too many bytes. + */ + @Test(expected = IOException.class) + public void testDeserializeValueTooMany2() throws Exception { + // Long + 2 bytes + KvStateSerializer.deserializeValue(new byte[] {1, 1, 1, 1, 1, 1, 1, 1, 2, 2}, + LongSerializer.INSTANCE); + } + + /** + * Tests list serialization utils. + */ + @Test + public void testListSerialization() throws Exception { + final long key = 0L; + + // objects for heap state list serialisation + final HeapKeyedStateBackend<Long> longHeapKeyedStateBackend = + new HeapKeyedStateBackend<>( + mock(TaskKvStateRegistry.class), + LongSerializer.INSTANCE, + ClassLoader.getSystemClassLoader(), + 1, + new KeyGroupRange(0, 0), + async, + new ExecutionConfig() + ); + longHeapKeyedStateBackend.setCurrentKey(key); + + final InternalListState<VoidNamespace, Long> listState = longHeapKeyedStateBackend.createListState( + VoidNamespaceSerializer.INSTANCE, + new ListStateDescriptor<>("test", LongSerializer.INSTANCE)); + + testListSerialization(key, listState); + } + + /** + * Verifies that the serialization of a list using the given list state + * matches the deserialization with {@link KvStateSerializer#deserializeList}. + * + * @param key + * key of the list state + * @param listState + * list state using the {@link VoidNamespace}, must also be a {@link InternalKvState} instance + * + * @throws Exception + */ + public static void testListSerialization( + final long key, + final InternalListState<VoidNamespace, Long> listState) throws Exception { + + TypeSerializer<Long> valueSerializer = LongSerializer.INSTANCE; + listState.setCurrentNamespace(VoidNamespace.INSTANCE); + + // List + final int numElements = 10; + + final List<Long> expectedValues = new ArrayList<>(); + for (int i = 0; i < numElements; i++) { + final long value = ThreadLocalRandom.current().nextLong(); + expectedValues.add(value); + listState.add(value); + } + + final byte[] serializedKey = + KvStateSerializer.serializeKeyAndNamespace( + key, LongSerializer.INSTANCE, + VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE); + + final byte[] serializedValues = listState.getSerializedValue(serializedKey); + + List<Long> actualValues = KvStateSerializer.deserializeList(serializedValues, valueSerializer); + assertEquals(expectedValues, actualValues); + + // Single value + long expectedValue = ThreadLocalRandom.current().nextLong(); + byte[] serializedValue = KvStateSerializer.serializeValue(expectedValue, valueSerializer); + List<Long> actualValue = KvStateSerializer.deserializeList(serializedValue, valueSerializer); + assertEquals(1, actualValue.size()); + assertEquals(expectedValue, actualValue.get(0).longValue()); + } + + /** + * Tests list deserialization with too few bytes. + */ + @Test + public void testDeserializeListEmpty() throws Exception { + List<Long> actualValue = KvStateSerializer + .deserializeList(new byte[] {}, LongSerializer.INSTANCE); + assertEquals(0, actualValue.size()); + } + + /** + * Tests list deserialization with too few bytes. + */ + @Test(expected = IOException.class) + public void testDeserializeListTooShort1() throws Exception { + // 1 byte (incomplete Long) + KvStateSerializer.deserializeList(new byte[] {1}, LongSerializer.INSTANCE); + } + + /** + * Tests list deserialization with too few bytes. + */ + @Test(expected = IOException.class) + public void testDeserializeListTooShort2() throws Exception { + // Long + 1 byte (separator) + 1 byte (incomplete Long) + KvStateSerializer.deserializeList(new byte[] {1, 1, 1, 1, 1, 1, 1, 1, 2, 3}, + LongSerializer.INSTANCE); + } + + /** + * Tests map serialization utils. + */ + @Test + public void testMapSerialization() throws Exception { + final long key = 0L; + + // objects for heap state list serialisation + final HeapKeyedStateBackend<Long> longHeapKeyedStateBackend = + new HeapKeyedStateBackend<>( + mock(TaskKvStateRegistry.class), + LongSerializer.INSTANCE, + ClassLoader.getSystemClassLoader(), + 1, + new KeyGroupRange(0, 0), + async, + new ExecutionConfig() + ); + longHeapKeyedStateBackend.setCurrentKey(key); + + final InternalMapState<VoidNamespace, Long, String> mapState = (InternalMapState<VoidNamespace, Long, String>) longHeapKeyedStateBackend.getPartitionedState( + VoidNamespace.INSTANCE, + VoidNamespaceSerializer.INSTANCE, + new MapStateDescriptor<>("test", LongSerializer.INSTANCE, StringSerializer.INSTANCE)); + + testMapSerialization(key, mapState); + } + + /** + * Verifies that the serialization of a map using the given map state + * matches the deserialization with {@link KvStateSerializer#deserializeList}. + * + * @param key + * key of the map state + * @param mapState + * map state using the {@link VoidNamespace}, must also be a {@link InternalKvState} instance + * + * @throws Exception + */ + public static void testMapSerialization( + final long key, + final InternalMapState<VoidNamespace, Long, String> mapState) throws Exception { + + TypeSerializer<Long> userKeySerializer = LongSerializer.INSTANCE; + TypeSerializer<String> userValueSerializer = StringSerializer.INSTANCE; + mapState.setCurrentNamespace(VoidNamespace.INSTANCE); + + // Map + final int numElements = 10; + + final Map<Long, String> expectedValues = new HashMap<>(); + for (int i = 1; i <= numElements; i++) { + final long value = ThreadLocalRandom.current().nextLong(); + expectedValues.put(value, Long.toString(value)); + mapState.put(value, Long.toString(value)); + } + + expectedValues.put(0L, null); + mapState.put(0L, null); + + final byte[] serializedKey = + KvStateSerializer.serializeKeyAndNamespace( + key, LongSerializer.INSTANCE, + VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE); + + final byte[] serializedValues = mapState.getSerializedValue(serializedKey); + + Map<Long, String> actualValues = KvStateSerializer.deserializeMap(serializedValues, userKeySerializer, userValueSerializer); + assertEquals(expectedValues.size(), actualValues.size()); + for (Map.Entry<Long, String> actualEntry : actualValues.entrySet()) { + assertEquals(expectedValues.get(actualEntry.getKey()), actualEntry.getValue()); + } + + // Single value + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + long expectedKey = ThreadLocalRandom.current().nextLong(); + String expectedValue = Long.toString(expectedKey); + byte[] isNull = {0}; + + baos.write(KvStateSerializer.serializeValue(expectedKey, userKeySerializer)); + baos.write(isNull); + baos.write(KvStateSerializer.serializeValue(expectedValue, userValueSerializer)); + byte[] serializedValue = baos.toByteArray(); + + Map<Long, String> actualValue = KvStateSerializer.deserializeMap(serializedValue, userKeySerializer, userValueSerializer); + assertEquals(1, actualValue.size()); + assertEquals(expectedValue, actualValue.get(expectedKey)); + } + + /** + * Tests map deserialization with too few bytes. + */ + @Test + public void testDeserializeMapEmpty() throws Exception { + Map<Long, String> actualValue = KvStateSerializer + .deserializeMap(new byte[] {}, LongSerializer.INSTANCE, StringSerializer.INSTANCE); + assertEquals(0, actualValue.size()); + } + + /** + * Tests map deserialization with too few bytes. + */ + @Test(expected = IOException.class) + public void testDeserializeMapTooShort1() throws Exception { + // 1 byte (incomplete Key) + KvStateSerializer.deserializeMap(new byte[] {1}, LongSerializer.INSTANCE, StringSerializer.INSTANCE); + } + + /** + * Tests map deserialization with too few bytes. + */ + @Test(expected = IOException.class) + public void testDeserializeMapTooShort2() throws Exception { + // Long (Key) + 1 byte (incomplete Value) + KvStateSerializer.deserializeMap(new byte[]{1, 1, 1, 1, 1, 1, 1, 1, 0}, + LongSerializer.INSTANCE, LongSerializer.INSTANCE); + } + + /** + * Tests map deserialization with too few bytes. + */ + @Test(expected = IOException.class) + public void testDeserializeMapTooShort3() throws Exception { + // Long (Key1) + Boolean (false) + Long (Value1) + 1 byte (incomplete Key2) + KvStateSerializer.deserializeMap(new byte[] {1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 3}, + LongSerializer.INSTANCE, LongSerializer.INSTANCE); + } + + private byte[] randomByteArray(int capacity) { + byte[] bytes = new byte[capacity]; + ThreadLocalRandom.current().nextBytes(bytes); + return bytes; + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/0c771505/flink-queryable-state/flink-queryable-state-runtime/src/test/java/org/apache/flink/queryablestate/network/KvStateServerHandlerTest.java ---------------------------------------------------------------------- diff --git a/flink-queryable-state/flink-queryable-state-runtime/src/test/java/org/apache/flink/queryablestate/network/KvStateServerHandlerTest.java b/flink-queryable-state/flink-queryable-state-runtime/src/test/java/org/apache/flink/queryablestate/network/KvStateServerHandlerTest.java new file mode 100644 index 0000000..041544d --- /dev/null +++ b/flink-queryable-state/flink-queryable-state-runtime/src/test/java/org/apache/flink/queryablestate/network/KvStateServerHandlerTest.java @@ -0,0 +1,758 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.queryablestate.network; + +import org.apache.flink.api.common.JobID; +import org.apache.flink.api.common.state.ValueState; +import org.apache.flink.api.common.state.ValueStateDescriptor; +import org.apache.flink.api.common.typeutils.base.IntSerializer; +import org.apache.flink.api.common.typeutils.base.StringSerializer; +import org.apache.flink.api.common.typeutils.base.array.BytePrimitiveArraySerializer; +import org.apache.flink.queryablestate.KvStateID; +import org.apache.flink.queryablestate.client.VoidNamespace; +import org.apache.flink.queryablestate.client.VoidNamespaceSerializer; +import org.apache.flink.queryablestate.client.state.serialization.KvStateSerializer; +import org.apache.flink.queryablestate.exceptions.UnknownKeyOrNamespaceException; +import org.apache.flink.queryablestate.exceptions.UnknownKvStateIdException; +import org.apache.flink.queryablestate.messages.KvStateInternalRequest; +import org.apache.flink.queryablestate.messages.KvStateResponse; +import org.apache.flink.queryablestate.network.messages.MessageSerializer; +import org.apache.flink.queryablestate.network.messages.MessageType; +import org.apache.flink.queryablestate.network.messages.RequestFailure; +import org.apache.flink.queryablestate.network.stats.AtomicKvStateRequestStats; +import org.apache.flink.queryablestate.network.stats.DisabledKvStateRequestStats; +import org.apache.flink.queryablestate.network.stats.KvStateRequestStats; +import org.apache.flink.queryablestate.server.KvStateServerHandler; +import org.apache.flink.queryablestate.server.KvStateServerImpl; +import org.apache.flink.runtime.jobgraph.JobVertexID; +import org.apache.flink.runtime.operators.testutils.DummyEnvironment; +import org.apache.flink.runtime.query.KvStateRegistry; +import org.apache.flink.runtime.query.KvStateRegistryListener; +import org.apache.flink.runtime.state.AbstractKeyedStateBackend; +import org.apache.flink.runtime.state.AbstractStateBackend; +import org.apache.flink.runtime.state.KeyGroupRange; +import org.apache.flink.runtime.state.KeyedStateBackend; +import org.apache.flink.runtime.state.internal.InternalKvState; +import org.apache.flink.runtime.state.memory.MemoryStateBackend; +import org.apache.flink.util.TestLogger; + +import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf; +import org.apache.flink.shaded.netty4.io.netty.buffer.Unpooled; +import org.apache.flink.shaded.netty4.io.netty.channel.ChannelHandler; +import org.apache.flink.shaded.netty4.io.netty.channel.embedded.EmbeddedChannel; +import org.apache.flink.shaded.netty4.io.netty.handler.codec.LengthFieldBasedFrameDecoder; + +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; + +import java.net.InetAddress; +import java.util.Collections; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.mockito.Matchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +/** + * Tests for {@link KvStateServerHandler}. + */ +public class KvStateServerHandlerTest extends TestLogger { + + private static KvStateServerImpl testServer; + + private static final long READ_TIMEOUT_MILLIS = 10000L; + + @BeforeClass + public static void setup() { + try { + testServer = new KvStateServerImpl( + InetAddress.getLocalHost(), + Collections.singletonList(0).iterator(), + 1, + 1, + new KvStateRegistry(), + new DisabledKvStateRequestStats()); + testServer.start(); + } catch (Throwable e) { + e.printStackTrace(); + } + } + + @AfterClass + public static void tearDown() throws Exception { + testServer.shutdown(); + } + + /** + * Tests a simple successful query via an EmbeddedChannel. + */ + @Test + public void testSimpleQuery() throws Exception { + KvStateRegistry registry = new KvStateRegistry(); + AtomicKvStateRequestStats stats = new AtomicKvStateRequestStats(); + + MessageSerializer<KvStateInternalRequest, KvStateResponse> serializer = + new MessageSerializer<>(new KvStateInternalRequest.KvStateInternalRequestDeserializer(), new KvStateResponse.KvStateResponseDeserializer()); + + KvStateServerHandler handler = new KvStateServerHandler(testServer, registry, serializer, stats); + EmbeddedChannel channel = new EmbeddedChannel(getFrameDecoder(), handler); + + // Register state + ValueStateDescriptor<Integer> desc = new ValueStateDescriptor<>("any", IntSerializer.INSTANCE); + desc.setQueryable("vanilla"); + + int numKeyGroups = 1; + AbstractStateBackend abstractBackend = new MemoryStateBackend(); + DummyEnvironment dummyEnv = new DummyEnvironment("test", 1, 0); + dummyEnv.setKvStateRegistry(registry); + AbstractKeyedStateBackend<Integer> backend = abstractBackend.createKeyedStateBackend( + dummyEnv, + new JobID(), + "test_op", + IntSerializer.INSTANCE, + numKeyGroups, + new KeyGroupRange(0, 0), + registry.createTaskRegistry(dummyEnv.getJobID(), dummyEnv.getJobVertexId())); + + final TestRegistryListener registryListener = new TestRegistryListener(); + registry.registerListener(registryListener); + + // Update the KvState and request it + int expectedValue = 712828289; + + int key = 99812822; + backend.setCurrentKey(key); + ValueState<Integer> state = backend.getPartitionedState( + VoidNamespace.INSTANCE, + VoidNamespaceSerializer.INSTANCE, + desc); + + state.update(expectedValue); + + byte[] serializedKeyAndNamespace = KvStateSerializer.serializeKeyAndNamespace( + key, + IntSerializer.INSTANCE, + VoidNamespace.INSTANCE, + VoidNamespaceSerializer.INSTANCE); + + long requestId = Integer.MAX_VALUE + 182828L; + + assertTrue(registryListener.registrationName.equals("vanilla")); + + KvStateInternalRequest request = new KvStateInternalRequest( + registryListener.kvStateId, serializedKeyAndNamespace); + + ByteBuf serRequest = MessageSerializer.serializeRequest(channel.alloc(), requestId, request); + + // Write the request and wait for the response + channel.writeInbound(serRequest); + + ByteBuf buf = (ByteBuf) readInboundBlocking(channel); + buf.skipBytes(4); // skip frame length + + // Verify the response + assertEquals(MessageType.REQUEST_RESULT, MessageSerializer.deserializeHeader(buf)); + long deserRequestId = MessageSerializer.getRequestId(buf); + KvStateResponse response = serializer.deserializeResponse(buf); + + assertEquals(requestId, deserRequestId); + + int actualValue = KvStateSerializer.deserializeValue(response.getContent(), IntSerializer.INSTANCE); + assertEquals(expectedValue, actualValue); + + assertEquals(stats.toString(), 1, stats.getNumRequests()); + + // Wait for async successful request report + long deadline = System.nanoTime() + TimeUnit.NANOSECONDS.convert(30, TimeUnit.SECONDS); + while (stats.getNumSuccessful() != 1L && System.nanoTime() <= deadline) { + Thread.sleep(10L); + } + + assertEquals(stats.toString(), 1L, stats.getNumSuccessful()); + } + + /** + * Tests the failure response with {@link UnknownKvStateIdException} as cause on + * queries for unregistered KvStateIDs. + */ + @Test + public void testQueryUnknownKvStateID() throws Exception { + KvStateRegistry registry = new KvStateRegistry(); + AtomicKvStateRequestStats stats = new AtomicKvStateRequestStats(); + + MessageSerializer<KvStateInternalRequest, KvStateResponse> serializer = + new MessageSerializer<>(new KvStateInternalRequest.KvStateInternalRequestDeserializer(), new KvStateResponse.KvStateResponseDeserializer()); + + KvStateServerHandler handler = new KvStateServerHandler(testServer, registry, serializer, stats); + EmbeddedChannel channel = new EmbeddedChannel(getFrameDecoder(), handler); + + long requestId = Integer.MAX_VALUE + 182828L; + + KvStateInternalRequest request = new KvStateInternalRequest(new KvStateID(), new byte[0]); + + ByteBuf serRequest = MessageSerializer.serializeRequest(channel.alloc(), requestId, request); + + // Write the request and wait for the response + channel.writeInbound(serRequest); + + ByteBuf buf = (ByteBuf) readInboundBlocking(channel); + buf.skipBytes(4); // skip frame length + + // Verify the response + assertEquals(MessageType.REQUEST_FAILURE, MessageSerializer.deserializeHeader(buf)); + RequestFailure response = MessageSerializer.deserializeRequestFailure(buf); + + assertEquals(requestId, response.getRequestId()); + + assertTrue("Did not respond with expected failure cause", response.getCause() instanceof UnknownKvStateIdException); + + assertEquals(1L, stats.getNumRequests()); + assertEquals(1L, stats.getNumFailed()); + } + + /** + * Tests the failure response with {@link UnknownKeyOrNamespaceException} as cause + * on queries for non-existing keys. + */ + @Test + public void testQueryUnknownKey() throws Exception { + KvStateRegistry registry = new KvStateRegistry(); + AtomicKvStateRequestStats stats = new AtomicKvStateRequestStats(); + + MessageSerializer<KvStateInternalRequest, KvStateResponse> serializer = + new MessageSerializer<>(new KvStateInternalRequest.KvStateInternalRequestDeserializer(), new KvStateResponse.KvStateResponseDeserializer()); + + KvStateServerHandler handler = new KvStateServerHandler(testServer, registry, serializer, stats); + EmbeddedChannel channel = new EmbeddedChannel(getFrameDecoder(), handler); + + int numKeyGroups = 1; + AbstractStateBackend abstractBackend = new MemoryStateBackend(); + DummyEnvironment dummyEnv = new DummyEnvironment("test", 1, 0); + dummyEnv.setKvStateRegistry(registry); + KeyedStateBackend<Integer> backend = abstractBackend.createKeyedStateBackend( + dummyEnv, + new JobID(), + "test_op", + IntSerializer.INSTANCE, + numKeyGroups, + new KeyGroupRange(0, 0), + registry.createTaskRegistry(dummyEnv.getJobID(), dummyEnv.getJobVertexId())); + + final TestRegistryListener registryListener = new TestRegistryListener(); + registry.registerListener(registryListener); + + // Register state + ValueStateDescriptor<Integer> desc = new ValueStateDescriptor<>("any", IntSerializer.INSTANCE); + desc.setQueryable("vanilla"); + + backend.getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, desc); + + byte[] serializedKeyAndNamespace = KvStateSerializer.serializeKeyAndNamespace( + 1238283, + IntSerializer.INSTANCE, + VoidNamespace.INSTANCE, + VoidNamespaceSerializer.INSTANCE); + + long requestId = Integer.MAX_VALUE + 22982L; + + assertTrue(registryListener.registrationName.equals("vanilla")); + + KvStateInternalRequest request = new KvStateInternalRequest(registryListener.kvStateId, serializedKeyAndNamespace); + ByteBuf serRequest = MessageSerializer.serializeRequest(channel.alloc(), requestId, request); + + // Write the request and wait for the response + channel.writeInbound(serRequest); + + ByteBuf buf = (ByteBuf) readInboundBlocking(channel); + buf.skipBytes(4); // skip frame length + + // Verify the response + assertEquals(MessageType.REQUEST_FAILURE, MessageSerializer.deserializeHeader(buf)); + RequestFailure response = MessageSerializer.deserializeRequestFailure(buf); + + assertEquals(requestId, response.getRequestId()); + + assertTrue("Did not respond with expected failure cause", response.getCause() instanceof UnknownKeyOrNamespaceException); + + assertEquals(1L, stats.getNumRequests()); + assertEquals(1L, stats.getNumFailed()); + } + + /** + * Tests the failure response on a failure on the {@link InternalKvState#getSerializedValue(byte[])} call. + */ + @Test + public void testFailureOnGetSerializedValue() throws Exception { + KvStateRegistry registry = new KvStateRegistry(); + AtomicKvStateRequestStats stats = new AtomicKvStateRequestStats(); + + MessageSerializer<KvStateInternalRequest, KvStateResponse> serializer = + new MessageSerializer<>(new KvStateInternalRequest.KvStateInternalRequestDeserializer(), new KvStateResponse.KvStateResponseDeserializer()); + + KvStateServerHandler handler = new KvStateServerHandler(testServer, registry, serializer, stats); + EmbeddedChannel channel = new EmbeddedChannel(getFrameDecoder(), handler); + + // Failing KvState + InternalKvState<?> kvState = mock(InternalKvState.class); + when(kvState.getSerializedValue(any(byte[].class))) + .thenThrow(new RuntimeException("Expected test Exception")); + + KvStateID kvStateId = registry.registerKvState( + new JobID(), + new JobVertexID(), + new KeyGroupRange(0, 0), + "vanilla", + kvState); + + KvStateInternalRequest request = new KvStateInternalRequest(kvStateId, new byte[0]); + ByteBuf serRequest = MessageSerializer.serializeRequest(channel.alloc(), 282872L, request); + + // Write the request and wait for the response + channel.writeInbound(serRequest); + + ByteBuf buf = (ByteBuf) readInboundBlocking(channel); + buf.skipBytes(4); // skip frame length + + // Verify the response + assertEquals(MessageType.REQUEST_FAILURE, MessageSerializer.deserializeHeader(buf)); + RequestFailure response = MessageSerializer.deserializeRequestFailure(buf); + + assertTrue(response.getCause().getMessage().contains("Expected test Exception")); + + assertEquals(1L, stats.getNumRequests()); + assertEquals(1L, stats.getNumFailed()); + } + + /** + * Tests that the channel is closed if an Exception reaches the channel handler. + */ + @Test + public void testCloseChannelOnExceptionCaught() throws Exception { + KvStateRegistry registry = new KvStateRegistry(); + AtomicKvStateRequestStats stats = new AtomicKvStateRequestStats(); + + MessageSerializer<KvStateInternalRequest, KvStateResponse> serializer = + new MessageSerializer<>(new KvStateInternalRequest.KvStateInternalRequestDeserializer(), new KvStateResponse.KvStateResponseDeserializer()); + + KvStateServerHandler handler = new KvStateServerHandler(testServer, registry, serializer, stats); + EmbeddedChannel channel = new EmbeddedChannel(handler); + + channel.pipeline().fireExceptionCaught(new RuntimeException("Expected test Exception")); + + ByteBuf buf = (ByteBuf) readInboundBlocking(channel); + buf.skipBytes(4); // skip frame length + + // Verify the response + assertEquals(MessageType.SERVER_FAILURE, MessageSerializer.deserializeHeader(buf)); + Throwable response = MessageSerializer.deserializeServerFailure(buf); + + assertTrue(response.getMessage().contains("Expected test Exception")); + + channel.closeFuture().await(READ_TIMEOUT_MILLIS); + assertFalse(channel.isActive()); + } + + /** + * Tests the failure response on a rejected execution, because the query executor has been closed. + */ + @Test + public void testQueryExecutorShutDown() throws Throwable { + KvStateRegistry registry = new KvStateRegistry(); + AtomicKvStateRequestStats stats = new AtomicKvStateRequestStats(); + + KvStateServerImpl localTestServer = new KvStateServerImpl( + InetAddress.getLocalHost(), + Collections.singletonList(0).iterator(), + 1, + 1, + new KvStateRegistry(), + new DisabledKvStateRequestStats()); + + localTestServer.start(); + localTestServer.shutdown(); + assertTrue(localTestServer.isExecutorShutdown()); + + MessageSerializer<KvStateInternalRequest, KvStateResponse> serializer = + new MessageSerializer<>(new KvStateInternalRequest.KvStateInternalRequestDeserializer(), new KvStateResponse.KvStateResponseDeserializer()); + + KvStateServerHandler handler = new KvStateServerHandler(localTestServer, registry, serializer, stats); + EmbeddedChannel channel = new EmbeddedChannel(getFrameDecoder(), handler); + + int numKeyGroups = 1; + AbstractStateBackend abstractBackend = new MemoryStateBackend(); + DummyEnvironment dummyEnv = new DummyEnvironment("test", 1, 0); + dummyEnv.setKvStateRegistry(registry); + KeyedStateBackend<Integer> backend = abstractBackend.createKeyedStateBackend( + dummyEnv, + new JobID(), + "test_op", + IntSerializer.INSTANCE, + numKeyGroups, + new KeyGroupRange(0, 0), + registry.createTaskRegistry(dummyEnv.getJobID(), dummyEnv.getJobVertexId())); + + final TestRegistryListener registryListener = new TestRegistryListener(); + registry.registerListener(registryListener); + + // Register state + ValueStateDescriptor<Integer> desc = new ValueStateDescriptor<>("any", IntSerializer.INSTANCE); + desc.setQueryable("vanilla"); + + backend.getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, desc); + + assertTrue(registryListener.registrationName.equals("vanilla")); + + KvStateInternalRequest request = new KvStateInternalRequest(registryListener.kvStateId, new byte[0]); + ByteBuf serRequest = MessageSerializer.serializeRequest(channel.alloc(), 282872L, request); + + // Write the request and wait for the response + channel.writeInbound(serRequest); + + ByteBuf buf = (ByteBuf) readInboundBlocking(channel); + buf.skipBytes(4); // skip frame length + + // Verify the response + assertEquals(MessageType.REQUEST_FAILURE, MessageSerializer.deserializeHeader(buf)); + RequestFailure response = MessageSerializer.deserializeRequestFailure(buf); + + assertTrue(response.getCause().getMessage().contains("RejectedExecutionException")); + + assertEquals(1L, stats.getNumRequests()); + assertEquals(1L, stats.getNumFailed()); + + localTestServer.shutdown(); + } + + /** + * Tests response on unexpected messages. + */ + @Test + public void testUnexpectedMessage() throws Exception { + KvStateRegistry registry = new KvStateRegistry(); + AtomicKvStateRequestStats stats = new AtomicKvStateRequestStats(); + + MessageSerializer<KvStateInternalRequest, KvStateResponse> serializer = + new MessageSerializer<>(new KvStateInternalRequest.KvStateInternalRequestDeserializer(), new KvStateResponse.KvStateResponseDeserializer()); + + KvStateServerHandler handler = new KvStateServerHandler(testServer, registry, serializer, stats); + EmbeddedChannel channel = new EmbeddedChannel(getFrameDecoder(), handler); + + // Write the request and wait for the response + ByteBuf unexpectedMessage = Unpooled.buffer(8); + unexpectedMessage.writeInt(4); + unexpectedMessage.writeInt(123238213); + + channel.writeInbound(unexpectedMessage); + + ByteBuf buf = (ByteBuf) readInboundBlocking(channel); + buf.skipBytes(4); // skip frame length + + // Verify the response + assertEquals(MessageType.SERVER_FAILURE, MessageSerializer.deserializeHeader(buf)); + Throwable response = MessageSerializer.deserializeServerFailure(buf); + + assertEquals(0L, stats.getNumRequests()); + assertEquals(0L, stats.getNumFailed()); + + KvStateResponse stateResponse = new KvStateResponse(new byte[0]); + unexpectedMessage = MessageSerializer.serializeResponse(channel.alloc(), 192L, stateResponse); + + channel.writeInbound(unexpectedMessage); + + buf = (ByteBuf) readInboundBlocking(channel); + buf.skipBytes(4); // skip frame length + + // Verify the response + assertEquals(MessageType.SERVER_FAILURE, MessageSerializer.deserializeHeader(buf)); + response = MessageSerializer.deserializeServerFailure(buf); + + assertTrue("Unexpected failure cause " + response.getClass().getName(), response instanceof IllegalArgumentException); + + assertEquals(0L, stats.getNumRequests()); + assertEquals(0L, stats.getNumFailed()); + } + + /** + * Tests that incoming buffer instances are recycled. + */ + @Test + public void testIncomingBufferIsRecycled() throws Exception { + KvStateRegistry registry = new KvStateRegistry(); + AtomicKvStateRequestStats stats = new AtomicKvStateRequestStats(); + + MessageSerializer<KvStateInternalRequest, KvStateResponse> serializer = + new MessageSerializer<>(new KvStateInternalRequest.KvStateInternalRequestDeserializer(), new KvStateResponse.KvStateResponseDeserializer()); + + KvStateServerHandler handler = new KvStateServerHandler(testServer, registry, serializer, stats); + EmbeddedChannel channel = new EmbeddedChannel(getFrameDecoder(), handler); + + KvStateInternalRequest request = new KvStateInternalRequest(new KvStateID(), new byte[0]); + ByteBuf serRequest = MessageSerializer.serializeRequest(channel.alloc(), 282872L, request); + + assertEquals(1L, serRequest.refCnt()); + + // Write regular request + channel.writeInbound(serRequest); + assertEquals("Buffer not recycled", 0L, serRequest.refCnt()); + + // Write unexpected msg + ByteBuf unexpected = channel.alloc().buffer(8); + unexpected.writeInt(4); + unexpected.writeInt(4); + + assertEquals(1L, unexpected.refCnt()); + + channel.writeInbound(unexpected); + assertEquals("Buffer not recycled", 0L, unexpected.refCnt()); + } + + /** + * Tests the failure response if the serializers don't match. + */ + @Test + public void testSerializerMismatch() throws Exception { + KvStateRegistry registry = new KvStateRegistry(); + AtomicKvStateRequestStats stats = new AtomicKvStateRequestStats(); + + MessageSerializer<KvStateInternalRequest, KvStateResponse> serializer = + new MessageSerializer<>(new KvStateInternalRequest.KvStateInternalRequestDeserializer(), new KvStateResponse.KvStateResponseDeserializer()); + + KvStateServerHandler handler = new KvStateServerHandler(testServer, registry, serializer, stats); + EmbeddedChannel channel = new EmbeddedChannel(getFrameDecoder(), handler); + + int numKeyGroups = 1; + AbstractStateBackend abstractBackend = new MemoryStateBackend(); + DummyEnvironment dummyEnv = new DummyEnvironment("test", 1, 0); + dummyEnv.setKvStateRegistry(registry); + AbstractKeyedStateBackend<Integer> backend = abstractBackend.createKeyedStateBackend( + dummyEnv, + new JobID(), + "test_op", + IntSerializer.INSTANCE, + numKeyGroups, + new KeyGroupRange(0, 0), + registry.createTaskRegistry(dummyEnv.getJobID(), dummyEnv.getJobVertexId())); + + final TestRegistryListener registryListener = new TestRegistryListener(); + registry.registerListener(registryListener); + + // Register state + ValueStateDescriptor<Integer> desc = new ValueStateDescriptor<>("any", IntSerializer.INSTANCE); + desc.setQueryable("vanilla"); + + ValueState<Integer> state = backend.getPartitionedState( + VoidNamespace.INSTANCE, + VoidNamespaceSerializer.INSTANCE, + desc); + + int key = 99812822; + + // Update the KvState + backend.setCurrentKey(key); + state.update(712828289); + + byte[] wrongKeyAndNamespace = KvStateSerializer.serializeKeyAndNamespace( + "wrong-key-type", + StringSerializer.INSTANCE, + "wrong-namespace-type", + StringSerializer.INSTANCE); + + byte[] wrongNamespace = KvStateSerializer.serializeKeyAndNamespace( + key, + IntSerializer.INSTANCE, + "wrong-namespace-type", + StringSerializer.INSTANCE); + + assertTrue(registryListener.registrationName.equals("vanilla")); + + KvStateInternalRequest request = new KvStateInternalRequest(registryListener.kvStateId, wrongKeyAndNamespace); + ByteBuf serRequest = MessageSerializer.serializeRequest(channel.alloc(), 182828L, request); + + // Write the request and wait for the response + channel.writeInbound(serRequest); + + ByteBuf buf = (ByteBuf) readInboundBlocking(channel); + buf.skipBytes(4); // skip frame length + + // Verify the response + assertEquals(MessageType.REQUEST_FAILURE, MessageSerializer.deserializeHeader(buf)); + RequestFailure response = MessageSerializer.deserializeRequestFailure(buf); + assertEquals(182828L, response.getRequestId()); + assertTrue(response.getCause().getMessage().contains("IOException")); + + // Repeat with wrong namespace only + request = new KvStateInternalRequest(registryListener.kvStateId, wrongNamespace); + serRequest = MessageSerializer.serializeRequest(channel.alloc(), 182829L, request); + + // Write the request and wait for the response + channel.writeInbound(serRequest); + + buf = (ByteBuf) readInboundBlocking(channel); + buf.skipBytes(4); // skip frame length + + // Verify the response + assertEquals(MessageType.REQUEST_FAILURE, MessageSerializer.deserializeHeader(buf)); + response = MessageSerializer.deserializeRequestFailure(buf); + assertEquals(182829L, response.getRequestId()); + assertTrue(response.getCause().getMessage().contains("IOException")); + + assertEquals(2L, stats.getNumRequests()); + assertEquals(2L, stats.getNumFailed()); + } + + /** + * Tests that large responses are chunked. + */ + @Test + public void testChunkedResponse() throws Exception { + KvStateRegistry registry = new KvStateRegistry(); + KvStateRequestStats stats = new AtomicKvStateRequestStats(); + + MessageSerializer<KvStateInternalRequest, KvStateResponse> serializer = + new MessageSerializer<>(new KvStateInternalRequest.KvStateInternalRequestDeserializer(), new KvStateResponse.KvStateResponseDeserializer()); + + KvStateServerHandler handler = new KvStateServerHandler(testServer, registry, serializer, stats); + EmbeddedChannel channel = new EmbeddedChannel(getFrameDecoder(), handler); + + int numKeyGroups = 1; + AbstractStateBackend abstractBackend = new MemoryStateBackend(); + DummyEnvironment dummyEnv = new DummyEnvironment("test", 1, 0); + dummyEnv.setKvStateRegistry(registry); + AbstractKeyedStateBackend<Integer> backend = abstractBackend.createKeyedStateBackend( + dummyEnv, + new JobID(), + "test_op", + IntSerializer.INSTANCE, + numKeyGroups, + new KeyGroupRange(0, 0), + registry.createTaskRegistry(dummyEnv.getJobID(), dummyEnv.getJobVertexId())); + + final TestRegistryListener registryListener = new TestRegistryListener(); + registry.registerListener(registryListener); + + // Register state + ValueStateDescriptor<byte[]> desc = new ValueStateDescriptor<>("any", BytePrimitiveArraySerializer.INSTANCE); + desc.setQueryable("vanilla"); + + ValueState<byte[]> state = backend.getPartitionedState( + VoidNamespace.INSTANCE, + VoidNamespaceSerializer.INSTANCE, + desc); + + // Update KvState + byte[] bytes = new byte[2 * channel.config().getWriteBufferHighWaterMark()]; + + byte current = 0; + for (int i = 0; i < bytes.length; i++) { + bytes[i] = current++; + } + + int key = 99812822; + backend.setCurrentKey(key); + state.update(bytes); + + // Request + byte[] serializedKeyAndNamespace = KvStateSerializer.serializeKeyAndNamespace( + key, + IntSerializer.INSTANCE, + VoidNamespace.INSTANCE, + VoidNamespaceSerializer.INSTANCE); + + long requestId = Integer.MAX_VALUE + 182828L; + + assertTrue(registryListener.registrationName.equals("vanilla")); + + KvStateInternalRequest request = new KvStateInternalRequest(registryListener.kvStateId, serializedKeyAndNamespace); + ByteBuf serRequest = MessageSerializer.serializeRequest(channel.alloc(), requestId, request); + + // Write the request and wait for the response + channel.writeInbound(serRequest); + + Object msg = readInboundBlocking(channel); + assertTrue("Not ChunkedByteBuf", msg instanceof ChunkedByteBuf); + } + + // ------------------------------------------------------------------------ + + /** + * Queries the embedded channel for data. + */ + private Object readInboundBlocking(EmbeddedChannel channel) throws InterruptedException, TimeoutException { + final long sleepMillis = 50L; + + long sleptMillis = 0L; + + Object msg = null; + while (sleptMillis < READ_TIMEOUT_MILLIS && + (msg = channel.readOutbound()) == null) { + + Thread.sleep(sleepMillis); + sleptMillis += sleepMillis; + } + + if (msg == null) { + throw new TimeoutException(); + } else { + return msg; + } + } + + /** + * Frame length decoder (expected by the serialized messages). + */ + private ChannelHandler getFrameDecoder() { + return new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 4, 0, 4); + } + + /** + * A listener that keeps the last updated KvState information so that a test + * can retrieve it. + */ + static class TestRegistryListener implements KvStateRegistryListener { + volatile JobVertexID jobVertexID; + volatile KeyGroupRange keyGroupIndex; + volatile String registrationName; + volatile KvStateID kvStateId; + + @Override + public void notifyKvStateRegistered(JobID jobId, + JobVertexID jobVertexId, + KeyGroupRange keyGroupRange, + String registrationName, + KvStateID kvStateId) { + this.jobVertexID = jobVertexId; + this.keyGroupIndex = keyGroupRange; + this.registrationName = registrationName; + this.kvStateId = kvStateId; + } + + @Override + public void notifyKvStateUnregistered(JobID jobId, + JobVertexID jobVertexId, + KeyGroupRange keyGroupRange, + String registrationName) { + + } + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/0c771505/flink-queryable-state/flink-queryable-state-runtime/src/test/java/org/apache/flink/queryablestate/network/KvStateServerTest.java ---------------------------------------------------------------------- diff --git a/flink-queryable-state/flink-queryable-state-runtime/src/test/java/org/apache/flink/queryablestate/network/KvStateServerTest.java b/flink-queryable-state/flink-queryable-state-runtime/src/test/java/org/apache/flink/queryablestate/network/KvStateServerTest.java new file mode 100644 index 0000000..debd190 --- /dev/null +++ b/flink-queryable-state/flink-queryable-state-runtime/src/test/java/org/apache/flink/queryablestate/network/KvStateServerTest.java @@ -0,0 +1,212 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.queryablestate.network; + +import org.apache.flink.api.common.JobID; +import org.apache.flink.api.common.state.ValueState; +import org.apache.flink.api.common.state.ValueStateDescriptor; +import org.apache.flink.api.common.typeutils.base.IntSerializer; +import org.apache.flink.queryablestate.client.VoidNamespace; +import org.apache.flink.queryablestate.client.VoidNamespaceSerializer; +import org.apache.flink.queryablestate.client.state.serialization.KvStateSerializer; +import org.apache.flink.queryablestate.messages.KvStateInternalRequest; +import org.apache.flink.queryablestate.messages.KvStateResponse; +import org.apache.flink.queryablestate.network.messages.MessageSerializer; +import org.apache.flink.queryablestate.network.messages.MessageType; +import org.apache.flink.queryablestate.network.stats.AtomicKvStateRequestStats; +import org.apache.flink.queryablestate.network.stats.KvStateRequestStats; +import org.apache.flink.queryablestate.server.KvStateServerImpl; +import org.apache.flink.runtime.jobgraph.JobVertexID; +import org.apache.flink.runtime.operators.testutils.DummyEnvironment; +import org.apache.flink.runtime.query.KvStateRegistry; +import org.apache.flink.runtime.state.AbstractKeyedStateBackend; +import org.apache.flink.runtime.state.AbstractStateBackend; +import org.apache.flink.runtime.state.KeyGroupRange; +import org.apache.flink.runtime.state.memory.MemoryStateBackend; + +import org.apache.flink.shaded.netty4.io.netty.bootstrap.Bootstrap; +import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf; +import org.apache.flink.shaded.netty4.io.netty.channel.Channel; +import org.apache.flink.shaded.netty4.io.netty.channel.ChannelHandler; +import org.apache.flink.shaded.netty4.io.netty.channel.ChannelHandlerContext; +import org.apache.flink.shaded.netty4.io.netty.channel.ChannelInboundHandlerAdapter; +import org.apache.flink.shaded.netty4.io.netty.channel.ChannelInitializer; +import org.apache.flink.shaded.netty4.io.netty.channel.EventLoopGroup; +import org.apache.flink.shaded.netty4.io.netty.channel.nio.NioEventLoopGroup; +import org.apache.flink.shaded.netty4.io.netty.channel.socket.SocketChannel; +import org.apache.flink.shaded.netty4.io.netty.channel.socket.nio.NioSocketChannel; +import org.apache.flink.shaded.netty4.io.netty.handler.codec.LengthFieldBasedFrameDecoder; + +import org.junit.AfterClass; +import org.junit.Test; + +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.util.Collections; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.TimeUnit; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +/** + * Tests for {@link KvStateServerImpl}. + */ +public class KvStateServerTest { + + // Thread pool for client bootstrap (shared between tests) + private static final NioEventLoopGroup NIO_GROUP = new NioEventLoopGroup(); + + private static final int TIMEOUT_MILLIS = 10000; + + @AfterClass + public static void tearDown() throws Exception { + if (NIO_GROUP != null) { + NIO_GROUP.shutdownGracefully(); + } + } + + /** + * Tests a simple successful query via a SocketChannel. + */ + @Test + public void testSimpleRequest() throws Throwable { + KvStateServerImpl server = null; + Bootstrap bootstrap = null; + try { + KvStateRegistry registry = new KvStateRegistry(); + KvStateRequestStats stats = new AtomicKvStateRequestStats(); + + server = new KvStateServerImpl( + InetAddress.getLocalHost(), + Collections.singletonList(0).iterator(), + 1, + 1, + registry, + stats); + server.start(); + + InetSocketAddress serverAddress = server.getServerAddress(); + int numKeyGroups = 1; + AbstractStateBackend abstractBackend = new MemoryStateBackend(); + DummyEnvironment dummyEnv = new DummyEnvironment("test", 1, 0); + dummyEnv.setKvStateRegistry(registry); + AbstractKeyedStateBackend<Integer> backend = abstractBackend.createKeyedStateBackend( + dummyEnv, + new JobID(), + "test_op", + IntSerializer.INSTANCE, + numKeyGroups, + new KeyGroupRange(0, 0), + registry.createTaskRegistry(new JobID(), new JobVertexID())); + + final KvStateServerHandlerTest.TestRegistryListener registryListener = + new KvStateServerHandlerTest.TestRegistryListener(); + + registry.registerListener(registryListener); + + ValueStateDescriptor<Integer> desc = new ValueStateDescriptor<>("any", IntSerializer.INSTANCE); + desc.setQueryable("vanilla"); + + ValueState<Integer> state = backend.getPartitionedState( + VoidNamespace.INSTANCE, + VoidNamespaceSerializer.INSTANCE, + desc); + + // Update KvState + int expectedValue = 712828289; + + int key = 99812822; + backend.setCurrentKey(key); + state.update(expectedValue); + + // Request + byte[] serializedKeyAndNamespace = KvStateSerializer.serializeKeyAndNamespace( + key, + IntSerializer.INSTANCE, + VoidNamespace.INSTANCE, + VoidNamespaceSerializer.INSTANCE); + + // Connect to the server + final BlockingQueue<ByteBuf> responses = new LinkedBlockingQueue<>(); + bootstrap = createBootstrap( + new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 4, 0, 4), + new ChannelInboundHandlerAdapter() { + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + responses.add((ByteBuf) msg); + } + }); + + Channel channel = bootstrap + .connect(serverAddress.getAddress(), serverAddress.getPort()) + .sync().channel(); + + long requestId = Integer.MAX_VALUE + 182828L; + + assertTrue(registryListener.registrationName.equals("vanilla")); + + final KvStateInternalRequest request = new KvStateInternalRequest( + registryListener.kvStateId, + serializedKeyAndNamespace); + + ByteBuf serializeRequest = MessageSerializer.serializeRequest( + channel.alloc(), + requestId, + request); + + channel.writeAndFlush(serializeRequest); + + ByteBuf buf = responses.poll(TIMEOUT_MILLIS, TimeUnit.MILLISECONDS); + + assertEquals(MessageType.REQUEST_RESULT, MessageSerializer.deserializeHeader(buf)); + assertEquals(requestId, MessageSerializer.getRequestId(buf)); + KvStateResponse response = server.getSerializer().deserializeResponse(buf); + + int actualValue = KvStateSerializer.deserializeValue(response.getContent(), IntSerializer.INSTANCE); + assertEquals(expectedValue, actualValue); + } finally { + if (server != null) { + server.shutdown(); + } + + if (bootstrap != null) { + EventLoopGroup group = bootstrap.group(); + if (group != null) { + group.shutdownGracefully(); + } + } + } + } + + /** + * Creates a client bootstrap. + */ + private Bootstrap createBootstrap(final ChannelHandler... handlers) { + return new Bootstrap().group(NIO_GROUP).channel(NioSocketChannel.class) + .handler(new ChannelInitializer<SocketChannel>() { + @Override + protected void initChannel(SocketChannel ch) throws Exception { + ch.pipeline().addLast(handlers); + } + }); + } + +}
