This is an automated email from the ASF dual-hosted git repository.

uce pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git

commit 440076a751e169358935960eb34fdd8d9051357b
Author: Ufuk Celebi <[email protected]>
AuthorDate: Tue May 28 11:41:55 2024 +0200

    [FLINK-26808][rest] Only accept file upload at mutlipart routes
---
 .../flink/runtime/rest/FileUploadHandler.java      | 31 ++++++++-
 .../flink/runtime/rest/RestServerEndpoint.java     | 33 +++++++++-
 .../runtime/rest/RestServerEndpointITCase.java     | 77 ++++++++++++++++------
 3 files changed, 116 insertions(+), 25 deletions(-)

diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/rest/FileUploadHandler.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/rest/FileUploadHandler.java
index c9e1fd78d74..3a969a57308 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/rest/FileUploadHandler.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/rest/FileUploadHandler.java
@@ -19,6 +19,7 @@
 package org.apache.flink.runtime.rest;
 
 import org.apache.flink.runtime.rest.handler.FileUploads;
+import org.apache.flink.runtime.rest.handler.router.MultipartRoutes;
 import org.apache.flink.runtime.rest.handler.util.HandlerUtils;
 import org.apache.flink.runtime.rest.messages.ErrorResponseBody;
 import org.apache.flink.runtime.rest.util.RestConstants;
