This is an automated email from the ASF dual-hosted git repository.
hvanhovell pushed a commit to branch branch-3.4
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.4 by this push:
new cedfbf59e14 [SPARK-42579][CONNECT] Part-1: `function.lit` support
`Array[_]` dataType
cedfbf59e14 is described below
commit cedfbf59e140058c58ba25314b38c43f87a7ede2
Author: yangjie01 <[email protected]>
AuthorDate: Mon Mar 6 19:26:39 2023 -0400
[SPARK-42579][CONNECT] Part-1: `function.lit` support `Array[_]` dataType
### What changes were proposed in this pull request?
This is the first part of SPARK-42579, the pr is aims to support `Array[_]`
data type for `function.lit`.
### Why are the changes needed?
Make `function.lit` support nested dataType
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
- Add new test
- Manually checked Scala 2.13 test
```
build/sbt "connect-client-jvm/test" -Phive -Pscala-2.13
build/sbt "connect/test" -Phive -Pscala-2.13
```
Closes #40218 from LuciferYang/SPARK-42579.
Authored-by: yangjie01 <[email protected]>
Signed-off-by: Herman van Hovell <[email protected]>
(cherry picked from commit 0def3de6ed1000efe72c8bbdd3b3804bb34ce620)
Signed-off-by: Herman van Hovell <[email protected]>
---
.../sql/expressions/LiteralProtoConverter.scala | 145 +++++++
.../scala/org/apache/spark/sql/functions.scala | 65 +--
.../apache/spark/sql/PlanGenerationTestSuite.scala | 37 ++
.../main/protobuf/spark/connect/expressions.proto | 6 +
.../explain-results/function_lit_array.explain | 2 +
.../query-tests/queries/function_lit_array.json | 461 +++++++++++++++++++++
.../queries/function_lit_array.proto.bin | Bin 0 -> 885 bytes
.../planner/LiteralValueProtoConverter.scala | 65 +++
.../pyspark/sql/connect/proto/expressions_pb2.py | 79 ++--
.../pyspark/sql/connect/proto/expressions_pb2.pyi | 38 ++
10 files changed, 805 insertions(+), 93 deletions(-)
diff --git
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/LiteralProtoConverter.scala
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/LiteralProtoConverter.scala
new file mode 100644
index 00000000000..b3b9f53e7bb
--- /dev/null
+++
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/LiteralProtoConverter.scala
@@ -0,0 +1,145 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.expressions
+
+import java.lang.{Boolean => JBoolean, Byte => JByte, Character => JChar,
Double => JDouble, Float => JFloat, Integer => JInteger, Long => JLong, Short
=> JShort}
+import java.math.{BigDecimal => JBigDecimal}
+import java.sql.{Date, Timestamp}
+import java.time._
+
+import com.google.protobuf.ByteString
+
+import org.apache.spark.connect.proto
+import org.apache.spark.sql.catalyst.util.{DateTimeUtils, IntervalUtils}
+import org.apache.spark.sql.connect.client.unsupported
+import org.apache.spark.sql.connect.common.DataTypeProtoConverter._
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.CalendarInterval
+
+object LiteralProtoConverter {
+
+ private lazy val nullType =
+
proto.DataType.newBuilder().setNull(proto.DataType.NULL.getDefaultInstance).build()
+
+ /**
+ * Transforms literal value to the `proto.Expression.Literal.Builder`.
+ *
+ * @return
+ * proto.Expression.Literal.Builder
+ */
+ @scala.annotation.tailrec
+ def toLiteralProtoBuilder(literal: Any): proto.Expression.Literal.Builder = {
+ val builder = proto.Expression.Literal.newBuilder()
+
+ def decimalBuilder(precision: Int, scale: Int, value: String) = {
+
builder.getDecimalBuilder.setPrecision(precision).setScale(scale).setValue(value)
+ }
+
+ def calendarIntervalBuilder(months: Int, days: Int, microseconds: Long) = {
+ builder.getCalendarIntervalBuilder
+ .setMonths(months)
+ .setDays(days)
+ .setMicroseconds(microseconds)
+ }
+
+ def arrayBuilder(array: Array[_]) = {
+ val ab = builder.getArrayBuilder
+
.setElementType(toConnectProtoType(toDataType(array.getClass.getComponentType)))
+ array.foreach(x => ab.addElement(toLiteralProto(x)))
+ ab
+ }
+
+ literal match {
+ case v: Boolean => builder.setBoolean(v)
+ case v: Byte => builder.setByte(v)
+ case v: Short => builder.setShort(v)
+ case v: Int => builder.setInteger(v)
+ case v: Long => builder.setLong(v)
+ case v: Float => builder.setFloat(v)
+ case v: Double => builder.setDouble(v)
+ case v: BigDecimal =>
+ builder.setDecimal(decimalBuilder(v.precision, v.scale, v.toString))
+ case v: JBigDecimal =>
+ builder.setDecimal(decimalBuilder(v.precision, v.scale, v.toString))
+ case v: String => builder.setString(v)
+ case v: Char => builder.setString(v.toString)
+ case v: Array[Char] => builder.setString(String.valueOf(v))
+ case v: Array[Byte] => builder.setBinary(ByteString.copyFrom(v))
+ case v: collection.mutable.WrappedArray[_] =>
toLiteralProtoBuilder(v.array)
+ case v: LocalDate => builder.setDate(v.toEpochDay.toInt)
+ case v: Decimal =>
+ builder.setDecimal(decimalBuilder(Math.max(v.precision, v.scale),
v.scale, v.toString))
+ case v: Instant => builder.setTimestamp(DateTimeUtils.instantToMicros(v))
+ case v: Timestamp =>
builder.setTimestamp(DateTimeUtils.fromJavaTimestamp(v))
+ case v: LocalDateTime =>
builder.setTimestampNtz(DateTimeUtils.localDateTimeToMicros(v))
+ case v: Date => builder.setDate(DateTimeUtils.fromJavaDate(v))
+ case v: Duration =>
builder.setDayTimeInterval(IntervalUtils.durationToMicros(v))
+ case v: Period =>
builder.setYearMonthInterval(IntervalUtils.periodToMonths(v))
+ case v: Array[_] => builder.setArray(arrayBuilder(v))
+ case v: CalendarInterval =>
+ builder.setCalendarInterval(calendarIntervalBuilder(v.months, v.days,
v.microseconds))
+ case null => builder.setNull(nullType)
+ case _ => unsupported(s"literal $literal not supported (yet).")
+ }
+ }
+
+ /**
+ * Transforms literal value to the `proto.Expression.Literal`.
+ *
+ * @return
+ * proto.Expression.Literal
+ */
+ private def toLiteralProto(literal: Any): proto.Expression.Literal =
+ toLiteralProtoBuilder(literal).build()
+
+ private def toDataType(clz: Class[_]): DataType = clz match {
+ // primitive types
+ case JShort.TYPE => ShortType
+ case JInteger.TYPE => IntegerType
+ case JLong.TYPE => LongType
+ case JDouble.TYPE => DoubleType
+ case JByte.TYPE => ByteType
+ case JFloat.TYPE => FloatType
+ case JBoolean.TYPE => BooleanType
+ case JChar.TYPE => StringType
+
+ // java classes
+ case _ if clz == classOf[LocalDate] || clz == classOf[Date] => DateType
+ case _ if clz == classOf[Instant] || clz == classOf[Timestamp] =>
TimestampType
+ case _ if clz == classOf[LocalDateTime] => TimestampNTZType
+ case _ if clz == classOf[Duration] => DayTimeIntervalType.DEFAULT
+ case _ if clz == classOf[Period] => YearMonthIntervalType.DEFAULT
+ case _ if clz == classOf[JBigDecimal] => DecimalType.SYSTEM_DEFAULT
+ case _ if clz == classOf[Array[Byte]] => BinaryType
+ case _ if clz == classOf[Array[Char]] => StringType
+ case _ if clz == classOf[JShort] => ShortType
+ case _ if clz == classOf[JInteger] => IntegerType
+ case _ if clz == classOf[JLong] => LongType
+ case _ if clz == classOf[JDouble] => DoubleType
+ case _ if clz == classOf[JByte] => ByteType
+ case _ if clz == classOf[JFloat] => FloatType
+ case _ if clz == classOf[JBoolean] => BooleanType
+
+ // other scala classes
+ case _ if clz == classOf[String] => StringType
+ case _ if clz == classOf[BigInt] || clz == classOf[BigDecimal] =>
DecimalType.SYSTEM_DEFAULT
+ case _ if clz == classOf[CalendarInterval] => CalendarIntervalType
+ case _ if clz.isArray => ArrayType(toDataType(clz.getComponentType))
+ case _ =>
+ throw new UnsupportedOperationException(s"Unsupported component type
$clz in arrays.")
+ }
+}
diff --git
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala
index 76a27686bfd..8ce90886e0f 100644
---
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala
+++
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala
@@ -16,24 +16,17 @@
*/
package org.apache.spark.sql
-import java.math.{BigDecimal => JBigDecimal}
-import java.sql.{Date, Timestamp}
-import java.time.{Duration, Instant, LocalDate, LocalDateTime, Period}
import java.util.Collections
import scala.collection.JavaConverters._
import scala.reflect.runtime.universe.{typeTag, TypeTag}
-import com.google.protobuf.ByteString
-
import org.apache.spark.connect.proto
import
org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.PrimitiveLongEncoder
-import org.apache.spark.sql.catalyst.util.{DateTimeUtils, IntervalUtils}
-import org.apache.spark.sql.connect.client.unsupported
import org.apache.spark.sql.expressions.{ScalarUserDefinedFunction,
UserDefinedFunction}
-import org.apache.spark.sql.types.{DataType, Decimal, StructType}
+import org.apache.spark.sql.expressions.LiteralProtoConverter._
+import org.apache.spark.sql.types.{DataType, StructType}
import org.apache.spark.sql.types.DataType.parseTypeWithFallback
-import org.apache.spark.unsafe.types.CalendarInterval
/**
* Commonly used functions available for DataFrame operations. Using functions
defined here
@@ -93,32 +86,10 @@ object functions {
*/
def column(colName: String): Column = col(colName)
- private def createLiteral(f: proto.Expression.Literal.Builder => Unit):
Column = Column {
- builder =>
- val literalBuilder = proto.Expression.Literal.newBuilder()
- f(literalBuilder)
- builder.setLiteral(literalBuilder)
+ private def createLiteral(literalBuilder: proto.Expression.Literal.Builder):
Column = Column {
+ builder => builder.setLiteral(literalBuilder)
}
- private def createDecimalLiteral(precision: Int, scale: Int, value: String):
Column =
- createLiteral { builder =>
- builder.getDecimalBuilder
- .setPrecision(precision)
- .setScale(scale)
- .setValue(value)
- }
-
- private def createCalendarIntervalLiteral(months: Int, days: Int,
microseconds: Long): Column =
- createLiteral { builder =>
- builder.getCalendarIntervalBuilder
- .setMonths(months)
- .setDays(days)
- .setMicroseconds(microseconds)
- }
-
- private val nullType =
-
proto.DataType.newBuilder().setNull(proto.DataType.NULL.getDefaultInstance).build()
-
/**
* Creates a [[Column]] of literal value.
*
@@ -128,37 +99,11 @@ object functions {
*
* @since 3.4.0
*/
- @scala.annotation.tailrec
def lit(literal: Any): Column = {
literal match {
case c: Column => c
case s: Symbol => Column(s.name)
- case v: Boolean => createLiteral(_.setBoolean(v))
- case v: Byte => createLiteral(_.setByte(v))
- case v: Short => createLiteral(_.setShort(v))
- case v: Int => createLiteral(_.setInteger(v))
- case v: Long => createLiteral(_.setLong(v))
- case v: Float => createLiteral(_.setFloat(v))
- case v: Double => createLiteral(_.setDouble(v))
- case v: BigDecimal => createDecimalLiteral(v.precision, v.scale,
v.toString)
- case v: JBigDecimal => createDecimalLiteral(v.precision, v.scale,
v.toString)
- case v: String => createLiteral(_.setString(v))
- case v: Char => createLiteral(_.setString(v.toString))
- case v: Array[Char] => createLiteral(_.setString(String.valueOf(v)))
- case v: Array[Byte] => createLiteral(_.setBinary(ByteString.copyFrom(v)))
- case v: collection.mutable.WrappedArray[_] => lit(v.array)
- case v: LocalDate => createLiteral(_.setDate(v.toEpochDay.toInt))
- case v: Decimal => createDecimalLiteral(Math.max(v.precision, v.scale),
v.scale, v.toString)
- case v: Instant =>
createLiteral(_.setTimestamp(DateTimeUtils.instantToMicros(v)))
- case v: Timestamp =>
createLiteral(_.setTimestamp(DateTimeUtils.fromJavaTimestamp(v)))
- case v: LocalDateTime =>
-
createLiteral(_.setTimestampNtz(DateTimeUtils.localDateTimeToMicros(v)))
- case v: Date => createLiteral(_.setDate(DateTimeUtils.fromJavaDate(v)))
- case v: Duration =>
createLiteral(_.setDayTimeInterval(IntervalUtils.durationToMicros(v)))
- case v: Period =>
createLiteral(_.setYearMonthInterval(IntervalUtils.periodToMonths(v)))
- case v: CalendarInterval => createCalendarIntervalLiteral(v.months,
v.days, v.microseconds)
- case null => createLiteral(_.setNull(nullType))
- case _ => unsupported(s"literal $literal not supported (yet).")
+ case _ => createLiteral(toLiteralProtoBuilder(literal))
}
}
//////////////////////////////////////////////////////////////////////////////////////////////
diff --git
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
index f5ffaf9b73a..85523a22d2b 100755
---
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
+++
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
@@ -2012,6 +2012,43 @@ class PlanGenerationTestSuite
fn.lit(new CalendarInterval(2, 20, 100L)))
}
+ test("function lit array") {
+ simple.select(
+ fn.lit(Array.emptyDoubleArray),
+ fn.lit(Array(Array(1), Array(2), Array(3))),
+ fn.lit(Array(Array(Array(1)), Array(Array(2)), Array(Array(3)))),
+ fn.lit(Array(true, false)),
+ fn.lit(Array(67.toByte, 68.toByte, 69.toByte)),
+ fn.lit(Array(9872.toShort, 9873.toShort, 9874.toShort)),
+ fn.lit(Array(-8726532, 8726532, -8726533)),
+ fn.lit(Array(7834609328726531L, 7834609328726532L, 7834609328726533L)),
+ fn.lit(Array(Math.E, 1.toDouble, 2.toDouble)),
+ fn.lit(Array(-0.8f, -0.7f, -0.9f)),
+ fn.lit(Array(BigDecimal(8997620, 5), BigDecimal(8997621, 5))),
+ fn.lit(
+ Array(BigDecimal(898897667231L, 7).bigDecimal,
BigDecimal(898897667231L, 7).bigDecimal)),
+ fn.lit(Array("connect!", "disconnect!")),
+ fn.lit(Array('T', 'F')),
+ fn.lit(
+ Array(
+ Array.tabulate(10)(i => ('A' + i).toChar),
+ Array.tabulate(10)(i => ('B' + i).toChar))),
+ fn.lit(Array(java.time.LocalDate.of(2020, 10, 10),
java.time.LocalDate.of(2020, 10, 11))),
+ fn.lit(
+ Array(
+ java.time.Instant.ofEpochMilli(1677155519808L),
+ java.time.Instant.ofEpochMilli(1677155519809L))),
+ fn.lit(Array(new java.sql.Timestamp(12345L), new
java.sql.Timestamp(23456L))),
+ fn.lit(
+ Array(
+ java.time.LocalDateTime.of(2023, 2, 23, 20, 36),
+ java.time.LocalDateTime.of(2023, 2, 23, 21, 36))),
+ fn.lit(Array(java.sql.Date.valueOf("2023-02-23"),
java.sql.Date.valueOf("2023-03-01"))),
+ fn.lit(Array(java.time.Duration.ofSeconds(100L),
java.time.Duration.ofSeconds(200L))),
+ fn.lit(Array(java.time.Period.ofDays(100),
java.time.Period.ofDays(200))),
+ fn.lit(Array(new CalendarInterval(2, 20, 100L), new CalendarInterval(2,
21, 200L))))
+ }
+
/* Window API */
test("window") {
simple.select(
diff --git
a/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto
b/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto
index e37a13ee959..6eb769ad27e 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto
@@ -172,6 +172,7 @@ message Expression {
CalendarInterval calendar_interval = 19;
int32 year_month_interval = 20;
int64 day_time_interval = 21;
+ Array array = 22;
}
message Decimal {
@@ -189,6 +190,11 @@ message Expression {
int32 days = 2;
int64 microseconds = 3;
}
+
+ message Array {
+ DataType elementType = 1;
+ repeated Literal element = 2;
+ }
}
// An unresolved attribute that is not explicitly bound to a specific
column, but the column
diff --git
a/connector/connect/common/src/test/resources/query-tests/explain-results/function_lit_array.explain
b/connector/connect/common/src/test/resources/query-tests/explain-results/function_lit_array.explain
new file mode 100644
index 00000000000..74d512b6910
--- /dev/null
+++
b/connector/connect/common/src/test/resources/query-tests/explain-results/function_lit_array.explain
@@ -0,0 +1,2 @@
+Project [[] AS ARRAY()#0, [[1],[2],[3]] AS ARRAY(ARRAY(1), ARRAY(2),
ARRAY(3))#0, [[[1]],[[2]],[[3]]] AS ARRAY(ARRAY(ARRAY(1)), ARRAY(ARRAY(2)),
ARRAY(ARRAY(3)))#0, [true,false] AS ARRAY(true, false)#0, 0x434445 AS
X'434445'#0, [9872,9873,9874] AS ARRAY(9872S, 9873S, 9874S)#0,
[-8726532,8726532,-8726533] AS ARRAY(-8726532, 8726532, -8726533)#0,
[7834609328726531,7834609328726532,7834609328726533] AS
ARRAY(7834609328726531L, 7834609328726532L, 7834609328726533L)#0,
[2.718281828459045,1.0, [...]
++- LocalRelation <empty>, [id#0L, a#0, b#0]
diff --git
a/connector/connect/common/src/test/resources/query-tests/queries/function_lit_array.json
b/connector/connect/common/src/test/resources/query-tests/queries/function_lit_array.json
new file mode 100644
index 00000000000..c9441c9e77c
--- /dev/null
+++
b/connector/connect/common/src/test/resources/query-tests/queries/function_lit_array.json
@@ -0,0 +1,461 @@
+{
+ "common": {
+ "planId": "1"
+ },
+ "project": {
+ "input": {
+ "common": {
+ "planId": "0"
+ },
+ "localRelation": {
+ "schema": "struct\u003cid:bigint,a:int,b:double\u003e"
+ }
+ },
+ "expressions": [{
+ "literal": {
+ "array": {
+ "elementType": {
+ "double": {
+ }
+ }
+ }
+ }
+ }, {
+ "literal": {
+ "array": {
+ "elementType": {
+ "array": {
+ "elementType": {
+ "integer": {
+ }
+ },
+ "containsNull": true
+ }
+ },
+ "element": [{
+ "array": {
+ "elementType": {
+ "integer": {
+ }
+ },
+ "element": [{
+ "integer": 1
+ }]
+ }
+ }, {
+ "array": {
+ "elementType": {
+ "integer": {
+ }
+ },
+ "element": [{
+ "integer": 2
+ }]
+ }
+ }, {
+ "array": {
+ "elementType": {
+ "integer": {
+ }
+ },
+ "element": [{
+ "integer": 3
+ }]
+ }
+ }]
+ }
+ }
+ }, {
+ "literal": {
+ "array": {
+ "elementType": {
+ "array": {
+ "elementType": {
+ "array": {
+ "elementType": {
+ "integer": {
+ }
+ },
+ "containsNull": true
+ }
+ },
+ "containsNull": true
+ }
+ },
+ "element": [{
+ "array": {
+ "elementType": {
+ "array": {
+ "elementType": {
+ "integer": {
+ }
+ },
+ "containsNull": true
+ }
+ },
+ "element": [{
+ "array": {
+ "elementType": {
+ "integer": {
+ }
+ },
+ "element": [{
+ "integer": 1
+ }]
+ }
+ }]
+ }
+ }, {
+ "array": {
+ "elementType": {
+ "array": {
+ "elementType": {
+ "integer": {
+ }
+ },
+ "containsNull": true
+ }
+ },
+ "element": [{
+ "array": {
+ "elementType": {
+ "integer": {
+ }
+ },
+ "element": [{
+ "integer": 2
+ }]
+ }
+ }]
+ }
+ }, {
+ "array": {
+ "elementType": {
+ "array": {
+ "elementType": {
+ "integer": {
+ }
+ },
+ "containsNull": true
+ }
+ },
+ "element": [{
+ "array": {
+ "elementType": {
+ "integer": {
+ }
+ },
+ "element": [{
+ "integer": 3
+ }]
+ }
+ }]
+ }
+ }]
+ }
+ }
+ }, {
+ "literal": {
+ "array": {
+ "elementType": {
+ "boolean": {
+ }
+ },
+ "element": [{
+ "boolean": true
+ }, {
+ "boolean": false
+ }]
+ }
+ }
+ }, {
+ "literal": {
+ "binary": "Q0RF"
+ }
+ }, {
+ "literal": {
+ "array": {
+ "elementType": {
+ "short": {
+ }
+ },
+ "element": [{
+ "short": 9872
+ }, {
+ "short": 9873
+ }, {
+ "short": 9874
+ }]
+ }
+ }
+ }, {
+ "literal": {
+ "array": {
+ "elementType": {
+ "integer": {
+ }
+ },
+ "element": [{
+ "integer": -8726532
+ }, {
+ "integer": 8726532
+ }, {
+ "integer": -8726533
+ }]
+ }
+ }
+ }, {
+ "literal": {
+ "array": {
+ "elementType": {
+ "long": {
+ }
+ },
+ "element": [{
+ "long": "7834609328726531"
+ }, {
+ "long": "7834609328726532"
+ }, {
+ "long": "7834609328726533"
+ }]
+ }
+ }
+ }, {
+ "literal": {
+ "array": {
+ "elementType": {
+ "double": {
+ }
+ },
+ "element": [{
+ "double": 2.718281828459045
+ }, {
+ "double": 1.0
+ }, {
+ "double": 2.0
+ }]
+ }
+ }
+ }, {
+ "literal": {
+ "array": {
+ "elementType": {
+ "float": {
+ }
+ },
+ "element": [{
+ "float": -0.8
+ }, {
+ "float": -0.7
+ }, {
+ "float": -0.9
+ }]
+ }
+ }
+ }, {
+ "literal": {
+ "array": {
+ "elementType": {
+ "decimal": {
+ "scale": 18,
+ "precision": 38
+ }
+ },
+ "element": [{
+ "decimal": {
+ "value": "89.97620",
+ "precision": 7,
+ "scale": 5
+ }
+ }, {
+ "decimal": {
+ "value": "89.97621",
+ "precision": 7,
+ "scale": 5
+ }
+ }]
+ }
+ }
+ }, {
+ "literal": {
+ "array": {
+ "elementType": {
+ "decimal": {
+ "scale": 18,
+ "precision": 38
+ }
+ },
+ "element": [{
+ "decimal": {
+ "value": "89889.7667231",
+ "precision": 12,
+ "scale": 7
+ }
+ }, {
+ "decimal": {
+ "value": "89889.7667231",
+ "precision": 12,
+ "scale": 7
+ }
+ }]
+ }
+ }
+ }, {
+ "literal": {
+ "array": {
+ "elementType": {
+ "string": {
+ }
+ },
+ "element": [{
+ "string": "connect!"
+ }, {
+ "string": "disconnect!"
+ }]
+ }
+ }
+ }, {
+ "literal": {
+ "string": "TF"
+ }
+ }, {
+ "literal": {
+ "array": {
+ "elementType": {
+ "string": {
+ }
+ },
+ "element": [{
+ "string": "ABCDEFGHIJ"
+ }, {
+ "string": "BCDEFGHIJK"
+ }]
+ }
+ }
+ }, {
+ "literal": {
+ "array": {
+ "elementType": {
+ "date": {
+ }
+ },
+ "element": [{
+ "date": 18545
+ }, {
+ "date": 18546
+ }]
+ }
+ }
+ }, {
+ "literal": {
+ "array": {
+ "elementType": {
+ "timestamp": {
+ }
+ },
+ "element": [{
+ "timestamp": "1677155519808000"
+ }, {
+ "timestamp": "1677155519809000"
+ }]
+ }
+ }
+ }, {
+ "literal": {
+ "array": {
+ "elementType": {
+ "timestamp": {
+ }
+ },
+ "element": [{
+ "timestamp": "12345000"
+ }, {
+ "timestamp": "23456000"
+ }]
+ }
+ }
+ }, {
+ "literal": {
+ "array": {
+ "elementType": {
+ "timestampNtz": {
+ }
+ },
+ "element": [{
+ "timestampNtz": "1677184560000000"
+ }, {
+ "timestampNtz": "1677188160000000"
+ }]
+ }
+ }
+ }, {
+ "literal": {
+ "array": {
+ "elementType": {
+ "date": {
+ }
+ },
+ "element": [{
+ "date": 19411
+ }, {
+ "date": 19417
+ }]
+ }
+ }
+ }, {
+ "literal": {
+ "array": {
+ "elementType": {
+ "dayTimeInterval": {
+ "startField": 0,
+ "endField": 3
+ }
+ },
+ "element": [{
+ "dayTimeInterval": "100000000"
+ }, {
+ "dayTimeInterval": "200000000"
+ }]
+ }
+ }
+ }, {
+ "literal": {
+ "array": {
+ "elementType": {
+ "yearMonthInterval": {
+ "startField": 0,
+ "endField": 1
+ }
+ },
+ "element": [{
+ "yearMonthInterval": 0
+ }, {
+ "yearMonthInterval": 0
+ }]
+ }
+ }
+ }, {
+ "literal": {
+ "array": {
+ "elementType": {
+ "calendarInterval": {
+ }
+ },
+ "element": [{
+ "calendarInterval": {
+ "months": 2,
+ "days": 20,
+ "microseconds": "100"
+ }
+ }, {
+ "calendarInterval": {
+ "months": 2,
+ "days": 21,
+ "microseconds": "200"
+ }
+ }]
+ }
+ }
+ }]
+ }
+}
\ No newline at end of file
diff --git
a/connector/connect/common/src/test/resources/query-tests/queries/function_lit_array.proto.bin
b/connector/connect/common/src/test/resources/query-tests/queries/function_lit_array.proto.bin
new file mode 100644
index 00000000000..9763bed6b50
Binary files /dev/null and
b/connector/connect/common/src/test/resources/query-tests/queries/function_lit_array.proto.bin
differ
diff --git
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/LiteralValueProtoConverter.scala
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/LiteralValueProtoConverter.scala
index 6ddaabb1b88..79c489b9f5b 100644
---
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/LiteralValueProtoConverter.scala
+++
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/LiteralValueProtoConverter.scala
@@ -17,8 +17,12 @@
package org.apache.spark.sql.connect.planner
+import scala.collection.mutable
+import scala.reflect.ClassTag
+
import org.apache.spark.connect.proto
import org.apache.spark.sql.catalyst.expressions
+import org.apache.spark.sql.catalyst.util.{DateTimeUtils, IntervalUtils}
import org.apache.spark.sql.connect.common.{DataTypeProtoConverter,
InvalidPlanInput}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
@@ -97,6 +101,10 @@ object LiteralValueProtoConverter {
case proto.Expression.Literal.LiteralTypeCase.DAY_TIME_INTERVAL =>
expressions.Literal(lit.getDayTimeInterval, DayTimeIntervalType())
+ case proto.Expression.Literal.LiteralTypeCase.ARRAY =>
+ expressions.Literal.create(
+ toArrayData(lit.getArray),
+
ArrayType(DataTypeProtoConverter.toCatalystType(lit.getArray.getElementType)))
case _ =>
throw InvalidPlanInput(
s"Unsupported Literal Type: ${lit.getLiteralTypeCase.getNumber}" +
@@ -130,4 +138,61 @@ object LiteralValueProtoConverter {
case o => throw new Exception(s"Unsupported value type: $o")
}
}
+
+ private def toArrayData(array: proto.Expression.Literal.Array): Any = {
+ def makeArrayData[T](converter: proto.Expression.Literal => T)(implicit
+ tag: ClassTag[T]): Array[T] = {
+ val builder = mutable.ArrayBuilder.make[T]
+ val elementList = array.getElementList
+ builder.sizeHint(elementList.size())
+ val iter = elementList.iterator()
+ while (iter.hasNext) {
+ builder += converter(iter.next())
+ }
+ builder.result()
+ }
+
+ val elementType = array.getElementType
+ if (elementType.hasShort) {
+ makeArrayData(v => v.getShort.toShort)
+ } else if (elementType.hasInteger) {
+ makeArrayData(v => v.getInteger)
+ } else if (elementType.hasLong) {
+ makeArrayData(v => v.getLong)
+ } else if (elementType.hasDouble) {
+ makeArrayData(v => v.getDouble)
+ } else if (elementType.hasByte) {
+ makeArrayData(v => v.getByte.toByte)
+ } else if (elementType.hasFloat) {
+ makeArrayData(v => v.getFloat)
+ } else if (elementType.hasBoolean) {
+ makeArrayData(v => v.getBoolean)
+ } else if (elementType.hasString) {
+ makeArrayData(v => v.getString)
+ } else if (elementType.hasBinary) {
+ makeArrayData(v => v.getBinary.toByteArray)
+ } else if (elementType.hasDate) {
+ makeArrayData(v => DateTimeUtils.toJavaDate(v.getDate))
+ } else if (elementType.hasTimestamp) {
+ makeArrayData(v => DateTimeUtils.toJavaTimestamp(v.getTimestamp))
+ } else if (elementType.hasTimestampNtz) {
+ makeArrayData(v =>
DateTimeUtils.microsToLocalDateTime(v.getTimestampNtz))
+ } else if (elementType.hasDayTimeInterval) {
+ makeArrayData(v => IntervalUtils.microsToDuration(v.getDayTimeInterval))
+ } else if (elementType.hasYearMonthInterval) {
+ makeArrayData(v => IntervalUtils.monthsToPeriod(v.getYearMonthInterval))
+ } else if (elementType.hasDecimal) {
+ makeArrayData(v => Decimal(v.getDecimal.getValue))
+ } else if (elementType.hasCalendarInterval) {
+ makeArrayData(v => {
+ val interval = v.getCalendarInterval
+ new CalendarInterval(interval.getMonths, interval.getDays,
interval.getMicroseconds)
+ })
+ } else if (elementType.hasArray) {
+ makeArrayData(v => toArrayData(v.getArray))
+ } else {
+ throw InvalidPlanInput(s"Unsupported Literal Type: $elementType)")
+ }
+ }
+
}
diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.py
b/python/pyspark/sql/connect/proto/expressions_pb2.py
index 6e515235c7d..d0db2ad56cc 100644
--- a/python/pyspark/sql/connect/proto/expressions_pb2.py
+++ b/python/pyspark/sql/connect/proto/expressions_pb2.py
@@ -34,7 +34,7 @@ from pyspark.sql.connect.proto import types_pb2 as
spark_dot_connect_dot_types__
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-
b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x19spark/connect/types.proto"\xe6%\n\nExpression\x12=\n\x07literal\x18\x01
\x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02
\x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03
\x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunct
[...]
+
b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x19spark/connect/types.proto"\xa8\'\n\nExpression\x12=\n\x07literal\x18\x01
\x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02
\x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03
\x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunc
[...]
)
@@ -49,6 +49,7 @@ _EXPRESSION_CAST = _EXPRESSION.nested_types_by_name["Cast"]
_EXPRESSION_LITERAL = _EXPRESSION.nested_types_by_name["Literal"]
_EXPRESSION_LITERAL_DECIMAL =
_EXPRESSION_LITERAL.nested_types_by_name["Decimal"]
_EXPRESSION_LITERAL_CALENDARINTERVAL =
_EXPRESSION_LITERAL.nested_types_by_name["CalendarInterval"]
+_EXPRESSION_LITERAL_ARRAY = _EXPRESSION_LITERAL.nested_types_by_name["Array"]
_EXPRESSION_UNRESOLVEDATTRIBUTE =
_EXPRESSION.nested_types_by_name["UnresolvedAttribute"]
_EXPRESSION_UNRESOLVEDFUNCTION =
_EXPRESSION.nested_types_by_name["UnresolvedFunction"]
_EXPRESSION_EXPRESSIONSTRING =
_EXPRESSION.nested_types_by_name["ExpressionString"]
@@ -142,6 +143,15 @@ Expression = _reflection.GeneratedProtocolMessageType(
#
@@protoc_insertion_point(class_scope:spark.connect.Expression.Literal.CalendarInterval)
},
),
+ "Array": _reflection.GeneratedProtocolMessageType(
+ "Array",
+ (_message.Message,),
+ {
+ "DESCRIPTOR": _EXPRESSION_LITERAL_ARRAY,
+ "__module__": "spark.connect.expressions_pb2"
+ #
@@protoc_insertion_point(class_scope:spark.connect.Expression.Literal.Array)
+ },
+ ),
"DESCRIPTOR": _EXPRESSION_LITERAL,
"__module__": "spark.connect.expressions_pb2"
#
@@protoc_insertion_point(class_scope:spark.connect.Expression.Literal)
@@ -251,6 +261,7 @@ _sym_db.RegisterMessage(Expression.Cast)
_sym_db.RegisterMessage(Expression.Literal)
_sym_db.RegisterMessage(Expression.Literal.Decimal)
_sym_db.RegisterMessage(Expression.Literal.CalendarInterval)
+_sym_db.RegisterMessage(Expression.Literal.Array)
_sym_db.RegisterMessage(Expression.UnresolvedAttribute)
_sym_db.RegisterMessage(Expression.UnresolvedFunction)
_sym_db.RegisterMessage(Expression.ExpressionString)
@@ -300,7 +311,7 @@ if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None
DESCRIPTOR._serialized_options =
b"\n\036org.apache.spark.connect.protoP\001"
_EXPRESSION._serialized_start = 105
- _EXPRESSION._serialized_end = 4943
+ _EXPRESSION._serialized_end = 5137
_EXPRESSION_WINDOW._serialized_start = 1475
_EXPRESSION_WINDOW._serialized_end = 2258
_EXPRESSION_WINDOW_WINDOWFRAME._serialized_start = 1765
@@ -318,35 +329,37 @@ if _descriptor._USE_C_DESCRIPTORS == False:
_EXPRESSION_CAST._serialized_start = 2689
_EXPRESSION_CAST._serialized_end = 2834
_EXPRESSION_LITERAL._serialized_start = 2837
- _EXPRESSION_LITERAL._serialized_end = 3713
- _EXPRESSION_LITERAL_DECIMAL._serialized_start = 3480
- _EXPRESSION_LITERAL_DECIMAL._serialized_end = 3597
- _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_start = 3599
- _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_end = 3697
- _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_start = 3715
- _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_end = 3827
- _EXPRESSION_UNRESOLVEDFUNCTION._serialized_start = 3830
- _EXPRESSION_UNRESOLVEDFUNCTION._serialized_end = 4034
- _EXPRESSION_EXPRESSIONSTRING._serialized_start = 4036
- _EXPRESSION_EXPRESSIONSTRING._serialized_end = 4086
- _EXPRESSION_UNRESOLVEDSTAR._serialized_start = 4088
- _EXPRESSION_UNRESOLVEDSTAR._serialized_end = 4170
- _EXPRESSION_UNRESOLVEDREGEX._serialized_start = 4172
- _EXPRESSION_UNRESOLVEDREGEX._serialized_end = 4258
- _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_start = 4261
- _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_end = 4393
- _EXPRESSION_UPDATEFIELDS._serialized_start = 4396
- _EXPRESSION_UPDATEFIELDS._serialized_end = 4583
- _EXPRESSION_ALIAS._serialized_start = 4585
- _EXPRESSION_ALIAS._serialized_end = 4705
- _EXPRESSION_LAMBDAFUNCTION._serialized_start = 4708
- _EXPRESSION_LAMBDAFUNCTION._serialized_end = 4866
- _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_start = 4868
- _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_end = 4930
- _COMMONINLINEUSERDEFINEDFUNCTION._serialized_start = 4946
- _COMMONINLINEUSERDEFINEDFUNCTION._serialized_end = 5257
- _PYTHONUDF._serialized_start = 5260
- _PYTHONUDF._serialized_end = 5390
- _SCALARSCALAUDF._serialized_start = 5393
- _SCALARSCALAUDF._serialized_end = 5577
+ _EXPRESSION_LITERAL._serialized_end = 3907
+ _EXPRESSION_LITERAL_DECIMAL._serialized_start = 3545
+ _EXPRESSION_LITERAL_DECIMAL._serialized_end = 3662
+ _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_start = 3664
+ _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_end = 3762
+ _EXPRESSION_LITERAL_ARRAY._serialized_start = 3764
+ _EXPRESSION_LITERAL_ARRAY._serialized_end = 3891
+ _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_start = 3909
+ _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_end = 4021
+ _EXPRESSION_UNRESOLVEDFUNCTION._serialized_start = 4024
+ _EXPRESSION_UNRESOLVEDFUNCTION._serialized_end = 4228
+ _EXPRESSION_EXPRESSIONSTRING._serialized_start = 4230
+ _EXPRESSION_EXPRESSIONSTRING._serialized_end = 4280
+ _EXPRESSION_UNRESOLVEDSTAR._serialized_start = 4282
+ _EXPRESSION_UNRESOLVEDSTAR._serialized_end = 4364
+ _EXPRESSION_UNRESOLVEDREGEX._serialized_start = 4366
+ _EXPRESSION_UNRESOLVEDREGEX._serialized_end = 4452
+ _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_start = 4455
+ _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_end = 4587
+ _EXPRESSION_UPDATEFIELDS._serialized_start = 4590
+ _EXPRESSION_UPDATEFIELDS._serialized_end = 4777
+ _EXPRESSION_ALIAS._serialized_start = 4779
+ _EXPRESSION_ALIAS._serialized_end = 4899
+ _EXPRESSION_LAMBDAFUNCTION._serialized_start = 4902
+ _EXPRESSION_LAMBDAFUNCTION._serialized_end = 5060
+ _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_start = 5062
+ _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_end = 5124
+ _COMMONINLINEUSERDEFINEDFUNCTION._serialized_start = 5140
+ _COMMONINLINEUSERDEFINEDFUNCTION._serialized_end = 5451
+ _PYTHONUDF._serialized_start = 5454
+ _PYTHONUDF._serialized_end = 5584
+ _SCALARSCALAUDF._serialized_start = 5587
+ _SCALARSCALAUDF._serialized_end = 5771
# @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.pyi
b/python/pyspark/sql/connect/proto/expressions_pb2.pyi
index 996de7fef2d..37db24ff91a 100644
--- a/python/pyspark/sql/connect/proto/expressions_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/expressions_pb2.pyi
@@ -440,6 +440,35 @@ class Expression(google.protobuf.message.Message):
],
) -> None: ...
+ class Array(google.protobuf.message.Message):
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+ ELEMENTTYPE_FIELD_NUMBER: builtins.int
+ ELEMENT_FIELD_NUMBER: builtins.int
+ @property
+ def elementType(self) ->
pyspark.sql.connect.proto.types_pb2.DataType: ...
+ @property
+ def element(
+ self,
+ ) ->
google.protobuf.internal.containers.RepeatedCompositeFieldContainer[
+ global___Expression.Literal
+ ]: ...
+ def __init__(
+ self,
+ *,
+ elementType: pyspark.sql.connect.proto.types_pb2.DataType |
None = ...,
+ element: collections.abc.Iterable[global___Expression.Literal]
| None = ...,
+ ) -> None: ...
+ def HasField(
+ self, field_name: typing_extensions.Literal["elementType",
b"elementType"]
+ ) -> builtins.bool: ...
+ def ClearField(
+ self,
+ field_name: typing_extensions.Literal[
+ "element", b"element", "elementType", b"elementType"
+ ],
+ ) -> None: ...
+
NULL_FIELD_NUMBER: builtins.int
BINARY_FIELD_NUMBER: builtins.int
BOOLEAN_FIELD_NUMBER: builtins.int
@@ -457,6 +486,7 @@ class Expression(google.protobuf.message.Message):
CALENDAR_INTERVAL_FIELD_NUMBER: builtins.int
YEAR_MONTH_INTERVAL_FIELD_NUMBER: builtins.int
DAY_TIME_INTERVAL_FIELD_NUMBER: builtins.int
+ ARRAY_FIELD_NUMBER: builtins.int
@property
def null(self) -> pyspark.sql.connect.proto.types_pb2.DataType: ...
binary: builtins.bytes
@@ -480,6 +510,8 @@ class Expression(google.protobuf.message.Message):
def calendar_interval(self) ->
global___Expression.Literal.CalendarInterval: ...
year_month_interval: builtins.int
day_time_interval: builtins.int
+ @property
+ def array(self) -> global___Expression.Literal.Array: ...
def __init__(
self,
*,
@@ -500,10 +532,13 @@ class Expression(google.protobuf.message.Message):
calendar_interval: global___Expression.Literal.CalendarInterval |
None = ...,
year_month_interval: builtins.int = ...,
day_time_interval: builtins.int = ...,
+ array: global___Expression.Literal.Array | None = ...,
) -> None: ...
def HasField(
self,
field_name: typing_extensions.Literal[
+ "array",
+ b"array",
"binary",
b"binary",
"boolean",
@@ -545,6 +580,8 @@ class Expression(google.protobuf.message.Message):
def ClearField(
self,
field_name: typing_extensions.Literal[
+ "array",
+ b"array",
"binary",
b"binary",
"boolean",
@@ -603,6 +640,7 @@ class Expression(google.protobuf.message.Message):
"calendar_interval",
"year_month_interval",
"day_time_interval",
+ "array",
] | None: ...
class UnresolvedAttribute(google.protobuf.message.Message):
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]