Github user cloud-fan commented on a diff in the pull request:

    https://github.com/apache/spark/pull/20750#discussion_r174890611
  
    --- Diff: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala
 ---
    @@ -0,0 +1,366 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +package org.apache.spark.sql.catalyst.expressions
    +
    +import org.apache.spark.SparkException
    +import org.apache.spark.sql.catalyst.InternalRow
    +import org.apache.spark.sql.catalyst.expressions.codegen.{BufferHolder, 
UnsafeArrayWriter, UnsafeRowWriter, UnsafeWriter}
    +import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
    +import org.apache.spark.sql.types.{UserDefinedType, _}
    +import org.apache.spark.unsafe.Platform
    +
    +/**
    + * An interpreted unsafe projection. This class reuses the [[UnsafeRow]] 
it produces, a consumer
    + * should copy the row if it is being buffered. This class is not thread 
safe.
    + *
    + * @param expressions that produces the resulting fields. These 
expressions must be bound
    + *                    to a schema.
    + */
    +class InterpretedUnsafeProjection(expressions: Array[Expression]) extends 
UnsafeProjection {
    +  import InterpretedUnsafeProjection._
    +
    +  /** Number of (top level) fields in the resulting row. */
    +  private[this] val numFields = expressions.length
    +
    +  /** Array that expression results. */
    +  private[this] val values = new Array[Any](numFields)
    +
    +  /** The row representing the expression results. */
    +  private[this] val intermediate = new GenericInternalRow(values)
    +
    +  /** The row returned by the projection. */
    +  private[this] val result = new UnsafeRow(numFields)
    +
    +  /** The buffer which holds the resulting row's backing data. */
    +  private[this] val holder = new BufferHolder(result, numFields * 32)
    +
    +  /** The writer that writes the intermediate result to the result row. */
    +  private[this] val writer: InternalRow => Unit = {
    +    val rowWriter = new UnsafeRowWriter(holder, numFields)
    +    val baseWriter = generateStructWriter(
    +      holder,
    +      rowWriter,
    +      expressions.map(e => StructField("", e.dataType, e.nullable)))
    +    if (!expressions.exists(_.nullable)) {
    +      // No nullable fields. The top-level null bit mask will always be 
zeroed out.
    +      baseWriter
    +    } else {
    +      // Zero out the null bit mask before we write the row.
    +      row => {
    +        rowWriter.zeroOutNullBytes()
    +        baseWriter(row)
    +      }
    +    }
    +  }
    +
    +  override def initialize(partitionIndex: Int): Unit = {
    +    expressions.foreach(_.foreach {
    +      case n: Nondeterministic => n.initialize(partitionIndex)
    +      case _ =>
    +    })
    +  }
    +
    +  override def apply(row: InternalRow): UnsafeRow = {
    +    // Put the expression results in the intermediate row.
    +    var i = 0
    +    while (i < numFields) {
    +      values(i) = expressions(i).eval(row)
    +      i += 1
    +    }
    +
    +    // Write the intermediate row to an unsafe row.
    +    holder.reset()
    +    writer(intermediate)
    +    result.setTotalSize(holder.totalSize())
    +    result
    +  }
    +}
    +
    +/**
    + * Helper functions for creating an [[InterpretedUnsafeProjection]].
    + */
    +object InterpretedUnsafeProjection extends UnsafeProjectionCreator {
    +
    +  /**
    +   * Returns an [[UnsafeProjection]] for given sequence of bound 
Expressions.
    +   */
    +  override protected def createProjection(exprs: Seq[Expression]): 
UnsafeProjection = {
    +    // We need to make sure that we do not reuse stateful expressions.
    +    val cleanedExpressions = exprs.map(_.transform {
    +      case s: Stateful => s.freshCopy()
    +    })
    +    new InterpretedUnsafeProjection(cleanedExpressions.toArray)
    +  }
    +
    +  /**
    +   * Generate a struct writer function. The generated function writes an 
[[InternalRow]] to the
    +   * given buffer using the given [[UnsafeRowWriter]].
    +   */
    +  private def generateStructWriter(
    +      bufferHolder: BufferHolder,
    +      rowWriter: UnsafeRowWriter,
    +      fields: Array[StructField]): InternalRow => Unit = {
    +    val numFields = fields.length
    +
    +    // Create field writers.
    +    val fieldWriters = fields.map { field =>
    +      generateFieldWriter(bufferHolder, rowWriter, field.dataType, 
field.nullable)
    +    }
    +    // Create basic writer.
    +    row => {
    +      var i = 0
    +      while (i < numFields) {
    +        fieldWriters(i).apply(row, i)
    +        i += 1
    +      }
    +    }
    +  }
    +
    +  /**
    +   * Generate a writer function for a struct field, array element, map key 
or map value. The
    +   * generated function writes the element at an index in a 
[[SpecializedGetters]] object (row
    +   * or array) to the given buffer using the given [[UnsafeWriter]].
    +   */
    +  private def generateFieldWriter(
    +      bufferHolder: BufferHolder,
    +      writer: UnsafeWriter,
    +      dt: DataType,
    +      nullable: Boolean): (SpecializedGetters, Int) => Unit = {
    +
    +    // Create the the basic writer.
    +    val unsafeWriter: (SpecializedGetters, Int) => Unit = dt match {
    +      case BooleanType =>
    +        (v, i) => writer.write(i, v.getBoolean(i))
    +
    +      case ByteType =>
    +        (v, i) => writer.write(i, v.getByte(i))
    +
    +      case ShortType =>
    +        (v, i) => writer.write(i, v.getShort(i))
    +
    +      case IntegerType | DateType =>
    +        (v, i) => writer.write(i, v.getInt(i))
    +
    +      case LongType | TimestampType =>
    +        (v, i) => writer.write(i, v.getLong(i))
    +
    +      case FloatType =>
    +        (v, i) => writer.write(i, v.getFloat(i))
    +
    +      case DoubleType =>
    +        (v, i) => writer.write(i, v.getDouble(i))
    +
    +      case DecimalType.Fixed(precision, scale) =>
    +        (v, i) => writer.write(i, v.getDecimal(i, precision, scale), 
precision, scale)
    +
    +      case CalendarIntervalType =>
    +        (v, i) => writer.write(i, v.getInterval(i))
    +
    +      case BinaryType =>
    +        (v, i) => writer.write(i, v.getBinary(i))
    +
    +      case StringType =>
    +        (v, i) => writer.write(i, v.getUTF8String(i))
    +
    +      case StructType(fields) =>
    +        val numFields = fields.length
    +        val rowWriter = new UnsafeRowWriter(bufferHolder, numFields)
    +        val structWriter = generateStructWriter(bufferHolder, rowWriter, 
fields)
    +        (v, i) => {
    +          val tmpCursor = bufferHolder.cursor
    +          v.getStruct(i, fields.length) match {
    +            case row: UnsafeRow =>
    +              writeUnsafeData(
    +                bufferHolder,
    +                row.getBaseObject,
    +                row.getBaseOffset,
    +                row.getSizeInBytes)
    +            case row =>
    +              // Nested struct. We don't know where this will start 
because a row can be
    +              // variable length, so we need to update the offsets and 
zero out the bit mask.
    +              rowWriter.reset()
    +              structWriter.apply(row)
    +          }
    +          writer.setOffsetAndSize(i, tmpCursor, bufferHolder.cursor - 
tmpCursor)
    +        }
    +
    +      case ArrayType(elementType, containsNull) =>
    +        val arrayWriter = new UnsafeArrayWriter
    +        val elementSize = getElementSize(elementType)
    +        val elementWriter = generateFieldWriter(
    +          bufferHolder,
    +          arrayWriter,
    +          elementType,
    +          containsNull)
    +        (v, i) => {
    +          val tmpCursor = bufferHolder.cursor
    +          writeArray(bufferHolder, arrayWriter, elementWriter, 
v.getArray(i), elementSize)
    +          writer.setOffsetAndSize(i, tmpCursor, bufferHolder.cursor - 
tmpCursor)
    +        }
    +
    +      case MapType(keyType, valueType, valueContainsNull) =>
    +        val keyArrayWriter = new UnsafeArrayWriter
    +        val keySize = getElementSize(keyType)
    +        val keyWriter = generateFieldWriter(
    +          bufferHolder,
    +          keyArrayWriter,
    +          keyType,
    +          nullable = false)
    +        val valueArrayWriter = new UnsafeArrayWriter
    +        val valueSize = getElementSize(valueType)
    +        val valueWriter = generateFieldWriter(
    +          bufferHolder,
    +          valueArrayWriter,
    +          valueType,
    +          valueContainsNull)
    +        (v, i) => {
    +          val tmpCursor = bufferHolder.cursor
    +          v.getMap(i) match {
    +            case map: UnsafeMapData =>
    +              writeUnsafeData(
    +                bufferHolder,
    +                map.getBaseObject,
    +                map.getBaseOffset,
    +                map.getSizeInBytes)
    +            case map =>
    +              // preserve 8 bytes to write the key array numBytes later.
    +              bufferHolder.grow(8)
    +              bufferHolder.cursor += 8
    +
    +              // Write the keys and write the numBytes of key array into 
the first 8 bytes.
    +              writeArray(bufferHolder, keyArrayWriter, keyWriter, 
map.keyArray(), keySize)
    +              Platform.putLong(bufferHolder.buffer, tmpCursor, 
bufferHolder.cursor - tmpCursor - 8)
    +
    +              // Write the values.
    +              writeArray(bufferHolder, valueArrayWriter, valueWriter, 
map.valueArray(), valueSize)
    +          }
    +          writer.setOffsetAndSize(i, tmpCursor, bufferHolder.cursor - 
tmpCursor)
    +        }
    +
    +      case udt: UserDefinedType[_] =>
    +        generateFieldWriter(bufferHolder, writer, udt.sqlType, nullable)
    +
    +      case NullType =>
    +        (_, _) => {}
    +
    +      case _ =>
    +        throw new SparkException(s"Unsupported data type $dt")
    +    }
    +
    +    // Always wrap the writer with a null safe version.
    +    dt match {
    +      case _: UserDefinedType[_] =>
    +        // The null wrapper depends on the sql type and not on the UDT.
    +        unsafeWriter
    +      case DecimalType.Fixed(precision, _) if precision > 
Decimal.MAX_LONG_DIGITS =>
    +        // We can't call setNullAt() for DecimalType with precision larger 
than 18, we call write
    +        // directly. We can use the unwrapped writer directly.
    +        unsafeWriter
    +      case BooleanType | ByteType =>
    +        (v, i) => {
    +          if (!v.isNullAt(i)) {
    +            unsafeWriter(v, i)
    +          } else {
    +            writer.setNull1Bytes(i)
    +          }
    +        }
    +      case ShortType =>
    +        (v, i) => {
    +          if (!v.isNullAt(i)) {
    +            unsafeWriter(v, i)
    +          } else {
    +            writer.setNull2Bytes(i)
    +          }
    +        }
    +      case IntegerType | DateType | FloatType =>
    +        (v, i) => {
    +          if (!v.isNullAt(i)) {
    +            unsafeWriter(v, i)
    +          } else {
    +            writer.setNull4Bytes(i)
    +          }
    +        }
    +      case _ =>
    +        (v, i) => {
    +          if (!v.isNullAt(i)) {
    +            unsafeWriter(v, i)
    +          } else {
    +            writer.setNull8Bytes(i)
    --- End diff --
    
    null type will hit this branch, can we add a test to make sure it works?


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to