This is an automated email from the ASF dual-hosted git repository.

huajianlan pushed a commit to branch nested_column_prune
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/nested_column_prune by this 
push:
     new f9079ca71c8 fix some bugs of prune nested columns
f9079ca71c8 is described below

commit f9079ca71c8024eb9e0e29f5ce042204c0f4a433
Author: 924060929 <[email protected]>
AuthorDate: Mon Oct 27 21:32:47 2025 +0800

    fix some bugs of prune nested columns
---
 .../rules/expression/ExpressionNormalization.java  |  2 +
 .../rules/expression/ExpressionRuleType.java       |  1 +
 .../expression/rules/NormalizeStructElement.java   | 66 +++++++++++++++
 .../rewrite/AccessPathExpressionCollector.java     | 32 +++++++-
 .../rules/rewrite/AccessPathPlanCollector.java     | 15 ++++
 .../nereids/rules/rewrite/NestedColumnPruning.java | 84 ++++++++++++++++++-
 .../nereids/rules/rewrite/SlotTypeReplacer.java    | 96 +++++++++++++++-------
 .../trees/plans/logical/LogicalTVFRelation.java    | 23 +++++-
 .../java/org/apache/doris/planner/PlanNode.java    |  4 +-
 .../rules/rewrite/PruneNestedColumnTest.java       | 56 +++++++++++++
 10 files changed, 342 insertions(+), 37 deletions(-)

diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionNormalization.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionNormalization.java
index a4cdb54a5cd..c7a33ce680e 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionNormalization.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionNormalization.java
@@ -29,6 +29,7 @@ import 
org.apache.doris.nereids.rules.expression.rules.LogToLn;
 import org.apache.doris.nereids.rules.expression.rules.MedianConvert;
 import org.apache.doris.nereids.rules.expression.rules.MergeDateTrunc;
 import 
org.apache.doris.nereids.rules.expression.rules.NormalizeBinaryPredicatesRule;
+import org.apache.doris.nereids.rules.expression.rules.NormalizeStructElement;
 import 
org.apache.doris.nereids.rules.expression.rules.SimplifyArithmeticComparisonRule;
 import org.apache.doris.nereids.rules.expression.rules.SimplifyArithmeticRule;
 import org.apache.doris.nereids.rules.expression.rules.SimplifyCastRule;
