Examining and thresholding sensitivity of a probe to the cortex using the Schaefer parcellation scheme

This notebook shows how to examine the theoretical sensitivity of a probe on a headmodel to brain areas (here we use parcel coordinates from the Schaefer 2018 atlas), and how to identify parcels that should be dropped, because changes in them cannot be observed. For this the original designed probe can also be reduced to an effective probe by dropping channels that are pruned due to bad signal quality.

[1]:
# This cells setups the environment when executed in Google Colab.
try:
    import google.colab
    !curl -s https://raw.githubusercontent.com/ibs-lab/cedalion/dev/scripts/colab_setup.py -o colab_setup.py
    # Select branch with --branch "branch name" (default is "dev")
    %run colab_setup.py
except ImportError:
    pass
[2]:
# set this flag to True to enable interactive 3D plots
INTERACTIVE_PLOTS = False
[3]:
import pyvista as pv

import cedalion.sigproc

if INTERACTIVE_PLOTS:
    pv.set_jupyter_backend('html')
else:
    pv.set_jupyter_backend('static')

import os

import matplotlib.pyplot as p
import numpy as np
import xarray as xr

import cedalion
import cedalion.dataclasses as cdc
import cedalion.datasets
import cedalion.imagereco.forward_model as fw
import cedalion.sigproc.quality as quality
import cedalion.xrutils as xrutils
from cedalion import units
from cedalion.vis import plot_sensitivity_matrix
from cedalion.io import load_Adot

xr.set_options(display_expand_data=False)

#%matplotlib widget

# for dev purposes
%load_ext autoreload
%autoreload 2

Load a DOT finger-tapping dataset

and perform some very basic quality checks to identify bad channels

[4]:
# load example dataset
rec = cedalion.datasets.get_fingertappingDOT()

# check signal quality using a simple SNR threshold
snr_thresh = 30 # the SNR (std/mean) of a channel. Set very high here for demonstration purposes

# SNR thresholding using the "snr" function of the quality subpackage
snr, snr_mask = quality.snr(rec["amp"], snr_thresh)

# drop channels with bad signal quality (here we only need the list of channels):
# prune channels using the masks and the operator "all", which will keep only channels that pass all three metrics
_, snr_ch_droplist = quality.prune_ch(rec["amp"], [snr_mask], "all")

# print list of dropped channels
print(f"{len(snr_ch_droplist)} channels pruned. List of pruned channels due to bad SNR: {snr_ch_droplist}")
37 channels pruned. List of pruned channels due to bad SNR: ['S1D1' 'S1D4' 'S1D5' 'S1D6' 'S1D8' 'S2D5' 'S2D6' 'S2D9' 'S3D1' 'S4D2'
 'S4D7' 'S4D13' 'S5D3' 'S5D6' 'S8D18' 'S8D20' 'S8D21' 'S8D22' 'S8D24'
 'S9D18' 'S9D21' 'S9D22' 'S9D25' 'S10D17' 'S10D21' 'S11D18' 'S11D20'
 'S11D23' 'S11D24' 'S11D29' 'S12D19' 'S12D22' 'S12D25' 'S12D28' 'S12D29'
 'S12D32' 'S14D27']

Load a headmodel and precalulated fluence profile

[5]:
# load pathes to segmentation data for the icbm-152 atlas
SEG_DATADIR, mask_files, landmarks_file = cedalion.datasets.get_icbm152_segmentation()
PARCEL_FILE = cedalion.datasets.get_icbm152_parcel_file()

# create forward model class for icbm152 atlas
head = fw.TwoSurfaceHeadModel.from_surfaces(
    segmentation_dir=SEG_DATADIR,
    mask_files = mask_files,
    brain_surface_file= os.path.join(SEG_DATADIR, "mask_brain.obj"),
    scalp_surface_file= os.path.join(SEG_DATADIR, "mask_scalp.obj"),
    landmarks_ras_file=landmarks_file,
    parcel_file=PARCEL_FILE,
    brain_face_count=None,
    scalp_face_count=None
)

# snap probe to head and create forward model
geo3D_snapped = head.align_and_snap_to_scalp(rec.geo3d)
fwm = fw.ForwardModel(head, geo3D_snapped, rec._measurement_lists["amp"])

