Repository: spark Updated Branches: refs/heads/master 41a7cdf85 -> 4a01bfc2a
[SPARK-9350][SQL] Introduce an InternalRow generic getter that requires a DataType Currently UnsafeRow cannot support a generic getter. However, if the data type is known, we can support a generic getter. Author: Reynold Xin <[email protected]> Closes #7666 from rxin/generic-getter-with-datatype and squashes the following commits: ee2874c [Reynold Xin] Add a default implementation for getStruct. 1e109a0 [Reynold Xin] [SPARK-9350][SQL] Introduce an InternalRow generic getter that requires a DataType. 033ee88 [Reynold Xin] Removed getAs in non test code. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/4a01bfc2 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/4a01bfc2 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/4a01bfc2 Branch: refs/heads/master Commit: 4a01bfc2a2e664186028ea32095d32d29c9f9e38 Parents: 41a7cdf Author: Reynold Xin <[email protected]> Authored: Sat Jul 25 23:52:37 2015 -0700 Committer: Reynold Xin <[email protected]> Committed: Sat Jul 25 23:52:37 2015 -0700 ---------------------------------------------------------------------- .../apache/spark/mllib/linalg/Matrices.scala | 8 +++-- .../org/apache/spark/mllib/linalg/Vectors.scala | 9 ++++-- .../sql/catalyst/expressions/UnsafeRow.java | 5 --- .../sql/catalyst/CatalystTypeConverters.scala | 16 +++++----- .../apache/spark/sql/catalyst/InternalRow.scala | 32 +++++++++++--------- .../catalyst/expressions/BoundAttribute.scala | 2 +- .../spark/sql/catalyst/expressions/Cast.scala | 5 +-- .../sql/catalyst/expressions/Projection.scala | 8 +++++ .../expressions/SpecificMutableRow.scala | 10 +++--- .../sql/catalyst/expressions/aggregates.scala | 2 +- .../expressions/complexTypeCreator.scala | 2 +- .../expressions/complexTypeExtractors.scala | 4 +-- .../spark/sql/catalyst/expressions/rows.scala | 8 +++++ .../expressions/ExpressionEvalHelper.scala | 7 +++-- .../expressions/MathFunctionsSuite.scala | 2 +- .../expressions/UnsafeRowConverterSuite.scala | 2 +- .../sql/execution/SparkSqlSerializer2.scala | 4 +-- .../spark/sql/execution/basicOperators.scala | 4 +-- .../datasources/DataSourceStrategy.scala | 4 +-- .../spark/sql/execution/debug/package.scala | 2 +- .../apache/spark/sql/execution/pythonUDFs.scala | 2 +- .../sql/execution/stat/FrequentItems.scala | 8 ++--- .../spark/sql/expressions/aggregate/udaf.scala | 10 ++++-- .../sql/parquet/ParquetTableOperations.scala | 2 +- .../spark/sql/parquet/ParquetTableSupport.scala | 4 +-- .../scala/org/apache/spark/sql/RowSuite.scala | 11 ++++--- .../hive/execution/InsertIntoHiveTable.scala | 4 ++- .../apache/spark/sql/hive/orc/OrcRelation.scala | 2 +- 28 files changed, 105 insertions(+), 74 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/4a01bfc2/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala index b6e2c30..d82ba24 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala @@ -179,12 +179,14 @@ private[spark] class MatrixUDT extends UserDefinedType[Matrix] { val tpe = row.getByte(0) val numRows = row.getInt(1) val numCols = row.getInt(2) - val values = row.getAs[Iterable[Double]](5).toArray + val values = row.getAs[Seq[Double]](5, ArrayType(DoubleType, containsNull = false)).toArray val isTransposed = row.getBoolean(6) tpe match { case 0 => - val colPtrs = row.getAs[Iterable[Int]](3).toArray - val rowIndices = row.getAs[Iterable[Int]](4).toArray + val colPtrs = + row.getAs[Seq[Int]](3, ArrayType(IntegerType, containsNull = false)).toArray + val rowIndices = + row.getAs[Seq[Int]](4, ArrayType(IntegerType, containsNull = false)).toArray new SparseMatrix(numRows, numCols, colPtrs, rowIndices, values, isTransposed) case 1 => new DenseMatrix(numRows, numCols, values, isTransposed) http://git-wip-us.apache.org/repos/asf/spark/blob/4a01bfc2/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index c884aad..0cb28d7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -209,11 +209,14 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] { tpe match { case 0 => val size = row.getInt(1) - val indices = row.getAs[Iterable[Int]](2).toArray - val values = row.getAs[Iterable[Double]](3).toArray + val indices = + row.getAs[Seq[Int]](2, ArrayType(IntegerType, containsNull = false)).toArray + val values = + row.getAs[Seq[Double]](3, ArrayType(DoubleType, containsNull = false)).toArray new SparseVector(size, indices, values) case 1 => - val values = row.getAs[Iterable[Double]](3).toArray + val values = + row.getAs[Seq[Double]](3, ArrayType(DoubleType, containsNull = false)).toArray new DenseVector(values) } } http://git-wip-us.apache.org/repos/asf/spark/blob/4a01bfc2/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 9be9089..87e5a89 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -236,11 +236,6 @@ public final class UnsafeRow extends MutableRow { } @Override - public <T> T getAs(int ordinal) { - throw new UnsupportedOperationException(); - } - - @Override public boolean isNullAt(int ordinal) { assertIndexIsValid(ordinal); return BitSetMethods.isSet(baseObject, baseOffset, ordinal); http://git-wip-us.apache.org/repos/asf/spark/blob/4a01bfc2/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index 7416ddb..d1d89a1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -77,7 +77,7 @@ object CatalystTypeConverters { case LongType => LongConverter case FloatType => FloatConverter case DoubleType => DoubleConverter - case _ => IdentityConverter + case dataType: DataType => IdentityConverter(dataType) } converter.asInstanceOf[CatalystTypeConverter[Any, Any, Any]] } @@ -137,17 +137,19 @@ object CatalystTypeConverters { protected def toScalaImpl(row: InternalRow, column: Int): ScalaOutputType } - private object IdentityConverter extends CatalystTypeConverter[Any, Any, Any] { + private case class IdentityConverter(dataType: DataType) + extends CatalystTypeConverter[Any, Any, Any] { override def toCatalystImpl(scalaValue: Any): Any = scalaValue override def toScala(catalystValue: Any): Any = catalystValue - override def toScalaImpl(row: InternalRow, column: Int): Any = row.get(column) + override def toScalaImpl(row: InternalRow, column: Int): Any = row.get(column, dataType) } private case class UDTConverter( udt: UserDefinedType[_]) extends CatalystTypeConverter[Any, Any, Any] { override def toCatalystImpl(scalaValue: Any): Any = udt.serialize(scalaValue) override def toScala(catalystValue: Any): Any = udt.deserialize(catalystValue) - override def toScalaImpl(row: InternalRow, column: Int): Any = toScala(row.get(column)) + override def toScalaImpl(row: InternalRow, column: Int): Any = + toScala(row.get(column, udt.sqlType)) } /** Converter for arrays, sequences, and Java iterables. */ @@ -184,7 +186,7 @@ object CatalystTypeConverters { } override def toScalaImpl(row: InternalRow, column: Int): Seq[Any] = - toScala(row.get(column).asInstanceOf[Seq[Any]]) + toScala(row.get(column, ArrayType(elementType)).asInstanceOf[Seq[Any]]) } private case class MapConverter( @@ -227,7 +229,7 @@ object CatalystTypeConverters { } override def toScalaImpl(row: InternalRow, column: Int): Map[Any, Any] = - toScala(row.get(column).asInstanceOf[Map[Any, Any]]) + toScala(row.get(column, MapType(keyType, valueType)).asInstanceOf[Map[Any, Any]]) } private case class StructConverter( @@ -311,7 +313,7 @@ object CatalystTypeConverters { } override def toScala(catalystValue: Decimal): JavaBigDecimal = catalystValue.toJavaBigDecimal override def toScalaImpl(row: InternalRow, column: Int): JavaBigDecimal = - row.get(column).asInstanceOf[Decimal].toJavaBigDecimal + row.getDecimal(column).toJavaBigDecimal } private abstract class PrimitiveConverter[T] extends CatalystTypeConverter[T, Any, Any] { http://git-wip-us.apache.org/repos/asf/spark/blob/4a01bfc2/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala index 37f0f57..385d967 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.types.Decimal +import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String /** @@ -32,32 +32,36 @@ abstract class InternalRow extends Serializable { def get(ordinal: Int): Any - def getAs[T](ordinal: Int): T = get(ordinal).asInstanceOf[T] + def genericGet(ordinal: Int): Any = get(ordinal, null) + + def get(ordinal: Int, dataType: DataType): Any = get(ordinal) + + def getAs[T](ordinal: Int, dataType: DataType): T = get(ordinal, dataType).asInstanceOf[T] def isNullAt(ordinal: Int): Boolean = get(ordinal) == null - def getBoolean(ordinal: Int): Boolean = getAs[Boolean](ordinal) + def getBoolean(ordinal: Int): Boolean = getAs[Boolean](ordinal, BooleanType) - def getByte(ordinal: Int): Byte = getAs[Byte](ordinal) + def getByte(ordinal: Int): Byte = getAs[Byte](ordinal, ByteType) - def getShort(ordinal: Int): Short = getAs[Short](ordinal) + def getShort(ordinal: Int): Short = getAs[Short](ordinal, ShortType) - def getInt(ordinal: Int): Int = getAs[Int](ordinal) + def getInt(ordinal: Int): Int = getAs[Int](ordinal, IntegerType) - def getLong(ordinal: Int): Long = getAs[Long](ordinal) + def getLong(ordinal: Int): Long = getAs[Long](ordinal, LongType) - def getFloat(ordinal: Int): Float = getAs[Float](ordinal) + def getFloat(ordinal: Int): Float = getAs[Float](ordinal, FloatType) - def getDouble(ordinal: Int): Double = getAs[Double](ordinal) + def getDouble(ordinal: Int): Double = getAs[Double](ordinal, DoubleType) - def getUTF8String(ordinal: Int): UTF8String = getAs[UTF8String](ordinal) + def getUTF8String(ordinal: Int): UTF8String = getAs[UTF8String](ordinal, StringType) - def getBinary(ordinal: Int): Array[Byte] = getAs[Array[Byte]](ordinal) + def getBinary(ordinal: Int): Array[Byte] = getAs[Array[Byte]](ordinal, BinaryType) - def getDecimal(ordinal: Int): Decimal = getAs[Decimal](ordinal) + def getDecimal(ordinal: Int): Decimal = getAs[Decimal](ordinal, DecimalType.SYSTEM_DEFAULT) // This is only use for test and will throw a null pointer exception if the position is null. - def getString(ordinal: Int): String = getAs[UTF8String](ordinal).toString + def getString(ordinal: Int): String = getUTF8String(ordinal).toString /** * Returns a struct from ordinal position. @@ -65,7 +69,7 @@ abstract class InternalRow extends Serializable { * @param ordinal position to get the struct from. * @param numFields number of fields the struct type has */ - def getStruct(ordinal: Int, numFields: Int): InternalRow = getAs[InternalRow](ordinal) + def getStruct(ordinal: Int, numFields: Int): InternalRow = getAs[InternalRow](ordinal, null) override def toString: String = s"[${this.mkString(",")}]" http://git-wip-us.apache.org/repos/asf/spark/blob/4a01bfc2/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index 1f7adcd..6b5c450 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -49,7 +49,7 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) case StringType => input.getUTF8String(ordinal) case BinaryType => input.getBinary(ordinal) case t: StructType => input.getStruct(ordinal, t.size) - case _ => input.get(ordinal) + case dataType => input.get(ordinal, dataType) } } } http://git-wip-us.apache.org/repos/asf/spark/blob/4a01bfc2/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 47ad3e0..e5b83cd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -375,7 +375,7 @@ case class Cast(child: Expression, dataType: DataType) } private[this] def castStruct(from: StructType, to: StructType): Any => Any = { - val casts = from.fields.zip(to.fields).map { + val castFuncs: Array[(Any) => Any] = from.fields.zip(to.fields).map { case (fromField, toField) => cast(fromField.dataType, toField.dataType) } // TODO: Could be faster? @@ -383,7 +383,8 @@ case class Cast(child: Expression, dataType: DataType) buildCast[InternalRow](_, row => { var i = 0 while (i < row.numFields) { - newRow.update(i, if (row.isNullAt(i)) null else casts(i)(row.get(i))) + newRow.update(i, + if (row.isNullAt(i)) null else castFuncs(i)(row.get(i, from.apply(i).dataType))) i += 1 } newRow.copy() http://git-wip-us.apache.org/repos/asf/spark/blob/4a01bfc2/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index c1ed9cf..cc89d74 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -225,6 +225,14 @@ class JoinedRow extends InternalRow { override def getFloat(i: Int): Float = if (i < row1.numFields) row1.getFloat(i) else row2.getFloat(i - row1.numFields) + override def getStruct(i: Int, numFields: Int): InternalRow = { + if (i < row1.numFields) { + row1.getStruct(i, numFields) + } else { + row2.getStruct(i - row1.numFields, numFields) + } + } + override def copy(): InternalRow = { val totalSize = row1.numFields + row2.numFields val copiedValues = new Array[Any](totalSize) http://git-wip-us.apache.org/repos/asf/spark/blob/4a01bfc2/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala index 4b4833b..5953a09 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala @@ -221,6 +221,10 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR override def get(i: Int): Any = values(i).boxed + override def getStruct(ordinal: Int, numFields: Int): InternalRow = { + values(ordinal).boxed.asInstanceOf[InternalRow] + } + override def isNullAt(i: Int): Boolean = values(i).isNull override def copy(): InternalRow = { @@ -245,8 +249,6 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR override def setString(ordinal: Int, value: String): Unit = update(ordinal, UTF8String.fromString(value)) - override def getString(ordinal: Int): String = get(ordinal).toString - override def setInt(ordinal: Int, value: Int): Unit = { val currentValue = values(ordinal).asInstanceOf[MutableInt] currentValue.isNull = false @@ -316,8 +318,4 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR override def getByte(i: Int): Byte = { values(i).asInstanceOf[MutableByte].value } - - override def getAs[T](i: Int): T = { - values(i).boxed.asInstanceOf[T] - } } http://git-wip-us.apache.org/repos/asf/spark/blob/4a01bfc2/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index 62b6cc8..42343d4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -685,7 +685,7 @@ case class CombineSetsAndSumFunction( null } else { Cast(Literal( - casted.iterator.map(f => f.get(0)).reduceLeft( + casted.iterator.map(f => f.genericGet(0)).reduceLeft( base.dataType.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].plus)), base.dataType).eval(null) } http://git-wip-us.apache.org/repos/asf/spark/blob/4a01bfc2/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 20b1eaa..119168f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -21,7 +21,7 @@ import scala.collection.mutable import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} +import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, GeneratedExpressionCode, CodeGenContext} import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ http://git-wip-us.apache.org/repos/asf/spark/blob/4a01bfc2/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index c91122c..6331a9e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -110,7 +110,7 @@ case class GetStructField(child: Expression, field: StructField, ordinal: Int) override def toString: String = s"$child.${field.name}" protected override def nullSafeEval(input: Any): Any = - input.asInstanceOf[InternalRow].get(ordinal) + input.asInstanceOf[InternalRow].get(ordinal, field.dataType) override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { nullSafeCodeGen(ctx, ev, eval => { @@ -142,7 +142,7 @@ case class GetArrayStructFields( protected override def nullSafeEval(input: Any): Any = { input.asInstanceOf[Seq[InternalRow]].map { row => - if (row == null) null else row.get(ordinal) + if (row == null) null else row.get(ordinal, field.dataType) } } http://git-wip-us.apache.org/repos/asf/spark/blob/4a01bfc2/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index 53779dd..daeabe8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -101,6 +101,10 @@ class GenericInternalRow(protected[sql] val values: Array[Any]) extends Internal override def get(i: Int): Any = values(i) + override def getStruct(ordinal: Int, numFields: Int): InternalRow = { + values(ordinal).asInstanceOf[InternalRow] + } + override def copy(): InternalRow = this } @@ -128,6 +132,10 @@ class GenericMutableRow(val values: Array[Any]) extends MutableRow { override def get(i: Int): Any = values(i) + override def getStruct(ordinal: Int, numFields: Int): InternalRow = { + values(ordinal).asInstanceOf[InternalRow] + } + override def setNullAt(i: Int): Unit = { values(i) = null} override def update(i: Int, value: Any): Unit = { values(i) = value } http://git-wip-us.apache.org/repos/asf/spark/blob/4a01bfc2/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index 852a8b2..8b0f90c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -113,7 +113,7 @@ trait ExpressionEvalHelper { GenerateMutableProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil)(), expression) - val actual = plan(inputRow).get(0) + val actual = plan(inputRow).get(0, expression.dataType) if (!checkResult(actual, expected)) { val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expected$input") @@ -194,13 +194,14 @@ trait ExpressionEvalHelper { var plan = generateProject( GenerateProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil), expression) - var actual = plan(inputRow).get(0) + var actual = plan(inputRow).get(0, expression.dataType) assert(checkResult(actual, expected)) plan = generateProject( GenerateUnsafeProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil), expression) - actual = FromUnsafeProjection(expression.dataType :: Nil)(plan(inputRow)).get(0) + actual = FromUnsafeProjection(expression.dataType :: Nil)( + plan(inputRow)).get(0, expression.dataType) assert(checkResult(actual, expected)) } } http://git-wip-us.apache.org/repos/asf/spark/blob/4a01bfc2/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala index 6caf8ba..21459a7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala @@ -158,7 +158,7 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { GenerateMutableProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil)(), expression) - val actual = plan(inputRow).get(0) + val actual = plan(inputRow).get(0, expression.dataType) if (!actual.asInstanceOf[Double].isNaN) { fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: NaN") } http://git-wip-us.apache.org/repos/asf/spark/blob/4a01bfc2/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala index 4606bcb..2834b54 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala @@ -183,7 +183,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(setToNullAfterCreation.getFloat(6) === rowWithNoNullColumns.getFloat(6)) assert(setToNullAfterCreation.getDouble(7) === rowWithNoNullColumns.getDouble(7)) assert(setToNullAfterCreation.getString(8) === rowWithNoNullColumns.getString(8)) - assert(setToNullAfterCreation.getBinary(9) === rowWithNoNullColumns.get(9)) + assert(setToNullAfterCreation.getBinary(9) === rowWithNoNullColumns.getBinary(9)) // assert(setToNullAfterCreation.get(10) === rowWithNoNullColumns.get(10)) // assert(setToNullAfterCreation.get(11) === rowWithNoNullColumns.get(11)) http://git-wip-us.apache.org/repos/asf/spark/blob/4a01bfc2/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala index 6ee833c..c808442 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala @@ -288,7 +288,7 @@ private[sql] object SparkSqlSerializer2 { out.writeByte(NULL) } else { out.writeByte(NOT_NULL) - val bytes = row.getAs[Array[Byte]](i) + val bytes = row.getBinary(i) out.writeInt(bytes.length) out.write(bytes) } @@ -298,7 +298,7 @@ private[sql] object SparkSqlSerializer2 { out.writeByte(NULL) } else { out.writeByte(NOT_NULL) - val value = row.getAs[Decimal](i) + val value = row.getDecimal(i) val javaBigDecimal = value.toJavaBigDecimal // First, write out the unscaled value. val bytes: Array[Byte] = javaBigDecimal.unscaledValue().toByteArray http://git-wip-us.apache.org/repos/asf/spark/blob/4a01bfc2/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index fdd7ad5..fe429d8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -17,16 +17,16 @@ package org.apache.spark.sql.execution -import org.apache.spark.sql.types.StructType import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.{RDD, ShuffledRDD} import org.apache.spark.shuffle.sort.SortShuffleManager +import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.types.StructType import org.apache.spark.util.collection.ExternalSorter import org.apache.spark.util.collection.unsafe.sort.PrefixComparator import org.apache.spark.util.{CompletionIterator, MutablePair} http://git-wip-us.apache.org/repos/asf/spark/blob/4a01bfc2/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index cdbe423..6b91e51 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -189,13 +189,13 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { if (i != -1) { // If yes, gets column value from partition values. (mutableRow: MutableRow, dataRow: InternalRow, ordinal: Int) => { - mutableRow(ordinal) = partitionValues.get(i) + mutableRow(ordinal) = partitionValues.genericGet(i) } } else { // Otherwise, inherits the value from scanned data. val i = nonPartitionColumns.indexOf(name) (mutableRow: MutableRow, dataRow: InternalRow, ordinal: Int) => { - mutableRow(ordinal) = dataRow.get(i) + mutableRow(ordinal) = dataRow.genericGet(i) } } } http://git-wip-us.apache.org/repos/asf/spark/blob/4a01bfc2/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala index 1fdcc6a..aeeb0e4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala @@ -136,7 +136,7 @@ package object debug { tupleCount += 1 var i = 0 while (i < numColumns) { - val value = currentRow.get(i) + val value = currentRow.get(i, output(i).dataType) if (value != null) { columnStats(i).elementTypes += HashSet(value.getClass.getName) } http://git-wip-us.apache.org/repos/asf/spark/blob/4a01bfc2/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala index 970c40d..ec084a2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala @@ -129,7 +129,7 @@ object EvaluatePython { val values = new Array[Any](row.numFields) var i = 0 while (i < row.numFields) { - values(i) = toJava(row.get(i), struct.fields(i).dataType) + values(i) = toJava(row.get(i, struct.fields(i).dataType), struct.fields(i).dataType) i += 1 } new GenericInternalRowWithSchema(values, struct) http://git-wip-us.apache.org/repos/asf/spark/blob/4a01bfc2/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala index ec5c695..78da284 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala @@ -22,7 +22,7 @@ import scala.collection.mutable.{Map => MutableMap} import org.apache.spark.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.plans.logical.LocalRelation -import org.apache.spark.sql.types.{ArrayType, StructField, StructType} +import org.apache.spark.sql.types.{DataType, ArrayType, StructField, StructType} import org.apache.spark.sql.{Column, DataFrame} private[sql] object FrequentItems extends Logging { @@ -85,17 +85,17 @@ private[sql] object FrequentItems extends Logging { val sizeOfMap = (1 / support).toInt val countMaps = Seq.tabulate(numCols)(i => new FreqItemCounter(sizeOfMap)) val originalSchema = df.schema - val colInfo = cols.map { name => + val colInfo: Array[(String, DataType)] = cols.map { name => val index = originalSchema.fieldIndex(name) (name, originalSchema.fields(index).dataType) - } + }.toArray val freqItems = df.select(cols.map(Column(_)) : _*).queryExecution.toRdd.aggregate(countMaps)( seqOp = (counts, row) => { var i = 0 while (i < numCols) { val thisMap = counts(i) - val key = row.get(i) + val key = row.get(i, colInfo(i)._2) thisMap.add(key, 1L) i += 1 } http://git-wip-us.apache.org/repos/asf/spark/blob/4a01bfc2/sql/core/src/main/scala/org/apache/spark/sql/expressions/aggregate/udaf.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/aggregate/udaf.scala index 7a6e867..4ada9ec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/aggregate/udaf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/aggregate/udaf.scala @@ -110,6 +110,7 @@ private[sql] abstract class AggregationBuffer( * A Mutable [[Row]] representing an mutable aggregation buffer. */ class MutableAggregationBuffer private[sql] ( + schema: StructType, toCatalystConverters: Array[Any => Any], toScalaConverters: Array[Any => Any], bufferOffset: Int, @@ -121,7 +122,7 @@ class MutableAggregationBuffer private[sql] ( throw new IllegalArgumentException( s"Could not access ${i}th value in this buffer because it only has $length values.") } - toScalaConverters(i)(underlyingBuffer.get(offsets(i))) + toScalaConverters(i)(underlyingBuffer.get(offsets(i), schema(i).dataType)) } def update(i: Int, value: Any): Unit = { @@ -134,6 +135,7 @@ class MutableAggregationBuffer private[sql] ( override def copy(): MutableAggregationBuffer = { new MutableAggregationBuffer( + schema, toCatalystConverters, toScalaConverters, bufferOffset, @@ -145,6 +147,7 @@ class MutableAggregationBuffer private[sql] ( * A [[Row]] representing an immutable aggregation buffer. */ class InputAggregationBuffer private[sql] ( + schema: StructType, toCatalystConverters: Array[Any => Any], toScalaConverters: Array[Any => Any], bufferOffset: Int, @@ -157,11 +160,12 @@ class InputAggregationBuffer private[sql] ( s"Could not access ${i}th value in this buffer because it only has $length values.") } // TODO: Use buffer schema to avoid using generic getter. - toScalaConverters(i)(underlyingInputBuffer.get(offsets(i))) + toScalaConverters(i)(underlyingInputBuffer.get(offsets(i), schema(i).dataType)) } override def copy(): InputAggregationBuffer = { new InputAggregationBuffer( + schema, toCatalystConverters, toScalaConverters, bufferOffset, @@ -233,6 +237,7 @@ case class ScalaUDAF( lazy val inputAggregateBuffer: InputAggregationBuffer = new InputAggregationBuffer( + bufferSchema, bufferValuesToCatalystConverters, bufferValuesToScalaConverters, bufferOffset, @@ -240,6 +245,7 @@ case class ScalaUDAF( lazy val mutableAggregateBuffer: MutableAggregationBuffer = new MutableAggregationBuffer( + bufferSchema, bufferValuesToCatalystConverters, bufferValuesToScalaConverters, bufferOffset, http://git-wip-us.apache.org/repos/asf/spark/blob/4a01bfc2/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala index 38bb1e3..75cbbde 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala @@ -179,7 +179,7 @@ private[sql] case class ParquetTableScan( var i = 0 while (i < row.numFields) { - mutableRow(i) = row.get(i) + mutableRow(i) = row.genericGet(i) i += 1 } // Parquet will leave partitioning columns empty, so we fill them in here. http://git-wip-us.apache.org/repos/asf/spark/blob/4a01bfc2/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala index 2c23d4e..7b6a7f6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala @@ -219,7 +219,7 @@ private[parquet] class RowWriteSupport extends WriteSupport[InternalRow] with Lo // null values indicate optional fields but we do not check currently if (!record.isNullAt(index)) { writer.startField(attributes(index).name, index) - writeValue(attributes(index).dataType, record.get(index)) + writeValue(attributes(index).dataType, record.get(index, attributes(index).dataType)) writer.endField(attributes(index).name, index) } index = index + 1 @@ -280,7 +280,7 @@ private[parquet] class RowWriteSupport extends WriteSupport[InternalRow] with Lo while(i < fields.length) { if (!struct.isNullAt(i)) { writer.startField(fields(i).name, i) - writeValue(fields(i).dataType, struct.get(i)) + writeValue(fields(i).dataType, struct.get(i, fields(i).dataType)) writer.endField(fields(i).name, i) } i = i + 1 http://git-wip-us.apache.org/repos/asf/spark/blob/4a01bfc2/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala index c6804e8..01b7c21 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala @@ -30,23 +30,24 @@ class RowSuite extends SparkFunSuite { test("create row") { val expected = new GenericMutableRow(4) - expected.update(0, 2147483647) + expected.setInt(0, 2147483647) expected.setString(1, "this is a string") - expected.update(2, false) - expected.update(3, null) + expected.setBoolean(2, false) + expected.setNullAt(3) + val actual1 = Row(2147483647, "this is a string", false, null) assert(expected.numFields === actual1.size) assert(expected.getInt(0) === actual1.getInt(0)) assert(expected.getString(1) === actual1.getString(1)) assert(expected.getBoolean(2) === actual1.getBoolean(2)) - assert(expected.get(3) === actual1.get(3)) + assert(expected.isNullAt(3) === actual1.isNullAt(3)) val actual2 = Row.fromSeq(Seq(2147483647, "this is a string", false, null)) assert(expected.numFields === actual2.size) assert(expected.getInt(0) === actual2.getInt(0)) assert(expected.getString(1) === actual2.getString(1)) assert(expected.getBoolean(2) === actual2.getBoolean(2)) - assert(expected.get(3) === actual2.get(3)) + assert(expected.isNullAt(3) === actual2.isNullAt(3)) } test("SpecificMutableRow.update with null") { http://git-wip-us.apache.org/repos/asf/spark/blob/4a01bfc2/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala ---------------------------------------------------------------------- diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index f0e0ca0..e4944ca 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -36,6 +36,7 @@ import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.execution.{UnaryNode, SparkPlan} import org.apache.spark.sql.hive.HiveShim.{ShimFileSinkDesc => FileSinkDesc} import org.apache.spark.sql.hive._ +import org.apache.spark.sql.types.DataType import org.apache.spark.{SparkException, TaskContext} import scala.collection.JavaConversions._ @@ -96,13 +97,14 @@ case class InsertIntoHiveTable( val fieldOIs = standardOI.getAllStructFieldRefs.map(_.getFieldObjectInspector).toArray val wrappers = fieldOIs.map(wrapperFor) val outputData = new Array[Any](fieldOIs.length) + val dataTypes: Array[DataType] = child.output.map(_.dataType).toArray writerContainer.executorSideSetup(context.stageId, context.partitionId, context.attemptNumber) iterator.foreach { row => var i = 0 while (i < fieldOIs.length) { - outputData(i) = if (row.isNullAt(i)) null else wrappers(i)(row.get(i)) + outputData(i) = if (row.isNullAt(i)) null else wrappers(i)(row.get(i, dataTypes(i))) i += 1 } http://git-wip-us.apache.org/repos/asf/spark/blob/4a01bfc2/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala ---------------------------------------------------------------------- diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala index 5844509..924f4d3 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala @@ -122,7 +122,7 @@ private[orc] class OrcOutputWriter( override def writeInternal(row: InternalRow): Unit = { var i = 0 while (i < row.numFields) { - reusableOutputBuffer(i) = wrappers(i)(row.get(i)) + reusableOutputBuffer(i) = wrappers(i)(row.get(i, dataSchema(i).dataType)) i += 1 } --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
