Tools for modeling brain responses using (multivariate)temporal response functions.
Project description
mTRFpy
This is an adaptation of the matlab mTRF-toolbox implemented in basic Python and Numpy. It aims to provide the same functionalities as the original toolbox and eventually advance them. The package is written and maintained by Jin Dou and Ole Bialas at the University of Rochester.
Installation
You can get the stable release from PyPI:\
pip install mtrf
Or get the latest version from this repo:\
pip install git+https://github.com/powerfulbean/mTRFpy.git
Tutorial
Here, we provide an overview of mTRFpy's core functions
The TRF class
The TRF class is the core of the toolbox, we import it along with some sample data (the data will be downloaded when you call the loading function for the first time.
from mtrf.model import TRF, load_sample_data
# stimulus is a 16-band spectrogram, response a 128-channel EEG
stimulus, response, samplerate = load_sample_data()
trf = TRF(direction=1) # create a forward TRF
The TRF is applied to the data using the train method which requires specification of the range of time lags and the regularization parameter, often called lambda. To test the models accuracy, we can use the trained TRF to predict the EEG from the stimulus and compute the correlation between the prediction and actual data.
trf.train(stimulus, response, samplerate, tmin=0, tmax=0.3, regularization=1000)
# add the argument `average=False` to get one correlation coefficient per channel
prediction, correlation, error = trf.predict(stimulus, response)
print(f"Pearson's correlation between actual brain response and prediction: {correlation.round(3)}")
The TRF class also has a plotting method to visualize the weights across time. Using the trained TRF we could, for example, plot the weights for each spectral band at one channel or plot the weights for each channel, averaged across all spectral bands
from matplotlib import pyplot as plt
fig, ax = plt.subplots(2)
trf.plot(channel=60, axes=ax[0], show=False, kind='line')
ax[0].set_title('16-band spectrogram TRF at channel 60')
trf.plot(feature='avg', axes=ax[1], show=False, kind='image')
ax[1].set_title('Average TRF at every channel')
plt.tight_layout()
plt.show()
Prevent overfitting
TRFs can also be used as a backward model to the stimulus envelope (i.e. the average spectrogram) from the recorded EEG.
trf = TRF(direction=-1) # create a backward TRF
envelope = stimulus.mean(axis=-1, keepdims=True)
trf.train(envelope, response, samplerate, tmin=0, tmax=0.3, regularization=1000)
prediction, correlation, error = trf.predict(envelope, response)
print(f"Pearson's correlation between actual envelope and prediction: {correlation.round(3)}")
The correlation between the predicted and actual envelope is 0.56, which is far too high. This is the result of overfitting because we are using a model with lots of free parameters (one per channel) and a single estimand (the envelope). To prevent overfitting we need to train the TRF on one (part of the) dataset and test it on another. This can be done systematically using the cross_validate
function. To use it, we must reshape stimulus and response into a 3-D array of shape trials x samples x features.
import numpy as np
from mtrf.crossval import cross_validate
# split stimulus and response into 10 trials
envelope, response = np.array_split(envelope, 10), np.array_split(response, 10)
correlation, error = cross_validate(TRF(direction=-1), envelope, response, samplerate, tmin=0, tmax=0.3, regularization=1000)
print(f"Pearson's correlation between actual envelope and prediction: {correlation.round(3)}")
The correlation estimated via cross-validation is a more accurate description of the decoders accuracy.
Fitting hyperparameters
So far, we used a regularization value of 1000 in all examples which worked reasonably well, judging from the correlation values and visual inspection of TRFs. However, a more principled way is to find the regularization value yielding the most accurate predictions. This can be done using the fit
method. This method takes a list of regularization values, creates a TRF-model for each one and tests its accuracy with cross validation. Then, the value yielding the highest correlation is selected to train the final model.
trf = TRF(direction=1) # create a forward TRF
regularization=np.logspace(-1, 6, 10) # try 10 values between 0.1 and 1,000,000
stimulus = np.array_split(stimulus, 10) # split stimulus as well
correlation, error = trf.fit(stimulus, response, samplerate, tmin=0, tmax=0.3, regularization=regularization)
The TRF class also implements banded ridge regression. This allows us to split our features into bands and fitting the regularization parameter to each band. When using this method, you need to define the bands as an argument of the fit
method. For example, we could fit the regularization to the first and second half of the spectrogram separately (this is just for demonstration purposes, you would not actually do this). Note that the computational cost increases exponentially with the number of bands because the total number of iterations is defined by $n_{regularization}^{n_{bands}}$
trf = TRF(direction=1, method='banded') # create a forward TRF
bands = [8, 8] # first and second half of the spectrogram
regularization=np.logspace(-1, 6, 5) # only 5 values to reduce computation time
correlation, error = trf.fit(stimulus, response, samplerate, tmin=0, tmax=0.3, regularization=regularization, bands=bands)
Note that, fitting the regularization on the data that the model is being tested on also constitutes a (less severe) form of overfitting. To avoid this you should test the final model on data that was withheld from fitting.
Found a bug or missing a feature?
If you want to report a bug or request the implementation of a feature, please take a moment to review the guidelines for contributing.
License
The project is licensed under the BSD 3-Clause License.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file mtrf-1.1.0.tar.gz
.
File metadata
- Download URL: mtrf-1.1.0.tar.gz
- Upload date:
- Size: 15.7 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.1 CPython/3.10.6
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 26403938c7931e5c20ac4edfb2d9844bea6789d54369a11ed893e2600aabc795 |
|
MD5 | 180a3f3ff9e14efaffa8b35c88244894 |
|
BLAKE2b-256 | 90b9d7c7a5a89344264351a3519649edf7a59fe8e85d98a253cc272f44a68587 |
File details
Details for the file mtrf-1.1.0-py3-none-any.whl
.
File metadata
- Download URL: mtrf-1.1.0-py3-none-any.whl
- Upload date:
- Size: 14.7 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.1 CPython/3.10.6
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | a1f5544813accefb909febd8f2997bf246a2fb6a918b6f895697272b15494b7c |
|
MD5 | 6515f2fb977f69724f957a37911eba8f |
|
BLAKE2b-256 | 543cdf70f53701732607d38b19732611662e90e75c3b35a9c066204b9694a6ea |