From 3b0df42013cd66e745046689bd677d23372fc4f0 Mon Sep 17 00:00:00 2001 From: disinvite Date: Sat, 1 Feb 2025 12:52:08 -0500 Subject: [PATCH] New matching and event reporting module --- reccmp/isledecomp/compare/core.py | 124 +++-- reccmp/isledecomp/compare/db.py | 14 + reccmp/isledecomp/compare/event.py | 48 ++ reccmp/isledecomp/compare/match_msvc.py | 311 ++++++++++++ tests/test_compare_db.py | 16 + tests/test_match_msvc.py | 620 ++++++++++++++++++++++++ 6 files changed, 1092 insertions(+), 41 deletions(-) create mode 100644 reccmp/isledecomp/compare/event.py create mode 100644 reccmp/isledecomp/compare/match_msvc.py create mode 100644 tests/test_match_msvc.py diff --git a/reccmp/isledecomp/compare/core.py b/reccmp/isledecomp/compare/core.py index d51db0e6..4a72d697 100644 --- a/reccmp/isledecomp/compare/core.py +++ b/reccmp/isledecomp/compare/core.py @@ -16,10 +16,19 @@ from reccmp.isledecomp.parser import DecompCodebase from reccmp.isledecomp.dir import walk_source_dir from reccmp.isledecomp.types import EntityType +from reccmp.isledecomp.compare.event import create_logging_wrapper from reccmp.isledecomp.compare.asm import ParseAsm from reccmp.isledecomp.compare.asm.replacement import create_name_lookup from reccmp.isledecomp.compare.asm.fixes import assert_fixup, find_effective_match from reccmp.isledecomp.analysis import find_float_consts +from .match_msvc import ( + match_symbols, + match_functions, + match_vtables, + match_static_variables, + match_variables, + match_strings, +) from .db import EntityDb, ReccmpEntity, ReccmpMatch from .diff import combined_diff, CombinedDiffOutput from .lines import LinesDb @@ -142,7 +151,7 @@ def _load_cvdump(self): # Build the list of entries to insert to the DB. # In the rare case we have duplicate symbols for an address, ignore them. - dataset = {} + seen_addrs = set() batch = self._db.batch() @@ -162,9 +171,11 @@ def _load_cvdump(self): addr = self.recomp_bin.get_abs_addr(sym.section, sym.offset) sym.addr = addr - if addr in dataset: + if addr in seen_addrs: continue + seen_addrs.add(addr) + # If this symbol is the final one in its section, we were not able to # estimate its size because we didn't have the total size of that section. # We can get this estimate now and assume that the final symbol occupies @@ -262,51 +273,82 @@ def orig_bin_checker(addr: int) -> bool: # If we have two functions that share the same name, and one is # a lineref, we can match the nameref correctly because the lineref # was already removed from consideration. - for fun in codebase.iter_line_functions(): - assert fun.filename is not None - recomp_addr = self._lines_db.search_line( - fun.filename, fun.line_number, fun.end_line - ) - if recomp_addr is not None: - self._db.set_function_pair(fun.offset, recomp_addr) - if fun.should_skip(): - self._db.mark_stub(fun.offset) - - for fun in codebase.iter_name_functions(): - self._db.match_function(fun.offset, fun.name) - if fun.should_skip(): - self._db.mark_stub(fun.offset) - - for var in codebase.iter_variables(): - if var.is_static and var.parent_function is not None: - self._db.match_static_variable( - var.offset, var.name, var.parent_function + with self._db.batch() as batch: + for fun in codebase.iter_line_functions(): + assert fun.filename is not None + recomp_addr = self._lines_db.search_line( + fun.filename, fun.line_number, fun.end_line + ) + if recomp_addr is not None: + batch.match(fun.offset, recomp_addr) + batch.set_recomp( + recomp_addr, type=EntityType.FUNCTION, stub=fun.should_skip() + ) + + with self._db.batch() as batch: + for fun in codebase.iter_name_functions(): + batch.set_orig( + fun.offset, type=EntityType.FUNCTION, stub=fun.should_skip() ) - else: - self._db.match_variable(var.offset, var.name) - for tbl in codebase.iter_vtables(): - self._db.match_vtable(tbl.offset, tbl.name, tbl.base_class) + if fun.name.startswith("?"): + batch.set_orig(fun.offset, symbol=fun.name) + else: + batch.set_orig(fun.offset, name=fun.name) + + for var in codebase.iter_variables(): + batch.set_orig(var.offset, name=var.name, type=EntityType.DATA) + if var.is_static and var.parent_function is not None: + batch.set_orig( + var.offset, static_var=True, parent_function=var.parent_function + ) - for string in codebase.iter_strings(): - # Not that we don't trust you, but we're checking the string - # annotation to make sure it is accurate. - try: - # TODO: would presumably fail for wchar_t strings - orig = self.orig_bin.read_string(string.offset).decode("latin1") - string_correct = string.name == orig - except UnicodeDecodeError: - string_correct = False - - if not string_correct: - logger.error( - "Data at 0x%x does not match string %s", + for tbl in codebase.iter_vtables(): + batch.set_orig( + tbl.offset, + name=tbl.name, + base_class=tbl.base_class, + type=EntityType.VTABLE, + ) + + # For now, just redirect match alerts to the logger. + report = create_logging_wrapper(logger) + + # Now match + match_symbols(self._db, report) + match_functions(self._db, report) + match_vtables(self._db, report) + match_static_variables(self._db, report) + match_variables(self._db, report) + + with self._db.batch() as batch: + for string in codebase.iter_strings(): + # Not that we don't trust you, but we're checking the string + # annotation to make sure it is accurate. + try: + # TODO: would presumably fail for wchar_t strings + orig = self.orig_bin.read_string(string.offset).decode("latin1") + string_correct = string.name == orig + except UnicodeDecodeError: + string_correct = False + + if not string_correct: + logger.error( + "Data at 0x%x does not match string %s", + string.offset, + repr(string.name), + ) + continue + + batch.set_orig( string.offset, - repr(string.name), + name=string.name, + type=EntityType.STRING, + size=len(string.name), ) - continue + # self._db.match_string(string.offset, string.name) - self._db.match_string(string.offset, string.name) + match_strings(self._db, report) def _match_array_elements(self): """ diff --git a/reccmp/isledecomp/compare/db.py b/reccmp/isledecomp/compare/db.py index bd268586..dee412a2 100644 --- a/reccmp/isledecomp/compare/db.py +++ b/reccmp/isledecomp/compare/db.py @@ -16,6 +16,16 @@ matched int as (orig_addr is not null and recomp_addr is not null), kvstore text default '{}' ); + + CREATE VIEW orig_unmatched (orig_addr, kvstore) AS + SELECT orig_addr, kvstore FROM entities + WHERE orig_addr is not null and recomp_addr is null + ORDER by orig_addr; + + CREATE VIEW recomp_unmatched (recomp_addr, kvstore) AS + SELECT recomp_addr, kvstore FROM entities + WHERE recomp_addr is not null and orig_addr is null + ORDER by recomp_addr; """ @@ -238,6 +248,10 @@ def sql(self) -> sqlite3.Connection: def batch(self) -> EntityBatch: return EntityBatch(self) + def count(self) -> int: + (count,) = self._sql.execute("SELECT count(1) from entities").fetchone() + return count + def set_orig_symbol(self, addr: int, **kwargs): self.bulk_orig_insert(iter([(addr, kwargs)])) diff --git a/reccmp/isledecomp/compare/event.py b/reccmp/isledecomp/compare/event.py new file mode 100644 index 00000000..b402dfa5 --- /dev/null +++ b/reccmp/isledecomp/compare/event.py @@ -0,0 +1,48 @@ +import enum +import logging +from typing import Protocol + + +class LoggingSeverity(enum.IntEnum): + """To improve type checking. There isn't an enum to import from the logging module.""" + + DEBUG = logging.DEBUG + INFO = logging.INFO + WARNING = logging.WARNING + ERROR = logging.ERROR + + +class ReccmpEvent(enum.Enum): + NO_MATCH = enum.auto() + + # Symbol (or designated unique attribute) was found not to be unique + NON_UNIQUE_SYMBOL = enum.auto() + + # Match by name/type not unique + AMBIGUOUS_MATCH = enum.auto() + + +def event_to_severity(event: ReccmpEvent) -> LoggingSeverity: + return { + ReccmpEvent.NO_MATCH: LoggingSeverity.ERROR, + ReccmpEvent.NON_UNIQUE_SYMBOL: LoggingSeverity.WARNING, + ReccmpEvent.AMBIGUOUS_MATCH: LoggingSeverity.WARNING, + }.get(event, LoggingSeverity.INFO) + + +class ReccmpReportProtocol(Protocol): + def __call__(self, event: ReccmpEvent, orig_addr: int, /, msg: str = ""): + ... + + +def reccmp_report_nop(*_, **__): + """Reporting no-op function""" + + +def create_logging_wrapper(logger: logging.Logger) -> ReccmpReportProtocol: + """Return a function to use when you just want to redirect events to the given logger""" + + def wrap(event: ReccmpEvent, _: int, msg: str = ""): + logger.log(event_to_severity(event), msg) + + return wrap diff --git a/reccmp/isledecomp/compare/match_msvc.py b/reccmp/isledecomp/compare/match_msvc.py new file mode 100644 index 00000000..1f1a75c6 --- /dev/null +++ b/reccmp/isledecomp/compare/match_msvc.py @@ -0,0 +1,311 @@ +from reccmp.isledecomp.types import EntityType +from reccmp.isledecomp.compare.db import EntityDb +from reccmp.isledecomp.compare.event import ( + ReccmpEvent, + ReccmpReportProtocol, + reccmp_report_nop, +) + + +class EntityIndex: + """One-to-many index. Maps string value to address.""" + + _dict: dict[str, list[int]] + + def __init__(self) -> None: + self._dict = {} + + def __contains__(self, key: str) -> bool: + return key in self._dict + + def add(self, key: str, value: int): + self._dict.setdefault(key, []).append(value) + + def count(self, key: str) -> int: + return len(self._dict.get(key, [])) + + def pop(self, key: str) -> int: + value = self._dict[key].pop(0) + if len(self._dict[key]) == 0: + del self._dict[key] + + return value + + +def match_symbols(db: EntityDb, report: ReccmpReportProtocol = reccmp_report_nop): + """Match all entities using the symbol attribute. We expect this value to be unique.""" + + symbol_index = EntityIndex() + + for recomp_addr, symbol in db.sql.execute( + """SELECT recomp_addr, json_extract(kvstore, '$.symbol') as symbol + from recomp_unmatched where symbol is not null""" + ): + # Max symbol length in MSVC is 255 chars. See also: Warning C4786. + symbol_index.add(symbol[:255], recomp_addr) + + with db.batch() as batch: + for orig_addr, symbol in db.sql.execute( + """SELECT orig_addr, json_extract(kvstore, '$.symbol') as symbol + from orig_unmatched where symbol is not null""" + ): + # Same truncate to 255 chars as above. + symbol = symbol[:255] + if symbol in symbol_index: + recomp_addr = symbol_index.pop(symbol) + + # If match was not unique: + if symbol in symbol_index: + report( + ReccmpEvent.NON_UNIQUE_SYMBOL, + orig_addr, + msg=f"Matched 0x{orig_addr:x} using non-unique symbol '{symbol}'", + ) + + batch.match(orig_addr, recomp_addr) + + else: + report( + ReccmpEvent.NO_MATCH, + orig_addr, + msg=f"Failed to match function at 0x{orig_addr:x} with symbol '{symbol}'", + ) + + +def match_functions(db: EntityDb, report: ReccmpReportProtocol = reccmp_report_nop): + # addr->symbol map. Used later in error message for non-unique match. + recomp_symbols: dict[int, str] = {} + + name_index = EntityIndex() + + # TODO: We allow a match if entity_type is null. + # This can be removed if we can more confidently declare a symbol is a function + # when adding from the PDB. + for recomp_addr, name, symbol in db.sql.execute( + """SELECT recomp_addr, json_extract(kvstore, '$.name') as name, json_extract(kvstore, '$.symbol') + from recomp_unmatched where name is not null + and (json_extract(kvstore, '$.type') = ? or json_extract(kvstore, '$.type') is null)""", + (EntityType.FUNCTION,), + ): + # Truncate the name to 255 characters. It will not be possible to match a name + # longer than that because MSVC truncates to this length. + # See also: warning C4786. + name = name[:255] + name_index.add(name, recomp_addr) + + # Get the symbol for the error message later. + if symbol is not None: + recomp_symbols[recomp_addr] = symbol + + # Report if the name used in the match is not unique. + # If the name list contained multiple addreses at the start, + # we should report even for the last address in the list. + non_unique_names = set() + + with db.batch() as batch: + for orig_addr, name in db.sql.execute( + """SELECT orig_addr, json_extract(kvstore, '$.name') as name + from orig_unmatched where name is not null + and json_extract(kvstore, '$.type') = ?""", + (EntityType.FUNCTION,), + ): + # Repeat the truncate for our match search + name = name[:255] + + if name in name_index: + recomp_addr = name_index.pop(name) + # If match was not unique + if name in name_index: + non_unique_names.add(name) + + # If this name was ever matched non-uniquely + if name in non_unique_names: + symbol = recomp_symbols.get(recomp_addr, "None") + report( + ReccmpEvent.AMBIGUOUS_MATCH, + orig_addr, + msg=f"Ambiguous match 0x{orig_addr:x} on name '{name}' to '{symbol}'", + ) + + batch.match(orig_addr, recomp_addr) + else: + report( + ReccmpEvent.NO_MATCH, + orig_addr, + msg=f"Failed to match function at 0x{orig_addr:x} with name '{name}'", + ) + + +def match_vtables(db: EntityDb, report: ReccmpReportProtocol = reccmp_report_nop): + """The requirements for matching are: + 1. Recomp entity has name attribute in this format: "Pizza::`vftable'" + This is derived from the symbol: "??_7Pizza@@6B@" + 2. Orig entity has name attribute with class name only. (e.g. "Pizza") + 3. If multiple inheritance is used, the orig entity has the base_class attribute set. + + For multiple inheritance, the vtable name references the base class like this: + + - X::`vftable'{for `Y'} + + The vtable for the derived class will take one of these forms: + + - X::`vftable'{for `X'} + - X::`vftable' + + We assume only one of the above will appear for a given class.""" + + vtable_name_index = EntityIndex() + + for recomp_addr, name in db.sql.execute( + """SELECT recomp_addr, json_extract(kvstore, '$.name') as name + from recomp_unmatched where name is not null + and json_extract(kvstore, '$.type') = ?""", + (EntityType.VTABLE,), + ): + vtable_name_index.add(name, recomp_addr) + + with db.batch() as batch: + for orig_addr, class_name, base_class in db.sql.execute( + """SELECT orig_addr, json_extract(kvstore, '$.name') as name, json_extract(kvstore, '$.base_class') + from orig_unmatched where name is not null + and json_extract(kvstore, '$.type') = ?""", + (EntityType.VTABLE,), + ): + # Most classes will not use multiple inheritance, so try the regular vtable + # first, unless a base class is provided. + if base_class is None or base_class == class_name: + bare_vftable = f"{class_name}::`vftable'" + + if bare_vftable in vtable_name_index: + recomp_addr = vtable_name_index.pop(bare_vftable) + batch.match(orig_addr, recomp_addr) + continue + + # If we didn't find a match above, search for the multiple inheritance vtable. + for_name = base_class if base_class is not None else class_name + for_vftable = f"{class_name}::`vftable'{{for `{for_name}'}}" + + if for_vftable in vtable_name_index: + recomp_addr = vtable_name_index.pop(for_vftable) + batch.match(orig_addr, recomp_addr) + continue + + report( + ReccmpEvent.NO_MATCH, + orig_addr, + msg=f"Failed to match vtable at 0x{orig_addr:x} for class '{class_name}' (base={base_class or 'None'})", + ) + + +def match_static_variables( + db: EntityDb, report: ReccmpReportProtocol = reccmp_report_nop +): + """To match a static variable, we need the following: + 1. Orig entity function with symbol + 2. Orig entity variable with: + - name = name of variable + - static_var = True + - parent_function = orig address of function + 3. Recomp entity for the static variable with symbol""" + with db.batch() as batch: + for ( + variable_addr, + variable_name, + function_name, + function_symbol, + ) in db.sql.execute( + """SELECT var.orig_addr, json_extract(var.kvstore, '$.name') as name, + json_extract(func.kvstore, '$.name'), json_extract(func.kvstore, '$.symbol') + from orig_unmatched var left join entities func on json_extract(var.kvstore, '$.parent_function') = func.orig_addr + where json_extract(var.kvstore, '$.static_var') = 1 + and name is not null""" + ): + # If we could not find the parent function, or if it has no symbol: + if function_symbol is None: + report( + ReccmpEvent.NO_MATCH, + variable_addr, + msg=f"No function for static variable '{variable_name}'", + ) + continue + + # If the static variable has a symbol, it will contain the parent function's symbol. + # e.g. Static variable "g_startupDelay" from function "IsleApp::Tick" + # The function symbol is: "?Tick@IsleApp@@QAEXH@Z" + # The variable symbol is: "?g_startupDelay@?1??Tick@IsleApp@@QAEXH@Z@4HA" + for (recomp_addr,) in db.sql.execute( + """SELECT recomp_addr FROM recomp_unmatched + where (json_extract(kvstore, '$.type') = ? OR json_extract(kvstore, '$.type') IS NULL) + and json_extract(kvstore, '$.symbol') LIKE '%' || ? || '%' || ? || '%'""", + (EntityType.DATA, variable_name, function_symbol), + ): + batch.match(variable_addr, recomp_addr) + break + else: + report( + ReccmpEvent.NO_MATCH, + variable_addr, + msg=f"Failed to match static variable {variable_name} from function {function_name} annotated with 0x{variable_addr:x}", + ) + + +def match_variables(db: EntityDb, report: ReccmpReportProtocol = reccmp_report_nop): + var_name_index = EntityIndex() + + # TODO: We allow a match if entity_type is null. + # This can be removed if we can more confidently declare a symbol is a variable + # when adding from the PDB. + for name, recomp_addr in db.sql.execute( + """SELECT json_extract(kvstore, '$.name') as name, recomp_addr + from recomp_unmatched where name is not null + and (json_extract(kvstore, '$.type') = ? or json_extract(kvstore, '$.type') is null)""", + (EntityType.DATA,), + ): + var_name_index.add(name, recomp_addr) + + with db.batch() as batch: + for orig_addr, name in db.sql.execute( + """SELECT orig_addr, json_extract(kvstore, '$.name') as name + from orig_unmatched where name is not null + and json_extract(kvstore, '$.type') = ? + and coalesce(json_extract(kvstore, '$.static_var'), 0) != 1""", + (EntityType.DATA,), + ): + if name in var_name_index: + recomp_addr = var_name_index.pop(name) + batch.match(orig_addr, recomp_addr) + else: + report( + ReccmpEvent.NO_MATCH, + orig_addr, + msg=f"Failed to match variable {name} at 0x{orig_addr:x}", + ) + + +def match_strings(db: EntityDb, report: ReccmpReportProtocol = reccmp_report_nop): + string_index = EntityIndex() + + for recomp_addr, text in db.sql.execute( + """SELECT recomp_addr, json_extract(kvstore, '$.name') as name + from recomp_unmatched where name is not null + and json_extract(kvstore,'$.type') = ?""", + (EntityType.STRING,), + ): + string_index.add(text, recomp_addr) + + with db.batch() as batch: + for orig_addr, text in db.sql.execute( + """SELECT orig_addr, json_extract(kvstore, '$.name') as name + from orig_unmatched where name is not null + and json_extract(kvstore,'$.type') = ?""", + (EntityType.STRING,), + ): + if text in string_index: + recomp_addr = string_index.pop(text) + batch.match(orig_addr, recomp_addr) + else: + report( + ReccmpEvent.NO_MATCH, + orig_addr, + msg=f"Failed to match string {repr(text)} at 0x{orig_addr:x}", + ) diff --git a/tests/test_compare_db.py b/tests/test_compare_db.py index c1268d0b..7c03ca9c 100644 --- a/tests/test_compare_db.py +++ b/tests/test_compare_db.py @@ -88,6 +88,22 @@ def test_dynamic_metadata(db): assert obj.get("option") is True +def test_db_count(db): + """Wrapper around SELECT COUNT""" + assert db.count() == 0 + + with db.batch() as batch: + batch.set_orig(100) + batch.set_recomp(100) + + assert db.count() == 2 + + with db.batch() as batch: + batch.match(100, 100) + + assert db.count() == 1 + + #### Testing new batch API #### diff --git a/tests/test_match_msvc.py b/tests/test_match_msvc.py new file mode 100644 index 00000000..b9a42357 --- /dev/null +++ b/tests/test_match_msvc.py @@ -0,0 +1,620 @@ +"""Tests MSVC-specific match strategies""" + +from unittest.mock import Mock, ANY +import pytest +from reccmp.isledecomp.types import EntityType +from reccmp.isledecomp.compare.db import EntityDb +from reccmp.isledecomp.compare.match_msvc import ( + match_functions, + match_static_variables, + match_strings, + match_symbols, + match_variables, + match_vtables, +) +from reccmp.isledecomp.compare.event import ReccmpEvent, ReccmpReportProtocol + + +@pytest.fixture(name="db") +def fixture_db() -> EntityDb: + return EntityDb() + + +@pytest.fixture(name="report") +def fixture_report_mock() -> ReccmpReportProtocol: + return Mock(spec=ReccmpReportProtocol) + + +#### match_symbols #### + + +def test_match_symbols(db): + """Should combine entities with the same symbol""" + with db.batch() as batch: + batch.set_orig(123, symbol="hello") + batch.set_recomp(555, symbol="hello") + + match_symbols(db) + + assert db.get_by_orig(123).recomp_addr == 555 + assert db.get_by_recomp(555).orig_addr == 123 + + # Should combine entities + assert db.count() == 1 + + +def test_match_symbols_no_match(db): + """Should not affect entities with no symbol or no matching symbol.""" + with db.batch() as batch: + batch.set_orig(123) + batch.set_recomp(555, symbol="hello") + + match_symbols(db) + + assert db.get_by_orig(123).recomp_addr is None + assert db.get_by_recomp(555).orig_addr is None + assert db.count() == 2 + + +def test_match_symbols_no_match_report(db, report): + """Should report if we cannot match a symbol on the orig side.""" + with db.batch() as batch: + batch.set_orig(123, symbol="test") + + match_symbols(db, report) + + report.assert_called_with(ReccmpEvent.NO_MATCH, 123, msg=ANY) + + +def test_match_symbols_stable_match_order(db): + """Match in ascending address order on both sides for duplicate symbols.""" + with db.batch() as batch: + # Descending order + batch.set_orig(200, symbol="test") + batch.set_orig(100, symbol="test") + batch.set_recomp(555, symbol="test") + batch.set_recomp(333, symbol="test") + + match_symbols(db) + + assert db.get_by_orig(100).recomp_addr == 333 + assert db.get_by_orig(200).recomp_addr == 555 + + +def test_match_symbols_recomp_not_unique(db, report): + """Alert when symbol match is non-unique on the recomp side.""" + with db.batch() as batch: + batch.set_orig(123, symbol="hello") + batch.set_recomp(555, symbol="hello") + batch.set_recomp(222, symbol="hello") + + match_symbols(db, report) + + # Should match first occurrence. + assert db.get_by_orig(123).recomp_addr == 222 + + # Report non-unique match for orig_addr 123 + report.assert_called_with(ReccmpEvent.NON_UNIQUE_SYMBOL, 123, msg=ANY) + + +def test_match_symbols_over_255(db): + """MSVC truncates symbols to 255 characters in the PDB. + Match entities where the symbols are equal up to the 255th character.""" + long_name = "x" * 255 + with db.batch() as batch: + batch.set_orig(123, symbol=long_name + "y") + batch.set_recomp(555, symbol=long_name + "z") + + match_symbols(db) + + assert db.get_by_orig(123).recomp_addr == 555 + assert db.get_by_recomp(555).orig_addr == 123 + + +#### match_functions #### + + +def test_match_functions(db): + """Simple match by name and type""" + with db.batch() as batch: + batch.set_orig(123, name="hello", type=EntityType.FUNCTION) + batch.set_recomp(555, name="hello", type=EntityType.FUNCTION) + + match_functions(db) + + assert db.get_by_orig(123).recomp_addr == 555 + assert db.get_by_recomp(555).orig_addr == 123 + + # Should combine entities + assert db.count() == 1 + + +def test_match_functions_no_match(db): + """Skip entities with no match""" + with db.batch() as batch: + batch.set_orig(123, name="hello", type=EntityType.FUNCTION) + batch.set_recomp(555, name="test", type=EntityType.FUNCTION) + + match_functions(db) + + assert db.get_by_orig(123).recomp_addr is None + assert db.get_by_recomp(555).orig_addr is None + assert db.count() == 2 + + +def test_match_functions_no_match_report(db, report): + """Should report if we cannot match a name on the orig side.""" + with db.batch() as batch: + batch.set_orig(123, name="test", type=EntityType.FUNCTION) + + match_functions(db, report) + + report.assert_called_with(ReccmpEvent.NO_MATCH, 123, msg=ANY) + + +def test_match_function_stable_order(db): + """If name is not unique, match according to orig and recomp address order. + i.e. insertion order does not matter""" + with db.batch() as batch: + # Descending order + batch.set_orig(101, name="hello", type=EntityType.FUNCTION) + batch.set_orig(100, name="hello", type=EntityType.FUNCTION) + batch.set_recomp(501, name="hello", type=EntityType.FUNCTION) + batch.set_recomp(500, name="hello", type=EntityType.FUNCTION) + + match_functions(db) + + assert db.get_by_orig(100).recomp_addr == 500 + assert db.get_by_orig(101).recomp_addr == 501 + + +def test_match_functions_type_null(db): + """Will allow a function match if the recomp side has type=null""" + with db.batch() as batch: + batch.set_orig(123, name="hello", type=EntityType.FUNCTION) + batch.set_recomp(555, name="hello") + + match_functions(db) + + assert db.get_by_orig(123).recomp_addr == 555 + assert db.get_by_recomp(555).orig_addr == 123 + assert db.count() == 1 + + +def test_match_functions_ambiguous(db, report): + """Report if a name match had multiple options. + If there is only one option left, but previous matches were ambiguous, report it anyway. + """ + with db.batch() as batch: + batch.set_orig(100, name="hello", type=EntityType.FUNCTION) + batch.set_orig(101, name="hello", type=EntityType.FUNCTION) + batch.set_recomp(500, name="hello", type=EntityType.FUNCTION) + batch.set_recomp(501, name="hello", type=EntityType.FUNCTION) + + match_functions(db, report) + + # Report for both ambiguous matches + report.assert_any_call(ReccmpEvent.AMBIGUOUS_MATCH, 100, msg=ANY) + report.assert_any_call(ReccmpEvent.AMBIGUOUS_MATCH, 101, msg=ANY) + + # Should match regardless + assert db.count() == 2 + + +def test_match_functions_ignore_already_matched(db, report): + """If the name is non-unique but there is only one option available to match + (i.e. if previous entities were matched by line number) + do not report an ambiguous match.""" + with db.batch() as batch: + batch.set_orig(101, name="hello", type=EntityType.FUNCTION) + batch.set_recomp(500, name="hello", type=EntityType.FUNCTION) + batch.set_recomp(501, name="hello", type=EntityType.FUNCTION) + # Match these addrs before calling match_functions() + batch.match(100, 500) + + # 1 matched, 2 unmatched + assert db.count() == 3 + + match_functions(db, report) + + # Do not report + report.assert_not_called() + + # Should combine the two unmatched entities + assert db.get_by_recomp(501).orig_addr == 101 + assert db.count() == 2 + + +def test_match_function_names_over_255(db): + """MSVC truncates names to 255 characters in the PDB. + Match function entities where the names are are equal up to the 255th character.""" + long_name = "x" * 255 + with db.batch() as batch: + batch.set_orig(123, name=long_name + "y", type=EntityType.FUNCTION) + batch.set_recomp(555, name=long_name + "z", type=EntityType.FUNCTION) + + match_functions(db) + + assert db.get_by_orig(123).recomp_addr == 555 + assert db.get_by_recomp(555).orig_addr == 123 + + +#### match_vtables #### + + +def test_match_vtables(db): + """Matching with the specific requirements on attributes for orig and recomp entities""" + with db.batch() as batch: + # Orig has class name and type + batch.set_orig(100, name="Pizza", type=EntityType.VTABLE) + # Recomp has full vtable name and type + batch.set_recomp(200, name="Pizza::`vftable'", type=EntityType.VTABLE) + + match_vtables(db) + + assert db.get_by_orig(100).recomp_addr == 200 + assert db.count() == 1 + + +def test_match_vtables_no_match_recomp_name(db): + """Recomp entity name must be in a specific format""" + with db.batch() as batch: + batch.set_orig(100, name="Pizza", type=EntityType.VTABLE) + batch.set_recomp(200, name="Pizza", type=EntityType.VTABLE) + + match_vtables(db) + + assert db.get_by_orig(100).recomp_addr is None + + +def test_match_vtables_no_match_recomp_type(db): + """Recomp entity must have type=EntityType.VTABLE""" + with db.batch() as batch: + batch.set_orig(100, name="Pizza", type=EntityType.VTABLE) + batch.set_recomp(200, name="Pizza::`vftable'") + + match_vtables(db) + + assert db.get_by_orig(100).recomp_addr is None + + +def test_match_vtables_no_match_orig_type(db): + """Orig entity must have type=EntityType.VTABLE""" + with db.batch() as batch: + batch.set_orig(100, name="Pizza") + batch.set_recomp(200, name="Pizza::`vftable'", type=EntityType.VTABLE) + + match_vtables(db) + + assert db.get_by_orig(100).recomp_addr is None + + +def test_match_vtables_no_match_report(db, report): + """Report a failure to match a vtable from the orig side.""" + with db.batch() as batch: + batch.set_orig(100, name="Pizza", type=EntityType.VTABLE) + + match_vtables(db, report) + + report.assert_called_with(ReccmpEvent.NO_MATCH, 100, msg=ANY) + + +def test_match_vtables_base_class(db): + """Match a vtable with a base class""" + with db.batch() as batch: + batch.set_orig(100, name="Pizza", type=EntityType.VTABLE, base_class="Lunch") + batch.set_recomp( + 200, name="Pizza::`vftable'{for `Lunch'}", type=EntityType.VTABLE + ) + + match_vtables(db) + + assert db.get_by_orig(100).recomp_addr == 200 + + +def test_match_vtables_base_class_orig_none(db): + """Do not match a multiple-inheritance vtable if the base class is not specified on the orig entity.""" + with db.batch() as batch: + batch.set_orig(100, name="Pizza", type=EntityType.VTABLE) + batch.set_recomp( + 200, name="Pizza::`vftable'{for `Lunch'}", type=EntityType.VTABLE + ) + + match_vtables(db) + + assert db.get_by_orig(100).recomp_addr is None + + +def test_match_vtables_base_class_same_as_derived(db): + """Matching a vtable with the same base class and derived class. + The base_class attribute is set on the orig entity.""" + with db.batch() as batch: + batch.set_orig(100, name="Pizza", type=EntityType.VTABLE, base_class="Pizza") + batch.set_recomp( + 200, name="Pizza::`vftable'{for `Pizza'}", type=EntityType.VTABLE + ) + + match_vtables(db) + + assert db.get_by_orig(100).recomp_addr == 200 + + +def test_match_vtables_base_class_same_as_derived_orig_none(db): + """If orig does not have the base_class attribute set, we can still match if + the recomp vtable has the same base and derived class.""" + with db.batch() as batch: + batch.set_orig(100, name="Pizza", type=EntityType.VTABLE) + batch.set_recomp( + 200, name="Pizza::`vftable'{for `Pizza'}", type=EntityType.VTABLE + ) + + match_vtables(db) + + assert db.get_by_orig(100).recomp_addr == 200 + + +def test_match_vtables_incompatible_base_class(db): + """If the orig entity has a base_class, do not match with a recomp vtable that does not use multiple-inheritance.""" + with db.batch() as batch: + batch.set_orig(100, name="Pizza", type=EntityType.VTABLE, base_class="Lunch") + batch.set_recomp(200, name="Pizza::`vftable'", type=EntityType.VTABLE) + + match_vtables(db) + + assert db.get_by_orig(100).recomp_addr is None + + +#### match_static_variables #### + + +def test_match_static_var(db): + """Match a static variable with all requirements satisfied.""" + with db.batch() as batch: + # Orig entity function with symbol + batch.set_orig(200, symbol="?Tick@IsleApp@@QAEXH@Z", type=EntityType.FUNCTION) + # Static variable with symbol + batch.set_recomp(500, symbol="?g_startupDelay@?1??Tick@IsleApp@@QAEXH@Z@4HA") + # Orig entity with variable name and link to orig function addr + batch.set_orig( + 600, + name="g_startupDelay", + parent_function=200, + static_var=True, + type=EntityType.DATA, + ) + + match_static_variables(db) + + assert db.get_by_orig(600).recomp_addr == 500 + + +def test_match_static_var_no_parent_function(db): + """Cannot match static variable without a reference to its parent function""" + with db.batch() as batch: + batch.set_orig(200, symbol="?Tick@IsleApp@@QAEXH@Z", type=EntityType.FUNCTION) + batch.set_recomp(500, symbol="?g_startupDelay@?1??Tick@IsleApp@@QAEXH@Z@4HA") + # No parent function + batch.set_orig( + 600, + name="g_startupDelay", + static_var=True, + type=EntityType.DATA, + ) + + match_static_variables(db) + + assert db.get_by_orig(600).recomp_addr is None + + +def test_match_static_var_static_false(db): + """Cannot match static variable unless the static_var attribute is True""" + with db.batch() as batch: + batch.set_orig(200, symbol="?Tick@IsleApp@@QAEXH@Z", type=EntityType.FUNCTION) + batch.set_recomp(500, symbol="?g_startupDelay@?1??Tick@IsleApp@@QAEXH@Z@4HA") + # static_var is not set + batch.set_orig( + 600, + name="g_startupDelay", + parent_function=200, + type=EntityType.DATA, + ) + + match_static_variables(db) + + assert db.get_by_orig(600).recomp_addr is None + + +def test_match_static_var_no_symbol_function(db): + """Cannot match static variable if the parent function has no symbol""" + with db.batch() as batch: + # No symbol on parent function + batch.set_orig(200, type=EntityType.FUNCTION) + batch.set_recomp(500, symbol="?g_startupDelay@?1??Tick@IsleApp@@QAEXH@Z@4HA") + batch.set_orig( + 600, + name="g_startupDelay", + parent_function=200, + static_var=True, + type=EntityType.DATA, + ) + + match_static_variables(db) + + assert db.get_by_orig(600).recomp_addr is None + + +def test_match_static_var_no_symbol_variable(db): + """Cannot match static variable without a symbol.""" + with db.batch() as batch: + batch.set_orig(200, symbol="?Tick@IsleApp@@QAEXH@Z", type=EntityType.FUNCTION) + # No symbol on variable + batch.set_recomp(500, name="g_startupDelay") + batch.set_orig( + 600, + name="g_startupDelay", + parent_function=200, + static_var=True, + type=EntityType.DATA, + ) + + match_static_variables(db) + + assert db.get_by_orig(600).recomp_addr is None + + +def test_match_static_var_no_match_report(db, report): + """Report match failure for any orig entities with static_var=True""" + with db.batch() as batch: + batch.set_orig(600, name="test", static_var=True, type=EntityType.DATA) + + match_static_variables(db, report) + + report.assert_called_with(ReccmpEvent.NO_MATCH, 600, msg=ANY) + + +#### match_variables #### + + +def test_match_variables(db): + """Simple match by name and type""" + with db.batch() as batch: + batch.set_orig(123, name="hello", type=EntityType.DATA) + batch.set_recomp(555, name="hello", type=EntityType.DATA) + + match_variables(db) + + assert db.get_by_orig(123).recomp_addr == 555 + assert db.get_by_recomp(555).orig_addr == 123 + + # Should combine entities + assert db.count() == 1 + + +def test_match_variables_no_match(db): + """Skip entities with no match""" + with db.batch() as batch: + batch.set_orig(123, name="hello", type=EntityType.DATA) + batch.set_recomp(555, name="test", type=EntityType.DATA) + + match_variables(db) + + assert db.get_by_orig(123).recomp_addr is None + assert db.get_by_recomp(555).orig_addr is None + assert db.count() == 2 + + +def test_match_variables_no_match_report(db, report): + """Should report if we cannot match a name on the orig side.""" + with db.batch() as batch: + batch.set_orig(123, name="test", type=EntityType.DATA) + + match_variables(db, report) + + report.assert_called_with(ReccmpEvent.NO_MATCH, 123, msg=ANY) + + +def test_match_variables_type_null(db): + """Will allow a variable match if the recomp side has type=null""" + with db.batch() as batch: + batch.set_orig(123, name="hello", type=EntityType.DATA) + batch.set_recomp(555, name="hello") + + match_variables(db) + + assert db.get_by_orig(123).recomp_addr == 555 + assert db.get_by_recomp(555).orig_addr == 123 + assert db.count() == 1 + + +#### match_strings #### + + +def test_match_strings(db): + with db.batch() as batch: + batch.set_orig(123, name="hello", type=EntityType.STRING) + batch.set_recomp(555, name="hello", type=EntityType.STRING) + + match_strings(db) + + assert db.get_by_orig(123).recomp_addr == 555 + assert db.get_by_recomp(555).orig_addr == 123 + + # Should combine entities + assert db.count() == 1 + + +def test_match_strings_no_match(db): + """Skip strings with no match""" + with db.batch() as batch: + batch.set_orig(123, name="hello", type=EntityType.STRING) + batch.set_recomp(555, name="test", type=EntityType.STRING) + + match_strings(db) + + assert db.get_by_orig(123).recomp_addr is None + assert db.get_by_recomp(555).orig_addr is None + assert db.count() == 2 + + +def test_match_strings_type_required(db): + """Do not match if one side is missing the type. + This is a concern because we use the name attribute for the string's text.""" + with db.batch() as batch: + batch.set_orig(100, name="hello", type=EntityType.STRING) + batch.set_orig(200, name="test") + batch.set_recomp(500, name="hello") + batch.set_recomp(600, name="test", type=EntityType.STRING) + + match_strings(db) + + assert db.get_by_orig(100).recomp_addr is None + assert db.get_by_orig(200).recomp_addr is None + + +def test_match_strings_no_match_report(db, report): + """Should report if we cannot match a string on the orig side.""" + with db.batch() as batch: + batch.set_orig(123, name="test", type=EntityType.STRING) + + match_strings(db, report) + + report.assert_called_with(ReccmpEvent.NO_MATCH, 123, msg=ANY) + + +def test_match_strings_duplicates(db, report): + """Binaries that do not de-dupe string should match duplicates by address order.""" + with db.batch() as batch: + batch.set_orig(100, name="hello", type=EntityType.STRING) + batch.set_orig(200, name="hello", type=EntityType.STRING) + batch.set_orig(300, name="hello", type=EntityType.STRING) + batch.set_recomp(500, name="hello", type=EntityType.STRING) + batch.set_recomp(600, name="hello", type=EntityType.STRING) + batch.set_recomp(700, name="hello", type=EntityType.STRING) + + match_strings(db, report) + + assert db.get_by_orig(100).recomp_addr == 500 + assert db.get_by_orig(200).recomp_addr == 600 + assert db.get_by_orig(300).recomp_addr == 700 + assert db.count() == 3 + + # Do not alert for duplicate string matches. + report.assert_not_called() + + +def test_match_strings_stable_order(db): + """Duplicates are matched by address order, not db insertion order.""" + with db.batch() as batch: + # Descending order + batch.set_orig(300, name="hello", type=EntityType.STRING) + batch.set_orig(200, name="hello", type=EntityType.STRING) + batch.set_orig(100, name="hello", type=EntityType.STRING) + batch.set_recomp(700, name="hello", type=EntityType.STRING) + batch.set_recomp(600, name="hello", type=EntityType.STRING) + batch.set_recomp(500, name="hello", type=EntityType.STRING) + + match_strings(db) + + assert db.get_by_orig(100).recomp_addr == 500 + assert db.get_by_orig(200).recomp_addr == 600 + assert db.get_by_orig(300).recomp_addr == 700