Skip to content

Commit fc4938a

Browse files
committed
Transfer map is now method of the MagneticLattice
1 parent 32a1c5b commit fc4938a

7 files changed

+91
-43
lines changed

ocelot/cpbd/magnetic_lattice.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from ocelot.cpbd.latticeIO import LatticeIO
88
from ocelot.cpbd.transformations.transfer_map import TransferMap
99
from ocelot.cpbd.optics import lattice_transfer_map
10+
from ocelot.cpbd.tm_utils import transfer_maps_mult
1011

1112
import logging
1213
import re
@@ -103,11 +104,8 @@ def merger(lat, remaining_types=None, remaining_elems=None, init_energy=0.):
103104
else:
104105
delta_e = np.sum([tm.get_delta_e() for elem in elem_list for tm in elem.tms])
105106
lattice = MagneticLattice(elem_list, method=lat.method)
106-
R = lattice_transfer_map(lattice, energy=E)
107107
m = Matrix()
108-
m.r = lattice.R
109-
m.t = lattice.T
110-
m.b = lattice.B
108+
m.b, m.r, m.t = lattice.transfer_maps(energy=E)
111109
m.l = lattice.totalLen
112110
m.delta_e = delta_e
113111
E += delta_e
@@ -117,6 +115,7 @@ def merger(lat, remaining_types=None, remaining_elems=None, init_energy=0.):
117115
_logger.debug("element numbers after: " + str(len(new_lat.sequence)))
118116
return new_lat
119117

118+
120119
def flatten(iterable: Iterator[Any]) -> Generator[Any, None, None]:
121120
"""Flatten arbitrarily nested iterable.
122121
Special case for strings that avoids infinite recursion. Non iterables passed
@@ -280,6 +279,23 @@ def save_as_py_file(self, file_name: str, remove_rep_drifts=True, power_supply=F
280279
LatticeIO.save_lattice(self, tws0=None, file_name=file_name, remove_rep_drifts=remove_rep_drifts,
281280
power_supply=power_supply)
282281

282+
def transfer_maps(self, energy):
283+
"""
284+
Function calculates transfer maps, the first and second orders (R, T), for the whole lattice.
285+
286+
:param energy: the initial electron beam energy [GeV]
287+
:return: B, R, T - matrices
288+
"""
289+
Ra = np.eye(6)
290+
Ta = np.zeros((6, 6, 6))
291+
Ba = np.zeros((6, 1))
292+
E = energy
293+
for elem in self.sequence:
294+
for Rb, Bb, Tb, tm in zip(elem.R(E), elem.B(E), elem.T(E), elem.tms):
295+
Ba, Ra, Ta = transfer_maps_mult(Ba, Ra, Ta, Bb, Rb, Tb)
296+
E += tm.get_delta_e()
297+
return Ba, Ra, Ta
298+
283299

284300
class EndElements:
285301
suffix_1 = "_1"

ocelot/cpbd/optics.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -84,19 +84,9 @@ def lattice_transfer_map(lattice, energy):
8484
:param energy: the initial electron beam energy [GeV]
8585
:return: R - matrix
8686
"""
87-
88-
Ra = np.eye(6)
89-
Ta = np.zeros((6, 6, 6))
90-
Ba = np.zeros((6, 1))
91-
E = energy
92-
for elem in lattice.sequence:
93-
for Rb, Bb, Tb, tm in zip(elem.R(E), elem.B(E), elem.T(E), elem.tms):
94-
Ba, Ra, Ta = transfer_maps_mult(Ba, Ra, Ta, Bb, Rb, Tb)
95-
#Ba = np.dot(Rb, Ba) + Bb
96-
E += tm.get_delta_e()
87+
Ba, Ra, Ta = lattice.transfer_maps(energy)
9788

9889
# TODO: Adding Attributes at runtime should be avoided
99-
lattice.E = E
10090
lattice.T_sym = Ta
10191
lattice.T = Ta #unsym_matrix(deepcopy(Ta))
10292
lattice.R = Ra

unit_tests/ebeam_test/io_lattice/io_lattice_test.py

Lines changed: 48 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import copy
1616

1717

18-
def test_original_lattice_transfer_map(lattice, tws0, method, parametr=None, update_ref_values=False):
18+
def test_original_lattice_transfer_map(lattice, tws0, method, parameter=None, update_ref_values=False):
1919
"""R maxtrix calculation test"""
2020

