This is an automated email from the ASF dual-hosted git repository. hvanhovell pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 7bfbeb62cb1 [SPARK-44326][SQL][CONNECT] Move utils that are used from Scala client to the common modules 7bfbeb62cb1 is described below commit 7bfbeb62cb1dc58d81243d22888faa688bad8064 Author: Rui Wang <rui.w...@databricks.com> AuthorDate: Fri Jul 7 13:38:39 2023 -0400 [SPARK-44326][SQL][CONNECT] Move utils that are used from Scala client to the common modules ### What changes were proposed in this pull request? There are some utils are used in the scala client including ser/derse, datetime and interval utils. These can be moved to the common modules. ### Why are the changes needed? To make sure Scala client does not depend on the Catalyst in the future. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Existing test Closes #41885 from amaliujia/SPARK-44326. Authored-by: Rui Wang <rui.w...@databricks.com> Signed-off-by: Herman van Hovell <her...@databricks.com> --- .../org/apache/spark/util/SparkSerDerseUtils.scala | 30 +++ .../spark/sql/execution/streaming/Triggers.scala | 6 +- .../sql/expressions/UserDefinedFunction.scala | 5 +- .../spark/sql/streaming/DataStreamWriter.scala | 4 +- .../main/scala/org/apache/spark/util/Utils.scala | 6 +- .../sql/catalyst/util/SparkDateTimeUtils.scala | 39 +++ .../sql/catalyst/util/SparkIntervalUtils.scala | 263 +++++++++++++++++++++ .../spark/sql/catalyst/util/DateTimeUtils.scala | 2 +- .../spark/sql/catalyst/util/IntervalUtils.scala | 225 +----------------- 9 files changed, 345 insertions(+), 235 deletions(-) diff --git a/common/utils/src/main/scala/org/apache/spark/util/SparkSerDerseUtils.scala b/common/utils/src/main/scala/org/apache/spark/util/SparkSerDerseUtils.scala new file mode 100644 index 00000000000..e9150618476 --- /dev/null +++ b/common/utils/src/main/scala/org/apache/spark/util/SparkSerDerseUtils.scala @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.util + +import java.io.{ByteArrayOutputStream, ObjectOutputStream} + +object SparkSerDerseUtils { + /** Serialize an object using Java serialization */ + def serialize[T](o: T): Array[Byte] = { + val bos = new ByteArrayOutputStream() + val oos = new ObjectOutputStream(bos) + oos.writeObject(o) + oos.close() + bos.toByteArray + } +} diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/execution/streaming/Triggers.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/execution/streaming/Triggers.scala index be1b0e8ac0c..ad19ad17805 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/execution/streaming/Triggers.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/execution/streaming/Triggers.scala @@ -22,8 +22,8 @@ import java.util.concurrent.TimeUnit import scala.concurrent.duration.Duration import org.apache.spark.sql.catalyst.util.DateTimeConstants.MICROS_PER_DAY -import org.apache.spark.sql.catalyst.util.DateTimeUtils.microsToMillis -import org.apache.spark.sql.catalyst.util.IntervalUtils +import org.apache.spark.sql.catalyst.util.SparkDateTimeUtils.microsToMillis +import org.apache.spark.sql.catalyst.util.SparkIntervalUtils import org.apache.spark.sql.streaming.Trigger import org.apache.spark.unsafe.types.UTF8String @@ -35,7 +35,7 @@ private object Triggers { } def convert(interval: String): Long = { - val cal = IntervalUtils.stringToInterval(UTF8String.fromString(interval)) + val cal = SparkIntervalUtils.stringToInterval(UTF8String.fromString(interval)) if (cal.months != 0) { throw new IllegalArgumentException(s"Doesn't support month or year interval: $interval") } diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala index 14dfc0c6a86..c3c735cd42e 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.Column import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, UdfPacket} -import org.apache.spark.util.Utils +import org.apache.spark.util.SparkSerDerseUtils /** * A user-defined function. To create one, use the `udf` functions in `functions`. @@ -103,7 +103,8 @@ case class ScalarUserDefinedFunction( // SPARK-43198: Eagerly serialize to prevent the UDF from containing a reference to this class. private[this] val udf = { - val udfPacketBytes = Utils.serialize(UdfPacket(function, inputEncoders, outputEncoder)) + val udfPacketBytes = + SparkSerDerseUtils.serialize(UdfPacket(function, inputEncoders, outputEncoder)) val scalaUdfBuilder = proto.ScalarScalaUDF .newBuilder() .setPayload(ByteString.copyFrom(udfPacketBytes)) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index ed3d2bb8558..27f5642d0ec 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -35,7 +35,7 @@ import org.apache.spark.sql.execution.streaming.AvailableNowTrigger import org.apache.spark.sql.execution.streaming.ContinuousTrigger import org.apache.spark.sql.execution.streaming.OneTimeTrigger import org.apache.spark.sql.execution.streaming.ProcessingTimeTrigger -import org.apache.spark.util.Utils +import org.apache.spark.util.SparkSerDerseUtils /** * Interface used to write a streaming `Dataset` to external storage systems (e.g. file systems, @@ -214,7 +214,7 @@ final class DataStreamWriter[T] private[sql] (ds: Dataset[T]) extends Logging { * @since 3.5.0 */ def foreach(writer: ForeachWriter[T]): DataStreamWriter[T] = { - val serialized = Utils.serialize(ForeachWriterPacket(writer, ds.encoder)) + val serialized = SparkSerDerseUtils.serialize(ForeachWriterPacket(writer, ds.encoder)) val scalaWriterBuilder = proto.ScalarScalaUDF .newBuilder() .setPayload(ByteString.copyFrom(serialized)) diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index b5c0ee1bab8..d4b2651748e 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -121,11 +121,7 @@ private[spark] object Utils extends Logging with SparkClassUtils { /** Serialize an object using Java serialization */ def serialize[T](o: T): Array[Byte] = { - val bos = new ByteArrayOutputStream() - val oos = new ObjectOutputStream(bos) - oos.writeObject(o) - oos.close() - bos.toByteArray + SparkSerDerseUtils.serialize(o) } /** Deserialize an object using Java serialization */ diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/SparkDateTimeUtils.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/SparkDateTimeUtils.scala new file mode 100644 index 00000000000..e96da99c6b0 --- /dev/null +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/SparkDateTimeUtils.scala @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.util + +import org.apache.spark.sql.catalyst.util.DateTimeConstants.MICROS_PER_MILLIS + +object SparkDateTimeUtils { + /** + * Converts the timestamp to milliseconds since epoch. In Spark timestamp values have microseconds + * precision, so this conversion is lossy. + */ + def microsToMillis(micros: Long): Long = { + // When the timestamp is negative i.e before 1970, we need to adjust the milliseconds portion. + // Example - 1965-01-01 10:11:12.123456 is represented as (-157700927876544) in micro precision. + // In millis precision the above needs to be represented as (-157700927877). + Math.floorDiv(micros, MICROS_PER_MILLIS) + } + + /** + * Converts milliseconds since the epoch to microseconds. + */ + def millisToMicros(millis: Long): Long = { + Math.multiplyExact(millis, MICROS_PER_MILLIS) + } +} diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/SparkIntervalUtils.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/SparkIntervalUtils.scala new file mode 100644 index 00000000000..05ceb04f12b --- /dev/null +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/SparkIntervalUtils.scala @@ -0,0 +1,263 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.util + +import org.apache.spark.sql.catalyst.util.DateTimeConstants.{DAYS_PER_WEEK, MICROS_PER_HOUR, MICROS_PER_MINUTE, MICROS_PER_SECOND, MONTHS_PER_YEAR, NANOS_PER_MICROS, NANOS_PER_SECOND} +import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} + +trait SparkIntervalUtils { + /** + * Converts a string to [[CalendarInterval]] case-insensitively. + * + * @throws IllegalArgumentException if the input string is not in valid interval format. + */ + def stringToInterval(input: UTF8String): CalendarInterval = { + import ParseState._ + def throwIAE(msg: String, e: Exception = null) = { + throw new IllegalArgumentException(s"Error parsing '$input' to interval, $msg", e) + } + + if (input == null) { + throwIAE("interval string cannot be null") + } + // scalastyle:off caselocale .toLowerCase + val s = input.trimAll().toLowerCase + // scalastyle:on + val bytes = s.getBytes + if (bytes.isEmpty) { + throwIAE("interval string cannot be empty") + } + var state = PREFIX + var i = 0 + var currentValue: Long = 0 + var isNegative: Boolean = false + var months: Int = 0 + var days: Int = 0 + var microseconds: Long = 0 + var fractionScale: Int = 0 + val initialFractionScale = (NANOS_PER_SECOND / 10).toInt + var fraction: Int = 0 + var pointPrefixed: Boolean = false + + def trimToNextState(b: Byte, next: ParseState): Unit = { + if (Character.isWhitespace(b)) { + i += 1 + } else { + state = next + } + } + + def currentWord: String = { + val sep = "\\s+" + val strings = s.toString.split(sep) + val lenRight = s.substring(i, s.numBytes()).toString.split(sep).length + strings(strings.length - lenRight) + } + + while (i < bytes.length) { + val b = bytes(i) + state match { + case PREFIX => + if (s.startsWith(intervalStr)) { + if (s.numBytes() == intervalStr.numBytes()) { + throwIAE("interval string cannot be empty") + } else if (!Character.isWhitespace(bytes(i + intervalStr.numBytes()))) { + throwIAE(s"invalid interval prefix $currentWord") + } else { + i += intervalStr.numBytes() + 1 + } + } + state = TRIM_BEFORE_SIGN + case TRIM_BEFORE_SIGN => trimToNextState(b, SIGN) + case SIGN => + currentValue = 0 + fraction = 0 + // We preset next state from SIGN to TRIM_BEFORE_VALUE. If we meet '.' in the SIGN state, + // it means that the interval value we deal with here is a numeric with only fractional + // part, such as '.11 second', which can be parsed to 0.11 seconds. In this case, we need + // to reset next state to `VALUE_FRACTIONAL_PART` to go parse the fraction part of the + // interval value. + state = TRIM_BEFORE_VALUE + // We preset the scale to an invalid value to track fraction presence in the UNIT_BEGIN + // state. If we meet '.', the scale become valid for the VALUE_FRACTIONAL_PART state. + fractionScale = -1 + pointPrefixed = false + b match { + case '-' => + isNegative = true + i += 1 + case '+' => + isNegative = false + i += 1 + case _ if '0' <= b && b <= '9' => + isNegative = false + case '.' => + isNegative = false + fractionScale = initialFractionScale + pointPrefixed = true + i += 1 + state = VALUE_FRACTIONAL_PART + case _ => throwIAE( s"unrecognized number '$currentWord'") + } + case TRIM_BEFORE_VALUE => trimToNextState(b, VALUE) + case VALUE => + b match { + case _ if '0' <= b && b <= '9' => + try { + currentValue = Math.addExact(Math.multiplyExact(10, currentValue), (b - '0')) + } catch { + case e: ArithmeticException => throwIAE(e.getMessage, e) + } + case _ if Character.isWhitespace(b) => state = TRIM_BEFORE_UNIT + case '.' => + fractionScale = initialFractionScale + state = VALUE_FRACTIONAL_PART + case _ => throwIAE(s"invalid value '$currentWord'") + } + i += 1 + case VALUE_FRACTIONAL_PART => + if ('0' <= b && b <= '9' && fractionScale > 0) { + fraction += (b - '0') * fractionScale + fractionScale /= 10 + } else if (Character.isWhitespace(b) && + (!pointPrefixed || fractionScale < initialFractionScale)) { + fraction /= NANOS_PER_MICROS.toInt + state = TRIM_BEFORE_UNIT + } else if ('0' <= b && b <= '9') { + throwIAE(s"interval can only support nanosecond precision, '$currentWord' is out" + + s" of range") + } else { + throwIAE(s"invalid value '$currentWord'") + } + i += 1 + case TRIM_BEFORE_UNIT => trimToNextState(b, UNIT_BEGIN) + case UNIT_BEGIN => + // Checks that only seconds can have the fractional part + if (b != 's' && fractionScale >= 0) { + throwIAE(s"'$currentWord' cannot have fractional part") + } + if (isNegative) { + currentValue = -currentValue + fraction = -fraction + } + try { + b match { + case 'y' if s.matchAt(yearStr, i) => + val monthsInYears = Math.multiplyExact(MONTHS_PER_YEAR, currentValue) + months = Math.toIntExact(Math.addExact(months, monthsInYears)) + i += yearStr.numBytes() + case 'w' if s.matchAt(weekStr, i) => + val daysInWeeks = Math.multiplyExact(DAYS_PER_WEEK, currentValue) + days = Math.toIntExact(Math.addExact(days, daysInWeeks)) + i += weekStr.numBytes() + case 'd' if s.matchAt(dayStr, i) => + days = Math.addExact(days, Math.toIntExact(currentValue)) + i += dayStr.numBytes() + case 'h' if s.matchAt(hourStr, i) => + val hoursUs = Math.multiplyExact(currentValue, MICROS_PER_HOUR) + microseconds = Math.addExact(microseconds, hoursUs) + i += hourStr.numBytes() + case 's' if s.matchAt(secondStr, i) => + val secondsUs = Math.multiplyExact(currentValue, MICROS_PER_SECOND) + microseconds = Math.addExact(Math.addExact(microseconds, secondsUs), fraction) + i += secondStr.numBytes() + case 'm' => + if (s.matchAt(monthStr, i)) { + months = Math.addExact(months, Math.toIntExact(currentValue)) + i += monthStr.numBytes() + } else if (s.matchAt(minuteStr, i)) { + val minutesUs = Math.multiplyExact(currentValue, MICROS_PER_MINUTE) + microseconds = Math.addExact(microseconds, minutesUs) + i += minuteStr.numBytes() + } else if (s.matchAt(millisStr, i)) { + val millisUs = SparkDateTimeUtils.millisToMicros(currentValue) + microseconds = Math.addExact(microseconds, millisUs) + i += millisStr.numBytes() + } else if (s.matchAt(microsStr, i)) { + microseconds = Math.addExact(microseconds, currentValue) + i += microsStr.numBytes() + } else throwIAE(s"invalid unit '$currentWord'") + case _ => throwIAE(s"invalid unit '$currentWord'") + } + } catch { + case e: ArithmeticException => throwIAE(e.getMessage, e) + } + state = UNIT_SUFFIX + case UNIT_SUFFIX => + b match { + case 's' => state = UNIT_END + case _ if Character.isWhitespace(b) => state = TRIM_BEFORE_SIGN + case _ => throwIAE(s"invalid unit '$currentWord'") + } + i += 1 + case UNIT_END => + if (Character.isWhitespace(b) ) { + i += 1 + state = TRIM_BEFORE_SIGN + } else { + throwIAE(s"invalid unit '$currentWord'") + } + } + } + + val result = state match { + case UNIT_SUFFIX | UNIT_END | TRIM_BEFORE_SIGN => + new CalendarInterval(months, days, microseconds) + case TRIM_BEFORE_VALUE => throwIAE(s"expect a number after '$currentWord' but hit EOL") + case VALUE | VALUE_FRACTIONAL_PART => + throwIAE(s"expect a unit name after '$currentWord' but hit EOL") + case _ => throwIAE(s"unknown error when parsing '$currentWord'") + } + + result + } + + protected def unitToUtf8(unit: String): UTF8String = { + UTF8String.fromString(unit) + } + + protected val intervalStr = unitToUtf8("interval") + + protected val yearStr = unitToUtf8("year") + protected val monthStr = unitToUtf8("month") + protected val weekStr = unitToUtf8("week") + protected val dayStr = unitToUtf8("day") + protected val hourStr = unitToUtf8("hour") + protected val minuteStr = unitToUtf8("minute") + protected val secondStr = unitToUtf8("second") + protected val millisStr = unitToUtf8("millisecond") + protected val microsStr = unitToUtf8("microsecond") + protected val nanosStr = unitToUtf8("nanosecond") + + + private object ParseState extends Enumeration { + type ParseState = Value + + val PREFIX, + TRIM_BEFORE_SIGN, + SIGN, + TRIM_BEFORE_VALUE, + VALUE, + VALUE_FRACTIONAL_PART, + TRIM_BEFORE_UNIT, + UNIT_BEGIN, + UNIT_SUFFIX, + UNIT_END = Value + } +} + +object SparkIntervalUtils extends SparkIntervalUtils diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index a3f74168b84..1142b6184dd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -255,7 +255,7 @@ object DateTimeUtils { * Converts milliseconds since the epoch to microseconds. */ def millisToMicros(millis: Long): Long = { - Math.multiplyExact(millis, MICROS_PER_MILLIS) + SparkDateTimeUtils.millisToMicros(millis) } private final val gmtUtf8 = UTF8String.fromString("GMT") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala index 455a74e06c4..6ba59b4e730 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala @@ -28,7 +28,6 @@ import scala.util.control.NonFatal import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.util.DateTimeConstants._ -import org.apache.spark.sql.catalyst.util.DateTimeUtils.millisToMicros import org.apache.spark.sql.catalyst.util.IntervalStringStyles.{ANSI_STYLE, HIVE_STYLE, IntervalStyle} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.internal.SQLConf @@ -44,7 +43,7 @@ object IntervalStringStyles extends Enumeration { val ANSI_STYLE, HIVE_STYLE = Value } -object IntervalUtils { +object IntervalUtils extends SparkIntervalUtils { private val MAX_DAY = Long.MaxValue / MICROS_PER_DAY private val MAX_HOUR = Long.MaxValue / MICROS_PER_HOUR @@ -754,20 +753,6 @@ object IntervalUtils { UNIT_SUFFIX, UNIT_END = Value } - private final val intervalStr = unitToUtf8("interval") - private def unitToUtf8(unit: String): UTF8String = { - UTF8String.fromString(unit) - } - private final val yearStr = unitToUtf8("year") - private final val monthStr = unitToUtf8("month") - private final val weekStr = unitToUtf8("week") - private final val dayStr = unitToUtf8("day") - private final val hourStr = unitToUtf8("hour") - private final val minuteStr = unitToUtf8("minute") - private final val secondStr = unitToUtf8("second") - private final val millisStr = unitToUtf8("millisecond") - private final val microsStr = unitToUtf8("microsecond") - private final val nanosStr = unitToUtf8("nanosecond") /** * A safe version of `stringToInterval`. It returns null for invalid input string. @@ -780,212 +765,6 @@ object IntervalUtils { } } - /** - * Converts a string to [[CalendarInterval]] case-insensitively. - * - * @throws IllegalArgumentException if the input string is not in valid interval format. - */ - def stringToInterval(input: UTF8String): CalendarInterval = { - import ParseState._ - def throwIAE(msg: String, e: Exception = null) = { - throw new IllegalArgumentException(s"Error parsing '$input' to interval, $msg", e) - } - - if (input == null) { - throwIAE("interval string cannot be null") - } - // scalastyle:off caselocale .toLowerCase - val s = input.trimAll().toLowerCase - // scalastyle:on - val bytes = s.getBytes - if (bytes.isEmpty) { - throwIAE("interval string cannot be empty") - } - var state = PREFIX - var i = 0 - var currentValue: Long = 0 - var isNegative: Boolean = false - var months: Int = 0 - var days: Int = 0 - var microseconds: Long = 0 - var fractionScale: Int = 0 - val initialFractionScale = (NANOS_PER_SECOND / 10).toInt - var fraction: Int = 0 - var pointPrefixed: Boolean = false - - def trimToNextState(b: Byte, next: ParseState): Unit = { - if (Character.isWhitespace(b)) { - i += 1 - } else { - state = next - } - } - - def currentWord: String = { - val sep = "\\s+" - val strings = s.toString.split(sep) - val lenRight = s.substring(i, s.numBytes()).toString.split(sep).length - strings(strings.length - lenRight) - } - - while (i < bytes.length) { - val b = bytes(i) - state match { - case PREFIX => - if (s.startsWith(intervalStr)) { - if (s.numBytes() == intervalStr.numBytes()) { - throwIAE("interval string cannot be empty") - } else if (!Character.isWhitespace(bytes(i + intervalStr.numBytes()))) { - throwIAE(s"invalid interval prefix $currentWord") - } else { - i += intervalStr.numBytes() + 1 - } - } - state = TRIM_BEFORE_SIGN - case TRIM_BEFORE_SIGN => trimToNextState(b, SIGN) - case SIGN => - currentValue = 0 - fraction = 0 - // We preset next state from SIGN to TRIM_BEFORE_VALUE. If we meet '.' in the SIGN state, - // it means that the interval value we deal with here is a numeric with only fractional - // part, such as '.11 second', which can be parsed to 0.11 seconds. In this case, we need - // to reset next state to `VALUE_FRACTIONAL_PART` to go parse the fraction part of the - // interval value. - state = TRIM_BEFORE_VALUE - // We preset the scale to an invalid value to track fraction presence in the UNIT_BEGIN - // state. If we meet '.', the scale become valid for the VALUE_FRACTIONAL_PART state. - fractionScale = -1 - pointPrefixed = false - b match { - case '-' => - isNegative = true - i += 1 - case '+' => - isNegative = false - i += 1 - case _ if '0' <= b && b <= '9' => - isNegative = false - case '.' => - isNegative = false - fractionScale = initialFractionScale - pointPrefixed = true - i += 1 - state = VALUE_FRACTIONAL_PART - case _ => throwIAE( s"unrecognized number '$currentWord'") - } - case TRIM_BEFORE_VALUE => trimToNextState(b, VALUE) - case VALUE => - b match { - case _ if '0' <= b && b <= '9' => - try { - currentValue = Math.addExact(Math.multiplyExact(10, currentValue), (b - '0')) - } catch { - case e: ArithmeticException => throwIAE(e.getMessage, e) - } - case _ if Character.isWhitespace(b) => state = TRIM_BEFORE_UNIT - case '.' => - fractionScale = initialFractionScale - state = VALUE_FRACTIONAL_PART - case _ => throwIAE(s"invalid value '$currentWord'") - } - i += 1 - case VALUE_FRACTIONAL_PART => - if ('0' <= b && b <= '9' && fractionScale > 0) { - fraction += (b - '0') * fractionScale - fractionScale /= 10 - } else if (Character.isWhitespace(b) && - (!pointPrefixed || fractionScale < initialFractionScale)) { - fraction /= NANOS_PER_MICROS.toInt - state = TRIM_BEFORE_UNIT - } else if ('0' <= b && b <= '9') { - throwIAE(s"interval can only support nanosecond precision, '$currentWord' is out" + - s" of range") - } else { - throwIAE(s"invalid value '$currentWord'") - } - i += 1 - case TRIM_BEFORE_UNIT => trimToNextState(b, UNIT_BEGIN) - case UNIT_BEGIN => - // Checks that only seconds can have the fractional part - if (b != 's' && fractionScale >= 0) { - throwIAE(s"'$currentWord' cannot have fractional part") - } - if (isNegative) { - currentValue = -currentValue - fraction = -fraction - } - try { - b match { - case 'y' if s.matchAt(yearStr, i) => - val monthsInYears = Math.multiplyExact(MONTHS_PER_YEAR, currentValue) - months = Math.toIntExact(Math.addExact(months, monthsInYears)) - i += yearStr.numBytes() - case 'w' if s.matchAt(weekStr, i) => - val daysInWeeks = Math.multiplyExact(DAYS_PER_WEEK, currentValue) - days = Math.toIntExact(Math.addExact(days, daysInWeeks)) - i += weekStr.numBytes() - case 'd' if s.matchAt(dayStr, i) => - days = Math.addExact(days, Math.toIntExact(currentValue)) - i += dayStr.numBytes() - case 'h' if s.matchAt(hourStr, i) => - val hoursUs = Math.multiplyExact(currentValue, MICROS_PER_HOUR) - microseconds = Math.addExact(microseconds, hoursUs) - i += hourStr.numBytes() - case 's' if s.matchAt(secondStr, i) => - val secondsUs = Math.multiplyExact(currentValue, MICROS_PER_SECOND) - microseconds = Math.addExact(Math.addExact(microseconds, secondsUs), fraction) - i += secondStr.numBytes() - case 'm' => - if (s.matchAt(monthStr, i)) { - months = Math.addExact(months, Math.toIntExact(currentValue)) - i += monthStr.numBytes() - } else if (s.matchAt(minuteStr, i)) { - val minutesUs = Math.multiplyExact(currentValue, MICROS_PER_MINUTE) - microseconds = Math.addExact(microseconds, minutesUs) - i += minuteStr.numBytes() - } else if (s.matchAt(millisStr, i)) { - val millisUs = millisToMicros(currentValue) - microseconds = Math.addExact(microseconds, millisUs) - i += millisStr.numBytes() - } else if (s.matchAt(microsStr, i)) { - microseconds = Math.addExact(microseconds, currentValue) - i += microsStr.numBytes() - } else throwIAE(s"invalid unit '$currentWord'") - case _ => throwIAE(s"invalid unit '$currentWord'") - } - } catch { - case e: ArithmeticException => throwIAE(e.getMessage, e) - } - state = UNIT_SUFFIX - case UNIT_SUFFIX => - b match { - case 's' => state = UNIT_END - case _ if Character.isWhitespace(b) => state = TRIM_BEFORE_SIGN - case _ => throwIAE(s"invalid unit '$currentWord'") - } - i += 1 - case UNIT_END => - if (Character.isWhitespace(b) ) { - i += 1 - state = TRIM_BEFORE_SIGN - } else { - throwIAE(s"invalid unit '$currentWord'") - } - } - } - - val result = state match { - case UNIT_SUFFIX | UNIT_END | TRIM_BEFORE_SIGN => - new CalendarInterval(months, days, microseconds) - case TRIM_BEFORE_VALUE => throwIAE(s"expect a number after '$currentWord' but hit EOL") - case VALUE | VALUE_FRACTIONAL_PART => - throwIAE(s"expect a unit name after '$currentWord' but hit EOL") - case _ => throwIAE(s"unknown error when parsing '$currentWord'") - } - - result - } - def makeInterval( years: Int, months: Int, @@ -1163,8 +942,10 @@ object IntervalUtils { endField: Byte): String = { var sign = "" var rest = micros + // scalastyle:off caselocale val from = DT.fieldToString(startField).toUpperCase val to = DT.fieldToString(endField).toUpperCase + // scalastyle:on caselocale val prefix = "INTERVAL '" val postfix = s"' ${if (startField == endField) from else s"$from TO $to"}" --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org