This is an automated email from the ASF dual-hosted git repository.

ppa pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/ignite-3.git


The following commit(s) were added to refs/heads/main by this push:
     new 53f1f5cedd IGNITE-20649 Sql. Added casting to the required type for 
EXCEPT/INTERSECT operations (#3479)
53f1f5cedd is described below

commit 53f1f5ceddda5a282d2ed55f0b7b5a381ce757bf
Author: Pavel Pereslegin <[email protected]>
AuthorDate: Wed Mar 27 16:08:39 2024 +0300

    IGNITE-20649 Sql. Added casting to the required type for EXCEPT/INTERSECT 
operations (#3479)
---
 .../ignite/internal/sql/engine/ItSetOpTest.java    | 40 +++++++++-
 .../sql/engine/rule/SetOpConverterRule.java        |  5 ++
 .../ignite/internal/sql/engine/util/Commons.java   | 61 ++++++++++++++++
 .../sql/engine/planner/SetOpPlannerTest.java       | 85 ++++++++++++++++++++++
 .../internal/sql/engine/util/CommonsTest.java      | 68 ++++++++++++++++-
 5 files changed, 257 insertions(+), 2 deletions(-)

diff --git 
a/modules/sql-engine/src/integrationTest/java/org/apache/ignite/internal/sql/engine/ItSetOpTest.java
 
b/modules/sql-engine/src/integrationTest/java/org/apache/ignite/internal/sql/engine/ItSetOpTest.java
index 4ede763a3d..39d1212e4e 100644
--- 
a/modules/sql-engine/src/integrationTest/java/org/apache/ignite/internal/sql/engine/ItSetOpTest.java
+++ 
b/modules/sql-engine/src/integrationTest/java/org/apache/ignite/internal/sql/engine/ItSetOpTest.java
@@ -20,6 +20,7 @@ package org.apache.ignite.internal.sql.engine;
 import static org.apache.ignite.internal.lang.IgniteStringFormatter.format;
 import static org.junit.jupiter.api.Assertions.assertEquals;
 
+import java.math.BigDecimal;
 import java.util.ArrayList;
 import java.util.List;
 import java.util.function.Predicate;
@@ -28,12 +29,15 @@ import java.util.stream.StreamSupport;
 import org.apache.ignite.internal.sql.BaseSqlIntegrationTest;
 import org.apache.ignite.internal.sql.engine.hint.IgniteHint;
 import org.apache.ignite.internal.sql.engine.util.HintUtils;
+import org.apache.ignite.internal.sql.engine.util.MetadataMatcher;
 import org.apache.ignite.internal.sql.engine.util.QueryChecker;
+import org.apache.ignite.sql.ColumnType;
 import org.junit.jupiter.api.BeforeAll;
 import org.junit.jupiter.api.Disabled;
 import org.junit.jupiter.api.Test;
 import org.junit.jupiter.params.ParameterizedTest;
 import org.junit.jupiter.params.provider.Arguments;
+import org.junit.jupiter.params.provider.CsvSource;
 import org.junit.jupiter.params.provider.EnumSource;
 import org.junit.jupiter.params.provider.MethodSource;
 
@@ -70,7 +74,22 @@ public class ItSetOpTest extends BaseSqlIntegrationTest {
                 {idx, "Igor1", 13d}
         });
 
-
+        // Creating tables with different numeric types for the "val" column.
+        {
+            sql("CREATE TABLE t1(id INTEGER PRIMARY KEY, val INTEGER)");
+            sql("INSERT INTO t1 VALUES(1, 1)");
+            sql("INSERT INTO t1 VALUES(2, 2)");
+            sql("INSERT INTO t1 VALUES(3, 3)");
+            sql("INSERT INTO t1 VALUES(4, 4)");
+
+            sql("CREATE TABLE t2(id INTEGER PRIMARY KEY, val DECIMAL(4,2))");
+            sql("INSERT INTO t2 VALUES(2, 2)");
+            sql("INSERT INTO t2 VALUES(4, 4)");
+
+            sql("CREATE TABLE t3(id INTEGER PRIMARY KEY, val NUMERIC(4,2))");
+            sql("INSERT INTO t3 VALUES(2, 2)");
+            sql("INSERT INTO t3 VALUES(3, 3)");
+        }
     }
 
     @ParameterizedTest
