Basic single trial fNIRS finger tapping classification
This notebook sketches the analysis of a finger tapping dataset with multiple subjects. A simple Linear Discriminant Analysis (LDA) classifier is trained to distinguish left and right fingertapping.
PLEASE NOTE: For simplicity’s sake we are skipping many preprocessing steps (e.g. pruning, artifact removal, physiology removal). These are subject of other example notebooks. For a rigorous analysis you will want to include such steps. The purpose of this notebook is only to demonstrate easy interfacing of the scikit learn package.
[1]:
# This cells setups the environment when executed in Google Colab.
try:
import google.colab
!curl -s https://raw.githubusercontent.com/ibs-lab/cedalion/dev/scripts/colab_setup.py -o colab_setup.py
# Select branch with --branch "branch name" (default is "dev")
%run colab_setup.py
except ImportError:
pass
[2]:
import cedalion
import cedalion.nirs
from cedalion.datasets import get_multisubject_fingertapping_snirf_paths
import cedalion.sigproc.quality as quality
from cedalion.sigproc.frequency import freq_filter
import cedalion.plots as plots
import numpy as np
import xarray as xr
import matplotlib.pyplot as p
from sklearn.preprocessing import LabelEncoder
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.model_selection import train_test_split, cross_val_score, cross_val_predict, StratifiedKFold
from sklearn.metrics import accuracy_score,roc_curve, roc_auc_score, auc
from cedalion import units
xr.set_options(display_max_rows=3, display_values_threshold=50)
np.set_printoptions(precision=4)
Loading raw CW-NIRS data from SNIRF files
This notebook uses a finger-tapping dataset in BIDS layout provided by Rob Luke. It can can be downloaded via cedalion.datasets
.
Cedalion’s read_snirf
method returns a list of Recording
objects. These are containers for timeseries and adjunct data objects.
[3]:
fnames = get_multisubject_fingertapping_snirf_paths()
subjects = [f"sub-{i:02d}" for i in [1, 2, 3]]
# store data of different subjects in a dictionary
data = {}
for subject, fname in zip(subjects, fnames):
records = cedalion.io.read_snirf(fname)
rec = records[0]
# Cedalion registers an accessor (attribute .cd ) on pandas DataFrames.
# Use this to rename trial_types inplace.
rec.stim.cd.rename_events(
{"1.0": "control", "2.0": "Tapping/Left", "3.0": "Tapping/Right"}
)
dpf = xr.DataArray(
[6, 6],
dims="wavelength",
coords={"wavelength": rec["amp"].wavelength},
)
rec["od"] = cedalion.nirs.int2od(rec["amp"])
rec["conc"] = cedalion.nirs.od2conc(rec["amp"], rec.geo3d, dpf)
display(subject, rec)
data[subject] = rec
'sub-01'
<Recording | timeseries: ['amp', 'od', 'conc'], masks: [], stim: ['control', '15.0', 'Tapping/Left', 'Tapping/Right'], aux_ts: [], aux_obj: []>
'sub-02'
<Recording | timeseries: ['amp', 'od', 'conc'], masks: [], stim: ['control', '15.0', 'Tapping/Left', 'Tapping/Right'], aux_ts: [], aux_obj: []>
'sub-03'
<Recording | timeseries: ['amp', 'od', 'conc'], masks: [], stim: ['control', '15.0', 'Tapping/Left', 'Tapping/Right'], aux_ts: [], aux_obj: []>
[4]:
display(
data["sub-01"]
.stim.groupby("trial_type")[["trial_type"]]
.count()
.rename({"trial_type": "# trials"}, axis=1) # rename column
)
# trials | |
---|---|
trial_type | |
15.0 | 2 |
Tapping/Left | 30 |
Tapping/Right | 30 |
control | 30 |
Preprocessing
Frequency filtering and splitting into epochs
[5]:
for subject, rec in data.items():
# cedalion registers the accessor .cd on DataArrays
# to provide common functionality like frequency filters...
rec["conc_freqfilt"] = freq_filter(
rec["conc"], fmin=0.01 * units.Hz, fmax=0.5 * units.Hz
)
# cedalion registers the accessor .cd on DataArrays
# to provide common functionality like splitting time series into epochs
rec["cfepochs"] = rec["conc_freqfilt"].cd.to_epochs(
rec.stim, # stimulus dataframe
["Tapping/Left", "Tapping/Right"], # select trials
before=5 * units.s, # seconds before stimulus
after=20 * units.s, # seconds after stimulus
)
Show the time series before and after to_epochs
:
[6]:
display(data["sub-01"]["conc_freqfilt"])
display(data["sub-01"]["cfepochs"])
<xarray.DataArray 'concentration' (chromo: 2, channel: 28, time: 23239)> Size: 10MB <Quantity([[[-0.0475 -0.0538 -0.0598 ... 0.0401 0.0322 0.024 ] [-0.1 -0.1145 -0.1281 ... 0.0387 0.0342 0.03 ] [-0.0702 -0.084 -0.097 ... 0.0252 0.0232 0.0215] ... [-0.0813 -0.1201 -0.1563 ... -0.0219 -0.0108 0.0018] [-0.0558 -0.1322 -0.2031 ... -0.0873 -0.0708 -0.0521] [-0.1566 -0.383 -0.5956 ... -0.0796 -0.0633 -0.0433]] [[ 0.0132 0.014 0.0147 ... -0.0149 -0.0122 -0.0093] [ 0.0274 0.0324 0.0372 ... -0.0069 -0.0062 -0.0056] [ 0.0154 0.0187 0.0218 ... -0.0047 -0.0038 -0.0029] ... [ 0.0238 0.0341 0.0437 ... -0.011 -0.0101 -0.0093] [-0.0157 -0.0071 0.0009 ... 0.0026 0.0029 0.0031] [-0.0084 0.0567 0.1184 ... -0.0084 -0.0021 0.004 ]]], 'micromolar')> Coordinates: (3/6) * chromo (chromo) <U3 24B 'HbO' 'HbR' * time (time) float64 186kB 0.0 0.128 0.256 ... 2.974e+03 2.974e+03 ... ... detector (channel) object 224B 'D1' 'D2' 'D3' 'D9' ... 'D7' 'D8' 'D16'
<xarray.DataArray (epoch: 60, chromo: 2, channel: 28, reltime: 198)> Size: 5MB <Quantity([[[[-6.0991e-04 -3.9017e-04 -9.2706e-05 ... 1.6313e-02 2.1562e-02 2.6736e-02] [-1.2727e-02 -1.4438e-02 -1.5744e-02 ... 1.4890e-01 1.5447e-01 1.5879e-01] [-2.1286e-02 -2.2470e-02 -2.3656e-02 ... 5.1021e-02 5.5714e-02 6.0325e-02] ... [-1.1650e-01 -1.1826e-01 -1.1993e-01 ... 1.3764e-01 1.4620e-01 1.5430e-01] [-3.1655e-01 -3.1766e-01 -3.1920e-01 ... 3.1914e-01 3.2441e-01 3.2954e-01] [-1.7606e-01 -1.7166e-01 -1.6920e-01 ... 5.2047e-01 5.4246e-01 5.6296e-01]] [[-3.6935e-03 -3.0874e-03 -2.4405e-03 ... 1.0394e-02 8.8207e-03 7.0695e-03] [-2.1270e-03 -2.2010e-03 -2.4108e-03 ... -1.9207e-02 -1.9699e-02 -2.0153e-02] [-2.3742e-03 -1.7401e-03 -1.0668e-03 ... 1.1681e-03 -2.2401e-04 -1.7251e-03] ... [-2.1053e-01 -2.1518e-01 -2.1919e-01 ... -6.6813e-02 -7.0985e-02 -7.4982e-02] [-6.9048e-01 -7.5107e-01 -8.1072e-01 ... 2.0459e-01 1.9986e-01 1.9542e-01] [-1.1743e+00 -1.1628e+00 -1.1513e+00 ... -3.3815e-01 -3.3827e-01 -3.3734e-01]] [[-1.7871e-03 -1.2240e-03 -5.6485e-04 ... 9.1667e-03 9.5571e-03 9.9147e-03] [-9.9304e-03 -8.8731e-03 -7.5464e-03 ... 3.8110e-02 3.7036e-02 3.5999e-02] [-2.8111e-03 -1.7362e-03 -7.0393e-04 ... 1.4618e-02 1.4164e-02 1.3620e-02] ... [ 2.5693e-02 2.7554e-02 2.9413e-02 ... 1.6841e-02 1.7701e-02 1.8645e-02] [ 9.6264e-03 1.5051e-02 2.0240e-02 ... 1.1630e-02 8.6480e-03 6.1833e-03] [ 1.0219e-01 1.0239e-01 1.0304e-01 ... 3.9592e-02 3.7872e-02 3.6623e-02]]]], 'micromolar')> Coordinates: (3/6) * reltime (reltime) float64 2kB -5.12 -4.992 -4.864 ... 19.84 19.97 20.1 trial_type (epoch) <U13 3kB 'Tapping/Left' ... 'Tapping/Right' ... ... detector (channel) object 224B 'D1' 'D2' 'D3' 'D9' ... 'D7' 'D8' 'D16' Dimensions without coordinates: epoch
Plot frequency filtered data
Illustrate for a single subject and channel the effect of the bandpass filter. The lowpass remove the cardiac component. The highpass removed slow drift and the DC offset.
[7]:
rec = data["sub-01"]
channel = "S5D7"
f, ax = p.subplots(2, 1, figsize=(12, 4), sharex=True)
ax[0].plot(rec["conc"].time, rec["conc"].sel(channel=channel, chromo="HbO"), "r-", label="HbO")
ax[0].plot(rec["conc"].time, rec["conc"].sel(channel=channel, chromo="HbR"), "b-", label="HbR")
ax[1].plot(
rec["conc_freqfilt"].time,
rec["conc_freqfilt"].sel(channel=channel, chromo="HbO"),
"r-",
label="HbO",
)
ax[1].plot(
rec["conc_freqfilt"].time,
rec["conc_freqfilt"].sel(channel=channel, chromo="HbR"),
"b-",
label="HbR",
)
ax[0].set_xlim(1000, 1040)
ax[1].set_xlabel("time / s")
ax[0].set_ylabel("$\Delta c$ / $\mu M$")
ax[1].set_ylabel("$\Delta c$ / $\mu M$")
ax[0].legend(loc="upper left")
ax[1].legend(loc="upper left");

