Motion Artefact Detection and Correction

This notebook shows how to identify and correct motion-artefacts using xarray-based masks and cedalion’s correction functionality.

[1]:
# This cells setups the environment when executed in Google Colab.
try:
    import google.colab
    !curl -s https://raw.githubusercontent.com/ibs-lab/cedalion/dev/scripts/colab_setup.py -o colab_setup.py
    # Select branch with --branch "branch name" (default is "dev")
    %run colab_setup.py
except ImportError:
    pass
[2]:
import matplotlib.pyplot as p

import cedalion
import cedalion.datasets as datasets
import cedalion.nirs
import cedalion.sigproc.motion_correct as motion_correct
import cedalion.sigproc.quality as quality
import cedalion.sim.synthetic_artifact as synthetic_artifact
from cedalion import units
[3]:
# get example finger tapping dataset

rec = datasets.get_fingertapping()
rec["od"] = cedalion.nirs.int2od(rec["amp"])

# Add some synthetic spikes and baseline shifts
artifacts = {
    "spike": synthetic_artifact.gen_spike,
    "bl_shift": synthetic_artifact.gen_bl_shift,
}
timing = synthetic_artifact.random_events_perc(rec["od"].time, 0.01, ["spike"])
timing = synthetic_artifact.add_event_timing(
    [(200, 0), (400, 0)], "bl_shift", None, timing
)
rec["od"] = synthetic_artifact.add_artifacts(rec["od"], timing, artifacts)

# Plot some data for visual validation
f, ax = p.subplots(1, 1, figsize=(12, 4))
ax.plot(
    rec["od"].time, rec["od"].sel(channel="S3D3", wavelength="850"), "r-", label="850nm"
)
ax.plot(
    rec["od"].time, rec["od"].sel(channel="S3D3", wavelength="760"), "g-", label="760nm"
)

# indicate added artefacts
for _,row in timing.iterrows():
    p.axvline(row["onset"], c="k", alpha=.2)

p.legend()
ax.set_xlim(0, 500)
ax.set_xlabel("time / s")
ax.set_ylabel("OD")


display(rec["od"])
<xarray.DataArray (channel: 28, wavelength: 2, time: 23239)> Size: 10MB
<Quantity([[[ 0.04042072  0.04460046  0.04421587 ...  0.08227263  0.08328687
    0.07824392]
  [ 0.0238205   0.02007699  0.03480909 ...  0.11612429  0.11917232
    0.12386444]]

 [[-0.00828006 -0.01784406 -0.00219874 ...  0.05383301  0.05068234
    0.05268067]
  [-0.03725579 -0.04067296 -0.02826115 ...  0.08155008  0.07904583
    0.07842621]]

 [[ 0.10055823  0.09914287  0.11119026 ...  0.03696155  0.04202579
    0.04167814]
  [ 0.049938    0.04755176  0.06016311 ...  0.0744283   0.07835364
    0.07515144]]

 ...

 [[ 0.0954341   0.11098679  0.10684828 ...  0.02764187  0.03057966
    0.02245211]
  [ 0.03858011  0.06286433  0.0612825  ...  0.10304278  0.10240304
    0.09205354]]

 [[ 0.1550658   0.17214468  0.16880747 ... -0.00790466 -0.00773703
   -0.01269059]
  [ 0.10250045  0.12616269  0.12619078 ...  0.04663763  0.04687754
    0.03974277]]

 [[ 0.05805322  0.06125157  0.06083507 ... -0.00101062 -0.000856
   -0.00219674]
  [ 0.02437702  0.03088664  0.03219055 ...  0.01326252  0.01341195
    0.01118119]]], 'dimensionless')>
Coordinates:
  * time        (time) float64 186kB 0.0 0.128 0.256 ... 2.974e+03 2.974e+03
    samples     (time) int64 186kB 0 1 2 3 4 5 ... 23234 23235 23236 23237 23238
  * channel     (channel) object 224B 'S1D1' 'S1D2' 'S1D3' ... 'S8D8' 'S8D16'
    source      (channel) object 224B 'S1' 'S1' 'S1' 'S1' ... 'S8' 'S8' 'S8'
    detector    (channel) object 224B 'D1' 'D2' 'D3' 'D9' ... 'D7' 'D8' 'D16'
  * wavelength  (wavelength) float64 16B 760.0 850.0
