This is an automated email from the ASF dual-hosted git repository.
lzljs3620320 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/paimon.git
The following commit(s) were added to refs/heads/master by this push:
new dc7c03ec5a [spark] AbstractSparkInternalRow supports fallback equals
and hashCode (#5101)
dc7c03ec5a is described below
commit dc7c03ec5a94b00f0107cb8d89e53bad6cc03b55
Author: Xiduo You <[email protected]>
AuthorDate: Sat Feb 22 22:00:38 2025 +0800
[spark] AbstractSparkInternalRow supports fallback equals and hashCode
(#5101)
---
.../java/org/apache/paimon/data/GenericMap.java | 4 +
.../org/apache/paimon/utils/InternalRowUtils.java | 131 +++++++++++++++++++++
.../apache/paimon/utils/InternalRowUtilsTest.java | 78 ++++++++++++
.../paimon/spark/AbstractSparkInternalRow.java | 18 ++-
4 files changed, 229 insertions(+), 2 deletions(-)
diff --git a/paimon-common/src/main/java/org/apache/paimon/data/GenericMap.java
b/paimon-common/src/main/java/org/apache/paimon/data/GenericMap.java
index 0b196c0757..0e07e80a5f 100644
--- a/paimon-common/src/main/java/org/apache/paimon/data/GenericMap.java
+++ b/paimon-common/src/main/java/org/apache/paimon/data/GenericMap.java
@@ -64,6 +64,10 @@ public final class GenericMap implements InternalMap,
Serializable {
return map.get(key);
}
+ public boolean contains(Object key) {
+ return map.containsKey(key);
+ }
+
@Override
public int size() {
return map.size();
diff --git
a/paimon-common/src/main/java/org/apache/paimon/utils/InternalRowUtils.java
b/paimon-common/src/main/java/org/apache/paimon/utils/InternalRowUtils.java
index bd46bae631..b052690f0f 100644
--- a/paimon-common/src/main/java/org/apache/paimon/utils/InternalRowUtils.java
+++ b/paimon-common/src/main/java/org/apache/paimon/utils/InternalRowUtils.java
@@ -48,6 +48,7 @@ import javax.annotation.Nullable;
import java.math.BigDecimal;
import java.math.RoundingMode;
import java.util.ArrayList;
+import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@@ -55,6 +56,127 @@ import java.util.Map;
/** Utils for {@link InternalRow} structures. */
public class InternalRowUtils {
+ public static boolean equals(Object data1, Object data2, DataType
dataType) {
+ if ((data1 == null) != (data2 == null)) {
+ return false;
+ }
+ if (data1 != null) {
+ if (data1 instanceof InternalRow) {
+ RowType rowType = (RowType) dataType;
+ int len = rowType.getFieldCount();
+ for (int i = 0; i < len; i++) {
+ Object value1 = get((InternalRow) data1, i,
rowType.getTypeAt(i));
+ Object value2 = get((InternalRow) data2, i,
rowType.getTypeAt(i));
+ if (!equals(value1, value2, rowType.getTypeAt(i))) {
+ return false;
+ }
+ }
+ } else if (data1 instanceof InternalArray) {
+ if (((InternalArray) data1).size() != ((InternalArray)
data2).size()) {
+ return false;
+ }
+ ArrayType arrayType = (ArrayType) dataType;
+ for (int i = 0; i < ((InternalArray) data1).size(); i++) {
+ Object value1 = get((InternalArray) data1, i,
arrayType.getElementType());
+ Object value2 = get((InternalArray) data2, i,
arrayType.getElementType());
+ if (!equals(value1, value2, arrayType.getElementType())) {
+ return false;
+ }
+ }
+ } else if (data1 instanceof InternalMap) {
+ if (((InternalMap) data1).size() != ((InternalMap)
data2).size()) {
+ return false;
+ }
+ MapType mapType = (MapType) dataType;
+ GenericMap map1;
+ GenericMap map2;
+ if (data1 instanceof GenericMap) {
+ map1 = (GenericMap) data1;
+ map2 = (GenericMap) data2;
+ } else {
+ map1 =
+ copyToGenericMap(
+ (InternalMap) data1,
+ mapType.getKeyType(),
+ mapType.getValueType());
+ map2 =
+ copyToGenericMap(
+ (InternalMap) data2,
+ mapType.getKeyType(),
+ mapType.getValueType());
+ }
+ InternalArray keyArray1 = map1.keyArray();
+ for (int i = 0; i < map1.size(); i++) {
+ Object key = get(keyArray1, i, mapType.getKeyType());
+ if (!map2.contains(key)
+ || !equals(map1.get(key), map2.get(key),
mapType.getValueType())) {
+ return false;
+ }
+ }
+ } else if (data1 instanceof byte[]) {
+ if (!java.util.Arrays.equals((byte[]) data1, (byte[]) data2)) {
+ return false;
+ }
+ } else if (data1 instanceof Float && java.lang.Float.isNaN((Float)
data1)) {
+ if (!java.lang.Float.isNaN((Float) data2)) {
+ return false;
+ }
+ } else if (data1 instanceof Double &&
java.lang.Double.isNaN((Double) data1)) {
+ if (!java.lang.Double.isNaN((Double) data2)) {
+ return false;
+ }
+ } else {
+ if (!data1.equals(data2)) {
+ return false;
+ }
+ }
+ }
+ return true;
+ }
+
+ public static int hash(Object data, DataType dataType) {
+ if (data == null) {
+ return 0;
+ }
+ int result = 0;
+ if (data instanceof InternalRow) {
+ RowType rowType = (RowType) dataType;
+ int len = rowType.getFieldCount();
+ for (int i = 0; i < len; i++) {
+ Object v = get((InternalRow) data, i, rowType.getTypeAt(i));
+ result = 37 * result + hash(v, rowType.getTypeAt(i));
+ }
+ } else if (data instanceof InternalArray) {
+ ArrayType arrayType = (ArrayType) dataType;
+ int len = ((InternalArray) data).size();
+ for (int i = 0; i < len; i++) {
+ Object v = get((InternalArray) data, i,
arrayType.getElementType());
+ result = 37 * result + hash(v, arrayType.getElementType());
+ }
+ } else if (data instanceof InternalMap) {
+ MapType mapType = (MapType) dataType;
+ GenericMap map;
+ if (data instanceof GenericMap) {
+ map = (GenericMap) data;
+ } else {
+ map =
+ copyToGenericMap(
+ (InternalMap) data, mapType.getKeyType(),
mapType.getValueType());
+ }
+ InternalArray keyArray = map.keyArray();
+ for (int i = 0; i < map.size(); i++) {
+ Object key = get(keyArray, i, mapType.getKeyType());
+ result = 37 * result + hash(key, mapType.getKeyType());
+ result = 37 * result + hash(map.get(key),
mapType.getValueType());
+ }
+ } else if (data instanceof byte[]) {
+ result = Arrays.hashCode((byte[]) data);
+ } else {
+ result = data.hashCode();
+ }
+ return result;
+ }
+
public static InternalRow copyInternalRow(InternalRow row, RowType
rowType) {
if (row instanceof BinaryRow) {
return ((BinaryRow) row).copy();
@@ -117,6 +239,11 @@ public class InternalRowUtils {
return ((BinaryMap) map).copy();
}
+ return copyToGenericMap(map, keyType, valueType);
+ }
+
+ private static GenericMap copyToGenericMap(
+ InternalMap map, DataType keyType, DataType valueType) {
Map<Object, Object> javaMap = new HashMap<>();
InternalArray keys = map.keyArray();
InternalArray values = map.valueArray();
@@ -145,6 +272,10 @@ public class InternalRowUtils {
return copyMap(
(InternalMap) o, ((MultisetType)
type).getElementType(), new IntType());
}
+ } else if (o instanceof byte[]) {
+ byte[] copy = new byte[((byte[]) o).length];
+ System.arraycopy(((byte[]) o), 0, copy, 0, ((byte[]) o).length);
+ return copy;
} else if (o instanceof Decimal) {
return ((Decimal) o).copy();
}
diff --git
a/paimon-common/src/test/java/org/apache/paimon/utils/InternalRowUtilsTest.java
b/paimon-common/src/test/java/org/apache/paimon/utils/InternalRowUtilsTest.java
index ea3bd98cfe..70d32c928c 100644
---
a/paimon-common/src/test/java/org/apache/paimon/utils/InternalRowUtilsTest.java
+++
b/paimon-common/src/test/java/org/apache/paimon/utils/InternalRowUtilsTest.java
@@ -21,6 +21,9 @@ package org.apache.paimon.utils;
import org.apache.paimon.data.BinaryRow;
import org.apache.paimon.data.BinaryString;
import org.apache.paimon.data.Decimal;
+import org.apache.paimon.data.GenericArray;
+import org.apache.paimon.data.GenericMap;
+import org.apache.paimon.data.GenericRow;
import org.apache.paimon.data.InternalRow;
import org.apache.paimon.data.Timestamp;
import org.apache.paimon.data.serializer.InternalRowSerializer;
@@ -37,6 +40,8 @@ import org.junit.jupiter.api.Test;
import java.math.BigDecimal;
import java.time.LocalDateTime;
+import java.util.HashMap;
+import java.util.Map;
import static org.assertj.core.api.Assertions.assertThat;
@@ -52,6 +57,7 @@ public class InternalRowUtilsTest {
.field("intArray",
DataTypes.ARRAY(DataTypes.INT()).nullable())
.field("char", DataTypes.CHAR(10).notNull())
.field("varchar", DataTypes.VARCHAR(10).notNull())
+ .field("binary", DataTypes.BINARY(10).notNull())
.field("boolean", DataTypes.BOOLEAN().nullable())
.field("tinyint", DataTypes.TINYINT())
.field("smallint", DataTypes.SMALLINT())
@@ -144,4 +150,76 @@ public class InternalRowUtilsTest {
DataTypeRoot.VARCHAR))
.isLessThan(0);
}
+
+ @Test
+ public void testEqualsAndHashCode() {
+ for (int i = 0; i < 10; i++) {
+ GenericRow row1 = (GenericRow) rowDataGenerator.next();
+ GenericRow row2 = (GenericRow)
InternalRowUtils.copyInternalRow(row1, ROW_TYPE);
+ GenericRow row3 = (GenericRow) rowDataGenerator.next();
+ assertThat(InternalRowUtils.equals(row1, row2, ROW_TYPE)).isTrue();
+ assertThat(InternalRowUtils.equals(row1, row3,
ROW_TYPE)).isFalse();
+
+ assertThat(InternalRowUtils.hash(row1, ROW_TYPE))
+ .isEqualTo(InternalRowUtils.hash(row2, ROW_TYPE));
+ assertThat(InternalRowUtils.hash(row1, ROW_TYPE))
+ .isNotEqualTo(InternalRowUtils.hash(row3, ROW_TYPE));
+ }
+
+ RowType rowType =
+ RowType.builder()
+ .field("f1", DataTypes.DOUBLE())
+ .field("f2", DataTypes.FLOAT())
+ .field("f3", DataTypes.BINARY(3))
+ .field("f4", DataTypes.STRING())
+ .field("f5",
DataTypes.ARRAY(DataTypes.ROW(DataTypes.INT())))
+ .field(
+ "f6",
+ DataTypes.MAP(DataTypes.STRING(),
DataTypes.ROW(DataTypes.INT())))
+ .field("f7", DataTypes.ROW(DataTypes.INT()))
+ .build();
+ GenericRow row1 = new GenericRow(7);
+ row1.setField(0, Double.NaN);
+ row1.setField(1, Float.NaN);
+ row1.setField(2, "abc".getBytes());
+ row1.setField(3, null);
+ row1.setField(4, new GenericArray(new GenericRow[] {GenericRow.of(1),
GenericRow.of(10)}));
+ Map<BinaryString, InternalRow> map = new HashMap<>();
+ map.put(BinaryString.fromString("a"), GenericRow.of(1));
+ map.put(BinaryString.fromString("b"), GenericRow.of(2));
+ row1.setField(5, new GenericMap(map));
+ row1.setField(6, GenericRow.of(1));
+ GenericRow row2 = (GenericRow) InternalRowUtils.copyInternalRow(row1,
rowType);
+ assertThat(InternalRowUtils.equals(row1, row2, rowType)).isTrue();
+ assertThat(InternalRowUtils.hash(row1, rowType))
+ .isEqualTo(InternalRowUtils.hash(row2, rowType));
+ }
+
+ @Test
+ public void testEqualsAndHashCodeNegativeCase() {
+ // different array len
+ RowType rowType = RowType.builder().field("f1",
DataTypes.ARRAY(DataTypes.INT())).build();
+ GenericRow rowWithArray1 = new GenericRow(1);
+ rowWithArray1.setField(
+ 0, new GenericArray(new GenericRow[] {GenericRow.of(1),
GenericRow.of(10)}));
+ GenericRow rowWithArray2 = new GenericRow(1);
+ rowWithArray2.setField(0, new GenericArray(new GenericRow[]
{GenericRow.of(1)}));
+ assertThat(InternalRowUtils.equals(rowWithArray1, rowWithArray2,
rowType)).isFalse();
+
+ // different map len
+ RowType rowType2 =
+ RowType.builder()
+ .field("f1", DataTypes.MAP(DataTypes.STRING(),
DataTypes.INT()))
+ .build();
+ Map<BinaryString, InternalRow> map1 = new HashMap<>();
+ map1.put(BinaryString.fromString("a"), GenericRow.of(1));
+ map1.put(BinaryString.fromString("b"), GenericRow.of(2));
+ GenericRow rowWithMap1 = new GenericRow(1);
+ rowWithMap1.setField(0, new GenericMap(map1));
+ Map<BinaryString, InternalRow> map2 = new HashMap<>();
+ map2.put(BinaryString.fromString("a"), GenericRow.of(1));
+ GenericRow rowWithMap2 = new GenericRow(1);
+ rowWithMap2.setField(0, new GenericMap(map2));
+ assertThat(InternalRowUtils.equals(rowWithMap1, rowWithMap2,
rowType2)).isFalse();
+ }
}
diff --git
a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/AbstractSparkInternalRow.java
b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/AbstractSparkInternalRow.java
index 28604a6d62..283077430e 100644
---
a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/AbstractSparkInternalRow.java
+++
b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/AbstractSparkInternalRow.java
@@ -25,6 +25,7 @@ import org.apache.paimon.types.BigIntType;
import org.apache.paimon.types.DataType;
import org.apache.paimon.types.DataTypeChecks;
import org.apache.paimon.types.RowType;
+import org.apache.paimon.utils.InternalRowUtils;
import org.apache.spark.sql.catalyst.util.ArrayData;
import org.apache.spark.sql.catalyst.util.MapData;
@@ -251,11 +252,24 @@ public abstract class AbstractSparkInternalRow extends
SparkInternalRow {
return false;
}
AbstractSparkInternalRow that = (AbstractSparkInternalRow) o;
- return Objects.equals(rowType, that.rowType) && Objects.equals(row,
that.row);
+ if (Objects.equals(rowType, that.rowType)) {
+ try {
+ return Objects.equals(row, that.row);
+ } catch (Exception e) {
+ // The underlying row may not support equals or hashcode,
e.g., `ProjectedRow`,
+ // to be safe, we fallback to a slow performance version.
+ return InternalRowUtils.equals(row, that.row, rowType);
+ }
+ }
+ return false;
}
@Override
public int hashCode() {
- return Objects.hash(rowType, row);
+ try {
+ return Objects.hash(rowType, row);
+ } catch (Exception e) {
+ return InternalRowUtils.hash(row, rowType);
+ }
}
}