@@ -240,6 +259,25 @@ public class ItSetOpTest extends BaseSqlIntegrationTest {
                 .check();
     }
 
+    @ParameterizedTest(name = "{0}")
+    @CsvSource({"EXCEPT,1,1.00", "INTERSECT,2,2.00"})
+    public void testSetOpDifferentNumericTypes(String setOp, int expectId, 
String expectVal) {
+        String query = "SELECT id, val FROM t1 "
+                + setOp
+                + " SELECT id, val FROM t2 "
+                + setOp
+                + " SELECT id, val FROM t3 ";
+
+        assertQuery(query)
+                .returns(expectId, new BigDecimal(expectVal))
+                .columnMetadata(
+                        new 
MetadataMatcher().nullable(false).type(ColumnType.INT32),
+                        new 
MetadataMatcher().nullable(true).type(ColumnType.DECIMAL)
+                )
+                .ordered()
+                .check();
+    }
+
     /**
      * Test that set op node can be rewinded.
      */
diff --git 
a/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/rule/SetOpConverterRule.java
 
b/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/rule/SetOpConverterRule.java
index ce41411701..c124b07ad4 100644
--- 
a/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/rule/SetOpConverterRule.java
+++ 
b/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/rule/SetOpConverterRule.java
@@ -38,6 +38,7 @@ import 
org.apache.ignite.internal.sql.engine.rel.set.IgniteMapMinus;
 import org.apache.ignite.internal.sql.engine.rel.set.IgniteReduceIntersect;
 import org.apache.ignite.internal.sql.engine.rel.set.IgniteReduceMinus;
 import org.apache.ignite.internal.sql.engine.trait.IgniteDistributions;
