From b1d1f10f96b1a92037a0205854745efeac5717ea Mon Sep 17 00:00:00 2001 From: Anish Shrigondekar Date: Sat, 19 Oct 2024 06:22:02 +0900 Subject: [PATCH] [SPARK-49846][SS] Add numUpdatedStateRows and numRemovedStateRows metrics for use with transformWithState operator ### What changes were proposed in this pull request? Add numUpdatedStateRows and numRemovedStateRows metrics for use with transformWithState operator ### Why are the changes needed? Without this change, metrics around these operations are not available in the query progress metrics ### Does this PR introduce _any_ user-facing change? No Metrics updated as part of the streaming query progress ``` "operatorName" : "transformWithStateExec", "numRowsTotal" : 1, "numRowsUpdated" : 1, "numRowsRemoved" : 1, ``` ### How was this patch tested? Added unit tests ``` [info] Run completed in 25 seconds, 697 milliseconds. [info] Total number of tests run: 2 [info] Suites: completed 1, aborted 0 [info] Tests: succeeded 2, failed 0, canceled 0, ignored 0, pending 0 [info] All tests passed. ``` ### Was this patch authored or co-authored using generative AI tooling? No Closes #48317 from anishshri-db/task/SPARK-49846. Authored-by: Anish Shrigondekar Signed-off-by: Jungtaek Lim --- .../execution/streaming/ListStateImpl.scala | 37 ++++++-- .../streaming/ListStateImplWithTTL.scala | 38 +++++++- .../streaming/ListStateMetricsImpl.scala | 86 +++++++++++++++++++ .../execution/streaming/MapStateImpl.scala | 19 +++- .../streaming/MapStateImplWithTTL.scala | 10 ++- .../StatefulProcessorHandleImpl.scala | 54 ++++++++---- .../execution/streaming/ValueStateImpl.scala | 7 +- .../streaming/ValueStateImplWithTTL.scala | 8 +- .../TransformWithListStateSuite.scala | 2 + .../TransformWithListStateTTLSuite.scala | 9 +- .../TransformWithMapStateSuite.scala | 4 + .../TransformWithMapStateTTLSuite.scala | 11 ++- .../streaming/TransformWithStateSuite.scala | 5 ++ .../TransformWithValueStateTTLSuite.scala | 14 ++- 14 files changed, 266 insertions(+), 38 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateMetricsImpl.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImpl.scala index 497472ce63676..77c481a8ba0ba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImpl.scala @@ -19,8 +19,10 @@ package org.apache.spark.sql.execution.streaming import org.apache.spark.internal.Logging import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, StateStore, StateStoreErrors} import org.apache.spark.sql.streaming.ListState +import org.apache.spark.sql.types.StructType /** * Provides concrete implementation for list of values associated with a state variable @@ -30,14 +32,22 @@ import org.apache.spark.sql.streaming.ListState * @param stateName - name of logical state partition * @param keyExprEnc - Spark SQL encoder for key * @param valEncoder - Spark SQL encoder for value + * @param metrics - metrics to be updated as part of stateful processing * @tparam S - data type of object that will be stored in the list */ class ListStateImpl[S]( store: StateStore, stateName: String, keyExprEnc: ExpressionEncoder[Any], - valEncoder: Encoder[S]) - extends ListState[S] with Logging { + valEncoder: Encoder[S], + metrics: Map[String, SQLMetric] = Map.empty) + extends ListStateMetricsImpl + with ListState[S] + with Logging { + + override def stateStore: StateStore = store + override def baseStateName: String = stateName + override def exprEncSchema: StructType = keyExprEnc.schema private val stateTypesEncoder = StateTypesEncoder(keyExprEnc, valEncoder, stateName) @@ -76,6 +86,8 @@ class ListStateImpl[S]( val encodedKey = stateTypesEncoder.encodeGroupingKey() var isFirst = true + var entryCount = 0L + TWSMetricsUtils.resetMetric(metrics, "numUpdatedStateRows") newState.foreach { v => val encodedValue = stateTypesEncoder.encodeValue(v) @@ -83,16 +95,23 @@ class ListStateImpl[S]( store.put(encodedKey, encodedValue, stateName) isFirst = false } else { - store.merge(encodedKey, encodedValue, stateName) + store.merge(encodedKey, encodedValue, stateName) } + entryCount += 1 + TWSMetricsUtils.incrementMetric(metrics, "numUpdatedStateRows") } + updateEntryCount(encodedKey, entryCount) } /** Append an entry to the list. */ override def appendValue(newState: S): Unit = { StateStoreErrors.requireNonNullStateValue(newState, stateName) - store.merge(stateTypesEncoder.encodeGroupingKey(), + val encodedKey = stateTypesEncoder.encodeGroupingKey() + val entryCount = getEntryCount(encodedKey) + store.merge(encodedKey, stateTypesEncoder.encodeValue(newState), stateName) + TWSMetricsUtils.incrementMetric(metrics, "numUpdatedStateRows") + updateEntryCount(encodedKey, entryCount + 1) } /** Append an entire list to the existing value. */ @@ -100,15 +119,23 @@ class ListStateImpl[S]( validateNewState(newState) val encodedKey = stateTypesEncoder.encodeGroupingKey() + var entryCount = getEntryCount(encodedKey) newState.foreach { v => val encodedValue = stateTypesEncoder.encodeValue(v) store.merge(encodedKey, encodedValue, stateName) + entryCount += 1 + TWSMetricsUtils.incrementMetric(metrics, "numUpdatedStateRows") } + updateEntryCount(encodedKey, entryCount) } /** Remove this state. */ override def clear(): Unit = { - store.remove(stateTypesEncoder.encodeGroupingKey(), stateName) + val encodedKey = stateTypesEncoder.encodeGroupingKey() + store.remove(encodedKey, stateName) + val entryCount = getEntryCount(encodedKey) + TWSMetricsUtils.incrementMetric(metrics, "numRemovedStateRows", entryCount) + removeEntryCount(encodedKey) } private def validateNewState(newState: Array[S]): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImplWithTTL.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImplWithTTL.scala index faeec7cb93863..be47f566bc6a9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImplWithTTL.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImplWithTTL.scala @@ -19,9 +19,11 @@ package org.apache.spark.sql.execution.streaming import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchemaUtils._ import org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, StateStore, StateStoreErrors} import org.apache.spark.sql.streaming.{ListState, TTLConfig} +import org.apache.spark.sql.types.StructType import org.apache.spark.util.NextIterator /** @@ -34,6 +36,7 @@ import org.apache.spark.util.NextIterator * @param valEncoder - Spark SQL encoder for value * @param ttlConfig - TTL configuration for values stored in this state * @param batchTimestampMs - current batch processing timestamp. + * @param metrics - metrics to be updated as part of stateful processing * @tparam S - data type of object that will be stored */ class ListStateImplWithTTL[S]( @@ -42,9 +45,15 @@ class ListStateImplWithTTL[S]( keyExprEnc: ExpressionEncoder[Any], valEncoder: Encoder[S], ttlConfig: TTLConfig, - batchTimestampMs: Long) - extends SingleKeyTTLStateImpl( - stateName, store, keyExprEnc, batchTimestampMs) with ListState[S] { + batchTimestampMs: Long, + metrics: Map[String, SQLMetric] = Map.empty) + extends SingleKeyTTLStateImpl(stateName, store, keyExprEnc, batchTimestampMs) + with ListStateMetricsImpl + with ListState[S] { + + override def stateStore: StateStore = store + override def baseStateName: String = stateName + override def exprEncSchema: StructType = keyExprEnc.schema private lazy val stateTypesEncoder = StateTypesEncoder(keyExprEnc, valEncoder, stateName, hasTtl = true) @@ -99,6 +108,8 @@ class ListStateImplWithTTL[S]( val encodedKey = stateTypesEncoder.encodeGroupingKey() var isFirst = true + var entryCount = 0L + TWSMetricsUtils.resetMetric(metrics, "numUpdatedStateRows") newState.foreach { v => val encodedValue = stateTypesEncoder.encodeValue(v, ttlExpirationMs) @@ -108,17 +119,23 @@ class ListStateImplWithTTL[S]( } else { store.merge(encodedKey, encodedValue, stateName) } + entryCount += 1 + TWSMetricsUtils.incrementMetric(metrics, "numUpdatedStateRows") } upsertTTLForStateKey(encodedKey) + updateEntryCount(encodedKey, entryCount) } /** Append an entry to the list. */ override def appendValue(newState: S): Unit = { StateStoreErrors.requireNonNullStateValue(newState, stateName) val encodedKey = stateTypesEncoder.encodeGroupingKey() + val entryCount = getEntryCount(encodedKey) store.merge(encodedKey, stateTypesEncoder.encodeValue(newState, ttlExpirationMs), stateName) + TWSMetricsUtils.incrementMetric(metrics, "numUpdatedStateRows") upsertTTLForStateKey(encodedKey) + updateEntryCount(encodedKey, entryCount + 1) } /** Append an entire list to the existing value. */ @@ -126,16 +143,24 @@ class ListStateImplWithTTL[S]( validateNewState(newState) val encodedKey = stateTypesEncoder.encodeGroupingKey() + var entryCount = getEntryCount(encodedKey) newState.foreach { v => val encodedValue = stateTypesEncoder.encodeValue(v, ttlExpirationMs) store.merge(encodedKey, encodedValue, stateName) + entryCount += 1 + TWSMetricsUtils.incrementMetric(metrics, "numUpdatedStateRows") } upsertTTLForStateKey(encodedKey) + updateEntryCount(encodedKey, entryCount) } /** Remove this state. */ override def clear(): Unit = { - store.remove(stateTypesEncoder.encodeGroupingKey(), stateName) + val encodedKey = stateTypesEncoder.encodeGroupingKey() + store.remove(encodedKey, stateName) + val entryCount = getEntryCount(encodedKey) + TWSMetricsUtils.incrementMetric(metrics, "numRemovedStateRows", entryCount) + removeEntryCount(encodedKey) clearTTLState() } @@ -158,7 +183,9 @@ class ListStateImplWithTTL[S]( val unsafeRowValuesIterator = store.valuesIterator(groupingKey, stateName) // We clear the list, and use the iterator to put back all of the non-expired values store.remove(groupingKey, stateName) + removeEntryCount(groupingKey) var isFirst = true + var entryCount = 0L unsafeRowValuesIterator.foreach { encodedValue => if (!stateTypesEncoder.isExpired(encodedValue, batchTimestampMs)) { if (isFirst) { @@ -167,10 +194,13 @@ class ListStateImplWithTTL[S]( } else { store.merge(groupingKey, encodedValue, stateName) } + entryCount += 1 } else { numValuesExpired += 1 } } + updateEntryCount(groupingKey, entryCount) + TWSMetricsUtils.incrementMetric(metrics, "numRemovedStateRows", numValuesExpired) numValuesExpired } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateMetricsImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateMetricsImpl.scala new file mode 100644 index 0000000000000..ea43c3f4fcd3b --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateMetricsImpl.scala @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.streaming + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, StateStore} +import org.apache.spark.sql.types._ + +/** + * Trait that provides helper methods to maintain metrics for a list state. + * For list state, we keep track of the count of entries in the list in a separate column family + * to get an accurate view of the number of entries that are updated/removed from the list and + * reported as part of the query progress metrics. + */ +trait ListStateMetricsImpl { + def stateStore: StateStore + + def baseStateName: String + + def exprEncSchema: StructType + + // We keep track of the count of entries in the list in a separate column family + // to avoid scanning the entire list to get the count. + private val counterCFValueSchema: StructType = + StructType(Seq(StructField("count", LongType, nullable = false))) + + private val counterCFProjection = UnsafeProjection.create(counterCFValueSchema) + + private val updatedCountRow = new GenericInternalRow(1) + + private def getRowCounterCFName(stateName: String) = "$rowCounter_" + stateName + + stateStore.createColFamilyIfAbsent(getRowCounterCFName(baseStateName), exprEncSchema, + counterCFValueSchema, NoPrefixKeyStateEncoderSpec(exprEncSchema), isInternal = true) + + /** + * Function to get the number of entries in the list state for a given grouping key + * @param encodedKey - encoded grouping key + * @return - number of entries in the list state + */ + def getEntryCount(encodedKey: UnsafeRow): Long = { + val countRow = stateStore.get(encodedKey, getRowCounterCFName(baseStateName)) + if (countRow != null) { + countRow.getLong(0) + } else { + 0L + } + } + + /** + * Function to update the number of entries in the list state for a given grouping key + * @param encodedKey - encoded grouping key + * @param updatedCount - updated count of entries in the list state + */ + def updateEntryCount( + encodedKey: UnsafeRow, + updatedCount: Long): Unit = { + updatedCountRow.setLong(0, updatedCount) + stateStore.put(encodedKey, + counterCFProjection(updatedCountRow.asInstanceOf[InternalRow]), + getRowCounterCFName(baseStateName)) + } + + /** + * Function to remove the number of entries in the list state for a given grouping key + * @param encodedKey - encoded grouping key + */ + def removeEntryCount(encodedKey: UnsafeRow): Unit = { + stateStore.remove(encodedKey, getRowCounterCFName(baseStateName)) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImpl.scala index 2fa6fa415a77b..cb3db19496dd2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImpl.scala @@ -19,17 +19,30 @@ package org.apache.spark.sql.execution.streaming import org.apache.spark.internal.Logging import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchemaUtils._ import org.apache.spark.sql.execution.streaming.state.{PrefixKeyScanStateEncoderSpec, StateStore, StateStoreErrors, UnsafeRowPair} import org.apache.spark.sql.streaming.MapState import org.apache.spark.sql.types.StructType +/** + * Class that provides a concrete implementation for map state associated with state + * variables used in the streaming transformWithState operator. + * @param store - reference to the StateStore instance to be used for storing state + * @param stateName - name of logical state partition + * @param keyExprEnc - Spark SQL encoder for key + * @param valEncoder - Spark SQL encoder for value + * @param metrics - metrics to be updated as part of stateful processing + * @tparam K - type of key for map state variable + * @tparam V - type of value for map state variable + */ class MapStateImpl[K, V]( store: StateStore, stateName: String, keyExprEnc: ExpressionEncoder[Any], userKeyEnc: Encoder[K], - valEncoder: Encoder[V]) extends MapState[K, V] with Logging { + valEncoder: Encoder[V], + metrics: Map[String, SQLMetric] = Map.empty) extends MapState[K, V] with Logging { // Pack grouping key and user key together as a prefixed composite key private val schemaForCompositeKeyRow: StructType = { @@ -70,6 +83,7 @@ class MapStateImpl[K, V]( val encodedValue = stateTypesEncoder.encodeValue(value) val encodedCompositeKey = stateTypesEncoder.encodeCompositeKey(key) store.put(encodedCompositeKey, encodedValue, stateName) + TWSMetricsUtils.incrementMetric(metrics, "numUpdatedStateRows") } /** Get the map associated with grouping key */ @@ -98,6 +112,9 @@ class MapStateImpl[K, V]( StateStoreErrors.requireNonNullStateValue(key, stateName) val compositeKey = stateTypesEncoder.encodeCompositeKey(key) store.remove(compositeKey, stateName) + // Note that for mapState, the rows are flattened. So we count the number of rows removed + // proportional to the number of keys in the map per grouping key. + TWSMetricsUtils.incrementMetric(metrics, "numRemovedStateRows") } /** Remove this state. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImplWithTTL.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImplWithTTL.scala index a6234636a94f7..6a3685ad6c46c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImplWithTTL.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImplWithTTL.scala @@ -20,6 +20,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchemaUtils._ import org.apache.spark.sql.execution.streaming.state.{PrefixKeyScanStateEncoderSpec, StateStore, StateStoreErrors} import org.apache.spark.sql.streaming.{MapState, TTLConfig} @@ -35,6 +36,7 @@ import org.apache.spark.util.NextIterator * @param valEncoder - SQL encoder for state variable * @param ttlConfig - the ttl configuration (time to live duration etc.) * @param batchTimestampMs - current batch processing timestamp. + * @param metrics - metrics to be updated as part of stateful processing * @tparam K - type of key for map state variable * @tparam V - type of value for map state variable * @return - instance of MapState of type [K,V] that can be used to store state persistently @@ -46,7 +48,8 @@ class MapStateImplWithTTL[K, V]( userKeyEnc: Encoder[K], valEncoder: Encoder[V], ttlConfig: TTLConfig, - batchTimestampMs: Long) + batchTimestampMs: Long, + metrics: Map[String, SQLMetric] = Map.empty) extends CompositeKeyTTLStateImpl[K](stateName, store, keyExprEnc, userKeyEnc, batchTimestampMs) with MapState[K, V] with Logging { @@ -106,6 +109,7 @@ class MapStateImplWithTTL[K, V]( val encodedValue = stateTypesEncoder.encodeValue(value, ttlExpirationMs) val encodedCompositeKey = stateTypesEncoder.encodeCompositeKey(key) store.put(encodedCompositeKey, encodedValue, stateName) + TWSMetricsUtils.incrementMetric(metrics, "numUpdatedStateRows") upsertTTLForStateKey(ttlExpirationMs, encodedGroupingKey, encodedUserKey) } @@ -149,6 +153,9 @@ class MapStateImplWithTTL[K, V]( StateStoreErrors.requireNonNullStateValue(key, stateName) val compositeKey = stateTypesEncoder.encodeCompositeKey(key) store.remove(compositeKey, stateName) + // Note that for mapState, the rows are flattened. So we count the number of rows removed + // proportional to the number of keys in the map per grouping key. + TWSMetricsUtils.incrementMetric(metrics, "numRemovedStateRows") } /** Remove this state. */ @@ -184,6 +191,7 @@ class MapStateImplWithTTL[K, V]( if (stateTypesEncoder.isExpired(retRow, batchTimestampMs)) { store.remove(compositeKeyRow, stateName) numRemovedElements += 1 + TWSMetricsUtils.incrementMetric(metrics, "numRemovedStateRows") } } numRemovedElements diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala index 8beacbec7e6ef..762dfc7d08920 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala @@ -45,6 +45,24 @@ object ImplicitGroupingKeyTracker { def removeImplicitKey(): Unit = implicitKey.remove() } +/** + * Utility object to perform metrics updates + */ +object TWSMetricsUtils { + def resetMetric( + metrics: Map[String, SQLMetric], + metricName: String): Unit = { + metrics.get(metricName).foreach(_.reset()) + } + + def incrementMetric( + metrics: Map[String, SQLMetric], + metricName: String, + countValue: Long = 1L): Unit = { + metrics.get(metricName).foreach(_.add(countValue)) + } +} + /** * Enum used to track valid states for the StatefulProcessorHandle */ @@ -117,16 +135,12 @@ class StatefulProcessorHandleImpl( private lazy val currQueryInfo: QueryInfo = buildQueryInfo() - private def incrementMetric(metricName: String): Unit = { - metrics.get(metricName).foreach(_.add(1)) - } - override def getValueState[T]( stateName: String, valEncoder: Encoder[T]): ValueState[T] = { verifyStateVarOperations("get_value_state", CREATED) - incrementMetric("numValueStateVars") - val resultState = new ValueStateImpl[T](store, stateName, keyEncoder, valEncoder) + val resultState = new ValueStateImpl[T](store, stateName, keyEncoder, valEncoder, metrics) + TWSMetricsUtils.incrementMetric(metrics, "numValueStateVars") resultState } @@ -139,9 +153,10 @@ class StatefulProcessorHandleImpl( assert(batchTimestampMs.isDefined) val valueStateWithTTL = new ValueStateImplWithTTL[T](store, stateName, - keyEncoder, valEncoder, ttlConfig, batchTimestampMs.get) - incrementMetric("numValueStateWithTTLVars") + keyEncoder, valEncoder, ttlConfig, batchTimestampMs.get, metrics) ttlStates.add(valueStateWithTTL) + TWSMetricsUtils.incrementMetric(metrics, "numValueStateWithTTLVars") + valueStateWithTTL } @@ -155,8 +170,8 @@ class StatefulProcessorHandleImpl( */ override def registerTimer(expiryTimestampMs: Long): Unit = { verifyTimerOperations("register_timer") - incrementMetric("numRegisteredTimers") timerState.registerTimer(expiryTimestampMs) + TWSMetricsUtils.incrementMetric(metrics, "numRegisteredTimers") } /** @@ -165,8 +180,8 @@ class StatefulProcessorHandleImpl( */ override def deleteTimer(expiryTimestampMs: Long): Unit = { verifyTimerOperations("delete_timer") - incrementMetric("numDeletedTimers") timerState.deleteTimer(expiryTimestampMs) + TWSMetricsUtils.incrementMetric(metrics, "numDeletedTimers") } /** @@ -211,14 +226,14 @@ class StatefulProcessorHandleImpl( override def deleteIfExists(stateName: String): Unit = { verifyStateVarOperations("delete_if_exists", CREATED) if (store.removeColFamilyIfExists(stateName)) { - incrementMetric("numDeletedStateVars") + TWSMetricsUtils.incrementMetric(metrics, "numDeletedStateVars") } } override def getListState[T](stateName: String, valEncoder: Encoder[T]): ListState[T] = { verifyStateVarOperations("get_list_state", CREATED) - incrementMetric("numListStateVars") - val resultState = new ListStateImpl[T](store, stateName, keyEncoder, valEncoder) + val resultState = new ListStateImpl[T](store, stateName, keyEncoder, valEncoder, metrics) + TWSMetricsUtils.incrementMetric(metrics, "numListStateVars") resultState } @@ -247,8 +262,8 @@ class StatefulProcessorHandleImpl( assert(batchTimestampMs.isDefined) val listStateWithTTL = new ListStateImplWithTTL[T](store, stateName, - keyEncoder, valEncoder, ttlConfig, batchTimestampMs.get) - incrementMetric("numListStateWithTTLVars") + keyEncoder, valEncoder, ttlConfig, batchTimestampMs.get, metrics) + TWSMetricsUtils.incrementMetric(metrics, "numListStateWithTTLVars") ttlStates.add(listStateWithTTL) listStateWithTTL @@ -259,8 +274,9 @@ class StatefulProcessorHandleImpl( userKeyEnc: Encoder[K], valEncoder: Encoder[V]): MapState[K, V] = { verifyStateVarOperations("get_map_state", CREATED) - incrementMetric("numMapStateVars") - val resultState = new MapStateImpl[K, V](store, stateName, keyEncoder, userKeyEnc, valEncoder) + val resultState = new MapStateImpl[K, V](store, stateName, keyEncoder, + userKeyEnc, valEncoder, metrics) + TWSMetricsUtils.incrementMetric(metrics, "numMapStateVars") resultState } @@ -274,8 +290,8 @@ class StatefulProcessorHandleImpl( assert(batchTimestampMs.isDefined) val mapStateWithTTL = new MapStateImplWithTTL[K, V](store, stateName, keyEncoder, userKeyEnc, - valEncoder, ttlConfig, batchTimestampMs.get) - incrementMetric("numMapStateWithTTLVars") + valEncoder, ttlConfig, batchTimestampMs.get, metrics) + TWSMetricsUtils.incrementMetric(metrics, "numMapStateWithTTLVars") ttlStates.add(mapStateWithTTL) mapStateWithTTL diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala index 63cac4a3b68cb..b1b87feeb263b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.streaming import org.apache.spark.internal.Logging import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, StateStore} import org.apache.spark.sql.streaming.ValueState @@ -29,13 +30,15 @@ import org.apache.spark.sql.streaming.ValueState * @param stateName - name of logical state partition * @param keyExprEnc - Spark SQL encoder for key * @param valEncoder - Spark SQL encoder for value + * @param metrics - metrics to be updated as part of stateful processing * @tparam S - data type of object that will be stored */ class ValueStateImpl[S]( store: StateStore, stateName: String, keyExprEnc: ExpressionEncoder[Any], - valEncoder: Encoder[S]) + valEncoder: Encoder[S], + metrics: Map[String, SQLMetric] = Map.empty) extends ValueState[S] with Logging { private val stateTypesEncoder = StateTypesEncoder(keyExprEnc, valEncoder, stateName) @@ -74,10 +77,12 @@ class ValueStateImpl[S]( val encodedValue = stateTypesEncoder.encodeValue(newState) store.put(stateTypesEncoder.encodeGroupingKey(), encodedValue, stateName) + TWSMetricsUtils.incrementMetric(metrics, "numUpdatedStateRows") } /** Function to remove state for given key */ override def clear(): Unit = { store.remove(stateTypesEncoder.encodeGroupingKey(), stateName) + TWSMetricsUtils.incrementMetric(metrics, "numRemovedStateRows") } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImplWithTTL.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImplWithTTL.scala index c6d11b155866b..145cd90264910 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImplWithTTL.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImplWithTTL.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.streaming import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchemaUtils._ import org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, StateStore} import org.apache.spark.sql.streaming.{TTLConfig, ValueState} @@ -33,6 +34,7 @@ import org.apache.spark.sql.streaming.{TTLConfig, ValueState} * @param valEncoder - Spark SQL encoder for value * @param ttlConfig - TTL configuration for values stored in this state * @param batchTimestampMs - current batch processing timestamp. + * @param metrics - metrics to be updated as part of stateful processing * @tparam S - data type of object that will be stored */ class ValueStateImplWithTTL[S]( @@ -41,7 +43,8 @@ class ValueStateImplWithTTL[S]( keyExprEnc: ExpressionEncoder[Any], valEncoder: Encoder[S], ttlConfig: TTLConfig, - batchTimestampMs: Long) + batchTimestampMs: Long, + metrics: Map[String, SQLMetric] = Map.empty) extends SingleKeyTTLStateImpl( stateName, store, keyExprEnc, batchTimestampMs) with ValueState[S] { @@ -92,12 +95,14 @@ class ValueStateImplWithTTL[S]( val serializedGroupingKey = stateTypesEncoder.encodeGroupingKey() store.put(serializedGroupingKey, encodedValue, stateName) + TWSMetricsUtils.incrementMetric(metrics, "numUpdatedStateRows") upsertTTLForStateKey(ttlExpirationMs, serializedGroupingKey) } /** Function to remove state for given key */ override def clear(): Unit = { store.remove(stateTypesEncoder.encodeGroupingKey(), stateName) + TWSMetricsUtils.incrementMetric(metrics, "numRemovedStateRows") clearTTLState() } @@ -108,6 +113,7 @@ class ValueStateImplWithTTL[S]( if (retRow != null) { if (stateTypesEncoder.isExpired(retRow, batchTimestampMs)) { store.remove(groupingKey, stateName) + TWSMetricsUtils.incrementMetric(metrics, "numRemovedStateRows") result = 1L } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateSuite.scala index dea16e5298975..71b8c8ac923d4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateSuite.scala @@ -301,6 +301,8 @@ class TransformWithListStateSuite extends StreamTest CheckNewAnswer(("k5", "v5"), ("k5", "v6")), Execute { q => assert(q.lastProgress.stateOperators(0).customMetrics.get("numListStateVars") > 0) + assert(q.lastProgress.stateOperators(0).numRowsUpdated === 2) + assert(q.lastProgress.stateOperators(0).numRowsRemoved === 2) } ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateTTLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateTTLSuite.scala index 299a3346b2e51..d11d8ef9a9b36 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateTTLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateTTLSuite.scala @@ -147,6 +147,9 @@ class TransformWithListStateTTLSuite extends TransformWithStateTTLTest { ), AdvanceManualClock(1 * 1000), CheckNewAnswer(), + Execute { q => + assert(q.lastProgress.stateOperators(0).numRowsUpdated === 3) + }, // get ttl values AddData(inputStream, InputEvent("k1", "get_ttl_value_from_state", -1, null)), AdvanceManualClock(1 * 1000), @@ -158,15 +161,17 @@ class TransformWithListStateTTLSuite extends TransformWithStateTTLTest { OutputEvent("k1", 5, isTTLValue = true, 109000), OutputEvent("k1", 6, isTTLValue = true, 109000) ), + AddData(inputStream, InputEvent("k1", "get", -1, null)), // advance clock to expire the first three elements AdvanceManualClock(15 * 1000), // batch timestamp: 65000 - AddData(inputStream, InputEvent("k1", "get", -1, null)), - AdvanceManualClock(1 * 1000), CheckNewAnswer( OutputEvent("k1", 4, isTTLValue = false, -1), OutputEvent("k1", 5, isTTLValue = false, -1), OutputEvent("k1", 6, isTTLValue = false, -1) ), + Execute { q => + assert(q.lastProgress.stateOperators(0).numRowsRemoved === 3) + }, // ensure that expired elements are no longer in state AddData(inputStream, InputEvent("k1", "get_without_enforcing_ttl", -1, null)), AdvanceManualClock(1 * 1000), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateSuite.scala index fe88fbaa91cb7..e4e6862f7f937 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateSuite.scala @@ -209,9 +209,13 @@ class TransformWithMapStateSuite extends StreamTest AddData(inputData, InputMapRow("k2", "iterator", ("", ""))), CheckNewAnswer(), AddData(inputData, InputMapRow("k2", "exists", ("", ""))), + AddData(inputData, InputMapRow("k1", "clear", ("", ""))), + AddData(inputData, InputMapRow("k3", "updateValue", ("v7", "11"))), CheckNewAnswer(("k2", "exists", "false")), Execute { q => assert(q.lastProgress.stateOperators(0).customMetrics.get("numMapStateVars") > 0) + assert(q.lastProgress.stateOperators(0).numRowsUpdated === 1) + assert(q.lastProgress.stateOperators(0).numRowsRemoved === 1) } ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateTTLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateTTLSuite.scala index bf46c802fdea4..3794bcc9ea271 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateTTLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateTTLSuite.scala @@ -210,12 +210,17 @@ class TransformWithMapStateTTLSuite extends TransformWithStateTTLTest { AddData(inputStream, MapInputEvent("k1", "key2", "put", 2)), AdvanceManualClock(1 * 1000), CheckNewAnswer(), - // advance clock to expire first key - AdvanceManualClock(30 * 1000), + Execute { q => + assert(q.lastProgress.stateOperators(0).numRowsUpdated === 1) + }, AddData(inputStream, MapInputEvent("k1", "key1", "get", -1), MapInputEvent("k1", "key2", "get", -1)), - AdvanceManualClock(1 * 1000), + // advance clock to expire first key + AdvanceManualClock(30 * 1000), CheckNewAnswer(MapOutputEvent("k1", "key2", 2, isTTLValue = false, -1)), + Execute { q => + assert(q.lastProgress.stateOperators(0).numRowsRemoved === 1) + }, StopStream ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala index 0c02fbf97820b..257578ee65447 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala @@ -528,6 +528,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest Execute { q => assert(q.lastProgress.stateOperators(0).customMetrics.get("numValueStateVars") > 0) assert(q.lastProgress.stateOperators(0).customMetrics.get("numRegisteredTimers") == 0) + assert(q.lastProgress.stateOperators(0).numRowsUpdated === 1) }, AddData(inputData, "a", "b"), CheckNewAnswer(("a", "2"), ("b", "1")), @@ -536,6 +537,10 @@ class TransformWithStateSuite extends StateStoreMetricsTest AddData(inputData, "a", "b"), // should remove state for "a" and not return anything for a CheckNewAnswer(("b", "2")), StopStream, + Execute { q => + assert(q.lastProgress.stateOperators(0).numRowsUpdated === 1) + assert(q.lastProgress.stateOperators(0).numRowsRemoved === 1) + }, StartStream(), AddData(inputData, "a", "c"), // should recreate state for "a" and return count as 1 and CheckNewAnswer(("a", "1"), ("c", "1")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithValueStateTTLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithValueStateTTLSuite.scala index 1fbeaeb817bd9..e2b31de1f66b3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithValueStateTTLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithValueStateTTLSuite.scala @@ -247,7 +247,19 @@ class TransformWithValueStateTTLSuite extends TransformWithStateTTLTest { // validate ttl value is removed in the value state column family AddData(inputStream, InputEvent(ttlKey, "get_ttl_value_from_state", -1)), AdvanceManualClock(1 * 1000), - CheckNewAnswer() + CheckNewAnswer(), + AddData(inputStream, InputEvent(ttlKey, "put", 3)), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(), + Execute { q => + assert(q.lastProgress.stateOperators(0).numRowsUpdated === 1) + }, + AddData(inputStream, InputEvent(noTtlKey, "get", -1)), + AdvanceManualClock(60 * 1000), + CheckNewAnswer(OutputEvent(noTtlKey, 2, isTTLValue = false, -1)), + Execute { q => + assert(q.lastProgress.stateOperators(0).numRowsRemoved === 1) + } ) } }