load precomputed fluce, calculate sensitivity on the cortex and plot it on head model

[6]:
# load precomputed sensitivity for this dataset and headmodel
Adot = cedalion.datasets.get_precomputed_sensitivity("fingertappingDOT", "icbm152")

# plot on head model
plotter = plot_sensitivity_matrix.Main(
    sensitivity=Adot,
    brain_surface=head.brain,
    head_surface=head.scalp,
    labeled_points=geo3D_snapped,
)
plotter.plot(high_th=0, low_th=-3)
plotter.plt.show()
Downloading file 'sensitivity_fingertappingDOT_icbm152.nc' from 'https://doc.ibs.tu-berlin.de/cedalion/datasets/v25.1.0/sensitivity_fingertappingDOT_icbm152.nc' to '/home/runner/.cache/cedalion/v25.1.0'.
../../_images/examples_head_models_45_parcel_sensitivity_9_1.png

Investigation of Parcels and effective Parcel Sensitivity

First plot full parcellation scheme on head

[7]:
parcels = cedalion.io.read_parcellations(PARCEL_FILE)
[8]:
b = cdc.VTKSurface.from_trimeshsurface(head.brain)
b = pv.wrap(b.mesh)
b["parcels"] = parcels.Color.tolist()

plt = pv.Plotter()

plt.add_mesh(
    b,
    scalars="parcels",
    rgb=True
)


cog = head.brain.vertices.pint.dequantify().mean("label").values
plt.camera.position = cog + [400,0,400]
plt.camera.focal_point = cog
plt.camera.left = [0, 1, 0]
plt.reset_camera()

plt.show()
../../_images/examples_head_models_45_parcel_sensitivity_12_0.png

Calculate parcel sensitivity mask

Parcels are considered good, if a change in HbO and HbR [µMol] in the parcel leads to an observable change of at least dOD in at least one wavelength of one channel. Sensitivities of all vertices in the parcel are summed up in the sensitivity matrix Adot. Bad channels in an actual measurement that are pruned can be considered by providing a boolean channel_mask, where False indicates bad channels that are dropped and not considered for parcel sensitivity. Requires headmodel with parcelation coordinates.

For this the following input arguments are used with parcel_sensitivity():

  • Adot (channel, vertex, wavelength)): Sensitivity matrix with parcel coordinate belonging to each vertex

  • chan_mask: boolean xarray DataArray channel mask, False for channels to be dropped

  • dOD_thresh: threshold for minimum dOD change in a channel that should be observed from a hemodynamic change in a parcel

  • minCh: minimum number of channels per parcel that should see a change above dOD_thresh

  • dHbO: change in HbO concentration in the parcel in [µMol] used to calculate dOD

  • dHbR: change in HbR concentration in the parcel in [µMol] used to calculate dOD

Output is a tuple (parcel_dOD, parcel_mask), where

  • parcel_dOD (channel, parcel) contains the delta OD observed in a channel given the assumed dHb change in a parcel, and

  • parcel_mask is a boolean DataArray with parcel coords from Adot that is true for parcels for which dOD_thresh is met.

Example without channel pruning

[9]:
# set input parameters for parcel sensitivity calculation.
# Here we do not (yet) drop bad channels to investigate the genereal
# sensitivity of the probe to parcel space independent of channel quality

dOD_thresh = 0.001
minCh = 1
dHbO = 10 #µM
dHbR = -3 #µM

parcel_dOD, parcel_mask = fwm.parcel_sensitivity(Adot, None, dOD_thresh, minCh, dHbO, dHbR)

# display results
display(parcel_dOD)
display(parcel_mask)

# fetch parcels from the parcel_mask that are above the threshold to a list of parcel names
sensitive_parcels = parcel_mask.where(parcel_mask, drop=True)["parcel"].values.tolist()
dropped_parcels = parcel_mask.where(~parcel_mask, drop=True)["parcel"].values.tolist()
print(f"Number of sensitive parcels: {len(sensitive_parcels)}")
print(f"Number of dropped parcels: {len(dropped_parcels)}")
/home/runner/miniconda3/envs/cedalion/lib/python3.11/site-packages/xarray/core/variable.py:315: UnitStrippedWarning: The unit of the quantity is stripped when downcasting to ndarray.
  data = np.asarray(data)
