This is an automated email from the ASF dual-hosted git repository.

feiwang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/kyuubi.git


The following commit(s) were added to refs/heads/master by this push:
     new 50910ae6c [KYUUBI #5877] Support Python magic syntax for notebook usage
50910ae6c is described below

commit 50910ae6c4317156f29e53893e9e236d9f6b0f15
Author: Fei Wang <[email protected]>
AuthorDate: Wed Dec 20 19:31:47 2023 -0800

    [KYUUBI #5877] Support Python magic syntax for notebook usage
    
    # :mag: Description
    ## Issue References ๐Ÿ”—
    
    Support python magic syntax, for example:
    ```
    %table
    %json
    %matplot
    ```
    
    Refer:
    
https://docs.aws.amazon.com/emr/latest/ManagementGuide/emr-studio-magics.html
    https://github.com/jupyter-incubator/sparkmagic
    
https://github.com/apache/incubator-livy/blob/master/repl/src/main/resources/fake_shell.py
    This pull request fixes #5877
    
    ## Describe Your Solution ๐Ÿ”ง
    
    Please include a summary of the change and which issue is fixed. Please 
also include relevant motivation and context. List any dependencies that are 
required for this change.
    
    ## Types of changes :bookmark:
    
    - [ ] Bugfix (non-breaking change which fixes an issue)
    - [ ] New feature (non-breaking change which adds functionality)
    - [ ] Breaking change (fix or feature that would cause existing 
functionality to change)
    
    ## Test Plan ๐Ÿงช
    
    Testing with python code:
    ```
    import matplotlib.pyplot as plt
    plt.plot([3,4,5],[6,7,8])
    %matplot plt;
    ```
    <img width="1723" alt="image" 
src="https://github.com/apache/kyuubi/assets/6757692/9a1176c0-8eb0-4a64-83e4-35e74e33d2f0";>
    
    Decode the "image/png" and save to png.
    
    
![matplot](https://github.com/apache/kyuubi/assets/6757692/9139f9d3-7822-43b0-8959-261ed8e79d22)
    
    #### Behavior Without This Pull Request :coffin:
    
    #### Behavior With This Pull Request :tada:
    
    #### Related Unit Tests
    
    ---
    
    # Checklists
    ## ๐Ÿ“ Author Self Checklist
    
    - [ ] My code follows the [style 
guidelines](https://kyuubi.readthedocs.io/en/master/contributing/code/style.html)
 of this project
    - [ ] I have performed a self-review
    - [ ] I have commented my code, particularly in hard-to-understand areas
    - [ ] I have made corresponding changes to the documentation
    - [ ] My changes generate no new warnings
    - [ ] I have added tests that prove my fix is effective or that my feature 
works
    - [ ] New and existing unit tests pass locally with my changes
    - [ ] This patch was not authored or co-authored using [Generative 
Tooling](https://www.apache.org/legal/generative-tooling.html)
    
    ## ๐Ÿ“ Committer Pre-Merge Checklist
    
    - [ ] Pull request title is okay.
    - [ ] No license issues.
    - [ ] Milestone correctly set?
    - [ ] Test coverage is ok
    - [ ] Assignees are selected.
    - [ ] Minimum number of approvals
    - [ ] No changes are requested
    
    **Be nice. Be informative.**
    
    Closes #5881 from turboFei/magic_command.
    
    Closes #5877
    
    6f2b193a9 [Fei Wang] ut
    877c7d108 [Fei Wang] internal config
    012dfe44f [Fei Wang] nit
    3e0f324f4 [Fei Wang] except other exceptions
    24352d2c6 [Fei Wang] raise execution error
    085316111 [Fei Wang] raise ExecutionError instead of execute_reply_error
    c058defc5 [Fei Wang] add more ut
    4da52153b [Fei Wang] Dumps python object to json at last
    35127537b [Fei Wang] add ut for json and table
    48735ebd9 [Fei Wang] the data should be Map[String, Object]
    3a3ba0a49 [Fei Wang] return other data fields
    54d680090 [Fei Wang] reformat
    87ded6e8d [Fei Wang] add config to disable
    44f88ef74 [Fei Wang] add magic node back
    
    Authored-by: Fei Wang <[email protected]>
    Signed-off-by: Fei Wang <[email protected]>
---
 .../src/main/resources/python/execute_python.py    | 239 ++++++++++++++++++++-
 .../engine/spark/operation/ExecutePython.scala     |  21 +-
 .../org/apache/kyuubi/config/KyuubiConf.scala      |   9 +
 .../apache/kyuubi/engine/spark/PySparkTests.scala  |  55 +++++
 4 files changed, 312 insertions(+), 12 deletions(-)

diff --git 
a/externals/kyuubi-spark-sql-engine/src/main/resources/python/execute_python.py 
b/externals/kyuubi-spark-sql-engine/src/main/resources/python/execute_python.py
index e6fe7f92b..6729092f7 100644
--- 
a/externals/kyuubi-spark-sql-engine/src/main/resources/python/execute_python.py
+++ 
b/externals/kyuubi-spark-sql-engine/src/main/resources/python/execute_python.py
@@ -16,6 +16,8 @@
 #
 
 import ast
+import datetime
+import decimal
 import io
 import json
 
@@ -23,6 +25,7 @@ import os
 import re
 import sys
 import traceback
+import base64
 from glob import glob
 
 if sys.version_info[0] < 3:
@@ -70,6 +73,8 @@ TOP_FRAME_REGEX = re.compile(r'\s*File "<stdin>".*in 
<module>')
 
 global_dict = {}
 
+MAGIC_ENABLED = os.environ.get("MAGIC_ENABLED") == "true"
+
 
 class NormalNode(object):
     def __init__(self, code):
@@ -94,6 +99,36 @@ class NormalNode(object):
             raise ExecutionError(sys.exc_info())
 
 
+class UnknownMagic(Exception):
+    pass
+
+
+class MagicNode(object):
+    def __init__(self, line):
+        parts = line[1:].split(" ", 1)
+        if len(parts) == 1:
+            self.magic, self.rest = parts[0], ()
+        else:
+            self.magic, self.rest = parts[0], (parts[1],)
+
+    def execute(self):
+        if not self.magic:
+            raise UnknownMagic("magic command not specified")
+
+        try:
+            handler = magic_router[self.magic]
+        except KeyError:
+            raise UnknownMagic("unknown magic command '%s'" % self.magic)
+
+        try:
+            return handler(*self.rest)
+        except ExecutionError as e:
+            raise e
+        except Exception:
+            exc_type, exc_value, tb = sys.exc_info()
+            raise ExecutionError((exc_type, exc_value, None))
+
+
 class ExecutionError(Exception):
     def __init__(self, exc_info):
         self.exc_info = exc_info
@@ -118,6 +153,14 @@ def parse_code_into_nodes(code):
     try:
         nodes.append(NormalNode(code))
     except SyntaxError:
+        # It's possible we hit a syntax error because of a magic command. 
Split the code groups
+        # of 'normal code', and code that starts with a '%'. possibly magic 
code lines, and see
+        # if any of the lines. Remove lines until we find a node that parses, 
then check if the
+        # next line is a magic line.
+
+        # Split the code into chunks of normal code, and possibly magic code, 
which starts with
+        # a '%'.
+
         normal = []
         chunks = []
         for i, line in enumerate(code.rstrip().split("\n")):
@@ -135,24 +178,22 @@ def parse_code_into_nodes(code):
 
         # Convert the chunks into AST nodes. Let exceptions propagate.
         for chunk in chunks:
-            # TODO: look back here when Jupyter and sparkmagic are supported
-            # if chunk.startswith('%'):
-            #     nodes.append(MagicNode(chunk))
-
-            nodes.append(NormalNode(chunk))
+            if MAGIC_ENABLED and chunk.startswith("%"):
+                nodes.append(MagicNode(chunk))
+            else:
+                nodes.append(NormalNode(chunk))
 
     return nodes
 
 
 def execute_reply(status, content):
-    msg = {
+    return {
         "msg_type": "execute_reply",
         "content": dict(
             content,
             status=status,
         ),
     }
-    return json.dumps(msg)
 
 
 def execute_reply_ok(data):
@@ -211,6 +252,9 @@ def execute_request(content):
     try:
         for node in nodes:
             result = node.execute()
+    except UnknownMagic:
+        exc_type, exc_value, tb = sys.exc_info()
+        return execute_reply_error(exc_type, exc_value, None)
     except ExecutionError as e:
         return execute_reply_error(*e.exc_info)
 
@@ -239,6 +283,171 @@ def execute_request(content):
     return execute_reply_ok(result)
 
 
+def magic_table_convert(value):
+    try:
+        converter = magic_table_types[type(value)]
+    except KeyError:
+        converter = magic_table_types[str]
+
+    return converter(value)
+
+
+def magic_table_convert_seq(items):
+    last_item_type = None
+    converted_items = []
+
+    for item in items:
+        item_type, item = magic_table_convert(item)
+
+        if last_item_type is None:
+            last_item_type = item_type
+        elif last_item_type != item_type:
+            raise ValueError("value has inconsistent types")
+
+        converted_items.append(item)
+
+    return "ARRAY_TYPE", converted_items
+
+
+def magic_table_convert_map(m):
+    last_key_type = None
+    last_value_type = None
+    converted_items = {}
+
+    for key, value in m:
+        key_type, key = magic_table_convert(key)
+        value_type, value = magic_table_convert(value)
+
+        if last_key_type is None:
+            last_key_type = key_type
+        elif last_value_type != value_type:
+            raise ValueError("value has inconsistent types")
+
+        if last_value_type is None:
+            last_value_type = value_type
+        elif last_value_type != value_type:
+            raise ValueError("value has inconsistent types")
+
+        converted_items[key] = value
+
+    return "MAP_TYPE", converted_items
+
+
+magic_table_types = {
+    type(None): lambda x: ("NULL_TYPE", x),
+    bool: lambda x: ("BOOLEAN_TYPE", x),
+    int: lambda x: ("INT_TYPE", x),
+    float: lambda x: ("DOUBLE_TYPE", x),
+    str: lambda x: ("STRING_TYPE", str(x)),
+    datetime.date: lambda x: ("DATE_TYPE", str(x)),
+    datetime.datetime: lambda x: ("TIMESTAMP_TYPE", str(x)),
+    decimal.Decimal: lambda x: ("DECIMAL_TYPE", str(x)),
+    tuple: magic_table_convert_seq,
+    list: magic_table_convert_seq,
+    dict: magic_table_convert_map,
+}
+
+
+def magic_table(name):
+    try:
+        value = global_dict[name]
+    except KeyError:
+        exc_type, exc_value, tb = sys.exc_info()
+        raise ExecutionError((exc_type, exc_value, None))
+
+    if not isinstance(value, (list, tuple)):
+        value = [value]
+
+    headers = {}
+    data = []
+
+    for row in value:
+        cols = []
+        data.append(cols)
+
+        if "Row" == row.__class__.__name__:
+            row = row.asDict()
+
+        if not isinstance(row, (list, tuple, dict)):
+            row = [row]
+
+        if isinstance(row, (list, tuple)):
+            iterator = enumerate(row)
+        else:
+            iterator = sorted(row.items())
+
+        for name, col in iterator:
+            col_type, col = magic_table_convert(col)
+
+            try:
+                header = headers[name]
+            except KeyError:
+                header = {
+                    "name": str(name),
+                    "type": col_type,
+                }
+                headers[name] = header
+            else:
+                # Reject columns that have a different type. (allow none value)
+                if col_type != "NULL_TYPE" and header["type"] != col_type:
+                    if header["type"] == "NULL_TYPE":
+                        header["type"] = col_type
+                    else:
+                        exc_type = Exception
+                        exc_value = Exception("table rows have different 
types")
+                        raise ExecutionError((exc_type, exc_value, None))
+
+            cols.append(col)
+
+    headers = [v for k, v in sorted(headers.items())]
+
+    return {
+        "application/vnd.livy.table.v1+json": {
+            "headers": headers,
+            "data": data,
+        }
+    }
+
+
+def magic_json(name):
+    try:
+        value = global_dict[name]
+    except KeyError:
+        exc_type, exc_value, tb = sys.exc_info()
+        raise ExecutionError((exc_type, exc_value, None))
+
+    return {
+        "application/json": value,
+    }
+
+
+def magic_matplot(name):
+    try:
+        value = global_dict[name]
+        fig = value.gcf()
+        imgdata = io.BytesIO()
+        fig.savefig(imgdata, format="png")
+        imgdata.seek(0)
+        encode = base64.b64encode(imgdata.getvalue())
+        if sys.version >= "3":
+            encode = encode.decode()
+
+    except:
+        exc_type, exc_value, tb = sys.exc_info()
+        raise ExecutionError((exc_type, exc_value, None))
+
+    return {
+        "image/png": encode,
+    }
+
+
+magic_router = {
+    "table": magic_table,
+    "json": magic_json,
+    "matplot": magic_matplot,
+}
+
+
 # get or create spark session
 spark_session = kyuubi_util.get_spark_session(
     os.environ.get("KYUUBI_SPARK_SESSION_UUID")
@@ -278,6 +487,22 @@ def main():
                 break
 
             result = execute_request(content)
+
+            try:
+                result = json.dumps(result)
+            except ValueError:
+                result = json.dumps(
+                    {
+                        "msg_type": "inspect_reply",
+                        "content": {
+                            "status": "error",
+                            "ename": "ValueError",
+                            "evalue": "cannot json-ify %s" % response,
+                            "traceback": [],
+                        },
+                    }
+                )
+
             print(result, file=sys_stdout)
             sys_stdout.flush()
             clearOutputs()
diff --git 
a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/ExecutePython.scala
 
b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/ExecutePython.scala
index b3643a7ae..d35c3fbd4 100644
--- 
a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/ExecutePython.scala
+++ 
b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/ExecutePython.scala
@@ -37,7 +37,7 @@ import org.apache.spark.sql.{Row, SparkSession}
 import org.apache.spark.sql.types.StructType
 
 import org.apache.kyuubi.{KyuubiSQLException, Logging, Utils}
-import org.apache.kyuubi.config.KyuubiConf.{ENGINE_SPARK_PYTHON_ENV_ARCHIVE, 
ENGINE_SPARK_PYTHON_ENV_ARCHIVE_EXEC_PATH, ENGINE_SPARK_PYTHON_HOME_ARCHIVE}
+import org.apache.kyuubi.config.KyuubiConf.{ENGINE_SPARK_PYTHON_ENV_ARCHIVE, 
ENGINE_SPARK_PYTHON_ENV_ARCHIVE_EXEC_PATH, ENGINE_SPARK_PYTHON_HOME_ARCHIVE, 
ENGINE_SPARK_PYTHON_MAGIC_ENABLED}
 import org.apache.kyuubi.config.KyuubiReservedKeys.{KYUUBI_SESSION_USER_KEY, 
KYUUBI_STATEMENT_ID_KEY}
 import org.apache.kyuubi.engine.spark.KyuubiSparkUtil._
 import org.apache.kyuubi.operation.{ArrayFetchIterator, OperationHandle, 
OperationState}
@@ -233,6 +233,7 @@ object ExecutePython extends Logging {
   final val PY4J_REGEX = "py4j-[\\S]*.zip$".r
   final val PY4J_PATH = "PY4J_PATH"
   final val IS_PYTHON_APP_KEY = "spark.yarn.isPython"
+  final val MAGIC_ENABLED = "MAGIC_ENABLED"
 
   private val isPythonGatewayStart = new AtomicBoolean(false)
   private val kyuubiPythonPath = Utils.createTempDir()
@@ -280,6 +281,7 @@ object ExecutePython extends Logging {
     }
     env.put("KYUUBI_SPARK_SESSION_UUID", sessionId)
     env.put("PYTHON_GATEWAY_CONNECTION_INFO", 
KyuubiPythonGatewayServer.CONNECTION_FILE_PATH)
+    env.put(MAGIC_ENABLED, getSessionConf(ENGINE_SPARK_PYTHON_MAGIC_ENABLED, 
spark).toString)
     logger.info(
       s"""
          |launch python worker command: ${builder.command().asScala.mkString(" 
")}
@@ -409,15 +411,24 @@ object PythonResponse {
 }
 
 case class PythonResponseContent(
-    data: Map[String, String],
+    data: Map[String, Object],
     ename: String,
     evalue: String,
     traceback: Seq[String],
     status: String) {
   def getOutput(): String = {
-    Option(data)
-      .map(_.getOrElse("text/plain", ""))
-      .getOrElse("")
+    if (data == null) return ""
+
+    // If data does not contains field other than `test/plain`, keep backward 
compatibility,
+    // otherwise, return all the data.
+    if (data.filterNot(_._1 == "text/plain").isEmpty) {
+      data.get("text/plain").map {
+        case str: String => str
+        case obj => ExecutePython.toJson(obj)
+      }.getOrElse("")
+    } else {
+      ExecutePython.toJson(data)
+    }
   }
   def getEname(): String = {
     Option(ename).getOrElse("")
diff --git 
a/kyuubi-common/src/main/scala/org/apache/kyuubi/config/KyuubiConf.scala 
b/kyuubi-common/src/main/scala/org/apache/kyuubi/config/KyuubiConf.scala
index 00c1b8995..9dbc483b0 100644
--- a/kyuubi-common/src/main/scala/org/apache/kyuubi/config/KyuubiConf.scala
+++ b/kyuubi-common/src/main/scala/org/apache/kyuubi/config/KyuubiConf.scala
@@ -3246,6 +3246,15 @@ object KyuubiConf {
       .stringConf
       .createWithDefault("bin/python")
 
+  val ENGINE_SPARK_PYTHON_MAGIC_ENABLED: ConfigEntry[Boolean] =
+    buildConf("kyuubi.engine.spark.python.magic.enabled")
+      .internal
+      .doc("Whether to enable pyspark magic node, which is helpful for 
notebook." +
+        " See details in KYUUBI #5877")
+      .version("1.9.0")
+      .booleanConf
+      .createWithDefault(true)
+
   val ENGINE_SPARK_REGISTER_ATTRIBUTES: ConfigEntry[Seq[String]] =
     buildConf("kyuubi.engine.spark.register.attributes")
       .internal
diff --git 
a/kyuubi-server/src/test/scala/org/apache/kyuubi/engine/spark/PySparkTests.scala
 
b/kyuubi-server/src/test/scala/org/apache/kyuubi/engine/spark/PySparkTests.scala
index 16a7f728e..c723dcf4a 100644
--- 
a/kyuubi-server/src/test/scala/org/apache/kyuubi/engine/spark/PySparkTests.scala
+++ 
b/kyuubi-server/src/test/scala/org/apache/kyuubi/engine/spark/PySparkTests.scala
@@ -132,6 +132,61 @@ class PySparkTests extends WithKyuubiServer with 
HiveJDBCTestHelper {
     })
   }
 
+  test("Support python magic syntax for python notebook") {
+    checkPythonRuntimeAndVersion()
+    withSessionConf()(Map(KyuubiConf.ENGINE_SPARK_PYTHON_MAGIC_ENABLED.key -> 
"true"))() {
+      withMultipleConnectionJdbcStatement()({ stmt =>
+        val statement = stmt.asInstanceOf[KyuubiStatement]
+        statement.executePython("x = [[1, 'a'], [3, 'b']]")
+
+        val resultSet1 = statement.executePython("%json x")
+        assert(resultSet1.next())
+        val output1 = resultSet1.getString("output")
+        assert(output1 == "{\"application/json\":[[1,\"a\"],[3,\"b\"]]}")
+
+        val resultSet2 = statement.executePython("%table x")
+        assert(resultSet2.next())
+        val output2 = resultSet2.getString("output")
+        assert(output2 == "{\"application/vnd.livy.table.v1+json\":{" +
+          "\"headers\":[" +
+          
"{\"name\":\"0\",\"type\":\"INT_TYPE\"},{\"name\":\"1\",\"type\":\"STRING_TYPE\"}"
 +
+          "]," +
+          "\"data\":[" +
+          "[1,\"a\"],[3,\"b\"]" +
+          "]}}")
+
+        Seq("table", "json", "matplot").foreach { magic =>
+          val e = intercept[KyuubiSQLException] {
+            statement.executePython(s"%$magic invalid_value")
+          }.getMessage
+          assert(e.contains("KeyError: 'invalid_value'"))
+        }
+
+        statement.executePython("y = [[1, 2], [3, 'b']]")
+        var e = intercept[KyuubiSQLException] {
+          statement.executePython("%table y")
+        }.getMessage
+        assert(e.contains("table rows have different types"))
+
+        e = intercept[KyuubiSQLException] {
+          statement.executePython("%magic_unknown")
+        }.getMessage
+        assert(e.contains("unknown magic command 'magic_unknown'"))
+      })
+    }
+
+    withSessionConf()(Map(KyuubiConf.ENGINE_SPARK_PYTHON_MAGIC_ENABLED.key -> 
"false"))() {
+      withMultipleConnectionJdbcStatement()({ stmt =>
+        val statement = stmt.asInstanceOf[KyuubiStatement]
+        statement.executePython("x = [[1, 'a'], [3, 'b']]")
+        val e = intercept[KyuubiSQLException] {
+          statement.executePython("%json x")
+        }.getMessage
+        assert(e.contains("SyntaxError: invalid syntax"))
+      })
+    }
+  }
+
   private def runPySparkTest(
       pyCode: String,
       output: String): Unit = {

Reply via email to