Skip to content

Fix pyright and ruff for genie_epics_api #46

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
May 23, 2025
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/genie_python/block_names.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class BlockNamesManager:

def __init__(
self,
block_names: "PVValue",
block_names: "BlockNames",
delay_before_retry_add_monitor: float = DELAY_BEFORE_RETRYING_BLOCK_NAMES_PV_ON_FAIL,
) -> None:
"""
Expand Down Expand Up @@ -100,6 +100,7 @@ def _update_block_names(self, value: "PVValue", _: Optional[str], _1: Optional[s

# add new block as attributes to class
try:
assert isinstance(value, (str, bytes)), value
block_names = dehex_decompress_and_dejson(value)
for name in block_names:
attribute_name = name
Expand Down
120 changes: 87 additions & 33 deletions src/genie_python/genie_epics_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from builtins import str
from collections import OrderedDict
from io import open
from typing import TYPE_CHECKING, Callable
from typing import TYPE_CHECKING, Any, Callable

from genie_python.block_names import BlockNames, BlockNamesManager
from genie_python.channel_access_exceptions import UnableToConnectToPVException
Expand All @@ -22,7 +22,7 @@
from genie_python.genie_logging import filter as logging_filter
from genie_python.genie_pre_post_cmd_manager import PrePostCmdManager
from genie_python.genie_wait_for_move import WaitForMoveController
from genie_python.genie_waitfor import WaitForController
from genie_python.genie_waitfor import WAITFOR_VALUE, WaitForController
from genie_python.utilities import (
EnvironmentDetails,
crc8,
Expand All @@ -46,7 +46,10 @@

class API(object):
def __init__(
self, pv_prefix: str, globs: dict, environment_details: EnvironmentDetails | None = None
self,
pv_prefix: str,
globs: dict[str, Any],
environment_details: EnvironmentDetails | None = None,
) -> None:
"""
Constructor for the EPICS enabled API.
Expand All @@ -56,12 +59,12 @@
globs: globals
environment_details: details of the computer environment
"""
self.waitfor = None # type: WaitForController
self.wait_for_move = None
self.dae = None # type: Dae
self.blockserver = None # type: BlockServer
self.exp_data = None # type: GetExperimentData
self.inst_prefix = ""
self.waitfor: WaitForController | None = None
self.wait_for_move: WaitForMoveController | None = None
self.dae: Dae | None = None
self.blockserver: BlockServer | None = None
self.exp_data: GetExperimentData | None = None
self.inst_prefix: str = ""
self.instrument_name = ""
self.machine_name = ""
self.localmod = None
Expand Down Expand Up @@ -108,7 +111,9 @@
"""
return self.instrument_name.lower().replace("-", "_")

def _get_machine_details_from_identifier(self, machine_identifier: str) -> tuple[str, str, str]:
def _get_machine_details_from_identifier(
self, machine_identifier: str | None
) -> tuple[str, str, str]:
"""
Gets the details of a machine by looking it up in the instrument list first.
If there is no match it calculates the details as usual.
Expand Down Expand Up @@ -137,7 +142,9 @@
# that's been passed to this function if it is not found instrument_details will be None
instrument_details = None
try:
instrument_list = dehex_decompress_and_dejson(self.get_pv_value("CS:INSTLIST"))
input_list = self.get_pv_value("CS:INSTLIST")
assert isinstance(input_list, str | bytes)
instrument_list = dehex_decompress_and_dejson(input_list)
instrument_details = next(
(inst for inst in instrument_list if inst["pvPrefix"] == machine_identifier), None
)
Expand Down Expand Up @@ -182,7 +189,7 @@
return self.machine_name

def set_instrument(
self, machine_identifier: str, globs: dict, import_instrument_init: bool = True
self, machine_identifier: str, globs: dict[str, Any], import_instrument_init: bool = True
) -> None:
"""
Set the instrument being used by setting the PV prefix or by the
Expand Down Expand Up @@ -262,12 +269,14 @@
"""
Adds the instrument prefix to the specified PV.
"""
if self.inst_prefix is not None:
return self.inst_prefix + name
return name
return self.inst_prefix + name

def init_instrument(
self, instrument: str, machine_name: str, globs: dict, import_instrument_init: bool
self,
instrument: str,
machine_name: str,
globs: dict[str, Any],
import_instrument_init: bool,
) -> None:
"""
Initialise an instrument using the default init file followed by the machine specific init.
Expand Down Expand Up @@ -298,10 +307,14 @@
# Load the instrument init file
self.localmod = importlib.import_module("init_{}".format(instrument))

if self.localmod.__file__.endswith(".pyc"):
file_loc = self.localmod.__file__[:-1]
_file = self.localmod.__file__
assert _file is not None

if _file.endswith(".pyc"):
file_loc = _file[:-1]
else:
file_loc = self.localmod.__file__
file_loc = _file
assert isinstance(file_loc, str)
# execfile - this puts any imports in the init file into the globals namespace
# Note: Anything loose in the module like print statements will be run twice
exec(compile(open(file_loc).read(), file_loc, "exec"), globs)
Expand Down Expand Up @@ -351,6 +364,26 @@
self.logger.log_error_msg("set_pv_value exception {!r}".format(e))
raise e

@typing.overload
def get_pv_value(
self,
name: str,
to_string: typing.Literal[True] = True,
attempts: int = 3,
is_local: bool = True,
use_numpy: None = None,
) -> str: ...

@typing.overload
def get_pv_value(
self,
name: str,
to_string: bool = False,
attempts: int = 3,
is_local: bool = False,
use_numpy: bool | None = None,
) -> "PVValue": ...

def get_pv_value(
self,
name: str,
Expand Down Expand Up @@ -448,6 +481,7 @@
"""
Reload the current configuration.
"""
assert self.blockserver is not None
self.blockserver.reload_current_config()

def correct_blockname(self, name: str, add_prefix: bool = True) -> str:
Expand All @@ -472,7 +506,7 @@
"""
return [name for name in BLOCK_NAMES.__dict__.keys()]

