diff --git a/Cargo.lock b/Cargo.lock index a025f7e29..ca1511b57 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -531,12 +531,14 @@ dependencies = [ "ceno_rt", "elf", "itertools 0.13.0", + "num-bigint", "num-derive", "num-traits", "rrs-succinct", "secp", "strum", "strum_macros", + "substrate-bn", "tiny-keccak", "tracing", ] @@ -551,6 +553,7 @@ dependencies = [ "itertools 0.13.0", "rand", "rkyv", + "substrate-bn", "tiny-keccak", ] @@ -1038,7 +1041,9 @@ name = "examples" version = "0.1.0" dependencies = [ "ceno_rt", + "rand", "rkyv", + "substrate-bn", ] [[package]] @@ -2418,6 +2423,12 @@ version = "0.1.24" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f" +[[package]] +name = "rustc-hex" +version = "2.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3e75f6a532d0fd9f7f13144f392b6ad56a32696bfcd9c78f797f16bbb6f072d6" + [[package]] name = "rustix" version = "0.38.44" @@ -2637,6 +2648,19 @@ dependencies = [ "syn 2.0.98", ] +[[package]] +name = "substrate-bn" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b5bbfa79abbae15dd642ea8176a21a635ff3c00059961d1ea27ad04e5b441c" +dependencies = [ + "byteorder", + "crunchy", + "lazy_static", + "rand", + "rustc-hex", +] + [[package]] name = "subtle" version = "2.6.1" diff --git a/Cargo.toml b/Cargo.toml index 615ed199f..7218aef05 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -32,6 +32,7 @@ cfg-if = "1.0" criterion = { version = "0.5", features = ["html_reports"] } crossbeam-channel = "0.5" itertools = "0.13" +num-bigint = { version = "0.4.6" } num-derive = "0.4" num-traits = "0.2" p3-challenger = { git = "https://github.com/scroll-tech/plonky3", rev = "8d2be81" } @@ -57,12 +58,14 @@ serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" strum = "0.26" strum_macros = "0.26" +substrate-bn = { version = "0.6.0" } tiny-keccak = { version = "2.0.2", features = ["keccak"] } tracing = { version = "0.1", features = [ "attributes", ] } tracing-forest = { version = "0.1.6" } tracing-subscriber = { version = "0.3", features = ["env-filter"] } +uint = "0.8" [profile.dev] lto = "thin" diff --git a/ceno_emul/Cargo.toml b/ceno_emul/Cargo.toml index 6478a1597..fe2384032 100644 --- a/ceno_emul/Cargo.toml +++ b/ceno_emul/Cargo.toml @@ -14,12 +14,14 @@ anyhow.workspace = true ceno_rt = { path = "../ceno_rt" } elf = "0.7" itertools.workspace = true +num-bigint.workspace = true num-derive.workspace = true num-traits.workspace = true rrs_lib = { package = "rrs-succinct", version = "0.1.0" } secp.workspace = true strum.workspace = true strum_macros.workspace = true +substrate-bn.workspace = true tiny-keccak.workspace = true tracing.workspace = true diff --git a/ceno_emul/src/host_utils.rs b/ceno_emul/src/host_utils.rs index 5a5956317..a507f66c5 100644 --- a/ceno_emul/src/host_utils.rs +++ b/ceno_emul/src/host_utils.rs @@ -1,6 +1,8 @@ use std::iter::from_fn; -use crate::{ByteAddr, EmuContext, VMState, WordAddr}; +use itertools::Itertools; + +use crate::{ByteAddr, EmuContext, VMState, Word, WordAddr}; const WORD_SIZE: usize = 4; const INFO_OUT_ADDR: WordAddr = ByteAddr(0xC000_0000).waddr(); @@ -17,6 +19,19 @@ pub fn read_all_messages(state: &VMState) -> Vec> { .collect() } +pub fn read_all_messages_as_words(state: &VMState) -> Vec> { + read_all_messages(state) + .iter() + .map(|message| { + assert_eq!(message.len() % WORD_SIZE, 0); + message + .chunks_exact(WORD_SIZE) + .map(|chunk| Word::from_le_bytes(chunk.try_into().unwrap())) + .collect_vec() + }) + .collect_vec() +} + fn read_message(state: &VMState, offset: WordAddr) -> Vec { let out_addr = INFO_OUT_ADDR + offset; let byte_len = state.peek_memory(out_addr) as usize; diff --git a/ceno_emul/src/lib.rs b/ceno_emul/src/lib.rs index 60a4a3037..eb374c783 100644 --- a/ceno_emul/src/lib.rs +++ b/ceno_emul/src/lib.rs @@ -25,6 +25,10 @@ pub mod disassemble; mod syscalls; pub use syscalls::{ KECCAK_PERMUTE, SECP256K1_ADD, SECP256K1_DECOMPRESS, SECP256K1_DOUBLE, SHA_EXTEND, SyscallSpec, + bn254::{ + BN254_FP_WORDS, BN254_FP2_WORDS, BN254_POINT_WORDS, Bn254AddSpec, Bn254DoubleSpec, + Bn254Fp2AddSpec, Bn254Fp2MulSpec, Bn254FpAddSpec, Bn254FpMulSpec, + }, keccak_permute::{KECCAK_WORDS, KeccakSpec}, secp256k1::{ COORDINATE_WORDS, SECP256K1_ARG_WORDS, Secp256k1AddSpec, Secp256k1DecompressSpec, diff --git a/ceno_emul/src/syscalls.rs b/ceno_emul/src/syscalls.rs index d99bd9025..b6fd82ac2 100644 --- a/ceno_emul/src/syscalls.rs +++ b/ceno_emul/src/syscalls.rs @@ -1,6 +1,7 @@ use crate::{RegIdx, Tracer, VMState, Word, WordAddr, WriteOp}; use anyhow::Result; +pub mod bn254; pub mod keccak_permute; pub mod secp256k1; pub mod sha256; @@ -9,6 +10,7 @@ pub mod sha256; // https://github.com/succinctlabs/sp1/blob/013c24ea2fa15a0e7ed94f7d11a7ada4baa39ab9/crates/core/executor/src/syscalls/code.rs pub use ceno_rt::syscalls::{ + BN254_ADD, BN254_DOUBLE, BN254_FP_ADD, BN254_FP_MUL, BN254_FP2_ADD, BN254_FP2_MUL, KECCAK_PERMUTE, SECP256K1_ADD, SECP256K1_DECOMPRESS, SECP256K1_DOUBLE, SHA_EXTEND, }; @@ -28,12 +30,19 @@ pub fn handle_syscall(vm: &VMState, function_code: u32) -> Result Ok(secp256k1::secp256k1_double(vm)), SECP256K1_DECOMPRESS => Ok(secp256k1::secp256k1_decompress(vm)), SHA_EXTEND => Ok(sha256::extend(vm)), + BN254_ADD => Ok(bn254::bn254_add(vm)), + BN254_DOUBLE => Ok(bn254::bn254_double(vm)), + BN254_FP_ADD => Ok(bn254::bn254_fp_add(vm)), + BN254_FP_MUL => Ok(bn254::bn254_fp_mul(vm)), + BN254_FP2_ADD => Ok(bn254::bn254_fp2_add(vm)), + BN254_FP2_MUL => Ok(bn254::bn254_fp2_mul(vm)), // TODO: introduce error types. _ => Err(anyhow::anyhow!("Unknown syscall: {}", function_code)), } } /// A syscall event, available to the circuit witness generators. +/// TODO: separate mem_ops into two stages: reads-and-writes #[derive(Clone, Debug, Default, PartialEq, Eq)] pub struct SyscallWitness { pub mem_ops: Vec, @@ -43,13 +52,6 @@ pub struct SyscallWitness { impl SyscallWitness { fn new(mem_ops: Vec, reg_ops: Vec) -> SyscallWitness { - for (i, op) in mem_ops.iter().enumerate() { - assert_eq!( - op.addr, - mem_ops[0].addr + i, - "Dummy circuit expects that mem_ops addresses are consecutive." - ); - } SyscallWitness { mem_ops, reg_ops, diff --git a/ceno_emul/src/syscalls/bn254/bn254_curve.rs b/ceno_emul/src/syscalls/bn254/bn254_curve.rs new file mode 100644 index 000000000..fb9be520d --- /dev/null +++ b/ceno_emul/src/syscalls/bn254/bn254_curve.rs @@ -0,0 +1,109 @@ +use crate::{ + Change, EmuContext, Platform, SyscallSpec, VMState, Word, WriteOp, + syscalls::{SyscallEffects, SyscallWitness, bn254::types::Bn254Point}, + utils::MemoryView, +}; + +use super::types::BN254_POINT_WORDS; +use itertools::Itertools; + +pub struct Bn254AddSpec; +impl SyscallSpec for Bn254AddSpec { + const NAME: &'static str = "BN254_ADD"; + + const REG_OPS_COUNT: usize = 2; + const MEM_OPS_COUNT: usize = 2 * BN254_POINT_WORDS; + const CODE: u32 = ceno_rt::syscalls::BN254_ADD; +} + +pub struct Bn254DoubleSpec; +impl SyscallSpec for Bn254DoubleSpec { + const NAME: &'static str = "BN254_DOUBLE"; + + const REG_OPS_COUNT: usize = 2; + const MEM_OPS_COUNT: usize = BN254_POINT_WORDS; + const CODE: u32 = ceno_rt::syscalls::BN254_DOUBLE; +} + +pub fn bn254_add(vm: &VMState) -> SyscallEffects { + let p_ptr = vm.peek_register(Platform::reg_arg0()); + let q_ptr = vm.peek_register(Platform::reg_arg1()); + + // Read the argument pointers + let reg_ops = vec![ + WriteOp::new_register_op( + Platform::reg_arg0(), + Change::new(p_ptr, p_ptr), + 0, // Cycle set later in finalize(). + ), + WriteOp::new_register_op( + Platform::reg_arg1(), + Change::new(q_ptr, q_ptr), + 0, // Cycle set later in finalize(). + ), + ]; + + // Memory segments of P and Q + let [mut p_view, q_view] = + [p_ptr, q_ptr].map(|start| MemoryView::::new(vm, start)); + + // Read P and Q from words via wrapper type + let [p, q] = [&p_view, &q_view].map(|view| Bn254Point::from(view.words())); + + // TODO: what does sp1 do with invalid points? equal points? + // Compute the sum and convert back to words + let output_words: [Word; BN254_POINT_WORDS] = (p + q).into(); + + p_view.write(output_words); + + let mem_ops = p_view + .mem_ops() + .into_iter() + .chain(q_view.mem_ops()) + .collect_vec(); + + assert_eq!(mem_ops.len(), 2 * BN254_POINT_WORDS); + SyscallEffects { + witness: SyscallWitness::new(mem_ops, reg_ops), + next_pc: None, + } +} + +pub fn bn254_double(vm: &VMState) -> SyscallEffects { + let p_ptr = vm.peek_register(Platform::reg_arg0()); + + // for compatibility with sp1 spec + assert_eq!(vm.peek_register(Platform::reg_arg1()), 0); + + // Read the argument pointers + let reg_ops = vec![ + WriteOp::new_register_op( + Platform::reg_arg0(), + Change::new(p_ptr, p_ptr), + 0, // Cycle set later in finalize(). + ), + WriteOp::new_register_op( + Platform::reg_arg1(), + Change::new(0, 0), + 0, // Cycle set later in finalize(). + ), + ]; + + // P's memory segment + let mut p_view = MemoryView::::new(vm, p_ptr); + // Create point from words via wrapper type + let p = Bn254Point::from(p_view.words()); + + let result = p.double(); + let output_words: [Word; BN254_POINT_WORDS] = result.into(); + + p_view.write(output_words); + + let mem_ops = p_view.mem_ops().to_vec(); + + assert_eq!(mem_ops.len(), BN254_POINT_WORDS); + SyscallEffects { + witness: SyscallWitness::new(mem_ops, reg_ops), + next_pc: None, + } +} diff --git a/ceno_emul/src/syscalls/bn254/bn254_fptower.rs b/ceno_emul/src/syscalls/bn254/bn254_fptower.rs new file mode 100644 index 000000000..0e7c21db6 --- /dev/null +++ b/ceno_emul/src/syscalls/bn254/bn254_fptower.rs @@ -0,0 +1,114 @@ +use itertools::Itertools; + +use crate::{ + Change, EmuContext, Platform, SyscallSpec, VMState, Word, WriteOp, + syscalls::{ + SyscallEffects, SyscallWitness, + bn254::types::{Bn254Fp, Bn254Fp2}, + }, + utils::MemoryView, +}; + +use super::types::{BN254_FP_WORDS, BN254_FP2_WORDS}; + +pub struct Bn254FpAddSpec; +impl SyscallSpec for Bn254FpAddSpec { + const NAME: &'static str = "BN254_FP_ADD"; + + const REG_OPS_COUNT: usize = 2; + const MEM_OPS_COUNT: usize = 2 * BN254_FP_WORDS; + const CODE: u32 = ceno_rt::syscalls::BN254_FP_ADD; +} + +pub struct Bn254Fp2AddSpec; +impl SyscallSpec for Bn254Fp2AddSpec { + const NAME: &'static str = "BN254_FP2_ADD"; + + const REG_OPS_COUNT: usize = 2; + const MEM_OPS_COUNT: usize = 2 * BN254_FP2_WORDS; + const CODE: u32 = ceno_rt::syscalls::BN254_FP2_ADD; +} + +pub struct Bn254FpMulSpec; +impl SyscallSpec for Bn254FpMulSpec { + const NAME: &'static str = "BN254_FP_MUL"; + + const REG_OPS_COUNT: usize = 2; + const MEM_OPS_COUNT: usize = 2 * BN254_FP_WORDS; + const CODE: u32 = ceno_rt::syscalls::BN254_FP_MUL; +} + +pub struct Bn254Fp2MulSpec; +impl SyscallSpec for Bn254Fp2MulSpec { + const NAME: &'static str = "BN254_FP2_MUL"; + + const REG_OPS_COUNT: usize = 2; + const MEM_OPS_COUNT: usize = 2 * BN254_FP2_WORDS; + const CODE: u32 = ceno_rt::syscalls::BN254_FP2_MUL; +} + +fn bn254_fptower_binary_op< + const WORDS: usize, + const IS_ADD: bool, + F: From<[Word; WORDS]> + + Into<[Word; WORDS]> + + std::ops::Add + + std::ops::Mul, +>( + vm: &VMState, +) -> SyscallEffects { + let p_ptr = vm.peek_register(Platform::reg_arg0()); + let q_ptr = vm.peek_register(Platform::reg_arg1()); + + // Read the argument pointers + let reg_ops = vec![ + WriteOp::new_register_op( + Platform::reg_arg0(), + Change::new(p_ptr, p_ptr), + 0, // Cycle set later in finalize(). + ), + WriteOp::new_register_op( + Platform::reg_arg1(), + Change::new(q_ptr, q_ptr), + 0, // Cycle set later in finalize(). + ), + ]; + let [mut p_view, q_view] = [p_ptr, q_ptr].map(|start| MemoryView::::new(vm, start)); + + let p = F::from(p_view.words()); + let q = F::from(q_view.words()); + let result = match IS_ADD { + true => p + q, + false => p * q, + }; + p_view.write(result.into()); + + let mem_ops = p_view + .mem_ops() + .into_iter() + .chain(q_view.mem_ops()) + .collect_vec(); + + assert_eq!(mem_ops.len(), 2 * WORDS); + + SyscallEffects { + witness: SyscallWitness::new(mem_ops, reg_ops), + next_pc: None, + } +} + +pub fn bn254_fp_add(vm: &VMState) -> SyscallEffects { + bn254_fptower_binary_op::(vm) +} + +pub fn bn254_fp_mul(vm: &VMState) -> SyscallEffects { + bn254_fptower_binary_op::(vm) +} + +pub fn bn254_fp2_add(vm: &VMState) -> SyscallEffects { + bn254_fptower_binary_op::(vm) +} + +pub fn bn254_fp2_mul(vm: &VMState) -> SyscallEffects { + bn254_fptower_binary_op::(vm) +} diff --git a/ceno_emul/src/syscalls/bn254/mod.rs b/ceno_emul/src/syscalls/bn254/mod.rs new file mode 100644 index 000000000..9d15fd41c --- /dev/null +++ b/ceno_emul/src/syscalls/bn254/mod.rs @@ -0,0 +1,12 @@ +mod bn254_curve; +mod bn254_fptower; +mod types; + +pub use bn254_curve::{Bn254AddSpec, Bn254DoubleSpec, bn254_add, bn254_double}; + +pub use bn254_fptower::{ + Bn254Fp2AddSpec, Bn254Fp2MulSpec, Bn254FpAddSpec, Bn254FpMulSpec, bn254_fp_add, bn254_fp_mul, + bn254_fp2_add, bn254_fp2_mul, +}; + +pub use types::{BN254_FP_WORDS, BN254_FP2_WORDS, BN254_POINT_WORDS}; diff --git a/ceno_emul/src/syscalls/bn254/types.rs b/ceno_emul/src/syscalls/bn254/types.rs new file mode 100644 index 000000000..d6312d978 --- /dev/null +++ b/ceno_emul/src/syscalls/bn254/types.rs @@ -0,0 +1,127 @@ +use itertools::Itertools; +use substrate_bn::{AffineG1, Fq, Fq2, Fr, G1}; + +use crate::Word; + +pub const BN254_FP_WORDS: usize = 8; +pub const BN254_FP2_WORDS: usize = 2 * BN254_FP_WORDS; +pub const BN254_POINT_WORDS: usize = 2 * BN254_FP_WORDS; + +pub struct Bn254Fp(substrate_bn::Fq); + +impl From<[Word; BN254_FP_WORDS]> for Bn254Fp { + fn from(value: [Word; BN254_FP_WORDS]) -> Self { + let bytes_be = value + .iter() + .flat_map(|word| word.to_le_bytes()) + .rev() + .collect_vec(); + Bn254Fp(Fq::from_slice(&bytes_be).expect("cannot parse Fq")) + } +} + +impl From for [Word; BN254_FP_WORDS] { + fn from(value: Bn254Fp) -> Self { + let mut bytes_be = [0u8; 32]; + value + .0 + .to_big_endian(&mut bytes_be) + .expect("cannot serialize Fq"); + bytes_be.reverse(); + + bytes_be + .chunks_exact(4) + .map(|chunk| Word::from_le_bytes(chunk.try_into().unwrap())) + .collect_vec() + .try_into() + .unwrap() + } +} + +impl std::ops::Add for Bn254Fp { + type Output = Bn254Fp; + fn add(self, rhs: Self) -> Self::Output { + Bn254Fp(self.0 + rhs.0) + } +} + +impl std::ops::Mul for Bn254Fp { + type Output = Bn254Fp; + fn mul(self, rhs: Self) -> Self::Output { + Bn254Fp(self.0 * rhs.0) + } +} + +pub struct Bn254Fp2(substrate_bn::Fq2); + +impl From<[Word; BN254_FP2_WORDS]> for Bn254Fp2 { + fn from(value: [Word; BN254_FP2_WORDS]) -> Self { + let first_half: [Word; BN254_FP_WORDS] = value[..BN254_FP_WORDS].try_into().unwrap(); + let second_half: [Word; BN254_FP_WORDS] = value[BN254_FP_WORDS..].try_into().unwrap(); + // notation: Fq2 is a + bi (a real and b imaginary) + let a = Bn254Fp::from(first_half).0; + let b = Bn254Fp::from(second_half).0; + Bn254Fp2(Fq2::new(a, b)) + } +} + +impl From for [Word; BN254_FP2_WORDS] { + fn from(value: Bn254Fp2) -> Self { + // notation: Fq2 is a + bi (a real and b imaginary) + let first_half: [Word; BN254_FP_WORDS] = Bn254Fp(value.0.real()).into(); + let second_half: [Word; BN254_FP_WORDS] = Bn254Fp(value.0.imaginary()).into(); + + [first_half, second_half].concat().try_into().unwrap() + } +} + +impl std::ops::Add for Bn254Fp2 { + type Output = Bn254Fp2; + fn add(self, rhs: Self) -> Self::Output { + Bn254Fp2(self.0 + rhs.0) + } +} + +impl std::ops::Mul for Bn254Fp2 { + type Output = Bn254Fp2; + fn mul(self, rhs: Self) -> Self::Output { + Bn254Fp2(self.0 * rhs.0) + } +} + +#[derive(Debug)] +pub struct Bn254Point(substrate_bn::G1); + +impl From<[Word; BN254_POINT_WORDS]> for Bn254Point { + fn from(value: [Word; BN254_POINT_WORDS]) -> Self { + let first_half: [Word; BN254_FP_WORDS] = value[..BN254_FP_WORDS].try_into().unwrap(); + let second_half: [Word; BN254_FP_WORDS] = value[BN254_FP_WORDS..].try_into().unwrap(); + let a = Bn254Fp::from(first_half).0; + let b = Bn254Fp::from(second_half).0; + Bn254Point(G1::new(a, b, Fq::one())) + } +} + +impl From for [Word; BN254_POINT_WORDS] { + fn from(value: Bn254Point) -> Self { + let affine = AffineG1::from_jacobian(value.0).expect("cannot unpack affine"); + let first_half: [Word; BN254_FP_WORDS] = Bn254Fp(affine.x()).into(); + let second_half: [Word; BN254_FP_WORDS] = Bn254Fp(affine.y()).into(); + + [first_half, second_half].concat().try_into().unwrap() + } +} + +impl std::ops::Add for Bn254Point { + type Output = Bn254Point; + fn add(self, rhs: Self) -> Self::Output { + Bn254Point(self.0 + rhs.0) + } +} + +impl Bn254Point { + pub fn double(&self) -> Self { + let two = Fr::from_str("2").unwrap(); + Bn254Point(self.0 * two) + } +} diff --git a/ceno_host/Cargo.toml b/ceno_host/Cargo.toml index 5307a1fef..8b45c96b6 100644 --- a/ceno_host/Cargo.toml +++ b/ceno_host/Cargo.toml @@ -14,6 +14,7 @@ anyhow.workspace = true ceno_emul = { path = "../ceno_emul" } itertools.workspace = true rkyv = { version = "0.8", features = ["pointer_width_32"] } +substrate-bn.workspace = true tiny-keccak.workspace = true [dev-dependencies] diff --git a/ceno_host/tests/test_elf.rs b/ceno_host/tests/test_elf.rs index 4268a6f0e..6353a0cf9 100644 --- a/ceno_host/tests/test_elf.rs +++ b/ceno_host/tests/test_elf.rs @@ -2,8 +2,10 @@ use std::{collections::BTreeSet, iter::from_fn, sync::Arc}; use anyhow::Result; use ceno_emul::{ - CENO_PLATFORM, COORDINATE_WORDS, EmuContext, InsnKind, Platform, Program, SECP256K1_ARG_WORDS, - SHA_EXTEND_WORDS, StepRecord, VMState, WORD_SIZE, WordAddr, host_utils::read_all_messages, + BN254_FP_WORDS, BN254_FP2_WORDS, BN254_POINT_WORDS, CENO_PLATFORM, COORDINATE_WORDS, + EmuContext, InsnKind, Platform, Program, SECP256K1_ARG_WORDS, SHA_EXTEND_WORDS, StepRecord, + VMState, WORD_SIZE, Word, WordAddr, WriteOp, + host_utils::{read_all_messages, read_all_messages_as_words}, }; use ceno_host::CenoStdin; use itertools::{Itertools, enumerate, izip}; @@ -469,6 +471,88 @@ fn test_sha256_extend() -> Result<()> { Ok(()) } +#[test] +fn test_bn254_fptower_syscalls() -> Result<()> { + let program_elf = ceno_examples::bn254_fptower_syscalls; + let mut state = VMState::new_from_elf(unsafe_platform(), program_elf)?; + let steps = run(&mut state)?; + + const RUNS: usize = 10; + let syscalls = steps.iter().filter_map(|step| step.syscall()).collect_vec(); + assert_eq!(syscalls.len(), 4 * RUNS); + + for witness in syscalls.iter() { + assert_eq!(witness.reg_ops.len(), 2); + assert_eq!(witness.reg_ops[0].register_index(), Platform::reg_arg0()); + assert_eq!(witness.reg_ops[1].register_index(), Platform::reg_arg1()); + } + + let messages = read_all_messages_as_words(&state); + let mut m_iter = messages.iter(); + + // just Fp syscalls + for witness in syscalls.iter().take(2 * RUNS) { + assert_eq!(witness.mem_ops.len(), 2 * BN254_FP_WORDS); + let [a_before, b, a_after] = [ + m_iter.next().unwrap(), + m_iter.next().unwrap(), + m_iter.next().unwrap(), + ]; + check_writes(&witness.mem_ops[0..BN254_FP_WORDS], a_before, a_after); + check_reads(&witness.mem_ops[BN254_FP_WORDS..], b); + } + + // just Fp2 syscalls + for witness in syscalls.iter().skip(2 * RUNS) { + assert_eq!(witness.mem_ops.len(), 2 * BN254_FP2_WORDS); + let [a_before, b, a_after] = [ + m_iter.next().unwrap(), + m_iter.next().unwrap(), + m_iter.next().unwrap(), + ]; + check_writes(&witness.mem_ops[0..BN254_FP2_WORDS], a_before, a_after); + check_reads(&witness.mem_ops[BN254_FP2_WORDS..], b); + } + + Ok(()) +} + +#[test] +fn test_bn254_curve() -> Result<()> { + let program_elf = ceno_examples::bn254_curve_syscalls; + let mut state = VMState::new_from_elf(unsafe_platform(), program_elf)?; + let steps = run(&mut state)?; + + let syscalls = steps.iter().filter_map(|step| step.syscall()).collect_vec(); + assert_eq!(syscalls.len(), 3); + + for witness in syscalls.iter() { + assert_eq!(witness.reg_ops.len(), 2); + assert_eq!(witness.reg_ops[0].register_index(), Platform::reg_arg0()); + assert_eq!(witness.reg_ops[1].register_index(), Platform::reg_arg1()); + } + + let messages = read_all_messages_as_words(&state); + let [a1, b, a2, c1, c2, one, c3]: [Vec; 7] = messages.try_into().unwrap(); + + { + assert_eq!(syscalls[0].mem_ops.len(), 2 * BN254_POINT_WORDS); + check_writes(&syscalls[0].mem_ops[..BN254_POINT_WORDS], &a1, &a2); + check_reads(&syscalls[0].mem_ops[BN254_POINT_WORDS..], &b); + } + { + assert_eq!(syscalls[1].mem_ops.len(), BN254_POINT_WORDS); + check_writes(&syscalls[1].mem_ops, &c1, &c2); + } + { + assert_eq!(syscalls[2].mem_ops.len(), 2 * BN254_POINT_WORDS); + check_writes(&syscalls[2].mem_ops[..BN254_POINT_WORDS], &c2, &c3); + check_reads(&syscalls[2].mem_ops[BN254_POINT_WORDS..], &one); + } + + Ok(()) +} + #[test] fn test_syscalls_compatibility() -> Result<()> { let program_elf = ceno_examples::syscalls; @@ -484,6 +568,19 @@ fn unsafe_platform() -> Platform { platform } +fn check_writes(ops: &[WriteOp], before: &[Word], after: &[Word]) { + assert!(ops.len() == before.len() && ops.len() == after.len()); + for (i, _) in ops.iter().enumerate() { + assert_eq!(ops[0].addr + i, ops[i].addr); + assert_eq!(ops[i].value.before, before[i]); + assert_eq!(ops[i].value.after, after[i]); + } +} + +fn check_reads(ops: &[WriteOp], before: &[Word]) { + check_writes(ops, before, before); +} + fn sample_keccak_f(count: usize) -> Vec> { let mut state = [0_u64; 25]; diff --git a/ceno_rt/src/syscalls.rs b/ceno_rt/src/syscalls.rs index 32a87bfb2..5fbeee96d 100644 --- a/ceno_rt/src/syscalls.rs +++ b/ceno_rt/src/syscalls.rs @@ -2,6 +2,16 @@ use core::arch::asm; pub const KECCAK_PERMUTE: u32 = 0x00_01_01_09; +pub const SECP256K1_ADD: u32 = 0x00_01_01_0A; +pub const SECP256K1_DOUBLE: u32 = 0x00_00_01_0B; +pub const SECP256K1_DECOMPRESS: u32 = 0x00_00_01_0C; +pub const SHA_EXTEND: u32 = 0x00_30_01_05; +pub const BN254_ADD: u32 = 0x00_01_01_0E; +pub const BN254_DOUBLE: u32 = 0x00_00_01_0F; +pub const BN254_FP_ADD: u32 = 0x00_01_01_26; +pub const BN254_FP_MUL: u32 = 0x00_01_01_28; +pub const BN254_FP2_ADD: u32 = 0x00_01_01_29; +pub const BN254_FP2_MUL: u32 = 0x00_01_01_2B; /// Based on https://github.com/succinctlabs/sp1/blob/013c24ea2fa15a0e7ed94f7d11a7ada4baa39ab9/crates/zkvm/entrypoint/src/syscalls/keccak_permute.rs /// Executes the Keccak256 permutation on the given state. @@ -25,7 +35,6 @@ pub fn syscall_keccak_permute(state: &mut [u64; 25]) { unreachable!() } -pub const SECP256K1_ADD: u32 = 0x00_01_01_0A; /// Based on https://github.com/succinctlabs/sp1/blob/dbe622aa4a6a33c88d76298c2a29a1d7ef7e90df/crates/zkvm/entrypoint/src/syscalls/secp256k1.rs /// Adds two Secp256k1 points. /// @@ -53,8 +62,6 @@ pub fn syscall_secp256k1_add(p: *mut [u32; 16], q: *mut [u32; 16]) { unreachable!() } -pub const SECP256K1_DOUBLE: u32 = 0x00_00_01_0B; - /// Based on: https://github.com/succinctlabs/sp1/blob/dbe622aa4a6a33c88d76298c2a29a1d7ef7e90df/crates/zkvm/entrypoint/src/syscalls/secp256k1.rs /// Double a Secp256k1 point. /// @@ -80,8 +87,6 @@ pub fn syscall_secp256k1_double(p: *mut [u32; 16]) { unreachable!() } -pub const SECP256K1_DECOMPRESS: u32 = 0x00_00_01_0C; - /// Decompresses a compressed Secp256k1 point. /// /// ### Spec @@ -111,7 +116,6 @@ pub fn syscall_secp256k1_decompress(point: &mut [u8; 64], is_odd: bool) { unreachable!() } -pub const SHA_EXTEND: u32 = 0x00_30_01_05; /// Based on: https://github.com/succinctlabs/sp1/blob/2aed8fea16a67a5b2983ffc471b2942c2f2512c8/crates/zkvm/entrypoint/src/syscalls/sha_extend.rs#L12 /// Executes the SHA256 extend operation on the given word array. /// @@ -134,3 +138,133 @@ pub fn syscall_sha256_extend(w: *mut [u32; 64]) { #[cfg(not(target_os = "zkvm"))] unreachable!() } + +/// Adds two Bn254 points. +/// +/// The result is stored in the first point. +/// +/// ### Safety +/// +/// The caller must ensure that `p` and `q` are valid pointers to data that is aligned along a four +/// byte boundary. +#[allow(unused_variables)] +#[no_mangle] +pub extern "C" fn syscall_bn254_add(p: *mut [u32; 16], q: *const [u32; 16]) { + #[cfg(target_os = "zkvm")] + unsafe { + asm!( + "ecall", + in("t0") BN254_ADD, + in("a0") p, + in("a1") q, + ); + } + + #[cfg(not(target_os = "zkvm"))] + unreachable!() +} + +/// Double a Bn254 point. +/// +/// The result is stored in the first point. +/// +/// ### Safety +/// +/// The caller must ensure that `p` is valid pointer to data that is aligned along a four byte +/// boundary. +#[allow(unused_variables)] +#[no_mangle] +pub extern "C" fn syscall_bn254_double(p: *mut [u32; 16]) { + #[cfg(target_os = "zkvm")] + unsafe { + asm!( + "ecall", + in("t0") BN254_DOUBLE, + in("a0") p, + in("a1") 0, + ); + } + + #[cfg(not(target_os = "zkvm"))] + unreachable!() +} + +/// Fp addition operation. +/// +/// The result is written over the first input. +#[allow(unused_variables)] +#[no_mangle] +pub extern "C" fn syscall_bn254_fp_addmod(x: *mut u32, y: *const u32) { + #[cfg(target_os = "zkvm")] + unsafe { + asm!( + "ecall", + in("t0") BN254_FP_ADD, + in("a0") x, + in("a1") y, + ); + } + + #[cfg(not(target_os = "zkvm"))] + unreachable!() +} + +/// Fp multiplication operation. +/// +/// The result is written over the first input. +#[allow(unused_variables)] +#[no_mangle] +pub extern "C" fn syscall_bn254_fp_mulmod(x: *mut u32, y: *const u32) { + #[cfg(target_os = "zkvm")] + unsafe { + asm!( + "ecall", + in("t0") BN254_FP_MUL, + in("a0") x, + in("a1") y, + ); + } + + #[cfg(not(target_os = "zkvm"))] + unreachable!() +} + +/// BN254 Fp2 addition operation. +/// +/// The result is written over the first input. +#[allow(unused_variables)] +#[no_mangle] +pub extern "C" fn syscall_bn254_fp2_addmod(x: *mut u32, y: *const u32) { + #[cfg(target_os = "zkvm")] + unsafe { + asm!( + "ecall", + in("t0") BN254_FP2_ADD, + in("a0") x, + in("a1") y, + ); + } + + #[cfg(not(target_os = "zkvm"))] + unreachable!() +} + +/// BN254 Fp2 multiplication operation. +/// +/// The result is written over the first input. +#[allow(unused_variables)] +#[no_mangle] +pub extern "C" fn syscall_bn254_fp2_mulmod(x: *mut u32, y: *const u32) { + #[cfg(target_os = "zkvm")] + unsafe { + asm!( + "ecall", + in("t0") BN254_FP2_MUL, + in("a0") x, + in("a1") y, + ); + } + + #[cfg(not(target_os = "zkvm"))] + unreachable!() +} diff --git a/ceno_zkvm/src/instructions/riscv/dummy/dummy_ecall.rs b/ceno_zkvm/src/instructions/riscv/dummy/dummy_ecall.rs index 0325e786e..a4e572935 100644 --- a/ceno_zkvm/src/instructions/riscv/dummy/dummy_ecall.rs +++ b/ceno_zkvm/src/instructions/riscv/dummy/dummy_ecall.rs @@ -1,6 +1,6 @@ use std::marker::PhantomData; -use ceno_emul::{Change, InsnKind, StepRecord, SyscallSpec, WORD_SIZE}; +use ceno_emul::{Change, InsnKind, StepRecord, SyscallSpec}; use ff_ext::ExtensionField; use itertools::Itertools; @@ -31,7 +31,6 @@ impl Instruction for LargeEcallDummy fn name() -> String { format!("{}_DUMMY", S::NAME) } - fn construct_circuit(cb: &mut CircuitBuilder) -> Result { let dummy_insn = DummyConfig::construct_circuit( cb, @@ -59,15 +58,15 @@ impl Instruction for LargeEcallDummy .map(|i| { let val_before = cb.create_witin(|| format!("mem_before_{}", i)); let val_after = cb.create_witin(|| format!("mem_after_{}", i)); - + let addr = cb.create_witin(|| format!("addr_{}", i)); WriteMEM::construct_circuit( cb, - start_addr.expr() + (i * WORD_SIZE) as u64, + addr.expr(), val_before.expr(), val_after.expr(), dummy_insn.ts(), ) - .map(|writer| (Change::new(val_before, val_after), writer)) + .map(|writer| (addr, Change::new(val_before, val_after), writer)) }) .collect::, _>>()?; @@ -101,9 +100,10 @@ impl Instruction for LargeEcallDummy } // Assign memory. - for ((value, writer), op) in config.mem_writes.iter().zip_eq(&ops.mem_ops) { + for ((addr, value, writer), op) in config.mem_writes.iter().zip_eq(&ops.mem_ops) { set_val!(instance, value.before, op.value.before as u64); set_val!(instance, value.after, op.value.after as u64); + set_val!(instance, addr, u64::from(op.addr)); writer.assign_op(instance, lk_multiplicity, step.cycle(), op)?; } @@ -118,5 +118,5 @@ pub struct LargeEcallConfig { reg_writes: Vec<(UInt, WriteRD)>, start_addr: WitIn, - mem_writes: Vec<(Change, WriteMEM)>, + mem_writes: Vec<(WitIn, Change, WriteMEM)>, } diff --git a/ceno_zkvm/src/instructions/riscv/rv32im.rs b/ceno_zkvm/src/instructions/riscv/rv32im.rs index da1df2caf..0b6e5a824 100644 --- a/ceno_zkvm/src/instructions/riscv/rv32im.rs +++ b/ceno_zkvm/src/instructions/riscv/rv32im.rs @@ -24,6 +24,8 @@ use crate::{ }, }; use ceno_emul::{ + Bn254AddSpec, Bn254DoubleSpec, Bn254Fp2AddSpec, Bn254Fp2MulSpec, Bn254FpAddSpec, + Bn254FpMulSpec, InsnKind::{self, *}, KeccakSpec, Platform, Secp256k1AddSpec, Secp256k1DecompressSpec, Secp256k1DoubleSpec, Sha256ExtendSpec, StepRecord, SyscallSpec, @@ -451,6 +453,14 @@ pub struct DummyExtraConfig { as Instruction>::InstructionConfig, sha256_extend_config: as Instruction>::InstructionConfig, + bn254_add_config: as Instruction>::InstructionConfig, + bn254_double_config: as Instruction>::InstructionConfig, + bn254_fp_add_config: as Instruction>::InstructionConfig, + bn254_fp_mul_config: as Instruction>::InstructionConfig, + bn254_fp2_add_config: + as Instruction>::InstructionConfig, + bn254_fp2_mul_config: + as Instruction>::InstructionConfig, } impl DummyExtraConfig { @@ -465,6 +475,18 @@ impl DummyExtraConfig { cs.register_opcode_circuit::>(); let sha256_extend_config = cs.register_opcode_circuit::>(); + let bn254_add_config = cs.register_opcode_circuit::>(); + let bn254_double_config = + cs.register_opcode_circuit::>(); + let bn254_fp_add_config = + cs.register_opcode_circuit::>(); + let bn254_fp_mul_config = + cs.register_opcode_circuit::>(); + let bn254_fp2_add_config = + cs.register_opcode_circuit::>(); + let bn254_fp2_mul_config = + cs.register_opcode_circuit::>(); + Self { ecall_config, keccak_config, @@ -472,6 +494,12 @@ impl DummyExtraConfig { secp256k1_double_config, secp256k1_decompress_config, sha256_extend_config, + bn254_add_config, + bn254_double_config, + bn254_fp_add_config, + bn254_fp_mul_config, + bn254_fp2_add_config, + bn254_fp2_mul_config, } } @@ -486,6 +514,12 @@ impl DummyExtraConfig { fixed.register_opcode_circuit::>(cs); fixed.register_opcode_circuit::>(cs); fixed.register_opcode_circuit::>(cs); + fixed.register_opcode_circuit::>(cs); + fixed.register_opcode_circuit::>(cs); + fixed.register_opcode_circuit::>(cs); + fixed.register_opcode_circuit::>(cs); + fixed.register_opcode_circuit::>(cs); + fixed.register_opcode_circuit::>(cs); } pub fn assign_opcode_circuit( @@ -501,6 +535,12 @@ impl DummyExtraConfig { let mut secp256k1_double_steps = Vec::new(); let mut secp256k1_decompress_steps = Vec::new(); let mut sha256_extend_steps = Vec::new(); + let mut bn254_add_steps = Vec::new(); + let mut bn254_double_steps = Vec::new(); + let mut bn254_fp_add_steps = Vec::new(); + let mut bn254_fp_mul_steps = Vec::new(); + let mut bn254_fp2_add_steps = Vec::new(); + let mut bn254_fp2_mul_steps = Vec::new(); let mut other_steps = Vec::new(); if let Some(ecall_steps) = steps.remove(&ECALL) { @@ -511,6 +551,12 @@ impl DummyExtraConfig { Secp256k1DoubleSpec::CODE => secp256k1_double_steps.push(step), Secp256k1DecompressSpec::CODE => secp256k1_decompress_steps.push(step), Sha256ExtendSpec::CODE => sha256_extend_steps.push(step), + Bn254AddSpec::CODE => bn254_add_steps.push(step), + Bn254DoubleSpec::CODE => bn254_double_steps.push(step), + Bn254FpAddSpec::CODE => bn254_fp_add_steps.push(step), + Bn254FpMulSpec::CODE => bn254_fp_mul_steps.push(step), + Bn254Fp2AddSpec::CODE => bn254_fp2_add_steps.push(step), + Bn254Fp2MulSpec::CODE => bn254_fp2_mul_steps.push(step), _ => other_steps.push(step), } } @@ -541,6 +587,36 @@ impl DummyExtraConfig { &self.sha256_extend_config, sha256_extend_steps, )?; + witness.assign_opcode_circuit::>( + cs, + &self.bn254_add_config, + bn254_add_steps, + )?; + witness.assign_opcode_circuit::>( + cs, + &self.bn254_double_config, + bn254_double_steps, + )?; + witness.assign_opcode_circuit::>( + cs, + &self.bn254_fp_add_config, + bn254_fp_add_steps, + )?; + witness.assign_opcode_circuit::>( + cs, + &self.bn254_fp_mul_config, + bn254_fp_mul_steps, + )?; + witness.assign_opcode_circuit::>( + cs, + &self.bn254_fp2_add_config, + bn254_fp2_add_steps, + )?; + witness.assign_opcode_circuit::>( + cs, + &self.bn254_fp2_mul_config, + bn254_fp2_mul_steps, + )?; witness.assign_opcode_circuit::>(cs, &self.ecall_config, other_steps)?; let _ = steps.remove(&INVALID); diff --git a/examples/Cargo.toml b/examples/Cargo.toml index 6be7ab795..6dccced38 100644 --- a/examples/Cargo.toml +++ b/examples/Cargo.toml @@ -11,7 +11,9 @@ version = "0.1.0" [dependencies] ceno_rt = { path = "../ceno_rt" } +rand.workspace = true rkyv = { version = "0.8", default-features = false, features = [ "alloc", "bytecheck", ] } +substrate-bn.workspace = true diff --git a/examples/examples/bn254_curve_syscalls.rs b/examples/examples/bn254_curve_syscalls.rs new file mode 100644 index 000000000..3f5715fe3 --- /dev/null +++ b/examples/examples/bn254_curve_syscalls.rs @@ -0,0 +1,75 @@ +// Test addition of two curve points. Assert result inside the guest +extern crate ceno_rt; +use ceno_rt::{ + info_out, + syscalls::{syscall_bn254_add, syscall_bn254_double}, +}; +use std::slice; + +use substrate_bn::{AffineG1, Fr, G1, Group}; +fn bytes_to_words(bytes: [u8; 64]) -> [u32; 16] { + let mut bytes = bytes; + // Reverse the order of bytes for each coordinate + bytes[0..32].reverse(); + bytes[32..].reverse(); + std::array::from_fn(|i| u32::from_le_bytes(bytes[4 * i..4 * (i + 1)].try_into().unwrap())) +} + +fn g1_to_words(elem: G1) -> [u32; 16] { + let elem = AffineG1::from_jacobian(elem).unwrap(); + let mut x_bytes = [0u8; 32]; + elem.x().to_big_endian(&mut x_bytes).unwrap(); + let mut y_bytes = [0u8; 32]; + elem.y().to_big_endian(&mut y_bytes).unwrap(); + + let mut bytes = [0u8; 64]; + bytes[..32].copy_from_slice(&x_bytes); + bytes[32..].copy_from_slice(&y_bytes); + + bytes_to_words(bytes) +} + +fn main() { + let log_flag = true; + let log_state = |state: &[u32]| { + if log_flag { + let out = unsafe { + slice::from_raw_parts(state.as_ptr() as *const u8, std::mem::size_of_val(state)) + }; + info_out().write_frame(out); + } + }; + + let a = G1::one() * Fr::from_str("237").unwrap(); + let b = G1::one() * Fr::from_str("450").unwrap(); + let mut a = g1_to_words(a); + let b = g1_to_words(b); + + log_state(&a); + log_state(&b); + + syscall_bn254_add(&mut a, &b); + + assert_eq!(a, [ + 3533671058, 384027398, 1667527989, 405931240, 1244739547, 3008185164, 3438692308, + 533547881, 4111479971, 1966599592, 1118334819, 3045025257, 3188923637, 1210932908, + 947531184, 656119894 + ]); + log_state(&a); + + let c = G1::one() * Fr::from_str("343").unwrap(); + let mut c = g1_to_words(c); + log_state(&c); + + syscall_bn254_double(&mut c); + log_state(&c); + + let one = g1_to_words(G1::one()); + log_state(&one); + + syscall_bn254_add(&mut c, &one); + log_state(&c); + + // 2 * 343 + 1 == 237 + 450, one hopes + assert_eq!(a, c); +} diff --git a/examples/examples/bn254_fptower_syscalls.rs b/examples/examples/bn254_fptower_syscalls.rs new file mode 100644 index 000000000..aabbf81f6 --- /dev/null +++ b/examples/examples/bn254_fptower_syscalls.rs @@ -0,0 +1,102 @@ +extern crate ceno_rt; +use ceno_rt::{ + info_out, + syscalls::{ + syscall_bn254_fp_addmod, syscall_bn254_fp_mulmod, syscall_bn254_fp2_addmod, + syscall_bn254_fp2_mulmod, + }, +}; +use rand::{SeedableRng, rngs::StdRng}; +use std::slice; +use substrate_bn::{Fq, Fq2}; + +fn bytes_to_words(bytes: [u8; 32]) -> [u32; 8] { + std::array::from_fn(|i| u32::from_le_bytes(bytes[4 * i..4 * (i + 1)].try_into().unwrap())) +} + +fn fq_to_words(val: Fq) -> [u32; 8] { + let mut bytes = [0u8; 32]; + val.to_big_endian(&mut bytes).unwrap(); + bytes.reverse(); + bytes_to_words(bytes) +} + +fn fq2_to_words(val: Fq2) -> [u32; 16] { + [fq_to_words(val.real()), fq_to_words(val.imaginary())] + .concat() + .try_into() + .unwrap() +} + +fn main() { + let log_flag = true; + + let log_state = |state: &[u32]| { + if log_flag { + let out = unsafe { + slice::from_raw_parts(state.as_ptr() as *const u8, std::mem::size_of_val(state)) + }; + info_out().write_frame(out); + } + }; + + let mut a = Fq::one(); + let mut b = Fq::one(); + let seed = [0u8; 32]; + let mut rng = StdRng::from_seed(seed); + const RUNS: usize = 10; + + for _ in 0..RUNS { + let mut a_words = fq_to_words(a); + + let a_backup = a_words; + let b_words = fq_to_words(b); + + log_state(&a_words); + log_state(&b_words); + syscall_bn254_fp_addmod(&mut a_words[0], &b_words[0]); + let sum_words = fq_to_words(a + b); + assert_eq!(a_words, sum_words); + log_state(&a_words); + + a_words.copy_from_slice(&a_backup); + + log_state(&a_words); + log_state(&b_words); + syscall_bn254_fp_mulmod(&mut a_words[0], &b_words[0]); + let prod_words = fq_to_words(a * b); + assert_eq!(a_words, prod_words); + log_state(&a_words); + + a = Fq::random(&mut rng); + b = Fq::random(&mut rng); + } + + let mut a = Fq2::one(); + let mut b = Fq2::one(); + + for _ in 0..RUNS { + let mut a_words = fq2_to_words(a); + let a_backup = a_words; + let b_words = fq2_to_words(b); + + log_state(&a_words); + log_state(&b_words); + syscall_bn254_fp2_addmod(&mut a_words[0], &b_words[0]); + let sum_words = fq2_to_words(a + b); + assert_eq!(a_words, sum_words); + log_state(&a_words); + + a_words.copy_from_slice(&a_backup); + + log_state(&a_words); + log_state(&b_words); + syscall_bn254_fp2_mulmod(&mut a_words[0], &b_words[0]); + let prod_words = fq2_to_words(a * b); + assert_eq!(a_words, prod_words); + log_state(&a_words); + + a = Fq2::new(Fq::random(&mut rng), Fq::random(&mut rng)); + b = Fq2::new(Fq::random(&mut rng), Fq::random(&mut rng)); + } +}