10
10
from builtins import str
11
11
from collections import OrderedDict
12
12
from io import open
13
- from typing import TYPE_CHECKING , Callable
13
+ from typing import TYPE_CHECKING , Any , Callable
14
14
15
15
from genie_python .block_names import BlockNames , BlockNamesManager
16
16
from genie_python .channel_access_exceptions import UnableToConnectToPVException
22
22
from genie_python .genie_logging import filter as logging_filter
23
23
from genie_python .genie_pre_post_cmd_manager import PrePostCmdManager
24
24
from genie_python .genie_wait_for_move import WaitForMoveController
25
- from genie_python .genie_waitfor import WaitForController
25
+ from genie_python .genie_waitfor import WAITFOR_VALUE , WaitForController
26
26
from genie_python .utilities import (
27
27
EnvironmentDetails ,
28
28
crc8 ,
46
46
47
47
class API (object ):
48
48
def __init__ (
49
- self , pv_prefix : str , globs : dict , environment_details : EnvironmentDetails | None = None
49
+ self ,
50
+ pv_prefix : str ,
51
+ globs : dict [str , Any ],
52
+ environment_details : EnvironmentDetails | None = None ,
50
53
) -> None :
51
54
"""
52
55
Constructor for the EPICS enabled API.
@@ -56,12 +59,12 @@ def __init__(
56
59
globs: globals
57
60
environment_details: details of the computer environment
58
61
"""
59
- self .waitfor = None # type: WaitForController
60
- self .wait_for_move = None
61
- self .dae = None # type: Dae
62
- self .blockserver = None # type: BlockServer
63
- self .exp_data = None # type: GetExperimentData
64
- self .inst_prefix = ""
62
+ self .waitfor : WaitForController | None = None
63
+ self .wait_for_move : WaitForMoveController | None = None
64
+ self .dae : Dae | None = None
65
+ self .blockserver : BlockServer | None = None
66
+ self .exp_data : GetExperimentData | None = None
67
+ self .inst_prefix : str = ""
65
68
self .instrument_name = ""
66
69
self .machine_name = ""
67
70
self .localmod = None
@@ -108,7 +111,9 @@ def get_instrument_py_name(self) -> str:
108
111
"""
109
112
return self .instrument_name .lower ().replace ("-" , "_" )
110
113
111
- def _get_machine_details_from_identifier (self , machine_identifier : str ) -> tuple [str , str , str ]:
114
+ def _get_machine_details_from_identifier (
115
+ self , machine_identifier : str | None
116
+ ) -> tuple [str , str , str ]:
112
117
"""
113
118
Gets the details of a machine by looking it up in the instrument list first.
114
119
If there is no match it calculates the details as usual.
@@ -137,7 +142,9 @@ def _get_machine_details_from_identifier(self, machine_identifier: str) -> tuple
137
142
# that's been passed to this function if it is not found instrument_details will be None
138
143
instrument_details = None
139
144
try :
140
- instrument_list = dehex_decompress_and_dejson (self .get_pv_value ("CS:INSTLIST" ))
145
+ input_list = self .get_pv_value ("CS:INSTLIST" )
146
+ assert isinstance (input_list , str | bytes )
147
+ instrument_list = dehex_decompress_and_dejson (input_list )
141
148
instrument_details = next (
142
149
(inst for inst in instrument_list if inst ["pvPrefix" ] == machine_identifier ), None
143
150
)
@@ -182,7 +189,7 @@ def get_instrument_full_name(self) -> str:
182
189
return self .machine_name
183
190
184
191
def set_instrument (
185
- self , machine_identifier : str , globs : dict , import_instrument_init : bool = True
192
+ self , machine_identifier : str , globs : dict [ str , Any ] , import_instrument_init : bool = True
186
193
) -> None :
187
194
"""
188
195
Set the instrument being used by setting the PV prefix or by the
@@ -262,12 +269,14 @@ def prefix_pv_name(self, name: str) -> str:
262
269
"""
263
270
Adds the instrument prefix to the specified PV.
264
271
"""
265
- if self .inst_prefix is not None :
266
- return self .inst_prefix + name
267
- return name
272
+ return self .inst_prefix + name
268
273
269
274
def init_instrument (
270
- self , instrument : str , machine_name : str , globs : dict , import_instrument_init : bool
275
+ self ,
276
+ instrument : str ,
277
+ machine_name : str ,
278
+ globs : dict [str , Any ],
279
+ import_instrument_init : bool ,
271
280
) -> None :
272
281
"""
273
282
Initialise an instrument using the default init file followed by the machine specific init.
@@ -298,10 +307,14 @@ def init_instrument(
298
307
# Load the instrument init file
299
308
self .localmod = importlib .import_module ("init_{}" .format (instrument ))
300
309
301
- if self .localmod .__file__ .endswith (".pyc" ):
302
- file_loc = self .localmod .__file__ [:- 1 ]
310
+ _file = self .localmod .__file__
311
+ assert _file is not None
312
+
313
+ if _file .endswith (".pyc" ):
314
+ file_loc = _file [:- 1 ]
303
315
else :
304
- file_loc = self .localmod .__file__
316
+ file_loc = _file
317
+ assert isinstance (file_loc , str )
305
318
# execfile - this puts any imports in the init file into the globals namespace
306
319
# Note: Anything loose in the module like print statements will be run twice
307
320
exec (compile (open (file_loc ).read (), file_loc , "exec" ), globs )
@@ -351,6 +364,26 @@ def set_pv_value(
351
364
self .logger .log_error_msg ("set_pv_value exception {!r}" .format (e ))
352
365
raise e
353
366
367
+ @typing .overload
368
+ def get_pv_value (
369
+ self ,
370
+ name : str ,
371
+ to_string : typing .Literal [True ] = True ,
372
+ attempts : int = 3 ,
373
+ is_local : bool = True ,
374
+ use_numpy : bool | None = None ,
375
+ ) -> str : ...
376
+
377
+ @typing .overload
378
+ def get_pv_value (
379
+ self ,
380
+ name : str ,
381
+ to_string : bool = False ,
382
+ attempts : int = 3 ,
383
+ is_local : bool = False ,
384
+ use_numpy : bool | None = None ,
385
+ ) -> "PVValue" : ...
386
+
354
387
def get_pv_value (
355
388
self ,
356
389
name : str ,
@@ -448,6 +481,7 @@ def reload_current_config(self) -> None:
448
481
"""
449
482
Reload the current configuration.
450
483
"""
484
+ assert self .blockserver is not None
451
485
self .blockserver .reload_current_config ()
452
486
453
487
def correct_blockname (self , name : str , add_prefix : bool = True ) -> str :
@@ -472,7 +506,7 @@ def get_block_names(self) -> list[str]:
472
506
"""
473
507
return [name for name in BLOCK_NAMES .__dict__ .keys ()]
474
508
475
- def block_exists (self , name : str , fail_fast : bool = False ) -> str | None :
509
+ def block_exists (self , name : str , fail_fast : bool = False ) -> bool :
476
510
"""
477
511
Checks whether the block exists.
478
512
@@ -510,6 +544,8 @@ def set_block_value(
510
544
full_name = self .get_pv_from_block (name )
511
545
512
546
if lowlimit is not None and highlimit is not None :
547
+ if not isinstance (value , (float , int )):
548
+ raise ValueError ("Both limits provided but value is not a number" )
513
549
if lowlimit > highlimit :
514
550
print (
515
551
"Low limit ({}) higher than high limit ({}), "
@@ -531,6 +567,9 @@ def set_block_value(
531
567
self .set_pv_value (full_name , value )
532
568
533
569
if wait :
570
+ if not isinstance (value , WAITFOR_VALUE ):
571
+ raise ValueError (f"Wait value is not a WAITFOR_VALUE: { value } " )
572
+ assert self .waitfor is not None
534
573
self .waitfor .start_waiting (name , value , lowlimit , highlimit )
535
574
return
536
575
@@ -599,14 +638,18 @@ def get_block_units(self, block_name: str) -> str | None:
599
638
return typing .cast (str | None , Wrapper .get_pv_value (unit_name ))
600
639
601
640
def _get_pars (
602
- self , pv_prefix_identifier : str , get_names_from_blockserver : Callable [[], list [ str ] ]
603
- ) -> dict :
641
+ self , pv_prefix_identifier : str , get_names_from_blockserver : Callable [[], Any ]
642
+ ) -> dict [ str , "PVValue" ] :
604
643
"""
605
644
Get the current parameter values for a given pv subset as a dictionary.
606
645
"""
607
646
names = get_names_from_blockserver ()
608
647
ans = {}
609
- if names is not None :
648
+ if (
649
+ names is not None
650
+ and isinstance (names , list )
651
+ and all (isinstance (elem , str ) for elem in names )
652
+ ):
610
653
for n in names :
611
654
val = self .get_pv_value (self .prefix_pv_name (n ))
612
655
m = re .match (".+:" + pv_prefix_identifier + ":(.+)" , n )
@@ -618,10 +661,11 @@ def _get_pars(
618
661
)
619
662
return ans
620
663
621
- def get_sample_pars (self ) -> dict :
664
+ def get_sample_pars (self ) -> dict [ str , "PVValue" ] :
622
665
"""
623
666
Get the current sample parameter values as a dictionary.
624
667
"""
668
+ assert self .blockserver is not None
625
669
return self ._get_pars ("SAMPLE" , self .blockserver .get_sample_par_names )
626
670
627
671
def set_sample_par (self , name : str , value : "PVValue" ) -> None :
@@ -632,8 +676,13 @@ def set_sample_par(self, name: str, value: "PVValue") -> None:
632
676
name: the name of the parameter to change
633
677
value: the new value
634
678
"""
679
+ assert self .blockserver is not None
635
680
names = self .blockserver .get_sample_par_names ()
636
- if names is not None :
681
+ if (
682
+ names is not None
683
+ and isinstance (names , list )
684
+ and all (isinstance (elem , str ) for elem in names )
685
+ ):
637
686
for n in names :
638
687
m = re .match (".+:SAMPLE:%s" % name .upper (), n )
639
688
if m is not None :
@@ -642,10 +691,11 @@ def set_sample_par(self, name: str, value: "PVValue") -> None:
642
691
return
643
692
raise Exception ("Sample parameter %s does not exist" % name )
644
693
645
- def get_beamline_pars (self ) -> dict :
694
+ def get_beamline_pars (self ) -> dict [ str , "PVValue" ] :
646
695
"""
647
696
Get the current beamline parameter values as a dictionary.
648
697
"""
698
+ assert self .blockserver is not None
649
699
return self ._get_pars ("BL" , self .blockserver .get_beamline_par_names )
650
700
651
701
def set_beamline_par (self , name : str , value : "PVValue" ) -> None :
@@ -656,6 +706,7 @@ def set_beamline_par(self, name: str, value: "PVValue") -> None:
656
706
name: the name of the parameter to change
657
707
value: the new value
658
708
"""
709
+ assert self .blockserver is not None
659
710
names = self .blockserver .get_beamline_par_names ()
660
711
if names is not None :
661
712
for n in names :
@@ -665,7 +716,7 @@ def set_beamline_par(self, name: str, value: "PVValue") -> None:
665
716
return
666
717
raise Exception ("Beamline parameter %s does not exist" % name )
667
718
668
- def get_runcontrol_settings (self , block_name : str ) -> tuple [bool , float , float ]:
719
+ def get_runcontrol_settings (self , block_name : str ) -> tuple ["PVValue" , "PVValue" , "PVValue" ]:
669
720
"""
670
721
Gets the current run-control settings for a block.
671
722
@@ -711,19 +762,20 @@ def check_limit_violations(self, blocks: list[str]) -> list[str]:
711
762
list: the blocks which have soft limit violations
712
763
"""
713
764
violation_states = self ._get_fields_from_blocks (blocks , "LVIO" , "limit violation" )
714
- return [t [0 ] for t in violation_states if t [1 ] == 1 ]
765
+
766
+ return [t [0 ] for t in violation_states if t [1 ]]
715
767
716
768
def _get_fields_from_blocks (
717
769
self , blocks : list [str ], field_name : str , field_description : str
718
- ) -> list ["PVValue" ]:
770
+ ) -> list [tuple [ str , "PVValue" ] ]:
719
771
field_values = list ()
720
772
for block in blocks :
721
773
if self .block_exists (block ):
722
774
block_name = self .correct_blockname (block , False )
723
775
full_block_pv = self .get_pv_from_block (block )
724
776
try :
725
777
field_value = self .get_pv_value (full_block_pv + "." + field_name , attempts = 1 )
726
- field_values .append ([ block_name , field_value ] )
778
+ field_values .append (( block_name , field_value ) )
727
779
except IOError :
728
780
# Could not get value
729
781
print ("Could not get {} for block: {}" .format (field_description , block ))
@@ -817,7 +869,7 @@ def send_email(self, address: str, message: str) -> None:
817
869
except Exception as e :
818
870
raise Exception ("Could not send email: {}" .format (e ))
819
871
820
- def send_alert (self , message : str , inst : str ) -> None :
872
+ def send_alert (self , message : str , inst : str | None ) -> None :
821
873
"""
822
874
Sends an alert message for a specified instrument.
823
875
@@ -860,13 +912,15 @@ def get_pv_alarm(self, pv_name: str) -> str:
860
912
alarm status could not be determined
861
913
"""
862
914
try :
863
- return self .get_pv_value (
915
+ alarm_val = self .get_pv_value (
864
916
"{}.SEVR" .format (remove_field_from_pv (pv_name )), to_string = True
865
917
)
918
+ return alarm_val
919
+
866
920
except Exception :
867
921
return "UNKNOWN"
868
922
869
- def get_block_data (self , block : str , fail_fast : bool = False ) -> dict :
923
+ def get_block_data (self , block : str , fail_fast : bool = False ) -> dict [ str , "PVValue" ] :
870
924
"""
871
925
Gets the useful values associated with a block.
872
926
0 commit comments