Skip to content

Commit

Permalink
feat: eDSL support for all ALU operations + rename LT to SLTU (#483)
Browse files Browse the repository at this point in the history
  • Loading branch information
stephenh-axiom-xyz authored Oct 1, 2024
1 parent cad2916 commit 71b3f15
Show file tree
Hide file tree
Showing 14 changed files with 520 additions and 90 deletions.
66 changes: 58 additions & 8 deletions compiler/src/asm/compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -349,21 +349,21 @@ impl<F: PrimeField32 + TwoAdicField, EF: ExtensionField<F> + TwoAdicField> AsmCo
debug_info,
);
}
DslIr::AddU256(dst, lhs, rhs) => {
DslIr::Add256(dst, lhs, rhs) => {
self.push(
AsmInstruction::AddU256(dst.ptr_fp(), lhs.ptr_fp(), rhs.ptr_fp()),
AsmInstruction::Add256(dst.ptr_fp(), lhs.ptr_fp(), rhs.ptr_fp()),
debug_info,
);
}
DslIr::SubU256(dst, lhs, rhs) => {
DslIr::Sub256(dst, lhs, rhs) => {
self.push(
AsmInstruction::SubU256(dst.ptr_fp(), lhs.ptr_fp(), rhs.ptr_fp()),
AsmInstruction::Sub256(dst.ptr_fp(), lhs.ptr_fp(), rhs.ptr_fp()),
debug_info,
);
}
DslIr::MulU256(dst, lhs, rhs) => {
DslIr::Mul256(dst, lhs, rhs) => {
self.push(
AsmInstruction::MulU256(dst.ptr_fp(), lhs.ptr_fp(), rhs.ptr_fp()),
AsmInstruction::Mul256(dst.ptr_fp(), lhs.ptr_fp(), rhs.ptr_fp()),
debug_info,
);
}
Expand All @@ -373,9 +373,59 @@ impl<F: PrimeField32 + TwoAdicField, EF: ExtensionField<F> + TwoAdicField> AsmCo
debug_info,
);
}
DslIr::EqualToU256(dst, lhs, rhs) => {
DslIr::EqualTo256(dst, lhs, rhs) => {
self.push(
AsmInstruction::EqualToU256(dst.fp(), lhs.ptr_fp(), rhs.ptr_fp()),
AsmInstruction::EqualTo256(dst.fp(), lhs.ptr_fp(), rhs.ptr_fp()),
debug_info,
);
}
DslIr::Xor256(dst, lhs, rhs) => {
self.push(
AsmInstruction::Xor256(dst.ptr_fp(), lhs.ptr_fp(), rhs.ptr_fp()),
debug_info,
);
}
DslIr::And256(dst, lhs, rhs) => {
self.push(
AsmInstruction::And256(dst.ptr_fp(), lhs.ptr_fp(), rhs.ptr_fp()),
debug_info,
);
}
DslIr::Or256(dst, lhs, rhs) => {
self.push(
AsmInstruction::Or256(dst.ptr_fp(), lhs.ptr_fp(), rhs.ptr_fp()),
debug_info,
);
}
DslIr::LessThanI256(dst, lhs, rhs) => {
self.push(
AsmInstruction::LessThanI256(dst.fp(), lhs.ptr_fp(), rhs.ptr_fp()),
debug_info,
);
}
DslIr::ShiftLeft256(dst, lhs, rhs) => {
self.push(
AsmInstruction::ShiftLeft256(dst.ptr_fp(), lhs.ptr_fp(), rhs.ptr_fp()),
debug_info,
);
}
DslIr::ShiftRightLogic256(dst, lhs, rhs) => {
self.push(
AsmInstruction::ShiftRightLogic256(
dst.ptr_fp(),
lhs.ptr_fp(),
rhs.ptr_fp(),
),
debug_info,
);
}
DslIr::ShiftRightArith256(dst, lhs, rhs) => {
self.push(
AsmInstruction::ShiftRightArith256(
dst.ptr_fp(),
lhs.ptr_fp(),
rhs.ptr_fp(),
),
debug_info,
);
}
Expand Down
76 changes: 59 additions & 17 deletions compiler/src/asm/instruction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,20 +96,41 @@ pub enum AsmInstruction<F, EF> {
/// Modular divide, dst = lhs / rhs.
DivSecp256k1Scalar(i32, i32, i32),

/// uint add, dst = lhs + rhs.
AddU256(i32, i32, i32),
/// int add, dst = lhs + rhs.
Add256(i32, i32, i32),

/// uint subtract, dst = lhs - rhs.
SubU256(i32, i32, i32),
/// int subtract, dst = lhs - rhs.
Sub256(i32, i32, i32),

/// uint multiply, dst = lhs * rhs.
MulU256(i32, i32, i32),
/// int multiply, dst = lhs * rhs.
Mul256(i32, i32, i32),

/// uint less than, dst = lhs < rhs.
LessThanU256(i32, i32, i32),

/// uint equal to, dst = lhs == rhs.
EqualToU256(i32, i32, i32),
/// int equal to, dst = lhs == rhs.
EqualTo256(i32, i32, i32),

/// int bitwise XOR, dst = lhs ^ rhs
Xor256(i32, i32, i32),

/// int bitwise AND, dst = lhs & rhs
And256(i32, i32, i32),

/// int bitwise OR, dst = lhs | rhs
Or256(i32, i32, i32),

/// signed int less than, dst = lhs < rhs
LessThanI256(i32, i32, i32),

/// int shift left, dst = lhs << rhs
ShiftLeft256(i32, i32, i32),

/// int shift right logical, dst = lhs >> rhs
ShiftRightLogic256(i32, i32, i32),

/// int shift right arithmetic, dst = lhs >> rhs
ShiftRightArith256(i32, i32, i32),

/// Jump.
Jump(i32, F),
Expand Down Expand Up @@ -473,20 +494,41 @@ impl<F: PrimeField32, EF: ExtensionField<F>> AsmInstruction<F, EF> {
dst, src1, src2
)
}
AsmInstruction::AddU256(dst, src1, src2) => {
write!(f, "add_u256 ({})fp ({})fp ({})fp", dst, src1, src2)
AsmInstruction::Add256(dst, src1, src2) => {
write!(f, "add_256 ({})fp ({})fp ({})fp", dst, src1, src2)
}
AsmInstruction::SubU256(dst, src1, src2) => {
write!(f, "sub_u256 ({})fp ({})fp ({})fp", dst, src1, src2)
AsmInstruction::Sub256(dst, src1, src2) => {
write!(f, "sub_256 ({})fp ({})fp ({})fp", dst, src1, src2)
}
AsmInstruction::MulU256(dst, src1, src2) => {
write!(f, "mul_u256 ({})fp ({})fp ({})fp", dst, src1, src2)
AsmInstruction::Mul256(dst, src1, src2) => {
write!(f, "mul_256 ({})fp ({})fp ({})fp", dst, src1, src2)
}
AsmInstruction::LessThanU256(dst, src1, src2) => {
write!(f, "lt_u256 ({})fp ({})fp ({})fp", dst, src1, src2)
write!(f, "sltu_256 ({})fp ({})fp ({})fp", dst, src1, src2)
}
AsmInstruction::EqualTo256(dst, src1, src2) => {
write!(f, "eq_256 ({})fp ({})fp ({})fp", dst, src1, src2)
}
AsmInstruction::Xor256(dst, src1, src2) => {
write!(f, "xor_256 ({})fp ({})fp ({})fp", dst, src1, src2)
}
AsmInstruction::And256(dst, src1, src2) => {
write!(f, "and_256 ({})fp ({})fp ({})fp", dst, src1, src2)
}
AsmInstruction::Or256(dst, src1, src2) => {
write!(f, "or_256 ({})fp ({})fp ({})fp", dst, src1, src2)
}
AsmInstruction::LessThanI256(dst, src1, src2) => {
write!(f, "slt_256 ({})fp ({})fp ({})fp", dst, src1, src2)
}
AsmInstruction::ShiftLeft256(dst, src1, src2) => {
write!(f, "sll_256 ({})fp ({})fp ({})fp", dst, src1, src2)
}
AsmInstruction::ShiftRightLogic256(dst, src1, src2) => {
write!(f, "srl_256 ({})fp ({})fp ({})fp", dst, src1, src2)
}
AsmInstruction::EqualToU256(dst, src1, src2) => {
write!(f, "eq_u256 ({})fp ({})fp ({})fp", dst, src1, src2)
AsmInstruction::ShiftRightArith256(dst, src1, src2) => {
write!(f, "sra_256 ({})fp ({})fp ({})fp", dst, src1, src2)
}
}
}
Expand Down
76 changes: 62 additions & 14 deletions compiler/src/conversion/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -759,53 +759,101 @@ fn convert_instruction<F: PrimeField32, EF: ExtensionField<F>>(
AS::Memory,
AS::Memory,
)],
AsmInstruction::AddU256(dst, src1, src2) => vec![inst_large(
AsmInstruction::Add256(dst, src1, src2) => vec![inst(
ADD256,
i32_f(dst),
i32_f(src1),
i32_f(src2),
AS::Memory,
AS::Memory,
AS::Memory.to_field(),
AS::Memory.to_field(),
)],
AsmInstruction::SubU256(dst, src1, src2) => vec![inst_large(
AsmInstruction::Sub256(dst, src1, src2) => vec![inst(
SUB256,
i32_f(dst),
i32_f(src1),
i32_f(src2),
AS::Memory,
AS::Memory,
AS::Memory.to_field(),
AS::Memory.to_field(),
)],
AsmInstruction::MulU256(dst, src1, src2) => vec![inst(
AsmInstruction::Mul256(dst, src1, src2) => vec![inst(
MUL256,
i32_f(dst),
i32_f(src1),
i32_f(src2),
AS::Memory,
AS::Memory,
)],
AsmInstruction::LessThanU256(dst, src1, src2) => vec![inst_large(
LT256,
AsmInstruction::LessThanU256(dst, src1, src2) => vec![inst(
SLTU256,
i32_f(dst),
i32_f(src1),
i32_f(src2),
AS::Memory,
AS::Memory,
AS::Memory.to_field(),
AS::Memory.to_field(),
)],
AsmInstruction::EqualToU256(dst, src1, src2) => vec![inst_large(
AsmInstruction::EqualTo256(dst, src1, src2) => vec![inst(
EQ256,
i32_f(dst),
i32_f(src1),
i32_f(src2),
AS::Memory,
AS::Memory,
AS::Memory.to_field(),
AS::Memory.to_field(),
)],
AsmInstruction::Xor256(dst, src1, src2) => vec![inst(
XOR256,
i32_f(dst),
i32_f(src1),
i32_f(src2),
AS::Memory,
AS::Memory,
)],
AsmInstruction::And256(dst, src1, src2) => vec![inst(
AND256,
i32_f(dst),
i32_f(src1),
i32_f(src2),
AS::Memory,
AS::Memory,
)],
AsmInstruction::Or256(dst, src1, src2) => vec![inst(
OR256,
i32_f(dst),
i32_f(src1),
i32_f(src2),
AS::Memory,
AS::Memory,
)],
AsmInstruction::LessThanI256(dst, src1, src2) => vec![inst(
SLT256,
i32_f(dst),
i32_f(src1),
i32_f(src2),
AS::Memory,
AS::Memory,
)],
AsmInstruction::ShiftLeft256(dst, src1, src2) => vec![inst(
SLL256,
i32_f(dst),
i32_f(src1),
i32_f(src2),
AS::Memory,
AS::Memory,
)],
AsmInstruction::ShiftRightLogic256(dst, src1, src2) => vec![inst(
SRL256,
i32_f(dst),
i32_f(src1),
i32_f(src2),
AS::Memory,
AS::Memory,
)],
AsmInstruction::ShiftRightArith256(dst, src1, src2) => vec![inst(
SRA256,
i32_f(dst),
i32_f(src1),
i32_f(src2),
AS::Memory,
AS::Memory,
)],
AsmInstruction::Keccak256(dst, src, len) => vec![inst_med(
KECCAK256,
Expand Down
35 changes: 27 additions & 8 deletions compiler/src/ir/instructions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ pub enum DslIr<C: Config> {
AddSecp256k1Coord(BigUintVar<C>, BigUintVar<C>, BigUintVar<C>),
/// Add two modular BigInts over scalar field.
AddSecp256k1Scalar(BigUintVar<C>, BigUintVar<C>, BigUintVar<C>),
/// Add two u256
AddU256(BigUintVar<C>, BigUintVar<C>, BigUintVar<C>),
/// Add two 256-bit integers
Add256(BigUintVar<C>, BigUintVar<C>, BigUintVar<C>),

// Subtractions.
/// Subtracts two variables (var = var - var).
Expand Down Expand Up @@ -68,8 +68,8 @@ pub enum DslIr<C: Config> {
SubSecp256k1Coord(BigUintVar<C>, BigUintVar<C>, BigUintVar<C>),
/// Subtracts two modular BigInts over scalar field.
SubSecp256k1Scalar(BigUintVar<C>, BigUintVar<C>, BigUintVar<C>),
/// Subtract two u256
SubU256(BigUintVar<C>, BigUintVar<C>, BigUintVar<C>),
/// Subtract two 256-bit integers
Sub256(BigUintVar<C>, BigUintVar<C>, BigUintVar<C>),

// Multiplications.
/// Multiplies two variables (var = var * var).
Expand All @@ -92,8 +92,8 @@ pub enum DslIr<C: Config> {
MulSecp256k1Coord(BigUintVar<C>, BigUintVar<C>, BigUintVar<C>),
/// Multiplies two modular BigInts over scalar field.
MulSecp256k1Scalar(BigUintVar<C>, BigUintVar<C>, BigUintVar<C>),
/// Multiply two u256
MulU256(BigUintVar<C>, BigUintVar<C>, BigUintVar<C>),
/// Multiply two 256-bit integers
Mul256(BigUintVar<C>, BigUintVar<C>, BigUintVar<C>),

// Divisions.
/// Divides two variables (var = var / var).
Expand Down Expand Up @@ -132,8 +132,27 @@ pub enum DslIr<C: Config> {
LessThanVI(Var<C::N>, Var<C::N>, C::N),
/// Compare two u256 for <
LessThanU256(Ptr<C::N>, BigUintVar<C>, BigUintVar<C>),
/// Compare two u256 for ==
EqualToU256(Ptr<C::N>, BigUintVar<C>, BigUintVar<C>),
/// Compare two 256-bit integers for ==
EqualTo256(Ptr<C::N>, BigUintVar<C>, BigUintVar<C>),
/// Compare two signed 256-bit integers for <
LessThanI256(Ptr<C::N>, BigUintVar<C>, BigUintVar<C>),

// Bitwise operations.
/// Bitwise XOR on two 256-bit integers
Xor256(BigUintVar<C>, BigUintVar<C>, BigUintVar<C>),
/// Bitwise AND on two 256-bit integers
And256(BigUintVar<C>, BigUintVar<C>, BigUintVar<C>),
/// Bitwise OR on two 256-bit integers
Or256(BigUintVar<C>, BigUintVar<C>, BigUintVar<C>),

// Shifts.
/// Shift left on 256-bit integers
ShiftLeft256(BigUintVar<C>, BigUintVar<C>, BigUintVar<C>),
/// Shift right logical on 256-bit integers
ShiftRightLogic256(BigUintVar<C>, BigUintVar<C>, BigUintVar<C>),
/// Shift right arithmetic on 256-bit integers
ShiftRightArith256(BigUintVar<C>, BigUintVar<C>, BigUintVar<C>),

// =======

// Control flow.
Expand Down
Loading

0 comments on commit 71b3f15

Please sign in to comment.