Repository: spark Updated Branches: refs/heads/branch-1.6 f9aeb961e -> 9bf988555
[SPARK-11564][SQL][FOLLOW-UP] clean up java tuple encoder We need to support custom classes like java beans and combine them into tuple, and it's very hard to do it with the TypeTag-based approach. We should keep only the compose-based way to create tuple encoder. This PR also move `Encoder` to `org.apache.spark.sql` Author: Wenchen Fan <wenc...@databricks.com> Closes #9567 from cloud-fan/java. (cherry picked from commit ec2b807212e568c9e98cd80746bcb61e02c7a98e) Signed-off-by: Michael Armbrust <mich...@databricks.com> Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/9bf98855 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/9bf98855 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/9bf98855 Branch: refs/heads/branch-1.6 Commit: 9bf988555ac94daba72b432129cafb22fd39c95c Parents: f9aeb96 Author: Wenchen Fan <wenc...@databricks.com> Authored: Wed Nov 11 10:52:23 2015 -0800 Committer: Michael Armbrust <mich...@databricks.com> Committed: Wed Nov 11 10:52:37 2015 -0800 ---------------------------------------------------------------------- .../scala/org/apache/spark/sql/Encoder.scala | 131 +++++++++++++ .../spark/sql/catalyst/encoders/Encoder.scala | 182 ------------------- .../catalyst/encoders/ExpressionEncoder.scala | 10 +- .../spark/sql/catalyst/encoders/package.scala | 3 +- .../catalyst/plans/logical/basicOperators.scala | 1 + .../scala/org/apache/spark/sql/Column.scala | 2 +- .../scala/org/apache/spark/sql/DataFrame.scala | 2 - .../org/apache/spark/sql/GroupedDataset.scala | 2 +- .../scala/org/apache/spark/sql/SQLContext.scala | 2 +- .../aggregate/TypedAggregateExpression.scala | 3 +- .../spark/sql/expressions/Aggregator.scala | 3 +- .../scala/org/apache/spark/sql/functions.scala | 2 +- .../org/apache/spark/sql/JavaDatasetSuite.java | 78 ++++---- .../spark/sql/DatasetAggregatorSuite.scala | 4 +- .../scala/org/apache/spark/sql/QueryTest.scala | 1 - 15 files changed, 189 insertions(+), 237 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/9bf98855/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala new file mode 100644 index 0000000..1ff7340 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types.{ObjectType, StructField, StructType} +import org.apache.spark.util.Utils + +import scala.reflect.ClassTag + +/** + * Used to convert a JVM object of type `T` to and from the internal Spark SQL representation. + * + * Encoders are not intended to be thread-safe and thus they are allow to avoid internal locking + * and reuse internal buffers to improve performance. + */ +trait Encoder[T] extends Serializable { + + /** Returns the schema of encoding this type of object as a Row. */ + def schema: StructType + + /** A ClassTag that can be used to construct and Array to contain a collection of `T`. */ + def clsTag: ClassTag[T] +} + +object Encoders { + def BOOLEAN: Encoder[java.lang.Boolean] = ExpressionEncoder(flat = true) + def BYTE: Encoder[java.lang.Byte] = ExpressionEncoder(flat = true) + def SHORT: Encoder[java.lang.Short] = ExpressionEncoder(flat = true) + def INT: Encoder[java.lang.Integer] = ExpressionEncoder(flat = true) + def LONG: Encoder[java.lang.Long] = ExpressionEncoder(flat = true) + def FLOAT: Encoder[java.lang.Float] = ExpressionEncoder(flat = true) + def DOUBLE: Encoder[java.lang.Double] = ExpressionEncoder(flat = true) + def STRING: Encoder[java.lang.String] = ExpressionEncoder(flat = true) + + def tuple[T1, T2](enc1: Encoder[T1], enc2: Encoder[T2]): Encoder[(T1, T2)] = { + tuple(Seq(enc1, enc2).map(_.asInstanceOf[ExpressionEncoder[_]])) + .asInstanceOf[ExpressionEncoder[(T1, T2)]] + } + + def tuple[T1, T2, T3]( + enc1: Encoder[T1], + enc2: Encoder[T2], + enc3: Encoder[T3]): Encoder[(T1, T2, T3)] = { + tuple(Seq(enc1, enc2, enc3).map(_.asInstanceOf[ExpressionEncoder[_]])) + .asInstanceOf[ExpressionEncoder[(T1, T2, T3)]] + } + + def tuple[T1, T2, T3, T4]( + enc1: Encoder[T1], + enc2: Encoder[T2], + enc3: Encoder[T3], + enc4: Encoder[T4]): Encoder[(T1, T2, T3, T4)] = { + tuple(Seq(enc1, enc2, enc3, enc4).map(_.asInstanceOf[ExpressionEncoder[_]])) + .asInstanceOf[ExpressionEncoder[(T1, T2, T3, T4)]] + } + + def tuple[T1, T2, T3, T4, T5]( + enc1: Encoder[T1], + enc2: Encoder[T2], + enc3: Encoder[T3], + enc4: Encoder[T4], + enc5: Encoder[T5]): Encoder[(T1, T2, T3, T4, T5)] = { + tuple(Seq(enc1, enc2, enc3, enc4, enc5).map(_.asInstanceOf[ExpressionEncoder[_]])) + .asInstanceOf[ExpressionEncoder[(T1, T2, T3, T4, T5)]] + } + + private def tuple(encoders: Seq[ExpressionEncoder[_]]): ExpressionEncoder[_] = { + assert(encoders.length > 1) + // make sure all encoders are resolved, i.e. `Attribute` has been resolved to `BoundReference`. + assert(encoders.forall(_.constructExpression.find(_.isInstanceOf[Attribute]).isEmpty)) + + val schema = StructType(encoders.zipWithIndex.map { + case (e, i) => StructField(s"_${i + 1}", if (e.flat) e.schema.head.dataType else e.schema) + }) + + val cls = Utils.getContextOrSparkClassLoader.loadClass(s"scala.Tuple${encoders.size}") + + val extractExpressions = encoders.map { + case e if e.flat => e.extractExpressions.head + case other => CreateStruct(other.extractExpressions) + }.zipWithIndex.map { case (expr, index) => + expr.transformUp { + case BoundReference(0, t: ObjectType, _) => + Invoke( + BoundReference(0, ObjectType(cls), nullable = true), + s"_${index + 1}", + t) + } + } + + val constructExpressions = encoders.zipWithIndex.map { case (enc, index) => + if (enc.flat) { + enc.constructExpression.transform { + case b: BoundReference => b.copy(ordinal = index) + } + } else { + enc.constructExpression.transformUp { + case BoundReference(ordinal, dt, _) => + GetInternalRowField(BoundReference(index, enc.schema, nullable = true), ordinal, dt) + } + } + } + + val constructExpression = + NewInstance(cls, constructExpressions, propagateNull = false, ObjectType(cls)) + + new ExpressionEncoder[Any]( + schema, + false, + extractExpressions, + constructExpression, + ClassTag.apply(cls)) + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/9bf98855/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala deleted file mode 100644 index 6569b90..0000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala +++ /dev/null @@ -1,182 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.encoders - -import scala.reflect.ClassTag - -import org.apache.spark.util.Utils -import org.apache.spark.sql.types.{ObjectType, StructField, StructType} -import org.apache.spark.sql.catalyst.expressions._ - -/** - * Used to convert a JVM object of type `T` to and from the internal Spark SQL representation. - * - * Encoders are not intended to be thread-safe and thus they are allow to avoid internal locking - * and reuse internal buffers to improve performance. - */ -trait Encoder[T] extends Serializable { - - /** Returns the schema of encoding this type of object as a Row. */ - def schema: StructType - - /** A ClassTag that can be used to construct and Array to contain a collection of `T`. */ - def clsTag: ClassTag[T] -} - -object Encoder { - import scala.reflect.runtime.universe._ - - def BOOLEAN: Encoder[java.lang.Boolean] = ExpressionEncoder(flat = true) - def BYTE: Encoder[java.lang.Byte] = ExpressionEncoder(flat = true) - def SHORT: Encoder[java.lang.Short] = ExpressionEncoder(flat = true) - def INT: Encoder[java.lang.Integer] = ExpressionEncoder(flat = true) - def LONG: Encoder[java.lang.Long] = ExpressionEncoder(flat = true) - def FLOAT: Encoder[java.lang.Float] = ExpressionEncoder(flat = true) - def DOUBLE: Encoder[java.lang.Double] = ExpressionEncoder(flat = true) - def STRING: Encoder[java.lang.String] = ExpressionEncoder(flat = true) - - def tuple[T1, T2](enc1: Encoder[T1], enc2: Encoder[T2]): Encoder[(T1, T2)] = { - tuple(Seq(enc1, enc2).map(_.asInstanceOf[ExpressionEncoder[_]])) - .asInstanceOf[ExpressionEncoder[(T1, T2)]] - } - - def tuple[T1, T2, T3]( - enc1: Encoder[T1], - enc2: Encoder[T2], - enc3: Encoder[T3]): Encoder[(T1, T2, T3)] = { - tuple(Seq(enc1, enc2, enc3).map(_.asInstanceOf[ExpressionEncoder[_]])) - .asInstanceOf[ExpressionEncoder[(T1, T2, T3)]] - } - - def tuple[T1, T2, T3, T4]( - enc1: Encoder[T1], - enc2: Encoder[T2], - enc3: Encoder[T3], - enc4: Encoder[T4]): Encoder[(T1, T2, T3, T4)] = { - tuple(Seq(enc1, enc2, enc3, enc4).map(_.asInstanceOf[ExpressionEncoder[_]])) - .asInstanceOf[ExpressionEncoder[(T1, T2, T3, T4)]] - } - - def tuple[T1, T2, T3, T4, T5]( - enc1: Encoder[T1], - enc2: Encoder[T2], - enc3: Encoder[T3], - enc4: Encoder[T4], - enc5: Encoder[T5]): Encoder[(T1, T2, T3, T4, T5)] = { - tuple(Seq(enc1, enc2, enc3, enc4, enc5).map(_.asInstanceOf[ExpressionEncoder[_]])) - .asInstanceOf[ExpressionEncoder[(T1, T2, T3, T4, T5)]] - } - - private def tuple(encoders: Seq[ExpressionEncoder[_]]): ExpressionEncoder[_] = { - assert(encoders.length > 1) - // make sure all encoders are resolved, i.e. `Attribute` has been resolved to `BoundReference`. - assert(encoders.forall(_.constructExpression.find(_.isInstanceOf[Attribute]).isEmpty)) - - val schema = StructType(encoders.zipWithIndex.map { - case (e, i) => StructField(s"_${i + 1}", if (e.flat) e.schema.head.dataType else e.schema) - }) - - val cls = Utils.getContextOrSparkClassLoader.loadClass(s"scala.Tuple${encoders.size}") - - val extractExpressions = encoders.map { - case e if e.flat => e.extractExpressions.head - case other => CreateStruct(other.extractExpressions) - }.zipWithIndex.map { case (expr, index) => - expr.transformUp { - case BoundReference(0, t: ObjectType, _) => - Invoke( - BoundReference(0, ObjectType(cls), nullable = true), - s"_${index + 1}", - t) - } - } - - val constructExpressions = encoders.zipWithIndex.map { case (enc, index) => - if (enc.flat) { - enc.constructExpression.transform { - case b: BoundReference => b.copy(ordinal = index) - } - } else { - enc.constructExpression.transformUp { - case BoundReference(ordinal, dt, _) => - GetInternalRowField(BoundReference(index, enc.schema, nullable = true), ordinal, dt) - } - } - } - - val constructExpression = - NewInstance(cls, constructExpressions, propagateNull = false, ObjectType(cls)) - - new ExpressionEncoder[Any]( - schema, - false, - extractExpressions, - constructExpression, - ClassTag.apply(cls)) - } - - def typeTagOfTuple2[T1 : TypeTag, T2 : TypeTag]: TypeTag[(T1, T2)] = typeTag[(T1, T2)] - - private def getTypeTag[T](c: Class[T]): TypeTag[T] = { - import scala.reflect.api - - // val mirror = runtimeMirror(c.getClassLoader) - val mirror = rootMirror - val sym = mirror.staticClass(c.getName) - val tpe = sym.selfType - TypeTag(mirror, new api.TypeCreator { - def apply[U <: api.Universe with Singleton](m: api.Mirror[U]) = - if (m eq mirror) tpe.asInstanceOf[U # Type] - else throw new IllegalArgumentException( - s"Type tag defined in $mirror cannot be migrated to other mirrors.") - }) - } - - def forTuple[T1, T2](c1: Class[T1], c2: Class[T2]): Encoder[(T1, T2)] = { - implicit val typeTag1 = getTypeTag(c1) - implicit val typeTag2 = getTypeTag(c2) - ExpressionEncoder[(T1, T2)]() - } - - def forTuple[T1, T2, T3](c1: Class[T1], c2: Class[T2], c3: Class[T3]): Encoder[(T1, T2, T3)] = { - implicit val typeTag1 = getTypeTag(c1) - implicit val typeTag2 = getTypeTag(c2) - implicit val typeTag3 = getTypeTag(c3) - ExpressionEncoder[(T1, T2, T3)]() - } - - def forTuple[T1, T2, T3, T4]( - c1: Class[T1], c2: Class[T2], c3: Class[T3], c4: Class[T4]): Encoder[(T1, T2, T3, T4)] = { - implicit val typeTag1 = getTypeTag(c1) - implicit val typeTag2 = getTypeTag(c2) - implicit val typeTag3 = getTypeTag(c3) - implicit val typeTag4 = getTypeTag(c4) - ExpressionEncoder[(T1, T2, T3, T4)]() - } - - def forTuple[T1, T2, T3, T4, T5]( - c1: Class[T1], c2: Class[T2], c3: Class[T3], c4: Class[T4], c5: Class[T5]) - : Encoder[(T1, T2, T3, T4, T5)] = { - implicit val typeTag1 = getTypeTag(c1) - implicit val typeTag2 = getTypeTag(c2) - implicit val typeTag3 = getTypeTag(c3) - implicit val typeTag4 = getTypeTag(c4) - implicit val typeTag5 = getTypeTag(c5) - ExpressionEncoder[(T1, T2, T3, T4, T5)]() - } -} http://git-wip-us.apache.org/repos/asf/spark/blob/9bf98855/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 005c062..294afde 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -17,18 +17,18 @@ package org.apache.spark.sql.catalyst.encoders -import org.apache.spark.sql.catalyst.analysis.{SimpleAnalyzer, UnresolvedExtractValue, UnresolvedAttribute} -import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project} -import org.apache.spark.util.Utils - import scala.reflect.ClassTag import scala.reflect.runtime.universe.{typeTag, TypeTag} +import org.apache.spark.util.Utils +import org.apache.spark.sql.Encoder +import org.apache.spark.sql.catalyst.analysis.{SimpleAnalyzer, UnresolvedExtractValue, UnresolvedAttribute} +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.ScalaReflection -import org.apache.spark.sql.types.{StructField, DataType, ObjectType, StructType} +import org.apache.spark.sql.types.{StructField, ObjectType, StructType} /** * A factory for constructing encoders that convert objects and primitves to and from the http://git-wip-us.apache.org/repos/asf/spark/blob/9bf98855/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/package.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/package.scala index d4642a5..2c35adc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/package.scala @@ -17,10 +17,11 @@ package org.apache.spark.sql.catalyst +import org.apache.spark.sql.Encoder + package object encoders { private[sql] def encoderFor[A : Encoder]: ExpressionEncoder[A] = implicitly[Encoder[A]] match { case e: ExpressionEncoder[A] => e case _ => sys.error(s"Only expression encoders are supported today") } } - http://git-wip-us.apache.org/repos/asf/spark/blob/9bf98855/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 764f8aa..597f03e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.plans.logical +import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression http://git-wip-us.apache.org/repos/asf/spark/blob/9bf98855/sql/core/src/main/scala/org/apache/spark/sql/Column.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index d26b6c3..f0f275e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -23,7 +23,7 @@ import org.apache.spark.annotation.Experimental import org.apache.spark.Logging import org.apache.spark.sql.functions.lit import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.encoders.{encoderFor, Encoder} +import org.apache.spark.sql.catalyst.encoders.encoderFor import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.DataTypeParser import org.apache.spark.sql.types._ http://git-wip-us.apache.org/repos/asf/spark/blob/9bf98855/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 691b476..a492099 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -23,7 +23,6 @@ import java.util.Properties import scala.language.implicitConversions import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag -import scala.util.control.NonFatal import com.fasterxml.jackson.core.JsonFactory import org.apache.commons.lang3.StringUtils @@ -35,7 +34,6 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.encoders.Encoder import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection, SqlParser} http://git-wip-us.apache.org/repos/asf/spark/blob/9bf98855/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala index db61499..61e2a95 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala @@ -24,7 +24,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.function.{Function2 => JFunction2, Function3 => JFunction3, _} import org.apache.spark.sql.catalyst.analysis.{UnresolvedAlias, UnresolvedAttribute} -import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, encoderFor, Encoder} +import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, encoderFor} import org.apache.spark.sql.catalyst.expressions.{Expression, NamedExpression, Alias, Attribute} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression http://git-wip-us.apache.org/repos/asf/spark/blob/9bf98855/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 1cf1e30..cd1fdc4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -33,7 +33,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd} import org.apache.spark.sql.SQLConf.SQLConfEntry import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.encoders.{encoderFor, Encoder} +import org.apache.spark.sql.catalyst.encoders.encoderFor import org.apache.spark.sql.catalyst.errors.DialectException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.optimizer.{DefaultOptimizer, Optimizer} http://git-wip-us.apache.org/repos/asf/spark/blob/9bf98855/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala index b5a87c5..dfcbac8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala @@ -20,9 +20,10 @@ package org.apache.spark.sql.execution.aggregate import scala.language.existentials import org.apache.spark.Logging +import org.apache.spark.sql.Encoder import org.apache.spark.sql.expressions.Aggregator import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.encoders.{encoderFor, Encoder} +import org.apache.spark.sql.catalyst.encoders.encoderFor import org.apache.spark.sql.catalyst.expressions.aggregate.ImperativeAggregate import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions._ http://git-wip-us.apache.org/repos/asf/spark/blob/9bf98855/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala index 2aa5a7d..360c9a5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala @@ -17,7 +17,8 @@ package org.apache.spark.sql.expressions -import org.apache.spark.sql.catalyst.encoders.{encoderFor, Encoder} +import org.apache.spark.sql.Encoder +import org.apache.spark.sql.catalyst.encoders.encoderFor import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete} import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression import org.apache.spark.sql.{Dataset, DataFrame, TypedColumn} http://git-wip-us.apache.org/repos/asf/spark/blob/9bf98855/sql/core/src/main/scala/org/apache/spark/sql/functions.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index a59d738..ab49ed4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -26,7 +26,7 @@ import scala.util.Try import org.apache.spark.annotation.Experimental import org.apache.spark.sql.catalyst.{SqlParser, ScalaReflection} import org.apache.spark.sql.catalyst.analysis.{UnresolvedFunction, Star} -import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, Encoder} +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical.BroadcastHint http://git-wip-us.apache.org/repos/asf/spark/blob/9bf98855/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java ---------------------------------------------------------------------- diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index 2da63d1..33d8388 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -30,8 +30,8 @@ import org.apache.spark.Accumulator; import org.apache.spark.SparkContext; import org.apache.spark.api.java.function.*; import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.catalyst.encoders.Encoder; -import org.apache.spark.sql.catalyst.encoders.Encoder$; +import org.apache.spark.sql.Encoder; +import org.apache.spark.sql.Encoders; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.GroupedDataset; import org.apache.spark.sql.test.TestSQLContext; @@ -41,7 +41,6 @@ import static org.apache.spark.sql.functions.*; public class JavaDatasetSuite implements Serializable { private transient JavaSparkContext jsc; private transient TestSQLContext context; - private transient Encoder$ e = Encoder$.MODULE$; @Before public void setUp() { @@ -66,7 +65,7 @@ public class JavaDatasetSuite implements Serializable { @Test public void testCollect() { List<String> data = Arrays.asList("hello", "world"); - Dataset<String> ds = context.createDataset(data, e.STRING()); + Dataset<String> ds = context.createDataset(data, Encoders.STRING()); List<String> collected = ds.collectAsList(); Assert.assertEquals(Arrays.asList("hello", "world"), collected); } @@ -74,7 +73,7 @@ public class JavaDatasetSuite implements Serializable { @Test public void testTake() { List<String> data = Arrays.asList("hello", "world"); - Dataset<String> ds = context.createDataset(data, e.STRING()); + Dataset<String> ds = context.createDataset(data, Encoders.STRING()); List<String> collected = ds.takeAsList(1); Assert.assertEquals(Arrays.asList("hello"), collected); } @@ -82,7 +81,7 @@ public class JavaDatasetSuite implements Serializable { @Test public void testCommonOperation() { List<String> data = Arrays.asList("hello", "world"); - Dataset<String> ds = context.createDataset(data, e.STRING()); + Dataset<String> ds = context.createDataset(data, Encoders.STRING()); Assert.assertEquals("hello", ds.first()); Dataset<String> filtered = ds.filter(new FilterFunction<String>() { @@ -99,7 +98,7 @@ public class JavaDatasetSuite implements Serializable { public Integer call(String v) throws Exception { return v.length(); } - }, e.INT()); + }, Encoders.INT()); Assert.assertEquals(Arrays.asList(5, 5), mapped.collectAsList()); Dataset<String> parMapped = ds.mapPartitions(new MapPartitionsFunction<String, String>() { @@ -111,7 +110,7 @@ public class JavaDatasetSuite implements Serializable { } return ls; } - }, e.STRING()); + }, Encoders.STRING()); Assert.assertEquals(Arrays.asList("HELLO", "WORLD"), parMapped.collectAsList()); Dataset<String> flatMapped = ds.flatMap(new FlatMapFunction<String, String>() { @@ -123,7 +122,7 @@ public class JavaDatasetSuite implements Serializable { } return ls; } - }, e.STRING()); + }, Encoders.STRING()); Assert.assertEquals( Arrays.asList("h", "e", "l", "l", "o", "w", "o", "r", "l", "d"), flatMapped.collectAsList()); @@ -133,7 +132,7 @@ public class JavaDatasetSuite implements Serializable { public void testForeach() { final Accumulator<Integer> accum = jsc.accumulator(0); List<String> data = Arrays.asList("a", "b", "c"); - Dataset<String> ds = context.createDataset(data, e.STRING()); + Dataset<String> ds = context.createDataset(data, Encoders.STRING()); ds.foreach(new ForeachFunction<String>() { @Override @@ -147,7 +146,7 @@ public class JavaDatasetSuite implements Serializable { @Test public void testReduce() { List<Integer> data = Arrays.asList(1, 2, 3); - Dataset<Integer> ds = context.createDataset(data, e.INT()); + Dataset<Integer> ds = context.createDataset(data, Encoders.INT()); int reduced = ds.reduce(new ReduceFunction<Integer>() { @Override @@ -161,13 +160,13 @@ public class JavaDatasetSuite implements Serializable { @Test public void testGroupBy() { List<String> data = Arrays.asList("a", "foo", "bar"); - Dataset<String> ds = context.createDataset(data, e.STRING()); + Dataset<String> ds = context.createDataset(data, Encoders.STRING()); GroupedDataset<Integer, String> grouped = ds.groupBy(new MapFunction<String, Integer>() { @Override public Integer call(String v) throws Exception { return v.length(); } - }, e.INT()); + }, Encoders.INT()); Dataset<String> mapped = grouped.map(new MapGroupFunction<Integer, String, String>() { @Override @@ -178,7 +177,7 @@ public class JavaDatasetSuite implements Serializable { } return sb.toString(); } - }, e.STRING()); + }, Encoders.STRING()); Assert.assertEquals(Arrays.asList("1a", "3foobar"), mapped.collectAsList()); @@ -193,27 +192,27 @@ public class JavaDatasetSuite implements Serializable { return Collections.singletonList(sb.toString()); } }, - e.STRING()); + Encoders.STRING()); Assert.assertEquals(Arrays.asList("1a", "3foobar"), flatMapped.collectAsList()); List<Integer> data2 = Arrays.asList(2, 6, 10); - Dataset<Integer> ds2 = context.createDataset(data2, e.INT()); + Dataset<Integer> ds2 = context.createDataset(data2, Encoders.INT()); GroupedDataset<Integer, Integer> grouped2 = ds2.groupBy(new MapFunction<Integer, Integer>() { @Override public Integer call(Integer v) throws Exception { return v / 2; } - }, e.INT()); + }, Encoders.INT()); Dataset<String> cogrouped = grouped.cogroup( grouped2, new CoGroupFunction<Integer, String, Integer, String>() { @Override public Iterable<String> call( - Integer key, - Iterator<String> left, - Iterator<Integer> right) throws Exception { + Integer key, + Iterator<String> left, + Iterator<Integer> right) throws Exception { StringBuilder sb = new StringBuilder(key.toString()); while (left.hasNext()) { sb.append(left.next()); @@ -225,7 +224,7 @@ public class JavaDatasetSuite implements Serializable { return Collections.singletonList(sb.toString()); } }, - e.STRING()); + Encoders.STRING()); Assert.assertEquals(Arrays.asList("1a#2", "3foobar#6", "5#10"), cogrouped.collectAsList()); } @@ -233,8 +232,9 @@ public class JavaDatasetSuite implements Serializable { @Test public void testGroupByColumn() { List<String> data = Arrays.asList("a", "foo", "bar"); - Dataset<String> ds = context.createDataset(data, e.STRING()); - GroupedDataset<Integer, String> grouped = ds.groupBy(length(col("value"))).asKey(e.INT()); + Dataset<String> ds = context.createDataset(data, Encoders.STRING()); + GroupedDataset<Integer, String> grouped = + ds.groupBy(length(col("value"))).asKey(Encoders.INT()); Dataset<String> mapped = grouped.map( new MapGroupFunction<Integer, String, String>() { @@ -247,7 +247,7 @@ public class JavaDatasetSuite implements Serializable { return sb.toString(); } }, - e.STRING()); + Encoders.STRING()); Assert.assertEquals(Arrays.asList("1a", "3foobar"), mapped.collectAsList()); } @@ -255,11 +255,11 @@ public class JavaDatasetSuite implements Serializable { @Test public void testSelect() { List<Integer> data = Arrays.asList(2, 6); - Dataset<Integer> ds = context.createDataset(data, e.INT()); + Dataset<Integer> ds = context.createDataset(data, Encoders.INT()); Dataset<Tuple2<Integer, String>> selected = ds.select( expr("value + 1"), - col("value").cast("string")).as(e.tuple(e.INT(), e.STRING())); + col("value").cast("string")).as(Encoders.tuple(Encoders.INT(), Encoders.STRING())); Assert.assertEquals( Arrays.asList(tuple2(3, "2"), tuple2(7, "6")), @@ -269,14 +269,14 @@ public class JavaDatasetSuite implements Serializable { @Test public void testSetOperation() { List<String> data = Arrays.asList("abc", "abc", "xyz"); - Dataset<String> ds = context.createDataset(data, e.STRING()); + Dataset<String> ds = context.createDataset(data, Encoders.STRING()); Assert.assertEquals( Arrays.asList("abc", "xyz"), sort(ds.distinct().collectAsList().toArray(new String[0]))); List<String> data2 = Arrays.asList("xyz", "foo", "foo"); - Dataset<String> ds2 = context.createDataset(data2, e.STRING()); + Dataset<String> ds2 = context.createDataset(data2, Encoders.STRING()); Dataset<String> intersected = ds.intersect(ds2); Assert.assertEquals(Arrays.asList("xyz"), intersected.collectAsList()); @@ -298,9 +298,9 @@ public class JavaDatasetSuite implements Serializable { @Test public void testJoin() { List<Integer> data = Arrays.asList(1, 2, 3); - Dataset<Integer> ds = context.createDataset(data, e.INT()).as("a"); + Dataset<Integer> ds = context.createDataset(data, Encoders.INT()).as("a"); List<Integer> data2 = Arrays.asList(2, 3, 4); - Dataset<Integer> ds2 = context.createDataset(data2, e.INT()).as("b"); + Dataset<Integer> ds2 = context.createDataset(data2, Encoders.INT()).as("b"); Dataset<Tuple2<Integer, Integer>> joined = ds.joinWith(ds2, col("a.value").equalTo(col("b.value"))); @@ -311,26 +311,28 @@ public class JavaDatasetSuite implements Serializable { @Test public void testTupleEncoder() { - Encoder<Tuple2<Integer, String>> encoder2 = e.tuple(e.INT(), e.STRING()); + Encoder<Tuple2<Integer, String>> encoder2 = Encoders.tuple(Encoders.INT(), Encoders.STRING()); List<Tuple2<Integer, String>> data2 = Arrays.asList(tuple2(1, "a"), tuple2(2, "b")); Dataset<Tuple2<Integer, String>> ds2 = context.createDataset(data2, encoder2); Assert.assertEquals(data2, ds2.collectAsList()); - Encoder<Tuple3<Integer, Long, String>> encoder3 = e.tuple(e.INT(), e.LONG(), e.STRING()); + Encoder<Tuple3<Integer, Long, String>> encoder3 = + Encoders.tuple(Encoders.INT(), Encoders.LONG(), Encoders.STRING()); List<Tuple3<Integer, Long, String>> data3 = Arrays.asList(new Tuple3<Integer, Long, String>(1, 2L, "a")); Dataset<Tuple3<Integer, Long, String>> ds3 = context.createDataset(data3, encoder3); Assert.assertEquals(data3, ds3.collectAsList()); Encoder<Tuple4<Integer, String, Long, String>> encoder4 = - e.tuple(e.INT(), e.STRING(), e.LONG(), e.STRING()); + Encoders.tuple(Encoders.INT(), Encoders.STRING(), Encoders.LONG(), Encoders.STRING()); List<Tuple4<Integer, String, Long, String>> data4 = Arrays.asList(new Tuple4<Integer, String, Long, String>(1, "b", 2L, "a")); Dataset<Tuple4<Integer, String, Long, String>> ds4 = context.createDataset(data4, encoder4); Assert.assertEquals(data4, ds4.collectAsList()); Encoder<Tuple5<Integer, String, Long, String, Boolean>> encoder5 = - e.tuple(e.INT(), e.STRING(), e.LONG(), e.STRING(), e.BOOLEAN()); + Encoders.tuple(Encoders.INT(), Encoders.STRING(), Encoders.LONG(), Encoders.STRING(), + Encoders.BOOLEAN()); List<Tuple5<Integer, String, Long, String, Boolean>> data5 = Arrays.asList(new Tuple5<Integer, String, Long, String, Boolean>(1, "b", 2L, "a", true)); Dataset<Tuple5<Integer, String, Long, String, Boolean>> ds5 = @@ -342,7 +344,7 @@ public class JavaDatasetSuite implements Serializable { public void testNestedTupleEncoder() { // test ((int, string), string) Encoder<Tuple2<Tuple2<Integer, String>, String>> encoder = - e.tuple(e.tuple(e.INT(), e.STRING()), e.STRING()); + Encoders.tuple(Encoders.tuple(Encoders.INT(), Encoders.STRING()), Encoders.STRING()); List<Tuple2<Tuple2<Integer, String>, String>> data = Arrays.asList(tuple2(tuple2(1, "a"), "a"), tuple2(tuple2(2, "b"), "b")); Dataset<Tuple2<Tuple2<Integer, String>, String>> ds = context.createDataset(data, encoder); @@ -350,7 +352,8 @@ public class JavaDatasetSuite implements Serializable { // test (int, (string, string, long)) Encoder<Tuple2<Integer, Tuple3<String, String, Long>>> encoder2 = - e.tuple(e.INT(), e.tuple(e.STRING(), e.STRING(), e.LONG())); + Encoders.tuple(Encoders.INT(), + Encoders.tuple(Encoders.STRING(), Encoders.STRING(), Encoders.LONG())); List<Tuple2<Integer, Tuple3<String, String, Long>>> data2 = Arrays.asList(tuple2(1, new Tuple3<String, String, Long>("a", "b", 3L))); Dataset<Tuple2<Integer, Tuple3<String, String, Long>>> ds2 = @@ -359,7 +362,8 @@ public class JavaDatasetSuite implements Serializable { // test (int, ((string, long), string)) Encoder<Tuple2<Integer, Tuple2<Tuple2<String, Long>, String>>> encoder3 = - e.tuple(e.INT(), e.tuple(e.tuple(e.STRING(), e.LONG()), e.STRING())); + Encoders.tuple(Encoders.INT(), + Encoders.tuple(Encoders.tuple(Encoders.STRING(), Encoders.LONG()), Encoders.STRING())); List<Tuple2<Integer, Tuple2<Tuple2<String, Long>, String>>> data3 = Arrays.asList(tuple2(1, tuple2(tuple2("a", 2L), "b"))); Dataset<Tuple2<Integer, Tuple2<Tuple2<String, Long>, String>>> ds3 = http://git-wip-us.apache.org/repos/asf/spark/blob/9bf98855/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala index d4f0ab7..378cd36 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala @@ -17,13 +17,11 @@ package org.apache.spark.sql -import org.apache.spark.sql.catalyst.encoders.Encoder -import org.apache.spark.sql.functions._ import scala.language.postfixOps import org.apache.spark.sql.test.SharedSQLContext - +import org.apache.spark.sql.functions._ import org.apache.spark.sql.expressions.Aggregator /** An `Aggregator` that adds up any numeric type returned by the given function. */ http://git-wip-us.apache.org/repos/asf/spark/blob/9bf98855/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 3c174ef..7a8b7ae 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -24,7 +24,6 @@ import scala.collection.JavaConverters._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.columnar.InMemoryRelation -import org.apache.spark.sql.catalyst.encoders.Encoder abstract class QueryTest extends PlanTest { --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org