Skip to content

Commit

Permalink
fixes after review
Browse files Browse the repository at this point in the history
  • Loading branch information
popenta committed Jan 13, 2025
1 parent 6a37cdb commit e1f3a41
Show file tree
Hide file tree
Showing 5 changed files with 124 additions and 41 deletions.
8 changes: 3 additions & 5 deletions multiversx_sdk/abi/abi.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,13 +323,11 @@ def _create_prototype(self, type_formula: TypeFormula) -> Any:
scale = type_formula.type_parameters[0].name

if scale == "usize":
scale = 0
is_variable = True
return ManagedDecimalValue(scale=0, is_variable=True)
else:
scale = int(scale)
is_variable = False
return ManagedDecimalValue(scale=int(scale), is_variable=False)

return ManagedDecimalValue(scale=scale, is_variable=is_variable)

if name == "ManagedDecimalSigned":
scale = type_formula.type_parameters[0].name

Expand Down
42 changes: 42 additions & 0 deletions multiversx_sdk/abi/abi_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,3 +354,45 @@ def test_managed_decimals():
assert second_input.is_variable
assert second_input.scale == 0
assert second_input.value == Decimal(0)


def test_encode_decode_managed_decimals():
abi_definition = AbiDefinition.from_dict(
{
"endpoints": [
{
"name": "dummy",
"inputs": [{"type": "ManagedDecimal<18>"}],
"outputs": [],
},
{
"name": "foo",
"inputs": [{"name": "x", "type": "ManagedDecimal<usize>"}],
"outputs": [{"type": "ManagedDecimalSigned<9>"}],
},
{
"name": "foobar",
"inputs": [{"name": "x", "type": "ManagedDecimal<usize>"}],
"outputs": [{"type": "ManagedDecimal<usize>"}],
},
]
}
)

abi = Abi(abi_definition)

values = abi.encode_endpoint_input_parameters("dummy", [1])
assert values[0].hex() == "01"

values = abi.encode_endpoint_input_parameters("foo", [ManagedDecimalValue(7, 2, True)])
assert values[0].hex() == "0702"

values = abi.decode_endpoint_output_parameters("foo", [bytes.fromhex("07")])
assert values[0].get_payload() == Decimal("7")
assert values[0].scale == Decimal("7")
assert values[0].to_string() == "7.000000000"

values = abi.decode_endpoint_output_parameters("foobar", [bytes.fromhex("0700000003")])
assert values[0].get_payload() == Decimal("7")
assert values[0].scale == Decimal("3")
assert values[0].to_string() == "7.000"
31 changes: 13 additions & 18 deletions multiversx_sdk/abi/managed_decimal_signed_value.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import io
from decimal import ROUND_DOWN, Decimal
from decimal import Decimal
from typing import Any, Union

