This is an automated email from the ASF dual-hosted git repository.

mbudiu pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/calcite.git

commit aea65f23d3754aab03c16fe43479192594b4a06b
Author: Mihai Budiu <[email protected]>
AuthorDate: Wed Aug 14 11:47:30 2024 -0700

    [CALCITE-6372] Add ASOF join to the Calcite validator
    
    Signed-off-by: Mihai Budiu <[email protected]>
---
 .../main/java/org/apache/calcite/sql/SqlKind.java  |  14 ++
 .../apache/calcite/sql/validate/JoinNamespace.java |   1 +
 .../sql/validate/SqlNonNullableAccessors.java      |   6 +
 .../calcite/sql/validate/SqlValidatorImpl.java     | 144 +++++++++++++++++++++
 .../calcite/sql/validate/SqlValidatorUtil.java     |   1 +
 .../org/apache/calcite/test/SqlValidatorTest.java  |  83 +++++++++++-
 6 files changed, 248 insertions(+), 1 deletion(-)

diff --git a/core/src/main/java/org/apache/calcite/sql/SqlKind.java 
b/core/src/main/java/org/apache/calcite/sql/SqlKind.java
index 0b6b48bd09..e292990e47 100644
--- a/core/src/main/java/org/apache/calcite/sql/SqlKind.java
+++ b/core/src/main/java/org/apache/calcite/sql/SqlKind.java
@@ -1479,6 +1479,20 @@ public enum SqlKind {
           LESS_THAN, GREATER_THAN,
           GREATER_THAN_OR_EQUAL, LESS_THAN_OR_EQUAL);
 
+  /**
+   * Comparison operators that order values.
+   *
+   * <p>Consists of:
+   * {@link #LESS_THAN},
+   * {@link #GREATER_THAN},
+   * {@link #LESS_THAN_OR_EQUAL},
+   * {@link #GREATER_THAN_OR_EQUAL}.
+   */
+  public static final Set<SqlKind> ORDER_COMPARISON =
+      EnumSet.of(
+          LESS_THAN, GREATER_THAN,
+          GREATER_THAN_OR_EQUAL, LESS_THAN_OR_EQUAL);
+
   /**
    * Category of binary arithmetic.
    *
diff --git 
a/core/src/main/java/org/apache/calcite/sql/validate/JoinNamespace.java 
b/core/src/main/java/org/apache/calcite/sql/validate/JoinNamespace.java
index f816d6deb5..c533796acd 100644
--- a/core/src/main/java/org/apache/calcite/sql/validate/JoinNamespace.java
+++ b/core/src/main/java/org/apache/calcite/sql/validate/JoinNamespace.java
@@ -48,6 +48,7 @@ class JoinNamespace extends AbstractNamespace {
     final RelDataTypeFactory typeFactory = validator.getTypeFactory();
     switch (join.getJoinType()) {
     case LEFT:
+    case LEFT_ASOF:
       rightType = typeFactory.createTypeWithNullability(rightType, true);
       break;
     case RIGHT:
diff --git 
a/core/src/main/java/org/apache/calcite/sql/validate/SqlNonNullableAccessors.java
 
b/core/src/main/java/org/apache/calcite/sql/validate/SqlNonNullableAccessors.java
index eab9007b96..7fbca5f888 100644
--- 
a/core/src/main/java/org/apache/calcite/sql/validate/SqlNonNullableAccessors.java
+++ 
b/core/src/main/java/org/apache/calcite/sql/validate/SqlNonNullableAccessors.java
@@ -16,6 +16,7 @@
  */
 package org.apache.calcite.sql.validate;
 
+import org.apache.calcite.sql.SqlAsofJoin;
 import org.apache.calcite.sql.SqlCallBinding;
 import org.apache.calcite.sql.SqlDelete;
 import org.apache.calcite.sql.SqlJoin;
@@ -71,6 +72,11 @@ public class SqlNonNullableAccessors {
         () -> "getCondition of " + safeToString(join));
   }
 
