Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] instantaneous rate with list[neo.SpikeTrain] #649

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
Open
7 changes: 3 additions & 4 deletions elephant/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@
import elephant.trials
from elephant.conversion import BinnedSpikeTrain
from elephant.utils import deprecated_alias, check_neo_consistency, \
is_time_quantity, round_binning_errors
is_time_quantity, round_binning_errors, is_list_spiketrains

# do not import unicode_literals
# (quantities rescale does not work with unicodes)
Expand Down Expand Up @@ -613,7 +613,7 @@ def instantaneous_rate(spiketrains, sampling_period, kernel='auto',

Parameters
----------
spiketrains : neo.SpikeTrain, list of neo.SpikeTrain or elephant.trials.Trials # noqa
spiketrains : neo.SpikeTrain, list of neo.SpikeTrain or elephant.trials.Trials
Input spike train(s) for which the instantaneous firing rate is
calculated. If a list of spike trains is supplied, the parameter
pool_spike_trains determines the behavior of the function. If a Trials
Expand Down Expand Up @@ -1031,8 +1031,7 @@ def optimal_kernel(st):
sigma=str(kernel.sigma),
invert=kernel.invert)

if isinstance(spiketrains, neo.core.spiketrainlist.SpikeTrainList) and (
pool_spike_trains):
if is_list_spiketrains(spiketrains) and (pool_spike_trains):
rate = np.mean(rate, axis=1)

rate = neo.AnalogSignal(signal=rate,
Expand Down
58 changes: 53 additions & 5 deletions elephant/test/test_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,15 +482,15 @@ def test_cv2_raise_error(self):
self.assertRaises(ValueError, statistics.cv2, np.array([seq, seq]))


class InstantaneousRateTest(unittest.TestCase):
class InstantaneousRateTestCase(unittest.TestCase):

@classmethod
def setUpClass(cls) -> None:
"""
Run once before tests:
"""

block = _create_trials_block(n_trials=36)
block = _create_trials_block(n_trials=36, n_spiketrains=5)
cls.block = block
cls.trial_object = TrialsFromBlock(block,
description='trials are segments')
Expand Down Expand Up @@ -988,8 +988,44 @@ def test_instantaneous_rate_trials_pool_trials(self):
pool_spike_trains=False,
pool_trials=True)
self.assertIsInstance(rate, neo.core.AnalogSignal)
self.assertEqual(rate.shape[1], self.trial_object.n_spiketrains_trial_by_trial[0])

def test_instantaneous_rate_list_pool_spike_trains(self):
def test_instantaneous_rate_trials_pool_spiketrains(self):
kernel = kernels.GaussianKernel(sigma=500 * pq.ms)

rate = statistics.instantaneous_rate(self.trial_object,
sampling_period=0.1 * pq.ms,
kernel=kernel,
pool_spike_trains=True,
pool_trials=False)
self.assertIsInstance(rate, list)
self.assertEqual(len(rate), self.trial_object.n_trials)
self.assertEqual(rate[0].shape[1], 1)

def test_instantaneous_rate_trials_pool_spiketrains_pool_trials(self):
kernel = kernels.GaussianKernel(sigma=500 * pq.ms)

rate = statistics.instantaneous_rate(self.trial_object,
sampling_period=0.1 * pq.ms,
kernel=kernel,
pool_spike_trains=True,
pool_trials=True)
self.assertIsInstance(rate, neo.AnalogSignal)
self.assertEqual(rate.shape[1], 1)

def test_instantaneous_rate_trials_pool_spiketrains_false_pool_trials_false(self):
kernel = kernels.GaussianKernel(sigma=500 * pq.ms)

rate = statistics.instantaneous_rate(self.trial_object,
sampling_period=0.1 * pq.ms,
kernel=kernel,
pool_spike_trains=False,
pool_trials=False)
self.assertIsInstance(rate, list)
self.assertEqual(len(rate), self.trial_object.n_trials)
self.assertEqual(rate[0].shape[1], self.trial_object.n_spiketrains_trial_by_trial[0])

def test_instantaneous_rate_spiketrainlist_pool_spike_trains(self):
kernel = kernels.GaussianKernel(sigma=500 * pq.ms)

rate = statistics.instantaneous_rate(
Expand All @@ -999,7 +1035,19 @@ def test_instantaneous_rate_list_pool_spike_trains(self):
pool_spike_trains=True,
pool_trials=False)
self.assertIsInstance(rate, neo.core.AnalogSignal)
self.assertEqual(rate.magnitude.shape[1], 1)
self.assertEqual(rate.shape[1], 1)

