From edd4aa99cfbf0200d76555936ef646905727b4ae Mon Sep 17 00:00:00 2001
From: Lars Wrenger <lars@wrenger.net>
Date: Fri, 16 Feb 2024 22:57:59 +0100
Subject: [PATCH] :sparkles: Conversion functions can use smaller ints than the
 bitfield

---
 src/lib.rs    | 66 +++++++++++++++++++++++++--------------------------
 tests/test.rs |  6 ++---
 2 files changed, 35 insertions(+), 37 deletions(-)

diff --git a/src/lib.rs b/src/lib.rs
index 319ee96..1ea3e39 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -5,11 +5,15 @@
 use proc_macro as pc;
 use proc_macro2::{Ident, TokenStream};
 use quote::{format_ident, quote, ToTokens};
-use std::stringify;
+use std::{fmt, stringify};
 use syn::parse::{Parse, ParseStream};
 use syn::spanned::Spanned;
 use syn::Token;
 
+fn s_err(span: proc_macro2::Span, msg: impl fmt::Display) -> syn::Error {
+    syn::Error::new(span, msg)
+}
+
 /// Creates a bitfield for this struct.
 ///
 /// The arguments first, have to begin with the underlying type of the bitfield:
@@ -54,7 +58,7 @@ fn bitfield_inner(args: TokenStream, input: TokenStream) -> syn::Result<TokenStr
     let attrs: TokenStream = input.attrs.iter().map(ToTokens::to_token_stream).collect();
 
     let syn::Fields::Named(fields) = input.fields else {
-        return Err(syn::Error::new(span, "only named fields are supported"));
+        return Err(s_err(span, "only named fields are supported"));
     };
 
     let mut offset = 0;
