Skip to content

Commit fe21795

Browse files
authored
Merge pull request #54 from BYUCamachoLab/simulation
Simulation contexts
2 parents dcea736 + 56315fd commit fe21795

File tree

8 files changed

+1818
-16
lines changed

8 files changed

+1818
-16
lines changed

.github/workflows/build-and-test.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ jobs:
1717
- name: Lint with flake8
1818
run: |
1919
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
20-
flake8 . --count --ignore=E501,E741,W503,W605 --max-complexity=10 --statistics
20+
flake8 . --count --ignore=E501,E741,W503,W605 --max-complexity=12 --statistics
2121
- name: Run Tox
2222
run: tox -e py
2323
strategy:

docs/source/reference/api.rst

+5
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,11 @@ API
2828
:inherited-members:
2929
:show-inheritance:
3030

31+
.. automodule:: simphony.simulation
32+
:members:
33+
:inherited-members:
34+
:show-inheritance:
35+
3136
.. automodule:: simphony.simulators
3237
:members:
3338
:inherited-members:

simphony/formatters.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -220,12 +220,15 @@ class CircuitJSONFormatter:
220220
"""This class handles converting a circuit to JSON and vice-versa."""
221221

222222
def format(self, circuit: "Circuit", freqs: np.array) -> str:
223+
from simphony.simulation import SimulationModel
223224
from simphony.simulators import Simulator
224225

225226
data = {"components": [], "connections": []}
226227
for i, component in enumerate(circuit):
227228
# skip simulators
228-
if isinstance(component, Simulator):
229+
if isinstance(component, Simulator) or isinstance(
230+
component, SimulationModel
231+
):
229232
continue
230233

231234
# get a representation for each component

simphony/layout.py

+9
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,15 @@ def pins(self) -> List["Pin"]:
136136
"""Returns the pins for the circuit."""
137137
return self.to_subcircuit(permanent=False).pins
138138

139+
def get_pin_index(self, pin: "Pin") -> int:
140+
"""Gets the pin index for the specified pin in the scattering
141+
parameters."""
142+
for i, _pin in enumerate(self.pins):
143+
if _pin == pin:
144+
return i
145+
146+
raise ValueError("The pin must belong to the circuit.")
147+
139148
def s_parameters(self, freqs: "np.array") -> "np.ndarray":
140149
"""Returns the scattering parameters for the circuit."""
141150
return self.to_subcircuit(permanent=False).s_parameters(freqs)

simphony/models.py

+38-13
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,15 @@ class is the base class for all models. The ``Subcircuit`` class is where
1919
"""
2020

2121
import os
22-
from typing import TYPE_CHECKING, ClassVar, Dict, List, Optional, Tuple, Union
22+
from typing import ClassVar, Dict, List, Optional, Tuple, Union
23+
24+
import numpy as np
2325

2426
from simphony.connect import create_block_diagonal, innerconnect_s
2527
from simphony.formatters import ModelFormatter, ModelJSONFormatter
2628
from simphony.layout import Circuit
2729
from simphony.pins import Pin, PinList
2830

29-
if TYPE_CHECKING:
30-
import numpy as np
31-
3231

3332
class Model:
3433
"""The basic element type describing the model for a component with
@@ -230,7 +229,7 @@ def _on_disconnect_recursive(self, circuit: Circuit) -> None:
230229
if circuit._add(component):
231230
component._on_disconnect_recursive(circuit)
232231

233-
def connect(self, component_or_pin: Union["Model", Pin]) -> None:
232+
def connect(self, component_or_pin: Union["Model", Pin]) -> "Model":
234233
"""Connects the next available (unconnected) pin from this component to
235234
the component/pin passed in as the argument.
236235
@@ -239,13 +238,14 @@ def connect(self, component_or_pin: Union["Model", Pin]) -> None:
239238
component.
240239
"""
241240
self._get_next_unconnected_pin().connect(component_or_pin)
241+
return self
242242

243243
def disconnect(self) -> None:
244244
"""Disconnects this component from all other components."""
245245
for pin in self.pins:
246246
pin.disconnect()
247247

