This is an automated email from the ASF dual-hosted git repository.
danny0405 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/hudi.git
The following commit(s) were added to refs/heads/master by this push:
new 77833cdb096 [HUDI-7311] Add implicit literal type conversion before
filter push down (#10531)
77833cdb096 is described below
commit 77833cdb09661b2cdac740520b51a29264afd9c7
Author: Paul Zhang <[email protected]>
AuthorDate: Wed Jan 24 17:15:07 2024 +0800
[HUDI-7311] Add implicit literal type conversion before filter push down
(#10531)
---
.../apache/hudi/source/ExpressionPredicates.java | 4 +-
.../apache/hudi/util/ImplicitTypeConverter.java | 134 +++++++++++++++++++++
.../hudi/source/TestExpressionPredicates.java | 61 ++++++++++
3 files changed, 198 insertions(+), 1 deletion(-)
diff --git
a/hudi-flink-datasource/hudi-flink/src/main/java/org/apache/hudi/source/ExpressionPredicates.java
b/hudi-flink-datasource/hudi-flink/src/main/java/org/apache/hudi/source/ExpressionPredicates.java
index 8faf705a81f..58ee59a8176 100644
---
a/hudi-flink-datasource/hudi-flink/src/main/java/org/apache/hudi/source/ExpressionPredicates.java
+++
b/hudi-flink-datasource/hudi-flink/src/main/java/org/apache/hudi/source/ExpressionPredicates.java
@@ -26,6 +26,7 @@ import
org.apache.flink.table.expressions.ValueLiteralExpression;
import org.apache.flink.table.functions.BuiltInFunctionDefinitions;
import org.apache.flink.table.functions.FunctionDefinition;
import org.apache.flink.table.types.logical.LogicalType;
+import org.apache.hudi.util.ImplicitTypeConverter;
import org.apache.parquet.filter2.predicate.FilterPredicate;
import org.apache.parquet.filter2.predicate.Operators;
import org.slf4j.Logger;
@@ -223,7 +224,8 @@ public class ExpressionPredicates {
@Override
public FilterPredicate filter() {
- return toParquetPredicate(getFunctionDefinition(), literalType,
columnName, literal);
+ Serializable convertedLiteral =
ImplicitTypeConverter.convertImplicitly(literalType, literal);
+ return toParquetPredicate(getFunctionDefinition(), literalType,
columnName, convertedLiteral);
}
/**
diff --git
a/hudi-flink-datasource/hudi-flink/src/main/java/org/apache/hudi/util/ImplicitTypeConverter.java
b/hudi-flink-datasource/hudi-flink/src/main/java/org/apache/hudi/util/ImplicitTypeConverter.java
new file mode 100644
index 00000000000..601b878655f
--- /dev/null
+++
b/hudi-flink-datasource/hudi-flink/src/main/java/org/apache/hudi/util/ImplicitTypeConverter.java
@@ -0,0 +1,134 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hudi.util;
+
+import org.apache.flink.table.types.logical.LogicalType;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.Serializable;
+import java.time.LocalDate;
+import java.time.LocalDateTime;
+import java.time.LocalTime;
+import java.time.ZoneOffset;
+import java.time.temporal.ChronoField;
+
+/**
+ * Implicit type converter for predicates push down.
+ */
+public class ImplicitTypeConverter {
+
+ private static final Logger LOG =
LoggerFactory.getLogger(ImplicitTypeConverter.class);
+
+ /**
+ * Convert the literal to the corresponding type.
+ * @param literalType The type of the literal.
+ * @param literal The literal value.
+ * @return The converted literal.
+ */
+ public static Serializable convertImplicitly(LogicalType literalType,
Serializable literal) {
+ try {
+ switch (literalType.getTypeRoot()) {
+ case BOOLEAN:
+ if (literal instanceof Boolean) {
+ return literal;
+ } else {
+ return Boolean.valueOf(String.valueOf(literal));
+ }
+ case TINYINT:
+ case SMALLINT:
+ case INTEGER:
+ if (literal instanceof Integer) {
+ return literal;
+ } else {
+ return Integer.valueOf(String.valueOf(literal));
+ }
+ case BIGINT:
+ if (literal instanceof Long) {
+ return literal;
+ } else if (literal instanceof Integer) {
+ return new Long((Integer) literal);
+ } else {
+ return Long.valueOf(String.valueOf(literal));
+ }
+ case FLOAT:
+ if (literal instanceof Float) {
+ return literal;
+ } else {
+ return Float.valueOf(String.valueOf(literal));
+ }
+ case DOUBLE:
+ if (literal instanceof Double) {
+ return literal;
+ } else {
+ return Double.valueOf(String.valueOf(literal));
+ }
+ case BINARY:
+ case VARBINARY:
+ if (literal instanceof byte[]) {
+ return literal;
+ } else {
+ return String.valueOf(literal).getBytes();
+ }
+ case DATE:
+ if (literal instanceof LocalDate) {
+ return (int) ((LocalDate) literal).toEpochDay();
+ } else if (literal instanceof Integer) {
+ return literal;
+ } else if (literal instanceof Long) {
+ return ((Long) literal).intValue();
+ } else {
+ return (int) LocalDate.parse(String.valueOf(literal)).toEpochDay();
+ }
+ case CHAR:
+ case VARCHAR:
+ if (literal instanceof String) {
+ return literal;
+ } else {
+ return String.valueOf(literal);
+ }
+ case TIME_WITHOUT_TIME_ZONE:
+ if (literal instanceof LocalTime) {
+ return ((LocalTime) literal).get(ChronoField.MILLI_OF_DAY);
+ } else if (literal instanceof Integer) {
+ return literal;
+ } else if (literal instanceof Long) {
+ return ((Long) literal).intValue();
+ } else {
+ return
LocalTime.parse(String.valueOf(literal)).get(ChronoField.MILLI_OF_DAY);
+ }
+ case TIMESTAMP_WITHOUT_TIME_ZONE:
+ if (literal instanceof LocalDateTime) {
+ return ((LocalDateTime)
literal).toInstant(ZoneOffset.UTC).toEpochMilli();
+ } else if (literal instanceof Long) {
+ return literal;
+ } else if (literal instanceof Integer) {
+ return new Long((Integer) literal);
+ } else {
+ return
LocalDateTime.parse(String.valueOf(literal)).toInstant(ZoneOffset.UTC).toEpochMilli();
+ }
+ default:
+ return literal;
+ }
+ } catch (RuntimeException e) {
+ LOG.warn("Failed to convert literal [{}] to type [{}]. Use its original
type", literal, literalType);
+ return literal;
+ }
+ }
+}
diff --git
a/hudi-flink-datasource/hudi-flink/src/test/java/org/apache/hudi/source/TestExpressionPredicates.java
b/hudi-flink-datasource/hudi-flink/src/test/java/org/apache/hudi/source/TestExpressionPredicates.java
index 02af3a85006..869b69a1a2d 100644
---
a/hudi-flink-datasource/hudi-flink/src/test/java/org/apache/hudi/source/TestExpressionPredicates.java
+++
b/hudi-flink-datasource/hudi-flink/src/test/java/org/apache/hudi/source/TestExpressionPredicates.java
@@ -18,6 +18,7 @@
package org.apache.hudi.source;
+import org.apache.flink.table.types.DataType;
import org.apache.hudi.source.ExpressionPredicates.And;
import org.apache.hudi.source.ExpressionPredicates.Equals;
import org.apache.hudi.source.ExpressionPredicates.GreaterThan;
@@ -41,11 +42,18 @@ import org.apache.parquet.filter2.predicate.Operators.Gt;
import org.apache.parquet.filter2.predicate.Operators.IntColumn;
import org.apache.parquet.filter2.predicate.Operators.Lt;
import org.junit.jupiter.api.Test;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.Arguments;
+import org.junit.jupiter.params.provider.MethodSource;
+import java.time.LocalDate;
+import java.time.LocalDateTime;
+import java.time.LocalTime;
import java.math.BigDecimal;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
+import java.util.stream.Stream;
import static org.apache.hudi.source.ExpressionPredicates.fromExpression;
import static org.apache.parquet.filter2.predicate.FilterApi.and;
@@ -58,6 +66,7 @@ import static
org.apache.parquet.filter2.predicate.FilterApi.ltEq;
import static org.apache.parquet.filter2.predicate.FilterApi.not;
import static org.apache.parquet.filter2.predicate.FilterApi.notEq;
import static org.apache.parquet.filter2.predicate.FilterApi.or;
+import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNull;
@@ -66,6 +75,8 @@ import static org.junit.jupiter.api.Assertions.assertNull;
*/
public class TestExpressionPredicates {
+ private static final String TEST_NAME_WITH_PARAMS = "[{index}] Test with
fieldName={0}, dataType={1}, literalValue={2}";
+
@Test
public void testFilterPredicateFromExpression() {
FieldReferenceExpression fieldReference = new
FieldReferenceExpression("f_int", DataTypes.INT(), 0, 0);
@@ -182,4 +193,54 @@ public class TestExpressionPredicates {
assertNull(Or.getInstance().bindPredicates(greaterThanPredicate,
lessThanPredicate).filter(), "Decimal type push down is unsupported, so we
expect null");
assertNull(Not.getInstance().bindPredicate(greaterThanPredicate).filter(),
"Decimal type push down is unsupported, so we expect null");
}
+
+ public static Stream<Arguments>
testColumnPredicateLiteralTypeConversionParams() {
+ return Stream.of(
+ Arguments.of("f_boolean", DataTypes.BOOLEAN(), Boolean.TRUE),
+ Arguments.of("f_boolean", DataTypes.BOOLEAN(), "true"),
+ Arguments.of("f_tinyint", DataTypes.TINYINT(), 12345),
+ Arguments.of("f_tinyint", DataTypes.TINYINT(), "12345"),
+ Arguments.of("f_smallint", DataTypes.SMALLINT(), 12345),
+ Arguments.of("f_smallint", DataTypes.SMALLINT(), "12345"),
+ Arguments.of("f_integer", DataTypes.INT(), 12345),
+ Arguments.of("f_integer", DataTypes.INT(), "12345"),
+ Arguments.of("f_bigint", DataTypes.BIGINT(), 12345L),
+ Arguments.of("f_bigint", DataTypes.BIGINT(), 12345),
+ Arguments.of("f_bigint", DataTypes.BIGINT(), "12345"),
+ Arguments.of("f_float", DataTypes.FLOAT(), 123.45f),
+ Arguments.of("f_float", DataTypes.FLOAT(), "123.45f"),
+ Arguments.of("f_double", DataTypes.DOUBLE(), 123.45),
+ Arguments.of("f_double", DataTypes.DOUBLE(), "123.45"),
+ Arguments.of("f_varbinary", DataTypes.VARBINARY(10), "a".getBytes()),
+ Arguments.of("f_varbinary", DataTypes.VARBINARY(10), "a"),
+ Arguments.of("f_binary", DataTypes.BINARY(10), "a".getBytes()),
+ Arguments.of("f_binary", DataTypes.BINARY(10), "a"),
+ Arguments.of("f_date", DataTypes.DATE(), LocalDate.now()),
+ Arguments.of("f_date", DataTypes.DATE(), 19740),
+ Arguments.of("f_date", DataTypes.DATE(), 19740L),
+ Arguments.of("f_date", DataTypes.DATE(), "2024-01-18"),
+ Arguments.of("f_char", DataTypes.CHAR(1), "a"),
+ Arguments.of("f_char", DataTypes.CHAR(1), 1),
+ Arguments.of("f_varchar", DataTypes.VARCHAR(1), "a"),
+ Arguments.of("f_varchar", DataTypes.VARCHAR(1), 1),
+ Arguments.of("f_time", DataTypes.TIME(), LocalTime.now()),
+ Arguments.of("f_time", DataTypes.TIME(), 12345),
+ Arguments.of("f_time", DataTypes.TIME(), 60981896000L),
+ Arguments.of("f_time", DataTypes.TIME(), "20:00:00"),
+ Arguments.of("f_timestamp", DataTypes.TIMESTAMP(),
LocalDateTime.now()),
+ Arguments.of("f_timestamp", DataTypes.TIMESTAMP(), 12345),
+ Arguments.of("f_timestamp", DataTypes.TIMESTAMP(), 1705568913701L),
+ Arguments.of("f_timestamp", DataTypes.TIMESTAMP(),
"2024-01-18T15:00:00")
+ );
+ }
+
+ @ParameterizedTest(name = TEST_NAME_WITH_PARAMS)
+ @MethodSource("testColumnPredicateLiteralTypeConversionParams")
+ public void testColumnPredicateLiteralTypeConversion(String fieldName,
DataType dataType, Object literalValue) {
+ FieldReferenceExpression fieldReference = new
FieldReferenceExpression(fieldName, dataType, 0, 0);
+ ValueLiteralExpression valueLiteral = new
ValueLiteralExpression(literalValue);
+
+ ExpressionPredicates.ColumnPredicate predicate =
Equals.getInstance().bindFieldReference(fieldReference).bindValueLiteral(valueLiteral);
+ assertDoesNotThrow(predicate::filter, () -> String.format("Convert from %s
to %s failed", literalValue.getClass().getName(), dataType));
+ }
}