Repository: spark
Updated Branches:
  refs/heads/master 77ab49b85 -> b1a771231


[SPARK-12480][SQL] add Hash expression that can calculate hash value for a 
group of expressions

just write the arguments into unsafe row and use murmur3 to calculate hash code

Author: Wenchen Fan <wenc...@databricks.com>

Closes #10435 from cloud-fan/hash-expr.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/b1a77123
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/b1a77123
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/b1a77123

Branch: refs/heads/master
Commit: b1a771231e20df157fb3e780287390a883c0cc6f
Parents: 77ab49b
Author: Wenchen Fan <wenc...@databricks.com>
Authored: Mon Jan 4 18:49:41 2016 -0800
Committer: Reynold Xin <r...@databricks.com>
Committed: Mon Jan 4 18:49:41 2016 -0800

----------------------------------------------------------------------
 .../sql/catalyst/expressions/UnsafeRow.java     |  4 ++
 .../catalyst/analysis/FunctionRegistry.scala    |  3 +-
 .../spark/sql/catalyst/expressions/misc.scala   | 44 ++++++++++++
 .../sql/catalyst/encoders/RowEncoderSuite.scala |  2 +-
 .../expressions/MiscFunctionsSuite.scala        | 73 +++++++++++++++++++-
 .../scala/org/apache/spark/sql/functions.scala  | 11 +++
 .../org/apache/spark/sql/SQLQuerySuite.scala    | 10 +++
 .../hive/execution/HiveCompatibilitySuite.scala |  3 +
 .../apache/spark/sql/hive/test/TestHive.scala   | 24 +++++++
 .../sql/hive/execution/HiveQuerySuite.scala     |  3 -
 10 files changed, 171 insertions(+), 6 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/b1a77123/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
 
b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
index 1a35193..b8d3c49 100644
--- 
a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
+++ 
b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
@@ -566,6 +566,10 @@ public final class UnsafeRow extends MutableRow implements 
Externalizable, KryoS
     return Murmur3_x86_32.hashUnsafeWords(baseObject, baseOffset, sizeInBytes, 
42);
   }
 
