Github user cloud-fan commented on a diff in the pull request:
https://github.com/apache/spark/pull/22468#discussion_r225521397
--- Diff:
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedSafeProjection.scala
---
@@ -0,0 +1,173 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.catalyst.expressions
+
+import org.apache.spark.SparkException
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp
+import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData,
GenericArrayData, MapData}
+import org.apache.spark.sql.types._
+
+
+/**
+ * An interpreted version of a safe projection.
+ *
+ * @param expressions that produces the resulting fields. These
expressions must be bound
+ * to a schema.
+ */
+class InterpretedSafeProjection(expressions: Seq[Expression]) extends
Projection {
+
+ private[this] val mutableRow = new
SpecificInternalRow(expressions.map(_.dataType))
+
+ private[this] val exprsWithWriters = expressions.zipWithIndex.filter {
+ case (NoOp, _) => false
+ case _ => true
+ }.map { case (e, i) =>
+ val converter = generateSafeValueConverter(e.dataType)
+ val writer = generateRowWriter(i, e.dataType)
+ val f = if (!e.nullable) {
+ (v: Any) => writer(converter(v))
+ } else {
+ (v: Any) => {
+ if (v == null) {
+ mutableRow.setNullAt(i)
+ } else {
+ writer(converter(v))
+ }
+ }
+ }
+ (e, f)
+ }
+
+ private def isPrimitive(dataType: DataType): Boolean = dataType match {
+ case BooleanType => true
+ case ByteType => true
+ case ShortType => true
+ case IntegerType => true
+ case LongType => true
+ case FloatType => true
+ case DoubleType => true
+ case _ => false
+ }
+
+ private def generateSafeValueConverter(dt: DataType): Any => Any = dt
match {
+ case ArrayType(elemType, _) =>
+ if (isPrimitive(elemType)) {
+ v => {
+ val arrayValue = v.asInstanceOf[ArrayData]
+ new GenericArrayData(arrayValue.toArray[Any](elemType))
+ }
+ } else {
+ val elementConverter = generateSafeValueConverter(elemType)
+ v => {
+ val arrayValue = v.asInstanceOf[ArrayData]
+ val result = new Array[Any](arrayValue.numElements())
+ arrayValue.foreach(elemType, (i, e) => {
+ result(i) = elementConverter(e)
+ })
+ new GenericArrayData(result)
+ }
+ }
+
+ case st: StructType =>
+ val fieldTypes = st.fields.map(_.dataType)
+ val fieldConverters = fieldTypes.map(generateSafeValueConverter)
+ v => {
+ val row = v.asInstanceOf[InternalRow]
+ val ar = new Array[Any](row.numFields)
+ var idx = 0
+ while (idx < row.numFields) {
+ ar(idx) = fieldConverters(idx)(row.get(idx, fieldTypes(idx)))
+ idx += 1
+ }
+ new GenericInternalRow(ar)
+ }
+
+ case MapType(keyType, valueType, _) =>
+ lazy val keyConverter = generateSafeValueConverter(keyType)
+ lazy val valueConverter = generateSafeValueConverter(valueType)
+ v => {
+ val mapValue = v.asInstanceOf[MapData]
+ val keys = mapValue.keyArray().toArray[Any](keyType)
+ val values = mapValue.valueArray().toArray[Any](valueType)
+ val convertedKeys =
+ if (isPrimitive(keyType)) keys else keys.map(keyConverter)
--- End diff --
ditto
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]