This is an automated email from the ASF dual-hosted git repository.

krisztiankasa pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/hive.git


The following commit(s) were added to refs/heads/master by this push:
     new 067d598a113 HIVE-26652: HiveSortPullUpConstantsRule produces an 
invalid plan when pulling up constants for nullable fields (Alessandro 
Solimando, reviewed by Krisztian Kasa)
067d598a113 is described below

commit 067d598a1131709172af3efd8b95ea370c0408c2
Author: Alessandro Solimando <[email protected]>
AuthorDate: Mon Oct 24 16:34:58 2022 +0200

    HIVE-26652: HiveSortPullUpConstantsRule produces an invalid plan when 
pulling up constants for nullable fields (Alessandro Solimando, reviewed by 
Krisztian Kasa)
---
 .../calcite/rules/HiveSortPullUpConstantsRule.java |  16 +-
 .../rules/TestHivePointLookupOptimizerRule.java    | 216 +++++++++------------
 .../TestHiveSortExchangePullUpConstantsRule.java   | 110 +++++++++++
 .../TestHiveSortLimitPullUpConstantsRule.java      | 109 +++++++++++
 .../rules/TestHiveUnionPullUpConstantsRule.java    |  94 +++------
 .../ql/optimizer/calcite/rules/TestRuleHelper.java | 110 +++++++++++
 6 files changed, 456 insertions(+), 199 deletions(-)

diff --git 
a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveSortPullUpConstantsRule.java
 
