Github user davies commented on a diff in the pull request:

    https://github.com/apache/spark/pull/7605#discussion_r35396814
  
    --- Diff: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
 ---
    @@ -334,144 +334,94 @@ object HiveTypeCoercion {
        * - SHORT gets turned into DECIMAL(5, 0)
        * - INT gets turned into DECIMAL(10, 0)
        * - LONG gets turned into DECIMAL(20, 0)
    -   * - FLOAT and DOUBLE
    -   *   1. Union, Intersect and Except operations:
    -   *      FLOAT gets turned into DECIMAL(7, 7), DOUBLE gets turned into 
DECIMAL(15, 15) (this is the
    -   *      same as Hive)
    -   *   2. Other operation:
    -   *      FLOAT and DOUBLE cause fixed-length decimals to turn into DOUBLE 
(this is the same as Hive,
    -   *   but note that unlimited decimals are considered bigger than doubles 
in WidenTypes)
    +   * - FLOAT and DOUBLE cause fixed-length decimals to turn into DOUBLE
    +   *
    +   * Note: Union/Except/Interact is handled by WidenTypes
        */
       // scalastyle:on
       object DecimalPrecision extends Rule[LogicalPlan] {
         import scala.math.{max, min}
     
    -    // Conversion rules for integer types into fixed-precision decimals
    -    private val intTypeToFixed: Map[DataType, DecimalType] = Map(
    -      ByteType -> DecimalType(3, 0),
    -      ShortType -> DecimalType(5, 0),
    -      IntegerType -> DecimalType(10, 0),
    -      LongType -> DecimalType(20, 0)
    -    )
    -
         private def isFloat(t: DataType): Boolean = t == FloatType || t == 
DoubleType
     
    -    // Conversion rules for float and double into fixed-precision decimals
    -    private val floatTypeToFixed: Map[DataType, DecimalType] = Map(
    -      FloatType -> DecimalType(7, 7),
    -      DoubleType -> DecimalType(15, 15)
    -    )
    -
    -    private def castDecimalPrecision(
    -        left: LogicalPlan,
    -        right: LogicalPlan): (LogicalPlan, LogicalPlan) = {
    -      val castedInput = left.output.zip(right.output).map {
    -        case (lhs, rhs) if lhs.dataType != rhs.dataType =>
    -          (lhs.dataType, rhs.dataType) match {
    -            case (DecimalType.Fixed(p1, s1), DecimalType.Fixed(p2, s2)) =>
    -              // Decimals with precision/scale p1/s2 and p2/s2  will be 
promoted to
    -              // DecimalType(max(s1, s2) + max(p1-s1, p2-s2), max(s1, s2))
    -              val fixedType = DecimalType(max(s1, s2) + max(p1 - s1, p2 - 
s2), max(s1, s2))
    -              (Alias(Cast(lhs, fixedType), lhs.name)(), Alias(Cast(rhs, 
fixedType), rhs.name)())
    -            case (t, DecimalType.Fixed(p, s)) if 
intTypeToFixed.contains(t) =>
    -              (Alias(Cast(lhs, intTypeToFixed(t)), lhs.name)(), rhs)
    -            case (DecimalType.Fixed(p, s), t) if 
intTypeToFixed.contains(t) =>
    -              (lhs, Alias(Cast(rhs, intTypeToFixed(t)), rhs.name)())
    -            case (t, DecimalType.Fixed(p, s)) if 
floatTypeToFixed.contains(t) =>
    -              (Alias(Cast(lhs, floatTypeToFixed(t)), lhs.name)(), rhs)
    -            case (DecimalType.Fixed(p, s), t) if 
floatTypeToFixed.contains(t) =>
    -              (lhs, Alias(Cast(rhs, floatTypeToFixed(t)), rhs.name)())
    -            case _ => (lhs, rhs)
    -          }
    -        case other => other
    -      }
    -
    -      val (castedLeft, castedRight) = castedInput.unzip
    +    // Returns the wider decimal type that's wider than both of them
    +    def widerDecimalType(d1: DecimalType, d2: DecimalType): DecimalType = {
    +      widerDecimalType(d1.precision, d1.scale, d2.precision, d2.scale)
    +    }
    +    // max(s1, s2) + max(p1-s1, p2-s2), max(s1, s2)
    +    def widerDecimalType(p1: Int, s1: Int, p2: Int, s2: Int): DecimalType 
= {
    +      val scale = max(s1, s2)
    +      val range = max(p1 - s1, p2 - s2)
    +      DecimalType.bounded(range + scale, scale)
    +    }
     
    -      val newLeft =
    -        if (castedLeft.map(_.dataType) != left.output.map(_.dataType)) {
    -          Project(castedLeft, left)
    -        } else {
    -          left
    -        }
    +    /**
    +     * An expression used to wrap the children when promote the precision 
of DecimalType to avoid
    +     * promote multiple times.
    +     */
    +    case class ChangePrecision(child: Expression) extends UnaryExpression {
    +      override def dataType: DataType = child.dataType
    +      override def eval(input: InternalRow): Any = child.eval(input)
    +      override def gen(ctx: CodeGenContext): GeneratedExpressionCode = 
child.gen(ctx)
    +      override protected def genCode(ctx: CodeGenContext, ev: 
GeneratedExpressionCode): String = ""
    +      override def prettyName: String = "change_precision"
    +    }
     
    -      val newRight =
    -        if (castedRight.map(_.dataType) != right.output.map(_.dataType)) {
    -          Project(castedRight, right)
    -        } else {
    -          right
    -        }
    -      (newLeft, newRight)
    +    def changePrecision(e: Expression, dataType: DataType): Expression = {
    +      ChangePrecision(Cast(e, dataType))
         }
     
         def apply(plan: LogicalPlan): LogicalPlan = plan transform {
    -      // fix decimal precision for union, intersect and except
    -      case u @ Union(left, right) if u.childrenResolved && !u.resolved =>
    -        val (newLeft, newRight) = castDecimalPrecision(left, right)
    -        Union(newLeft, newRight)
    -      case i @ Intersect(left, right) if i.childrenResolved && !i.resolved 
=>
    -        val (newLeft, newRight) = castDecimalPrecision(left, right)
    -        Intersect(newLeft, newRight)
    -      case e @ Except(left, right) if e.childrenResolved && !e.resolved =>
    -        val (newLeft, newRight) = castDecimalPrecision(left, right)
    -        Except(newLeft, newRight)
    -
           // fix decimal precision for expressions
           case q => q.transformExpressions {
             // Skip nodes whose children have not been resolved yet
             case e if !e.childrenResolved => e
     
    +        // Skip nodes who is already promoted
    +        case e: BinaryArithmetic if e.left.isInstanceOf[ChangePrecision] 
=> e
    +
             case Add(e1 @ DecimalType.Expression(p1, s1), e2 @ 
DecimalType.Expression(p2, s2)) =>
    -          Cast(
    -            Add(Cast(e1, DecimalType.Unlimited), Cast(e2, 
DecimalType.Unlimited)),
    -            DecimalType(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, 
s2))
    -          )
    +          val dt = DecimalType.bounded(max(s1, s2) + max(p1 - s1, p2 - s2) 
+ 1, max(s1, s2))
    +          Add(changePrecision(e1, dt), changePrecision(e2, dt))
     
             case Subtract(e1 @ DecimalType.Expression(p1, s1), e2 @ 
DecimalType.Expression(p2, s2)) =>
    -          Cast(
    -            Subtract(Cast(e1, DecimalType.Unlimited), Cast(e2, 
DecimalType.Unlimited)),
    -            DecimalType(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, 
s2))
    -          )
    +          val dt = DecimalType.bounded(max(s1, s2) + max(p1 - s1, p2 - s2) 
+ 1, max(s1, s2))
    +          Subtract(changePrecision(e1, dt), changePrecision(e2, dt))
     
             case Multiply(e1 @ DecimalType.Expression(p1, s1), e2 @ 
DecimalType.Expression(p2, s2)) =>
    -          Cast(
    -            Multiply(Cast(e1, DecimalType.Unlimited), Cast(e2, 
DecimalType.Unlimited)),
    -            DecimalType(p1 + p2 + 1, s1 + s2)
    -          )
    +          val dt = DecimalType.bounded(p1 + p2 + 1, s1 + s2)
    +          Multiply(changePrecision(e1, dt), changePrecision(e2, dt))
     
             case Divide(e1 @ DecimalType.Expression(p1, s1), e2 @ 
DecimalType.Expression(p2, s2)) =>
    -          Cast(
    -            Divide(Cast(e1, DecimalType.Unlimited), Cast(e2, 
DecimalType.Unlimited)),
    -            DecimalType(p1 - s1 + s2 + max(6, s1 + p2 + 1), max(6, s1 + p2 
+ 1))
    -          )
    +          val dt = DecimalType.bounded(p1 - s1 + s2 + max(6, s1 + p2 + 1), 
max(6, s1 + p2 + 1))
    +          Divide(changePrecision(e1, dt), changePrecision(e2, dt))
     
             case Remainder(e1 @ DecimalType.Expression(p1, s1), e2 @ 
DecimalType.Expression(p2, s2)) =>
    -          Cast(
    -            Remainder(Cast(e1, DecimalType.Unlimited), Cast(e2, 
DecimalType.Unlimited)),
    -            DecimalType(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2))
    -          )
    +          val resultType = DecimalType.bounded(min(p1 - s1, p2 - s2) + 
max(s1, s2), max(s1, s2))
    +          // resultType may have lower precision, so we cast them into 
wider type first.
    +          val widerType = widerDecimalType(p1, s1, p2, s2)
    +          Cast(Remainder(changePrecision(e1, widerType), 
changePrecision(e2, widerType)),
    +            resultType)
     
             case Pmod(e1 @ DecimalType.Expression(p1, s1), e2 @ 
DecimalType.Expression(p2, s2)) =>
    -          Cast(
    -            Pmod(Cast(e1, DecimalType.Unlimited), Cast(e2, 
DecimalType.Unlimited)),
    -            DecimalType(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2))
    -          )
    +          val resultType = DecimalType.bounded(min(p1 - s1, p2 - s2) + 
max(s1, s2), max(s1, s2))
    +          // resultType may have lower precision, so we cast them into 
wider type first.
    +          val widerType = widerDecimalType(p1, s1, p2, s2)
    +          Cast(Pmod(changePrecision(e1, widerType), changePrecision(e2, 
widerType)), resultType)
     
    -        // When we compare 2 decimal types with different precisions, cast 
them to the smallest
    -        // common precision.
             case b @ BinaryComparison(e1 @ DecimalType.Expression(p1, s1),
                                       e2 @ DecimalType.Expression(p2, s2)) if 
p1 != p2 || s1 != s2 =>
    -          val resultType = DecimalType(max(p1, p2), max(s1, s2))
    +          val resultType = widerDecimalType(p1, s1, p2, s2)
    --- End diff --
    
    It's different thing, findTightestCommonTypeOfTwo will return one datatype 
from the two, but here, it could a third one, which is wider than both of them.


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to