@@ -79,15 +80,15 @@ public class FileUploadHandler extends 
SimpleChannelInboundHandler<HttpObject> {
 
     private final Path uploadDir;
 
+    private final MultipartRoutes multipartRoutes;
+
     private HttpPostRequestDecoder currentHttpPostRequestDecoder;
 
     private HttpRequest currentHttpRequest;
     private byte[] currentJsonPayload;
     private Path currentUploadDir;
 
-    private boolean addCRPrefix = false;
-
-    public FileUploadHandler(final Path uploadDir) {
+    public FileUploadHandler(final Path uploadDir, final MultipartRoutes 
multipartRoutes) {
         super(true);
 
         // the clean up of temp files when jvm exits is handled by
@@ -103,6 +104,7 @@ public class FileUploadHandler extends 
SimpleChannelInboundHandler<HttpObject> {
         DiskAttribute.baseDirectory = DiskFileUpload.baseDirectory;
 
         this.uploadDir = requireNonNull(uploadDir);
+        this.multipartRoutes = requireNonNull(multipartRoutes);
     }
 
     @Override
@@ -125,6 +127,18 @@ public class FileUploadHandler extends 
SimpleChannelInboundHandler<HttpObject> {
                                 new HttpPostRequestDecoder(DATA_FACTORY, 
httpRequest);
                         currentHttpRequest = 
ReferenceCountUtil.retain(httpRequest);
 
+                        // We check this after initializing the multipart file 
upload in order for
+                        // handleError to work correctly.
+                        if (!multipartRoutes.isPostRoute(httpRequest.uri())) {
+                            LOG.trace("POST request not allowed for {}.", 
httpRequest.uri());
+                            handleError(
+                                    ctx,
+                                    "POST request not allowed",
+                                    HttpResponseStatus.BAD_REQUEST,
+                                    null);
+                            return;
+                        }
+
                         // make sure that we still have a upload dir in case 
that it got deleted in
                         // the meanwhile
                         RestServerEndpoint.createUploadDir(uploadDir, LOG, 
false);
@@ -151,6 +165,17 @@ public class FileUploadHandler extends 
SimpleChannelInboundHandler<HttpObject> {
                         && hasNext(currentHttpPostRequestDecoder)) {
                     final InterfaceHttpData data = 
currentHttpPostRequestDecoder.next();
                     if (data.getHttpDataType() == 
InterfaceHttpData.HttpDataType.FileUpload) {
+                        HttpRequest httpRequest = currentHttpRequest;
+                        if 
(!multipartRoutes.isFileUploadRoute(httpRequest.uri())) {
+                            LOG.trace("File upload not allowed for {}.", 
httpRequest.uri());
+                            handleError(
+                                    ctx,
+                                    "File upload not allowed",
+                                    HttpResponseStatus.BAD_REQUEST,
+                                    null);
+                            return;
+                        }
+
                         final DiskFileUpload fileUpload = (DiskFileUpload) 
data;
                         checkState(fileUpload.isCompleted());
 
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/rest/RestServerEndpoint.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/rest/RestServerEndpoint.java
index 817e25521d8..4d021e019d0 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/rest/RestServerEndpoint.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/rest/RestServerEndpoint.java
@@ -28,8 +28,10 @@ import 
org.apache.flink.runtime.io.network.netty.SSLHandlerFactory;
 import org.apache.flink.runtime.net.RedirectingSslHandler;
 import org.apache.flink.runtime.rest.handler.PipelineErrorHandler;
 import org.apache.flink.runtime.rest.handler.RestHandlerSpecification;
+import org.apache.flink.runtime.rest.handler.router.MultipartRoutes;
 import org.apache.flink.runtime.rest.handler.router.Router;
 import org.apache.flink.runtime.rest.handler.router.RouterHandler;
+import org.apache.flink.runtime.rest.messages.UntypedResponseMessageHeaders;
 import org.apache.flink.runtime.rest.versioning.RestAPIVersion;
 import org.apache.flink.util.AutoCloseableAsync;
 import org.apache.flink.util.ConfigurationException;
@@ -196,6 +198,9 @@ public abstract class RestServerEndpoint implements 
RestService {
             checkAllEndpointsAndHandlersAreUnique(handlers);
             handlers.forEach(handler -> registerHandler(router, handler, log));
 
+            MultipartRoutes multipartRoutes = createMultipartRoutes(handlers);
+            log.debug("Using {} for FileUploadHandler", multipartRoutes);
+
             ChannelInitializer<SocketChannel> initializer =
                     new ChannelInitializer<SocketChannel>() {
 
@@ -216,7 +221,7 @@ public abstract class RestServerEndpoint implements 
RestService {
 
                             ch.pipeline()
                                     .addLast(new HttpServerCodec())
-                                    .addLast(new FileUploadHandler(uploadDir))
+                                    .addLast(new FileUploadHandler(uploadDir, 
multipartRoutes))
                                     .addLast(
                                             new FlinkHttpObjectAggregator(
                                                     maxContentLength, 
responseHeaders));
@@ -635,6 +640,32 @@ public abstract class RestServerEndpoint implements 
RestService {
         }
     }
 
+    private MultipartRoutes createMultipartRoutes(
+            List<Tuple2<RestHandlerSpecification, ChannelInboundHandler>> 
handlers) {
+        MultipartRoutes.Builder builder = new MultipartRoutes.Builder();
+
+        for (Tuple2<RestHandlerSpecification, ChannelInboundHandler> handler : 
handlers) {
+            if (handler.f0.getHttpMethod() == HttpMethodWrapper.POST) {
+                for (String url : getHandlerRoutes(handler.f0)) {
+                    builder.addPostRoute(url);
+                }
+            }
+
+            // The cast is necessary, because currently only 
UntypedResponseMessageHeaders exposes
+            // whether the handler accepts file uploads.
+            if (handler.f0 instanceof UntypedResponseMessageHeaders) {
+                UntypedResponseMessageHeaders<?, ?> header =
+                        (UntypedResponseMessageHeaders<?, ?>) handler.f0;
+                if (header.acceptsFileUploads()) {
+                    for (String url : getHandlerRoutes(header)) {
+                        builder.addFileUploadRoute(url);
+                    }
+                }
+            }
+        }
+        return builder.build();
+    }
+
     private static Iterable<String> getHandlerRoutes(RestHandlerSpecification 
handlerSpec) {
         final List<String> registeredRoutes = new ArrayList<>();
         final String handlerUrl = handlerSpec.getTargetRestEndpointURL();
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/rest/RestServerEndpointITCase.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/rest/RestServerEndpointITCase.java
index 2c6d3b0afef..bcc1ee07f0e 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/rest/RestServerEndpointITCase.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/rest/RestServerEndpointITCase.java
@@ -374,25 +374,12 @@ public class RestServerEndpointITCase {
     @TestTemplate
     void testFileUpload() throws Exception {
         final String boundary = generateMultiPartBoundary();
-        final String crlf = "\r\n";
         final String uploadedContent = "hello";
-        final HttpURLConnection connection = 
openHttpConnectionForUpload(boundary);
-
-        try (OutputStream output = connection.getOutputStream();
-                PrintWriter writer =
-                        new PrintWriter(
-                                new OutputStreamWriter(output, 
StandardCharsets.UTF_8), true)) {
+        final HttpURLConnection connection =
+                openHttpConnectionForUpload(
+                        boundary, 
TestUploadHeaders.INSTANCE.getTargetRestEndpointURL());
 
-            writer.append("--" + boundary).append(crlf);
-            writer.append("Content-Disposition: form-data; name=\"foo\"; 
filename=\"bar\"")
-                    .append(crlf);
-            writer.append("Content-Type: plain/text; 
charset=utf8").append(crlf);
-            writer.append(crlf).flush();
-            output.write(uploadedContent.getBytes(StandardCharsets.UTF_8));
-            output.flush();
-            writer.append(crlf).flush();
-            writer.append("--" + boundary + "--").append(crlf).flush();
-        }
+        uploadFile(connection, uploadedContent, boundary);
 
         assertThat(connection.getResponseCode()).isEqualTo(200);
         final byte[] lastUploadedFileContents = 
testUploadHandler.getLastUploadedFileContents();
@@ -400,6 +387,32 @@ public class RestServerEndpointITCase {
                 .isEqualTo(new String(lastUploadedFileContents, 
StandardCharsets.UTF_8));
     }
 
+    /**
+     * Tests that when a handler is marked as not accepting file uploads we 
(1) return an error and
+     * (2) don't upload the file to the upload directory.
+     */
+    @TestTemplate
+    void testFileUploadLimitedToAllowedUris() throws Exception {
+        final String boundary = generateMultiPartBoundary();
+        final File uploadDir = new File(tempFolder.toString(), 
"flink-web-upload");
+        final File[] preUploadFiles = uploadDir.listFiles();
+
+        // We need a handler that does not accept file uploads for this test
+        assertThat(TestVersionHeaders.INSTANCE.acceptsFileUploads()).isFalse();
+        String uri = TestVersionHeaders.INSTANCE.getTargetRestEndpointURL();
+
+        final HttpURLConnection connection = 
openHttpConnectionForUpload(boundary, uri);
+
+        uploadFile(connection, "hello", boundary);
+
+        assertThat(connection.getResponseCode()).isEqualTo(400);
+
+        // This is the important check. We don't want additional files when 
the handler does
+        // not accept file uploads.
+        final File[] postUploadFiles = uploadDir.listFiles();
+        assertThat(postUploadFiles).isEqualTo(preUploadFiles);
+    }
+
     /**
      * Sending multipart/form-data without a file should result in a bad 
request if the handler
      * expects a file upload.
@@ -408,7 +421,9 @@ public class RestServerEndpointITCase {
     void testMultiPartFormDataWithoutFileUpload() throws Exception {
         final String boundary = generateMultiPartBoundary();
         final String crlf = "\r\n";
-        final HttpURLConnection connection = 
openHttpConnectionForUpload(boundary);
+        final HttpURLConnection connection =
+                openHttpConnectionForUpload(
+                        boundary, 
TestUploadHeaders.INSTANCE.getTargetRestEndpointURL());
 
         try (OutputStream output = connection.getOutputStream();
                 PrintWriter writer =
@@ -715,11 +730,11 @@ public class RestServerEndpointITCase {
         return new File(resource.getFile());
     }
 
-    private HttpURLConnection openHttpConnectionForUpload(final String 
boundary)
-            throws IOException {
+    private HttpURLConnection openHttpConnectionForUpload(
+            final String boundary, final String uploadUri) throws IOException {
         final HttpURLConnection connection =
                 (HttpURLConnection)
-                        new URL(serverEndpoint.getRestBaseUrl() + 
"/upload").openConnection();
+                        new URL(serverEndpoint.getRestBaseUrl() + 
uploadUri).openConnection();
         connection.setDoOutput(true);
         connection.setRequestProperty("Content-Type", "multipart/form-data; 
boundary=" + boundary);
         return connection;
@@ -737,6 +752,26 @@ public class RestServerEndpointITCase {
         return sb.toString();
     }
 
+    private static void uploadFile(HttpURLConnection connection, String 
content, String boundary)
+            throws IOException {
+        final String crlf = "\r\n";
+        try (OutputStream output = connection.getOutputStream();
+                PrintWriter writer =
+                        new PrintWriter(
+                                new OutputStreamWriter(output, 
StandardCharsets.UTF_8), true)) {
+
+            writer.append("--" + boundary).append(crlf);
+            writer.append("Content-Disposition: form-data; name=\"foo\"; 
filename=\"bar\"")
+                    .append(crlf);
+            writer.append("Content-Type: plain/text; 
charset=utf8").append(crlf);
+            writer.append(crlf).flush();
+            output.write(content.getBytes(StandardCharsets.UTF_8));
+            output.flush();
+            writer.append(crlf).flush();
+            writer.append("--" + boundary + "--").append(crlf).flush();
+        }
+    }
+
     private static class TestHandler
             extends AbstractRestHandler<RestfulGateway, TestRequest, 
TestResponse, TestParameters> {
 

Reply via email to