Skip to content

Commit

Permalink
Add tests for removing double inverts
Browse files Browse the repository at this point in the history
  • Loading branch information
gaborszita committed Jan 17, 2025
1 parent 378f86a commit 3ab98c5
Showing 1 changed file with 80 additions and 0 deletions.
80 changes: 80 additions & 0 deletions tests/test_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,87 @@ def test_slice_net_removal_4(self):
block = pyrtl.working_block()
self.num_net_of_type('s', 1, block)
self.num_net_of_type('w', 2, block)

def test_remove_double_inverts_1_invert(self):
inwire = pyrtl.Input(bitwidth=1)
outwire = pyrtl.Output(bitwidth=1)
outwire <<= ~inwire
pyrtl.optimize()
block = pyrtl.working_block()
self.assert_num_net(2, block)
self.assert_num_wires(3, block)

def test_remove_double_inverts_3_inverts(self):
inwire = pyrtl.Input(bitwidth=1)
outwire = pyrtl.Output(bitwidth=1)
outwire <<= ~(~(~inwire))
pyrtl.optimize()
block = pyrtl.working_block()
self.assert_num_net(2, block)
self.assert_num_wires(3, block)

def test_remove_double_inverts_5_inverts(self):
inwire = pyrtl.Input(bitwidth=1)
outwire = pyrtl.Output(bitwidth=1)
outwire <<= ~(~(~(~(~inwire))))
pyrtl.optimize()
block = pyrtl.working_block()
self.assert_num_net(2, block)
self.assert_num_wires(3, block)

def test_remove_double_inverts_2_inverts(self):
inwire = pyrtl.Input(bitwidth=1)
outwire = pyrtl.Output(bitwidth=1)
outwire <<= ~(~inwire)
pyrtl.optimize()
block = pyrtl.working_block()
self.assert_num_net(1, block)
self.assert_num_wires(2, block)

def test_remove_double_inverts_4_inverts(self):
inwire = pyrtl.Input(bitwidth=1)
outwire = pyrtl.Output(bitwidth=1)
outwire <<= ~(~(~(~inwire)))
pyrtl.optimize()
block = pyrtl.working_block()
self.assert_num_net(1, block)
self.assert_num_wires(2, block)

def test_remove_double_inverts_6_inverts(self):
inwire = pyrtl.Input(bitwidth=1)
outwire = pyrtl.Output(bitwidth=1)
outwire <<= ~(~(~(~(~(~inwire)))))
pyrtl.optimize()
block = pyrtl.working_block()
self.assert_num_net(1, block)
self.assert_num_wires(2, block)

def test_dont_remove_double_inverts_another_user(self):
inwire = pyrtl.Input(bitwidth=1)
outwire = pyrtl.Output(bitwidth=1)
outwire2 = pyrtl.Output(bitwidth=1)
tempwire = pyrtl.WireVector()
tempwire <<= ~inwire
outwire <<= ~tempwire
outwire2 <<= tempwire
pyrtl.optimize()
block = pyrtl.working_block()
self.assert_num_net(4, block)
self.assert_num_wires(5, block)

def test_multiple_double_invert_chains(self):
# _remove_double_inverts removes double inverts by chains,
# so it is useful to make sure it can remove
# double inverts from multiple chains
inwire = pyrtl.Input(bitwidth=1)
outwire = pyrtl.Output(bitwidth=1)
outwire2 = pyrtl.Output(bitwidth=1)
outwire <<= ~(~inwire)
outwire2 <<= ~(~(~(~(inwire))))
pyrtl.optimize()
block = pyrtl.working_block()
self.assert_num_net(2, block)
self.assert_num_wires(3, block)

class TestConstFolding(NetWireNumTestCases):
def setUp(self):
Expand Down

0 comments on commit 3ab98c5

Please sign in to comment.