Skip to content

Commit f350c90

Browse files
committed
Use context aware bandwidth
1 parent 31f073e commit f350c90

File tree

2 files changed

+87
-45
lines changed

2 files changed

+87
-45
lines changed

simphony/simulation.py

+65-23
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
"""
1212

1313
from cmath import rect
14-
from typing import TYPE_CHECKING, ClassVar, List, Optional
14+
from typing import TYPE_CHECKING, ClassVar, List, Optional, Tuple
1515

1616
import numpy as np
1717
from scipy.constants import epsilon_0, h, mu_0
@@ -552,7 +552,7 @@ def _detect(self, power: List[np.ndarray]) -> List[np.ndarray]:
552552
for source in self.context.sources:
553553
# we let the laser handle the RIN distribution
554554
# so the same noise is injected in all the signals
555-
rin = source.get_rin(self.high_fc - self.low_fc)
555+
rin = source.get_rin(self._get_bandwidth()[0])
556556
dist = source.get_rin_dist(i, j)
557557

558558
# calculate the standard deviation of the RIN noise
@@ -602,15 +602,15 @@ def _filter(self, signal: np.ndarray) -> np.ndarray:
602602
signal :
603603
The signal to filter.
604604
"""
605-
high = min(self.high_fc, 0.5 * (self.context.fs - 1))
605+
bw = self._get_bandwidth()
606606
sos = (
607-
butter(6, high, "lowpass", fs=self.context.fs, output="sos")
608-
if self.low_fc == 0
607+
butter(6, bw[2], "lowpass", fs=self.context.fs, output="sos")
608+
if bw[1] == 0
609609
else butter(
610610
6,
611611
[
612-
max(self.low_fc, self.context.fs / self.context.num_samples * 30),
613-
high,
612+
bw[1],
613+
bw[2],
614614
],
615615
"bandpass",
616616
fs=self.context.fs,
@@ -619,6 +619,20 @@ def _filter(self, signal: np.ndarray) -> np.ndarray:
619619
)
620620
return sosfiltfilt(sos, signal)
621621

622+
def _get_bandwidth(self):
623+
"""Gets the bandwidth of the detector w.r.t.
624+
625+
the sampling frequency.
626+
"""
627+
low = (
628+
0
629+
if self.low_fc == 0
630+
else max(self.low_fc, self.context.fs / self.context.num_samples * 30)
631+
)
632+
high = min(self.high_fc, 0.5 * (self.context.fs - 1))
633+
634+
return (high - low, low, high)
635+
622636

623637
class DifferentialDetector(Detector):
624638
"""A differential detector takes two connections and provides three outputs
@@ -697,10 +711,8 @@ def _detect(self, powers: List[np.ndarray]) -> List[np.ndarray]:
697711
for source in self.context.sources:
698712
# get the RIN specs from the laser to ensure that the
699713
# same noise is injected across all signals
700-
monitor_rin = source.get_rin(
701-
self.monitor_high_fc - self.monitor_low_fc
702-
)
703-
rf_rin = source.get_rin(self.rf_high_fc - self.rf_low_fc)
714+
monitor_rin = source.get_rin(self._get_monitor_bandwidth()[0])
715+
rf_rin = source.get_rin(self._get_rf_bandwidth()[0])
704716
dist = source.get_rin_dist(i, j)
705717

706718
# only calculate the noise if there is power
@@ -756,27 +768,24 @@ def _detect(self, powers: List[np.ndarray]) -> List[np.ndarray]:
756768
self._monitor(p2, self.monitor_rin_dists2),
757769
)
758770

759-
def _filter(self, signal: np.ndarray, low_fc: float, high_fc: float) -> np.ndarray:
771+
def _filter(self, signal: np.ndarray, bw: Tuple[float, float, float]) -> np.ndarray:
760772
"""Filters the signal.
761773
762774
Parameters
763775
----------
764776
signal :
765777
The signal to filter.
766-
low_fc :
767-
The lower cut-off frequency.
768-
high_fc :
769-
The higher cut-off frequency.
778+
bw :
779+
The bandwidth of the filter. (difference, low_fc, high_fc)
770780
"""
771-
high = min(high_fc, 0.5 * (self.context.fs - 1))
772781
sos = (
773-
butter(6, high, "lowpass", fs=self.context.fs, output="sos")
774-
if low_fc == 0
782+
butter(6, bw[2], "lowpass", fs=self.context.fs, output="sos")
783+
if bw[1] == 0
775784
else butter(
776785
6,
777786
[
778-
max(low_fc, self.context.fs / self.context.num_samples * 30),
779-
high,
787+
bw[1],
788+
bw[2],
780789
],
781790
"bandpass",
782791
fs=self.context.fs,
@@ -785,6 +794,39 @@ def _filter(self, signal: np.ndarray, low_fc: float, high_fc: float) -> np.ndarr
785794
)
786795
return sosfiltfilt(sos, signal)
787796

797+
def _get_bandwidth(self, low_fc, high_fc):
798+
"""Gets the bandwidth of the detector w.r.t. the sampling frequency.
799+
800+
Parameters
801+
----------
802+
low_fc :
803+
The lower cut-off frequency.
804+
high_fc :
805+
The higher cut-off frequency.
806+
"""
807+
low = (
808+
0
809+
if low_fc == 0
810+
else max(low_fc, self.context.fs / self.context.num_samples * 30)
811+
)
812+
high = min(high_fc, 0.5 * (self.context.fs - 1))
813+
814+
return (high - low, low, high)
815+
816+
def _get_monitor_bandwidth(self):
817+
"""Gets the bandwidth of the monitor w.r.t.
818+
819+
the sampling frequency.
820+
"""
821+
return self._get_bandwidth(self.monitor_low_fc, self.monitor_high_fc)
822+
823+
def _get_rf_bandwidth(self):
824+
"""Gets the bandwidth of the rf w.r.t.
825+
826+
the sampling frequency.
827+
"""
828+
return self._get_bandwidth(self.rf_low_fc, self.rf_high_fc)
829+
788830
def _monitor(self, power: np.ndarray, rin_dists: np.ndarray) -> np.ndarray:
789831
"""Takes a signal and turns it into a monitor output.
790832
@@ -813,7 +855,7 @@ def _monitor_filter(self, signal: np.ndarray) -> np.ndarray:
813855
signal :
814856
The signal to filter.
815857
"""
816-
return self._filter(signal, self.monitor_low_fc, self.monitor_high_fc)
858+
return self._filter(signal, self._get_monitor_bandwidth())
817859

818860
def _rf(self, p1: np.ndarray, p2: np.ndarray) -> np.ndarray:
819861
"""Takes two signals and generates the differential RF signal. p1 - p2.
@@ -850,4 +892,4 @@ def _rf_filter(self, signal: np.ndarray) -> np.ndarray:
850892
signal :
851893
The signal to filter.
852894
"""
853-
return self._filter(signal, self.rf_low_fc, self.rf_high_fc)
895+
return self._filter(signal, self._get_rf_bandwidth())

simphony/tests/test_simulation.py

+22-22
Original file line numberDiff line numberDiff line change
@@ -1231,28 +1231,28 @@ class TestLaser:
12311231

12321232
rin_results = [
12331233
0.00017534,
1234-
0.00019688,
1235-
0.00014356,
1236-
0.00021228,
1237-
0.00016978,
1238-
0.00016308,
1239-
0.00019001,
1240-
0.00018387,
1241-
0.00015724,
1242-
0.00018694,
1243-
0.00015188,
1244-
0.00019778,
1245-
0.00017628,
1246-
0.00017727,
1247-
0.00016445,
1248-
0.00019244,
1249-
0.0001659,
1250-
0.00014431,
1251-
0.00013288,
1252-
0.00023053,
1253-
0.00016564,
1254-
0.0001742,
1255-
0.00016614,
1234+
0.00019062,
1235+
0.00015289,
1236+
0.00020163,
1237+
0.00017138,
1238+
0.00016665,
1239+
0.0001857,
1240+
0.00018142,
1241+
0.00016255,
1242+
0.00018353,
1243+
0.00015875,
1244+
0.0001913,
1245+
0.00017599,
1246+
0.00017676,
1247+
0.00016767,
1248+
0.00018743,
1249+
0.00016871,
1250+
0.00015348,
1251+
0.00014547,
1252+
0.00021439,
1253+
0.00016853,
1254+
0.00017464,
1255+
0.00016882,
12561256
]
12571257

12581258
def test_wlsweep(self, mzi):

0 commit comments

Comments
 (0)