Skip to content

Commit

Permalink
[SPARK-50128][SS] Add stateful processor handle APIs using implicit e…
Browse files Browse the repository at this point in the history
…ncoders in Scala

### What changes were proposed in this pull request?
Add stateful processor handle APIs using implicit encoders in Scala

### Why are the changes needed?
Without the changes, users have to pass explicit SQL encoders for state types while acquiring an instance of the underlying state variable

### Does this PR introduce _any_ user-facing change?
Yes

Users can now implicits available in Scala through `import spark.implicits._` and only provide the type while getting the state objects. For eg -

```
      override def init(outputMode: OutputMode, timeMode: TimeMode): Unit = {
         _myValueState = getHandle.getValueState[Long]("myValueState", TTLConfig.NONE)
      }
```

### How was this patch tested?
Existing unit tests

### Was this patch authored or co-authored using generative AI tooling?
No

Closes apache#48728 from anishshri-db/task/SPARK-50128.

Authored-by: Anish Shrigondekar <[email protected]>
Signed-off-by: Jungtaek Lim <[email protected]>
  • Loading branch information
anishshri-db authored and HeartSaVioR committed Nov 5, 2024
1 parent 0d2d031 commit 47063a6
Show file tree
Hide file tree
Showing 25 changed files with 367 additions and 255 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,25 +29,12 @@ import org.apache.spark.sql.Encoder
@Evolving
private[sql] trait StatefulProcessorHandle extends Serializable {

/**
* Function to create new or return existing single value state variable of given type. The user
* must ensure to call this function only within the `init()` method of the StatefulProcessor.
*
* @param stateName
* \- name of the state variable
* @param valEncoder
* \- SQL encoder for state variable
* @tparam T
* \- type of state variable
* @return
* \- instance of ValueState of type T that can be used to store state persistently
*/
def getValueState[T](stateName: String, valEncoder: Encoder[T]): ValueState[T]

/**
* Function to create new or return existing single value state variable of given type with ttl.
* State values will not be returned past ttlDuration, and will be eventually removed from the
* state store. Any state update resets the ttl to current processing time plus ttlDuration.
* Users can use the helper method `TTLConfig.NONE` in Scala or `TTLConfig.NONE()` in Java for
* the TTLConfig parameter to disable TTL for the state variable.
*
* The user must ensure to call this function only within the `init()` method of the
* StatefulProcessor.
Expand All @@ -69,25 +56,34 @@ private[sql] trait StatefulProcessorHandle extends Serializable {
ttlConfig: TTLConfig): ValueState[T]

/**
* Creates new or returns existing list state associated with stateName. The ListState persists
* values of type T.
* (Scala-specific) Function to create new or return existing single value state variable of
* given type with ttl. State values will not be returned past ttlDuration, and will be
* eventually removed from the state store. Any state update resets the ttl to current
* processing time plus ttlDuration. Users can use the helper method `TTLConfig.NONE` in Scala
* or `TTLConfig.NONE()` in Java for the TTLConfig parameter to disable TTL for the state
* variable.
*
* The user must ensure to call this function only within the `init()` method of the
* StatefulProcessor. Note that this API uses the implicit SQL encoder in Scala.
*
* @param stateName
* \- name of the state variable
* @param valEncoder
* \- SQL encoder for state variable
* @param ttlConfig
* \- the ttl configuration (time to live duration etc.)
* @tparam T
* \- type of state variable
* @return
* \- instance of ListState of type T that can be used to store state persistently
* \- instance of ValueState of type T that can be used to store state persistently
*/
def getListState[T](stateName: String, valEncoder: Encoder[T]): ListState[T]
def getValueState[T: Encoder](stateName: String, ttlConfig: TTLConfig): ValueState[T]

/**
* Function to create new or return existing list state variable of given type with ttl. State
* values will not be returned past ttlDuration, and will be eventually removed from the state
* store. Any values in listState which have expired after ttlDuration will not be returned on
* get() and will be eventually removed from the state.
* get() and will be eventually removed from the state. Users can use the helper method
* `TTLConfig.NONE` in Scala or `TTLConfig.NONE()` in Java for the TTLConfig parameter to
* disable TTL for the state variable.
*
* The user must ensure to call this function only within the `init()` method of the
* StatefulProcessor.
Expand All @@ -109,32 +105,34 @@ private[sql] trait StatefulProcessorHandle extends Serializable {
ttlConfig: TTLConfig): ListState[T]

/**
* Creates new or returns existing map state associated with stateName. The MapState persists
* Key-Value pairs of type [K, V].
* (Scala-specific) Function to create new or return existing list state variable of given type
* with ttl. State values will not be returned past ttlDuration, and will be eventually removed
* from the state store. Any values in listState which have expired after ttlDuration will not
* be returned on get() and will be eventually removed from the state. Users can use the helper
* method `TTLConfig.NONE` in Scala or `TTLConfig.NONE()` in Java for the TTLConfig parameter to
* disable TTL for the state variable.
*
* The user must ensure to call this function only within the `init()` method of the
* StatefulProcessor. Note that this API uses the implicit SQL encoder in Scala.
*
* @param stateName
* \- name of the state variable
* @param userKeyEnc
* \- spark sql encoder for the map key
* @param valEncoder
* \- spark sql encoder for the map value
* @tparam K
* \- type of key for map state variable
* @tparam V
* \- type of value for map state variable
* @param ttlConfig
* \- the ttl configuration (time to live duration etc.)
* @tparam T
* \- type of state variable
* @return
* \- instance of MapState of type [K,V] that can be used to store state persistently
* \- instance of ListState of type T that can be used to store state persistently
*/
def getMapState[K, V](
stateName: String,
userKeyEnc: Encoder[K],
valEncoder: Encoder[V]): MapState[K, V]
def getListState[T: Encoder](stateName: String, ttlConfig: TTLConfig): ListState[T]

/**
* Function to create new or return existing map state variable of given type with ttl. State
* values will not be returned past ttlDuration, and will be eventually removed from the state
* store. Any values in mapState which have expired after ttlDuration will not returned on get()
* and will be eventually removed from the state.
* and will be eventually removed from the state. Users can use the helper method
* `TTLConfig.NONE` in Scala or `TTLConfig.NONE()` in Java for the TTLConfig parameter to
* disable TTL for the state variable.
*
* The user must ensure to call this function only within the `init()` method of the
* StatefulProcessor.
Expand All @@ -160,6 +158,30 @@ private[sql] trait StatefulProcessorHandle extends Serializable {
valEncoder: Encoder[V],
ttlConfig: TTLConfig): MapState[K, V]

/**
* (Scala-specific) Function to create new or return existing map state variable of given type
* with ttl. State values will not be returned past ttlDuration, and will be eventually removed
* from the state store. Any values in mapState which have expired after ttlDuration will not be
* returned on get() and will be eventually removed from the state. Users can use the helper
* method `TTLConfig.NONE` in Scala or `TTLConfig.NONE()` in Java for the TTLConfig parameter to
* disable TTL for the state variable.
*
* The user must ensure to call this function only within the `init()` method of the
* StatefulProcessor. Note that this API uses the implicit SQL encoder in Scala.
*
* @param stateName
* \- name of the state variable
* @param ttlConfig
* \- the ttl configuration (time to live duration etc.)
* @tparam K
* \- type of key for map state variable
* @tparam V
* \- type of value for map state variable
* @return
* \- instance of MapState of type [K,V] that can be used to store state persistently
*/
def getMapState[K: Encoder, V: Encoder](stateName: String, ttlConfig: TTLConfig): MapState[K, V]

/** Function to return queryInfo for currently running task */
def getQueryInfo(): QueryInfo

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,22 @@ import java.time.Duration
* will be eventually removed from the state store. Any state update resets the ttl to current
* processing time plus ttlDuration.
*
* Passing a TTL duration of zero will disable the TTL for the state variable. Users can also use
* the helper method `TTLConfig.NONE` in Scala or `TTLConfig.NONE()` in Java to disable TTL for
* the state variable.
*
* @param ttlDuration
* time to live duration for state stored in the state variable.
*/
case class TTLConfig(ttlDuration: Duration)

object TTLConfig {

/**
* Helper method to create a TTLConfig with expiry duration as Zero
* @return
* \- TTLConfig with expiry duration as Zero
*/
def NONE: TTLConfig = TTLConfig(Duration.ZERO)

}
Original file line number Diff line number Diff line change
Expand Up @@ -591,20 +591,23 @@ class TransformWithStateInPandasStateServer(
stateType match {
case StateVariableType.ValueState => if (!valueStates.contains(stateName)) {
val state = if (ttlDurationMs.isEmpty) {
statefulProcessorHandle.getValueState[Row](stateName, Encoders.row(schema))
statefulProcessorHandle.getValueState[Row](stateName, Encoders.row(schema),
TTLConfig.NONE)
} else {
statefulProcessorHandle.getValueState(
stateName, Encoders.row(schema), TTLConfig(Duration.ofMillis(ttlDurationMs.get)))
}
valueStates.put(stateName,
ValueStateInfo(state, schema, expressionEncoder.createDeserializer()))
sendResponse(0)
} else {
statefulProcessorHandle.getValueState(
stateName, Encoders.row(schema), TTLConfig(Duration.ofMillis(ttlDurationMs.get)))
sendResponse(1, s"Value state $stateName already exists")
}
valueStates.put(stateName,
ValueStateInfo(state, schema, expressionEncoder.createDeserializer()))
sendResponse(0)
} else {
sendResponse(1, s"Value state $stateName already exists")
}

case StateVariableType.ListState => if (!listStates.contains(stateName)) {
val state = if (ttlDurationMs.isEmpty) {
statefulProcessorHandle.getListState[Row](stateName, Encoders.row(schema))
statefulProcessorHandle.getListState[Row](stateName, Encoders.row(schema),
TTLConfig.NONE)
} else {
statefulProcessorHandle.getListState(
stateName, Encoders.row(schema), TTLConfig(Duration.ofMillis(ttlDurationMs.get)))
Expand All @@ -616,12 +619,13 @@ class TransformWithStateInPandasStateServer(
} else {
sendResponse(1, s"List state $stateName already exists")
}

case StateVariableType.MapState => if (!mapStates.contains(stateName)) {
val valueSchema = StructType.fromString(mapStateValueSchemaString)
val valueExpressionEncoder = ExpressionEncoder(valueSchema).resolveAndBind()
val state = if (ttlDurationMs.isEmpty) {
statefulProcessorHandle.getMapState[Row, Row](stateName,
Encoders.row(schema), Encoders.row(valueSchema))
Encoders.row(schema), Encoders.row(valueSchema), TTLConfig.NONE)
} else {
statefulProcessorHandle.getMapState[Row, Row](stateName, Encoders.row(schema),
Encoders.row(valueSchema), TTLConfig(Duration.ofMillis(ttlDurationMs.get)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
package org.apache.spark.sql.execution.streaming

import org.apache.spark.internal.Logging
import org.apache.spark.sql.Encoder
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, StateStore, StateStoreErrors}
Expand All @@ -39,7 +38,7 @@ class ListStateImpl[S](
store: StateStore,
stateName: String,
keyExprEnc: ExpressionEncoder[Any],
valEncoder: Encoder[S],
valEncoder: ExpressionEncoder[Any],
metrics: Map[String, SQLMetric] = Map.empty)
extends ListStateMetricsImpl
with ListState[S]
Expand Down Expand Up @@ -75,7 +74,7 @@ class ListStateImpl[S](

override def next(): S = {
val valueUnsafeRow = unsafeRowValuesIterator.next()
stateTypesEncoder.decodeValue(valueUnsafeRow)
stateTypesEncoder.decodeValue(valueUnsafeRow).asInstanceOf[S]
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
*/
package org.apache.spark.sql.execution.streaming

import org.apache.spark.sql.Encoder
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.execution.metric.SQLMetric
Expand All @@ -43,7 +42,7 @@ class ListStateImplWithTTL[S](
store: StateStore,
stateName: String,
keyExprEnc: ExpressionEncoder[Any],
valEncoder: Encoder[S],
valEncoder: ExpressionEncoder[Any],
ttlConfig: TTLConfig,
batchTimestampMs: Long,
metrics: Map[String, SQLMetric] = Map.empty)
Expand Down Expand Up @@ -91,7 +90,7 @@ class ListStateImplWithTTL[S](

if (iter.hasNext) {
val currentRow = iter.next()
stateTypesEncoder.decodeValue(currentRow)
stateTypesEncoder.decodeValue(currentRow).asInstanceOf[S]
} else {
finished = true
null.asInstanceOf[S]
Expand Down Expand Up @@ -223,7 +222,7 @@ class ListStateImplWithTTL[S](
val encodedGroupingKey = stateTypesEncoder.encodeGroupingKey()
val unsafeRowValuesIterator = store.valuesIterator(encodedGroupingKey, stateName)
unsafeRowValuesIterator.map { valueUnsafeRow =>
stateTypesEncoder.decodeValue(valueUnsafeRow)
stateTypesEncoder.decodeValue(valueUnsafeRow).asInstanceOf[S]
}
}

Expand All @@ -234,7 +233,7 @@ class ListStateImplWithTTL[S](
val encodedGroupingKey = stateTypesEncoder.encodeGroupingKey()
val unsafeRowValuesIterator = store.valuesIterator(encodedGroupingKey, stateName)
unsafeRowValuesIterator.map { valueUnsafeRow =>
(stateTypesEncoder.decodeValue(valueUnsafeRow),
(stateTypesEncoder.decodeValue(valueUnsafeRow).asInstanceOf[S],
stateTypesEncoder.decodeTtlExpirationMs(valueUnsafeRow).get)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
package org.apache.spark.sql.execution.streaming

import org.apache.spark.internal.Logging
import org.apache.spark.sql.Encoder
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchemaUtils._
Expand All @@ -40,8 +39,8 @@ class MapStateImpl[K, V](
store: StateStore,
stateName: String,
keyExprEnc: ExpressionEncoder[Any],
userKeyEnc: Encoder[K],
valEncoder: Encoder[V],
userKeyEnc: ExpressionEncoder[Any],
valEncoder: ExpressionEncoder[Any],
metrics: Map[String, SQLMetric] = Map.empty) extends MapState[K, V] with Logging {

// Pack grouping key and user key together as a prefixed composite key
Expand All @@ -67,7 +66,7 @@ class MapStateImpl[K, V](
val unsafeRowValue = store.get(encodedCompositeKey, stateName)

if (unsafeRowValue == null) return null.asInstanceOf[V]
stateTypesEncoder.decodeValue(unsafeRowValue)
stateTypesEncoder.decodeValue(unsafeRowValue).asInstanceOf[V]
}

/** Check if the user key is contained in the map */
Expand All @@ -92,8 +91,8 @@ class MapStateImpl[K, V](
store.prefixScan(encodedGroupingKey, stateName)
.map {
case iter: UnsafeRowPair =>
(stateTypesEncoder.decodeCompositeKey(iter.key),
stateTypesEncoder.decodeValue(iter.value))
(stateTypesEncoder.decodeCompositeKey(iter.key).asInstanceOf[K],
stateTypesEncoder.decodeValue(iter.value).asInstanceOf[V])
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
package org.apache.spark.sql.execution.streaming

import org.apache.spark.internal.Logging
import org.apache.spark.sql.Encoder
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.execution.metric.SQLMetric
Expand Down Expand Up @@ -45,8 +44,8 @@ class MapStateImplWithTTL[K, V](
store: StateStore,
stateName: String,
keyExprEnc: ExpressionEncoder[Any],
userKeyEnc: Encoder[K],
valEncoder: Encoder[V],
userKeyEnc: ExpressionEncoder[Any],
valEncoder: ExpressionEncoder[Any],
ttlConfig: TTLConfig,
batchTimestampMs: Long,
metrics: Map[String, SQLMetric] = Map.empty)
Expand Down Expand Up @@ -83,7 +82,7 @@ class MapStateImplWithTTL[K, V](

if (retRow != null) {
if (!stateTypesEncoder.isExpired(retRow, batchTimestampMs)) {
stateTypesEncoder.decodeValue(retRow)
stateTypesEncoder.decodeValue(retRow).asInstanceOf[V]
} else {
null.asInstanceOf[V]
}
Expand Down Expand Up @@ -126,7 +125,9 @@ class MapStateImplWithTTL[K, V](
if (iter.hasNext) {
val currentRowPair = iter.next()
val key = stateTypesEncoder.decodeCompositeKey(currentRowPair.key)
.asInstanceOf[K]
val value = stateTypesEncoder.decodeValue(currentRowPair.value)
.asInstanceOf[V]
(key, value)
} else {
finished = true
Expand Down Expand Up @@ -213,7 +214,7 @@ class MapStateImplWithTTL[K, V](
val retRow = store.get(encodedCompositeKey, stateName)

if (retRow != null) {
val resState = stateTypesEncoder.decodeValue(retRow)
val resState = stateTypesEncoder.decodeValue(retRow).asInstanceOf[V]
Some(resState)
} else {
None
Expand All @@ -231,7 +232,9 @@ class MapStateImplWithTTL[K, V](
// ttlExpiration
Option(retRow).flatMap { row =>
val ttlExpiration = stateTypesEncoder.decodeTtlExpirationMs(row)
ttlExpiration.map(expiration => (stateTypesEncoder.decodeValue(row), expiration))
ttlExpiration.map { expiration =>
(stateTypesEncoder.decodeValue(row).asInstanceOf[V], expiration)
}
}
}

Expand All @@ -253,7 +256,7 @@ class MapStateImplWithTTL[K, V](
0, keyExprEnc.schema.length)) {
val userKey = stateTypesEncoder.decodeUserKey(
nextTtlValue.userKey)
nextValue = Some(userKey, nextTtlValue.expirationMs)
nextValue = Some(userKey.asInstanceOf[K], nextTtlValue.expirationMs)
}
}
nextValue.isDefined
Expand Down
Loading

0 comments on commit 47063a6

Please sign in to comment.