This is an automated email from the ASF dual-hosted git repository.

lzljs3620320 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git


The following commit(s) were added to refs/heads/master by this push:
     new f2c07aa  [FLINK-22405] [table-runtime] Support CharType for the 
LeadLagAggFunction in batch mode
f2c07aa is described below

commit f2c07aaee210e2205187dc58a92a8d96582d2369
Author: liliwei <[email protected]>
AuthorDate: Thu Aug 5 19:42:42 2021 +0800

    [FLINK-22405] [table-runtime] Support CharType for the LeadLagAggFunction 
in batch mode
    
    This closes #16650
---
 .../functions/aggfunctions/LeadLagAggFunction.java | 17 ++++++
 .../planner/plan/utils/AggFunctionFactory.scala    |  3 +
 .../batch/sql/agg/AggregateITCaseBase.scala        | 68 +++++++++++++++++++++-
 3 files changed, 87 insertions(+), 1 deletion(-)

diff --git 
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/aggfunctions/LeadLagAggFunction.java
 
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/aggfunctions/LeadLagAggFunction.java
index f016d6c..f7af12c 100644
--- 
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/aggfunctions/LeadLagAggFunction.java
+++ 
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/aggfunctions/LeadLagAggFunction.java
@@ -23,6 +23,7 @@ import org.apache.flink.table.expressions.Expression;
 import org.apache.flink.table.expressions.UnresolvedReferenceExpression;
 import org.apache.flink.table.runtime.operators.over.frame.OffsetOverFrame;
 import org.apache.flink.table.types.DataType;
+import org.apache.flink.table.types.logical.CharType;
 import org.apache.flink.table.types.logical.DecimalType;
 import org.apache.flink.table.types.logical.TimeType;
 import org.apache.flink.table.types.logical.TimestampType;
@@ -230,6 +231,22 @@ public abstract class LeadLagAggFunction extends 
DeclarativeAggregateFunction {
         }
     }
 