@@ -67,6 +68,7 @@ public class ExpressionNormalization extends 
ExpressionRewrite {
                 SimplifyArithmeticComparisonRule.INSTANCE,
                 ConvertAggStateCast.INSTANCE,
                 MergeDateTrunc.INSTANCE,
+                NormalizeStructElement.INSTANCE,
                 CheckCast.INSTANCE
             )
     );
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionRuleType.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionRuleType.java
index 823dbd49b93..e3359a1f621 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionRuleType.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionRuleType.java
@@ -58,6 +58,7 @@ public enum ExpressionRuleType {
     SIMPLIFY_RANGE,
     SIMPLIFY_SELF_COMPARISON,
     SUPPORT_JAVA_DATE_FORMATTER,
+    NORMALIZE_STRUCT_ELEMENT,
     TOPN_TO_MAX;
 
     public int type() {
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/NormalizeStructElement.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/NormalizeStructElement.java
new file mode 100644
index 00000000000..dc50921bc1e
--- /dev/null
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/NormalizeStructElement.java
@@ -0,0 +1,66 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.expression.rules;
+
+import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher;
+import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory;
+import org.apache.doris.nereids.rules.expression.ExpressionRuleType;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import 
org.apache.doris.nereids.trees.expressions.functions.scalar.StructElement;
+import org.apache.doris.nereids.trees.expressions.literal.IntegerLikeLiteral;
+import org.apache.doris.nereids.trees.expressions.literal.StringLiteral;
+import org.apache.doris.nereids.types.StructField;
+import org.apache.doris.nereids.types.StructType;
+
+import com.google.common.collect.ImmutableList;
+
+import java.util.List;
+
+/**
+ * given column s has data type: struct&lt;a: int, b: int&gt;, if exists 
struct_element(s, 2), we will rewrite
+ * to struct_element(s, 'b')
+ */
+public class NormalizeStructElement implements ExpressionPatternRuleFactory {
+    public static final NormalizeStructElement INSTANCE = new 
NormalizeStructElement();
+
+    @Override
+    public List<ExpressionPatternMatcher<? extends Expression>> buildRules() {
+        return ImmutableList.of(
+                
matchesType(StructElement.class).then(NormalizeStructElement::normalize)
+                        .toRule(ExpressionRuleType.NORMALIZE_STRUCT_ELEMENT)
+        );
+    }
+
+    private static StructElement normalize(StructElement structElement) {
+        Expression field = structElement.getArgument(1);
+        if (field instanceof IntegerLikeLiteral) {
+            int fieldIndex = ((Number) ((IntegerLikeLiteral) 
field).getValue()).intValue();
+            StructType structType = (StructType) 
structElement.getArgument(0).getDataType();
+            List<StructField> fields = structType.getFields();
+            if (fieldIndex >= 0 && fieldIndex <= fields.size()) {
+                return structElement.withChildren(
+                        ImmutableList.of(
+                                structElement.child(0),
+                                new StringLiteral(fields.get(fieldIndex - 
1).getName())
+                        )
+                );
+            }
+        }
+        return structElement;
+    }
+}
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AccessPathExpressionCollector.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AccessPathExpressionCollector.java
index 57fe9fefcc5..8bbb1a8c25e 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AccessPathExpressionCollector.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AccessPathExpressionCollector.java
@@ -19,6 +19,7 @@ package org.apache.doris.nereids.rules.rewrite;
 
 import org.apache.doris.nereids.StatementContext;
 import 
org.apache.doris.nereids.rules.rewrite.AccessPathExpressionCollector.CollectorContext;
+import 
org.apache.doris.nereids.rules.rewrite.NestedColumnPruning.DataTypeAccessTree;
 import org.apache.doris.nereids.trees.expressions.Alias;
 import org.apache.doris.nereids.trees.expressions.ArrayItemReference;
 import 
org.apache.doris.nereids.trees.expressions.ArrayItemReference.ArrayItemSlot;
@@ -50,6 +51,8 @@ import 
org.apache.doris.nereids.trees.expressions.literal.Literal;
 import 
org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionVisitor;
 import org.apache.doris.nereids.types.DataType;
 import org.apache.doris.nereids.types.NestedColumnPrunable;
+import org.apache.doris.nereids.types.StructField;
+import org.apache.doris.nereids.types.StructType;
 import org.apache.doris.nereids.util.Utils;
 import org.apache.doris.thrift.TAccessPathType;
 
@@ -57,6 +60,7 @@ import com.google.common.collect.ImmutableList;
 import com.google.common.collect.Maps;
 import com.google.common.collect.Multimap;
 
+import java.util.ArrayList;
 import java.util.LinkedList;
 import java.util.List;
 import java.util.Map;
@@ -128,7 +132,23 @@ public class AccessPathExpressionCollector extends 
DefaultExpressionVisitor<Void
 
     @Override
     public Void visitCast(Cast cast, CollectorContext context) {
-        return cast.child(0).accept(this, context);
+        if (!context.accessPathBuilder.isEmpty()
+                && cast.getDataType() instanceof NestedColumnPrunable
+                && cast.child().getDataType() instanceof NestedColumnPrunable) 
{
+
+            DataTypeAccessTree castTree = 
DataTypeAccessTree.of(cast.getDataType(), TAccessPathType.DATA);
+            DataTypeAccessTree originTree = 
DataTypeAccessTree.of(cast.child().getDataType(), TAccessPathType.DATA);
+
+            List<String> replacePath = new 
ArrayList<>(context.accessPathBuilder.getPathList());
+            if (originTree.replacePathByAnotherTree(castTree, replacePath, 0)) 
{
+                CollectorContext castContext = new 
CollectorContext(context.statementContext, context.bottomFilter);
+                castContext.accessPathBuilder.accessPath.addAll(replacePath);
+                return continueCollectAccessPath(cast.child(), castContext);
+            }
+        }
+        return cast.child(0).accept(this,
+                new CollectorContext(context.statementContext, 
context.bottomFilter)
+        );
     }
 
     // array element at
@@ -158,6 +178,15 @@ public class AccessPathExpressionCollector extends 
DefaultExpressionVisitor<Void
         DataType fieldType = fieldName.getDataType();
 
         if (fieldName.isLiteral() && (fieldType.isIntegerLikeType() || 
fieldType.isStringLikeType())) {
+            if (fieldType.isIntegerLikeType()) {
+                int fieldIndex = ((Number) ((Literal) 
fieldName).getValue()).intValue();
+                List<StructField> fields = ((StructType) 
struct.getDataType()).getFields();
+                if (fieldIndex >= 1 && fieldIndex <= fields.size()) {
+                    String realFieldName = fields.get(fieldIndex - 
1).getName();
+                    context.accessPathBuilder.addPrefix(realFieldName);
+                    return continueCollectAccessPath(struct, context);
+                }
+            }
             context.accessPathBuilder.addPrefix(((Literal) 
fieldName).getStringValue());
             return continueCollectAccessPath(struct, context);
         }
@@ -170,6 +199,7 @@ public class AccessPathExpressionCollector extends 
DefaultExpressionVisitor<Void
 
     @Override
     public Void visitMapKeys(MapKeys mapKeys, CollectorContext context) {
+        context = new CollectorContext(context.statementContext, 
context.bottomFilter);
         context.accessPathBuilder.addPrefix("KEYS");
         return continueCollectAccessPath(mapKeys.getArgument(0), context);
     }
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AccessPathPlanCollector.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AccessPathPlanCollector.java
index 21406c2ea13..514f7bb1e8c 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AccessPathPlanCollector.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AccessPathPlanCollector.java
@@ -28,6 +28,7 @@ import 
org.apache.doris.nereids.trees.plans.logical.LogicalCTEProducer;
 import org.apache.doris.nereids.trees.plans.logical.LogicalFileScan;
 import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
 import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalTVFRelation;
 import org.apache.doris.nereids.trees.plans.logical.LogicalUnion;
 import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanVisitor;
 import org.apache.doris.nereids.types.NestedColumnPrunable;
@@ -137,6 +138,20 @@ public class AccessPathPlanCollector extends 
DefaultPlanVisitor<Void, StatementC
         return null;
     }
 
+    @Override
+    public Void visitLogicalTVFRelation(LogicalTVFRelation tvfRelation, 
StatementContext context) {
+        for (Slot slot : tvfRelation.getOutput()) {
+            if (!(slot.getDataType() instanceof NestedColumnPrunable)) {
+                continue;
+            }
+            Collection<CollectAccessPathResult> accessPaths = 
allSlotToAccessPaths.get(slot.getExprId().asInt());
+            if (!accessPaths.isEmpty()) {
+                scanSlotToAccessPaths.put(slot, new ArrayList<>(accessPaths));
+            }
+        }
+        return null;
+    }
+
     @Override
     public Void visit(Plan plan, StatementContext context) {
         collectByExpressions(plan, context);
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/NestedColumnPruning.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/NestedColumnPruning.java
index 0f99b576cae..00f79b08eef 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/NestedColumnPruning.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/NestedColumnPruning.java
@@ -234,7 +234,8 @@ public class NestedColumnPruning implements CustomRewriter {
         return result;
     }
 
-    private static class DataTypeAccessTree {
+    /** DataTypeAccessTree */
+    public static class DataTypeAccessTree {
         private DataType type;
         private boolean isRoot;
         private boolean accessPartialChild;
@@ -252,6 +253,78 @@ public class NestedColumnPruning implements CustomRewriter 
{
             this.pathType = pathType;
         }
 
+        /** pruneCastType */
+        public DataType pruneCastType(DataTypeAccessTree origin, 
DataTypeAccessTree cast) {
+            if (type instanceof StructType) {
+                Map<String, String> nameMapping = new LinkedHashMap<>();
+                List<String> castNames = new 
ArrayList<>(cast.children.keySet());
+                int i = 0;
+                for (String s : origin.children.keySet()) {
+                    nameMapping.put(s, castNames.get(i++));
+                }
+                List<StructField> mappingFields = new ArrayList<>();
+                StructType originPrunedStructType = (StructType) type;
+                for (Entry<String, DataTypeAccessTree> kv : 
children.entrySet()) {
+                    String originName = kv.getKey();
+                    String mappingName = nameMapping.getOrDefault(originName, 
originName);
+                    DataTypeAccessTree originPrunedTree = kv.getValue();
+                    DataType mappingType = originPrunedTree.pruneCastType(
+                            origin.children.get(originName),
+                            cast.children.get(mappingName)
+                    );
+                    StructField field = 
originPrunedStructType.getField(originName);
+                    mappingFields.add(
+                            new StructField(mappingName, mappingType, 
field.isNullable(), field.getComment())
+                    );
+                }
+                return new StructType(mappingFields);
+            } else if (type instanceof ArrayType) {
+                return ArrayType.of(
+                        children.values().iterator().next().pruneCastType(
+                                origin.children.values().iterator().next(),
+                                cast.children.values().iterator().next()
+                        )
+                );
+            } else if (type instanceof MapType) {
+                return MapType.of(
+                        
children.get("KEYS").pruneCastType(origin.children.get("KEYS"), 
cast.children.get("KEYS")),
+                        
children.get("VALUES").pruneCastType(origin.children.get("VALUES"), 
cast.children.get("VALUES"))
+                );
+            } else {
+                return type;
+            }
+        }
+
+        /** replacePathByAnotherTree */
+        public boolean replacePathByAnotherTree(DataTypeAccessTree cast, 
List<String> path, int index) {
+            if (index >= path.size()) {
+                return true;
+            }
+            if (cast.type instanceof StructType) {
+                List<StructField> fields = ((StructType) 
cast.type).getFields();
+                for (int i = 0; i < fields.size(); i++) {
+                    String castFieldName = path.get(index);
+                    if 
(fields.get(i).getName().equalsIgnoreCase(castFieldName)) {
+                        String originFieldName = ((StructType) 
type).getFields().get(i).getName();
+                        path.set(index, originFieldName);
+                        return 
children.get(originFieldName).replacePathByAnotherTree(
+                                cast.children.get(castFieldName), path, index 
+ 1
+                        );
+                    }
+                }
+            } else if (cast.type instanceof ArrayType) {
+                return 
children.values().iterator().next().replacePathByAnotherTree(
+                        cast.children.values().iterator().next(), path, index 
+ 1);
+            } else if (cast.type instanceof MapType) {
+                String fieldName = path.get(index);
+                return children.get("VALUES").replacePathByAnotherTree(
+                        cast.children.get(fieldName), path, index + 1
+                );
+            }
+            return false;
+        }
+
+        /** setAccessByPath */
         public void setAccessByPath(List<String> path, int accessIndex, 
TAccessPathType pathType) {
             if (accessIndex >= path.size()) {
                 accessAll = true;
@@ -267,7 +340,12 @@ public class NestedColumnPruning implements CustomRewriter 
{
             if (this.type.isStructType()) {
                 String fieldName = path.get(accessIndex);
                 DataTypeAccessTree child = children.get(fieldName);
-                child.setAccessByPath(path, accessIndex + 1, pathType);
+                if (child != null) {
+                    child.setAccessByPath(path, accessIndex + 1, pathType);
+                } else {
+                    // can not find the field
+                    accessAll = true;
+                }
                 return;
             } else if (this.type.isArrayType()) {
                 DataTypeAccessTree child = children.get("*");
@@ -314,6 +392,7 @@ public class NestedColumnPruning implements CustomRewriter {
             return root;
         }
 
+        /** of */
         public static DataTypeAccessTree of(DataType type, TAccessPathType 
pathType) {
             DataTypeAccessTree root = new DataTypeAccessTree(type, pathType);
             if (type instanceof StructType) {
@@ -330,6 +409,7 @@ public class NestedColumnPruning implements CustomRewriter {
             return root;
         }
 
+        /** pruneDataType */
         public Optional<DataType> pruneDataType() {
             if (isRoot) {
                 return children.values().iterator().next().pruneDataType();
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/SlotTypeReplacer.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/SlotTypeReplacer.java
index 043e9c7bdc1..459c7362758 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/SlotTypeReplacer.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/SlotTypeReplacer.java
@@ -22,7 +22,9 @@ import org.apache.doris.catalog.Column;
 import org.apache.doris.common.Pair;
 import org.apache.doris.datasource.iceberg.IcebergExternalTable;
 import org.apache.doris.nereids.properties.OrderKey;
+import 
org.apache.doris.nereids.rules.rewrite.NestedColumnPruning.DataTypeAccessTree;
 import org.apache.doris.nereids.trees.expressions.ArrayItemReference;
+import org.apache.doris.nereids.trees.expressions.Cast;
 import org.apache.doris.nereids.trees.expressions.Expression;
 import org.apache.doris.nereids.trees.expressions.NamedExpression;
 import org.apache.doris.nereids.trees.expressions.OrderExpression;
@@ -49,8 +51,8 @@ import 
org.apache.doris.nereids.trees.plans.logical.LogicalPartitionTopN;
 import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
 import org.apache.doris.nereids.trees.plans.logical.LogicalRepeat;
 import org.apache.doris.nereids.trees.plans.logical.LogicalResultSink;
-import org.apache.doris.nereids.trees.plans.logical.LogicalSink;
 import org.apache.doris.nereids.trees.plans.logical.LogicalSort;
+import org.apache.doris.nereids.trees.plans.logical.LogicalTVFRelation;
 import org.apache.doris.nereids.trees.plans.logical.LogicalTopN;
 import org.apache.doris.nereids.trees.plans.logical.LogicalUnion;
 import org.apache.doris.nereids.trees.plans.logical.LogicalWindow;
@@ -58,6 +60,7 @@ import 
org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter;
 import org.apache.doris.nereids.types.ArrayType;
 import org.apache.doris.nereids.types.DataType;
 import org.apache.doris.nereids.types.MapType;
+import org.apache.doris.nereids.types.NestedColumnPrunable;
 import org.apache.doris.nereids.types.StructType;
 import org.apache.doris.nereids.util.MoreFieldsThread;
 import org.apache.doris.thrift.TAccessPathType;
@@ -402,6 +405,16 @@ public class SlotTypeReplacer extends 
DefaultPlanRewriter<Void> {
         return fileScan;
     }
 
+    @Override
+    public Plan visitLogicalTVFRelation(LogicalTVFRelation tvfRelation, Void 
context) {
+        Pair<Boolean, List<Slot>> replaced
+                = replaceExpressions(tvfRelation.getOutput(), false, true);
+        if (replaced.first) {
+            return tvfRelation.withCachedOutputs(replaced.second);
+        }
+        return tvfRelation;
+    }
+
     @Override
     public Plan visitLogicalOlapScan(LogicalOlapScan olapScan, Void context) {
         Pair<Boolean, List<Slot>> replaced = 
replaceExpressions(olapScan.getOutput(), false, true);
@@ -445,12 +458,6 @@ public class SlotTypeReplacer extends 
DefaultPlanRewriter<Void> {
         return logicalResultSink;
     }
 
-    @Override
-    public Plan visitLogicalSink(LogicalSink<? extends Plan> logicalSink, Void 
context) {
-        // do nothing
-        return logicalSink;
-    }
-
     private Pair<Boolean, List<OrderExpression>> 
replaceOrderExpressions(List<OrderExpression> orderExpressions) {
         ImmutableList.Builder<OrderExpression> newOrderKeys
                 = 
ImmutableList.builderWithExpectedSize(orderExpressions.size());
@@ -525,28 +532,39 @@ public class SlotTypeReplacer extends 
DefaultPlanRewriter<Void> {
         return Pair.of(changed, (C) newExprs.build());
     }
 
-    private Expression replaceSlot(Expression expr, boolean fillAccessPath) {
-        return MoreFieldsThread.keepFunctionSignature(false, () -> {
-            return expr.rewriteUp(e -> {
-                if (e instanceof Lambda) {
-                    return rewriteLambda((Lambda) e, fillAccessPath);
-                } else if (e instanceof SlotReference) {
-                    AccessPathInfo accessPathInfo = 
replacedDataTypes.get(((SlotReference) e).getExprId().asInt());
-                    if (accessPathInfo != null) {
-                        SlotReference newSlot
-                                = (SlotReference) ((SlotReference) 
e).withNullableAndDataType(
-                                e.nullable(), accessPathInfo.getPrunedType());
-                        if (fillAccessPath) {
-                            newSlot = newSlot.withAccessPaths(
-                                    accessPathInfo.getAllAccessPaths(), 
accessPathInfo.getPredicateAccessPaths()
-                            );
-                        }
-                        return newSlot;
-                    }
+    private Expression replaceSlot(Expression e, boolean fillAccessPath) {
+        return MoreFieldsThread.keepFunctionSignature(false,
+                () -> doRewriteExpression(e, fillAccessPath)
+        );
+    }
+
+    private Expression doRewriteExpression(Expression e, boolean 
fillAccessPath) {
+        if (e instanceof Lambda) {
+            return rewriteLambda((Lambda) e, fillAccessPath);
+        } else if (e instanceof Cast) {
+            return rewriteCast((Cast) e, fillAccessPath);
+        } else if (e instanceof SlotReference) {
+            AccessPathInfo accessPathInfo = 
replacedDataTypes.get(((SlotReference) e).getExprId().asInt());
+            if (accessPathInfo != null) {
+                SlotReference newSlot = (SlotReference) ((SlotReference) 
e).withNullableAndDataType(
+                        e.nullable(), accessPathInfo.getPrunedType());
+                if (fillAccessPath) {
+                    newSlot = newSlot.withAccessPaths(
+                            accessPathInfo.getAllAccessPaths(), 
accessPathInfo.getPredicateAccessPaths()
+                    );
                 }
-                return e;
-            });
-        });
+                return newSlot;
+            }
+        }
+
+        ImmutableList.Builder<Expression> newChildren = 
ImmutableList.builderWithExpectedSize(e.arity());
+        boolean changed = false;
+        for (Expression child : e.children()) {
+            Expression newChild = doRewriteExpression(child, fillAccessPath);
+            changed |= child != newChild;
+            newChildren.add(newChild);
+        }
+        return changed ? e.withChildren(newChildren.build()) : e;
     }
 
     private Expression rewriteLambda(Lambda e, boolean fillAccessPath) {
@@ -555,7 +573,7 @@ public class SlotTypeReplacer extends 
DefaultPlanRewriter<Void> {
         for (int i = 0; i < e.arity(); i++) {
             Expression child = e.child(i);
             if (child instanceof ArrayItemReference) {
-                Expression newRef = 
child.withChildren(replaceSlot(child.child(0), fillAccessPath));
+                Expression newRef = 
child.withChildren(doRewriteExpression(child.child(0), fillAccessPath));
                 replacedDataTypes.put(((ArrayItemReference) 
child).getExprId().asInt(),
                         new AccessPathInfo(newRef.getDataType(), null, null));
                 newChildren[i] = newRef;
@@ -567,13 +585,31 @@ public class SlotTypeReplacer extends 
DefaultPlanRewriter<Void> {
         for (int i = 0; i < newChildren.length; i++) {
             Expression child = newChildren[i];
             if (!(child instanceof ArrayItemReference)) {
-                newChildren[i] = replaceSlot(child, fillAccessPath);
+                newChildren[i] = doRewriteExpression(child, fillAccessPath);
             }
         }
 
         return e.withChildren(newChildren);
     }
 
+    private Expression rewriteCast(Cast cast, boolean fillAccessPath) {
+        Expression newChild = doRewriteExpression(cast.child(0), 
fillAccessPath);
+        if (newChild == cast.child(0)) {
+            return cast;
+        }
+
+        DataType newType = cast.getDataType();
+        if (cast.getDataType() instanceof NestedColumnPrunable
+                && newChild.getDataType() instanceof NestedColumnPrunable) {
+            DataTypeAccessTree originTree = 
DataTypeAccessTree.of(cast.child().getDataType(), TAccessPathType.DATA);
+            DataTypeAccessTree prunedTree = 
DataTypeAccessTree.of(newChild.getDataType(), TAccessPathType.DATA);
+            DataTypeAccessTree castTree = 
DataTypeAccessTree.of(cast.getDataType(), TAccessPathType.DATA);
+            newType = prunedTree.pruneCastType(originTree, castTree);
+        }
+
+        return new Cast(newChild, newType);
+    }
+
     private List<TColumnAccessPath> replaceIcebergAccessPathToId(
             List<TColumnAccessPath> originAccessPaths, SlotReference 
slotReference) {
         Column column = slotReference.getOriginalColumn().get();
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalTVFRelation.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalTVFRelation.java
index a434f81b231..c6d0241f393 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalTVFRelation.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalTVFRelation.java
@@ -47,20 +47,31 @@ public class LogicalTVFRelation extends LogicalRelation 
implements TVFRelation,
     private final TableValuedFunction function;
     private final ImmutableList<String> qualifier;
     private final ImmutableList<Slot> operativeSlots;
+    private final Optional<List<Slot>> cachedOutputs;
 
     public LogicalTVFRelation(RelationId id, TableValuedFunction function, 
ImmutableList<Slot> operativeSlots) {
         super(id, PlanType.LOGICAL_TVF_RELATION);
         this.operativeSlots = operativeSlots;
         this.function = function;
-        qualifier = ImmutableList.of(TableValuedFunctionIf.TVF_TABLE_PREFIX + 
function.getName());
+        this.qualifier = 
ImmutableList.of(TableValuedFunctionIf.TVF_TABLE_PREFIX + function.getName());
+        this.cachedOutputs = Optional.empty();
     }
 
     public LogicalTVFRelation(RelationId id, TableValuedFunction function, 
ImmutableList<Slot> operativeSlots,
             Optional<GroupExpression> groupExpression, 
Optional<LogicalProperties> logicalProperties) {
+        this(id, function, operativeSlots, Optional.empty(), groupExpression, 
logicalProperties);
+    }
+
+    public LogicalTVFRelation(RelationId id, TableValuedFunction function,
+            ImmutableList<Slot> operativeSlots,
+            Optional<List<Slot>> cachedOutputs,
+            Optional<GroupExpression> groupExpression,
+            Optional<LogicalProperties> logicalProperties) {
         super(id, PlanType.LOGICAL_TVF_RELATION, groupExpression, 
logicalProperties);
         this.operativeSlots = operativeSlots;
         this.function = function;
-        qualifier = ImmutableList.of(TableValuedFunctionIf.TVF_TABLE_PREFIX + 
function.getName());
+        this.cachedOutputs = Objects.requireNonNull(cachedOutputs, 
"cachedOutputs can not be null");
+        this.qualifier = 
ImmutableList.of(TableValuedFunctionIf.TVF_TABLE_PREFIX + function.getName());
     }
 
     @Override
@@ -120,6 +131,9 @@ public class LogicalTVFRelation extends LogicalRelation 
implements TVFRelation,
 
     @Override
     public List<Slot> computeOutput() {
+        if (cachedOutputs.isPresent()) {
+            return cachedOutputs.get();
+        }
         IdGenerator<ExprId> exprIdGenerator = 
StatementScopeIdGenerator.getExprIdGenerator();
         return function.getTable().getBaseSchema()
                 .stream()
@@ -138,4 +152,9 @@ public class LogicalTVFRelation extends LogicalRelation 
implements TVFRelation,
     public TableValuedFunction getFunction() {
         return function;
     }
+
+    public LogicalTVFRelation withCachedOutputs(List<Slot> replaceSlots) {
+        return new LogicalTVFRelation(relationId, function, 
Utils.fastToImmutableList(operativeSlots),
+                Optional.of(replaceSlots), Optional.empty(), Optional.empty());
+    }
 }
diff --git a/fe/fe-core/src/main/java/org/apache/doris/planner/PlanNode.java 
b/fe/fe-core/src/main/java/org/apache/doris/planner/PlanNode.java
index 27369e40e18..382782c393a 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/planner/PlanNode.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/planner/PlanNode.java
@@ -518,7 +518,7 @@ public abstract class PlanNode extends TreeNode<PlanNode> 
implements PlanStats {
         return expBuilder.toString();
     }
 
-    private String getplanNodeExplainString(String prefix, TExplainLevel 
detailLevel) {
+    private String getPlanNodeExplainString(String prefix, TExplainLevel 
detailLevel) {
         StringBuilder expBuilder = new StringBuilder();
         expBuilder.append(getNodeExplainString(prefix, detailLevel));
         if (limit != -1) {
@@ -533,7 +533,7 @@ public abstract class PlanNode extends TreeNode<PlanNode> 
implements PlanStats {
     }
 
     public void getExplainStringMap(TExplainLevel detailLevel, Map<Integer, 
String> planNodeMap) {
-        planNodeMap.put(id.asInt(), getplanNodeExplainString("", detailLevel));
+        planNodeMap.put(id.asInt(), getPlanNodeExplainString("", detailLevel));
         for (int i = 0; i < children.size(); ++i) {
             children.get(i).getExplainStringMap(detailLevel, planNodeMap);
         }
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PruneNestedColumnTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PruneNestedColumnTest.java
index 40980803931..39d652ab510 100644
--- 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PruneNestedColumnTest.java
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PruneNestedColumnTest.java
@@ -75,10 +75,66 @@ public class PruneNestedColumnTest extends 
TestWithFeService implements MemoPatt
                 + ">)\n"
                 + "properties ('replication_num'='1')");
 
+        createTable("create table tbl2(\n"
+                + "  id2 int,\n"
+                + "  s2 struct<\n"
+                + "    city2: string,\n"
+                + "    data2: array<map<\n"
+                + "      int,\n"
+                + "      struct<a2: int, b2: double>\n"
+                + "    >>\n"
+                + ">)\n"
+                + "properties ('replication_num'='1')");
+
         
connectContext.getSessionVariable().setDisableNereidsRules(RuleType.PRUNE_EMPTY_PARTITION.name());
         connectContext.getSessionVariable().enableNereidsTimeout = false;
     }
 
+    @Test
+    public void testMap() throws Exception {
+        assertColumn("select MAP_KEYS(struct_element(s, 'data')[0])[1] from 
tbl",
+                "struct<data:array<map<int,struct<a:int,b:double>>>>",
+                ImmutableList.of(path("s", "data", "*", "KEYS")),
+                ImmutableList.of()
+        );
+
+        assertColumn("select MAP_VALUES(struct_element(s, 'data')[0])[1] from 
tbl",
+                "struct<data:array<map<int,struct<a:int,b:double>>>>",
+                ImmutableList.of(path("s", "data", "*", "VALUES")),
+                ImmutableList.of()
+        );
+    }
+
+    @Test
+    public void testStruct() throws Throwable {
+        assertColumn("select struct_element(s, 1) from tbl",
+                "struct<city:text>",
+                ImmutableList.of(path("s", "city")),
+                ImmutableList.of()
+        );
+
+        assertColumn("select struct_element(map_values(struct_element(s, 
'data')[0])[0], 1) from tbl",
+                "struct<data:array<map<int,struct<a:int>>>>",
+                ImmutableList.of(path("s", "data", "*", "VALUES", "a")),
+                ImmutableList.of()
+        );
+    }
+
+    @Test
+    public void testPruneCast() throws Exception {
+        assertColumn("select struct_element(cast(s as 
struct<k:text,l:array<map<int,struct<x:int,y:double>>>>), 'k') from tbl",
+                "struct<city:text>",
+                ImmutableList.of(path("s", "city")),
+                ImmutableList.of()
+        );
+
+        // assertColumn("select struct_element(s, 'city'), 
struct_element(map_values(struct_element(s, 'data')[0])[0], 'b') from (select * 
from tbl union all select * from tbl2)t",
+        //         "struct<city:text>",
+        //         ImmutableList.of(path("s", "city")),
+        //         ImmutableList.of()
+        // );
+    }
+
     @Test
     public void testPruneArrayLambda() throws Exception {
         // map_values(struct_element(s, 'data').*)[0].a


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to