Skip to content

Commit

Permalink
Redesign custom_insn_r
Browse files Browse the repository at this point in the history
  • Loading branch information
Golovanov399 committed Jan 8, 2025
1 parent 6fdbf70 commit 76f4a47
Show file tree
Hide file tree
Showing 19 changed files with 437 additions and 212 deletions.
10 changes: 10 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ openvm-transpiler = { path = "crates/toolchain/transpiler", default-features = f
openvm-circuit = { path = "crates/vm", default-features = false }
openvm-circuit-derive = { path = "crates/vm/derive", default-features = false }
openvm-toolchain-tests = { path = "crates/toolchain/tests", default-features = false }
openvm-custom-insn = { path = "crates/toolchain/platform/custom_insn", default-features = false }

# Extensions
openvm-rv32im-circuit = { path = "extensions/rv32im/circuit", default-features = false }
Expand Down
1 change: 1 addition & 0 deletions crates/toolchain/platform/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ repository.workspace = true
[dependencies]
stability = "0.2"
strum_macros.workspace = true
openvm-custom-insn.workspace = true

# This crate should have as few dependencies as possible so it can be
# used as many places as possible to share the platform definitions.
Expand Down
12 changes: 12 additions & 0 deletions crates/toolchain/platform/custom_insn/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
[package]
name = "openvm-custom-insn"
version = "0.1.0"
edition = "2021"

[dependencies]
syn = { version = "2.0", features = ["full"] }
quote = "1.0"
proc-macro2 = "1.0"

[lib]
proc-macro = true
198 changes: 198 additions & 0 deletions crates/toolchain/platform/custom_insn/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
use proc_macro2::{Span, TokenStream};
use syn::{
parse::{Parse, ParseStream},
Ident, Token,
};

enum AsmArg {
In(TokenStream),
Out(TokenStream),
InOut(TokenStream),
ConstExpr(TokenStream),
ConstLit(syn::LitStr),
}

struct CustomInsnR {
pub rd: Option<AsmArg>,
pub rs1: Option<AsmArg>,
pub rs2: Option<AsmArg>,
pub opcode: TokenStream,
pub funct3: TokenStream,
pub funct7: TokenStream,
}

impl Parse for CustomInsnR {
fn parse(input: ParseStream) -> syn::Result<Self> {
let mut rd = None;
let mut rs1 = None;
let mut rs2 = None;
let mut opcode = None;
let mut funct3 = None;
let mut funct7 = None;
while !input.is_empty() {
let key: Ident = input.parse()?;
input.parse::<Token![=]>()?;

let value = if key == "opcode" || key == "funct3" || key == "funct7" {
let mut tokens = TokenStream::new();
while !input.is_empty() && !input.peek(Token![,]) {
tokens.extend(TokenStream::from(input.parse::<proc_macro2::TokenTree>()?));
}
match key.to_string().as_str() {
"opcode" => opcode = Some(tokens),
"funct3" => funct3 = Some(tokens),
"funct7" => funct7 = Some(tokens),
_ => unreachable!(),
}
None
} else {
let lookahead = input.lookahead1();
Some(if lookahead.peek(kw::In) {
input.parse::<kw::In>()?;
let mut tokens = TokenStream::new();
while !input.is_empty() && !input.peek(Token![,]) {
tokens.extend(TokenStream::from(input.parse::<proc_macro2::TokenTree>()?));
}
AsmArg::In(tokens)
} else if lookahead.peek(kw::Out) {
input.parse::<kw::Out>()?;
let mut tokens = TokenStream::new();
while !input.is_empty() && !input.peek(Token![,]) {
tokens.extend(TokenStream::from(input.parse::<proc_macro2::TokenTree>()?));
}
AsmArg::Out(tokens)
} else if lookahead.peek(kw::InOut) {
input.parse::<kw::InOut>()?;
let mut tokens = TokenStream::new();
while !input.is_empty() && !input.peek(Token![,]) {
tokens.extend(TokenStream::from(input.parse::<proc_macro2::TokenTree>()?));
}
AsmArg::InOut(tokens)
} else if lookahead.peek(kw::Const) {
input.parse::<kw::Const>()?;
if input.peek(syn::LitStr) {
AsmArg::ConstLit(input.parse()?)
} else {
let mut tokens = TokenStream::new();
while !input.is_empty() && !input.peek(Token![,]) {
tokens.extend(TokenStream::from(
input.parse::<proc_macro2::TokenTree>()?,
));
}
AsmArg::ConstExpr(tokens)
}
} else {
return Err(lookahead.error());
})
};

match key.to_string().as_str() {
"rd" => rd = value,
"rs1" => rs1 = value,
"rs2" => rs2 = value,
"opcode" | "funct3" | "funct7" => (),
_ => return Err(syn::Error::new(key.span(), "unexpected field")),
}

if !input.is_empty() {
input.parse::<Token![,]>()?;
}
}

let opcode = opcode.ok_or_else(|| syn::Error::new(input.span(), "missing opcode field"))?;
let funct3 = funct3.ok_or_else(|| syn::Error::new(input.span(), "missing funct3 field"))?;
let funct7 = funct7.ok_or_else(|| syn::Error::new(input.span(), "missing funct7 field"))?;

Ok(CustomInsnR {
rd,
rs1,
rs2,
opcode,
funct3,
funct7,
})
}
}

mod kw {
syn::custom_keyword!(In);
syn::custom_keyword!(Out);
syn::custom_keyword!(InOut);
syn::custom_keyword!(Const);
}

#[proc_macro]
pub fn custom_insn_r(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
let CustomInsnR {
rd,
rs1,
rs2,
opcode,
funct3,
funct7,
} = syn::parse_macro_input!(input as CustomInsnR);

let mut template = String::from(".insn r {opcode}, {funct3}, {funct7}");
let mut args = vec![];

// Helper function to handle register arguments
fn handle_reg_arg(
template: &mut String,
args: &mut Vec<proc_macro2::TokenStream>,
arg: &Option<AsmArg>,
reg_name: &str,
) {
let reg_ident = syn::Ident::new(reg_name, Span::call_site());
if let Some(arg) = arg {
match arg {
AsmArg::ConstLit(lit) => {
template.push_str(", ");
template.push_str(&lit.value());
}
AsmArg::In(tokens) => {
template.push_str(", {");
template.push_str(reg_name);
template.push('}');
args.push(quote::quote! { #reg_ident = in(reg) #tokens });
}
AsmArg::Out(tokens) => {
template.push_str(", {");
template.push_str(reg_name);
template.push('}');
args.push(quote::quote! { #reg_ident = out(reg) #tokens });
}
AsmArg::InOut(tokens) => {
template.push_str(", {");
template.push_str(reg_name);
template.push('}');
args.push(quote::quote! { #reg_ident = inout(reg) #tokens });
}
AsmArg::ConstExpr(tokens) => {
template.push_str(", {");
template.push_str(reg_name);
template.push('}');
args.push(quote::quote! { #reg_ident = const #tokens });
}
}
}
}

// Build the template string and args based on which parameters are literals vs expressions
handle_reg_arg(&mut template, &mut args, &rd, "rd");
handle_reg_arg(&mut template, &mut args, &rs1, "rs1");
handle_reg_arg(&mut template, &mut args, &rs2, "rs2");

let expanded = quote::quote! {
unsafe {
core::arch::asm!(
#template,
opcode = const #opcode,
funct3 = const #funct3,
funct7 = const #funct7,
#(#args),*
)
}
};

expanded.into()
}
88 changes: 45 additions & 43 deletions crates/toolchain/platform/src/custom_insn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,46 +31,48 @@ macro_rules! custom_insn_i {
};
}

#[macro_export]
macro_rules! custom_insn_r {
($opcode:expr, $funct3:expr, $funct7:expr, $rd:literal, $rs1:literal, $rs2:literal) => {
unsafe {
core::arch::asm!(concat!(
".insn r {opcode}, {funct3}, {funct7}, ",
$rd,
", ",
$rs1,
", ",
$rs2,
), opcode = const $opcode, funct3 = const $funct3, funct7 = const $funct7)
}
};
($opcode:expr, $funct3:expr, $funct7:expr, $rd:ident, $rs1:literal, $rs2:literal) => {
unsafe {
core::arch::asm!(concat!(
".insn r {opcode}, {funct3}, {funct7}, {rd}, ",
$rs1,
", ",
$rs2,
), opcode = const $opcode, funct3 = const $funct3, funct7 = const $funct7, rd = out(reg) $rd)
}
};
($opcode:expr, $funct3:expr, $funct7:expr, $rd:expr, $rs1:expr, $rs2:literal) => {
// Note: rd = in(reg) because we expect rd to be a pointer
unsafe {
core::arch::asm!(concat!(
".insn r {opcode}, {funct3}, {funct7}, {rd}, {rs1}, ",
$rs2,
), opcode = const $opcode, funct3 = const $funct3, funct7 = const $funct7, rd = in(reg) $rd, rs1 = in(reg) $rs1)
}
};
($opcode:expr, $funct3:expr, $funct7:expr, $rd:expr, $rs1:expr, $rs2:expr) => {
// Note: rd = in(reg) because we expect rd to be a pointer
unsafe {
core::arch::asm!(
".insn r {opcode}, {funct3}, {funct7}, {rd}, {rs1}, {rs2}",
opcode = const $opcode, funct3 = const $funct3, funct7 = const $funct7, rd = in(reg) $rd, rs1 = in(reg) $rs1, rs2 = in(reg) $rs2)
}
};
// TODO: implement more variants with like rs1 = in(reg) $y etc
}
pub use openvm_custom_insn::custom_insn_r;

// #[macro_export]
// macro_rules! custom_insn_r {
// ($opcode:expr, $funct3:expr, $funct7:expr, $rd:literal, $rs1:literal, $rs2:literal) => {
// unsafe {
// core::arch::asm!(concat!(
// ".insn r {opcode}, {funct3}, {funct7}, ",
// $rd,
// ", ",
// $rs1,
// ", ",
// $rs2,
// ), opcode = const $opcode, funct3 = const $funct3, funct7 = const $funct7)
// }
// };
// ($opcode:expr, $funct3:expr, $funct7:expr, $rd:ident, $rs1:literal, $rs2:literal) => {
// unsafe {
// core::arch::asm!(concat!(
// ".insn r {opcode}, {funct3}, {funct7}, {rd}, ",
// $rs1,
// ", ",
// $rs2,
// ), opcode = const $opcode, funct3 = const $funct3, funct7 = const $funct7, rd = out(reg) $rd)
// }
// };
// ($opcode:expr, $funct3:expr, $funct7:expr, $rd:expr, $rs1:expr, $rs2:literal) => {
// // Note: rd = in(reg) because we expect rd to be a pointer
// unsafe {
// core::arch::asm!(concat!(
// ".insn r {opcode}, {funct3}, {funct7}, {rd}, {rs1}, ",
// $rs2,
// ), opcode = const $opcode, funct3 = const $funct3, funct7 = const $funct7, rd = in(reg) $rd, rs1 = in(reg) $rs1)
// }
// };
// ($opcode:expr, $funct3:expr, $funct7:expr, $rd:expr, $rs1:expr, $rs2:expr) => {
// // Note: rd = in(reg) because we expect rd to be a pointer
// unsafe {
// core::arch::asm!(
// ".insn r {opcode}, {funct3}, {funct7}, {rd}, {rs1}, {rs2}",
// opcode = const $opcode, funct3 = const $funct3, funct7 = const $funct7, rd = in(reg) $rd, rs1 = in(reg) $rs1, rs2 = in(reg) $rs2)
// }
// };
// // TODO: implement more variants with like rs1 = in(reg) $y etc
// }
2 changes: 2 additions & 0 deletions crates/toolchain/platform/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

#[cfg(all(feature = "rust-runtime", target_os = "zkvm"))]
pub mod custom_insn;
#[cfg(all(feature = "rust-runtime", target_os = "zkvm"))]
pub use custom_insn::*;
#[cfg(all(feature = "export-getrandom", target_os = "zkvm"))]
mod getrandom;
#[cfg(all(feature = "rust-runtime", target_os = "zkvm"))]
Expand Down
Loading

0 comments on commit 76f4a47

Please sign in to comment.