Data Loading and Preprocessing
[1]:
# This cells setups the environment when executed in Google Colab.
try:
import google.colab
!curl -s https://raw.githubusercontent.com/ibs-lab/cedalion/dev/scripts/colab_setup.py -o colab_setup.py
# Select branch with --branch "branch name" (default is "dev")
%run colab_setup.py
except ImportError:
pass
[2]:
import cedalion
import cedalion.sigproc.quality as quality
import cedalion.sigproc.motion_correct as motion_correct
from cedalion.plots import segmented_cmap
from cedalion import units
import cedalion.xrutils as xrutils
import cedalion.datasets
from pathlib import Path
import numpy as np
import xarray as xr
import matplotlib.pyplot as p
Load Data
Example datasets are accessible through functions in cedalion.datasets
. These take care of downloading, caching and updating the data files. Often they also already load the data.
Here we load a single-subject DOT dataset with a motor task.
[3]:
rec = cedalion.datasets.get_fingertappingDOT()
This recording object hold a single NIRS time series 'amp'
[4]:
rec.timeseries.keys()
[4]:
odict_keys(['amp'])
It contains several auxiliary time series from additional sensors:
[5]:
rec.aux_ts.keys()
[5]:
odict_keys(['ACCEL_X_1', 'ACCEL_Y_1', 'ACCEL_Z_1', 'GYRO_X_1', 'GYRO_Y_1', 'GYRO_Z_1', 'ExGa1', 'ExGa2', 'ExGa3', 'ExGa4', 'ECG', 'Respiration', 'PPG', 'SpO2', 'Heartrate', 'GSR', 'Temperature'])
Inspecting the Datasets
Raw Amplitude Time Series
[6]:
rec["amp"]
[6]:
<xarray.DataArray (channel: 100, wavelength: 2, time: 8794)> Size: 14MB <Quantity([[[0.08740092 0.08734962 0.08818625 ... 0.09035587 0.09098899 0.09272738] [0.13985697 0.13982265 0.14141524 ... 0.13390421 0.13521901 0.13864038]] [[0.27071937 0.27030255 0.27172273 ... 0.26408907 0.26584981 0.26889485] [0.65636219 0.65455566 0.65930072 ... 0.62508687 0.63123411 0.64150497]] [[0.12522511 0.1251687 0.12585573 ... 0.1244254 0.1247638 0.12649158] [0.20260527 0.2023983 0.20387637 ... 0.19096156 0.19242489 0.19605342]] ... [[0.12742799 0.12771832 0.12856142 ... 0.12475135 0.1240544 0.1230771 ] [0.28363299 0.28471081 0.2877631 ... 0.26401558 0.26345736 0.26146422]] [[0.03967372 0.03983601 0.04008131 ... 0.03961476 0.03932456 0.03901398] [0.1705261 0.17133473 0.17316591 ... 0.16186236 0.16202994 0.16063383]] [[0.08932346 0.08954745 0.0901911 ... 0.08869942 0.08842191 0.08802665] [0.09587179 0.09623634 0.09746928 ... 0.09036516 0.09042745 0.08998772]]], 'volt')> Coordinates: * time (time) float64 70kB 0.0 0.2294 0.4588 ... 2.017e+03 2.017e+03 samples (time) int64 70kB 0 1 2 3 4 5 ... 8788 8789 8790 8791 8792 8793 * channel (channel) object 800B 'S1D1' 'S1D2' 'S1D4' ... 'S14D31' 'S14D32' source (channel) object 800B 'S1' 'S1' 'S1' 'S1' ... 'S14' 'S14' 'S14' detector (channel) object 800B 'D1' 'D2' 'D4' 'D5' ... 'D29' 'D31' 'D32' * wavelength (wavelength) float64 16B 760.0 850.0 Attributes: data_type_group: unprocessed raw
Stimulus event information
[7]:
rec.stim
[7]:
onset | duration | value | trial_type | |
---|---|---|---|---|
0 | 23.855104 | 10.0 | 1.0 | 1 |
1 | 54.132736 | 10.0 | 1.0 | 1 |
2 | 84.410368 | 10.0 | 1.0 | 1 |
3 | 114.688000 | 10.0 | 1.0 | 1 |
4 | 146.112512 | 10.0 | 1.0 | 1 |
... | ... | ... | ... | ... |
125 | 1431.535616 | 10.0 | 1.0 | 5 |
126 | 1526.038528 | 10.0 | 1.0 | 5 |
127 | 1650.819072 | 10.0 | 1.0 | 5 |
128 | 1805.418496 | 10.0 | 1.0 | 5 |
129 | 1931.116544 | 10.0 | 1.0 | 5 |
130 rows × 4 columns
Montage
[8]:
rec.geo3d
[8]:
<xarray.DataArray (label: 346, digitized: 3)> Size: 8kB <Quantity([[-77.817871 15.680614 23.17227 ] [-61.906841 21.227732 56.492802] [-85.37146 -16.079958 8.900885] ... [ 77.521 28.883 -39.113 ] [ 80.59 14.229 -38.278 ] [ 81.95 -0.678 -37.027 ]], 'millimeter')> Coordinates: type (label) object 3kB PointType.SOURCE ... PointType.LANDMARK * label (label) <U6 8kB 'S1' 'S2' 'S3' 'S4' ... 'FFT10h' 'FT10h' 'FTT10h' Dimensions without coordinates: digitized
[9]:
cedalion.plots.plot_montage3D(rec["amp"], rec.geo3d)

