This is an automated email from the ASF dual-hosted git repository.

tyrantlucifer pushed a commit to branch dev
in repository https://gitbox.apache.org/repos/asf/seatunnel.git


The following commit(s) were added to refs/heads/dev by this push:
     new 05cf84fb17 [Improve][Spark] Convert array type to exact type (#7758)
05cf84fb17 is described below

commit 05cf84fb17c49e0f3a3e0db8b091bc8e3f7cc9b0
Author: corgy-w <[email protected]>
AuthorDate: Sat Sep 28 14:47:00 2024 +0800

    [Improve][Spark] Convert array type to exact type (#7758)
---
 .../spark/serialization/InternalRowConverter.java  | 18 ++++--
 .../spark/execution/MultiTableManagerTest.java     | 72 +++++++++++-----------
 2 files changed, 49 insertions(+), 41 deletions(-)

diff --git 
a/seatunnel-translation/seatunnel-translation-spark/seatunnel-translation-spark-common/src/main/java/org/apache/seatunnel/translation/spark/serialization/InternalRowConverter.java
 
b/seatunnel-translation/seatunnel-translation-spark/seatunnel-translation-spark-common/src/main/java/org/apache/seatunnel/translation/spark/serialization/InternalRowConverter.java
index 8dbca1975a..7e2ee4ef87 100644
--- 
a/seatunnel-translation/seatunnel-translation-spark/seatunnel-translation-spark-common/src/main/java/org/apache/seatunnel/translation/spark/serialization/InternalRowConverter.java
+++ 
b/seatunnel-translation/seatunnel-translation-spark/seatunnel-translation-spark-common/src/main/java/org/apache/seatunnel/translation/spark/serialization/InternalRowConverter.java
@@ -53,6 +53,7 @@ import scala.collection.immutable.HashMap.HashTrieMap;
 import scala.collection.mutable.WrappedArray;
 
 import java.io.IOException;
+import java.lang.reflect.Array;
 import java.math.BigDecimal;
 import java.sql.Date;
 import java.sql.Timestamp;
@@ -107,17 +108,23 @@ public final class InternalRowConverter extends 
RowConverter<InternalRow> {
             case DECIMAL:
                 return Decimal.apply((BigDecimal) field);
             case ARRAY:
+                Class<?> elementTypeClass =
+                        ((ArrayType<?, ?>) 
dataType).getElementType().getTypeClass();
                 // if string array, we need to covert every item in array from 
String to UTF8String
                 if (((ArrayType<?, ?>) 
dataType).getElementType().equals(BasicType.STRING_TYPE)) {
                     Object[] fields = (Object[]) field;
-                    Object[] objects =
+                    UTF8String[] objects =
                             Arrays.stream(fields)
                                     .map(v -> UTF8String.fromString((String) 
v))
-                                    .toArray();
+                                    .toArray(UTF8String[]::new);
                     return ArrayData.toArrayData(objects);
                 }
                 // except string, now only support convert boolean int tinyint 
smallint bigint float
                 // double, because SeaTunnel Array only support these types
+                Object array = Array.newInstance(elementTypeClass, ((Object[]) 
field).length);
+                for (int i = 0; i < ((Object[]) field).length; i++) {
+                    Array.set(array, i, ((Object[]) field)[i]);
+                }
                 return ArrayData.toArrayData(field);
             default:
                 if (field instanceof scala.Some) {
@@ -339,14 +346,17 @@ public final class InternalRowConverter extends 
RowConverter<InternalRow> {
     }
 
     private static Object reconvertArray(ArrayData arrayData, ArrayType<?, ?> 
arrayType) {
+        Class<?> elementTypeClass = arrayType.getElementType().getTypeClass();
         if (arrayData == null || arrayData.numElements() == 0) {
             return Collections.emptyList().toArray();
         }
-        Object[] newArray = new Object[arrayData.numElements()];
+        Object[] newArray = (Object[]) Array.newInstance(elementTypeClass, 
arrayData.numElements());
         Object[] values =
                 
arrayData.toObjectArray(TypeConverterUtils.convert(arrayType.getElementType()));
         for (int i = 0; i < arrayData.numElements(); i++) {
-            newArray[i] = reconvert(values[i], arrayType.getElementType());
+            Object reconvert =
+                    elementTypeClass.cast(reconvert(values[i], 
arrayType.getElementType()));
+            newArray[i] = reconvert;
         }
         return newArray;
     }
diff --git 
a/seatunnel-translation/seatunnel-translation-spark/seatunnel-translation-spark-common/src/test/java/org/apache/seatunnel/translation/spark/execution/MultiTableManagerTest.java
 
b/seatunnel-translation/seatunnel-translation-spark/seatunnel-translation-spark-common/src/test/java/org/apache/seatunnel/translation/spark/execution/MultiTableManagerTest.java
index ad5fdfbea5..24dd23148a 100644
--- 
a/seatunnel-translation/seatunnel-translation-spark/seatunnel-translation-spark-common/src/test/java/org/apache/seatunnel/translation/spark/execution/MultiTableManagerTest.java
+++ 
b/seatunnel-translation/seatunnel-translation-spark/seatunnel-translation-spark-common/src/test/java/org/apache/seatunnel/translation/spark/execution/MultiTableManagerTest.java
@@ -648,14 +648,14 @@ public class MultiTableManagerTest {
                             LocalDate.parse("2001-01-01"),
                             LocalDateTime.parse("2031-01-01T00:00:00"),
                             null,
-                            new Object[] {"string1fsa", "stringdsa2", 
"strfdsaing3"},
-                            new Object[] {false, true, true},
-                            new Object[] {(byte) 6, (byte) 2, (byte) 1},
-                            new Object[] {(short) 7, (short) 8, (short) 9},
-                            new Object[] {3, 77, 22},
-                            new Object[] {143L, 642L, 533L},
-                            new Object[] {24.1f, 54.2f, 1.3f},
-                            new Object[] {431.11, 2422.22, 3243.33},
+                            new String[] {"string1fsa", "stringdsa2", 
"strfdsaing3"},
+                            new Boolean[] {false, true, true},
+                            new Byte[] {(byte) 6, (byte) 2, (byte) 1},
+                            new Short[] {(short) 7, (short) 8, (short) 9},
+                            new Integer[] {3, 77, 22},
+                            new Long[] {143L, 642L, 533L},
+                            new Float[] {24.1f, 54.2f, 1.3f},
+                            new Double[] {431.11, 2422.22, 3243.33},
                             new HashMap<String, String>() {
                                 {
                                     put("keyfs1", "valfdsue1");
@@ -680,14 +680,14 @@ public class MultiTableManagerTest {
                             LocalDate.parse("2001-01-01"),
                             LocalDateTime.parse("2031-01-01T00:00:00"),
                             null,
-                            new Object[] {"string1fsa", "stringdsa2", 
"strfdsaing3"},
-                            new Object[] {false, true, true},
-                            new Object[] {(byte) 6, (byte) 2, (byte) 1},
-                            new Object[] {(short) 7, (short) 8, (short) 9},
-                            new Object[] {3, 77, 22},
-                            new Object[] {143L, 642L, 533L},
-                            new Object[] {24.1f, 54.2f, 1.3f},
-                            new Object[] {431.11, 2422.22, 3243.33},
+                            new String[] {"string1fsa", "stringdsa2", 
"strfdsaing3"},
+                            new Boolean[] {false, true, true},
+                            new Byte[] {(byte) 6, (byte) 2, (byte) 1},
+                            new Short[] {(short) 7, (short) 8, (short) 9},
+                            new Integer[] {3, 77, 22},
+                            new Long[] {143L, 642L, 533L},
+                            new Float[] {24.1f, 54.2f, 1.3f},
+                            new Double[] {431.11, 2422.22, 3243.33},
                             new HashMap<String, String>() {
                                 {
                                     put("keyfs1", "valfdsue1");
@@ -736,37 +736,36 @@ public class MultiTableManagerTest {
                         }));
 
         mutableValues[13] = new MutableAny();
-        mutableValues[13].update(ArrayData.toArrayData(new Object[] {false, 
true, true}));
+        mutableValues[13].update(ArrayData.toArrayData(new Boolean[] {false, 
true, true}));
 
         mutableValues[14] = new MutableAny();
-        mutableValues[14].update(
-                ArrayData.toArrayData(new Object[] {(byte) 6, (byte) 2, (byte) 
1}));
+        mutableValues[14].update(ArrayData.toArrayData(new Byte[] {(byte) 6, 
(byte) 2, (byte) 1}));
 
         mutableValues[15] = new MutableAny();
         mutableValues[15].update(
-                ArrayData.toArrayData(new Object[] {(short) 7, (short) 8, 
(short) 9}));
+                ArrayData.toArrayData(new Short[] {(short) 7, (short) 8, 
(short) 9}));
 
         mutableValues[16] = new MutableAny();
-        mutableValues[16].update(ArrayData.toArrayData(new Object[] {3, 77, 
22}));
+        mutableValues[16].update(ArrayData.toArrayData(new Integer[] {3, 77, 
22}));
 
         mutableValues[17] = new MutableAny();
-        mutableValues[17].update(ArrayData.toArrayData(new Object[] {143L, 
642L, 533L}));
+        mutableValues[17].update(ArrayData.toArrayData(new Long[] {143L, 642L, 
533L}));
 
         mutableValues[18] = new MutableAny();
-        mutableValues[18].update(ArrayData.toArrayData(new Object[] {24.1f, 
54.2f, 1.3f}));
+        mutableValues[18].update(ArrayData.toArrayData(new Float[] {24.1f, 
54.2f, 1.3f}));
 
         mutableValues[19] = new MutableAny();
-        mutableValues[19].update(ArrayData.toArrayData(new Object[] {431.11, 
2422.22, 3243.33}));
+        mutableValues[19].update(ArrayData.toArrayData(new Double[] {431.11, 
2422.22, 3243.33}));
 
         mutableValues[20] = new MutableAny();
         mutableValues[20].update(
                 ArrayBasedMapData.apply(
-                        new Object[] {
+                        new UTF8String[] {
                             UTF8String.fromString("kefdsay3"),
                             UTF8String.fromString("keyfs1"),
                             UTF8String.fromString("kedfasy2")
                         },
-                        new Object[] {
+                        new UTF8String[] {
                             UTF8String.fromString("vfdasalue3"),
                             UTF8String.fromString("valfdsue1"),
                             UTF8String.fromString("vafdslue2")
@@ -808,44 +807,43 @@ public class MultiTableManagerTest {
         mutableValues1[14] = new MutableAny();
         mutableValues1[14].update(
                 ArrayData.toArrayData(
-                        new Object[] {
+                        new UTF8String[] {
                             UTF8String.fromString("string1fsa"),
                             UTF8String.fromString("stringdsa2"),
                             UTF8String.fromString("strfdsaing3")
                         }));
 
         mutableValues1[15] = new MutableAny();
-        mutableValues1[15].update(ArrayData.toArrayData(new Object[] {false, 
true, true}));
+        mutableValues1[15].update(ArrayData.toArrayData(new Boolean[] {false, 
true, true}));
 
         mutableValues1[16] = new MutableAny();
-        mutableValues1[16].update(
-                ArrayData.toArrayData(new Object[] {(byte) 6, (byte) 2, (byte) 
1}));
+        mutableValues1[16].update(ArrayData.toArrayData(new Byte[] {(byte) 6, 
(byte) 2, (byte) 1}));
 
         mutableValues1[17] = new MutableAny();
         mutableValues1[17].update(
-                ArrayData.toArrayData(new Object[] {(short) 7, (short) 8, 
(short) 9}));
+                ArrayData.toArrayData(new Short[] {(short) 7, (short) 8, 
(short) 9}));
 
         mutableValues1[18] = new MutableAny();
-        mutableValues1[18].update(ArrayData.toArrayData(new Object[] {3, 77, 
22}));
+        mutableValues1[18].update(ArrayData.toArrayData(new Integer[] {3, 77, 
22}));
 
         mutableValues1[19] = new MutableAny();
-        mutableValues1[19].update(ArrayData.toArrayData(new Object[] {143L, 
642L, 533L}));
+        mutableValues1[19].update(ArrayData.toArrayData(new Long[] {143L, 
642L, 533L}));
 
         mutableValues1[20] = new MutableAny();
-        mutableValues1[20].update(ArrayData.toArrayData(new Object[] {24.1f, 
54.2f, 1.3f}));
+        mutableValues1[20].update(ArrayData.toArrayData(new Float[] {24.1f, 
54.2f, 1.3f}));
 
         mutableValues1[21] = new MutableAny();
-        mutableValues1[21].update(ArrayData.toArrayData(new Object[] {431.11, 
2422.22, 3243.33}));
+        mutableValues1[21].update(ArrayData.toArrayData(new Double[] {431.11, 
2422.22, 3243.33}));
 
         mutableValues1[22] = new MutableAny();
         mutableValues1[22].update(
                 ArrayBasedMapData.apply(
-                        new Object[] {
+                        new UTF8String[] {
                             UTF8String.fromString("kefdsay3"),
                             UTF8String.fromString("keyfs1"),
                             UTF8String.fromString("kedfasy2")
                         },
-                        new Object[] {
+                        new UTF8String[] {
                             UTF8String.fromString("vfdasalue3"),
                             UTF8String.fromString("valfdsue1"),
                             UTF8String.fromString("vafdslue2")

Reply via email to