Skip to content

Pyright and ruff checking for Database Server #399

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Oct 18, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 56 additions & 31 deletions DatabaseServer/database_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from functools import partial
from threading import RLock, Thread
from time import sleep
from typing import Callable, Literal

from pcaspy import Driver

Expand All @@ -33,6 +34,7 @@
from genie_python.mysql_abstraction_layer import SQLAbstraction

from DatabaseServer.exp_data import ExpData, ExpDataSource
from DatabaseServer.mocks.mock_exp_data import MockExpData
from DatabaseServer.moxa_data import MoxaData, MoxaDataSource
from DatabaseServer.options_holder import OptionsHolder
from DatabaseServer.options_loader import OptionsLoader
Expand All @@ -42,6 +44,7 @@
from server_common.ioc_data import IOCData
from server_common.ioc_data_source import IocDataSource
from server_common.loggers.isis_logger import IsisLogger
from server_common.mocks.mock_ca_server import MockCAServer
from server_common.pv_names import DatabasePVNames as DbPVNames
from server_common.utilities import (
char_waveform,
Expand Down Expand Up @@ -72,22 +75,23 @@ class DatabaseServer(Driver):

def __init__(
self,
ca_server: CAServer,
ca_server: CAServer | MockCAServer,
ioc_data: IOCData,
exp_data: ExpData,
exp_data: ExpData | MockExpData,
moxa_data: MoxaData,
options_folder: str,
blockserver_prefix: str,
test_mode: bool = False,
):
) -> None:
"""
Constructor.

Args:
ca_server: The CA server used for generating PVs on the fly
ioc_data: The data source for IOC information
exp_data: The data source for experiment information
options_folder: The location of the folder containing the config.xml file that holds IOC options
options_folder: The location of the folder containing the config.xml file that holds IOC
options
blockserver_prefix: The PV prefix to use
test_mode: Enables starting the server in a mode suitable for unit tests
"""
Expand Down Expand Up @@ -118,7 +122,7 @@ def _generate_pv_acquisition_info(self) -> dict:
"""
enhanced_info = DatabaseServer.generate_pv_info()

def add_get_method(pv, get_function):
def add_get_method(pv: str, get_function: Callable[[], list | str | dict]) -> None:
enhanced_info[pv]["get"] = get_function

add_get_method(DbPVNames.IOCS, self._get_iocs_info)
Expand Down Expand Up @@ -187,25 +191,28 @@ def get_data_for_pv(self, pv: str) -> bytes:
self._check_pv_capacity(pv, len(data), self._blockserver_prefix)
return data

def read(self, reason: str) -> str:
def read(self, reason: str) -> bytes:
"""
A method called by SimpleServer when a PV is read from the DatabaseServer over Channel Access.
A method called by SimpleServer when a PV is read from the DatabaseServer over Channel
Access.

Args:
reason: The PV that is being requested (without the PV prefix)

Returns:
A compressed and hexed JSON formatted string that gives the desired information based on reason.
A compressed and hexed JSON formatted string that gives the desired information based on
reason.
"""
return (
self.get_data_for_pv(reason)
if reason in self._pv_info.keys()
else self.getParam(reason)
)

def write(self, reason: str, value: str) -> bool:
def write(self, reason: str, value: str) -> Literal[True]:
"""
A method called by SimpleServer when a PV is written to the DatabaseServer over Channel Access.
A method called by SimpleServer when a PV is written to the DatabaseServer over Channel
Access.

Args:
reason: The PV that is being requested (without the PV prefix)
Expand All @@ -224,8 +231,10 @@ def write(self, reason: str, value: str) -> bool:
elif reason == "UPDATE_MM":
self._moxa_data.update_mappings()
except Exception as e:
value = compress_and_hex(convert_to_json("Error: " + str(e)))
value_bytes = compress_and_hex(convert_to_json("Error: " + str(e)))
print_and_log(str(e), MAJOR_MSG)
self.setParam(reason, value_bytes)
return True
# store the values
self.setParam(reason, value)
return True
Expand Down Expand Up @@ -266,9 +275,8 @@ def _check_pv_capacity(self, pv: str, size: int, prefix: str) -> None:
"""
if size > self._pv_info[pv]["count"]:
print_and_log(
"Too much data to encode PV {0}. Current size is {1} characters but {2} are required".format(
prefix + pv, self._pv_info[pv]["count"], size
),
"Too much data to encode PV {0}. Current size is {1} characters "
"but {2} are required".format(prefix + pv, self._pv_info[pv]["count"], size),
MAJOR_MSG,
LOG_TARGET,
)
Expand All @@ -281,10 +289,12 @@ def _get_iocs_info(self) -> dict:
iocs[iocname].update(options[iocname])
return iocs

def _get_pvs(self, get_method: callable, replace_pv_prefix: bool, *get_args: list) -> list:
def _get_pvs(
self, get_method: Callable[[], list | str | dict], replace_pv_prefix: bool, *get_args: str
) -> list | str | dict:
"""
Method to get pv data using the given method called with the given arguments and optionally remove instrument
prefixes from pv names.
Method to get pv data using the given method called with the given arguments and optionally
remove instrument prefixes from pv names.

