Github user yijieshen commented on a diff in the pull request:

    https://github.com/apache/spark/pull/7365#discussion_r34538299
  
    --- Diff: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
 ---
    @@ -421,48 +421,515 @@ case class Cast(child: Expression, dataType: 
DataType) extends UnaryExpression w
     
       protected override def nullSafeEval(input: Any): Any = cast(input)
     
    +  private[this] class CodeHolder private() {
    +    // expression that can be directly assigned to result's primitive
    +    // similar to f in `defineCodeGen`
    +    private var _cg: String => String = null
    +    // statements to put in null safety section
    +    // similar to f in `nullSafeCodeGen`
    +    private var _ns: (String, String, String) => String = null
    +
    +    // child.primitive
    +    def set(f: String => String): CodeHolder = {_cg = f; this}
    +
    +    // child.primitive, result.primitive, result.isNull
    +    def set(f: (String, String, String) => String): CodeHolder = {_ns = f; 
this}
    +
    +    def code(ctx: CodeGenContext, childPrim: String, childNull: String,
    +      resultPrim: String, resultNull: String, resultType: DataType): 
String = {
    +      if (_cg != null) {
    +        s"""
    +          boolean $resultNull = $childNull;
    +          ${ctx.javaType(resultType)} $resultPrim = 
${ctx.defaultValue(resultType)};
    +          if (!${childNull}) {
    +            $resultPrim = ${_cg(childPrim)};
    +          }
    +        """
    +      } else {
    +        s"""
    +          boolean $resultNull = $childNull;
    +          ${ctx.javaType(resultType)} $resultPrim = 
${ctx.defaultValue(resultType)};
    +          if (!${childNull}) {
    +            ${_ns(childPrim, resultPrim, resultNull)}
    +          }
    +        """
    +      }
    +    }
    +  }
    +
    +  private[this] object CodeHolder {
    +    def apply(f: String => String): CodeHolder = (new CodeHolder).set(f)
    +    def apply(f: (String, String, String) => String): CodeHolder = (new 
CodeHolder).set(f)
    +  }
    +
    +  private[this] def getCodeHolder(from: DataType, to: DataType, ctx: 
CodeGenContext) = to match {
    +    case StringType => castToStringCode(from, ctx)
    +    case BinaryType => castToBinaryCode(from)
    +    case DateType => castToDateCode(from)
    +    case decimal: DecimalType => castToDecimalCode(from, decimal)
    +    case TimestampType => castToTimestampCode(from)
    +    case BooleanType => castToBooleanCode(from)
    +    case ByteType => castToByteCode(from)
    +    case ShortType => castToShortCode(from)
    +    case IntegerType => castToIntCode(from)
    +    case FloatType => castToFloatCode(from)
    +    case LongType => castToLongCode(from)
    +    case DoubleType => castToDoubleCode(from)
    +
    +    case array: ArrayType => castArrayCode(from.asInstanceOf[ArrayType], 
array, ctx)
    +    case map: MapType => castMapCode(from.asInstanceOf[MapType], map, ctx)
    +    case struct: StructType => 
castStructCode(from.asInstanceOf[StructType], struct, ctx)
    +    case other => null
    +  }
    +
       override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): 
