This is an automated email from the ASF dual-hosted git repository.

yangjie01 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 4fc2910f92d1 [SPARK-48238][BUILD][YARN] Replace YARN AmIpFilter with a 
forked implementation
4fc2910f92d1 is described below

commit 4fc2910f92d1b5f7e0dd5f803e822668f23c21c5
Author: Cheng Pan <cheng...@apache.org>
AuthorDate: Mon May 20 20:42:57 2024 +0800

    [SPARK-48238][BUILD][YARN] Replace YARN AmIpFilter with a forked 
implementation
    
    ### What changes were proposed in this pull request?
    
    This PR replaces AmIpFilter with a forked implementation, and removes the 
dependency `hadoop-yarn-server-web-proxy`
    
    ### Why are the changes needed?
    
    SPARK-47118 upgraded Spark built-in Jetty from 10 to 11, and migrated from 
`javax.servlet` to `jakarta.servlet`, which breaks the Spark on YARN.
    
    ```
    Caused by: java.lang.IllegalStateException: class 
org.apache.hadoop.yarn.server.webproxy.amfilter.AmIpFilter is not a 
jakarta.servlet.Filter
        at 
org.sparkproject.jetty.servlet.FilterHolder.doStart(FilterHolder.java:99)
        at 
org.sparkproject.jetty.util.component.AbstractLifeCycle.start(AbstractLifeCycle.java:93)
        at 
org.sparkproject.jetty.servlet.ServletHandler.lambda$initialize$2(ServletHandler.java:724)
        at 
java.base/java.util.ArrayList$ArrayListSpliterator.forEachRemaining(ArrayList.java:1625)
        at 
java.base/java.util.stream.Streams$ConcatSpliterator.forEachRemaining(Streams.java:734)
        at 
java.base/java.util.stream.ReferencePipeline$Head.forEach(ReferencePipeline.java:762)
        at 
org.sparkproject.jetty.servlet.ServletHandler.initialize(ServletHandler.java:749)
        ... 38 more
    ```
    
    During the investigation, I found a comment here 
https://github.com/apache/spark/pull/31642#issuecomment-786257114
    
    > Agree that in the long term we should either: 1) consider to re-implement 
the logic in Spark which allows us to get away from server-side dependency in 
Hadoop ...
    
    This should be a simple and clean way to address the exact issue, then we 
don't need to wait for Hadoop `jakarta.servlet` migration, and it also strips a 
Hadoop dependency.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No, this recovers the bootstrap of the Spark application on YARN mode, 
keeping the same behavior with Spark 3.5 and earlier versions.
    
    ### How was this patch tested?
    
    UTs are added. (refer to 
`org.apache.hadoop.yarn.server.webproxy.amfilter.TestAmFilter`)
    
    I tested it in a YARN cluster.
    
    Spark successfully started.
    ```
    roothadoop-master1:/opt/spark-SPARK-48238# JAVA_HOME=/opt/openjdk-17 
bin/spark-sql --conf spark.yarn.appMasterEnv.JAVA_HOME=/opt/openjdk-17 --conf 
spark.executorEnv.JAVA_HOME=/opt/openjdk-17
    WARNING: Using incubator modules: jdk.incubator.vector
    Setting default log level to "WARN".
    To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use 
setLogLevel(newLevel).
    2024-05-18 04:11:36 WARN NativeCodeLoader: Unable to load native-hadoop 
library for your platform... using builtin-java classes where applicable
    2024-05-18 04:11:44 WARN Client: Neither spark.yarn.jars nor 
spark.yarn.archive} is set, falling back to uploading libraries under 
SPARK_HOME.
    Spark Web UI available at http://hadoop-master1.orb.local:4040
    Spark master: yarn, Application Id: application_1716005503866_0001
    spark-sql (default)> select version();
    4.0.0 4ddc2303c7cbabee12a3de9f674aaacad3f5eb01
    Time taken: 1.707 seconds, Fetched 1 row(s)
    spark-sql (default)>
    ```
    
    When access `http://hadoop-master1.orb.local:4040`, it redirects to 
`http://hadoop-master1.orb.local:8088/proxy/redirect/application_1716005503866_0001/`,
 and the UI looks correct.
    
    <img width="1474" alt="image" 
src="https://github.com/apache/spark/assets/26535726/8500fc83-48c5-4603-8d05-37855f0308ae";>
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #46611 from pan3793/SPARK-48238.
    
    Authored-by: Cheng Pan <cheng...@apache.org>
    Signed-off-by: yangjie01 <yangji...@baidu.com>
---
 assembly/pom.xml                                   |   4 -
 dev/deps/spark-deps-hadoop-3-hive-2.3              |   1 -
 pom.xml                                            |  77 -----
 .../org/apache/spark/deploy/yarn/AmIpFilter.java   | 239 ++++++++++++++
 .../apache/spark/deploy/yarn/AmIpPrincipal.java    |  35 +++
 .../deploy/yarn/AmIpServletRequestWrapper.java     |  54 ++++
 .../org/apache/spark/deploy/yarn/ProxyUtils.java   | 126 ++++++++
 .../spark/deploy/yarn/ApplicationMaster.scala      |   2 +-
 .../apache/spark/deploy/yarn/AmIpFilterSuite.scala | 342 +++++++++++++++++++++
 .../org/apache/spark/streaming/Checkpoint.scala    |   2 +-
 10 files changed, 798 insertions(+), 84 deletions(-)

