Skip to content

Commit

Permalink
New matching and event reporting module
Browse files Browse the repository at this point in the history
  • Loading branch information
disinvite committed Feb 1, 2025
1 parent 0fbd24c commit 3b0df42
Show file tree
Hide file tree
Showing 6 changed files with 1,092 additions and 41 deletions.
124 changes: 83 additions & 41 deletions reccmp/isledecomp/compare/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,19 @@
from reccmp.isledecomp.parser import DecompCodebase
from reccmp.isledecomp.dir import walk_source_dir
from reccmp.isledecomp.types import EntityType
from reccmp.isledecomp.compare.event import create_logging_wrapper
from reccmp.isledecomp.compare.asm import ParseAsm
from reccmp.isledecomp.compare.asm.replacement import create_name_lookup
from reccmp.isledecomp.compare.asm.fixes import assert_fixup, find_effective_match
from reccmp.isledecomp.analysis import find_float_consts
from .match_msvc import (
match_symbols,
match_functions,
match_vtables,
match_static_variables,
match_variables,
match_strings,
)
from .db import EntityDb, ReccmpEntity, ReccmpMatch
from .diff import combined_diff, CombinedDiffOutput
from .lines import LinesDb
Expand Down Expand Up @@ -142,7 +151,7 @@ def _load_cvdump(self):

# Build the list of entries to insert to the DB.
# In the rare case we have duplicate symbols for an address, ignore them.
dataset = {}
seen_addrs = set()

batch = self._db.batch()

Expand All @@ -162,9 +171,11 @@ def _load_cvdump(self):
addr = self.recomp_bin.get_abs_addr(sym.section, sym.offset)
sym.addr = addr

if addr in dataset:
if addr in seen_addrs:
continue

seen_addrs.add(addr)

# If this symbol is the final one in its section, we were not able to
# estimate its size because we didn't have the total size of that section.
# We can get this estimate now and assume that the final symbol occupies
Expand Down Expand Up @@ -262,51 +273,82 @@ def orig_bin_checker(addr: int) -> bool:
# If we have two functions that share the same name, and one is
# a lineref, we can match the nameref correctly because the lineref
# was already removed from consideration.
for fun in codebase.iter_line_functions():
assert fun.filename is not None
recomp_addr = self._lines_db.search_line(
fun.filename, fun.line_number, fun.end_line
)
if recomp_addr is not None:
self._db.set_function_pair(fun.offset, recomp_addr)
if fun.should_skip():
self._db.mark_stub(fun.offset)

for fun in codebase.iter_name_functions():
self._db.match_function(fun.offset, fun.name)
if fun.should_skip():
self._db.mark_stub(fun.offset)

