Repository: spark Updated Branches: refs/heads/master 960298ee6 -> 742da0868
[SPARK-19439][PYSPARK][SQL] PySpark's registerJavaFunction Should Support UDAFs ## What changes were proposed in this pull request? Support register Java UDAFs in PySpark so that user can use Java UDAF in PySpark. Besides that I also add api in `UDFRegistration` ## How was this patch tested? Unit test is added Author: Jeff Zhang <[email protected]> Closes #17222 from zjffdu/SPARK-19439. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/742da086 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/742da086 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/742da086 Branch: refs/heads/master Commit: 742da0868534dab3d4d7b7edbe5ba9dc8bf26cc8 Parents: 960298e Author: Jeff Zhang <[email protected]> Authored: Wed Jul 5 10:59:10 2017 -0700 Committer: gatorsmile <[email protected]> Committed: Wed Jul 5 10:59:10 2017 -0700 ---------------------------------------------------------------------- python/pyspark/sql/context.py | 23 ++++ python/pyspark/sql/tests.py | 10 ++ .../org/apache/spark/sql/UDFRegistration.scala | 33 ++++- .../org/apache/spark/sql/JavaUDAFSuite.java | 55 ++++++++ .../test/org/apache/spark/sql/MyDoubleAvg.java | 129 +++++++++++++++++++ .../test/org/apache/spark/sql/MyDoubleSum.java | 118 +++++++++++++++++ sql/hive/pom.xml | 7 + .../spark/sql/hive/JavaDataFrameSuite.java | 2 +- .../spark/sql/hive/aggregate/MyDoubleAvg.java | 129 ------------------- .../spark/sql/hive/aggregate/MyDoubleSum.java | 118 ----------------- .../hive/execution/AggregationQuerySuite.scala | 5 +- 11 files changed, 374 insertions(+), 255 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/742da086/python/pyspark/sql/context.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 426f07c..c44ab24 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -232,6 +232,23 @@ class SQLContext(object): jdt = self.sparkSession._jsparkSession.parseDataType(returnType.json()) self.sparkSession._jsparkSession.udf().registerJava(name, javaClassName, jdt) + @ignore_unicode_prefix + @since(2.3) + def registerJavaUDAF(self, name, javaClassName): + """Register a java UDAF so it can be used in SQL statements. + + :param name: name of the UDAF + :param javaClassName: fully qualified name of java class + + >>> sqlContext.registerJavaUDAF("javaUDAF", + ... "test.org.apache.spark.sql.MyDoubleAvg") + >>> df = sqlContext.createDataFrame([(1, "a"),(2, "b"), (3, "a")],["id", "name"]) + >>> df.registerTempTable("df") + >>> sqlContext.sql("SELECT name, javaUDAF(id) as avg from df group by name").collect() + [Row(name=u'b', avg=102.0), Row(name=u'a', avg=102.0)] + """ + self.sparkSession._jsparkSession.udf().registerJavaUDAF(name, javaClassName) + # TODO(andrew): delete this once we refactor things to take in SparkSession def _inferSchema(self, rdd, samplingRatio=None): """ @@ -551,6 +568,12 @@ class UDFRegistration(object): def register(self, name, f, returnType=StringType()): return self.sqlContext.registerFunction(name, f, returnType) + def registerJavaFunction(self, name, javaClassName, returnType=None): + self.sqlContext.registerJavaFunction(name, javaClassName, returnType) + + def registerJavaUDAF(self, name, javaClassName): + self.sqlContext.registerJavaUDAF(name, javaClassName) + register.__doc__ = SQLContext.registerFunction.__doc__ http://git-wip-us.apache.org/repos/asf/spark/blob/742da086/python/pyspark/sql/tests.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 16ba8bd..c0e3b8d 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -481,6 +481,16 @@ class SQLTests(ReusedPySparkTestCase): df.select(add_three("id").alias("plus_three")).collect() ) + def test_non_existed_udf(self): + spark = self.spark + self.assertRaisesRegexp(AnalysisException, "Can not load class non_existed_udf", + lambda: spark.udf.registerJavaFunction("udf1", "non_existed_udf")) + + def test_non_existed_udaf(self): + spark = self.spark + self.assertRaisesRegexp(AnalysisException, "Can not load class non_existed_udaf", + lambda: spark.udf.registerJavaUDAF("udaf1", "non_existed_udaf")) + def test_multiLine_json(self): people1 = self.spark.read.json("python/test_support/sql/people.json") people_array = self.spark.read.json("python/test_support/sql/people_array.json", http://git-wip-us.apache.org/repos/asf/spark/blob/742da086/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala index ad01b88..8bdc022 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql -import java.io.IOException import java.lang.reflect.{ParameterizedType, Type} import scala.reflect.runtime.universe.TypeTag @@ -456,9 +455,9 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends .map(_.asInstanceOf[ParameterizedType]) .filter(e => e.getRawType.isInstanceOf[Class[_]] && e.getRawType.asInstanceOf[Class[_]].getCanonicalName.startsWith("org.apache.spark.sql.api.java.UDF")) if (udfInterfaces.length == 0) { - throw new IOException(s"UDF class ${className} doesn't implement any UDF interface") + throw new AnalysisException(s"UDF class ${className} doesn't implement any UDF interface") } else if (udfInterfaces.length > 1) { - throw new IOException(s"It is invalid to implement multiple UDF interfaces, UDF class ${className}") + throw new AnalysisException(s"It is invalid to implement multiple UDF interfaces, UDF class ${className}") } else { try { val udf = clazz.newInstance() @@ -491,20 +490,42 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends case 21 => register(name, udf.asInstanceOf[UDF20[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType) case 22 => register(name, udf.asInstanceOf[UDF21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType) case 23 => register(name, udf.asInstanceOf[UDF22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType) - case n => logError(s"UDF class with ${n} type arguments is not supported ") + case n => + throw new AnalysisException(s"UDF class with ${n} type arguments is not supported.") } } catch { case e @ (_: InstantiationException | _: IllegalArgumentException) => - logError(s"Can not instantiate class ${className}, please make sure it has public non argument constructor") + throw new AnalysisException(s"Can not instantiate class ${className}, please make sure it has public non argument constructor") } } } catch { - case e: ClassNotFoundException => logError(s"Can not load class ${className}, please make sure it is on the classpath") + case e: ClassNotFoundException => throw new AnalysisException(s"Can not load class ${className}, please make sure it is on the classpath") } } /** + * Register a Java UDAF class using reflection, for use from pyspark + * + * @param name UDAF name + * @param className fully qualified class name of UDAF + */ + private[sql] def registerJavaUDAF(name: String, className: String): Unit = { + try { + val clazz = Utils.classForName(className) + if (!classOf[UserDefinedAggregateFunction].isAssignableFrom(clazz)) { + throw new AnalysisException(s"class $className doesn't implement interface UserDefinedAggregateFunction") + } + val udaf = clazz.newInstance().asInstanceOf[UserDefinedAggregateFunction] + register(name, udaf) + } catch { + case e: ClassNotFoundException => throw new AnalysisException(s"Can not load class ${className}, please make sure it is on the classpath") + case e @ (_: InstantiationException | _: IllegalArgumentException) => + throw new AnalysisException(s"Can not instantiate class ${className}, please make sure it has public non argument constructor") + } + } + + /** * Register a user-defined function with 1 arguments. * @since 1.3.0 */ http://git-wip-us.apache.org/repos/asf/spark/blob/742da086/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDAFSuite.java ---------------------------------------------------------------------- diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDAFSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDAFSuite.java new file mode 100644 index 0000000..ddbaa45 --- /dev/null +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDAFSuite.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark.sql; + +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + + +public class JavaUDAFSuite { + + private transient SparkSession spark; + + @Before + public void setUp() { + spark = SparkSession.builder() + .master("local[*]") + .appName("testing") + .getOrCreate(); + } + + @After + public void tearDown() { + spark.stop(); + spark = null; + } + + @SuppressWarnings("unchecked") + @Test + public void udf1Test() { + spark.range(1, 10).toDF("value").registerTempTable("df"); + spark.udf().registerJavaUDAF("myDoubleAvg", MyDoubleAvg.class.getName()); + Row result = spark.sql("SELECT myDoubleAvg(value) as my_avg from df").head(); + Assert.assertEquals(105.0, result.getDouble(0), 1.0e-6); + } + +} http://git-wip-us.apache.org/repos/asf/spark/blob/742da086/sql/core/src/test/java/test/org/apache/spark/sql/MyDoubleAvg.java ---------------------------------------------------------------------- diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/MyDoubleAvg.java b/sql/core/src/test/java/test/org/apache/spark/sql/MyDoubleAvg.java new file mode 100644 index 0000000..447a71d --- /dev/null +++ b/sql/core/src/test/java/test/org/apache/spark/sql/MyDoubleAvg.java @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark.sql; + +import java.util.ArrayList; +import java.util.List; + +import org.apache.spark.sql.Row; +import org.apache.spark.sql.expressions.MutableAggregationBuffer; +import org.apache.spark.sql.expressions.UserDefinedAggregateFunction; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +/** + * An example {@link UserDefinedAggregateFunction} to calculate a special average value of a + * {@link org.apache.spark.sql.types.DoubleType} column. This special average value is the sum + * of the average value of input values and 100.0. + */ +public class MyDoubleAvg extends UserDefinedAggregateFunction { + + private StructType _inputDataType; + + private StructType _bufferSchema; + + private DataType _returnDataType; + + public MyDoubleAvg() { + List<StructField> inputFields = new ArrayList<>(); + inputFields.add(DataTypes.createStructField("inputDouble", DataTypes.DoubleType, true)); + _inputDataType = DataTypes.createStructType(inputFields); + + // The buffer has two values, bufferSum for storing the current sum and + // bufferCount for storing the number of non-null input values that have been contribuetd + // to the current sum. + List<StructField> bufferFields = new ArrayList<>(); + bufferFields.add(DataTypes.createStructField("bufferSum", DataTypes.DoubleType, true)); + bufferFields.add(DataTypes.createStructField("bufferCount", DataTypes.LongType, true)); + _bufferSchema = DataTypes.createStructType(bufferFields); + + _returnDataType = DataTypes.DoubleType; + } + + @Override public StructType inputSchema() { + return _inputDataType; + } + + @Override public StructType bufferSchema() { + return _bufferSchema; + } + + @Override public DataType dataType() { + return _returnDataType; + } + + @Override public boolean deterministic() { + return true; + } + + @Override public void initialize(MutableAggregationBuffer buffer) { + // The initial value of the sum is null. + buffer.update(0, null); + // The initial value of the count is 0. + buffer.update(1, 0L); + } + + @Override public void update(MutableAggregationBuffer buffer, Row input) { + // This input Row only has a single column storing the input value in Double. + // We only update the buffer when the input value is not null. + if (!input.isNullAt(0)) { + // If the buffer value (the intermediate result of the sum) is still null, + // we set the input value to the buffer and set the bufferCount to 1. + if (buffer.isNullAt(0)) { + buffer.update(0, input.getDouble(0)); + buffer.update(1, 1L); + } else { + // Otherwise, update the bufferSum and increment bufferCount. + Double newValue = input.getDouble(0) + buffer.getDouble(0); + buffer.update(0, newValue); + buffer.update(1, buffer.getLong(1) + 1L); + } + } + } + + @Override public void merge(MutableAggregationBuffer buffer1, Row buffer2) { + // buffer1 and buffer2 have the same structure. + // We only update the buffer1 when the input buffer2's sum value is not null. + if (!buffer2.isNullAt(0)) { + if (buffer1.isNullAt(0)) { + // If the buffer value (intermediate result of the sum) is still null, + // we set the it as the input buffer's value. + buffer1.update(0, buffer2.getDouble(0)); + buffer1.update(1, buffer2.getLong(1)); + } else { + // Otherwise, we update the bufferSum and bufferCount. + Double newValue = buffer2.getDouble(0) + buffer1.getDouble(0); + buffer1.update(0, newValue); + buffer1.update(1, buffer1.getLong(1) + buffer2.getLong(1)); + } + } + } + + @Override public Object evaluate(Row buffer) { + if (buffer.isNullAt(0)) { + // If the bufferSum is still null, we return null because this function has not got + // any input row. + return null; + } else { + // Otherwise, we calculate the special average value. + return buffer.getDouble(0) / buffer.getLong(1) + 100.0; + } + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/742da086/sql/core/src/test/java/test/org/apache/spark/sql/MyDoubleSum.java ---------------------------------------------------------------------- diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/MyDoubleSum.java b/sql/core/src/test/java/test/org/apache/spark/sql/MyDoubleSum.java new file mode 100644 index 0000000..93d2033 --- /dev/null +++ b/sql/core/src/test/java/test/org/apache/spark/sql/MyDoubleSum.java @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark.sql; + +import java.util.ArrayList; +import java.util.List; + +import org.apache.spark.sql.Row; +import org.apache.spark.sql.expressions.MutableAggregationBuffer; +import org.apache.spark.sql.expressions.UserDefinedAggregateFunction; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +/** + * An example {@link UserDefinedAggregateFunction} to calculate the sum of a + * {@link org.apache.spark.sql.types.DoubleType} column. + */ +public class MyDoubleSum extends UserDefinedAggregateFunction { + + private StructType _inputDataType; + + private StructType _bufferSchema; + + private DataType _returnDataType; + + public MyDoubleSum() { + List<StructField> inputFields = new ArrayList<>(); + inputFields.add(DataTypes.createStructField("inputDouble", DataTypes.DoubleType, true)); + _inputDataType = DataTypes.createStructType(inputFields); + + List<StructField> bufferFields = new ArrayList<>(); + bufferFields.add(DataTypes.createStructField("bufferDouble", DataTypes.DoubleType, true)); + _bufferSchema = DataTypes.createStructType(bufferFields); + + _returnDataType = DataTypes.DoubleType; + } + + @Override public StructType inputSchema() { + return _inputDataType; + } + + @Override public StructType bufferSchema() { + return _bufferSchema; + } + + @Override public DataType dataType() { + return _returnDataType; + } + + @Override public boolean deterministic() { + return true; + } + + @Override public void initialize(MutableAggregationBuffer buffer) { + // The initial value of the sum is null. + buffer.update(0, null); + } + + @Override public void update(MutableAggregationBuffer buffer, Row input) { + // This input Row only has a single column storing the input value in Double. + // We only update the buffer when the input value is not null. + if (!input.isNullAt(0)) { + if (buffer.isNullAt(0)) { + // If the buffer value (the intermediate result of the sum) is still null, + // we set the input value to the buffer. + buffer.update(0, input.getDouble(0)); + } else { + // Otherwise, we add the input value to the buffer value. + Double newValue = input.getDouble(0) + buffer.getDouble(0); + buffer.update(0, newValue); + } + } + } + + @Override public void merge(MutableAggregationBuffer buffer1, Row buffer2) { + // buffer1 and buffer2 have the same structure. + // We only update the buffer1 when the input buffer2's value is not null. + if (!buffer2.isNullAt(0)) { + if (buffer1.isNullAt(0)) { + // If the buffer value (intermediate result of the sum) is still null, + // we set the it as the input buffer's value. + buffer1.update(0, buffer2.getDouble(0)); + } else { + // Otherwise, we add the input buffer's value (buffer1) to the mutable + // buffer's value (buffer2). + Double newValue = buffer2.getDouble(0) + buffer1.getDouble(0); + buffer1.update(0, newValue); + } + } + } + + @Override public Object evaluate(Row buffer) { + if (buffer.isNullAt(0)) { + // If the buffer value is still null, we return null. + return null; + } else { + // Otherwise, the intermediate sum is the final result. + return buffer.getDouble(0); + } + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/742da086/sql/hive/pom.xml ---------------------------------------------------------------------- diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml index 09dcc40..f9462e7 100644 --- a/sql/hive/pom.xml +++ b/sql/hive/pom.xml @@ -59,6 +59,13 @@ </dependency> <dependency> <groupId>org.apache.spark</groupId> + <artifactId>spark-sql_${scala.binary.version}</artifactId> + <version>${project.version}</version> + <type>test-jar</type> + <scope>test</scope> + </dependency> + <dependency> + <groupId>org.apache.spark</groupId> <artifactId>spark-tags_${scala.binary.version}</artifactId> <type>test-jar</type> <scope>test</scope> http://git-wip-us.apache.org/repos/asf/spark/blob/742da086/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaDataFrameSuite.java ---------------------------------------------------------------------- diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaDataFrameSuite.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaDataFrameSuite.java index aefc9cc..636ce10 100644 --- a/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaDataFrameSuite.java +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaDataFrameSuite.java @@ -31,7 +31,7 @@ import org.apache.spark.sql.expressions.Window; import org.apache.spark.sql.expressions.UserDefinedAggregateFunction; import static org.apache.spark.sql.functions.*; import org.apache.spark.sql.hive.test.TestHive$; -import org.apache.spark.sql.hive.aggregate.MyDoubleSum; +import test.org.apache.spark.sql.MyDoubleSum; public class JavaDataFrameSuite { private transient SQLContext hc; http://git-wip-us.apache.org/repos/asf/spark/blob/742da086/sql/hive/src/test/java/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java ---------------------------------------------------------------------- diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java deleted file mode 100644 index ae0c097..0000000 --- a/sql/hive/src/test/java/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java +++ /dev/null @@ -1,129 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.hive.aggregate; - -import java.util.ArrayList; -import java.util.List; - -import org.apache.spark.sql.Row; -import org.apache.spark.sql.expressions.MutableAggregationBuffer; -import org.apache.spark.sql.expressions.UserDefinedAggregateFunction; -import org.apache.spark.sql.types.DataType; -import org.apache.spark.sql.types.DataTypes; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; - -/** - * An example {@link UserDefinedAggregateFunction} to calculate a special average value of a - * {@link org.apache.spark.sql.types.DoubleType} column. This special average value is the sum - * of the average value of input values and 100.0. - */ -public class MyDoubleAvg extends UserDefinedAggregateFunction { - - private StructType _inputDataType; - - private StructType _bufferSchema; - - private DataType _returnDataType; - - public MyDoubleAvg() { - List<StructField> inputFields = new ArrayList<>(); - inputFields.add(DataTypes.createStructField("inputDouble", DataTypes.DoubleType, true)); - _inputDataType = DataTypes.createStructType(inputFields); - - // The buffer has two values, bufferSum for storing the current sum and - // bufferCount for storing the number of non-null input values that have been contribuetd - // to the current sum. - List<StructField> bufferFields = new ArrayList<>(); - bufferFields.add(DataTypes.createStructField("bufferSum", DataTypes.DoubleType, true)); - bufferFields.add(DataTypes.createStructField("bufferCount", DataTypes.LongType, true)); - _bufferSchema = DataTypes.createStructType(bufferFields); - - _returnDataType = DataTypes.DoubleType; - } - - @Override public StructType inputSchema() { - return _inputDataType; - } - - @Override public StructType bufferSchema() { - return _bufferSchema; - } - - @Override public DataType dataType() { - return _returnDataType; - } - - @Override public boolean deterministic() { - return true; - } - - @Override public void initialize(MutableAggregationBuffer buffer) { - // The initial value of the sum is null. - buffer.update(0, null); - // The initial value of the count is 0. - buffer.update(1, 0L); - } - - @Override public void update(MutableAggregationBuffer buffer, Row input) { - // This input Row only has a single column storing the input value in Double. - // We only update the buffer when the input value is not null. - if (!input.isNullAt(0)) { - // If the buffer value (the intermediate result of the sum) is still null, - // we set the input value to the buffer and set the bufferCount to 1. - if (buffer.isNullAt(0)) { - buffer.update(0, input.getDouble(0)); - buffer.update(1, 1L); - } else { - // Otherwise, update the bufferSum and increment bufferCount. - Double newValue = input.getDouble(0) + buffer.getDouble(0); - buffer.update(0, newValue); - buffer.update(1, buffer.getLong(1) + 1L); - } - } - } - - @Override public void merge(MutableAggregationBuffer buffer1, Row buffer2) { - // buffer1 and buffer2 have the same structure. - // We only update the buffer1 when the input buffer2's sum value is not null. - if (!buffer2.isNullAt(0)) { - if (buffer1.isNullAt(0)) { - // If the buffer value (intermediate result of the sum) is still null, - // we set the it as the input buffer's value. - buffer1.update(0, buffer2.getDouble(0)); - buffer1.update(1, buffer2.getLong(1)); - } else { - // Otherwise, we update the bufferSum and bufferCount. - Double newValue = buffer2.getDouble(0) + buffer1.getDouble(0); - buffer1.update(0, newValue); - buffer1.update(1, buffer1.getLong(1) + buffer2.getLong(1)); - } - } - } - - @Override public Object evaluate(Row buffer) { - if (buffer.isNullAt(0)) { - // If the bufferSum is still null, we return null because this function has not got - // any input row. - return null; - } else { - // Otherwise, we calculate the special average value. - return buffer.getDouble(0) / buffer.getLong(1) + 100.0; - } - } -} http://git-wip-us.apache.org/repos/asf/spark/blob/742da086/sql/hive/src/test/java/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java ---------------------------------------------------------------------- diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java deleted file mode 100644 index d17fb3e..0000000 --- a/sql/hive/src/test/java/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java +++ /dev/null @@ -1,118 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.hive.aggregate; - -import java.util.ArrayList; -import java.util.List; - -import org.apache.spark.sql.expressions.MutableAggregationBuffer; -import org.apache.spark.sql.expressions.UserDefinedAggregateFunction; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; -import org.apache.spark.sql.types.DataType; -import org.apache.spark.sql.types.DataTypes; -import org.apache.spark.sql.Row; - -/** - * An example {@link UserDefinedAggregateFunction} to calculate the sum of a - * {@link org.apache.spark.sql.types.DoubleType} column. - */ -public class MyDoubleSum extends UserDefinedAggregateFunction { - - private StructType _inputDataType; - - private StructType _bufferSchema; - - private DataType _returnDataType; - - public MyDoubleSum() { - List<StructField> inputFields = new ArrayList<>(); - inputFields.add(DataTypes.createStructField("inputDouble", DataTypes.DoubleType, true)); - _inputDataType = DataTypes.createStructType(inputFields); - - List<StructField> bufferFields = new ArrayList<>(); - bufferFields.add(DataTypes.createStructField("bufferDouble", DataTypes.DoubleType, true)); - _bufferSchema = DataTypes.createStructType(bufferFields); - - _returnDataType = DataTypes.DoubleType; - } - - @Override public StructType inputSchema() { - return _inputDataType; - } - - @Override public StructType bufferSchema() { - return _bufferSchema; - } - - @Override public DataType dataType() { - return _returnDataType; - } - - @Override public boolean deterministic() { - return true; - } - - @Override public void initialize(MutableAggregationBuffer buffer) { - // The initial value of the sum is null. - buffer.update(0, null); - } - - @Override public void update(MutableAggregationBuffer buffer, Row input) { - // This input Row only has a single column storing the input value in Double. - // We only update the buffer when the input value is not null. - if (!input.isNullAt(0)) { - if (buffer.isNullAt(0)) { - // If the buffer value (the intermediate result of the sum) is still null, - // we set the input value to the buffer. - buffer.update(0, input.getDouble(0)); - } else { - // Otherwise, we add the input value to the buffer value. - Double newValue = input.getDouble(0) + buffer.getDouble(0); - buffer.update(0, newValue); - } - } - } - - @Override public void merge(MutableAggregationBuffer buffer1, Row buffer2) { - // buffer1 and buffer2 have the same structure. - // We only update the buffer1 when the input buffer2's value is not null. - if (!buffer2.isNullAt(0)) { - if (buffer1.isNullAt(0)) { - // If the buffer value (intermediate result of the sum) is still null, - // we set the it as the input buffer's value. - buffer1.update(0, buffer2.getDouble(0)); - } else { - // Otherwise, we add the input buffer's value (buffer1) to the mutable - // buffer's value (buffer2). - Double newValue = buffer2.getDouble(0) + buffer1.getDouble(0); - buffer1.update(0, newValue); - } - } - } - - @Override public Object evaluate(Row buffer) { - if (buffer.isNullAt(0)) { - // If the buffer value is still null, we return null. - return null; - } else { - // Otherwise, the intermediate sum is the final result. - return buffer.getDouble(0); - } - } -} http://git-wip-us.apache.org/repos/asf/spark/blob/742da086/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala ---------------------------------------------------------------------- diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index 84f9159..f245a79 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -20,16 +20,19 @@ package org.apache.spark.sql.hive.execution import scala.collection.JavaConverters._ import scala.util.Random +import test.org.apache.spark.sql.MyDoubleAvg +import test.org.apache.spark.sql.MyDoubleSum + import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction} import org.apache.spark.sql.functions._ -import org.apache.spark.sql.hive.aggregate.{MyDoubleAvg, MyDoubleSum} import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ + class ScalaAggregateFunction(schema: StructType) extends UserDefinedAggregateFunction { def inputSchema: StructType = schema --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
