Skip to content

Commit f4c9506

Browse files
authored
Merge pull request #1 from xenanetworks/dev
Switch to pydantic v2
2 parents 30cc40a + 9ff0422 commit f4c9506

25 files changed

+310
-437
lines changed

plugin2544/dataset.py

Lines changed: 41 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
from typing import Any, List, Tuple, Dict
2-
from pydantic import BaseModel, validator
1+
from typing import Any, List, Tuple, Dict, Annotated
2+
from pydantic import BaseModel, field_validator, ValidationInfo, Field
33
from .utils import exceptions, constants as const
44
from .model.m_test_config import TestConfigModel
55
from .model.m_test_type_config import TestTypesConfiguration
@@ -11,9 +11,9 @@
1111

1212

1313
class PluginModel2544(BaseModel): # Main Model
14-
test_configuration: TestConfigModel
14+
test_configuration: Annotated[TestConfigModel, Field(validate_default=True)]
1515
protocol_segments: List[ProtocolSegmentProfileConfig]
16-
ports_configuration: PortConfType
16+
ports_configuration: Annotated[PortConfType, Field(validate_default=True)]
1717
test_types_configuration: TestTypesConfiguration
1818

1919
def set_ports_rx_tx_type(self) -> None:
@@ -45,33 +45,33 @@ def __init__(self, **data: Dict[str, Any]) -> None:
4545
self.check_port_groups_and_peers()
4646
self.set_profile()
4747

48-
@validator("ports_configuration", always=True)
49-
def check_ip_properties(cls, v: "PortConfType", values) -> "PortConfType":
50-
pro_map = {v.id: v.protocol_version for v in values['protocol_segments']}
51-
for i, port_config in enumerate(v):
48+
49+
@field_validator("ports_configuration")
50+
def check_ip_properties(cls, value: "PortConfType", info: ValidationInfo) -> "PortConfType":
51+
pro_map = {v.id: v.protocol_version for v in info.data['protocol_segments']}
52+
for i, port_config in enumerate(value):
5253
if port_config.protocol_segment_profile_id not in pro_map:
5354
raise exceptions.PSPMissing()
5455
if (
5556
pro_map[port_config.protocol_segment_profile_id].is_l3
5657
and (not port_config.ip_address or port_config.ip_address.address.is_empty)
5758
):
5859
raise exceptions.IPAddressMissing()
59-
return v
60+
return value
6061

