Repository: spark
Updated Branches:
  refs/heads/master 9b6baeb7b -> ab3302895


[SPARK-25178][SQL] Directly ship the StructType objects of the keySchema / 
valueSchema for xxxHashMapGenerator

## What changes were proposed in this pull request?

This PR generates the code that to refer a `StructType` generated in the scala 
code instead of generating `StructType` in Java code.

The original code has two issues.
1. Avoid to used the field name such as `key.name`
1. Support complicated schema (e.g. nested DataType)

At first, [the JIRA entry](https://issues.apache.org/jira/browse/SPARK-25178) 
proposed to change the generated field name of the keySchema / valueSchema to a 
dummy name in `RowBasedHashMapGenerator` and 
`VectorizedHashMapGenerator.scala`. This proposal can addresse issue 1.

Ueshin suggested an approach to refer to a `StructType` generated in the scala 
code using `ctx.addReferenceObj()`. This approach can address issues 1 and 2. 
Finally, this PR uses this approach.

## How was this patch tested?

Existing UTs

Closes #22187 from kiszk/SPARK-25178.

Authored-by: Kazuaki Ishizaki <ishiz...@jp.ibm.com>
Signed-off-by: Takuya UESHIN <ues...@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/ab330289
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/ab330289
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/ab330289

Branch: refs/heads/master
Commit: ab33028957443189efc4106afd9d65dddf8f9c98
Parents: 9b6baeb
Author: Kazuaki Ishizaki <ishiz...@jp.ibm.com>
Authored: Fri Aug 24 14:58:55 2018 +0900
Committer: Takuya UESHIN <ues...@databricks.com>
Committed: Fri Aug 24 14:58:55 2018 +0900

----------------------------------------------------------------------
 .../aggregate/RowBasedHashMapGenerator.scala    | 33 +++--------------
 .../aggregate/VectorizedHashMapGenerator.scala  | 37 ++++----------------
 2 files changed, 10 insertions(+), 60 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/ab330289/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala
index d550827..ca59bb1 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala
@@ -44,31 +44,8 @@ class RowBasedHashMapGenerator(
     groupingKeySchema, bufferSchema) {
 
   override protected def initializeAggregateHashMap(): String = {
-    val generatedKeySchema: String =
-      s"new org.apache.spark.sql.types.StructType()" +
-        groupingKeySchema.map { key =>
-          val keyName = ctx.addReferenceObj("keyName", key.name)
-          key.dataType match {
-            case d: DecimalType =>
-              s""".add($keyName, 
org.apache.spark.sql.types.DataTypes.createDecimalType(
-                  |${d.precision}, ${d.scale}))""".stripMargin
-            case _ =>
-              s""".add($keyName, 
org.apache.spark.sql.types.DataTypes.${key.dataType})"""
-          }
-        }.mkString("\n").concat(";")
-
-    val generatedValueSchema: String =
-      s"new org.apache.spark.sql.types.StructType()" +
-        bufferSchema.map { key =>
-          val keyName = ctx.addReferenceObj("keyName", key.name)
-          key.dataType match {
-            case d: DecimalType =>
-              s""".add($keyName, 
org.apache.spark.sql.types.DataTypes.createDecimalType(
-                  |${d.precision}, ${d.scale}))""".stripMargin
-            case _ =>
-              s""".add($keyName, 
org.apache.spark.sql.types.DataTypes.${key.dataType})"""
-          }
-        }.mkString("\n").concat(";")
+    val keySchema = ctx.addReferenceObj("keySchemaTerm", groupingKeySchema)
+    val valueSchema = ctx.addReferenceObj("valueSchemaTerm", bufferSchema)
 
     s"""
        |  private 
org.apache.spark.sql.catalyst.expressions.RowBasedKeyValueBatch batch;
@@ -78,8 +55,6 @@ class RowBasedHashMapGenerator(
        |  private int numBuckets = (int) (capacity / loadFactor);
        |  private int maxSteps = 2;
        |  private int numRows = 0;
-       |  private org.apache.spark.sql.types.StructType keySchema = 
$generatedKeySchema
-       |  private org.apache.spark.sql.types.StructType valueSchema = 
$generatedValueSchema
        |  private Object emptyVBase;
        |  private long emptyVOff;
        |  private int emptyVLen;
@@ -90,9 +65,9 @@ class RowBasedHashMapGenerator(
        |    org.apache.spark.memory.TaskMemoryManager taskMemoryManager,
        |    InternalRow emptyAggregationBuffer) {
        |    batch = 
org.apache.spark.sql.catalyst.expressions.RowBasedKeyValueBatch
-       |      .allocate(keySchema, valueSchema, taskMemoryManager, capacity);
+       |      .allocate($keySchema, $valueSchema, taskMemoryManager, capacity);
        |
-       |    final UnsafeProjection valueProjection = 
UnsafeProjection.create(valueSchema);
+       |    final UnsafeProjection valueProjection = 
UnsafeProjection.create($valueSchema);
        |    final byte[] emptyBuffer = 
valueProjection.apply(emptyAggregationBuffer).getBytes();
        |
        |    emptyVBase = emptyBuffer;

http://git-wip-us.apache.org/repos/asf/spark/blob/ab330289/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala
index 7b3580c..95ebefe 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala
@@ -52,31 +52,9 @@ class VectorizedHashMapGenerator(
     groupingKeySchema, bufferSchema) {
 
   override protected def initializeAggregateHashMap(): String = {
-    val generatedSchema: String =
-      s"new org.apache.spark.sql.types.StructType()" +
-        (groupingKeySchema ++ bufferSchema).map { key =>
-          val keyName = ctx.addReferenceObj("keyName", key.name)
-          key.dataType match {
-            case d: DecimalType =>
-              s""".add($keyName, 
org.apache.spark.sql.types.DataTypes.createDecimalType(
-                  |${d.precision}, ${d.scale}))""".stripMargin
-            case _ =>
-              s""".add($keyName, 
org.apache.spark.sql.types.DataTypes.${key.dataType})"""
-          }
-        }.mkString("\n").concat(";")
-
-    val generatedAggBufferSchema: String =
-      s"new org.apache.spark.sql.types.StructType()" +
-        bufferSchema.map { key =>
-          val keyName = ctx.addReferenceObj("keyName", key.name)
-          key.dataType match {
-            case d: DecimalType =>
-              s""".add($keyName, 
org.apache.spark.sql.types.DataTypes.createDecimalType(
-                  |${d.precision}, ${d.scale}))""".stripMargin
-            case _ =>
-              s""".add($keyName, 
org.apache.spark.sql.types.DataTypes.${key.dataType})"""
-          }
-        }.mkString("\n").concat(";")
+    val schemaStructType = new StructType((groupingKeySchema ++ 
bufferSchema).toArray)
+    val schema = ctx.addReferenceObj("schemaTerm", schemaStructType)
+    val aggBufferSchemaFieldsLength = bufferSchema.fields.length
 
     s"""
        |  private ${classOf[OnHeapColumnVector].getName}[] vectors;
@@ -88,18 +66,15 @@ class VectorizedHashMapGenerator(
        |  private int numBuckets = (int) (capacity / loadFactor);
        |  private int maxSteps = 2;
        |  private int numRows = 0;
-       |  private org.apache.spark.sql.types.StructType schema = 
$generatedSchema
-       |  private org.apache.spark.sql.types.StructType aggregateBufferSchema =
-       |    $generatedAggBufferSchema
        |
        |  public $generatedClassName() {
-       |    vectors = 
${classOf[OnHeapColumnVector].getName}.allocateColumns(capacity, schema);
+       |    vectors = 
${classOf[OnHeapColumnVector].getName}.allocateColumns(capacity, $schema);
        |    batch = new ${classOf[ColumnarBatch].getName}(vectors);
        |
        |    // Generates a projection to return the aggregate buffer only.
        |    ${classOf[OnHeapColumnVector].getName}[] aggBufferVectors =
-       |      new 
${classOf[OnHeapColumnVector].getName}[aggregateBufferSchema.fields().length];
-       |    for (int i = 0; i < aggregateBufferSchema.fields().length; i++) {
+       |      new 
${classOf[OnHeapColumnVector].getName}[$aggBufferSchemaFieldsLength];
+       |    for (int i = 0; i < $aggBufferSchemaFieldsLength; i++) {
        |      aggBufferVectors[i] = vectors[i + ${groupingKeys.length}];
        |    }
        |    aggBufferRow = new 
${classOf[MutableColumnarRow].getName}(aggBufferVectors);


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to