/home/runner/miniconda3/envs/cedalion/lib/python3.11/site-packages/xarray/core/variable.py:315: UnitStrippedWarning: The unit of the quantity is stripped when downcasting to ndarray.
  data = np.asarray(data)
<xarray.DataArray (wavelength: 2, channel: 100, parcel: 602)> Size: 963kB
7.192e-09 8.348e-13 6.043e-14 1.443e-16 ... 1.228e-13 2.644e-13 5.771e-13
Coordinates:
  * channel     (channel) object 800B 'S1D1' 'S1D2' 'S1D4' ... 'S14D31' 'S14D32'
  * parcel      (parcel) object 5kB 'Background+FreeSurfer_Defined_Medial_Wal...
  * wavelength  (wavelength) float64 16B 760.0 850.0
<xarray.DataArray (parcel: 602)> Size: 602B
False False False False False False True ... False False False False False False
Coordinates:
  * parcel   (parcel) object 5kB 'Background+FreeSurfer_Defined_Medial_Wall_L...
Number of sensitive parcels: 112
Number of dropped parcels: 490

visualize results

[10]:
# plots a surface plot of dOD with axes "channel" and "parcel" using the log of the values in dOD on the z-axis for both wavelengths 760 and 850
fig, axes = p.subplots(1, 2, figsize=(12, 6))

for i, wl in enumerate([760.0, 850.0]):
    ax = axes[i]
    im = ax.imshow(np.log10(parcel_dOD.sel(wavelength=wl).values), aspect="auto")
    im.set_clim(-10, 0)
    fig.colorbar(im, ax=ax)
    ax.set_xlabel("parcel")
    ax.set_ylabel("channel")
    ax.set_title(f"log(dOD) for wavelength {wl}")

p.tight_layout()
p.show()
../../_images/examples_head_models_45_parcel_sensitivity_17_0.png
[11]:
# reduce parcel set to plot to the sensitive parcels
# Find mask of rows to update
mask = parcels["Label"].isin(dropped_parcels)
# Use .loc with .apply to set the Color column per row
parcels_plotsens = parcels.copy()
parcels_plotsens.loc[mask, "Color"] = parcels_plotsens.loc[mask, "Color"].apply(lambda _: [1, 1, 1])


b = cdc.VTKSurface.from_trimeshsurface(head.brain)
b = pv.wrap(b.mesh)
b["parcels"] = parcels_plotsens.Color.tolist()

plt = pv.Plotter()
plt.add_mesh(
    b,
    scalars="parcels",
    rgb=True
)


cog = head.brain.vertices.mean("label").values
plt.camera.position = cog + [400,0,400]
plt.camera.focal_point = cog
plt.camera.left = [0,1,0]
plt.reset_camera()
# add probe
geo3D_snapped_o = geo3D_snapped.where(geo3D_snapped.label.str.contains("S|D"), drop=True)
cedalion.plots.plot_labeled_points(plt, geo3D_snapped_o)
plt.show()
/home/runner/miniconda3/envs/cedalion/lib/python3.11/site-packages/xarray/core/variable.py:315: UnitStrippedWarning: The unit of the quantity is stripped when downcasting to ndarray.
  data = np.asarray(data)
../../_images/examples_head_models_45_parcel_sensitivity_18_1.png

Example with channel pruning

The same as before, but now we consider a list of “bad” channels that should be excluded from the sensitivity consideration

[12]:
# set input parameters for parcel sensitivity calculation.
# Now we use the snr channel mask to exclude channels with bad signal quality
# (here artificially high threshold) from consideration for parcel sensitivity

dOD_thresh = 0.001
minCh = 1
dHbO = 10 #µMol
dHbR = -3 #µMol
chan_droplist = snr_ch_droplist # list of dropped channels due to bad SNR, effectively reducing probe


parcel_dOD, parcel_mask = fwm.parcel_sensitivity(
    Adot, chan_droplist, dOD_thresh, minCh, dHbO, dHbR
)

# display results
display(parcel_dOD)
display(parcel_mask)

