Skip to content
This repository has been archived by the owner on Jan 9, 2025. It is now read-only.

Commit

Permalink
fix test; make storage mock work on iterable values
Browse files Browse the repository at this point in the history
  • Loading branch information
enitrat committed Nov 14, 2024
1 parent 9d3ff74 commit 4101768
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 43 deletions.
2 changes: 1 addition & 1 deletion cairo_zero/kakarot/library.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ namespace Kakarot {
return ();
}

Kakarot_base_fee.write('current_block', (base_fee, starting_block));
Kakarot_base_fee.write('current_block', (next_base_fee, starting_block));
return ();
}

Expand Down
45 changes: 5 additions & 40 deletions cairo_zero/tests/src/kakarot/test_kakarot.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,14 +165,7 @@ def test_set_base_fee_should_set_next_block_fee(self, cairo_run):
get_storage_var_address(
"Kakarot_base_fee", int.from_bytes(b"next_block", "big")
),
value=1,
)
@SyscallHandler.patch(
get_storage_var_address(
"Kakarot_base_fee", int.from_bytes(b"next_block", "big")
)
+ 1,
value=0x101,
value=[1, 0x101],
)
@patch.object(SyscallHandler, "block_number", 0x102)
def test_set_base_fee_should_overwrite_current_block_fee_if_next_block_is_applicable(
Expand Down Expand Up @@ -213,27 +206,13 @@ def test_set_base_fee_should_overwrite_current_block_fee_if_next_block_is_applic
get_storage_var_address(
"Kakarot_base_fee", int.from_bytes(b"current_block", "big")
),
value=1,
)
@SyscallHandler.patch(
get_storage_var_address(
"Kakarot_base_fee", int.from_bytes(b"current_block", "big")
)
+ 1,
value=0x100,
value=[1, 0x100],
)
@SyscallHandler.patch(
get_storage_var_address(
"Kakarot_base_fee", int.from_bytes(b"next_block", "big")
),
value=2,
)
@SyscallHandler.patch(
get_storage_var_address(
"Kakarot_base_fee", int.from_bytes(b"next_block", "big")
)
+ 1,
value=0x101,
value=[2, 0x101],
)
@patch.object(SyscallHandler, "block_number", 0x100)
def test_get_base_fee_should_return_current_block_fee_if_next_block_is_not_applicable(
Expand All @@ -247,27 +226,13 @@ def test_get_base_fee_should_return_current_block_fee_if_next_block_is_not_appli
get_storage_var_address(
"Kakarot_base_fee", int.from_bytes(b"current_block", "big")
),
value=1,
)
@SyscallHandler.patch(
get_storage_var_address(
"Kakarot_base_fee", int.from_bytes(b"current_block", "big")
)
+ 1,
value=0x100,
value=[1, 0x100],
)
@SyscallHandler.patch(
get_storage_var_address(
"Kakarot_base_fee", int.from_bytes(b"next_block", "big")
),
value=2,
)
@SyscallHandler.patch(
get_storage_var_address(
"Kakarot_base_fee", int.from_bytes(b"next_block", "big")
)
+ 1,
value=0x101,
value=[2, 0x101],
)
@patch.object(SyscallHandler, "block_number", 0x101)
def test_get_base_fee_should_return_next_block_fee_if_applicable_and_update_current_block(
Expand Down
10 changes: 8 additions & 2 deletions tests/utils/syscall_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from contextlib import contextmanager
from dataclasses import dataclass
from hashlib import sha256
from typing import Optional, Union
from typing import Iterable, Optional, Union
from unittest import mock

import ecdsa
Expand Down Expand Up @@ -609,7 +609,13 @@ def patch(
selector_if_storage = get_storage_var_address(target, *args)
else:
selector_if_storage = target
cls.patches[selector_if_storage] = value

if isinstance(value, Iterable):
for i, v in enumerate(value):
cls.patches[selector_if_storage + i] = v
else:
cls.patches[selector_if_storage] = value

except AssertionError:
pass

Expand Down

0 comments on commit 4101768

Please sign in to comment.