This is an automated email from the ASF dual-hosted git repository.
diwu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris-spark-connector.git
The following commit(s) were added to refs/heads/master by this push:
new da673a1 convert date/datev2 to sql date and largeint into
decimal(38,0) (#125)
da673a1 is described below
commit da673a134204bbff3300f6cb0901134f63da7adf
Author: gnehil <[email protected]>
AuthorDate: Mon Aug 14 15:24:30 2023 +0800
convert date/datev2 to sql date and largeint into decimal(38,0) (#125)
---
.../apache/doris/spark/serialization/RowBatch.java | 53 +++++-
.../org/apache/doris/spark/sql/SchemaUtils.scala | 6 +-
.../doris/spark/serialization/TestRowBatch.java | 192 +++++++++++++++++++--
3 files changed, 227 insertions(+), 24 deletions(-)
diff --git
a/spark-doris-connector/src/main/java/org/apache/doris/spark/serialization/RowBatch.java
b/spark-doris-connector/src/main/java/org/apache/doris/spark/serialization/RowBatch.java
index 371c491..faa8ef5 100644
---
a/spark-doris-connector/src/main/java/org/apache/doris/spark/serialization/RowBatch.java
+++
b/spark-doris-connector/src/main/java/org/apache/doris/spark/serialization/RowBatch.java
@@ -20,7 +20,12 @@ package org.apache.doris.spark.serialization;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.math.BigDecimal;
+import java.math.BigInteger;
import java.nio.charset.StandardCharsets;
+import java.sql.Date;
+import java.time.LocalDate;
+import java.time.LocalDateTime;
+import java.time.format.DateTimeFormatter;
import java.util.ArrayList;
import java.util.List;
import java.util.NoSuchElementException;
@@ -30,6 +35,7 @@ import org.apache.arrow.vector.BigIntVector;
import org.apache.arrow.vector.BitVector;
import org.apache.arrow.vector.DecimalVector;
import org.apache.arrow.vector.FieldVector;
+import org.apache.arrow.vector.FixedSizeBinaryVector;
import org.apache.arrow.vector.Float4Vector;
import org.apache.arrow.vector.Float8Vector;
import org.apache.arrow.vector.IntVector;
@@ -41,9 +47,12 @@ import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.complex.ListVector;
import org.apache.arrow.vector.ipc.ArrowStreamReader;
import org.apache.arrow.vector.types.Types;
+
import org.apache.doris.sdk.thrift.TScanBatchResult;
import org.apache.doris.spark.exception.DorisException;
import org.apache.doris.spark.rest.models.Schema;
+
+import org.apache.commons.lang3.ArrayUtils;
import org.apache.spark.sql.types.Decimal;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -98,7 +107,7 @@ public class RowBatch {
fieldVectors.size(), schema.size());
throw new DorisException("Load Doris data failed, schema
size of fetch data is wrong.");
}
- if (fieldVectors.size() == 0 || root.getRowCount() == 0) {
+ if (fieldVectors.isEmpty() || root.getRowCount() == 0) {
logger.debug("One batch in arrow has no data.");
continue;
}
@@ -190,6 +199,34 @@ public class RowBatch {
addValueToRow(rowIndex, fieldValue);
}
break;
+ case "LARGEINT":
+
Preconditions.checkArgument(mt.equals(Types.MinorType.FIXEDSIZEBINARY) ||
+ mt.equals(Types.MinorType.VARCHAR),
typeMismatchMessage(currentType, mt));
+ if (mt.equals(Types.MinorType.FIXEDSIZEBINARY)) {
+ FixedSizeBinaryVector largeIntVector =
(FixedSizeBinaryVector) curFieldVector;
+ for (int rowIndex = 0; rowIndex <
rowCountInOneBatch; rowIndex++) {
+ if (largeIntVector.isNull(rowIndex)) {
+ addValueToRow(rowIndex, null);
+ continue;
+ }
+ byte[] bytes = largeIntVector.get(rowIndex);
+ ArrayUtils.reverse(bytes);
+ BigInteger largeInt = new BigInteger(bytes);
+ addValueToRow(rowIndex,
Decimal.apply(largeInt));
+ }
+ } else {
+ VarCharVector largeIntVector = (VarCharVector)
curFieldVector;
+ for (int rowIndex = 0; rowIndex <
rowCountInOneBatch; rowIndex++) {
+ if (largeIntVector.isNull(rowIndex)) {
+ addValueToRow(rowIndex, null);
+ continue;
+ }
+ String stringValue = new
String(largeIntVector.get(rowIndex));
+ BigInteger largeInt = new
BigInteger(stringValue);
+ addValueToRow(rowIndex,
Decimal.apply(largeInt));
+ }
+ }
+ break;
case "FLOAT":
Preconditions.checkArgument(mt.equals(Types.MinorType.FLOAT4),
typeMismatchMessage(currentType, mt));
@@ -257,9 +294,21 @@ public class RowBatch {
break;
case "DATE":
case "DATEV2":
+
Preconditions.checkArgument(mt.equals(Types.MinorType.VARCHAR),
+ typeMismatchMessage(currentType, mt));
+ VarCharVector date = (VarCharVector) curFieldVector;
+ for (int rowIndex = 0; rowIndex < rowCountInOneBatch;
rowIndex++) {
+ if (date.isNull(rowIndex)) {
+ addValueToRow(rowIndex, null);
+ continue;
+ }
+ String stringValue = new
String(date.get(rowIndex));
+ LocalDate localDate = LocalDate.parse(stringValue);
+ addValueToRow(rowIndex, Date.valueOf(localDate));
+ }
+ break;
case "DATETIME":
case "DATETIMEV2":
- case "LARGEINT":
case "CHAR":
case "VARCHAR":
case "STRING":
diff --git
a/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/SchemaUtils.scala
b/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/SchemaUtils.scala
index c7fad41..c8aa034 100644
---
a/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/SchemaUtils.scala
+++
b/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/SchemaUtils.scala
@@ -102,14 +102,14 @@ private[spark] object SchemaUtils {
case "BIGINT" => DataTypes.LongType
case "FLOAT" => DataTypes.FloatType
case "DOUBLE" => DataTypes.DoubleType
- case "DATE" => DataTypes.StringType
- case "DATEV2" => DataTypes.StringType
+ case "DATE" => DataTypes.DateType
+ case "DATEV2" => DataTypes.DateType
case "DATETIME" => DataTypes.StringType
case "DATETIMEV2" => DataTypes.StringType
case "BINARY" => DataTypes.BinaryType
case "DECIMAL" => DecimalType(precision, scale)
case "CHAR" => DataTypes.StringType
- case "LARGEINT" => DataTypes.StringType
+ case "LARGEINT" => DecimalType(38,0)
case "VARCHAR" => DataTypes.StringType
case "JSONB" => DataTypes.StringType
case "DECIMALV2" => DecimalType(precision, scale)
diff --git
a/spark-doris-connector/src/test/java/org/apache/doris/spark/serialization/TestRowBatch.java
b/spark-doris-connector/src/test/java/org/apache/doris/spark/serialization/TestRowBatch.java
index ceeeb32..ace928f 100644
---
a/spark-doris-connector/src/test/java/org/apache/doris/spark/serialization/TestRowBatch.java
+++
b/spark-doris-connector/src/test/java/org/apache/doris/spark/serialization/TestRowBatch.java
@@ -17,19 +17,20 @@
package org.apache.doris.spark.serialization;
-import static org.hamcrest.core.StringStartsWith.startsWith;
-
-import java.io.ByteArrayOutputStream;
-import java.math.BigDecimal;
-import java.util.Arrays;
-import java.util.List;
-import java.util.NoSuchElementException;
+import org.apache.doris.sdk.thrift.TScanBatchResult;
+import org.apache.doris.sdk.thrift.TStatus;
+import org.apache.doris.sdk.thrift.TStatusCode;
+import org.apache.doris.spark.exception.DorisException;
+import org.apache.doris.spark.rest.RestService;
+import org.apache.doris.spark.rest.models.Schema;
+import com.google.common.collect.ImmutableList;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.vector.BigIntVector;
import org.apache.arrow.vector.BitVector;
import org.apache.arrow.vector.DecimalVector;
import org.apache.arrow.vector.FieldVector;
+import org.apache.arrow.vector.FixedSizeBinaryVector;
import org.apache.arrow.vector.Float4Vector;
import org.apache.arrow.vector.Float8Vector;
import org.apache.arrow.vector.IntVector;
@@ -44,11 +45,7 @@ import org.apache.arrow.vector.types.FloatingPointPrecision;
import org.apache.arrow.vector.types.pojo.ArrowType;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.FieldType;
-import org.apache.doris.sdk.thrift.TScanBatchResult;
-import org.apache.doris.sdk.thrift.TStatus;
-import org.apache.doris.sdk.thrift.TStatusCode;
-import org.apache.doris.spark.rest.RestService;
-import org.apache.doris.spark.rest.models.Schema;
+import org.apache.commons.lang3.ArrayUtils;
import org.apache.spark.sql.types.Decimal;
import org.junit.Assert;
import org.junit.Rule;
@@ -57,7 +54,16 @@ import org.junit.rules.ExpectedException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
-import com.google.common.collect.ImmutableList;
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.math.BigDecimal;
+import java.math.BigInteger;
+import java.sql.Date;
+import java.util.Arrays;
+import java.util.List;
+import java.util.NoSuchElementException;
+
+import static org.hamcrest.core.StringStartsWith.startsWith;
public class TestRowBatch {
private final static Logger logger =
LoggerFactory.getLogger(TestRowBatch.class);
@@ -250,7 +256,7 @@ public class TestRowBatch {
1L,
(float) 1.1,
(double) 1.1,
- "2008-08-08",
+ Date.valueOf("2008-08-08"),
"2008-08-08 00:00:00",
Decimal.apply(1234L, 4, 2),
"char1"
@@ -264,7 +270,7 @@ public class TestRowBatch {
2L,
(float) 2.2,
(double) 2.2,
- "1900-08-08",
+ Date.valueOf("1900-08-08"),
"1900-08-08 00:00:00",
Decimal.apply(8888L, 4, 2),
"char2"
@@ -278,7 +284,7 @@ public class TestRowBatch {
3L,
(float) 3.3,
(double) 3.3,
- "2100-08-08",
+ Date.valueOf("2100-08-08"),
"2100-08-08 00:00:00",
Decimal.apply(10L, 2, 0),
"char3"
@@ -286,15 +292,15 @@ public class TestRowBatch {
Assert.assertTrue(rowBatch.hasNext());
List<Object> actualRow1 = rowBatch.next();
- Assert.assertEquals(expectedRow1, actualRow1);
+ Assert.assertArrayEquals(expectedRow1.toArray(), actualRow1.toArray());
Assert.assertTrue(rowBatch.hasNext());
List<Object> actualRow2 = rowBatch.next();
- Assert.assertEquals(expectedRow2, actualRow2);
+ Assert.assertArrayEquals(expectedRow2.toArray(), actualRow2.toArray());
Assert.assertTrue(rowBatch.hasNext());
List<Object> actualRow3 = rowBatch.next();
- Assert.assertEquals(expectedRow3, actualRow3);
+ Assert.assertArrayEquals(expectedRow3.toArray(), actualRow3.toArray());
Assert.assertFalse(rowBatch.hasNext());
thrown.expect(NoSuchElementException.class);
@@ -437,4 +443,152 @@ public class TestRowBatch {
thrown.expectMessage(startsWith("Get row offset:"));
rowBatch.next();
}
+
+ @Test
+ public void testDate() throws DorisException, IOException {
+
+ ImmutableList.Builder<Field> childrenBuilder = ImmutableList.builder();
+ childrenBuilder.add(new Field("k1", FieldType.nullable(new
ArrowType.Utf8()), null));
+ childrenBuilder.add(new Field("k2", FieldType.nullable(new
ArrowType.Utf8()), null));
+
+ VectorSchemaRoot root = VectorSchemaRoot.create(
+ new
org.apache.arrow.vector.types.pojo.Schema(childrenBuilder.build(), null),
+ new RootAllocator(Integer.MAX_VALUE));
+ ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
+ ArrowStreamWriter arrowStreamWriter = new ArrowStreamWriter(
+ root,
+ new DictionaryProvider.MapDictionaryProvider(),
+ outputStream);
+
+ arrowStreamWriter.start();
+ root.setRowCount(1);
+
+ FieldVector vector = root.getVector("k1");
+ VarCharVector dateVector = (VarCharVector)vector;
+ dateVector.setInitialCapacity(1);
+ dateVector.allocateNew();
+ dateVector.setIndexDefined(0);
+ dateVector.setValueLengthSafe(0, 10);
+ dateVector.setSafe(0, "2023-08-09".getBytes());
+ vector.setValueCount(1);
+
+
+ vector = root.getVector("k2");
+ VarCharVector dateV2Vector = (VarCharVector)vector;
+ dateV2Vector.setInitialCapacity(1);
+ dateV2Vector.allocateNew();
+ dateV2Vector.setIndexDefined(0);
+ dateV2Vector.setValueLengthSafe(0, 10);
+ dateV2Vector.setSafe(0, "2023-08-10".getBytes());
+ vector.setValueCount(1);
+
+ arrowStreamWriter.writeBatch();
+
+ arrowStreamWriter.end();
+ arrowStreamWriter.close();
+
+ TStatus status = new TStatus();
+ status.setStatusCode(TStatusCode.OK);
+ TScanBatchResult scanBatchResult = new TScanBatchResult();
+ scanBatchResult.setStatus(status);
+ scanBatchResult.setEos(false);
+ scanBatchResult.setRows(outputStream.toByteArray());
+
+
+ String schemaStr = "{\"properties\":[" +
+ "{\"type\":\"DATE\",\"name\":\"k1\",\"comment\":\"\"}, " +
+ "{\"type\":\"DATEV2\",\"name\":\"k2\",\"comment\":\"\"}" +
+ "], \"status\":200}";
+
+ Schema schema = RestService.parseSchema(schemaStr, logger);
+
+ RowBatch rowBatch = new RowBatch(scanBatchResult, schema);
+
+ Assert.assertTrue(rowBatch.hasNext());
+ List<Object> actualRow0 = rowBatch.next();
+ Assert.assertEquals(Date.valueOf("2023-08-09"), actualRow0.get(0));
+ Assert.assertEquals(Date.valueOf("2023-08-10"), actualRow0.get(1));
+
+ Assert.assertFalse(rowBatch.hasNext());
+ thrown.expect(NoSuchElementException.class);
+ thrown.expectMessage(startsWith("Get row offset:"));
+ rowBatch.next();
+
+ }
+
+ @Test
+ public void testLargeInt() throws DorisException, IOException {
+
+ ImmutableList.Builder<Field> childrenBuilder = ImmutableList.builder();
+ childrenBuilder.add(new Field("k1", FieldType.nullable(new
ArrowType.Utf8()), null));
+ childrenBuilder.add(new Field("k2", FieldType.nullable(new
ArrowType.FixedSizeBinary(16)), null));
+
+ VectorSchemaRoot root = VectorSchemaRoot.create(
+ new
org.apache.arrow.vector.types.pojo.Schema(childrenBuilder.build(), null),
+ new RootAllocator(Integer.MAX_VALUE));
+ ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
+ ArrowStreamWriter arrowStreamWriter = new ArrowStreamWriter(
+ root,
+ new DictionaryProvider.MapDictionaryProvider(),
+ outputStream);
+
+ arrowStreamWriter.start();
+ root.setRowCount(1);
+
+ FieldVector vector = root.getVector("k1");
+ VarCharVector lageIntVector = (VarCharVector)vector;
+ lageIntVector.setInitialCapacity(1);
+ lageIntVector.allocateNew();
+ lageIntVector.setIndexDefined(0);
+ lageIntVector.setValueLengthSafe(0, 19);
+ lageIntVector.setSafe(0, "9223372036854775808".getBytes());
+ vector.setValueCount(1);
+
+
+ vector = root.getVector("k2");
+ FixedSizeBinaryVector lageIntVector1 = (FixedSizeBinaryVector)vector;
+ lageIntVector1.setInitialCapacity(1);
+ lageIntVector1.allocateNew();
+ lageIntVector1.setIndexDefined(0);
+ byte[] bytes = new BigInteger("9223372036854775809").toByteArray();
+ byte[] fixedBytes = new byte[16];
+ System.arraycopy(bytes, 0, fixedBytes, 16 - bytes.length,
bytes.length);
+ ArrayUtils.reverse(fixedBytes);
+ lageIntVector1.setSafe(0, fixedBytes);
+ vector.setValueCount(1);
+
+ arrowStreamWriter.writeBatch();
+
+ arrowStreamWriter.end();
+ arrowStreamWriter.close();
+
+ TStatus status = new TStatus();
+ status.setStatusCode(TStatusCode.OK);
+ TScanBatchResult scanBatchResult = new TScanBatchResult();
+ scanBatchResult.setStatus(status);
+ scanBatchResult.setEos(false);
+ scanBatchResult.setRows(outputStream.toByteArray());
+
+ String schemaStr = "{\"properties\":[" +
+ "{\"type\":\"LARGEINT\",\"name\":\"k1\",\"comment\":\"\"}, " +
+ "{\"type\":\"LARGEINT\",\"name\":\"k2\",\"comment\":\"\"}" +
+ "], \"status\":200}";
+
+ Schema schema = RestService.parseSchema(schemaStr, logger);
+
+ RowBatch rowBatch = new RowBatch(scanBatchResult, schema);
+
+ Assert.assertTrue(rowBatch.hasNext());
+ List<Object> actualRow0 = rowBatch.next();
+
+ Assert.assertEquals(Decimal.apply(new
BigInteger("9223372036854775808")), actualRow0.get(0));
+ Assert.assertEquals(Decimal.apply(new
BigInteger("9223372036854775809")), actualRow0.get(1));
+
+ Assert.assertFalse(rowBatch.hasNext());
+ thrown.expect(NoSuchElementException.class);
+ thrown.expectMessage(startsWith("Get row offset:"));
+ rowBatch.next();
+
+ }
+
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]