diff --git a/assembly/pom.xml b/assembly/pom.xml
index 6c31ec745b5b..58e7ae5bb0c7 100644
--- a/assembly/pom.xml
+++ b/assembly/pom.xml
@@ -136,10 +136,6 @@
           <artifactId>spark-yarn_${scala.binary.version}</artifactId>
           <version>${project.version}</version>
         </dependency>
-        <dependency>
-          <groupId>org.apache.hadoop</groupId>
-          <artifactId>hadoop-yarn-server-web-proxy</artifactId>
-        </dependency>
       </dependencies>
     </profile>
     <profile>
diff --git a/dev/deps/spark-deps-hadoop-3-hive-2.3 
b/dev/deps/spark-deps-hadoop-3-hive-2.3
index 4b6f5dda585b..13e74a1627fb 100644
--- a/dev/deps/spark-deps-hadoop-3-hive-2.3
+++ b/dev/deps/spark-deps-hadoop-3-hive-2.3
@@ -80,7 +80,6 @@ hadoop-client-runtime/3.4.0//hadoop-client-runtime-3.4.0.jar
 hadoop-cloud-storage/3.4.0//hadoop-cloud-storage-3.4.0.jar
 hadoop-huaweicloud/3.4.0//hadoop-huaweicloud-3.4.0.jar
 hadoop-shaded-guava/1.2.0//hadoop-shaded-guava-1.2.0.jar
-hadoop-yarn-server-web-proxy/3.4.0//hadoop-yarn-server-web-proxy-3.4.0.jar
 hive-beeline/2.3.10//hive-beeline-2.3.10.jar
 hive-cli/2.3.10//hive-cli-2.3.10.jar
 hive-common/2.3.10//hive-common-2.3.10.jar
diff --git a/pom.xml b/pom.xml
index d92d210a5ffc..1d11d3840e25 100644
--- a/pom.xml
+++ b/pom.xml
@@ -1769,83 +1769,6 @@
         <version>${yarn.version}</version>
         <scope>test</scope>
       </dependency>
-      <dependency>
-        <groupId>org.apache.hadoop</groupId>
-        <artifactId>hadoop-yarn-server-web-proxy</artifactId>
-        <version>${yarn.version}</version>
-        <scope>${hadoop.deps.scope}</scope>
-        <exclusions>
-          <exclusion>
-            <groupId>org.apache.hadoop</groupId>
-            <artifactId>hadoop-yarn-server-common</artifactId>
-          </exclusion>
-          <exclusion>
-            <groupId>org.apache.hadoop</groupId>
-            <artifactId>hadoop-yarn-common</artifactId>
-          </exclusion>
-          <exclusion>
-            <groupId>org.apache.hadoop</groupId>
-            <artifactId>hadoop-yarn-api</artifactId>
-          </exclusion>
-          <exclusion>
-            <groupId>org.bouncycastle</groupId>
-            <artifactId>bcprov-jdk15on</artifactId>
-          </exclusion>
-          <exclusion>
-            <groupId>org.bouncycastle</groupId>
-            <artifactId>bcpkix-jdk15on</artifactId>
-          </exclusion>
-          <exclusion>
-            <groupId>org.fusesource.leveldbjni</groupId>
-            <artifactId>leveldbjni-all</artifactId>
-          </exclusion>
-          <exclusion>
-            <groupId>asm</groupId>
-            <artifactId>asm</artifactId>
-          </exclusion>
-          <exclusion>
-            <groupId>org.ow2.asm</groupId>
-            <artifactId>asm</artifactId>
-          </exclusion>
-          <exclusion>
-            <groupId>org.jboss.netty</groupId>
-            <artifactId>netty</artifactId>
-          </exclusion>
-          <exclusion>
-            <groupId>javax.servlet</groupId>
-            <artifactId>servlet-api</artifactId>
-          </exclusion>
-          <exclusion>
-            <groupId>javax.servlet</groupId>
-            <artifactId>javax.servlet-api</artifactId>
-          </exclusion>
-          <exclusion>
-            <groupId>commons-logging</groupId>
-            <artifactId>commons-logging</artifactId>
-          </exclusion>
-          <exclusion>
-            <groupId>com.sun.jersey</groupId>
-            <artifactId>*</artifactId>
-          </exclusion>
-          <exclusion>
-            <groupId>com.sun.jersey.jersey-test-framework</groupId>
-            <artifactId>*</artifactId>
-          </exclusion>
-          <exclusion>
-            <groupId>com.sun.jersey.contribs</groupId>
-            <artifactId>*</artifactId>
-          </exclusion>
-          <!-- Hadoop-3.x -->
-          <exclusion>
-            <groupId>com.zaxxer</groupId>
-            <artifactId>HikariCP-java7</artifactId>
-          </exclusion>
-          <exclusion>
-            <groupId>com.microsoft.sqlserver</groupId>
-            <artifactId>mssql-jdbc</artifactId>
-          </exclusion>
-        </exclusions>
-      </dependency>
       <dependency>
         <groupId>org.apache.hadoop</groupId>
         <artifactId>hadoop-yarn-client</artifactId>