String = {
    -    // TODO: Add support for more data types.
    -    (child.dataType, dataType) match {
    +    val eval = child.gen(ctx)
    +    val holder = getCodeHolder(child.dataType, dataType, ctx)
    +    if (holder != null) {
    +      eval.code + holder.code(ctx, eval.primitive, eval.isNull, 
ev.primitive, ev.isNull, dataType)
    +    } else {
    +      super.genCode(ctx, ev)
    +    }
    +  }
    +
    +  private[this] def castToStringCode(from: DataType, ctx: CodeGenContext): 
CodeHolder = {
    +    from match {
    +      case BinaryType =>
    +        CodeHolder(c => s"${ctx.stringType}.fromBytes($c)")
    +      case DateType =>
    +        CodeHolder(c => s"""${ctx.stringType}.fromString(
    +        
org.apache.spark.sql.catalyst.util.DateTimeUtils.dateToString($c))""")
    +      case TimestampType =>
    +        CodeHolder(c => s"""${ctx.stringType}.fromString(
    +        
org.apache.spark.sql.catalyst.util.DateTimeUtils.timestampToString($c))""")
    +      case _ =>
    +        CodeHolder(c => 
s"${ctx.stringType}.fromString(String.valueOf($c))")
    +    }
    +  }
    +
    +  private[this] def castToBinaryCode(from: DataType): CodeHolder = from 
match {
    +    case StringType =>
    +      CodeHolder(c => s"$c.getBytes()")
    +  }
    +
    +  private[this] def castToDateCode(from: DataType): CodeHolder = from 
match {
    +    case StringType =>
    +      CodeHolder((c, evPrim, evNull) => s"""
    +        try {
    +          $evPrim = 
org.apache.spark.sql.catalyst.util.DateTimeUtils.fromJavaDate(
    +            java.sql.Date.valueOf($c.toString()));
    +        } catch (java.lang.IllegalArgumentException e) {
    +         $evNull = true;
    +        }
    +       """)
    +    case TimestampType =>
    +      CodeHolder(c => 
s"org.apache.spark.sql.catalyst.util.DateTimeUtils.millisToDays($c / 1000L)")
    +    case _ =>
    +      CodeHolder((c, evPrim, evNull) => s"$evNull = true;")
    +  }
    +
    +  private[this] def changePrecision(d: String, decimalType: DecimalType,
    +      evPrim: String, evNull: String): String = {
    +    decimalType match {
    +      case DecimalType.Unlimited =>
    +        s"$evPrim = $d;"
    +      case DecimalType.Fixed(precision, scale) =>
    +        s"""
    +          if ($d.changePrecision($precision, $scale)) {
    +            $evPrim = $d;
    +          } else {
    +            $evNull = true;
    +          }
    +        """
    +    }
    +  }
    +
    +  private[this] def castToDecimalCode(from: DataType, target: 
DecimalType): CodeHolder = {
    +    from match {
    +      case StringType =>
    +        CodeHolder((c, evPrim, evNull) =>
    +          s"""
    +            try {
    +              org.apache.spark.sql.types.Decimal tmpDecimal =
    +                new org.apache.spark.sql.types.Decimal().set(
    +                  new scala.math.BigDecimal(
    +                    new java.math.BigDecimal($c.toString())));
    +              ${changePrecision("tmpDecimal", target, evPrim, evNull)}
    +            } catch (java.lang.NumberFormatException e) {
    +              $evNull = true;
    +            }
    +          """)
    +      case BooleanType =>
    +        CodeHolder((c, evPrim, evNull) =>
    +          s"""
    +            org.apache.spark.sql.types.Decimal tmpDecimal = null;
    +            if ($c) {
    +              tmpDecimal = new org.apache.spark.sql.types.Decimal().set(1);
    +            } else {
    +              tmpDecimal = new org.apache.spark.sql.types.Decimal().set(0);
    +            }
    +            ${changePrecision("tmpDecimal", target, evPrim, evNull)}
    +          """)
    +      case DateType =>
    +        // date can't cast to decimal in Hive
    +        CodeHolder((c, evPrim, evNull) => s"$evNull = true;")
    +      case TimestampType =>
    +        // Note that we lose precision here.
    +        CodeHolder((c, evPrim, evNull) =>
    +          s"""
    +            org.apache.spark.sql.types.Decimal tmpDecimal =
    +              new org.apache.spark.sql.types.Decimal().set(
    +                
scala.math.BigDecimal.valueOf(${timestampToDoubleCode(c)}));
    +            ${changePrecision("tmpDecimal", target, evPrim, evNull)}
    +          """)
    +      case DecimalType() =>
    +        CodeHolder((c, evPrim, evNull) =>
    +          s"""
    +            org.apache.spark.sql.types.Decimal tmpDecimal = $c.clone();
    +            ${changePrecision("tmpDecimal", target, evPrim, evNull)}
    +          """)
    +      case LongType =>
    +        CodeHolder((c, evPrim, evNull) =>
    +          s"""
    +            org.apache.spark.sql.types.Decimal tmpDecimal =
    +              new org.apache.spark.sql.types.Decimal().set($c);
    +            ${changePrecision("tmpDecimal", target, evPrim, evNull)}
    +          """)
    +      case x: NumericType =>
    +        // All other numeric types can be represented precisely as Doubles
    +        CodeHolder((c, evPrim, evNull) =>
    +          s"""
    +            try {
    +              org.apache.spark.sql.types.Decimal tmpDecimal =
    +                new org.apache.spark.sql.types.Decimal().set(
    +                  scala.math.BigDecimal.valueOf((double) $c));
    +              ${changePrecision("tmpDecimal", target, evPrim, evNull)}
    +            } catch (java.lang.NumberFormatException e) {
    +              $evNull = true;
    +            }
    +          """
    +        )
    +    }
    +  }
    +
    +  private[this] def castToTimestampCode(from: DataType): CodeHolder = from 
match {
    +    case StringType =>
    +      CodeHolder((c, evPrim, evNull) =>
    +        s"""
    +          try {
    +            $evPrim = 
org.apache.spark.sql.catalyst.util.DateTimeUtils.fromJavaTimestamp(
    +              java.sql.Timestamp.valueOf($c.toString()));
    +          } catch (java.lang.IllegalArgumentException e) {
    +            $evNull = true;
    +          }
    +         """
    +      )
    +    case BooleanType =>
    +      CodeHolder(c => s"$c ? 1L : 0")
    +    case _: IntegralType =>
    +      CodeHolder(c => longToTimeStampCode(c))
    +    case DateType =>
    +      CodeHolder(c => 
s"org.apache.spark.sql.catalyst.util.DateTimeUtils.daysToMillis($c) * 1000")
    +    case DecimalType() =>
    +      CodeHolder(c => decimalToTimestampCode(c))
    +    case DoubleType =>
    +      CodeHolder((c, evPrim, evNull) =>
    +        s"""
    +          if (Double.isNaN($c) || Double.isInfinite($c)) {
    +            $evNull = true;
    +          } else {
    +            $evPrim = (long)($c * 1000000L);
    +          }
    +        """
    +      )
    +    case FloatType =>
    +      CodeHolder((c, evPrim, evNull) =>
    +        s"""
    +          if (Float.isNaN($c) || Float.isInfinite($c)) {
    +            $evNull = true;
    +          } else {
    +            $evPrim = (long)($c * 1000000L);
    +          }
    +        """)
    +  }
    +
    +  private[this] def decimalToTimestampCode(d: String): String =
    +    s"($d.toBigDecimal().bigDecimal().multiply(new 
java.math.BigDecimal(1000000L))).longValue()"
    +  private[this] def longToTimeStampCode(l: String): String = s"$l * 1000L"
    +  private[this] def timestampToIntegerCode(ts: String): String =
    +    s"java.lang.Math.floor((double) $ts / 1000000L)"
    +  private[this] def timestampToDoubleCode(ts: String): String = s"$ts / 
