diff --git a/complexnn/__init__.py b/complexnn/__init__.py index c337296..4d6c1b1 100644 --- a/complexnn/__init__.py +++ b/complexnn/__init__.py @@ -4,7 +4,6 @@ from . import bn, conv, dense, init, norm, pool # from . import fft - from .bn import ComplexBatchNormalization as ComplexBN from .conv import ( ComplexConv, @@ -18,17 +17,17 @@ # from .fft import (fft, ifft, fft2, ifft2, FFT, IFFT, FFT2, IFFT2) from .init import ( ComplexIndependentFilters, - IndependentFilters, ComplexInit, + IndependentFilters, SqrtInit, ) -from .norm import LayerNormalization, ComplexLayerNorm +from .norm import ComplexLayerNorm, LayerNormalization from .pool import SpectralPooling1D, SpectralPooling2D from .utils import ( - get_realpart, - get_imagpart, - getpart_output_shape, + GetAbs, GetImag, GetReal, - GetAbs, + get_imagpart, + get_realpart, + getpart_output_shape, ) diff --git a/complexnn/bn.py b/complexnn/bn.py index 090ee45..582177d 100644 --- a/complexnn/bn.py +++ b/complexnn/bn.py @@ -7,9 +7,9 @@ # https://github.com/fchollet/keras/blob/master/keras/layers/normalization.py import numpy as np -from tensorflow.keras.layers import Layer, InputSpec -from tensorflow.keras import initializers, regularizers, constraints import tensorflow.keras.backend as K +from tensorflow.keras import constraints, initializers, regularizers +from tensorflow.keras.layers import InputSpec, Layer def sqrt_init(shape, dtype=None): @@ -139,7 +139,18 @@ def complex_standardization(input_centred, Vrr, Vii, Vri, layernorm=False, axis= def ComplexBN( - input_centred, Vrr, Vii, Vri, beta, gamma_rr, gamma_ri, gamma_ii, scale=True, center=True, layernorm=False, axis=-1 + input_centred, + Vrr, + Vii, + Vri, + beta, + gamma_rr, + gamma_ri, + gamma_ii, + scale=True, + center=True, + layernorm=False, + axis=-1, ): """Complex Batch Normalization @@ -176,7 +187,9 @@ def ComplexBN( broadcast_beta_shape[axis] = input_dim * 2 if scale: - standardized_output = complex_standardization(input_centred, Vrr, Vii, Vri, layernorm, axis=axis) + standardized_output = complex_standardization( + input_centred, Vrr, Vii, Vri, layernorm, axis=axis + ) # Now we perform th scaling and Shifting of the normalized x using # the scaling parameter @@ -194,8 +207,12 @@ def ComplexBN( broadcast_gamma_ri = K.reshape(gamma_ri, gamma_broadcast_shape) broadcast_gamma_ii = K.reshape(gamma_ii, gamma_broadcast_shape) - cat_gamma_4_real = K.concatenate([broadcast_gamma_rr, broadcast_gamma_ii], axis=axis) - cat_gamma_4_imag = K.concatenate([broadcast_gamma_ri, broadcast_gamma_ri], axis=axis) + cat_gamma_4_real = K.concatenate( + [broadcast_gamma_rr, broadcast_gamma_ii], axis=axis + ) + cat_gamma_4_imag = K.concatenate( + [broadcast_gamma_ri, broadcast_gamma_ri], axis=axis + ) if (axis == 1 and ndim != 3) or ndim == 2: centred_real = standardized_output[:, :input_dim] centred_imag = standardized_output[:, input_dim:] @@ -214,14 +231,21 @@ def ComplexBN( " should be either 1 or -1. " "axis: " + str(axis) + "; ndim: " + str(ndim) + "." ) - rolled_standardized_output = K.concatenate([centred_imag, centred_real], axis=axis) + rolled_standardized_output = K.concatenate( + [centred_imag, centred_real], axis=axis + ) if center: broadcast_beta = K.reshape(beta, broadcast_beta_shape) return ( - cat_gamma_4_real * standardized_output + cat_gamma_4_imag * rolled_standardized_output + broadcast_beta + cat_gamma_4_real * standardized_output + + cat_gamma_4_imag * rolled_standardized_output + + broadcast_beta ) else: - return cat_gamma_4_real * standardized_output + cat_gamma_4_imag * rolled_standardized_output + return ( + cat_gamma_4_real * standardized_output + + cat_gamma_4_imag * rolled_standardized_output + ) else: if center: broadcast_beta = K.reshape(beta, broadcast_beta_shape) @@ -294,7 +318,7 @@ def __init__( beta_constraint=None, gamma_diag_constraint=None, gamma_off_constraint=None, - **kwargs + **kwargs, ): super(ComplexBatchNormalization, self).__init__(**kwargs) self.supports_masking = True @@ -308,7 +332,9 @@ def __init__( self.gamma_off_initializer = sanitizedInitGet(gamma_off_initializer) self.moving_mean_initializer = sanitizedInitGet(moving_mean_initializer) self.moving_variance_initializer = sanitizedInitGet(moving_variance_initializer) - self.moving_covariance_initializer = sanitizedInitGet(moving_covariance_initializer) + self.moving_covariance_initializer = sanitizedInitGet( + moving_covariance_initializer + ) self.beta_regularizer = regularizers.get(beta_regularizer) self.gamma_diag_regularizer = regularizers.get(gamma_diag_regularizer) self.gamma_off_regularizer = regularizers.get(gamma_off_regularizer) @@ -317,7 +343,6 @@ def __init__( self.gamma_off_constraint = constraints.get(gamma_off_constraint) def build(self, input_shape): - ndim = len(input_shape) dim = input_shape[self.axis] @@ -354,13 +379,22 @@ def build(self, input_shape): constraint=self.gamma_off_constraint, ) self.moving_Vrr = self.add_weight( - shape=param_shape, initializer=self.moving_variance_initializer, name="moving_Vrr", trainable=False + shape=param_shape, + initializer=self.moving_variance_initializer, + name="moving_Vrr", + trainable=False, ) self.moving_Vii = self.add_weight( - shape=param_shape, initializer=self.moving_variance_initializer, name="moving_Vii", trainable=False + shape=param_shape, + initializer=self.moving_variance_initializer, + name="moving_Vii", + trainable=False, ) self.moving_Vri = self.add_weight( - shape=param_shape, initializer=self.moving_covariance_initializer, name="moving_Vri", trainable=False + shape=param_shape, + initializer=self.moving_covariance_initializer, + name="moving_Vri", + trainable=False, ) else: self.gamma_rr = None @@ -443,7 +477,9 @@ def call(self, inputs, training=None): Vii = None Vri = None else: - raise ValueError("Error. Both scale and center in batchnorm are set to False.") + raise ValueError( + "Error. Both scale and center in batchnorm are set to False." + ) input_bn = ComplexBN( input_centred, @@ -460,34 +496,43 @@ def call(self, inputs, training=None): ) if training in {0, False}: return input_bn - else: - update_list = [] + update_list = [] + if self.center: + update_list.append( + K.moving_average_update(self.moving_mean, mu, self.momentum) + ) + if self.scale: + update_list.append( + K.moving_average_update(self.moving_Vrr, Vrr, self.momentum) + ) + update_list.append( + K.moving_average_update(self.moving_Vii, Vii, self.momentum) + ) + update_list.append( + K.moving_average_update(self.moving_Vri, Vri, self.momentum) + ) + self.add_update(update_list) + + def normalize_inference(): if self.center: - update_list.append(K.moving_average_update(self.moving_mean, mu, self.momentum)) - if self.scale: - update_list.append(K.moving_average_update(self.moving_Vrr, Vrr, self.momentum)) - update_list.append(K.moving_average_update(self.moving_Vii, Vii, self.momentum)) - update_list.append(K.moving_average_update(self.moving_Vri, Vri, self.momentum)) - self.add_update(update_list) - - def normalize_inference(): - if self.center: - inference_centred = inputs - K.reshape(self.moving_mean, broadcast_mu_shape) - else: - inference_centred = inputs - return ComplexBN( - inference_centred, - self.moving_Vrr, - self.moving_Vii, - self.moving_Vri, - self.beta, - self.gamma_rr, - self.gamma_ri, - self.gamma_ii, - self.scale, - self.center, - axis=self.axis, + inference_centred = inputs - K.reshape( + self.moving_mean, broadcast_mu_shape ) + else: + inference_centred = inputs + return ComplexBN( + inference_centred, + self.moving_Vrr, + self.moving_Vii, + self.moving_Vri, + self.beta, + self.gamma_rr, + self.gamma_ri, + self.gamma_ii, + self.scale, + self.center, + axis=self.axis, + ) # Pick the normalized form corresponding to the training phase. return K.in_train_phase(input_bn, normalize_inference, training=training) @@ -503,10 +548,16 @@ def get_config(self): "gamma_diag_initializer": sanitizedInitSer(self.gamma_diag_initializer), "gamma_off_initializer": sanitizedInitSer(self.gamma_off_initializer), "moving_mean_initializer": sanitizedInitSer(self.moving_mean_initializer), - "moving_variance_initializer": sanitizedInitSer(self.moving_variance_initializer), - "moving_covariance_initializer": sanitizedInitSer(self.moving_covariance_initializer), + "moving_variance_initializer": sanitizedInitSer( + self.moving_variance_initializer + ), + "moving_covariance_initializer": sanitizedInitSer( + self.moving_covariance_initializer + ), "beta_regularizer": regularizers.serialize(self.beta_regularizer), - "gamma_diag_regularizer": regularizers.serialize(self.gamma_diag_regularizer), + "gamma_diag_regularizer": regularizers.serialize( + self.gamma_diag_regularizer + ), "gamma_off_regularizer": regularizers.serialize(self.gamma_off_regularizer), "beta_constraint": constraints.serialize(self.beta_constraint), "gamma_diag_constraint": constraints.serialize(self.gamma_diag_constraint), diff --git a/complexnn/conv.py b/complexnn/conv.py index 11bf21f..1f6303d 100644 --- a/complexnn/conv.py +++ b/complexnn/conv.py @@ -1,20 +1,21 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- +import numpy as np import tensorflow as tf +from tensorflow.keras import activations, constraints, initializers, regularizers from tensorflow.keras import backend as K -from tensorflow.keras import activations, initializers, regularizers, constraints from tensorflow.keras.layers import ( - Layer, InputSpec, + Layer, ) from tensorflow.python.keras.layers.convolutional import Conv from tensorflow.python.keras.utils import conv_utils -import numpy as np -from .fft import fft, ifft, fft2, ifft2 + from .bn import ComplexBN as complex_normalization from .bn import sqrt_init -from .init import ComplexInit, ComplexIndependentFilters +from .fft import fft, fft2, ifft, ifft2 +from .init import ComplexIndependentFilters, ComplexInit def conv1d_transpose( @@ -116,7 +117,9 @@ def conv2d_transpose( output_shape = (batch_size, out_height, out_width, filters) filter = K.permute_dimensions(filter, (0, 1, 3, 2)) - return K.conv2d_transpose(inputs, filter, output_shape, strides, padding=padding, data_format=data_format) + return K.conv2d_transpose( + inputs, filter, output_shape, strides, padding=padding, data_format=data_format + ) def ifft(f): @@ -129,7 +132,9 @@ def ifft2(f): raise NotImplementedError(str(f)) -def conv_transpose_output_length(input_length, filter_size, padding, stride, dilation=1, output_padding=None): +def conv_transpose_output_length( + input_length, filter_size, padding, stride, dilation=1, output_padding=None +): """Rearrange arguments for compatibility with conv_output_length.""" if dilation != 1: msg = f"Dilation must be 1 for transposed convolution. " @@ -278,8 +283,14 @@ def __init__( self.kernel_size = conv_utils.normalize_tuple(kernel_size, rank, "kernel_size") self.strides = conv_utils.normalize_tuple(strides, rank, "strides") self.padding = conv_utils.normalize_padding(padding) - self.data_format = "channels_last" if rank == 1 else conv_utils.normalize_data_format(data_format) - self.dilation_rate = conv_utils.normalize_tuple(dilation_rate, rank, "dilation_rate") + self.data_format = ( + "channels_last" + if rank == 1 + else conv_utils.normalize_data_format(data_format) + ) + self.dilation_rate = conv_utils.normalize_tuple( + dilation_rate, rank, "dilation_rate" + ) self.activation = activations.get(activation) self.use_bias = use_bias self.normalize_weight = normalize_weight @@ -321,7 +332,10 @@ def build(self, input_shape): else: channel_axis = -1 if input_shape[channel_axis] is None: - raise ValueError("The channel dimension of the inputs " "should be defined. Found `None`.") + raise ValueError( + "The channel dimension of the inputs " + "should be defined. Found `None`." + ) # Divide by 2 for real and complex input. input_dim = input_shape[channel_axis] // 2 if False and self.transposed: @@ -354,7 +368,7 @@ def build(self, input_shape): actual_kernel_shape[-1] *= 2 self.kernel = self.add_weight( - "kernel", + name="kernel", shape=tuple(actual_kernel_shape), initializer=kern_init, regularizer=self.kernel_regularizer, @@ -392,8 +406,8 @@ def build(self, input_shape): if self.use_bias: bias_shape = (2 * self.filters,) self.bias = self.add_weight( - "bias", - bias_shape, + name="bias", + shape=bias_shape, initializer=self.bias_initializer, regularizer=self.bias_regularizer, constraint=self.bias_constraint, @@ -403,10 +417,12 @@ def build(self, input_shape): self.bias = None # Set input spec. - self.input_spec = InputSpec(ndim=self.rank + 2, axes={channel_axis: input_dim * 2}) + self.input_spec = InputSpec( + ndim=self.rank + 2, axes={channel_axis: input_dim * 2} + ) self.built = True - def call(self, inputs, **kwargs): + def call(self, inputs): if self.data_format == "channels_first": channel_axis = 1 else: @@ -437,7 +453,9 @@ def call(self, inputs, **kwargs): "strides": self.strides[0] if self.rank == 1 else self.strides, "padding": self.padding, "data_format": self.data_format, - "dilation_rate": self.dilation_rate[0] if self.rank == 1 else self.dilation_rate, + "dilation_rate": self.dilation_rate[0] + if self.rank == 1 + else self.dilation_rate, } if self.transposed: convArgs.pop("dilation_rate", None) @@ -532,7 +550,9 @@ def call(self, inputs, **kwargs): cat_kernels_4_real = K.concatenate([f_real, -f_imag], axis=-2) cat_kernels_4_imag = K.concatenate([f_imag, f_real], axis=-2) - cat_kernels_4_complex = K.concatenate([cat_kernels_4_real, cat_kernels_4_imag], axis=-1) + cat_kernels_4_complex = K.concatenate( + [cat_kernels_4_real, cat_kernels_4_imag], axis=-1 + ) if False and self.transposed: cat_kernels_4_complex._keras_shape = self.kernel_size + ( 2 * self.filters, @@ -604,7 +624,9 @@ def get_config(self): "gamma_off_initializer": sanitizedInitSer(self.gamma_off_initializer), "kernel_regularizer": regularizers.serialize(self.kernel_regularizer), "bias_regularizer": regularizers.serialize(self.bias_regularizer), - "gamma_diag_regularizer": regularizers.serialize(self.gamma_diag_regularizer), + "gamma_diag_regularizer": regularizers.serialize( + self.gamma_diag_regularizer + ), "gamma_off_regularizer": regularizers.serialize(self.gamma_off_regularizer), "activity_regularizer": regularizers.serialize(self.activity_regularizer), "kernel_constraint": constraints.serialize(self.kernel_constraint), @@ -1083,7 +1105,10 @@ def build(self, input_shape): else: channel_axis = -1 if input_shape[channel_axis] is None: - raise ValueError("The channel dimension of the inputs " "should be defined. Found `None`.") + raise ValueError( + "The channel dimension of the inputs " + "should be defined. Found `None`." + ) input_dim = input_shape[channel_axis] gamma_shape = (input_dim * self.filters,) self.gamma = self.add_weight( @@ -1101,14 +1126,22 @@ def call(self, inputs): else: channel_axis = -1 if input_shape[channel_axis] is None: - raise ValueError("The channel dimension of the inputs " "should be defined. Found `None`.") + raise ValueError( + "The channel dimension of the inputs " + "should be defined. Found `None`." + ) input_dim = input_shape[channel_axis] ker_shape = self.kernel_size + (input_dim, self.filters) nb_kernels = ker_shape[-2] * ker_shape[-1] kernel_shape_4_norm = (np.prod(self.kernel_size), nb_kernels) reshaped_kernel = K.reshape(self.kernel, kernel_shape_4_norm) - normalized_weight = K.l2_normalize(reshaped_kernel, axis=0, epsilon=self.epsilon) - normalized_weight = K.reshape(self.gamma, (1, ker_shape[-2] * ker_shape[-1])) * normalized_weight + normalized_weight = K.l2_normalize( + reshaped_kernel, axis=0, epsilon=self.epsilon + ) + normalized_weight = ( + K.reshape(self.gamma, (1, ker_shape[-2] * ker_shape[-1])) + * normalized_weight + ) shaped_kernel = K.reshape(normalized_weight, ker_shape) shaped_kernel._keras_shape = ker_shape @@ -1116,7 +1149,9 @@ def call(self, inputs): "strides": self.strides[0] if self.rank == 1 else self.strides, "padding": self.padding, "data_format": self.data_format, - "dilation_rate": self.dilation_rate[0] if self.rank == 1 else self.dilation_rate, + "dilation_rate": self.dilation_rate[0] + if self.rank == 1 + else self.dilation_rate, } convFunc = {1: K.conv1d, 2: K.conv2d, 3: K.conv3d}[self.rank] output = convFunc(inputs, shaped_kernel, **convArgs) diff --git a/complexnn/dense.py b/complexnn/dense.py index 9b31dc0..274ffd5 100644 --- a/complexnn/dense.py +++ b/complexnn/dense.py @@ -1,12 +1,12 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- -from tensorflow.keras import backend as K -from tensorflow.keras import backend as K -from tensorflow.keras import activations, initializers, regularizers, constraints -from tensorflow.keras.layers import Layer, InputSpec import numpy as np from numpy.random import RandomState +from tensorflow.keras import activations, constraints, initializers, regularizers +from tensorflow.keras import backend as K +from tensorflow.keras.layers import InputSpec, Layer + from .utils import _compute_fans @@ -69,7 +69,7 @@ def __init__( kernel_constraint=None, bias_constraint=None, seed=None, - **kwargs + **kwargs, ): if "input_shape" not in kwargs and "input_dim" in kwargs: kwargs["input_shape"] = (kwargs.pop("input_dim"),) @@ -163,15 +163,21 @@ def init_w_imag(shape, dtype=None): self.input_spec = InputSpec(ndim=2, axes={-1: 2 * input_dim}) self.built = True - def call(self, inputs, **kwargs): + def call(self, inputs): input_shape = K.shape(inputs) input_dim = input_shape[-1] // 2 real_input = inputs[:, :input_dim] imag_input = inputs[:, input_dim:] - cat_kernels_4_real = K.concatenate([self.real_kernel, -self.imag_kernel], axis=-1) - cat_kernels_4_imag = K.concatenate([self.imag_kernel, self.real_kernel], axis=-1) - cat_kernels_4_complex = K.concatenate([cat_kernels_4_real, cat_kernels_4_imag], axis=0) + cat_kernels_4_real = K.concatenate( + [self.real_kernel, -self.imag_kernel], axis=-1 + ) + cat_kernels_4_imag = K.concatenate( + [self.imag_kernel, self.real_kernel], axis=-1 + ) + cat_kernels_4_complex = K.concatenate( + [cat_kernels_4_real, cat_kernels_4_imag], axis=0 + ) output = K.dot(inputs, cat_kernels_4_complex) diff --git a/complexnn/fft.py b/complexnn/fft.py index 867b7d8..1a00162 100644 --- a/complexnn/fft.py +++ b/complexnn/fft.py @@ -2,12 +2,11 @@ # -*- coding: utf-8 -*- # import tensorflow.keras.engine as KE +import numpy as np +import tensorflow as tf import tensorflow.keras.backend as KB import tensorflow.keras.layers as KL import tensorflow.keras.optimizers as KO -import tensorflow as tf -import numpy as np - # # FFT functions: @@ -26,10 +25,14 @@ def fft(z): Zr, Zi = tf.signal.rfft(z[:B]), tf.signal.rfft(z[B:]) isOdd = tf.equal(L % 2, 1) Zr = tf.cond( - isOdd, tf.concat([Zr, C * Zr[:, 1:][:, ::-1]], axis=1), tf.concat([Zr, C * Zr[:, 1:-1][:, ::-1]], axis=1) + isOdd, + tf.concat([Zr, C * Zr[:, 1:][:, ::-1]], axis=1), + tf.concat([Zr, C * Zr[:, 1:-1][:, ::-1]], axis=1), ) Zi = tf.cond( - isOdd, tf.concat([Zi, C * Zi[:, 1:][:, ::-1]], axis=1), tf.concat([Zi, C * Zi[:, 1:-1][:, ::-1]], axis=1) + isOdd, + tf.concat([Zi, C * Zi[:, 1:][:, ::-1]], axis=1), + tf.concat([Zi, C * Zi[:, 1:-1][:, ::-1]], axis=1), ) Zi = (C * Zi)[:, :, ::-1] # Zi * i Z = Zr + Zi @@ -43,10 +46,14 @@ def ifft(z): Zr, Zi = tf.signal.rfft(z[:B]), tf.signal.rfft(z[B:] * -1) isOdd = tf.equal(L % 2, 1) Zr = tf.cond( - isOdd, tf.concat([Zr, C * Zr[:, 1:][:, ::-1]], axis=1), tf.concat([Zr, C * Zr[:, 1:-1][:, ::-1]], axis=1) + isOdd, + tf.concat([Zr, C * Zr[:, 1:][:, ::-1]], axis=1), + tf.concat([Zr, C * Zr[:, 1:-1][:, ::-1]], axis=1), ) Zi = tf.cond( - isOdd, tf.concat([Zi, C * Zi[:, 1:][:, ::-1]], axis=1), tf.concat([Zi, C * Zi[:, 1:-1][:, ::-1]], axis=1) + isOdd, + tf.concat([Zi, C * Zi[:, 1:][:, ::-1]], axis=1), + tf.concat([Zi, C * Zi[:, 1:-1][:, ::-1]], axis=1), ) Zi = (C * Zi)[:, :, ::-1] # Zi * i Z = Zr + Zi diff --git a/complexnn/init.py b/complexnn/init.py index 480f41a..6ebd7fc 100644 --- a/complexnn/init.py +++ b/complexnn/init.py @@ -2,10 +2,14 @@ # -*- coding: utf-8 -*- import numpy as np -from numpy.random import RandomState import tensorflow.keras.backend as K +from numpy.random import RandomState from tensorflow.keras.initializers import Initializer -from tensorflow.python.keras.utils.generic_utils import serialize_keras_object, deserialize_keras_object +from tensorflow.python.keras.utils.generic_utils import ( + deserialize_keras_object, + serialize_keras_object, +) + from .utils import _compute_fans @@ -13,8 +17,15 @@ class IndependentFilters(Initializer): # This initialization constructs real-valued kernels # that are independent as much as possible from each other # while respecting either the He or the Glorot criterion. - def __init__(self, kernel_size, input_dim, weight_dim, nb_filters=None, criterion="glorot", seed=None): - + def __init__( + self, + kernel_size, + input_dim, + weight_dim, + nb_filters=None, + criterion="glorot", + seed=None, + ): # `weight_dim` is used as a parameter for sanity check # as we should not pass an integer as kernel_size when # the weight dimension is >= 2. @@ -35,7 +46,6 @@ def __init__(self, kernel_size, input_dim, weight_dim, nb_filters=None, criterio self.seed = 1337 if seed is None else seed def __call__(self, shape, dtype=None): - if self.nb_filters is not None: num_rows = self.nb_filters * self.input_dim num_cols = np.prod(self.kernel_size) @@ -51,8 +61,12 @@ def __call__(self, shape, dtype=None): u, _, v = np.linalg.svd(x) orthogonal_x = np.dot(u, np.dot(np.eye(num_rows, num_cols), v.T)) if self.nb_filters is not None: - independent_filters = np.reshape(orthogonal_x, (num_rows,) + tuple(self.kernel_size)) - fan_in, fan_out = _compute_fans(tuple(self.kernel_size) + (self.input_dim, self.nb_filters)) + independent_filters = np.reshape( + orthogonal_x, (num_rows,) + tuple(self.kernel_size) + ) + fan_in, fan_out = _compute_fans( + tuple(self.kernel_size) + (self.input_dim, self.nb_filters) + ) else: independent_filters = orthogonal_x fan_in, fan_out = (self.input_dim, self.kernel_size[-1]) @@ -67,7 +81,6 @@ def __call__(self, shape, dtype=None): multip_constant = np.sqrt(desired_var / np.var(independent_filters)) scaled_indep = multip_constant * independent_filters - if self.weight_dim == 2 and self.nb_filters is None: weight = scaled_indep else: @@ -98,8 +111,15 @@ class ComplexIndependentFilters(Initializer): # This initialization constructs complex-valued kernels # that are independent as much as possible from each other # while respecting either the He or the Glorot criterion. - def __init__(self, kernel_size, input_dim, weight_dim, nb_filters=None, criterion="glorot", seed=None): - + def __init__( + self, + kernel_size, + input_dim, + weight_dim, + nb_filters=None, + criterion="glorot", + seed=None, + ): # `weight_dim` is used as a parameter for sanity check # as we should not pass an integer as kernel_size when # the weight dimension is >= 2. @@ -120,7 +140,6 @@ def __init__(self, kernel_size, input_dim, weight_dim, nb_filters=None, criterio self.seed = 1337 if seed is None else seed def __call__(self, shape, dtype=None): - if self.nb_filters is not None: num_rows = self.nb_filters * self.input_dim num_cols = np.prod(self.kernel_size) @@ -136,13 +155,17 @@ def __call__(self, shape, dtype=None): i = rng.uniform(size=flat_shape) z = r + 1j * i u, _, v = np.linalg.svd(z) - unitary_z = np.dot(u, np.dot(np.eye(int(num_rows), int(num_cols)), np.conjugate(v).T)) + unitary_z = np.dot( + u, np.dot(np.eye(int(num_rows), int(num_cols)), np.conjugate(v).T) + ) real_unitary = unitary_z.real imag_unitary = unitary_z.imag if self.nb_filters is not None: indep_real = np.reshape(real_unitary, (num_rows,) + tuple(self.kernel_size)) indep_imag = np.reshape(imag_unitary, (num_rows,) + tuple(self.kernel_size)) - fan_in, fan_out = _compute_fans(tuple(self.kernel_size) + (int(self.input_dim), self.nb_filters)) + fan_in, fan_out = _compute_fans( + tuple(self.kernel_size) + (int(self.input_dim), self.nb_filters) + ) else: indep_real = real_unitary indep_imag = imag_unitary @@ -164,7 +187,10 @@ def __call__(self, shape, dtype=None): weight_real = scaled_real weight_imag = scaled_imag else: - kernel_shape = tuple(self.kernel_size) + (int(self.input_dim), self.nb_filters) + kernel_shape = tuple(self.kernel_size) + ( + int(self.input_dim), + self.nb_filters, + ) if self.weight_dim == 1: transpose_shape = (1, 0) elif self.weight_dim == 2 and self.nb_filters is not None: @@ -194,8 +220,15 @@ def get_config(self): class ComplexInit(Initializer): # The standard complex initialization using # either the He or the Glorot criterion. - def __init__(self, kernel_size, input_dim, weight_dim, nb_filters=None, criterion="glorot", seed=None): - + def __init__( + self, + kernel_size, + input_dim, + weight_dim, + nb_filters=None, + criterion="glorot", + seed=None, + ): # `weight_dim` is used as a parameter for sanity check # as we should not pass an integer as kernel_size when # the weight dimension is >= 2. @@ -216,7 +249,6 @@ def __init__(self, kernel_size, input_dim, weight_dim, nb_filters=None, criterio self.seed = 1337 if seed is None else seed def __call__(self, shape, dtype=None): - if self.nb_filters is not None: kernel_shape = shape # kernel_shape = tuple(self.kernel_size) + (int(self.input_dim), diff --git a/complexnn/norm.py b/complexnn/norm.py index 2806083..d66e193 100644 --- a/complexnn/norm.py +++ b/complexnn/norm.py @@ -5,9 +5,10 @@ import numpy as np -from tensorflow.keras.layers import Layer, InputSpec -from tensorflow.keras import initializers, regularizers, constraints import tensorflow.keras.backend as K +from tensorflow.keras import constraints, initializers, regularizers +from tensorflow.keras.layers import InputSpec, Layer + from .bn import ComplexBN as complex_normalization from .bn import sqrt_init @@ -53,9 +54,8 @@ def __init__( gamma_init="ones", gamma_regularizer=None, beta_regularizer=None, - **kwargs + **kwargs, ): - self.supports_masking = True self.beta_init = initializers.get(beta_init) self.gamma_init = initializers.get(gamma_init) @@ -67,14 +67,22 @@ def __init__( super(LayerNormalization, self).__init__(**kwargs) def build(self, input_shape): - self.input_spec = InputSpec(ndim=len(input_shape), axes={self.axis: input_shape[self.axis]}) + self.input_spec = InputSpec( + ndim=len(input_shape), axes={self.axis: input_shape[self.axis]} + ) shape = (input_shape[self.axis],) self.gamma = self.add_weight( - shape, initializer=self.gamma_init, regularizer=self.gamma_regularizer, name="{}_gamma".format(self.name) + shape=shape, + initializer=self.gamma_init, + regularizer=self.gamma_regularizer, + name="{}_gamma".format(self.name), ) self.beta = self.add_weight( - shape, initializer=self.beta_init, regularizer=self.beta_regularizer, name="{}_beta".format(self.name) + shape=shape, + initializer=self.beta_init, + regularizer=self.beta_regularizer, + name="{}_beta".format(self.name), ) self.built = True @@ -87,8 +95,12 @@ def get_config(self): config = { "epsilon": self.epsilon, "axis": self.axis, - "gamma_regularizer": self.gamma_regularizer.get_config() if self.gamma_regularizer else None, - "beta_regularizer": self.beta_regularizer.get_config() if self.beta_regularizer else None, + "gamma_regularizer": self.gamma_regularizer.get_config() + if self.gamma_regularizer + else None, + "beta_regularizer": self.beta_regularizer.get_config() + if self.beta_regularizer + else None, } base_config = super(LayerNormalization, self).get_config() return dict(list(base_config.items()) + list(config.items())) @@ -110,9 +122,8 @@ def __init__( beta_constraint=None, gamma_diag_constraint=None, gamma_off_constraint=None, - **kwargs + **kwargs, ): - self.supports_masking = True self.epsilon = epsilon self.axis = axis @@ -130,7 +141,6 @@ def __init__( super(ComplexLayerNorm, self).__init__(**kwargs) def build(self, input_shape): - ndim = len(input_shape) dim = input_shape[self.axis] if dim is None: @@ -182,7 +192,7 @@ def build(self, input_shape): self.built = True - def call(self, inputs, **kwargs): + def call(self, inputs): input_shape = K.shape(inputs) ndim = K.ndim(inputs) reduction_axes = list(range(ndim)) @@ -237,7 +247,9 @@ def call(self, inputs, **kwargs): Vii = None Vri = None else: - raise ValueError("Error. Both scale and center in batchnorm are set to False.") + raise ValueError( + "Error. Both scale and center in batchnorm are set to False." + ) return complex_normalization( input_centred, @@ -261,10 +273,14 @@ def get_config(self): "center": self.center, "scale": self.scale, "beta_initializer": initializers.serialize(self.beta_initializer), - "gamma_diag_initializer": initializers.serialize(self.gamma_diag_initializer), + "gamma_diag_initializer": initializers.serialize( + self.gamma_diag_initializer + ), "gamma_off_initializer": initializers.serialize(self.gamma_off_initializer), "beta_regularizer": regularizers.serialize(self.beta_regularizer), - "gamma_diag_regularizer": regularizers.serialize(self.gamma_diag_regularizer), + "gamma_diag_regularizer": regularizers.serialize( + self.gamma_diag_regularizer + ), "gamma_off_regularizer": regularizers.serialize(self.gamma_off_regularizer), "beta_constraint": constraints.serialize(self.beta_constraint), "gamma_diag_constraint": constraints.serialize(self.gamma_diag_constraint), diff --git a/complexnn/pool.py b/complexnn/pool.py index e07405f..3eac79f 100644 --- a/complexnn/pool.py +++ b/complexnn/pool.py @@ -1,12 +1,11 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- -import tensorflow.keras.backend as KB +import numpy as np import tensorflow.keras as KE +import tensorflow.keras.backend as KB import tensorflow.keras.layers as KL import tensorflow.keras.optimizers as KO -import numpy as np - # # Spectral Pooling Layer @@ -37,14 +36,18 @@ def call(self, x, mask=None): if KB.image_data_format() == "channels_first": if topf[0] > 0 and xshape[2] >= 2 * topf[0]: - mask = [1] * (topf[0]) + [0] * (xshape[2] - 2 * topf[0]) + [1] * (topf[0]) + mask = ( + [1] * (topf[0]) + [0] * (xshape[2] - 2 * topf[0]) + [1] * (topf[0]) + ) mask = [[mask]] mask = np.asarray(mask, dtype=KB.floatx()).transpose((0, 1, 2)) mask = KB.constant(mask) x *= mask else: if topf[0] > 0 and xshape[1] >= 2 * topf[0]: - mask = [1] * (topf[0]) + [0] * (xshape[1] - 2 * topf[0]) + [1] * (topf[0]) + mask = ( + [1] * (topf[0]) + [0] * (xshape[1] - 2 * topf[0]) + [1] * (topf[0]) + ) mask = [[mask]] mask = np.asarray(mask, dtype=KB.floatx()).transpose((0, 2, 1)) mask = KB.constant(mask) @@ -77,26 +80,34 @@ def call(self, x, mask=None): if KB.image_data_format() == "channels_first": if topf[0] > 0 and xshape[2] >= 2 * topf[0]: - mask = [1] * (topf[0]) + [0] * (xshape[2] - 2 * topf[0]) + [1] * (topf[0]) + mask = ( + [1] * (topf[0]) + [0] * (xshape[2] - 2 * topf[0]) + [1] * (topf[0]) + ) mask = [[[mask]]] mask = np.asarray(mask, dtype=KB.floatx()).transpose((0, 1, 3, 2)) mask = KB.constant(mask) x *= mask if topf[1] > 0 and xshape[3] >= 2 * topf[1]: - mask = [1] * (topf[1]) + [0] * (xshape[3] - 2 * topf[1]) + [1] * (topf[1]) + mask = ( + [1] * (topf[1]) + [0] * (xshape[3] - 2 * topf[1]) + [1] * (topf[1]) + ) mask = [[[mask]]] mask = np.asarray(mask, dtype=KB.floatx()).transpose((0, 1, 2, 3)) mask = KB.constant(mask) x *= mask else: if topf[0] > 0 and xshape[1] >= 2 * topf[0]: - mask = [1] * (topf[0]) + [0] * (xshape[1] - 2 * topf[0]) + [1] * (topf[0]) + mask = ( + [1] * (topf[0]) + [0] * (xshape[1] - 2 * topf[0]) + [1] * (topf[0]) + ) mask = [[[mask]]] mask = np.asarray(mask, dtype=KB.floatx()).transpose((0, 3, 1, 2)) mask = KB.constant(mask) x *= mask if topf[1] > 0 and xshape[2] >= 2 * topf[1]: - mask = [1] * (topf[1]) + [0] * (xshape[2] - 2 * topf[1]) + [1] * (topf[1]) + mask = ( + [1] * (topf[1]) + [0] * (xshape[2] - 2 * topf[1]) + [1] * (topf[1]) + ) mask = [[[mask]]] mask = np.asarray(mask, dtype=KB.floatx()).transpose((0, 1, 3, 2)) mask = KB.constant(mask) @@ -106,10 +117,13 @@ def call(self, x, mask=None): if __name__ == "__main__": - import cv2, sys - import __main__ as SP + import sys + + import cv2 import fft as CF + import __main__ as SP + # Build Model x = i = KL.Input(shape=(6, 512, 512)) f = CF.FFT2()(x) diff --git a/complexnn/utils.py b/complexnn/utils.py index 55290c7..2ce9c90 100644 --- a/complexnn/utils.py +++ b/complexnn/utils.py @@ -1,10 +1,9 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- -import tensorflow.keras.backend as K -from tensorflow.keras.layers import Layer, Lambda import numpy as np - +import tensorflow.keras.backend as K +from tensorflow.keras.layers import Lambda, Layer # # GetReal/GetImag Lambda layer Implementation diff --git a/docs/source/conf.py b/docs/source/conf.py index 6fdd59e..da654f1 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -27,7 +27,12 @@ # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. -extensions = ["sphinxcontrib.apidoc", "sphinx.ext.autodoc", "recommonmark", "sphinxcontrib.bibtex"] +extensions = [ + "sphinxcontrib.apidoc", + "sphinx.ext.autodoc", + "recommonmark", + "sphinxcontrib.bibtex", +] master_doc = "index" # Needed by RTD diff --git a/pyproject.toml b/pyproject.toml index cc62436..6c5846b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,6 +38,9 @@ dependencies = [ ] dynamic = ["version"] +[project.optional-dependencies] +test = ["pytest"] + [project.urls] homepage = "https://github.com/JesperDramsch/keras-complex/" documentation = "https://keras-complex.readthedocs.org" @@ -46,4 +49,4 @@ changelog = "https://github.com/JesperDramsch/keras-complex/releases" bugtracker = "https://github.com/JesperDramsch/keras-complex/issues" [tool.setuptools.packages.find] -exclude = ["tests*", "docs*"] +exclude = ["tests*", "docs*"] diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..6068493 --- /dev/null +++ b/setup.py @@ -0,0 +1,3 @@ +from setuptools import setup + +setup() diff --git a/tests/test_conv.py b/tests/test_conv.py index fd012f1..4d82ecb 100644 --- a/tests/test_conv.py +++ b/tests/test_conv.py @@ -1,10 +1,10 @@ import unittest -from tensorflow.keras.layers import Input, MaxPooling2D, Dense -from tensorflow.keras.models import Model, Sequential +import numpy as np import tensorflow as tf +from tensorflow.keras.layers import Dense, Input, MaxPooling2D +from tensorflow.keras.models import Model, Sequential -import numpy as np import complexnn as conn @@ -13,7 +13,9 @@ class TestConvMethods(unittest.TestCase): def test_conv_outputs_forward(self): """Test computed shape of forward convolution output""" - layer = conn.ComplexConv2D(filters=4, kernel_size=3, strides=2, padding="same", transposed=False) + layer = conn.ComplexConv2D( + filters=4, kernel_size=3, strides=2, padding="same", transposed=False + ) input_shape = (None, 128, 128, 2) true = (None, 64, 64, 8) calc = layer.compute_output_shape(input_shape) @@ -21,7 +23,9 @@ def test_conv_outputs_forward(self): def test_outputs_transpose(self): """Test computed shape of transposed convolution output""" - layer = conn.ComplexConv2D(filters=2, kernel_size=3, strides=2, padding="same", transposed=True) + layer = conn.ComplexConv2D( + filters=2, kernel_size=3, strides=2, padding="same", transposed=True + ) input_shape = (None, 64, 64, 4) true = (None, 128, 128, 4) calc = layer.compute_output_shape(input_shape) @@ -30,7 +34,9 @@ def test_outputs_transpose(self): def test_conv2D_forward(self): """Test shape of model output, forward""" inputs = Input(shape=(128, 128, 2)) - outputs = conn.ComplexConv2D(filters=4, kernel_size=3, strides=2, padding="same", transposed=False)(inputs) + outputs = conn.ComplexConv2D( + filters=4, kernel_size=3, strides=2, padding="same", transposed=False + )(inputs) model = Model(inputs=inputs, outputs=outputs) true = (None, 64, 64, 8) calc = model.output_shape @@ -40,7 +46,11 @@ def test_conv2Dtranspose(self): """Test shape of model output, transposed""" inputs = Input(shape=(64, 64, 20)) # = 10 CDN filters outputs = conn.ComplexConv2D( - filters=2, kernel_size=3, strides=2, padding="same", transposed=True # = 4 Keras filters + filters=2, + kernel_size=3, + strides=2, + padding="same", + transposed=True, # = 4 Keras filters )(inputs) model = Model(inputs=inputs, outputs=outputs) true = (None, 128, 128, 4) diff --git a/tests/test_dense.py b/tests/test_dense.py index 4f8364c..87eb0f8 100644 --- a/tests/test_dense.py +++ b/tests/test_dense.py @@ -1,10 +1,10 @@ import unittest -from tensorflow.keras.layers import Input, MaxPooling2D, Dense -from tensorflow.keras.models import Model, Sequential +import numpy as np import tensorflow as tf +from tensorflow.keras.layers import Dense, Input, MaxPooling2D +from tensorflow.keras.models import Model, Sequential -import numpy as np import complexnn as conn diff --git a/tests/test_readme.py b/tests/test_readme.py index 795e274..9f436cd 100644 --- a/tests/test_readme.py +++ b/tests/test_readme.py @@ -1,10 +1,10 @@ import unittest -from tensorflow.keras.layers import Input, MaxPooling2D, Dense -from tensorflow.keras.models import Model, Sequential +import numpy as np import tensorflow as tf +from tensorflow.keras.layers import Dense, Input, MaxPooling2D +from tensorflow.keras.models import Model, Sequential -import numpy as np import complexnn as conn @@ -14,8 +14,12 @@ class TestDNCMethods(unittest.TestCase): def test_github_example(self): # example from repository https://github.com/JesperDramsch/keras-complex/blob/master/README.md page model = tf.keras.models.Sequential() - model.add(conn.conv.ComplexConv2D(32, (3, 3), activation="relu", padding="same", input_shape=(28, 28, 2))) + model.add( + conn.conv.ComplexConv2D( + 32, (3, 3), activation="relu", padding="same", input_shape=(28, 28, 2) + ) + ) model.add(conn.bn.ComplexBatchNormalization()) model.add(MaxPooling2D((2, 2), padding="same")) model.compile(optimizer=tf.keras.optimizers.Adam(), loss="mse") - model.summary() \ No newline at end of file + model.summary() diff --git a/tests/test_train.py b/tests/test_train.py index 2bb5940..75452f0 100644 --- a/tests/test_train.py +++ b/tests/test_train.py @@ -1,10 +1,10 @@ import unittest -from tensorflow.keras.layers import Input, MaxPooling2D, Dense -from tensorflow.keras.models import Model, Sequential +import numpy as np import tensorflow as tf +from tensorflow.keras.layers import Dense, Input, MaxPooling2D +from tensorflow.keras.models import Model, Sequential -import numpy as np import complexnn as conn @@ -20,10 +20,18 @@ def test_train_transpose(self): Y = X inputs = Input(shape=(64, 64, 2)) conv1 = conn.ComplexConv2D( - filters=2, kernel_size=3, strides=2, padding="same", transposed=False # = 4 Keras filters + filters=2, + kernel_size=3, + strides=2, + padding="same", + transposed=False, # = 4 Keras filters )(inputs) outputs = conn.ComplexConv2D( - filters=1, kernel_size=3, strides=2, padding="same", transposed=True # = 2 Keras filters => 1 complex layer + filters=1, + kernel_size=3, + strides=2, + padding="same", + transposed=True, # = 2 Keras filters => 1 complex layer )(conv1) model = Model(inputs=inputs, outputs=outputs) model.compile(optimizer="adam", loss="mean_squared_error", metrics=["accuracy"])