Basic single trial fNIRS finger tapping classification

This notebook sketches the analysis of a finger tapping dataset with multiple subjects. A simple Linear Discriminant Analysis (LDA) classifier is trained to distinguish left and right fingertapping.

PLEASE NOTE: For simplicity’s sake we are skipping many preprocessing steps (e.g. pruning, artifact removal, physiology removal). These are subject of other example notebooks. For a rigorous analysis you will want to include such steps. The purpose of this notebook is only to demonstrate easy interfacing of the scikit learn package.

[1]:
import cedalion
import cedalion.nirs
from cedalion.datasets import get_multisubject_fingertapping_snirf_paths
import cedalion.sigproc.quality as quality
import cedalion.plots as plots
import numpy as np
import xarray as xr
import matplotlib.pyplot as p

from sklearn.preprocessing import LabelEncoder
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.model_selection import train_test_split, cross_val_score, cross_val_predict, StratifiedKFold
from sklearn.metrics import accuracy_score,roc_curve, roc_auc_score, auc

from cedalion import units

xr.set_options(display_max_rows=3, display_values_threshold=50)
np.set_printoptions(precision=4)

Loading raw CW-NIRS data from a SNIRF file

This notebook uses a finger-tapping dataset in BIDS layout provided by Rob Luke. It can can be downloaded via cedalion.datasets.

Cedalion’s read_snirf method returns a list of Recording objects. These are containers for timeseries and adjunct data objects.

[2]:
fnames = get_multisubject_fingertapping_snirf_paths()
subjects = [f"sub-{i:02d}" for i in [1, 2, 3]]

# store data of different subjects in a dictionary
data = {}
for subject, fname in zip(subjects, fnames):
    records = cedalion.io.read_snirf(fname)
    rec = records[0]
    display(rec)

    # Cedalion registers an accessor (attribute .cd ) on pandas DataFrames.
    # Use this to rename trial_types inplace.
    rec.stim.cd.rename_events(
        {"1.0": "control", "2.0": "Tapping/Left", "3.0": "Tapping/Right"}
    )

    dpf = xr.DataArray(
        [6, 6],
        dims="wavelength",
        coords={"wavelength": rec["amp"].wavelength},
    )

    rec["od"] = -np.log(rec["amp"] / rec["amp"].mean("time")),
    rec["conc"] = cedalion.nirs.beer_lambert(rec["amp"], rec.geo3d, dpf)

    data[subject] = rec
Downloading file 'multisubject-fingertapping.zip' from 'https://doc.ibs.tu-berlin.de/cedalion/datasets/multisubject-fingertapping.zip' to '/home/runner/.cache/cedalion'.
Unzipping contents of '/home/runner/.cache/cedalion/multisubject-fingertapping.zip' to '/home/runner/.cache/cedalion/multisubject-fingertapping.zip.unzip'
<Recording |  timeseries: ['amp'],  masks: [],  stim: ['1.0', '15.0', '2.0', '3.0'],  aux_ts: [],  aux_obj: []>
<Recording |  timeseries: ['amp'],  masks: [],  stim: ['1.0', '15.0', '2.0', '3.0'],  aux_ts: [],  aux_obj: []>
<Recording |  timeseries: ['amp'],  masks: [],  stim: ['1.0', '15.0', '2.0', '3.0'],  aux_ts: [],  aux_obj: []>

Illustrate the dataset of one subject

[3]:
display(data["sub-01"])
<Recording |  timeseries: ['amp', 'od', 'conc'],  masks: [],  stim: ['control', '15.0', 'Tapping/Left', 'Tapping/Right'],  aux_ts: [],  aux_obj: []>

Frequency filtering and splitting into epochs

[4]:
for subject, rec in data.items():
    # cedalion registers the accessor .cd on DataArrays
    # to provide common functionality like frequency filters...
    rec["conc_freqfilt"] = rec["conc"].cd.freq_filter(
        fmin=0.01, fmax=0.5, butter_order=4
    )

    # ... or epoch splitting
    rec["cfepochs"] = rec["conc_freqfilt"].cd.to_epochs(
        rec.stim,  # stimulus dataframe
        ["Tapping/Left", "Tapping/Right"],  # select events
        before=5 * units.s,  # seconds before stimulus
        after=20 * units.s,  # seconds after stimulus
    )

