diff --git a/complexnn/bn.py b/complexnn/bn.py index 75c1e7b..66bd53d 100644 --- a/complexnn/bn.py +++ b/complexnn/bn.py @@ -147,6 +147,30 @@ def complex_standardization(input_centred, Vrr, Vii, Vri, def ComplexBN(input_centred, Vrr, Vii, Vri, beta, gamma_rr, gamma_ri, gamma_ii, scale=True, center=True, layernorm=False, axis=-1): + """Complex Batch Normalization + + Arguments: + input_centred -- input data + Vrr -- Real component of covariance matrix V + Vii -- Imaginary component of covariance matrix V + Vri -- Non-diagonal component of covariance matrix V + beta -- Lernable shift parameter beta + gamma_rr -- Scaling parameter gamma - rr component of 2x2 matrix + gamma_ri -- Scaling parameter gamma - ri component of 2x2 matrix + gamma_ii -- Scaling parameter gamma - ii component of 2x2 matrix + + Keyword Arguments: + scale {bool} {bool} -- Standardization of input (default: {True}) + center {bool} -- Mean-shift correction (default: {True}) + layernorm {bool} -- Normalization (default: {False}) + axis {int} -- Axis for Standardization (default: {-1}) + + Raises: + ValueError: Dimonsional mismatch + + Returns: + Batch-Normalized Input + """ ndim = K.ndim(input_centred) input_dim = K.shape(input_centred)[axis] // 2