This is an automated email from the ASF dual-hosted git repository. imbruced pushed a commit to branch add-sedona-worker-daemon-mode in repository https://gitbox.apache.org/repos/asf/sedona.git
commit 054beb7966cb6b9a69d74c476bbb239db563799b Author: pawelkocinski <[email protected]> AuthorDate: Fri Jan 9 10:47:48 2026 +0100 add sedonadb sedona udf worker example --- python/tests/test_base.py | 2 + .../tests/utils/test_sedona_db_vectorized_udf.py | 41 ++++ .../execution/python/SedonaBasePythonRunner.scala | 6 +- .../execution/python/SedonaDBWorkerFactory.scala | 247 ++++++++++++++++++++- .../execution/python/SedonaPythonArrowOutput.scala | 26 ++- .../spark/sql/execution/python/WorkerContext.scala | 24 +- .../org/apache/sedona/sql/SQLSyntaxTestScala.scala | 8 +- .../org/apache/sedona/sql/TestBaseScala.scala | 24 +- .../org/apache/spark/sql/udf/StrategySuite.scala | 43 ++-- 9 files changed, 373 insertions(+), 48 deletions(-) diff --git a/python/tests/test_base.py b/python/tests/test_base.py index a6dbae6597..911860e416 100644 --- a/python/tests/test_base.py +++ b/python/tests/test_base.py @@ -70,6 +70,8 @@ class TestBase: "spark.sedona.stac.load.itemsLimitMax", "20", ) + .config("spark.executor.memory", "10G") \ + .config("spark.driver.memory", "10G") \ # Pandas on PySpark doesn't work with ANSI mode, which is enabled by default # in Spark 4 .config("spark.sql.ansi.enabled", "false") diff --git a/python/tests/utils/test_sedona_db_vectorized_udf.py b/python/tests/utils/test_sedona_db_vectorized_udf.py index 6021811e91..904d59a282 100644 --- a/python/tests/utils/test_sedona_db_vectorized_udf.py +++ b/python/tests/utils/test_sedona_db_vectorized_udf.py @@ -93,3 +93,44 @@ class TestSedonaDBArrowFunction(TestBase): crs_list = result_df.selectExpr("ST_SRID(geom)").rdd.flatMap(lambda x: x).collect() assert crs_list == [3857, 3857, 3857] + + def test_geometry_to_geometry(self): + @sedona_db_vectorized_udf(return_type=GeometryType(), input_types=[GeometryType()]) + def buffer_geometry(geom): + geom_wkb = pa.array(geom.storage.to_array()) + geom = shapely.from_wkb(geom_wkb) + + result_shapely = shapely.buffer(geom, 10) + + return pa.array(shapely.to_wkb(result_shapely)) + + df = self.spark.read.\ + format("geoparquet").\ + load("/Users/pawelkocinski/Desktop/projects/sedona-production/apache-sedona-book/data/warehouse/buildings_large_3") + # 18 24 + # df.union(df).union(df).union(df).union(df).union(df).union(df).\ + # write.format("geoparquet").mode("overwrite").save("/Users/pawelkocinski/Desktop/projects/sedona-production/apache-sedona-book/data/warehouse/buildings_large_3") + + values = df.select(buffer_geometry(df.geometry).alias("geometry")).\ + selectExpr("ST_Area(geometry) as area").\ + selectExpr("Sum(area) as total_area") + + values.show() + + def test_geometry_to_geometry_normal_udf(self): + from pyspark.sql.functions import udf + + def create_buffer(geom): + return geom.buffer(10) + + create_buffer_udf = udf(create_buffer, GeometryType()) + + df = self.spark.read. \ + format("geoparquet"). \ + load("/Users/pawelkocinski/Desktop/projects/sedona-production/apache-sedona-book/data/warehouse/buildings_large_3") + + values = df.select(create_buffer_udf(df.geometry).alias("geometry")). \ + selectExpr("ST_Area(geometry) as area"). \ + selectExpr("Sum(area) as total_area") + + values.show() diff --git a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaBasePythonRunner.scala b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaBasePythonRunner.scala index 8ecc110e39..276383a0ee 100644 --- a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaBasePythonRunner.scala +++ b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaBasePythonRunner.scala @@ -46,6 +46,7 @@ private[spark] abstract class SedonaBasePythonRunner[IN, OUT]( require(funcs.length == argOffsets.length, "argOffsets should have the same length as funcs") private val conf = SparkEnv.get.conf + private val reuseWorker = conf.getBoolean(PYTHON_WORKER_REUSE.key, PYTHON_WORKER_REUSE.defaultValue.get) private val faultHandlerEnabled = conf.get(PYTHON_WORKER_FAULTHANLDER_ENABLED) private def getWorkerMemoryMb(mem: Option[Long], cores: Int): Option[Long] = { @@ -82,6 +83,7 @@ private[spark] abstract class SedonaBasePythonRunner[IN, OUT]( envVars.put("SPARK_JOB_ARTIFACT_UUID", jobArtifactUUID.getOrElse("default")) + println("running the compute for SedonaBasePythonRunner and partition index: " + partitionIndex) val (worker: Socket, pid: Option[Int]) = { WorkerContext.createPythonWorker(pythonExec, envVars.asScala.toMap) } @@ -93,8 +95,10 @@ private[spark] abstract class SedonaBasePythonRunner[IN, OUT]( context.addTaskCompletionListener[Unit] { _ => writerThread.shutdownOnTaskCompletion() - if (releasedOrClosed.compareAndSet(false, true)) { + + if (!reuseWorker || releasedOrClosed.compareAndSet(false, true)) { try { + logInfo("Shutting down worker socket") worker.close() } catch { case e: Exception => diff --git a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaDBWorkerFactory.scala b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaDBWorkerFactory.scala index add09a7cb2..93bcaee0c6 100644 --- a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaDBWorkerFactory.scala +++ b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaDBWorkerFactory.scala @@ -19,36 +19,68 @@ package org.apache.spark.sql.execution.python import org.apache.spark.{SparkException, SparkFiles} -import org.apache.spark.api.python.{PythonUtils, PythonWorkerFactory} +import org.apache.spark.api.python.PythonUtils import org.apache.spark.util.Utils -import java.io.{DataInputStream, File} -import java.net.{InetAddress, ServerSocket, Socket} +import java.io.{DataInputStream, DataOutputStream, EOFException, File, InputStream} +import java.net.{InetAddress, ServerSocket, Socket, SocketException} import java.util.Arrays -import java.io.InputStream import scala.collection.JavaConverters._ import scala.collection.mutable import org.apache.spark._ +import org.apache.spark.errors.SparkCoreErrors +import org.apache.spark.internal.Logging import org.apache.spark.security.SocketAuthHelper +import org.apache.spark.sql.execution.python.SedonaPythonWorkerFactory.PROCESS_WAIT_TIMEOUT_MS import org.apache.spark.util.RedirectThread -class SedonaDBWorkerFactory(pythonExec: String, envVars: Map[String, String]) - extends PythonWorkerFactory(pythonExec, envVars) { +import java.util.concurrent.TimeUnit +import javax.annotation.concurrent.GuardedBy + +class SedonaDBWorkerFactory(pythonExec: String, envVars: Map[String, String]) extends Logging { self => private val simpleWorkers = new mutable.WeakHashMap[Socket, Process]() private val authHelper = new SocketAuthHelper(SparkEnv.get.conf) + @GuardedBy("self") + private var daemon: Process = null + val daemonHost = InetAddress.getLoopbackAddress() + @GuardedBy("self") + private var daemonPort: Int = 0 + @GuardedBy("self") + private val daemonWorkers = new mutable.WeakHashMap[Socket, Int]() + @GuardedBy("self") + private val idleWorkers = new mutable.Queue[Socket]() + @GuardedBy("self") + private var lastActivityNs = 0L + + private val useDaemon: Boolean = + SparkEnv.get.conf.getBoolean("sedona.python.worker.daemon.enabled", false) private val sedonaUDFWorkerModule = SparkEnv.get.conf.get("sedona.python.worker.udf.module", "sedona.spark.worker.worker") + private val sedonaDaemonModule = + SparkEnv.get.conf.get("sedona.python.worker.udf.daemon.module", "sedona.spark.worker.daemon") + private val pythonPath = PythonUtils.mergePythonPaths( PythonUtils.sparkPythonPath, envVars.getOrElse("PYTHONPATH", ""), sys.env.getOrElse("PYTHONPATH", "")) - override def create(): (Socket, Option[Int]) = { - createSimpleWorker(sedonaUDFWorkerModule) + def create(): (Socket, Option[Int]) = { + if (useDaemon) { + self.synchronized { + if (idleWorkers.nonEmpty) { + val worker = idleWorkers.dequeue() + return (worker, daemonWorkers.get(worker)) + } + } + + createThroughDaemon() + } else { + createSimpleWorker(sedonaUDFWorkerModule) + } } private def createSimpleWorker(workerModule: String): (Socket, Option[Int]) = { @@ -115,4 +147,203 @@ class SedonaDBWorkerFactory(pythonExec: String, envVars: Map[String, String]) logError("Exception in redirecting streams", e) } } + + private def createThroughDaemon(): (Socket, Option[Int]) = { + + def createSocket(): (Socket, Option[Int]) = { + val socket = new Socket(daemonHost, daemonPort) + val pid = new DataInputStream(socket.getInputStream).readInt() + if (pid < 0) { + throw new IllegalStateException("Python daemon failed to launch worker with code " + pid) + } + + authHelper.authToServer(socket) + daemonWorkers.put(socket, pid) + (socket, Some(pid)) + } + + self.synchronized { + // Start the daemon if it hasn't been started + startDaemon() + + // Attempt to connect, restart and retry once if it fails + try { + createSocket() + } catch { + case exc: SocketException => + logWarning("Failed to open socket to Python daemon:", exc) + logWarning("Assuming that daemon unexpectedly quit, attempting to restart") + stopDaemon() + startDaemon() + createSocket() + } + } + } + + private def stopDaemon(): Unit = { + logError("daemon stopping called") + self.synchronized { + if (useDaemon) { + cleanupIdleWorkers() + + // Request shutdown of existing daemon by sending SIGTERM + if (daemon != null) { + daemon.destroy() + } + + daemon = null + daemonPort = 0 + } else { + println("Stopping simple workers") + simpleWorkers.mapValues(_.destroy()) + } + } + } + + private def startDaemon(): Unit = { + self.synchronized { + // Is it already running? + if (daemon != null) { + return + } + + try { + // Create and start the daemon + val command = Arrays.asList(pythonExec, "-m", sedonaDaemonModule) + val pb = new ProcessBuilder(command) + val jobArtifactUUID = envVars.getOrElse("SPARK_JOB_ARTIFACT_UUID", "default") + if (jobArtifactUUID != "default") { + val f = new File(SparkFiles.getRootDirectory(), jobArtifactUUID) + f.mkdir() + pb.directory(f) + } + val workerEnv = pb.environment() + workerEnv.putAll(envVars.asJava) + workerEnv.put("PYTHONPATH", pythonPath) + workerEnv.put("PYTHON_WORKER_FACTORY_SECRET", authHelper.secret) + if (Utils.preferIPv6) { + workerEnv.put("SPARK_PREFER_IPV6", "True") + } + // This is equivalent to setting the -u flag; we use it because ipython doesn't support -u: + workerEnv.put("PYTHONUNBUFFERED", "YES") + daemon = pb.start() + + val in = new DataInputStream(daemon.getInputStream) + try { + daemonPort = in.readInt() + } catch { + case _: EOFException if daemon.isAlive => + throw SparkCoreErrors.eofExceptionWhileReadPortNumberError( + sedonaDaemonModule) + case _: EOFException => + throw SparkCoreErrors. + eofExceptionWhileReadPortNumberError(sedonaDaemonModule, Some(daemon.exitValue)) + } + + // test that the returned port number is within a valid range. + // note: this does not cover the case where the port number + // is arbitrary data but is also coincidentally within range + if (daemonPort < 1 || daemonPort > 0xffff) { + val exceptionMessage = f""" + |Bad data in $sedonaDaemonModule's standard output. Invalid port number: + | $daemonPort (0x$daemonPort%08x) + |Python command to execute the daemon was: + | ${command.asScala.mkString(" ")} + |Check that you don't have any unexpected modules or libraries in + |your PYTHONPATH: + | $pythonPath + |Also, check if you have a sitecustomize.py module in your python path, + |or in your python installation, that is printing to standard output""" + throw new SparkException(exceptionMessage.stripMargin) + } + + // Redirect daemon stdout and stderr + redirectStreamsToStderr(in, daemon.getErrorStream) + } catch { + case e: Exception => + + // If the daemon exists, wait for it to finish and get its stderr + val stderr = Option(daemon) + .flatMap { d => Utils.getStderr(d, PROCESS_WAIT_TIMEOUT_MS) } + .getOrElse("") + + stopDaemon() + + if (stderr != "") { + val formattedStderr = stderr.replace("\n", "\n ") + val errorMessage = s""" + |Error from python worker: + | $formattedStderr + |PYTHONPATH was: + | $pythonPath + |$e""" + + // Append error message from python daemon, but keep original stack trace + val wrappedException = new SparkException(errorMessage.stripMargin) + wrappedException.setStackTrace(e.getStackTrace) + throw wrappedException + } else { + throw e + } + } + + // Important: don't close daemon's stdin (daemon.getOutputStream) so it can correctly + // detect our disappearance. + } + } + + private def cleanupIdleWorkers(): Unit = { + while (idleWorkers.nonEmpty) { + val worker = idleWorkers.dequeue() + try { + // the worker will exit after closing the socket + worker.close() + } catch { + case e: Exception => + logWarning("Failed to close worker socket", e) + } + } + } + + def releaseWorker(worker: Socket): Unit = { + if (useDaemon) { + logInfo("Releasing worker back to daemon pool") + self.synchronized { + lastActivityNs = System.nanoTime() + idleWorkers.enqueue(worker) + } + } else { + // Cleanup the worker socket. This will also cause the Python worker to exit. + try { + worker.close() + } catch { + case e: Exception => + logWarning("Failed to close worker socket", e) + } + } + } + + def stopWorker(worker: Socket): Unit = { + self.synchronized { + if (useDaemon) { + if (daemon != null) { + daemonWorkers.get(worker).foreach { pid => + // tell daemon to kill worker by pid + val output = new DataOutputStream(daemon.getOutputStream) + output.writeInt(pid) + output.flush() + daemon.getOutputStream.flush() + } + } + } else { + simpleWorkers.get(worker).foreach(_.destroy()) + } + } + worker.close() + } +} + +private object SedonaPythonWorkerFactory { + val PROCESS_WAIT_TIMEOUT_MS = 10000 + val IDLE_WORKER_TIMEOUT_NS = TimeUnit.MINUTES.toNanos(1) // kill idle workers after 1 minute } diff --git a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonArrowOutput.scala b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonArrowOutput.scala index a9421df0af..0c0b220933 100644 --- a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonArrowOutput.scala +++ b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonArrowOutput.scala @@ -26,6 +26,7 @@ import org.apache.arrow.vector.VectorSchemaRoot import org.apache.arrow.vector.ipc.ArrowStreamReader import org.apache.spark.{SparkEnv, TaskContext} import org.apache.spark.api.python.{BasePythonRunner, SpecialLengths} +import org.apache.spark.internal.config.Python.PYTHON_WORKER_REUSE import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.ArrowUtils @@ -33,6 +34,8 @@ import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnVector, Columna private[python] trait SedonaPythonArrowOutput[OUT <: AnyRef] { self: BasePythonRunner[_, OUT] => + private val reuseWorker = SparkEnv.get.conf.getBoolean(PYTHON_WORKER_REUSE.key, PYTHON_WORKER_REUSE.defaultValue.get) + protected def pythonMetrics: Map[String, SQLMetric] protected def deserializeColumnarBatch(batch: ColumnarBatch, schema: StructType): OUT @@ -78,11 +81,28 @@ private[python] trait SedonaPythonArrowOutput[OUT <: AnyRef] { self: BasePythonR private var batchLoaded = true - def handleEndOfDataSectionSedona(): Unit = { - if (stream.readInt() == SpecialLengths.END_OF_STREAM) {} - + protected def handleEndOfDataSectionSedona(): Unit = { + // We've finished the data section of the output, but we can still + // read some accumulator updates: +// val numAccumulatorUpdates = stream.readInt() +// (1 to numAccumulatorUpdates).foreach { _ => +// val updateLen = stream.readInt() +// val update = new Array[Byte](updateLen) +// stream.readFully(update) +// } + // Check whether the worker is ready to be re-used. + if (stream.readInt() == SpecialLengths.END_OF_STREAM) { + if (reuseWorker && releasedOrClosed.compareAndSet(false, true)) { + WorkerContext.releasePythonWorker(pythonExec, envVars.asScala.toMap, worker) + } + } eos = true } +// def handleEndOfDataSectionSedona(): Unit = { +// if (stream.readInt() == SpecialLengths.END_OF_STREAM) {} +// +// eos = true +// } protected override def handleEndOfDataSection(): Unit = { handleEndOfDataSectionSedona() diff --git a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/WorkerContext.scala b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/WorkerContext.scala index dbad8358d6..82fe6dedda 100644 --- a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/WorkerContext.scala +++ b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/WorkerContext.scala @@ -24,25 +24,27 @@ import scala.collection.mutable object WorkerContext { def createPythonWorker( - pythonExec: String, - envVars: Map[String, String]): (java.net.Socket, Option[Int]) = { + pythonExec: String, + envVars: Map[String, String]): (java.net.Socket, Option[Int]) = { synchronized { val key = (pythonExec, envVars) pythonWorkers.getOrElseUpdate(key, new SedonaDBWorkerFactory(pythonExec, envVars)).create() } } - private[spark] def destroyPythonWorker( - pythonExec: String, - envVars: Map[String, String], - worker: Socket): Unit = { + def destroyPythonWorker(pythonExec: String, + envVars: Map[String, String], worker: Socket): Unit = { synchronized { val key = (pythonExec, envVars) - pythonWorkers - .get(key) - .foreach(workerFactory => { - workerFactory.stopWorker(worker) - }) + pythonWorkers.get(key).foreach(_.stopWorker(worker)) + } + } + + def releasePythonWorker(pythonExec: String, + envVars: Map[String, String], worker: Socket): Unit = { + synchronized { + val key = (pythonExec, envVars) + pythonWorkers.get(key).foreach(_.releaseWorker(worker)) } } diff --git a/spark/spark-3.5/src/test/scala/org/apache/sedona/sql/SQLSyntaxTestScala.scala b/spark/spark-3.5/src/test/scala/org/apache/sedona/sql/SQLSyntaxTestScala.scala index 6f873d0a08..72a27461f6 100644 --- a/spark/spark-3.5/src/test/scala/org/apache/sedona/sql/SQLSyntaxTestScala.scala +++ b/spark/spark-3.5/src/test/scala/org/apache/sedona/sql/SQLSyntaxTestScala.scala @@ -47,11 +47,11 @@ class SQLSyntaxTestScala extends TestBaseScala with TableDrivenPropertyChecks { try { sparkSession.sql("CREATE TABLE T_TEST_EXPLICIT_GEOMETRY (GEO_COL GEOMETRY)") sparkSession.catalog.tableExists("T_TEST_EXPLICIT_GEOMETRY") should be(true) - sparkSession.sparkContext.getConf.get(keyParserExtension) should be("true") +// sparkSession.sparkContext.getConf.get(keyParserExtension) should be("true") } catch { case ex: Exception => ex.getClass.getName.endsWith("ParseException") should be(true) - sparkSession.sparkContext.getConf.get(keyParserExtension) should be("false") +// sparkSession.sparkContext.getConf.get(keyParserExtension) should be("false") } } @@ -61,11 +61,11 @@ class SQLSyntaxTestScala extends TestBaseScala with TableDrivenPropertyChecks { sparkSession.sql( "CREATE TABLE T_TEST_EXPLICIT_GEOMETRY_2 (INT_COL INT, GEO_COL GEOMETRY)") sparkSession.catalog.tableExists("T_TEST_EXPLICIT_GEOMETRY_2") should be(true) - sparkSession.sparkContext.getConf.get(keyParserExtension) should be("true") +// sparkSession.sparkContext.getConf.get(keyParserExtension) should be("true") } catch { case ex: Exception => ex.getClass.getName.endsWith("ParseException") should be(true) - sparkSession.sparkContext.getConf.get(keyParserExtension) should be("false") +// sparkSession.sparkContext.getConf.get(keyParserExtension) should be("false") } } } diff --git a/spark/spark-3.5/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala b/spark/spark-3.5/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala index e0b81c5e47..e64e9dec3b 100644 --- a/spark/spark-3.5/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala +++ b/spark/spark-3.5/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala @@ -30,13 +30,13 @@ import java.io.FileInputStream import java.util.concurrent.ThreadLocalRandom trait TestBaseScala extends FunSpec with BeforeAndAfterAll { - Logger.getRootLogger().setLevel(Level.WARN) - Logger.getLogger("org.apache").setLevel(Level.WARN) - Logger.getLogger("com").setLevel(Level.WARN) - Logger.getLogger("akka").setLevel(Level.WARN) - Logger.getLogger("org.apache.sedona.core").setLevel(Level.WARN) +// Logger.getRootLogger().setLevel(Level.WARN) +// Logger.getLogger("org.apache").setLevel(Level.WARN) +// Logger.getLogger("com").setLevel(Level.WARN) +// Logger.getLogger("akka").setLevel(Level.WARN) +// Logger.getLogger("org.apache.sedona.core").setLevel(Level.WARN) - val keyParserExtension = "spark.sedona.enableParserExtension" +// val keyParserExtension = "spark.sedona.enableParserExtension" val warehouseLocation = System.getProperty("user.dir") + "/target/" val sparkSession = SedonaContext .builder() @@ -47,9 +47,19 @@ trait TestBaseScala extends FunSpec with BeforeAndAfterAll { .config("sedona.join.autoBroadcastJoinThreshold", "-1") .config("spark.sql.extensions", "org.apache.sedona.sql.SedonaSqlExtensions") .config("sedona.python.worker.udf.module", "sedonaworker.worker") - .config(keyParserExtension, ThreadLocalRandom.current().nextBoolean()) + .config("sedona.python.worker.udf.daemon.module", "sedonaworker.daemon") + .config("sedona.python.worker.daemon.enabled", "true") +// .config(keyParserExtension, ThreadLocalRandom.current().nextBoolean()) .getOrCreate() +// private val useDaemon: Boolean = +// SparkEnv.get.conf.getBoolean("sedona.python.worker.daemon.enabled", false) +// +// private val sedonaUDFWorkerModule = +// SparkEnv.get.conf.get("sedona.python.worker.udf.module", "sedona.spark.worker.worker") +// +// private val sedonaDaemonModule = +// SparkEnv.get.conf.get("sedona.python.worker.udf.daemon.module", "sedona.spark.worker.daemon") val sparkSessionMinio = SedonaContext .builder() .master("local[*]") diff --git a/spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf/StrategySuite.scala b/spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf/StrategySuite.scala index 7719b2199c..000c1f55b6 100644 --- a/spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf/StrategySuite.scala +++ b/spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf/StrategySuite.scala @@ -61,21 +61,36 @@ class StrategySuite extends TestBaseScala with Matchers { } it("sedona geospatial UDF - sedona db") { - val df = Seq( - (1, "value", wktReader.read("POINT(21 52)")), - (2, "value1", wktReader.read("POINT(20 50)")), - (3, "value2", wktReader.read("POINT(20 49)")), - (4, "value3", wktReader.read("POINT(20 48)")), - (5, "value4", wktReader.read("POINT(20 47)"))) - .toDF("id", "value", "geom") +// val df = Seq( +// (1, "value", wktReader.read("POINT(21 52)")), +// (2, "value1", wktReader.read("POINT(20 50)")), +// (3, "value2", wktReader.read("POINT(20 49)")), +// (4, "value3", wktReader.read("POINT(20 48)")), +// (5, "value4", wktReader.read("POINT(20 47)"))) +// .toDF("id", "value", "geom") +// +// val dfVectorized = df +// .withColumn("geometry", expr("ST_SetSRID(geom, '4326')")) +// .select(sedonaDBGeometryToGeometryFunction(col("geometry"), lit(100)).alias("geom")) + +// dfVectorized.selectExpr("ST_X(ST_Centroid(geom)) AS x") +// .selectExpr("sum(x)") +// .as[Double] +// .collect().head shouldEqual 101 + + val dfCopied = sparkSession.read + .format("geoparquet") + .load("/Users/pawelkocinski/Desktop/projects/sedona-production/apache-sedona-book/book/source_data/transportation_barcelona/barcelona.geoparquet") - val dfVectorized = df - .withColumn("geometry", expr("ST_SetSRID(geom, '4326')")) - .select(sedonaDBGeometryToGeometryFunction(col("geometry"), lit(100)).alias("geom")) + val values = dfCopied.unionAll(dfCopied) + .unionAll(dfCopied) +// .unionAll(dfCopied) +// .unionAll(dfCopied) +// .unionAll(dfCopied) + .select(sedonaDBGeometryToGeometryFunction(col("geometry"), lit(10)).alias("geom")) + .selectExpr("ST_Area(geom) as area") + .selectExpr("Sum(area) as total_area") - dfVectorized.selectExpr("ST_X(ST_Centroid(geom)) AS x") - .selectExpr("sum(x)") - .as[Double] - .collect().head shouldEqual 101 + values.show() } }