248-
def interface(self, component: "Model") -> None:
248+
def interface(self, component: "Model") -> "Model":
249249
"""Interfaces this component to the component passed in by connecting
250250
pins with the same names.
251251
@@ -256,6 +256,8 @@ def interface(self, component: "Model") -> None:
256256
if selfpin.name[0:3] != "pin" and selfpin.name == componentpin.name:
257257
selfpin.connect(componentpin)
258258

259+
return self
260+
259261
def monte_carlo_s_parameters(self, freqs: "np.array") -> "np.ndarray":
260262
"""Implements the monte carlo routine for the given Model.
261263
@@ -278,7 +280,7 @@ def monte_carlo_s_parameters(self, freqs: "np.array") -> "np.ndarray":
278280
"""
279281
return self.s_parameters(freqs)
280282

281-
def multiconnect(self, *connections: Union["Model", Pin, None]) -> None:
283+
def multiconnect(self, *connections: Union["Model", Pin, None]) -> "Model":
282284
"""Connects this component to the specified connections by looping
283285
through each connection and connecting it with the corresponding pin.
284286
@@ -293,6 +295,8 @@ def multiconnect(self, *connections: Union["Model", Pin, None]) -> None:
293295
if connection is not None:
294296
self.pins[index].connect(connection)
295297

298+
return self
299+
296300
def regenerate_monte_carlo_parameters(self) -> None:
297301
"""Regenerates parameters used to generate monte carlo s-matrices.
298302
@@ -540,6 +544,7 @@ def _s_parameters(
540544
The method name to call to get the scattering parameters.
541545
Either 's_parameters' or 'monte_carlo_s_parameters'
542546
"""
547+
from simphony.simulation import SimulationModel
543548
from simphony.simulators import Simulator
544549

545550
all_pins = []
@@ -549,16 +554,36 @@ def _s_parameters(
549554
# merge all of the s_params into one giant block diagonal matrix
550555
for component in self._wrapped_circuit:
551556
# simulators don't have scattering parameters
552-
if isinstance(component, Simulator):
557+
if isinstance(component, Simulator) or isinstance(
558+
component, SimulationModel
559+
):
553560
continue
554561

555562
# get the s_params from the cache if possible
556563
if s_parameters_method == "s_parameters":
557-
try:
558-
s_params = self.__class__.scache[component]
559-
except KeyError:
560-
s_params = getattr(component, s_parameters_method)(freqs)
561-
self.__class__.scache[component] = s_params
564+
# each frequency has a different s-matrix, so we need to cache
565+
# the s-matrices by frequency as well as component
566+
s_params = []
567+
for freq in freqs:
568+
try:
569+
# use the cached s-matrix if available
570+
s_matrix = self.__class__.scache[component][freq]
571+
except KeyError:
572+
# make sure the frequency dict is created
573+
if component not in self.__class__.scache:
574+
self.__class__.scache[component] = {}
575+
576+
# store the s-matrix for the frequency and component
577+
s_matrix = getattr(component, s_parameters_method)(
578+
np.array([freq])
579+
)[0]
580+
self.__class__.scache[component][freq] = s_matrix
581+
582+
# add the s-matrix to our list of s-matrices
583+
s_params.append(s_matrix)
584+
585+
# convert to numpy array for the rest of the function
586+
s_params = np.array(s_params)
562587
elif s_parameters_method == "monte_carlo_s_parameters":
563588
# don't cache Monte Carlo scattering parameters
564589
s_params = getattr(component, s_parameters_method)(freqs)

simphony/pins.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,12 @@ def _isconnected(self, *, include_simulators: bool = True) -> bool:
5353
if include_simulators:
5454
return True
5555

56+
from simphony.simulation import SimulationModel
5657
from simphony.simulators import Simulator
5758

58-
return not isinstance(self._connection._component, Simulator)
59+
return not isinstance(
60+
self._connection._component, Simulator
61+
) and not isinstance(self._connection._component, SimulationModel)
5962

6063
def connect(self, pin_or_component: Union["Pin", "Model"]) -> None:
6164
"""Connects this pin to the pin/component that is passed in.

0 commit comments

Comments
 (0)