From c0261d68b538f3e9a5783d61401432ea1fcff54b Mon Sep 17 00:00:00 2001 From: Hugo Osvaldo Barrera Date: Mon, 27 Mar 2023 19:14:10 +0200 Subject: [PATCH] Add Method::from_static Allows creating constant `Method` instances, e.g.: const PROPFIND: Method = Method::from_static(b"PROPFIND"); Fixes: https://github.com/hyperium/http/issues/587 --- src/method.rs | 118 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 118 insertions(+) diff --git a/src/method.rs b/src/method.rs index 7b4584ab..0538d84f 100644 --- a/src/method.rs +++ b/src/method.rs @@ -15,6 +15,8 @@ //! assert_eq!(Method::POST.as_str(), "POST"); //! ``` +use extension::StaticExtension; + use self::extension::{AllocatedExtension, InlineExtension}; use self::Inner::*; @@ -64,6 +66,8 @@ enum Inner { ExtensionInline(InlineExtension), // Otherwise, allocate it ExtensionAllocated(AllocatedExtension), + // Statically allocated data + ExtensionStatic(StaticExtension), } impl Method { @@ -134,6 +138,30 @@ impl Method { } } + /// Convert static bytes into a `Method`. + pub const fn from_static(src: &'static [u8]) -> Method { + match src { + b"OPTIONS" => Method::OPTIONS, + b"GET" => Method::GET, + b"POST" => Method::POST, + b"PUT" => Method::PUT, + b"DELETE" => Method::DELETE, + b"HEAD" => Method::HEAD, + b"TRACE" => Method::TRACE, + b"CONNECT" => Method::CONNECT, + b"PATCH" => Method::PATCH, + src => { + if src.len() <= 15 { + let inline = InlineExtension::from_static(src); + Method(ExtensionInline(inline)) + } else { + let allocated = StaticExtension::from_static(src); + Method(ExtensionStatic(allocated)) + } + } + } + } + fn extension_inline(src: &[u8]) -> Result { let inline = InlineExtension::new(src)?; @@ -176,6 +204,7 @@ impl Method { Patch => "PATCH", ExtensionInline(ref inline) => inline.as_str(), ExtensionAllocated(ref allocated) => allocated.as_str(), + ExtensionStatic(ref s) => s.as_str(), } } } @@ -316,6 +345,9 @@ mod extension { // Invariant: self.0 contains valid UTF-8. pub struct AllocatedExtension(Box<[u8]>); + #[derive(Clone, PartialEq, Eq, Hash)] + pub struct StaticExtension(&'static [u8]); + impl InlineExtension { // Method::from_bytes() assumes this is at least 7 pub const MAX: usize = 15; @@ -330,6 +362,34 @@ mod extension { Ok(InlineExtension(data, src.len() as u8)) } + /// Convert static bytes into an `InlineExtension`. + /// + /// # Panics + /// + /// If the input bytes are not a valid method name or if the method name is over 15 bytes. + pub const fn from_static(src: &'static [u8]) -> InlineExtension { + let mut i = 0; + let mut dst = [0u8; 15]; + if src.len() > 15 { + // panicking in const requires Rust 1.57.0 + #[allow(unconditional_panic)] + ([] as [u8; 0])[0]; + } + while i < src.len() { + let byte = src[i]; + let v = METHOD_CHARS[byte as usize]; + if v == 0 { + // panicking in const requires Rust 1.57.0 + #[allow(unconditional_panic)] + ([] as [u8; 0])[0]; + } + dst[i] = byte; + i += 1; + } + + InlineExtension(dst, i as u8) + } + pub fn as_str(&self) -> &str { let InlineExtension(ref data, len) = self; // Safety: the invariant of InlineExtension ensures that the first @@ -356,6 +416,32 @@ mod extension { } } + impl StaticExtension { + pub const fn from_static(src: &'static [u8]) -> StaticExtension { + let mut i = 0; + while i < src.len() { + let byte = src[i]; + let v = METHOD_CHARS[byte as usize]; + if v == 0 { + // panicking in const requires Rust 1.57.0 + #[allow(unconditional_panic)] + ([] as [u8; 0])[0]; + } + i += 1; + } + + // Invariant: data is exactly src.len() long and write_checked + // ensures that the first src.len() bytes of data are valid UTF-8. + StaticExtension(src) + } + + pub fn as_str(&self) -> &str { + // Safety: the invariant of StaticExtension ensures that self.0 + // contains valid UTF-8. + unsafe { str::from_utf8_unchecked(&self.0) } + } + } + // From the RFC 9110 HTTP Semantics, section 9.1, the HTTP method is case-sensitive and can // contain the following characters: // @@ -436,6 +522,38 @@ mod test { assert_eq!(Method::GET, &Method::GET); } + #[test] + fn test_from_static() { + // First class variant + assert_eq!( + Method::from_static(b"GET"), + Method::from_bytes(b"GET").unwrap() + ); + // Inline, len < 15 + assert_eq!( + Method::from_static(b"PROPFIND"), + Method::from_bytes(b"PROPFIND").unwrap() + ); + // Inline, len == 15 + assert_eq!(Method::from_static(b"GET"), Method::GET); + assert_eq!( + Method::from_static(b"123456789012345").to_string(), + "123456789012345".to_string() + ); + // Ref, len > 15 + Method::from_static(b"1234567890123456"); + assert_eq!( + Method::from_static(b"1234567890123456").to_string(), + "1234567890123456".to_string() + ); + } + + #[test] + #[should_panic] + fn test_from_static_bad() { + Method::from_static(b"\0"); + } + #[test] fn test_invalid_method() { assert!(Method::from_str("").is_err());