This is an automated email from the ASF dual-hosted git repository.
diwu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris-spark-connector.git
The following commit(s) were added to refs/heads/master by this push:
new 4849ee7 [feature](step:one)Add arrow format (#161)
4849ee7 is described below
commit 4849ee72968d2da958bde621e00d252e3049f500
Author: wuwenchi <[email protected]>
AuthorDate: Fri Dec 8 21:53:27 2023 +0800
[feature](step:one)Add arrow format (#161)
---
README.md | 23 +
spark-doris-connector/pom.xml | 2 +-
.../org/apache/doris/spark/load/DataFormat.java | 24 ++
.../apache/doris/spark/load/DorisStreamLoad.java | 29 +-
.../org/apache/doris/spark/load/RecordBatch.java | 37 +-
.../doris/spark/load/RecordBatchInputStream.java | 78 +++-
.../java/org/apache/doris/spark/util/DataUtil.java | 2 -
.../testcase/TestStreamLoadForArrowType.scala | 461 +++++++++++++++++++++
.../spark/sql/doris/spark/ArrowSchemaUtils.scala | 52 +++
.../doris/spark/sql/TestConnectorWriteDoris.scala | 1 -
10 files changed, 665 insertions(+), 44 deletions(-)
diff --git a/README.md b/README.md
index 70c0a77..5661318 100644
--- a/README.md
+++ b/README.md
@@ -124,6 +124,29 @@ dorisSparkDF = spark.read.format("doris")
dorisSparkDF.show(5)
```
+## type convertion for writing to doris using arrow
+|doris|spark|
+|---|---|
+| BOOLEAN | BooleanType |
+| TINYINT | ByteType |
+| SMALLINT | ShortType |
+| INT | IntegerType |
+| BIGINT | LongType |
+| LARGEINT | StringType |
+| FLOAT | FloatType |
+| DOUBLE | DoubleType |
+| DECIMAL(M,D) | DecimalType(M,D) |
+| DATE | DateType |
+| DATETIME | TimestampType |
+| CHAR(L) | StringType |
+| VARCHAR(L) | StringType |
+| STRING | StringType |
+| ARRAY | ARRAY |
+| MAP | MAP |
+| STRUCT | STRUCT |
+
+
+
## Report issues or submit pull request
If you find any bugs, feel free to file a [GitHub
issue](https://github.com/apache/doris/issues) or fix it by submitting a [pull
request](https://github.com/apache/doris/pulls).
diff --git a/spark-doris-connector/pom.xml b/spark-doris-connector/pom.xml
index 518a3e2..89e1716 100644
--- a/spark-doris-connector/pom.xml
+++ b/spark-doris-connector/pom.xml
@@ -72,7 +72,7 @@
<spark.major.version>3.1</spark.major.version>
<scala.version>2.12</scala.version>
<libthrift.version>0.16.0</libthrift.version>
- <arrow.version>5.0.0</arrow.version>
+ <arrow.version>13.0.0</arrow.version>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<project.scm.id>github</project.scm.id>
<netty.version>4.1.77.Final</netty.version>
diff --git
a/spark-doris-connector/src/main/java/org/apache/doris/spark/load/DataFormat.java
b/spark-doris-connector/src/main/java/org/apache/doris/spark/load/DataFormat.java
new file mode 100644
index 0000000..e3e3b0b
--- /dev/null
+++
b/spark-doris-connector/src/main/java/org/apache/doris/spark/load/DataFormat.java
@@ -0,0 +1,24 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.spark.load;
+
+public enum DataFormat {
+ CSV,
+ JSON,
+ ARROW
+}
diff --git
a/spark-doris-connector/src/main/java/org/apache/doris/spark/load/DorisStreamLoad.java
b/spark-doris-connector/src/main/java/org/apache/doris/spark/load/DorisStreamLoad.java
index f473636..7758c6f 100644
---
a/spark-doris-connector/src/main/java/org/apache/doris/spark/load/DorisStreamLoad.java
+++
b/spark-doris-connector/src/main/java/org/apache/doris/spark/load/DorisStreamLoad.java
@@ -98,7 +98,7 @@ public class DorisStreamLoad implements Serializable {
private static final long cacheExpireTimeout = 4 * 60;
private final LoadingCache<String, List<BackendV2.BackendRowV2>> cache;
private final String fenodes;
- private final String fileType;
+ private final DataFormat dataFormat;
private String FIELD_DELIMITER;
private final String LINE_DELIMITER;
private boolean streamingPassthrough = false;
@@ -119,16 +119,19 @@ public class DorisStreamLoad implements Serializable {
this.maxFilterRatio =
settings.getProperty(ConfigurationOptions.DORIS_MAX_FILTER_RATIO);
this.streamLoadProp = getStreamLoadProp(settings);
cache = CacheBuilder.newBuilder().expireAfterWrite(cacheExpireTimeout,
TimeUnit.MINUTES).build(new BackendCacheLoader(settings));
- fileType = streamLoadProp.getOrDefault("format", "csv");
- if ("csv".equals(fileType)) {
- FIELD_DELIMITER =
escapeString(streamLoadProp.getOrDefault("column_separator", "\t"));
- this.addDoubleQuotes =
Boolean.parseBoolean(streamLoadProp.getOrDefault("add_double_quotes", "false"));
- if (addDoubleQuotes) {
- LOG.info("set add_double_quotes for csv mode, add
trim_double_quotes to true for prop.");
- streamLoadProp.put("trim_double_quotes", "true");
- }
- } else if ("json".equalsIgnoreCase(fileType)) {
- streamLoadProp.put("read_json_by_line", "true");
+ dataFormat = DataFormat.valueOf(streamLoadProp.getOrDefault("format",
"csv").toUpperCase());
+ switch (dataFormat) {
+ case CSV:
+ FIELD_DELIMITER =
escapeString(streamLoadProp.getOrDefault("column_separator", "\t"));
+ this.addDoubleQuotes =
Boolean.parseBoolean(streamLoadProp.getOrDefault("add_double_quotes", "false"));
+ if (addDoubleQuotes) {
+ LOG.info("set add_double_quotes for csv mode, add
trim_double_quotes to true for prop.");
+ streamLoadProp.put("trim_double_quotes", "true");
+ }
+ break;
+ case JSON:
+ streamLoadProp.put("read_json_by_line", "true");
+ break;
}
LINE_DELIMITER =
escapeString(streamLoadProp.getOrDefault("line_delimiter", "\n"));
this.streamingPassthrough =
settings.getBooleanProperty(ConfigurationOptions.DORIS_SINK_STREAMING_PASSTHROUGH,
@@ -218,7 +221,7 @@ public class DorisStreamLoad implements Serializable {
this.loadUrlStr = loadUrlStr;
HttpPut httpPut = getHttpPut(label, loadUrlStr, enable2PC, schema);
RecordBatchInputStream recodeBatchInputStream = new
RecordBatchInputStream(RecordBatch.newBuilder(rows)
- .format(fileType)
+ .format(dataFormat)
.sep(FIELD_DELIMITER)
.delim(LINE_DELIMITER)
.schema(schema)
@@ -492,7 +495,7 @@ public class DorisStreamLoad implements Serializable {
*/
private void handleStreamPassThrough() {
- if ("json".equalsIgnoreCase(fileType)) {
+ if (dataFormat.equals(DataFormat.JSON)) {
LOG.info("handle stream pass through, force set read_json_by_line
is true for json format");
streamLoadProp.put("read_json_by_line", "true");
streamLoadProp.remove("strip_outer_array");
diff --git
a/spark-doris-connector/src/main/java/org/apache/doris/spark/load/RecordBatch.java
b/spark-doris-connector/src/main/java/org/apache/doris/spark/load/RecordBatch.java
index e471d5b..b514586 100644
---
a/spark-doris-connector/src/main/java/org/apache/doris/spark/load/RecordBatch.java
+++
b/spark-doris-connector/src/main/java/org/apache/doris/spark/load/RecordBatch.java
@@ -17,7 +17,11 @@
package org.apache.doris.spark.load;
+import org.apache.arrow.memory.RootAllocator;
+import org.apache.arrow.vector.VectorSchemaRoot;
+import org.apache.arrow.vector.types.pojo.Schema;
import org.apache.spark.sql.catalyst.InternalRow;
+import org.apache.spark.sql.doris.spark.ArrowSchemaUtils;
import org.apache.spark.sql.types.StructType;
import java.nio.charset.Charset;
@@ -39,7 +43,7 @@ public class RecordBatch {
/**
* stream load format
*/
- private final String format;
+ private final DataFormat format;
/**
* column separator, only used when the format is csv
@@ -58,7 +62,11 @@ public class RecordBatch {
private final boolean addDoubleQuotes;
- private RecordBatch(Iterator<InternalRow> iterator, String format, String
sep, byte[] delim,
+ private VectorSchemaRoot arrowRoot = null;
+
+ private int arrowBatchSize = 1000;
+
+ private RecordBatch(Iterator<InternalRow> iterator, DataFormat format,
String sep, byte[] delim,
StructType schema, boolean addDoubleQuotes) {
this.iterator = iterator;
this.format = format;
@@ -66,13 +74,17 @@ public class RecordBatch {
this.delim = delim;
this.schema = schema;
this.addDoubleQuotes = addDoubleQuotes;
+ if (format.equals(DataFormat.ARROW)) {
+ Schema arrowSchema = ArrowSchemaUtils.toArrowSchema(schema, "UTC");
+ this.arrowRoot = VectorSchemaRoot.create(arrowSchema, new
RootAllocator(Integer.MAX_VALUE));
+ }
}
public Iterator<InternalRow> getIterator() {
return iterator;
}
- public String getFormat() {
+ public DataFormat getFormat() {
return format;
}
@@ -84,13 +96,28 @@ public class RecordBatch {
return delim;
}
+ public VectorSchemaRoot getVectorSchemaRoot() {
+ return arrowRoot;
+ }
+
public StructType getSchema() {
return schema;
}
+ public int getArrowBatchSize() {
+ return arrowBatchSize;
+ }
+
public boolean getAddDoubleQuotes(){
return addDoubleQuotes;
}
+
+ public void clearBatch() {
+ if (format.equals(DataFormat.ARROW)) {
+ this.arrowRoot.clear();
+ }
+ }
+
public static Builder newBuilder(Iterator<InternalRow> iterator) {
return new Builder(iterator);
}
@@ -102,7 +129,7 @@ public class RecordBatch {
private final Iterator<InternalRow> iterator;
- private String format;
+ private DataFormat format;
private String sep;
@@ -116,7 +143,7 @@ public class RecordBatch {
this.iterator = iterator;
}
- public Builder format(String format) {
+ public Builder format(DataFormat format) {
this.format = format;
return this;
}
diff --git
a/spark-doris-connector/src/main/java/org/apache/doris/spark/load/RecordBatchInputStream.java
b/spark-doris-connector/src/main/java/org/apache/doris/spark/load/RecordBatchInputStream.java
index 544e683..c43f685 100644
---
a/spark-doris-connector/src/main/java/org/apache/doris/spark/load/RecordBatchInputStream.java
+++
b/spark-doris-connector/src/main/java/org/apache/doris/spark/load/RecordBatchInputStream.java
@@ -17,6 +17,8 @@
package org.apache.doris.spark.load;
+import org.apache.arrow.vector.dictionary.DictionaryProvider;
+import org.apache.arrow.vector.ipc.ArrowStreamWriter;
import org.apache.doris.spark.exception.DorisException;
import org.apache.doris.spark.exception.IllegalArgumentException;
import org.apache.doris.spark.exception.ShouldNeverHappenException;
@@ -24,9 +26,11 @@ import org.apache.doris.spark.util.DataUtil;
import com.fasterxml.jackson.core.JsonProcessingException;
import org.apache.spark.sql.catalyst.InternalRow;
+import org.apache.spark.sql.execution.arrow.ArrowWriter;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
+import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.nio.ByteBuffer;
@@ -54,12 +58,17 @@ public class RecordBatchInputStream extends InputStream {
* record buffer
*/
- private ByteBuffer lineBuf = ByteBuffer.allocate(0);;
+ private ByteBuffer lineBuf = ByteBuffer.allocate(0);
private ByteBuffer delimBuf = ByteBuffer.allocate(0);
private final byte[] delim;
+ /**
+ * record count has been read
+ */
+ private int readCount = 0;
+
/**
* streaming mode pass through data without process
*/
@@ -73,18 +82,13 @@ public class RecordBatchInputStream extends InputStream {
@Override
public int read() throws IOException {
- try {
- if (lineBuf.remaining() == 0 && endOfBatch()) {
- return -1;
- }
-
- if (delimBuf != null && delimBuf.remaining() > 0) {
- return delimBuf.get() & 0xff;
- }
- } catch (DorisException e) {
- throw new IOException(e);
+ byte[] bytes = new byte[1];
+ int read = read(bytes, 0, 1);
+ if (read < 0) {
+ return -1;
+ } else {
+ return bytes[0];
}
- return lineBuf.get() & 0xFF;
}
@Override
@@ -102,6 +106,7 @@ public class RecordBatchInputStream extends InputStream {
} catch (DorisException e) {
throw new IOException(e);
}
+
int bytesRead = Math.min(len, lineBuf.remaining());
lineBuf.get(b, off, bytesRead);
return bytesRead;
@@ -121,6 +126,8 @@ public class RecordBatchInputStream extends InputStream {
readNext(iterator);
return false;
}
+
+ recordBatch.clearBatch();
delimBuf = null;
return true;
}
@@ -135,14 +142,41 @@ public class RecordBatchInputStream extends InputStream {
if (!iterator.hasNext()) {
throw new ShouldNeverHappenException();
}
- byte[] rowBytes = rowToByte(iterator.next());
- if (isFirst) {
+
+ if (recordBatch.getFormat().equals(DataFormat.ARROW)) {
+ ArrowWriter arrowWriter =
ArrowWriter.create(recordBatch.getVectorSchemaRoot());
+ while (iterator.hasNext() && readCount <
recordBatch.getArrowBatchSize()) {
+ arrowWriter.write(iterator.next());
+ readCount++;
+ }
+ arrowWriter.finish();
+
+ ByteArrayOutputStream out = new ByteArrayOutputStream();
+ ArrowStreamWriter writer = new ArrowStreamWriter(
+ recordBatch.getVectorSchemaRoot(),
+ new DictionaryProvider.MapDictionaryProvider(),
+ out);
+
+ try {
+ writer.writeBatch();
+ writer.end();
+ } catch (IOException e) {
+ throw new DorisException(e);
+ }
+
delimBuf = null;
- lineBuf = ByteBuffer.wrap(rowBytes);
- isFirst = false;
+ lineBuf = ByteBuffer.wrap(out.toByteArray());
+ readCount = 0;
} else {
- delimBuf = ByteBuffer.wrap(delim);
- lineBuf = ByteBuffer.wrap(rowBytes);
+ byte[] rowBytes = rowToByte(iterator.next());
+ if (isFirst) {
+ delimBuf = null;
+ lineBuf = ByteBuffer.wrap(rowBytes);
+ isFirst = false;
+ } else {
+ delimBuf = ByteBuffer.wrap(delim);
+ lineBuf = ByteBuffer.wrap(rowBytes);
+ }
}
}
@@ -162,11 +196,11 @@ public class RecordBatchInputStream extends InputStream {
return bytes;
}
- switch (recordBatch.getFormat().toLowerCase()) {
- case "csv":
+ switch (recordBatch.getFormat()) {
+ case CSV:
bytes = DataUtil.rowToCsvBytes(row, recordBatch.getSchema(),
recordBatch.getSep(), recordBatch.getAddDoubleQuotes());
break;
- case "json":
+ case JSON:
try {
bytes = DataUtil.rowToJsonBytes(row,
recordBatch.getSchema());
} catch (JsonProcessingException e) {
@@ -174,7 +208,7 @@ public class RecordBatchInputStream extends InputStream {
}
break;
default:
- throw new IllegalArgumentException("format",
recordBatch.getFormat());
+ throw new IllegalArgumentException("Unsupported format: ",
recordBatch.getFormat().toString());
}
return bytes;
diff --git
a/spark-doris-connector/src/main/java/org/apache/doris/spark/util/DataUtil.java
b/spark-doris-connector/src/main/java/org/apache/doris/spark/util/DataUtil.java
index 763d72b..f7218c3 100644
---
a/spark-doris-connector/src/main/java/org/apache/doris/spark/util/DataUtil.java
+++
b/spark-doris-connector/src/main/java/org/apache/doris/spark/util/DataUtil.java
@@ -31,10 +31,8 @@ import org.apache.spark.sql.types.StructType;
import java.nio.charset.StandardCharsets;
import java.util.HashMap;
import java.util.Map;
-import java.util.Objects;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
-import java.util.stream.Stream;
public class DataUtil {
diff --git
a/spark-doris-connector/src/main/scala/org/apache/doris/spark/testcase/TestStreamLoadForArrowType.scala
b/spark-doris-connector/src/main/scala/org/apache/doris/spark/testcase/TestStreamLoadForArrowType.scala
new file mode 100644
index 0000000..f1e1977
--- /dev/null
+++
b/spark-doris-connector/src/main/scala/org/apache/doris/spark/testcase/TestStreamLoadForArrowType.scala
@@ -0,0 +1,461 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.spark.testcase
+
+import org.apache.spark.sql.types._
+import org.apache.spark.sql.{Row, SparkSession}
+
+import java.sql.{Date, Timestamp}
+import scala.collection.mutable.ListBuffer
+
+// This object is used to test writing from spark into doris with arrow format,
+// And it will be executed in doris's pipeline.
+object TestStreamLoadForArrowType {
+ val spark: SparkSession =
SparkSession.builder().master("local[1]").getOrCreate()
+ var dorisFeNodes = "127.0.0.1:8030"
+ var dorisUser = "root"
+ val dorisPwd = ""
+ var databaseName = ""
+
+ def main(args: Array[String]): Unit = {
+
+ dorisFeNodes = args(0)
+ dorisUser = args(1)
+ databaseName = args(2)
+
+ testDataframeWritePrimitiveType()
+ testDataframeWriteArrayTypes()
+ testDataframeWriteMapType()
+ testDataframeWriteStructType()
+
+ spark.stop()
+ }
+
+
+ def testDataframeWritePrimitiveType(): Unit = {
+ /*
+
+ CREATE TABLE `spark_connector_primitive` (
+ `id` int(11) NOT NULL,
+ `c_bool` boolean NULL,
+ `c_tinyint` tinyint NULL,
+ `c_smallint` smallint NULL,
+ `c_int` int NULL,
+ `c_bigint` bigint NULL,
+ `c_largeint` largeint NULL,
+ `c_float` float NULL,
+ `c_double` double NULL,
+ `c_decimal` DECIMAL(10, 5) NULL,
+ `c_date` date NULL,
+ `c_datetime` datetime(6) NULL,
+ `c_char` char(10) NULL,
+ `c_varchar` varchar(10) NULL,
+ `c_string` string NULL
+ ) ENGINE=OLAP
+ DUPLICATE KEY(`id`)
+ COMMENT 'OLAP'
+ DISTRIBUTED BY HASH(`id`) BUCKETS 1
+ PROPERTIES (
+ "replication_allocation" = "tag.location.default: 1"
+ );
+
+ +------------+----------------+------+-------+---------+-------+
+ | Field | Type | Null | Key | Default | Extra |
+ +------------+----------------+------+-------+---------+-------+
+ | id | INT | No | true | NULL | |
+ | c_bool | BOOLEAN | Yes | false | NULL | NONE |
+ | c_tinyint | TINYINT | Yes | false | NULL | NONE |
+ | c_smallint | SMALLINT | Yes | false | NULL | NONE |
+ | c_int | INT | Yes | false | NULL | NONE |
+ | c_bigint | BIGINT | Yes | false | NULL | NONE |
+ | c_largeint | LARGEINT | Yes | false | NULL | NONE |
+ | c_float | FLOAT | Yes | false | NULL | NONE |
+ | c_double | DOUBLE | Yes | false | NULL | NONE |
+ | c_decimal | DECIMAL(10, 5) | Yes | false | NULL | NONE |
+ | c_date | DATE | Yes | false | NULL | NONE |
+ | c_datetime | DATETIME(6) | Yes | false | NULL | NONE |
+ | c_char | CHAR(10) | Yes | false | NULL | NONE |
+ | c_varchar | VARCHAR(10) | Yes | false | NULL | NONE |
+ | c_string | TEXT | Yes | false | NULL | NONE |
+ +------------+----------------+------+-------+---------+-------+
+
+ */
+
+ val schema = new StructType()
+ .add("id", IntegerType)
+ .add("c_bool", BooleanType)
+ .add("c_tinyint", ByteType)
+ .add("c_smallint", ShortType)
+ .add("c_int", IntegerType)
+ .add("c_bigint", LongType)
+ .add("c_largeint", StringType)
+ .add("c_float", FloatType)
+ .add("c_double", DoubleType)
+ .add("c_decimal", DecimalType.apply(10, 5))
+ .add("c_date", DateType)
+ .add("c_datetime", TimestampType)
+ .add("c_char", StringType)
+ .add("c_varchar", StringType)
+ .add("c_string", StringType)
+
+ val row = Row(
+ 1,
+ true,
+ 1.toByte,
+ 2.toShort,
+ 3,
+ 4.toLong,
+ "123456789",
+ 6.6.floatValue(),
+ 7.7.doubleValue(),
+ Decimal.apply(3.12),
+ Date.valueOf("2023-09-08"),
+ Timestamp.valueOf("2023-09-08 17:12:34.123456"),
+ "char",
+ "varchar",
+ "string"
+ )
+
+
+ val inputList = ListBuffer[Row]()
+ for (a <- 0 until 7) {
+ inputList.append(row)
+ }
+
+ val rdd = spark.sparkContext.parallelize(inputList, 1)
+ val df = spark.createDataFrame(rdd, schema).toDF()
+
+ df.write
+ .format("doris")
+ .option("doris.fenodes", dorisFeNodes)
+ .option("user", dorisUser)
+ .option("password", dorisPwd)
+ .option("doris.table.identifier",
s"$databaseName.spark_connector_primitive")
+ .option("doris.sink.batch.size", 3)
+ .option("doris.sink.properties.format", "arrow")
+ .option("doris.sink.max-retries", 0)
+ .save()
+ }
+
+ def testDataframeWriteArrayTypes(): Unit = {
+ /*
+
+ CREATE TABLE `spark_connector_array` (
+ `id` int(11) NOT NULL,
+ `c_array_boolean` ARRAY<boolean> NULL,
+ `c_array_tinyint` ARRAY<tinyint> NULL,
+ `c_array_smallint` ARRAY<smallint> NULL,
+ `c_array_int` ARRAY<int> NULL,
+ `c_array_bigint` ARRAY<bigint> NULL,
+ `c_array_largeint` ARRAY<largeint> NULL,
+ `c_array_float` ARRAY<float> NULL,
+ `c_array_double` ARRAY<double> NULL,
+ `c_array_decimal` ARRAY<DECIMAL(10, 5)> NULL,
+ `c_array_date` ARRAY<date> NULL,
+ `c_array_datetime` ARRAY<datetime(6)> NULL,
+ `c_array_char` ARRAY<char(10)> NULL,
+ `c_array_varchar` ARRAY<varchar(10)> NULL,
+ `c_array_string` ARRAY<string> NULL
+ ) ENGINE=OLAP
+ DUPLICATE KEY(`id`)
+ COMMENT 'OLAP'
+ DISTRIBUTED BY HASH(`id`) BUCKETS 1
+ PROPERTIES (
+ "replication_allocation" = "tag.location.default: 1"
+ );
+
+
+------------------+-------------------------+------+-------+---------+-------+
+ | Field | Type | Null | Key | Default |
Extra |
+
+------------------+-------------------------+------+-------+---------+-------+
+ | id | INT | No | true | NULL |
|
+ | c_array_boolean | ARRAY<BOOLEAN> | Yes | false | [] |
NONE |
+ | c_array_tinyint | ARRAY<TINYINT> | Yes | false | [] |
NONE |
+ | c_array_smallint | ARRAY<SMALLINT> | Yes | false | [] |
NONE |
+ | c_array_int | ARRAY<INT> | Yes | false | [] |
NONE |
+ | c_array_bigint | ARRAY<BIGINT> | Yes | false | [] |
NONE |
+ | c_array_largeint | ARRAY<LARGEINT> | Yes | false | [] |
NONE |
+ | c_array_float | ARRAY<FLOAT> | Yes | false | [] |
NONE |
+ | c_array_double | ARRAY<DOUBLE> | Yes | false | [] |
NONE |
+ | c_array_decimal | ARRAY<DECIMALV3(10, 5)> | Yes | false | [] |
NONE |
+ | c_array_date | ARRAY<DATEV2> | Yes | false | [] |
NONE |
+ | c_array_datetime | ARRAY<DATETIMEV2(6)> | Yes | false | [] |
NONE |
+ | c_array_char | ARRAY<CHAR(10)> | Yes | false | [] |
NONE |
+ | c_array_varchar | ARRAY<VARCHAR(10)> | Yes | false | [] |
NONE |
+ | c_array_string | ARRAY<TEXT> | Yes | false | [] |
NONE |
+
+------------------+-------------------------+------+-------+---------+-------+
+
+ */
+
+ val schema = new StructType()
+ .add("id", IntegerType)
+ .add("c_array_boolean", ArrayType(BooleanType))
+ .add("c_array_tinyint", ArrayType(ByteType))
+ .add("c_array_smallint", ArrayType(ShortType))
+ .add("c_array_int", ArrayType(IntegerType))
+ .add("c_array_bigint", ArrayType(LongType))
+ .add("c_array_largeint", ArrayType(StringType))
+ .add("c_array_float", ArrayType(FloatType))
+ .add("c_array_double", ArrayType(DoubleType))
+ .add("c_array_decimal", ArrayType(DecimalType.apply(10, 5)))
+ .add("c_array_date", ArrayType(DateType))
+ .add("c_array_datetime", ArrayType(TimestampType))
+ .add("c_array_char", ArrayType(StringType))
+ .add("c_array_varchar", ArrayType(StringType))
+ .add("c_array_string", ArrayType(StringType))
+
+ val row = Row(
+ 1,
+ Array(true, false, false, true, true),
+ Array(1.toByte, 2.toByte, 3.toByte),
+ Array(2.toShort, 12.toShort, 32.toShort),
+ Array(3, 4, 5, 6),
+ Array(4.toLong, 5.toLong, 6.toLong),
+ Array("123456789", "987654321", "123789456"),
+ Array(6.6.floatValue(), 6.7.floatValue(), 7.8.floatValue()),
+ Array(7.7.doubleValue(), 8.8.doubleValue(), 8.9.floatValue()),
+ Array(Decimal.apply(3.12), Decimal.apply(1.12345)),
+ Array(Date.valueOf("2023-09-08"), Date.valueOf("2027-10-28")),
+ Array(Timestamp.valueOf("2023-09-08 17:12:34.123456"),
Timestamp.valueOf("2024-09-08 18:12:34.123456")),
+ Array("char", "char2"),
+ Array("varchar", "varchar2"),
+ Array("string", "string2")
+ )
+
+
+ val inputList = ListBuffer[Row]()
+ for (a <- 0 until 7) {
+ inputList.append(row)
+ }
+
+ val rdd = spark.sparkContext.parallelize(inputList, 1)
+ val df = spark.createDataFrame(rdd, schema).toDF()
+
+ df.write
+ .format("doris")
+ .option("doris.fenodes", dorisFeNodes)
+ .option("user", dorisUser)
+ .option("password", dorisPwd)
+ .option("doris.table.identifier", s"$databaseName.spark_connector_array")
+ .option("doris.sink.batch.size", 30)
+ .option("doris.sink.properties.format", "arrow")
+ .option("doris.sink.max-retries", 0)
+ .save()
+ }
+
+ def testDataframeWriteMapType(): Unit = {
+ /*
+
+ CREATE TABLE `spark_connector_map` (
+ `id` int(11) NOT NULL,
+ `c_map_bool` Map<boolean,boolean> NULL,
+ `c_map_tinyint` Map<tinyint,tinyint> NULL,
+ `c_map_smallint` Map<smallint,smallint> NULL,
+ `c_map_int` Map<int,int> NULL,
+ `c_map_bigint` Map<bigint,bigint> NULL,
+ `c_map_largeint` Map<largeint,largeint> NULL,
+ `c_map_float` Map<float,float> NULL,
+ `c_map_double` Map<double,double> NULL,
+ `c_map_decimal` Map<DECIMAL(10, 5),DECIMAL(10, 5)> NULL,
+ `c_map_date` Map<date,date> NULL,
+ `c_map_datetime` Map<datetime(6),datetime(6)> NULL,
+ `c_map_char` Map<char(10),char(10)> NULL,
+ `c_map_varchar` Map<varchar(10),varchar(10)> NULL,
+ `c_map_string` Map<string,string> NULL
+ ) ENGINE=OLAP
+ DUPLICATE KEY(`id`)
+ COMMENT 'OLAP'
+ DISTRIBUTED BY HASH(`id`) BUCKETS 1
+ PROPERTIES (
+ "replication_allocation" = "tag.location.default: 1"
+ );
+
+
+----------------+----------------------------------------+------+-------+---------+-------+
+ | Field | Type | Null | Key |
Default | Extra |
+
+----------------+----------------------------------------+------+-------+---------+-------+
+ | id | INT | No | true |
NULL | |
+ | c_map_bool | MAP<BOOLEAN,BOOLEAN> | Yes | false |
NULL | NONE |
+ | c_map_tinyint | MAP<TINYINT,TINYINT> | Yes | false |
NULL | NONE |
+ | c_map_smallint | MAP<SMALLINT,SMALLINT> | Yes | false |
NULL | NONE |
+ | c_map_int | MAP<INT,INT> | Yes | false |
NULL | NONE |
+ | c_map_bigint | MAP<BIGINT,BIGINT> | Yes | false |
NULL | NONE |
+ | c_map_largeint | MAP<LARGEINT,LARGEINT> | Yes | false |
NULL | NONE |
+ | c_map_float | MAP<FLOAT,FLOAT> | Yes | false |
NULL | NONE |
+ | c_map_double | MAP<DOUBLE,DOUBLE> | Yes | false |
NULL | NONE |
+ | c_map_decimal | MAP<DECIMALV3(10, 5),DECIMALV3(10, 5)> | Yes | false |
NULL | NONE |
+ | c_map_date | MAP<DATEV2,DATEV2> | Yes | false |
NULL | NONE |
+ | c_map_datetime | MAP<DATETIMEV2(6),DATETIMEV2(6)> | Yes | false |
NULL | NONE |
+ | c_map_char | MAP<CHAR(10),CHAR(10)> | Yes | false |
NULL | NONE |
+ | c_map_varchar | MAP<VARCHAR(10),VARCHAR(10)> | Yes | false |
NULL | NONE |
+ | c_map_string | MAP<TEXT,TEXT> | Yes | false |
NULL | NONE |
+
+----------------+----------------------------------------+------+-------+---------+-------+
+
+ */
+
+ val schema = new StructType()
+ .add("id", IntegerType)
+ .add("c_map_bool", MapType(BooleanType, BooleanType))
+ .add("c_map_tinyint", MapType(ByteType, ByteType))
+ .add("c_map_smallint", MapType(ShortType, ShortType))
+ .add("c_map_int", MapType(IntegerType, IntegerType))
+ .add("c_map_bigint", MapType(LongType, LongType))
+ .add("c_map_largeint", MapType(StringType, StringType))
+ .add("c_map_float", MapType(FloatType, FloatType))
+ .add("c_map_double", MapType(DoubleType, DoubleType))
+ .add("c_map_decimal", MapType(DecimalType.apply(10, 5),
DecimalType.apply(10, 5)))
+ .add("c_map_date", MapType(DateType, DateType))
+ .add("c_map_datetime", MapType(TimestampType, TimestampType))
+ .add("c_map_char", MapType(StringType, StringType))
+ .add("c_map_varchar", MapType(StringType, StringType))
+ .add("c_map_string", MapType(StringType, StringType))
+
+ val row = Row(
+ 1,
+ Map(true -> false, false -> true, true -> true),
+ Map(1.toByte -> 2.toByte, 3.toByte -> 4.toByte),
+ Map(2.toShort -> 4.toShort, 5.toShort -> 6.toShort),
+ Map(3 -> 4, 7 -> 8),
+ Map(4.toLong -> 5.toLong, 1.toLong -> 2.toLong),
+ Map("123456789" -> "987654321", "789456123" -> "456789123"),
+ Map(6.6.floatValue() -> 8.8.floatValue(), 9.9.floatValue() ->
10.1.floatValue()),
+ Map(7.7.doubleValue() -> 1.1.doubleValue(), 2.2 -> 3.3.doubleValue()),
+ Map(Decimal.apply(3.12) -> Decimal.apply(1.23), Decimal.apply(2.34) ->
Decimal.apply(5.67)),
+ Map(Date.valueOf("2023-09-08") -> Date.valueOf("2024-09-08"),
Date.valueOf("1023-09-08") -> Date.valueOf("2023-09-08")),
+ Map(Timestamp.valueOf("1023-09-08 17:12:34.123456") ->
Timestamp.valueOf("2023-09-08 17:12:34.123456"), Timestamp.valueOf("3023-09-08
17:12:34.123456") -> Timestamp.valueOf("4023-09-08 17:12:34.123456")),
+ Map("char" -> "char2", "char2" -> "char3"),
+ Map("varchar" -> "varchar2", "varchar3" -> "varchar4"),
+ Map("string" -> "string2", "string3" -> "string4")
+ )
+
+
+ val inputList = ListBuffer[Row]()
+ for (a <- 0 until 7) {
+ inputList.append(row)
+ }
+
+ val rdd = spark.sparkContext.parallelize(inputList, 1)
+ val df = spark.createDataFrame(rdd, schema).toDF()
+
+ df.write
+ .format("doris")
+ .option("doris.fenodes", dorisFeNodes)
+ .option("user", dorisUser)
+ .option("password", dorisPwd)
+ .option("doris.table.identifier", s"$databaseName.spark_connector_map")
+ .option("doris.sink.batch.size", 3)
+ .option("doris.sink.properties.format", "arrow")
+ .option("doris.sink.max-retries", 0)
+ .save()
+ }
+
+ def testDataframeWriteStructType(): Unit = {
+ /*
+
+CREATE TABLE `spark_connector_struct` (
+ `id` int NOT NULL,
+ `st` STRUCT<
+ `c_bool`:boolean,
+ `c_tinyint`:tinyint(4),
+ `c_smallint`:smallint(6),
+ `c_int`:int(11),
+ `c_bigint`:bigint(20),
+ `c_largeint`:largeint(40),
+ `c_float`:float,
+ `c_double`:double,
+ `c_decimal`:DECIMAL(10, 5),
+ `c_date`:date,
+ `c_datetime`:datetime(6),
+ `c_char`:char(10),
+ `c_varchar`:varchar(10),
+ `c_string`:string
+ > NULL
+ ) ENGINE=OLAP
+ DUPLICATE KEY(`id`)
+ COMMENT 'OLAP'
+ DISTRIBUTED BY HASH(`id`) BUCKETS 1
+ PROPERTIES (
+ "replication_allocation" = "tag.location.default: 1"
+ );
+
+
+-------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+------+-------+---------+-------+
+ | Field | Type
| Null | Key | Default | Extra |
+
+-------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+------+-------+---------+-------+
+ | id | INT
| No | true | NULL | |
+ | st |
STRUCT<c_bool:BOOLEAN,c_tinyint:TINYINT,c_smallint:SMALLINT,c_int:INT,c_bigint:BIGINT,c_largeint:LARGEINT,c_float:FLOAT,c_double:DOUBLE,c_decimal:DECIMALV3(10,
5),c_date:DATEV2,c_datetime:DATETIMEV2(6),c_char:CHAR(10),c_varchar:VARCHAR(10),c_string:TEXT>
| Yes | false | NULL | NONE |
+
+-------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+------+-------+---------+-------+
+
+ */
+
+ val st = new StructType()
+ .add("c_bool", BooleanType)
+ .add("c_tinyint", ByteType)
+ .add("c_smallint", ShortType)
+ .add("c_int", IntegerType)
+ .add("c_bigint", LongType)
+ .add("c_largeint", StringType)
+ .add("c_float", FloatType)
+ .add("c_double", DoubleType)
+ .add("c_decimal", DecimalType.apply(10, 5))
+ .add("c_date", DateType)
+ .add("c_datetime", TimestampType)
+ .add("c_char", StringType)
+ .add("c_varchar", StringType)
+ .add("c_string", StringType)
+
+ val schema = new StructType()
+ .add("id", IntegerType)
+ .add("st", st)
+
+ val row = Row(
+ 1,
+ Row(true,
+ 1.toByte,
+ 2.toShort,
+ 3,
+ 4.toLong,
+ "123456789",
+ 6.6.floatValue(),
+ 7.7.doubleValue(),
+ Decimal.apply(3.12),
+ Date.valueOf("2023-09-08"),
+ Timestamp.valueOf("2023-09-08 17:12:34.123456"),
+ "char",
+ "varchar",
+ "string")
+ )
+
+
+ val inputList = ListBuffer[Row]()
+ for (a <- 0 until 7) {
+ inputList.append(row)
+ }
+
+ val rdd = spark.sparkContext.parallelize(inputList, 1)
+ val df = spark.createDataFrame(rdd, schema).toDF()
+
+ df.write
+ .format("doris")
+ .option("doris.fenodes", dorisFeNodes)
+ .option("user", dorisUser)
+ .option("password", dorisPwd)
+ .option("doris.table.identifier",
s"$databaseName.spark_connector_struct")
+ .option("doris.sink.batch.size", 3)
+ .option("doris.sink.properties.format", "arrow")
+ .option("doris.sink.max-retries", 0)
+ .save()
+ }
+}
diff --git
a/spark-doris-connector/src/main/scala/org/apache/spark/sql/doris/spark/ArrowSchemaUtils.scala
b/spark-doris-connector/src/main/scala/org/apache/spark/sql/doris/spark/ArrowSchemaUtils.scala
new file mode 100644
index 0000000..3ef75a2
--- /dev/null
+++
b/spark-doris-connector/src/main/scala/org/apache/spark/sql/doris/spark/ArrowSchemaUtils.scala
@@ -0,0 +1,52 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.spark.sql.doris.spark
+
+import org.apache.arrow.vector.types.pojo.Schema
+import org.apache.spark.sql.types.StructType
+
+object ArrowSchemaUtils {
+ var classArrowUtils: Option[Class[_]] = None: Option[Class[_]]
+
+ def tryLoadArrowUtilsClass(): Unit = {
+ if (classArrowUtils.isEmpty) {
+ // for spark3.x
+ classArrowUtils =
classArrowUtils.orElse(tryLoadClass("org.apache.spark.sql.util.ArrowUtils"))
+ // for spark2.x
+ classArrowUtils =
classArrowUtils.orElse(tryLoadClass("org.apache.spark.sql.execution.arrow.ArrowUtils"))
+ if (classArrowUtils.isEmpty) {
+ throw new ClassNotFoundException("can't load class for ArrowUtils")
+ }
+ }
+ }
+
+ def tryLoadClass(className: String): Option[Class[_]] = {
+ try {
+ Some(Class.forName(className))
+ } catch {
+ case e: ClassNotFoundException =>
+ None
+ }
+ }
+
+ def toArrowSchema(schema: StructType, timeZoneId: String): Schema = {
+ tryLoadArrowUtilsClass()
+ val toArrowSchema = classArrowUtils.get.getMethod("toArrowSchema",
classOf[StructType], classOf[String])
+ toArrowSchema.invoke(null, schema, timeZoneId).asInstanceOf[Schema]
+ }
+}
diff --git
a/spark-doris-connector/src/test/scala/org/apache/doris/spark/sql/TestConnectorWriteDoris.scala
b/spark-doris-connector/src/test/scala/org/apache/doris/spark/sql/TestConnectorWriteDoris.scala
index fecface..dbea21a 100644
---
a/spark-doris-connector/src/test/scala/org/apache/doris/spark/sql/TestConnectorWriteDoris.scala
+++
b/spark-doris-connector/src/test/scala/org/apache/doris/spark/sql/TestConnectorWriteDoris.scala
@@ -184,5 +184,4 @@ class TestConnectorWriteDoris {
.save()
spark.stop()
}
-
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]