This is an automated email from the ASF dual-hosted git repository.

jhyde pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/calcite.git

commit 3e877d11ae4bf3ff6dbfc990b6a13914e7c30944
Author: Julian Hyde <[email protected]>
AuthorDate: Sun Apr 10 15:10:46 2022 +0200

    Refactor RelWriterTest
    
    (No changes in functionality.)
---
 .../java/org/apache/calcite/util/JsonBuilder.java  |   2 +
 .../org/apache/calcite/plan/RelWriterTest.java     | 673 ++++++++++-----------
 2 files changed, 336 insertions(+), 339 deletions(-)

diff --git a/core/src/main/java/org/apache/calcite/util/JsonBuilder.java 
b/core/src/main/java/org/apache/calcite/util/JsonBuilder.java
index 6e20834af..b9fdec835 100644
--- a/core/src/main/java/org/apache/calcite/util/JsonBuilder.java
+++ b/core/src/main/java/org/apache/calcite/util/JsonBuilder.java
@@ -18,6 +18,8 @@ package org.apache.calcite.util;
 
 import org.apache.calcite.avatica.util.Spaces;
 
+import com.google.common.collect.ImmutableList;
+
 import org.checkerframework.checker.nullness.qual.Nullable;
 
 import java.util.ArrayList;
diff --git a/core/src/test/java/org/apache/calcite/plan/RelWriterTest.java 
b/core/src/test/java/org/apache/calcite/plan/RelWriterTest.java
index 6b5c55347..b75d09f70 100644
--- a/core/src/test/java/org/apache/calcite/plan/RelWriterTest.java
+++ b/core/src/test/java/org/apache/calcite/plan/RelWriterTest.java
@@ -47,7 +47,6 @@ import org.apache.calcite.rex.RexBuilder;
 import org.apache.calcite.rex.RexCorrelVariable;
 import org.apache.calcite.rex.RexFieldCollation;
 import org.apache.calcite.rex.RexInputRef;
-import org.apache.calcite.rex.RexLiteral;
 import org.apache.calcite.rex.RexNode;
 import org.apache.calcite.rex.RexProgramBuilder;
 import org.apache.calcite.rex.RexWindowBounds;
@@ -92,6 +91,7 @@ import java.util.Collections;
 import java.util.LinkedHashMap;
 import java.util.List;
 import java.util.Map;
+import java.util.function.Function;
 import java.util.stream.Stream;
 
 import static org.apache.calcite.test.Matchers.isLinux;
@@ -101,6 +101,8 @@ import static org.hamcrest.CoreMatchers.notNullValue;
 import static org.hamcrest.MatcherAssert.assertThat;
 import static org.junit.jupiter.api.Assertions.assertThrows;
 
+import static java.util.Objects.requireNonNull;
+
 /**
  * Unit test for {@link org.apache.calcite.rel.externalize.RelJson}.
  */
@@ -437,6 +439,11 @@ class RelWriterTest {
     return Stream.of(SqlExplainFormat.TEXT, SqlExplainFormat.DOT);
   }
 
