@@ -19,16 +19,15 @@ class is the base class for all models. The ``Subcircuit`` class is where
19
19
"""
20
20
21
21
import os
22
- from typing import TYPE_CHECKING , ClassVar , Dict , List , Optional , Tuple , Union
22
+ from typing import ClassVar , Dict , List , Optional , Tuple , Union
23
+
24
+ import numpy as np
23
25
24
26
from simphony .connect import create_block_diagonal , innerconnect_s
25
27
from simphony .formatters import ModelFormatter , ModelJSONFormatter
26
28
from simphony .layout import Circuit
27
29
from simphony .pins import Pin , PinList
28
30
29
- if TYPE_CHECKING :
30
- import numpy as np
31
-
32
31
33
32
class Model :
34
33
"""The basic element type describing the model for a component with
@@ -230,7 +229,7 @@ def _on_disconnect_recursive(self, circuit: Circuit) -> None:
230
229
if circuit ._add (component ):
231
230
component ._on_disconnect_recursive (circuit )
232
231
233
- def connect (self , component_or_pin : Union ["Model" , Pin ]) -> None :
232
+ def connect (self , component_or_pin : Union ["Model" , Pin ]) -> "Model" :
234
233
"""Connects the next available (unconnected) pin from this component to
235
234
the component/pin passed in as the argument.
236
235
@@ -239,13 +238,14 @@ def connect(self, component_or_pin: Union["Model", Pin]) -> None:
239
238
component.
240
239
"""
241
240
self ._get_next_unconnected_pin ().connect (component_or_pin )
241
+ return self
242
242
243
243
def disconnect (self ) -> None :
244
244
"""Disconnects this component from all other components."""
245
245
for pin in self .pins :
246
246
pin .disconnect ()
247
247
248
- def interface (self , component : "Model" ) -> None :
248
+ def interface (self , component : "Model" ) -> "Model" :
249
249
"""Interfaces this component to the component passed in by connecting
250
250
pins with the same names.
251
251
@@ -256,6 +256,8 @@ def interface(self, component: "Model") -> None:
256
256
if selfpin .name [0 :3 ] != "pin" and selfpin .name == componentpin .name :
257
257
selfpin .connect (componentpin )
258
258
259
+ return self
260
+
259
261
def monte_carlo_s_parameters (self , freqs : "np.array" ) -> "np.ndarray" :
260
262
"""Implements the monte carlo routine for the given Model.
261
263
@@ -278,7 +280,7 @@ def monte_carlo_s_parameters(self, freqs: "np.array") -> "np.ndarray":
278
280
"""
279
281
return self .s_parameters (freqs )
280
282
281
- def multiconnect (self , * connections : Union ["Model" , Pin , None ]) -> None :
283
+ def multiconnect (self , * connections : Union ["Model" , Pin , None ]) -> "Model" :
282
284
"""Connects this component to the specified connections by looping
283
285
through each connection and connecting it with the corresponding pin.
284
286
@@ -293,6 +295,8 @@ def multiconnect(self, *connections: Union["Model", Pin, None]) -> None:
293
295
if connection is not None :
294
296
self .pins [index ].connect (connection )
295
297
298
+ return self
299
+
296
300
def regenerate_monte_carlo_parameters (self ) -> None :
297
301
"""Regenerates parameters used to generate monte carlo s-matrices.
298
302
@@ -540,6 +544,7 @@ def _s_parameters(
540
544
The method name to call to get the scattering parameters.
541
545
Either 's_parameters' or 'monte_carlo_s_parameters'
542
546
"""
547
+ from simphony .simulation import SimulationModel
543
548
from simphony .simulators import Simulator
544
549
545
550
all_pins = []
@@ -549,16 +554,36 @@ def _s_parameters(
549
554
# merge all of the s_params into one giant block diagonal matrix
550
555
for component in self ._wrapped_circuit :
551
556
# simulators don't have scattering parameters
552
- if isinstance (component , Simulator ):
557
+ if isinstance (component , Simulator ) or isinstance (
558
+ component , SimulationModel
559
+ ):
553
560
continue
554
561
555
562
# get the s_params from the cache if possible
556
563
if s_parameters_method == "s_parameters" :
557
- try :
558
- s_params = self .__class__ .scache [component ]
559
- except KeyError :
560
- s_params = getattr (component , s_parameters_method )(freqs )
561
- self .__class__ .scache [component ] = s_params
564
+ # each frequency has a different s-matrix, so we need to cache
565
+ # the s-matrices by frequency as well as component
566
+ s_params = []
567
+ for freq in freqs :
568
+ try :
569
+ # use the cached s-matrix if available
570
+ s_matrix = self .__class__ .scache [component ][freq ]
571
+ except KeyError :
572
+ # make sure the frequency dict is created
573
+ if component not in self .__class__ .scache :
574
+ self .__class__ .scache [component ] = {}
575
+
576
+ # store the s-matrix for the frequency and component
577
+ s_matrix = getattr (component , s_parameters_method )(
578
+ np .array ([freq ])
579
+ )[0 ]
580
+ self .__class__ .scache [component ][freq ] = s_matrix
581
+
582
+ # add the s-matrix to our list of s-matrices
583
+ s_params .append (s_matrix )
584
+
585
+ # convert to numpy array for the rest of the function
586
+ s_params = np .array (s_params )
562
587
elif s_parameters_method == "monte_carlo_s_parameters" :
563
588
# don't cache Monte Carlo scattering parameters
564
589
s_params = getattr (component , s_parameters_method )(freqs )
0 commit comments