diff --git 
a/resource-managers/yarn/src/main/java/org/apache/spark/deploy/yarn/AmIpFilter.java
 
b/resource-managers/yarn/src/main/java/org/apache/spark/deploy/yarn/AmIpFilter.java
new file mode 100644
index 000000000000..60e880d1ac4a
--- /dev/null
+++ 
b/resource-managers/yarn/src/main/java/org/apache/spark/deploy/yarn/AmIpFilter.java
@@ -0,0 +1,239 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.yarn;
+
+import org.apache.hadoop.classification.InterfaceAudience.Public;
+import org.apache.hadoop.classification.VisibleForTesting;
+import org.apache.hadoop.security.UserGroupInformation;
+import org.apache.hadoop.util.Time;
+
+import jakarta.servlet.*;
+import jakarta.servlet.http.Cookie;
+import jakarta.servlet.http.HttpServletRequest;
+import jakarta.servlet.http.HttpServletResponse;
+import java.io.IOException;
+import java.net.*;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Map;
+import java.util.Set;
+import java.util.concurrent.TimeUnit;
+
+import org.apache.spark.internal.SparkLogger;
+import org.apache.spark.internal.SparkLoggerFactory;
+
+// This class is copied from Hadoop 3.4.0
+// org.apache.hadoop.yarn.server.webproxy.amfilter.AmIpFilter
+//
+// Modification:
+// Migrate from javax.servlet to jakarta.servlet
+// Copy constant string definitions to strip external dependency
+//  - RM_HA_URLS
+//  - PROXY_USER_COOKIE_NAME
+@Public
+public class AmIpFilter implements Filter {
+  private static final SparkLogger LOG = 
SparkLoggerFactory.getLogger(AmIpFilter.class);
+
+  @Deprecated
+  public static final String PROXY_HOST = "PROXY_HOST";
+  @Deprecated
+  public static final String PROXY_URI_BASE = "PROXY_URI_BASE";
+  public static final String PROXY_HOSTS = "PROXY_HOSTS";
+  public static final String PROXY_HOSTS_DELIMITER = ",";
+  public static final String PROXY_URI_BASES = "PROXY_URI_BASES";
+  public static final String PROXY_URI_BASES_DELIMITER = ",";
+  private static final String PROXY_PATH = "/proxy";
+  // RM_HA_URLS is defined in AmFilterInitializer in the original Hadoop code
+  private static final String RM_HA_URLS = "RM_HA_URLS";
+  // WebAppProxyServlet is defined in WebAppProxyServlet in the original 
Hadoop code
+  public static final String PROXY_USER_COOKIE_NAME = "proxy-user";
+  // update the proxy IP list about every 5 min
+  private static long updateInterval = TimeUnit.MINUTES.toMillis(5);
+
+  private String[] proxyHosts;
+  private Set<String> proxyAddresses = null;
+  private long lastUpdate;
+  @VisibleForTesting
+  Map<String, String> proxyUriBases;
+  String[] rmUrls = null;
+
+  @Override
+  public void init(FilterConfig conf) throws ServletException {
+    // Maintain for backwards compatibility
+    if (conf.getInitParameter(PROXY_HOST) != null
+        && conf.getInitParameter(PROXY_URI_BASE) != null) {
+      proxyHosts = new String[]{conf.getInitParameter(PROXY_HOST)};
+      proxyUriBases = new HashMap<>(1);
+      proxyUriBases.put("dummy", conf.getInitParameter(PROXY_URI_BASE));
+    } else {
+      proxyHosts = conf.getInitParameter(PROXY_HOSTS)
+        .split(PROXY_HOSTS_DELIMITER);
+
+      String[] proxyUriBasesArr = conf.getInitParameter(PROXY_URI_BASES)
+        .split(PROXY_URI_BASES_DELIMITER);
+      proxyUriBases = new HashMap<>(proxyUriBasesArr.length);
+      for (String proxyUriBase : proxyUriBasesArr) {
+        try {
+          URL url = new URL(proxyUriBase);
+          proxyUriBases.put(url.getHost() + ":" + url.getPort(), proxyUriBase);
+        } catch(MalformedURLException e) {
+          LOG.warn(proxyUriBase + " does not appear to be a valid URL", e);
+        }
+      }
+    }
+
+    if (conf.getInitParameter(RM_HA_URLS) != null) {
+      rmUrls = conf.getInitParameter(RM_HA_URLS).split(",");
+    }
+  }
+
+  protected Set<String> getProxyAddresses() throws ServletException {
+    long now = Time.monotonicNow();
+    synchronized(this) {
+      if (proxyAddresses == null || (lastUpdate + updateInterval) <= now) {
+        proxyAddresses = new HashSet<>();
+        for (String proxyHost : proxyHosts) {
+          try {
+            for (InetAddress add : InetAddress.getAllByName(proxyHost)) {
+              LOG.debug("proxy address is: {}", add.getHostAddress());
+              proxyAddresses.add(add.getHostAddress());
+            }
+            lastUpdate = now;
+          } catch (UnknownHostException e) {
+            LOG.warn("Could not locate " + proxyHost + " - skipping", e);
+          }
+        }
+        if (proxyAddresses.isEmpty()) {
+          throw new ServletException("Could not locate any of the proxy 
hosts");
+        }
+      }
+      return proxyAddresses;
+    }
+  }
+
+  @Override
+  public void destroy() {
+    // Empty
+  }
+
+  @Override
+  public void doFilter(ServletRequest req, ServletResponse resp,
+      FilterChain chain) throws IOException, ServletException {
+    ProxyUtils.rejectNonHttpRequests(req);
+
+    HttpServletRequest httpReq = (HttpServletRequest)req;
+    HttpServletResponse httpResp = (HttpServletResponse)resp;
+
+    LOG.debug("Remote address for request is: {}", httpReq.getRemoteAddr());
+
+    if (!getProxyAddresses().contains(httpReq.getRemoteAddr())) {
+      StringBuilder redirect = new StringBuilder(findRedirectUrl());
+
+      redirect.append(httpReq.getRequestURI());
+
+      int insertPoint = redirect.indexOf(PROXY_PATH);
+
+      if (insertPoint >= 0) {
+        // Add /redirect as the second component of the path so that the RM web
+        // proxy knows that this request was a redirect.
+        insertPoint += PROXY_PATH.length();
+        redirect.insert(insertPoint, "/redirect");
+      }
+      // add the query parameters on the redirect if there were any
+      String queryString = httpReq.getQueryString();
+      if (queryString != null && !queryString.isEmpty()) {
+        redirect.append("?");
+        redirect.append(queryString);
+      }
+
+      ProxyUtils.sendRedirect(httpReq, httpResp, redirect.toString());
+    } else {
+      String user = null;
+
+      if (httpReq.getCookies() != null) {
+        for (Cookie c: httpReq.getCookies()) {
+          if (PROXY_USER_COOKIE_NAME.equals(c.getName())){
+            user = c.getValue();
+            break;
+          }
+        }
+      }
+      if (user == null) {
+        LOG.debug("Could not find {} cookie, so user will not be set",
+            PROXY_USER_COOKIE_NAME);
+
+        chain.doFilter(req, resp);
+      } else {
+        AmIpPrincipal principal = new AmIpPrincipal(user);
+        ServletRequest requestWrapper = new AmIpServletRequestWrapper(httpReq,
+            principal);
+
+        chain.doFilter(requestWrapper, resp);
+      }
+    }
+  }
+
+  @VisibleForTesting
+  public String findRedirectUrl() throws ServletException {
+    String addr = null;
+    if (proxyUriBases.size() == 1) {
+      // external proxy or not RM HA
+      addr = proxyUriBases.values().iterator().next();
+    } else if (rmUrls != null) {
+      for (String url : rmUrls) {
+        String host = proxyUriBases.get(url);
+        if (isValidUrl(host)) {
+          addr = host;
+          break;
+        }
+      }
+    }
+
+    if (addr == null) {
+      throw new ServletException(
+          "Could not determine the proxy server for redirection");
+    }
+    return addr;
+  }
+
+  @VisibleForTesting
+  public boolean isValidUrl(String url) {
+    boolean isValid = false;
+    try {
+      HttpURLConnection conn = (HttpURLConnection) new 
URL(url).openConnection();
+      conn.connect();
+      isValid = conn.getResponseCode() == HttpURLConnection.HTTP_OK;
+      // If security is enabled, any valid RM which can give 401 Unauthorized 
is
+      // good enough to access. Since AM doesn't have enough credential, auth
+      // cannot be completed and hence 401 is fine in such case.
+      if (!isValid && UserGroupInformation.isSecurityEnabled()) {
+        isValid = (conn.getResponseCode() == 
HttpURLConnection.HTTP_UNAUTHORIZED)
+            || (conn.getResponseCode() == HttpURLConnection.HTTP_FORBIDDEN);
+        return isValid;
+      }
+    } catch (Exception e) {
+      LOG.warn("Failed to connect to " + url + ": " + e.toString());
+    }
+    return isValid;
+  }
+
+  @VisibleForTesting
+  protected static void setUpdateInterval(long updateInterval) {
+    AmIpFilter.updateInterval = updateInterval;
+  }
+}
diff --git 
a/resource-managers/yarn/src/main/java/org/apache/spark/deploy/yarn/AmIpPrincipal.java
 
b/resource-managers/yarn/src/main/java/org/apache/spark/deploy/yarn/AmIpPrincipal.java
new file mode 100644
index 000000000000..9d5a5e3b0456
--- /dev/null
+++ 
b/resource-managers/yarn/src/main/java/org/apache/spark/deploy/yarn/AmIpPrincipal.java
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.yarn;
+
+import java.security.Principal;
+
+// This class is copied from Hadoop 3.4.0
+// org.apache.hadoop.yarn.server.webproxy.amfilter.AmIpPrincipal
+public class AmIpPrincipal implements Principal {
+  private final String name;
+
+  public AmIpPrincipal(String name) {
+    this.name = name;
+  }
+
+  @Override
+  public String getName() {
+    return name;
+  }
+}
diff --git 
a/resource-managers/yarn/src/main/java/org/apache/spark/deploy/yarn/AmIpServletRequestWrapper.java
 
b/resource-managers/yarn/src/main/java/org/apache/spark/deploy/yarn/AmIpServletRequestWrapper.java
new file mode 100644
index 000000000000..9082378fe89c
--- /dev/null
+++ 
b/resource-managers/yarn/src/main/java/org/apache/spark/deploy/yarn/AmIpServletRequestWrapper.java
@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.yarn;
+
+import jakarta.servlet.http.HttpServletRequest;
+import jakarta.servlet.http.HttpServletRequestWrapper;
+import java.security.Principal;
+
+// This class is copied from Hadoop 3.4.0
+// org.apache.hadoop.yarn.server.webproxy.amfilter.AmIpServletRequestWrapper
+//
+// Modification:
+// Migrate from javax.servlet to jakarta.servlet
+public class AmIpServletRequestWrapper extends HttpServletRequestWrapper {
+  private final AmIpPrincipal principal;
+
+  public AmIpServletRequestWrapper(HttpServletRequest request,
+      AmIpPrincipal principal) {
+    super(request);
+    this.principal = principal;
+  }
+
+  @Override
+  public Principal getUserPrincipal() {
+    return principal;
+  }
+
+  @Override
+  public String getRemoteUser() {
+    return principal.getName();
+  }
+
+  @Override
+  public boolean isUserInRole(String role) {
+    // No role info so far
+    return false;
+  }
+
+}
diff --git 
a/resource-managers/yarn/src/main/java/org/apache/spark/deploy/yarn/ProxyUtils.java
 
b/resource-managers/yarn/src/main/java/org/apache/spark/deploy/yarn/ProxyUtils.java
new file mode 100644
index 000000000000..c7a49a76c655
--- /dev/null
+++ 
b/resource-managers/yarn/src/main/java/org/apache/spark/deploy/yarn/ProxyUtils.java
@@ -0,0 +1,126 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.yarn;
+
+import org.apache.hadoop.yarn.webapp.MimeType;
+import org.apache.hadoop.yarn.webapp.hamlet2.Hamlet;
+
+import jakarta.servlet.ServletException;
+import jakarta.servlet.ServletRequest;
+import jakarta.servlet.http.HttpServletRequest;
+import jakarta.servlet.http.HttpServletResponse;
+import java.io.IOException;
+import java.io.PrintWriter;
+import java.util.EnumSet;
+
+import org.apache.spark.internal.SparkLogger;
+import org.apache.spark.internal.SparkLoggerFactory;
+
+// Class containing general purpose proxy utilities
+//
+// This class is copied from Hadoop 3.4.0
+// org.apache.hadoop.yarn.server.webproxy.ProxyUtils
+//
+// Modification:
+// Migrate from javax.servlet to jakarta.servlet
+public class ProxyUtils {
+  private static final SparkLogger LOG = 
SparkLoggerFactory.getLogger(ProxyUtils.class);
+  public static final String E_HTTP_HTTPS_ONLY =
+      "This filter only works for HTTP/HTTPS";
+  public static final String LOCATION = "Location";
+
+  public static class __ implements Hamlet.__ {
+    // Empty
+  }
+
+  public static class Page extends Hamlet {
+    Page(PrintWriter out) {
+      super(out, 0, false);
+    }
+
+    public HTML<ProxyUtils.__> html() {
+      return new HTML<>("html", null, EnumSet.of(EOpt.ENDTAG));
+    }
+  }
+
+  /**
+   * Handle redirects with a status code that can in future support verbs other
+   * than GET, thus supporting full REST functionality.
+   * <p>
+   * The target URL is included in the redirect text returned
+   * <p>
+   * At the end of this method, the output stream is closed.
+   *
+   * @param request request (hence: the verb and any other information
+   * relevant to a redirect)
+   * @param response the response
+   * @param target the target URL -unencoded
+   *
+   */
+  public static void sendRedirect(HttpServletRequest request,
+      HttpServletResponse response,
+      String target)
+      throws IOException {
+    LOG.debug("Redirecting {} {} to {}",
+          request.getMethod(),
+          request.getRequestURI(),
+          target);
+    String location = response.encodeRedirectURL(target);
+    response.setStatus(HttpServletResponse.SC_FOUND);
+    response.setHeader(LOCATION, location);
+    response.setContentType(MimeType.HTML);
+    PrintWriter writer = response.getWriter();
+    Page p = new Page(writer);
+    p.html()
+      .head().title("Moved").__()
+      .body()
+      .h1("Moved")
+      .div()
+        .__("Content has moved ")
+        .a(location, "here").__()
+      .__().__();
+    writer.close();
+  }
+
+
+  /**
+   * Output 404 with appropriate message.
+   * @param resp the http response.
+   * @param message the message to include on the page.
+   * @throws IOException on any error.
+   */
+  public static void notFound(HttpServletResponse resp, String message)
+      throws IOException {
+    resp.setStatus(HttpServletResponse.SC_NOT_FOUND);
+    resp.setContentType(MimeType.HTML);
+    Page p = new Page(resp.getWriter());
+    p.html().h1(message).__();
+  }
+
+  /**
+   * Reject any request that isn't from an HTTP servlet
+   * @param req request
+   * @throws ServletException if the request is of the wrong type
+   */
+  public static void rejectNonHttpRequests(ServletRequest req) throws
+      ServletException {
+    if (!(req instanceof HttpServletRequest)) {
+      throw new ServletException(E_HTTP_HTTPS_ONLY);
+    }
+  }
+}
diff --git 
a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
 
b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
index 8f20f6602ec5..4b5f9be3193f 100644
--- 
a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
+++ 
b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
@@ -696,7 +696,7 @@ private[spark] class ApplicationMaster(
 
   /** Add the Yarn IP filter that is required for properly securing the UI. */
   private def addAmIpFilter(driver: Option[RpcEndpointRef], proxyBase: String) 
= {
-    val amFilter = "org.apache.hadoop.yarn.server.webproxy.amfilter.AmIpFilter"
+    val amFilter = classOf[AmIpFilter].getName
     val params = client.getAmIpFilterParams(yarnConf, proxyBase)
     driver match {
       case Some(d) =>
diff --git 
a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/AmIpFilterSuite.scala
 
b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/AmIpFilterSuite.scala
new file mode 100644
index 000000000000..e25bd665dec0
--- /dev/null
+++ 
b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/AmIpFilterSuite.scala
@@ -0,0 +1,342 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.yarn
+
+import java.io.{IOException, PrintWriter, StringWriter}
+import java.net.HttpURLConnection
+import java.util
+import java.util.{Collections, Locale}
+import java.util.concurrent.TimeUnit
+import java.util.concurrent.atomic.AtomicBoolean
+
+import scala.jdk.CollectionConverters._
+
+import jakarta.servlet.{FilterChain, FilterConfig, ServletContext, 
ServletException, ServletOutputStream, ServletRequest, ServletResponse}
+import jakarta.servlet.http.{Cookie, HttpServlet, HttpServletRequest, 
HttpServletResponse}
+import jakarta.ws.rs.core.MediaType
+import org.eclipse.jetty.server.{Server, ServerConnector}
+import org.eclipse.jetty.servlet.{ServletContextHandler, ServletHolder}
+import org.eclipse.jetty.util.thread.QueuedThreadPool
+import org.mockito.Mockito.{mock, when}
+import org.scalatest.concurrent.Eventually._
+import org.scalatest.time.SpanSugar._
+
+import org.apache.spark.SparkFunSuite
+
+// A port of org.apache.hadoop.yarn.server.webproxy.amfilter.TestAmFilter
+class AmIpFilterSuite extends SparkFunSuite {
+
+  private val proxyHost = "localhost"
+  private val proxyUri = "http://bogus";
+
+  class TestAmIpFilter extends AmIpFilter {
+    override def getProxyAddresses: util.Set[String] = Set(proxyHost).asJava
+  }
+
+  class DummyFilterConfig (val map: util.Map[String, String]) extends 
FilterConfig {
+    override def getFilterName: String = "dummy"
+
+    override def getInitParameter(arg0: String): String = map.get(arg0)
+
+    override def getInitParameterNames: util.Enumeration[String] =
+      Collections.enumeration(map.keySet)
+
+    override def getServletContext: ServletContext = null
+  }
+
+  test("filterNullCookies") {
+    val request = mock(classOf[HttpServletRequest])
+
+    when(request.getCookies).thenReturn(null)
+    when(request.getRemoteAddr).thenReturn(proxyHost)
+
+    val response = mock(classOf[HttpServletResponse])
+    val invoked = new AtomicBoolean
+
+    val chain = new FilterChain() {
+      @throws[IOException]
+      @throws[ServletException]
+      override def doFilter(req: ServletRequest, resp: ServletResponse): Unit 
= {
+        invoked.set(true)
+      }
+    }
+
+    val params = new util.HashMap[String, String]
+    params.put(AmIpFilter.PROXY_HOST, proxyHost)
+    params.put(AmIpFilter.PROXY_URI_BASE, proxyUri)
+    val conf = new DummyFilterConfig(params)
+    val filter = new TestAmIpFilter
+    filter.init(conf)
+    filter.doFilter(request, response, chain)
+    assert(invoked.get)
+    filter.destroy()
+  }
+
+  test("testFindRedirectUrl") {
+    class EchoServlet extends HttpServlet {
+      @throws[IOException]
+      @throws[ServletException]
+      override def doGet(request: HttpServletRequest, response: 
HttpServletResponse): Unit = {
+        response.setContentType(MediaType.TEXT_PLAIN + "; charset=utf-8")
+        val out = response.getWriter
+        request.getParameterNames.asScala.toSeq.sorted.foreach { key =>
+          out.print(key)
+          out.print(':')
+          out.print(request.getParameter(key))
+          out.print('\n')
+        }
+        out.close()
+      }
+    }
+
+    def withHttpEchoServer(body: String => Unit): Unit = {
+      val server = new Server(0)
+      server.getThreadPool.asInstanceOf[QueuedThreadPool].setMaxThreads(20)
+      val context = new ServletContextHandler
+      context.setContextPath("/foo")
+      server.setHandler(context)
+      val servletPath = "/bar"
+      context.addServlet(new ServletHolder(new EchoServlet), servletPath)
+      
server.getConnectors.head.asInstanceOf[ServerConnector].setHost("localhost")
+      try {
+        server.start()
+        body(server.getURI.toString + servletPath)
+      } finally {
+        server.stop()
+      }
+    }
+
+    // generate a valid URL
+    withHttpEchoServer { rm1Url =>
+      val rm1 = "rm1"
+      val rm2 = "rm2"
+      // invalid url
+      val rm2Url = "host2:8088"
+
+      val filter = new TestAmIpFilter
+      // make sure findRedirectUrl() go to HA branch
+      filter.proxyUriBases = Map(rm1 -> rm1Url, rm2 -> rm2Url).asJava
+      filter.rmUrls = Array[String](rm1, rm2)
+
+      assert(filter.findRedirectUrl === rm1Url)
+    }
+  }
+
+  test("testProxyUpdate") {
+    var params = new util.HashMap[String, String]
+    params.put(AmIpFilter.PROXY_HOSTS, proxyHost)
+    params.put(AmIpFilter.PROXY_URI_BASES, proxyUri)
+
+    var conf = new DummyFilterConfig(params)
+    val filter = new AmIpFilter
+    val updateInterval = TimeUnit.SECONDS.toMillis(1)
+    AmIpFilter.setUpdateInterval(updateInterval)
+    filter.init(conf)
+
+    // check that the configuration was applied
+    assert(filter.getProxyAddresses.contains("127.0.0.1"))
+
+    // change proxy configurations
+    params = new util.HashMap[String, String]
+    params.put(AmIpFilter.PROXY_HOSTS, "unknownhost")
+    params.put(AmIpFilter.PROXY_URI_BASES, proxyUri)
+    conf = new DummyFilterConfig(params)
+    filter.init(conf)
+
+    // configurations shouldn't be updated now
+    assert(!filter.getProxyAddresses.isEmpty)
+    // waiting for configuration update
+    eventually(timeout(5.seconds), interval(500.millis)) {
+      assertThrows[ServletException] {
+        filter.getProxyAddresses.isEmpty
+      }
+    }
+  }
+
+  test("testFilter") {
+    var doFilterRequest: String = null
+    var servletWrapper: AmIpServletRequestWrapper = null
+
+    val params = new util.HashMap[String, String]
+    params.put(AmIpFilter.PROXY_HOST, proxyHost)
+    params.put(AmIpFilter.PROXY_URI_BASE, proxyUri)
+    val config = new DummyFilterConfig(params)
+
+    // dummy filter
+    val chain = new FilterChain() {
+      @throws[IOException]
+      @throws[ServletException]
+      override def doFilter(req: ServletRequest, resp: ServletResponse): Unit 
= {
+        doFilterRequest = req.getClass.getName
+        req match {
+          case wrapper: AmIpServletRequestWrapper => servletWrapper = wrapper
+          case _ =>
+        }
+      }
+    }
+    val testFilter = new AmIpFilter
+    testFilter.init(config)
+
+    val response = new HttpServletResponseForTest
+
+    // Test request should implements HttpServletRequest
+    val failRequest = mock(classOf[ServletRequest])
+    val throws = intercept[ServletException] {
+      testFilter.doFilter(failRequest, response, chain)
+    }
+    assert(ProxyUtils.E_HTTP_HTTPS_ONLY === throws.getMessage)
+
+
+    // request with HttpServletRequest
+    val request = mock(classOf[HttpServletRequest])
+    when(request.getRemoteAddr).thenReturn("nowhere")
+    when(request.getRequestURI).thenReturn("/app/application_00_0")
+
+    // address "redirect" is not in host list for non-proxy connection
+    testFilter.doFilter(request, response, chain)
+    assert(HttpURLConnection.HTTP_MOVED_TEMP === response.status)
+    var redirect = response.getHeader(ProxyUtils.LOCATION)
+    assert("http://bogus/app/application_00_0"; === redirect)
+
+    // address "redirect" is not in host list for proxy connection
+    when(request.getRequestURI).thenReturn("/proxy/application_00_0")
+    testFilter.doFilter(request, response, chain)
+    assert(HttpURLConnection.HTTP_MOVED_TEMP === response.status)
+    redirect = response.getHeader(ProxyUtils.LOCATION)
+    assert("http://bogus/proxy/redirect/application_00_0"; === redirect)
+
+    // check for query parameters
+    when(request.getRequestURI).thenReturn("/proxy/application_00_0")
+    when(request.getQueryString).thenReturn("id=0")
+    testFilter.doFilter(request, response, chain)
+    assert(HttpURLConnection.HTTP_MOVED_TEMP === response.status)
+    redirect = response.getHeader(ProxyUtils.LOCATION)
+    assert("http://bogus/proxy/redirect/application_00_0?id=0"; === redirect)
+
+    // "127.0.0.1" contains in host list. Without cookie
+    when(request.getRemoteAddr).thenReturn("127.0.0.1")
+    testFilter.doFilter(request, response, chain)
+    assert(doFilterRequest.contains("HttpServletRequest"))
+
+    // cookie added
+    val cookies = Array[Cookie](new Cookie(AmIpFilter.PROXY_USER_COOKIE_NAME, 
"user"))
+
+    when(request.getCookies).thenReturn(cookies)
+    testFilter.doFilter(request, response, chain)
+
+    assert(doFilterRequest === classOf[AmIpServletRequestWrapper].getName)
+    // request contains principal from cookie
+    assert(servletWrapper.getUserPrincipal.getName === "user")
+    assert(servletWrapper.getRemoteUser === "user")
+    assert(!servletWrapper.isUserInRole(""))
+  }
+
+  private class HttpServletResponseForTest extends HttpServletResponse {
+    private var redirectLocation = ""
+    var status = 0
+    private var contentType: String = _
+    final private val headers = new util.HashMap[String, String](1)
+    private var body: StringWriter = _
+
+    def getRedirect: String = redirectLocation
+
+    @throws[IOException]
+    override def sendRedirect(location: String): Unit = redirectLocation = 
location
+
+    override def setDateHeader(name: String, date: Long): Unit = {}
+
+    override def addDateHeader(name: String, date: Long): Unit = {}
+
+    override def addCookie(cookie: Cookie): Unit = {}
+
+    override def containsHeader(name: String): Boolean = false
+
+    override def encodeURL(url: String): String = null
+
+    override def encodeRedirectURL(url: String): String = url
+
+    override def encodeUrl(url: String): String = null
+
+    override def encodeRedirectUrl(url: String): String = null
+
+    @throws[IOException]
+    override def sendError(sc: Int, msg: String): Unit = {}
+
+    @throws[IOException]
+    override def sendError(sc: Int): Unit = {}
+
+    override def setStatus(status: Int): Unit = this.status = status
+
+    override def setStatus(sc: Int, sm: String): Unit = {}
+
+    override def getStatus: Int = 0
+
+    override def setContentType(contentType: String): Unit = this.contentType 
= contentType
+
+    override def setBufferSize(size: Int): Unit = {}
+
+    override def getBufferSize: Int = 0
+
+    @throws[IOException]
+    override def flushBuffer(): Unit = {}
+
+    override def resetBuffer(): Unit = {}
+
+    override def isCommitted: Boolean = false
+
+    override def reset(): Unit = {}
+
+    override def setLocale(loc: Locale): Unit = {}
+
+    override def getLocale: Locale = null
+
+    override def setHeader(name: String, value: String): Unit = 
headers.put(name, value)
+
+    override def addHeader(name: String, value: String): Unit = {}
+
+    override def setIntHeader(name: String, value: Int): Unit = {}
+
+    override def addIntHeader(name: String, value: Int): Unit = {}
+
+    override def getHeader(name: String): String = headers.get(name)
+
+    override def getHeaders(name: String): util.Collection[String] = null
+
+    override def getHeaderNames: util.Collection[String] = null
+
+    override def getCharacterEncoding: String = null
+
+    override def getContentType: String = null
+
+    @throws[IOException]
+    override def getOutputStream: ServletOutputStream = null
+
+    @throws[IOException]
+    override def getWriter: PrintWriter = {
+      body = new StringWriter
+      new PrintWriter(body)
+    }
+
+    override def setCharacterEncoding(charset: String): Unit = {}
+
+    override def setContentLength(len: Int): Unit = {}
+
+    override def setContentLengthLong(len: Long): Unit = {}
+  }
+
+}
diff --git 
a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala 
b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala
index bed048c4b5df..6cbc74a75a06 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala
@@ -86,7 +86,7 @@ class Checkpoint(ssc: StreamingContext, val checkpointTime: 
Time)
     }
 
     // Add Yarn proxy filter specific configurations to the recovered SparkConf
-    val filter = "org.apache.hadoop.yarn.server.webproxy.amfilter.AmIpFilter"
+    val filter = "org.apache.spark.deploy.yarn.AmIpFilter"
     val filterPrefix = s"spark.$filter.param."
     newReloadConf.getAll.foreach { case (k, v) =>
       if (k.startsWith(filterPrefix) && k.length > filterPrefix.length) {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org


Reply via email to