diff --git a/Cargo.lock b/Cargo.lock index 5b9f76c3032..ef34de28d60 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5886,6 +5886,7 @@ dependencies = [ "fuel-tx", "lazy_static", "num-bigint", + "num-traits", "serde", "thiserror", ] diff --git a/sway-core/src/ir_generation/const_eval.rs b/sway-core/src/ir_generation/const_eval.rs index d9e4504fcef..5777020d0d1 100644 --- a/sway-core/src/ir_generation/const_eval.rs +++ b/sway-core/src/ir_generation/const_eval.rs @@ -1,4 +1,4 @@ -use std::ops::{BitAnd, BitOr, BitXor}; +use std::ops::{BitAnd, BitOr, BitXor, Not, Rem}; use crate::{ asm_generation::from_ir::{ir_type_size_in_bytes, ir_type_str_size_in_bytes}, @@ -730,82 +730,124 @@ fn const_eval_intrinsic( assert!(args.len() == intrinsic.arguments.len()); match intrinsic.kind { - sway_ast::Intrinsic::Add - | sway_ast::Intrinsic::Sub - | sway_ast::Intrinsic::Mul - | sway_ast::Intrinsic::Div - | sway_ast::Intrinsic::And - | sway_ast::Intrinsic::Or - | sway_ast::Intrinsic::Xor - | sway_ast::Intrinsic::Mod => { + Intrinsic::Add + | Intrinsic::Sub + | Intrinsic::Mul + | Intrinsic::Div + | Intrinsic::And + | Intrinsic::Or + | Intrinsic::Xor + | Intrinsic::Mod => { let ty = args[0].ty; - assert!( - args.len() == 2 && ty.is_uint(lookup.context) && ty.eq(lookup.context, &args[1].ty) - ); - let (ConstantValue::Uint(arg1), ConstantValue::Uint(ref arg2)) = - (&args[0].value, &args[1].value) - else { - panic!("Type checker allowed incorrect args to binary op"); - }; - - // All arithmetic is done as if it were u64 - let result = match intrinsic.kind { - Intrinsic::Add => arg1.checked_add(*arg2), - Intrinsic::Sub => arg1.checked_sub(*arg2), - Intrinsic::Mul => arg1.checked_mul(*arg2), - Intrinsic::Div => arg1.checked_div(*arg2), - Intrinsic::And => Some(arg1.bitand(arg2)), - Intrinsic::Or => Some(arg1.bitor(*arg2)), - Intrinsic::Xor => Some(arg1.bitxor(*arg2)), - Intrinsic::Mod => arg1.checked_rem(*arg2), - _ => unreachable!(), - }; - - match result { - Some(sum) => Ok(Some(Constant { - ty, - value: ConstantValue::Uint(sum), - })), - None => Err(ConstEvalError::CannotBeEvaluatedToConst { - span: intrinsic.span.clone(), - }), + assert!(args.len() == 2 && ty.eq(lookup.context, &args[1].ty)); + + use ConstantValue::*; + match (&args[0].value, &args[1].value) { + (Uint(arg1), Uint(ref arg2)) => { + // All arithmetic is done as if it were u64 + let result = match intrinsic.kind { + Intrinsic::Add => arg1.checked_add(*arg2), + Intrinsic::Sub => arg1.checked_sub(*arg2), + Intrinsic::Mul => arg1.checked_mul(*arg2), + Intrinsic::Div => arg1.checked_div(*arg2), + Intrinsic::And => Some(arg1.bitand(arg2)), + Intrinsic::Or => Some(arg1.bitor(*arg2)), + Intrinsic::Xor => Some(arg1.bitxor(*arg2)), + Intrinsic::Mod => arg1.checked_rem(*arg2), + _ => unreachable!(), + }; + + match result { + Some(sum) => Ok(Some(Constant { + ty, + value: ConstantValue::Uint(sum), + })), + None => Err(ConstEvalError::CannotBeEvaluatedToConst { + span: intrinsic.span.clone(), + }), + } + } + (U256(arg1), U256(arg2)) => { + let result = match intrinsic.kind { + Intrinsic::Add => arg1.checked_add(arg2), + Intrinsic::Sub => arg1.checked_sub(arg2), + Intrinsic::Mul => arg1.checked_mul(arg2), + Intrinsic::Div => arg1.checked_div(arg2), + Intrinsic::And => Some(arg1.bitand(arg2)), + Intrinsic::Or => Some(arg1.bitor(arg2)), + Intrinsic::Xor => Some(arg1.bitxor(arg2)), + Intrinsic::Mod => Some(arg1.rem(arg2)), + _ => unreachable!(), + }; + + match result { + Some(sum) => Ok(Some(Constant { + ty, + value: ConstantValue::U256(sum), + })), + None => Err(ConstEvalError::CannotBeEvaluatedToConst { + span: intrinsic.span.clone(), + }), + } + } + _ => { + panic!("Type checker allowed incorrect args to binary op"); + } } } - sway_ast::Intrinsic::Lsh | sway_ast::Intrinsic::Rsh => { - let ty = args[0].ty; - assert!( - args.len() == 2 - && ty.is_uint(lookup.context) - && args[1].ty.is_uint64(lookup.context) - ); - - let (ConstantValue::Uint(arg1), ConstantValue::Uint(ref arg2)) = - (&args[0].value, &args[1].value) - else { - panic!("Type checker allowed incorrect args to binary op"); - }; + Intrinsic::Lsh | Intrinsic::Rsh => { + assert!(args.len() == 2); + assert!(args[0].ty.is_uint(lookup.context)); + assert!(args[1].ty.is_uint64(lookup.context)); - let result = match intrinsic.kind { - Intrinsic::Lsh => u32::try_from(*arg2) - .ok() - .and_then(|arg2| arg1.checked_shl(arg2)), - Intrinsic::Rsh => u32::try_from(*arg2) - .ok() - .and_then(|arg2| arg1.checked_shr(arg2)), - _ => unreachable!(), - }; + let ty = args[0].ty; - match result { - Some(sum) => Ok(Some(Constant { - ty, - value: ConstantValue::Uint(sum), - })), - None => Err(ConstEvalError::CannotBeEvaluatedToConst { - span: intrinsic.span.clone(), - }), + use ConstantValue::*; + match (&args[0].value, &args[1].value) { + (Uint(arg1), Uint(ref arg2)) => { + let result = match intrinsic.kind { + Intrinsic::Lsh => u32::try_from(*arg2) + .ok() + .and_then(|arg2| arg1.checked_shl(arg2)), + Intrinsic::Rsh => u32::try_from(*arg2) + .ok() + .and_then(|arg2| arg1.checked_shr(arg2)), + _ => unreachable!(), + }; + + match result { + Some(sum) => Ok(Some(Constant { + ty, + value: ConstantValue::Uint(sum), + })), + None => Err(ConstEvalError::CannotBeEvaluatedToConst { + span: intrinsic.span.clone(), + }), + } + } + (U256(arg1), Uint(ref arg2)) => { + let result = match intrinsic.kind { + Intrinsic::Lsh => arg1.checked_shl(arg2), + Intrinsic::Rsh => Some(arg1.shr(arg2)), + _ => unreachable!(), + }; + + match result { + Some(value) => Ok(Some(Constant { + ty, + value: ConstantValue::U256(value), + })), + None => Err(ConstEvalError::CannotBeEvaluatedToConst { + span: intrinsic.span.clone(), + }), + } + } + _ => { + panic!("Type checker allowed incorrect args to binary op"); + } } } - sway_ast::Intrinsic::SizeOfType => { + Intrinsic::SizeOfType => { let targ = &intrinsic.type_arguments[0]; let ir_type = convert_resolved_typeid( lookup.engines.te(), @@ -820,7 +862,7 @@ fn const_eval_intrinsic( value: ConstantValue::Uint(ir_type_size_in_bytes(lookup.context, &ir_type)), })) } - sway_ast::Intrinsic::SizeOfVal => { + Intrinsic::SizeOfVal => { let val = &intrinsic.arguments[0]; let type_id = val.return_type; let ir_type = convert_resolved_typeid( @@ -836,7 +878,7 @@ fn const_eval_intrinsic( value: ConstantValue::Uint(ir_type_size_in_bytes(lookup.context, &ir_type)), })) } - sway_ast::Intrinsic::SizeOfStr => { + Intrinsic::SizeOfStr => { let targ = &intrinsic.type_arguments[0]; let ir_type = convert_resolved_typeid( lookup.engines.te(), @@ -851,7 +893,7 @@ fn const_eval_intrinsic( value: ConstantValue::Uint(ir_type_str_size_in_bytes(lookup.context, &ir_type)), })) } - sway_ast::Intrinsic::CheckStrType => { + Intrinsic::CheckStrType => { let targ = &intrinsic.type_arguments[0]; let ir_type = convert_resolved_typeid( lookup.engines.te(), @@ -873,77 +915,88 @@ fn const_eval_intrinsic( )), } } - sway_ast::Intrinsic::Eq => { + Intrinsic::Eq => { assert!(args.len() == 2); Ok(Some(Constant { ty: Type::get_bool(lookup.context), value: ConstantValue::Bool(args[0].eq(lookup.context, &args[1])), })) } - sway_ast::Intrinsic::Gt => { - let (ConstantValue::Uint(val1), ConstantValue::Uint(val2)) = - (&args[0].value, &args[1].value) - else { - unreachable!("Type checker allowed non integer value for GreaterThan") - }; - Ok(Some(Constant { + Intrinsic::Gt => match (&args[0].value, &args[1].value) { + (ConstantValue::Uint(val1), ConstantValue::Uint(val2)) => Ok(Some(Constant { ty: Type::get_bool(lookup.context), value: ConstantValue::Bool(val1 > val2), - })) - } - sway_ast::Intrinsic::Lt => { - let (ConstantValue::Uint(val1), ConstantValue::Uint(val2)) = - (&args[0].value, &args[1].value) - else { - unreachable!("Type checker allowed non integer value for LessThan") - }; - Ok(Some(Constant { + })), + (ConstantValue::U256(val1), ConstantValue::U256(val2)) => Ok(Some(Constant { + ty: Type::get_bool(lookup.context), + value: ConstantValue::Bool(val1 > val2), + })), + _ => { + unreachable!("Type checker allowed non integer value for GreaterThan") + } + }, + Intrinsic::Lt => match (&args[0].value, &args[1].value) { + (ConstantValue::Uint(val1), ConstantValue::Uint(val2)) => Ok(Some(Constant { ty: Type::get_bool(lookup.context), value: ConstantValue::Bool(val1 < val2), - })) - } - sway_ast::Intrinsic::AddrOf - | sway_ast::Intrinsic::PtrAdd - | sway_ast::Intrinsic::PtrSub - | sway_ast::Intrinsic::IsReferenceType - | sway_ast::Intrinsic::IsStrType - | sway_ast::Intrinsic::Gtf - | sway_ast::Intrinsic::StateClear - | sway_ast::Intrinsic::StateLoadWord - | sway_ast::Intrinsic::StateStoreWord - | sway_ast::Intrinsic::StateLoadQuad - | sway_ast::Intrinsic::StateStoreQuad - | sway_ast::Intrinsic::Log - | sway_ast::Intrinsic::Revert - | sway_ast::Intrinsic::Smo => Err(ConstEvalError::CannotBeEvaluatedToConst { + })), + (ConstantValue::U256(val1), ConstantValue::U256(val2)) => Ok(Some(Constant { + ty: Type::get_bool(lookup.context), + value: ConstantValue::Bool(val1 < val2), + })), + _ => { + unreachable!("Type checker allowed non integer value for LessThan") + } + }, + Intrinsic::AddrOf + | Intrinsic::PtrAdd + | Intrinsic::PtrSub + | Intrinsic::IsReferenceType + | Intrinsic::IsStrType + | Intrinsic::Gtf + | Intrinsic::StateClear + | Intrinsic::StateLoadWord + | Intrinsic::StateStoreWord + | Intrinsic::StateLoadQuad + | Intrinsic::StateStoreQuad + | Intrinsic::Log + | Intrinsic::Revert + | Intrinsic::Smo => Err(ConstEvalError::CannotBeEvaluatedToConst { span: intrinsic.span.clone(), }), - sway_ast::Intrinsic::Not => { - // Not works only with uint at the moment + Intrinsic::Not => { + // `not` works only with uint/u256 at the moment // `bool` ops::Not implementation uses `__eq`. - assert!(args.len() == 1 && args[0].ty.is_uint(lookup.context)); + assert!(args.len() == 1); + assert!(args[0].ty.is_uint(lookup.context)); let Some(arg) = args.into_iter().next() else { unreachable!("Unexpected 'not' without any arguments"); }; - let ConstantValue::Uint(v) = arg.value else { - unreachable!("Type checker allowed non integer value for Not"); - }; - - let v = match arg.ty.get_uint_width(lookup.context) { - Some(8) => !(v as u8) as u64, - Some(16) => !(v as u16) as u64, - Some(32) => !(v as u32) as u64, - Some(64) => !v, - _ => unreachable!("Invalid unsigned integer width"), - }; - - Ok(Some(Constant { - ty: arg.ty, - value: ConstantValue::Uint(v), - })) + match arg.value { + ConstantValue::Uint(v) => { + let v = match arg.ty.get_uint_width(lookup.context) { + Some(8) => !(v as u8) as u64, + Some(16) => !(v as u16) as u64, + Some(32) => !(v as u32) as u64, + Some(64) => !v, + _ => unreachable!("Invalid unsigned integer width"), + }; + Ok(Some(Constant { + ty: arg.ty, + value: ConstantValue::Uint(v), + })) + } + ConstantValue::U256(v) => Ok(Some(Constant { + ty: arg.ty, + value: ConstantValue::U256(v.not()), + })), + _ => { + unreachable!("Type checker allowed non integer value for Not"); + } + } } } } @@ -1054,6 +1107,43 @@ mod tests { assert_is_constant(true, "", "(0,1).0"); assert_is_constant(true, "", "[0,1][0]"); + // u256 + assert_is_constant( + true, + "", + "0x0000000000000000000000000000000000000000000000000000000000000001u256", + ); + assert_is_constant( + true, + "", + "__add(0x0000000000000000000000000000000000000000000000000000000000000001u256, 0x0000000000000000000000000000000000000000000000000000000000000001u256)", + ); + assert_is_constant( + true, + "", + "__eq(0x0000000000000000000000000000000000000000000000000000000000000001u256, 0x0000000000000000000000000000000000000000000000000000000000000001u256)", + ); + assert_is_constant( + true, + "", + "__gt(0x0000000000000000000000000000000000000000000000000000000000000001u256, 0x0000000000000000000000000000000000000000000000000000000000000001u256)", + ); + assert_is_constant( + true, + "", + "__lt(0x0000000000000000000000000000000000000000000000000000000000000001u256, 0x0000000000000000000000000000000000000000000000000000000000000001u256)", + ); + assert_is_constant( + true, + "", + "__lsh(0x0000000000000000000000000000000000000000000000000000000000000001u256, 2)", + ); + assert_is_constant( + true, + "", + "__not(0x0000000000000000000000000000000000000000000000000000000000000001u256)", + ); + // Expressions that cannot be converted to constant assert_is_constant(false, "", "{ return 1; }"); assert_is_constant(false, "", "{ return 1; 1}"); diff --git a/sway-ir/src/optimize/constants.rs b/sway-ir/src/optimize/constants.rs index e371530df74..4a5d4b1ac64 100644 --- a/sway-ir/src/optimize/constants.rs +++ b/sway-ir/src/optimize/constants.rs @@ -119,25 +119,33 @@ fn combine_cmp(context: &mut Context, function: &Function) -> bool { { let val1 = val1.get_constant(context).unwrap(); let val2 = val2.get_constant(context).unwrap(); + + use ConstantValue::*; match pred { Predicate::Equal => Some((inst_val, block, val1.eq(context, val2))), Predicate::GreaterThan => { - let (ConstantValue::Uint(val1), ConstantValue::Uint(val2)) = - (&val1.value, &val2.value) - else { - unreachable!( - "Type checker allowed non integer value for GreaterThan" - ) + let r = match (&val1.value, &val2.value) { + (Uint(val1), Uint(val2)) => val1 > val2, + (U256(val1), U256(val2)) => val1 > val2, + _ => { + unreachable!( + "Type checker allowed non integer value for GreaterThan" + ) + } }; - Some((inst_val, block, val1 > val2)) + Some((inst_val, block, r)) } Predicate::LessThan => { - let (ConstantValue::Uint(val1), ConstantValue::Uint(val2)) = - (&val1.value, &val2.value) - else { - unreachable!("Type checker allowed non integer value for LessThan") + let r = match (&val1.value, &val2.value) { + (Uint(val1), Uint(val2)) => val1 < val2, + (U256(val1), U256(val2)) => val1 < val2, + _ => { + unreachable!( + "Type checker allowed non integer value for GreaterThan" + ) + } }; - Some((inst_val, block, val1 < val2)) + Some((inst_val, block, r)) } } } @@ -166,63 +174,45 @@ fn combine_binary_op(context: &mut Context, function: &Function) -> bool { { let val1 = arg1.get_constant(context).unwrap(); let val2 = arg2.get_constant(context).unwrap(); - let v = match op { - crate::BinaryOpKind::Add => match (&val1.value, &val2.value) { - (ConstantValue::Uint(l), ConstantValue::Uint(r)) => l.checked_add(*r), - _ => None, - }, - crate::BinaryOpKind::Sub => match (&val1.value, &val2.value) { - (ConstantValue::Uint(l), ConstantValue::Uint(r)) => l.checked_sub(*r), - _ => None, - }, - crate::BinaryOpKind::Mul => match (&val1.value, &val2.value) { - (ConstantValue::Uint(l), ConstantValue::Uint(r)) => l.checked_mul(*r), - _ => None, - }, - crate::BinaryOpKind::Div => match (&val1.value, &val2.value) { - (ConstantValue::Uint(l), ConstantValue::Uint(r)) => l.checked_div(*r), - _ => None, - }, - crate::BinaryOpKind::And => match (&val1.value, &val2.value) { - (ConstantValue::Uint(l), ConstantValue::Uint(r)) => Some(l & r), - _ => None, - }, - crate::BinaryOpKind::Or => match (&val1.value, &val2.value) { - (ConstantValue::Uint(l), ConstantValue::Uint(r)) => Some(l | r), - _ => None, - }, - crate::BinaryOpKind::Xor => match (&val1.value, &val2.value) { - (ConstantValue::Uint(l), ConstantValue::Uint(r)) => Some(l ^ r), - _ => None, - }, - crate::BinaryOpKind::Mod => match (&val1.value, &val2.value) { - (ConstantValue::Uint(l), ConstantValue::Uint(r)) => Some(l % r), - _ => None, - }, - crate::BinaryOpKind::Rsh => match (&val1.value, &val2.value) { - (ConstantValue::Uint(l), ConstantValue::Uint(r)) => { - u32::try_from(*r).ok().and_then(|r| l.checked_shr(r)) - } - _ => None, - }, - crate::BinaryOpKind::Lsh => match (&val1.value, &val2.value) { - (ConstantValue::Uint(l), ConstantValue::Uint(r)) => { - u32::try_from(*r).ok().and_then(|r| l.checked_shl(r)) - } - _ => None, - }, - }; + use crate::BinaryOpKind::*; + use ConstantValue::*; + let v = match (op, &val1.value, &val2.value) { + (Add, Uint(l), Uint(r)) => l.checked_add(*r).map(Uint), + (Add, U256(l), U256(r)) => l.checked_add(r).map(U256), - v.map(|v| { - ( - inst_val, - block, - Constant { - ty: val1.ty, - value: ConstantValue::Uint(v), - }, - ) - }) + (Sub, Uint(l), Uint(r)) => l.checked_sub(*r).map(Uint), + (Sub, U256(l), U256(r)) => l.checked_sub(r).map(U256), + + (Mul, Uint(l), Uint(r)) => l.checked_mul(*r).map(Uint), + (Mul, U256(l), U256(r)) => l.checked_mul(r).map(U256), + + (Div, Uint(l), Uint(r)) => l.checked_div(*r).map(Uint), + (Div, U256(l), U256(r)) => l.checked_div(r).map(U256), + + (And, Uint(l), Uint(r)) => Some(Uint(l & r)), + (And, U256(l), U256(r)) => Some(U256(l & r)), + + (Or, Uint(l), Uint(r)) => Some(Uint(l | r)), + (Or, U256(l), U256(r)) => Some(U256(l | r)), + + (Xor, Uint(l), Uint(r)) => Some(Uint(l ^ r)), + (Xor, U256(l), U256(r)) => Some(U256(l ^ r)), + + (Mod, Uint(l), Uint(r)) => Some(Uint(l % r)), + (Mod, U256(l), U256(r)) => Some(U256(l % r)), + + (Rsh, Uint(l), Uint(r)) => u32::try_from(*r) + .ok() + .and_then(|r| l.checked_shr(r).map(Uint)), + (Rsh, U256(l), Uint(r)) => Some(U256(l.shr(r))), + + (Lsh, Uint(l), Uint(r)) => u32::try_from(*r) + .ok() + .and_then(|r| l.checked_shl(r).map(Uint)), + (Lsh, U256(l), Uint(r)) => l.checked_shl(r).map(U256), + _ => None, + }; + v.map(|value| (inst_val, block, Constant { ty: val1.ty, value })) } _ => None, }, @@ -245,30 +235,23 @@ fn combine_unary_op(context: &mut Context, function: &Function) -> bool { if arg.is_constant(context) => { let val = arg.get_constant(context).unwrap(); - match op { - crate::UnaryOpKind::Not => match &val.value { - ConstantValue::Uint(v) => { - val.ty.get_uint_width(context).and_then(|width| { - let max = match width { - 8 => u8::MAX as u64, - 16 => u16::MAX as u64, - 32 => u32::MAX as u64, - 64 => u64::MAX, - _ => return None, - }; - Some(( - inst_val, - block, - Constant { - ty: val.ty, - value: ConstantValue::Uint((!v) & max), - }, - )) - }) - } - _ => None, - }, - } + use crate::UnaryOpKind::*; + use ConstantValue::*; + let v = match (op, &val.value) { + (Not, Uint(v)) => val.ty.get_uint_width(context).and_then(|width| { + let max = match width { + 8 => u8::MAX as u64, + 16 => u16::MAX as u64, + 32 => u32::MAX as u64, + 64 => u64::MAX, + _ => return None, + }; + Some(Uint((!v) & max)) + }), + (Not, U256(v)) => Some(U256(!v)), + _ => None, + }; + v.map(|value| (inst_val, block, Constant { ty: val.ty, value })) } _ => None, }, @@ -286,20 +269,20 @@ fn combine_unary_op(context: &mut Context, function: &Function) -> bool { mod tests { use crate::optimize::tests::*; - fn assert_operator(opcode: &str, l: &str, r: Option<&str>, result: Option<&str>) { - let expected = result.map(|result| format!("v0 = const u64 {result}")); + fn assert_operator(t: &str, opcode: &str, l: &str, r: Option<&str>, result: Option<&str>) { + let expected = result.map(|result| format!("v0 = const {t} {result}")); let expected = expected.as_ref().map(|x| vec![x.as_str()]); let body = format!( " - entry fn main() -> u64 {{ + entry fn main() -> {t} {{ entry(): - l = const u64 {l} + l = const {t} {l} {r_inst} result = {opcode} l, {result_inst} !0 - ret u64 result + ret {t} result }} ", - r_inst = r.map_or("".into(), |r| format!("r = const u64 {r}")), + r_inst = r.map_or("".into(), |r| format!("r = const {t} {r}")), result_inst = r.map_or("", |_| " r,") ); assert_optimization(&["constcombine"], &body, expected); @@ -307,26 +290,29 @@ mod tests { #[test] fn unary_op_are_optimized() { - assert_operator("not", &u64::MAX.to_string(), None, Some("0")); + assert_operator("u64", "not", &u64::MAX.to_string(), None, Some("0")); } #[test] fn binary_op_are_optimized() { - assert_operator("add", "1", Some("1"), Some("2")); - assert_operator("sub", "1", Some("1"), Some("0")); - assert_operator("mul", "2", Some("2"), Some("4")); - assert_operator("div", "10", Some("5"), Some("2")); - assert_operator("mod", "12", Some("5"), Some("2")); - assert_operator("rsh", "16", Some("1"), Some("8")); - assert_operator("lsh", "16", Some("1"), Some("32")); + // u64 + assert_operator("u64", "add", "1", Some("1"), Some("2")); + assert_operator("u64", "sub", "1", Some("1"), Some("0")); + assert_operator("u64", "mul", "2", Some("2"), Some("4")); + assert_operator("u64", "div", "10", Some("5"), Some("2")); + assert_operator("u64", "mod", "12", Some("5"), Some("2")); + assert_operator("u64", "rsh", "16", Some("1"), Some("8")); + assert_operator("u64", "lsh", "16", Some("1"), Some("32")); assert_operator( + "u64", "and", &0x00FFF.to_string(), Some(&0xFFF00.to_string()), Some(&0xF00.to_string()), ); assert_operator( + "u64", "or", &0x00FFF.to_string(), Some(&0xFFF00.to_string()), @@ -334,6 +320,7 @@ mod tests { ); assert_operator( + "u64", "xor", &0x00FFF.to_string(), Some(&0xFFF00.to_string()), @@ -343,13 +330,13 @@ mod tests { #[test] fn binary_op_are_not_optimized() { - assert_operator("add", &u64::MAX.to_string(), Some("1"), None); - assert_operator("sub", "0", Some("1"), None); - assert_operator("mul", &u64::MAX.to_string(), Some("2"), None); - assert_operator("div", "1", Some("0"), None); + assert_operator("u64", "add", &u64::MAX.to_string(), Some("1"), None); + assert_operator("u64", "sub", "0", Some("1"), None); + assert_operator("u64", "mul", &u64::MAX.to_string(), Some("2"), None); + assert_operator("u64", "div", "1", Some("0"), None); - assert_operator("rsh", "1", Some("64"), None); - assert_operator("lsh", "1", Some("64"), None); + assert_operator("u64", "rsh", "1", Some("64"), None); + assert_operator("u64", "lsh", "1", Some("64"), None); } #[test] diff --git a/sway-ir/tests/constants/u256_cmp.ir b/sway-ir/tests/constants/u256_cmp.ir new file mode 100644 index 00000000000..5f4642751d9 --- /dev/null +++ b/sway-ir/tests/constants/u256_cmp.ir @@ -0,0 +1,11 @@ +script { + fn main() -> bool { + entry(): + v0 = const u256 0x0000000000000000000000000000000000000000000000000000000000000000 + v1 = const u256 0x0000000000000000000000000000000000000000000000000000000000000001 + + v10 = cmp eq v0 v0 +//check: v0 = const bool true + ret bool v10 + } +} \ No newline at end of file diff --git a/sway-ir/tests/constants/u256_ops.ir b/sway-ir/tests/constants/u256_ops.ir new file mode 100644 index 00000000000..fce5fd4379c --- /dev/null +++ b/sway-ir/tests/constants/u256_ops.ir @@ -0,0 +1,28 @@ +script { + fn main() -> u256 { + entry(): + v0 = const u256 0x0000000000000000000000000000000000000000000000000000000000000000 + v1 = const u256 0x0000000000000000000000000000000000000000000000000000000000000001 + v2 = const u256 0x0000000000000000000000000000000000000000000000000000000000000002 + v3 = const u256 0x0000000000000000000000000000000000000000000000000000000000000003 + v4 = const u256 0x0000000000000000000000000000000000000000000000000000000000000004 + v5 = const u256 0x0000000000000000000000000000000000000000000000000000000000000005 + v6 = const u256 0x0000000000000000000000000000000000000000000000000000000000000006 + v7 = const u256 0x0000000000000000000000000000000000000000000000000000000000000007 + v8 = const u64 2 + + v10 = rsh v4, v8 + v11 = lsh v2, v8 + v12 = add v10, v11 + v13 = sub v12, v1 + v14 = mul v13, v2 + v15 = div v14, v4 + v16 = or v15, v2 + v17 = and v16, v4 + v18 = not v17 + v19 = not v18 + v20 = xor v19, v6 +//check: v0 = const u256 0x0000000000000000000000000000000000000000000000000000000000000002 + ret u256 v20 + } +} \ No newline at end of file diff --git a/sway-types/Cargo.toml b/sway-types/Cargo.toml index a99a6279946..10e32679e14 100644 --- a/sway-types/Cargo.toml +++ b/sway-types/Cargo.toml @@ -14,6 +14,7 @@ fuel-crypto = { workspace = true } fuel-tx = { workspace = true } lazy_static = "1.4" num-bigint = "0.4.3" +num-traits = "0.2.16" serde = { version = "1.0", features = ["derive"] } thiserror = "1" diff --git a/sway-types/src/u256.rs b/sway-types/src/u256.rs index 072fb1dc42b..85e853156a5 100644 --- a/sway-types/src/u256.rs +++ b/sway-types/src/u256.rs @@ -1,4 +1,7 @@ +use std::ops::{Not, Shl, Shr}; + use num_bigint::{BigUint, ParseBigIntError, TryFromBigIntError}; +use num_traits::Zero; use thiserror::Error; #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Hash)] @@ -17,6 +20,33 @@ impl U256 { assert!(bytes.len() == 32); bytes.try_into().expect("unexpected vector size") } + + pub fn checked_add(&self, other: &U256) -> Option { + let r = &self.0 + &other.0; + (r.bits() <= 256).then_some(Self(r)) + } + + pub fn checked_sub(&self, other: &U256) -> Option { + (self.0 >= other.0).then(|| Self(&self.0 - &other.0)) + } + + pub fn checked_mul(&self, other: &U256) -> Option { + let r = &self.0 * &other.0; + (r.bits() <= 256).then_some(Self(r)) + } + + pub fn checked_div(&self, other: &U256) -> Option { + other.0.is_zero().not().then(|| Self(&self.0 / &other.0)) + } + + pub fn shr(&self, other: &u64) -> U256 { + U256((&self.0).shr(other)) + } + + pub fn checked_shl(&self, other: &u64) -> Option { + let r = (&self.0).shl(other); + (r.bits() <= 256).then_some(Self(r)) + } } impl std::fmt::Display for U256 { @@ -61,3 +91,47 @@ impl std::str::FromStr for U256 { Ok(Self(v)) } } + +impl<'a> std::ops::BitAnd<&'a U256> for &'a U256 { + type Output = U256; + + fn bitand(self, rhs: Self) -> Self::Output { + U256((&self.0).bitand(&rhs.0)) + } +} + +impl<'a> std::ops::BitOr<&'a U256> for &'a U256 { + type Output = U256; + + fn bitor(self, rhs: Self) -> Self::Output { + U256((&self.0).bitor(&rhs.0)) + } +} + +impl<'a> std::ops::BitXor<&'a U256> for &'a U256 { + type Output = U256; + + fn bitxor(self, rhs: Self) -> Self::Output { + U256((&self.0).bitxor(&rhs.0)) + } +} + +impl<'a> std::ops::Rem<&'a U256> for &'a U256 { + type Output = U256; + + fn rem(self, rhs: Self) -> Self::Output { + U256((&self.0).rem(&rhs.0)) + } +} + +impl<'a> std::ops::Not for &'a U256 { + type Output = U256; + + fn not(self) -> Self::Output { + let mut bytes = self.to_be_bytes(); + for b in bytes.iter_mut() { + *b = !*b; + } + U256(BigUint::from_bytes_be(&bytes)) + } +}