Github user viirya commented on a diff in the pull request:
https://github.com/apache/spark/pull/20858#discussion_r176901843
--- Diff:
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
---
@@ -699,3 +699,88 @@ abstract class TernaryExpression extends Expression {
* and Hive function wrappers.
*/
trait UserDefinedExpression
+
+/**
+ * The trait covers logic for performing null save evaluation and code
generation.
+ */
+trait NullSafeEvaluation extends Expression
+{
+ override def foldable: Boolean = children.forall(_.foldable)
+
+ override def nullable: Boolean = children.exists(_.nullable)
+
+ /**
+ * Default behavior of evaluation according to the default nullability
of NullSafeEvaluation.
+ * If a class utilizing NullSaveEvaluation override [[nullable]],
probably should also
+ * override this.
+ */
+ override def eval(input: InternalRow): Any =
+ {
+ val values = children.map(_.eval(input))
+ if (values.contains(null)) null
+ else nullSafeEval(values)
+ }
+
+ /**
+ * Called by default [[eval]] implementation. If a class utilizing
NullSaveEvaluation keep
+ * the default nullability, they can override this method to save
null-check code. If we need
+ * full control of evaluation process, we should override [[eval]].
+ */
+ protected def nullSafeEval(inputs: Seq[Any]): Any =
+ sys.error(s"The class utilizing NullSaveEvaluation must override
either eval or nullSafeEval")
+
+ /**
+ * Short hand for generating of null save evaluation code.
+ * If either of the sub-expressions is null, the result of this
computation
+ * is assumed to be null.
+ *
+ * @param f accepts a sequence of variable names and returns Java code
to compute the output.
+ */
+ protected def defineCodeGen(
+ ctx: CodegenContext,
+ ev: ExprCode,
+ f: Seq[String] => String): ExprCode = {
+ nullSafeCodeGen(ctx, ev, values => {
+ s"${ev.value} = ${f(values)};"
+ })
+ }
+
+ /**
+ * Called by expressions to generate null safe evaluation code.
+ * If either of the sub-expressions is null, the result of this
computation
+ * is assumed to be null.
+ *
+ * @param f a function that accepts a sequence of non-null evaluation
result names of children
+ * and returns Java code to compute the output.
+ */
+ protected def nullSafeCodeGen(
+ ctx: CodegenContext,
+ ev: ExprCode,
+ f: Seq[String] => String): ExprCode = {
+ val gens = children.map(_.genCode(ctx))
+ val resultCode = f(gens.map(_.value))
+
+ if (nullable) {
+ val nullSafeEval =
+ (s"""
+ ${ev.isNull} = false; // resultCode could change nullability.
+ $resultCode
+ """ /: children.zip(gens)) {
--- End diff --
Use `foldLeft` for readability.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]