Skip to content

Commit

Permalink
Handle exceptions in batch context manager
Browse files Browse the repository at this point in the history
  • Loading branch information
disinvite committed Jan 27, 2025
1 parent 0063313 commit d8ee4b2
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 9 deletions.
23 changes: 14 additions & 9 deletions reccmp/isledecomp/compare/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,15 @@ def __init__(self, backref: "EntityDb") -> None:
self._orig_to_recomp = {}
self._recomp_to_orig = {}

def reset(self):
"""Clear all pending changes"""
self._orig_insert.clear()
self._recomp_insert.clear()
self._orig.clear()
self._recomp.clear()
self._orig_to_recomp.clear()
self._recomp_to_orig.clear()

def insert_orig(self, addr: int, **kwargs):
self._orig_insert.setdefault(addr, {}).update(kwargs)

Expand Down Expand Up @@ -201,20 +210,16 @@ def commit(self):
if self._orig_to_recomp:
self.base.bulk_match(self._orig_to_recomp.items())

self._orig_insert.clear()
self._recomp_insert.clear()

self._orig.clear()
self._recomp.clear()

self._orig_to_recomp.clear()
self._recomp_to_orig.clear()
self.reset()

def __enter__(self):
return self

def __exit__(self, exc_type, exc_value, exc_traceback):
self.commit()
if exc_type is not None:
self.reset()
else:
self.commit()


class EntityDb:
Expand Down
30 changes: 30 additions & 0 deletions tests/test_compare_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,3 +282,33 @@ def test_batch_match_repeat_recomp_addr(db):

assert db.get_by_recomp(200).orig_addr == 101
assert db.get_by_orig(100) is None


def test_batch_exception_uncaught(db):
"""When using batch context manager, an uncaught exception should clear the staged changes."""
try:
with db.batch() as batch:
batch.set_orig(100, name="Test")
batch.set_recomp(200, test=123)
batch.match(100, 200)
_ = 1 / 0
except ZeroDivisionError:
pass

assert db.get_by_orig(100) is None
assert db.get_by_orig(200) is None


def test_batch_exception_caught(db):
"""If the exception is caught, allow the batch to go through."""
with db.batch() as batch:
batch.set_orig(100, name="Test")
batch.set_recomp(200, test=123)
batch.match(100, 200)
try:
_ = 1 / 0
except ZeroDivisionError:
pass

assert db.get_by_orig(100) is not None
assert db.get_by_recomp(200) is not None

0 comments on commit d8ee4b2

Please sign in to comment.