Channel Distances
[10]:
distances = cedalion.nirs.channel_distances(rec["amp"], rec.geo3d)
p.figure(figsize=(8,4))
p.hist(distances, 40)
p.xlabel("channel distance / mm")
p.ylabel("channel count");

Plot raw amplitude for one channel
[11]:
# example time trace
amp = rec["amp"]
ch = "S12D25"
f, ax = p.subplots(1,1, figsize=(12,4))
ax.set_prop_cycle("color", cedalion.plots.COLORBREWER_Q8)
ax.plot(amp.time, amp.sel(channel=ch, wavelength=760), label="amp. 760 nm")
ax.plot(amp.time, amp.sel(channel=ch, wavelength=850), label="amp. 850 nm")
cedalion.plots.plot_stim_markers(ax, rec.stim, y=1)
ax.set_xlabel("time / s")
ax.set_ylabel("amplitude / V")
ax.set_xlim(0,150)
ax.legend()
ax.set_title(ch);

Quality Metrics : SCI & PSP
using functions from cedalion.sigproc.quality we calculate two metrics:
scalp coupling index (SCI)
peak spectral power (PSP)
note the different time axis: both metrics a calculated in sliding windows
both functions return a metric and boolean arrays (masks) if the metric is above threshold
[12]:
sci_threshold = 0.75
window_length = 10*units.s
sci, sci_mask = quality.sci(rec["amp"], window_length, sci_threshold)
psp_threshold = 0.03
psp, psp_mask = quality.psp(rec["amp"], window_length, psp_threshold)
display(sci.rename("sci"))
display(sci_mask.rename("sci_mask"))
<xarray.DataArray 'sci' (channel: 100, time: 200)> Size: 160kB array([[0.9949109 , 0.9949109 , 0.97750268, ..., 0.99755867, 0.99673787, 0.99561847], [0.99159096, 0.99159096, 0.93520281, ..., 0.99238156, 0.99570154, 0.98828878], [0.99206498, 0.99206498, 0.9939492 , ..., 0.99527559, 0.99615776, 0.99315244], ..., [0.94375225, 0.94375225, 0.95649474, ..., 0.95639974, 0.97469584, 0.92564555], [0.86450677, 0.86450677, 0.91515934, ..., 0.94323422, 0.94480114, 0.87333273], [0.96496945, 0.96496945, 0.97428204, ..., 0.96280239, 0.97176867, 0.95451752]], shape=(100, 200)) Coordinates: * time (time) float64 2kB 0.0 10.09 20.19 ... 1.998e+03 2.008e+03 samples (time) int64 2kB 0 44 88 132 176 220 ... 8580 8624 8668 8712 8756 * channel (channel) object 800B 'S1D1' 'S1D2' 'S1D4' ... 'S14D31' 'S14D32' source (channel) object 800B 'S1' 'S1' 'S1' 'S1' ... 'S14' 'S14' 'S14' detector (channel) object 800B 'D1' 'D2' 'D4' 'D5' ... 'D29' 'D31' 'D32'
<xarray.DataArray 'sci_mask' (channel: 100, time: 200)> Size: 20kB array([[ True, True, True, ..., True, True, True], [ True, True, True, ..., True, True, True], [ True, True, True, ..., True, True, True], ..., [ True, True, True, ..., True, True, True], [ True, True, True, ..., True, True, True], [ True, True, True, ..., True, True, True]], shape=(100, 200)) Coordinates: * time (time) float64 2kB 0.0 10.09 20.19 ... 1.998e+03 2.008e+03 samples (time) int64 2kB 0 44 88 132 176 220 ... 8580 8624 8668 8712 8756 * channel (channel) object 800B 'S1D1' 'S1D2' 'S1D4' ... 'S14D31' 'S14D32' source (channel) object 800B 'S1' 'S1' 'S1' 'S1' ... 'S14' 'S14' 'S14' detector (channel) object 800B 'D1' 'D2' 'D4' 'D5' ... 'D29' 'D31' 'D32'
[13]:
# define three colomaps: redish below a threshold, blueish above
sci_norm, sci_cmap = segmented_cmap(
"sci_cmap",
0,
1.0,
[(0.0, "#000000"), (sci_threshold, "#DC3220"), (sci_threshold, "#5D3A9B"), (1.0, "#0C7BDC")],
bad="magenta", over="magenta", under="magenta"
)
psp_norm, psp_cmap = segmented_cmap(
"psp_cmap",
0,
1.0,
[(0.0, "#000000"), (psp_threshold, "#DC3220"), (psp_threshold, "#5D3A9B"), (1.0, "#0C7BDC")],
bad="magenta", over="magenta", under="magenta"
)
mask_norm, mask_cmap = segmented_cmap(
"mask_cmap",
0,
1.0,
[(0.0, "#DC3220"), (0.5, "#DC3220"), (0.5, "#0C7BDC"), (1.0, "#0C7BDC")],
)
def plot_sci(sci):
# plot the heatmap
f,ax = p.subplots(1,1,figsize=(17,10))
m = ax.pcolormesh(sci.time, np.arange(len(sci.channel)), sci, shading="nearest", cmap=sci_cmap, norm=sci_norm)
cb = p.colorbar(m, ax=ax)
cb.set_label("SCI")
ax.set_xlabel("time / s")
p.tight_layout()
ax.yaxis.set_ticks(np.arange(len(sci.channel)))
ax.yaxis.set_ticklabels(sci.channel.values, fontsize=7)
def plot_psp(psp):
f,ax = p.subplots(1,1,figsize=(17,10))
m = ax.pcolormesh(psp.time, np.arange(len(psp.channel)), psp, shading="nearest", cmap=psp_cmap, norm=psp_norm)
cb = p.colorbar(m, ax=ax)
cb.set_label("PSP")
ax.set_xlabel("time / s")
p.tight_layout()
ax.yaxis.set_ticks(np.arange(len(psp.channel)))
ax.yaxis.set_ticklabels(psp.channel.values, fontsize=7)
def plot_quality_mask(mask, cb_label : str, bool_labels = ["TAINTED", "CLEAN"]):
# plot the binary heatmap
f,ax = p.subplots(1,1,figsize=(17,10))
m = ax.pcolormesh(mask.time, np.arange(len(mask.channel)), mask, shading="nearest", cmap=mask_cmap, norm=mask_norm)
cb = p.colorbar(m, ax=ax)
p.tight_layout()
ax.yaxis.set_ticks(np.arange(len(mask.channel)))
ax.yaxis.set_ticklabels(mask.channel.values, fontsize=7);
cb.set_label(cb_label)
ax.set_xlabel("time / s");
cb.set_ticks([.25,.75])
cb.set_ticklabels(bool_labels)
ax.set_xlabel("time / s");
[14]:
plot_sci(sci)
plot_quality_mask(sci > sci_threshold, f"SCI > {sci_threshold}")
plot_psp(psp)
plot_quality_mask(psp > psp_threshold, f"PSP > {psp_threshold}")




