Basic single trial fNIRS finger tapping classification
This notebook sketches the analysis of a finger tapping dataset with multiple subjects. A simple Linear Discriminant Analysis (LDA) classifier is trained to distinguish left and right fingertapping.
PLEASE NOTE: For simplicity’s sake we are skipping many preprocessing steps (e.g. pruning, artifact removal, physiology removal). These are subject of other example notebooks. For a rigorous analysis you will want to include such steps. The purpose of this notebook is only to demonstrate easy interfacing of the scikit learn package.
[1]:
import cedalion
import cedalion.nirs
from cedalion.datasets import get_multisubject_fingertapping_snirf_paths
import cedalion.sigproc.quality as quality
import cedalion.plots as plots
import numpy as np
import xarray as xr
import matplotlib.pyplot as p
from sklearn.preprocessing import LabelEncoder
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.model_selection import train_test_split, cross_val_score, cross_val_predict, StratifiedKFold
from sklearn.metrics import accuracy_score,roc_curve, roc_auc_score, auc
from cedalion import units
xr.set_options(display_max_rows=3, display_values_threshold=50)
np.set_printoptions(precision=4)
Loading raw CW-NIRS data from a SNIRF file
This notebook uses a finger-tapping dataset in BIDS layout provided by Rob Luke. It can can be downloaded via cedalion.datasets
.
Cedalion’s read_snirf
method returns a list of Recording
objects. These are containers for timeseries and adjunct data objects.
[2]:
fnames = get_multisubject_fingertapping_snirf_paths()
subjects = [f"sub-{i:02d}" for i in [1, 2, 3]]
# store data of different subjects in a dictionary
data = {}
for subject, fname in zip(subjects, fnames):
records = cedalion.io.read_snirf(fname)
rec = records[0]
display(rec)
# Cedalion registers an accessor (attribute .cd ) on pandas DataFrames.
# Use this to rename trial_types inplace.
rec.stim.cd.rename_events(
{"1.0": "control", "2.0": "Tapping/Left", "3.0": "Tapping/Right"}
)
dpf = xr.DataArray(
[6, 6],
dims="wavelength",
coords={"wavelength": rec["amp"].wavelength},
)
rec["od"] = -np.log(rec["amp"] / rec["amp"].mean("time")),
rec["conc"] = cedalion.nirs.beer_lambert(rec["amp"], rec.geo3d, dpf)
data[subject] = rec
Downloading file 'multisubject-fingertapping.zip' from 'https://doc.ibs.tu-berlin.de/cedalion/datasets/multisubject-fingertapping.zip' to '/home/runner/.cache/cedalion'.
Unzipping contents of '/home/runner/.cache/cedalion/multisubject-fingertapping.zip' to '/home/runner/.cache/cedalion/multisubject-fingertapping.zip.unzip'
<Recording | timeseries: ['amp'], masks: [], stim: ['1.0', '15.0', '2.0', '3.0'], aux_ts: [], aux_obj: []>
<Recording | timeseries: ['amp'], masks: [], stim: ['1.0', '15.0', '2.0', '3.0'], aux_ts: [], aux_obj: []>
<Recording | timeseries: ['amp'], masks: [], stim: ['1.0', '15.0', '2.0', '3.0'], aux_ts: [], aux_obj: []>
Illustrate the dataset of one subject
[3]:
display(data["sub-01"])
<Recording | timeseries: ['amp', 'od', 'conc'], masks: [], stim: ['control', '15.0', 'Tapping/Left', 'Tapping/Right'], aux_ts: [], aux_obj: []>
Frequency filtering and splitting into epochs
[4]:
for subject, rec in data.items():
# cedalion registers the accessor .cd on DataArrays
# to provide common functionality like frequency filters...
rec["conc_freqfilt"] = rec["conc"].cd.freq_filter(
fmin=0.01, fmax=0.5, butter_order=4
)
# ... or epoch splitting
rec["cfepochs"] = rec["conc_freqfilt"].cd.to_epochs(
rec.stim, # stimulus dataframe
["Tapping/Left", "Tapping/Right"], # select events
before=5 * units.s, # seconds before stimulus
after=20 * units.s, # seconds after stimulus
)
Plot frequency filtered data
Illustrate for a single subject and channel the effect of the bandpass filter.
[5]:
rec = data["sub-01"]
channel = "S5D7"
f, ax = p.subplots(2, 1, figsize=(12, 4), sharex=True)
ax[0].plot(rec["conc"].time, rec["conc"].sel(channel=channel, chromo="HbO"), "r-", label="HbO")
ax[0].plot(rec["conc"].time, rec["conc"].sel(channel=channel, chromo="HbR"), "b-", label="HbR")
ax[1].plot(
rec["conc_freqfilt"].time,
rec["conc_freqfilt"].sel(channel=channel, chromo="HbO"),
"r-",
label="HbO",
)
ax[1].plot(
rec["conc_freqfilt"].time,
rec["conc_freqfilt"].sel(channel=channel, chromo="HbR"),
"b-",
label="HbR",
)
ax[0].set_xlim(1000, 1100)
ax[1].set_xlabel("time / s")
ax[0].set_ylabel("$\Delta c$ / $\mu M$")
ax[1].set_ylabel("$\Delta c$ / $\mu M$")
ax[0].legend(loc="upper left")
ax[1].legend(loc="upper left");

