Skip to content

Commit af27df0

Browse files
authored
Merge pull request #46 from ISISComputingGroup/Ticket8553_Fix_ruff_and_pyright_in_genie_epics_api
Fix pyright and ruff for genie_epics_api
2 parents 07ad083 + b621085 commit af27df0

File tree

2 files changed

+89
-34
lines changed

2 files changed

+89
-34
lines changed

src/genie_python/block_names.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ class BlockNamesManager:
2121

2222
def __init__(
2323
self,
24-
block_names: "PVValue",
24+
block_names: "BlockNames",
2525
delay_before_retry_add_monitor: float = DELAY_BEFORE_RETRYING_BLOCK_NAMES_PV_ON_FAIL,
2626
) -> None:
2727
"""
@@ -100,6 +100,7 @@ def _update_block_names(self, value: "PVValue", _: Optional[str], _1: Optional[s
100100

101101
# add new block as attributes to class
102102
try:
103+
assert isinstance(value, (str, bytes)), value
103104
block_names = dehex_decompress_and_dejson(value)
104105
for name in block_names:
105106
attribute_name = name

src/genie_python/genie_epics_api.py

Lines changed: 87 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from builtins import str
1111
from collections import OrderedDict
1212
from io import open
13-
from typing import TYPE_CHECKING, Callable
13+
from typing import TYPE_CHECKING, Any, Callable
1414

1515
from genie_python.block_names import BlockNames, BlockNamesManager
1616
from genie_python.channel_access_exceptions import UnableToConnectToPVException
@@ -22,7 +22,7 @@
2222
from genie_python.genie_logging import filter as logging_filter
2323
from genie_python.genie_pre_post_cmd_manager import PrePostCmdManager
2424
from genie_python.genie_wait_for_move import WaitForMoveController
25-
from genie_python.genie_waitfor import WaitForController
25+
from genie_python.genie_waitfor import WAITFOR_VALUE, WaitForController
2626
from genie_python.utilities import (
2727
EnvironmentDetails,
2828
crc8,
@@ -46,7 +46,10 @@
4646

4747
class API(object):
4848
def __init__(
49-
self, pv_prefix: str, globs: dict, environment_details: EnvironmentDetails | None = None
49+
self,
50+
pv_prefix: str,
51+
globs: dict[str, Any],
52+
environment_details: EnvironmentDetails | None = None,
5053
) -> None:
5154
"""
5255
Constructor for the EPICS enabled API.
@@ -56,12 +59,12 @@ def __init__(
5659
globs: globals
5760
environment_details: details of the computer environment
5861
"""
59-
self.waitfor = None # type: WaitForController
60-
self.wait_for_move = None
61-
self.dae = None # type: Dae
62-
self.blockserver = None # type: BlockServer
63-
self.exp_data = None # type: GetExperimentData
64-
self.inst_prefix = ""
62+
self.waitfor: WaitForController | None = None
63+
self.wait_for_move: WaitForMoveController | None = None
64+
self.dae: Dae | None = None
65+
self.blockserver: BlockServer | None = None
66+
self.exp_data: GetExperimentData | None = None
67+
self.inst_prefix: str = ""
6568
self.instrument_name = ""
6669
self.machine_name = ""
6770
self.localmod = None
@@ -108,7 +111,9 @@ def get_instrument_py_name(self) -> str:
108111
"""
109112
return self.instrument_name.lower().replace("-", "_")
110113

111-
def _get_machine_details_from_identifier(self, machine_identifier: str) -> tuple[str, str, str]:
114+
def _get_machine_details_from_identifier(
115+
self, machine_identifier: str | None
116+
) -> tuple[str, str, str]:
112117
"""
113118
Gets the details of a machine by looking it up in the instrument list first.
114119
If there is no match it calculates the details as usual.
@@ -137,7 +142,9 @@ def _get_machine_details_from_identifier(self, machine_identifier: str) -> tuple
137142
# that's been passed to this function if it is not found instrument_details will be None
138143
instrument_details = None
139144
try:
140-
instrument_list = dehex_decompress_and_dejson(self.get_pv_value("CS:INSTLIST"))
145+
input_list = self.get_pv_value("CS:INSTLIST")
146+
assert isinstance(input_list, str | bytes)
147+
instrument_list = dehex_decompress_and_dejson(input_list)
141148
instrument_details = next(
142149
(inst for inst in instrument_list if inst["pvPrefix"] == machine_identifier), None
143150
)
@@ -182,7 +189,7 @@ def get_instrument_full_name(self) -> str:
182189
return self.machine_name
183190

184191
def set_instrument(
185-
self, machine_identifier: str, globs: dict, import_instrument_init: bool = True
192+
self, machine_identifier: str, globs: dict[str, Any], import_instrument_init: bool = True
186193
) -> None:
187194
"""
188195
Set the instrument being used by setting the PV prefix or by the
@@ -262,12 +269,14 @@ def prefix_pv_name(self, name: str) -> str:
262269
"""
263270
Adds the instrument prefix to the specified PV.
264271
"""
265-
if self.inst_prefix is not None:
266-
return self.inst_prefix + name
267-
return name
272+
return self.inst_prefix + name
268273

269274
def init_instrument(
270-
self, instrument: str, machine_name: str, globs: dict, import_instrument_init: bool
275+
self,
276+
instrument: str,
277+
machine_name: str,
278+
globs: dict[str, Any],
279+
import_instrument_init: bool,
271280
) -> None:
272281
"""
273282
Initialise an instrument using the default init file followed by the machine specific init.
@@ -298,10 +307,14 @@ def init_instrument(
298307
# Load the instrument init file
299308
self.localmod = importlib.import_module("init_{}".format(instrument))
300309

301-
if self.localmod.__file__.endswith(".pyc"):
302-
file_loc = self.localmod.__file__[:-1]
310+
_file = self.localmod.__file__
311+
assert _file is not None
312+
313+
if _file.endswith(".pyc"):
314+
file_loc = _file[:-1]
303315
else:
304-
file_loc = self.localmod.__file__
316+
file_loc = _file
317+
assert isinstance(file_loc, str)
305318
# execfile - this puts any imports in the init file into the globals namespace
306319
# Note: Anything loose in the module like print statements will be run twice
307320
exec(compile(open(file_loc).read(), file_loc, "exec"), globs)
@@ -351,6 +364,26 @@ def set_pv_value(
351364
self.logger.log_error_msg("set_pv_value exception {!r}".format(e))
352365
raise e
353366

367+
@typing.overload
368+
def get_pv_value(
369+
self,
370+
name: str,
371+
to_string: typing.Literal[True] = True,
372+
attempts: int = 3,
373+
is_local: bool = True,
374+
use_numpy: bool | None = None,
375+
) -> str: ...
376+
377+
@typing.overload
378+
def get_pv_value(
379+
self,
380+
name: str,
381+
to_string: bool = False,
382+
attempts: int = 3,
383+
is_local: bool = False,
384+
use_numpy: bool | None = None,
385+
) -> "PVValue": ...
386+
354387
def get_pv_value(
355388
self,
356389
name: str,
@@ -448,6 +481,7 @@ def reload_current_config(self) -> None:
448481
"""
449482
Reload the current configuration.
450483
"""
484+
assert self.blockserver is not None
451485
self.blockserver.reload_current_config()
452486

453487
def correct_blockname(self, name: str, add_prefix: bool = True) -> str:
@@ -472,7 +506,7 @@ def get_block_names(self) -> list[str]:
472506
"""
473507
return [name for name in BLOCK_NAMES.__dict__.keys()]
474508

475-
def block_exists(self, name: str, fail_fast: bool = False) -> str | None:
509+
def block_exists(self, name: str, fail_fast: bool = False) -> bool:
476510
"""
477511
Checks whether the block exists.
478512
@@ -510,6 +544,8 @@ def set_block_value(
510544
full_name = self.get_pv_from_block(name)
511545

512546
if lowlimit is not None and highlimit is not None:
547+
if not isinstance(value, (float, int)):
548+
raise ValueError("Both limits provided but value is not a number")
513549
if lowlimit > highlimit:
514550
print(
515551
"Low limit ({}) higher than high limit ({}), "
@@ -531,6 +567,9 @@ def set_block_value(
531567
self.set_pv_value(full_name, value)
532568

533569
if wait:
570+
if not isinstance(value, WAITFOR_VALUE):
571+
raise ValueError(f"Wait value is not a WAITFOR_VALUE: {value}")
572+
assert self.waitfor is not None
534573
self.waitfor.start_waiting(name, value, lowlimit, highlimit)
535574
return
536575

@@ -599,14 +638,18 @@ def get_block_units(self, block_name: str) -> str | None:
599638
return typing.cast(str | None, Wrapper.get_pv_value(unit_name))
600639

601640
def _get_pars(
602-
self, pv_prefix_identifier: str, get_names_from_blockserver: Callable[[], list[str]]
603-
) -> dict:
641+
self, pv_prefix_identifier: str, get_names_from_blockserver: Callable[[], Any]
642+
) -> dict[str, "PVValue"]:
604643
"""
605644
Get the current parameter values for a given pv subset as a dictionary.
606645
"""
607646
names = get_names_from_blockserver()
608647
ans = {}
609-
if names is not None:
648+
if (
649+
names is not None
650+
and isinstance(names, list)
651+
and all(isinstance(elem, str) for elem in names)
652+
):
610653
for n in names:
611654
val = self.get_pv_value(self.prefix_pv_name(n))
612655
m = re.match(".+:" + pv_prefix_identifier + ":(.+)", n)
@@ -618,10 +661,11 @@ def _get_pars(
618661
)
619662
return ans
620663

621-
def get_sample_pars(self) -> dict:
664+
def get_sample_pars(self) -> dict[str, "PVValue"]:
622665
"""
623666
Get the current sample parameter values as a dictionary.
624667
"""
668+
assert self.blockserver is not None
625669
return self._get_pars("SAMPLE", self.blockserver.get_sample_par_names)
626670

627671
def set_sample_par(self, name: str, value: "PVValue") -> None:
@@ -632,8 +676,13 @@ def set_sample_par(self, name: str, value: "PVValue") -> None:
632676
name: the name of the parameter to change
633677
value: the new value
634678
"""
679+
assert self.blockserver is not None
635680
names = self.blockserver.get_sample_par_names()
636-
if names is not None:
681+
if (
682+
names is not None
683+
and isinstance(names, list)
684+
and all(isinstance(elem, str) for elem in names)
685+
):
637686
for n in names:
638687
m = re.match(".+:SAMPLE:%s" % name.upper(), n)
639688
if m is not None:
@@ -642,10 +691,11 @@ def set_sample_par(self, name: str, value: "PVValue") -> None:
642691
return
643692
raise Exception("Sample parameter %s does not exist" % name)
644693

645-
def get_beamline_pars(self) -> dict:
694+
def get_beamline_pars(self) -> dict[str, "PVValue"]:
646695
"""
647696
Get the current beamline parameter values as a dictionary.
648697
"""
698+
assert self.blockserver is not None
649699
return self._get_pars("BL", self.blockserver.get_beamline_par_names)
650700

651701
def set_beamline_par(self, name: str, value: "PVValue") -> None:
@@ -656,6 +706,7 @@ def set_beamline_par(self, name: str, value: "PVValue") -> None:
656706
name: the name of the parameter to change
657707
value: the new value
658708
"""
709+
assert self.blockserver is not None
659710
names = self.blockserver.get_beamline_par_names()
660711
if names is not None:
661712
for n in names:
@@ -665,7 +716,7 @@ def set_beamline_par(self, name: str, value: "PVValue") -> None:
665716
return
666717
raise Exception("Beamline parameter %s does not exist" % name)
667718

668-
def get_runcontrol_settings(self, block_name: str) -> tuple[bool, float, float]:
719+
def get_runcontrol_settings(self, block_name: str) -> tuple["PVValue", "PVValue", "PVValue"]:
669720
"""
670721
Gets the current run-control settings for a block.
671722
@@ -711,19 +762,20 @@ def check_limit_violations(self, blocks: list[str]) -> list[str]:
711762
list: the blocks which have soft limit violations
712763
"""
713764
violation_states = self._get_fields_from_blocks(blocks, "LVIO", "limit violation")
714-
return [t[0] for t in violation_states if t[1] == 1]
765+
766+
return [t[0] for t in violation_states if t[1]]
715767

716768
def _get_fields_from_blocks(
717769
self, blocks: list[str], field_name: str, field_description: str
718-
) -> list["PVValue"]:
770+
) -> list[tuple[str, "PVValue"]]:
719771
field_values = list()
720772
for block in blocks:
721773
if self.block_exists(block):
722774
block_name = self.correct_blockname(block, False)
723775
full_block_pv = self.get_pv_from_block(block)
724776
try:
725777
field_value = self.get_pv_value(full_block_pv + "." + field_name, attempts=1)
726-
field_values.append([block_name, field_value])
778+
field_values.append((block_name, field_value))
727779
except IOError:
728780
# Could not get value
729781
print("Could not get {} for block: {}".format(field_description, block))
@@ -817,7 +869,7 @@ def send_email(self, address: str, message: str) -> None:
817869
except Exception as e:
818870
raise Exception("Could not send email: {}".format(e))
819871

820-
def send_alert(self, message: str, inst: str) -> None:
872+
def send_alert(self, message: str, inst: str | None) -> None:
821873
"""
822874
Sends an alert message for a specified instrument.
823875
@@ -860,13 +912,15 @@ def get_pv_alarm(self, pv_name: str) -> str:
860912
alarm status could not be determined
861913
"""
862914
try:
863-
return self.get_pv_value(
915+
alarm_val = self.get_pv_value(
864916
"{}.SEVR".format(remove_field_from_pv(pv_name)), to_string=True
865917
)
918+
return alarm_val
919+
866920
except Exception:
867921
return "UNKNOWN"
868922

869-
def get_block_data(self, block: str, fail_fast: bool = False) -> dict:
923+
def get_block_data(self, block: str, fail_fast: bool = False) -> dict[str, "PVValue"]:
870924
"""
871925
Gets the useful values associated with a block.
872926

0 commit comments

Comments
 (0)