+  public int hashCode(int seed) {
+    return Murmur3_x86_32.hashUnsafeWords(baseObject, baseOffset, sizeInBytes, 
seed);
+  }
+
   @Override
   public boolean equals(Object other) {
     if (other instanceof UnsafeRow) {

http://git-wip-us.apache.org/repos/asf/spark/blob/b1a77123/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 57d1a11..5c2aa3c 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -49,7 +49,7 @@ trait FunctionRegistry {
 
 class SimpleFunctionRegistry extends FunctionRegistry {
 
-  private val functionBuilders =
+  private[sql] val functionBuilders =
     StringKeyHashMap[(ExpressionInfo, FunctionBuilder)](caseSensitive = false)
 
   override def registerFunction(
@@ -278,6 +278,7 @@ object FunctionRegistry {
     // misc functions
     expression[Crc32]("crc32"),
     expression[Md5]("md5"),
+    expression[Murmur3Hash]("hash"),
     expression[Sha1]("sha"),
     expression[Sha1]("sha1"),
     expression[Sha2]("sha2"),

http://git-wip-us.apache.org/repos/asf/spark/blob/b1a77123/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
index d0ec99b..8834924 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
@@ -22,6 +22,8 @@ import java.util.zip.CRC32
 
 import org.apache.commons.codec.digest.DigestUtils
 
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
 import org.apache.spark.sql.catalyst.expressions.codegen._
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.UTF8String
@@ -177,3 +179,45 @@ case class Crc32(child: Expression) extends 
UnaryExpression with ImplicitCastInp
     })
   }
 }
+
+/**
+ * A function that calculates hash value for a group of expressions.  Note 
that the `seed` argument
+ * is not exposed to users and should only be set inside spark SQL.
+ *
+ * Internally this function will write arguments into an [[UnsafeRow]], and 
calculate hash code of
+ * the unsafe row using murmur3 hasher with a seed.
+ * We should use this hash function for both shuffle and bucket, so that we 
can guarantee shuffle
+ * and bucketing have same data distribution.
+ */
+case class Murmur3Hash(children: Seq[Expression], seed: Int) extends 
Expression {
+  def this(arguments: Seq[Expression]) = this(arguments, 42)
+
+  override def dataType: DataType = IntegerType
+
+  override def foldable: Boolean = children.forall(_.foldable)
+
+  override def nullable: Boolean = false
+
+  override def checkInputDataTypes(): TypeCheckResult = {
+    if (children.isEmpty) {
+      TypeCheckResult.TypeCheckFailure("arguments of function hash cannot be 
empty")
+    } else {
+      TypeCheckResult.TypeCheckSuccess
+    }
+  }
+
+  private lazy val unsafeProjection = UnsafeProjection.create(children)
+
+  override def eval(input: InternalRow): Any = {
+    unsafeProjection(input).hashCode(seed)
+  }
+
+  override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): 
String = {
+    val unsafeRow = GenerateUnsafeProjection.createCode(ctx, children)
+    ev.isNull = "false"
+    s"""
+      ${unsafeRow.code}
+      final int ${ev.value} = ${unsafeRow.value}.hashCode($seed);
+    """
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/b1a77123/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
index 8f4faab..b17f8d5 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
@@ -99,7 +99,7 @@ class RowEncoderSuite extends SparkFunSuite {
       .add("binary", BinaryType)
       .add("date", DateType)
       .add("timestamp", TimestampType)
-      .add("udt", new ExamplePointUDT, false))
+      .add("udt", new ExamplePointUDT))
 
   encodeDecodeTest(
     new StructType()

http://git-wip-us.apache.org/repos/asf/spark/blob/b1a77123/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala
index 75d1741..9175568 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala
@@ -20,7 +20,9 @@ package org.apache.spark.sql.catalyst.expressions
 import org.apache.commons.codec.digest.DigestUtils
 
 import org.apache.spark.SparkFunSuite
-import org.apache.spark.sql.types.{IntegerType, StringType, BinaryType}
+import org.apache.spark.sql.{Row, RandomDataGenerator}
+import org.apache.spark.sql.catalyst.encoders.{ExamplePointUDT, RowEncoder}
+import org.apache.spark.sql.types._
 
 class MiscFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
 
@@ -59,4 +61,73 @@ class MiscFunctionsSuite extends SparkFunSuite with 
ExpressionEvalHelper {
     checkEvaluation(Crc32(Literal.create(null, BinaryType)), null)
     checkConsistencyBetweenInterpretedAndCodegen(Crc32, BinaryType)
   }
+
+  private val structOfString = new StructType().add("str", StringType)
+  private val structOfUDT = new StructType().add("udt", new ExamplePointUDT, 
false)
+  private val arrayOfString = ArrayType(StringType)
+  private val arrayOfNull = ArrayType(NullType)
+  private val mapOfString = MapType(StringType, StringType)
+  private val arrayOfUDT = ArrayType(new ExamplePointUDT, false)
+
+  testMurmur3Hash(
+    new StructType()
+      .add("null", NullType)
+      .add("boolean", BooleanType)
+      .add("byte", ByteType)
+      .add("short", ShortType)
+      .add("int", IntegerType)
+      .add("long", LongType)
+      .add("float", FloatType)
+      .add("double", DoubleType)
+      .add("decimal", DecimalType.SYSTEM_DEFAULT)
+      .add("string", StringType)
+      .add("binary", BinaryType)
+      .add("date", DateType)
+      .add("timestamp", TimestampType)
+      .add("udt", new ExamplePointUDT))
+
+  testMurmur3Hash(
+    new StructType()
+      .add("arrayOfNull", arrayOfNull)
+      .add("arrayOfString", arrayOfString)
+      .add("arrayOfArrayOfString", ArrayType(arrayOfString))
+      .add("arrayOfArrayOfInt", ArrayType(ArrayType(IntegerType)))
+      .add("arrayOfMap", ArrayType(mapOfString))
+      .add("arrayOfStruct", ArrayType(structOfString))
+      .add("arrayOfUDT", arrayOfUDT))
+
+  testMurmur3Hash(
+    new StructType()
+      .add("mapOfIntAndString", MapType(IntegerType, StringType))
+      .add("mapOfStringAndArray", MapType(StringType, arrayOfString))
+      .add("mapOfArrayAndInt", MapType(arrayOfString, IntegerType))
+      .add("mapOfArray", MapType(arrayOfString, arrayOfString))
+      .add("mapOfStringAndStruct", MapType(StringType, structOfString))
+      .add("mapOfStructAndString", MapType(structOfString, StringType))
+      .add("mapOfStruct", MapType(structOfString, structOfString)))
+
+  testMurmur3Hash(
+    new StructType()
+      .add("structOfString", structOfString)
+      .add("structOfStructOfString", new StructType().add("struct", 
structOfString))
+      .add("structOfArray", new StructType().add("array", arrayOfString))
+      .add("structOfMap", new StructType().add("map", mapOfString))
+      .add("structOfArrayAndMap",
+        new StructType().add("array", arrayOfString).add("map", mapOfString))
+      .add("structOfUDT", structOfUDT))
+
+  private def testMurmur3Hash(inputSchema: StructType): Unit = {
+    val inputGenerator = RandomDataGenerator.forType(inputSchema, nullable = 
false).get
+    val encoder = RowEncoder(inputSchema)
+    val seed = scala.util.Random.nextInt()
+    test(s"murmur3 hash: ${inputSchema.simpleString}") {
+      for (_ <- 1 to 10) {
+        val input = 
encoder.toRow(inputGenerator.apply().asInstanceOf[Row]).asInstanceOf[UnsafeRow]
+        val literals = 
input.toSeq(inputSchema).zip(inputSchema.map(_.dataType)).map {
+          case (value, dt) => Literal.create(value, dt)
+        }
+        checkEvaluation(Murmur3Hash(literals, seed), input.hashCode(seed))
+      }
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/b1a77123/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 2b3db39..e223e32 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -1813,6 +1813,17 @@ object functions extends LegacyFunctions {
    */
   def crc32(e: Column): Column = withExpr { Crc32(e.expr) }
 
+  /**
+   * Calculates the hash code of given columns, and returns the result as a 
int column.
+   *
+   * @group misc_funcs
+   * @since 2.0
+   */
+  @scala.annotation.varargs
+  def hash(col: Column, cols: Column*): Column = withExpr {
+    new Murmur3Hash((col +: cols).map(_.expr))
+  }
+
   
//////////////////////////////////////////////////////////////////////////////////////////////
   // String functions
   
//////////////////////////////////////////////////////////////////////////////////////////////

http://git-wip-us.apache.org/repos/asf/spark/blob/b1a77123/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 115b617..7284571 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -2057,4 +2057,14 @@ class SQLQuerySuite extends QueryTest with 
SharedSQLContext {
     )
   }
 
+  test("hash function") {
+    val df = Seq(1 -> "a", 2 -> "b").toDF("i", "j")
+    withTempTable("tbl") {
+      df.registerTempTable("tbl")
+      checkAnswer(
+        df.select(hash($"i", $"j")),
+        sql("SELECT hash(i, j) from tbl")
+      )
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/b1a77123/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
 
b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
index 2b0e48d..bd1a52e 100644
--- 
a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
+++ 
b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
@@ -53,6 +53,8 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with 
BeforeAndAfter {
     TestHive.setConf(SQLConf.COLUMN_BATCH_SIZE, 5)
     // Enable in-memory partition pruning for testing purposes
     TestHive.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, true)
+    // Use Hive hash expression instead of the native one
+    TestHive.functionRegistry.unregisterFunction("hash")
     RuleExecutor.resetTime()
   }
 
@@ -62,6 +64,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with 
BeforeAndAfter {
     Locale.setDefault(originalLocale)
     TestHive.setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize)
     TestHive.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, 
originalInMemoryPartitionPruning)
+    TestHive.functionRegistry.restore()
 
     // For debugging dump some statistics about how much time was spent in 
various optimizer rules.
     logWarning(RuleExecutor.dumpTimeSpent())

http://git-wip-us.apache.org/repos/asf/spark/blob/b1a77123/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala
----------------------------------------------------------------------
diff --git 
a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala 
b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala
index 013fbab..66d5f20 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala
@@ -31,10 +31,13 @@ import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe
 
 import org.apache.spark.sql.{SQLContext, SQLConf}
 import org.apache.spark.sql.catalyst.analysis._
+import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
+import org.apache.spark.sql.catalyst.expressions.ExpressionInfo
 import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
 import org.apache.spark.sql.execution.CacheTableCommand
 import org.apache.spark.sql.hive._
 import org.apache.spark.sql.hive.execution.HiveNativeCommand
+import org.apache.spark.sql.hive.client.ClientWrapper
 import org.apache.spark.util.{ShutdownHookManager, Utils}
 import org.apache.spark.{SparkConf, SparkContext}
 
@@ -451,6 +454,27 @@ class TestHiveContext(sc: SparkContext) extends 
HiveContext(sc) {
         logError("FATAL ERROR: Failed to reset TestDB state.", e)
     }
   }
+
+  @transient
+  override protected[sql] lazy val functionRegistry = new 
TestHiveFunctionRegistry(
+    org.apache.spark.sql.catalyst.analysis.FunctionRegistry.builtin.copy(), 
this.executionHive)
+}
+
+private[hive] class TestHiveFunctionRegistry(fr: SimpleFunctionRegistry, 
client: ClientWrapper)
+  extends HiveFunctionRegistry(fr, client) {
+
+  private val removedFunctions =
+    collection.mutable.ArrayBuffer.empty[(String, (ExpressionInfo, 
FunctionBuilder))]
+
+  def unregisterFunction(name: String): Unit = {
+    fr.functionBuilders.remove(name).foreach(f => removedFunctions += name -> 
f)
+  }
+
+  def restore(): Unit = {
+    removedFunctions.foreach {
+      case (name, (info, builder)) => fr.registerFunction(name, info, builder)
+    }
+  }
 }
 
 private[hive] object TestHiveContext {

http://git-wip-us.apache.org/repos/asf/spark/blob/b1a77123/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
 
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
index 8a5acaf..acd1130 100644
--- 
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
+++ 
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
@@ -387,9 +387,6 @@ class HiveQuerySuite extends HiveComparisonTest with 
BeforeAndAfter {
   createQueryTest("partitioned table scan",
     "SELECT ds, hr, key, value FROM srcpart")
 
-  createQueryTest("hash",
-    "SELECT hash('test') FROM src LIMIT 1")
-
   createQueryTest("create table as",
     """
       |CREATE TABLE createdtable AS SELECT * FROM src;


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to