This is an automated email from the ASF dual-hosted git repository.
diwu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris-spark-connector.git
The following commit(s) were added to refs/heads/master by this push:
new 54ac94b [fix](connector) Fixed writing issues in arrow format (#270)
54ac94b is described below
commit 54ac94bc5067588c22391306d082017d63a22d65
Author: gnehil <[email protected]>
AuthorDate: Tue Feb 25 11:14:42 2025 +0800
[fix](connector) Fixed writing issues in arrow format (#270)
---
.../spark/client/read/AbstractThriftReader.java | 2 +-
.../apache/doris/spark/client/read/RowBatch.java | 22 ++++++--
.../client/write/AbstractStreamLoadProcessor.java | 20 ++++---
.../spark/client/write/StreamLoadProcessor.java | 9 ++--
.../org/apache/doris/spark/load/StreamLoader.scala | 2 +
.../apache/doris/spark/util/RowConvertors.scala | 5 +-
.../apache/doris/spark/util/SchemaConvertors.scala | 4 +-
.../doris/spark/client/read/RowBatchTest.java | 61 +++++++++++++++++-----
8 files changed, 94 insertions(+), 31 deletions(-)
diff --git
a/spark-doris-connector/spark-doris-connector-base/src/main/java/org/apache/doris/spark/client/read/AbstractThriftReader.java
b/spark-doris-connector/spark-doris-connector-base/src/main/java/org/apache/doris/spark/client/read/AbstractThriftReader.java
index 7fdb1cf..f200b38 100644
---
a/spark-doris-connector/spark-doris-connector-base/src/main/java/org/apache/doris/spark/client/read/AbstractThriftReader.java
+++
b/spark-doris-connector/spark-doris-connector-base/src/main/java/org/apache/doris/spark/client/read/AbstractThriftReader.java
@@ -110,7 +110,7 @@ public abstract class AbstractThriftReader extends
DorisReader {
this.rowBatchQueue = null;
this.asyncThread = null;
}
- this.datetimeJava8ApiEnabled = false;
+ this.datetimeJava8ApiEnabled = partition.getDateTimeJava8APIEnabled();
}
private void runAsync() throws DorisException, InterruptedException {
diff --git
a/spark-doris-connector/spark-doris-connector-base/src/main/java/org/apache/doris/spark/client/read/RowBatch.java
b/spark-doris-connector/spark-doris-connector-base/src/main/java/org/apache/doris/spark/client/read/RowBatch.java
index 759678d..ba61d6a 100644
---
a/spark-doris-connector/spark-doris-connector-base/src/main/java/org/apache/doris/spark/client/read/RowBatch.java
+++
b/spark-doris-connector/spark-doris-connector-base/src/main/java/org/apache/doris/spark/client/read/RowBatch.java
@@ -59,10 +59,12 @@ import java.math.BigDecimal;
import java.math.BigInteger;
import java.nio.charset.StandardCharsets;
import java.sql.Date;
+import java.sql.Timestamp;
import java.time.Instant;
import java.time.LocalDate;
import java.time.LocalDateTime;
import java.time.ZoneId;
+import java.time.ZoneOffset;
import java.time.format.DateTimeFormatter;
import java.time.format.DateTimeFormatterBuilder;
import java.time.temporal.ChronoField;
@@ -72,6 +74,7 @@ import java.util.List;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.Objects;
+import java.util.TimeZone;
/**
* row batch data container.
@@ -403,8 +406,15 @@ public class RowBatch implements Serializable {
addValueToRow(rowIndex, null);
continue;
}
- String value = new
String(varCharVector.get(rowIndex), StandardCharsets.UTF_8);
- addValueToRow(rowIndex, value);
+ String stringValue = completeMilliseconds(new
String(varCharVector.get(rowIndex),
+ StandardCharsets.UTF_8));
+ LocalDateTime dateTime =
LocalDateTime.parse(stringValue, dateTimeV2Formatter);
+ if (datetimeJava8ApiEnabled) {
+ Instant instant =
dateTime.atZone(DEFAULT_ZONE_ID).toInstant();
+ addValueToRow(rowIndex, instant);
+ } else {
+ addValueToRow(rowIndex,
Timestamp.valueOf(dateTime));
+ }
}
} else if (curFieldVector instanceof TimeStampVector) {
TimeStampVector timeStampVector =
(TimeStampVector) curFieldVector;
@@ -414,8 +424,12 @@ public class RowBatch implements Serializable {
continue;
}
LocalDateTime dateTime = getDateTime(rowIndex,
timeStampVector);
- String formatted =
DATE_TIME_FORMATTER.format(dateTime);
- addValueToRow(rowIndex, formatted);
+ if (datetimeJava8ApiEnabled) {
+ Instant instant =
dateTime.atZone(DEFAULT_ZONE_ID).toInstant();
+ addValueToRow(rowIndex, instant);
+ } else {
+ addValueToRow(rowIndex,
Timestamp.valueOf(dateTime));
+ }
}
} else {
String errMsg = String.format("Unsupported type
for DATETIMEV2, minorType %s, class is %s",
diff --git
a/spark-doris-connector/spark-doris-connector-base/src/main/java/org/apache/doris/spark/client/write/AbstractStreamLoadProcessor.java
b/spark-doris-connector/spark-doris-connector-base/src/main/java/org/apache/doris/spark/client/write/AbstractStreamLoadProcessor.java
index 8c9e859..484653e 100644
---
a/spark-doris-connector/spark-doris-connector-base/src/main/java/org/apache/doris/spark/client/write/AbstractStreamLoadProcessor.java
+++
b/spark-doris-connector/spark-doris-connector-base/src/main/java/org/apache/doris/spark/client/write/AbstractStreamLoadProcessor.java
@@ -48,9 +48,9 @@ import java.io.IOException;
import java.io.PipedInputStream;
import java.io.PipedOutputStream;
import java.nio.charset.StandardCharsets;
-import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
+import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
@@ -106,7 +106,7 @@ public abstract class AbstractStreamLoadProcessor<R>
implements DorisWriter<R>,
private boolean isFirstRecordOfBatch = true;
- private final List<R> recordBuffer = new ArrayList<>();
+ private final List<R> recordBuffer = new LinkedList<>();
private static final int arrowBufferSize = 1000;
@@ -161,6 +161,12 @@ public abstract class AbstractStreamLoadProcessor<R>
implements DorisWriter<R>,
@Override
public String stop() throws Exception {
+ // arrow format need to send all buffer data before stop
+ if (!recordBuffer.isEmpty() && "arrow".equalsIgnoreCase(format)) {
+ List<R> rs = new LinkedList<>(recordBuffer);
+ recordBuffer.clear();
+ output.write(toArrowFormat(rs));
+ }
output.close();
CloseableHttpResponse res = requestFuture.get();
if (res.getStatusLine().getStatusCode() != HttpStatus.SC_OK) {
@@ -239,13 +245,13 @@ public abstract class AbstractStreamLoadProcessor<R>
implements DorisWriter<R>,
case "json":
return toStringFormat(row, format);
case "arrow":
- recordBuffer.add(row);
+ recordBuffer.add(copy(row));
if (recordBuffer.size() < arrowBufferSize) {
return new byte[0];
} else {
- R[] dataArray = (R[]) recordBuffer.toArray();
+ LinkedList<R> rs = new LinkedList<>(recordBuffer);
recordBuffer.clear();
- return toArrowFormat(dataArray);
+ return toArrowFormat(rs);
}
default:
throw new IllegalArgumentException("Unsupported stream load
format: " + format);
@@ -263,7 +269,7 @@ public abstract class AbstractStreamLoadProcessor<R>
implements DorisWriter<R>,
public abstract String stringify(R row, String format);
- public abstract byte[] toArrowFormat(R[] rowArray) throws IOException;
+ public abstract byte[] toArrowFormat(List<R> rows) throws IOException;
public abstract String getWriteFields() throws OptionRequiredException;
@@ -364,4 +370,6 @@ public abstract class AbstractStreamLoadProcessor<R>
implements DorisWriter<R>,
}
}
+ protected abstract R copy(R row);
+
}
diff --git
a/spark-doris-connector/spark-doris-connector-base/src/main/java/org/apache/doris/spark/client/write/StreamLoadProcessor.java
b/spark-doris-connector/spark-doris-connector-base/src/main/java/org/apache/doris/spark/client/write/StreamLoadProcessor.java
index e5ac4fa..2f787a5 100644
---
a/spark-doris-connector/spark-doris-connector-base/src/main/java/org/apache/doris/spark/client/write/StreamLoadProcessor.java
+++
b/spark-doris-connector/spark-doris-connector-base/src/main/java/org/apache/doris/spark/client/write/StreamLoadProcessor.java
@@ -34,6 +34,7 @@ import org.apache.spark.sql.util.ArrowUtils;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
+import java.util.List;
public class StreamLoadProcessor extends
AbstractStreamLoadProcessor<InternalRow> {
@@ -50,7 +51,7 @@ public class StreamLoadProcessor extends
AbstractStreamLoadProcessor<InternalRow
}
@Override
- public byte[] toArrowFormat(InternalRow[] rowArray) throws IOException {
+ public byte[] toArrowFormat(List<InternalRow> rowArray) throws IOException
{
Schema arrowSchema = ArrowUtils.toArrowSchema(schema, "UTC");
VectorSchemaRoot root = VectorSchemaRoot.create(arrowSchema, new
RootAllocator(Integer.MAX_VALUE));
ArrowWriter arrowWriter = ArrowWriter.create(root);
@@ -112,6 +113,8 @@ public class StreamLoadProcessor extends
AbstractStreamLoadProcessor<InternalRow
this.schema = schema;
}
-
-
+ @Override
+ protected InternalRow copy(InternalRow row) {
+ return row.copy();
+ }
}
\ No newline at end of file
diff --git
a/spark-doris-connector/spark-doris-connector-base/src/main/scala/org/apache/doris/spark/load/StreamLoader.scala
b/spark-doris-connector/spark-doris-connector-base/src/main/scala/org/apache/doris/spark/load/StreamLoader.scala
index 73e8c9b..109fa5b 100644
---
a/spark-doris-connector/spark-doris-connector-base/src/main/scala/org/apache/doris/spark/load/StreamLoader.scala
+++
b/spark-doris-connector/spark-doris-connector-base/src/main/scala/org/apache/doris/spark/load/StreamLoader.scala
@@ -269,6 +269,8 @@ class StreamLoader(settings: SparkSettings, isStreaming:
Boolean) extends Loader
*/
private def buildLoadRequest(iterator: Iterator[InternalRow], schema:
StructType, label: String): HttpUriRequest = {
+ iterator.next().copy()
+
currentLoadUrl = URLs.streamLoad(getNode, database, table, enableHttps)
val put = new HttpPut(currentLoadUrl)
addCommonHeader(put)
diff --git
a/spark-doris-connector/spark-doris-connector-base/src/main/scala/org/apache/doris/spark/util/RowConvertors.scala
b/spark-doris-connector/spark-doris-connector-base/src/main/scala/org/apache/doris/spark/util/RowConvertors.scala
index 31b7196..b75d1ce 100644
---
a/spark-doris-connector/spark-doris-connector-base/src/main/scala/org/apache/doris/spark/util/RowConvertors.scala
+++
b/spark-doris-connector/spark-doris-connector-base/src/main/scala/org/apache/doris/spark/util/RowConvertors.scala
@@ -27,7 +27,7 @@ import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import java.sql.{Date, Timestamp}
-import java.time.LocalDate
+import java.time.{Instant, LocalDate}
import scala.collection.JavaConverters.mapAsScalaMapConverter
import scala.collection.mutable
@@ -110,7 +110,8 @@ object RowConvertors {
def convertValue(v: Any, dataType: DataType, datetimeJava8ApiEnabled:
Boolean): Any = {
dataType match {
case StringType => UTF8String.fromString(v.asInstanceOf[String])
- case TimestampType =>
DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf(v.asInstanceOf[String]))
+ case TimestampType if datetimeJava8ApiEnabled =>
DateTimeUtils.instantToMicros(v.asInstanceOf[Instant])
+ case TimestampType =>
DateTimeUtils.fromJavaTimestamp(v.asInstanceOf[Timestamp])
case DateType if datetimeJava8ApiEnabled =>
v.asInstanceOf[LocalDate].toEpochDay.toInt
case DateType => DateTimeUtils.fromJavaDate(v.asInstanceOf[Date])
case _: MapType =>
diff --git
a/spark-doris-connector/spark-doris-connector-base/src/main/scala/org/apache/doris/spark/util/SchemaConvertors.scala
b/spark-doris-connector/spark-doris-connector-base/src/main/scala/org/apache/doris/spark/util/SchemaConvertors.scala
index 303aa1f..e694083 100644
---
a/spark-doris-connector/spark-doris-connector-base/src/main/scala/org/apache/doris/spark/util/SchemaConvertors.scala
+++
b/spark-doris-connector/spark-doris-connector-base/src/main/scala/org/apache/doris/spark/util/SchemaConvertors.scala
@@ -37,8 +37,8 @@ object SchemaConvertors {
case "DOUBLE" => DataTypes.DoubleType
case "DATE" => DataTypes.DateType
case "DATEV2" => DataTypes.DateType
- case "DATETIME" => DataTypes.StringType
- case "DATETIMEV2" => DataTypes.StringType
+ case "DATETIME" => DataTypes.TimestampType
+ case "DATETIMEV2" => DataTypes.TimestampType
case "BINARY" => DataTypes.BinaryType
case "DECIMAL" => DecimalType(precision, scale)
case "CHAR" => DataTypes.StringType
diff --git
a/spark-doris-connector/spark-doris-connector-base/src/test/java/org/apache/doris/spark/client/read/RowBatchTest.java
b/spark-doris-connector/spark-doris-connector-base/src/test/java/org/apache/doris/spark/client/read/RowBatchTest.java
index 05123b4..acc7712 100644
---
a/spark-doris-connector/spark-doris-connector-base/src/test/java/org/apache/doris/spark/client/read/RowBatchTest.java
+++
b/spark-doris-connector/spark-doris-connector-base/src/test/java/org/apache/doris/spark/client/read/RowBatchTest.java
@@ -56,12 +56,10 @@ import org.apache.doris.sdk.thrift.TStatusCode;
import org.apache.doris.spark.exception.DorisException;
import org.apache.doris.spark.rest.RestService;
import org.apache.doris.spark.rest.models.Schema;
-import org.apache.spark.sql.internal.SQLConf;
-import org.apache.spark.sql.internal.SQLConf$;
import org.apache.spark.sql.types.Decimal;
import static org.hamcrest.core.StringStartsWith.startsWith;
import org.junit.Assert;
-import static org.junit.Assert.*;
+import static org.junit.Assert.assertEquals;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
@@ -75,6 +73,7 @@ import java.math.BigDecimal;
import java.math.BigInteger;
import java.nio.charset.StandardCharsets;
import java.sql.Date;
+import java.sql.Timestamp;
import java.time.LocalDate;
import java.time.LocalDateTime;
import java.time.ZoneId;
@@ -275,7 +274,7 @@ public class RowBatchTest {
(float) 1.1,
(double) 1.1,
Date.valueOf("2008-08-08"),
- "2008-08-08 00:00:00",
+ Timestamp.valueOf("2008-08-08 00:00:00"),
Decimal.apply(1234L, 4, 2),
"char1"
);
@@ -289,7 +288,7 @@ public class RowBatchTest {
(float) 2.2,
(double) 2.2,
Date.valueOf("1900-08-08"),
- "1900-08-08 00:00:00",
+ Timestamp.valueOf("1900-08-08 00:00:00"),
Decimal.apply(8888L, 4, 2),
"char2"
);
@@ -303,7 +302,7 @@ public class RowBatchTest {
(float) 3.3,
(double) 3.3,
Date.valueOf("2100-08-08"),
- "2100-08-08 00:00:00",
+ Timestamp.valueOf("2100-08-08 00:00:00"),
Decimal.apply(10L, 2, 0),
"char3"
);
@@ -831,16 +830,16 @@ public class RowBatchTest {
Assert.assertTrue(rowBatch.hasNext());
List<Object> actualRow0 = rowBatch.next();
- Assert.assertEquals("2024-03-20 00:00:00", actualRow0.get(0));
- Assert.assertEquals("2024-03-20 00:00:00", actualRow0.get(1));
+ Assert.assertEquals(Timestamp.valueOf("2024-03-20 00:00:00"),
actualRow0.get(0));
+ Assert.assertEquals(Timestamp.valueOf("2024-03-20 00:00:00"),
actualRow0.get(1));
List<Object> actualRow1 = rowBatch.next();
- Assert.assertEquals("2024-03-20 00:00:01", actualRow1.get(0));
- Assert.assertEquals("2024-03-20 00:00:00.123", actualRow1.get(1));
+ Assert.assertEquals(Timestamp.valueOf("2024-03-20 00:00:01"),
actualRow1.get(0));
+ Assert.assertEquals(Timestamp.valueOf("2024-03-20 00:00:00.123"),
actualRow1.get(1));
List<Object> actualRow2 = rowBatch.next();
- Assert.assertEquals("2024-03-20 00:00:02", actualRow2.get(0));
- Assert.assertEquals("2024-03-20 00:00:00.123456", actualRow2.get(1));
+ Assert.assertEquals(Timestamp.valueOf("2024-03-20 00:00:02"),
actualRow2.get(0));
+ Assert.assertEquals(Timestamp.valueOf("2024-03-20 00:00:00.123456"),
actualRow2.get(1));
Assert.assertFalse(rowBatch.hasNext());
@@ -1169,6 +1168,10 @@ public class RowBatchTest {
ImmutableList.Builder<Field> childrenBuilder = ImmutableList.builder();
childrenBuilder.add(new Field("k0", FieldType.nullable(new
ArrowType.Utf8()), null));
childrenBuilder.add(new Field("k1", FieldType.nullable(new
ArrowType.Date(DateUnit.DAY)), null));
+ childrenBuilder.add(new Field("k2", FieldType.nullable(new
ArrowType.Timestamp(TimeUnit.MICROSECOND,
+ null)), null));
+ childrenBuilder.add(new Field("k3", FieldType.nullable(new
ArrowType.Timestamp(TimeUnit.MICROSECOND,
+ null)), null));
VectorSchemaRoot root = VectorSchemaRoot.create(
new
org.apache.arrow.vector.types.pojo.Schema(childrenBuilder.build(), null),
@@ -1202,6 +1205,32 @@ public class RowBatchTest {
date2Vector.setSafe(0, (int) date);
vector.setValueCount(1);
+ LocalDateTime localDateTime = LocalDateTime.of(2025, 2, 24,
+ 0, 0, 0, 123000000);
+ long second =
localDateTime.atZone(ZoneId.systemDefault()).toEpochSecond();
+ int nano = localDateTime.getNano();
+
+ vector = root.getVector("k2");
+ TimeStampMicroVector datetimeV2Vector = (TimeStampMicroVector) vector;
+ datetimeV2Vector.setInitialCapacity(1);
+ datetimeV2Vector.allocateNew();
+ datetimeV2Vector.setIndexDefined(0);
+ datetimeV2Vector.setSafe(0, second * 1000000 + nano / 1000);
+ vector.setValueCount(1);
+
+ LocalDateTime localDateTime1 = LocalDateTime.of(2025, 2, 24,
+ 1, 2, 3, 123456000);
+ long second1 =
localDateTime1.atZone(ZoneId.systemDefault()).toEpochSecond();
+ int nano1 = localDateTime1.getNano();
+
+ vector = root.getVector("k3");
+ TimeStampMicroVector datetimeV2Vector1 = (TimeStampMicroVector) vector;
+ datetimeV2Vector1.setInitialCapacity(1);
+ datetimeV2Vector1.allocateNew();
+ datetimeV2Vector1.setIndexDefined(0);
+ datetimeV2Vector1.setSafe(0, second1 * 1000000 + nano1 / 1000);
+ vector.setValueCount(1);
+
arrowStreamWriter.writeBatch();
arrowStreamWriter.end();
@@ -1217,7 +1246,9 @@ public class RowBatchTest {
String schemaStr = "{\"properties\":[" +
"{\"type\":\"DATE\",\"name\":\"k0\",\"comment\":\"\"}, " +
- "{\"type\":\"DATEV2\",\"name\":\"k1\",\"comment\":\"\"}" +
+ "{\"type\":\"DATEV2\",\"name\":\"k1\",\"comment\":\"\"}," +
+ "{\"type\":\"DATETIME\",\"name\":\"k2\",\"comment\":\"\"}," +
+ "{\"type\":\"DATETIMEV2\",\"name\":\"k3\",\"comment\":\"\"}" +
"], \"status\":200}";
Schema schema = RestService.parseSchema(schemaStr, logger);
@@ -1228,6 +1259,8 @@ public class RowBatchTest {
List<Object> actualRow0 = rowBatch1.next();
Assert.assertEquals(Date.valueOf("2025-01-01"), actualRow0.get(0));
Assert.assertEquals(Date.valueOf("2025-02-01"), actualRow0.get(1));
+ Assert.assertEquals(Timestamp.valueOf("2025-02-24 00:00:00.123"),
actualRow0.get(2));
+ Assert.assertEquals(Timestamp.valueOf("2025-02-24 01:02:03.123456"),
actualRow0.get(3));
Assert.assertFalse(rowBatch1.hasNext());
@@ -1237,6 +1270,8 @@ public class RowBatchTest {
List<Object> actualRow01 = rowBatch2.next();
Assert.assertEquals(LocalDate.of(2025,1,1), actualRow01.get(0));
Assert.assertEquals(localDate, actualRow01.get(1));
+
Assert.assertEquals(localDateTime.atZone(ZoneId.systemDefault()).toInstant(),
actualRow01.get(2));
+
Assert.assertEquals(localDateTime1.atZone(ZoneId.systemDefault()).toInstant(),
actualRow01.get(3));
Assert.assertFalse(rowBatch2.hasNext());
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]