-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathcomplex.py
42 lines (32 loc) · 1.03 KB
/
complex.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import torch
def complex_multiply_astar_b(a, b):
"""
Compute the product of the complex conjugate of a and the b in torch format
[..., 0] -> real part
[..., 1] -> imaginary part
out = a* x b
Parameters
----------
a : torch complex array
a
b : torch complex array
b
"""
tmp1 = torch.unsqueeze(a[..., 0] * b[..., 0] + a[..., 1] * b[..., 1], -1)
tmp2 = torch.unsqueeze(a[..., 0] * b[..., 1] - a[..., 1] * b[..., 0], -1)
return torch.cat([tmp1, tmp2], -1)
def complex_division(a, b):
"""
Compute the division of two complex numbers
out = a / b
Parameters
----------
a : torch complex array
a
b : torch complex array
b
"""
denominator = torch.unsqueeze(b[..., 0] * b[..., 0] + b[..., 1] * b[..., 1], -1)
tmp1 = torch.unsqueeze(a[..., 0] * b[..., 0] + a[..., 1] * b[..., 1], -1)
tmp2 = torch.unsqueeze(a[..., 1] * b[..., 0] - a[..., 0] * b[..., 1], -1)
return torch.cat([tmp1 / denominator, tmp2 / denominator], -1)