1000000.0"
     
    -      case (BinaryType, StringType) =>
    -        defineCodeGen (ctx, ev, c =>
    -          s"${ctx.stringType}.fromBytes($c)")
    +  private[this] def castToBooleanCode(from: DataType): CodeHolder = from 
match {
    +    case StringType =>
    +      CodeHolder(c => s"$c.numBytes() != 0")
    +    case TimestampType =>
    +      CodeHolder(c => s"$c != 0")
    +    case DateType =>
    +      // Hive would return null when cast from date to boolean
    +      CodeHolder((c, evPrim, evNull) => s"$evNull = true;")
    +    case DecimalType() =>
    +      CodeHolder(c => s"!$c.isZero()")
    +    case n: NumericType =>
    +      CodeHolder(c => s"$c != 0")
    +  }
     
    -      case (DateType, StringType) =>
    -        defineCodeGen(ctx, ev, c =>
    -          s"""${ctx.stringType}.fromString(
    -                
org.apache.spark.sql.catalyst.util.DateTimeUtils.dateToString($c))""")
    +  private[this] def castToByteCode(from: DataType): CodeHolder = from 
match {
    +    case StringType =>
    +      CodeHolder((c, evPrim, evNull) =>
    +        s"""
    +          try {
    +            $evPrim = Byte.valueOf($c.toString());
    +          } catch (java.lang.NumberFormatException e) {
    +            $evNull = true;
    +          }
    +        """)
    +    case BooleanType =>
    +      CodeHolder(c => s"$c ? 1 : 0")
    +    case DateType =>
    +      CodeHolder((c, evPrim, evNull) => s"$evNull = true;")
    +    case TimestampType =>
    +      CodeHolder(c => s"(byte) ${timestampToIntegerCode(c)}")
    +    case DecimalType() =>
    +      CodeHolder(c => s"$c.toByte()")
    +    case x: NumericType =>
    +      CodeHolder(c => s"(byte) $c")
    +  }
     
    -      case (TimestampType, StringType) =>
    -        defineCodeGen(ctx, ev, c =>
    -          s"""${ctx.stringType}.fromString(
    -                
org.apache.spark.sql.catalyst.util.DateTimeUtils.timestampToString($c))""")
    +  private[this] def castToShortCode(from: DataType): CodeHolder = from 
