From bd6972a90c5214a204ddb7637f7fcab2576df1c5 Mon Sep 17 00:00:00 2001 From: Lowri Jenkins Date: Thu, 22 May 2025 09:50:35 +0100 Subject: [PATCH 1/6] Fix pyright and ruff for genie_epics_api --- src/genie_python/block_names.py | 3 +- src/genie_python/genie_epics_api.py | 98 +++++++++++++++++++---------- 2 files changed, 67 insertions(+), 34 deletions(-) diff --git a/src/genie_python/block_names.py b/src/genie_python/block_names.py index 09f4882..fb38738 100644 --- a/src/genie_python/block_names.py +++ b/src/genie_python/block_names.py @@ -21,7 +21,7 @@ class BlockNamesManager: def __init__( self, - block_names: "PVValue", + block_names: "BlockNames", delay_before_retry_add_monitor: float = DELAY_BEFORE_RETRYING_BLOCK_NAMES_PV_ON_FAIL, ) -> None: """ @@ -100,6 +100,7 @@ def _update_block_names(self, value: "PVValue", _: Optional[str], _1: Optional[s # add new block as attributes to class try: + assert isinstance(value, str | bytes) block_names = dehex_decompress_and_dejson(value) for name in block_names: attribute_name = name diff --git a/src/genie_python/genie_epics_api.py b/src/genie_python/genie_epics_api.py index 40fb2b0..bb47a18 100644 --- a/src/genie_python/genie_epics_api.py +++ b/src/genie_python/genie_epics_api.py @@ -10,7 +10,7 @@ from builtins import str from collections import OrderedDict from io import open -from typing import TYPE_CHECKING, Callable +from typing import TYPE_CHECKING, Any, Callable from genie_python.block_names import BlockNames, BlockNamesManager from genie_python.channel_access_exceptions import UnableToConnectToPVException @@ -22,7 +22,7 @@ from genie_python.genie_logging import filter as logging_filter from genie_python.genie_pre_post_cmd_manager import PrePostCmdManager from genie_python.genie_wait_for_move import WaitForMoveController -from genie_python.genie_waitfor import WaitForController +from genie_python.genie_waitfor import WAITFOR_VALUE, WaitForController from genie_python.utilities import ( EnvironmentDetails, crc8, @@ -46,7 +46,10 @@ class API(object): def __init__( - self, pv_prefix: str, globs: dict, environment_details: EnvironmentDetails | None = None + self, + pv_prefix: str, + globs: dict[str, Any], + environment_details: EnvironmentDetails | None = None, ) -> None: """ Constructor for the EPICS enabled API. @@ -56,12 +59,12 @@ def __init__( globs: globals environment_details: details of the computer environment """ - self.waitfor = None # type: WaitForController - self.wait_for_move = None - self.dae = None # type: Dae - self.blockserver = None # type: BlockServer - self.exp_data = None # type: GetExperimentData - self.inst_prefix = "" + self.waitfor: WaitForController | None = None # type: WaitForController + self.wait_for_move: WaitForMoveController | None = None + self.dae: Dae | None = None # type: Dae + self.blockserver: BlockServer | None = None # type: BlockServer + self.exp_data: GetExperimentData | None = None # type: GetExperimentData + self.inst_prefix: str = "" self.instrument_name = "" self.machine_name = "" self.localmod = None @@ -108,7 +111,9 @@ def get_instrument_py_name(self) -> str: """ return self.instrument_name.lower().replace("-", "_") - def _get_machine_details_from_identifier(self, machine_identifier: str) -> tuple[str, str, str]: + def _get_machine_details_from_identifier( + self, machine_identifier: str | None + ) -> tuple[str, str, str]: """ Gets the details of a machine by looking it up in the instrument list first. If there is no match it calculates the details as usual. @@ -137,7 +142,9 @@ def _get_machine_details_from_identifier(self, machine_identifier: str) -> tuple # that's been passed to this function if it is not found instrument_details will be None instrument_details = None try: - instrument_list = dehex_decompress_and_dejson(self.get_pv_value("CS:INSTLIST")) + input_list = self.get_pv_value("CS:INSTLIST") + assert isinstance(input_list, str | bytes) + instrument_list = dehex_decompress_and_dejson(input_list) instrument_details = next( (inst for inst in instrument_list if inst["pvPrefix"] == machine_identifier), None ) @@ -182,7 +189,7 @@ def get_instrument_full_name(self) -> str: return self.machine_name def set_instrument( - self, machine_identifier: str, globs: dict, import_instrument_init: bool = True + self, machine_identifier: str, globs: dict[str, Any], import_instrument_init: bool = True ) -> None: """ Set the instrument being used by setting the PV prefix or by the @@ -262,12 +269,14 @@ def prefix_pv_name(self, name: str) -> str: """ Adds the instrument prefix to the specified PV. """ - if self.inst_prefix is not None: - return self.inst_prefix + name - return name + return self.inst_prefix + name def init_instrument( - self, instrument: str, machine_name: str, globs: dict, import_instrument_init: bool + self, + instrument: str, + machine_name: str, + globs: dict[str, Any], + import_instrument_init: bool, ) -> None: """ Initialise an instrument using the default init file followed by the machine specific init. @@ -298,10 +307,14 @@ def init_instrument( # Load the instrument init file self.localmod = importlib.import_module("init_{}".format(instrument)) - if self.localmod.__file__.endswith(".pyc"): - file_loc = self.localmod.__file__[:-1] + _file = self.localmod.__file__ + assert _file is not None + + if _file.endswith(".pyc"): + file_loc = _file[:-1] else: - file_loc = self.localmod.__file__ + file_loc = _file + assert isinstance(file_loc, str) # execfile - this puts any imports in the init file into the globals namespace # Note: Anything loose in the module like print statements will be run twice exec(compile(open(file_loc).read(), file_loc, "exec"), globs) @@ -448,6 +461,7 @@ def reload_current_config(self) -> None: """ Reload the current configuration. """ + assert self.blockserver is not None self.blockserver.reload_current_config() def correct_blockname(self, name: str, add_prefix: bool = True) -> str: @@ -472,7 +486,7 @@ def get_block_names(self) -> list[str]: """ return [name for name in BLOCK_NAMES.__dict__.keys()] - def block_exists(self, name: str, fail_fast: bool = False) -> str | None: + def block_exists(self, name: str, fail_fast: bool = False) -> bool: """ Checks whether the block exists. @@ -510,6 +524,7 @@ def set_block_value( full_name = self.get_pv_from_block(name) if lowlimit is not None and highlimit is not None: + assert isinstance(value, (float, int)) if lowlimit > highlimit: print( "Low limit ({}) higher than high limit ({}), " @@ -531,6 +546,8 @@ def set_block_value( self.set_pv_value(full_name, value) if wait: + assert isinstance(value, WAITFOR_VALUE) + assert self.waitfor is not None self.waitfor.start_waiting(name, value, lowlimit, highlimit) return @@ -599,14 +616,18 @@ def get_block_units(self, block_name: str) -> str | None: return typing.cast(str | None, Wrapper.get_pv_value(unit_name)) def _get_pars( - self, pv_prefix_identifier: str, get_names_from_blockserver: Callable[[], list[str]] - ) -> dict: + self, pv_prefix_identifier: str, get_names_from_blockserver: Callable[[], Any] + ) -> dict[str, PVValue]: """ Get the current parameter values for a given pv subset as a dictionary. """ names = get_names_from_blockserver() ans = {} - if names is not None: + if ( + names is not None + and isinstance(names, list) + and all(isinstance(elem, str) for elem in names) + ): for n in names: val = self.get_pv_value(self.prefix_pv_name(n)) m = re.match(".+:" + pv_prefix_identifier + ":(.+)", n) @@ -618,10 +639,11 @@ def _get_pars( ) return ans - def get_sample_pars(self) -> dict: + def get_sample_pars(self) -> dict[str, PVValue]: """ Get the current sample parameter values as a dictionary. """ + assert self.blockserver is not None return self._get_pars("SAMPLE", self.blockserver.get_sample_par_names) def set_sample_par(self, name: str, value: "PVValue") -> None: @@ -632,8 +654,13 @@ def set_sample_par(self, name: str, value: "PVValue") -> None: name: the name of the parameter to change value: the new value """ + assert self.blockserver is not None names = self.blockserver.get_sample_par_names() - if names is not None: + if ( + names is not None + and isinstance(names, list) + and all(isinstance(elem, str) for elem in names) + ): for n in names: m = re.match(".+:SAMPLE:%s" % name.upper(), n) if m is not None: @@ -642,10 +669,11 @@ def set_sample_par(self, name: str, value: "PVValue") -> None: return raise Exception("Sample parameter %s does not exist" % name) - def get_beamline_pars(self) -> dict: + def get_beamline_pars(self) -> dict[str, PVValue]: """ Get the current beamline parameter values as a dictionary. """ + assert self.blockserver is not None return self._get_pars("BL", self.blockserver.get_beamline_par_names) def set_beamline_par(self, name: str, value: "PVValue") -> None: @@ -656,6 +684,7 @@ def set_beamline_par(self, name: str, value: "PVValue") -> None: name: the name of the parameter to change value: the new value """ + assert self.blockserver is not None names = self.blockserver.get_beamline_par_names() if names is not None: for n in names: @@ -665,7 +694,7 @@ def set_beamline_par(self, name: str, value: "PVValue") -> None: return raise Exception("Beamline parameter %s does not exist" % name) - def get_runcontrol_settings(self, block_name: str) -> tuple[bool, float, float]: + def get_runcontrol_settings(self, block_name: str) -> tuple[PVValue, PVValue, PVValue]: """ Gets the current run-control settings for a block. @@ -711,11 +740,11 @@ def check_limit_violations(self, blocks: list[str]) -> list[str]: list: the blocks which have soft limit violations """ violation_states = self._get_fields_from_blocks(blocks, "LVIO", "limit violation") - return [t[0] for t in violation_states if t[1] == 1] + return [t[0] for t in violation_states if t[1] == "1"] def _get_fields_from_blocks( self, blocks: list[str], field_name: str, field_description: str - ) -> list["PVValue"]: + ) -> list[str]: field_values = list() for block in blocks: if self.block_exists(block): @@ -723,7 +752,7 @@ def _get_fields_from_blocks( full_block_pv = self.get_pv_from_block(block) try: field_value = self.get_pv_value(full_block_pv + "." + field_name, attempts=1) - field_values.append([block_name, field_value]) + field_values.append([block_name, str(field_value)]) except IOError: # Could not get value print("Could not get {} for block: {}".format(field_description, block)) @@ -817,7 +846,7 @@ def send_email(self, address: str, message: str) -> None: except Exception as e: raise Exception("Could not send email: {}".format(e)) - def send_alert(self, message: str, inst: str) -> None: + def send_alert(self, message: str, inst: str | None) -> None: """ Sends an alert message for a specified instrument. @@ -860,13 +889,16 @@ def get_pv_alarm(self, pv_name: str) -> str: alarm status could not be determined """ try: - return self.get_pv_value( + alarm_val = self.get_pv_value( "{}.SEVR".format(remove_field_from_pv(pv_name)), to_string=True ) + assert alarm_val is str + return alarm_val + except Exception: return "UNKNOWN" - def get_block_data(self, block: str, fail_fast: bool = False) -> dict: + def get_block_data(self, block: str, fail_fast: bool = False) -> dict[str, PVValue]: """ Gets the useful values associated with a block. From 58e55c84d560ecf7c3e5baa09bf006d48c9f453a Mon Sep 17 00:00:00 2001 From: Lowri Jenkins Date: Thu, 22 May 2025 09:56:10 +0100 Subject: [PATCH 2/6] Made PVValue a string --- src/genie_python/genie_epics_api.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/genie_python/genie_epics_api.py b/src/genie_python/genie_epics_api.py index bb47a18..29499d9 100644 --- a/src/genie_python/genie_epics_api.py +++ b/src/genie_python/genie_epics_api.py @@ -617,7 +617,7 @@ def get_block_units(self, block_name: str) -> str | None: def _get_pars( self, pv_prefix_identifier: str, get_names_from_blockserver: Callable[[], Any] - ) -> dict[str, PVValue]: + ) -> dict[str, "PVValue"]: """ Get the current parameter values for a given pv subset as a dictionary. """ @@ -639,7 +639,7 @@ def _get_pars( ) return ans - def get_sample_pars(self) -> dict[str, PVValue]: + def get_sample_pars(self) -> dict[str, "PVValue"]: """ Get the current sample parameter values as a dictionary. """ @@ -669,7 +669,7 @@ def set_sample_par(self, name: str, value: "PVValue") -> None: return raise Exception("Sample parameter %s does not exist" % name) - def get_beamline_pars(self) -> dict[str, PVValue]: + def get_beamline_pars(self) -> dict[str, "PVValue"]: """ Get the current beamline parameter values as a dictionary. """ @@ -694,7 +694,7 @@ def set_beamline_par(self, name: str, value: "PVValue") -> None: return raise Exception("Beamline parameter %s does not exist" % name) - def get_runcontrol_settings(self, block_name: str) -> tuple[PVValue, PVValue, PVValue]: + def get_runcontrol_settings(self, block_name: str) -> tuple["PVValue", "PVValue", "PVValue"]: """ Gets the current run-control settings for a block. @@ -898,7 +898,7 @@ def get_pv_alarm(self, pv_name: str) -> str: except Exception: return "UNKNOWN" - def get_block_data(self, block: str, fail_fast: bool = False) -> dict[str, PVValue]: + def get_block_data(self, block: str, fail_fast: bool = False) -> dict[str, "PVValue"]: """ Gets the useful values associated with a block. From 7eece59acfa4d654bd7210f38c168d212157a181 Mon Sep 17 00:00:00 2001 From: Jack Doughty <56323305+jackbdoughty@users.noreply.github.com> Date: Fri, 23 May 2025 13:09:56 +0100 Subject: [PATCH 3/6] Fix tests --- src/genie_python/block_names.py | 2 +- src/genie_python/genie_epics_api.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/genie_python/block_names.py b/src/genie_python/block_names.py index fb38738..5f7f6a5 100644 --- a/src/genie_python/block_names.py +++ b/src/genie_python/block_names.py @@ -100,7 +100,7 @@ def _update_block_names(self, value: "PVValue", _: Optional[str], _1: Optional[s # add new block as attributes to class try: - assert isinstance(value, str | bytes) + assert isinstance(value, (str, bytes)), value block_names = dehex_decompress_and_dejson(value) for name in block_names: attribute_name = name diff --git a/src/genie_python/genie_epics_api.py b/src/genie_python/genie_epics_api.py index 29499d9..838688f 100644 --- a/src/genie_python/genie_epics_api.py +++ b/src/genie_python/genie_epics_api.py @@ -892,7 +892,7 @@ def get_pv_alarm(self, pv_name: str) -> str: alarm_val = self.get_pv_value( "{}.SEVR".format(remove_field_from_pv(pv_name)), to_string=True ) - assert alarm_val is str + assert isinstance(alarm_val, str) return alarm_val except Exception: From a7b400e9dc8d245ddab2fb428660068f116ae7b2 Mon Sep 17 00:00:00 2001 From: Jack Doughty Date: Fri, 23 May 2025 14:15:30 +0100 Subject: [PATCH 4/6] fix stuff --- src/genie_python/genie_epics_api.py | 42 ++++++++++++++++++++++------- 1 file changed, 32 insertions(+), 10 deletions(-) diff --git a/src/genie_python/genie_epics_api.py b/src/genie_python/genie_epics_api.py index 838688f..0834d30 100644 --- a/src/genie_python/genie_epics_api.py +++ b/src/genie_python/genie_epics_api.py @@ -59,11 +59,11 @@ def __init__( globs: globals environment_details: details of the computer environment """ - self.waitfor: WaitForController | None = None # type: WaitForController + self.waitfor: WaitForController | None = None self.wait_for_move: WaitForMoveController | None = None - self.dae: Dae | None = None # type: Dae - self.blockserver: BlockServer | None = None # type: BlockServer - self.exp_data: GetExperimentData | None = None # type: GetExperimentData + self.dae: Dae | None = None + self.blockserver: BlockServer | None = None + self.exp_data: GetExperimentData | None = None self.inst_prefix: str = "" self.instrument_name = "" self.machine_name = "" @@ -364,6 +364,26 @@ def set_pv_value( self.logger.log_error_msg("set_pv_value exception {!r}".format(e)) raise e + @typing.overload + def get_pv_value( + self, + name: str, + to_string: typing.Literal[True] = True, + attempts: int = 3, + is_local: bool = True, + use_numpy: None = None, + ) -> str: ... + + @typing.overload + def get_pv_value( + self, + name: str, + to_string: bool = False, + attempts: int = 3, + is_local: bool = False, + use_numpy: bool | None = None, + ) -> "PVValue": ... + def get_pv_value( self, name: str, @@ -524,7 +544,8 @@ def set_block_value( full_name = self.get_pv_from_block(name) if lowlimit is not None and highlimit is not None: - assert isinstance(value, (float, int)) + if not isinstance(value, (float, int)): + raise ValueError("Both limits provided but value is not a number") if lowlimit > highlimit: print( "Low limit ({}) higher than high limit ({}), " @@ -546,7 +567,8 @@ def set_block_value( self.set_pv_value(full_name, value) if wait: - assert isinstance(value, WAITFOR_VALUE) + if not isinstance(value, WAITFOR_VALUE): + raise ValueError(f"Wait value is not a WAITFOR_VALUE: {value}") assert self.waitfor is not None self.waitfor.start_waiting(name, value, lowlimit, highlimit) return @@ -740,11 +762,12 @@ def check_limit_violations(self, blocks: list[str]) -> list[str]: list: the blocks which have soft limit violations """ violation_states = self._get_fields_from_blocks(blocks, "LVIO", "limit violation") - return [t[0] for t in violation_states if t[1] == "1"] + + return [t[0] for t in violation_states if typing.cast(bool, t[1]) == 1] def _get_fields_from_blocks( self, blocks: list[str], field_name: str, field_description: str - ) -> list[str]: + ) -> list[tuple[str, "PVValue"]]: field_values = list() for block in blocks: if self.block_exists(block): @@ -752,7 +775,7 @@ def _get_fields_from_blocks( full_block_pv = self.get_pv_from_block(block) try: field_value = self.get_pv_value(full_block_pv + "." + field_name, attempts=1) - field_values.append([block_name, str(field_value)]) + field_values.append((block_name, field_value)) except IOError: # Could not get value print("Could not get {} for block: {}".format(field_description, block)) @@ -892,7 +915,6 @@ def get_pv_alarm(self, pv_name: str) -> str: alarm_val = self.get_pv_value( "{}.SEVR".format(remove_field_from_pv(pv_name)), to_string=True ) - assert isinstance(alarm_val, str) return alarm_val except Exception: From 7a3c8e66afd917972a597aef177fb84916281874 Mon Sep 17 00:00:00 2001 From: Jack Doughty <56323305+jackbdoughty@users.noreply.github.com> Date: Fri, 23 May 2025 14:18:39 +0100 Subject: [PATCH 5/6] Update src/genie_python/genie_epics_api.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- src/genie_python/genie_epics_api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/genie_python/genie_epics_api.py b/src/genie_python/genie_epics_api.py index 0834d30..b8325cb 100644 --- a/src/genie_python/genie_epics_api.py +++ b/src/genie_python/genie_epics_api.py @@ -763,7 +763,7 @@ def check_limit_violations(self, blocks: list[str]) -> list[str]: """ violation_states = self._get_fields_from_blocks(blocks, "LVIO", "limit violation") - return [t[0] for t in violation_states if typing.cast(bool, t[1]) == 1] + return [t[0] for t in violation_states if t[1]] def _get_fields_from_blocks( self, blocks: list[str], field_name: str, field_description: str From 59c3a2ce638d141628a452c471a07a854b3bbf28 Mon Sep 17 00:00:00 2001 From: Jack Doughty Date: Fri, 23 May 2025 14:23:07 +0100 Subject: [PATCH 6/6] copilot suggestions --- src/genie_python/genie_epics_api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/genie_python/genie_epics_api.py b/src/genie_python/genie_epics_api.py index 0834d30..7807e32 100644 --- a/src/genie_python/genie_epics_api.py +++ b/src/genie_python/genie_epics_api.py @@ -371,7 +371,7 @@ def get_pv_value( to_string: typing.Literal[True] = True, attempts: int = 3, is_local: bool = True, - use_numpy: None = None, + use_numpy: bool | None = None, ) -> str: ... @typing.overload