Repository: spark Updated Branches: refs/heads/master 571aa2755 -> 28315714d
[SPARK-22791][SQL][SS] Redact Output of Explain ## What changes were proposed in this pull request? When calling explain on a query, the output can contain sensitive information. We should provide an admin/user to redact such information. Before this PR, the plan of SS is like this ``` == Physical Plan == *HashAggregate(keys=[value#6], functions=[count(1)], output=[value#6, count(1)#12L]) +- StateStoreSave [value#6], state info [ checkpoint = file:/private/var/folders/vx/j0ydl5rn0gd9mgrh1pljnw900000gn/T/temporary-91c6fac0-609f-4bc8-ad57-52c189f06797/state, runId = 05a4b3af-f02c-40f8-9ff9-a3e18bae496f, opId = 0, ver = 0, numPartitions = 5], Complete, 0 +- *HashAggregate(keys=[value#6], functions=[merge_count(1)], output=[value#6, count#18L]) +- StateStoreRestore [value#6], state info [ checkpoint = file:/private/var/folders/vx/j0ydl5rn0gd9mgrh1pljnw900000gn/T/temporary-91c6fac0-609f-4bc8-ad57-52c189f06797/state, runId = 05a4b3af-f02c-40f8-9ff9-a3e18bae496f, opId = 0, ver = 0, numPartitions = 5] +- *HashAggregate(keys=[value#6], functions=[merge_count(1)], output=[value#6, count#18L]) +- Exchange hashpartitioning(value#6, 5) +- *HashAggregate(keys=[value#6], functions=[partial_count(1)], output=[value#6, count#18L]) +- *SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, input[0, java.lang.String, true], true, false) AS value#6] +- *MapElements <function1>, obj#5: java.lang.String +- *DeserializeToObject value#30.toString, obj#4: java.lang.String +- LocalTableScan [value#30] ``` After this PR, we can get the following output if users set `spark.redaction.string.regex` to `file:/[\\w_]+` ``` == Physical Plan == *HashAggregate(keys=[value#6], functions=[count(1)], output=[value#6, count(1)#12L]) +- StateStoreSave [value#6], state info [ checkpoint = *********(redacted)/var/folders/vx/j0ydl5rn0gd9mgrh1pljnw900000gn/T/temporary-e7da9b7d-3ec0-474d-8b8c-927f7d12ed72/state, runId = 8a9c3761-93d5-4896-ab82-14c06240dcea, opId = 0, ver = 0, numPartitions = 5], Complete, 0 +- *HashAggregate(keys=[value#6], functions=[merge_count(1)], output=[value#6, count#32L]) +- StateStoreRestore [value#6], state info [ checkpoint = *********(redacted)/var/folders/vx/j0ydl5rn0gd9mgrh1pljnw900000gn/T/temporary-e7da9b7d-3ec0-474d-8b8c-927f7d12ed72/state, runId = 8a9c3761-93d5-4896-ab82-14c06240dcea, opId = 0, ver = 0, numPartitions = 5] +- *HashAggregate(keys=[value#6], functions=[merge_count(1)], output=[value#6, count#32L]) +- Exchange hashpartitioning(value#6, 5) +- *HashAggregate(keys=[value#6], functions=[partial_count(1)], output=[value#6, count#32L]) +- *SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, input[0, java.lang.String, true], true, false) AS value#6] +- *MapElements <function1>, obj#5: java.lang.String +- *DeserializeToObject value#27.toString, obj#4: java.lang.String +- LocalTableScan [value#27] ``` ## How was this patch tested? Added a test case Author: gatorsmile <gatorsm...@gmail.com> Closes #19985 from gatorsmile/redactPlan. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/28315714 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/28315714 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/28315714 Branch: refs/heads/master Commit: 28315714ddef3ddcc192375e98dd5207cf4ecc98 Parents: 571aa27 Author: gatorsmile <gatorsm...@gmail.com> Authored: Tue Dec 19 22:12:23 2017 +0800 Committer: Wenchen Fan <wenc...@databricks.com> Committed: Tue Dec 19 22:12:23 2017 +0800 ---------------------------------------------------------------------- .../scala/org/apache/spark/util/Utils.scala | 26 +++++++++++---- .../org/apache/spark/sql/internal/SQLConf.scala | 11 +++++++ .../sql/execution/DataSourceScanExec.scala | 2 +- .../spark/sql/execution/QueryExecution.scala | 13 ++++++-- .../DataSourceScanExecRedactionSuite.scala | 31 ++++++++++++++++++ .../spark/sql/streaming/StreamSuite.scala | 33 +++++++++++++++++++- 6 files changed, 105 insertions(+), 11 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/28315714/core/src/main/scala/org/apache/spark/util/Utils.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 8871870..5853302 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -2651,14 +2651,28 @@ private[spark] object Utils extends Logging { } /** + * Redact the sensitive values in the given map. If a map key matches the redaction pattern then + * its value is replaced with a dummy text. + */ + def redact(regex: Option[Regex], kvs: Seq[(String, String)]): Seq[(String, String)] = { + regex match { + case None => kvs + case Some(r) => redact(r, kvs) + } + } + + /** * Redact the sensitive information in the given string. */ - def redact(conf: SparkConf, text: String): String = { - if (text == null || text.isEmpty || conf == null || !conf.contains(STRING_REDACTION_PATTERN)) { - text - } else { - val regex = conf.get(STRING_REDACTION_PATTERN).get - regex.replaceAllIn(text, REDACTION_REPLACEMENT_TEXT) + def redact(regex: Option[Regex], text: String): String = { + regex match { + case None => text + case Some(r) => + if (text == null || text.isEmpty) { + text + } else { + r.replaceAllIn(text, REDACTION_REPLACEMENT_TEXT) + } } } http://git-wip-us.apache.org/repos/asf/spark/blob/28315714/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index cf7e3eb..bdc8d92 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -23,6 +23,7 @@ import java.util.concurrent.atomic.AtomicReference import scala.collection.JavaConverters._ import scala.collection.immutable +import scala.util.matching.Regex import org.apache.hadoop.fs.Path @@ -1035,6 +1036,14 @@ object SQLConf { .booleanConf .createWithDefault(true) + val SQL_STRING_REDACTION_PATTERN = + ConfigBuilder("spark.sql.redaction.string.regex") + .doc("Regex to decide which parts of strings produced by Spark contain sensitive " + + "information. When this regex matches a string part, that string part is replaced by a " + + "dummy value. This is currently used to redact the output of SQL explain commands. " + + "When this conf is not set, the value from `spark.redaction.string.regex` is used.") + .fallbackConf(org.apache.spark.internal.config.STRING_REDACTION_PATTERN) + object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" } @@ -1173,6 +1182,8 @@ class SQLConf extends Serializable with Logging { def escapedStringLiterals: Boolean = getConf(ESCAPED_STRING_LITERALS) + def stringRedationPattern: Option[Regex] = SQL_STRING_REDACTION_PATTERN.readFrom(reader) + /** * Returns the [[Resolver]] for the current configuration, which can be used to determine if two * identifiers are equal. http://git-wip-us.apache.org/repos/asf/spark/blob/28315714/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index 747749b..27c7dc3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -69,7 +69,7 @@ trait DataSourceScanExec extends LeafExecNode with CodegenSupport { * Shorthand for calling redactString() without specifying redacting rules */ private def redact(text: String): String = { - Utils.redact(SparkSession.getActiveSession.map(_.sparkContext.conf).orNull, text) + Utils.redact(sqlContext.sessionState.conf.stringRedationPattern, text) } } http://git-wip-us.apache.org/repos/asf/spark/blob/28315714/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 946475a..8bfe3ef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -194,13 +194,13 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { } } - def simpleString: String = { + def simpleString: String = withRedaction { s"""== Physical Plan == |${stringOrError(executedPlan.treeString(verbose = false))} """.stripMargin.trim } - override def toString: String = { + override def toString: String = withRedaction { def output = Utils.truncatedString( analyzed.output.map(o => s"${o.name}: ${o.dataType.simpleString}"), ", ") val analyzedPlan = Seq( @@ -219,7 +219,7 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { """.stripMargin.trim } - def stringWithStats: String = { + def stringWithStats: String = withRedaction { // trigger to compute stats for logical plans optimizedPlan.stats @@ -231,6 +231,13 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { """.stripMargin.trim } + /** + * Redact the sensitive information in the given string. + */ + private def withRedaction(message: String): String = { + Utils.redact(sparkSession.sessionState.conf.stringRedationPattern, message) + } + /** A special namespace for commands that can be used to debug query execution. */ // scalastyle:off object debug { http://git-wip-us.apache.org/repos/asf/spark/blob/28315714/sql/core/src/test/scala/org/apache/spark/sql/execution/DataSourceScanExecRedactionSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/DataSourceScanExecRedactionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/DataSourceScanExecRedactionSuite.scala index 423e128..c8d045a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/DataSourceScanExecRedactionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/DataSourceScanExecRedactionSuite.scala @@ -20,6 +20,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.SparkConf import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext /** @@ -52,4 +53,34 @@ class DataSourceScanExecRedactionSuite extends QueryTest with SharedSQLContext { assert(df.queryExecution.simpleString.contains(replacement)) } } + + private def isIncluded(queryExecution: QueryExecution, msg: String): Boolean = { + queryExecution.toString.contains(msg) || + queryExecution.simpleString.contains(msg) || + queryExecution.stringWithStats.contains(msg) + } + + test("explain is redacted using SQLConf") { + withTempDir { dir => + val basePath = dir.getCanonicalPath + spark.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString) + val df = spark.read.parquet(basePath) + val replacement = "*********" + + // Respect SparkConf and replace file:/ + assert(isIncluded(df.queryExecution, replacement)) + + assert(isIncluded(df.queryExecution, "FileScan")) + assert(!isIncluded(df.queryExecution, "file:/")) + + withSQLConf(SQLConf.SQL_STRING_REDACTION_PATTERN.key -> "(?i)FileScan") { + // Respect SQLConf and replace FileScan + assert(isIncluded(df.queryExecution, replacement)) + + assert(!isIncluded(df.queryExecution, "FileScan")) + assert(isIncluded(df.queryExecution, "file:/")) + } + } + } + } http://git-wip-us.apache.org/repos/asf/spark/blob/28315714/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index 9e696b2..fa4b2dd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -28,7 +28,7 @@ import com.google.common.util.concurrent.UncheckedExecutionException import org.apache.commons.io.FileUtils import org.apache.hadoop.conf.Configuration -import org.apache.spark.SparkContext +import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.plans.logical.Range @@ -418,6 +418,37 @@ class StreamSuite extends StreamTest { assert(OutputMode.Update === InternalOutputModes.Update) } + override protected def sparkConf: SparkConf = super.sparkConf + .set("spark.redaction.string.regex", "file:/[\\w_]+") + + test("explain - redaction") { + val replacement = "*********" + + val inputData = MemoryStream[String] + val df = inputData.toDS().map(_ + "foo").groupBy("value").agg(count("*")) + // Test StreamingQuery.display + val q = df.writeStream.queryName("memory_explain").outputMode("complete").format("memory") + .start() + .asInstanceOf[StreamingQueryWrapper] + .streamingQuery + try { + inputData.addData("abc") + q.processAllAvailable() + + val explainWithoutExtended = q.explainInternal(false) + assert(explainWithoutExtended.contains(replacement)) + assert(explainWithoutExtended.contains("StateStoreRestore")) + assert(!explainWithoutExtended.contains("file:/")) + + val explainWithExtended = q.explainInternal(true) + assert(explainWithExtended.contains(replacement)) + assert(explainWithExtended.contains("StateStoreRestore")) + assert(!explainWithoutExtended.contains("file:/")) + } finally { + q.stop() + } + } + test("explain") { val inputData = MemoryStream[String] val df = inputData.toDS().map(_ + "foo").groupBy("value").agg(count("*")) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org