+    /** CharLeadLagAggFunction. */
+    public static class CharLeadLagAggFunction extends LeadLagAggFunction {
+
+        private final CharType type;
+
+        public CharLeadLagAggFunction(int operandCount, CharType type) {
+            super(operandCount);
+            this.type = type;
+        }
+
+        @Override
+        public DataType getResultType() {
+            return DataTypes.CHAR(type.getLength());
+        }
+    }
+
     /** DateLeadLagAggFunction. */
     public static class DateLeadLagAggFunction extends LeadLagAggFunction {
 
diff --git 
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/AggFunctionFactory.scala
 
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/AggFunctionFactory.scala
index a2b795b..860b829 100644
--- 
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/AggFunctionFactory.scala
+++ 
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/AggFunctionFactory.scala
@@ -368,6 +368,9 @@ class AggFunctionFactory(
         new LeadLagAggFunction.BooleanLeadLagAggFunction(argTypes.length)
       case VARCHAR =>
         new LeadLagAggFunction.StringLeadLagAggFunction(argTypes.length)
+      case CHAR =>
+        val d = argTypes(0).asInstanceOf[CharType]
+        new LeadLagAggFunction.CharLeadLagAggFunction(argTypes.length, d);
       case DATE =>
         new LeadLagAggFunction.DateLeadLagAggFunction(argTypes.length)
       case TIME_WITHOUT_TIME_ZONE =>
diff --git 
a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/batch/sql/agg/AggregateITCaseBase.scala
 
b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/batch/sql/agg/AggregateITCaseBase.scala
index 2d23a83..05f6174 100644
--- 
a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/batch/sql/agg/AggregateITCaseBase.scala
+++ 
b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/batch/sql/agg/AggregateITCaseBase.scala
@@ -21,7 +21,7 @@ import org.apache.flink.api.common.typeinfo.TypeInformation
 import org.apache.flink.api.java.tuple.{Tuple2 => JTuple2}
 import org.apache.flink.api.java.typeutils.RowTypeInfo
 import org.apache.flink.api.scala._
-import org.apache.flink.table.api.{TableException, Types}
+import org.apache.flink.table.api.{DataTypes, TableException, Types}
 import org.apache.flink.table.data.DecimalDataUtils
 import org.apache.flink.table.planner.runtime.utils.BatchTestBase
 import org.apache.flink.table.planner.runtime.utils.BatchTestBase.row
@@ -865,6 +865,72 @@ abstract class AggregateITCaseBase(testName: String) 
extends BatchTestBase {
     // that doesn't make sense, and we do not support it.
   }
 
+  @Test
+  def testLeadLag(): Unit = {
+
+    val testAllDataTypeCardinality = tEnv.fromValues(
+      DataTypes.ROW(
+        DataTypes.FIELD("a", DataTypes.STRING()),
+        DataTypes.FIELD("b", DataTypes.TINYINT()),
+        DataTypes.FIELD("c", DataTypes.SMALLINT()),
+        DataTypes.FIELD("d", DataTypes.INT),
+        DataTypes.FIELD("e", DataTypes.BIGINT()),
+        DataTypes.FIELD("f", DataTypes.FLOAT()),
+        DataTypes.FIELD("g", DataTypes.DOUBLE()),
+        DataTypes.FIELD("h", DataTypes.BOOLEAN()),
+        DataTypes.FIELD("i", DataTypes.VARCHAR(20)),
+        DataTypes.FIELD("j", DataTypes.CHAR(20)),
+        DataTypes.FIELD("k", DataTypes.DATE()),
+        DataTypes.FIELD("l", DataTypes.TIME()),
+        DataTypes.FIELD("m", DataTypes.TIMESTAMP()),
+        DataTypes.FIELD("n", DataTypes.DECIMAL(3, 2))
+      ),
+      row("Alice", 1, 1, 2, 9223, -2.3F, 9.9D, "true", "varchar", "char", 
+        "2021-8-3", "20:8:17", "2021-8-3 20:8:29", 9.99),
+      row("Alice", null, null, null, null, null, null, null, null, 
+        null, null, null, null, null),
+      row("Alice", 1, 1, 2, 9223, -2.3F, 9.9D, "true", "varchar", "char",
+        "2021-8-3", "20:8:17", "2021-8-3 20:8:29", 9.99)
+    )
+
+    checkResult(
+      s"""
+         |SELECT
+         |  a,
+         |  b, LEAD(b, 1) over (order by a)  AS bLead, LAG(b, 1) over (order 
by a)  AS bLag,
+         |  c, LEAD(c, 1) over (order by a)  AS cLead, LAG(c, 1) over (order 
by a)  AS cLag,
+         |  d, LEAD(d, 1) over (order by a)  AS dLead, LAG(d, 1) over (order 
by a)  AS dLag,
+         |  e, LEAD(e, 1) over (order by a)  AS eLead, LAG(e, 1) over (order 
by a)  AS eLag,
+         |  f, LEAD(f, 1) over (order by a)  AS fLead, LAG(f, 1) over (order 
by a)  AS fLag,
+         |  g, LEAD(g, 1) over (order by a)  AS gLead, LAG(g, 1) over (order 
by a)  AS gLag,
+         |  h, LEAD(h, 1) over (order by a)  AS hLead, LAG(h, 1) over (order 
by a)  AS hLag,
+         |  i, LEAD(i, 1) over (order by a)  AS iLead, LAG(i, 1) over (order 
by a)  AS iLag,
+         |  j, LEAD(j, 1) over (order by a)  AS jLead, LAG(j, 1) over (order 
by a)  AS jLag,
+         |  k, LEAD(k, 1) over (order by a)  AS kLead, LAG(k, 1) over (order 
by a)  AS kLag,
+         |  l, LEAD(l, 1) over (order by a)  AS lLead, LAG(l, 1) over (order 
by a)  AS lLag,
+         |  m, LEAD(m, 1) over (order by a)  AS mLead, LAG(m, 1) over (order 
by a)  AS mLag,
+         |  n, LEAD(n, 1) over (order by a)  AS nLead, LAG(n, 1) over (order 
by a)  AS nLag
+         |
+         |FROM ${testAllDataTypeCardinality}
+         |order by a
+         |""".stripMargin,
+      Seq(
+        row("Alice", 1, 1, null, 1, 1, null, 2, 2, null, 9223, 9223, null, 
-2.3, -2.3, null, 
+          9.9, 9.9, null, true, true, null, "varchar", "varchar", null, "char  
              ", 
+          "char                ", null, "2021-08-03", "2021-08-03", null, 
"20:08:17", "20:08:17", 
+          null, "2021-08-03T20:08:29", "2021-08-03T20:08:29", null, 9.99, 
9.99, null),
+        row("Alice", 1, null, 1, 1, null, 1, 2, null, 2, 9223, null, 9223,
+          -2.3, null, -2.3, 9.9, null, 9.9, true, null, true, "varchar", null, 
+          "varchar", "char                ", null, "char                ", 
"2021-08-03", null, 
+          "2021-08-03", "20:08:17", null, "20:08:17", "2021-08-03T20:08:29", 
null, 
+          "2021-08-03T20:08:29", 9.99, null, 9.99),
+        row("Alice", null, null, 1, null, null, 1, null, null, 2, null, null, 
9223, null,
+          null, -2.3, null, null, 9.9, null, null, true, null, null, 
"varchar", null, null,
+          "char                ", null, null, "2021-08-03", null, null, 
"20:08:17", null, null,
+          "2021-08-03T20:08:29", null, null, 9.99)
+      ))
+  }
+
   // TODO support csv
 //  @Test
 //  def testMultiGroupBys(): Unit = {

Reply via email to