Skip to content

Commit

Permalink
Move message encoding into PSQLChannelHandler (#181)
Browse files Browse the repository at this point in the history
  • Loading branch information
fabianfett authored Sep 23, 2021
1 parent 28ab2df commit 131deb3
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 54 deletions.
109 changes: 65 additions & 44 deletions Sources/PostgresNIO/New/PSQLChannelHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ protocol PSQLChannelHandlerNotificationDelegate: AnyObject {
final class PSQLChannelHandler: ChannelDuplexHandler {
typealias OutboundIn = PSQLTask
typealias InboundIn = ByteBuffer
typealias OutboundOut = PSQLFrontendMessage
typealias OutboundOut = ByteBuffer

private let logger: Logger
private var state: ConnectionStateMachine {
Expand All @@ -25,32 +25,33 @@ final class PSQLChannelHandler: ChannelDuplexHandler {
private var handlerContext: ChannelHandlerContext!
private var rowStream: PSQLRowStream?
private var decoder: NIOSingleStepByteToMessageProcessor<PSQLBackendMessageDecoder>
private let authentificationConfiguration: PSQLConnection.Configuration.Authentication?
private var encoder: BufferedMessageEncoder<PSQLFrontendMessageEncoder>!
private let configuration: PSQLConnection.Configuration
private let configureSSLCallback: ((Channel) throws -> Void)?

/// this delegate should only be accessed on the connections `EventLoop`
weak var notificationDelegate: PSQLChannelHandlerNotificationDelegate?

init(authentification: PSQLConnection.Configuration.Authentication?,
init(configuration: PSQLConnection.Configuration,
logger: Logger,
configureSSLCallback: ((Channel) throws -> Void)?)
{
self.state = ConnectionStateMachine()
self.authentificationConfiguration = authentification
self.configuration = configuration
self.configureSSLCallback = configureSSLCallback
self.logger = logger
self.decoder = NIOSingleStepByteToMessageProcessor(PSQLBackendMessageDecoder())
}

#if DEBUG
/// for testing purposes only
init(authentification: PSQLConnection.Configuration.Authentication?,
init(configuration: PSQLConnection.Configuration,
state: ConnectionStateMachine = .init(.initialized),
logger: Logger = .psqlNoOpLogger,
configureSSLCallback: ((Channel) throws -> Void)?)
{
self.state = state
self.authentificationConfiguration = authentification
self.configuration = configuration
self.configureSSLCallback = configureSSLCallback
self.logger = logger
self.decoder = NIOSingleStepByteToMessageProcessor(PSQLBackendMessageDecoder())
Expand All @@ -61,6 +62,11 @@ final class PSQLChannelHandler: ChannelDuplexHandler {

func handlerAdded(context: ChannelHandlerContext) {
self.handlerContext = context
self.encoder = BufferedMessageEncoder(
buffer: context.channel.allocator.buffer(capacity: 256),
encoder: PSQLFrontendMessageEncoder(jsonEncoder: self.configuration.coders.jsonEncoder)
)

if context.channel.isActive {
self.connected(context: context)
}
Expand Down Expand Up @@ -222,15 +228,19 @@ final class PSQLChannelHandler: ChannelDuplexHandler {
case .wait:
break
case .sendStartupMessage(let authContext):
context.writeAndFlush(.startup(.versionThree(parameters: authContext.toStartupParameters())), promise: nil)
try! self.encoder.encode(.startup(.versionThree(parameters: authContext.toStartupParameters())))
context.writeAndFlush(self.wrapOutboundOut(self.encoder.flush()!), promise: nil)
case .sendSSLRequest:
context.writeAndFlush(.sslRequest(.init()), promise: nil)
try! self.encoder.encode(.sslRequest(.init()))
context.writeAndFlush(self.wrapOutboundOut(self.encoder.flush()!), promise: nil)
case .sendPasswordMessage(let mode, let authContext):
self.sendPasswordMessage(mode: mode, authContext: authContext, context: context)
case .sendSaslInitialResponse(let name, let initialResponse):
context.writeAndFlush(.saslInitialResponse(.init(saslMechanism: name, initialData: initialResponse)))
try! self.encoder.encode(.saslInitialResponse(.init(saslMechanism: name, initialData: initialResponse)))
context.writeAndFlush(self.wrapOutboundOut(self.encoder.flush()!), promise: nil)
case .sendSaslResponse(let bytes):
context.writeAndFlush(.saslResponse(.init(data: bytes)))
try! self.encoder.encode(.saslResponse(.init(data: bytes)))
context.writeAndFlush(self.wrapOutboundOut(self.encoder.flush()!), promise: nil)
case .closeConnectionAndCleanup(let cleanupContext):
self.closeConnectionAndCleanup(cleanupContext, context: context)
case .fireChannelInactive:
Expand Down Expand Up @@ -277,7 +287,7 @@ final class PSQLChannelHandler: ChannelDuplexHandler {
case .provideAuthenticationContext:
context.fireUserInboundEventTriggered(PSQLEvent.readyForStartup)

if let authentication = self.authentificationConfiguration {
if let authentication = self.configuration.authentication {
let authContext = AuthContext(
username: authentication.username,
password: authentication.password,
Expand All @@ -293,7 +303,8 @@ final class PSQLChannelHandler: ChannelDuplexHandler {
// The normal, graceful termination procedure is that the frontend sends a Terminate
// message and immediately closes the connection. On receipt of this message, the
// backend closes the connection and terminates.
context.write(.terminate, promise: nil)
try! self.encoder.encode(.terminate)
context.writeAndFlush(self.wrapOutboundOut(self.encoder.flush()!), promise: nil)
}
context.close(mode: .all, promise: promise)
case .succeedPreparedStatementCreation(let preparedContext, with: let rowDescription):
Expand Down Expand Up @@ -357,22 +368,26 @@ final class PSQLChannelHandler: ChannelDuplexHandler {
hash2.append(salt.3)
let hash = "md5" + Insecure.MD5.hash(data: hash2).hexdigest()

context.writeAndFlush(.password(.init(value: hash)), promise: nil)
try! self.encoder.encode(.password(.init(value: hash)))
context.writeAndFlush(self.wrapOutboundOut(self.encoder.flush()!), promise: nil)

case .cleartext:
context.writeAndFlush(.password(.init(value: authContext.password ?? "")), promise: nil)
try! self.encoder.encode(.password(.init(value: authContext.password ?? "")))
context.writeAndFlush(self.wrapOutboundOut(self.encoder.flush()!), promise: nil)
}
}

private func sendCloseAndSyncMessage(_ sendClose: CloseTarget, context: ChannelHandlerContext) {
switch sendClose {
case .preparedStatement(let name):
context.write(.close(.preparedStatement(name)), promise: nil)
context.write(.sync, promise: nil)
context.flush()
try! self.encoder.encode(.close(.preparedStatement(name)))
try! self.encoder.encode(.sync)
context.writeAndFlush(self.wrapOutboundOut(self.encoder.flush()!), promise: nil)

case .portal(let name):
context.write(.close(.portal(name)), promise: nil)
context.write(.sync, promise: nil)
context.flush()
try! self.encoder.encode(.close(.portal(name)))
try! self.encoder.encode(.sync)
context.writeAndFlush(self.wrapOutboundOut(self.encoder.flush()!), promise: nil)
}
}

Expand All @@ -387,10 +402,16 @@ final class PSQLChannelHandler: ChannelDuplexHandler {
query: query,
parameters: [])

context.write(.parse(parse), promise: nil)
context.write(.describe(.preparedStatement(statementName)), promise: nil)
context.write(.sync, promise: nil)
context.flush()

do {
try self.encoder.encode(.parse(parse))
try self.encoder.encode(.describe(.preparedStatement(statementName)))
try self.encoder.encode(.sync)
context.writeAndFlush(self.wrapOutboundOut(self.encoder.flush()!), promise: nil)
} catch {
let action = self.state.errorHappened(.channel(underlying: error))
self.run(action, with: context)
}
}

private func sendBindExecuteAndSyncMessage(
Expand All @@ -403,10 +424,15 @@ final class PSQLChannelHandler: ChannelDuplexHandler {
preparedStatementName: statementName,
parameters: binds)

context.write(.bind(bind), promise: nil)
context.write(.execute(.init(portalName: "")), promise: nil)
context.write(.sync, promise: nil)
context.flush()
do {
try self.encoder.encode(.bind(bind))
try self.encoder.encode(.execute(.init(portalName: "")))
try self.encoder.encode(.sync)
context.writeAndFlush(self.wrapOutboundOut(self.encoder.flush()!), promise: nil)
} catch {
let action = self.state.errorHappened(.channel(underlying: error))
self.run(action, with: context)
}
}

private func sendParseDescribeBindExecuteAndSyncMessage(
Expand All @@ -424,12 +450,17 @@ final class PSQLChannelHandler: ChannelDuplexHandler {
preparedStatementName: unnamedStatementName,
parameters: binds)

context.write(wrapOutboundOut(.parse(parse)), promise: nil)
context.write(wrapOutboundOut(.describe(.preparedStatement(""))), promise: nil)
context.write(wrapOutboundOut(.bind(bind)), promise: nil)
context.write(wrapOutboundOut(.execute(.init(portalName: ""))), promise: nil)
context.write(wrapOutboundOut(.sync), promise: nil)
context.flush()
do {
try self.encoder.encode(.parse(parse))
try self.encoder.encode(.describe(.preparedStatement("")))
try self.encoder.encode(.bind(bind))
try self.encoder.encode(.execute(.init(portalName: "")))
try self.encoder.encode(.sync)
context.writeAndFlush(self.wrapOutboundOut(self.encoder.flush()!), promise: nil)
} catch {
let action = self.state.errorHappened(.channel(underlying: error))
self.run(action, with: context)
}
}

private func succeedQueryWithRowStream(
Expand Down Expand Up @@ -503,16 +534,6 @@ extension PSQLChannelHandler: PSQLRowsDataSource {
}
}

extension ChannelHandlerContext {
func write(_ psqlMessage: PSQLFrontendMessage, promise: EventLoopPromise<Void>? = nil) {
self.write(NIOAny(psqlMessage), promise: promise)
}

func writeAndFlush(_ psqlMessage: PSQLFrontendMessage, promise: EventLoopPromise<Void>? = nil) {
self.writeAndFlush(NIOAny(psqlMessage), promise: promise)
}
}

extension PSQLConnection.Configuration.Authentication {
func toAuthContext() -> AuthContext {
AuthContext(
Expand Down
3 changes: 1 addition & 2 deletions Sources/PostgresNIO/New/PSQLConnection.swift
Original file line number Diff line number Diff line change
Expand Up @@ -228,9 +228,8 @@ final class PSQLConnection {
}

return channel.pipeline.addHandlers([
MessageToByteHandler(PSQLFrontendMessageEncoder(jsonEncoder: configuration.coders.jsonEncoder)),
PSQLChannelHandler(
authentification: configuration.authentication,
configuration: configuration,
logger: logger,
configureSSLCallback: configureSSLCallback),
PSQLEventsHandler(logger: logger)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,13 @@ struct PSQLFrontendMessageDecoder: NIOSingleStepByteToMessageDecoder {
return nil
}

guard var messageSlice = buffer.getSlice(at: buffer.readerIndex &+ 4, length: Int(length)) else {
guard var messageSlice = buffer.getSlice(at: buffer.readerIndex + 4, length: Int(length) - 4) else {
return nil
}
buffer.moveReaderIndex(forwardBy: 4 &+ Int(length))
buffer.moveReaderIndex(to: Int(length))
let finalIndex = buffer.readerIndex

guard let code = buffer.readInteger(as: UInt32.self) else {
guard let code = messageSlice.readInteger(as: UInt32.self) else {
throw PSQLPartialDecodingError.fieldNotDecodable(type: UInt32.self)
}

Expand Down
16 changes: 11 additions & 5 deletions Tests/PostgresNIOTests/New/PSQLChannelHandlerTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,11 @@ class PSQLChannelHandlerTests: XCTestCase {

func testHandlerAddedWithoutSSL() {
let config = self.testConnectionConfiguration()
let handler = PSQLChannelHandler(configuration: config, configureSSLCallback: nil)
let embedded = EmbeddedChannel(handlers: [
ReverseByteToMessageHandler(PSQLFrontendMessageDecoder()),
ReverseMessageToByteHandler(PSQLBackendMessageEncoder()),
PSQLChannelHandler(authentification: config.authentication, configureSSLCallback: nil)
handler
])
defer { XCTAssertNoThrow(try embedded.finish()) }

Expand All @@ -38,10 +40,11 @@ class PSQLChannelHandlerTests: XCTestCase {
var config = self.testConnectionConfiguration()
config.tlsConfiguration = .makeClientConfiguration()
var addSSLCallbackIsHit = false
let handler = PSQLChannelHandler(authentification: config.authentication) { channel in
let handler = PSQLChannelHandler(configuration: config) { channel in
addSSLCallbackIsHit = true
}
let embedded = EmbeddedChannel(handlers: [
ReverseByteToMessageHandler(PSQLFrontendMessageDecoder()),
ReverseMessageToByteHandler(PSQLBackendMessageEncoder()),
handler
])
Expand Down Expand Up @@ -79,11 +82,12 @@ class PSQLChannelHandlerTests: XCTestCase {
var config = self.testConnectionConfiguration()
config.tlsConfiguration = .makeClientConfiguration()

let handler = PSQLChannelHandler(authentification: config.authentication) { channel in
let handler = PSQLChannelHandler(configuration: config) { channel in
XCTFail("This callback should never be exectuded")
throw PSQLError.sslUnsupported
}
let embedded = EmbeddedChannel(handlers: [
ReverseByteToMessageHandler(PSQLFrontendMessageDecoder()),
ReverseMessageToByteHandler(PSQLBackendMessageEncoder()),
handler
])
Expand Down Expand Up @@ -114,8 +118,9 @@ class PSQLChannelHandlerTests: XCTestCase {
database: config.authentication?.database
)
let state = ConnectionStateMachine(.waitingToStartAuthentication)
let handler = PSQLChannelHandler(authentification: config.authentication, state: state, configureSSLCallback: nil)
let handler = PSQLChannelHandler(configuration: config, state: state, configureSSLCallback: nil)
let embedded = EmbeddedChannel(handlers: [
ReverseByteToMessageHandler(PSQLFrontendMessageDecoder()),
ReverseMessageToByteHandler(PSQLBackendMessageEncoder()),
handler
])
Expand All @@ -142,8 +147,9 @@ class PSQLChannelHandlerTests: XCTestCase {
database: config.authentication?.database
)
let state = ConnectionStateMachine(.waitingToStartAuthentication)
let handler = PSQLChannelHandler(authentification: config.authentication, state: state, configureSSLCallback: nil)
let handler = PSQLChannelHandler(configuration: config, state: state, configureSSLCallback: nil)
let embedded = EmbeddedChannel(handlers: [
ReverseByteToMessageHandler(PSQLFrontendMessageDecoder()),
ReverseMessageToByteHandler(PSQLBackendMessageEncoder()),
handler
])
Expand Down

0 comments on commit 131deb3

Please sign in to comment.