Baseline removal
Calculate a baseline by averaging all samples before the stimulus onset (reltime < 0
>) and subtract it:
[8]:
for subject, rec in data.items():
# calculate baseline
baseline_conc = rec["cfepochs"].sel(reltime=(rec["cfepochs"].reltime < 0)).mean("reltime")
# subtract baseline
rec["cfbl_epochs"] = rec["cfepochs"] - baseline_conc
[9]:
display(data["sub-01"]["cfbl_epochs"])
<xarray.DataArray (epoch: 60, chromo: 2, channel: 28, reltime: 198)> Size: 5MB <Quantity([[[[-1.3579e-02 -1.3360e-02 -1.3062e-02 ... 3.3438e-03 8.5929e-03 1.3767e-02] [-4.7307e-02 -4.9018e-02 -5.0324e-02 ... 1.1432e-01 1.1990e-01 1.2422e-01] [-1.2940e-02 -1.4123e-02 -1.5310e-02 ... 5.9368e-02 6.4061e-02 6.8671e-02] ... [-1.6980e-02 -1.8739e-02 -2.0412e-02 ... 2.3716e-01 2.4572e-01 2.5382e-01] [-3.2259e-02 -3.3362e-02 -3.4900e-02 ... 6.0343e-01 6.0870e-01 6.1383e-01] [-1.0204e-01 -9.7645e-02 -9.5183e-02 ... 5.9449e-01 6.1648e-01 6.3698e-01]] [[ 4.1874e-04 1.0249e-03 1.6717e-03 ... 1.4507e-02 1.2933e-02 1.1182e-02] [ 1.2053e-02 1.1980e-02 1.1770e-02 ... -5.0260e-03 -5.5182e-03 -5.9722e-03] [ 4.9959e-04 1.1337e-03 1.8070e-03 ... 4.0420e-03 2.6498e-03 1.1488e-03] ... [-7.9217e-02 -8.3868e-02 -8.7874e-02 ... 6.4499e-02 6.0328e-02 5.6331e-02] [ 5.4297e-01 4.8238e-01 4.2273e-01 ... 1.4380e+00 1.4333e+00 1.4289e+00] [-4.5178e-01 -4.4024e-01 -4.2880e-01 ... 3.8440e-01 3.8428e-01 3.8521e-01]] [[ 4.2527e-03 4.8158e-03 5.4749e-03 ... 1.5207e-02 1.5597e-02 1.5954e-02] [ 1.7197e-02 1.8254e-02 1.9581e-02 ... 6.5237e-02 6.4163e-02 6.3126e-02] [ 1.2834e-02 1.3909e-02 1.4941e-02 ... 3.0263e-02 2.9809e-02 2.9264e-02] ... [ 1.4614e-02 1.6476e-02 1.8335e-02 ... 5.7627e-03 6.6223e-03 7.5667e-03] [-7.0962e-02 -6.5538e-02 -6.0348e-02 ... -6.8959e-02 -7.1941e-02 -7.4405e-02] [ 4.9307e-02 4.9503e-02 5.0154e-02 ... -1.3293e-02 -1.5012e-02 -1.6261e-02]]]], 'micromolar')> Coordinates: (3/6) * reltime (reltime) float64 2kB -5.12 -4.992 -4.864 ... 19.84 19.97 20.1 trial_type (epoch) <U13 3kB 'Tapping/Left' ... 'Tapping/Right' ... ... detector (channel) object 224B 'D1' 'D2' 'D3' 'D9' ... 'D7' 'D8' 'D16' Dimensions without coordinates: epoch
Block Averages of trials for one participant per condition
[10]:
# we use subject 1 as an example here
subject = "sub-01"
# group trials by trial_type. For each group individually average the epoch dimension
blockaverage = data[subject]["cfbl_epochs"].groupby("trial_type").mean("epoch")
Plotting averaged epochs
[11]:
f, ax = p.subplots(5, 6, figsize=(12, 10))
ax = ax.flatten()
for i_ch, ch in enumerate(blockaverage.channel):
for ls, trial_type in zip(["-", "--"], blockaverage.trial_type):
ax[i_ch].plot(
blockaverage.reltime,
blockaverage.sel(chromo="HbO", trial_type=trial_type, channel=ch),
"r",
lw=2,
ls=ls,
)
ax[i_ch].plot(
blockaverage.reltime,
blockaverage.sel(chromo="HbR", trial_type=trial_type, channel=ch),
"b",
lw=2,
ls=ls,
)
ax[i_ch].grid(1)
ax[i_ch].set_title(ch.values)
ax[i_ch].set_ylim(-0.3, 0.6)
# add legend
ax[0].legend(["HbO Tapping/Left", "HbR Tapping/Left", "HbO Tapping/Right", "HbR Tapping/Right"])
p.tight_layout()

