Skip to content

Commit

Permalink
Router validation (#637)
Browse files Browse the repository at this point in the history
  • Loading branch information
adam-fowler authored Dec 28, 2024
1 parent 53c231f commit 17d8de6
Show file tree
Hide file tree
Showing 3 changed files with 154 additions and 24 deletions.
105 changes: 105 additions & 0 deletions Sources/Hummingbird/Router/Router+validation.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
//===----------------------------------------------------------------------===//
//
// This source file is part of the Hummingbird server framework project
//
// Copyright (c) 2024 the Hummingbird authors
// Licensed under Apache License v2.0
//
// See LICENSE.txt for license information
// See hummingbird/CONTRIBUTORS.txt for the list of Hummingbird authors
//
// SPDX-License-Identifier: Apache-2.0
//
//===----------------------------------------------------------------------===//

import HTTPTypes

#if canImport(FoundationEssentials)
import FoundationEssentials
#else
import Foundation
#endif

extension Router {
/// Route description
public struct RouteDescription: CustomStringConvertible {
/// Route path
public let path: RouterPath
/// Route method
public let method: HTTPRequest.Method

public var description: String { "\(method) \(path)" }
}

/// List of routes added to router
public var routes: [RouteDescription] {
let trieValues = self.trie.root.values()
return trieValues.flatMap { endpoint in
endpoint.value.methods.keys
.sorted { $0.rawValue < $1.rawValue }
.map { RouteDescription(path: endpoint.path, method: $0) }
}
}

/// Validate router
///
/// Verify that routes are not clashing
public func validate() throws {
try self.trie.root.validate()
}
}

extension RouterPathTrieBuilder.Node {
func validate(_ root: String = "") throws {
let sortedChildren = children.sorted { $0.key.priority > $1.key.priority }
if sortedChildren.count > 1 {
for index in 1..<sortedChildren.count {
let exampleElement =
switch sortedChildren[index].key.value {
case .path(let path):
String(path)
case .capture:
UUID().uuidString
case .prefixCapture(let suffix, _):
"\(UUID().uuidString)\(suffix)"
case .suffixCapture(let prefix, _):
"\(prefix)/\(UUID().uuidString)"
case .wildcard:
UUID().uuidString
case .prefixWildcard(let suffix):
"\(UUID().uuidString)\(suffix)"
case .suffixWildcard(let prefix):
"\(prefix)/\(UUID().uuidString)"
case .recursiveWildcard:
UUID().uuidString
case .null:
""
}
// test path element against all the previous trie entries in this node
for trieEntry in sortedChildren[0..<index] {
if case trieEntry.key = exampleElement {
throw RouterValidationError(
path: "\(root)/\(sortedChildren[index].key)",
override: "\(root)/\(trieEntry.key)"
)
}
}

}
}

for child in self.children {
try child.validate("\(root)/\(child.key)")
}
}
}

/// Router validation error
public struct RouterValidationError: Error, CustomStringConvertible {
let path: RouterPath
let override: RouterPath

public var description: String {
"Route \(override) overrides \(path)"
}
}
29 changes: 7 additions & 22 deletions Sources/Hummingbird/Router/Router.swift
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,13 @@ public final class Router<Context: RequestContext>: RouterMethods, HTTPResponder

/// build responder from router
public func buildResponder() -> RouterResponder<Context> {
#if DEBUG
do {
try self.validate()
} catch {
assertionFailure("\(error)")
}
#endif
if self.options.contains(.autoGenerateHeadEndpoints) {
// swift-format-ignore: ReplaceForEachWithForLoop
self.trie.forEach { node in
Expand Down Expand Up @@ -128,25 +135,3 @@ public struct RouterOptions: OptionSet, Sendable {
/// For every GET request that does not have a HEAD request, auto generate the HEAD request
public static var autoGenerateHeadEndpoints: Self { .init(rawValue: 1 << 1) }
}

extension Router {
/// Route description
public struct RouteDescription: CustomStringConvertible {
/// Route path
public let path: RouterPath
/// Route method
public let method: HTTPRequest.Method

public var description: String { "\(method) \(path)" }
}

/// List of routes added to router
public var routes: [RouteDescription] {
let trieValues = self.trie.root.values()
return trieValues.flatMap { endpoint in
endpoint.value.methods.keys
.sorted { $0.rawValue < $1.rawValue }
.map { RouteDescription(path: endpoint.path, method: $0) }
}
}
}
44 changes: 42 additions & 2 deletions Tests/HummingbirdTests/RouterTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -735,7 +735,7 @@ final class RouterTests: XCTestCase {
router.get("test/this") { _, _ in "" }
router.put("test") { _, _ in "" }
router.post("{test}/{what}") { _, _ in "" }
router.get("wildcard/*") { _, _ in "" }
router.get("wildcard/*/*") { _, _ in "" }
router.get("recursive_wildcard/**") { _, _ in "" }
router.patch("/test/longer/path/name") { _, _ in "" }
let routes = router.routes
Expand All @@ -750,11 +750,51 @@ final class RouterTests: XCTestCase {
XCTAssertEqual(routes[3].method, .patch)
XCTAssertEqual(routes[4].path.description, "/{test}/{what}")
XCTAssertEqual(routes[4].method, .post)
XCTAssertEqual(routes[5].path.description, "/wildcard/*")
XCTAssertEqual(routes[5].path.description, "/wildcard/*/*")
XCTAssertEqual(routes[5].method, .get)
XCTAssertEqual(routes[6].path.description, "/recursive_wildcard/**")
XCTAssertEqual(routes[6].method, .get)
}

func testValidateOrdering() throws {
let router = Router()
router.post("{test}/{what}") { _, _ in "" }
router.get("test/this") { _, _ in "" }
try router.validate()
}

func testValidateParametersVsWildcards() throws {
let router = Router()
router.get("test/*") { _, _ in "" }
router.get("test/{what}") { _, _ in "" }
XCTAssertThrowsError(try router.validate()) { error in
guard let error = error as? RouterValidationError else {
XCTFail()
return
}
XCTAssertEqual(error.description, "Route /test/{what} overrides /test/*")
}
}

func testValidateParametersVsRecursiveWildcard() throws {
let router = Router()
router.get("test/**") { _, _ in "" }
router.get("test/{what}") { _, _ in "" }
XCTAssertThrowsError(try router.validate()) { error in
guard let error = error as? RouterValidationError else {
XCTFail()
return
}
XCTAssertEqual(error.description, "Route /test/{what} overrides /test/**")
}
}

func testValidateDifferentParameterNames() throws {
let router = Router()
router.get("test/{this}") { _, _ in "" }
router.get("test/{what}") { _, _ in "" }
XCTAssertThrowsError(try router.validate())
}
}

struct TestRouterContext2: RequestContext {
Expand Down

0 comments on commit 17d8de6

Please sign in to comment.