[GitHub] [spark] cloud-fan commented on a change in pull request #29792: [SPARK-32858][SQL] UnwrapCastInBinaryComparison: support other numeric types

2020-10-12 Thread GitBox


cloud-fan commented on a change in pull request #29792:
URL: https://github.com/apache/spark/pull/29792#discussion_r503135640



##
File path: 
sql/core/src/test/scala/org/apache/spark/sql/UnwrapCastInComparisonEndToEndSuite.scala
##
@@ -0,0 +1,165 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import 
org.apache.spark.sql.catalyst.expressions.IntegralLiteralTestUtils.{negativeInt,
 positiveInt}
+import org.apache.spark.sql.test.SharedSparkSession
+import org.apache.spark.sql.types.Decimal
+
+class UnwrapCastInComparisonEndToEndSuite extends QueryTest with 
SharedSparkSession {
+  import testImplicits._
+
+  test("cases when literal is max") {
+val t = "test_table"
+withTable(t) {
+  Seq[(Integer, java.lang.Short, java.lang.Float)](
+(1, 100.toShort, 3.14.toFloat), (2, Short.MaxValue, Float.NaN), (3, 
null, null))
+.toDF("c1", "c2", "c3").write.saveAsTable(t)
+  val df = spark.table(t)
+
+  var lit = Short.MaxValue.toInt
+  checkAnswer(df.select("c1").where(s"c2 > $lit"), Seq.empty)
+  checkAnswer(df.select("c1").where(s"c2 >= $lit"), Row(2))
+  checkAnswer(df.select("c1").where(s"c2 == $lit"), Row(2))
+  checkAnswer(df.select("c1").where(s"c2 <=> $lit"), Row(2))
+  checkAnswer(df.select("c1").where(s"c2 != $lit"), Row(1))
+  checkAnswer(df.select("c1").where(s"c2 <= $lit"), Row(1) :: Row(2) :: 
Nil)
+  checkAnswer(df.select("c1").where(s"c2 < $lit"), Row(1))
+
+  checkAnswer(df.select("c1").where(s"c3 > double('nan')"), Seq.empty)
+  checkAnswer(df.select("c1").where(s"c3 >= double('nan')"), Row(2))
+  checkAnswer(df.select("c1").where(s"c3 == double('nan')"), Row(2))
+  checkAnswer(df.select("c1").where(s"c3 <=> double('nan')"), Row(2))
+  checkAnswer(df.select("c1").where(s"c3 != double('nan')"), Row(1))
+  checkAnswer(df.select("c1").where(s"c3 <= double('nan')"), Row(1) :: 
Row(2) :: Nil)
+  checkAnswer(df.select("c1").where(s"c3 < double('nan')"), Row(1))
+
+  lit = positiveInt
+  checkAnswer(df.select("c1").where(s"c2 > $lit"), Seq.empty)
+  checkAnswer(df.select("c1").where(s"c2 >= $lit"), Seq.empty)
+  checkAnswer(df.select("c1").where(s"c2 == $lit"), Seq.empty)
+  checkAnswer(df.select("c1").where(s"c2 <=> $lit"), Seq.empty)
+  checkAnswer(df.select("c1").where(s"c2 != $lit"), Row(1) :: Row(2) :: 
Nil)
+  checkAnswer(df.select("c1").where(s"c2 <= $lit"), Row(1) :: Row(2) :: 
Nil)
+  checkAnswer(df.select("c1").where(s"c2 < $lit"), Row(1) :: Row(2) :: Nil)
+}
+  }
+
+  test("cases when literal is min") {
+val t = "test_table"
+withTable(t) {
+  Seq[(Integer, java.lang.Short, java.lang.Float)](
+(1, 100.toShort, 3.14.toFloat), (2, Short.MinValue, 
Float.NegativeInfinity),
+(3, null, null))
+.toDF("c1", "c2", "c3").write.saveAsTable(t)
+  val df = spark.table(t)
+
+  var lit = Short.MinValue.toInt
+  checkAnswer(df.select("c1").where(s"c2 > $lit"), Row(1))

Review comment:
   can we put the `where` before `select`? It's weird to see we select only 
`c1` and then filter on `c2`.





This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org



-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] [spark] cloud-fan commented on a change in pull request #29792: [SPARK-32858][SQL] UnwrapCastInBinaryComparison: support other numeric types

2020-10-12 Thread GitBox


cloud-fan commented on a change in pull request #29792:
URL: https://github.com/apache/spark/pull/29792#discussion_r503134929



##
File path: 
sql/core/src/test/scala/org/apache/spark/sql/UnwrapCastInComparisonEndToEndSuite.scala
##
@@ -0,0 +1,165 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import 
org.apache.spark.sql.catalyst.expressions.IntegralLiteralTestUtils.{negativeInt,
 positiveInt}
+import org.apache.spark.sql.test.SharedSparkSession
+import org.apache.spark.sql.types.Decimal
+
+class UnwrapCastInComparisonEndToEndSuite extends QueryTest with 
SharedSparkSession {
+  import testImplicits._
+
+  test("cases when literal is max") {
+val t = "test_table"
+withTable(t) {
+  Seq[(Integer, java.lang.Short, java.lang.Float)](
+(1, 100.toShort, 3.14.toFloat), (2, Short.MaxValue, Float.NaN), (3, 
null, null))
+.toDF("c1", "c2", "c3").write.saveAsTable(t)
+  val df = spark.table(t)
+
+  var lit = Short.MaxValue.toInt
+  checkAnswer(df.select("c1").where(s"c2 > $lit"), Seq.empty)
+  checkAnswer(df.select("c1").where(s"c2 >= $lit"), Row(2))
+  checkAnswer(df.select("c1").where(s"c2 == $lit"), Row(2))
+  checkAnswer(df.select("c1").where(s"c2 <=> $lit"), Row(2))
+  checkAnswer(df.select("c1").where(s"c2 != $lit"), Row(1))
+  checkAnswer(df.select("c1").where(s"c2 <= $lit"), Row(1) :: Row(2) :: 
Nil)
+  checkAnswer(df.select("c1").where(s"c2 < $lit"), Row(1))
+
+  checkAnswer(df.select("c1").where(s"c3 > double('nan')"), Seq.empty)
+  checkAnswer(df.select("c1").where(s"c3 >= double('nan')"), Row(2))
+  checkAnswer(df.select("c1").where(s"c3 == double('nan')"), Row(2))
+  checkAnswer(df.select("c1").where(s"c3 <=> double('nan')"), Row(2))
+  checkAnswer(df.select("c1").where(s"c3 != double('nan')"), Row(1))
+  checkAnswer(df.select("c1").where(s"c3 <= double('nan')"), Row(1) :: 
Row(2) :: Nil)
+  checkAnswer(df.select("c1").where(s"c3 < double('nan')"), Row(1))
+
+  lit = positiveInt
+  checkAnswer(df.select("c1").where(s"c2 > $lit"), Seq.empty)
+  checkAnswer(df.select("c1").where(s"c2 >= $lit"), Seq.empty)
+  checkAnswer(df.select("c1").where(s"c2 == $lit"), Seq.empty)
+  checkAnswer(df.select("c1").where(s"c2 <=> $lit"), Seq.empty)
+  checkAnswer(df.select("c1").where(s"c2 != $lit"), Row(1) :: Row(2) :: 
Nil)
+  checkAnswer(df.select("c1").where(s"c2 <= $lit"), Row(1) :: Row(2) :: 
Nil)
+  checkAnswer(df.select("c1").where(s"c2 < $lit"), Row(1) :: Row(2) :: Nil)
+}
+  }
+
+  test("cases when literal is min") {
+val t = "test_table"
+withTable(t) {
+  Seq[(Integer, java.lang.Short, java.lang.Float)](
+(1, 100.toShort, 3.14.toFloat), (2, Short.MinValue, 
Float.NegativeInfinity),
+(3, null, null))
+.toDF("c1", "c2", "c3").write.saveAsTable(t)
+  val df = spark.table(t)
+
+  var lit = Short.MinValue.toInt
+  checkAnswer(df.select("c1").where(s"c2 > $lit"), Row(1))
+  checkAnswer(df.select("c1").where(s"c2 >= $lit"), Row(1) :: Row(2) :: 
Nil)
+  checkAnswer(df.select("c1").where(s"c2 == $lit"), Row(2))
+  checkAnswer(df.select("c1").where(s"c2 <=> $lit"), Row(2))
+  checkAnswer(df.select("c1").where(s"c2 != $lit"), Row(1))
+  checkAnswer(df.select("c1").where(s"c2 <= $lit"), Row(2))
+  checkAnswer(df.select("c1").where(s"c2 < $lit"), Seq.empty)
+
+  checkAnswer(df.select("c1").where(s"c3 > double('-inf')"), Row(1))
+  checkAnswer(df.select("c1").where(s"c3 >= double('-inf')"), Row(1) :: 
Row(2) :: Nil)
+  checkAnswer(df.select("c1").where(s"c3 == double('-inf')"), Row(2))
+  checkAnswer(df.select("c1").where(s"c3 <=> double('-inf')"), Row(2))
+  checkAnswer(df.select("c1").where(s"c3 != double('-inf')"), Row(1))
+  checkAnswer(df.select("c1").where(s"c3 <= double('-inf')"), Row(2))
+  checkAnswer(df.select("c1").where(s"c3 < double('-inf')"), Seq.empty)
+
+  lit = negativeInt
+  checkAnswer(df.select("c1").where(s"c2 > $lit"), Row(1) :: Row(2) :: Nil)
+  checkAnswer(df.select("c1").where(s"c2 >= $lit"), Row(1) :: Row(2) :: 
Nil)
+  

[GitHub] [spark] cloud-fan commented on a change in pull request #29792: [SPARK-32858][SQL] UnwrapCastInBinaryComparison: support other numeric types

2020-10-12 Thread GitBox


cloud-fan commented on a change in pull request #29792:
URL: https://github.com/apache/spark/pull/29792#discussion_r503133967



##
File path: 
sql/core/src/test/scala/org/apache/spark/sql/UnwrapCastInComparisonEndToEndSuite.scala
##
@@ -0,0 +1,165 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import 
org.apache.spark.sql.catalyst.expressions.IntegralLiteralTestUtils.{negativeInt,
 positiveInt}
+import org.apache.spark.sql.test.SharedSparkSession
+import org.apache.spark.sql.types.Decimal
+
+class UnwrapCastInComparisonEndToEndSuite extends QueryTest with 
SharedSparkSession {
+  import testImplicits._
+
+  test("cases when literal is max") {
+val t = "test_table"
+withTable(t) {
+  Seq[(Integer, java.lang.Short, java.lang.Float)](
+(1, 100.toShort, 3.14.toFloat), (2, Short.MaxValue, Float.NaN), (3, 
null, null))
+.toDF("c1", "c2", "c3").write.saveAsTable(t)
+  val df = spark.table(t)
+
+  var lit = Short.MaxValue.toInt
+  checkAnswer(df.select("c1").where(s"c2 > $lit"), Seq.empty)
+  checkAnswer(df.select("c1").where(s"c2 >= $lit"), Row(2))
+  checkAnswer(df.select("c1").where(s"c2 == $lit"), Row(2))
+  checkAnswer(df.select("c1").where(s"c2 <=> $lit"), Row(2))
+  checkAnswer(df.select("c1").where(s"c2 != $lit"), Row(1))
+  checkAnswer(df.select("c1").where(s"c2 <= $lit"), Row(1) :: Row(2) :: 
Nil)
+  checkAnswer(df.select("c1").where(s"c2 < $lit"), Row(1))
+
+  checkAnswer(df.select("c1").where(s"c3 > double('nan')"), Seq.empty)
+  checkAnswer(df.select("c1").where(s"c3 >= double('nan')"), Row(2))
+  checkAnswer(df.select("c1").where(s"c3 == double('nan')"), Row(2))
+  checkAnswer(df.select("c1").where(s"c3 <=> double('nan')"), Row(2))
+  checkAnswer(df.select("c1").where(s"c3 != double('nan')"), Row(1))
+  checkAnswer(df.select("c1").where(s"c3 <= double('nan')"), Row(1) :: 
Row(2) :: Nil)
+  checkAnswer(df.select("c1").where(s"c3 < double('nan')"), Row(1))
+
+  lit = positiveInt

Review comment:
   this doesn't match the test case name `cases when literal is max`
   
   We can put it in a new test `cases when literal exceeds max`

##
File path: 
sql/core/src/test/scala/org/apache/spark/sql/UnwrapCastInComparisonEndToEndSuite.scala
##
@@ -0,0 +1,165 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import 
org.apache.spark.sql.catalyst.expressions.IntegralLiteralTestUtils.{negativeInt,
 positiveInt}
+import org.apache.spark.sql.test.SharedSparkSession
+import org.apache.spark.sql.types.Decimal
+
+class UnwrapCastInComparisonEndToEndSuite extends QueryTest with 
SharedSparkSession {
+  import testImplicits._
+
+  test("cases when literal is max") {
+val t = "test_table"
+withTable(t) {
+  Seq[(Integer, java.lang.Short, java.lang.Float)](
+(1, 100.toShort, 3.14.toFloat), (2, Short.MaxValue, Float.NaN), (3, 
null, null))
+.toDF("c1", "c2", "c3").write.saveAsTable(t)
+  val df = spark.table(t)
+
+  var lit = Short.MaxValue.toInt
+  checkAnswer(df.select("c1").where(s"c2 > $lit"), Seq.empty)
+  checkAnswer(df.select("c1").where(s"c2 >= $lit"), Row(2))
+  checkAnswer(df.select("c1").where(s"c2 == $lit"), Row(2))
+  checkAnswer(df.select("c1").where(s"c2 <=> $lit"), Row(2))
+  checkAnswer(df.select("c1").where(s"c2 != $lit"), Row(1))
+  

[GitHub] [spark] cloud-fan commented on a change in pull request #29792: [SPARK-32858][SQL] UnwrapCastInBinaryComparison: support other numeric types

2020-10-12 Thread GitBox


cloud-fan commented on a change in pull request #29792:
URL: https://github.com/apache/spark/pull/29792#discussion_r503132659



##
File path: 
sql/core/src/test/scala/org/apache/spark/sql/UnwrapCastInComparisonEndToEndSuite.scala
##
@@ -0,0 +1,165 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import 
org.apache.spark.sql.catalyst.expressions.IntegralLiteralTestUtils.{negativeInt,
 positiveInt}
+import org.apache.spark.sql.test.SharedSparkSession
+import org.apache.spark.sql.types.Decimal
+
+class UnwrapCastInComparisonEndToEndSuite extends QueryTest with 
SharedSparkSession {
+  import testImplicits._
+
+  test("cases when literal is max") {
+val t = "test_table"
+withTable(t) {
+  Seq[(Integer, java.lang.Short, java.lang.Float)](
+(1, 100.toShort, 3.14.toFloat), (2, Short.MaxValue, Float.NaN), (3, 
null, null))

Review comment:
   can we test `Float.PositiveInfinity` as well?





This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org



-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] [spark] cloud-fan commented on a change in pull request #29792: [SPARK-32858][SQL] UnwrapCastInBinaryComparison: support other numeric types

2020-10-12 Thread GitBox


cloud-fan commented on a change in pull request #29792:
URL: https://github.com/apache/spark/pull/29792#discussion_r503130781



##
File path: 
sql/core/src/test/scala/org/apache/spark/sql/UnwrapCastInComparisonEndToEndSuite.scala
##
@@ -0,0 +1,165 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import 
org.apache.spark.sql.catalyst.expressions.IntegralLiteralTestUtils.{negativeInt,
 positiveInt}
+import org.apache.spark.sql.test.SharedSparkSession
+import org.apache.spark.sql.types.Decimal
+
+class UnwrapCastInComparisonEndToEndSuite extends QueryTest with 
SharedSparkSession {

Review comment:
   I think it's fine to follow ReplaceNullWithFalseInPredicateEndToEndSuite 
here.





This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org



-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] [spark] cloud-fan commented on a change in pull request #29792: [SPARK-32858][SQL] UnwrapCastInBinaryComparison: support other numeric types

2020-10-06 Thread GitBox


cloud-fan commented on a change in pull request #29792:
URL: https://github.com/apache/spark/pull/29792#discussion_r500247652



##
File path: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala
##
@@ -116,82 +128,118 @@ object UnwrapCastInBinaryComparison extends 
Rule[LogicalPlan] {
* optimizes the expression by moving the cast to the literal side. 
Otherwise if result is not
* true, this replaces the input binary comparison `exp` with simpler 
expressions.
*/
-  private def simplifyIntegralComparison(
+  private def simplifyNumericComparison(
   exp: BinaryComparison,
   fromExp: Expression,
-  toType: IntegralType,
+  toType: NumericType,
   value: Any): Expression = {
 
 val fromType = fromExp.dataType
-val (min, max) = getRange(fromType)
-val (minInToType, maxInToType) = {
-  (Cast(Literal(min), toType).eval(), Cast(Literal(max), toType).eval())
-}
 val ordering = toType.ordering.asInstanceOf[Ordering[Any]]
-val minCmp = ordering.compare(value, minInToType)
-val maxCmp = ordering.compare(value, maxInToType)
+val range = getRange(fromType)
 
-if (maxCmp > 0) {
-  exp match {
-case EqualTo(_, _) | GreaterThan(_, _) | GreaterThanOrEqual(_, _) =>
-  falseIfNotNull(fromExp)
-case LessThan(_, _) | LessThanOrEqual(_, _) =>
-  trueIfNotNull(fromExp)
-// make sure the expression is evaluated if it is non-deterministic
-case EqualNullSafe(_, _) if exp.deterministic =>
-  FalseLiteral
-case _ => exp
+if (range.isDefined) {
+  val (min, max) = range.get
+  val (minInToType, maxInToType) = {
+(Cast(Literal(min), toType).eval(), Cast(Literal(max), toType).eval())
   }
-} else if (maxCmp == 0) {
-  exp match {
-case GreaterThan(_, _) =>
-  falseIfNotNull(fromExp)
-case LessThanOrEqual(_, _) =>
-  trueIfNotNull(fromExp)
-case LessThan(_, _) =>
-  Not(EqualTo(fromExp, Literal(max, fromType)))
-case GreaterThanOrEqual(_, _) | EqualTo(_, _) =>
-  EqualTo(fromExp, Literal(max, fromType))
-case EqualNullSafe(_, _) =>
-  EqualNullSafe(fromExp, Literal(max, fromType))
-case _ => exp
+  val minCmp = ordering.compare(value, minInToType)
+  val maxCmp = ordering.compare(value, maxInToType)
+
+  if (maxCmp >= 0 || minCmp <= 0) {
+return if (maxCmp > 0) {
+  exp match {
+case EqualTo(_, _) | GreaterThan(_, _) | GreaterThanOrEqual(_, _) 
=>
+  falseIfNotNull(fromExp)
+case LessThan(_, _) | LessThanOrEqual(_, _) =>
+  trueIfNotNull(fromExp)
+// make sure the expression is evaluated if it is non-deterministic
+case EqualNullSafe(_, _) if exp.deterministic =>
+  FalseLiteral
+case _ => exp
+  }
+} else if (maxCmp == 0) {
+  exp match {
+case GreaterThan(_, _) =>
+  falseIfNotNull(fromExp)
+case LessThanOrEqual(_, _) =>
+  trueIfNotNull(fromExp)
+case LessThan(_, _) =>
+  Not(EqualTo(fromExp, Literal(max, fromType)))
+case GreaterThanOrEqual(_, _) | EqualTo(_, _) =>
+  EqualTo(fromExp, Literal(max, fromType))
+case EqualNullSafe(_, _) =>
+  EqualNullSafe(fromExp, Literal(max, fromType))
+case _ => exp
+  }
+} else if (minCmp < 0) {
+  exp match {
+case GreaterThan(_, _) | GreaterThanOrEqual(_, _) =>
+  trueIfNotNull(fromExp)
+case LessThan(_, _) | LessThanOrEqual(_, _) | EqualTo(_, _) =>
+  falseIfNotNull(fromExp)
+// make sure the expression is evaluated if it is non-deterministic
+case EqualNullSafe(_, _) if exp.deterministic =>
+  FalseLiteral
+case _ => exp
+  }
+} else { // minCmp == 0
+  exp match {
+case LessThan(_, _) =>
+  falseIfNotNull(fromExp)
+case GreaterThanOrEqual(_, _) =>
+  trueIfNotNull(fromExp)
+case GreaterThan(_, _) =>
+  Not(EqualTo(fromExp, Literal(min, fromType)))
+case LessThanOrEqual(_, _) | EqualTo(_, _) =>
+  EqualTo(fromExp, Literal(min, fromType))
+case EqualNullSafe(_, _) =>
+  EqualNullSafe(fromExp, Literal(min, fromType))
+case _ => exp
+  }
+}
   }
-} else if (minCmp < 0) {
+}
+
+// When we reach to this point, it means either there is no min/max for 
the `fromType` (e.g.,
+// decimal type), or that the literal `value` is within range `(min, 
max)`. For these, we
+// optimize by moving the cast to the literal side.
+
+val newValue = Cast(Literal(value), fromType).eval()
+if 

[GitHub] [spark] cloud-fan commented on a change in pull request #29792: [SPARK-32858][SQL] UnwrapCastInBinaryComparison: support other numeric types

2020-10-06 Thread GitBox


cloud-fan commented on a change in pull request #29792:
URL: https://github.com/apache/spark/pull/29792#discussion_r500137824



##
File path: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala
##
@@ -116,82 +132,118 @@ object UnwrapCastInBinaryComparison extends 
Rule[LogicalPlan] {
* optimizes the expression by moving the cast to the literal side. 
Otherwise if result is not
* true, this replaces the input binary comparison `exp` with simpler 
expressions.
*/
-  private def simplifyIntegralComparison(
+  private def simplifyNumericComparison(
   exp: BinaryComparison,
   fromExp: Expression,
-  toType: IntegralType,
+  toType: NumericType,
   value: Any): Expression = {
 
 val fromType = fromExp.dataType
-val (min, max) = getRange(fromType)
-val (minInToType, maxInToType) = {
-  (Cast(Literal(min), toType).eval(), Cast(Literal(max), toType).eval())
-}
 val ordering = toType.ordering.asInstanceOf[Ordering[Any]]
-val minCmp = ordering.compare(value, minInToType)
-val maxCmp = ordering.compare(value, maxInToType)
+val range = getRange(fromType)
 
-if (maxCmp > 0) {
-  exp match {
-case EqualTo(_, _) | GreaterThan(_, _) | GreaterThanOrEqual(_, _) =>
-  falseIfNotNull(fromExp)
-case LessThan(_, _) | LessThanOrEqual(_, _) =>
-  trueIfNotNull(fromExp)
-// make sure the expression is evaluated if it is non-deterministic
-case EqualNullSafe(_, _) if exp.deterministic =>
-  FalseLiteral
-case _ => exp
+if (range.isDefined) {
+  val (min, max) = range.get
+  val (minInToType, maxInToType) = {
+(Cast(Literal(min), toType).eval(), Cast(Literal(max), toType).eval())
   }
-} else if (maxCmp == 0) {
-  exp match {
-case GreaterThan(_, _) =>
-  falseIfNotNull(fromExp)
-case LessThanOrEqual(_, _) =>
-  trueIfNotNull(fromExp)
-case LessThan(_, _) =>
-  Not(EqualTo(fromExp, Literal(max, fromType)))
-case GreaterThanOrEqual(_, _) | EqualTo(_, _) =>
-  EqualTo(fromExp, Literal(max, fromType))
-case EqualNullSafe(_, _) =>
-  EqualNullSafe(fromExp, Literal(max, fromType))
-case _ => exp
+  val minCmp = ordering.compare(value, minInToType)
+  val maxCmp = ordering.compare(value, maxInToType)
+
+  if (maxCmp >= 0 || minCmp <= 0) {
+return if (maxCmp > 0) {
+  exp match {
+case EqualTo(_, _) | GreaterThan(_, _) | GreaterThanOrEqual(_, _) 
=>
+  falseIfNotNull(fromExp)
+case LessThan(_, _) | LessThanOrEqual(_, _) =>
+  trueIfNotNull(fromExp)
+// make sure the expression is evaluated if it is non-deterministic
+case EqualNullSafe(_, _) if exp.deterministic =>
+  FalseLiteral
+case _ => exp
+  }
+} else if (maxCmp == 0) {
+  exp match {
+case GreaterThan(_, _) =>
+  falseIfNotNull(fromExp)
+case LessThanOrEqual(_, _) =>
+  trueIfNotNull(fromExp)
+case LessThan(_, _) =>
+  Not(EqualTo(fromExp, Literal(max, fromType)))
+case GreaterThanOrEqual(_, _) | EqualTo(_, _) =>
+  EqualTo(fromExp, Literal(max, fromType))
+case EqualNullSafe(_, _) =>
+  EqualNullSafe(fromExp, Literal(max, fromType))
+case _ => exp
+  }
+} else if (minCmp < 0) {
+  exp match {
+case GreaterThan(_, _) | GreaterThanOrEqual(_, _) =>
+  trueIfNotNull(fromExp)
+case LessThan(_, _) | LessThanOrEqual(_, _) | EqualTo(_, _) =>
+  falseIfNotNull(fromExp)
+// make sure the expression is evaluated if it is non-deterministic
+case EqualNullSafe(_, _) if exp.deterministic =>
+  FalseLiteral
+case _ => exp
+  }
+} else { // minCmp == 0
+  exp match {
+case LessThan(_, _) =>
+  falseIfNotNull(fromExp)
+case GreaterThanOrEqual(_, _) =>
+  trueIfNotNull(fromExp)
+case GreaterThan(_, _) =>
+  Not(EqualTo(fromExp, Literal(min, fromType)))
+case LessThanOrEqual(_, _) | EqualTo(_, _) =>
+  EqualTo(fromExp, Literal(min, fromType))
+case EqualNullSafe(_, _) =>
+  EqualNullSafe(fromExp, Literal(min, fromType))
+case _ => exp
+  }
+}
   }
-} else if (minCmp < 0) {
+}
+
+// When we reach to this point, it means either there is no min/max for 
the `fromType` (e.g.,
+// decimal type), or that the literal `value` is within range `(min, 
max)`. For these, we

Review comment:
   makes sense.





This is an automated 

[GitHub] [spark] cloud-fan commented on a change in pull request #29792: [SPARK-32858][SQL] UnwrapCastInBinaryComparison: support other numeric types

2020-10-06 Thread GitBox


cloud-fan commented on a change in pull request #29792:
URL: https://github.com/apache/spark/pull/29792#discussion_r500243158



##
File path: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala
##
@@ -116,82 +132,118 @@ object UnwrapCastInBinaryComparison extends 
Rule[LogicalPlan] {
* optimizes the expression by moving the cast to the literal side. 
Otherwise if result is not
* true, this replaces the input binary comparison `exp` with simpler 
expressions.
*/
-  private def simplifyIntegralComparison(
+  private def simplifyNumericComparison(
   exp: BinaryComparison,
   fromExp: Expression,
-  toType: IntegralType,
+  toType: NumericType,
   value: Any): Expression = {
 
 val fromType = fromExp.dataType
-val (min, max) = getRange(fromType)
-val (minInToType, maxInToType) = {
-  (Cast(Literal(min), toType).eval(), Cast(Literal(max), toType).eval())
-}
 val ordering = toType.ordering.asInstanceOf[Ordering[Any]]
-val minCmp = ordering.compare(value, minInToType)
-val maxCmp = ordering.compare(value, maxInToType)
+val range = getRange(fromType)
 
-if (maxCmp > 0) {
-  exp match {
-case EqualTo(_, _) | GreaterThan(_, _) | GreaterThanOrEqual(_, _) =>
-  falseIfNotNull(fromExp)
-case LessThan(_, _) | LessThanOrEqual(_, _) =>
-  trueIfNotNull(fromExp)
-// make sure the expression is evaluated if it is non-deterministic
-case EqualNullSafe(_, _) if exp.deterministic =>
-  FalseLiteral
-case _ => exp
+if (range.isDefined) {
+  val (min, max) = range.get
+  val (minInToType, maxInToType) = {
+(Cast(Literal(min), toType).eval(), Cast(Literal(max), toType).eval())
   }
-} else if (maxCmp == 0) {
-  exp match {
-case GreaterThan(_, _) =>
-  falseIfNotNull(fromExp)
-case LessThanOrEqual(_, _) =>
-  trueIfNotNull(fromExp)
-case LessThan(_, _) =>
-  Not(EqualTo(fromExp, Literal(max, fromType)))
-case GreaterThanOrEqual(_, _) | EqualTo(_, _) =>
-  EqualTo(fromExp, Literal(max, fromType))
-case EqualNullSafe(_, _) =>
-  EqualNullSafe(fromExp, Literal(max, fromType))
-case _ => exp
+  val minCmp = ordering.compare(value, minInToType)
+  val maxCmp = ordering.compare(value, maxInToType)
+
+  if (maxCmp >= 0 || minCmp <= 0) {
+return if (maxCmp > 0) {
+  exp match {
+case EqualTo(_, _) | GreaterThan(_, _) | GreaterThanOrEqual(_, _) 
=>
+  falseIfNotNull(fromExp)
+case LessThan(_, _) | LessThanOrEqual(_, _) =>
+  trueIfNotNull(fromExp)
+// make sure the expression is evaluated if it is non-deterministic
+case EqualNullSafe(_, _) if exp.deterministic =>
+  FalseLiteral
+case _ => exp
+  }
+} else if (maxCmp == 0) {
+  exp match {
+case GreaterThan(_, _) =>
+  falseIfNotNull(fromExp)
+case LessThanOrEqual(_, _) =>
+  trueIfNotNull(fromExp)
+case LessThan(_, _) =>
+  Not(EqualTo(fromExp, Literal(max, fromType)))
+case GreaterThanOrEqual(_, _) | EqualTo(_, _) =>
+  EqualTo(fromExp, Literal(max, fromType))
+case EqualNullSafe(_, _) =>
+  EqualNullSafe(fromExp, Literal(max, fromType))
+case _ => exp
+  }
+} else if (minCmp < 0) {
+  exp match {
+case GreaterThan(_, _) | GreaterThanOrEqual(_, _) =>
+  trueIfNotNull(fromExp)
+case LessThan(_, _) | LessThanOrEqual(_, _) | EqualTo(_, _) =>
+  falseIfNotNull(fromExp)
+// make sure the expression is evaluated if it is non-deterministic
+case EqualNullSafe(_, _) if exp.deterministic =>
+  FalseLiteral
+case _ => exp
+  }
+} else { // minCmp == 0
+  exp match {
+case LessThan(_, _) =>
+  falseIfNotNull(fromExp)
+case GreaterThanOrEqual(_, _) =>
+  trueIfNotNull(fromExp)
+case GreaterThan(_, _) =>
+  Not(EqualTo(fromExp, Literal(min, fromType)))
+case LessThanOrEqual(_, _) | EqualTo(_, _) =>
+  EqualTo(fromExp, Literal(min, fromType))
+case EqualNullSafe(_, _) =>
+  EqualNullSafe(fromExp, Literal(min, fromType))
+case _ => exp
+  }
+}
   }
-} else if (minCmp < 0) {
+}
+
+// When we reach to this point, it means either there is no min/max for 
the `fromType` (e.g.,
+// decimal type), or that the literal `value` is within range `(min, 
max)`. For these, we

Review comment:
   is it only useful to `NaN`?





This is an 

[GitHub] [spark] cloud-fan commented on a change in pull request #29792: [SPARK-32858][SQL] UnwrapCastInBinaryComparison: support other numeric types

2020-10-06 Thread GitBox


cloud-fan commented on a change in pull request #29792:
URL: https://github.com/apache/spark/pull/29792#discussion_r500138374



##
File path: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala
##
@@ -116,82 +128,118 @@ object UnwrapCastInBinaryComparison extends 
Rule[LogicalPlan] {
* optimizes the expression by moving the cast to the literal side. 
Otherwise if result is not
* true, this replaces the input binary comparison `exp` with simpler 
expressions.
*/
-  private def simplifyIntegralComparison(
+  private def simplifyNumericComparison(
   exp: BinaryComparison,
   fromExp: Expression,
-  toType: IntegralType,
+  toType: NumericType,
   value: Any): Expression = {
 
 val fromType = fromExp.dataType
-val (min, max) = getRange(fromType)
-val (minInToType, maxInToType) = {
-  (Cast(Literal(min), toType).eval(), Cast(Literal(max), toType).eval())
-}
 val ordering = toType.ordering.asInstanceOf[Ordering[Any]]
-val minCmp = ordering.compare(value, minInToType)
-val maxCmp = ordering.compare(value, maxInToType)
+val range = getRange(fromType)
 
-if (maxCmp > 0) {
-  exp match {
-case EqualTo(_, _) | GreaterThan(_, _) | GreaterThanOrEqual(_, _) =>
-  falseIfNotNull(fromExp)
-case LessThan(_, _) | LessThanOrEqual(_, _) =>
-  trueIfNotNull(fromExp)
-// make sure the expression is evaluated if it is non-deterministic
-case EqualNullSafe(_, _) if exp.deterministic =>
-  FalseLiteral
-case _ => exp
+if (range.isDefined) {
+  val (min, max) = range.get
+  val (minInToType, maxInToType) = {
+(Cast(Literal(min), toType).eval(), Cast(Literal(max), toType).eval())
   }
-} else if (maxCmp == 0) {
-  exp match {
-case GreaterThan(_, _) =>
-  falseIfNotNull(fromExp)
-case LessThanOrEqual(_, _) =>
-  trueIfNotNull(fromExp)
-case LessThan(_, _) =>
-  Not(EqualTo(fromExp, Literal(max, fromType)))
-case GreaterThanOrEqual(_, _) | EqualTo(_, _) =>
-  EqualTo(fromExp, Literal(max, fromType))
-case EqualNullSafe(_, _) =>
-  EqualNullSafe(fromExp, Literal(max, fromType))
-case _ => exp
+  val minCmp = ordering.compare(value, minInToType)
+  val maxCmp = ordering.compare(value, maxInToType)
+
+  if (maxCmp >= 0 || minCmp <= 0) {
+return if (maxCmp > 0) {
+  exp match {
+case EqualTo(_, _) | GreaterThan(_, _) | GreaterThanOrEqual(_, _) 
=>
+  falseIfNotNull(fromExp)
+case LessThan(_, _) | LessThanOrEqual(_, _) =>
+  trueIfNotNull(fromExp)
+// make sure the expression is evaluated if it is non-deterministic
+case EqualNullSafe(_, _) if exp.deterministic =>
+  FalseLiteral
+case _ => exp
+  }
+} else if (maxCmp == 0) {
+  exp match {
+case GreaterThan(_, _) =>
+  falseIfNotNull(fromExp)
+case LessThanOrEqual(_, _) =>
+  trueIfNotNull(fromExp)
+case LessThan(_, _) =>
+  Not(EqualTo(fromExp, Literal(max, fromType)))
+case GreaterThanOrEqual(_, _) | EqualTo(_, _) =>
+  EqualTo(fromExp, Literal(max, fromType))
+case EqualNullSafe(_, _) =>
+  EqualNullSafe(fromExp, Literal(max, fromType))
+case _ => exp
+  }
+} else if (minCmp < 0) {
+  exp match {
+case GreaterThan(_, _) | GreaterThanOrEqual(_, _) =>
+  trueIfNotNull(fromExp)
+case LessThan(_, _) | LessThanOrEqual(_, _) | EqualTo(_, _) =>
+  falseIfNotNull(fromExp)
+// make sure the expression is evaluated if it is non-deterministic
+case EqualNullSafe(_, _) if exp.deterministic =>
+  FalseLiteral
+case _ => exp
+  }
+} else { // minCmp == 0
+  exp match {
+case LessThan(_, _) =>
+  falseIfNotNull(fromExp)
+case GreaterThanOrEqual(_, _) =>
+  trueIfNotNull(fromExp)
+case GreaterThan(_, _) =>
+  Not(EqualTo(fromExp, Literal(min, fromType)))
+case LessThanOrEqual(_, _) | EqualTo(_, _) =>
+  EqualTo(fromExp, Literal(min, fromType))
+case EqualNullSafe(_, _) =>
+  EqualNullSafe(fromExp, Literal(min, fromType))
+case _ => exp
+  }
+}
   }
-} else if (minCmp < 0) {
+}
+
+// When we reach to this point, it means either there is no min/max for 
the `fromType` (e.g.,
+// decimal type), or that the literal `value` is within range `(min, 
max)`. For these, we
+// optimize by moving the cast to the literal side.
+
+val newValue = Cast(Literal(value), fromType).eval()
+if 

[GitHub] [spark] cloud-fan commented on a change in pull request #29792: [SPARK-32858][SQL] UnwrapCastInBinaryComparison: support other numeric types

2020-10-06 Thread GitBox


cloud-fan commented on a change in pull request #29792:
URL: https://github.com/apache/spark/pull/29792#discussion_r500137824



##
File path: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala
##
@@ -116,82 +132,118 @@ object UnwrapCastInBinaryComparison extends 
Rule[LogicalPlan] {
* optimizes the expression by moving the cast to the literal side. 
Otherwise if result is not
* true, this replaces the input binary comparison `exp` with simpler 
expressions.
*/
-  private def simplifyIntegralComparison(
+  private def simplifyNumericComparison(
   exp: BinaryComparison,
   fromExp: Expression,
-  toType: IntegralType,
+  toType: NumericType,
   value: Any): Expression = {
 
 val fromType = fromExp.dataType
-val (min, max) = getRange(fromType)
-val (minInToType, maxInToType) = {
-  (Cast(Literal(min), toType).eval(), Cast(Literal(max), toType).eval())
-}
 val ordering = toType.ordering.asInstanceOf[Ordering[Any]]
-val minCmp = ordering.compare(value, minInToType)
-val maxCmp = ordering.compare(value, maxInToType)
+val range = getRange(fromType)
 
-if (maxCmp > 0) {
-  exp match {
-case EqualTo(_, _) | GreaterThan(_, _) | GreaterThanOrEqual(_, _) =>
-  falseIfNotNull(fromExp)
-case LessThan(_, _) | LessThanOrEqual(_, _) =>
-  trueIfNotNull(fromExp)
-// make sure the expression is evaluated if it is non-deterministic
-case EqualNullSafe(_, _) if exp.deterministic =>
-  FalseLiteral
-case _ => exp
+if (range.isDefined) {
+  val (min, max) = range.get
+  val (minInToType, maxInToType) = {
+(Cast(Literal(min), toType).eval(), Cast(Literal(max), toType).eval())
   }
-} else if (maxCmp == 0) {
-  exp match {
-case GreaterThan(_, _) =>
-  falseIfNotNull(fromExp)
-case LessThanOrEqual(_, _) =>
-  trueIfNotNull(fromExp)
-case LessThan(_, _) =>
-  Not(EqualTo(fromExp, Literal(max, fromType)))
-case GreaterThanOrEqual(_, _) | EqualTo(_, _) =>
-  EqualTo(fromExp, Literal(max, fromType))
-case EqualNullSafe(_, _) =>
-  EqualNullSafe(fromExp, Literal(max, fromType))
-case _ => exp
+  val minCmp = ordering.compare(value, minInToType)
+  val maxCmp = ordering.compare(value, maxInToType)
+
+  if (maxCmp >= 0 || minCmp <= 0) {
+return if (maxCmp > 0) {
+  exp match {
+case EqualTo(_, _) | GreaterThan(_, _) | GreaterThanOrEqual(_, _) 
=>
+  falseIfNotNull(fromExp)
+case LessThan(_, _) | LessThanOrEqual(_, _) =>
+  trueIfNotNull(fromExp)
+// make sure the expression is evaluated if it is non-deterministic
+case EqualNullSafe(_, _) if exp.deterministic =>
+  FalseLiteral
+case _ => exp
+  }
+} else if (maxCmp == 0) {
+  exp match {
+case GreaterThan(_, _) =>
+  falseIfNotNull(fromExp)
+case LessThanOrEqual(_, _) =>
+  trueIfNotNull(fromExp)
+case LessThan(_, _) =>
+  Not(EqualTo(fromExp, Literal(max, fromType)))
+case GreaterThanOrEqual(_, _) | EqualTo(_, _) =>
+  EqualTo(fromExp, Literal(max, fromType))
+case EqualNullSafe(_, _) =>
+  EqualNullSafe(fromExp, Literal(max, fromType))
+case _ => exp
+  }
+} else if (minCmp < 0) {
+  exp match {
+case GreaterThan(_, _) | GreaterThanOrEqual(_, _) =>
+  trueIfNotNull(fromExp)
+case LessThan(_, _) | LessThanOrEqual(_, _) | EqualTo(_, _) =>
+  falseIfNotNull(fromExp)
+// make sure the expression is evaluated if it is non-deterministic
+case EqualNullSafe(_, _) if exp.deterministic =>
+  FalseLiteral
+case _ => exp
+  }
+} else { // minCmp == 0
+  exp match {
+case LessThan(_, _) =>
+  falseIfNotNull(fromExp)
+case GreaterThanOrEqual(_, _) =>
+  trueIfNotNull(fromExp)
+case GreaterThan(_, _) =>
+  Not(EqualTo(fromExp, Literal(min, fromType)))
+case LessThanOrEqual(_, _) | EqualTo(_, _) =>
+  EqualTo(fromExp, Literal(min, fromType))
+case EqualNullSafe(_, _) =>
+  EqualNullSafe(fromExp, Literal(min, fromType))
+case _ => exp
+  }
+}
   }
-} else if (minCmp < 0) {
+}
+
+// When we reach to this point, it means either there is no min/max for 
the `fromType` (e.g.,
+// decimal type), or that the literal `value` is within range `(min, 
max)`. For these, we

Review comment:
   makes sense. Can you give an example of how the range of float/double 
helps the optimization?





[GitHub] [spark] cloud-fan commented on a change in pull request #29792: [SPARK-32858][SQL] UnwrapCastInBinaryComparison: support other numeric types

2020-09-29 Thread GitBox


cloud-fan commented on a change in pull request #29792:
URL: https://github.com/apache/spark/pull/29792#discussion_r496547966



##
File path: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala
##
@@ -116,82 +132,118 @@ object UnwrapCastInBinaryComparison extends 
Rule[LogicalPlan] {
* optimizes the expression by moving the cast to the literal side. 
Otherwise if result is not
* true, this replaces the input binary comparison `exp` with simpler 
expressions.
*/
-  private def simplifyIntegralComparison(
+  private def simplifyNumericComparison(
   exp: BinaryComparison,
   fromExp: Expression,
-  toType: IntegralType,
+  toType: NumericType,
   value: Any): Expression = {
 
 val fromType = fromExp.dataType
-val (min, max) = getRange(fromType)
-val (minInToType, maxInToType) = {
-  (Cast(Literal(min), toType).eval(), Cast(Literal(max), toType).eval())
-}
 val ordering = toType.ordering.asInstanceOf[Ordering[Any]]
-val minCmp = ordering.compare(value, minInToType)
-val maxCmp = ordering.compare(value, maxInToType)
+val range = getRange(fromType)
 
-if (maxCmp > 0) {
-  exp match {
-case EqualTo(_, _) | GreaterThan(_, _) | GreaterThanOrEqual(_, _) =>
-  falseIfNotNull(fromExp)
-case LessThan(_, _) | LessThanOrEqual(_, _) =>
-  trueIfNotNull(fromExp)
-// make sure the expression is evaluated if it is non-deterministic
-case EqualNullSafe(_, _) if exp.deterministic =>
-  FalseLiteral
-case _ => exp
+if (range.isDefined) {
+  val (min, max) = range.get
+  val (minInToType, maxInToType) = {
+(Cast(Literal(min), toType).eval(), Cast(Literal(max), toType).eval())
   }
-} else if (maxCmp == 0) {
-  exp match {
-case GreaterThan(_, _) =>
-  falseIfNotNull(fromExp)
-case LessThanOrEqual(_, _) =>
-  trueIfNotNull(fromExp)
-case LessThan(_, _) =>
-  Not(EqualTo(fromExp, Literal(max, fromType)))
-case GreaterThanOrEqual(_, _) | EqualTo(_, _) =>
-  EqualTo(fromExp, Literal(max, fromType))
-case EqualNullSafe(_, _) =>
-  EqualNullSafe(fromExp, Literal(max, fromType))
-case _ => exp
+  val minCmp = ordering.compare(value, minInToType)
+  val maxCmp = ordering.compare(value, maxInToType)
+
+  if (maxCmp >= 0 || minCmp <= 0) {
+return if (maxCmp > 0) {
+  exp match {
+case EqualTo(_, _) | GreaterThan(_, _) | GreaterThanOrEqual(_, _) 
=>
+  falseIfNotNull(fromExp)
+case LessThan(_, _) | LessThanOrEqual(_, _) =>
+  trueIfNotNull(fromExp)
+// make sure the expression is evaluated if it is non-deterministic
+case EqualNullSafe(_, _) if exp.deterministic =>
+  FalseLiteral
+case _ => exp
+  }
+} else if (maxCmp == 0) {
+  exp match {
+case GreaterThan(_, _) =>
+  falseIfNotNull(fromExp)
+case LessThanOrEqual(_, _) =>
+  trueIfNotNull(fromExp)
+case LessThan(_, _) =>
+  Not(EqualTo(fromExp, Literal(max, fromType)))
+case GreaterThanOrEqual(_, _) | EqualTo(_, _) =>
+  EqualTo(fromExp, Literal(max, fromType))
+case EqualNullSafe(_, _) =>
+  EqualNullSafe(fromExp, Literal(max, fromType))
+case _ => exp
+  }
+} else if (minCmp < 0) {
+  exp match {
+case GreaterThan(_, _) | GreaterThanOrEqual(_, _) =>
+  trueIfNotNull(fromExp)
+case LessThan(_, _) | LessThanOrEqual(_, _) | EqualTo(_, _) =>
+  falseIfNotNull(fromExp)
+// make sure the expression is evaluated if it is non-deterministic
+case EqualNullSafe(_, _) if exp.deterministic =>
+  FalseLiteral
+case _ => exp
+  }
+} else { // minCmp == 0
+  exp match {
+case LessThan(_, _) =>
+  falseIfNotNull(fromExp)
+case GreaterThanOrEqual(_, _) =>
+  trueIfNotNull(fromExp)
+case GreaterThan(_, _) =>
+  Not(EqualTo(fromExp, Literal(min, fromType)))
+case LessThanOrEqual(_, _) | EqualTo(_, _) =>
+  EqualTo(fromExp, Literal(min, fromType))
+case EqualNullSafe(_, _) =>
+  EqualNullSafe(fromExp, Literal(min, fromType))
+case _ => exp
+  }
+}
   }
-} else if (minCmp < 0) {
+}
+
+// When we reach to this point, it means either there is no min/max for 
the `fromType` (e.g.,
+// decimal type), or that the literal `value` is within range `(min, 
max)`. For these, we
+// optimize by moving the cast to the literal side.
+
+val newValue = Cast(Literal(value), fromType).eval()
+if 

[GitHub] [spark] cloud-fan commented on a change in pull request #29792: [SPARK-32858][SQL] UnwrapCastInBinaryComparison: support other numeric types

2020-09-29 Thread GitBox


cloud-fan commented on a change in pull request #29792:
URL: https://github.com/apache/spark/pull/29792#discussion_r496536630



##
File path: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala
##
@@ -116,82 +132,118 @@ object UnwrapCastInBinaryComparison extends 
Rule[LogicalPlan] {
* optimizes the expression by moving the cast to the literal side. 
Otherwise if result is not
* true, this replaces the input binary comparison `exp` with simpler 
expressions.
*/
-  private def simplifyIntegralComparison(
+  private def simplifyNumericComparison(
   exp: BinaryComparison,
   fromExp: Expression,
-  toType: IntegralType,
+  toType: NumericType,
   value: Any): Expression = {
 
 val fromType = fromExp.dataType
-val (min, max) = getRange(fromType)
-val (minInToType, maxInToType) = {
-  (Cast(Literal(min), toType).eval(), Cast(Literal(max), toType).eval())
-}
 val ordering = toType.ordering.asInstanceOf[Ordering[Any]]
-val minCmp = ordering.compare(value, minInToType)
-val maxCmp = ordering.compare(value, maxInToType)
+val range = getRange(fromType)
 
-if (maxCmp > 0) {
-  exp match {
-case EqualTo(_, _) | GreaterThan(_, _) | GreaterThanOrEqual(_, _) =>
-  falseIfNotNull(fromExp)
-case LessThan(_, _) | LessThanOrEqual(_, _) =>
-  trueIfNotNull(fromExp)
-// make sure the expression is evaluated if it is non-deterministic
-case EqualNullSafe(_, _) if exp.deterministic =>
-  FalseLiteral
-case _ => exp
+if (range.isDefined) {
+  val (min, max) = range.get
+  val (minInToType, maxInToType) = {
+(Cast(Literal(min), toType).eval(), Cast(Literal(max), toType).eval())
   }
-} else if (maxCmp == 0) {
-  exp match {
-case GreaterThan(_, _) =>
-  falseIfNotNull(fromExp)
-case LessThanOrEqual(_, _) =>
-  trueIfNotNull(fromExp)
-case LessThan(_, _) =>
-  Not(EqualTo(fromExp, Literal(max, fromType)))
-case GreaterThanOrEqual(_, _) | EqualTo(_, _) =>
-  EqualTo(fromExp, Literal(max, fromType))
-case EqualNullSafe(_, _) =>
-  EqualNullSafe(fromExp, Literal(max, fromType))
-case _ => exp
+  val minCmp = ordering.compare(value, minInToType)
+  val maxCmp = ordering.compare(value, maxInToType)
+
+  if (maxCmp >= 0 || minCmp <= 0) {
+return if (maxCmp > 0) {
+  exp match {
+case EqualTo(_, _) | GreaterThan(_, _) | GreaterThanOrEqual(_, _) 
=>
+  falseIfNotNull(fromExp)
+case LessThan(_, _) | LessThanOrEqual(_, _) =>
+  trueIfNotNull(fromExp)
+// make sure the expression is evaluated if it is non-deterministic
+case EqualNullSafe(_, _) if exp.deterministic =>
+  FalseLiteral
+case _ => exp
+  }
+} else if (maxCmp == 0) {
+  exp match {
+case GreaterThan(_, _) =>
+  falseIfNotNull(fromExp)
+case LessThanOrEqual(_, _) =>
+  trueIfNotNull(fromExp)
+case LessThan(_, _) =>
+  Not(EqualTo(fromExp, Literal(max, fromType)))
+case GreaterThanOrEqual(_, _) | EqualTo(_, _) =>
+  EqualTo(fromExp, Literal(max, fromType))
+case EqualNullSafe(_, _) =>
+  EqualNullSafe(fromExp, Literal(max, fromType))
+case _ => exp
+  }
+} else if (minCmp < 0) {
+  exp match {
+case GreaterThan(_, _) | GreaterThanOrEqual(_, _) =>
+  trueIfNotNull(fromExp)
+case LessThan(_, _) | LessThanOrEqual(_, _) | EqualTo(_, _) =>
+  falseIfNotNull(fromExp)
+// make sure the expression is evaluated if it is non-deterministic
+case EqualNullSafe(_, _) if exp.deterministic =>
+  FalseLiteral
+case _ => exp
+  }
+} else { // minCmp == 0
+  exp match {
+case LessThan(_, _) =>
+  falseIfNotNull(fromExp)
+case GreaterThanOrEqual(_, _) =>
+  trueIfNotNull(fromExp)
+case GreaterThan(_, _) =>
+  Not(EqualTo(fromExp, Literal(min, fromType)))
+case LessThanOrEqual(_, _) | EqualTo(_, _) =>
+  EqualTo(fromExp, Literal(min, fromType))
+case EqualNullSafe(_, _) =>
+  EqualNullSafe(fromExp, Literal(min, fromType))
+case _ => exp
+  }
+}
   }
-} else if (minCmp < 0) {
+}
+
+// When we reach to this point, it means either there is no min/max for 
the `fromType` (e.g.,
+// decimal type), or that the literal `value` is within range `(min, 
max)`. For these, we

Review comment:
   why it's safe to skip range check for decimal type?





[GitHub] [spark] cloud-fan commented on a change in pull request #29792: [SPARK-32858][SQL] UnwrapCastInBinaryComparison: support other numeric types

2020-09-29 Thread GitBox


cloud-fan commented on a change in pull request #29792:
URL: https://github.com/apache/spark/pull/29792#discussion_r496529759



##
File path: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala
##
@@ -200,25 +252,27 @@ object UnwrapCastInBinaryComparison extends 
Rule[LogicalPlan] {
   /**
* Check if the input `fromExp` can be safely cast to `toType` without any 
loss of precision,
* i.e., the conversion is injective. Note this only handles the case when 
both sides are of
-   * integral type.
+   * numeric type.
*/
   private def canImplicitlyCast(
   fromExp: Expression,
   toType: DataType,
   literalType: DataType): Boolean = {
 toType.sameType(literalType) &&
   !fromExp.foldable &&
-  fromExp.dataType.isInstanceOf[IntegralType] &&
-  toType.isInstanceOf[IntegralType] &&
+  fromExp.dataType.isInstanceOf[NumericType] &&
+  toType.isInstanceOf[NumericType] &&
   Cast.canUpCast(fromExp.dataType, toType)
   }
 
-  private def getRange(dt: DataType): (Any, Any) = dt match {
-case ByteType => (Byte.MinValue, Byte.MaxValue)
-case ShortType => (Short.MinValue, Short.MaxValue)
-case IntegerType => (Int.MinValue, Int.MaxValue)
-case LongType => (Long.MinValue, Long.MaxValue)
-case other => throw new IllegalArgumentException(s"Unsupported type: 
${other.catalogString}")
+  private def getRange(dt: DataType): Option[(Any, Any)] = dt match {
+case ByteType => Some((Byte.MinValue, Byte.MaxValue))
+case ShortType => Some((Short.MinValue, Short.MaxValue))
+case IntegerType => Some((Int.MinValue, Int.MaxValue))
+case LongType => Some((Long.MinValue, Long.MaxValue))
+case FloatType => Some((Float.NegativeInfinity, Float.NaN))

Review comment:
   why the upper bound is not `PositiveInfinity`?





This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org



-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] [spark] cloud-fan commented on a change in pull request #29792: [SPARK-32858][SQL] UnwrapCastInBinaryComparison: support other numeric types

2020-09-29 Thread GitBox


cloud-fan commented on a change in pull request #29792:
URL: https://github.com/apache/spark/pull/29792#discussion_r496523301



##
File path: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala
##
@@ -35,18 +35,34 @@ import org.apache.spark.sql.types._
  * to be optimized away later and pushed down to data sources.
  *
  * Currently this only handles cases where:
- *   1). `fromType` (of `fromExp`) and `toType` are of integral types (i.e., 
byte, short, int and
- * long)
+ *   1). `fromType` (of `fromExp`) and `toType` are of numeric types (i.e., 
short, int, float,
+ * decimal, etc)
  *   2). `fromType` can be safely coerced to `toType` without precision loss 
(e.g., short to int,
  * int to long, but not long to int)
  *
  * If the above conditions are satisfied, the rule checks to see if the 
literal `value` is within
  * range `(min, max)`, where `min` and `max` are the minimum and maximum value 
of `fromType`,
- * respectively. If this is true then it means we can safely cast `value` to 
`fromType` and thus
+ * respectively. If this is true then it means we may safely cast `value` to 
`fromType` and thus
  * able to move the cast to the literal side. That is:
  *
  *   `cast(fromExp, toType) op value` ==> `fromExp op cast(value, fromType)`
  *
+ * Note there are some exceptions to the above: if casting from `value` to 
`fromType` causes
+ * rounding up or down, the above conversion will no longer be valid. Instead, 
the rule does the
+ * following:
+ *
+ * if casting `value` to `fromType` causes rounding up:
+ *  - `cast(fromExp, toType) > value` ==> `fromExp >= cast(value, fromType)`
+ *  - `cast(fromExp, toType) >= value` ==> `fromExp >= cast(value, fromType)`
+ *  - `cast(fromExp, toType) === value` ==> if(isnull(fromExp), null, false)
+ *  - `cast(fromExp, toType) <=> value` ==> false (if `fromExp` is 
deterministic)
+ *  - `cast(fromExp, toType) <=> value` ==> `cast(fromExp, toType) <=> value` 
(if `fromExp` is

Review comment:
   We can remove this because the rule does nothing for it.





This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org



-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] [spark] cloud-fan commented on a change in pull request #29792: [SPARK-32858][SQL] UnwrapCastInBinaryComparison: support other numeric types

2020-09-25 Thread GitBox


cloud-fan commented on a change in pull request #29792:
URL: https://github.com/apache/spark/pull/29792#discussion_r495008728



##
File path: 
sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparisonSuite.scala
##
@@ -79,13 +106,65 @@ class UnwrapCastInBinaryComparisonSuite extends PlanTest 
with ExpressionEvalHelp
 assertEquivalent(castInt(f) < v, falseIfNotNull(f))
   }
 
-  test("unwrap casts when literal is within range (min, max)") {
-assertEquivalent(castInt(f) > 300, f > 300.toShort)
-assertEquivalent(castInt(f) >= 500, f >= 500.toShort)
-assertEquivalent(castInt(f) === 32766, f === 32766.toShort)
-assertEquivalent(castInt(f) <=> 32766, f <=> 32766.toShort)
-assertEquivalent(castInt(f) <= -6000, f <= -6000.toShort)
-assertEquivalent(castInt(f) < -32767, f < -32767.toShort)
+  test("unwrap casts when literal is within range (min, max) or fromType has 
no range") {
+Seq(300, 500, 32766, -6000, -32767).foreach(v => {
+  assertEquivalent(castInt(f) > v, f > v.toShort)
+  assertEquivalent(castInt(f) >= v, f >= v.toShort)
+  assertEquivalent(castInt(f) === v, f === v.toShort)
+  assertEquivalent(castInt(f) <=> v, f <=> v.toShort)
+  assertEquivalent(castInt(f) <= v, f <= v.toShort)
+  assertEquivalent(castInt(f) < v, f < v.toShort)
+})
+
+Seq(3.14.toFloat.toDouble, -1000.0.toFloat.toDouble,
+  20.0.toFloat.toDouble, -2.414.toFloat.toDouble,
+  Float.MinValue.toDouble, Float.MaxValue.toDouble, 
Float.PositiveInfinity.toDouble
+).foreach(v => {
+  assertEquivalent(castDouble(f2) > v, f2 > v.toFloat)
+  assertEquivalent(castDouble(f2) >= v, f2 >= v.toFloat)
+  assertEquivalent(castDouble(f2) === v, f2 === v.toFloat)
+  assertEquivalent(castDouble(f2) <=> v, f2 <=> v.toFloat)
+  assertEquivalent(castDouble(f2) <= v, f2 <= v.toFloat)
+  assertEquivalent(castDouble(f2) < v, f2 < v.toFloat)
+})
+
+Seq(decimal2(100.20), decimal2(-200.50)).foreach(v => {
+  assertEquivalent(castDecimal2(f3) > v, f3 > decimal(v))
+  assertEquivalent(castDecimal2(f3) >= v, f3 >= decimal(v))
+  assertEquivalent(castDecimal2(f3) === v, f3 === decimal(v))
+  assertEquivalent(castDecimal2(f3) <=> v, f3 <=> decimal(v))
+  assertEquivalent(castDecimal2(f3) <= v, f3 <= decimal(v))
+  assertEquivalent(castDecimal2(f3) < v, f3 < decimal(v))
+})
+  }
+
+  test("unwrap cast when literal is within range (min, max) AND has round up 
or down") {
+// Cases for rounding down
+var doubleValue = 100.6
+assertEquivalent(castDouble(f) > doubleValue, f > doubleValue.toShort)
+assertEquivalent(castDouble(f) > doubleValue, f > doubleValue.toShort)
+assertEquivalent(castDouble(f) === doubleValue, falseIfNotNull(f))
+assertEquivalent(castDouble(f) <=> doubleValue, false)
+assertEquivalent(castDouble(f) <= doubleValue, f <= doubleValue.toShort)
+assertEquivalent(castDouble(f) < doubleValue, f <= doubleValue.toShort)
+
+// Cases for rounding up: 3.14 will be rounded to 3.1410... after 
casting to float

Review comment:
   This is an important point. Can we explain how to know it's rounding up 
or down in the PR description?





This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org



-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] [spark] cloud-fan commented on a change in pull request #29792: [SPARK-32858][SQL] UnwrapCastInBinaryComparison: support other numeric types

2020-09-23 Thread GitBox


cloud-fan commented on a change in pull request #29792:
URL: https://github.com/apache/spark/pull/29792#discussion_r493731636



##
File path: 
sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparisonSuite.scala
##
@@ -79,13 +106,65 @@ class UnwrapCastInBinaryComparisonSuite extends PlanTest 
with ExpressionEvalHelp
 assertEquivalent(castInt(f) < v, falseIfNotNull(f))
   }
 
-  test("unwrap casts when literal is within range (min, max)") {
-assertEquivalent(castInt(f) > 300, f > 300.toShort)
-assertEquivalent(castInt(f) >= 500, f >= 500.toShort)
-assertEquivalent(castInt(f) === 32766, f === 32766.toShort)
-assertEquivalent(castInt(f) <=> 32766, f <=> 32766.toShort)
-assertEquivalent(castInt(f) <= -6000, f <= -6000.toShort)
-assertEquivalent(castInt(f) < -32767, f < -32767.toShort)
+  test("unwrap casts when literal is within range (min, max) or fromType has 
no range") {
+Seq(300, 500, 32766, -6000, -32767).foreach(v => {
+  assertEquivalent(castInt(f) > v, f > v.toShort)
+  assertEquivalent(castInt(f) >= v, f >= v.toShort)
+  assertEquivalent(castInt(f) === v, f === v.toShort)
+  assertEquivalent(castInt(f) <=> v, f <=> v.toShort)
+  assertEquivalent(castInt(f) <= v, f <= v.toShort)
+  assertEquivalent(castInt(f) < v, f < v.toShort)
+})
+
+Seq(3.14.toFloat.toDouble, -1000.0.toFloat.toDouble,
+  20.0.toFloat.toDouble, -2.414.toFloat.toDouble,
+  Float.MinValue.toDouble, Float.MaxValue.toDouble, 
Float.PositiveInfinity.toDouble
+).foreach(v => {
+  assertEquivalent(castDouble(f2) > v, f2 > v.toFloat)
+  assertEquivalent(castDouble(f2) >= v, f2 >= v.toFloat)
+  assertEquivalent(castDouble(f2) === v, f2 === v.toFloat)
+  assertEquivalent(castDouble(f2) <=> v, f2 <=> v.toFloat)
+  assertEquivalent(castDouble(f2) <= v, f2 <= v.toFloat)
+  assertEquivalent(castDouble(f2) < v, f2 < v.toFloat)
+})
+
+Seq(decimal2(100.20), decimal2(-200.50)).foreach(v => {
+  assertEquivalent(castDecimal2(f3) > v, f3 > decimal(v))
+  assertEquivalent(castDecimal2(f3) >= v, f3 >= decimal(v))
+  assertEquivalent(castDecimal2(f3) === v, f3 === decimal(v))
+  assertEquivalent(castDecimal2(f3) <=> v, f3 <=> decimal(v))
+  assertEquivalent(castDecimal2(f3) <= v, f3 <= decimal(v))
+  assertEquivalent(castDecimal2(f3) < v, f3 < decimal(v))
+})
+  }
+
+  test("unwrap cast when literal is within range (min, max) AND has round up 
or down") {
+// Cases for rounding down
+var doubleValue = 100.6
+assertEquivalent(castDouble(f) > doubleValue, f > doubleValue.toShort)
+assertEquivalent(castDouble(f) > doubleValue, f > doubleValue.toShort)
+assertEquivalent(castDouble(f) === doubleValue, falseIfNotNull(f))
+assertEquivalent(castDouble(f) <=> doubleValue, false)
+assertEquivalent(castDouble(f) <= doubleValue, f <= doubleValue.toShort)
+assertEquivalent(castDouble(f) < doubleValue, f <= doubleValue.toShort)
+
+// Cases for rounding up: 3.14 will be rounded to 3.1410... after 
casting to float

Review comment:
   so casting double to float can be either rounding up or down, depend on 
the value?





This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org



-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] [spark] cloud-fan commented on a change in pull request #29792: [SPARK-32858][SQL] UnwrapCastInBinaryComparison: support other numeric types

2020-09-23 Thread GitBox


cloud-fan commented on a change in pull request #29792:
URL: https://github.com/apache/spark/pull/29792#discussion_r493216123



##
File path: 
sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparisonSuite.scala
##
@@ -67,6 +77,23 @@ class UnwrapCastInBinaryComparisonSuite extends PlanTest 
with ExpressionEvalHelp
 assertEquivalent(castInt(f) <=> v.toInt, f <=> v)
 assertEquivalent(castInt(f) <= v.toInt, f === v)
 assertEquivalent(castInt(f) < v.toInt, falseIfNotNull(f))
+
+val d = Float.NegativeInfinity
+assertEquivalent(castDouble(f2) > d.toDouble, f2 =!= d)

Review comment:
   is casting double to float rounding up or rounding down?





This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org



-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org