Skip to content

Commit

Permalink
Require environmental variable SPECTRAL_CONNECTIVITY_ENABLE_GPU to …
Browse files Browse the repository at this point in the history
…be set to 'true' in order to use GPU along with installation of cupy
  • Loading branch information
edeno committed May 22, 2022
1 parent ec9d57f commit a36553c
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 34 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
+ it implements the canonical coherence, which can
efficiently summarize brain-area level coherences from multielectrode recordings.
+ easier user interface for the multitaper fourier transform
+ all function are GPU-enabled if `cupy` is installed and the environmental variable `SPECTRAL_CONNECTIVITY_ENABLE_GPU` is set to 'true'.

See the notebooks ([\#1](examples/Tutorial_On_Simulated_Examples.ipynb), [\#2](examples/Tutorial_Using_Paper_Examples.ipynb)) for more information on how to use the package.

Expand All @@ -26,7 +27,7 @@ m = Multitaper(time_series=signals,
time_window_duration=0.060,
time_window_step=0.060,
start_time=time[0])

# Sets up computing connectivity measures/power from multitaper spectral estimate
c = Connectivity.from_multitaper(m)

Expand Down
7 changes: 3 additions & 4 deletions spectral_connectivity/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# flake8: noqa
from .connectivity import Connectivity
from .wrapper import multitaper_connectivity
from .transforms import Multitaper
from .wrapper import multitaper_connectivity
from spectral_connectivity.connectivity import Connectivity
from spectral_connectivity.transforms import Multitaper
from spectral_connectivity.wrapper import multitaper_connectivity
34 changes: 21 additions & 13 deletions spectral_connectivity/connectivity.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,31 @@
import os
from functools import partial, wraps
from inspect import signature
from itertools import combinations

try:
import cupy as xp
from cupyx.scipy.fft import ifft
from cupyx.scipy.sparse.linalg import svds
except ImportError:
import numpy as xp
from scipy.fft import ifft
from scipy.sparse.linalg import svds

import numpy as np
from scipy.ndimage import label
from scipy.stats.mstats import linregress

from .minimum_phase_decomposition import minimum_phase_decomposition
from .statistics import (adjust_for_multiple_comparisons, coherence_bias,
fisher_z_transform, get_normal_distribution_p_values)
from spectral_connectivity.minimum_phase_decomposition import \
minimum_phase_decomposition
from spectral_connectivity.statistics import (adjust_for_multiple_comparisons,
coherence_bias,
fisher_z_transform,
get_normal_distribution_p_values)

if os.environ.get('SPEC_CON_ENABLE_GPU') == 'true':
try:
import cupy as xp
from cupyx.scipy.fft import ifft
from cupyx.scipy.sparse.linalg import svds
except ImportError:
import numpy as xp
from scipy.fft import ifft
from scipy.sparse.linalg import svds
else:
import numpy as xp
from scipy.fft import ifft
from scipy.sparse.linalg import svds

EXPECTATION = {
'time': partial(xp.mean, axis=0),
Expand Down
16 changes: 11 additions & 5 deletions spectral_connectivity/minimum_phase_decomposition.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
import os
from logging import getLogger

try:
import cupy as xp
from cupyx.scipy.fft import fft, ifft
except ImportError:
import numpy as np

if os.environ.get('SPECTRAL_CONNECTIVITY_ENABLE_GPU') == 'true':
try:
import cupy as xp
from cupyx.scipy.fft import fft, ifft
except ImportError:
import numpy as xp
from scipy.fft import fft, ifft
else:
import numpy as xp
from scipy.fft import fft, ifft

import numpy as np

logger = getLogger(__name__)

Expand Down
30 changes: 19 additions & 11 deletions spectral_connectivity/transforms.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,29 @@
import os
from logging import getLogger

from scipy import interpolate

try:
import cupy as xp
from cupy.linalg import lstsq
from cupyx.scipy.fft import fft, fftfreq, ifft, next_fast_len
except ImportError:
import numpy as xp
from scipy.fftpack import fft, ifft, next_fast_len, fftfreq
from scipy.linalg import lstsq

import numpy as np
from scipy import interpolate
from scipy.linalg import eigvals_banded

logger = getLogger(__name__)

if os.environ.get('SPECTRAL_CONNECTIVITY_ENABLE_GPU') == 'true':
try:
logger.info('Using GPU for spectral_connectivity...')
import cupy as xp
from cupy.linalg import lstsq
from cupyx.scipy.fft import fft, fftfreq, ifft, next_fast_len
except ImportError:
print('Cupy not installed. Cupy is needed to use GPU for '
'spectral_connectivity.')
import numpy as xp
from scipy.fftpack import fft, fftfreq, ifft, next_fast_len
from scipy.linalg import lstsq
else:
import numpy as xp
from scipy.fftpack import fft, fftfreq, ifft, next_fast_len
from scipy.linalg import lstsq


class Multitaper(object):
'''Transform time-domain signal(s) to the frequency domain by using
Expand Down

0 comments on commit a36553c

Please sign in to comment.