This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch branch-2.4
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-2.4 by this push:
     new 1973515  [SPARK-32999][SQL][2.4] Use Utils.getSimpleName to avoid 
hitting Malformed class name in TreeNode
1973515 is described below

commit 1973515d069633d214dd4cb5f83f7b6f6b59766c
Author: Kris Mok <[email protected]>
AuthorDate: Mon Sep 28 17:15:43 2020 -0700

    [SPARK-32999][SQL][2.4] Use Utils.getSimpleName to avoid hitting Malformed 
class name in TreeNode
    
    # What changes were proposed in this pull request?
    
    Use `Utils.getSimpleName` to avoid hitting `Malformed class name` error in 
`TreeNode`.
    
    ### Why are the changes needed?
    
    On older JDK versions (e.g. JDK8u), nested Scala classes may trigger 
`java.lang.Class.getSimpleName` to throw an `java.lang.InternalError: Malformed 
class name` error.
    
    Similar to https://github.com/apache/spark/pull/29050, we should use  
Spark's `Utils.getSimpleName` utility function in place of 
`Class.getSimpleName` to avoid hitting the issue.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Fixes a bug that throws an error when invoking `TreeNode.nodeName`, 
otherwise no changes.
    
    ### How was this patch tested?
    
    Added new unit test case in `TreeNodeSuite`. Note that the test case 
assumes the test code can trigger the expected error, otherwise it'll skip the 
test safely, for compatibility with newer JDKs.
    
    Manually tested on JDK8u and JDK11u and observed expected behavior:
    - JDK8u: the test case triggers the "Malformed class name" issue and the 
fix works;
    - JDK11u: the test case does not trigger the "Malformed class name" issue, 
and the test case is safely skipped.
    
    Closes #29896 from rednaxelafx/spark-32999-getsimplename-2.4.
    
    Authored-by: Kris Mok <[email protected]>
    Signed-off-by: Dongjoon Hyun <[email protected]>
---
 .../apache/spark/sql/catalyst/trees/TreeNode.scala |  6 ++--
 .../spark/sql/catalyst/trees/TreeNodeSuite.scala   | 37 +++++++++++++++++++++-
 2 files changed, 40 insertions(+), 3 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
index a924f10..5e59eb3 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
@@ -419,11 +419,13 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] 
extends Product {
     }
   }
 
+  private def simpleClassName: String = Utils.getSimpleName(this.getClass)
+
   /**
    * Returns the name of this type of TreeNode.  Defaults to the class name.
    * Note that we remove the "Exec" suffix for physical operators here.
    */
-  def nodeName: String = getClass.getSimpleName.replaceAll("Exec$", "")
+  def nodeName: String = simpleClassName.replaceAll("Exec$", "")
 
   /**
    * The arguments that should be included in the arg string.  Defaults to the 
`productIterator`.
@@ -610,7 +612,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] 
extends Product {
   protected def jsonFields: List[JField] = {
     val fieldNames = getConstructorParameterNames(getClass)
     val fieldValues = productIterator.toSeq ++ otherCopyArgs
-    assert(fieldNames.length == fieldValues.length, 
s"${getClass.getSimpleName} fields: " +
+    assert(fieldNames.length == fieldValues.length, s"$simpleClassName fields: 
" +
       fieldNames.mkString(", ") + s", values: " + 
fieldValues.map(_.toString).mkString(", "))
 
     fieldNames.zip(fieldValues).map {
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
index e37cf8a..883f673 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
@@ -32,7 +32,7 @@ import org.apache.spark.sql.catalyst.{AliasIdentifier, 
FunctionIdentifier, Inter
 import org.apache.spark.sql.catalyst.catalog._
 import org.apache.spark.sql.catalyst.dsl.expressions.DslString
 import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, 
CodegenFallback, ExprCode}
 import org.apache.spark.sql.catalyst.plans.{LeftOuter, NaturalJoin}
 import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, SubqueryAlias, 
Union}
 import org.apache.spark.sql.catalyst.plans.physical.{IdentityBroadcastMode, 
RoundRobinPartitioning, SinglePartition}
@@ -617,4 +617,39 @@ class TreeNodeSuite extends SparkFunSuite {
     val expected = Coalesce(Stream(Literal(1), Literal(3)))
     assert(result === expected)
   }
+
+  object MalformedClassObject extends Serializable {
+    // Backport notes: this class inline-expands TaggingExpression from Spark 
3.1
+    case class MalformedNameExpression(child: Expression) extends 
UnaryExpression {
+      override def nullable: Boolean = child.nullable
+      override def dataType: DataType = child.dataType
+
+      override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): 
ExprCode =
+        child.genCode(ctx)
+
+      override def eval(input: InternalRow): Any = child.eval(input)
+    }
+  }
+
+  test("SPARK-32999: TreeNode.nodeName should not throw malformed class name 
error") {
+    val testTriggersExpectedError = try {
+      classOf[MalformedClassObject.MalformedNameExpression].getSimpleName
+      false
+    } catch {
+      case ex: java.lang.InternalError if ex.getMessage.contains("Malformed 
class name") =>
+        true
+      case ex: Throwable => throw ex
+    }
+    // This test case only applies on older JDK versions (e.g. JDK8u), and 
doesn't trigger the
+    // issue on newer JDK versions (e.g. JDK11u).
+    assume(testTriggersExpectedError, "the test case didn't trigger malformed 
class name error")
+
+    val expr = MalformedClassObject.MalformedNameExpression(Literal(1))
+    try {
+      expr.nodeName
+    } catch {
+      case ex: java.lang.InternalError if ex.getMessage.contains("Malformed 
class name") =>
+        fail("TreeNode.nodeName should not throw malformed class name error")
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to