ueshin commented on code in PR #44504:
URL: https://github.com/apache/spark/pull/44504#discussion_r1437201436
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala:
##########
@@ -404,15 +404,68 @@ object UserDefinedPythonDataSource {
* The schema of the output to the Python data source write function.
*/
val writeOutputSchema: StructType = new StructType().add("message",
BinaryType)
+
+ /**
+ * (Driver-side) Look up all available Python Data Sources.
+ */
+ def lookupAllDataSourcesInPython(): PythonAllDataSourcesCreationResult = {
+ new UserDefinedPythonDataSourceLookupRunner(
+ PythonUtils.createPythonFunction(Array.empty[Byte])).runInPython()
+ }
}
+/**
+ * All Data Sources in Python
+ */
+case class PythonAllDataSourcesCreationResult(
+ names: Array[String], dataSources: Array[Array[Byte]])
+
/**
* Used to store the result of creating a Python data source in the Python
process.
*/
case class PythonDataSourceCreationResult(
dataSource: Array[Byte],
schema: StructType)
+/**
+ * A runner used to look up Python Data Sources available in Python path.
+ */
+class UserDefinedPythonDataSourceLookupRunner(lookupSources: PythonFunction)
+ extends
PythonPlannerRunner[PythonAllDataSourcesCreationResult](lookupSources) {
+
+ override val workerModule = "pyspark.sql.worker.lookup_data_sources"
+
+ override protected def writeToPython(dataOut: DataOutputStream, pickler:
Pickler): Unit = {
+ PythonWorkerUtils.writePythonFunction(lookupSources, dataOut)
Review Comment:
We don't need to send anything here? Seems like there is not a corresponding
read in the Python worker.
##########
core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala:
##########
@@ -145,4 +149,48 @@ private[spark] object PythonUtils extends Logging {
listOfPackages.foreach(x => logInfo(s"List of Python packages :-
${formatOutput(x)}"))
}
}
+
+ // Only for testing.
+ private[spark] var additionalTestingPath: Option[String] = None
+
+ private[spark] def createPythonFunction(command: Array[Byte]):
SimplePythonFunction = {
+ val pythonExec: String = sys.env.getOrElse(
+ "PYSPARK_DRIVER_PYTHON", sys.env.getOrElse("PYSPARK_PYTHON", "python3"))
+
+ val sourcePython = if (Utils.isTesting) {
+ // Put PySpark source code instead of the build zip archive so we don't
need
+ // to build PySpark every time during development.
+ val sparkHome: String = {
+ require(
+ sys.props.contains("spark.test.home") ||
sys.env.contains("SPARK_HOME"),
+ "spark.test.home or SPARK_HOME is not set.")
+ sys.props.getOrElse("spark.test.home", sys.env("SPARK_HOME"))
+ }
+ val sourcePath = Paths.get(sparkHome, "python").toAbsolutePath
+ val py4jPath = Paths.get(
+ sparkHome, "python", "lib", PythonUtils.PY4J_ZIP_NAME).toAbsolutePath
Review Comment:
Do we need Py4J path? The Python functions are not supposed to use Py4J?
##########
sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala:
##########
@@ -40,6 +44,52 @@ class PythonDataSourceSuite extends QueryTest with
SharedSparkSession {
| yield (1, partition.value)
| yield (2, partition.value)
|""".stripMargin
+ private val staticSourceName = "custom_source"
+ private var tempDir: File = _
+
+ override def beforeAll(): Unit = {
+ // Create a Python Data Source package before starting up the Spark Session
+ // that triggers automatic registration of the Python Data Source.
+ val dataSourceScript =
+ s"""
+ |from pyspark.sql.datasource import DataSource, DataSourceReader
+ |$simpleDataSourceReaderScript
+ |
+ |class DefaultSource(DataSource):
+ | def schema(self) -> str:
+ | return "id INT, partition INT"
+ |
+ | def reader(self, schema):
+ | return SimpleDataSourceReader()
+ |
+ | @classmethod
+ | def name(cls):
+ | return "$staticSourceName"
+ |""".stripMargin
+ tempDir = Utils.createTempDir()
+ // Write a temporary package to test.
+ // tmp/my_source
+ // tmp/my_source/__init__.py
+ val packageDir = new File(tempDir, "pyspark_mysource")
+ assert(packageDir.mkdir())
+ Utils.tryWithResource(
+ new FileWriter(new File(packageDir,
"__init__.py")))(_.write(dataSourceScript))
+ // So Spark Session initialization can lookup this temporary directory.
+ PythonUtils.additionalTestingPath = Some(tempDir.toString)
+ super.beforeAll()
+ }
+
+ override def afterAll(): Unit = {
+ Utils.deleteRecursively(tempDir)
+ PythonUtils.additionalTestingPath = None
+ super.afterAll()
Review Comment:
just in case,
```suggestion
try {
Utils.deleteRecursively(tempDir)
PythonUtils.additionalTestingPath = None
} finally {
super.afterAll()
}
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]