Github user cloud-fan commented on a diff in the pull request:

    https://github.com/apache/spark/pull/20750#discussion_r174873582
  
    --- Diff: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala
 ---
    @@ -0,0 +1,366 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +package org.apache.spark.sql.catalyst.expressions
    +
    +import org.apache.spark.SparkException
    +import org.apache.spark.sql.catalyst.InternalRow
    +import org.apache.spark.sql.catalyst.expressions.codegen.{BufferHolder, 
UnsafeArrayWriter, UnsafeRowWriter, UnsafeWriter}
    +import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
    +import org.apache.spark.sql.types.{UserDefinedType, _}
    +import org.apache.spark.unsafe.Platform
    +
    +/**
    + * An interpreted unsafe projection. This class reuses the [[UnsafeRow]] 
it produces, a consumer
    + * should copy the row if it is being buffered. This class is not thread 
safe.
    + *
    + * @param expressions that produces the resulting fields. These 
expressions must be bound
    + *                    to a schema.
    + */
    +class InterpretedUnsafeProjection(expressions: Array[Expression]) extends 
UnsafeProjection {
    +  import InterpretedUnsafeProjection._
    +
    +  /** Number of (top level) fields in the resulting row. */
    +  private[this] val numFields = expressions.length
    +
    +  /** Array that expression results. */
    +  private[this] val values = new Array[Any](numFields)
    +
    +  /** The row representing the expression results. */
    +  private[this] val intermediate = new GenericInternalRow(values)
    +
    +  /** The row returned by the projection. */
    +  private[this] val result = new UnsafeRow(numFields)
    +
    +  /** The buffer which holds the resulting row's backing data. */
    +  private[this] val holder = new BufferHolder(result, numFields * 32)
    +
    +  /** The writer that writes the intermediate result to the result row. */
    +  private[this] val writer: InternalRow => Unit = {
    +    val rowWriter = new UnsafeRowWriter(holder, numFields)
    +    val baseWriter = generateStructWriter(
    +      holder,
    +      rowWriter,
    +      expressions.map(e => StructField("", e.dataType, e.nullable)))
    +    if (!expressions.exists(_.nullable)) {
    +      // No nullable fields. The top-level null bit mask will always be 
zeroed out.
    +      baseWriter
    +    } else {
    +      // Zero out the null bit mask before we write the row.
    +      row => {
    +        rowWriter.zeroOutNullBytes()
    +        baseWriter(row)
    +      }
    +    }
    +  }
    +
    +  override def initialize(partitionIndex: Int): Unit = {
    +    expressions.foreach(_.foreach {
    +      case n: Nondeterministic => n.initialize(partitionIndex)
    +      case _ =>
    +    })
    +  }
    +
    +  override def apply(row: InternalRow): UnsafeRow = {
    +    // Put the expression results in the intermediate row.
    +    var i = 0
    +    while (i < numFields) {
    +      values(i) = expressions(i).eval(row)
    +      i += 1
    +    }
    +
    +    // Write the intermediate row to an unsafe row.
    +    holder.reset()
    +    writer(intermediate)
    +    result.setTotalSize(holder.totalSize())
    +    result
    +  }
    +}
    +
    +/**
    + * Helper functions for creating an [[InterpretedUnsafeProjection]].
    + */
    +object InterpretedUnsafeProjection extends UnsafeProjectionCreator {
    +
    +  /**
    +   * Returns an [[UnsafeProjection]] for given sequence of bound 
Expressions.
    +   */
    +  override protected def createProjection(exprs: Seq[Expression]): 
UnsafeProjection = {
    +    // We need to make sure that we do not reuse stateful expressions.
    +    val cleanedExpressions = exprs.map(_.transform {
    +      case s: Stateful => s.freshCopy()
    +    })
    +    new InterpretedUnsafeProjection(cleanedExpressions.toArray)
    +  }
    +
    +  /**
    +   * Generate a struct writer function. The generated function writes an 
[[InternalRow]] to the
    +   * given buffer using the given [[UnsafeRowWriter]].
    +   */
    +  private def generateStructWriter(
    +      bufferHolder: BufferHolder,
    +      rowWriter: UnsafeRowWriter,
    --- End diff --
    
    We can add `UnsafeWriter#getBufferHolder`, so that we don't need to pass 2 
parameters.


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to