This is an automated email from the ASF dual-hosted git repository. hvanhovell pushed a commit to branch branch-3.5 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.5 by this push: new bde7aa61ce3 [SPARK-44613][CONNECT] Add Encoders object bde7aa61ce3 is described below commit bde7aa61ce3de15323a8920e8114a681fcd17000 Author: Herman van Hovell <her...@databricks.com> AuthorDate: Tue Aug 1 14:39:38 2023 -0400 [SPARK-44613][CONNECT] Add Encoders object ### What changes were proposed in this pull request? This PR adds the org.apache.spark.sql.Encoders object to Connect. ### Why are the changes needed? To increase compatibility with the SQL Dataframe API ### Does this PR introduce _any_ user-facing change? Yes, it adds missing functionality. ### How was this patch tested? Added a couple of java based tests. Closes #42264 from hvanhovell/SPARK-44613. Authored-by: Herman van Hovell <her...@databricks.com> Signed-off-by: Herman van Hovell <her...@databricks.com> (cherry picked from commit 4f62f8a718e80dca13a1d44b6fdf8857f037c15e) Signed-off-by: Herman van Hovell <her...@databricks.com> --- .../main/scala/org/apache/spark/sql/Encoders.scala | 262 +++++++++++++++++++++ .../spark/sql/connect/client/SparkResult.scala | 14 +- .../org/apache/spark/sql/JavaEncoderSuite.java | 94 ++++++++ .../CheckConnectJvmClientCompatibility.scala | 8 +- .../connect/client/util/RemoteSparkSession.scala | 2 +- 5 files changed, 371 insertions(+), 9 deletions(-) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Encoders.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Encoders.scala new file mode 100644 index 00000000000..3f2f7ec96d4 --- /dev/null +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Encoders.scala @@ -0,0 +1,262 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql + +import scala.reflect.runtime.universe.TypeTag + +import org.apache.spark.sql.catalyst.{JavaTypeInference, ScalaReflection} +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._ + +/** + * Methods for creating an [[Encoder]]. + * + * @since 3.5.0 + */ +object Encoders { + + /** + * An encoder for nullable boolean type. The Scala primitive encoder is available as + * [[scalaBoolean]]. + * @since 3.5.0 + */ + def BOOLEAN: Encoder[java.lang.Boolean] = BoxedBooleanEncoder + + /** + * An encoder for nullable byte type. The Scala primitive encoder is available as [[scalaByte]]. + * @since 3.5.0 + */ + def BYTE: Encoder[java.lang.Byte] = BoxedByteEncoder + + /** + * An encoder for nullable short type. The Scala primitive encoder is available as + * [[scalaShort]]. + * @since 3.5.0 + */ + def SHORT: Encoder[java.lang.Short] = BoxedShortEncoder + + /** + * An encoder for nullable int type. The Scala primitive encoder is available as [[scalaInt]]. + * @since 3.5.0 + */ + def INT: Encoder[java.lang.Integer] = BoxedIntEncoder + + /** + * An encoder for nullable long type. The Scala primitive encoder is available as [[scalaLong]]. + * @since 3.5.0 + */ + def LONG: Encoder[java.lang.Long] = BoxedLongEncoder + + /** + * An encoder for nullable float type. The Scala primitive encoder is available as + * [[scalaFloat]]. + * @since 3.5.0 + */ + def FLOAT: Encoder[java.lang.Float] = BoxedFloatEncoder + + /** + * An encoder for nullable double type. The Scala primitive encoder is available as + * [[scalaDouble]]. + * @since 3.5.0 + */ + def DOUBLE: Encoder[java.lang.Double] = BoxedDoubleEncoder + + /** + * An encoder for nullable string type. + * + * @since 3.5.0 + */ + def STRING: Encoder[java.lang.String] = StringEncoder + + /** + * An encoder for nullable decimal type. + * + * @since 3.5.0 + */ + def DECIMAL: Encoder[java.math.BigDecimal] = DEFAULT_JAVA_DECIMAL_ENCODER + + /** + * An encoder for nullable date type. + * + * @since 3.5.0 + */ + def DATE: Encoder[java.sql.Date] = DateEncoder(lenientSerialization = false) + + /** + * Creates an encoder that serializes instances of the `java.time.LocalDate` class to the + * internal representation of nullable Catalyst's DateType. + * + * @since 3.5.0 + */ + def LOCALDATE: Encoder[java.time.LocalDate] = STRICT_LOCAL_DATE_ENCODER + + /** + * Creates an encoder that serializes instances of the `java.time.LocalDateTime` class to the + * internal representation of nullable Catalyst's TimestampNTZType. + * + * @since 3.5.0 + */ + def LOCALDATETIME: Encoder[java.time.LocalDateTime] = LocalDateTimeEncoder + + /** + * An encoder for nullable timestamp type. + * + * @since 3.5.0 + */ + def TIMESTAMP: Encoder[java.sql.Timestamp] = STRICT_TIMESTAMP_ENCODER + + /** + * Creates an encoder that serializes instances of the `java.time.Instant` class to the internal + * representation of nullable Catalyst's TimestampType. + * + * @since 3.5.0 + */ + def INSTANT: Encoder[java.time.Instant] = STRICT_INSTANT_ENCODER + + /** + * An encoder for arrays of bytes. + * + * @since 3.5.0 + */ + def BINARY: Encoder[Array[Byte]] = BinaryEncoder + + /** + * Creates an encoder that serializes instances of the `java.time.Duration` class to the + * internal representation of nullable Catalyst's DayTimeIntervalType. + * + * @since 3.5.0 + */ + def DURATION: Encoder[java.time.Duration] = DayTimeIntervalEncoder + + /** + * Creates an encoder that serializes instances of the `java.time.Period` class to the internal + * representation of nullable Catalyst's YearMonthIntervalType. + * + * @since 3.5.0 + */ + def PERIOD: Encoder[java.time.Period] = YearMonthIntervalEncoder + + /** + * Creates an encoder for Java Bean of type T. + * + * T must be publicly accessible. + * + * supported types for java bean field: + * - primitive types: boolean, int, double, etc. + * - boxed types: Boolean, Integer, Double, etc. + * - String + * - java.math.BigDecimal, java.math.BigInteger + * - time related: java.sql.Date, java.sql.Timestamp, java.time.LocalDate, java.time.Instant + * - collection types: array, java.util.List, and map + * - nested java bean. + * + * @since 3.5.0 + */ + def bean[T](beanClass: Class[T]): Encoder[T] = JavaTypeInference.encoderFor(beanClass) + + private def tupleEncoder[T](encoders: Encoder[_]*): Encoder[T] = { + ProductEncoder.tuple(encoders.asInstanceOf[Seq[AgnosticEncoder[_]]]).asInstanceOf[Encoder[T]] + } + + /** + * An encoder for 2-ary tuples. + * + * @since 3.5.0 + */ + def tuple[T1, T2](e1: Encoder[T1], e2: Encoder[T2]): Encoder[(T1, T2)] = tupleEncoder(e1, e2) + + /** + * An encoder for 3-ary tuples. + * + * @since 3.5.0 + */ + def tuple[T1, T2, T3]( + e1: Encoder[T1], + e2: Encoder[T2], + e3: Encoder[T3]): Encoder[(T1, T2, T3)] = tupleEncoder(e1, e2, e3) + + /** + * An encoder for 4-ary tuples. + * + * @since 3.5.0 + */ + def tuple[T1, T2, T3, T4]( + e1: Encoder[T1], + e2: Encoder[T2], + e3: Encoder[T3], + e4: Encoder[T4]): Encoder[(T1, T2, T3, T4)] = tupleEncoder(e1, e2, e3, e4) + + /** + * An encoder for 5-ary tuples. + * + * @since 3.5.0 + */ + def tuple[T1, T2, T3, T4, T5]( + e1: Encoder[T1], + e2: Encoder[T2], + e3: Encoder[T3], + e4: Encoder[T4], + e5: Encoder[T5]): Encoder[(T1, T2, T3, T4, T5)] = tupleEncoder(e1, e2, e3, e4, e5) + + /** + * An encoder for Scala's product type (tuples, case classes, etc). + * @since 3.5.0 + */ + def product[T <: Product: TypeTag]: Encoder[T] = ScalaReflection.encoderFor[T] + + /** + * An encoder for Scala's primitive int type. + * @since 3.5.0 + */ + def scalaInt: Encoder[Int] = PrimitiveIntEncoder + + /** + * An encoder for Scala's primitive long type. + * @since 3.5.0 + */ + def scalaLong: Encoder[Long] = PrimitiveLongEncoder + + /** + * An encoder for Scala's primitive double type. + * @since 3.5.0 + */ + def scalaDouble: Encoder[Double] = PrimitiveDoubleEncoder + + /** + * An encoder for Scala's primitive float type. + * @since 3.5.0 + */ + def scalaFloat: Encoder[Float] = PrimitiveFloatEncoder + + /** + * An encoder for Scala's primitive byte type. + * @since 3.5.0 + */ + def scalaByte: Encoder[Byte] = PrimitiveByteEncoder + + /** + * An encoder for Scala's primitive short type. + * @since 3.5.0 + */ + def scalaShort: Encoder[Short] = PrimitiveShortEncoder + + /** + * An encoder for Scala's primitive boolean type. + * @since 3.5.0 + */ + def scalaBoolean: Encoder[Boolean] = PrimitiveBooleanEncoder +} diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala index e3055b2678f..93c32aa2954 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala @@ -182,11 +182,15 @@ private[sql] class SparkResult[T]( def toArray: Array[T] = { val result = encoder.clsTag.newArray(length) val rows = iterator - var i = 0 - while (rows.hasNext) { - result(i) = rows.next() - assert(i < numRecords) - i += 1 + try { + var i = 0 + while (rows.hasNext) { + result(i) = rows.next() + assert(i < numRecords) + i += 1 + } + } finally { + rows.close() } result } diff --git a/connector/connect/client/jvm/src/test/java/org/apache/spark/sql/JavaEncoderSuite.java b/connector/connect/client/jvm/src/test/java/org/apache/spark/sql/JavaEncoderSuite.java new file mode 100644 index 00000000000..c8210a7a485 --- /dev/null +++ b/connector/connect/client/jvm/src/test/java/org/apache/spark/sql/JavaEncoderSuite.java @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql; + +import org.junit.*; +import static org.junit.Assert.*; + +import static org.apache.spark.sql.Encoders.*; +import static org.apache.spark.sql.functions.*; +import org.apache.spark.sql.connect.client.SparkConnectClient; +import org.apache.spark.sql.connect.client.util.SparkConnectServerUtils; + +import java.math.BigDecimal; +import java.util.Arrays; + +/** + * Tests for the encoders class. + */ +public class JavaEncoderSuite { + private static SparkSession spark; + + @BeforeClass + public static void setup() { + SparkConnectServerUtils.start(); + spark = SparkSession + .builder() + .client(SparkConnectClient + .builder() + .port(SparkConnectServerUtils.port()) + .build()) + .create(); + } + + @AfterClass + public static void tearDown() { + spark.stop(); + spark = null; + SparkConnectServerUtils.stop(); + } + + private static BigDecimal bigDec(long unscaled, int scale) { + return BigDecimal.valueOf(unscaled, scale); + } + + + private <T> Dataset<T> dataset(Encoder<T> encoder, T... elements) { + return spark.createDataset(Arrays.asList(elements), encoder); + } + + @Test + public void testSimpleEncoders() { + final Column v = col("value"); + assertFalse( + dataset(BOOLEAN(), false, true, false).select(every(v)).as(BOOLEAN()).head()); + assertEquals( + 7L, + dataset(BYTE(), (byte) -120, (byte)127).select(sum(v)).as(LONG()).head().longValue()); + assertEquals( + (short) 16, + dataset(SHORT(), (short)16, (short)2334).select(min(v)).as(SHORT()).head().shortValue()); + assertEquals( + 10L, + dataset(INT(), 1, 2, 3, 4).select(sum(v)).as(LONG()).head().longValue()); + assertEquals( + 96L, + dataset(LONG(), 77L, 19L).select(sum(v)).as(LONG()).head().longValue()); + assertEquals( + 0.12f, + dataset(FLOAT(), 0.12f, 0.3f, 44f).select(min(v)).as(FLOAT()).head(), + 0.0001f); + assertEquals( + 789d, + dataset(DOUBLE(), 789d, 12.213d, 10.01d).select(max(v)).as(DOUBLE()).head(), + 0.0001f); + assertEquals( + bigDec(1002, 2), + dataset(DECIMAL(), bigDec(1000, 2), bigDec(2, 2)) + .select(sum(v)).as(DECIMAL()).head().setScale(2)); + } +} diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala index 08028f26eb4..6e577e0f212 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala @@ -192,7 +192,6 @@ object CheckConnectJvmClientCompatibility { // functions ProblemFilters.exclude[Problem]("org.apache.spark.sql.functions.udf"), - ProblemFilters.exclude[Problem]("org.apache.spark.sql.functions.call_udf"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.functions.callUDF"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.functions.unwrap_udt"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.functions.udaf"), @@ -216,7 +215,6 @@ object CheckConnectJvmClientCompatibility { ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.sqlContext"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.listenerManager"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.experimental"), - ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.udf"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.udtf"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.streams"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.createDataFrame"), @@ -418,7 +416,11 @@ object CheckConnectJvmClientCompatibility { ProblemFilters.exclude[MissingClassProblem]( "org.apache.spark.sql.streaming.RemoteStreamingQuery"), ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.sql.streaming.RemoteStreamingQuery$")) + "org.apache.spark.sql.streaming.RemoteStreamingQuery$"), + + // Encoders are in the wrong JAR + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.Encoders"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.Encoders$")) checkMiMaCompatibility(sqlJar, clientJar, includedRules, excludeRules) } diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/RemoteSparkSession.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/RemoteSparkSession.scala index 88c0785d3af..f14109e49b5 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/RemoteSparkSession.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/RemoteSparkSession.scala @@ -50,7 +50,7 @@ import org.apache.spark.sql.connect.common.config.ConnectCommon object SparkConnectServerUtils { // Server port - private[spark] val port: Int = + val port: Int = ConnectCommon.CONNECT_GRPC_BINDING_PORT + util.Random.nextInt(1000) @volatile private var stopped = false --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org