Skip to content

Commit

Permalink
Handle ARM32 jump tables properly (#163)
Browse files Browse the repository at this point in the history
  • Loading branch information
abaresk authored Jun 7, 2024
1 parent 54d4d82 commit eeb204d
Showing 1 changed file with 123 additions and 6 deletions.
129 changes: 123 additions & 6 deletions diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -1211,7 +1211,8 @@ def preprocess_objdump_out(
+ out
)

return out
processor = config.arch.proc(config)
return processor.preprocess_objdump(out)


def search_build_objects(objname: str, project: ProjectSettings) -> Optional[str]:
Expand Down Expand Up @@ -1602,15 +1603,16 @@ def dump_binary(
)


# Example: "ldr r4, [pc, #56] ; (4c <AddCoins+0x4c>)"
ARM32_LOAD_POOL_PATTERN = r"(ldr\s+r([0-9]|1[0-3]),\s+\[pc,.*;\s*)(\([a-fA-F0-9]+.*\))"


# The base class is a no-op.
class AsmProcessor:
def __init__(self, config: Config) -> None:
self.config = config

# Called during run_objdump() for arch-specific normalization. Runs before
# diff-processing, i.e. process().
def preprocess_objdump(self, objdump: str) -> str:
return objdump

def pre_process(
self, mnemonic: str, args: str, next_row: Optional[str]
) -> Tuple[str, str]:
Expand Down Expand Up @@ -1780,8 +1782,100 @@ def process_reloc(self, row: str, prev: str) -> Tuple[str, Optional[str]]:
def is_end_of_function(self, mnemonic: str, args: str) -> bool:
return mnemonic == "blr"

# Example: "cmp r0, #0x10"
ARM32_COMPARE_IMM_PATTERN = r"cmp\s+(r[0-9]|1[0-3]),\s+#(\w+)"

# Example: "add pc, r1"
ARM32_JUMP_TABLE_START = r"add\s+pc,\s*r"

# Examples:
# - "44: 00060032 .word 0x00060032"
# - "48: 0032 .short 0x0032"
# - "4a: 0032 movs r2, r6"
# - "9e: 00be lsls r6, r7, #2"
# - ".short 0x0032 ; 0x64"
ARM32_JUMP_TABLE_ENTRY_PATTERN = r"(?:(\w+):\s+([0-9a-f]+)\s+)?([\w\.]+)\s+([\w,\ ]+)"

# Example: "ldr r4, [pc, #56] ; (4c <AddCoins+0x4c>)"
ARM32_LOAD_POOL_PATTERN = r"(ldr\s+r([0-9]|1[0-3]),\s+\[pc,.*;\s*)(\([a-fA-F0-9]+.*\))"

class AsmProcessorARM32(AsmProcessor):
@dataclass
class JumpTableEntry:
cur_addr: int
table_start_addr: int
value: int
is_word: bool

def preprocess_objdump(self, objdump: str) -> str:
def short_table_entry(cur_addr: int, jump_table_start_addr: int, value: int) -> str:
branch_target = jump_table_start_addr + value + 4
return f" {cur_addr:x}: {value:04x} .short 0x{value:04x} ; 0x{branch_target:x}"

new_lines = []
lines = objdump.splitlines()
for i, jump_table_entry in self._lines_iterator(lines):
if jump_table_entry is None:
new_lines.append(lines[i])
continue

entry = jump_table_entry
if entry.is_word:
# Split into two ".short" entries.
hi, lo = entry.value >> 16, entry.value & 0xffff
new_lines.append(
short_table_entry(entry.cur_addr, entry.table_start_addr, lo)
)
new_lines.append(
short_table_entry(entry.cur_addr + 2, entry.table_start_addr, hi)
)
else:
new_lines.append(
short_table_entry(entry.cur_addr, entry.table_start_addr, entry.value)
)
return "\n".join(new_lines)

# An iterator for each line of assembly, returning the line index and optional
# metadata if the line is a jump table entry.
def _lines_iterator(self, lines: List[str]) -> Iterator[Tuple[int, Optional[JumpTableEntry]]]:
jump_table_entries = 0
table_start_addr = 0
for i, line in enumerate(lines):
addr_match = re.match(r"^\s*([0-9a-f]+):", line)
addr = int(addr_match.group(1), 16) if addr_match else -1
entry_match = re.search(ARM32_JUMP_TABLE_ENTRY_PATTERN, line)
if jump_table_entries > 0 and entry_match:
value = entry_match.group(4) if is_hexstring(entry_match.group(4)) else entry_match.group(2)
table_entry = self.JumpTableEntry(
cur_addr=addr,
table_start_addr=table_start_addr,
value = int(value, 16),
is_word = entry_match.group(3) == ".word",
)
jump_table_entries -= 2 if table_entry.is_word else 1

yield i, table_entry
continue

# Check for jump tables.
if re.search(ARM32_JUMP_TABLE_START, line):
jump_table_entries = self._jump_table_entries_count(lines, i)
table_start_addr = addr
yield i, None

# Returns the number of entries in the jump table starting at `line_no`, or
# 0 if it's not a jump table.
def _jump_table_entries_count(self, raw_lines: List[str], line_no: int) -> int:
# The number of entries should be in the most recent `cmp` before the
# jump table.
for i in reversed(range(line_no)):
cmp_match = re.search(ARM32_COMPARE_IMM_PATTERN, raw_lines[i])
if cmp_match:
value = immediate_to_int(cmp_match.group(2))
if value > 0:
return value + 1
return 0

def process_reloc(self, row: str, prev: str) -> Tuple[str, Optional[str]]:
arch = self.config.arch
if "R_ARM_V4BX" in row:
Expand Down Expand Up @@ -1814,7 +1908,16 @@ def _normalize_data_pool(self, row: str) -> str:
pool_match = re.search(ARM32_LOAD_POOL_PATTERN, row)
return pool_match.group(1) if pool_match else row

def post_process(self, lines: List["Line"]) -> None:
def _post_process_jump_tables(self, lines: List["Line"]) -> None:
raw_lines = [f"{line.line_num:x}: {line.original}" for line in lines]
for i, jump_table_entry in self._lines_iterator(raw_lines):
if jump_table_entry is None:
continue

entry = jump_table_entry
lines[i].branch_target = entry.table_start_addr + entry.value + 4

def _post_process_data_pools(self, lines: List["Line"]) -> None:
lines_by_line_number = {}
for line in lines:
lines_by_line_number[line.line_num] = line
Expand All @@ -1828,6 +1931,9 @@ def post_process(self, lines: List["Line"]) -> None:
addr = "{:x}".format(line.data_pool_addr)
line.original = line.normalized_original + f"={value} ({addr})"

def post_process(self, lines: List["Line"]) -> None:
self._post_process_jump_tables(lines)
self._post_process_data_pools(lines)

class AsmProcessorAArch64(AsmProcessor):
def __init__(self, config: Config) -> None:
Expand Down Expand Up @@ -2491,6 +2597,17 @@ class ArchSettings:
M68K_SETTINGS,
]

def immediate_to_int(immediate: str) -> int:
imm_match = re.match(r"#?(0x)?([0-9a-f]+)", immediate)
base = 16 if imm_match.group(1) else 10
return int(imm_match.group(2), base)

def is_hexstring(value: str) -> bool:
try:
int(value, 16)
return True
except ValueError:
return False

def hexify_int(row: str, pat: Match[str], arch: ArchSettings) -> str:
full = pat.group(0)
Expand Down

0 comments on commit eeb204d

Please sign in to comment.