This is an automated email from the ASF dual-hosted git repository.
bowenliang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/kyuubi.git
The following commit(s) were added to refs/heads/master by this push:
new 9047151d8 [KYUUBI #5851] Generalize TRowSet generators
9047151d8 is described below
commit 9047151d8dd24cb5834770c9b216438838b7f9ed
Author: Bowen Liang <[email protected]>
AuthorDate: Fri Dec 15 17:44:07 2023 +0800
[KYUUBI #5851] Generalize TRowSet generators
# :mag: Description
## Issue References ๐
As described.
## Describe Your Solution ๐ง
- Introduced a generalized RowSet generator
`AbstractTRowSetGenerator[SchemaT, RowT, ColumnT]`
- extract common methods for looping and assembling the rows to TRowSet
- support generation for either column-based or row-based TRowSet
- Each engine creates a sub-generator of `AbstractTRowSetGenerator`
- focus on mapping and conversion from the engine's data type to the
relative Thrift type
- implements the schema data type and column value methods
- create a generator instance instead of the previously used `RowSet`
object, for isolated session-aware or thread-aware configs or context, eg.
Timezone ID for Flink, and the Hive time formatters for Spark.
- This PR covers the TRowSet generation for the server and the engines of
Spark/Flink/Trino/Chat, except the JDBC engine which will be supported in the
follow-ups with JDBC dialect support.
## Types of changes :bookmark:
- [ ] Bugfix (non-breaking change which fixes an issue)
- [ ] New feature (non-breaking change which adds functionality)
- [ ] Breaking change (fix or feature that would cause existing
functionality to change)
## Test Plan ๐งช
#### Behavior Without This Pull Request :coffin:
No behavior changes.
#### Behavior With This Pull Request :tada:
No behavior changes.
#### Related Unit Tests
CI tests.
---
# Checklists
## ๐ Author Self Checklist
- [x] My code follows the [style
guidelines](https://kyuubi.readthedocs.io/en/master/contributing/code/style.html)
of this project
- [x] I have performed a self-review
- [x] I have commented my code, particularly in hard-to-understand areas
- [ ] I have made corresponding changes to the documentation
- [ ] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my feature
works
- [x] New and existing unit tests pass locally with my changes
- [x] This patch was not authored or co-authored using [Generative
Tooling](https://www.apache.org/legal/generative-tooling.html)
## ๐ Committer Pre-Merge Checklist
- [ ] Pull request title is okay.
- [ ] No license issues.
- [ ] Milestone correctly set?
- [ ] Test coverage is ok
- [ ] Assignees are selected.
- [ ] Minimum number of approvals
- [ ] No changes are requested
**Be nice. Be informative.**
Closes #5851 from bowenliang123/rowset-gen.
Closes #5851
1d2f73ab4 [Bowen Liang] common RowSetGenerator
Authored-by: Bowen Liang <[email protected]>
Signed-off-by: Bowen Liang <[email protected]>
---
.../engine/chat/operation/ChatOperation.scala | 10 +-
.../engine/chat/schema/ChatTRowSetGenerator.scala | 52 +++++
.../apache/kyuubi/engine/chat/schema/RowSet.scala | 107 ---------
.../engine/flink/operation/FlinkOperation.scala | 5 +-
.../flink/schema/FlinkTRowSetGenerator.scala | 141 ++++++++++++
.../apache/kyuubi/engine/flink/schema/RowSet.scala | 239 +--------------------
.../engine/flink/result/ResultSetSuite.scala | 10 +-
.../engine/spark/operation/SparkOperation.scala | 14 +-
.../apache/kyuubi/engine/spark/schema/RowSet.scala | 228 --------------------
.../spark/schema/SparkArrowTRowSetGenerator.scala | 77 +++++++
.../spark/schema/SparkTRowSetGenerator.scala | 93 ++++++++
.../kyuubi/engine/spark/schema/RowSetSuite.scala | 6 +-
.../engine/trino/operation/ExecuteStatement.scala | 5 +-
.../engine/trino/operation/TrinoOperation.scala | 6 +-
.../apache/kyuubi/engine/trino/schema/RowSet.scala | 217 -------------------
.../trino/schema/TrinoTRowSetGenerator.scala | 96 +++++++++
.../kyuubi/engine/trino/schema/RowSetSuite.scala | 6 +-
.../engine/schema/AbstractTRowSetGenerator.scala | 210 ++++++++++++++++++
.../kyuubi/sql/plan/command/RunnableCommand.scala | 4 +-
.../apache/kyuubi/sql/schema/RowSetHelper.scala | 209 ------------------
.../kyuubi/sql/schema/ServerTRowSetGenerator.scala | 78 +++++++
21 files changed, 785 insertions(+), 1028 deletions(-)
diff --git
a/externals/kyuubi-chat-engine/src/main/scala/org/apache/kyuubi/engine/chat/operation/ChatOperation.scala
b/externals/kyuubi-chat-engine/src/main/scala/org/apache/kyuubi/engine/chat/operation/ChatOperation.scala
index 9cddc3e66..60f15ea65 100644
---
a/externals/kyuubi-chat-engine/src/main/scala/org/apache/kyuubi/engine/chat/operation/ChatOperation.scala
+++
b/externals/kyuubi-chat-engine/src/main/scala/org/apache/kyuubi/engine/chat/operation/ChatOperation.scala
@@ -18,7 +18,8 @@ package org.apache.kyuubi.engine.chat.operation
import org.apache.kyuubi.{KyuubiSQLException, Utils}
import org.apache.kyuubi.config.KyuubiConf
-import org.apache.kyuubi.engine.chat.schema.{RowSet, SchemaHelper}
+import org.apache.kyuubi.engine.chat.schema.{ChatTRowSetGenerator,
SchemaHelper}
+import
org.apache.kyuubi.engine.chat.schema.ChatTRowSetGenerator.COL_STRING_TYPE
import org.apache.kyuubi.operation.{AbstractOperation, FetchIterator,
OperationState}
import org.apache.kyuubi.operation.FetchOrientation.{FETCH_FIRST, FETCH_NEXT,
FETCH_PRIOR, FetchOrientation}
import org.apache.kyuubi.session.Session
@@ -45,8 +46,11 @@ abstract class ChatOperation(session: Session) extends
AbstractOperation(session
iter.fetchAbsolute(0)
}
- val taken = iter.take(rowSetSize)
- val resultRowSet = RowSet.toTRowSet(taken.toSeq, 1, getProtocolVersion)
+ val taken = iter.take(rowSetSize).map(_.toSeq)
+ val resultRowSet = new ChatTRowSetGenerator().toTRowSet(
+ taken.toSeq,
+ Seq(COL_STRING_TYPE),
+ getProtocolVersion)
resultRowSet.setStartRowOffset(iter.getPosition)
val resp = new TFetchResultsResp(OK_STATUS)
resp.setResults(resultRowSet)
diff --git
a/externals/kyuubi-chat-engine/src/main/scala/org/apache/kyuubi/engine/chat/schema/ChatTRowSetGenerator.scala
b/externals/kyuubi-chat-engine/src/main/scala/org/apache/kyuubi/engine/chat/schema/ChatTRowSetGenerator.scala
new file mode 100644
index 000000000..990a19764
--- /dev/null
+++
b/externals/kyuubi-chat-engine/src/main/scala/org/apache/kyuubi/engine/chat/schema/ChatTRowSetGenerator.scala
@@ -0,0 +1,52 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kyuubi.engine.chat.schema
+
+import org.apache.kyuubi.engine.chat.schema.ChatTRowSetGenerator._
+import org.apache.kyuubi.engine.schema.AbstractTRowSetGenerator
+import org.apache.kyuubi.shaded.hive.service.rpc.thrift._
+import org.apache.kyuubi.shaded.hive.service.rpc.thrift.TTypeId._
+
+class ChatTRowSetGenerator
+ extends AbstractTRowSetGenerator[Seq[String], Seq[String], String] {
+
+ override def getColumnSizeFromSchemaType(schema: Seq[String]): Int =
schema.length
+
+ override def getColumnType(schema: Seq[String], ordinal: Int): String =
COL_STRING_TYPE
+
+ override protected def isColumnNullAt(row: Seq[String], ordinal: Int):
Boolean =
+ row(ordinal) == null
+
+ override def getColumnAs[T](row: Seq[String], ordinal: Int): T =
row(ordinal).asInstanceOf[T]
+
+ override def toTColumn(rows: Seq[Seq[String]], ordinal: Int, typ: String):
TColumn =
+ typ match {
+ case COL_STRING_TYPE => toTTypeColumn(STRING_TYPE, rows, ordinal)
+ case otherType => throw new UnsupportedOperationException(s"type
$otherType")
+ }
+
+ override def toTColumnValue(ordinal: Int, row: Seq[String], types:
Seq[String]): TColumnValue =
+ getColumnType(types, ordinal) match {
+ case "String" => toTTypeColumnVal(STRING_TYPE, row, ordinal)
+ case otherType => throw new UnsupportedOperationException(s"type
$otherType")
+ }
+}
+
+object ChatTRowSetGenerator {
+ val COL_STRING_TYPE: String = classOf[String].getSimpleName
+}
diff --git
a/externals/kyuubi-chat-engine/src/main/scala/org/apache/kyuubi/engine/chat/schema/RowSet.scala
b/externals/kyuubi-chat-engine/src/main/scala/org/apache/kyuubi/engine/chat/schema/RowSet.scala
deleted file mode 100644
index 827940001..000000000
---
a/externals/kyuubi-chat-engine/src/main/scala/org/apache/kyuubi/engine/chat/schema/RowSet.scala
+++ /dev/null
@@ -1,107 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.kyuubi.engine.chat.schema
-
-import java.util
-
-import org.apache.kyuubi.shaded.hive.service.rpc.thrift._
-import org.apache.kyuubi.util.RowSetUtils._
-
-object RowSet {
-
- def emptyTRowSet(): TRowSet = {
- new TRowSet(0, new java.util.ArrayList[TRow](0))
- }
-
- def toTRowSet(
- rows: Seq[Array[String]],
- columnSize: Int,
- protocolVersion: TProtocolVersion): TRowSet = {
- if (protocolVersion.getValue <
TProtocolVersion.HIVE_CLI_SERVICE_PROTOCOL_V6.getValue) {
- toRowBasedSet(rows, columnSize)
- } else {
- toColumnBasedSet(rows, columnSize)
- }
- }
-
- def toRowBasedSet(rows: Seq[Array[String]], columnSize: Int): TRowSet = {
- val rowSize = rows.length
- val tRows = new java.util.ArrayList[TRow](rowSize)
- var i = 0
- while (i < rowSize) {
- val row = rows(i)
- val tRow = new TRow()
- var j = 0
- val columnSize = row.length
- while (j < columnSize) {
- val columnValue = stringTColumnValue(j, row)
- tRow.addToColVals(columnValue)
- j += 1
- }
- i += 1
- tRows.add(tRow)
- }
- new TRowSet(0, tRows)
- }
-
- def toColumnBasedSet(rows: Seq[Array[String]], columnSize: Int): TRowSet = {
- val rowSize = rows.length
- val tRowSet = new TRowSet(0, new util.ArrayList[TRow](rowSize))
- var i = 0
- while (i < columnSize) {
- val tColumn = toTColumn(rows, i)
- tRowSet.addToColumns(tColumn)
- i += 1
- }
- tRowSet
- }
-
- private def toTColumn(rows: Seq[Array[String]], ordinal: Int): TColumn = {
- val nulls = new java.util.BitSet()
- val values = getOrSetAsNull[String](rows, ordinal, nulls, "")
- TColumn.stringVal(new TStringColumn(values, nulls))
- }
-
- private def getOrSetAsNull[String](
- rows: Seq[Array[String]],
- ordinal: Int,
- nulls: util.BitSet,
- defaultVal: String): util.List[String] = {
- val size = rows.length
- val ret = new util.ArrayList[String](size)
- var idx = 0
- while (idx < size) {
- val row = rows(idx)
- val isNull = row(ordinal) == null
- if (isNull) {
- nulls.set(idx, true)
- ret.add(idx, defaultVal)
- } else {
- ret.add(idx, row(ordinal))
- }
- idx += 1
- }
- ret
- }
-
- private def stringTColumnValue(ordinal: Int, row: Array[String]):
TColumnValue = {
- val tStringValue = new TStringValue
- if (row(ordinal) != null) tStringValue.setValue(row(ordinal))
- TColumnValue.stringVal(tStringValue)
- }
-}
diff --git
a/externals/kyuubi-flink-sql-engine/src/main/scala/org/apache/kyuubi/engine/flink/operation/FlinkOperation.scala
b/externals/kyuubi-flink-sql-engine/src/main/scala/org/apache/kyuubi/engine/flink/operation/FlinkOperation.scala
index ff2e99c0c..df067a888 100644
---
a/externals/kyuubi-flink-sql-engine/src/main/scala/org/apache/kyuubi/engine/flink/operation/FlinkOperation.scala
+++
b/externals/kyuubi-flink-sql-engine/src/main/scala/org/apache/kyuubi/engine/flink/operation/FlinkOperation.scala
@@ -31,7 +31,7 @@ import org.apache.flink.types.Row
import org.apache.kyuubi.{KyuubiSQLException, Utils}
import org.apache.kyuubi.engine.flink.result.ResultSet
-import org.apache.kyuubi.engine.flink.schema.RowSet
+import org.apache.kyuubi.engine.flink.schema.{FlinkTRowSetGenerator, RowSet}
import org.apache.kyuubi.engine.flink.session.FlinkSessionImpl
import org.apache.kyuubi.operation.{AbstractOperation, OperationState}
import org.apache.kyuubi.operation.FetchOrientation.{FETCH_FIRST, FETCH_NEXT,
FETCH_PRIOR, FetchOrientation}
@@ -133,10 +133,9 @@ abstract class FlinkOperation(session: Session) extends
AbstractOperation(sessio
case Some(tz) => ZoneId.of(tz)
case None => ZoneId.systemDefault()
}
- val resultRowSet = RowSet.resultSetToTRowSet(
+ val resultRowSet = new FlinkTRowSetGenerator(zoneId).toTRowSet(
batch.toList,
resultSet,
- zoneId,
getProtocolVersion)
val resp = new TFetchResultsResp(OK_STATUS)
resp.setResults(resultRowSet)
diff --git
a/externals/kyuubi-flink-sql-engine/src/main/scala/org/apache/kyuubi/engine/flink/schema/FlinkTRowSetGenerator.scala
b/externals/kyuubi-flink-sql-engine/src/main/scala/org/apache/kyuubi/engine/flink/schema/FlinkTRowSetGenerator.scala
new file mode 100644
index 000000000..b53aab47f
--- /dev/null
+++
b/externals/kyuubi-flink-sql-engine/src/main/scala/org/apache/kyuubi/engine/flink/schema/FlinkTRowSetGenerator.scala
@@ -0,0 +1,141 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kyuubi.engine.flink.schema
+
+import java.time.{Instant, ZonedDateTime, ZoneId}
+
+import scala.collection.JavaConverters._
+
+import org.apache.flink.table.data.StringData
+import org.apache.flink.table.types.logical._
+import org.apache.flink.types.Row
+
+import org.apache.kyuubi.engine.flink.result.ResultSet
+import org.apache.kyuubi.engine.flink.schema.RowSet.{toHiveString,
TIMESTAMP_LZT_FORMATTER}
+import org.apache.kyuubi.engine.schema.AbstractTRowSetGenerator
+import org.apache.kyuubi.shaded.hive.service.rpc.thrift._
+import org.apache.kyuubi.shaded.hive.service.rpc.thrift.TTypeId._
+import org.apache.kyuubi.util.RowSetUtils.bitSetToBuffer
+
+class FlinkTRowSetGenerator(zoneId: ZoneId)
+ extends AbstractTRowSetGenerator[ResultSet, Row, LogicalType] {
+ override def getColumnSizeFromSchemaType(schema: ResultSet): Int =
schema.columns.size
+
+ override def getColumnType(schema: ResultSet, ordinal: Int): LogicalType =
+ schema.columns.get(ordinal).getDataType.getLogicalType
+
+ override def isColumnNullAt(row: Row, ordinal: Int): Boolean =
row.getField(ordinal) == null
+
+ override def getColumnAs[T](row: Row, ordinal: Int): T =
row.getFieldAs[T](ordinal)
+
+ override def toTColumnValue(ordinal: Int, row: Row, types: ResultSet):
TColumnValue = {
+ getColumnType(types, ordinal) match {
+ case _: BooleanType => toTTypeColumnVal(BOOLEAN_TYPE, row, ordinal)
+ case _: TinyIntType => toTTypeColumnVal(BINARY_TYPE, row, ordinal)
+ case _: SmallIntType => toTTypeColumnVal(TINYINT_TYPE, row, ordinal)
+ case _: IntType => toTTypeColumnVal(INT_TYPE, row, ordinal)
+ case _: BigIntType => toTTypeColumnVal(BIGINT_TYPE, row, ordinal)
+ case _: DoubleType => toTTypeColumnVal(DOUBLE_TYPE, row, ordinal)
+ case _: FloatType => toTTypeColumnVal(FLOAT_TYPE, row, ordinal)
+ case t @ (_: VarCharType | _: CharType) =>
+ val tStringValue = new TStringValue
+ val fieldValue = row.getField(ordinal)
+ fieldValue match {
+ case value: String =>
+ tStringValue.setValue(value)
+ case value: StringData =>
+ tStringValue.setValue(value.toString)
+ case null =>
+ tStringValue.setValue(null)
+ case other =>
+ throw new IllegalArgumentException(
+ s"Unsupported conversion class ${other.getClass} " +
+ s"for type ${t.getClass}.")
+ }
+ TColumnValue.stringVal(tStringValue)
+ case _: LocalZonedTimestampType =>
+ val tStringValue = new TStringValue
+ val fieldValue = row.getField(ordinal)
+ tStringValue.setValue(TIMESTAMP_LZT_FORMATTER.format(
+ ZonedDateTime.ofInstant(fieldValue.asInstanceOf[Instant], zoneId)))
+ TColumnValue.stringVal(tStringValue)
+ case t =>
+ val tStringValue = new TStringValue
+ if (row.getField(ordinal) != null) {
+ tStringValue.setValue(toHiveString((row.getField(ordinal), t)))
+ }
+ TColumnValue.stringVal(tStringValue)
+ }
+ }
+
+ override def toTColumn(rows: Seq[Row], ordinal: Int, logicalType:
LogicalType): TColumn = {
+ val nulls = new java.util.BitSet()
+ // for each column, determine the conversion class by sampling the first
non-value value
+ // if there's no row, set the entire column empty
+ val sampleField = rows.iterator.map(_.getField(ordinal)).find(_ ne
null).orNull
+ logicalType match {
+ case _: BooleanType => toTTypeColumn(BOOLEAN_TYPE, rows, ordinal)
+ case _: TinyIntType => toTTypeColumn(BINARY_TYPE, rows, ordinal)
+ case _: SmallIntType => toTTypeColumn(TINYINT_TYPE, rows, ordinal)
+ case _: IntType => toTTypeColumn(INT_TYPE, rows, ordinal)
+ case _: BigIntType => toTTypeColumn(BIGINT_TYPE, rows, ordinal)
+ case _: FloatType => toTTypeColumn(FLOAT_TYPE, rows, ordinal)
+ case _: DoubleType => toTTypeColumn(DOUBLE_TYPE, rows, ordinal)
+ case t @ (_: VarCharType | _: CharType) =>
+ val values: java.util.List[String] = new java.util.ArrayList[String](0)
+ sampleField match {
+ case _: String =>
+ values.addAll(getOrSetAsNull[String](rows, ordinal, nulls, ""))
+ case _: StringData =>
+ val stringDataValues =
+ getOrSetAsNull[StringData](rows, ordinal, nulls,
StringData.fromString(""))
+ stringDataValues.forEach(e => values.add(e.toString))
+ case null =>
+ values.addAll(getOrSetAsNull[String](rows, ordinal, nulls, ""))
+ case other =>
+ throw new IllegalArgumentException(
+ s"Unsupported conversion class ${other.getClass} " +
+ s"for type ${t.getClass}.")
+ }
+ TColumn.stringVal(new TStringColumn(values, nulls))
+ case _: LocalZonedTimestampType =>
+ val values = getOrSetAsNull[Instant](rows, ordinal, nulls,
Instant.EPOCH)
+ .toArray().map(v =>
+ TIMESTAMP_LZT_FORMATTER.format(
+ ZonedDateTime.ofInstant(v.asInstanceOf[Instant], zoneId)))
+ TColumn.stringVal(new TStringColumn(values.toList.asJava, nulls))
+ case _ =>
+ var i = 0
+ val rowSize = rows.length
+ val values = new java.util.ArrayList[String](rowSize)
+ while (i < rowSize) {
+ val row = rows(i)
+ nulls.set(i, row.getField(ordinal) == null)
+ val value =
+ if (row.getField(ordinal) == null) {
+ ""
+ } else {
+ toHiveString((row.getField(ordinal), logicalType))
+ }
+ values.add(value)
+ i += 1
+ }
+ TColumn.stringVal(new TStringColumn(values, nulls))
+ }
+ }
+
+}
diff --git
a/externals/kyuubi-flink-sql-engine/src/main/scala/org/apache/kyuubi/engine/flink/schema/RowSet.scala
b/externals/kyuubi-flink-sql-engine/src/main/scala/org/apache/kyuubi/engine/flink/schema/RowSet.scala
index a000869cc..7015d7c52 100644
---
a/externals/kyuubi-flink-sql-engine/src/main/scala/org/apache/kyuubi/engine/flink/schema/RowSet.scala
+++
b/externals/kyuubi-flink-sql-engine/src/main/scala/org/apache/kyuubi/engine/flink/schema/RowSet.scala
@@ -17,262 +17,25 @@
package org.apache.kyuubi.engine.flink.schema
-import java.{lang, util}
-import java.nio.ByteBuffer
import java.nio.charset.StandardCharsets
import java.sql.{Date, Timestamp}
-import java.time.{Instant, LocalDate, LocalDateTime, ZonedDateTime, ZoneId}
+import java.time.{LocalDate, LocalDateTime}
import java.time.format.{DateTimeFormatter, DateTimeFormatterBuilder,
TextStyle}
import java.time.temporal.ChronoField
import java.util.Collections
import scala.collection.JavaConverters._
import scala.collection.mutable.ListBuffer
-import scala.language.implicitConversions
import org.apache.flink.table.catalog.Column
-import org.apache.flink.table.data.StringData
import org.apache.flink.table.types.logical._
import org.apache.flink.types.Row
-import org.apache.kyuubi.engine.flink.result.ResultSet
import org.apache.kyuubi.shaded.hive.service.rpc.thrift._
import org.apache.kyuubi.util.RowSetUtils._
object RowSet {
- def resultSetToTRowSet(
- rows: Seq[Row],
- resultSet: ResultSet,
- zoneId: ZoneId,
- protocolVersion: TProtocolVersion): TRowSet = {
- if (protocolVersion.getValue <
TProtocolVersion.HIVE_CLI_SERVICE_PROTOCOL_V6.getValue) {
- toRowBaseSet(rows, resultSet, zoneId)
- } else {
- toColumnBasedSet(rows, resultSet, zoneId)
- }
- }
-
- def toRowBaseSet(rows: Seq[Row], resultSet: ResultSet, zoneId: ZoneId):
TRowSet = {
- val rowSize = rows.size
- val tRows = new util.ArrayList[TRow](rowSize)
- var i = 0
- while (i < rowSize) {
- val row = rows(i)
- val tRow = new TRow()
- val columnSize = row.getArity
- var j = 0
- while (j < columnSize) {
- val columnValue = toTColumnValue(j, row, resultSet, zoneId)
- tRow.addToColVals(columnValue)
- j += 1
- }
- tRows.add(tRow)
- i += 1
- }
-
- new TRowSet(0, tRows)
- }
-
- def toColumnBasedSet(rows: Seq[Row], resultSet: ResultSet, zoneId: ZoneId):
TRowSet = {
- val size = rows.length
- val tRowSet = new TRowSet(0, new util.ArrayList[TRow](size))
- val columnSize = resultSet.getColumns.size()
- var i = 0
- while (i < columnSize) {
- val field = resultSet.getColumns.get(i)
- val tColumn = toTColumn(rows, i, field.getDataType.getLogicalType,
zoneId)
- tRowSet.addToColumns(tColumn)
- i += 1
- }
- tRowSet
- }
-
- private def toTColumnValue(
- ordinal: Int,
- row: Row,
- resultSet: ResultSet,
- zoneId: ZoneId): TColumnValue = {
-
- val column = resultSet.getColumns.get(ordinal)
- val logicalType = column.getDataType.getLogicalType
-
- logicalType match {
- case _: BooleanType =>
- val boolValue = new TBoolValue
- if (row.getField(ordinal) != null) {
- boolValue.setValue(row.getField(ordinal).asInstanceOf[Boolean])
- }
- TColumnValue.boolVal(boolValue)
- case _: TinyIntType =>
- val tByteValue = new TByteValue
- if (row.getField(ordinal) != null) {
- tByteValue.setValue(row.getField(ordinal).asInstanceOf[Byte])
- }
- TColumnValue.byteVal(tByteValue)
- case _: SmallIntType =>
- val tI16Value = new TI16Value
- if (row.getField(ordinal) != null) {
- tI16Value.setValue(row.getField(ordinal).asInstanceOf[Short])
- }
- TColumnValue.i16Val(tI16Value)
- case _: IntType =>
- val tI32Value = new TI32Value
- if (row.getField(ordinal) != null) {
- tI32Value.setValue(row.getField(ordinal).asInstanceOf[Int])
- }
- TColumnValue.i32Val(tI32Value)
- case _: BigIntType =>
- val tI64Value = new TI64Value
- if (row.getField(ordinal) != null) {
- tI64Value.setValue(row.getField(ordinal).asInstanceOf[Long])
- }
- TColumnValue.i64Val(tI64Value)
- case _: FloatType =>
- val tDoubleValue = new TDoubleValue
- if (row.getField(ordinal) != null) {
- val doubleValue =
lang.Double.valueOf(row.getField(ordinal).asInstanceOf[Float].toString)
- tDoubleValue.setValue(doubleValue)
- }
- TColumnValue.doubleVal(tDoubleValue)
- case _: DoubleType =>
- val tDoubleValue = new TDoubleValue
- if (row.getField(ordinal) != null) {
- tDoubleValue.setValue(row.getField(ordinal).asInstanceOf[Double])
- }
- TColumnValue.doubleVal(tDoubleValue)
- case t @ (_: VarCharType | _: CharType) =>
- val tStringValue = new TStringValue
- val fieldValue = row.getField(ordinal)
- fieldValue match {
- case value: String =>
- tStringValue.setValue(value)
- case value: StringData =>
- tStringValue.setValue(value.toString)
- case null =>
- tStringValue.setValue(null)
- case other =>
- throw new IllegalArgumentException(
- s"Unsupported conversion class ${other.getClass} " +
- s"for type ${t.getClass}.")
- }
- TColumnValue.stringVal(tStringValue)
- case _: LocalZonedTimestampType =>
- val tStringValue = new TStringValue
- val fieldValue = row.getField(ordinal)
- tStringValue.setValue(TIMESTAMP_LZT_FORMATTER.format(
- ZonedDateTime.ofInstant(fieldValue.asInstanceOf[Instant], zoneId)))
- TColumnValue.stringVal(tStringValue)
- case t =>
- val tStringValue = new TStringValue
- if (row.getField(ordinal) != null) {
- tStringValue.setValue(toHiveString((row.getField(ordinal), t)))
- }
- TColumnValue.stringVal(tStringValue)
- }
- }
-
- implicit private def bitSetToBuffer(bitSet: java.util.BitSet): ByteBuffer = {
- ByteBuffer.wrap(bitSet.toByteArray)
- }
-
- private def toTColumn(
- rows: Seq[Row],
- ordinal: Int,
- logicalType: LogicalType,
- zoneId: ZoneId): TColumn = {
- val nulls = new java.util.BitSet()
- // for each column, determine the conversion class by sampling the first
non-value value
- // if there's no row, set the entire column empty
- val sampleField = rows.iterator.map(_.getField(ordinal)).find(_ ne
null).orNull
- logicalType match {
- case _: BooleanType =>
- val values = getOrSetAsNull[lang.Boolean](rows, ordinal, nulls, true)
- TColumn.boolVal(new TBoolColumn(values, nulls))
- case _: TinyIntType =>
- val values = getOrSetAsNull[lang.Byte](rows, ordinal, nulls, 0.toByte)
- TColumn.byteVal(new TByteColumn(values, nulls))
- case _: SmallIntType =>
- val values = getOrSetAsNull[lang.Short](rows, ordinal, nulls,
0.toShort)
- TColumn.i16Val(new TI16Column(values, nulls))
- case _: IntType =>
- val values = getOrSetAsNull[lang.Integer](rows, ordinal, nulls, 0)
- TColumn.i32Val(new TI32Column(values, nulls))
- case _: BigIntType =>
- val values = getOrSetAsNull[lang.Long](rows, ordinal, nulls, 0L)
- TColumn.i64Val(new TI64Column(values, nulls))
- case _: FloatType =>
- val values = getOrSetAsNull[lang.Float](rows, ordinal, nulls, 0.0f)
- .asScala.map(n => lang.Double.valueOf(n.toString)).asJava
- TColumn.doubleVal(new TDoubleColumn(values, nulls))
- case _: DoubleType =>
- val values = getOrSetAsNull[lang.Double](rows, ordinal, nulls, 0.0)
- TColumn.doubleVal(new TDoubleColumn(values, nulls))
- case t @ (_: VarCharType | _: CharType) =>
- val values: util.List[String] = new util.ArrayList[String](0)
- sampleField match {
- case _: String =>
- values.addAll(getOrSetAsNull[String](rows, ordinal, nulls, ""))
- case _: StringData =>
- val stringDataValues =
- getOrSetAsNull[StringData](rows, ordinal, nulls,
StringData.fromString(""))
- stringDataValues.forEach(e => values.add(e.toString))
- case null =>
- values.addAll(getOrSetAsNull[String](rows, ordinal, nulls, ""))
- case other =>
- throw new IllegalArgumentException(
- s"Unsupported conversion class ${other.getClass} " +
- s"for type ${t.getClass}.")
- }
- TColumn.stringVal(new TStringColumn(values, nulls))
- case _: LocalZonedTimestampType =>
- val values = getOrSetAsNull[Instant](rows, ordinal, nulls,
Instant.EPOCH)
- .toArray().map(v =>
- TIMESTAMP_LZT_FORMATTER.format(
- ZonedDateTime.ofInstant(v.asInstanceOf[Instant], zoneId)))
- TColumn.stringVal(new TStringColumn(values.toList.asJava, nulls))
- case _ =>
- var i = 0
- val rowSize = rows.length
- val values = new java.util.ArrayList[String](rowSize)
- while (i < rowSize) {
- val row = rows(i)
- nulls.set(i, row.getField(ordinal) == null)
- val value =
- if (row.getField(ordinal) == null) {
- ""
- } else {
- toHiveString((row.getField(ordinal), logicalType))
- }
- values.add(value)
- i += 1
- }
- TColumn.stringVal(new TStringColumn(values, nulls))
- }
- }
-
- private def getOrSetAsNull[T](
- rows: Seq[Row],
- ordinal: Int,
- nulls: java.util.BitSet,
- defaultVal: T): java.util.List[T] = {
- val size = rows.length
- val ret = new java.util.ArrayList[T](size)
- var idx = 0
- while (idx < size) {
- val row = rows(idx)
- val isNull = row.getField(ordinal) == null
- if (isNull) {
- nulls.set(idx, true)
- ret.add(idx, defaultVal)
- } else {
- ret.add(idx, row.getFieldAs[T](ordinal))
- }
- idx += 1
- }
- ret
- }
-
def toTColumnDesc(field: Column, pos: Int): TColumnDesc = {
val tColumnDesc = new TColumnDesc()
tColumnDesc.setColumnName(field.getName)
diff --git
a/externals/kyuubi-flink-sql-engine/src/test/scala/org/apache/kyuubi/engine/flink/result/ResultSetSuite.scala
b/externals/kyuubi-flink-sql-engine/src/test/scala/org/apache/kyuubi/engine/flink/result/ResultSetSuite.scala
index 9ee5c658b..5e58d433f 100644
---
a/externals/kyuubi-flink-sql-engine/src/test/scala/org/apache/kyuubi/engine/flink/result/ResultSetSuite.scala
+++
b/externals/kyuubi-flink-sql-engine/src/test/scala/org/apache/kyuubi/engine/flink/result/ResultSetSuite.scala
@@ -25,7 +25,7 @@ import org.apache.flink.table.data.StringData
import org.apache.flink.types.Row
import org.apache.kyuubi.KyuubiFunSuite
-import org.apache.kyuubi.engine.flink.schema.RowSet
+import org.apache.kyuubi.engine.flink.schema.FlinkTRowSetGenerator
class ResultSetSuite extends KyuubiFunSuite {
@@ -47,9 +47,9 @@ class ResultSetSuite extends KyuubiFunSuite {
.build
val timeZone = ZoneId.of("America/Los_Angeles")
- assert(RowSet.toRowBaseSet(rowsNew, resultSetNew, timeZone)
- === RowSet.toRowBaseSet(rowsOld, resultSetOld, timeZone))
- assert(RowSet.toColumnBasedSet(rowsNew, resultSetNew, timeZone)
- === RowSet.toColumnBasedSet(rowsOld, resultSetOld, timeZone))
+ assert(new FlinkTRowSetGenerator(timeZone).toRowBasedSet(rowsNew,
resultSetNew)
+ === new FlinkTRowSetGenerator(timeZone).toRowBasedSet(rowsOld,
resultSetOld))
+ assert(new FlinkTRowSetGenerator(timeZone).toColumnBasedSet(rowsNew,
resultSetNew)
+ === new FlinkTRowSetGenerator(timeZone).toColumnBasedSet(rowsOld,
resultSetOld))
}
}
diff --git
a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/SparkOperation.scala
b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/SparkOperation.scala
index e2fc80c6b..1d271cfce 100644
---
a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/SparkOperation.scala
+++
b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/SparkOperation.scala
@@ -24,7 +24,7 @@ import org.apache.spark.kyuubi.{SparkProgressMonitor,
SQLOperationListener}
import org.apache.spark.kyuubi.SparkUtilsHelper.redact
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
import org.apache.spark.sql.execution.SQLExecution
-import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.types.{BinaryType, StructField, StructType}
import org.apache.kyuubi.{KyuubiSQLException, Utils}
import org.apache.kyuubi.config.KyuubiConf
@@ -33,7 +33,7 @@ import
org.apache.kyuubi.config.KyuubiReservedKeys.{KYUUBI_SESSION_SIGN_PUBLICKE
import org.apache.kyuubi.engine.spark.KyuubiSparkUtil.{getSessionConf,
SPARK_SCHEDULER_POOL_KEY}
import org.apache.kyuubi.engine.spark.events.SparkOperationEvent
import org.apache.kyuubi.engine.spark.operation.SparkOperation.TIMEZONE_KEY
-import org.apache.kyuubi.engine.spark.schema.{RowSet, SchemaHelper}
+import org.apache.kyuubi.engine.spark.schema.{SchemaHelper,
SparkArrowTRowSetGenerator, SparkTRowSetGenerator}
import org.apache.kyuubi.engine.spark.session.SparkSessionImpl
import org.apache.kyuubi.events.EventBus
import org.apache.kyuubi.operation.{AbstractOperation, FetchIterator,
OperationState, OperationStatus}
@@ -42,6 +42,7 @@ import
org.apache.kyuubi.operation.OperationState.OperationState
import org.apache.kyuubi.operation.log.OperationLog
import org.apache.kyuubi.session.Session
import org.apache.kyuubi.shaded.hive.service.rpc.thrift.{TFetchResultsResp,
TGetResultSetMetadataResp, TProgressUpdateResp, TRowSet}
+import org.apache.kyuubi.util.ThriftUtils
abstract class SparkOperation(session: Session)
extends AbstractOperation(session) {
@@ -243,13 +244,16 @@ abstract class SparkOperation(session: Session)
if (isArrowBasedOperation) {
if (iter.hasNext) {
val taken = iter.next().asInstanceOf[Array[Byte]]
- RowSet.toTRowSet(taken, getProtocolVersion)
+ new SparkArrowTRowSetGenerator().toTRowSet(
+ Seq(taken),
+ new StructType().add(StructField(null, BinaryType)),
+ getProtocolVersion)
} else {
- RowSet.emptyTRowSet()
+ ThriftUtils.newEmptyRowSet
}
} else {
val taken = iter.take(rowSetSize)
- RowSet.toTRowSet(
+ new SparkTRowSetGenerator().toTRowSet(
taken.toSeq.asInstanceOf[Seq[Row]],
resultSchema,
getProtocolVersion)
diff --git
a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/schema/RowSet.scala
b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/schema/RowSet.scala
index 806451907..c5f322108 100644
---
a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/schema/RowSet.scala
+++
b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/schema/RowSet.scala
@@ -17,18 +17,10 @@
package org.apache.kyuubi.engine.spark.schema
-import java.nio.ByteBuffer
-
-import scala.collection.JavaConverters._
-
-import org.apache.spark.sql.Row
import org.apache.spark.sql.execution.HiveResult
import org.apache.spark.sql.execution.HiveResult.TimeFormatters
import org.apache.spark.sql.types._
-import org.apache.kyuubi.shaded.hive.service.rpc.thrift._
-import org.apache.kyuubi.util.RowSetUtils._
-
object RowSet {
def toHiveString(
@@ -38,224 +30,4 @@ object RowSet {
HiveResult.toHiveString(valueAndType, nested, timeFormatters)
}
- def toTRowSet(
- bytes: Array[Byte],
- protocolVersion: TProtocolVersion): TRowSet = {
- if (protocolVersion.getValue <
TProtocolVersion.HIVE_CLI_SERVICE_PROTOCOL_V6.getValue) {
- throw new UnsupportedOperationException
- } else {
- toColumnBasedSet(bytes)
- }
- }
-
- def emptyTRowSet(): TRowSet = {
- new TRowSet(0, new java.util.ArrayList[TRow](0))
- }
-
- def toColumnBasedSet(data: Array[Byte]): TRowSet = {
- val tRowSet = new TRowSet(0, new java.util.ArrayList[TRow](1))
- val tColumn = toTColumn(data)
- tRowSet.addToColumns(tColumn)
- tRowSet
- }
-
- def toTRowSet(
- rows: Seq[Row],
- schema: StructType,
- protocolVersion: TProtocolVersion): TRowSet = {
- if (protocolVersion.getValue <
TProtocolVersion.HIVE_CLI_SERVICE_PROTOCOL_V6.getValue) {
- toRowBasedSet(rows, schema)
- } else {
- toColumnBasedSet(rows, schema)
- }
- }
-
- def toRowBasedSet(rows: Seq[Row], schema: StructType): TRowSet = {
- val rowSize = rows.length
- val tRows = new java.util.ArrayList[TRow](rowSize)
- val timeFormatters = HiveResult.getTimeFormatters
- var i = 0
- while (i < rowSize) {
- val row = rows(i)
- var j = 0
- val columnSize = row.length
- val tColumnValues = new java.util.ArrayList[TColumnValue](columnSize)
- while (j < columnSize) {
- val columnValue = toTColumnValue(j, row, schema, timeFormatters)
- tColumnValues.add(columnValue)
- j += 1
- }
- i += 1
- val tRow = new TRow(tColumnValues)
- tRows.add(tRow)
- }
- new TRowSet(0, tRows)
- }
-
- def toColumnBasedSet(rows: Seq[Row], schema: StructType): TRowSet = {
- val rowSize = rows.length
- val tRowSet = new TRowSet(0, new java.util.ArrayList[TRow](rowSize))
- val timeFormatters = HiveResult.getTimeFormatters
- var i = 0
- val columnSize = schema.length
- val tColumns = new java.util.ArrayList[TColumn](columnSize)
- while (i < columnSize) {
- val field = schema(i)
- val tColumn = toTColumn(rows, i, field.dataType, timeFormatters)
- tColumns.add(tColumn)
- i += 1
- }
- tRowSet.setColumns(tColumns)
- tRowSet
- }
-
- private def toTColumn(
- rows: Seq[Row],
- ordinal: Int,
- typ: DataType,
- timeFormatters: TimeFormatters): TColumn = {
- val nulls = new java.util.BitSet()
- typ match {
- case BooleanType =>
- val values = getOrSetAsNull[java.lang.Boolean](rows, ordinal, nulls,
true)
- TColumn.boolVal(new TBoolColumn(values, nulls))
-
- case ByteType =>
- val values = getOrSetAsNull[java.lang.Byte](rows, ordinal, nulls,
0.toByte)
- TColumn.byteVal(new TByteColumn(values, nulls))
-
- case ShortType =>
- val values = getOrSetAsNull[java.lang.Short](rows, ordinal, nulls,
0.toShort)
- TColumn.i16Val(new TI16Column(values, nulls))
-
- case IntegerType =>
- val values = getOrSetAsNull[java.lang.Integer](rows, ordinal, nulls, 0)
- TColumn.i32Val(new TI32Column(values, nulls))
-
- case LongType =>
- val values = getOrSetAsNull[java.lang.Long](rows, ordinal, nulls, 0L)
- TColumn.i64Val(new TI64Column(values, nulls))
-
- case FloatType =>
- val values = getOrSetAsNull[java.lang.Float](rows, ordinal, nulls,
0.toFloat)
- .asScala.map(n => java.lang.Double.valueOf(n.toString)).asJava
- TColumn.doubleVal(new TDoubleColumn(values, nulls))
-
- case DoubleType =>
- val values = getOrSetAsNull[java.lang.Double](rows, ordinal, nulls,
0.toDouble)
- TColumn.doubleVal(new TDoubleColumn(values, nulls))
-
- case StringType =>
- val values = getOrSetAsNull[java.lang.String](rows, ordinal, nulls, "")
- TColumn.stringVal(new TStringColumn(values, nulls))
-
- case BinaryType =>
- val values = getOrSetAsNull[Array[Byte]](rows, ordinal, nulls, Array())
- .asScala
- .map(ByteBuffer.wrap)
- .asJava
- TColumn.binaryVal(new TBinaryColumn(values, nulls))
-
- case _ =>
- var i = 0
- val rowSize = rows.length
- val values = new java.util.ArrayList[String](rowSize)
- while (i < rowSize) {
- val row = rows(i)
- nulls.set(i, row.isNullAt(ordinal))
- values.add(toHiveString(row.get(ordinal) -> typ, timeFormatters =
timeFormatters))
- i += 1
- }
- TColumn.stringVal(new TStringColumn(values, nulls))
- }
- }
-
- private def getOrSetAsNull[T](
- rows: Seq[Row],
- ordinal: Int,
- nulls: java.util.BitSet,
- defaultVal: T): java.util.List[T] = {
- val size = rows.length
- val ret = new java.util.ArrayList[T](size)
- var idx = 0
- while (idx < size) {
- val row = rows(idx)
- val isNull = row.isNullAt(ordinal)
- if (isNull) {
- nulls.set(idx, true)
- ret.add(idx, defaultVal)
- } else {
- ret.add(idx, row.getAs[T](ordinal))
- }
- idx += 1
- }
- ret
- }
-
- private def toTColumnValue(
- ordinal: Int,
- row: Row,
- types: StructType,
- timeFormatters: TimeFormatters): TColumnValue = {
- types(ordinal).dataType match {
- case BooleanType =>
- val boolValue = new TBoolValue
- if (!row.isNullAt(ordinal)) boolValue.setValue(row.getBoolean(ordinal))
- TColumnValue.boolVal(boolValue)
-
- case ByteType =>
- val byteValue = new TByteValue
- if (!row.isNullAt(ordinal)) byteValue.setValue(row.getByte(ordinal))
- TColumnValue.byteVal(byteValue)
-
- case ShortType =>
- val tI16Value = new TI16Value
- if (!row.isNullAt(ordinal)) tI16Value.setValue(row.getShort(ordinal))
- TColumnValue.i16Val(tI16Value)
-
- case IntegerType =>
- val tI32Value = new TI32Value
- if (!row.isNullAt(ordinal)) tI32Value.setValue(row.getInt(ordinal))
- TColumnValue.i32Val(tI32Value)
-
- case LongType =>
- val tI64Value = new TI64Value
- if (!row.isNullAt(ordinal)) tI64Value.setValue(row.getLong(ordinal))
- TColumnValue.i64Val(tI64Value)
-
- case FloatType =>
- val tDoubleValue = new TDoubleValue
- if (!row.isNullAt(ordinal)) {
- val doubleValue =
java.lang.Double.valueOf(row.getFloat(ordinal).toString)
- tDoubleValue.setValue(doubleValue)
- }
- TColumnValue.doubleVal(tDoubleValue)
-
- case DoubleType =>
- val tDoubleValue = new TDoubleValue
- if (!row.isNullAt(ordinal))
tDoubleValue.setValue(row.getDouble(ordinal))
- TColumnValue.doubleVal(tDoubleValue)
-
- case StringType =>
- val tStringValue = new TStringValue
- if (!row.isNullAt(ordinal))
tStringValue.setValue(row.getString(ordinal))
- TColumnValue.stringVal(tStringValue)
-
- case _ =>
- val tStrValue = new TStringValue
- if (!row.isNullAt(ordinal)) {
- tStrValue.setValue(toHiveString(
- row.get(ordinal) -> types(ordinal).dataType,
- timeFormatters = timeFormatters))
- }
- TColumnValue.stringVal(tStrValue)
- }
- }
-
- private def toTColumn(data: Array[Byte]): TColumn = {
- val values = new java.util.ArrayList[ByteBuffer](1)
- values.add(ByteBuffer.wrap(data))
- val nulls = new java.util.BitSet()
- TColumn.binaryVal(new TBinaryColumn(values, nulls))
- }
}
diff --git
a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/schema/SparkArrowTRowSetGenerator.scala
b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/schema/SparkArrowTRowSetGenerator.scala
new file mode 100644
index 000000000..ded022ad0
--- /dev/null
+++
b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/schema/SparkArrowTRowSetGenerator.scala
@@ -0,0 +1,77 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kyuubi.engine.spark.schema
+
+import java.nio.ByteBuffer
+
+import org.apache.spark.sql.types._
+
+import org.apache.kyuubi.engine.schema.AbstractTRowSetGenerator
+import org.apache.kyuubi.shaded.hive.service.rpc.thrift._
+import org.apache.kyuubi.util.RowSetUtils.bitSetToBuffer
+
+class SparkArrowTRowSetGenerator
+ extends AbstractTRowSetGenerator[StructType, Array[Byte], DataType] {
+ override def toColumnBasedSet(rows: Seq[Array[Byte]], schema: StructType):
TRowSet = {
+ require(schema.length == 1, "ArrowRowSetGenerator accepts only one single
byte array")
+ require(schema.head.dataType == BinaryType, "ArrowRowSetGenerator accepts
only BinaryType")
+
+ val tRowSet = new TRowSet(0, new java.util.ArrayList[TRow](1))
+ val tColumn = toTColumn(rows, 1, schema.head.dataType)
+ tRowSet.addToColumns(tColumn)
+ tRowSet
+ }
+
+ override def toTColumn(rows: Seq[Array[Byte]], ordinal: Int, typ: DataType):
TColumn = {
+ require(rows.length == 1, "ArrowRowSetGenerator accepts only one single
byte array")
+ typ match {
+ case BinaryType =>
+ val values = new java.util.ArrayList[ByteBuffer](1)
+ values.add(ByteBuffer.wrap(rows.head))
+ val nulls = new java.util.BitSet()
+ TColumn.binaryVal(new TBinaryColumn(values, nulls))
+ case _ => throw new IllegalArgumentException(
+ s"unsupported datatype $typ, ArrowRowSetGenerator accepts only
BinaryType")
+ }
+ }
+
+ override def toRowBasedSet(rows: Seq[Array[Byte]], schema: StructType):
TRowSet = {
+ throw new UnsupportedOperationException
+ }
+
+ override def getColumnSizeFromSchemaType(schema: StructType): Int = {
+ throw new UnsupportedOperationException
+ }
+
+ override def getColumnType(schema: StructType, ordinal: Int): DataType = {
+ throw new UnsupportedOperationException
+ }
+
+ override def isColumnNullAt(row: Array[Byte], ordinal: Int): Boolean = {
+ throw new UnsupportedOperationException
+ }
+
+ override def getColumnAs[T](row: Array[Byte], ordinal: Int): T = {
+ throw new UnsupportedOperationException
+ }
+
+ override def toTColumnValue(ordinal: Int, row: Array[Byte], types:
StructType): TColumnValue = {
+ throw new UnsupportedOperationException
+ }
+
+}
diff --git
a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/schema/SparkTRowSetGenerator.scala
b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/schema/SparkTRowSetGenerator.scala
new file mode 100644
index 000000000..a35455292
--- /dev/null
+++
b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/schema/SparkTRowSetGenerator.scala
@@ -0,0 +1,93 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kyuubi.engine.spark.schema
+
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.execution.HiveResult
+import org.apache.spark.sql.execution.HiveResult.TimeFormatters
+import org.apache.spark.sql.types._
+
+import org.apache.kyuubi.engine.schema.AbstractTRowSetGenerator
+import org.apache.kyuubi.shaded.hive.service.rpc.thrift._
+import org.apache.kyuubi.shaded.hive.service.rpc.thrift.TTypeId._
+import org.apache.kyuubi.util.RowSetUtils.bitSetToBuffer
+
+class SparkTRowSetGenerator
+ extends AbstractTRowSetGenerator[StructType, Row, DataType] {
+
+ // reused time formatters in single RowSet generation, see KYUUBI-5811
+ private val tf = HiveResult.getTimeFormatters
+
+ override def getColumnSizeFromSchemaType(schema: StructType): Int =
schema.length
+
+ override def getColumnType(schema: StructType, ordinal: Int): DataType =
schema(ordinal).dataType
+
+ override def isColumnNullAt(row: Row, ordinal: Int): Boolean =
row.isNullAt(ordinal)
+
+ override def getColumnAs[T](row: Row, ordinal: Int): T =
row.getAs[T](ordinal)
+
+ override def toTColumn(rows: Seq[Row], ordinal: Int, typ: DataType): TColumn
= {
+ val timeFormatters: TimeFormatters = tf
+ val nulls = new java.util.BitSet()
+ typ match {
+ case BooleanType => toTTypeColumn(BOOLEAN_TYPE, rows, ordinal)
+ case ByteType => toTTypeColumn(BINARY_TYPE, rows, ordinal)
+ case ShortType => toTTypeColumn(TINYINT_TYPE, rows, ordinal)
+ case IntegerType => toTTypeColumn(INT_TYPE, rows, ordinal)
+ case LongType => toTTypeColumn(BIGINT_TYPE, rows, ordinal)
+ case FloatType => toTTypeColumn(FLOAT_TYPE, rows, ordinal)
+ case DoubleType => toTTypeColumn(DOUBLE_TYPE, rows, ordinal)
+ case StringType => toTTypeColumn(STRING_TYPE, rows, ordinal)
+ case BinaryType => toTTypeColumn(ARRAY_TYPE, rows, ordinal)
+ case _ =>
+ var i = 0
+ val rowSize = rows.length
+ val values = new java.util.ArrayList[String](rowSize)
+ while (i < rowSize) {
+ val row = rows(i)
+ nulls.set(i, row.isNullAt(ordinal))
+ values.add(RowSet.toHiveString(row.get(ordinal) -> typ,
timeFormatters = timeFormatters))
+ i += 1
+ }
+ TColumn.stringVal(new TStringColumn(values, nulls))
+ }
+ }
+
+ override def toTColumnValue(ordinal: Int, row: Row, types: StructType):
TColumnValue = {
+ val timeFormatters: TimeFormatters = tf
+ getColumnType(types, ordinal) match {
+ case BooleanType => toTTypeColumnVal(BOOLEAN_TYPE, row, ordinal)
+ case ByteType => toTTypeColumnVal(BINARY_TYPE, row, ordinal)
+ case ShortType => toTTypeColumnVal(TINYINT_TYPE, row, ordinal)
+ case IntegerType => toTTypeColumnVal(INT_TYPE, row, ordinal)
+ case LongType => toTTypeColumnVal(BIGINT_TYPE, row, ordinal)
+ case FloatType => toTTypeColumnVal(FLOAT_TYPE, row, ordinal)
+ case DoubleType => toTTypeColumnVal(DOUBLE_TYPE, row, ordinal)
+ case StringType => toTTypeColumnVal(STRING_TYPE, row, ordinal)
+ case _ =>
+ val tStrValue = new TStringValue
+ if (!row.isNullAt(ordinal)) {
+ tStrValue.setValue(RowSet.toHiveString(
+ row.get(ordinal) -> types(ordinal).dataType,
+ timeFormatters = timeFormatters))
+ }
+ TColumnValue.stringVal(tStrValue)
+ }
+ }
+
+}
diff --git
a/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/schema/RowSetSuite.scala
b/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/schema/RowSetSuite.scala
index dec185897..228bdcaf2 100644
---
a/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/schema/RowSetSuite.scala
+++
b/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/schema/RowSetSuite.scala
@@ -99,7 +99,7 @@ class RowSetSuite extends KyuubiFunSuite {
private val rows: Seq[Row] = (0 to 10).map(genRow) ++
Seq(Row.fromSeq(Seq.fill(17)(null)))
test("column based set") {
- val tRowSet = RowSet.toColumnBasedSet(rows, schema)
+ val tRowSet = new SparkTRowSetGenerator().toColumnBasedSet(rows, schema)
assert(tRowSet.getColumns.size() === schema.size)
assert(tRowSet.getRowsSize === 0)
@@ -210,7 +210,7 @@ class RowSetSuite extends KyuubiFunSuite {
}
test("row based set") {
- val tRowSet = RowSet.toRowBasedSet(rows, schema)
+ val tRowSet = new SparkTRowSetGenerator().toRowBasedSet(rows, schema)
assert(tRowSet.getColumnCount === 0)
assert(tRowSet.getRowsSize === rows.size)
val iter = tRowSet.getRowsIterator
@@ -258,7 +258,7 @@ class RowSetSuite extends KyuubiFunSuite {
test("to row set") {
TProtocolVersion.values().foreach { proto =>
- val set = RowSet.toTRowSet(rows, schema, proto)
+ val set = new SparkTRowSetGenerator().toTRowSet(rows, schema, proto)
if (proto.getValue <
TProtocolVersion.HIVE_CLI_SERVICE_PROTOCOL_V6.getValue) {
assert(!set.isSetColumns, proto.toString)
assert(set.isSetRows, proto.toString)
diff --git
a/externals/kyuubi-trino-engine/src/main/scala/org/apache/kyuubi/engine/trino/operation/ExecuteStatement.scala
b/externals/kyuubi-trino-engine/src/main/scala/org/apache/kyuubi/engine/trino/operation/ExecuteStatement.scala
index 4f5049223..3de2ae59f 100644
---
a/externals/kyuubi-trino-engine/src/main/scala/org/apache/kyuubi/engine/trino/operation/ExecuteStatement.scala
+++
b/externals/kyuubi-trino-engine/src/main/scala/org/apache/kyuubi/engine/trino/operation/ExecuteStatement.scala
@@ -22,7 +22,7 @@ import java.util.concurrent.RejectedExecutionException
import org.apache.kyuubi.{KyuubiSQLException, Logging}
import org.apache.kyuubi.engine.trino.TrinoStatement
import org.apache.kyuubi.engine.trino.event.TrinoOperationEvent
-import org.apache.kyuubi.engine.trino.schema.RowSet
+import org.apache.kyuubi.engine.trino.schema.TrinoTRowSetGenerator
import org.apache.kyuubi.events.EventBus
import org.apache.kyuubi.operation.{ArrayFetchIterator, FetchIterator,
OperationState}
import org.apache.kyuubi.operation.FetchOrientation.{FETCH_FIRST, FETCH_NEXT,
FETCH_PRIOR, FetchOrientation}
@@ -96,7 +96,8 @@ class ExecuteStatement(
throw KyuubiSQLException(s"Fetch orientation[$order] is not supported
in $mode mode")
}
val taken = iter.take(rowSetSize)
- val resultRowSet = RowSet.toTRowSet(taken.toList, schema,
getProtocolVersion)
+ val resultRowSet = new TrinoTRowSetGenerator()
+ .toTRowSet(taken.toList, schema, getProtocolVersion)
resultRowSet.setStartRowOffset(iter.getPosition)
val fetchResultsResp = new TFetchResultsResp(OK_STATUS)
fetchResultsResp.setResults(resultRowSet)
diff --git
a/externals/kyuubi-trino-engine/src/main/scala/org/apache/kyuubi/engine/trino/operation/TrinoOperation.scala
b/externals/kyuubi-trino-engine/src/main/scala/org/apache/kyuubi/engine/trino/operation/TrinoOperation.scala
index d82b11adc..822f1726a 100644
---
a/externals/kyuubi-trino-engine/src/main/scala/org/apache/kyuubi/engine/trino/operation/TrinoOperation.scala
+++
b/externals/kyuubi-trino-engine/src/main/scala/org/apache/kyuubi/engine/trino/operation/TrinoOperation.scala
@@ -25,8 +25,7 @@ import io.trino.client.StatementClient
import org.apache.kyuubi.KyuubiSQLException
import org.apache.kyuubi.Utils
import org.apache.kyuubi.engine.trino.TrinoContext
-import org.apache.kyuubi.engine.trino.schema.RowSet
-import org.apache.kyuubi.engine.trino.schema.SchemaHelper
+import org.apache.kyuubi.engine.trino.schema.{SchemaHelper,
TrinoTRowSetGenerator}
import org.apache.kyuubi.engine.trino.session.TrinoSessionImpl
import org.apache.kyuubi.operation.AbstractOperation
import org.apache.kyuubi.operation.FetchIterator
@@ -66,7 +65,8 @@ abstract class TrinoOperation(session: Session) extends
AbstractOperation(sessio
case FETCH_FIRST => iter.fetchAbsolute(0)
}
val taken = iter.take(rowSetSize)
- val resultRowSet = RowSet.toTRowSet(taken.toList, schema,
getProtocolVersion)
+ val resultRowSet =
+ new TrinoTRowSetGenerator().toTRowSet(taken.toSeq, schema,
getProtocolVersion)
resultRowSet.setStartRowOffset(iter.getPosition)
val resp = new TFetchResultsResp(OK_STATUS)
resp.setResults(resultRowSet)
diff --git
a/externals/kyuubi-trino-engine/src/main/scala/org/apache/kyuubi/engine/trino/schema/RowSet.scala
b/externals/kyuubi-trino-engine/src/main/scala/org/apache/kyuubi/engine/trino/schema/RowSet.scala
index 2bb16622e..22e09f381 100644
---
a/externals/kyuubi-trino-engine/src/main/scala/org/apache/kyuubi/engine/trino/schema/RowSet.scala
+++
b/externals/kyuubi-trino-engine/src/main/scala/org/apache/kyuubi/engine/trino/schema/RowSet.scala
@@ -17,233 +17,16 @@
package org.apache.kyuubi.engine.trino.schema
-import java.nio.ByteBuffer
import java.nio.charset.StandardCharsets
-import java.util
import scala.collection.JavaConverters._
import io.trino.client.ClientStandardTypes._
import io.trino.client.ClientTypeSignature
-import io.trino.client.Column
import io.trino.client.Row
-import org.apache.kyuubi.shaded.hive.service.rpc.thrift.TBinaryColumn
-import org.apache.kyuubi.shaded.hive.service.rpc.thrift.TBoolColumn
-import org.apache.kyuubi.shaded.hive.service.rpc.thrift.TBoolValue
-import org.apache.kyuubi.shaded.hive.service.rpc.thrift.TByteColumn
-import org.apache.kyuubi.shaded.hive.service.rpc.thrift.TByteValue
-import org.apache.kyuubi.shaded.hive.service.rpc.thrift.TColumn
-import org.apache.kyuubi.shaded.hive.service.rpc.thrift.TColumnValue
-import org.apache.kyuubi.shaded.hive.service.rpc.thrift.TDoubleColumn
-import org.apache.kyuubi.shaded.hive.service.rpc.thrift.TDoubleValue
-import org.apache.kyuubi.shaded.hive.service.rpc.thrift.TI16Column
-import org.apache.kyuubi.shaded.hive.service.rpc.thrift.TI16Value
-import org.apache.kyuubi.shaded.hive.service.rpc.thrift.TI32Column
-import org.apache.kyuubi.shaded.hive.service.rpc.thrift.TI32Value
-import org.apache.kyuubi.shaded.hive.service.rpc.thrift.TI64Column
-import org.apache.kyuubi.shaded.hive.service.rpc.thrift.TI64Value
-import org.apache.kyuubi.shaded.hive.service.rpc.thrift.TProtocolVersion
-import org.apache.kyuubi.shaded.hive.service.rpc.thrift.TRow
-import org.apache.kyuubi.shaded.hive.service.rpc.thrift.TRowSet
-import org.apache.kyuubi.shaded.hive.service.rpc.thrift.TStringColumn
-import org.apache.kyuubi.shaded.hive.service.rpc.thrift.TStringValue
-import org.apache.kyuubi.util.RowSetUtils.bitSetToBuffer
-
object RowSet {
- def toTRowSet(
- rows: Seq[List[_]],
- schema: List[Column],
- protocolVersion: TProtocolVersion): TRowSet = {
- if (protocolVersion.getValue <
TProtocolVersion.HIVE_CLI_SERVICE_PROTOCOL_V6.getValue) {
- toRowBasedSet(rows, schema)
- } else {
- toColumnBasedSet(rows, schema)
- }
- }
-
- def toRowBasedSet(rows: Seq[List[_]], schema: List[Column]): TRowSet = {
- val rowSize = rows.length
- val tRows = new util.ArrayList[TRow](rowSize)
- var i = 0
- while (i < rowSize) {
- val row = rows(i)
- val tRow = new TRow()
- val columnSize = row.size
- var j = 0
- while (j < columnSize) {
- val columnValue = toTColumnValue(j, row, schema)
- tRow.addToColVals(columnValue)
- j += 1
- }
- tRows.add(tRow)
- i += 1
- }
- new TRowSet(0, tRows)
- }
-
- def toColumnBasedSet(rows: Seq[List[_]], schema: List[Column]): TRowSet = {
- val size = rows.size
- val tRowSet = new TRowSet(0, new java.util.ArrayList[TRow](size))
- val columnSize = schema.length
- var i = 0
- while (i < columnSize) {
- val field = schema(i)
- val tColumn = toTColumn(rows, i, field.getTypeSignature)
- tRowSet.addToColumns(tColumn)
- i += 1
- }
- tRowSet
- }
-
- private def toTColumn(
- rows: Seq[Seq[Any]],
- ordinal: Int,
- typ: ClientTypeSignature): TColumn = {
- val nulls = new java.util.BitSet()
- typ.getRawType match {
- case BOOLEAN =>
- val values = getOrSetAsNull[java.lang.Boolean](rows, ordinal, nulls,
true)
- TColumn.boolVal(new TBoolColumn(values, nulls))
-
- case TINYINT =>
- val values = getOrSetAsNull[java.lang.Byte](rows, ordinal, nulls,
0.toByte)
- TColumn.byteVal(new TByteColumn(values, nulls))
-
- case SMALLINT =>
- val values = getOrSetAsNull[java.lang.Short](rows, ordinal, nulls,
0.toShort)
- TColumn.i16Val(new TI16Column(values, nulls))
-
- case INTEGER =>
- val values = getOrSetAsNull[java.lang.Integer](rows, ordinal, nulls, 0)
- TColumn.i32Val(new TI32Column(values, nulls))
-
- case BIGINT =>
- val values = getOrSetAsNull[java.lang.Long](rows, ordinal, nulls, 0L)
- TColumn.i64Val(new TI64Column(values, nulls))
-
- case REAL =>
- val values = getOrSetAsNull[java.lang.Float](rows, ordinal, nulls,
0.toFloat)
- .asScala.map(n => java.lang.Double.valueOf(n.toString)).asJava
- TColumn.doubleVal(new TDoubleColumn(values, nulls))
-
- case DOUBLE =>
- val values = getOrSetAsNull[java.lang.Double](rows, ordinal, nulls,
0.toDouble)
- TColumn.doubleVal(new TDoubleColumn(values, nulls))
-
- case VARCHAR =>
- val values = getOrSetAsNull[String](rows, ordinal, nulls, "")
- TColumn.stringVal(new TStringColumn(values, nulls))
-
- case VARBINARY =>
- val values = getOrSetAsNull[Array[Byte]](rows, ordinal, nulls, Array())
- .asScala
- .map(ByteBuffer.wrap)
- .asJava
- TColumn.binaryVal(new TBinaryColumn(values, nulls))
-
- case _ =>
- val rowSize = rows.length
- val values = new util.ArrayList[String](rowSize)
- var i = 0
- while (i < rowSize) {
- val row = rows(i)
- nulls.set(i, row(ordinal) == null)
- val value =
- if (row(ordinal) == null) {
- ""
- } else {
- toHiveString(row(ordinal), typ)
- }
- values.add(value)
- i += 1
- }
- TColumn.stringVal(new TStringColumn(values, nulls))
- }
- }
-
- private def getOrSetAsNull[T](
- rows: Seq[Seq[Any]],
- ordinal: Int,
- nulls: java.util.BitSet,
- defaultVal: T): java.util.List[T] = {
- val size = rows.length
- val ret = new java.util.ArrayList[T](size)
- var idx = 0
- while (idx < size) {
- val row = rows(idx)
- val isNull = row(ordinal) == null
- if (isNull) {
- nulls.set(idx, true)
- ret.add(idx, defaultVal)
- } else {
- ret.add(idx, row(ordinal).asInstanceOf[T])
- }
- idx += 1
- }
- ret
- }
-
- private def toTColumnValue(
- ordinal: Int,
- row: List[Any],
- types: List[Column]): TColumnValue = {
-
- types(ordinal).getTypeSignature.getRawType match {
- case BOOLEAN =>
- val boolValue = new TBoolValue
- if (row(ordinal) != null)
boolValue.setValue(row(ordinal).asInstanceOf[Boolean])
- TColumnValue.boolVal(boolValue)
-
- case TINYINT =>
- val byteValue = new TByteValue
- if (row(ordinal) != null)
byteValue.setValue(row(ordinal).asInstanceOf[Byte])
- TColumnValue.byteVal(byteValue)
-
- case SMALLINT =>
- val tI16Value = new TI16Value
- if (row(ordinal) != null)
tI16Value.setValue(row(ordinal).asInstanceOf[Short])
- TColumnValue.i16Val(tI16Value)
-
- case INTEGER =>
- val tI32Value = new TI32Value
- if (row(ordinal) != null)
tI32Value.setValue(row(ordinal).asInstanceOf[Int])
- TColumnValue.i32Val(tI32Value)
-
- case BIGINT =>
- val tI64Value = new TI64Value
- if (row(ordinal) != null)
tI64Value.setValue(row(ordinal).asInstanceOf[Long])
- TColumnValue.i64Val(tI64Value)
-
- case REAL =>
- val tDoubleValue = new TDoubleValue
- if (row(ordinal) != null) {
- val doubleValue =
java.lang.Double.valueOf(row(ordinal).asInstanceOf[Float].toString)
- tDoubleValue.setValue(doubleValue)
- }
- TColumnValue.doubleVal(tDoubleValue)
-
- case DOUBLE =>
- val tDoubleValue = new TDoubleValue
- if (row(ordinal) != null)
tDoubleValue.setValue(row(ordinal).asInstanceOf[Double])
- TColumnValue.doubleVal(tDoubleValue)
-
- case VARCHAR =>
- val tStringValue = new TStringValue
- if (row(ordinal) != null)
tStringValue.setValue(row(ordinal).asInstanceOf[String])
- TColumnValue.stringVal(tStringValue)
-
- case _ =>
- val tStrValue = new TStringValue
- if (row(ordinal) != null) {
- tStrValue.setValue(
- toHiveString(row(ordinal), types(ordinal).getTypeSignature))
- }
- TColumnValue.stringVal(tStrValue)
- }
- }
-
/**
* A simpler impl of Trino's toHiveString
*/
diff --git
a/externals/kyuubi-trino-engine/src/main/scala/org/apache/kyuubi/engine/trino/schema/TrinoTRowSetGenerator.scala
b/externals/kyuubi-trino-engine/src/main/scala/org/apache/kyuubi/engine/trino/schema/TrinoTRowSetGenerator.scala
new file mode 100644
index 000000000..9c323a508
--- /dev/null
+++
b/externals/kyuubi-trino-engine/src/main/scala/org/apache/kyuubi/engine/trino/schema/TrinoTRowSetGenerator.scala
@@ -0,0 +1,96 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kyuubi.engine.trino.schema
+
+import io.trino.client.{ClientTypeSignature, Column}
+import io.trino.client.ClientStandardTypes._
+
+import org.apache.kyuubi.engine.schema.AbstractTRowSetGenerator
+import org.apache.kyuubi.engine.trino.schema.RowSet.toHiveString
+import org.apache.kyuubi.shaded.hive.service.rpc.thrift._
+import org.apache.kyuubi.shaded.hive.service.rpc.thrift.TTypeId._
+import org.apache.kyuubi.util.RowSetUtils._
+
+class TrinoTRowSetGenerator
+ extends AbstractTRowSetGenerator[Seq[Column], Seq[_], ClientTypeSignature] {
+
+ override def getColumnSizeFromSchemaType(schema: Seq[Column]): Int =
schema.length
+
+ override def getColumnType(schema: Seq[Column], ordinal: Int):
ClientTypeSignature = {
+ schema(ordinal).getTypeSignature
+ }
+
+ override def isColumnNullAt(row: Seq[_], ordinal: Int): Boolean =
+ row(ordinal) == null
+
+ override def getColumnAs[T](row: Seq[_], ordinal: Int): T =
+ row(ordinal).asInstanceOf[T]
+
+ override def toTColumn(rows: Seq[Seq[_]], ordinal: Int, typ:
ClientTypeSignature): TColumn = {
+ val nulls = new java.util.BitSet()
+ typ.getRawType match {
+ case BOOLEAN => toTTypeColumn(BOOLEAN_TYPE, rows, ordinal)
+ case TINYINT => toTTypeColumn(BINARY_TYPE, rows, ordinal)
+ case SMALLINT => toTTypeColumn(TINYINT_TYPE, rows, ordinal)
+ case INTEGER => toTTypeColumn(INT_TYPE, rows, ordinal)
+ case BIGINT => toTTypeColumn(BIGINT_TYPE, rows, ordinal)
+ case REAL => toTTypeColumn(FLOAT_TYPE, rows, ordinal)
+ case DOUBLE => toTTypeColumn(DOUBLE_TYPE, rows, ordinal)
+ case VARCHAR => toTTypeColumn(STRING_TYPE, rows, ordinal)
+ case VARBINARY => toTTypeColumn(ARRAY_TYPE, rows, ordinal)
+ case _ =>
+ val rowSize = rows.length
+ val values = new java.util.ArrayList[String](rowSize)
+ var i = 0
+ while (i < rowSize) {
+ val row = rows(i)
+ val isNull = isColumnNullAt(row, ordinal)
+ nulls.set(i, isNull)
+ val value = if (isNull) {
+ ""
+ } else {
+ toHiveString(row(ordinal), typ)
+ }
+ values.add(value)
+ i += 1
+ }
+ TColumn.stringVal(new TStringColumn(values, nulls))
+ }
+ }
+
+ override def toTColumnValue(ordinal: Int, row: Seq[_], types: Seq[Column]):
TColumnValue = {
+ getColumnType(types, ordinal).getRawType match {
+ case BOOLEAN => toTTypeColumnVal(BOOLEAN_TYPE, row, ordinal)
+ case TINYINT => toTTypeColumnVal(BINARY_TYPE, row, ordinal)
+ case SMALLINT => toTTypeColumnVal(TINYINT_TYPE, row, ordinal)
+ case INTEGER => toTTypeColumnVal(INT_TYPE, row, ordinal)
+ case BIGINT => toTTypeColumnVal(BIGINT_TYPE, row, ordinal)
+ case REAL => toTTypeColumnVal(FLOAT_TYPE, row, ordinal)
+ case DOUBLE => toTTypeColumnVal(DOUBLE_TYPE, row, ordinal)
+ case VARCHAR => toTTypeColumnVal(STRING_TYPE, row, ordinal)
+ case _ =>
+ val tStrValue = new TStringValue
+ if (row(ordinal) != null) {
+ tStrValue.setValue(
+ toHiveString(row(ordinal), types(ordinal).getTypeSignature))
+ }
+ TColumnValue.stringVal(tStrValue)
+ }
+ }
+
+}
diff --git
a/externals/kyuubi-trino-engine/src/test/scala/org/apache/kyuubi/engine/trino/schema/RowSetSuite.scala
b/externals/kyuubi-trino-engine/src/test/scala/org/apache/kyuubi/engine/trino/schema/RowSetSuite.scala
index acc55d5a3..461c453ec 100644
---
a/externals/kyuubi-trino-engine/src/test/scala/org/apache/kyuubi/engine/trino/schema/RowSetSuite.scala
+++
b/externals/kyuubi-trino-engine/src/test/scala/org/apache/kyuubi/engine/trino/schema/RowSetSuite.scala
@@ -126,7 +126,7 @@ class RowSetSuite extends KyuubiFunSuite {
def uuidSuffix(value: Int): String = if (value > 9) value.toString else
s"f$value"
test("column based set") {
- val tRowSet = RowSet.toColumnBasedSet(rows, schema)
+ val tRowSet = new TrinoTRowSetGenerator().toColumnBasedSet(rows, schema)
assert(tRowSet.getColumns.size() === schema.size)
assert(tRowSet.getRowsSize === 0)
@@ -277,7 +277,7 @@ class RowSetSuite extends KyuubiFunSuite {
}
test("row based set") {
- val tRowSet = RowSet.toRowBasedSet(rows, schema)
+ val tRowSet = new TrinoTRowSetGenerator().toRowBasedSet(rows, schema)
assert(tRowSet.getColumnCount === 0)
assert(tRowSet.getRowsSize === rows.size)
val iter = tRowSet.getRowsIterator
@@ -333,7 +333,7 @@ class RowSetSuite extends KyuubiFunSuite {
test("to row set") {
TProtocolVersion.values().foreach { proto =>
- val set = RowSet.toTRowSet(rows, schema, proto)
+ val set = new TrinoTRowSetGenerator().toTRowSet(rows, schema, proto)
if (proto.getValue <
TProtocolVersion.HIVE_CLI_SERVICE_PROTOCOL_V6.getValue) {
assert(!set.isSetColumns, proto.toString)
assert(set.isSetRows, proto.toString)
diff --git
a/kyuubi-common/src/main/scala/org/apache/kyuubi/engine/schema/AbstractTRowSetGenerator.scala
b/kyuubi-common/src/main/scala/org/apache/kyuubi/engine/schema/AbstractTRowSetGenerator.scala
new file mode 100644
index 000000000..365ed7298
--- /dev/null
+++
b/kyuubi-common/src/main/scala/org/apache/kyuubi/engine/schema/AbstractTRowSetGenerator.scala
@@ -0,0 +1,210 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kyuubi.engine.schema
+import java.nio.ByteBuffer
+import java.util.{ArrayList => JArrayList, BitSet => JBitSet, List => JList}
+
+import scala.collection.JavaConverters._
+
+import org.apache.kyuubi.shaded.hive.service.rpc.thrift._
+import org.apache.kyuubi.shaded.hive.service.rpc.thrift.TTypeId._
+import org.apache.kyuubi.util.RowSetUtils.bitSetToBuffer
+
+trait AbstractTRowSetGenerator[SchemaT, RowT, ColumnT] {
+
+ protected def getColumnSizeFromSchemaType(schema: SchemaT): Int
+
+ protected def getColumnType(schema: SchemaT, ordinal: Int): ColumnT
+
+ protected def isColumnNullAt(row: RowT, ordinal: Int): Boolean
+
+ protected def getColumnAs[T](row: RowT, ordinal: Int): T
+
+ protected def toTColumn(rows: Seq[RowT], ordinal: Int, typ: ColumnT): TColumn
+
+ protected def toTColumnValue(ordinal: Int, row: RowT, types: SchemaT):
TColumnValue
+
+ def toTRowSet(
+ rows: Seq[RowT],
+ schema: SchemaT,
+ protocolVersion: TProtocolVersion): TRowSet = {
+ if (protocolVersion.getValue <
TProtocolVersion.HIVE_CLI_SERVICE_PROTOCOL_V6.getValue) {
+ toRowBasedSet(rows, schema)
+ } else {
+ toColumnBasedSet(rows, schema)
+ }
+ }
+
+ def toRowBasedSet(rows: Seq[RowT], schema: SchemaT): TRowSet = {
+ val rowSize = rows.length
+ val tRows = new JArrayList[TRow](rowSize)
+ var i = 0
+ while (i < rowSize) {
+ val row = rows(i)
+ var j = 0
+ val columnSize = getColumnSizeFromSchemaType(schema)
+ val tColumnValues = new JArrayList[TColumnValue](columnSize)
+ while (j < columnSize) {
+ val columnValue = toTColumnValue(j, row, schema)
+ tColumnValues.add(columnValue)
+ j += 1
+ }
+ i += 1
+ val tRow = new TRow(tColumnValues)
+ tRows.add(tRow)
+ }
+ new TRowSet(0, tRows)
+ }
+
+ def toColumnBasedSet(rows: Seq[RowT], schema: SchemaT): TRowSet = {
+ val rowSize = rows.length
+ val tRowSet = new TRowSet(0, new JArrayList[TRow](rowSize))
+ var i = 0
+ val columnSize = getColumnSizeFromSchemaType(schema)
+ val tColumns = new JArrayList[TColumn](columnSize)
+ while (i < columnSize) {
+ val tColumn = toTColumn(rows, i, getColumnType(schema, i))
+ tColumns.add(tColumn)
+ i += 1
+ }
+ tRowSet.setColumns(tColumns)
+ tRowSet
+ }
+
+ protected def getOrSetAsNull[T](
+ rows: Seq[RowT],
+ ordinal: Int,
+ nulls: JBitSet,
+ defaultVal: T): JList[T] = {
+ val size = rows.length
+ val ret = new JArrayList[T](size)
+ var idx = 0
+ while (idx < size) {
+ val row = rows(idx)
+ val isNull = isColumnNullAt(row, ordinal)
+ if (isNull) {
+ nulls.set(idx, true)
+ ret.add(defaultVal)
+ } else {
+ ret.add(getColumnAs[T](row, ordinal))
+ }
+ idx += 1
+ }
+ ret
+ }
+
+ protected def toTTypeColumnVal(typeId: TTypeId, row: RowT, ordinal: Int):
TColumnValue = {
+ def isNull = isColumnNullAt(row, ordinal)
+ typeId match {
+ case BOOLEAN_TYPE =>
+ val boolValue = new TBoolValue
+ if (!isNull) boolValue.setValue(getColumnAs[java.lang.Boolean](row,
ordinal))
+ TColumnValue.boolVal(boolValue)
+
+ case BINARY_TYPE =>
+ val byteValue = new TByteValue
+ if (!isNull) byteValue.setValue(getColumnAs[java.lang.Byte](row,
ordinal))
+ TColumnValue.byteVal(byteValue)
+
+ case TINYINT_TYPE =>
+ val tI16Value = new TI16Value
+ if (!isNull) tI16Value.setValue(getColumnAs[java.lang.Short](row,
ordinal))
+ TColumnValue.i16Val(tI16Value)
+
+ case INT_TYPE =>
+ val tI32Value = new TI32Value
+ if (!isNull) tI32Value.setValue(getColumnAs[java.lang.Integer](row,
ordinal))
+ TColumnValue.i32Val(tI32Value)
+
+ case BIGINT_TYPE =>
+ val tI64Value = new TI64Value
+ if (!isNull) tI64Value.setValue(getColumnAs[java.lang.Long](row,
ordinal))
+ TColumnValue.i64Val(tI64Value)
+
+ case FLOAT_TYPE =>
+ val tDoubleValue = new TDoubleValue
+ if (!isNull) tDoubleValue.setValue(getColumnAs[java.lang.Float](row,
ordinal).toDouble)
+ TColumnValue.doubleVal(tDoubleValue)
+
+ case DOUBLE_TYPE =>
+ val tDoubleValue = new TDoubleValue
+ if (!isNull) tDoubleValue.setValue(getColumnAs[java.lang.Double](row,
ordinal))
+ TColumnValue.doubleVal(tDoubleValue)
+
+ case STRING_TYPE =>
+ val tStringValue = new TStringValue
+ if (!isNull) tStringValue.setValue(getColumnAs[String](row, ordinal))
+ TColumnValue.stringVal(tStringValue)
+
+ case otherType =>
+ throw new UnsupportedOperationException(s"unsupported type $otherType
for toTTypeColumnVal")
+ }
+ }
+
+ protected def toTTypeColumn(typeId: TTypeId, rows: Seq[RowT], ordinal: Int):
TColumn = {
+ val nulls = new JBitSet()
+ typeId match {
+ case BOOLEAN_TYPE =>
+ val values = getOrSetAsNull[java.lang.Boolean](rows, ordinal, nulls,
true)
+ TColumn.boolVal(new TBoolColumn(values, nulls))
+
+ case BINARY_TYPE =>
+ val values = getOrSetAsNull[java.lang.Byte](rows, ordinal, nulls,
0.toByte)
+ TColumn.byteVal(new TByteColumn(values, nulls))
+
+ case SMALLINT_TYPE =>
+ val values = getOrSetAsNull[java.lang.Short](rows, ordinal, nulls,
0.toShort)
+ TColumn.i16Val(new TI16Column(values, nulls))
+
+ case TINYINT_TYPE =>
+ val values = getOrSetAsNull[java.lang.Short](rows, ordinal, nulls,
0.toShort)
+ TColumn.i16Val(new TI16Column(values, nulls))
+
+ case INT_TYPE =>
+ val values = getOrSetAsNull[java.lang.Integer](rows, ordinal, nulls, 0)
+ TColumn.i32Val(new TI32Column(values, nulls))
+
+ case BIGINT_TYPE =>
+ val values = getOrSetAsNull[java.lang.Long](rows, ordinal, nulls, 0L)
+ TColumn.i64Val(new TI64Column(values, nulls))
+
+ case FLOAT_TYPE =>
+ val values = getOrSetAsNull[java.lang.Float](rows, ordinal, nulls,
0.toFloat)
+ .asScala.map(n => java.lang.Double.valueOf(n.toString)).asJava
+ TColumn.doubleVal(new TDoubleColumn(values, nulls))
+
+ case DOUBLE_TYPE =>
+ val values = getOrSetAsNull[java.lang.Double](rows, ordinal, nulls,
0.toDouble)
+ TColumn.doubleVal(new TDoubleColumn(values, nulls))
+
+ case STRING_TYPE =>
+ val values = getOrSetAsNull[java.lang.String](rows, ordinal, nulls, "")
+ TColumn.stringVal(new TStringColumn(values, nulls))
+
+ case ARRAY_TYPE =>
+ val values = getOrSetAsNull[Array[Byte]](rows, ordinal, nulls, Array())
+ .asScala
+ .map(ByteBuffer.wrap)
+ .asJava
+ TColumn.binaryVal(new TBinaryColumn(values, nulls))
+
+ case otherType =>
+ throw new UnsupportedOperationException(s"unsupported type $otherType
for toTTypeColumnVal")
+ }
+ }
+}
diff --git
a/kyuubi-server/src/main/scala/org/apache/kyuubi/sql/plan/command/RunnableCommand.scala
b/kyuubi-server/src/main/scala/org/apache/kyuubi/sql/plan/command/RunnableCommand.scala
index 8f19d7f7a..cdfb515bd 100644
---
a/kyuubi-server/src/main/scala/org/apache/kyuubi/sql/plan/command/RunnableCommand.scala
+++
b/kyuubi-server/src/main/scala/org/apache/kyuubi/sql/plan/command/RunnableCommand.scala
@@ -22,7 +22,7 @@ import
org.apache.kyuubi.operation.FetchOrientation.{FETCH_FIRST, FETCH_NEXT, FE
import org.apache.kyuubi.session.KyuubiSession
import org.apache.kyuubi.shaded.hive.service.rpc.thrift.{TProtocolVersion,
TRowSet}
import org.apache.kyuubi.sql.plan.KyuubiTreeNode
-import org.apache.kyuubi.sql.schema.{Row, RowSetHelper, Schema}
+import org.apache.kyuubi.sql.schema.{Row, Schema, ServerTRowSetGenerator}
trait RunnableCommand extends KyuubiTreeNode {
@@ -44,7 +44,7 @@ trait RunnableCommand extends KyuubiTreeNode {
case FETCH_FIRST => iter.fetchAbsolute(0)
}
val taken = iter.take(rowSetSize)
- val resultRowSet = RowSetHelper.toTRowSet(
+ val resultRowSet = new ServerTRowSetGenerator().toTRowSet(
taken.toList,
resultSchema,
protocolVersion)
diff --git
a/kyuubi-server/src/main/scala/org/apache/kyuubi/sql/schema/RowSetHelper.scala
b/kyuubi-server/src/main/scala/org/apache/kyuubi/sql/schema/RowSetHelper.scala
deleted file mode 100644
index 7a5fab082..000000000
---
a/kyuubi-server/src/main/scala/org/apache/kyuubi/sql/schema/RowSetHelper.scala
+++ /dev/null
@@ -1,209 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.kyuubi.sql.schema
-
-import java.util
-
-import scala.collection.JavaConverters._
-
-import org.apache.kyuubi.shaded.hive.service.rpc.thrift._
-import org.apache.kyuubi.util.RowSetUtils._
-
-object RowSetHelper {
-
- def toTRowSet(
- rows: Seq[Row],
- schema: Schema,
- protocolVersion: TProtocolVersion): TRowSet = {
- if (protocolVersion.getValue <
TProtocolVersion.HIVE_CLI_SERVICE_PROTOCOL_V6.getValue) {
- toRowBasedSet(rows, schema)
- } else {
- toColumnBasedSet(rows, schema)
- }
- }
-
- def toRowBasedSet(rows: Seq[Row], schema: Schema): TRowSet = {
- var i = 0
- val rowSize = rows.length
- val tRows = new java.util.ArrayList[TRow](rowSize)
- while (i < rowSize) {
- val row = rows(i)
- val tRow = new TRow()
- var j = 0
- val columnSize = row.length
- while (j < columnSize) {
- val columnValue = toTColumnValue(j, row, schema)
- tRow.addToColVals(columnValue)
- j += 1
- }
- i += 1
- tRows.add(tRow)
- }
- new TRowSet(0, tRows)
- }
-
- private def toTColumnValue(
- ordinal: Int,
- row: Row,
- types: Schema): TColumnValue = {
- types(ordinal).dataType match {
- case TTypeId.BOOLEAN_TYPE =>
- val boolValue = new TBoolValue
- if (!row.isNullAt(ordinal)) boolValue.setValue(row.getBoolean(ordinal))
- TColumnValue.boolVal(boolValue)
-
- case TTypeId.BINARY_TYPE =>
- val byteValue = new TByteValue
- if (!row.isNullAt(ordinal)) byteValue.setValue(row.getByte(ordinal))
- TColumnValue.byteVal(byteValue)
-
- case TTypeId.TINYINT_TYPE =>
- val tI16Value = new TI16Value
- if (!row.isNullAt(ordinal)) tI16Value.setValue(row.getShort(ordinal))
- TColumnValue.i16Val(tI16Value)
-
- case TTypeId.INT_TYPE =>
- val tI32Value = new TI32Value
- if (!row.isNullAt(ordinal)) tI32Value.setValue(row.getInt(ordinal))
- TColumnValue.i32Val(tI32Value)
-
- case TTypeId.BIGINT_TYPE =>
- val tI64Value = new TI64Value
- if (!row.isNullAt(ordinal)) tI64Value.setValue(row.getLong(ordinal))
- TColumnValue.i64Val(tI64Value)
-
- case TTypeId.FLOAT_TYPE =>
- val tDoubleValue = new TDoubleValue
- if (!row.isNullAt(ordinal)) {
- val doubleValue =
java.lang.Double.valueOf(row.getFloat(ordinal).toString)
- tDoubleValue.setValue(doubleValue)
- }
- TColumnValue.doubleVal(tDoubleValue)
-
- case TTypeId.DOUBLE_TYPE =>
- val tDoubleValue = new TDoubleValue
- if (!row.isNullAt(ordinal))
tDoubleValue.setValue(row.getDouble(ordinal))
- TColumnValue.doubleVal(tDoubleValue)
-
- case TTypeId.STRING_TYPE =>
- val tStringValue = new TStringValue
- if (!row.isNullAt(ordinal))
tStringValue.setValue(row.getString(ordinal))
- TColumnValue.stringVal(tStringValue)
-
- case _ =>
- val tStrValue = new TStringValue
- if (!row.isNullAt(ordinal)) {
- tStrValue.setValue((row.get(ordinal),
types(ordinal).dataType).toString())
- }
- TColumnValue.stringVal(tStrValue)
- }
- }
-
- def toColumnBasedSet(rows: Seq[Row], schema: Schema): TRowSet = {
- val rowSize = rows.length
- val tRowSet = new TRowSet(0, new java.util.ArrayList[TRow](rowSize))
- var i = 0
- val columnSize = schema.length
- while (i < columnSize) {
- val field = schema(i)
- val tColumn = toTColumn(rows, i, field.dataType)
- tRowSet.addToColumns(tColumn)
- i += 1
- }
- tRowSet
- }
-
- private def toTColumn(rows: Seq[Row], ordinal: Int, typ: TTypeId): TColumn =
{
- val nulls = new java.util.BitSet()
- typ match {
- case TTypeId.BOOLEAN_TYPE =>
- val values = getOrSetAsNull[java.lang.Boolean](rows, ordinal, nulls,
true)
- TColumn.boolVal(new TBoolColumn(values, nulls))
-
- case TTypeId.BINARY_TYPE =>
- val values = getOrSetAsNull[java.lang.Byte](rows, ordinal, nulls,
0.toByte)
- TColumn.byteVal(new TByteColumn(values, nulls))
-
- case TTypeId.TINYINT_TYPE =>
- val values = getOrSetAsNull[java.lang.Short](rows, ordinal, nulls,
0.toShort)
- TColumn.i16Val(new TI16Column(values, nulls))
-
- case TTypeId.INT_TYPE =>
- val values = getOrSetAsNull[java.lang.Integer](rows, ordinal, nulls, 0)
- TColumn.i32Val(new TI32Column(values, nulls))
-
- case TTypeId.BIGINT_TYPE =>
- val values = getOrSetAsNull[java.lang.Long](rows, ordinal, nulls, 0L)
- TColumn.i64Val(new TI64Column(values, nulls))
-
- case TTypeId.FLOAT_TYPE =>
- val values = getOrSetAsNull[java.lang.Float](rows, ordinal, nulls,
0.toFloat)
- .asScala.map(n => java.lang.Double.valueOf(n.toString)).asJava
- TColumn.doubleVal(new TDoubleColumn(values, nulls))
-
- case TTypeId.DOUBLE_TYPE =>
- val values = getOrSetAsNull[java.lang.Double](rows, ordinal, nulls,
0.toDouble)
- TColumn.doubleVal(new TDoubleColumn(values, nulls))
-
- case TTypeId.STRING_TYPE =>
- val values: util.List[String] = getOrSetAsNull[java.lang.String](rows,
ordinal, nulls, "")
- TColumn.stringVal(new TStringColumn(values, nulls))
-
- case _ =>
- var i = 0
- val rowSize = rows.length
- val values = new java.util.ArrayList[String](rowSize)
- while (i < rowSize) {
- val row = rows(i)
- nulls.set(i, row.isNullAt(ordinal))
- val value =
- if (row.isNullAt(ordinal)) {
- ""
- } else {
- (row.get(ordinal), typ).toString()
- }
- values.add(value)
- i += 1
- }
- TColumn.stringVal(new TStringColumn(values, nulls))
- }
- }
-
- private def getOrSetAsNull[T](
- rows: Seq[Row],
- ordinal: Int,
- nulls: java.util.BitSet,
- defaultVal: T): java.util.List[T] = {
- val size = rows.length
- val ret = new java.util.ArrayList[T](size)
- var idx = 0
- while (idx < size) {
- val row = rows(idx)
- val isNull = row.isNullAt(ordinal)
- if (isNull) {
- nulls.set(idx, true)
- ret.add(idx, defaultVal)
- } else {
- ret.add(idx, row.getAs[T](ordinal))
- }
- idx += 1
- }
- ret
- }
-
-}
diff --git
a/kyuubi-server/src/main/scala/org/apache/kyuubi/sql/schema/ServerTRowSetGenerator.scala
b/kyuubi-server/src/main/scala/org/apache/kyuubi/sql/schema/ServerTRowSetGenerator.scala
new file mode 100644
index 000000000..e1a9d55a6
--- /dev/null
+++
b/kyuubi-server/src/main/scala/org/apache/kyuubi/sql/schema/ServerTRowSetGenerator.scala
@@ -0,0 +1,78 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kyuubi.sql.schema
+
+import org.apache.kyuubi.engine.schema.AbstractTRowSetGenerator
+import org.apache.kyuubi.shaded.hive.service.rpc.thrift._
+import org.apache.kyuubi.shaded.hive.service.rpc.thrift.TTypeId._
+import org.apache.kyuubi.util.RowSetUtils._
+
+class ServerTRowSetGenerator
+ extends AbstractTRowSetGenerator[Schema, Row, TTypeId] {
+
+ override def getColumnSizeFromSchemaType(schema: Schema): Int = schema.length
+
+ override def getColumnType(schema: Schema, ordinal: Int): TTypeId =
schema(ordinal).dataType
+
+ override def isColumnNullAt(row: Row, ordinal: Int): Boolean =
row.isNullAt(ordinal)
+
+ override def getColumnAs[T](row: Row, ordinal: Int): T =
row.getAs[T](ordinal)
+
+ override def toTColumn(rows: Seq[Row], ordinal: Int, typ: TTypeId): TColumn
= {
+ val nulls = new java.util.BitSet()
+ typ match {
+ case t @ (BOOLEAN_TYPE | BINARY_TYPE | BINARY_TYPE | TINYINT_TYPE |
INT_TYPE |
+ BIGINT_TYPE | FLOAT_TYPE | DOUBLE_TYPE | STRING_TYPE) =>
+ toTTypeColumn(t, rows, ordinal)
+
+ case _ =>
+ var i = 0
+ val rowSize = rows.length
+ val values = new java.util.ArrayList[String](rowSize)
+ while (i < rowSize) {
+ val row = rows(i)
+ val isNull = isColumnNullAt(row, ordinal)
+ nulls.set(i, isNull)
+ val value = if (isNull) {
+ ""
+ } else {
+ (row.get(ordinal), typ).toString()
+ }
+ values.add(value)
+ i += 1
+ }
+ TColumn.stringVal(new TStringColumn(values, nulls))
+ }
+ }
+
+ override def toTColumnValue(ordinal: Int, row: Row, types: Schema):
TColumnValue = {
+ getColumnType(types, ordinal) match {
+ case t @ (BOOLEAN_TYPE | BINARY_TYPE | BINARY_TYPE | TINYINT_TYPE |
INT_TYPE |
+ BIGINT_TYPE | FLOAT_TYPE | DOUBLE_TYPE | STRING_TYPE) =>
+ toTTypeColumnVal(t, row, ordinal)
+
+ case _ =>
+ val tStrValue = new TStringValue
+ if (!isColumnNullAt(row, ordinal)) {
+ tStrValue.setValue((row.get(ordinal),
types(ordinal).dataType).toString())
+ }
+ TColumnValue.stringVal(tStrValue)
+ }
+ }
+
+}