From 723dda2772fe51b55c5ff07f1b00a2f8a090deaa Mon Sep 17 00:00:00 2001 From: Aditya Mohan Date: Thu, 31 Oct 2024 03:54:53 -0400 Subject: [PATCH 1/6] Added test cases for shared_utils Signed-off-by: Aditya Mohan --- test.py | 245 ++++++++++++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 221 insertions(+), 24 deletions(-) diff --git a/test.py b/test.py index eb9b6783..c09e9227 100644 --- a/test.py +++ b/test.py @@ -2,41 +2,238 @@ import logging import unittest +from unittest.mock import Mock, mock_open, patch -from parse import * from shared_utils import * -class EncodingLineTest(unittest.TestCase): +class EncodingUtilsTest(unittest.TestCase): + """Tests for basic encoding utilities""" + def setUp(self): - logger = logging.getLogger() - logger.disabled = True + self.logger = logging.getLogger() + self.logger.disabled = True + + def test_initialize_encoding(self): + """Test encoding initialization with different bit lengths""" + self.assertEqual(initialize_encoding(32), ["-"] * 32) + self.assertEqual(initialize_encoding(16), ["-"] * 16) + self.assertEqual(initialize_encoding(), ["-"] * 32) # default case + + def test_validate_bit_range(self): + """Test bit range validation""" + # Valid cases + validate_bit_range(7, 3, 15, "test_instr") # 15 fits in 5 bits + validate_bit_range(31, 0, 0xFFFFFFFF, "test_instr") # max 32-bit value - def assertError(self, string): - self.assertRaises(SystemExit, process_enc_line, string, "rv_i") + # Invalid cases + with self.assertRaises(SystemExit): + validate_bit_range(3, 7, 1, "test_instr") # msb < lsb + with self.assertRaises(SystemExit): + validate_bit_range(3, 0, 16, "test_instr") # value too large for range + + def test_parse_instruction_line(self): + """Test instruction line parsing""" + name, remaining = parse_instruction_line("add.w r1, r2, r3") + self.assertEqual(name, "add_w") + self.assertEqual(remaining, "r1, r2, r3") + + name, remaining = parse_instruction_line("lui rd imm20 6..2=0x0D") + self.assertEqual(name, "lui") + self.assertEqual(remaining, "rd imm20 6..2=0x0D") + + +class BitManipulationTest(unittest.TestCase): + """Tests for bit manipulation and checking functions""" + + def setUp(self): + self.logger = logging.getLogger() + self.logger.disabled = True + self.test_encoding = initialize_encoding() + + def test_check_overlapping_bits(self): + """Test overlapping bits detection""" + # Valid case - no overlap + self.test_encoding[31 - 5] = "-" + check_overlapping_bits(self.test_encoding, 5, "test_instr") + + # Invalid case - overlap + self.test_encoding[31 - 5] = "1" + with self.assertRaises(SystemExit): + check_overlapping_bits(self.test_encoding, 5, "test_instr") + + def test_update_encoding_for_fixed_range(self): + """Test encoding updates for fixed ranges""" + encoding = initialize_encoding() + update_encoding_for_fixed_range(encoding, 6, 2, 0x0D, "test_instr") + + # Check specific bits are set correctly + self.assertEqual(encoding[31 - 6 : 31 - 1], ["0", "1", "1", "0", "1"]) + + def test_process_fixed_ranges(self): + """Test processing of fixed bit ranges""" + encoding = initialize_encoding() + remaining = "rd imm20 6..2=0x0D 1..0=3" + + result = process_fixed_ranges(remaining, encoding, "test_instr") + self.assertNotIn("6..2=0x0D", result) + self.assertNotIn("1..0=3", result) + + +class EncodingArgsTest(unittest.TestCase): + """Tests for encoding arguments handling""" + + def setUp(self): + self.logger = logging.getLogger() + self.logger.disabled = True + + @patch.dict("shared_utils.arg_lut", {"rd": (11, 7), "rs1": (19, 15)}) + def test_check_arg_lut(self): + """Test argument lookup table checking""" + encoding_args = initialize_encoding() + args = ["rd", "rs1"] + check_arg_lut(args, encoding_args, "test_instr") + + # Verify encoding_args has been updated correctly + self.assertEqual(encoding_args[31 - 11 : 31 - 6], ["rd"] * 5) + self.assertEqual(encoding_args[31 - 19 : 31 - 14], ["rs1"] * 5) + + @patch.dict("shared_utils.arg_lut", {"rs1": (19, 15)}) + def test_handle_arg_lut_mapping(self): + """Test handling of argument mappings""" + # Valid mapping + result = handle_arg_lut_mapping("rs1=new_arg", "test_instr") + self.assertEqual(result, "rs1=new_arg") + + # Invalid mapping + with self.assertRaises(SystemExit): + handle_arg_lut_mapping("invalid_arg=new_arg", "test_instr") + + +class ISAHandlingTest(unittest.TestCase): + """Tests for ISA type handling and validation""" + + def test_extract_isa_type(self): + """Test ISA type extraction""" + self.assertEqual(extract_isa_type("rv32_i"), "rv32") + self.assertEqual(extract_isa_type("rv64_m"), "rv64") + self.assertEqual(extract_isa_type("rv_c"), "rv") + + def test_is_rv_variant(self): + """Test RV variant checking""" + self.assertTrue(is_rv_variant("rv32", "rv")) + self.assertTrue(is_rv_variant("rv", "rv64")) + self.assertFalse(is_rv_variant("rv32", "rv64")) + + def test_same_base_isa(self): + """Test base ISA comparison""" + self.assertTrue(same_base_isa("rv32_i", ["rv32_m", "rv32_a"])) + self.assertTrue(same_base_isa("rv_i", ["rv32_i", "rv64_i"])) + self.assertFalse(same_base_isa("rv32_i", ["rv64_m"])) + + +class StringManipulationTest(unittest.TestCase): + """Tests for string manipulation utilities""" + + def test_pad_to_equal_length(self): + """Test string padding""" + str1, str2 = pad_to_equal_length("101", "1101") + self.assertEqual(len(str1), len(str2)) + self.assertEqual(str1, "-101") + self.assertEqual(str2, "1101") + + def test_overlaps(self): + """Test string overlap checking""" + self.assertTrue(overlaps("1-1", "101")) + self.assertTrue(overlaps("---", "101")) + self.assertFalse(overlaps("111", "101")) + + +class InstructionProcessingTest(unittest.TestCase): + """Tests for instruction processing and validation""" + + def setUp(self): + self.logger = logging.getLogger() + self.logger.disabled = True + # Create a patch for arg_lut + self.arg_lut_patcher = patch.dict( + "shared_utils.arg_lut", {"rd": (11, 7), "imm20": (31, 12)} + ) + self.arg_lut_patcher.start() + + def tearDown(self): + self.arg_lut_patcher.stop() + + @patch("shared_utils.fixed_ranges") + @patch("shared_utils.single_fixed") + def test_process_enc_line(self, mock_single_fixed, mock_fixed_ranges): + """Test processing of encoding lines""" + # Setup mock return values + mock_fixed_ranges.findall.return_value = [(6, 2, "0x0D")] + mock_fixed_ranges.sub.return_value = "rd imm20" + mock_single_fixed.findall.return_value = [] + mock_single_fixed.sub.return_value = "rd imm20" + + # Create a mock for split() that returns the expected list + mock_split = Mock(return_value=["rd", "imm20"]) + mock_single_fixed.sub.return_value = Mock(split=mock_split) + + name, data = process_enc_line("lui rd imm20 6..2=0x0D", "rv_i") - def test_lui(self): - name, data = process_enc_line("lui rd imm20 6..2=0x0D 1=1 0=1", "rv_i") self.assertEqual(name, "lui") self.assertEqual(data["extension"], ["rv_i"]) - self.assertEqual(data["match"], "0x37") - self.assertEqual(data["mask"], "0x7f") + self.assertIn("rd", data["variable_fields"]) + self.assertIn("imm20", data["variable_fields"]) + + @patch("os.path.exists") + @patch("shared_utils.logging.error") + def test_find_extension_file(self, mock_logging, mock_exists): + """Test extension file finding""" + # Test successful case - file exists in main directory + mock_exists.side_effect = [True, False] + result = find_extension_file("rv32i", "/path/to/opcodes") + self.assertEqual(result, "/path/to/opcodes/rv32i") + + # Test successful case - file exists in unratified directory + mock_exists.side_effect = [False, True] + result = find_extension_file("rv32i", "/path/to/opcodes") + self.assertEqual(result, "/path/to/opcodes/unratified/rv32i") + + # Test failure case - file doesn't exist anywhere + mock_exists.side_effect = [False, False] + with self.assertRaises(SystemExit): + find_extension_file("rv32i", "/path/to/opcodes") + mock_logging.assert_called_with("Extension rv32i not found.") + + def test_process_standard_instructions(self): + """Test processing of standard instructions""" + lines = [ + "add rd rs1 rs2 31..25=0 14..12=0 6..2=0x0C 1..0=3", + "sub rd rs1 rs2 31..25=0x20 14..12=0 6..2=0x0C 1..0=3", + "$pseudo add_pseudo rd rs1 rs2", # Should be skipped + "$import rv32i::mul", # Should be skipped + ] + + instr_dict = {} + file_name = "rv32i" + + with patch("shared_utils.process_enc_line") as mock_process_enc: + # Setup mock return values + mock_process_enc.side_effect = [ + ("add", {"extension": ["rv32i"], "encoding": "encoding1"}), + ("sub", {"extension": ["rv32i"], "encoding": "encoding2"}), + ] - def test_overlapping(self): - self.assertError("jol rd jimm20 6..2=0x00 3..0=7") - self.assertError("jol rd jimm20 6..2=0x00 3=1") - self.assertError("jol rd jimm20 6..2=0x00 10=1") - self.assertError("jol rd jimm20 6..2=0x00 31..10=1") + process_standard_instructions(lines, instr_dict, file_name) - def test_invalid_order(self): - self.assertError("jol 2..6=0x1b") + # Verify process_enc_line was called twice (skipping pseudo and import) + self.assertEqual(mock_process_enc.call_count, 2) - def test_illegal_value(self): - self.assertError("jol rd jimm20 2..0=10") - self.assertError("jol rd jimm20 2..0=0xB") + # Verify the instruction dictionary was updated correctly + self.assertEqual(len(instr_dict), 2) + self.assertIn("add", instr_dict) + self.assertIn("sub", instr_dict) - def test_overlapping_field(self): - self.assertError("jol rd rs1 jimm20 6..2=0x1b 1..0=3") - def test_illegal_field(self): - self.assertError("jol rd jimm128 2..0=3") +if __name__ == "__main__": + unittest.main() From 6895d598ce8d27dba6a4582649cfb6374c66bc84 Mon Sep 17 00:00:00 2001 From: Jay Dev Jha Date: Thu, 31 Oct 2024 16:55:43 +0530 Subject: [PATCH 2/6] Added definition for logging an error shared_utils.py Signed-off-by: Jay Dev Jha --- shared_utils.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/shared_utils.py b/shared_utils.py index 5c925151..a45d2554 100644 --- a/shared_utils.py +++ b/shared_utils.py @@ -566,3 +566,11 @@ def instr_dict_2_extensions(instr_dict): # Returns signed interpretation of a value within a given width def signed(value, width): return value if 0 <= value < (1 << (width - 1)) else value - (1 << width) + + +# Log an error message +def log_and_exit(message): + """Log an error message and exit the program.""" + logging.error(message) + raise SystemExit(1) + From 06ff7f8d71b085c7415e09adc100ad5d0dacd781 Mon Sep 17 00:00:00 2001 From: Jay Dev Jha Date: Thu, 31 Oct 2024 16:58:32 +0530 Subject: [PATCH 3/6] Pre-commit fixes for shared_utils.py Signed-off-by: Jay Dev Jha --- shared_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/shared_utils.py b/shared_utils.py index a45d2554..4e777932 100644 --- a/shared_utils.py +++ b/shared_utils.py @@ -573,4 +573,3 @@ def log_and_exit(message): """Log an error message and exit the program.""" logging.error(message) raise SystemExit(1) - From 827cd1107711ac39a39c25304ff16e6d86c7620f Mon Sep 17 00:00:00 2001 From: Jay Dev Jha Date: Sat, 2 Nov 2024 20:03:41 +0530 Subject: [PATCH 4/6] pyright fixes for test.py Signed-off-by: Jay Dev Jha --- test.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test.py b/test.py index c09e9227..aa091b72 100644 --- a/test.py +++ b/test.py @@ -2,7 +2,7 @@ import logging import unittest -from unittest.mock import Mock, mock_open, patch +from unittest.mock import Mock, patch from shared_utils import * @@ -166,7 +166,7 @@ def tearDown(self): @patch("shared_utils.fixed_ranges") @patch("shared_utils.single_fixed") - def test_process_enc_line(self, mock_single_fixed, mock_fixed_ranges): + def test_process_enc_line(self, mock_single_fixed: Mock, mock_fixed_ranges: Mock): """Test processing of encoding lines""" # Setup mock return values mock_fixed_ranges.findall.return_value = [(6, 2, "0x0D")] @@ -187,7 +187,7 @@ def test_process_enc_line(self, mock_single_fixed, mock_fixed_ranges): @patch("os.path.exists") @patch("shared_utils.logging.error") - def test_find_extension_file(self, mock_logging, mock_exists): + def test_find_extension_file(self, mock_logging: Mock, mock_exists: Mock): """Test extension file finding""" # Test successful case - file exists in main directory mock_exists.side_effect = [True, False] @@ -214,7 +214,7 @@ def test_process_standard_instructions(self): "$import rv32i::mul", # Should be skipped ] - instr_dict = {} + instr_dict: InstrDict = {} file_name = "rv32i" with patch("shared_utils.process_enc_line") as mock_process_enc: From 9d7b4b6d873724cf20e1588796161e1a62760791 Mon Sep 17 00:00:00 2001 From: Jay Dev Jha Date: Sat, 2 Nov 2024 20:04:15 +0530 Subject: [PATCH 5/6] Minor changes to shared_utils.py Signed-off-by: Jay Dev Jha --- shared_utils.py | 161 +++++++++++++++++++++++++++++------------------- 1 file changed, 97 insertions(+), 64 deletions(-) diff --git a/shared_utils.py b/shared_utils.py index 4e777932..34482b1b 100644 --- a/shared_utils.py +++ b/shared_utils.py @@ -6,6 +6,7 @@ import pprint import re from itertools import chain +from typing import Dict, TypedDict from constants import * @@ -16,30 +17,35 @@ logging.basicConfig(level=LOG_LEVEL, format=LOG_FORMAT) +# Log an error message +def log_and_exit(message: str): + """Log an error message and exit the program.""" + logging.error(message) + raise SystemExit(1) + + # Initialize encoding to 32-bit '-' values -def initialize_encoding(bits=32): +def initialize_encoding(bits: int = 32) -> "list[str]": """Initialize encoding with '-' to represent don't care bits.""" return ["-"] * bits # Validate bit range and value -def validate_bit_range(msb, lsb, entry_value, line): +def validate_bit_range(msb: int, lsb: int, entry_value: int, line: str): """Validate the bit range and entry value.""" if msb < lsb: - logging.error( + log_and_exit( f'{line.split(" ")[0]:<10} has position {msb} less than position {lsb} in its encoding' ) - raise SystemExit(1) if entry_value >= (1 << (msb - lsb + 1)): - logging.error( + log_and_exit( f'{line.split(" ")[0]:<10} has an illegal value {entry_value} assigned as per the bit width {msb - lsb}' ) - raise SystemExit(1) # Split the instruction line into name and remaining part -def parse_instruction_line(line): +def parse_instruction_line(line: str) -> "tuple[str, str]": """Parse the instruction name and the remaining encoding details.""" name, remaining = line.split(" ", 1) name = name.replace(".", "_") # Replace dots for compatibility @@ -48,17 +54,18 @@ def parse_instruction_line(line): # Verify Overlapping Bits -def check_overlapping_bits(encoding, ind, line): +def check_overlapping_bits(encoding: "list[str]", ind: int, line: str): """Check for overlapping bits in the encoding.""" if encoding[31 - ind] != "-": - logging.error( + log_and_exit( f'{line.split(" ")[0]:<10} has {ind} bit overlapping in its opcodes' ) - raise SystemExit(1) # Update encoding for fixed ranges -def update_encoding_for_fixed_range(encoding, msb, lsb, entry_value, line): +def update_encoding_for_fixed_range( + encoding: "list[str]", msb: int, lsb: int, entry_value: int, line: str +): """ Update encoding bits for a given bit range. Checks for overlapping bits and assigns the value accordingly. @@ -70,7 +77,7 @@ def update_encoding_for_fixed_range(encoding, msb, lsb, entry_value, line): # Process fixed bit patterns -def process_fixed_ranges(remaining, encoding, line): +def process_fixed_ranges(remaining: str, encoding: "list[str]", line: str): """Process fixed bit ranges in the encoding.""" for s2, s1, entry in fixed_ranges.findall(remaining): msb, lsb, entry_value = int(s2), int(s1), int(entry, 0) @@ -83,9 +90,9 @@ def process_fixed_ranges(remaining, encoding, line): # Process single bit assignments -def process_single_fixed(remaining, encoding, line): +def process_single_fixed(remaining: str, encoding: "list[str]", line: str): """Process single fixed assignments in the encoding.""" - for lsb, value, drop in single_fixed.findall(remaining): + for lsb, value, _drop in single_fixed.findall(remaining): lsb = int(lsb, 0) value = int(value, 0) @@ -94,7 +101,7 @@ def process_single_fixed(remaining, encoding, line): # Main function to check argument look-up table -def check_arg_lut(args, encoding_args, name): +def check_arg_lut(args: "list[str]", encoding_args: "list[str]", name: str): """Check if arguments are present in arg_lut.""" for arg in args: if arg not in arg_lut: @@ -104,30 +111,28 @@ def check_arg_lut(args, encoding_args, name): # Handle missing argument mappings -def handle_arg_lut_mapping(arg, name): +def handle_arg_lut_mapping(arg: str, name: str): """Handle cases where an argument needs to be mapped to an existing one.""" parts = arg.split("=") if len(parts) == 2: - existing_arg, new_arg = parts + existing_arg, _new_arg = parts if existing_arg in arg_lut: arg_lut[arg] = arg_lut[existing_arg] else: - logging.error( + log_and_exit( f" Found field {existing_arg} in variable {arg} in instruction {name} " f"whose mapping in arg_lut does not exist" ) - raise SystemExit(1) else: - logging.error( + log_and_exit( f" Found variable {arg} in instruction {name} " f"whose mapping in arg_lut does not exist" ) - raise SystemExit(1) return arg # Update encoding args with variables -def update_encoding_args(encoding_args, arg, msb, lsb): +def update_encoding_args(encoding_args: "list[str]", arg: str, msb: int, lsb: int): """Update encoding arguments and ensure no overlapping.""" for ind in range(lsb, msb + 1): check_overlapping_bits(encoding_args, ind, arg) @@ -135,15 +140,26 @@ def update_encoding_args(encoding_args, arg, msb, lsb): # Compute match and mask -def convert_encoding_to_match_mask(encoding): +def convert_encoding_to_match_mask(encoding: "list[str]") -> "tuple[str, str]": """Convert the encoding list to match and mask strings.""" match = "".join(encoding).replace("-", "0") mask = "".join(encoding).replace("0", "1").replace("-", "0") return hex(int(match, 2)), hex(int(mask, 2)) +class SingleInstr(TypedDict): + encoding: str + variable_fields: "list[str]" + extension: "list[str]" + match: str + mask: str + + +InstrDict = Dict[str, SingleInstr] + + # Processing main function for a line in the encoding file -def process_enc_line(line, ext): +def process_enc_line(line: str, ext: str) -> "tuple[str, SingleInstr]": """ This function processes each line of the encoding files (rv*). As part of the processing, the function ensures that the encoding is legal through the @@ -199,13 +215,13 @@ def process_enc_line(line, ext): # Extract ISA Type -def extract_isa_type(ext_name): +def extract_isa_type(ext_name: str) -> str: """Extracts the ISA type from the extension name.""" return ext_name.split("_")[0] # Verify the types for RV* -def is_rv_variant(type1, type2): +def is_rv_variant(type1: str, type2: str) -> bool: """Checks if the types are RV variants (rv32/rv64).""" return (type2 == "rv" and type1 in {"rv32", "rv64"}) or ( type1 == "rv" and type2 in {"rv32", "rv64"} @@ -213,77 +229,79 @@ def is_rv_variant(type1, type2): # Check for same base ISA -def has_same_base_isa(type1, type2): +def has_same_base_isa(type1: str, type2: str) -> bool: """Determines if the two ISA types share the same base.""" return type1 == type2 or is_rv_variant(type1, type2) # Compare the base ISA type of a given extension name against a list of extension names -def same_base_isa(ext_name, ext_name_list): +def same_base_isa(ext_name: str, ext_name_list: "list[str]") -> bool: """Checks if the base ISA type of ext_name matches any in ext_name_list.""" type1 = extract_isa_type(ext_name) return any(has_same_base_isa(type1, extract_isa_type(ext)) for ext in ext_name_list) # Pad two strings to equal length -def pad_to_equal_length(str1, str2, pad_char="-"): +def pad_to_equal_length(str1: str, str2: str, pad_char: str = "-") -> "tuple[str, str]": """Pads two strings to equal length using the given padding character.""" max_len = max(len(str1), len(str2)) return str1.rjust(max_len, pad_char), str2.rjust(max_len, pad_char) # Check compatibility for two characters -def has_no_conflict(char1, char2): +def has_no_conflict(char1: str, char2: str) -> bool: """Checks if two characters are compatible (either matching or don't-care).""" return char1 == "-" or char2 == "-" or char1 == char2 # Conflict check between two encoded strings -def overlaps(x, y): +def overlaps(x: str, y: str) -> bool: """Checks if two encoded strings overlap without conflict.""" x, y = pad_to_equal_length(x, y) return all(has_no_conflict(x[i], y[i]) for i in range(len(x))) # Check presence of keys in dictionary. -def is_in_nested_dict(a, key1, key2): +def is_in_nested_dict(a: "dict[str, set[str]]", key1: str, key2: str) -> bool: """Checks if key2 exists in the dictionary under key1.""" return key1 in a and key2 in a[key1] # Overlap allowance -def overlap_allowed(a, x, y): +def overlap_allowed(a: "dict[str, set[str]]", x: str, y: str) -> bool: """Determines if overlap is allowed between x and y based on nested dictionary checks""" return is_in_nested_dict(a, x, y) or is_in_nested_dict(a, y, x) # Check overlap allowance between extensions -def extension_overlap_allowed(x, y): +def extension_overlap_allowed(x: str, y: str) -> bool: """Checks if overlap is allowed between two extensions using the overlapping_extensions dictionary.""" return overlap_allowed(overlapping_extensions, x, y) # Check overlap allowance between instructions -def instruction_overlap_allowed(x, y): +def instruction_overlap_allowed(x: str, y: str) -> bool: """Checks if overlap is allowed between two instructions using the overlapping_instructions dictionary.""" return overlap_allowed(overlapping_instructions, x, y) # Check 'nf' field -def is_segmented_instruction(instruction): +def is_segmented_instruction(instruction: SingleInstr) -> bool: """Checks if an instruction contains the 'nf' field.""" return "nf" in instruction["variable_fields"] # Expand 'nf' fields -def update_with_expanded_instructions(updated_dict, key, value): +def update_with_expanded_instructions( + updated_dict: InstrDict, key: str, value: SingleInstr +): """Expands 'nf' fields in the instruction dictionary and updates it with new instructions.""" for new_key, new_value in expand_nf_field(key, value): updated_dict[new_key] = new_value # Process instructions, expanding segmented ones and updating the dictionary -def add_segmented_vls_insn(instr_dict): +def add_segmented_vls_insn(instr_dict: InstrDict) -> InstrDict: """Processes instructions, expanding segmented ones and updating the dictionary.""" # Use dictionary comprehension for efficiency return dict( @@ -299,7 +317,9 @@ def add_segmented_vls_insn(instr_dict): # Expand the 'nf' field in the instruction dictionary -def expand_nf_field(name, single_dict): +def expand_nf_field( + name: str, single_dict: SingleInstr +) -> "list[tuple[str, SingleInstr]]": """Validate and prepare the instruction dictionary.""" validate_nf_field(single_dict, name) remove_nf_field(single_dict) @@ -322,29 +342,33 @@ def expand_nf_field(name, single_dict): # Validate the presence of 'nf' -def validate_nf_field(single_dict, name): +def validate_nf_field(single_dict: SingleInstr, name: str): """Validates the presence of 'nf' in variable fields before expansion.""" if "nf" not in single_dict["variable_fields"]: - logging.error(f"Cannot expand nf field for instruction {name}") - raise SystemExit(1) + log_and_exit(f"Cannot expand nf field for instruction {name}") # Remove 'nf' from variable fields -def remove_nf_field(single_dict): +def remove_nf_field(single_dict: SingleInstr): """Removes 'nf' from variable fields in the instruction dictionary.""" single_dict["variable_fields"].remove("nf") # Update the mask to include the 'nf' field -def update_mask(single_dict): +def update_mask(single_dict: SingleInstr): """Updates the mask to include the 'nf' field in the instruction dictionary.""" single_dict["mask"] = hex(int(single_dict["mask"], 16) | 0b111 << 29) # Create an expanded instruction def create_expanded_instruction( - name, single_dict, nf, name_expand_index, base_match, encoding_prefix -): + name: str, + single_dict: SingleInstr, + nf: int, + name_expand_index: int, + base_match: int, + encoding_prefix: str, +) -> "tuple[str, SingleInstr]": """Creates an expanded instruction based on 'nf' value.""" new_single_dict = copy.deepcopy(single_dict) @@ -363,7 +387,7 @@ def create_expanded_instruction( # Return a list of relevant lines from the specified file -def read_lines(file): +def read_lines(file: str) -> "list[str]": """Reads lines from a file and returns non-blank, non-comment lines.""" with open(file) as fp: lines = (line.rstrip() for line in fp) @@ -371,7 +395,9 @@ def read_lines(file): # Update the instruction dictionary -def process_standard_instructions(lines, instr_dict, file_name): +def process_standard_instructions( + lines: "list[str]", instr_dict: InstrDict, file_name: str +): """Processes standard instructions from the given lines and updates the instruction dictionary.""" for line in lines: if "$import" in line or "$pseudo" in line: @@ -409,7 +435,12 @@ def process_standard_instructions(lines, instr_dict, file_name): # Incorporate pseudo instructions into the instruction dictionary based on given conditions def process_pseudo_instructions( - lines, instr_dict, file_name, opcodes_dir, include_pseudo, include_pseudo_ops + lines: "list[str]", + instr_dict: InstrDict, + file_name: str, + opcodes_dir: str, + include_pseudo: bool, + include_pseudo_ops: "list[str]", ): """Processes pseudo instructions from the given lines and updates the instruction dictionary.""" for line in lines: @@ -433,12 +464,15 @@ def process_pseudo_instructions( else: if single_dict["match"] != instr_dict[name]["match"]: instr_dict[f"{name}_pseudo"] = single_dict - elif single_dict["extension"] not in instr_dict[name]["extension"]: + # TODO: This expression is always false since both sides are list[str]. + elif single_dict["extension"] not in instr_dict[name]["extension"]: # type: ignore instr_dict[name]["extension"].extend(single_dict["extension"]) # Integrate imported instructions into the instruction dictionary -def process_imported_instructions(lines, instr_dict, file_name, opcodes_dir): +def process_imported_instructions( + lines: "list[str]", instr_dict: InstrDict, file_name: str, opcodes_dir: str +): """Processes imported instructions from the given lines and updates the instruction dictionary.""" for line in lines: if "$import" not in line: @@ -464,7 +498,7 @@ def process_imported_instructions(lines, instr_dict, file_name, opcodes_dir): # Locate the path of the specified extension file, checking fallback directories -def find_extension_file(ext, opcodes_dir): +def find_extension_file(ext: str, opcodes_dir: str): """Finds the extension file path, considering the unratified directory if necessary.""" ext_file = f"{opcodes_dir}/{ext}" if not os.path.exists(ext_file): @@ -475,7 +509,9 @@ def find_extension_file(ext, opcodes_dir): # Confirm the presence of an original instruction in the corresponding extension file. -def validate_instruction_in_extension(inst, ext_file, file_name, pseudo_inst): +def validate_instruction_in_extension( + inst: str, ext_file: str, file_name: str, pseudo_inst: str +): """Validates if the original instruction exists in the dependent extension.""" found = False for oline in open(ext_file): @@ -489,7 +525,11 @@ def validate_instruction_in_extension(inst, ext_file, file_name, pseudo_inst): # Construct a dictionary of instructions filtered by specified criteria -def create_inst_dict(file_filter, include_pseudo=False, include_pseudo_ops=[]): +def create_inst_dict( + file_filter: "list[str]", + include_pseudo: bool = False, + include_pseudo_ops: "list[str]" = [], +) -> InstrDict: """Creates a dictionary of instructions based on the provided file filters.""" """ @@ -522,7 +562,7 @@ def create_inst_dict(file_filter, include_pseudo=False, include_pseudo_ops=[]): is not already present; otherwise, it is skipped. """ opcodes_dir = os.path.dirname(os.path.realpath(__file__)) - instr_dict = {} + instr_dict: InstrDict = {} file_names = [ file @@ -559,17 +599,10 @@ def create_inst_dict(file_filter, include_pseudo=False, include_pseudo_ops=[]): # Extracts the extensions used in an instruction dictionary -def instr_dict_2_extensions(instr_dict): +def instr_dict_2_extensions(instr_dict: InstrDict) -> "list[str]": return list({item["extension"][0] for item in instr_dict.values()}) # Returns signed interpretation of a value within a given width -def signed(value, width): +def signed(value: int, width: int) -> int: return value if 0 <= value < (1 << (width - 1)) else value - (1 << width) - - -# Log an error message -def log_and_exit(message): - """Log an error message and exit the program.""" - logging.error(message) - raise SystemExit(1) From 49a0e3e9ec54faf12e7fc0e90695771f1e9e0b2d Mon Sep 17 00:00:00 2001 From: Jay Dev Jha Date: Sat, 2 Nov 2024 20:17:35 +0530 Subject: [PATCH 6/6] Updated test.py Signed-off-by: Jay Dev Jha --- test.py | 833 ++++++++++++++++---------------------------------------- 1 file changed, 232 insertions(+), 601 deletions(-) diff --git a/test.py b/test.py index 34482b1b..aa091b72 100644 --- a/test.py +++ b/test.py @@ -1,608 +1,239 @@ #!/usr/bin/env python3 -import copy -import glob -import logging -import os -import pprint -import re -from itertools import chain -from typing import Dict, TypedDict - -from constants import * - -LOG_FORMAT = "%(levelname)s:: %(message)s" -LOG_LEVEL = logging.INFO - -pretty_printer = pprint.PrettyPrinter(indent=2) -logging.basicConfig(level=LOG_LEVEL, format=LOG_FORMAT) - - -# Log an error message -def log_and_exit(message: str): - """Log an error message and exit the program.""" - logging.error(message) - raise SystemExit(1) - - -# Initialize encoding to 32-bit '-' values -def initialize_encoding(bits: int = 32) -> "list[str]": - """Initialize encoding with '-' to represent don't care bits.""" - return ["-"] * bits - - -# Validate bit range and value -def validate_bit_range(msb: int, lsb: int, entry_value: int, line: str): - """Validate the bit range and entry value.""" - if msb < lsb: - log_and_exit( - f'{line.split(" ")[0]:<10} has position {msb} less than position {lsb} in its encoding' - ) - - if entry_value >= (1 << (msb - lsb + 1)): - log_and_exit( - f'{line.split(" ")[0]:<10} has an illegal value {entry_value} assigned as per the bit width {msb - lsb}' - ) - - -# Split the instruction line into name and remaining part -def parse_instruction_line(line: str) -> "tuple[str, str]": - """Parse the instruction name and the remaining encoding details.""" - name, remaining = line.split(" ", 1) - name = name.replace(".", "_") # Replace dots for compatibility - remaining = remaining.lstrip() # Remove leading whitespace - return name, remaining - - -# Verify Overlapping Bits -def check_overlapping_bits(encoding: "list[str]", ind: int, line: str): - """Check for overlapping bits in the encoding.""" - if encoding[31 - ind] != "-": - log_and_exit( - f'{line.split(" ")[0]:<10} has {ind} bit overlapping in its opcodes' - ) - - -# Update encoding for fixed ranges -def update_encoding_for_fixed_range( - encoding: "list[str]", msb: int, lsb: int, entry_value: int, line: str -): - """ - Update encoding bits for a given bit range. - Checks for overlapping bits and assigns the value accordingly. - """ - for ind in range(lsb, msb + 1): - check_overlapping_bits(encoding, ind, line) - bit = str((entry_value >> (ind - lsb)) & 1) - encoding[31 - ind] = bit - - -# Process fixed bit patterns -def process_fixed_ranges(remaining: str, encoding: "list[str]", line: str): - """Process fixed bit ranges in the encoding.""" - for s2, s1, entry in fixed_ranges.findall(remaining): - msb, lsb, entry_value = int(s2), int(s1), int(entry, 0) - - # Validate bit range and entry value - validate_bit_range(msb, lsb, entry_value, line) - update_encoding_for_fixed_range(encoding, msb, lsb, entry_value, line) - - return fixed_ranges.sub(" ", remaining) - - -# Process single bit assignments -def process_single_fixed(remaining: str, encoding: "list[str]", line: str): - """Process single fixed assignments in the encoding.""" - for lsb, value, _drop in single_fixed.findall(remaining): - lsb = int(lsb, 0) - value = int(value, 0) - - check_overlapping_bits(encoding, lsb, line) - encoding[31 - lsb] = str(value) - - -# Main function to check argument look-up table -def check_arg_lut(args: "list[str]", encoding_args: "list[str]", name: str): - """Check if arguments are present in arg_lut.""" - for arg in args: - if arg not in arg_lut: - arg = handle_arg_lut_mapping(arg, name) - msb, lsb = arg_lut[arg] - update_encoding_args(encoding_args, arg, msb, lsb) - - -# Handle missing argument mappings -def handle_arg_lut_mapping(arg: str, name: str): - """Handle cases where an argument needs to be mapped to an existing one.""" - parts = arg.split("=") - if len(parts) == 2: - existing_arg, _new_arg = parts - if existing_arg in arg_lut: - arg_lut[arg] = arg_lut[existing_arg] - else: - log_and_exit( - f" Found field {existing_arg} in variable {arg} in instruction {name} " - f"whose mapping in arg_lut does not exist" - ) - else: - log_and_exit( - f" Found variable {arg} in instruction {name} " - f"whose mapping in arg_lut does not exist" - ) - return arg - - -# Update encoding args with variables -def update_encoding_args(encoding_args: "list[str]", arg: str, msb: int, lsb: int): - """Update encoding arguments and ensure no overlapping.""" - for ind in range(lsb, msb + 1): - check_overlapping_bits(encoding_args, ind, arg) - encoding_args[31 - ind] = arg - - -# Compute match and mask -def convert_encoding_to_match_mask(encoding: "list[str]") -> "tuple[str, str]": - """Convert the encoding list to match and mask strings.""" - match = "".join(encoding).replace("-", "0") - mask = "".join(encoding).replace("0", "1").replace("-", "0") - return hex(int(match, 2)), hex(int(mask, 2)) - - -class SingleInstr(TypedDict): - encoding: str - variable_fields: "list[str]" - extension: "list[str]" - match: str - mask: str - - -InstrDict = Dict[str, SingleInstr] - - -# Processing main function for a line in the encoding file -def process_enc_line(line: str, ext: str) -> "tuple[str, SingleInstr]": - """ - This function processes each line of the encoding files (rv*). As part of - the processing, the function ensures that the encoding is legal through the - following checks:: - - there is no over specification (same bits assigned different values) - - there is no under specification (some bits not assigned values) - - bit ranges are in the format hi..lo=val where hi > lo - - value assigned is representable in the bit range - - also checks that the mapping of arguments of an instruction exists in - arg_lut. - If the above checks pass, then the function returns a tuple of the name and - a dictionary containing basic information of the instruction which includes: - - variables: list of arguments used by the instruction whose mapping - exists in the arg_lut dictionary - - encoding: this contains the 32-bit encoding of the instruction where - '-' is used to represent position of arguments and 1/0 is used to - reprsent the static encoding of the bits - - extension: this field contains the rv* filename from which this - instruction was included - - match: hex value representing the bits that need to match to detect - this instruction - - mask: hex value representin the bits that need to be masked to extract - the value required for matching. - """ - encoding = initialize_encoding() - - # Parse the instruction line - name, remaining = parse_instruction_line(line) - - # Process fixed ranges - remaining = process_fixed_ranges(remaining, encoding, line) - - # Process single fixed assignments - process_single_fixed(remaining, encoding, line) - - # Convert the list of encodings into a match and mask - match, mask = convert_encoding_to_match_mask(encoding) - # Check arguments in arg_lut - args = single_fixed.sub(" ", remaining).split() - encoding_args = encoding.copy() - - check_arg_lut(args, encoding_args, name) - - # Return single_dict - return name, { - "encoding": "".join(encoding), - "variable_fields": args, - "extension": [os.path.basename(ext)], - "match": match, - "mask": mask, - } - - -# Extract ISA Type -def extract_isa_type(ext_name: str) -> str: - """Extracts the ISA type from the extension name.""" - return ext_name.split("_")[0] - - -# Verify the types for RV* -def is_rv_variant(type1: str, type2: str) -> bool: - """Checks if the types are RV variants (rv32/rv64).""" - return (type2 == "rv" and type1 in {"rv32", "rv64"}) or ( - type1 == "rv" and type2 in {"rv32", "rv64"} - ) - - -# Check for same base ISA -def has_same_base_isa(type1: str, type2: str) -> bool: - """Determines if the two ISA types share the same base.""" - return type1 == type2 or is_rv_variant(type1, type2) - - -# Compare the base ISA type of a given extension name against a list of extension names -def same_base_isa(ext_name: str, ext_name_list: "list[str]") -> bool: - """Checks if the base ISA type of ext_name matches any in ext_name_list.""" - type1 = extract_isa_type(ext_name) - return any(has_same_base_isa(type1, extract_isa_type(ext)) for ext in ext_name_list) - - -# Pad two strings to equal length -def pad_to_equal_length(str1: str, str2: str, pad_char: str = "-") -> "tuple[str, str]": - """Pads two strings to equal length using the given padding character.""" - max_len = max(len(str1), len(str2)) - return str1.rjust(max_len, pad_char), str2.rjust(max_len, pad_char) - - -# Check compatibility for two characters -def has_no_conflict(char1: str, char2: str) -> bool: - """Checks if two characters are compatible (either matching or don't-care).""" - return char1 == "-" or char2 == "-" or char1 == char2 - - -# Conflict check between two encoded strings -def overlaps(x: str, y: str) -> bool: - """Checks if two encoded strings overlap without conflict.""" - x, y = pad_to_equal_length(x, y) - return all(has_no_conflict(x[i], y[i]) for i in range(len(x))) - - -# Check presence of keys in dictionary. -def is_in_nested_dict(a: "dict[str, set[str]]", key1: str, key2: str) -> bool: - """Checks if key2 exists in the dictionary under key1.""" - return key1 in a and key2 in a[key1] - - -# Overlap allowance -def overlap_allowed(a: "dict[str, set[str]]", x: str, y: str) -> bool: - """Determines if overlap is allowed between x and y based on nested dictionary checks""" - return is_in_nested_dict(a, x, y) or is_in_nested_dict(a, y, x) - - -# Check overlap allowance between extensions -def extension_overlap_allowed(x: str, y: str) -> bool: - """Checks if overlap is allowed between two extensions using the overlapping_extensions dictionary.""" - return overlap_allowed(overlapping_extensions, x, y) - - -# Check overlap allowance between instructions -def instruction_overlap_allowed(x: str, y: str) -> bool: - """Checks if overlap is allowed between two instructions using the overlapping_instructions dictionary.""" - return overlap_allowed(overlapping_instructions, x, y) - - -# Check 'nf' field -def is_segmented_instruction(instruction: SingleInstr) -> bool: - """Checks if an instruction contains the 'nf' field.""" - return "nf" in instruction["variable_fields"] - - -# Expand 'nf' fields -def update_with_expanded_instructions( - updated_dict: InstrDict, key: str, value: SingleInstr -): - """Expands 'nf' fields in the instruction dictionary and updates it with new instructions.""" - for new_key, new_value in expand_nf_field(key, value): - updated_dict[new_key] = new_value - - -# Process instructions, expanding segmented ones and updating the dictionary -def add_segmented_vls_insn(instr_dict: InstrDict) -> InstrDict: - """Processes instructions, expanding segmented ones and updating the dictionary.""" - # Use dictionary comprehension for efficiency - return dict( - chain.from_iterable( - ( - expand_nf_field(key, value) - if is_segmented_instruction(value) - else [(key, value)] - ) - for key, value in instr_dict.items() - ) - ) - - -# Expand the 'nf' field in the instruction dictionary -def expand_nf_field( - name: str, single_dict: SingleInstr -) -> "list[tuple[str, SingleInstr]]": - """Validate and prepare the instruction dictionary.""" - validate_nf_field(single_dict, name) - remove_nf_field(single_dict) - update_mask(single_dict) - - name_expand_index = name.find("e") +import logging +import unittest +from unittest.mock import Mock, patch + +from shared_utils import * + + +class EncodingUtilsTest(unittest.TestCase): + """Tests for basic encoding utilities""" - # Pre compute the base match value and encoding prefix - base_match = int(single_dict["match"], 16) - encoding_prefix = single_dict["encoding"][3:] + def setUp(self): + self.logger = logging.getLogger() + self.logger.disabled = True - expanded_instructions = [ - create_expanded_instruction( - name, single_dict, nf, name_expand_index, base_match, encoding_prefix - ) - for nf in range(8) # Range of 0 to 7 - ] - - return expanded_instructions - - -# Validate the presence of 'nf' -def validate_nf_field(single_dict: SingleInstr, name: str): - """Validates the presence of 'nf' in variable fields before expansion.""" - if "nf" not in single_dict["variable_fields"]: - log_and_exit(f"Cannot expand nf field for instruction {name}") - - -# Remove 'nf' from variable fields -def remove_nf_field(single_dict: SingleInstr): - """Removes 'nf' from variable fields in the instruction dictionary.""" - single_dict["variable_fields"].remove("nf") - - -# Update the mask to include the 'nf' field -def update_mask(single_dict: SingleInstr): - """Updates the mask to include the 'nf' field in the instruction dictionary.""" - single_dict["mask"] = hex(int(single_dict["mask"], 16) | 0b111 << 29) - - -# Create an expanded instruction -def create_expanded_instruction( - name: str, - single_dict: SingleInstr, - nf: int, - name_expand_index: int, - base_match: int, - encoding_prefix: str, -) -> "tuple[str, SingleInstr]": - """Creates an expanded instruction based on 'nf' value.""" - new_single_dict = copy.deepcopy(single_dict) - - # Update match value in one step - new_single_dict["match"] = hex(base_match | (nf << 29)) - new_single_dict["encoding"] = format(nf, "03b") + encoding_prefix - - # Construct new instruction name - new_name = ( - name - if nf == 0 - else f"{name[:name_expand_index]}seg{nf + 1}{name[name_expand_index:]}" - ) - - return (new_name, new_single_dict) - - -# Return a list of relevant lines from the specified file -def read_lines(file: str) -> "list[str]": - """Reads lines from a file and returns non-blank, non-comment lines.""" - with open(file) as fp: - lines = (line.rstrip() for line in fp) - return [line for line in lines if line and not line.startswith("#")] - - -# Update the instruction dictionary -def process_standard_instructions( - lines: "list[str]", instr_dict: InstrDict, file_name: str -): - """Processes standard instructions from the given lines and updates the instruction dictionary.""" - for line in lines: - if "$import" in line or "$pseudo" in line: - continue - logging.debug(f"Processing line: {line}") - name, single_dict = process_enc_line(line, file_name) - ext_name = os.path.basename(file_name) - - if name in instr_dict: - var = instr_dict[name]["extension"] - if same_base_isa(ext_name, var): - log_and_exit( - f"Instruction {name} from {ext_name} is already added from {var} in same base ISA" - ) - elif instr_dict[name]["encoding"] != single_dict["encoding"]: - log_and_exit( - f"Instruction {name} from {ext_name} has different encodings in different base ISAs" - ) - - instr_dict[name]["extension"].extend(single_dict["extension"]) - else: - for key, item in instr_dict.items(): - if ( - overlaps(item["encoding"], single_dict["encoding"]) - and not extension_overlap_allowed(ext_name, item["extension"][0]) - and not instruction_overlap_allowed(name, key) - and same_base_isa(ext_name, item["extension"]) - ): - log_and_exit( - f'Instruction {name} in extension {ext_name} overlaps with {key} in {item["extension"]}' - ) - - instr_dict[name] = single_dict - - -# Incorporate pseudo instructions into the instruction dictionary based on given conditions -def process_pseudo_instructions( - lines: "list[str]", - instr_dict: InstrDict, - file_name: str, - opcodes_dir: str, - include_pseudo: bool, - include_pseudo_ops: "list[str]", -): - """Processes pseudo instructions from the given lines and updates the instruction dictionary.""" - for line in lines: - if "$pseudo" not in line: - continue - logging.debug(f"Processing pseudo line: {line}") - ext, orig_inst, pseudo_inst, line_content = pseudo_regex.findall(line)[0] - ext_file = find_extension_file(ext, opcodes_dir) - - validate_instruction_in_extension(orig_inst, ext_file, file_name, pseudo_inst) - - name, single_dict = process_enc_line(f"{pseudo_inst} {line_content}", file_name) - if ( - orig_inst.replace(".", "_") not in instr_dict - or include_pseudo - or name in include_pseudo_ops - ): - if name not in instr_dict: - instr_dict[name] = single_dict - logging.debug(f"Including pseudo_op: {name}") - else: - if single_dict["match"] != instr_dict[name]["match"]: - instr_dict[f"{name}_pseudo"] = single_dict - # TODO: This expression is always false since both sides are list[str]. - elif single_dict["extension"] not in instr_dict[name]["extension"]: # type: ignore - instr_dict[name]["extension"].extend(single_dict["extension"]) - - -# Integrate imported instructions into the instruction dictionary -def process_imported_instructions( - lines: "list[str]", instr_dict: InstrDict, file_name: str, opcodes_dir: str -): - """Processes imported instructions from the given lines and updates the instruction dictionary.""" - for line in lines: - if "$import" not in line: - continue - logging.debug(f"Processing imported line: {line}") - import_ext, reg_instr = imported_regex.findall(line)[0] - ext_file = find_extension_file(import_ext, opcodes_dir) - - validate_instruction_in_extension(reg_instr, ext_file, file_name, line) - - for oline in open(ext_file): - if re.findall(f"^\\s*{reg_instr}\\s+", oline): - name, single_dict = process_enc_line(oline, file_name) - if name in instr_dict: - if instr_dict[name]["encoding"] != single_dict["encoding"]: - log_and_exit( - f"Imported instruction {name} from {os.path.basename(file_name)} has different encodings" - ) - instr_dict[name]["extension"].extend(single_dict["extension"]) - else: - instr_dict[name] = single_dict - break - - -# Locate the path of the specified extension file, checking fallback directories -def find_extension_file(ext: str, opcodes_dir: str): - """Finds the extension file path, considering the unratified directory if necessary.""" - ext_file = f"{opcodes_dir}/{ext}" - if not os.path.exists(ext_file): - ext_file = f"{opcodes_dir}/unratified/{ext}" - if not os.path.exists(ext_file): - log_and_exit(f"Extension {ext} not found.") - return ext_file - - -# Confirm the presence of an original instruction in the corresponding extension file. -def validate_instruction_in_extension( - inst: str, ext_file: str, file_name: str, pseudo_inst: str -): - """Validates if the original instruction exists in the dependent extension.""" - found = False - for oline in open(ext_file): - if re.findall(f"^\\s*{inst}\\s+", oline): - found = True - break - if not found: - log_and_exit( - f"Original instruction {inst} required by pseudo_op {pseudo_inst} in {file_name} not found in {ext_file}" - ) + def test_initialize_encoding(self): + """Test encoding initialization with different bit lengths""" + self.assertEqual(initialize_encoding(32), ["-"] * 32) + self.assertEqual(initialize_encoding(16), ["-"] * 16) + self.assertEqual(initialize_encoding(), ["-"] * 32) # default case + def test_validate_bit_range(self): + """Test bit range validation""" + # Valid cases + validate_bit_range(7, 3, 15, "test_instr") # 15 fits in 5 bits + validate_bit_range(31, 0, 0xFFFFFFFF, "test_instr") # max 32-bit value + + # Invalid cases + with self.assertRaises(SystemExit): + validate_bit_range(3, 7, 1, "test_instr") # msb < lsb + with self.assertRaises(SystemExit): + validate_bit_range(3, 0, 16, "test_instr") # value too large for range -# Construct a dictionary of instructions filtered by specified criteria -def create_inst_dict( - file_filter: "list[str]", - include_pseudo: bool = False, - include_pseudo_ops: "list[str]" = [], -) -> InstrDict: - """Creates a dictionary of instructions based on the provided file filters.""" - - """ - This function return a dictionary containing all instructions associated - with an extension defined by the file_filter input. - Allowed input extensions: needs to be rv* file name without the 'rv' prefix i.e. '_i', '32_i', etc. - Each node of the dictionary will correspond to an instruction which again is - a dictionary. The dictionary contents of each instruction includes: - - variables: list of arguments used by the instruction whose mapping - exists in the arg_lut dictionary - - encoding: this contains the 32-bit encoding of the instruction where - '-' is used to represent position of arguments and 1/0 is used to - reprsent the static encoding of the bits - - extension: this field contains the rv* filename from which this - instruction was included - - match: hex value representing the bits that need to match to detect - this instruction - - mask: hex value representin the bits that need to be masked to extract - the value required for matching. - In order to build this dictionary, the function does 2 passes over the same - rv file: - - First pass: extracts all standard instructions, skipping pseudo ops - and imported instructions. For each selected line, the `process_enc_line` - function is called to create the dictionary contents of the instruction. - Checks are performed to ensure that the same instruction is not added - twice to the overall dictionary. - - Second pass: parses only pseudo_ops. For each pseudo_op, the function: - - Checks if the dependent extension and instruction exist. - - Adds the pseudo_op to the dictionary if the dependent instruction - is not already present; otherwise, it is skipped. - """ - opcodes_dir = os.path.dirname(os.path.realpath(__file__)) - instr_dict: InstrDict = {} - - file_names = [ - file - for fil in file_filter - for file in sorted(glob.glob(f"{opcodes_dir}/{fil}"), reverse=True) - ] - - logging.debug("Collecting standard instructions") - for file_name in file_names: - logging.debug(f"Parsing File: {file_name} for standard instructions") - lines = read_lines(file_name) - process_standard_instructions(lines, instr_dict, file_name) - - logging.debug("Collecting pseudo instructions") - for file_name in file_names: - logging.debug(f"Parsing File: {file_name} for pseudo instructions") - lines = read_lines(file_name) - process_pseudo_instructions( - lines, - instr_dict, - file_name, - opcodes_dir, - include_pseudo, - include_pseudo_ops, + def test_parse_instruction_line(self): + """Test instruction line parsing""" + name, remaining = parse_instruction_line("add.w r1, r2, r3") + self.assertEqual(name, "add_w") + self.assertEqual(remaining, "r1, r2, r3") + + name, remaining = parse_instruction_line("lui rd imm20 6..2=0x0D") + self.assertEqual(name, "lui") + self.assertEqual(remaining, "rd imm20 6..2=0x0D") + + +class BitManipulationTest(unittest.TestCase): + """Tests for bit manipulation and checking functions""" + + def setUp(self): + self.logger = logging.getLogger() + self.logger.disabled = True + self.test_encoding = initialize_encoding() + + def test_check_overlapping_bits(self): + """Test overlapping bits detection""" + # Valid case - no overlap + self.test_encoding[31 - 5] = "-" + check_overlapping_bits(self.test_encoding, 5, "test_instr") + + # Invalid case - overlap + self.test_encoding[31 - 5] = "1" + with self.assertRaises(SystemExit): + check_overlapping_bits(self.test_encoding, 5, "test_instr") + + def test_update_encoding_for_fixed_range(self): + """Test encoding updates for fixed ranges""" + encoding = initialize_encoding() + update_encoding_for_fixed_range(encoding, 6, 2, 0x0D, "test_instr") + + # Check specific bits are set correctly + self.assertEqual(encoding[31 - 6 : 31 - 1], ["0", "1", "1", "0", "1"]) + + def test_process_fixed_ranges(self): + """Test processing of fixed bit ranges""" + encoding = initialize_encoding() + remaining = "rd imm20 6..2=0x0D 1..0=3" + + result = process_fixed_ranges(remaining, encoding, "test_instr") + self.assertNotIn("6..2=0x0D", result) + self.assertNotIn("1..0=3", result) + + +class EncodingArgsTest(unittest.TestCase): + """Tests for encoding arguments handling""" + + def setUp(self): + self.logger = logging.getLogger() + self.logger.disabled = True + + @patch.dict("shared_utils.arg_lut", {"rd": (11, 7), "rs1": (19, 15)}) + def test_check_arg_lut(self): + """Test argument lookup table checking""" + encoding_args = initialize_encoding() + args = ["rd", "rs1"] + check_arg_lut(args, encoding_args, "test_instr") + + # Verify encoding_args has been updated correctly + self.assertEqual(encoding_args[31 - 11 : 31 - 6], ["rd"] * 5) + self.assertEqual(encoding_args[31 - 19 : 31 - 14], ["rs1"] * 5) + + @patch.dict("shared_utils.arg_lut", {"rs1": (19, 15)}) + def test_handle_arg_lut_mapping(self): + """Test handling of argument mappings""" + # Valid mapping + result = handle_arg_lut_mapping("rs1=new_arg", "test_instr") + self.assertEqual(result, "rs1=new_arg") + + # Invalid mapping + with self.assertRaises(SystemExit): + handle_arg_lut_mapping("invalid_arg=new_arg", "test_instr") + + +class ISAHandlingTest(unittest.TestCase): + """Tests for ISA type handling and validation""" + + def test_extract_isa_type(self): + """Test ISA type extraction""" + self.assertEqual(extract_isa_type("rv32_i"), "rv32") + self.assertEqual(extract_isa_type("rv64_m"), "rv64") + self.assertEqual(extract_isa_type("rv_c"), "rv") + + def test_is_rv_variant(self): + """Test RV variant checking""" + self.assertTrue(is_rv_variant("rv32", "rv")) + self.assertTrue(is_rv_variant("rv", "rv64")) + self.assertFalse(is_rv_variant("rv32", "rv64")) + + def test_same_base_isa(self): + """Test base ISA comparison""" + self.assertTrue(same_base_isa("rv32_i", ["rv32_m", "rv32_a"])) + self.assertTrue(same_base_isa("rv_i", ["rv32_i", "rv64_i"])) + self.assertFalse(same_base_isa("rv32_i", ["rv64_m"])) + + +class StringManipulationTest(unittest.TestCase): + """Tests for string manipulation utilities""" + + def test_pad_to_equal_length(self): + """Test string padding""" + str1, str2 = pad_to_equal_length("101", "1101") + self.assertEqual(len(str1), len(str2)) + self.assertEqual(str1, "-101") + self.assertEqual(str2, "1101") + + def test_overlaps(self): + """Test string overlap checking""" + self.assertTrue(overlaps("1-1", "101")) + self.assertTrue(overlaps("---", "101")) + self.assertFalse(overlaps("111", "101")) + + +class InstructionProcessingTest(unittest.TestCase): + """Tests for instruction processing and validation""" + + def setUp(self): + self.logger = logging.getLogger() + self.logger.disabled = True + # Create a patch for arg_lut + self.arg_lut_patcher = patch.dict( + "shared_utils.arg_lut", {"rd": (11, 7), "imm20": (31, 12)} ) - - logging.debug("Collecting imported instructions") - for file_name in file_names: - logging.debug(f"Parsing File: {file_name} for imported instructions") - lines = read_lines(file_name) - process_imported_instructions(lines, instr_dict, file_name, opcodes_dir) - - return instr_dict - - -# Extracts the extensions used in an instruction dictionary -def instr_dict_2_extensions(instr_dict: InstrDict) -> "list[str]": - return list({item["extension"][0] for item in instr_dict.values()}) - - -# Returns signed interpretation of a value within a given width -def signed(value: int, width: int) -> int: - return value if 0 <= value < (1 << (width - 1)) else value - (1 << width) + self.arg_lut_patcher.start() + + def tearDown(self): + self.arg_lut_patcher.stop() + + @patch("shared_utils.fixed_ranges") + @patch("shared_utils.single_fixed") + def test_process_enc_line(self, mock_single_fixed: Mock, mock_fixed_ranges: Mock): + """Test processing of encoding lines""" + # Setup mock return values + mock_fixed_ranges.findall.return_value = [(6, 2, "0x0D")] + mock_fixed_ranges.sub.return_value = "rd imm20" + mock_single_fixed.findall.return_value = [] + mock_single_fixed.sub.return_value = "rd imm20" + + # Create a mock for split() that returns the expected list + mock_split = Mock(return_value=["rd", "imm20"]) + mock_single_fixed.sub.return_value = Mock(split=mock_split) + + name, data = process_enc_line("lui rd imm20 6..2=0x0D", "rv_i") + + self.assertEqual(name, "lui") + self.assertEqual(data["extension"], ["rv_i"]) + self.assertIn("rd", data["variable_fields"]) + self.assertIn("imm20", data["variable_fields"]) + + @patch("os.path.exists") + @patch("shared_utils.logging.error") + def test_find_extension_file(self, mock_logging: Mock, mock_exists: Mock): + """Test extension file finding""" + # Test successful case - file exists in main directory + mock_exists.side_effect = [True, False] + result = find_extension_file("rv32i", "/path/to/opcodes") + self.assertEqual(result, "/path/to/opcodes/rv32i") + + # Test successful case - file exists in unratified directory + mock_exists.side_effect = [False, True] + result = find_extension_file("rv32i", "/path/to/opcodes") + self.assertEqual(result, "/path/to/opcodes/unratified/rv32i") + + # Test failure case - file doesn't exist anywhere + mock_exists.side_effect = [False, False] + with self.assertRaises(SystemExit): + find_extension_file("rv32i", "/path/to/opcodes") + mock_logging.assert_called_with("Extension rv32i not found.") + + def test_process_standard_instructions(self): + """Test processing of standard instructions""" + lines = [ + "add rd rs1 rs2 31..25=0 14..12=0 6..2=0x0C 1..0=3", + "sub rd rs1 rs2 31..25=0x20 14..12=0 6..2=0x0C 1..0=3", + "$pseudo add_pseudo rd rs1 rs2", # Should be skipped + "$import rv32i::mul", # Should be skipped + ] + + instr_dict: InstrDict = {} + file_name = "rv32i" + + with patch("shared_utils.process_enc_line") as mock_process_enc: + # Setup mock return values + mock_process_enc.side_effect = [ + ("add", {"extension": ["rv32i"], "encoding": "encoding1"}), + ("sub", {"extension": ["rv32i"], "encoding": "encoding2"}), + ] + + process_standard_instructions(lines, instr_dict, file_name) + + # Verify process_enc_line was called twice (skipping pseudo and import) + self.assertEqual(mock_process_enc.call_count, 2) + + # Verify the instruction dictionary was updated correctly + self.assertEqual(len(instr_dict), 2) + self.assertIn("add", instr_dict) + self.assertIn("sub", instr_dict) + + +if __name__ == "__main__": + unittest.main()