WencongLiu commented on code in PR #20236: URL: https://github.com/apache/flink/pull/20236#discussion_r931887347
########## flink-table/flink-sql-gateway/src/test/java/org/apache/flink/table/gateway/rest/SqlGatewayRestEndpointITCase.java: ########## @@ -0,0 +1,596 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.gateway.rest; + +import org.apache.flink.api.common.time.Time; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.configuration.RestOptions; +import org.apache.flink.configuration.SecurityOptions; +import org.apache.flink.core.testutils.BlockerSync; +import org.apache.flink.core.testutils.OneShotLatch; +import org.apache.flink.runtime.rest.HttpMethodWrapper; +import org.apache.flink.runtime.rest.RestClient; +import org.apache.flink.runtime.rest.RestServerEndpoint; +import org.apache.flink.runtime.rest.handler.HandlerRequest; +import org.apache.flink.runtime.rest.messages.EmptyMessageParameters; +import org.apache.flink.runtime.rest.messages.EmptyRequestBody; +import org.apache.flink.runtime.rest.messages.RequestBody; +import org.apache.flink.runtime.rest.messages.ResponseBody; +import org.apache.flink.runtime.rest.util.RestClientException; +import org.apache.flink.runtime.rest.util.TestRestServerEndpoint; +import org.apache.flink.runtime.rpc.RpcUtils; +import org.apache.flink.runtime.rpc.exceptions.EndpointNotStartedException; +import org.apache.flink.table.gateway.api.SqlGatewayService; +import org.apache.flink.table.gateway.rest.handler.AbstractSqlGatewayRestHandler; +import org.apache.flink.table.gateway.rest.header.SqlGatewayMessageHeaders; +import org.apache.flink.table.gateway.rest.util.SqlGatewayRestAPIVersion; +import org.apache.flink.table.gateway.rest.util.SqlGatewayRestOptions; +import org.apache.flink.table.gateway.rest.util.TestSqlGatewayRestEndpoint; +import org.apache.flink.util.ExceptionUtils; +import org.apache.flink.util.FlinkRuntimeException; +import org.apache.flink.util.TestLogger; +import org.apache.flink.util.concurrent.ExecutorThreadFactory; +import org.apache.flink.util.concurrent.FutureUtils; + +import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonCreator; +import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonProperty; +import org.apache.flink.shaded.netty4.io.netty.handler.codec.http.HttpResponseStatus; + +import okhttp3.OkHttpClient; +import okhttp3.Request; +import okhttp3.Response; +import org.hamcrest.MatcherAssert; +import org.jetbrains.annotations.NotNull; +import org.jetbrains.annotations.Nullable; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.io.IOException; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.util.Collection; +import java.util.Collections; +import java.util.Objects; +import java.util.Optional; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.function.Function; + +import static org.apache.flink.core.testutils.CommonTestUtils.assertThrows; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.hamcrest.Matchers.greaterThanOrEqualTo; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.lessThanOrEqualTo; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +/** IT cases for {@link SqlGatewayRestEndpoint}. */ +public class SqlGatewayRestEndpointITCase extends TestLogger { + + private static final SqlGatewayService service = null; + + private static RestServerEndpoint serverEndpoint; + private static RestClient restClient; + private static InetSocketAddress serverAddress; + + private static TestBadCaseHandler testHandler; + private static TestVersionSelectionHeaders1 header1; + private static TestVersionSelectionHeaders2 header2; + private static TestBadCaseHeaders badCaseHeader; + private static TestVersionHandler testVersionHandler1; + private static TestVersionHandler testVersionHandler2; + + private static Configuration config; + private static final Time timeout = Time.seconds(10L); + + private static Configuration getBaseConfig() { + final String loopbackAddress = InetAddress.getLoopbackAddress().getHostAddress(); + + final Configuration config = new Configuration(); + config.setString(SqlGatewayRestOptions.BIND_PORT, "0"); + config.setString(SqlGatewayRestOptions.BIND_ADDRESS, loopbackAddress); + config.setString(SqlGatewayRestOptions.ADDRESS, loopbackAddress); + return config; + } + + @BeforeEach + public void setup() throws Exception { + // Test version cases + header1 = new TestVersionSelectionHeaders1(); + header2 = new TestVersionSelectionHeaders2(); + testVersionHandler1 = new TestVersionHandler(service, RpcUtils.INF_TIMEOUT, header1); + testVersionHandler2 = new TestVersionHandler(service, RpcUtils.INF_TIMEOUT, header2); + + // Test exception cases + badCaseHeader = new TestBadCaseHeaders(); + testHandler = new TestBadCaseHandler(service, RpcUtils.INF_TIMEOUT); + + // Init + config = getBaseConfig(); + serverEndpoint = + TestSqlGatewayRestEndpoint.builder(config, service) + .withHandler(badCaseHeader, testHandler) + .withHandler(header1, testVersionHandler1) + .withHandler(header2, testVersionHandler2) + .buildAndStart(); + + restClient = + new RestClient( + config, + Executors.newFixedThreadPool( + 1, new ExecutorThreadFactory("rest-client-thread-pool"))); + serverAddress = serverEndpoint.getServerAddress(); + } + + @BeforeEach + public void stop() throws Exception { + + if (restClient != null) { + restClient.shutdown(timeout); + restClient = null; + } + + if (serverEndpoint != null) { + serverEndpoint.closeAsync().get(timeout.getSize(), timeout.getUnit()); + serverEndpoint = null; + } + } + + /** Test that {@link SqlGatewayMessageHeaders} can identify the version correctly. */ + @Test + public void testSqlGatewayMessageHeaders() throws Exception { + // The header only support V1, but send request by V0 + Assertions.assertThrows( + IllegalArgumentException.class, + () -> + restClient.sendRequest( + serverAddress.getHostName(), + serverAddress.getPort(), + header2, + EmptyMessageParameters.getInstance(), + EmptyRequestBody.getInstance(), + Collections.emptyList(), + SqlGatewayRestAPIVersion.V0)); + + // The header only support V1, send request by V1 + CompletableFuture<TestResponse> specifiedVersionResponse = + restClient.sendRequest( + serverAddress.getHostName(), + serverAddress.getPort(), + header2, + EmptyMessageParameters.getInstance(), + EmptyRequestBody.getInstance(), + Collections.emptyList(), + SqlGatewayRestAPIVersion.V1); + + TestResponse testResponse1 = specifiedVersionResponse.get(5, TimeUnit.SECONDS); + assertEquals("V1", testResponse1.getStatus()); + + // The header only support V1, send request by latest version V1 + CompletableFuture<TestResponse> unspecifiedVersionResponse = + restClient.sendRequest( + serverAddress.getHostName(), + serverAddress.getPort(), + header2, + EmptyMessageParameters.getInstance(), + EmptyRequestBody.getInstance(), + Collections.emptyList()); + + TestResponse testResponse2 = unspecifiedVersionResponse.get(5, TimeUnit.SECONDS); + assertEquals("V1", testResponse2.getStatus()); + } + + /** Test that requests of different version are routed to correct handlers. */ + @Test + public void testVersionSelection() throws Exception { + CompletableFuture<TestResponse> version1Response = + restClient.sendRequest( + serverAddress.getHostName(), + serverAddress.getPort(), + header1, + EmptyMessageParameters.getInstance(), + EmptyRequestBody.getInstance(), + Collections.emptyList(), + SqlGatewayRestAPIVersion.V0); + + TestResponse testResponse = version1Response.get(5, TimeUnit.SECONDS); + assertEquals("V0", testResponse.getStatus()); + + CompletableFuture<TestResponse> version2Response = + restClient.sendRequest( + serverAddress.getHostName(), + serverAddress.getPort(), + header2, + EmptyMessageParameters.getInstance(), + EmptyRequestBody.getInstance(), + Collections.emptyList(), + SqlGatewayRestAPIVersion.V1); + TestResponse testResponse2 = version2Response.get(5, TimeUnit.SECONDS); + assertEquals("V1", testResponse2.getStatus()); + } + + /** + * Test that {@link AbstractSqlGatewayRestHandler} will use the default endpoint version when + * the url does not contain version. + */ + @Test + public void testDefaultVersionRouting() throws Exception { + assertFalse( + config.getBoolean(SecurityOptions.SSL_REST_ENABLED), + "Ignoring SSL-enabled test to keep OkHttp usage simple."); + + OkHttpClient client = new OkHttpClient(); + final Request request = + new Request.Builder() + .url(serverEndpoint.getRestBaseUrl() + header1.getTargetRestEndpointURL()) + .build(); + + final Response response = client.newCall(request).execute(); + assert response.body() != null; + assertTrue(response.body().string().contains("V1")); + } + + /** + * Tests that request are handled as individual units which don't interfere with each other. + * This means that request responses can overtake each other. + */ + @Test + public void testRequestInterleaving() throws Exception { + final BlockerSync sync = new BlockerSync(); + testHandler.handlerBody = + id -> { + if (id == 1) { + try { + sync.block(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + } + return CompletableFuture.completedFuture(new TestResponse(id.toString())); + }; + + // send first request and wait until the handler blocks + final CompletableFuture<TestResponse> response1 = + sendRequestToTestHandler(new TestRequest(1)); + sync.awaitBlocker(); + + // send second request and verify response + final CompletableFuture<TestResponse> response2 = + sendRequestToTestHandler(new TestRequest(2)); + assertEquals("2", response2.get().status); + + // wake up blocked handler + sync.releaseBlocker(); + + // verify response to first request + assertEquals("1", response1.get().status); + } + + @Test + public void testDuplicateHandlerRegistrationIsForbidden() { + assertThrows( + "Duplicate REST handler", + FlinkRuntimeException.class, + () -> { + try (TestRestServerEndpoint restServerEndpoint = + TestRestServerEndpoint.builder(config) + .withHandler(header1, testHandler) + .withHandler(badCaseHeader, testHandler) + .build()) { + restServerEndpoint.start(); + return null; + } + }); + } + + @Test + public void testEndpointsMustBeUnique() { + assertThrows( + "REST handler registration", + FlinkRuntimeException.class, + () -> { + try (TestRestServerEndpoint restServerEndpoint = + TestRestServerEndpoint.builder(config) + .withHandler(badCaseHeader, testHandler) + .withHandler(badCaseHeader, testVersionHandler1) + .build()) { + restServerEndpoint.start(); + return null; + } + }); + } + + /** + * Tests that after calling {@link SqlGatewayRestEndpoint#closeAsync()}, the handlers are closed + * first, and we wait for in-flight requests to finish. As long as not all handlers are closed, + * HTTP requests should be served. + */ + @Test + public void testShouldWaitForHandlersWhenClosing() throws Exception { + testHandler.closeFuture = new CompletableFuture<>(); + final BlockerSync sync = new BlockerSync(); + testHandler.handlerBody = + id -> { + // Intentionally schedule the work on a different thread. This is to simulate + // handlers where the CompletableFuture is finished by the RPC framework. + return CompletableFuture.supplyAsync( + () -> { + try { + sync.block(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + return new TestResponse(id.toString()); + }); + }; + + // Initiate closing RestServerEndpoint but the test handler should block. + final CompletableFuture<Void> closeRestServerEndpointFuture = serverEndpoint.closeAsync(); + assertFalse(closeRestServerEndpointFuture.isDone()); + + // create an in-flight request + final CompletableFuture<TestResponse> request = + sendRequestToTestHandler(new TestRequest(1)); + sync.awaitBlocker(); + + // Allow handler to close but there is still one in-flight request which should prevent + // the RestServerEndpoint from closing. + testHandler.closeFuture.complete(null); + assertFalse(closeRestServerEndpointFuture.isDone()); + + // Finish the in-flight request. + sync.releaseBlocker(); + + request.get(timeout.getSize(), timeout.getUnit()); + closeRestServerEndpointFuture.get(timeout.getSize(), timeout.getUnit()); + } + + @Test + public void testRestServerBindPort() throws Exception { + final int portRangeStart = 52300; + final int portRangeEnd = 52400; + final Configuration config = new Configuration(); + config.setString(RestOptions.ADDRESS, "localhost"); + config.setString(RestOptions.BIND_PORT, portRangeStart + "-" + portRangeEnd); Review Comment: Done. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
