Bryan Jeffrey created SPARK-30986:
-------------------------------------

             Summary: Structured Streaming: mapGroupsWithState UDT 
serialization does not work
                 Key: SPARK-30986
                 URL: https://issues.apache.org/jira/browse/SPARK-30986
             Project: Spark
          Issue Type: Bug
          Components: Structured Streaming
    Affects Versions: 2.3.0
         Environment: We're using Spark 2.3.0 on Ubuntu Linux and Windows w/ 
Scala 2.11.8
            Reporter: Bryan Jeffrey


Hello.  
  
 I'm running Scala 2.11 w/ Spark 2.3.0.  I've encountered a problem with 
mapGroupsWithState, and was wondering if anyone had insight.  We use Joda time 
in a number of data structures, and so we've generated a custom serializer for 
Joda.  This works well in most dataset/dataframe structured streaming 
operations. However, when running mapGroupsWithState we observed that incorrect 
dates were being returned from a state.
  
 Simple example:
 1. Input A has a date D
 2. Input A updates state in mapGroupsWithState. Date present in state is D
 3. Input A is added again.  Input A has correct date D, but existing state now 
has invalid date
  
 Here is a simple repro:
  
 Joda Time UDT:
  
{code:scala}
private[sql] class JodaTimeUDT extends UserDefinedType[DateTime] {
 override def sqlType: DataType = LongType
 override def serialize(obj: DateTime): Long = obj.getMillis
 def deserialize(datum: Any): DateTime = datum match \{ case value: Long => new 
DateTime(value, DateTimeZone.UTC) }
 override def userClass: Class[DateTime] = classOf[DateTime]
 private[spark] override def asNullable: JodaTimeUDT = this
}

object JodaTimeUDTRegister {
 def register : Unit = \{ UDTRegistration.register(classOf[DateTime].getName, 
classOf[JodaTimeUDT].getName) }
}
{code}
 
 Test Leveraging Joda UDT:
  
{code:scala}
case class FooWithDate(date: DateTime, s: String, i: Int)

@RunWith(classOf[JUnitRunner])
class TestJodaTimeUdt extends FlatSpec with Matchers with MockFactory with 
BeforeAndAfterAll {
  val application = this.getClass.getName
  var session: SparkSession = _

  override def beforeAll(): Unit = {
    System.setProperty("hadoop.home.dir", getClass.getResource("/").getPath)
    val sparkConf = new SparkConf()
      .set("spark.driver.allowMultipleContexts", "true")
      .set("spark.testing", "true")
      .set("spark.memory.fraction", "1")
      .set("spark.ui.enabled", "false")
      .set("spark.streaming.gracefulStopTimeout", "1000")
      .setAppName(application).setMaster("local[*]")


    session = SparkSession.builder().config(sparkConf).getOrCreate()
    session.sparkContext.setCheckpointDir("/")
    JodaTimeUDTRegister.register
  }

  override def afterAll(): Unit = {
    session.stop()
  }

  it should "work correctly for a streaming input with stateful transformation" 
in {
    val date = new DateTime(2020, 1, 2, 3, 4, 5, 6, DateTimeZone.UTC)
    val sqlContext = session.sqlContext
    import sqlContext.implicits._

    val input = List(FooWithDate(date, "Foo", 1), FooWithDate(date, "Foo", 3), 
FooWithDate(date, "Foo", 3))
    val streamInput: MemoryStream[FooWithDate] = new 
MemoryStream[FooWithDate](42, session.sqlContext)
    streamInput.addData(input)
    val ds: Dataset[FooWithDate] = streamInput.toDS()

    val mapGroupsWithStateFunction: (Int, Iterator[FooWithDate], 
GroupState[FooWithDate]) => FooWithDate = TestJodaTimeUdt.updateFooState
    val result: Dataset[FooWithDate] = ds
      .groupByKey(x => x.i)
      
.mapGroupsWithState(GroupStateTimeout.ProcessingTimeTimeout())(mapGroupsWithStateFunction)
    val writeTo = s"random_table_name"

    
result.writeStream.outputMode(OutputMode.Update).format("memory").queryName(writeTo).trigger(Trigger.Once()).start().awaitTermination()
    val combinedResults: Array[FooWithDate] = session.sql(sqlText = s"select * 
from $writeTo").as[FooWithDate].collect()
    val expected = Array(FooWithDate(date, "Foo", 1), FooWithDate(date, 
"FooFoo", 6))
    combinedResults should contain theSameElementsAs(expected)
  }
}

object TestJodaTimeUdt {
  def updateFooState(id: Int, inputs: Iterator[FooWithDate], state: 
GroupState[FooWithDate]): FooWithDate = {
    if (state.hasTimedOut) {
      state.remove()
      state.getOption.get
    } else {
      val inputsSeq: Seq[FooWithDate] = inputs.toSeq
      val startingState = state.getOption.getOrElse(inputsSeq.head)
      val toProcess = if (state.getOption.isDefined) inputsSeq else 
inputsSeq.tail
      val updatedFoo = toProcess.foldLeft(startingState)(concatFoo)

      state.update(updatedFoo)
      state.setTimeoutDuration("1 minute")
      updatedFoo
    }
  }

  def concatFoo(a: FooWithDate, b: FooWithDate): FooWithDate = 
FooWithDate(b.date, a.s + b.s, a.i + b.i)
}

{code}
The test output shows the invalid date:
{quote}   
 org.scalatest.exceptions.TestFailedException: 
 Array(FooWithDate(*2021-02-02T19:26:23.374Z*,Foo,1), 
FooWithDate(2021-02-02T19:26:23.374Z,FooFoo,6)) did not contain the same 
elements as 
 Array(FooWithDate(2020-01-02T03:04:05.006Z,Foo,1), 
FooWithDate(2020-01-02T03:04:05.006Z,FooFoo,6))
{quote}



--
This message was sent by Atlassian Jira
(v8.3.4#803005)

---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscr...@spark.apache.org
For additional commands, e-mail: issues-h...@spark.apache.org

Reply via email to