b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveSortPullUpConstantsRule.java
index 5cf2eb6a6c6..51f53cd0ead 100644
--- 
a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveSortPullUpConstantsRule.java
+++ 
b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveSortPullUpConstantsRule.java
@@ -50,7 +50,7 @@ import com.google.common.collect.ImmutableList;
 
 /**
  * Planner rule that pulls up constant keys through a SortLimit or 
SortExchange operator.
- *
+ * <p>
  * This rule is only applied on SortLimit operators that are not the root
  * of the plan tree. This is done because the interaction of this rule
  * with the AST conversion may cause some optimizations to not kick in
@@ -67,7 +67,7 @@ public final class HiveSortPullUpConstantsRule {
   private static final class HiveSortLimitPullUpConstantsRule
       extends HiveSortPullUpConstantsRuleBase<HiveSortLimit> {
 
-    protected HiveSortLimitPullUpConstantsRule() {
+    private HiveSortLimitPullUpConstantsRule() {
       super(HiveSortLimit.class);
     }
 
@@ -87,7 +87,7 @@ public final class HiveSortPullUpConstantsRule {
   private static final class HiveSortExchangePullUpConstantsRule
       extends HiveSortPullUpConstantsRuleBase<HiveSortExchange> {
 
-    protected HiveSortExchangePullUpConstantsRule() {
+    private HiveSortExchangePullUpConstantsRule() {
       super(HiveSortExchange.class);
     }
 
@@ -154,10 +154,14 @@ public final class HiveSortPullUpConstantsRule {
         RexNode expr = rexBuilder.makeInputRef(sortNode.getInput(), i);
         RelDataTypeField field = fields.get(i);
         if (constants.containsKey(expr)) {
-          topChildExprs.add(constants.get(expr));
+          if (constants.get(expr).getType().equals(field.getType())) {
+            topChildExprs.add(constants.get(expr));
+          } else {
+            topChildExprs.add(rexBuilder.makeCast(field.getType(), 
constants.get(expr), true));
+          }
           topChildExprsFields.add(field.getName());
         } else {
-          newChildExprs.add(Pair.<RexNode, String>of(expr, field.getName()));
+          newChildExprs.add(Pair.of(expr, field.getName()));
           topChildExprs.add(expr);
           topChildExprsFields.add(field.getName());
         }
@@ -199,7 +203,7 @@ public final class HiveSortPullUpConstantsRule {
           // It is a constant, we can ignore it
           continue;
         }
-        fieldCollations.add(fc.copy(target));
+        fieldCollations.add(fc.withFieldIndex(target));
       }
       return fieldCollations;
     }
diff --git 
a/ql/src/test/org/apache/hadoop/hive/ql/optimizer/calcite/rules/TestHivePointLookupOptimizerRule.java
 
b/ql/src/test/org/apache/hadoop/hive/ql/optimizer/calcite/rules/TestHivePointLookupOptimizerRule.java
index 67ba43784a6..ba695f44af2 100644
--- 
a/ql/src/test/org/apache/hadoop/hive/ql/optimizer/calcite/rules/TestHivePointLookupOptimizerRule.java
+++ 
b/ql/src/test/org/apache/hadoop/hive/ql/optimizer/calcite/rules/TestHivePointLookupOptimizerRule.java
@@ -18,35 +18,29 @@
 
 package org.apache.hadoop.hive.ql.optimizer.calcite.rules;
 
-import org.apache.calcite.jdbc.JavaTypeFactoryImpl;
-import org.apache.calcite.plan.RelOptCluster;
+import org.apache.calcite.plan.AbstractRelOptPlanner;
 import org.apache.calcite.plan.RelOptSchema;
-import org.apache.calcite.plan.hep.HepPlanner;
-import org.apache.calcite.plan.hep.HepProgramBuilder;
 import org.apache.calcite.rel.RelNode;
-import org.apache.calcite.rel.logical.LogicalTableScan;
-import org.apache.calcite.rel.type.RelDataType;
-import org.apache.calcite.rex.RexBuilder;
 import org.apache.calcite.rex.RexNode;
-import org.apache.calcite.sql.fun.SqlStdOperatorTable;
 import org.apache.calcite.tools.RelBuilder;
 import org.apache.hadoop.hive.ql.metadata.Table;
-import org.apache.hadoop.hive.ql.optimizer.calcite.HiveRelFactories;
 import org.apache.hadoop.hive.ql.optimizer.calcite.RelOptHiveTable;
 import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveFilter;
 import org.junit.Before;
 import org.junit.Test;
 import org.junit.runner.RunWith;
-import org.mockito.ArgumentMatchers;
 import org.mockito.Mock;
 import org.mockito.junit.MockitoJUnitRunner;
 
 import java.util.Collections;
 
+import static 
org.apache.hadoop.hive.ql.optimizer.calcite.rules.TestRuleHelper.buildPlanner;
+import static 
org.apache.hadoop.hive.ql.optimizer.calcite.rules.TestRuleHelper.buildRelBuilder;
+import static 
org.apache.hadoop.hive.ql.optimizer.calcite.rules.TestRuleHelper.and;
+import static 
org.apache.hadoop.hive.ql.optimizer.calcite.rules.TestRuleHelper.eq;
+import static 
org.apache.hadoop.hive.ql.optimizer.calcite.rules.TestRuleHelper.or;
+
 import static org.junit.Assert.assertEquals;
-import static org.mockito.ArgumentMatchers.any;
-import static org.mockito.Mockito.doReturn;
-import static org.mockito.Mockito.lenient;
 
 @RunWith(MockitoJUnitRunner.class)
 public class TestHivePointLookupOptimizerRule {
@@ -58,8 +52,8 @@ public class TestHivePointLookupOptimizerRule {
   @Mock
   Table hiveTableMDMock;
 
-  private HepPlanner planner;
-  private RelBuilder builder;
+  private AbstractRelOptPlanner planner;
+  private RelBuilder relBuilder;
 
   @SuppressWarnings("unused")
   private static class MyRecord {
@@ -71,53 +65,25 @@ public class TestHivePointLookupOptimizerRule {
 
   @Before
   public void before() {
-    HepProgramBuilder programBuilder = new HepProgramBuilder();
-    programBuilder.addRuleInstance(new 
HivePointLookupOptimizerRule.FilterCondition(2));
-
-    planner = new HepPlanner(programBuilder.build());
-
-    JavaTypeFactoryImpl typeFactory = new JavaTypeFactoryImpl();
-    RexBuilder rexBuilder = new RexBuilder(typeFactory);
-    final RelOptCluster optCluster = RelOptCluster.create(planner, rexBuilder);
-    RelDataType rowTypeMock = typeFactory.createStructType(MyRecord.class);
-    doReturn(rowTypeMock).when(tableMock).getRowType();
-    LogicalTableScan tableScan = LogicalTableScan.create(optCluster, 
tableMock, Collections.emptyList());
-    doReturn(tableScan).when(tableMock).toRel(ArgumentMatchers.any());
-    doReturn(tableMock).when(schemaMock).getTableForMember(any());
-    lenient().doReturn(hiveTableMDMock).when(tableMock).getHiveTableMD();
-
-    builder = HiveRelFactories.HIVE_BUILDER.create(optCluster, schemaMock);
-
-  }
-
-  public RexNode or(RexNode... args) {
-    return builder.call(SqlStdOperatorTable.OR, args);
-  }
-
-  public RexNode and(RexNode... args) {
-    return builder.call(SqlStdOperatorTable.AND, args);
-  }
-
-  public RexNode eq(String field, Number value) {
-    return builder.call(SqlStdOperatorTable.EQUALS,
-        builder.field(field), builder.literal(value));
+    planner = buildPlanner(Collections.singletonList(new 
HivePointLookupOptimizerRule.FilterCondition(2)));
+    relBuilder = buildRelBuilder(planner, schemaMock, tableMock, 
hiveTableMDMock, MyRecord.class);
   }
 
   @Test
   public void testSimpleCase() {
 
     // @formatter:off
-    final RelNode basePlan = builder
+    final RelNode basePlan = relBuilder
           .scan("t")
           .filter(
-              and(
-                or(
-                    eq("f1",1),
-                    eq("f1",2)
+              and(relBuilder,
+                or(relBuilder,
+                    eq(relBuilder, "f1",1),
+                    eq(relBuilder, "f1",2)
                     ),
-                or(
-                    eq("f2",3),
-                    eq("f2",4)
+                or(relBuilder,
+                    eq(relBuilder, "f2",3),
+                    eq(relBuilder, "f2",4)
                     )
                 )
               )
@@ -136,17 +102,17 @@ public class TestHivePointLookupOptimizerRule {
   public void testInExprsMergedSingleOverlap() {
 
     // @formatter:off
-    final RelNode basePlan = builder
+    final RelNode basePlan = relBuilder
         .scan("t")
         .filter(
-            and(
-                or(
-                    eq("f1",1),
-                    eq("f1",2)
+            and(relBuilder,
+                or(relBuilder,
+                    eq(relBuilder,"f1",1),
+                    eq(relBuilder,"f1",2)
                 ),
-                or(
-                    eq("f1",1),
-                    eq("f1",3)
+                or(relBuilder,
+                    eq(relBuilder,"f1",1),
+                    eq(relBuilder,"f1",3)
                 )
             )
         )
@@ -165,19 +131,19 @@ public class TestHivePointLookupOptimizerRule {
   public void testInExprsAndEqualsMerged() {
 
     // @formatter:off
-    final RelNode basePlan = builder
+    final RelNode basePlan = relBuilder
         .scan("t")
         .filter(
-            and(
-                or(
-                    eq("f1",1),
-                    eq("f1",2)
+            and(relBuilder,
+                or(relBuilder,
+                    eq(relBuilder,"f1",1),
+                    eq(relBuilder,"f1",2)
                 ),
-                or(
-                    eq("f1",1),
-                    eq("f1",3)
+                or(relBuilder,
+                    eq(relBuilder,"f1",1),
+                    eq(relBuilder,"f1",3)
                 ),
-                eq("f1",1)
+                eq(relBuilder,"f1",1)
             )
         )
         .build();
@@ -195,21 +161,21 @@ public class TestHivePointLookupOptimizerRule {
   public void testInExprsMergedMultipleOverlap() {
 
     // @formatter:off
-    final RelNode basePlan = builder
+    final RelNode basePlan = relBuilder
         .scan("t")
         .filter(
-            and(
-                or(
-                    eq("f1",1),
-                    eq("f1",2),
-                    eq("f1",4),
-                    eq("f1",3)
+            and(relBuilder,
+                or(relBuilder,
+                    eq(relBuilder,"f1",1),
+                    eq(relBuilder,"f1",2),
+                    eq(relBuilder,"f1",4),
+                    eq(relBuilder,"f1",3)
                 ),
-                or(
-                    eq("f1",5),
-                    eq("f1",1),
-                    eq("f1",2),
-                    eq("f1",3)
+                or(relBuilder,
+                    eq(relBuilder,"f1",5),
+                    eq(relBuilder,"f1",1),
+                    eq(relBuilder,"f1",2),
+                    eq(relBuilder,"f1",3)
                 )
             )
         )
@@ -228,18 +194,18 @@ public class TestHivePointLookupOptimizerRule {
   public void testCaseWithConstantsOfDifferentType() {
 
     // @formatter:off
-    final RelNode basePlan = builder
+    final RelNode basePlan = relBuilder
         .scan("t")
         .filter(
-            and(
-                or(
-                    eq("f1",1),
-                    eq("f1",2)
+            and(relBuilder,
+                or(relBuilder,
+                    eq(relBuilder,"f1",1),
+                    eq(relBuilder,"f1",2)
                 ),
-                eq("f1", 1.0),
-                or(
-                    eq("f4",3.0),
-                    eq("f4",4.1)
+                eq(relBuilder,"f1", 1.0),
+                or(relBuilder,
+                    eq(relBuilder,"f4",3.0),
+                    eq(relBuilder,"f4",4.1)
                 )
             )
         )
@@ -261,20 +227,20 @@ public class TestHivePointLookupOptimizerRule {
   public void testCaseInAndEqualsWithConstantsOfDifferentType() {
 
     // @formatter:off
-    final RelNode basePlan = builder
+    final RelNode basePlan = relBuilder
         .scan("t")
         .filter(
-            and(
-                or(
-                    eq("f1",1),
-                    eq("f1",2)
+            and(relBuilder,
+                or(relBuilder,
+                    eq(relBuilder,"f1",1),
+                    eq(relBuilder,"f1",2)
                 ),
-                eq("f1",1),
-                or(
-                    eq("f4",3.0),
-                    eq("f4",4.1)
+                eq(relBuilder,"f1",1),
+                or(relBuilder,
+                    eq(relBuilder,"f4",3.0),
+                    eq(relBuilder,"f4",4.1)
                 ),
-                eq("f4",4.1)
+                eq(relBuilder,"f4",4.1)
             )
         )
         .build();
@@ -292,12 +258,14 @@ public class TestHivePointLookupOptimizerRule {
   public void testSimpleStructCase() {
 
     // @formatter:off
-    final RelNode basePlan = builder
+    final RelNode basePlan = relBuilder
           .scan("t")
           .filter(
-              or(
-                  and( eq("f1",1),eq("f2",1)),
-                  and( eq("f1",2),eq("f2",2))
+              or(relBuilder,
+                  and(relBuilder,
+                      eq(relBuilder,"f1",1), eq(relBuilder,"f2",1)),
+                  and(relBuilder,
+                      eq(relBuilder,"f1",2), eq(relBuilder,"f2",2))
                   )
               )
           .build();
@@ -316,13 +284,13 @@ public class TestHivePointLookupOptimizerRule {
   public void testObscuredSimple() {
 
     // @formatter:off
-    final RelNode basePlan = builder
+    final RelNode basePlan = relBuilder
           .scan("t")
           .filter(
-              or(
-                  eq("f2",99),
-                  eq("f1",1),
-                  eq("f1",2)
+              or(relBuilder,
+                  eq(relBuilder,"f2",99),
+                  eq(relBuilder,"f1",1),
+                  eq(relBuilder,"f1",2)
                   )
               )
           .build();
@@ -342,23 +310,27 @@ public class TestHivePointLookupOptimizerRule {
   public void testRecursionIsNotObstructed() {
 
     // @formatter:off
-    final RelNode basePlan = builder
+    final RelNode basePlan = relBuilder
           .scan("t")
           .filter(
-              and(
-                or(
-                    eq("f1",1),
-                    eq("f1",2)
+              and(relBuilder,
+                or(relBuilder,
+                    eq(relBuilder,"f1",1),
+                    eq(relBuilder,"f1",2)
                     )
                 ,
-                or(
-                    and(
-                        or(eq("f2",1),eq("f2",2)),
-                        or(eq("f3",1),eq("f3",2))
+                or(relBuilder,
+                    and(relBuilder,
+                        or(relBuilder,
+                            eq(relBuilder,"f2",1), eq(relBuilder,"f2",2)),
+                        or(relBuilder,
+                            eq(relBuilder,"f3",1), eq(relBuilder,"f3",2))
                         ),
-                    and(
-                        or(eq("f2",3),eq("f2",4)),
-                        or(eq("f3",3),eq("f3",4))
+                    and(relBuilder,
+                        or(relBuilder,
+                            eq(relBuilder,"f2",3),eq(relBuilder,"f2",4)),
+                        or(relBuilder,
+                            eq(relBuilder,"f3",3),eq(relBuilder,"f3",4))
                         )
                 )
               ))
diff --git 
a/ql/src/test/org/apache/hadoop/hive/ql/optimizer/calcite/rules/TestHiveSortExchangePullUpConstantsRule.java
 
b/ql/src/test/org/apache/hadoop/hive/ql/optimizer/calcite/rules/TestHiveSortExchangePullUpConstantsRule.java
new file mode 100644
index 00000000000..30681a7ec40
--- /dev/null
+++ 
b/ql/src/test/org/apache/hadoop/hive/ql/optimizer/calcite/rules/TestHiveSortExchangePullUpConstantsRule.java
@@ -0,0 +1,110 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.hive.ql.optimizer.calcite.rules;
+
+import org.apache.calcite.plan.AbstractRelOptPlanner;
+import org.apache.calcite.plan.RelOptSchema;
+import org.apache.calcite.rel.RelCollations;
+import org.apache.calcite.rel.RelDistributions;
+import org.apache.calcite.rel.RelNode;
+import org.apache.calcite.tools.RelBuilder;
+import org.apache.hadoop.hive.ql.metadata.Table;
+import org.apache.hadoop.hive.ql.optimizer.calcite.RelOptHiveTable;
+import org.junit.runner.RunWith;
+import org.mockito.Mock;
+import org.mockito.junit.MockitoJUnitRunner;
+
+import java.util.Arrays;
+
+import static 
org.apache.hadoop.hive.ql.optimizer.calcite.rules.TestRuleHelper.assertPlans;
+import static 
org.apache.hadoop.hive.ql.optimizer.calcite.rules.TestRuleHelper.buildPlanner;
+import static 
org.apache.hadoop.hive.ql.optimizer.calcite.rules.TestRuleHelper.buildRelBuilder;
+import static 
org.apache.hadoop.hive.ql.optimizer.calcite.rules.TestRuleHelper.MyRecord;
+import static 
org.apache.hadoop.hive.ql.optimizer.calcite.rules.TestRuleHelper.MyRecordWithNullableField;
+import static 
org.apache.hadoop.hive.ql.optimizer.calcite.rules.TestRuleHelper.eq;
+
+@RunWith(MockitoJUnitRunner.class)
+public class TestHiveSortExchangePullUpConstantsRule {
+
+  @Mock
+  private RelOptSchema schemaMock;
+  @Mock
+  RelOptHiveTable tableMock;
+  @Mock
+  Table hiveTableMDMock;
+
+  private AbstractRelOptPlanner planner;
+  private RelBuilder relBuilder;
+
+  public void before(Class<?> clazz) {
+    planner = buildPlanner(Arrays.asList(
+        HiveSortPullUpConstantsRule.SORT_EXCHANGE_INSTANCE, 
HiveProjectMergeRule.INSTANCE));
+    relBuilder = buildRelBuilder(planner, schemaMock, tableMock, 
hiveTableMDMock, clazz);
+  }
+
+  @org.junit.Test
+  public void testNonNullableFields() {
+    before(MyRecord.class);
+
+    final RelNode plan = relBuilder
+        .scan("t")
+        .filter(eq(relBuilder, "f1",1))
+        .sortExchange(RelDistributions.ROUND_ROBIN_DISTRIBUTED, 
RelCollations.of(0))
+        .project(relBuilder.field("f1"), relBuilder.field("f2"))
+        .build();
+
+    String prePlan = "HiveProject(f1=[$0], f2=[$1])\n"
+                   + "  HiveSortExchange(distribution=[rr], collation=[[0]])\n"
+                   + "    HiveFilter(condition=[=($0, 1)])\n"
+                   + "      LogicalTableScan(table=[[]])\n";
+
+    String postPlan = "HiveProject(f1=[1], f2=[$0])\n"
+                    + "  HiveSortExchange(distribution=[rr], collation=[[]])\n"
+                    + "    HiveProject(f2=[$1], f3=[$2])\n"
+                    + "      HiveFilter(condition=[=($0, 1)])\n"
+                    + "        LogicalTableScan(table=[[]])\n";
+
+    assertPlans(planner, plan, prePlan, postPlan);
+  }
+
+  @org.junit.Test
+  public void testNullableFields() {
+    before(MyRecordWithNullableField.class);
+
+    final RelNode plan = relBuilder
+        .scan("t")
+        .filter(eq(relBuilder, "f1",1))
+        .sortExchange(RelDistributions.ROUND_ROBIN_DISTRIBUTED, 
RelCollations.of(0))
+        .project(relBuilder.field("f1"), relBuilder.field("f2"))
+        .build();
+
+    String prePlan = "HiveProject(f1=[$0], f2=[$1])\n"
+                   + "  HiveSortExchange(distribution=[rr], collation=[[0]])\n"
+                   + "    HiveFilter(condition=[=($0, 1)])\n"
+                   + "      LogicalTableScan(table=[[]])\n";
+
+    String postPlan = "HiveProject(f1=[CAST(1):JavaType(class 
java.lang.Integer)], f2=[$0])\n"
+                    + "  HiveSortExchange(distribution=[rr], collation=[[]])\n"
+                    + "    HiveProject(f2=[$1], f3=[$2])\n"
+                    + "      HiveFilter(condition=[=($0, 1)])\n"
+                    + "        LogicalTableScan(table=[[]])\n";
+
+    assertPlans(planner, plan, prePlan, postPlan);
+  }
+}
diff --git 
a/ql/src/test/org/apache/hadoop/hive/ql/optimizer/calcite/rules/TestHiveSortLimitPullUpConstantsRule.java
 
b/ql/src/test/org/apache/hadoop/hive/ql/optimizer/calcite/rules/TestHiveSortLimitPullUpConstantsRule.java
new file mode 100644
index 00000000000..1f3288ce33d
--- /dev/null
+++ 
b/ql/src/test/org/apache/hadoop/hive/ql/optimizer/calcite/rules/TestHiveSortLimitPullUpConstantsRule.java
@@ -0,0 +1,109 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.hive.ql.optimizer.calcite.rules;
+
+import org.apache.calcite.plan.AbstractRelOptPlanner;
+import org.apache.calcite.plan.RelOptSchema;
+import org.apache.calcite.rel.RelNode;
+import org.apache.calcite.tools.RelBuilder;
+import org.apache.hadoop.hive.ql.metadata.Table;
+import org.apache.hadoop.hive.ql.optimizer.calcite.RelOptHiveTable;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.mockito.Mock;
+import org.mockito.junit.MockitoJUnitRunner;
+
+import java.util.Arrays;
+
+import static 
org.apache.hadoop.hive.ql.optimizer.calcite.rules.TestRuleHelper.assertPlans;
+import static 
org.apache.hadoop.hive.ql.optimizer.calcite.rules.TestRuleHelper.buildPlanner;
+import static 
org.apache.hadoop.hive.ql.optimizer.calcite.rules.TestRuleHelper.buildRelBuilder;
+import static 
org.apache.hadoop.hive.ql.optimizer.calcite.rules.TestRuleHelper.MyRecord;
+import static 
org.apache.hadoop.hive.ql.optimizer.calcite.rules.TestRuleHelper.MyRecordWithNullableField;
+import static 
org.apache.hadoop.hive.ql.optimizer.calcite.rules.TestRuleHelper.eq;
+
+@RunWith(MockitoJUnitRunner.class)
+public class TestHiveSortLimitPullUpConstantsRule {
+
+  @Mock
+  private RelOptSchema schemaMock;
+  @Mock
+  RelOptHiveTable tableMock;
+  @Mock
+  Table hiveTableMDMock;
+
+  private AbstractRelOptPlanner planner;
+  private RelBuilder relBuilder;
+
+  public void before(Class<?> clazz) {
+    planner = buildPlanner(Arrays.asList(
+        HiveSortPullUpConstantsRule.SORT_LIMIT_INSTANCE, 
HiveProjectMergeRule.INSTANCE));
+    relBuilder = buildRelBuilder(planner, schemaMock, tableMock, 
hiveTableMDMock, clazz);
+  }
+
+  @Test
+  public void testNonNullableFields() {
+    before(MyRecord.class);
+
+    final RelNode plan = relBuilder
+        .scan("t")
+        .filter(eq(relBuilder, "f1",1))
+        .sort(relBuilder.field("f1"), relBuilder.field("f2"))
+        .project(relBuilder.field("f1"), relBuilder.field("f2"))
+        .build();
+
+    String prePlan = "HiveProject(f1=[$0], f2=[$1])\n"
+                   + "  HiveSortLimit(sort0=[$0], sort1=[$1], dir0=[ASC], 
dir1=[ASC])\n"
+                   + "    HiveFilter(condition=[=($0, 1)])\n"
+                   + "      LogicalTableScan(table=[[]])\n";
+
+    String postPlan = "HiveProject(f1=[1], f2=[$0])\n"
+                    + "  HiveSortLimit(sort0=[$0], dir0=[ASC])\n"
+                    + "    HiveProject(f2=[$1], f3=[$2])\n"
+                    + "      HiveFilter(condition=[=($0, 1)])\n"
+                    + "        LogicalTableScan(table=[[]])\n";
+
+    assertPlans(planner, plan, prePlan, postPlan);
+  }
+
+  @org.junit.Test
+  public void testNullableFields() {
+    before(MyRecordWithNullableField.class);
+
+    final RelNode plan = relBuilder
+        .scan("t")
+        .filter(eq(relBuilder,"f1",1))
+        .sort(relBuilder.field("f1"), relBuilder.field("f2"))
+        .project(relBuilder.field("f1"), relBuilder.field("f2"))
+        .build();
+
+    String prePlan = "HiveProject(f1=[$0], f2=[$1])\n"
+                   + "  HiveSortLimit(sort0=[$0], sort1=[$1], dir0=[ASC], 
dir1=[ASC])\n"
+                   + "    HiveFilter(condition=[=($0, 1)])\n"
+                   + "      LogicalTableScan(table=[[]])\n";
+
+    String postPlan = "HiveProject(f1=[CAST(1):JavaType(class 
java.lang.Integer)], f2=[$0])\n"
+                    + "  HiveSortLimit(sort0=[$0], dir0=[ASC])\n"
+                    + "    HiveProject(f2=[$1], f3=[$2])\n"
+                    + "      HiveFilter(condition=[=($0, 1)])\n"
+                    + "        LogicalTableScan(table=[[]])\n";
+
+    assertPlans(planner, plan, prePlan, postPlan);
+  }
+}
diff --git 
a/ql/src/test/org/apache/hadoop/hive/ql/optimizer/calcite/rules/TestHiveUnionPullUpConstantsRule.java
 
b/ql/src/test/org/apache/hadoop/hive/ql/optimizer/calcite/rules/TestHiveUnionPullUpConstantsRule.java
index 6e455aa732d..bc96a2cdb87 100644
--- 
a/ql/src/test/org/apache/hadoop/hive/ql/optimizer/calcite/rules/TestHiveUnionPullUpConstantsRule.java
+++ 
b/ql/src/test/org/apache/hadoop/hive/ql/optimizer/calcite/rules/TestHiveUnionPullUpConstantsRule.java
@@ -18,40 +18,29 @@
 
 package org.apache.hadoop.hive.ql.optimizer.calcite.rules;
 
-import org.apache.calcite.jdbc.JavaTypeFactoryImpl;
-import org.apache.calcite.plan.RelOptCluster;
+import org.apache.calcite.plan.AbstractRelOptPlanner;
 import org.apache.calcite.plan.RelOptSchema;
-import org.apache.calcite.plan.RelOptUtil;
-import org.apache.calcite.plan.hep.HepPlanner;
-import org.apache.calcite.plan.hep.HepProgramBuilder;
 import org.apache.calcite.rel.RelNode;
-import org.apache.calcite.rel.logical.LogicalTableScan;
-import org.apache.calcite.rel.type.RelDataType;
-import org.apache.calcite.rex.RexBuilder;
-import org.apache.calcite.rex.RexNode;
-import org.apache.calcite.sql.fun.SqlStdOperatorTable;
 import org.apache.calcite.tools.RelBuilder;
 import org.apache.hadoop.hive.ql.metadata.Table;
-import org.apache.hadoop.hive.ql.optimizer.calcite.HiveRelFactories;
 import org.apache.hadoop.hive.ql.optimizer.calcite.RelOptHiveTable;
 import org.junit.Test;
 import org.junit.runner.RunWith;
-import org.mockito.ArgumentMatchers;
 import org.mockito.Mock;
 import org.mockito.junit.MockitoJUnitRunner;
 
 import java.util.Collections;
 
-import static org.junit.Assert.assertEquals;
-import static org.mockito.ArgumentMatchers.any;
-import static org.mockito.Mockito.doReturn;
-import static org.mockito.Mockito.lenient;
+import static 
org.apache.hadoop.hive.ql.optimizer.calcite.rules.TestRuleHelper.assertPlans;
+import static 
org.apache.hadoop.hive.ql.optimizer.calcite.rules.TestRuleHelper.buildPlanner;
+import static 
org.apache.hadoop.hive.ql.optimizer.calcite.rules.TestRuleHelper.buildRelBuilder;
+import static 
org.apache.hadoop.hive.ql.optimizer.calcite.rules.TestRuleHelper.MyRecord;
+import static 
org.apache.hadoop.hive.ql.optimizer.calcite.rules.TestRuleHelper.MyRecordWithNullableField;
+import static 
org.apache.hadoop.hive.ql.optimizer.calcite.rules.TestRuleHelper.eq;
 
 @RunWith(MockitoJUnitRunner.class)
 public class TestHiveUnionPullUpConstantsRule {
 
-  private final static JavaTypeFactoryImpl JAVA_TYPE_FACTORY = new 
JavaTypeFactoryImpl();
-
   @Mock
   private RelOptSchema schemaMock;
   @Mock
@@ -59,62 +48,25 @@ public class TestHiveUnionPullUpConstantsRule {
   @Mock
   Table hiveTableMDMock;
 
-  private HepPlanner planner;
-  private RelBuilder rexBuilder;
-
-  private static class MyRecordWithNullableField {
-    public Integer f1;
-    public int f2;
-    public double f3;
-  }
-
-  private static class MyRecord {
-    public int f1;
-    public int f2;
-    public double f3;
-  }
+  private AbstractRelOptPlanner planner;
+  private RelBuilder relBuilder;
 
   public void before(Class<?> clazz) {
-    HepProgramBuilder programBuilder = new HepProgramBuilder();
-    programBuilder.addRuleInstance(HiveUnionPullUpConstantsRule.INSTANCE);
-
-    planner = new HepPlanner(programBuilder.build());
-
-    RexBuilder rexBuilder = new RexBuilder(JAVA_TYPE_FACTORY);
-    final RelOptCluster optCluster = RelOptCluster.create(planner, rexBuilder);
-    RelDataType rowTypeMock = JAVA_TYPE_FACTORY.createStructType(clazz);
-    doReturn(rowTypeMock).when(tableMock).getRowType();
-    LogicalTableScan tableScan = LogicalTableScan.create(optCluster, 
tableMock, Collections.emptyList());
-    doReturn(tableScan).when(tableMock).toRel(ArgumentMatchers.any());
-    doReturn(tableMock).when(schemaMock).getTableForMember(any());
-    lenient().doReturn(hiveTableMDMock).when(tableMock).getHiveTableMD();
-
-    this.rexBuilder = HiveRelFactories.HIVE_BUILDER.create(optCluster, 
schemaMock);
-  }
-
-  public RexNode eq(String field, Number value) {
-    return rexBuilder.call(SqlStdOperatorTable.EQUALS,
-        rexBuilder.field(field), rexBuilder.literal(value));
-  }
-
-  private void test(RelNode plan, String expectedPrePlan, String 
expectedPostPlan) {
-    planner.setRoot(plan);
-    RelNode optimizedRelNode = planner.findBestExp();
-    assertEquals("Original plans do not match", expectedPrePlan, 
RelOptUtil.toString(plan));
-    assertEquals("Optimized plans do not match", expectedPostPlan, 
RelOptUtil.toString(optimizedRelNode));
+    planner = 
buildPlanner(Collections.singletonList(HiveUnionPullUpConstantsRule.INSTANCE));
+    relBuilder = buildRelBuilder(planner, schemaMock, tableMock, 
hiveTableMDMock, clazz);
   }
 
   @Test
   public void testNonNullableFields() {
     before(MyRecord.class);
 
-    final RelNode plan = rexBuilder
+    final RelNode plan = relBuilder
         .scan("t")
-        .filter(eq("f1",1))
-        .project(rexBuilder.field("f1"), rexBuilder.field("f2"))
+        .filter(eq(relBuilder, "f1",1))
+        .project(relBuilder.field("f1"), relBuilder.field("f2"))
         .scan("t")
-        .filter(eq("f1",1))
-        .project(rexBuilder.field("f1"), rexBuilder.field("f2"))
+        .filter(eq(relBuilder, "f1",1))
+        .project(relBuilder.field("f1"), relBuilder.field("f2"))
         .union(true)
         .build();
 
@@ -137,20 +89,20 @@ public class TestHiveUnionPullUpConstantsRule {
                     + "        HiveFilter(condition=[=($0, 1)])\n"
                     + "          LogicalTableScan(table=[[]])\n";
 
-    test(plan, prePlan, postPlan);
+    assertPlans(planner, plan, prePlan, postPlan);
   }
 
   @Test
   public void testNullableFields() {
     before(MyRecordWithNullableField.class);
 
-    final RelNode plan = rexBuilder
+    final RelNode plan = relBuilder
         .scan("t")
-        .filter(eq("f1",1))
-        .project(rexBuilder.field("f1"), rexBuilder.field("f2"))
+        .filter(eq(relBuilder, "f1",1))
+        .project(relBuilder.field("f1"), relBuilder.field("f2"))
         .scan("t")
-        .filter(eq("f1",1))
-        .project(rexBuilder.field("f1"), rexBuilder.field("f2"))
+        .filter(eq(relBuilder, "f1",1))
+        .project(relBuilder.field("f1"), relBuilder.field("f2"))
         .union(false)
         .build();
 
@@ -173,6 +125,6 @@ public class TestHiveUnionPullUpConstantsRule {
                     + "        HiveFilter(condition=[=($0, 1)])\n"
                     + "          LogicalTableScan(table=[[]])\n";
 
-    test(plan, prePlan, postPlan);
+    assertPlans(planner, plan, prePlan, postPlan);
   }
 }
diff --git 
a/ql/src/test/org/apache/hadoop/hive/ql/optimizer/calcite/rules/TestRuleHelper.java
 
b/ql/src/test/org/apache/hadoop/hive/ql/optimizer/calcite/rules/TestRuleHelper.java
new file mode 100644
index 00000000000..8c49f58c424
--- /dev/null
+++ 
b/ql/src/test/org/apache/hadoop/hive/ql/optimizer/calcite/rules/TestRuleHelper.java
@@ -0,0 +1,110 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.hive.ql.optimizer.calcite.rules;
+
+import org.apache.calcite.jdbc.JavaTypeFactoryImpl;
+import org.apache.calcite.plan.AbstractRelOptPlanner;
+import org.apache.calcite.plan.RelOptCluster;
+import org.apache.calcite.plan.RelOptPlanner;
+import org.apache.calcite.plan.RelOptRule;
+import org.apache.calcite.plan.RelOptSchema;
+import org.apache.calcite.plan.RelOptUtil;
+import org.apache.calcite.plan.hep.HepPlanner;
+import org.apache.calcite.plan.hep.HepProgramBuilder;
+import org.apache.calcite.rel.RelNode;
+import org.apache.calcite.rel.logical.LogicalTableScan;
+import org.apache.calcite.rel.type.RelDataType;
+import org.apache.calcite.rex.RexBuilder;
+import org.apache.calcite.rex.RexNode;
+import org.apache.calcite.sql.fun.SqlStdOperatorTable;
+import org.apache.calcite.tools.RelBuilder;
+import org.apache.hadoop.hive.ql.metadata.Table;
+import org.apache.hadoop.hive.ql.optimizer.calcite.HiveRelFactories;
+import org.apache.hadoop.hive.ql.optimizer.calcite.RelOptHiveTable;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.mockito.ArgumentMatchers;
+import org.mockito.Mock;
+import org.mockito.junit.MockitoJUnitRunner;
+
+import java.util.Collection;
+import java.util.Collections;
+
+import static org.junit.Assert.assertEquals;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.Mockito.doReturn;
+import static org.mockito.Mockito.lenient;
+
+public class TestRuleHelper {
+
+  protected final static JavaTypeFactoryImpl JAVA_TYPE_FACTORY = new 
JavaTypeFactoryImpl();
+
+  static class MyRecordWithNullableField {
+    public Integer f1;
+    public int f2;
+    public double f3;
+  }
+
+  static class MyRecord {
+    public int f1;
+    public int f2;
+    public double f3;
+  }
+
+  public static AbstractRelOptPlanner buildPlanner(Collection<RelOptRule> 
rules) {
+    HepProgramBuilder programBuilder = new HepProgramBuilder();
+    rules.forEach(programBuilder::addRuleInstance);
+    return new HepPlanner(programBuilder.build());
+  }
+
+  public static RelBuilder buildRelBuilder(AbstractRelOptPlanner planner,
+      RelOptSchema schemaMock, RelOptHiveTable tableMock, Table hiveTableMock, 
Class<?> clazz) {
+
+    RexBuilder rexBuilder = new RexBuilder(JAVA_TYPE_FACTORY);
+    final RelOptCluster optCluster = RelOptCluster.create(planner, rexBuilder);
+    RelDataType rowTypeMock = JAVA_TYPE_FACTORY.createStructType(clazz);
+    doReturn(rowTypeMock).when(tableMock).getRowType();
+    LogicalTableScan tableScan = LogicalTableScan.create(optCluster, 
tableMock, Collections.emptyList());
+    doReturn(tableScan).when(tableMock).toRel(ArgumentMatchers.any());
+    doReturn(tableMock).when(schemaMock).getTableForMember(any());
+    lenient().doReturn(hiveTableMock).when(tableMock).getHiveTableMD();
+
+    return HiveRelFactories.HIVE_BUILDER.create(optCluster, schemaMock);
+  }
+
+  static RexNode eq(RelBuilder relBuilder, String field, Number value) {
+    return relBuilder.call(SqlStdOperatorTable.EQUALS,
+        relBuilder.field(field), relBuilder.literal(value));
+  }
+
+  static RexNode or(RelBuilder relBuilder, RexNode... args) {
+    return relBuilder.call(SqlStdOperatorTable.OR, args);
+  }
+
+  static RexNode and(RelBuilder relBuilder, RexNode... args) {
+    return relBuilder.call(SqlStdOperatorTable.AND, args);
+  }
+
+  static void assertPlans(AbstractRelOptPlanner planner, RelNode plan, String 
expectedPrePlan, String expectedPostPlan) {
+    planner.setRoot(plan);
+    RelNode optimizedRelNode = planner.findBestExp();
+    assertEquals("Original plans do not match", expectedPrePlan, 
RelOptUtil.toString(plan));
+    assertEquals("Optimized plans do not match", expectedPostPlan, 
RelOptUtil.toString(optimizedRelNode));
+  }
+}


Reply via email to