Combining Signal Quality Masks
We want both SCI and PSP to be above their respective thresholds for a window to be considered clean. We can use the boolean and operation to combine both and then look at the percentage of time both metrics are above the thresholds.
[15]:
combined_mask = sci_mask & psp_mask
display(combined_mask)
plot_quality_mask(combined_mask, "combined_mask")
<xarray.DataArray (channel: 100, time: 200)> Size: 20kB array([[ True, True, False, ..., True, True, True], [ True, True, True, ..., True, True, True], [ True, True, True, ..., True, True, True], ..., [ True, True, True, ..., True, True, True], [ True, True, True, ..., True, True, True], [ True, True, True, ..., True, True, True]], shape=(100, 200)) Coordinates: * time (time) float64 2kB 0.0 10.09 20.19 ... 1.998e+03 2.008e+03 samples (time) int64 2kB 0 44 88 132 176 220 ... 8580 8624 8668 8712 8756 * channel (channel) object 800B 'S1D1' 'S1D2' 'S1D4' ... 'S14D31' 'S14D32' source (channel) object 800B 'S1' 'S1' 'S1' 'S1' ... 'S14' 'S14' 'S14' detector (channel) object 800B 'D1' 'D2' 'D4' 'D5' ... 'D29' 'D31' 'D32'