Plot frequency filtered data

Illustrate for a single subject and channel the effect of the bandpass filter.

[5]:
rec = data["sub-01"]
channel = "S5D7"

f, ax = p.subplots(2, 1, figsize=(12, 4), sharex=True)
ax[0].plot(rec["conc"].time, rec["conc"].sel(channel=channel, chromo="HbO"), "r-", label="HbO")
ax[0].plot(rec["conc"].time, rec["conc"].sel(channel=channel, chromo="HbR"), "b-", label="HbR")
ax[1].plot(
    rec["conc_freqfilt"].time,
    rec["conc_freqfilt"].sel(channel=channel, chromo="HbO"),
    "r-",
    label="HbO",
)
ax[1].plot(
    rec["conc_freqfilt"].time,
    rec["conc_freqfilt"].sel(channel=channel, chromo="HbR"),
    "b-",
    label="HbR",
)
ax[0].set_xlim(1000, 1100)
ax[1].set_xlabel("time / s")
ax[0].set_ylabel("$\Delta c$ / $\mu M$")
ax[1].set_ylabel("$\Delta c$ / $\mu M$")
ax[0].legend(loc="upper left")
ax[1].legend(loc="upper left");
../../_images/examples_machine_learning_50_finger_tapping_lda_classification_10_0.png

Baseline removal

[6]:
for subject, rec in data.items():
    # calculate baseline
    baseline_conc = rec["cfepochs"].sel(reltime=(rec["cfepochs"].reltime < 0)).mean("reltime")
    # subtract baseline
    rec["cfbl_epochs"] = rec["cfepochs"] - baseline_conc
[7]:
display(data["sub-01"]["cfbl_epochs"])
<xarray.DataArray (epoch: 60, chromo: 2, channel: 28, reltime: 198)> Size: 5MB
<Quantity([[[[ 5.2902e-02  5.5043e-02  5.7083e-02 ...  5.6429e-02  2.9234e-02
     1.5282e-03]
   [ 6.3262e-02  6.4291e-02  6.4695e-02 ... -9.3166e-02 -9.8304e-02
    -1.0246e-01]
   [ 2.6691e-02  3.2027e-02  3.7556e-02 ... -9.2581e-02 -1.0783e-01
    -1.2340e-01]
   ...
   [ 3.7579e-02  4.4702e-02  5.1556e-02 ... -4.4831e-01 -4.6680e-01
    -4.8436e-01]
   [ 4.0865e-02  4.3057e-02  4.5466e-02 ... -6.5807e-01 -6.6514e-01
    -6.7172e-01]
   [ 4.8372e-02  4.7851e-02  4.8783e-02 ... -3.9799e-01 -4.1427e-01
    -4.2880e-01]]

  [[ 3.2768e-02  2.4241e-02  1.4962e-02 ... -1.9548e-01 -1.8971e-01
    -1.8149e-01]
   [ 2.2247e-03  5.0735e-03  8.0807e-03 ... -1.3482e-01 -1.4002e-01
    -1.4366e-01]
   [ 3.0611e-02  2.6840e-02  2.2641e-02 ... -2.0681e-01 -2.0450e-01
    -2.0081e-01]
...
   [ 1.3317e-01  1.4323e-01  1.5242e-01 ... -7.2660e-02 -6.5427e-02
    -5.8264e-02]
   [-5.3550e-01 -4.7794e-01 -4.2124e-01 ... -1.3102e+00 -1.3077e+00
    -1.3052e+00]
   [ 2.8455e-01  2.7816e-01  2.7210e-01 ... -2.2472e-01 -2.2569e-01
    -2.2698e-01]]

  [[ 4.8799e-02  4.1770e-02  3.4004e-02 ... -6.2999e-02 -6.3845e-02
    -6.5198e-02]
   [ 3.7193e-02  3.6186e-02  3.3858e-02 ... -9.2909e-02 -8.9262e-02
    -8.5327e-02]
   [ 1.7907e-02  8.2854e-03 -8.6835e-04 ... -1.0254e-01 -9.8087e-02
    -9.3929e-02]
   ...
   [ 1.6617e-02  1.2406e-02  7.4503e-03 ... -1.0295e-01 -1.0261e-01
    -1.0289e-01]
   [-2.8990e-02 -2.0466e-02 -1.1887e-02 ... -3.0189e-01 -2.9509e-01
    -2.8930e-01]
   [ 7.1953e-02  6.7632e-02  6.2574e-02 ... -1.0922e-01 -1.0625e-01
    -1.0445e-01]]]], 'micromolar')>
