使用枚举和结构生成链式字符串

问题描述 投票:0回答:1

我正在尝试 proc-macro,目标是我可以像包含枚举字段的结构一样链接函数,

看看测试,响应将是

permission::auth::moderator::execute::conditional

main.rs

use ::enum_display::RbacBuilder;

#[derive(Debug, RbacBuilder)]
pub struct PermissionBuilder {
    #[enum_def(Object)]
    object: Option<Object>,
    #[enum_def(Subject)]
    subject: Option<Subject>,
    #[enum_def(Action)]
    action: Option<Action>,
    #[enum_def(Permission)]
    permission: Option<Permission>,
}

#[derive(Debug)]
pub enum Subject {
    User,
    Admin,
    Guest,
    Moderator,
}

#[derive(Debug)]
pub enum Object {
    Auth,
    Role,
    Resource,
    Policy,
}

#[derive(Debug)]
pub enum Action {
    Create,
    Read,
    Update,
    Delete,
    List,
    Execute,
}

#[derive(Debug)]
pub enum Permission {
    Allow,
    Deny,
    Conditional,
}
#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_permission() {
        let result = PermissionBuilder::new()
            .auth()
            .moderator() // New method from added enum variant
            .execute() // New method from added enum variant
            .conditional() // New method from added enum variant
            .finish();

        println!("perm: {:?}", result);
    }
}

我的 proc 宏库 (

enum_display/src/lib.rs
) 看起来像这样:

extern crate proc_macro;
extern crate proc_macro2;
extern crate quote;
extern crate syn;

use proc_macro::TokenStream;
use quote::{format_ident, quote};
use syn::{parse_macro_input, Data, DeriveInput, Fields, Type};

#[proc_macro_derive(RbacBuilder, attributes(enum_def))]
pub fn rbac_builder_derive(input: TokenStream) -> TokenStream {
    let input = parse_macro_input!(input as DeriveInput);
    let name = input.ident;

    let mut methods = vec![];

    if let Data::Struct(data_struct) = &input.data {
        if let Fields::Named(fields_named) = &data_struct.fields {
            for field in fields_named.named.iter() {
                if let Some(ident) = &field.ident {
                    let ty = &field.ty;

                    // Extract the enum name and its variants
                    let enum_name = match ty {
                        Type::Path(type_path) => {
                            if let Some(segment) = type_path.path.segments.last() {
                                // Check if it's an Option<T>
                                if segment.ident == "Option" {
                                    if let syn::PathArguments::AngleBracketed(args) =
                                        &segment.arguments
                                    {
                                        if let Some(syn::GenericArgument::Type(Type::Path(
                                            inner_type_path,
                                        ))) = args.args.first()
                                        {
                                            if let Some(inner_segment) =
                                                inner_type_path.path.segments.last()
                                            {
                                                Some(inner_segment.ident.clone())
                                            } else {
                                                None
                                            }
                                        } else {
                                            None
                                        }
                                    } else {
                                        None
                                    }
                                } else {
                                    Some(segment.ident.clone())
                                }
                            } else {
                                None
                            }
                        }
                        _ => None,
                    };

                    if let Some(enum_name) = enum_name {
                        eprintln!("Detected enum: {:?}", enum_name);

                        // Directly reference the enums instead of finding them in the AST
                        let enum_variants = get_enum_variants(&enum_name.to_string());

                        if let Some(variants) = enum_variants {
                            for variant in variants {
                                let method_name = format_ident!("{}", variant.to_lowercase());
                                let variant_ident = format_ident!("{}", variant);

                                methods.push(quote! {
                                    pub fn #method_name(mut self) -> Self {
                                        self.#ident = Some(#enum_name::#variant_ident);
                                        self
                                    }
                                });
                            }
                        } else {
                            eprintln!("Enum definition not found for {:?}", enum_name);
                        }
                    } else {
                        eprintln!("Field type is not a recognized path: {:?}", ty);
                    }
                } else {
                    eprintln!("Field without an identifier");
                }
            }
        } else {
            eprintln!("Struct does not have named fields");
        }
    } else {
        eprintln!("Not a struct: {:?}", input.data);
    }

    let expanded = quote! {
        impl #name {
            pub fn new() -> Self {
                Self {
                    object: None,
                    subject: None,
                    action: None,
                    permission: None,
                }
            }

            #(#methods)*

            pub fn finish(self) -> String {
                format!(
                    "permission::{:?}::{:?}::{:?}::{:?}",
                    self.object.unwrap(),
                    self.subject.unwrap(),
                    self.action.unwrap(),
                    self.permission.unwrap()
                ).to_lowercase().replace("::", "::")
            }
        }
    };

    TokenStream::from(expanded)
}

fn get_enum_variants(enum_name: &str) -> Option<Vec<&'static str>> {
    match enum_name {
        "Subject" => Some(vec!["User", "Admin", "Guest", "Moderator"]),
        "Object" => Some(vec!["Auth", "Role", "Resource", "Policy"]),
        "Action" => Some(vec![
            "Create", "Read", "Update", "Delete", "List", "Execute",
        ]),
        "Permission" => Some(vec!["Allow", "Deny", "Conditional"]),
        _ => None,
    }
}