calculate percentage of clean time per channel
[16]:
perc_time_clean = combined_mask.sum(dim="time") / len(sci.time)
display(perc_time_clean)
f, ax = p.subplots(1,1,figsize=(6.5,6.5))
cedalion.plots.scalp_plot(
rec["amp"],
rec.geo3d,
perc_time_clean,
ax,
cmap="RdYlGn",
vmin=0.80,
vmax=1,
title=None,
cb_label="Percentage of clean time",
channel_lw=2,
optode_labels=True
)
f.tight_layout()
<xarray.DataArray (channel: 100)> Size: 800B array([0.955, 0.99 , 0.995, 0.995, 0.995, 0.995, 0.995, 0.995, 0.995, 0.995, 0.995, 0.995, 0.95 , 0.985, 0.99 , 0.995, 0.995, 0.985, 0.995, 0.995, 0.995, 0.995, 0.995, 0.995, 0.99 , 0.995, 0.97 , 0.995, 0.995, 1. , 0.995, 0.995, 0.995, 0.995, 0.995, 0.995, 0.995, 0.995, 0.995, 0.995, 0.995, 0.995, 0.99 , 0.995, 0.995, 0.995, 0.995, 0.99 , 0.995, 0.995, 0.995, 0.995, 0.995, 0.995, 0.995, 0.995, 0.995, 0.995, 0.995, 0.995, 0.995, 0.995, 0.995, 0.995, 0.995, 0.99 , 1. , 0.995, 0.995, 0.995, 0.995, 0.995, 0.995, 0.995, 0.995, 0.995, 0.975, 0.995, 0.995, 0.995, 0.995, 0.995, 1. , 0.995, 0.995, 0.995, 0.995, 0.995, 0.97 , 0.815, 0.995, 1. , 0.985, 0.995, 0.995, 0.995, 0.91 , 1. , 0.98 , 0.995]) Coordinates: * channel (channel) object 800B 'S1D1' 'S1D2' 'S1D4' ... 'S14D31' 'S14D32' source (channel) object 800B 'S1' 'S1' 'S1' 'S1' ... 'S14' 'S14' 'S14' detector (channel) object 800B 'D1' 'D2' 'D4' 'D5' ... 'D29' 'D31' 'D32'

