This is an automated email from the ASF dual-hosted git repository.
feiwang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/kyuubi.git
The following commit(s) were added to refs/heads/master by this push:
new 50910ae6c [KYUUBI #5877] Support Python magic syntax for notebook usage
50910ae6c is described below
commit 50910ae6c4317156f29e53893e9e236d9f6b0f15
Author: Fei Wang <[email protected]>
AuthorDate: Wed Dec 20 19:31:47 2023 -0800
[KYUUBI #5877] Support Python magic syntax for notebook usage
# :mag: Description
## Issue References ๐
Support python magic syntax, for example:
```
%table
%json
%matplot
```
Refer:
https://docs.aws.amazon.com/emr/latest/ManagementGuide/emr-studio-magics.html
https://github.com/jupyter-incubator/sparkmagic
https://github.com/apache/incubator-livy/blob/master/repl/src/main/resources/fake_shell.py
This pull request fixes #5877
## Describe Your Solution ๐ง
Please include a summary of the change and which issue is fixed. Please
also include relevant motivation and context. List any dependencies that are
required for this change.
## Types of changes :bookmark:
- [ ] Bugfix (non-breaking change which fixes an issue)
- [ ] New feature (non-breaking change which adds functionality)
- [ ] Breaking change (fix or feature that would cause existing
functionality to change)
## Test Plan ๐งช
Testing with python code:
```
import matplotlib.pyplot as plt
plt.plot([3,4,5],[6,7,8])
%matplot plt;
```
<img width="1723" alt="image"
src="https://github.com/apache/kyuubi/assets/6757692/9a1176c0-8eb0-4a64-83e4-35e74e33d2f0">
Decode the "image/png" and save to png.

#### Behavior Without This Pull Request :coffin:
#### Behavior With This Pull Request :tada:
#### Related Unit Tests
---
# Checklists
## ๐ Author Self Checklist
- [ ] My code follows the [style
guidelines](https://kyuubi.readthedocs.io/en/master/contributing/code/style.html)
of this project
- [ ] I have performed a self-review
- [ ] I have commented my code, particularly in hard-to-understand areas
- [ ] I have made corresponding changes to the documentation
- [ ] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my feature
works
- [ ] New and existing unit tests pass locally with my changes
- [ ] This patch was not authored or co-authored using [Generative
Tooling](https://www.apache.org/legal/generative-tooling.html)
## ๐ Committer Pre-Merge Checklist
- [ ] Pull request title is okay.
- [ ] No license issues.
- [ ] Milestone correctly set?
- [ ] Test coverage is ok
- [ ] Assignees are selected.
- [ ] Minimum number of approvals
- [ ] No changes are requested
**Be nice. Be informative.**
Closes #5881 from turboFei/magic_command.
Closes #5877
6f2b193a9 [Fei Wang] ut
877c7d108 [Fei Wang] internal config
012dfe44f [Fei Wang] nit
3e0f324f4 [Fei Wang] except other exceptions
24352d2c6 [Fei Wang] raise execution error
085316111 [Fei Wang] raise ExecutionError instead of execute_reply_error
c058defc5 [Fei Wang] add more ut
4da52153b [Fei Wang] Dumps python object to json at last
35127537b [Fei Wang] add ut for json and table
48735ebd9 [Fei Wang] the data should be Map[String, Object]
3a3ba0a49 [Fei Wang] return other data fields
54d680090 [Fei Wang] reformat
87ded6e8d [Fei Wang] add config to disable
44f88ef74 [Fei Wang] add magic node back
Authored-by: Fei Wang <[email protected]>
Signed-off-by: Fei Wang <[email protected]>
---
.../src/main/resources/python/execute_python.py | 239 ++++++++++++++++++++-
.../engine/spark/operation/ExecutePython.scala | 21 +-
.../org/apache/kyuubi/config/KyuubiConf.scala | 9 +
.../apache/kyuubi/engine/spark/PySparkTests.scala | 55 +++++
4 files changed, 312 insertions(+), 12 deletions(-)
diff --git
a/externals/kyuubi-spark-sql-engine/src/main/resources/python/execute_python.py
b/externals/kyuubi-spark-sql-engine/src/main/resources/python/execute_python.py
index e6fe7f92b..6729092f7 100644
---
a/externals/kyuubi-spark-sql-engine/src/main/resources/python/execute_python.py
+++
b/externals/kyuubi-spark-sql-engine/src/main/resources/python/execute_python.py
@@ -16,6 +16,8 @@
#
import ast
+import datetime
+import decimal
import io
import json
@@ -23,6 +25,7 @@ import os
import re
import sys
import traceback
+import base64
from glob import glob
if sys.version_info[0] < 3:
@@ -70,6 +73,8 @@ TOP_FRAME_REGEX = re.compile(r'\s*File "<stdin>".*in
<module>')
global_dict = {}
+MAGIC_ENABLED = os.environ.get("MAGIC_ENABLED") == "true"
+
class NormalNode(object):
def __init__(self, code):
@@ -94,6 +99,36 @@ class NormalNode(object):
raise ExecutionError(sys.exc_info())
+class UnknownMagic(Exception):
+ pass
+
+
+class MagicNode(object):
+ def __init__(self, line):
+ parts = line[1:].split(" ", 1)
+ if len(parts) == 1:
+ self.magic, self.rest = parts[0], ()
+ else:
+ self.magic, self.rest = parts[0], (parts[1],)
+
+ def execute(self):
+ if not self.magic:
+ raise UnknownMagic("magic command not specified")
+
+ try:
+ handler = magic_router[self.magic]
+ except KeyError:
+ raise UnknownMagic("unknown magic command '%s'" % self.magic)
+
+ try:
+ return handler(*self.rest)
+ except ExecutionError as e:
+ raise e
+ except Exception:
+ exc_type, exc_value, tb = sys.exc_info()
+ raise ExecutionError((exc_type, exc_value, None))
+
+
class ExecutionError(Exception):
def __init__(self, exc_info):
self.exc_info = exc_info
@@ -118,6 +153,14 @@ def parse_code_into_nodes(code):
try:
nodes.append(NormalNode(code))
except SyntaxError:
+ # It's possible we hit a syntax error because of a magic command.
Split the code groups
+ # of 'normal code', and code that starts with a '%'. possibly magic
code lines, and see
+ # if any of the lines. Remove lines until we find a node that parses,
then check if the
+ # next line is a magic line.
+
+ # Split the code into chunks of normal code, and possibly magic code,
which starts with
+ # a '%'.
+
normal = []
chunks = []
for i, line in enumerate(code.rstrip().split("\n")):
@@ -135,24 +178,22 @@ def parse_code_into_nodes(code):
# Convert the chunks into AST nodes. Let exceptions propagate.
for chunk in chunks:
- # TODO: look back here when Jupyter and sparkmagic are supported
- # if chunk.startswith('%'):
- # nodes.append(MagicNode(chunk))
-
- nodes.append(NormalNode(chunk))
+ if MAGIC_ENABLED and chunk.startswith("%"):
+ nodes.append(MagicNode(chunk))
+ else:
+ nodes.append(NormalNode(chunk))
return nodes
def execute_reply(status, content):
- msg = {
+ return {
"msg_type": "execute_reply",
"content": dict(
content,
status=status,
),
}
- return json.dumps(msg)
def execute_reply_ok(data):
@@ -211,6 +252,9 @@ def execute_request(content):
try:
for node in nodes:
result = node.execute()
+ except UnknownMagic:
+ exc_type, exc_value, tb = sys.exc_info()
+ return execute_reply_error(exc_type, exc_value, None)
except ExecutionError as e:
return execute_reply_error(*e.exc_info)
@@ -239,6 +283,171 @@ def execute_request(content):
return execute_reply_ok(result)
+def magic_table_convert(value):
+ try:
+ converter = magic_table_types[type(value)]
+ except KeyError:
+ converter = magic_table_types[str]
+
+ return converter(value)
+
+
+def magic_table_convert_seq(items):
+ last_item_type = None
+ converted_items = []
+
+ for item in items:
+ item_type, item = magic_table_convert(item)
+
+ if last_item_type is None:
+ last_item_type = item_type
+ elif last_item_type != item_type:
+ raise ValueError("value has inconsistent types")
+
+ converted_items.append(item)
+
+ return "ARRAY_TYPE", converted_items
+
+
+def magic_table_convert_map(m):
+ last_key_type = None
+ last_value_type = None
+ converted_items = {}
+
+ for key, value in m:
+ key_type, key = magic_table_convert(key)
+ value_type, value = magic_table_convert(value)
+
+ if last_key_type is None:
+ last_key_type = key_type
+ elif last_value_type != value_type:
+ raise ValueError("value has inconsistent types")
+
+ if last_value_type is None:
+ last_value_type = value_type
+ elif last_value_type != value_type:
+ raise ValueError("value has inconsistent types")
+
+ converted_items[key] = value
+
+ return "MAP_TYPE", converted_items
+
+
+magic_table_types = {
+ type(None): lambda x: ("NULL_TYPE", x),
+ bool: lambda x: ("BOOLEAN_TYPE", x),
+ int: lambda x: ("INT_TYPE", x),
+ float: lambda x: ("DOUBLE_TYPE", x),
+ str: lambda x: ("STRING_TYPE", str(x)),
+ datetime.date: lambda x: ("DATE_TYPE", str(x)),
+ datetime.datetime: lambda x: ("TIMESTAMP_TYPE", str(x)),
+ decimal.Decimal: lambda x: ("DECIMAL_TYPE", str(x)),
+ tuple: magic_table_convert_seq,
+ list: magic_table_convert_seq,
+ dict: magic_table_convert_map,
+}
+
+
+def magic_table(name):
+ try:
+ value = global_dict[name]
+ except KeyError:
+ exc_type, exc_value, tb = sys.exc_info()
+ raise ExecutionError((exc_type, exc_value, None))
+
+ if not isinstance(value, (list, tuple)):
+ value = [value]
+
+ headers = {}
+ data = []
+
+ for row in value:
+ cols = []
+ data.append(cols)
+
+ if "Row" == row.__class__.__name__:
+ row = row.asDict()
+
+ if not isinstance(row, (list, tuple, dict)):
+ row = [row]
+
+ if isinstance(row, (list, tuple)):
+ iterator = enumerate(row)
+ else:
+ iterator = sorted(row.items())
+
+ for name, col in iterator:
+ col_type, col = magic_table_convert(col)
+
+ try:
+ header = headers[name]
+ except KeyError:
+ header = {
+ "name": str(name),
+ "type": col_type,
+ }
+ headers[name] = header
+ else:
+ # Reject columns that have a different type. (allow none value)
+ if col_type != "NULL_TYPE" and header["type"] != col_type:
+ if header["type"] == "NULL_TYPE":
+ header["type"] = col_type
+ else:
+ exc_type = Exception
+ exc_value = Exception("table rows have different
types")
+ raise ExecutionError((exc_type, exc_value, None))
+
+ cols.append(col)
+
+ headers = [v for k, v in sorted(headers.items())]
+
+ return {
+ "application/vnd.livy.table.v1+json": {
+ "headers": headers,
+ "data": data,
+ }
+ }
+
+
+def magic_json(name):
+ try:
+ value = global_dict[name]
+ except KeyError:
+ exc_type, exc_value, tb = sys.exc_info()
+ raise ExecutionError((exc_type, exc_value, None))
+
+ return {
+ "application/json": value,
+ }
+
+
+def magic_matplot(name):
+ try:
+ value = global_dict[name]
+ fig = value.gcf()
+ imgdata = io.BytesIO()
+ fig.savefig(imgdata, format="png")
+ imgdata.seek(0)
+ encode = base64.b64encode(imgdata.getvalue())
+ if sys.version >= "3":
+ encode = encode.decode()
+
+ except:
+ exc_type, exc_value, tb = sys.exc_info()
+ raise ExecutionError((exc_type, exc_value, None))
+
+ return {
+ "image/png": encode,
+ }
+
+
+magic_router = {
+ "table": magic_table,
+ "json": magic_json,
+ "matplot": magic_matplot,
+}
+
+
# get or create spark session
spark_session = kyuubi_util.get_spark_session(
os.environ.get("KYUUBI_SPARK_SESSION_UUID")
@@ -278,6 +487,22 @@ def main():
break
result = execute_request(content)
+
+ try:
+ result = json.dumps(result)
+ except ValueError:
+ result = json.dumps(
+ {
+ "msg_type": "inspect_reply",
+ "content": {
+ "status": "error",
+ "ename": "ValueError",
+ "evalue": "cannot json-ify %s" % response,
+ "traceback": [],
+ },
+ }
+ )
+
print(result, file=sys_stdout)
sys_stdout.flush()
clearOutputs()
diff --git
a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/ExecutePython.scala
b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/ExecutePython.scala
index b3643a7ae..d35c3fbd4 100644
---
a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/ExecutePython.scala
+++
b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/ExecutePython.scala
@@ -37,7 +37,7 @@ import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.types.StructType
import org.apache.kyuubi.{KyuubiSQLException, Logging, Utils}
-import org.apache.kyuubi.config.KyuubiConf.{ENGINE_SPARK_PYTHON_ENV_ARCHIVE,
ENGINE_SPARK_PYTHON_ENV_ARCHIVE_EXEC_PATH, ENGINE_SPARK_PYTHON_HOME_ARCHIVE}
+import org.apache.kyuubi.config.KyuubiConf.{ENGINE_SPARK_PYTHON_ENV_ARCHIVE,
ENGINE_SPARK_PYTHON_ENV_ARCHIVE_EXEC_PATH, ENGINE_SPARK_PYTHON_HOME_ARCHIVE,
ENGINE_SPARK_PYTHON_MAGIC_ENABLED}
import org.apache.kyuubi.config.KyuubiReservedKeys.{KYUUBI_SESSION_USER_KEY,
KYUUBI_STATEMENT_ID_KEY}
import org.apache.kyuubi.engine.spark.KyuubiSparkUtil._
import org.apache.kyuubi.operation.{ArrayFetchIterator, OperationHandle,
OperationState}
@@ -233,6 +233,7 @@ object ExecutePython extends Logging {
final val PY4J_REGEX = "py4j-[\\S]*.zip$".r
final val PY4J_PATH = "PY4J_PATH"
final val IS_PYTHON_APP_KEY = "spark.yarn.isPython"
+ final val MAGIC_ENABLED = "MAGIC_ENABLED"
private val isPythonGatewayStart = new AtomicBoolean(false)
private val kyuubiPythonPath = Utils.createTempDir()
@@ -280,6 +281,7 @@ object ExecutePython extends Logging {
}
env.put("KYUUBI_SPARK_SESSION_UUID", sessionId)
env.put("PYTHON_GATEWAY_CONNECTION_INFO",
KyuubiPythonGatewayServer.CONNECTION_FILE_PATH)
+ env.put(MAGIC_ENABLED, getSessionConf(ENGINE_SPARK_PYTHON_MAGIC_ENABLED,
spark).toString)
logger.info(
s"""
|launch python worker command: ${builder.command().asScala.mkString("
")}
@@ -409,15 +411,24 @@ object PythonResponse {
}
case class PythonResponseContent(
- data: Map[String, String],
+ data: Map[String, Object],
ename: String,
evalue: String,
traceback: Seq[String],
status: String) {
def getOutput(): String = {
- Option(data)
- .map(_.getOrElse("text/plain", ""))
- .getOrElse("")
+ if (data == null) return ""
+
+ // If data does not contains field other than `test/plain`, keep backward
compatibility,
+ // otherwise, return all the data.
+ if (data.filterNot(_._1 == "text/plain").isEmpty) {
+ data.get("text/plain").map {
+ case str: String => str
+ case obj => ExecutePython.toJson(obj)
+ }.getOrElse("")
+ } else {
+ ExecutePython.toJson(data)
+ }
}
def getEname(): String = {
Option(ename).getOrElse("")
diff --git
a/kyuubi-common/src/main/scala/org/apache/kyuubi/config/KyuubiConf.scala
b/kyuubi-common/src/main/scala/org/apache/kyuubi/config/KyuubiConf.scala
index 00c1b8995..9dbc483b0 100644
--- a/kyuubi-common/src/main/scala/org/apache/kyuubi/config/KyuubiConf.scala
+++ b/kyuubi-common/src/main/scala/org/apache/kyuubi/config/KyuubiConf.scala
@@ -3246,6 +3246,15 @@ object KyuubiConf {
.stringConf
.createWithDefault("bin/python")
+ val ENGINE_SPARK_PYTHON_MAGIC_ENABLED: ConfigEntry[Boolean] =
+ buildConf("kyuubi.engine.spark.python.magic.enabled")
+ .internal
+ .doc("Whether to enable pyspark magic node, which is helpful for
notebook." +
+ " See details in KYUUBI #5877")
+ .version("1.9.0")
+ .booleanConf
+ .createWithDefault(true)
+
val ENGINE_SPARK_REGISTER_ATTRIBUTES: ConfigEntry[Seq[String]] =
buildConf("kyuubi.engine.spark.register.attributes")
.internal
diff --git
a/kyuubi-server/src/test/scala/org/apache/kyuubi/engine/spark/PySparkTests.scala
b/kyuubi-server/src/test/scala/org/apache/kyuubi/engine/spark/PySparkTests.scala
index 16a7f728e..c723dcf4a 100644
---
a/kyuubi-server/src/test/scala/org/apache/kyuubi/engine/spark/PySparkTests.scala
+++
b/kyuubi-server/src/test/scala/org/apache/kyuubi/engine/spark/PySparkTests.scala
@@ -132,6 +132,61 @@ class PySparkTests extends WithKyuubiServer with
HiveJDBCTestHelper {
})
}
+ test("Support python magic syntax for python notebook") {
+ checkPythonRuntimeAndVersion()
+ withSessionConf()(Map(KyuubiConf.ENGINE_SPARK_PYTHON_MAGIC_ENABLED.key ->
"true"))() {
+ withMultipleConnectionJdbcStatement()({ stmt =>
+ val statement = stmt.asInstanceOf[KyuubiStatement]
+ statement.executePython("x = [[1, 'a'], [3, 'b']]")
+
+ val resultSet1 = statement.executePython("%json x")
+ assert(resultSet1.next())
+ val output1 = resultSet1.getString("output")
+ assert(output1 == "{\"application/json\":[[1,\"a\"],[3,\"b\"]]}")
+
+ val resultSet2 = statement.executePython("%table x")
+ assert(resultSet2.next())
+ val output2 = resultSet2.getString("output")
+ assert(output2 == "{\"application/vnd.livy.table.v1+json\":{" +
+ "\"headers\":[" +
+
"{\"name\":\"0\",\"type\":\"INT_TYPE\"},{\"name\":\"1\",\"type\":\"STRING_TYPE\"}"
+
+ "]," +
+ "\"data\":[" +
+ "[1,\"a\"],[3,\"b\"]" +
+ "]}}")
+
+ Seq("table", "json", "matplot").foreach { magic =>
+ val e = intercept[KyuubiSQLException] {
+ statement.executePython(s"%$magic invalid_value")
+ }.getMessage
+ assert(e.contains("KeyError: 'invalid_value'"))
+ }
+
+ statement.executePython("y = [[1, 2], [3, 'b']]")
+ var e = intercept[KyuubiSQLException] {
+ statement.executePython("%table y")
+ }.getMessage
+ assert(e.contains("table rows have different types"))
+
+ e = intercept[KyuubiSQLException] {
+ statement.executePython("%magic_unknown")
+ }.getMessage
+ assert(e.contains("unknown magic command 'magic_unknown'"))
+ })
+ }
+
+ withSessionConf()(Map(KyuubiConf.ENGINE_SPARK_PYTHON_MAGIC_ENABLED.key ->
"false"))() {
+ withMultipleConnectionJdbcStatement()({ stmt =>
+ val statement = stmt.asInstanceOf[KyuubiStatement]
+ statement.executePython("x = [[1, 'a'], [3, 'b']]")
+ val e = intercept[KyuubiSQLException] {
+ statement.executePython("%json x")
+ }.getMessage
+ assert(e.contains("SyntaxError: invalid syntax"))
+ })
+ }
+ }
+
private def runPySparkTest(
pyCode: String,
output: String): Unit = {