Skip to content

Commit c6d48ab

Browse files
Merge pull request #399 from ISISComputingGroup/Ticket8527_Database_server_Type_checking
Pyright and ruff checking for Database Server
2 parents 366673e + 77992f7 commit c6d48ab

11 files changed

+202
-142
lines changed

DatabaseServer/database_server.py

Lines changed: 56 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from functools import partial
2626
from threading import RLock, Thread
2727
from time import sleep
28+
from typing import Callable, Literal
2829

2930
from pcaspy import Driver
3031

@@ -33,6 +34,7 @@
3334
from genie_python.mysql_abstraction_layer import SQLAbstraction
3435

3536
from DatabaseServer.exp_data import ExpData, ExpDataSource
37+
from DatabaseServer.mocks.mock_exp_data import MockExpData
3638
from DatabaseServer.moxa_data import MoxaData, MoxaDataSource
3739
from DatabaseServer.options_holder import OptionsHolder
3840
from DatabaseServer.options_loader import OptionsLoader
@@ -42,6 +44,7 @@
4244
from server_common.ioc_data import IOCData
4345
from server_common.ioc_data_source import IocDataSource
4446
from server_common.loggers.isis_logger import IsisLogger
47+
from server_common.mocks.mock_ca_server import MockCAServer
4548
from server_common.pv_names import DatabasePVNames as DbPVNames
4649
from server_common.utilities import (
4750
char_waveform,
@@ -72,22 +75,23 @@ class DatabaseServer(Driver):
7275

7376
def __init__(
7477
self,
75-
ca_server: CAServer,
78+
ca_server: CAServer | MockCAServer,
7679
ioc_data: IOCData,
77-
exp_data: ExpData,
80+
exp_data: ExpData | MockExpData,
7881
moxa_data: MoxaData,
7982
options_folder: str,
8083
blockserver_prefix: str,
8184
test_mode: bool = False,
82-
):
85+
) -> None:
8386
"""
8487
Constructor.
8588
8689
Args:
8790
ca_server: The CA server used for generating PVs on the fly
8891
ioc_data: The data source for IOC information
8992
exp_data: The data source for experiment information
90-
options_folder: The location of the folder containing the config.xml file that holds IOC options
93+
options_folder: The location of the folder containing the config.xml file that holds IOC
94+
options
9195
blockserver_prefix: The PV prefix to use
9296
test_mode: Enables starting the server in a mode suitable for unit tests
9397
"""
@@ -118,7 +122,7 @@ def _generate_pv_acquisition_info(self) -> dict:
118122
"""
119123
enhanced_info = DatabaseServer.generate_pv_info()
120124

121-
def add_get_method(pv, get_function):
125+
def add_get_method(pv: str, get_function: Callable[[], list | str | dict]) -> None:
122126
enhanced_info[pv]["get"] = get_function
123127

124128
add_get_method(DbPVNames.IOCS, self._get_iocs_info)
@@ -187,25 +191,28 @@ def get_data_for_pv(self, pv: str) -> bytes:
187191
self._check_pv_capacity(pv, len(data), self._blockserver_prefix)
188192
return data
189193

190-
def read(self, reason: str) -> str:
194+
def read(self, reason: str) -> bytes:
191195
"""
192-
A method called by SimpleServer when a PV is read from the DatabaseServer over Channel Access.
196+
A method called by SimpleServer when a PV is read from the DatabaseServer over Channel
197+
Access.
193198
194199
Args:
195200
reason: The PV that is being requested (without the PV prefix)
196201
197202
Returns:
198-
A compressed and hexed JSON formatted string that gives the desired information based on reason.
203+
A compressed and hexed JSON formatted string that gives the desired information based on
204+
reason.
199205
"""
200206
return (
201207
self.get_data_for_pv(reason)
202208
if reason in self._pv_info.keys()
203209
else self.getParam(reason)
204210
)
205211

206-
def write(self, reason: str, value: str) -> bool:
212+
def write(self, reason: str, value: str) -> Literal[True]:
207213
"""
208-
A method called by SimpleServer when a PV is written to the DatabaseServer over Channel Access.
214+
A method called by SimpleServer when a PV is written to the DatabaseServer over Channel
215+
Access.
209216
210217
Args:
211218
reason: The PV that is being requested (without the PV prefix)
@@ -224,8 +231,10 @@ def write(self, reason: str, value: str) -> bool:
224231
elif reason == "UPDATE_MM":
225232
self._moxa_data.update_mappings()
226233
except Exception as e:
227-
value = compress_and_hex(convert_to_json("Error: " + str(e)))
234+
value_bytes = compress_and_hex(convert_to_json("Error: " + str(e)))
228235
print_and_log(str(e), MAJOR_MSG)
236+
self.setParam(reason, value_bytes)
237+
return True
229238
# store the values
230239
self.setParam(reason, value)
231240
return True
@@ -266,9 +275,8 @@ def _check_pv_capacity(self, pv: str, size: int, prefix: str) -> None:
266275
"""
267276
if size > self._pv_info[pv]["count"]:
268277
print_and_log(
269-
"Too much data to encode PV {0}. Current size is {1} characters but {2} are required".format(
270-
prefix + pv, self._pv_info[pv]["count"], size
271-
),
278+
"Too much data to encode PV {0}. Current size is {1} characters "
279+
"but {2} are required".format(prefix + pv, self._pv_info[pv]["count"], size),
272280
MAJOR_MSG,
273281
LOG_TARGET,
274282
)
@@ -281,10 +289,12 @@ def _get_iocs_info(self) -> dict:
281289
iocs[iocname].update(options[iocname])
282290
return iocs
283291

284-
def _get_pvs(self, get_method: callable, replace_pv_prefix: bool, *get_args: list) -> list:
292+
def _get_pvs(
293+
self, get_method: Callable[[], list | str | dict], replace_pv_prefix: bool, *get_args: str
294+
) -> list | str | dict:
285295
"""
286-
Method to get pv data using the given method called with the given arguments and optionally remove instrument
287-
prefixes from pv names.
296+
Method to get pv data using the given method called with the given arguments and optionally
297+
remove instrument prefixes from pv names.
288298
289299
Args:
290300
get_method: The method used to get pv data.
@@ -301,17 +311,19 @@ def _get_pvs(self, get_method: callable, replace_pv_prefix: bool, *get_args: lis
301311
else:
302312
return []
303313

304-
def _get_interesting_pvs(self, level) -> list:
314+
def _get_interesting_pvs(self, level: str) -> list:
305315
"""
306316
Gets interesting pvs of the current instrument.
307317
308318
Args:
309-
level: The level of high interesting pvs, can be high, low, medium or facility. If level is an empty
310-
string, it returns all interesting pvs of all levels.
319+
level: The level of high interesting pvs, can be high, low, medium or facility.
320+
If level is an empty string, it returns all interesting pvs of all levels.
311321
Returns:
312322
a list of names of pvs with given level of interest.
313323
"""
314-
return self._get_pvs(self._iocs.get_interesting_pvs, False, level)
324+
result = self._get_pvs(self._iocs.get_interesting_pvs, False, level)
325+
assert isinstance(result, list)
326+
return result
315327

316328
def _get_active_pvs(self) -> list:
317329
"""
@@ -320,7 +332,9 @@ def _get_active_pvs(self) -> list:
320332
Returns:
321333
a list of names of pvs.
322334
"""
323-
return self._get_pvs(self._iocs.get_active_pvs, False)
335+
result = self._get_pvs(self._iocs.get_active_pvs, False)
336+
assert isinstance(result, list)
337+
return result
324338

325339
def _get_sample_par_names(self) -> list:
326340
"""
@@ -329,7 +343,9 @@ def _get_sample_par_names(self) -> list:
329343
Returns:
330344
A list of sample parameter names, an empty list if the database does not exist
331345
"""
332-
return self._get_pvs(self._iocs.get_sample_pars, True)
346+
result = self._get_pvs(self._iocs.get_sample_pars, True)
347+
assert isinstance(result, list)
348+
return result
333349

334350
def _get_beamline_par_names(self) -> list:
335351
"""
@@ -338,7 +354,9 @@ def _get_beamline_par_names(self) -> list:
338354
Returns:
339355
A list of beamline parameter names, an empty list if the database does not exist
340356
"""
341-
return self._get_pvs(self._iocs.get_beamline_pars, True)
357+
result = self._get_pvs(self._iocs.get_beamline_pars, True)
358+
assert isinstance(result, list)
359+
return result
342360

343361
def _get_user_par_names(self) -> list:
344362
"""
@@ -347,19 +365,25 @@ def _get_user_par_names(self) -> list:
347365
Returns:
348366
A list of user parameter names, an empty list if the database does not exist
349367
"""
350-
return self._get_pvs(self._iocs.get_user_pars, True)
368+
result = self._get_pvs(self._iocs.get_user_pars, True)
369+
assert isinstance(result, list)
370+
return result
351371

352-
def _get_moxa_mappings(self) -> list:
372+
def _get_moxa_mappings(self) -> dict:
353373
"""
354374
Returns the user parameters from the database, replacing the MYPVPREFIX macro.
355375
356376
Returns:
357377
An ordered dict of moxa models and their respective COM mappings
358378
"""
359-
return self._get_pvs(self._moxa_data._get_mappings_str, False)
379+
result = self._get_pvs(self._moxa_data._get_mappings_str, False)
380+
assert isinstance(result, dict)
381+
return result
360382

361-
def _get_num_of_moxas(self):
362-
return self._get_pvs(self._moxa_data._get_moxa_num, True)
383+
def _get_num_of_moxas(self) -> str:
384+
result = self._get_pvs(self._moxa_data._get_moxa_num, True)
385+
assert isinstance(result, str)
386+
return result
363387

364388
@staticmethod
365389
def _get_iocs_not_to_stop() -> list:
@@ -369,7 +393,7 @@ def _get_iocs_not_to_stop() -> list:
369393
Returns:
370394
A list of IOCs not to stop
371395
"""
372-
return IOCS_NOT_TO_STOP
396+
return list(IOCS_NOT_TO_STOP)
373397

374398

375399
if __name__ == "__main__":
@@ -390,7 +414,8 @@ def _get_iocs_not_to_stop() -> list:
390414
nargs=1,
391415
type=str,
392416
default=["."],
393-
help="The directory from which to load the configuration options(default=current directory)",
417+
help="The directory from which to load the configuration options"
418+
"(default=current directory)",
394419
)
395420

396421
args = parser.parse_args()

DatabaseServer/exp_data.py

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -17,22 +17,27 @@
1717
# http://opensource.org/licenses/eclipse-1.0.php
1818
import json
1919
import traceback
20-
import typing
2120
import unicodedata
22-
from typing import Union
21+
from typing import TYPE_CHECKING, Union
2322

2423
from genie_python.mysql_abstraction_layer import SQLAbstraction
2524

2625
from server_common.channel_access import ChannelAccess
26+
from server_common.mocks.mock_ca import MockChannelAccess
2727
from server_common.utilities import char_waveform, compress_and_hex, print_and_log
2828

29+
if TYPE_CHECKING:
30+
from DatabaseServer.test_modules.test_exp_data import MockExpDataSource
31+
2932

3033
class User(object):
3134
"""
3235
A user class to allow for easier conversions from database to json.
3336
"""
3437

35-
def __init__(self, name: str = "UNKNOWN", institute: str = "UNKNOWN", role: str = "UNKNOWN"):
38+
def __init__(
39+
self, name: str = "UNKNOWN", institute: str = "UNKNOWN", role: str = "UNKNOWN"
40+
) -> None:
3641
self.name = name
3742
self.institute = institute
3843
self.role = role
@@ -43,7 +48,7 @@ class ExpDataSource(object):
4348
This is a humble object containing all the code for accessing the database.
4449
"""
4550

46-
def __init__(self):
51+
def __init__(self) -> None:
4752
self._db = SQLAbstraction("exp_data", "exp_data", "$exp_data")
4853

4954
def get_team(self, experiment_id: str) -> list:
@@ -64,7 +69,10 @@ def get_team(self, experiment_id: str) -> list:
6469
sqlquery += " AND experimentteams.experimentID = %s"
6570
sqlquery += " GROUP BY user.userID"
6671
sqlquery += " ORDER BY role.priority"
67-
team = [list(element) for element in self._db.query(sqlquery, (experiment_id,))]
72+
result = self._db.query(sqlquery, (experiment_id,))
73+
if result is None:
74+
return []
75+
team = [list(element) for element in result]
6876
if len(team) == 0:
6977
raise ValueError(
7078
"unable to find team details for experiment ID {}".format(experiment_id)
@@ -90,6 +98,8 @@ def experiment_exists(self, experiment_id: str) -> bool:
9098
sqlquery += " FROM experiment "
9199
sqlquery += " WHERE experiment.experimentID = %s"
92100
id = self._db.query(sqlquery, (experiment_id,))
101+
if id is None:
102+
return False
93103
return len(id) >= 1
94104
except Exception:
95105
print_and_log(traceback.format_exc())
@@ -109,8 +119,8 @@ def __init__(
109119
self,
110120
prefix: str,
111121
db: Union[ExpDataSource, "MockExpDataSource"],
112-
ca: Union[ChannelAccess, "MockChannelAccess"] = ChannelAccess(),
113-
):
122+
ca: Union[ChannelAccess, MockChannelAccess] = ChannelAccess(),
123+
) -> None:
114124
"""
115125
Constructor.
116126
@@ -148,7 +158,7 @@ def _make_ascii_mappings() -> dict:
148158
d[ord("\xe6")] = "ae"
149159
return d
150160

151-
def encode_for_return(self, data: typing.Any) -> bytes:
161+
def encode_for_return(self, data: dict | list) -> bytes:
152162
"""
153163
Converts data to JSON, compresses it and converts it to hex.
154164
@@ -163,7 +173,7 @@ def encode_for_return(self, data: typing.Any) -> bytes:
163173
def _get_surname_from_fullname(self, fullname: str) -> str:
164174
try:
165175
return fullname.split(" ")[-1]
166-
except:
176+
except ValueError | IndexError:
167177
return fullname
168178

169179
def update_experiment_id(self, experiment_id: str) -> None:
@@ -210,11 +220,11 @@ def update_experiment_id(self, experiment_id: str) -> None:
210220
self.ca.caput(self._simnames, self.encode_for_return(names))
211221
self.ca.caput(self._surnamepv, self.encode_for_return(surnames))
212222
self.ca.caput(self._orgspv, self.encode_for_return(orgs))
213-
# The value put to the dae names pv will need changing in time to use compressed and hexed json etc. but
214-
# this is not available at this time in the ICP
223+
# The value put to the dae names pv will need changing in time to use compressed and
224+
# hexed json etc. but this is not available at this time in the ICP
215225
self.ca.caput(self._daenamespv, ExpData.make_name_list_ascii(surnames))
216226

217-
def update_username(self, users: str) -> None:
227+
def update_username(self, user_str: str) -> None:
218228
"""
219229
Updates the associated PVs when the User Names are altered.
220230
@@ -229,7 +239,7 @@ def update_username(self, users: str) -> None:
229239
surnames = []
230240
orgs = []
231241

232-
users = json.loads(users) if users else []
242+
users = json.loads(user_str) if user_str else []
233243

234244
# Find user details in deserialized json user data
235245
for team_member in users:
@@ -249,8 +259,8 @@ def update_username(self, users: str) -> None:
249259
self.ca.caput(self._simnames, self.encode_for_return(names))
250260
self.ca.caput(self._surnamepv, self.encode_for_return(surnames))
251261
self.ca.caput(self._orgspv, self.encode_for_return(orgs))
252-
# The value put to the dae names pv will need changing in time to use compressed and hexed json etc. but
253-
# this is not available at this time in the ICP
262+
# The value put to the dae names pv will need changing in time to use compressed and hexed
263+
# json etc. but this is not available at this time in the ICP
254264
if not surnames:
255265
self.ca.caput(self._daenamespv, " ")
256266
else:
@@ -259,8 +269,8 @@ def update_username(self, users: str) -> None:
259269
@staticmethod
260270
def make_name_list_ascii(names: list) -> bytes:
261271
"""
262-
Takes a unicode list of names and creates a best ascii comma separated list this implementation is a temporary
263-
fix until we install the PyPi unidecode module.
272+
Takes a unicode list of names and creates a best ascii comma separated list this
273+
implementation is a temporary fix until we install the PyPi unidecode module.
264274
265275
Args:
266276
names: list of unicode names

DatabaseServer/ioc_options.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,16 @@
2020
class IocOptions(object):
2121
"""Contains the possible macros and pvsets of an IOC."""
2222

23-
def __init__(self, name: str):
23+
def __init__(self, name: str) -> None:
2424
"""Constructor
2525
2626
Args:
2727
name: The name of the IOC the options are associated with
2828
"""
2929
self.name = name
3030

31-
# The possible macros, pvsets and pvs for an IOC, along with associated parameters such as description
31+
# The possible macros, pvsets and pvs for an IOC, along with associated parameters such as
32+
# description
3233
self.macros = dict()
3334
self.pvsets = dict()
3435
self.pvs = dict()

0 commit comments

Comments
 (0)