This is an automated email from the ASF dual-hosted git repository.

jhyde pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/calcite.git

commit 3801b42c0a5237ad7ff6b4c22eec96bcde628034
Author: Julian Hyde <[email protected]>
AuthorDate: Mon Jan 16 20:05:55 2023 -0800

    [CALCITE-5342] Refactor SqlFunctions methods lastDay, addMonths, 
subtractMonths to use DateTimeUtils from Avatica
---
 .../calcite/adapter/enumerable/RexImpTable.java    |  34 ++++-
 .../org/apache/calcite/rel/type/TimeFrameSet.java  |   5 +-
 .../org/apache/calcite/runtime/SqlFunctions.java   | 151 ++++-----------------
 .../org/apache/calcite/util/BuiltInMethod.java     |   9 +-
 .../org/apache/calcite/test/SqlFunctionsTest.java  |  45 ------
 5 files changed, 63 insertions(+), 181 deletions(-)

diff --git 
a/core/src/main/java/org/apache/calcite/adapter/enumerable/RexImpTable.java 
b/core/src/main/java/org/apache/calcite/adapter/enumerable/RexImpTable.java
index 2454fe9f8c..c4845a7d29 100644
--- a/core/src/main/java/org/apache/calcite/adapter/enumerable/RexImpTable.java
+++ b/core/src/main/java/org/apache/calcite/adapter/enumerable/RexImpTable.java
@@ -527,7 +527,8 @@ public class RexImpTable {
       map.put(TIMESTAMP_TRUNC, map.get(FLOOR));
       map.put(TIME_TRUNC, map.get(FLOOR));
 
-      defineMethod(LAST_DAY, "lastDay", NullPolicy.STRICT);
+      map.put(LAST_DAY,
+          new LastDayImplementor("lastDay", BuiltInMethod.LAST_DAY));
       map.put(DAYNAME,
           new PeriodNameImplementor("dayName",
               BuiltInMethod.DAYNAME_WITH_TIMESTAMP,
@@ -2081,6 +2082,37 @@ public class RexImpTable {
     }
   }
 
+  /** Implementor for the {@code LAST_DAY} function. */
+  private static class LastDayImplementor extends MethodNameImplementor {
+    private final BuiltInMethod dateMethod;
+
+    LastDayImplementor(String methodName, BuiltInMethod dateMethod) {
+      super(methodName, NullPolicy.STRICT, false);
+      this.dateMethod = dateMethod;
+    }
+
+    @Override String getVariableName() {
+      return methodName;
+    }
+
+    @Override Expression implementSafe(final RexToLixTranslator translator,
+        final RexCall call, final List<Expression> argValueList) {
+      Expression operand = argValueList.get(0);
+      final RelDataType type = call.operands.get(0).getType();
+      switch (type.getSqlTypeName()) {
+      case TIMESTAMP:
+        operand =
+            Expressions.call(BuiltInMethod.TIMESTAMP_TO_DATE.method, operand);
+        // fall through
+      case DATE:
+        return Expressions.call(dateMethod.method.getDeclaringClass(),
+            dateMethod.method.getName(), operand);
+      default:
+        throw new AssertionError("unknown type " + type);
+      }
+    }
+  }
+
   /** Implementor for the {@code FLOOR} and {@code CEIL} functions. */
   private static class FloorImplementor extends MethodNameImplementor {
     final Method timestampMethod;
diff --git a/core/src/main/java/org/apache/calcite/rel/type/TimeFrameSet.java 
b/core/src/main/java/org/apache/calcite/rel/type/TimeFrameSet.java
index 52713dc122..07869b8504 100644
--- a/core/src/main/java/org/apache/calcite/rel/type/TimeFrameSet.java
+++ b/core/src/main/java/org/apache/calcite/rel/type/TimeFrameSet.java
@@ -19,7 +19,6 @@ package org.apache.calcite.rel.type;
 import org.apache.calcite.avatica.util.DateTimeUtils;
 import org.apache.calcite.avatica.util.TimeUnit;
 import org.apache.calcite.avatica.util.TimeUnitRange;
-import org.apache.calcite.runtime.SqlFunctions;
 import org.apache.calcite.util.NameMap;
 import org.apache.calcite.util.TimestampString;
 import org.apache.calcite.util.Util;
@@ -214,7 +213,7 @@ public class TimeFrameSet {
     if (perMonth != null
         && perMonth.getNumerator().equals(BigInteger.ONE)) {
       final int m = perMonth.getDenominator().intValueExact(); // e.g. 12 for 
YEAR
-      return SqlFunctions.addMonths(date, interval * m);
+      return DateTimeUtils.addMonths(date, interval * m);
     }
 
     // Unknown time frame. Return the original value unchanged.
@@ -235,7 +234,7 @@ public class TimeFrameSet {
     if (perMonth != null
         && perMonth.getNumerator().equals(BigInteger.ONE)) {
       final long m = perMonth.getDenominator().longValueExact(); // e.g. 12 
for YEAR
-      return SqlFunctions.addMonths(timestamp, (int) (interval * m));
+      return DateTimeUtils.addMonths(timestamp, (int) (interval * m));
     }
 
     // Unknown time frame. Return the original value unchanged.
diff --git a/core/src/main/java/org/apache/calcite/runtime/SqlFunctions.java 
b/core/src/main/java/org/apache/calcite/runtime/SqlFunctions.java
index 9a7fa9b9b1..bb617f8cf1 100644
--- a/core/src/main/java/org/apache/calcite/runtime/SqlFunctions.java
+++ b/core/src/main/java/org/apache/calcite/runtime/SqlFunctions.java
@@ -84,8 +84,6 @@ import java.util.regex.Pattern;
 import static org.apache.calcite.linq4j.Nullness.castNonNull;
 import static org.apache.calcite.util.Static.RESOURCE;
 
-import static java.lang.Math.floorDiv;
-import static java.lang.Math.floorMod;
 import static java.nio.charset.StandardCharsets.UTF_8;
 import static java.util.Objects.requireNonNull;
 
@@ -403,8 +401,7 @@ public class SqlFunctions {
 
   /** SQL {@code ENDS_WITH(binary, binary)} function. */
   public static boolean endsWith(ByteString s0, ByteString s1) {
-    return s0.length() >= s1.length()
-        && s0.substring(s0.length() - s1.length()).equals(s1);
+    return s0.endsWith(s1);
   }
 
   /** SQL {@code STARTS_WITH(string, string)} function. */
@@ -414,8 +411,7 @@ public class SqlFunctions {
 
   /** SQL {@code STARTS_WITH(binary, binary)} function. */
   public static boolean startsWith(ByteString s0, ByteString s1) {
-    return s0.length() >= s1.length()
-        && s0.substring(0, s1.length()).equals(s1);
+    return s0.startsWith(s1);
   }
 
   /** SQL SUBSTRING(string FROM ...) function. */
@@ -2725,33 +2721,6 @@ public class SqlFunctions {
     return v - remainder;
   }
 
-  /**
-   * SQL {@code LAST_DAY} function.
-   *
-   * @param date days since epoch
-   * @return days of the last day of the month since epoch
-   */
-  public static int lastDay(int date) {
-    int y0 = (int) DateTimeUtils.unixDateExtract(TimeUnitRange.YEAR, date);
-    int m0 = (int) DateTimeUtils.unixDateExtract(TimeUnitRange.MONTH, date);
-    int last = lastDay(y0, m0);
-    return DateTimeUtils.ymdToUnixDate(y0, m0, last);
-  }
-
-  /**
-   * SQL {@code LAST_DAY} function.
-   *
-   * @param timestamp milliseconds from epoch
-   * @return milliseconds of the last day of the month since epoch
-   */
-  public static int lastDay(long timestamp) {
-    int date = (int) (timestamp / DateTimeUtils.MILLIS_PER_DAY);
-    int y0 = (int) DateTimeUtils.unixDateExtract(TimeUnitRange.YEAR, date);
-    int m0 = (int) DateTimeUtils.unixDateExtract(TimeUnitRange.MONTH, date);
-    int last = lastDay(y0, m0);
-    return DateTimeUtils.ymdToUnixDate(y0, m0, last);
-  }
-
   /**
    * SQL {@code DAYNAME} function, applied to a TIMESTAMP argument.
    *
@@ -2820,21 +2789,32 @@ public class SqlFunctions {
    * @return localDate
    */
   private static LocalDate timeStampToLocalDate(long timestamp) {
-    int date = (int) (timestamp / DateTimeUtils.MILLIS_PER_DAY);
+    int date = timestampToDate(timestamp);
     return dateToLocalDate(date);
   }
 
+  /** Converts a timestamp (milliseconds since epoch)
+   * to a date (days since epoch). */
+  public static int timestampToDate(long timestamp) {
+    return (int) (timestamp / DateTimeUtils.MILLIS_PER_DAY);
+  }
+
+  /** Converts a timestamp (milliseconds since epoch)
+   * to a time (milliseconds since midnight). */
+  public static int timestampToTime(long timestamp) {
+    return (int) (timestamp % DateTimeUtils.MILLIS_PER_DAY);
+  }
+
   /** SQL {@code CURRENT_TIMESTAMP} function. */
   @NonDeterministic
   public static long currentTimestamp(DataContext root) {
-    // Cast required for JDK 1.6.
-    return (Long) DataContext.Variable.CURRENT_TIMESTAMP.get(root);
+    return DataContext.Variable.CURRENT_TIMESTAMP.get(root);
   }
 
   /** SQL {@code CURRENT_TIME} function. */
   @NonDeterministic
   public static int currentTime(DataContext root) {
-    int time = (int) (currentTimestamp(root) % DateTimeUtils.MILLIS_PER_DAY);
+    int time = timestampToTime(currentTimestamp(root));
     if (time < 0) {
       time = (int) (time + DateTimeUtils.MILLIS_PER_DAY);
     }
@@ -2845,8 +2825,8 @@ public class SqlFunctions {
   @NonDeterministic
   public static int currentDate(DataContext root) {
     final long timestamp = currentTimestamp(root);
-    int date = (int) (timestamp / DateTimeUtils.MILLIS_PER_DAY);
-    final int time = (int) (timestamp % DateTimeUtils.MILLIS_PER_DAY);
+    int date = timestampToDate(timestamp);
+    final int time = timestampToTime(timestamp);
     if (time < 0) {
       --date;
     }
@@ -2856,19 +2836,18 @@ public class SqlFunctions {
   /** SQL {@code LOCAL_TIMESTAMP} function. */
   @NonDeterministic
   public static long localTimestamp(DataContext root) {
-    // Cast required for JDK 1.6.
-    return (Long) DataContext.Variable.LOCAL_TIMESTAMP.get(root);
+    return DataContext.Variable.LOCAL_TIMESTAMP.get(root);
   }
 
   /** SQL {@code LOCAL_TIME} function. */
   @NonDeterministic
   public static int localTime(DataContext root) {
-    return (int) (localTimestamp(root) % DateTimeUtils.MILLIS_PER_DAY);
+    return timestampToTime(localTimestamp(root));
   }
 
   @NonDeterministic
   public static TimeZone timeZone(DataContext root) {
-    return (TimeZone) DataContext.Variable.TIME_ZONE.get(root);
+    return DataContext.Variable.TIME_ZONE.get(root);
   }
 
   /** SQL {@code USER} function. */
@@ -2885,7 +2864,7 @@ public class SqlFunctions {
 
   @NonDeterministic
   public static Locale locale(DataContext root) {
-    return (Locale) DataContext.Variable.LOCALE.get(root);
+    return DataContext.Variable.LOCALE.get(root);
   }
 
   /** SQL {@code DATEADD} function applied to a custom time frame.
@@ -3335,90 +3314,6 @@ public class SqlFunctions {
     };
   }
 
-  /** Adds a given number of months to a timestamp, represented as the number
-   * of milliseconds since the epoch. */
-  public static long addMonths(long timestamp, int m) {
-    final long millis = floorMod(timestamp, DateTimeUtils.MILLIS_PER_DAY);
-    timestamp -= millis;
-    final long x =
-        addMonths((int) (timestamp / DateTimeUtils.MILLIS_PER_DAY), m);
-    return x * DateTimeUtils.MILLIS_PER_DAY + millis;
-  }
-
-  /** Adds a given number of months to a date, represented as the number of
-   * days since the epoch. */
-  public static int addMonths(int date, int m) {
-    int y0 = (int) DateTimeUtils.unixDateExtract(TimeUnitRange.YEAR, date);
-    int m0 = (int) DateTimeUtils.unixDateExtract(TimeUnitRange.MONTH, date);
-    int d0 = (int) DateTimeUtils.unixDateExtract(TimeUnitRange.DAY, date);
-    m0 += m;
-    int deltaYear = floorDiv(m0, 12);
-    y0 += deltaYear;
-    m0 = floorMod(m0, 12);
-    if (m0 == 0) {
-      y0 -= 1;
-      m0 += 12;
-    }
-
-    int last = lastDay(y0, m0);
-    if (d0 > last) {
-      d0 = last;
-    }
-    return DateTimeUtils.ymdToUnixDate(y0, m0, d0);
-  }
-
-  private static int lastDay(int y, int m) {
-    switch (m) {
-    case 2:
-      return y % 4 == 0
-          && (y % 100 != 0
-          || y % 400 == 0)
-          ? 29 : 28;
-    case 4:
-    case 6:
-    case 9:
-    case 11:
-      return 30;
-    default:
-      return 31;
-    }
-  }
-
-  /** Finds the number of months between two dates, each represented as the
-   * number of days since the epoch. */
-  public static int subtractMonths(int date0, int date1) {
-    if (date0 < date1) {
-      return -subtractMonths(date1, date0);
-    }
-    // Start with an estimate.
-    // Since no month has more than 31 days, the estimate is <= the true value.
-    int m = (date0 - date1) / 31;
-    for (;;) {
-      int date2 = addMonths(date1, m);
-      if (date2 >= date0) {
-        return m;
-      }
-      int date3 = addMonths(date1, m + 1);
-      if (date3 > date0) {
-        return m;
-      }
-      ++m;
-    }
-  }
-
-  public static int subtractMonths(long t0, long t1) {
-    final long millis0 = floorMod(t0, DateTimeUtils.MILLIS_PER_DAY);
-    final int d0 = (int) floorDiv(t0 - millis0, DateTimeUtils.MILLIS_PER_DAY);
-    final long millis1 = floorMod(t1, DateTimeUtils.MILLIS_PER_DAY);
-    final int d1 = (int) floorDiv(t1 - millis1, DateTimeUtils.MILLIS_PER_DAY);
-    int x = subtractMonths(d0, d1);
-    final long d2 = addMonths(d1, x);
-    if (d2 == d0 && millis0 < millis1) {
-      --x;
-    }
-    return x;
-  }
-
   /**
    * Implements the {@code .} (field access) operator on an object
    * whose type is not known until runtime.
diff --git a/core/src/main/java/org/apache/calcite/util/BuiltInMethod.java 
b/core/src/main/java/org/apache/calcite/util/BuiltInMethod.java
index ff9c5fcb30..22a36ebb9a 100644
--- a/core/src/main/java/org/apache/calcite/util/BuiltInMethod.java
+++ b/core/src/main/java/org/apache/calcite/util/BuiltInMethod.java
@@ -410,9 +410,9 @@ public enum BuiltInMethod {
   MULTI_STRING_CONCAT(SqlFunctions.class, "concatMulti", String[].class),
   FLOOR_DIV(Math.class, "floorDiv", long.class, long.class),
   FLOOR_MOD(Math.class, "floorMod", long.class, long.class),
-  ADD_MONTHS(SqlFunctions.class, "addMonths", long.class, int.class),
-  ADD_MONTHS_INT(SqlFunctions.class, "addMonths", int.class, int.class),
-  SUBTRACT_MONTHS(SqlFunctions.class, "subtractMonths", long.class,
+  ADD_MONTHS(DateTimeUtils.class, "addMonths", long.class, int.class),
+  ADD_MONTHS_INT(DateTimeUtils.class, "addMonths", int.class, int.class),
+  SUBTRACT_MONTHS(DateTimeUtils.class, "subtractMonths", long.class,
       long.class),
   FLOOR(SqlFunctions.class, "floor", int.class, int.class),
   CEIL(SqlFunctions.class, "ceil", int.class, int.class),
@@ -528,7 +528,8 @@ public enum BuiltInMethod {
       DataContext.class, String.class, long.class),
   CUSTOM_TIMESTAMP_CEIL(SqlFunctions.class, "customTimestampCeil",
       DataContext.class, String.class, long.class),
-  LAST_DAY(SqlFunctions.class, "lastDay", int.class),
+  TIMESTAMP_TO_DATE(SqlFunctions.class, "timestampToDate", long.class),
+  LAST_DAY(DateTimeUtils.class, "lastDay", int.class),
   DAYNAME_WITH_TIMESTAMP(SqlFunctions.class,
       "dayNameWithTimestamp", long.class, Locale.class),
   DAYNAME_WITH_DATE(SqlFunctions.class,
diff --git a/core/src/test/java/org/apache/calcite/test/SqlFunctionsTest.java 
b/core/src/test/java/org/apache/calcite/test/SqlFunctionsTest.java
index ea1fdc2d6f..9d2c1b89ec 100644
--- a/core/src/test/java/org/apache/calcite/test/SqlFunctionsTest.java
+++ b/core/src/test/java/org/apache/calcite/test/SqlFunctionsTest.java
@@ -39,8 +39,6 @@ import static 
org.apache.calcite.avatica.util.DateTimeUtils.MILLIS_PER_DAY;
 import static 
org.apache.calcite.avatica.util.DateTimeUtils.dateStringToUnixDate;
 import static 
org.apache.calcite.avatica.util.DateTimeUtils.timeStringToUnixDate;
 import static 
org.apache.calcite.avatica.util.DateTimeUtils.timestampStringToUnixDate;
-import static org.apache.calcite.avatica.util.DateTimeUtils.ymdToUnixDate;
-import static org.apache.calcite.runtime.SqlFunctions.addMonths;
 import static org.apache.calcite.runtime.SqlFunctions.charLength;
 import static org.apache.calcite.runtime.SqlFunctions.concat;
 import static org.apache.calcite.runtime.SqlFunctions.fromBase64;
@@ -57,7 +55,6 @@ import static 
org.apache.calcite.runtime.SqlFunctions.posixRegex;
 import static org.apache.calcite.runtime.SqlFunctions.regexpReplace;
 import static org.apache.calcite.runtime.SqlFunctions.rtrim;
 import static org.apache.calcite.runtime.SqlFunctions.sha1;
-import static org.apache.calcite.runtime.SqlFunctions.subtractMonths;
 import static org.apache.calcite.runtime.SqlFunctions.toBase64;
 import static org.apache.calcite.runtime.SqlFunctions.toInt;
 import static org.apache.calcite.runtime.SqlFunctions.toIntOptional;
@@ -71,7 +68,6 @@ import static org.hamcrest.CoreMatchers.equalTo;
 import static org.hamcrest.CoreMatchers.is;
 import static org.hamcrest.CoreMatchers.nullValue;
 import static org.hamcrest.MatcherAssert.assertThat;
-import static org.hamcrest.core.AnyOf.anyOf;
 import static org.junit.jupiter.api.Assertions.assertSame;
 import static org.junit.jupiter.api.Assertions.fail;
 
@@ -302,47 +298,6 @@ class SqlFunctionsTest {
     return trim(true, true, " ", s);
   }
 
-  @Test void testAddMonths() {
-    checkAddMonths(2016, 1, 1, 2016, 2, 1, 1);
-    checkAddMonths(2016, 1, 1, 2017, 1, 1, 12);
-    checkAddMonths(2016, 1, 1, 2017, 2, 1, 13);
-    checkAddMonths(2016, 1, 1, 2015, 1, 1, -12);
-    checkAddMonths(2016, 1, 1, 2018, 10, 1, 33);
-    checkAddMonths(2016, 1, 31, 2016, 4, 30, 3);
-    checkAddMonths(2016, 4, 30, 2016, 7, 30, 3);
-    checkAddMonths(2016, 1, 31, 2016, 2, 29, 1);
-    checkAddMonths(2016, 3, 31, 2016, 2, 29, -1);
-    checkAddMonths(2016, 3, 31, 2116, 3, 31, 1200);
-    checkAddMonths(2016, 2, 28, 2116, 2, 28, 1200);
-    checkAddMonths(2019, 9, 1, 2020, 3, 1, 6);
-    checkAddMonths(2019, 9, 1, 2016, 8, 1, -37);
-  }
-
-  private void checkAddMonths(int y0, int m0, int d0, int y1, int m1, int d1,
-      int months) {
-    final int date0 = ymdToUnixDate(y0, m0, d0);
-    final long date = addMonths(date0, months);
-    final int date1 = ymdToUnixDate(y1, m1, d1);
-    assertThat((int) date, is(date1));
-
-    assertThat(subtractMonths(date1, date0),
-        anyOf(is(months), is(months + 1)));
-    assertThat(subtractMonths(date1 + 1, date0),
-        anyOf(is(months), is(months + 1)));
-    assertThat(subtractMonths(date1, date0 + 1),
-        anyOf(is(months), is(months - 1)));
-    assertThat(subtractMonths(d2ts(date1, 1), d2ts(date0, 0)),
-        anyOf(is(months), is(months + 1)));
-    assertThat(subtractMonths(d2ts(date1, 0), d2ts(date0, 1)),
-        anyOf(is(months - 1), is(months), is(months + 1)));
-  }
-
-  /** Converts a date (days since epoch) and milliseconds (since midnight)
-   * into a timestamp (milliseconds since epoch). */
-  private long d2ts(int date, int millis) {
-    return date * DateTimeUtils.MILLIS_PER_DAY + millis;
-  }
-
   @Test void testFloor() {
     checkFloor(0, 10, 0);
     checkFloor(27, 10, 20);

Reply via email to