bytemuck_derive/
traits.rs

1#![allow(unused_imports)]
2use std::{cmp, convert::TryFrom};
3
4use proc_macro2::{Ident, Span, TokenStream, TokenTree};
5use quote::{quote, ToTokens};
6use syn::{
7  parse::{Parse, ParseStream, Parser},
8  punctuated::Punctuated,
9  spanned::Spanned,
10  Result, *,
11};
12
13macro_rules! bail {
14  ($msg:expr $(,)?) => {
15    return Err(Error::new(Span::call_site(), &$msg[..]))
16  };
17
18  ( $msg:expr => $span_to_blame:expr $(,)? ) => {
19    return Err(Error::new_spanned(&$span_to_blame, $msg))
20  };
21}
22
23pub trait Derivable {
24  fn ident(input: &DeriveInput, crate_name: &TokenStream) -> Result<syn::Path>;
25  fn implies_trait(_crate_name: &TokenStream) -> Option<TokenStream> {
26    None
27  }
28  fn asserts(
29    _input: &DeriveInput, _crate_name: &TokenStream,
30  ) -> Result<TokenStream> {
31    Ok(quote!())
32  }
33  fn check_attributes(_ty: &Data, _attributes: &[Attribute]) -> Result<()> {
34    Ok(())
35  }
36  fn trait_impl(
37    _input: &DeriveInput, _crate_name: &TokenStream,
38  ) -> Result<(TokenStream, TokenStream)> {
39    Ok((quote!(), quote!()))
40  }
41  fn requires_where_clause() -> bool {
42    true
43  }
44  fn explicit_bounds_attribute_name() -> Option<&'static str> {
45    None
46  }
47
48  /// If this trait has a custom meaning for "perfect derive", this function
49  /// should be overridden to return `Some`.
50  ///
51  /// The default is "the fields of a struct; unions and enums not supported".
52  fn perfect_derive_fields(_input: &DeriveInput) -> Option<Fields> {
53    None
54  }
55}
56
57pub struct Pod;
58
59impl Derivable for Pod {
60  fn ident(_: &DeriveInput, crate_name: &TokenStream) -> Result<syn::Path> {
61    Ok(syn::parse_quote!(#crate_name::Pod))
62  }
63
64  fn asserts(
65    input: &DeriveInput, crate_name: &TokenStream,
66  ) -> Result<TokenStream> {
67    let repr = get_repr(&input.attrs)?;
68
69    let completly_packed =
70      repr.packed == Some(1) || repr.repr == Repr::Transparent;
71
72    if !completly_packed && !input.generics.params.is_empty() {
73      bail!("\
74        Pod requires cannot be derived for non-packed types containing \
75        generic parameters because the padding requirements can't be verified \
76        for generic non-packed structs\
77      " => input.generics.params.first().unwrap());
78    }
79
80    match &input.data {
81      Data::Struct(_) => {
82        let assert_no_padding = if !completly_packed {
83          Some(generate_assert_no_padding(input, None, "Pod")?)
84        } else {
85          None
86        };
87        let assert_fields_are_pod = generate_fields_are_trait(
88          input,
89          None,
90          Self::ident(input, crate_name)?,
91        )?;
92
93        Ok(quote!(
94          #assert_no_padding
95          #assert_fields_are_pod
96        ))
97      }
98      Data::Enum(_) => bail!("Deriving Pod is not supported for enums"),
99      Data::Union(_) => bail!("Deriving Pod is not supported for unions"),
100    }
101  }
102
103  fn check_attributes(_ty: &Data, attributes: &[Attribute]) -> Result<()> {
104    let repr = get_repr(attributes)?;
105    match repr.repr {
106      Repr::C => Ok(()),
107      Repr::Transparent => Ok(()),
108      _ => {
109        bail!("Pod requires the type to be #[repr(C)] or #[repr(transparent)]")
110      }
111    }
112  }
113}
114
115pub struct AnyBitPattern;
116
117impl Derivable for AnyBitPattern {
118  fn ident(_: &DeriveInput, crate_name: &TokenStream) -> Result<syn::Path> {
119    Ok(syn::parse_quote!(#crate_name::AnyBitPattern))
120  }
121
122  fn implies_trait(crate_name: &TokenStream) -> Option<TokenStream> {
123    Some(quote!(#crate_name::Zeroable))
124  }
125
126  fn asserts(
127    input: &DeriveInput, crate_name: &TokenStream,
128  ) -> Result<TokenStream> {
129    match &input.data {
130      Data::Union(_) => Ok(quote!()), // unions are always `AnyBitPattern`
131      Data::Struct(_) => {
132        generate_fields_are_trait(input, None, Self::ident(input, crate_name)?)
133      }
134      Data::Enum(_) => {
135        bail!("Deriving AnyBitPattern is not supported for enums")
136      }
137    }
138  }
139}
140
141pub struct Zeroable;
142
143/// Helper function to get the variant with discriminant zero (implicit or
144/// explicit).
145fn get_zero_variant(enum_: &DataEnum) -> Result<Option<&Variant>> {
146  let iter = VariantDiscriminantIterator::new(enum_.variants.iter());
147  let mut zero_variant = None;
148  for res in iter {
149    let (discriminant, variant) = res?;
150    if discriminant == 0 {
151      zero_variant = Some(variant);
152      break;
153    }
154  }
155  Ok(zero_variant)
156}
157
158impl Derivable for Zeroable {
159  fn ident(_: &DeriveInput, crate_name: &TokenStream) -> Result<syn::Path> {
160    Ok(syn::parse_quote!(#crate_name::Zeroable))
161  }
162
163  fn check_attributes(ty: &Data, attributes: &[Attribute]) -> Result<()> {
164    let repr = get_repr(attributes)?;
165    match ty {
166      Data::Struct(_) => Ok(()),
167      Data::Enum(_) => {
168        if !matches!(
169          repr.repr,
170          Repr::C | Repr::Integer(_) | Repr::CWithDiscriminant(_)
171        ) {
172          bail!("Zeroable requires the enum to be an explicit #[repr(Int)] and/or #[repr(C)]")
173        }
174
175        // We ensure there is a zero variant in `asserts`, since it is needed
176        // there anyway.
177
178        Ok(())
179      }
180      Data::Union(_) => Ok(()),
181    }
182  }
183
184  fn asserts(
185    input: &DeriveInput, crate_name: &TokenStream,
186  ) -> Result<TokenStream> {
187    match &input.data {
188      Data::Union(_) => Ok(quote!()), // unions are always `Zeroable`
189      Data::Struct(_) => {
190        generate_fields_are_trait(input, None, Self::ident(input, crate_name)?)
191      }
192      Data::Enum(enum_) => {
193        let zero_variant = get_zero_variant(enum_)?;
194
195        if zero_variant.is_none() {
196          bail!("No variant's discriminant is 0")
197        };
198
199        generate_fields_are_trait(
200          input,
201          zero_variant,
202          Self::ident(input, crate_name)?,
203        )
204      }
205    }
206  }
207
208  fn explicit_bounds_attribute_name() -> Option<&'static str> {
209    Some("zeroable")
210  }
211
212  fn perfect_derive_fields(input: &DeriveInput) -> Option<Fields> {
213    match &input.data {
214      Data::Struct(struct_) => Some(struct_.fields.clone()),
215      Data::Enum(enum_) => {
216        // We handle `Err` returns from `get_zero_variant` in `asserts`, so it's
217        // fine to just ignore them here and return `None`.
218        // Otherwise, we clone the `fields` of the zero variant (if any).
219        Some(get_zero_variant(enum_).ok()??.fields.clone())
220      }
221      Data::Union(_) => None,
222    }
223  }
224}
225
226pub struct NoUninit;
227
228impl Derivable for NoUninit {
229  fn ident(_: &DeriveInput, crate_name: &TokenStream) -> Result<syn::Path> {
230    Ok(syn::parse_quote!(#crate_name::NoUninit))
231  }
232
233  fn check_attributes(ty: &Data, attributes: &[Attribute]) -> Result<()> {
234    let repr = get_repr(attributes)?;
235    match ty {
236      Data::Struct(_) => match repr.repr {
237        Repr::C | Repr::Transparent => Ok(()),
238        _ => bail!("NoUninit requires the struct to be #[repr(C)] or #[repr(transparent)]"),
239      },
240      Data::Enum(DataEnum { variants,.. }) => {
241        if !enum_has_fields(variants.iter()) {
242          if matches!(repr.repr, Repr::C | Repr::Integer(_)) {
243            Ok(())
244          } else {
245            bail!("NoUninit requires the enum to be #[repr(C)] or #[repr(Int)]")
246          }
247        } else if matches!(repr.repr, Repr::Rust) {
248          bail!("NoUninit requires an explicit repr annotation because `repr(Rust)` doesn't have a specified type layout")
249        } else {
250          Ok(())
251        }
252      },
253      Data::Union(_) => bail!("NoUninit can only be derived on enums and structs")
254    }
255  }
256
257  fn asserts(
258    input: &DeriveInput, crate_name: &TokenStream,
259  ) -> Result<TokenStream> {
260    if !input.generics.params.is_empty() {
261      bail!("NoUninit cannot be derived for structs containing generic parameters because the padding requirements can't be verified for generic structs");
262    }
263
264    match &input.data {
265      Data::Struct(DataStruct { .. }) => {
266        let assert_no_padding =
267          generate_assert_no_padding(&input, None, "NoUninit")?;
268        let assert_fields_are_no_padding = generate_fields_are_trait(
269          &input,
270          None,
271          Self::ident(input, crate_name)?,
272        )?;
273
274        Ok(quote!(
275            #assert_no_padding
276            #assert_fields_are_no_padding
277        ))
278      }
279      Data::Enum(DataEnum { variants, .. }) => {
280        if enum_has_fields(variants.iter()) {
281          // There are two different C representations for enums with fields:
282          // There's `#[repr(C)]`/`[repr(C, int)]` and `#[repr(int)]`.
283          // `#[repr(C)]` is equivalent to a struct containing the discriminant
284          // and a union of structs representing each variant's fields.
285          // `#[repr(int)]` is equivalent to a union containing structs of the
286          // discriminant and the fields.
287          //
288          // See https://doc.rust-lang.org/reference/type-layout.html#r-layout.repr.c.adt
289          // and https://doc.rust-lang.org/reference/type-layout.html#r-layout.repr.primitive.adt
290          //
291          // In practice the only difference between the two is whether and
292          // where padding bytes are placed. For `#[repr(C)]` enums, the first
293          // enum fields of all variants start at the same location (the first
294          // byte in the union). For `#[repr(int)]` enums, the structs
295          // representing each variant are layed out individually and padding
296          // does not depend on other variants, but only on the size of the
297          // discriminant and the alignment of the first field. The location of
298          // the first field might differ between variants, potentially
299          // resulting in less padding or padding placed later in the enum.
300          //
301          // The `NoUninit` derive macro asserts that no padding exists by
302          // removing all padding with `#[repr(packed)]` and checking that this
303          // doesn't change the size. Since the location and presence of
304          // padding bytes is the only difference between the two
305          // representations and we're removing all padding bytes, the resuling
306          // layout would identical for both representations. This means that
307          // we can just pick one of the representations and don't have to
308          // implement desugaring for both. We chose to implement the
309          // desugaring for `#[repr(int)]`.
310
311          let enum_discriminant = generate_enum_discriminant(input)?;
312          let variant_assertions = variants
313            .iter()
314            .map(|variant| {
315              let assert_no_padding =
316                generate_assert_no_padding(&input, Some(variant), "NoUninit")?;
317              let assert_fields_are_no_padding = generate_fields_are_trait(
318                &input,
319                Some(variant),
320                Self::ident(input, crate_name)?,
321              )?;
322
323              Ok(quote!(
324                  #assert_no_padding
325                  #assert_fields_are_no_padding
326              ))
327            })
328            .collect::<Result<Vec<_>>>()?;
329          Ok(quote! {
330            const _: () = {
331              #enum_discriminant
332              #(#variant_assertions)*
333            };
334          })
335        } else {
336          Ok(quote!())
337        }
338      }
339      Data::Union(_) => bail!("NoUninit cannot be derived for unions"), /* shouldn't be possible since we already error in attribute check for this case */
340    }
341  }
342
343  fn trait_impl(
344    _input: &DeriveInput, _crate_name: &TokenStream,
345  ) -> Result<(TokenStream, TokenStream)> {
346    Ok((quote!(), quote!()))
347  }
348}
349
350pub struct CheckedBitPattern;
351
352impl Derivable for CheckedBitPattern {
353  fn ident(_: &DeriveInput, crate_name: &TokenStream) -> Result<syn::Path> {
354    Ok(syn::parse_quote!(#crate_name::CheckedBitPattern))
355  }
356
357  fn check_attributes(ty: &Data, attributes: &[Attribute]) -> Result<()> {
358    let repr = get_repr(attributes)?;
359    match ty {
360      Data::Struct(_) => match repr.repr {
361        Repr::C | Repr::Transparent => Ok(()),
362        _ => bail!("CheckedBitPattern derive requires the struct to be #[repr(C)] or #[repr(transparent)]"),
363      },
364      Data::Enum(DataEnum { variants,.. }) => {
365        if !enum_has_fields(variants.iter()){
366          if matches!(repr.repr, Repr::C | Repr::Integer(_)) {
367            Ok(())
368          } else {
369            bail!("CheckedBitPattern requires the enum to be #[repr(C)] or #[repr(Int)]")
370          }
371        } else if matches!(repr.repr, Repr::Rust) {
372          bail!("CheckedBitPattern requires an explicit repr annotation because `repr(Rust)` doesn't have a specified type layout")
373        } else {
374          Ok(())
375        }
376      }
377      Data::Union(_) => bail!("CheckedBitPattern can only be derived on enums and structs")
378    }
379  }
380
381  fn asserts(
382    input: &DeriveInput, crate_name: &TokenStream,
383  ) -> Result<TokenStream> {
384    if !input.generics.params.is_empty() {
385      bail!("CheckedBitPattern cannot be derived for structs containing generic parameters");
386    }
387
388    match &input.data {
389      Data::Struct(DataStruct { .. }) => {
390        let assert_fields_are_maybe_pod = generate_fields_are_trait(
391          &input,
392          None,
393          Self::ident(input, crate_name)?,
394        )?;
395
396        Ok(assert_fields_are_maybe_pod)
397      }
398      // nothing needed, already guaranteed OK by NoUninit.
399      Data::Enum(_) => Ok(quote!()),
400      Data::Union(_) => bail!("Internal error in CheckedBitPattern derive"), /* shouldn't be possible since we already error in attribute check for this case */
401    }
402  }
403
404  fn trait_impl(
405    input: &DeriveInput, crate_name: &TokenStream,
406  ) -> Result<(TokenStream, TokenStream)> {
407    match &input.data {
408      Data::Struct(DataStruct { fields, .. }) => {
409        generate_checked_bit_pattern_struct(
410          &input.ident,
411          fields,
412          &input.attrs,
413          crate_name,
414        )
415      }
416      Data::Enum(DataEnum { variants, .. }) => {
417        generate_checked_bit_pattern_enum(input, variants, crate_name)
418      }
419      Data::Union(_) => bail!("Internal error in CheckedBitPattern derive"), /* shouldn't be possible since we already error in attribute check for this case */
420    }
421  }
422}
423
424pub struct TransparentWrapper;
425
426struct WrappedType {
427  wrapped_type: syn::Type,
428  /// Was the type given with a #[transparent(Type)] attribute.
429  explicit: bool,
430}
431
432impl TransparentWrapper {
433  fn get_wrapped_type(
434    attributes: &[Attribute], fields: &Fields,
435  ) -> Option<WrappedType> {
436    let transparent_param =
437      get_type_from_simple_attr(attributes, "transparent")
438        .map(|wrapped_type| WrappedType { wrapped_type, explicit: true });
439    transparent_param.or_else(|| {
440      let mut types = get_field_types(&fields);
441      let first_type = types.next();
442      if let Some(_) = types.next() {
443        // can't guess param type if there is more than one field
444        return None;
445      } else {
446        first_type
447          .cloned()
448          .map(|wrapped_type| WrappedType { wrapped_type, explicit: false })
449      }
450    })
451  }
452}
453
454impl Derivable for TransparentWrapper {
455  fn ident(input: &DeriveInput, crate_name: &TokenStream) -> Result<syn::Path> {
456    let fields = get_struct_fields(input)?;
457
458    let WrappedType { wrapped_type: ty, .. } =
459      match Self::get_wrapped_type(&input.attrs, &fields) {
460        Some(ty) => ty,
461        None => bail!("when deriving TransparentWrapper for a struct with more \
462                       than one field, you need to specify the transparent field \
463                       using #[transparent(T)]"),
464      };
465
466    Ok(syn::parse_quote!(#crate_name::TransparentWrapper<#ty>))
467  }
468
469  fn asserts(
470    input: &DeriveInput, crate_name: &TokenStream,
471  ) -> Result<TokenStream> {
472    let (impl_generics, _ty_generics, where_clause) =
473      input.generics.split_for_impl();
474    let fields = get_struct_fields(input)?;
475    let (wrapped_type, explicit) =
476      match Self::get_wrapped_type(&input.attrs, &fields) {
477        Some(WrappedType { wrapped_type, explicit }) => {
478          (wrapped_type.to_token_stream().to_string(), explicit)
479        }
480        None => unreachable!(), /* other code will already reject this derive */
481      };
482    let mut wrapped_field_ty = None;
483    let mut nonwrapped_field_tys = vec![];
484    for field in fields.iter() {
485      let field_ty = &field.ty;
486      if field_ty.to_token_stream().to_string() == wrapped_type {
487        if wrapped_field_ty.is_some() {
488          if explicit {
489            bail!("TransparentWrapper must have one field of the wrapped type. \
490                   The type given in `#[transparent(Type)]` must match tokenwise \
491                   with the type in the struct definition, not just be the same type. \
492                   You may be able to use a type alias to work around this limitation.");
493          } else {
494            bail!("TransparentWrapper must have one field of the wrapped type");
495          }
496        }
497        wrapped_field_ty = Some(field_ty);
498      } else {
499        nonwrapped_field_tys.push(field_ty);
500      }
501    }
502    if let Some(wrapped_field_ty) = wrapped_field_ty {
503      Ok(quote!(
504        const _: () = {
505          #[repr(transparent)]
506          #[allow(clippy::multiple_bound_locations)]
507          struct AssertWrappedIsWrapped #impl_generics((u8, ::core::marker::PhantomData<#wrapped_field_ty>), #(#nonwrapped_field_tys),*) #where_clause;
508          fn assert_zeroable<Z: #crate_name::Zeroable>() {}
509          #[allow(clippy::multiple_bound_locations)]
510          fn check #impl_generics () #where_clause {
511            #(
512              assert_zeroable::<#nonwrapped_field_tys>();
513            )*
514          }
515        };
516      ))
517    } else {
518      bail!("TransparentWrapper must have one field of the wrapped type")
519    }
520  }
521
522  fn check_attributes(_ty: &Data, attributes: &[Attribute]) -> Result<()> {
523    let repr = get_repr(attributes)?;
524
525    match repr.repr {
526      Repr::Transparent => Ok(()),
527      _ => {
528        bail!(
529          "TransparentWrapper requires the struct to be #[repr(transparent)]"
530        )
531      }
532    }
533  }
534
535  fn requires_where_clause() -> bool {
536    false
537  }
538}
539
540pub struct Contiguous;
541
542impl Derivable for Contiguous {
543  fn ident(_: &DeriveInput, crate_name: &TokenStream) -> Result<syn::Path> {
544    Ok(syn::parse_quote!(#crate_name::Contiguous))
545  }
546
547  fn trait_impl(
548    input: &DeriveInput, _crate_name: &TokenStream,
549  ) -> Result<(TokenStream, TokenStream)> {
550    let repr = get_repr(&input.attrs)?;
551
552    let integer_ty = if let Some(integer_ty) = repr.repr.as_integer() {
553      integer_ty
554    } else {
555      bail!("Contiguous requires the enum to be #[repr(Int)]");
556    };
557
558    let variants = get_enum_variants(input)?;
559    if enum_has_fields(variants.clone()) {
560      return Err(Error::new_spanned(
561        &input,
562        "Only fieldless enums are supported",
563      ));
564    }
565
566    let mut variants_with_discriminant =
567      VariantDiscriminantIterator::new(variants);
568
569    let (min, max, count) = variants_with_discriminant.try_fold(
570      (i128::MAX, i128::MIN, 0),
571      |(min, max, count), res| {
572        let (discriminant, _variant) = res?;
573        Ok::<_, Error>((
574          i128::min(min, discriminant),
575          i128::max(max, discriminant),
576          count + 1,
577        ))
578      },
579    )?;
580
581    if max - min != count - 1 {
582      bail! {
583        "Contiguous requires the enum discriminants to be contiguous",
584      }
585    }
586
587    let min_lit = LitInt::new(&format!("{}", min), input.span());
588    let max_lit = LitInt::new(&format!("{}", max), input.span());
589
590    // `from_integer` and `into_integer` are usually provided by the trait's
591    // default implementation. We override this implementation because it
592    // goes through `transmute_copy`, which can lead to inefficient assembly as seen in https://github.com/Lokathor/bytemuck/issues/175 .
593
594    Ok((
595      quote!(),
596      quote! {
597          type Int = #integer_ty;
598
599          #[allow(clippy::missing_docs_in_private_items)]
600          const MIN_VALUE: #integer_ty = #min_lit;
601
602          #[allow(clippy::missing_docs_in_private_items)]
603          const MAX_VALUE: #integer_ty = #max_lit;
604
605          #[inline]
606          fn from_integer(value: Self::Int) -> Option<Self> {
607            #[allow(clippy::manual_range_contains)]
608            if Self::MIN_VALUE <= value && value <= Self::MAX_VALUE {
609              Some(unsafe { ::core::mem::transmute(value) })
610            } else {
611              None
612            }
613          }
614
615          #[inline]
616          fn into_integer(self) -> Self::Int {
617              self as #integer_ty
618          }
619      },
620    ))
621  }
622}
623
624fn get_struct_fields(input: &DeriveInput) -> Result<&Fields> {
625  if let Data::Struct(DataStruct { fields, .. }) = &input.data {
626    Ok(fields)
627  } else {
628    bail!("deriving this trait is only supported for structs")
629  }
630}
631
632/// Extract the `Fields` off a `DeriveInput`, or, in the `enum` case, off
633/// those of the `enum_variant`, when provided (e.g., for `Zeroable`).
634///
635/// We purposely allow not providing an `enum_variant` for cases where
636/// the caller wants to reject supporting `enum`s (e.g., `NoPadding`).
637fn get_fields(
638  input: &DeriveInput, enum_variant: Option<&Variant>,
639) -> Result<Fields> {
640  match &input.data {
641    Data::Struct(DataStruct { fields, .. }) => Ok(fields.clone()),
642    Data::Union(DataUnion { fields, .. }) => Ok(Fields::Named(fields.clone())),
643    Data::Enum(_) => match enum_variant {
644      Some(variant) => Ok(variant.fields.clone()),
645      None => bail!("deriving this trait is not supported for enums"),
646    },
647  }
648}
649
650fn get_enum_variants<'a>(
651  input: &'a DeriveInput,
652) -> Result<impl Iterator<Item = &'a Variant> + Clone + 'a> {
653  if let Data::Enum(DataEnum { variants, .. }) = &input.data {
654    Ok(variants.iter())
655  } else {
656    bail!("deriving this trait is only supported for enums")
657  }
658}
659
660fn get_field_types<'a>(
661  fields: &'a Fields,
662) -> impl Iterator<Item = &'a Type> + 'a {
663  fields.iter().map(|field| &field.ty)
664}
665
666fn generate_checked_bit_pattern_struct(
667  input_ident: &Ident, fields: &Fields, attrs: &[Attribute],
668  crate_name: &TokenStream,
669) -> Result<(TokenStream, TokenStream)> {
670  let bits_ty = Ident::new(&format!("{}Bits", input_ident), input_ident.span());
671
672  let repr = get_repr(attrs)?;
673
674  let field_names = fields
675    .iter()
676    .enumerate()
677    .map(|(i, field)| {
678      field.ident.clone().unwrap_or_else(|| {
679        Ident::new(&format!("field{}", i), input_ident.span())
680      })
681    })
682    .collect::<Vec<_>>();
683  let field_tys = fields.iter().map(|field| &field.ty).collect::<Vec<_>>();
684
685  let field_name = &field_names[..];
686  let field_ty = &field_tys[..];
687
688  Ok((
689    quote! {
690        #[doc = #GENERATED_TYPE_DOCUMENTATION]
691        #repr
692        #[derive(Clone, Copy, #crate_name::AnyBitPattern)]
693        #[allow(missing_docs)]
694        pub struct #bits_ty {
695            #(#field_name: <#field_ty as #crate_name::CheckedBitPattern>::Bits,)*
696        }
697
698        #[allow(unexpected_cfgs)]
699        const _: () = {
700          #[cfg(not(target_arch = "spirv"))]
701          impl ::core::fmt::Debug for #bits_ty {
702            fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result {
703              let mut debug_struct = ::core::fmt::Formatter::debug_struct(f, ::core::stringify!(#bits_ty));
704              #(::core::fmt::DebugStruct::field(&mut debug_struct, ::core::stringify!(#field_name), &{ self.#field_name });)*
705              ::core::fmt::DebugStruct::finish(&mut debug_struct)
706            }
707          }
708        };
709    },
710    quote! {
711        type Bits = #bits_ty;
712
713        #[inline]
714        #[allow(clippy::double_comparisons, unused)]
715        fn is_valid_bit_pattern(bits: &#bits_ty) -> bool {
716            #(<#field_ty as #crate_name::CheckedBitPattern>::is_valid_bit_pattern(&{ bits.#field_name }) && )* true
717        }
718    },
719  ))
720}
721
722fn generate_checked_bit_pattern_enum(
723  input: &DeriveInput, variants: &Punctuated<Variant, Token![,]>,
724  crate_name: &TokenStream,
725) -> Result<(TokenStream, TokenStream)> {
726  if enum_has_fields(variants.iter()) {
727    generate_checked_bit_pattern_enum_with_fields(input, variants, crate_name)
728  } else {
729    generate_checked_bit_pattern_enum_without_fields(
730      input, variants, crate_name,
731    )
732  }
733}
734
735fn generate_checked_bit_pattern_enum_without_fields(
736  input: &DeriveInput, variants: &Punctuated<Variant, Token![,]>,
737  crate_name: &TokenStream,
738) -> Result<(TokenStream, TokenStream)> {
739  let span = input.span();
740  let mut variants_with_discriminant =
741    VariantDiscriminantIterator::new(variants.iter());
742
743  let (min, max, count) = variants_with_discriminant.try_fold(
744    (i128::MAX, i128::MIN, 0),
745    |(min, max, count), res| {
746      let (discriminant, _variant) = res?;
747      Ok::<_, Error>((
748        i128::min(min, discriminant),
749        i128::max(max, discriminant),
750        count + 1,
751      ))
752    },
753  )?;
754
755  let check = if count == 0 {
756    quote!(false)
757  } else if max - min == count - 1 {
758    // contiguous range
759    let min_lit = LitInt::new(&format!("{}", min), span);
760    let max_lit = LitInt::new(&format!("{}", max), span);
761
762    quote!(*bits >= #min_lit && *bits <= #max_lit)
763  } else {
764    // not contiguous range, check for each
765    let variant_discriminant_lits =
766      VariantDiscriminantIterator::new(variants.iter())
767        .map(|res| {
768          let (discriminant, _variant) = res?;
769          Ok(LitInt::new(&format!("{}", discriminant), span))
770        })
771        .collect::<Result<Vec<_>>>()?;
772
773    // count is at least 1
774    let first = &variant_discriminant_lits[0];
775    let rest = &variant_discriminant_lits[1..];
776
777    quote!(matches!(*bits, #first #(| #rest )*))
778  };
779
780  let (integer, defs) = get_enum_discriminant(input, crate_name)?;
781  Ok((
782    quote!(#defs),
783    quote! {
784        type Bits = #integer;
785
786        #[inline]
787        #[allow(clippy::double_comparisons)]
788        fn is_valid_bit_pattern(bits: &Self::Bits) -> bool {
789            #check
790        }
791    },
792  ))
793}
794
795fn generate_checked_bit_pattern_enum_with_fields(
796  input: &DeriveInput, variants: &Punctuated<Variant, Token![,]>,
797  crate_name: &TokenStream,
798) -> Result<(TokenStream, TokenStream)> {
799  let representation = get_repr(&input.attrs)?;
800  let vis = &input.vis;
801
802  match representation.repr {
803    Repr::Rust => unreachable!(),
804    Repr::C | Repr::CWithDiscriminant(_) => {
805      let (integer, defs) = get_enum_discriminant(input, crate_name)?;
806      let input_ident = &input.ident;
807
808      let bits_repr = Representation { repr: Repr::C, ..representation };
809
810      // the enum manually re-configured as the actual tagged union it
811      // represents, thus circumventing the requirements rust imposes on
812      // the tag even when using #[repr(C)] enum layout
813      // see: https://doc.rust-lang.org/reference/type-layout.html#reprc-enums-with-fields
814      let bits_ty_ident =
815        Ident::new(&format!("{input_ident}Bits"), input.span());
816
817      // the variants union part of the tagged union. These get put into a union
818      // which gets the AnyBitPattern derive applied to it, thus checking
819      // that the fields of the union obey the requriements of AnyBitPattern.
820      // The types that actually go in the union are one more level of
821      // indirection deep: we generate new structs for each variant
822      // (`variant_struct_definitions`) which themselves have the
823      // `CheckedBitPattern` derive applied, thus generating
824      // `{variant_struct_ident}Bits` structs, which are the ones that go
825      // into this union.
826      let variants_union_ident =
827        Ident::new(&format!("{}Variants", input.ident), input.span());
828
829      let variant_struct_idents = variants.iter().map(|v| {
830        Ident::new(&format!("{input_ident}Variant{}", v.ident), v.span())
831      });
832
833      let variant_struct_definitions =
834        variant_struct_idents.clone().zip(variants.iter()).map(|(variant_struct_ident, v)| {
835          let fields = v.fields.iter().map(|v| &v.ty);
836
837          quote! {
838            #[derive(::core::clone::Clone, ::core::marker::Copy, #crate_name::CheckedBitPattern)]
839            #[repr(C)]
840            #vis struct #variant_struct_ident(#(#fields),*);
841          }
842        });
843
844      let union_fields = variant_struct_idents
845        .clone()
846        .zip(variants.iter())
847        .map(|(variant_struct_ident, v)| {
848          let variant_struct_bits_ident =
849            Ident::new(&format!("{variant_struct_ident}Bits"), input.span());
850          let field_ident = &v.ident;
851          quote! {
852            #field_ident: #variant_struct_bits_ident
853          }
854        });
855
856      let variant_checks = variant_struct_idents
857        .clone()
858        .zip(VariantDiscriminantIterator::new(variants.iter()))
859        .zip(variants.iter())
860        .map(|((variant_struct_ident, discriminant), v)| -> Result<_> {
861          let (discriminant, _variant) = discriminant?;
862          let discriminant = LitInt::new(&discriminant.to_string(), v.span());
863          let ident = &v.ident;
864          Ok(quote! {
865            #discriminant => {
866              let payload = unsafe { &bits.payload.#ident };
867              <#variant_struct_ident as #crate_name::CheckedBitPattern>::is_valid_bit_pattern(payload)
868            }
869          })
870        })
871        .collect::<Result<Vec<_>>>()?;
872
873      Ok((
874        quote! {
875          #defs
876
877          #[doc = #GENERATED_TYPE_DOCUMENTATION]
878          #[derive(::core::clone::Clone, ::core::marker::Copy, #crate_name::AnyBitPattern)]
879          #bits_repr
880          #vis struct #bits_ty_ident {
881            tag: #integer,
882            payload: #variants_union_ident,
883          }
884
885          #[allow(unexpected_cfgs)]
886          const _: () = {
887            #[cfg(not(target_arch = "spirv"))]
888            impl ::core::fmt::Debug for #bits_ty_ident {
889              fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result {
890                let mut debug_struct = ::core::fmt::Formatter::debug_struct(f, ::core::stringify!(#bits_ty_ident));
891                ::core::fmt::DebugStruct::field(&mut debug_struct, "tag", &self.tag);
892                ::core::fmt::DebugStruct::field(&mut debug_struct, "payload", &self.payload);
893                ::core::fmt::DebugStruct::finish(&mut debug_struct)
894              }
895            }
896          };
897
898          #[derive(::core::clone::Clone, ::core::marker::Copy, #crate_name::AnyBitPattern)]
899          #[repr(C)]
900          #[allow(non_snake_case)]
901          #vis union #variants_union_ident {
902            #(#union_fields,)*
903          }
904
905          #[allow(unexpected_cfgs)]
906          const _: () = {
907            #[cfg(not(target_arch = "spirv"))]
908            impl ::core::fmt::Debug for #variants_union_ident {
909              fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result {
910                let mut debug_struct = ::core::fmt::Formatter::debug_struct(f, ::core::stringify!(#variants_union_ident));
911                ::core::fmt::DebugStruct::finish_non_exhaustive(&mut debug_struct)
912              }
913            }
914          };
915
916          #(#variant_struct_definitions)*
917        },
918        quote! {
919          type Bits = #bits_ty_ident;
920
921          #[inline]
922          #[allow(clippy::double_comparisons)]
923          fn is_valid_bit_pattern(bits: &Self::Bits) -> bool {
924            match bits.tag {
925              #(#variant_checks)*
926              _ => false,
927            }
928          }
929        },
930      ))
931    }
932    Repr::Transparent => {
933      if variants.len() != 1 {
934        bail!("enums with more than one variant cannot be transparent")
935      }
936
937      let variant = &variants[0];
938
939      let bits_ty = Ident::new(&format!("{}Bits", input.ident), input.span());
940      let fields = variant.fields.iter().map(|v| &v.ty);
941
942      Ok((
943        quote! {
944          #[doc = #GENERATED_TYPE_DOCUMENTATION]
945          #[derive(::core::clone::Clone, ::core::marker::Copy, #crate_name::CheckedBitPattern)]
946          #[repr(C)]
947          #vis struct #bits_ty(#(#fields),*);
948        },
949        quote! {
950          type Bits = <#bits_ty as #crate_name::CheckedBitPattern>::Bits;
951
952          #[inline]
953          #[allow(clippy::double_comparisons)]
954          fn is_valid_bit_pattern(bits: &Self::Bits) -> bool {
955            <#bits_ty as #crate_name::CheckedBitPattern>::is_valid_bit_pattern(bits)
956          }
957        },
958      ))
959    }
960    Repr::Integer(integer) => {
961      let bits_repr = Representation { repr: Repr::C, ..representation };
962      let input_ident = &input.ident;
963
964      // the enum manually re-configured as the union it represents. such a
965      // union is the union of variants as a repr(c) struct with the
966      // discriminator type inserted at the beginning. in our case we
967      // union the `Bits` representation of each variant rather than the variant
968      // itself, which we generate via a nested `CheckedBitPattern` derive
969      // on the `variant_struct_definitions` generated below.
970      //
971      // see: https://doc.rust-lang.org/reference/type-layout.html#primitive-representation-of-enums-with-fields
972      let bits_ty_ident =
973        Ident::new(&format!("{input_ident}Bits"), input.span());
974
975      let variant_struct_idents = variants.iter().map(|v| {
976        Ident::new(&format!("{input_ident}Variant{}", v.ident), v.span())
977      });
978
979      let variant_struct_definitions =
980        variant_struct_idents.clone().zip(variants.iter()).map(|(variant_struct_ident, v)| {
981          let fields = v.fields.iter().map(|v| &v.ty);
982
983          // adding the discriminant repr integer as first field, as described above
984          quote! {
985            #[derive(::core::clone::Clone, ::core::marker::Copy, #crate_name::CheckedBitPattern)]
986            #[repr(C)]
987            #vis struct #variant_struct_ident(#integer, #(#fields),*);
988          }
989        });
990
991      let union_fields = variant_struct_idents
992        .clone()
993        .zip(variants.iter())
994        .map(|(variant_struct_ident, v)| {
995          let variant_struct_bits_ident =
996            Ident::new(&format!("{variant_struct_ident}Bits"), input.span());
997          let field_ident = &v.ident;
998          quote! {
999            #field_ident: #variant_struct_bits_ident
1000          }
1001        });
1002
1003      let variant_checks = variant_struct_idents
1004        .clone()
1005        .zip(VariantDiscriminantIterator::new(variants.iter()))
1006        .zip(variants.iter())
1007        .map(|((variant_struct_ident, discriminant), v)| -> Result<_> {
1008          let (discriminant, _variant) = discriminant?;
1009          let discriminant = LitInt::new(&discriminant.to_string(), v.span());
1010          let ident = &v.ident;
1011          Ok(quote! {
1012            #discriminant => {
1013              let payload = unsafe { &bits.#ident };
1014              <#variant_struct_ident as #crate_name::CheckedBitPattern>::is_valid_bit_pattern(payload)
1015            }
1016          })
1017        })
1018        .collect::<Result<Vec<_>>>()?;
1019
1020      Ok((
1021        quote! {
1022          #[doc = #GENERATED_TYPE_DOCUMENTATION]
1023          #[derive(::core::clone::Clone, ::core::marker::Copy, #crate_name::AnyBitPattern)]
1024          #bits_repr
1025          #[allow(non_snake_case)]
1026          #vis union #bits_ty_ident {
1027            __tag: #integer,
1028            #(#union_fields,)*
1029          }
1030
1031          #[allow(unexpected_cfgs)]
1032          const _: () = {
1033            #[cfg(not(target_arch = "spirv"))]
1034            impl ::core::fmt::Debug for #bits_ty_ident {
1035              fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result {
1036                let mut debug_struct = ::core::fmt::Formatter::debug_struct(f, ::core::stringify!(#bits_ty_ident));
1037                ::core::fmt::DebugStruct::field(&mut debug_struct, "tag", unsafe { &self.__tag });
1038                ::core::fmt::DebugStruct::finish_non_exhaustive(&mut debug_struct)
1039              }
1040            }
1041          };
1042
1043          #(#variant_struct_definitions)*
1044        },
1045        quote! {
1046          type Bits = #bits_ty_ident;
1047
1048          #[inline]
1049          #[allow(clippy::double_comparisons)]
1050          fn is_valid_bit_pattern(bits: &Self::Bits) -> bool {
1051            match unsafe { bits.__tag } {
1052              #(#variant_checks)*
1053              _ => false,
1054            }
1055          }
1056        },
1057      ))
1058    }
1059  }
1060}
1061
1062/// Check that a struct or enum has no padding by asserting that the size of
1063/// the type is equal to the sum of the size of it's fields and discriminant
1064/// (for enums, this must be asserted for each variant).
1065fn generate_assert_no_padding(
1066  input: &DeriveInput, enum_variant: Option<&Variant>, for_trait: &str,
1067) -> Result<TokenStream> {
1068  let struct_type = &input.ident;
1069  let fields = get_fields(input, enum_variant)?;
1070
1071  // If the type is an enum, determine the type of its discriminant.
1072  let enum_discriminant = if matches!(input.data, Data::Enum(_)) {
1073    let ident =
1074      Ident::new(&format!("{}Discriminant", input.ident), input.ident.span());
1075    Some(ident.into_token_stream())
1076  } else {
1077    None
1078  };
1079
1080  // Prepend the type of the discriminant to the types of the fields.
1081  let mut field_types = enum_discriminant
1082    .into_iter()
1083    .chain(get_field_types(&fields).map(ToTokens::to_token_stream));
1084  let size_sum = if let Some(first) = field_types.next() {
1085    let size_first = quote!(::core::mem::size_of::<#first>());
1086    let size_rest = quote!(#( + ::core::mem::size_of::<#field_types>() )*);
1087
1088    quote!(#size_first #size_rest)
1089  } else {
1090    quote!(0)
1091  };
1092
1093  let message =
1094    format!("derive({for_trait}) was applied to a type with padding");
1095
1096  Ok(quote! {const _: () = {
1097    assert!(
1098        ::core::mem::size_of::<#struct_type>() == (#size_sum),
1099        #message,
1100    );
1101  };})
1102}
1103
1104/// Check that all fields implement a given trait
1105fn generate_fields_are_trait(
1106  input: &DeriveInput, enum_variant: Option<&Variant>, trait_: syn::Path,
1107) -> Result<TokenStream> {
1108  let (impl_generics, _ty_generics, where_clause) =
1109    input.generics.split_for_impl();
1110  let fields = get_fields(input, enum_variant)?;
1111  let field_types = get_field_types(&fields);
1112  Ok(quote! {#(const _: fn() = || {
1113      #[allow(clippy::missing_const_for_fn)]
1114      #[doc(hidden)]
1115      fn check #impl_generics () #where_clause {
1116        fn assert_impl<T: #trait_>() {}
1117        assert_impl::<#field_types>();
1118      }
1119    };)*
1120  })
1121}
1122
1123/// Get the type of an enum's discriminant.
1124///
1125/// For `repr(int)` and `repr(C, int)` enums, this will return the known bare
1126/// integer type specified.
1127///
1128/// For `repr(C)` enums, this will extract the underlying size chosen by rustc.
1129/// It will return a token stream which is a type expression that evaluates to
1130/// a primitive integer type of this size, using our `EnumTagIntegerBytes`
1131/// trait.
1132///
1133/// For fieldless `repr(C)` enums, we can feed the size of the enum directly
1134/// into the trait.
1135///
1136/// For `repr(C)` enums with fields, we generate a new fieldless `repr(C)` enum
1137/// with the same variants, then use that in the calculation. This is the
1138/// specified behavior, see https://doc.rust-lang.org/stable/reference/type-layout.html#reprc-enums-with-fields
1139///
1140/// Returns a tuple of (type ident, auxiliary definitions)
1141fn get_enum_discriminant(
1142  input: &DeriveInput, crate_name: &TokenStream,
1143) -> Result<(TokenStream, TokenStream)> {
1144  let repr = get_repr(&input.attrs)?;
1145  match repr.repr {
1146    Repr::C => {
1147      let e = if let Data::Enum(e) = &input.data { e } else { unreachable!() };
1148      if enum_has_fields(e.variants.iter()) {
1149        // If the enum has fields, we must first isolate the discriminant by
1150        // removing all the fields.
1151        let enum_discriminant = generate_enum_discriminant(input)?;
1152        let discriminant_ident = Ident::new(
1153          &format!("{}Discriminant", input.ident),
1154          input.ident.span(),
1155        );
1156        Ok((
1157          quote!(<[::core::primitive::u8; ::core::mem::size_of::<#discriminant_ident>()] as #crate_name::derive::EnumTagIntegerBytes>::Integer),
1158          quote! {
1159            #enum_discriminant
1160          },
1161        ))
1162      } else {
1163        // If the enum doesn't have fields, we can just use it directly.
1164        let ident = &input.ident;
1165        Ok((
1166          quote!(<[::core::primitive::u8; ::core::mem::size_of::<#ident>()] as #crate_name::derive::EnumTagIntegerBytes>::Integer),
1167          quote!(),
1168        ))
1169      }
1170    }
1171    Repr::Integer(integer) | Repr::CWithDiscriminant(integer) => {
1172      Ok((quote!(#integer), quote!()))
1173    }
1174    _ => unreachable!(),
1175  }
1176}
1177
1178fn generate_enum_discriminant(input: &DeriveInput) -> Result<TokenStream> {
1179  let e = if let Data::Enum(e) = &input.data { e } else { unreachable!() };
1180  let repr = get_repr(&input.attrs)?;
1181  let repr = match repr.repr {
1182    Repr::C => quote!(#[repr(C)]),
1183    Repr::Integer(int) | Repr::CWithDiscriminant(int) => quote!(#[repr(#int)]),
1184    Repr::Rust | Repr::Transparent => unreachable!(),
1185  };
1186  let ident =
1187    Ident::new(&format!("{}Discriminant", input.ident), input.ident.span());
1188  let variants = e.variants.iter().cloned().map(|mut e| {
1189    e.fields = Fields::Unit;
1190    e
1191  });
1192  Ok(quote! {
1193    #repr
1194    #[allow(dead_code)]
1195    enum #ident {
1196      #(#variants,)*
1197    }
1198  })
1199}
1200
1201fn get_wrapped_type_from_stream(tokens: TokenStream) -> Option<syn::Type> {
1202  let mut tokens = tokens.into_iter().peekable();
1203  match tokens.peek() {
1204    Some(TokenTree::Group(group)) => {
1205      let res = get_wrapped_type_from_stream(group.stream());
1206      tokens.next(); // remove the peeked token tree
1207      match tokens.next() {
1208        // If there were more tokens, the input was invalid
1209        Some(_) => None,
1210        None => res,
1211      }
1212    }
1213    _ => syn::parse2(tokens.collect()).ok(),
1214  }
1215}
1216
1217/// get a simple `#[foo(bar)]` attribute, returning `bar`
1218fn get_type_from_simple_attr(
1219  attributes: &[Attribute], attr_name: &str,
1220) -> Option<syn::Type> {
1221  for attr in attributes {
1222    if let (AttrStyle::Outer, Meta::List(list)) = (&attr.style, &attr.meta) {
1223      if list.path.is_ident(attr_name) {
1224        if let Some(ty) = get_wrapped_type_from_stream(list.tokens.clone()) {
1225          return Some(ty);
1226        }
1227      }
1228    }
1229  }
1230
1231  None
1232}
1233
1234fn get_repr(attributes: &[Attribute]) -> Result<Representation> {
1235  attributes
1236    .iter()
1237    .filter_map(|attr| {
1238      if attr.path().is_ident("repr") {
1239        Some(attr.parse_args::<Representation>())
1240      } else {
1241        None
1242      }
1243    })
1244    .try_fold(Representation::default(), |a, b| {
1245      let b = b?;
1246      Ok(Representation {
1247        repr: match (a.repr, b.repr) {
1248          (a, Repr::Rust) => a,
1249          (Repr::Rust, b) => b,
1250          _ => bail!("conflicting representation hints"),
1251        },
1252        packed: match (a.packed, b.packed) {
1253          (a, None) => a,
1254          (None, b) => b,
1255          _ => bail!("conflicting representation hints"),
1256        },
1257        align: match (a.align, b.align) {
1258          (Some(a), Some(b)) => Some(cmp::max(a, b)),
1259          (a, None) => a,
1260          (None, b) => b,
1261        },
1262      })
1263    })
1264}
1265
1266mk_repr! {
1267  U8 => u8,
1268  I8 => i8,
1269  U16 => u16,
1270  I16 => i16,
1271  U32 => u32,
1272  I32 => i32,
1273  U64 => u64,
1274  I64 => i64,
1275  I128 => i128,
1276  U128 => u128,
1277  Usize => usize,
1278  Isize => isize,
1279}
1280// where
1281macro_rules! mk_repr {(
1282  $(
1283    $Xn:ident => $xn:ident
1284  ),* $(,)?
1285) => (
1286  #[derive(Debug, Clone, Copy, PartialEq, Eq)]
1287  enum IntegerRepr {
1288    $($Xn),*
1289  }
1290
1291  impl<'a> TryFrom<&'a str> for IntegerRepr {
1292    type Error = &'a str;
1293
1294    fn try_from(value: &'a str) -> std::result::Result<Self, &'a str> {
1295      match value {
1296        $(
1297          stringify!($xn) => Ok(Self::$Xn),
1298        )*
1299        _ => Err(value),
1300      }
1301    }
1302  }
1303
1304  impl ToTokens for IntegerRepr {
1305    fn to_tokens(&self, tokens: &mut TokenStream) {
1306      match self {
1307        $(
1308          Self::$Xn => tokens.extend(quote!($xn)),
1309        )*
1310      }
1311    }
1312  }
1313)}
1314use mk_repr;
1315
1316#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1317enum Repr {
1318  Rust,
1319  C,
1320  Transparent,
1321  Integer(IntegerRepr),
1322  CWithDiscriminant(IntegerRepr),
1323}
1324
1325impl Repr {
1326  fn as_integer(&self) -> Option<IntegerRepr> {
1327    if let Self::Integer(v) = self {
1328      Some(*v)
1329    } else {
1330      None
1331    }
1332  }
1333}
1334
1335#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1336struct Representation {
1337  packed: Option<u32>,
1338  align: Option<u32>,
1339  repr: Repr,
1340}
1341
1342impl Default for Representation {
1343  fn default() -> Self {
1344    Self { packed: None, align: None, repr: Repr::Rust }
1345  }
1346}
1347
1348impl Parse for Representation {
1349  fn parse(input: ParseStream<'_>) -> Result<Representation> {
1350    let mut ret = Representation::default();
1351    while !input.is_empty() {
1352      let keyword = input.parse::<Ident>()?;
1353      // preƫmptively call `.to_string()` *once* (rather than on `is_ident()`)
1354      let keyword_str = keyword.to_string();
1355      let new_repr = match keyword_str.as_str() {
1356        "C" => Repr::C,
1357        "transparent" => Repr::Transparent,
1358        "packed" => {
1359          ret.packed = Some(if input.peek(token::Paren) {
1360            let contents;
1361            parenthesized!(contents in input);
1362            LitInt::base10_parse::<u32>(&contents.parse()?)?
1363          } else {
1364            1
1365          });
1366          let _: Option<Token![,]> = input.parse()?;
1367          continue;
1368        }
1369        "align" => {
1370          let contents;
1371          parenthesized!(contents in input);
1372          let new_align = LitInt::base10_parse::<u32>(&contents.parse()?)?;
1373          ret.align = Some(
1374            ret
1375              .align
1376              .map_or(new_align, |old_align| cmp::max(old_align, new_align)),
1377          );
1378          let _: Option<Token![,]> = input.parse()?;
1379          continue;
1380        }
1381        ident => {
1382          let primitive = IntegerRepr::try_from(ident)
1383            .map_err(|_| input.error("unrecognized representation hint"))?;
1384          Repr::Integer(primitive)
1385        }
1386      };
1387      ret.repr = match (ret.repr, new_repr) {
1388        (Repr::Rust, new_repr) => {
1389          // This is the first explicit repr.
1390          new_repr
1391        }
1392        (Repr::C, Repr::Integer(integer))
1393        | (Repr::Integer(integer), Repr::C) => {
1394          // Both the C repr and an integer repr have been specified
1395          // -> merge into a C wit discriminant.
1396          Repr::CWithDiscriminant(integer)
1397        }
1398        (_, _) => {
1399          return Err(input.error("duplicate representation hint"));
1400        }
1401      };
1402      let _: Option<Token![,]> = input.parse()?;
1403    }
1404    Ok(ret)
1405  }
1406}
1407
1408impl ToTokens for Representation {
1409  fn to_tokens(&self, tokens: &mut TokenStream) {
1410    let mut meta = Punctuated::<_, Token![,]>::new();
1411
1412    match self.repr {
1413      Repr::Rust => {}
1414      Repr::C => meta.push(quote!(C)),
1415      Repr::Transparent => meta.push(quote!(transparent)),
1416      Repr::Integer(primitive) => meta.push(quote!(#primitive)),
1417      Repr::CWithDiscriminant(primitive) => {
1418        meta.push(quote!(C));
1419        meta.push(quote!(#primitive));
1420      }
1421    }
1422
1423    if let Some(packed) = self.packed.as_ref() {
1424      let lit = LitInt::new(&packed.to_string(), Span::call_site());
1425      meta.push(quote!(packed(#lit)));
1426    }
1427
1428    if let Some(align) = self.align.as_ref() {
1429      let lit = LitInt::new(&align.to_string(), Span::call_site());
1430      meta.push(quote!(align(#lit)));
1431    }
1432
1433    tokens.extend(quote!(
1434      #[repr(#meta)]
1435    ));
1436  }
1437}
1438
1439fn enum_has_fields<'a>(
1440  mut variants: impl Iterator<Item = &'a Variant>,
1441) -> bool {
1442  variants.any(|v| matches!(v.fields, Fields::Named(_) | Fields::Unnamed(_)))
1443}
1444
1445struct VariantDiscriminantIterator<'a, I: Iterator<Item = &'a Variant> + 'a> {
1446  inner: I,
1447  last_value: i128,
1448}
1449
1450impl<'a, I: Iterator<Item = &'a Variant> + 'a>
1451  VariantDiscriminantIterator<'a, I>
1452{
1453  fn new(inner: I) -> Self {
1454    VariantDiscriminantIterator { inner, last_value: -1 }
1455  }
1456}
1457
1458impl<'a, I: Iterator<Item = &'a Variant> + 'a> Iterator
1459  for VariantDiscriminantIterator<'a, I>
1460{
1461  type Item = Result<(i128, &'a Variant)>;
1462
1463  fn next(&mut self) -> Option<Self::Item> {
1464    let variant = self.inner.next()?;
1465
1466    if let Some((_, discriminant)) = &variant.discriminant {
1467      let discriminant_value = match parse_int_expr(discriminant) {
1468        Ok(value) => value,
1469        Err(e) => return Some(Err(e)),
1470      };
1471      self.last_value = discriminant_value;
1472    } else {
1473      // If this wraps, then either:
1474      // 1. the enum is using repr(u128), so wrapping is correct
1475      // 2. the enum is using repr(i<=128 or u<128), so the compiler will
1476      //    already emit a "wrapping discriminant" E0370 error.
1477      self.last_value = self.last_value.wrapping_add(1);
1478      // Static assert that there is no integer repr > 128 bits. If that
1479      // changes, the above comment is inaccurate and needs to be updated!
1480      // FIXME(zachs18): maybe should also do something to ensure `isize::BITS
1481      // <= 128`?
1482      if let Some(repr) = None::<IntegerRepr> {
1483        match repr {
1484          IntegerRepr::U8
1485          | IntegerRepr::I8
1486          | IntegerRepr::U16
1487          | IntegerRepr::I16
1488          | IntegerRepr::U32
1489          | IntegerRepr::I32
1490          | IntegerRepr::U64
1491          | IntegerRepr::I64
1492          | IntegerRepr::I128
1493          | IntegerRepr::U128
1494          | IntegerRepr::Usize
1495          | IntegerRepr::Isize => (),
1496        }
1497      }
1498    }
1499
1500    Some(Ok((self.last_value, variant)))
1501  }
1502}
1503
1504fn parse_int_expr(expr: &Expr) -> Result<i128> {
1505  match expr {
1506    Expr::Unary(ExprUnary { op: UnOp::Neg(_), expr, .. }) => {
1507      parse_int_expr(expr).map(|int| -int)
1508    }
1509    Expr::Lit(ExprLit { lit: Lit::Int(int), .. }) => int.base10_parse(),
1510    Expr::Lit(ExprLit { lit: Lit::Byte(byte), .. }) => Ok(byte.value().into()),
1511    _ => bail!("Not an integer expression"),
1512  }
1513}
1514
1515#[cfg(test)]
1516mod tests {
1517  use syn::parse_quote;
1518
1519  use super::{get_repr, IntegerRepr, Repr, Representation};
1520
1521  #[test]
1522  fn parse_basic_repr() {
1523    let attr = parse_quote!(#[repr(C)]);
1524    let repr = get_repr(&[attr]).unwrap();
1525    assert_eq!(repr, Representation { repr: Repr::C, ..Default::default() });
1526
1527    let attr = parse_quote!(#[repr(transparent)]);
1528    let repr = get_repr(&[attr]).unwrap();
1529    assert_eq!(
1530      repr,
1531      Representation { repr: Repr::Transparent, ..Default::default() }
1532    );
1533
1534    let attr = parse_quote!(#[repr(u8)]);
1535    let repr = get_repr(&[attr]).unwrap();
1536    assert_eq!(
1537      repr,
1538      Representation {
1539        repr: Repr::Integer(IntegerRepr::U8),
1540        ..Default::default()
1541      }
1542    );
1543
1544    let attr = parse_quote!(#[repr(packed)]);
1545    let repr = get_repr(&[attr]).unwrap();
1546    assert_eq!(repr, Representation { packed: Some(1), ..Default::default() });
1547
1548    let attr = parse_quote!(#[repr(packed(1))]);
1549    let repr = get_repr(&[attr]).unwrap();
1550    assert_eq!(repr, Representation { packed: Some(1), ..Default::default() });
1551
1552    let attr = parse_quote!(#[repr(packed(2))]);
1553    let repr = get_repr(&[attr]).unwrap();
1554    assert_eq!(repr, Representation { packed: Some(2), ..Default::default() });
1555
1556    let attr = parse_quote!(#[repr(align(2))]);
1557    let repr = get_repr(&[attr]).unwrap();
1558    assert_eq!(repr, Representation { align: Some(2), ..Default::default() });
1559  }
1560
1561  #[test]
1562  fn parse_advanced_repr() {
1563    let attr = parse_quote!(#[repr(align(4), align(2))]);
1564    let repr = get_repr(&[attr]).unwrap();
1565    assert_eq!(repr, Representation { align: Some(4), ..Default::default() });
1566
1567    let attr1 = parse_quote!(#[repr(align(1))]);
1568    let attr2 = parse_quote!(#[repr(align(4))]);
1569    let attr3 = parse_quote!(#[repr(align(2))]);
1570    let repr = get_repr(&[attr1, attr2, attr3]).unwrap();
1571    assert_eq!(repr, Representation { align: Some(4), ..Default::default() });
1572
1573    let attr = parse_quote!(#[repr(C, u8)]);
1574    let repr = get_repr(&[attr]).unwrap();
1575    assert_eq!(
1576      repr,
1577      Representation {
1578        repr: Repr::CWithDiscriminant(IntegerRepr::U8),
1579        ..Default::default()
1580      }
1581    );
1582
1583    let attr = parse_quote!(#[repr(u8, C)]);
1584    let repr = get_repr(&[attr]).unwrap();
1585    assert_eq!(
1586      repr,
1587      Representation {
1588        repr: Repr::CWithDiscriminant(IntegerRepr::U8),
1589        ..Default::default()
1590      }
1591    );
1592  }
1593}
1594
1595pub fn bytemuck_crate_name(input: &DeriveInput) -> TokenStream {
1596  const ATTR_NAME: &'static str = "crate";
1597
1598  let mut crate_name = quote!(::bytemuck);
1599  for attr in &input.attrs {
1600    if !attr.path().is_ident("bytemuck") {
1601      continue;
1602    }
1603
1604    attr.parse_nested_meta(|meta| {
1605      if meta.path.is_ident(ATTR_NAME) {
1606        let expr: syn::Expr = meta.value()?.parse()?;
1607        let mut value = &expr;
1608        while let syn::Expr::Group(e) = value {
1609          value = &e.expr;
1610        }
1611        if let syn::Expr::Lit(syn::ExprLit {
1612          lit: syn::Lit::Str(lit), ..
1613        }) = value
1614        {
1615          let suffix = lit.suffix();
1616          if !suffix.is_empty() {
1617            bail!(format!("Unexpected suffix `{}` on string literal", suffix))
1618          }
1619          let path: syn::Path = match lit.parse() {
1620            Ok(path) => path,
1621            Err(_) => {
1622              bail!(format!("Failed to parse path: {:?}", lit.value()))
1623            }
1624          };
1625          crate_name = path.into_token_stream();
1626        } else {
1627          bail!(
1628            "Expected bytemuck `crate` attribute to be a string: `crate = \"...\"`",
1629          )
1630        }
1631      }
1632      Ok(())
1633    }).unwrap();
1634  }
1635
1636  return crate_name;
1637}
1638
1639const GENERATED_TYPE_DOCUMENTATION: &str =
1640  " `bytemuck`-generated type for internal purposes only.";