+import org.apache.ignite.internal.sql.engine.util.Commons;
 
 /**
  * Set op (MINUS, INTERSECT) converter rule.
@@ -71,6 +72,8 @@ public class SetOpConverterRule {
             RelTraitSet outTrait = 
cluster.traitSetOf(IgniteConvention.INSTANCE).replace(IgniteDistributions.single());
             List<RelNode> inputs = Util.transform(setOp.getInputs(), rel -> 
convert(rel, inTrait));
 
+            inputs = Commons.castInputsToLeastRestrictiveTypeIfNeeded(inputs, 
cluster, inTrait);
+
             return createNode(cluster, outTrait, inputs, setOp.all);
         }
     }
@@ -122,6 +125,8 @@ public class SetOpConverterRule {
             RelTraitSet outTrait = 
cluster.traitSetOf(IgniteConvention.INSTANCE);
             List<RelNode> inputs = Util.transform(setOp.getInputs(), rel -> 
convert(rel, inTrait));
 
+            inputs = Commons.castInputsToLeastRestrictiveTypeIfNeeded(inputs, 
cluster, inTrait);
+
             RelNode map = createMapNode(cluster, outTrait, inputs, setOp.all);
 
             return createReduceNode(
diff --git 
a/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/util/Commons.java
 
b/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/util/Commons.java
index d4e5890639..fa1ac16a13 100644
--- 
a/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/util/Commons.java
+++ 
b/modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/util/Commons.java
@@ -61,10 +61,13 @@ import org.apache.calcite.plan.ConventionTraitDef;
 import org.apache.calcite.plan.RelOptCluster;
 import org.apache.calcite.plan.RelOptRule;
 import org.apache.calcite.plan.RelTraitDef;
+import org.apache.calcite.plan.RelTraitSet;
 import org.apache.calcite.rel.RelCollationTraitDef;
 import org.apache.calcite.rel.RelNode;
 import org.apache.calcite.rel.hint.HintStrategyTable;
+import org.apache.calcite.rel.type.RelDataType;
 import org.apache.calcite.rex.RexBuilder;
+import org.apache.calcite.rex.RexNode;
 import org.apache.calcite.sql.SqlKind;
 import org.apache.calcite.sql.SqlNode;
 import org.apache.calcite.sql.type.SqlTypeCoercionRule;
@@ -94,6 +97,7 @@ import 
org.apache.ignite.internal.sql.engine.metadata.cost.IgniteCostFactory;
 import org.apache.ignite.internal.sql.engine.prepare.IgniteConvertletTable;
 import org.apache.ignite.internal.sql.engine.prepare.IgniteTypeCoercion;
 import org.apache.ignite.internal.sql.engine.prepare.PlanningContext;
+import org.apache.ignite.internal.sql.engine.rel.IgniteProject;
 import 
org.apache.ignite.internal.sql.engine.rel.logical.IgniteLogicalTableScan;
 import org.apache.ignite.internal.sql.engine.sql.IgniteSqlCommitTransaction;
 import org.apache.ignite.internal.sql.engine.sql.IgniteSqlConformance;
@@ -829,4 +833,61 @@ public final class Commons {
                 return null;
         }
     }
+
+    /**
+     * Computes the least restrictive type among the provided inputs and
+     * adds a projection where a cast to the inferred type is needed.
+     *
+     * @param inputs Input relational expressions.
+     * @param cluster Cluster.
+     * @param traits Traits of relational expression.
+     * @return Converted inputs.
+     */
+    public static List<RelNode> 
castInputsToLeastRestrictiveTypeIfNeeded(List<RelNode> inputs, RelOptCluster 
cluster, RelTraitSet traits) {
+        List<RelDataType> inputRowTypes = inputs.stream()
+                .map(RelNode::getRowType)
+                .collect(Collectors.toList());
+
+        // Output type of a set operator is equal to 
leastRestrictive(inputTypes) (see SetOp::deriveRowType)
+
+        RelDataType resultType = 
cluster.getTypeFactory().leastRestrictive(inputRowTypes);
+        if (resultType == null) {
+            throw new IllegalArgumentException("Cannot compute compatible row 
type for arguments to set op: " + inputRowTypes);
+        }
+
+        // Check output type of each input, if input's type does not match the 
result type,
+        // then add a projection with casts for non-matching fields.
+
+        RexBuilder rexBuilder = cluster.getRexBuilder();
+        List<RelNode> actualInputs = new ArrayList<>(inputs.size());
+
+        for (RelNode input : inputs) {
+            RelDataType inputRowType = input.getRowType();
+
+            if (resultType.equalsSansFieldNames(inputRowType)) {
+                actualInputs.add(input);
+
+                continue;
+            }
+
+            List<RexNode> exprs = new 
ArrayList<>(inputRowType.getFieldCount());
+
+            for (int i = 0; i < resultType.getFieldCount(); i++) {
+                RelDataType fieldType = 
inputRowType.getFieldList().get(i).getType();
+                RelDataType outFieldType = 
resultType.getFieldList().get(i).getType();
+                RexNode ref = rexBuilder.makeInputRef(input, i);
+
+                if (fieldType.equals(outFieldType)) {
+                    exprs.add(ref);
+                } else {
+                    RexNode expr = rexBuilder.makeCast(outFieldType, ref, 
true, false);
+                    exprs.add(expr);
+                }
+            }
+
+            actualInputs.add(new IgniteProject(cluster, traits, input, exprs, 
resultType));
+        }
+
+        return actualInputs;
+    }
 }
diff --git 
a/modules/sql-engine/src/test/java/org/apache/ignite/internal/sql/engine/planner/SetOpPlannerTest.java
 
b/modules/sql-engine/src/test/java/org/apache/ignite/internal/sql/engine/planner/SetOpPlannerTest.java
index 1bc5e7d3c7..25f4f5c17d 100644
--- 
a/modules/sql-engine/src/test/java/org/apache/ignite/internal/sql/engine/planner/SetOpPlannerTest.java
+++ 
b/modules/sql-engine/src/test/java/org/apache/ignite/internal/sql/engine/planner/SetOpPlannerTest.java
@@ -17,11 +17,16 @@
 
 package org.apache.ignite.internal.sql.engine.planner;
 
+import java.util.Arrays;
 import java.util.List;
+import java.util.function.Predicate;
 import java.util.function.UnaryOperator;
 import org.apache.calcite.rel.RelDistribution.Type;
+import org.apache.calcite.rel.RelNode;
+import org.apache.ignite.internal.sql.engine.framework.TestBuilders;
 import 
org.apache.ignite.internal.sql.engine.framework.TestBuilders.TableBuilder;
 import org.apache.ignite.internal.sql.engine.rel.IgniteExchange;
+import org.apache.ignite.internal.sql.engine.rel.IgniteProject;
 import org.apache.ignite.internal.sql.engine.rel.IgniteTrimExchange;
 import org.apache.ignite.internal.sql.engine.rel.set.IgniteColocatedIntersect;
 import org.apache.ignite.internal.sql.engine.rel.set.IgniteColocatedMinus;
