Skip to content

Commit

Permalink
Merge pull request #372 from Chia-Network/small-int
Browse files Browse the repository at this point in the history
`SmallAtom` optimization
  • Loading branch information
arvidn authored Feb 9, 2024
2 parents bf3c86c + 8f27e57 commit da3d0bd
Show file tree
Hide file tree
Showing 10 changed files with 773 additions and 189 deletions.
601 changes: 514 additions & 87 deletions src/allocator.rs

Large diffs are not rendered by default.

28 changes: 15 additions & 13 deletions src/chia_dialect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ impl Dialect for ChiaDialect {
max_cost: Cost,
extension: OperatorSet,
) -> Response {
let b = allocator.atom(o);
if b.len() == 4 {
let op_len = allocator.atom_len(o);
if op_len == 4 {
// these are unkown operators with assigned cost
// the formula is:
// +---+---+---+------------+
Expand All @@ -83,6 +83,7 @@ impl Dialect for ChiaDialect {
// (3 bytes) + 2 bits
// cost_function

let b = allocator.atom(o);
let opcode = u32::from_be_bytes(b.try_into().unwrap());

// the secp operators have a fixed cost of 1850000 and 1300000,
Expand All @@ -97,10 +98,13 @@ impl Dialect for ChiaDialect {
};
return f(allocator, argument_list, max_cost);
}
if b.len() != 1 {
if op_len != 1 {
return unknown_operator(allocator, o, argument_list, self.flags, max_cost);
}
let f = match b[0] {
let Some(op) = allocator.small_number(o) else {
return unknown_operator(allocator, o, argument_list, self.flags, max_cost);
};
let f = match op {
// 1 = quote
// 2 = apply
3 => op_if,
Expand Down Expand Up @@ -146,7 +150,7 @@ impl Dialect for ChiaDialect {
_ => {
if extension == OperatorSet::BLS || (self.flags & ENABLE_BLS_OPS_OUTSIDE_GUARD) != 0
{
match b[0] {
match op {
48 => op_coinid,
49 => op_bls_g1_subtract,
50 => op_bls_g1_multiply,
Expand Down Expand Up @@ -179,16 +183,14 @@ impl Dialect for ChiaDialect {
f(allocator, argument_list, max_cost)
}

fn quote_kw(&self) -> &[u8] {
&[1]
fn quote_kw(&self) -> u32 {
1
}

fn apply_kw(&self) -> &[u8] {
&[2]
fn apply_kw(&self) -> u32 {
2
}

fn softfork_kw(&self) -> &[u8] {
&[36]
fn softfork_kw(&self) -> u32 {
36
}

// interpret the extension argument passed to the softfork operator, and
Expand Down
6 changes: 3 additions & 3 deletions src/dialect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ pub enum OperatorSet {
}

pub trait Dialect {
fn quote_kw(&self) -> &[u8];
fn apply_kw(&self) -> &[u8];
fn softfork_kw(&self) -> &[u8];
fn quote_kw(&self) -> u32;
fn apply_kw(&self) -> u32;
fn softfork_kw(&self) -> u32;
fn softfork_extension(&self, ext: u32) -> OperatorSet;
fn op(
&self,
Expand Down
77 changes: 61 additions & 16 deletions src/more_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::ops::BitAndAssign;
use std::ops::BitOrAssign;
use std::ops::BitXorAssign;

use crate::allocator::{Allocator, NodePtr, SExp};
use crate::allocator::{len_for_value, Allocator, NodePtr, NodeVisitor, SExp};
use crate::cost::{check_cost, Cost};
use crate::err_utils::err;
use crate::number::Number;
Expand Down Expand Up @@ -365,9 +365,21 @@ pub fn op_add(a: &mut Allocator, mut input: NodePtr, max_cost: Cost) -> Response
cost + (byte_count as Cost * ARITH_COST_PER_BYTE),
max_cost,
)?;
let (v, len) = int_atom(a, arg, "+")?;
byte_count += len;
total += v;

match a.node(arg) {
NodeVisitor::Buffer(buf) => {
use crate::number::number_from_u8;
total += number_from_u8(buf);
byte_count += buf.len();
}
NodeVisitor::U32(val) => {
total += val;
byte_count += len_for_value(val);
}
NodeVisitor::Pair(_, _) => {
return err(arg, "+ requires int args");
}
}
}
let total = a.new_number(total)?;
cost += byte_count as Cost * ARITH_COST_PER_BYTE;
Expand All @@ -383,12 +395,25 @@ pub fn op_subtract(a: &mut Allocator, mut input: NodePtr, max_cost: Cost) -> Res
input = rest;
cost += ARITH_COST_PER_ARG;
check_cost(a, cost + byte_count as Cost * ARITH_COST_PER_BYTE, max_cost)?;
let (v, len) = int_atom(a, arg, "-")?;
byte_count += len;
if is_first {
total += v;
let (v, len) = int_atom(a, arg, "-")?;
byte_count = len;
total = v;
} else {
total -= v;
match a.node(arg) {
NodeVisitor::Buffer(buf) => {
use crate::number::number_from_u8;
total -= number_from_u8(buf);
byte_count += buf.len();
}
NodeVisitor::U32(val) => {
total -= val;
byte_count += len_for_value(val);
}
NodeVisitor::Pair(_, _) => {
return err(arg, "- requires int args");
}
}
};
is_first = false;
}
Expand All @@ -411,14 +436,24 @@ pub fn op_multiply(a: &mut Allocator, mut input: NodePtr, max_cost: Cost) -> Res
continue;
}

let (v0, l1) = int_atom(a, arg, "*")?;
let l1 = match a.node(arg) {
NodeVisitor::Buffer(buf) => {
use crate::number::number_from_u8;
total *= number_from_u8(buf);
buf.len()
}
NodeVisitor::U32(val) => {
total *= val;
len_for_value(val)
}
NodeVisitor::Pair(_, _) => {
return err(arg, "* requires int args");
}
};

total *= v0;
cost += MUL_COST_PER_OP;

cost += (l0 + l1) as Cost * MUL_LINEAR_COST_PER_BYTE;
cost += (l0 * l1) as Cost / MUL_SQUARE_COST_PER_BYTE_DIVIDER;

l0 = limbs_for_int(&total);
}
let total = a.new_number(total)?;
Expand Down Expand Up @@ -490,10 +525,20 @@ pub fn op_mod(a: &mut Allocator, input: NodePtr, _max_cost: Cost) -> Response {

pub fn op_gr(a: &mut Allocator, input: NodePtr, _max_cost: Cost) -> Response {
let [v0, v1] = get_args::<2>(a, input, ">")?;
let (v0, v0_len) = int_atom(a, v0, ">")?;
let (v1, v1_len) = int_atom(a, v1, ">")?;
let cost = GR_BASE_COST + (v0_len + v1_len) as Cost * GR_COST_PER_BYTE;
Ok(Reduction(cost, if v0 > v1 { a.one() } else { a.nil() }))

match (a.small_number(v0), a.small_number(v1)) {
(Some(lhs), Some(rhs)) => {
let cost =
GR_BASE_COST + (len_for_value(lhs) + len_for_value(rhs)) as Cost * GR_COST_PER_BYTE;
Ok(Reduction(cost, if lhs > rhs { a.one() } else { a.nil() }))
}
_ => {
let (v0, v0_len) = int_atom(a, v0, ">")?;
let (v1, v1_len) = int_atom(a, v1, ">")?;
let cost = GR_BASE_COST + (v0_len + v1_len) as Cost * GR_COST_PER_BYTE;
Ok(Reduction(cost, if v0 > v1 { a.one() } else { a.nil() }))
}
}
}

pub fn op_gr_bytes(a: &mut Allocator, input: NodePtr, _max_cost: Cost) -> Response {
Expand Down
81 changes: 39 additions & 42 deletions src/op_utils.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::allocator::{Allocator, NodePtr, SExp};
use crate::allocator::{Allocator, NodePtr, NodeVisitor, SExp};
use crate::cost::Cost;
use crate::err_utils::err;
use crate::number::Number;
Expand Down Expand Up @@ -279,37 +279,36 @@ pub fn uint_atom<const SIZE: usize>(
args: NodePtr,
op_name: &str,
) -> Result<u64, EvalErr> {
let bytes = match a.sexp(args) {
SExp::Atom => a.atom(args),
_ => {
return err(args, &format!("{op_name} requires int arg"));
match a.node(args) {
NodeVisitor::Buffer(bytes) => {
if bytes.is_empty() {
return Ok(0);
}

if (bytes[0] & 0x80) != 0 {
return err(args, &format!("{op_name} requires positive int arg"));
}

// strip leading zeros
let mut buf: &[u8] = bytes;
while !buf.is_empty() && buf[0] == 0 {
buf = &buf[1..];
}

if buf.len() > SIZE {
return err(args, &format!("{op_name} requires u{} arg", SIZE * 8));
}

let mut ret = 0;
for b in buf {
ret <<= 8;
ret |= *b as u64;
}
Ok(ret)
}
};

if bytes.is_empty() {
return Ok(0);
}

if (bytes[0] & 0x80) != 0 {
return err(args, &format!("{op_name} requires positive int arg"));
}

// strip leading zeros
let mut buf: &[u8] = bytes;
while !buf.is_empty() && buf[0] == 0 {
buf = &buf[1..];
}

if buf.len() > SIZE {
return err(args, &format!("{op_name} requires u{} arg", SIZE * 8));
}

let mut ret = 0;
for b in buf {
ret <<= 8;
ret |= *b as u64;
NodeVisitor::U32(val) => Ok(val as u64),
NodeVisitor::Pair(_, _) => err(args, &format!("{op_name} requires int arg")),
}
Ok(ret)
}

#[cfg(test)]
Expand Down Expand Up @@ -532,18 +531,16 @@ fn test_u64_from_bytes() {
}

pub fn i32_atom(a: &Allocator, args: NodePtr, op_name: &str) -> Result<i32, EvalErr> {
let buf = match a.sexp(args) {
SExp::Atom => a.atom(args),
_ => {
return err(args, &format!("{op_name} requires int32 args"));
}
};
match i32_from_u8(buf) {
Some(v) => Ok(v),
_ => err(
args,
&format!("{op_name} requires int32 args (with no leading zeros)"),
),
match a.node(args) {
NodeVisitor::Buffer(buf) => match i32_from_u8(buf) {
Some(v) => Ok(v),
_ => err(
args,
&format!("{op_name} requires int32 args (with no leading zeros)"),
),
},
NodeVisitor::U32(val) => Ok(val as i32),
NodeVisitor::Pair(_, _) => err(args, &format!("{op_name} requires int32 args")),
}
}

Expand Down
Loading

0 comments on commit da3d0bd

Please sign in to comment.