diff --git a/DatabaseServer/database_server.py b/DatabaseServer/database_server.py index 08d8c63e..07665252 100644 --- a/DatabaseServer/database_server.py +++ b/DatabaseServer/database_server.py @@ -25,6 +25,7 @@ from functools import partial from threading import RLock, Thread from time import sleep +from typing import Callable, Literal from pcaspy import Driver @@ -33,6 +34,7 @@ from genie_python.mysql_abstraction_layer import SQLAbstraction from DatabaseServer.exp_data import ExpData, ExpDataSource +from DatabaseServer.mocks.mock_exp_data import MockExpData from DatabaseServer.moxa_data import MoxaData, MoxaDataSource from DatabaseServer.options_holder import OptionsHolder from DatabaseServer.options_loader import OptionsLoader @@ -42,6 +44,7 @@ from server_common.ioc_data import IOCData from server_common.ioc_data_source import IocDataSource from server_common.loggers.isis_logger import IsisLogger +from server_common.mocks.mock_ca_server import MockCAServer from server_common.pv_names import DatabasePVNames as DbPVNames from server_common.utilities import ( char_waveform, @@ -72,14 +75,14 @@ class DatabaseServer(Driver): def __init__( self, - ca_server: CAServer, + ca_server: CAServer | MockCAServer, ioc_data: IOCData, - exp_data: ExpData, + exp_data: ExpData | MockExpData, moxa_data: MoxaData, options_folder: str, blockserver_prefix: str, test_mode: bool = False, - ): + ) -> None: """ Constructor. @@ -87,7 +90,8 @@ def __init__( ca_server: The CA server used for generating PVs on the fly ioc_data: The data source for IOC information exp_data: The data source for experiment information - options_folder: The location of the folder containing the config.xml file that holds IOC options + options_folder: The location of the folder containing the config.xml file that holds IOC + options blockserver_prefix: The PV prefix to use test_mode: Enables starting the server in a mode suitable for unit tests """ @@ -118,7 +122,7 @@ def _generate_pv_acquisition_info(self) -> dict: """ enhanced_info = DatabaseServer.generate_pv_info() - def add_get_method(pv, get_function): + def add_get_method(pv: str, get_function: Callable[[], list | str | dict]) -> None: enhanced_info[pv]["get"] = get_function add_get_method(DbPVNames.IOCS, self._get_iocs_info) @@ -187,15 +191,17 @@ def get_data_for_pv(self, pv: str) -> bytes: self._check_pv_capacity(pv, len(data), self._blockserver_prefix) return data - def read(self, reason: str) -> str: + def read(self, reason: str) -> bytes: """ - A method called by SimpleServer when a PV is read from the DatabaseServer over Channel Access. + A method called by SimpleServer when a PV is read from the DatabaseServer over Channel + Access. Args: reason: The PV that is being requested (without the PV prefix) Returns: - A compressed and hexed JSON formatted string that gives the desired information based on reason. + A compressed and hexed JSON formatted string that gives the desired information based on + reason. """ return ( self.get_data_for_pv(reason) @@ -203,9 +209,10 @@ def read(self, reason: str) -> str: else self.getParam(reason) ) - def write(self, reason: str, value: str) -> bool: + def write(self, reason: str, value: str) -> Literal[True]: """ - A method called by SimpleServer when a PV is written to the DatabaseServer over Channel Access. + A method called by SimpleServer when a PV is written to the DatabaseServer over Channel + Access. Args: reason: The PV that is being requested (without the PV prefix) @@ -224,8 +231,10 @@ def write(self, reason: str, value: str) -> bool: elif reason == "UPDATE_MM": self._moxa_data.update_mappings() except Exception as e: - value = compress_and_hex(convert_to_json("Error: " + str(e))) + value_bytes = compress_and_hex(convert_to_json("Error: " + str(e))) print_and_log(str(e), MAJOR_MSG) + self.setParam(reason, value_bytes) + return True # store the values self.setParam(reason, value) return True @@ -266,9 +275,8 @@ def _check_pv_capacity(self, pv: str, size: int, prefix: str) -> None: """ if size > self._pv_info[pv]["count"]: print_and_log( - "Too much data to encode PV {0}. Current size is {1} characters but {2} are required".format( - prefix + pv, self._pv_info[pv]["count"], size - ), + "Too much data to encode PV {0}. Current size is {1} characters " + "but {2} are required".format(prefix + pv, self._pv_info[pv]["count"], size), MAJOR_MSG, LOG_TARGET, ) @@ -281,10 +289,12 @@ def _get_iocs_info(self) -> dict: iocs[iocname].update(options[iocname]) return iocs - def _get_pvs(self, get_method: callable, replace_pv_prefix: bool, *get_args: list) -> list: + def _get_pvs( + self, get_method: Callable[[], list | str | dict], replace_pv_prefix: bool, *get_args: str + ) -> list | str | dict: """ - Method to get pv data using the given method called with the given arguments and optionally remove instrument - prefixes from pv names. + Method to get pv data using the given method called with the given arguments and optionally + remove instrument prefixes from pv names. Args: get_method: The method used to get pv data. @@ -301,17 +311,19 @@ def _get_pvs(self, get_method: callable, replace_pv_prefix: bool, *get_args: lis else: return [] - def _get_interesting_pvs(self, level) -> list: + def _get_interesting_pvs(self, level: str) -> list: """ Gets interesting pvs of the current instrument. Args: - level: The level of high interesting pvs, can be high, low, medium or facility. If level is an empty - string, it returns all interesting pvs of all levels. + level: The level of high interesting pvs, can be high, low, medium or facility. + If level is an empty string, it returns all interesting pvs of all levels. Returns: a list of names of pvs with given level of interest. """ - return self._get_pvs(self._iocs.get_interesting_pvs, False, level) + result = self._get_pvs(self._iocs.get_interesting_pvs, False, level) + assert isinstance(result, list) + return result def _get_active_pvs(self) -> list: """ @@ -320,7 +332,9 @@ def _get_active_pvs(self) -> list: Returns: a list of names of pvs. """ - return self._get_pvs(self._iocs.get_active_pvs, False) + result = self._get_pvs(self._iocs.get_active_pvs, False) + assert isinstance(result, list) + return result def _get_sample_par_names(self) -> list: """ @@ -329,7 +343,9 @@ def _get_sample_par_names(self) -> list: Returns: A list of sample parameter names, an empty list if the database does not exist """ - return self._get_pvs(self._iocs.get_sample_pars, True) + result = self._get_pvs(self._iocs.get_sample_pars, True) + assert isinstance(result, list) + return result def _get_beamline_par_names(self) -> list: """ @@ -338,7 +354,9 @@ def _get_beamline_par_names(self) -> list: Returns: A list of beamline parameter names, an empty list if the database does not exist """ - return self._get_pvs(self._iocs.get_beamline_pars, True) + result = self._get_pvs(self._iocs.get_beamline_pars, True) + assert isinstance(result, list) + return result def _get_user_par_names(self) -> list: """ @@ -347,19 +365,25 @@ def _get_user_par_names(self) -> list: Returns: A list of user parameter names, an empty list if the database does not exist """ - return self._get_pvs(self._iocs.get_user_pars, True) + result = self._get_pvs(self._iocs.get_user_pars, True) + assert isinstance(result, list) + return result - def _get_moxa_mappings(self) -> list: + def _get_moxa_mappings(self) -> dict: """ Returns the user parameters from the database, replacing the MYPVPREFIX macro. Returns: An ordered dict of moxa models and their respective COM mappings """ - return self._get_pvs(self._moxa_data._get_mappings_str, False) + result = self._get_pvs(self._moxa_data._get_mappings_str, False) + assert isinstance(result, dict) + return result - def _get_num_of_moxas(self): - return self._get_pvs(self._moxa_data._get_moxa_num, True) + def _get_num_of_moxas(self) -> str: + result = self._get_pvs(self._moxa_data._get_moxa_num, True) + assert isinstance(result, str) + return result @staticmethod def _get_iocs_not_to_stop() -> list: @@ -369,7 +393,7 @@ def _get_iocs_not_to_stop() -> list: Returns: A list of IOCs not to stop """ - return IOCS_NOT_TO_STOP + return list(IOCS_NOT_TO_STOP) if __name__ == "__main__": @@ -390,7 +414,8 @@ def _get_iocs_not_to_stop() -> list: nargs=1, type=str, default=["."], - help="The directory from which to load the configuration options(default=current directory)", + help="The directory from which to load the configuration options" + "(default=current directory)", ) args = parser.parse_args() diff --git a/DatabaseServer/exp_data.py b/DatabaseServer/exp_data.py index 289424e6..f3f22d4d 100644 --- a/DatabaseServer/exp_data.py +++ b/DatabaseServer/exp_data.py @@ -17,22 +17,27 @@ # http://opensource.org/licenses/eclipse-1.0.php import json import traceback -import typing import unicodedata -from typing import Union +from typing import TYPE_CHECKING, Union from genie_python.mysql_abstraction_layer import SQLAbstraction from server_common.channel_access import ChannelAccess +from server_common.mocks.mock_ca import MockChannelAccess from server_common.utilities import char_waveform, compress_and_hex, print_and_log +if TYPE_CHECKING: + from DatabaseServer.test_modules.test_exp_data import MockExpDataSource + class User(object): """ A user class to allow for easier conversions from database to json. """ - def __init__(self, name: str = "UNKNOWN", institute: str = "UNKNOWN", role: str = "UNKNOWN"): + def __init__( + self, name: str = "UNKNOWN", institute: str = "UNKNOWN", role: str = "UNKNOWN" + ) -> None: self.name = name self.institute = institute self.role = role @@ -43,7 +48,7 @@ class ExpDataSource(object): This is a humble object containing all the code for accessing the database. """ - def __init__(self): + def __init__(self) -> None: self._db = SQLAbstraction("exp_data", "exp_data", "$exp_data") def get_team(self, experiment_id: str) -> list: @@ -64,7 +69,10 @@ def get_team(self, experiment_id: str) -> list: sqlquery += " AND experimentteams.experimentID = %s" sqlquery += " GROUP BY user.userID" sqlquery += " ORDER BY role.priority" - team = [list(element) for element in self._db.query(sqlquery, (experiment_id,))] + result = self._db.query(sqlquery, (experiment_id,)) + if result is None: + return [] + team = [list(element) for element in result] if len(team) == 0: raise ValueError( "unable to find team details for experiment ID {}".format(experiment_id) @@ -90,6 +98,8 @@ def experiment_exists(self, experiment_id: str) -> bool: sqlquery += " FROM experiment " sqlquery += " WHERE experiment.experimentID = %s" id = self._db.query(sqlquery, (experiment_id,)) + if id is None: + return False return len(id) >= 1 except Exception: print_and_log(traceback.format_exc()) @@ -109,8 +119,8 @@ def __init__( self, prefix: str, db: Union[ExpDataSource, "MockExpDataSource"], - ca: Union[ChannelAccess, "MockChannelAccess"] = ChannelAccess(), - ): + ca: Union[ChannelAccess, MockChannelAccess] = ChannelAccess(), + ) -> None: """ Constructor. @@ -148,7 +158,7 @@ def _make_ascii_mappings() -> dict: d[ord("\xe6")] = "ae" return d - def encode_for_return(self, data: typing.Any) -> bytes: + def encode_for_return(self, data: dict | list) -> bytes: """ Converts data to JSON, compresses it and converts it to hex. @@ -163,7 +173,7 @@ def encode_for_return(self, data: typing.Any) -> bytes: def _get_surname_from_fullname(self, fullname: str) -> str: try: return fullname.split(" ")[-1] - except: + except ValueError | IndexError: return fullname def update_experiment_id(self, experiment_id: str) -> None: @@ -210,11 +220,11 @@ def update_experiment_id(self, experiment_id: str) -> None: self.ca.caput(self._simnames, self.encode_for_return(names)) self.ca.caput(self._surnamepv, self.encode_for_return(surnames)) self.ca.caput(self._orgspv, self.encode_for_return(orgs)) - # The value put to the dae names pv will need changing in time to use compressed and hexed json etc. but - # this is not available at this time in the ICP + # The value put to the dae names pv will need changing in time to use compressed and + # hexed json etc. but this is not available at this time in the ICP self.ca.caput(self._daenamespv, ExpData.make_name_list_ascii(surnames)) - def update_username(self, users: str) -> None: + def update_username(self, user_str: str) -> None: """ Updates the associated PVs when the User Names are altered. @@ -229,7 +239,7 @@ def update_username(self, users: str) -> None: surnames = [] orgs = [] - users = json.loads(users) if users else [] + users = json.loads(user_str) if user_str else [] # Find user details in deserialized json user data for team_member in users: @@ -249,8 +259,8 @@ def update_username(self, users: str) -> None: self.ca.caput(self._simnames, self.encode_for_return(names)) self.ca.caput(self._surnamepv, self.encode_for_return(surnames)) self.ca.caput(self._orgspv, self.encode_for_return(orgs)) - # The value put to the dae names pv will need changing in time to use compressed and hexed json etc. but - # this is not available at this time in the ICP + # The value put to the dae names pv will need changing in time to use compressed and hexed + # json etc. but this is not available at this time in the ICP if not surnames: self.ca.caput(self._daenamespv, " ") else: @@ -259,8 +269,8 @@ def update_username(self, users: str) -> None: @staticmethod def make_name_list_ascii(names: list) -> bytes: """ - Takes a unicode list of names and creates a best ascii comma separated list this implementation is a temporary - fix until we install the PyPi unidecode module. + Takes a unicode list of names and creates a best ascii comma separated list this + implementation is a temporary fix until we install the PyPi unidecode module. Args: names: list of unicode names diff --git a/DatabaseServer/ioc_options.py b/DatabaseServer/ioc_options.py index 78c452b0..55f8aaac 100644 --- a/DatabaseServer/ioc_options.py +++ b/DatabaseServer/ioc_options.py @@ -20,7 +20,7 @@ class IocOptions(object): """Contains the possible macros and pvsets of an IOC.""" - def __init__(self, name: str): + def __init__(self, name: str) -> None: """Constructor Args: @@ -28,7 +28,8 @@ def __init__(self, name: str): """ self.name = name - # The possible macros, pvsets and pvs for an IOC, along with associated parameters such as description + # The possible macros, pvsets and pvs for an IOC, along with associated parameters such as + # description self.macros = dict() self.pvsets = dict() self.pvs = dict() diff --git a/DatabaseServer/mocks/mock_exp_data.py b/DatabaseServer/mocks/mock_exp_data.py index bae643ef..2f737d3c 100644 --- a/DatabaseServer/mocks/mock_exp_data.py +++ b/DatabaseServer/mocks/mock_exp_data.py @@ -12,7 +12,7 @@ def encode_for_return(self, data: str) -> bytes: def _get_surname_from_fullname(self, fullname: str) -> str: try: return fullname.split(" ")[-1] - except: + except ValueError | IndexError: return fullname def update_experiment_id(self, experiment_id: str) -> None: diff --git a/DatabaseServer/mocks/mock_procserv_utils.py b/DatabaseServer/mocks/mock_procserv_utils.py index 1933f4ce..78ccf1db 100644 --- a/DatabaseServer/mocks/mock_procserv_utils.py +++ b/DatabaseServer/mocks/mock_procserv_utils.py @@ -19,12 +19,10 @@ class MockProcServWrapper(object): """ - Note: this file cannot currently be given type hints as it is included from the "server_common" tests. - - It can get type hints added once server_common has been migrated to Py3. + Mock ProcServer """ - def __init__(self): + def __init__(self) -> None: self.ps_status = dict() self.ps_status["simple1"] = "SHUTDOWN" self.ps_status["simple2"] = "SHUTDOWN" @@ -32,28 +30,28 @@ def __init__(self): self.ps_status["stopdioc"] = "SHUTDOWN" @staticmethod - def generate_prefix(prefix, ioc): - return "%sCS:PS:%s" % (prefix, ioc) + def generate_prefix(prefix: str, ioc: str) -> str: + return f"{prefix}CS:PS:{ioc}" - def start_ioc(self, prefix, ioc): + def start_ioc(self, prefix: str, ioc: str) -> None: self.ps_status[ioc.lower()] = "RUNNING" - def stop_ioc(self, prefix, ioc): + def stop_ioc(self, prefix: str, ioc: str) -> None: """Stops the specified IOC""" self.ps_status[ioc.lower()] = "SHUTDOWN" - def restart_ioc(self, prefix, ioc): + def restart_ioc(self, prefix: str, ioc: str) -> None: self.ps_status[ioc.lower()] = "RUNNING" - def get_ioc_status(self, prefix, ioc): + def get_ioc_status(self, prefix: str, ioc: str) -> str: if ioc.lower() not in self.ps_status.keys(): - raise Exception("Could not find IOC (%s)" % self.generate_prefix(prefix, ioc)) + raise TimeoutError("Could not find IOC (%s)" % self.generate_prefix(prefix, ioc)) else: return self.ps_status[ioc.lower()] - def ioc_exists(self, prefix, ioc): + def ioc_exists(self, prefix: str, ioc: str) -> bool: try: self.get_ioc_status(prefix, ioc) return True - except: + except TimeoutError: return False diff --git a/DatabaseServer/moxa_data.py b/DatabaseServer/moxa_data.py index 41dd8b8d..18f6d682 100644 --- a/DatabaseServer/moxa_data.py +++ b/DatabaseServer/moxa_data.py @@ -3,7 +3,12 @@ import time from collections import OrderedDict from threading import RLock, Thread -from typing import Dict, List, Tuple +from typing import Dict, List, Optional, Tuple, TypeGuard + +from genie_python.mysql_abstraction_layer import ( # type: ignore + AbstractSQLCommands, + ParamsSequenceOrDictType, +) from server_common.snmpWalker import walk from server_common.utilities import SEVERITY, print_and_log @@ -38,23 +43,36 @@ PORT_MIBS = ["IF-MIB::ifOperStatus", "IF-MIB::ifSpeed", "IF-MIB::ifInOctets", "IF-MIB::ifOutOctets"] +def _is_list_of_str_list(val: list[list]) -> TypeGuard[list[list[str]]]: + sublist_vals = [_is_list_of_str(sublist) for sublist in val] + return all(sublist_vals) + + +def _is_list_of_str(val: list) -> TypeGuard[list[str]]: + return all(isinstance(val, str) for x in val) + + class MoxaDataSource(object): """ A source for IOC data from the database """ - def __init__(self, mysql_abstraction_layer): + def __init__(self, mysql_abstraction_layer: AbstractSQLCommands) -> None: """ Constructor. Args: - mysql_abstraction_layer(genie_python.mysql_abstraction_layer.AbstractSQLCommands): contact database with sql + mysql_abstraction_layer(genie_python.mysql_abstraction_layer.AbstractSQLCommands): + contact database with sql """ self.mysql_abstraction_layer = mysql_abstraction_layer - def _query_and_normalise(self, sqlquery, bind_vars=None): + def _query_and_normalise( + self, sqlquery: str, bind_vars: Optional[ParamsSequenceOrDictType] = None + ) -> list[list[str]]: """ - Executes the given query to the database and converts the data in each row from bytearray to a normal string. + Executes the given query to the database and converts the data in each row from bytearray to + a normal string. Args: sqlquery: The query to execute. @@ -64,30 +82,34 @@ def _query_and_normalise(self, sqlquery, bind_vars=None): A list of lists of strings, representing the data from the table. """ # Get as a plain list of lists - values = [ - list(element) for element in self.mysql_abstraction_layer.query(sqlquery, bind_vars) - ] + result = self.mysql_abstraction_layer.query(sqlquery, bind_vars) + if result is None: + return [[]] + values = [list(element) for element in result] # Convert any bytearrays for i, pv in enumerate(values): for j, element in enumerate(pv): - if type(element) == bytearray: + if type(element) is bytearray: values[i][j] = element.decode("utf-8") + + assert _is_list_of_str_list(values) return values - def _delete_all(self): + def _delete_all(self) -> None: self.mysql_abstraction_layer.update(DELETE_PORTS) self.mysql_abstraction_layer.update(DELETE_IPS) """ - Iterates through the map of ip to hostname and physical port to COM ports and inserts the mappings into the sql instance. + Iterates through the map of ip to hostname and physical port to COM ports and inserts the + mappings into the sql instance. Args: moxa_ip_name_dict: The map of IP addresses to hostnames of Moxa Nports moxa_ports_dict: The map of IP addresses to physical and COM port mappings """ - def insert_mappings(self, moxa_ip_name_dict, moxa_ports_dict): + def insert_mappings(self, moxa_ip_name_dict: dict, moxa_ports_dict: dict) -> None: print_and_log("inserting moxa mappings to SQL") self._delete_all() for moxa_name, moxa_ip in moxa_ip_name_dict.items(): @@ -109,13 +131,13 @@ def insert_mappings(self, moxa_ip_name_dict, moxa_ports_dict): class MoxaData: MDPV = {"UPDATE_MM": {"type": "int"}} - def __init__(self, data_source, prefix): + def __init__(self, data_source: MoxaDataSource, prefix: str) -> None: """Constructor Args: - data_source (IocDataSource): The wrapper for the database that holds IOC information - procserver (ProcServWrapper): An instance of ProcServWrapper, used to start and stop IOCs - prefix (string): The pv prefix of the instrument the server is being run on + data_source: The wrapper for the database that holds IOC information + procserver: An instance of ProcServWrapper, used to start and stop IOCs + prefix: The pv prefix of the instrument the server is being run on """ self._moxa_data_source = data_source self._prefix = prefix @@ -132,16 +154,17 @@ def __init__(self, data_source, prefix): Gets the mappings and inserts them into SQL """ - def update_mappings(self): + def update_mappings(self) -> None: print_and_log("updating moxa mappings") self._mappings = self._get_mappings() self._moxa_data_source.insert_mappings(*self._get_mappings()) """ - Returns the IP to hostname and IP to port mappings as a string representation for use with the MOXA_MAPPINGS PV + Returns the IP to hostname and IP to port mappings as a string representation for use with the + MOXA_MAPPINGS PV """ - def _get_mappings_str(self): + def _get_mappings_str(self) -> Dict: with self._snmp_lock: return self._snmp_map @@ -149,33 +172,34 @@ def _get_mappings_str(self): ran as background thread to update _snmp_map """ - def _update_snmp(self): + def _update_snmp(self) -> None: while True: - # it is much easier to parse the mappings if they just look like a key:{key, val} list, so lets do that now rather than in the GUI + # it is much easier to parse the mappings if they just look like a key:{key, val} list, + # so lets do that now rather than in the GUI newmap = dict() for hostname, mappings in self._mappings[1].items(): ip_addr = self._mappings[0][hostname] mibmap = walk(ip_addr, "1.3.6.1.2.1", SYSTEM_MIBS + PORT_MIBS) # Some defensive coding to avoid errors if SNMP walk fails - upTime = "" + up_time = "" if "DISMAN-EXPRESSION-MIB::sysUpTimeInstance" in mibmap: - upTime = mibmap["DISMAN-EXPRESSION-MIB::sysUpTimeInstance"] - sysName = "" - if "SNMPv2-MIB::sysName.0" in mibmap: - sysName = mibmap["SNMPv2-MIB::sysName.0"] + up_time = mibmap["DISMAN-EXPRESSION-MIB::sysUpTimeInstance"] + sys_name = "" + if "SNMPv2-MIB::sys_name.0" in mibmap: + sys_name = mibmap["SNMPv2-MIB::sys_name.0"] newkey = f"{hostname}({ip_addr})" - if len(upTime) > 0: - newkey = f"{hostname}({ip_addr} - {sysName})({upTime})" + if len(up_time) > 0: + newkey = f"{hostname}({ip_addr} - {sys_name})({up_time})" newmap[newkey] = [] for coms in mappings: - additionalInfo = "" + additional_info = "" for mib in PORT_MIBS: - portMIB = int(str(coms[0])) + 1 - key = mib + "." + str(portMIB) + port_mib = int(str(coms[0])) + 1 + key = mib + "." + str(port_mib) if key in mibmap: - additionalInfo += mib + "=" + mibmap[key] + "~" - if len(additionalInfo) > 0: - newmap[newkey].append([str(coms[0]), f"COM{coms[1]}~{additionalInfo}"]) + additional_info += mib + "=" + mibmap[key] + "~" + if len(additional_info) > 0: + newmap[newkey].append([str(coms[0]), f"COM{coms[1]}~{additional_info}"]) else: newmap[newkey].append([str(coms[0]), f"COM{coms[1]}"]) @@ -194,51 +218,48 @@ def _get_hostname(self, ip_addr: str) -> str: print(f"unknown hostname for IP address {ip_addr}") return "unknown" - def _get_mappings(self) -> Tuple[Dict[str, str], Dict[int, List[Tuple[int, int]]]]: + def _get_mappings(self) -> Tuple[Dict[str, str], Dict[str, List[Tuple[int, int]]]]: # moxa_name_ip_dict: HOSTNAME:IPADDR # moxa_ports_dict: HOSTNAME:[(PHYSPORT:COMPORT),...] - moxa_name_ip_dict = dict() - moxa_ports_dict = dict() + moxa_name_ip_dict: Dict[str, str] = {} + moxa_ports_dict: Dict[str, List[Tuple[int, int]]] = {} if os.name == "nt": import winreg as wrg location = wrg.HKEY_LOCAL_MACHINE - - using_npdrv2 = False - ports_count = 0 try: # Try and find whether the npdrv2 subkey exists to determine whether we are using # the Nport Driver manager as opposed to Nport Administrator ports_path = wrg.OpenKeyEx(location, REG_DIR_NPDRV2) - using_npdrv2 = True + ports_count = wrg.QueryInfoKey(ports_path)[0] + + # This is what Nport Windows Driver manager uses. It uses a subkey for each port + # mappping, each of which has an ip address referenced. It doesn't seem to have + # a physical port number as the ports are added individually, so we have to + # modulo the port number. + for port_num in range(0, ports_count): + port_subkey = f"{port_num:04d}" + port_reg = wrg.OpenKeyEx(ports_path, port_subkey) + device_params = wrg.OpenKeyEx(port_reg, "Device Parameters") + ip_addr = wrg.QueryValueEx(device_params, "IPAddress1")[0] + com_num = wrg.QueryValueEx(device_params, "COMNO")[0] + hostname = self._get_hostname(ip_addr) + + moxa_name_ip_dict[hostname] = ip_addr + + if hostname not in moxa_ports_dict.keys(): + moxa_ports_dict[hostname] = list() + # Modulo by 16 here as we want the 2nd moxa's first port_num to be 1 rather + # than 17 as it's the first port on the second moxa + port_num_respective = port_num % 16 + moxa_ports_dict[hostname].append((port_num_respective + 1, com_num)) except FileNotFoundError: print_and_log("using old style registry for moxas", severity=SEVERITY.MINOR) - try: - if using_npdrv2: - # This is what Nport Windows Driver manager uses. It uses a subkey for each port mappping, - # each of which has an ip address referenced. It doesn't seem to have a physical port number - # as the ports are added individually, so we have to modulo the port number. - for port_num in range(0, ports_count): - port_subkey = f"{port_num:04d}" - port_reg = wrg.OpenKeyEx(ports_path, port_subkey) - device_params = wrg.OpenKeyEx(port_reg, "Device Parameters") - ip_addr = wrg.QueryValueEx(device_params, "IPAddress1")[0] - com_num = wrg.QueryValueEx(device_params, "COMNO")[0] - hostname = self._get_hostname(ip_addr) - - moxa_name_ip_dict[hostname] = ip_addr - - if hostname not in moxa_ports_dict.keys(): - moxa_ports_dict[hostname] = list() - # Modulo by 16 here as we want the 2nd moxa's first port_num to be 1 rather - # than 17 as it's the first port on the second moxa - port_num_respective = port_num % 16 - moxa_ports_dict[hostname].append((port_num_respective + 1, com_num)) - - else: - # This is what Nport Administrator uses. It lays out each Moxa that is added to "Servers" which contains a few bytes - # and lays things out in a subkey for each. + else: + try: + # This is what Nport Administrator uses. It lays out each Moxa that is added to + # "Servers" which contains a few bytes and lays things out in a subkey for each. params = wrg.OpenKeyEx(location, REG_KEY_NPDRV) server_count = wrg.QueryValueEx(params, "Servers")[0] @@ -254,10 +275,10 @@ def _get_mappings(self) -> Tuple[Dict[str, str], Dict[int, List[Tuple[int, int]] moxa_ports_dict[hostname] = list(com_nums) for count, value in com_nums: print_and_log(f"physical port {count} COM number {value}") - except FileNotFoundError as e: - print_and_log( - f"Error reading registry for moxa mapping information: {str(e)}", - severity=SEVERITY.MAJOR, - ) + except FileNotFoundError as e: + print_and_log( + f"Error reading registry for moxa mapping information: {str(e)}", + severity=SEVERITY.MAJOR, + ) return moxa_name_ip_dict, moxa_ports_dict diff --git a/DatabaseServer/options_holder.py b/DatabaseServer/options_holder.py index 517a72f2..3d0d024b 100644 --- a/DatabaseServer/options_holder.py +++ b/DatabaseServer/options_holder.py @@ -21,7 +21,7 @@ class OptionsHolder(object): """Holds all the IOC options""" - def __init__(self, options_folder: str, options_loader: OptionsLoader): + def __init__(self, options_folder: str, options_loader: OptionsLoader) -> None: """Constructor Args: diff --git a/DatabaseServer/options_loader.py b/DatabaseServer/options_loader.py index cad45d8b..117f2e8b 100644 --- a/DatabaseServer/options_loader.py +++ b/DatabaseServer/options_loader.py @@ -16,8 +16,8 @@ # https://www.eclipse.org/org/documents/epl-v10.php or # http://opensource.org/licenses/eclipse-1.0.php import os -import xml from collections import OrderedDict +from xml.etree.ElementTree import Element from DatabaseServer.ioc_options import IocOptions from server_common.utilities import parse_xml_removing_namespace, print_and_log @@ -61,7 +61,7 @@ def get_options(path: str) -> OrderedDict: return iocs @staticmethod - def _options_from_xml(root_xml: xml.etree.ElementTree.Element, iocs: OrderedDict) -> None: + def _options_from_xml(root_xml: Element, iocs: OrderedDict) -> None: """Populates the supplied list of iocs based on an XML tree within a config.xml file""" for ioc in root_xml.findall("./" + TAG_IOC_CONFIG): name = ioc.attrib[TAG_NAME] diff --git a/DatabaseServer/procserv_utils.py b/DatabaseServer/procserv_utils.py index bb2142f4..6b530e62 100644 --- a/DatabaseServer/procserv_utils.py +++ b/DatabaseServer/procserv_utils.py @@ -76,6 +76,7 @@ def get_ioc_status(self, prefix: str, ioc: str) -> str: ans = ChannelAccess.caget(pv, as_string=True) if ans is None: raise IOError("Could not find IOC (%s)" % pv) + assert isinstance(ans, str) return ans.upper() def ioc_exists(self, prefix: str, ioc: str) -> bool: @@ -91,5 +92,5 @@ def ioc_exists(self, prefix: str, ioc: str) -> bool: try: self.get_ioc_status(prefix, ioc) return True - except: + except IOError: return False diff --git a/DatabaseServer/test_modules/__init__.py b/DatabaseServer/test_modules/__init__.py index 16f1eab9..99648879 100644 --- a/DatabaseServer/test_modules/__init__.py +++ b/DatabaseServer/test_modules/__init__.py @@ -1,5 +1,10 @@ from __future__ import absolute_import, division, print_function, unicode_literals +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from unittest import TestLoader, TestSuite + # This file is part of the ISIS IBEX application. # Copyright (C) 2012-2016 Science & Technology Facilities Council. # All rights reserved. @@ -16,10 +21,9 @@ # https://www.eclipse.org/org/documents/epl-v10.php or # http://opensource.org/licenses/eclipse-1.0.php import os -import unittest -def load_tests(loader, standard_tests, pattern): +def load_tests(loader: "TestLoader", standard_tests: "TestSuite", pattern: str) -> "TestSuite": """ This function is needed by the load_tests protocol described at https://docs.python.org/3/library/unittest.html#load-tests-protocol diff --git a/DatabaseServer/test_modules/test_exp_data.py b/DatabaseServer/test_modules/test_exp_data.py index 2d72c70c..193b2858 100644 --- a/DatabaseServer/test_modules/test_exp_data.py +++ b/DatabaseServer/test_modules/test_exp_data.py @@ -68,7 +68,7 @@ def test_update_experiment_id_throws_if_experiment_does_not_exists(self): try: self.exp_data.update_experiment_id("000000") self.fail("Setting invalid experiment id did not throw") - except: + except Exception: pass def test_single_surname_returns_surname(self):