Github user concretevitamin commented on a diff in the pull request:
https://github.com/apache/spark/pull/1055#discussion_r13683959
--- Diff:
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
---
@@ -202,3 +203,139 @@ case class If(predicate: Expression, trueValue:
Expression, falseValue: Expressi
override def toString = s"if ($predicate) $trueValue else $falseValue"
}
+
+// scalastyle:off
+/**
+ * Case statements of the form "CASE WHEN a THEN b [WHEN c THEN d]* [ELSE
e] END".
+ * Refer to this link for the corresponding semantics:
+ *
https://cwiki.apache.org/confluence/display/Hive/LanguageManual+UDF#LanguageManualUDF-ConditionalFunctions
+ *
+ * Note that branches are considered in consecutive pairs (cond, val), and
the optional last element
+ * is the val for the default catch-all case (if provided). Hence,
`branches` consist of at least
+ * two elements, and can have an odd or even length.
+ */
+// scalastyle:on
+case class CaseWhen(branches: Seq[Expression]) extends Expression {
+ type EvaluatedType = Any
+ def children = branches
+ def references = children.flatMap(_.references).toSet
+ def dataType = {
+ if (!resolved) {
+ throw new UnresolvedException(this, "cannot resolve due to differing
types in some branches")
+ }
+ branches(1).dataType
+ }
+
+ override def nullable = branches.sliding(2, 2).map {
+ case Seq(cond, value) => value.nullable
+ case Seq(elseValue) => elseValue.nullable
+ }.reduce(_ || _)
+
+ override lazy val resolved = {
+ lazy val allCondBooleans = branches.sliding(2, 2).map {
+ case Seq(cond, value) => cond.dataType == BooleanType
+ case _ => true
+ }.reduce(_ && _)
+ lazy val dataTypes = branches.sliding(2, 2).map {
+ case Seq(cond, value) => value.dataType
+ case Seq(elseValue) => elseValue.dataType
+ }.toSeq
+ lazy val dataTypesEqual =
+ if (dataTypes.size <= 1) true else dataTypes.drop(1).map(_ ==
dataTypes(0)).reduce(_ && _)
+ if (!childrenResolved) false else allCondBooleans && dataTypesEqual
+ }
+
+ /** Written in imperative fashion for performance considerations. Same
for CaseKeyWhen. */
+ override def eval(input: Row): Any = {
+ val branchesArr = branches.toArray
+ val len = branchesArr.length
+ var i = 0
+ // If all branches fail and an elseVal is not provided, the whole
statement
+ // defaults to null, according to Hive's semantics.
+ var res: Any = null
+ while (i < len - 1) {
+ if (branches(i).eval(input) == true) {
+ res = branches(i + 1).eval(input)
+ return res
+ }
+ i += 2
+ }
+ if (i == len - 1) {
+ res = branches(i).eval(input)
+ }
+ res
+ }
+
+ override def toString = {
+ val firstBranch = s"if (${branches(0)} == true) { ${branches(1)} }"
+ val otherBranches = branches.sliding(2, 2).drop(1).map {
+ case Seq(cond, value) => s" else if ($cond == true) { $value }"
+ case Seq(elseValue) => s" else { $elseValue }"
+ }.mkString
+ firstBranch ++ otherBranches
+ }
+}
+
+/**
+ * Case statements of the form "CASE a WHEN b THEN c [WHEN d THEN e]*
[ELSE f] END". This type
+ * of case statements is separated out from the other type mainly due to
performance reason: this
+ * approach avoids branching (based on whether or not the key is provided)
in eval().
+ */
+case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends
Expression {
--- End diff --
I have to admit there are a lot of code duplication here. Directly
subclassing a case class is prohibited, so we need some other ways to achieve
code reuse without paying too much performance penalty (e.g., as discussed
before we don't want to have a single case class that carries a field like
`key: Option[Expression]`). Any suggestions are welcomed!
---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---