diff --git a/README.md b/README.md
index bf21d796..3a0a8900 100644
--- a/README.md
+++ b/README.md
@@ -46,6 +46,10 @@ The next steps differ based on what kind of project you have.
All scripts will become available to use in your terminal with the `reccmp-` prefix. Note that these scripts need to be executed in the directory where `reccmp-build.yml` is located.
+* [`aggregate`](/reccmp/tools/aggregate.py): Combines JSON reports into a single file.
+ * Aggregate using highest accuracy score: `reccmp-aggregate --samples ./sample0.json ./sample1.json ./sample2.json --output ./combined.json`
+ * Diff two saved reports: `reccmp-aggregate --diff ./before.json ./after.json`
+ * Diff against the aggregate: `reccmp-aggregate --samples ./sample0.json ./sample1.json ./sample2.json --diff ./before.json`
* [`decomplint`](/reccmp/tools/decomplint.py): Checks the decompilation annotations (see above)
* e.g. `reccmp-decomplint --module LEGO1 LEGO1`
* [`reccmp`](/reccmp/tools/asmcmp.py): Compares an original binary with a recompiled binary, provided a PDB file. For example:
diff --git a/reccmp/assets/template.html b/reccmp/assets/template.html
index 3ac15b45..6518c5a6 100644
--- a/reccmp/assets/template.html
+++ b/reccmp/assets/template.html
@@ -180,7 +180,10 @@
margin-bottom: 0;
}
-
+
diff --git a/reccmp/isledecomp/compare/asm/parse.py b/reccmp/isledecomp/compare/asm/parse.py
index aa048040..48c0434a 100644
--- a/reccmp/isledecomp/compare/asm/parse.py
+++ b/reccmp/isledecomp/compare/asm/parse.py
@@ -7,9 +7,7 @@
placeholder string."""
import re
-import struct
from functools import cache
-from typing import Callable
from .const import JUMP_MNEMONICS, SINGLE_OPERAND_INSTS
from .instgen import InstructGen, SectionType
from .replacement import AddrTestProtocol, NameReplacementProtocol
@@ -33,28 +31,22 @@ def from_hex(string: str) -> int | None:
return None
-def bytes_to_dword(b: bytes) -> int | None:
- if len(b) == 4:
- return struct.unpack(" None:
self.addr_test = addr_test
self.name_lookup = name_lookup
- self.bin_lookup = bin_lookup
+
self.replacements: dict[int, str] = {}
+ self.indirect_replacements: dict[int, str] = {}
self.number_placeholders = True
def reset(self):
self.replacements = {}
+ self.indirect_replacements = {}
def is_addr(self, value: int) -> bool:
"""Wrapper for user-provided address test"""
@@ -63,13 +55,22 @@ def is_addr(self, value: int) -> bool:
return False
- def lookup(self, addr: int, exact: bool = False) -> str | None:
+ def lookup(
+ self, addr: int, exact: bool = False, indirect: bool = False
+ ) -> str | None:
"""Wrapper for user-provided name lookup"""
if callable(self.name_lookup):
- return self.name_lookup(addr, exact=exact)
+ return self.name_lookup(addr, exact=exact, indirect=indirect)
return None
+ def _next_placeholder(self) -> str:
+ """The placeholder number corresponds to the number of addresses we have
+ already replaced. This is so the number will be consistent across the diff
+ if we can replace some symbols with actual names in recomp but not orig."""
+ number = len(self.replacements) + len(self.indirect_replacements) + 1
+ return f"" if self.number_placeholders else ""
+
def replace(self, addr: int, exact: bool = False) -> str:
"""Provide a replacement name for the given address."""
if addr in self.replacements:
@@ -79,14 +80,22 @@ def replace(self, addr: int, exact: bool = False) -> str:
self.replacements[addr] = name
return name
- # The placeholder number corresponds to the number of addresses we have
- # already replaced. This is so the number will be consistent across the diff
- # if we can replace some symbols with actual names in recomp but not orig.
- idx = len(self.replacements) + 1
- placeholder = f"" if self.number_placeholders else ""
+ placeholder = self._next_placeholder()
self.replacements[addr] = placeholder
return placeholder
+ def indirect_replace(self, addr: int) -> str:
+ if addr in self.indirect_replacements:
+ return self.indirect_replacements[addr]
+
+ if (name := self.lookup(addr, exact=True, indirect=True)) is not None:
+ self.indirect_replacements[addr] = name
+ return name
+
+ placeholder = self._next_placeholder()
+ self.indirect_replacements[addr] = placeholder
+ return placeholder
+
def hex_replace_always(self, match: re.Match) -> str:
"""If a pointer value was matched, always insert a placeholder"""
value = int(match.group(1), 16)
@@ -119,17 +128,7 @@ def hex_replace_indirect(self, match: re.Match) -> str:
If we cannot identify the indirect address, fall back to a lookup
on the original pointer value so we might display something useful."""
value = int(match.group(1), 16)
- indirect_value = None
-
- if callable(self.bin_lookup):
- indirect_value = self.bin_lookup(value, 4)
-
- if indirect_value is not None:
- indirect_addr = bytes_to_dword(indirect_value)
- if indirect_addr is not None and self.lookup(indirect_addr) is not None:
- return "->" + self.replace(indirect_addr)
-
- return self.replace(value)
+ return self.indirect_replace(value)
def sanitize(self, inst: DisasmLiteInst) -> tuple[str, str]:
# For jumps or calls, if the entire op_str is a hex number, the value
diff --git a/reccmp/isledecomp/compare/asm/replacement.py b/reccmp/isledecomp/compare/asm/replacement.py
index 4c5833ce..89e8c0e6 100644
--- a/reccmp/isledecomp/compare/asm/replacement.py
+++ b/reccmp/isledecomp/compare/asm/replacement.py
@@ -10,28 +10,92 @@ def __call__(self, addr: int, /) -> bool:
class NameReplacementProtocol(Protocol):
- def __call__(self, addr: int, exact: bool = False) -> str | None:
+ def __call__(
+ self, addr: int, exact: bool = False, indirect: bool = False
+ ) -> str | None:
...
def create_name_lookup(
- db_getter: Callable[[int, bool], ReccmpEntity | None], addr_attribute: str
+ db_getter: Callable[[int, bool], ReccmpEntity | None],
+ bin_read: Callable[[int], int | None],
+ addr_attribute: str,
) -> NameReplacementProtocol:
"""Function generator for name replacement"""
- @cache
- def lookup(addr: int, exact: bool = False) -> str | None:
- m = db_getter(addr, exact)
- if m is None:
+ def follow_indirect(pointer: int) -> ReccmpEntity | None:
+ """Read the pointer address and open the entity (if it exists) at the indirect location."""
+ addr = bin_read(pointer)
+ if addr is not None:
+ return db_getter(addr, True)
+
+ return None
+
+ def get_name(entity: ReccmpEntity, offset: int = 0) -> str | None:
+ """The offset is the difference between the input search address and the entity's
+ starting address. Decide whether to return the base name (match_name) or
+ a string wtih the base name plus the offset.
+ Returns None if there is no suitable name."""
+ if offset == 0:
+ return entity.match_name()
+
+ # We will not return an offset name if this is not a variable
+ # or if the offset is outside the range of the entity.
+ if entity.entity_type != EntityType.DATA or offset >= entity.size:
return None
- if getattr(m, addr_attribute) == addr:
- return m.match_name()
+ return entity.offset_name(offset)
+
+ def indirect_lookup(addr: int) -> str | None:
+ """Same as regular lookup but aware of the fact that the address is a pointer.
+ Indirect implies exact search, so we drop both parameters from the lookup entry point.
+ """
+ entity = db_getter(addr, True)
+ if entity is not None:
+ # If the indirect call points at a variable initialized to a function,
+ # prefer the variable name as this is more useful.
+ if entity.entity_type == EntityType.DATA:
+ return entity.match_name()
+
+ if entity.entity_type == EntityType.IMPORT:
+ import_name = entity.get("import_name")
+ if import_name is not None:
+ return "->" + import_name + " (FUNCTION)"
+
+ return entity.match_name()
+
+ # No suitable entity at the base address. Read the pointer and see what we get.
+ entity = follow_indirect(addr)
+
+ if entity is None:
+ return None
+
+ # Exact match only for indirect.
+ # The 'addr' variable still points at the indirect addr.
+ name = get_name(entity, offset=0)
+ if name is not None:
+ return "->" + name
+
+ return None
+
+ @cache
+ def lookup(addr: int, exact: bool = False, indirect: bool = False) -> str | None:
+ """Returns the name that represents the entity at the given addresss.
+ If there is no suitable name, return None and let the caller choose one (i.e. placeholder).
+ * exact: If the addr is an offset of an entity (e.g. struct/array) we may return
+ a name like 'variable+8'. If exact is True, return a name only if the entity's addr
+ matches the addr parameter.
+ * indirect: If True, the given addr is a pointer so we have the option to read the address
+ from the binary to find the name."""
+ if indirect:
+ return indirect_lookup(addr)
+
+ entity = db_getter(addr, exact)
- offset = addr - getattr(m, addr_attribute)
- if m.entity_type != EntityType.DATA or offset >= m.size:
+ if entity is None:
return None
- return m.offset_name(offset)
+ offset = addr - getattr(entity, addr_attribute)
+ return get_name(entity, offset)
return lookup
diff --git a/reccmp/isledecomp/compare/core.py b/reccmp/isledecomp/compare/core.py
index 05a0e358..f5fab329 100644
--- a/reccmp/isledecomp/compare/core.py
+++ b/reccmp/isledecomp/compare/core.py
@@ -6,7 +6,10 @@
import uuid
from dataclasses import dataclass
from typing import Callable, Iterable, Iterator
-from reccmp.isledecomp.formats.exceptions import InvalidVirtualAddressError
+from reccmp.isledecomp.formats.exceptions import (
+ InvalidVirtualAddressError,
+ InvalidVirtualReadError,
+)
from reccmp.isledecomp.formats.pe import PEImage
from reccmp.isledecomp.cvdump.demangler import (
demangle_string_const,
@@ -75,13 +78,14 @@ def lookup(addr: int) -> bool:
return lookup
-def create_bin_lookup(bin_file: PEImage) -> Callable[[int, int], bytes | None]:
- """Function generator for reading from the bin file"""
+def create_bin_lookup(bin_file: PEImage) -> Callable[[int], int | None]:
+ """Function generator to read a pointer from the bin file"""
- def lookup(addr: int, size: int) -> bytes | None:
+ def lookup(addr: int) -> int | None:
try:
- return bin_file.read(addr, size)
- except InvalidVirtualAddressError:
+ (ptr,) = struct.unpack(" None:
+ self.filename = filename
+ if timestamp is not None:
+ self.timestamp = timestamp
+ else:
+ self.timestamp = datetime.now().replace(microsecond=0)
+
+ self.entities = {}
+
+
+def _get_entity_for_addr(
+ samples: Iterable[ReccmpStatusReport], addr: str
+) -> Iterator[ReccmpComparedEntity]:
+ """Helper to return entities from xreports that have the given address."""
+ for sample in samples:
+ if addr in sample.entities:
+ yield sample.entities[addr]
+
+
+def _accuracy_sort_key(entity: ReccmpComparedEntity) -> float:
+ """Helper to sort entity samples by accuracy score.
+ 100% match is preferred over effective match.
+ Effective match is preferred over any accuracy.
+ Stubs rank lower than any accuracy score."""
+ if entity.is_stub:
+ return -1.0
+
+ if entity.accuracy == 1.0:
+ if not entity.is_effective_match:
+ return 1000.0
+
+ if entity.is_effective_match:
+ return 1.0
+
+ return entity.accuracy
+
+
+def combine_reports(samples: list[ReccmpStatusReport]) -> ReccmpStatusReport:
+ """Combines the sample reports into a single report.
+ The current strategy is to use the entity with the highest
+ accuracy score from any report."""
+ assert len(samples) > 0
+
+ if not all(samples[0].filename == s.filename for s in samples):
+ raise ReccmpReportSameSourceError
+
+ output = ReccmpStatusReport(filename=samples[0].filename)
+
+ # Combine every orig addr used in any of the reports.
+ orig_addr_set = {key for sample in samples for key in sample.entities.keys()}
+
+ all_orig_addrs = sorted(list(orig_addr_set))
+
+ for addr in all_orig_addrs:
+ e_list = list(_get_entity_for_addr(samples, addr))
+ assert len(e_list) > 0
+
+ # Our aggregate accuracy score is the highest from any report.
+ e_list.sort(key=_accuracy_sort_key, reverse=True)
+
+ output.entities[addr] = e_list[0]
+
+ # Keep the recomp_addr if it is the same across all samples.
+ # i.e. to detect where function alignment ends
+ if not all(e_list[0].recomp_addr == e.recomp_addr for e in e_list):
+ output.entities[addr].recomp_addr = "various"
+
+ return output
+
+
+#### JSON schemas and conversion functions ####
+
+
+@dataclass
+class JSONEntityVersion1:
+ address: str
+ name: str
+ matching: float
+ # Optional fields
+ recomp: str | None = None
+ stub: bool = False
+ effective: bool = False
+ diff: CombinedDiffOutput | None = None
+
+
+class JSONReportVersion1(BaseModel):
+ file: str
+ format: Literal[1]
+ timestamp: float
+ data: list[JSONEntityVersion1]
+
+
+def _serialize_version_1(
+ report: ReccmpStatusReport, diff_included: bool = False
+) -> JSONReportVersion1:
+ """The HTML file needs the diff data, but it is omitted from the JSON report."""
+ entities = [
+ JSONEntityVersion1(
+ address=addr, # prefer dict key over redundant value in entity
+ name=e.name,
+ matching=e.accuracy,
+ recomp=e.recomp_addr,
+ stub=e.is_stub,
+ effective=e.is_effective_match,
+ diff=e.diff if diff_included else None,
+ )
+ for addr, e in report.entities.items()
+ ]
+
+ return JSONReportVersion1(
+ file=report.filename,
+ format=1,
+ timestamp=report.timestamp.timestamp(),
+ data=entities,
+ )
+
+
+def _deserialize_version_1(obj: JSONReportVersion1) -> ReccmpStatusReport:
+ report = ReccmpStatusReport(
+ filename=obj.file, timestamp=datetime.fromtimestamp(obj.timestamp)
+ )
+
+ for e in obj.data:
+ report.entities[e.address] = ReccmpComparedEntity(
+ orig_addr=e.address,
+ name=e.name,
+ accuracy=e.matching,
+ recomp_addr=e.recomp,
+ is_stub=e.stub,
+ is_effective_match=e.effective,
+ diff=e.diff,
+ )
+
+ return report
+
+
+def deserialize_reccmp_report(json_str: str) -> ReccmpStatusReport:
+ try:
+ obj = JSONReportVersion1.model_validate(from_json(json_str))
+ return _deserialize_version_1(obj)
+ except ValidationError as ex:
+ raise ReccmpReportDeserializeError from ex
+
+
+def serialize_reccmp_report(
+ report: ReccmpStatusReport, diff_included: bool = False
+) -> str:
+ """Create a JSON string for the report so it can be written to a file."""
+ now = datetime.now().replace(microsecond=0)
+ report.timestamp = now
+ obj = _serialize_version_1(report, diff_included=diff_included)
+
+ return obj.model_dump_json(exclude_defaults=True)
diff --git a/reccmp/isledecomp/types.py b/reccmp/isledecomp/types.py
index afbc7341..34e1e64a 100644
--- a/reccmp/isledecomp/types.py
+++ b/reccmp/isledecomp/types.py
@@ -12,3 +12,4 @@ class EntityType(IntEnum):
STRING = 4
VTABLE = 5
FLOAT = 6
+ IMPORT = 7
diff --git a/reccmp/isledecomp/utils.py b/reccmp/isledecomp/utils.py
index fc7215d9..4a04269f 100644
--- a/reccmp/isledecomp/utils.py
+++ b/reccmp/isledecomp/utils.py
@@ -1,7 +1,31 @@
from datetime import datetime
import logging
-from pathlib import Path
import colorama
+from pystache import Renderer # type: ignore[import-untyped]
+from reccmp.assets import get_asset_file
+from reccmp.isledecomp.compare.report import (
+ ReccmpStatusReport,
+ ReccmpComparedEntity,
+ serialize_reccmp_report,
+)
+
+
+def write_html_report(html_file: str, report: ReccmpStatusReport):
+ """Create the interactive HTML diff viewer with the given report."""
+ js_path = get_asset_file("../assets/reccmp.js")
+ with open(js_path, "r", encoding="utf-8") as f:
+ reccmp_js = f.read()
+
+ # Convert the report to a JSON string to insert in the HTML template.
+ report_str = serialize_reccmp_report(report, diff_included=True)
+
+ output_data = Renderer().render_path(
+ get_asset_file("../assets/template.html"),
+ {"report": report_str, "reccmp_js": reccmp_js},
+ )
+
+ with open(html_file, "w", encoding="utf-8") as htmlfile:
+ htmlfile.write(output_data)
def print_combined_diff(udiff, plain: bool = False, show_both: bool = False):
@@ -129,7 +153,9 @@ def diff_json_display(show_both_addrs: bool = False, is_plain: bool = False):
"""Generate a function that will display the diff according to
the reccmp display preferences."""
- def formatter(orig_addr, saved, new) -> str:
+ def formatter(
+ orig_addr, saved: ReccmpComparedEntity, new: ReccmpComparedEntity
+ ) -> str:
old_pct = "new"
new_pct = "gone"
name = ""
@@ -138,29 +164,25 @@ def formatter(orig_addr, saved, new) -> str:
if new is not None:
new_pct = (
"stub"
- if new.get("stub", False)
- else percent_string(
- new["matching"], new.get("effective", False), is_plain
- )
+ if new.is_stub
+ else percent_string(new.accuracy, new.is_effective_match, is_plain)
)
# Prefer the current name of this function if we have it.
# We are using the original address as the key.
# A function being renamed is not of interest here.
- name = new.get("name", "")
- recomp_addr = new.get("recomp", "n/a")
+ name = new.name
+ recomp_addr = new.recomp_addr or "n/a"
if saved is not None:
old_pct = (
"stub"
- if saved.get("stub", False)
- else percent_string(
- saved["matching"], saved.get("effective", False), is_plain
- )
+ if saved.is_stub
+ else percent_string(saved.accuracy, saved.is_effective_match, is_plain)
)
if name == "":
- name = saved.get("name", "")
+ name = saved.name
if show_both_addrs:
addr_string = f"{orig_addr} / {recomp_addr:10}"
@@ -176,29 +198,25 @@ def formatter(orig_addr, saved, new) -> str:
def diff_json(
- saved_data,
- new_data,
- orig_file: Path,
+ saved_data: ReccmpStatusReport,
+ new_data: ReccmpStatusReport,
show_both_addrs: bool = False,
is_plain: bool = False,
):
- """Using a saved copy of the diff summary and the current data, print a
- report showing which functions/symbols have changed match percentage."""
+ """Compare two status report files, determine what items changed, and print the result."""
# Don't try to diff a report generated for a different binary file
- base_file = orig_file.name.lower()
-
- if saved_data.get("file") != base_file:
+ if saved_data.filename != new_data.filename:
logging.getLogger().error(
"Diff report for '%s' does not match current file '%s'",
- saved_data.get("file"),
- base_file,
+ saved_data.filename,
+ new_data.filename,
)
return
- if "timestamp" in saved_data:
+ if saved_data.timestamp is not None:
now = datetime.now().replace(microsecond=0)
- then = datetime.fromtimestamp(saved_data["timestamp"]).replace(microsecond=0)
+ then = saved_data.timestamp.replace(microsecond=0)
print(
" ".join(
@@ -213,8 +231,8 @@ def diff_json(
print()
# Convert to dict, using orig_addr as key
- saved_invert = {obj["address"]: obj for obj in saved_data["data"]}
- new_invert = {obj["address"]: obj for obj in new_data}
+ saved_invert = saved_data.entities
+ new_invert = new_data.entities
all_addrs = set(saved_invert.keys()).union(new_invert.keys())
@@ -227,60 +245,56 @@ def diff_json(
for addr in sorted(all_addrs)
}
+ DiffSubsectionType = dict[
+ str, tuple[ReccmpComparedEntity | None, ReccmpComparedEntity | None]
+ ]
+
# The criteria for diff judgement is in these dict comprehensions:
# Any function not in the saved file
- new_functions = {
+ new_functions: DiffSubsectionType = {
key: (saved, new) for key, (saved, new) in combined.items() if saved is None
}
# Any function now missing from the saved file
# or a non-stub -> stub conversion
- dropped_functions = {
+ dropped_functions: DiffSubsectionType = {
key: (saved, new)
for key, (saved, new) in combined.items()
if new is None
- or (
- new is not None
- and saved is not None
- and new.get("stub", False)
- and not saved.get("stub", False)
- )
+ or (new is not None and saved is not None and new.is_stub and not saved.is_stub)
}
# TODO: move these two into functions if the assessment gets more complex
# Any function with increased match percentage
# or stub -> non-stub conversion
- improved_functions = {
+ improved_functions: DiffSubsectionType = {
key: (saved, new)
for key, (saved, new) in combined.items()
if saved is not None
and new is not None
- and (
- new["matching"] > saved["matching"]
- or (not new.get("stub", False) and saved.get("stub", False))
- )
+ and (new.accuracy > saved.accuracy or (not new.is_stub and saved.is_stub))
}
# Any non-stub function with decreased match percentage
- degraded_functions = {
+ degraded_functions: DiffSubsectionType = {
key: (saved, new)
for key, (saved, new) in combined.items()
if saved is not None
and new is not None
- and new["matching"] < saved["matching"]
- and not saved.get("stub")
- and not new.get("stub")
+ and new.accuracy < saved.accuracy
+ and not saved.is_stub
+ and not new.is_stub
}
# Any function with former or current "effective" match
- entropy_functions = {
+ entropy_functions: DiffSubsectionType = {
key: (saved, new)
for key, (saved, new) in combined.items()
if saved is not None
and new is not None
- and new["matching"] == 1.0
- and saved["matching"] == 1.0
- and new.get("effective", False) != saved.get("effective", False)
+ and new.accuracy == 1.0
+ and saved.accuracy == 1.0
+ and new.is_effective_match != saved.is_effective_match
}
get_diff_str = diff_json_display(show_both_addrs, is_plain)
diff --git a/reccmp/tools/aggregate.py b/reccmp/tools/aggregate.py
new file mode 100644
index 00000000..e7661e9a
--- /dev/null
+++ b/reccmp/tools/aggregate.py
@@ -0,0 +1,179 @@
+#!/usr/bin/env python3
+
+import argparse
+import logging
+from typing import Sequence
+from pathlib import Path
+from reccmp.isledecomp.utils import diff_json, write_html_report
+from reccmp.isledecomp.compare.report import (
+ ReccmpStatusReport,
+ combine_reports,
+ ReccmpReportDeserializeError,
+ ReccmpReportSameSourceError,
+ deserialize_reccmp_report,
+ serialize_reccmp_report,
+)
+
+
+logger = logging.getLogger(__name__)
+
+
+def write_report_file(output_file: Path, report: ReccmpStatusReport):
+ """Convert the status report to JSON and write to a file."""
+ json_str = serialize_reccmp_report(report)
+
+ with open(output_file, "w+", encoding="utf-8") as f:
+ f.write(json_str)
+
+
+def load_report_file(report_path: Path) -> ReccmpStatusReport:
+ """Deserialize from JSON at the given filename and return the report."""
+
+ with report_path.open("r", encoding="utf-8") as f:
+ return deserialize_reccmp_report(f.read())
+
+
+def deserialize_sample_files(paths: list[Path]) -> list[ReccmpStatusReport]:
+ """Deserialize all sample files and return the list of reports.
+ Does not remove duplicates."""
+ samples = []
+
+ for path in paths:
+ if path.is_file():
+ try:
+ report = load_report_file(path)
+ samples.append(report)
+ except ReccmpReportDeserializeError:
+ logger.warning("Skipping '%s' due to import error", path)
+ elif not path.exists():
+ logger.warning("File not found: '%s'", path)
+
+ return samples
+
+
+class TwoOrMoreArgsAction(argparse.Action):
+ """Support nargs=2+"""
+
+ def __call__(
+ self, parser, namespace, values: Sequence[str] | None, option_string=None
+ ):
+ assert isinstance(values, Sequence)
+ if len(values) < 2:
+ raise argparse.ArgumentError(self, "expected two or more arguments")
+
+ setattr(namespace, self.dest, values)
+
+
+class TwoOrFewerArgsAction(argparse.Action):
+ """Support nargs=(1,2)"""
+
+ def __call__(
+ self, parser, namespace, values: Sequence[str] | None, option_string=None
+ ):
+ assert isinstance(values, Sequence)
+ if len(values) not in (1, 2):
+ raise argparse.ArgumentError(self, "expected one or two arguments")
+
+ setattr(namespace, self.dest, values)
+
+
+def main():
+ parser = argparse.ArgumentParser(
+ allow_abbrev=False,
+ description="Aggregate saved accuracy reports.",
+ )
+ parser.add_argument(
+ "--diff",
+ type=Path,
+ metavar="",
+ nargs="+",
+ action=TwoOrFewerArgsAction,
+ help="Report files to diff.",
+ )
+ parser.add_argument(
+ "--html",
+ type=Path,
+ metavar="",
+ help="Location for HTML report based on aggregate.",
+ )
+ parser.add_argument(
+ "--output",
+ "-o",
+ type=Path,
+ metavar="",
+ help="Where to save the aggregate file.",
+ )
+ parser.add_argument(
+ "--samples",
+ type=Path,
+ metavar="",
+ nargs="+",
+ action=TwoOrMoreArgsAction,
+ help="Report files to aggregate.",
+ )
+ parser.add_argument(
+ "--no-color", "-n", action="store_true", help="Do not color the output"
+ )
+
+ args = parser.parse_args()
+
+ if not (args.samples or args.diff):
+ parser.error(
+ "exepected arguments for --samples or --diff. (No input files specified)"
+ )
+
+ if not (args.output or args.diff or args.html):
+ parser.error(
+ "expected arguments for --output, --html, or --diff. (No output action specified)"
+ )
+
+ agg_report: ReccmpStatusReport | None = None
+
+ if args.samples is not None:
+ samples = deserialize_sample_files(args.samples)
+
+ if len(samples) < 2:
+ logger.error("Not enough samples to aggregate!")
+ return 1
+
+ try:
+ agg_report = combine_reports(samples)
+ except ReccmpReportSameSourceError:
+ filename_list = sorted({s.filename for s in samples})
+ logger.error(
+ "Aggregate samples are not from the same source file!\nFilenames used: %s",
+ filename_list,
+ )
+ return 1
+
+ if args.output is not None:
+ write_report_file(args.output, agg_report)
+
+ if args.html is not None:
+ write_html_report(args.html, agg_report)
+
+ # If --diff has at least one file and we aggregated some samples this run, diff the first file and the aggregate.
+ # If --diff has two files and we did not aggregate this run, diff the files in the list.
+ if args.diff is not None:
+ saved_data = load_report_file(args.diff[0])
+
+ if agg_report is None:
+ if len(args.diff) > 1:
+ agg_report = load_report_file(args.diff[1])
+ else:
+ logger.error("Not enough files to diff!")
+ return 1
+ elif len(args.diff) == 2:
+ logger.warning(
+ "Ignoring second --diff argument '%s'.\nDiff of '%s' and aggregate report follows.",
+ args.diff[1],
+ args.diff[0],
+ )
+
+ diff_json(saved_data, agg_report, show_both_addrs=False, is_plain=args.no_color)
+
+ return 0
+
+
+if __name__ == "__main__":
+ raise SystemExit(main())
diff --git a/reccmp/tools/asmcmp.py b/reccmp/tools/asmcmp.py
index 0e5eb16c..e83d95ce 100755
--- a/reccmp/tools/asmcmp.py
+++ b/reccmp/tools/asmcmp.py
@@ -2,11 +2,8 @@
import argparse
import base64
-import json
import logging
import os
-from datetime import datetime
-from pathlib import Path
from pystache import Renderer # type: ignore[import-untyped]
import colorama
@@ -15,9 +12,16 @@
print_combined_diff,
diff_json,
percent_string,
+ write_html_report,
)
from reccmp.isledecomp.compare import Compare as IsleCompare
+from reccmp.isledecomp.compare.report import (
+ ReccmpStatusReport,
+ ReccmpComparedEntity,
+ deserialize_reccmp_report,
+ serialize_reccmp_report,
+)
from reccmp.isledecomp.formats.detect import detect_image
from reccmp.isledecomp.formats.pe import PEImage
from reccmp.isledecomp.types import EntityType
@@ -34,43 +38,11 @@
colorama.just_fix_windows_console()
-def gen_json(json_file: str, orig_file: Path, data):
- """Create a JSON file that contains the comparison summary"""
-
- # If the structure of the JSON file ever changes, we would run into a problem
- # reading an older format file in the CI action. Mark which version we are
- # generating so we could potentially address this down the road.
- json_format_version = 1
-
- # Remove the diff field
- reduced_data = [
- {key: value for (key, value) in obj.items() if key != "diff"} for obj in data
- ]
+def gen_json(json_file: str, json_str: str):
+ """Convert the status report to JSON and write to a file."""
with open(json_file, "w", encoding="utf-8") as f:
- json.dump(
- {
- "file": orig_file.name.lower(),
- "format": json_format_version,
- "timestamp": datetime.now().timestamp(),
- "data": reduced_data,
- },
- f,
- )
-
-
-def gen_html(html_file, data):
- js_path = get_asset_file("../assets/reccmp.js")
- with open(js_path, "r", encoding="utf-8") as f:
- reccmp_js = f.read()
-
- output_data = Renderer().render_path(
- get_asset_file("../assets/template.html"),
- {"data": data, "reccmp_js": reccmp_js},
- )
-
- with open(html_file, "w", encoding="utf-8") as htmlfile:
- htmlfile.write(output_data)
+ f.write(json_str)
def gen_svg(svg_file, name_svg, icon, svg_implemented_funcs, total_funcs, raw_accuracy):
@@ -177,6 +149,11 @@ def virtual_address(value) -> int:
metavar="",
help="Generate JSON file with match summary",
)
+ parser.add_argument(
+ "--json-diet",
+ action="store_true",
+ help="Exclude diff from JSON report.",
+ )
parser.add_argument(
"--diff",
metavar="",
@@ -267,7 +244,8 @@ def main():
function_count = 0
total_accuracy = 0.0
total_effective_accuracy = 0.0
- htmlinsert = []
+
+ report = ReccmpStatusReport(filename=target.original_path.name.lower())
for match in isle_compare.compare_all():
if not args.silent and args.diff is None:
@@ -287,44 +265,42 @@ def main():
total_effective_accuracy += match.effective_ratio
# If html, record the diffs to an HTML file
- html_obj = {
- "address": f"0x{match.orig_addr:x}",
- "recomp": f"0x{match.recomp_addr:x}",
- "name": match.name,
- "matching": match.effective_ratio,
- }
-
- if match.is_effective_match:
- html_obj["effective"] = True
-
- if match.udiff is not None:
- html_obj["diff"] = match.udiff
-
- if match.is_stub:
- html_obj["stub"] = True
-
- htmlinsert.append(html_obj)
+ orig_addr = f"0x{match.orig_addr:x}"
+ recomp_addr = f"0x{match.recomp_addr:x}"
+
+ report.entities[orig_addr] = ReccmpComparedEntity(
+ orig_addr=orig_addr,
+ name=match.name,
+ accuracy=match.effective_ratio,
+ recomp_addr=recomp_addr,
+ is_effective_match=match.is_effective_match,
+ is_stub=match.is_stub,
+ diff=match.udiff,
+ )
# Compare with saved diff report.
if args.diff is not None:
with open(args.diff, "r", encoding="utf-8") as f:
- saved_data = json.load(f)
-
- diff_json(
- saved_data,
- htmlinsert,
- target.original_path,
- show_both_addrs=args.print_rec_addr,
- is_plain=args.no_color,
- )
+ saved_data = deserialize_reccmp_report(f.read())
+
+ diff_json(
+ saved_data,
+ report,
+ show_both_addrs=args.print_rec_addr,
+ is_plain=args.no_color,
+ )
## Generate files and show summary.
if args.json is not None:
- gen_json(args.json, target.original_path, htmlinsert)
+ # If we're on a diet, hold the diff.
+ diff_included = not bool(args.json_diet)
+ gen_json(
+ args.json, serialize_reccmp_report(report, diff_included=diff_included)
+ )
if args.html is not None:
- gen_html(args.html, json.dumps(htmlinsert))
+ write_html_report(args.html, report)
implemented_funcs = function_count
diff --git a/setup.cfg b/setup.cfg
index 6458a6f5..629c9862 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -21,6 +21,7 @@ exclude =
[options.entry_points]
console_scripts =
+ reccmp-aggregate = reccmp.tools.aggregate:main
reccmp-datacmp = reccmp.tools.datacmp:main
reccmp-decomplint = reccmp.tools.decomplint:main
reccmp-project = reccmp.tools.project:main
diff --git a/tests/test_name_replacement.py b/tests/test_name_replacement.py
new file mode 100644
index 00000000..e62c3f92
--- /dev/null
+++ b/tests/test_name_replacement.py
@@ -0,0 +1,217 @@
+import pytest
+from reccmp.isledecomp.types import EntityType
+from reccmp.isledecomp.compare.db import EntityDb
+from reccmp.isledecomp.compare.asm.replacement import (
+ create_name_lookup,
+ NameReplacementProtocol,
+)
+
+
+@pytest.fixture(name="db")
+def fixture_db() -> EntityDb:
+ return EntityDb()
+
+
+def create_lookup(
+ db, addrs: dict[int, int] | None = None, is_orig: bool = True
+) -> NameReplacementProtocol:
+ if addrs is None:
+ addrs = {}
+
+ def bin_lookup(addr: int) -> int | None:
+ return addrs.get(addr)
+
+ if is_orig:
+ return create_name_lookup(db.get_by_orig, bin_lookup, "orig_addr")
+
+ return create_name_lookup(db.get_by_recomp, bin_lookup, "recomp_addr")
+
+
+####
+
+
+def test_name_replacement(db):
+ """Should return a name for an entity that has one.
+ Return None if there are no name attributes set or the entity does not exist."""
+ with db.batch() as batch:
+ batch.set_orig(100, name="Test")
+ batch.set_orig(200, computed_name="Hello")
+ batch.set_orig(300) # No name
+
+ lookup = create_lookup(db)
+
+ # Using "in" here because the returned string may contain other information.
+ # e.g. the entity type
+ assert "Test" in lookup(100)
+ assert "Hello" in lookup(200)
+ assert lookup(300) is None
+
+
+def test_name_hierarchy(db):
+ """Use the "best" entity name. Currently there are only two.
+ 'computed_name' is preferred over just 'name'."""
+ with db.batch() as batch:
+ batch.set_orig(100, name="Test", computed_name="Hello")
+
+ lookup = create_lookup(db)
+
+ # Should prefer 'computed_name' over 'name'
+ assert "Hello" in lookup(100)
+ assert "Test" not in lookup(100)
+
+
+def test_string_escape_newlines(db):
+ """Make sure newlines are removed from the string.
+ This overlap with tests on the ReccmpEntity name functions, but it is more vital
+ to ensure there are no newlines at this stage because they will disrupt the asm diff.
+ """
+ with db.batch() as batch:
+ batch.set_orig(100, name="Test\nTest", type=EntityType.STRING)
+
+ lookup = create_lookup(db)
+
+ assert "\n" not in lookup(100)
+
+
+def test_offset_name(db):
+ """For some entities (i.e. variables) we will return a name if the search address
+ is inside the address range of the entity. This is determined by the size attribute.
+ """
+ with db.batch() as batch:
+ batch.set_orig(100, name="Hello", type=EntityType.DATA, size=10)
+
+ lookup = create_lookup(db)
+
+ assert lookup(100) is not None
+ assert lookup(101) is not None
+
+ # Outside the range = no name
+ assert lookup(110) is None
+
+
+def test_offset_name_non_variables(db):
+ """Do not return an offset name for non-variable entities. (e.g. functions)."""
+ with db.batch() as batch:
+ batch.set_orig(100, name="Hello", type=EntityType.FUNCTION, size=10)
+ batch.set_orig(200, name="Hello", size=10) # No type
+
+ lookup = create_lookup(db)
+
+ assert lookup(100) is not None
+ assert lookup(101) is None
+
+ assert lookup(200) is not None
+ assert lookup(201) is None
+
+
+def test_offset_name_no_size(db):
+ """An enity with no size attribute is considered to have size=0.
+ Meaning: match only against the address value."""
+ with db.batch() as batch:
+ batch.set_orig(100, name="Hello", type=EntityType.DATA)
+
+ lookup = create_lookup(db)
+
+ assert lookup(100) is not None
+ assert lookup(101) is None
+
+
+def test_exact_restriction(db):
+ """If exact=True, return a name only if the entity's address matches the search address.
+ Otherwise we might return a name if the entity contains the search address."""
+ with db.batch() as batch:
+ batch.set_orig(100, name="Hello", type=EntityType.DATA, size=10)
+
+ lookup = create_lookup(db)
+
+ assert lookup(100, exact=True) is not None
+ assert lookup(101, exact=True) is None
+
+ # Proof that the exact parameter controls whether we get a name.
+ assert lookup(101, exact=False) is not None
+
+
+def test_indirect_function(db):
+ """An instruction like `call dword ptr [0x1234]` means that we call the function
+ whose address is at address 0x1234. This is an indirect lookup."""
+ with db.batch() as batch:
+ batch.set_orig(100, name="Hello", type=EntityType.FUNCTION)
+
+ # Mock lookup so we will read 100 from address 200.
+ lookup = create_lookup(db, {200: 100})
+
+ # No entity at 200
+ assert lookup(200) is None
+ assert lookup(200, indirect=True) is not None
+
+ # Imitating ghidra asm display. Not every indirect lookup gets the arrow.
+ assert "->" in lookup(200, indirect=True)
+
+
+def test_indirect_function_variable(db):
+ """If the indirect call instruction has the address of a variable in our database,
+ prefer the variable name rather than reading the pointer."""
+ with db.batch() as batch:
+ batch.set_orig(100, name="Hello", type=EntityType.FUNCTION)
+ batch.set_orig(200, name="Test", type=EntityType.DATA)
+
+ # Mock lookup so we will read 100 from address 200.
+ lookup = create_lookup(db, {200: 100})
+
+ name = lookup(200, indirect=True)
+ assert name is not None
+ assert "Hello" not in name
+ assert "Test" in name
+ assert "->" not in name
+
+
+def test_indirect_import(db):
+ """If we are indirectly calling an imported funtion, we should see the import_name
+ attribute used in the result. This will probably contain the DLL and function name.
+ """
+ with db.batch() as batch:
+ batch.set_orig(100, import_name="Hello", name="Test", type=EntityType.IMPORT)
+
+ # No mock needed here because we will not need to read any data.
+ lookup = create_lookup(db)
+
+ # Should use import name with arrow to suggest indirect call.
+ name = lookup(100, indirect=True)
+ assert name is not None
+ assert "Hello" in name
+ assert "->" in name
+
+ # Show the entity name instead. (e.g. __imp__ symbol)
+ name = lookup(100, indirect=False)
+ assert name is not None
+ assert "Test" in name
+ assert "->" not in name
+
+
+def test_indirect_import_missing_data(db):
+ """Edge cases for indirect lookup on an IMPORT entity.."""
+ with db.batch() as batch:
+ batch.set_orig(100, name="Test", type=EntityType.IMPORT)
+
+ lookup = create_lookup(db)
+
+ # No import name. Use the regular entity name instead (i.e. match indirect=False lookup)
+ name = lookup(100, indirect=True)
+ assert name is not None
+ assert "Test" in name
+ assert "->" not in name
+
+
+def test_indirect_failed_lookup(db):
+ """In the general case (i.e. we do not use the base entity to get the name)
+ if there is no entity at the pointer location, return None."""
+ with db.batch() as batch:
+ batch.set_orig(200, name="Hello", type=EntityType.FUNCTION)
+
+ # Mock lookup so we will read 100 from address 200.
+ lookup = create_lookup(db, {200: 100})
+
+ # There is an entity at 200 but we can't use it to generate a name.
+ # There is no entity at 100 (indirect location)
+ assert lookup(200, indirect=False) is not None
+ assert lookup(200, indirect=True) is None
diff --git a/tests/test_report.py b/tests/test_report.py
new file mode 100644
index 00000000..71a71f47
--- /dev/null
+++ b/tests/test_report.py
@@ -0,0 +1,136 @@
+"""Reccmp reports: files that contain the comparison result from asmcmp."""
+import pytest
+from reccmp.isledecomp.compare.report import (
+ ReccmpStatusReport,
+ ReccmpComparedEntity,
+ combine_reports,
+ ReccmpReportSameSourceError,
+)
+
+
+def create_report(
+ entities: list[tuple[str, float]] | None = None
+) -> ReccmpStatusReport:
+ """Helper to quickly set up a report to be customized further for each test."""
+ report = ReccmpStatusReport(filename="test.exe")
+ if entities is not None:
+ for addr, accuracy in entities:
+ report.entities[addr] = ReccmpComparedEntity(addr, "test", accuracy)
+
+ return report
+
+
+def test_aggregate_identity():
+ """Combine a list of one report. Should get the same report back,
+ except for expected differences like the timestamp."""
+ report = create_report([("100", 1.0), ("200", 0.5)])
+ combined = combine_reports([report])
+
+ for (a_key, a_entity), (b_key, b_entity) in zip(
+ report.entities.items(), combined.entities.items()
+ ):
+ assert a_key == b_key
+ assert a_entity.orig_addr == b_entity.orig_addr
+ assert a_entity.accuracy == b_entity.accuracy
+
+
+def test_aggregate_simple():
+ """Should choose the best score from the sample reports."""
+ x = create_report([("100", 0.8), ("200", 0.2)])
+ y = create_report([("100", 0.2), ("200", 0.8)])
+
+ combined = combine_reports([x, y])
+ assert combined.entities["100"].accuracy == 0.8
+ assert combined.entities["200"].accuracy == 0.8
+
+
+def test_aggregate_union_all_addrs():
+ """Should combine all addresses from any report."""
+ x = create_report([("100", 0.8)])
+ y = create_report([("200", 0.8)])
+
+ combined = combine_reports([x, y])
+ assert "100" in combined.entities
+ assert "200" in combined.entities
+
+
+def test_aggregate_stubs():
+ """Stub functions (i.e. do not compare asm) are considered to have 0 percent accuracy."""
+ x = create_report([("100", 0.9)])
+ y = create_report([("100", 0.5)])
+
+ # In a real report, accuracy would be zero for a stub.
+ x.entities["100"].is_stub = True
+ y.entities["100"].is_stub = False
+
+ combined = combine_reports([x, y])
+ assert combined.entities["100"].is_stub is False
+
+ # Choose the lower non-stub value
+ assert combined.entities["100"].accuracy == 0.5
+
+
+def test_aggregate_all_stubs():
+ """If all samples are stubs, preserve that setting."""
+ x = create_report([("100", 1.0)])
+
+ x.entities["100"].is_stub = True
+
+ combined = combine_reports([x, x])
+ assert combined.entities["100"].is_stub is True
+
+
+def test_aggregate_100_over_effective():
+ """Prefer 100% match over effective."""
+ x = create_report([("100", 0.9)])
+ y = create_report([("100", 1.0)])
+ x.entities["100"].is_effective_match = True
+
+ combined = combine_reports([x, y])
+ assert combined.entities["100"].is_effective_match is False
+
+
+def test_aggregate_effective_over_any():
+ """Prefer effective match over any accuracy."""
+ x = create_report([("100", 0.5)])
+ y = create_report([("100", 0.6)])
+ x.entities["100"].is_effective_match = True
+ # Y has higher accuracy score, but we could not confirm an effective match.
+
+ combined = combine_reports([x, y])
+ assert combined.entities["100"].is_effective_match is True
+
+ # Should retain original accuracy for effective match.
+ assert combined.entities["100"].accuracy == 0.5
+
+
+def test_aggregate_different_files():
+ """Should raise an exception if we try to aggregate reports
+ where the orig filename does not match."""
+ x = create_report()
+ y = create_report()
+
+ # Make sure they are different, regardless of what is set by create_report().
+ x.filename = "test.exe"
+ y.filename = "hello.exe"
+
+ with pytest.raises(ReccmpReportSameSourceError):
+ combine_reports([x, y])
+
+
+def test_aggregate_recomp_addr():
+ """We combine the entity data based on the orig addr because this will not change.
+ The recomp addr may vary a lot. If it is the same in all samples, use the value.
+ Otherwise use a placeholder value."""
+ x = create_report([("100", 0.8), ("200", 0.2)])
+ y = create_report([("100", 0.2), ("200", 0.8)])
+ # These recomp addrs match:
+ x.entities["100"].recomp_addr = "500"
+ y.entities["100"].recomp_addr = "500"
+ # Y report has no addr for this
+ x.entities["200"].recomp_addr = "600"
+
+ combined = combine_reports([x, y])
+ assert combined.entities["100"].recomp_addr == "500"
+ assert combined.entities["200"].recomp_addr != "600"
+ assert combined.entities["200"].recomp_addr == "various"
diff --git a/tests/test_sanitize32.py b/tests/test_sanitize32.py
index 4dfd4201..033437ba 100644
--- a/tests/test_sanitize32.py
+++ b/tests/test_sanitize32.py
@@ -75,7 +75,7 @@ def test_pointer_instructions_with_name(inst: DisasmLiteInst):
(_, op_str) = p.sanitize(inst)
# Using sample instructions where exact match is not required
- name_lookup.assert_called_with(0x1234, exact=False)
+ name_lookup.assert_called_with(0x1234, exact=False, indirect=False)
assert "[0x1234]" not in op_str
assert "[Hello]" in op_str
@@ -258,7 +258,7 @@ def test_jmp_with_name_lookup():
(_, op_str) = p.sanitize(DisasmLiteInst(0x1000, 5, "jmp", "0x2000"))
- name_lookup.assert_called_with(0x2000, exact=True)
+ name_lookup.assert_called_with(0x2000, exact=True, indirect=False)
assert op_str == "Hello"
@@ -302,7 +302,7 @@ def test_cmp_with_name_lookup():
(_, op_str) = p.sanitize(inst)
- name_lookup.assert_called_with(0x2000, exact=False)
+ name_lookup.assert_called_with(0x2000, exact=False, indirect=False)
assert op_str == "eax, Hello"
@@ -327,7 +327,7 @@ def test_call_with_name_lookup():
(_, op_str) = p.sanitize(inst)
- name_lookup.assert_called_with(0x1234, exact=True)
+ name_lookup.assert_called_with(0x1234, exact=True, indirect=False)
assert op_str == "Hello"
@@ -348,42 +348,72 @@ def substitute_1234(addr: int, **_) -> str | None:
def test_absolute_indirect():
- """**** Held over from previous test file. This behavior may change soon. ****
- The instruction `call dword ptr [0x1234]` means we call the function
- whose address is at 0x1234. (i.e. absolute indirect addressing mode)
- It is probably more useful to show the name of the function itself if
- we have it, but there are some circumstances where we want to replace
- with the pointer's name (i.e. an import function)."""
+ """Read the given pointer and replace its value with a name or placeholder.
+ Previously we handled reading from the binary inside the sanitize function.
+ This is now delegated to the name lookup function, so we just need to check
+ that it was called with the indirect parameter set."""
+ name_lookup = Mock(spec=NameReplacementProtocol, return_value=None)
+ p = ParseAsm(name_lookup=name_lookup)
+ inst = DisasmLiteInst(0x1000, 5, "call", "dword ptr [0x1234]")
- def name_lookup(addr: int, **_) -> str | None:
- return {
- 0x1234: "Hello",
- 0x4321: "xyz",
- 0x5555: "Test",
- }.get(addr)
-
- def bin_lookup(addr: int, _: int) -> bytes | None:
- return (
- {
- 0x1234: b"\x55\x55\x00\x00",
- 0x4321: b"\x99\x99\x00\x00",
- }
- ).get(addr)
-
- p = ParseAsm(name_lookup=name_lookup, bin_lookup=bin_lookup)
-
- # If we know the indirect address (0x5555)
- # Arrow to indicate this is an indirect replacement
- (_, op_str) = p.sanitize(DisasmLiteInst(0x1000, 5, "call", "dword ptr [0x1234]"))
- assert op_str == "dword ptr [->Test]"
-
- # If we do not know the indirect address (0x9999)
- (_, op_str) = p.sanitize(DisasmLiteInst(0x1000, 5, "call", "dword ptr [0x4321]"))
- assert op_str == "dword ptr [xyz]"
-
- # If we can't read the indirect address
- (_, op_str) = p.sanitize(DisasmLiteInst(0x1000, 5, "call", "dword ptr [0x5555]"))
- assert op_str == "dword ptr [Test]"
+ (_, op_str) = p.sanitize(inst)
+
+ name_lookup.assert_called_with(0x1234, exact=True, indirect=True)
+ assert op_str == "dword ptr []"
+
+
+def test_direct_and_indirect_different_names():
+ """Indirect pointers should not collide with
+ cached lookups on direct pointers and vice versa"""
+
+ # Create a lookup that checks indirect access
+ def lookup(_, indirect: bool = False, **__) -> str:
+ return "Indirect" if indirect else "Direct"
+
+ indirect_inst = DisasmLiteInst(0x1000, 5, "call", "dword ptr [0x1234]")
+ direct_inst = DisasmLiteInst(0x1000, 5, "mov", "eax, dword ptr [0x1234]")
+
+ # Indirect first
+ p = ParseAsm(name_lookup=lookup)
+ (_, op_str) = p.sanitize(indirect_inst)
+ assert op_str == "dword ptr [Indirect]"
+
+ (_, op_str) = p.sanitize(direct_inst)
+ assert op_str == "eax, dword ptr [Direct]"
+
+ # Direct first
+ p = ParseAsm(name_lookup=lookup)
+ (_, op_str) = p.sanitize(direct_inst)
+ assert op_str == "eax, dword ptr [Direct]"
+
+ (_, op_str) = p.sanitize(indirect_inst)
+ assert op_str == "dword ptr [Indirect]"
+
+ # Now verify that we use cached values for each
+ name_lookup = Mock(spec=NameReplacementProtocol, return_value=None)
+ p.name_lookup = name_lookup
+ (_, op_str) = p.sanitize(indirect_inst)
+ assert op_str == "dword ptr [Indirect]"
+
+ (_, op_str) = p.sanitize(direct_inst)
+ assert op_str == "eax, dword ptr [Direct]"
+
+ name_lookup.assert_not_called()
+
+
+def test_direct_and_indirect_placeholders():
+ """If no addresses are known, placeholders for direct and indirect lookup must be distinct"""
+ indirect_inst = DisasmLiteInst(0x1000, 5, "call", "dword ptr [0x1234]")
+ direct_inst = DisasmLiteInst(0x1000, 5, "mov", "eax, dword ptr [0x1234]")
+
+ name_lookup = Mock(spec=NameReplacementProtocol, return_value=None)
+ p = ParseAsm(name_lookup=name_lookup)
+
+ (_, indirect_op_str) = p.sanitize(indirect_inst)
+ (_, direct_op_str) = p.sanitize(direct_inst)
+
+ # Must use two different placeholders
+ assert indirect_op_str != direct_op_str
def test_consistent_numbering():