Skip to content
This repository has been archived by the owner on Aug 10, 2024. It is now read-only.

Ownership system for callbacks to address memory leak #611

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions buildSrc/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ dependencies {
// files in the project.
// Use their Maven coordinates (plus versions), not Gradle plugin IDs!
// This should be the only place that Gradle plugin versions are defined, so they are aligned across all build scripts
implementation("org.jetbrains.kotlin:kotlin-gradle-plugin:1.9.22")
implementation("org.jetbrains.kotlin:kotlin-serialization:1.9.22")
implementation("org.jetbrains.kotlin:kotlin-gradle-plugin:1.9.23")
implementation("org.jetbrains.kotlin:kotlin-serialization:1.9.23")
}

val gradleJvmTarget = 17
Expand Down
153 changes: 75 additions & 78 deletions src/main/kotlin/kweb/Kweb.kt
Original file line number Diff line number Diff line change
Expand Up @@ -251,96 +251,96 @@ class Kweb private constructor(
}
}

private suspend fun RemoteClientState?.ensureSessionExists(
sock: DefaultWebSocketSession,
sessionId: String
): RemoteClientState {
if (this == null) {
sock.close(CloseReason(CloseReason.Codes.NOT_CONSISTENT, "Session not found. Please reload"))
error("Unable to find server state corresponding to client id ${sessionId}")
}
return this
}