../../_images/examples_signal_quality_22_motion_artefacts_and_correction_3_1.png
[4]:
display(timing)
onset duration trial_type value channel
0 201.756856 0.354829 spike 1 None
1 8.712008 0.181029 spike 1 None
2 2247.869207 0.382840 spike 1 None
3 2442.594614 0.341627 spike 1 None
4 1189.751050 0.381130 spike 1 None
... ... ... ... ... ...
118 1080.114881 0.184603 spike 1 None
119 757.496150 0.399591 spike 1 None
120 2903.104484 0.398869 spike 1 None
121 200.000000 0.000000 bl_shift 1 None
122 400.000000 0.000000 bl_shift 1 None

123 rows × 5 columns

Detecting Motion Artifacts and generating the MA mask

The example below shows how to check channels for motion artefacts using standard thresholds from Homer2/3. The output is a mask that can be handed to motion correction algorithms that require segments flagged as artefact.

[5]:
# we use Optical Density data for motion artifact detection
fnirs_data = rec["od"]

# define parameters for motion artifact detection. We follow the method from Homer2/3:
# "hmrR_MotionArtifactByChannel" and "hmrR_MotionArtifact".
t_motion = 0.5 * units.s  # time window for motion artifact detection
t_mask = 1.0 * units.s    # time window for masking motion artifacts
                          # (+- t_mask s before/after detected motion artifact)
stdev_thresh = 7.0        # threshold for std. deviation of the signal used to detect
                          # motion artifacts. Default is 50. We set it very low to find
                          # something in our good data for demonstration purposes.
amp_thresh = 5.0          # threshold for amplitude of the signal used to detect motion
                          # artifacts. Default is 5.

# to identify motion artifacts with these parameters we call the following function
ma_mask = quality.id_motion(fnirs_data, t_motion, t_mask, stdev_thresh, amp_thresh)

# it hands us a boolean mask (xarray) of the input dimension, where False indicates a
# motion artifact at a given time point:
ma_mask
[5]:
<xarray.DataArray (channel: 28, wavelength: 2, time: 23239)> Size: 1MB
array([[[ True,  True,  True, ...,  True,  True,  True],
        [ True,  True,  True, ...,  True,  True,  True]],

       [[ True,  True,  True, ...,  True,  True,  True],
        [ True,  True,  True, ...,  True,  True,  True]],

       [[ True,  True,  True, ...,  True,  True,  True],
        [ True,  True,  True, ...,  True,  True,  True]],

       ...,

       [[ True,  True,  True, ...,  True,  True,  True],
        [ True,  True,  True, ...,  True,  True,  True]],

       [[ True,  True,  True, ...,  True,  True,  True],
        [ True,  True,  True, ...,  True,  True,  True]],

       [[ True,  True,  True, ...,  True,  True,  True],
        [ True,  True,  True, ...,  True,  True,  True]]],
      shape=(28, 2, 23239))
Coordinates:
  * time        (time) float64 186kB 0.0 0.128 0.256 ... 2.974e+03 2.974e+03
    samples     (time) int64 186kB 0 1 2 3 4 5 ... 23234 23235 23236 23237 23238
  * channel     (channel) object 224B 'S1D1' 'S1D2' 'S1D3' ... 'S8D8' 'S8D16'
    source      (channel) object 224B 'S1' 'S1' 'S1' 'S1' ... 'S8' 'S8' 'S8'
    detector    (channel) object 224B 'D1' 'D2' 'D3' 'D9' ... 'D7' 'D8' 'D16'
  * wavelength  (wavelength) float64 16B 760.0 850.0

The output mask is quite detailed and still contains all original dimensions (e.g. single wavelengths) and allows us to combine it with a mask from another motion artifact detection method. This is the same approach as for the channel quality metrics above.

Let us now plot the result for an example channel. Note, that for both wavelengths a different number of artifacts was identified, which can sometimes happen:

[6]:
p.figure()
p.plot(ma_mask.time, ma_mask.sel(channel="S3D3", wavelength="760"), "b-")
p.plot(ma_mask.time, ma_mask.sel(channel="S3D3", wavelength="850"), "r-")

# indicate added artefacts
for _,row in timing.iterrows():
    p.axvline(row["onset"], c="k", alpha=.2)

p.xlim(0, 500)
p.xlabel("time / s")
p.ylabel("Motion artifact mask")


p.show()
../../_images/examples_signal_quality_22_motion_artefacts_and_correction_8_0.png

Plotting the mask and the data together (we have to rescale a bit to make both fit):

[7]:
p.figure()
p.plot(fnirs_data.time, fnirs_data.sel(channel="S3D3", wavelength="760"), "r-")
p.plot(ma_mask.time, ma_mask.sel(channel="S3D3", wavelength="760") / 10, "k-")

# indicate added artefacts
for _,row in timing.iterrows():
    p.axvline(row["onset"], c="k", alpha=.2)

