Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] instantaneous rate with list[neo.SpikeTrain] #649

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
Open
4 changes: 2 additions & 2 deletions elephant/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,7 +613,7 @@ def instantaneous_rate(spiketrains, sampling_period, kernel='auto',

Parameters
----------
spiketrains : neo.SpikeTrain, list of neo.SpikeTrain or elephant.trials.Trials # noqa
spiketrains : neo.SpikeTrain, list of neo.SpikeTrain or elephant.trials.Trials
Input spike train(s) for which the instantaneous firing rate is
calculated. If a list of spike trains is supplied, the parameter
pool_spike_trains determines the behavior of the function. If a Trials
Expand Down Expand Up @@ -1031,7 +1031,7 @@ def optimal_kernel(st):
sigma=str(kernel.sigma),
invert=kernel.invert)

if isinstance(spiketrains, neo.core.spiketrainlist.SpikeTrainList) and (
if isinstance(spiketrains, (neo.core.spiketrainlist.SpikeTrainList, list, tuple)) and (
pool_spike_trains):
rate = np.mean(rate, axis=1)

Expand Down
58 changes: 53 additions & 5 deletions elephant/test/test_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,15 +482,15 @@ def test_cv2_raise_error(self):
self.assertRaises(ValueError, statistics.cv2, np.array([seq, seq]))


class InstantaneousRateTest(unittest.TestCase):
class InstantaneousRateTestCase(unittest.TestCase):

@classmethod
def setUpClass(cls) -> None:
"""
Run once before tests:
"""

block = _create_trials_block(n_trials=36)
block = _create_trials_block(n_trials=36, n_spiketrains=5)
cls.block = block
cls.trial_object = TrialsFromBlock(block,
description='trials are segments')
Expand Down Expand Up @@ -988,8 +988,44 @@ def test_instantaneous_rate_trials_pool_trials(self):
pool_spike_trains=False,
pool_trials=True)
self.assertIsInstance(rate, neo.core.AnalogSignal)
self.assertEqual(rate.shape[1], self.trial_object.n_spiketrains_trial_by_trial[0])

def test_instantaneous_rate_list_pool_spike_trains(self):
def test_instantaneous_rate_trials_pool_spiketrains(self):
kernel = kernels.GaussianKernel(sigma=500 * pq.ms)

rate = statistics.instantaneous_rate(self.trial_object,
sampling_period=0.1 * pq.ms,
kernel=kernel,
pool_spike_trains=True,
pool_trials=False)
self.assertIsInstance(rate, list)
self.assertEqual(len(rate), self.trial_object.n_trials)
self.assertEqual(rate[0].shape[1], 1)

def test_instantaneous_rate_trials_pool_spiketrains_pool_trials(self):
kernel = kernels.GaussianKernel(sigma=500 * pq.ms)

rate = statistics.instantaneous_rate(self.trial_object,
sampling_period=0.1 * pq.ms,
kernel=kernel,
pool_spike_trains=True,
pool_trials=True)
self.assertIsInstance(rate, neo.AnalogSignal)
self.assertEqual(rate.shape[1], 1)

def test_instantaneous_rate_trials_pool_spiketrains_false_pool_trials_false(self):
kernel = kernels.GaussianKernel(sigma=500 * pq.ms)

rate = statistics.instantaneous_rate(self.trial_object,
sampling_period=0.1 * pq.ms,
kernel=kernel,
pool_spike_trains=False,
pool_trials=False)
self.assertIsInstance(rate, list)
self.assertEqual(len(rate), self.trial_object.n_trials)
self.assertEqual(rate[0].shape[1], self.trial_object.n_spiketrains_trial_by_trial[0])

def test_instantaneous_rate_spiketrainlist_pool_spike_trains(self):
kernel = kernels.GaussianKernel(sigma=500 * pq.ms)

rate = statistics.instantaneous_rate(
Expand All @@ -999,7 +1035,19 @@ def test_instantaneous_rate_list_pool_spike_trains(self):
pool_spike_trains=True,
pool_trials=False)
self.assertIsInstance(rate, neo.core.AnalogSignal)
self.assertEqual(rate.magnitude.shape[1], 1)
self.assertEqual(rate.shape[1], 1)

def test_instantaneous_rate_list_pool_spike_trains(self):
kernel = kernels.GaussianKernel(sigma=500 * pq.ms)

rate = statistics.instantaneous_rate(
list(self.trial_object.get_spiketrains_from_trial_as_list(0)),
sampling_period=0.1 * pq.ms,
kernel=kernel,
pool_spike_trains=True,
pool_trials=False)
self.assertIsInstance(rate, neo.core.AnalogSignal)
self.assertEqual(rate.shape[1], 1)

def test_instantaneous_rate_list_of_spike_trains(self):
kernel = kernels.GaussianKernel(sigma=500 * pq.ms)
Expand All @@ -1010,7 +1058,7 @@ def test_instantaneous_rate_list_of_spike_trains(self):
pool_spike_trains=False,
pool_trials=False)
self.assertIsInstance(rate, neo.core.AnalogSignal)
self.assertEqual(rate.magnitude.shape[1], 2)
self.assertEqual(rate.magnitude.shape[1], self.trial_object.n_spiketrains_trial_by_trial[0])


class TimeHistogramTestCase(unittest.TestCase):
Expand Down
Loading