@@ -66,7 +70,7 @@ fn bitfield_inner(args: TokenStream, input: TokenStream) -> syn::Result<TokenStr
     }
 
     if offset < bits {
-        return Err(syn::Error::new(
+        return Err(s_err(
             span,
             format!(
                 "The bitfield size ({bits} bits) has to be equal to the sum of its members ({offset} bits)!. \
@@ -76,7 +80,7 @@ fn bitfield_inner(args: TokenStream, input: TokenStream) -> syn::Result<TokenStr
         ));
     }
     if offset > bits {
-        return Err(syn::Error::new(
+        return Err(s_err(
             span,
             format!(
                 "The size of the members ({offset} bits) is larger than the type ({bits} bits)!."
@@ -183,7 +187,7 @@ impl Member {
             ..
         } = f;
 
-        let ident = ident.ok_or_else(|| syn::Error::new(span, "Not supported"))?;
+        let ident = ident.ok_or_else(|| s_err(span, "Not supported"))?;
         let ignore = ident.to_string().starts_with('_');
 
         let Field {
@@ -200,9 +204,9 @@ impl Member {
         if bits > 0 && !ignore {
             // overflow check
             if offset + bits > base_bits {
-                return Err(syn::Error::new(
+                return Err(s_err(
                     ty.span(),
-                    "The total size of the members is too large!",
+                    "The sum of the members overflows the type size",
                 ));
             };
 
@@ -400,7 +404,7 @@ fn parse_field(
     ignore: bool,
 ) -> syn::Result<Field> {
     fn malformed(mut e: syn::Error, attr: &syn::Attribute) -> syn::Error {
-        e.combine(syn::Error::new(attr.span(), "malformed #[bits] attribute"));
+        e.combine(s_err(attr.span(), "malformed #[bits] attribute"));
         e
     }
 
@@ -441,8 +445,8 @@ fn parse_field(
             bits: ty_bits,
             ty: ty.clone(),
             default: quote!(),
-            into: quote!(<#ty>::into_bits(this)),
-            from: quote!(<#ty>::from_bits(this)),
+            into: quote!(<#ty>::into_bits(this) as _),
+            from: quote!(<#ty>::from_bits(this as _)),
             access,
         },
     };
@@ -470,10 +474,10 @@ fn parse_field(
             // bit size
             if let Some(bits) = bits {
                 if bits == 0 {
-                    return Err(syn::Error::new(span, "bits cannot bit 0"));
+                    return Err(s_err(span, "bits cannot bit 0"));
                 }
                 if ty_bits != 0 && bits > ty_bits {
-                    return Err(syn::Error::new(span, "overflowing field type"));
+                    return Err(s_err(span, "overflowing field type"));
                 }
                 ret.bits = bits;
             }
@@ -481,7 +485,7 @@ fn parse_field(
             // read/write access
             if let Some(access) = access {
                 if ignore {
-                    return Err(syn::Error::new(
+                    return Err(s_err(
                         tokens.span(),
                         "'access' is not supported for padding",
                     ));
@@ -492,21 +496,15 @@ fn parse_field(
             // conversion
             if let Some(into) = into {
                 if ret.access == Access::None {
-                    return Err(syn::Error::new(
-                        into.span(),
-                        "'into' and 'from' are not supported on padding",
-                    ));
+                    return Err(s_err(into.span(), "'into' is not supported on padding"));
                 }
-                ret.into = quote!(#into(this));
+                ret.into = quote!(#into(this) as _);
             }
             if let Some(from) = from {
                 if ret.access == Access::None {
-                    return Err(syn::Error::new(
-                        from.span(),
-                        "'into' and 'from' are not supported on padding",
-                    ));
+                    return Err(s_err(from.span(), "'from' is not supported on padding"));
                 }
-                ret.from = quote!(#from(this));
+                ret.from = quote!(#from(this as _));
             }
             if let Some(default) = default {
                 ret.default = default.into_token_stream();
@@ -515,7 +513,7 @@ fn parse_field(
     }
 
     if ret.bits == 0 {
-        return Err(syn::Error::new(
+        return Err(s_err(
             ty.span(),
             "Custom types and isize/usize require an explicit bit size",
         ));
@@ -612,15 +610,15 @@ impl Parse for Access {
         let mode = input.parse::<Ident>()?;
 
         if mode == "RW" {
-            Ok(Access::ReadWrite)
+            Ok(Self::ReadWrite)
         } else if mode == "RO" {
-            Ok(Access::ReadOnly)
+            Ok(Self::ReadOnly)
         } else if mode == "WO" {
-            Ok(Access::WriteOnly)
+            Ok(Self::WriteOnly)
         } else if mode == "None" {
-            Ok(Access::None)
+            Ok(Self::None)
         } else {
-            Err(syn::Error::new(
+            Err(s_err(
                 mode.span(),
                 "Invalid access mode, only RW, RO, WO, or None are allowed",
             ))
@@ -646,11 +644,11 @@ struct Params {
 impl Parse for Params {
     fn parse(input: ParseStream) -> syn::Result<Self> {
         let Ok(ty) = syn::Type::parse(input) else {
-            return Err(syn::Error::new(input.span(), "unknown type"));
+            return Err(s_err(input.span(), "unknown type"));
         };
         let (class, bits) = type_bits(&ty);
         if class != TypeClass::UInt {
-            return Err(syn::Error::new(input.span(), "unsupported type"));
+            return Err(s_err(input.span(), "unsupported type"));
         }
 
         let mut debug = true;
@@ -674,15 +672,15 @@ impl Parse for Params {
                     let value = match syn::Ident::parse(input)?.to_string().as_str() {
                         "Msb" | "msb" => Order::Msb,
                         "Lsb" | "lsb" => Order::Lsb,
-                        _ => return Err(syn::Error::new(ident.span(), "unknown value for order")),
+                        _ => return Err(s_err(ident.span(), "unknown value for order")),
                     };
                     order = value;
                 }
-                _ => return Err(syn::Error::new(ident.span(), "unknown argument")),
+                _ => return Err(s_err(ident.span(), "unknown argument")),
             };
         }
 
-        Ok(Params {
+        Ok(Self {
             ty,
             bits,
             debug,
diff --git a/tests/test.rs b/tests/test.rs
index f98c9dd..a003d00 100644
--- a/tests/test.rs
+++ b/tests/test.rs
@@ -292,7 +292,7 @@ fn defaults() {
 
     /// A custom enum
     #[derive(Debug, PartialEq, Eq)]
-    #[repr(u16)]
+    #[repr(u8)]
     enum CustomEnum {
         A = 0,
         B = 1,
@@ -300,10 +300,10 @@ fn defaults() {
     }
     impl CustomEnum {
         // This has to be const eval
-        const fn into_bits(self) -> u16 {
+        const fn into_bits(self) -> u8 {
             self as _
         }
-        const fn my_from_bits(value: u16) -> Self {
+        const fn my_from_bits(value: u8) -> Self {
             match value {
                 0 => Self::A,
                 1 => Self::B,