Baseline removal
[6]:
for subject, rec in data.items():
# calculate baseline
baseline_conc = rec["cfepochs"].sel(reltime=(rec["cfepochs"].reltime < 0)).mean("reltime")
# subtract baseline
rec["cfbl_epochs"] = rec["cfepochs"] - baseline_conc
[7]:
display(data["sub-01"]["cfbl_epochs"])
<xarray.DataArray (epoch: 60, chromo: 2, channel: 28, reltime: 198)> Size: 5MB <Quantity([[[[ 5.2902e-02 5.5043e-02 5.7083e-02 ... 5.6429e-02 2.9234e-02 1.5282e-03] [ 6.3262e-02 6.4291e-02 6.4695e-02 ... -9.3166e-02 -9.8304e-02 -1.0246e-01] [ 2.6691e-02 3.2027e-02 3.7556e-02 ... -9.2581e-02 -1.0783e-01 -1.2340e-01] ... [ 3.7579e-02 4.4702e-02 5.1556e-02 ... -4.4831e-01 -4.6680e-01 -4.8436e-01] [ 4.0865e-02 4.3057e-02 4.5466e-02 ... -6.5807e-01 -6.6514e-01 -6.7172e-01] [ 4.8372e-02 4.7851e-02 4.8783e-02 ... -3.9799e-01 -4.1427e-01 -4.2880e-01]] [[ 3.2768e-02 2.4241e-02 1.4962e-02 ... -1.9548e-01 -1.8971e-01 -1.8149e-01] [ 2.2247e-03 5.0735e-03 8.0807e-03 ... -1.3482e-01 -1.4002e-01 -1.4366e-01] [ 3.0611e-02 2.6840e-02 2.2641e-02 ... -2.0681e-01 -2.0450e-01 -2.0081e-01] ... [ 1.3317e-01 1.4323e-01 1.5242e-01 ... -7.2660e-02 -6.5427e-02 -5.8264e-02] [-5.3550e-01 -4.7794e-01 -4.2124e-01 ... -1.3102e+00 -1.3077e+00 -1.3052e+00] [ 2.8455e-01 2.7816e-01 2.7210e-01 ... -2.2472e-01 -2.2569e-01 -2.2698e-01]] [[ 4.8799e-02 4.1770e-02 3.4004e-02 ... -6.2999e-02 -6.3845e-02 -6.5198e-02] [ 3.7193e-02 3.6186e-02 3.3858e-02 ... -9.2909e-02 -8.9262e-02 -8.5327e-02] [ 1.7907e-02 8.2854e-03 -8.6835e-04 ... -1.0254e-01 -9.8087e-02 -9.3929e-02] ... [ 1.6617e-02 1.2406e-02 7.4503e-03 ... -1.0295e-01 -1.0261e-01 -1.0289e-01] [-2.8990e-02 -2.0466e-02 -1.1887e-02 ... -3.0189e-01 -2.9509e-01 -2.8930e-01] [ 7.1953e-02 6.7632e-02 6.2574e-02 ... -1.0922e-01 -1.0625e-01 -1.0445e-01]]]], 'micromolar')> Coordinates: (3/6) * reltime (reltime) float64 2kB -5.12 -4.992 -4.864 ... 19.84 19.97 20.1 trial_type (epoch) <U13 3kB 'Tapping/Left' ... 'Tapping/Right' ... ... detector (channel) object 224B 'D1' 'D2' 'D3' 'D9' ... 'D7' 'D8' 'D16' Dimensions without coordinates: epoch
Block Averages of trials for one participant per condition
[8]:
# we use subject 1 as an example here
subject = "sub-01"
# group trials by trial_type. For each group individually average the epoch dimension
blockaverage = data[subject]["cfbl_epochs"].groupby("trial_type").mean("epoch")
Plotting averaged epochs
[9]:
f, ax = p.subplots(5, 6, figsize=(12, 10))
ax = ax.flatten()
for i_ch, ch in enumerate(blockaverage.channel):
for ls, trial_type in zip(["-", "--"], blockaverage.trial_type):
ax[i_ch].plot(
blockaverage.reltime,
blockaverage.sel(chromo="HbO", trial_type=trial_type, channel=ch),
"r",
lw=2,
ls=ls,
)
ax[i_ch].plot(
blockaverage.reltime,
blockaverage.sel(chromo="HbR", trial_type=trial_type, channel=ch),
"b",
lw=2,
ls=ls,
)
ax[i_ch].grid(1)
ax[i_ch].set_title(ch.values)
ax[i_ch].set_ylim(-0.3, 0.6)
# add legend
ax[0].legend(["HbO Tapping/Left", "HbR Tapping/Left", "HbO Tapping/Right", "HbR Tapping/Right"])
p.tight_layout()