这里的问题是我找不到推断枚举值的方法,必须在我的库中对它们进行硬编码。

我如何摆脱硬编码

lib.rs

rust enums quote rust-proc-macros syn
1个回答
0
投票

在我尝试并出错之后,我正在关注@Jmb 评论,结果效果非常好:

// lib.rs

extern crate proc_macro;
extern crate proc_macro2;
extern crate quote;
extern crate syn;

use proc_macro::TokenStream;
use proc_macro2::Span;
use quote::quote;
use quote::ToTokens;
use syn::{parse_macro_input, Attribute, DeriveInput, Ident};

#[proc_macro_derive(RbacEnum, attributes(enum_field))]
pub fn rbac_enum_derive(input: TokenStream) -> TokenStream {
  let input = parse_macro_input!(input as DeriveInput);
  let name = &input.ident;

  // Extract the enum variants
  let data = match input.data {
    syn::Data::Enum(data) => data,
    _ => {
      return TokenStream::from(quote! {
          compile_error!("#[derive(RbacEnum)] can only be used with enums");
      })
    }
  };

  let field_name = match extract_field_name(&input.attrs) {
    Ok(field) => field,
    Err(err) => {
      return TokenStream::from(quote! {
          compile_error!(#err);
      })
    }
  };

  let mut methods = Vec::new();
  for variant in data.variants.iter() {
    let variant_name = &variant.ident;
    let method_name = Ident::new(
      &variant_name.to_string().to_lowercase(),
      variant_name.span(),
    );

    methods.push(quote! {
        impl #name {
            pub fn #method_name() -> Self {
                #name::#variant_name
            }
        }

        impl PermissionBuilder {
            pub fn #method_name(mut self) -> Self {
                self.#field_name = Some(#name::#variant_name);
                self
            }
        }
    });
  }

  let expanded = quote! {
      #(#methods)*
  };

  TokenStream::from(expanded)
}

fn extract_field_name(attrs: &[Attribute]) -> Result<Ident, String> {
  for attr in attrs {
    if attr.path.is_ident("enum_field") {
      if let Ok(meta) = attr.parse_meta() {
        if let syn::Meta::List(list) = meta {
          if let Some(syn::NestedMeta::Meta(syn::Meta::Path(path))) =
            list.nested.first()
          {
            let path_string = path.to_token_stream().to_string();
            let parts: Vec<&str> = path_string.split("::").collect();
            if parts.len() == 2 {
              let field_name = parts[1].trim();
              if !field_name.is_empty() {
                return Ok(Ident::new(field_name, Span::call_site()));
              }
            }
          }
        }
      }
    }
  }
  Err("Missing or invalid enum_field attribute".to_string())
}

lib.rs
现在可以推断枚举值并使用 #[enum_field(...)] 中的 impl 扩展方法

// main.rs

use ::enum_display::RbacEnum;

#[derive(Debug)]
pub struct PermissionBuilder {
  object: Option<Object>,
  subject: Option<Subject>,
  action: Option<Action>,
  permission: Option<Permission>,
}
impl PermissionBuilder {
  pub fn new() -> Self {
    Self {
      object: None,
      subject: None,
      action: None,
      permission: None,
    }
  }

  pub fn finish(self) -> String {
    format!(
      "permission::{:?}::{:?}::{:?}::{:?}",
      self.object.unwrap(),
      self.subject.unwrap(),
      self.action.unwrap(),
      self.permission.unwrap()
    )
    .to_lowercase()
    .replace("::", "::")
  }
}

#[derive(Debug, RbacEnum)]
#[enum_field(PermissionBuilder::subject)]
pub enum Subject {
  User,
  Admin,
  Guest,
  Moderator,
}

#[derive(Debug, RbacEnum)]
#[enum_field(PermissionBuilder::object)]
pub enum Object {
  Auth,
  Role,
  Resource,
  Policy,
}

#[derive(Debug, RbacEnum)]
#[enum_field(PermissionBuilder::action)]
pub enum Action {
  Create,
  Read,
  Update,
  Delete,
  List,
  Execute,
}

#[derive(Debug, RbacEnum)]
#[enum_field(PermissionBuilder::permission)]
pub enum Permission {
  Allow,
  Deny,
  Conditional,
}

#[cfg(test)]
mod tests {
  use super::*;

  #[test]
  fn test_permission() {
    let result = PermissionBuilder::new()
      .auth()
      .moderator() // New method from added enum variant
      .execute() // New method from added enum variant
      .conditional() // New method from added enum variant
      .finish();

    println!("perm: {:?}", result); // print: perm: "permission::auth::moderator::execute::conditional"
  }
}

要使用它,我们可以调用带有 enum_field 属性的 proc_macro

RbacEnum
派生。 由于 enum_field 属性,它将针对 struct
PermissionBuilder

再次感谢@Jmb带路,感谢@cafce25让我对这个实验充满信心!

© www.soinside.com 2019 - 2024. All rights reserved.