+  /** Creates a fixture. */
+  private static Fixture relFn(Function<RelBuilder, RelNode> relFn) {
+    return new Fixture(relFn, false, SqlExplainFormat.TEXT);
+  }
+
   /** Unit test for {@link RelJson#toJson(Object)} for an object of type
    * {@link RelDataType}. */
   @Test void testTypeJson() {
@@ -594,40 +601,28 @@ class RelWriterTest {
   }
 
   @Test void testExchange() {
-    final FrameworkConfig config = RelBuilderTest.config().build();
-    final RelBuilder builder = RelBuilder.create(config);
-    final RelNode rel = builder
-        .scan("EMP")
-        .exchange(RelDistributions.hash(ImmutableList.of(0, 1)))
-        .build();
-    final String relJson = RelOptUtil.dumpPlan("", rel,
-        SqlExplainFormat.JSON, SqlExplainLevel.EXPPLAN_ATTRIBUTES);
-    String s = deserializeAndDumpToTextFormat(getSchema(rel), relJson);
+    final Function<RelBuilder, RelNode> relFn = b ->
+        b.scan("EMP")
+            .exchange(RelDistributions.hash(ImmutableList.of(0, 1)))
+            .build();
     final String expected = ""
         + "LogicalExchange(distribution=[hash[0, 1]])\n"
         + "  LogicalTableScan(table=[[scott, EMP]])\n";
-    assertThat(s, isLinux(expected));
+    relFn(relFn)
+        .assertThatPlan(isLinux(expected));
   }
 
   @Test public void testExchangeWithDistributionTraitDef() {
-    final FrameworkConfig config = RelBuilderTest.config().build();
-    final RelBuilder builder = RelBuilder.create(config);
-    final RelNode rel = builder
-        .scan("EMP")
-        .exchange(RelDistributions.hash(ImmutableList.of(0, 1)))
-        .build();
-    final String relJson = RelOptUtil.dumpPlan("", rel,
-        SqlExplainFormat.JSON, SqlExplainLevel.EXPPLAN_ATTRIBUTES);
-
-    VolcanoPlanner planner = new VolcanoPlanner();
-    planner.addRelTraitDef(RelDistributionTraitDef.INSTANCE);
-    RelOptCluster cluster = RelOptCluster.create(planner, 
builder.getRexBuilder());
-
-    String plan = deserializeAndDump(cluster, getSchema(rel), relJson, 
SqlExplainFormat.TEXT);
+    final Function<RelBuilder, RelNode> relFn = b ->
+        b.scan("EMP")
+            .exchange(RelDistributions.hash(ImmutableList.of(0, 1)))
+            .build();
     final String expected = ""
         + "LogicalExchange(distribution=[hash[0, 1]])\n"
         + "  LogicalTableScan(table=[[scott, EMP]])\n";
-    assertThat(plan, isLinux(expected));
+    relFn(relFn)
+        .withDistribution(true)
+        .assertThatPlan(isLinux(expected));
   }
 
   /**
@@ -714,7 +709,7 @@ class RelWriterTest {
             + "    LogicalTableScan(table=[[hr, emps]])\n"));
   }
 
-  @Test void testJsonToRex() throws JsonProcessingException {
+  @Test void testJsonToRex() {
     // Test simple literal without inputs
     final String jsonString1 = "{\n"
         + "  \"literal\": 10,\n"
@@ -777,9 +772,7 @@ class RelWriterTest {
   }
 
   @Test void testTrim() {
-    final FrameworkConfig config = RelBuilderTest.config().build();
-    final RelBuilder b = RelBuilder.create(config);
-    final RelNode rel =
+    final Function<RelBuilder, RelNode> relFn = b ->
         b.scan("EMP")
             .project(
                 b.alias(
@@ -789,36 +782,26 @@ class RelWriterTest {
                         b.field("ENAME")),
                     "trimmed_ename"))
             .build();
-
-    RelJsonWriter jsonWriter = new RelJsonWriter();
-    rel.explain(jsonWriter);
-    String relJson = jsonWriter.asString();
-    final RelOptSchema schema = getSchema(rel);
-    final String s = deserializeAndDumpToTextFormat(schema, relJson);
     final String expected = ""
         + "LogicalProject(trimmed_ename=[TRIM(FLAG(BOTH), ' ', $1)])\n"
         + "  LogicalTableScan(table=[[scott, EMP]])\n";
-    assertThat(s, isLinux(expected));
+    relFn(relFn)
+        .assertThatPlan(isLinux(expected));
   }
 
   @Test void testPlusOperator() {
-    final FrameworkConfig config = RelBuilderTest.config().build();
-    final RelBuilder builder = RelBuilder.create(config);
-    final RelNode rel = builder
-        .scan("EMP")
-        .project(
-            builder.call(SqlStdOperatorTable.PLUS,
-                builder.field("SAL"),
-                builder.literal(10)))
-        .build();
-    RelJsonWriter jsonWriter = new RelJsonWriter();
-    rel.explain(jsonWriter);
-    String relJson = jsonWriter.asString();
-    String s = deserializeAndDumpToTextFormat(getSchema(rel), relJson);
+    final Function<RelBuilder, RelNode> relFn = b ->
+        b.scan("EMP")
+            .project(
+                b.call(SqlStdOperatorTable.PLUS,
+                    b.field("SAL"),
+                    b.literal(10)))
+            .build();
     final String expected = ""
         + "LogicalProject($f0=[+($5, 10)])\n"
         + "  LogicalTableScan(table=[[scott, EMP]])\n";
-    assertThat(s, isLinux(expected));
+    relFn(relFn)
+        .assertThatPlan(isLinux(expected));
   }
 
   @ParameterizedTest
@@ -920,60 +903,43 @@ class RelWriterTest {
   }
 
   @Test void testCalc() {
-    final FrameworkConfig config = RelBuilderTest.config().build();
-    final RelBuilder builder = RelBuilder.create(config);
-    final RexBuilder rexBuilder = builder.getRexBuilder();
-    final LogicalTableScan scan = (LogicalTableScan) 
builder.scan("EMP").build();
-    final RexProgramBuilder programBuilder =
-        new RexProgramBuilder(scan.getRowType(), rexBuilder);
-    final RelDataTypeField field = scan.getRowType().getField("SAL", false, 
false);
-    programBuilder.addIdentity();
-    programBuilder.addCondition(
-        rexBuilder.makeCall(SqlStdOperatorTable.GREATER_THAN,
-            new RexInputRef(field.getIndex(), field.getType()),
-            builder.literal(10)));
-    final LogicalCalc calc = LogicalCalc.create(scan, 
programBuilder.getProgram());
-    String relJson = RelOptUtil.dumpPlan("", calc,
-        SqlExplainFormat.JSON, SqlExplainLevel.EXPPLAN_ATTRIBUTES);
-    String s =
-        Frameworks.withPlanner((cluster, relOptSchema, rootSchema) -> {
-          final RelJsonReader reader = new RelJsonReader(
-              cluster, getSchema(calc), rootSchema);
-          RelNode node;
-          try {
-            node = reader.read(relJson);
-          } catch (IOException e) {
-            throw TestUtil.rethrow(e);
-          }
-          return RelOptUtil.dumpPlan("", node, SqlExplainFormat.TEXT,
-              SqlExplainLevel.EXPPLAN_ATTRIBUTES);
-        });
-    final String expected =
-        "LogicalCalc(expr#0..7=[{inputs}], expr#8=[10], expr#9=[>($t5, $t8)],"
-            + " proj#0..7=[{exprs}], $condition=[$t9])\n"
-            + "  LogicalTableScan(table=[[scott, EMP]])\n";
-    assertThat(s, isLinux(expected));
+    final Function<RelBuilder, RelNode> relFn = b ->
+        b.scan("EMP")
+            .let(b2 -> {
+              final RexBuilder rexBuilder = b2.getRexBuilder();
+              final RelNode scan = b2.build();
+              final RelDataType rowType = scan.getRowType();
+              final RexProgramBuilder programBuilder =
+                  new RexProgramBuilder(rowType, rexBuilder);
+              final RelDataTypeField field =
+                  rowType.getField("SAL", false, false);
+              assertThat(field, notNullValue());
+              programBuilder.addIdentity();
+              programBuilder.addCondition(
+                  rexBuilder.makeCall(SqlStdOperatorTable.GREATER_THAN,
+                      new RexInputRef(field.getIndex(), field.getType()),
+                      b2.literal(10)));
+              return LogicalCalc.create(scan, programBuilder.getProgram());
+            });
+    final String expected = ""
+        + "LogicalCalc(expr#0..7=[{inputs}], expr#8=[10], expr#9=[>($t5, 
$t8)],"
+        + " proj#0..7=[{exprs}], $condition=[$t9])\n"
+        + "  LogicalTableScan(table=[[scott, EMP]])\n";
+    relFn(relFn)
+        .assertThatPlan(isLinux(expected));
   }
 
   @ParameterizedTest
   @MethodSource("explainFormats")
   void testCorrelateQuery(SqlExplainFormat format) {
-    final FrameworkConfig config = RelBuilderTest.config().build();
-    final RelBuilder builder = RelBuilder.create(config);
     final Holder<RexCorrelVariable> v = Holder.empty();
-    RelNode relNode = builder.scan("EMP")
+    final Function<RelBuilder, RelNode> relFn = b -> b.scan("EMP")
         .variable(v)
         .scan("DEPT")
-        .filter(
-            builder.equals(builder.field(0), builder.field(v.get(), "DEPTNO")))
-        .correlate(
-            JoinRelType.INNER, v.get().id, builder.field(2, 0, "DEPTNO"))
+        .filter(b.equals(b.field(0), b.field(v.get(), "DEPTNO")))
+        .correlate(JoinRelType.INNER, v.get().id, b.field(2, 0, "DEPTNO"))
         .build();
-    RelJsonWriter jsonWriter = new RelJsonWriter();
-    relNode.explain(jsonWriter);
-    final String relJson = jsonWriter.asString();
-    String s = deserializeAndDump(getSchema(relNode), relJson, format);
-    String expected = null;
+    final String expected;
     switch (format) {
     case TEXT:
       expected = ""
@@ -993,96 +959,82 @@ class RelWriterTest {
           + "($0, $c\\nor0.DEPTNO)\\n\" [label=\"0\"]\n"
           + "}\n";
       break;
+    default:
+      throw new AssertionError(format);
     }
-    assertThat(s, isLinux(expected));
+    relFn(relFn)
+        .withFormat(format)
+        .assertThatPlan(isLinux(expected));
   }
 
   @Test void testOverWithoutPartition() {
-    // The rel stands for the sql of "select count(*) over (order by deptno) 
from EMP"
-    final RelNode rel = mockCountOver("EMP", ImmutableList.of(), 
ImmutableList.of("DEPTNO"));
-    String relJson = RelOptUtil.dumpPlan("", rel, SqlExplainFormat.JSON,
-        SqlExplainLevel.EXPPLAN_ATTRIBUTES);
-    String s = deserializeAndDumpToTextFormat(getSchema(rel), relJson);
+    // Equivalent SQL:
+    //   SELECT count(*) OVER (ORDER BY deptno) FROM emp
+    final Function<RelBuilder, RelNode> relFn = b ->
+        mockCountOver(b, "EMP", ImmutableList.of(), 
ImmutableList.of("DEPTNO"));
     final String expected = ""
         + "LogicalProject($f0=[COUNT() OVER (ORDER BY $7 NULLS LAST "
         + "ROWS UNBOUNDED PRECEDING)])\n"
         + "  LogicalTableScan(table=[[scott, EMP]])\n";
-    assertThat(s, isLinux(expected));
+    relFn(relFn)
+        .assertThatPlan(isLinux(expected));
   }
 
   @Test void testOverWithoutOrderKey() {
-    // The rel stands for the sql of "select count(*) over (partition by 
DEPTNO) from EMP"
-    final RelNode rel = mockCountOver("EMP", ImmutableList.of("DEPTNO"), 
ImmutableList.of());
-    String relJson = RelOptUtil.dumpPlan("", rel, SqlExplainFormat.JSON,
-        SqlExplainLevel.EXPPLAN_ATTRIBUTES);
-    String s = deserializeAndDumpToTextFormat(getSchema(rel), relJson);
+    // Equivalent SQL:
+    //   SELECT count(*) OVER (PARTITION BY deptno) FROM emp
+    final Function<RelBuilder, RelNode> relFn = b ->
+        mockCountOver(b, "EMP", ImmutableList.of("DEPTNO"), 
ImmutableList.of());
     final String expected = ""
         + "LogicalProject($f0=[COUNT() OVER (PARTITION BY $7)])\n"
         + "  LogicalTableScan(table=[[scott, EMP]])\n";
-    assertThat(s, isLinux(expected));
+    relFn(relFn)
+        .assertThatPlan(isLinux(expected));
   }
 
   @Test void testInterval() {
-    final FrameworkConfig config = RelBuilderTest.config().build();
-    final RelBuilder builder = RelBuilder.create(config);
     SqlIntervalQualifier sqlIntervalQualifier =
         new SqlIntervalQualifier(TimeUnit.DAY, TimeUnit.DAY, 
SqlParserPos.ZERO);
     BigDecimal value = new BigDecimal(86400000);
-    RexLiteral intervalLiteral = builder.getRexBuilder()
-        .makeIntervalLiteral(value, sqlIntervalQualifier);
-    final RelNode rel = builder
-        .scan("EMP")
+    final Function<RelBuilder, RelNode> relFn = b -> b.scan("EMP")
         .project(
-            builder.call(
-                SqlStdOperatorTable.TUMBLE_END,
-                builder.field("HIREDATE"),
-                intervalLiteral))
+            b.call(SqlStdOperatorTable.TUMBLE_END,
+                b.field("HIREDATE"),
+                b.getRexBuilder()
+                    .makeIntervalLiteral(value, sqlIntervalQualifier)))
         .build();
-    RelJsonWriter jsonWriter = new RelJsonWriter();
-    rel.explain(jsonWriter);
-    String relJson = jsonWriter.asString();
-    String s = deserializeAndDumpToTextFormat(getSchema(rel), relJson);
     final String expected = ""
         + "LogicalProject($f0=[TUMBLE_END($4, 86400000:INTERVAL DAY)])\n"
         + "  LogicalTableScan(table=[[scott, EMP]])\n";
-    assertThat(s, isLinux(expected));
+    relFn(relFn)
+        .assertThatPlan(isLinux(expected));
   }
 
   @Test void testUdf() {
-    final FrameworkConfig config = RelBuilderTest.config().build();
-    final RelBuilder builder = RelBuilder.create(config);
-    final RelNode rel = builder
-        .scan("EMP")
-        .project(
-            builder.call(new MockSqlOperatorTable.MyFunction(),
-                builder.field("EMPNO")))
-        .build();
-    String relJson = RelOptUtil.dumpPlan("", rel,
-        SqlExplainFormat.JSON, SqlExplainLevel.EXPPLAN_ATTRIBUTES);
-    String s = deserializeAndDumpToTextFormat(getSchema(rel), relJson);
+    final Function<RelBuilder, RelNode> relFn = b ->
+        b.scan("EMP")
+            .project(
+                b.call(new MockSqlOperatorTable.MyFunction(),
+                    b.field("EMPNO")))
+            .build();
     final String expected = ""
         + "LogicalProject($f0=[MYFUN($0)])\n"
         + "  LogicalTableScan(table=[[scott, EMP]])\n";
-    assertThat(s, isLinux(expected));
+    relFn(relFn)
+        .assertThatPlan(isLinux(expected));
   }
 
   @ParameterizedTest
   @MethodSource("explainFormats")
   void testUDAF(SqlExplainFormat format) {
-    final FrameworkConfig config = RelBuilderTest.config().build();
-    final RelBuilder builder = RelBuilder.create(config);
-    final RelNode rel = builder
-        .scan("EMP")
-        .project(builder.field("ENAME"), builder.field("DEPTNO"))
-        .aggregate(
-            builder.groupKey("ENAME"),
-            builder.aggregateCall(new MockSqlOperatorTable.MyAggFunc(),
-                builder.field("DEPTNO")))
+    final Function<RelBuilder, RelNode> relFn = b ->
+        b.scan("EMP")
+            .project(b.field("ENAME"), b.field("DEPTNO"))
+            .aggregate(b.groupKey("ENAME"),
+                b.aggregateCall(new MockSqlOperatorTable.MyAggFunc(),
+                    b.field("DEPTNO")))
         .build();
-    final String relJson = RelOptUtil.dumpPlan("", rel,
-        SqlExplainFormat.JSON, SqlExplainLevel.EXPPLAN_ATTRIBUTES);
-    final String result = deserializeAndDump(getSchema(rel), relJson, format);
-    String expected = null;
+    final String expected;
     switch (format) {
     case TEXT:
       expected = ""
@@ -1098,49 +1050,45 @@ class RelWriterTest {
           + "$1\\nDEPTNO = $7\\n\" [label=\"0\"]\n"
           + "}\n";
       break;
+    default:
+      throw new AssertionError(format);
     }
-    assertThat(result, isLinux(expected));
+    relFn(relFn)
+        .withFormat(format)
+        .assertThatPlan(isLinux(expected));
   }
 
   @Test void testArrayType() {
-    final FrameworkConfig config = RelBuilderTest.config().build();
-    final RelBuilder builder = RelBuilder.create(config);
-    final RelNode rel = builder
-        .scan("EMP")
-        .project(
-            builder.call(new MockSqlOperatorTable.SplitFunction(),
-                builder.field("ENAME"), builder.literal(",")))
-        .build();
-    final String relJson = RelOptUtil.dumpPlan("", rel,
-        SqlExplainFormat.JSON, SqlExplainLevel.EXPPLAN_ATTRIBUTES);
-    final String s = deserializeAndDumpToTextFormat(getSchema(rel), relJson);
+    final Function<RelBuilder, RelNode> relFn = b ->
+        b.scan("EMP")
+            .project(
+                b.call(new MockSqlOperatorTable.SplitFunction(),
+                    b.field("ENAME"), b.literal(",")))
+            .build();
     final String expected = ""
         + "LogicalProject($f0=[SPLIT($1, ',')])\n"
         + "  LogicalTableScan(table=[[scott, EMP]])\n";
-    assertThat(s, isLinux(expected));
+    relFn(relFn)
+        .assertThatPlan(isLinux(expected));
   }
 
   @Test void testMapType() {
-    final FrameworkConfig config = RelBuilderTest.config().build();
-    final RelBuilder builder = RelBuilder.create(config);
-    final RelNode rel = builder
-        .scan("EMP")
-        .project(
-            builder.call(new MockSqlOperatorTable.MapFunction(),
-                builder.literal("key"), builder.literal("value")))
-        .build();
-    final String relJson = RelOptUtil.dumpPlan("", rel,
-        SqlExplainFormat.JSON, SqlExplainLevel.EXPPLAN_ATTRIBUTES);
-    final String s = deserializeAndDumpToTextFormat(getSchema(rel), relJson);
+    final Function<RelBuilder, RelNode> relFn = b ->
+        b.scan("EMP")
+            .project(
+                b.call(new MockSqlOperatorTable.MapFunction(),
+                    b.literal("key"), b.literal("value")))
+            .build();
     final String expected = ""
         + "LogicalProject($f0=[MAP('key', 'value')])\n"
         + "  LogicalTableScan(table=[[scott, EMP]])\n";
-    assertThat(s, isLinux(expected));
+    relFn(relFn)
+        .assertThatPlan(isLinux(expected));
   }
 
   /** Returns the schema of a {@link org.apache.calcite.rel.core.TableScan}
    * in this plan, or null if there are no scans. */
-  private RelOptSchema getSchema(RelNode rel) {
+  private static RelOptSchema getSchema(RelNode rel) {
     final Holder<@Nullable RelOptSchema> schemaHolder = Holder.empty();
     rel.accept(
         new RelShuttleImpl() {
@@ -1156,26 +1104,24 @@ class RelWriterTest {
    * Deserialize a relnode from the json string by {@link RelJsonReader},
    * and dump it to the given format.
    */
-  private String deserializeAndDump(
-      RelOptSchema schema, String relJson, SqlExplainFormat format) {
-    String s =
-        Frameworks.withPlanner((cluster, relOptSchema, rootSchema) -> {
-          final RelJsonReader reader = new RelJsonReader(
-              cluster, schema, rootSchema);
-          RelNode node;
-          try {
-            node = reader.read(relJson);
-          } catch (IOException e) {
-            throw TestUtil.rethrow(e);
-          }
-          return RelOptUtil.dumpPlan("", node, format,
-              SqlExplainLevel.EXPPLAN_ATTRIBUTES);
-        });
-    return s;
+  private static String deserializeAndDump(RelOptSchema schema, String relJson,
+      SqlExplainFormat format) {
+    return Frameworks.withPlanner((cluster, relOptSchema, rootSchema) -> {
+      final RelJsonReader reader =
+          new RelJsonReader(cluster, schema, rootSchema);
+      RelNode node;
+      try {
+        node = reader.read(relJson);
+      } catch (IOException e) {
+        throw TestUtil.rethrow(e);
+      }
+      return RelOptUtil.dumpPlan("", node, format,
+          SqlExplainLevel.EXPPLAN_ATTRIBUTES);
+    });
   }
 
-  private String deserializeAndDump(RelOptCluster cluster, RelOptSchema 
schema, String relJson,
-      SqlExplainFormat format) {
+  private static String deserializeAndDump(RelOptCluster cluster,
+      RelOptSchema schema, String relJson, SqlExplainFormat format) {
     final RelJsonReader reader = new RelJsonReader(cluster, schema, null);
     RelNode node;
     try {
@@ -1190,7 +1136,8 @@ class RelWriterTest {
    * Deserialize a relnode from the json string by {@link RelJsonReader},
    * and dump it to text format.
    */
-  private String deserializeAndDumpToTextFormat(RelOptSchema schema, String 
relJson) {
+  private static String deserializeAndDumpToTextFormat(RelOptSchema schema,
+      String relJson) {
     return deserializeAndDump(schema, relJson, SqlExplainFormat.TEXT);
   }
 
@@ -1208,11 +1155,8 @@ class RelWriterTest {
    * @param orderKeyNames Order by column names, may empty, can not be null
    * @return RelNode for the SQL
    */
-  private RelNode mockCountOver(String table,
+  private RelNode mockCountOver(RelBuilder builder, String table,
       List<String> partitionKeyNames, List<String> orderKeyNames) {
-
-    final FrameworkConfig config = RelBuilderTest.config().build();
-    final RelBuilder builder = RelBuilder.create(config);
     final RexBuilder rexBuilder = builder.getRexBuilder();
     final RelDataType type = 
rexBuilder.getTypeFactory().createSqlType(SqlTypeName.BIGINT);
     List<RexNode> partitionKeys = new ArrayList<>(partitionKeyNames.size());
@@ -1240,179 +1184,173 @@ class RelWriterTest {
   }
 
   @Test void testHashDistributionWithoutKeys() {
-    final RelNode root = 
createSortPlan(RelDistributions.hash(Collections.emptyList()));
-    final RelJsonWriter writer = new RelJsonWriter();
-    root.explain(writer);
-    final String json = writer.asString();
-    assertThat(json, is(HASH_DIST_WITHOUT_KEYS));
-
-    final String s = deserializeAndDumpToTextFormat(getSchema(root), json);
+    final Function<RelBuilder, RelNode> relFn = b ->
+        createSortPlan(b, RelDistributions.hash(Collections.emptyList()));
     final String expected =
         "LogicalSortExchange(distribution=[hash], collation=[[0]])\n"
             + "  LogicalTableScan(table=[[scott, EMP]])\n";
-    assertThat(s, isLinux(expected));
+    relFn(relFn)
+        .assertThatJson(is(HASH_DIST_WITHOUT_KEYS))
+        .assertThatPlan(isLinux(expected));
   }
 
   @Test void testWriteSortExchangeWithHashDistribution() {
-    final RelNode root = 
createSortPlan(RelDistributions.hash(Lists.newArrayList(0)));
-    final RelJsonWriter writer = new RelJsonWriter();
-    root.explain(writer);
-    final String json = writer.asString();
-    assertThat(json, is(XX3));
-
-    final String s = deserializeAndDumpToTextFormat(getSchema(root), json);
-    final String expected =
-        "LogicalSortExchange(distribution=[hash[0]], collation=[[0]])\n"
-            + "  LogicalTableScan(table=[[scott, EMP]])\n";
-    assertThat(s, isLinux(expected));
+    final Function<RelBuilder, RelNode> relFn = b ->
+        createSortPlan(b, RelDistributions.hash(Lists.newArrayList(0)));
+    final String expected = ""
+        + "LogicalSortExchange(distribution=[hash[0]], collation=[[0]])\n"
+        + "  LogicalTableScan(table=[[scott, EMP]])\n";
+    relFn(relFn)
+        .assertThatJson(is(XX3))
+        .assertThatPlan(isLinux(expected));
   }
 
   @Test void testWriteSortExchangeWithRandomDistribution() {
-    final RelNode root = createSortPlan(RelDistributions.RANDOM_DISTRIBUTED);
-    final RelJsonWriter writer = new RelJsonWriter();
-    root.explain(writer);
-    final String json = writer.asString();
-    final String s = deserializeAndDumpToTextFormat(getSchema(root), json);
-    final String expected =
-        "LogicalSortExchange(distribution=[random], collation=[[0]])\n"
-            + "  LogicalTableScan(table=[[scott, EMP]])\n";
-    assertThat(s, isLinux(expected));
+    final Function<RelBuilder, RelNode> relFn = b ->
+        createSortPlan(b, RelDistributions.RANDOM_DISTRIBUTED);
+    final String expected = ""
+        + "LogicalSortExchange(distribution=[random], collation=[[0]])\n"
+        + "  LogicalTableScan(table=[[scott, EMP]])\n";
+    relFn(relFn)
+        .assertThatPlan(isLinux(expected));
   }
 
   @Test void testTableModifyInsert() {
-    final FrameworkConfig config = RelBuilderTest.config().build();
-    final RelBuilder builder = RelBuilder.create(config);
-    RelNode project = builder
-        .scan("EMP")
-        .project(builder.fields(), ImmutableList.of(), true)
+    final Function<RelBuilder, RelNode> relFn = b ->
+        b.scan("EMP")
+        .project(b.fields(), ImmutableList.of(), true)
+            .let(b2 -> {
+              final RelNode input = b2.build();
+              final RelOptTable table = input.getInput(0).getTable();
+              final LogicalTableModify modify =
+                  LogicalTableModify.create(table,
+                      (Prepare.CatalogReader) table.getRelOptSchema(),
+                      input,
+                      TableModify.Operation.INSERT,
+                      null,
+                      null,
+                      false);
+              return b2.push(modify);
+            })
         .build();
-    LogicalTableModify modify = LogicalTableModify.create(
-        project.getInput(0).getTable(),
-        (Prepare.CatalogReader) 
project.getInput(0).getTable().getRelOptSchema(),
-        project,
-        TableModify.Operation.INSERT,
-        null,
-        null,
-        false);
-    String relJson = RelOptUtil.dumpPlan("", modify,
-        SqlExplainFormat.JSON, SqlExplainLevel.EXPPLAN_ATTRIBUTES);
-    String s = deserializeAndDumpToTextFormat(getSchema(modify), relJson);
     final String expected = ""
         + "LogicalTableModify(table=[[scott, EMP]], operation=[INSERT], 
flattened=[false])\n"
         + "  LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], 
HIREDATE=[$4], SAL=[$5], "
         + "COMM=[$6], DEPTNO=[$7])\n"
         + "    LogicalTableScan(table=[[scott, EMP]])\n";
-    assertThat(s, isLinux(expected));
+    relFn(relFn)
+        .assertThatPlan(isLinux(expected));
   }
 
   @Test void testTableModifyUpdate() {
-    final FrameworkConfig config = RelBuilderTest.config().build();
-    final RelBuilder builder = RelBuilder.create(config);
-    RelNode filter = builder
-        .scan("EMP")
-        .filter(
-            builder.call(
-                SqlStdOperatorTable.EQUALS,
-                builder.field("JOB"),
-                builder.literal("c")))
-        .build();
-    LogicalTableModify modify = LogicalTableModify.create(
-        filter.getInput(0).getTable(),
-        (Prepare.CatalogReader) 
filter.getInput(0).getTable().getRelOptSchema(),
-        filter,
-        TableModify.Operation.UPDATE,
-        ImmutableList.of("ENAME"),
-        ImmutableList.of(builder.literal("a")),
-        false);
-    String relJson = RelOptUtil.dumpPlan("", modify,
-        SqlExplainFormat.JSON, SqlExplainLevel.EXPPLAN_ATTRIBUTES);
-    String s = deserializeAndDumpToTextFormat(getSchema(modify), relJson);
+    final Function<RelBuilder, RelNode> relFn = b ->
+        b.scan("EMP")
+            .filter(
+                b.equals(b.field("JOB"), b.literal("c")))
+            .let(b2 -> {
+              final RelNode filter = b2.build();
+              final RelOptTable table = filter.getInput(0).getTable();
+              final LogicalTableModify modify =
+                  LogicalTableModify.create(table,
+                      (Prepare.CatalogReader) table.getRelOptSchema(),
+                      filter,
+                      TableModify.Operation.UPDATE,
+                      ImmutableList.of("ENAME"),
+                      ImmutableList.of(b2.literal("a")),
+                      false);
+              return b2.push(modify);
+            })
+            .build();
     final String expected = ""
         + "LogicalTableModify(table=[[scott, EMP]], operation=[UPDATE], 
updateColumnList=[[ENAME]],"
         + " sourceExpressionList=[['a']], flattened=[false])\n"
         + "  LogicalFilter(condition=[=($2, 'c')])\n"
         + "    LogicalTableScan(table=[[scott, EMP]])\n";
-    assertThat(s, isLinux(expected));
+    relFn(relFn)
+        .assertThatPlan(isLinux(expected));
   }
 
   @Test void testTableModifyDelete() {
-    final FrameworkConfig config = RelBuilderTest.config().build();
-    final RelBuilder builder = RelBuilder.create(config);
-    RelNode filter = builder
-        .scan("EMP")
-        .filter(
-            builder.call(
-                SqlStdOperatorTable.EQUALS,
-                builder.field("JOB"),
-                builder.literal("c")))
-        .build();
-    LogicalTableModify modify = LogicalTableModify.create(
-        filter.getInput(0).getTable(),
-        (Prepare.CatalogReader) 
filter.getInput(0).getTable().getRelOptSchema(),
-        filter,
-        TableModify.Operation.DELETE,
-        null,
-        null,
-        false);
-    String relJson = RelOptUtil.dumpPlan("", modify,
-        SqlExplainFormat.JSON, SqlExplainLevel.EXPPLAN_ATTRIBUTES);
-    String s = deserializeAndDumpToTextFormat(getSchema(modify), relJson);
+    final Function<RelBuilder, RelNode> relFn = b ->
+        b.scan("EMP")
+            .filter(b.equals(b.field("JOB"), b.literal("c")))
+            .let(b2 -> {
+              final RelNode filter = b2.build();
+              final RelOptTable table = filter.getInput(0).getTable();
+              LogicalTableModify modify =
+                  LogicalTableModify.create(table,
+                      (Prepare.CatalogReader) table.getRelOptSchema(),
+                      filter,
+                      TableModify.Operation.DELETE,
+                      null,
+                      null,
+                      false);
+              return b2.push(modify);
+            })
+            .build();
     final String expected = ""
         + "LogicalTableModify(table=[[scott, EMP]], operation=[DELETE], 
flattened=[false])\n"
         + "  LogicalFilter(condition=[=($2, 'c')])\n"
         + "    LogicalTableScan(table=[[scott, EMP]])\n";
-    assertThat(s, isLinux(expected));
+    relFn(relFn)
+        .assertThatPlan(isLinux(expected));
   }
 
   @Test void testTableModifyMerge() {
-    final FrameworkConfig config = RelBuilderTest.config().build();
-    final RelBuilder builder = RelBuilder.create(config);
-    RelNode deptScan = builder.scan("DEPT").build();
-    RelNode empScan = builder.scan("EMP").build();
-    builder.push(deptScan);
-    builder.push(empScan);
-    RelNode project = builder
-        .join(JoinRelType.LEFT,
-            builder.call(
-                SqlStdOperatorTable.EQUALS,
-                builder.field(2, 0, "DEPTNO"),
-                builder.field(2, 1, "DEPTNO")))
-        .project(
-            builder.literal(0),
-            builder.literal("x"),
-            builder.literal("x"),
-            builder.literal(0),
-            builder.literal("20200501 10:00:00"),
-            builder.literal(0),
-            builder.literal(0),
-            builder.literal(0),
-            builder.literal("false"),
-            builder.field(1, 0, 2),
-            builder.field(1, 0, 3),
-            builder.field(1, 0, 4),
-            builder.field(1, 0, 5),
-            builder.field(1, 0, 6),
-            builder.field(1, 0, 7),
-            builder.field(1, 0, 8),
-            builder.field(1, 0, 9),
-            builder.field(1, 0, 10),
-            builder.literal("a"))
-        .build();
-    // for sql:
-    // merge into emp using dept on emp.deptno = dept.deptno
-    // when matched then update set job = 'a'
-    // when not matched then insert values(0, 'x', 'x', 0, '20200501 
10:00:00', 0, 0, 0, 0)
-    LogicalTableModify modify = LogicalTableModify.create(
-        empScan.getTable(),
-        (Prepare.CatalogReader) empScan.getTable().getRelOptSchema(),
-        project,
-        TableModify.Operation.MERGE,
-        ImmutableList.of("ENAME"),
-        null,
-        false);
-    String relJson = RelOptUtil.dumpPlan("", modify,
-        SqlExplainFormat.JSON, SqlExplainLevel.EXPPLAN_ATTRIBUTES);
-    String s = deserializeAndDumpToTextFormat(getSchema(modify), relJson);
+    final Holder<RelOptTable> emp = Holder.empty();
+    final Holder<RelOptTable> dept = Holder.empty();
+    final Function<RelBuilder, RelNode> relFn = b ->
+        b.scan("DEPT")
+            .let(b2 -> {
+              dept.set(requireNonNull(b2.peek().getTable()));
+              return b2;
+            })
+            .scan("EMP")
+            .let(b2 -> {
+              emp.set(requireNonNull(b2.peek().getTable()));
+              return b2;
+            })
+            .join(JoinRelType.LEFT,
+                b.equals(b.field(2, 0, "DEPTNO"), b.field(2, 1, "DEPTNO")))
+            .project(b.literal(0),
+                b.literal("x"),
+                b.literal("x"),
+                b.literal(0),
+                b.literal("20200501 10:00:00"),
+                b.literal(0),
+                b.literal(0),
+                b.literal(0),
+                b.literal("false"),
+                b.field(1, 0, 2),
+                b.field(1, 0, 3),
+                b.field(1, 0, 4),
+                b.field(1, 0, 5),
+                b.field(1, 0, 6),
+                b.field(1, 0, 7),
+                b.field(1, 0, 8),
+                b.field(1, 0, 9),
+                b.field(1, 0, 10),
+                b.literal("a"))
+            .let(b2 -> {
+              // For SQL:
+              //   MERGE INTO emp USING dept ON emp.deptno = dept.deptno
+              //   WHEN MATCHED THEN
+              //     UPDATE SET job = 'a'
+              //   WHEN NOT MATCHED THEN
+              //     INSERT VALUES (0, 'x', 'x', 0, '20200501 10:00:00',
+              //         0, 0, 0, 0)
+              final RelNode project = b.build();
+              LogicalTableModify modify =
+                  LogicalTableModify.create(emp.get(),
+                      (Prepare.CatalogReader) emp.get().getRelOptSchema(),
+                      project,
+                      TableModify.Operation.MERGE,
+                      ImmutableList.of("ENAME"),
+                      null,
+                      false);
+              return b2.push(modify);
+            })
+            .build();
     final String expected = ""
         + "LogicalTableModify(table=[[scott, EMP]], operation=[MERGE], "
         + "updateColumnList=[[ENAME]], flattened=[false])\n"
@@ -1422,15 +1360,72 @@ class RelWriterTest {
         + "    LogicalJoin(condition=[=($0, $10)], joinType=[left])\n"
         + "      LogicalTableScan(table=[[scott, DEPT]])\n"
         + "      LogicalTableScan(table=[[scott, EMP]])\n";
-    assertThat(s, isLinux(expected));
+    relFn(relFn)
+        .assertThatPlan(isLinux(expected));
   }
 
-  private RelNode createSortPlan(RelDistribution distribution) {
-    final FrameworkConfig config = RelBuilderTest.config().build();
-    final RelBuilder builder = RelBuilder.create(config);
+  private RelNode createSortPlan(RelBuilder builder, RelDistribution 
distribution) {
     return builder.scan("EMP")
             .sortExchange(distribution,
                 RelCollations.of(0))
             .build();
   }
+
+  /** Test fixture. */
+  static class Fixture {
+    final Function<RelBuilder, RelNode> relFn;
+    final boolean distribution;
+    final SqlExplainFormat format;
+
+    Fixture(Function<RelBuilder, RelNode> relFn, boolean distribution,
+        SqlExplainFormat format) {
+      this.relFn = relFn;
+      this.distribution = distribution;
+      this.format = format;
+    }
+
+    Fixture withDistribution(boolean distribution) {
+      if (distribution == this.distribution) {
+        return this;
+      }
+      return new Fixture(relFn, distribution, format);
+    }
+
+    Fixture withFormat(SqlExplainFormat format) {
+      if (format == this.format) {
+        return this;
+      }
+      return new Fixture(relFn, distribution, format);
+    }
+
+    Fixture assertThatJson(Matcher<String> matcher) {
+      final FrameworkConfig config = RelBuilderTest.config().build();
+      final RelBuilder b = RelBuilder.create(config);
+      RelNode rel = relFn.apply(b);
+      final String relJson = RelOptUtil.dumpPlan("", rel,
+          SqlExplainFormat.JSON, SqlExplainLevel.EXPPLAN_ATTRIBUTES);
+      assertThat(relJson, matcher);
+      return this;
+    }
+
+    Fixture assertThatPlan(Matcher<String> matcher) {
+      final FrameworkConfig config = RelBuilderTest.config().build();
+      final RelBuilder b = RelBuilder.create(config);
+      RelNode rel = relFn.apply(b);
+      final String relJson = RelOptUtil.dumpPlan("", rel,
+          SqlExplainFormat.JSON, SqlExplainLevel.EXPPLAN_ATTRIBUTES);
+      final String plan;
+      if (distribution) {
+        VolcanoPlanner planner = new VolcanoPlanner();
+        planner.addRelTraitDef(RelDistributionTraitDef.INSTANCE);
+        RelOptCluster cluster =
+            RelOptCluster.create(planner, b.getRexBuilder());
+        plan = deserializeAndDump(cluster, getSchema(rel), relJson, format);
+      } else {
+        plan = deserializeAndDump(getSchema(rel), relJson, format);
+      }
+      assertThat(plan, matcher);
+      return this;
+    }
+  }
 }

Reply via email to