Coordinates: (3/6)
  * reltime     (reltime) float64 2kB -5.12 -4.992 -4.864 ... 19.84 19.97 20.1
    trial_type  (epoch) <U13 3kB 'Tapping/Left' ... 'Tapping/Right'
    ...          ...
    detector    (channel) object 224B 'D1' 'D2' 'D3' 'D9' ... 'D7' 'D8' 'D16'
Dimensions without coordinates: epoch

Block Averages of trials for one participant per condition

[8]:
# we use subject 1 as an example here
subject = "sub-01"

# group trials by trial_type. For each group individually average the epoch dimension
blockaverage = data[subject]["cfbl_epochs"].groupby("trial_type").mean("epoch")

Plotting averaged epochs

[9]:
f, ax = p.subplots(5, 6, figsize=(12, 10))
ax = ax.flatten()
for i_ch, ch in enumerate(blockaverage.channel):
    for ls, trial_type in zip(["-", "--"], blockaverage.trial_type):
        ax[i_ch].plot(
            blockaverage.reltime,
            blockaverage.sel(chromo="HbO", trial_type=trial_type, channel=ch),
            "r",
            lw=2,
            ls=ls,
        )
        ax[i_ch].plot(
            blockaverage.reltime,
            blockaverage.sel(chromo="HbR", trial_type=trial_type, channel=ch),
            "b",
            lw=2,
            ls=ls,
        )
    ax[i_ch].grid(1)
    ax[i_ch].set_title(ch.values)
    ax[i_ch].set_ylim(-0.3, 0.6)

# add legend
ax[0].legend(["HbO Tapping/Left", "HbR Tapping/Left",  "HbO Tapping/Right", "HbR Tapping/Right"])
p.tight_layout()
../../_images/examples_machine_learning_50_finger_tapping_lda_classification_17_0.png

Training a LDA classifier with Scikit-Learn

Feature Extraction

We use very simple min, max and avg features.

[10]:
for subject, rec in data.items():

    # avg signal between 0 and 10 seconds after stimulus onset
    fmean = rec["cfbl_epochs"].sel(reltime=slice(0, 10)).mean("reltime")
    # min signal between 0 and 15 seconds after stimulus onset
    fmin = rec["cfbl_epochs"].sel(reltime=slice(0, 15)).min("reltime")
    # max signal between 0 and 15 seconds after stimulus onset
    fmax = rec["cfbl_epochs"].sel(reltime=slice(0, 15)).max("reltime")

    # concatenate features and stack them into a single dimension
    X = xr.concat([fmean, fmin, fmax], dim="reltime")
    X = X.stack(features=["chromo", "channel", "reltime"])

    # strip units. sklearn would strip them anyway and issue a warning about it.
    X = X.pint.dequantify()

    # need to manually tell xarray to create an index for trial_type
    X = X.set_xindex("trial_type")

    # save in recording container
    rec.aux_obj["X"] = X
