Repository: beam
Updated Branches:
  refs/heads/DSL_SQL 25fea4e1e -> d89d1ee1a


support UDF/UDAF in BeamSql


Project: http://git-wip-us.apache.org/repos/asf/beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/5ca54e95
Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/5ca54e95
Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/5ca54e95

Branch: refs/heads/DSL_SQL
Commit: 5ca54e956e80f3059a9e67bf9b3d34af08569ff1
Parents: 25fea4e
Author: mingmxu <[email protected]>
Authored: Sun Jul 2 21:24:07 2017 -0700
Committer: Tyler Akidau <[email protected]>
Committed: Wed Jul 12 15:54:03 2017 -0700

----------------------------------------------------------------------
 .../java/org/apache/beam/dsls/sql/BeamSql.java  | 114 ++++++++++-----
 .../org/apache/beam/dsls/sql/BeamSqlEnv.java    |   6 +-
 .../beam/dsls/sql/BeamSqlDslUdfUdafTest.java    | 137 +++++++++++++++++++
 3 files changed, 221 insertions(+), 36 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/beam/blob/5ca54e95/dsls/sql/src/main/java/org/apache/beam/dsls/sql/BeamSql.java
----------------------------------------------------------------------
diff --git a/dsls/sql/src/main/java/org/apache/beam/dsls/sql/BeamSql.java 
b/dsls/sql/src/main/java/org/apache/beam/dsls/sql/BeamSql.java
index a0e7cbc..ec3799c 100644
--- a/dsls/sql/src/main/java/org/apache/beam/dsls/sql/BeamSql.java
+++ b/dsls/sql/src/main/java/org/apache/beam/dsls/sql/BeamSql.java
@@ -17,10 +17,12 @@
  */
 package org.apache.beam.dsls.sql;
 
+import com.google.auto.value.AutoValue;
 import org.apache.beam.dsls.sql.rel.BeamRelNode;
 import org.apache.beam.dsls.sql.schema.BeamPCollectionTable;
 import org.apache.beam.dsls.sql.schema.BeamSqlRow;
 import org.apache.beam.dsls.sql.schema.BeamSqlRowCoder;
+import org.apache.beam.dsls.sql.schema.BeamSqlUdaf;
 import org.apache.beam.sdk.annotations.Experimental;
 import org.apache.beam.sdk.transforms.PTransform;
 import org.apache.beam.sdk.values.PCollection;