2121
r_matrix = lattice_transfer_map(lattice, tws0.E)
@@ -29,7 +29,7 @@ def test_original_lattice_transfer_map(lattice, tws0, method, parametr=None, upd
2929
assert check_result(result)
3030

3131

32-
def test_lattice_save_as_py_file(lattice, tws0, method, parametr=None, update_ref_values=False):
32+
def test_lattice_save_as_py_file(lattice, tws0, method, parameter=None, update_ref_values=False):
3333
"""R maxtrix calculation test"""
3434

3535
lattice.save_as_py_file(file_name="tmp_lattice.py")
@@ -48,7 +48,7 @@ def test_lattice_save_as_py_file(lattice, tws0, method, parametr=None, update_re
4848
assert check_result(res)
4949

5050

51-
def test_lattice_save_as_py_file_w_coupler(lattice, tws0, method, parametr=None, update_ref_values=False):
51+
def test_lattice_save_as_py_file_w_coupler(lattice, tws0, method, parameter=None, update_ref_values=False):
5252
"""R maxtrix calculation test"""
5353
lattice0 = copy.deepcopy(lattice)
5454
for elem in lattice0.sequence:
@@ -87,7 +87,7 @@ def test_lattice_save_as_py_file_w_coupler(lattice, tws0, method, parametr=None,
8787

8888
assert check_result(res)
8989

90-
def test_original_twiss(lattice, tws0, method, parametr=None, update_ref_values=False):
90+
def test_original_twiss(lattice, tws0, method, parameter=None, update_ref_values=False):
9191
"""Twiss parameters calculation function test"""
9292

9393
tws = twiss(lattice, tws0, nPoints=None)
@@ -103,8 +103,8 @@ def test_original_twiss(lattice, tws0, method, parametr=None, update_ref_values=
103103
assert check_result(result)
104104

105105

106-
@pytest.mark.parametrize('parametr', [False, True])
107-
def test_lat2input(lattice, tws0, method, parametr, update_ref_values=False):
106+
@pytest.mark.parametrize('parameter', [False, True])
107+
def test_lat2input(lattice, tws0, method, parameter, update_ref_values=False):
108108
"""lat2input with tws0 saving function test"""
109109

110110
lines_arr = LatticeIO.lat2input(lattice, tws0=tws0)
@@ -114,13 +114,13 @@ def test_lat2input(lattice, tws0, method, parametr, update_ref_values=False):
114114
try:
115115
exec(lines, globals(), loc_dict)
116116
except Exception as err:
117-
assert check_result(['Exception error during the lattice file execution, parametr is ' + str(parametr)])
117+
assert check_result(['Exception error during the lattice file execution, parameter is ' + str(parameter)])
118118

119-
if parametr:
119+
if parameter:
120120
if "tws0" in loc_dict:
121121
tws0_new = loc_dict['tws0']
122122
else:
123-
assert check_result(['No tws0 in the lattice file, parametr is ' + str(parametr)])
123+
assert check_result(['No tws0 in the lattice file, parameter is ' + str(parameter)])
124124
else:
125125
tws0_new = tws0
126126

@@ -130,7 +130,7 @@ def test_lat2input(lattice, tws0, method, parametr, update_ref_values=False):
130130
lattice_new_transfer_map_check(lattice_new, tws0_new)
131131
twiss_new_check(lattice_new, tws0_new)
132132
else:
133-
assert check_result(['No cell variable in the lattice file, parametr is ' + str(parametr)])
133+
assert check_result(['No cell variable in the lattice file, parameter is ' + str(parameter)])
134134

135135

136136
def lattice_new_transfer_map_check(lattice, tws0):
@@ -142,6 +142,21 @@ def lattice_new_transfer_map_check(lattice, tws0):
142142
result = check_matrix(r_matrix, r_matrix_ref, TOL, assert_info=' r_matrix for new lattice - ')
143143
assert check_result(result)
144144

145+
@pytest.mark.parametrize('parameter', [0, 1, 2])
146+
def test_lattice_transfer_maps_check(lattice, tws0, method, parameter, update_ref_values=False):
147+
matrices = lattice.transfer_maps(tws0.E)
148+
149+
150+
#r_matrix_ref = json2numpy(json_read(REF_RES_DIR + 'test_original_lattice_transfer_map.json'))
151+
m = matrices[parameter]
152+
if update_ref_values:
153+
return numpyBRT2json(m)
154+
155+
m_ref = json2numpyBRT(json_read(REF_RES_DIR + sys._getframe().f_code.co_name + str(parameter) + '.json'))
156+
157+
result = check_matrix(m, m_ref, TOL, assert_info=' B_matrix for the lattice - ')
158+
159+
assert check_result(result)
145160

146161
def twiss_new_check(lattice, tws0):
147162

@@ -155,7 +170,7 @@ def twiss_new_check(lattice, tws0):
155170
assert check_result(result)
156171

157172

158-
def test_merger(lattice, tws0, method, parametr=None, update_ref_values=False):
173+
def test_merger(lattice, tws0, method, parameter=None, update_ref_values=False):
159174
"""R maxtrix calculation test"""
160175
d = Drift(l=0.5)
161176
q = Quadrupole(l=0.3, k1=3, k2=3.3, eid="quad")
@@ -185,7 +200,7 @@ def test_merger(lattice, tws0, method, parametr=None, update_ref_values=False):
185200
assert check_result(result + result2)
186201

187202

188-
def test_merger_elem(lattice, tws0, method, parametr=None, update_ref_values=False):
203+
def test_merger_elem(lattice, tws0, method, parameter=None, update_ref_values=False):
189204
"""R maxtrix calculation test"""
190205
d = Drift(l=0.5)
191206
q = Quadrupole(l=0.3, k1=3, k2=3.3, eid="quad")
@@ -215,7 +230,7 @@ def test_merger_elem(lattice, tws0, method, parametr=None, update_ref_values=Fal
215230
assert check_result(result + result2)
216231

217232

218-
def test_merger_elem_w_coupler(lattice, tws0, method, parametr=None, update_ref_values=False):
233+
def test_merger_elem_w_coupler(lattice, tws0, method, parameter=None, update_ref_values=False):
219234
"""R maxtrix calculation test"""
220235
d = Drift(l=0.5)
221236
q = Quadrupole(l=0.3, k1=3, k2=3.3, eid="quad")
@@ -252,7 +267,7 @@ def test_merger_elem_w_coupler(lattice, tws0, method, parametr=None, update_ref_
252267
result2 = check_matrix(lat.T, new_lat.T, TOL, assert_info=' t_matrix - ')
253268
assert check_result(result + result2)
254269

255-
def test_merger_type(lattice, tws0, method, parametr=None, update_ref_values=False):
270+
def test_merger_type(lattice, tws0, method, parameter=None, update_ref_values=False):
256271
"""R maxtrix calculation test"""
257272
d = Drift(l=0.5)
258273
q = Quadrupole(l=0.3, k1=3, k2=3.3, eid="quad")
@@ -284,7 +299,7 @@ def test_merger_type(lattice, tws0, method, parametr=None, update_ref_values=Fal
284299
assert check_result(result + result2)
285300

286301

287-
def test_merger_extensive(lattice, tws0, method, parametr=None, update_ref_values=False):
302+
def test_merger_extensive(lattice, tws0, method, parameter=None, update_ref_values=False):
288303
"""R maxtrix calculation test"""
289304

290305

@@ -300,7 +315,7 @@ def test_merger_extensive(lattice, tws0, method, parametr=None, update_ref_value
300315
assert check_result(result + result2)
301316

302317

303-
def test_merger_tilt(lattice, tws0, method, parametr=None, update_ref_values=False):
318+
def test_merger_tilt(lattice, tws0, method, parameter=None, update_ref_values=False):
304319
"""R maxtrix calculation test"""
305320
d = Drift(l=0.5)
306321
q = Quadrupole(l=0.3, k1=3, k2=3.3, eid="quad", tilt=1.)
@@ -331,7 +346,7 @@ def test_merger_tilt(lattice, tws0, method, parametr=None, update_ref_values=Fal
331346
assert check_result(result + result2)
332347

333348

334-
def test_merger_write_read(lattice, tws0, method, parametr=None, update_ref_values=False):
349+
def test_merger_write_read(lattice, tws0, method, parameter=None, update_ref_values=False):
335350
"""R maxtrix calculation test"""
336351

337352
R = lattice_transfer_map(lattice, energy=tws0.E)
@@ -348,7 +363,7 @@ def test_merger_write_read(lattice, tws0, method, parametr=None, update_ref_valu
348363
result2 = check_matrix(lattice.T, new_lat2.T, tolerance=1.0e-8, tolerance_type='absolute', assert_info=' t_matrix - ')
349364
assert check_result(result + result2)
350365

351-
def test_matrix_write_read(lattice, tws0, method, parametr=None, update_ref_values=False):
366+
def test_matrix_write_read(lattice, tws0, method, parameter=None, update_ref_values=False):
352367
"""R maxtrix calculation test"""
353368
m = Matrix(l=0.3, delta_e=0.1)
354369
m.r = np.random.random((6, 6))
@@ -368,11 +383,11 @@ def test_matrix_write_read(lattice, tws0, method, parametr=None, update_ref_valu
368383

369384
result = check_matrix(R, R2, TOL, assert_info='r_matrix - ')
370385
result2 = check_matrix(lat.T, lat2.T, TOL, assert_info='t_matrix - ')
371-
result3 = check_matrix(np.array([lat.E, lat.totalLen]), np.array([lat2.E,lat2.totalLen ]), TOL, assert_info='t_matrix - ')
386+
result3 = check_matrix(np.array([ lat.totalLen]), np.array([lat2.totalLen ]), TOL, assert_info='t_matrix - ')
372387
assert check_result(result + result2 + result3)
373388

374389

375-
def test_matrix_b_vector(lattice, tws0, method, parametr=None, update_ref_values=False):
390+
def test_matrix_b_vector(lattice, tws0, method, parameter=None, update_ref_values=False):
376391
"""R maxtrix calculation test"""
377392
d = Drift(l=0.5)
378393
q = Quadrupole(l=0.3, k1=3, k2=3.3, eid="quad")
@@ -407,7 +422,7 @@ def test_matrix_b_vector(lattice, tws0, method, parametr=None, update_ref_values
407422
assert check_result(result + result2 + result3)
408423

409424

410-
def test_matrix_b_vector_read_write(lattice, tws0, method, parametr=None, update_ref_values=False):
425+
def test_matrix_b_vector_read_write(lattice, tws0, method, parameter=None, update_ref_values=False):
411426
"""R maxtrix calculation test"""
412427
d = Drift(l=0.5)
413428
q = Quadrupole(l=0.3, k1=3, k2=3.3, eid="quad")
@@ -481,16 +496,22 @@ def test_update_ref_values(lattice, tws0, method, cmdopt):
481496
update_functions = []
482497
update_functions.append('test_original_lattice_transfer_map')
483498
update_functions.append('test_original_twiss')
499+
update_functions.append("test_lattice_transfer_maps_check")
484500

485501
# function test_lat2input function need not be added here.
486502
# It is used reference results from test_original_lattice_transfer_map and test_original_twiss functions
503+
update_function_parameters = {}
504+
update_function_parameters['test_lattice_transfer_maps_check'] = [0, 1, 2]
505+
506+
parameter = update_function_parameters[cmdopt] if cmdopt in update_function_parameters.keys() else ['']
487507

488508
if cmdopt in update_functions:
489-
result = eval(cmdopt)(lattice, tws0, method, None, True)
490-
if result is None:
491-
return
509+
for p in parameter:
510+
result = eval(cmdopt)(lattice, tws0, method, p, True)
511+
if result is None:
512+
return
492513

493-
if os.path.isfile(REF_RES_DIR + cmdopt + '.json'):
494-
os.rename(REF_RES_DIR + cmdopt + '.json', REF_RES_DIR + cmdopt + '.old')
514+
if os.path.isfile(REF_RES_DIR + cmdopt + '.json'):
515+
os.rename(REF_RES_DIR + cmdopt + '.json', REF_RES_DIR + cmdopt + str(p) + '.old')
495516

496-
json_save(result, REF_RES_DIR + cmdopt + '.json')
517+
json_save(result, REF_RES_DIR + cmdopt + str(p) + '.json')
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
{"0": 0.0, "1": 0.0, "2": 0.0, "3": 0.0, "4": 0.0, "5": 0.0}
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
{"0": -0.3669735365228965, "1": -1.1490635074188513, "2": 1.190282708588602e-10, "3": 3.1482812735754887e-10, "4": 1.132202512078879e-22, "5": 2.8128820894915816e-09, "6": 0.1347434520651659, "7": 0.3173235504195618, "8": 2.2852523065905037e-11, "9": 7.028882971223359e-11, "10": 6.818683204158491e-23, "11": 1.6940566422179824e-09, "12": 2.6540664561105906e-10, "13": 7.45020103770362e-10, "14": 1.0266211358897732, "15": 2.4501975772363154, "16": 3.5480318787650113e-22, "17": 8.81484999284599e-09, "18": -9.44327601378983e-12, "19": -2.209322846125279e-11, "20": 0.055499305735640936, "21": 0.1698423074184009, "22": 6.328982779931208e-23, "23": 1.5723938178315115e-09, "24": -2.6210674684298977e-08, "25": -7.43652080033237e-08, "26": 2.9467543946615216e-08, "27": 6.169766289532141e-08, "28": 1.0, "29": -0.04189680038806823, "30": 0.0, "31": 0.0, "32": 0.0, "33": 0.0, "34": 1.5367193273535444e-15, "35": 0.03817877302615135}

0 commit comments

Comments
 (0)