From f8f8dc6587527928afc03b94465d3e04cb7b084f Mon Sep 17 00:00:00 2001 From: Aditya Mohan <54040096+TraXIcoN@users.noreply.github.com> Date: Sun, 3 Nov 2024 23:52:15 -0500 Subject: [PATCH] Add Unit Tests for shared_utils.py (#309) * Added test cases for shared_utils Signed-off-by: Aditya Mohan * Added definition for logging an error shared_utils.py Signed-off-by: Jay Dev Jha * Pre-commit fixes for shared_utils.py Signed-off-by: Jay Dev Jha * pyright fixes for test.py Signed-off-by: Jay Dev Jha * Minor changes to shared_utils.py Signed-off-by: Jay Dev Jha * Updated test.py Signed-off-by: Jay Dev Jha --------- Signed-off-by: Aditya Mohan Signed-off-by: Jay Dev Jha Co-authored-by: Jay Dev Jha --- shared_utils.py | 5 +- test.py | 245 +++++++++++++++++++++++++++++++++++++++++++----- 2 files changed, 223 insertions(+), 27 deletions(-) diff --git a/shared_utils.py b/shared_utils.py index fb91b330..34482b1b 100644 --- a/shared_utils.py +++ b/shared_utils.py @@ -17,10 +17,9 @@ logging.basicConfig(level=LOG_LEVEL, format=LOG_FORMAT) +# Log an error message def log_and_exit(message: str): - """ - Log an error message and then exit with EXIT_FAILURE. - """ + """Log an error message and exit the program.""" logging.error(message) raise SystemExit(1) diff --git a/test.py b/test.py index d1c9c89b..aa091b72 100644 --- a/test.py +++ b/test.py @@ -2,41 +2,238 @@ import logging import unittest +from unittest.mock import Mock, patch -from parse import * from shared_utils import * -class EncodingLineTest(unittest.TestCase): +class EncodingUtilsTest(unittest.TestCase): + """Tests for basic encoding utilities""" + def setUp(self): - logger = logging.getLogger() - logger.disabled = True + self.logger = logging.getLogger() + self.logger.disabled = True + + def test_initialize_encoding(self): + """Test encoding initialization with different bit lengths""" + self.assertEqual(initialize_encoding(32), ["-"] * 32) + self.assertEqual(initialize_encoding(16), ["-"] * 16) + self.assertEqual(initialize_encoding(), ["-"] * 32) # default case + + def test_validate_bit_range(self): + """Test bit range validation""" + # Valid cases + validate_bit_range(7, 3, 15, "test_instr") # 15 fits in 5 bits + validate_bit_range(31, 0, 0xFFFFFFFF, "test_instr") # max 32-bit value - def assertError(self, string: str): - self.assertRaises(SystemExit, process_enc_line, string, "rv_i") + # Invalid cases + with self.assertRaises(SystemExit): + validate_bit_range(3, 7, 1, "test_instr") # msb < lsb + with self.assertRaises(SystemExit): + validate_bit_range(3, 0, 16, "test_instr") # value too large for range + + def test_parse_instruction_line(self): + """Test instruction line parsing""" + name, remaining = parse_instruction_line("add.w r1, r2, r3") + self.assertEqual(name, "add_w") + self.assertEqual(remaining, "r1, r2, r3") + + name, remaining = parse_instruction_line("lui rd imm20 6..2=0x0D") + self.assertEqual(name, "lui") + self.assertEqual(remaining, "rd imm20 6..2=0x0D") + + +class BitManipulationTest(unittest.TestCase): + """Tests for bit manipulation and checking functions""" + + def setUp(self): + self.logger = logging.getLogger() + self.logger.disabled = True + self.test_encoding = initialize_encoding() + + def test_check_overlapping_bits(self): + """Test overlapping bits detection""" + # Valid case - no overlap + self.test_encoding[31 - 5] = "-" + check_overlapping_bits(self.test_encoding, 5, "test_instr") + + # Invalid case - overlap + self.test_encoding[31 - 5] = "1" + with self.assertRaises(SystemExit): + check_overlapping_bits(self.test_encoding, 5, "test_instr") + + def test_update_encoding_for_fixed_range(self): + """Test encoding updates for fixed ranges""" + encoding = initialize_encoding() + update_encoding_for_fixed_range(encoding, 6, 2, 0x0D, "test_instr") + + # Check specific bits are set correctly + self.assertEqual(encoding[31 - 6 : 31 - 1], ["0", "1", "1", "0", "1"]) + + def test_process_fixed_ranges(self): + """Test processing of fixed bit ranges""" + encoding = initialize_encoding() + remaining = "rd imm20 6..2=0x0D 1..0=3" + + result = process_fixed_ranges(remaining, encoding, "test_instr") + self.assertNotIn("6..2=0x0D", result) + self.assertNotIn("1..0=3", result) + + +class EncodingArgsTest(unittest.TestCase): + """Tests for encoding arguments handling""" + + def setUp(self): + self.logger = logging.getLogger() + self.logger.disabled = True + + @patch.dict("shared_utils.arg_lut", {"rd": (11, 7), "rs1": (19, 15)}) + def test_check_arg_lut(self): + """Test argument lookup table checking""" + encoding_args = initialize_encoding() + args = ["rd", "rs1"] + check_arg_lut(args, encoding_args, "test_instr") + + # Verify encoding_args has been updated correctly + self.assertEqual(encoding_args[31 - 11 : 31 - 6], ["rd"] * 5) + self.assertEqual(encoding_args[31 - 19 : 31 - 14], ["rs1"] * 5) + + @patch.dict("shared_utils.arg_lut", {"rs1": (19, 15)}) + def test_handle_arg_lut_mapping(self): + """Test handling of argument mappings""" + # Valid mapping + result = handle_arg_lut_mapping("rs1=new_arg", "test_instr") + self.assertEqual(result, "rs1=new_arg") + + # Invalid mapping + with self.assertRaises(SystemExit): + handle_arg_lut_mapping("invalid_arg=new_arg", "test_instr") + + +class ISAHandlingTest(unittest.TestCase): + """Tests for ISA type handling and validation""" + + def test_extract_isa_type(self): + """Test ISA type extraction""" + self.assertEqual(extract_isa_type("rv32_i"), "rv32") + self.assertEqual(extract_isa_type("rv64_m"), "rv64") + self.assertEqual(extract_isa_type("rv_c"), "rv") + + def test_is_rv_variant(self): + """Test RV variant checking""" + self.assertTrue(is_rv_variant("rv32", "rv")) + self.assertTrue(is_rv_variant("rv", "rv64")) + self.assertFalse(is_rv_variant("rv32", "rv64")) + + def test_same_base_isa(self): + """Test base ISA comparison""" + self.assertTrue(same_base_isa("rv32_i", ["rv32_m", "rv32_a"])) + self.assertTrue(same_base_isa("rv_i", ["rv32_i", "rv64_i"])) + self.assertFalse(same_base_isa("rv32_i", ["rv64_m"])) + + +class StringManipulationTest(unittest.TestCase): + """Tests for string manipulation utilities""" + + def test_pad_to_equal_length(self): + """Test string padding""" + str1, str2 = pad_to_equal_length("101", "1101") + self.assertEqual(len(str1), len(str2)) + self.assertEqual(str1, "-101") + self.assertEqual(str2, "1101") + + def test_overlaps(self): + """Test string overlap checking""" + self.assertTrue(overlaps("1-1", "101")) + self.assertTrue(overlaps("---", "101")) + self.assertFalse(overlaps("111", "101")) + + +class InstructionProcessingTest(unittest.TestCase): + """Tests for instruction processing and validation""" + + def setUp(self): + self.logger = logging.getLogger() + self.logger.disabled = True + # Create a patch for arg_lut + self.arg_lut_patcher = patch.dict( + "shared_utils.arg_lut", {"rd": (11, 7), "imm20": (31, 12)} + ) + self.arg_lut_patcher.start() + + def tearDown(self): + self.arg_lut_patcher.stop() + + @patch("shared_utils.fixed_ranges") + @patch("shared_utils.single_fixed") + def test_process_enc_line(self, mock_single_fixed: Mock, mock_fixed_ranges: Mock): + """Test processing of encoding lines""" + # Setup mock return values + mock_fixed_ranges.findall.return_value = [(6, 2, "0x0D")] + mock_fixed_ranges.sub.return_value = "rd imm20" + mock_single_fixed.findall.return_value = [] + mock_single_fixed.sub.return_value = "rd imm20" + + # Create a mock for split() that returns the expected list + mock_split = Mock(return_value=["rd", "imm20"]) + mock_single_fixed.sub.return_value = Mock(split=mock_split) + + name, data = process_enc_line("lui rd imm20 6..2=0x0D", "rv_i") - def test_lui(self): - name, data = process_enc_line("lui rd imm20 6..2=0x0D 1=1 0=1", "rv_i") self.assertEqual(name, "lui") self.assertEqual(data["extension"], ["rv_i"]) - self.assertEqual(data["match"], "0x37") - self.assertEqual(data["mask"], "0x7f") + self.assertIn("rd", data["variable_fields"]) + self.assertIn("imm20", data["variable_fields"]) + + @patch("os.path.exists") + @patch("shared_utils.logging.error") + def test_find_extension_file(self, mock_logging: Mock, mock_exists: Mock): + """Test extension file finding""" + # Test successful case - file exists in main directory + mock_exists.side_effect = [True, False] + result = find_extension_file("rv32i", "/path/to/opcodes") + self.assertEqual(result, "/path/to/opcodes/rv32i") + + # Test successful case - file exists in unratified directory + mock_exists.side_effect = [False, True] + result = find_extension_file("rv32i", "/path/to/opcodes") + self.assertEqual(result, "/path/to/opcodes/unratified/rv32i") + + # Test failure case - file doesn't exist anywhere + mock_exists.side_effect = [False, False] + with self.assertRaises(SystemExit): + find_extension_file("rv32i", "/path/to/opcodes") + mock_logging.assert_called_with("Extension rv32i not found.") + + def test_process_standard_instructions(self): + """Test processing of standard instructions""" + lines = [ + "add rd rs1 rs2 31..25=0 14..12=0 6..2=0x0C 1..0=3", + "sub rd rs1 rs2 31..25=0x20 14..12=0 6..2=0x0C 1..0=3", + "$pseudo add_pseudo rd rs1 rs2", # Should be skipped + "$import rv32i::mul", # Should be skipped + ] + + instr_dict: InstrDict = {} + file_name = "rv32i" + + with patch("shared_utils.process_enc_line") as mock_process_enc: + # Setup mock return values + mock_process_enc.side_effect = [ + ("add", {"extension": ["rv32i"], "encoding": "encoding1"}), + ("sub", {"extension": ["rv32i"], "encoding": "encoding2"}), + ] - def test_overlapping(self): - self.assertError("jol rd jimm20 6..2=0x00 3..0=7") - self.assertError("jol rd jimm20 6..2=0x00 3=1") - self.assertError("jol rd jimm20 6..2=0x00 10=1") - self.assertError("jol rd jimm20 6..2=0x00 31..10=1") + process_standard_instructions(lines, instr_dict, file_name) - def test_invalid_order(self): - self.assertError("jol 2..6=0x1b") + # Verify process_enc_line was called twice (skipping pseudo and import) + self.assertEqual(mock_process_enc.call_count, 2) - def test_illegal_value(self): - self.assertError("jol rd jimm20 2..0=10") - self.assertError("jol rd jimm20 2..0=0xB") + # Verify the instruction dictionary was updated correctly + self.assertEqual(len(instr_dict), 2) + self.assertIn("add", instr_dict) + self.assertIn("sub", instr_dict) - def test_overlapping_field(self): - self.assertError("jol rd rs1 jimm20 6..2=0x1b 1..0=3") - def test_illegal_field(self): - self.assertError("jol rd jimm128 2..0=3") +if __name__ == "__main__": + unittest.main()