[11]:
display(data["sub-01"].aux_obj["X"])
<xarray.DataArray (epoch: 60, features: 168)> Size: 81kB
array([[ 0.4964,  0.0127,  0.8439, ..., -0.0041, -0.0903,  0.0671],
       [ 0.0445, -0.3305,  0.2788, ...,  0.0166, -0.0359,  0.0614],
       [ 0.109 , -0.0486,  0.3628, ...,  0.018 , -0.0568,  0.0678],
       ...,
       [ 0.1284, -0.1551,  0.3873, ..., -0.05  , -0.1053, -0.0204],
       [ 0.4959, -0.4342,  0.7733, ..., -0.0512, -0.4268,  0.6003],
       [ 0.0728, -0.3504,  0.4382, ..., -0.0664, -0.1677, -0.0176]])
Coordinates: (3/7)
  * trial_type  (epoch) <U13 3kB 'Tapping/Left' ... 'Tapping/Right'
    source      (features) object 1kB 'S1' 'S1' 'S1' 'S1' ... 'S8' 'S8' 'S8'
    ...          ...
  * reltime     (features) int64 1kB 0 1 2 0 1 2 0 1 2 0 ... 2 0 1 2 0 1 2 0 1 2
Dimensions without coordinates: epoch
Attributes:
    units:    micromolar
[12]:
# Encode labels for use in scikit-learn
for subject, rec in data.items():
    rec.aux_obj["y"] = xr.apply_ufunc(LabelEncoder().fit_transform, rec.aux_obj["X"].trial_type)

display(data["sub-01"].aux_obj["y"])
<xarray.DataArray 'trial_type' (epoch: 60)> Size: 480B
array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])
Coordinates:
  * trial_type  (epoch) <U13 3kB 'Tapping/Left' ... 'Tapping/Right'
Dimensions without coordinates: epoch

Train LDA classifier for each subject using 5-fold cross-validation

[13]:
# initialize dictionaries for key metrics for each subject to plot
scores = {}
fpr = {}
tpr = {}
roc_auc = {}

for subject, rec in data.items():

    X = rec.aux_obj["X"]
    y = rec.aux_obj["y"]
    classifier = LinearDiscriminantAnalysis(n_components=1)

    # Define the cross-validation strategy (e.g., stratified k-fold with 5 folds)
    cv = StratifiedKFold(n_splits=5)

    # Perform cross-validation and get accuracy scores
    scores[subject] = cross_val_score(classifier, X, y, cv=cv, scoring='accuracy')
    # Get predicted probabilities using cross-validation
    pred_prob = cross_val_predict(classifier, X, y, cv=cv, method='predict_proba')[:, 1]

    # Calculate ROC curve and AUC
    fpr[subject], tpr[subject], thresholds = roc_curve(y, pred_prob)
    roc_auc[subject] = auc(fpr[subject], tpr[subject])


    # Print the mean accuracy across folds
    print(f"Cross-validated accuracy for subject {subject}: {scores[subject].mean():.2f}")

# barplot of accuracies
f, ax = p.subplots()
ax.bar(data.keys(), [scores.mean() for scores in scores.values()])
ax.set_ylabel("Accuracy")
ax.set_xlabel("Subject")

Cross-validated accuracy for subject sub-01: 0.82
Cross-validated accuracy for subject sub-02: 0.68
Cross-validated accuracy for subject sub-03: 0.52
[13]:
Text(0.5, 0, 'Subject')
../../_images/examples_machine_learning_50_finger_tapping_lda_classification_23_2.png

Plot ROC curves for subjects

[14]:
# Initialize the ROC plot
p.figure(figsize=(10, 8))
# Train classifier and plot ROC curve for each subject
for subject, rec in data.items():
    # Plotting the ROC curve
    p.plot(fpr[subject], tpr[subject], lw=2, label=f'Subject {subject} (AUC = {roc_auc[subject]:.2f})')
# Plot the diagonal line for random guessing
p.plot([0, 1], [0, 1], color='gray', linestyle='--')
    # Adding labels and title
p.xlabel('False Positive Rate')
p.ylabel('True Positive Rate')
p.title('ROC Curves for All Subjects')
p.legend(loc='lower right')
p.grid(True)
p.show()
../../_images/examples_machine_learning_50_finger_tapping_lda_classification_25_0.png