Dmitry Lapshin created SPARK-48965:
--------------------------------------

             Summary: toJSON produces wrong values if DecimalType information 
is lost in as[Product]
                 Key: SPARK-48965
                 URL: https://issues.apache.org/jira/browse/SPARK-48965
             Project: Spark
          Issue Type: Bug
          Components: SQL
    Affects Versions: 3.5.1, 3.1.1
            Reporter: Dmitry Lapshin


Consider this example:
{code:scala}
package com.jetbrains.jetstat.etl

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.types.DecimalType

object A {
  case class Example(x: BigDecimal)

  def main(args: Array[String]): Unit = {
    val spark = SparkSession.builder()
      .master("local[1]")
      .getOrCreate()

    import spark.implicits._

    val originalRaw = BigDecimal("123.456")
    val original = Example(originalRaw)

    val ds1 = spark.createDataset(Seq(original))
    val ds2 = ds1
      .withColumn("x", $"x" cast DecimalType(12, 6))

    val ds3 = ds2
      .as[Example]

    println(s"DS1: schema=${ds1.schema}, encoder.schema=${ds1.encoder.schema}")
    println(s"DS2: schema=${ds1.schema}, encoder.schema=${ds2.encoder.schema}")
    println(s"DS3: schema=${ds1.schema}, encoder.schema=${ds3.encoder.schema}")

    val json1 = ds1.toJSON.collect().head
    val json2 = ds2.toJSON.collect().head
    val json3 = ds3.toJSON.collect().head

    val collect1 = ds1.collect().head
    val collect2_ = ds2.collect().head
    val collect2 = collect2_.getDecimal(collect2_.fieldIndex("x"))
    val collect3 = ds3.collect().head

    println(s"Original: $original (scale = ${original.x.scale}, precision = 
${original.x.precision})")
    println(s"Collect1: $collect1 (scale = ${collect1.x.scale}, precision = 
${collect1.x.precision})")
    println(s"Collect2: $collect2 (scale = ${collect2.scale}, precision = 
${collect2.precision})")
    println(s"Collect3: $collect3 (scale = ${collect3.x.scale}, precision = 
${collect3.x.precision})")
    println(s"json1: $json1")
    println(s"json2: $json2")
    println(s"json3: $json3")
  }
}
{code}
Running it you'd see that json3 contains very much wrong data. After a bit of 
debugging, and sorry since I'm bad with Spark internals, I've found that:
 * In-memory representation of the data in this example used {{UnsafeRow}}, 
whose {{.getDecimal}} uses compression to store small Decimal values as longs, 
but doesn't remember decimal sizing parameters,
 * However, there are at least two sources for precision & scale to pass to 
that method: {{Dataset.schema}} (which is based on query execution, always 
contains 38,18 for me) and {{Dataset.encoder.schema}} (that gets updated in 
`ds2` to 12,6 but then is reset in `ds3`). Also, there is a 
{{Dataset.deserializer}} that seems to be combining those two non-trivially.
 * This doesn't seem to affect {{Dataset.collect()}} methods since they use 
{{deserializer}}, but {{Dataset.toJSON}} only uses the first schema.

Seems to me that either {{.toJSON}} should be more aware of what's going on or 
{{.as[]}} should be doing something else.



--
This message was sent by Atlassian Jira
(v8.20.10#820010)

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to