Skip to content

Commit

Permalink
Rename Nftables to NFTables
Browse files Browse the repository at this point in the history
  • Loading branch information
badrogger committed Nov 15, 2024
1 parent 43a1e63 commit 6980a49
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 27 deletions.
8 changes: 4 additions & 4 deletions core/schains/firewall/firewall_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from typing import Iterable, Optional

from core.schains.firewall.iptables import IptablesController
from core.schains.firewall.nftables import NftablesController
from core.schains.firewall.nftables import NFTablesController
from core.schains.firewall.types import (
IFirewallManager,
IHostFirewallController,
Expand Down Expand Up @@ -91,9 +91,9 @@ def create_host_controller(self) -> IptablesController:
return IptablesController()


class NftSchainFirewallManager(SChainFirewallManager):
def create_host_controller(self) -> NftablesController:
nc_controller = NftablesController(chain=self.name)
class NFTSchainFirewallManager(SChainFirewallManager):
def create_host_controller(self) -> NFTablesController:
nc_controller = NFTablesController(chain=self.name)
nc_controller.create_table()
nc_controller.create_chain()
return nc_controller
18 changes: 9 additions & 9 deletions core/schains/firewall/nftables.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,19 +58,19 @@ def is_like_number(value):
return True


class NftablesCmdFailedError(Exception):
class NFTablesCmdFailedError(Exception):
pass


class NftablesController(IHostFirewallController):
class NFTablesController(IHostFirewallController):
plock = multiprocessing.Lock()
FAMILY = 'inet'

def __init__(self, table: str = TABLE, chain: str = CHAIN) -> None:
self.table = table
self.chain = chain
self._nftables = importlib.import_module('nftables')
self.nft = self._nftables.Nftables()
self.nft = self._nftables.NFTables()
self.nft.set_json_output(True)

def _compose_json(self, commands: list[dict]) -> dict:
Expand Down Expand Up @@ -105,25 +105,25 @@ def create_chain(self) -> None:
def chains(self) -> list[dict]:
output = self.run_cmd('list chains')
if output[0] != 0:
raise NftablesCmdFailedError(output)
raise NFTablesCmdFailedError(output)
parsed = json.loads(output[1])['nftables']
return [record['chain']['name'] for record in parsed if 'chain' in record]

@property
def tables(self) -> list[dict]:
output = self.run_cmd('list tables')
if output[0] != 0:
raise NftablesCmdFailedError(output)
raise NFTablesCmdFailedError(output)
parsed = json.loads(output[1])['nftables']
return [record['table']['name'] for record in parsed if 'table' in record]

def run_json_cmd(self, cmd: dict) -> tuple:
logger.debug('Nftables json cmd %s', cmd)
logger.debug('NFTables json cmd %s', cmd)
with self.plock:
return self.nft.json_cmd(cmd)

def run_cmd(self, cmd: str) -> tuple:
logger.debug('Nftables cmd %s', cmd)
logger.debug('NFTables cmd %s', cmd)
with self.plock:
return self.nft.cmd(cmd)

Expand Down Expand Up @@ -155,7 +155,7 @@ def add_rule(self, rule: SChainRule) -> None:

rc, output, error = self.run_json_cmd(json_cmd)
if rc != 0:
raise NftablesCmdFailedError(f'Failed to add allow rule: {error}')
raise NFTablesCmdFailedError(f'Failed to add allow rule: {error}')

@classmethod
def rule_to_expr(cls, rule: SChainRule) -> list:
Expand Down Expand Up @@ -256,7 +256,7 @@ def remove_rule(self, rule: SChainRule) -> None:

rc, output, error = self.run_json_cmd(json_cmd)
if rc != 0:
raise NftablesCmdFailedError(f'Failed to delete rule: {error}')
raise NFTablesCmdFailedError(f'Failed to delete rule: {error}')

@property # type: ignore
def rules(self) -> Iterable[SChainRule]:
Expand Down
8 changes: 4 additions & 4 deletions core/schains/firewall/rule_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from functools import wraps
from typing import Any, Callable, cast, Dict, Iterable, List, Optional, TypeVar

from .firewall_manager import IptablesSChainFirewallManager, NftSchainFirewallManager
from .firewall_manager import IptablesSChainFirewallManager, NFTSchainFirewallManager
from .types import (
IFirewallManager,
IpRange,
Expand Down Expand Up @@ -216,10 +216,10 @@ def create_firewall_manager(self) -> IptablesSChainFirewallManager:
)


class NftSchainRuleController(SChainRuleController):
class NFTSchainRuleController(SChainRuleController):
@configured_only
def create_firewall_manager(self) -> NftSchainFirewallManager:
return NftSchainFirewallManager(
def create_firewall_manager(self) -> NFTSchainFirewallManager:
return NFTSchainFirewallManager(
self.name,
self.base_port, # type: ignore
self.base_port + self.ports_per_schain - 1 # type: ignore
Expand Down
6 changes: 3 additions & 3 deletions core/schains/firewall/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from skale import Skale

from .types import IpRange
from .rule_controller import IptablesSChainRuleController, NftSchainRuleController
from .rule_controller import IptablesSChainRuleController, NFTSchainRuleController


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -72,11 +72,11 @@ def get_nftables_rule_controller(
own_ip: Optional[str] = None,
node_ips: List[str] = [],
sync_agent_ranges: Optional[List[IpRange]] = []
) -> NftSchainRuleController:
) -> NFTSchainRuleController:
sync_agent_ranges = sync_agent_ranges or []
logger.info('Creating rule controller for %s', name)
logger.debug('Rule controller ranges for %s: %s', name, sync_agent_ranges)
return NftSchainRuleController(
return NFTSchainRuleController(
name=name,
base_port=base_port,
own_ip=own_ip,
Expand Down
13 changes: 6 additions & 7 deletions tests/firewall/nftables_test.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
import concurrent.futures
import importlib
import subprocess
import time

import pytest

from core.schains.firewall.nftables import NftablesController
from core.schains.firewall.nftables import NFTablesController
from core.schains.firewall.types import SChainRule


@pytest.fixture
def nf_test_tables():
nft = importlib.import_module('nftables').Nftables()
nft = importlib.import_module('nftables').NFTables()
nft.cmd('flush ruleset')
return nft

Expand All @@ -28,7 +27,7 @@ def custom_chain(nf_test_tables, filter_table):


def test_nftables_controller(custom_chain):
nft_controller = NftablesController(chain='test-chain')
nft_controller = NFTablesController(chain='test-chain')
rule_a = SChainRule(10000, '1.1.1.1', '2.2.2.2')
rule_b = SChainRule(10001, '3.3.3.3')
nft_controller.add_rule(rule_a)
Expand All @@ -46,7 +45,7 @@ def test_nftables_controller(custom_chain):

def test_nftables_controller_duplicates(custom_chain):
rule_a = SChainRule(10000, '1.1.1.1', '2.2.2.2')
manager = NftablesController(chain='test-chain')
manager = NFTablesController(chain='test-chain')
manager.add_rule(rule_a)
rule_b = SChainRule(10001, '3.3.3.3', '4.4.4.4')
manager.add_rule(rule_b)
Expand All @@ -68,7 +67,7 @@ def test_nftables_controller_duplicates(custom_chain):


def add_remove_rule(srule, refresh):
manager = NftablesController()
manager = NFTablesController()
manager.add_rule(srule)
time.sleep(1)
if not manager.has_rule(srule):
Expand Down Expand Up @@ -100,6 +99,6 @@ def test_nftables_manager_parallel(custom_chain):

for future in concurrent.futures.as_completed(futures):
assert future.result
manager = NftablesController(custom_chain)
manager = NFTablesController(custom_chain)
time.sleep(10)
assert len(list(manager.rules)) == 0

0 comments on commit 6980a49

Please sign in to comment.