Correct Motion Artefacts
use
cedalion.nirs.int2d
to get optical densitiesapply Temporal Derivative Distribution Repair (TDDR) first to correct jumps
then apply Wavelet motion artifact correction
[17]:
rec["od"] = cedalion.nirs.int2od(rec["amp"])
rec["od_tddr"] = motion_correct.tddr(rec["od"])
rec["od_wavelet"] = motion_correct.wavelet(rec["od_tddr"])
rec["amp_corrected"] = cedalion.nirs.od2int(rec["od_wavelet"], rec["amp"].mean("time"))
[18]:
# recalculate sci & psp on cleaned data
sci_corr, sci_corr_mask = quality.sci(rec["amp_corrected"], window_length, sci_threshold)
psp_corr, psp_corr_mask = quality.psp(rec["amp_corrected"], window_length, psp_threshold)
combined_corr_mask = sci_corr_mask & psp_corr_mask
[19]:
plot_quality_mask(combined_mask, f"combined mask")
plot_quality_mask(combined_corr_mask, f"combined corrected mask")


Compare masks before and after motion artifact correction
[20]:
changed_windows = (combined_mask == quality.TAINTED) & (combined_corr_mask == quality.CLEAN)
plot_quality_mask(changed_windows, "mask of time windows cleaned by motion correction", bool_labels=["unchanged", "improved"])
changed_windows = (combined_mask == quality.CLEAN) & (combined_corr_mask == quality.TAINTED)
plot_quality_mask(changed_windows, "mask of time windows corrupted by motion correction", bool_labels=["unchanged", "worsened"])


recalculate percentage of clean time
[21]:
perc_time_clean_corr = combined_corr_mask.sum(dim="time") / len(sci.time)
f, ax = p.subplots(1,1,figsize=(6.5,6.5))
cedalion.plots.scalp_plot(
rec["amp"],
rec.geo3d,
perc_time_clean_corr,
ax,
cmap="RdYlGn",
vmin=0.80,
vmax=1,
title=None,
cb_label="Percentage of clean time",
channel_lw=2,
optode_labels=True
)
f.tight_layout()

Global Variance of the Temporal Derivative (GVTD) for identifying global bad time segments
[22]:
gvtd, gvtd_mask = quality.gvtd(rec["amp"])
gvtd_corr, gvtd_prr_mask = quality.gvtd(rec["amp_corrected"])
[23]:
# select the 10 segments with highest gvtd
top10_bad_segments = sorted(
[seg for seg in quality.mask_to_segments(combined_mask.all("channel"))],
key=lambda t: gvtd.sel(time=slice(t[0], t[1])).max(),
reverse=True,
)[:10]
Calculate GVTD for the original and corrected time series
[24]:
f,ax = p.subplots(4,1,figsize=(16,6), sharex=True)
ax[0].plot(gvtd.time, gvtd)
ax[1].plot(combined_mask.time, combined_mask.all("channel"))
ax[2].plot(gvtd_corr.time, gvtd_corr)
ax[3].plot(combined_corr_mask.time, combined_corr_mask.all("channel"))
ax[0].set_ylim(0, 0.02)
ax[2].set_ylim(0, 0.02)
ax[0].set_ylabel("GVTD")
ax[2].set_ylabel("GVTD")
ax[1].set_ylabel("combined_mask")
ax[3].set_ylabel("combined_corr_mask")
ax[3].set_xlabel("time / s")
for i in range(4):
cedalion.plots.plot_segments(ax[i], top10_bad_segments)

