This is an automated email from the ASF dual-hosted git repository.

silun pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/calcite.git


The following commit(s) were added to refs/heads/main by this push:
     new 28ddefd12d [CALCITE-7431] RelTraitSet#getTrait seems to mishandle 
RelCompositeTrait
28ddefd12d is described below

commit 28ddefd12d88f9288e93fe485fb69ea601ed1e51
Author: Zhen Chen <[email protected]>
AuthorDate: Mon Mar 30 17:42:25 2026 +0800

    [CALCITE-7431] RelTraitSet#getTrait seems to mishandle RelCompositeTrait
---
 .../adapter/enumerable/EnumerableMergeUnion.java   | 14 +++--
 .../java/org/apache/calcite/plan/RelTraitSet.java  | 70 ++++++++++++++++++----
 .../java/org/apache/calcite/plan/RelTraitTest.java | 35 +++++++++++
 3 files changed, 102 insertions(+), 17 deletions(-)

diff --git 
a/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableMergeUnion.java
 
b/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableMergeUnion.java
index d06fa31661..09d5a9e221 100644
--- 
a/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableMergeUnion.java
+++ 
b/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableMergeUnion.java
@@ -22,10 +22,8 @@
 import org.apache.calcite.linq4j.tree.Expressions;
 import org.apache.calcite.linq4j.tree.ParameterExpression;
 import org.apache.calcite.plan.RelOptCluster;
-import org.apache.calcite.plan.RelTrait;
 import org.apache.calcite.plan.RelTraitSet;
 import org.apache.calcite.rel.RelCollation;
-import org.apache.calcite.rel.RelCollationTraitDef;
 import org.apache.calcite.rel.RelNode;
 import org.apache.calcite.util.BuiltInMethod;
 import org.apache.calcite.util.Pair;
