Skip to content

Commit

Permalink
Fix mypy and black issues
Browse files Browse the repository at this point in the history
  • Loading branch information
berendbutje committed Jan 4, 2024
1 parent 5199884 commit d9f2f29
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 32 deletions.
57 changes: 38 additions & 19 deletions src/splat/platforms/wasm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from wasm_tob import (
format_instruction,
format_function,
Section,
SEC_TYPE,
TypeSection,
Expand All @@ -16,8 +15,9 @@
DataSegment,
CodeSection,
FunctionBody,
INSN_LEAVE_BLOCK, INSN_ENTER_BLOCK,
decode_bytecode
INSN_LEAVE_BLOCK,
INSN_ENTER_BLOCK,
decode_bytecode,
)

from enum import IntEnum
Expand Down Expand Up @@ -107,6 +107,7 @@ def data_section_to_wat(section: DataSection) -> str:
data_segment_to_wat(index, entry) for index, entry in enumerate(section.entries)
)


def format_function(
fname: str,
type_index: int,
Expand All @@ -119,40 +120,58 @@ def format_function(

func_fmt = '(func "{name}" (type {type_index}){param_section}{result_section}'

param_section = ' (param {})'.format(' '.join(
map(format_lang_type, func_type.param_types)
)) if func_type.param_types else ''
result_section = ' (result {})'.format(
format_lang_type(func_type.return_type)
) if func_type.return_type else ''
param_section = (
" (param {})".format(" ".join(map(format_lang_type, func_type.param_types)))
if func_type.param_types
else ""
)
result_section = (
" (result {})".format(format_lang_type(func_type.return_type))
if func_type.return_type
else ""
)

yield func_fmt.format(name=fname, type_index=type_index, param_section=param_section, result_section=result_section)
yield func_fmt.format(
name=fname,
type_index=type_index,
param_section=param_section,
result_section=result_section,
)

if format_locals and func_body.locals:
yield ' ' * indent + '(local {})'.format(' '.join(itertools.chain.from_iterable(
itertools.repeat(format_lang_type(x.type), x.count)
for x in func_body.locals
)))
yield " " * indent + "(local {})".format(
" ".join(
itertools.chain.from_iterable(
itertools.repeat(format_lang_type(x.type), x.count)
for x in func_body.locals
)
)
)

level = 1
for cur_insn in decode_bytecode(func_body.code):
if cur_insn.op.flags & INSN_LEAVE_BLOCK:
level -= 1
yield ' ' * (level * indent) + format_instruction(cur_insn)
yield " " * (level * indent) + format_instruction(cur_insn)
if cur_insn.op.flags & INSN_ENTER_BLOCK:
level += 1


def code_section_to_wat(code: CodeSection, funcs: FunctionSection, types: TypeSection) -> str:

def code_section_to_wat(
code: CodeSection, funcs: FunctionSection, types: TypeSection
) -> str:
src = ""
for index, body in enumerate(code.bodies):
type_index = funcs.types[index]

try:
src += "\n".join(format_function(f'func_{index:04d}', type_index, body, types.entries[type_index]))
src += "\n".join(
format_function(
f"func_{index:04d}", type_index, body, types.entries[type_index]
)
)
except Exception as e:
src += f';; Exception: {str(e)}'
src += f";; Exception: {str(e)}"
pass

src += "\n"
Expand Down
15 changes: 7 additions & 8 deletions src/splat/segtypes/wasm/module.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from pathlib import Path
from typing import Optional
from typing import Optional, Dict

from ...util import options

from ..common.group import CommonSegGroup

from wasm_tob import (
ModuleHeader,
Section,
SEC_UNK,
SEC_TYPE,
SEC_IMPORT,
Expand All @@ -21,7 +22,8 @@
SEC_DATA,
)

class WasmSegModule(CommonSegGroup):

class WasmSegModule(CommonSegGroup):
def __init__(
self,
rom_start: Optional[int],
Expand All @@ -41,10 +43,7 @@ def __init__(
args=args,
yaml=yaml,
)
self.magic = None
self.version = None

self.sections = { }
self.sections: Dict[int, Section] = {}

@staticmethod
def is_text() -> bool:
Expand All @@ -69,7 +68,7 @@ def get_section_payload(self, id: int):
@property
def type_section(self):
return self.get_section_payload(SEC_TYPE)

@property
def import_section(self):
return self.get_section_payload(SEC_IMPORT)
Expand All @@ -88,4 +87,4 @@ def data_section(self):

@property
def code_section(self):
return self.get_section_payload(SEC_CODE)
return self.get_section_payload(SEC_CODE)
16 changes: 11 additions & 5 deletions src/splat/segtypes/wasm/section.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,17 @@
code_section_to_wat,
)

from .module import WasmSegModule

SECTION_TO_WAT = {
SEC_TYPE: lambda mod: type_section_to_wat(mod.type_section),
SEC_IMPORT: lambda mod: import_section_to_wat(mod.import_section),
SEC_FUNCTION: lambda mod: function_section_to_wat(mod.function_section),
SEC_EXPORT: lambda mod: export_section_to_wat(mod.export_section),
SEC_DATA: lambda mod: data_section_to_wat(mod.data_section),
SEC_CODE: lambda mod: code_section_to_wat(mod.code_section, mod.function_section, mod.type_section),
SEC_CODE: lambda mod: code_section_to_wat(
mod.code_section, mod.function_section, mod.type_section
),
}

SECTION_TO_STR = {
Expand All @@ -58,7 +61,10 @@
SEC_DATA: "data",
}


class WasmSegSection(CommonSegment):
parent: WasmSegModule

def __init__(
self,
rom_start: Optional[int],
Expand All @@ -78,8 +84,8 @@ def __init__(
args=args,
yaml=yaml,
)
self.section = None
self.section: Section = None

@staticmethod
def is_text() -> bool:
return True
Expand All @@ -100,7 +106,7 @@ def scan(self, rom_bytes: bytes):
)
else:
self.parent.sections[self.section.id] = self.section

pass

def split(self, rom_bytes: bytes):
Expand All @@ -110,6 +116,6 @@ def split(self, rom_bytes: bytes):
out_path = self.out_path()
if out_path:
out_path.parent.mkdir(parents=True, exist_ok=True)

with open(out_path, "w") as f:
f.write(SECTION_TO_WAT[self.section.id](self.parent))

0 comments on commit d9f2f29

Please sign in to comment.