This is an automated email from the ASF dual-hosted git repository.

diwu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris-spark-connector.git


The following commit(s) were added to refs/heads/master by this push:
     new 4849ee7  [feature](step:one)Add arrow format (#161)
4849ee7 is described below

commit 4849ee72968d2da958bde621e00d252e3049f500
Author: wuwenchi <[email protected]>
AuthorDate: Fri Dec 8 21:53:27 2023 +0800

    [feature](step:one)Add arrow format (#161)
---
 README.md                                          |  23 +
 spark-doris-connector/pom.xml                      |   2 +-
 .../org/apache/doris/spark/load/DataFormat.java    |  24 ++
 .../apache/doris/spark/load/DorisStreamLoad.java   |  29 +-
 .../org/apache/doris/spark/load/RecordBatch.java   |  37 +-
 .../doris/spark/load/RecordBatchInputStream.java   |  78 +++-
 .../java/org/apache/doris/spark/util/DataUtil.java |   2 -
 .../testcase/TestStreamLoadForArrowType.scala      | 461 +++++++++++++++++++++
 .../spark/sql/doris/spark/ArrowSchemaUtils.scala   |  52 +++
 .../doris/spark/sql/TestConnectorWriteDoris.scala  |   1 -
 10 files changed, 665 insertions(+), 44 deletions(-)

diff --git a/README.md b/README.md
index 70c0a77..5661318 100644
--- a/README.md
+++ b/README.md
@@ -124,6 +124,29 @@ dorisSparkDF = spark.read.format("doris")
 dorisSparkDF.show(5)
 ```
 
+## type convertion for writing to doris using arrow
+|doris|spark|
+|---|---|
+| BOOLEAN | BooleanType |
+| TINYINT | ByteType |
+| SMALLINT | ShortType |
+| INT | IntegerType |
+| BIGINT | LongType |
+| LARGEINT | StringType |
+| FLOAT | FloatType |
+| DOUBLE | DoubleType |
+| DECIMAL(M,D) | DecimalType(M,D) |
+| DATE | DateType |
+| DATETIME | TimestampType |
+| CHAR(L) | StringType |
+| VARCHAR(L) | StringType |
+| STRING | StringType |
+| ARRAY | ARRAY |
+| MAP | MAP |
+| STRUCT | STRUCT |
+
+
+
 ## Report issues or submit pull request
 
 If you find any bugs, feel free to file a [GitHub 
issue](https://github.com/apache/doris/issues) or fix it by submitting a [pull 
request](https://github.com/apache/doris/pulls).
diff --git a/spark-doris-connector/pom.xml b/spark-doris-connector/pom.xml
index 518a3e2..89e1716 100644
--- a/spark-doris-connector/pom.xml
+++ b/spark-doris-connector/pom.xml
@@ -72,7 +72,7 @@
         <spark.major.version>3.1</spark.major.version>
         <scala.version>2.12</scala.version>
         <libthrift.version>0.16.0</libthrift.version>
-        <arrow.version>5.0.0</arrow.version>
+        <arrow.version>13.0.0</arrow.version>
         <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
         <project.scm.id>github</project.scm.id>
         <netty.version>4.1.77.Final</netty.version>
diff --git 
a/spark-doris-connector/src/main/java/org/apache/doris/spark/load/DataFormat.java
 
b/spark-doris-connector/src/main/java/org/apache/doris/spark/load/DataFormat.java
new file mode 100644
index 0000000..e3e3b0b
--- /dev/null
+++ 
b/spark-doris-connector/src/main/java/org/apache/doris/spark/load/DataFormat.java
@@ -0,0 +1,24 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.spark.load;
+
+public enum DataFormat {
+    CSV,
+    JSON,
+    ARROW
+}
diff --git 
a/spark-doris-connector/src/main/java/org/apache/doris/spark/load/DorisStreamLoad.java
 
b/spark-doris-connector/src/main/java/org/apache/doris/spark/load/DorisStreamLoad.java
index f473636..7758c6f 100644
--- 
a/spark-doris-connector/src/main/java/org/apache/doris/spark/load/DorisStreamLoad.java
+++ 
b/spark-doris-connector/src/main/java/org/apache/doris/spark/load/DorisStreamLoad.java
@@ -98,7 +98,7 @@ public class DorisStreamLoad implements Serializable {
     private static final long cacheExpireTimeout = 4 * 60;
     private final LoadingCache<String, List<BackendV2.BackendRowV2>> cache;
     private final String fenodes;
-    private final String fileType;
+    private final DataFormat dataFormat;
     private String FIELD_DELIMITER;
     private final String LINE_DELIMITER;
     private boolean streamingPassthrough = false;
@@ -119,16 +119,19 @@ public class DorisStreamLoad implements Serializable {
         this.maxFilterRatio = 
settings.getProperty(ConfigurationOptions.DORIS_MAX_FILTER_RATIO);
         this.streamLoadProp = getStreamLoadProp(settings);
         cache = CacheBuilder.newBuilder().expireAfterWrite(cacheExpireTimeout, 
TimeUnit.MINUTES).build(new BackendCacheLoader(settings));
-        fileType = streamLoadProp.getOrDefault("format", "csv");
-        if ("csv".equals(fileType)) {
-            FIELD_DELIMITER = 
escapeString(streamLoadProp.getOrDefault("column_separator", "\t"));
-            this.addDoubleQuotes = 
Boolean.parseBoolean(streamLoadProp.getOrDefault("add_double_quotes", "false"));
-            if (addDoubleQuotes) {
-                LOG.info("set add_double_quotes for csv mode, add 
trim_double_quotes to true for prop.");
-                streamLoadProp.put("trim_double_quotes", "true");
-            }
-        } else if ("json".equalsIgnoreCase(fileType)) {
-            streamLoadProp.put("read_json_by_line", "true");
+        dataFormat = DataFormat.valueOf(streamLoadProp.getOrDefault("format", 
"csv").toUpperCase());
+        switch (dataFormat) {
+            case CSV:
+                FIELD_DELIMITER = 
escapeString(streamLoadProp.getOrDefault("column_separator", "\t"));
+                this.addDoubleQuotes = 
Boolean.parseBoolean(streamLoadProp.getOrDefault("add_double_quotes", "false"));
+                if (addDoubleQuotes) {
+                    LOG.info("set add_double_quotes for csv mode, add 
trim_double_quotes to true for prop.");
+                    streamLoadProp.put("trim_double_quotes", "true");
+                }
+                break;
+            case JSON:
+                streamLoadProp.put("read_json_by_line", "true");
+                break;
         }
         LINE_DELIMITER = 
escapeString(streamLoadProp.getOrDefault("line_delimiter", "\n"));
         this.streamingPassthrough = 
settings.getBooleanProperty(ConfigurationOptions.DORIS_SINK_STREAMING_PASSTHROUGH,
@@ -218,7 +221,7 @@ public class DorisStreamLoad implements Serializable {
             this.loadUrlStr = loadUrlStr;
             HttpPut httpPut = getHttpPut(label, loadUrlStr, enable2PC, schema);
             RecordBatchInputStream recodeBatchInputStream = new 
RecordBatchInputStream(RecordBatch.newBuilder(rows)
-                    .format(fileType)
+                    .format(dataFormat)
                     .sep(FIELD_DELIMITER)
                     .delim(LINE_DELIMITER)
                     .schema(schema)
@@ -492,7 +495,7 @@ public class DorisStreamLoad implements Serializable {
      */
     private void handleStreamPassThrough() {
 
-        if ("json".equalsIgnoreCase(fileType)) {
+        if (dataFormat.equals(DataFormat.JSON)) {
             LOG.info("handle stream pass through, force set read_json_by_line 
is true for json format");
             streamLoadProp.put("read_json_by_line", "true");
             streamLoadProp.remove("strip_outer_array");
diff --git 
a/spark-doris-connector/src/main/java/org/apache/doris/spark/load/RecordBatch.java
 
b/spark-doris-connector/src/main/java/org/apache/doris/spark/load/RecordBatch.java
index e471d5b..b514586 100644
--- 
a/spark-doris-connector/src/main/java/org/apache/doris/spark/load/RecordBatch.java
+++ 
b/spark-doris-connector/src/main/java/org/apache/doris/spark/load/RecordBatch.java
@@ -17,7 +17,11 @@
 
 package org.apache.doris.spark.load;
 
+import org.apache.arrow.memory.RootAllocator;
+import org.apache.arrow.vector.VectorSchemaRoot;
+import org.apache.arrow.vector.types.pojo.Schema;
 import org.apache.spark.sql.catalyst.InternalRow;
+import org.apache.spark.sql.doris.spark.ArrowSchemaUtils;
 import org.apache.spark.sql.types.StructType;
 
 import java.nio.charset.Charset;
@@ -39,7 +43,7 @@ public class RecordBatch {
     /**
      * stream load format
      */
-    private final String format;
+    private final DataFormat format;
 
     /**
      * column separator, only used when the format is csv
@@ -58,7 +62,11 @@ public class RecordBatch {
 
     private final boolean addDoubleQuotes;
 
-    private RecordBatch(Iterator<InternalRow> iterator, String format, String 
sep, byte[] delim,
+    private VectorSchemaRoot arrowRoot = null;
+
+    private int arrowBatchSize = 1000;
+
+    private RecordBatch(Iterator<InternalRow> iterator, DataFormat format, 
String sep, byte[] delim,
                         StructType schema, boolean addDoubleQuotes) {
         this.iterator = iterator;
         this.format = format;
@@ -66,13 +74,17 @@ public class RecordBatch {
         this.delim = delim;
         this.schema = schema;
         this.addDoubleQuotes = addDoubleQuotes;
+        if (format.equals(DataFormat.ARROW)) {
+            Schema arrowSchema = ArrowSchemaUtils.toArrowSchema(schema, "UTC");
+            this.arrowRoot = VectorSchemaRoot.create(arrowSchema, new 
RootAllocator(Integer.MAX_VALUE));
+        }
     }
 
     public Iterator<InternalRow> getIterator() {
         return iterator;
     }
 
-    public String getFormat() {
+    public DataFormat getFormat() {
         return format;
     }
 
@@ -84,13 +96,28 @@ public class RecordBatch {
         return delim;
     }
 
+    public VectorSchemaRoot getVectorSchemaRoot() {
+        return arrowRoot;
+    }
+
     public StructType getSchema() {
         return schema;
     }
 
+    public int getArrowBatchSize() {
+        return arrowBatchSize;
+    }
+
     public boolean getAddDoubleQuotes(){
         return addDoubleQuotes;
     }
+
+    public void clearBatch() {
+        if (format.equals(DataFormat.ARROW)) {
+            this.arrowRoot.clear();
+        }
+    }
+
     public static Builder newBuilder(Iterator<InternalRow> iterator) {
         return new Builder(iterator);
     }
@@ -102,7 +129,7 @@ public class RecordBatch {
 
         private final Iterator<InternalRow> iterator;
 
-        private String format;
+        private DataFormat format;
 
         private String sep;
 
@@ -116,7 +143,7 @@ public class RecordBatch {
             this.iterator = iterator;
         }
 
-        public Builder format(String format) {
+        public Builder format(DataFormat format) {
             this.format = format;
             return this;
         }
diff --git 
a/spark-doris-connector/src/main/java/org/apache/doris/spark/load/RecordBatchInputStream.java
 
b/spark-doris-connector/src/main/java/org/apache/doris/spark/load/RecordBatchInputStream.java
index 544e683..c43f685 100644
--- 
a/spark-doris-connector/src/main/java/org/apache/doris/spark/load/RecordBatchInputStream.java
+++ 
b/spark-doris-connector/src/main/java/org/apache/doris/spark/load/RecordBatchInputStream.java
@@ -17,6 +17,8 @@
 
 package org.apache.doris.spark.load;
 
+import org.apache.arrow.vector.dictionary.DictionaryProvider;
+import org.apache.arrow.vector.ipc.ArrowStreamWriter;
 import org.apache.doris.spark.exception.DorisException;
 import org.apache.doris.spark.exception.IllegalArgumentException;
 import org.apache.doris.spark.exception.ShouldNeverHappenException;
@@ -24,9 +26,11 @@ import org.apache.doris.spark.util.DataUtil;
 
 import com.fasterxml.jackson.core.JsonProcessingException;
 import org.apache.spark.sql.catalyst.InternalRow;
+import org.apache.spark.sql.execution.arrow.ArrowWriter;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
+import java.io.ByteArrayOutputStream;
 import java.io.IOException;
 import java.io.InputStream;
 import java.nio.ByteBuffer;
@@ -54,12 +58,17 @@ public class RecordBatchInputStream extends InputStream {
      * record buffer
      */
 
-    private ByteBuffer lineBuf = ByteBuffer.allocate(0);;
+    private ByteBuffer lineBuf = ByteBuffer.allocate(0);
 
     private ByteBuffer delimBuf = ByteBuffer.allocate(0);
 
     private final byte[] delim;
 
+    /**
+     * record count has been read
+     */
+    private int readCount = 0;
+
     /**
      * streaming mode pass through data without process
      */
@@ -73,18 +82,13 @@ public class RecordBatchInputStream extends InputStream {
 
     @Override
     public int read() throws IOException {
-        try {
-            if (lineBuf.remaining() == 0 && endOfBatch()) {
-                return -1;
-            }
-
-            if (delimBuf != null && delimBuf.remaining() > 0) {
-                return delimBuf.get() & 0xff;
-            }
-        } catch (DorisException e) {
-            throw new IOException(e);
+        byte[] bytes = new byte[1];
+        int read = read(bytes, 0, 1);
+        if (read < 0) {
+            return -1;
+        } else {
+            return bytes[0];
         }
-        return lineBuf.get() & 0xFF;
     }
 
     @Override
@@ -102,6 +106,7 @@ public class RecordBatchInputStream extends InputStream {
         } catch (DorisException e) {
             throw new IOException(e);
         }
+
         int bytesRead = Math.min(len, lineBuf.remaining());
         lineBuf.get(b, off, bytesRead);
         return bytesRead;
@@ -121,6 +126,8 @@ public class RecordBatchInputStream extends InputStream {
             readNext(iterator);
             return false;
         }
+
+        recordBatch.clearBatch();
         delimBuf = null;
         return true;
     }
@@ -135,14 +142,41 @@ public class RecordBatchInputStream extends InputStream {
         if (!iterator.hasNext()) {
             throw new ShouldNeverHappenException();
         }
-        byte[] rowBytes = rowToByte(iterator.next());
-        if (isFirst) {
+
+        if (recordBatch.getFormat().equals(DataFormat.ARROW)) {
+            ArrowWriter arrowWriter = 
ArrowWriter.create(recordBatch.getVectorSchemaRoot());
+            while (iterator.hasNext() && readCount <  
recordBatch.getArrowBatchSize()) {
+                arrowWriter.write(iterator.next());
+                readCount++;
+            }
+            arrowWriter.finish();
+
+            ByteArrayOutputStream out = new ByteArrayOutputStream();
+            ArrowStreamWriter writer = new ArrowStreamWriter(
+                recordBatch.getVectorSchemaRoot(),
+                new DictionaryProvider.MapDictionaryProvider(),
+                out);
+
+            try {
+                writer.writeBatch();
+                writer.end();
+            } catch (IOException e) {
+                throw new DorisException(e);
+            }
+
             delimBuf = null;
-            lineBuf = ByteBuffer.wrap(rowBytes);
-            isFirst = false;
+            lineBuf = ByteBuffer.wrap(out.toByteArray());
+            readCount = 0;
         } else {
-            delimBuf =  ByteBuffer.wrap(delim);
-            lineBuf = ByteBuffer.wrap(rowBytes);
+            byte[] rowBytes = rowToByte(iterator.next());
+            if (isFirst) {
+                delimBuf = null;
+                lineBuf = ByteBuffer.wrap(rowBytes);
+                isFirst = false;
+            } else {
+                delimBuf =  ByteBuffer.wrap(delim);
+                lineBuf = ByteBuffer.wrap(rowBytes);
+            }
         }
     }
 
@@ -162,11 +196,11 @@ public class RecordBatchInputStream extends InputStream {
             return bytes;
         }
 
-        switch (recordBatch.getFormat().toLowerCase()) {
-            case "csv":
+        switch (recordBatch.getFormat()) {
+            case CSV:
                 bytes = DataUtil.rowToCsvBytes(row, recordBatch.getSchema(), 
recordBatch.getSep(), recordBatch.getAddDoubleQuotes());
                 break;
-            case "json":
+            case JSON:
                 try {
                     bytes = DataUtil.rowToJsonBytes(row, 
recordBatch.getSchema());
                 } catch (JsonProcessingException e) {
@@ -174,7 +208,7 @@ public class RecordBatchInputStream extends InputStream {
                 }
                 break;
             default:
-                throw new IllegalArgumentException("format", 
recordBatch.getFormat());
+                throw new IllegalArgumentException("Unsupported format: ", 
recordBatch.getFormat().toString());
         }
 
         return bytes;
diff --git 
a/spark-doris-connector/src/main/java/org/apache/doris/spark/util/DataUtil.java 
b/spark-doris-connector/src/main/java/org/apache/doris/spark/util/DataUtil.java
index 763d72b..f7218c3 100644
--- 
a/spark-doris-connector/src/main/java/org/apache/doris/spark/util/DataUtil.java
+++ 
b/spark-doris-connector/src/main/java/org/apache/doris/spark/util/DataUtil.java
@@ -31,10 +31,8 @@ import org.apache.spark.sql.types.StructType;
 import java.nio.charset.StandardCharsets;
 import java.util.HashMap;
 import java.util.Map;
-import java.util.Objects;
 import java.util.stream.Collectors;
 import java.util.stream.IntStream;
-import java.util.stream.Stream;
 
 public class DataUtil {
 
diff --git 
a/spark-doris-connector/src/main/scala/org/apache/doris/spark/testcase/TestStreamLoadForArrowType.scala
 
b/spark-doris-connector/src/main/scala/org/apache/doris/spark/testcase/TestStreamLoadForArrowType.scala
new file mode 100644
index 0000000..f1e1977
--- /dev/null
+++ 
b/spark-doris-connector/src/main/scala/org/apache/doris/spark/testcase/TestStreamLoadForArrowType.scala
@@ -0,0 +1,461 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.spark.testcase
+
+import org.apache.spark.sql.types._
+import org.apache.spark.sql.{Row, SparkSession}
+
+import java.sql.{Date, Timestamp}
+import scala.collection.mutable.ListBuffer
+
+// This object is used to test writing from spark into doris with arrow format,
+// And it will be executed in doris's pipeline.
+object TestStreamLoadForArrowType {
+  val spark: SparkSession = 
SparkSession.builder().master("local[1]").getOrCreate()
+  var dorisFeNodes = "127.0.0.1:8030"
+  var dorisUser = "root"
+  val dorisPwd = ""
+  var databaseName = ""
+
+  def main(args: Array[String]): Unit = {
+
+    dorisFeNodes = args(0)
+    dorisUser = args(1)
+    databaseName = args(2)
+
+    testDataframeWritePrimitiveType()
+    testDataframeWriteArrayTypes()
+    testDataframeWriteMapType()
+    testDataframeWriteStructType()
+
+    spark.stop()
+  }
+
+
+  def testDataframeWritePrimitiveType(): Unit = {
+    /*
+
+  CREATE TABLE `spark_connector_primitive` (
+        `id` int(11) NOT NULL,
+        `c_bool` boolean NULL,
+        `c_tinyint` tinyint NULL,
+        `c_smallint` smallint NULL,
+        `c_int` int NULL,
+        `c_bigint` bigint NULL,
+        `c_largeint` largeint NULL,
+        `c_float` float NULL,
+        `c_double` double NULL,
+        `c_decimal` DECIMAL(10, 5) NULL,
+        `c_date` date NULL,
+        `c_datetime` datetime(6) NULL,
+        `c_char` char(10) NULL,
+        `c_varchar` varchar(10) NULL,
+        `c_string` string NULL
+      ) ENGINE=OLAP
+      DUPLICATE KEY(`id`)
+      COMMENT 'OLAP'
+      DISTRIBUTED BY HASH(`id`) BUCKETS 1
+      PROPERTIES (
+      "replication_allocation" = "tag.location.default: 1"
+      );
+
+    +------------+----------------+------+-------+---------+-------+
+    | Field      | Type           | Null | Key   | Default | Extra |
+    +------------+----------------+------+-------+---------+-------+
+    | id         | INT            | No   | true  | NULL    |       |
+    | c_bool     | BOOLEAN        | Yes  | false | NULL    | NONE  |
+    | c_tinyint  | TINYINT        | Yes  | false | NULL    | NONE  |
+    | c_smallint | SMALLINT       | Yes  | false | NULL    | NONE  |
+    | c_int      | INT            | Yes  | false | NULL    | NONE  |
+    | c_bigint   | BIGINT         | Yes  | false | NULL    | NONE  |
+    | c_largeint | LARGEINT       | Yes  | false | NULL    | NONE  |
+    | c_float    | FLOAT          | Yes  | false | NULL    | NONE  |
+    | c_double   | DOUBLE         | Yes  | false | NULL    | NONE  |
+    | c_decimal  | DECIMAL(10, 5) | Yes  | false | NULL    | NONE  |
+    | c_date     | DATE           | Yes  | false | NULL    | NONE  |
+    | c_datetime | DATETIME(6)    | Yes  | false | NULL    | NONE  |
+    | c_char     | CHAR(10)       | Yes  | false | NULL    | NONE  |
+    | c_varchar  | VARCHAR(10)    | Yes  | false | NULL    | NONE  |
+    | c_string   | TEXT           | Yes  | false | NULL    | NONE  |
+    +------------+----------------+------+-------+---------+-------+
+
+     */
+
+    val schema = new StructType()
+      .add("id", IntegerType)
+      .add("c_bool", BooleanType)
+      .add("c_tinyint", ByteType)
+      .add("c_smallint", ShortType)
+      .add("c_int", IntegerType)
+      .add("c_bigint", LongType)
+      .add("c_largeint", StringType)
+      .add("c_float", FloatType)
+      .add("c_double", DoubleType)
+      .add("c_decimal", DecimalType.apply(10, 5))
+      .add("c_date", DateType)
+      .add("c_datetime", TimestampType)
+      .add("c_char", StringType)
+      .add("c_varchar", StringType)
+      .add("c_string", StringType)
+
+    val row = Row(
+      1,
+      true,
+      1.toByte,
+      2.toShort,
+      3,
+      4.toLong,
+      "123456789",
+      6.6.floatValue(),
+      7.7.doubleValue(),
+      Decimal.apply(3.12),
+      Date.valueOf("2023-09-08"),
+      Timestamp.valueOf("2023-09-08 17:12:34.123456"),
+      "char",
+      "varchar",
+      "string"
+    )
+
+
+    val inputList = ListBuffer[Row]()
+    for (a <- 0 until 7) {
+      inputList.append(row)
+    }
+
+    val rdd = spark.sparkContext.parallelize(inputList, 1)
+    val df = spark.createDataFrame(rdd, schema).toDF()
+
+    df.write
+      .format("doris")
+      .option("doris.fenodes", dorisFeNodes)
+      .option("user", dorisUser)
+      .option("password", dorisPwd)
+      .option("doris.table.identifier", 
s"$databaseName.spark_connector_primitive")
+      .option("doris.sink.batch.size", 3)
+      .option("doris.sink.properties.format", "arrow")
+      .option("doris.sink.max-retries", 0)
+      .save()
+  }
+
+  def testDataframeWriteArrayTypes(): Unit = {
+    /*
+
+  CREATE TABLE `spark_connector_array` (
+        `id` int(11) NOT NULL,
+        `c_array_boolean` ARRAY<boolean> NULL,
+        `c_array_tinyint` ARRAY<tinyint> NULL,
+        `c_array_smallint` ARRAY<smallint> NULL,
+        `c_array_int` ARRAY<int> NULL,
+        `c_array_bigint` ARRAY<bigint> NULL,
+        `c_array_largeint` ARRAY<largeint> NULL,
+        `c_array_float` ARRAY<float> NULL,
+        `c_array_double` ARRAY<double> NULL,
+        `c_array_decimal` ARRAY<DECIMAL(10, 5)> NULL,
+        `c_array_date` ARRAY<date> NULL,
+        `c_array_datetime` ARRAY<datetime(6)> NULL,
+        `c_array_char` ARRAY<char(10)> NULL,
+        `c_array_varchar` ARRAY<varchar(10)> NULL,
+        `c_array_string` ARRAY<string> NULL
+      ) ENGINE=OLAP
+      DUPLICATE KEY(`id`)
+      COMMENT 'OLAP'
+      DISTRIBUTED BY HASH(`id`) BUCKETS 1
+      PROPERTIES (
+      "replication_allocation" = "tag.location.default: 1"
+      );
+
+    
+------------------+-------------------------+------+-------+---------+-------+
+    | Field            | Type                    | Null | Key   | Default | 
Extra |
+    
+------------------+-------------------------+------+-------+---------+-------+
+    | id               | INT                     | No   | true  | NULL    |    
   |
+    | c_array_boolean  | ARRAY<BOOLEAN>          | Yes  | false | []      | 
NONE  |
+    | c_array_tinyint  | ARRAY<TINYINT>          | Yes  | false | []      | 
NONE  |
+    | c_array_smallint | ARRAY<SMALLINT>         | Yes  | false | []      | 
NONE  |
+    | c_array_int      | ARRAY<INT>              | Yes  | false | []      | 
NONE  |
+    | c_array_bigint   | ARRAY<BIGINT>           | Yes  | false | []      | 
NONE  |
+    | c_array_largeint | ARRAY<LARGEINT>         | Yes  | false | []      | 
NONE  |
+    | c_array_float    | ARRAY<FLOAT>            | Yes  | false | []      | 
NONE  |
+    | c_array_double   | ARRAY<DOUBLE>           | Yes  | false | []      | 
NONE  |
+    | c_array_decimal  | ARRAY<DECIMALV3(10, 5)> | Yes  | false | []      | 
NONE  |
+    | c_array_date     | ARRAY<DATEV2>           | Yes  | false | []      | 
NONE  |
+    | c_array_datetime | ARRAY<DATETIMEV2(6)>    | Yes  | false | []      | 
NONE  |
+    | c_array_char     | ARRAY<CHAR(10)>         | Yes  | false | []      | 
NONE  |
+    | c_array_varchar  | ARRAY<VARCHAR(10)>      | Yes  | false | []      | 
NONE  |
+    | c_array_string   | ARRAY<TEXT>             | Yes  | false | []      | 
NONE  |
+    
+------------------+-------------------------+------+-------+---------+-------+
+
+     */
+
+    val schema = new StructType()
+      .add("id", IntegerType)
+      .add("c_array_boolean", ArrayType(BooleanType))
+      .add("c_array_tinyint", ArrayType(ByteType))
+      .add("c_array_smallint", ArrayType(ShortType))
+      .add("c_array_int", ArrayType(IntegerType))
+      .add("c_array_bigint", ArrayType(LongType))
+      .add("c_array_largeint", ArrayType(StringType))
+      .add("c_array_float", ArrayType(FloatType))
+      .add("c_array_double", ArrayType(DoubleType))
+      .add("c_array_decimal", ArrayType(DecimalType.apply(10, 5)))
+      .add("c_array_date", ArrayType(DateType))
+      .add("c_array_datetime", ArrayType(TimestampType))
+      .add("c_array_char", ArrayType(StringType))
+      .add("c_array_varchar", ArrayType(StringType))
+      .add("c_array_string", ArrayType(StringType))
+
+    val row = Row(
+      1,
+      Array(true, false, false, true, true),
+      Array(1.toByte, 2.toByte, 3.toByte),
+      Array(2.toShort, 12.toShort, 32.toShort),
+      Array(3, 4, 5, 6),
+      Array(4.toLong, 5.toLong, 6.toLong),
+      Array("123456789", "987654321", "123789456"),
+      Array(6.6.floatValue(), 6.7.floatValue(), 7.8.floatValue()),
+      Array(7.7.doubleValue(), 8.8.doubleValue(), 8.9.floatValue()),
+      Array(Decimal.apply(3.12), Decimal.apply(1.12345)),
+      Array(Date.valueOf("2023-09-08"), Date.valueOf("2027-10-28")),
+      Array(Timestamp.valueOf("2023-09-08 17:12:34.123456"), 
Timestamp.valueOf("2024-09-08 18:12:34.123456")),
+      Array("char", "char2"),
+      Array("varchar", "varchar2"),
+      Array("string", "string2")
+    )
+
+
+    val inputList = ListBuffer[Row]()
+    for (a <- 0 until 7) {
+      inputList.append(row)
+    }
+
+    val rdd = spark.sparkContext.parallelize(inputList, 1)
+    val df = spark.createDataFrame(rdd, schema).toDF()
+
+    df.write
+      .format("doris")
+      .option("doris.fenodes", dorisFeNodes)
+      .option("user", dorisUser)
+      .option("password", dorisPwd)
+      .option("doris.table.identifier", s"$databaseName.spark_connector_array")
+      .option("doris.sink.batch.size", 30)
+      .option("doris.sink.properties.format", "arrow")
+      .option("doris.sink.max-retries", 0)
+      .save()
+  }
+
+  def testDataframeWriteMapType(): Unit = {
+    /*
+
+  CREATE TABLE `spark_connector_map` (
+        `id` int(11) NOT NULL,
+        `c_map_bool` Map<boolean,boolean> NULL,
+        `c_map_tinyint` Map<tinyint,tinyint> NULL,
+        `c_map_smallint` Map<smallint,smallint> NULL,
+        `c_map_int` Map<int,int> NULL,
+        `c_map_bigint` Map<bigint,bigint> NULL,
+        `c_map_largeint` Map<largeint,largeint> NULL,
+        `c_map_float` Map<float,float> NULL,
+        `c_map_double` Map<double,double> NULL,
+        `c_map_decimal` Map<DECIMAL(10, 5),DECIMAL(10, 5)> NULL,
+        `c_map_date` Map<date,date> NULL,
+        `c_map_datetime` Map<datetime(6),datetime(6)> NULL,
+        `c_map_char` Map<char(10),char(10)> NULL,
+        `c_map_varchar` Map<varchar(10),varchar(10)> NULL,
+        `c_map_string` Map<string,string> NULL
+      ) ENGINE=OLAP
+      DUPLICATE KEY(`id`)
+      COMMENT 'OLAP'
+      DISTRIBUTED BY HASH(`id`) BUCKETS 1
+      PROPERTIES (
+      "replication_allocation" = "tag.location.default: 1"
+      );
+
+    
+----------------+----------------------------------------+------+-------+---------+-------+
+    | Field          | Type                                   | Null | Key   | 
Default | Extra |
+    
+----------------+----------------------------------------+------+-------+---------+-------+
+    | id             | INT                                    | No   | true  | 
NULL    |       |
+    | c_map_bool     | MAP<BOOLEAN,BOOLEAN>                   | Yes  | false | 
NULL    | NONE  |
+    | c_map_tinyint  | MAP<TINYINT,TINYINT>                   | Yes  | false | 
NULL    | NONE  |
+    | c_map_smallint | MAP<SMALLINT,SMALLINT>                 | Yes  | false | 
NULL    | NONE  |
+    | c_map_int      | MAP<INT,INT>                           | Yes  | false | 
NULL    | NONE  |
+    | c_map_bigint   | MAP<BIGINT,BIGINT>                     | Yes  | false | 
NULL    | NONE  |
+    | c_map_largeint | MAP<LARGEINT,LARGEINT>                 | Yes  | false | 
NULL    | NONE  |
+    | c_map_float    | MAP<FLOAT,FLOAT>                       | Yes  | false | 
NULL    | NONE  |
+    | c_map_double   | MAP<DOUBLE,DOUBLE>                     | Yes  | false | 
NULL    | NONE  |
+    | c_map_decimal  | MAP<DECIMALV3(10, 5),DECIMALV3(10, 5)> | Yes  | false | 
NULL    | NONE  |
+    | c_map_date     | MAP<DATEV2,DATEV2>                     | Yes  | false | 
NULL    | NONE  |
+    | c_map_datetime | MAP<DATETIMEV2(6),DATETIMEV2(6)>       | Yes  | false | 
NULL    | NONE  |
+    | c_map_char     | MAP<CHAR(10),CHAR(10)>                 | Yes  | false | 
NULL    | NONE  |
+    | c_map_varchar  | MAP<VARCHAR(10),VARCHAR(10)>           | Yes  | false | 
NULL    | NONE  |
+    | c_map_string   | MAP<TEXT,TEXT>                         | Yes  | false | 
NULL    | NONE  |
+    
+----------------+----------------------------------------+------+-------+---------+-------+
+
+     */
+
+    val schema = new StructType()
+      .add("id", IntegerType)
+      .add("c_map_bool", MapType(BooleanType, BooleanType))
+      .add("c_map_tinyint", MapType(ByteType, ByteType))
+      .add("c_map_smallint", MapType(ShortType, ShortType))
+      .add("c_map_int", MapType(IntegerType, IntegerType))
+      .add("c_map_bigint", MapType(LongType, LongType))
+      .add("c_map_largeint", MapType(StringType, StringType))
+      .add("c_map_float", MapType(FloatType, FloatType))
+      .add("c_map_double", MapType(DoubleType, DoubleType))
+      .add("c_map_decimal", MapType(DecimalType.apply(10, 5), 
DecimalType.apply(10, 5)))
+      .add("c_map_date", MapType(DateType, DateType))
+      .add("c_map_datetime", MapType(TimestampType, TimestampType))
+      .add("c_map_char", MapType(StringType, StringType))
+      .add("c_map_varchar", MapType(StringType, StringType))
+      .add("c_map_string", MapType(StringType, StringType))
+
+    val row = Row(
+      1,
+      Map(true -> false, false -> true, true -> true),
+      Map(1.toByte -> 2.toByte, 3.toByte -> 4.toByte),
+      Map(2.toShort -> 4.toShort, 5.toShort -> 6.toShort),
+      Map(3 -> 4, 7 -> 8),
+      Map(4.toLong -> 5.toLong, 1.toLong -> 2.toLong),
+      Map("123456789" -> "987654321", "789456123" -> "456789123"),
+      Map(6.6.floatValue() -> 8.8.floatValue(), 9.9.floatValue() -> 
10.1.floatValue()),
+      Map(7.7.doubleValue() -> 1.1.doubleValue(), 2.2 -> 3.3.doubleValue()),
+      Map(Decimal.apply(3.12) -> Decimal.apply(1.23), Decimal.apply(2.34) -> 
Decimal.apply(5.67)),
+      Map(Date.valueOf("2023-09-08") -> Date.valueOf("2024-09-08"), 
Date.valueOf("1023-09-08") -> Date.valueOf("2023-09-08")),
+      Map(Timestamp.valueOf("1023-09-08 17:12:34.123456") -> 
Timestamp.valueOf("2023-09-08 17:12:34.123456"), Timestamp.valueOf("3023-09-08 
17:12:34.123456") -> Timestamp.valueOf("4023-09-08 17:12:34.123456")),
+      Map("char" -> "char2", "char2" -> "char3"),
+      Map("varchar" -> "varchar2", "varchar3" -> "varchar4"),
+      Map("string" -> "string2", "string3" -> "string4")
+    )
+
+
+    val inputList = ListBuffer[Row]()
+    for (a <- 0 until 7) {
+      inputList.append(row)
+    }
+
+    val rdd = spark.sparkContext.parallelize(inputList, 1)
+    val df = spark.createDataFrame(rdd, schema).toDF()
+
+    df.write
+      .format("doris")
+      .option("doris.fenodes", dorisFeNodes)
+      .option("user", dorisUser)
+      .option("password", dorisPwd)
+      .option("doris.table.identifier", s"$databaseName.spark_connector_map")
+      .option("doris.sink.batch.size", 3)
+      .option("doris.sink.properties.format", "arrow")
+      .option("doris.sink.max-retries", 0)
+      .save()
+  }
+
+  def testDataframeWriteStructType(): Unit = {
+    /*
+
+CREATE TABLE `spark_connector_struct` (
+          `id` int NOT NULL,
+          `st` STRUCT<
+              `c_bool`:boolean,
+              `c_tinyint`:tinyint(4),
+              `c_smallint`:smallint(6),
+              `c_int`:int(11),
+              `c_bigint`:bigint(20),
+              `c_largeint`:largeint(40),
+              `c_float`:float,
+              `c_double`:double,
+              `c_decimal`:DECIMAL(10, 5),
+              `c_date`:date,
+              `c_datetime`:datetime(6),
+              `c_char`:char(10),
+              `c_varchar`:varchar(10),
+              `c_string`:string
+            > NULL
+        ) ENGINE=OLAP
+        DUPLICATE KEY(`id`)
+        COMMENT 'OLAP'
+        DISTRIBUTED BY HASH(`id`) BUCKETS 1
+        PROPERTIES (
+        "replication_allocation" = "tag.location.default: 1"
+        );
+
+    
+-------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+------+-------+---------+-------+
+    | Field | Type                                                             
                                                                                
                                                                                
                              | Null | Key   | Default | Extra |
+    
+-------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+------+-------+---------+-------+
+    | id    | INT                                                              
                                                                                
                                                                                
                              | No   | true  | NULL    |       |
+    | st    | 
STRUCT<c_bool:BOOLEAN,c_tinyint:TINYINT,c_smallint:SMALLINT,c_int:INT,c_bigint:BIGINT,c_largeint:LARGEINT,c_float:FLOAT,c_double:DOUBLE,c_decimal:DECIMALV3(10,
 
5),c_date:DATEV2,c_datetime:DATETIMEV2(6),c_char:CHAR(10),c_varchar:VARCHAR(10),c_string:TEXT>
 | Yes  | false | NULL    | NONE  |
+    
+-------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+------+-------+---------+-------+
+
+     */
+
+    val st = new StructType()
+      .add("c_bool", BooleanType)
+      .add("c_tinyint", ByteType)
+      .add("c_smallint", ShortType)
+      .add("c_int", IntegerType)
+      .add("c_bigint", LongType)
+      .add("c_largeint", StringType)
+      .add("c_float", FloatType)
+      .add("c_double", DoubleType)
+      .add("c_decimal", DecimalType.apply(10, 5))
+      .add("c_date", DateType)
+      .add("c_datetime", TimestampType)
+      .add("c_char", StringType)
+      .add("c_varchar", StringType)
+      .add("c_string", StringType)
+
+    val schema = new StructType()
+      .add("id", IntegerType)
+      .add("st", st)
+
+    val row = Row(
+      1,
+      Row(true,
+        1.toByte,
+        2.toShort,
+        3,
+        4.toLong,
+        "123456789",
+        6.6.floatValue(),
+        7.7.doubleValue(),
+        Decimal.apply(3.12),
+        Date.valueOf("2023-09-08"),
+        Timestamp.valueOf("2023-09-08 17:12:34.123456"),
+        "char",
+        "varchar",
+        "string")
+    )
+
+
+    val inputList = ListBuffer[Row]()
+    for (a <- 0 until 7) {
+      inputList.append(row)
+    }
+
+    val rdd = spark.sparkContext.parallelize(inputList, 1)
+    val df = spark.createDataFrame(rdd, schema).toDF()
+
+    df.write
+      .format("doris")
+      .option("doris.fenodes", dorisFeNodes)
+      .option("user", dorisUser)
+      .option("password", dorisPwd)
+      .option("doris.table.identifier", 
s"$databaseName.spark_connector_struct")
+      .option("doris.sink.batch.size", 3)
+      .option("doris.sink.properties.format", "arrow")
+      .option("doris.sink.max-retries", 0)
+      .save()
+  }
+}
diff --git 
a/spark-doris-connector/src/main/scala/org/apache/spark/sql/doris/spark/ArrowSchemaUtils.scala
 
b/spark-doris-connector/src/main/scala/org/apache/spark/sql/doris/spark/ArrowSchemaUtils.scala
new file mode 100644
index 0000000..3ef75a2
--- /dev/null
+++ 
b/spark-doris-connector/src/main/scala/org/apache/spark/sql/doris/spark/ArrowSchemaUtils.scala
@@ -0,0 +1,52 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.spark.sql.doris.spark
+
+import org.apache.arrow.vector.types.pojo.Schema
+import org.apache.spark.sql.types.StructType
+
+object ArrowSchemaUtils {
+  var classArrowUtils: Option[Class[_]] = None: Option[Class[_]]
+
+  def tryLoadArrowUtilsClass(): Unit = {
+    if (classArrowUtils.isEmpty) {
+      // for spark3.x
+      classArrowUtils = 
classArrowUtils.orElse(tryLoadClass("org.apache.spark.sql.util.ArrowUtils"))
+      // for spark2.x
+      classArrowUtils = 
classArrowUtils.orElse(tryLoadClass("org.apache.spark.sql.execution.arrow.ArrowUtils"))
+      if (classArrowUtils.isEmpty) {
+        throw new ClassNotFoundException("can't load class for ArrowUtils")
+      }
+    }
+  }
+
+  def tryLoadClass(className: String): Option[Class[_]] = {
+    try {
+      Some(Class.forName(className))
+    } catch {
+      case e: ClassNotFoundException =>
+        None
+    }
+  }
+
+  def toArrowSchema(schema: StructType, timeZoneId: String): Schema = {
+    tryLoadArrowUtilsClass()
+    val toArrowSchema = classArrowUtils.get.getMethod("toArrowSchema", 
classOf[StructType], classOf[String])
+    toArrowSchema.invoke(null, schema, timeZoneId).asInstanceOf[Schema]
+  }
+}
diff --git 
a/spark-doris-connector/src/test/scala/org/apache/doris/spark/sql/TestConnectorWriteDoris.scala
 
b/spark-doris-connector/src/test/scala/org/apache/doris/spark/sql/TestConnectorWriteDoris.scala
index fecface..dbea21a 100644
--- 
a/spark-doris-connector/src/test/scala/org/apache/doris/spark/sql/TestConnectorWriteDoris.scala
+++ 
b/spark-doris-connector/src/test/scala/org/apache/doris/spark/sql/TestConnectorWriteDoris.scala
@@ -184,5 +184,4 @@ class TestConnectorWriteDoris {
       .save()
     spark.stop()
   }
-
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to