Args:
get_method: The method used to get pv data.
Expand All @@ -301,17 +311,19 @@ def _get_pvs(self, get_method: callable, replace_pv_prefix: bool, *get_args: lis
else:
return []

def _get_interesting_pvs(self, level) -> list:
def _get_interesting_pvs(self, level: str) -> list:
"""
Gets interesting pvs of the current instrument.

Args:
level: The level of high interesting pvs, can be high, low, medium or facility. If level is an empty
string, it returns all interesting pvs of all levels.
level: The level of high interesting pvs, can be high, low, medium or facility.
If level is an empty string, it returns all interesting pvs of all levels.
Returns:
a list of names of pvs with given level of interest.
"""
return self._get_pvs(self._iocs.get_interesting_pvs, False, level)
result = self._get_pvs(self._iocs.get_interesting_pvs, False, level)
assert isinstance(result, list)
return result

def _get_active_pvs(self) -> list:
"""
Expand All @@ -320,7 +332,9 @@ def _get_active_pvs(self) -> list:
Returns:
a list of names of pvs.
"""
return self._get_pvs(self._iocs.get_active_pvs, False)
result = self._get_pvs(self._iocs.get_active_pvs, False)
assert isinstance(result, list)
return result

def _get_sample_par_names(self) -> list:
"""
Expand All @@ -329,7 +343,9 @@ def _get_sample_par_names(self) -> list:
Returns:
A list of sample parameter names, an empty list if the database does not exist
"""
return self._get_pvs(self._iocs.get_sample_pars, True)
result = self._get_pvs(self._iocs.get_sample_pars, True)
assert isinstance(result, list)
return result

def _get_beamline_par_names(self) -> list:
"""
Expand All @@ -338,7 +354,9 @@ def _get_beamline_par_names(self) -> list:
Returns:
A list of beamline parameter names, an empty list if the database does not exist
"""
return self._get_pvs(self._iocs.get_beamline_pars, True)
result = self._get_pvs(self._iocs.get_beamline_pars, True)
assert isinstance(result, list)
return result

def _get_user_par_names(self) -> list:
"""
Expand All @@ -347,19 +365,25 @@ def _get_user_par_names(self) -> list:
Returns:
A list of user parameter names, an empty list if the database does not exist
"""
return self._get_pvs(self._iocs.get_user_pars, True)
result = self._get_pvs(self._iocs.get_user_pars, True)
assert isinstance(result, list)
return result

def _get_moxa_mappings(self) -> list:
def _get_moxa_mappings(self) -> dict:
"""
Returns the user parameters from the database, replacing the MYPVPREFIX macro.

Returns:
An ordered dict of moxa models and their respective COM mappings
"""
return self._get_pvs(self._moxa_data._get_mappings_str, False)
result = self._get_pvs(self._moxa_data._get_mappings_str, False)
assert isinstance(result, dict)
return result

def _get_num_of_moxas(self):
return self._get_pvs(self._moxa_data._get_moxa_num, True)
def _get_num_of_moxas(self) -> str:
result = self._get_pvs(self._moxa_data._get_moxa_num, True)
assert isinstance(result, str)
return result

@staticmethod
def _get_iocs_not_to_stop() -> list:
Expand All @@ -369,7 +393,7 @@ def _get_iocs_not_to_stop() -> list:
Returns:
A list of IOCs not to stop
"""
return IOCS_NOT_TO_STOP
return list(IOCS_NOT_TO_STOP)


if __name__ == "__main__":
Expand All @@ -390,7 +414,8 @@ def _get_iocs_not_to_stop() -> list:
nargs=1,
type=str,
default=["."],
help="The directory from which to load the configuration options(default=current directory)",
help="The directory from which to load the configuration options"
"(default=current directory)",
)

args = parser.parse_args()
Expand Down
46 changes: 28 additions & 18 deletions DatabaseServer/exp_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,27 @@
# http://opensource.org/licenses/eclipse-1.0.php
import json
import traceback
import typing
import unicodedata
from typing import Union
from typing import TYPE_CHECKING, Union

from genie_python.mysql_abstraction_layer import SQLAbstraction

from server_common.channel_access import ChannelAccess
from server_common.mocks.mock_ca import MockChannelAccess
from server_common.utilities import char_waveform, compress_and_hex, print_and_log

if TYPE_CHECKING:
from DatabaseServer.test_modules.test_exp_data import MockExpDataSource


class User(object):
"""
A user class to allow for easier conversions from database to json.
"""

def __init__(self, name: str = "UNKNOWN", institute: str = "UNKNOWN", role: str = "UNKNOWN"):
def __init__(
self, name: str = "UNKNOWN", institute: str = "UNKNOWN", role: str = "UNKNOWN"
) -> None:
self.name = name
self.institute = institute
self.role = role
Expand All @@ -43,7 +48,7 @@ class ExpDataSource(object):
This is a humble object containing all the code for accessing the database.
"""

def __init__(self):
def __init__(self) -> None:
self._db = SQLAbstraction("exp_data", "exp_data", "$exp_data")

def get_team(self, experiment_id: str) -> list:
Expand All @@ -64,7 +69,10 @@ def get_team(self, experiment_id: str) -> list:
sqlquery += " AND experimentteams.experimentID = %s"
sqlquery += " GROUP BY user.userID"
sqlquery += " ORDER BY role.priority"
team = [list(element) for element in self._db.query(sqlquery, (experiment_id,))]
result = self._db.query(sqlquery, (experiment_id,))
if result is None:
return []
team = [list(element) for element in result]
if len(team) == 0:
raise ValueError(
"unable to find team details for experiment ID {}".format(experiment_id)
Expand All @@ -90,6 +98,8 @@ def experiment_exists(self, experiment_id: str) -> bool:
sqlquery += " FROM experiment "
sqlquery += " WHERE experiment.experimentID = %s"
id = self._db.query(sqlquery, (experiment_id,))
if id is None:
return False
return len(id) >= 1
except Exception:
print_and_log(traceback.format_exc())
Expand All @@ -108,9 +118,9 @@ class ExpData(object):
def __init__(
self,
prefix: str,
db: Union[ExpDataSource, "MockExpDataSource"],
ca: Union[ChannelAccess, "MockChannelAccess"] = ChannelAccess(),
):
db: Union[ExpDataSource, MockExpDataSource],
ca: Union[ChannelAccess, MockChannelAccess] = ChannelAccess(),
) -> None:
"""
Constructor.

Expand Down Expand Up @@ -148,7 +158,7 @@ def _make_ascii_mappings() -> dict:
d[ord("\xe6")] = "ae"
return d

def encode_for_return(self, data: typing.Any) -> bytes:
def encode_for_return(self, data: dict | list) -> bytes:
"""
Converts data to JSON, compresses it and converts it to hex.

Expand All @@ -163,7 +173,7 @@ def encode_for_return(self, data: typing.Any) -> bytes:
def _get_surname_from_fullname(self, fullname: str) -> str:
try:
return fullname.split(" ")[-1]
except:
except ValueError | IndexError:
return fullname

def update_experiment_id(self, experiment_id: str) -> None:
Expand Down Expand Up @@ -210,11 +220,11 @@ def update_experiment_id(self, experiment_id: str) -> None:
self.ca.caput(self._simnames, self.encode_for_return(names))
self.ca.caput(self._surnamepv, self.encode_for_return(surnames))
self.ca.caput(self._orgspv, self.encode_for_return(orgs))
# The value put to the dae names pv will need changing in time to use compressed and hexed json etc. but
# this is not available at this time in the ICP
# The value put to the dae names pv will need changing in time to use compressed and
# hexed json etc. but this is not available at this time in the ICP
self.ca.caput(self._daenamespv, ExpData.make_name_list_ascii(surnames))

def update_username(self, users: str) -> None:
def update_username(self, user_str: str) -> None:
"""
Updates the associated PVs when the User Names are altered.

Expand All @@ -229,7 +239,7 @@ def update_username(self, users: str) -> None:
surnames = []
orgs = []

users = json.loads(users) if users else []
users = json.loads(user_str) if user_str else []

# Find user details in deserialized json user data
for team_member in users:
Expand All @@ -249,8 +259,8 @@ def update_username(self, users: str) -> None:
self.ca.caput(self._simnames, self.encode_for_return(names))
self.ca.caput(self._surnamepv, self.encode_for_return(surnames))
self.ca.caput(self._orgspv, self.encode_for_return(orgs))
# The value put to the dae names pv will need changing in time to use compressed and hexed json etc. but
# this is not available at this time in the ICP
# The value put to the dae names pv will need changing in time to use compressed and hexed
# json etc. but this is not available at this time in the ICP
if not surnames:
self.ca.caput(self._daenamespv, " ")
else:
Expand All @@ -259,8 +269,8 @@ def update_username(self, users: str) -> None:
@staticmethod
def make_name_list_ascii(names: list) -> bytes:
"""
Takes a unicode list of names and creates a best ascii comma separated list this implementation is a temporary
fix until we install the PyPi unidecode module.
Takes a unicode list of names and creates a best ascii comma separated list this
implementation is a temporary fix until we install the PyPi unidecode module.

Args:
names: list of unicode names
Expand Down
5 changes: 3 additions & 2 deletions DatabaseServer/ioc_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,16 @@
class IocOptions(object):
"""Contains the possible macros and pvsets of an IOC."""

def __init__(self, name: str):
def __init__(self, name: str) -> None:
"""Constructor

Args:
name: The name of the IOC the options are associated with
"""
self.name = name

# The possible macros, pvsets and pvs for an IOC, along with associated parameters such as description
# The possible macros, pvsets and pvs for an IOC, along with associated parameters such as
# description
self.macros = dict()
self.pvsets = dict()
self.pvs = dict()
Expand Down
Loading