Repository: beam Updated Branches: refs/heads/DSL_SQL 25fea4e1e -> d89d1ee1a
support UDF/UDAF in BeamSql Project: http://git-wip-us.apache.org/repos/asf/beam/repo Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/5ca54e95 Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/5ca54e95 Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/5ca54e95 Branch: refs/heads/DSL_SQL Commit: 5ca54e956e80f3059a9e67bf9b3d34af08569ff1 Parents: 25fea4e Author: mingmxu <[email protected]> Authored: Sun Jul 2 21:24:07 2017 -0700 Committer: Tyler Akidau <[email protected]> Committed: Wed Jul 12 15:54:03 2017 -0700 ---------------------------------------------------------------------- .../java/org/apache/beam/dsls/sql/BeamSql.java | 114 ++++++++++----- .../org/apache/beam/dsls/sql/BeamSqlEnv.java | 6 +- .../beam/dsls/sql/BeamSqlDslUdfUdafTest.java | 137 +++++++++++++++++++ 3 files changed, 221 insertions(+), 36 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/beam/blob/5ca54e95/dsls/sql/src/main/java/org/apache/beam/dsls/sql/BeamSql.java ---------------------------------------------------------------------- diff --git a/dsls/sql/src/main/java/org/apache/beam/dsls/sql/BeamSql.java b/dsls/sql/src/main/java/org/apache/beam/dsls/sql/BeamSql.java index a0e7cbc..ec3799c 100644 --- a/dsls/sql/src/main/java/org/apache/beam/dsls/sql/BeamSql.java +++ b/dsls/sql/src/main/java/org/apache/beam/dsls/sql/BeamSql.java @@ -17,10 +17,12 @@ */ package org.apache.beam.dsls.sql; +import com.google.auto.value.AutoValue; import org.apache.beam.dsls.sql.rel.BeamRelNode; import org.apache.beam.dsls.sql.schema.BeamPCollectionTable; import org.apache.beam.dsls.sql.schema.BeamSqlRow; import org.apache.beam.dsls.sql.schema.BeamSqlRowCoder; +import org.apache.beam.dsls.sql.schema.BeamSqlUdaf; import org.apache.beam.sdk.annotations.Experimental; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.values.PCollection; @@ -51,7 +53,9 @@ PCollection<BeamSqlRow> inputTableB = p.apply(TextIO.read().from("/my/input/path //run a simple query, and register the output as a table in BeamSql; String sql1 = "select MY_FUNC(c1), c2 from PCOLLECTION"; -PCollection<BeamSqlRow> outputTableA = inputTableA.apply(BeamSql.simpleQuery(sql1)); +PCollection<BeamSqlRow> outputTableA = inputTableA.apply( + BeamSql.simpleQuery(sql1) + .withUdf("MY_FUNC", MY_FUNC.class, "FUNC")); //run a JOIN with one table from TextIO, and one table from another query PCollection<BeamSqlRow> outputTableB = PCollectionTuple.of( @@ -60,7 +64,7 @@ PCollection<BeamSqlRow> outputTableB = PCollectionTuple.of( .apply(BeamSql.query("select * from TABLE_O_A JOIN TABLE_B where ...")); //output the final result with TextIO -outputTableB.apply(BeamSql.toTextRow()).apply(TextIO.write().to("/my/output/path")); +outputTableB.apply(...).apply(TextIO.write().to("/my/output/path")); p.run().waitUntilFinish(); * } @@ -68,7 +72,6 @@ p.run().waitUntilFinish(); */ @Experimental public class BeamSql { - /** * Transforms a SQL query into a {@link PTransform} representing an equivalent execution plan. * @@ -80,9 +83,11 @@ public class BeamSql { * <p>It is an error to apply a {@link PCollectionTuple} missing any {@code table names} * referenced within the query. */ - public static PTransform<PCollectionTuple, PCollection<BeamSqlRow>> query(String sqlQuery) { - return new QueryTransform(sqlQuery); - + public static QueryTransform query(String sqlQuery) { + return QueryTransform.builder() + .setSqlEnv(new BeamSqlEnv()) + .setSqlQuery(sqlQuery) + .build(); } /** @@ -93,42 +98,62 @@ public class BeamSql { * * <p>Make sure to query it from a static table name <em>PCOLLECTION</em>. */ - public static PTransform<PCollection<BeamSqlRow>, PCollection<BeamSqlRow>> - simpleQuery(String sqlQuery) throws Exception { - return new SimpleQueryTransform(sqlQuery); + public static SimpleQueryTransform simpleQuery(String sqlQuery) throws Exception { + return SimpleQueryTransform.builder() + .setSqlEnv(new BeamSqlEnv()) + .setSqlQuery(sqlQuery) + .build(); } /** * A {@link PTransform} representing an execution plan for a SQL query. */ - private static class QueryTransform extends + @AutoValue + public abstract static class QueryTransform extends PTransform<PCollectionTuple, PCollection<BeamSqlRow>> { - private transient BeamSqlEnv sqlEnv; - private String sqlQuery; + abstract BeamSqlEnv getSqlEnv(); + abstract String getSqlQuery(); - public QueryTransform(String sqlQuery) { - this.sqlQuery = sqlQuery; - sqlEnv = new BeamSqlEnv(); + static Builder builder() { + return new AutoValue_BeamSql_QueryTransform.Builder(); } - public QueryTransform(String sqlQuery, BeamSqlEnv sqlEnv) { - this.sqlQuery = sqlQuery; - this.sqlEnv = sqlEnv; + @AutoValue.Builder + abstract static class Builder { + abstract Builder setSqlQuery(String sqlQuery); + abstract Builder setSqlEnv(BeamSqlEnv sqlEnv); + abstract QueryTransform build(); } + /** + * register a UDF function used in this query. + */ + public QueryTransform withUdf(String functionName, Class<?> clazz, String methodName){ + getSqlEnv().registerUdf(functionName, clazz, methodName); + return this; + } + + /** + * register a UDAF function used in this query. + */ + public QueryTransform withUdaf(String functionName, Class<? extends BeamSqlUdaf> clazz){ + getSqlEnv().registerUdaf(functionName, clazz); + return this; + } + @Override public PCollection<BeamSqlRow> expand(PCollectionTuple input) { registerTables(input); BeamRelNode beamRelNode = null; try { - beamRelNode = sqlEnv.planner.convertToBeamRel(sqlQuery); + beamRelNode = getSqlEnv().planner.convertToBeamRel(getSqlQuery()); } catch (ValidationException | RelConversionException | SqlParseException e) { throw new IllegalStateException(e); } try { - return beamRelNode.buildBeamPipeline(input, sqlEnv); + return beamRelNode.buildBeamPipeline(input, getSqlEnv()); } catch (Exception e) { throw new IllegalStateException(e); } @@ -140,7 +165,7 @@ public class BeamSql { PCollection<BeamSqlRow> sourceStream = (PCollection<BeamSqlRow>) input.get(sourceTag); BeamSqlRowCoder sourceCoder = (BeamSqlRowCoder) sourceStream.getCoder(); - sqlEnv.registerTable(sourceTag.getId(), + getSqlEnv().registerTable(sourceTag.getId(), new BeamPCollectionTable(sourceStream, sourceCoder.getTableSchema())); } } @@ -150,26 +175,45 @@ public class BeamSql { * A {@link PTransform} representing an execution plan for a SQL query referencing * a single table. */ - private static class SimpleQueryTransform + @AutoValue + public abstract static class SimpleQueryTransform extends PTransform<PCollection<BeamSqlRow>, PCollection<BeamSqlRow>> { private static final String PCOLLECTION_TABLE_NAME = "PCOLLECTION"; - private transient BeamSqlEnv sqlEnv = new BeamSqlEnv(); - private String sqlQuery; + abstract BeamSqlEnv getSqlEnv(); + abstract String getSqlQuery(); - public SimpleQueryTransform(String sqlQuery) { - this.sqlQuery = sqlQuery; - validateQuery(); + static Builder builder() { + return new AutoValue_BeamSql_SimpleQueryTransform.Builder(); } - // public SimpleQueryTransform withUdf(String udfName){ - // throw new UnsupportedOperationException("Pending for UDF support"); - // } + @AutoValue.Builder + abstract static class Builder { + abstract Builder setSqlQuery(String sqlQuery); + abstract Builder setSqlEnv(BeamSqlEnv sqlEnv); + abstract SimpleQueryTransform build(); + } + + /** + * register a UDF function used in this query. + */ + public SimpleQueryTransform withUdf(String functionName, Class<?> clazz, String methodName){ + getSqlEnv().registerUdf(functionName, clazz, methodName); + return this; + } + + /** + * register a UDAF function used in this query. + */ + public SimpleQueryTransform withUdaf(String functionName, Class<? extends BeamSqlUdaf> clazz){ + getSqlEnv().registerUdaf(functionName, clazz); + return this; + } private void validateQuery() { SqlNode sqlNode; try { - sqlNode = sqlEnv.planner.parseQuery(sqlQuery); - sqlEnv.planner.getPlanner().close(); + sqlNode = getSqlEnv().planner.parseQuery(getSqlQuery()); + getSqlEnv().planner.getPlanner().close(); } catch (SqlParseException e) { throw new IllegalStateException(e); } @@ -188,8 +232,12 @@ public class BeamSql { @Override public PCollection<BeamSqlRow> expand(PCollection<BeamSqlRow> input) { + validateQuery(); return PCollectionTuple.of(new TupleTag<BeamSqlRow>(PCOLLECTION_TABLE_NAME), input) - .apply(new QueryTransform(sqlQuery, sqlEnv)); + .apply(QueryTransform.builder() + .setSqlEnv(getSqlEnv()) + .setSqlQuery(getSqlQuery()) + .build()); } } } http://git-wip-us.apache.org/repos/asf/beam/blob/5ca54e95/dsls/sql/src/main/java/org/apache/beam/dsls/sql/BeamSqlEnv.java ---------------------------------------------------------------------- diff --git a/dsls/sql/src/main/java/org/apache/beam/dsls/sql/BeamSqlEnv.java b/dsls/sql/src/main/java/org/apache/beam/dsls/sql/BeamSqlEnv.java index 078d9d3..61f0355 100644 --- a/dsls/sql/src/main/java/org/apache/beam/dsls/sql/BeamSqlEnv.java +++ b/dsls/sql/src/main/java/org/apache/beam/dsls/sql/BeamSqlEnv.java @@ -43,9 +43,9 @@ import org.apache.calcite.tools.Frameworks; * <p>It contains a {@link SchemaPlus} which holds the metadata of tables/UDF functions, and * a {@link BeamQueryPlanner} which parse/validate/optimize/translate input SQL queries. */ -public class BeamSqlEnv { - SchemaPlus schema; - BeamQueryPlanner planner; +public class BeamSqlEnv implements Serializable{ + transient SchemaPlus schema; + transient BeamQueryPlanner planner; public BeamSqlEnv() { schema = Frameworks.createRootSchema(true); http://git-wip-us.apache.org/repos/asf/beam/blob/5ca54e95/dsls/sql/src/test/java/org/apache/beam/dsls/sql/BeamSqlDslUdfUdafTest.java ---------------------------------------------------------------------- diff --git a/dsls/sql/src/test/java/org/apache/beam/dsls/sql/BeamSqlDslUdfUdafTest.java b/dsls/sql/src/test/java/org/apache/beam/dsls/sql/BeamSqlDslUdfUdafTest.java new file mode 100644 index 0000000..ba3e87e --- /dev/null +++ b/dsls/sql/src/test/java/org/apache/beam/dsls/sql/BeamSqlDslUdfUdafTest.java @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.dsls.sql; + +import java.sql.Types; +import java.util.Arrays; +import java.util.Iterator; +import org.apache.beam.dsls.sql.schema.BeamSqlRecordType; +import org.apache.beam.dsls.sql.schema.BeamSqlRow; +import org.apache.beam.dsls.sql.schema.BeamSqlUdaf; +import org.apache.beam.sdk.testing.PAssert; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PCollectionTuple; +import org.apache.beam.sdk.values.TupleTag; +import org.junit.Test; + +/** + * Tests for UDF/UDAF. + */ +public class BeamSqlDslUdfUdafTest extends BeamSqlDslBase { + /** + * GROUP-BY with UDAF. + */ + @Test + public void testUdaf() throws Exception { + BeamSqlRecordType resultType = BeamSqlRecordType.create(Arrays.asList("f_int2", "squaresum"), + Arrays.asList(Types.INTEGER, Types.INTEGER)); + + BeamSqlRow record = new BeamSqlRow(resultType); + record.addField("f_int2", 0); + record.addField("squaresum", 30); + + String sql1 = "SELECT f_int2, squaresum1(f_int) AS `squaresum`" + + " FROM PCOLLECTION GROUP BY f_int2"; + PCollection<BeamSqlRow> result1 = + boundedInput1.apply("testUdaf1", + BeamSql.simpleQuery(sql1).withUdaf("squaresum1", SquareSum.class)); + PAssert.that(result1).containsInAnyOrder(record); + + String sql2 = "SELECT f_int2, squaresum2(f_int) AS `squaresum`" + + " FROM PCOLLECTION GROUP BY f_int2"; + PCollection<BeamSqlRow> result2 = + PCollectionTuple.of(new TupleTag<BeamSqlRow>("PCOLLECTION"), boundedInput1) + .apply("testUdaf2", + BeamSql.query(sql2).withUdaf("squaresum2", SquareSum.class)); + PAssert.that(result2).containsInAnyOrder(record); + + pipeline.run().waitUntilFinish(); + } + + /** + * test UDF. + */ + @Test + public void testUdf() throws Exception{ + BeamSqlRecordType resultType = BeamSqlRecordType.create(Arrays.asList("f_int", "cubicvalue"), + Arrays.asList(Types.INTEGER, Types.INTEGER)); + + BeamSqlRow record = new BeamSqlRow(resultType); + record.addField("f_int", 2); + record.addField("cubicvalue", 8); + + String sql1 = "SELECT f_int, cubic1(f_int) as cubicvalue FROM PCOLLECTION WHERE f_int = 2"; + PCollection<BeamSqlRow> result1 = + boundedInput1.apply("testUdf1", + BeamSql.simpleQuery(sql1).withUdf("cubic1", CubicInteger.class, "cubic")); + PAssert.that(result1).containsInAnyOrder(record); + + String sql2 = "SELECT f_int, cubic2(f_int) as cubicvalue FROM PCOLLECTION WHERE f_int = 2"; + PCollection<BeamSqlRow> result2 = + PCollectionTuple.of(new TupleTag<BeamSqlRow>("PCOLLECTION"), boundedInput1) + .apply("testUdf2", + BeamSql.query(sql2).withUdf("cubic2", CubicInteger.class, "cubic")); + PAssert.that(result2).containsInAnyOrder(record); + + pipeline.run().waitUntilFinish(); + } + + /** + * UDAF for test, which returns the sum of square. + */ + public static class SquareSum extends BeamSqlUdaf<Integer, Integer, Integer> { + + public SquareSum() { + } + + @Override + public Integer init() { + return 0; + } + + @Override + public Integer add(Integer accumulator, Integer input) { + return accumulator + input * input; + } + + @Override + public Integer merge(Iterable<Integer> accumulators) { + int v = 0; + Iterator<Integer> ite = accumulators.iterator(); + while (ite.hasNext()) { + v += ite.next(); + } + return v; + } + + @Override + public Integer result(Integer accumulator) { + return accumulator; + } + + } + + /** + * A example UDF for test. + */ + public static class CubicInteger{ + public static Integer cubic(Integer input){ + return input * input * input; + } + } +}
