This is an automated email from the ASF dual-hosted git repository.
jackie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git
The following commit(s) were added to refs/heads/master by this push:
new 45eea269e8 [Spark Connector] Escape column names when querying Pinot
(#10663)
45eea269e8 is described below
commit 45eea269e8942dd5c031c5b1789199a6c5397cd0
Author: Caner Balci <[email protected]>
AuthorDate: Tue Apr 25 12:54:27 2023 -0700
[Spark Connector] Escape column names when querying Pinot (#10663)
---
.../spark/datasource/query/FilterPushDown.scala | 32 ++++++++++++----------
.../datasource/query/FilterPushDownTest.scala | 21 +++++++++-----
.../spark/v3/datasource/query/FilterPushDown.scala | 32 ++++++++++++----------
.../ExampleSparkPinotConnectorTest.scala | 1 +
.../v3/datasource/query/FilterPushDownTest.scala | 21 +++++++++-----
.../spark/common/PinotClusterClient.scala | 4 +--
.../spark/common/query/ScanQueryGenerator.scala | 6 +++-
.../common/query/ScanQueryGeneratorTest.scala | 24 ++++++++--------
8 files changed, 84 insertions(+), 57 deletions(-)
diff --git
a/pinot-connectors/pinot-spark-2-connector/src/main/scala/org/apache/pinot/connector/spark/datasource/query/FilterPushDown.scala
b/pinot-connectors/pinot-spark-2-connector/src/main/scala/org/apache/pinot/connector/spark/datasource/query/FilterPushDown.scala
index 30954b566c..331594663a 100644
---
a/pinot-connectors/pinot-spark-2-connector/src/main/scala/org/apache/pinot/connector/spark/datasource/query/FilterPushDown.scala
+++
b/pinot-connectors/pinot-spark-2-connector/src/main/scala/org/apache/pinot/connector/spark/datasource/query/FilterPushDown.scala
@@ -81,25 +81,29 @@ private[pinot] object FilterPushDown {
case _ => value
}
+ private def escapeAttr(attr: String): String = {
+ if (attr.contains("\"")) attr else s""""$attr""""
+ }
+
private def compileFilter(filter: Filter): Option[String] = {
val whereCondition = filter match {
- case EqualTo(attr, value) => s"$attr = ${compileValue(value)}"
+ case EqualTo(attr, value) => s"${escapeAttr(attr)} =
${compileValue(value)}"
case EqualNullSafe(attr, value) =>
- s"NOT ($attr != ${compileValue(value)} OR $attr IS NULL OR " +
+ s"NOT (${escapeAttr(attr)} != ${compileValue(value)} OR
${escapeAttr(attr)} IS NULL OR " +
s"${compileValue(value)} IS NULL) OR " +
- s"($attr IS NULL AND ${compileValue(value)} IS NULL)"
- case LessThan(attr, value) => s"$attr < ${compileValue(value)}"
- case GreaterThan(attr, value) => s"$attr > ${compileValue(value)}"
- case LessThanOrEqual(attr, value) => s"$attr <= ${compileValue(value)}"
- case GreaterThanOrEqual(attr, value) => s"$attr >=
${compileValue(value)}"
- case IsNull(attr) => s"$attr IS NULL"
- case IsNotNull(attr) => s"$attr IS NOT NULL"
- case StringStartsWith(attr, value) => s"$attr LIKE '$value%'"
- case StringEndsWith(attr, value) => s"$attr LIKE '%$value'"
- case StringContains(attr, value) => s"$attr LIKE '%$value%'"
+ s"(${escapeAttr(attr)} IS NULL AND ${compileValue(value)} IS NULL)"
+ case LessThan(attr, value) => s"${escapeAttr(attr)} <
${compileValue(value)}"
+ case GreaterThan(attr, value) => s"${escapeAttr(attr)} >
${compileValue(value)}"
+ case LessThanOrEqual(attr, value) => s"${escapeAttr(attr)} <=
${compileValue(value)}"
+ case GreaterThanOrEqual(attr, value) => s"${escapeAttr(attr)} >=
${compileValue(value)}"
+ case IsNull(attr) => s"${escapeAttr(attr)} IS NULL"
+ case IsNotNull(attr) => s"${escapeAttr(attr)} IS NOT NULL"
+ case StringStartsWith(attr, value) => s"${escapeAttr(attr)} LIKE
'$value%'"
+ case StringEndsWith(attr, value) => s"${escapeAttr(attr)} LIKE '%$value'"
+ case StringContains(attr, value) => s"${escapeAttr(attr)} LIKE
'%$value%'"
case In(attr, value) if value.isEmpty =>
- s"CASE WHEN $attr IS NULL THEN NULL ELSE FALSE END"
- case In(attr, value) => s"$attr IN (${compileValue(value)})"
+ s"CASE WHEN ${escapeAttr(attr)} IS NULL THEN NULL ELSE FALSE END"
+ case In(attr, value) => s"${escapeAttr(attr)} IN
(${compileValue(value)})"
case Not(f) => compileFilter(f).map(p => s"NOT ($p)").orNull
case Or(f1, f2) =>
val or = Seq(f1, f2).flatMap(compileFilter)
diff --git
a/pinot-connectors/pinot-spark-2-connector/src/test/scala/org/apache/pinot/connector/spark/datasource/query/FilterPushDownTest.scala
b/pinot-connectors/pinot-spark-2-connector/src/test/scala/org/apache/pinot/connector/spark/datasource/query/FilterPushDownTest.scala
index 127aab211f..eeb961ef25 100644
---
a/pinot-connectors/pinot-spark-2-connector/src/test/scala/org/apache/pinot/connector/spark/datasource/query/FilterPushDownTest.scala
+++
b/pinot-connectors/pinot-spark-2-connector/src/test/scala/org/apache/pinot/connector/spark/datasource/query/FilterPushDownTest.scala
@@ -61,15 +61,22 @@ class FilterPushDownTest extends BaseTest {
test("SQL query should be created from spark filters") {
val whereClause = FilterPushDown.compileFiltersToSqlWhereClause(filters)
val expectedOutput =
- s"(attr1 = 1) AND (attr2 IN ('1', '2', '''5''')) AND (attr3 < 1) AND
(attr4 <= 3) AND (attr5 > 10) AND " +
- s"(attr6 >= 15) AND (NOT (attr7 = '1')) AND ((attr8 < 10) AND (attr9
<= 3)) AND " +
- s"((attr10 = 'hello') OR (attr11 >= 13)) AND (attr12 LIKE '%pinot%')
AND (attr13 IN (10, 20)) AND " +
- s"(NOT (attr20 != '123' OR attr20 IS NULL OR '123' IS NULL) OR (attr20
IS NULL AND '123' IS NULL)) AND " +
- s"(attr14 IS NULL) AND (attr15 IS NOT NULL) AND (attr16 LIKE
'pinot1%') AND (attr17 LIKE '%pinot2') AND " +
- s"(attr18 = '2020-01-01 00:00:15.0') AND (attr19 < '2020-01-01') AND
(attr21 = List(1, 2)) AND " +
- s"(attr22 = 10.5)"
+ s"""("attr1" = 1) AND ("attr2" IN ('1', '2', '''5''')) AND ("attr3" < 1)
AND ("attr4" <= 3) AND ("attr5" > 10) AND """ +
+ s"""("attr6" >= 15) AND (NOT ("attr7" = '1')) AND (("attr8" < 10) AND
("attr9" <= 3)) AND """ +
+ s"""(("attr10" = 'hello') OR ("attr11" >= 13)) AND ("attr12" LIKE
'%pinot%') AND ("attr13" IN (10, 20)) AND """ +
+ s"""(NOT ("attr20" != '123' OR "attr20" IS NULL OR '123' IS NULL) OR
("attr20" IS NULL AND '123' IS NULL)) AND """ +
+ s"""("attr14" IS NULL) AND ("attr15" IS NOT NULL) AND ("attr16" LIKE
'pinot1%') AND ("attr17" LIKE '%pinot2') AND """ +
+ s"""("attr18" = '2020-01-01 00:00:15.0') AND ("attr19" < '2020-01-01')
AND ("attr21" = List(1, 2)) AND """ +
+ s"""("attr22" = 10.5)"""
whereClause.get shouldEqual expectedOutput
}
+ test("Shouldn't escape column names which are already escaped") {
+ val whereClause = FilterPushDown.compileFiltersToSqlWhereClause(
+ Array(EqualTo("\"some\".\"nested\".\"column\"", 1)))
+ val expectedOutput = "(\"some\".\"nested\".\"column\" = 1)"
+
+ whereClause.get shouldEqual expectedOutput
+ }
}
diff --git
a/pinot-connectors/pinot-spark-3-connector/src/main/scala/org/apache/pinot/connector/spark/v3/datasource/query/FilterPushDown.scala
b/pinot-connectors/pinot-spark-3-connector/src/main/scala/org/apache/pinot/connector/spark/v3/datasource/query/FilterPushDown.scala
index 3d4e3f658d..cac50ec031 100644
---
a/pinot-connectors/pinot-spark-3-connector/src/main/scala/org/apache/pinot/connector/spark/v3/datasource/query/FilterPushDown.scala
+++
b/pinot-connectors/pinot-spark-3-connector/src/main/scala/org/apache/pinot/connector/spark/v3/datasource/query/FilterPushDown.scala
@@ -81,25 +81,29 @@ private[pinot] object FilterPushDown {
case _ => value
}
+ private def escapeAttr(attr: String): String = {
+ if (attr.contains("\"")) attr else s""""$attr""""
+ }
+
private def compileFilter(filter: Filter): Option[String] = {
val whereCondition = filter match {
- case EqualTo(attr, value) => s"$attr = ${compileValue(value)}"
+ case EqualTo(attr, value) => s"${escapeAttr(attr)} =
${compileValue(value)}"
case EqualNullSafe(attr, value) =>
- s"NOT ($attr != ${compileValue(value)} OR $attr IS NULL OR " +
+ s"NOT (${escapeAttr(attr)} != ${compileValue(value)} OR
${escapeAttr(attr)} IS NULL OR " +
s"${compileValue(value)} IS NULL) OR " +
- s"($attr IS NULL AND ${compileValue(value)} IS NULL)"
- case LessThan(attr, value) => s"$attr < ${compileValue(value)}"
- case GreaterThan(attr, value) => s"$attr > ${compileValue(value)}"
- case LessThanOrEqual(attr, value) => s"$attr <= ${compileValue(value)}"
- case GreaterThanOrEqual(attr, value) => s"$attr >=
${compileValue(value)}"
- case IsNull(attr) => s"$attr IS NULL"
- case IsNotNull(attr) => s"$attr IS NOT NULL"
- case StringStartsWith(attr, value) => s"$attr LIKE '$value%'"
- case StringEndsWith(attr, value) => s"$attr LIKE '%$value'"
- case StringContains(attr, value) => s"$attr LIKE '%$value%'"
+ s"(${escapeAttr(attr)} IS NULL AND ${compileValue(value)} IS NULL)"
+ case LessThan(attr, value) => s"${escapeAttr(attr)} <
${compileValue(value)}"
+ case GreaterThan(attr, value) => s"${escapeAttr(attr)} >
${compileValue(value)}"
+ case LessThanOrEqual(attr, value) => s"${escapeAttr(attr)} <=
${compileValue(value)}"
+ case GreaterThanOrEqual(attr, value) => s"${escapeAttr(attr)} >=
${compileValue(value)}"
+ case IsNull(attr) => s"${escapeAttr(attr)} IS NULL"
+ case IsNotNull(attr) => s"${escapeAttr(attr)} IS NOT NULL"
+ case StringStartsWith(attr, value) => s"${escapeAttr(attr)} LIKE
'$value%'"
+ case StringEndsWith(attr, value) => s"${escapeAttr(attr)} LIKE '%$value'"
+ case StringContains(attr, value) => s"${escapeAttr(attr)} LIKE
'%$value%'"
case In(attr, value) if value.isEmpty =>
- s"CASE WHEN $attr IS NULL THEN NULL ELSE FALSE END"
- case In(attr, value) => s"$attr IN (${compileValue(value)})"
+ s"CASE WHEN ${escapeAttr(attr)} IS NULL THEN NULL ELSE FALSE END"
+ case In(attr, value) => s"${escapeAttr(attr)} IN
(${compileValue(value)})"
case Not(f) => compileFilter(f).map(p => s"NOT ($p)").orNull
case Or(f1, f2) =>
val or = Seq(f1, f2).flatMap(compileFilter)
diff --git
a/pinot-connectors/pinot-spark-3-connector/src/test/scala/org/apache/pinot/connector/spark/v3/datasource/ExampleSparkPinotConnectorTest.scala
b/pinot-connectors/pinot-spark-3-connector/src/test/scala/org/apache/pinot/connector/spark/v3/datasource/ExampleSparkPinotConnectorTest.scala
index 48692ae0f7..3c2755baf7 100644
---
a/pinot-connectors/pinot-spark-3-connector/src/test/scala/org/apache/pinot/connector/spark/v3/datasource/ExampleSparkPinotConnectorTest.scala
+++
b/pinot-connectors/pinot-spark-3-connector/src/test/scala/org/apache/pinot/connector/spark/v3/datasource/ExampleSparkPinotConnectorTest.scala
@@ -49,6 +49,7 @@ object ExampleSparkPinotConnectorTest extends Logging {
}
def readOffline()(implicit spark: SparkSession): Unit = {
+ import spark.implicits._
log.info("## Reading `airlineStats_OFFLINE` table... ##")
val data = spark.read
.format("pinot")
diff --git
a/pinot-connectors/pinot-spark-3-connector/src/test/scala/org/apache/pinot/connector/spark/v3/datasource/query/FilterPushDownTest.scala
b/pinot-connectors/pinot-spark-3-connector/src/test/scala/org/apache/pinot/connector/spark/v3/datasource/query/FilterPushDownTest.scala
index 6202257b9b..1bf889ddee 100644
---
a/pinot-connectors/pinot-spark-3-connector/src/test/scala/org/apache/pinot/connector/spark/v3/datasource/query/FilterPushDownTest.scala
+++
b/pinot-connectors/pinot-spark-3-connector/src/test/scala/org/apache/pinot/connector/spark/v3/datasource/query/FilterPushDownTest.scala
@@ -61,15 +61,22 @@ class FilterPushDownTest extends BaseTest {
test("SQL query should be created from spark filters") {
val whereClause = FilterPushDown.compileFiltersToSqlWhereClause(filters)
val expectedOutput =
- s"(attr1 = 1) AND (attr2 IN ('1', '2', '''5''')) AND (attr3 < 1) AND
(attr4 <= 3) AND (attr5 > 10) AND " +
- s"(attr6 >= 15) AND (NOT (attr7 = '1')) AND ((attr8 < 10) AND (attr9
<= 3)) AND " +
- s"((attr10 = 'hello') OR (attr11 >= 13)) AND (attr12 LIKE '%pinot%')
AND (attr13 IN (10, 20)) AND " +
- s"(NOT (attr20 != '123' OR attr20 IS NULL OR '123' IS NULL) OR (attr20
IS NULL AND '123' IS NULL)) AND " +
- s"(attr14 IS NULL) AND (attr15 IS NOT NULL) AND (attr16 LIKE
'pinot1%') AND (attr17 LIKE '%pinot2') AND " +
- s"(attr18 = '2020-01-01 00:00:15.0') AND (attr19 < '2020-01-01') AND
(attr21 = List(1, 2)) AND " +
- s"(attr22 = 10.5)"
+ s"""("attr1" = 1) AND ("attr2" IN ('1', '2', '''5''')) AND ("attr3" < 1)
AND ("attr4" <= 3) AND ("attr5" > 10) AND """ +
+ s"""("attr6" >= 15) AND (NOT ("attr7" = '1')) AND (("attr8" < 10) AND
("attr9" <= 3)) AND """ +
+ s"""(("attr10" = 'hello') OR ("attr11" >= 13)) AND ("attr12" LIKE
'%pinot%') AND ("attr13" IN (10, 20)) AND """ +
+ s"""(NOT ("attr20" != '123' OR "attr20" IS NULL OR '123' IS NULL) OR
("attr20" IS NULL AND '123' IS NULL)) AND """ +
+ s"""("attr14" IS NULL) AND ("attr15" IS NOT NULL) AND ("attr16" LIKE
'pinot1%') AND ("attr17" LIKE '%pinot2') AND """ +
+ s"""("attr18" = '2020-01-01 00:00:15.0') AND ("attr19" < '2020-01-01')
AND ("attr21" = List(1, 2)) AND """ +
+ s"""("attr22" = 10.5)"""
whereClause.get shouldEqual expectedOutput
}
+ test("Shouldn't escape column names which are already escaped") {
+ val whereClause = FilterPushDown.compileFiltersToSqlWhereClause(
+ Array(EqualTo("\"some\".\"nested\".\"column\"", 1)))
+ val expectedOutput = "(\"some\".\"nested\".\"column\" = 1)"
+
+ whereClause.get shouldEqual expectedOutput
+ }
}
diff --git
a/pinot-connectors/pinot-spark-common/src/main/scala/org/apache/pinot/connector/spark/common/PinotClusterClient.scala
b/pinot-connectors/pinot-spark-common/src/main/scala/org/apache/pinot/connector/spark/common/PinotClusterClient.scala
index 52e86cbaf5..1c5dafe2a5 100644
---
a/pinot-connectors/pinot-spark-common/src/main/scala/org/apache/pinot/connector/spark/common/PinotClusterClient.scala
+++
b/pinot-connectors/pinot-spark-common/src/main/scala/org/apache/pinot/connector/spark/common/PinotClusterClient.scala
@@ -213,9 +213,9 @@ private[pinot] object PinotClusterClient extends Logging {
private[pinot] case class TimeBoundaryInfo(timeColumn: String, timeValue:
String) {
- def getOfflinePredicate: String = s"$timeColumn < $timeValue"
+ def getOfflinePredicate: String = s""""$timeColumn" < $timeValue"""
- def getRealtimePredicate: String = s"$timeColumn >= $timeValue"
+ def getRealtimePredicate: String = s""""$timeColumn" >= $timeValue"""
}
private[pinot] case class InstanceInfo(instanceName: String,
diff --git
a/pinot-connectors/pinot-spark-common/src/main/scala/org/apache/pinot/connector/spark/common/query/ScanQueryGenerator.scala
b/pinot-connectors/pinot-spark-common/src/main/scala/org/apache/pinot/connector/spark/common/query/ScanQueryGenerator.scala
index 4616bbd7c7..e6c1afb9c8 100644
---
a/pinot-connectors/pinot-spark-common/src/main/scala/org/apache/pinot/connector/spark/common/query/ScanQueryGenerator.scala
+++
b/pinot-connectors/pinot-spark-common/src/main/scala/org/apache/pinot/connector/spark/common/query/ScanQueryGenerator.scala
@@ -46,7 +46,11 @@ private[pinot] class ScanQueryGenerator(
/** Get all columns if selecting columns empty(eg: resultDataFrame.count())
*/
private def columnsAsExpression(): String = {
- if (columns.isEmpty) "*" else columns.mkString(",")
+ if (columns.isEmpty) "*" else columns.map(escapeCol).mkString(",")
+ }
+
+ private def escapeCol(col: String): String = {
+ if (col.contains("\"")) col else s""""$col""""
}
/** Build realtime or offline SQL selection query. */
diff --git
a/pinot-connectors/pinot-spark-common/src/test/scala/org/apache/pinot/connector/spark/common/query/ScanQueryGeneratorTest.scala
b/pinot-connectors/pinot-spark-common/src/test/scala/org/apache/pinot/connector/spark/common/query/ScanQueryGeneratorTest.scala
index 9be29e3c44..73fc9d8584 100644
---
a/pinot-connectors/pinot-spark-common/src/test/scala/org/apache/pinot/connector/spark/common/query/ScanQueryGeneratorTest.scala
+++
b/pinot-connectors/pinot-spark-common/src/test/scala/org/apache/pinot/connector/spark/common/query/ScanQueryGeneratorTest.scala
@@ -25,19 +25,19 @@ import org.apache.pinot.spi.config.table.TableType
* Test SQL query generation from spark push down filters, selection columns
etc.
*/
class ScanQueryGeneratorTest extends BaseTest {
- private val columns = Array("c1, c2")
+ private val columns = Array("c1","c2")
private val tableName = "tbl"
private val tableType = Some(TableType.OFFLINE)
private val whereClause = Some("c1 = 5 OR c2 = 'hello'")
- private val limit = s"LIMIT ${Int.MaxValue}"
+ private val limit = s"""LIMIT ${Int.MaxValue}"""
test("Queries should be created with given filters") {
val pinotQueries =
ScanQueryGenerator.generate(tableName, tableType, None, columns,
whereClause, Set())
val expectedRealtimeQuery =
- s"SELECT c1, c2 FROM ${tableName}_REALTIME WHERE ${whereClause.get}
$limit"
+ s"""SELECT "c1","c2" FROM ${tableName}_REALTIME WHERE ${whereClause.get}
$limit"""
val expectedOfflineQuery =
- s"SELECT c1, c2 FROM ${tableName}_OFFLINE WHERE ${whereClause.get}
$limit"
+ s"""SELECT "c1","c2" FROM ${tableName}_OFFLINE WHERE ${whereClause.get}
$limit"""
pinotQueries.realtimeSelectQuery shouldEqual expectedRealtimeQuery
pinotQueries.offlineSelectQuery shouldEqual expectedOfflineQuery
@@ -48,12 +48,12 @@ class ScanQueryGeneratorTest extends BaseTest {
val pinotQueries = ScanQueryGenerator
.generate(tableName, tableType, Some(timeBoundaryInfo), columns,
whereClause, Set())
- val realtimeWhereClause = s"${whereClause.get} AND timeCol >= 12345"
- val offlineWhereClause = s"${whereClause.get} AND timeCol < 12345"
+ val realtimeWhereClause = s"""${whereClause.get} AND "timeCol" >= 12345"""
+ val offlineWhereClause = s"""${whereClause.get} AND "timeCol" < 12345"""
val expectedRealtimeQuery =
- s"SELECT c1, c2 FROM ${tableName}_REALTIME WHERE $realtimeWhereClause
$limit"
+ s"""SELECT "c1","c2" FROM ${tableName}_REALTIME WHERE
$realtimeWhereClause $limit"""
val expectedOfflineQuery =
- s"SELECT c1, c2 FROM ${tableName}_OFFLINE WHERE $offlineWhereClause
$limit"
+ s"""SELECT "c1","c2" FROM ${tableName}_OFFLINE WHERE $offlineWhereClause
$limit"""
pinotQueries.realtimeSelectQuery shouldEqual expectedRealtimeQuery
pinotQueries.offlineSelectQuery shouldEqual expectedOfflineQuery
@@ -64,12 +64,12 @@ class ScanQueryGeneratorTest extends BaseTest {
val pinotQueries = ScanQueryGenerator
.generate(tableName, tableType, Some(timeBoundaryInfo), columns, None,
Set())
- val realtimeWhereClause = s"timeCol >= 12345"
- val offlineWhereClause = s"timeCol < 12345"
+ val realtimeWhereClause = s""""timeCol" >= 12345"""
+ val offlineWhereClause = s""""timeCol" < 12345"""
val expectedRealtimeQuery =
- s"SELECT c1, c2 FROM ${tableName}_REALTIME WHERE $realtimeWhereClause
$limit"
+ s"""SELECT "c1","c2" FROM ${tableName}_REALTIME WHERE
$realtimeWhereClause $limit"""
val expectedOfflineQuery =
- s"SELECT c1, c2 FROM ${tableName}_OFFLINE WHERE $offlineWhereClause
$limit"
+ s"""SELECT "c1","c2" FROM ${tableName}_OFFLINE WHERE $offlineWhereClause
$limit"""
pinotQueries.realtimeSelectQuery shouldEqual expectedRealtimeQuery
pinotQueries.offlineSelectQuery shouldEqual expectedOfflineQuery
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]