This is an automated email from the ASF dual-hosted git repository.
tyrantlucifer pushed a commit to branch dev
in repository https://gitbox.apache.org/repos/asf/seatunnel.git
The following commit(s) were added to refs/heads/dev by this push:
new 05cf84fb17 [Improve][Spark] Convert array type to exact type (#7758)
05cf84fb17 is described below
commit 05cf84fb17c49e0f3a3e0db8b091bc8e3f7cc9b0
Author: corgy-w <[email protected]>
AuthorDate: Sat Sep 28 14:47:00 2024 +0800
[Improve][Spark] Convert array type to exact type (#7758)
---
.../spark/serialization/InternalRowConverter.java | 18 ++++--
.../spark/execution/MultiTableManagerTest.java | 72 +++++++++++-----------
2 files changed, 49 insertions(+), 41 deletions(-)
diff --git
a/seatunnel-translation/seatunnel-translation-spark/seatunnel-translation-spark-common/src/main/java/org/apache/seatunnel/translation/spark/serialization/InternalRowConverter.java
b/seatunnel-translation/seatunnel-translation-spark/seatunnel-translation-spark-common/src/main/java/org/apache/seatunnel/translation/spark/serialization/InternalRowConverter.java
index 8dbca1975a..7e2ee4ef87 100644
---
a/seatunnel-translation/seatunnel-translation-spark/seatunnel-translation-spark-common/src/main/java/org/apache/seatunnel/translation/spark/serialization/InternalRowConverter.java
+++
b/seatunnel-translation/seatunnel-translation-spark/seatunnel-translation-spark-common/src/main/java/org/apache/seatunnel/translation/spark/serialization/InternalRowConverter.java
@@ -53,6 +53,7 @@ import scala.collection.immutable.HashMap.HashTrieMap;
import scala.collection.mutable.WrappedArray;
import java.io.IOException;
+import java.lang.reflect.Array;
import java.math.BigDecimal;
import java.sql.Date;
import java.sql.Timestamp;
@@ -107,17 +108,23 @@ public final class InternalRowConverter extends
RowConverter<InternalRow> {
case DECIMAL:
return Decimal.apply((BigDecimal) field);
case ARRAY:
+ Class<?> elementTypeClass =
+ ((ArrayType<?, ?>)
dataType).getElementType().getTypeClass();
// if string array, we need to covert every item in array from
String to UTF8String
if (((ArrayType<?, ?>)
dataType).getElementType().equals(BasicType.STRING_TYPE)) {
Object[] fields = (Object[]) field;
- Object[] objects =
+ UTF8String[] objects =
Arrays.stream(fields)
.map(v -> UTF8String.fromString((String)
v))
- .toArray();
+ .toArray(UTF8String[]::new);
return ArrayData.toArrayData(objects);
}
// except string, now only support convert boolean int tinyint
smallint bigint float
// double, because SeaTunnel Array only support these types
+ Object array = Array.newInstance(elementTypeClass, ((Object[])
field).length);
+ for (int i = 0; i < ((Object[]) field).length; i++) {
+ Array.set(array, i, ((Object[]) field)[i]);
+ }
return ArrayData.toArrayData(field);
default:
if (field instanceof scala.Some) {
@@ -339,14 +346,17 @@ public final class InternalRowConverter extends
RowConverter<InternalRow> {
}
private static Object reconvertArray(ArrayData arrayData, ArrayType<?, ?>
arrayType) {
+ Class<?> elementTypeClass = arrayType.getElementType().getTypeClass();
if (arrayData == null || arrayData.numElements() == 0) {
return Collections.emptyList().toArray();
}
- Object[] newArray = new Object[arrayData.numElements()];
+ Object[] newArray = (Object[]) Array.newInstance(elementTypeClass,
arrayData.numElements());
Object[] values =
arrayData.toObjectArray(TypeConverterUtils.convert(arrayType.getElementType()));
for (int i = 0; i < arrayData.numElements(); i++) {
- newArray[i] = reconvert(values[i], arrayType.getElementType());
+ Object reconvert =
+ elementTypeClass.cast(reconvert(values[i],
arrayType.getElementType()));
+ newArray[i] = reconvert;
}
return newArray;
}
diff --git
a/seatunnel-translation/seatunnel-translation-spark/seatunnel-translation-spark-common/src/test/java/org/apache/seatunnel/translation/spark/execution/MultiTableManagerTest.java
b/seatunnel-translation/seatunnel-translation-spark/seatunnel-translation-spark-common/src/test/java/org/apache/seatunnel/translation/spark/execution/MultiTableManagerTest.java
index ad5fdfbea5..24dd23148a 100644
---
a/seatunnel-translation/seatunnel-translation-spark/seatunnel-translation-spark-common/src/test/java/org/apache/seatunnel/translation/spark/execution/MultiTableManagerTest.java
+++
b/seatunnel-translation/seatunnel-translation-spark/seatunnel-translation-spark-common/src/test/java/org/apache/seatunnel/translation/spark/execution/MultiTableManagerTest.java
@@ -648,14 +648,14 @@ public class MultiTableManagerTest {
LocalDate.parse("2001-01-01"),
LocalDateTime.parse("2031-01-01T00:00:00"),
null,
- new Object[] {"string1fsa", "stringdsa2",
"strfdsaing3"},
- new Object[] {false, true, true},
- new Object[] {(byte) 6, (byte) 2, (byte) 1},
- new Object[] {(short) 7, (short) 8, (short) 9},
- new Object[] {3, 77, 22},
- new Object[] {143L, 642L, 533L},
- new Object[] {24.1f, 54.2f, 1.3f},
- new Object[] {431.11, 2422.22, 3243.33},
+ new String[] {"string1fsa", "stringdsa2",
"strfdsaing3"},
+ new Boolean[] {false, true, true},
+ new Byte[] {(byte) 6, (byte) 2, (byte) 1},
+ new Short[] {(short) 7, (short) 8, (short) 9},
+ new Integer[] {3, 77, 22},
+ new Long[] {143L, 642L, 533L},
+ new Float[] {24.1f, 54.2f, 1.3f},
+ new Double[] {431.11, 2422.22, 3243.33},
new HashMap<String, String>() {
{
put("keyfs1", "valfdsue1");
@@ -680,14 +680,14 @@ public class MultiTableManagerTest {
LocalDate.parse("2001-01-01"),
LocalDateTime.parse("2031-01-01T00:00:00"),
null,
- new Object[] {"string1fsa", "stringdsa2",
"strfdsaing3"},
- new Object[] {false, true, true},
- new Object[] {(byte) 6, (byte) 2, (byte) 1},
- new Object[] {(short) 7, (short) 8, (short) 9},
- new Object[] {3, 77, 22},
- new Object[] {143L, 642L, 533L},
- new Object[] {24.1f, 54.2f, 1.3f},
- new Object[] {431.11, 2422.22, 3243.33},
+ new String[] {"string1fsa", "stringdsa2",
"strfdsaing3"},
+ new Boolean[] {false, true, true},
+ new Byte[] {(byte) 6, (byte) 2, (byte) 1},
+ new Short[] {(short) 7, (short) 8, (short) 9},
+ new Integer[] {3, 77, 22},
+ new Long[] {143L, 642L, 533L},
+ new Float[] {24.1f, 54.2f, 1.3f},
+ new Double[] {431.11, 2422.22, 3243.33},
new HashMap<String, String>() {
{
put("keyfs1", "valfdsue1");
@@ -736,37 +736,36 @@ public class MultiTableManagerTest {
}));
mutableValues[13] = new MutableAny();
- mutableValues[13].update(ArrayData.toArrayData(new Object[] {false,
true, true}));
+ mutableValues[13].update(ArrayData.toArrayData(new Boolean[] {false,
true, true}));
mutableValues[14] = new MutableAny();
- mutableValues[14].update(
- ArrayData.toArrayData(new Object[] {(byte) 6, (byte) 2, (byte)
1}));
+ mutableValues[14].update(ArrayData.toArrayData(new Byte[] {(byte) 6,
(byte) 2, (byte) 1}));
mutableValues[15] = new MutableAny();
mutableValues[15].update(
- ArrayData.toArrayData(new Object[] {(short) 7, (short) 8,
(short) 9}));
+ ArrayData.toArrayData(new Short[] {(short) 7, (short) 8,
(short) 9}));
mutableValues[16] = new MutableAny();
- mutableValues[16].update(ArrayData.toArrayData(new Object[] {3, 77,
22}));
+ mutableValues[16].update(ArrayData.toArrayData(new Integer[] {3, 77,
22}));
mutableValues[17] = new MutableAny();
- mutableValues[17].update(ArrayData.toArrayData(new Object[] {143L,
642L, 533L}));
+ mutableValues[17].update(ArrayData.toArrayData(new Long[] {143L, 642L,
533L}));
mutableValues[18] = new MutableAny();
- mutableValues[18].update(ArrayData.toArrayData(new Object[] {24.1f,
54.2f, 1.3f}));
+ mutableValues[18].update(ArrayData.toArrayData(new Float[] {24.1f,
54.2f, 1.3f}));
mutableValues[19] = new MutableAny();
- mutableValues[19].update(ArrayData.toArrayData(new Object[] {431.11,
2422.22, 3243.33}));
+ mutableValues[19].update(ArrayData.toArrayData(new Double[] {431.11,
2422.22, 3243.33}));
mutableValues[20] = new MutableAny();
mutableValues[20].update(
ArrayBasedMapData.apply(
- new Object[] {
+ new UTF8String[] {
UTF8String.fromString("kefdsay3"),
UTF8String.fromString("keyfs1"),
UTF8String.fromString("kedfasy2")
},
- new Object[] {
+ new UTF8String[] {
UTF8String.fromString("vfdasalue3"),
UTF8String.fromString("valfdsue1"),
UTF8String.fromString("vafdslue2")
@@ -808,44 +807,43 @@ public class MultiTableManagerTest {
mutableValues1[14] = new MutableAny();
mutableValues1[14].update(
ArrayData.toArrayData(
- new Object[] {
+ new UTF8String[] {
UTF8String.fromString("string1fsa"),
UTF8String.fromString("stringdsa2"),
UTF8String.fromString("strfdsaing3")
}));
mutableValues1[15] = new MutableAny();
- mutableValues1[15].update(ArrayData.toArrayData(new Object[] {false,
true, true}));
+ mutableValues1[15].update(ArrayData.toArrayData(new Boolean[] {false,
true, true}));
mutableValues1[16] = new MutableAny();
- mutableValues1[16].update(
- ArrayData.toArrayData(new Object[] {(byte) 6, (byte) 2, (byte)
1}));
+ mutableValues1[16].update(ArrayData.toArrayData(new Byte[] {(byte) 6,
(byte) 2, (byte) 1}));
mutableValues1[17] = new MutableAny();
mutableValues1[17].update(
- ArrayData.toArrayData(new Object[] {(short) 7, (short) 8,
(short) 9}));
+ ArrayData.toArrayData(new Short[] {(short) 7, (short) 8,
(short) 9}));
mutableValues1[18] = new MutableAny();
- mutableValues1[18].update(ArrayData.toArrayData(new Object[] {3, 77,
22}));
+ mutableValues1[18].update(ArrayData.toArrayData(new Integer[] {3, 77,
22}));
mutableValues1[19] = new MutableAny();
- mutableValues1[19].update(ArrayData.toArrayData(new Object[] {143L,
642L, 533L}));
+ mutableValues1[19].update(ArrayData.toArrayData(new Long[] {143L,
642L, 533L}));
mutableValues1[20] = new MutableAny();
- mutableValues1[20].update(ArrayData.toArrayData(new Object[] {24.1f,
54.2f, 1.3f}));
+ mutableValues1[20].update(ArrayData.toArrayData(new Float[] {24.1f,
54.2f, 1.3f}));
mutableValues1[21] = new MutableAny();
- mutableValues1[21].update(ArrayData.toArrayData(new Object[] {431.11,
2422.22, 3243.33}));
+ mutableValues1[21].update(ArrayData.toArrayData(new Double[] {431.11,
2422.22, 3243.33}));
mutableValues1[22] = new MutableAny();
mutableValues1[22].update(
ArrayBasedMapData.apply(
- new Object[] {
+ new UTF8String[] {
UTF8String.fromString("kefdsay3"),
UTF8String.fromString("keyfs1"),
UTF8String.fromString("kedfasy2")
},
- new Object[] {
+ new UTF8String[] {
UTF8String.fromString("vfdasalue3"),
UTF8String.fromString("valfdsue1"),
UTF8String.fromString("vafdslue2")