# fetch parcels from the parcel_mask that are above the threshold to a list of parcel names
sensitive_parcels = parcel_mask.where(parcel_mask, drop=True)["parcel"].values.tolist()
dropped_parcels = parcel_mask.where(~parcel_mask, drop=True)["parcel"].values.tolist()
print(f"Number of sensitive parcels: {len(sensitive_parcels)}")
print(f"Number of dropped parcels: {len(dropped_parcels)}")
/home/runner/miniconda3/envs/cedalion/lib/python3.11/site-packages/xarray/core/variable.py:315: UnitStrippedWarning: The unit of the quantity is stripped when downcasting to ndarray.
  data = np.asarray(data)
/home/runner/miniconda3/envs/cedalion/lib/python3.11/site-packages/xarray/core/variable.py:315: UnitStrippedWarning: The unit of the quantity is stripped when downcasting to ndarray.
  data = np.asarray(data)
<xarray.DataArray (wavelength: 2, channel: 100, parcel: 602)> Size: 963kB
0.0 0.0 0.0 0.0 0.0 0.0 ... 8.939e-14 3.946e-13 1.228e-13 2.644e-13 5.771e-13
Coordinates:
  * channel     (channel) object 800B 'S1D1' 'S1D2' 'S1D4' ... 'S14D31' 'S14D32'
  * parcel      (parcel) object 5kB 'Background+FreeSurfer_Defined_Medial_Wal...
  * wavelength  (wavelength) float64 16B 760.0 850.0
<xarray.DataArray (parcel: 602)> Size: 602B
False False False False False False True ... False False False False False False
Coordinates:
  * parcel   (parcel) object 5kB 'Background+FreeSurfer_Defined_Medial_Wall_L...
Number of sensitive parcels: 86
Number of dropped parcels: 516

visualize results

[13]:
# plots a surface plot of dOD with axes "channel" and "parcel" using the log of the values in dOD on the z-axis for both wavelengths 760 and 850
fig, axes = p.subplots(1, 2, figsize=(12, 6))

for i, wl in enumerate([760.0, 850.0]):
    ax = axes[i]
    im = ax.imshow(np.log10(parcel_dOD.sel(wavelength=wl).values), aspect="auto")
    im.set_clim(-10, 0)
    fig.colorbar(im, ax=ax)
    ax.set_xlabel("parcel")
    ax.set_ylabel("channel")
    ax.set_title(f"log(dOD) for wavelength {wl}")

p.tight_layout()
p.show()


# reduce parcel set to plot to the sensitive parcels
# Find mask of rows to update
mask = parcels["Label"].isin(dropped_parcels)
# Use .loc with .apply to set the Color column per row
parcels_plotsens = parcels.copy()
parcels_plotsens.loc[mask, "Color"] = parcels_plotsens.loc[mask, "Color"].apply(lambda _: [1, 1, 1])

b = cdc.VTKSurface.from_trimeshsurface(head.brain)
b = pv.wrap(b.mesh)
b["parcels"] = parcels_plotsens.Color.tolist()

plt = pv.Plotter()
plt.add_mesh(
    b,
    scalars="parcels",
    rgb=True
)


cog = head.brain.vertices.mean("label").values
plt.camera.position = cog + [400,0,400]
plt.camera.focal_point = cog
plt.camera.left = [0,1,0]
plt.reset_camera()
# add probe
geo3D_snapped_o = geo3D_snapped.where(geo3D_snapped.label.str.contains("S|D"), drop=True)
cedalion.plots.plot_labeled_points(plt, geo3D_snapped_o)
plt.show()


/tmp/ipykernel_4941/1229230510.py:6: RuntimeWarning: divide by zero encountered in log10
  im = ax.imshow(np.log10(parcel_dOD.sel(wavelength=wl).values), aspect="auto")
../../_images/examples_head_models_45_parcel_sensitivity_22_1.png
/home/runner/miniconda3/envs/cedalion/lib/python3.11/site-packages/xarray/core/variable.py:315: UnitStrippedWarning: The unit of the quantity is stripped when downcasting to ndarray.
  data = np.asarray(data)
../../_images/examples_head_models_45_parcel_sensitivity_22_3.png