Repository: spark Updated Branches: refs/heads/master 9b6baeb7b -> ab3302895
[SPARK-25178][SQL] Directly ship the StructType objects of the keySchema / valueSchema for xxxHashMapGenerator ## What changes were proposed in this pull request? This PR generates the code that to refer a `StructType` generated in the scala code instead of generating `StructType` in Java code. The original code has two issues. 1. Avoid to used the field name such as `key.name` 1. Support complicated schema (e.g. nested DataType) At first, [the JIRA entry](https://issues.apache.org/jira/browse/SPARK-25178) proposed to change the generated field name of the keySchema / valueSchema to a dummy name in `RowBasedHashMapGenerator` and `VectorizedHashMapGenerator.scala`. This proposal can addresse issue 1. Ueshin suggested an approach to refer to a `StructType` generated in the scala code using `ctx.addReferenceObj()`. This approach can address issues 1 and 2. Finally, this PR uses this approach. ## How was this patch tested? Existing UTs Closes #22187 from kiszk/SPARK-25178. Authored-by: Kazuaki Ishizaki <ishiz...@jp.ibm.com> Signed-off-by: Takuya UESHIN <ues...@databricks.com> Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/ab330289 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/ab330289 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/ab330289 Branch: refs/heads/master Commit: ab33028957443189efc4106afd9d65dddf8f9c98 Parents: 9b6baeb Author: Kazuaki Ishizaki <ishiz...@jp.ibm.com> Authored: Fri Aug 24 14:58:55 2018 +0900 Committer: Takuya UESHIN <ues...@databricks.com> Committed: Fri Aug 24 14:58:55 2018 +0900 ---------------------------------------------------------------------- .../aggregate/RowBasedHashMapGenerator.scala | 33 +++-------------- .../aggregate/VectorizedHashMapGenerator.scala | 37 ++++---------------- 2 files changed, 10 insertions(+), 60 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/ab330289/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala index d550827..ca59bb1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala @@ -44,31 +44,8 @@ class RowBasedHashMapGenerator( groupingKeySchema, bufferSchema) { override protected def initializeAggregateHashMap(): String = { - val generatedKeySchema: String = - s"new org.apache.spark.sql.types.StructType()" + - groupingKeySchema.map { key => - val keyName = ctx.addReferenceObj("keyName", key.name) - key.dataType match { - case d: DecimalType => - s""".add($keyName, org.apache.spark.sql.types.DataTypes.createDecimalType( - |${d.precision}, ${d.scale}))""".stripMargin - case _ => - s""".add($keyName, org.apache.spark.sql.types.DataTypes.${key.dataType})""" - } - }.mkString("\n").concat(";") - - val generatedValueSchema: String = - s"new org.apache.spark.sql.types.StructType()" + - bufferSchema.map { key => - val keyName = ctx.addReferenceObj("keyName", key.name) - key.dataType match { - case d: DecimalType => - s""".add($keyName, org.apache.spark.sql.types.DataTypes.createDecimalType( - |${d.precision}, ${d.scale}))""".stripMargin - case _ => - s""".add($keyName, org.apache.spark.sql.types.DataTypes.${key.dataType})""" - } - }.mkString("\n").concat(";") + val keySchema = ctx.addReferenceObj("keySchemaTerm", groupingKeySchema) + val valueSchema = ctx.addReferenceObj("valueSchemaTerm", bufferSchema) s""" | private org.apache.spark.sql.catalyst.expressions.RowBasedKeyValueBatch batch; @@ -78,8 +55,6 @@ class RowBasedHashMapGenerator( | private int numBuckets = (int) (capacity / loadFactor); | private int maxSteps = 2; | private int numRows = 0; - | private org.apache.spark.sql.types.StructType keySchema = $generatedKeySchema - | private org.apache.spark.sql.types.StructType valueSchema = $generatedValueSchema | private Object emptyVBase; | private long emptyVOff; | private int emptyVLen; @@ -90,9 +65,9 @@ class RowBasedHashMapGenerator( | org.apache.spark.memory.TaskMemoryManager taskMemoryManager, | InternalRow emptyAggregationBuffer) { | batch = org.apache.spark.sql.catalyst.expressions.RowBasedKeyValueBatch - | .allocate(keySchema, valueSchema, taskMemoryManager, capacity); + | .allocate($keySchema, $valueSchema, taskMemoryManager, capacity); | - | final UnsafeProjection valueProjection = UnsafeProjection.create(valueSchema); + | final UnsafeProjection valueProjection = UnsafeProjection.create($valueSchema); | final byte[] emptyBuffer = valueProjection.apply(emptyAggregationBuffer).getBytes(); | | emptyVBase = emptyBuffer; http://git-wip-us.apache.org/repos/asf/spark/blob/ab330289/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala index 7b3580c..95ebefe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala @@ -52,31 +52,9 @@ class VectorizedHashMapGenerator( groupingKeySchema, bufferSchema) { override protected def initializeAggregateHashMap(): String = { - val generatedSchema: String = - s"new org.apache.spark.sql.types.StructType()" + - (groupingKeySchema ++ bufferSchema).map { key => - val keyName = ctx.addReferenceObj("keyName", key.name) - key.dataType match { - case d: DecimalType => - s""".add($keyName, org.apache.spark.sql.types.DataTypes.createDecimalType( - |${d.precision}, ${d.scale}))""".stripMargin - case _ => - s""".add($keyName, org.apache.spark.sql.types.DataTypes.${key.dataType})""" - } - }.mkString("\n").concat(";") - - val generatedAggBufferSchema: String = - s"new org.apache.spark.sql.types.StructType()" + - bufferSchema.map { key => - val keyName = ctx.addReferenceObj("keyName", key.name) - key.dataType match { - case d: DecimalType => - s""".add($keyName, org.apache.spark.sql.types.DataTypes.createDecimalType( - |${d.precision}, ${d.scale}))""".stripMargin - case _ => - s""".add($keyName, org.apache.spark.sql.types.DataTypes.${key.dataType})""" - } - }.mkString("\n").concat(";") + val schemaStructType = new StructType((groupingKeySchema ++ bufferSchema).toArray) + val schema = ctx.addReferenceObj("schemaTerm", schemaStructType) + val aggBufferSchemaFieldsLength = bufferSchema.fields.length s""" | private ${classOf[OnHeapColumnVector].getName}[] vectors; @@ -88,18 +66,15 @@ class VectorizedHashMapGenerator( | private int numBuckets = (int) (capacity / loadFactor); | private int maxSteps = 2; | private int numRows = 0; - | private org.apache.spark.sql.types.StructType schema = $generatedSchema - | private org.apache.spark.sql.types.StructType aggregateBufferSchema = - | $generatedAggBufferSchema | | public $generatedClassName() { - | vectors = ${classOf[OnHeapColumnVector].getName}.allocateColumns(capacity, schema); + | vectors = ${classOf[OnHeapColumnVector].getName}.allocateColumns(capacity, $schema); | batch = new ${classOf[ColumnarBatch].getName}(vectors); | | // Generates a projection to return the aggregate buffer only. | ${classOf[OnHeapColumnVector].getName}[] aggBufferVectors = - | new ${classOf[OnHeapColumnVector].getName}[aggregateBufferSchema.fields().length]; - | for (int i = 0; i < aggregateBufferSchema.fields().length; i++) { + | new ${classOf[OnHeapColumnVector].getName}[$aggBufferSchemaFieldsLength]; + | for (int i = 0; i < $aggBufferSchemaFieldsLength; i++) { | aggBufferVectors[i] = vectors[i + ${groupingKeys.length}]; | } | aggBufferRow = new ${classOf[MutableColumnarRow].getName}(aggBufferVectors); --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org