http://git-wip-us.apache.org/repos/asf/flink/blob/44f3977e/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/stringExpressions.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/stringExpressions.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/stringExpressions.scala index a692c9e..56b5b5e 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/stringExpressions.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/stringExpressions.scala @@ -32,7 +32,7 @@ import org.apache.flink.api.table.validate._ case class CharLength(child: Expression) extends UnaryExpression { override private[flink] def resultType: TypeInformation[_] = INT_TYPE_INFO - override private[flink] def validateInput(): ExprValidationResult = { + override private[flink] def validateInput(): ValidationResult = { if (child.resultType == STRING_TYPE_INFO) { ValidationSuccess } else { @@ -55,7 +55,7 @@ case class CharLength(child: Expression) extends UnaryExpression { case class InitCap(child: Expression) extends UnaryExpression { override private[flink] def resultType: TypeInformation[_] = STRING_TYPE_INFO - override private[flink] def validateInput(): ExprValidationResult = { + override private[flink] def validateInput(): ValidationResult = { if (child.resultType == STRING_TYPE_INFO) { ValidationSuccess } else { @@ -80,7 +80,7 @@ case class Like(str: Expression, pattern: Expression) extends BinaryExpression { override private[flink] def resultType: TypeInformation[_] = BOOLEAN_TYPE_INFO - override private[flink] def validateInput(): ExprValidationResult = { + override private[flink] def validateInput(): ValidationResult = { if (str.resultType == STRING_TYPE_INFO && pattern.resultType == STRING_TYPE_INFO) { ValidationSuccess } else { @@ -102,7 +102,7 @@ case class Like(str: Expression, pattern: Expression) extends BinaryExpression { case class Lower(child: Expression) extends UnaryExpression { override private[flink] def resultType: TypeInformation[_] = STRING_TYPE_INFO - override private[flink] def validateInput(): ExprValidationResult = { + override private[flink] def validateInput(): ValidationResult = { if (child.resultType == STRING_TYPE_INFO) { ValidationSuccess } else { @@ -127,7 +127,7 @@ case class Similar(str: Expression, pattern: Expression) extends BinaryExpressio override private[flink] def resultType: TypeInformation[_] = BOOLEAN_TYPE_INFO - override private[flink] def validateInput(): ExprValidationResult = { + override private[flink] def validateInput(): ValidationResult = { if (str.resultType == STRING_TYPE_INFO && pattern.resultType == STRING_TYPE_INFO) { ValidationSuccess } else { @@ -179,7 +179,7 @@ case class Trim( override private[flink] def resultType: TypeInformation[_] = STRING_TYPE_INFO - override private[flink] def validateInput(): ExprValidationResult = { + override private[flink] def validateInput(): ValidationResult = { trimMode match { case SymbolExpression(_: TrimMode) => if (trimString.resultType != STRING_TYPE_INFO) {
http://git-wip-us.apache.org/repos/asf/flink/blob/44f3977e/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/time.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/time.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/time.scala index 488fd33..cd5ca0a 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/time.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/time.scala @@ -29,8 +29,8 @@ import org.apache.flink.api.table.FlinkRelBuilder import org.apache.flink.api.table.expressions.ExpressionUtils.{divide, getFactor, mod} import org.apache.flink.api.table.expressions.TimeIntervalUnit.TimeIntervalUnit import org.apache.flink.api.table.typeutils.TypeCheckUtils.isTimeInterval -import org.apache.flink.api.table.typeutils.{IntervalTypeInfo, TypeCheckUtils} -import org.apache.flink.api.table.validate.{ExprValidationResult, ValidationFailure, ValidationSuccess} +import org.apache.flink.api.table.typeutils.{TimeIntervalTypeInfo, TypeCheckUtils} +import org.apache.flink.api.table.validate.{ValidationResult, ValidationFailure, ValidationSuccess} import scala.collection.JavaConversions._ @@ -40,7 +40,7 @@ case class Extract(timeIntervalUnit: Expression, temporal: Expression) extends E override private[flink] def resultType: TypeInformation[_] = LONG_TYPE_INFO - override private[flink] def validateInput(): ExprValidationResult = { + override private[flink] def validateInput(): ValidationResult = { if (!TypeCheckUtils.isTemporal(temporal.resultType)) { return ValidationFailure(s"Extract operator requires Temporal input, " + s"but $temporal is of type ${temporal.resultType}") @@ -52,8 +52,8 @@ case class Extract(timeIntervalUnit: Expression, temporal: Expression) extends E | SymbolExpression(TimeIntervalUnit.DAY) if temporal.resultType == SqlTimeTypeInfo.DATE || temporal.resultType == SqlTimeTypeInfo.TIMESTAMP - || temporal.resultType == IntervalTypeInfo.INTERVAL_MILLIS - || temporal.resultType == IntervalTypeInfo.INTERVAL_MONTHS => + || temporal.resultType == TimeIntervalTypeInfo.INTERVAL_MILLIS + || temporal.resultType == TimeIntervalTypeInfo.INTERVAL_MONTHS => ValidationSuccess case SymbolExpression(TimeIntervalUnit.HOUR) @@ -61,7 +61,7 @@ case class Extract(timeIntervalUnit: Expression, temporal: Expression) extends E | SymbolExpression(TimeIntervalUnit.SECOND) if temporal.resultType == SqlTimeTypeInfo.TIME || temporal.resultType == SqlTimeTypeInfo.TIMESTAMP - || temporal.resultType == IntervalTypeInfo.INTERVAL_MILLIS => + || temporal.resultType == TimeIntervalTypeInfo.INTERVAL_MILLIS => ValidationSuccess case _ => @@ -146,7 +146,7 @@ abstract class TemporalCeilFloor( override private[flink] def resultType: TypeInformation[_] = temporal.resultType - override private[flink] def validateInput(): ExprValidationResult = { + override private[flink] def validateInput(): ValidationResult = { if (!TypeCheckUtils.isTimePoint(temporal.resultType)) { return ValidationFailure(s"Temporal ceil/floor operator requires Time Point input, " + s"but $temporal is of type ${temporal.resultType}") @@ -211,7 +211,7 @@ abstract class CurrentTimePoint( override private[flink] def resultType: TypeInformation[_] = targetType - override private[flink] def validateInput(): ExprValidationResult = { + override private[flink] def validateInput(): ValidationResult = { if (!TypeCheckUtils.isTimePoint(targetType)) { ValidationFailure(s"CurrentTimePoint operator requires Time Point target type, " + s"but get $targetType.") @@ -293,7 +293,7 @@ case class TemporalOverlaps( override private[flink] def resultType: TypeInformation[_] = BOOLEAN_TYPE_INFO - override private[flink] def validateInput(): ExprValidationResult = { + override private[flink] def validateInput(): ValidationResult = { if (!TypeCheckUtils.isTimePoint(leftTimePoint.resultType)) { return ValidationFailure(s"TemporalOverlaps operator requires leftTimePoint to be of type " + s"Time Point, but get ${leftTimePoint.resultType}.") http://git-wip-us.apache.org/repos/asf/flink/blob/44f3977e/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/windowProperties.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/windowProperties.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/windowProperties.scala new file mode 100644 index 0000000..8386c46 --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/windowProperties.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.api.table.expressions + +import org.apache.calcite.rex.RexNode +import org.apache.calcite.tools.RelBuilder +import org.apache.flink.api.common.typeinfo.SqlTimeTypeInfo +import org.apache.flink.api.table.FlinkRelBuilder.NamedWindowProperty +import org.apache.flink.api.table.validate.{ValidationFailure, ValidationSuccess} + +abstract class WindowProperty(child: Expression) extends UnaryExpression { + + override def toString = s"WindowProperty($child)" + + override private[flink] def toRexNode(implicit relBuilder: RelBuilder): RexNode = + throw new UnsupportedOperationException("WindowProperty cannot be transformed to RexNode.") + + override private[flink] def validateInput() = + if (child.isInstanceOf[WindowReference]) { + ValidationSuccess + } else { + ValidationFailure("Child must be a window reference.") + } + + private[flink] def toNamedWindowProperty(name: String)(implicit relBuilder: RelBuilder) + : NamedWindowProperty = NamedWindowProperty(name, this) +} + +case class WindowStart(child: Expression) extends WindowProperty(child) { + + override private[flink] def resultType = SqlTimeTypeInfo.TIMESTAMP + + override def toString: String = s"start($child)" +} + +case class WindowEnd(child: Expression) extends WindowProperty(child) { + + override private[flink] def resultType = SqlTimeTypeInfo.TIMESTAMP + + override def toString: String = s"end($child)" +} http://git-wip-us.apache.org/repos/asf/flink/blob/44f3977e/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/ProjectionTranslator.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/ProjectionTranslator.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/ProjectionTranslator.scala new file mode 100644 index 0000000..2299bd1 --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/ProjectionTranslator.scala @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.api.table.plan + +import org.apache.flink.api.table.TableEnvironment +import org.apache.flink.api.table.expressions._ +import org.apache.flink.api.table.plan.logical.LogicalNode + +import scala.collection.mutable.ListBuffer + +object ProjectionTranslator { + + /** + * Extracts all aggregation and property expressions (zero, one, or more) from an expression, + * and replaces the original expressions by field accesses expressions. + */ + def extractAggregationsAndProperties( + exp: Expression, + tableEnv: TableEnvironment) + : (Expression, List[NamedExpression], List[NamedExpression]) = { + + exp match { + case agg: Aggregation => + val name = tableEnv.createUniqueAttributeName() + val aggCall = Alias(agg, name) + val fieldExp = UnresolvedFieldReference(name) + (fieldExp, List(aggCall), Nil) + case prop: WindowProperty => + val name = tableEnv.createUniqueAttributeName() + val propCall = Alias(prop, name) + val fieldExp = UnresolvedFieldReference(name) + (fieldExp, Nil, List(propCall)) + case n @ Alias(agg: Aggregation, name) => + val fieldExp = UnresolvedFieldReference(name) + (fieldExp, List(n), Nil) + case n @ Alias(prop: WindowProperty, name) => + val fieldExp = UnresolvedFieldReference(name) + (fieldExp, Nil, List(n)) + case l: LeafExpression => + (l, Nil, Nil) + case u: UnaryExpression => + val c = extractAggregationsAndProperties(u.child, tableEnv) + (u.makeCopy(Array(c._1)), c._2, c._3) + case b: BinaryExpression => + val l = extractAggregationsAndProperties(b.left, tableEnv) + val r = extractAggregationsAndProperties(b.right, tableEnv) + (b.makeCopy(Array(l._1, r._1)), + l._2 ::: r._2, + l._3 ::: r._3) + + // Functions calls + case c @ Call(name, args) => + val newArgs = args.map(extractAggregationsAndProperties(_, tableEnv)) + (c.makeCopy((name :: newArgs.map(_._1) :: Nil).toArray), + newArgs.flatMap(_._2).toList, + newArgs.flatMap(_._3).toList) + + case sfc @ ScalarFunctionCall(clazz, args) => + val newArgs = args.map(extractAggregationsAndProperties(_, tableEnv)) + (sfc.makeCopy((clazz :: newArgs.map(_._1) :: Nil).toArray), + newArgs.flatMap(_._2).toList, + newArgs.flatMap(_._3).toList) + + // General expression + case e: Expression => + val newArgs = e.productIterator.map { + case arg: Expression => + extractAggregationsAndProperties(arg, tableEnv) + } + (e.makeCopy(newArgs.map(_._1).toArray), + newArgs.flatMap(_._2).toList, + newArgs.flatMap(_._3).toList) + } + } + + /** + * Parses all input expressions to [[UnresolvedAlias]]. + * And expands star to parent's full project list. + */ + def expandProjectList(exprs: Seq[Expression], parent: LogicalNode): Seq[NamedExpression] = { + val projectList = new ListBuffer[NamedExpression] + exprs.foreach { + case n: UnresolvedFieldReference if n.name == "*" => + projectList ++= parent.output.map(UnresolvedAlias(_)) + case e: Expression => projectList += UnresolvedAlias(e) + } + projectList + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/44f3977e/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/RexNodeTranslator.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/RexNodeTranslator.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/RexNodeTranslator.scala deleted file mode 100644 index eb40bba..0000000 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/RexNodeTranslator.scala +++ /dev/null @@ -1,88 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.api.table.plan - -import org.apache.flink.api.table.TableEnvironment -import org.apache.flink.api.table.expressions._ -import org.apache.flink.api.table.plan.logical.LogicalNode - -import scala.collection.mutable.ListBuffer - -object RexNodeTranslator { - - /** - * Extracts all aggregation expressions (zero, one, or more) from an expression, - * and replaces the original aggregation expressions by field accesses expressions. - */ - def extractAggregations( - exp: Expression, - tableEnv: TableEnvironment): Pair[Expression, List[NamedExpression]] = { - - exp match { - case agg: Aggregation => - val name = tableEnv.createUniqueAttributeName() - val aggCall = Alias(agg, name) - val fieldExp = UnresolvedFieldReference(name) - (fieldExp, List(aggCall)) - case n @ Alias(agg: Aggregation, name) => - val fieldExp = UnresolvedFieldReference(name) - (fieldExp, List(n)) - case l: LeafExpression => - (l, Nil) - case u: UnaryExpression => - val c = extractAggregations(u.child, tableEnv) - (u.makeCopy(Array(c._1)), c._2) - case b: BinaryExpression => - val l = extractAggregations(b.left, tableEnv) - val r = extractAggregations(b.right, tableEnv) - (b.makeCopy(Array(l._1, r._1)), l._2 ::: r._2) - - // Functions calls - case c @ Call(name, args) => - val newArgs = args.map(extractAggregations(_, tableEnv)) - (c.makeCopy((name :: newArgs.map(_._1) :: Nil).toArray), newArgs.flatMap(_._2).toList) - - case sfc @ ScalarFunctionCall(clazz, args) => - val newArgs = args.map(extractAggregations(_, tableEnv)) - (sfc.makeCopy((clazz :: newArgs.map(_._1) :: Nil).toArray), newArgs.flatMap(_._2).toList) - - // General expression - case e: Expression => - val newArgs = e.productIterator.map { - case arg: Expression => - extractAggregations(arg, tableEnv) - } - (e.makeCopy(newArgs.map(_._1).toArray), newArgs.flatMap(_._2).toList) - } - } - - /** - * Parses all input expressions to [[UnresolvedAlias]]. - * And expands star to parent's full project list. - */ - def expandProjectList(exprs: Seq[Expression], parent: LogicalNode): Seq[NamedExpression] = { - val projectList = new ListBuffer[NamedExpression] - exprs.foreach { - case n: UnresolvedFieldReference if n.name == "*" => - projectList ++= parent.output.map(UnresolvedAlias(_)) - case e: Expression => projectList += UnresolvedAlias(e) - } - projectList - } -} http://git-wip-us.apache.org/repos/asf/flink/blob/44f3977e/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/logical/LogicalNode.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/logical/LogicalNode.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/logical/LogicalNode.scala index dae02bd..55fba07 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/logical/LogicalNode.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/logical/LogicalNode.scala @@ -19,8 +19,7 @@ package org.apache.flink.api.table.plan.logical import org.apache.calcite.rel.RelNode import org.apache.calcite.tools.RelBuilder - -import org.apache.flink.api.table.{TableEnvironment, ValidationException} +import org.apache.flink.api.table.{StreamTableEnvironment, TableEnvironment, ValidationException} import org.apache.flink.api.table.expressions._ import org.apache.flink.api.table.trees.TreeNode import org.apache.flink.api.table.typeutils.TypeCoercion @@ -54,7 +53,7 @@ abstract class LogicalNode extends TreeNode[LogicalNode] { // resolve references and function calls val exprResolved = expressionPostOrderTransform { case u @ UnresolvedFieldReference(name) => - resolveReference(name).getOrElse(u) + resolveReference(tableEnv, name).getOrElse(u) case c @ Call(name, children) if c.childrenValid => tableEnv.getFunctionCatalog.lookupFunction(name, children) } @@ -84,7 +83,7 @@ abstract class LogicalNode extends TreeNode[LogicalNode] { resolvedNode.expressionPostOrderTransform { case a: Attribute if !a.valid => val from = children.flatMap(_.output).map(_.name).mkString(", ") - failValidation(s"cannot resolve [${a.name}] given input [$from]") + failValidation(s"Cannot resolve [${a.name}] given input [$from].") case e: Expression if e.validateInput().isFailure => failValidation(s"Expression $e failed on input check: " + @@ -96,12 +95,12 @@ abstract class LogicalNode extends TreeNode[LogicalNode] { * Resolves the given strings to a [[NamedExpression]] using the input from all child * nodes of this LogicalPlan. */ - def resolveReference(name: String): Option[NamedExpression] = { + def resolveReference(tableEnv: TableEnvironment, name: String): Option[NamedExpression] = { val childrenOutput = children.flatMap(_.output) val candidates = childrenOutput.filter(_.name.equalsIgnoreCase(name)) if (candidates.length > 1) { - failValidation(s"Reference $name is ambiguous") - } else if (candidates.length == 0) { + failValidation(s"Reference $name is ambiguous.") + } else if (candidates.isEmpty) { None } else { Some(candidates.head.withName(name)) @@ -133,6 +132,7 @@ abstract class LogicalNode extends TreeNode[LogicalNode] { case e: Expression => expressionPostOrderTransform(e) case other => other } + case r: Resolvable[_] => r.resolveExpressions(e => expressionPostOrderTransform(e)) case other: AnyRef => other }.toArray http://git-wip-us.apache.org/repos/asf/flink/blob/44f3977e/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/logical/LogicalWindow.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/logical/LogicalWindow.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/logical/LogicalWindow.scala new file mode 100644 index 0000000..19fd603 --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/logical/LogicalWindow.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.api.table.plan.logical + +import org.apache.flink.api.table.TableEnvironment +import org.apache.flink.api.table.expressions.{Expression, WindowReference} +import org.apache.flink.api.table.validate.{ValidationFailure, ValidationResult, ValidationSuccess} + +abstract class LogicalWindow(val alias: Option[Expression]) extends Resolvable[LogicalWindow] { + + def resolveExpressions(resolver: (Expression) => Expression): LogicalWindow = this + + def validate(tableEnv: TableEnvironment): ValidationResult = alias match { + case Some(WindowReference(_)) => ValidationSuccess + case Some(_) => ValidationFailure("Window reference for window expected.") + case None => ValidationSuccess + } + + override def toString: String = getClass.getSimpleName +} http://git-wip-us.apache.org/repos/asf/flink/blob/44f3977e/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/logical/Resolvable.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/logical/Resolvable.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/logical/Resolvable.scala new file mode 100644 index 0000000..7540d43 --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/logical/Resolvable.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.api.table.plan.logical + +import org.apache.flink.api.table.expressions.Expression + +/** + * A class implementing this interface can resolve the expressions of its parameters and + * return a new instance with resolved parameters. This is necessary if expression are nested in + * a not supported structure. By default, the validation of a logical node can resolve common + * structures like `Expression`, `Option[Expression]`, `Traversable[Expression]`. + * + * See also [[LogicalNode.expressionPostOrderTransform(scala.PartialFunction)]]. + * + * @tparam T class which expression parameters need to be resolved + */ +trait Resolvable[T <: AnyRef] { + + /** + * An implementing class can resolve its expressions by applying the given resolver + * function on its parameters. + * + * @param resolver function that can resolve an expression + * @return class with resolved expression parameters + */ + def resolveExpressions(resolver: (Expression) => Expression): T +} http://git-wip-us.apache.org/repos/asf/flink/blob/44f3977e/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/logical/groupWindows.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/logical/groupWindows.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/logical/groupWindows.scala new file mode 100644 index 0000000..aeb9676 --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/logical/groupWindows.scala @@ -0,0 +1,258 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.api.table.plan.logical + +import org.apache.flink.api.common.typeinfo.BasicTypeInfo +import org.apache.flink.api.table.{BatchTableEnvironment, StreamTableEnvironment, TableEnvironment} +import org.apache.flink.api.table.expressions._ +import org.apache.flink.api.table.typeutils.{RowIntervalTypeInfo, TimeIntervalTypeInfo, TypeCoercion} +import org.apache.flink.api.table.validate.{ValidationFailure, ValidationResult, ValidationSuccess} + +abstract class EventTimeGroupWindow( + name: Option[Expression], + time: Expression) + extends LogicalWindow(name) { + + override def validate(tableEnv: TableEnvironment): ValidationResult = { + val valid = super.validate(tableEnv) + if (valid.isFailure) { + return valid + } + + tableEnv match { + case _: StreamTableEnvironment => + time match { + case RowtimeAttribute() => + ValidationSuccess + case _ => + ValidationFailure("Event-time window expects a 'rowtime' time field.") + } + case _: BatchTableEnvironment => + if (!TypeCoercion.canCast(time.resultType, BasicTypeInfo.LONG_TYPE_INFO)) { + ValidationFailure(s"Event-time window expects a time field that can be safely cast " + + s"to Long, but is ${time.resultType}") + } else { + ValidationSuccess + } + } + + } +} + +abstract class ProcessingTimeGroupWindow(name: Option[Expression]) extends LogicalWindow(name) + +// ------------------------------------------------------------------------------------------------ +// Tumbling group windows +// ------------------------------------------------------------------------------------------------ + +object TumblingGroupWindow { + def validate(tableEnv: TableEnvironment, size: Expression): ValidationResult = size match { + case Literal(_, TimeIntervalTypeInfo.INTERVAL_MILLIS) => + ValidationSuccess + case Literal(_, RowIntervalTypeInfo.INTERVAL_ROWS) => + ValidationSuccess + case _ => + ValidationFailure("Tumbling window expects size literal of type Interval of Milliseconds " + + "or Interval of Rows.") + } +} + +case class ProcessingTimeTumblingGroupWindow( + name: Option[Expression], + size: Expression) + extends ProcessingTimeGroupWindow(name) { + + override def resolveExpressions(resolve: (Expression) => Expression): LogicalWindow = + ProcessingTimeTumblingGroupWindow( + name.map(resolve), + resolve(size)) + + override def validate(tableEnv: TableEnvironment): ValidationResult = + super.validate(tableEnv).orElse(TumblingGroupWindow.validate(tableEnv, size)) + + override def toString: String = s"ProcessingTimeTumblingGroupWindow($name, $size)" +} + +case class EventTimeTumblingGroupWindow( + name: Option[Expression], + timeField: Expression, + size: Expression) + extends EventTimeGroupWindow( + name, + timeField) { + + override def resolveExpressions(resolve: (Expression) => Expression): LogicalWindow = + EventTimeTumblingGroupWindow( + name.map(resolve), + resolve(timeField), + resolve(size)) + + override def validate(tableEnv: TableEnvironment): ValidationResult = + super.validate(tableEnv) + .orElse(TumblingGroupWindow.validate(tableEnv, size)) + .orElse(size match { + case Literal(_, RowIntervalTypeInfo.INTERVAL_ROWS) => + ValidationFailure( + "Event-time grouping windows on row intervals are currently not supported.") + case _ => + ValidationSuccess + }) + + override def toString: String = s"EventTimeTumblingGroupWindow($name, $timeField, $size)" +} + +// ------------------------------------------------------------------------------------------------ +// Sliding group windows +// ------------------------------------------------------------------------------------------------ + +object SlidingGroupWindow { + def validate( + tableEnv: TableEnvironment, + size: Expression, + slide: Expression) + : ValidationResult = { + + val checkedSize = size match { + case Literal(_, TimeIntervalTypeInfo.INTERVAL_MILLIS) => + ValidationSuccess + case Literal(_, RowIntervalTypeInfo.INTERVAL_ROWS) => + ValidationSuccess + case _ => + ValidationFailure("Sliding window expects size literal of type Interval of " + + "Milliseconds or Interval of Rows.") + } + + val checkedSlide = slide match { + case Literal(_, TimeIntervalTypeInfo.INTERVAL_MILLIS) => + ValidationSuccess + case Literal(_, RowIntervalTypeInfo.INTERVAL_ROWS) => + ValidationSuccess + case _ => + ValidationFailure("Sliding window expects slide literal of type Interval of " + + "Milliseconds or Interval of Rows.") + } + + checkedSize + .orElse(checkedSlide) + .orElse { + if (size.resultType != slide.resultType) { + ValidationFailure("Sliding window expects same type of size and slide.") + } else { + ValidationSuccess + } + } + } +} + +case class ProcessingTimeSlidingGroupWindow( + name: Option[Expression], + size: Expression, + slide: Expression) + extends ProcessingTimeGroupWindow(name) { + + override def resolveExpressions(resolve: (Expression) => Expression): LogicalWindow = + ProcessingTimeSlidingGroupWindow( + name.map(resolve), + resolve(size), + resolve(slide)) + + override def validate(tableEnv: TableEnvironment): ValidationResult = + super.validate(tableEnv).orElse(SlidingGroupWindow.validate(tableEnv, size, slide)) + + override def toString: String = s"ProcessingTimeSlidingGroupWindow($name, $size, $slide)" +} + +case class EventTimeSlidingGroupWindow( + name: Option[Expression], + timeField: Expression, + size: Expression, + slide: Expression) + extends EventTimeGroupWindow(name, timeField) { + + override def resolveExpressions(resolve: (Expression) => Expression): LogicalWindow = + EventTimeSlidingGroupWindow( + name.map(resolve), + resolve(timeField), + resolve(size), + resolve(slide)) + + override def validate(tableEnv: TableEnvironment): ValidationResult = + super.validate(tableEnv) + .orElse(SlidingGroupWindow.validate(tableEnv, size, slide)) + .orElse(size match { + case Literal(_, RowIntervalTypeInfo.INTERVAL_ROWS) => + ValidationFailure( + "Event-time grouping windows on row intervals are currently not supported.") + case _ => + ValidationSuccess + }) + + override def toString: String = s"EventTimeSlidingGroupWindow($name, $timeField, $size, $slide)" +} + +// ------------------------------------------------------------------------------------------------ +// Session group windows +// ------------------------------------------------------------------------------------------------ + +object SessionGroupWindow { + + def validate(tableEnv: TableEnvironment, gap: Expression): ValidationResult = gap match { + case Literal(timeInterval: Long, TimeIntervalTypeInfo.INTERVAL_MILLIS) => + ValidationSuccess + case _ => + ValidationFailure( + "Session window expects gap literal of type Interval of Milliseconds.") + } +} + +case class ProcessingTimeSessionGroupWindow( + name: Option[Expression], + gap: Expression) + extends ProcessingTimeGroupWindow(name) { + + override def resolveExpressions(resolve: (Expression) => Expression): LogicalWindow = + ProcessingTimeSessionGroupWindow( + name.map(resolve), + resolve(gap)) + + override def validate(tableEnv: TableEnvironment): ValidationResult = + super.validate(tableEnv).orElse(SessionGroupWindow.validate(tableEnv, gap)) + + override def toString: String = s"ProcessingTimeSessionGroupWindow($name, $gap)" +} + +case class EventTimeSessionGroupWindow( + name: Option[Expression], + timeField: Expression, + gap: Expression) + extends EventTimeGroupWindow( + name, + timeField) { + + override def resolveExpressions(resolve: (Expression) => Expression): LogicalWindow = + EventTimeSessionGroupWindow( + name.map(resolve), + resolve(timeField), + resolve(gap)) + + override def validate(tableEnv: TableEnvironment): ValidationResult = + super.validate(tableEnv).orElse(SessionGroupWindow.validate(tableEnv, gap)) + + override def toString: String = s"EventTimeSessionGroupWindow($name, $timeField, $gap)" +} http://git-wip-us.apache.org/repos/asf/flink/blob/44f3977e/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/logical/operators.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/logical/operators.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/logical/operators.scala index 066e9d6..1d7ed5f 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/logical/operators.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/logical/operators.scala @@ -28,6 +28,7 @@ import org.apache.flink.api.java.operators.join.JoinType import org.apache.flink.api.table._ import org.apache.flink.api.table.expressions._ import org.apache.flink.api.table.typeutils.TypeConverter +import org.apache.flink.api.table.validate.{ValidationFailure, ValidationSuccess} import scala.collection.JavaConverters._ import scala.collection.mutable @@ -55,28 +56,27 @@ case class Project(projectList: Seq[NamedExpression], child: LogicalNode) extend override def validate(tableEnv: TableEnvironment): LogicalNode = { val resolvedProject = super.validate(tableEnv).asInstanceOf[Project] + val names: mutable.Set[String] = mutable.Set() - def checkUniqueNames(exprs: Seq[Expression]): Unit = { - val names: mutable.Set[String] = mutable.Set() - exprs.foreach { - case n: Alias => - // explicit name - if (names.contains(n.name)) { - throw ValidationException(s"Duplicate field name ${n.name}.") - } else { - names.add(n.name) - } - case r: ResolvedFieldReference => - // simple field forwarding - if (names.contains(r.name)) { - throw ValidationException(s"Duplicate field name ${r.name}.") - } else { - names.add(r.name) - } - case _ => // Do nothing + def checkName(name: String): Unit = { + if (names.contains(name)) { + failValidation(s"Duplicate field name $name.") + } else if (tableEnv.isInstanceOf[StreamTableEnvironment] && name == "rowtime") { + failValidation("'rowtime' cannot be used as field name in a streaming environment.") + } else { + names.add(name) } } - checkUniqueNames(resolvedProject.projectList) + + resolvedProject.projectList.foreach { + case n: Alias => + // explicit name + checkName(n.name) + case r: ResolvedFieldReference => + // simple field forwarding + checkName(r.name) + case _ => // Do nothing + } resolvedProject } @@ -112,6 +112,10 @@ case class AliasNode(aliasList: Seq[Expression], child: LogicalNode) extends Una failValidation("Alias only accept name expressions as arguments") } else if (!aliasList.forall(_.asInstanceOf[UnresolvedFieldReference].name != "*")) { failValidation("Alias can not accept '*' as name") + } else if (tableEnv.isInstanceOf[StreamTableEnvironment] && !aliasList.forall { + case UnresolvedFieldReference(name) => name != "rowtime" + }) { + failValidation("'rowtime' cannot be used as field name in a streaming environment.") } else { val names = aliasList.map(_.asInstanceOf[UnresolvedFieldReference].name) val input = child.output @@ -498,3 +502,105 @@ case class LogicalRelNode( override def validate(tableEnv: TableEnvironment): LogicalNode = this } + +case class WindowAggregate( + groupingExpressions: Seq[Expression], + window: LogicalWindow, + propertyExpressions: Seq[NamedExpression], + aggregateExpressions: Seq[NamedExpression], + child: LogicalNode) + extends UnaryNode { + + override def output: Seq[Attribute] = { + (groupingExpressions ++ aggregateExpressions ++ propertyExpressions) map { + case ne: NamedExpression => ne.toAttribute + case e => Alias(e, e.toString).toAttribute + } + } + + // resolve references of this operator's parameters + override def resolveReference( + tableEnv: TableEnvironment, + name: String) + : Option[NamedExpression] = tableEnv match { + // resolve reference to rowtime attribute in a streaming environment + case _: StreamTableEnvironment if name == "rowtime" => + Some(RowtimeAttribute()) + case _ => + window.alias match { + // resolve reference to this window's alias + case Some(UnresolvedFieldReference(alias)) if name == alias => + // check if reference can already be resolved by input fields + val found = super.resolveReference(tableEnv, name) + if (found.isDefined) { + failValidation(s"Reference $name is ambiguous.") + } else { + Some(WindowReference(name)) + } + case _ => + // resolve references as usual + super.resolveReference(tableEnv, name) + } + } + + override protected[logical] def construct(relBuilder: RelBuilder): RelBuilder = { + val flinkRelBuilder = relBuilder.asInstanceOf[FlinkRelBuilder] + child.construct(flinkRelBuilder) + flinkRelBuilder.aggregate( + window, + relBuilder.groupKey(groupingExpressions.map(_.toRexNode(relBuilder)).asJava), + propertyExpressions.map { + case Alias(prop: WindowProperty, name) => prop.toNamedWindowProperty(name)(relBuilder) + case _ => throw new RuntimeException("This should never happen.") + }, + aggregateExpressions.map { + case Alias(agg: Aggregation, name) => agg.toAggCall(name)(relBuilder) + case _ => throw new RuntimeException("This should never happen.") + }.asJava) + } + + override def validate(tableEnv: TableEnvironment): LogicalNode = { + val resolvedWindowAggregate = super.validate(tableEnv).asInstanceOf[WindowAggregate] + val groupingExprs = resolvedWindowAggregate.groupingExpressions + val aggregateExprs = resolvedWindowAggregate.aggregateExpressions + aggregateExprs.foreach(validateAggregateExpression) + groupingExprs.foreach(validateGroupingExpression) + + def validateAggregateExpression(expr: Expression): Unit = expr match { + // check no nested aggregation exists. + case aggExpr: Aggregation => + aggExpr.children.foreach { child => + child.preOrderVisit { + case agg: Aggregation => + failValidation( + "It's not allowed to use an aggregate function as " + + "input of another aggregate function") + case _ => // ok + } + } + case a: Attribute if !groupingExprs.exists(_.checkEquals(a)) => + failValidation( + s"Expression '$a' is invalid because it is neither" + + " present in group by nor an aggregate function") + case e if groupingExprs.exists(_.checkEquals(e)) => // ok + case e => e.children.foreach(validateAggregateExpression) + } + + def validateGroupingExpression(expr: Expression): Unit = { + if (!expr.resultType.isKeyType) { + failValidation( + s"Expression $expr cannot be used as a grouping expression " + + "because it's not a valid key type which must be hashable and comparable") + } + } + + // validate window + resolvedWindowAggregate.window.validate(tableEnv) match { + case ValidationFailure(msg) => + failValidation(s"$window is invalid: $msg") + case ValidationSuccess => // ok + } + + resolvedWindowAggregate + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/44f3977e/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/logical/rel/LogicalWindowAggregate.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/logical/rel/LogicalWindowAggregate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/logical/rel/LogicalWindowAggregate.scala new file mode 100644 index 0000000..9615168 --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/logical/rel/LogicalWindowAggregate.scala @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.api.table.plan.logical.rel + +import java.util + +import org.apache.calcite.plan.{Convention, RelOptCluster, RelTraitSet} +import org.apache.calcite.rel.`type`.RelDataType +import org.apache.calcite.rel.core.{Aggregate, AggregateCall} +import org.apache.calcite.rel.{RelNode, RelShuttle} +import org.apache.calcite.util.ImmutableBitSet +import org.apache.flink.api.table.FlinkRelBuilder.NamedWindowProperty +import org.apache.flink.api.table.FlinkTypeFactory +import org.apache.flink.api.table.plan.logical.LogicalWindow + +class LogicalWindowAggregate( + window: LogicalWindow, + namedProperties: Seq[NamedWindowProperty], + cluster: RelOptCluster, + traitSet: RelTraitSet, + child: RelNode, + indicator: Boolean, + groupSet: ImmutableBitSet, + groupSets: util.List[ImmutableBitSet], + aggCalls: util.List[AggregateCall]) + extends Aggregate( + cluster, + traitSet, + child, + indicator, + groupSet, + groupSets, + aggCalls) { + + def getWindow = window + + def getNamedProperties = namedProperties + + override def copy( + traitSet: RelTraitSet, + input: RelNode, + indicator: Boolean, + groupSet: ImmutableBitSet, + groupSets: util.List[ImmutableBitSet], + aggCalls: util.List[AggregateCall]) + : Aggregate = { + + new LogicalWindowAggregate( + window, + namedProperties, + cluster, + traitSet, + input, + indicator, + groupSet, + groupSets, + aggCalls) + } + + override def accept(shuttle: RelShuttle): RelNode = shuttle.visit(this) + + override def deriveRowType(): RelDataType = { + val aggregateRowType = super.deriveRowType() + val typeFactory = getCluster.getTypeFactory.asInstanceOf[FlinkTypeFactory] + val builder = typeFactory.builder + builder.addAll(aggregateRowType.getFieldList) + namedProperties.foreach { namedProp => + builder.add( + namedProp.name, + typeFactory.createTypeFromTypeInfo(namedProp.property.resultType) + ) + } + builder.build() + } +} + +object LogicalWindowAggregate { + + def create( + window: LogicalWindow, + namedProperties: Seq[NamedWindowProperty], + aggregate: Aggregate) + : LogicalWindowAggregate = { + + val cluster: RelOptCluster = aggregate.getCluster + val traitSet: RelTraitSet = cluster.traitSetOf(Convention.NONE) + new LogicalWindowAggregate( + window, + namedProperties, + cluster, + traitSet, + aggregate.getInput, + aggregate.indicator, + aggregate.getGroupSet, + aggregate.getGroupSets, + aggregate.getAggCallList) + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/44f3977e/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/FlinkAggregate.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/FlinkAggregate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/FlinkAggregate.scala new file mode 100644 index 0000000..85129c4 --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/FlinkAggregate.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.api.table.plan.nodes + +import org.apache.calcite.rel.`type`.RelDataType +import org.apache.calcite.rel.core.AggregateCall +import org.apache.flink.api.table.FlinkRelBuilder.NamedWindowProperty +import org.apache.flink.api.table.runtime.aggregate.AggregateUtil._ + +import scala.collection.JavaConverters._ + +trait FlinkAggregate { + + private[flink] def groupingToString(inputType: RelDataType, grouping: Array[Int]): String = { + + val inFields = inputType.getFieldNames.asScala + grouping.map( inFields(_) ).mkString(", ") + } + + private[flink] def aggregationToString( + inputType: RelDataType, + grouping: Array[Int], + rowType: RelDataType, + namedAggregates: Seq[CalcitePair[AggregateCall, String]], + namedProperties: Seq[NamedWindowProperty]) + : String = { + + val inFields = inputType.getFieldNames.asScala + val outFields = rowType.getFieldNames.asScala + + val groupStrings = grouping.map( inFields(_) ) + + val aggs = namedAggregates.map(_.getKey) + val aggStrings = aggs.map( a => s"${a.getAggregation}(${ + if (a.getArgList.size() > 0) { + inFields(a.getArgList.get(0)) + } else { + "*" + } + })") + + val propStrings = namedProperties.map(_.property.toString) + + (groupStrings ++ aggStrings ++ propStrings).zip(outFields).map { + case (f, o) => if (f == o) { + f + } else { + s"$f AS $o" + } + }.mkString(", ") + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/44f3977e/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/FlinkRel.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/FlinkRel.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/FlinkRel.scala index dad50a3..a4c7589 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/FlinkRel.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/FlinkRel.scala @@ -19,6 +19,12 @@ package org.apache.flink.api.table.plan.nodes import org.apache.calcite.rex._ +import org.apache.flink.api.common.functions.MapFunction +import org.apache.flink.api.common.typeinfo.TypeInformation +import org.apache.flink.api.table.TableConfig +import org.apache.flink.api.table.codegen.CodeGenerator +import org.apache.flink.api.table.runtime.MapRunner + import scala.collection.JavaConversions._ trait FlinkRel { @@ -44,4 +50,41 @@ trait FlinkRel { case _ => throw new IllegalArgumentException("Unknown expression type: " + expr) } } + + private[flink] def getConversionMapper( + config: TableConfig, + nullableInput: Boolean, + inputType: TypeInformation[Any], + expectedType: TypeInformation[Any], + conversionOperatorName: String, + fieldNames: Seq[String], + inputPojoFieldMapping: Option[Array[Int]] = None) + : MapFunction[Any, Any] = { + + val generator = new CodeGenerator( + config, + nullableInput, + inputType, + None, + inputPojoFieldMapping) + val conversion = generator.generateConverterResultExpression(expectedType, fieldNames) + + val body = + s""" + |${conversion.code} + |return ${conversion.resultTerm}; + |""".stripMargin + + val genFunction = generator.generateFunction( + conversionOperatorName, + classOf[MapFunction[Any, Any]], + body, + expectedType) + + new MapRunner[Any, Any]( + genFunction.name, + genFunction.code, + genFunction.returnType) + + } } http://git-wip-us.apache.org/repos/asf/flink/blob/44f3977e/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/dataset/DataSetAggregate.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/dataset/DataSetAggregate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/dataset/DataSetAggregate.scala index c826d83..c73d781 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/dataset/DataSetAggregate.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/dataset/DataSetAggregate.scala @@ -25,10 +25,11 @@ import org.apache.calcite.rel.metadata.RelMetadataQuery import org.apache.calcite.rel.{RelNode, RelWriter, SingleRel} import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.api.java.DataSet +import org.apache.flink.api.table.plan.nodes.FlinkAggregate import org.apache.flink.api.table.runtime.aggregate.AggregateUtil import org.apache.flink.api.table.runtime.aggregate.AggregateUtil.CalcitePair import org.apache.flink.api.table.typeutils.{RowTypeInfo, TypeConverter} -import org.apache.flink.api.table.{FlinkTypeFactory, BatchTableEnvironment, Row} +import org.apache.flink.api.table.{BatchTableEnvironment, FlinkTypeFactory, Row} import scala.collection.JavaConverters._ @@ -38,12 +39,13 @@ import scala.collection.JavaConverters._ class DataSetAggregate( cluster: RelOptCluster, traitSet: RelTraitSet, - input: RelNode, + inputNode: RelNode, namedAggregates: Seq[CalcitePair[AggregateCall, String]], rowRelDataType: RelDataType, inputType: RelDataType, grouping: Array[Int]) - extends SingleRel(cluster, traitSet, input) + extends SingleRel(cluster, traitSet, inputNode) + with FlinkAggregate with DataSetRel { override def deriveRowType() = rowRelDataType @@ -61,16 +63,16 @@ class DataSetAggregate( override def toString: String = { s"Aggregate(${ if (!grouping.isEmpty) { - s"groupBy: ($groupingToString), " + s"groupBy: (${groupingToString(inputType, grouping)}), " } else { "" - }}}select:($aggregationToString))" + }}select: (${aggregationToString(inputType, grouping, getRowType, namedAggregates, Nil)}))" } override def explainTerms(pw: RelWriter): RelWriter = { super.explainTerms(pw) - .itemIf("groupBy",groupingToString, !grouping.isEmpty) - .item("select", aggregationToString) + .itemIf("groupBy", groupingToString(inputType, grouping), !grouping.isEmpty) + .item("select", aggregationToString(inputType, grouping, getRowType, namedAggregates, Nil)) } override def computeSelfCost (planner: RelOptPlanner, metadata: RelMetadataQuery): RelOptCost = { @@ -90,8 +92,11 @@ class DataSetAggregate( val groupingKeys = grouping.indices.toArray // add grouping fields, position keys in the input, and input type - val aggregateResult = AggregateUtil.createOperatorFunctionsForAggregates(namedAggregates, - inputType, getRowType, grouping, config) + val aggregateResult = AggregateUtil.createOperatorFunctionsForAggregates( + namedAggregates, + inputType, + getRowType, + grouping) val inputDS = getInput.asInstanceOf[DataSetRel].translateToPlan( tableEnv, @@ -103,7 +108,7 @@ class DataSetAggregate( .map(field => FlinkTypeFactory.toTypeInfo(field.getType)) .toArray - val aggString = aggregationToString + val aggString = aggregationToString(inputType, grouping, getRowType, namedAggregates, Nil) val prepareOpName = s"prepare select: ($aggString)" val mappedInput = inputDS .map(aggregateResult._1) @@ -115,7 +120,8 @@ class DataSetAggregate( val result = { if (groupingKeys.length > 0) { // grouped aggregation - val aggOpName = s"groupBy: ($groupingToString), select:($aggString)" + val aggOpName = s"groupBy: (${groupingToString(inputType, grouping)}), " + + s"select: ($aggString)" mappedInput.asInstanceOf[DataSet[Row]] .groupBy(groupingKeys: _*) @@ -151,36 +157,4 @@ class DataSetAggregate( case _ => result } } - - private def groupingToString: String = { - - val inFields = inputType.getFieldNames.asScala - grouping.map( inFields(_) ).mkString(", ") - } - - private def aggregationToString: String = { - - val inFields = inputType.getFieldNames.asScala - val outFields = getRowType.getFieldNames.asScala - - val groupStrings = grouping.map( inFields(_) ) - - val aggs = namedAggregates.map(_.getKey) - val aggStrings = aggs.map( a => s"${a.getAggregation}(${ - if (a.getArgList.size() > 0) { - inFields(a.getArgList.get(0)) - } else { - "*" - } - })") - - (groupStrings ++ aggStrings).zip(outFields).map { - case (f, o) => if (f == o) { - f - } else { - s"$f AS $o" - } - }.mkString(", ") - } - } http://git-wip-us.apache.org/repos/asf/flink/blob/44f3977e/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/dataset/DataSetRel.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/dataset/DataSetRel.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/dataset/DataSetRel.scala index 39532f0..82c75e1 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/dataset/DataSetRel.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/dataset/DataSetRel.scala @@ -71,40 +71,4 @@ trait DataSetRel extends RelNode with FlinkRel { } - private[dataset] def getConversionMapper( - config: TableConfig, - nullableInput: Boolean, - inputType: TypeInformation[Any], - expectedType: TypeInformation[Any], - conversionOperatorName: String, - fieldNames: Seq[String], - inputPojoFieldMapping: Option[Array[Int]] = None): MapFunction[Any, Any] = { - - val generator = new CodeGenerator( - config, - nullableInput, - inputType, - None, - inputPojoFieldMapping) - val conversion = generator.generateConverterResultExpression(expectedType, fieldNames) - - val body = - s""" - |${conversion.code} - |return ${conversion.resultTerm}; - |""".stripMargin - - val genFunction = generator.generateFunction( - conversionOperatorName, - classOf[MapFunction[Any, Any]], - body, - expectedType) - - new MapRunner[Any, Any]( - genFunction.name, - genFunction.code, - genFunction.returnType) - - } - } http://git-wip-us.apache.org/repos/asf/flink/blob/44f3977e/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/datastream/DataStreamAggregate.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/datastream/DataStreamAggregate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/datastream/DataStreamAggregate.scala new file mode 100644 index 0000000..b9b4561 --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/datastream/DataStreamAggregate.scala @@ -0,0 +1,300 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.api.table.plan.nodes.datastream + +import org.apache.calcite.plan.{RelOptCluster, RelTraitSet} +import org.apache.calcite.rel.`type`.RelDataType +import org.apache.calcite.rel.core.AggregateCall +import org.apache.calcite.rel.{RelNode, RelWriter, SingleRel} +import org.apache.flink.api.common.typeinfo.TypeInformation +import org.apache.flink.api.java.tuple.Tuple +import org.apache.flink.api.table.FlinkRelBuilder.NamedWindowProperty +import org.apache.flink.api.table.expressions._ +import org.apache.flink.api.table.plan.logical._ +import org.apache.flink.api.table.plan.nodes.FlinkAggregate +import org.apache.flink.api.table.plan.nodes.datastream.DataStreamAggregate.{createKeyedWindowedStream, createNonKeyedWindowedStream, transformToPropertyReads} +import org.apache.flink.api.table.runtime.aggregate.AggregateUtil._ +import org.apache.flink.api.table.runtime.aggregate._ +import org.apache.flink.api.table.typeutils.TypeCheckUtils.isTimeInterval +import org.apache.flink.api.table.typeutils.{RowIntervalTypeInfo, RowTypeInfo, TimeIntervalTypeInfo, TypeConverter} +import org.apache.flink.api.table.{FlinkTypeFactory, Row, StreamTableEnvironment} +import org.apache.flink.streaming.api.datastream.{AllWindowedStream, DataStream, KeyedStream, WindowedStream} +import org.apache.flink.streaming.api.windowing.assigners._ +import org.apache.flink.streaming.api.windowing.time.Time +import org.apache.flink.streaming.api.windowing.windows.{Window => DataStreamWindow} + +import scala.collection.JavaConverters._ + +class DataStreamAggregate( + window: LogicalWindow, + namedProperties: Seq[NamedWindowProperty], + cluster: RelOptCluster, + traitSet: RelTraitSet, + inputNode: RelNode, + namedAggregates: Seq[CalcitePair[AggregateCall, String]], + rowRelDataType: RelDataType, + inputType: RelDataType, + grouping: Array[Int]) + extends SingleRel(cluster, traitSet, inputNode) + with FlinkAggregate + with DataStreamRel { + + override def deriveRowType() = rowRelDataType + + override def copy(traitSet: RelTraitSet, inputs: java.util.List[RelNode]): RelNode = { + new DataStreamAggregate( + window, + namedProperties, + cluster, + traitSet, + inputs.get(0), + namedAggregates, + getRowType, + inputType, + grouping) + } + + override def toString: String = { + s"Aggregate(${ + if (!grouping.isEmpty) { + s"groupBy: (${groupingToString(inputType, grouping)}), " + } else { + "" + } + }window: ($window), " + + s"select: (${ + aggregationToString( + inputType, + grouping, + getRowType, + namedAggregates, + namedProperties) + }))" + } + + override def explainTerms(pw: RelWriter): RelWriter = { + super.explainTerms(pw) + .itemIf("groupBy", groupingToString(inputType, grouping), !grouping.isEmpty) + .item("window", window) + .item("select", aggregationToString( + inputType, + grouping, + getRowType, + namedAggregates, + namedProperties)) + } + + override def translateToPlan( + tableEnv: StreamTableEnvironment, + expectedType: Option[TypeInformation[Any]]) + : DataStream[Any] = { + + val config = tableEnv.getConfig + + val groupingKeys = grouping.indices.toArray + // add grouping fields, position keys in the input, and input type + val aggregateResult = AggregateUtil.createOperatorFunctionsForAggregates( + namedAggregates, + inputType, + getRowType, + grouping) + + val propertyReads = transformToPropertyReads(namedProperties.map(_.property)) + + val inputDS = input.asInstanceOf[DataStreamRel].translateToPlan( + tableEnv, + // tell the input operator that this operator currently only supports Rows as input + Some(TypeConverter.DEFAULT_ROW_TYPE)) + + // get the output types + val fieldTypes: Array[TypeInformation[_]] = getRowType.getFieldList.asScala + .map(field => FlinkTypeFactory.toTypeInfo(field.getType)) + .toArray + + val aggString = aggregationToString( + inputType, + grouping, + getRowType, + namedAggregates, + namedProperties) + + val prepareOpName = s"prepare select: ($aggString)" + val mappedInput = inputDS + .map(aggregateResult._1) + .name(prepareOpName) + + val groupReduceFunction = aggregateResult._2 + val rowTypeInfo = new RowTypeInfo(fieldTypes) + + val result = { + // grouped / keyed aggregation + if (groupingKeys.length > 0) { + val aggOpName = s"groupBy: (${groupingToString(inputType, grouping)}), " + + s"window: ($window), " + + s"select: ($aggString)" + val aggregateFunction = new AggregateWindowFunction(propertyReads, groupReduceFunction) + + val keyedStream = mappedInput.keyBy(groupingKeys: _*) + + val windowedStream = createKeyedWindowedStream(window, keyedStream) + .asInstanceOf[WindowedStream[Row, Tuple, DataStreamWindow]] + + windowedStream + .apply(aggregateFunction) + .returns(rowTypeInfo) + .name(aggOpName) + .asInstanceOf[DataStream[Any]] + } + // global / non-keyed aggregation + else { + val aggOpName = s"window: ($window), select: ($aggString)" + val aggregateFunction = new AggregateAllWindowFunction(propertyReads, groupReduceFunction) + + val windowedStream = createNonKeyedWindowedStream(window, mappedInput) + .asInstanceOf[AllWindowedStream[Row, DataStreamWindow]] + + windowedStream + .apply(aggregateFunction) + .returns(rowTypeInfo) + .name(aggOpName) + .asInstanceOf[DataStream[Any]] + } + } + + // if the expected type is not a Row, inject a mapper to convert to the expected type + expectedType match { + case Some(typeInfo) if typeInfo.getTypeClass != classOf[Row] => + val mapName = s"convert: (${getRowType.getFieldNames.asScala.toList.mkString(", ")})" + result.map(getConversionMapper( + config = config, + nullableInput = false, + inputType = rowTypeInfo.asInstanceOf[TypeInformation[Any]], + expectedType = expectedType.get, + conversionOperatorName = "DataStreamAggregateConversion", + fieldNames = getRowType.getFieldNames.asScala + )) + .name(mapName) + case _ => result + } + } +} + +object DataStreamAggregate { + + private def transformToPropertyReads(namedProperties: Seq[WindowProperty]) + : Array[WindowPropertyRead[_ <: Any]] = namedProperties.map { + case WindowStart(_) => new WindowStartRead() + case WindowEnd(_) => new WindowEndRead() + }.toArray + + private def createKeyedWindowedStream(groupWindow: LogicalWindow, stream: KeyedStream[Row, Tuple]) + : WindowedStream[Row, Tuple, _ <: DataStreamWindow] = groupWindow match { + + case ProcessingTimeTumblingGroupWindow(_, size) if isTimeInterval(size.resultType) => + stream.window(TumblingProcessingTimeWindows.of(asTime(size))) + + case ProcessingTimeTumblingGroupWindow(_, size) => + stream.countWindow(asCount(size)) + + case EventTimeTumblingGroupWindow(_, _, size) if isTimeInterval(size.resultType) => + stream.window(TumblingEventTimeWindows.of(asTime(size))) + + case EventTimeTumblingGroupWindow(_, _, size) => + // TODO: EventTimeTumblingGroupWindow should sort the stream on event time + // before applying the windowing logic. Otherwise, this would be the same as a + // ProcessingTimeTumblingGroupWindow + throw new UnsupportedOperationException("Event-time grouping windows on row intervals are " + + "currently not supported.") + + case ProcessingTimeSlidingGroupWindow(_, size, slide) if isTimeInterval(size.resultType) => + stream.window(SlidingProcessingTimeWindows.of(asTime(size), asTime(slide))) + + case ProcessingTimeSlidingGroupWindow(_, size, slide) => + stream.countWindow(asCount(size), asCount(slide)) + + case EventTimeSlidingGroupWindow(_, _, size, slide) if isTimeInterval(size.resultType) => + stream.window(SlidingEventTimeWindows.of(asTime(size), asTime(slide))) + + case EventTimeSlidingGroupWindow(_, _, size, slide) => + // TODO: EventTimeTumblingGroupWindow should sort the stream on event time + // before applying the windowing logic. Otherwise, this would be the same as a + // ProcessingTimeTumblingGroupWindow + throw new UnsupportedOperationException("Event-time grouping windows on row intervals are " + + "currently not supported.") + + case ProcessingTimeSessionGroupWindow(_, gap: Expression) => + stream.window(ProcessingTimeSessionWindows.withGap(asTime(gap))) + + case EventTimeSessionGroupWindow(_, _, gap) => + stream.window(EventTimeSessionWindows.withGap(asTime(gap))) + } + + private def createNonKeyedWindowedStream(groupWindow: LogicalWindow, stream: DataStream[Row]) + : AllWindowedStream[Row, _ <: DataStreamWindow] = groupWindow match { + + case ProcessingTimeTumblingGroupWindow(_, size) if isTimeInterval(size.resultType) => + stream.windowAll(TumblingProcessingTimeWindows.of(asTime(size))) + + case ProcessingTimeTumblingGroupWindow(_, size) => + stream.countWindowAll(asCount(size)) + + case EventTimeTumblingGroupWindow(_, _, size) if isTimeInterval(size.resultType) => + stream.windowAll(TumblingEventTimeWindows.of(asTime(size))) + + case EventTimeTumblingGroupWindow(_, _, size) => + // TODO: EventTimeTumblingGroupWindow should sort the stream on event time + // before applying the windowing logic. Otherwise, this would be the same as a + // ProcessingTimeTumblingGroupWindow + throw new UnsupportedOperationException("Event-time grouping windows on row intervals are " + + "currently not supported.") + + case ProcessingTimeSlidingGroupWindow(_, size, slide) if isTimeInterval(size.resultType) => + stream.windowAll(SlidingProcessingTimeWindows.of(asTime(size), asTime(slide))) + + case ProcessingTimeSlidingGroupWindow(_, size, slide) => + stream.countWindowAll(asCount(size), asCount(slide)) + + case EventTimeSlidingGroupWindow(_, _, size, slide) if isTimeInterval(size.resultType) => + stream.windowAll(SlidingEventTimeWindows.of(asTime(size), asTime(slide))) + + case EventTimeSlidingGroupWindow(_, _, size, slide) => + // TODO: EventTimeTumblingGroupWindow should sort the stream on event time + // before applying the windowing logic. Otherwise, this would be the same as a + // ProcessingTimeTumblingGroupWindow + throw new UnsupportedOperationException("Event-time grouping windows on row intervals are " + + "currently not supported.") + + case ProcessingTimeSessionGroupWindow(_, gap) => + stream.windowAll(ProcessingTimeSessionWindows.withGap(asTime(gap))) + + case EventTimeSessionGroupWindow(_, _, gap) => + stream.windowAll(EventTimeSessionWindows.withGap(asTime(gap))) + } + + def asTime(expr: Expression): Time = expr match { + case Literal(value: Long, TimeIntervalTypeInfo.INTERVAL_MILLIS) => Time.milliseconds(value) + case _ => throw new IllegalArgumentException() + } + + def asCount(expr: Expression): Long = expr match { + case Literal(value: Long, RowIntervalTypeInfo.INTERVAL_ROWS) => value + case _ => throw new IllegalArgumentException() + } +} + http://git-wip-us.apache.org/repos/asf/flink/blob/44f3977e/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/FlinkRuleSets.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/FlinkRuleSets.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/FlinkRuleSets.scala index 3ed4385..638deac 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/FlinkRuleSets.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/FlinkRuleSets.scala @@ -146,6 +146,7 @@ object FlinkRuleSets { UnionEliminatorRule.INSTANCE, // translate to DataStream nodes + DataStreamAggregateRule.INSTANCE, DataStreamCalcRule.INSTANCE, DataStreamScanRule.INSTANCE, DataStreamUnionRule.INSTANCE, http://git-wip-us.apache.org/repos/asf/flink/blob/44f3977e/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/dataSet/DataSetAggregateRule.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/dataSet/DataSetAggregateRule.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/dataSet/DataSetAggregateRule.scala index 9f78adb..72ed27e 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/dataSet/DataSetAggregateRule.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/dataSet/DataSetAggregateRule.scala @@ -40,13 +40,13 @@ class DataSetAggregateRule // check if we have distinct aggregates val distinctAggs = agg.getAggCallList.exists(_.isDistinct) if (distinctAggs) { - throw new TableException("DISTINCT aggregates are currently not supported.") + throw TableException("DISTINCT aggregates are currently not supported.") } // check if we have grouping sets val groupSets = agg.getGroupSets.size() != 1 || agg.getGroupSets.get(0) != agg.getGroupSet if (groupSets || agg.indicator) { - throw new TableException("GROUPING SETS are currently not supported.") + throw TableException("GROUPING SETS are currently not supported.") } !distinctAggs && !groupSets && !agg.indicator http://git-wip-us.apache.org/repos/asf/flink/blob/44f3977e/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/datastream/DataStreamAggregateRule.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/datastream/DataStreamAggregateRule.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/datastream/DataStreamAggregateRule.scala new file mode 100644 index 0000000..dff2adc --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/datastream/DataStreamAggregateRule.scala @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.api.table.plan.rules.datastream + +import org.apache.calcite.plan.{Convention, RelOptRule, RelOptRuleCall, RelTraitSet} +import org.apache.calcite.rel.RelNode +import org.apache.calcite.rel.convert.ConverterRule +import org.apache.flink.api.table.TableException +import org.apache.flink.api.table.expressions.Alias +import org.apache.flink.api.table.plan.logical.rel.LogicalWindowAggregate +import org.apache.flink.api.table.plan.nodes.datastream.{DataStreamAggregate, DataStreamConvention} + +import scala.collection.JavaConversions._ + +class DataStreamAggregateRule + extends ConverterRule( + classOf[LogicalWindowAggregate], + Convention.NONE, + DataStreamConvention.INSTANCE, + "DataStreamAggregateRule") + { + + override def matches(call: RelOptRuleCall): Boolean = { + val agg: LogicalWindowAggregate = call.rel(0).asInstanceOf[LogicalWindowAggregate] + + // check if we have distinct aggregates + val distinctAggs = agg.getAggCallList.exists(_.isDistinct) + if (distinctAggs) { + throw TableException("DISTINCT aggregates are currently not supported.") + } + + // check if we have grouping sets + val groupSets = agg.getGroupSets.size() != 1 || agg.getGroupSets.get(0) != agg.getGroupSet + if (groupSets || agg.indicator) { + throw TableException("GROUPING SETS are currently not supported.") + } + + !distinctAggs && !groupSets && !agg.indicator + } + + override def convert(rel: RelNode): RelNode = { + val agg: LogicalWindowAggregate = rel.asInstanceOf[LogicalWindowAggregate] + val traitSet: RelTraitSet = rel.getTraitSet.replace(DataStreamConvention.INSTANCE) + val convInput: RelNode = RelOptRule.convert(agg.getInput, DataStreamConvention.INSTANCE) + + new DataStreamAggregate( + agg.getWindow, + agg.getNamedProperties, + rel.getCluster, + traitSet, + convInput, + agg.getNamedAggCalls, + rel.getRowType, + agg.getInput.getRowType, + agg.getGroupSet.toArray) + } + } + +object DataStreamAggregateRule { + val INSTANCE: RelOptRule = new DataStreamAggregateRule +} + http://git-wip-us.apache.org/repos/asf/flink/blob/44f3977e/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/AggregateAllWindowFunction.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/AggregateAllWindowFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/AggregateAllWindowFunction.scala new file mode 100644 index 0000000..86f8a20 --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/AggregateAllWindowFunction.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.api.table.runtime.aggregate + +import java.lang.Iterable + +import org.apache.flink.api.common.functions.RichGroupReduceFunction +import org.apache.flink.api.table.Row +import org.apache.flink.configuration.Configuration +import org.apache.flink.streaming.api.functions.windowing.RichAllWindowFunction +import org.apache.flink.streaming.api.windowing.windows.Window +import org.apache.flink.util.Collector + +class AggregateAllWindowFunction( + propertyReads: Array[WindowPropertyRead[_ <: Any]], + groupReduceFunction: RichGroupReduceFunction[Row, Row]) + extends RichAllWindowFunction[Row, Row, Window] { + + private var propertyCollector: PropertyCollector = _ + + override def open(parameters: Configuration): Unit = { + groupReduceFunction.open(parameters) + propertyCollector = new PropertyCollector(propertyReads) + } + + override def apply(window: Window, input: Iterable[Row], out: Collector[Row]): Unit = { + + // extract the properties from window + propertyReads.foreach(_.extract(window)) + + // set final collector + propertyCollector.finalCollector = out + + // call wrapped reduce function with property collector + groupReduceFunction.reduce(input, propertyCollector) + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/44f3977e/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/AggregateReduceCombineFunction.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/AggregateReduceCombineFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/AggregateReduceCombineFunction.scala new file mode 100644 index 0000000..ca074cc --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/AggregateReduceCombineFunction.scala @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.api.table.runtime.aggregate + +import java.lang.Iterable + +import org.apache.flink.api.common.functions.{CombineFunction, RichGroupReduceFunction} +import org.apache.flink.api.table.Row +import org.apache.flink.configuration.Configuration +import org.apache.flink.util.{Collector, Preconditions} + +import scala.collection.JavaConversions._ + + +/** + * It wraps the aggregate logic inside of + * [[org.apache.flink.api.java.operators.GroupReduceOperator]] and + * [[org.apache.flink.api.java.operators.GroupCombineOperator]] + * + * @param aggregates The aggregate functions. + * @param groupKeysMapping The index mapping of group keys between intermediate aggregate Row + * and output Row. + * @param aggregateMapping The index mapping between aggregate function list and aggregated value + * index in output Row. + */ +class AggregateReduceCombineFunction( + private val aggregates: Array[Aggregate[_ <: Any]], + private val groupKeysMapping: Array[(Int, Int)], + private val aggregateMapping: Array[(Int, Int)], + private val intermediateRowArity: Int, + private val finalRowArity: Int) + extends RichGroupReduceFunction[Row, Row] with CombineFunction[Row, Row] { + + private var aggregateBuffer: Row = _ + private var output: Row = _ + + override def open(config: Configuration): Unit = { + Preconditions.checkNotNull(aggregates) + Preconditions.checkNotNull(groupKeysMapping) + aggregateBuffer = new Row(intermediateRowArity) + output = new Row(finalRowArity) + } + + /** + * For grouped intermediate aggregate Rows, merge all of them into aggregate buffer, + * calculate aggregated values output by aggregate buffer, and set them into output + * Row based on the mapping relation between intermediate aggregate Row and output Row. + * + * @param records Grouped intermediate aggregate Rows iterator. + * @param out The collector to hand results to. + * + */ + override def reduce(records: Iterable[Row], out: Collector[Row]): Unit = { + + // Initiate intermediate aggregate value. + aggregates.foreach(_.initiate(aggregateBuffer)) + + // Merge intermediate aggregate value to buffer. + var last: Row = null + records.foreach((record) => { + aggregates.foreach(_.merge(record, aggregateBuffer)) + last = record + }) + + // Set group keys value to final output. + groupKeysMapping.foreach { + case (after, previous) => + output.setField(after, last.productElement(previous)) + } + + // Evaluate final aggregate value and set to output. + aggregateMapping.foreach { + case (after, previous) => + output.setField(after, aggregates(previous).evaluate(aggregateBuffer)) + } + + out.collect(output) + } + + /** + * For sub-grouped intermediate aggregate Rows, merge all of them into aggregate buffer, + * + * @param records Sub-grouped intermediate aggregate Rows iterator. + * @return Combined intermediate aggregate Row. + * + */ + override def combine(records: Iterable[Row]): Row = { + + // Initiate intermediate aggregate value. + aggregates.foreach(_.initiate(aggregateBuffer)) + + // Merge intermediate aggregate value to buffer. + var last: Row = null + records.foreach((record) => { + aggregates.foreach(_.merge(record, aggregateBuffer)) + last = record + }) + + // Set group keys to aggregateBuffer. + for (i <- groupKeysMapping.indices) { + aggregateBuffer.setField(i, last.productElement(i)) + } + + aggregateBuffer + } +}