match {
    +    case StringType =>
    +      CodeHolder((c, evPrim, evNull) =>
    +        s"""
    +          try {
    +            $evPrim = Short.valueOf($c.toString());
    +          } catch (java.lang.NumberFormatException e) {
    +            $evNull = true;
    +          }
    +        """)
    +    case BooleanType =>
    +      CodeHolder(c => s"$c ? 1 : 0")
    +    case DateType =>
    +      CodeHolder((c, evPrim, evNull) => s"$evNull = true;")
    +    case TimestampType =>
    +      CodeHolder(c => s"(short) ${timestampToIntegerCode(c)}")
    +    case DecimalType() =>
    +      CodeHolder(c => s"$c.toShort()")
    +    case x: NumericType =>
    +      CodeHolder(c => s"(short) $c")
    +  }
     
    -      case (_, StringType) =>
    -        defineCodeGen(ctx, ev, c => 
s"${ctx.stringType}.fromString(String.valueOf($c))")
    +  private[this] def castToIntCode(from: DataType): CodeHolder = from match 
{
    +    case StringType =>
    +      CodeHolder((c, evPrim, evNull) =>
    +        s"""
    +          try {
    +            $evPrim = Integer.valueOf($c.toString());
    +          } catch (java.lang.NumberFormatException e) {
    +            $evNull = true;
    +          }
    +        """)
    +    case BooleanType =>
    +      CodeHolder(c => s"$c ? 1 : 0")
    +    case DateType =>
    +      CodeHolder((c, evPrim, evNull) => s"$evNull = true;")
    +    case TimestampType =>
    +      CodeHolder(c => s"(int) ${timestampToIntegerCode(c)}")
    +    case DecimalType() =>
    +      CodeHolder(c => s"$c.toInt()")
    +    case x: NumericType =>
    +      CodeHolder(c => s"(int) $c")
    +  }
     
    -      // fallback for DecimalType, this must be before other numeric types
    -      case (_, dt: DecimalType) =>
    -        super.genCode(ctx, ev)
    +  private[this] def castToLongCode(from: DataType): CodeHolder = from 
match {
    +    case StringType =>
    +      CodeHolder((c, evPrim, evNull) =>
    +        s"""
    +          try {
    +            $evPrim = Long.valueOf($c.toString());
    +          } catch (java.lang.NumberFormatException e) {
    +            $evNull = true;
    +          }
    +        """)
    +    case BooleanType =>
    +      CodeHolder(c => s"$c ? 1 : 0")
    +    case DateType =>
    +      CodeHolder((c, evPrim, evNull) => s"$evNull = true;")
    +    case TimestampType =>
    +      CodeHolder(c => s"(long) ${timestampToIntegerCode(c)}")
    +    case DecimalType() =>
    +      CodeHolder(c => s"$c.toLong()")
    +    case x: NumericType =>
    +      CodeHolder(c => s"(long) $c")
    +  }
     
    -      case (BooleanType, dt: NumericType) =>
    -        defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dt)})($c ? 1 : 0)")
    +  private[this] def castToFloatCode(from: DataType): CodeHolder = from 