+  public static SqlNode getMatchCondition(SqlAsofJoin join) {
+    return requireNonNull(join.getMatchCondition(),
+        () -> "getMatchCondition of " + safeToString(join));
+  }
+
   @API(since = "1.27", status = API.Status.EXPERIMENTAL)
   static SqlNode getNode(ScopeChild child) {
     return requireNonNull(child.namespace.getNode(),
diff --git 
a/core/src/main/java/org/apache/calcite/sql/validate/SqlValidatorImpl.java 
b/core/src/main/java/org/apache/calcite/sql/validate/SqlValidatorImpl.java
index e433a5254c..b6813c751d 100644
--- a/core/src/main/java/org/apache/calcite/sql/validate/SqlValidatorImpl.java
+++ b/core/src/main/java/org/apache/calcite/sql/validate/SqlValidatorImpl.java
@@ -48,6 +48,7 @@ import org.apache.calcite.sql.SqlAccessEnum;
 import org.apache.calcite.sql.SqlAccessType;
 import org.apache.calcite.sql.SqlAggFunction;
 import org.apache.calcite.sql.SqlAsOperator;
+import org.apache.calcite.sql.SqlAsofJoin;
 import org.apache.calcite.sql.SqlBasicCall;
 import org.apache.calcite.sql.SqlCall;
 import org.apache.calcite.sql.SqlCallBinding;
@@ -170,6 +171,7 @@ import static org.apache.calcite.sql.SqlUtil.stripAs;
 import static org.apache.calcite.sql.type.NonNullableAccessors.getCharset;
 import static org.apache.calcite.sql.type.NonNullableAccessors.getCollation;
 import static 
org.apache.calcite.sql.validate.SqlNonNullableAccessors.getCondition;
+import static 
org.apache.calcite.sql.validate.SqlNonNullableAccessors.getMatchCondition;
 import static org.apache.calcite.sql.validate.SqlNonNullableAccessors.getTable;
 import static org.apache.calcite.util.Static.RESOURCE;
 import static org.apache.calcite.util.Util.first;
@@ -2398,6 +2400,8 @@ public class SqlValidatorImpl implements 
SqlValidatorWithHints {
    *                      scope
    * @return registered node, usually the same as {@code node}
    */
+  // CHECKSTYLE: OFF
+  // CheckStyle thinks this method is too long
   private SqlNode registerFrom(
       SqlValidatorScope parentScope0,
       SqlValidatorScope usingScope,
@@ -2559,6 +2563,7 @@ public class SqlValidatorImpl implements 
SqlValidatorWithHints {
       boolean forceRightNullable = forceNullable;
       switch (join.getJoinType()) {
       case LEFT:
+      case LEFT_ASOF:
         forceRightNullable = true;
         break;
       case RIGHT:
@@ -2777,6 +2782,7 @@ public class SqlValidatorImpl implements 
SqlValidatorWithHints {
       throw Util.unexpected(kind);
     }
   }
+  // CHECKSTYLE: ON
 
   protected boolean shouldAllowOverRelation() {
     return false;
@@ -3774,11 +3780,149 @@ public class SqlValidatorImpl implements 
SqlValidatorWithHints {
             RESOURCE.crossJoinDisallowsCondition());
       }
       break;
+    case LEFT_ASOF:
+    case ASOF: {
+      // In addition to the standard join checks, the ASOF join requires the
+      // ON conditions to be a conjunction of simple equalities from both 
relations.
+      SqlAsofJoin asof = (SqlAsofJoin) join;
+      SqlNode matchCondition = getMatchCondition(asof);
+      matchCondition = expand(matchCondition, joinScope);
+      join.setOperand(6, matchCondition);
+      validateWhereOrOn(joinScope, matchCondition, "MATCH_CONDITION");
+      SqlNode condition = join.getCondition();
+      if (condition == null) {
+        throw newValidationError(join, RESOURCE.joinRequiresCondition());
+      }
+      ConjunctionOfEqualities conj = new ConjunctionOfEqualities();
+      condition.accept(conj);
+      if (conj.illegal) {
+        throw newValidationError(condition, 
RESOURCE.asofConditionMustBeComparison());
+      }
+
+      CompareFromBothSides validateCompare =
+          new CompareFromBothSides(joinScope,
+              catalogReader, RESOURCE.asofConditionMustBeComparison());
+      condition.accept(validateCompare);
+
+      // It also requires the MATCH condition to be a comparison.
+      if (!(matchCondition instanceof SqlCall)) {
+        throw newValidationError(matchCondition, 
RESOURCE.asofMatchMustBeComparison());
+      }
+      SqlCall matchCall = (SqlCall) matchCondition;
+      SqlOperator operator = matchCall.getOperator();
+      if (!SqlKind.ORDER_COMPARISON.contains(operator.kind)) {
+        throw newValidationError(matchCondition, 
RESOURCE.asofMatchMustBeComparison());
+      }
+
+      // Change the exception in validateCompare when we validate the match 
condition
+      validateCompare =
+          new CompareFromBothSides(joinScope,
+              catalogReader, RESOURCE.asofMatchMustBeComparison());
+      matchCondition.accept(validateCompare);
+      break;
+    }
     default:
       throw Util.unexpected(joinType);
     }
   }
 
+  /**
+   * Shuttle which determines whether all SqlCalls that are
+   * comparisons are comparing columns from both namespaces.
+   * The shuttle will throw an exception if that happens.
+   * If it returns all SqlCalls have the expected shape.
+   */
+  private class CompareFromBothSides extends SqlShuttle {
+    final SqlValidatorScope scope;
+    final SqlValidatorCatalogReader catalogReader;
+    final Resources.ExInst<SqlValidatorException> exception;
+
+    private CompareFromBothSides(
+        SqlValidatorScope scope,
+        SqlValidatorCatalogReader catalogReader,
+        Resources.ExInst<SqlValidatorException> exception) {
+      this.scope = scope;
+      this.catalogReader = catalogReader;
+      this.exception = exception;
+    }
+
+    @Override public @Nullable SqlNode visit(final SqlCall call) {
+      SqlKind kind = call.getKind();
+      if (SqlKind.COMPARISON.contains(kind)) {
+        assert call.getOperandList().size() == 2;
+
+        boolean leftFound = false;
+        boolean rightFound = false;
+        // The two sides of the comparison must be from different tables
+        for (SqlNode operand : call.getOperandList()) {
+          if (!(operand instanceof SqlIdentifier)) {
+            throw newValidationError(call, this.exception);
+          }
+          // We know that all identifiers have been expanded by the caller,
+          // so they have the shape namespace.field
+          SqlIdentifier id = (SqlIdentifier) operand;
+          final SqlNameMatcher nameMatcher = catalogReader.nameMatcher();
+          final SqlValidatorScope.ResolvedImpl resolved = new 
SqlValidatorScope.ResolvedImpl();
+          // Lookup just the first component of the name
+          scope.resolve(id.names.subList(0, id.names.size() - 1), nameMatcher, 
false, resolved);
+          SqlValidatorScope.Resolve resolve = resolved.only();
+          int index = resolve.path.steps().get(0).i;
+          if (index == 0) {
+            leftFound = true;
+          }
+          if (index == 1) {
+            rightFound = true;
+          }
+
+          if (!leftFound && !rightFound) {
+            throw newValidationError(call, this.exception);
+          }
+        }
+        if (!leftFound || !rightFound) {
+          // The comparison does not look at both tables
+          throw newValidationError(call, this.exception);
+        }
+      }
+      return super.visit(call);
+    }
+  }
+
+  /**
+   * Shuttle which determines whether an expression is a simple conjunction
+   * of equalities. */
+  private static class ConjunctionOfEqualities extends SqlShuttle {
+    boolean illegal = false;
+
+    // Check an AND node.  Children can be AND nodes or EQUAL nodes.
+    void checkAnd(SqlCall call) {
+      // This doesn't seem to use the visitor pattern,
+      // because we recurse explicitly on the tree structure.
+      // The visitor is useful to make sure no other kinds of operations
+      // appear in the expression tree.
+      List<SqlNode> operands = call.getOperandList();
+      for (SqlNode operand : operands) {
+        if (operand.getKind() == SqlKind.AND) {
+          this.checkAnd((SqlCall) operand);
+          return;
+        }
+        if (operand.getKind() != SqlKind.EQUALS) {
+          illegal = true;
+        }
+      }
+    }
+
+    @Override public @Nullable SqlNode visit(final 
org.apache.calcite.sql.SqlCall call) {
+      SqlKind kind = call.getKind();
+      if (kind != SqlKind.AND && kind != SqlKind.EQUALS) {
+        illegal = true;
+      }
+      if (kind == SqlKind.AND) {
+        this.checkAnd(call);
+      }
+      return super.visit(call);
+    }
+  }
+
   /**
    * Throws an error if there is an aggregate or windowed aggregate in the
    * given clause.
diff --git 
a/core/src/main/java/org/apache/calcite/sql/validate/SqlValidatorUtil.java 
b/core/src/main/java/org/apache/calcite/sql/validate/SqlValidatorUtil.java
index ba9888b51a..3f88e78fff 100644
--- a/core/src/main/java/org/apache/calcite/sql/validate/SqlValidatorUtil.java
+++ b/core/src/main/java/org/apache/calcite/sql/validate/SqlValidatorUtil.java
@@ -538,6 +538,7 @@ public class SqlValidatorUtil {
     requireNonNull(systemFieldList, "systemFieldList");
     switch (joinType) {
     case LEFT:
+    case LEFT_ASOF:
       rightType =
           typeFactory.createTypeWithNullability(
               requireNonNull(rightType, "rightType"), true);
diff --git a/core/src/test/java/org/apache/calcite/test/SqlValidatorTest.java 
b/core/src/test/java/org/apache/calcite/test/SqlValidatorTest.java
index de415956a1..e040836ec5 100644
--- a/core/src/test/java/org/apache/calcite/test/SqlValidatorTest.java
+++ b/core/src/test/java/org/apache/calcite/test/SqlValidatorTest.java
@@ -3249,6 +3249,87 @@ public class SqlValidatorTest extends 
SqlValidatorTestCase {
     sql(query3).type(type);
   }
 
+  @Test void testAsOfJoin() {
+    final String type0 = "RecordType(INTEGER NOT NULL EMPNO, INTEGER NOT NULL 
DEPTNO) NOT NULL";
+    final String sql0 = "select emp.empno, dept.deptno from emp asof join 
dept\n"
+        + "match_condition emp.deptno <= dept.deptno\n"
+        + "on emp.ename = dept.name";
+    sql(sql0).type(type0);
+    // ASOF join of a join result
+    final String sql1 = "select emp.empno, D.deptno from emp asof join\n"
+        + "(select L.* FROM dept AS L join dept AS R on L.deptno = R.deptno) 
as D\n"
+        + "match_condition emp.deptno <= D.deptno\n"
+        + "on emp.ename = D.name";
+    sql(sql1).type(type0);
+
+    // LEFT ASOF JOIN
+    final String sql2 = "select emp.empno, dept.deptno from emp left asof join 
dept\n"
+        + "match_condition emp.deptno <= dept.deptno\n"
+        + "on emp.ename = dept.name";
+    final String type2 = "RecordType(INTEGER NOT NULL EMPNO, INTEGER DEPTNO) 
NOT NULL";
+    sql(sql2).type(type2);
+    // LEFT ASOF join of a join result
+    final String sql3 = "select emp.empno, D.deptno from emp left asof join\n"
+        + "(select L.* FROM dept AS L join dept AS R on L.deptno = R.deptno) 
as D\n"
+        + "match_condition emp.deptno <= D.deptno\n"
+        + "on emp.ename = D.name";
+    sql(sql3).type(type2);
+
+    // Longer sequence of comparisons
+    final String sql6 = "select emp.empno, dept.deptno from emp asof join 
dept\n"
+        + "match_condition emp.deptno <= dept.deptno\n"
+        + "on emp.ename = dept.name AND emp.deptno = dept.deptno AND emp.job = 
dept.name";
+    sql(sql6).type(type0);
+
+    // No table specified for on condition
+    final String sql4 = "select emp.empno, dept.deptno from emp asof join 
dept\n"
+        + "match_condition emp.deptno <= dept.deptno\n"
+        + "on ename = name";
+    sql(sql4).type(type0);
+
+    // No table specified for match condition
+    final String sql5 = "select emp.empno, dno as deptno from emp asof join "
+        + "(select deptno as dno, name from dept)\n"
+        + "match_condition deptno <= dno\n"
+        + "on ename = name";
+    sql(sql5).type(type0);
+
+    // Failure cases
+    // match condition is not an inequality test
+    sql("select emp.empno from emp asof join dept\n"
+        + "match_condition ^emp.deptno IN (1, 2)^\n"
+        + "on emp.ename = dept.name")
+        .fails(
+            "ASOF JOIN MATCH_CONDITION must be a comparison between columns 
from the two inputs");
+    // match condition does not compare columns from both tables
+    sql("select emp.empno from emp asof join dept\n"
+        + "match_condition ^emp.deptno < 12^\n"
+        + "on emp.ename = dept.name")
+        .fails(
+            "ASOF JOIN MATCH_CONDITION must be a comparison between columns 
from the two inputs");
+    // comparison is not a conjunction of equality tests
+    sql("select emp.empno from emp asof join dept\n"
+        + "match_condition emp.deptno < dept.deptno\n"
+        + "on ^emp.ename < 'foo'^")
+        .fails("ASOF JOIN condition must be a conjunction of equality 
comparisons");
+    // comparison contains an equality test that does not check both tables 
joined
+    sql("select emp.empno from emp asof join dept\n"
+        + "match_condition emp.deptno < dept.deptno\n"
+        + "on ^emp.ename = 'foo'^")
+        .fails("ASOF JOIN condition must be a conjunction of equality 
comparisons");
+    // comparison contains is not a conjunction
+    sql("select emp.empno from emp asof join dept\n"
+        + "match_condition emp.deptno < dept.deptno\n"
+        + "on ^emp.ename = dept.name OR emp.deptno = dept.deptno^")
+        .fails("ASOF JOIN condition must be a conjunction of equality 
comparisons");
+    // comparison is not a conjunction
+    sql("select * from (VALUES(true, false)) AS T0(b0, b1)\n"
+        + "asof join (VALUES(false, false)) AS T1(b0, b1)\n"
+        + "match_condition T0.b0 < T1.b0\n"
+        + "on ^T0.b1 AND T1.b1^")
+        .fails("ASOF JOIN condition must be a conjunction of equality 
comparisons");
+  }
+
   @Test void testInvalidWindowFunctionWithGroupBy() {
     sql("select max(^empno^) over () from emp\n"
         + "group by deptno")
@@ -5079,7 +5160,7 @@ public class SqlValidatorTest extends 
SqlValidatorTestCase {
   @Test void testInnerJoinWithoutUsingOrOnFails() {
     sql("select * from emp inner ^join^ dept\n"
         + "where emp.deptno = dept.deptno")
-        .fails("INNER, LEFT, RIGHT or FULL join requires a condition "
+        .fails("INNER, LEFT, RIGHT, FULL, or ASOF join requires a condition "
             + "\\(NATURAL keyword or ON or USING clause\\)");
   }
 

Reply via email to