Source code for mne_hfo.detect

import collections
from typing import Tuple, Union

import mne
import numpy as np
from joblib import Parallel, delayed, cpu_count
from mne.utils import warn
from scipy.signal import hilbert
from tqdm import tqdm

from mne_hfo.base import Detector
from mne_hfo.config import ACCEPTED_BAND_METHODS
from mne_hfo.posthoc import _check_detection_overlap


def _band_z_score_detect(x_cond, sfreq, band_idx, l_freq, h_freq,
                         cycle_threshold, gap_threshold, threshold):
    """Iterate bandpass and z-score a channel's signal.

    Creates a bandpass filter (order=3) between ``l_freq`` and ``h_freq``.
    Then it z-scores the data

    Parameters
    ----------
    x_cond :
    sfreq :
    band_idx :
    l_freq :
    h_freq :
    cycle_threshold :
    gap_threshold :
    threshold :

    Returns
    -------
    tdetects : list of tuples
        All HFO events that passed the bandpass, zscore. It will store
        the band index (``band_idx``), timepoint of start, endpoint of start,
        and the maximum value of the Hilbert envelope in this event window.
    """
    tdetects = []

    # create highpass and lowpass filters
    x_cond = mne.filter.filter_data(
        x_cond, sfreq=sfreq,
        l_freq=l_freq, h_freq=h_freq, method='iir')
    # b, a = butter(N=3, Wn=l_freq / (sfreq / 2), btype='highpass')
    # x_cond = filtfilt(b, a, x_cond)
    # b, a = butter(N=3, Wn=h_freq / (sfreq / 2), btype='lowpass')
    # x_cond = filtfilt(b, a, x_cond)

    # Compute the z-scores
    x_cond = (x_cond - np.mean(x_cond)) / np.std(x_cond)

    # compute the absolute value of the Hilbert transform
    # (i.e. the envelope)
    hfx = np.abs(hilbert(x_cond))

    # threshold the Hilbert envelope to create a
    # thresholded mask
    thresh_sig = np.zeros(len(x_cond), dtype='bool')
    thresh_sig[hfx > threshold] = 1

    # Now get the lengths of each detected HFO
    idx = 0

    # indices where the threshold was met
    thresh_idxs = np.where(thresh_sig == 1)[0]
    gap_samp = round(gap_threshold * sfreq / l_freq)

    # loop through all significant zscore time points
    while idx < len(thresh_idxs) - 1:
        if (thresh_idxs[idx + 1] - thresh_idxs[idx]) == 1:
            start_idx = thresh_idxs[idx]
            while idx < len(thresh_idxs) - 1:
                if (thresh_idxs[idx + 1] - thresh_idxs[idx]) == 1:
                    idx += 1  # Move to the end of the detection
                    if idx == len(thresh_idxs) - 1:
                        stop_idx = thresh_idxs[idx]
                        # Check for number of cycles
                        dur = (stop_idx - start_idx) / sfreq
                        cycs = l_freq * dur
                        if cycs > cycle_threshold:
                            # Carry the amplitude and frequency info
                            tdetects.append([band_idx, start_idx, stop_idx,
                                             max(hfx[start_idx:stop_idx])])
                else:  # Check for gap
                    if (thresh_idxs[idx + 1] - thresh_idxs[idx]) < gap_samp:
                        idx += 1
                    else:
                        stop_idx = thresh_idxs[idx]
                        # Check for number of cycles
                        dur = (stop_idx - start_idx) / sfreq
                        cycs = l_freq * dur
                        if cycs > cycle_threshold:
                            tdetects.append([band_idx, start_idx, stop_idx,
                                             max(hfx[start_idx:stop_idx])])
                        idx += 1
                        break
        else:
            idx += 1

    return tdetects


def _run_detect_branch(detects, det_idx, HFO_outline):
    """
    Process detections from hilbert detector.

    HFO_outline structure:
    [0] - bands in which the detection happened
    [1] - starts for each band
    [2] - stop for each band
    [3] -
    """
    HFO_outline.append(np.copy(detects[det_idx, :]))

    # Create a subset for next band
    next_band_idcs = np.where(detects[:, 0] == detects[det_idx, 0] + 1)
    if not len((next_band_idcs)[0]):
        # No detects in band - finish the branch
        detects[det_idx, 0] = 0  # Set the processed detect to zero
        return HFO_outline
    else:
        # Get overllaping detects
        for next_det_idx in next_band_idcs[0]:
            if _check_detection_overlap([detects[det_idx, 1], detects[det_idx,
                                                                      2]],
                                        [detects[next_det_idx, 1],
                                         detects[next_det_idx,
                                                 2]]):
                # Go up the tree
                _run_detect_branch(detects, next_det_idx, HFO_outline)

        detects[det_idx, 0] = 0
        return HFO_outline