Highlight motion correction in selected segments
[25]:
example_channels = ["S4D10", "S13D26"]
f, ax = p.subplots(5,4, figsize=(16,16), sharex=False)
ax = ax.T.flatten()
padding = 15
i = 0
for ch in example_channels:
for (start, end) in top10_bad_segments:
ax[i].set_prop_cycle(color=["#e41a1c", "#ff7f00", "#377eb8", "#984ea3"])
for wl in rec["od"].wavelength.values:
sel = rec["od"].sel(time=slice(start-padding, end+padding), channel=ch, wavelength=wl)
ax[i].plot(sel.time, sel, label=f"{wl:.0f} nm orig")
sel = rec["od_wavelet"].sel(time=slice(start-padding, end+padding), channel=ch, wavelength=wl)
ax[i].plot(sel.time, sel, label=f"{wl:.0f} nm corr")
ax[i].set_title(ch)
ax[i].legend(ncol=2, loc="upper center")
ylim = ax[i].get_ylim()
ax[i].set_ylim(ylim[0], ylim[1]+0.25*(ylim[1]-ylim[0])) # make space for legend
i += 1
p.tight_layout()

Final channel selection
[26]:
perc_time_clean_corr[perc_time_clean_corr < 0.95]
[26]:
<xarray.DataArray (channel: 2)> Size: 16B array([0.88, 0.92]) Coordinates: * channel (channel) object 16B 'S13D26' 'S14D28' source (channel) object 16B 'S13' 'S14' detector (channel) object 16B 'D26' 'D28'
[27]:
signal_quality_selection_masks = [perc_time_clean >= .95]
rec["amp_pruned"], pruned_channels = quality.prune_ch(
rec["amp"], signal_quality_selection_masks, "all"
)
display(rec["amp_pruned"])
display(pruned_channels)
<xarray.DataArray (channel: 98, wavelength: 2, time: 8794)> Size: 14MB <Quantity([[[0.08740092 0.08734962 0.08818625 ... 0.09035587 0.09098899 0.09272738] [0.13985697 0.13982265 0.14141524 ... 0.13390421 0.13521901 0.13864038]] [[0.27071937 0.27030255 0.27172273 ... 0.26408907 0.26584981 0.26889485] [0.65636219 0.65455566 0.65930072 ... 0.62508687 0.63123411 0.64150497]] [[0.12522511 0.1251687 0.12585573 ... 0.1244254 0.1247638 0.12649158] [0.20260527 0.2023983 0.20387637 ... 0.19096156 0.19242489 0.19605342]] ... [[0.12742799 0.12771832 0.12856142 ... 0.12475135 0.1240544 0.1230771 ] [0.28363299 0.28471081 0.2877631 ... 0.26401558 0.26345736 0.26146422]] [[0.03967372 0.03983601 0.04008131 ... 0.03961476 0.03932456 0.03901398] [0.1705261 0.17133473 0.17316591 ... 0.16186236 0.16202994 0.16063383]] [[0.08932346 0.08954745 0.0901911 ... 0.08869942 0.08842191 0.08802665] [0.09587179 0.09623634 0.09746928 ... 0.09036516 0.09042745 0.08998772]]], 'volt')> Coordinates: * time (time) float64 70kB 0.0 0.2294 0.4588 ... 2.017e+03 2.017e+03 samples (time) int64 70kB 0 1 2 3 4 5 ... 8788 8789 8790 8791 8792 8793 * channel (channel) object 784B 'S1D1' 'S1D2' 'S1D4' ... 'S14D31' 'S14D32' source (channel) object 784B 'S1' 'S1' 'S1' 'S1' ... 'S14' 'S14' 'S14' detector (channel) object 784B 'D1' 'D2' 'D4' 'D5' ... 'D29' 'D31' 'D32' * wavelength (wavelength) float64 16B 760.0 850.0 Attributes: data_type_group: unprocessed raw
array(['S13D26', 'S14D28'], dtype=object)