This is an automated email from the ASF dual-hosted git repository. hvanhovell pushed a commit to branch branch-3.5 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.5 by this push: new 5412fb0590e [SPARK-41400][CONNECT] Remove Connect Client Catalyst Dependency 5412fb0590e is described below commit 5412fb0590e55d635e9e31887ec5c72d10011899 Author: Herman van Hovell <her...@databricks.com> AuthorDate: Fri Jul 28 21:30:51 2023 -0400 [SPARK-41400][CONNECT] Remove Connect Client Catalyst Dependency ### What changes were proposed in this pull request? This PR decouples the Spark Connect Scala Client from Catalyst, it now used SQL API module instead. There were quite a few changes we still needed to make: - For testing we needed a bunch of utilities. I have moved these to common-utils. - I have moved bits and pieces of IntervalUtils to SparkIntervalUtils. - A lot of small fixes. ### Why are the changes needed? This reduces the client's dependency tree from ~300 MB of deps to ~30MB. This makes it easier to use the client when you are developing connect applications. On top of this the reduced dependency graph also means folks will be less affected by the clients' classpath. ### Does this PR introduce _any_ user-facing change? Yes. It changes the classpath exposed by the Spark Connect Scala Client. ### How was this patch tested? Existing tests. Closes #42184 from hvanhovell/SPARK-41400-v1. Authored-by: Herman van Hovell <her...@databricks.com> Signed-off-by: Herman van Hovell <her...@databricks.com> (cherry picked from commit 85a4d1e56e80d85dad9b8945c67287927eb379f6) Signed-off-by: Herman van Hovell <her...@databricks.com> --- common/network-common/pom.xml | 3 +- .../apache/spark/network/sasl/SparkSaslSuite.java | 5 +- .../org/apache/spark/network/util/ByteUnit.java | 0 .../org/apache/spark/network/util/JavaUtils.java | 66 ++--- .../org/apache/spark/util/SparkErrorUtils.scala | 53 +++- .../org/apache/spark/util/SparkFileUtils.scala | 80 +++++- .../org/apache/spark/util/SparkSerDeUtils.scala | 9 +- connector/connect/client/jvm/pom.xml | 14 +- .../scala/org/apache/spark/sql/SparkSession.scala | 6 +- .../connect/client/arrow/ArrowDeserializer.scala | 6 +- .../sql/connect/client/arrow/ArrowSerializer.scala | 19 +- .../connect/client/arrow/ArrowVectorReader.scala | 15 +- .../org/apache/spark/sql/protobuf/functions.scala | 15 +- .../org/apache/spark/sql/ClientE2ETestSuite.scala | 7 +- .../spark/sql/DataFrameNaFunctionSuite.scala | 6 +- .../apache/spark/sql/PlanGenerationTestSuite.scala | 6 +- .../scala/org/apache/spark/sql/SQLHelper.scala | 11 +- .../apache/spark/sql/SQLImplicitsTestSuite.scala | 31 ++- .../apache/spark/sql/SparkSessionE2ESuite.scala | 16 +- .../sql/UserDefinedFunctionE2ETestSuite.scala | 15 +- .../spark/sql/UserDefinedFunctionSuite.scala | 4 +- .../sql/connect/client/ClassFinderSuite.scala | 6 +- .../connect/client/arrow/ArrowEncoderSuite.scala | 6 +- .../connect/client/util/IntegrationTestUtils.scala | 11 +- .../spark/sql/connect/client/util/QueryTest.scala | 4 +- .../spark/sql/streaming/StreamingQuerySuite.scala | 6 +- connector/connect/common/pom.xml | 8 +- .../common/LiteralValueProtoConverter.scala | 14 +- .../main/scala/org/apache/spark/util/Utils.scala | 132 +--------- .../catalyst/plans/logical/LogicalGroupState.scala | 16 +- .../sql/catalyst/util/SparkIntervalUtils.scala | 287 ++++++++++++++++++++- .../spark/sql/catalyst/util/StringUtils.scala | 20 ++ .../spark/sql/errors/CompilationErrors.scala | 54 ++++ .../org/apache/spark/sql/internal/SqlApiConf.scala | 3 + .../org/apache/spark/sql/types/UpCastRule.scala | 86 ++++++ .../sql/catalyst/analysis/AnsiTypeCoercion.scala | 2 +- .../spark/sql/catalyst/analysis/TypeCoercion.scala | 12 +- .../spark/sql/catalyst/expressions/Cast.scala | 47 +--- .../sql/catalyst/expressions/ToStringBase.scala | 6 +- .../spark/sql/catalyst/plans/logical/object.scala | 3 - .../spark/sql/catalyst/util/IntervalUtils.scala | 260 ------------------- .../apache/spark/sql/catalyst/util/package.scala | 21 +- .../spark/sql/errors/QueryCompilationErrors.scala | 33 +-- .../org/apache/spark/sql/internal/SQLConf.scala | 6 +- .../sql/catalyst/expressions/CastSuiteBase.scala | 2 +- 45 files changed, 751 insertions(+), 681 deletions(-) diff --git a/common/network-common/pom.xml b/common/network-common/pom.xml index 8a63e999c53..2b43f9ce98a 100644 --- a/common/network-common/pom.xml +++ b/common/network-common/pom.xml @@ -150,7 +150,8 @@ <dependency> <groupId>org.apache.spark</groupId> - <artifactId>spark-tags_${scala.binary.version}</artifactId> + <artifactId>spark-common-utils_${scala.binary.version}</artifactId> + <version>${project.version}</version> </dependency> <!-- diff --git a/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java b/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java index 6096cd32f3d..3562e9f3025 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java @@ -58,10 +58,7 @@ import org.apache.spark.network.server.RpcHandler; import org.apache.spark.network.server.StreamManager; import org.apache.spark.network.server.TransportServer; import org.apache.spark.network.server.TransportServerBootstrap; -import org.apache.spark.network.util.ByteArrayWritableChannel; -import org.apache.spark.network.util.JavaUtils; -import org.apache.spark.network.util.MapConfigProvider; -import org.apache.spark.network.util.TransportConf; +import org.apache.spark.network.util.*; /** * Jointly tests SparkSaslClient and SparkSaslServer, as both are black boxes. diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/ByteUnit.java b/common/utils/src/main/java/org/apache/spark/network/util/ByteUnit.java similarity index 100% rename from common/network-common/src/main/java/org/apache/spark/network/util/ByteUnit.java rename to common/utils/src/main/java/org/apache/spark/network/util/ByteUnit.java diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/JavaUtils.java b/common/utils/src/main/java/org/apache/spark/network/util/JavaUtils.java similarity index 89% rename from common/network-common/src/main/java/org/apache/spark/network/util/JavaUtils.java rename to common/utils/src/main/java/org/apache/spark/network/util/JavaUtils.java index 7e410e9eab2..bbe764b8366 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/util/JavaUtils.java +++ b/common/utils/src/main/java/org/apache/spark/network/util/JavaUtils.java @@ -23,15 +23,11 @@ import java.nio.channels.ReadableByteChannel; import java.nio.charset.StandardCharsets; import java.nio.file.Files; import java.nio.file.attribute.BasicFileAttributes; -import java.util.Locale; -import java.util.UUID; +import java.util.*; import java.util.concurrent.TimeUnit; import java.util.regex.Matcher; import java.util.regex.Pattern; -import com.google.common.base.Preconditions; -import com.google.common.collect.ImmutableMap; -import io.netty.buffer.Unpooled; import org.apache.commons.lang3.SystemUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -72,7 +68,7 @@ public class JavaUtils { * converted back to the same string through {@link #bytesToString(ByteBuffer)}. */ public static ByteBuffer stringToBytes(String s) { - return Unpooled.wrappedBuffer(s.getBytes(StandardCharsets.UTF_8)).nioBuffer(); + return ByteBuffer.wrap(s.getBytes(StandardCharsets.UTF_8)); } /** @@ -80,7 +76,7 @@ public class JavaUtils { * converted back to the same byte buffer through {@link #stringToBytes(String)}. */ public static String bytesToString(ByteBuffer b) { - return Unpooled.wrappedBuffer(b).toString(StandardCharsets.UTF_8); + return StandardCharsets.UTF_8.decode(b.slice()).toString(); } /** @@ -191,7 +187,7 @@ public class JavaUtils { } private static boolean isSymlink(File file) throws IOException { - Preconditions.checkNotNull(file); + Objects.requireNonNull(file); File fileInCanonicalDir = null; if (file.getParent() == null) { fileInCanonicalDir = file; @@ -201,31 +197,35 @@ public class JavaUtils { return !fileInCanonicalDir.getCanonicalFile().equals(fileInCanonicalDir.getAbsoluteFile()); } - private static final ImmutableMap<String, TimeUnit> timeSuffixes = - ImmutableMap.<String, TimeUnit>builder() - .put("us", TimeUnit.MICROSECONDS) - .put("ms", TimeUnit.MILLISECONDS) - .put("s", TimeUnit.SECONDS) - .put("m", TimeUnit.MINUTES) - .put("min", TimeUnit.MINUTES) - .put("h", TimeUnit.HOURS) - .put("d", TimeUnit.DAYS) - .build(); - - private static final ImmutableMap<String, ByteUnit> byteSuffixes = - ImmutableMap.<String, ByteUnit>builder() - .put("b", ByteUnit.BYTE) - .put("k", ByteUnit.KiB) - .put("kb", ByteUnit.KiB) - .put("m", ByteUnit.MiB) - .put("mb", ByteUnit.MiB) - .put("g", ByteUnit.GiB) - .put("gb", ByteUnit.GiB) - .put("t", ByteUnit.TiB) - .put("tb", ByteUnit.TiB) - .put("p", ByteUnit.PiB) - .put("pb", ByteUnit.PiB) - .build(); + private static final Map<String, TimeUnit> timeSuffixes; + + private static final Map<String, ByteUnit> byteSuffixes; + + static { + final Map<String, TimeUnit> timeSuffixesBuilder = new HashMap<>(); + timeSuffixesBuilder.put("us", TimeUnit.MICROSECONDS); + timeSuffixesBuilder.put("ms", TimeUnit.MILLISECONDS); + timeSuffixesBuilder.put("s", TimeUnit.SECONDS); + timeSuffixesBuilder.put("m", TimeUnit.MINUTES); + timeSuffixesBuilder.put("min", TimeUnit.MINUTES); + timeSuffixesBuilder.put("h", TimeUnit.HOURS); + timeSuffixesBuilder.put("d", TimeUnit.DAYS); + timeSuffixes = Collections.unmodifiableMap(timeSuffixesBuilder); + + final Map<String, ByteUnit> byteSuffixesBuilder = new HashMap<>(); + byteSuffixesBuilder.put("b", ByteUnit.BYTE); + byteSuffixesBuilder.put("k", ByteUnit.KiB); + byteSuffixesBuilder.put("kb", ByteUnit.KiB); + byteSuffixesBuilder.put("m", ByteUnit.MiB); + byteSuffixesBuilder.put("mb", ByteUnit.MiB); + byteSuffixesBuilder.put("g", ByteUnit.GiB); + byteSuffixesBuilder.put("gb", ByteUnit.GiB); + byteSuffixesBuilder.put("t", ByteUnit.TiB); + byteSuffixesBuilder.put("tb", ByteUnit.TiB); + byteSuffixesBuilder.put("p", ByteUnit.PiB); + byteSuffixesBuilder.put("pb", ByteUnit.PiB); + byteSuffixes = Collections.unmodifiableMap(byteSuffixesBuilder); + } /** * Convert a passed time string (e.g. 50s, 100ms, or 250us) to a time count in the given unit. diff --git a/common/utils/src/main/scala/org/apache/spark/util/SparkErrorUtils.scala b/common/utils/src/main/scala/org/apache/spark/util/SparkErrorUtils.scala index 8e4de01885e..97a07984a22 100644 --- a/common/utils/src/main/scala/org/apache/spark/util/SparkErrorUtils.scala +++ b/common/utils/src/main/scala/org/apache/spark/util/SparkErrorUtils.scala @@ -16,13 +16,14 @@ */ package org.apache.spark.util -import java.io.IOException +import java.io.{Closeable, IOException, PrintWriter} +import java.nio.charset.StandardCharsets.UTF_8 import scala.util.control.NonFatal import org.apache.spark.internal.Logging -object SparkErrorUtils extends Logging { +private[spark] trait SparkErrorUtils extends Logging { /** * Execute a block of code that returns a value, re-throwing any non-fatal uncaught * exceptions as IOException. This is used when implementing Externalizable and Serializable's @@ -41,4 +42,52 @@ object SparkErrorUtils extends Logging { throw new IOException(e) } } + + def tryWithResource[R <: Closeable, T](createResource: => R)(f: R => T): T = { + val resource = createResource + try f.apply(resource) finally resource.close() + } + + /** + * Execute a block of code, then a finally block, but if exceptions happen in + * the finally block, do not suppress the original exception. + * + * This is primarily an issue with `finally { out.close() }` blocks, where + * close needs to be called to clean up `out`, but if an exception happened + * in `out.write`, it's likely `out` may be corrupted and `out.close` will + * fail as well. This would then suppress the original/likely more meaningful + * exception from the original `out.write` call. + */ + def tryWithSafeFinally[T](block: => T)(finallyBlock: => Unit): T = { + var originalThrowable: Throwable = null + try { + block + } catch { + case t: Throwable => + // Purposefully not using NonFatal, because even fatal exceptions + // we don't want to have our finallyBlock suppress + originalThrowable = t + throw originalThrowable + } finally { + try { + finallyBlock + } catch { + case t: Throwable if (originalThrowable != null && originalThrowable != t) => + originalThrowable.addSuppressed(t) + logWarning(s"Suppressing exception in finally: ${t.getMessage}", t) + throw originalThrowable + } + } + } + + def stackTraceToString(t: Throwable): String = { + val out = new java.io.ByteArrayOutputStream + SparkErrorUtils.tryWithResource(new PrintWriter(out)) { writer => + t.printStackTrace(writer) + writer.flush() + } + new String(out.toByteArray, UTF_8) + } } + +object SparkErrorUtils extends SparkErrorUtils diff --git a/common/utils/src/main/scala/org/apache/spark/util/SparkFileUtils.scala b/common/utils/src/main/scala/org/apache/spark/util/SparkFileUtils.scala index 63d1ab4799a..e12f8acdadd 100644 --- a/common/utils/src/main/scala/org/apache/spark/util/SparkFileUtils.scala +++ b/common/utils/src/main/scala/org/apache/spark/util/SparkFileUtils.scala @@ -18,8 +18,12 @@ package org.apache.spark.util import java.io.File import java.net.{URI, URISyntaxException} +import java.nio.file.Files -private[spark] object SparkFileUtils { +import org.apache.spark.internal.Logging +import org.apache.spark.network.util.JavaUtils + +private[spark] trait SparkFileUtils extends Logging { /** * Return a well-formed URI for the file described by a user input string. * @@ -44,4 +48,78 @@ private[spark] object SparkFileUtils { } new File(path).getCanonicalFile().toURI() } + + /** + * Lists files recursively. + */ + def recursiveList(f: File): Array[File] = { + require(f.isDirectory) + val result = f.listFiles.toBuffer + val dirList = result.filter(_.isDirectory) + while (dirList.nonEmpty) { + val curDir = dirList.remove(0) + val files = curDir.listFiles() + result ++= files + dirList ++= files.filter(_.isDirectory) + } + result.toArray + } + + /** + * Create a directory given the abstract pathname + * @return true, if the directory is successfully created; otherwise, return false. + */ + def createDirectory(dir: File): Boolean = { + try { + // SPARK-35907: The check was required by File.mkdirs() because it could sporadically + // fail silently. After switching to Files.createDirectories(), ideally, there should + // no longer be silent fails. But the check is kept for the safety concern. We can + // remove the check when we're sure that Files.createDirectories() would never fail silently. + Files.createDirectories(dir.toPath) + if ( !dir.exists() || !dir.isDirectory) { + logError(s"Failed to create directory " + dir) + } + dir.isDirectory + } catch { + case e: Exception => + logError(s"Failed to create directory " + dir, e) + false + } + } + + /** + * Create a directory inside the given parent directory. The directory is guaranteed to be + * newly created, and is not marked for automatic deletion. + */ + def createDirectory(root: String, namePrefix: String = "spark"): File = { + JavaUtils.createDirectory(root, namePrefix) + } + + /** + * Create a temporary directory inside the `java.io.tmpdir` prefixed with `spark`. + * The directory will be automatically deleted when the VM shuts down. + */ + def createTempDir(): File = + createTempDir(System.getProperty("java.io.tmpdir"), "spark") + + /** + * Create a temporary directory inside the given parent directory. The directory will be + * automatically deleted when the VM shuts down. + */ + def createTempDir( + root: String = System.getProperty("java.io.tmpdir"), + namePrefix: String = "spark"): File = { + createDirectory(root, namePrefix) + } + + /** + * Delete a file or directory and its contents recursively. + * Don't follow directories if they are symlinks. + * Throws an exception if deletion is unsuccessful. + */ + def deleteRecursively(file: File): Unit = { + JavaUtils.deleteRecursively(file) + } } + +private[spark] object SparkFileUtils extends SparkFileUtils diff --git a/common/utils/src/main/scala/org/apache/spark/util/SparkSerDeUtils.scala b/common/utils/src/main/scala/org/apache/spark/util/SparkSerDeUtils.scala index 88d1d6bdba8..3069e4c36a7 100644 --- a/common/utils/src/main/scala/org/apache/spark/util/SparkSerDeUtils.scala +++ b/common/utils/src/main/scala/org/apache/spark/util/SparkSerDeUtils.scala @@ -16,7 +16,7 @@ */ package org.apache.spark.util -import java.io.{ByteArrayOutputStream, ObjectOutputStream} +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, ObjectInputStream, ObjectOutputStream} object SparkSerDeUtils { /** Serialize an object using Java serialization */ @@ -27,4 +27,11 @@ object SparkSerDeUtils { oos.close() bos.toByteArray } + + /** Deserialize an object using Java serialization */ + def deserialize[T](bytes: Array[Byte]): T = { + val bis = new ByteArrayInputStream(bytes) + val ois = new ObjectInputStream(bis) + ois.readObject.asInstanceOf[T] + } } diff --git a/connector/connect/client/jvm/pom.xml b/connector/connect/client/jvm/pom.xml index f60e60890c0..503888ce095 100644 --- a/connector/connect/client/jvm/pom.xml +++ b/connector/connect/client/jvm/pom.xml @@ -46,23 +46,11 @@ </exclusion> </exclusions> </dependency> - <!--TODO: fix the dependency once the catalyst refactoring is done--> <dependency> <groupId>org.apache.spark</groupId> - <artifactId>spark-catalyst_${scala.binary.version}</artifactId> + <artifactId>spark-sql-api_${scala.binary.version}</artifactId> <version>${project.version}</version> <scope>provided</scope> - <exclusions> - <exclusion> - <groupId>com.google.guava</groupId> - <artifactId>guava</artifactId> - </exclusion> - </exclusions> - </dependency> - <dependency> - <groupId>org.apache.spark</groupId> - <artifactId>spark-common-utils_${scala.binary.version}</artifactId> - <version>${project.version}</version> </dependency> <dependency> <groupId>org.apache.spark</groupId> diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala index 5a0f33ffd5d..d1832e65f3e 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -41,7 +41,7 @@ import org.apache.spark.sql.connect.client.SparkConnectClient.Configuration import org.apache.spark.sql.connect.client.arrow.ArrowSerializer import org.apache.spark.sql.connect.client.util.Cleaner import org.apache.spark.sql.connect.common.LiteralValueProtoConverter.toLiteralProto -import org.apache.spark.sql.internal.CatalogImpl +import org.apache.spark.sql.internal.{CatalogImpl, SqlApiConf} import org.apache.spark.sql.streaming.DataStreamReader import org.apache.spark.sql.streaming.StreamingQueryManager import org.apache.spark.sql.types.StructType @@ -127,7 +127,7 @@ class SparkSession private[sql] ( newDataset(encoder) { builder => if (data.nonEmpty) { val arrowData = ArrowSerializer.serialize(data, encoder, allocator, timeZoneId) - if (arrowData.size() <= conf.get("spark.sql.session.localRelationCacheThreshold").toInt) { + if (arrowData.size() <= conf.get(SqlApiConf.LOCAL_RELATION_CACHE_THRESHOLD_KEY).toInt) { builder.getLocalRelationBuilder .setSchema(encoder.schema.json) .setData(arrowData) @@ -528,7 +528,7 @@ class SparkSession private[sql] ( client.semanticHash(plan).getSemanticHash.getResult } - private[sql] def timeZoneId: String = conf.get("spark.sql.session.timeZone") + private[sql] def timeZoneId: String = conf.get(SqlApiConf.SESSION_LOCAL_TIMEZONE_KEY) private[sql] def execute[T](plan: proto.Plan, encoder: AgnosticEncoder[T]): SparkResult[T] = { val value = client.execute(plan) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala index cd3afeb7e90..509ceffc552 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala @@ -37,7 +37,7 @@ import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._ import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema -import org.apache.spark.sql.errors.{ExecutionErrors, QueryCompilationErrors} +import org.apache.spark.sql.errors.{CompilationErrors, ExecutionErrors} import org.apache.spark.sql.types.Decimal /** @@ -436,13 +436,13 @@ object ArrowDeserializers { val key = toKey(field.getName) val old = lookup.put(key, field) if (old.isDefined) { - throw QueryCompilationErrors.ambiguousColumnOrFieldError( + throw CompilationErrors.ambiguousColumnOrFieldError( field.getName :: Nil, fields.count(f => toKey(f.getName) == key)) } } name => { - lookup.getOrElse(toKey(name), throw QueryCompilationErrors.columnNotFoundError(name)) + lookup.getOrElse(toKey(name), throw CompilationErrors.columnNotFoundError(name)) } } diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowSerializer.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowSerializer.scala index 94e01bbef31..c4a2cfa8a85 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowSerializer.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowSerializer.scala @@ -37,7 +37,7 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.DefinedByConstructorParams import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._ -import org.apache.spark.sql.catalyst.util.{IntervalUtils, SparkDateTimeUtils} +import org.apache.spark.sql.catalyst.util.{SparkDateTimeUtils, SparkIntervalUtils} import org.apache.spark.sql.errors.ExecutionErrors import org.apache.spark.sql.types.Decimal import org.apache.spark.sql.util.ArrowUtils @@ -191,11 +191,14 @@ object ArrowSerializer { allocator: BufferAllocator, timeZoneId: String): ByteString = { val serializer = new ArrowSerializer[T](enc, allocator, timeZoneId) - serializer.reset() - input.foreach(serializer.append) - val output = ByteString.newOutput() - serializer.writeIpcStream(output) - output.toByteString + try { + input.foreach(serializer.append) + val output = ByteString.newOutput() + serializer.writeIpcStream(output) + output.toByteString + } finally { + serializer.close() + } } /** @@ -313,12 +316,12 @@ object ArrowSerializer { case (DayTimeIntervalEncoder, v: DurationVector) => new FieldSerializer[Duration, DurationVector](v) { override def set(index: Int, value: Duration): Unit = - vector.setSafe(index, IntervalUtils.durationToMicros(value)) + vector.setSafe(index, SparkIntervalUtils.durationToMicros(value)) } case (YearMonthIntervalEncoder, v: IntervalYearVector) => new FieldSerializer[Period, IntervalYearVector](v) { override def set(index: Int, value: Period): Unit = - vector.setSafe(index, IntervalUtils.periodToMonths(value)) + vector.setSafe(index, SparkIntervalUtils.periodToMonths(value)) } case (DateEncoder(true) | LocalDateEncoder(true), v: DateDayVector) => new FieldSerializer[Any, DateDayVector](v) { diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowVectorReader.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowVectorReader.scala index 63c6a0df5c8..48820857480 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowVectorReader.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowVectorReader.scala @@ -23,12 +23,11 @@ import java.time.{Duration, Instant, LocalDate, LocalDateTime, Period, ZoneOffse import org.apache.arrow.vector.{BigIntVector, BitVector, DateDayVector, DecimalVector, DurationVector, FieldVector, Float4Vector, Float8Vector, IntervalYearVector, IntVector, NullVector, SmallIntVector, TimeStampMicroTZVector, TimeStampMicroVector, TinyIntVector, VarBinaryVector, VarCharVector} import org.apache.arrow.vector.util.Text -import org.apache.spark.sql.catalyst.expressions.Cast -import org.apache.spark.sql.catalyst.util.{DateFormatter, IntervalUtils, StringUtils, TimestampFormatter} +import org.apache.spark.sql.catalyst.util.{DateFormatter, SparkIntervalUtils, SparkStringUtils, TimestampFormatter} import org.apache.spark.sql.catalyst.util.DateTimeConstants.MICROS_PER_SECOND import org.apache.spark.sql.catalyst.util.IntervalStringStyles.ANSI_STYLE import org.apache.spark.sql.catalyst.util.SparkDateTimeUtils._ -import org.apache.spark.sql.types.{DataType, DayTimeIntervalType, Decimal, YearMonthIntervalType} +import org.apache.spark.sql.types.{DataType, DayTimeIntervalType, Decimal, UpCastRule, YearMonthIntervalType} import org.apache.spark.sql.util.ArrowUtils /** @@ -69,7 +68,7 @@ object ArrowVectorReader { vector: FieldVector, timeZoneId: String): ArrowVectorReader = { val vectorDataType = ArrowUtils.fromArrowType(vector.getField.getType) - if (!Cast.canUpCast(vectorDataType, targetDataType)) { + if (!UpCastRule.canUpCast(vectorDataType, targetDataType)) { throw new RuntimeException( s"Reading '$targetDataType' values from a ${vector.getClass} instance is not supported.") } @@ -193,15 +192,15 @@ private[arrow] class VarCharVectorReader(v: VarCharVector) private[arrow] class VarBinaryVectorReader(v: VarBinaryVector) extends TypedArrowVectorReader[VarBinaryVector](v) { override def getBytes(i: Int): Array[Byte] = vector.get(i) - override def getString(i: Int): String = StringUtils.getHexString(getBytes(i)) + override def getString(i: Int): String = SparkStringUtils.getHexString(getBytes(i)) } private[arrow] class DurationVectorReader(v: DurationVector) extends TypedArrowVectorReader[DurationVector](v) { override def getDuration(i: Int): Duration = vector.getObject(i) override def getString(i: Int): String = { - IntervalUtils.toDayTimeIntervalString( - IntervalUtils.durationToMicros(getDuration(i)), + SparkIntervalUtils.toDayTimeIntervalString( + SparkIntervalUtils.durationToMicros(getDuration(i)), ANSI_STYLE, DayTimeIntervalType.DEFAULT.startField, DayTimeIntervalType.DEFAULT.endField) @@ -212,7 +211,7 @@ private[arrow] class IntervalYearVectorReader(v: IntervalYearVector) extends TypedArrowVectorReader[IntervalYearVector](v) { override def getPeriod(i: Int): Period = vector.getObject(i).normalized() override def getString(i: Int): String = { - IntervalUtils.toYearMonthIntervalString( + SparkIntervalUtils.toYearMonthIntervalString( vector.get(i), ANSI_STYLE, YearMonthIntervalType.DEFAULT.startField, diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/protobuf/functions.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/protobuf/functions.scala index 57ce013065e..293490928a2 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/protobuf/functions.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/protobuf/functions.scala @@ -16,19 +16,16 @@ */ package org.apache.spark.sql.protobuf -import java.io.File import java.io.FileNotFoundException -import java.nio.file.NoSuchFileException +import java.nio.file.{Files, NoSuchFileException, Paths} import java.util.Collections import scala.collection.JavaConverters._ import scala.util.control.NonFatal -import org.apache.commons.io.FileUtils - import org.apache.spark.annotation.Experimental import org.apache.spark.sql.Column -import org.apache.spark.sql.errors.QueryCompilationErrors +import org.apache.spark.sql.errors.CompilationErrors import org.apache.spark.sql.functions.{fnWithOptions, lit} // scalastyle:off: object.name @@ -309,13 +306,13 @@ object functions { // This method is copied from org.apache.spark.sql.protobuf.util.ProtobufUtils private def readDescriptorFileContent(filePath: String): Array[Byte] = { try { - FileUtils.readFileToByteArray(new File(filePath)) + Files.readAllBytes(Paths.get(filePath)) } catch { case ex: FileNotFoundException => - throw QueryCompilationErrors.cannotFindDescriptorFileError(filePath, ex) + throw CompilationErrors.cannotFindDescriptorFileError(filePath, ex) case ex: NoSuchFileException => - throw QueryCompilationErrors.cannotFindDescriptorFileError(filePath, ex) - case NonFatal(ex) => throw QueryCompilationErrors.descriptorParseError(ex) + throw CompilationErrors.cannotFindDescriptorFileError(filePath, ex) + case NonFatal(ex) => throw CompilationErrors.descriptorParseError(ex) } } } diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala index b69151f75be..852f1799f36 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala @@ -29,7 +29,8 @@ import org.apache.commons.lang3.{JavaVersion, SystemUtils} import org.scalactic.TolerantNumerics import org.scalatest.PrivateMethodTester -import org.apache.spark.{SPARK_VERSION, SparkException} +import org.apache.spark.SparkBuildInfo.{spark_version => SPARK_VERSION} +import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.StringEncoder import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema import org.apache.spark.sql.catalyst.parser.ParseException @@ -37,7 +38,7 @@ import org.apache.spark.sql.connect.client.{SparkConnectClient, SparkResult} import org.apache.spark.sql.connect.client.util.{IntegrationTestUtils, RemoteSparkSession} import org.apache.spark.sql.connect.client.util.SparkConnectServerUtils.port import org.apache.spark.sql.functions._ -import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SqlApiConf import org.apache.spark.sql.types._ class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper with PrivateMethodTester { @@ -929,7 +930,7 @@ class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper with PrivateM test("SparkSession.createDataFrame - large data set") { val threshold = 1024 * 1024 - withSQLConf(SQLConf.LOCAL_RELATION_CACHE_THRESHOLD.key -> threshold.toString) { + withSQLConf(SqlApiConf.LOCAL_RELATION_CACHE_THRESHOLD_KEY -> threshold.toString) { val count = 2 val suffix = "abcdef" val str = scala.util.Random.alphanumeric.take(1024 * 1024).mkString + suffix diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionSuite.scala index c44b515bded..525a5902525 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionSuite.scala @@ -21,7 +21,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.SparkException import org.apache.spark.sql.connect.client.util.QueryTest -import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SqlApiConf import org.apache.spark.sql.types.{StringType, StructType} class DataFrameNaFunctionSuite extends QueryTest with SQLHelper { @@ -388,7 +388,7 @@ class DataFrameNaFunctionSuite extends QueryTest with SQLHelper { } test("replace float with nan") { - withSQLConf(SQLConf.ANSI_ENABLED.key -> false.toString) { + withSQLConf(SqlApiConf.ANSI_ENABLED_KEY -> false.toString) { checkAnswer( createNaNDF().na.replace("*", Map(1.0f -> Float.NaN)), Row(0, 0L, 0.toShort, 0.toByte, Float.NaN, Double.NaN) :: @@ -397,7 +397,7 @@ class DataFrameNaFunctionSuite extends QueryTest with SQLHelper { } test("replace double with nan") { - withSQLConf(SQLConf.ANSI_ENABLED.key -> false.toString) { + withSQLConf(SqlApiConf.ANSI_ENABLED_KEY -> false.toString) { checkAnswer( createNaNDF().na.replace("*", Map(1.0 -> Double.NaN)), Row(0, 0L, 0.toShort, 0.toByte, Float.NaN, Double.NaN) :: diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala index 7e4e0f24f4f..11d0696b6e1 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala @@ -44,7 +44,7 @@ import org.apache.spark.sql.functions.lit import org.apache.spark.sql.protobuf.{functions => pbFn} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval -import org.apache.spark.util.Utils +import org.apache.spark.util.SparkFileUtils // scalastyle:off /** @@ -131,7 +131,7 @@ class PlanGenerationTestSuite private def cleanOrphanedGoldenFile(): Unit = { val allTestNames = testNames.map(_.replace(' ', '_')) - val orphans = Utils + val orphans = SparkFileUtils .recursiveList(queryFilePath.toFile) .filter(g => g.getAbsolutePath.endsWith(".proto.bin") || @@ -139,7 +139,7 @@ class PlanGenerationTestSuite .filter(g => !allTestNames.contains(g.getName.stripSuffix(".proto.bin")) && !allTestNames.contains(g.getName.stripSuffix(".json"))) - orphans.foreach(Utils.deleteRecursively) + orphans.foreach(SparkFileUtils.deleteRecursively) } private def test(name: String)(f: => Dataset[_]): Unit = super.test(name) { diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SQLHelper.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SQLHelper.scala index 5603099fd49..f357270e20f 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SQLHelper.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SQLHelper.scala @@ -21,8 +21,7 @@ import java.util.UUID import org.scalatest.Assertions.fail -import org.apache.spark.sql.catalyst.catalog.SessionCatalog.DEFAULT_DATABASE -import org.apache.spark.util.Utils +import org.apache.spark.util.{SparkErrorUtils, SparkFileUtils} trait SQLHelper { @@ -77,7 +76,7 @@ trait SQLHelper { try f(dbName) finally { if (spark.catalog.currentDatabase == dbName) { - spark.sql(s"USE $DEFAULT_DATABASE") + spark.sql(s"USE default") } spark.sql(s"DROP DATABASE $dbName CASCADE") } @@ -88,17 +87,17 @@ trait SQLHelper { * If a file/directory is created there by `f`, it will be delete after `f` returns. */ protected def withTempPath(f: File => Unit): Unit = { - val path = Utils.createTempDir() + val path = SparkFileUtils.createTempDir() path.delete() try f(path) - finally Utils.deleteRecursively(path) + finally SparkFileUtils.deleteRecursively(path) } /** * Drops table `tableName` after calling `f`. */ protected def withTable(tableNames: String*)(f: => Unit): Unit = { - Utils.tryWithSafeFinally(f) { + SparkErrorUtils.tryWithSafeFinally(f) { tableNames.foreach { name => spark.sql(s"DROP TABLE IF EXISTS $name").collect() } diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SQLImplicitsTestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SQLImplicitsTestSuite.scala index 3b5d7dae1b3..6db38bfb1c3 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SQLImplicitsTestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SQLImplicitsTestSuite.scala @@ -22,11 +22,12 @@ import java.time.temporal.ChronoUnit import java.util.concurrent.atomic.AtomicLong import io.grpc.inprocess.InProcessChannelBuilder +import org.apache.arrow.memory.RootAllocator import org.apache.commons.lang3.{JavaVersion, SystemUtils} import org.scalatest.BeforeAndAfterAll -import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, ExpressionEncoder} import org.apache.spark.sql.connect.client.SparkConnectClient +import org.apache.spark.sql.connect.client.arrow.{ArrowDeserializers, ArrowSerializer} import org.apache.spark.sql.connect.client.util.ConnectFunSuite /** @@ -54,12 +55,28 @@ class SQLImplicitsTestSuite extends ConnectFunSuite with BeforeAndAfterAll { val spark = session import spark.implicits._ def testImplicit[T: Encoder](expected: T): Unit = { - val encoder = implicitly[Encoder[T]].asInstanceOf[AgnosticEncoder[T]] - val expressionEncoder = ExpressionEncoder(encoder).resolveAndBind() - val serializer = expressionEncoder.createSerializer() - val deserializer = expressionEncoder.createDeserializer() - val actual = deserializer(serializer(expected)) - assert(actual === expected) + val encoder = encoderFor[T] + val allocator = new RootAllocator() + try { + val batch = ArrowSerializer.serialize( + input = Iterator.single(expected), + enc = encoder, + allocator = allocator, + timeZoneId = "UTC") + val fromArrow = ArrowDeserializers.deserializeFromArrow( + input = Iterator.single(batch.toByteArray), + encoder = encoder, + allocator = allocator, + timeZoneId = "UTC") + try { + assert(fromArrow.next() === expected) + assert(!fromArrow.hasNext) + } finally { + fromArrow.close() + } + } finally { + allocator.close() + } } val booleans = Array(false, true, false, false) diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionE2ESuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionE2ESuite.scala index 5afafaaa6b9..86deae982a5 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionE2ESuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionE2ESuite.scala @@ -16,6 +16,8 @@ */ package org.apache.spark.sql +import java.util.concurrent.ForkJoinPool + import scala.collection.mutable import scala.concurrent.{ExecutionContext, ExecutionContextExecutor, Future} import scala.concurrent.duration._ @@ -25,7 +27,7 @@ import org.scalatest.concurrent.Eventually._ import org.apache.spark.SparkException import org.apache.spark.sql.connect.client.util.RemoteSparkSession -import org.apache.spark.util.ThreadUtils +import org.apache.spark.util.SparkThreadUtils.awaitResult /** * NOTE: Do not import classes that only exist in `spark-connect-client-jvm.jar` into the this @@ -102,7 +104,7 @@ class SparkSessionE2ESuite extends RemoteSparkSession { } assert(e2.getMessage.contains("OPERATION_CANCELED"), s"Unexpected exception: $e2") finished = true - assert(ThreadUtils.awaitResult(interruptor, 10.seconds)) + assert(awaitResult(interruptor, 10.seconds)) assert(interrupted.length == 2, s"Interrupted operations: $interrupted.") } @@ -113,7 +115,7 @@ class SparkSessionE2ESuite extends RemoteSparkSession { // global ExecutionContext has only 2 threads in Apache Spark CI // create own thread pool for four Futures used in this test val numThreads = 4 - val fpool = ThreadUtils.newForkJoinPool("job-tags-test-thread-pool", numThreads) + val fpool = new ForkJoinPool(numThreads) val executionContext = ExecutionContext.fromExecutorService(fpool) val q1 = Future { @@ -200,11 +202,11 @@ class SparkSessionE2ESuite extends RemoteSparkSession { assert(interrupted.length == 2, s"Interrupted operations: $interrupted.") } val e2 = intercept[SparkException] { - ThreadUtils.awaitResult(q2, 1.minute) + awaitResult(q2, 1.minute) } assert(e2.getCause.getMessage contains "OPERATION_CANCELED") val e3 = intercept[SparkException] { - ThreadUtils.awaitResult(q3, 1.minute) + awaitResult(q3, 1.minute) } assert(e3.getCause.getMessage contains "OPERATION_CANCELED") assert(interrupted.length == 2, s"Interrupted operations: $interrupted.") @@ -217,11 +219,11 @@ class SparkSessionE2ESuite extends RemoteSparkSession { assert(interrupted.length == 2, s"Interrupted operations: $interrupted.") } val e1 = intercept[SparkException] { - ThreadUtils.awaitResult(q1, 1.minute) + awaitResult(q1, 1.minute) } assert(e1.getCause.getMessage contains "OPERATION_CANCELED") val e4 = intercept[SparkException] { - ThreadUtils.awaitResult(q4, 1.minute) + awaitResult(q4, 1.minute) } assert(e4.getCause.getMessage contains "OPERATION_CANCELED") assert(interrupted.length == 2, s"Interrupted operations: $interrupted.") diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala index a4f1a61cf39..258fa1e7c74 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala @@ -23,7 +23,6 @@ import java.util.concurrent.atomic.AtomicLong import scala.collection.JavaConverters._ -import org.apache.spark.TaskContext import org.apache.spark.api.java.function._ import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{PrimitiveIntEncoder, PrimitiveLongEncoder} import org.apache.spark.sql.connect.client.util.QueryTest @@ -157,11 +156,8 @@ class UserDefinedFunctionE2ETestSuite extends QueryTest { val sum = new AtomicLong() val func: Iterator[JLong] => Unit = f => { f.foreach(v => sum.addAndGet(v)) - TaskContext - .get() - .addTaskCompletionListener(_ => - // The value should be 45 - assert(sum.get() == -1)) + // The value should be 45 + assert(sum.get() == -1) } val exception = intercept[Exception] { spark.range(10).repartition(1).foreachPartition(func) @@ -178,11 +174,8 @@ class UserDefinedFunctionE2ETestSuite extends QueryTest { .foreachPartition(new ForeachPartitionFunction[JLong] { override def call(t: JIterator[JLong]): Unit = { t.asScala.foreach(v => sum.addAndGet(v)) - TaskContext - .get() - .addTaskCompletionListener(_ => - // The value should be 45 - assert(sum.get() == -1)) + // The value should be 45 + assert(sum.get() == -1) } }) } diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionSuite.scala index 1c4ee217737..684f5671e48 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionSuite.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.connect.client.util.ConnectFunSuite import org.apache.spark.sql.connect.common.UdfPacket import org.apache.spark.sql.functions.udf -import org.apache.spark.util.Utils +import org.apache.spark.util.SparkSerDeUtils class UserDefinedFunctionSuite extends ConnectFunSuite with BeforeAndAfterEach { @@ -42,7 +42,7 @@ class UserDefinedFunctionSuite extends ConnectFunSuite with BeforeAndAfterEach { assert(udfObj.getNullable) - val deSer = Utils.deserialize[UdfPacket](udfObj.getPayload.toByteArray) + val deSer = SparkSerDeUtils.deserialize[UdfPacket](udfObj.getPayload.toByteArray) assert(deSer.function.asInstanceOf[Int => Int](5) == func(5)) assert(deSer.outputEncoder == ScalaReflection.encoderFor(typeTag[Int])) diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/ClassFinderSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/ClassFinderSuite.scala index c9066615bb5..625d4cf43e1 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/ClassFinderSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/ClassFinderSuite.scala @@ -21,14 +21,14 @@ import java.nio.file.Paths import org.apache.commons.io.FileUtils import org.apache.spark.sql.connect.client.util.ConnectFunSuite -import org.apache.spark.util.Utils +import org.apache.spark.util.SparkFileUtils class ClassFinderSuite extends ConnectFunSuite { private val classResourcePath = commonResourcePath.resolve("artifact-tests") test("REPLClassDirMonitor functionality test") { - val copyDir = Utils.createTempDir().toPath + val copyDir = SparkFileUtils.createTempDir().toPath FileUtils.copyDirectory(classResourcePath.toFile, copyDir.toFile) val monitor = new REPLClassDirMonitor(copyDir.toAbsolutePath.toString) @@ -47,7 +47,7 @@ class ClassFinderSuite extends ConnectFunSuite { checkClasses(monitor) // Add new class file into directory - val subDir = Utils.createTempDir(copyDir.toAbsolutePath.toString) + val subDir = SparkFileUtils.createTempDir(copyDir.toAbsolutePath.toString) val classToCopy = copyDir.resolve("Hello.class") val copyLocation = subDir.toPath.resolve("HelloDup.class") FileUtils.copyFile(classToCopy.toFile, copyLocation.toFile) diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala index 7d5311e7312..dd0e9347ac8 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala @@ -36,11 +36,11 @@ import org.apache.spark.sql.catalyst.{DefinedByConstructorParams, JavaTypeInfere import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{BinaryEncoder, BoxedBooleanEncoder, BoxedByteEncoder, BoxedDoubleEncoder, BoxedFloatEncoder, BoxedIntEncoder, BoxedLongEncoder, BoxedShortEncoder, CalendarIntervalEncoder, DateEncoder, DayTimeIntervalEncoder, EncoderField, InstantEncoder, IterableEncoder, JavaDecimalEncoder, LocalDateEncoder, LocalDateTimeEncoder, NullEncoder, PrimitiveBooleanEncoder, PrimitiveByteEncoder, PrimitiveDoubleEncoder, PrimitiveFloatEncoder, Primi [...] import org.apache.spark.sql.catalyst.encoders.RowEncoder.{encoderFor => toRowEncoder} -import org.apache.spark.sql.catalyst.util.{DateFormatter, StringUtils, TimestampFormatter} +import org.apache.spark.sql.catalyst.util.{DateFormatter, SparkStringUtils, TimestampFormatter} import org.apache.spark.sql.catalyst.util.DateTimeConstants.MICROS_PER_SECOND import org.apache.spark.sql.catalyst.util.IntervalStringStyles.ANSI_STYLE -import org.apache.spark.sql.catalyst.util.IntervalUtils._ import org.apache.spark.sql.catalyst.util.SparkDateTimeUtils._ +import org.apache.spark.sql.catalyst.util.SparkIntervalUtils._ import org.apache.spark.sql.connect.client.arrow.FooEnum.FooEnum import org.apache.spark.sql.connect.client.util.ConnectFunSuite import org.apache.spark.sql.types.{ArrayType, DataType, DayTimeIntervalType, Decimal, DecimalType, IntegerType, Metadata, SQLUserDefinedType, StructType, UserDefinedType, YearMonthIntervalType} @@ -900,7 +900,7 @@ class ArrowEncoderSuite extends ConnectFunSuite with BeforeAndAfterAll { YearMonthIntervalType.DEFAULT.endField) }) UpCastTestCase(BinaryEncoder, i => Array.tabulate(10)(j => (64 + j + i).toByte)) - .test(StringEncoder, bytes => StringUtils.getHexString(bytes)) + .test(StringEncoder, bytes => SparkStringUtils.getHexString(bytes)) /* ******************************************************************** * * Arrow serialization/deserialization specific errors diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/IntegrationTestUtils.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/IntegrationTestUtils.scala index 0eaca7577b9..819df5fc25b 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/IntegrationTestUtils.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/IntegrationTestUtils.scala @@ -23,7 +23,8 @@ import scala.util.Properties.versionNumberString import org.scalatest.Assertions.fail -import org.apache.spark.util.Utils +import org.apache.spark.SparkBuildInfo.{spark_version => SPARK_VERSION} +import org.apache.spark.util.SparkFileUtils object IntegrationTestUtils { @@ -75,14 +76,14 @@ object IntegrationTestUtils { private[sql] lazy val isSparkHiveJarAvailable: Boolean = { val filePath = s"$sparkHome/assembly/target/$scalaDir/jars/" + - s"spark-hive_$scalaVersion-${org.apache.spark.SPARK_VERSION}.jar" + s"spark-hive_$scalaVersion-$SPARK_VERSION.jar" Files.exists(Paths.get(filePath)) } private[sql] def cleanUpHiveClassesDirIfNeeded(): Unit = { def delete(f: File): Unit = { if (f.exists()) { - Utils.deleteRecursively(f) + SparkFileUtils.deleteRecursively(f) } } delete(new File(s"$sparkHome/sql/hive/target/$scalaDir/classes")) @@ -105,7 +106,7 @@ object IntegrationTestUtils { val jar = tryFindJar(path, sbtName, mvnName, test).getOrElse({ val suffix = if (test) "-tests.jar" else ".jar" val sbtFileName = s"$sbtName(.*)$suffix" - val mvnFileName = s"$mvnName(.*)${org.apache.spark.SPARK_VERSION}$suffix" + val mvnFileName = s"$mvnName(.*)$SPARK_VERSION$suffix" throw new RuntimeException(s"Failed to find the jar: $sbtFileName or $mvnFileName " + s"inside folder: ${getTargetFilePath(path)}. This file can be generated by similar to " + s"the following command: build/sbt package|assembly") @@ -136,7 +137,7 @@ object IntegrationTestUtils { // Maven Jar (f.getParent.endsWith("target") && f.getName.startsWith(mvnName) && - f.getName.endsWith(s"${org.apache.spark.SPARK_VERSION}$suffix")) + f.getName.endsWith(s"$SPARK_VERSION$suffix")) } } diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/QueryTest.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/QueryTest.scala index fdbb3edbf84..a0d3d4368dd 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/QueryTest.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/QueryTest.scala @@ -22,7 +22,7 @@ import java.util.TimeZone import org.scalatest.Assertions import org.apache.spark.sql.{DataFrame, Dataset, Row} -import org.apache.spark.sql.catalyst.util.sideBySide +import org.apache.spark.sql.catalyst.util.SparkStringUtils.sideBySide abstract class QueryTest extends RemoteSparkSession { @@ -122,7 +122,7 @@ object QueryTest extends Assertions { |${df.analyze} |== Exception == |$e - |${org.apache.spark.sql.catalyst.util.stackTraceToString(e)} + |${org.apache.spark.util.SparkErrorUtils.stackTraceToString(e)} """.stripMargin return Some(errorMessage) } diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index 62770ae383e..f1ae9684d2a 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.connect.client.util.QueryTest import org.apache.spark.sql.functions.col import org.apache.spark.sql.functions.window import org.apache.spark.sql.streaming.StreamingQueryListener.{QueryIdleEvent, QueryStartedEvent, QueryTerminatedEvent} -import org.apache.spark.util.Utils +import org.apache.spark.util.SparkFileUtils class StreamingQuerySuite extends QueryTest with SQLHelper with Logging { @@ -359,7 +359,7 @@ class TestForeachWriter[T] extends ForeachWriter[T] { var path: File = _ def open(partitionId: Long, version: Long): Boolean = { - path = Utils.createTempDir() + path = SparkFileUtils.createTempDir() fileWriter = new FileWriter(path, true) true } @@ -371,7 +371,7 @@ class TestForeachWriter[T] extends ForeachWriter[T] { def close(errorOrNull: Throwable): Unit = { fileWriter.close() - Utils.deleteRecursively(path) + SparkFileUtils.deleteRecursively(path) } } diff --git a/connector/connect/common/pom.xml b/connector/connect/common/pom.xml index 1890384b51d..891cda97236 100644 --- a/connector/connect/common/pom.xml +++ b/connector/connect/common/pom.xml @@ -36,15 +36,9 @@ <dependencies> <dependency> <groupId>org.apache.spark</groupId> - <artifactId>spark-catalyst_${scala.binary.version}</artifactId> + <artifactId>spark-sql-api_${scala.binary.version}</artifactId> <version>${project.version}</version> <scope>provided</scope> - <exclusions> - <exclusion> - <groupId>com.google.guava</groupId> - <artifactId>guava</artifactId> - </exclusion> - </exclusions> </dependency> <dependency> <groupId>org.scala-lang</groupId> diff --git a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala index f2abbf9c723..00546c02bc7 100644 --- a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala +++ b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala @@ -32,11 +32,11 @@ import com.google.protobuf.ByteString import org.apache.spark.connect.proto import org.apache.spark.sql.catalyst.ScalaReflection -import org.apache.spark.sql.catalyst.util.{IntervalUtils, SparkDateTimeUtils} +import org.apache.spark.sql.catalyst.util.{SparkDateTimeUtils, SparkIntervalUtils} import org.apache.spark.sql.connect.common.DataTypeProtoConverter._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval -import org.apache.spark.util.Utils +import org.apache.spark.util.SparkClassUtils object LiteralValueProtoConverter { @@ -93,8 +93,8 @@ object LiteralValueProtoConverter { case v: LocalDateTime => builder.setTimestampNtz(SparkDateTimeUtils.localDateTimeToMicros(v)) case v: Date => builder.setDate(SparkDateTimeUtils.fromJavaDate(v)) - case v: Duration => builder.setDayTimeInterval(IntervalUtils.durationToMicros(v)) - case v: Period => builder.setYearMonthInterval(IntervalUtils.periodToMonths(v)) + case v: Duration => builder.setDayTimeInterval(SparkIntervalUtils.durationToMicros(v)) + case v: Period => builder.setYearMonthInterval(SparkIntervalUtils.periodToMonths(v)) case v: Array[_] => builder.setArray(arrayBuilder(v)) case v: CalendarInterval => builder.setCalendarInterval(calendarIntervalBuilder(v.months, v.days, v.microseconds)) @@ -279,10 +279,10 @@ object LiteralValueProtoConverter { literal.getCalendarInterval.getMicroseconds) case proto.Expression.Literal.LiteralTypeCase.YEAR_MONTH_INTERVAL => - IntervalUtils.monthsToPeriod(literal.getYearMonthInterval) + SparkIntervalUtils.monthsToPeriod(literal.getYearMonthInterval) case proto.Expression.Literal.LiteralTypeCase.DAY_TIME_INTERVAL => - IntervalUtils.microsToDuration(literal.getDayTimeInterval) + SparkIntervalUtils.microsToDuration(literal.getDayTimeInterval) case proto.Expression.Literal.LiteralTypeCase.ARRAY => toCatalystArray(literal.getArray) @@ -376,7 +376,7 @@ object LiteralValueProtoConverter { def toCatalystStruct(struct: proto.Expression.Literal.Struct): Any = { def toTuple[A <: Object](data: Seq[A]): Product = { try { - val tupleClass = Utils.classForName(s"scala.Tuple${data.length}") + val tupleClass = SparkClassUtils.classForName(s"scala.Tuple${data.length}") tupleClass.getConstructors.head.newInstance(data: _*).asInstanceOf[Product] } catch { case _: Exception => diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 2d61b1b6305..a3002eb40f4 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -91,7 +91,11 @@ private[spark] object CallSite { /** * Various utility methods used by Spark. */ -private[spark] object Utils extends Logging with SparkClassUtils { +private[spark] object Utils + extends Logging + with SparkClassUtils + with SparkErrorUtils + with SparkFileUtils { private val sparkUncaughtExceptionHandler = new SparkUncaughtExceptionHandler @volatile private var cachedLocalDir: String = "" @@ -119,16 +123,10 @@ private[spark] object Utils extends Logging with SparkClassUtils { }) /** Serialize an object using Java serialization */ - def serialize[T](o: T): Array[Byte] = { - SparkSerDeUtils.serialize(o) - } + def serialize[T](o: T): Array[Byte] = SparkSerDeUtils.serialize(o) /** Deserialize an object using Java serialization */ - def deserialize[T](bytes: Array[Byte]): T = { - val bis = new ByteArrayInputStream(bytes) - val ois = new ObjectInputStream(bis) - ois.readObject.asInstanceOf[T] - } + def deserialize[T](bytes: Array[Byte]): T = SparkSerDeUtils.deserialize(bytes) /** Deserialize an object using Java serialization and the given ClassLoader */ def deserialize[T](bytes: Array[Byte], loader: ClassLoader): T = { @@ -252,48 +250,11 @@ private[spark] object Utils extends Logging with SparkClassUtils { file.setExecutable(true, true) } - /** - * Create a directory given the abstract pathname - * @return true, if the directory is successfully created; otherwise, return false. - */ - def createDirectory(dir: File): Boolean = { - try { - // SPARK-35907: The check was required by File.mkdirs() because it could sporadically - // fail silently. After switching to Files.createDirectories(), ideally, there should - // no longer be silent fails. But the check is kept for the safety concern. We can - // remove the check when we're sure that Files.createDirectories() would never fail silently. - Files.createDirectories(dir.toPath) - if ( !dir.exists() || !dir.isDirectory) { - logError(s"Failed to create directory " + dir) - } - dir.isDirectory - } catch { - case e: Exception => - logError(s"Failed to create directory " + dir, e) - false - } - } - - /** - * Create a directory inside the given parent directory. The directory is guaranteed to be - * newly created, and is not marked for automatic deletion. - */ - def createDirectory(root: String, namePrefix: String = "spark"): File = { - JavaUtils.createDirectory(root, namePrefix) - } - - /** - * Create a temporary directory inside the `java.io.tmpdir` prefixed with `spark`. - * The directory will be automatically deleted when the VM shuts down. - */ - def createTempDir(): File = - createTempDir(System.getProperty("java.io.tmpdir"), "spark") - /** * Create a temporary directory inside the given parent directory. The directory will be * automatically deleted when the VM shuts down. */ - def createTempDir( + override def createTempDir( root: String = System.getProperty("java.io.tmpdir"), namePrefix: String = "spark"): File = { val dir = createDirectory(root, namePrefix) @@ -1175,30 +1136,14 @@ private[spark] object Utils extends Logging with SparkClassUtils { s"${TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTimeNs)} ms" } - /** - * Lists files recursively. - */ - def recursiveList(f: File): Array[File] = { - require(f.isDirectory) - val result = f.listFiles.toBuffer - val dirList = result.filter(_.isDirectory) - while (dirList.nonEmpty) { - val curDir = dirList.remove(0) - val files = curDir.listFiles() - result ++= files - dirList ++= files.filter(_.isDirectory) - } - result.toArray - } - /** * Delete a file or directory and its contents recursively. * Don't follow directories if they are symlinks. * Throws an exception if deletion is unsuccessful. */ - def deleteRecursively(file: File): Unit = { + override def deleteRecursively(file: File): Unit = { + super.deleteRecursively(file) if (file != null) { - JavaUtils.deleteRecursively(file) ShutdownHookManager.removeShutdownDeleteDir(file) } } @@ -1443,16 +1388,6 @@ private[spark] object Utils extends Logging with SparkClassUtils { } } - /** - * Execute a block of code that returns a value, re-throwing any non-fatal uncaught - * exceptions as IOException. This is used when implementing Externalizable and Serializable's - * read and write methods, since Java's serializer will not report non-IOExceptions properly; - * see SPARK-4080 for more context. - */ - def tryOrIOException[T](block: => T): T = { - SparkErrorUtils.tryOrIOException(block) - } - /** Executes the given block. Log non-fatal errors if any, and only throw fatal errors */ def tryLogNonFatalError(block: => Unit): Unit = { try { @@ -1463,38 +1398,6 @@ private[spark] object Utils extends Logging with SparkClassUtils { } } - /** - * Execute a block of code, then a finally block, but if exceptions happen in - * the finally block, do not suppress the original exception. - * - * This is primarily an issue with `finally { out.close() }` blocks, where - * close needs to be called to clean up `out`, but if an exception happened - * in `out.write`, it's likely `out` may be corrupted and `out.close` will - * fail as well. This would then suppress the original/likely more meaningful - * exception from the original `out.write` call. - */ - def tryWithSafeFinally[T](block: => T)(finallyBlock: => Unit): T = { - var originalThrowable: Throwable = null - try { - block - } catch { - case t: Throwable => - // Purposefully not using NonFatal, because even fatal exceptions - // we don't want to have our finallyBlock suppress - originalThrowable = t - throw originalThrowable - } finally { - try { - finallyBlock - } catch { - case t: Throwable if (originalThrowable != null && originalThrowable != t) => - originalThrowable.addSuppressed(t) - logWarning(s"Suppressing exception in finally: ${t.getMessage}", t) - throw originalThrowable - } - } - } - /** * Execute a block of code and call the failure callbacks in the catch block. If exceptions occur * in either the catch or the finally block, they are appended to the list of suppressed @@ -2079,16 +1982,6 @@ private[spark] object Utils extends Logging with SparkClassUtils { } } - /** - * Return a well-formed URI for the file described by a user input string. - * - * If the supplied path does not contain a scheme, or is a relative path, it will be - * converted into an absolute path with a file:// scheme. - */ - def resolveURI(path: String): URI = { - SparkFileUtils.resolveURI(path) - } - /** Resolve a comma-separated list of paths. */ def resolveURIs(paths: String): String = { if (paths == null || paths.trim.isEmpty) { @@ -2740,11 +2633,6 @@ private[spark] object Utils extends Logging with SparkClassUtils { initialExecutors } - def tryWithResource[R <: Closeable, T](createResource: => R)(f: R => T): T = { - val resource = createResource - try f.apply(resource) finally resource.close() - } - /** * Returns a path of temporary file which is in the same directory with `path`. */ diff --git a/common/utils/src/main/scala/org/apache/spark/util/SparkSerDeUtils.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalGroupState.scala similarity index 69% copy from common/utils/src/main/scala/org/apache/spark/util/SparkSerDeUtils.scala copy to sql/api/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalGroupState.scala index 88d1d6bdba8..40ec411aac6 100644 --- a/common/utils/src/main/scala/org/apache/spark/util/SparkSerDeUtils.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalGroupState.scala @@ -14,17 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.util +package org.apache.spark.sql.catalyst.plans.logical -import java.io.{ByteArrayOutputStream, ObjectOutputStream} - -object SparkSerDeUtils { - /** Serialize an object using Java serialization */ - def serialize[T](o: T): Array[Byte] = { - val bos = new ByteArrayOutputStream() - val oos = new ObjectOutputStream(bos) - oos.writeObject(o) - oos.close() - bos.toByteArray - } -} +/** Internal class representing State */ +trait LogicalGroupState[S] diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/SparkIntervalUtils.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/SparkIntervalUtils.scala index 05ceb04f12b..f9c132ebf37 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/SparkIntervalUtils.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/SparkIntervalUtils.scala @@ -16,10 +16,113 @@ */ package org.apache.spark.sql.catalyst.util -import org.apache.spark.sql.catalyst.util.DateTimeConstants.{DAYS_PER_WEEK, MICROS_PER_HOUR, MICROS_PER_MINUTE, MICROS_PER_SECOND, MONTHS_PER_YEAR, NANOS_PER_MICROS, NANOS_PER_SECOND} +import java.time.{Duration, Period} +import java.time.temporal.ChronoUnit + +import scala.collection.mutable + +import org.apache.spark.sql.catalyst.util.DateTimeConstants._ +import org.apache.spark.sql.catalyst.util.IntervalStringStyles.{ANSI_STYLE, HIVE_STYLE, IntervalStyle} +import org.apache.spark.sql.types.{DayTimeIntervalType => DT, YearMonthIntervalType => YM} import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} trait SparkIntervalUtils { + protected val MAX_DAY: Long = Long.MaxValue / MICROS_PER_DAY + protected val MAX_HOUR: Long = Long.MaxValue / MICROS_PER_HOUR + protected val MAX_MINUTE: Long = Long.MaxValue / MICROS_PER_MINUTE + protected val MAX_SECOND: Long = Long.MaxValue / MICROS_PER_SECOND + protected val MIN_SECOND: Long = Long.MinValue / MICROS_PER_SECOND + + // The amount of seconds that can cause overflow in the conversion to microseconds + private final val minDurationSeconds = Math.floorDiv(Long.MinValue, MICROS_PER_SECOND) + + /** + * Converts this duration to the total length in microseconds. + * <p> + * If this duration is too large to fit in a [[Long]] microseconds, then an + * exception is thrown. + * <p> + * If this duration has greater than microsecond precision, then the conversion + * will drop any excess precision information as though the amount in nanoseconds + * was subject to integer division by one thousand. + * + * @return The total length of the duration in microseconds + * @throws ArithmeticException If numeric overflow occurs + */ + def durationToMicros(duration: Duration): Long = { + durationToMicros(duration, DT.SECOND) + } + + def durationToMicros(duration: Duration, endField: Byte): Long = { + val seconds = duration.getSeconds + val micros = if (seconds == minDurationSeconds) { + val microsInSeconds = (minDurationSeconds + 1) * MICROS_PER_SECOND + val nanoAdjustment = duration.getNano + assert(0 <= nanoAdjustment && nanoAdjustment < NANOS_PER_SECOND, + "Duration.getNano() must return the adjustment to the seconds field " + + "in the range from 0 to 999999999 nanoseconds, inclusive.") + Math.addExact(microsInSeconds, (nanoAdjustment - NANOS_PER_SECOND) / NANOS_PER_MICROS) + } else { + val microsInSeconds = Math.multiplyExact(seconds, MICROS_PER_SECOND) + Math.addExact(microsInSeconds, duration.getNano / NANOS_PER_MICROS) + } + + endField match { + case DT.DAY => micros - micros % MICROS_PER_DAY + case DT.HOUR => micros - micros % MICROS_PER_HOUR + case DT.MINUTE => micros - micros % MICROS_PER_MINUTE + case DT.SECOND => micros + } + } + + /** + * Gets the total number of months in this period. + * <p> + * This returns the total number of months in the period by multiplying the + * number of years by 12 and adding the number of months. + * <p> + * + * @return The total number of months in the period, may be negative + * @throws ArithmeticException If numeric overflow occurs + */ + def periodToMonths(period: Period): Int = { + periodToMonths(period, YM.MONTH) + } + + def periodToMonths(period: Period, endField: Byte): Int = { + val monthsInYears = Math.multiplyExact(period.getYears, MONTHS_PER_YEAR) + val months = Math.addExact(monthsInYears, period.getMonths) + if (endField == YM.YEAR) { + months - months % MONTHS_PER_YEAR + } else { + months + } + } + + /** + * Obtains a [[Duration]] representing a number of microseconds. + * + * @param micros The number of microseconds, positive or negative + * @return A [[Duration]], not null + */ + def microsToDuration(micros: Long): Duration = Duration.of(micros, ChronoUnit.MICROS) + + /** + * Obtains a [[Period]] representing a number of months. The days unit will be zero, and the years + * and months units will be normalized. + * + * <p> + * The months unit is adjusted to have an absolute value < 12, with the years unit being adjusted + * to compensate. For example, the method returns "2 years and 3 months" for the 27 input months. + * <p> + * The sign of the years and months units will be the same after normalization. + * For example, -13 months will be converted to "-1 year and -1 month". + * + * @param months The number of months, positive or negative + * @return The period of months, not null + */ + def monthsToPeriod(months: Int): Period = Period.ofMonths(months).normalized() + /** * Converts a string to [[CalendarInterval]] case-insensitively. * @@ -226,22 +329,176 @@ trait SparkIntervalUtils { result } + /** + * Converts an year-month interval as a number of months to its textual representation + * which conforms to the ANSI SQL standard. + * + * @param months The number of months, positive or negative + * @param style The style of textual representation of the interval + * @param startField The start field (YEAR or MONTH) which the interval comprises of. + * @param endField The end field (YEAR or MONTH) which the interval comprises of. + * @return Year-month interval string + */ + def toYearMonthIntervalString( + months: Int, + style: IntervalStyle, + startField: Byte, + endField: Byte): String = { + var sign = "" + var absMonths: Long = months + if (months < 0) { + sign = "-" + absMonths = -absMonths + } + val year = s"$sign${absMonths / MONTHS_PER_YEAR}" + val yearAndMonth = s"$year-${absMonths % MONTHS_PER_YEAR}" + style match { + case ANSI_STYLE => + val formatBuilder = new StringBuilder("INTERVAL '") + if (startField == endField) { + startField match { + case YM.YEAR => formatBuilder.append(s"$year' YEAR") + case YM.MONTH => formatBuilder.append(s"$months' MONTH") + } + } else { + formatBuilder.append(s"$yearAndMonth' YEAR TO MONTH") + } + formatBuilder.toString + case HIVE_STYLE => s"$yearAndMonth" + } + } + + /** + * Converts a day-time interval as a number of microseconds to its textual representation + * which conforms to the ANSI SQL standard. + * + * @param micros The number of microseconds, positive or negative + * @param style The style of textual representation of the interval + * @param startField The start field (DAY, HOUR, MINUTE, SECOND) which the interval comprises of. + * @param endField The end field (DAY, HOUR, MINUTE, SECOND) which the interval comprises of. + * @return Day-time interval string + */ + def toDayTimeIntervalString( + micros: Long, + style: IntervalStyle, + startField: Byte, + endField: Byte): String = { + var sign = "" + var rest = micros + // scalastyle:off caselocale + val from = DT.fieldToString(startField).toUpperCase + val to = DT.fieldToString(endField).toUpperCase + // scalastyle:on caselocale + val prefix = "INTERVAL '" + val postfix = s"' ${if (startField == endField) from else s"$from TO $to"}" + + if (micros < 0) { + if (micros == Long.MinValue) { + // Especial handling of minimum `Long` value because negate op overflows `Long`. + // seconds = 106751991 * (24 * 60 * 60) + 4 * 60 * 60 + 54 = 9223372036854 + // microseconds = -9223372036854000000L-775808 == Long.MinValue + val baseStr = "-106751991 04:00:54.775808000" + val minIntervalString = style match { + case ANSI_STYLE => + val firstStr = startField match { + case DT.DAY => s"-$MAX_DAY" + case DT.HOUR => s"-$MAX_HOUR" + case DT.MINUTE => s"-$MAX_MINUTE" + case DT.SECOND => s"-$MAX_SECOND.775808" + } + val followingStr = if (startField == endField) { + "" + } else { + val substrStart = startField match { + case DT.DAY => 10 + case DT.HOUR => 13 + case DT.MINUTE => 16 + } + val substrEnd = endField match { + case DT.HOUR => 13 + case DT.MINUTE => 16 + case DT.SECOND => 26 + } + baseStr.substring(substrStart, substrEnd) + } + + s"$prefix$firstStr$followingStr$postfix" + case HIVE_STYLE => baseStr + } + return minIntervalString + } else { + sign = "-" + rest = -rest + } + } + val intervalString = style match { + case ANSI_STYLE => + val formatBuilder = new mutable.StringBuilder(sign) + val formatArgs = new mutable.ArrayBuffer[Long]() + startField match { + case DT.DAY => + formatBuilder.append(rest / MICROS_PER_DAY) + rest %= MICROS_PER_DAY + case DT.HOUR => + formatBuilder.append("%02d") + formatArgs.append(rest / MICROS_PER_HOUR) + rest %= MICROS_PER_HOUR + case DT.MINUTE => + formatBuilder.append("%02d") + formatArgs.append(rest / MICROS_PER_MINUTE) + rest %= MICROS_PER_MINUTE + case DT.SECOND => + val leadZero = if (rest < 10 * MICROS_PER_SECOND) "0" else "" + formatBuilder.append(s"$leadZero" + + s"${java.math.BigDecimal.valueOf(rest, 6).stripTrailingZeros.toPlainString}") + } + + if (startField < DT.HOUR && DT.HOUR <= endField) { + formatBuilder.append(" %02d") + formatArgs.append(rest / MICROS_PER_HOUR) + rest %= MICROS_PER_HOUR + } + if (startField < DT.MINUTE && DT.MINUTE <= endField) { + formatBuilder.append(":%02d") + formatArgs.append(rest / MICROS_PER_MINUTE) + rest %= MICROS_PER_MINUTE + } + if (startField < DT.SECOND && DT.SECOND <= endField) { + val leadZero = if (rest < 10 * MICROS_PER_SECOND) "0" else "" + formatBuilder.append( + s":$leadZero${java.math.BigDecimal.valueOf(rest, 6).stripTrailingZeros.toPlainString}") + } + s"$prefix${formatBuilder.toString.format(formatArgs.toSeq: _*)}$postfix" + case HIVE_STYLE => + val secondsWithFraction = rest % MICROS_PER_MINUTE + rest /= MICROS_PER_MINUTE + val minutes = rest % MINUTES_PER_HOUR + rest /= MINUTES_PER_HOUR + val hours = rest % HOURS_PER_DAY + val days = rest / HOURS_PER_DAY + val seconds = secondsWithFraction / MICROS_PER_SECOND + val nanos = (secondsWithFraction % MICROS_PER_SECOND) * NANOS_PER_MICROS + f"$sign$days $hours%02d:$minutes%02d:$seconds%02d.$nanos%09d" + } + intervalString + } + protected def unitToUtf8(unit: String): UTF8String = { UTF8String.fromString(unit) } - protected val intervalStr = unitToUtf8("interval") + protected val intervalStr: UTF8String = unitToUtf8("interval") - protected val yearStr = unitToUtf8("year") - protected val monthStr = unitToUtf8("month") - protected val weekStr = unitToUtf8("week") - protected val dayStr = unitToUtf8("day") - protected val hourStr = unitToUtf8("hour") - protected val minuteStr = unitToUtf8("minute") - protected val secondStr = unitToUtf8("second") - protected val millisStr = unitToUtf8("millisecond") - protected val microsStr = unitToUtf8("microsecond") - protected val nanosStr = unitToUtf8("nanosecond") + protected val yearStr: UTF8String = unitToUtf8("year") + protected val monthStr: UTF8String = unitToUtf8("month") + protected val weekStr: UTF8String = unitToUtf8("week") + protected val dayStr: UTF8String = unitToUtf8("day") + protected val hourStr: UTF8String = unitToUtf8("hour") + protected val minuteStr: UTF8String = unitToUtf8("minute") + protected val secondStr: UTF8String = unitToUtf8("second") + protected val millisStr: UTF8String = unitToUtf8("millisecond") + protected val microsStr: UTF8String = unitToUtf8("microsecond") + protected val nanosStr: UTF8String = unitToUtf8("nanosecond") private object ParseState extends Enumeration { @@ -261,3 +518,9 @@ trait SparkIntervalUtils { } object SparkIntervalUtils extends SparkIntervalUtils + +// The style of textual representation of intervals +object IntervalStringStyles extends Enumeration { + type IntervalStyle = Value + val ANSI_STYLE, HIVE_STYLE = Value +} diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala index 20fb8bb94bd..3ac81b94358 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala @@ -105,4 +105,24 @@ object SparkStringUtils extends Logging { // identifier with back-ticks. "`" + name.replace("`", "``") + "`" } + + /** + * Returns a pretty string of the byte array which prints each byte as a hex digit and add spaces + * between them. For example, [1A C0]. + */ + def getHexString(bytes: Array[Byte]): String = bytes.map("%02X".format(_)).mkString("[", " ", "]") + + def sideBySide(left: String, right: String): Seq[String] = { + sideBySide(left.split("\n"), right.split("\n")) + } + + def sideBySide(left: Seq[String], right: Seq[String]): Seq[String] = { + val maxLeftSize = left.map(_.length).max + val leftPadded = left ++ Seq.fill(math.max(right.size - left.size, 0))("") + val rightPadded = right ++ Seq.fill(math.max(left.size - right.size, 0))("") + + leftPadded.zip(rightPadded).map { + case (l, r) => (if (l == r) " " else "!") + l + (" " * ((maxLeftSize - l.length) + 3)) + r + } + } } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/errors/CompilationErrors.scala b/sql/api/src/main/scala/org/apache/spark/sql/errors/CompilationErrors.scala new file mode 100644 index 00000000000..deae1198d9c --- /dev/null +++ b/sql/api/src/main/scala/org/apache/spark/sql/errors/CompilationErrors.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.errors + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.internal.SqlApiConf + +private[sql] trait CompilationErrors extends DataTypeErrorsBase { + def ambiguousColumnOrFieldError(name: Seq[String], numMatches: Int): AnalysisException = { + new AnalysisException( + errorClass = "AMBIGUOUS_COLUMN_OR_FIELD", + messageParameters = Map( + "name" -> toSQLId(name), + "n" -> numMatches.toString)) + } + + def columnNotFoundError(colName: String): AnalysisException = { + new AnalysisException( + errorClass = "COLUMN_NOT_FOUND", + messageParameters = Map( + "colName" -> toSQLId(colName), + "caseSensitiveConfig" -> toSQLConf(SqlApiConf.CASE_SENSITIVE_KEY))) + } + + def descriptorParseError(cause: Throwable): AnalysisException = { + new AnalysisException( + errorClass = "CANNOT_PARSE_PROTOBUF_DESCRIPTOR", + messageParameters = Map.empty, + cause = Option(cause)) + } + + def cannotFindDescriptorFileError(filePath: String, cause: Throwable): AnalysisException = { + new AnalysisException( + errorClass = "PROTOBUF_DESCRIPTOR_FILE_NOT_FOUND", + messageParameters = Map("filePath" -> filePath), + cause = Option(cause)) + } +} + +object CompilationErrors extends CompilationErrors diff --git a/sql/api/src/main/scala/org/apache/spark/sql/internal/SqlApiConf.scala b/sql/api/src/main/scala/org/apache/spark/sql/internal/SqlApiConf.scala index bdc4a5759a1..d746e9037ec 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/internal/SqlApiConf.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/internal/SqlApiConf.scala @@ -50,6 +50,9 @@ private[sql] object SqlApiConf { // Shared keys. val ANSI_ENABLED_KEY: String = "spark.sql.ansi.enabled" val LEGACY_TIME_PARSER_POLICY_KEY: String = "spark.sql.legacy.timeParserPolicy" + val CASE_SENSITIVE_KEY: String = "spark.sql.caseSensitive" + val SESSION_LOCAL_TIMEZONE_KEY: String = "spark.sql.session.timeZone" + val LOCAL_RELATION_CACHE_THRESHOLD_KEY: String = "spark.sql.session.localRelationCacheThreshold" /** * Defines a getter that returns the [[SqlApiConf]] within scope. diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/UpCastRule.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/UpCastRule.scala new file mode 100644 index 00000000000..8b091b51812 --- /dev/null +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/UpCastRule.scala @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.types + +import scala.collection.immutable.IndexedSeq + +/** + * Rule that defines which upcasts are allow in Spark. + */ +private[sql] object UpCastRule { + // See https://cwiki.apache.org/confluence/display/Hive/LanguageManual+Types. + // The conversion for integral and floating point types have a linear widening hierarchy: + val numericPrecedence: IndexedSeq[NumericType] = IndexedSeq( + ByteType, + ShortType, + IntegerType, + LongType, + FloatType, + DoubleType) + + /** + * Returns true iff we can safely up-cast the `from` type to `to` type without any truncating or + * precision lose or possible runtime failures. For example, long to int, string to int are not + * up-casts. + */ + def canUpCast(from: DataType, to: DataType): Boolean = (from, to) match { + case _ if from == to => true + case (from: NumericType, to: DecimalType) if to.isWiderThan(from) => true + case (from: DecimalType, to: NumericType) if from.isTighterThan(to) => true + case (f, t) if legalNumericPrecedence(f, t) => true + case (DateType, TimestampType) => true + case (DateType, TimestampNTZType) => true + case (TimestampNTZType, TimestampType) => true + case (TimestampType, TimestampNTZType) => true + case (_: AtomicType, StringType) => true + case (_: CalendarIntervalType, StringType) => true + case (NullType, _) => true + + // Spark supports casting between long and timestamp, please see `longToTimestamp` and + // `timestampToLong` for details. + case (TimestampType, LongType) => true + case (LongType, TimestampType) => true + + case (ArrayType(fromType, fn), ArrayType(toType, tn)) => + resolvableNullability(fn, tn) && canUpCast(fromType, toType) + + case (MapType(fromKey, fromValue, fn), MapType(toKey, toValue, tn)) => + resolvableNullability(fn, tn) && canUpCast(fromKey, toKey) && canUpCast(fromValue, toValue) + + case (StructType(fromFields), StructType(toFields)) => + fromFields.length == toFields.length && + fromFields.zip(toFields).forall { + case (f1, f2) => + resolvableNullability(f1.nullable, f2.nullable) && canUpCast(f1.dataType, f2.dataType) + } + + case (_: DayTimeIntervalType, _: DayTimeIntervalType) => true + case (_: YearMonthIntervalType, _: YearMonthIntervalType) => true + + case (from: UserDefinedType[_], to: UserDefinedType[_]) if to.acceptsType(from) => true + + case _ => false + } + + private def legalNumericPrecedence(from: DataType, to: DataType): Boolean = { + val fromPrecedence = numericPrecedence.indexOf(from) + val toPrecedence = numericPrecedence.indexOf(to) + fromPrecedence >= 0 && fromPrecedence < toPrecedence + } + + private def resolvableNullability(from: Boolean, to: Boolean): Boolean = !from || to +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala index 5854f42a061..8857f0b5a25 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala @@ -17,11 +17,11 @@ package org.apache.spark.sql.catalyst.analysis -import org.apache.spark.sql.catalyst.analysis.TypeCoercion.numericPrecedence import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.types._ +import org.apache.spark.sql.types.UpCastRule.numericPrecedence /** * In Spark ANSI mode, the type coercion rules are based on the type precedence lists of the input diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 2a1067be004..bf9e461744e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ +import org.apache.spark.sql.types.UpCastRule.numericPrecedence abstract class TypeCoercionBase { /** @@ -856,17 +857,6 @@ object TypeCoercion extends TypeCoercionBase { override def canCast(from: DataType, to: DataType): Boolean = Cast.canCast(from, to) - // See https://cwiki.apache.org/confluence/display/Hive/LanguageManual+Types. - // The conversion for integral and floating point types have a linear widening hierarchy: - val numericPrecedence = - IndexedSeq( - ByteType, - ShortType, - IntegerType, - LongType, - FloatType, - DoubleType) - override val findTightestCommonType: (DataType, DataType) => Option[DataType] = { case (t1, t2) if t1 == t2 => Some(t1) case (NullType, t1) => Some(t1) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 72df6c33ecd..24ade61c121 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -23,7 +23,7 @@ import java.util.concurrent.TimeUnit._ import org.apache.spark.SparkArithmeticException import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ @@ -289,44 +289,7 @@ object Cast extends QueryErrorsBase { * precision lose or possible runtime failures. For example, long -> int, string -> int are not * up-cast. */ - def canUpCast(from: DataType, to: DataType): Boolean = (from, to) match { - case _ if from == to => true - case (from: NumericType, to: DecimalType) if to.isWiderThan(from) => true - case (from: DecimalType, to: NumericType) if from.isTighterThan(to) => true - case (f, t) if legalNumericPrecedence(f, t) => true - case (DateType, TimestampType) => true - case (DateType, TimestampNTZType) => true - case (TimestampNTZType, TimestampType) => true - case (TimestampType, TimestampNTZType) => true - case (_: AtomicType, StringType) => true - case (_: CalendarIntervalType, StringType) => true - case (NullType, _) => true - - // Spark supports casting between long and timestamp, please see `longToTimestamp` and - // `timestampToLong` for details. - case (TimestampType, LongType) => true - case (LongType, TimestampType) => true - - case (ArrayType(fromType, fn), ArrayType(toType, tn)) => - resolvableNullability(fn, tn) && canUpCast(fromType, toType) - - case (MapType(fromKey, fromValue, fn), MapType(toKey, toValue, tn)) => - resolvableNullability(fn, tn) && canUpCast(fromKey, toKey) && canUpCast(fromValue, toValue) - - case (StructType(fromFields), StructType(toFields)) => - fromFields.length == toFields.length && - fromFields.zip(toFields).forall { - case (f1, f2) => - resolvableNullability(f1.nullable, f2.nullable) && canUpCast(f1.dataType, f2.dataType) - } - - case (_: DayTimeIntervalType, _: DayTimeIntervalType) => true - case (_: YearMonthIntervalType, _: YearMonthIntervalType) => true - - case (from: UserDefinedType[_], to: UserDefinedType[_]) if to.acceptsType(from) => true - - case _ => false - } + def canUpCast(from: DataType, to: DataType): Boolean = UpCastRule.canUpCast(from, to) /** * Returns true iff we can cast the `from` type to `to` type as per the ANSI SQL. @@ -359,12 +322,6 @@ object Cast extends QueryErrorsBase { case _ => false } - private def legalNumericPrecedence(from: DataType, to: DataType): Boolean = { - val fromPrecedence = TypeCoercion.numericPrecedence.indexOf(from) - val toPrecedence = TypeCoercion.numericPrecedence.indexOf(to) - fromPrecedence >= 0 && fromPrecedence < toPrecedence - } - def canNullSafeCastToDecimal(from: DataType, to: DecimalType): Boolean = from match { case from: BooleanType if to.isWiderThan(DecimalType.BooleanDecimal) => true case from: NumericType if to.isWiderThan(from) => true diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ToStringBase.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ToStringBase.scala index ab9f451e88f..f903863bec6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ToStringBase.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ToStringBase.scala @@ -22,7 +22,7 @@ import java.time.ZoneOffset import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ -import org.apache.spark.sql.catalyst.util.{ArrayData, DateFormatter, IntervalStringStyles, IntervalUtils, MapData, StringUtils, TimestampFormatter} +import org.apache.spark.sql.catalyst.util.{ArrayData, DateFormatter, IntervalStringStyles, IntervalUtils, MapData, SparkStringUtils, TimestampFormatter} import org.apache.spark.sql.catalyst.util.IntervalStringStyles.ANSI_STYLE import org.apache.spark.sql.types._ import org.apache.spark.unsafe.UTF8StringBuilder @@ -53,7 +53,7 @@ trait ToStringBase { self: UnaryExpression with TimeZoneAwareExpression => case CalendarIntervalType => acceptAny[CalendarInterval](i => UTF8String.fromString(i.toString)) case BinaryType if useHexFormatForBinary => - acceptAny[Array[Byte]](binary => UTF8String.fromString(StringUtils.getHexString(binary))) + acceptAny[Array[Byte]](binary => UTF8String.fromString(SparkStringUtils.getHexString(binary))) case BinaryType => acceptAny[Array[Byte]](UTF8String.fromBytes) case DateType => @@ -173,7 +173,7 @@ trait ToStringBase { self: UnaryExpression with TimeZoneAwareExpression => from match { case BinaryType if useHexFormatForBinary => (c, evPrim) => - val utilCls = StringUtils.getClass.getName.stripSuffix("$") + val utilCls = SparkStringUtils.getClass.getName.stripSuffix("$") code"$evPrim = UTF8String.fromString($utilCls.getHexString($c));" case BinaryType => (c, evPrim) => code"$evPrim = UTF8String.fromBytes($c);" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala index 0abbbae93c6..d4851019db8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala @@ -448,9 +448,6 @@ case class MapGroups( copy(child = newChild) } -/** Internal class representing State */ -trait LogicalGroupState[S] - /** Factory for constructing new `MapGroupsWithState` nodes. */ object FlatMapGroupsWithState { def apply[K: Encoder, V: Encoder, S: Encoder, U: Encoder]( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala index 24620e692a0..e051cfc37f1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala @@ -17,18 +17,14 @@ package org.apache.spark.sql.catalyst.util -import java.time.{Duration, Period} -import java.time.temporal.ChronoUnit import java.util.Locale import java.util.concurrent.TimeUnit -import scala.collection.mutable import scala.util.control.NonFatal import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.util.DateTimeConstants._ -import org.apache.spark.sql.catalyst.util.IntervalStringStyles.{ANSI_STYLE, HIVE_STYLE, IntervalStyle} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -37,20 +33,8 @@ import org.apache.spark.sql.types.DayTimeIntervalType.{DAY, HOUR, MINUTE, SECOND import org.apache.spark.sql.types.YearMonthIntervalType.{MONTH, YEAR} import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} -// The style of textual representation of intervals -object IntervalStringStyles extends Enumeration { - type IntervalStyle = Value - val ANSI_STYLE, HIVE_STYLE = Value -} - object IntervalUtils extends SparkIntervalUtils { - private val MAX_DAY = Long.MaxValue / MICROS_PER_DAY - private val MAX_HOUR = Long.MaxValue / MICROS_PER_HOUR - private val MAX_MINUTE = Long.MaxValue / MICROS_PER_MINUTE - private val MAX_SECOND = Long.MaxValue / MICROS_PER_SECOND - private val MIN_SECOND = Long.MinValue / MICROS_PER_SECOND - def getYears(months: Int): Int = months / MONTHS_PER_YEAR def getYears(interval: CalendarInterval): Int = getYears(interval.months) @@ -781,250 +765,6 @@ object IntervalUtils extends SparkIntervalUtils { micros } - // The amount of seconds that can cause overflow in the conversion to microseconds - private final val minDurationSeconds = Math.floorDiv(Long.MinValue, MICROS_PER_SECOND) - - /** - * Converts this duration to the total length in microseconds. - * <p> - * If this duration is too large to fit in a [[Long]] microseconds, then an - * exception is thrown. - * <p> - * If this duration has greater than microsecond precision, then the conversion - * will drop any excess precision information as though the amount in nanoseconds - * was subject to integer division by one thousand. - * - * @return The total length of the duration in microseconds - * @throws ArithmeticException If numeric overflow occurs - */ - def durationToMicros(duration: Duration): Long = { - durationToMicros(duration, DT.SECOND) - } - - def durationToMicros(duration: Duration, endField: Byte): Long = { - val seconds = duration.getSeconds - val micros = if (seconds == minDurationSeconds) { - val microsInSeconds = (minDurationSeconds + 1) * MICROS_PER_SECOND - val nanoAdjustment = duration.getNano - assert(0 <= nanoAdjustment && nanoAdjustment < NANOS_PER_SECOND, - "Duration.getNano() must return the adjustment to the seconds field " + - "in the range from 0 to 999999999 nanoseconds, inclusive.") - Math.addExact(microsInSeconds, (nanoAdjustment - NANOS_PER_SECOND) / NANOS_PER_MICROS) - } else { - val microsInSeconds = Math.multiplyExact(seconds, MICROS_PER_SECOND) - Math.addExact(microsInSeconds, duration.getNano / NANOS_PER_MICROS) - } - - endField match { - case DT.DAY => micros - micros % MICROS_PER_DAY - case DT.HOUR => micros - micros % MICROS_PER_HOUR - case DT.MINUTE => micros - micros % MICROS_PER_MINUTE - case DT.SECOND => micros - } - } - - /** - * Obtains a [[Duration]] representing a number of microseconds. - * - * @param micros The number of microseconds, positive or negative - * @return A [[Duration]], not null - */ - def microsToDuration(micros: Long): Duration = Duration.of(micros, ChronoUnit.MICROS) - - /** - * Gets the total number of months in this period. - * <p> - * This returns the total number of months in the period by multiplying the - * number of years by 12 and adding the number of months. - * <p> - * - * @return The total number of months in the period, may be negative - * @throws ArithmeticException If numeric overflow occurs - */ - def periodToMonths(period: Period): Int = { - periodToMonths(period, YM.MONTH) - } - - def periodToMonths(period: Period, endField: Byte): Int = { - val monthsInYears = Math.multiplyExact(period.getYears, MONTHS_PER_YEAR) - val months = Math.addExact(monthsInYears, period.getMonths) - if (endField == YM.YEAR) { - months - months % MONTHS_PER_YEAR - } else { - months - } - } - - /** - * Obtains a [[Period]] representing a number of months. The days unit will be zero, and the years - * and months units will be normalized. - * - * <p> - * The months unit is adjusted to have an absolute value < 12, with the years unit being adjusted - * to compensate. For example, the method returns "2 years and 3 months" for the 27 input months. - * <p> - * The sign of the years and months units will be the same after normalization. - * For example, -13 months will be converted to "-1 year and -1 month". - * - * @param months The number of months, positive or negative - * @return The period of months, not null - */ - def monthsToPeriod(months: Int): Period = Period.ofMonths(months).normalized() - - /** - * Converts an year-month interval as a number of months to its textual representation - * which conforms to the ANSI SQL standard. - * - * @param months The number of months, positive or negative - * @param style The style of textual representation of the interval - * @param startField The start field (YEAR or MONTH) which the interval comprises of. - * @param endField The end field (YEAR or MONTH) which the interval comprises of. - * @return Year-month interval string - */ - def toYearMonthIntervalString( - months: Int, - style: IntervalStyle, - startField: Byte, - endField: Byte): String = { - var sign = "" - var absMonths: Long = months - if (months < 0) { - sign = "-" - absMonths = -absMonths - } - val year = s"$sign${absMonths / MONTHS_PER_YEAR}" - val yearAndMonth = s"$year-${absMonths % MONTHS_PER_YEAR}" - style match { - case ANSI_STYLE => - val formatBuilder = new StringBuilder("INTERVAL '") - if (startField == endField) { - startField match { - case YM.YEAR => formatBuilder.append(s"$year' YEAR") - case YM.MONTH => formatBuilder.append(s"$months' MONTH") - } - } else { - formatBuilder.append(s"$yearAndMonth' YEAR TO MONTH") - } - formatBuilder.toString - case HIVE_STYLE => s"$yearAndMonth" - } - } - - /** - * Converts a day-time interval as a number of microseconds to its textual representation - * which conforms to the ANSI SQL standard. - * - * @param micros The number of microseconds, positive or negative - * @param style The style of textual representation of the interval - * @param startField The start field (DAY, HOUR, MINUTE, SECOND) which the interval comprises of. - * @param endField The end field (DAY, HOUR, MINUTE, SECOND) which the interval comprises of. - * @return Day-time interval string - */ - def toDayTimeIntervalString( - micros: Long, - style: IntervalStyle, - startField: Byte, - endField: Byte): String = { - var sign = "" - var rest = micros - // scalastyle:off caselocale - val from = DT.fieldToString(startField).toUpperCase - val to = DT.fieldToString(endField).toUpperCase - // scalastyle:on caselocale - val prefix = "INTERVAL '" - val postfix = s"' ${if (startField == endField) from else s"$from TO $to"}" - - if (micros < 0) { - if (micros == Long.MinValue) { - // Especial handling of minimum `Long` value because negate op overflows `Long`. - // seconds = 106751991 * (24 * 60 * 60) + 4 * 60 * 60 + 54 = 9223372036854 - // microseconds = -9223372036854000000L-775808 == Long.MinValue - val baseStr = "-106751991 04:00:54.775808000" - val minIntervalString = style match { - case ANSI_STYLE => - val firstStr = startField match { - case DT.DAY => s"-$MAX_DAY" - case DT.HOUR => s"-$MAX_HOUR" - case DT.MINUTE => s"-$MAX_MINUTE" - case DT.SECOND => s"-$MAX_SECOND.775808" - } - val followingStr = if (startField == endField) { - "" - } else { - val substrStart = startField match { - case DT.DAY => 10 - case DT.HOUR => 13 - case DT.MINUTE => 16 - } - val substrEnd = endField match { - case DT.HOUR => 13 - case DT.MINUTE => 16 - case DT.SECOND => 26 - } - baseStr.substring(substrStart, substrEnd) - } - - s"$prefix$firstStr$followingStr$postfix" - case HIVE_STYLE => baseStr - } - return minIntervalString - } else { - sign = "-" - rest = -rest - } - } - val intervalString = style match { - case ANSI_STYLE => - val formatBuilder = new mutable.StringBuilder(sign) - val formatArgs = new mutable.ArrayBuffer[Long]() - startField match { - case DT.DAY => - formatBuilder.append(rest / MICROS_PER_DAY) - rest %= MICROS_PER_DAY - case DT.HOUR => - formatBuilder.append("%02d") - formatArgs.append(rest / MICROS_PER_HOUR) - rest %= MICROS_PER_HOUR - case DT.MINUTE => - formatBuilder.append("%02d") - formatArgs.append(rest / MICROS_PER_MINUTE) - rest %= MICROS_PER_MINUTE - case DT.SECOND => - val leadZero = if (rest < 10 * MICROS_PER_SECOND) "0" else "" - formatBuilder.append(s"$leadZero" + - s"${java.math.BigDecimal.valueOf(rest, 6).stripTrailingZeros.toPlainString}") - } - - if (startField < DT.HOUR && DT.HOUR <= endField) { - formatBuilder.append(" %02d") - formatArgs.append(rest / MICROS_PER_HOUR) - rest %= MICROS_PER_HOUR - } - if (startField < DT.MINUTE && DT.MINUTE <= endField) { - formatBuilder.append(":%02d") - formatArgs.append(rest / MICROS_PER_MINUTE) - rest %= MICROS_PER_MINUTE - } - if (startField < DT.SECOND && DT.SECOND <= endField) { - val leadZero = if (rest < 10 * MICROS_PER_SECOND) "0" else "" - formatBuilder.append( - s":$leadZero${java.math.BigDecimal.valueOf(rest, 6).stripTrailingZeros.toPlainString}") - } - s"$prefix${formatBuilder.toString.format(formatArgs.toSeq: _*)}$postfix" - case HIVE_STYLE => - val secondsWithFraction = rest % MICROS_PER_MINUTE - rest /= MICROS_PER_MINUTE - val minutes = rest % MINUTES_PER_HOUR - rest /= MINUTES_PER_HOUR - val hours = rest % HOURS_PER_DAY - val days = rest / HOURS_PER_DAY - val seconds = secondsWithFraction / MICROS_PER_SECOND - val nanos = (secondsWithFraction % MICROS_PER_SECOND) * NANOS_PER_MICROS - f"$sign$days $hours%02d:$minutes%02d:$seconds%02d.$nanos%09d" - } - intervalString - } - def intToYearMonthInterval(v: Int, startField: Byte, endField: Byte): Int = { endField match { case YEAR => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala index 8721bb809fc..78d4beab414 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala @@ -27,7 +27,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types.{MetadataBuilder, NumericType, StringType, StructType} import org.apache.spark.unsafe.types.UTF8String -import org.apache.spark.util.Utils +import org.apache.spark.util.{SparkErrorUtils, Utils} package object util extends Logging { @@ -81,27 +81,14 @@ package object util extends Logging { } def sideBySide(left: String, right: String): Seq[String] = { - sideBySide(left.split("\n"), right.split("\n")) + SparkStringUtils.sideBySide(left, right) } def sideBySide(left: Seq[String], right: Seq[String]): Seq[String] = { - val maxLeftSize = left.map(_.length).max - val leftPadded = left ++ Seq.fill(math.max(right.size - left.size, 0))("") - val rightPadded = right ++ Seq.fill(math.max(left.size - right.size, 0))("") - - leftPadded.zip(rightPadded).map { - case (l, r) => (if (l == r) " " else "!") + l + (" " * ((maxLeftSize - l.length) + 3)) + r - } + SparkStringUtils.sideBySide(left, right) } - def stackTraceToString(t: Throwable): String = { - val out = new java.io.ByteArrayOutputStream - Utils.tryWithResource(new PrintWriter(out)) { writer => - t.printStackTrace(writer) - writer.flush() - } - new String(out.toByteArray, UTF_8) - } + def stackTraceToString(t: Throwable): String = SparkErrorUtils.stackTraceToString(t) // Replaces attributes, string literals, complex type extractors with their pretty form so that // generated column names don't contain back-ticks or double-quotes. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala index 79b8e9f389e..79b88fde483 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala @@ -48,7 +48,7 @@ import org.apache.spark.sql.types._ * As commands are executed eagerly, this also includes errors thrown during the execution of * commands, which users can see immediately. */ -private[sql] object QueryCompilationErrors extends QueryErrorsBase { +private[sql] object QueryCompilationErrors extends QueryErrorsBase with CompilationErrors { def unexpectedRequiredParameterInFunctionSignature( functionName: String, functionSignature: FunctionSignature) : Throwable = { @@ -1929,15 +1929,6 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase { origin = context) } - def ambiguousColumnOrFieldError( - name: Seq[String], numMatches: Int): Throwable = { - new AnalysisException( - errorClass = "AMBIGUOUS_COLUMN_OR_FIELD", - messageParameters = Map( - "name" -> toSQLId(name), - "n" -> numMatches.toString)) - } - def ambiguousReferenceError(name: String, ambiguousReferences: Seq[Attribute]): Throwable = { new AnalysisException( errorClass = "AMBIGUOUS_REFERENCE", @@ -2449,14 +2440,6 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase { messageParameters = Map("columnName" -> toSQLId(columnName))) } - def columnNotFoundError(colName: String): Throwable = { - new AnalysisException( - errorClass = "COLUMN_NOT_FOUND", - messageParameters = Map( - "colName" -> toSQLId(colName), - "caseSensitiveConfig" -> toSQLConf(SQLConf.CASE_SENSITIVE.key))) - } - def noSuchTableError(db: String, table: String): Throwable = { new NoSuchTableException(db = db, table = table) } @@ -3551,20 +3534,6 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase { messageParameters = Map("messageName" -> messageName)) } - def descriptorParseError(cause: Throwable): Throwable = { - new AnalysisException( - errorClass = "CANNOT_PARSE_PROTOBUF_DESCRIPTOR", - messageParameters = Map.empty, - cause = Option(cause)) - } - - def cannotFindDescriptorFileError(filePath: String, cause: Throwable): Throwable = { - new AnalysisException( - errorClass = "PROTOBUF_DESCRIPTOR_FILE_NOT_FOUND", - messageParameters = Map("filePath" -> filePath), - cause = Option(cause)) - } - def foundRecursionInProtobufSchema(fieldDescriptor: String): Throwable = { new AnalysisException( errorClass = "RECURSIVE_PROTOBUF_SCHEMA", diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index e74c3ada491..444ecbd837f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -885,7 +885,7 @@ object SQLConf { .booleanConf .createWithDefault(false) - val CASE_SENSITIVE = buildConf("spark.sql.caseSensitive") + val CASE_SENSITIVE = buildConf(SqlApiConf.CASE_SENSITIVE_KEY) .internal() .doc("Whether the query analyzer should be case sensitive or not. " + "Default to case insensitive. It is highly discouraged to turn on case sensitive mode.") @@ -2657,7 +2657,7 @@ object SQLConf { Try { DateTimeUtils.getZoneId(zone) }.isSuccess } - val SESSION_LOCAL_TIMEZONE = buildConf("spark.sql.session.timeZone") + val SESSION_LOCAL_TIMEZONE = buildConf(SqlApiConf.SESSION_LOCAL_TIMEZONE_KEY) .doc("The ID of session local timezone in the format of either region-based zone IDs or " + "zone offsets. Region IDs must have the form 'area/city', such as 'America/Los_Angeles'. " + "Zone offsets must be in the format '(+|-)HH', '(+|-)HH:mm' or '(+|-)HH:mm:ss', e.g '-08', " + @@ -4325,7 +4325,7 @@ object SQLConf { .createWithDefault(false) val LOCAL_RELATION_CACHE_THRESHOLD = - buildConf("spark.sql.session.localRelationCacheThreshold") + buildConf(SqlApiConf.LOCAL_RELATION_CACHE_THRESHOLD_KEY) .doc("The threshold for the size in bytes of local relations to be cached at " + "the driver side after serialization.") .version("3.5.0") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala index 7c11ccf6a0d..34f87f940a7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala @@ -28,7 +28,6 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch -import org.apache.spark.sql.catalyst.analysis.TypeCoercion.numericPrecedence import org.apache.spark.sql.catalyst.expressions.Cast._ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext import org.apache.spark.sql.catalyst.util.DateTimeConstants._ @@ -40,6 +39,7 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.sql.types.DataTypeTestUtils.{dayTimeIntervalTypes, yearMonthIntervalTypes} import org.apache.spark.sql.types.DayTimeIntervalType.{DAY, HOUR, MINUTE, SECOND} +import org.apache.spark.sql.types.UpCastRule.numericPrecedence import org.apache.spark.sql.types.YearMonthIntervalType.{MONTH, YEAR} import org.apache.spark.unsafe.types.UTF8String --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org