This is an automated email from the git hooks/post-receive script. yoh pushed a commit to annotated tag v0.2 in repository python-mne.
commit 6799065d5f906de43c2ffa2bf297a26b7aa419d5 Author: Martin Luessi <[email protected]> Date: Wed Sep 28 15:26:21 2011 -0400 Multiple fixes: - indexing and slicing now always returns Epochs object - fixed bugs that occurs when Epochs only has one event (len(self.events) is 3 for single event) - still using shallow copy to avoid copying raw --- mne/epochs.py | 82 +++++++++++++++++++++++++----------------------- mne/tests/test_epochs.py | 6 ++-- 2 files changed, 45 insertions(+), 43 deletions(-) diff --git a/mne/epochs.py b/mne/epochs.py index f4c44e4..f61f393 100644 --- a/mne/epochs.py +++ b/mne/epochs.py @@ -88,9 +88,9 @@ class Epochs(object): ------- epochs = Epochs(...) - epochs[idx] : Return epoch with index idx (2D array, [n_channels, n_times]) - - epochs[start:stop] : Return Epochs object with a subset of epochs + epochs[idx] : Epochs + Return Epochs object with a subset of epochs (supports single + index and python style slicing) """ def __init__(self, raw, events, event_id, tmin, tmax, baseline=(None, 0), @@ -177,7 +177,7 @@ class Epochs(object): # Select the desired events selected = np.logical_and(events[:, 1] == 0, events[:, 2] == event_id) self.events = events[selected] - n_events = len(self.events) + n_events = len(self) if n_events > 0: print '%d matching events found' % n_events @@ -232,7 +232,7 @@ class Epochs(object): return good_events = [] - n_events = len(self.events) + n_events = len(self) for idx in range(n_events): epoch = self._get_epoch_from_disk(idx) if self._is_good_epoch(epoch): @@ -246,7 +246,12 @@ class Epochs(object): def _get_epoch_from_disk(self, idx): """Load one epoch from disk""" sfreq = self.raw.info['sfreq'] - event_samp = self.events[idx, 0] + + if self.events.ndim == 1: + #single event + event_samp = self.events[0] + else: + event_samp = self.events[idx, 0] # Read a data segment first_samp = self.raw.first_samp @@ -268,15 +273,15 @@ class Epochs(object): """ n_channels = len(self.ch_names) n_times = len(self.times) - n_events = len(self.events) + n_events = len(self) data = np.empty((n_events, n_channels, n_times)) cnt = 0 n_reject = 0 event_idx = [] for k in range(n_events): - e = self._get_epoch_from_disk(k) - if self._is_good_epoch(e): - data[cnt] = self._get_epoch_from_disk(k) + epoch = self._get_epoch_from_disk(k) + if self._is_good_epoch(epoch): + data[cnt] = epoch event_idx.append(k) cnt += 1 else: @@ -342,7 +347,7 @@ class Epochs(object): epoch = self._data[self._current] self._current += 1 else: - if self._current >= len(self.events): + if self._current >= len(self): raise StopIteration epoch = self._get_epoch_from_disk(self._current) self._current += 1 @@ -353,9 +358,9 @@ class Epochs(object): def __repr__(self): if not self.bad_dropped: - s = "n_events : %s (good & bad)" % len(self.events) + s = "n_events : %s (good & bad)" % len(self) else: - s = "n_events : %s (all good)" % len(self.events) + s = "n_events : %s (all good)" % len(self) s += ", tmin : %s (s)" % self.tmin s += ", tmax : %s (s)" % self.tmax s += ", baseline : %s" % str(self.baseline) @@ -364,38 +369,35 @@ class Epochs(object): def __len__(self): """Return length (number of events) """ - return len(self.events) + if self.events.ndim == 1: + return 1 + else: + return len(self.events) - def __getitem__(self, index): - """Return epoch at index or an Epochs object with a slice of epochs + def __getitem__(self, key): + """Return an Epochs object with a subset of epochs """ - if isinstance(index, slice): - # return Epochs object with slice of epochs - if not self.bad_dropped: - warnings.warn("Bad epochs have not been dropped, indexing " - "will be inccurate. Use drop_bad_epochs() " - "or preload=True") - - epoch_slice = copy.copy(self) - epoch_slice.events = self.events[index] - - if self.preload: - epoch_slice._data = self._data[index] + print key + if not self.bad_dropped: + warnings.warn("Bad epochs have not been dropped, indexing " + "will be inccurate. Use drop_bad_epochs() " + "or preload=True") - return epoch_slice + epochs = copy.copy(self) + epochs.events = self.events[key] - # return single epoch as 2D array if self.preload: - epoch = epoch = self._data[index] - else: - epoch = self._get_epoch_from_disk(index) - - if not self._is_good_epoch(epoch): - warnings.warn("Bad epoch with index %d returned. " - "Use drop_bad_epochs() or preload=True " - "to prevent this." % (index)) + if isinstance(key, slice): + epochs._data = self._data[key] + else: + #make sure data remains a 3D array + n_channels = len(self.ch_names) + n_times = len(self.times) + data = np.empty((1, n_channels, n_times)) + data[0, :, :] = self._data[key] + epochs._data = data - return epoch + return epochs def average(self): """Compute average of epochs @@ -409,7 +411,7 @@ class Epochs(object): evoked.info = copy.deepcopy(self.info) n_channels = len(self.ch_names) n_times = len(self.times) - n_events = len(self.events) + n_events = len(self) if self.preload: data = np.mean(self._data, axis=0) else: diff --git a/mne/tests/test_epochs.py b/mne/tests/test_epochs.py index 545d91a..120cf87 100644 --- a/mne/tests/test_epochs.py +++ b/mne/tests/test_epochs.py @@ -96,10 +96,9 @@ def test_indexing_slicing(): if not preload: epochs2.drop_bad_epochs() - # get slice + # using slicing epochs2_sliced = epochs2[start_index:end_index] - # using get_data() data_epochs2_sliced = epochs2_sliced.get_data() assert_array_equal(data_epochs2_sliced, \ data_normal[start_index:end_index]) @@ -107,7 +106,8 @@ def test_indexing_slicing(): # using indexing pos = 0 for idx in range(start_index, end_index): - assert_array_equal(epochs2_sliced[pos], data_normal[idx]) + data = epochs2_sliced[pos].get_data() + assert_array_equal(data[0], data_normal[idx]) pos += 1 -- Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-med/python-mne.git _______________________________________________ debian-med-commit mailing list [email protected] http://lists.alioth.debian.org/cgi-bin/mailman/listinfo/debian-med-commit
