diff --git a/Sources/InProcessClient/InProcessSourceKitLSPClient.swift b/Sources/InProcessClient/InProcessSourceKitLSPClient.swift index dbcd25d6b..eaa4e4878 100644 --- a/Sources/InProcessClient/InProcessSourceKitLSPClient.swift +++ b/Sources/InProcessClient/InProcessSourceKitLSPClient.swift @@ -92,16 +92,41 @@ public final class InProcessSourceKitLSPClient: Sendable { /// necessary and the response of the request is not awaited, use the version of the function that takes a /// completion handler public func send(_ request: R) async throws -> R.Response { - return try await withCheckedThrowingContinuation { continuation in - self.send(request) { - continuation.resume(with: $0) + let requestId = ThreadSafeBox(initialValue: nil) + return try await withTaskCancellationHandler { + return try await withCheckedThrowingContinuation { continuation in + if Task.isCancelled { + // Check if the task has been cancelled before we send the request to LSP to avoid any kind of work if + // possible. + return continuation.resume(throwing: CancellationError()) + } + requestId.value = self.send(request) { + continuation.resume(with: $0) + } + if Task.isCancelled, let requestId = requestId.takeValue() { + // The task might have been cancelled after the above cancellation check but before `requestId` was assigned + // a value. To cover that case, check for cancellation here again. Note that we won't cancel twice from here + // and the `onCancel` handler because we take the request ID out of the `ThreadSafeBox` before sending the + // `CancelRequestNotification`. + self.send(CancelRequestNotification(id: requestId)) + } + } + } onCancel: { + if let requestId = requestId.takeValue() { + self.send(CancelRequestNotification(id: requestId)) } } } /// Send the request to `server` and return the request result via a completion handler. - public func send(_ request: R, reply: @Sendable @escaping (LSPResult) -> Void) { - server.handle(request, id: .number(Int(nextRequestID.fetchAndIncrement())), reply: reply) + @discardableResult + public func send( + _ request: R, + reply: @Sendable @escaping (LSPResult) -> Void + ) -> RequestID { + let requestID = RequestID.number(Int(nextRequestID.fetchAndIncrement())) + server.handle(request, id: requestID, reply: reply) + return requestID } /// Send the notification to `server`.