private suspend fun DefaultWebSocketSession.listenForWebsocketConnection() {
val hello = Json.decodeFromString<Client2ServerMessage>((incoming.receive() as Text).readText())

if (hello.hello == null) {
error("First message from client isn't 'hello'")
}

val remoteClientState = clientState.getIfPresent(hello.id).ensureSessionExists(this, hello.id)
val remoteClientState = clientState.getIfPresent(hello.id)

val currentCC = remoteClientState.clientConnection
remoteClientState.clientConnection = when (currentCC) {
is Caching -> {
val webSocketClientConnection = ClientConnection.WebSocket(this)
currentCC.read().forEach {
webSocketClientConnection.send(it)
}
remoteClientState.addCloseHandler {
webSocketClientConnection.close(CloseReason(4002, "RemoteClientState closed by server, likely due to cache expiry"))
if (remoteClientState == null) {
logger.warn("Client id ${hello.id} not found, closing connection")
this.close(CloseReason(4000, "Client id ${hello.id} not found"))
} else {

val newClientConnection = when (val currentCC = remoteClientState.getClientConnection()) {
is Caching -> {
val webSocketClientConnection = ClientConnection.WebSocket(this)
currentCC.read().forEach {
webSocketClientConnection.send(it)
}
remoteClientState.addCloseHandler {
webSocketClientConnection.close(
CloseReason(
4002,
"RemoteClientState closed by server, likely due to cache expiry"
)
)
}
webSocketClientConnection
}
webSocketClientConnection
}

is ClientConnection.WebSocket -> {
currentCC.close(CloseReason(4001, "Client reconnected via another connection"))
val ws = ClientConnection.WebSocket(this)
remoteClientState.addCloseHandler {
ws.close(CloseReason(4002, "RemoteClientState closed by server, likely due to cache expiry"))
is ClientConnection.WebSocket -> {
currentCC.close(CloseReason(4001, "Client reconnected via another connection"))
val ws = ClientConnection.WebSocket(this)
remoteClientState.addCloseHandler {
ws.close(CloseReason(4002, "RemoteClientState closed by server, likely due to cache expiry"))
}
ws
}
ws
}
}
remoteClientState.updateClientConnection(newClientConnection)

try {
for (frame in incoming) {
try {
for (frame in incoming) {

logger.debug { "WebSocket frame of type ${frame.frameType} received" }
logger.debug { "WebSocket frame of type ${frame.frameType} received" }

// Retrieve the clientState so that it doesn't expire, replace it if it
// has expired.
clientState.get(hello.id) { remoteClientState }
// Retrieve the clientState so that it doesn't expire, replace it if it
// has expired.
clientState.get(hello.id) { remoteClientState }

try {
logger.debug { "Message received from client" }

if (frame is Text) {
val message = Json.decodeFromString<Client2ServerMessage>(frame.readText())

logger.debug { "Message received: $message" }
if (message.error != null) {
handleError(message.error, remoteClientState)
} else {
when {
message.callback != null -> {
val (resultId, result) = message.callback
val resultHandler = remoteClientState.eventHandlers[resultId]
?: error("No resultHandler for $resultId, for client ${remoteClientState.id}")
resultHandler(result)
}

message.keepalive -> {
logger.debug { "keepalive received from client ${hello.id}" }
}
try {
logger.debug { "Message received from client" }

if (frame is Text) {
val message = Json.decodeFromString<Client2ServerMessage>(frame.readText())

logger.debug { "Message received: $message" }
if (message.error != null) {
handleError(message.error, remoteClientState)
} else {
when {
message.callback != null -> {
val (resultId, result) = message.callback
val resultHandler = remoteClientState.eventHandlers[resultId]
?: error("No resultHandler for $resultId, for client ${remoteClientState.id}")
resultHandler(result)
}

message.keepalive -> {
logger.debug { "keepalive received from client ${hello.id}" }
}

message.onMessageData != null -> {
val data = message.onMessageData
remoteClientState.onMessageFunction!!.invoke(data)
}

message.onMessageData != null -> {
val data = message.onMessageData
remoteClientState.onMessageFunction!!.invoke(data)
}

}
}
} catch (e: Exception) {
logger.error("Exception while receiving websocket message", e)
kwebConfig.onWebsocketMessageHandlingFailure(e)
}
} catch (e: Exception) {
logger.error("Exception while receiving websocket message", e)
kwebConfig.onWebsocketMessageHandlingFailure(e)
}
} finally {
logger.info("WS session disconnected for client id: ${remoteClientState.id}")
remoteClientState.updateClientConnection(Caching("After WS disconnection"))
}
} finally {
logger.info("WS session disconnected for client id: ${remoteClientState.id}")
remoteClientState.clientConnection = Caching()
}
}

Expand All @@ -365,12 +365,8 @@ class Kweb private constructor(

val httpRequestInfo = HttpRequestInfo(call.request)


//this doesn't work. I get this error when running the todo Demo
//Caused by: java.lang.IllegalStateException: Client id Nb9_U7:eJ5dw4 not found
//The debugger says that the remoteClientState ID here matches the clientPrefix and the kwebSessionID from a few lines ago.
val remoteClientState = clientState.get(kwebSessionId) {
RemoteClientState(id = kwebSessionId, clientConnection = Caching())
RemoteClientState(id = kwebSessionId, initialClientConnection = Caching("Initial render"))
}


Expand All @@ -396,15 +392,15 @@ class Kweb private constructor(
} catch (e: Exception) {
logger.error("Exception thrown building page", e)
}
logger.debug { "Outbound message queue size after buildPage is ${(remoteClientState.clientConnection as Caching).queueSize()}" }
logger.debug { "Outbound message queue size after buildPage is ${(remoteClientState.getClientConnection() as Caching).queueSize()}" }
}
} else {
try {
buildPage(webBrowser)
} catch (e: Exception) {
logger.error("Exception thrown building page", e)
}
logger.debug { "Outbound message queue size after buildPage is ${(remoteClientState.clientConnection as Caching).queueSize()}" }
logger.debug { "Outbound message queue size after buildPage is ${(remoteClientState.getClientConnection() as Caching).queueSize()}" }
}
for (plugin in plugins) {
//this code block looks a little funny now, but I still think moving the message creation out of Kweb.callJs() was the right move
Expand All @@ -418,11 +414,12 @@ class Kweb private constructor(

webBrowser.htmlDocument.set(null) // Don't think this webBrowser will be used again, but not going to risk it

val initialCachedMessages = remoteClientState.clientConnection as Caching
val initialCachedMessages = remoteClientState.getClientConnection() as Caching

remoteClientState.clientConnection = Caching()
// TODO: Verify that this is correct
remoteClientState.updateClientConnection(Caching("Awaiting WS connection cache"))

val initialMessages = initialCachedMessages.read()//the initialCachedMessages queue can only be read once
val initialMessages = initialCachedMessages.read()

val cachedFunctions = mutableListOf<String>()
val cachedIds = mutableListOf<Int>()
Expand Down Expand Up @@ -504,7 +501,7 @@ class Kweb private constructor(
for (client in clientState.asMap().values) {
val refreshCall = FunctionCall(js = "window.location.reload(true);")
val message = Server2ClientMessage(client.id, refreshCall)
client.clientConnection.send(Json.encodeToString(message))
client.getClientConnection().send(Json.encodeToString(message))
}
}
}
Expand Down
13 changes: 10 additions & 3 deletions src/main/kotlin/kweb/client/ClientConnection.kt
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,13 @@ sealed class ClientConnection {
sendBuffer.close()
}
}

override fun toString(): String {
return "WebSocket()"
}
}

class Caching : ClientConnection() {
class Caching(val description : String) : ClientConnection() {
private val queue = ConcurrentLinkedQueue<String>()
private val lock = ReentrantLock()
private val isRead = AtomicBoolean(false)
Expand All @@ -58,7 +62,7 @@ sealed class ClientConnection {
if (isRead.get()) {
error("Can't write to queue after it has been read")
} else {
logger.debug("Caching '$message' as websocket isn't yet available")
logger.debug("Caching \"${message.take(20)}...\" in $description")
queue.add(message)
}
}
Expand All @@ -82,7 +86,10 @@ sealed class ClientConnection {
return queue.size
}
}
}

override fun toString(): String {
return "Caching($description)"
}
}

}
15 changes: 14 additions & 1 deletion src/main/kotlin/kweb/client/RemoteClientState.kt
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,27 @@ import kotlinx.serialization.json.Json
import kotlinx.serialization.json.JsonElement
import kweb.DebugInfo
import kweb.util.random
import mu.two.KotlinLogging
import java.time.Instant
import java.util.concurrent.ConcurrentHashMap

data class RemoteClientState(val id: String, @Volatile var clientConnection: ClientConnection,
private var logger = KotlinLogging.logger {}

data class RemoteClientState(val id: String, val initialClientConnection: ClientConnection,
val eventHandlers: MutableMap<Int, (JsonElement) -> Unit> = HashMap(),
val onCloseHandlers : ConcurrentHashMap<Int, () -> Unit> = ConcurrentHashMap(),
val debugTokens: MutableMap<String, DebugInfo> = HashMap(), var lastModified: Instant = Instant.now(),
var onMessageFunction: ((data: JsonElement?) -> Unit)? = null) {

private @Volatile var clientConnection = initialClientConnection

fun getClientConnection() = clientConnection

fun updateClientConnection(newClientConnection: ClientConnection) {
logger.debug { "Updating client connection from $clientConnection to $newClientConnection" }
clientConnection = newClientConnection
}

fun send(message: Server2ClientMessage) {
clientConnection.send(Json.encodeToString(message))
}
Expand Down
10 changes: 6 additions & 4 deletions src/main/kotlin/kweb/state/KVal.kt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package kweb.state

import kweb.util.CallbackOwner
import kweb.util.random
import mu.two.KotlinLogging
import java.util.concurrent.ConcurrentHashMap
Expand All @@ -11,7 +12,7 @@ private val logger = KotlinLogging.logger {}
* A KVal is a **read-only** observable container for a value of type T. These are typically created by
* [KVal.map] or [KVar.map], but can also be created directly.
*/
open class KVal<T : Any?>(value: T) : AutoCloseable{
open class KVal<T : Any?>(private val kvalOwner : CallbackOwner, value: T) : AutoCloseable{

@Volatile
protected var closeReason: CloseReason? = null
Expand All @@ -24,10 +25,11 @@ open class KVal<T : Any?>(value: T) : AutoCloseable{
/**
* Add a listener to this KVar. The listener will be called whenever the [value] property changes.
*/
fun addListener(listener: (T, T) -> Unit): Long {
fun addListener(owner : CallbackOwner = kvalOwner, listener: (T, T) -> Unit) : Long {
verifyNotClosed("add a listener")
val handle = random.nextLong()
listeners[handle] = listener
owner.onClose(listeners, handle)
return handle
}

Expand Down Expand Up @@ -59,11 +61,11 @@ open class KVal<T : Any?>(value: T) : AutoCloseable{
*
* For bi-directional mappings, see [KVar.map].
*/
fun <O : Any?> map(mapper: (T) -> O): KVal<O> {
fun <O : Any?> map(owner: CallbackOwner = kvalOwner, mapper: (T) -> O): KVal<O> {
if (isClosed) {
error("Can't map this var because it was closed due to $closeReason")
}
val mappedKVal = KVal(mapper(value))
val mappedKVal = KVal(owner.child("map"), mapper(value))
val handle = addListener { old, new ->
if (!isClosed && !mappedKVal.isClosed) {
if (old != new) {
Expand Down
2 changes: 1 addition & 1 deletion src/main/kotlin/kweb/state/render.kt
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,10 @@ fun <T : Any?> ElementCreator<*>.render(
renderLoop()

this.onCleanup(true) {
value.removeListener(listenerHandle)
previousElementCreatorLock.withLock {
previousElementCreator.getAndSet(null)?.cleanup()
}
value.removeListener(listenerHandle)
}

return renderFragment
Expand Down
Loading
Loading