Copilot commented on code in PR #2357:
URL: https://github.com/apache/sedona/pull/2357#discussion_r2358266196


##########
spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/BarrierFunction.scala:
##########
@@ -0,0 +1,198 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.spark.sql.sedona_sql.expressions
+
+import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.types.{DataType, BooleanType, StringType}
+import org.apache.spark.unsafe.types.UTF8String
+import scala.util.parsing.combinator._
+
+/**
+ * Barrier function to prevent filter pushdown and control predicate 
evaluation order. Takes a
+ * boolean expression string followed by pairs of variable names and their 
values.
+ *
+ * Usage: barrier(expression, var_name1, var_value1, var_name2, var_value2, 
...) Example:
+ * barrier('rating > 4.0 AND stars >= 4', 'rating', r.rating, 'stars', h.stars)
+ *
+ * Extends CodegenFallback to prevent Catalyst optimizer from pushing this 
filter through joins.
+ * CodegenFallback makes this expression opaque to optimization rules, 
ensuring it evaluates at
+ * runtime in its original position within the query plan.
+ */
+private[apache] case class Barrier(inputExpressions: Seq[Expression])
+    extends Expression
+    with CodegenFallback {
+
+  override def nullable: Boolean = false
+
+  override def dataType: DataType = BooleanType
+
+  override def children: Seq[Expression] = inputExpressions
+
+  override def eval(input: InternalRow): Any = {
+    // Get the expression string
+    val exprString = inputExpressions.head.eval(input) match {
+      case s: UTF8String => s.toString
+      case null => throw new IllegalArgumentException("Barrier expression 
cannot be null")
+      case other =>
+        throw new IllegalArgumentException(
+          s"Barrier expression must be a string, got: ${other.getClass}")
+    }
+
+    // Build variable map from pairs
+    val varMap = scala.collection.mutable.Map[String, Any]()
+    var i = 1
+    while (i < inputExpressions.length) {
+      if (i + 1 >= inputExpressions.length) {
+        throw new IllegalArgumentException(
+          "Barrier function requires pairs of variable names and values")
+      }
+
+      val varName = inputExpressions(i).eval(input) match {
+        case s: UTF8String => s.toString
+        case null => throw new IllegalArgumentException("Variable name cannot 
be null")
+        case other =>
+          throw new IllegalArgumentException(
+            s"Variable name must be a string, got: ${other.getClass}")
+      }
+
+      val varValue = inputExpressions(i + 1).eval(input)
+      varMap(varName) = varValue
+      i += 2
+    }
+
+    // Evaluate the expression with variable substitution
+    evaluateBooleanExpression(exprString, varMap.toMap)
+  }
+
+  /**
+   * Evaluates a boolean expression string with variable substitution. 
Supports basic comparison
+   * operators and logical operators (AND, OR, NOT).
+   */
+  private def evaluateBooleanExpression(
+      expression: String,
+      variables: Map[String, Any]): Boolean = {
+    val parser = new BooleanExpressionParser(variables)
+    parser.parseExpression(expression) match {
+      case parser.Success(result, _) => result
+      case parser.NoSuccess(msg, _) =>
+        throw new IllegalArgumentException(s"Failed to parse barrier 
expression: $msg")
+    }
+  }
+
+  protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): 
Expression = {
+    copy(inputExpressions = newChildren)
+  }
+}
+
+/**
+ * Parser for boolean expressions in barrier function. Supports comparison 
operators: =, !=, <>,
+ * <, <=, >, >= Supports logical operators: AND, OR, NOT Supports parentheses 
for grouping
+ */
+private class BooleanExpressionParser(variables: Map[String, Any]) extends 
JavaTokenParsers {
+
+  def parseExpression(expr: String): ParseResult[Boolean] = parseAll(boolExpr, 
expr)
+
+  def boolExpr: Parser[Boolean] = orExpr
+
+  def orExpr: Parser[Boolean] = andExpr ~ rep("(?i)OR".r ~> andExpr) ^^ { case 
left ~ rights =>
+    rights.foldLeft(left)(_ || _)
+  }
+
+  def andExpr: Parser[Boolean] = notExpr ~ rep("(?i)AND".r ~> notExpr) ^^ { 
case left ~ rights =>
+    rights.foldLeft(left)(_ && _)
+  }
+
+  def notExpr: Parser[Boolean] =
+    "(?i)NOT".r ~> notExpr ^^ (!_) |
+      primaryExpr
+
+  def primaryExpr: Parser[Boolean] =
+    "(" ~> boolExpr <~ ")" |
+      attempt(comparison) |
+      booleanValue
+
+  def comparison: Parser[Boolean] = value ~ compOp ~ value ^^ { case left ~ op 
~ right =>
+    compareValues(left, op, right)
+  }
+
+  def attempt[T](p: Parser[T]): Parser[T] = Parser { in =>
+    p(in) match {
+      case s @ Success(_, _) => s
+      case _ => Failure("", in)
+    }
+  }
+
+  def booleanValue: Parser[Boolean] =
+    "(?i)true".r ^^ (_ => true) |
+      "(?i)false".r ^^ (_ => false) |
+      ident.filter(id => !id.toUpperCase.matches("AND|OR|NOT")) ^^ { name =>

Review Comment:
   The regular expression pattern `\"(?i)true\".r` and `\"(?i)false\".r` are 
compiled on every parse. Consider pre-compiling these regex patterns as class 
constants to improve performance.



##########
spark/common/src/test/scala/org/apache/sedona/sql/BarrierFunctionTest.scala:
##########
@@ -0,0 +1,437 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sedona.sql
+
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.types.{DoubleType, IntegerType, StringType, 
StructField, StructType, BooleanType}
+import org.apache.spark.sql.functions.expr
+import org.scalatest.matchers.should.Matchers.{convertToAnyShouldWrapper, be}
+import org.scalatest.prop.TableDrivenPropertyChecks
+import scala.collection.JavaConverters._
+
+class BarrierFunctionTest extends TestBaseScala with TableDrivenPropertyChecks 
{
+
+  describe("Barrier Function Test") {
+
+    it("should evaluate simple comparison expressions") {
+      // Create test data
+      val testDf = sparkSession
+        .createDataFrame(
+          Seq(Row(1, 4.5, 5), Row(2, 3.0, 3), Row(3, 5.0, 4), Row(4, 2.0, 
2)).asJava,
+          StructType(
+            Seq(
+              StructField("id", IntegerType, false),
+              StructField("rating", DoubleType, false),
+              StructField("stars", IntegerType, false))))
+
+      testDf.createOrReplaceTempView("test_table")
+
+      // Test simple greater than
+      val result1 =
+        sparkSession.sql("""SELECT id, barrier('rating > 4.0', 'rating', 
rating) as result
+           FROM test_table""")
+      val expected1 = Seq(true, false, true, false)
+      result1.collect().map(_.getBoolean(1)) should be(expected1)
+
+      // Test AND condition
+      val result2 = sparkSession.sql("""SELECT id, barrier('rating > 4.0 AND 
stars >= 4',
+                              'rating', rating,
+                              'stars', stars) as result
+           FROM test_table""")
+      val expected2 = Seq(true, false, true, false)
+      result2.collect().map(_.getBoolean(1)) should be(expected2)
+
+      // Test OR condition
+      val result3 = sparkSession.sql("""SELECT id, barrier('rating < 3.0 OR 
stars >= 5',
+                              'rating', rating,
+                              'stars', stars) as result
+           FROM test_table""")
+      val expected3 = Seq(true, false, false, true)
+      result3.collect().map(_.getBoolean(1)) should be(expected3)
+    }
+
+    it("should handle different comparison operators") {
+      val testDf = sparkSession
+        .createDataFrame(
+          Seq(Row(1, 10, 10), Row(2, 20, 30), Row(3, 15, 15), Row(4, 25, 
20)).asJava,
+          StructType(
+            Seq(
+              StructField("id", IntegerType, false),
+              StructField("val1", IntegerType, false),
+              StructField("val2", IntegerType, false))))
+
+      testDf.createOrReplaceTempView("test_table")
+
+      // Test equals
+      val result1 = sparkSession.sql("""SELECT id, barrier('val1 = val2',
+                              'val1', val1,
+                              'val2', val2) as result
+           FROM test_table""")
+      val expected1 = Seq(true, false, true, false)
+      result1.collect().map(_.getBoolean(1)) should be(expected1)
+
+      // Test not equals
+      val result2 = sparkSession.sql("""SELECT id, barrier('val1 != val2',
+                              'val1', val1,
+                              'val2', val2) as result
+           FROM test_table""")
+      val expected2 = Seq(false, true, false, true)
+      result2.collect().map(_.getBoolean(1)) should be(expected2)
+
+      // Test less than or equal
+      val result3 = sparkSession.sql("""SELECT id, barrier('val1 <= val2',
+                              'val1', val1,
+                              'val2', val2) as result
+           FROM test_table""")
+      val expected3 = Seq(true, true, true, false)
+      result3.collect().map(_.getBoolean(1)) should be(expected3)
+    }
+
+    it("should handle NOT operator") {
+      val testDf = sparkSession
+        .createDataFrame(
+          Seq(Row(1, true), Row(2, false), Row(3, true), Row(4, false)).asJava,
+          StructType(
+            Seq(StructField("id", IntegerType, false), StructField("flag", 
BooleanType, false))))
+
+      testDf.createOrReplaceTempView("test_table")
+
+      val result = sparkSession.sql("""SELECT id, barrier('NOT flag', 'flag', 
flag) as result
+           FROM test_table""")
+      val expected = Seq(false, true, false, true)
+      result.collect().map(_.getBoolean(1)) should be(expected)
+    }
+
+    it("should handle parentheses for grouping") {
+      val testDf = sparkSession
+        .createDataFrame(
+          Seq(
+            Row(1, 10, 20, 30),
+            Row(2, 15, 25, 35),
+            Row(3, 5, 15, 25),
+            Row(4, 20, 10, 5)).asJava,
+          StructType(
+            Seq(
+              StructField("id", IntegerType, false),
+              StructField("a", IntegerType, false),
+              StructField("b", IntegerType, false),
+              StructField("c", IntegerType, false))))
+
+      testDf.createOrReplaceTempView("test_table")
+
+      // Test with parentheses
+      val result =
+        sparkSession.sql("""SELECT id, barrier('(a < b AND b < c) OR (a > b 
AND b > c)',
+                             'a', a, 'b', b, 'c', c) as result
+           FROM test_table""")
+      val expected = Seq(true, true, true, true)
+      result.collect().map(_.getBoolean(1)) should be(expected)
+    }
+
+    it("should handle string comparisons") {
+      val testDf = sparkSession
+        .createDataFrame(
+          Seq(
+            Row(1, "apple", "banana"),
+            Row(2, "zebra", "apple"),
+            Row(3, "cat", "cat"),
+            Row(4, "dog", "cat")).asJava,
+          StructType(
+            Seq(
+              StructField("id", IntegerType, false),
+              StructField("str1", StringType, false),
+              StructField("str2", StringType, false))))
+
+      testDf.createOrReplaceTempView("test_table")
+
+      val result = sparkSession.sql("""SELECT id, barrier('str1 < str2',
+                             'str1', str1,
+                             'str2', str2) as result
+           FROM test_table""")
+      val expected = Seq(true, false, false, false)
+      result.collect().map(_.getBoolean(1)) should be(expected)
+    }
+
+    it("should handle null values") {
+      val testDf = sparkSession
+        .createDataFrame(
+          Seq(Row(1, 10, 20), Row(2, null, 20), Row(3, 10, null), Row(4, null, 
null)).asJava,
+          StructType(
+            Seq(
+              StructField("id", IntegerType, false),
+              StructField("val1", IntegerType, true),
+              StructField("val2", IntegerType, true))))
+
+      testDf.createOrReplaceTempView("test_table")
+
+      // Test null comparisons
+      val result = sparkSession.sql("""SELECT id, barrier('val1 = val2',
+                             'val1', val1,
+                             'val2', val2) as result
+           FROM test_table""")
+      val expected = Seq(false, false, false, true)

Review Comment:
   This test expects `null = null` to return `true`, but the test data shows 
Row(4, null, null) which should match the comparison logic in the barrier 
function. However, the comment and test setup suggest this should be testing 
null handling more comprehensively.



##########
spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/BarrierFunction.scala:
##########
@@ -0,0 +1,198 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.spark.sql.sedona_sql.expressions
+
+import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.types.{DataType, BooleanType, StringType}
+import org.apache.spark.unsafe.types.UTF8String
+import scala.util.parsing.combinator._
+
+/**
+ * Barrier function to prevent filter pushdown and control predicate 
evaluation order. Takes a
+ * boolean expression string followed by pairs of variable names and their 
values.
+ *
+ * Usage: barrier(expression, var_name1, var_value1, var_name2, var_value2, 
...) Example:
+ * barrier('rating > 4.0 AND stars >= 4', 'rating', r.rating, 'stars', h.stars)
+ *
+ * Extends CodegenFallback to prevent Catalyst optimizer from pushing this 
filter through joins.
+ * CodegenFallback makes this expression opaque to optimization rules, 
ensuring it evaluates at
+ * runtime in its original position within the query plan.
+ */
+private[apache] case class Barrier(inputExpressions: Seq[Expression])
+    extends Expression
+    with CodegenFallback {
+
+  override def nullable: Boolean = false
+
+  override def dataType: DataType = BooleanType
+
+  override def children: Seq[Expression] = inputExpressions
+
+  override def eval(input: InternalRow): Any = {
+    // Get the expression string
+    val exprString = inputExpressions.head.eval(input) match {
+      case s: UTF8String => s.toString
+      case null => throw new IllegalArgumentException("Barrier expression 
cannot be null")
+      case other =>
+        throw new IllegalArgumentException(
+          s"Barrier expression must be a string, got: ${other.getClass}")
+    }
+
+    // Build variable map from pairs
+    val varMap = scala.collection.mutable.Map[String, Any]()
+    var i = 1
+    while (i < inputExpressions.length) {
+      if (i + 1 >= inputExpressions.length) {
+        throw new IllegalArgumentException(
+          "Barrier function requires pairs of variable names and values")
+      }
+
+      val varName = inputExpressions(i).eval(input) match {
+        case s: UTF8String => s.toString
+        case null => throw new IllegalArgumentException("Variable name cannot 
be null")
+        case other =>
+          throw new IllegalArgumentException(
+            s"Variable name must be a string, got: ${other.getClass}")
+      }
+
+      val varValue = inputExpressions(i + 1).eval(input)
+      varMap(varName) = varValue
+      i += 2
+    }
+
+    // Evaluate the expression with variable substitution
+    evaluateBooleanExpression(exprString, varMap.toMap)
+  }
+
+  /**
+   * Evaluates a boolean expression string with variable substitution. 
Supports basic comparison
+   * operators and logical operators (AND, OR, NOT).
+   */
+  private def evaluateBooleanExpression(
+      expression: String,
+      variables: Map[String, Any]): Boolean = {
+    val parser = new BooleanExpressionParser(variables)
+    parser.parseExpression(expression) match {
+      case parser.Success(result, _) => result
+      case parser.NoSuccess(msg, _) =>
+        throw new IllegalArgumentException(s"Failed to parse barrier 
expression: $msg")
+    }
+  }
+
+  protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): 
Expression = {
+    copy(inputExpressions = newChildren)
+  }
+}
+
+/**
+ * Parser for boolean expressions in barrier function. Supports comparison 
operators: =, !=, <>,
+ * <, <=, >, >= Supports logical operators: AND, OR, NOT Supports parentheses 
for grouping
+ */
+private class BooleanExpressionParser(variables: Map[String, Any]) extends 
JavaTokenParsers {
+
+  def parseExpression(expr: String): ParseResult[Boolean] = parseAll(boolExpr, 
expr)
+
+  def boolExpr: Parser[Boolean] = orExpr
+
+  def orExpr: Parser[Boolean] = andExpr ~ rep("(?i)OR".r ~> andExpr) ^^ { case 
left ~ rights =>
+    rights.foldLeft(left)(_ || _)
+  }
+
+  def andExpr: Parser[Boolean] = notExpr ~ rep("(?i)AND".r ~> notExpr) ^^ { 
case left ~ rights =>
+    rights.foldLeft(left)(_ && _)
+  }
+
+  def notExpr: Parser[Boolean] =
+    "(?i)NOT".r ~> notExpr ^^ (!_) |
+      primaryExpr
+
+  def primaryExpr: Parser[Boolean] =
+    "(" ~> boolExpr <~ ")" |
+      attempt(comparison) |
+      booleanValue
+
+  def comparison: Parser[Boolean] = value ~ compOp ~ value ^^ { case left ~ op 
~ right =>
+    compareValues(left, op, right)
+  }
+
+  def attempt[T](p: Parser[T]): Parser[T] = Parser { in =>
+    p(in) match {
+      case s @ Success(_, _) => s
+      case _ => Failure("", in)
+    }
+  }
+
+  def booleanValue: Parser[Boolean] =
+    "(?i)true".r ^^ (_ => true) |
+      "(?i)false".r ^^ (_ => false) |
+      ident.filter(id => !id.toUpperCase.matches("AND|OR|NOT")) ^^ { name =>
+        variables.get(name) match {
+          case Some(b: Boolean) => b
+          case Some(other) =>
+            throw new IllegalArgumentException(s"Expected boolean value for 
$name, got: $other")
+          case None =>
+            throw new IllegalArgumentException(s"Unknown variable: $name")
+        }
+      }
+
+  def compOp: Parser[String] = ">=" | "<=" | "!=" | "<>" | "=" | ">" | "<"
+
+  def value: Parser[Any] =
+    floatingPointNumber ^^ (_.toDouble) |
+      wholeNumber ^^ (_.toLong) |
+      stringLiteral ^^ (s => s.substring(1, s.length - 1)) | // Remove quotes
+      "(?i)true".r ^^ (_ => true) |
+      "(?i)false".r ^^ (_ => false) |
+      "(?i)null".r ^^ (_ => null) |

Review Comment:
   Similar to the previous comment, these regex patterns should be pre-compiled 
as constants to avoid repeated compilation during parsing.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to