p.xlim(0, 500)
p.xlabel("time / s")
p.ylabel("fNIRS Signal / Motion artifact mask")
p.show()

../../_images/examples_signal_quality_22_motion_artefacts_and_correction_10_0.png

Refining the MA Mask

At the latest when we want to correct motion artifacts, we usually do not need the level of granularity that the mask provides. For instance, we usually want to treat a detected motion artifact in either of both wavelengths or chromophores of one channel as a single artifact that gets flagged for both. We might also want to flag motion artifacts globally, i.e. mask time points for all channels even if only some of them show an artifact. This can easily be done by using the “id_motion_refine” function. The function also returns useful information about motion artifacts in each channel in “ma_info”

[8]:
# refine the motion artifact mask. This function collapses the mask along dimensions
# that are chosen by the "operator" argument. Here we use "by_channel", which will yield
# a mask for each channel by collapsing the masks along either the wavelength or
# concentration dimension.
ma_mask_refined, ma_info = quality.id_motion_refine(ma_mask, "by_channel")

# show the refined mask
ma_mask_refined
[8]:
<xarray.DataArray (channel: 28, time: 23239)> Size: 651kB
array([[ True,  True,  True, ...,  True,  True,  True],
       [ True,  True,  True, ...,  True,  True,  True],
       [ True,  True,  True, ...,  True,  True,  True],
       ...,
       [ True,  True,  True, ...,  True,  True,  True],
       [ True,  True,  True, ...,  True,  True,  True],
       [ True,  True,  True, ...,  True,  True,  True]], shape=(28, 23239))
Coordinates:
  * time      (time) float64 186kB 0.0 0.128 0.256 ... 2.974e+03 2.974e+03
    samples   (time) int64 186kB 0 1 2 3 4 5 ... 23234 23235 23236 23237 23238
  * channel   (channel) object 224B 'S1D1' 'S1D2' 'S1D3' ... 'S8D8' 'S8D16'
    source    (channel) object 224B 'S1' 'S1' 'S1' 'S1' ... 'S7' 'S8' 'S8' 'S8'
    detector  (channel) object 224B 'D1' 'D2' 'D3' 'D9' ... 'D7' 'D8' 'D16'

Now the mask does not have the “wavelength” or “concentration” dimension anymore, and the masks of these dimensions are combined:

[9]:
# plot the figure
p.figure()
p.plot(fnirs_data.time, fnirs_data.sel(channel="S3D3", wavelength="760"), "r-")
p.plot(ma_mask_refined.time, ma_mask_refined.sel(channel="S3D3") / 10, "k-")

# indicate added artefacts
for _,row in timing.iterrows():
    p.axvline(row["onset"], c="k", alpha=.2)

p.xlim(0, 500)
p.xlabel("time / s")
p.ylabel("fNIRS Signal / Refined Motion artifact mask")
p.show()

# show the information about the motion artifacts: we get a pandas dataframe telling us
# 1) for which channels artifacts were detected,
# 2) what is the fraction of time points that were marked as artifacts and
# 3) how many artifacts where detected
ma_info

../../_images/examples_signal_quality_22_motion_artefacts_and_correction_14_0.png
[9]:
channel ma_fraction ma_count
0 S1D1 0.086579 97
1 S1D2 0.056070 67
2 S1D3 0.054391 64
3 S1D9 0.042601 54
4 S2D1 0.046689 57
5 S2D3 0.059641 70
6 S2D4 0.031413 36
7 S2D10 0.067301 79
8 S3D2 0.035113 44
9 S3D3 0.058479 71
10 S3D11 0.057490 69
11 S4D3 0.020999 26
12 S4D4 0.028960 27
13 S4D12 0.046301 56
14 S5D5 0.090322 100
15 S5D6 0.061448 74
16 S5D7 0.041869 52
17 S5D13 0.093980 106
18 S6D5 0.058049 72
19 S6D7 0.051121 64
20 S6D8 0.056113 65
21 S6D14 0.060287 74
22 S7D6 0.084599 96
23 S7D7 0.076983 90
24 S7D15 0.056457 69
25 S8D7 0.026722 33
26 S8D8 0.035587 39
27 S8D16 0.043978 52

Now we look at the “all” operator, which will collapse the mask across all dimensions except time, leading to a single motion artifact mask

[10]:
# "all", yields a mask that flags an artifact at any given time if flagged for
# any channetransl, wavelength, chromophore, etc.
ma_mask_refined, ma_info = quality.id_motion_refine(ma_mask, 'all')