@@ -694,6 +699,86 @@ public class SetOpPlannerTest extends AbstractPlannerTest {
         );
     }
 
+    @ParameterizedTest
+    @EnumSource
+    public void testSetOpResultsInLeastRestrictiveType(SetOp setOp) throws 
Exception {
+        IgniteSchema publicSchema = createSchema(
+                TestBuilders.table()
+                        .name("TABLE1")
+                        .addColumn("C1", NativeTypes.INT32)
+                        .addColumn("C2", NativeTypes.STRING)
+                        .distribution(someAffinity())
+                        .build(),
+
+                TestBuilders.table()
+                        .name("TABLE2")
+                        .addColumn("C1", NativeTypes.DOUBLE)
+                        .addColumn("C2", NativeTypes.STRING)
+                        .distribution(someAffinity())
+                        .build(),
+
+                TestBuilders.table()
+                        .name("TABLE3")
+                        .addColumn("C1", NativeTypes.INT64)
+                        .addColumn("C2", NativeTypes.STRING)
+                        .distribution(someAffinity())
+                        .build()
+        );
+
+        String sql = "SELECT * FROM table1 "
+                + setOp
+                + " SELECT * FROM table2 "
+                + setOp
+                + " SELECT * FROM table3 ";
+
+        assertPlan(sql, publicSchema, nodeOrAnyChild(isInstanceOf(setOp.map)
+                        .and(input(0, projectFromTable("TABLE1", 
"CAST($0):DOUBLE", "$1")))
+                        .and(input(1, isTableScan("TABLE2")))
+                        .and(input(2, projectFromTable("TABLE3", 
"CAST($0):DOUBLE", "$1")))
+                )
+        );
+    }
+
+    @ParameterizedTest
+    @EnumSource
+    public void testSetOpDifferentNullability(SetOp setOp) throws Exception {
+        IgniteSchema publicSchema = createSchema(
+                TestBuilders.table()
+                        .name("TABLE1")
+                        .addColumn("C1", NativeTypes.INT32, false)
+                        .addColumn("C2", NativeTypes.STRING)
+                        .distribution(someAffinity())
+                        .build(),
+
+                TestBuilders.table()
+                        .name("TABLE2")
+                        .addColumn("C1", NativeTypes.INT32, true)
+                        .addColumn("C2", NativeTypes.STRING)
+                        .distribution(someAffinity())
+                        .build()
+        );
+
+        String sql = "SELECT * FROM table1 "
+                + setOp
+                + " SELECT * FROM table2";
+
+        assertPlan(sql, publicSchema, nodeOrAnyChild(isInstanceOf(setOp.map)
+                        .and(input(0, projectFromTable("TABLE1", 
"CAST($0):INTEGER", "$1")))
+                        .and(input(1, isTableScan("TABLE2")))
+                )
+        );
+    }
+
+    private Predicate<? extends RelNode> projectFromTable(String tableName, 
String... exprs) {
+        return isInstanceOf(IgniteProject.class)
+                .and(projection -> {
+                    String actualProjStr = projection.getProjects().toString();
+                    String expectedProjStr = Arrays.asList(exprs).toString();
+                    return actualProjStr.equals(expectedProjStr);
+                })
+                .and(hasChildThat(isTableScan(tableName)));
+    }
+
     private String setOp(SetOp setOp) {
         return setOp.name() + ' ';
     }
diff --git 
a/modules/sql-engine/src/test/java/org/apache/ignite/internal/sql/engine/util/CommonsTest.java
 
b/modules/sql-engine/src/test/java/org/apache/ignite/internal/sql/engine/util/CommonsTest.java
index 9e13644e81..4357cca929 100644
--- 
a/modules/sql-engine/src/test/java/org/apache/ignite/internal/sql/engine/util/CommonsTest.java
+++ 
b/modules/sql-engine/src/test/java/org/apache/ignite/internal/sql/engine/util/CommonsTest.java
@@ -18,19 +18,31 @@
 package org.apache.ignite.internal.sql.engine.util;
 
 import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertInstanceOf;
+import static org.junit.jupiter.api.Assertions.assertSame;
+import static org.mockito.Mockito.when;
 
 import java.util.HashMap;
+import java.util.List;
 import java.util.Map;
