diff --git a/diff.py b/diff.py index 249c3ca..2612938 100755 --- a/diff.py +++ b/diff.py @@ -1211,7 +1211,8 @@ def preprocess_objdump_out( + out ) - return out + processor = config.arch.proc(config) + return processor.preprocess_objdump(out) def search_build_objects(objname: str, project: ProjectSettings) -> Optional[str]: @@ -1602,15 +1603,16 @@ def dump_binary( ) -# Example: "ldr r4, [pc, #56] ; (4c )" -ARM32_LOAD_POOL_PATTERN = r"(ldr\s+r([0-9]|1[0-3]),\s+\[pc,.*;\s*)(\([a-fA-F0-9]+.*\))" - - # The base class is a no-op. class AsmProcessor: def __init__(self, config: Config) -> None: self.config = config + # Called during run_objdump() for arch-specific normalization. Runs before + # diff-processing, i.e. process(). + def preprocess_objdump(self, objdump: str) -> str: + return objdump + def pre_process( self, mnemonic: str, args: str, next_row: Optional[str] ) -> Tuple[str, str]: @@ -1780,8 +1782,100 @@ def process_reloc(self, row: str, prev: str) -> Tuple[str, Optional[str]]: def is_end_of_function(self, mnemonic: str, args: str) -> bool: return mnemonic == "blr" +# Example: "cmp r0, #0x10" +ARM32_COMPARE_IMM_PATTERN = r"cmp\s+(r[0-9]|1[0-3]),\s+#(\w+)" + +# Example: "add pc, r1" +ARM32_JUMP_TABLE_START = r"add\s+pc,\s*r" + +# Examples: +# - "44: 00060032 .word 0x00060032" +# - "48: 0032 .short 0x0032" +# - "4a: 0032 movs r2, r6" +# - "9e: 00be lsls r6, r7, #2" +# - ".short 0x0032 ; 0x64" +ARM32_JUMP_TABLE_ENTRY_PATTERN = r"(?:(\w+):\s+([0-9a-f]+)\s+)?([\w\.]+)\s+([\w,\ ]+)" + +# Example: "ldr r4, [pc, #56] ; (4c )" +ARM32_LOAD_POOL_PATTERN = r"(ldr\s+r([0-9]|1[0-3]),\s+\[pc,.*;\s*)(\([a-fA-F0-9]+.*\))" class AsmProcessorARM32(AsmProcessor): + @dataclass + class JumpTableEntry: + cur_addr: int + table_start_addr: int + value: int + is_word: bool + + def preprocess_objdump(self, objdump: str) -> str: + def short_table_entry(cur_addr: int, jump_table_start_addr: int, value: int) -> str: + branch_target = jump_table_start_addr + value + 4 + return f" {cur_addr:x}: {value:04x} .short 0x{value:04x} ; 0x{branch_target:x}" + + new_lines = [] + lines = objdump.splitlines() + for i, jump_table_entry in self._lines_iterator(lines): + if jump_table_entry is None: + new_lines.append(lines[i]) + continue + + entry = jump_table_entry + if entry.is_word: + # Split into two ".short" entries. + hi, lo = entry.value >> 16, entry.value & 0xffff + new_lines.append( + short_table_entry(entry.cur_addr, entry.table_start_addr, lo) + ) + new_lines.append( + short_table_entry(entry.cur_addr + 2, entry.table_start_addr, hi) + ) + else: + new_lines.append( + short_table_entry(entry.cur_addr, entry.table_start_addr, entry.value) + ) + return "\n".join(new_lines) + + # An iterator for each line of assembly, returning the line index and optional + # metadata if the line is a jump table entry. + def _lines_iterator(self, lines: List[str]) -> Iterator[Tuple[int, Optional[JumpTableEntry]]]: + jump_table_entries = 0 + table_start_addr = 0 + for i, line in enumerate(lines): + addr_match = re.match(r"^\s*([0-9a-f]+):", line) + addr = int(addr_match.group(1), 16) if addr_match else -1 + entry_match = re.search(ARM32_JUMP_TABLE_ENTRY_PATTERN, line) + if jump_table_entries > 0 and entry_match: + value = entry_match.group(4) if is_hexstring(entry_match.group(4)) else entry_match.group(2) + table_entry = self.JumpTableEntry( + cur_addr=addr, + table_start_addr=table_start_addr, + value = int(value, 16), + is_word = entry_match.group(3) == ".word", + ) + jump_table_entries -= 2 if table_entry.is_word else 1 + + yield i, table_entry + continue + + # Check for jump tables. + if re.search(ARM32_JUMP_TABLE_START, line): + jump_table_entries = self._jump_table_entries_count(lines, i) + table_start_addr = addr + yield i, None + + # Returns the number of entries in the jump table starting at `line_no`, or + # 0 if it's not a jump table. + def _jump_table_entries_count(self, raw_lines: List[str], line_no: int) -> int: + # The number of entries should be in the most recent `cmp` before the + # jump table. + for i in reversed(range(line_no)): + cmp_match = re.search(ARM32_COMPARE_IMM_PATTERN, raw_lines[i]) + if cmp_match: + value = immediate_to_int(cmp_match.group(2)) + if value > 0: + return value + 1 + return 0 + def process_reloc(self, row: str, prev: str) -> Tuple[str, Optional[str]]: arch = self.config.arch if "R_ARM_V4BX" in row: @@ -1814,7 +1908,16 @@ def _normalize_data_pool(self, row: str) -> str: pool_match = re.search(ARM32_LOAD_POOL_PATTERN, row) return pool_match.group(1) if pool_match else row - def post_process(self, lines: List["Line"]) -> None: + def _post_process_jump_tables(self, lines: List["Line"]) -> None: + raw_lines = [f"{line.line_num:x}: {line.original}" for line in lines] + for i, jump_table_entry in self._lines_iterator(raw_lines): + if jump_table_entry is None: + continue + + entry = jump_table_entry + lines[i].branch_target = entry.table_start_addr + entry.value + 4 + + def _post_process_data_pools(self, lines: List["Line"]) -> None: lines_by_line_number = {} for line in lines: lines_by_line_number[line.line_num] = line @@ -1828,6 +1931,9 @@ def post_process(self, lines: List["Line"]) -> None: addr = "{:x}".format(line.data_pool_addr) line.original = line.normalized_original + f"={value} ({addr})" + def post_process(self, lines: List["Line"]) -> None: + self._post_process_jump_tables(lines) + self._post_process_data_pools(lines) class AsmProcessorAArch64(AsmProcessor): def __init__(self, config: Config) -> None: @@ -2491,6 +2597,17 @@ class ArchSettings: M68K_SETTINGS, ] +def immediate_to_int(immediate: str) -> int: + imm_match = re.match(r"#?(0x)?([0-9a-f]+)", immediate) + base = 16 if imm_match.group(1) else 10 + return int(imm_match.group(2), base) + +def is_hexstring(value: str) -> bool: + try: + int(value, 16) + return True + except ValueError: + return False def hexify_int(row: str, pat: Match[str], arch: ArchSettings) -> str: full = pat.group(0)