def block_exists(self, name: str, fail_fast: bool = False) -> str | None:
def block_exists(self, name: str, fail_fast: bool = False) -> bool:
"""
Checks whether the block exists.

Expand Down Expand Up @@ -510,6 +544,8 @@
full_name = self.get_pv_from_block(name)

if lowlimit is not None and highlimit is not None:
if not isinstance(value, (float, int)):
raise ValueError("Both limits provided but value is not a number")
if lowlimit > highlimit:
print(
"Low limit ({}) higher than high limit ({}), "
Expand All @@ -531,6 +567,9 @@
self.set_pv_value(full_name, value)

if wait:
if not isinstance(value, WAITFOR_VALUE):
raise ValueError(f"Wait value is not a WAITFOR_VALUE: {value}")
assert self.waitfor is not None
self.waitfor.start_waiting(name, value, lowlimit, highlimit)
return

Expand Down Expand Up @@ -599,14 +638,18 @@
return typing.cast(str | None, Wrapper.get_pv_value(unit_name))

def _get_pars(
self, pv_prefix_identifier: str, get_names_from_blockserver: Callable[[], list[str]]
) -> dict:
self, pv_prefix_identifier: str, get_names_from_blockserver: Callable[[], Any]
) -> dict[str, "PVValue"]:
"""
Get the current parameter values for a given pv subset as a dictionary.
"""
names = get_names_from_blockserver()
ans = {}
if names is not None:
if (
names is not None
and isinstance(names, list)
and all(isinstance(elem, str) for elem in names)
):
for n in names:
val = self.get_pv_value(self.prefix_pv_name(n))
m = re.match(".+:" + pv_prefix_identifier + ":(.+)", n)
Expand All @@ -618,10 +661,11 @@
)
return ans

def get_sample_pars(self) -> dict:
def get_sample_pars(self) -> dict[str, "PVValue"]:
"""
Get the current sample parameter values as a dictionary.
"""
assert self.blockserver is not None
return self._get_pars("SAMPLE", self.blockserver.get_sample_par_names)

def set_sample_par(self, name: str, value: "PVValue") -> None:
Expand All @@ -632,8 +676,13 @@
name: the name of the parameter to change
value: the new value
"""
assert self.blockserver is not None
names = self.blockserver.get_sample_par_names()
if names is not None:
if (
names is not None
and isinstance(names, list)
and all(isinstance(elem, str) for elem in names)
):
for n in names:
m = re.match(".+:SAMPLE:%s" % name.upper(), n)
if m is not None:
Expand All @@ -642,10 +691,11 @@
return
raise Exception("Sample parameter %s does not exist" % name)

def get_beamline_pars(self) -> dict:
def get_beamline_pars(self) -> dict[str, "PVValue"]:
"""
Get the current beamline parameter values as a dictionary.
"""
assert self.blockserver is not None
return self._get_pars("BL", self.blockserver.get_beamline_par_names)

def set_beamline_par(self, name: str, value: "PVValue") -> None:
Expand All @@ -656,6 +706,7 @@
name: the name of the parameter to change
value: the new value
"""
assert self.blockserver is not None
names = self.blockserver.get_beamline_par_names()
if names is not None:
for n in names:
Expand All @@ -665,7 +716,7 @@
return
raise Exception("Beamline parameter %s does not exist" % name)

def get_runcontrol_settings(self, block_name: str) -> tuple[bool, float, float]:
def get_runcontrol_settings(self, block_name: str) -> tuple["PVValue", "PVValue", "PVValue"]:
"""
Gets the current run-control settings for a block.

Expand Down Expand Up @@ -711,19 +762,20 @@
list: the blocks which have soft limit violations
"""
violation_states = self._get_fields_from_blocks(blocks, "LVIO", "limit violation")
return [t[0] for t in violation_states if t[1] == 1]

return [t[0] for t in violation_states if typing.cast(bool, t[1]) == 1]

def _get_fields_from_blocks(
self, blocks: list[str], field_name: str, field_description: str
) -> list["PVValue"]:
) -> list[tuple[str, "PVValue"]]:
field_values = list()
for block in blocks:
if self.block_exists(block):
block_name = self.correct_blockname(block, False)
full_block_pv = self.get_pv_from_block(block)
try:
field_value = self.get_pv_value(full_block_pv + "." + field_name, attempts=1)
field_values.append([block_name, field_value])
field_values.append((block_name, field_value))
except IOError:
# Could not get value
print("Could not get {} for block: {}".format(field_description, block))
Expand Down Expand Up @@ -817,7 +869,7 @@
except Exception as e:
raise Exception("Could not send email: {}".format(e))

def send_alert(self, message: str, inst: str) -> None:
def send_alert(self, message: str, inst: str | None) -> None:
"""
Sends an alert message for a specified instrument.

Expand Down Expand Up @@ -860,13 +912,15 @@
alarm status could not be determined
"""
try:
return self.get_pv_value(
alarm_val = self.get_pv_value(
"{}.SEVR".format(remove_field_from_pv(pv_name)), to_string=True
)
return alarm_val

except Exception:
return "UNKNOWN"

def get_block_data(self, block: str, fail_fast: bool = False) -> dict:
def get_block_data(self, block: str, fail_fast: bool = False) -> dict[str, "PVValue"]:
"""
Gets the useful values associated with a block.

Expand Down
Loading