Github user JoshRosen commented on a diff in the pull request:
https://github.com/apache/spark/pull/6479#discussion_r31688022
--- Diff:
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
---
@@ -163,133 +190,111 @@ abstract class CodeGenerator[InType <: AnyRef,
OutType <: AnyRef] extends Loggin
*
* @param f a function from two primitive term names to a tree that
evaluates them.
*/
- def evaluate(f: (TermName, TermName) => Tree): Seq[Tree] =
+ def evaluate(f: (String, String) => String): String =
evaluateAs(expressions._1.dataType)(f)
- def evaluateAs(resultType: DataType)(f: (TermName, TermName) =>
Tree): Seq[Tree] = {
+ def evaluateAs(resultType: DataType)(f: (String, String) => String):
String = {
// TODO: Right now some timestamp tests fail if we enforce this...
if (expressions._1.dataType != expressions._2.dataType) {
log.warn(s"${expressions._1.dataType} !=
${expressions._2.dataType}")
}
- val eval1 = expressionEvaluator(expressions._1)
- val eval2 = expressionEvaluator(expressions._2)
+ val eval1 = expressionEvaluator(expressions._1, ctx)
+ val eval2 = expressionEvaluator(expressions._2, ctx)
val resultCode = f(eval1.primitiveTerm, eval2.primitiveTerm)
- eval1.code ++ eval2.code ++
- q"""
- val $nullTerm = ${eval1.nullTerm} || ${eval2.nullTerm}
- val $primitiveTerm: ${termForType(resultType)} =
- if($nullTerm) {
- ${defaultPrimitive(resultType)}
- } else {
- $resultCode.asInstanceOf[${termForType(resultType)}]
- }
- """.children : Seq[Tree]
+ eval1.code + eval2.code +
+ s"""
+ boolean $nullTerm = ${eval1.nullTerm} || ${eval2.nullTerm};
+ ${primitiveForType(resultType)} $primitiveTerm =
${defaultPrimitive(resultType)};
+ if(!$nullTerm) {
+ $primitiveTerm =
(${primitiveForType(resultType)})($resultCode);
+ }
+ """
}
}
val inputTuple = newTermName(s"i")
// TODO: Skip generation of null handling code when expression are not
nullable.
- val primitiveEvaluation: PartialFunction[Expression, Seq[Tree]] = {
+ val primitiveEvaluation: PartialFunction[Expression, String] = {
case b @ BoundReference(ordinal, dataType, nullable) =>
- val nullValue = q"$inputTuple.isNullAt($ordinal)"
- q"""
- val $nullTerm: Boolean = $nullValue
- val $primitiveTerm: ${termForType(dataType)} =
- if($nullTerm)
- ${defaultPrimitive(dataType)}
- else
- ${getColumn(inputTuple, dataType, ordinal)}
- """.children
+ s"""
+ final boolean $nullTerm = $inputTuple.isNullAt($ordinal);
+ final ${primitiveForType(dataType)} $primitiveTerm = $nullTerm ?
+ ${defaultPrimitive(dataType)} : (${getColumn(inputTuple,
dataType, ordinal)});
+ """
case expressions.Literal(null, dataType) =>
- q"""
- val $nullTerm = true
- val $primitiveTerm: ${termForType(dataType)} =
null.asInstanceOf[${termForType(dataType)}]
- """.children
-
- case expressions.Literal(value: Boolean, dataType) =>
- q"""
- val $nullTerm = ${value == null}
- val $primitiveTerm: ${termForType(dataType)} = $value
- """.children
-
- case expressions.Literal(value: UTF8String, dataType) =>
- q"""
- val $nullTerm = ${value == null}
- val $primitiveTerm: ${termForType(dataType)} =
- org.apache.spark.sql.types.UTF8String(${value.getBytes})
- """.children
-
- case expressions.Literal(value: Int, dataType) =>
- q"""
- val $nullTerm = ${value == null}
- val $primitiveTerm: ${termForType(dataType)} = $value
- """.children
-
- case expressions.Literal(value: Long, dataType) =>
- q"""
- val $nullTerm = ${value == null}
- val $primitiveTerm: ${termForType(dataType)} = $value
- """.children
-
- case Cast(e @ BinaryType(), StringType) =>
- val eval = expressionEvaluator(e)
- eval.code ++
- q"""
- val $nullTerm = ${eval.nullTerm}
- val $primitiveTerm =
- if($nullTerm)
- ${defaultPrimitive(StringType)}
- else
-
org.apache.spark.sql.types.UTF8String(${eval.primitiveTerm}.asInstanceOf[Array[Byte]])
- """.children
+ s"""
+ final boolean $nullTerm = true;
+ ${primitiveForType(dataType)} $primitiveTerm =
${defaultPrimitive(dataType)};
+ """
+
+ case expressions.Literal(value: UTF8String, StringType) =>
+ val arr = s"new
byte[]{${value.getBytes.map(_.toString).mkString(", ")}}"
+ s"""
+ final boolean $nullTerm = false;
+ org.apache.spark.sql.types.UTF8String $primitiveTerm =
+ new org.apache.spark.sql.types.UTF8String().set(${arr});
+ """
+
+ case expressions.Literal(value, FloatType) =>
+ s"""
+ final boolean $nullTerm = false;
+ float $primitiveTerm = ${value}f;
+ """
+
+ case expressions.Literal(value, dt @ DecimalType()) =>
+ s"""
+ final boolean $nullTerm = false;
+ ${primitiveForType(dt)} $primitiveTerm = new
${primitiveForType(dt)}().set($value);
+ """
+
+ case expressions.Literal(value, dataType) =>
+ s"""
+ final boolean $nullTerm = false;
+ ${primitiveForType(dataType)} $primitiveTerm = $value;
+ """
+
+ case Cast(child @ BinaryType(), StringType) =>
+ child.castOrNull(c =>
+ s"new org.apache.spark.sql.types.UTF8String().set($c)",
+ StringType)
case Cast(child @ DateType(), StringType) =>
child.castOrNull(c =>
- q"""org.apache.spark.sql.types.UTF8String(
+ s"""new org.apache.spark.sql.types.UTF8String().set(
org.apache.spark.sql.catalyst.util.DateUtils.toString($c))""",
StringType)
- case Cast(child @ NumericType(), IntegerType) =>
- child.castOrNull(c => q"$c.toInt", IntegerType)
+ case Cast(child @ BooleanType(), dt: NumericType) if
!dt.isInstanceOf[DecimalType] =>
+ child.castOrNull(c => s"(${primitiveForType(dt)})($c?1:0)", dt)
- case Cast(child @ NumericType(), LongType) =>
- child.castOrNull(c => q"$c.toLong", LongType)
+ case Cast(child @ DecimalType(), IntegerType) =>
+ child.castOrNull(c => s"($c).toInt()", IntegerType)
- case Cast(child @ NumericType(), DoubleType) =>
- child.castOrNull(c => q"$c.toDouble", DoubleType)
+ case Cast(child @ DecimalType(), dt: NumericType) if
!dt.isInstanceOf[DecimalType] =>
+ child.castOrNull(c => s"($c).to${termForType(dt)}()", dt)
- case Cast(child @ NumericType(), FloatType) =>
- child.castOrNull(c => q"$c.toFloat", FloatType)
+ case Cast(child @ NumericType(), dt: NumericType) if
!dt.isInstanceOf[DecimalType] =>
+ child.castOrNull(c => s"(${primitiveForType(dt)})($c)", dt)
// Special handling required for timestamps in hive test cases since
the toString function
// does not match the expected output.
case Cast(e, StringType) if e.dataType != TimestampType =>
- val eval = expressionEvaluator(e)
- eval.code ++
- q"""
- val $nullTerm = ${eval.nullTerm}
- val $primitiveTerm =
- if($nullTerm)
- ${defaultPrimitive(StringType)}
- else
-
org.apache.spark.sql.types.UTF8String(${eval.primitiveTerm}.toString)
- """.children
+ e.castOrNull(c =>
+ s"new
org.apache.spark.sql.types.UTF8String().set(String.valueOf($c))",
+ StringType)
case EqualTo(e1 @ BinaryType(), e2 @ BinaryType()) =>
(e1, e2).evaluateAs (BooleanType) {
case (eval1, eval2) =>
- q"""
- java.util.Arrays.equals($eval1.asInstanceOf[Array[Byte]],
- $eval2.asInstanceOf[Array[Byte]])
- """
+ s"java.util.Arrays.equals((byte[])$eval1, (byte[])$eval2)"
}
case EqualTo(e1, e2) =>
- (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) =>
q"$eval1 == $eval2" }
+ (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) =>
s"$eval1 == $eval2" }
/* TODO: Fix null semantics.
--- End diff --
This block of commented-out code looks like dead Scala code that can be
removed in this PR. If there's still a TODO task for the null-handling
semantics, then we should file a followup JIRA.
---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]