This is an automated email from the ASF dual-hosted git repository.
huajianlan pushed a commit to branch nested_column_prune
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/nested_column_prune by this
push:
new f9079ca71c8 fix some bugs of prune nested columns
f9079ca71c8 is described below
commit f9079ca71c8024eb9e0e29f5ce042204c0f4a433
Author: 924060929 <[email protected]>
AuthorDate: Mon Oct 27 21:32:47 2025 +0800
fix some bugs of prune nested columns
---
.../rules/expression/ExpressionNormalization.java | 2 +
.../rules/expression/ExpressionRuleType.java | 1 +
.../expression/rules/NormalizeStructElement.java | 66 +++++++++++++++
.../rewrite/AccessPathExpressionCollector.java | 32 +++++++-
.../rules/rewrite/AccessPathPlanCollector.java | 15 ++++
.../nereids/rules/rewrite/NestedColumnPruning.java | 84 ++++++++++++++++++-
.../nereids/rules/rewrite/SlotTypeReplacer.java | 96 +++++++++++++++-------
.../trees/plans/logical/LogicalTVFRelation.java | 23 +++++-
.../java/org/apache/doris/planner/PlanNode.java | 4 +-
.../rules/rewrite/PruneNestedColumnTest.java | 56 +++++++++++++
10 files changed, 342 insertions(+), 37 deletions(-)
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionNormalization.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionNormalization.java
index a4cdb54a5cd..c7a33ce680e 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionNormalization.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionNormalization.java
@@ -29,6 +29,7 @@ import
org.apache.doris.nereids.rules.expression.rules.LogToLn;
import org.apache.doris.nereids.rules.expression.rules.MedianConvert;
import org.apache.doris.nereids.rules.expression.rules.MergeDateTrunc;
import
org.apache.doris.nereids.rules.expression.rules.NormalizeBinaryPredicatesRule;
+import org.apache.doris.nereids.rules.expression.rules.NormalizeStructElement;
import
org.apache.doris.nereids.rules.expression.rules.SimplifyArithmeticComparisonRule;
import org.apache.doris.nereids.rules.expression.rules.SimplifyArithmeticRule;
import org.apache.doris.nereids.rules.expression.rules.SimplifyCastRule;
@@ -67,6 +68,7 @@ public class ExpressionNormalization extends
ExpressionRewrite {
SimplifyArithmeticComparisonRule.INSTANCE,
ConvertAggStateCast.INSTANCE,
MergeDateTrunc.INSTANCE,
+ NormalizeStructElement.INSTANCE,
CheckCast.INSTANCE
)
);
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionRuleType.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionRuleType.java
index 823dbd49b93..e3359a1f621 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionRuleType.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionRuleType.java
@@ -58,6 +58,7 @@ public enum ExpressionRuleType {
SIMPLIFY_RANGE,
SIMPLIFY_SELF_COMPARISON,
SUPPORT_JAVA_DATE_FORMATTER,
+ NORMALIZE_STRUCT_ELEMENT,
TOPN_TO_MAX;
public int type() {
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/NormalizeStructElement.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/NormalizeStructElement.java
new file mode 100644
index 00000000000..dc50921bc1e
--- /dev/null
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/NormalizeStructElement.java
@@ -0,0 +1,66 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.expression.rules;
+
+import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher;
+import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory;
+import org.apache.doris.nereids.rules.expression.ExpressionRuleType;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import
org.apache.doris.nereids.trees.expressions.functions.scalar.StructElement;
+import org.apache.doris.nereids.trees.expressions.literal.IntegerLikeLiteral;
+import org.apache.doris.nereids.trees.expressions.literal.StringLiteral;
+import org.apache.doris.nereids.types.StructField;
+import org.apache.doris.nereids.types.StructType;
+
+import com.google.common.collect.ImmutableList;
+
+import java.util.List;
+
+/**
+ * given column s has data type: struct<a: int, b: int>, if exists
struct_element(s, 2), we will rewrite
+ * to struct_element(s, 'b')
+ */
+public class NormalizeStructElement implements ExpressionPatternRuleFactory {
+ public static final NormalizeStructElement INSTANCE = new
NormalizeStructElement();
+
+ @Override
+ public List<ExpressionPatternMatcher<? extends Expression>> buildRules() {
+ return ImmutableList.of(
+
matchesType(StructElement.class).then(NormalizeStructElement::normalize)
+ .toRule(ExpressionRuleType.NORMALIZE_STRUCT_ELEMENT)
+ );
+ }
+
+ private static StructElement normalize(StructElement structElement) {
+ Expression field = structElement.getArgument(1);
+ if (field instanceof IntegerLikeLiteral) {
+ int fieldIndex = ((Number) ((IntegerLikeLiteral)
field).getValue()).intValue();
+ StructType structType = (StructType)
structElement.getArgument(0).getDataType();
+ List<StructField> fields = structType.getFields();
+ if (fieldIndex >= 0 && fieldIndex <= fields.size()) {
+ return structElement.withChildren(
+ ImmutableList.of(
+ structElement.child(0),
+ new StringLiteral(fields.get(fieldIndex -
1).getName())
+ )
+ );
+ }
+ }
+ return structElement;
+ }
+}
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AccessPathExpressionCollector.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AccessPathExpressionCollector.java
index 57fe9fefcc5..8bbb1a8c25e 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AccessPathExpressionCollector.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AccessPathExpressionCollector.java
@@ -19,6 +19,7 @@ package org.apache.doris.nereids.rules.rewrite;
import org.apache.doris.nereids.StatementContext;
import
org.apache.doris.nereids.rules.rewrite.AccessPathExpressionCollector.CollectorContext;
+import
org.apache.doris.nereids.rules.rewrite.NestedColumnPruning.DataTypeAccessTree;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.ArrayItemReference;
import
org.apache.doris.nereids.trees.expressions.ArrayItemReference.ArrayItemSlot;
@@ -50,6 +51,8 @@ import
org.apache.doris.nereids.trees.expressions.literal.Literal;
import
org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionVisitor;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.NestedColumnPrunable;
+import org.apache.doris.nereids.types.StructField;
+import org.apache.doris.nereids.types.StructType;
import org.apache.doris.nereids.util.Utils;
import org.apache.doris.thrift.TAccessPathType;
@@ -57,6 +60,7 @@ import com.google.common.collect.ImmutableList;
import com.google.common.collect.Maps;
import com.google.common.collect.Multimap;
+import java.util.ArrayList;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
@@ -128,7 +132,23 @@ public class AccessPathExpressionCollector extends
DefaultExpressionVisitor<Void
@Override
public Void visitCast(Cast cast, CollectorContext context) {
- return cast.child(0).accept(this, context);
+ if (!context.accessPathBuilder.isEmpty()
+ && cast.getDataType() instanceof NestedColumnPrunable
+ && cast.child().getDataType() instanceof NestedColumnPrunable)
{
+
+ DataTypeAccessTree castTree =
DataTypeAccessTree.of(cast.getDataType(), TAccessPathType.DATA);
+ DataTypeAccessTree originTree =
DataTypeAccessTree.of(cast.child().getDataType(), TAccessPathType.DATA);
+
+ List<String> replacePath = new
ArrayList<>(context.accessPathBuilder.getPathList());
+ if (originTree.replacePathByAnotherTree(castTree, replacePath, 0))
{
+ CollectorContext castContext = new
CollectorContext(context.statementContext, context.bottomFilter);
+ castContext.accessPathBuilder.accessPath.addAll(replacePath);
+ return continueCollectAccessPath(cast.child(), castContext);
+ }
+ }
+ return cast.child(0).accept(this,
+ new CollectorContext(context.statementContext,
context.bottomFilter)
+ );
}
// array element at
@@ -158,6 +178,15 @@ public class AccessPathExpressionCollector extends
DefaultExpressionVisitor<Void
DataType fieldType = fieldName.getDataType();
if (fieldName.isLiteral() && (fieldType.isIntegerLikeType() ||
fieldType.isStringLikeType())) {
+ if (fieldType.isIntegerLikeType()) {
+ int fieldIndex = ((Number) ((Literal)
fieldName).getValue()).intValue();
+ List<StructField> fields = ((StructType)
struct.getDataType()).getFields();
+ if (fieldIndex >= 1 && fieldIndex <= fields.size()) {
+ String realFieldName = fields.get(fieldIndex -
1).getName();
+ context.accessPathBuilder.addPrefix(realFieldName);
+ return continueCollectAccessPath(struct, context);
+ }
+ }
context.accessPathBuilder.addPrefix(((Literal)
fieldName).getStringValue());
return continueCollectAccessPath(struct, context);
}
@@ -170,6 +199,7 @@ public class AccessPathExpressionCollector extends
DefaultExpressionVisitor<Void
@Override
public Void visitMapKeys(MapKeys mapKeys, CollectorContext context) {
+ context = new CollectorContext(context.statementContext,
context.bottomFilter);
context.accessPathBuilder.addPrefix("KEYS");
return continueCollectAccessPath(mapKeys.getArgument(0), context);
}
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AccessPathPlanCollector.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AccessPathPlanCollector.java
index 21406c2ea13..514f7bb1e8c 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AccessPathPlanCollector.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AccessPathPlanCollector.java
@@ -28,6 +28,7 @@ import
org.apache.doris.nereids.trees.plans.logical.LogicalCTEProducer;
import org.apache.doris.nereids.trees.plans.logical.LogicalFileScan;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalTVFRelation;
import org.apache.doris.nereids.trees.plans.logical.LogicalUnion;
import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanVisitor;
import org.apache.doris.nereids.types.NestedColumnPrunable;
@@ -137,6 +138,20 @@ public class AccessPathPlanCollector extends
DefaultPlanVisitor<Void, StatementC
return null;
}
+ @Override
+ public Void visitLogicalTVFRelation(LogicalTVFRelation tvfRelation,
StatementContext context) {
+ for (Slot slot : tvfRelation.getOutput()) {
+ if (!(slot.getDataType() instanceof NestedColumnPrunable)) {
+ continue;
+ }
+ Collection<CollectAccessPathResult> accessPaths =
allSlotToAccessPaths.get(slot.getExprId().asInt());
+ if (!accessPaths.isEmpty()) {
+ scanSlotToAccessPaths.put(slot, new ArrayList<>(accessPaths));
+ }
+ }
+ return null;
+ }
+
@Override
public Void visit(Plan plan, StatementContext context) {
collectByExpressions(plan, context);
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/NestedColumnPruning.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/NestedColumnPruning.java
index 0f99b576cae..00f79b08eef 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/NestedColumnPruning.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/NestedColumnPruning.java
@@ -234,7 +234,8 @@ public class NestedColumnPruning implements CustomRewriter {
return result;
}
- private static class DataTypeAccessTree {
+ /** DataTypeAccessTree */
+ public static class DataTypeAccessTree {
private DataType type;
private boolean isRoot;
private boolean accessPartialChild;
@@ -252,6 +253,78 @@ public class NestedColumnPruning implements CustomRewriter
{
this.pathType = pathType;
}
+ /** pruneCastType */
+ public DataType pruneCastType(DataTypeAccessTree origin,
DataTypeAccessTree cast) {
+ if (type instanceof StructType) {
+ Map<String, String> nameMapping = new LinkedHashMap<>();
+ List<String> castNames = new
ArrayList<>(cast.children.keySet());
+ int i = 0;
+ for (String s : origin.children.keySet()) {
+ nameMapping.put(s, castNames.get(i++));
+ }
+ List<StructField> mappingFields = new ArrayList<>();
+ StructType originPrunedStructType = (StructType) type;
+ for (Entry<String, DataTypeAccessTree> kv :
children.entrySet()) {
+ String originName = kv.getKey();
+ String mappingName = nameMapping.getOrDefault(originName,
originName);
+ DataTypeAccessTree originPrunedTree = kv.getValue();
+ DataType mappingType = originPrunedTree.pruneCastType(
+ origin.children.get(originName),
+ cast.children.get(mappingName)
+ );
+ StructField field =
originPrunedStructType.getField(originName);
+ mappingFields.add(
+ new StructField(mappingName, mappingType,
field.isNullable(), field.getComment())
+ );
+ }
+ return new StructType(mappingFields);
+ } else if (type instanceof ArrayType) {
+ return ArrayType.of(
+ children.values().iterator().next().pruneCastType(
+ origin.children.values().iterator().next(),
+ cast.children.values().iterator().next()
+ )
+ );
+ } else if (type instanceof MapType) {
+ return MapType.of(
+
children.get("KEYS").pruneCastType(origin.children.get("KEYS"),
cast.children.get("KEYS")),
+
children.get("VALUES").pruneCastType(origin.children.get("VALUES"),
cast.children.get("VALUES"))
+ );
+ } else {
+ return type;
+ }
+ }
+
+ /** replacePathByAnotherTree */
+ public boolean replacePathByAnotherTree(DataTypeAccessTree cast,
List<String> path, int index) {
+ if (index >= path.size()) {
+ return true;
+ }
+ if (cast.type instanceof StructType) {
+ List<StructField> fields = ((StructType)
cast.type).getFields();
+ for (int i = 0; i < fields.size(); i++) {
+ String castFieldName = path.get(index);
+ if
(fields.get(i).getName().equalsIgnoreCase(castFieldName)) {
+ String originFieldName = ((StructType)
type).getFields().get(i).getName();
+ path.set(index, originFieldName);
+ return
children.get(originFieldName).replacePathByAnotherTree(
+ cast.children.get(castFieldName), path, index
+ 1
+ );
+ }
+ }
+ } else if (cast.type instanceof ArrayType) {
+ return
children.values().iterator().next().replacePathByAnotherTree(
+ cast.children.values().iterator().next(), path, index
+ 1);
+ } else if (cast.type instanceof MapType) {
+ String fieldName = path.get(index);
+ return children.get("VALUES").replacePathByAnotherTree(
+ cast.children.get(fieldName), path, index + 1
+ );
+ }
+ return false;
+ }
+
+ /** setAccessByPath */
public void setAccessByPath(List<String> path, int accessIndex,
TAccessPathType pathType) {
if (accessIndex >= path.size()) {
accessAll = true;
@@ -267,7 +340,12 @@ public class NestedColumnPruning implements CustomRewriter
{
if (this.type.isStructType()) {
String fieldName = path.get(accessIndex);
DataTypeAccessTree child = children.get(fieldName);
- child.setAccessByPath(path, accessIndex + 1, pathType);
+ if (child != null) {
+ child.setAccessByPath(path, accessIndex + 1, pathType);
+ } else {
+ // can not find the field
+ accessAll = true;
+ }
return;
} else if (this.type.isArrayType()) {
DataTypeAccessTree child = children.get("*");
@@ -314,6 +392,7 @@ public class NestedColumnPruning implements CustomRewriter {
return root;
}
+ /** of */
public static DataTypeAccessTree of(DataType type, TAccessPathType
pathType) {
DataTypeAccessTree root = new DataTypeAccessTree(type, pathType);
if (type instanceof StructType) {
@@ -330,6 +409,7 @@ public class NestedColumnPruning implements CustomRewriter {
return root;
}
+ /** pruneDataType */
public Optional<DataType> pruneDataType() {
if (isRoot) {
return children.values().iterator().next().pruneDataType();
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/SlotTypeReplacer.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/SlotTypeReplacer.java
index 043e9c7bdc1..459c7362758 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/SlotTypeReplacer.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/SlotTypeReplacer.java
@@ -22,7 +22,9 @@ import org.apache.doris.catalog.Column;
import org.apache.doris.common.Pair;
import org.apache.doris.datasource.iceberg.IcebergExternalTable;
import org.apache.doris.nereids.properties.OrderKey;
+import
org.apache.doris.nereids.rules.rewrite.NestedColumnPruning.DataTypeAccessTree;
import org.apache.doris.nereids.trees.expressions.ArrayItemReference;
+import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.OrderExpression;
@@ -49,8 +51,8 @@ import
org.apache.doris.nereids.trees.plans.logical.LogicalPartitionTopN;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.logical.LogicalRepeat;
import org.apache.doris.nereids.trees.plans.logical.LogicalResultSink;
-import org.apache.doris.nereids.trees.plans.logical.LogicalSink;
import org.apache.doris.nereids.trees.plans.logical.LogicalSort;
+import org.apache.doris.nereids.trees.plans.logical.LogicalTVFRelation;
import org.apache.doris.nereids.trees.plans.logical.LogicalTopN;
import org.apache.doris.nereids.trees.plans.logical.LogicalUnion;
import org.apache.doris.nereids.trees.plans.logical.LogicalWindow;
@@ -58,6 +60,7 @@ import
org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter;
import org.apache.doris.nereids.types.ArrayType;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.MapType;
+import org.apache.doris.nereids.types.NestedColumnPrunable;
import org.apache.doris.nereids.types.StructType;
import org.apache.doris.nereids.util.MoreFieldsThread;
import org.apache.doris.thrift.TAccessPathType;
@@ -402,6 +405,16 @@ public class SlotTypeReplacer extends
DefaultPlanRewriter<Void> {
return fileScan;
}
+ @Override
+ public Plan visitLogicalTVFRelation(LogicalTVFRelation tvfRelation, Void
context) {
+ Pair<Boolean, List<Slot>> replaced
+ = replaceExpressions(tvfRelation.getOutput(), false, true);
+ if (replaced.first) {
+ return tvfRelation.withCachedOutputs(replaced.second);
+ }
+ return tvfRelation;
+ }
+
@Override
public Plan visitLogicalOlapScan(LogicalOlapScan olapScan, Void context) {
Pair<Boolean, List<Slot>> replaced =
replaceExpressions(olapScan.getOutput(), false, true);
@@ -445,12 +458,6 @@ public class SlotTypeReplacer extends
DefaultPlanRewriter<Void> {
return logicalResultSink;
}
- @Override
- public Plan visitLogicalSink(LogicalSink<? extends Plan> logicalSink, Void
context) {
- // do nothing
- return logicalSink;
- }
-
private Pair<Boolean, List<OrderExpression>>
replaceOrderExpressions(List<OrderExpression> orderExpressions) {
ImmutableList.Builder<OrderExpression> newOrderKeys
=
ImmutableList.builderWithExpectedSize(orderExpressions.size());
@@ -525,28 +532,39 @@ public class SlotTypeReplacer extends
DefaultPlanRewriter<Void> {
return Pair.of(changed, (C) newExprs.build());
}
- private Expression replaceSlot(Expression expr, boolean fillAccessPath) {
- return MoreFieldsThread.keepFunctionSignature(false, () -> {
- return expr.rewriteUp(e -> {
- if (e instanceof Lambda) {
- return rewriteLambda((Lambda) e, fillAccessPath);
- } else if (e instanceof SlotReference) {
- AccessPathInfo accessPathInfo =
replacedDataTypes.get(((SlotReference) e).getExprId().asInt());
- if (accessPathInfo != null) {
- SlotReference newSlot
- = (SlotReference) ((SlotReference)
e).withNullableAndDataType(
- e.nullable(), accessPathInfo.getPrunedType());
- if (fillAccessPath) {
- newSlot = newSlot.withAccessPaths(
- accessPathInfo.getAllAccessPaths(),
accessPathInfo.getPredicateAccessPaths()
- );
- }
- return newSlot;
- }
+ private Expression replaceSlot(Expression e, boolean fillAccessPath) {
+ return MoreFieldsThread.keepFunctionSignature(false,
+ () -> doRewriteExpression(e, fillAccessPath)
+ );
+ }
+
+ private Expression doRewriteExpression(Expression e, boolean
fillAccessPath) {
+ if (e instanceof Lambda) {
+ return rewriteLambda((Lambda) e, fillAccessPath);
+ } else if (e instanceof Cast) {
+ return rewriteCast((Cast) e, fillAccessPath);
+ } else if (e instanceof SlotReference) {
+ AccessPathInfo accessPathInfo =
replacedDataTypes.get(((SlotReference) e).getExprId().asInt());
+ if (accessPathInfo != null) {
+ SlotReference newSlot = (SlotReference) ((SlotReference)
e).withNullableAndDataType(
+ e.nullable(), accessPathInfo.getPrunedType());
+ if (fillAccessPath) {
+ newSlot = newSlot.withAccessPaths(
+ accessPathInfo.getAllAccessPaths(),
accessPathInfo.getPredicateAccessPaths()
+ );
}
- return e;
- });
- });
+ return newSlot;
+ }
+ }
+
+ ImmutableList.Builder<Expression> newChildren =
ImmutableList.builderWithExpectedSize(e.arity());
+ boolean changed = false;
+ for (Expression child : e.children()) {
+ Expression newChild = doRewriteExpression(child, fillAccessPath);
+ changed |= child != newChild;
+ newChildren.add(newChild);
+ }
+ return changed ? e.withChildren(newChildren.build()) : e;
}
private Expression rewriteLambda(Lambda e, boolean fillAccessPath) {
@@ -555,7 +573,7 @@ public class SlotTypeReplacer extends
DefaultPlanRewriter<Void> {
for (int i = 0; i < e.arity(); i++) {
Expression child = e.child(i);
if (child instanceof ArrayItemReference) {
- Expression newRef =
child.withChildren(replaceSlot(child.child(0), fillAccessPath));
+ Expression newRef =
child.withChildren(doRewriteExpression(child.child(0), fillAccessPath));
replacedDataTypes.put(((ArrayItemReference)
child).getExprId().asInt(),
new AccessPathInfo(newRef.getDataType(), null, null));
newChildren[i] = newRef;
@@ -567,13 +585,31 @@ public class SlotTypeReplacer extends
DefaultPlanRewriter<Void> {
for (int i = 0; i < newChildren.length; i++) {
Expression child = newChildren[i];
if (!(child instanceof ArrayItemReference)) {
- newChildren[i] = replaceSlot(child, fillAccessPath);
+ newChildren[i] = doRewriteExpression(child, fillAccessPath);
}
}
return e.withChildren(newChildren);
}
+ private Expression rewriteCast(Cast cast, boolean fillAccessPath) {
+ Expression newChild = doRewriteExpression(cast.child(0),
fillAccessPath);
+ if (newChild == cast.child(0)) {
+ return cast;
+ }
+
+ DataType newType = cast.getDataType();
+ if (cast.getDataType() instanceof NestedColumnPrunable
+ && newChild.getDataType() instanceof NestedColumnPrunable) {
+ DataTypeAccessTree originTree =
DataTypeAccessTree.of(cast.child().getDataType(), TAccessPathType.DATA);
+ DataTypeAccessTree prunedTree =
DataTypeAccessTree.of(newChild.getDataType(), TAccessPathType.DATA);
+ DataTypeAccessTree castTree =
DataTypeAccessTree.of(cast.getDataType(), TAccessPathType.DATA);
+ newType = prunedTree.pruneCastType(originTree, castTree);
+ }
+
+ return new Cast(newChild, newType);
+ }
+
private List<TColumnAccessPath> replaceIcebergAccessPathToId(
List<TColumnAccessPath> originAccessPaths, SlotReference
slotReference) {
Column column = slotReference.getOriginalColumn().get();
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalTVFRelation.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalTVFRelation.java
index a434f81b231..c6d0241f393 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalTVFRelation.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalTVFRelation.java
@@ -47,20 +47,31 @@ public class LogicalTVFRelation extends LogicalRelation
implements TVFRelation,
private final TableValuedFunction function;
private final ImmutableList<String> qualifier;
private final ImmutableList<Slot> operativeSlots;
+ private final Optional<List<Slot>> cachedOutputs;
public LogicalTVFRelation(RelationId id, TableValuedFunction function,
ImmutableList<Slot> operativeSlots) {
super(id, PlanType.LOGICAL_TVF_RELATION);
this.operativeSlots = operativeSlots;
this.function = function;
- qualifier = ImmutableList.of(TableValuedFunctionIf.TVF_TABLE_PREFIX +
function.getName());
+ this.qualifier =
ImmutableList.of(TableValuedFunctionIf.TVF_TABLE_PREFIX + function.getName());
+ this.cachedOutputs = Optional.empty();
}
public LogicalTVFRelation(RelationId id, TableValuedFunction function,
ImmutableList<Slot> operativeSlots,
Optional<GroupExpression> groupExpression,
Optional<LogicalProperties> logicalProperties) {
+ this(id, function, operativeSlots, Optional.empty(), groupExpression,
logicalProperties);
+ }
+
+ public LogicalTVFRelation(RelationId id, TableValuedFunction function,
+ ImmutableList<Slot> operativeSlots,
+ Optional<List<Slot>> cachedOutputs,
+ Optional<GroupExpression> groupExpression,
+ Optional<LogicalProperties> logicalProperties) {
super(id, PlanType.LOGICAL_TVF_RELATION, groupExpression,
logicalProperties);
this.operativeSlots = operativeSlots;
this.function = function;
- qualifier = ImmutableList.of(TableValuedFunctionIf.TVF_TABLE_PREFIX +
function.getName());
+ this.cachedOutputs = Objects.requireNonNull(cachedOutputs,
"cachedOutputs can not be null");
+ this.qualifier =
ImmutableList.of(TableValuedFunctionIf.TVF_TABLE_PREFIX + function.getName());
}
@Override
@@ -120,6 +131,9 @@ public class LogicalTVFRelation extends LogicalRelation
implements TVFRelation,
@Override
public List<Slot> computeOutput() {
+ if (cachedOutputs.isPresent()) {
+ return cachedOutputs.get();
+ }
IdGenerator<ExprId> exprIdGenerator =
StatementScopeIdGenerator.getExprIdGenerator();
return function.getTable().getBaseSchema()
.stream()
@@ -138,4 +152,9 @@ public class LogicalTVFRelation extends LogicalRelation
implements TVFRelation,
public TableValuedFunction getFunction() {
return function;
}
+
+ public LogicalTVFRelation withCachedOutputs(List<Slot> replaceSlots) {
+ return new LogicalTVFRelation(relationId, function,
Utils.fastToImmutableList(operativeSlots),
+ Optional.of(replaceSlots), Optional.empty(), Optional.empty());
+ }
}
diff --git a/fe/fe-core/src/main/java/org/apache/doris/planner/PlanNode.java
b/fe/fe-core/src/main/java/org/apache/doris/planner/PlanNode.java
index 27369e40e18..382782c393a 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/planner/PlanNode.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/planner/PlanNode.java
@@ -518,7 +518,7 @@ public abstract class PlanNode extends TreeNode<PlanNode>
implements PlanStats {
return expBuilder.toString();
}
- private String getplanNodeExplainString(String prefix, TExplainLevel
detailLevel) {
+ private String getPlanNodeExplainString(String prefix, TExplainLevel
detailLevel) {
StringBuilder expBuilder = new StringBuilder();
expBuilder.append(getNodeExplainString(prefix, detailLevel));
if (limit != -1) {
@@ -533,7 +533,7 @@ public abstract class PlanNode extends TreeNode<PlanNode>
implements PlanStats {
}
public void getExplainStringMap(TExplainLevel detailLevel, Map<Integer,
String> planNodeMap) {
- planNodeMap.put(id.asInt(), getplanNodeExplainString("", detailLevel));
+ planNodeMap.put(id.asInt(), getPlanNodeExplainString("", detailLevel));
for (int i = 0; i < children.size(); ++i) {
children.get(i).getExplainStringMap(detailLevel, planNodeMap);
}
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PruneNestedColumnTest.java
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PruneNestedColumnTest.java
index 40980803931..39d652ab510 100644
---
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PruneNestedColumnTest.java
+++
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PruneNestedColumnTest.java
@@ -75,10 +75,66 @@ public class PruneNestedColumnTest extends
TestWithFeService implements MemoPatt
+ ">)\n"
+ "properties ('replication_num'='1')");
+ createTable("create table tbl2(\n"
+ + " id2 int,\n"
+ + " s2 struct<\n"
+ + " city2: string,\n"
+ + " data2: array<map<\n"
+ + " int,\n"
+ + " struct<a2: int, b2: double>\n"
+ + " >>\n"
+ + ">)\n"
+ + "properties ('replication_num'='1')");
+
connectContext.getSessionVariable().setDisableNereidsRules(RuleType.PRUNE_EMPTY_PARTITION.name());
connectContext.getSessionVariable().enableNereidsTimeout = false;
}
+ @Test
+ public void testMap() throws Exception {
+ assertColumn("select MAP_KEYS(struct_element(s, 'data')[0])[1] from
tbl",
+ "struct<data:array<map<int,struct<a:int,b:double>>>>",
+ ImmutableList.of(path("s", "data", "*", "KEYS")),
+ ImmutableList.of()
+ );
+
+ assertColumn("select MAP_VALUES(struct_element(s, 'data')[0])[1] from
tbl",
+ "struct<data:array<map<int,struct<a:int,b:double>>>>",
+ ImmutableList.of(path("s", "data", "*", "VALUES")),
+ ImmutableList.of()
+ );
+ }
+
+ @Test
+ public void testStruct() throws Throwable {
+ assertColumn("select struct_element(s, 1) from tbl",
+ "struct<city:text>",
+ ImmutableList.of(path("s", "city")),
+ ImmutableList.of()
+ );
+
+ assertColumn("select struct_element(map_values(struct_element(s,
'data')[0])[0], 1) from tbl",
+ "struct<data:array<map<int,struct<a:int>>>>",
+ ImmutableList.of(path("s", "data", "*", "VALUES", "a")),
+ ImmutableList.of()
+ );
+ }
+
+ @Test
+ public void testPruneCast() throws Exception {
+ assertColumn("select struct_element(cast(s as
struct<k:text,l:array<map<int,struct<x:int,y:double>>>>), 'k') from tbl",
+ "struct<city:text>",
+ ImmutableList.of(path("s", "city")),
+ ImmutableList.of()
+ );
+
+ // assertColumn("select struct_element(s, 'city'),
struct_element(map_values(struct_element(s, 'data')[0])[0], 'b') from (select *
from tbl union all select * from tbl2)t",
+ // "struct<city:text>",
+ // ImmutableList.of(path("s", "city")),
+ // ImmutableList.of()
+ // );
+ }
+
@Test
public void testPruneArrayLambda() throws Exception {
// map_values(struct_element(s, 'data').*)[0].a
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]