This is an automated email from the ASF dual-hosted git repository. reuvenlax pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push: new 05305ede453 Merge pull request #27617: Support withFanout and withHotKeyFanout on schema group transform 05305ede453 is described below commit 05305ede45366f158f27fc2b83b9ce00db4df2ab Author: Reuven Lax <re...@google.com> AuthorDate: Sat Jul 22 10:40:43 2023 -0700 Merge pull request #27617: Support withFanout and withHotKeyFanout on schema group transform --- .../apache/beam/sdk/schemas/transforms/Group.java | 162 ++++++++++++++++----- .../beam/sdk/schemas/transforms/GroupTest.java | 96 +++++++++--- .../sdk/extensions/sql/impl/rel/BeamWindowRel.java | 5 +- 3 files changed, 201 insertions(+), 62 deletions(-) diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/Group.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/Group.java index fb48ceed311..fe8933d24d5 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/Group.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/Group.java @@ -17,7 +17,9 @@ */ package org.apache.beam.sdk.schemas.transforms; +import com.google.auto.value.AutoOneOf; import com.google.auto.value.AutoValue; +import java.io.Serializable; import java.util.List; import org.apache.beam.sdk.coders.KvCoder; import org.apache.beam.sdk.schemas.FieldAccessDescriptor; @@ -34,6 +36,7 @@ import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.GroupByKey; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.transforms.SerializableFunction; import org.apache.beam.sdk.transforms.Values; import org.apache.beam.sdk.transforms.WithKeys; import org.apache.beam.sdk.transforms.windowing.GlobalWindows; @@ -42,6 +45,7 @@ import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.Row; import org.apache.beam.sdk.values.TypeDescriptors; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists; +import org.checkerframework.checker.nullness.qual.Nullable; /** * A generic grouping transform for schema {@link PCollection}s. @@ -153,7 +157,7 @@ public class Group { */ public <OutputT> CombineGlobally<InputT, OutputT> aggregate( CombineFn<InputT, ?, OutputT> combineFn) { - return new CombineGlobally<>(combineFn); + return new CombineGlobally<>(combineFn, 0); } /** @@ -169,10 +173,8 @@ public class Group { return new CombineFieldsGlobally<>( SchemaAggregateFn.create() .aggregateFields( - FieldAccessDescriptor.withFieldNames(inputFieldName), - false, - fn, - outputFieldName)); + FieldAccessDescriptor.withFieldNames(inputFieldName), false, fn, outputFieldName), + 0); } public <CombineInputT, AccumT, CombineOutputT> @@ -183,7 +185,8 @@ public class Group { return new CombineFieldsGlobally<>( SchemaAggregateFn.create() .aggregateFields( - FieldAccessDescriptor.withFieldNames(inputFieldName), true, fn, outputFieldName)); + FieldAccessDescriptor.withFieldNames(inputFieldName), true, fn, outputFieldName), + 0); } /** The same as {@link #aggregateField} but using field id. */ @@ -194,7 +197,8 @@ public class Group { return new CombineFieldsGlobally<>( SchemaAggregateFn.create() .aggregateFields( - FieldAccessDescriptor.withFieldIds(inputFieldId), false, fn, outputFieldName)); + FieldAccessDescriptor.withFieldIds(inputFieldId), false, fn, outputFieldName), + 0); } public <CombineInputT, AccumT, CombineOutputT> @@ -205,7 +209,8 @@ public class Group { return new CombineFieldsGlobally<>( SchemaAggregateFn.create() .aggregateFields( - FieldAccessDescriptor.withFieldIds(inputFieldId), true, fn, outputFieldName)); + FieldAccessDescriptor.withFieldIds(inputFieldId), true, fn, outputFieldName), + 0); } /** @@ -221,7 +226,8 @@ public class Group { return new CombineFieldsGlobally<>( SchemaAggregateFn.create() .aggregateFields( - FieldAccessDescriptor.withFieldNames(inputFieldName), false, fn, outputField)); + FieldAccessDescriptor.withFieldNames(inputFieldName), false, fn, outputField), + 0); } public <CombineInputT, AccumT, CombineOutputT> @@ -232,7 +238,8 @@ public class Group { return new CombineFieldsGlobally<>( SchemaAggregateFn.create() .aggregateFields( - FieldAccessDescriptor.withFieldNames(inputFieldName), true, fn, outputField)); + FieldAccessDescriptor.withFieldNames(inputFieldName), true, fn, outputField), + 0); } /** The same as {@link #aggregateField} but using field id. */ @@ -241,7 +248,8 @@ public class Group { return new CombineFieldsGlobally<>( SchemaAggregateFn.create() .aggregateFields( - FieldAccessDescriptor.withFieldIds(inputFielId), false, fn, outputField)); + FieldAccessDescriptor.withFieldIds(inputFielId), false, fn, outputField), + 0); } public <CombineInputT, AccumT, CombineOutputT> @@ -252,7 +260,8 @@ public class Group { return new CombineFieldsGlobally<>( SchemaAggregateFn.create() .aggregateFields( - FieldAccessDescriptor.withFieldIds(inputFielId), true, fn, outputField)); + FieldAccessDescriptor.withFieldIds(inputFielId), true, fn, outputField), + 0); } /** @@ -298,8 +307,8 @@ public class Group { CombineFn<CombineInputT, AccumT, CombineOutputT> fn, String outputFieldName) { return new CombineFieldsGlobally<>( - SchemaAggregateFn.create() - .aggregateFields(fieldsToAggregate, false, fn, outputFieldName)); + SchemaAggregateFn.create().aggregateFields(fieldsToAggregate, false, fn, outputFieldName), + 0); } /** @@ -335,7 +344,7 @@ public class Group { CombineFn<CombineInputT, AccumT, CombineOutputT> fn, Field outputField) { return new CombineFieldsGlobally<>( - SchemaAggregateFn.create().aggregateFields(fieldsToAggregate, false, fn, outputField)); + SchemaAggregateFn.create().aggregateFields(fieldsToAggregate, false, fn, outputField), 0); } @Override @@ -351,14 +360,20 @@ public class Group { public static class CombineGlobally<InputT, OutputT> extends PTransform<PCollection<InputT>, PCollection<OutputT>> { final CombineFn<InputT, ?, OutputT> combineFn; + int fanout; - CombineGlobally(CombineFn<InputT, ?, OutputT> combineFn) { + CombineGlobally(CombineFn<InputT, ?, OutputT> combineFn, int fanout) { this.combineFn = combineFn; + this.fanout = fanout; + } + + public CombineGlobally<InputT, OutputT> withFanout(int fanout) { + return new CombineGlobally<>(combineFn, fanout); } @Override public PCollection<OutputT> expand(PCollection<InputT> input) { - return input.apply("globalCombine", Combine.globally(combineFn)); + return input.apply("globalCombine", Combine.globally(combineFn).withFanout(fanout)); } } @@ -420,9 +435,11 @@ public class Group { */ public static class CombineFieldsGlobally<InputT> extends AggregateCombiner<InputT> { private final SchemaAggregateFn.Inner schemaAggregateFn; + private final int fanout; - CombineFieldsGlobally(SchemaAggregateFn.Inner schemaAggregateFn) { + CombineFieldsGlobally(SchemaAggregateFn.Inner schemaAggregateFn, int fanout) { this.schemaAggregateFn = schemaAggregateFn; + this.fanout = fanout; } /** @@ -431,7 +448,7 @@ public class Group { * determined by the output types of all the composed combiners. */ public static CombineFieldsGlobally<?> create() { - return new CombineFieldsGlobally<>(SchemaAggregateFn.create()); + return new CombineFieldsGlobally<>(SchemaAggregateFn.create(), 0); } /** @@ -450,7 +467,8 @@ public class Group { String outputFieldName) { return new CombineFieldsGlobally<>( schemaAggregateFn.aggregateFields( - FieldAccessDescriptor.withFieldNames(inputFieldName), false, fn, outputFieldName)); + FieldAccessDescriptor.withFieldNames(inputFieldName), false, fn, outputFieldName), + fanout); } public <CombineInputT, AccumT, CombineOutputT> @@ -460,7 +478,8 @@ public class Group { String outputFieldName) { return new CombineFieldsGlobally<>( schemaAggregateFn.aggregateFields( - FieldAccessDescriptor.withFieldNames(inputFieldName), true, fn, outputFieldName)); + FieldAccessDescriptor.withFieldNames(inputFieldName), true, fn, outputFieldName), + fanout); } public <CombineInputT, AccumT, CombineOutputT> CombineFieldsGlobally<InputT> aggregateField( @@ -469,7 +488,8 @@ public class Group { String outputFieldName) { return new CombineFieldsGlobally<>( schemaAggregateFn.aggregateFields( - FieldAccessDescriptor.withFieldIds(inputFieldId), false, fn, outputFieldName)); + FieldAccessDescriptor.withFieldIds(inputFieldId), false, fn, outputFieldName), + fanout); } public <CombineInputT, AccumT, CombineOutputT> @@ -479,7 +499,8 @@ public class Group { String outputFieldName) { return new CombineFieldsGlobally<>( schemaAggregateFn.aggregateFields( - FieldAccessDescriptor.withFieldIds(inputFieldId), true, fn, outputFieldName)); + FieldAccessDescriptor.withFieldIds(inputFieldId), true, fn, outputFieldName), + fanout); } /** @@ -495,7 +516,8 @@ public class Group { Field outputField) { return new CombineFieldsGlobally<>( schemaAggregateFn.aggregateFields( - FieldAccessDescriptor.withFieldNames(inputFieldName), false, fn, outputField)); + FieldAccessDescriptor.withFieldNames(inputFieldName), false, fn, outputField), + fanout); } public <CombineInputT, AccumT, CombineOutputT> @@ -505,7 +527,8 @@ public class Group { Field outputField) { return new CombineFieldsGlobally<>( schemaAggregateFn.aggregateFields( - FieldAccessDescriptor.withFieldNames(inputFieldName), true, fn, outputField)); + FieldAccessDescriptor.withFieldNames(inputFieldName), true, fn, outputField), + fanout); } @Override @@ -513,7 +536,8 @@ public class Group { int inputFieldId, CombineFn<CombineInputT, AccumT, CombineOutputT> fn, Field outputField) { return new CombineFieldsGlobally<>( schemaAggregateFn.aggregateFields( - FieldAccessDescriptor.withFieldIds(inputFieldId), false, fn, outputField)); + FieldAccessDescriptor.withFieldIds(inputFieldId), false, fn, outputField), + fanout); } public <CombineInputT, AccumT, CombineOutputT> @@ -523,7 +547,8 @@ public class Group { Field outputField) { return new CombineFieldsGlobally<>( schemaAggregateFn.aggregateFields( - FieldAccessDescriptor.withFieldIds(inputFieldId), true, fn, outputField)); + FieldAccessDescriptor.withFieldIds(inputFieldId), true, fn, outputField), + fanout); } /** @@ -568,7 +593,8 @@ public class Group { CombineFn<CombineInputT, AccumT, CombineOutputT> fn, String outputFieldName) { return new CombineFieldsGlobally<>( - schemaAggregateFn.aggregateFields(fieldAccessDescriptor, false, fn, outputFieldName)); + schemaAggregateFn.aggregateFields(fieldAccessDescriptor, false, fn, outputFieldName), + fanout); } /** @@ -605,13 +631,17 @@ public class Group { CombineFn<CombineInputT, AccumT, CombineOutputT> fn, Field outputField) { return new CombineFieldsGlobally<>( - schemaAggregateFn.aggregateFields(fieldAccessDescriptor, false, fn, outputField)); + schemaAggregateFn.aggregateFields(fieldAccessDescriptor, false, fn, outputField), fanout); + } + + public CombineFieldsGlobally<InputT> withFanout(int fanout) { + return new CombineFieldsGlobally<>(schemaAggregateFn, fanout); } @Override public PCollection<Row> expand(PCollection<InputT> input) { SchemaAggregateFn.Inner fn = schemaAggregateFn.withSchema(input.getSchema()); - Combine.Globally<Row, Row> combineFn = Combine.globally(fn); + Combine.Globally<Row, Row> combineFn = Combine.globally(fn).withFanout(fanout); if (!(input.getWindowingStrategy().getWindowFn() instanceof GlobalWindows)) { combineFn = combineFn.withoutDefaults(); } @@ -631,6 +661,7 @@ public class Group { */ @AutoValue public abstract static class ByFields<InputT> extends AggregateCombiner<InputT> { + abstract FieldAccessDescriptor getFieldAccessDescriptor(); abstract String getKeyField(); @@ -651,11 +682,11 @@ public class Group { abstract ByFields<InputT> build(); } - class ToKv extends PTransform<PCollection<InputT>, PCollection<KV<Row, Iterable<Row>>>> { + class ToKV extends PTransform<PCollection<InputT>, PCollection<KV<Row, Row>>> { private RowSelector rowSelector; @Override - public PCollection<KV<Row, Iterable<Row>>> expand(PCollection<InputT> input) { + public PCollection<KV<Row, Row>> expand(PCollection<InputT> input) { Schema schema = input.getSchema(); FieldAccessDescriptor resolved = getFieldAccessDescriptor().resolve(schema); rowSelector = new RowSelectorContainer(schema, resolved, true); @@ -666,13 +697,12 @@ public class Group { .apply( "selectKeys", WithKeys.of((Row e) -> rowSelector.select(e)).withKeyType(TypeDescriptors.rows())) - .setCoder(KvCoder.of(SchemaCoder.of(keySchema), SchemaCoder.of(schema))) - .apply("GroupByKey", GroupByKey.create()); + .setCoder(KvCoder.of(SchemaCoder.of(keySchema), SchemaCoder.of(schema))); } } - public ToKv getToKvs() { - return new ToKv(); + public ToKV getToKV() { + return new ToKV(); } private static <InputT> ByFields<InputT> of(FieldAccessDescriptor fieldAccessDescriptor) { @@ -919,7 +949,8 @@ public class Group { .build(); return input - .apply("ToKvs", getToKvs()) + .apply("ToKvs", getToKV()) + .apply("GroupByKey", GroupByKey.create()) .apply( "ToRow", ParDo.of( @@ -942,6 +973,31 @@ public class Group { */ @AutoValue public abstract static class CombineFieldsByFields<InputT> extends AggregateCombiner<InputT> { + + @AutoOneOf(Fanout.Kind.class) + public abstract static class Fanout implements Serializable { + public enum Kind { + NUMBER, + FUNCTION + } + + public abstract Kind getKind(); + + public abstract Integer getNumber(); + + public abstract SerializableFunction<Row, Integer> getFunction(); + + public static Fanout of(int n) { + return AutoOneOf_Group_CombineFieldsByFields_Fanout.number(n); + } + + public static Fanout of(SerializableFunction<Row, Integer> f) { + return AutoOneOf_Group_CombineFieldsByFields_Fanout.function(f); + } + } + + abstract @Nullable Fanout getFanout(); + abstract ByFields<InputT> getByFields(); abstract SchemaAggregateFn.Inner getSchemaAggregateFn(); @@ -954,6 +1010,8 @@ public class Group { @AutoValue.Builder abstract static class Builder<InputT> { + public abstract Builder<InputT> setFanout(@Nullable Fanout value); + abstract Builder<InputT> setByFields(ByFields<InputT> byFields); abstract Builder<InputT> setSchemaAggregateFn(SchemaAggregateFn.Inner schemaAggregateFn); @@ -988,6 +1046,14 @@ public class Group { return toBuilder().setValueField(valueField).build(); } + public CombineFieldsByFields<InputT> withHotKeyFanout(int n) { + return toBuilder().setFanout(Fanout.of(n)).build(); + } + + public CombineFieldsByFields<InputT> withHotKeyFanout(SerializableFunction<Row, Integer> f) { + return toBuilder().setFanout(Fanout.of(f)).build(); + } + /** * Build up an aggregation function over the input elements. * @@ -1187,9 +1253,25 @@ public class Group { .build(); } + PTransform<PCollection<KV<Row, Row>>, PCollection<KV<Row, Row>>> getCombineTransform( + Schema schema) { + SchemaAggregateFn.Inner fn = getSchemaAggregateFn().withSchema(schema); + @Nullable Fanout fanout = getFanout(); + if (fanout != null) { + switch (fanout.getKind()) { + case NUMBER: + return Combine.<Row, Row, Row>perKey(fn).withHotKeyFanout(fanout.getNumber()); + case FUNCTION: + return Combine.<Row, Row, Row>perKey(fn).withHotKeyFanout(fanout.getFunction()); + default: + throw new RuntimeException("Unexpected kind: " + fanout.getKind()); + } + } + return Combine.perKey(fn); + } + @Override public PCollection<Row> expand(PCollection<InputT> input) { - SchemaAggregateFn.Inner fn = getSchemaAggregateFn().withSchema(input.getSchema()); Schema keySchema = getByFields().getKeySchema(input.getSchema()); Schema outputSchema = @@ -1199,8 +1281,8 @@ public class Group { .build(); return input - .apply("ToKvs", getByFields().getToKvs()) - .apply("Combine", Combine.groupedValues(fn)) + .apply("ToKvs", getByFields().getToKV()) + .apply("Combine", getCombineTransform(input.getSchema())) .apply( "ToRow", ParDo.of( diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/transforms/GroupTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/transforms/GroupTest.java index b4b074e0ca0..f6f33208a10 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/transforms/GroupTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/transforms/GroupTest.java @@ -30,6 +30,7 @@ import java.util.ArrayList; import java.util.Collection; import java.util.Iterator; import java.util.List; +import javax.annotation.Nullable; import org.apache.beam.sdk.schemas.AutoValueSchema; import org.apache.beam.sdk.schemas.NoSuchSchemaException; import org.apache.beam.sdk.schemas.Schema; @@ -50,6 +51,7 @@ import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.transforms.Sample; import org.apache.beam.sdk.transforms.SerializableFunction; +import org.apache.beam.sdk.transforms.SerializableFunctions; import org.apache.beam.sdk.transforms.Sum; import org.apache.beam.sdk.transforms.Top; import org.apache.beam.sdk.values.PCollection; @@ -260,17 +262,30 @@ public class GroupTest implements Serializable { @Test @Category(NeedsRunner.class) - public void testGlobalAggregation() { + public void testGlobalAggregationWithoutFanout() { + globalAggregationWithFanout(false); + } + + @Test + @Category(NeedsRunner.class) + public void testGlobalAggregationWithFanout() { + globalAggregationWithFanout(true); + } + + public void globalAggregationWithFanout(boolean withFanout) { Collection<Basic> elements = ImmutableList.of( Basic.of("key1", 1, "value1"), Basic.of("key1", 1, "value2"), Basic.of("key2", 2, "value3"), Basic.of("key2", 2, "value4")); - PCollection<Long> count = - pipeline - .apply(Create.of(elements)) - .apply(Group.<Basic>globally().aggregate(Count.combineFn())); + + Group.CombineGlobally<Basic, Long> transform = + Group.<Basic>globally().aggregate(Count.combineFn()); + if (withFanout) { + transform = transform.withFanout(10); + } + PCollection<Long> count = pipeline.apply(Create.of(elements)).apply(transform); PAssert.that(count).containsInAnyOrder(4L); pipeline.run(); @@ -426,7 +441,17 @@ public class GroupTest implements Serializable { @Test @Category(NeedsRunner.class) - public void testAggregateByMultipleFields() { + public void testAggregateByMultipleFieldsWithoutFanout() { + aggregateByMultipleFieldsWithFanout(false); + } + + @Test + @Category(NeedsRunner.class) + public void testAggregateByMultipleFieldsWithFanout() { + aggregateByMultipleFieldsWithFanout(true); + } + + public void aggregateByMultipleFieldsWithFanout(boolean withFanout) { Collection<Aggregate> elements = ImmutableList.of( Aggregate.of(1, 1, 2), @@ -435,12 +460,14 @@ public class GroupTest implements Serializable { Aggregate.of(4, 2, 5)); List<String> fieldNames = Lists.newArrayList("field1", "field2"); - PCollection<Row> aggregate = - pipeline - .apply(Create.of(elements)) - .apply( - Group.<Aggregate>globally() - .aggregateFields(fieldNames, new MultipleFieldCombineFn(), "field1+field2")); + + Group.CombineFieldsGlobally<Aggregate> transform = + Group.<Aggregate>globally() + .aggregateFields(fieldNames, new MultipleFieldCombineFn(), "field1+field2"); + if (withFanout) { + transform = transform.withFanout(10); + } + PCollection<Row> aggregate = pipeline.apply(Create.of(elements)).apply(transform); Schema outputSchema = Schema.builder().addInt64Field("field1+field2").build(); Row expectedRow = Row.withSchema(outputSchema).addValues(16L).build(); @@ -462,7 +489,25 @@ public class GroupTest implements Serializable { @Test @Category(NeedsRunner.class) - public void testByKeyWithSchemaAggregateFnNestedFields() { + public void testByKeyWithSchemaAggregateFnNestedFieldsNoFanout() { + byKeyWithSchemaAggregateFnNestedFieldsWithFanout(null); + } + + @Test + @Category(NeedsRunner.class) + public void testByKeyWithSchemaAggregateFnNestedFieldsWithNumberFanout() { + byKeyWithSchemaAggregateFnNestedFieldsWithFanout(Group.CombineFieldsByFields.Fanout.of(10)); + } + + @Test + @Category(NeedsRunner.class) + public void testByKeyWithSchemaAggregateFnNestedFieldsWithFunctionFanout() { + byKeyWithSchemaAggregateFnNestedFieldsWithFanout( + Group.CombineFieldsByFields.Fanout.of(SerializableFunctions.constant(10))); + } + + public void byKeyWithSchemaAggregateFnNestedFieldsWithFanout( + @Nullable Group.CombineFieldsByFields.Fanout fanout) { Collection<OuterAggregate> elements = ImmutableList.of( OuterAggregate.of(Aggregate.of(1, 1, 2)), @@ -470,14 +515,23 @@ public class GroupTest implements Serializable { OuterAggregate.of(Aggregate.of(3, 2, 4)), OuterAggregate.of(Aggregate.of(4, 2, 5))); - PCollection<Row> aggregations = - pipeline - .apply(Create.of(elements)) - .apply( - Group.<OuterAggregate>byFieldNames("inner.field2") - .aggregateField("inner.field1", Sum.ofLongs(), "field1_sum") - .aggregateField("inner.field3", Sum.ofIntegers(), "field3_sum") - .aggregateField("inner.field1", Top.largestLongsFn(1), "field1_top")); + Group.CombineFieldsByFields<OuterAggregate> transform = + Group.<OuterAggregate>byFieldNames("inner.field2") + .aggregateField("inner.field1", Sum.ofLongs(), "field1_sum") + .aggregateField("inner.field3", Sum.ofIntegers(), "field3_sum") + .aggregateField("inner.field1", Top.largestLongsFn(1), "field1_top"); + if (fanout != null) { + switch (fanout.getKind()) { + case NUMBER: + transform = transform.withHotKeyFanout(fanout.getNumber()); + break; + case FUNCTION: + transform = transform.withHotKeyFanout(fanout.getFunction()); + break; + } + } + + PCollection<Row> aggregations = pipeline.apply(Create.of(elements)).apply(transform); Schema keySchema = Schema.builder().addInt64Field("field2").build(); Schema valueSchema = diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamWindowRel.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamWindowRel.java index 906258b164a..be88229e755 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamWindowRel.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamWindowRel.java @@ -35,6 +35,7 @@ import org.apache.beam.sdk.extensions.sql.impl.utils.CalciteUtils; import org.apache.beam.sdk.schemas.Schema; import org.apache.beam.sdk.transforms.Combine; import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.GroupByKey; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.values.KV; @@ -254,7 +255,9 @@ public class BeamWindowRel extends Window implements BeamRelNode { org.apache.beam.sdk.schemas.transforms.Group.ByFields<Row> myg = org.apache.beam.sdk.schemas.transforms.Group.byFieldIds(af.partitionKeys); PCollection<KV<Row, Iterable<Row>>> partitionBy = - inputData.apply(prefix + "partitionBy", myg.getToKvs()); + inputData + .apply(prefix + "partitionByKV", myg.getToKV()) + .apply(prefix + "partitionByGK", GroupByKey.create()); partitioned = partitionBy .apply(prefix + "selectOnlyValues", ParDo.of(new SelectOnlyValues()))