61-
@validator("ports_configuration", always=True)
62-
def check_port_count(
63-
cls, v: "PortConfType", values: Dict[str, Any]
64-
) -> "PortConfType":
62+
63+
@field_validator("ports_configuration")
64+
def check_port_count(cls, value: "PortConfType", info: ValidationInfo) -> "PortConfType":
6565
require_ports = 2
66-
if "test_configuration" in values:
67-
topology: const.TestTopology = values[
66+
if "test_configuration" in info.data:
67+
topology: const.TestTopology = info.data[
6868
"test_configuration"
6969
].topology_config.topology
7070
if topology.is_pair_topology:
7171
require_ports = 1
72-
if len(v) < require_ports:
72+
if len(value) < require_ports:
7373
raise exceptions.PortConfigNotEnough(require_ports)
74-
return v
74+
return value
7575

7676
def check_port_groups_and_peers(self) -> None:
7777
topology = self.test_configuration.topology_config.topology
@@ -89,37 +89,35 @@ def check_port_groups_and_peers(self) -> None:
8989
if not i:
9090
raise exceptions.PortGroupError(group)
9191

92-
@validator("ports_configuration", always=True)
93-
def check_modifier_mode_and_segments(
94-
cls, v: "PortConfType", values: Dict[str, Any]
95-
) -> "PortConfType":
96-
if "test_configuration" in values:
97-
flow_creation_type = values[
92+
93+
@field_validator("ports_configuration")
94+
def check_modifier_mode_and_segments(cls, value: PortConfType, info: ValidationInfo) -> PortConfType:
95+
if "test_configuration" in info.data:
96+
flow_creation_type = info.data[
9897
"test_configuration"
9998
].test_execution_config.flow_creation_config.flow_creation_type
100-
for port_config in v:
99+
for port_config in value:
101100
if (
102101
not flow_creation_type.is_stream_based
103102
) and port_config.profile.protocol_version.is_l3:
104103
raise exceptions.ModifierBasedNotSupportL3()
105-
return v
104+
return value
105+
106106

107-
@validator("ports_configuration", always=True)
108-
def check_port_group(
109-
cls, v: "PortConfiguration", values: Dict[str, Any]
110-
) -> "PortConfiguration":
111-
if "ports_configuration" in values and "test_configuration" in values:
112-
for k, p in values["ports_configuration"].items():
107+
@field_validator("ports_configuration")
108+
def check_port_group(cls, value: PortConfiguration, info: ValidationInfo) -> PortConfiguration:
109+
if "ports_configuration" in info.data and "test_configuration" in info.data:
110+
for k, p in info.data["ports_configuration"].items():
113111
if (
114112
p.port_group == const.PortGroup.UNDEFINED
115-
and not values[
113+
and not info.data[
116114
"test_configuration"
117115
].topology_config.topology.is_mesh_topology
118116
):
119117
raise exceptions.PortGroupNeeded()
120-
return v
118+
return value
121119

122-
@validator("test_types_configuration", always=True)
120+
@field_validator("test_types_configuration")
123121
def check_test_type_enable(
124122
cls, v: "TestTypesConfiguration"
125123
) -> "TestTypesConfiguration":
@@ -134,22 +132,21 @@ def check_test_type_enable(
134132
raise exceptions.TestTypesError()
135133
return v
136134

137-
@validator("test_types_configuration", always=True)
138-
def check_result_scope(
139-
cls, v: "TestTypesConfiguration", values: Dict[str, Any]
140-
) -> "TestTypesConfiguration":
141-
if "test_configuration" not in values:
142-
return v
135+
136+
@field_validator("test_types_configuration")
137+
def check_result_scope(cls, value: "TestTypesConfiguration", info: ValidationInfo) -> "TestTypesConfiguration":
138+
if "test_configuration" not in info.data:
139+
return value
143140
if (
144-
v.throughput_test.enabled
145-
and v.throughput_test.rate_iteration_options.result_scope
141+
value.throughput_test.enabled
142+
and value.throughput_test.rate_iteration_options.result_scope
146143
== const.RateResultScopeType.PER_SOURCE_PORT
147-
and not values[
144+
and not info.data[
148145
"test_configuration"
149146
].test_execution_config.flow_creation_config.flow_creation_type.is_stream_based
150147
):
151148
raise exceptions.ModifierBasedNotSupportPerPortResult()
152-
return v
149+
return value
153150

154151
@staticmethod
155152
def count_port_group(

plugin2544/model/m_port_config.py

Lines changed: 21 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -7,95 +7,54 @@
77
IPv6Address as OriginIPv6Address,
88
)
99
from typing import Union, Optional
10-
from pydantic import BaseModel, validator, Field
10+
from pydantic import BaseModel, field_validator, Field
1111
from ..utils import constants as const
1212
from ..utils.field import MacAddress, IPv4Address, IPv6Address, Prefix
1313
from .m_protocol_segment import ProtocolSegmentProfileConfig
1414

1515

16-
class IPAddressProperties(BaseModel):
17-
address: Union[IPv4Address, IPv6Address] = IPv4Address("0.0.0.0")
18-
routing_prefix: Prefix = Prefix(24)
19-
public_address: Union[IPv4Address, IPv6Address] = IPv4Address("0.0.0.0")
20-
public_routing_prefix: Prefix = Prefix(24)
21-
gateway: Union[IPv4Address, IPv6Address] = IPv4Address("0.0.0.0")
22-
remote_loop_address: Union[IPv4Address, IPv6Address] = IPv4Address("0.0.0.0")
16+
class IPAddressProperties(BaseModel, arbitrary_types_allowed=True):
17+
address: Union[IPv4Address, IPv6Address, str] = "0.0.0.0"
18+
routing_prefix: Prefix | int = Prefix(24)
19+
public_address: Union[IPv4Address, IPv6Address, str] = "0.0.0.0"
20+
public_routing_prefix: Prefix | int = Prefix(24)
21+
gateway: Union[IPv4Address, IPv6Address, str] = "0.0.0.0"
22+
remote_loop_address: Union[IPv4Address, IPv6Address, str] = "0.0.0.0"
2323
# ip_version: const.IPVersion = const.IPVersion.IPV6
2424

2525
@property
2626
def network(self) -> Union["IPv4Network", "IPv6Network"]:
2727
return ip_network(f"{self.address}/{self.routing_prefix}", strict=False)
2828

29-
@validator(
30-
"address",
31-
"public_address",
32-
"gateway",
33-
"remote_loop_address",
34-
pre=True,
35-
allow_reuse=True,
36-
)
29+
@field_validator("address", "public_address", "gateway", "remote_loop_address", mode="before")
3730
def set_address(
38-
cls, origin_addr: Union[str, "IPv4Address", "IPv6Address"]
31+
cls, value: Union[str, "IPv4Address", "IPv6Address"]
3932
) -> Union["IPv4Address", "IPv6Address"]:
40-
address = ip_address(origin_addr)
33+
address = ip_address(value)
4134
return (
4235
IPv4Address(address)
4336
if isinstance(address, OriginIPv4Address)
4437
else IPv6Address(address)
4538
)
4639

47-
@validator("routing_prefix", "public_routing_prefix", pre=True, allow_reuse=True)
48-
def set_prefix(cls, v: int) -> Prefix:
49-
return Prefix(v)
40+
@field_validator("routing_prefix", "public_routing_prefix", mode="before")
41+
def set_prefix(cls, value: int) -> Prefix:
42+
return Prefix(value)
5043

5144
@property
52-
def dst_addr(self) -> Union["IPv4Address", "IPv6Address"]:
45+
def dst_addr(self) -> Union["IPv4Address", "IPv6Address", str]:
5346
return self.public_address if not self.public_address.is_empty else self.address
5447

55-
56-
# class IPV4AddressProperties(BaseModel):
57-
# address: IPv4Address = IPv4Address("0.0.0.0")
58-
# routing_prefix: Prefix = Prefix(24)
59-
# public_address: IPv4Address = IPv4Address("0.0.0.0")
60-
# public_routing_prefix: Prefix = Prefix(24)
61-
# gateway: IPv4Address = IPv4Address("0.0.0.0")
62-
# remote_loop_address: IPv4Address = IPv4Address("0.0.0.0")
63-
# ip_version: const.IPVersion = const.IPVersion.IPV4
64-
65-
# @property
66-
# def network(self) -> "IPv4Network":
67-
# return IPv4Network(f"{self.address}/{self.routing_prefix}", strict=False)
68-
69-
# @validator(
70-
# "address",
71-
# "public_address",
72-
# "gateway",
73-
# "remote_loop_address",
74-
# pre=True,
75-
# allow_reuse=True,
76-
# )
77-
# def set_address(cls, v: Union[str, "IPv4Address"]) -> "IPv4Address":
78-
# return IPv4Address(v)
79-
80-
# @validator("routing_prefix", "public_routing_prefix", pre=True, allow_reuse=True)
81-
# def set_prefix(cls, v: int) -> Prefix:
82-
# return Prefix(v)
83-
84-
# @property
85-
# def dst_addr(self) -> "IPv4Address":
86-
# return self.public_address if not self.public_address.is_empty else self.address
87-
88-
89-
class PortConfiguration(BaseModel):
48+
class PortConfiguration(BaseModel, arbitrary_types_allowed=True):
9049
port_slot: int
9150
peer_slot: Optional[int]
9251
port_group: const.PortGroup
9352
port_speed_mode: const.PortSpeedStr
9453
ip_address: Optional[IPAddressProperties]
95-
ip_gateway_mac_address: MacAddress
54+
ip_gateway_mac_address: MacAddress | str
9655
reply_arp_requests: bool
9756
reply_ping_requests: bool
98-
remote_loop_mac_address: MacAddress
57+
remote_loop_mac_address: MacAddress | str
9958
inter_frame_gap: float
10059
speed_reduction_ppm: int = Field(ge=0)
10160
pause_mode_enabled: bool
@@ -121,12 +80,9 @@ class PortConfiguration(BaseModel):
12180
_is_tx: bool = True
12281
_is_rx: bool = True
12382

124-
class Config:
125-
underscore_attrs_are_private = True
126-
127-
@validator("ip_gateway_mac_address", pre=True)
128-
def set_ip_gateway_mac_address(cls, ip_gateway_mac_address: str) -> "MacAddress":
129-
return MacAddress(ip_gateway_mac_address)
83+
@field_validator("ip_gateway_mac_address", mode="before")
84+
def set_ip_gateway_mac_address(cls, value: str) -> "MacAddress":
85+
return MacAddress(value)
13086

13187
@property
13288
def is_tx_port(self) -> bool:

plugin2544/model/m_protocol_segment.py

Lines changed: 31 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,27 @@
11
import re
22
from enum import Enum
33
from random import randint
4-
from typing import Any, Callable, Dict, Generator, List, Optional
5-
from pydantic import BaseModel, Field
6-
from pydantic.class_validators import validator
4+
from typing import Any, Callable, Dict, Generator, List, Optional, Annotated
5+
from pydantic import BaseModel, Field, field_validator, ValidationInfo
6+
from pydantic_core import CoreSchema, core_schema
7+
from pydantic import GetCoreSchemaHandler, TypeAdapter
78
from xoa_driver.enums import ProtocolOption, ModifierAction
89
from ..utils.exceptions import ModifierRangeError
910

1011
class BinaryString(str):
1112
@classmethod
12-
def __get_validators__(cls) -> Generator[Callable, None, None]:
13-
yield cls.validate
14-
15-
@classmethod
16-
def validate(cls, v: str) -> "BinaryString":
17-
if not re.search("^[01]+$", v):
18-
raise ValueError("binary string must zero or one")
19-
return cls(v)
13+
def __get_pydantic_core_schema__(cls, source_type: Any, handler: GetCoreSchemaHandler) -> CoreSchema:
14+
return core_schema.no_info_after_validator_function(cls, handler(str))
15+
16+
# @classmethod
17+
# def __get_validators__(cls) -> Generator[Callable, None, None]:
18+
# yield cls.validate
19+
20+
# @classmethod
21+
# def validate(cls, v: str) -> "BinaryString":
22+
# if not re.search("^[01]+$", v):
23+
# raise ValueError("binary string must zero or one")
24+
# return cls(v)
2025

2126
@property
2227
def is_all_zero(self) -> bool:
@@ -91,13 +96,17 @@ class SegmentType(Enum):
9196

9297
@property
9398
def is_raw(self) -> bool:
94-
return self.value.lower().startswith("raw")
99+
if isinstance(self.value, str):
100+
return self.value.lower().startswith("raw")
101+
return False
95102

96103
@property
97104
def raw_length(self) -> int:
98105
if not self.is_raw:
99106
return 0
100-
return int(self.value.split("_")[-1])
107+
if isinstance(self.value, str):
108+
return int(self.value.split("_")[-1])
109+
return 0
101110

102111
def to_xmp(self) -> "ProtocolOption":
103112
return ProtocolOption[self.name]
@@ -123,9 +132,6 @@ class ValueRange(BaseModel):
123132
restart_for_each_port: bool
124133
_current_count: int = 0 # counter start from 0
125134

126-
class Config:
127-
underscore_attrs_are_private = True
128-
129135
def reset(self) -> None:
130136
self._current_count = 0
131137

@@ -157,21 +163,20 @@ def get_current_value(self) -> int:
157163
class HWModifier(BaseModel):
158164
start_value: int
159165
step_value: int = Field(gt=0)
160-
stop_value: int
166+
stop_value: Annotated[int, Field(validate_default=True)]
161167
repeat: int
162168
offset: int
163169
action: ModifierActionOption
164170
mask: str # hex string as 'FFFF'
165171
_byte_segment_position: int = 0 # byte position of all header segments
166172

167-
class Config:
168-
underscore_attrs_are_private = True
169-
170-
@validator('stop_value', pre=True, always=True)
171-
def validate_modifier_value(cls, v: int, values: Dict[str, Any]):
172-
if (v - values['start_value']) % values['step_value']:
173-
raise ModifierRangeError(values['start_value'], v, values['step_value'])
174-
return v
173+
@field_validator('stop_value', mode="before")
174+
def validate_modifier_value(cls, value: int, info: ValidationInfo):
175+
start_value = info.data['start_value']
176+
step_value = info.data['step_value']
177+
if (value - start_value) % step_value:
178+
raise ModifierRangeError(start_value, value, step_value)
179+
return value
175180

176181

177182
def set_byte_segment_position(self, position: int) -> None:
@@ -241,7 +246,7 @@ def hw_modifiers(self) -> Generator["HWModifier", None, None]:
241246
def value_ranges(self) -> Generator["ValueRange", None, None]:
242247
return (f.value_range for f in self.fields if f.value_range)
243248

244-
@validator("checksum_offset")
249+
@field_validator("checksum_offset")
245250
def is_digit(cls, value: int) -> int:
246251
if value and not isinstance(value, int):
247252
raise ValueError("checksum offset must digit")

0 commit comments

Comments
 (0)