Skip to content

Commit

Permalink
fix: fix signed integer casting (#261)
Browse files Browse the repository at this point in the history
  • Loading branch information
shpface authored Jun 19, 2024
1 parent a4d7f98 commit 302ecab
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 6 deletions.
8 changes: 5 additions & 3 deletions src/braket/default_simulator/openqasm/_helpers/casting.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,14 @@ def _(into: IntType, variable: LiteralType) -> IntegerLiteral:
if isinstance(variable, ArrayLiteral):
value = int("".join("01"[x.value] for x in variable.values[1:]), base=2)
if variable.values[0].value:
value *= -1
value -= 2 ** (len(variable.values) - 1)
else:
value = variable.value
if into.size is not None:
limit = 2 ** (into.size.value - 1)
value = int(np.sign(value) * (np.abs(int(value)) % limit))
limit = 2**into.size.value
value = int(value) % limit
if (value) >= limit / 2:
value -= limit
if value != variable.value:
warnings.warn(
f"Integer overflow for value {variable.value} and size {into.size.value}."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@ def test_int_declaration():
int[8] uninitialized;
int[8] pos = 10;
int[5] neg = -4;
int[8] int_min = -128;
int[8] int_max = 127;
int[3] pos_overflow = 5;
int[3] neg_overflow = -6;
int no_size = 1e9;
Expand All @@ -133,8 +135,10 @@ def test_int_declaration():
assert context.get_value("uninitialized") is None
assert context.get_value("pos") == IntegerLiteral(10)
assert context.get_value("neg") == IntegerLiteral(-4)
assert context.get_value("pos_overflow") == IntegerLiteral(1)
assert context.get_value("neg_overflow") == IntegerLiteral(-2)
assert context.get_value("int_min") == IntegerLiteral(-128)
assert context.get_value("int_max") == IntegerLiteral(127)
assert context.get_value("pos_overflow") == IntegerLiteral(-3)
assert context.get_value("neg_overflow") == IntegerLiteral(2)
assert context.get_value("no_size") == IntegerLiteral(1_000_000_000)

warnings = {(warn.category, warn.message.args[0]) for warn in warn_info}
Expand Down Expand Up @@ -173,6 +177,34 @@ def test_uint_declaration():
assert context.get_value("no_size") == IntegerLiteral(1_000_000_000)


def test_signed_int_cast():
qasm = """
uint[8] x0 = 255;
int[8] x1 = x0;
uint[8] x2 = x1;
uint[8] y0 = 128;
int[8] y1 = y0;
uint[8] y2 = y1;
int[3] z0 = "100";
int[3] z1 = "111";
"""

context = Interpreter().run(qasm)

assert context.get_value("x0") == IntegerLiteral(255)
assert context.get_value("x1") == IntegerLiteral(-1)
assert context.get_value("x2") == IntegerLiteral(255)

assert context.get_value("y0") == IntegerLiteral(128)
assert context.get_value("y1") == IntegerLiteral(-128)
assert context.get_value("y2") == IntegerLiteral(128)

assert context.get_value("z0") == IntegerLiteral(-4)
assert context.get_value("z1") == IntegerLiteral(-1)


def test_float_declaration():
qasm = """
float[16] uninitialized;
Expand Down Expand Up @@ -674,7 +706,7 @@ def test_update_bits_int():
"""
context = Interpreter().run(qasm)
assert context.get_value("x") == IntegerLiteral(3)
assert context.get_value("y") == IntegerLiteral(-2)
assert context.get_value("y") == IntegerLiteral(-6)
assert context.get_value("z") == IntegerLiteral(10)


Expand Down

0 comments on commit 302ecab

Please sign in to comment.