diff --git a/core/schains/cleaner.py b/core/schains/cleaner.py index 9d1bd1f4..881a2e63 100644 --- a/core/schains/cleaner.py +++ b/core/schains/cleaner.py @@ -30,15 +30,9 @@ from core.node import get_current_nodes, get_skale_node_version from core.schains.checks import SChainChecks -from core.schains.config.file_manager import ConfigFileManager from core.schains.config.directory import schain_config_dir from core.schains.dkg.utils import get_secret_key_share_filepath -from core.schains.firewall.utils import get_default_rule_controller -from core.schains.config.helper import ( - get_base_port_from_config, - get_node_ips_from_config, - get_own_ip_from_config, -) +from core.schains.firewall.utils import cleanup_firewall_for_schain, get_default_rule_controller from core.schains.process import ProcessReport, terminate_process from core.schains.runner import get_container_name, is_exited from core.schains.external_config import ExternalConfig @@ -152,8 +146,10 @@ def get_schains_on_node(dutils=None): schains_with_container = get_schains_with_containers(dutils) schains_active_records = get_schains_names() schains_firewall_configs = list( - map(lambda name: name.removeprefix('skale-'), - get_schains_firewall_configs()) + map( + lambda name: name.removeprefix('skale-'), + get_schains_firewall_configs() + ) ) logger.info( 'dirs %s, containers: %s, records: %s, firewall configs: %s', @@ -281,15 +277,8 @@ def cleanup_schain( if check_status['volume']: remove_schain_volume(schain_name, dutils=dutils) if any(checks.firewall_rules.data): - conf = ConfigFileManager(schain_name).skaled_config - base_port = get_base_port_from_config(conf) - own_ip = get_own_ip_from_config(conf) - node_ips = get_node_ips_from_config(conf) - ranges = [] - if estate is not None: - ranges = estate.ranges - rc.configure(base_port=base_port, own_ip=own_ip, node_ips=node_ips, sync_ip_ranges=ranges) - rc.cleanup() + logger.info('Cleaning firewall for %s', schain_name) + cleanup_firewall_for_schain(schain_name) if estate is not None and estate.ima_linked: if check_status.get('ima_container', False) or is_exited( diff --git a/core/schains/firewall/firewall_manager.py b/core/schains/firewall/firewall_manager.py index 6b33ab3b..f9d1bad2 100644 --- a/core/schains/firewall/firewall_manager.py +++ b/core/schains/firewall/firewall_manager.py @@ -87,15 +87,14 @@ def remove_rules(self, rules: Iterable[SChainRule]) -> None: for rule in rules: self.host_controller.remove_rule(rule) - def flush(self) -> None: - self.remove_rules(self.rules) - self.host_controller.cleanup() - class IptablesSChainFirewallManager(SChainFirewallManager): def create_host_controller(self) -> IptablesController: return IptablesController() + def cleanup(self) -> None: + self.remove_rules(self.rules) + class NFTSchainFirewallManager(SChainFirewallManager): def create_host_controller(self) -> NFTablesController: @@ -111,3 +110,7 @@ def rules_saved(self) -> bool: def base_config_applied(self) -> bool: return self.host_controller.has_chain(self.host_controller.chain) and \ self.host_controller.has_drop_rule(self.first_port, self.last_port) + + def cleanup(self) -> None: + self.host_controller.cleanup() + self.host_controller.remove_saved_rules() diff --git a/core/schains/firewall/nftables.py b/core/schains/firewall/nftables.py index a2977c4b..e2a56e39 100644 --- a/core/schains/firewall/nftables.py +++ b/core/schains/firewall/nftables.py @@ -44,7 +44,7 @@ class NFTablesController(IHostFirewallController): plock = multiprocessing.Lock() FAMILY = 'inet' - def __init__(self, table: str = TABLE, chain: str = CHAIN) -> None: + def __init__(self, chain: str, table: str = TABLE) -> None: self.table = table self.chain = f'skale-{chain}' self._nftables = importlib.import_module('nftables') @@ -384,9 +384,8 @@ def get_saved_rules(self) -> str: return nft_chain_file.read() def remove_saved_rules(self) -> None: - if os.isfile(self.nft_chain_path): + if os.path.isfile(self.nft_chain_path): os.remove(self.nft_chain_path) def cleanup(self) -> None: - self.remove_saved_rules() self.delete_chain() diff --git a/core/schains/firewall/rule_controller.py b/core/schains/firewall/rule_controller.py index a109d30b..68620526 100644 --- a/core/schains/firewall/rule_controller.py +++ b/core/schains/firewall/rule_controller.py @@ -202,9 +202,6 @@ def sync(self) -> None: logger.debug('Syncing firewall rules with %s', erules) self.firewall_manager.update_rules(erules) - def cleanup(self) -> None: - self.firewall_manager.flush() - class IptablesSChainRuleController(SChainRuleController): @configured_only @@ -223,6 +220,10 @@ def is_persistent(self) -> bool: def is_inited(self) -> bool: return True + @configured_only + def cleanup(self) -> None: + self.firewall_manager.cleanup() + class NFTSchainRuleController(SChainRuleController): @configured_only @@ -240,3 +241,6 @@ def is_persistent(self) -> bool: @configured_only def is_inited(self) -> bool: return self.firewall_manager.base_config_applied() + + def cleanup(self) -> None: + self.firewall_manager.cleanup() diff --git a/core/schains/firewall/types.py b/core/schains/firewall/types.py index ecb076c6..c30bfc11 100644 --- a/core/schains/firewall/types.py +++ b/core/schains/firewall/types.py @@ -108,7 +108,7 @@ def update_rules(self, rules: Iterable[SChainRule]) -> None: # pragma: no cover pass @abstractmethod - def flush(self) -> None: # pragma: no cover # noqa + def cleanup(self) -> None: # pragma: no cover # noqa pass diff --git a/core/schains/firewall/utils.py b/core/schains/firewall/utils.py index 1f94694f..0788c6df 100644 --- a/core/schains/firewall/utils.py +++ b/core/schains/firewall/utils.py @@ -25,6 +25,7 @@ from skale import Skale from .types import IpRange +from .nftables import NFTablesController from .rule_controller import IptablesSChainRuleController, NFTSchainRuleController @@ -101,3 +102,9 @@ def save_sync_ranges(sync_agent_ranges: List[IpRange], path: str) -> None: def ranges_from_plain_tuples(plain_ranges: List[Tuple]) -> List[IpRange]: return list(sorted(map(lambda r: IpRange(*r), plain_ranges))) + + +def cleanup_firewall_for_schain(schain_name: str) -> None: + nft = NFTablesController(chain=schain_name) + nft.cleanup() + nft.remove_saved_rules() diff --git a/tests/firewall/firewall_manager_test.py b/tests/firewall/firewall_manager_test.py index 719ad1bf..04203acc 100644 --- a/tests/firewall/firewall_manager_test.py +++ b/tests/firewall/firewall_manager_test.py @@ -53,7 +53,7 @@ def test_firewall_manager_update_existed(): assert fm.host_controller.remove_rule.call_count == 0 -def test_firewall_manager_flush(): +def test_firewall_manager_cleanup(): fm = SChainTestFirewallManager('test', 10000, 10064) rules = [ SChainRule(10000, '2.2.2.2'), @@ -63,6 +63,6 @@ def test_firewall_manager_flush(): fm.add_rules(rules) fm.host_controller.add_rule(SChainRule(10072, '2.2.2.2')) - fm.flush() + fm.cleanup() assert list(fm.rules) == [] assert fm.host_controller.has_rule(SChainRule(10072, '2.2.2.2')) diff --git a/tests/firewall/nftables_test.py b/tests/firewall/nftables_test.py index cee70eed..e77af1fa 100644 --- a/tests/firewall/nftables_test.py +++ b/tests/firewall/nftables_test.py @@ -7,6 +7,7 @@ from core.schains.firewall.nftables import NFTablesController, NFT_CHAIN_BASE_PATH from core.schains.firewall.types import SChainRule +from core.schains.firewall.utils import cleanup_firewall_for_schain from tools.helper import run_cmd @@ -87,6 +88,9 @@ def test_create_delete_chain(filter_table, nft_chain_folder): manager.cleanup() chains = run_cmd(['nft', 'list', 'chains']).stdout.decode('utf-8') assert chains == 'table inet firewall {\n}\n' + assert os.path.isfile(nft_chain_path) + + manager.remove_saved_rules() assert not os.path.isfile(nft_chain_path) @@ -106,8 +110,21 @@ def test_saved_rules(filter_table, nft_chain_folder): assert not os.path.isfile(nft_chain_path) +def test_cleanup_firewall_for_schain(filter_table, nft_chain_folder): + chain_name = 'test-chain' + nft_chain_path = os.path.join(NFT_CHAIN_BASE_PATH, f'skale-{chain_name}.conf') + + manager = NFTablesController(chain=chain_name) + manager.create_chain(first_port=10000, last_port=10063) + + cleanup_firewall_for_schain(schain_name=chain_name) + chains = run_cmd(['nft', 'list', 'chains']).stdout.decode('utf-8') + assert chains == 'table inet firewall {\n}\n' + assert not os.path.isfile(nft_chain_path) + + def add_remove_rule(srule, refresh): - manager = NFTablesController() + manager = NFTablesController(chain='test') manager.add_rule(srule) time.sleep(1) if not manager.has_rule(srule): diff --git a/tests/schains/cleaner_test.py b/tests/schains/cleaner_test.py index d16b41fd..ea45b25b 100644 --- a/tests/schains/cleaner_test.py +++ b/tests/schains/cleaner_test.py @@ -239,7 +239,8 @@ def test_get_schains_on_node(schain_dirs_for_monitor, ]).issubset(set(result)) -def test_remove_schain(skale, schain_db, node_config, dutils): +@mock.patch('core.schains.cleaner.cleanup_firewall_for_schain') +def test_remove_schain(cleanup_firewall_for_schain, skale, schain_db, node_config, dutils): schain_name = schain_db remove_schain(skale, node_config.id, schain_name, msg='Test remove_schain', dutils=dutils) container_name = SCHAIN_CONTAINER_NAME_TEMPLATE.format(schain_name) @@ -250,7 +251,9 @@ def test_remove_schain(skale, schain_db, node_config, dutils): assert record.is_deleted is True +@mock.patch('core.schains.cleaner.cleanup_firewall_for_schain') def test_cleanup_schain( + cleanup_firewall_rules, schain_db, node_config, schain_on_contracts, diff --git a/tests/utils.py b/tests/utils.py index 06b7e515..ac1f7f7d 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -233,6 +233,9 @@ class SChainTestFirewallManager(SChainFirewallManager): def create_host_controller(self): return HostTestFirewallController() + def cleanup(self): + self.remove_rules(self.rules) + class SChainTestRuleController(SChainRuleController): def create_firewall_manager(self): @@ -248,6 +251,9 @@ def is_persistent(self) -> bool: def is_inited(self) -> bool: return True + def cleanup(self) -> None: + self.firewall_manager.cleanup() + def get_test_rule_controller( name,