diff --git a/README.md b/README.md index bf21d796..3a0a8900 100644 --- a/README.md +++ b/README.md @@ -46,6 +46,10 @@ The next steps differ based on what kind of project you have. All scripts will become available to use in your terminal with the `reccmp-` prefix. Note that these scripts need to be executed in the directory where `reccmp-build.yml` is located. +* [`aggregate`](/reccmp/tools/aggregate.py): Combines JSON reports into a single file. + * Aggregate using highest accuracy score: `reccmp-aggregate --samples ./sample0.json ./sample1.json ./sample2.json --output ./combined.json` + * Diff two saved reports: `reccmp-aggregate --diff ./before.json ./after.json` + * Diff against the aggregate: `reccmp-aggregate --samples ./sample0.json ./sample1.json ./sample2.json --diff ./before.json` * [`decomplint`](/reccmp/tools/decomplint.py): Checks the decompilation annotations (see above) * e.g. `reccmp-decomplint --module LEGO1 LEGO1` * [`reccmp`](/reccmp/tools/asmcmp.py): Compares an original binary with a recompiled binary, provided a PDB file. For example: diff --git a/reccmp/assets/template.html b/reccmp/assets/template.html index 3ac15b45..6518c5a6 100644 --- a/reccmp/assets/template.html +++ b/reccmp/assets/template.html @@ -180,7 +180,10 @@ margin-bottom: 0; } - + diff --git a/reccmp/isledecomp/compare/asm/parse.py b/reccmp/isledecomp/compare/asm/parse.py index aa048040..48c0434a 100644 --- a/reccmp/isledecomp/compare/asm/parse.py +++ b/reccmp/isledecomp/compare/asm/parse.py @@ -7,9 +7,7 @@ placeholder string.""" import re -import struct from functools import cache -from typing import Callable from .const import JUMP_MNEMONICS, SINGLE_OPERAND_INSTS from .instgen import InstructGen, SectionType from .replacement import AddrTestProtocol, NameReplacementProtocol @@ -33,28 +31,22 @@ def from_hex(string: str) -> int | None: return None -def bytes_to_dword(b: bytes) -> int | None: - if len(b) == 4: - return struct.unpack(" None: self.addr_test = addr_test self.name_lookup = name_lookup - self.bin_lookup = bin_lookup + self.replacements: dict[int, str] = {} + self.indirect_replacements: dict[int, str] = {} self.number_placeholders = True def reset(self): self.replacements = {} + self.indirect_replacements = {} def is_addr(self, value: int) -> bool: """Wrapper for user-provided address test""" @@ -63,13 +55,22 @@ def is_addr(self, value: int) -> bool: return False - def lookup(self, addr: int, exact: bool = False) -> str | None: + def lookup( + self, addr: int, exact: bool = False, indirect: bool = False + ) -> str | None: """Wrapper for user-provided name lookup""" if callable(self.name_lookup): - return self.name_lookup(addr, exact=exact) + return self.name_lookup(addr, exact=exact, indirect=indirect) return None + def _next_placeholder(self) -> str: + """The placeholder number corresponds to the number of addresses we have + already replaced. This is so the number will be consistent across the diff + if we can replace some symbols with actual names in recomp but not orig.""" + number = len(self.replacements) + len(self.indirect_replacements) + 1 + return f"" if self.number_placeholders else "" + def replace(self, addr: int, exact: bool = False) -> str: """Provide a replacement name for the given address.""" if addr in self.replacements: @@ -79,14 +80,22 @@ def replace(self, addr: int, exact: bool = False) -> str: self.replacements[addr] = name return name - # The placeholder number corresponds to the number of addresses we have - # already replaced. This is so the number will be consistent across the diff - # if we can replace some symbols with actual names in recomp but not orig. - idx = len(self.replacements) + 1 - placeholder = f"" if self.number_placeholders else "" + placeholder = self._next_placeholder() self.replacements[addr] = placeholder return placeholder + def indirect_replace(self, addr: int) -> str: + if addr in self.indirect_replacements: + return self.indirect_replacements[addr] + + if (name := self.lookup(addr, exact=True, indirect=True)) is not None: + self.indirect_replacements[addr] = name + return name + + placeholder = self._next_placeholder() + self.indirect_replacements[addr] = placeholder + return placeholder + def hex_replace_always(self, match: re.Match) -> str: """If a pointer value was matched, always insert a placeholder""" value = int(match.group(1), 16) @@ -119,17 +128,7 @@ def hex_replace_indirect(self, match: re.Match) -> str: If we cannot identify the indirect address, fall back to a lookup on the original pointer value so we might display something useful.""" value = int(match.group(1), 16) - indirect_value = None - - if callable(self.bin_lookup): - indirect_value = self.bin_lookup(value, 4) - - if indirect_value is not None: - indirect_addr = bytes_to_dword(indirect_value) - if indirect_addr is not None and self.lookup(indirect_addr) is not None: - return "->" + self.replace(indirect_addr) - - return self.replace(value) + return self.indirect_replace(value) def sanitize(self, inst: DisasmLiteInst) -> tuple[str, str]: # For jumps or calls, if the entire op_str is a hex number, the value diff --git a/reccmp/isledecomp/compare/asm/replacement.py b/reccmp/isledecomp/compare/asm/replacement.py index 4c5833ce..89e8c0e6 100644 --- a/reccmp/isledecomp/compare/asm/replacement.py +++ b/reccmp/isledecomp/compare/asm/replacement.py @@ -10,28 +10,92 @@ def __call__(self, addr: int, /) -> bool: class NameReplacementProtocol(Protocol): - def __call__(self, addr: int, exact: bool = False) -> str | None: + def __call__( + self, addr: int, exact: bool = False, indirect: bool = False + ) -> str | None: ... def create_name_lookup( - db_getter: Callable[[int, bool], ReccmpEntity | None], addr_attribute: str + db_getter: Callable[[int, bool], ReccmpEntity | None], + bin_read: Callable[[int], int | None], + addr_attribute: str, ) -> NameReplacementProtocol: """Function generator for name replacement""" - @cache - def lookup(addr: int, exact: bool = False) -> str | None: - m = db_getter(addr, exact) - if m is None: + def follow_indirect(pointer: int) -> ReccmpEntity | None: + """Read the pointer address and open the entity (if it exists) at the indirect location.""" + addr = bin_read(pointer) + if addr is not None: + return db_getter(addr, True) + + return None + + def get_name(entity: ReccmpEntity, offset: int = 0) -> str | None: + """The offset is the difference between the input search address and the entity's + starting address. Decide whether to return the base name (match_name) or + a string wtih the base name plus the offset. + Returns None if there is no suitable name.""" + if offset == 0: + return entity.match_name() + + # We will not return an offset name if this is not a variable + # or if the offset is outside the range of the entity. + if entity.entity_type != EntityType.DATA or offset >= entity.size: return None - if getattr(m, addr_attribute) == addr: - return m.match_name() + return entity.offset_name(offset) + + def indirect_lookup(addr: int) -> str | None: + """Same as regular lookup but aware of the fact that the address is a pointer. + Indirect implies exact search, so we drop both parameters from the lookup entry point. + """ + entity = db_getter(addr, True) + if entity is not None: + # If the indirect call points at a variable initialized to a function, + # prefer the variable name as this is more useful. + if entity.entity_type == EntityType.DATA: + return entity.match_name() + + if entity.entity_type == EntityType.IMPORT: + import_name = entity.get("import_name") + if import_name is not None: + return "->" + import_name + " (FUNCTION)" + + return entity.match_name() + + # No suitable entity at the base address. Read the pointer and see what we get. + entity = follow_indirect(addr) + + if entity is None: + return None + + # Exact match only for indirect. + # The 'addr' variable still points at the indirect addr. + name = get_name(entity, offset=0) + if name is not None: + return "->" + name + + return None + + @cache + def lookup(addr: int, exact: bool = False, indirect: bool = False) -> str | None: + """Returns the name that represents the entity at the given addresss. + If there is no suitable name, return None and let the caller choose one (i.e. placeholder). + * exact: If the addr is an offset of an entity (e.g. struct/array) we may return + a name like 'variable+8'. If exact is True, return a name only if the entity's addr + matches the addr parameter. + * indirect: If True, the given addr is a pointer so we have the option to read the address + from the binary to find the name.""" + if indirect: + return indirect_lookup(addr) + + entity = db_getter(addr, exact) - offset = addr - getattr(m, addr_attribute) - if m.entity_type != EntityType.DATA or offset >= m.size: + if entity is None: return None - return m.offset_name(offset) + offset = addr - getattr(entity, addr_attribute) + return get_name(entity, offset) return lookup diff --git a/reccmp/isledecomp/compare/core.py b/reccmp/isledecomp/compare/core.py index 05a0e358..f5fab329 100644 --- a/reccmp/isledecomp/compare/core.py +++ b/reccmp/isledecomp/compare/core.py @@ -6,7 +6,10 @@ import uuid from dataclasses import dataclass from typing import Callable, Iterable, Iterator -from reccmp.isledecomp.formats.exceptions import InvalidVirtualAddressError +from reccmp.isledecomp.formats.exceptions import ( + InvalidVirtualAddressError, + InvalidVirtualReadError, +) from reccmp.isledecomp.formats.pe import PEImage from reccmp.isledecomp.cvdump.demangler import ( demangle_string_const, @@ -75,13 +78,14 @@ def lookup(addr: int) -> bool: return lookup -def create_bin_lookup(bin_file: PEImage) -> Callable[[int, int], bytes | None]: - """Function generator for reading from the bin file""" +def create_bin_lookup(bin_file: PEImage) -> Callable[[int], int | None]: + """Function generator to read a pointer from the bin file""" - def lookup(addr: int, size: int) -> bytes | None: + def lookup(addr: int) -> int | None: try: - return bin_file.read(addr, size) - except InvalidVirtualAddressError: + (ptr,) = struct.unpack(" None: + self.filename = filename + if timestamp is not None: + self.timestamp = timestamp + else: + self.timestamp = datetime.now().replace(microsecond=0) + + self.entities = {} + + +def _get_entity_for_addr( + samples: Iterable[ReccmpStatusReport], addr: str +) -> Iterator[ReccmpComparedEntity]: + """Helper to return entities from xreports that have the given address.""" + for sample in samples: + if addr in sample.entities: + yield sample.entities[addr] + + +def _accuracy_sort_key(entity: ReccmpComparedEntity) -> float: + """Helper to sort entity samples by accuracy score. + 100% match is preferred over effective match. + Effective match is preferred over any accuracy. + Stubs rank lower than any accuracy score.""" + if entity.is_stub: + return -1.0 + + if entity.accuracy == 1.0: + if not entity.is_effective_match: + return 1000.0 + + if entity.is_effective_match: + return 1.0 + + return entity.accuracy + + +def combine_reports(samples: list[ReccmpStatusReport]) -> ReccmpStatusReport: + """Combines the sample reports into a single report. + The current strategy is to use the entity with the highest + accuracy score from any report.""" + assert len(samples) > 0 + + if not all(samples[0].filename == s.filename for s in samples): + raise ReccmpReportSameSourceError + + output = ReccmpStatusReport(filename=samples[0].filename) + + # Combine every orig addr used in any of the reports. + orig_addr_set = {key for sample in samples for key in sample.entities.keys()} + + all_orig_addrs = sorted(list(orig_addr_set)) + + for addr in all_orig_addrs: + e_list = list(_get_entity_for_addr(samples, addr)) + assert len(e_list) > 0 + + # Our aggregate accuracy score is the highest from any report. + e_list.sort(key=_accuracy_sort_key, reverse=True) + + output.entities[addr] = e_list[0] + + # Keep the recomp_addr if it is the same across all samples. + # i.e. to detect where function alignment ends + if not all(e_list[0].recomp_addr == e.recomp_addr for e in e_list): + output.entities[addr].recomp_addr = "various" + + return output + + +#### JSON schemas and conversion functions #### + + +@dataclass +class JSONEntityVersion1: + address: str + name: str + matching: float + # Optional fields + recomp: str | None = None + stub: bool = False + effective: bool = False + diff: CombinedDiffOutput | None = None + + +class JSONReportVersion1(BaseModel): + file: str + format: Literal[1] + timestamp: float + data: list[JSONEntityVersion1] + + +def _serialize_version_1( + report: ReccmpStatusReport, diff_included: bool = False +) -> JSONReportVersion1: + """The HTML file needs the diff data, but it is omitted from the JSON report.""" + entities = [ + JSONEntityVersion1( + address=addr, # prefer dict key over redundant value in entity + name=e.name, + matching=e.accuracy, + recomp=e.recomp_addr, + stub=e.is_stub, + effective=e.is_effective_match, + diff=e.diff if diff_included else None, + ) + for addr, e in report.entities.items() + ] + + return JSONReportVersion1( + file=report.filename, + format=1, + timestamp=report.timestamp.timestamp(), + data=entities, + ) + + +def _deserialize_version_1(obj: JSONReportVersion1) -> ReccmpStatusReport: + report = ReccmpStatusReport( + filename=obj.file, timestamp=datetime.fromtimestamp(obj.timestamp) + ) + + for e in obj.data: + report.entities[e.address] = ReccmpComparedEntity( + orig_addr=e.address, + name=e.name, + accuracy=e.matching, + recomp_addr=e.recomp, + is_stub=e.stub, + is_effective_match=e.effective, + diff=e.diff, + ) + + return report + + +def deserialize_reccmp_report(json_str: str) -> ReccmpStatusReport: + try: + obj = JSONReportVersion1.model_validate(from_json(json_str)) + return _deserialize_version_1(obj) + except ValidationError as ex: + raise ReccmpReportDeserializeError from ex + + +def serialize_reccmp_report( + report: ReccmpStatusReport, diff_included: bool = False +) -> str: + """Create a JSON string for the report so it can be written to a file.""" + now = datetime.now().replace(microsecond=0) + report.timestamp = now + obj = _serialize_version_1(report, diff_included=diff_included) + + return obj.model_dump_json(exclude_defaults=True) diff --git a/reccmp/isledecomp/types.py b/reccmp/isledecomp/types.py index afbc7341..34e1e64a 100644 --- a/reccmp/isledecomp/types.py +++ b/reccmp/isledecomp/types.py @@ -12,3 +12,4 @@ class EntityType(IntEnum): STRING = 4 VTABLE = 5 FLOAT = 6 + IMPORT = 7 diff --git a/reccmp/isledecomp/utils.py b/reccmp/isledecomp/utils.py index fc7215d9..4a04269f 100644 --- a/reccmp/isledecomp/utils.py +++ b/reccmp/isledecomp/utils.py @@ -1,7 +1,31 @@ from datetime import datetime import logging -from pathlib import Path import colorama +from pystache import Renderer # type: ignore[import-untyped] +from reccmp.assets import get_asset_file +from reccmp.isledecomp.compare.report import ( + ReccmpStatusReport, + ReccmpComparedEntity, + serialize_reccmp_report, +) + + +def write_html_report(html_file: str, report: ReccmpStatusReport): + """Create the interactive HTML diff viewer with the given report.""" + js_path = get_asset_file("../assets/reccmp.js") + with open(js_path, "r", encoding="utf-8") as f: + reccmp_js = f.read() + + # Convert the report to a JSON string to insert in the HTML template. + report_str = serialize_reccmp_report(report, diff_included=True) + + output_data = Renderer().render_path( + get_asset_file("../assets/template.html"), + {"report": report_str, "reccmp_js": reccmp_js}, + ) + + with open(html_file, "w", encoding="utf-8") as htmlfile: + htmlfile.write(output_data) def print_combined_diff(udiff, plain: bool = False, show_both: bool = False): @@ -129,7 +153,9 @@ def diff_json_display(show_both_addrs: bool = False, is_plain: bool = False): """Generate a function that will display the diff according to the reccmp display preferences.""" - def formatter(orig_addr, saved, new) -> str: + def formatter( + orig_addr, saved: ReccmpComparedEntity, new: ReccmpComparedEntity + ) -> str: old_pct = "new" new_pct = "gone" name = "" @@ -138,29 +164,25 @@ def formatter(orig_addr, saved, new) -> str: if new is not None: new_pct = ( "stub" - if new.get("stub", False) - else percent_string( - new["matching"], new.get("effective", False), is_plain - ) + if new.is_stub + else percent_string(new.accuracy, new.is_effective_match, is_plain) ) # Prefer the current name of this function if we have it. # We are using the original address as the key. # A function being renamed is not of interest here. - name = new.get("name", "") - recomp_addr = new.get("recomp", "n/a") + name = new.name + recomp_addr = new.recomp_addr or "n/a" if saved is not None: old_pct = ( "stub" - if saved.get("stub", False) - else percent_string( - saved["matching"], saved.get("effective", False), is_plain - ) + if saved.is_stub + else percent_string(saved.accuracy, saved.is_effective_match, is_plain) ) if name == "": - name = saved.get("name", "") + name = saved.name if show_both_addrs: addr_string = f"{orig_addr} / {recomp_addr:10}" @@ -176,29 +198,25 @@ def formatter(orig_addr, saved, new) -> str: def diff_json( - saved_data, - new_data, - orig_file: Path, + saved_data: ReccmpStatusReport, + new_data: ReccmpStatusReport, show_both_addrs: bool = False, is_plain: bool = False, ): - """Using a saved copy of the diff summary and the current data, print a - report showing which functions/symbols have changed match percentage.""" + """Compare two status report files, determine what items changed, and print the result.""" # Don't try to diff a report generated for a different binary file - base_file = orig_file.name.lower() - - if saved_data.get("file") != base_file: + if saved_data.filename != new_data.filename: logging.getLogger().error( "Diff report for '%s' does not match current file '%s'", - saved_data.get("file"), - base_file, + saved_data.filename, + new_data.filename, ) return - if "timestamp" in saved_data: + if saved_data.timestamp is not None: now = datetime.now().replace(microsecond=0) - then = datetime.fromtimestamp(saved_data["timestamp"]).replace(microsecond=0) + then = saved_data.timestamp.replace(microsecond=0) print( " ".join( @@ -213,8 +231,8 @@ def diff_json( print() # Convert to dict, using orig_addr as key - saved_invert = {obj["address"]: obj for obj in saved_data["data"]} - new_invert = {obj["address"]: obj for obj in new_data} + saved_invert = saved_data.entities + new_invert = new_data.entities all_addrs = set(saved_invert.keys()).union(new_invert.keys()) @@ -227,60 +245,56 @@ def diff_json( for addr in sorted(all_addrs) } + DiffSubsectionType = dict[ + str, tuple[ReccmpComparedEntity | None, ReccmpComparedEntity | None] + ] + # The criteria for diff judgement is in these dict comprehensions: # Any function not in the saved file - new_functions = { + new_functions: DiffSubsectionType = { key: (saved, new) for key, (saved, new) in combined.items() if saved is None } # Any function now missing from the saved file # or a non-stub -> stub conversion - dropped_functions = { + dropped_functions: DiffSubsectionType = { key: (saved, new) for key, (saved, new) in combined.items() if new is None - or ( - new is not None - and saved is not None - and new.get("stub", False) - and not saved.get("stub", False) - ) + or (new is not None and saved is not None and new.is_stub and not saved.is_stub) } # TODO: move these two into functions if the assessment gets more complex # Any function with increased match percentage # or stub -> non-stub conversion - improved_functions = { + improved_functions: DiffSubsectionType = { key: (saved, new) for key, (saved, new) in combined.items() if saved is not None and new is not None - and ( - new["matching"] > saved["matching"] - or (not new.get("stub", False) and saved.get("stub", False)) - ) + and (new.accuracy > saved.accuracy or (not new.is_stub and saved.is_stub)) } # Any non-stub function with decreased match percentage - degraded_functions = { + degraded_functions: DiffSubsectionType = { key: (saved, new) for key, (saved, new) in combined.items() if saved is not None and new is not None - and new["matching"] < saved["matching"] - and not saved.get("stub") - and not new.get("stub") + and new.accuracy < saved.accuracy + and not saved.is_stub + and not new.is_stub } # Any function with former or current "effective" match - entropy_functions = { + entropy_functions: DiffSubsectionType = { key: (saved, new) for key, (saved, new) in combined.items() if saved is not None and new is not None - and new["matching"] == 1.0 - and saved["matching"] == 1.0 - and new.get("effective", False) != saved.get("effective", False) + and new.accuracy == 1.0 + and saved.accuracy == 1.0 + and new.is_effective_match != saved.is_effective_match } get_diff_str = diff_json_display(show_both_addrs, is_plain) diff --git a/reccmp/tools/aggregate.py b/reccmp/tools/aggregate.py new file mode 100644 index 00000000..e7661e9a --- /dev/null +++ b/reccmp/tools/aggregate.py @@ -0,0 +1,179 @@ +#!/usr/bin/env python3 + +import argparse +import logging +from typing import Sequence +from pathlib import Path +from reccmp.isledecomp.utils import diff_json, write_html_report +from reccmp.isledecomp.compare.report import ( + ReccmpStatusReport, + combine_reports, + ReccmpReportDeserializeError, + ReccmpReportSameSourceError, + deserialize_reccmp_report, + serialize_reccmp_report, +) + + +logger = logging.getLogger(__name__) + + +def write_report_file(output_file: Path, report: ReccmpStatusReport): + """Convert the status report to JSON and write to a file.""" + json_str = serialize_reccmp_report(report) + + with open(output_file, "w+", encoding="utf-8") as f: + f.write(json_str) + + +def load_report_file(report_path: Path) -> ReccmpStatusReport: + """Deserialize from JSON at the given filename and return the report.""" + + with report_path.open("r", encoding="utf-8") as f: + return deserialize_reccmp_report(f.read()) + + +def deserialize_sample_files(paths: list[Path]) -> list[ReccmpStatusReport]: + """Deserialize all sample files and return the list of reports. + Does not remove duplicates.""" + samples = [] + + for path in paths: + if path.is_file(): + try: + report = load_report_file(path) + samples.append(report) + except ReccmpReportDeserializeError: + logger.warning("Skipping '%s' due to import error", path) + elif not path.exists(): + logger.warning("File not found: '%s'", path) + + return samples + + +class TwoOrMoreArgsAction(argparse.Action): + """Support nargs=2+""" + + def __call__( + self, parser, namespace, values: Sequence[str] | None, option_string=None + ): + assert isinstance(values, Sequence) + if len(values) < 2: + raise argparse.ArgumentError(self, "expected two or more arguments") + + setattr(namespace, self.dest, values) + + +class TwoOrFewerArgsAction(argparse.Action): + """Support nargs=(1,2)""" + + def __call__( + self, parser, namespace, values: Sequence[str] | None, option_string=None + ): + assert isinstance(values, Sequence) + if len(values) not in (1, 2): + raise argparse.ArgumentError(self, "expected one or two arguments") + + setattr(namespace, self.dest, values) + + +def main(): + parser = argparse.ArgumentParser( + allow_abbrev=False, + description="Aggregate saved accuracy reports.", + ) + parser.add_argument( + "--diff", + type=Path, + metavar="", + nargs="+", + action=TwoOrFewerArgsAction, + help="Report files to diff.", + ) + parser.add_argument( + "--html", + type=Path, + metavar="", + help="Location for HTML report based on aggregate.", + ) + parser.add_argument( + "--output", + "-o", + type=Path, + metavar="", + help="Where to save the aggregate file.", + ) + parser.add_argument( + "--samples", + type=Path, + metavar="", + nargs="+", + action=TwoOrMoreArgsAction, + help="Report files to aggregate.", + ) + parser.add_argument( + "--no-color", "-n", action="store_true", help="Do not color the output" + ) + + args = parser.parse_args() + + if not (args.samples or args.diff): + parser.error( + "exepected arguments for --samples or --diff. (No input files specified)" + ) + + if not (args.output or args.diff or args.html): + parser.error( + "expected arguments for --output, --html, or --diff. (No output action specified)" + ) + + agg_report: ReccmpStatusReport | None = None + + if args.samples is not None: + samples = deserialize_sample_files(args.samples) + + if len(samples) < 2: + logger.error("Not enough samples to aggregate!") + return 1 + + try: + agg_report = combine_reports(samples) + except ReccmpReportSameSourceError: + filename_list = sorted({s.filename for s in samples}) + logger.error( + "Aggregate samples are not from the same source file!\nFilenames used: %s", + filename_list, + ) + return 1 + + if args.output is not None: + write_report_file(args.output, agg_report) + + if args.html is not None: + write_html_report(args.html, agg_report) + + # If --diff has at least one file and we aggregated some samples this run, diff the first file and the aggregate. + # If --diff has two files and we did not aggregate this run, diff the files in the list. + if args.diff is not None: + saved_data = load_report_file(args.diff[0]) + + if agg_report is None: + if len(args.diff) > 1: + agg_report = load_report_file(args.diff[1]) + else: + logger.error("Not enough files to diff!") + return 1 + elif len(args.diff) == 2: + logger.warning( + "Ignoring second --diff argument '%s'.\nDiff of '%s' and aggregate report follows.", + args.diff[1], + args.diff[0], + ) + + diff_json(saved_data, agg_report, show_both_addrs=False, is_plain=args.no_color) + + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/reccmp/tools/asmcmp.py b/reccmp/tools/asmcmp.py index 0e5eb16c..e83d95ce 100755 --- a/reccmp/tools/asmcmp.py +++ b/reccmp/tools/asmcmp.py @@ -2,11 +2,8 @@ import argparse import base64 -import json import logging import os -from datetime import datetime -from pathlib import Path from pystache import Renderer # type: ignore[import-untyped] import colorama @@ -15,9 +12,16 @@ print_combined_diff, diff_json, percent_string, + write_html_report, ) from reccmp.isledecomp.compare import Compare as IsleCompare +from reccmp.isledecomp.compare.report import ( + ReccmpStatusReport, + ReccmpComparedEntity, + deserialize_reccmp_report, + serialize_reccmp_report, +) from reccmp.isledecomp.formats.detect import detect_image from reccmp.isledecomp.formats.pe import PEImage from reccmp.isledecomp.types import EntityType @@ -34,43 +38,11 @@ colorama.just_fix_windows_console() -def gen_json(json_file: str, orig_file: Path, data): - """Create a JSON file that contains the comparison summary""" - - # If the structure of the JSON file ever changes, we would run into a problem - # reading an older format file in the CI action. Mark which version we are - # generating so we could potentially address this down the road. - json_format_version = 1 - - # Remove the diff field - reduced_data = [ - {key: value for (key, value) in obj.items() if key != "diff"} for obj in data - ] +def gen_json(json_file: str, json_str: str): + """Convert the status report to JSON and write to a file.""" with open(json_file, "w", encoding="utf-8") as f: - json.dump( - { - "file": orig_file.name.lower(), - "format": json_format_version, - "timestamp": datetime.now().timestamp(), - "data": reduced_data, - }, - f, - ) - - -def gen_html(html_file, data): - js_path = get_asset_file("../assets/reccmp.js") - with open(js_path, "r", encoding="utf-8") as f: - reccmp_js = f.read() - - output_data = Renderer().render_path( - get_asset_file("../assets/template.html"), - {"data": data, "reccmp_js": reccmp_js}, - ) - - with open(html_file, "w", encoding="utf-8") as htmlfile: - htmlfile.write(output_data) + f.write(json_str) def gen_svg(svg_file, name_svg, icon, svg_implemented_funcs, total_funcs, raw_accuracy): @@ -177,6 +149,11 @@ def virtual_address(value) -> int: metavar="", help="Generate JSON file with match summary", ) + parser.add_argument( + "--json-diet", + action="store_true", + help="Exclude diff from JSON report.", + ) parser.add_argument( "--diff", metavar="", @@ -267,7 +244,8 @@ def main(): function_count = 0 total_accuracy = 0.0 total_effective_accuracy = 0.0 - htmlinsert = [] + + report = ReccmpStatusReport(filename=target.original_path.name.lower()) for match in isle_compare.compare_all(): if not args.silent and args.diff is None: @@ -287,44 +265,42 @@ def main(): total_effective_accuracy += match.effective_ratio # If html, record the diffs to an HTML file - html_obj = { - "address": f"0x{match.orig_addr:x}", - "recomp": f"0x{match.recomp_addr:x}", - "name": match.name, - "matching": match.effective_ratio, - } - - if match.is_effective_match: - html_obj["effective"] = True - - if match.udiff is not None: - html_obj["diff"] = match.udiff - - if match.is_stub: - html_obj["stub"] = True - - htmlinsert.append(html_obj) + orig_addr = f"0x{match.orig_addr:x}" + recomp_addr = f"0x{match.recomp_addr:x}" + + report.entities[orig_addr] = ReccmpComparedEntity( + orig_addr=orig_addr, + name=match.name, + accuracy=match.effective_ratio, + recomp_addr=recomp_addr, + is_effective_match=match.is_effective_match, + is_stub=match.is_stub, + diff=match.udiff, + ) # Compare with saved diff report. if args.diff is not None: with open(args.diff, "r", encoding="utf-8") as f: - saved_data = json.load(f) - - diff_json( - saved_data, - htmlinsert, - target.original_path, - show_both_addrs=args.print_rec_addr, - is_plain=args.no_color, - ) + saved_data = deserialize_reccmp_report(f.read()) + + diff_json( + saved_data, + report, + show_both_addrs=args.print_rec_addr, + is_plain=args.no_color, + ) ## Generate files and show summary. if args.json is not None: - gen_json(args.json, target.original_path, htmlinsert) + # If we're on a diet, hold the diff. + diff_included = not bool(args.json_diet) + gen_json( + args.json, serialize_reccmp_report(report, diff_included=diff_included) + ) if args.html is not None: - gen_html(args.html, json.dumps(htmlinsert)) + write_html_report(args.html, report) implemented_funcs = function_count diff --git a/setup.cfg b/setup.cfg index 6458a6f5..629c9862 100644 --- a/setup.cfg +++ b/setup.cfg @@ -21,6 +21,7 @@ exclude = [options.entry_points] console_scripts = + reccmp-aggregate = reccmp.tools.aggregate:main reccmp-datacmp = reccmp.tools.datacmp:main reccmp-decomplint = reccmp.tools.decomplint:main reccmp-project = reccmp.tools.project:main diff --git a/tests/test_name_replacement.py b/tests/test_name_replacement.py new file mode 100644 index 00000000..e62c3f92 --- /dev/null +++ b/tests/test_name_replacement.py @@ -0,0 +1,217 @@ +import pytest +from reccmp.isledecomp.types import EntityType +from reccmp.isledecomp.compare.db import EntityDb +from reccmp.isledecomp.compare.asm.replacement import ( + create_name_lookup, + NameReplacementProtocol, +) + + +@pytest.fixture(name="db") +def fixture_db() -> EntityDb: + return EntityDb() + + +def create_lookup( + db, addrs: dict[int, int] | None = None, is_orig: bool = True +) -> NameReplacementProtocol: + if addrs is None: + addrs = {} + + def bin_lookup(addr: int) -> int | None: + return addrs.get(addr) + + if is_orig: + return create_name_lookup(db.get_by_orig, bin_lookup, "orig_addr") + + return create_name_lookup(db.get_by_recomp, bin_lookup, "recomp_addr") + + +#### + + +def test_name_replacement(db): + """Should return a name for an entity that has one. + Return None if there are no name attributes set or the entity does not exist.""" + with db.batch() as batch: + batch.set_orig(100, name="Test") + batch.set_orig(200, computed_name="Hello") + batch.set_orig(300) # No name + + lookup = create_lookup(db) + + # Using "in" here because the returned string may contain other information. + # e.g. the entity type + assert "Test" in lookup(100) + assert "Hello" in lookup(200) + assert lookup(300) is None + + +def test_name_hierarchy(db): + """Use the "best" entity name. Currently there are only two. + 'computed_name' is preferred over just 'name'.""" + with db.batch() as batch: + batch.set_orig(100, name="Test", computed_name="Hello") + + lookup = create_lookup(db) + + # Should prefer 'computed_name' over 'name' + assert "Hello" in lookup(100) + assert "Test" not in lookup(100) + + +def test_string_escape_newlines(db): + """Make sure newlines are removed from the string. + This overlap with tests on the ReccmpEntity name functions, but it is more vital + to ensure there are no newlines at this stage because they will disrupt the asm diff. + """ + with db.batch() as batch: + batch.set_orig(100, name="Test\nTest", type=EntityType.STRING) + + lookup = create_lookup(db) + + assert "\n" not in lookup(100) + + +def test_offset_name(db): + """For some entities (i.e. variables) we will return a name if the search address + is inside the address range of the entity. This is determined by the size attribute. + """ + with db.batch() as batch: + batch.set_orig(100, name="Hello", type=EntityType.DATA, size=10) + + lookup = create_lookup(db) + + assert lookup(100) is not None + assert lookup(101) is not None + + # Outside the range = no name + assert lookup(110) is None + + +def test_offset_name_non_variables(db): + """Do not return an offset name for non-variable entities. (e.g. functions).""" + with db.batch() as batch: + batch.set_orig(100, name="Hello", type=EntityType.FUNCTION, size=10) + batch.set_orig(200, name="Hello", size=10) # No type + + lookup = create_lookup(db) + + assert lookup(100) is not None + assert lookup(101) is None + + assert lookup(200) is not None + assert lookup(201) is None + + +def test_offset_name_no_size(db): + """An enity with no size attribute is considered to have size=0. + Meaning: match only against the address value.""" + with db.batch() as batch: + batch.set_orig(100, name="Hello", type=EntityType.DATA) + + lookup = create_lookup(db) + + assert lookup(100) is not None + assert lookup(101) is None + + +def test_exact_restriction(db): + """If exact=True, return a name only if the entity's address matches the search address. + Otherwise we might return a name if the entity contains the search address.""" + with db.batch() as batch: + batch.set_orig(100, name="Hello", type=EntityType.DATA, size=10) + + lookup = create_lookup(db) + + assert lookup(100, exact=True) is not None + assert lookup(101, exact=True) is None + + # Proof that the exact parameter controls whether we get a name. + assert lookup(101, exact=False) is not None + + +def test_indirect_function(db): + """An instruction like `call dword ptr [0x1234]` means that we call the function + whose address is at address 0x1234. This is an indirect lookup.""" + with db.batch() as batch: + batch.set_orig(100, name="Hello", type=EntityType.FUNCTION) + + # Mock lookup so we will read 100 from address 200. + lookup = create_lookup(db, {200: 100}) + + # No entity at 200 + assert lookup(200) is None + assert lookup(200, indirect=True) is not None + + # Imitating ghidra asm display. Not every indirect lookup gets the arrow. + assert "->" in lookup(200, indirect=True) + + +def test_indirect_function_variable(db): + """If the indirect call instruction has the address of a variable in our database, + prefer the variable name rather than reading the pointer.""" + with db.batch() as batch: + batch.set_orig(100, name="Hello", type=EntityType.FUNCTION) + batch.set_orig(200, name="Test", type=EntityType.DATA) + + # Mock lookup so we will read 100 from address 200. + lookup = create_lookup(db, {200: 100}) + + name = lookup(200, indirect=True) + assert name is not None + assert "Hello" not in name + assert "Test" in name + assert "->" not in name + + +def test_indirect_import(db): + """If we are indirectly calling an imported funtion, we should see the import_name + attribute used in the result. This will probably contain the DLL and function name. + """ + with db.batch() as batch: + batch.set_orig(100, import_name="Hello", name="Test", type=EntityType.IMPORT) + + # No mock needed here because we will not need to read any data. + lookup = create_lookup(db) + + # Should use import name with arrow to suggest indirect call. + name = lookup(100, indirect=True) + assert name is not None + assert "Hello" in name + assert "->" in name + + # Show the entity name instead. (e.g. __imp__ symbol) + name = lookup(100, indirect=False) + assert name is not None + assert "Test" in name + assert "->" not in name + + +def test_indirect_import_missing_data(db): + """Edge cases for indirect lookup on an IMPORT entity..""" + with db.batch() as batch: + batch.set_orig(100, name="Test", type=EntityType.IMPORT) + + lookup = create_lookup(db) + + # No import name. Use the regular entity name instead (i.e. match indirect=False lookup) + name = lookup(100, indirect=True) + assert name is not None + assert "Test" in name + assert "->" not in name + + +def test_indirect_failed_lookup(db): + """In the general case (i.e. we do not use the base entity to get the name) + if there is no entity at the pointer location, return None.""" + with db.batch() as batch: + batch.set_orig(200, name="Hello", type=EntityType.FUNCTION) + + # Mock lookup so we will read 100 from address 200. + lookup = create_lookup(db, {200: 100}) + + # There is an entity at 200 but we can't use it to generate a name. + # There is no entity at 100 (indirect location) + assert lookup(200, indirect=False) is not None + assert lookup(200, indirect=True) is None diff --git a/tests/test_report.py b/tests/test_report.py new file mode 100644 index 00000000..71a71f47 --- /dev/null +++ b/tests/test_report.py @@ -0,0 +1,136 @@ +"""Reccmp reports: files that contain the comparison result from asmcmp.""" +import pytest +from reccmp.isledecomp.compare.report import ( + ReccmpStatusReport, + ReccmpComparedEntity, + combine_reports, + ReccmpReportSameSourceError, +) + + +def create_report( + entities: list[tuple[str, float]] | None = None +) -> ReccmpStatusReport: + """Helper to quickly set up a report to be customized further for each test.""" + report = ReccmpStatusReport(filename="test.exe") + if entities is not None: + for addr, accuracy in entities: + report.entities[addr] = ReccmpComparedEntity(addr, "test", accuracy) + + return report + + +def test_aggregate_identity(): + """Combine a list of one report. Should get the same report back, + except for expected differences like the timestamp.""" + report = create_report([("100", 1.0), ("200", 0.5)]) + combined = combine_reports([report]) + + for (a_key, a_entity), (b_key, b_entity) in zip( + report.entities.items(), combined.entities.items() + ): + assert a_key == b_key + assert a_entity.orig_addr == b_entity.orig_addr + assert a_entity.accuracy == b_entity.accuracy + + +def test_aggregate_simple(): + """Should choose the best score from the sample reports.""" + x = create_report([("100", 0.8), ("200", 0.2)]) + y = create_report([("100", 0.2), ("200", 0.8)]) + + combined = combine_reports([x, y]) + assert combined.entities["100"].accuracy == 0.8 + assert combined.entities["200"].accuracy == 0.8 + + +def test_aggregate_union_all_addrs(): + """Should combine all addresses from any report.""" + x = create_report([("100", 0.8)]) + y = create_report([("200", 0.8)]) + + combined = combine_reports([x, y]) + assert "100" in combined.entities + assert "200" in combined.entities + + +def test_aggregate_stubs(): + """Stub functions (i.e. do not compare asm) are considered to have 0 percent accuracy.""" + x = create_report([("100", 0.9)]) + y = create_report([("100", 0.5)]) + + # In a real report, accuracy would be zero for a stub. + x.entities["100"].is_stub = True + y.entities["100"].is_stub = False + + combined = combine_reports([x, y]) + assert combined.entities["100"].is_stub is False + + # Choose the lower non-stub value + assert combined.entities["100"].accuracy == 0.5 + + +def test_aggregate_all_stubs(): + """If all samples are stubs, preserve that setting.""" + x = create_report([("100", 1.0)]) + + x.entities["100"].is_stub = True + + combined = combine_reports([x, x]) + assert combined.entities["100"].is_stub is True + + +def test_aggregate_100_over_effective(): + """Prefer 100% match over effective.""" + x = create_report([("100", 0.9)]) + y = create_report([("100", 1.0)]) + x.entities["100"].is_effective_match = True + + combined = combine_reports([x, y]) + assert combined.entities["100"].is_effective_match is False + + +def test_aggregate_effective_over_any(): + """Prefer effective match over any accuracy.""" + x = create_report([("100", 0.5)]) + y = create_report([("100", 0.6)]) + x.entities["100"].is_effective_match = True + # Y has higher accuracy score, but we could not confirm an effective match. + + combined = combine_reports([x, y]) + assert combined.entities["100"].is_effective_match is True + + # Should retain original accuracy for effective match. + assert combined.entities["100"].accuracy == 0.5 + + +def test_aggregate_different_files(): + """Should raise an exception if we try to aggregate reports + where the orig filename does not match.""" + x = create_report() + y = create_report() + + # Make sure they are different, regardless of what is set by create_report(). + x.filename = "test.exe" + y.filename = "hello.exe" + + with pytest.raises(ReccmpReportSameSourceError): + combine_reports([x, y]) + + +def test_aggregate_recomp_addr(): + """We combine the entity data based on the orig addr because this will not change. + The recomp addr may vary a lot. If it is the same in all samples, use the value. + Otherwise use a placeholder value.""" + x = create_report([("100", 0.8), ("200", 0.2)]) + y = create_report([("100", 0.2), ("200", 0.8)]) + # These recomp addrs match: + x.entities["100"].recomp_addr = "500" + y.entities["100"].recomp_addr = "500" + # Y report has no addr for this + x.entities["200"].recomp_addr = "600" + + combined = combine_reports([x, y]) + assert combined.entities["100"].recomp_addr == "500" + assert combined.entities["200"].recomp_addr != "600" + assert combined.entities["200"].recomp_addr == "various" diff --git a/tests/test_sanitize32.py b/tests/test_sanitize32.py index 4dfd4201..033437ba 100644 --- a/tests/test_sanitize32.py +++ b/tests/test_sanitize32.py @@ -75,7 +75,7 @@ def test_pointer_instructions_with_name(inst: DisasmLiteInst): (_, op_str) = p.sanitize(inst) # Using sample instructions where exact match is not required - name_lookup.assert_called_with(0x1234, exact=False) + name_lookup.assert_called_with(0x1234, exact=False, indirect=False) assert "[0x1234]" not in op_str assert "[Hello]" in op_str @@ -258,7 +258,7 @@ def test_jmp_with_name_lookup(): (_, op_str) = p.sanitize(DisasmLiteInst(0x1000, 5, "jmp", "0x2000")) - name_lookup.assert_called_with(0x2000, exact=True) + name_lookup.assert_called_with(0x2000, exact=True, indirect=False) assert op_str == "Hello" @@ -302,7 +302,7 @@ def test_cmp_with_name_lookup(): (_, op_str) = p.sanitize(inst) - name_lookup.assert_called_with(0x2000, exact=False) + name_lookup.assert_called_with(0x2000, exact=False, indirect=False) assert op_str == "eax, Hello" @@ -327,7 +327,7 @@ def test_call_with_name_lookup(): (_, op_str) = p.sanitize(inst) - name_lookup.assert_called_with(0x1234, exact=True) + name_lookup.assert_called_with(0x1234, exact=True, indirect=False) assert op_str == "Hello" @@ -348,42 +348,72 @@ def substitute_1234(addr: int, **_) -> str | None: def test_absolute_indirect(): - """**** Held over from previous test file. This behavior may change soon. **** - The instruction `call dword ptr [0x1234]` means we call the function - whose address is at 0x1234. (i.e. absolute indirect addressing mode) - It is probably more useful to show the name of the function itself if - we have it, but there are some circumstances where we want to replace - with the pointer's name (i.e. an import function).""" + """Read the given pointer and replace its value with a name or placeholder. + Previously we handled reading from the binary inside the sanitize function. + This is now delegated to the name lookup function, so we just need to check + that it was called with the indirect parameter set.""" + name_lookup = Mock(spec=NameReplacementProtocol, return_value=None) + p = ParseAsm(name_lookup=name_lookup) + inst = DisasmLiteInst(0x1000, 5, "call", "dword ptr [0x1234]") - def name_lookup(addr: int, **_) -> str | None: - return { - 0x1234: "Hello", - 0x4321: "xyz", - 0x5555: "Test", - }.get(addr) - - def bin_lookup(addr: int, _: int) -> bytes | None: - return ( - { - 0x1234: b"\x55\x55\x00\x00", - 0x4321: b"\x99\x99\x00\x00", - } - ).get(addr) - - p = ParseAsm(name_lookup=name_lookup, bin_lookup=bin_lookup) - - # If we know the indirect address (0x5555) - # Arrow to indicate this is an indirect replacement - (_, op_str) = p.sanitize(DisasmLiteInst(0x1000, 5, "call", "dword ptr [0x1234]")) - assert op_str == "dword ptr [->Test]" - - # If we do not know the indirect address (0x9999) - (_, op_str) = p.sanitize(DisasmLiteInst(0x1000, 5, "call", "dword ptr [0x4321]")) - assert op_str == "dword ptr [xyz]" - - # If we can't read the indirect address - (_, op_str) = p.sanitize(DisasmLiteInst(0x1000, 5, "call", "dword ptr [0x5555]")) - assert op_str == "dword ptr [Test]" + (_, op_str) = p.sanitize(inst) + + name_lookup.assert_called_with(0x1234, exact=True, indirect=True) + assert op_str == "dword ptr []" + + +def test_direct_and_indirect_different_names(): + """Indirect pointers should not collide with + cached lookups on direct pointers and vice versa""" + + # Create a lookup that checks indirect access + def lookup(_, indirect: bool = False, **__) -> str: + return "Indirect" if indirect else "Direct" + + indirect_inst = DisasmLiteInst(0x1000, 5, "call", "dword ptr [0x1234]") + direct_inst = DisasmLiteInst(0x1000, 5, "mov", "eax, dword ptr [0x1234]") + + # Indirect first + p = ParseAsm(name_lookup=lookup) + (_, op_str) = p.sanitize(indirect_inst) + assert op_str == "dword ptr [Indirect]" + + (_, op_str) = p.sanitize(direct_inst) + assert op_str == "eax, dword ptr [Direct]" + + # Direct first + p = ParseAsm(name_lookup=lookup) + (_, op_str) = p.sanitize(direct_inst) + assert op_str == "eax, dword ptr [Direct]" + + (_, op_str) = p.sanitize(indirect_inst) + assert op_str == "dword ptr [Indirect]" + + # Now verify that we use cached values for each + name_lookup = Mock(spec=NameReplacementProtocol, return_value=None) + p.name_lookup = name_lookup + (_, op_str) = p.sanitize(indirect_inst) + assert op_str == "dword ptr [Indirect]" + + (_, op_str) = p.sanitize(direct_inst) + assert op_str == "eax, dword ptr [Direct]" + + name_lookup.assert_not_called() + + +def test_direct_and_indirect_placeholders(): + """If no addresses are known, placeholders for direct and indirect lookup must be distinct""" + indirect_inst = DisasmLiteInst(0x1000, 5, "call", "dword ptr [0x1234]") + direct_inst = DisasmLiteInst(0x1000, 5, "mov", "eax, dword ptr [0x1234]") + + name_lookup = Mock(spec=NameReplacementProtocol, return_value=None) + p = ParseAsm(name_lookup=name_lookup) + + (_, indirect_op_str) = p.sanitize(indirect_inst) + (_, direct_op_str) = p.sanitize(direct_inst) + + # Must use two different placeholders + assert indirect_op_str != direct_op_str def test_consistent_numbering():