Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 32 additions & 37 deletions Sources/LanguageServerProtocolTransport/JSONRPCConnection.swift
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
//
//===----------------------------------------------------------------------===//

public import Dispatch
public import Foundation
public import LanguageServerProtocol
@_spi(SourceKitLSP) import SKLogging
Expand Down Expand Up @@ -49,15 +48,15 @@ public final class JSONRPCConnection: Connection {
nonisolated(unsafe) private var receiveHandler: MessageHandler? = nil

/// Queue for synchronizing all messages to ensure they remain in order
private let queue: DispatchQueue = DispatchQueue(label: "jsonrpc-queue", qos: .userInitiated)
private let queue: DispatchQueue

/// Queue for reading off of `receiveFD`
private let readQueue: DispatchQueue = DispatchQueue(label: "jsonrpc-read-queue", qos: .userInitiated)
/// Queue for the read loop (effectively just a separate thread - we never yield from the initial task)
private let receiveQueue: DispatchQueue

/// Queue for sending any data through `sendFD`. This is currently needed as the read loop is blocked on messages
/// being parsed on `queue` (in order to not add an extra copy), so we must perform any corresponding sends off of
/// `queue`. If we ever change that, we can likely remove this queue.
private let sendQueue: DispatchQueue = DispatchQueue(label: "jsonrpc-send-queue", qos: .userInitiated)
private let sendQueue: DispatchQueue

/// File descriptor for reading input (eg. stdin for an LSP server)
private let receiveFD: FileHandle
Expand Down Expand Up @@ -137,6 +136,9 @@ public final class JSONRPCConnection: Connection {
sendMirrorFile: FileHandle? = nil
) {
self.name = name
self.queue = DispatchQueue(label: "\(name)-jsonrpc-queue", qos: .userInitiated)
self.receiveQueue = DispatchQueue(label: "\(name)-jsonrpc-read-queue", qos: .userInitiated)
self.sendQueue = DispatchQueue(label: "\(name)-jsonrpc-send-queue", qos: .userInitiated)
self.receiveFD = receiveFD
self.receiveMirrorFile = receiveMirrorFile
self.sendFD = sendFD
Expand Down Expand Up @@ -248,7 +250,7 @@ public final class JSONRPCConnection: Connection {
self.closeHandler = closeHandler
}

self.readQueue.async {
self.receiveQueue.async {
let parser = JSONMessageParser(decoder: self.decodeJSONRPCMessage)
while true {
let data = orLog("Reading from \(self.name)") { try self.receiveFD.read(upToCount: parser.nextReadLength) }
Expand Down Expand Up @@ -286,7 +288,7 @@ public final class JSONRPCConnection: Connection {
\(message). Please run 'sourcekit-lsp diagnose' to file an issue.
"""
)
self.send(.notification(showMessage))
self.sendAssumingOnQueue(.notification(showMessage))
}

/// Decode a single JSONRPC message from the given `messageBytes`.
Expand Down Expand Up @@ -339,7 +341,7 @@ public final class JSONRPCConnection: Connection {
logger.fault(
"Replying to request \(id, privacy: .public) with error response because we failed to decode the request"
)
self.send(.errorResponse(ResponseError(error), id: id))
self.sendAssumingOnQueue(.errorResponse(ResponseError(error), id: id))
return nil
}
// If we don't know the ID of the request, ignore it and show a notification to the user.
Expand Down Expand Up @@ -458,36 +460,33 @@ public final class JSONRPCConnection: Connection {
/// If an unrecoverable error occurred on the channel's file descriptor, the connection gets closed.
///
/// - Important: Must be called on `queue`
private func send(data dispatchData: DispatchData) {
private func sendAssumingOnQueue(data: Data) {
dispatchPrecondition(condition: .onQueue(queue))

guard readyToSend() else { return }

#if !os(macOS)
nonisolated(unsafe) let dispatchData = dispatchData
#endif
sendQueue.async {
orLog("Writing send mirror file") {
try self.sendMirrorFile?.write(contentsOf: dispatchData)
try self.sendMirrorFile?.write(contentsOf: data)
}

do {
try self.sendFD.write(contentsOf: dispatchData)
try self.sendFD.write(contentsOf: data)
} catch {
logger.fault("IO error sending message to \(self.name): \(error.forLogging)")
self.close()
}
}
}

/// Wrapper of `send(data:)` that automatically switches to `queue`.
/// Wrapper of `sendAssumingOnQueue(data:)` that automatically switches to `queue`.
///
/// This should only be used to test that the client decodes messages correctly if data is delivered to it
/// byte-by-byte instead of in larger chunks that contain entire messages.
@_spi(Testing)
public func send(_rawData dispatchData: DispatchData) {
public func send(data: Data) {
queue.sync {
self.send(data: dispatchData)
self.sendAssumingOnQueue(data: data)
}
}

Expand All @@ -496,14 +495,12 @@ public final class JSONRPCConnection: Connection {
/// If an unrecoverable error occurred on the channel's file descriptor, the connection gets closed.
///
/// - Important: Must be called on `queue`
private func send(_ message: JSONRPCMessage) {
private func sendAssumingOnQueue(_ message: JSONRPCMessage) {
dispatchPrecondition(condition: .onQueue(queue))

let encoder = JSONEncoder()

let data: Data
let content: Data
do {
data = try encoder.encode(message)
content = try JSONEncoder().encode(message)
} catch {
logger.fault("Failed to encode message: \(error.forLogging)")
logger.fault("Malformed message: \(String(describing: message))")
Expand Down Expand Up @@ -541,16 +538,9 @@ public final class JSONRPCConnection: Connection {
}
}

var dispatchData = DispatchData.empty
let header = "Content-Length: \(data.count)\r\n\r\n"
header.utf8.map { $0 }.withUnsafeBytes { buffer in
dispatchData.append(buffer)
}
data.withUnsafeBytes { rawBufferPointer in
dispatchData.append(rawBufferPointer)
}

send(data: dispatchData)
let header = "Content-Length: \(content.count)\r\n\r\n"
sendAssumingOnQueue(data: Data(header.utf8))
sendAssumingOnQueue(data: content)
}

/// Close the connection.
Expand Down Expand Up @@ -585,6 +575,12 @@ public final class JSONRPCConnection: Connection {
orLog("Closing receiveFD to \(name)") {
try receiveFD.close()
}
orLog("Closing sendMirrorFile to \(name)") {
try sendMirrorFile?.close()
}
orLog("Closing receiveMirrorFile to \(name)") {
try receiveMirrorFile?.close()
}
}

self.receiveHandler = nil
Expand Down Expand Up @@ -614,7 +610,7 @@ public final class JSONRPCConnection: Connection {
\(notification.forLogging)
"""
)
self.send(.notification(notification))
self.sendAssumingOnQueue(.notification(notification))
}
}

Expand Down Expand Up @@ -664,8 +660,7 @@ public final class JSONRPCConnection: Connection {
"""
)

self.send(.request(request, id: id))
return
self.sendAssumingOnQueue(.request(request, id: id))
}
}

Expand All @@ -674,9 +669,9 @@ public final class JSONRPCConnection: Connection {
queue.async {
switch response {
case .success(let result):
self.send(.response(result, id: id))
self.sendAssumingOnQueue(.response(result, id: id))
case .failure(let error):
self.send(.errorResponse(error, id: id))
self.sendAssumingOnQueue(.errorResponse(error, id: id))
}
}
}
Expand Down
15 changes: 12 additions & 3 deletions Sources/LanguageServerProtocolTransport/MessageSplitting.swift
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,19 @@ package class JSONMessageParser<MessageType> {
package var nextReadLength: Int {
switch state {
case .header:
// Only ever read a single byte for the header to better handle invalid cases
return 1
// The header and content is split by `\r\n\r\n`. If we had the full separator, then we would be in `.content`
// state.
if requestBuffer.last == UInt8(ascii: "\n") {
// Can always read at least 2 bytes (we're either at `\r\n` or a lone `\n`)
return 2
} else if requestBuffer.last == UInt8(ascii: "\r") {
// Could be at `\r\n\r`, so can only read a single byte
return 1
}
// Don't have any part of the header separator, so can read at least its length
return 4
case .content(let remaining):
// Up until the message, where we should read its entire length (or anything remaining if we had a partial read)
// Read up until the end of the message (or anything remaining if we had a partial read)
return remaining
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,11 @@ class ConnectionTests: XCTestCase {
"Content-Length: \(notification2.count)\r\n\r\n\(String(data: notification2, encoding: .utf8)!)"

for b in notification1Str.utf8.dropLast() {
clientConnection.send(_rawData: [b].withUnsafeBytes { DispatchData(bytes: $0) })
clientConnection.send(data: Data([b]))
}

clientConnection.send(
_rawData: [notification1Str.utf8.last!, notfication2Str.utf8.first!].withUnsafeBytes { DispatchData(bytes: $0) }
data: Data([notification1Str.utf8.last!, notfication2Str.utf8.first!])
)

try await fulfillmentOfOrThrow(expectation)
Expand All @@ -91,7 +91,7 @@ class ConnectionTests: XCTestCase {
}

for b in notfication2Str.utf8.dropFirst() {
clientConnection.send(_rawData: [b].withUnsafeBytes { DispatchData(bytes: $0) })
clientConnection.send(data: Data([b]))
}

try await fulfillmentOfOrThrow(expectation2)
Expand Down Expand Up @@ -305,8 +305,6 @@ class ConnectionTests: XCTestCase {
fileprivate extension JSONRPCConnection {
func send(message: String) {
let messageWithHeader = "Content-Length: \(message.utf8.count)\r\n\r\n\(message)".data(using: .utf8)!
messageWithHeader.withUnsafeBytes { bytes in
send(_rawData: DispatchData(bytes: bytes))
}
send(data: messageWithHeader)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,24 +17,27 @@ import XCTest
final class MessageParsingTests: XCTestCase {
func testBasicMessage() {
let parser = parserForTesting()
XCTAssertEqual(parser.nextReadLength, 1)
XCTAssertEqual(parser.nextReadLength, 4)
XCTAssertNil(parser.parse(chunk: "Content-Length: 2\r\n\r\n".data))
XCTAssertEqual(parser.nextReadLength, 2)
XCTAssertEqual(parser.parse(chunk: "{}".data), "{}")
XCTAssertEqual(parser.nextReadLength, 1)
XCTAssertEqual(parser.nextReadLength, 4)
}

func testSplitMessage() {
let parser = parserForTesting()
XCTAssertEqual(parser.nextReadLength, 1)
XCTAssertEqual(parser.nextReadLength, 4)
XCTAssertNil(parser.parse(chunk: "Content".data))
XCTAssertEqual(parser.nextReadLength, 4)
XCTAssertNil(parser.parse(chunk: "-Length: 2\r".data))
XCTAssertEqual(parser.nextReadLength, 1)
XCTAssertNil(parser.parse(chunk: "-Length: 2\r\n".data))
XCTAssertEqual(parser.nextReadLength, 1)
XCTAssertNil(parser.parse(chunk: "\n".data))
XCTAssertEqual(parser.nextReadLength, 2)
XCTAssertNil(parser.parse(chunk: "\r\n".data))
XCTAssertEqual(parser.nextReadLength, 2)
XCTAssertNil(parser.parse(chunk: "{".data))
XCTAssertEqual(parser.parse(chunk: "}".data), "{}")
XCTAssertEqual(parser.nextReadLength, 4)
}

func testMultipleMessage() {
Expand Down