@@ -51,7 +53,9 @@ PCollection<BeamSqlRow> inputTableB = 
p.apply(TextIO.read().from("/my/input/path
 
 //run a simple query, and register the output as a table in BeamSql;
 String sql1 = "select MY_FUNC(c1), c2 from PCOLLECTION";
-PCollection<BeamSqlRow> outputTableA = 
inputTableA.apply(BeamSql.simpleQuery(sql1));
+PCollection<BeamSqlRow> outputTableA = inputTableA.apply(
+    BeamSql.simpleQuery(sql1)
+    .withUdf("MY_FUNC", MY_FUNC.class, "FUNC"));
 
 //run a JOIN with one table from TextIO, and one table from another query
 PCollection<BeamSqlRow> outputTableB = PCollectionTuple.of(
@@ -60,7 +64,7 @@ PCollection<BeamSqlRow> outputTableB = PCollectionTuple.of(
     .apply(BeamSql.query("select * from TABLE_O_A JOIN TABLE_B where ..."));
 
 //output the final result with TextIO
-outputTableB.apply(BeamSql.toTextRow()).apply(TextIO.write().to("/my/output/path"));
+outputTableB.apply(...).apply(TextIO.write().to("/my/output/path"));
 
 p.run().waitUntilFinish();
  * }
@@ -68,7 +72,6 @@ p.run().waitUntilFinish();
  */
 @Experimental
 public class BeamSql {
-
   /**
    * Transforms a SQL query into a {@link PTransform} representing an 
equivalent execution plan.
    *
@@ -80,9 +83,11 @@ public class BeamSql {
    * <p>It is an error to apply a {@link PCollectionTuple} missing any {@code 
table names}
    * referenced within the query.
    */
-  public static PTransform<PCollectionTuple, PCollection<BeamSqlRow>> 
query(String sqlQuery) {
-    return new QueryTransform(sqlQuery);
-
+  public static QueryTransform query(String sqlQuery) {
+    return QueryTransform.builder()
+        .setSqlEnv(new BeamSqlEnv())
+        .setSqlQuery(sqlQuery)
+        .build();
   }
 
   /**
@@ -93,42 +98,62 @@ public class BeamSql {
    *
    * <p>Make sure to query it from a static table name <em>PCOLLECTION</em>.
    */
-  public static PTransform<PCollection<BeamSqlRow>, PCollection<BeamSqlRow>>
-  simpleQuery(String sqlQuery) throws Exception {
-    return new SimpleQueryTransform(sqlQuery);
+  public static SimpleQueryTransform simpleQuery(String sqlQuery) throws 
Exception {
+    return SimpleQueryTransform.builder()
+        .setSqlEnv(new BeamSqlEnv())
+        .setSqlQuery(sqlQuery)
+        .build();
   }
 
   /**
    * A {@link PTransform} representing an execution plan for a SQL query.
    */
-  private static class QueryTransform extends
+  @AutoValue
+  public abstract static class QueryTransform extends
       PTransform<PCollectionTuple, PCollection<BeamSqlRow>> {
-    private transient BeamSqlEnv sqlEnv;
-    private String sqlQuery;
+    abstract BeamSqlEnv getSqlEnv();
+    abstract String getSqlQuery();
 
-    public QueryTransform(String sqlQuery) {
-      this.sqlQuery = sqlQuery;
-      sqlEnv = new BeamSqlEnv();
+    static Builder builder() {
+      return new AutoValue_BeamSql_QueryTransform.Builder();
     }
 
-    public QueryTransform(String sqlQuery, BeamSqlEnv sqlEnv) {
-      this.sqlQuery = sqlQuery;
-      this.sqlEnv = sqlEnv;
+    @AutoValue.Builder
+    abstract static class Builder {
+      abstract Builder setSqlQuery(String sqlQuery);
+      abstract Builder setSqlEnv(BeamSqlEnv sqlEnv);
+      abstract QueryTransform build();
     }
 
+    /**
+     * register a UDF function used in this query.
+     */
+     public QueryTransform withUdf(String functionName, Class<?> clazz, String 
methodName){
+       getSqlEnv().registerUdf(functionName, clazz, methodName);
+       return this;
+     }
+
+     /**
+      * register a UDAF function used in this query.
+      */
+     public QueryTransform withUdaf(String functionName, Class<? extends 
BeamSqlUdaf> clazz){
+       getSqlEnv().registerUdaf(functionName, clazz);
+       return this;
+     }
+
     @Override
     public PCollection<BeamSqlRow> expand(PCollectionTuple input) {
       registerTables(input);
 
       BeamRelNode beamRelNode = null;
       try {
-        beamRelNode = sqlEnv.planner.convertToBeamRel(sqlQuery);
+        beamRelNode = getSqlEnv().planner.convertToBeamRel(getSqlQuery());
       } catch (ValidationException | RelConversionException | 
SqlParseException e) {
         throw new IllegalStateException(e);
       }
 
       try {
-        return beamRelNode.buildBeamPipeline(input, sqlEnv);
+        return beamRelNode.buildBeamPipeline(input, getSqlEnv());
       } catch (Exception e) {
         throw new IllegalStateException(e);
       }
@@ -140,7 +165,7 @@ public class BeamSql {
         PCollection<BeamSqlRow> sourceStream = (PCollection<BeamSqlRow>) 
input.get(sourceTag);
         BeamSqlRowCoder sourceCoder = (BeamSqlRowCoder) 
sourceStream.getCoder();
 
-        sqlEnv.registerTable(sourceTag.getId(),
+        getSqlEnv().registerTable(sourceTag.getId(),
             new BeamPCollectionTable(sourceStream, 
sourceCoder.getTableSchema()));
       }
     }
@@ -150,26 +175,45 @@ public class BeamSql {
    * A {@link PTransform} representing an execution plan for a SQL query 
referencing
    * a single table.
    */
-  private static class SimpleQueryTransform
+  @AutoValue
+  public abstract static class SimpleQueryTransform
       extends PTransform<PCollection<BeamSqlRow>, PCollection<BeamSqlRow>> {
     private static final String PCOLLECTION_TABLE_NAME = "PCOLLECTION";
-    private transient BeamSqlEnv sqlEnv = new BeamSqlEnv();
-    private String sqlQuery;
+    abstract BeamSqlEnv getSqlEnv();
+    abstract String getSqlQuery();
 
-    public SimpleQueryTransform(String sqlQuery) {
-      this.sqlQuery = sqlQuery;
-      validateQuery();
+    static Builder builder() {
+      return new AutoValue_BeamSql_SimpleQueryTransform.Builder();
     }
 
-    // public SimpleQueryTransform withUdf(String udfName){
-    // throw new UnsupportedOperationException("Pending for UDF support");
-    // }
+    @AutoValue.Builder
+    abstract static class Builder {
+      abstract Builder setSqlQuery(String sqlQuery);
+      abstract Builder setSqlEnv(BeamSqlEnv sqlEnv);
+      abstract SimpleQueryTransform build();
+    }
+
+    /**
+     * register a UDF function used in this query.
+     */
+     public SimpleQueryTransform withUdf(String functionName, Class<?> clazz, 
String methodName){
+       getSqlEnv().registerUdf(functionName, clazz, methodName);
+       return this;
+     }
+
+     /**
+      * register a UDAF function used in this query.
+      */
+     public SimpleQueryTransform withUdaf(String functionName, Class<? extends 
BeamSqlUdaf> clazz){
+       getSqlEnv().registerUdaf(functionName, clazz);
+       return this;
+     }
 
     private void validateQuery() {
       SqlNode sqlNode;
       try {
-        sqlNode = sqlEnv.planner.parseQuery(sqlQuery);
-        sqlEnv.planner.getPlanner().close();
+        sqlNode = getSqlEnv().planner.parseQuery(getSqlQuery());
+        getSqlEnv().planner.getPlanner().close();
       } catch (SqlParseException e) {
         throw new IllegalStateException(e);
       }
@@ -188,8 +232,12 @@ public class BeamSql {
 
     @Override
     public PCollection<BeamSqlRow> expand(PCollection<BeamSqlRow> input) {
+      validateQuery();
       return PCollectionTuple.of(new 
TupleTag<BeamSqlRow>(PCOLLECTION_TABLE_NAME), input)
-          .apply(new QueryTransform(sqlQuery, sqlEnv));
+          .apply(QueryTransform.builder()
+              .setSqlEnv(getSqlEnv())
+              .setSqlQuery(getSqlQuery())
+              .build());
     }
   }
 }

http://git-wip-us.apache.org/repos/asf/beam/blob/5ca54e95/dsls/sql/src/main/java/org/apache/beam/dsls/sql/BeamSqlEnv.java
----------------------------------------------------------------------
diff --git a/dsls/sql/src/main/java/org/apache/beam/dsls/sql/BeamSqlEnv.java 
b/dsls/sql/src/main/java/org/apache/beam/dsls/sql/BeamSqlEnv.java
index 078d9d3..61f0355 100644
--- a/dsls/sql/src/main/java/org/apache/beam/dsls/sql/BeamSqlEnv.java
+++ b/dsls/sql/src/main/java/org/apache/beam/dsls/sql/BeamSqlEnv.java
@@ -43,9 +43,9 @@ import org.apache.calcite.tools.Frameworks;
  * <p>It contains a {@link SchemaPlus} which holds the metadata of tables/UDF 
functions, and
  * a {@link BeamQueryPlanner} which parse/validate/optimize/translate input 
SQL queries.
  */
-public class BeamSqlEnv {
-  SchemaPlus schema;
-  BeamQueryPlanner planner;
+public class BeamSqlEnv implements Serializable{
+  transient SchemaPlus schema;
+  transient BeamQueryPlanner planner;
 
   public BeamSqlEnv() {
     schema = Frameworks.createRootSchema(true);

http://git-wip-us.apache.org/repos/asf/beam/blob/5ca54e95/dsls/sql/src/test/java/org/apache/beam/dsls/sql/BeamSqlDslUdfUdafTest.java
----------------------------------------------------------------------
diff --git 
a/dsls/sql/src/test/java/org/apache/beam/dsls/sql/BeamSqlDslUdfUdafTest.java 
b/dsls/sql/src/test/java/org/apache/beam/dsls/sql/BeamSqlDslUdfUdafTest.java
new file mode 100644
index 0000000..ba3e87e
--- /dev/null
+++ b/dsls/sql/src/test/java/org/apache/beam/dsls/sql/BeamSqlDslUdfUdafTest.java
@@ -0,0 +1,137 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.dsls.sql;
+
+import java.sql.Types;
+import java.util.Arrays;
+import java.util.Iterator;
+import org.apache.beam.dsls.sql.schema.BeamSqlRecordType;
+import org.apache.beam.dsls.sql.schema.BeamSqlRow;
+import org.apache.beam.dsls.sql.schema.BeamSqlUdaf;
+import org.apache.beam.sdk.testing.PAssert;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.PCollectionTuple;
+import org.apache.beam.sdk.values.TupleTag;
+import org.junit.Test;
+
+/**
+ * Tests for UDF/UDAF.
+ */
+public class BeamSqlDslUdfUdafTest extends BeamSqlDslBase {
+  /**
+   * GROUP-BY with UDAF.
+   */
+  @Test
+  public void testUdaf() throws Exception {
+    BeamSqlRecordType resultType = 
BeamSqlRecordType.create(Arrays.asList("f_int2", "squaresum"),
+        Arrays.asList(Types.INTEGER, Types.INTEGER));
+
+    BeamSqlRow record = new BeamSqlRow(resultType);
+    record.addField("f_int2", 0);
+    record.addField("squaresum", 30);
+
+    String sql1 = "SELECT f_int2, squaresum1(f_int) AS `squaresum`"
+        + " FROM PCOLLECTION GROUP BY f_int2";
+    PCollection<BeamSqlRow> result1 =
+        boundedInput1.apply("testUdaf1",
+            BeamSql.simpleQuery(sql1).withUdaf("squaresum1", SquareSum.class));
+    PAssert.that(result1).containsInAnyOrder(record);
+
+    String sql2 = "SELECT f_int2, squaresum2(f_int) AS `squaresum`"
+        + " FROM PCOLLECTION GROUP BY f_int2";
+    PCollection<BeamSqlRow> result2 =
+        PCollectionTuple.of(new TupleTag<BeamSqlRow>("PCOLLECTION"), 
boundedInput1)
+        .apply("testUdaf2",
+            BeamSql.query(sql2).withUdaf("squaresum2", SquareSum.class));
+    PAssert.that(result2).containsInAnyOrder(record);
+
+    pipeline.run().waitUntilFinish();
+  }
+
+  /**
+   * test UDF.
+   */
+  @Test
+  public void testUdf() throws Exception{
+    BeamSqlRecordType resultType = 
BeamSqlRecordType.create(Arrays.asList("f_int", "cubicvalue"),
+        Arrays.asList(Types.INTEGER, Types.INTEGER));
+
+    BeamSqlRow record = new BeamSqlRow(resultType);
+    record.addField("f_int", 2);
+    record.addField("cubicvalue", 8);
+
+    String sql1 = "SELECT f_int, cubic1(f_int) as cubicvalue FROM PCOLLECTION 
WHERE f_int = 2";
+    PCollection<BeamSqlRow> result1 =
+        boundedInput1.apply("testUdf1",
+            BeamSql.simpleQuery(sql1).withUdf("cubic1", CubicInteger.class, 
"cubic"));
+    PAssert.that(result1).containsInAnyOrder(record);
+
+    String sql2 = "SELECT f_int, cubic2(f_int) as cubicvalue FROM PCOLLECTION 
WHERE f_int = 2";
+    PCollection<BeamSqlRow> result2 =
+        PCollectionTuple.of(new TupleTag<BeamSqlRow>("PCOLLECTION"), 
boundedInput1)
+        .apply("testUdf2",
+            BeamSql.query(sql2).withUdf("cubic2", CubicInteger.class, 
"cubic"));
+    PAssert.that(result2).containsInAnyOrder(record);
+
+    pipeline.run().waitUntilFinish();
+  }
+
+  /**
+   * UDAF for test, which returns the sum of square.
+   */
+  public static class SquareSum extends BeamSqlUdaf<Integer, Integer, Integer> 
{
+
+    public SquareSum() {
+    }
+
+    @Override
+    public Integer init() {
+      return 0;
+    }
+
+    @Override
+    public Integer add(Integer accumulator, Integer input) {
+      return accumulator + input * input;
+    }
+
+    @Override
+    public Integer merge(Iterable<Integer> accumulators) {
+      int v = 0;
+      Iterator<Integer> ite = accumulators.iterator();
+      while (ite.hasNext()) {
+        v += ite.next();
+      }
+      return v;
+    }
+
+    @Override
+    public Integer result(Integer accumulator) {
+      return accumulator;
+    }
+
+  }
+
+  /**
+   * A example UDF for test.
+   */
+  public static class CubicInteger{
+    public static Integer cubic(Integer input){
+      return input * input * input;
+    }
+  }
+}

Reply via email to