Training a LDA classifier with Scikit-Learn
Feature Extraction
We use very simple min, max and avg features.
[10]:
for subject, rec in data.items():
# avg signal between 0 and 10 seconds after stimulus onset
fmean = rec["cfbl_epochs"].sel(reltime=slice(0, 10)).mean("reltime")
# min signal between 0 and 15 seconds after stimulus onset
fmin = rec["cfbl_epochs"].sel(reltime=slice(0, 15)).min("reltime")
# max signal between 0 and 15 seconds after stimulus onset
fmax = rec["cfbl_epochs"].sel(reltime=slice(0, 15)).max("reltime")
# concatenate features and stack them into a single dimension
X = xr.concat([fmean, fmin, fmax], dim="reltime")
X = X.stack(features=["chromo", "channel", "reltime"])
# strip units. sklearn would strip them anyway and issue a warning about it.
X = X.pint.dequantify()
# need to manually tell xarray to create an index for trial_type
X = X.set_xindex("trial_type")
# save in recording container
rec.aux_obj["X"] = X
[11]:
display(data["sub-01"].aux_obj["X"])
<xarray.DataArray (epoch: 60, features: 168)> Size: 81kB array([[ 0.4964, 0.0127, 0.8439, ..., -0.0041, -0.0903, 0.0671], [ 0.0445, -0.3305, 0.2788, ..., 0.0166, -0.0359, 0.0614], [ 0.109 , -0.0486, 0.3628, ..., 0.018 , -0.0568, 0.0678], ..., [ 0.1284, -0.1551, 0.3873, ..., -0.05 , -0.1053, -0.0204], [ 0.4959, -0.4342, 0.7733, ..., -0.0512, -0.4268, 0.6003], [ 0.0728, -0.3504, 0.4382, ..., -0.0664, -0.1677, -0.0176]]) Coordinates: (3/7) * trial_type (epoch) <U13 3kB 'Tapping/Left' ... 'Tapping/Right' source (features) object 1kB 'S1' 'S1' 'S1' 'S1' ... 'S8' 'S8' 'S8' ... ... * reltime (features) int64 1kB 0 1 2 0 1 2 0 1 2 0 ... 2 0 1 2 0 1 2 0 1 2 Dimensions without coordinates: epoch Attributes: units: micromolar
[12]:
# Encode labels for use in scikit-learn
for subject, rec in data.items():
rec.aux_obj["y"] = xr.apply_ufunc(LabelEncoder().fit_transform, rec.aux_obj["X"].trial_type)
display(data["sub-01"].aux_obj["y"])
<xarray.DataArray 'trial_type' (epoch: 60)> Size: 480B array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]) Coordinates: * trial_type (epoch) <U13 3kB 'Tapping/Left' ... 'Tapping/Right' Dimensions without coordinates: epoch
Train LDA classifier for each subject using 5-fold cross-validation
[13]:
# initialize dictionaries for key metrics for each subject to plot
scores = {}
fpr = {}
tpr = {}
roc_auc = {}
for subject, rec in data.items():
X = rec.aux_obj["X"]
y = rec.aux_obj["y"]
classifier = LinearDiscriminantAnalysis(n_components=1)
# Define the cross-validation strategy (e.g., stratified k-fold with 5 folds)
cv = StratifiedKFold(n_splits=5)
# Perform cross-validation and get accuracy scores
scores[subject] = cross_val_score(classifier, X, y, cv=cv, scoring='accuracy')
# Get predicted probabilities using cross-validation
pred_prob = cross_val_predict(classifier, X, y, cv=cv, method='predict_proba')[:, 1]
# Calculate ROC curve and AUC
fpr[subject], tpr[subject], thresholds = roc_curve(y, pred_prob)
roc_auc[subject] = auc(fpr[subject], tpr[subject])
# Print the mean accuracy across folds
print(f"Cross-validated accuracy for subject {subject}: {scores[subject].mean():.2f}")
# barplot of accuracies
f, ax = p.subplots()
ax.bar(data.keys(), [scores.mean() for scores in scores.values()])
ax.set_ylabel("Accuracy")
ax.set_xlabel("Subject")
Cross-validated accuracy for subject sub-01: 0.82
Cross-validated accuracy for subject sub-02: 0.68
Cross-validated accuracy for subject sub-03: 0.52
[13]:
Text(0.5, 0, 'Subject')

Plot ROC curves for subjects
[14]:
# Initialize the ROC plot
p.figure(figsize=(10, 8))
# Train classifier and plot ROC curve for each subject
for subject, rec in data.items():
# Plotting the ROC curve
p.plot(fpr[subject], tpr[subject], lw=2, label=f'Subject {subject} (AUC = {roc_auc[subject]:.2f})')
# Plot the diagonal line for random guessing
p.plot([0, 1], [0, 1], color='gray', linestyle='--')
# Adding labels and title
p.xlabel('False Positive Rate')
p.ylabel('True Positive Rate')
p.title('ROC Curves for All Subjects')
p.legend(loc='lower right')
p.grid(True)
p.show()
