Github user ueshin commented on a diff in the pull request:

    https://github.com/apache/spark/pull/21954#discussion_r208282941
  
    --- Diff: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/higherOrderFunctions.scala
 ---
    @@ -0,0 +1,166 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.sql.catalyst.analysis
    +
    +import org.apache.spark.sql.catalyst.catalog.SessionCatalog
    +import org.apache.spark.sql.catalyst.expressions._
    +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
    +import org.apache.spark.sql.catalyst.rules.Rule
    +import org.apache.spark.sql.internal.SQLConf
    +import org.apache.spark.sql.types.DataType
    +
    +/**
    + * Resolve a higher order functions from the catalog. This is different 
from regular function
    + * resolution because lambda functions can only be resolved after the 
function has been resolved;
    + * so we need to resolve higher order function when all children are 
either resolved or a lambda
    + * function.
    + */
    +case class ResolveHigherOrderFunctions(catalog: SessionCatalog) extends 
Rule[LogicalPlan] {
    +
    +  override def apply(plan: LogicalPlan): LogicalPlan = 
plan.resolveOperators {
    +    case q: LogicalPlan =>
    +      q.transformExpressions {
    +        case u @ UnresolvedFunction(fn, children, false)
    +            if hasLambdaAndResolvedArguments(children) =>
    +          withPosition(u) {
    +            catalog.lookupFunction(fn, children) match {
    +              case func: HigherOrderFunction => func
    +              case other => other.failAnalysis(
    +                "A lambda function should only be used in a higher order 
function. However, " +
    +                  s"its class is ${other.getClass.getCanonicalName}, which 
is not a " +
    +                  s"higher order function.")
    +            }
    +          }
    +      }
    +  }
    +
    +  /**
    +   * Check if the arguments of a function are either resolved or a lambda 
function.
    +   */
    +  private def hasLambdaAndResolvedArguments(expressions: Seq[Expression]): 
Boolean = {
    +    val (lambdas, others) = 
expressions.partition(_.isInstanceOf[LambdaFunction])
    +    lambdas.nonEmpty && others.forall(_.resolved)
    +  }
    +}
    +
    +/**
    + * Resolve the lambda variables exposed by a higher order functions.
    + *
    + * This rule works in two steps:
    + * [1]. Bind the anonymous variables exposed by the higher order function 
to the lambda function's
    + *      arguments; this creates named and typed lambda variables. The 
argument names are checked
    + *      for duplicates and the number of arguments are checked during this 
step.
    + * [2]. Resolve the used lambda variables used in the lambda function's 
function expression tree.
    + *      Note that we allow the use of variables from outside the current 
lambda, this can either
    + *      be a lambda function defined in an outer scope, or a attribute in 
produced by the plan's
    + *      child. If names are duplicate, the name defined in the most inner 
scope is used.
    + */
    +case class ResolveLambdaVariables(conf: SQLConf) extends Rule[LogicalPlan] 
{
    +
    +  type LambdaVariableMap = Map[String, NamedExpression]
    +
    +  private val canonicalizer = {
    +    if (!conf.caseSensitiveAnalysis) {
    +      s: String => s.toLowerCase
    +    } else {
    +      s: String => s
    +    }
    +  }
    +
    +  override def apply(plan: LogicalPlan): LogicalPlan = {
    +    plan.resolveOperators {
    +      case q: LogicalPlan =>
    +        q.mapExpressions(resolve(_, Map.empty))
    +    }
    +  }
    +
    +  /**
    +   * Create a bound lambda function by binding the arguments of a lambda 
function to the given
    +   * partial arguments (dataType and nullability only). If the expression 
happens to be an already
    +   * bound lambda function then we assume it has been bound to the correct 
arguments and do
    +   * nothing. This function will produce a lambda function with hidden 
arguments when it is passed
    +   * an arbitrary expression.
    +   */
    +  private def createLambda(
    +      e: Expression,
    +      partialArguments: Seq[(DataType, Boolean)]): LambdaFunction = e 
match {
    +    case f: LambdaFunction if f.bound => f
    +
    +    case LambdaFunction(function, names, _) =>
    +      if (names.size != partialArguments.size) {
    +        e.failAnalysis(
    +          s"The number of lambda function arguments '${names.size}' does 
not " +
    +            "match the number of arguments expected by the higher order 
function " +
    +            s"'${partialArguments.size}'.")
    +      }
    +
    +      if (names.map(a => canonicalizer(a.name)).distinct.size < 
names.size) {
    +        e.failAnalysis(
    +          "Lambda function arguments should not have names that are 
semantically the same.")
    +      }
    +
    +      val arguments = partialArguments.zip(names).map {
    +        case ((dataType, nullable), ne) =>
    +          NamedLambdaVariable(ne.name, dataType, nullable)
    +      }
    +      LambdaFunction(function, arguments)
    +
    +    case _ =>
    +      // This expression does not consume any of the lambda's arguments 
(it is independent). We do
    +      // create a lambda function with default parameters because this is 
expected by the higher
    +      // order function. Note that we hide the lambda variables produced 
by this function in order
    +      // to prevent accidental naming collisions.
    +      val arguments = partialArguments.zipWithIndex.map {
    +        case ((dataType, nullable), i) =>
    +          NamedLambdaVariable(s"col$i", dataType, nullable)
    +      }
    +      LambdaFunction(e, arguments, hidden = true)
    +  }
    +
    +  /**
    +   * Resolve lambda variables in the expression subtree, using the passed 
lambda variable registry.
    +   */
    +  private def resolve(e: Expression, parentLambdaMap: LambdaVariableMap): 
Expression = e match {
    +    case _ if e.resolved => e
    +
    +    case h: HigherOrderFunction if h.inputResolved =>
    --- End diff --
    
    Let me think about it later.


---

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to