+import org.apache.calcite.rel.RelNode;
+import org.apache.calcite.rel.type.RelDataType;
+import org.apache.calcite.rel.type.RelDataTypeFactory;
+import org.apache.calcite.sql.type.SqlTypeName;
 import org.apache.calcite.util.ImmutableBitSet;
 import org.apache.calcite.util.ImmutableIntList;
 import org.apache.calcite.util.mapping.Mapping;
 import org.apache.calcite.util.mapping.Mappings;
+import org.apache.ignite.internal.sql.engine.rel.IgniteProject;
+import org.apache.ignite.internal.sql.engine.type.IgniteTypeFactory;
+import org.apache.ignite.internal.testframework.BaseIgniteAbstractTest;
 import org.junit.jupiter.api.Test;
+import org.mockito.Mockito;
 
 /**
  * Tests for utility functions defined in {@link Commons}.
  */
-public class CommonsTest {
+public class CommonsTest extends BaseIgniteAbstractTest {
 
     @Test
     public void testTrimmingMapping() {
@@ -74,6 +86,60 @@ public class CommonsTest {
         assertEquals(vals, Commons.arrayToMap(new Object[]{1, null, "3"}));
     }
 
+    @Test
+    public void testCastInputsToLeastRestrictiveTypeIfNeeded() {
+        IgniteTypeFactory tf = Commons.typeFactory();
+
+        RelDataType row1 = new RelDataTypeFactory.Builder(tf)
+                .add("C1", tf.createSqlType(SqlTypeName.INTEGER))
+                .add("C2", tf.createSqlType(SqlTypeName.INTEGER))
+                .build();
+
+        RelDataType row2 = new RelDataTypeFactory.Builder(tf)
+                .add("C1", 
tf.createTypeWithNullability(tf.createSqlType(SqlTypeName.INTEGER), true))
+                .add("C2", tf.createSqlType(SqlTypeName.REAL))
+                .build();
+
+        RelDataType row3 = new RelDataTypeFactory.Builder(tf)
+                .add("C1", tf.createSqlType(SqlTypeName.INTEGER))
+                .add("C2", tf.createSqlType(SqlTypeName.REAL))
+                .build();
+
+        RelDataType row4 = new RelDataTypeFactory.Builder(tf)
+                .add("C1", 
tf.createTypeWithNullability(tf.createSqlType(SqlTypeName.INTEGER), true))
+                .add("C2", tf.createSqlType(SqlTypeName.INTEGER))
+                .build();
+
+        RelNode node1 = Mockito.mock(RelNode.class);
+        when(node1.getRowType()).thenReturn(row1);
+
+        RelNode node2 = Mockito.mock(RelNode.class);
+        when(node2.getRowType()).thenReturn(row2);
+
+        RelNode node3 = Mockito.mock(RelNode.class);
+        when(node3.getRowType()).thenReturn(row3);
+
+        RelNode node4 = Mockito.mock(RelNode.class);
+        when(node4.getRowType()).thenReturn(row4);
+
+        List<RelNode> relNodes = 
Commons.castInputsToLeastRestrictiveTypeIfNeeded(List.of(node1, node2, node3, 
node4), Commons.cluster(),
+                Commons.cluster().traitSet());
+
+        RelDataType lt = tf.leastRestrictive(List.of(row1, row2, row3, row4));
+
+        IgniteProject project1 = assertInstanceOf(IgniteProject.class, 
relNodes.get(0), "node1");
+        assertEquals(lt, project1.getRowType(), "Invalid types in projection 
for node1");
+
+        // Node 2 has the same type as leastRestrictive(row1, row2)
+        assertSame(node2, relNodes.get(1), "Invalid types in projection for 
node2");
+
+        IgniteProject project3 = assertInstanceOf(IgniteProject.class, 
relNodes.get(2), "node2");
+        assertEquals(lt, project3.getRowType(), "Invalid types in projection 
for node3");
+
+        IgniteProject project4 = assertInstanceOf(IgniteProject.class, 
relNodes.get(3), "node4");
+        assertEquals(lt, project4.getRowType(), "Invalid types in projection 
for node4");
+    }
+
     private static void expectMapped(Mapping mapping, ImmutableBitSet bitSet, 
ImmutableBitSet expected) {
         assertEquals(expected, Mappings.apply(mapping, bitSet), "direct 
mapping");
 

Reply via email to