LDA classification with Scikit-Learn
Feature Extraction
We define simple features to characterize an epoch: the minimium, maximum and average concentration.
All of these are calculated in time windows after stimulus onset, for each channel and chromophore.
The features are calculated from the blockaveraged concentration time series. The layout of the resulting array is adjusted to match the layout scikit-learn expects.
[12]:
for subject, rec in data.items():
# avg signal between 0 and 10 seconds after stimulus onset
fmean = rec["cfbl_epochs"].sel(reltime=slice(0, 10)).mean("reltime")
# min signal between 0 and 15 seconds after stimulus onset
fmin = rec["cfbl_epochs"].sel(reltime=slice(0, 15)).min("reltime")
# max signal between 0 and 15 seconds after stimulus onset
fmax = rec["cfbl_epochs"].sel(reltime=slice(0, 15)).max("reltime")
# concatenate features and stack them into a new dimension ('feature')
# X will have shape (feature, epoch, chromo, channel)
X = xr.concat([fmean, fmin, fmax], dim="feature")
X = X.assign_coords({"feature" : ["mean", "min", "max"]})
display(X)
# afterwards stack the 3 per channel and chromo features together into a new
# dimension ('features')
X = X.stack(features=["chromo", "channel", "feature"])
#X = X.drop_vars("features_tmp") # coordinate created by concat. not needed
# strip units. sklearn would strip them anyway but issue a warning about it.
X = X.pint.dequantify()
# need to manually tell xarray to create an index for trial_type
X = X.set_xindex("trial_type")
# save in recording container
rec.aux_obj["X"] = X
<xarray.DataArray (feature: 3, epoch: 60, chromo: 2, channel: 28)> Size: 81kB <Quantity([[[[-9.3709e-02 -2.1889e-01 -9.4072e-02 ... -1.3173e-01 -1.1169e-01 -3.9138e-01] [ 2.7945e-02 6.5939e-02 3.5147e-02 ... 3.8594e-02 3.9230e-02 8.3013e-02]] [[-6.6110e-03 -3.6142e-03 -5.9760e-03 ... -1.0574e-01 -1.9792e-02 -1.7878e-01] [ 3.9941e-03 9.9413e-04 1.5060e-05 ... 3.2567e-02 2.8291e-02 2.7717e-02]] [[-1.8850e-02 -3.5395e-02 -1.7609e-02 ... -8.1159e-02 -1.6830e-01 -1.2537e-01] [ 8.3640e-03 6.9491e-03 7.5421e-03 ... 2.7515e-02 3.3584e-02 1.5711e-02]] ... [[-1.5387e-02 2.9948e-02 -7.4940e-03 ... 1.2255e-02 -3.3193e-02 5.1774e-01] [ 1.6051e-02 -1.7036e-03 9.8583e-03 ... 9.7127e-03 -7.6687e-03 ... 4.5626e-02] [ 2.0435e-02 5.9914e-02 2.6551e-02 ... 6.3754e-02 8.6744e-02 9.2641e-02]] ... [[ 4.3553e-02 1.6984e-01 7.3282e-02 ... 1.4835e-01 1.6543e-01 9.4178e-01] [ 3.1116e-02 2.9925e-02 2.7178e-02 ... 2.5866e-02 8.8768e-03 -4.3604e-02]] [[ 1.0048e-01 2.3613e-01 1.0076e-01 ... 1.3942e-01 -2.1243e-02 1.1223e+00] [ 4.1572e-02 9.8431e-02 5.4100e-02 ... 1.1013e-01 3.4744e-01 3.3735e-01]] [[ 7.8479e-02 1.8025e-01 8.5591e-02 ... 2.1315e-01 1.4708e+00 1.0362e+00] [ 2.8270e-02 6.1231e-02 4.5518e-02 ... 2.7968e-02 1.7058e-01 1.3357e-02]]]], 'micromolar')> Coordinates: (3/6) trial_type (epoch) <U13 3kB 'Tapping/Left' ... 'Tapping/Right' * chromo (chromo) <U3 24B 'HbO' 'HbR' ... ... * feature (feature) <U4 48B 'mean' 'min' 'max' Dimensions without coordinates: epoch
<xarray.DataArray (feature: 3, epoch: 60, chromo: 2, channel: 28)> Size: 81kB <Quantity([[[[-7.6058e-02 -1.9232e-02 -1.4292e-02 ... -6.0005e-03 -7.6004e-02 1.7201e-01] [ 7.2357e-02 1.8826e-03 1.0727e-02 ... 3.3683e-03 4.4195e-02 -1.0959e-02]] [[-1.7217e-02 3.9898e-02 5.1261e-02 ... 2.0060e-02 7.3336e-02 2.1333e-01] [ 1.4846e-02 -1.1545e-02 -1.6651e-02 ... -7.2794e-03 -1.4634e-02 -1.7267e-02]] [[ 2.3526e-01 5.8611e-02 4.6127e-02 ... 2.2957e-02 -2.8230e-02 3.8286e-01] [-1.1019e-01 -1.9586e-02 -1.0426e-02 ... -4.4624e-03 8.7338e-03 -6.0083e-02]] ... [[ 3.3141e-03 -5.0424e-02 -4.7069e-02 ... -2.9215e-02 -1.4680e-01 3.8087e-02] [ 5.8451e-02 1.6974e-02 1.7244e-02 ... 1.0881e-02 1.7931e-02 ... 1.7918e+00] [ 8.7512e-02 3.4332e-02 3.8019e-02 ... 1.2858e-02 5.1005e-02 1.0093e-01]] ... [[ 4.6768e-01 3.0753e-01 1.4946e-01 ... 4.6692e-02 1.5994e-01 1.3850e+00] [ 2.3556e-01 1.0041e-01 8.0922e-02 ... 2.6365e-02 6.5300e-02 2.2437e-01]] [[ 9.3439e-02 -1.5462e-02 3.3084e-02 ... 4.7123e-02 4.2476e-02 7.0692e-01] [ 1.0978e-01 3.8357e-01 2.9367e-01 ... 2.3907e-02 4.9559e-02 3.1058e-01]] [[ 1.0077e-01 4.7262e-01 1.1056e-01 ... -2.4398e-02 -1.1307e-01 3.4578e-01] [ 5.2073e-01 1.1032e-01 8.9784e-02 ... 4.5333e-02 1.2414e-01 3.0696e-01]]]], 'micromolar')> Coordinates: (3/6) trial_type (epoch) <U13 3kB 'Tapping/Left' ... 'Tapping/Right' * chromo (chromo) <U3 24B 'HbO' 'HbR' ... ... * feature (feature) <U4 48B 'mean' 'min' 'max' Dimensions without coordinates: epoch
<xarray.DataArray (feature: 3, epoch: 60, chromo: 2, channel: 28)> Size: 81kB <Quantity([[[[ 9.4151e-04 1.3353e-01 6.9375e-03 ... -3.2409e-02 -4.5654e-02 1.7955e-01] [ 2.8453e-03 -2.9226e-02 -9.2354e-04 ... 1.5962e-02 2.6193e-02 -4.2567e-02]] [[ 5.8593e-02 4.9796e-01 2.0111e-03 ... 6.5618e-02 4.3618e-02 3.5037e-01] [-1.7008e-02 -1.0509e-01 -2.5580e-03 ... -1.0156e-02 1.6038e-02 -8.4651e-02]] [[ 3.5433e-02 -1.6243e-01 1.6429e-02 ... -3.8725e-02 -9.6947e-02 -2.8258e-02] [ 7.5374e-03 4.4581e-02 -3.1688e-03 ... 1.5996e-04 -1.3312e-02 1.1759e-02]] ... [[ 8.8595e-02 4.3148e-01 3.5772e-02 ... 6.3133e-02 9.6649e-02 6.9241e-01] [-2.3717e-02 -9.9746e-02 -6.6261e-03 ... -4.5810e-04 3.8229e-02 ... 8.1524e-01] [ 3.4804e-02 1.7144e-01 1.3450e-02 ... 2.2054e-02 2.1947e-02 1.7357e-01]] ... [[ 1.7425e-01 8.3785e-01 6.4069e-02 ... 9.2185e-02 2.1503e-01 1.0599e+00] [-4.4987e-03 7.3010e-02 1.9523e-03 ... 1.5761e-02 9.0248e-02 4.4553e-02]] [[ 1.5211e-01 5.3921e-01 -1.5892e-02 ... 4.6031e-02 1.0641e-01 4.4798e-01] [ 4.5688e-02 2.2547e-01 1.3587e-02 ... 2.7739e-02 5.5008e-02 4.6410e-01]] [[ 3.2286e-02 8.1447e-01 2.5225e-02 ... 1.5702e-01 3.0574e-01 1.4467e+00] [ 3.0450e-03 1.1200e-02 2.1338e-04 ... 3.8020e-03 9.0089e-03 1.1847e-01]]]], 'micromolar')> Coordinates: (3/6) trial_type (epoch) <U13 3kB 'Tapping/Left' ... 'Tapping/Right' * chromo (chromo) <U3 24B 'HbO' 'HbR' ... ... * feature (feature) <U4 48B 'mean' 'min' 'max' Dimensions without coordinates: epoch
The arrays ‘X’ have now the layout that scikit-learn expects: (#samples, #features).
Features in columns can still be identified via their corresponding coordinates.
Epochs in rows have the trial_type
coorindate to distinguish different epochs.
Further coordinates can be added as needed (e.g. the subject when epochs from different subjects are pooled)
[13]:
display(data["sub-01"].aux_obj["X"])
<xarray.DataArray (epoch: 60, features: 168)> Size: 81kB array([[-0.0937, -0.1668, -0.002 , ..., 0.083 , -0.0384, 0.1765], [-0.0066, -0.0583, 0.0708, ..., 0.0277, -0.0441, 0.096 ], [-0.0188, -0.0746, 0.0099, ..., 0.0157, -0.0196, 0.0926], ..., [-0.0154, -0.0674, 0.0436, ..., -0.0757, -0.1599, -0.0436], [-0.1029, -0.1692, 0.1005, ..., 0.1566, -0.1096, 0.3374], [-0.0113, -0.0878, 0.0785, ..., -0.0435, -0.139 , 0.0134]], shape=(60, 168)) Coordinates: (3/7) * trial_type (epoch) <U13 3kB 'Tapping/Left' ... 'Tapping/Right' source (features) object 1kB 'S1' 'S1' 'S1' 'S1' ... 'S8' 'S8' 'S8' ... ... * feature (features) object 1kB 'mean' 'min' 'max' ... 'mean' 'min' 'max' Dimensions without coordinates: epoch Attributes: units: micromolar
[14]:
# Encode trial_type into 0,1 labels for use in scikit-learn
# LabelEncoder returns a numpy array. Use apply_ufunc to wrap the result into a xr.DataArray
for subject, rec in data.items():
rec.aux_obj["y"] = xr.apply_ufunc(LabelEncoder().fit_transform, rec.aux_obj["X"].trial_type)
display(data["sub-01"].aux_obj["y"])
<xarray.DataArray 'trial_type' (epoch: 60)> Size: 480B array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]) Coordinates: * trial_type (epoch) <U13 3kB 'Tapping/Left' ... 'Tapping/Right' Dimensions without coordinates: epoch
Train LDA classifier for each subject using 5-fold cross-validation
[15]:
# initialize dictionaries for key metrics for each subject to plot
scores = {}
fpr = {}
tpr = {}
roc_auc = {}
for subject, rec in data.items():
X = rec.aux_obj["X"]
y = rec.aux_obj["y"]
classifier = LinearDiscriminantAnalysis(n_components=1)
# Define the cross-validation strategy (e.g., stratified k-fold with 5 folds)
cv = StratifiedKFold(n_splits=5)
# Perform cross-validation and get accuracy scores
scores[subject] = cross_val_score(classifier, X, y, cv=cv, scoring="accuracy")
# Get predicted probabilities using cross-validation
pred_prob = cross_val_predict(classifier, X, y, cv=cv, method="predict_proba")[:, 1]
# Calculate ROC curve and AUC
fpr[subject], tpr[subject], thresholds = roc_curve(y, pred_prob)
roc_auc[subject] = auc(fpr[subject], tpr[subject])
# Print the mean accuracy across folds
print(
f"Cross-validated accuracy for subject {subject}: {scores[subject].mean():.2f}+-{scores[subject].std():.4f}"
)
# barplot of accuracies
f, ax = p.subplots()
ax.bar(data.keys(), [scores.mean() for scores in scores.values()])
ax.set_ylabel("Accuracy")
ax.set_xlabel("Subject")
Cross-validated accuracy for subject sub-01: 0.92+-0.0000
Cross-validated accuracy for subject sub-02: 0.83+-0.0913
Cross-validated accuracy for subject sub-03: 0.58+-0.1491
[15]:
Text(0.5, 0, 'Subject')

Plot ROC curves for subjects
[16]:
# Initialize the ROC plot
p.figure(figsize=(10, 8))
# Train classifier and plot ROC curve for each subject
for subject, rec in data.items():
# Plotting the ROC curve
p.plot(fpr[subject], tpr[subject], lw=2, label=f'Subject {subject} (AUC = {roc_auc[subject]:.2f})')
# Plot the diagonal line for random guessing
p.plot([0, 1], [0, 1], color='gray', linestyle='--')
# Adding labels and title
p.xlabel('False Positive Rate')
p.ylabel('True Positive Rate')
p.title('ROC Curves for All Subjects')
p.legend(loc='lower right')
p.grid(True)
p.show()
