[SPARK-18296][SQL] Use consistent naming for expression test suites ## What changes were proposed in this pull request? We have an undocumented naming convention to call expression unit tests ExpressionsSuite, and the end-to-end tests FunctionsSuite. It'd be great to make all test suites consistent with this naming convention.
## How was this patch tested? This is a test-only naming change. Author: Reynold Xin <r...@databricks.com> Closes #15793 from rxin/SPARK-18296. (cherry picked from commit 9db06c442cf85e41d51c7b167817f4e7971bf0da) Signed-off-by: Reynold Xin <r...@databricks.com> Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/2fa1a632 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/2fa1a632 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/2fa1a632 Branch: refs/heads/branch-2.1 Commit: 2fa1a632ae4e68ffa01fad0d6150219c13355724 Parents: 9ebd5e5 Author: Reynold Xin <r...@databricks.com> Authored: Sun Nov 6 22:44:55 2016 -0800 Committer: Reynold Xin <r...@databricks.com> Committed: Sun Nov 6 22:45:02 2016 -0800 ---------------------------------------------------------------------- .../expressions/BitwiseExpressionsSuite.scala | 134 +++++ .../expressions/BitwiseFunctionsSuite.scala | 134 ----- .../CollectionExpressionsSuite.scala | 108 ++++ .../expressions/CollectionFunctionsSuite.scala | 109 ---- .../expressions/MathExpressionsSuite.scala | 582 +++++++++++++++++++ .../expressions/MathFunctionsSuite.scala | 582 ------------------- .../expressions/MiscExpressionsSuite.scala | 42 ++ .../expressions/MiscFunctionsSuite.scala | 42 -- .../expressions/NullExpressionsSuite.scala | 136 +++++ .../expressions/NullFunctionsSuite.scala | 136 ----- .../apache/spark/sql/MathExpressionsSuite.scala | 424 -------------- .../apache/spark/sql/MathFunctionsSuite.scala | 424 ++++++++++++++ 12 files changed, 1426 insertions(+), 1427 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/2fa1a632/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/BitwiseExpressionsSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/BitwiseExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/BitwiseExpressionsSuite.scala new file mode 100644 index 0000000..4188dad --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/BitwiseExpressionsSuite.scala @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.types._ + + +class BitwiseExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { + + import IntegralLiteralTestUtils._ + + test("BitwiseNOT") { + def check(input: Any, expected: Any): Unit = { + val expr = BitwiseNot(Literal(input)) + assert(expr.dataType === Literal(input).dataType) + checkEvaluation(expr, expected) + } + + // Need the extra toByte even though IntelliJ thought it's not needed. + check(1.toByte, (~1.toByte).toByte) + check(1000.toShort, (~1000.toShort).toShort) + check(1000000, ~1000000) + check(123456789123L, ~123456789123L) + + checkEvaluation(BitwiseNot(Literal.create(null, IntegerType)), null) + checkEvaluation(BitwiseNot(positiveShortLit), (~positiveShort).toShort) + checkEvaluation(BitwiseNot(negativeShortLit), (~negativeShort).toShort) + checkEvaluation(BitwiseNot(positiveIntLit), ~positiveInt) + checkEvaluation(BitwiseNot(negativeIntLit), ~negativeInt) + checkEvaluation(BitwiseNot(positiveLongLit), ~positiveLong) + checkEvaluation(BitwiseNot(negativeLongLit), ~negativeLong) + + DataTypeTestUtils.integralType.foreach { dt => + checkConsistencyBetweenInterpretedAndCodegen(BitwiseNot, dt) + } + } + + test("BitwiseAnd") { + def check(input1: Any, input2: Any, expected: Any): Unit = { + val expr = BitwiseAnd(Literal(input1), Literal(input2)) + assert(expr.dataType === Literal(input1).dataType) + checkEvaluation(expr, expected) + } + + // Need the extra toByte even though IntelliJ thought it's not needed. + check(1.toByte, 2.toByte, (1.toByte & 2.toByte).toByte) + check(1000.toShort, 2.toShort, (1000.toShort & 2.toShort).toShort) + check(1000000, 4, 1000000 & 4) + check(123456789123L, 5L, 123456789123L & 5L) + + val nullLit = Literal.create(null, IntegerType) + checkEvaluation(BitwiseAnd(nullLit, Literal(1)), null) + checkEvaluation(BitwiseAnd(Literal(1), nullLit), null) + checkEvaluation(BitwiseAnd(nullLit, nullLit), null) + checkEvaluation(BitwiseAnd(positiveShortLit, negativeShortLit), + (positiveShort & negativeShort).toShort) + checkEvaluation(BitwiseAnd(positiveIntLit, negativeIntLit), positiveInt & negativeInt) + checkEvaluation(BitwiseAnd(positiveLongLit, negativeLongLit), positiveLong & negativeLong) + + DataTypeTestUtils.integralType.foreach { dt => + checkConsistencyBetweenInterpretedAndCodegen(BitwiseAnd, dt, dt) + } + } + + test("BitwiseOr") { + def check(input1: Any, input2: Any, expected: Any): Unit = { + val expr = BitwiseOr(Literal(input1), Literal(input2)) + assert(expr.dataType === Literal(input1).dataType) + checkEvaluation(expr, expected) + } + + // Need the extra toByte even though IntelliJ thought it's not needed. + check(1.toByte, 2.toByte, (1.toByte | 2.toByte).toByte) + check(1000.toShort, 2.toShort, (1000.toShort | 2.toShort).toShort) + check(1000000, 4, 1000000 | 4) + check(123456789123L, 5L, 123456789123L | 5L) + + val nullLit = Literal.create(null, IntegerType) + checkEvaluation(BitwiseOr(nullLit, Literal(1)), null) + checkEvaluation(BitwiseOr(Literal(1), nullLit), null) + checkEvaluation(BitwiseOr(nullLit, nullLit), null) + checkEvaluation(BitwiseOr(positiveShortLit, negativeShortLit), + (positiveShort | negativeShort).toShort) + checkEvaluation(BitwiseOr(positiveIntLit, negativeIntLit), positiveInt | negativeInt) + checkEvaluation(BitwiseOr(positiveLongLit, negativeLongLit), positiveLong | negativeLong) + + DataTypeTestUtils.integralType.foreach { dt => + checkConsistencyBetweenInterpretedAndCodegen(BitwiseOr, dt, dt) + } + } + + test("BitwiseXor") { + def check(input1: Any, input2: Any, expected: Any): Unit = { + val expr = BitwiseXor(Literal(input1), Literal(input2)) + assert(expr.dataType === Literal(input1).dataType) + checkEvaluation(expr, expected) + } + + // Need the extra toByte even though IntelliJ thought it's not needed. + check(1.toByte, 2.toByte, (1.toByte ^ 2.toByte).toByte) + check(1000.toShort, 2.toShort, (1000.toShort ^ 2.toShort).toShort) + check(1000000, 4, 1000000 ^ 4) + check(123456789123L, 5L, 123456789123L ^ 5L) + + val nullLit = Literal.create(null, IntegerType) + checkEvaluation(BitwiseXor(nullLit, Literal(1)), null) + checkEvaluation(BitwiseXor(Literal(1), nullLit), null) + checkEvaluation(BitwiseXor(nullLit, nullLit), null) + checkEvaluation(BitwiseXor(positiveShortLit, negativeShortLit), + (positiveShort ^ negativeShort).toShort) + checkEvaluation(BitwiseXor(positiveIntLit, negativeIntLit), positiveInt ^ negativeInt) + checkEvaluation(BitwiseXor(positiveLongLit, negativeLongLit), positiveLong ^ negativeLong) + + DataTypeTestUtils.integralType.foreach { dt => + checkConsistencyBetweenInterpretedAndCodegen(BitwiseXor, dt, dt) + } + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/2fa1a632/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/BitwiseFunctionsSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/BitwiseFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/BitwiseFunctionsSuite.scala deleted file mode 100644 index 3a310c0..0000000 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/BitwiseFunctionsSuite.scala +++ /dev/null @@ -1,134 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions - -import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.types._ - - -class BitwiseFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { - - import IntegralLiteralTestUtils._ - - test("BitwiseNOT") { - def check(input: Any, expected: Any): Unit = { - val expr = BitwiseNot(Literal(input)) - assert(expr.dataType === Literal(input).dataType) - checkEvaluation(expr, expected) - } - - // Need the extra toByte even though IntelliJ thought it's not needed. - check(1.toByte, (~1.toByte).toByte) - check(1000.toShort, (~1000.toShort).toShort) - check(1000000, ~1000000) - check(123456789123L, ~123456789123L) - - checkEvaluation(BitwiseNot(Literal.create(null, IntegerType)), null) - checkEvaluation(BitwiseNot(positiveShortLit), (~positiveShort).toShort) - checkEvaluation(BitwiseNot(negativeShortLit), (~negativeShort).toShort) - checkEvaluation(BitwiseNot(positiveIntLit), ~positiveInt) - checkEvaluation(BitwiseNot(negativeIntLit), ~negativeInt) - checkEvaluation(BitwiseNot(positiveLongLit), ~positiveLong) - checkEvaluation(BitwiseNot(negativeLongLit), ~negativeLong) - - DataTypeTestUtils.integralType.foreach { dt => - checkConsistencyBetweenInterpretedAndCodegen(BitwiseNot, dt) - } - } - - test("BitwiseAnd") { - def check(input1: Any, input2: Any, expected: Any): Unit = { - val expr = BitwiseAnd(Literal(input1), Literal(input2)) - assert(expr.dataType === Literal(input1).dataType) - checkEvaluation(expr, expected) - } - - // Need the extra toByte even though IntelliJ thought it's not needed. - check(1.toByte, 2.toByte, (1.toByte & 2.toByte).toByte) - check(1000.toShort, 2.toShort, (1000.toShort & 2.toShort).toShort) - check(1000000, 4, 1000000 & 4) - check(123456789123L, 5L, 123456789123L & 5L) - - val nullLit = Literal.create(null, IntegerType) - checkEvaluation(BitwiseAnd(nullLit, Literal(1)), null) - checkEvaluation(BitwiseAnd(Literal(1), nullLit), null) - checkEvaluation(BitwiseAnd(nullLit, nullLit), null) - checkEvaluation(BitwiseAnd(positiveShortLit, negativeShortLit), - (positiveShort & negativeShort).toShort) - checkEvaluation(BitwiseAnd(positiveIntLit, negativeIntLit), positiveInt & negativeInt) - checkEvaluation(BitwiseAnd(positiveLongLit, negativeLongLit), positiveLong & negativeLong) - - DataTypeTestUtils.integralType.foreach { dt => - checkConsistencyBetweenInterpretedAndCodegen(BitwiseAnd, dt, dt) - } - } - - test("BitwiseOr") { - def check(input1: Any, input2: Any, expected: Any): Unit = { - val expr = BitwiseOr(Literal(input1), Literal(input2)) - assert(expr.dataType === Literal(input1).dataType) - checkEvaluation(expr, expected) - } - - // Need the extra toByte even though IntelliJ thought it's not needed. - check(1.toByte, 2.toByte, (1.toByte | 2.toByte).toByte) - check(1000.toShort, 2.toShort, (1000.toShort | 2.toShort).toShort) - check(1000000, 4, 1000000 | 4) - check(123456789123L, 5L, 123456789123L | 5L) - - val nullLit = Literal.create(null, IntegerType) - checkEvaluation(BitwiseOr(nullLit, Literal(1)), null) - checkEvaluation(BitwiseOr(Literal(1), nullLit), null) - checkEvaluation(BitwiseOr(nullLit, nullLit), null) - checkEvaluation(BitwiseOr(positiveShortLit, negativeShortLit), - (positiveShort | negativeShort).toShort) - checkEvaluation(BitwiseOr(positiveIntLit, negativeIntLit), positiveInt | negativeInt) - checkEvaluation(BitwiseOr(positiveLongLit, negativeLongLit), positiveLong | negativeLong) - - DataTypeTestUtils.integralType.foreach { dt => - checkConsistencyBetweenInterpretedAndCodegen(BitwiseOr, dt, dt) - } - } - - test("BitwiseXor") { - def check(input1: Any, input2: Any, expected: Any): Unit = { - val expr = BitwiseXor(Literal(input1), Literal(input2)) - assert(expr.dataType === Literal(input1).dataType) - checkEvaluation(expr, expected) - } - - // Need the extra toByte even though IntelliJ thought it's not needed. - check(1.toByte, 2.toByte, (1.toByte ^ 2.toByte).toByte) - check(1000.toShort, 2.toShort, (1000.toShort ^ 2.toShort).toShort) - check(1000000, 4, 1000000 ^ 4) - check(123456789123L, 5L, 123456789123L ^ 5L) - - val nullLit = Literal.create(null, IntegerType) - checkEvaluation(BitwiseXor(nullLit, Literal(1)), null) - checkEvaluation(BitwiseXor(Literal(1), nullLit), null) - checkEvaluation(BitwiseXor(nullLit, nullLit), null) - checkEvaluation(BitwiseXor(positiveShortLit, negativeShortLit), - (positiveShort ^ negativeShort).toShort) - checkEvaluation(BitwiseXor(positiveIntLit, negativeIntLit), positiveInt ^ negativeInt) - checkEvaluation(BitwiseXor(positiveLongLit, negativeLongLit), positiveLong ^ negativeLong) - - DataTypeTestUtils.integralType.foreach { dt => - checkConsistencyBetweenInterpretedAndCodegen(BitwiseXor, dt, dt) - } - } -} http://git-wip-us.apache.org/repos/asf/spark/blob/2fa1a632/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala new file mode 100644 index 0000000..020687e --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.types._ + +class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { + + test("Array and Map Size") { + val a0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType)) + val a1 = Literal.create(Seq[Integer](), ArrayType(IntegerType)) + val a2 = Literal.create(Seq(1, 2), ArrayType(IntegerType)) + + checkEvaluation(Size(a0), 3) + checkEvaluation(Size(a1), 0) + checkEvaluation(Size(a2), 2) + + val m0 = Literal.create(Map("a" -> "a", "b" -> "b"), MapType(StringType, StringType)) + val m1 = Literal.create(Map[String, String](), MapType(StringType, StringType)) + val m2 = Literal.create(Map("a" -> "a"), MapType(StringType, StringType)) + + checkEvaluation(Size(m0), 2) + checkEvaluation(Size(m1), 0) + checkEvaluation(Size(m2), 1) + + checkEvaluation(Size(Literal.create(null, MapType(StringType, StringType))), -1) + checkEvaluation(Size(Literal.create(null, ArrayType(StringType))), -1) + } + + test("MapKeys/MapValues") { + val m0 = Literal.create(Map("a" -> "1", "b" -> "2"), MapType(StringType, StringType)) + val m1 = Literal.create(Map[String, String](), MapType(StringType, StringType)) + val m2 = Literal.create(null, MapType(StringType, StringType)) + + checkEvaluation(MapKeys(m0), Seq("a", "b")) + checkEvaluation(MapValues(m0), Seq("1", "2")) + checkEvaluation(MapKeys(m1), Seq()) + checkEvaluation(MapValues(m1), Seq()) + checkEvaluation(MapKeys(m2), null) + checkEvaluation(MapValues(m2), null) + } + + test("Sort Array") { + val a0 = Literal.create(Seq(2, 1, 3), ArrayType(IntegerType)) + val a1 = Literal.create(Seq[Integer](), ArrayType(IntegerType)) + val a2 = Literal.create(Seq("b", "a"), ArrayType(StringType)) + val a3 = Literal.create(Seq("b", null, "a"), ArrayType(StringType)) + val a4 = Literal.create(Seq(null, null), ArrayType(NullType)) + + checkEvaluation(new SortArray(a0), Seq(1, 2, 3)) + checkEvaluation(new SortArray(a1), Seq[Integer]()) + checkEvaluation(new SortArray(a2), Seq("a", "b")) + checkEvaluation(new SortArray(a3), Seq(null, "a", "b")) + checkEvaluation(SortArray(a0, Literal(true)), Seq(1, 2, 3)) + checkEvaluation(SortArray(a1, Literal(true)), Seq[Integer]()) + checkEvaluation(SortArray(a2, Literal(true)), Seq("a", "b")) + checkEvaluation(new SortArray(a3, Literal(true)), Seq(null, "a", "b")) + checkEvaluation(SortArray(a0, Literal(false)), Seq(3, 2, 1)) + checkEvaluation(SortArray(a1, Literal(false)), Seq[Integer]()) + checkEvaluation(SortArray(a2, Literal(false)), Seq("b", "a")) + checkEvaluation(new SortArray(a3, Literal(false)), Seq("b", "a", null)) + + checkEvaluation(Literal.create(null, ArrayType(StringType)), null) + checkEvaluation(new SortArray(a4), Seq(null, null)) + + val typeAS = ArrayType(StructType(StructField("a", IntegerType) :: Nil)) + val arrayStruct = Literal.create(Seq(create_row(2), create_row(1)), typeAS) + + checkEvaluation(new SortArray(arrayStruct), Seq(create_row(1), create_row(2))) + } + + test("Array contains") { + val a0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType)) + val a1 = Literal.create(Seq[String](null, ""), ArrayType(StringType)) + val a2 = Literal.create(Seq(null), ArrayType(LongType)) + val a3 = Literal.create(null, ArrayType(StringType)) + + checkEvaluation(ArrayContains(a0, Literal(1)), true) + checkEvaluation(ArrayContains(a0, Literal(0)), false) + checkEvaluation(ArrayContains(a0, Literal.create(null, IntegerType)), null) + + checkEvaluation(ArrayContains(a1, Literal("")), true) + checkEvaluation(ArrayContains(a1, Literal("a")), null) + checkEvaluation(ArrayContains(a1, Literal.create(null, StringType)), null) + + checkEvaluation(ArrayContains(a2, Literal(1L)), null) + checkEvaluation(ArrayContains(a2, Literal.create(null, LongType)), null) + + checkEvaluation(ArrayContains(a3, Literal("")), null) + checkEvaluation(ArrayContains(a3, Literal.create(null, StringType)), null) + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/2fa1a632/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala deleted file mode 100644 index c76dad2..0000000 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala +++ /dev/null @@ -1,109 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions - -import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.types._ - - -class CollectionFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { - - test("Array and Map Size") { - val a0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType)) - val a1 = Literal.create(Seq[Integer](), ArrayType(IntegerType)) - val a2 = Literal.create(Seq(1, 2), ArrayType(IntegerType)) - - checkEvaluation(Size(a0), 3) - checkEvaluation(Size(a1), 0) - checkEvaluation(Size(a2), 2) - - val m0 = Literal.create(Map("a" -> "a", "b" -> "b"), MapType(StringType, StringType)) - val m1 = Literal.create(Map[String, String](), MapType(StringType, StringType)) - val m2 = Literal.create(Map("a" -> "a"), MapType(StringType, StringType)) - - checkEvaluation(Size(m0), 2) - checkEvaluation(Size(m1), 0) - checkEvaluation(Size(m2), 1) - - checkEvaluation(Size(Literal.create(null, MapType(StringType, StringType))), -1) - checkEvaluation(Size(Literal.create(null, ArrayType(StringType))), -1) - } - - test("MapKeys/MapValues") { - val m0 = Literal.create(Map("a" -> "1", "b" -> "2"), MapType(StringType, StringType)) - val m1 = Literal.create(Map[String, String](), MapType(StringType, StringType)) - val m2 = Literal.create(null, MapType(StringType, StringType)) - - checkEvaluation(MapKeys(m0), Seq("a", "b")) - checkEvaluation(MapValues(m0), Seq("1", "2")) - checkEvaluation(MapKeys(m1), Seq()) - checkEvaluation(MapValues(m1), Seq()) - checkEvaluation(MapKeys(m2), null) - checkEvaluation(MapValues(m2), null) - } - - test("Sort Array") { - val a0 = Literal.create(Seq(2, 1, 3), ArrayType(IntegerType)) - val a1 = Literal.create(Seq[Integer](), ArrayType(IntegerType)) - val a2 = Literal.create(Seq("b", "a"), ArrayType(StringType)) - val a3 = Literal.create(Seq("b", null, "a"), ArrayType(StringType)) - val a4 = Literal.create(Seq(null, null), ArrayType(NullType)) - - checkEvaluation(new SortArray(a0), Seq(1, 2, 3)) - checkEvaluation(new SortArray(a1), Seq[Integer]()) - checkEvaluation(new SortArray(a2), Seq("a", "b")) - checkEvaluation(new SortArray(a3), Seq(null, "a", "b")) - checkEvaluation(SortArray(a0, Literal(true)), Seq(1, 2, 3)) - checkEvaluation(SortArray(a1, Literal(true)), Seq[Integer]()) - checkEvaluation(SortArray(a2, Literal(true)), Seq("a", "b")) - checkEvaluation(new SortArray(a3, Literal(true)), Seq(null, "a", "b")) - checkEvaluation(SortArray(a0, Literal(false)), Seq(3, 2, 1)) - checkEvaluation(SortArray(a1, Literal(false)), Seq[Integer]()) - checkEvaluation(SortArray(a2, Literal(false)), Seq("b", "a")) - checkEvaluation(new SortArray(a3, Literal(false)), Seq("b", "a", null)) - - checkEvaluation(Literal.create(null, ArrayType(StringType)), null) - checkEvaluation(new SortArray(a4), Seq(null, null)) - - val typeAS = ArrayType(StructType(StructField("a", IntegerType) :: Nil)) - val arrayStruct = Literal.create(Seq(create_row(2), create_row(1)), typeAS) - - checkEvaluation(new SortArray(arrayStruct), Seq(create_row(1), create_row(2))) - } - - test("Array contains") { - val a0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType)) - val a1 = Literal.create(Seq[String](null, ""), ArrayType(StringType)) - val a2 = Literal.create(Seq(null), ArrayType(LongType)) - val a3 = Literal.create(null, ArrayType(StringType)) - - checkEvaluation(ArrayContains(a0, Literal(1)), true) - checkEvaluation(ArrayContains(a0, Literal(0)), false) - checkEvaluation(ArrayContains(a0, Literal.create(null, IntegerType)), null) - - checkEvaluation(ArrayContains(a1, Literal("")), true) - checkEvaluation(ArrayContains(a1, Literal("a")), null) - checkEvaluation(ArrayContains(a1, Literal.create(null, StringType)), null) - - checkEvaluation(ArrayContains(a2, Literal(1L)), null) - checkEvaluation(ArrayContains(a2, Literal.create(null, LongType)), null) - - checkEvaluation(ArrayContains(a3, Literal("")), null) - checkEvaluation(ArrayContains(a3, Literal.create(null, StringType)), null) - } -} http://git-wip-us.apache.org/repos/asf/spark/blob/2fa1a632/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala new file mode 100644 index 0000000..6b5bfac --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala @@ -0,0 +1,582 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import java.nio.charset.StandardCharsets + +import com.google.common.math.LongMath + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection +import org.apache.spark.sql.catalyst.optimizer.SimpleTestOptimizer +import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project} +import org.apache.spark.sql.types._ + +class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { + + import IntegralLiteralTestUtils._ + + /** + * Used for testing leaf math expressions. + * + * @param e expression + * @param c The constants in scala.math + * @tparam T Generic type for primitives + */ + private def testLeaf[T]( + e: () => Expression, + c: T): Unit = { + checkEvaluation(e(), c, EmptyRow) + checkEvaluation(e(), c, create_row(null)) + } + + /** + * Used for testing unary math expressions. + * + * @param c expression + * @param f The functions in scala.math or elsewhere used to generate expected results + * @param domain The set of values to run the function with + * @param expectNull Whether the given values should return null or not + * @param expectNaN Whether the given values should eval to NaN or not + * @tparam T Generic type for primitives + * @tparam U Generic type for the output of the given function `f` + */ + private def testUnary[T, U]( + c: Expression => Expression, + f: T => U, + domain: Iterable[T] = (-20 to 20).map(_ * 0.1), + expectNull: Boolean = false, + expectNaN: Boolean = false, + evalType: DataType = DoubleType): Unit = { + if (expectNull) { + domain.foreach { value => + checkEvaluation(c(Literal(value)), null, EmptyRow) + } + } else if (expectNaN) { + domain.foreach { value => + checkNaN(c(Literal(value)), EmptyRow) + } + } else { + domain.foreach { value => + checkEvaluation(c(Literal(value)), f(value), EmptyRow) + } + } + checkEvaluation(c(Literal.create(null, evalType)), null, create_row(null)) + } + + /** + * Used for testing binary math expressions. + * + * @param c The DataFrame function + * @param f The functions in scala.math + * @param domain The set of values to run the function with + * @param expectNull Whether the given values should return null or not + * @param expectNaN Whether the given values should eval to NaN or not + */ + private def testBinary( + c: (Expression, Expression) => Expression, + f: (Double, Double) => Double, + domain: Iterable[(Double, Double)] = (-20 to 20).map(v => (v * 0.1, v * -0.1)), + expectNull: Boolean = false, expectNaN: Boolean = false): Unit = { + if (expectNull) { + domain.foreach { case (v1, v2) => + checkEvaluation(c(Literal(v1), Literal(v2)), null, create_row(null)) + } + } else if (expectNaN) { + domain.foreach { case (v1, v2) => + checkNaN(c(Literal(v1), Literal(v2)), EmptyRow) + } + } else { + domain.foreach { case (v1, v2) => + checkEvaluation(c(Literal(v1), Literal(v2)), f(v1 + 0.0, v2 + 0.0), EmptyRow) + checkEvaluation(c(Literal(v2), Literal(v1)), f(v2 + 0.0, v1 + 0.0), EmptyRow) + } + } + checkEvaluation(c(Literal.create(null, DoubleType), Literal(1.0)), null, create_row(null)) + checkEvaluation(c(Literal(1.0), Literal.create(null, DoubleType)), null, create_row(null)) + } + + private def checkNaN( + expression: Expression, inputRow: InternalRow = EmptyRow): Unit = { + checkNaNWithoutCodegen(expression, inputRow) + checkNaNWithGeneratedProjection(expression, inputRow) + checkNaNWithOptimization(expression, inputRow) + } + + private def checkNaNWithoutCodegen( + expression: Expression, + inputRow: InternalRow = EmptyRow): Unit = { + val actual = try evaluate(expression, inputRow) catch { + case e: Exception => fail(s"Exception evaluating $expression", e) + } + if (!actual.asInstanceOf[Double].isNaN) { + fail(s"Incorrect evaluation (codegen off): $expression, " + + s"actual: $actual, " + + s"expected: NaN") + } + } + + private def checkNaNWithGeneratedProjection( + expression: Expression, + inputRow: InternalRow = EmptyRow): Unit = { + + val plan = generateProject( + GenerateMutableProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil), + expression) + + val actual = plan(inputRow).get(0, expression.dataType) + if (!actual.asInstanceOf[Double].isNaN) { + fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: NaN") + } + } + + private def checkNaNWithOptimization( + expression: Expression, + inputRow: InternalRow = EmptyRow): Unit = { + val plan = Project(Alias(expression, s"Optimized($expression)")() :: Nil, OneRowRelation) + val optimizedPlan = SimpleTestOptimizer.execute(plan) + checkNaNWithoutCodegen(optimizedPlan.expressions.head, inputRow) + } + + test("conv") { + checkEvaluation(Conv(Literal("3"), Literal(10), Literal(2)), "11") + checkEvaluation(Conv(Literal("-15"), Literal(10), Literal(-16)), "-F") + checkEvaluation(Conv(Literal("-15"), Literal(10), Literal(16)), "FFFFFFFFFFFFFFF1") + checkEvaluation(Conv(Literal("big"), Literal(36), Literal(16)), "3A48") + checkEvaluation(Conv(Literal.create(null, StringType), Literal(36), Literal(16)), null) + checkEvaluation(Conv(Literal("3"), Literal.create(null, IntegerType), Literal(16)), null) + checkEvaluation(Conv(Literal("3"), Literal(16), Literal.create(null, IntegerType)), null) + checkEvaluation( + Conv(Literal("1234"), Literal(10), Literal(37)), null) + checkEvaluation( + Conv(Literal(""), Literal(10), Literal(16)), null) + checkEvaluation( + Conv(Literal("9223372036854775807"), Literal(36), Literal(16)), "FFFFFFFFFFFFFFFF") + // If there is an invalid digit in the number, the longest valid prefix should be converted. + checkEvaluation( + Conv(Literal("11abc"), Literal(10), Literal(16)), "B") + } + + test("e") { + testLeaf(EulerNumber, math.E) + } + + test("pi") { + testLeaf(Pi, math.Pi) + } + + test("sin") { + testUnary(Sin, math.sin) + checkConsistencyBetweenInterpretedAndCodegen(Sin, DoubleType) + } + + test("asin") { + testUnary(Asin, math.asin, (-10 to 10).map(_ * 0.1)) + testUnary(Asin, math.asin, (11 to 20).map(_ * 0.1), expectNaN = true) + checkConsistencyBetweenInterpretedAndCodegen(Asin, DoubleType) + } + + test("sinh") { + testUnary(Sinh, math.sinh) + checkConsistencyBetweenInterpretedAndCodegen(Sinh, DoubleType) + } + + test("cos") { + testUnary(Cos, math.cos) + checkConsistencyBetweenInterpretedAndCodegen(Cos, DoubleType) + } + + test("acos") { + testUnary(Acos, math.acos, (-10 to 10).map(_ * 0.1)) + testUnary(Acos, math.acos, (11 to 20).map(_ * 0.1), expectNaN = true) + checkConsistencyBetweenInterpretedAndCodegen(Acos, DoubleType) + } + + test("cosh") { + testUnary(Cosh, math.cosh) + checkConsistencyBetweenInterpretedAndCodegen(Cosh, DoubleType) + } + + test("tan") { + testUnary(Tan, math.tan) + checkConsistencyBetweenInterpretedAndCodegen(Tan, DoubleType) + } + + test("atan") { + testUnary(Atan, math.atan) + checkConsistencyBetweenInterpretedAndCodegen(Atan, DoubleType) + } + + test("tanh") { + testUnary(Tanh, math.tanh) + checkConsistencyBetweenInterpretedAndCodegen(Tanh, DoubleType) + } + + test("toDegrees") { + testUnary(ToDegrees, math.toDegrees) + checkConsistencyBetweenInterpretedAndCodegen(Acos, DoubleType) + } + + test("toRadians") { + testUnary(ToRadians, math.toRadians) + checkConsistencyBetweenInterpretedAndCodegen(ToRadians, DoubleType) + } + + test("cbrt") { + testUnary(Cbrt, math.cbrt) + checkConsistencyBetweenInterpretedAndCodegen(Cbrt, DoubleType) + } + + test("ceil") { + testUnary(Ceil, (d: Double) => math.ceil(d).toLong) + checkConsistencyBetweenInterpretedAndCodegen(Ceil, DoubleType) + + testUnary(Ceil, (d: Decimal) => d.ceil, (-20 to 20).map(x => Decimal(x * 0.1))) + checkConsistencyBetweenInterpretedAndCodegen(Ceil, DecimalType(25, 3)) + checkConsistencyBetweenInterpretedAndCodegen(Ceil, DecimalType(25, 0)) + checkConsistencyBetweenInterpretedAndCodegen(Ceil, DecimalType(5, 0)) + } + + test("floor") { + testUnary(Floor, (d: Double) => math.floor(d).toLong) + checkConsistencyBetweenInterpretedAndCodegen(Floor, DoubleType) + + testUnary(Floor, (d: Decimal) => d.floor, (-20 to 20).map(x => Decimal(x * 0.1))) + checkConsistencyBetweenInterpretedAndCodegen(Floor, DecimalType(25, 3)) + checkConsistencyBetweenInterpretedAndCodegen(Floor, DecimalType(25, 0)) + checkConsistencyBetweenInterpretedAndCodegen(Floor, DecimalType(5, 0)) + } + + test("factorial") { + (0 to 20).foreach { value => + checkEvaluation(Factorial(Literal(value)), LongMath.factorial(value), EmptyRow) + } + checkEvaluation(Literal.create(null, IntegerType), null, create_row(null)) + checkEvaluation(Factorial(Literal(20)), 2432902008176640000L, EmptyRow) + checkEvaluation(Factorial(Literal(21)), null, EmptyRow) + checkConsistencyBetweenInterpretedAndCodegen(Factorial.apply _, IntegerType) + } + + test("rint") { + testUnary(Rint, math.rint) + checkConsistencyBetweenInterpretedAndCodegen(Rint, DoubleType) + } + + test("exp") { + testUnary(Exp, math.exp) + checkConsistencyBetweenInterpretedAndCodegen(Exp, DoubleType) + } + + test("expm1") { + testUnary(Expm1, math.expm1) + checkConsistencyBetweenInterpretedAndCodegen(Expm1, DoubleType) + } + + test("signum") { + testUnary[Double, Double](Signum, math.signum) + checkConsistencyBetweenInterpretedAndCodegen(Signum, DoubleType) + } + + test("log") { + testUnary(Log, math.log, (1 to 20).map(_ * 0.1)) + testUnary(Log, math.log, (-5 to 0).map(_ * 0.1), expectNull = true) + checkConsistencyBetweenInterpretedAndCodegen(Log, DoubleType) + } + + test("log10") { + testUnary(Log10, math.log10, (1 to 20).map(_ * 0.1)) + testUnary(Log10, math.log10, (-5 to 0).map(_ * 0.1), expectNull = true) + checkConsistencyBetweenInterpretedAndCodegen(Log10, DoubleType) + } + + test("log1p") { + testUnary(Log1p, math.log1p, (0 to 20).map(_ * 0.1)) + testUnary(Log1p, math.log1p, (-10 to -1).map(_ * 1.0), expectNull = true) + checkConsistencyBetweenInterpretedAndCodegen(Log1p, DoubleType) + } + + test("bin") { + testUnary(Bin, java.lang.Long.toBinaryString, (-20 to 20).map(_.toLong), evalType = LongType) + + val row = create_row(null, 12L, 123L, 1234L, -123L) + val l1 = 'a.long.at(0) + val l2 = 'a.long.at(1) + val l3 = 'a.long.at(2) + val l4 = 'a.long.at(3) + val l5 = 'a.long.at(4) + + checkEvaluation(Bin(l1), null, row) + checkEvaluation(Bin(l2), java.lang.Long.toBinaryString(12), row) + checkEvaluation(Bin(l3), java.lang.Long.toBinaryString(123), row) + checkEvaluation(Bin(l4), java.lang.Long.toBinaryString(1234), row) + checkEvaluation(Bin(l5), java.lang.Long.toBinaryString(-123), row) + + checkEvaluation(Bin(positiveLongLit), java.lang.Long.toBinaryString(positiveLong)) + checkEvaluation(Bin(negativeLongLit), java.lang.Long.toBinaryString(negativeLong)) + + checkConsistencyBetweenInterpretedAndCodegen(Bin, LongType) + } + + test("log2") { + def f: (Double) => Double = (x: Double) => math.log(x) / math.log(2) + testUnary(Log2, f, (1 to 20).map(_ * 0.1)) + testUnary(Log2, f, (-5 to 0).map(_ * 1.0), expectNull = true) + checkConsistencyBetweenInterpretedAndCodegen(Log2, DoubleType) + } + + test("sqrt") { + testUnary(Sqrt, math.sqrt, (0 to 20).map(_ * 0.1)) + testUnary(Sqrt, math.sqrt, (-5 to -1).map(_ * 1.0), expectNaN = true) + + checkEvaluation(Sqrt(Literal.create(null, DoubleType)), null, create_row(null)) + checkNaN(Sqrt(Literal(-1.0)), EmptyRow) + checkNaN(Sqrt(Literal(-1.5)), EmptyRow) + checkConsistencyBetweenInterpretedAndCodegen(Sqrt, DoubleType) + } + + test("pow") { + testBinary(Pow, math.pow, (-5 to 5).map(v => (v * 1.0, v * 1.0))) + testBinary(Pow, math.pow, Seq((-1.0, 0.9), (-2.2, 1.7), (-2.2, -1.7)), expectNaN = true) + checkConsistencyBetweenInterpretedAndCodegen(Pow, DoubleType, DoubleType) + } + + test("shift left") { + checkEvaluation(ShiftLeft(Literal.create(null, IntegerType), Literal(1)), null) + checkEvaluation(ShiftLeft(Literal(21), Literal.create(null, IntegerType)), null) + checkEvaluation( + ShiftLeft(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null) + checkEvaluation(ShiftLeft(Literal(21), Literal(1)), 42) + + checkEvaluation(ShiftLeft(Literal(21.toLong), Literal(1)), 42.toLong) + checkEvaluation(ShiftLeft(Literal(-21.toLong), Literal(1)), -42.toLong) + + checkEvaluation(ShiftLeft(positiveIntLit, positiveIntLit), positiveInt << positiveInt) + checkEvaluation(ShiftLeft(positiveIntLit, negativeIntLit), positiveInt << negativeInt) + checkEvaluation(ShiftLeft(negativeIntLit, positiveIntLit), negativeInt << positiveInt) + checkEvaluation(ShiftLeft(negativeIntLit, negativeIntLit), negativeInt << negativeInt) + checkEvaluation(ShiftLeft(positiveLongLit, positiveIntLit), positiveLong << positiveInt) + checkEvaluation(ShiftLeft(positiveLongLit, negativeIntLit), positiveLong << negativeInt) + checkEvaluation(ShiftLeft(negativeLongLit, positiveIntLit), negativeLong << positiveInt) + checkEvaluation(ShiftLeft(negativeLongLit, negativeIntLit), negativeLong << negativeInt) + + checkConsistencyBetweenInterpretedAndCodegen(ShiftLeft, IntegerType, IntegerType) + checkConsistencyBetweenInterpretedAndCodegen(ShiftLeft, LongType, IntegerType) + } + + test("shift right") { + checkEvaluation(ShiftRight(Literal.create(null, IntegerType), Literal(1)), null) + checkEvaluation(ShiftRight(Literal(42), Literal.create(null, IntegerType)), null) + checkEvaluation( + ShiftRight(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null) + checkEvaluation(ShiftRight(Literal(42), Literal(1)), 21) + + checkEvaluation(ShiftRight(Literal(42.toLong), Literal(1)), 21.toLong) + checkEvaluation(ShiftRight(Literal(-42.toLong), Literal(1)), -21.toLong) + + checkEvaluation(ShiftRight(positiveIntLit, positiveIntLit), positiveInt >> positiveInt) + checkEvaluation(ShiftRight(positiveIntLit, negativeIntLit), positiveInt >> negativeInt) + checkEvaluation(ShiftRight(negativeIntLit, positiveIntLit), negativeInt >> positiveInt) + checkEvaluation(ShiftRight(negativeIntLit, negativeIntLit), negativeInt >> negativeInt) + checkEvaluation(ShiftRight(positiveLongLit, positiveIntLit), positiveLong >> positiveInt) + checkEvaluation(ShiftRight(positiveLongLit, negativeIntLit), positiveLong >> negativeInt) + checkEvaluation(ShiftRight(negativeLongLit, positiveIntLit), negativeLong >> positiveInt) + checkEvaluation(ShiftRight(negativeLongLit, negativeIntLit), negativeLong >> negativeInt) + + checkConsistencyBetweenInterpretedAndCodegen(ShiftRight, IntegerType, IntegerType) + checkConsistencyBetweenInterpretedAndCodegen(ShiftRight, LongType, IntegerType) + } + + test("shift right unsigned") { + checkEvaluation(ShiftRightUnsigned(Literal.create(null, IntegerType), Literal(1)), null) + checkEvaluation(ShiftRightUnsigned(Literal(42), Literal.create(null, IntegerType)), null) + checkEvaluation( + ShiftRight(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null) + checkEvaluation(ShiftRightUnsigned(Literal(42), Literal(1)), 21) + + checkEvaluation(ShiftRightUnsigned(Literal(42.toLong), Literal(1)), 21.toLong) + checkEvaluation(ShiftRightUnsigned(Literal(-42.toLong), Literal(1)), 9223372036854775787L) + + checkEvaluation(ShiftRightUnsigned(positiveIntLit, positiveIntLit), + positiveInt >>> positiveInt) + checkEvaluation(ShiftRightUnsigned(positiveIntLit, negativeIntLit), + positiveInt >>> negativeInt) + checkEvaluation(ShiftRightUnsigned(negativeIntLit, positiveIntLit), + negativeInt >>> positiveInt) + checkEvaluation(ShiftRightUnsigned(negativeIntLit, negativeIntLit), + negativeInt >>> negativeInt) + checkEvaluation(ShiftRightUnsigned(positiveLongLit, positiveIntLit), + positiveLong >>> positiveInt) + checkEvaluation(ShiftRightUnsigned(positiveLongLit, negativeIntLit), + positiveLong >>> negativeInt) + checkEvaluation(ShiftRightUnsigned(negativeLongLit, positiveIntLit), + negativeLong >>> positiveInt) + checkEvaluation(ShiftRightUnsigned(negativeLongLit, negativeIntLit), + negativeLong >>> negativeInt) + + checkConsistencyBetweenInterpretedAndCodegen(ShiftRightUnsigned, IntegerType, IntegerType) + checkConsistencyBetweenInterpretedAndCodegen(ShiftRightUnsigned, LongType, IntegerType) + } + + test("hex") { + checkEvaluation(Hex(Literal.create(null, LongType)), null) + checkEvaluation(Hex(Literal(28L)), "1C") + checkEvaluation(Hex(Literal(-28L)), "FFFFFFFFFFFFFFE4") + checkEvaluation(Hex(Literal(100800200404L)), "177828FED4") + checkEvaluation(Hex(Literal(-100800200404L)), "FFFFFFE887D7012C") + checkEvaluation(Hex(Literal.create(null, BinaryType)), null) + checkEvaluation(Hex(Literal("helloHex".getBytes(StandardCharsets.UTF_8))), "68656C6C6F486578") + // scalastyle:off + // Turn off scala style for non-ascii chars + checkEvaluation(Hex(Literal("ä¸éç".getBytes(StandardCharsets.UTF_8))), "E4B889E9878DE79A84") + // scalastyle:on + Seq(LongType, BinaryType, StringType).foreach { dt => + checkConsistencyBetweenInterpretedAndCodegen(Hex.apply _, dt) + } + } + + test("unhex") { + checkEvaluation(Unhex(Literal.create(null, StringType)), null) + checkEvaluation(Unhex(Literal("737472696E67")), "string".getBytes(StandardCharsets.UTF_8)) + checkEvaluation(Unhex(Literal("")), new Array[Byte](0)) + checkEvaluation(Unhex(Literal("F")), Array[Byte](15)) + checkEvaluation(Unhex(Literal("ff")), Array[Byte](-1)) + checkEvaluation(Unhex(Literal("GG")), null) + // scalastyle:off + // Turn off scala style for non-ascii chars + checkEvaluation(Unhex(Literal("E4B889E9878DE79A84")), "ä¸éç".getBytes(StandardCharsets.UTF_8)) + checkEvaluation(Unhex(Literal("ä¸éç")), null) + // scalastyle:on + checkConsistencyBetweenInterpretedAndCodegen(Unhex, StringType) + } + + test("hypot") { + testBinary(Hypot, math.hypot) + checkConsistencyBetweenInterpretedAndCodegen(Hypot, DoubleType, DoubleType) + } + + test("atan2") { + testBinary(Atan2, math.atan2) + checkConsistencyBetweenInterpretedAndCodegen(Atan2, DoubleType, DoubleType) + } + + test("binary log") { + val f = (c1: Double, c2: Double) => math.log(c2) / math.log(c1) + val domain = (1 to 20).map(v => (v * 0.1, v * 0.2)) + + domain.foreach { case (v1, v2) => + checkEvaluation(Logarithm(Literal(v1), Literal(v2)), f(v1 + 0.0, v2 + 0.0), EmptyRow) + checkEvaluation(Logarithm(Literal(v2), Literal(v1)), f(v2 + 0.0, v1 + 0.0), EmptyRow) + checkEvaluation(new Logarithm(Literal(v1)), f(math.E, v1 + 0.0), EmptyRow) + } + + // null input should yield null output + checkEvaluation( + Logarithm(Literal.create(null, DoubleType), Literal(1.0)), + null, + create_row(null)) + checkEvaluation( + Logarithm(Literal(1.0), Literal.create(null, DoubleType)), + null, + create_row(null)) + + // negative input should yield null output + checkEvaluation( + Logarithm(Literal(-1.0), Literal(1.0)), + null, + create_row(null)) + checkEvaluation( + Logarithm(Literal(1.0), Literal(-1.0)), + null, + create_row(null)) + checkConsistencyBetweenInterpretedAndCodegen(Logarithm, DoubleType, DoubleType) + } + + test("round/bround") { + val scales = -6 to 6 + val doublePi: Double = math.Pi + val shortPi: Short = 31415 + val intPi: Int = 314159265 + val longPi: Long = 31415926535897932L + val bdPi: BigDecimal = BigDecimal(31415927L, 7) + + val doubleResults: Seq[Double] = Seq(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.0, 3.1, 3.14, 3.142, + 3.1416, 3.14159, 3.141593) + + val shortResults: Seq[Short] = Seq[Short](0, 0, 30000, 31000, 31400, 31420) ++ + Seq.fill[Short](7)(31415) + + val intResults: Seq[Int] = Seq(314000000, 314200000, 314160000, 314159000, 314159300, + 314159270) ++ Seq.fill(7)(314159265) + + val longResults: Seq[Long] = Seq(31415926536000000L, 31415926535900000L, + 31415926535900000L, 31415926535898000L, 31415926535897900L, 31415926535897930L) ++ + Seq.fill(7)(31415926535897932L) + + val intResultsB: Seq[Int] = Seq(314000000, 314200000, 314160000, 314159000, 314159300, + 314159260) ++ Seq.fill(7)(314159265) + + scales.zipWithIndex.foreach { case (scale, i) => + checkEvaluation(Round(doublePi, scale), doubleResults(i), EmptyRow) + checkEvaluation(Round(shortPi, scale), shortResults(i), EmptyRow) + checkEvaluation(Round(intPi, scale), intResults(i), EmptyRow) + checkEvaluation(Round(longPi, scale), longResults(i), EmptyRow) + checkEvaluation(BRound(doublePi, scale), doubleResults(i), EmptyRow) + checkEvaluation(BRound(shortPi, scale), shortResults(i), EmptyRow) + checkEvaluation(BRound(intPi, scale), intResultsB(i), EmptyRow) + checkEvaluation(BRound(longPi, scale), longResults(i), EmptyRow) + } + + val bdResults: Seq[BigDecimal] = Seq(BigDecimal(3.0), BigDecimal(3.1), BigDecimal(3.14), + BigDecimal(3.142), BigDecimal(3.1416), BigDecimal(3.14159), + BigDecimal(3.141593), BigDecimal(3.1415927)) + // round_scale > current_scale would result in precision increase + // and not allowed by o.a.s.s.types.Decimal.changePrecision, therefore null + (0 to 7).foreach { i => + checkEvaluation(Round(bdPi, i), bdResults(i), EmptyRow) + checkEvaluation(BRound(bdPi, i), bdResults(i), EmptyRow) + } + (8 to 10).foreach { scale => + checkEvaluation(Round(bdPi, scale), null, EmptyRow) + checkEvaluation(BRound(bdPi, scale), null, EmptyRow) + } + + DataTypeTestUtils.numericTypes.foreach { dataType => + checkEvaluation(Round(Literal.create(null, dataType), Literal(2)), null) + checkEvaluation(Round(Literal.create(null, dataType), + Literal.create(null, IntegerType)), null) + checkEvaluation(BRound(Literal.create(null, dataType), Literal(2)), null) + checkEvaluation(BRound(Literal.create(null, dataType), + Literal.create(null, IntegerType)), null) + } + + checkEvaluation(Round(2.5, 0), 3.0) + checkEvaluation(Round(3.5, 0), 4.0) + checkEvaluation(Round(-2.5, 0), -3.0) + checkEvaluation(Round(-3.5, 0), -4.0) + checkEvaluation(Round(-0.35, 1), -0.4) + checkEvaluation(Round(-35, -1), -40) + checkEvaluation(BRound(2.5, 0), 2.0) + checkEvaluation(BRound(3.5, 0), 4.0) + checkEvaluation(BRound(-2.5, 0), -2.0) + checkEvaluation(BRound(-3.5, 0), -4.0) + checkEvaluation(BRound(-0.35, 1), -0.4) + checkEvaluation(BRound(-35, -1), -40) + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/2fa1a632/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala deleted file mode 100644 index f88c9e8..0000000 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala +++ /dev/null @@ -1,582 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions - -import java.nio.charset.StandardCharsets - -import com.google.common.math.LongMath - -import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection -import org.apache.spark.sql.catalyst.optimizer.SimpleTestOptimizer -import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project} -import org.apache.spark.sql.types._ - -class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { - - import IntegralLiteralTestUtils._ - - /** - * Used for testing leaf math expressions. - * - * @param e expression - * @param c The constants in scala.math - * @tparam T Generic type for primitives - */ - private def testLeaf[T]( - e: () => Expression, - c: T): Unit = { - checkEvaluation(e(), c, EmptyRow) - checkEvaluation(e(), c, create_row(null)) - } - - /** - * Used for testing unary math expressions. - * - * @param c expression - * @param f The functions in scala.math or elsewhere used to generate expected results - * @param domain The set of values to run the function with - * @param expectNull Whether the given values should return null or not - * @param expectNaN Whether the given values should eval to NaN or not - * @tparam T Generic type for primitives - * @tparam U Generic type for the output of the given function `f` - */ - private def testUnary[T, U]( - c: Expression => Expression, - f: T => U, - domain: Iterable[T] = (-20 to 20).map(_ * 0.1), - expectNull: Boolean = false, - expectNaN: Boolean = false, - evalType: DataType = DoubleType): Unit = { - if (expectNull) { - domain.foreach { value => - checkEvaluation(c(Literal(value)), null, EmptyRow) - } - } else if (expectNaN) { - domain.foreach { value => - checkNaN(c(Literal(value)), EmptyRow) - } - } else { - domain.foreach { value => - checkEvaluation(c(Literal(value)), f(value), EmptyRow) - } - } - checkEvaluation(c(Literal.create(null, evalType)), null, create_row(null)) - } - - /** - * Used for testing binary math expressions. - * - * @param c The DataFrame function - * @param f The functions in scala.math - * @param domain The set of values to run the function with - * @param expectNull Whether the given values should return null or not - * @param expectNaN Whether the given values should eval to NaN or not - */ - private def testBinary( - c: (Expression, Expression) => Expression, - f: (Double, Double) => Double, - domain: Iterable[(Double, Double)] = (-20 to 20).map(v => (v * 0.1, v * -0.1)), - expectNull: Boolean = false, expectNaN: Boolean = false): Unit = { - if (expectNull) { - domain.foreach { case (v1, v2) => - checkEvaluation(c(Literal(v1), Literal(v2)), null, create_row(null)) - } - } else if (expectNaN) { - domain.foreach { case (v1, v2) => - checkNaN(c(Literal(v1), Literal(v2)), EmptyRow) - } - } else { - domain.foreach { case (v1, v2) => - checkEvaluation(c(Literal(v1), Literal(v2)), f(v1 + 0.0, v2 + 0.0), EmptyRow) - checkEvaluation(c(Literal(v2), Literal(v1)), f(v2 + 0.0, v1 + 0.0), EmptyRow) - } - } - checkEvaluation(c(Literal.create(null, DoubleType), Literal(1.0)), null, create_row(null)) - checkEvaluation(c(Literal(1.0), Literal.create(null, DoubleType)), null, create_row(null)) - } - - private def checkNaN( - expression: Expression, inputRow: InternalRow = EmptyRow): Unit = { - checkNaNWithoutCodegen(expression, inputRow) - checkNaNWithGeneratedProjection(expression, inputRow) - checkNaNWithOptimization(expression, inputRow) - } - - private def checkNaNWithoutCodegen( - expression: Expression, - inputRow: InternalRow = EmptyRow): Unit = { - val actual = try evaluate(expression, inputRow) catch { - case e: Exception => fail(s"Exception evaluating $expression", e) - } - if (!actual.asInstanceOf[Double].isNaN) { - fail(s"Incorrect evaluation (codegen off): $expression, " + - s"actual: $actual, " + - s"expected: NaN") - } - } - - private def checkNaNWithGeneratedProjection( - expression: Expression, - inputRow: InternalRow = EmptyRow): Unit = { - - val plan = generateProject( - GenerateMutableProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil), - expression) - - val actual = plan(inputRow).get(0, expression.dataType) - if (!actual.asInstanceOf[Double].isNaN) { - fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: NaN") - } - } - - private def checkNaNWithOptimization( - expression: Expression, - inputRow: InternalRow = EmptyRow): Unit = { - val plan = Project(Alias(expression, s"Optimized($expression)")() :: Nil, OneRowRelation) - val optimizedPlan = SimpleTestOptimizer.execute(plan) - checkNaNWithoutCodegen(optimizedPlan.expressions.head, inputRow) - } - - test("conv") { - checkEvaluation(Conv(Literal("3"), Literal(10), Literal(2)), "11") - checkEvaluation(Conv(Literal("-15"), Literal(10), Literal(-16)), "-F") - checkEvaluation(Conv(Literal("-15"), Literal(10), Literal(16)), "FFFFFFFFFFFFFFF1") - checkEvaluation(Conv(Literal("big"), Literal(36), Literal(16)), "3A48") - checkEvaluation(Conv(Literal.create(null, StringType), Literal(36), Literal(16)), null) - checkEvaluation(Conv(Literal("3"), Literal.create(null, IntegerType), Literal(16)), null) - checkEvaluation(Conv(Literal("3"), Literal(16), Literal.create(null, IntegerType)), null) - checkEvaluation( - Conv(Literal("1234"), Literal(10), Literal(37)), null) - checkEvaluation( - Conv(Literal(""), Literal(10), Literal(16)), null) - checkEvaluation( - Conv(Literal("9223372036854775807"), Literal(36), Literal(16)), "FFFFFFFFFFFFFFFF") - // If there is an invalid digit in the number, the longest valid prefix should be converted. - checkEvaluation( - Conv(Literal("11abc"), Literal(10), Literal(16)), "B") - } - - test("e") { - testLeaf(EulerNumber, math.E) - } - - test("pi") { - testLeaf(Pi, math.Pi) - } - - test("sin") { - testUnary(Sin, math.sin) - checkConsistencyBetweenInterpretedAndCodegen(Sin, DoubleType) - } - - test("asin") { - testUnary(Asin, math.asin, (-10 to 10).map(_ * 0.1)) - testUnary(Asin, math.asin, (11 to 20).map(_ * 0.1), expectNaN = true) - checkConsistencyBetweenInterpretedAndCodegen(Asin, DoubleType) - } - - test("sinh") { - testUnary(Sinh, math.sinh) - checkConsistencyBetweenInterpretedAndCodegen(Sinh, DoubleType) - } - - test("cos") { - testUnary(Cos, math.cos) - checkConsistencyBetweenInterpretedAndCodegen(Cos, DoubleType) - } - - test("acos") { - testUnary(Acos, math.acos, (-10 to 10).map(_ * 0.1)) - testUnary(Acos, math.acos, (11 to 20).map(_ * 0.1), expectNaN = true) - checkConsistencyBetweenInterpretedAndCodegen(Acos, DoubleType) - } - - test("cosh") { - testUnary(Cosh, math.cosh) - checkConsistencyBetweenInterpretedAndCodegen(Cosh, DoubleType) - } - - test("tan") { - testUnary(Tan, math.tan) - checkConsistencyBetweenInterpretedAndCodegen(Tan, DoubleType) - } - - test("atan") { - testUnary(Atan, math.atan) - checkConsistencyBetweenInterpretedAndCodegen(Atan, DoubleType) - } - - test("tanh") { - testUnary(Tanh, math.tanh) - checkConsistencyBetweenInterpretedAndCodegen(Tanh, DoubleType) - } - - test("toDegrees") { - testUnary(ToDegrees, math.toDegrees) - checkConsistencyBetweenInterpretedAndCodegen(Acos, DoubleType) - } - - test("toRadians") { - testUnary(ToRadians, math.toRadians) - checkConsistencyBetweenInterpretedAndCodegen(ToRadians, DoubleType) - } - - test("cbrt") { - testUnary(Cbrt, math.cbrt) - checkConsistencyBetweenInterpretedAndCodegen(Cbrt, DoubleType) - } - - test("ceil") { - testUnary(Ceil, (d: Double) => math.ceil(d).toLong) - checkConsistencyBetweenInterpretedAndCodegen(Ceil, DoubleType) - - testUnary(Ceil, (d: Decimal) => d.ceil, (-20 to 20).map(x => Decimal(x * 0.1))) - checkConsistencyBetweenInterpretedAndCodegen(Ceil, DecimalType(25, 3)) - checkConsistencyBetweenInterpretedAndCodegen(Ceil, DecimalType(25, 0)) - checkConsistencyBetweenInterpretedAndCodegen(Ceil, DecimalType(5, 0)) - } - - test("floor") { - testUnary(Floor, (d: Double) => math.floor(d).toLong) - checkConsistencyBetweenInterpretedAndCodegen(Floor, DoubleType) - - testUnary(Floor, (d: Decimal) => d.floor, (-20 to 20).map(x => Decimal(x * 0.1))) - checkConsistencyBetweenInterpretedAndCodegen(Floor, DecimalType(25, 3)) - checkConsistencyBetweenInterpretedAndCodegen(Floor, DecimalType(25, 0)) - checkConsistencyBetweenInterpretedAndCodegen(Floor, DecimalType(5, 0)) - } - - test("factorial") { - (0 to 20).foreach { value => - checkEvaluation(Factorial(Literal(value)), LongMath.factorial(value), EmptyRow) - } - checkEvaluation(Literal.create(null, IntegerType), null, create_row(null)) - checkEvaluation(Factorial(Literal(20)), 2432902008176640000L, EmptyRow) - checkEvaluation(Factorial(Literal(21)), null, EmptyRow) - checkConsistencyBetweenInterpretedAndCodegen(Factorial.apply _, IntegerType) - } - - test("rint") { - testUnary(Rint, math.rint) - checkConsistencyBetweenInterpretedAndCodegen(Rint, DoubleType) - } - - test("exp") { - testUnary(Exp, math.exp) - checkConsistencyBetweenInterpretedAndCodegen(Exp, DoubleType) - } - - test("expm1") { - testUnary(Expm1, math.expm1) - checkConsistencyBetweenInterpretedAndCodegen(Expm1, DoubleType) - } - - test("signum") { - testUnary[Double, Double](Signum, math.signum) - checkConsistencyBetweenInterpretedAndCodegen(Signum, DoubleType) - } - - test("log") { - testUnary(Log, math.log, (1 to 20).map(_ * 0.1)) - testUnary(Log, math.log, (-5 to 0).map(_ * 0.1), expectNull = true) - checkConsistencyBetweenInterpretedAndCodegen(Log, DoubleType) - } - - test("log10") { - testUnary(Log10, math.log10, (1 to 20).map(_ * 0.1)) - testUnary(Log10, math.log10, (-5 to 0).map(_ * 0.1), expectNull = true) - checkConsistencyBetweenInterpretedAndCodegen(Log10, DoubleType) - } - - test("log1p") { - testUnary(Log1p, math.log1p, (0 to 20).map(_ * 0.1)) - testUnary(Log1p, math.log1p, (-10 to -1).map(_ * 1.0), expectNull = true) - checkConsistencyBetweenInterpretedAndCodegen(Log1p, DoubleType) - } - - test("bin") { - testUnary(Bin, java.lang.Long.toBinaryString, (-20 to 20).map(_.toLong), evalType = LongType) - - val row = create_row(null, 12L, 123L, 1234L, -123L) - val l1 = 'a.long.at(0) - val l2 = 'a.long.at(1) - val l3 = 'a.long.at(2) - val l4 = 'a.long.at(3) - val l5 = 'a.long.at(4) - - checkEvaluation(Bin(l1), null, row) - checkEvaluation(Bin(l2), java.lang.Long.toBinaryString(12), row) - checkEvaluation(Bin(l3), java.lang.Long.toBinaryString(123), row) - checkEvaluation(Bin(l4), java.lang.Long.toBinaryString(1234), row) - checkEvaluation(Bin(l5), java.lang.Long.toBinaryString(-123), row) - - checkEvaluation(Bin(positiveLongLit), java.lang.Long.toBinaryString(positiveLong)) - checkEvaluation(Bin(negativeLongLit), java.lang.Long.toBinaryString(negativeLong)) - - checkConsistencyBetweenInterpretedAndCodegen(Bin, LongType) - } - - test("log2") { - def f: (Double) => Double = (x: Double) => math.log(x) / math.log(2) - testUnary(Log2, f, (1 to 20).map(_ * 0.1)) - testUnary(Log2, f, (-5 to 0).map(_ * 1.0), expectNull = true) - checkConsistencyBetweenInterpretedAndCodegen(Log2, DoubleType) - } - - test("sqrt") { - testUnary(Sqrt, math.sqrt, (0 to 20).map(_ * 0.1)) - testUnary(Sqrt, math.sqrt, (-5 to -1).map(_ * 1.0), expectNaN = true) - - checkEvaluation(Sqrt(Literal.create(null, DoubleType)), null, create_row(null)) - checkNaN(Sqrt(Literal(-1.0)), EmptyRow) - checkNaN(Sqrt(Literal(-1.5)), EmptyRow) - checkConsistencyBetweenInterpretedAndCodegen(Sqrt, DoubleType) - } - - test("pow") { - testBinary(Pow, math.pow, (-5 to 5).map(v => (v * 1.0, v * 1.0))) - testBinary(Pow, math.pow, Seq((-1.0, 0.9), (-2.2, 1.7), (-2.2, -1.7)), expectNaN = true) - checkConsistencyBetweenInterpretedAndCodegen(Pow, DoubleType, DoubleType) - } - - test("shift left") { - checkEvaluation(ShiftLeft(Literal.create(null, IntegerType), Literal(1)), null) - checkEvaluation(ShiftLeft(Literal(21), Literal.create(null, IntegerType)), null) - checkEvaluation( - ShiftLeft(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null) - checkEvaluation(ShiftLeft(Literal(21), Literal(1)), 42) - - checkEvaluation(ShiftLeft(Literal(21.toLong), Literal(1)), 42.toLong) - checkEvaluation(ShiftLeft(Literal(-21.toLong), Literal(1)), -42.toLong) - - checkEvaluation(ShiftLeft(positiveIntLit, positiveIntLit), positiveInt << positiveInt) - checkEvaluation(ShiftLeft(positiveIntLit, negativeIntLit), positiveInt << negativeInt) - checkEvaluation(ShiftLeft(negativeIntLit, positiveIntLit), negativeInt << positiveInt) - checkEvaluation(ShiftLeft(negativeIntLit, negativeIntLit), negativeInt << negativeInt) - checkEvaluation(ShiftLeft(positiveLongLit, positiveIntLit), positiveLong << positiveInt) - checkEvaluation(ShiftLeft(positiveLongLit, negativeIntLit), positiveLong << negativeInt) - checkEvaluation(ShiftLeft(negativeLongLit, positiveIntLit), negativeLong << positiveInt) - checkEvaluation(ShiftLeft(negativeLongLit, negativeIntLit), negativeLong << negativeInt) - - checkConsistencyBetweenInterpretedAndCodegen(ShiftLeft, IntegerType, IntegerType) - checkConsistencyBetweenInterpretedAndCodegen(ShiftLeft, LongType, IntegerType) - } - - test("shift right") { - checkEvaluation(ShiftRight(Literal.create(null, IntegerType), Literal(1)), null) - checkEvaluation(ShiftRight(Literal(42), Literal.create(null, IntegerType)), null) - checkEvaluation( - ShiftRight(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null) - checkEvaluation(ShiftRight(Literal(42), Literal(1)), 21) - - checkEvaluation(ShiftRight(Literal(42.toLong), Literal(1)), 21.toLong) - checkEvaluation(ShiftRight(Literal(-42.toLong), Literal(1)), -21.toLong) - - checkEvaluation(ShiftRight(positiveIntLit, positiveIntLit), positiveInt >> positiveInt) - checkEvaluation(ShiftRight(positiveIntLit, negativeIntLit), positiveInt >> negativeInt) - checkEvaluation(ShiftRight(negativeIntLit, positiveIntLit), negativeInt >> positiveInt) - checkEvaluation(ShiftRight(negativeIntLit, negativeIntLit), negativeInt >> negativeInt) - checkEvaluation(ShiftRight(positiveLongLit, positiveIntLit), positiveLong >> positiveInt) - checkEvaluation(ShiftRight(positiveLongLit, negativeIntLit), positiveLong >> negativeInt) - checkEvaluation(ShiftRight(negativeLongLit, positiveIntLit), negativeLong >> positiveInt) - checkEvaluation(ShiftRight(negativeLongLit, negativeIntLit), negativeLong >> negativeInt) - - checkConsistencyBetweenInterpretedAndCodegen(ShiftRight, IntegerType, IntegerType) - checkConsistencyBetweenInterpretedAndCodegen(ShiftRight, LongType, IntegerType) - } - - test("shift right unsigned") { - checkEvaluation(ShiftRightUnsigned(Literal.create(null, IntegerType), Literal(1)), null) - checkEvaluation(ShiftRightUnsigned(Literal(42), Literal.create(null, IntegerType)), null) - checkEvaluation( - ShiftRight(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null) - checkEvaluation(ShiftRightUnsigned(Literal(42), Literal(1)), 21) - - checkEvaluation(ShiftRightUnsigned(Literal(42.toLong), Literal(1)), 21.toLong) - checkEvaluation(ShiftRightUnsigned(Literal(-42.toLong), Literal(1)), 9223372036854775787L) - - checkEvaluation(ShiftRightUnsigned(positiveIntLit, positiveIntLit), - positiveInt >>> positiveInt) - checkEvaluation(ShiftRightUnsigned(positiveIntLit, negativeIntLit), - positiveInt >>> negativeInt) - checkEvaluation(ShiftRightUnsigned(negativeIntLit, positiveIntLit), - negativeInt >>> positiveInt) - checkEvaluation(ShiftRightUnsigned(negativeIntLit, negativeIntLit), - negativeInt >>> negativeInt) - checkEvaluation(ShiftRightUnsigned(positiveLongLit, positiveIntLit), - positiveLong >>> positiveInt) - checkEvaluation(ShiftRightUnsigned(positiveLongLit, negativeIntLit), - positiveLong >>> negativeInt) - checkEvaluation(ShiftRightUnsigned(negativeLongLit, positiveIntLit), - negativeLong >>> positiveInt) - checkEvaluation(ShiftRightUnsigned(negativeLongLit, negativeIntLit), - negativeLong >>> negativeInt) - - checkConsistencyBetweenInterpretedAndCodegen(ShiftRightUnsigned, IntegerType, IntegerType) - checkConsistencyBetweenInterpretedAndCodegen(ShiftRightUnsigned, LongType, IntegerType) - } - - test("hex") { - checkEvaluation(Hex(Literal.create(null, LongType)), null) - checkEvaluation(Hex(Literal(28L)), "1C") - checkEvaluation(Hex(Literal(-28L)), "FFFFFFFFFFFFFFE4") - checkEvaluation(Hex(Literal(100800200404L)), "177828FED4") - checkEvaluation(Hex(Literal(-100800200404L)), "FFFFFFE887D7012C") - checkEvaluation(Hex(Literal.create(null, BinaryType)), null) - checkEvaluation(Hex(Literal("helloHex".getBytes(StandardCharsets.UTF_8))), "68656C6C6F486578") - // scalastyle:off - // Turn off scala style for non-ascii chars - checkEvaluation(Hex(Literal("ä¸éç".getBytes(StandardCharsets.UTF_8))), "E4B889E9878DE79A84") - // scalastyle:on - Seq(LongType, BinaryType, StringType).foreach { dt => - checkConsistencyBetweenInterpretedAndCodegen(Hex.apply _, dt) - } - } - - test("unhex") { - checkEvaluation(Unhex(Literal.create(null, StringType)), null) - checkEvaluation(Unhex(Literal("737472696E67")), "string".getBytes(StandardCharsets.UTF_8)) - checkEvaluation(Unhex(Literal("")), new Array[Byte](0)) - checkEvaluation(Unhex(Literal("F")), Array[Byte](15)) - checkEvaluation(Unhex(Literal("ff")), Array[Byte](-1)) - checkEvaluation(Unhex(Literal("GG")), null) - // scalastyle:off - // Turn off scala style for non-ascii chars - checkEvaluation(Unhex(Literal("E4B889E9878DE79A84")), "ä¸éç".getBytes(StandardCharsets.UTF_8)) - checkEvaluation(Unhex(Literal("ä¸éç")), null) - // scalastyle:on - checkConsistencyBetweenInterpretedAndCodegen(Unhex, StringType) - } - - test("hypot") { - testBinary(Hypot, math.hypot) - checkConsistencyBetweenInterpretedAndCodegen(Hypot, DoubleType, DoubleType) - } - - test("atan2") { - testBinary(Atan2, math.atan2) - checkConsistencyBetweenInterpretedAndCodegen(Atan2, DoubleType, DoubleType) - } - - test("binary log") { - val f = (c1: Double, c2: Double) => math.log(c2) / math.log(c1) - val domain = (1 to 20).map(v => (v * 0.1, v * 0.2)) - - domain.foreach { case (v1, v2) => - checkEvaluation(Logarithm(Literal(v1), Literal(v2)), f(v1 + 0.0, v2 + 0.0), EmptyRow) - checkEvaluation(Logarithm(Literal(v2), Literal(v1)), f(v2 + 0.0, v1 + 0.0), EmptyRow) - checkEvaluation(new Logarithm(Literal(v1)), f(math.E, v1 + 0.0), EmptyRow) - } - - // null input should yield null output - checkEvaluation( - Logarithm(Literal.create(null, DoubleType), Literal(1.0)), - null, - create_row(null)) - checkEvaluation( - Logarithm(Literal(1.0), Literal.create(null, DoubleType)), - null, - create_row(null)) - - // negative input should yield null output - checkEvaluation( - Logarithm(Literal(-1.0), Literal(1.0)), - null, - create_row(null)) - checkEvaluation( - Logarithm(Literal(1.0), Literal(-1.0)), - null, - create_row(null)) - checkConsistencyBetweenInterpretedAndCodegen(Logarithm, DoubleType, DoubleType) - } - - test("round/bround") { - val scales = -6 to 6 - val doublePi: Double = math.Pi - val shortPi: Short = 31415 - val intPi: Int = 314159265 - val longPi: Long = 31415926535897932L - val bdPi: BigDecimal = BigDecimal(31415927L, 7) - - val doubleResults: Seq[Double] = Seq(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.0, 3.1, 3.14, 3.142, - 3.1416, 3.14159, 3.141593) - - val shortResults: Seq[Short] = Seq[Short](0, 0, 30000, 31000, 31400, 31420) ++ - Seq.fill[Short](7)(31415) - - val intResults: Seq[Int] = Seq(314000000, 314200000, 314160000, 314159000, 314159300, - 314159270) ++ Seq.fill(7)(314159265) - - val longResults: Seq[Long] = Seq(31415926536000000L, 31415926535900000L, - 31415926535900000L, 31415926535898000L, 31415926535897900L, 31415926535897930L) ++ - Seq.fill(7)(31415926535897932L) - - val intResultsB: Seq[Int] = Seq(314000000, 314200000, 314160000, 314159000, 314159300, - 314159260) ++ Seq.fill(7)(314159265) - - scales.zipWithIndex.foreach { case (scale, i) => - checkEvaluation(Round(doublePi, scale), doubleResults(i), EmptyRow) - checkEvaluation(Round(shortPi, scale), shortResults(i), EmptyRow) - checkEvaluation(Round(intPi, scale), intResults(i), EmptyRow) - checkEvaluation(Round(longPi, scale), longResults(i), EmptyRow) - checkEvaluation(BRound(doublePi, scale), doubleResults(i), EmptyRow) - checkEvaluation(BRound(shortPi, scale), shortResults(i), EmptyRow) - checkEvaluation(BRound(intPi, scale), intResultsB(i), EmptyRow) - checkEvaluation(BRound(longPi, scale), longResults(i), EmptyRow) - } - - val bdResults: Seq[BigDecimal] = Seq(BigDecimal(3.0), BigDecimal(3.1), BigDecimal(3.14), - BigDecimal(3.142), BigDecimal(3.1416), BigDecimal(3.14159), - BigDecimal(3.141593), BigDecimal(3.1415927)) - // round_scale > current_scale would result in precision increase - // and not allowed by o.a.s.s.types.Decimal.changePrecision, therefore null - (0 to 7).foreach { i => - checkEvaluation(Round(bdPi, i), bdResults(i), EmptyRow) - checkEvaluation(BRound(bdPi, i), bdResults(i), EmptyRow) - } - (8 to 10).foreach { scale => - checkEvaluation(Round(bdPi, scale), null, EmptyRow) - checkEvaluation(BRound(bdPi, scale), null, EmptyRow) - } - - DataTypeTestUtils.numericTypes.foreach { dataType => - checkEvaluation(Round(Literal.create(null, dataType), Literal(2)), null) - checkEvaluation(Round(Literal.create(null, dataType), - Literal.create(null, IntegerType)), null) - checkEvaluation(BRound(Literal.create(null, dataType), Literal(2)), null) - checkEvaluation(BRound(Literal.create(null, dataType), - Literal.create(null, IntegerType)), null) - } - - checkEvaluation(Round(2.5, 0), 3.0) - checkEvaluation(Round(3.5, 0), 4.0) - checkEvaluation(Round(-2.5, 0), -3.0) - checkEvaluation(Round(-3.5, 0), -4.0) - checkEvaluation(Round(-0.35, 1), -0.4) - checkEvaluation(Round(-35, -1), -40) - checkEvaluation(BRound(2.5, 0), 2.0) - checkEvaluation(BRound(3.5, 0), 4.0) - checkEvaluation(BRound(-2.5, 0), -2.0) - checkEvaluation(BRound(-3.5, 0), -4.0) - checkEvaluation(BRound(-0.35, 1), -0.4) - checkEvaluation(BRound(-35, -1), -40) - } -} http://git-wip-us.apache.org/repos/asf/spark/blob/2fa1a632/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala new file mode 100644 index 0000000..a26d070 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.types._ + +class MiscExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { + + test("assert_true") { + intercept[RuntimeException] { + checkEvaluation(AssertTrue(Literal.create(false, BooleanType)), null) + } + intercept[RuntimeException] { + checkEvaluation(AssertTrue(Cast(Literal(0), BooleanType)), null) + } + intercept[RuntimeException] { + checkEvaluation(AssertTrue(Literal.create(null, NullType)), null) + } + intercept[RuntimeException] { + checkEvaluation(AssertTrue(Literal.create(null, BooleanType)), null) + } + checkEvaluation(AssertTrue(Literal.create(true, BooleanType)), null) + checkEvaluation(AssertTrue(Cast(Literal(1), BooleanType)), null) + } + +} http://git-wip-us.apache.org/repos/asf/spark/blob/2fa1a632/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala deleted file mode 100644 index ed82efe..0000000 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala +++ /dev/null @@ -1,42 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions - -import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.types._ - -class MiscFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { - - test("assert_true") { - intercept[RuntimeException] { - checkEvaluation(AssertTrue(Literal.create(false, BooleanType)), null) - } - intercept[RuntimeException] { - checkEvaluation(AssertTrue(Cast(Literal(0), BooleanType)), null) - } - intercept[RuntimeException] { - checkEvaluation(AssertTrue(Literal.create(null, NullType)), null) - } - intercept[RuntimeException] { - checkEvaluation(AssertTrue(Literal.create(null, BooleanType)), null) - } - checkEvaluation(AssertTrue(Literal.create(true, BooleanType)), null) - checkEvaluation(AssertTrue(Cast(Literal(1), BooleanType)), null) - } - -} http://git-wip-us.apache.org/repos/asf/spark/blob/2fa1a632/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala new file mode 100644 index 0000000..5064a1f --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer +import org.apache.spark.sql.catalyst.expressions.objects.AssertNotNull +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project} +import org.apache.spark.sql.types._ + +class NullExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { + + def testAllTypes(testFunc: (Any, DataType) => Unit): Unit = { + testFunc(false, BooleanType) + testFunc(1.toByte, ByteType) + testFunc(1.toShort, ShortType) + testFunc(1, IntegerType) + testFunc(1L, LongType) + testFunc(1.0F, FloatType) + testFunc(1.0, DoubleType) + testFunc(Decimal(1.5), DecimalType(2, 1)) + testFunc(new java.sql.Date(10), DateType) + testFunc(new java.sql.Timestamp(10), TimestampType) + testFunc("abcd", StringType) + } + + test("isnull and isnotnull") { + testAllTypes { (value: Any, tpe: DataType) => + checkEvaluation(IsNull(Literal.create(value, tpe)), false) + checkEvaluation(IsNotNull(Literal.create(value, tpe)), true) + checkEvaluation(IsNull(Literal.create(null, tpe)), true) + checkEvaluation(IsNotNull(Literal.create(null, tpe)), false) + } + } + + test("AssertNotNUll") { + val ex = intercept[RuntimeException] { + evaluate(AssertNotNull(Literal(null), Seq.empty[String])) + }.getMessage + assert(ex.contains("Null value appeared in non-nullable field")) + } + + test("IsNaN") { + checkEvaluation(IsNaN(Literal(Double.NaN)), true) + checkEvaluation(IsNaN(Literal(Float.NaN)), true) + checkEvaluation(IsNaN(Literal(math.log(-3))), true) + checkEvaluation(IsNaN(Literal.create(null, DoubleType)), false) + checkEvaluation(IsNaN(Literal(Double.PositiveInfinity)), false) + checkEvaluation(IsNaN(Literal(Float.MaxValue)), false) + checkEvaluation(IsNaN(Literal(5.5f)), false) + } + + test("nanvl") { + checkEvaluation(NaNvl(Literal(5.0), Literal.create(null, DoubleType)), 5.0) + checkEvaluation(NaNvl(Literal.create(null, DoubleType), Literal(5.0)), null) + checkEvaluation(NaNvl(Literal.create(null, DoubleType), Literal(Double.NaN)), null) + checkEvaluation(NaNvl(Literal(Double.NaN), Literal(5.0)), 5.0) + checkEvaluation(NaNvl(Literal(Double.NaN), Literal.create(null, DoubleType)), null) + assert(NaNvl(Literal(Double.NaN), Literal(Double.NaN)). + eval(EmptyRow).asInstanceOf[Double].isNaN) + } + + test("coalesce") { + testAllTypes { (value: Any, tpe: DataType) => + val lit = Literal.create(value, tpe) + val nullLit = Literal.create(null, tpe) + checkEvaluation(Coalesce(Seq(nullLit)), null) + checkEvaluation(Coalesce(Seq(lit)), value) + checkEvaluation(Coalesce(Seq(nullLit, lit)), value) + checkEvaluation(Coalesce(Seq(nullLit, lit, lit)), value) + checkEvaluation(Coalesce(Seq(nullLit, nullLit, lit)), value) + } + } + + test("SPARK-16602 Nvl should support numeric-string cases") { + def analyze(expr: Expression): Expression = { + val relation = LocalRelation() + SimpleAnalyzer.execute(Project(Seq(Alias(expr, "c")()), relation)).expressions.head + } + + val intLit = Literal.create(1, IntegerType) + val doubleLit = Literal.create(2.2, DoubleType) + val stringLit = Literal.create("c", StringType) + val nullLit = Literal.create(null, NullType) + + assert(analyze(new Nvl(intLit, doubleLit)).dataType == DoubleType) + assert(analyze(new Nvl(intLit, stringLit)).dataType == StringType) + assert(analyze(new Nvl(stringLit, doubleLit)).dataType == StringType) + + assert(analyze(new Nvl(nullLit, intLit)).dataType == IntegerType) + assert(analyze(new Nvl(doubleLit, nullLit)).dataType == DoubleType) + assert(analyze(new Nvl(nullLit, stringLit)).dataType == StringType) + } + + test("AtLeastNNonNulls") { + val mix = Seq(Literal("x"), + Literal.create(null, StringType), + Literal.create(null, DoubleType), + Literal(Double.NaN), + Literal(5f)) + + val nanOnly = Seq(Literal("x"), + Literal(10.0), + Literal(Float.NaN), + Literal(math.log(-2)), + Literal(Double.MaxValue)) + + val nullOnly = Seq(Literal("x"), + Literal.create(null, DoubleType), + Literal.create(null, DecimalType.USER_DEFAULT), + Literal(Float.MaxValue), + Literal(false)) + + checkEvaluation(AtLeastNNonNulls(2, mix), true, EmptyRow) + checkEvaluation(AtLeastNNonNulls(3, mix), false, EmptyRow) + checkEvaluation(AtLeastNNonNulls(3, nanOnly), true, EmptyRow) + checkEvaluation(AtLeastNNonNulls(4, nanOnly), false, EmptyRow) + checkEvaluation(AtLeastNNonNulls(3, nullOnly), true, EmptyRow) + checkEvaluation(AtLeastNNonNulls(4, nullOnly), false, EmptyRow) + } +} --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org