class HilbertDetector(Detector):  # noqa
    """2D HFO hilbert detection used in Kucewicz et al. 2014.

    A multi-taper method with: 4 Hz bandwidth, 1 sec sliding window,
    stepsize 100 ms, for the 1-500 Hz range, no padding, 2 tapers.

    Parameters
    ----------
    sfreq: float
        Sampling frequency of the signal
    l_freq: float
        Low cut-off frequency
    h_freq: float
        High cut-off frequency
    threshold: float
        Threshold for detection (default=3)
    band_method: str
        Spacing of hilbert frequency bands - options: 'linear' or 'log'
        (default='linear'). Linear provides better frequency resolution but
        is slower.
    n_bands: int
        Number of bands if band_spacing = log (default=300)
    cycle_threshold: float
        Minimum number of cycles to detect (default=1)
    gap_threshold: float
        Number of cycles for gaps (default=1)
    n_jobs: int
        Number of cores to use (default=1)
    offset: int
        Offset which is added to the final detection. This is used when the
        function is run in separate windows. Default = 0

    References
    ----------
    [1] M. T. Kucewicz, J. Cimbalnik, J. Y. Matsumoto, B. H. Brinkmann,
    M. Bower, V. Vasoli, V. Sulc, F. Meyer, W. R. Marsh, S. M. Stead, and
    G. A. Worrell, “High frequency oscillations are associated with
    cognitive processing in human recognition memory.,” Brain, pp. 1–14,
    Jun. 2014.
    """

    def __init__(self, sfreq: float, l_freq: float, h_freq: float,
                 threshold: float = 3, band_method: str = 'linear',
                 n_bands: int = 300, cycle_threshold: float = 1,
                 gap_threshold: float = 1, n_jobs: int = 1,
                 offset: int = 0, scoring_func: str = 'f1',
                 verbose: bool = False):
        if band_method not in ACCEPTED_BAND_METHODS:
            raise ValueError(f'Band method {band_method} is not '
                             f'an acceptable parameter. Please use '
                             f'one of {ACCEPTED_BAND_METHODS}')

        super(HilbertDetector, self).__init__(
            threshold, win_size=None, overlap=None,
            scoring_func=scoring_func, n_jobs=n_jobs, verbose=verbose)

        self.sfreq = sfreq
        self.l_freq = l_freq
        self.h_freq = h_freq
        self.band_method = band_method
        self.n_bands = n_bands
        self.cycle_threshold = cycle_threshold
        self.gap_threshold = gap_threshold
        self.n_jobs = n_jobs
        self.offset = offset

    def fit(self, X):
        """Override ``Detector.fit`` function."""
        n_chs, n_times = X.shape

        # store all hfo occurrences as an array of channels X windows
        n_windows = int(np.ceil((n_times - self.win_size) /
                                self.step_size)) + 1
        hfo_event_arr = np.empty((n_chs, n_windows))

        # store contiguous hfos as one occurrence, so we store
        # them as a dictionary of lists
        chs_hfos = collections.defaultdict(list)

        # bandpass the signal using FIR filter
        # X = mne.filter.filter_data(X, sfreq=self.sfreq,
        #                            l_freq=self.l_freq,
        #                            h_freq=self.h_freq, picks=picks,
        #                            method='iir', verbose=self.verbose)

        # Construct filter cut offs that are either logarithmically
        # or linearly spaced
        if self.band_method == 'log':
            low_fc = float(self.l_freq)
            high_fc = float(self.h_freq)
            freq_cutoffs = np.logspace(0, np.log10(high_fc), self.n_bands)
            freq_cutoffs = freq_cutoffs[(freq_cutoffs > low_fc) &
                                        (freq_cutoffs < high_fc)]
            freq_span = len(freq_cutoffs) - 1
        elif self.band_method == 'linear':
            freq_cutoffs = np.arange(self.l_freq, self.h_freq)
            freq_span = (self.h_freq - self.l_freq) - 1

        # run detector on all channels
        for idx in range(self.n_chs):
            sig = X[idx, :]

            output = []

            # run detection algorithm possibly in parallel on this channel
            tdetects_concat = []
            if self.n_jobs > 1 or self.n_jobs == -1:
                # for a single job, we don't need joblib
                if self.n_jobs != 1:
                    try:
                        from joblib import Parallel, delayed
                    except ImportError:
                        try:
                            from sklearn.externals.joblib import (
                                Parallel, delayed)
                        except ImportError:
                            warn('joblib not installed. '
                                 'Cannot run in parallel.')
                            self.n_jobs = 1

                # Run the filters in their threads and return the result
                iter_mat = [(sig, self.sfreq, i,
                             freq_cutoffs[i],
                             freq_cutoffs[i + 1],
                             self.cycle_threshold, self.gap_threshold,
                             self.threshold) for i in range(freq_span)]
                tdetects_concat = Parallel(n_jobs=self.n_jobs)(
                    delayed(_band_z_score_detect)(
                        *iter_args
                    )
                    for iter_args in tqdm(iter_mat, unit='HFO-first-phase')
                )
            else:
                # OPTIMIZE - check if there is a better way to do this
                # (S transform?+ spectra zeroing?)
                for i in tqdm(range(freq_span), unit='HFO-first-phase'):
                    bot = freq_cutoffs[i]
                    top = freq_cutoffs[i + 1]

                    args = [sig, self.sfreq, i, bot, top,
                            self.cycle_threshold, self.gap_threshold,
                            self.threshold]

                    tdetects_concat.append(_band_z_score_detect(args))

            # post process detected HFO events by detecting outline of events
            detects = np.array([det for band in tdetects_concat
                                for det in band])
            outlines = []
            if len(detects):
                while sum(detects[:, 0] != 0):
                    det_idx = np.where(detects[:, 0] != 0)[0][0]
                    HFO_outline = []
                    outlines.append(np.array(_run_detect_branch(detects,
                                                                det_idx,
                                                                HFO_outline)))

            # Get the detections
            for outline in outlines:
                start = min(outline[:, 1])
                stop = max(outline[:, 2])
                freq_min = freq_cutoffs[int(outline[0, 0])]
                freq_max = freq_cutoffs[int(outline[-1, 0])]
                frequency_at_max = freq_cutoffs[
                    int(outline[np.argmax(outline[:, 3]), 0])]
                max_amplitude = max(outline[:, 3])

                output.append((start, stop,
                               freq_min, freq_max, frequency_at_max,
                               max_amplitude))
            chs_hfos[idx] = output

        self.hfo_event_arr_ = hfo_event_arr
        self.chs_hfos_ = chs_hfos
        return chs_hfos