@@ -49,13 +47,17 @@ protected EnumerableMergeUnion(RelOptCluster cluster, 
RelTraitSet traitSet,
       throw new IllegalArgumentException("EnumerableMergeUnion with no 
collation");
     }
     for (RelNode input : inputs) {
-      final RelTrait inputCollationTrait =
-          input.getTraitSet().getTrait(RelCollationTraitDef.INSTANCE);
+      // Use getCollations() rather than getTrait() so that we handle the case
+      // where the input's collation slot holds a RelCompositeTrait (multiple
+      // collations).  For each required collation, at least one of the input's
+      // collations must satisfy it.
+      final List<RelCollation> inputCollations = 
input.getTraitSet().getCollations();
       for (RelCollation collation : collations) {
-        if (inputCollationTrait == null || 
!inputCollationTrait.satisfies(collation)) {
+        boolean satisfied = inputCollations.stream().anyMatch(ic -> 
ic.satisfies(collation));
+        if (!satisfied) {
           throw new IllegalArgumentException("EnumerableMergeUnion input does "
               + "not satisfy collation. EnumerableMergeUnion collation: "
-              + collation + ". Input collation: " + inputCollationTrait + ". 
Input: "
+              + collation + ". Input collations: " + inputCollations + ". 
Input: "
               + input);
         }
       }
diff --git a/core/src/main/java/org/apache/calcite/plan/RelTraitSet.java 
b/core/src/main/java/org/apache/calcite/plan/RelTraitSet.java
index 6574742606..fff7eea4f9 100644
--- a/core/src/main/java/org/apache/calcite/plan/RelTraitSet.java
+++ b/core/src/main/java/org/apache/calcite/plan/RelTraitSet.java
@@ -87,7 +87,13 @@ public static RelTraitSet createEmpty() {
    *                                        {@link #size()} or less than 0.
    */
   public RelTrait getTrait(int index) {
-    return traits[index];
+    final RelTrait trait = traits[index];
+    if (trait instanceof RelCompositeTrait) {
+      throw new IllegalStateException("Trait index " + index
+          + " has multiple values in this trait set; "
+          + "use getTraits(RelTraitDef) instead of getTrait(RelTraitDef)");
+    }
+    return trait;
   }
 
   /**
@@ -110,21 +116,29 @@ public <E extends RelMultipleTrait> List<E> getTraits(int 
index) {
   }
 
   @Override public RelTrait get(int index) {
-    return getTrait(index);
+    return traits[index];
   }
 
   /**
    * Returns whether a given kind of trait is enabled.
    */
   public <T extends RelTrait> boolean isEnabled(RelTraitDef<T> traitDef) {
-    return getTrait(traitDef) != null;
+    return findIndex(traitDef) >= 0;
   }
 
   /**
    * Retrieves a RelTrait of the given type from the set.
    *
+   * <p>If this trait def supports multiple values (i.e. its trait implements
+   * {@link RelMultipleTrait}), the underlying slot may contain a
+   * {@link RelCompositeTrait} when more than one value is present. In that
+   * case this method throws {@link IllegalStateException}; use
+   * {@link #getTraits(RelTraitDef)} instead.
+   *
    * @param traitDef the type of RelTrait to retrieve
    * @return the RelTrait, or null if not found
+   * @throws IllegalStateException if the slot holds a composite (multiple)
+   *   trait; use {@link #getTraits(RelTraitDef)} in that case
    */
   public <T extends RelTrait> @Nullable T getTrait(RelTraitDef<T> traitDef) {
     int index = findIndex(traitDef);
@@ -375,17 +389,44 @@ public RelTraitSet getDefaultSansConvention() {
    * {@link RelDistributionTraitDef#INSTANCE}, or null if the
    * {@link RelDistributionTraitDef#INSTANCE} is not registered
    * in this traitSet.
+   *
+   * <p>If this trait set contains multiple distributions (a composite trait),
+   * this method throws {@link IllegalStateException}. Use
+   * {@link #getDistributions()} to handle both the single and 
multi-distribution
+   * cases uniformly.
    */
   @SuppressWarnings("unchecked")
   public <T extends RelDistribution> @Nullable T getDistribution() {
     return (@Nullable T) getTrait(RelDistributionTraitDef.INSTANCE);
   }
 
+  /**
+   * Returns {@link RelDistribution} traits defined by
+   * {@link RelDistributionTraitDef#INSTANCE}.
+   *
+   * <p>Returns an empty list when the trait def is not registered, a
+   * singleton list for the common single-distribution case, and a list with
+   * more than one element when a {@link RelCompositeTrait} is present.
+   */
+  @SuppressWarnings("unchecked")
+  public List<RelDistribution> getDistributions() {
+    int index = findIndex(RelDistributionTraitDef.INSTANCE);
+    if (index < 0) {
+      return ImmutableList.of();
+    }
+    return (List<RelDistribution>) (List<?>) getTraits(index);
+  }
+
   /**
    * Returns {@link RelCollation} trait defined by
    * {@link RelCollationTraitDef#INSTANCE}, or null if the
    * {@link RelCollationTraitDef#INSTANCE} is not registered
    * in this traitSet.
+   *
+   * <p>If this trait set contains multiple collations (a composite trait),
+   * this method throws {@link IllegalStateException}. Use
+   * {@link #getCollations()} to handle both the single and multi-collation
+   * cases uniformly.
    */
   @SuppressWarnings("unchecked")
   public <T extends RelCollation> @Nullable T getCollation() {
@@ -395,17 +436,19 @@ public RelTraitSet getDefaultSansConvention() {
   /**
    * Returns {@link RelCollation} traits defined by
    * {@link RelCollationTraitDef#INSTANCE}.
+   *
+   * <p>Returns an empty list when the trait def is not registered, a
+   * singleton list for the common single-collation case, and a list with
+   * more than one element when a {@link RelCompositeTrait} is present.
    */
   @SuppressWarnings("unchecked")
   public List<RelCollation> getCollations() {
-    RelCollation trait = getTrait(RelCollationTraitDef.INSTANCE);
-    if (trait == null) {
+    int index = findIndex(RelCollationTraitDef.INSTANCE);
+    if (index < 0) {
       return ImmutableList.of();
     }
-    if (trait instanceof RelCompositeTrait) {
-      return ((RelCompositeTrait<RelCollation>) trait).traitList();
-    }
-    return ImmutableList.of(trait);
+    // getTraits(int) already unwraps RelCompositeTrait transparently.
+    return (List<RelCollation>) (List<?>) getTraits(index);
   }
 
   /**
@@ -577,8 +620,13 @@ public boolean contains(RelTrait trait) {
    */
   public boolean containsIfApplicable(RelTrait trait) {
     // Note that '==' is sufficient, because trait should be canonized.
-    final RelTrait trait1 = getTrait(trait.getTraitDef());
-    return trait1 == null || trait1 == trait;
+    int index = findIndex(trait.getTraitDef());
+    if (index < 0) {
+      // TraitDef not registered in this set → treat as "not applicable" → true
+      return true;
+    }
+    final RelTrait stored = get(index);
+    return stored == trait;
   }
 
   /**
diff --git a/core/src/test/java/org/apache/calcite/plan/RelTraitTest.java 
b/core/src/test/java/org/apache/calcite/plan/RelTraitTest.java
index 6a9ddbaa49..2760b1e157 100644
--- a/core/src/test/java/org/apache/calcite/plan/RelTraitTest.java
+++ b/core/src/test/java/org/apache/calcite/plan/RelTraitTest.java
@@ -20,6 +20,10 @@
 import org.apache.calcite.rel.RelCollation;
 import org.apache.calcite.rel.RelCollationTraitDef;
 import org.apache.calcite.rel.RelCollations;
+import org.apache.calcite.rel.RelNode;
+import org.apache.calcite.test.RelBuilderTest;
+import org.apache.calcite.tools.FrameworkConfig;
+import org.apache.calcite.tools.RelBuilder;
 
 import com.google.common.collect.ImmutableList;
 
@@ -33,6 +37,7 @@
 import static org.hamcrest.Matchers.hasSize;
 import static org.junit.jupiter.api.Assertions.assertFalse;
 import static org.junit.jupiter.api.Assertions.assertNotEquals;
+import static org.junit.jupiter.api.Assertions.assertThrows;
 import static org.junit.jupiter.api.Assertions.assertTrue;
 
 import static java.lang.Integer.toHexString;
@@ -98,4 +103,34 @@ private void assertCanonical(String message, 
Supplier<List<RelCollation>> collat
     RelTraitSet traits3 = traits2.replace(RelCollations.of(1));
     assertFalse(traits3.equalsSansConvention(traits2));
   }
+
+  /** Test for
+   * <a 
href="https://issues.apache.org/jira/browse/CALCITE-7431";>[CALCITE-7431]
+   * RelTraitSet#getTrait seems to mishandle RelCompositeTrait</a>. */
+  @Test void testRelCompositeTrait() {
+    // Build:  EMP -> Sort(MGR asc) -> Project(MGR, MGR as MGR2)
+    // The project maps both output columns 0 and 1 back to input column 3
+    // (MGR), so the planner derives two collations: [0 ASC] and [1 ASC], which
+    // are stored as a RelCompositeTrait in the output trait set.
+    final FrameworkConfig config = RelBuilderTest.config().build();
+    final RelBuilder b = RelBuilder.create(config);
+    final RelNode in = b
+        .scan("EMP")
+        .sort(3)                                          // MGR asc
+        .project(b.field(3), b.alias(b.field(3), "MGR2")) // MGR, MGR as MGR2
+        .build();
+
+    final RelTraitSet traitSet = in.getTraitSet();
+
+    final List<RelCollation> collations = traitSet.getCollations();
+    assertTrue(collations.size() >= 2,
+        "getCollations() should expose all composite collations");
+
+    assertThrows(IllegalStateException.class, traitSet::getCollation,
+        "getCollation() should throw when a RelCompositeTrait is present");
+
+    assertThrows(IllegalStateException.class,
+        () -> traitSet.getTrait(RelCollationTraitDef.INSTANCE),
+        "getTrait() should throw when a RelCompositeTrait is present");
+  }
 }

Reply via email to