# show the refined mask
ma_mask_refined
[10]:
<xarray.DataArray (time: 23239)> Size: 23kB
array([ True,  True,  True, ...,  True,  True,  True], shape=(23239,))
Coordinates:
  * time     (time) float64 186kB 0.0 0.128 0.256 ... 2.974e+03 2.974e+03
    samples  (time) int64 186kB 0 1 2 3 4 5 ... 23234 23235 23236 23237 23238
[11]:
# plot the figure
p.figure()
p.plot(fnirs_data.time, fnirs_data.sel(channel="S3D3", wavelength="760"), "r-")
p.plot(ma_mask_refined.time, ma_mask_refined/10, "k-")
p.xlim(0,500)
p.xlabel("time / s")
p.ylabel("fNIRS Signal / Refined Motion artifact mask")
p.show()

# show the information about the motion artifacts: we get a pandas dataframe telling us
# 1) that the mask is for all channels
# 2) fraction of time points that were marked as artifacts for this mask across all
#    channels
# 3) how many artifacts where detected in total
ma_info
../../_images/examples_signal_quality_22_motion_artefacts_and_correction_17_0.png
[11]:
channel ma_fraction ma_count
0 all channels combined [1.0, 0.9941477688368691] [0, 8]

Motion Correction

Here we illustrate effect of different motion correction methods. Cedalion might have more methods, so make sure to check the API documentation.

[12]:
def compare_raw_cleaned(rec, key_raw, key_cleaned, title):
    chwl = dict(channel="S3D3", wavelength="850")
    f, ax = p.subplots(1, 1, figsize=(12, 4))
    ax.plot(
        rec[key_raw].time,
        rec[key_raw].sel(**chwl),
        "r-",
        label="850nm raw",
    )
    ax.plot(
        rec[key_cleaned].time,
        rec[key_cleaned].sel(**chwl),
        "g-",
        label="850nm cleaned",
    )
    ax.set_xlim(0, 500)
    ax.set_ylabel("OD")
    ax.set_xlabel("time / s")
    ax.set_title(title)
    ax.legend()

    # indicate added artefacts
    for _,row in timing.iterrows():
        p.axvline(row["onset"], c="k", alpha=.2)

SplineSG method:

  1. identifies baselineshifts in the data and uses spline interpolation to correct these shifts

  2. uses a Savitzky-Golay filter to remove spikes

[13]:
frame_size = 10 * units.s
rec["od_splineSG"] = motion_correct.motion_correct_splineSG(
    rec["od"], frame_size=frame_size, p=1
)

compare_raw_cleaned(rec, "od", "od_splineSG", "SplineSG")

../../_images/examples_signal_quality_22_motion_artefacts_and_correction_21_0.png

TDDR:

  • Temporal Derivative Distribution Repair (TDDR) is a robust regression based motion correction algorithm.

  • Doesn’t require any user-supplied parameters

  • See [FLVM19]

[14]:
rec["od_tddr"] = motion_correct.tddr(rec["od"])

compare_raw_cleaned(rec, "od", "od_tddr", "TDDR")
../../_images/examples_signal_quality_22_motion_artefacts_and_correction_23_0.png

PCA

  • Apply motion correction using PCA filter on motion artefact segments (identified by mask).

  • Implementation is based on Homer3 v1.80.2 “hmrR_MotionCorrectPCA.m”

[15]:
rec["od_pca"], nSV_ret, svs = motion_correct.motion_correct_PCA(
    rec["od"], ma_mask_refined
)

compare_raw_cleaned(rec, "od", "od_pca", "PCA")
../../_images/examples_signal_quality_22_motion_artefacts_and_correction_25_0.png

Recursive PCA

  • If any active channel exhibits signal change greater than STDEVthresh or AMPthresh, then that segment of data is marked as a motion artefact.

  • motion_correct_PCA is applied to all segments of data identified as a motion artefact.

  • This is called until maxIter is reached or there are no motion artefacts identified.

[16]:
rec["od_pca_r"], svs, nSV, tInc = motion_correct.motion_correct_PCA_recurse(
    rec["od"], t_motion, t_mask, stdev_thresh, amp_thresh
)

compare_raw_cleaned(rec, "od", "od_pca_r", "Recursive PCA")
../../_images/examples_signal_quality_22_motion_artefacts_and_correction_27_0.png

Wavelet Motion Correction

  • Focused on spike artifacts

  • Can set iqr factor, wavelet, and wavelet decomposition level.

  • Higher iqr factor leads to more coefficients being discarded, i.e. more drastic correction.

[17]:
rec["od_wavelet"] = motion_correct.motion_correct_wavelet(rec["od"])

compare_raw_cleaned(rec, "od", "od_wavelet", "Wavelet")
../../_images/examples_signal_quality_22_motion_artefacts_and_correction_29_0.png