match {
    +    case StringType =>
    +      CodeHolder((c, evPrim, evNull) =>
    +        s"""
    +          try {
    +            $evPrim = Float.valueOf($c.toString());
    +          } catch (java.lang.NumberFormatException e) {
    +            $evNull = true;
    +          }
    +        """)
    +    case BooleanType =>
    +      CodeHolder(c => s"$c ? 1 : 0")
    +    case DateType =>
    +      CodeHolder((c, evPrim, evNull) => s"$evNull = true;")
    +    case TimestampType =>
    +      CodeHolder(c => s"(float) (${timestampToDoubleCode(c)})")
    +    case DecimalType() =>
    +      CodeHolder(c => s"$c.toFloat()")
    +    case x: NumericType =>
    +      CodeHolder(c => s"(float) $c")
    +  }
     
    -      case (dt: DecimalType, BooleanType) =>
    -        defineCodeGen(ctx, ev, c => s"!$c.isZero()")
    +  private[this] def castToDoubleCode(from: DataType): CodeHolder = from 
match {
    +    case StringType =>
    +      CodeHolder((c, evPrim, evNull) =>
    +        s"""
    +          try {
    +            $evPrim = Double.valueOf($c.toString());
    +          } catch (java.lang.NumberFormatException e) {
    +            $evNull = true;
    +          }
    +        """)
    +    case BooleanType =>
    +      CodeHolder(c => s"$c ? 1 : 0")
    +    case DateType =>
    +      CodeHolder((c, evPrim, evNull) => s"$evNull = true;")
    +    case TimestampType =>
    +      CodeHolder(c => timestampToDoubleCode(c))
    +    case DecimalType() =>
    +      CodeHolder(c => s"$c.toDouble()")
    +    case x: NumericType =>
    +      CodeHolder(c => s"(double) $c")
    +  }
     
    -      case (dt: NumericType, BooleanType) =>
    -        defineCodeGen(ctx, ev, c => s"$c != 0")
    +  private[this] def castArrayCode(
    +      from: ArrayType, to: ArrayType, ctx: CodeGenContext): CodeHolder = {
    +    val elementCodeHolder = getCodeHolder(from.elementType, 
to.elementType, ctx)
    +
    +    val arraySeqClass = "scala.collection.mutable.ArraySeq"
    +    val fromElementNull = ctx.freshName("feNull")
    +    val fromElementPrim = ctx.freshName("fePrim")
    +    val toElementNull = ctx.freshName("teNull")
    +    val toElementPrim = ctx.freshName("tePrim")
    +    val size = ctx.freshName("n")
    +    val j = ctx.freshName("j")
    +    val result = ctx.freshName("result")
    +
    +    CodeHolder((c, evPrim, evNull) =>
    +      s"""
    +        final int $size = $c.size();
    +        final $arraySeqClass<Object> $result = new 
$arraySeqClass<Object>($size);
    +        for (int $j = 0; $j < $size; $j ++) {
    +          if ($c.apply($j) == null) {
    +            $result.update($j, null);
    +          } else {
    +            boolean $fromElementNull = false;
    +            ${ctx.boxedType(from.elementType)} $fromElementPrim =
    +              (${ctx.boxedType(from.elementType)}) $c.apply($j);
    +            ${elementCodeHolder.code(ctx, fromElementPrim,
    +              fromElementNull, toElementPrim, toElementNull, 
to.elementType)}
    +            if ($toElementNull) {
    +              $result.update($j, null);
    +            } else {
    +              $result.update($j, $toElementPrim);
    +            }
    +          }
    +        }
    +        $evPrim = $result;
    +      """)
    +  }
     
    -      case (_: DecimalType, dt: NumericType) =>
    -        defineCodeGen(ctx, ev, c => 
s"($c).to${ctx.primitiveTypeName(dt)}()")
    +  private[this] def castMapCode(from: MapType, to: MapType, ctx: 
CodeGenContext): CodeHolder = {
    +    val keyCodeHolder = getCodeHolder(from.keyType, to.keyType, ctx)
    +    val valueCodeHolder = getCodeHolder(from.valueType, to.valueType, ctx)
    +
    +    val hashMapClass = "scala.collection.mutable.HashMap"
    +    val fromKeyPrim = ctx.freshName("fkp")
    +    val fromKeyNull = ctx.freshName("fkn")
    +    val fromValuePrim = ctx.freshName("fvp")
    +    val fromValueNull = ctx.freshName("fvn")
    +    val toKeyPrim = ctx.freshName("tkp")
    +    val toKeyNull = ctx.freshName("tkn")
    +    val toValuePrim = ctx.freshName("tvp")
    +    val toValueNull = ctx.freshName("tvn")
    +    val result = ctx.freshName("result")
    +
    +
    +    CodeHolder((c, evPrim, evNull) =>
    +      s"""
    +        final $hashMapClass $result = new $hashMapClass();
    +        scala.collection.Iterator iter = $c.iterator();
    +        while (iter.hasNext()) {
    +          scala.Tuple2 kv = (scala.Tuple2) iter.next();
    +          boolean $fromKeyNull = false;
    +          ${ctx.boxedType(from.keyType)} $fromKeyPrim =
    +            (${ctx.boxedType(from.keyType)}) kv._1();
    +          ${keyCodeHolder.code(ctx, fromKeyPrim,
    +            fromKeyNull, toKeyPrim, toKeyNull, to.keyType)}
    +
    +          boolean $fromValueNull = kv._2() == null;
    +          if ($fromValueNull) {
    +            $result.put($toKeyPrim, null);
    +          } else {
    +            ${ctx.boxedType(from.valueType)} $fromValuePrim =
    +              (${ctx.boxedType(from.valueType)}) kv._2();
    +            ${valueCodeHolder.code(ctx, fromValuePrim,
    +              fromValueNull, toValuePrim, toValueNull, to.valueType)}
    +            if ($toValueNull) {
    +              $result.put($toKeyPrim, null);
    +            } else {
    +              $result.put($toKeyPrim, $toValuePrim);
    +            }
    +          }
    +        }
    +        $evPrim = $result;
    +      """)
    +  }
     
    -      case (_: NumericType, dt: NumericType) =>
    -        defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dt)})($c)")
    +  private[this] def castStructCode(
    +      from: StructType, to: StructType, ctx: CodeGenContext): CodeHolder = 
{
     
    -      case other =>
    -        super.genCode(ctx, ev)
    +    val fieldsCodeHolder = from.fields.zip(to.fields).map {
    +      case (fromField, toField) => getCodeHolder(fromField.dataType, 
toField.dataType, ctx)
         }
    +    val rowClass = 
"org.apache.spark.sql.catalyst.expressions.GenericMutableRow"
    +    val result = ctx.freshName("result")
    +    val tmpRow = ctx.freshName("tmpRow")
    +
    +    val fieldsEvalCode = fieldsCodeHolder.zipWithIndex.map { case (holder, 
i) => {
    +      val fromFieldPrim = ctx.freshName("ffp")
    +      val fromFieldNull = ctx.freshName("ffn")
    +      val toFieldPrim = ctx.freshName("tfp")
    +      val toFieldNull = ctx.freshName("tfn")
    +      val fromType = ctx.boxedType(from.fields(i).dataType)
    +      s"""
    +        boolean $fromFieldNull = $tmpRow.isNullAt($i);
    +        if ($fromFieldNull) {
    +          $result.update($i, null);
    +        } else {
    +          $fromType $fromFieldPrim = ($fromType) $tmpRow.apply($i);
    +          ${holder.code(ctx, fromFieldPrim,
    +            fromFieldNull, toFieldPrim, toFieldNull, 
to.fields(i).dataType)}
    +          if ($toFieldNull) {
    +            $result.update($i, null);
    +          } else {
    +            $result.update($i, $toFieldPrim);
    +          }
    +        }
    +       """
    +      }
    +    }.mkString("\n")
    +
    +    CodeHolder((c, evPrim, evNull) =>
    +      s"""
    +        final $rowClass $result = new $rowClass(${fieldsCodeHolder.size});
    +        final InternalRow $tmpRow = $c;
    +        $fieldsEvalCode
    --- End diff --
    
    I'm unfolding the field's evaluation here since it's only possible to 
determine each field's calculation code here.


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to