def test_instantaneous_rate_list_pool_spike_trains(self):
kernel = kernels.GaussianKernel(sigma=500 * pq.ms)

rate = statistics.instantaneous_rate(
list(self.trial_object.get_spiketrains_from_trial_as_list(0)),
sampling_period=0.1 * pq.ms,
kernel=kernel,
pool_spike_trains=True,
pool_trials=False)
self.assertIsInstance(rate, neo.core.AnalogSignal)
self.assertEqual(rate.shape[1], 1)

def test_instantaneous_rate_list_of_spike_trains(self):
kernel = kernels.GaussianKernel(sigma=500 * pq.ms)
Expand All @@ -1010,7 +1058,7 @@ def test_instantaneous_rate_list_of_spike_trains(self):
pool_spike_trains=False,
pool_trials=False)
self.assertIsInstance(rate, neo.core.AnalogSignal)
self.assertEqual(rate.magnitude.shape[1], 2)
self.assertEqual(rate.magnitude.shape[1], self.trial_object.n_spiketrains_trial_by_trial[0])


class TimeHistogramTestCase(unittest.TestCase):
Expand Down
37 changes: 37 additions & 0 deletions elephant/test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,5 +136,42 @@ def test_decorator_return_with_list_of_lists_input_as_kwarg(self):
self.assertIsInstance(spiketrain, SpikeTrain)


class TestIsListNeoSpiketrains(unittest.TestCase):
def setUp(self):
# Set up common test spiketrains.
self.spiketrain1 = neo.SpikeTrain([1, 2, 3] * pq.s, t_stop=4 * pq.s)
self.spiketrain2 = neo.SpikeTrain([2, 3, 4] * pq.s, t_stop=5 * pq.s)

def test_valid_list_input(self):
valid_list = [self.spiketrain1, self.spiketrain2]
self.assertTrue(utils.is_list_spiketrains(valid_list))

def test_valid_tuple_input(self):
valid_tuple = (self.spiketrain1, self.spiketrain2)
self.assertTrue(utils.is_list_spiketrains(valid_tuple))

def test_valid_spiketrainlist_input(self):
valid_spiketrainlist = neo.core.spiketrainlist.SpikeTrainList(items=(self.spiketrain1, self.spiketrain2))
self.assertTrue(utils.is_list_spiketrains(valid_spiketrainlist))

def test_non_iterable_input(self):
self.assertFalse(utils.is_list_spiketrains(42))

def test_non_spiketrain_objects(self):
invalid_list = [self.spiketrain1, "not a spiketrain"]
self.assertFalse(utils.is_list_spiketrains(invalid_list))

def test_mixed_types_input(self):
invalid_mixed = [self.spiketrain1, 42, self.spiketrain2]
self.assertFalse(utils.is_list_spiketrains(invalid_mixed))

def test_none_input(self):
self.assertFalse(utils.is_list_spiketrains(None))

def test_single_spiketrain_input(self):
single_spiketrain = neo.SpikeTrain([1, 2, 3] * pq.s, t_stop=4 * pq.s)
self.assertFalse(utils.is_list_spiketrains(single_spiketrain))


if __name__ == '__main__':
unittest.main()
33 changes: 32 additions & 1 deletion elephant/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
check_neo_consistency
check_same_units
round_binning_errors
is_list_spiketrains
"""

from __future__ import division, print_function, unicode_literals
Expand All @@ -21,7 +22,8 @@
import quantities as pq

from elephant.trials import Trials

import collections.abc
import neo

__all__ = [
"deprecated_alias",
Expand All @@ -31,6 +33,7 @@
"check_neo_consistency",
"check_same_units",
"round_binning_errors",
"is_list_spiketrains",
]


Expand Down Expand Up @@ -446,3 +449,31 @@ def wrapper(*args, **kwargs):
return method(*new_args, **new_kwargs)

return wrapper


def is_list_spiketrains(obj: object) -> bool:
"""
Check if input is an iterable containing only neo.SpikeTrain objects.

Parameters
----------
obj : object
The object to check.

Returns
-------
bool
True if obj is an iterable containing only neo.SpikeTrain objects. A single `neo.SpikeTrain` object (not inside
an iterable) will return `False`.

"""

if not isinstance(obj, collections.abc.Iterable):
# Input must be an iterable (list, tuple, etc.)
return False

if not all(isinstance(st, neo.SpikeTrain) for st in obj):
# All elements must be neo.SpikeTrain objects
return False

return True
Loading