for var in codebase.iter_variables():
if var.is_static and var.parent_function is not None:
self._db.match_static_variable(
var.offset, var.name, var.parent_function
with self._db.batch() as batch:
for fun in codebase.iter_line_functions():
assert fun.filename is not None
recomp_addr = self._lines_db.search_line(
fun.filename, fun.line_number, fun.end_line
)
if recomp_addr is not None:
batch.match(fun.offset, recomp_addr)
batch.set_recomp(
recomp_addr, type=EntityType.FUNCTION, stub=fun.should_skip()
)

with self._db.batch() as batch:
for fun in codebase.iter_name_functions():
batch.set_orig(
fun.offset, type=EntityType.FUNCTION, stub=fun.should_skip()
)
else:
self._db.match_variable(var.offset, var.name)

for tbl in codebase.iter_vtables():
self._db.match_vtable(tbl.offset, tbl.name, tbl.base_class)
if fun.name.startswith("?"):
batch.set_orig(fun.offset, symbol=fun.name)
else:
batch.set_orig(fun.offset, name=fun.name)

for var in codebase.iter_variables():
batch.set_orig(var.offset, name=var.name, type=EntityType.DATA)
if var.is_static and var.parent_function is not None:
batch.set_orig(
var.offset, static_var=True, parent_function=var.parent_function
)

for string in codebase.iter_strings():
# Not that we don't trust you, but we're checking the string
# annotation to make sure it is accurate.
try:
# TODO: would presumably fail for wchar_t strings
orig = self.orig_bin.read_string(string.offset).decode("latin1")
string_correct = string.name == orig
except UnicodeDecodeError:
string_correct = False

if not string_correct:
logger.error(
"Data at 0x%x does not match string %s",
for tbl in codebase.iter_vtables():
batch.set_orig(
tbl.offset,
name=tbl.name,
base_class=tbl.base_class,
type=EntityType.VTABLE,
)

# For now, just redirect match alerts to the logger.
report = create_logging_wrapper(logger)

# Now match
match_symbols(self._db, report)
match_functions(self._db, report)
match_vtables(self._db, report)
match_static_variables(self._db, report)
match_variables(self._db, report)

with self._db.batch() as batch:
for string in codebase.iter_strings():
# Not that we don't trust you, but we're checking the string
# annotation to make sure it is accurate.
try:
# TODO: would presumably fail for wchar_t strings
orig = self.orig_bin.read_string(string.offset).decode("latin1")
string_correct = string.name == orig
except UnicodeDecodeError:
string_correct = False

if not string_correct:
logger.error(
"Data at 0x%x does not match string %s",
string.offset,
repr(string.name),
)
continue

batch.set_orig(
string.offset,
repr(string.name),
name=string.name,
type=EntityType.STRING,
size=len(string.name),
)
continue
# self._db.match_string(string.offset, string.name)

self._db.match_string(string.offset, string.name)
match_strings(self._db, report)

def _match_array_elements(self):
"""
Expand Down
14 changes: 14 additions & 0 deletions reccmp/isledecomp/compare/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,16 @@
matched int as (orig_addr is not null and recomp_addr is not null),
kvstore text default '{}'
);
CREATE VIEW orig_unmatched (orig_addr, kvstore) AS
SELECT orig_addr, kvstore FROM entities
WHERE orig_addr is not null and recomp_addr is null
ORDER by orig_addr;
CREATE VIEW recomp_unmatched (recomp_addr, kvstore) AS
SELECT recomp_addr, kvstore FROM entities
WHERE recomp_addr is not null and orig_addr is null
ORDER by recomp_addr;
"""


Expand Down Expand Up @@ -238,6 +248,10 @@ def sql(self) -> sqlite3.Connection:
def batch(self) -> EntityBatch:
return EntityBatch(self)

def count(self) -> int:
(count,) = self._sql.execute("SELECT count(1) from entities").fetchone()
return count

def set_orig_symbol(self, addr: int, **kwargs):
self.bulk_orig_insert(iter([(addr, kwargs)]))

Expand Down
48 changes: 48 additions & 0 deletions reccmp/isledecomp/compare/event.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import enum
import logging
from typing import Protocol


class LoggingSeverity(enum.IntEnum):
"""To improve type checking. There isn't an enum to import from the logging module."""

DEBUG = logging.DEBUG
INFO = logging.INFO
WARNING = logging.WARNING
ERROR = logging.ERROR


class ReccmpEvent(enum.Enum):
NO_MATCH = enum.auto()

# Symbol (or designated unique attribute) was found not to be unique
NON_UNIQUE_SYMBOL = enum.auto()

# Match by name/type not unique
AMBIGUOUS_MATCH = enum.auto()


def event_to_severity(event: ReccmpEvent) -> LoggingSeverity:
return {
ReccmpEvent.NO_MATCH: LoggingSeverity.ERROR,
ReccmpEvent.NON_UNIQUE_SYMBOL: LoggingSeverity.WARNING,
ReccmpEvent.AMBIGUOUS_MATCH: LoggingSeverity.WARNING,
}.get(event, LoggingSeverity.INFO)


class ReccmpReportProtocol(Protocol):
def __call__(self, event: ReccmpEvent, orig_addr: int, /, msg: str = ""):
...


def reccmp_report_nop(*_, **__):
"""Reporting no-op function"""


def create_logging_wrapper(logger: logging.Logger) -> ReccmpReportProtocol:
"""Return a function to use when you just want to redirect events to the given logger"""

def wrap(event: ReccmpEvent, _: int, msg: str = ""):
logger.log(event_to_severity(event), msg)

return wrap
Loading

0 comments on commit 3b0df42

Please sign in to comment.