HeartSaVioR commented on code in PR #43961: URL: https://github.com/apache/spark/pull/43961#discussion_r1429468098
########## sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessor.scala: ########## @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import java.io.Serializable + +import org.apache.spark.annotation.{Evolving, Experimental} + +/** + * Represents the arbitrary stateful logic that needs to be provided by the user to perform + * stateful manipulations on keyed streams. + */ +@Experimental +@Evolving +trait StatefulProcessor[K, I, O] extends Serializable { + + /** + * Function that will be invoked as the first method that allows for users to + * initialize all their state variables and perform other init actions before handling data. + * @param handle - reference to the statefulProcessorHandle that the user can use to perform + * future actions Review Comment: nit: future actions is too generic. Are we going to explain what they can do in StatefulProcessorHandle? If then probably better to refer to the classdoc e.g. `[[StatefulProcessorHandle]]` ########## sql/api/src/main/scala/org/apache/spark/sql/streaming/TimerValues.scala: ########## Review Comment: We could address this in later review phase, but since this is a public API, we may want to put a considerable effort to describe the method doc. ########## sql/api/src/main/scala/org/apache/spark/sql/streaming/ValueState.scala: ########## @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import java.io.Serializable + +import org.apache.spark.annotation.{Evolving, Experimental} + +@Experimental +@Evolving +/** + * Interface used for arbitrary stateful operations with the v2 API to capture + * single value state. + */ +trait ValueState[S] extends Serializable { + + /** Whether state exists or not. */ + def exists(): Boolean + + /** Get the state value if it exists, or throw NoSuchElementException. */ Review Comment: We can leverage annotation to describe known exceptions being thrown via `@throws[<Exception>](description)` ########## sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala: ########## @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.streaming + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.execution.streaming.state.StateStore +import org.apache.spark.sql.streaming.{StatefulProcessorHandle, ValueState} + +/** Review Comment: I don't get why this is needed as of now, but I'll look into the code more and check the necessity. ########## sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala: ########## @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.streaming + +import java.io.Serializable + +import org.apache.commons.lang3.SerializationUtils + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.execution.streaming.state.StateStore +import org.apache.spark.sql.streaming.ValueState +import org.apache.spark.sql.types._ + +/** + * Class that provides a concrete implementation for a single value state associated with state + * variables used in the streaming transformWithState operator. + * @param store - reference to the StateStore instance to be used for storing state + * @param stateName - name of logical state partition + * @tparam S - data type of object that will be stored + */ +class ValueStateImpl[S]( + store: StateStore, + stateName: String) extends ValueState[S] with Logging{ + + private def encodeKey(): UnsafeRow = { + val keyOption = ImplicitKeyTracker.getImplicitKeyOption + if (!keyOption.isDefined) { + throw new UnsupportedOperationException("Implicit key not found for operation on" + + s"stateName=$stateName") + } + + val schemaForKeyRow: StructType = new StructType().add("key", BinaryType) + val keyByteArr = SerializationUtils.serialize(keyOption.get.asInstanceOf[Serializable]) Review Comment: Please correct me if I'm mistaken, but this enforces the type of key to ensure both compatibility with Spark SQL and Serializable. Can we enforce either one, preferably Spark SQL? ########## sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala: ########## @@ -0,0 +1,171 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.streaming + +import java.util.concurrent.TimeUnit.NANOSECONDS + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Expression, SortOrder, UnsafeRow} +import org.apache.spark.sql.catalyst.plans.physical.Distribution +import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.streaming.state._ +import org.apache.spark.sql.streaming.{OutputMode, StatefulProcessor} +import org.apache.spark.sql.types._ +import org.apache.spark.util.CompletionIterator + +/** + * Physical operator for executing `TransformWithState` + * + * @param statefulProcessor processor methods called on underlying data + * @param keyDeserializer used to extract the key object for each group. + * @param valueDeserializer used to extract the items in the iterator from an input row. + * @param groupingAttributes used to group the data + * @param dataAttributes used to read the data + * @param outputObjAttr Defines the output object + * @param batchTimestampMs processing timestamp of the current batch. + * @param eventTimeWatermarkForLateEvents event time watermark for filtering late events + * @param eventTimeWatermarkForEviction event time watermark for state eviction + * @param child the physical plan for the underlying data + */ +case class TransformWithStateExec( + keyDeserializer: Expression, + valueDeserializer: Expression, + groupingAttributes: Seq[Attribute], + dataAttributes: Seq[Attribute], + statefulProcessor: StatefulProcessor[Any, Any, Any], + outputMode: OutputMode, + outputObjAttr: Attribute, + stateInfo: Option[StatefulOperatorStateInfo], + batchTimestampMs: Option[Long], + eventTimeWatermarkForLateEvents: Option[Long], + eventTimeWatermarkForEviction: Option[Long], + child: SparkPlan) + extends UnaryExecNode + with StateStoreWriter + with WatermarkSupport + with ObjectProducerExec { + + override def shortName: String = "transformWithStateExec" + + override def shouldRunAnotherBatch(newInputWatermark: Long): Boolean = false + + override protected def withNewChildInternal( + newChild: SparkPlan): TransformWithStateExec = copy(child = newChild) + + override def keyExpressions: Seq[Attribute] = groupingAttributes + + protected val schemaForKeyRow: StructType = new StructType().add("key", BinaryType) + + protected val schemaForValueRow: StructType = new StructType().add("value", BinaryType) + + override def requiredChildDistribution: Seq[Distribution] = { + StatefulOperatorPartitioning.getCompatibleDistribution(groupingAttributes, + getStateInfo, conf) :: + Nil + } + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = Seq( + groupingAttributes.map(SortOrder(_, Ascending))) + + private def handleInputRows(keyRow: UnsafeRow, valueRowIter: Iterator[InternalRow]): + Iterator[InternalRow] = { + val getKeyObj = + ObjectOperator.deserializeRowToObject(keyDeserializer, groupingAttributes) + + val getValueObj = + ObjectOperator.deserializeRowToObject(valueDeserializer, dataAttributes) + + val getOutputRow = ObjectOperator.wrapObjectToRow(outputObjectType) + + val keyObj = getKeyObj(keyRow) // convert key to objects + ImplicitKeyTracker.setImplicitKey(keyObj) Review Comment: I see what you're trying to do. It'd be nice if we can figure out better approach, but that's OK if we can't find as I see the necessity. ########## sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala: ########## @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.streaming + +import java.io.Serializable + +import org.apache.commons.lang3.SerializationUtils + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.execution.streaming.state.StateStore +import org.apache.spark.sql.streaming.ValueState +import org.apache.spark.sql.types._ + +/** + * Class that provides a concrete implementation for a single value state associated with state + * variables used in the streaming transformWithState operator. + * @param store - reference to the StateStore instance to be used for storing state + * @param stateName - name of logical state partition + * @tparam S - data type of object that will be stored + */ +class ValueStateImpl[S]( + store: StateStore, + stateName: String) extends ValueState[S] with Logging{ + + private def encodeKey(): UnsafeRow = { Review Comment: This approach seems to end up with calling this multiple times. Should we employ some cache for latest key to the encoded key? ########## sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala: ########## @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.streaming + +import java.io.Serializable + +import org.apache.commons.lang3.SerializationUtils + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.execution.streaming.state.StateStore +import org.apache.spark.sql.streaming.ValueState +import org.apache.spark.sql.types._ + +/** + * Class that provides a concrete implementation for a single value state associated with state + * variables used in the streaming transformWithState operator. + * @param store - reference to the StateStore instance to be used for storing state + * @param stateName - name of logical state partition + * @tparam S - data type of object that will be stored + */ +class ValueStateImpl[S]( + store: StateStore, + stateName: String) extends ValueState[S] with Logging{ + + private def encodeKey(): UnsafeRow = { + val keyOption = ImplicitKeyTracker.getImplicitKeyOption + if (!keyOption.isDefined) { + throw new UnsupportedOperationException("Implicit key not found for operation on" + + s"stateName=$stateName") + } + + val schemaForKeyRow: StructType = new StructType().add("key", BinaryType) + val keyByteArr = SerializationUtils.serialize(keyOption.get.asInstanceOf[Serializable]) + val keyEncoder = UnsafeProjection.create(schemaForKeyRow) + val keyRow = keyEncoder(InternalRow(keyByteArr)) + keyRow + } + + private def encodeValue(value: S): UnsafeRow = { + val schemaForValueRow: StructType = new StructType().add("value", BinaryType) + val valueByteArr = SerializationUtils.serialize(value.asInstanceOf[Serializable]) + val valueEncoder = UnsafeProjection.create(schemaForValueRow) + val valueRow = valueEncoder(InternalRow(valueByteArr)) + valueRow + } + + /** Function to check if state exists. Returns true if present and false otherwise */ + override def exists(): Boolean = { + val retRow = store.get(encodeKey(), stateName) + if (retRow == null) { + false + } else { + true + } + } + + /** Function to return Option of value if exists and None otherwise */ + override def getOption(): Option[S] = { + if (exists()) { Review Comment: Shall we refine getOption() and get() to reduce unnecessary store.get() calls? It's now called twice but it should have been done at once. ########## sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreChangelog.scala: ########## @@ -109,18 +119,94 @@ class StateStoreChangelogWriter( } } +/** + * Write changes to the key value state store instance to a changelog file. + * There are 2 types of records, put and delete. + * A put record is written as: | key length | key content | value length | value content | + * A delete record is written as: | key length | key content | -1 | + * Write an Int -1 to signal the end of file. + * The overall changelog format is: | put record | delete record | ... | put record | -1 | + */ +class StateStoreChangelogWriterV1( + fm: CheckpointFileManager, + file: Path, + compressionCodec: CompressionCodec) + extends StateStoreChangelogWriter(fm, file, compressionCodec) { + + override def put(key: Array[Byte], value: Array[Byte]): Unit = { + assert(compressedStream != null) + compressedStream.writeInt(key.size) + compressedStream.write(key) + compressedStream.writeInt(value.size) + compressedStream.write(value) + size += 1 + } + + override def delete(key: Array[Byte]): Unit = { + assert(compressedStream != null) + compressedStream.writeInt(key.size) + compressedStream.write(key) + // -1 in the value field means record deletion. + compressedStream.writeInt(-1) + size += 1 + } +} /** - * Read an iterator of change record from the changelog file. - * A record is represented by ByteArrayPair(key: Array[Byte], value: Array[Byte]) - * A put record is returned as a ByteArrayPair(key, value) - * A delete record is return as a ByteArrayPair(key, null) + * Write changes to the key value state store instance to a changelog file. + * There are 2 types of records, put and delete. + * A put record is written as: | record type | key length + * | key content | value length | value content | col family name length | col family name | -1 | + * A delete record is written as: | record type | key length | key content | -1 + * | col family name length | col family name | -1 | + * Write an Int -1 to signal the end of file. + * The overall changelog format is: | put record | delete record | ... | put record | -1 | + */ +class StateStoreChangelogWriterV2( + fm: CheckpointFileManager, + file: Path, + compressionCodec: CompressionCodec) + extends StateStoreChangelogWriter(fm, file, compressionCodec) { + + override def put(key: Array[Byte], value: Array[Byte], colFamilyName: String): Unit = { + assert(compressedStream != null) + compressedStream.writeInt(RecordType.PUT_RECORD.toString.getBytes.size) Review Comment: It looks like already non-trivial size just to mark whether it is a put or delete. This is applied to every changes. Let's just use numeric code, or even boolean. ########## sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreChangelog.scala: ########## @@ -33,12 +33,20 @@ import org.apache.spark.sql.execution.streaming.CheckpointFileManager.Cancellabl import org.apache.spark.util.NextIterator /** - * Write changes to the key value state store instance to a changelog file. - * There are 2 types of records, put and delete. - * A put record is written as: | key length | key content | value length | value content | - * A delete record is written as: | key length | key content | -1 | - * Write an Int -1 to signal the end of file. - * The overall changelog format is: | put record | delete record | ... | put record | -1 | + * Enum used to write record types to changelog files used with RocksDBStateStoreProvider. + */ +object RecordType extends Enumeration { + type RecordType = Value + + val PUT_RECORD = Value("put_record") + val DELETE_RECORD = Value("delete_record") +} + +/** + * Base class for state store changelog writer + * @param fm - checkpoint file manager used to manage streaming query checkpoint + * @param file - name of file to use to write changelog + * @param compressionCodec - compression method using for writing changelog file */ class StateStoreChangelogWriter( Review Comment: We can make this be abstract if we only use this as base implementation - abstract for put and delete. ########## sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessorHandle.scala: ########## @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.streaming + +import java.io.Serializable + +import org.apache.spark.annotation.{Evolving, Experimental} + +/** + * Represents the operation handle provided to the stateful processor used in the + * arbitrary state API v2. + */ +@Experimental +@Evolving +trait StatefulProcessorHandle extends Serializable { + + /** Function to create new or return existing single value state variable of given type */ Review Comment: Are there requirements against stateName? E.g. uniqueness within the query or entire cluster, etcetc. Are there any characters which won't work? ########## sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala: ########## @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.streaming + +import java.io.Serializable + +import org.apache.commons.lang3.SerializationUtils + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.execution.streaming.state.StateStore +import org.apache.spark.sql.streaming.ValueState +import org.apache.spark.sql.types._ + +/** + * Class that provides a concrete implementation for a single value state associated with state + * variables used in the streaming transformWithState operator. + * @param store - reference to the StateStore instance to be used for storing state + * @param stateName - name of logical state partition + * @tparam S - data type of object that will be stored + */ +class ValueStateImpl[S]( + store: StateStore, + stateName: String) extends ValueState[S] with Logging{ + + private def encodeKey(): UnsafeRow = { + val keyOption = ImplicitKeyTracker.getImplicitKeyOption + if (!keyOption.isDefined) { + throw new UnsupportedOperationException("Implicit key not found for operation on" + + s"stateName=$stateName") + } + + val schemaForKeyRow: StructType = new StructType().add("key", BinaryType) + val keyByteArr = SerializationUtils.serialize(keyOption.get.asInstanceOf[Serializable]) Review Comment: You can rely on binary format of UnsafeRow, as we have been relying on this for existing state store. This is arguably an issue if we want to change UnsafeRow, but it's done already so at least isn't worse. ########## sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala: ########## @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.streaming + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.execution.streaming.state.StateStore +import org.apache.spark.sql.streaming.{StatefulProcessorHandle, ValueState} + +/** + * Object used to assign/retrieve/remove grouping key passed implicitly for various state + * manipulation actions using the store handle. + */ +object ImplicitKeyTracker { + val implicitKey: InheritableThreadLocal[Any] = new InheritableThreadLocal[Any] + + def getImplicitKeyOption: Option[Any] = Option(implicitKey.get()) + + def setImplicitKey(key: Any): Unit = implicitKey.set(key) + + def removeImplicitKey(): Unit = implicitKey.remove() +} + +/** + * Enum used to track valid states for the StatefulProcessorHandle + */ +object StatefulProcessorHandleState extends Enumeration { + type StatefulProcessorHandleState = Value + val CREATED, INITIALIZED, DATA_PROCESSED, CLOSED = Value +} + +/** + * Class that provides a concrete implementation of a StatefulProcessorHandle. Note that we keep + * track of valid transitions as various functions are invoked to track object lifecycle. + * @param store - instance of state store + */ +class StatefulProcessorHandleImpl(store: StateStore) + extends StatefulProcessorHandle + with Logging { + import StatefulProcessorHandleState._ + + private var currState: StatefulProcessorHandleState = CREATED + + private def verify(condition: => Boolean, msg: String): Unit = { + if (!condition) { + throw new IllegalStateException(msg) + } + } + + def setHandleState(newState: StatefulProcessorHandleState): Unit = { + currState = newState + } + + def getHandleState: StatefulProcessorHandleState = currState + + override def getValueState[T](stateName: String): ValueState[T] = { + verify(currState == CREATED, s"Cannot create state variable with name=$stateName after " + Review Comment: Shall we provide this (when user should define the state) as contract to users? I guess add contract to interface method doc would work. ########## sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala: ########## @@ -0,0 +1,171 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.streaming + +import java.util.concurrent.TimeUnit.NANOSECONDS + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Expression, SortOrder, UnsafeRow} +import org.apache.spark.sql.catalyst.plans.physical.Distribution +import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.streaming.state._ +import org.apache.spark.sql.streaming.{OutputMode, StatefulProcessor} +import org.apache.spark.sql.types._ +import org.apache.spark.util.CompletionIterator + +/** + * Physical operator for executing `TransformWithState` + * + * @param statefulProcessor processor methods called on underlying data + * @param keyDeserializer used to extract the key object for each group. + * @param valueDeserializer used to extract the items in the iterator from an input row. + * @param groupingAttributes used to group the data + * @param dataAttributes used to read the data + * @param outputObjAttr Defines the output object + * @param batchTimestampMs processing timestamp of the current batch. + * @param eventTimeWatermarkForLateEvents event time watermark for filtering late events + * @param eventTimeWatermarkForEviction event time watermark for state eviction + * @param child the physical plan for the underlying data + */ +case class TransformWithStateExec( + keyDeserializer: Expression, + valueDeserializer: Expression, + groupingAttributes: Seq[Attribute], + dataAttributes: Seq[Attribute], + statefulProcessor: StatefulProcessor[Any, Any, Any], + outputMode: OutputMode, + outputObjAttr: Attribute, + stateInfo: Option[StatefulOperatorStateInfo], + batchTimestampMs: Option[Long], + eventTimeWatermarkForLateEvents: Option[Long], + eventTimeWatermarkForEviction: Option[Long], + child: SparkPlan) + extends UnaryExecNode + with StateStoreWriter + with WatermarkSupport + with ObjectProducerExec { + + override def shortName: String = "transformWithStateExec" + + override def shouldRunAnotherBatch(newInputWatermark: Long): Boolean = false Review Comment: Are we going to disallow no-data batch for new API? ########## sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessor.scala: ########## @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import java.io.Serializable + +import org.apache.spark.annotation.{Evolving, Experimental} + +/** + * Represents the arbitrary stateful logic that needs to be provided by the user to perform + * stateful manipulations on keyed streams. + */ +@Experimental +@Evolving +trait StatefulProcessor[K, I, O] extends Serializable { + + /** + * Function that will be invoked as the first method that allows for users to + * initialize all their state variables and perform other init actions before handling data. + * @param handle - reference to the statefulProcessorHandle that the user can use to perform + * future actions + * @param outputMode - output mode for the stateful processor + */ + def init( + handle: StatefulProcessorHandle, + outputMode: OutputMode): Unit + + /** + * Function that will allow users to interact with input data rows along with the grouping key + * and current timer values and optionally provide output rows. + * @param key - grouping key + * @param inputRows - iterator of input rows associated with grouping key + * @param timerValues - instance of TimerValues that provides access to current processing/event + * time if available + * @return - Zero or more output rows + */ + def handleInputRows( Review Comment: As I stated in SPIP doc, I'd like to see the ability of chaining as one of MVP. We'd need to have an origin input row for each output row to tag the event time timestamp, and then we should either 1) require users to provide the origin input row per each output row, or 2) change the method signature to handle an input row, so that we can implicitly associate the origin input row to the output rows in the same method call. Well known approach is 2), but I'm also open to 1) if handling multiple values are considered as great UX advantage. ########## sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreChangelog.scala: ########## @@ -133,11 +219,30 @@ class StateStoreChangelogReader( case f: FileNotFoundException => throw QueryExecutionErrors.failedToReadStreamingStateFileError(fileToRead, f) } - private val input: DataInputStream = decompressStream(sourceStream) + protected val input: DataInputStream = decompressStream(sourceStream) def close(): Unit = { if (input != null) input.close() } - override def getNext(): (Array[Byte], Array[Byte]) = { + override def getNext(): (RecordType.Value, Array[Byte], Array[Byte], String) = { + throw new UnsupportedOperationException("Iterator operations not supported on base " + + "changelog reader implementation") + } +} + +/** + * Read an iterator of change record from the changelog file. + * A record is represented by ByteArrayPair(recordType: RecordType.Value, + * key: Array[Byte], value: Array[Byte], colFamilyName: String) + * A put record is returned as a ByteArrayPair(recordType, key, value, colFamilyName) + * A delete record is return as a ByteArrayPair(recordType, key, null, colFamilyName) + */ +class StateStoreChangelogReaderV1( + fm: CheckpointFileManager, + fileToRead: Path, + compressionCodec: CompressionCodec) extends StateStoreChangelogReader(fm, fileToRead, Review Comment: nit: looks like you can pull extends part down one line and make everything one liner. ########## sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala: ########## @@ -642,6 +642,37 @@ class KeyValueGroupedDataset[K, V] private[sql]( outputMode, timeoutConf, initialState)(f)(stateEncoder, outputEncoder) } + /** + * (Scala-specific) + * Invokes methods defined in the stateful processor used in arbitrary state API v2. + * We allow the user to act on per-group set of input rows along with keyed state and the + * user can choose to output/return 0 or more rows. + * For a static/batch dataset, this operator is not supported and will throw an exception. + * For a streaming dataframe, we will repeatedly invoke the interface methods for new rows + * in each trigger and the user's state/state variables will be stored persistently across + * invocations. + * + * @tparam U The type of the output objects. Must be encodable to Spark SQL types. + * @param statefulProcessor Instance of statefulProcessor whose functions will be invoked by the + * operator. + * @param outputMode The output mode of the stateful processor. Defaults to APPEND mode. + * + */ + def transformWithState[U: Encoder] + (statefulProcessor: StatefulProcessor[K, V, U], Review Comment: nit: place `(` in above line and 4 spaces for params ########## sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala: ########## @@ -569,6 +569,41 @@ case class FlatMapGroupsWithState( copy(child = newLeft, initialState = newRight) } +object TransformWithState { + def apply[K: Encoder, V: Encoder, U: Encoder]( + groupingAttributes: Seq[Attribute], Review Comment: nit: 2 more spaces for params ########## sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TimerValuesImpl.scala: ########## @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.streaming + +import org.apache.spark.sql.streaming.TimerValues + +/** + * Class that provides a concrete implementation for TimerValues used for fetching + * processing time and event time (watermark). + * @param currentProcessingTimeOpt - option to current processing time + * @param currentWatermarkOpt - option to current watermark + */ +class TimerValuesImpl( + currentProcessingTimeOpt: Option[Long], + currentWatermarkOpt: Option[Long]) extends TimerValues { + + // Return available processing time or -1 otherwise + override def getCurrentProcessingTimeInMs(): Long = if (currentProcessingTimeOpt.isDefined) { Review Comment: nit: `currentProcessingTimeOpt.getOrElse(-1L)` ########## sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala: ########## @@ -0,0 +1,171 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.streaming + +import java.util.concurrent.TimeUnit.NANOSECONDS + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Expression, SortOrder, UnsafeRow} +import org.apache.spark.sql.catalyst.plans.physical.Distribution +import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.streaming.state._ +import org.apache.spark.sql.streaming.{OutputMode, StatefulProcessor} +import org.apache.spark.sql.types._ +import org.apache.spark.util.CompletionIterator + +/** + * Physical operator for executing `TransformWithState` + * + * @param statefulProcessor processor methods called on underlying data + * @param keyDeserializer used to extract the key object for each group. + * @param valueDeserializer used to extract the items in the iterator from an input row. + * @param groupingAttributes used to group the data + * @param dataAttributes used to read the data + * @param outputObjAttr Defines the output object + * @param batchTimestampMs processing timestamp of the current batch. + * @param eventTimeWatermarkForLateEvents event time watermark for filtering late events + * @param eventTimeWatermarkForEviction event time watermark for state eviction + * @param child the physical plan for the underlying data + */ +case class TransformWithStateExec( + keyDeserializer: Expression, + valueDeserializer: Expression, + groupingAttributes: Seq[Attribute], + dataAttributes: Seq[Attribute], + statefulProcessor: StatefulProcessor[Any, Any, Any], + outputMode: OutputMode, + outputObjAttr: Attribute, + stateInfo: Option[StatefulOperatorStateInfo], + batchTimestampMs: Option[Long], + eventTimeWatermarkForLateEvents: Option[Long], + eventTimeWatermarkForEviction: Option[Long], + child: SparkPlan) + extends UnaryExecNode + with StateStoreWriter Review Comment: nit: 2 spaces for extends/with ########## sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TimerValuesImpl.scala: ########## @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.streaming + +import org.apache.spark.sql.streaming.TimerValues + +/** + * Class that provides a concrete implementation for TimerValues used for fetching + * processing time and event time (watermark). + * @param currentProcessingTimeOpt - option to current processing time + * @param currentWatermarkOpt - option to current watermark + */ +class TimerValuesImpl( + currentProcessingTimeOpt: Option[Long], + currentWatermarkOpt: Option[Long]) extends TimerValues { + + // Return available processing time or -1 otherwise + override def getCurrentProcessingTimeInMs(): Long = if (currentProcessingTimeOpt.isDefined) { + currentProcessingTimeOpt.get + } else { + -1L + } + + // Return available watermark or -1 otherwise + override def getCurrentWatermarkInMs(): Long = if (currentWatermarkOpt.isDefined) { Review Comment: ditto ########## sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala: ########## @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.streaming + +import java.io.Serializable + +import org.apache.commons.lang3.SerializationUtils + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.execution.streaming.state.StateStore +import org.apache.spark.sql.streaming.ValueState +import org.apache.spark.sql.types._ + +/** + * Class that provides a concrete implementation for a single value state associated with state + * variables used in the streaming transformWithState operator. + * @param store - reference to the StateStore instance to be used for storing state + * @param stateName - name of logical state partition + * @tparam S - data type of object that will be stored + */ +class ValueStateImpl[S]( + store: StateStore, + stateName: String) extends ValueState[S] with Logging{ Review Comment: nit: space between g and { ########## sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala: ########## @@ -0,0 +1,171 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.streaming + +import java.util.concurrent.TimeUnit.NANOSECONDS + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Expression, SortOrder, UnsafeRow} +import org.apache.spark.sql.catalyst.plans.physical.Distribution +import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.streaming.state._ +import org.apache.spark.sql.streaming.{OutputMode, StatefulProcessor} +import org.apache.spark.sql.types._ +import org.apache.spark.util.CompletionIterator + +/** + * Physical operator for executing `TransformWithState` + * + * @param statefulProcessor processor methods called on underlying data + * @param keyDeserializer used to extract the key object for each group. + * @param valueDeserializer used to extract the items in the iterator from an input row. + * @param groupingAttributes used to group the data + * @param dataAttributes used to read the data + * @param outputObjAttr Defines the output object + * @param batchTimestampMs processing timestamp of the current batch. + * @param eventTimeWatermarkForLateEvents event time watermark for filtering late events + * @param eventTimeWatermarkForEviction event time watermark for state eviction + * @param child the physical plan for the underlying data + */ +case class TransformWithStateExec( + keyDeserializer: Expression, + valueDeserializer: Expression, + groupingAttributes: Seq[Attribute], + dataAttributes: Seq[Attribute], + statefulProcessor: StatefulProcessor[Any, Any, Any], + outputMode: OutputMode, + outputObjAttr: Attribute, + stateInfo: Option[StatefulOperatorStateInfo], + batchTimestampMs: Option[Long], + eventTimeWatermarkForLateEvents: Option[Long], + eventTimeWatermarkForEviction: Option[Long], + child: SparkPlan) + extends UnaryExecNode + with StateStoreWriter + with WatermarkSupport + with ObjectProducerExec { + + override def shortName: String = "transformWithStateExec" + + override def shouldRunAnotherBatch(newInputWatermark: Long): Boolean = false + + override protected def withNewChildInternal( + newChild: SparkPlan): TransformWithStateExec = copy(child = newChild) + + override def keyExpressions: Seq[Attribute] = groupingAttributes + + protected val schemaForKeyRow: StructType = new StructType().add("key", BinaryType) + + protected val schemaForValueRow: StructType = new StructType().add("value", BinaryType) + + override def requiredChildDistribution: Seq[Distribution] = { + StatefulOperatorPartitioning.getCompatibleDistribution(groupingAttributes, + getStateInfo, conf) :: + Nil + } + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = Seq( + groupingAttributes.map(SortOrder(_, Ascending))) + + private def handleInputRows(keyRow: UnsafeRow, valueRowIter: Iterator[InternalRow]): + Iterator[InternalRow] = { + val getKeyObj = + ObjectOperator.deserializeRowToObject(keyDeserializer, groupingAttributes) + + val getValueObj = + ObjectOperator.deserializeRowToObject(valueDeserializer, dataAttributes) + + val getOutputRow = ObjectOperator.wrapObjectToRow(outputObjectType) + + val keyObj = getKeyObj(keyRow) // convert key to objects + ImplicitKeyTracker.setImplicitKey(keyObj) + val valueObjIter = valueRowIter.map(getValueObj.apply) + val mappedIterator = statefulProcessor.handleInputRows(keyObj, valueObjIter, + new TimerValuesImpl(batchTimestampMs, eventTimeWatermarkForLateEvents)).map { obj => + getOutputRow(obj) + } + ImplicitKeyTracker.removeImplicitKey() + mappedIterator + } + + private def processNewData(dataIter: Iterator[InternalRow]): Iterator[InternalRow] = { + val groupedIter = GroupedIterator(dataIter, groupingAttributes, child.output) + groupedIter.flatMap { case (keyRow, valueRowIter) => + val keyUnsafeRow = keyRow.asInstanceOf[UnsafeRow] + handleInputRows(keyUnsafeRow, valueRowIter) + } + } + + private def processDataWithPartition( + iter: Iterator[InternalRow], + store: StateStore, + processorHandle: StatefulProcessorHandleImpl): + CompletionIterator[InternalRow, Iterator[InternalRow]] = { + val allUpdatesTimeMs = longMetric("allUpdatesTimeMs") + val commitTimeMs = longMetric("commitTimeMs") + + val currentTimeNs = System.nanoTime + val updatesStartTimeNs = currentTimeNs + + // If timeout is based on event time, then filter late data based on watermark + val filteredIter = watermarkPredicateForDataForLateEvents match { + case Some(predicate) => + applyRemovingRowsOlderThanWatermark(iter, predicate) + case _ => + iter + } + + val outputIterator = processNewData(filteredIter) + processorHandle.setHandleState(StatefulProcessorHandleState.DATA_PROCESSED) + // Return an iterator of all the rows generated by all the keys, such that when fully + // consumed, all the state updates will be committed by the state store + CompletionIterator[InternalRow, Iterator[InternalRow]](outputIterator, { + // Note: Due to the iterator lazy execution, this metric also captures the time taken + // by the upstream (consumer) operators in addition to the processing in this operator. + allUpdatesTimeMs += NANOSECONDS.toMillis(System.nanoTime - updatesStartTimeNs) + commitTimeMs += timeTakenMs { + store.commit() + } + setStoreMetrics(store) + setOperatorMetrics() + statefulProcessor.close() + processorHandle.setHandleState(StatefulProcessorHandleState.CLOSED) + }) + } + + override protected def doExecute(): RDD[InternalRow] = { + metrics Review Comment: nit: Is this to enforce initialization? If then please leave a code comment. Otherwise remove this. ########## sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala: ########## @@ -118,6 +121,10 @@ class RocksDB( dbOptions.setWriteBufferManager(writeBufferManager) } + // Maintain mapping of column family name to handle + @volatile private var colFamilyNameToHandleMap = Review Comment: Do we ever change the reference? I'm not sure you need `@volatile` and `var` here. If you meant to get synchronization on accessing map, either you need synchronization block (for synchronizing multiple accesses) or ConcurrentHashMap from Java collection (for single access including CAS). ########## sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala: ########## @@ -219,49 +232,122 @@ class RocksDB( loadedVersion = endVersion } + private def checkColFamilyExists(colFamilyName: String): Boolean = { + colFamilyNameToHandleMap.contains(colFamilyName) + } + + /** + * Create RocksDB column family, if not created already + */ + def createColFamilyIfAbsent(colFamilyName: String): Unit = { + if (colFamilyName == StateStore.DEFAULT_COL_FAMILY_NAME) { + throw new UnsupportedOperationException("Failed to create column family with reserved " + + s"name=$colFamilyName") + } + + if (!checkColFamilyExists(colFamilyName)) { + assert(db != null) + val descriptor = new ColumnFamilyDescriptor(colFamilyName.getBytes, columnFamilyOptions) + val handle = db.createColumnFamily(descriptor) + colFamilyNameToHandleMap(handle.getName.map(_.toChar).mkString) = handle + } + } + /** * Get the value for the given key if present, or null. * @note This will return the last written value even if it was uncommitted. */ - def get(key: Array[Byte]): Array[Byte] = { - db.get(readOptions, key) + def get(key: Array[Byte], Review Comment: nit: if it does not fit to the one liner, it is clearer to follow multi-lines style, even with 2 lines. pull key to below line, 4 spaces for params in overall ########## sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala: ########## @@ -219,49 +232,122 @@ class RocksDB( loadedVersion = endVersion } + private def checkColFamilyExists(colFamilyName: String): Boolean = { + colFamilyNameToHandleMap.contains(colFamilyName) + } + + /** + * Create RocksDB column family, if not created already + */ + def createColFamilyIfAbsent(colFamilyName: String): Unit = { + if (colFamilyName == StateStore.DEFAULT_COL_FAMILY_NAME) { + throw new UnsupportedOperationException("Failed to create column family with reserved " + + s"name=$colFamilyName") + } + + if (!checkColFamilyExists(colFamilyName)) { + assert(db != null) + val descriptor = new ColumnFamilyDescriptor(colFamilyName.getBytes, columnFamilyOptions) + val handle = db.createColumnFamily(descriptor) + colFamilyNameToHandleMap(handle.getName.map(_.toChar).mkString) = handle + } + } + /** * Get the value for the given key if present, or null. * @note This will return the last written value even if it was uncommitted. */ - def get(key: Array[Byte]): Array[Byte] = { - db.get(readOptions, key) + def get(key: Array[Byte], + colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): Array[Byte] = { + if (useColumnFamilies) { + // if col family is not created, throw an exception + if (!checkColFamilyExists(colFamilyName)) { + throw new RuntimeException(s"Column family with name=$colFamilyName does not exist") + } + db.get(colFamilyNameToHandleMap(colFamilyName), readOptions, key) + } else { + db.get(readOptions, key) + } } /** * Put the given value for the given key. * @note This update is not committed to disk until commit() is called. */ - def put(key: Array[Byte], value: Array[Byte]): Unit = { - if (conf.trackTotalNumberOfRows) { - val oldValue = db.get(readOptions, key) - if (oldValue == null) { - numKeysOnWritingVersion += 1 + def put(key: Array[Byte], value: Array[Byte], Review Comment: ditto ########## sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala: ########## @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.streaming + +import java.io.Serializable + +import org.apache.commons.lang3.SerializationUtils + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.execution.streaming.state.StateStore +import org.apache.spark.sql.streaming.ValueState +import org.apache.spark.sql.types._ + +/** + * Class that provides a concrete implementation for a single value state associated with state + * variables used in the streaming transformWithState operator. + * @param store - reference to the StateStore instance to be used for storing state + * @param stateName - name of logical state partition + * @tparam S - data type of object that will be stored + */ +class ValueStateImpl[S]( + store: StateStore, + stateName: String) extends ValueState[S] with Logging{ + + private def encodeKey(): UnsafeRow = { + val keyOption = ImplicitKeyTracker.getImplicitKeyOption + if (!keyOption.isDefined) { + throw new UnsupportedOperationException("Implicit key not found for operation on" + + s"stateName=$stateName") + } + + val schemaForKeyRow: StructType = new StructType().add("key", BinaryType) + val keyByteArr = SerializationUtils.serialize(keyOption.get.asInstanceOf[Serializable]) + val keyEncoder = UnsafeProjection.create(schemaForKeyRow) + val keyRow = keyEncoder(InternalRow(keyByteArr)) + keyRow + } + + private def encodeValue(value: S): UnsafeRow = { + val schemaForValueRow: StructType = new StructType().add("value", BinaryType) + val valueByteArr = SerializationUtils.serialize(value.asInstanceOf[Serializable]) + val valueEncoder = UnsafeProjection.create(schemaForValueRow) + val valueRow = valueEncoder(InternalRow(valueByteArr)) + valueRow + } + + /** Function to check if state exists. Returns true if present and false otherwise */ + override def exists(): Boolean = { + val retRow = store.get(encodeKey(), stateName) + if (retRow == null) { Review Comment: nit: `retRow != null` one liner ########## sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala: ########## @@ -204,12 +211,18 @@ class RocksDB( for (v <- loadedVersion + 1 to endVersion) { var changelogReader: StateStoreChangelogReader = null try { - changelogReader = fileManager.getChangelogReader(v) - changelogReader.foreach { case (key, value) => - if (value != null) { - put(key, value) - } else { - remove(key) + changelogReader = fileManager.getChangelogReader(v, useColumnFamilies) + changelogReader.foreach { case (recordType, key, value, colFamilyName) => Review Comment: Will we allow putting `null` in value in state store? Just to check we need additional recordType or not. ########## sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala: ########## @@ -604,17 +728,57 @@ class RocksDB( } private def getDBProperty(property: String): Long = { - db.getProperty(property).toLong + if (useColumnFamilies) { + // get cumulative sum across all available column families + assert(!colFamilyNameToHandleMap.isEmpty) + colFamilyNameToHandleMap + .values + .map(handle => db.getProperty(handle, property).toLong) + .sum + } else { + db.getProperty(property).toLong + } } private def openDB(): Unit = { assert(db == null) - db = NativeRocksDB.open(dbOptions, workingDir.toString) - logInfo(s"Opened DB with conf ${conf}") + if (useColumnFamilies) { + val colFamilies = NativeRocksDB.listColumnFamilies(dbOptions, workingDir.toString) + + var colFamilyDescriptors: Seq[ColumnFamilyDescriptor] = Seq.empty[ColumnFamilyDescriptor] + // populate the list of available col family descriptors + colFamilies.asScala.toList.foreach(family => { + val descriptor = new ColumnFamilyDescriptor(family, columnFamilyOptions) + colFamilyDescriptors = colFamilyDescriptors :+ descriptor + }) + + if (colFamilyDescriptors.isEmpty) { + colFamilyDescriptors = colFamilyDescriptors :+ + new ColumnFamilyDescriptor(NativeRocksDB.DEFAULT_COLUMN_FAMILY, columnFamilyOptions) + } + + val colFamilyHandles = new java.util.ArrayList[ColumnFamilyHandle]() + db = NativeRocksDB.open(new DBOptions(dbOptions), workingDir.toString, + colFamilyDescriptors.asJava, colFamilyHandles) + + // Store the mapping of names to handles in the internal map + colFamilyHandles.asScala.toList.map( Review Comment: nit: foreach? also `{ handle =>`. you can save a couple of lines. ########## sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala: ########## @@ -604,17 +728,57 @@ class RocksDB( } private def getDBProperty(property: String): Long = { - db.getProperty(property).toLong + if (useColumnFamilies) { + // get cumulative sum across all available column families + assert(!colFamilyNameToHandleMap.isEmpty) + colFamilyNameToHandleMap + .values + .map(handle => db.getProperty(handle, property).toLong) + .sum + } else { + db.getProperty(property).toLong + } } private def openDB(): Unit = { assert(db == null) - db = NativeRocksDB.open(dbOptions, workingDir.toString) - logInfo(s"Opened DB with conf ${conf}") + if (useColumnFamilies) { + val colFamilies = NativeRocksDB.listColumnFamilies(dbOptions, workingDir.toString) + + var colFamilyDescriptors: Seq[ColumnFamilyDescriptor] = Seq.empty[ColumnFamilyDescriptor] + // populate the list of available col family descriptors + colFamilies.asScala.toList.foreach(family => { Review Comment: nit: `{ family =>` should be sufficient, no need to do `(family => {` ########## sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBFileManager.scala: ########## @@ -149,21 +149,31 @@ class RocksDBFileManager( @volatile private var rootDirChecked: Boolean = false - def getChangeLogWriter(version: Long): StateStoreChangelogWriter = { + def getChangeLogWriter(version: Long, Review Comment: ditto about style ########## sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala: ########## @@ -219,49 +232,122 @@ class RocksDB( loadedVersion = endVersion } + private def checkColFamilyExists(colFamilyName: String): Boolean = { + colFamilyNameToHandleMap.contains(colFamilyName) + } + + /** + * Create RocksDB column family, if not created already + */ + def createColFamilyIfAbsent(colFamilyName: String): Unit = { + if (colFamilyName == StateStore.DEFAULT_COL_FAMILY_NAME) { + throw new UnsupportedOperationException("Failed to create column family with reserved " + + s"name=$colFamilyName") + } + + if (!checkColFamilyExists(colFamilyName)) { + assert(db != null) + val descriptor = new ColumnFamilyDescriptor(colFamilyName.getBytes, columnFamilyOptions) + val handle = db.createColumnFamily(descriptor) + colFamilyNameToHandleMap(handle.getName.map(_.toChar).mkString) = handle + } + } + /** * Get the value for the given key if present, or null. * @note This will return the last written value even if it was uncommitted. */ - def get(key: Array[Byte]): Array[Byte] = { - db.get(readOptions, key) + def get(key: Array[Byte], + colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): Array[Byte] = { + if (useColumnFamilies) { + // if col family is not created, throw an exception + if (!checkColFamilyExists(colFamilyName)) { + throw new RuntimeException(s"Column family with name=$colFamilyName does not exist") + } + db.get(colFamilyNameToHandleMap(colFamilyName), readOptions, key) + } else { + db.get(readOptions, key) + } } /** * Put the given value for the given key. * @note This update is not committed to disk until commit() is called. */ - def put(key: Array[Byte], value: Array[Byte]): Unit = { - if (conf.trackTotalNumberOfRows) { - val oldValue = db.get(readOptions, key) - if (oldValue == null) { - numKeysOnWritingVersion += 1 + def put(key: Array[Byte], value: Array[Byte], + colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): Unit = { + if (useColumnFamilies) { + // if col family is not created, throw an exception + if (!checkColFamilyExists(colFamilyName)) { + throw new RuntimeException(s"Column family with name=$colFamilyName does not exist") + } + + if (conf.trackTotalNumberOfRows) { + val oldValue = db.get(colFamilyNameToHandleMap(colFamilyName), readOptions, key) + if (oldValue == null) { + numKeysOnWritingVersion += 1 + } + } + db.put(colFamilyNameToHandleMap(colFamilyName), writeOptions, key, value) + changelogWriter.foreach(_.put(key, value, colFamilyName)) + } else { + if (conf.trackTotalNumberOfRows) { + val oldValue = db.get(readOptions, key) + if (oldValue == null) { + numKeysOnWritingVersion += 1 + } } + db.put(writeOptions, key, value) + changelogWriter.foreach(_.put(key, value)) } - db.put(writeOptions, key, value) - changelogWriter.foreach(_.put(key, value)) } /** * Remove the key if present. * @note This update is not committed to disk until commit() is called. */ - def remove(key: Array[Byte]): Unit = { - if (conf.trackTotalNumberOfRows) { - val value = db.get(readOptions, key) - if (value != null) { - numKeysOnWritingVersion -= 1 + def remove(key: Array[Byte], colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): Unit = { Review Comment: ditto ########## sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreChangelog.scala: ########## @@ -153,12 +258,70 @@ class StateStoreChangelogReader( val valueSize = input.readInt() if (valueSize < 0) { // A deletion record - (keyBuffer, null) + (RecordType.DELETE_RECORD, keyBuffer, null, StateStore.DEFAULT_COL_FAMILY_NAME) } else { val valueBuffer = new Array[Byte](valueSize) ByteStreams.readFully(input, valueBuffer, 0, valueSize) // A put record. - (keyBuffer, valueBuffer) + (RecordType.PUT_RECORD, keyBuffer, valueBuffer, StateStore.DEFAULT_COL_FAMILY_NAME) + } + } + } +} + +/** + * Read an iterator of change record from the changelog file. + * A record is represented by ByteArrayPair(recordType: RecordType.Value, + * key: Array[Byte], value: Array[Byte], colFamilyName: String) + * A put record is returned as a ByteArrayPair(recordType, key, value, colFamilyName) + * A delete record is return as a ByteArrayPair(recordType, key, null, colFamilyName) + */ +class StateStoreChangelogReaderV2( + fm: CheckpointFileManager, + fileToRead: Path, + compressionCodec: CompressionCodec) extends StateStoreChangelogReader(fm, fileToRead, + compressionCodec) { + + private def parseBuffer(input: DataInputStream): Array[Byte] = { + val blockSize = input.readInt() + val blockBuffer = new Array[Byte](blockSize) + ByteStreams.readFully(input, blockBuffer, 0, blockSize) + blockBuffer + } + + override def getNext(): (RecordType.Value, Array[Byte], Array[Byte], String) = { + val recordTypeSize = input.readInt() + // A -1 key size mean end of file. + if (recordTypeSize == -1) { + finished = true + null + } else if (recordTypeSize < 0) { + throw new IOException( + s"Error reading streaming state file $fileToRead: " + + s"record type size cannot be $recordTypeSize") + } else { + val recordTypeBuffer = new Array[Byte](recordTypeSize) Review Comment: This proves that the overhead is significant. ########## sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreChangelog.scala: ########## @@ -109,18 +119,94 @@ class StateStoreChangelogWriter( } } +/** + * Write changes to the key value state store instance to a changelog file. + * There are 2 types of records, put and delete. + * A put record is written as: | key length | key content | value length | value content | + * A delete record is written as: | key length | key content | -1 | + * Write an Int -1 to signal the end of file. + * The overall changelog format is: | put record | delete record | ... | put record | -1 | + */ +class StateStoreChangelogWriterV1( + fm: CheckpointFileManager, + file: Path, + compressionCodec: CompressionCodec) + extends StateStoreChangelogWriter(fm, file, compressionCodec) { + + override def put(key: Array[Byte], value: Array[Byte]): Unit = { + assert(compressedStream != null) + compressedStream.writeInt(key.size) + compressedStream.write(key) + compressedStream.writeInt(value.size) + compressedStream.write(value) + size += 1 + } + + override def delete(key: Array[Byte]): Unit = { + assert(compressedStream != null) + compressedStream.writeInt(key.size) + compressedStream.write(key) + // -1 in the value field means record deletion. + compressedStream.writeInt(-1) + size += 1 + } +} /** - * Read an iterator of change record from the changelog file. - * A record is represented by ByteArrayPair(key: Array[Byte], value: Array[Byte]) - * A put record is returned as a ByteArrayPair(key, value) - * A delete record is return as a ByteArrayPair(key, null) + * Write changes to the key value state store instance to a changelog file. + * There are 2 types of records, put and delete. + * A put record is written as: | record type | key length + * | key content | value length | value content | col family name length | col family name | -1 | + * A delete record is written as: | record type | key length | key content | -1 + * | col family name length | col family name | -1 | + * Write an Int -1 to signal the end of file. + * The overall changelog format is: | put record | delete record | ... | put record | -1 | + */ +class StateStoreChangelogWriterV2( + fm: CheckpointFileManager, + file: Path, + compressionCodec: CompressionCodec) + extends StateStoreChangelogWriter(fm, file, compressionCodec) { Review Comment: nit: 2 spaces ########## sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreChangelog.scala: ########## @@ -109,18 +119,94 @@ class StateStoreChangelogWriter( } } +/** + * Write changes to the key value state store instance to a changelog file. + * There are 2 types of records, put and delete. + * A put record is written as: | key length | key content | value length | value content | + * A delete record is written as: | key length | key content | -1 | + * Write an Int -1 to signal the end of file. + * The overall changelog format is: | put record | delete record | ... | put record | -1 | + */ +class StateStoreChangelogWriterV1( + fm: CheckpointFileManager, + file: Path, + compressionCodec: CompressionCodec) + extends StateStoreChangelogWriter(fm, file, compressionCodec) { + + override def put(key: Array[Byte], value: Array[Byte]): Unit = { + assert(compressedStream != null) + compressedStream.writeInt(key.size) + compressedStream.write(key) + compressedStream.writeInt(value.size) + compressedStream.write(value) + size += 1 + } + + override def delete(key: Array[Byte]): Unit = { + assert(compressedStream != null) + compressedStream.writeInt(key.size) + compressedStream.write(key) + // -1 in the value field means record deletion. + compressedStream.writeInt(-1) + size += 1 + } +} /** - * Read an iterator of change record from the changelog file. - * A record is represented by ByteArrayPair(key: Array[Byte], value: Array[Byte]) - * A put record is returned as a ByteArrayPair(key, value) - * A delete record is return as a ByteArrayPair(key, null) + * Write changes to the key value state store instance to a changelog file. + * There are 2 types of records, put and delete. + * A put record is written as: | record type | key length + * | key content | value length | value content | col family name length | col family name | -1 | + * A delete record is written as: | record type | key length | key content | -1 + * | col family name length | col family name | -1 | + * Write an Int -1 to signal the end of file. + * The overall changelog format is: | put record | delete record | ... | put record | -1 | + */ +class StateStoreChangelogWriterV2( + fm: CheckpointFileManager, + file: Path, + compressionCodec: CompressionCodec) + extends StateStoreChangelogWriter(fm, file, compressionCodec) { + + override def put(key: Array[Byte], value: Array[Byte], colFamilyName: String): Unit = { + assert(compressedStream != null) + compressedStream.writeInt(RecordType.PUT_RECORD.toString.getBytes.size) + compressedStream.write(RecordType.PUT_RECORD.toString.getBytes) + compressedStream.writeInt(key.size) + compressedStream.write(key) + compressedStream.writeInt(value.size) + compressedStream.write(value) + compressedStream.writeInt(colFamilyName.getBytes.size) + compressedStream.write(colFamilyName.getBytes) + size += 1 + } + + override def delete(key: Array[Byte], colFamilyName: String): Unit = { + assert(compressedStream != null) + compressedStream.writeInt(RecordType.DELETE_RECORD.toString.getBytes.size) + compressedStream.write(RecordType.DELETE_RECORD.toString.getBytes) + compressedStream.writeInt(key.size) + compressedStream.write(key) + // -1 in the value field means record deletion. + compressedStream.writeInt(-1) + compressedStream.writeInt(colFamilyName.getBytes.size) + compressedStream.write(colFamilyName.getBytes) + size += 1 + } +} + +/** + * Base class for state store changelog reader + * @param fm - checkpoint file manager used to manage streaming query checkpoint + * @param fileToRead - name of file to use to read changelog + * @param compressionCodec - de-compression method using for reading changelog file */ class StateStoreChangelogReader( fm: CheckpointFileManager, fileToRead: Path, compressionCodec: CompressionCodec) - extends NextIterator[(Array[Byte], Array[Byte])] with Logging { + extends NextIterator[(RecordType.Value, Array[Byte], Array[Byte], String)] Review Comment: nit: 2 spaces ########## sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBFileManager.scala: ########## @@ -149,21 +149,31 @@ class RocksDBFileManager( @volatile private var rootDirChecked: Boolean = false - def getChangeLogWriter(version: Long): StateStoreChangelogWriter = { + def getChangeLogWriter(version: Long, + useColumnFamilies: Boolean = false): StateStoreChangelogWriter = { val changelogFile = dfsChangelogFile(version) if (!rootDirChecked) { val rootDir = new Path(dfsRootDir) if (!fm.exists(rootDir)) fm.mkdirs(rootDir) rootDirChecked = true } - val changelogWriter = new StateStoreChangelogWriter(fm, changelogFile, codec) + val changelogWriter = if (useColumnFamilies) { + new StateStoreChangelogWriterV2(fm, changelogFile, codec) + } else { + new StateStoreChangelogWriterV1(fm, changelogFile, codec) + } changelogWriter } // Get the changelog file at version - def getChangelogReader(version: Long): StateStoreChangelogReader = { + def getChangelogReader(version: Long, Review Comment: ditto about style ########## sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreChangelog.scala: ########## @@ -109,18 +119,94 @@ class StateStoreChangelogWriter( } } +/** + * Write changes to the key value state store instance to a changelog file. + * There are 2 types of records, put and delete. + * A put record is written as: | key length | key content | value length | value content | + * A delete record is written as: | key length | key content | -1 | + * Write an Int -1 to signal the end of file. + * The overall changelog format is: | put record | delete record | ... | put record | -1 | + */ +class StateStoreChangelogWriterV1( + fm: CheckpointFileManager, + file: Path, + compressionCodec: CompressionCodec) + extends StateStoreChangelogWriter(fm, file, compressionCodec) { Review Comment: nit: 2 spaces ########## sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreChangelog.scala: ########## @@ -153,12 +258,70 @@ class StateStoreChangelogReader( val valueSize = input.readInt() if (valueSize < 0) { // A deletion record - (keyBuffer, null) + (RecordType.DELETE_RECORD, keyBuffer, null, StateStore.DEFAULT_COL_FAMILY_NAME) } else { val valueBuffer = new Array[Byte](valueSize) ByteStreams.readFully(input, valueBuffer, 0, valueSize) // A put record. - (keyBuffer, valueBuffer) + (RecordType.PUT_RECORD, keyBuffer, valueBuffer, StateStore.DEFAULT_COL_FAMILY_NAME) + } + } + } +} + +/** + * Read an iterator of change record from the changelog file. + * A record is represented by ByteArrayPair(recordType: RecordType.Value, + * key: Array[Byte], value: Array[Byte], colFamilyName: String) + * A put record is returned as a ByteArrayPair(recordType, key, value, colFamilyName) + * A delete record is return as a ByteArrayPair(recordType, key, null, colFamilyName) + */ +class StateStoreChangelogReaderV2( + fm: CheckpointFileManager, + fileToRead: Path, + compressionCodec: CompressionCodec) extends StateStoreChangelogReader(fm, fileToRead, Review Comment: ditto -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