[docs]class LineLengthDetector(Detector): """Line-length detection algorithm. Original paper defines HFOS as "(HFOs), which we collectively term as all activity >40 Hz (including gamma, high-gamma, ripple, and fast ripple oscillations), may have a fundamental role in the generation and spread of focal seizures" In the paper, data were sampled at 200 Hz and bandpass-filtered (0.1 – 100 Hz) during acquisition. Data were further digitally bandpass-filtered (4th-order Butterworth, forward-backward filtering, ``0.1 – 85 Hz``) to minimize potential artifacts due to aliasing. (IIR for forward-backward pass). Compared to RMS detector, they utilize line-length metric. Parameters ---------- filter_band : tuple(float, float) | None Low cut-off frequency at index 0 and high cut-off frequency at index 1. threshold: float Number of standard deviations to use as a threshold win_size: int Sliding window size in samples overlap: float Fraction of the window overlap (0 to 1) offset: int Offset which is added to the final detection. This is used when the function is run in separate windows. Default = 0 hfo_name: str What to name the events detected (i.e. fast ripple if freq_band is (250, 500)). Notes ----- For processing, a sliding window is used. For post-processing, any events that overlap are considered to be the same. References ---------- .. [1] A. B. Gardner, G. A. Worrell, E. Marsh, D. Dlugos, and B. Litt, “Human and automated detection of high-frequency oscillations in clinical intracranial EEG recordings,” Clin. Neurophysiol., vol. 118, no. 5, pp. 1134–1143, May 2007. .. [2] Esteller, R. et al. (2001). Line length: an efficient feature for seizure onset detection. In Engineering in Medicine and Biology Society, 2001. Proceedings of the 23rd Annual International Conference of the IEEE (Vol. 2, pp. 1707-1710). IEEE. """ def __init__(self, threshold: Union[int, float] = 3, win_size: int = 100, overlap: float = 0.25, sfreq: int = None, filter_band: Tuple[int, int] = (30, 100), scoring_func: str = 'f1', n_jobs: int = -1, hfo_name: str = "hfo", verbose: bool = False): super(LineLengthDetector, self).__init__( threshold, win_size=win_size, overlap=overlap, scoring_func=scoring_func, n_jobs=n_jobs, verbose=verbose) self.filter_band = filter_band self.sfreq = sfreq self.hfo_name = hfo_name @property def l_freq(self): """Lower frequency band for HFO definition.""" if self.filter_band is None: return None return self.filter_band[0] @property def h_freq(self): """Higher frequency band for HFO definition.""" if self.filter_band is None: return None return self.filter_band[1] def _compute_hfo(self, X): """Override ``Detector._compute_hfo`` function.""" # store all hfo occurrences as an array of channels X windows n_windows = self._compute_n_wins(self.win_size, self.step_size, self.n_times) hfo_event_arr = np.empty((self.n_chs, n_windows)) # bandpass the signal using FIR filter if self.filter_band is not None: X = mne.filter.filter_data(X, sfreq=self.sfreq, l_freq=self.l_freq, h_freq=self.h_freq, method='iir', verbose=self.verbose) # run HFO detection on all the channels if self.n_jobs == 1: for idx in tqdm(range(self.n_chs)): sig = X[idx, :] # compute sliding window RMS hfo_event_arr[idx, :] = \ self._compute_sliding_window_detection( sig, method='line_length') else: if self.n_jobs == -1: n_jobs = cpu_count() else: n_jobs = self.n_jobs # run joblib parallelization over channels results = Parallel(n_jobs=n_jobs)( delayed(self._compute_sliding_window_detection)( X[idx, :], 'line_length' ) for idx in tqdm(range(self.n_chs)) ) for idx in range(len(results)): hfo_event_arr[idx, :] = results[idx] return hfo_event_arr
[docs]class RMSDetector(Detector): """Root mean square (RMS) detection algorithm (Staba Detector). The original algorithm described in the reference, takes a sliding window of 3 ms, computes the RMS values between 100 and 500 Hz. Then events separated by less than 10 ms were combined into one event. Then events not having a minimum of 6 peaks (i.e. band-pass signal rectified above 0 V) with greater then 3 std above mean baseline were removed. A finite impulse response (FIR) filter with a Hamming window was used. Parameters ---------- filter_band : tuple(float, float) | None Low cut-off frequency at index 0 and high cut-off frequency at index 1. threshold: float Number of standard deviations to use as a threshold. Default = 3. win_size: int Sliding window size in samples. Default = 100. The original paper uses a window size equivalent to 3 ms. overlap: float Fraction of the window overlap (0 to 1). Default = 0.25. The original paper uses an overlap of 0. offset: int Offset which is added to the final detection. This is used when the function is run in separate windows. Default = 0 References ---------- [1] R. J. Staba, C. L. Wilson, A. Bragin, I. Fried, and J. Engel, “Quantitative Analysis of High-Frequency Oscillations (80 − 500 Hz) Recorded in Human Epileptic Hippocampus and Entorhinal Cortex,” J. Neurophysiol., vol. 88, pp. 1743–1752, 2002. """ def __init__(self, threshold: Union[int, float] = 3, win_size: int = 100, overlap: float = 0.25, sfreq=None, filter_band: Tuple[int, int] = (100, 500), scoring_func='f1', n_jobs: int = -1, hfo_name: str = "hfo", verbose: bool = False): super(RMSDetector, self).__init__( threshold, win_size, overlap, scoring_func, n_jobs=n_jobs, verbose=verbose) # hyperparameters self.filter_band = filter_band self.sfreq = sfreq self.hfo_name = hfo_name @property def l_freq(self): """Lower frequency band for HFO definition.""" if self.filter_band is None: return None return self.filter_band[0] @property def h_freq(self): """Higher frequency band for HFO definition.""" if self.filter_band is None: return None return self.filter_band[1] def _compute_hfo(self, X): """Override ``Detector._compute_hfo`` function. Returns ------- hfo_event_arr : np.ndarray (channels x windows) The HFO metric value within each window per channel. """ # store all hfo occurrences as an array of channels X windows n_windows = self._compute_n_wins(self.win_size, self.step_size, self.n_times) hfo_event_arr = np.empty((self.n_chs, n_windows)) if self.l_freq is not None or self.h_freq is not None: # bandpass the signal using FIR filter X = mne.filter.filter_data(X, sfreq=self.sfreq, l_freq=self.l_freq, h_freq=self.h_freq, method='fir', verbose=self.verbose) # run HFO detection on all the channels if self.n_jobs == 1: for idx in tqdm(range(self.n_chs)): sig = X[idx, :] # compute sliding window RMS hfo_event_arr[idx, :] = \ self._compute_sliding_window_detection(sig, method='rms') else: if self.n_jobs == -1: n_jobs = cpu_count() else: n_jobs = self.n_jobs results = Parallel(n_jobs=n_jobs)( delayed(self._compute_sliding_window_detection)( X[idx, :], 'rms' ) for idx in tqdm(range(self.n_chs)) ) for idx in range(len(results)): hfo_event_arr[idx, :] = results[idx] return hfo_event_arr