Github user rxin commented on a diff in the pull request:
https://github.com/apache/spark/pull/7605#discussion_r35397548
--- Diff:
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
---
@@ -334,144 +334,94 @@ object HiveTypeCoercion {
* - SHORT gets turned into DECIMAL(5, 0)
* - INT gets turned into DECIMAL(10, 0)
* - LONG gets turned into DECIMAL(20, 0)
- * - FLOAT and DOUBLE
- * 1. Union, Intersect and Except operations:
- * FLOAT gets turned into DECIMAL(7, 7), DOUBLE gets turned into
DECIMAL(15, 15) (this is the
- * same as Hive)
- * 2. Other operation:
- * FLOAT and DOUBLE cause fixed-length decimals to turn into DOUBLE
(this is the same as Hive,
- * but note that unlimited decimals are considered bigger than doubles
in WidenTypes)
+ * - FLOAT and DOUBLE cause fixed-length decimals to turn into DOUBLE
+ *
+ * Note: Union/Except/Interact is handled by WidenTypes
*/
// scalastyle:on
object DecimalPrecision extends Rule[LogicalPlan] {
import scala.math.{max, min}
- // Conversion rules for integer types into fixed-precision decimals
- private val intTypeToFixed: Map[DataType, DecimalType] = Map(
- ByteType -> DecimalType(3, 0),
- ShortType -> DecimalType(5, 0),
- IntegerType -> DecimalType(10, 0),
- LongType -> DecimalType(20, 0)
- )
-
private def isFloat(t: DataType): Boolean = t == FloatType || t ==
DoubleType
- // Conversion rules for float and double into fixed-precision decimals
- private val floatTypeToFixed: Map[DataType, DecimalType] = Map(
- FloatType -> DecimalType(7, 7),
- DoubleType -> DecimalType(15, 15)
- )
-
- private def castDecimalPrecision(
- left: LogicalPlan,
- right: LogicalPlan): (LogicalPlan, LogicalPlan) = {
- val castedInput = left.output.zip(right.output).map {
- case (lhs, rhs) if lhs.dataType != rhs.dataType =>
- (lhs.dataType, rhs.dataType) match {
- case (DecimalType.Fixed(p1, s1), DecimalType.Fixed(p2, s2)) =>
- // Decimals with precision/scale p1/s2 and p2/s2 will be
promoted to
- // DecimalType(max(s1, s2) + max(p1-s1, p2-s2), max(s1, s2))
- val fixedType = DecimalType(max(s1, s2) + max(p1 - s1, p2 -
s2), max(s1, s2))
- (Alias(Cast(lhs, fixedType), lhs.name)(), Alias(Cast(rhs,
fixedType), rhs.name)())
- case (t, DecimalType.Fixed(p, s)) if
intTypeToFixed.contains(t) =>
- (Alias(Cast(lhs, intTypeToFixed(t)), lhs.name)(), rhs)
- case (DecimalType.Fixed(p, s), t) if
intTypeToFixed.contains(t) =>
- (lhs, Alias(Cast(rhs, intTypeToFixed(t)), rhs.name)())
- case (t, DecimalType.Fixed(p, s)) if
floatTypeToFixed.contains(t) =>
- (Alias(Cast(lhs, floatTypeToFixed(t)), lhs.name)(), rhs)
- case (DecimalType.Fixed(p, s), t) if
floatTypeToFixed.contains(t) =>
- (lhs, Alias(Cast(rhs, floatTypeToFixed(t)), rhs.name)())
- case _ => (lhs, rhs)
- }
- case other => other
- }
-
- val (castedLeft, castedRight) = castedInput.unzip
+ // Returns the wider decimal type that's wider than both of them
+ def widerDecimalType(d1: DecimalType, d2: DecimalType): DecimalType = {
+ widerDecimalType(d1.precision, d1.scale, d2.precision, d2.scale)
+ }
+ // max(s1, s2) + max(p1-s1, p2-s2), max(s1, s2)
+ def widerDecimalType(p1: Int, s1: Int, p2: Int, s2: Int): DecimalType
= {
+ val scale = max(s1, s2)
+ val range = max(p1 - s1, p2 - s2)
+ DecimalType.bounded(range + scale, scale)
+ }
- val newLeft =
- if (castedLeft.map(_.dataType) != left.output.map(_.dataType)) {
- Project(castedLeft, left)
- } else {
- left
- }
+ /**
+ * An expression used to wrap the children when promote the precision
of DecimalType to avoid
+ * promote multiple times.
+ */
+ case class ChangePrecision(child: Expression) extends UnaryExpression {
+ override def dataType: DataType = child.dataType
+ override def eval(input: InternalRow): Any = child.eval(input)
+ override def gen(ctx: CodeGenContext): GeneratedExpressionCode =
child.gen(ctx)
+ override protected def genCode(ctx: CodeGenContext, ev:
GeneratedExpressionCode): String = ""
+ override def prettyName: String = "change_precision"
+ }
- val newRight =
- if (castedRight.map(_.dataType) != right.output.map(_.dataType)) {
- Project(castedRight, right)
- } else {
- right
- }
- (newLeft, newRight)
+ def changePrecision(e: Expression, dataType: DataType): Expression = {
+ ChangePrecision(Cast(e, dataType))
}
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
- // fix decimal precision for union, intersect and except
- case u @ Union(left, right) if u.childrenResolved && !u.resolved =>
- val (newLeft, newRight) = castDecimalPrecision(left, right)
- Union(newLeft, newRight)
- case i @ Intersect(left, right) if i.childrenResolved && !i.resolved
=>
- val (newLeft, newRight) = castDecimalPrecision(left, right)
- Intersect(newLeft, newRight)
- case e @ Except(left, right) if e.childrenResolved && !e.resolved =>
- val (newLeft, newRight) = castDecimalPrecision(left, right)
- Except(newLeft, newRight)
-
// fix decimal precision for expressions
case q => q.transformExpressions {
// Skip nodes whose children have not been resolved yet
case e if !e.childrenResolved => e
+ // Skip nodes who is already promoted
+ case e: BinaryArithmetic if e.left.isInstanceOf[ChangePrecision]
=> e
+
case Add(e1 @ DecimalType.Expression(p1, s1), e2 @
DecimalType.Expression(p2, s2)) =>
- Cast(
- Add(Cast(e1, DecimalType.Unlimited), Cast(e2,
DecimalType.Unlimited)),
- DecimalType(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1,
s2))
- )
+ val dt = DecimalType.bounded(max(s1, s2) + max(p1 - s1, p2 - s2)
+ 1, max(s1, s2))
+ Add(changePrecision(e1, dt), changePrecision(e2, dt))
case Subtract(e1 @ DecimalType.Expression(p1, s1), e2 @
DecimalType.Expression(p2, s2)) =>
- Cast(
- Subtract(Cast(e1, DecimalType.Unlimited), Cast(e2,
DecimalType.Unlimited)),
- DecimalType(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1,
s2))
- )
+ val dt = DecimalType.bounded(max(s1, s2) + max(p1 - s1, p2 - s2)
+ 1, max(s1, s2))
+ Subtract(changePrecision(e1, dt), changePrecision(e2, dt))
case Multiply(e1 @ DecimalType.Expression(p1, s1), e2 @
DecimalType.Expression(p2, s2)) =>
- Cast(
- Multiply(Cast(e1, DecimalType.Unlimited), Cast(e2,
DecimalType.Unlimited)),
- DecimalType(p1 + p2 + 1, s1 + s2)
- )
+ val dt = DecimalType.bounded(p1 + p2 + 1, s1 + s2)
+ Multiply(changePrecision(e1, dt), changePrecision(e2, dt))
case Divide(e1 @ DecimalType.Expression(p1, s1), e2 @
DecimalType.Expression(p2, s2)) =>
- Cast(
- Divide(Cast(e1, DecimalType.Unlimited), Cast(e2,
DecimalType.Unlimited)),
- DecimalType(p1 - s1 + s2 + max(6, s1 + p2 + 1), max(6, s1 + p2
+ 1))
- )
+ val dt = DecimalType.bounded(p1 - s1 + s2 + max(6, s1 + p2 + 1),
max(6, s1 + p2 + 1))
+ Divide(changePrecision(e1, dt), changePrecision(e2, dt))
case Remainder(e1 @ DecimalType.Expression(p1, s1), e2 @
DecimalType.Expression(p2, s2)) =>
- Cast(
- Remainder(Cast(e1, DecimalType.Unlimited), Cast(e2,
DecimalType.Unlimited)),
- DecimalType(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2))
- )
+ val resultType = DecimalType.bounded(min(p1 - s1, p2 - s2) +
max(s1, s2), max(s1, s2))
+ // resultType may have lower precision, so we cast them into
wider type first.
+ val widerType = widerDecimalType(p1, s1, p2, s2)
+ Cast(Remainder(changePrecision(e1, widerType),
changePrecision(e2, widerType)),
+ resultType)
case Pmod(e1 @ DecimalType.Expression(p1, s1), e2 @
DecimalType.Expression(p2, s2)) =>
- Cast(
- Pmod(Cast(e1, DecimalType.Unlimited), Cast(e2,
DecimalType.Unlimited)),
- DecimalType(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2))
- )
+ val resultType = DecimalType.bounded(min(p1 - s1, p2 - s2) +
max(s1, s2), max(s1, s2))
+ // resultType may have lower precision, so we cast them into
wider type first.
+ val widerType = widerDecimalType(p1, s1, p2, s2)
+ Cast(Pmod(changePrecision(e1, widerType), changePrecision(e2,
widerType)), resultType)
- // When we compare 2 decimal types with different precisions, cast
them to the smallest
- // common precision.
case b @ BinaryComparison(e1 @ DecimalType.Expression(p1, s1),
e2 @ DecimalType.Expression(p2, s2)) if
p1 != p2 || s1 != s2 =>
- val resultType = DecimalType(max(p1, p2), max(s1, s2))
+ val resultType = widerDecimalType(p1, s1, p2, s2)
--- End diff --
What I'm saying is the entire case of BinaryComparison here should be moved
to ImplicitTypeCasts. DecimalPrecision rule can then only handle expressions
that have decimal changes.
---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]