This is an automated email from the ASF dual-hosted git repository.
libenchao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git
The following commit(s) were added to refs/heads/master by this push:
new 42b7e74ab20 [FLINK-33263][table-planner] Implement ParallelismProvider
for sources in the table planner
42b7e74ab20 is described below
commit 42b7e74ab20785289b62f5dd68d566995ba9dcfc
Author: SuDewei <[email protected]>
AuthorDate: Thu Jan 18 16:05:40 2024 +0800
[FLINK-33263][table-planner] Implement ParallelismProvider for sources in
the table planner
Close apache/flink#24128
---
.../org/apache/flink/api/dag/Transformation.java | 15 ++
.../streaming/api/graph/StreamGraphGenerator.java | 3 +
.../SourceTransformationWrapper.java | 72 ++++++++++
.../exec/common/CommonExecTableSourceScan.java | 154 ++++++++++++++++++---
.../table/planner/delegation/BatchPlanner.scala | 2 +-
.../table/planner/delegation/PlannerBase.scala | 3 +-
.../table/planner/delegation/StreamPlanner.scala | 2 +-
.../planner/factories/TestValuesTableFactory.java | 33 +++--
.../planner/plan/stream/sql/TableScanTest.xml | 42 ++++++
.../planner/plan/stream/sql/TableScanTest.scala | 38 +++++
.../runtime/stream/sql/TableSourceITCase.scala | 80 +++++++++++
.../flink/table/planner/utils/TableTestBase.scala | 51 ++++++-
12 files changed, 463 insertions(+), 32 deletions(-)
diff --git
a/flink-core/src/main/java/org/apache/flink/api/dag/Transformation.java
b/flink-core/src/main/java/org/apache/flink/api/dag/Transformation.java
index a0448697dd1..6256f9624f6 100644
--- a/flink-core/src/main/java/org/apache/flink/api/dag/Transformation.java
+++ b/flink-core/src/main/java/org/apache/flink/api/dag/Transformation.java
@@ -19,6 +19,7 @@
package org.apache.flink.api.dag;
import org.apache.flink.annotation.Internal;
+import org.apache.flink.annotation.VisibleForTesting;
import org.apache.flink.api.common.ExecutionConfig;
import org.apache.flink.api.common.functions.InvalidTypesException;
import org.apache.flink.api.common.operators.ResourceSpec;
@@ -602,6 +603,20 @@ public abstract class Transformation<T> {
+ '}';
}
+ @VisibleForTesting
+ public String toStringWithoutId() {
+ return getClass().getSimpleName()
+ + "{"
+ + "name='"
+ + name
+ + '\''
+ + ", outputType="
+ + outputType
+ + ", parallelism="
+ + parallelism
+ + '}';
+ }
+
@Override
public boolean equals(Object o) {
if (this == o) {
diff --git
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamGraphGenerator.java
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamGraphGenerator.java
index 5929a2a5e8e..8e267ff84d6 100644
---
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamGraphGenerator.java
+++
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamGraphGenerator.java
@@ -65,6 +65,7 @@ import
org.apache.flink.streaming.api.transformations.ReduceTransformation;
import org.apache.flink.streaming.api.transformations.SideOutputTransformation;
import org.apache.flink.streaming.api.transformations.SinkTransformation;
import org.apache.flink.streaming.api.transformations.SourceTransformation;
+import
org.apache.flink.streaming.api.transformations.SourceTransformationWrapper;
import
org.apache.flink.streaming.api.transformations.TimestampsAndWatermarksTransformation;
import org.apache.flink.streaming.api.transformations.TwoInputTransformation;
import org.apache.flink.streaming.api.transformations.UnionTransformation;
@@ -553,6 +554,8 @@ public class StreamGraphGenerator {
transformedIds = transformFeedback((FeedbackTransformation<?>)
transform);
} else if (transform instanceof CoFeedbackTransformation<?>) {
transformedIds = transformCoFeedback((CoFeedbackTransformation<?>)
transform);
+ } else if (transform instanceof SourceTransformationWrapper<?>) {
+ transformedIds = transform(((SourceTransformationWrapper<?>)
transform).getInput());
} else {
throw new IllegalStateException("Unknown transformation: " +
transform);
}
diff --git
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/SourceTransformationWrapper.java
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/SourceTransformationWrapper.java
new file mode 100644
index 00000000000..d536000fde2
--- /dev/null
+++
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/transformations/SourceTransformationWrapper.java
@@ -0,0 +1,72 @@
+/*
+Licensed to the Apache Software Foundation (ASF) under one
+or more contributor license agreements. See the NOTICE file
+distributed with this work for additional information
+regarding copyright ownership. The ASF licenses this file
+to you under the Apache License, Version 2.0 (the
+"License"); you may not use this file except in compliance
+with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+
+package org.apache.flink.streaming.api.transformations;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.ExecutionConfig;
+import org.apache.flink.api.dag.Transformation;
+import org.apache.flink.streaming.api.graph.TransformationTranslator;
+
+import org.apache.flink.shaded.guava31.com.google.common.collect.Lists;
+
+import java.util.Collections;
+import java.util.List;
+
+/**
+ * This Transformation is a phantom transformation which is used to expose a
default parallelism to
+ * downstream.
+ *
+ * <p>It is used only when the parallelism of the source transformation
differs from the default
+ * parallelism, ensuring that the parallelism of downstream operations is not
affected.
+ *
+ * <p>Moreover, this transformation does not have a corresponding {@link
TransformationTranslator},
+ * meaning it will not become a node in the StreamGraph.
+ *
+ * @param <T> The type of the elements in the input {@code Transformation}
+ */
+@Internal
+public class SourceTransformationWrapper<T> extends Transformation<T> {
+
+ private final Transformation<T> input;
+
+ public SourceTransformationWrapper(Transformation<T> input) {
+ super(
+ "ChangeToDefaultParallel",
+ input.getOutputType(),
+ ExecutionConfig.PARALLELISM_DEFAULT);
+ this.input = input;
+ }
+
+ public Transformation<T> getInput() {
+ return input;
+ }
+
+ @Override
+ public List<Transformation<?>> getTransitivePredecessors() {
+ List<Transformation<?>> result = Lists.newArrayList();
+ result.add(this);
+ result.addAll(input.getTransitivePredecessors());
+ return result;
+ }
+
+ @Override
+ public List<Transformation<?>> getInputs() {
+ return Collections.singletonList(input);
+ }
+}
diff --git
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/common/CommonExecTableSourceScan.java
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/common/CommonExecTableSourceScan.java
index dc69543cd28..be5b46ba973 100644
---
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/common/CommonExecTableSourceScan.java
+++
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/common/CommonExecTableSourceScan.java
@@ -18,6 +18,7 @@
package org.apache.flink.table.planner.plan.nodes.exec.common;
+import org.apache.flink.api.common.ExecutionConfig;
import org.apache.flink.api.common.eventtime.WatermarkStrategy;
import org.apache.flink.api.common.io.InputFormat;
import org.apache.flink.api.common.typeinfo.TypeInformation;
@@ -30,6 +31,13 @@ import
org.apache.flink.streaming.api.functions.source.ParallelSourceFunction;
import org.apache.flink.streaming.api.functions.source.SourceFunction;
import org.apache.flink.streaming.api.operators.StreamSource;
import
org.apache.flink.streaming.api.transformations.LegacySourceTransformation;
+import org.apache.flink.streaming.api.transformations.PartitionTransformation;
+import
org.apache.flink.streaming.api.transformations.SourceTransformationWrapper;
+import
org.apache.flink.streaming.runtime.partitioner.KeyGroupStreamPartitioner;
+import org.apache.flink.table.api.TableException;
+import org.apache.flink.table.catalog.ResolvedSchema;
+import org.apache.flink.table.connector.ChangelogMode;
+import org.apache.flink.table.connector.ParallelismProvider;
import org.apache.flink.table.connector.ProviderContext;
import org.apache.flink.table.connector.source.DataStreamScanProvider;
import org.apache.flink.table.connector.source.InputFormatProvider;
@@ -48,17 +56,22 @@ import
org.apache.flink.table.planner.plan.nodes.exec.MultipleTransformationTran
import
org.apache.flink.table.planner.plan.nodes.exec.spec.DynamicTableSourceSpec;
import org.apache.flink.table.planner.plan.nodes.exec.stream.StreamExecNode;
import
org.apache.flink.table.planner.plan.nodes.exec.utils.TransformationMetadata;
+import org.apache.flink.table.planner.plan.utils.KeySelectorUtil;
import org.apache.flink.table.planner.utils.ShortcutUtils;
import
org.apache.flink.table.runtime.connector.source.ScanRuntimeProviderContext;
+import org.apache.flink.table.runtime.keyselector.RowDataKeySelector;
import org.apache.flink.table.runtime.typeutils.InternalTypeInfo;
import org.apache.flink.table.types.logical.LogicalType;
import org.apache.flink.table.types.logical.RowType;
+import org.apache.flink.types.RowKind;
import
org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonProperty;
import java.util.List;
import java.util.Optional;
+import static
org.apache.flink.runtime.state.KeyGroupRangeAssignment.DEFAULT_LOWER_BOUND_MAX_PARALLELISM;
+
/**
* Base {@link ExecNode} to read data from an external source defined by a
{@link ScanTableSource}.
*/
@@ -96,6 +109,7 @@ public abstract class CommonExecTableSourceScan extends
ExecNodeBase<RowData>
@Override
protected Transformation<RowData> translateToPlanInternal(
PlannerBase planner, ExecNodeConfig config) {
+ final Transformation<RowData> sourceTransform;
final StreamExecutionEnvironment env = planner.getExecEnv();
final TransformationMetadata meta =
createTransformationMeta(SOURCE_TRANSFORMATION, config);
final InternalTypeInfo<RowData> outputTypeInfo =
@@ -105,54 +119,149 @@ public abstract class CommonExecTableSourceScan extends
ExecNodeBase<RowData>
planner.getFlinkContext(),
ShortcutUtils.unwrapTypeFactory(planner));
ScanTableSource.ScanRuntimeProvider provider =
tableSource.getScanRuntimeProvider(ScanRuntimeProviderContext.INSTANCE);
+ final int sourceParallelism = deriveSourceParallelism(provider);
+ final boolean sourceParallelismConfigured =
isParallelismConfigured(provider);
if (provider instanceof SourceFunctionProvider) {
final SourceFunctionProvider sourceFunctionProvider =
(SourceFunctionProvider) provider;
final SourceFunction<RowData> function =
sourceFunctionProvider.createSourceFunction();
- final Transformation<RowData> transformation =
+ sourceTransform =
createSourceFunctionTransformation(
env,
function,
sourceFunctionProvider.isBounded(),
meta.getName(),
- outputTypeInfo);
- return meta.fill(transformation);
+ outputTypeInfo,
+ sourceParallelism,
+ sourceParallelismConfigured);
+ if (function instanceof ParallelSourceFunction &&
sourceParallelismConfigured) {
+ meta.fill(sourceTransform);
+ return new SourceTransformationWrapper<>(sourceTransform);
+ } else {
+ return meta.fill(sourceTransform);
+ }
} else if (provider instanceof InputFormatProvider) {
final InputFormat<RowData, ?> inputFormat =
((InputFormatProvider) provider).createInputFormat();
- final Transformation<RowData> transformation =
+ sourceTransform =
createInputFormatTransformation(
env, inputFormat, outputTypeInfo, meta.getName());
- return meta.fill(transformation);
+ meta.fill(sourceTransform);
} else if (provider instanceof SourceProvider) {
final Source<RowData, ?, ?> source = ((SourceProvider)
provider).createSource();
// TODO: Push down watermark strategy to source scan
- final Transformation<RowData> transformation =
+ sourceTransform =
env.fromSource(
source,
WatermarkStrategy.noWatermarks(),
meta.getName(),
outputTypeInfo)
.getTransformation();
- return meta.fill(transformation);
+ meta.fill(sourceTransform);
} else if (provider instanceof DataStreamScanProvider) {
- Transformation<RowData> transformation =
+ sourceTransform =
((DataStreamScanProvider) provider)
.produceDataStream(createProviderContext(config),
env)
.getTransformation();
- meta.fill(transformation);
- transformation.setOutputType(outputTypeInfo);
- return transformation;
+ meta.fill(sourceTransform);
+ sourceTransform.setOutputType(outputTypeInfo);
} else if (provider instanceof TransformationScanProvider) {
- final Transformation<RowData> transformation =
+ sourceTransform =
((TransformationScanProvider) provider)
.createTransformation(createProviderContext(config));
- meta.fill(transformation);
- transformation.setOutputType(outputTypeInfo);
- return transformation;
+ meta.fill(sourceTransform);
+ sourceTransform.setOutputType(outputTypeInfo);
} else {
throw new UnsupportedOperationException(
provider.getClass().getSimpleName() + " is unsupported
now.");
}
+
+ if (sourceParallelismConfigured) {
+ return applySourceTransformationWrapper(
+ sourceTransform,
+ planner.getFlinkContext().getClassLoader(),
+ outputTypeInfo,
+ config,
+ tableSource.getChangelogMode(),
+ sourceParallelism);
+ } else {
+ return sourceTransform;
+ }
+ }
+
+ private boolean
isParallelismConfigured(ScanTableSource.ScanRuntimeProvider runtimeProvider) {
+ return runtimeProvider instanceof ParallelismProvider
+ && ((ParallelismProvider)
runtimeProvider).getParallelism().isPresent();
+ }
+
+ private int deriveSourceParallelism(ScanTableSource.ScanRuntimeProvider
runtimeProvider) {
+ if (isParallelismConfigured(runtimeProvider)) {
+ int sourceParallelism = ((ParallelismProvider)
runtimeProvider).getParallelism().get();
+ if (sourceParallelism <= 0) {
+ throw new TableException(
+ String.format(
+ "Invalid configured parallelism %s for table
'%s'.",
+ sourceParallelism,
+ tableSourceSpec
+ .getContextResolvedTable()
+ .getIdentifier()
+ .asSummaryString()));
+ }
+ return sourceParallelism;
+ } else {
+ return ExecutionConfig.PARALLELISM_DEFAULT;
+ }
+ }
+
+ protected RowType getPhysicalRowType(ResolvedSchema schema) {
+ return (RowType) schema.toPhysicalRowDataType().getLogicalType();
+ }
+
+ protected int[] getPrimaryKeyIndices(RowType sourceRowType, ResolvedSchema
schema) {
+ return schema.getPrimaryKey()
+ .map(k ->
k.getColumns().stream().mapToInt(sourceRowType::getFieldIndex).toArray())
+ .orElse(new int[0]);
+ }
+
+ private Transformation<RowData> applySourceTransformationWrapper(
+ Transformation<RowData> sourceTransform,
+ ClassLoader classLoader,
+ InternalTypeInfo<RowData> outputTypeInfo,
+ ExecNodeConfig config,
+ ChangelogMode changelogMode,
+ int sourceParallelism) {
+ sourceTransform.setParallelism(sourceParallelism, true);
+ Transformation<RowData> sourceTransformationWrapper =
+ new SourceTransformationWrapper<>(sourceTransform);
+
+ if (!changelogMode.containsOnly(RowKind.INSERT)) {
+ final ResolvedSchema schema =
+
tableSourceSpec.getContextResolvedTable().getResolvedSchema();
+ final RowType physicalRowType = getPhysicalRowType(schema);
+ final int[] primaryKeys = getPrimaryKeyIndices(physicalRowType,
schema);
+ final boolean hasPk = primaryKeys.length > 0;
+ if (!hasPk) {
+ throw new TableException(
+ String.format(
+ "Configured parallelism %s for upsert table
'%s' while can not find primary key field. "
+ + "This is a bug, please file an
issue.",
+ sourceParallelism,
+ tableSourceSpec
+ .getContextResolvedTable()
+ .getIdentifier()
+ .asSummaryString()));
+ }
+ final RowDataKeySelector selector =
+ KeySelectorUtil.getRowDataSelector(classLoader,
primaryKeys, outputTypeInfo);
+ final KeyGroupStreamPartitioner<RowData, RowData> partitioner =
+ new KeyGroupStreamPartitioner<>(selector,
DEFAULT_LOWER_BOUND_MAX_PARALLELISM);
+ Transformation<RowData> partitionedTransform =
+ new PartitionTransformation<>(sourceTransformationWrapper,
partitioner);
+ createTransformationMeta("partitioner", "Partitioner",
"Partitioner", config)
+ .fill(partitionedTransform);
+ return partitionedTransform;
+ } else {
+ return sourceTransformationWrapper;
+ }
}
private ProviderContext createProviderContext(ExecNodeConfig config) {
@@ -178,17 +287,22 @@ public abstract class CommonExecTableSourceScan extends
ExecNodeBase<RowData>
SourceFunction<RowData> function,
boolean isBounded,
String operatorName,
- TypeInformation<RowData> outputTypeInfo) {
+ TypeInformation<RowData> outputTypeInfo,
+ int sourceParallelism,
+ boolean sourceParallelismConfigured) {
env.clean(function);
final int parallelism;
- boolean parallelismConfigured = false;
if (function instanceof ParallelSourceFunction) {
- parallelism = env.getParallelism();
+ if (sourceParallelismConfigured) {
+ parallelism = sourceParallelism;
+ } else {
+ parallelism = env.getParallelism();
+ }
} else {
parallelism = 1;
- parallelismConfigured = true;
+ sourceParallelismConfigured = true;
}
final Boundedness boundedness;
@@ -205,7 +319,7 @@ public abstract class CommonExecTableSourceScan extends
ExecNodeBase<RowData>
outputTypeInfo,
parallelism,
boundedness,
- parallelismConfigured);
+ sourceParallelismConfigured);
}
/**
diff --git
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/delegation/BatchPlanner.scala
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/delegation/BatchPlanner.scala
index bb4c1b75a28..cea10f7bb81 100644
---
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/delegation/BatchPlanner.scala
+++
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/delegation/BatchPlanner.scala
@@ -84,7 +84,7 @@ class BatchPlanner(
processors
}
- override protected def translateToPlan(execGraph: ExecNodeGraph):
util.List[Transformation[_]] = {
+ override def translateToPlan(execGraph: ExecNodeGraph):
util.List[Transformation[_]] = {
beforeTranslation()
val planner = createDummyPlanner()
diff --git
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/delegation/PlannerBase.scala
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/delegation/PlannerBase.scala
index 45788e6278e..b36edaa21d7 100644
---
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/delegation/PlannerBase.scala
+++
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/delegation/PlannerBase.scala
@@ -367,7 +367,8 @@ abstract class PlannerBase(
* @return
* The [[Transformation]] DAG that corresponds to the node DAG.
*/
- protected def translateToPlan(execGraph: ExecNodeGraph):
util.List[Transformation[_]]
+ @VisibleForTesting
+ def translateToPlan(execGraph: ExecNodeGraph): util.List[Transformation[_]]
def addExtraTransformation(transformation: Transformation[_]): Unit = {
if (!extraTransformations.contains(transformation)) {
diff --git
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/delegation/StreamPlanner.scala
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/delegation/StreamPlanner.scala
index fb32326f117..894a37c8cf9 100644
---
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/delegation/StreamPlanner.scala
+++
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/delegation/StreamPlanner.scala
@@ -78,7 +78,7 @@ class StreamPlanner(
override protected def getExecNodeGraphProcessors:
Seq[ExecNodeGraphProcessor] = Seq()
- override protected def translateToPlan(execGraph: ExecNodeGraph):
util.List[Transformation[_]] = {
+ override def translateToPlan(execGraph: ExecNodeGraph):
util.List[Transformation[_]] = {
beforeTranslation()
val planner = createDummyPlanner()
val transformations = execGraph.getRootNodes.map {
diff --git
a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/factories/TestValuesTableFactory.java
b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/factories/TestValuesTableFactory.java
index 3dbf4d5b9c0..db64847b75b 100644
---
a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/factories/TestValuesTableFactory.java
+++
b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/factories/TestValuesTableFactory.java
@@ -466,6 +466,8 @@ public final class TestValuesTableFactory
private static final ConfigOption<String> SINK_CHANGELOG_MODE_ENFORCED =
ConfigOptions.key("sink-changelog-mode-enforced").stringType().noDefaultValue();
+ private static final ConfigOption<Integer> SOURCE_PARALLELISM =
FactoryUtil.SOURCE_PARALLELISM;
+
private static final ConfigOption<Integer> SINK_PARALLELISM =
FactoryUtil.SINK_PARALLELISM;
@Override
@@ -497,6 +499,7 @@ public final class TestValuesTableFactory
int lookupThreshold = helper.getOptions().get(LOOKUP_THRESHOLD);
int sleepAfterElements =
helper.getOptions().get(SOURCE_SLEEP_AFTER_ELEMENTS);
long sleepTimeMillis =
helper.getOptions().get(SOURCE_SLEEP_TIME).toMillis();
+ Integer parallelism = helper.getOptions().get(SOURCE_PARALLELISM);
DefaultLookupCache cache = null;
if
(helper.getOptions().get(CACHE_TYPE).equals(LookupOptions.LookupCacheType.PARTIAL))
{
cache = DefaultLookupCache.fromConfig(helper.getOptions());
@@ -571,7 +574,8 @@ public final class TestValuesTableFactory
Long.MAX_VALUE,
partitions,
readableMetadata,
- null);
+ null,
+ parallelism);
}
if (disableLookup) {
@@ -746,6 +750,7 @@ public final class TestValuesTableFactory
SOURCE_NUM_ELEMENT_TO_SKIP,
SOURCE_SLEEP_AFTER_ELEMENTS,
SOURCE_SLEEP_TIME,
+ SOURCE_PARALLELISM,
INTERNAL_DATA,
CACHE_TYPE,
PARTIAL_CACHE_EXPIRE_AFTER_ACCESS,
@@ -916,6 +921,7 @@ public final class TestValuesTableFactory
private @Nullable int[] groupingSet;
private List<AggregateExpression> aggregateExpressions;
private List<String> acceptedPartitionFilterFields;
+ private final Integer parallelism;
private TestValuesScanTableSourceWithoutProjectionPushDown(
DataType producedDataType,
@@ -934,7 +940,8 @@ public final class TestValuesTableFactory
long limit,
List<Map<String, String>> allPartitions,
Map<String, DataType> readableMetadata,
- @Nullable int[] projectedMetadataFields) {
+ @Nullable int[] projectedMetadataFields,
+ @Nullable Integer parallelism) {
this.producedDataType = producedDataType;
this.changelogMode = changelogMode;
this.boundedness = boundedness;
@@ -954,6 +961,7 @@ public final class TestValuesTableFactory
this.projectedMetadataFields = projectedMetadataFields;
this.groupingSet = null;
this.aggregateExpressions = Collections.emptyList();
+ this.parallelism = parallelism;
}
@Override
@@ -987,7 +995,7 @@ public final class TestValuesTableFactory
sourceFunction = new
FromElementsFunction<>(serializer, values);
}
return SourceFunctionProvider.of(
- sourceFunction, boundedness ==
Boundedness.BOUNDED);
+ sourceFunction, boundedness ==
Boundedness.BOUNDED, parallelism);
} catch (IOException e) {
throw new TableException("Fail to init source
function", e);
}
@@ -999,7 +1007,8 @@ public final class TestValuesTableFactory
terminating == TerminatingLogic.FINITE,
"Values Source doesn't support infinite
InputFormat.");
Collection<RowData> values = convertToRowData(converter);
- return InputFormatProvider.of(new
CollectionInputFormat<>(values, serializer));
+ return InputFormatProvider.of(
+ new CollectionInputFormat<>(values, serializer),
parallelism);
case "DataStream":
checkArgument(
!failingSource,
@@ -1024,6 +1033,11 @@ public final class TestValuesTableFactory
return sourceStream;
}
+ @Override
+ public Optional<Integer> getParallelism() {
+ return Optional.ofNullable(parallelism);
+ }
+
@Override
public boolean isBounded() {
return boundedness == Boundedness.BOUNDED;
@@ -1039,7 +1053,8 @@ public final class TestValuesTableFactory
|| acceptedPartitionFilterFields.isEmpty()) {
Collection<RowData> values2 =
convertToRowData(converter);
return SourceProvider.of(
- new ValuesSource(terminating, boundedness,
values2, serializer));
+ new ValuesSource(terminating, boundedness,
values2, serializer),
+ parallelism);
} else {
Map<Map<String, String>, Collection<RowData>>
partitionValues =
convertToPartitionedRowData(converter);
@@ -1050,7 +1065,7 @@ public final class TestValuesTableFactory
partitionValues,
serializer,
acceptedPartitionFilterFields);
- return SourceProvider.of(source);
+ return SourceProvider.of(source, parallelism);
}
default:
throw new IllegalArgumentException(
@@ -1114,7 +1129,8 @@ public final class TestValuesTableFactory
limit,
allPartitions,
readableMetadata,
- projectedMetadataFields);
+ projectedMetadataFields,
+ parallelism);
}
@Override
@@ -1477,7 +1493,8 @@ public final class TestValuesTableFactory
limit,
allPartitions,
readableMetadata,
- projectedMetadataFields);
+ projectedMetadataFields,
+ null);
}
@Override
diff --git
a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/TableScanTest.xml
b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/TableScanTest.xml
index 52d21087262..8fe6835213c 100644
---
a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/TableScanTest.xml
+++
b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/TableScanTest.xml
@@ -727,4 +727,46 @@ Calc(select=[ts, a, b], where=[>(a, 1)],
changelogMode=[I,UB,UA,D])
]]>
</Resource>
</TestCase>
+
+ <TestCase name="testSetParallelismForSource">
+ <Resource name="sql">
+ <![CDATA[SELECT * FROM src LEFT JOIN changelog_src on src.id =
changelog_src.id WHERE src.c > 1]]>
+ </Resource>
+ <Resource name="ast">
+ <![CDATA[
+LogicalProject(id=[$0], b=[$1], c=[$2], id0=[$3], a=[$4])
++- LogicalFilter(condition=[>($2, 1)])
+ +- LogicalJoin(condition=[=($0, $3)], joinType=[left])
+ :- LogicalTableScan(table=[[default_catalog, default_database, src]])
+ +- LogicalTableScan(table=[[default_catalog, default_database,
changelog_src]])
+]]>
+ </Resource>
+ <Resource name="optimized exec plan">
+ <![CDATA[
+Join(joinType=[LeftOuterJoin], where=[(id = id0)], select=[id, b, c, id0, a],
leftInputSpec=[NoUniqueKey], rightInputSpec=[JoinKeyContainsUniqueKey])
+:- Exchange(distribution=[hash[id]])
+: +- Calc(select=[id, b, c], where=[(c > 1)])
+: +- TableSourceScan(table=[[default_catalog, default_database, src,
filter=[]]], fields=[id, b, c])
++- Exchange(distribution=[hash[id]])
+ +- ChangelogNormalize(key=[id])
+ +- Exchange(distribution=[hash[id]])
+ +- TableSourceScan(table=[[default_catalog, default_database,
changelog_src]], fields=[id, a])
+]]>
+ </Resource>
+ <Resource name="transformation">
+ <![CDATA[
+TwoInputTransformation{name='Join(joinType=[LeftOuterJoin], where=[(id =
id0)], select=[id, b, c, id0, a], leftInputSpec=[NoUniqueKey],
rightInputSpec=[JoinKeyContainsUniqueKey])', outputType=ROW<`id` INT, `b`
STRING, `c` INT, `id0` INT, `a` STRING>(org.apache.flink.table.data.RowData,
org.apache.flink.table.runtime.typeutils.RowDataSerializer), parallelism=-1}
+ +- PartitionTransformation{name='Exchange(distribution=[hash[id]])',
outputType=ROW<`id` INT, `b` STRING, `c`
INT>(org.apache.flink.table.data.RowData,
org.apache.flink.table.runtime.typeutils.RowDataSerializer), parallelism=-1}
+ +- OneInputTransformation{name='Calc(select=[id, b, c],
where=[(c > 1)])', outputType=ROW<`id` INT, `b` STRING, `c`
INT>(org.apache.flink.table.data.RowData,
org.apache.flink.table.runtime.typeutils.RowDataSerializer), parallelism=-1}
+ +-
SourceTransformationWrapper{name='ChangeToDefaultParallel', outputType=ROW<`id`
INT, `b` STRING, `c` INT>(org.apache.flink.table.data.RowData,
org.apache.flink.table.runtime.typeutils.RowDataSerializer), parallelism=-1}
+ +-
LegacySourceTransformation{name='TableSourceScan(table=[[default_catalog,
default_database, src, filter=[]]], fields=[id, b, c])', outputType=ROW<`id`
INT, `b` STRING, `c` INT>(org.apache.flink.table.data.RowData,
org.apache.flink.table.runtime.typeutils.RowDataSerializer), parallelism=3}
+ +- PartitionTransformation{name='Exchange(distribution=[hash[id]])',
outputType=ROW<`id` INT NOT NULL, `a`
STRING>(org.apache.flink.table.data.RowData,
org.apache.flink.table.runtime.typeutils.RowDataSerializer), parallelism=-1}
+ +- OneInputTransformation{name='ChangelogNormalize(key=[id])',
outputType=ROW<`id` INT NOT NULL, `a`
STRING>(org.apache.flink.table.data.RowData,
org.apache.flink.table.runtime.typeutils.RowDataSerializer), parallelism=-1}
+ +-
PartitionTransformation{name='Exchange(distribution=[hash[id]])',
outputType=ROW<`id` INT NOT NULL, `a`
STRING>(org.apache.flink.table.data.RowData,
org.apache.flink.table.runtime.typeutils.RowDataSerializer), parallelism=-1}
+ +- PartitionTransformation{name='Partitioner',
outputType=ROW<`id` INT NOT NULL, `a`
STRING>(org.apache.flink.table.data.RowData,
org.apache.flink.table.runtime.typeutils.RowDataSerializer), parallelism=-1}
+ +-
SourceTransformationWrapper{name='ChangeToDefaultParallel', outputType=ROW<`id`
INT NOT NULL, `a` STRING>(org.apache.flink.table.data.RowData,
org.apache.flink.table.runtime.typeutils.RowDataSerializer), parallelism=-1}
+ +-
LegacySourceTransformation{name='TableSourceScan(table=[[default_catalog,
default_database, changelog_src]], fields=[id, a])', outputType=ROW<`id` INT
NOT NULL, `a` STRING>(org.apache.flink.table.data.RowData,
org.apache.flink.table.runtime.typeutils.RowDataSerializer), parallelism=5}
+]]>
+ </Resource>
+ </TestCase>
</Root>
diff --git
a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/stream/sql/TableScanTest.scala
b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/stream/sql/TableScanTest.scala
index 0a31589b61c..be1ae70d3fa 100644
---
a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/stream/sql/TableScanTest.scala
+++
b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/stream/sql/TableScanTest.scala
@@ -775,4 +775,42 @@ class TableScanTest extends TableTestBase {
"expression type is CHAR(0) NOT NULL")
.isInstanceOf[ValidationException]
}
+
+ @Test
+ def testSetParallelismForSource(): Unit = {
+ val config = TableConfig.getDefault
+
config.set(ExecutionConfigOptions.TABLE_EXEC_SIMPLIFY_OPERATOR_NAME_ENABLED,
Boolean.box(false))
+ val util = streamTestUtil(config)
+
+ util.addTable("""
+ |CREATE TABLE changelog_src (
+ | id INT,
+ | a STRING,
+ | PRIMARY KEY (id) NOT ENFORCED
+ |) WITH (
+ | 'connector' = 'values',
+ | 'bounded' = 'true',
+ | 'runtime-source' = 'DataStream',
+ | 'scan.parallelism' = '5',
+ | 'enable-projection-push-down' = 'false',
+ | 'changelog-mode' = 'I,UA,D'
+ |)
+ """.stripMargin)
+ util.addTable("""
+ |CREATE TABLE src (
+ | id INT,
+ | b STRING,
+ | c INT
+ |) WITH (
+ | 'connector' = 'values',
+ | 'bounded' = 'true',
+ | 'runtime-source' = 'DataStream',
+ | 'scan.parallelism' = '3',
+ | 'enable-projection-push-down' = 'false'
+ |)
+ """.stripMargin)
+ util.verifyTransformation(
+ "SELECT * FROM src LEFT JOIN changelog_src " +
+ "on src.id = changelog_src.id WHERE src.c > 1")
+ }
}
diff --git
a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/TableSourceITCase.scala
b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/TableSourceITCase.scala
index 26b5d3a1709..a2089ee404e 100644
---
a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/TableSourceITCase.scala
+++
b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/TableSourceITCase.scala
@@ -24,6 +24,7 @@ import org.apache.flink.table.api.{DataTypes, TableException}
import org.apache.flink.table.api.bridge.scala._
import org.apache.flink.table.planner.factories.TestValuesTableFactory
import org.apache.flink.table.planner.runtime.utils.{StreamingTestBase,
TestData, TestingAppendSink, TestingRetractSink}
+import org.apache.flink.table.planner.runtime.utils.TestData.data1
import org.apache.flink.table.planner.utils._
import org.apache.flink.table.runtime.functions.scalar.SourceWatermarkFunction
import org.apache.flink.table.utils.LegacyRowExtension
@@ -33,6 +34,8 @@ import org.assertj.core.api.Assertions.{assertThat,
assertThatThrownBy}
import org.junit.jupiter.api.{BeforeEach, Test}
import org.junit.jupiter.api.extension.RegisterExtension
+import java.util.concurrent.atomic.AtomicInteger
+
class TableSourceITCase extends StreamingTestBase {
@RegisterExtension private val _: EachCallbackWrapper[LegacyRowExtension] =
@@ -421,4 +424,81 @@ class TableSourceITCase extends StreamingTestBase {
val expected = Seq("1,Sarah,1", "2,Rob,1", "3,Mike,1")
assertThat(sink.getAppendResults.sorted).isEqualTo(expected.sorted)
}
+
+ private def innerTestSetParallelism(provider: String, parallelism: Int,
index: Int): Unit = {
+ val dataId = TestValuesTableFactory.registerData(data1)
+ val sourceTableName =
s"test_para_source_${provider.toLowerCase.trim}_$index"
+ val sinkTableName = s"test_para_sink_${provider.toLowerCase.trim}_$index"
+ tEnv.executeSql(s"""
+ |CREATE TABLE $sourceTableName (
+ | the_month INT,
+ | area STRING,
+ | product INT
+ |) WITH (
+ | 'connector' = 'values',
+ | 'data-id' = '$dataId',
+ | 'bounded' = 'true',
+ | 'runtime-source' = '$provider',
+ | 'scan.parallelism' = '$parallelism',
+ | 'enable-projection-push-down' = 'false'
+ |)
+ |""".stripMargin)
+ tEnv.executeSql(s"""
+ |CREATE TABLE $sinkTableName (
+ | the_month INT,
+ | area STRING,
+ | product INT
+ |) WITH (
+ | 'connector' = 'values',
+ | 'sink-insert-only' = 'true'
+ |)
+ |""".stripMargin)
+ tEnv.executeSql(s"INSERT INTO $sinkTableName SELECT * FROM
$sourceTableName").await()
+ }
+
+ @Test
+ def testParallelismWithSourceFunction(): Unit = {
+ val negativeParallelism = -1
+ val validParallelism = 3
+ val index = new AtomicInteger(1)
+
+ assertThatThrownBy(
+ () =>
+ innerTestSetParallelism(
+ "SourceFunction",
+ negativeParallelism,
+ index = index.getAndIncrement))
+ .hasMessageContaining(s"Invalid configured parallelism")
+
+ innerTestSetParallelism("SourceFunction", validParallelism, index =
index.getAndIncrement)
+ }
+
+ @Test
+ def testParallelismWithInputFormat(): Unit = {
+ val negativeParallelism = -1
+ val validParallelism = 3
+ val index = new AtomicInteger(2)
+
+ assertThatThrownBy(
+ () =>
+ innerTestSetParallelism("InputFormat", negativeParallelism, index =
index.getAndIncrement))
+ .hasMessageContaining(s"Invalid configured parallelism")
+
+ innerTestSetParallelism("InputFormat", validParallelism, index =
index.getAndIncrement)
+ }
+
+ @Test
+ def testParallelismWithDataStream(): Unit = {
+ val negativeParallelism = -1
+ val validParallelism = 3
+ val index = new AtomicInteger(3)
+
+ assertThatThrownBy(
+ () =>
+ innerTestSetParallelism("DataStream", negativeParallelism, index =
index.getAndIncrement))
+ .hasMessageContaining(s"Invalid configured parallelism")
+
+ innerTestSetParallelism("DataStream", validParallelism, index =
index.getAndIncrement)
+ }
+
}
diff --git
a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/utils/TableTestBase.scala
b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/utils/TableTestBase.scala
index 1e006f3d94b..e5b418365be 100644
---
a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/utils/TableTestBase.scala
+++
b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/utils/TableTestBase.scala
@@ -19,6 +19,7 @@ package org.apache.flink.table.planner.utils
import org.apache.flink.FlinkVersion
import org.apache.flink.api.common.typeinfo.{AtomicType, TypeInformation}
+import org.apache.flink.api.dag.Transformation
import org.apache.flink.api.java.typeutils.{PojoTypeInfo, RowTypeInfo,
TupleTypeInfo}
import org.apache.flink.api.scala.typeutils.CaseClassTypeInfo
import org.apache.flink.configuration.BatchExecutionOptions
@@ -87,7 +88,7 @@ import org.junit.jupiter.api.extension.{BeforeEachCallback,
ExtendWith, Extensio
import org.junit.jupiter.api.io.TempDir
import org.junit.platform.commons.support.AnnotationSupport
-import java.io.{File, IOException}
+import java.io.{File, IOException, PrintWriter, StringWriter}
import java.net.URL
import java.nio.file.{Files, Path, Paths}
import java.time.Duration
@@ -702,6 +703,20 @@ abstract class TableTestUtilBase(test: TableTestBase,
isStreamingMode: Boolean)
withQueryBlockAlias = false)
}
+ /**
+ * Verify the AST (abstract syntax tree), the optimized exec plan and
tranformation for the given
+ * SELECT query. Note: An exception will be thrown if the given sql can't be
translated to exec
+ * plan and transformation result is wrong.
+ */
+ def verifyTransformation(query: String): Unit = {
+ doVerifyPlan(
+ query,
+ Array.empty[ExplainDetail],
+ withRowType = false,
+ Array(PlanKind.AST, PlanKind.OPT_EXEC, PlanKind.TRANSFORM),
+ withQueryBlockAlias = false)
+ }
+
/** Verify the explain result for the given SELECT query. See more about
[[Table#explain()]]. */
def verifyExplain(query: String): Unit =
verifyExplain(getTableEnv.sqlQuery(query))
@@ -1040,6 +1055,14 @@ abstract class TableTestUtilBase(test: TableTestBase,
isStreamingMode: Boolean)
""
}
+ // build transformation graph if `expectedPlans` contains TRANSFORM
+ val transformation = if (expectedPlans.contains(PlanKind.TRANSFORM)) {
+ val optimizedNodes = getPlanner.translateToExecNodeGraph(optimizedRels,
true)
+ System.lineSeparator +
getTransformations(getPlanner.translateToPlan(optimizedNodes))
+ } else {
+ ""
+ }
+
// check whether the sql equals to the expected if the `relNodes` are
translated from sql
assertSqlEqualsOrExpandFunc()
// check ast plan
@@ -1058,6 +1081,10 @@ abstract class TableTestUtilBase(test: TableTestBase,
isStreamingMode: Boolean)
if (expectedPlans.contains(PlanKind.OPT_EXEC)) {
assertEqualsOrExpand("optimized exec plan", optimizedExecPlan, expand =
false)
}
+ // check transformation graph
+ if (expectedPlans.contains(PlanKind.TRANSFORM)) {
+ assertEqualsOrExpand("transformation", transformation, expand = false)
+ }
}
private def doVerifyExplain(explainResult: String, extraDetails:
ExplainDetail*): Unit = {
@@ -1117,6 +1144,25 @@ abstract class TableTestUtilBase(test: TableTestBase,
isStreamingMode: Boolean)
replaceEstimatedCost(optimizedPlan)
}
+ private def getTransformations(transformations:
java.util.List[Transformation[_]]): String = {
+ val stringWriter = new StringWriter()
+ val printWriter = new PrintWriter(stringWriter)
+ transformations.foreach(transformation => getTransformation(printWriter,
transformation, 0))
+ stringWriter.toString
+ }
+
+ private def getTransformation(
+ printWriter: PrintWriter,
+ transformation: Transformation[_],
+ level: Int): Unit = {
+ if (level == 0) {
+ printWriter.println(transformation.toStringWithoutId)
+ } else {
+ printWriter.println(("\t" * level) + "+- " +
transformation.toStringWithoutId)
+ }
+ transformation.getInputs.foreach(child => getTransformation(printWriter,
child, level + 1))
+ }
+
/** Replace the estimated costs for the given plan, because it may be
unstable. */
protected def replaceEstimatedCost(s: String): String = {
var str = s.replaceAll("\\r\\n", "\n")
@@ -1624,6 +1670,9 @@ object PlanKind extends Enumeration {
/** Optimized Execution Plan */
val OPT_EXEC: Value = Value("OPT_EXEC")
+
+ /** Transformation */
+ val TRANSFORM: Value = Value("TRANSFORM")
}
object TableTestUtil {