[ 
https://issues.apache.org/jira/browse/BEAM-4461?focusedWorklogId=152156&page=com.atlassian.jira.plugin.system.issuetabpanels:worklog-tabpanel#worklog-152156
 ]

ASF GitHub Bot logged work on BEAM-4461:
----------------------------------------

                Author: ASF GitHub Bot
            Created on: 08/Oct/18 08:12
            Start Date: 08/Oct/18 08:12
    Worklog Time Spent: 10m 
      Work Description: reuvenlax closed pull request #6298: [BEAM-4461] 
Introduce Group transform.
URL: https://github.com/apache/beam/pull/6298
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git 
a/sdks/java/core/src/main/java/org/apache/beam/sdk/coders/CoderRegistry.java 
b/sdks/java/core/src/main/java/org/apache/beam/sdk/coders/CoderRegistry.java
index 865b2488608..b9e86649124 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/coders/CoderRegistry.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/coders/CoderRegistry.java
@@ -322,8 +322,7 @@ public void registerCoderForType(TypeDescriptor<?> type, 
Coder<?> coder) {
     if (paramCoderOrNull != null) {
       return paramCoderOrNull;
     } else {
-      throw new CannotProvideCoderException(
-          "Cannot infer coder for type parameter " + param.getName());
+      throw new CannotProvideCoderException("Cannot infer coder for type 
parameter " + param);
     }
   }
 
diff --git 
a/sdks/java/core/src/main/java/org/apache/beam/sdk/coders/RowCoder.java 
b/sdks/java/core/src/main/java/org/apache/beam/sdk/coders/RowCoder.java
index 7e677a51531..bcb87adfcee 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/coders/RowCoder.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/coders/RowCoder.java
@@ -29,6 +29,7 @@
 import javax.annotation.Nullable;
 import org.apache.beam.sdk.annotations.Experimental;
 import org.apache.beam.sdk.schemas.Schema;
+import org.apache.beam.sdk.schemas.Schema.Field;
 import org.apache.beam.sdk.schemas.Schema.FieldType;
 import org.apache.beam.sdk.schemas.Schema.TypeName;
 import org.apache.beam.sdk.util.SerializableUtils;
@@ -116,7 +117,35 @@ public Schema getSchema() {
 
   @Override
   public void verifyDeterministic()
-      throws org.apache.beam.sdk.coders.Coder.NonDeterministicException {}
+      throws org.apache.beam.sdk.coders.Coder.NonDeterministicException {
+    verifyDeterministic(schema);
+  }
+
+  private void verifyDeterministic(Schema schema)
+      throws org.apache.beam.sdk.coders.Coder.NonDeterministicException {
+    for (Field field : schema.getFields()) {
+      verifyDeterministic(field.getType());
+    }
+  }
+
+  private void verifyDeterministic(FieldType fieldType)
+      throws org.apache.beam.sdk.coders.Coder.NonDeterministicException {
+    switch (fieldType.getTypeName()) {
+      case MAP:
+        throw new NonDeterministicException(
+            this,
+            "Map-valued fields cannot be used in keys as Beam requires 
deterministic encoding for"
+                + " keys.");
+      case ROW:
+        verifyDeterministic(fieldType.getRowSchema());
+        break;
+      case ARRAY:
+        verifyDeterministic(fieldType.getCollectionElementType());
+        break;
+      default:
+        break;
+    }
+  }
 
   @Override
   public boolean consistentWithEquals() {
@@ -124,8 +153,20 @@ public boolean consistentWithEquals() {
   }
 
   /** Returns the coder used for a given primitive type. */
-  public static <T> Coder<T> coderForPrimitiveType(TypeName typeName) {
-    return (Coder<T>) CODER_MAP.get(typeName);
+  public static <T> Coder<T> coderForFieldType(FieldType fieldType) {
+    switch (fieldType.getTypeName()) {
+      case ROW:
+        return (Coder<T>) RowCoder.of(fieldType.getRowSchema());
+      case ARRAY:
+        return (Coder<T>) 
ListCoder.of(coderForFieldType(fieldType.getCollectionElementType()));
+      case MAP:
+        return (Coder<T>)
+            MapCoder.of(
+                coderForFieldType(fieldType.getMapKeyType()),
+                coderForFieldType(fieldType.getMapValueType()));
+      default:
+        return (Coder<T>) CODER_MAP.get(fieldType.getTypeName());
+    }
   }
 
   /** Return the estimated serialized size of a give row object. */
diff --git 
a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/FieldTypeDescriptors.java
 
b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/FieldTypeDescriptors.java
index 252ea5e1831..39eddb96ba5 100644
--- 
a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/FieldTypeDescriptors.java
+++ 
b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/FieldTypeDescriptors.java
@@ -102,7 +102,7 @@ private static FieldType getArrayFieldType(TypeDescriptor 
typeDescriptor) {
         return 
FieldType.array(fieldTypeForJavaType(TypeDescriptor.of(params[0])));
       }
     }
-    throw new RuntimeException("Coupld not determine array parameter type for 
field.");
+    throw new RuntimeException("Could not determine array parameter type for 
field.");
   }
 
   private static FieldType getMapFieldType(TypeDescriptor typeDescriptor) {
diff --git 
a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/Convert.java
 
b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/Convert.java
index 582068b0c40..183a578cf93 100644
--- 
a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/Convert.java
+++ 
b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/Convert.java
@@ -94,10 +94,6 @@
       extends PTransform<PCollection<InputT>, PCollection<OutputT>> {
     TypeDescriptor<OutputT> outputTypeDescriptor;
 
-    ConvertTransform(Class<OutputT> outputClass) {
-      this(TypeDescriptor.of(outputClass));
-    }
-
     ConvertTransform(TypeDescriptor<OutputT> outputTypeDescriptor) {
       this.outputTypeDescriptor = outputTypeDescriptor;
     }
diff --git 
a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/Group.java
 
b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/Group.java
new file mode 100644
index 00000000000..ffd343e88fc
--- /dev/null
+++ 
b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/Group.java
@@ -0,0 +1,684 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.sdk.schemas.transforms;
+
+import java.util.List;
+import javax.annotation.Nullable;
+import org.apache.beam.sdk.annotations.Experimental;
+import org.apache.beam.sdk.annotations.Experimental.Kind;
+import org.apache.beam.sdk.coders.KvCoder;
+import org.apache.beam.sdk.schemas.FieldAccessDescriptor;
+import org.apache.beam.sdk.schemas.Schema;
+import org.apache.beam.sdk.schemas.Schema.Field;
+import org.apache.beam.sdk.schemas.SchemaCoder;
+import org.apache.beam.sdk.transforms.Combine;
+import org.apache.beam.sdk.transforms.Combine.CombineFn;
+import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.GroupByKey;
+import org.apache.beam.sdk.transforms.PTransform;
+import org.apache.beam.sdk.transforms.ParDo;
+import org.apache.beam.sdk.transforms.Values;
+import org.apache.beam.sdk.transforms.WithKeys;
+import org.apache.beam.sdk.values.KV;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.Row;
+
+/**
+ * A generic grouping transform for schema {@link PCollection}s.
+ *
+ * <p>When used without a combiner, this transforms simply acts as a {@link 
GroupByKey} but without
+ * the need for the user to explicitly extract the keys. For example, consider 
the following input
+ * type:
+ *
+ * <pre>{@code
+ * {@literal @DefaultSchema(JavaFieldSchema.class)}
+ * public class UserPurchase {
+ *   public String userId;
+ *   public String country;
+ *   public long cost;
+ *   public double transactionDuration;
+ * }
+ *
+ * {@literal PCollection<UserPurchase>} purchases = readUserPurchases();
+ * }</pre>
+ *
+ * <p>You can group all purchases by user and country as follows:
+ *
+ * <pre>{@code
+ * {@literal @DefaultSchema}(JavaFieldSchema.class)
+ * {@literal PCollection<KV<Row, Iterable<UserPurchase>>} byUser =
+ *   purchases.apply(Group.byFieldNames("userId', "country"));
+ * }</pre>
+ *
+ * <p>However often an aggregation of some form is desired. The builder 
methods inside the Group
+ * class allows building up separate aggregations for every field (or set of 
fields) on the input
+ * schema, and generating an output schema based on these aggregations. For 
example:
+ *
+ * <pre>{@code
+ * PCollection<KV<Row, Row>> aggregated = purchases
+ *      .apply(Group.byFieldNames("userId', "country")
+ *          .aggregateField("cost", Sum.ofLongs(), "total_cost")
+ *          .aggregateField("cost", Top.<Long>largestLongsFn(10), 
"top_purchases")
+ *          .aggregateField("cost", ApproximateQuantilesCombineFn.create(21),
+ *            Field.of("transactionDurations", 
FieldType.array(FieldType.INT64)));
+ * }</pre>
+ *
+ * <p>The result will be a new row schema containing the fields total_cost, 
top_purchases, and
+ * transactionDurations, containing the sum of all purchases costs (for that 
user and country), the
+ * top ten purchases, and a histogram of transaction durations.
+ *
+ * <p>Note that usually the field type can be automatically inferred from the 
{@link CombineFn}
+ * passed in. However sometimes it cannot be inferred, due to Java type 
erasure, in which case a
+ * {@link Field} object containing the field type must be passed in. This is 
currently the case for
+ * ApproximateQuantilesCombineFn in the above example.
+ */
+@Experimental(Kind.SCHEMAS)
+public class Group {
+  /**
+   * Returns a transform that groups all elements in the input {@link 
PCollection}. The returned
+   * transform contains further builder methods to control how the grouping is 
done.
+   */
+  public static <T> Global<T> globally() {
+    return new Global<>();
+  }
+
+  /**
+   * Returns a transform that groups all elements in the input {@link 
PCollection} keyed by the list
+   * of fields specified. The output of this transform will be a {@link KV} 
keyed by a {@link Row}
+   * containing the specified extracted fields. The returned transform 
contains further builder
+   * methods to control how the grouping is done.
+   */
+  public static <T> ByFields<T> byFieldNames(String... fieldNames) {
+    return new ByFields<>(FieldAccessDescriptor.withFieldNames(fieldNames));
+  }
+
+  /**
+   * Returns a transform that groups all elements in the input {@link 
PCollection} keyed by the list
+   * of fields specified. The output of this transform will be a {@link KV} 
keyed by a {@link Row}
+   * containing the specified extracted fields. The returned transform 
contains further builder
+   * methods to control how the grouping is done.
+   */
+  public static <T> ByFields<T> byFieldIds(Integer... fieldIds) {
+    return new ByFields<>(FieldAccessDescriptor.withFieldIds(fieldIds));
+  }
+
+  /**
+   * Returns a transform that groups all elements in the input {@link 
PCollection} keyed by the
+   * fields specified. The output of this transform will be a {@link KV} keyed 
by a {@link Row}
+   * containing the specified extracted fields. The returned transform 
contains further builder
+   * methods to control how the grouping is done.
+   */
+  public static <T> ByFields<T> byFieldAccessDescriptor(FieldAccessDescriptor 
fieldAccess) {
+    return new ByFields<>(fieldAccess);
+  }
+
+  /** A {@link PTransform} for doing global aggregations on schema 
PCollections. */
+  public static class Global<InputT>
+      extends PTransform<PCollection<InputT>, PCollection<Iterable<InputT>>> {
+    /**
+     * Aggregate the grouped data using the specified {@link CombineFn}. The 
resulting {@link
+     * PCollection} will have type OutputT.
+     */
+    public <OutputT> CombineGlobally<InputT, OutputT> aggregate(
+        CombineFn<InputT, ?, OutputT> combineFn) {
+      return new CombineGlobally<>(combineFn);
+    }
+
+    /**
+     * Build up an aggregation function over the input elements.
+     *
+     * <p>This method specifies an aggregation over single field of the input. 
The union of all
+     * calls to aggregateField and aggregateFields will determine the output 
schema.
+     */
+    public <CombineInputT, AccumT, CombineOutputT> 
CombineFieldsGlobally<InputT> aggregateField(
+        String inputFieldName,
+        CombineFn<CombineInputT, AccumT, CombineOutputT> fn,
+        String outputFieldName) {
+      return new CombineFieldsGlobally<>(
+          SchemaAggregateFn.<InputT>create()
+              .aggregateFields(
+                  FieldAccessDescriptor.withFieldNames(inputFieldName), fn, 
outputFieldName));
+    }
+
+    /**
+     * Build up an aggregation function over the input elements.
+     *
+     * <p>This method specifies an aggregation over single field of the input. 
The union of all
+     * calls to aggregateField and aggregateFields will determine the output 
schema.
+     */
+    public <CombineInputT, AccumT, CombineOutputT> 
CombineFieldsGlobally<InputT> aggregateField(
+        String inputFieldName,
+        CombineFn<CombineInputT, AccumT, CombineOutputT> fn,
+        Field outputField) {
+      return new CombineFieldsGlobally<>(
+          SchemaAggregateFn.<InputT>create()
+              .aggregateFields(
+                  FieldAccessDescriptor.withFieldNames(inputFieldName), fn, 
outputField));
+    }
+
+    /**
+     * Build up an aggregation function over the input elements.
+     *
+     * <p>This method specifies an aggregation over multiple fields of the 
input. The union of all
+     * calls to aggregateField and aggregateFields will determine the output 
schema.
+     *
+     * <p>Field types in the output schema will be inferred from the provided 
combine function.
+     * Sometimes the field type cannot be inferred due to Java's type erasure. 
In that case, use the
+     * overload that allows setting the output field type explicitly.
+     */
+    public <CombineInputT, AccumT, CombineOutputT> 
CombineFieldsGlobally<InputT> aggregateFields(
+        List<String> inputFieldNames,
+        CombineFn<CombineInputT, AccumT, CombineOutputT> fn,
+        String outputFieldName) {
+      return aggregateFields(
+          FieldAccessDescriptor.withFieldNames(inputFieldNames), fn, 
outputFieldName);
+    }
+
+    /**
+     * Build up an aggregation function over the input elements.
+     *
+     * <p>This method specifies an aggregation over multiple fields of the 
input. The union of all
+     * calls to aggregateField and aggregateFields will determine the output 
schema.
+     *
+     * <p>Field types in the output schema will be inferred from the provided 
combine function.
+     * Sometimes the field type cannot be inferred due to Java's type erasure. 
In that case, use the
+     * overload that allows setting the output field type explicitly.
+     */
+    public <CombineInputT, AccumT, CombineOutputT> 
CombineFieldsGlobally<InputT> aggregateFields(
+        FieldAccessDescriptor fieldsToAggregate,
+        CombineFn<CombineInputT, AccumT, CombineOutputT> fn,
+        String outputFieldName) {
+      return new CombineFieldsGlobally<>(
+          SchemaAggregateFn.<InputT>create()
+              .aggregateFields(fieldsToAggregate, fn, outputFieldName));
+    }
+
+    /**
+     * Build up an aggregation function over the input elements.
+     *
+     * <p>This method specifies an aggregation over multiple fields of the 
input. The union of all
+     * calls to aggregateField and aggregateFields will determine the output 
schema.
+     */
+    public <CombineInputT, AccumT, CombineOutputT> 
CombineFieldsGlobally<InputT> aggregateFields(
+        List<String> inputFieldNames,
+        CombineFn<CombineInputT, AccumT, CombineOutputT> fn,
+        Field outputField) {
+      return aggregateFields(
+          FieldAccessDescriptor.withFieldNames(inputFieldNames), fn, 
outputField);
+    }
+
+    /**
+     * Build up an aggregation function over the input elements.
+     *
+     * <p>This method specifies an aggregation over multiple fields of the 
input. The union of all
+     * calls to aggregateField and aggregateFields will determine the output 
schema.
+     */
+    public <CombineInputT, AccumT, CombineOutputT> 
CombineFieldsGlobally<InputT> aggregateFields(
+        FieldAccessDescriptor fieldsToAggregate,
+        CombineFn<CombineInputT, AccumT, CombineOutputT> fn,
+        Field outputField) {
+      return new CombineFieldsGlobally<>(
+          
SchemaAggregateFn.<InputT>create().aggregateFields(fieldsToAggregate, fn, 
outputField));
+    }
+
+    @Override
+    public PCollection<Iterable<InputT>> expand(PCollection<InputT> input) {
+      return input
+          .apply(WithKeys.of((Void) null))
+          .apply(GroupByKey.create())
+          .apply(Values.create());
+    }
+  }
+
+  /** a {@link PTransform} that does a global combine using a provider {@link 
CombineFn}. */
+  public static class CombineGlobally<InputT, OutputT>
+      extends PTransform<PCollection<InputT>, PCollection<OutputT>> {
+    final CombineFn<InputT, ?, OutputT> combineFn;
+
+    CombineGlobally(CombineFn<InputT, ?, OutputT> combineFn) {
+      this.combineFn = combineFn;
+    }
+
+    @Override
+    public PCollection<OutputT> expand(PCollection<InputT> input) {
+      return input.apply(Combine.globally(combineFn));
+    }
+  }
+
+  /**
+   * a {@link PTransform} that does a global combine using an aggregation 
built up by calls to
+   * aggregateField and aggregateFields. The output of this transform will 
have a schema that is
+   * determined by the output types of all the composed combiners.
+   */
+  public static class CombineFieldsGlobally<InputT>
+      extends PTransform<PCollection<InputT>, PCollection<Row>> {
+    private final SchemaAggregateFn.Inner<InputT> schemaAggregateFn;
+
+    CombineFieldsGlobally(SchemaAggregateFn.Inner<InputT> schemaAggregateFn) {
+      this.schemaAggregateFn = schemaAggregateFn;
+    }
+
+    /**
+     * Build up an aggregation function over the input elements.
+     *
+     * <p>This method specifies an aggregation over single field of the input. 
The union of all
+     * calls to aggregateField and aggregateFields will determine the output 
schema.
+     *
+     * <p>Field types in the output schema will be inferred from the provided 
combine function.
+     * Sometimes the field type cannot be inferred due to Java's type erasure. 
In that case, use the
+     * overload that allows setting the output field type explicitly.
+     */
+    public <CombineInputT, AccumT, CombineOutputT> 
CombineFieldsGlobally<InputT> aggregateField(
+        String inputFieldName,
+        CombineFn<CombineInputT, AccumT, CombineOutputT> fn,
+        String outputFieldName) {
+      return new CombineFieldsGlobally<>(
+          schemaAggregateFn.aggregateFields(
+              FieldAccessDescriptor.withFieldNames(inputFieldName), fn, 
outputFieldName));
+    }
+
+    /**
+     * Build up an aggregation function over the input elements.
+     *
+     * <p>This method specifies an aggregation over single field of the input. 
The union of all
+     * calls to aggregateField and aggregateFields will determine the output 
schema.
+     */
+    public <CombineInputT, AccumT, CombineOutputT> 
CombineFieldsGlobally<InputT> aggregateField(
+        String inputFieldName,
+        CombineFn<CombineInputT, AccumT, CombineOutputT> fn,
+        Field outputField) {
+      return new CombineFieldsGlobally<>(
+          schemaAggregateFn.aggregateFields(
+              FieldAccessDescriptor.withFieldNames(inputFieldName), fn, 
outputField));
+    }
+
+    /**
+     * Build up an aggregation function over the input elements.
+     *
+     * <p>This method specifies an aggregation over multiple fields of the 
input. The union of all
+     * calls to aggregateField and aggregateFields will determine the output 
schema.
+     *
+     * <p>Field types in the output schema will be inferred from the provided 
combine function.
+     * Sometimes the field type cannot be inferred due to Java's type erasure. 
In that case, use the
+     * overload that allows setting the output field type explicitly.
+     */
+    public <CombineInputT, AccumT, CombineOutputT> 
CombineFieldsGlobally<InputT> aggregateFields(
+        List<String> inputFieldNames,
+        CombineFn<CombineInputT, AccumT, CombineOutputT> fn,
+        String outputFieldName) {
+      return aggregateFields(
+          FieldAccessDescriptor.withFieldNames(inputFieldNames), fn, 
outputFieldName);
+    }
+
+    /**
+     * Build up an aggregation function over the input elements.
+     *
+     * <p>This method specifies an aggregation over multiple fields of the 
input. The union of all
+     * calls to aggregateField and aggregateFields will determine the output 
schema.
+     *
+     * <p>Field types in the output schema will be inferred from the provided 
combine function.
+     * Sometimes the field type cannot be inferred due to Java's type erasure. 
In that case, use the
+     * overload that allows setting the output field type explicitly.
+     */
+    public <CombineInputT, AccumT, CombineOutputT> 
CombineFieldsGlobally<InputT> aggregateFields(
+        FieldAccessDescriptor fieldAccessDescriptor,
+        CombineFn<CombineInputT, AccumT, CombineOutputT> fn,
+        String outputFieldName) {
+      return new CombineFieldsGlobally<>(
+          schemaAggregateFn.aggregateFields(fieldAccessDescriptor, fn, 
outputFieldName));
+    }
+
+    /**
+     * Build up an aggregation function over the input elements.
+     *
+     * <p>This method specifies an aggregation over multiple fields of the 
input. The union of all
+     * calls to aggregateField and aggregateFields will determine the output 
schema.
+     */
+    public <CombineInputT, AccumT, CombineOutputT> 
CombineFieldsGlobally<InputT> aggregateFields(
+        List<String> inputFieldNames,
+        CombineFn<CombineInputT, AccumT, CombineOutputT> fn,
+        Field outputField) {
+      return aggregateFields(
+          FieldAccessDescriptor.withFieldNames(inputFieldNames), fn, 
outputField);
+    }
+
+    /**
+     * Build up an aggregation function over the input elements.
+     *
+     * <p>This method specifies an aggregation over multiple fields of the 
input. The union of all
+     * calls to aggregateField and aggregateFields will determine the output 
schema.
+     */
+    public <CombineInputT, AccumT, CombineOutputT> 
CombineFieldsGlobally<InputT> aggregateFields(
+        FieldAccessDescriptor fieldAccessDescriptor,
+        CombineFn<CombineInputT, AccumT, CombineOutputT> fn,
+        Field outputField) {
+      return new CombineFieldsGlobally<>(
+          schemaAggregateFn.aggregateFields(fieldAccessDescriptor, fn, 
outputField));
+    }
+
+    @Override
+    public PCollection<Row> expand(PCollection<InputT> input) {
+      SchemaAggregateFn.Inner<InputT> fn =
+          schemaAggregateFn.withSchema(input.getSchema(), 
input.getToRowFunction());
+      return 
input.apply(Combine.globally(fn)).setRowSchema(fn.getOutputSchema());
+    }
+  }
+
+  /**
+   * a {@link PTransform} that groups schema elements based on the given 
fields.
+   *
+   * <p>The output of this transform is a KV where the key type is a {@link 
Row} containing the
+   * extracted fields.
+   */
+  public static class ByFields<InputT>
+      extends PTransform<PCollection<InputT>, PCollection<KV<Row, 
Iterable<InputT>>>> {
+    private final FieldAccessDescriptor fieldAccessDescriptor;
+    @Nullable private Schema keySchema = null;
+
+    private ByFields(FieldAccessDescriptor fieldAccessDescriptor) {
+      this.fieldAccessDescriptor = fieldAccessDescriptor;
+    }
+
+    Schema getKeySchema() {
+      return keySchema;
+    }
+
+    /**
+     * Aggregate the grouped data using the specified {@link CombineFn}. The 
resulting {@link
+     * PCollection} will have type {@literal PCollection<KV<Row, OutputT>>}.
+     */
+    public <OutputT> CombineByFields<InputT, OutputT> aggregate(
+        CombineFn<InputT, ?, OutputT> combineFn) {
+      return new CombineByFields<>(this, combineFn);
+    }
+
+    /**
+     * Build up an aggregation function over the input elements.
+     *
+     * <p>This method specifies an aggregation over single field of the input. 
The union of all
+     * calls to aggregateField and aggregateFields will determine the output 
schema.
+     *
+     * <p>Field types in the output schema will be inferred from the provided 
combine function.
+     * Sometimes the field type cannot be inferred due to Java's type erasure. 
In that case, use the
+     * overload that allows setting the output field type explicitly.
+     */
+    public <CombineInputT, AccumT, CombineOutputT> 
CombineFieldsByFields<InputT> aggregateField(
+        String inputFieldName,
+        CombineFn<CombineInputT, AccumT, CombineOutputT> fn,
+        String outputFieldName) {
+      return new CombineFieldsByFields<>(
+          this,
+          SchemaAggregateFn.<InputT>create()
+              .aggregateFields(
+                  FieldAccessDescriptor.withFieldNames(inputFieldName), fn, 
outputFieldName));
+    }
+
+    /**
+     * Build up an aggregation function over the input elements.
+     *
+     * <p>This method specifies an aggregation over single field of the input. 
The union of all
+     * calls to aggregateField and aggregateFields will determine the output 
schema.
+     */
+    public <CombineInputT, AccumT, CombineOutputT> 
CombineFieldsByFields<InputT> aggregateField(
+        String inputFieldName,
+        CombineFn<CombineInputT, AccumT, CombineOutputT> fn,
+        Field outputField) {
+      return new CombineFieldsByFields<>(
+          this,
+          SchemaAggregateFn.<InputT>create()
+              .aggregateFields(
+                  FieldAccessDescriptor.withFieldNames(inputFieldName), fn, 
outputField));
+    }
+
+    /**
+     * Build up an aggregation function over the input elements.
+     *
+     * <p>This method specifies an aggregation over multiple fields of the 
input. The union of all
+     * calls to aggregateField and aggregateFields will determine the output 
schema.
+     *
+     * <p>Field types in the output schema will be inferred from the provided 
combine function.
+     * Sometimes the field type cannot be inferred due to Java's type erasure. 
In that case, use the
+     * overload that allows setting the output field type explicitly.
+     */
+    public <CombineInputT, AccumT, CombineOutputT> 
CombineFieldsByFields<InputT> aggregateFields(
+        List<String> inputFieldNames,
+        CombineFn<CombineInputT, AccumT, CombineOutputT> fn,
+        String outputFieldName) {
+      return aggregateFields(
+          FieldAccessDescriptor.withFieldNames(inputFieldNames), fn, 
outputFieldName);
+    }
+
+    /**
+     * Build up an aggregation function over the input elements.
+     *
+     * <p>This method specifies an aggregation over multiple fields of the 
input. The union of all
+     * calls to aggregateField and aggregateFields will determine the output 
schema.
+     *
+     * <p>Field types in the output schema will be inferred from the provided 
combine function.
+     * Sometimes the field type cannot be inferred due to Java's type erasure. 
In that case, use the
+     * overload that allows setting the output field type explicitly.
+     */
+    public <CombineInputT, AccumT, CombineOutputT> 
CombineFieldsByFields<InputT> aggregateFields(
+        FieldAccessDescriptor fieldsToAggregate,
+        CombineFn<CombineInputT, AccumT, CombineOutputT> fn,
+        String outputFieldName) {
+      return new CombineFieldsByFields<>(
+          this,
+          SchemaAggregateFn.<InputT>create()
+              .aggregateFields(fieldsToAggregate, fn, outputFieldName));
+    }
+
+    /**
+     * Build up an aggregation function over the input elements.
+     *
+     * <p>This method specifies an aggregation over multiple fields of the 
input. The union of all
+     * calls to aggregateField and aggregateFields will determine the output 
schema.
+     */
+    public <CombineInputT, AccumT, CombineOutputT> 
CombineFieldsByFields<InputT> aggregateFields(
+        List<String> inputFieldNames,
+        CombineFn<CombineInputT, AccumT, CombineOutputT> fn,
+        Field outputField) {
+      return aggregateFields(
+          FieldAccessDescriptor.withFieldNames(inputFieldNames), fn, 
outputField);
+    }
+
+    /**
+     * Build up an aggregation function over the input elements.
+     *
+     * <p>This method specifies an aggregation over multiple fields of the 
input. The union of all
+     * calls to aggregateField and aggregateFields will determine the output 
schema.
+     */
+    public <CombineInputT, AccumT, CombineOutputT> 
CombineFieldsByFields<InputT> aggregateFields(
+        FieldAccessDescriptor fieldsToAggregate,
+        CombineFn<CombineInputT, AccumT, CombineOutputT> fn,
+        Field outputField) {
+      return new CombineFieldsByFields<>(
+          this,
+          
SchemaAggregateFn.<InputT>create().aggregateFields(fieldsToAggregate, fn, 
outputField));
+    }
+
+    @Override
+    public PCollection<KV<Row, Iterable<InputT>>> expand(PCollection<InputT> 
input) {
+      Schema schema = input.getSchema();
+      FieldAccessDescriptor resolved = fieldAccessDescriptor.resolve(schema);
+      keySchema = Select.getOutputSchema(schema, resolved);
+      return input
+          .apply(
+              "Group by fields",
+              ParDo.of(
+                  new DoFn<InputT, KV<Row, InputT>>() {
+                    @ProcessElement
+                    public void process(
+                        @Element InputT element,
+                        @Element Row row,
+                        OutputReceiver<KV<Row, InputT>> o) {
+                      o.output(KV.of(Select.selectRow(row, resolved, schema, 
keySchema), element));
+                    }
+                  }))
+          .setCoder(KvCoder.of(SchemaCoder.of(keySchema), input.getCoder()))
+          .apply(GroupByKey.create());
+    }
+  }
+
+  /**
+   * a {@link PTransform} that does a per0-key combine using a specified 
{@link CombineFn}.
+   *
+   * <p>The output of this transform is a {@literal <KV<Row, OutputT>} where 
the key type is a
+   * {@link Row} containing the extracted fields.
+   */
+  public static class CombineByFields<InputT, OutputT>
+      extends PTransform<PCollection<InputT>, PCollection<KV<Row, OutputT>>> {
+    private final ByFields<InputT> byFields;
+    private final CombineFn<InputT, ?, OutputT> combineFn;
+
+    CombineByFields(ByFields<InputT> byFields, CombineFn<InputT, ?, OutputT> 
combineFn) {
+      this.byFields = byFields;
+      this.combineFn = combineFn;
+    }
+
+    @Override
+    public PCollection<KV<Row, OutputT>> expand(PCollection<InputT> input) {
+      return input.apply(byFields).apply(Combine.groupedValues(combineFn));
+    }
+  }
+
+  /**
+   * a {@link PTransform} that does a per-key combine using an aggregation 
built up by calls to
+   * aggregateField and aggregateFields. The output of this transform will 
have a schema that is
+   * determined by the output types of all the composed combiners.
+   */
+  public static class CombineFieldsByFields<InputT>
+      extends PTransform<PCollection<InputT>, PCollection<KV<Row, Row>>> {
+    private final ByFields<InputT> byFields;
+    private final SchemaAggregateFn.Inner<InputT> schemaAggregateFn;
+
+    CombineFieldsByFields(
+        ByFields<InputT> byFields, SchemaAggregateFn.Inner<InputT> 
schemaAggregateFn) {
+      this.byFields = byFields;
+      this.schemaAggregateFn = schemaAggregateFn;
+    }
+
+    /**
+     * Build up an aggregation function over the input elements.
+     *
+     * <p>This method specifies an aggregation over single field of the input. 
The union of all
+     * calls to aggregateField and aggregateFields will determine the output 
schema.
+     *
+     * <p>Field types in the output schema will be inferred from the provided 
combine function.
+     * Sometimes the field type cannot be inferred due to Java's type erasure. 
In that case, use the
+     * overload that allows setting the output field type explicitly.
+     */
+    public <CombineInputT, AccumT, CombineOutputT> 
CombineFieldsByFields<InputT> aggregateField(
+        String inputFieldName,
+        CombineFn<CombineInputT, AccumT, CombineOutputT> fn,
+        String outputFieldName) {
+      return new CombineFieldsByFields<>(
+          byFields,
+          schemaAggregateFn.aggregateFields(
+              FieldAccessDescriptor.withFieldNames(inputFieldName), fn, 
outputFieldName));
+    }
+
+    /**
+     * Build up an aggregation function over the input elements.
+     *
+     * <p>This method specifies an aggregation over single field of the input. 
The union of all
+     * calls to aggregateField and aggregateFields will determine the output 
schema.
+     */
+    public <CombineInputT, AccumT, CombineOutputT> 
CombineFieldsByFields<InputT> aggregateField(
+        String inputFieldName,
+        CombineFn<CombineInputT, AccumT, CombineOutputT> fn,
+        Field outputField) {
+      return new CombineFieldsByFields<>(
+          byFields,
+          schemaAggregateFn.aggregateFields(
+              FieldAccessDescriptor.withFieldNames(inputFieldName), fn, 
outputField));
+    }
+
+    /**
+     * Build up an aggregation function over the input elements.
+     *
+     * <p>This method specifies an aggregation over multiple fields of the 
input. The union of all
+     * calls to aggregateField and aggregateFields will determine the output 
schema.
+     *
+     * <p>Field types in the output schema will be inferred from the provided 
combine function.
+     * Sometimes the field type cannot be inferred due to Java's type erasure. 
In that case, use the
+     * overload that allows setting the output field type explicitly.
+     */
+    public <CombineInputT, AccumT, CombineOutputT> 
CombineFieldsByFields<InputT> aggregateFields(
+        List<String> inputFieldNames,
+        CombineFn<CombineInputT, AccumT, CombineOutputT> fn,
+        String outputFieldName) {
+      return aggregateFields(
+          FieldAccessDescriptor.withFieldNames(inputFieldNames), fn, 
outputFieldName);
+    }
+
+    /**
+     * Build up an aggregation function over the input elements.
+     *
+     * <p>This method specifies an aggregation over multiple fields of the 
input. The union of all
+     * calls to aggregateField and aggregateFields will determine the output 
schema.
+     *
+     * <p>Field types in the output schema will be inferred from the provided 
combine function.
+     * Sometimes the field type cannot be inferred due to Java's type erasure. 
In that case, use the
+     * overload that allows setting the output field type explicitly.
+     */
+    public <CombineInputT, AccumT, CombineOutputT> 
CombineFieldsByFields<InputT> aggregateFields(
+        FieldAccessDescriptor fieldsToAggregate,
+        CombineFn<CombineInputT, AccumT, CombineOutputT> fn,
+        String outputFieldName) {
+      return new CombineFieldsByFields<>(
+          byFields, schemaAggregateFn.aggregateFields(fieldsToAggregate, fn, 
outputFieldName));
+    }
+
+    /**
+     * Build up an aggregation function over the input elements.
+     *
+     * <p>This method specifies an aggregation over multiple fields of the 
input. The union of all
+     * calls to aggregateField and aggregateFields will determine the output 
schema.
+     */
+    public <CombineInputT, AccumT, CombineOutputT> 
CombineFieldsByFields<InputT> aggregateFields(
+        List<String> inputFieldNames,
+        CombineFn<CombineInputT, AccumT, CombineOutputT> fn,
+        Field outputField) {
+      return aggregateFields(
+          FieldAccessDescriptor.withFieldNames(inputFieldNames), fn, 
outputField);
+    }
+
+    /**
+     * Build up an aggregation function over the input elements.
+     *
+     * <p>This method specifies an aggregation over multiple fields of the 
input. The union of all
+     * calls to aggregateField and aggregateFields will determine the output 
schema.
+     */
+    public <CombineInputT, AccumT, CombineOutputT> 
CombineFieldsByFields<InputT> aggregateFields(
+        FieldAccessDescriptor fieldsToAggregate,
+        CombineFn<CombineInputT, AccumT, CombineOutputT> fn,
+        Field outputField) {
+      return new CombineFieldsByFields<>(
+          byFields, schemaAggregateFn.aggregateFields(fieldsToAggregate, fn, 
outputField));
+    }
+
+    @Override
+    public PCollection<KV<Row, Row>> expand(PCollection<InputT> input) {
+      SchemaAggregateFn.Inner<InputT> fn =
+          schemaAggregateFn.withSchema(input.getSchema(), 
input.getToRowFunction());
+      return input.apply(byFields).apply(Combine.groupedValues(fn));
+    }
+  }
+}
diff --git 
a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/SchemaAggregateFn.java
 
b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/SchemaAggregateFn.java
new file mode 100644
index 00000000000..32addcc4670
--- /dev/null
+++ 
b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/SchemaAggregateFn.java
@@ -0,0 +1,318 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.sdk.schemas.transforms;
+
+import com.google.auto.value.AutoValue;
+import com.google.common.collect.Lists;
+import java.io.Serializable;
+import java.util.List;
+import java.util.stream.Collectors;
+import javax.annotation.Nullable;
+import org.apache.beam.sdk.annotations.Experimental;
+import org.apache.beam.sdk.annotations.Experimental.Kind;
+import org.apache.beam.sdk.coders.CannotProvideCoderException;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.coders.CoderRegistry;
+import org.apache.beam.sdk.coders.RowCoder;
+import org.apache.beam.sdk.schemas.FieldAccessDescriptor;
+import org.apache.beam.sdk.schemas.FieldTypeDescriptors;
+import org.apache.beam.sdk.schemas.Schema;
+import org.apache.beam.sdk.schemas.Schema.Field;
+import org.apache.beam.sdk.schemas.SchemaCoder;
+import org.apache.beam.sdk.transforms.Combine.CombineFn;
+import org.apache.beam.sdk.transforms.CombineFns;
+import org.apache.beam.sdk.transforms.CombineFns.CoCombineResult;
+import org.apache.beam.sdk.transforms.CombineFns.ComposedCombineFn;
+import org.apache.beam.sdk.transforms.SerializableFunction;
+import org.apache.beam.sdk.transforms.SerializableFunctions;
+import org.apache.beam.sdk.transforms.SimpleFunction;
+import org.apache.beam.sdk.values.Row;
+import org.apache.beam.sdk.values.TupleTag;
+
+/** This is the builder used by {@link Group} to build up a composed {@link 
CombineFn}. */
+@Experimental(Kind.SCHEMAS)
+class SchemaAggregateFn {
+  static <T> Inner<T> create() {
+    return new AutoValue_SchemaAggregateFn_Inner.Builder<T>()
+        .setFieldAggregations(Lists.newArrayList())
+        .build();
+  }
+
+  /** Implementation of {@link #create}. */
+  @AutoValue
+  abstract static class Inner<T> extends CombineFn<T, Object[], Row> {
+    // Represents an aggregation of one or more fields.
+    static class FieldAggregation<FieldT, AccumT, OutputT> implements 
Serializable {
+      FieldAccessDescriptor fieldsToAggregate;
+      // The specification of the output field.
+      private final Field outputField;
+      // The combine function.
+      private final CombineFn<FieldT, AccumT, OutputT> fn;
+      // The TupleTag identifying this aggregation element in the composed 
combine fn.
+      private final TupleTag<Object> combineTag;
+      // The schema corresponding to the the subset of input fields being 
aggregated.
+      @Nullable private final Schema inputSubSchema;
+      // The flattened version of inputSubSchema.
+      @Nullable private final Schema unnestedInputSubSchema;
+      // The output schema resulting from the aggregation.
+      private final Schema aggregationSchema;
+      private final boolean needsUnnesting;
+
+      FieldAggregation(
+          FieldAccessDescriptor fieldsToAggregate,
+          Field outputField,
+          CombineFn<FieldT, AccumT, OutputT> fn,
+          TupleTag<Object> combineTag) {
+        this(
+            fieldsToAggregate,
+            outputField,
+            fn,
+            combineTag,
+            Schema.builder().addField(outputField).build(),
+            null);
+      }
+
+      FieldAggregation(
+          FieldAccessDescriptor fieldsToAggregate,
+          Field outputField,
+          CombineFn<FieldT, AccumT, OutputT> fn,
+          TupleTag<Object> combineTag,
+          Schema aggregationSchema,
+          @Nullable Schema inputSchema) {
+        if (inputSchema != null) {
+          this.fieldsToAggregate = fieldsToAggregate.resolve(inputSchema);
+          this.inputSubSchema = Select.getOutputSchema(inputSchema, 
this.fieldsToAggregate);
+          this.unnestedInputSubSchema = 
Unnest.getUnnestedSchema(inputSubSchema);
+          this.needsUnnesting = !inputSchema.equals(unnestedInputSubSchema);
+        } else {
+          this.fieldsToAggregate = fieldsToAggregate;
+          this.inputSubSchema = null;
+          this.unnestedInputSubSchema = null;
+          this.needsUnnesting = false;
+        }
+        this.outputField = outputField;
+        this.fn = fn;
+        this.combineTag = combineTag;
+        this.aggregationSchema = aggregationSchema;
+      }
+
+      // The Schema is not necessarily known when the SchemaAggregateFn is 
created. Once the schema
+      // is known, resolve will be called with the proper schema.
+      FieldAggregation<FieldT, AccumT, OutputT> resolve(Schema schema) {
+        return new FieldAggregation<>(
+            fieldsToAggregate, outputField, fn, combineTag, aggregationSchema, 
schema);
+      }
+    }
+
+    abstract Builder<T> toBuilder();
+
+    @AutoValue.Builder
+    abstract static class Builder<T> {
+      abstract Builder<T> setInputSchema(@Nullable Schema inputSchema);
+
+      abstract Builder<T> setOutputSchema(@Nullable Schema outputSchema);
+
+      abstract Builder<T> setComposedCombineFn(@Nullable ComposedCombineFn<T> 
composedCombineFn);
+
+      abstract Builder<T> setFieldAggregations(List<FieldAggregation> 
fieldAggregations);
+
+      abstract Inner<T> build();
+    }
+
+    abstract @Nullable Schema getInputSchema();
+
+    abstract @Nullable Schema getOutputSchema();
+
+    abstract @Nullable ComposedCombineFn<T> getComposedCombineFn();
+
+    abstract List<FieldAggregation> getFieldAggregations();
+
+    /** Once the schema is known, this function is called by the {@link Group} 
transform. */
+    Inner<T> withSchema(Schema inputSchema, SerializableFunction<T, Row> 
toRowFunction) {
+      List<FieldAggregation> fieldAggregations =
+          getFieldAggregations()
+              .stream()
+              .map(f -> f.resolve(inputSchema))
+              .collect(Collectors.toList());
+
+      ComposedCombineFn<T> composedCombineFn = null;
+      for (int i = 0; i < fieldAggregations.size(); ++i) {
+        FieldAggregation fieldAggregation = fieldAggregations.get(i);
+        SimpleFunction<T, ?> extractFunction;
+        Coder extractOutputCoder;
+        if (fieldAggregation.unnestedInputSubSchema.getFieldCount() == 1) {
+          extractFunction = new ExtractSingleFieldFunction<>(fieldAggregation, 
toRowFunction);
+          extractOutputCoder =
+              RowCoder.coderForFieldType(
+                  
fieldAggregation.unnestedInputSubSchema.getField(0).getType());
+        } else {
+          extractFunction = new ExtractFieldsFunction<>(fieldAggregation, 
toRowFunction);
+          extractOutputCoder = RowCoder.of(fieldAggregation.inputSubSchema);
+        }
+        if (i == 0) {
+          composedCombineFn =
+              CombineFns.compose()
+                  .with(
+                      extractFunction,
+                      extractOutputCoder,
+                      fieldAggregation.fn,
+                      fieldAggregation.combineTag);
+        } else {
+          composedCombineFn =
+              composedCombineFn.with(
+                  extractFunction,
+                  extractOutputCoder,
+                  fieldAggregation.fn,
+                  fieldAggregation.combineTag);
+        }
+      }
+
+      return toBuilder()
+          .setInputSchema(inputSchema)
+          .setComposedCombineFn(composedCombineFn)
+          .setFieldAggregations(fieldAggregations)
+          .build();
+    }
+
+    /** Aggregate all values of a set of fields into an output field. */
+    <CombineInputT, AccumT, CombineOutputT> Inner<T> aggregateFields(
+        FieldAccessDescriptor fieldsToAggregate,
+        CombineFn<CombineInputT, AccumT, CombineOutputT> fn,
+        String outputFieldName) {
+      return aggregateFields(
+          fieldsToAggregate,
+          fn,
+          Field.of(outputFieldName, 
FieldTypeDescriptors.fieldTypeForJavaType(fn.getOutputType())));
+    }
+
+    /** Aggregate all values of a set of fields into an output field. */
+    <CombineInputT, AccumT, CombineOutputT> Inner<T> aggregateFields(
+        FieldAccessDescriptor fieldsToAggregate,
+        CombineFn<CombineInputT, AccumT, CombineOutputT> fn,
+        Field outputField) {
+      List<FieldAggregation> fieldAggregations = getFieldAggregations();
+      TupleTag<Object> combineTag = new 
TupleTag<>(Integer.toString(fieldAggregations.size()));
+      FieldAggregation fieldAggregation =
+          new FieldAggregation<>(fieldsToAggregate, outputField, fn, 
combineTag);
+      fieldAggregations.add(fieldAggregation);
+
+      return toBuilder()
+          .setOutputSchema(getOutputSchema(fieldAggregations))
+          .setFieldAggregations(fieldAggregations)
+          .build();
+    }
+
+    private Schema getOutputSchema(List<FieldAggregation> fieldAggregations) {
+      Schema.Builder outputSchema = Schema.builder();
+      for (FieldAggregation aggregation : fieldAggregations) {
+        outputSchema.addField(aggregation.outputField);
+      }
+      return outputSchema.build();
+    }
+
+    /** Extract a single field from an input {@link Row}. */
+    private static class ExtractSingleFieldFunction<InputT, OutputT>
+        extends SimpleFunction<InputT, OutputT> {
+      private final FieldAggregation fieldAggregation;
+      private final SerializableFunction<InputT, Row> toRowFunction;
+
+      private ExtractSingleFieldFunction(
+          FieldAggregation fieldAggregation, SerializableFunction<InputT, Row> 
toRowFunction) {
+        this.fieldAggregation = fieldAggregation;
+        this.toRowFunction = toRowFunction;
+      }
+
+      @Override
+      public OutputT apply(InputT input) {
+        Row row = toRowFunction.apply(input);
+        Row selected =
+            Select.selectRow(
+                row,
+                fieldAggregation.fieldsToAggregate,
+                row.getSchema(),
+                fieldAggregation.inputSubSchema);
+        if (fieldAggregation.needsUnnesting) {
+          selected = Unnest.unnestRow(selected, 
fieldAggregation.unnestedInputSubSchema);
+        }
+        return selected.getValue(0);
+      }
+    }
+
+    /** Extract multiple fields from an input {@link Row}. */
+    private static class ExtractFieldsFunction<T> extends SimpleFunction<T, 
Row> {
+      private FieldAggregation fieldAggregation;
+      private SerializableFunction<T, Row> toRowFunction;
+
+      private ExtractFieldsFunction(
+          FieldAggregation fieldAggregation, SerializableFunction<T, Row> 
toRowFunction) {
+        this.fieldAggregation = fieldAggregation;
+        this.toRowFunction = toRowFunction;
+      }
+
+      @Override
+      public Row apply(T input) {
+        Row row = toRowFunction.apply(input);
+        return Select.selectRow(
+            row,
+            fieldAggregation.fieldsToAggregate,
+            row.getSchema(),
+            fieldAggregation.inputSubSchema);
+      }
+    }
+
+    @Override
+    public Object[] createAccumulator() {
+      return getComposedCombineFn().createAccumulator();
+    }
+
+    @Override
+    public Object[] addInput(Object[] accumulator, T input) {
+      return getComposedCombineFn().addInput(accumulator, input);
+    }
+
+    @Override
+    public Object[] mergeAccumulators(Iterable<Object[]> accumulator) {
+      return getComposedCombineFn().mergeAccumulators(accumulator);
+    }
+
+    @Override
+    public Coder<Object[]> getAccumulatorCoder(CoderRegistry registry, 
Coder<T> inputCoder)
+        throws CannotProvideCoderException {
+      return getComposedCombineFn().getAccumulatorCoder(registry, inputCoder);
+    }
+
+    @Override
+    public Coder<Row> getDefaultOutputCoder(CoderRegistry registry, Coder<T> 
inputCoder) {
+      return SchemaCoder.of(
+          getOutputSchema(), SerializableFunctions.identity(), 
SerializableFunctions.identity());
+    }
+
+    @Override
+    public Row extractOutput(Object[] accumulator) {
+      // Build a row containing a field for every aggregate that was 
registered.
+      CoCombineResult coCombineResult = 
getComposedCombineFn().extractOutput(accumulator);
+      Row.Builder output = Row.withSchema(getOutputSchema());
+      for (FieldAggregation fieldAggregation : getFieldAggregations()) {
+        Object aggregate = coCombineResult.get(fieldAggregation.combineTag);
+        output.addValue(aggregate);
+      }
+      return output.build();
+    }
+  }
+}
diff --git 
a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/Select.java
 
b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/Select.java
index a5024ef7400..c8ead62cf4d 100644
--- 
a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/Select.java
+++ 
b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/Select.java
@@ -61,7 +61,7 @@
  *
  * <pre>{@code
  * PCollection<UserEvent> events = readUserEvents();
- * PCollection<Row> rows = event.apply(Select.fieldNameFilters("userId", 
"eventId"));
+ * PCollection<Row> rows = event.apply(Select.fieldNames("userId", "eventId"));
  * }</pre>
  *
  * It's possible to select a nested field as well. For example, if you want 
just the location
@@ -132,6 +132,8 @@ public void process(
     return selected;
   }
 
+  // Currently we don't flatten selected nested fields. We should consider 
whether to flatten them
+  // or leave them as is.
   static Schema getOutputSchema(Schema inputSchema, FieldAccessDescriptor 
fieldAccessDescriptor) {
     if (fieldAccessDescriptor.allFields()) {
       return inputSchema;
@@ -140,12 +142,14 @@ static Schema getOutputSchema(Schema inputSchema, 
FieldAccessDescriptor fieldAcc
     for (int fieldId : fieldAccessDescriptor.fieldIdsAccessed()) {
       builder.addField(inputSchema.getField(fieldId));
     }
+
     for (Map.Entry<Integer, FieldAccessDescriptor> nested :
         fieldAccessDescriptor.nestedFields().entrySet()) {
       Field field = inputSchema.getField(nested.getKey());
       FieldAccessDescriptor nestedDescriptor = nested.getValue();
       FieldType nestedType =
           FieldType.row(getOutputSchema(field.getType().getRowSchema(), 
nestedDescriptor));
+
       if (field.getNullable()) {
         builder.addNullableField(field.getName(), nestedType);
       } else {
diff --git 
a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Combine.java 
b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Combine.java
index 20314fcba3c..cff3f6874a8 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Combine.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Combine.java
@@ -407,6 +407,17 @@ public OutputT defaultValue() {
     public TypeDescriptor<OutputT> getOutputType() {
       return new TypeDescriptor<OutputT>(getClass()) {};
     }
+
+    /**
+     * Returns a {@link TypeDescriptor} capturing what is known statically 
about the input type of
+     * this {@code CombineFn} instance's most-derived class.
+     *
+     * <p>In the normal case of a concrete {@code CombineFn} subclass with no 
generic type
+     * parameters of its own, this will be a complete non-generic type.
+     */
+    public TypeDescriptor<InputT> getInputType() {
+      return new TypeDescriptor<InputT>(getClass()) {};
+    }
   }
 
   /////////////////////////////////////////////////////////////////////////////
diff --git 
a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PCollection.java 
b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PCollection.java
index bd2eec1371b..532cbfac775 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PCollection.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PCollection.java
@@ -321,7 +321,7 @@ public boolean hasSchema() {
     return getCoder() instanceof SchemaCoder;
   }
 
-  /** Returns the attached schema, or null if there is none. */
+  /** Returns the attached schema. */
   @Experimental(Kind.SCHEMAS)
   public Schema getSchema() {
     if (!hasSchema()) {
@@ -330,6 +330,24 @@ public Schema getSchema() {
     return ((SchemaCoder) getCoder()).getSchema();
   }
 
+  /** Returns the attached schema's toRowFunction. */
+  @Experimental(Kind.SCHEMAS)
+  public SerializableFunction<T, Row> getToRowFunction() {
+    if (!hasSchema()) {
+      throw new IllegalStateException("Cannot call getToRowFunction when there 
is no schema");
+    }
+    return ((SchemaCoder<T>) getCoder()).getToRowFunction();
+  }
+
+  /** Returns the attached schema's fromRowFunction. */
+  @Experimental(Kind.SCHEMAS)
+  public SerializableFunction<Row, T> getFromRowFunction() {
+    if (!hasSchema()) {
+      throw new IllegalStateException("Cannot call getFromRowFunction when 
there is no schema");
+    }
+    return ((SchemaCoder<T>) getCoder()).getFromRowFunction();
+  }
+
   /**
    * of the {@link PTransform}.
    *
diff --git 
a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/transforms/GroupTest.java
 
b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/transforms/GroupTest.java
new file mode 100644
index 00000000000..b37b4ba998c
--- /dev/null
+++ 
b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/transforms/GroupTest.java
@@ -0,0 +1,646 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.sdk.schemas.transforms;
+
+import static org.apache.beam.sdk.TestUtils.KvMatcher.isKv;
+import static org.hamcrest.CoreMatchers.equalTo;
+import static 
org.hamcrest.collection.IsIterableContainingInAnyOrder.containsInAnyOrder;
+import static org.junit.Assert.assertThat;
+import static org.junit.Assert.assertTrue;
+
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.Lists;
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Objects;
+import org.apache.beam.sdk.schemas.DefaultSchema;
+import org.apache.beam.sdk.schemas.FieldAccessDescriptor;
+import org.apache.beam.sdk.schemas.JavaFieldSchema;
+import org.apache.beam.sdk.schemas.Schema;
+import org.apache.beam.sdk.schemas.Schema.FieldType;
+import org.apache.beam.sdk.testing.NeedsRunner;
+import org.apache.beam.sdk.testing.PAssert;
+import org.apache.beam.sdk.testing.TestPipeline;
+import org.apache.beam.sdk.transforms.Combine.CombineFn;
+import org.apache.beam.sdk.transforms.Count;
+import org.apache.beam.sdk.transforms.Create;
+import org.apache.beam.sdk.transforms.Flatten;
+import org.apache.beam.sdk.transforms.Keys;
+import org.apache.beam.sdk.transforms.Sum;
+import org.apache.beam.sdk.transforms.Top;
+import org.apache.beam.sdk.transforms.Values;
+import org.apache.beam.sdk.values.KV;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.Row;
+import org.hamcrest.Matcher;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.experimental.categories.Category;
+
+/** Test for {@link Group}. */
+public class GroupTest implements Serializable {
+  @Rule public final transient TestPipeline pipeline = TestPipeline.create();
+
+  /** A simple POJO for testing. */
+  @DefaultSchema(JavaFieldSchema.class)
+  public static class POJO implements Serializable {
+    public String field1;
+    public long field2;
+    public String field3;
+
+    public POJO(String field1, long field2, String field3) {
+      this.field1 = field1;
+      this.field2 = field2;
+      this.field3 = field3;
+    }
+
+    public POJO() {}
+
+    @Override
+    public boolean equals(Object o) {
+      if (this == o) {
+        return true;
+      }
+      if (o == null || getClass() != o.getClass()) {
+        return false;
+      }
+      POJO pojo = (POJO) o;
+      return field2 == pojo.field2
+          && Objects.equals(field1, pojo.field1)
+          && Objects.equals(field3, pojo.field3);
+    }
+
+    @Override
+    public int hashCode() {
+      return Objects.hash(field1, field2, field3);
+    }
+
+    @Override
+    public String toString() {
+      return "POJO{"
+          + "field1='"
+          + field1
+          + '\''
+          + ", field2="
+          + field2
+          + ", field3='"
+          + field3
+          + '\''
+          + '}';
+    }
+  }
+
+  private static final Schema POJO_SCHEMA =
+      Schema.builder()
+          .addStringField("field1")
+          .addInt64Field("field2")
+          .addStringField("field3")
+          .build();
+
+  @Test
+  @Category(NeedsRunner.class)
+  public void testGroupByOneField() {
+    PCollection<KV<Row, Iterable<POJO>>> grouped =
+        pipeline
+            .apply(
+                Create.of(
+                    new POJO("key1", 1, "value1"),
+                    new POJO("key1", 2, "value2"),
+                    new POJO("key2", 3, "value3"),
+                    new POJO("key2", 4, "value4")))
+            .apply(Group.byFieldNames("field1"));
+
+    Schema keySchema = Schema.builder().addStringField("field1").build();
+    List<KV<Row, Collection<POJO>>> expected =
+        ImmutableList.of(
+            KV.of(
+                Row.withSchema(keySchema).addValue("key1").build(),
+                ImmutableList.of(new POJO("key1", 1L, "value1"), new 
POJO("key1", 2L, "value2"))),
+            KV.of(
+                Row.withSchema(keySchema).addValue("key2").build(),
+                ImmutableList.of(new POJO("key2", 3L, "value3"), new 
POJO("key2", 4L, "value4"))));
+
+    PAssert.that(grouped).satisfies(actual -> containsKIterableVs(expected, 
actual, new POJO[0]));
+    pipeline.run();
+  }
+
+  @Test
+  @Category(NeedsRunner.class)
+  public void testGroupByMultiple() {
+    PCollection<KV<Row, Iterable<POJO>>> grouped =
+        pipeline
+            .apply(
+                Create.of(
+                    new POJO("key1", 1, "value1"),
+                    new POJO("key1", 1, "value2"),
+                    new POJO("key2", 2, "value3"),
+                    new POJO("key2", 2, "value4")))
+            .apply(Group.byFieldNames("field1", "field2"));
+
+    Schema keySchema = 
Schema.builder().addStringField("field1").addInt64Field("field2").build();
+    List<KV<Row, Collection<POJO>>> expected =
+        ImmutableList.of(
+            KV.of(
+                Row.withSchema(keySchema).addValues("key1", 1L).build(),
+                ImmutableList.of(new POJO("key1", 1L, "value1"), new 
POJO("key1", 1L, "value2"))),
+            KV.of(
+                Row.withSchema(keySchema).addValues("key2", 2L).build(),
+                ImmutableList.of(new POJO("key2", 2L, "value3"), new 
POJO("key2", 2L, "value4"))));
+
+    PAssert.that(grouped).satisfies(actual -> containsKIterableVs(expected, 
actual, new POJO[0]));
+    pipeline.run();
+  }
+
+  /** A class for testing nested key grouping. */
+  @DefaultSchema(JavaFieldSchema.class)
+  public static class OuterPOJO implements Serializable {
+    public POJO inner;
+
+    public OuterPOJO(POJO inner) {
+      this.inner = inner;
+    }
+
+    public OuterPOJO() {}
+
+    @Override
+    public boolean equals(Object o) {
+      if (this == o) {
+        return true;
+      }
+      if (o == null || getClass() != o.getClass()) {
+        return false;
+      }
+      OuterPOJO outerPOJO = (OuterPOJO) o;
+      return Objects.equals(inner, outerPOJO.inner);
+    }
+
+    @Override
+    public int hashCode() {
+      return Objects.hash(inner);
+    }
+
+    @Override
+    public String toString() {
+      return "OuterPOJO{" + "inner=" + inner + '}';
+    }
+  }
+
+  /** Test grouping by a set of fields that are nested. */
+  @Test
+  @Category(NeedsRunner.class)
+  public void testGroupByNestedKey() {
+    FieldAccessDescriptor groupKeys =
+        FieldAccessDescriptor.create()
+            .withNestedField("inner", 
FieldAccessDescriptor.withFieldNames("field1", "field2"));
+    PCollection<KV<Row, Iterable<OuterPOJO>>> grouped =
+        pipeline
+            .apply(
+                Create.of(
+                    new OuterPOJO(new POJO("key1", 1L, "value1")),
+                    new OuterPOJO(new POJO("key1", 1L, "value2")),
+                    new OuterPOJO(new POJO("key2", 2L, "value3")),
+                    new OuterPOJO(new POJO("key2", 2L, "value4"))))
+            .apply(Group.byFieldAccessDescriptor(groupKeys));
+
+    Schema selectedSchema =
+        
Schema.builder().addStringField("field1").addInt64Field("field2").build();
+    Schema keySchema = Schema.builder().addRowField("inner", 
selectedSchema).build();
+    List<KV<Row, Collection<OuterPOJO>>> expected =
+        ImmutableList.of(
+            KV.of(
+                Row.withSchema(keySchema)
+                    .addValue(Row.withSchema(selectedSchema).addValues("key1", 
1L).build())
+                    .build(),
+                ImmutableList.of(
+                    new OuterPOJO(new POJO("key1", 1L, "value1")),
+                    new OuterPOJO(new POJO("key1", 1L, "value2")))),
+            KV.of(
+                Row.withSchema(keySchema)
+                    .addValue(Row.withSchema(selectedSchema).addValues("key2", 
2L).build())
+                    .build(),
+                ImmutableList.of(
+                    new OuterPOJO(new POJO("key2", 2L, "value3")),
+                    new OuterPOJO(new POJO("key2", 2L, "value4")))));
+
+    PAssert.that(grouped)
+        .satisfies(actual -> containsKIterableVs(expected, actual, new 
OuterPOJO[0]));
+    pipeline.run();
+  }
+
+  @Test
+  @Category(NeedsRunner.class)
+  public void testGroupGlobally() {
+    Collection<POJO> elements =
+        ImmutableList.of(
+            new POJO("key1", 1, "value1"),
+            new POJO("key1", 1, "value2"),
+            new POJO("key2", 2, "value3"),
+            new POJO("key2", 2, "value4"));
+
+    PCollection<Iterable<POJO>> grouped =
+        pipeline.apply(Create.of(elements)).apply(Group.globally());
+    PAssert.that(grouped).satisfies(actual -> containsSingleIterable(elements, 
actual));
+    pipeline.run();
+  }
+
+  @Test
+  @Category(NeedsRunner.class)
+  public void testGlobalAggregation() {
+    Collection<POJO> elements =
+        ImmutableList.of(
+            new POJO("key1", 1, "value1"),
+            new POJO("key1", 1, "value2"),
+            new POJO("key2", 2, "value3"),
+            new POJO("key2", 2, "value4"));
+    PCollection<Long> count =
+        pipeline
+            .apply(Create.of(elements))
+            .apply(Group.<POJO>globally().aggregate(Count.combineFn()));
+    PAssert.that(count).containsInAnyOrder(4L);
+
+    pipeline.run();
+  }
+
+  @Test
+  @Category(NeedsRunner.class)
+  public void testPerKeyAggregation() {
+    Collection<POJO> elements =
+        ImmutableList.of(
+            new POJO("key1", 1, "value1"),
+            new POJO("key1", 1, "value2"),
+            new POJO("key2", 2, "value3"),
+            new POJO("key2", 2, "value4"),
+            new POJO("key2", 2, "value4"));
+    PCollection<KV<Row, Long>> count =
+        pipeline
+            .apply(Create.of(elements))
+            
.apply(Group.<POJO>byFieldNames("field1").aggregate(Count.combineFn()));
+
+    Schema keySchema = Schema.builder().addStringField("field1").build();
+
+    Collection<KV<Row, Long>> expectedCounts =
+        ImmutableList.of(
+            KV.of(Row.withSchema(keySchema).addValue("key1").build(), 2L),
+            KV.of(Row.withSchema(keySchema).addValue("key2").build(), 3L));
+    PAssert.that(count).containsInAnyOrder(expectedCounts);
+
+    pipeline.run();
+  }
+
+  @Test
+  @Category(NeedsRunner.class)
+  public void testOutputCoders() {
+    Schema keySchema = Schema.builder().addStringField("field1").build();
+
+    PCollection<KV<Row, Iterable<POJO>>> grouped =
+        pipeline
+            .apply(Create.of(new POJO("key1", 1, "value1")))
+            .apply(Group.byFieldNames("field1"));
+
+    // Make sure that the key has the right schema.
+    PCollection<Row> keys = grouped.apply(Keys.create());
+    assertTrue(keys.getSchema().equivalent(keySchema));
+
+    // Make sure that the value has the right schema.
+    PCollection<POJO> values = 
grouped.apply(Values.create()).apply(Flatten.iterables());
+    assertTrue(values.getSchema().equivalent(POJO_SCHEMA));
+    pipeline.run();
+  }
+
+  /** A class for testing field aggregation. */
+  @DefaultSchema(JavaFieldSchema.class)
+  public static class AggregatePojos implements Serializable {
+    public long field1;
+    public long field2;
+    public int field3;
+
+    public AggregatePojos(long field1, long field2, int field3) {
+      this.field1 = field1;
+      this.field2 = field2;
+      this.field3 = field3;
+    }
+
+    public AggregatePojos() {}
+
+    @Override
+    public boolean equals(Object o) {
+      if (this == o) {
+        return true;
+      }
+      if (o == null || getClass() != o.getClass()) {
+        return false;
+      }
+      AggregatePojos agg = (AggregatePojos) o;
+      return field1 == agg.field1 && field2 == agg.field2 && field3 == 
agg.field3;
+    }
+
+    @Override
+    public int hashCode() {
+      return Objects.hash(field1, field2, field3);
+    }
+  }
+
+  @Test
+  @Category(NeedsRunner.class)
+  public void testByKeyWithSchemaAggregateFn() {
+    Collection<AggregatePojos> elements =
+        ImmutableList.of(
+            new AggregatePojos(1, 1, 2),
+            new AggregatePojos(2, 1, 3),
+            new AggregatePojos(3, 2, 4),
+            new AggregatePojos(4, 2, 5));
+
+    PCollection<KV<Row, Row>> aggregations =
+        pipeline
+            .apply(Create.of(elements))
+            .apply(
+                Group.<AggregatePojos>byFieldNames("field2")
+                    .aggregateField("field1", Sum.ofLongs(), "field1_sum")
+                    .aggregateField("field3", Sum.ofIntegers(), "field3_sum")
+                    .aggregateField("field1", Top.largestLongsFn(1), 
"field1_top"));
+
+    Schema keySchema = Schema.builder().addInt64Field("field2").build();
+    Schema valueSchema =
+        Schema.builder()
+            .addInt64Field("field1_sum")
+            .addInt32Field("field3_sum")
+            .addArrayField("field1_top", FieldType.INT64)
+            .build();
+
+    List<KV<Row, Row>> expected =
+        ImmutableList.of(
+            KV.of(
+                Row.withSchema(keySchema).addValue(1L).build(),
+                
Row.withSchema(valueSchema).addValue(3L).addValue(5).addArray(2L).build()),
+            KV.of(
+                Row.withSchema(keySchema).addValue(2L).build(),
+                
Row.withSchema(valueSchema).addValue(7L).addValue(9).addArray(4L).build()));
+    PAssert.that(aggregations).satisfies(actual -> containsKvs(expected, 
actual));
+
+    pipeline.run();
+  }
+
+  @Test
+  @Category(NeedsRunner.class)
+  public void testGloballyWithSchemaAggregateFn() {
+    Collection<AggregatePojos> elements =
+        ImmutableList.of(
+            new AggregatePojos(1, 1, 2),
+            new AggregatePojos(2, 1, 3),
+            new AggregatePojos(3, 2, 4),
+            new AggregatePojos(4, 2, 5));
+
+    PCollection<Row> aggregate =
+        pipeline
+            .apply(Create.of(elements))
+            .apply(
+                Group.<AggregatePojos>globally()
+                    .aggregateField("field1", Sum.ofLongs(), "field1_sum")
+                    .aggregateField("field3", Sum.ofIntegers(), "field3_sum")
+                    .aggregateField("field1", Top.largestLongsFn(1), 
"field1_top"));
+
+    Schema aggregateSchema =
+        Schema.builder()
+            .addInt64Field("field1_sum")
+            .addInt32Field("field3_sum")
+            .addArrayField("field1_top", FieldType.INT64)
+            .build();
+    Row expectedRow = Row.withSchema(aggregateSchema).addValues(10L, 
14).addArray(4L).build();
+    PAssert.that(aggregate).containsInAnyOrder(expectedRow);
+
+    pipeline.run();
+  }
+
+  /** A combine function that adds all long fields in the Row. */
+  public static class MultipleFieldCombineFn extends CombineFn<Row, long[], 
Long> {
+    @Override
+    public long[] createAccumulator() {
+      return new long[] {0};
+    }
+
+    @Override
+    public long[] addInput(long[] accumulator, Row input) {
+      for (Object o : input.getValues()) {
+        if (o instanceof Long) {
+          accumulator[0] += (Long) o;
+        }
+      }
+      return accumulator;
+    }
+
+    @Override
+    public long[] mergeAccumulators(Iterable<long[]> accumulators) {
+      Iterator<long[]> iter = accumulators.iterator();
+      if (!iter.hasNext()) {
+        return createAccumulator();
+      } else {
+        long[] running = iter.next();
+        while (iter.hasNext()) {
+          running[0] += iter.next()[0];
+        }
+        return running;
+      }
+    }
+
+    @Override
+    public Long extractOutput(long[] accumulator) {
+      return accumulator[0];
+    }
+  }
+
+  @Test
+  @Category(NeedsRunner.class)
+  public void testAggregateByMultipleFields() {
+    Collection<AggregatePojos> elements =
+        ImmutableList.of(
+            new AggregatePojos(1, 1, 2),
+            new AggregatePojos(2, 1, 3),
+            new AggregatePojos(3, 2, 4),
+            new AggregatePojos(4, 2, 5));
+
+    List<String> fieldNames = Lists.newArrayList("field1", "field2");
+    PCollection<Row> aggregate =
+        pipeline
+            .apply(Create.of(elements))
+            .apply(
+                Group.<AggregatePojos>globally()
+                    .aggregateFields(fieldNames, new MultipleFieldCombineFn(), 
"field1+field2"));
+
+    Schema outputSchema = 
Schema.builder().addInt64Field("field1+field2").build();
+    Row expectedRow = Row.withSchema(outputSchema).addValues(16L).build();
+    PAssert.that(aggregate).containsInAnyOrder(expectedRow);
+
+    pipeline.run();
+  }
+
+  /** A class for testing nested aggregation. */
+  @DefaultSchema(JavaFieldSchema.class)
+  public static class OuterAggregate implements Serializable {
+    public AggregatePojos inner;
+
+    public OuterAggregate(AggregatePojos inner) {
+      this.inner = inner;
+    }
+
+    public OuterAggregate() {}
+
+    @Override
+    public boolean equals(Object o) {
+      if (this == o) {
+        return true;
+      }
+      if (o == null || getClass() != o.getClass()) {
+        return false;
+      }
+      OuterAggregate that = (OuterAggregate) o;
+      return Objects.equals(inner, that.inner);
+    }
+
+    @Override
+    public int hashCode() {
+      return Objects.hash(inner);
+    }
+  }
+
+  @Test
+  @Category(NeedsRunner.class)
+  public void testByKeyWithSchemaAggregateFnNestedFields() {
+    Collection<OuterAggregate> elements =
+        ImmutableList.of(
+            new OuterAggregate(new AggregatePojos(1, 1, 2)),
+            new OuterAggregate(new AggregatePojos(2, 1, 3)),
+            new OuterAggregate(new AggregatePojos(3, 2, 4)),
+            new OuterAggregate(new AggregatePojos(4, 2, 5)));
+
+    FieldAccessDescriptor field1Selector =
+        FieldAccessDescriptor.create()
+            .withNestedField("inner", 
FieldAccessDescriptor.withFieldNames("field1"));
+    FieldAccessDescriptor field2Selector =
+        FieldAccessDescriptor.create()
+            .withNestedField("inner", 
FieldAccessDescriptor.withFieldNames("field2"));
+    FieldAccessDescriptor field3Selector =
+        FieldAccessDescriptor.create()
+            .withNestedField("inner", 
FieldAccessDescriptor.withFieldNames("field3"));
+
+    PCollection<KV<Row, Row>> aggregations =
+        pipeline
+            .apply(Create.of(elements))
+            .apply(
+                Group.<OuterAggregate>byFieldAccessDescriptor(field2Selector)
+                    .aggregateFields(field1Selector, Sum.ofLongs(), 
"field1_sum")
+                    .aggregateFields(field3Selector, Sum.ofIntegers(), 
"field3_sum")
+                    .aggregateFields(field1Selector, Top.largestLongsFn(1), 
"field1_top"));
+
+    Schema innerKeySchema = Schema.builder().addInt64Field("field2").build();
+    Schema keySchema = Schema.builder().addRowField("inner", 
innerKeySchema).build();
+    Schema valueSchema =
+        Schema.builder()
+            .addInt64Field("field1_sum")
+            .addInt32Field("field3_sum")
+            .addArrayField("field1_top", FieldType.INT64)
+            .build();
+
+    List<KV<Row, Row>> expected =
+        ImmutableList.of(
+            KV.of(
+                Row.withSchema(keySchema)
+                    
.addValue(Row.withSchema(innerKeySchema).addValue(1L).build())
+                    .build(),
+                
Row.withSchema(valueSchema).addValue(3L).addValue(5).addArray(2L).build()),
+            KV.of(
+                Row.withSchema(keySchema)
+                    
.addValue(Row.withSchema(innerKeySchema).addValue(2L).build())
+                    .build(),
+                
Row.withSchema(valueSchema).addValue(7L).addValue(9).addArray(4L).build()));
+    PAssert.that(aggregations).satisfies(actual -> containsKvs(expected, 
actual));
+
+    pipeline.run();
+  }
+
+  @Test
+  @Category(NeedsRunner.class)
+  public void testGloballyWithSchemaAggregateFnNestedFields() {
+    Collection<OuterAggregate> elements =
+        ImmutableList.of(
+            new OuterAggregate(new AggregatePojos(1, 1, 2)),
+            new OuterAggregate(new AggregatePojos(2, 1, 3)),
+            new OuterAggregate(new AggregatePojos(3, 2, 4)),
+            new OuterAggregate(new AggregatePojos(4, 2, 5)));
+
+    FieldAccessDescriptor field1Selector =
+        FieldAccessDescriptor.create()
+            .withNestedField("inner", 
FieldAccessDescriptor.withFieldNames("field1"));
+    FieldAccessDescriptor field3Selector =
+        FieldAccessDescriptor.create()
+            .withNestedField("inner", 
FieldAccessDescriptor.withFieldNames("field3"));
+
+    PCollection<Row> aggregate =
+        pipeline
+            .apply(Create.of(elements))
+            .apply(
+                Group.<OuterAggregate>globally()
+                    .aggregateFields(field1Selector, Sum.ofLongs(), 
"field1_sum")
+                    .aggregateFields(field3Selector, Sum.ofIntegers(), 
"field3_sum")
+                    .aggregateFields(field1Selector, Top.largestLongsFn(1), 
"field1_top"));
+    Schema aggregateSchema =
+        Schema.builder()
+            .addInt64Field("field1_sum")
+            .addInt32Field("field3_sum")
+            .addArrayField("field1_top", FieldType.INT64)
+            .build();
+    Row expectedRow = Row.withSchema(aggregateSchema).addValues(10L, 
14).addArray(4L).build();
+    PAssert.that(aggregate).containsInAnyOrder(expectedRow);
+
+    pipeline.run();
+  }
+
+  private static <T> Void containsKIterableVs(
+      List<KV<Row, Collection<T>>> expectedKvs,
+      Iterable<KV<Row, Iterable<T>>> actualKvs,
+      T[] emptyArray) {
+    List<KV<Row, Iterable<T>>> list = Lists.newArrayList(actualKvs);
+    List<Matcher<? super KV<Row, Iterable<POJO>>>> matchers = new 
ArrayList<>();
+    for (KV<Row, Collection<T>> expected : expectedKvs) {
+      T[] values = expected.getValue().toArray(emptyArray);
+      matchers.add(isKv(equalTo(expected.getKey()), 
containsInAnyOrder(values)));
+    }
+    assertThat(actualKvs, containsInAnyOrder(matchers.toArray(new 
Matcher[0])));
+    return null;
+  }
+
+  private static <T> Void containsKvs(
+      List<KV<Row, Row>> expectedKvs, Iterable<KV<Row, Row>> actualKvs) {
+    List<Matcher<? super KV<Row, Iterable<POJO>>>> matchers = new 
ArrayList<>();
+    for (KV<Row, Row> expected : expectedKvs) {
+      matchers.add(isKv(equalTo(expected.getKey()), 
equalTo(expected.getValue())));
+    }
+    assertThat(actualKvs, containsInAnyOrder(matchers.toArray(new 
Matcher[0])));
+    return null;
+  }
+
+  private static Void containsSingleIterable(
+      Collection<POJO> expected, Iterable<Iterable<POJO>> actual) {
+    POJO[] values = expected.toArray(new POJO[0]);
+    assertThat(actual, containsInAnyOrder(containsInAnyOrder(values)));
+    return null;
+  }
+}
diff --git 
a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/transforms/SelectTest.java
 
b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/transforms/SelectTest.java
index 3b6aed68497..25109645e3b 100644
--- 
a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/transforms/SelectTest.java
+++ 
b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/transforms/SelectTest.java
@@ -150,7 +150,6 @@ public boolean equals(Object o) {
 
     @Override
     public int hashCode() {
-
       return Objects.hash(field2);
     }
   }
diff --git 
a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/BeamAggregationTransforms.java
 
b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/BeamAggregationTransforms.java
index 53fdee894b9..0a143690a9d 100644
--- 
a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/BeamAggregationTransforms.java
+++ 
b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/BeamAggregationTransforms.java
@@ -294,8 +294,7 @@ private Object getAggregatorOutput(AggregationAccumulator 
accumulator, int idx)
         if (sourceFieldExps.get(idx) instanceof Integer) {
           int srcFieldIndex = (Integer) sourceFieldExps.get(idx);
           Coder srcFieldCoder =
-              RowCoder.coderForPrimitiveType(
-                  
sourceSchema.getField(srcFieldIndex).getType().getTypeName());
+              
RowCoder.coderForFieldType(sourceSchema.getField(srcFieldIndex).getType());
           
aggAccuCoderList.add(aggregators.get(idx).getAccumulatorCoder(registry, 
srcFieldCoder));
         } else if (sourceFieldExps.get(idx) instanceof KV) {
           // extract coder of two expressions separately.
@@ -305,11 +304,9 @@ private Object getAggregatorOutput(AggregationAccumulator 
accumulator, int idx)
           int srcFieldIndexValue = exp.getValue();
 
           Coder srcFieldCoderKey =
-              RowCoder.coderForPrimitiveType(
-                  
sourceSchema.getField(srcFieldIndexKey).getType().getTypeName());
+              
RowCoder.coderForFieldType(sourceSchema.getField(srcFieldIndexKey).getType());
           Coder srcFieldCoderValue =
-              RowCoder.coderForPrimitiveType(
-                  
sourceSchema.getField(srcFieldIndexValue).getType().getTypeName());
+              
RowCoder.coderForFieldType(sourceSchema.getField(srcFieldIndexValue).getType());
 
           aggAccuCoderList.add(
               aggregators


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


Issue Time Tracking
-------------------

    Worklog Id:     (was: 152156)
    Time Spent: 13h  (was: 12h 50m)

> Create a library of useful transforms that use schemas
> ------------------------------------------------------
>
>                 Key: BEAM-4461
>                 URL: https://issues.apache.org/jira/browse/BEAM-4461
>             Project: Beam
>          Issue Type: Sub-task
>          Components: sdk-java-core
>            Reporter: Reuven Lax
>            Assignee: Reuven Lax
>            Priority: Major
>          Time Spent: 13h
>  Remaining Estimate: 0h
>
> e.g. JoinBy(fields). Project, Filter, etc.



--
This message was sent by Atlassian JIRA
(v7.6.3#76005)

Reply via email to