This is an automated email from the ASF dual-hosted git repository. dwysakowicz pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/flink.git
commit 1d167a3670614901e4ef011af92b4045c7eb1612 Author: Dawid Wysakowicz <[email protected]> AuthorDate: Sat Jun 1 19:08:56 2019 +0200 [FLINK-12906][table-planner][table-api-java] Ported OperationTreeBuilder to table-api-java module --- .../java/internal/StreamTableEnvironmentImpl.java | 8 +- .../flink/table/api/EnvironmentSettings.java | 7 + .../table/api/internal/TableEnvironmentImpl.java | 28 +- .../expressions/utils/ApiExpressionUtils.java | 29 +- .../operations/OperationExpressionsUtils.java | 4 +- .../table/operations/OperationTreeBuilder.java | 680 +++++++++++++++++++-- .../internal/StreamTableEnvironmentImpl.scala | 10 +- .../flink/table/expressions/ExpressionUtils.java | 8 +- .../operations/OperationTreeBuilderFactory.java | 44 -- .../flink/table/api/internal/TableEnvImpl.scala | 17 +- .../operations/OperationTreeBuilderImpl.scala | 600 ------------------ .../api/stream/StreamTableEnvironmentTest.scala | 3 +- .../flink/table/api/stream/sql/AggregateTest.scala | 3 +- .../apache/flink/table/utils/TableTestBase.scala | 6 +- 14 files changed, 702 insertions(+), 745 deletions(-) diff --git a/flink-table/flink-table-api-java-bridge/src/main/java/org/apache/flink/table/api/java/internal/StreamTableEnvironmentImpl.java b/flink-table/flink-table-api-java-bridge/src/main/java/org/apache/flink/table/api/java/internal/StreamTableEnvironmentImpl.java index 6b37690..05815dd 100644 --- a/flink-table/flink-table-api-java-bridge/src/main/java/org/apache/flink/table/api/java/internal/StreamTableEnvironmentImpl.java +++ b/flink-table/flink-table-api-java-bridge/src/main/java/org/apache/flink/table/api/java/internal/StreamTableEnvironmentImpl.java @@ -85,8 +85,9 @@ public final class StreamTableEnvironmentImpl extends TableEnvironmentImpl imple TableConfig tableConfig, StreamExecutionEnvironment executionEnvironment, Planner planner, - Executor executor) { - super(catalogManager, tableConfig, executor, functionCatalog, planner); + Executor executor, + boolean isStreaming) { + super(catalogManager, tableConfig, executor, functionCatalog, planner, isStreaming); this.executionEnvironment = executionEnvironment; } @@ -119,7 +120,8 @@ public final class StreamTableEnvironmentImpl extends TableEnvironmentImpl imple tableConfig, executionEnvironment, planner, - executor + executor, + !settings.isBatchMode() ); } diff --git a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/EnvironmentSettings.java b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/EnvironmentSettings.java index 37ba179..70b7ffd 100644 --- a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/EnvironmentSettings.java +++ b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/EnvironmentSettings.java @@ -114,6 +114,13 @@ public class EnvironmentSettings { return builtInDatabaseName; } + /** + * Tells if the {@link TableEnvironment} should work in a batch or streaming mode. + */ + public boolean isBatchMode() { + return isBatchMode; + } + @Internal public Map<String, String> toPlannerProperties() { Map<String, String> properties = new HashMap<>(toCommonProperties()); diff --git a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/internal/TableEnvironmentImpl.java b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/internal/TableEnvironmentImpl.java index 727727a..9b04f56 100644 --- a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/internal/TableEnvironmentImpl.java +++ b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/internal/TableEnvironmentImpl.java @@ -36,7 +36,6 @@ import org.apache.flink.table.catalog.CatalogManager; import org.apache.flink.table.catalog.ConnectorCatalogTable; import org.apache.flink.table.catalog.ExternalCatalog; import org.apache.flink.table.catalog.FunctionCatalog; -import org.apache.flink.table.catalog.FunctionLookup; import org.apache.flink.table.catalog.ObjectPath; import org.apache.flink.table.catalog.QueryOperationCatalogView; import org.apache.flink.table.catalog.exceptions.DatabaseNotExistException; @@ -46,7 +45,6 @@ import org.apache.flink.table.descriptors.ConnectorDescriptor; import org.apache.flink.table.descriptors.StreamTableDescriptor; import org.apache.flink.table.descriptors.TableDescriptor; import org.apache.flink.table.expressions.TableReferenceExpression; -import org.apache.flink.table.expressions.resolver.lookups.TableReferenceLookup; import org.apache.flink.table.functions.ScalarFunction; import org.apache.flink.table.operations.CatalogQueryOperation; import org.apache.flink.table.operations.CatalogSinkModifyOperation; @@ -60,7 +58,6 @@ import org.apache.flink.table.sources.TableSource; import org.apache.flink.table.sources.TableSourceValidation; import org.apache.flink.util.StringUtils; -import java.lang.reflect.Method; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; @@ -92,7 +89,8 @@ public class TableEnvironmentImpl implements TableEnvironment { TableConfig tableConfig, Executor executor, FunctionCatalog functionCatalog, - Planner planner) { + Planner planner, + boolean isStreaming) { this.catalogManager = catalogManager; this.execEnv = executor; @@ -103,32 +101,16 @@ public class TableEnvironmentImpl implements TableEnvironment { this.functionCatalog = functionCatalog; this.planner = planner; - this.operationTreeBuilder = lookupTreeBuilder( + this.operationTreeBuilder = OperationTreeBuilder.create( + functionCatalog, path -> { Optional<CatalogQueryOperation> catalogTableOperation = scanInternal(path); return catalogTableOperation.map(tableOperation -> new TableReferenceExpression(path, tableOperation)); }, - functionCatalog + isStreaming ); } - private static OperationTreeBuilder lookupTreeBuilder( - TableReferenceLookup tableReferenceLookup, - FunctionLookup functionDefinitionCatalog) { - try { - Class<?> clazz = Class.forName("org.apache.flink.table.operations.OperationTreeBuilderFactory"); - Method createMethod = clazz.getMethod( - "create", - TableReferenceLookup.class, - FunctionLookup.class); - - return (OperationTreeBuilder) createMethod.invoke(null, tableReferenceLookup, functionDefinitionCatalog); - } catch (Exception e) { - throw new TableException( - "Could not instantiate the operation builder. Make sure the planner module is on the classpath"); - } - } - @VisibleForTesting public Planner getPlanner() { return planner; diff --git a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/utils/ApiExpressionUtils.java b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/utils/ApiExpressionUtils.java index f04878f..5d0c2c3 100644 --- a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/utils/ApiExpressionUtils.java +++ b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/utils/ApiExpressionUtils.java @@ -31,6 +31,7 @@ import org.apache.flink.table.expressions.TypeLiteralExpression; import org.apache.flink.table.expressions.UnresolvedCallExpression; import org.apache.flink.table.expressions.UnresolvedReferenceExpression; import org.apache.flink.table.expressions.ValueLiteralExpression; +import org.apache.flink.table.functions.BuiltInFunctionDefinition; import org.apache.flink.table.functions.FunctionDefinition; import org.apache.flink.table.functions.FunctionKind; import org.apache.flink.table.types.DataType; @@ -116,18 +117,34 @@ public final class ApiExpressionUtils { /** * Checks if the expression is a function call of given type. * - * @param expr expression to check + * @param expression expression to check * @param kind expected type of function * @return true if the expression is function call of given type, false otherwise */ - public static boolean isFunctionOfKind(Expression expr, FunctionKind kind) { - if (expr instanceof UnresolvedCallExpression) { - return ((UnresolvedCallExpression) expr).getFunctionDefinition().getKind() == kind; + public static boolean isFunctionOfKind(Expression expression, FunctionKind kind) { + if (expression instanceof UnresolvedCallExpression) { + return ((UnresolvedCallExpression) expression).getFunctionDefinition().getKind() == kind; } - if (expr instanceof CallExpression) { - return ((CallExpression) expr).getFunctionDefinition().getKind() == kind; + if (expression instanceof CallExpression) { + return ((CallExpression) expression).getFunctionDefinition().getKind() == kind; } return false; + } + /** + * Checks if the given expression is a given builtin function. + * + * @param expression expression to check + * @param functionDefinition expected function definition + * @return true if the given expression is a given function call + */ + public static boolean isFunction(Expression expression, BuiltInFunctionDefinition functionDefinition) { + if (expression instanceof UnresolvedCallExpression) { + return ((UnresolvedCallExpression) expression).getFunctionDefinition() == functionDefinition; + } + if (expression instanceof CallExpression) { + return ((CallExpression) expression).getFunctionDefinition() == functionDefinition; + } + return false; } } diff --git a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/operations/OperationExpressionsUtils.java b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/operations/OperationExpressionsUtils.java index 8e9912b..eb2030d 100644 --- a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/operations/OperationExpressionsUtils.java +++ b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/operations/OperationExpressionsUtils.java @@ -155,8 +155,8 @@ public class OperationExpressionsUtils { private final Map<Expression, String> properties; private AggregationAndPropertiesReplacer( - Map<Expression, String> aggregates, - Map<Expression, String> properties) { + Map<Expression, String> aggregates, + Map<Expression, String> properties) { this.aggregates = aggregates; this.properties = properties; } diff --git a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/operations/OperationTreeBuilder.java b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/operations/OperationTreeBuilder.java index 50eb8d9..37de6da 100644 --- a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/operations/OperationTreeBuilder.java +++ b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/operations/OperationTreeBuilder.java @@ -19,92 +19,676 @@ package org.apache.flink.table.operations; import org.apache.flink.annotation.Internal; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.table.api.GroupWindow; import org.apache.flink.table.api.OverWindow; +import org.apache.flink.table.api.TableException; +import org.apache.flink.table.api.ValidationException; +import org.apache.flink.table.catalog.FunctionLookup; import org.apache.flink.table.expressions.Expression; +import org.apache.flink.table.expressions.ExpressionUtils; +import org.apache.flink.table.expressions.LocalReferenceExpression; +import org.apache.flink.table.expressions.ResolvedExpression; +import org.apache.flink.table.expressions.UnresolvedCallExpression; +import org.apache.flink.table.expressions.UnresolvedReferenceExpression; +import org.apache.flink.table.expressions.resolver.ExpressionResolver; +import org.apache.flink.table.expressions.resolver.LookupCallResolver; +import org.apache.flink.table.expressions.resolver.lookups.TableReferenceLookup; +import org.apache.flink.table.expressions.utils.ApiExpressionDefaultVisitor; +import org.apache.flink.table.expressions.utils.ApiExpressionUtils; +import org.apache.flink.table.functions.AggregateFunctionDefinition; +import org.apache.flink.table.functions.BuiltInFunctionDefinitions; +import org.apache.flink.table.functions.FunctionDefinition; +import org.apache.flink.table.functions.FunctionKind; +import org.apache.flink.table.functions.TableFunctionDefinition; import org.apache.flink.table.operations.JoinQueryOperation.JoinType; +import org.apache.flink.table.operations.WindowAggregateQueryOperation.ResolvedGroupWindow; +import org.apache.flink.table.operations.utils.factories.AggregateOperationFactory; +import org.apache.flink.table.operations.utils.factories.AliasOperationUtils; +import org.apache.flink.table.operations.utils.factories.CalculatedTableFactory; +import org.apache.flink.table.operations.utils.factories.ColumnOperationUtils; +import org.apache.flink.table.operations.utils.factories.JoinOperationFactory; +import org.apache.flink.table.operations.utils.factories.ProjectionOperationFactory; +import org.apache.flink.table.operations.utils.factories.SetOperationFactory; +import org.apache.flink.table.operations.utils.factories.SortOperationFactory; +import org.apache.flink.table.types.DataType; +import org.apache.flink.table.types.logical.LogicalTypeRoot; +import org.apache.flink.table.types.logical.utils.LogicalTypeChecks; +import org.apache.flink.table.typeutils.FieldInfoUtils; +import org.apache.flink.util.Preconditions; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.HashSet; import java.util.List; import java.util.Optional; +import java.util.Set; +import java.util.stream.Collectors; + +import static org.apache.flink.table.expressions.utils.ApiExpressionUtils.unresolvedCall; +import static org.apache.flink.table.expressions.utils.ApiExpressionUtils.unresolvedRef; +import static org.apache.flink.table.expressions.utils.ApiExpressionUtils.valueLiteral; +import static org.apache.flink.table.operations.SetQueryOperation.SetQueryOperationType.INTERSECT; +import static org.apache.flink.table.operations.SetQueryOperation.SetQueryOperationType.MINUS; +import static org.apache.flink.table.operations.SetQueryOperation.SetQueryOperationType.UNION; /** - * Builder for validated {@link QueryOperation}s. - * - * <p>TODO. This is a temporary solution. The actual implementation should be ported. + * A builder for constructing validated {@link QueryOperation}s. */ @Internal -public interface OperationTreeBuilder { - QueryOperation project(List<Expression> projectList, QueryOperation child); +public final class OperationTreeBuilder { + + private final FunctionLookup functionCatalog; + private final TableReferenceLookup tableReferenceLookup; + private final LookupCallResolver lookupResolver; + + /** + * Utility classes for constructing a validated operation of certain type. + */ + private final ProjectionOperationFactory projectionOperationFactory; + private final SortOperationFactory sortOperationFactory; + private final CalculatedTableFactory calculatedTableFactory; + private final SetOperationFactory setOperationFactory; + private final AggregateOperationFactory aggregateOperationFactory; + private final JoinOperationFactory joinOperationFactory; + + private OperationTreeBuilder( + FunctionLookup functionLookup, + TableReferenceLookup tableReferenceLookup, + ProjectionOperationFactory projectionOperationFactory, + SortOperationFactory sortOperationFactory, + CalculatedTableFactory calculatedTableFactory, + SetOperationFactory setOperationFactory, + AggregateOperationFactory aggregateOperationFactory, + JoinOperationFactory joinOperationFactory) { + this.functionCatalog = functionLookup; + this.tableReferenceLookup = tableReferenceLookup; + this.projectionOperationFactory = projectionOperationFactory; + this.sortOperationFactory = sortOperationFactory; + this.calculatedTableFactory = calculatedTableFactory; + this.setOperationFactory = setOperationFactory; + this.aggregateOperationFactory = aggregateOperationFactory; + this.joinOperationFactory = joinOperationFactory; + this.lookupResolver = new LookupCallResolver(functionLookup); + } + + public static OperationTreeBuilder create( + FunctionLookup functionCatalog, + TableReferenceLookup tableReferenceLookup, + boolean isStreaming) { + return new OperationTreeBuilder( + functionCatalog, + tableReferenceLookup, + new ProjectionOperationFactory(), + new SortOperationFactory(isStreaming), + new CalculatedTableFactory(), + new SetOperationFactory(isStreaming), + new AggregateOperationFactory(isStreaming), + new JoinOperationFactory() + ); + } + + public QueryOperation project(List<Expression> projectList, QueryOperation child) { + return project(projectList, child, false); + } + + public QueryOperation project(List<Expression> projectList, QueryOperation child, boolean explicitAlias) { + projectList.forEach(p -> p.accept(new NoAggregateChecker( + "Aggregate functions are not supported in the select right after the aggregate" + + " or flatAggregate operation."))); + projectList.forEach(p -> p.accept(new NoWindowPropertyChecker( + "Window properties can only be used on windowed tables."))); + return projectInternal(projectList, child, explicitAlias, Collections.emptyList()); + } + + public QueryOperation project(List<Expression> projectList, QueryOperation child, List<OverWindow> overWindows) { + + Preconditions.checkArgument(!overWindows.isEmpty()); + + projectList.forEach(p -> p.accept(new NoWindowPropertyChecker( + "Window start and end properties are not available for Over windows."))); + + return projectInternal( + projectList, + child, + true, + overWindows); + } + + private QueryOperation projectInternal( + List<Expression> projectList, + QueryOperation child, + boolean explicitAlias, + List<OverWindow> overWindows) { + + ExpressionResolver resolver = ExpressionResolver.resolverFor(tableReferenceLookup, functionCatalog, child) + .withOverWindows(overWindows) + .build(); + List<ResolvedExpression> projections = resolver.resolve(projectList); + return projectionOperationFactory.create(projections, child, explicitAlias, resolver.postResolverFactory()); + } + + /** + * Adds additional columns. Existing fields will be replaced if replaceIfExist is true. + */ + public QueryOperation addColumns(boolean replaceIfExist, List<Expression> fieldLists, QueryOperation child) { + final List<Expression> newColumns; + if (replaceIfExist) { + String[] fieldNames = child.getTableSchema().getFieldNames(); + newColumns = ColumnOperationUtils.addOrReplaceColumns(Arrays.asList(fieldNames), fieldLists); + } else { + newColumns = new ArrayList<>(fieldLists); + newColumns.add(0, new UnresolvedReferenceExpression("*")); + } + return project(newColumns, child, false); + } + + public QueryOperation renameColumns(List<Expression> aliases, QueryOperation child) { + + ExpressionResolver resolver = getResolver(child); + String[] inputFieldNames = child.getTableSchema().getFieldNames(); + List<Expression> validateAliases = ColumnOperationUtils.renameColumns( + Arrays.asList(inputFieldNames), + resolver.resolveExpanding(aliases)); + + return project(validateAliases, child, false); + } + + public QueryOperation dropColumns(List<Expression> fieldLists, QueryOperation child) { + + ExpressionResolver resolver = getResolver(child); + String[] inputFieldNames = child.getTableSchema().getFieldNames(); + List<Expression> finalFields = ColumnOperationUtils.dropFields( + Arrays.asList(inputFieldNames), + resolver.resolveExpanding(fieldLists)); + + return project(finalFields, child, false); + } + + public QueryOperation aggregate(List<Expression> groupingExpressions, List<Expression> aggregates, QueryOperation child) { - QueryOperation project(List<Expression> projectList, QueryOperation child, boolean explicitAlias); + ExpressionResolver resolver = getResolver(child); - QueryOperation project(List<Expression> projectList, QueryOperation child, List<OverWindow> overWindows); + List<ResolvedExpression> resolvedGroupings = resolver.resolve(groupingExpressions); + List<ResolvedExpression> resolvedAggregates = resolver.resolve(aggregates); - QueryOperation windowAggregate( + return aggregateOperationFactory.createAggregate(resolvedGroupings, resolvedAggregates, child); + } + + public QueryOperation windowAggregate( List<Expression> groupingExpressions, GroupWindow window, List<Expression> windowProperties, List<Expression> aggregates, - QueryOperation child); + QueryOperation child) { - QueryOperation join( - QueryOperation left, - QueryOperation right, - JoinType joinType, - Optional<Expression> condition, - boolean correlated); + ExpressionResolver resolver = getResolver(child); + ResolvedGroupWindow resolvedWindow = aggregateOperationFactory.createResolvedWindow(window, resolver); - QueryOperation joinLateral( - QueryOperation left, - Expression tableFunction, - JoinType joinType, - Optional<Expression> condition); + ExpressionResolver resolverWithWindowReferences = ExpressionResolver.resolverFor( + tableReferenceLookup, + functionCatalog, + child) + .withLocalReferences( + new LocalReferenceExpression( + resolvedWindow.getAlias(), + resolvedWindow.getTimeAttribute().getOutputDataType())) + .build(); - Expression resolveExpression(Expression expression, QueryOperation... tableOperation); + List<ResolvedExpression> convertedGroupings = resolverWithWindowReferences.resolve(groupingExpressions); + List<ResolvedExpression> convertedAggregates = resolverWithWindowReferences.resolve(aggregates); + List<ResolvedExpression> convertedProperties = resolverWithWindowReferences.resolve(windowProperties); - QueryOperation sort(List<Expression> fields, QueryOperation child); + return aggregateOperationFactory.createWindowAggregate( + convertedGroupings, + convertedAggregates, + convertedProperties, + resolvedWindow, + child); + } - QueryOperation limitWithOffset(int offset, QueryOperation child); + public QueryOperation join( + QueryOperation left, + QueryOperation right, + JoinType joinType, + Optional<Expression> condition, + boolean correlated) { + ExpressionResolver resolver = ExpressionResolver.resolverFor(tableReferenceLookup, functionCatalog, left, right) + .build(); + Optional<ResolvedExpression> resolvedCondition = condition.map(expr -> resolveSingleExpression(expr, resolver)); - QueryOperation limitWithFetch(int fetch, QueryOperation child); + return joinOperationFactory.create( + left, + right, + joinType, + resolvedCondition.orElse(valueLiteral(true)), + correlated); + } - QueryOperation alias(List<Expression> fields, QueryOperation child); + public QueryOperation joinLateral( + QueryOperation left, + Expression tableFunction, + JoinType joinType, + Optional<Expression> condition) { + ExpressionResolver resolver = getResolver(left); + ResolvedExpression resolvedFunction = resolveSingleExpression(tableFunction, resolver); - QueryOperation filter(Expression condition, QueryOperation child); + QueryOperation temporalTable = calculatedTableFactory.create( + resolvedFunction, + left.getTableSchema().getFieldNames()); - QueryOperation distinct(QueryOperation child); + return join(left, temporalTable, joinType, condition, true); + } - QueryOperation minus(QueryOperation left, QueryOperation right, boolean all); + public Expression resolveExpression(Expression expression, QueryOperation... tableOperation) { + ExpressionResolver resolver = ExpressionResolver.resolverFor( + tableReferenceLookup, + functionCatalog, + tableOperation).build(); - QueryOperation intersect(QueryOperation left, QueryOperation right, boolean all); + return resolveSingleExpression(expression, resolver); + } - QueryOperation union(QueryOperation left, QueryOperation right, boolean all); + private ResolvedExpression resolveSingleExpression(Expression expression, ExpressionResolver resolver) { + List<ResolvedExpression> resolvedExpression = resolver.resolve(Collections.singletonList(expression)); + if (resolvedExpression.size() != 1) { + throw new ValidationException("Expected single expression"); + } else { + return resolvedExpression.get(0); + } + } - /* Extensions */ + public QueryOperation sort(List<Expression> fields, QueryOperation child) { - QueryOperation addColumns(boolean replaceIfExist, List<Expression> fieldLists, QueryOperation child); + ExpressionResolver resolver = getResolver(child); + List<ResolvedExpression> resolvedFields = resolver.resolve(fields); - QueryOperation renameColumns(List<Expression> aliases, QueryOperation child); + return sortOperationFactory.createSort(resolvedFields, child, resolver.postResolverFactory()); + } - QueryOperation dropColumns(List<Expression> fieldLists, QueryOperation child); + public QueryOperation limitWithOffset(int offset, QueryOperation child) { + return sortOperationFactory.createLimitWithOffset(offset, child); + } - QueryOperation aggregate(List<Expression> groupingExpressions, List<Expression> aggregates, QueryOperation child); + public QueryOperation limitWithFetch(int fetch, QueryOperation child) { + return sortOperationFactory.createLimitWithFetch(fetch, child); + } - QueryOperation map(Expression mapFunction, QueryOperation child); + public QueryOperation alias(List<Expression> fields, QueryOperation child) { + List<Expression> newFields = AliasOperationUtils.createAliasList(fields, child); - QueryOperation flatMap(Expression tableFunction, QueryOperation child); + return project(newFields, child, true); + } - QueryOperation aggregate(List<Expression> groupingExpressions, Expression aggregate, QueryOperation child); + public QueryOperation filter(Expression condition, QueryOperation child) { - QueryOperation tableAggregate( - List<Expression> groupingExpressions, - Expression tableAggFunction, - QueryOperation child); + ExpressionResolver resolver = getResolver(child); + ResolvedExpression resolvedExpression = resolveSingleExpression(condition, resolver); + DataType conditionType = resolvedExpression.getOutputDataType(); + if (!LogicalTypeChecks.hasRoot(conditionType.getLogicalType(), LogicalTypeRoot.BOOLEAN)) { + throw new ValidationException("Filter operator requires a boolean expression as input," + + " but $condition is of type " + conditionType); + } - QueryOperation windowTableAggregate( - List<Expression> groupingExpressions, - GroupWindow window, - List<Expression> windowProperties, - Expression tableAggFunction, - QueryOperation child); + return new FilterQueryOperation(resolvedExpression, child); + } + + public QueryOperation distinct(QueryOperation child) { + return new DistinctQueryOperation(child); + } + + public QueryOperation minus(QueryOperation left, QueryOperation right, boolean all) { + return setOperationFactory.create(MINUS, left, right, all); + } + + public QueryOperation intersect(QueryOperation left, QueryOperation right, boolean all) { + return setOperationFactory.create(INTERSECT, left, right, all); + } + + public QueryOperation union(QueryOperation left, QueryOperation right, boolean all) { + return setOperationFactory.create(UNION, left, right, all); + } + + public QueryOperation map(Expression mapFunction, QueryOperation child) { + + Expression resolvedMapFunction = mapFunction.accept(lookupResolver); + + if (!ApiExpressionUtils.isFunctionOfKind(resolvedMapFunction, FunctionKind.SCALAR)) { + throw new ValidationException("Only a scalar function can be used in the map operator."); + } + + Expression expandedFields = unresolvedCall(BuiltInFunctionDefinitions.FLATTEN, resolvedMapFunction); + return project(Collections.singletonList(expandedFields), child, false); + } + + public QueryOperation flatMap(Expression tableFunction, QueryOperation child) { + + Expression resolvedTableFunction = tableFunction.accept(lookupResolver); + + if (!ApiExpressionUtils.isFunctionOfKind(resolvedTableFunction, FunctionKind.TABLE)) { + throw new ValidationException("Only a table function can be used in the flatMap operator."); + } + + TypeInformation<?> resultType = ((TableFunctionDefinition) ((UnresolvedCallExpression) resolvedTableFunction) + .getFunctionDefinition()) + .getResultType(); + List<String> originFieldNames = Arrays.asList(FieldInfoUtils.getFieldNames(resultType)); + + List<String> childFields = Arrays.asList(child.getTableSchema().getFieldNames()); + Set<String> usedFieldNames = new HashSet<>(childFields); + + List<Expression> args = new ArrayList<>(); + for (String originFieldName : originFieldNames) { + String resultName = getUniqueName(originFieldName, usedFieldNames); + usedFieldNames.add(resultName); + args.add(valueLiteral(resultName)); + } + + args.add(0, resolvedTableFunction); + Expression renamedTableFunction = unresolvedCall( + BuiltInFunctionDefinitions.AS, + args.toArray(new Expression[0])); + QueryOperation joinNode = joinLateral(child, renamedTableFunction, JoinType.INNER, Optional.empty()); + QueryOperation rightNode = dropColumns( + childFields.stream().map(UnresolvedReferenceExpression::new).collect(Collectors.toList()), + joinNode); + return alias( + originFieldNames.stream().map(UnresolvedReferenceExpression::new).collect(Collectors.toList()), + rightNode); + } + + public QueryOperation aggregate(List<Expression> groupingExpressions, Expression aggregate, QueryOperation child) { + Expression resolvedAggregate = aggregate.accept(lookupResolver); + AggregateWithAlias aggregateWithAlias = resolvedAggregate.accept(new ExtractAliasAndAggregate()); + + // turn agg to a named agg, because it will be verified later. + String[] childNames = child.getTableSchema().getFieldNames(); + Expression aggregateRenamed = addAliasToTheCallInGroupings( + Arrays.asList(childNames), + Collections.singletonList(aggregateWithAlias.aggregate)).get(0); + + // get agg table + QueryOperation aggregateOperation = this.aggregate( + groupingExpressions, + Collections.singletonList(aggregateRenamed), + child); + + // flatten the aggregate function + String[] aggNames = aggregateOperation.getTableSchema().getFieldNames(); + List<Expression> flattenedExpressions = Arrays.asList(aggNames) + .subList(0, groupingExpressions.size()) + .stream() + .map(ApiExpressionUtils::unresolvedRef) + .collect(Collectors.toCollection(ArrayList::new)); + + flattenedExpressions.add(unresolvedCall( + BuiltInFunctionDefinitions.FLATTEN, + unresolvedRef(aggNames[aggNames.length - 1]))); + + QueryOperation flattenedProjection = this.project(flattenedExpressions, aggregateOperation); + + // add alias + return aliasBackwardFields(flattenedProjection, aggregateWithAlias.aliases, groupingExpressions.size()); + } + + private static class AggregateWithAlias { + private final UnresolvedCallExpression aggregate; + private final List<String> aliases; + + private AggregateWithAlias(UnresolvedCallExpression aggregate, List<String> aliases) { + this.aggregate = aggregate; + this.aliases = aliases; + } + } + + private static class ExtractAliasAndAggregate extends ApiExpressionDefaultVisitor<AggregateWithAlias> { + @Override + public AggregateWithAlias visit(UnresolvedCallExpression unresolvedCall) { + if (ApiExpressionUtils.isFunction(unresolvedCall, BuiltInFunctionDefinitions.AS)) { + Expression expression = unresolvedCall.getChildren().get(0); + if (expression instanceof UnresolvedCallExpression) { + List<String> aliases = extractAliases(unresolvedCall); + + return getAggregate((UnresolvedCallExpression) expression, aliases) + .orElseGet(() -> defaultMethod(unresolvedCall)); + } else { + return defaultMethod(unresolvedCall); + } + } + + return getAggregate(unresolvedCall, Collections.emptyList()).orElseGet(() -> defaultMethod(unresolvedCall)); + } + + private List<String> extractAliases(UnresolvedCallExpression unresolvedCall) { + return unresolvedCall.getChildren() + .subList(1, unresolvedCall.getChildren().size()) + .stream() + .map(ex -> ExpressionUtils.extractValue(ex, String.class) + .orElseThrow(() -> new TableException("Expected string literal as alias."))) + .collect(Collectors.toList()); + } + + private Optional<AggregateWithAlias> getAggregate( + UnresolvedCallExpression unresolvedCall, + List<String> aliases) { + FunctionDefinition functionDefinition = unresolvedCall.getFunctionDefinition(); + if (ApiExpressionUtils.isFunctionOfKind(unresolvedCall, FunctionKind.AGGREGATE)) { + final List<String> fieldNames; + if (aliases.isEmpty()) { + if (functionDefinition instanceof AggregateFunctionDefinition) { + TypeInformation<?> resultTypeInfo = ((AggregateFunctionDefinition) functionDefinition) + .getResultTypeInfo(); + fieldNames = Arrays.asList(FieldInfoUtils.getFieldNames(resultTypeInfo)); + } else { + fieldNames = Collections.emptyList(); + } + } else { + fieldNames = aliases; + } + return Optional.of(new AggregateWithAlias(unresolvedCall, fieldNames)); + } else { + return Optional.empty(); + } + } + + @Override + protected AggregateWithAlias defaultMethod(Expression expression) { + throw new ValidationException("Aggregate function expected. Got: " + expression); + } + } + + public QueryOperation tableAggregate( + List<Expression> groupingExpressions, + Expression tableAggFunction, + QueryOperation child) { + + // Step1: add a default name to the call in the grouping expressions, e.g., groupBy(a % 5) to + // groupBy(a % 5 as TMP_0). We need a name for every column so that to perform alias for the + // table aggregate function in Step4. + List<Expression> newGroupingExpressions = addAliasToTheCallInGroupings( + Arrays.asList(child.getTableSchema().getFieldNames()), + groupingExpressions); + + // Step2: resolve expressions + ExpressionResolver resolver = getResolver(child); + List<ResolvedExpression> resolvedGroupings = resolver.resolve(newGroupingExpressions); + Tuple2<ResolvedExpression, List<String>> resolvedFunctionAndAlias = + aggregateOperationFactory.extractTableAggFunctionAndAliases( + resolveSingleExpression(tableAggFunction, resolver)); + + // Step3: create table agg operation + QueryOperation tableAggOperation = aggregateOperationFactory + .createAggregate(resolvedGroupings, Collections.singletonList(resolvedFunctionAndAlias.f0), child); + + // Step4: add a top project to alias the output fields of the table aggregate. + return aliasBackwardFields(tableAggOperation, resolvedFunctionAndAlias.f1, groupingExpressions.size()); + } + + public QueryOperation windowTableAggregate( + List<Expression> groupingExpressions, + GroupWindow window, + List<Expression> windowProperties, + Expression tableAggFunction, + QueryOperation child) { + + // Step1: add a default name to the call in the grouping expressions, e.g., groupBy(a % 5) to + // groupBy(a % 5 as TMP_0). We need a name for every column so that to perform alias for the + // table aggregate function in Step4. + List<Expression> newGroupingExpressions = addAliasToTheCallInGroupings( + Arrays.asList(child.getTableSchema().getFieldNames()), + groupingExpressions); + + // Step2: resolve expressions, including grouping, aggregates and window properties. + ExpressionResolver resolver = getResolver(child); + ResolvedGroupWindow resolvedWindow = aggregateOperationFactory.createResolvedWindow(window, resolver); + + ExpressionResolver resolverWithWindowReferences = ExpressionResolver.resolverFor( + tableReferenceLookup, + functionCatalog, + child) + .withLocalReferences( + new LocalReferenceExpression( + resolvedWindow.getAlias(), + resolvedWindow.getTimeAttribute().getOutputDataType())) + .build(); + + List<ResolvedExpression> convertedGroupings = resolverWithWindowReferences.resolve(newGroupingExpressions); + List<ResolvedExpression> convertedAggregates = resolverWithWindowReferences.resolve(Collections.singletonList( + tableAggFunction)); + List<ResolvedExpression> convertedProperties = resolverWithWindowReferences.resolve(windowProperties); + Tuple2<ResolvedExpression, List<String>> resolvedFunctionAndAlias = aggregateOperationFactory + .extractTableAggFunctionAndAliases(convertedAggregates.get(0)); + + // Step3: create window table agg operation + QueryOperation tableAggOperation = aggregateOperationFactory.createWindowAggregate( + convertedGroupings, + Collections.singletonList(resolvedFunctionAndAlias.f0), + convertedProperties, + resolvedWindow, + child); + + // Step4: add a top project to alias the output fields of the table aggregate. Also, project the + // window attribute. + return aliasBackwardFields(tableAggOperation, resolvedFunctionAndAlias.f1, groupingExpressions.size()); + } + + /** + * Rename fields in the input {@link QueryOperation}. + */ + private QueryOperation aliasBackwardFields( + QueryOperation inputOperation, + List<String> alias, + int aliasStartIndex) { + + if (!alias.isEmpty()) { + String[] namesBeforeAlias = inputOperation.getTableSchema().getFieldNames(); + List<String> namesAfterAlias = new ArrayList<>(Arrays.asList(namesBeforeAlias)); + for (int i = 0; i < alias.size(); i++) { + int withOffset = aliasStartIndex + i; + namesAfterAlias.remove(withOffset); + namesAfterAlias.add(withOffset, alias.get(i)); + } + + return this.alias(namesAfterAlias.stream() + .map(UnresolvedReferenceExpression::new) + .collect(Collectors.toList()), inputOperation); + } else { + return inputOperation; + } + } + + /** + * Add a default name to the call in the grouping expressions, e.g., groupBy(a % 5) to + * groupBy(a % 5 as TMP_0). + */ + private List<Expression> addAliasToTheCallInGroupings( + List<String> inputFieldNames, + List<Expression> groupingExpressions) { + + int attrNameCntr = 0; + Set<String> usedFieldNames = new HashSet<>(inputFieldNames); + + List<Expression> result = new ArrayList<>(); + for (Expression groupingExpression : groupingExpressions) { + if (groupingExpression instanceof UnresolvedCallExpression && + !ApiExpressionUtils.isFunction(groupingExpression, BuiltInFunctionDefinitions.AS)) { + String tempName = getUniqueName("TMP_" + attrNameCntr, usedFieldNames); + attrNameCntr += 1; + usedFieldNames.add(tempName); + result.add(unresolvedCall( + BuiltInFunctionDefinitions.AS, + groupingExpression, + valueLiteral(tempName))); + } else { + result.add(groupingExpression); + } + } + + return result; + } + + /** + * Return a unique name that does not exist in usedFieldNames according to the input name. + */ + private String getUniqueName(String inputName, Collection<String> usedFieldNames) { + int i = 0; + String resultName = inputName; + while (usedFieldNames.contains(resultName)) { + resultName = resultName + "_" + i; + i += 1; + } + return resultName; + } + + private ExpressionResolver getResolver(QueryOperation child) { + return ExpressionResolver.resolverFor(tableReferenceLookup, functionCatalog, child).build(); + } + + private static class NoWindowPropertyChecker extends ApiExpressionDefaultVisitor<Void> { + private final String exceptionMessage; + + private NoWindowPropertyChecker(String exceptionMessage) { + this.exceptionMessage = exceptionMessage; + } + + @Override + public Void visit(UnresolvedCallExpression call) { + FunctionDefinition functionDefinition = call.getFunctionDefinition(); + if (BuiltInFunctionDefinitions.WINDOW_PROPERTIES.contains(functionDefinition)) { + throw new ValidationException(exceptionMessage); + } + call.getChildren().forEach(expr -> expr.accept(this)); + return null; + } + + @Override + protected Void defaultMethod(Expression expression) { + return null; + } + } + + private static class NoAggregateChecker extends ApiExpressionDefaultVisitor<Void> { + private final String exceptionMessage; + + private NoAggregateChecker(String exceptionMessage) { + this.exceptionMessage = exceptionMessage; + } + + @Override + public Void visit(UnresolvedCallExpression call) { + if (ApiExpressionUtils.isFunctionOfKind(call, FunctionKind.AGGREGATE)) { + throw new ValidationException(exceptionMessage); + } + call.getChildren().forEach(expr -> expr.accept(this)); + return null; + } + + @Override + protected Void defaultMethod(Expression expression) { + return null; + } + } } diff --git a/flink-table/flink-table-api-scala-bridge/src/main/scala/org/apache/flink/table/api/scala/internal/StreamTableEnvironmentImpl.scala b/flink-table/flink-table-api-scala-bridge/src/main/scala/org/apache/flink/table/api/scala/internal/StreamTableEnvironmentImpl.scala index 96b7403..990f043 100644 --- a/flink-table/flink-table-api-scala-bridge/src/main/scala/org/apache/flink/table/api/scala/internal/StreamTableEnvironmentImpl.scala +++ b/flink-table/flink-table-api-scala-bridge/src/main/scala/org/apache/flink/table/api/scala/internal/StreamTableEnvironmentImpl.scala @@ -55,13 +55,15 @@ class StreamTableEnvironmentImpl ( config: TableConfig, scalaExecutionEnvironment: StreamExecutionEnvironment, planner: Planner, - executor: Executor) + executor: Executor, + isStreaming: Boolean) extends TableEnvironmentImpl( catalogManager, config, executor, functionCatalog, - planner) + planner, + isStreaming) with org.apache.flink.table.api.scala.StreamTableEnvironment { override def fromDataStream[T](dataStream: DataStream[T]): Table = { @@ -262,7 +264,9 @@ object StreamTableEnvironmentImpl { tableConfig, executionEnvironment, planner, - executor) + executor, + !settings.isBatchMode + ) } private def lookupExecutor( diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/expressions/ExpressionUtils.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/expressions/ExpressionUtils.java index 90aece5..951be99 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/expressions/ExpressionUtils.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/expressions/ExpressionUtils.java @@ -32,14 +32,14 @@ public final class ExpressionUtils { * Extracts the value (excluding null) of a given class from an expression assuming it is a * {@link ValueLiteralExpression}. * - * @param expr literal to extract the value from + * @param expression literal to extract the value from * @param targetClass expected class to extract from the literal * @param <V> type of extracted value * @return extracted value or empty if could not extract value of given type */ - public static <V> Optional<V> extractValue(Expression expr, Class<V> targetClass) { - if (expr instanceof ValueLiteralExpression) { - final ValueLiteralExpression valueLiteral = (ValueLiteralExpression) expr; + public static <V> Optional<V> extractValue(Expression expression, Class<V> targetClass) { + if (expression instanceof ValueLiteralExpression) { + final ValueLiteralExpression valueLiteral = (ValueLiteralExpression) expression; return valueLiteral.getValueAs(targetClass); } return Optional.empty(); diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/operations/OperationTreeBuilderFactory.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/operations/OperationTreeBuilderFactory.java deleted file mode 100644 index 1a01307..0000000 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/operations/OperationTreeBuilderFactory.java +++ /dev/null @@ -1,44 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.table.operations; - -import org.apache.flink.annotation.Internal; -import org.apache.flink.table.catalog.FunctionLookup; -import org.apache.flink.table.expressions.resolver.lookups.TableReferenceLookup; - -/** - * Temporary solution for looking up the {@link OperationTreeBuilder}. The tree builder - * should be moved to api module once the type inference is in place. - */ -@Internal -public final class OperationTreeBuilderFactory { - - public static OperationTreeBuilder create( - TableReferenceLookup tableReferenceLookup, - FunctionLookup functionLookup) { - return new OperationTreeBuilderImpl( - tableReferenceLookup, - functionLookup, - true - ); - } - - private OperationTreeBuilderFactory() { - } -} diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/api/internal/TableEnvImpl.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/api/internal/TableEnvImpl.scala index efa29ed..0041141 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/api/internal/TableEnvImpl.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/api/internal/TableEnvImpl.scala @@ -18,12 +18,6 @@ package org.apache.flink.table.api.internal -import _root_.java.util.Optional - -import org.apache.calcite.jdbc.CalciteSchemaBuilder.asRootSchema -import org.apache.calcite.sql._ -import org.apache.calcite.sql.parser.SqlParser -import org.apache.calcite.tools.FrameworkConfig import org.apache.flink.annotation.VisibleForTesting import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.table.api._ @@ -40,6 +34,13 @@ import org.apache.flink.table.sources.TableSource import org.apache.flink.table.util.JavaScalaConversionUtil import org.apache.flink.util.StringUtils +import org.apache.calcite.jdbc.CalciteSchemaBuilder.asRootSchema +import org.apache.calcite.sql._ +import org.apache.calcite.sql.parser.SqlParser +import org.apache.calcite.tools.FrameworkConfig + +import _root_.java.util.Optional + import _root_.scala.collection.JavaConverters._ /** @@ -73,9 +74,9 @@ abstract class TableEnvImpl( } } - private[flink] val operationTreeBuilder = new OperationTreeBuilderImpl( - tableLookup, + private[flink] val operationTreeBuilder = OperationTreeBuilder.create( functionCatalog, + tableLookup, !isBatch) protected val planningConfigurationBuilder: PlanningConfigurationBuilder = diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/operations/OperationTreeBuilderImpl.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/operations/OperationTreeBuilderImpl.scala deleted file mode 100644 index 03a9bc3..0000000 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/operations/OperationTreeBuilderImpl.scala +++ /dev/null @@ -1,600 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.table.operations - -import java.util.{Collections, Optional, List => JList} -import org.apache.flink.table.api._ -import org.apache.flink.table.catalog.FunctionLookup -import org.apache.flink.table.expressions._ -import org.apache.flink.table.expressions.resolver.ExpressionResolver.resolverFor -import org.apache.flink.table.expressions.resolver.lookups.TableReferenceLookup -import org.apache.flink.table.expressions.resolver.{ExpressionResolver, LookupCallResolver} -import org.apache.flink.table.expressions.utils.ApiExpressionUtils.{isFunctionOfKind, unresolvedCall, unresolvedRef, valueLiteral} -import org.apache.flink.table.expressions.utils.{ApiExpressionDefaultVisitor, ApiExpressionUtils} -import org.apache.flink.table.functions.FunctionKind.{SCALAR, TABLE} -import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils -import org.apache.flink.table.functions.{AggregateFunctionDefinition, BuiltInFunctionDefinitions, TableFunctionDefinition} -import org.apache.flink.table.operations.JoinQueryOperation.JoinType -import org.apache.flink.table.operations.OperationExpressionsUtils.extractAggregationsAndProperties -import org.apache.flink.table.operations.SetQueryOperation.SetQueryOperationType._ -import org.apache.flink.table.operations.utils.factories.AliasOperationUtils.createAliasList -import org.apache.flink.table.operations.utils.factories._ -import org.apache.flink.table.types.logical.LogicalTypeRoot -import org.apache.flink.table.types.logical.utils.LogicalTypeChecks -import org.apache.flink.table.util.JavaScalaConversionUtil.toScala -import org.apache.flink.util.Preconditions - -import _root_.scala.collection.JavaConversions._ -import _root_.scala.collection.JavaConverters._ - -/** - * Builder for [[[Operation]] tree. - * - * The operation tree builder resolves expressions such that factories only work with fully - * [[ResolvedExpression]]s. - */ -class OperationTreeBuilderImpl( - tableCatalog: TableReferenceLookup, - functionCatalog: FunctionLookup, - isStreaming: Boolean) - extends OperationTreeBuilder{ - - private val lookupResolver = new LookupCallResolver(functionCatalog) - private val projectionOperationFactory = new ProjectionOperationFactory() - private val sortOperationFactory = new SortOperationFactory(isStreaming) - private val calculatedTableFactory = new CalculatedTableFactory() - private val setOperationFactory = new SetOperationFactory(isStreaming) - private val aggregateOperationFactory = new AggregateOperationFactory(isStreaming) - private val joinOperationFactory = new JoinOperationFactory() - - private val noWindowPropertyChecker = new NoWindowPropertyChecker( - "Window start and end properties are not available for Over windows.") - - override def project( - projectList: JList[Expression], - child: QueryOperation) - : QueryOperation = { - project(projectList, child, explicitAlias = false) - } - - override def project( - projectList: JList[Expression], - child: QueryOperation, - explicitAlias: Boolean) - : QueryOperation = { - projectInternal(projectList, child, explicitAlias, Collections.emptyList()) - } - - override def project( - projectList: JList[Expression], - child: QueryOperation, - overWindows: JList[OverWindow]) - : QueryOperation = { - - Preconditions.checkArgument(!overWindows.isEmpty) - - projectList.asScala.map(_.accept(noWindowPropertyChecker)) - - projectInternal(projectList, - child, - explicitAlias = true, - overWindows) - } - - private def projectInternal( - projectList: JList[Expression], - child: QueryOperation, - explicitAlias: Boolean, - overWindows: JList[OverWindow]) - : QueryOperation = { - - validateProjectList(projectList, overWindows) - - val resolver = resolverFor(tableCatalog, functionCatalog, child) - .withOverWindows(overWindows) - .build - val projections = resolver.resolve(projectList) - projectionOperationFactory.create( - projections, - child, - explicitAlias, - resolver.postResolverFactory()) - } - - /** - * Window properties and aggregate function should not exist in the plain project, i.e., window - * properties should exist in the window operators and aggregate functions should exist in - * the aggregate operators. - */ - private def validateProjectList( - projectList: JList[Expression], - overWindows: JList[OverWindow]) - : Unit = { - - val expressionsWithResolvedCalls = projectList.map(_.accept(lookupResolver)).asJava - val extracted = extractAggregationsAndProperties(expressionsWithResolvedCalls) - if (!extracted.getWindowProperties.isEmpty) { - throw new ValidationException("Window properties can only be used on windowed tables.") - } - - // aggregate functions can't exist in the plain project except for the over window case - if (!extracted.getAggregations.isEmpty && overWindows.isEmpty) { - throw new ValidationException("Aggregate functions are not supported in the select right" + - " after the aggregate or flatAggregate operation.") - } - } - - /** - * Adds additional columns. Existing fields will be replaced if replaceIfExist is true. - */ - override def addColumns( - replaceIfExist: Boolean, - fieldLists: JList[Expression], - child: QueryOperation) - : QueryOperation = { - val newColumns = if (replaceIfExist) { - val fieldNames = child.getTableSchema.getFieldNames.toList.asJava - ColumnOperationUtils.addOrReplaceColumns(fieldNames, fieldLists) - } else { - (unresolvedRef("*") +: fieldLists.asScala).asJava - } - project(newColumns, child) - } - - override def renameColumns( - aliases: JList[Expression], - child: QueryOperation) - : QueryOperation = { - - val resolver = resolverFor(tableCatalog, functionCatalog, child) - .build - - val inputFieldNames = child.getTableSchema.getFieldNames.toList.asJava - val validateAliases = ColumnOperationUtils.renameColumns( - inputFieldNames, - resolver.resolveExpanding(aliases)) - - project(validateAliases, child) - } - - override def dropColumns( - fieldList: JList[Expression], - child: QueryOperation) - : QueryOperation = { - - val resolver = resolverFor(tableCatalog, functionCatalog, child) - .build - - val inputFieldNames = child.getTableSchema.getFieldNames.toList.asJava - val finalFields = ColumnOperationUtils.dropFields( - inputFieldNames, - resolver.resolveExpanding(fieldList)) - - project(finalFields, child) - } - - override def aggregate( - groupingExpressions: JList[Expression], - aggregates: JList[Expression], - child: QueryOperation) - : QueryOperation = { - - val resolver = resolverFor(tableCatalog, functionCatalog, child).build - - val resolvedGroupings = resolver.resolve(groupingExpressions) - val resolvedAggregates = resolver.resolve(aggregates) - - aggregateOperationFactory.createAggregate(resolvedGroupings, resolvedAggregates, child) - } - - /** - * Row based aggregate that will flatten the output if it is a composite type. - */ - override def aggregate( - groupingExpressions: JList[Expression], - aggregate: Expression, - child: QueryOperation) - : QueryOperation = { - val resolvedAggregate = aggregate.accept(lookupResolver) - - // extract alias and aggregate function - var alias: Seq[String] = Seq() - val aggWithoutAlias = resolvedAggregate match { - case c: UnresolvedCallExpression - if c.getFunctionDefinition == BuiltInFunctionDefinitions.AS => - alias = c.getChildren - .drop(1) - .map(e => ExpressionUtils.extractValue(e, classOf[String]).get()) - c.getChildren.get(0) - case c: UnresolvedCallExpression - if c.getFunctionDefinition.isInstanceOf[AggregateFunctionDefinition] => - if (alias.isEmpty) alias = UserDefinedFunctionUtils.getFieldInfo( - c.getFunctionDefinition.asInstanceOf[AggregateFunctionDefinition].getResultTypeInfo)._1 - c - case e => e - } - - // turn agg to a named agg, because it will be verified later. - var cnt = 0 - val childNames = child.getTableSchema.getFieldNames - while (childNames.contains("TMP_" + cnt)) { - cnt += 1 - } - val aggWithNamedAlias = unresolvedCall( - BuiltInFunctionDefinitions.AS, - aggWithoutAlias, - valueLiteral("TMP_" + cnt)) - - // get agg table - val aggQueryOperation = this.aggregate(groupingExpressions, Seq(aggWithNamedAlias), child) - - // flatten the aggregate function - val aggNames = aggQueryOperation.getTableSchema.getFieldNames - val flattenExpressions = aggNames.take(groupingExpressions.size()) - .map(e => unresolvedRef(e)) ++ - Seq(unresolvedCall(BuiltInFunctionDefinitions.FLATTEN, unresolvedRef(aggNames.last))) - val flattenedOperation = this.project(flattenExpressions.toList, aggQueryOperation) - - // add alias - aliasBackwardFields(flattenedOperation, alias, groupingExpressions.size()) - } - - override def tableAggregate( - groupingExpressions: JList[Expression], - tableAggFunction: Expression, - child: QueryOperation) - : QueryOperation = { - - // Step1: add a default name to the call in the grouping expressions, e.g., groupBy(a % 5) to - // groupBy(a % 5 as TMP_0). We need a name for every column so that to perform alias for the - // table aggregate function in Step4. - val newGroupingExpressions = addAliasToTheCallInGroupings( - child.getTableSchema.getFieldNames, - groupingExpressions) - - // Step2: resolve expressions - val resolver = resolverFor(tableCatalog, functionCatalog, child).build - val resolvedGroupings = resolver.resolve(newGroupingExpressions) - val resolvedFunctionAndAlias = aggregateOperationFactory.extractTableAggFunctionAndAliases( - resolveSingleExpression(tableAggFunction, resolver)) - - // Step3: create table agg operation - val tableAggOperation = aggregateOperationFactory - .createAggregate(resolvedGroupings, Seq(resolvedFunctionAndAlias.f0), child) - - // Step4: add a top project to alias the output fields of the table aggregate. - aliasBackwardFields(tableAggOperation, resolvedFunctionAndAlias.f1, groupingExpressions.size()) - } - - override def windowAggregate( - groupingExpressions: JList[Expression], - window: GroupWindow, - windowProperties: JList[Expression], - aggregates: JList[Expression], - child: QueryOperation) - : QueryOperation = { - - val resolver = resolverFor(tableCatalog, functionCatalog, child).build() - val resolvedWindow = aggregateOperationFactory.createResolvedWindow(window, resolver) - - val resolverWithWindowReferences = resolverFor(tableCatalog, functionCatalog, child) - .withLocalReferences( - new LocalReferenceExpression( - resolvedWindow.getAlias, - resolvedWindow.getTimeAttribute.getOutputDataType)) - .build - - val convertedGroupings = resolverWithWindowReferences.resolve(groupingExpressions) - - val convertedAggregates = resolverWithWindowReferences.resolve(aggregates) - - val convertedProperties = resolverWithWindowReferences.resolve(windowProperties) - - aggregateOperationFactory.createWindowAggregate( - convertedGroupings, - convertedAggregates, - convertedProperties, - resolvedWindow, - child) - } - - override def windowTableAggregate( - groupingExpressions: JList[Expression], - window: GroupWindow, - windowProperties: JList[Expression], - tableAggFunction: Expression, - child: QueryOperation) - : QueryOperation = { - - // Step1: add a default name to the call in the grouping expressions, e.g., groupBy(a % 5) to - // groupBy(a % 5 as TMP_0). We need a name for every column so that to perform alias for the - // table aggregate function in Step4. - val newGroupingExpressions = addAliasToTheCallInGroupings( - child.getTableSchema.getFieldNames, - groupingExpressions) - - // Step2: resolve expressions, including grouping, aggregates and window properties. - val resolver = resolverFor(tableCatalog, functionCatalog, child).build() - val resolvedWindow = aggregateOperationFactory.createResolvedWindow(window, resolver) - - val resolverWithWindowReferences = resolverFor(tableCatalog, functionCatalog, child) - .withLocalReferences( - new LocalReferenceExpression( - resolvedWindow.getAlias, - resolvedWindow.getTimeAttribute.getOutputDataType)) - .build - - val convertedGroupings = resolverWithWindowReferences.resolve(newGroupingExpressions) - val convertedAggregates = resolverWithWindowReferences.resolve(Seq(tableAggFunction)) - val convertedProperties = resolverWithWindowReferences.resolve(windowProperties) - val resolvedFunctionAndAlias = aggregateOperationFactory.extractTableAggFunctionAndAliases( - convertedAggregates.get(0)) - - // Step3: create window table agg operation - val tableAggOperation = aggregateOperationFactory.createWindowAggregate( - convertedGroupings, - Seq(resolvedFunctionAndAlias.f0), - convertedProperties, - resolvedWindow, - child) - - // Step4: add a top project to alias the output fields of the table aggregate. Also, project the - // window attribute. - aliasBackwardFields(tableAggOperation, resolvedFunctionAndAlias.f1, groupingExpressions.size()) - } - - override def join( - left: QueryOperation, - right: QueryOperation, - joinType: JoinType, - condition: Optional[Expression], - correlated: Boolean) - : QueryOperation = { - val resolver = resolverFor(tableCatalog, functionCatalog, left, right).build() - val resolvedCondition = toScala(condition).map(expr => resolveSingleExpression(expr, resolver)) - - joinOperationFactory - .create(left, right, joinType, resolvedCondition.getOrElse(valueLiteral(true)), correlated) - } - - override def joinLateral( - left: QueryOperation, - tableFunction: Expression, - joinType: JoinType, - condition: Optional[Expression]) - : QueryOperation = { - val resolver = resolverFor(tableCatalog, functionCatalog, left).build() - val resolvedFunction = resolveSingleExpression(tableFunction, resolver) - - val temporalTable = - calculatedTableFactory.create(resolvedFunction, left.getTableSchema.getFieldNames) - - join(left, temporalTable, joinType, condition, correlated = true) - } - - override def resolveExpression(expression: Expression, queryOperation: QueryOperation*) - : Expression = { - val resolver = resolverFor(tableCatalog, functionCatalog, queryOperation: _*).build() - - resolveSingleExpression(expression, resolver) - } - - private def resolveSingleExpression( - expression: Expression, - resolver: ExpressionResolver) - : ResolvedExpression = { - val resolvedExpression = resolver.resolve(List(expression).asJava) - if (resolvedExpression.size() != 1) { - throw new ValidationException("Expected single expression") - } else { - resolvedExpression.get(0) - } - } - - override def sort( - fields: JList[Expression], - child: QueryOperation) - : QueryOperation = { - - val resolver = resolverFor(tableCatalog, functionCatalog, child).build() - val resolvedFields = resolver.resolve(fields) - - sortOperationFactory.createSort(resolvedFields, child, resolver.postResolverFactory()) - } - - override def limitWithOffset(offset: Int, child: QueryOperation): QueryOperation = { - sortOperationFactory.createLimitWithOffset(offset, child) - } - - override def limitWithFetch(fetch: Int, child: QueryOperation): QueryOperation = { - sortOperationFactory.createLimitWithFetch(fetch, child) - } - - override def alias( - fields: JList[Expression], - child: QueryOperation) - : QueryOperation = { - - val newFields = createAliasList(fields, child) - - project(newFields, child, explicitAlias = true) - } - - override def filter( - condition: Expression, - child: QueryOperation) - : QueryOperation = { - - val resolver = resolverFor(tableCatalog, functionCatalog, child).build() - val resolvedExpression = resolveSingleExpression(condition, resolver) - val conditionType = resolvedExpression.getOutputDataType - if (!LogicalTypeChecks.hasRoot(conditionType.getLogicalType, LogicalTypeRoot.BOOLEAN)) { - throw new ValidationException(s"Filter operator requires a boolean expression as input," + - s" but $condition is of type ${conditionType}") - } - - new FilterQueryOperation(resolvedExpression, child) - } - - override def distinct( - child: QueryOperation) - : QueryOperation = { - new DistinctQueryOperation(child) - } - - override def minus( - left: QueryOperation, - right: QueryOperation, - all: Boolean) - : QueryOperation = { - setOperationFactory.create(MINUS, left, right, all) - } - - override def intersect( - left: QueryOperation, - right: QueryOperation, - all: Boolean) - : QueryOperation = { - setOperationFactory.create(INTERSECT, left, right, all) - } - - override def union( - left: QueryOperation, - right: QueryOperation, - all: Boolean) - : QueryOperation = { - setOperationFactory.create(UNION, left, right, all) - } - - override def map(mapFunction: Expression, child: QueryOperation): QueryOperation = { - - val resolvedMapFunction = mapFunction.accept(lookupResolver) - - if (!isFunctionOfKind(resolvedMapFunction, SCALAR)) { - throw new ValidationException("Only a scalar function can be used in the map operator.") - } - - val expandedFields = unresolvedCall(BuiltInFunctionDefinitions.FLATTEN, resolvedMapFunction) - project(Collections.singletonList(expandedFields), child) - } - - override def flatMap(tableFunction: Expression, child: QueryOperation): QueryOperation = { - - val resolvedTableFunction = tableFunction.accept(lookupResolver) - - if (!isFunctionOfKind(resolvedTableFunction, TABLE)) { - throw new ValidationException("Only a table function can be used in the flatMap operator.") - } - - val originFieldNames: Seq[String] = - resolvedTableFunction.asInstanceOf[UnresolvedCallExpression].getFunctionDefinition match { - case tfd: TableFunctionDefinition => - UserDefinedFunctionUtils.getFieldInfo(tfd.getResultType)._1 - } - - val usedFieldNames = child.getTableSchema.getFieldNames.toBuffer - val newFieldNames = originFieldNames.map({ e => - val resultName = getUniqueName(e, usedFieldNames) - usedFieldNames.append(resultName) - resultName - }) - - val renamedTableFunction = unresolvedCall( - BuiltInFunctionDefinitions.AS, - resolvedTableFunction +: newFieldNames.map(ApiExpressionUtils.valueLiteral(_)): _*) - val joinNode = joinLateral(child, renamedTableFunction, JoinType.INNER, Optional.empty()) - val rightNode = dropColumns( - child.getTableSchema.getFieldNames.map(ApiExpressionUtils.unresolvedRef).toList, - joinNode) - alias(originFieldNames.map(a => unresolvedRef(a)), rightNode) - } - - /** - * Return a unique name that does not exist in usedFieldNames according to the input name. - */ - private def getUniqueName(inputName: String, usedFieldNames: Seq[String]): String = { - var i = 0 - var resultName = inputName - while (usedFieldNames.contains(resultName)) { - resultName = resultName + "_" + i - i += 1 - } - resultName - } - - /** - * Add a default name to the call in the grouping expressions, e.g., groupBy(a % 5) to - * groupBy(a % 5 as TMP_0). - */ - private def addAliasToTheCallInGroupings( - inputFieldNames: Seq[String], - groupingExpressions: JList[Expression]) - : JList[Expression] = { - - var attrNameCntr: Int = 0 - val usedFieldNames = inputFieldNames.toBuffer - groupingExpressions.map { - case c: UnresolvedCallExpression - if c.getFunctionDefinition != BuiltInFunctionDefinitions.AS => - val tempName = getUniqueName("TMP_" + attrNameCntr, usedFieldNames) - usedFieldNames.append(tempName) - attrNameCntr += 1 - unresolvedCall( - BuiltInFunctionDefinitions.AS, - c, - valueLiteral(tempName) - ) - case e => e - } - } - - /** - * Rename fields in the input [[QueryOperation]]. - */ - private def aliasBackwardFields( - inputOperation: QueryOperation, - alias: Seq[String], - aliasStartIndex: Int) - : QueryOperation = { - - if (alias.nonEmpty) { - val namesBeforeAlias = inputOperation.getTableSchema.getFieldNames - val namesAfterAlias = namesBeforeAlias.take(aliasStartIndex) ++ alias ++ - namesBeforeAlias.takeRight(namesBeforeAlias.length - alias.size - aliasStartIndex) - this.alias(namesAfterAlias.map(e => unresolvedRef(e)).toList, inputOperation) - } else { - inputOperation - } - } - - class NoWindowPropertyChecker(val exceptionMessage: String) - extends ApiExpressionDefaultVisitor[Void] { - override def visit(unresolvedCall: UnresolvedCallExpression): Void = { - val functionDefinition = unresolvedCall.getFunctionDefinition - if (BuiltInFunctionDefinitions.WINDOW_PROPERTIES - .contains(functionDefinition)) { - throw new ValidationException(exceptionMessage) - } - unresolvedCall.getChildren.asScala.foreach(expr => expr.accept(this)) - null - } - - override protected def defaultMethod(expression: Expression): Void = null - } -} diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/StreamTableEnvironmentTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/StreamTableEnvironmentTest.scala index 44dc4ef..20e8ba6 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/StreamTableEnvironmentTest.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/StreamTableEnvironmentTest.scala @@ -215,7 +215,8 @@ class StreamTableEnvironmentTest extends TableTestBase { config, jStreamExecEnv, streamPlanner, - executor) + executor, + true) val sType = new TupleTypeInfo(Types.LONG, Types.INT, Types.STRING, Types.INT, Types.LONG) .asInstanceOf[TupleTypeInfo[JTuple5[JLong, JInt, String, JInt, JLong]]] diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/sql/AggregateTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/sql/AggregateTest.scala index d435809..b56919d 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/sql/AggregateTest.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/sql/AggregateTest.scala @@ -73,7 +73,8 @@ class AggregateTest extends TableTestBase { new TableConfig, Mockito.mock(classOf[StreamExecutionEnvironment]), Mockito.mock(classOf[Planner]), - Mockito.mock(classOf[Executor]) + Mockito.mock(classOf[Executor]), + true ) tablEnv.registerFunction("udag", new MyAgg) diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/utils/TableTestBase.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/utils/TableTestBase.scala index e72225d..e85e7e4 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/utils/TableTestBase.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/utils/TableTestBase.scala @@ -338,7 +338,8 @@ case class StreamTableTestUtil( tableConfig, javaEnv, streamPlanner, - executor) + executor, + true) val env = new StreamExecutionEnvironment(javaEnv) val tableEnv = new ScalaStreamTableEnvironmentImpl( @@ -347,7 +348,8 @@ case class StreamTableTestUtil( tableConfig, env, streamPlanner, - executor) + executor, + true) def addTable[T: TypeInformation]( name: String,
