This is an automated email from the ASF dual-hosted git repository.

diwu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris-spark-connector.git


The following commit(s) were added to refs/heads/master by this push:
     new da673a1  convert date/datev2 to sql date and largeint into 
decimal(38,0) (#125)
da673a1 is described below

commit da673a134204bbff3300f6cb0901134f63da7adf
Author: gnehil <[email protected]>
AuthorDate: Mon Aug 14 15:24:30 2023 +0800

    convert date/datev2 to sql date and largeint into decimal(38,0) (#125)
---
 .../apache/doris/spark/serialization/RowBatch.java |  53 +++++-
 .../org/apache/doris/spark/sql/SchemaUtils.scala   |   6 +-
 .../doris/spark/serialization/TestRowBatch.java    | 192 +++++++++++++++++++--
 3 files changed, 227 insertions(+), 24 deletions(-)

diff --git 
a/spark-doris-connector/src/main/java/org/apache/doris/spark/serialization/RowBatch.java
 
b/spark-doris-connector/src/main/java/org/apache/doris/spark/serialization/RowBatch.java
index 371c491..faa8ef5 100644
--- 
a/spark-doris-connector/src/main/java/org/apache/doris/spark/serialization/RowBatch.java
+++ 
b/spark-doris-connector/src/main/java/org/apache/doris/spark/serialization/RowBatch.java
@@ -20,7 +20,12 @@ package org.apache.doris.spark.serialization;
 import java.io.ByteArrayInputStream;
 import java.io.IOException;
 import java.math.BigDecimal;
+import java.math.BigInteger;
 import java.nio.charset.StandardCharsets;
+import java.sql.Date;
+import java.time.LocalDate;
+import java.time.LocalDateTime;
+import java.time.format.DateTimeFormatter;
 import java.util.ArrayList;
 import java.util.List;
 import java.util.NoSuchElementException;
@@ -30,6 +35,7 @@ import org.apache.arrow.vector.BigIntVector;
 import org.apache.arrow.vector.BitVector;
 import org.apache.arrow.vector.DecimalVector;
 import org.apache.arrow.vector.FieldVector;
+import org.apache.arrow.vector.FixedSizeBinaryVector;
 import org.apache.arrow.vector.Float4Vector;
 import org.apache.arrow.vector.Float8Vector;
 import org.apache.arrow.vector.IntVector;
@@ -41,9 +47,12 @@ import org.apache.arrow.vector.VectorSchemaRoot;
 import org.apache.arrow.vector.complex.ListVector;
 import org.apache.arrow.vector.ipc.ArrowStreamReader;
 import org.apache.arrow.vector.types.Types;
+
 import org.apache.doris.sdk.thrift.TScanBatchResult;
 import org.apache.doris.spark.exception.DorisException;
 import org.apache.doris.spark.rest.models.Schema;
+
+import org.apache.commons.lang3.ArrayUtils;
 import org.apache.spark.sql.types.Decimal;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -98,7 +107,7 @@ public class RowBatch {
                             fieldVectors.size(), schema.size());
                     throw new DorisException("Load Doris data failed, schema 
size of fetch data is wrong.");
                 }
-                if (fieldVectors.size() == 0 || root.getRowCount() == 0) {
+                if (fieldVectors.isEmpty() || root.getRowCount() == 0) {
                     logger.debug("One batch in arrow has no data.");
                     continue;
                 }
@@ -190,6 +199,34 @@ public class RowBatch {
                             addValueToRow(rowIndex, fieldValue);
                         }
                         break;
+                    case "LARGEINT":
+                        
Preconditions.checkArgument(mt.equals(Types.MinorType.FIXEDSIZEBINARY) ||
+                                mt.equals(Types.MinorType.VARCHAR), 
typeMismatchMessage(currentType, mt));
+                        if (mt.equals(Types.MinorType.FIXEDSIZEBINARY)) {
+                            FixedSizeBinaryVector largeIntVector = 
(FixedSizeBinaryVector) curFieldVector;
+                            for (int rowIndex = 0; rowIndex < 
rowCountInOneBatch; rowIndex++) {
+                                if (largeIntVector.isNull(rowIndex)) {
+                                    addValueToRow(rowIndex, null);
+                                    continue;
+                                }
+                                byte[] bytes = largeIntVector.get(rowIndex);
+                                ArrayUtils.reverse(bytes);
+                                BigInteger largeInt = new BigInteger(bytes);
+                                addValueToRow(rowIndex, 
Decimal.apply(largeInt));
+                            }
+                        } else {
+                            VarCharVector largeIntVector = (VarCharVector) 
curFieldVector;
+                            for (int rowIndex = 0; rowIndex < 
rowCountInOneBatch; rowIndex++) {
+                                if (largeIntVector.isNull(rowIndex)) {
+                                    addValueToRow(rowIndex, null);
+                                    continue;
+                                }
+                                String stringValue = new 
String(largeIntVector.get(rowIndex));
+                                BigInteger largeInt = new 
BigInteger(stringValue);
+                                addValueToRow(rowIndex, 
Decimal.apply(largeInt));
+                            }
+                        }
+                        break;
                     case "FLOAT":
                         
Preconditions.checkArgument(mt.equals(Types.MinorType.FLOAT4),
                                 typeMismatchMessage(currentType, mt));
@@ -257,9 +294,21 @@ public class RowBatch {
                         break;
                     case "DATE":
                     case "DATEV2":
+                        
Preconditions.checkArgument(mt.equals(Types.MinorType.VARCHAR),
+                                typeMismatchMessage(currentType, mt));
+                        VarCharVector date = (VarCharVector) curFieldVector;
+                        for (int rowIndex = 0; rowIndex < rowCountInOneBatch; 
rowIndex++) {
+                            if (date.isNull(rowIndex)) {
+                                addValueToRow(rowIndex, null);
+                                continue;
+                            }
+                            String stringValue = new 
String(date.get(rowIndex));
+                            LocalDate localDate = LocalDate.parse(stringValue);
+                            addValueToRow(rowIndex, Date.valueOf(localDate));
+                        }
+                        break;
                     case "DATETIME":
                     case "DATETIMEV2":
-                    case "LARGEINT":
                     case "CHAR":
                     case "VARCHAR":
                     case "STRING":
diff --git 
a/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/SchemaUtils.scala
 
b/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/SchemaUtils.scala
index c7fad41..c8aa034 100644
--- 
a/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/SchemaUtils.scala
+++ 
b/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/SchemaUtils.scala
@@ -102,14 +102,14 @@ private[spark] object SchemaUtils {
       case "BIGINT"          => DataTypes.LongType
       case "FLOAT"           => DataTypes.FloatType
       case "DOUBLE"          => DataTypes.DoubleType
-      case "DATE"            => DataTypes.StringType
-      case "DATEV2"          => DataTypes.StringType
+      case "DATE"            => DataTypes.DateType
+      case "DATEV2"          => DataTypes.DateType
       case "DATETIME"        => DataTypes.StringType
       case "DATETIMEV2"      => DataTypes.StringType
       case "BINARY"          => DataTypes.BinaryType
       case "DECIMAL"         => DecimalType(precision, scale)
       case "CHAR"            => DataTypes.StringType
-      case "LARGEINT"        => DataTypes.StringType
+      case "LARGEINT"        => DecimalType(38,0)
       case "VARCHAR"         => DataTypes.StringType
       case "JSONB"           => DataTypes.StringType
       case "DECIMALV2"       => DecimalType(precision, scale)
diff --git 
a/spark-doris-connector/src/test/java/org/apache/doris/spark/serialization/TestRowBatch.java
 
b/spark-doris-connector/src/test/java/org/apache/doris/spark/serialization/TestRowBatch.java
index ceeeb32..ace928f 100644
--- 
a/spark-doris-connector/src/test/java/org/apache/doris/spark/serialization/TestRowBatch.java
+++ 
b/spark-doris-connector/src/test/java/org/apache/doris/spark/serialization/TestRowBatch.java
@@ -17,19 +17,20 @@
 
 package org.apache.doris.spark.serialization;
 
-import static org.hamcrest.core.StringStartsWith.startsWith;
-
-import java.io.ByteArrayOutputStream;
-import java.math.BigDecimal;
-import java.util.Arrays;
-import java.util.List;
-import java.util.NoSuchElementException;
+import org.apache.doris.sdk.thrift.TScanBatchResult;
+import org.apache.doris.sdk.thrift.TStatus;
+import org.apache.doris.sdk.thrift.TStatusCode;
+import org.apache.doris.spark.exception.DorisException;
+import org.apache.doris.spark.rest.RestService;
+import org.apache.doris.spark.rest.models.Schema;
 
+import com.google.common.collect.ImmutableList;
 import org.apache.arrow.memory.RootAllocator;
 import org.apache.arrow.vector.BigIntVector;
 import org.apache.arrow.vector.BitVector;
 import org.apache.arrow.vector.DecimalVector;
 import org.apache.arrow.vector.FieldVector;
+import org.apache.arrow.vector.FixedSizeBinaryVector;
 import org.apache.arrow.vector.Float4Vector;
 import org.apache.arrow.vector.Float8Vector;
 import org.apache.arrow.vector.IntVector;
@@ -44,11 +45,7 @@ import org.apache.arrow.vector.types.FloatingPointPrecision;
 import org.apache.arrow.vector.types.pojo.ArrowType;
 import org.apache.arrow.vector.types.pojo.Field;
 import org.apache.arrow.vector.types.pojo.FieldType;
-import org.apache.doris.sdk.thrift.TScanBatchResult;
-import org.apache.doris.sdk.thrift.TStatus;
-import org.apache.doris.sdk.thrift.TStatusCode;
-import org.apache.doris.spark.rest.RestService;
-import org.apache.doris.spark.rest.models.Schema;
+import org.apache.commons.lang3.ArrayUtils;
 import org.apache.spark.sql.types.Decimal;
 import org.junit.Assert;
 import org.junit.Rule;
@@ -57,7 +54,16 @@ import org.junit.rules.ExpectedException;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
-import com.google.common.collect.ImmutableList;
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.math.BigDecimal;
+import java.math.BigInteger;
+import java.sql.Date;
+import java.util.Arrays;
+import java.util.List;
+import java.util.NoSuchElementException;
+
+import static org.hamcrest.core.StringStartsWith.startsWith;
 
 public class TestRowBatch {
     private final static Logger logger = 
LoggerFactory.getLogger(TestRowBatch.class);
@@ -250,7 +256,7 @@ public class TestRowBatch {
                 1L,
                 (float) 1.1,
                 (double) 1.1,
-                "2008-08-08",
+                Date.valueOf("2008-08-08"),
                 "2008-08-08 00:00:00",
                 Decimal.apply(1234L, 4, 2),
                 "char1"
@@ -264,7 +270,7 @@ public class TestRowBatch {
                 2L,
                 (float) 2.2,
                 (double) 2.2,
-                "1900-08-08",
+                Date.valueOf("1900-08-08"),
                 "1900-08-08 00:00:00",
                 Decimal.apply(8888L, 4, 2),
                 "char2"
@@ -278,7 +284,7 @@ public class TestRowBatch {
                 3L,
                 (float) 3.3,
                 (double) 3.3,
-                "2100-08-08",
+                Date.valueOf("2100-08-08"),
                 "2100-08-08 00:00:00",
                 Decimal.apply(10L, 2, 0),
                 "char3"
@@ -286,15 +292,15 @@ public class TestRowBatch {
 
         Assert.assertTrue(rowBatch.hasNext());
         List<Object> actualRow1 = rowBatch.next();
-        Assert.assertEquals(expectedRow1, actualRow1);
+        Assert.assertArrayEquals(expectedRow1.toArray(), actualRow1.toArray());
 
         Assert.assertTrue(rowBatch.hasNext());
         List<Object> actualRow2 = rowBatch.next();
-        Assert.assertEquals(expectedRow2, actualRow2);
+        Assert.assertArrayEquals(expectedRow2.toArray(), actualRow2.toArray());
 
         Assert.assertTrue(rowBatch.hasNext());
         List<Object> actualRow3 = rowBatch.next();
-        Assert.assertEquals(expectedRow3, actualRow3);
+        Assert.assertArrayEquals(expectedRow3.toArray(), actualRow3.toArray());
 
         Assert.assertFalse(rowBatch.hasNext());
         thrown.expect(NoSuchElementException.class);
@@ -437,4 +443,152 @@ public class TestRowBatch {
         thrown.expectMessage(startsWith("Get row offset:"));
         rowBatch.next();
     }
+
+    @Test
+    public void testDate() throws DorisException, IOException {
+
+        ImmutableList.Builder<Field> childrenBuilder = ImmutableList.builder();
+        childrenBuilder.add(new Field("k1", FieldType.nullable(new 
ArrowType.Utf8()), null));
+        childrenBuilder.add(new Field("k2", FieldType.nullable(new 
ArrowType.Utf8()), null));
+
+        VectorSchemaRoot root = VectorSchemaRoot.create(
+                new 
org.apache.arrow.vector.types.pojo.Schema(childrenBuilder.build(), null),
+                new RootAllocator(Integer.MAX_VALUE));
+        ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
+        ArrowStreamWriter arrowStreamWriter = new ArrowStreamWriter(
+                root,
+                new DictionaryProvider.MapDictionaryProvider(),
+                outputStream);
+
+        arrowStreamWriter.start();
+        root.setRowCount(1);
+
+        FieldVector vector = root.getVector("k1");
+        VarCharVector dateVector = (VarCharVector)vector;
+        dateVector.setInitialCapacity(1);
+        dateVector.allocateNew();
+        dateVector.setIndexDefined(0);
+        dateVector.setValueLengthSafe(0, 10);
+        dateVector.setSafe(0, "2023-08-09".getBytes());
+        vector.setValueCount(1);
+
+
+        vector = root.getVector("k2");
+        VarCharVector dateV2Vector = (VarCharVector)vector;
+        dateV2Vector.setInitialCapacity(1);
+        dateV2Vector.allocateNew();
+        dateV2Vector.setIndexDefined(0);
+        dateV2Vector.setValueLengthSafe(0, 10);
+        dateV2Vector.setSafe(0, "2023-08-10".getBytes());
+        vector.setValueCount(1);
+
+        arrowStreamWriter.writeBatch();
+
+        arrowStreamWriter.end();
+        arrowStreamWriter.close();
+
+        TStatus status = new TStatus();
+        status.setStatusCode(TStatusCode.OK);
+        TScanBatchResult scanBatchResult = new TScanBatchResult();
+        scanBatchResult.setStatus(status);
+        scanBatchResult.setEos(false);
+        scanBatchResult.setRows(outputStream.toByteArray());
+
+
+        String schemaStr = "{\"properties\":[" +
+                "{\"type\":\"DATE\",\"name\":\"k1\",\"comment\":\"\"}, " +
+                "{\"type\":\"DATEV2\",\"name\":\"k2\",\"comment\":\"\"}" +
+                "], \"status\":200}";
+
+        Schema schema = RestService.parseSchema(schemaStr, logger);
+
+        RowBatch rowBatch = new RowBatch(scanBatchResult, schema);
+
+        Assert.assertTrue(rowBatch.hasNext());
+        List<Object> actualRow0 = rowBatch.next();
+        Assert.assertEquals(Date.valueOf("2023-08-09"), actualRow0.get(0));
+        Assert.assertEquals(Date.valueOf("2023-08-10"), actualRow0.get(1));
+
+        Assert.assertFalse(rowBatch.hasNext());
+        thrown.expect(NoSuchElementException.class);
+        thrown.expectMessage(startsWith("Get row offset:"));
+        rowBatch.next();
+
+    }
+
+    @Test
+    public void testLargeInt() throws DorisException, IOException {
+
+        ImmutableList.Builder<Field> childrenBuilder = ImmutableList.builder();
+        childrenBuilder.add(new Field("k1", FieldType.nullable(new 
ArrowType.Utf8()), null));
+        childrenBuilder.add(new Field("k2", FieldType.nullable(new 
ArrowType.FixedSizeBinary(16)), null));
+
+        VectorSchemaRoot root = VectorSchemaRoot.create(
+                new 
org.apache.arrow.vector.types.pojo.Schema(childrenBuilder.build(), null),
+                new RootAllocator(Integer.MAX_VALUE));
+        ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
+        ArrowStreamWriter arrowStreamWriter = new ArrowStreamWriter(
+                root,
+                new DictionaryProvider.MapDictionaryProvider(),
+                outputStream);
+
+        arrowStreamWriter.start();
+        root.setRowCount(1);
+
+        FieldVector vector = root.getVector("k1");
+        VarCharVector lageIntVector = (VarCharVector)vector;
+        lageIntVector.setInitialCapacity(1);
+        lageIntVector.allocateNew();
+        lageIntVector.setIndexDefined(0);
+        lageIntVector.setValueLengthSafe(0, 19);
+        lageIntVector.setSafe(0, "9223372036854775808".getBytes());
+        vector.setValueCount(1);
+
+
+        vector = root.getVector("k2");
+        FixedSizeBinaryVector lageIntVector1 = (FixedSizeBinaryVector)vector;
+        lageIntVector1.setInitialCapacity(1);
+        lageIntVector1.allocateNew();
+        lageIntVector1.setIndexDefined(0);
+        byte[] bytes = new BigInteger("9223372036854775809").toByteArray();
+        byte[] fixedBytes = new byte[16];
+        System.arraycopy(bytes, 0, fixedBytes, 16 - bytes.length, 
bytes.length);
+        ArrayUtils.reverse(fixedBytes);
+        lageIntVector1.setSafe(0, fixedBytes);
+        vector.setValueCount(1);
+
+        arrowStreamWriter.writeBatch();
+
+        arrowStreamWriter.end();
+        arrowStreamWriter.close();
+
+        TStatus status = new TStatus();
+        status.setStatusCode(TStatusCode.OK);
+        TScanBatchResult scanBatchResult = new TScanBatchResult();
+        scanBatchResult.setStatus(status);
+        scanBatchResult.setEos(false);
+        scanBatchResult.setRows(outputStream.toByteArray());
+
+        String schemaStr = "{\"properties\":[" +
+                "{\"type\":\"LARGEINT\",\"name\":\"k1\",\"comment\":\"\"}, " +
+                "{\"type\":\"LARGEINT\",\"name\":\"k2\",\"comment\":\"\"}" +
+                "], \"status\":200}";
+
+        Schema schema = RestService.parseSchema(schemaStr, logger);
+
+        RowBatch rowBatch = new RowBatch(scanBatchResult, schema);
+
+        Assert.assertTrue(rowBatch.hasNext());
+        List<Object> actualRow0 = rowBatch.next();
+
+        Assert.assertEquals(Decimal.apply(new 
BigInteger("9223372036854775808")), actualRow0.get(0));
+        Assert.assertEquals(Decimal.apply(new 
BigInteger("9223372036854775809")), actualRow0.get(1));
+
+        Assert.assertFalse(rowBatch.hasNext());
+        thrown.expect(NoSuchElementException.class);
+        thrown.expectMessage(startsWith("Get row offset:"));
+        rowBatch.next();
+
+    }
+
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to