asfgit closed pull request #23302: [SPARK-24522][UI] Create filter to apply
HTTP security checks consistently.
URL: https://github.com/apache/spark/pull/23302
This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:
As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):
diff --git
a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala
b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala
index 00ca4efa4d266..7a8ab7fddd79f 100644
--- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala
@@ -27,9 +27,8 @@ import org.apache.spark.ui.{UIUtils, WebUIPage}
private[history] class HistoryPage(parent: HistoryServer) extends
WebUIPage("") {
def render(request: HttpServletRequest): Seq[Node] = {
- // stripXSS is called first to remove suspicious characters used in XSS
attacks
- val requestedIncomplete =
-
Option(UIUtils.stripXSS(request.getParameter("showIncomplete"))).getOrElse("false").toBoolean
+ val requestedIncomplete = Option(request.getParameter("showIncomplete"))
+ .getOrElse("false").toBoolean
val displayApplications = parent.getApplicationList()
.exists(isApplicationCompleted(_) != requestedIncomplete)
diff --git
a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala
b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala
index b9303388638fd..ff2ea3b843ee3 100644
--- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala
@@ -150,17 +150,15 @@ class HistoryServer(
ui: SparkUI,
completed: Boolean) {
assert(serverInfo.isDefined, "HistoryServer must be bound before attaching
SparkUIs")
- handlers.synchronized {
- ui.getHandlers.foreach(attachHandler)
+ ui.getHandlers.foreach { handler =>
+ serverInfo.get.addHandler(handler, ui.securityManager)
}
}
/** Detach a reconstructed UI from this server. Only valid after bind(). */
override def detachSparkUI(appId: String, attemptId: Option[String], ui:
SparkUI): Unit = {
assert(serverInfo.isDefined, "HistoryServer must be bound before detaching
SparkUIs")
- handlers.synchronized {
- ui.getHandlers.foreach(detachHandler)
- }
+ ui.getHandlers.foreach(detachHandler)
provider.onUIDetached(appId, attemptId, ui)
}
diff --git
a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala
b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala
index fad4e46dc035d..bcd7a7e4ccdb5 100644
---
a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala
+++
b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala
@@ -33,8 +33,7 @@ private[ui] class ApplicationPage(parent: MasterWebUI)
extends WebUIPage("app")
/** Executor details for a particular application */
def render(request: HttpServletRequest): Seq[Node] = {
- // stripXSS is called first to remove suspicious characters used in XSS
attacks
- val appId = UIUtils.stripXSS(request.getParameter("appId"))
+ val appId = request.getParameter("appId")
val state = master.askSync[MasterStateResponse](RequestMasterState)
val app = state.activeApps.find(_.id == appId)
.getOrElse(state.completedApps.find(_.id == appId).orNull)
diff --git
a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala
b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala
index b8afe203fbfa2..6701465c023c7 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala
@@ -57,10 +57,8 @@ private[ui] class MasterPage(parent: MasterWebUI) extends
WebUIPage("") {
private def handleKillRequest(request: HttpServletRequest, action: String =>
Unit): Unit = {
if (parent.killEnabled &&
parent.master.securityMgr.checkModifyPermissions(request.getRemoteUser)) {
- // stripXSS is called first to remove suspicious characters used in XSS
attacks
- val killFlag =
-
Option(UIUtils.stripXSS(request.getParameter("terminate"))).getOrElse("false").toBoolean
- val id = Option(UIUtils.stripXSS(request.getParameter("id")))
+ val killFlag =
Option(request.getParameter("terminate")).getOrElse("false").toBoolean
+ val id = Option(request.getParameter("id"))
if (id.isDefined && killFlag) {
action(id.get)
}
diff --git
a/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala
b/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala
index 4fca9342c0378..4e720a759a1bc 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala
@@ -33,15 +33,13 @@ private[ui] class LogPage(parent: WorkerWebUI) extends
WebUIPage("logPage") with
private val supportedLogTypes = Set("stderr", "stdout")
private val defaultBytes = 100 * 1024
- // stripXSS is called first to remove suspicious characters used in XSS
attacks
def renderLog(request: HttpServletRequest): String = {
- val appId = Option(UIUtils.stripXSS(request.getParameter("appId")))
- val executorId =
Option(UIUtils.stripXSS(request.getParameter("executorId")))
- val driverId = Option(UIUtils.stripXSS(request.getParameter("driverId")))
- val logType = UIUtils.stripXSS(request.getParameter("logType"))
- val offset =
Option(UIUtils.stripXSS(request.getParameter("offset"))).map(_.toLong)
- val byteLength =
- Option(UIUtils.stripXSS(request.getParameter("byteLength"))).map(_.toInt)
+ val appId = Option(request.getParameter("appId"))
+ val executorId = Option(request.getParameter("executorId"))
+ val driverId = Option(request.getParameter("driverId"))
+ val logType = request.getParameter("logType")
+ val offset = Option(request.getParameter("offset")).map(_.toLong)
+ val byteLength = Option(request.getParameter("byteLength")).map(_.toInt)
.getOrElse(defaultBytes)
val logDir = (appId, executorId, driverId) match {
@@ -58,15 +56,13 @@ private[ui] class LogPage(parent: WorkerWebUI) extends
WebUIPage("logPage") with
pre + logText
}
- // stripXSS is called first to remove suspicious characters used in XSS
attacks
def render(request: HttpServletRequest): Seq[Node] = {
- val appId = Option(UIUtils.stripXSS(request.getParameter("appId")))
- val executorId =
Option(UIUtils.stripXSS(request.getParameter("executorId")))
- val driverId = Option(UIUtils.stripXSS(request.getParameter("driverId")))
- val logType = UIUtils.stripXSS(request.getParameter("logType"))
- val offset =
Option(UIUtils.stripXSS(request.getParameter("offset"))).map(_.toLong)
- val byteLength =
- Option(UIUtils.stripXSS(request.getParameter("byteLength"))).map(_.toInt)
+ val appId = Option(request.getParameter("appId"))
+ val executorId = Option(request.getParameter("executorId"))
+ val driverId = Option(request.getParameter("driverId"))
+ val logType = request.getParameter("logType")
+ val offset = Option(request.getParameter("offset")).map(_.toLong)
+ val byteLength = Option(request.getParameter("byteLength")).map(_.toInt)
.getOrElse(defaultBytes)
val (logDir, params, pageName) = (appId, executorId, driverId) match {
diff --git
a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala
b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala
index ea67b7434a769..54886955b98fb 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala
@@ -50,7 +50,6 @@ class WorkerWebUI(
addStaticHandler(WorkerWebUI.STATIC_RESOURCE_BASE)
attachHandler(createServletHandler("/log",
(request: HttpServletRequest) => logPage.renderLog(request),
- worker.securityMgr,
worker.conf))
}
}
diff --git
a/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala
b/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala
index 68b58b8490641..bea24ca7807e4 100644
--- a/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala
+++ b/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala
@@ -51,7 +51,7 @@ private[spark] class MetricsServlet(
def getHandlers(conf: SparkConf): Array[ServletContextHandler] = {
Array[ServletContextHandler](
createServletHandler(servletPath,
- new ServletParams(request => getMetricsSnapshot(request),
"text/json"), securityMgr, conf)
+ new ServletParams(request => getMetricsSnapshot(request),
"text/json"), conf)
)
}
diff --git
a/core/src/main/scala/org/apache/spark/status/api/v1/SecurityFilter.scala
b/core/src/main/scala/org/apache/spark/status/api/v1/SecurityFilter.scala
deleted file mode 100644
index 1cd37185d6601..0000000000000
--- a/core/src/main/scala/org/apache/spark/status/api/v1/SecurityFilter.scala
+++ /dev/null
@@ -1,36 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.spark.status.api.v1
-
-import javax.ws.rs.container.{ContainerRequestContext, ContainerRequestFilter}
-import javax.ws.rs.core.Response
-import javax.ws.rs.ext.Provider
-
-@Provider
-private[v1] class SecurityFilter extends ContainerRequestFilter with
ApiRequestContext {
- override def filter(req: ContainerRequestContext): Unit = {
- val user = httpRequest.getRemoteUser()
- if (!uiRoot.securityManager.checkUIViewPermissions(user)) {
- req.abortWith(
- Response
- .status(Response.Status.FORBIDDEN)
- .entity(raw"""user "$user" is not authorized""")
- .build()
- )
- }
- }
-}
diff --git a/core/src/main/scala/org/apache/spark/ui/HttpSecurityFilter.scala
b/core/src/main/scala/org/apache/spark/ui/HttpSecurityFilter.scala
new file mode 100644
index 0000000000000..da84fdf8fe140
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/ui/HttpSecurityFilter.scala
@@ -0,0 +1,116 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ui
+
+import java.util.{Enumeration, Map => JMap}
+import javax.servlet._
+import javax.servlet.http.{HttpServletRequest, HttpServletRequestWrapper,
HttpServletResponse}
+
+import scala.collection.JavaConverters._
+
+import org.apache.commons.lang3.StringEscapeUtils
+
+import org.apache.spark.{SecurityManager, SparkConf}
+import org.apache.spark.internal.config._
+
+/**
+ * A servlet filter that implements HTTP security features. The following
actions are taken
+ * for every request:
+ *
+ * - perform access control of authenticated requests.
+ * - check request data for disallowed content (e.g. things that could be used
to create XSS
+ * attacks).
+ * - set response headers to prevent certain kinds of attacks.
+ *
+ * Request parameters are sanitized so that HTML content is escaped, and
disallowed content is
+ * removed.
+ */
+private class HttpSecurityFilter(
+ conf: SparkConf,
+ securityMgr: SecurityManager) extends Filter {
+
+ override def destroy(): Unit = { }
+
+ override def init(config: FilterConfig): Unit = { }
+
+ override def doFilter(req: ServletRequest, res: ServletResponse, chain:
FilterChain): Unit = {
+ val hreq = req.asInstanceOf[HttpServletRequest]
+ val hres = res.asInstanceOf[HttpServletResponse]
+ hres.setHeader("Cache-Control", "no-cache, no-store, must-revalidate")
+
+ if (!securityMgr.checkUIViewPermissions(hreq.getRemoteUser())) {
+ hres.sendError(HttpServletResponse.SC_FORBIDDEN,
+ "User is not authorized to access this page.")
+ return
+ }
+
+ // SPARK-10589 avoid frame-related click-jacking vulnerability, using
X-Frame-Options
+ // (see http://tools.ietf.org/html/rfc7034). By default allow framing only
from the
+ // same origin, but allow framing for a specific named URI.
+ // Example: spark.ui.allowFramingFrom = https://example.com/
+ val xFrameOptionsValue = conf.getOption("spark.ui.allowFramingFrom")
+ .map { uri => s"ALLOW-FROM $uri" }
+ .getOrElse("SAMEORIGIN")
+
+ hres.setHeader("X-Frame-Options", xFrameOptionsValue)
+ hres.setHeader("X-XSS-Protection", conf.get(UI_X_XSS_PROTECTION))
+ if (conf.get(UI_X_CONTENT_TYPE_OPTIONS)) {
+ hres.setHeader("X-Content-Type-Options", "nosniff")
+ }
+ if (hreq.getScheme() == "https") {
+ conf.get(UI_STRICT_TRANSPORT_SECURITY).foreach(
+ hres.setHeader("Strict-Transport-Security", _))
+ }
+
+ chain.doFilter(new XssSafeRequest(hreq), res)
+ }
+
+}
+
+private class XssSafeRequest(req: HttpServletRequest) extends
HttpServletRequestWrapper(req) {
+
+ private val NEWLINE_AND_SINGLE_QUOTE_REGEX =
raw"(?i)(\r\n|\n|\r|%0D%0A|%0A|%0D|'|%27)".r
+
+ private val parameterMap: Map[String, Array[String]] = {
+ super.getParameterMap().asScala.map { case (name, values) =>
+ stripXSS(name) -> values.map(stripXSS)
+ }.toMap
+ }
+
+ override def getParameterMap(): JMap[String, Array[String]] =
parameterMap.asJava
+
+ override def getParameterNames(): Enumeration[String] = {
+ parameterMap.keys.iterator.asJavaEnumeration
+ }
+
+ override def getParameterValues(name: String): Array[String] =
parameterMap.get(name).orNull
+
+ override def getParameter(name: String): String = {
+ parameterMap.get(name).flatMap(_.headOption).orNull
+ }
+
+ private def stripXSS(str: String): String = {
+ if (str != null) {
+ // Remove new lines and single quotes, followed by escaping HTML version
4.0
+
StringEscapeUtils.escapeHtml4(NEWLINE_AND_SINGLE_QUOTE_REGEX.replaceAllIn(str,
""))
+ } else {
+ null
+ }
+ }
+
+}
diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala
b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala
index 316af9b79d286..08f5fb937da7e 100644
--- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala
+++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala
@@ -18,6 +18,7 @@
package org.apache.spark.ui
import java.net.{URI, URL}
+import java.util.EnumSet
import javax.servlet.DispatcherType
import javax.servlet.http.{HttpServlet, HttpServletRequest,
HttpServletResponse}
@@ -68,43 +69,16 @@ private[spark] object JettyUtils extends Logging {
implicit def textResponderToServlet(responder: Responder[String]):
ServletParams[String] =
new ServletParams(responder, "text/plain")
- def createServlet[T <: AnyRef](
+ private def createServlet[T <: AnyRef](
servletParams: ServletParams[T],
- securityMgr: SecurityManager,
conf: SparkConf): HttpServlet = {
-
- // SPARK-10589 avoid frame-related click-jacking vulnerability, using
X-Frame-Options
- // (see http://tools.ietf.org/html/rfc7034). By default allow framing only
from the
- // same origin, but allow framing for a specific named URI.
- // Example: spark.ui.allowFramingFrom = https://example.com/
- val allowFramingFrom = conf.getOption("spark.ui.allowFramingFrom")
- val xFrameOptionsValue =
- allowFramingFrom.map(uri => s"ALLOW-FROM $uri").getOrElse("SAMEORIGIN")
-
new HttpServlet {
override def doGet(request: HttpServletRequest, response:
HttpServletResponse) {
try {
- if (securityMgr.checkUIViewPermissions(request.getRemoteUser)) {
-
response.setContentType("%s;charset=utf-8".format(servletParams.contentType))
- response.setStatus(HttpServletResponse.SC_OK)
- val result = servletParams.responder(request)
- response.setHeader("Cache-Control", "no-cache, no-store,
must-revalidate")
- response.setHeader("X-Frame-Options", xFrameOptionsValue)
- response.setHeader("X-XSS-Protection",
conf.get(UI_X_XSS_PROTECTION))
- if (conf.get(UI_X_CONTENT_TYPE_OPTIONS)) {
- response.setHeader("X-Content-Type-Options", "nosniff")
- }
- if (request.getScheme == "https") {
- conf.get(UI_STRICT_TRANSPORT_SECURITY).foreach(
- response.setHeader("Strict-Transport-Security", _))
- }
- response.getWriter.print(servletParams.extractFn(result))
- } else {
- response.setStatus(HttpServletResponse.SC_FORBIDDEN)
- response.setHeader("Cache-Control", "no-cache, no-store,
must-revalidate")
- response.sendError(HttpServletResponse.SC_FORBIDDEN,
- "User is not authorized to access this page.")
- }
+
response.setContentType("%s;charset=utf-8".format(servletParams.contentType))
+ response.setStatus(HttpServletResponse.SC_OK)
+ val result = servletParams.responder(request)
+ response.getWriter.print(servletParams.extractFn(result))
} catch {
case e: IllegalArgumentException =>
response.sendError(HttpServletResponse.SC_BAD_REQUEST,
e.getMessage)
@@ -124,10 +98,9 @@ private[spark] object JettyUtils extends Logging {
def createServletHandler[T <: AnyRef](
path: String,
servletParams: ServletParams[T],
- securityMgr: SecurityManager,
conf: SparkConf,
basePath: String = ""): ServletContextHandler = {
- createServletHandler(path, createServlet(servletParams, securityMgr,
conf), basePath)
+ createServletHandler(path, createServlet(servletParams, conf), basePath)
}
/** Create a context handler that responds to a request with the given path
prefix */
@@ -257,36 +230,6 @@ private[spark] object JettyUtils extends Logging {
contextHandler
}
- /** Add filters, if any, to the given list of ServletContextHandlers */
- def addFilters(handlers: Seq[ServletContextHandler], conf: SparkConf) {
- val filters: Array[String] = conf.get("spark.ui.filters",
"").split(',').map(_.trim())
- filters.foreach {
- case filter : String =>
- if (!filter.isEmpty) {
- logInfo(s"Adding filter $filter to
${handlers.map(_.getContextPath).mkString(", ")}.")
- val holder : FilterHolder = new FilterHolder()
- holder.setClassName(filter)
- // Get any parameters for each filter
- conf.get("spark." + filter + ".params",
"").split(',').map(_.trim()).toSet.foreach {
- param: String =>
- if (!param.isEmpty) {
- val parts = param.split("=")
- if (parts.length == 2) holder.setInitParameter(parts(0),
parts(1))
- }
- }
-
- val prefix = s"spark.$filter.param."
- conf.getAll
- .filter { case (k, v) => k.length() > prefix.length() &&
k.startsWith(prefix) }
- .foreach { case (k, v) =>
holder.setInitParameter(k.substring(prefix.length()), v) }
-
- val enumDispatcher = java.util.EnumSet.of(DispatcherType.ASYNC,
DispatcherType.ERROR,
- DispatcherType.FORWARD, DispatcherType.INCLUDE,
DispatcherType.REQUEST)
- handlers.foreach { case(handler) => handler.addFilter(holder, "/*",
enumDispatcher) }
- }
- }
- }
-
/**
* Attempt to start a Jetty server bound to the supplied hostName:port using
the given
* context handlers.
@@ -298,12 +241,9 @@ private[spark] object JettyUtils extends Logging {
hostName: String,
port: Int,
sslOptions: SSLOptions,
- handlers: Seq[ServletContextHandler],
conf: SparkConf,
serverName: String = ""): ServerInfo = {
- addFilters(handlers, conf)
-
// Start the server first, with no connectors.
val pool = new QueuedThreadPool
if (serverName.nonEmpty) {
@@ -398,16 +338,6 @@ private[spark] object JettyUtils extends Logging {
}
server.addConnector(httpConnector)
-
- // Add all the known handlers now that connectors are configured.
- handlers.foreach { h =>
- h.setVirtualHosts(toVirtualHosts(SPARK_CONNECTOR_NAME))
- val gzipHandler = new GzipHandler()
- gzipHandler.setHandler(h)
- collection.addHandler(gzipHandler)
- gzipHandler.start()
- }
-
pool.setMaxThreads(math.max(pool.getMaxThreads, minThreads))
ServerInfo(server, httpPort, securePort, conf, collection)
} catch {
@@ -489,6 +419,16 @@ private[spark] object JettyUtils extends Logging {
}
}
+ def addFilter(
+ handler: ServletContextHandler,
+ filter: String,
+ params: Map[String, String]): Unit = {
+ val holder = new FilterHolder()
+ holder.setClassName(filter)
+ params.foreach { case (k, v) => holder.setInitParameter(k, v) }
+ handler.addFilter(holder, "/*", EnumSet.allOf(classOf[DispatcherType]))
+ }
+
// Create a new URI from the arguments, handling IPv6 host encoding and
default ports.
private def createRedirectURI(
scheme: String, server: String, port: Int, path: String, query: String)
= {
@@ -509,20 +449,37 @@ private[spark] case class ServerInfo(
server: Server,
boundPort: Int,
securePort: Option[Int],
- conf: SparkConf,
- private val rootHandler: ContextHandlerCollection) {
+ private val conf: SparkConf,
+ private val rootHandler: ContextHandlerCollection) extends Logging {
- def addHandler(handler: ServletContextHandler): Unit = {
+ def addHandler(
+ handler: ServletContextHandler,
+ securityMgr: SecurityManager): Unit = synchronized {
handler.setVirtualHosts(JettyUtils.toVirtualHosts(JettyUtils.SPARK_CONNECTOR_NAME))
- JettyUtils.addFilters(Seq(handler), conf)
- rootHandler.addHandler(handler)
+ addFilters(handler, securityMgr)
+
+ val gzipHandler = new GzipHandler()
+ gzipHandler.setHandler(handler)
+ rootHandler.addHandler(gzipHandler)
+
if (!handler.isStarted()) {
handler.start()
}
+ gzipHandler.start()
}
- def removeHandler(handler: ContextHandler): Unit = {
- rootHandler.removeHandler(handler)
+ def removeHandler(handler: ServletContextHandler): Unit = synchronized {
+ // Since addHandler() always adds a wrapping gzip handler, find the
container handler
+ // and remove it.
+ rootHandler.getHandlers()
+ .find { h =>
+ h.isInstanceOf[GzipHandler] &&
h.asInstanceOf[GzipHandler].getHandler() == handler
+ }
+ .foreach { h =>
+ rootHandler.removeHandler(h)
+ h.stop()
+ }
+
if (handler.isStarted) {
handler.stop()
}
@@ -537,4 +494,33 @@ private[spark] case class ServerInfo(
threadPool.asInstanceOf[LifeCycle].stop
}
}
+
+ /**
+ * Add filters, if any, to the given ServletContextHandlers. Always adds a
filter at the end
+ * of the chain to perform security-related functions.
+ */
+ private def addFilters(handler: ServletContextHandler, securityMgr:
SecurityManager): Unit = {
+
conf.getOption("spark.ui.filters").toSeq.flatMap(Utils.stringToSeq).foreach {
filter =>
+ logInfo(s"Adding filter to ${handler.getContextPath()}: $filter")
+ val oldParams = conf.getOption(s"spark.$filter.params").toSeq
+ .flatMap(Utils.stringToSeq)
+ .flatMap { param =>
+ val parts = param.split("=")
+ if (parts.length == 2) Some(parts(0) -> parts(1)) else None
+ }
+ .toMap
+
+ val newParams = conf.getAllWithPrefix(s"spark.$filter.param.").toMap
+
+ JettyUtils.addFilter(handler, filter, oldParams ++ newParams)
+ }
+
+ // This filter must come after user-installed filters, since that's where
authentication
+ // filters are installed. This means that custom filters will see the
request before it's
+ // been validated by the security filter.
+ val securityFilter = new HttpSecurityFilter(conf, securityMgr)
+ val holder = new FilterHolder(securityFilter)
+ handler.addFilter(holder, "/*", EnumSet.allOf(classOf[DispatcherType]))
+ }
+
}
diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala
b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala
index 60a929375baae..967435030bc4d 100644
--- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala
+++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala
@@ -27,8 +27,6 @@ import scala.util.control.NonFatal
import scala.xml._
import scala.xml.transform.{RewriteRule, RuleTransformer}
-import org.apache.commons.lang3.StringEscapeUtils
-
import org.apache.spark.internal.Logging
import org.apache.spark.ui.scope.RDDOperationGraph
@@ -38,8 +36,6 @@ private[spark] object UIUtils extends Logging {
val TABLE_CLASS_STRIPED = TABLE_CLASS_NOT_STRIPED + " table-striped"
val TABLE_CLASS_STRIPED_SORTABLE = TABLE_CLASS_STRIPED + " sortable"
- private val NEWLINE_AND_SINGLE_QUOTE_REGEX =
raw"(?i)(\r\n|\n|\r|%0D%0A|%0A|%0D|'|%27)".r
-
// SimpleDateFormat is not thread-safe. Don't expose it to avoid improper
use.
private val dateFormat = new ThreadLocal[SimpleDateFormat]() {
override def initialValue(): SimpleDateFormat =
@@ -552,23 +548,6 @@ private[spark] object UIUtils extends Logging {
}
}
- /**
- * Remove suspicious characters of user input to prevent Cross-Site
scripting (XSS) attacks
- *
- * For more information about XSS testing:
- * https://www.owasp.org/index.php/XSS_Filter_Evasion_Cheat_Sheet and
- *
https://www.owasp.org/index.php/Testing_for_Reflected_Cross_site_scripting_(OTG-INPVAL-001)
- */
- def stripXSS(requestParameter: String): String = {
- if (requestParameter == null) {
- null
- } else {
- // Remove new lines and single quotes, followed by escaping HTML version
4.0
- StringEscapeUtils.escapeHtml4(
- NEWLINE_AND_SINGLE_QUOTE_REGEX.replaceAllIn(requestParameter, ""))
- }
- }
-
def buildErrorResponse(status: Response.Status, msg: String): Response = {
Response.status(status).entity(msg).`type`(MediaType.TEXT_PLAIN).build()
}
diff --git a/core/src/main/scala/org/apache/spark/ui/WebUI.scala
b/core/src/main/scala/org/apache/spark/ui/WebUI.scala
index 2e43f17e6a8e3..ebf8655ce8c2f 100644
--- a/core/src/main/scala/org/apache/spark/ui/WebUI.scala
+++ b/core/src/main/scala/org/apache/spark/ui/WebUI.scala
@@ -58,7 +58,6 @@ private[spark] abstract class WebUI(
def getBasePath: String = basePath
def getTabs: Seq[WebUITab] = tabs
def getHandlers: Seq[ServletContextHandler] = handlers
- def getSecurityManager: SecurityManager = securityManager
/** Attaches a tab to this UI, along with all of its attached pages. */
def attachTab(tab: WebUITab): Unit = {
@@ -81,9 +80,9 @@ private[spark] abstract class WebUI(
def attachPage(page: WebUIPage): Unit = {
val pagePath = "/" + page.prefix
val renderHandler = createServletHandler(pagePath,
- (request: HttpServletRequest) => page.render(request), securityManager,
conf, basePath)
+ (request: HttpServletRequest) => page.render(request), conf, basePath)
val renderJsonHandler = createServletHandler(pagePath.stripSuffix("/") +
"/json",
- (request: HttpServletRequest) => page.renderJson(request),
securityManager, conf, basePath)
+ (request: HttpServletRequest) => page.renderJson(request), conf,
basePath)
attachHandler(renderHandler)
attachHandler(renderJsonHandler)
val handlers = pageToHandlers.getOrElseUpdate(page,
ArrayBuffer[ServletContextHandler]())
@@ -91,13 +90,13 @@ private[spark] abstract class WebUI(
}
/** Attaches a handler to this UI. */
- def attachHandler(handler: ServletContextHandler): Unit = {
+ def attachHandler(handler: ServletContextHandler): Unit = synchronized {
handlers += handler
- serverInfo.foreach(_.addHandler(handler))
+ serverInfo.foreach(_.addHandler(handler, securityManager))
}
/** Detaches a handler from this UI. */
- def detachHandler(handler: ServletContextHandler): Unit = {
+ def detachHandler(handler: ServletContextHandler): Unit = synchronized {
handlers -= handler
serverInfo.foreach(_.removeHandler(handler))
}
@@ -129,7 +128,9 @@ private[spark] abstract class WebUI(
assert(serverInfo.isEmpty, s"Attempted to bind $className more than once!")
try {
val host = Option(conf.getenv("SPARK_LOCAL_IP")).getOrElse("0.0.0.0")
- serverInfo = Some(startJettyServer(host, port, sslOptions, handlers,
conf, name))
+ val server = startJettyServer(host, port, sslOptions, conf, name)
+ handlers.foreach(server.addHandler(_, securityManager))
+ serverInfo = Some(server)
logInfo(s"Bound $className to $host, and started at $webUrl")
} catch {
case e: Exception =>
diff --git
a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala
b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala
index f9713fb5b4a3c..a13037b5e24db 100644
--- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala
@@ -28,10 +28,8 @@ private[ui] class ExecutorThreadDumpPage(
parent: SparkUITab,
sc: Option[SparkContext]) extends WebUIPage("threadDump") {
- // stripXSS is called first to remove suspicious characters used in XSS
attacks
def render(request: HttpServletRequest): Seq[Node] = {
- val executorId =
- Option(UIUtils.stripXSS(request.getParameter("executorId"))).map {
executorId =>
+ val executorId = Option(request.getParameter("executorId")).map {
executorId =>
UIUtils.decodeURLParameter(executorId)
}.getOrElse {
throw new IllegalArgumentException(s"Missing executorId parameter")
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala
b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala
index 2c22e0555fcb8..b35ea5b52549b 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala
@@ -205,21 +205,17 @@ private[ui] class AllJobsPage(parent: JobsTab, store:
AppStatusStore) extends We
jobTag: String,
jobs: Seq[v1.JobData],
killEnabled: Boolean): Seq[Node] = {
- // stripXSS is called to remove suspicious characters used in XSS attacks
- val allParameters = request.getParameterMap.asScala.toMap.map { case (k,
v) =>
- UIUtils.stripXSS(k) -> v.map(UIUtils.stripXSS).toSeq
- }
- val parameterOtherTable = allParameters.filterNot(_._1.startsWith(jobTag))
+ val parameterOtherTable = request.getParameterMap().asScala
+ .filterNot(_._1.startsWith(jobTag))
.map(para => para._1 + "=" + para._2(0))
val someJobHasJobGroup = jobs.exists(_.jobGroup.isDefined)
val jobIdTitle = if (someJobHasJobGroup) "Job Id (Job Group)" else "Job Id"
- // stripXSS is called first to remove suspicious characters used in XSS
attacks
- val parameterJobPage = UIUtils.stripXSS(request.getParameter(jobTag +
".page"))
- val parameterJobSortColumn = UIUtils.stripXSS(request.getParameter(jobTag
+ ".sort"))
- val parameterJobSortDesc = UIUtils.stripXSS(request.getParameter(jobTag +
".desc"))
- val parameterJobPageSize = UIUtils.stripXSS(request.getParameter(jobTag +
".pageSize"))
+ val parameterJobPage = request.getParameter(jobTag + ".page")
+ val parameterJobSortColumn = request.getParameter(jobTag + ".sort")
+ val parameterJobSortDesc = request.getParameter(jobTag + ".desc")
+ val parameterJobPageSize = request.getParameter(jobTag + ".pageSize")
val jobPage = Option(parameterJobPage).map(_.toInt).getOrElse(1)
val jobSortColumn = Option(parameterJobSortColumn).map { sortColumn =>
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala
b/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala
index cd82439223b07..46295e73e086b 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala
@@ -184,8 +184,7 @@ private[ui] class JobPage(parent: JobsTab, store:
AppStatusStore) extends WebUIP
}
def render(request: HttpServletRequest): Seq[Node] = {
- // stripXSS is called first to remove suspicious characters used in XSS
attacks
- val parameterId = UIUtils.stripXSS(request.getParameter("id"))
+ val parameterId = request.getParameter("id")
require(parameterId != null && parameterId.nonEmpty, "Missing id
parameter")
val jobId = parameterId.toInt
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala
b/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala
index ff1b75e5c5065..37bb292bd5950 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala
@@ -47,9 +47,7 @@ private[ui] class JobsTab(parent: SparkUI, store:
AppStatusStore)
def handleKillRequest(request: HttpServletRequest): Unit = {
if (killEnabled &&
parent.securityManager.checkModifyPermissions(request.getRemoteUser)) {
- // stripXSS is called first to remove suspicious characters used in XSS
attacks
- val jobId =
Option(UIUtils.stripXSS(request.getParameter("id"))).map(_.toInt)
- jobId.foreach { id =>
+ Option(request.getParameter("id")).map(_.toInt).foreach { id =>
store.asOption(store.job(id)).foreach { job =>
if (job.status == JobExecutionStatus.RUNNING) {
sc.foreach(_.cancelJob(id))
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala
b/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala
index 22a40101e33df..6d2710385d9d1 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala
@@ -29,8 +29,7 @@ import org.apache.spark.ui.{UIUtils, WebUIPage}
private[ui] class PoolPage(parent: StagesTab) extends WebUIPage("pool") {
def render(request: HttpServletRequest): Seq[Node] = {
- // stripXSS is called first to remove suspicious characters used in XSS
attacks
- val poolName =
Option(UIUtils.stripXSS(request.getParameter("poolname"))).map { poolname =>
+ val poolName = Option(request.getParameter("poolname")).map { poolname =>
UIUtils.decodeURLParameter(poolname)
}.getOrElse {
throw new IllegalArgumentException(s"Missing poolname parameter")
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
index a213b764abea7..c6a59125ce3cb 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
@@ -80,22 +80,19 @@ private[ui] class StagePage(parent: StagesTab, store:
AppStatusStore) extends We
}
def render(request: HttpServletRequest): Seq[Node] = {
- // stripXSS is called first to remove suspicious characters used in XSS
attacks
- val parameterId = UIUtils.stripXSS(request.getParameter("id"))
+ val parameterId = request.getParameter("id")
require(parameterId != null && parameterId.nonEmpty, "Missing id
parameter")
- val parameterAttempt = UIUtils.stripXSS(request.getParameter("attempt"))
+ val parameterAttempt = request.getParameter("attempt")
require(parameterAttempt != null && parameterAttempt.nonEmpty, "Missing
attempt parameter")
- val parameterTaskPage = UIUtils.stripXSS(request.getParameter("task.page"))
- val parameterTaskSortColumn =
UIUtils.stripXSS(request.getParameter("task.sort"))
- val parameterTaskSortDesc =
UIUtils.stripXSS(request.getParameter("task.desc"))
- val parameterTaskPageSize =
UIUtils.stripXSS(request.getParameter("task.pageSize"))
+ val parameterTaskPage = request.getParameter("task.page")
+ val parameterTaskSortColumn = request.getParameter("task.sort")
+ val parameterTaskSortDesc = request.getParameter("task.desc")
+ val parameterTaskPageSize = request.getParameter("task.pageSize")
- val eventTimelineParameterTaskPage = UIUtils.stripXSS(
- request.getParameter("task.eventTimelinePageNumber"))
- val eventTimelineParameterTaskPageSize = UIUtils.stripXSS(
- request.getParameter("task.eventTimelinePageSize"))
+ val eventTimelineParameterTaskPage =
request.getParameter("task.eventTimelinePageNumber")
+ val eventTimelineParameterTaskPageSize =
request.getParameter("task.eventTimelinePageSize")
var eventTimelineTaskPage =
Option(eventTimelineParameterTaskPage).map(_.toInt).getOrElse(1)
var eventTimelineTaskPageSize = Option(
eventTimelineParameterTaskPageSize).map(_.toInt).getOrElse(100)
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala
b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala
index 766efc15e26ba..330b6422a13af 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala
@@ -42,17 +42,14 @@ private[ui] class StageTableBase(
isFairScheduler: Boolean,
killEnabled: Boolean,
isFailedStage: Boolean) {
- // stripXSS is called to remove suspicious characters used in XSS attacks
- val allParameters = request.getParameterMap.asScala.toMap.map { case (k, v)
=>
- UIUtils.stripXSS(k) -> v.map(UIUtils.stripXSS).toSeq
- }
- val parameterOtherTable = allParameters.filterNot(_._1.startsWith(stageTag))
+ val parameterOtherTable = request.getParameterMap().asScala
+ .filterNot(_._1.startsWith(stageTag))
.map(para => para._1 + "=" + para._2(0))
- val parameterStagePage = UIUtils.stripXSS(request.getParameter(stageTag +
".page"))
- val parameterStageSortColumn =
UIUtils.stripXSS(request.getParameter(stageTag + ".sort"))
- val parameterStageSortDesc = UIUtils.stripXSS(request.getParameter(stageTag
+ ".desc"))
- val parameterStagePageSize = UIUtils.stripXSS(request.getParameter(stageTag
+ ".pageSize"))
+ val parameterStagePage = request.getParameter(stageTag + ".page")
+ val parameterStageSortColumn = request.getParameter(stageTag + ".sort")
+ val parameterStageSortDesc = request.getParameter(stageTag + ".desc")
+ val parameterStagePageSize = request.getParameter(stageTag + ".pageSize")
val stagePage = Option(parameterStagePage).map(_.toInt).getOrElse(1)
val stageSortColumn = Option(parameterStageSortColumn).map { sortColumn =>
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala
b/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala
index 10b032084ce4f..e16c337ba1643 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala
@@ -45,9 +45,7 @@ private[ui] class StagesTab(val parent: SparkUI, val store:
AppStatusStore)
def handleKillRequest(request: HttpServletRequest): Unit = {
if (killEnabled &&
parent.securityManager.checkModifyPermissions(request.getRemoteUser)) {
- // stripXSS is called first to remove suspicious characters used in XSS
attacks
- val stageId =
Option(UIUtils.stripXSS(request.getParameter("id"))).map(_.toInt)
- stageId.foreach { id =>
+ Option(request.getParameter("id")).map(_.toInt).foreach { id =>
store.asOption(store.lastStageAttempt(id)).foreach { stage =>
val status = stage.status
if (status == StageStatus.ACTIVE || status == StageStatus.PENDING) {
diff --git a/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala
b/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala
index 87da290c83057..dde441abe5903 100644
--- a/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala
@@ -31,14 +31,13 @@ import org.apache.spark.util.Utils
private[ui] class RDDPage(parent: SparkUITab, store: AppStatusStore) extends
WebUIPage("rdd") {
def render(request: HttpServletRequest): Seq[Node] = {
- // stripXSS is called first to remove suspicious characters used in XSS
attacks
- val parameterId = UIUtils.stripXSS(request.getParameter("id"))
+ val parameterId = request.getParameter("id")
require(parameterId != null && parameterId.nonEmpty, "Missing id
parameter")
- val parameterBlockPage =
UIUtils.stripXSS(request.getParameter("block.page"))
- val parameterBlockSortColumn =
UIUtils.stripXSS(request.getParameter("block.sort"))
- val parameterBlockSortDesc =
UIUtils.stripXSS(request.getParameter("block.desc"))
- val parameterBlockPageSize =
UIUtils.stripXSS(request.getParameter("block.pageSize"))
+ val parameterBlockPage = request.getParameter("block.page")
+ val parameterBlockSortColumn = request.getParameter("block.sort")
+ val parameterBlockSortDesc = request.getParameter("block.desc")
+ val parameterBlockPageSize = request.getParameter("block.pageSize")
val blockPage = Option(parameterBlockPage).map(_.toInt).getOrElse(1)
val blockSortColumn = Option(parameterBlockSortColumn).getOrElse("Block
Name")
diff --git
a/core/src/test/scala/org/apache/spark/ui/HttpSecurityFilterSuite.scala
b/core/src/test/scala/org/apache/spark/ui/HttpSecurityFilterSuite.scala
new file mode 100644
index 0000000000000..f46cc293ed271
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/ui/HttpSecurityFilterSuite.scala
@@ -0,0 +1,157 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ui
+
+import java.util.UUID
+import javax.servlet.FilterChain
+import javax.servlet.http.{HttpServletRequest, HttpServletResponse}
+
+import scala.collection.JavaConverters._
+
+import org.mockito.ArgumentCaptor
+import org.mockito.ArgumentMatchers.{any, eq => meq}
+import org.mockito.Mockito.{mock, times, verify, when}
+
+import org.apache.spark._
+import org.apache.spark.internal.config._
+
+class HttpSecurityFilterSuite extends SparkFunSuite {
+
+ test("filter bad user input") {
+ val badValues = Map(
+ "encoded" -> "Encoding:base64%0d%0a%0d%0aPGh0bWw%2bjcmlwdD48L2h0bWw%2b",
+ "alert1" -> """>"'><script>alert(401)<%2Fscript>""",
+ "alert2" ->
"""app-20161208133404-0002<iframe+src%3Djavascript%3Aalert(1705)>""",
+ "alert3" -> """stdout'%2Balert(60)%2B'""",
+ "html" ->
"""stdout'"><iframe+id%3D1131+src%3Dhttp%3A%2F%2Fdemo.test.net%2Fphishing.html>"""
+ )
+ val badKeys = badValues.map(_.swap)
+ val goodInput = Map("goodKey" -> "goodValue")
+
+ val conf = new SparkConf()
+ val filter = new HttpSecurityFilter(conf, new SecurityManager(conf))
+
+ def newRequest(): HttpServletRequest = {
+ val req = mock(classOf[HttpServletRequest])
+ when(req.getParameterMap()).thenReturn(Map.empty[String,
Array[String]].asJava)
+ req
+ }
+
+ def doRequest(k: String, v: String): HttpServletRequest = {
+ val req = newRequest()
+ when(req.getParameterMap()).thenReturn(Map(k -> Array(v)).asJava)
+
+ val chain = mock(classOf[FilterChain])
+ val res = mock(classOf[HttpServletResponse])
+ filter.doFilter(req, res, chain)
+
+ val captor = ArgumentCaptor.forClass(classOf[HttpServletRequest])
+ verify(chain).doFilter(captor.capture(), any())
+ captor.getValue()
+ }
+
+ badKeys.foreach { case (k, v) =>
+ val req = doRequest(k, v)
+ assert(req.getParameter(k) === null)
+ assert(req.getParameterValues(k) === null)
+ assert(!req.getParameterMap().containsKey(k))
+ }
+
+ badValues.foreach { case (k, v) =>
+ val req = doRequest(k, v)
+ assert(req.getParameter(k) !== null)
+ assert(req.getParameter(k) !== v)
+ assert(req.getParameterValues(k) !== null)
+ assert(req.getParameterValues(k) !== Array(v))
+ assert(req.getParameterMap().get(k) !== null)
+ assert(req.getParameterMap().get(k) !== Array(v))
+ }
+
+ goodInput.foreach { case (k, v) =>
+ val req = doRequest(k, v)
+ assert(req.getParameter(k) === v)
+ assert(req.getParameterValues(k) === Array(v))
+ assert(req.getParameterMap().get(k) === Array(v))
+ }
+ }
+
+ test("perform access control") {
+ val conf = new SparkConf(false)
+ .set("spark.ui.acls.enable", "true")
+ .set("spark.admin.acls", "admin")
+ .set("spark.ui.view.acls", "alice")
+ val secMgr = new SecurityManager(conf)
+
+ val req = mockEmptyRequest()
+ val res = mock(classOf[HttpServletResponse])
+ val chain = mock(classOf[FilterChain])
+
+ val filter = new HttpSecurityFilter(conf, secMgr)
+
+ when(req.getRemoteUser()).thenReturn("admin")
+ filter.doFilter(req, res, chain)
+ verify(chain, times(1)).doFilter(any(), any())
+
+ when(req.getRemoteUser()).thenReturn("alice")
+ filter.doFilter(req, res, chain)
+ verify(chain, times(2)).doFilter(any(), any())
+
+ // Because the current user is added to the view ACLs, let's try to create
an invalid
+ // name, to avoid matching some common user name.
+ when(req.getRemoteUser()).thenReturn(UUID.randomUUID().toString())
+ filter.doFilter(req, res, chain)
+
+ // chain.doFilter() should not be called again, so same count as above.
+ verify(chain, times(2)).doFilter(any(), any())
+ verify(res).sendError(meq(HttpServletResponse.SC_FORBIDDEN), any())
+ }
+
+ test("set security-related headers") {
+ val conf = new SparkConf(false)
+ .set("spark.ui.allowFramingFrom", "example.com")
+ .set(UI_X_XSS_PROTECTION, "xssProtection")
+ .set(UI_X_CONTENT_TYPE_OPTIONS, true)
+ .set(UI_STRICT_TRANSPORT_SECURITY, "tsec")
+ val secMgr = new SecurityManager(conf)
+ val req = mockEmptyRequest()
+ val res = mock(classOf[HttpServletResponse])
+ val chain = mock(classOf[FilterChain])
+
+ when(req.getScheme()).thenReturn("https")
+
+ val filter = new HttpSecurityFilter(conf, secMgr)
+ filter.doFilter(req, res, chain)
+
+ Map(
+ "X-Frame-Options" -> "ALLOW-FROM example.com",
+ "X-XSS-Protection" -> "xssProtection",
+ "X-Content-Type-Options" -> "nosniff",
+ "Strict-Transport-Security" -> "tsec"
+ ).foreach { case (name, value) =>
+ verify(res).setHeader(meq(name), meq(value))
+ }
+ }
+
+ private def mockEmptyRequest(): HttpServletRequest = {
+ val params: Map[String, Array[String]] = Map.empty
+ val req = mock(classOf[HttpServletRequest])
+ when(req.getParameterMap()).thenReturn(params.asJava)
+ req
+ }
+
+}
diff --git a/core/src/test/scala/org/apache/spark/ui/UISuite.scala
b/core/src/test/scala/org/apache/spark/ui/UISuite.scala
index 36ea3799afdf2..eaa8f28ae0621 100644
--- a/core/src/test/scala/org/apache/spark/ui/UISuite.scala
+++ b/core/src/test/scala/org/apache/spark/ui/UISuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.ui
import java.net.{BindException, ServerSocket}
import java.net.{URI, URL}
import java.util.Locale
+import javax.servlet._
import javax.servlet.http.{HttpServlet, HttpServletRequest,
HttpServletResponse}
import scala.io.Source
@@ -49,12 +50,14 @@ class UISuite extends SparkFunSuite {
sc
}
- private def sslDisabledConf(): (SparkConf, SSLOptions) = {
+ private def sslDisabledConf(): (SparkConf, SecurityManager, SSLOptions) = {
val conf = new SparkConf
- (conf, new SecurityManager(conf).getSSLOptions("ui"))
+ val securityMgr = new SecurityManager(conf)
+ (conf, securityMgr, securityMgr.getSSLOptions("ui"))
}
- private def sslEnabledConf(sslPort: Option[Int] = None): (SparkConf,
SSLOptions) = {
+ private def sslEnabledConf(sslPort: Option[Int] = None):
+ (SparkConf, SecurityManager, SSLOptions) = {
val keyStoreFilePath = getTestResourcePath("spark.keystore")
val conf = new SparkConf()
.set("spark.ssl.ui.enabled", "true")
@@ -64,7 +67,8 @@ class UISuite extends SparkFunSuite {
sslPort.foreach { p =>
conf.set("spark.ssl.ui.port", p.toString)
}
- (conf, new SecurityManager(conf).getSSLOptions("ui"))
+ val securityMgr = new SecurityManager(conf)
+ (conf, securityMgr, securityMgr.getSSLOptions("ui"))
}
ignore("basic ui visibility") {
@@ -95,14 +99,12 @@ class UISuite extends SparkFunSuite {
var server: ServerSocket = null
var serverInfo1: ServerInfo = null
var serverInfo2: ServerInfo = null
- val (conf, sslOptions) = sslDisabledConf()
+ val (conf, _, sslOptions) = sslDisabledConf()
try {
server = new ServerSocket(0)
val startPort = server.getLocalPort
- serverInfo1 = JettyUtils.startJettyServer(
- "0.0.0.0", startPort, sslOptions, Seq[ServletContextHandler](), conf)
- serverInfo2 = JettyUtils.startJettyServer(
- "0.0.0.0", startPort, sslOptions, Seq[ServletContextHandler](), conf)
+ serverInfo1 = JettyUtils.startJettyServer("0.0.0.0", startPort,
sslOptions, conf)
+ serverInfo2 = JettyUtils.startJettyServer("0.0.0.0", startPort,
sslOptions, conf)
// Allow some wiggle room in case ports on the machine are under
contention
val boundPort1 = serverInfo1.boundPort
val boundPort2 = serverInfo2.boundPort
@@ -123,11 +125,9 @@ class UISuite extends SparkFunSuite {
try {
server = new ServerSocket(0)
val startPort = server.getLocalPort
- val (conf, sslOptions) = sslEnabledConf()
- serverInfo1 = JettyUtils.startJettyServer(
- "0.0.0.0", startPort, sslOptions, Seq[ServletContextHandler](), conf,
"server1")
- serverInfo2 = JettyUtils.startJettyServer(
- "0.0.0.0", startPort, sslOptions, Seq[ServletContextHandler](), conf,
"server2")
+ val (conf, _, sslOptions) = sslEnabledConf()
+ serverInfo1 = JettyUtils.startJettyServer("0.0.0.0", startPort,
sslOptions, conf, "server1")
+ serverInfo2 = JettyUtils.startJettyServer("0.0.0.0", startPort,
sslOptions, conf, "server2")
// Allow some wiggle room in case ports on the machine are under
contention
val boundPort1 = serverInfo1.boundPort
val boundPort2 = serverInfo2.boundPort
@@ -144,10 +144,9 @@ class UISuite extends SparkFunSuite {
test("jetty binds to port 0 correctly") {
var socket: ServerSocket = null
var serverInfo: ServerInfo = null
- val (conf, sslOptions) = sslDisabledConf()
+ val (conf, _, sslOptions) = sslDisabledConf()
try {
- serverInfo = JettyUtils.startJettyServer(
- "0.0.0.0", 0, sslOptions, Seq[ServletContextHandler](), conf)
+ serverInfo = JettyUtils.startJettyServer("0.0.0.0", 0, sslOptions, conf)
val server = serverInfo.server
val boundPort = serverInfo.boundPort
assert(server.getState === "STARTED")
@@ -165,9 +164,8 @@ class UISuite extends SparkFunSuite {
var socket: ServerSocket = null
var serverInfo: ServerInfo = null
try {
- val (conf, sslOptions) = sslEnabledConf()
- serverInfo = JettyUtils.startJettyServer(
- "0.0.0.0", 0, sslOptions, Seq[ServletContextHandler](), conf)
+ val (conf, _, sslOptions) = sslEnabledConf()
+ serverInfo = JettyUtils.startJettyServer("0.0.0.0", 0, sslOptions, conf)
val server = serverInfo.server
val boundPort = serverInfo.boundPort
assert(server.getState === "STARTED")
@@ -231,30 +229,49 @@ class UISuite extends SparkFunSuite {
assert(newHeader === null)
}
- test("http -> https redirect applies to all URIs") {
- var serverInfo: ServerInfo = null
+ test("add and remove handlers with custom user filter") {
+ val (conf, securityMgr, sslOptions) = sslDisabledConf()
+ conf.set("spark.ui.filters", classOf[TestFilter].getName())
+ conf.set(s"spark.${classOf[TestFilter].getName()}.param.responseCode",
+ HttpServletResponse.SC_NOT_ACCEPTABLE.toString)
+
+ val serverInfo = JettyUtils.startJettyServer("0.0.0.0", 0, sslOptions,
conf)
try {
- val servlet = new HttpServlet() {
- override def doGet(req: HttpServletRequest, res: HttpServletResponse):
Unit = {
- res.sendError(HttpServletResponse.SC_OK)
- }
- }
+ val path = "/test"
+ val url = new URL(s"http://localhost:${serverInfo.boundPort}$path/root")
- def newContext(path: String): ServletContextHandler = {
- val ctx = new ServletContextHandler()
- ctx.setContextPath(path)
- ctx.addServlet(new ServletHolder(servlet), "/root")
- ctx
- }
+ assert(TestUtils.httpResponseCode(url) ===
HttpServletResponse.SC_NOT_FOUND)
+
+ val (servlet, ctx) = newContext(path)
+ serverInfo.addHandler(ctx, securityMgr)
+ assert(TestUtils.httpResponseCode(url) ===
HttpServletResponse.SC_NOT_ACCEPTABLE)
+
+ // Try a request with bad content in a parameter to make sure the
security filter
+ // is being added to new handlers.
+ val badRequest = new URL(
+
s"http://localhost:${serverInfo.boundPort}$path/root?bypass&invalid<=foo")
+ assert(TestUtils.httpResponseCode(badRequest) ===
HttpServletResponse.SC_OK)
+ assert(servlet.lastRequest.getParameter("invalid<") === null)
+ assert(servlet.lastRequest.getParameter("invalid<") !== null)
- val (conf, sslOptions) = sslEnabledConf()
- serverInfo = JettyUtils.startJettyServer("0.0.0.0", 0, sslOptions,
- Seq[ServletContextHandler](newContext("/"), newContext("/test1")),
- conf)
+ serverInfo.removeHandler(ctx)
+ assert(TestUtils.httpResponseCode(url) ===
HttpServletResponse.SC_NOT_FOUND)
+ } finally {
+ stopServer(serverInfo)
+ }
+ }
+
+ test("http -> https redirect applies to all URIs") {
+ val (conf, securityMgr, sslOptions) = sslEnabledConf()
+ val serverInfo = JettyUtils.startJettyServer("0.0.0.0", 0, sslOptions,
conf)
+ try {
+ Seq(newContext("/"), newContext("/test1")).foreach { case (_, ctx) =>
+ serverInfo.addHandler(ctx, securityMgr)
+ }
assert(serverInfo.server.getState === "STARTED")
- val testContext = newContext("/test2")
- serverInfo.addHandler(testContext)
+ val (_, testContext) = newContext("/test2")
+ serverInfo.addHandler(testContext, securityMgr)
testContext.start()
val httpPort = serverInfo.boundPort
@@ -286,10 +303,10 @@ class UISuite extends SparkFunSuite {
// Make sure the SSL port lies way outside the "http + 400" range used
as the default.
val baseSslPort = Utils.userPort(socket.getLocalPort(), 10000)
- val (conf, sslOptions) = sslEnabledConf(sslPort = Some(baseSslPort))
+ val (conf, _, sslOptions) = sslEnabledConf(sslPort = Some(baseSslPort))
serverInfo = JettyUtils.startJettyServer("0.0.0.0",
socket.getLocalPort() + 1,
- sslOptions, Seq[ServletContextHandler](), conf, "server1")
+ sslOptions, conf, serverName = "server1")
val notAllowed = Utils.userPort(serverInfo.boundPort, 400)
assert(serverInfo.securePort.isDefined)
@@ -300,6 +317,18 @@ class UISuite extends SparkFunSuite {
}
}
+ /**
+ * Create a new context handler for the given path, with a single servlet
that responds to
+ * requests in `$path/root`.
+ */
+ private def newContext(path: String): (CapturingServlet,
ServletContextHandler) = {
+ val servlet = new CapturingServlet()
+ val ctx = new ServletContextHandler()
+ ctx.setContextPath(path)
+ ctx.addServlet(new ServletHolder(servlet), "/root")
+ (servlet, ctx)
+ }
+
def stopServer(info: ServerInfo): Unit = {
if (info != null) info.stop()
}
@@ -307,4 +336,40 @@ class UISuite extends SparkFunSuite {
def closeSocket(socket: ServerSocket): Unit = {
if (socket != null) socket.close
}
+
+ /** Test servlet that exposes the last request object for GET calls. */
+ private class CapturingServlet extends HttpServlet {
+
+ @volatile var lastRequest: HttpServletRequest = _
+
+ override def doGet(req: HttpServletRequest, res: HttpServletResponse):
Unit = {
+ lastRequest = req
+ res.sendError(HttpServletResponse.SC_OK)
+ }
+
+ }
+
+}
+
+// Filter for testing; returns a configurable code for every request.
+private[spark] class TestFilter extends Filter {
+
+ private var rc: Int = HttpServletResponse.SC_OK
+
+ override def destroy(): Unit = { }
+
+ override def init(config: FilterConfig): Unit = {
+ if (config.getInitParameter("responseCode") != null) {
+ rc = config.getInitParameter("responseCode").toInt
+ }
+ }
+
+ override def doFilter(req: ServletRequest, res: ServletResponse, chain:
FilterChain): Unit = {
+ if (req.getParameter("bypass") == null) {
+ res.asInstanceOf[HttpServletResponse].sendError(rc, "Test.")
+ } else {
+ chain.doFilter(req, res)
+ }
+ }
+
}
diff --git a/core/src/test/scala/org/apache/spark/ui/UIUtilsSuite.scala
b/core/src/test/scala/org/apache/spark/ui/UIUtilsSuite.scala
index 423daacc0f5a5..c770fd5da76f7 100644
--- a/core/src/test/scala/org/apache/spark/ui/UIUtilsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ui/UIUtilsSuite.scala
@@ -133,45 +133,6 @@ class UIUtilsSuite extends SparkFunSuite {
assert(decoded2 === decodeURLParameter(decoded2))
}
- test("SPARK-20393: Prevent newline characters in parameters.") {
- val encoding = "Encoding:base64%0d%0a%0d%0aPGh0bWw%2bjcmlwdD48L2h0bWw%2b"
- val stripEncoding = "Encoding:base64PGh0bWw%2bjcmlwdD48L2h0bWw%2b"
-
- assert(stripEncoding === stripXSS(encoding))
- }
-
- test("SPARK-20393: Prevent script from parameters running on page.") {
- val scriptAlert = """>"'><script>alert(401)<%2Fscript>"""
- val stripScriptAlert =
">"><script>alert(401)<%2Fscript>"
-
- assert(stripScriptAlert === stripXSS(scriptAlert))
- }
-
- test("SPARK-20393: Prevent javascript from parameters running on page.") {
- val javascriptAlert =
- """app-20161208133404-0002<iframe+src%3Djavascript%3Aalert(1705)>"""
- val stripJavascriptAlert =
- "app-20161208133404-0002<iframe+src%3Djavascript%3Aalert(1705)>"
-
- assert(stripJavascriptAlert === stripXSS(javascriptAlert))
- }
-
- test("SPARK-20393: Prevent links from parameters on page.") {
- val link =
-
"""stdout'"><iframe+id%3D1131+src%3Dhttp%3A%2F%2Fdemo.test.net%2Fphishing.html>"""
- val stripLink =
-
"stdout"><iframe+id%3D1131+src%3Dhttp%3A%2F%2Fdemo.test.net%2Fphishing.html>"
-
- assert(stripLink === stripXSS(link))
- }
-
- test("SPARK-20393: Prevent popups from parameters on page.") {
- val popup = """stdout'%2Balert(60)%2B'"""
- val stripPopup = "stdout%2Balert(60)%2B"
-
- assert(stripPopup === stripXSS(popup))
- }
-
private def verify(
desc: String,
expected: Node,
diff --git
a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala
b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala
index 91f64141e5318..6e4571eba0361 100644
---
a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala
+++
b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala
@@ -29,8 +29,7 @@ import org.apache.spark.ui.{UIUtils, WebUIPage}
private[ui] class DriverPage(parent: MesosClusterUI) extends
WebUIPage("driver") {
override def render(request: HttpServletRequest): Seq[Node] = {
- // stripXSS is called first to remove suspicious characters used in XSS
attacks
- val driverId = UIUtils.stripXSS(request.getParameter("id"))
+ val driverId = request.getParameter("id")
require(driverId != null && driverId.nonEmpty, "Missing id parameter")
val state = parent.scheduler.getDriverState(driverId)
diff --git
a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala
b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala
index 6357d4adbcd99..a9ff3023a5811 100644
---
a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala
+++
b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala
@@ -17,7 +17,9 @@
package org.apache.spark.scheduler.cluster
+import java.util.EnumSet
import java.util.concurrent.atomic.{AtomicBoolean}
+import javax.servlet.DispatcherType
import scala.concurrent.ExecutionContext.Implicits.global
import scala.concurrent.Future
@@ -25,6 +27,7 @@ import scala.util.{Failure, Success}
import scala.util.control.NonFatal
import org.apache.hadoop.yarn.api.records.{ApplicationAttemptId, ApplicationId}
+import org.eclipse.jetty.servlet.{FilterHolder, FilterMapping}
import org.apache.spark.SparkContext
import org.apache.spark.deploy.security.HadoopDelegationTokenManager
@@ -159,7 +162,7 @@ private[spark] abstract class YarnSchedulerBackend(
/**
* Add filters to the SparkUI.
*/
- private def addWebUIFilter(
+ private[cluster] def addWebUIFilter(
filterName: String,
filterParams: Map[String, String],
proxyBase: String): Unit = {
@@ -174,9 +177,33 @@ private[spark] abstract class YarnSchedulerBackend(
// SPARK-26255: Append user provided filters(spark.ui.filters) with yarn
filter.
val allFilters = filterName + "," + conf.get("spark.ui.filters", "")
logInfo(s"Add WebUI Filter. $filterName, $filterParams, $proxyBase")
- conf.set("spark.ui.filters", allFilters)
- filterParams.foreach { case (k, v) =>
conf.set(s"spark.$filterName.param.$k", v) }
- scheduler.sc.ui.foreach { ui => JettyUtils.addFilters(ui.getHandlers,
conf) }
+
+ // For already installed handlers, prepend the filter.
+ scheduler.sc.ui.foreach { ui =>
+ // Lock the UI so that new handlers are not added while this is
running. Set the updated
+ // filter config inside the lock so that we're sure all handlers will
properly get it.
+ ui.synchronized {
+ filterParams.foreach { case (k, v) =>
+ conf.set(s"spark.$filterName.param.$k", v)
+ }
+ conf.set("spark.ui.filters", allFilters)
+
+ ui.getHandlers.map(_.getServletHandler()).foreach { h =>
+ val holder = new FilterHolder()
+ holder.setName(filterName)
+ holder.setClassName(filterName)
+ filterParams.foreach { case (k, v) => holder.setInitParameter(k,
v) }
+ h.addFilter(holder)
+
+ val mapping = new FilterMapping()
+ mapping.setFilterName(filterName)
+ mapping.setPathSpec("/*")
+ mapping.setDispatcherTypes(EnumSet.allOf(classOf[DispatcherType]))
+
+ h.prependFilterMapping(mapping)
+ }
+ }
+ }
}
}
diff --git
a/resource-managers/yarn/src/test/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackendSuite.scala
b/resource-managers/yarn/src/test/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackendSuite.scala
index 7fac57ff68abc..5d285f89f22f5 100644
---
a/resource-managers/yarn/src/test/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackendSuite.scala
+++
b/resource-managers/yarn/src/test/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackendSuite.scala
@@ -16,14 +16,19 @@
*/
package org.apache.spark.scheduler.cluster
+import java.net.URL
+import javax.servlet.http.{HttpServlet, HttpServletRequest,
HttpServletResponse}
+
import scala.language.reflectiveCalls
+import org.eclipse.jetty.servlet.{ServletContextHandler, ServletHolder}
import org.mockito.Mockito.when
import org.scalatest.mockito.MockitoSugar
-import org.apache.spark.{LocalSparkContext, SparkContext, SparkFunSuite}
+import org.apache.spark._
import org.apache.spark.scheduler.TaskSchedulerImpl
import org.apache.spark.serializer.JavaSerializer
+import org.apache.spark.ui.TestFilter
class YarnSchedulerBackendSuite extends SparkFunSuite with MockitoSugar with
LocalSparkContext {
@@ -54,7 +59,57 @@ class YarnSchedulerBackendSuite extends SparkFunSuite with
MockitoSugar with Loc
// Serialize to make sure serialization doesn't throw an error
ser.serialize(req)
}
- sc.stop()
+ }
+
+ test("Respect user filters when adding AM IP filter") {
+ val conf = new SparkConf(false)
+ .set("spark.ui.filters", classOf[TestFilter].getName())
+ .set(s"spark.${classOf[TestFilter].getName()}.param.responseCode",
+ HttpServletResponse.SC_BAD_GATEWAY.toString)
+
+ sc = new SparkContext("local", "YarnSchedulerBackendSuite", conf)
+ val sched = mock[TaskSchedulerImpl]
+ when(sched.sc).thenReturn(sc)
+
+ val url = new URL(sc.uiWebUrl.get)
+ // Before adding the "YARN" filter, should get the code from the filter in
SparkConf.
+ assert(TestUtils.httpResponseCode(url) ===
HttpServletResponse.SC_BAD_GATEWAY)
+
+ val backend = new YarnSchedulerBackend(sched, sc) { }
+
+ backend.addWebUIFilter(classOf[TestFilter2].getName(),
+ Map("responseCode" -> HttpServletResponse.SC_NOT_ACCEPTABLE.toString),
"")
+
+ sc.ui.get.getHandlers.foreach { h =>
+ // Two filters above + security filter.
+ assert(h.getServletHandler().getFilters().length === 3)
+ }
+
+ // The filter should have been added first in the chain, so we should get
SC_NOT_ACCEPTABLE
+ // instead of SC_OK.
+ assert(TestUtils.httpResponseCode(url) ===
HttpServletResponse.SC_NOT_ACCEPTABLE)
+
+ // Add a new handler and make sure the added filter is properly registered.
+ val servlet = new HttpServlet() {
+ override def doGet(req: HttpServletRequest, res: HttpServletResponse):
Unit = {
+ res.sendError(HttpServletResponse.SC_CONFLICT)
+ }
+ }
+
+ val ctx = new ServletContextHandler()
+ ctx.setContextPath("/new-handler")
+ ctx.addServlet(new ServletHolder(servlet), "/")
+
+ sc.ui.get.attachHandler(ctx)
+
+ val newUrl = new URL(sc.uiWebUrl.get + "/new-handler/")
+ assert(TestUtils.httpResponseCode(newUrl) ===
HttpServletResponse.SC_NOT_ACCEPTABLE)
+
+ val bypassUrl = new URL(sc.uiWebUrl.get + "/new-handler/?bypass")
+ assert(TestUtils.httpResponseCode(bypassUrl) ===
HttpServletResponse.SC_CONFLICT)
}
}
+
+// Just extend the test filter so we can configure two of them.
+class TestFilter2 extends TestFilter
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala
index 4958f154e625f..05890de5e1260 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala
@@ -158,19 +158,14 @@ private[ui] class AllExecutionsPage(parent: SQLTab)
extends WebUIPage("") with L
showSucceededJobs: Boolean,
showFailedJobs: Boolean): Seq[Node] = {
- // stripXSS is called to remove suspicious characters used in XSS attacks
- val allParameters = request.getParameterMap.asScala.toMap.map { case (k,
v) =>
- UIUtils.stripXSS(k) -> v.map(UIUtils.stripXSS).toSeq
+ val parameterOtherTable = request.getParameterMap().asScala.map { case
(name, vals) =>
+ name + "=" + vals(0)
}
- val parameterOtherTable =
allParameters.filterNot(_._1.startsWith(executionTag))
- .map(para => para._1 + "=" + para._2(0))
-
- val parameterExecutionPage =
UIUtils.stripXSS(request.getParameter(s"$executionTag.page"))
- val parameterExecutionSortColumn = UIUtils.stripXSS(request
- .getParameter(s"$executionTag.sort"))
- val parameterExecutionSortDesc =
UIUtils.stripXSS(request.getParameter(s"$executionTag.desc"))
- val parameterExecutionPageSize = UIUtils.stripXSS(request
- .getParameter(s"$executionTag.pageSize"))
+
+ val parameterExecutionPage = request.getParameter(s"$executionTag.page")
+ val parameterExecutionSortColumn =
request.getParameter(s"$executionTag.sort")
+ val parameterExecutionSortDesc =
request.getParameter(s"$executionTag.desc")
+ val parameterExecutionPageSize =
request.getParameter(s"$executionTag.pageSize")
val executionPage =
Option(parameterExecutionPage).map(_.toInt).getOrElse(1)
val executionSortColumn = Option(parameterExecutionSortColumn).map {
sortColumn =>
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala
index e4c119e6d06c3..875086cda258d 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala
@@ -30,8 +30,7 @@ class ExecutionPage(parent: SQLTab) extends
WebUIPage("execution") with Logging
private val sqlStore = parent.sqlStore
override def render(request: HttpServletRequest): Seq[Node] = {
- // stripXSS is called first to remove suspicious characters used in XSS
attacks
- val parameterExecutionId = UIUtils.stripXSS(request.getParameter("id"))
+ val parameterExecutionId = request.getParameter("id")
require(parameterExecutionId != null && parameterExecutionId.nonEmpty,
"Missing execution id parameter")
diff --git
a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala
b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala
index f46eeea941540..fdc9bee5ed056 100644
---
a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala
+++
b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala
@@ -39,8 +39,7 @@ private[ui] class ThriftServerSessionPage(parent:
ThriftServerTab)
/** Render the page */
def render(request: HttpServletRequest): Seq[Node] = {
- // stripXSS is called first to remove suspicious characters used in XSS
attacks
- val parameterId = UIUtils.stripXSS(request.getParameter("id"))
+ val parameterId = request.getParameter("id")
require(parameterId != null && parameterId.nonEmpty, "Missing id
parameter")
val content =
diff --git
a/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala
b/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala
index 884d21d0afdd3..dc7876bad68d9 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala
@@ -317,12 +317,10 @@ private[ui] class BatchPage(parent: StreamingTab) extends
WebUIPage("batch") {
}
def render(request: HttpServletRequest): Seq[Node] =
streamingListener.synchronized {
- // stripXSS is called first to remove suspicious characters used in XSS
attacks
- val batchTime =
- Option(SparkUIUtils.stripXSS(request.getParameter("id"))).map(id =>
Time(id.toLong))
+ val batchTime = Option(request.getParameter("id")).map(id =>
Time(id.toLong))
.getOrElse {
- throw new IllegalArgumentException(s"Missing id parameter")
- }
+ throw new IllegalArgumentException(s"Missing id parameter")
+ }
val formattedBatchTime =
UIUtils.formatBatchTime(batchTime.milliseconds,
streamingListener.batchDuration)
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]