from multiversx_sdk.abi.bigint_value import BigIntValue
Expand All @@ -17,7 +17,7 @@ def __init__(self, value: Union[int, str] = 0, scale: int = 0, is_variable: bool
def set_payload(self, value: Any):
if isinstance(value, ManagedDecimalSignedValue):
if self.is_variable != value.is_variable:
raise Exception("Cannot set payload! Both ManagedDecimalValues should be variable.")
raise Exception("Cannot set payload! Both managed decimal values should be variable.")

self.value = value.value

Expand Down Expand Up @@ -46,53 +46,48 @@ def decode_top_level(self, data: bytes):
self.scale = 0
return

bigint = BigIntValue()
value = BigIntValue()
scale = U32Value()

if self.is_variable:
# read biguint value length in bytes
big_int_size = self._unsigned_from_bytes(data[:U32_SIZE_IN_BYTES])
value_length = self._unsigned_from_bytes(data[:U32_SIZE_IN_BYTES])

# remove biguint length; data is only biguint value and scale
data = data[U32_SIZE_IN_BYTES:]

# read biguint value
bigint.decode_top_level(data[:big_int_size])
value.decode_top_level(data[:value_length])

# remove biguintvalue; data contains only scale
data = data[big_int_size:]
data = data[value_length:]

# read scale
scale.decode_top_level(data)
self.scale = scale.get_payload()
else:
bigint.decode_top_level(data)
value.decode_top_level(data)

self.value = self._convert_to_decimal(bigint.get_payload())
self.value = self._convert_to_decimal(value.get_payload())

def decode_nested(self, reader: io.BytesIO):
length = self._unsigned_from_bytes(read_bytes_exactly(reader, U32_SIZE_IN_BYTES))
payload = read_bytes_exactly(reader, length)
self.decode_top_level(payload)

def to_string(self) -> str:
value_str = str(self._convert_value_to_int())
if self.scale == 0:
return value_str
if len(value_str) <= self.scale:
# If the value is smaller than the scale, prepend zeros
value_str = "0" * (self.scale - len(value_str) + 1) + value_str
return f"{value_str[:-self.scale]}.{value_str[-self.scale:]}"
scaled_value = self._convert_value_to_int()
return f"{scaled_value / (10 ** self.scale):.{self.scale}f}"

def get_precision(self) -> int:
return len(str(self._convert_value_to_int()).lstrip("0"))
value_str = f"{self.value:.{self.scale}f}"
return len(value_str.replace(".", ""))

def _unsigned_from_bytes(self, data: bytes) -> int:
return int.from_bytes(data, byteorder="big", signed=False)

def _convert_value_to_int(self) -> int:
scaled_value: Decimal = self.value * (10**self.scale)
return int(scaled_value.quantize(Decimal("1."), rounding=ROUND_DOWN))
return int(self.value.scaleb(self.scale))

def _convert_to_decimal(self, value: Union[int, str]) -> Decimal:
return Decimal(value) / (10**self.scale)
Expand Down
29 changes: 12 additions & 17 deletions multiversx_sdk/abi/managed_decimal_value.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def __init__(self, value: Union[int, str] = 0, scale: int = 0, is_variable: bool
def set_payload(self, value: Any):
if isinstance(value, ManagedDecimalValue):
if self.is_variable != value.is_variable:
raise Exception("Cannot set payload! Both ManagedDecimalValues should be variable.")
raise Exception("Cannot set payload! Both managed decimal values should be variable.")

self.value = value.value

Expand Down Expand Up @@ -46,53 +46,48 @@ def decode_top_level(self, data: bytes):
self.scale = 0
return

biguint = BigUIntValue()
value = BigUIntValue()
scale = U32Value()

if self.is_variable:
# read biguint value length in bytes
big_uint_size = self._unsigned_from_bytes(data[:U32_SIZE_IN_BYTES])
value_length = self._unsigned_from_bytes(data[:U32_SIZE_IN_BYTES])

# remove biguint length; data is only biguint value and scale
data = data[U32_SIZE_IN_BYTES:]

# read biguint value
biguint.decode_top_level(data[:big_uint_size])
value.decode_top_level(data[:value_length])

# remove biguintvalue; data contains only scale
data = data[big_uint_size:]
data = data[value_length:]

# read scale
scale.decode_top_level(data)
self.scale = scale.get_payload()
else:
biguint.decode_top_level(data)
value.decode_top_level(data)

self.value = self._convert_to_decimal(biguint.get_payload())
self.value = self._convert_to_decimal(value.get_payload())

def decode_nested(self, reader: io.BytesIO):
length = self._unsigned_from_bytes(read_bytes_exactly(reader, U32_SIZE_IN_BYTES))
payload = read_bytes_exactly(reader, length)
self.decode_top_level(payload)

def to_string(self) -> str:
value_str = str(self._convert_value_to_int())
if self.scale == 0:
return value_str
if len(value_str) <= self.scale:
# If the value is smaller than the scale, prepend zeros
value_str = "0" * (self.scale - len(value_str) + 1) + value_str
return f"{value_str[:-self.scale]}.{value_str[-self.scale:]}"
scaled_value = self._convert_value_to_int()
return f"{scaled_value / (10 ** self.scale):.{self.scale}f}"

def get_precision(self) -> int:
return len(str(self._convert_value_to_int()).lstrip("0"))
value_str = f"{self.value:.{self.scale}f}"
return len(value_str.replace(".", ""))

def _unsigned_from_bytes(self, data: bytes) -> int:
return int.from_bytes(data, byteorder="big", signed=False)

def _convert_value_to_int(self) -> int:
scaled_value: Decimal = self.value * (10**self.scale)
return int(scaled_value.quantize(Decimal("1."), rounding=ROUND_DOWN))
return int(self.value.scaleb(self.scale))

def _convert_to_decimal(self, value: Union[int, str]) -> Decimal:
return Decimal(value) / (10**self.scale)
Expand Down
55 changes: 54 additions & 1 deletion multiversx_sdk/abi/managed_decimal_value_test.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from decimal import Decimal

from multiversx_sdk.abi.managed_decimal_value import ManagedDecimalValue
from multiversx_sdk.abi.managed_decimal_signed_value import ManagedDecimalSignedValue


class TestManagedDecimalValueTest:

def test_expected_values(self):
value = ManagedDecimalValue(1, 2)

assert not value.is_variable
assert value.get_precision() == 3
assert value.to_string() == "1.00"
Expand All @@ -33,9 +33,62 @@ def test_expected_values(self):
assert value.to_string() == "2.70"
assert value.get_payload() == Decimal("2.7")

value = ManagedDecimalValue(value="0.000000000000000001", scale=18)
assert value.get_precision() == 19
assert value.to_string() == "0.000000000000000001"
assert value.get_payload() == Decimal("0.000000000000000001")

def test_compare_values(self):
value = ManagedDecimalValue(1, 2)

assert value != ManagedDecimalValue(2, 2)
assert value != ManagedDecimalValue(1, 3)
assert value == ManagedDecimalValue(1, 2)


class TestManagedDecimalSignedValueTest:

def test_expected_values(self):
value = ManagedDecimalSignedValue(1, 2)
assert not value.is_variable
assert value.get_precision() == 3
assert value.to_string() == "1.00"
assert value.get_payload() == Decimal(1)

value = ManagedDecimalSignedValue(-1, 2)
assert not value.is_variable
assert value.get_precision() == 4
assert value.to_string() == "-1.00"
assert value.get_payload() == Decimal(-1)

value = ManagedDecimalSignedValue("1.234", 3)
assert value.get_precision() == 4
assert value.to_string() == "1.234"
assert value.get_payload() == Decimal("1.234")

value = ManagedDecimalSignedValue("1.3", 2)
assert value.get_precision() == 3
assert value.to_string() == "1.30"
assert value.get_payload() == Decimal("1.3")

value = ManagedDecimalSignedValue(13, 2)
assert value.get_precision() == 4
assert value.to_string() == "13.00"
assert value.get_payload() == Decimal(13)

value = ManagedDecimalSignedValue("2.7", 2)
assert value.get_precision() == 3
assert value.to_string() == "2.70"
assert value.get_payload() == Decimal("2.7")

value = ManagedDecimalSignedValue(value="0.000000000000000001", scale=18)
assert value.get_precision() == 19
assert value.to_string() == "0.000000000000000001"
assert value.get_payload() == Decimal("0.000000000000000001")

def test_compare_values(self):
value = ManagedDecimalSignedValue(1, 2)

assert value != ManagedDecimalSignedValue(2, 2)
assert value != ManagedDecimalSignedValue(1, 3)
assert value == ManagedDecimalSignedValue(1, 2)

0 comments on commit e1f3a41

Please sign in to comment.