This is an automated email from the ASF dual-hosted git repository.
jiafengzheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-doris.git
The following commit(s) were added to refs/heads/master by this push:
new 466cd5d [Optimize] Spark connector supports multiple spark
versions:2.1.x/2.3.x/2.4.x/3.x (#6956)
466cd5d is described below
commit 466cd5dd0903ae023f7ee8c2271c02ecf3d21e15
Author: wei zhao <[email protected]>
AuthorDate: Fri Oct 29 17:06:05 2021 +0800
[Optimize] Spark connector supports multiple spark
versions:2.1.x/2.3.x/2.4.x/3.x (#6956)
* Spark connector supports multiple spark versions:2.1.x/2.3.x/2.4.x/3.x
Co-authored-by: wei.zhao <[email protected]>
---
.../org/apache/doris/spark/DorisStreamLoad.java | 6 +-
.../doris/spark/rdd/AbstractDorisRDDIterator.scala | 12 +-
.../apache/doris/spark/rdd/ScalaValueReader.scala | 2 +-
.../doris/spark/sql/DorisSourceProvider.scala | 26 +++--
.../doris/spark/sql/DorisStreamLoadSink.scala | 98 +++++++++++++++++
.../apache/doris/spark/sql/DorisStreamWriter.scala | 122 ---------------------
.../doris/spark/sql/ScalaDorisRowValueReader.scala | 10 +-
.../scala/org/apache/doris/spark/sql/Utils.scala | 38 +++++--
8 files changed, 158 insertions(+), 156 deletions(-)
diff --git
a/extension/spark-doris-connector/src/main/java/org/apache/doris/spark/DorisStreamLoad.java
b/extension/spark-doris-connector/src/main/java/org/apache/doris/spark/DorisStreamLoad.java
index dcf569f..ccf3a5e 100644
---
a/extension/spark-doris-connector/src/main/java/org/apache/doris/spark/DorisStreamLoad.java
+++
b/extension/spark-doris-connector/src/main/java/org/apache/doris/spark/DorisStreamLoad.java
@@ -150,6 +150,7 @@ public class DorisStreamLoad implements Serializable{
}
public void load(String value) throws StreamLoadException {
+ LOG.debug("Streamload Request:{} ,Body:{}", loadUrlStr, value);
LoadResponse loadResponse = loadBatch(value);
LOG.info("Streamload Response:{}",loadResponse);
if(loadResponse.status != 200){
@@ -169,7 +170,7 @@ public class DorisStreamLoad implements Serializable{
private LoadResponse loadBatch(String value) {
Calendar calendar = Calendar.getInstance();
- String label = String.format("audit_%s%02d%02d_%02d%02d%02d_%s",
+ String label =
String.format("spark_streamload_%s%02d%02d_%02d%02d%02d_%s",
calendar.get(Calendar.YEAR), calendar.get(Calendar.MONTH) + 1,
calendar.get(Calendar.DAY_OF_MONTH),
calendar.get(Calendar.HOUR_OF_DAY),
calendar.get(Calendar.MINUTE), calendar.get(Calendar.SECOND),
UUID.randomUUID().toString().replaceAll("-", ""));
@@ -194,12 +195,11 @@ public class DorisStreamLoad implements Serializable{
while ((line = br.readLine()) != null) {
response.append(line);
}
-// log.info("AuditLoader plugin load with label: {}, response code:
{}, msg: {}, content: {}",label, status, respMsg, response.toString());
return new LoadResponse(status, respMsg, response.toString());
} catch (Exception e) {
e.printStackTrace();
- String err = "failed to load audit via AuditLoader plugin with
label: " + label;
+ String err = "failed to execute spark streamload with label: " +
label;
LOG.warn(err, e);
return new LoadResponse(-1, e.getMessage(), err);
} finally {
diff --git
a/extension/spark-doris-connector/src/main/scala/org/apache/doris/spark/rdd/AbstractDorisRDDIterator.scala
b/extension/spark-doris-connector/src/main/scala/org/apache/doris/spark/rdd/AbstractDorisRDDIterator.scala
index dc39773..5b2b36f 100644
---
a/extension/spark-doris-connector/src/main/scala/org/apache/doris/spark/rdd/AbstractDorisRDDIterator.scala
+++
b/extension/spark-doris-connector/src/main/scala/org/apache/doris/spark/rdd/AbstractDorisRDDIterator.scala
@@ -20,15 +20,15 @@ package org.apache.doris.spark.rdd
import org.apache.doris.spark.cfg.ConfigurationOptions.DORIS_VALUE_READER_CLASS
import org.apache.doris.spark.cfg.Settings
import org.apache.doris.spark.rest.PartitionDefinition
-
import org.apache.spark.util.TaskCompletionListener
-import org.apache.spark.internal.Logging
import org.apache.spark.{TaskContext, TaskKilledException}
+import org.slf4j.{Logger, LoggerFactory}
private[spark] abstract class AbstractDorisRDDIterator[T](
context: TaskContext,
- partition: PartitionDefinition) extends Iterator[T] with Logging {
+ partition: PartitionDefinition) extends Iterator[T] {
+ private val logger: Logger =
LoggerFactory.getLogger(this.getClass.getName.stripSuffix("$"))
private var initialized = false
private var closed = false
@@ -38,7 +38,7 @@ private[spark] abstract class AbstractDorisRDDIterator[T](
val settings = partition.settings()
initReader(settings)
val valueReaderName = settings.getProperty(DORIS_VALUE_READER_CLASS)
- logDebug(s"Use value reader '$valueReaderName'.")
+ logger.debug(s"Use value reader '$valueReaderName'.")
val cons =
Class.forName(valueReaderName).getDeclaredConstructor(classOf[PartitionDefinition],
classOf[Settings])
cons.newInstance(partition, settings).asInstanceOf[ScalaValueReader]
}
@@ -65,7 +65,7 @@ private[spark] abstract class AbstractDorisRDDIterator[T](
}
def closeIfNeeded(): Unit = {
- logTrace(s"Close status is '$closed' when close Doris RDD Iterator")
+ logger.trace(s"Close status is '$closed' when close Doris RDD Iterator")
if (!closed) {
close()
closed = true
@@ -73,7 +73,7 @@ private[spark] abstract class AbstractDorisRDDIterator[T](
}
protected def close(): Unit = {
- logTrace(s"Initialize status is '$initialized' when close Doris RDD
Iterator")
+ logger.trace(s"Initialize status is '$initialized' when close Doris RDD
Iterator")
if (initialized) {
reader.close()
}
diff --git
a/extension/spark-doris-connector/src/main/scala/org/apache/doris/spark/rdd/ScalaValueReader.scala
b/extension/spark-doris-connector/src/main/scala/org/apache/doris/spark/rdd/ScalaValueReader.scala
index a1b26e4..03643b2 100644
---
a/extension/spark-doris-connector/src/main/scala/org/apache/doris/spark/rdd/ScalaValueReader.scala
+++
b/extension/spark-doris-connector/src/main/scala/org/apache/doris/spark/rdd/ScalaValueReader.scala
@@ -44,7 +44,7 @@ import scala.util.control.Breaks
* @param settings request configuration
*/
class ScalaValueReader(partition: PartitionDefinition, settings: Settings) {
- protected val logger = Logger.getLogger(classOf[ScalaValueReader])
+ private val logger = Logger.getLogger(classOf[ScalaValueReader])
protected val client = new BackendClient(new
Routing(partition.getBeAddress), settings)
protected val clientLock =
diff --git
a/extension/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/DorisSourceProvider.scala
b/extension/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/DorisSourceProvider.scala
index 3ac087d..ee77ce6 100644
---
a/extension/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/DorisSourceProvider.scala
+++
b/extension/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/DorisSourceProvider.scala
@@ -21,25 +21,29 @@ import org.apache.doris.spark.DorisStreamLoad
import org.apache.doris.spark.cfg.{ConfigurationOptions, SparkSettings}
import org.apache.doris.spark.sql.DorisSourceProvider.SHORT_NAME
import org.apache.spark.SparkConf
-import org.apache.spark.internal.Logging
-import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter
-import org.apache.spark.sql.sources.v2.{DataSourceOptions, StreamWriteSupport}
+import org.apache.spark.sql.execution.streaming.Sink
+import org.apache.spark.sql.sources._
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.{DataFrame, SQLContext, SaveMode}
-import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider,
DataSourceRegister, Filter, RelationProvider}
+import org.slf4j.{Logger, LoggerFactory}
import java.io.IOException
import java.util
-import scala.collection.JavaConversions.mapAsScalaMap
import scala.collection.JavaConverters.mapAsJavaMapConverter
import scala.util.control.Breaks
-private[sql] class DorisSourceProvider extends DataSourceRegister with
RelationProvider with CreatableRelationProvider with StreamWriteSupport with
Logging {
+private[sql] class DorisSourceProvider extends DataSourceRegister
+ with RelationProvider
+ with CreatableRelationProvider
+ with StreamSinkProvider {
+
+ private val logger: Logger =
LoggerFactory.getLogger(classOf[DorisSourceProvider].getName)
+
override def shortName(): String = SHORT_NAME
override def createRelation(sqlContext: SQLContext, parameters: Map[String,
String]): BaseRelation = {
- new DorisRelation(sqlContext, Utils.params(parameters, log))
+ new DorisRelation(sqlContext, Utils.params(parameters, logger))
}
@@ -51,7 +55,7 @@ private[sql] class DorisSourceProvider extends
DataSourceRegister with RelationP
data: DataFrame): BaseRelation = {
val sparkSettings = new SparkSettings(sqlContext.sparkContext.getConf)
- sparkSettings.merge(Utils.params(parameters, log).asJava)
+ sparkSettings.merge(Utils.params(parameters, logger).asJava)
// init stream loader
val dorisStreamLoader = new DorisStreamLoad(sparkSettings)
@@ -124,10 +128,10 @@ private[sql] class DorisSourceProvider extends
DataSourceRegister with RelationP
}
}
- override def createStreamWriter(queryId: String, structType: StructType,
outputMode: OutputMode, dataSourceOptions: DataSourceOptions): StreamWriter = {
+ override def createSink(sqlContext: SQLContext, parameters: Map[String,
String], partitionColumns: Seq[String], outputMode: OutputMode): Sink = {
val sparkSettings = new SparkSettings(new SparkConf())
- sparkSettings.merge(Utils.params(dataSourceOptions.asMap().toMap,
log).asJava)
- new DorisStreamWriter(sparkSettings)
+ sparkSettings.merge(Utils.params(parameters, logger).asJava)
+ new DorisStreamLoadSink(sqlContext, sparkSettings)
}
}
diff --git
a/extension/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/DorisStreamLoadSink.scala
b/extension/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/DorisStreamLoadSink.scala
new file mode 100644
index 0000000..409325d
--- /dev/null
+++
b/extension/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/DorisStreamLoadSink.scala
@@ -0,0 +1,98 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.spark.sql
+
+import org.apache.doris.spark.cfg.{ConfigurationOptions, SparkSettings}
+import org.apache.doris.spark.{CachedDorisStreamLoadClient, DorisStreamLoad}
+import org.apache.spark.sql.execution.QueryExecution
+import org.apache.spark.sql.execution.streaming.Sink
+import org.apache.spark.sql.{DataFrame, SQLContext}
+import org.slf4j.{Logger, LoggerFactory}
+
+import java.io.IOException
+import java.util
+import scala.util.control.Breaks
+
+private[sql] class DorisStreamLoadSink(sqlContext: SQLContext, settings:
SparkSettings) extends Sink with Serializable {
+
+ private val logger: Logger =
LoggerFactory.getLogger(classOf[DorisStreamLoadSink].getName)
+ @volatile private var latestBatchId = -1L
+ val maxRowCount: Int =
settings.getIntegerProperty(ConfigurationOptions.DORIS_BATCH_SIZE,
ConfigurationOptions.DORIS_BATCH_SIZE_DEFAULT)
+ val maxRetryTimes: Int =
settings.getIntegerProperty(ConfigurationOptions.DORIS_REQUEST_RETRIES,
ConfigurationOptions.DORIS_REQUEST_RETRIES_DEFAULT)
+ val dorisStreamLoader: DorisStreamLoad =
CachedDorisStreamLoadClient.getOrCreate(settings)
+
+ override def addBatch(batchId: Long, data: DataFrame): Unit = {
+ if (batchId <= latestBatchId) {
+ logger.info(s"Skipping already committed batch $batchId")
+ } else {
+ write(data.queryExecution)
+ latestBatchId = batchId
+ }
+ }
+
+ def write(queryExecution: QueryExecution): Unit = {
+ queryExecution.toRdd.foreachPartition(iter => {
+ val rowsBuffer: util.List[util.List[Object]] = new
util.ArrayList[util.List[Object]]()
+ iter.foreach(row => {
+ val line: util.List[Object] = new util.ArrayList[Object](maxRowCount)
+ for (i <- 0 until row.numFields) {
+ val field = row.copy().getUTF8String(i)
+ line.add(field.asInstanceOf[AnyRef])
+ }
+ rowsBuffer.add(line)
+ if (rowsBuffer.size > maxRowCount - 1) {
+ flush
+ }
+ })
+ // flush buffer
+ if (!rowsBuffer.isEmpty) {
+ flush
+ }
+
+ /**
+ * flush data to Doris and do retry when flush error
+ *
+ */
+ def flush = {
+ val loop = new Breaks
+ loop.breakable {
+
+ for (i <- 1 to maxRetryTimes) {
+ try {
+ dorisStreamLoader.load(rowsBuffer)
+ rowsBuffer.clear()
+ loop.break()
+ }
+ catch {
+ case e: Exception =>
+ try {
+ Thread.sleep(1000 * i)
+ } catch {
+ case ex: InterruptedException =>
+ Thread.currentThread.interrupt()
+ throw new IOException("unable to flush; interrupted while
doing another attempt", e)
+ }
+ }
+ }
+ }
+ }
+ })
+ }
+
+ override def toString: String = "DorisStreamLoadSink"
+}
diff --git
a/extension/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/DorisStreamWriter.scala
b/extension/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/DorisStreamWriter.scala
deleted file mode 100644
index 60d2c78..0000000
---
a/extension/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/DorisStreamWriter.scala
+++ /dev/null
@@ -1,122 +0,0 @@
-// Licensed to the Apache Software Foundation (ASF) under one
-// or more contributor license agreements. See the NOTICE file
-// distributed with this work for additional information
-// regarding copyright ownership. The ASF licenses this file
-// to you under the Apache License, Version 2.0 (the
-// "License"); you may not use this file except in compliance
-// with the License. You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing,
-// software distributed under the License is distributed on an
-// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-// KIND, either express or implied. See the License for the
-// specific language governing permissions and limitations
-// under the License.
-
-package org.apache.doris.spark.sql
-
-import org.apache.doris.spark.cfg.{ConfigurationOptions, SparkSettings}
-import org.apache.doris.spark.{CachedDorisStreamLoadClient, DorisStreamLoad}
-import org.apache.spark.sql.Row
-import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter
-import org.apache.spark.sql.sources.v2.writer.{DataWriter, DataWriterFactory,
WriterCommitMessage}
-
-import java.io.IOException
-import java.util
-import scala.util.control.Breaks
-
-/**
- * A [[StreamWriter]] for Apache Doris streaming writing.
- *
- * @param settings params for writing doris table
- */
-class DorisStreamWriter(settings: SparkSettings) extends StreamWriter {
- override def createWriterFactory(): DorisStreamWriterFactory =
DorisStreamWriterFactory(settings)
-
- override def commit(l: Long, writerCommitMessages:
Array[WriterCommitMessage]): Unit = {}
-
- override def abort(l: Long, writerCommitMessages:
Array[WriterCommitMessage]): Unit = {}
-
-}
-
-/**
- * A [[DataWriterFactory]] for Apache Doris streaming writing. Will be
serialized and sent to executors to generate
- * the per-task data writers.
- *
- * @param settings params for writing doris table
- */
-case class DorisStreamWriterFactory(settings: SparkSettings) extends
DataWriterFactory[Row] {
- override def createDataWriter(partitionId: Int, attemptNumber: Int):
DataWriter[Row] = {
- new DorisStreamDataWriter(settings)
- }
-}
-
-/**
- * Dummy commit message. The DataSourceV2 framework requires a commit message
implementation but we
- * don't need to really send one.
- */
-case object DorisWriterCommitMessage extends WriterCommitMessage
-
-/**
- * A [[DataWriter]] for Apache Doris streaming writing. One data writer will
be created in each partition to
- * process incoming rows.
- *
- * @param settings params for writing doris table
- */
-class DorisStreamDataWriter(settings: SparkSettings) extends DataWriter[Row] {
- val maxRowCount: Int =
settings.getIntegerProperty(ConfigurationOptions.DORIS_BATCH_SIZE,
ConfigurationOptions.DORIS_BATCH_SIZE_DEFAULT)
- val maxRetryTimes: Int =
settings.getIntegerProperty(ConfigurationOptions.DORIS_REQUEST_RETRIES,
ConfigurationOptions.DORIS_REQUEST_RETRIES_DEFAULT)
- val dorisStreamLoader: DorisStreamLoad =
CachedDorisStreamLoadClient.getOrCreate(settings)
- val rowsBuffer: util.List[util.List[Object]] = new
util.ArrayList[util.List[Object]](maxRowCount)
-
- override def write(row: Row): Unit = {
- val line: util.List[Object] = new util.ArrayList[Object]()
- for (i <- 0 until row.size) {
- val field = row.get(i)
- line.add(field.asInstanceOf[AnyRef])
- }
- if (!line.isEmpty) {
- rowsBuffer.add(line)
- }
- if (rowsBuffer.size >= maxRowCount) {
- // commit when buffer is full
- commit()
- }
- }
-
- override def commit(): WriterCommitMessage = {
- // we don't commit request until rows-buffer received some rows
- val loop = new Breaks
- loop.breakable {
- for (i <- 1 to maxRetryTimes) {
- try {
- if (!rowsBuffer.isEmpty) {
- dorisStreamLoader.load(rowsBuffer)
- }
- rowsBuffer.clear()
- loop.break()
- }
- catch {
- case e: Exception =>
- try {
- Thread.sleep(1000 * i)
- if (!rowsBuffer.isEmpty) {
- dorisStreamLoader.load(rowsBuffer)
- }
- rowsBuffer.clear()
- } catch {
- case ex: InterruptedException =>
- Thread.currentThread.interrupt()
- throw new IOException("unable to flush; interrupted while
doing another attempt", e)
- }
- }
- }
- }
- DorisWriterCommitMessage
- }
-
- override def abort(): Unit = {
- }
-}
\ No newline at end of file
diff --git
a/extension/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/ScalaDorisRowValueReader.scala
b/extension/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/ScalaDorisRowValueReader.scala
index 7825fcf..5b01854 100644
---
a/extension/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/ScalaDorisRowValueReader.scala
+++
b/extension/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/ScalaDorisRowValueReader.scala
@@ -18,26 +18,26 @@
package org.apache.doris.spark.sql
import scala.collection.JavaConverters._
-
import org.apache.doris.spark.cfg.ConfigurationOptions.DORIS_READ_FIELD
import org.apache.doris.spark.cfg.Settings
import org.apache.doris.spark.exception.ShouldNeverHappenException
import org.apache.doris.spark.rdd.ScalaValueReader
import org.apache.doris.spark.rest.PartitionDefinition
import org.apache.doris.spark.util.ErrorMessages.SHOULD_NOT_HAPPEN_MESSAGE
-
-import org.apache.spark.internal.Logging
+import org.slf4j.{Logger, LoggerFactory}
class ScalaDorisRowValueReader(
partition: PartitionDefinition,
settings: Settings)
- extends ScalaValueReader(partition, settings) with Logging {
+ extends ScalaValueReader(partition, settings) {
+
+ private val logger: Logger =
LoggerFactory.getLogger(classOf[ScalaDorisRowValueReader].getName)
val rowOrder: Seq[String] = settings.getProperty(DORIS_READ_FIELD).split(",")
override def next: AnyRef = {
if (!hasNext) {
- logError(SHOULD_NOT_HAPPEN_MESSAGE)
+ logger.error(SHOULD_NOT_HAPPEN_MESSAGE)
throw new ShouldNeverHappenException
}
val row: ScalaDorisRow = new ScalaDorisRow(rowOrder)
diff --git
a/extension/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/Utils.scala
b/extension/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/Utils.scala
index f5b5af1..6b66646 100644
---
a/extension/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/Utils.scala
+++
b/extension/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/Utils.scala
@@ -17,14 +17,15 @@
package org.apache.doris.spark.sql
+import org.apache.commons.lang3.StringUtils
import org.apache.doris.spark.cfg.ConfigurationOptions
import org.apache.doris.spark.exception.DorisException
-
import org.apache.spark.sql.jdbc.JdbcDialect
import org.apache.spark.sql.sources._
-
import org.slf4j.Logger
+import java.sql.{Date, Timestamp}
+
private[sql] object Utils {
/**
* quote column name
@@ -42,16 +43,16 @@ private[sql] object Utils {
*/
def compileFilter(filter: Filter, dialect: JdbcDialect, inValueLengthLimit:
Int): Option[String] = {
Option(filter match {
- case EqualTo(attribute, value) => s"${quote(attribute)} =
${dialect.compileValue(value)}"
- case GreaterThan(attribute, value) => s"${quote(attribute)} >
${dialect.compileValue(value)}"
- case GreaterThanOrEqual(attribute, value) => s"${quote(attribute)} >=
${dialect.compileValue(value)}"
- case LessThan(attribute, value) => s"${quote(attribute)} <
${dialect.compileValue(value)}"
- case LessThanOrEqual(attribute, value) => s"${quote(attribute)} <=
${dialect.compileValue(value)}"
+ case EqualTo(attribute, value) => s"${quote(attribute)} =
${compileValue(value)}"
+ case GreaterThan(attribute, value) => s"${quote(attribute)} >
${compileValue(value)}"
+ case GreaterThanOrEqual(attribute, value) => s"${quote(attribute)} >=
${compileValue(value)}"
+ case LessThan(attribute, value) => s"${quote(attribute)} <
${compileValue(value)}"
+ case LessThanOrEqual(attribute, value) => s"${quote(attribute)} <=
${compileValue(value)}"
case In(attribute, values) =>
if (values.isEmpty || values.length >= inValueLengthLimit) {
null
} else {
- s"${quote(attribute)} in (${dialect.compileValue(values)})"
+ s"${quote(attribute)} in (${compileValue(values)})"
}
case IsNull(attribute) => s"${quote(attribute)} is null"
case IsNotNull(attribute) => s"${quote(attribute)} is not null"
@@ -74,6 +75,27 @@ private[sql] object Utils {
}
/**
+ * Escape special characters in SQL string literals.
+ * @param value The string to be escaped.
+ * @return Escaped string.
+ */
+ private def escapeSql(value: String): String =
+ if (value == null) null else StringUtils.replace(value, "'", "''")
+
+ /**
+ * Converts value to SQL expression.
+ * @param value The value to be converted.
+ * @return Converted value.
+ */
+ private def compileValue(value: Any): Any = value match {
+ case stringValue: String => s"'${escapeSql(stringValue)}'"
+ case timestampValue: Timestamp => "'" + timestampValue + "'"
+ case dateValue: Date => "'" + dateValue + "'"
+ case arrayValue: Array[Any] => arrayValue.map(compileValue).mkString(", ")
+ case _ => value
+ }
+
+ /**
* check parameters validation and process it.
* @param parameters parameters from rdd and spark conf
* @param logger slf4j logger
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]