This is an automated email from the ASF dual-hosted git repository.

yiguolei pushed a commit to branch branch-2.1
in repository https://gitbox.apache.org/repos/asf/doris.git

commit 8f3f9a53be78e482971d70dbf4a2fa80be8bbe35
Author: 谢健 <[email protected]>
AuthorDate: Fri Mar 22 10:25:49 2024 +0800

    [feat](Nereids): add is null predicate for the first partition when 
updating mv by partition (#32463)
---
 .../plans/commands/UpdateMvByPartitionCommand.java | 103 +++++++++++++--------
 .../commands/UpdateMvByPartitionCommandTest.java   |  62 +++++++------
 2 files changed, 100 insertions(+), 65 deletions(-)

diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/UpdateMvByPartitionCommand.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/UpdateMvByPartitionCommand.java
index b456be6e260..f371a359e0a 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/UpdateMvByPartitionCommand.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/UpdateMvByPartitionCommand.java
@@ -49,12 +49,15 @@ import org.apache.doris.nereids.util.ExpressionUtils;
 import org.apache.doris.nereids.util.RelationUtil;
 import org.apache.doris.qe.ConnectContext;
 
+import com.google.common.annotations.VisibleForTesting;
 import com.google.common.collect.ImmutableList;
 import com.google.common.collect.ImmutableMap;
 import com.google.common.collect.ImmutableSet;
 import com.google.common.collect.Range;
+import com.google.common.collect.Sets;
 
 import java.util.ArrayList;
+import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
 import java.util.Optional;
@@ -111,11 +114,27 @@ public class UpdateMvByPartitionCommand extends 
InsertOverwriteTableCommand {
         return builder.build();
     }
 
-    private static Set<Expression> constructPredicates(Set<PartitionItem> 
partitions, String colName) {
+    /**
+     * construct predicates for partition items, the min key is the min key of 
range items.
+     * For list partition or less than partition items, the min key is null.
+     */
+    @VisibleForTesting
+    public static Set<Expression> constructPredicates(Set<PartitionItem> 
partitions, String colName) {
+        Set<Expression> predicates = new HashSet<>();
         UnboundSlot slot = new UnboundSlot(colName);
-        return partitions.stream()
-                .map(item -> convertPartitionItemToPredicate(item, slot))
-                .collect(ImmutableSet.toImmutableSet());
+        if (partitions.isEmpty()) {
+            return Sets.newHashSet(BooleanLiteral.TRUE);
+        }
+        if (partitions.iterator().next() instanceof ListPartitionItem) {
+            for (PartitionItem item : partitions) {
+                predicates.add(convertListPartitionToIn(item, slot));
+            }
+        } else {
+            for (PartitionItem item : partitions) {
+                predicates.add(convertRangePartitionToCompare(item, slot));
+            }
+        }
+        return predicates;
     }
 
     private static Expression convertPartitionKeyToLiteral(PartitionKey key) {
@@ -123,42 +142,48 @@ public class UpdateMvByPartitionCommand extends 
InsertOverwriteTableCommand {
                 Type.fromPrimitiveType(key.getTypes().get(0)));
     }
 
-    private static Expression convertPartitionItemToPredicate(PartitionItem 
item, Slot col) {
-        if (item instanceof ListPartitionItem) {
-            List<Expression> inValues = ((ListPartitionItem) 
item).getItems().stream()
-                    
.map(UpdateMvByPartitionCommand::convertPartitionKeyToLiteral)
-                    .collect(ImmutableList.toImmutableList());
-            List<Expression> predicates = new ArrayList<>();
-            if (inValues.stream().anyMatch(NullLiteral.class::isInstance)) {
-                inValues = inValues.stream()
-                        .filter(e -> !(e instanceof NullLiteral))
-                        .collect(Collectors.toList());
-                Expression isNullPredicate = new IsNull(col);
-                predicates.add(isNullPredicate);
-            }
-            if (!inValues.isEmpty()) {
-                predicates.add(new InPredicate(col, inValues));
-            }
-            if (predicates.isEmpty()) {
-                return BooleanLiteral.of(true);
-            }
-            return ExpressionUtils.or(predicates);
-        } else {
-            Range<PartitionKey> range = item.getItems();
-            List<Expression> exprs = new ArrayList<>();
-            if (range.hasLowerBound() && !range.lowerEndpoint().isMinValue()) {
-                PartitionKey key = range.lowerEndpoint();
-                exprs.add(new GreaterThanEqual(col, 
convertPartitionKeyToLiteral(key)));
-            }
-            if (range.hasUpperBound() && !range.upperEndpoint().isMaxValue()) {
-                PartitionKey key = range.upperEndpoint();
-                exprs.add(new LessThan(col, 
convertPartitionKeyToLiteral(key)));
-            }
-            if (exprs.isEmpty()) {
-                return BooleanLiteral.of(true);
-            }
-            return ExpressionUtils.and(exprs);
+    private static Expression convertListPartitionToIn(PartitionItem item, 
Slot col) {
+        List<Expression> inValues = ((ListPartitionItem) 
item).getItems().stream()
+                .map(UpdateMvByPartitionCommand::convertPartitionKeyToLiteral)
+                .collect(ImmutableList.toImmutableList());
+        List<Expression> predicates = new ArrayList<>();
+        if (inValues.stream().anyMatch(NullLiteral.class::isInstance)) {
+            inValues = inValues.stream()
+                    .filter(e -> !(e instanceof NullLiteral))
+                    .collect(Collectors.toList());
+            Expression isNullPredicate = new IsNull(col);
+            predicates.add(isNullPredicate);
+        }
+        if (!inValues.isEmpty()) {
+            predicates.add(new InPredicate(col, inValues));
+        }
+        if (predicates.isEmpty()) {
+            return BooleanLiteral.of(true);
+        }
+        return ExpressionUtils.or(predicates);
+    }
+
+    private static Expression convertRangePartitionToCompare(PartitionItem 
item, Slot col) {
+        Range<PartitionKey> range = item.getItems();
+        List<Expression> expressions = new ArrayList<>();
+        if (range.hasLowerBound() && !range.lowerEndpoint().isMinValue()) {
+            PartitionKey key = range.lowerEndpoint();
+            expressions.add(new GreaterThanEqual(col, 
convertPartitionKeyToLiteral(key)));
+        }
+        if (range.hasUpperBound() && !range.upperEndpoint().isMaxValue()) {
+            PartitionKey key = range.upperEndpoint();
+            expressions.add(new LessThan(col, 
convertPartitionKeyToLiteral(key)));
+        }
+        if (expressions.isEmpty()) {
+            return BooleanLiteral.of(true);
+        }
+        Expression predicate = ExpressionUtils.and(expressions);
+        // The partition without can be the first partition of LESS THAN 
PARTITIONS
+        // The null value can insert into this partition, so we need to add or 
is null condition
+        if (!range.hasLowerBound()) {
+            predicate = ExpressionUtils.or(predicate, new IsNull(col));
         }
+        return predicate;
     }
 
     static class PredicateAdder extends DefaultPlanRewriter<Map<TableIf, 
Set<Expression>>> {
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/plans/commands/UpdateMvByPartitionCommandTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/plans/commands/UpdateMvByPartitionCommandTest.java
index b7374565da1..a2909f85a76 100644
--- 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/plans/commands/UpdateMvByPartitionCommandTest.java
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/plans/commands/UpdateMvByPartitionCommandTest.java
@@ -20,58 +20,68 @@ package org.apache.doris.nereids.trees.plans.commands;
 import org.apache.doris.analysis.PartitionValue;
 import org.apache.doris.catalog.Column;
 import org.apache.doris.catalog.ListPartitionItem;
-import org.apache.doris.catalog.PartitionItem;
 import org.apache.doris.catalog.PartitionKey;
 import org.apache.doris.catalog.PrimitiveType;
 import org.apache.doris.catalog.RangePartitionItem;
 import org.apache.doris.common.AnalysisException;
 import org.apache.doris.nereids.trees.expressions.Expression;
-import org.apache.doris.nereids.trees.expressions.GreaterThanEqual;
 import org.apache.doris.nereids.trees.expressions.IsNull;
-import org.apache.doris.nereids.trees.expressions.Slot;
-import org.apache.doris.nereids.trees.expressions.SlotReference;
-import org.apache.doris.nereids.types.IntegerType;
 
 import com.google.common.collect.ImmutableList;
 import com.google.common.collect.Range;
+import com.google.common.collect.Sets;
 import org.junit.jupiter.api.Assertions;
 import org.junit.jupiter.api.Test;
 
-import java.lang.reflect.InvocationTargetException;
-import java.lang.reflect.Method;
+import java.util.Set;
 
 class UpdateMvByPartitionCommandTest {
     @Test
-    void testMaxMin() throws AnalysisException, NoSuchMethodException, 
InvocationTargetException,
-            IllegalAccessException {
-        Method m = 
UpdateMvByPartitionCommand.class.getDeclaredMethod("convertPartitionItemToPredicate",
 PartitionItem.class,
-                Slot.class);
-        m.setAccessible(true);
+    void testFirstPartWithoutLowerBound() throws AnalysisException {
         Column column = new Column("a", PrimitiveType.INT);
-        PartitionKey upper = 
PartitionKey.createPartitionKey(ImmutableList.of(PartitionValue.MAX_VALUE), 
ImmutableList.of(column));
-        PartitionKey lower = 
PartitionKey.createPartitionKey(ImmutableList.of(new PartitionValue(1L)), 
ImmutableList.of(column));
+        PartitionKey upper = 
PartitionKey.createPartitionKey(ImmutableList.of(new PartitionValue(1L)),
+                ImmutableList.of(column));
+        Range<PartitionKey> range1 = Range.lessThan(upper);
+        RangePartitionItem item1 = new RangePartitionItem(range1);
+
+        Set<Expression> predicates = 
UpdateMvByPartitionCommand.constructPredicates(Sets.newHashSet(item1), "s");
+        Assertions.assertEquals("((s < 1) OR s IS NULL)", 
predicates.iterator().next().toSql());
+
+    }
+
+    @Test
+    void testMaxMin() throws AnalysisException {
+        Column column = new Column("a", PrimitiveType.INT);
+        PartitionKey upper = 
PartitionKey.createPartitionKey(ImmutableList.of(PartitionValue.MAX_VALUE),
+                ImmutableList.of(column));
+        PartitionKey lower = 
PartitionKey.createPartitionKey(ImmutableList.of(new PartitionValue(1L)),
+                ImmutableList.of(column));
         Range<PartitionKey> range = Range.closedOpen(lower, upper);
         RangePartitionItem rangePartitionItem = new RangePartitionItem(range);
-        Expression expr = (Expression) m.invoke(null, rangePartitionItem, new 
SlotReference("s", IntegerType.INSTANCE));
-        Assertions.assertTrue(expr instanceof GreaterThanEqual);
+        Set<Expression> predicates = 
UpdateMvByPartitionCommand.constructPredicates(Sets.newHashSet(rangePartitionItem),
+                "s");
+        Expression expr = predicates.iterator().next();
+        System.out.println(expr.toSql());
+        Assertions.assertEquals("(s >= 1)", expr.toSql());
     }
 
     @Test
-    void testNull() throws AnalysisException, NoSuchMethodException, 
InvocationTargetException,
-            IllegalAccessException {
-        Method m = 
UpdateMvByPartitionCommand.class.getDeclaredMethod("convertPartitionItemToPredicate",
 PartitionItem.class,
-                Slot.class);
-        m.setAccessible(true);
+    void testNull() throws AnalysisException {
         Column column = new Column("a", PrimitiveType.INT);
-        PartitionKey v = 
PartitionKey.createListPartitionKeyWithTypes(ImmutableList.of(new 
PartitionValue("NULL", true)), ImmutableList.of(column.getType()), false);
+        PartitionKey v = PartitionKey.createListPartitionKeyWithTypes(
+                ImmutableList.of(new PartitionValue("NULL", true)), 
ImmutableList.of(column.getType()), false);
         ListPartitionItem listPartitionItem = new 
ListPartitionItem(ImmutableList.of(v));
-        Expression expr = (Expression) m.invoke(null, listPartitionItem, new 
SlotReference("s", IntegerType.INSTANCE));
+        Expression expr = 
UpdateMvByPartitionCommand.constructPredicates(Sets.newHashSet(listPartitionItem),
 "s")
+                .iterator().next();
         Assertions.assertTrue(expr instanceof IsNull);
 
-        PartitionKey v1 = 
PartitionKey.createListPartitionKeyWithTypes(ImmutableList.of(new 
PartitionValue("NULL", true)), ImmutableList.of(column.getType()), false);
-        PartitionKey v2 = 
PartitionKey.createListPartitionKeyWithTypes(ImmutableList.of(new 
PartitionValue("1", false)), ImmutableList.of(column.getType()), false);
+        PartitionKey v1 = PartitionKey.createListPartitionKeyWithTypes(
+                ImmutableList.of(new PartitionValue("NULL", true)), 
ImmutableList.of(column.getType()), false);
+        PartitionKey v2 = 
PartitionKey.createListPartitionKeyWithTypes(ImmutableList.of(new 
PartitionValue("1", false)),
+                ImmutableList.of(column.getType()), false);
         listPartitionItem = new ListPartitionItem(ImmutableList.of(v1, v2));
-        expr = (Expression) m.invoke(null, listPartitionItem, new 
SlotReference("s", IntegerType.INSTANCE));
+        expr = 
UpdateMvByPartitionCommand.constructPredicates(Sets.newHashSet(listPartitionItem),
 "s").iterator()
+                .next();
         Assertions.assertEquals("(s IS NULL OR s IN (1))", expr.toSql());
     }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to