diff --git a/README.md b/README.md
index bf21d796..3a0a8900 100644
--- a/README.md
+++ b/README.md
@@ -46,6 +46,10 @@ The next steps differ based on what kind of project you have.
All scripts will become available to use in your terminal with the `reccmp-` prefix. Note that these scripts need to be executed in the directory where `reccmp-build.yml` is located.
+* [`aggregate`](/reccmp/tools/aggregate.py): Combines JSON reports into a single file.
+ * Aggregate using highest accuracy score: `reccmp-aggregate --samples ./sample0.json ./sample1.json ./sample2.json --output ./combined.json`
+ * Diff two saved reports: `reccmp-aggregate --diff ./before.json ./after.json`
+ * Diff against the aggregate: `reccmp-aggregate --samples ./sample0.json ./sample1.json ./sample2.json --diff ./before.json`
* [`decomplint`](/reccmp/tools/decomplint.py): Checks the decompilation annotations (see above)
* e.g. `reccmp-decomplint --module LEGO1 LEGO1`
* [`reccmp`](/reccmp/tools/asmcmp.py): Compares an original binary with a recompiled binary, provided a PDB file. For example:
diff --git a/reccmp/assets/template.html b/reccmp/assets/template.html
index 3ac15b45..6518c5a6 100644
--- a/reccmp/assets/template.html
+++ b/reccmp/assets/template.html
@@ -180,7 +180,10 @@
margin-bottom: 0;
}
-
+
diff --git a/reccmp/isledecomp/compare/diff.py b/reccmp/isledecomp/compare/diff.py
index aebe1b6a..19c412bd 100644
--- a/reccmp/isledecomp/compare/diff.py
+++ b/reccmp/isledecomp/compare/diff.py
@@ -1,6 +1,5 @@
from difflib import SequenceMatcher
-from typing import TypedDict
-from typing_extensions import NotRequired
+from typing_extensions import NotRequired, TypedDict
CombinedDiffInput = list[tuple[str, str]]
diff --git a/reccmp/isledecomp/compare/report.py b/reccmp/isledecomp/compare/report.py
new file mode 100644
index 00000000..fe5ece9b
--- /dev/null
+++ b/reccmp/isledecomp/compare/report.py
@@ -0,0 +1,189 @@
+from datetime import datetime
+from dataclasses import dataclass
+from typing import Literal, Iterable, Iterator
+from pydantic import BaseModel, ValidationError
+from pydantic_core import from_json
+from .diff import CombinedDiffOutput
+
+
+class ReccmpReportDeserializeError(Exception):
+ """The given file is not a serialized reccmp report file"""
+
+
+class ReccmpReportSameSourceError(Exception):
+ """Tried to aggregate reports derived from different source files."""
+
+
+@dataclass
+class ReccmpComparedEntity:
+ orig_addr: str
+ name: str
+ accuracy: float
+ recomp_addr: str | None = None
+ is_effective_match: bool = False
+ is_stub: bool = False
+ diff: CombinedDiffOutput | None = None
+
+
+class ReccmpStatusReport:
+ # The filename of the original binary.
+ # This is here to avoid comparing reports derived from different files.
+ # TODO: in the future, we may want to use the hash instead
+ filename: str
+
+ # Creation date of the report file.
+ timestamp: datetime
+
+ # Using orig addr as the key.
+ entities: dict[str, ReccmpComparedEntity]
+
+ def __init__(self, filename: str, timestamp: datetime | None = None) -> None:
+ self.filename = filename
+ if timestamp is not None:
+ self.timestamp = timestamp
+ else:
+ self.timestamp = datetime.now().replace(microsecond=0)
+
+ self.entities = {}
+
+
+def _get_entity_for_addr(
+ samples: Iterable[ReccmpStatusReport], addr: str
+) -> Iterator[ReccmpComparedEntity]:
+ """Helper to return entities from xreports that have the given address."""
+ for sample in samples:
+ if addr in sample.entities:
+ yield sample.entities[addr]
+
+
+def _accuracy_sort_key(entity: ReccmpComparedEntity) -> float:
+ """Helper to sort entity samples by accuracy score.
+ 100% match is preferred over effective match.
+ Effective match is preferred over any accuracy.
+ Stubs rank lower than any accuracy score."""
+ if entity.is_stub:
+ return -1.0
+
+ if entity.accuracy == 1.0:
+ if not entity.is_effective_match:
+ return 1000.0
+
+ if entity.is_effective_match:
+ return 1.0
+
+ return entity.accuracy
+
+
+def combine_reports(samples: list[ReccmpStatusReport]) -> ReccmpStatusReport:
+ """Combines the sample reports into a single report.
+ The current strategy is to use the entity with the highest
+ accuracy score from any report."""
+ assert len(samples) > 0
+
+ if not all(samples[0].filename == s.filename for s in samples):
+ raise ReccmpReportSameSourceError
+
+ output = ReccmpStatusReport(filename=samples[0].filename)
+
+ # Combine every orig addr used in any of the reports.
+ orig_addr_set = {key for sample in samples for key in sample.entities.keys()}
+
+ all_orig_addrs = sorted(list(orig_addr_set))
+
+ for addr in all_orig_addrs:
+ e_list = list(_get_entity_for_addr(samples, addr))
+ assert len(e_list) > 0
+
+ # Our aggregate accuracy score is the highest from any report.
+ e_list.sort(key=_accuracy_sort_key, reverse=True)
+
+ output.entities[addr] = e_list[0]
+
+ # Recomp addr will most likely vary between samples, so clear it
+ output.entities[addr].recomp_addr = None
+
+ return output
+
+
+#### JSON schemas and conversion functions ####
+
+
+@dataclass
+class JSONEntityVersion1:
+ address: str
+ name: str
+ matching: float
+ # Optional fields
+ recomp: str | None = None
+ stub: bool = False
+ effective: bool = False
+ diff: CombinedDiffOutput | None = None
+
+
+class JSONReportVersion1(BaseModel):
+ file: str
+ format: Literal[1]
+ timestamp: float
+ data: list[JSONEntityVersion1]
+
+
+def _serialize_version_1(
+ report: ReccmpStatusReport, diff_included: bool = False
+) -> JSONReportVersion1:
+ """The HTML file needs the diff data, but it is omitted from the JSON report."""
+ entities = [
+ JSONEntityVersion1(
+ address=addr, # prefer dict key over redundant value in entity
+ name=e.name,
+ matching=e.accuracy,
+ recomp=e.recomp_addr,
+ stub=e.is_stub,
+ effective=e.is_effective_match,
+ diff=e.diff if diff_included else None,
+ )
+ for addr, e in report.entities.items()
+ ]
+
+ return JSONReportVersion1(
+ file=report.filename,
+ format=1,
+ timestamp=report.timestamp.timestamp(),
+ data=entities,
+ )
+
+
+def _deserialize_version_1(obj: JSONReportVersion1) -> ReccmpStatusReport:
+ report = ReccmpStatusReport(
+ filename=obj.file, timestamp=datetime.fromtimestamp(obj.timestamp)
+ )
+
+ for e in obj.data:
+ report.entities[e.address] = ReccmpComparedEntity(
+ orig_addr=e.address,
+ name=e.name,
+ accuracy=e.matching,
+ recomp_addr=e.recomp,
+ is_stub=e.stub,
+ is_effective_match=e.effective,
+ )
+
+ return report
+
+
+def deserialize_reccmp_report(json_str: str) -> ReccmpStatusReport:
+ try:
+ obj = JSONReportVersion1.model_validate(from_json(json_str))
+ return _deserialize_version_1(obj)
+ except ValidationError as ex:
+ raise ReccmpReportDeserializeError from ex
+
+
+def serialize_reccmp_report(
+ report: ReccmpStatusReport, diff_included: bool = False
+) -> str:
+ """Create a JSON string for the report so it can be written to a file."""
+ now = datetime.now().replace(microsecond=0)
+ report.timestamp = now
+ obj = _serialize_version_1(report, diff_included=diff_included)
+
+ return obj.model_dump_json(exclude_defaults=True)
diff --git a/reccmp/isledecomp/utils.py b/reccmp/isledecomp/utils.py
index fc7215d9..26b8f37a 100644
--- a/reccmp/isledecomp/utils.py
+++ b/reccmp/isledecomp/utils.py
@@ -1,7 +1,7 @@
from datetime import datetime
import logging
-from pathlib import Path
import colorama
+from reccmp.isledecomp.compare.report import ReccmpStatusReport, ReccmpComparedEntity
def print_combined_diff(udiff, plain: bool = False, show_both: bool = False):
@@ -129,7 +129,9 @@ def diff_json_display(show_both_addrs: bool = False, is_plain: bool = False):
"""Generate a function that will display the diff according to
the reccmp display preferences."""
- def formatter(orig_addr, saved, new) -> str:
+ def formatter(
+ orig_addr, saved: ReccmpComparedEntity, new: ReccmpComparedEntity
+ ) -> str:
old_pct = "new"
new_pct = "gone"
name = ""
@@ -138,29 +140,25 @@ def formatter(orig_addr, saved, new) -> str:
if new is not None:
new_pct = (
"stub"
- if new.get("stub", False)
- else percent_string(
- new["matching"], new.get("effective", False), is_plain
- )
+ if new.is_stub
+ else percent_string(new.accuracy, new.is_effective_match, is_plain)
)
# Prefer the current name of this function if we have it.
# We are using the original address as the key.
# A function being renamed is not of interest here.
- name = new.get("name", "")
- recomp_addr = new.get("recomp", "n/a")
+ name = new.name
+ recomp_addr = new.recomp_addr or "n/a"
if saved is not None:
old_pct = (
"stub"
- if saved.get("stub", False)
- else percent_string(
- saved["matching"], saved.get("effective", False), is_plain
- )
+ if saved.is_stub
+ else percent_string(saved.accuracy, saved.is_effective_match, is_plain)
)
if name == "":
- name = saved.get("name", "")
+ name = saved.name
if show_both_addrs:
addr_string = f"{orig_addr} / {recomp_addr:10}"
@@ -176,29 +174,25 @@ def formatter(orig_addr, saved, new) -> str:
def diff_json(
- saved_data,
- new_data,
- orig_file: Path,
+ saved_data: ReccmpStatusReport,
+ new_data: ReccmpStatusReport,
show_both_addrs: bool = False,
is_plain: bool = False,
):
- """Using a saved copy of the diff summary and the current data, print a
- report showing which functions/symbols have changed match percentage."""
+ """Compare two status report files, determine what items changed, and print the result."""
# Don't try to diff a report generated for a different binary file
- base_file = orig_file.name.lower()
-
- if saved_data.get("file") != base_file:
+ if saved_data.filename != new_data.filename:
logging.getLogger().error(
"Diff report for '%s' does not match current file '%s'",
- saved_data.get("file"),
- base_file,
+ saved_data.filename,
+ new_data.filename,
)
return
- if "timestamp" in saved_data:
+ if saved_data.timestamp is not None:
now = datetime.now().replace(microsecond=0)
- then = datetime.fromtimestamp(saved_data["timestamp"]).replace(microsecond=0)
+ then = saved_data.timestamp.replace(microsecond=0)
print(
" ".join(
@@ -213,8 +207,8 @@ def diff_json(
print()
# Convert to dict, using orig_addr as key
- saved_invert = {obj["address"]: obj for obj in saved_data["data"]}
- new_invert = {obj["address"]: obj for obj in new_data}
+ saved_invert = saved_data.entities
+ new_invert = new_data.entities
all_addrs = set(saved_invert.keys()).union(new_invert.keys())
@@ -227,60 +221,56 @@ def diff_json(
for addr in sorted(all_addrs)
}
+ DiffSubsectionType = dict[
+ str, tuple[ReccmpComparedEntity | None, ReccmpComparedEntity | None]
+ ]
+
# The criteria for diff judgement is in these dict comprehensions:
# Any function not in the saved file
- new_functions = {
+ new_functions: DiffSubsectionType = {
key: (saved, new) for key, (saved, new) in combined.items() if saved is None
}
# Any function now missing from the saved file
# or a non-stub -> stub conversion
- dropped_functions = {
+ dropped_functions: DiffSubsectionType = {
key: (saved, new)
for key, (saved, new) in combined.items()
if new is None
- or (
- new is not None
- and saved is not None
- and new.get("stub", False)
- and not saved.get("stub", False)
- )
+ or (new is not None and saved is not None and new.is_stub and not saved.is_stub)
}
# TODO: move these two into functions if the assessment gets more complex
# Any function with increased match percentage
# or stub -> non-stub conversion
- improved_functions = {
+ improved_functions: DiffSubsectionType = {
key: (saved, new)
for key, (saved, new) in combined.items()
if saved is not None
and new is not None
- and (
- new["matching"] > saved["matching"]
- or (not new.get("stub", False) and saved.get("stub", False))
- )
+ and (new.accuracy > saved.accuracy or (not new.is_stub and saved.is_stub))
}
# Any non-stub function with decreased match percentage
- degraded_functions = {
+ degraded_functions: DiffSubsectionType = {
key: (saved, new)
for key, (saved, new) in combined.items()
if saved is not None
and new is not None
- and new["matching"] < saved["matching"]
- and not saved.get("stub")
- and not new.get("stub")
+ and new.accuracy < saved.accuracy
+ and not saved.is_stub
+ and not new.is_stub
}
# Any function with former or current "effective" match
- entropy_functions = {
+ entropy_functions: DiffSubsectionType = {
key: (saved, new)
for key, (saved, new) in combined.items()
if saved is not None
and new is not None
- and new["matching"] == 1.0
- and saved["matching"] == 1.0
- and new.get("effective", False) != saved.get("effective", False)
+ and new.accuracy == 1.0
+ and saved.accuracy == 1.0
+ and new.is_effective_match != saved.is_effective_match
}
get_diff_str = diff_json_display(show_both_addrs, is_plain)
diff --git a/reccmp/tools/aggregate.py b/reccmp/tools/aggregate.py
new file mode 100644
index 00000000..563d5cad
--- /dev/null
+++ b/reccmp/tools/aggregate.py
@@ -0,0 +1,170 @@
+#!/usr/bin/env python3
+
+import argparse
+import logging
+from typing import Sequence
+from pathlib import Path
+from reccmp.isledecomp.utils import diff_json
+from reccmp.isledecomp.compare.report import (
+ ReccmpStatusReport,
+ combine_reports,
+ ReccmpReportDeserializeError,
+ ReccmpReportSameSourceError,
+ deserialize_reccmp_report,
+ serialize_reccmp_report,
+)
+
+
+logger = logging.getLogger(__name__)
+
+
+def write_report_file(output_file: Path, report: ReccmpStatusReport):
+ """Convert the status report to JSON and write to a file."""
+ json_str = serialize_reccmp_report(report)
+
+ with open(output_file, "w+", encoding="utf-8") as f:
+ f.write(json_str)
+
+
+def load_report_file(report_path: Path) -> ReccmpStatusReport:
+ """Deserialize from JSON at the given filename and return the report."""
+
+ with report_path.open("r", encoding="utf-8") as f:
+ return deserialize_reccmp_report(f.read())
+
+
+def deserialize_sample_files(paths: list[Path]) -> list[ReccmpStatusReport]:
+ """Deserialize all sample files and return the list of reports.
+ Does not remove duplicates."""
+ samples = []
+
+ for path in paths:
+ if path.is_file():
+ try:
+ report = load_report_file(path)
+ samples.append(report)
+ except ReccmpReportDeserializeError:
+ logger.warning("Skipping '%s' due to import error", path)
+ elif not path.exists():
+ logger.warning("File not found: '%s'", path)
+
+ return samples
+
+
+class TwoOrMoreArgsAction(argparse.Action):
+ """Support nargs=2+"""
+
+ def __call__(
+ self, parser, namespace, values: Sequence[str] | None, option_string=None
+ ):
+ assert isinstance(values, Sequence)
+ if len(values) < 2:
+ raise argparse.ArgumentError(self, "expected two or more arguments")
+
+ setattr(namespace, self.dest, values)
+
+
+class TwoOrFewerArgsAction(argparse.Action):
+ """Support nargs=(1,2)"""
+
+ def __call__(
+ self, parser, namespace, values: Sequence[str] | None, option_string=None
+ ):
+ assert isinstance(values, Sequence)
+ if len(values) not in (1, 2):
+ raise argparse.ArgumentError(self, "expected one or two arguments")
+
+ setattr(namespace, self.dest, values)
+
+
+def main():
+ parser = argparse.ArgumentParser(
+ allow_abbrev=False,
+ description="Aggregate saved accuracy reports.",
+ )
+ parser.add_argument(
+ "--diff",
+ type=Path,
+ metavar="",
+ nargs="+",
+ action=TwoOrFewerArgsAction,
+ help="Report files to diff.",
+ )
+ parser.add_argument(
+ "--output",
+ "-o",
+ type=Path,
+ metavar="",
+ help="Where to save the aggregate file.",
+ )
+ parser.add_argument(
+ "--samples",
+ type=Path,
+ metavar="",
+ nargs="+",
+ action=TwoOrMoreArgsAction,
+ help="Report files to aggregate.",
+ )
+ parser.add_argument(
+ "--no-color", "-n", action="store_true", help="Do not color the output"
+ )
+
+ args = parser.parse_args()
+
+ if not (args.samples or args.diff):
+ parser.error(
+ "exepected arguments for --samples or --diff. (No input files specified)"
+ )
+
+ if not (args.output or args.diff):
+ parser.error(
+ "expected arguments for --output or --diff. (No output action specified)"
+ )
+
+ agg_report: ReccmpStatusReport | None = None
+
+ if args.samples is not None:
+ samples = deserialize_sample_files(args.samples)
+
+ if len(samples) < 2:
+ logger.error("Not enough samples to aggregate!")
+ return 1
+
+ try:
+ agg_report = combine_reports(samples)
+ except ReccmpReportSameSourceError:
+ filename_list = sorted({s.filename for s in samples})
+ logger.error(
+ "Aggregate samples are not from the same source file!\nFilenames used: %s",
+ filename_list,
+ )
+ return 1
+
+ if args.output is not None:
+ write_report_file(args.output, agg_report)
+
+ # If --diff has at least one file and we aggregated some samples this run, diff the first file and the aggregate.
+ # If --diff has two files and we did not aggregate this run, diff the files in the list.
+ if args.diff is not None:
+ saved_data = load_report_file(args.diff[0])
+
+ if agg_report is None:
+ if len(args.diff) > 1:
+ agg_report = load_report_file(args.diff[1])
+ else:
+ logger.error("Not enough files to diff!")
+ return 1
+ elif len(args.diff) == 2:
+ logger.warning(
+ "Ignoring second --diff argument '%s'.\nDiff of '%s' and aggregate report follows.",
+ args.diff[1],
+ args.diff[0],
+ )
+
+ diff_json(saved_data, agg_report, show_both_addrs=False, is_plain=args.no_color)
+
+ return 0
+
+
+if __name__ == "__main__":
+ raise SystemExit(main())
diff --git a/reccmp/tools/asmcmp.py b/reccmp/tools/asmcmp.py
index 0e5eb16c..59c11781 100755
--- a/reccmp/tools/asmcmp.py
+++ b/reccmp/tools/asmcmp.py
@@ -2,11 +2,8 @@
import argparse
import base64
-import json
import logging
import os
-from datetime import datetime
-from pathlib import Path
from pystache import Renderer # type: ignore[import-untyped]
import colorama
@@ -18,6 +15,12 @@
)
from reccmp.isledecomp.compare import Compare as IsleCompare
+from reccmp.isledecomp.compare.report import (
+ ReccmpStatusReport,
+ ReccmpComparedEntity,
+ deserialize_reccmp_report,
+ serialize_reccmp_report,
+)
from reccmp.isledecomp.formats.detect import detect_image
from reccmp.isledecomp.formats.pe import PEImage
from reccmp.isledecomp.types import EntityType
@@ -34,39 +37,21 @@
colorama.just_fix_windows_console()
-def gen_json(json_file: str, orig_file: Path, data):
- """Create a JSON file that contains the comparison summary"""
-
- # If the structure of the JSON file ever changes, we would run into a problem
- # reading an older format file in the CI action. Mark which version we are
- # generating so we could potentially address this down the road.
- json_format_version = 1
-
- # Remove the diff field
- reduced_data = [
- {key: value for (key, value) in obj.items() if key != "diff"} for obj in data
- ]
+def gen_json(json_file: str, json_str: str):
+ """Convert the status report to JSON and write to a file."""
with open(json_file, "w", encoding="utf-8") as f:
- json.dump(
- {
- "file": orig_file.name.lower(),
- "format": json_format_version,
- "timestamp": datetime.now().timestamp(),
- "data": reduced_data,
- },
- f,
- )
+ f.write(json_str)
-def gen_html(html_file, data):
+def gen_html(html_file: str, report: str):
js_path = get_asset_file("../assets/reccmp.js")
with open(js_path, "r", encoding="utf-8") as f:
reccmp_js = f.read()
output_data = Renderer().render_path(
get_asset_file("../assets/template.html"),
- {"data": data, "reccmp_js": reccmp_js},
+ {"report": report, "reccmp_js": reccmp_js},
)
with open(html_file, "w", encoding="utf-8") as htmlfile:
@@ -267,7 +252,8 @@ def main():
function_count = 0
total_accuracy = 0.0
total_effective_accuracy = 0.0
- htmlinsert = []
+
+ report = ReccmpStatusReport(filename=target.original_path.name.lower())
for match in isle_compare.compare_all():
if not args.silent and args.diff is None:
@@ -287,44 +273,38 @@ def main():
total_effective_accuracy += match.effective_ratio
# If html, record the diffs to an HTML file
- html_obj = {
- "address": f"0x{match.orig_addr:x}",
- "recomp": f"0x{match.recomp_addr:x}",
- "name": match.name,
- "matching": match.effective_ratio,
- }
-
- if match.is_effective_match:
- html_obj["effective"] = True
-
- if match.udiff is not None:
- html_obj["diff"] = match.udiff
-
- if match.is_stub:
- html_obj["stub"] = True
-
- htmlinsert.append(html_obj)
+ orig_addr = f"0x{match.orig_addr:x}"
+ recomp_addr = f"0x{match.recomp_addr:x}"
+
+ report.entities[orig_addr] = ReccmpComparedEntity(
+ orig_addr=orig_addr,
+ name=match.name,
+ accuracy=match.effective_ratio,
+ recomp_addr=recomp_addr,
+ is_effective_match=match.is_effective_match,
+ is_stub=match.is_stub,
+ diff=match.udiff,
+ )
# Compare with saved diff report.
if args.diff is not None:
with open(args.diff, "r", encoding="utf-8") as f:
- saved_data = json.load(f)
-
- diff_json(
- saved_data,
- htmlinsert,
- target.original_path,
- show_both_addrs=args.print_rec_addr,
- is_plain=args.no_color,
- )
+ saved_data = deserialize_reccmp_report(f.read())
+
+ diff_json(
+ saved_data,
+ report,
+ show_both_addrs=args.print_rec_addr,
+ is_plain=args.no_color,
+ )
## Generate files and show summary.
if args.json is not None:
- gen_json(args.json, target.original_path, htmlinsert)
+ gen_json(args.json, serialize_reccmp_report(report))
if args.html is not None:
- gen_html(args.html, json.dumps(htmlinsert))
+ gen_html(args.html, serialize_reccmp_report(report, diff_included=True))
implemented_funcs = function_count
diff --git a/setup.cfg b/setup.cfg
index 6458a6f5..629c9862 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -21,6 +21,7 @@ exclude =
[options.entry_points]
console_scripts =
+ reccmp-aggregate = reccmp.tools.aggregate:main
reccmp-datacmp = reccmp.tools.datacmp:main
reccmp-decomplint = reccmp.tools.decomplint:main
reccmp-project = reccmp.tools.project:main
diff --git a/tests/test_report.py b/tests/test_report.py
new file mode 100644
index 00000000..11510d3c
--- /dev/null
+++ b/tests/test_report.py
@@ -0,0 +1,118 @@
+"""Reccmp reports: files that contain the comparison result from asmcmp."""
+import pytest
+from reccmp.isledecomp.compare.report import (
+ ReccmpStatusReport,
+ ReccmpComparedEntity,
+ combine_reports,
+ ReccmpReportSameSourceError,
+)
+
+
+def create_report(
+ entities: list[tuple[str, float]] | None = None
+) -> ReccmpStatusReport:
+ """Helper to quickly set up a report to be customized further for each test."""
+ report = ReccmpStatusReport(filename="test.exe")
+ if entities is not None:
+ for addr, accuracy in entities:
+ report.entities[addr] = ReccmpComparedEntity(addr, "test", accuracy)
+
+ return report
+
+
+def test_aggregate_identity():
+ """Combine a list of one report. Should get the same report back,
+ except for expected differences like the timestamp."""
+ report = create_report([("100", 1.0), ("200", 0.5)])
+ combined = combine_reports([report])
+
+ for (a_key, a_entity), (b_key, b_entity) in zip(
+ report.entities.items(), combined.entities.items()
+ ):
+ assert a_key == b_key
+ assert a_entity.orig_addr == b_entity.orig_addr
+ assert a_entity.accuracy == b_entity.accuracy
+
+
+def test_aggregate_simple():
+ """Should choose the best score from the sample reports."""
+ x = create_report([("100", 0.8), ("200", 0.2)])
+ y = create_report([("100", 0.2), ("200", 0.8)])
+
+ combined = combine_reports([x, y])
+ assert combined.entities["100"].accuracy == 0.8
+ assert combined.entities["200"].accuracy == 0.8
+
+
+def test_aggregate_union_all_addrs():
+ """Should combine all addresses from any report."""
+ x = create_report([("100", 0.8)])
+ y = create_report([("200", 0.8)])
+
+ combined = combine_reports([x, y])
+ assert "100" in combined.entities
+ assert "200" in combined.entities
+
+
+def test_aggregate_stubs():
+ """Stub functions (i.e. do not compare asm) are considered to have 0 percent accuracy."""
+ x = create_report([("100", 0.9)])
+ y = create_report([("100", 0.5)])
+
+ # In a real report, accuracy would be zero for a stub.
+ x.entities["100"].is_stub = True
+ y.entities["100"].is_stub = False
+
+ combined = combine_reports([x, y])
+ assert combined.entities["100"].is_stub is False
+
+ # Choose the lower non-stub value
+ assert combined.entities["100"].accuracy == 0.5
+
+
+def test_aggregate_all_stubs():
+ """If all samples are stubs, preserve that setting."""
+ x = create_report([("100", 1.0)])
+
+ x.entities["100"].is_stub = True
+
+ combined = combine_reports([x, x])
+ assert combined.entities["100"].is_stub is True
+
+
+def test_aggregate_100_over_effective():
+ """Prefer 100% match over effective."""
+ x = create_report([("100", 0.9)])
+ y = create_report([("100", 1.0)])
+ x.entities["100"].is_effective_match = True
+
+ combined = combine_reports([x, y])
+ assert combined.entities["100"].is_effective_match is False
+
+
+def test_aggregate_effective_over_any():
+ """Prefer effective match over any accuracy."""
+ x = create_report([("100", 0.5)])
+ y = create_report([("100", 0.6)])
+ x.entities["100"].is_effective_match = True
+ # Y has higher accuracy score, but we could not confirm an effective match.
+
+ combined = combine_reports([x, y])
+ assert combined.entities["100"].is_effective_match is True
+
+ # Should retain original accuracy for effective match.
+ assert combined.entities["100"].accuracy == 0.5
+
+
+def test_aggregate_different_files():
+ """Should raise an exception if we try to aggregate reports
+ where the orig filename does not match."""
+ x = create_report()
+ y = create_report()
+
+ # Make sure they are different, regardless of what is set by create_report().
+ x.filename = "test.exe"
+ y.filename = "hello.exe"
+
+ with pytest.raises(ReccmpReportSameSourceError):
+ combine_reports([x, y])