TRF Estimator
fftrf.TRF is the main public API of the toolbox. This page explains the
shared semantics of its parameters and attributes before showing the generated
function-by-function reference.
Constructor
Create a forward model with:
model = TRF(direction=1)
Create a backward model with:
model = TRF(direction=-1)
The metric argument controls how predictions are scored whenever you call
predict(...) with observed targets or score(...). It is also the criterion
used to pick the best regularization value during cross-validation. It does not
change the actual fitting objective, which remains ridge-regularized
frequency-domain TRF estimation.
The same metric also defines the observed and surrogate scores returned by
permutation_test(...) and refit_permutation_test(...).
Common Parameter Meanings
These parameters appear repeatedly across the API:
stimulus,response: one trial as a NumPy array or multiple trials as a list of arraysfs: sampling rate in Hz used to interpret lags and frequency binstmin,tmax: lag window in seconds for extracting the time-domain kernelregularization: ridge value or candidate grid; can also describe banded regularizationbands: contiguous feature-group sizes for grouped ridge penaltiessegment_length: segment size in samples for spectral estimationsegment_duration: segment size in seconds; a friendlier alternative tosegment_lengthoverlap: fractional overlap between neighboring segmentsn_fft: FFT size used when constructing sufficient statisticsspectral_method:"standard"or"multitaper"time_bandwidth,n_tapers: DPSS settings used in multi-taper modewindow: optional window applied before the FFT in standard modedetrend: optional per-segment detrendingk: number of cross-validation folds or"loo"for leave-one-out over trialsaverage: how output-channel scores are reducedtrial_weights: optional weighting over trials during aggregationinput_index,output_index: which predictor/target pair to inspect or plot
If you want a first guess for the segment-related arguments, see
suggest_segment_settings and the
Choosing Segment Settings guide.
What Training Stores
After a successful fit, the estimator keeps enough state to support prediction, plotting, and diagnostics without re-running the fit:
transfer_function: complex frequency-domain solutionfrequencies: frequency axis in Hzweights: lag-domain kerneltimes: lag axis in secondsregularization: chosen scalar ridge or banded tupleregularization_candidates: evaluated grid, if applicablesegment_length,segment_duration,n_fft,overlap: spectral settingsspectral_method,time_bandwidth,n_tapers,window,detrend: estimation settingsbootstrap_interval,bootstrap_level,bootstrap_samples: uncertainty information when bootstrap intervals are computed
Reading the Generated API
The generated reference below is the authoritative source for signatures, arguments, return values, and stored attributes. Use the guides for conceptual advice and the reference for exact behavior.
fftrf.TRF
Estimate stimulus-response mappings in the frequency domain.
TRF is the main estimator of this toolbox. Its public API is
intentionally close to mTRF:
- call :meth:
trainto fit the model - call :meth:
predictto generate predicted responses or stimuli - call :meth:
scoreto evaluate predictions - call :meth:
permutation_testto assess held-out prediction scores against surrogate nulls - call :meth:
plotto visualize the fitted kernel - call :meth:
plot_gridto visualize all input/output kernels at once - call :meth:
frequency_resolved_weightsor :meth:plot_frequency_resolved_weightsfor a spectrogram-like kernel view - call :meth:
time_frequency_poweror :meth:plot_time_frequency_powerfor a smoothed spectrogram-like power view of the kernel - call :meth:
plot_transfer_functionto inspect magnitude, phase, or group delay - call :meth:
cross_spectral_diagnostics, :meth:plot_coherence, and :meth:plot_cross_spectrumfor spectral prediction diagnostics - inspect :attr:
weightsand :attr:timesas the time-domain kernel
Unlike a classic time-domain TRF, the fit is performed through ridge-regularized spectral deconvolution. This is often attractive for high-rate continuous data where explicitly building large lag matrices is cumbersome. When multiple regularization values are supplied, the estimator caches per-trial spectral statistics so cross-validation can reuse them instead of recomputing FFTs for every fold and candidate value. In direct single-lambda fits, the estimator automatically uses an aggregated lower-memory spectral path because no per-trial cache is needed.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
direction
|
int
|
Modeling direction. Use |
1
|
metric
|
MetricSpec
|
Callable or built-in metric name used to score predictions. It must
accept |
pearsonr
|
Attributes:
| Name | Type | Description |
|---|---|---|
transfer_function |
ndarray | None
|
Complex-valued frequency-domain mapping with shape
|
frequencies |
ndarray | None
|
Frequency vector in Hz corresponding to |
weights |
ndarray | None
|
Time-domain kernel extracted from |
times |
ndarray | None
|
Lag values in seconds corresponding to the second axis of
:attr: |
regularization |
RegularizationSpec | None
|
Selected ridge parameter. In ordinary ridge mode this is a scalar. When
|
bands |
tuple[int, ...] | None
|
Optional contiguous feature-group definition used for banded regularization. |
feature_regularization |
ndarray | None
|
Expanded per-feature penalty vector actually used by the spectral solver. This is especially useful when banded regularization is active. |
regularization_candidates |
list[RegularizationSpec] | None
|
Candidate ridge values or banded coefficient tuples evaluated during training. This lets cross-validation scores be mapped back to the tested values. |
fs |
float | None
|
Sampling rate used during fitting. |
segment_duration |
float | None
|
Segment length expressed in seconds. This mirrors :attr: |
bootstrap_interval |
ndarray | None
|
Optional trial-bootstrap confidence interval with shape
|
bootstrap_level |
float | None
|
Confidence level used for :attr: |
spectral_method |
SpectralMethod
|
Spectral estimator used during fitting. |
time_bandwidth, n_tapers |
Multi-taper settings stored for fitted models that use
|
Examples:
>>> import numpy as np
>>> from fftrf import TRF
>>> x = np.random.randn(2000, 1)
>>> y = np.random.randn(2000, 1)
>>> model = TRF(direction=1)
>>> model.train(x, y, fs=1000, tmin=0.0, tmax=0.03, regularization=1e-3)
>>> prediction = model.predict(stimulus=x)
bootstrap_confidence_interval(stimulus, response, *, n_bootstraps=200, level=0.95, seed=None, n_jobs=1, trial_weights=_USE_STORED_TRIAL_WEIGHTS)
Estimate and store a trial-bootstrap confidence interval.
The estimator must already be fitted. By default the method reuses the
same fit settings and the same trial-weighting strategy as the model.
Bootstrap resampling is performed over trials, so at least two trials
are required. n_jobs controls optional parallel execution across
bootstrap resamples.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
stimulus
|
ndarray | Sequence[ndarray]
|
Trial data used for the bootstrap resampling. They follow the same
single-trial vs multi-trial conventions as :meth: |
required |
response
|
ndarray | Sequence[ndarray]
|
Trial data used for the bootstrap resampling. They follow the same
single-trial vs multi-trial conventions as :meth: |
required |
n_bootstraps
|
int
|
Number of bootstrap resamples. Larger values yield a more stable interval estimate but increase runtime. |
200
|
level
|
float
|
Confidence level between 0 and 1. For example, |
0.95
|
seed
|
int | None
|
Optional random seed for reproducible bootstrap resampling. |
None
|
n_jobs
|
int | None
|
Number of worker threads used to evaluate bootstrap resamples.
|
1
|
trial_weights
|
None | str | Sequence[float] | object
|
Trial-weighting strategy used while aggregating spectra inside each bootstrap replicate. The default reuses the model's stored weighting behavior. |
_USE_STORED_TRIAL_WEIGHTS
|
Returns:
| Type | Description |
|---|---|
interval, times:
|
Stored confidence interval with shape
|
bootstrap_interval_at(*, tmin=None, tmax=None)
Return the stored bootstrap interval over the requested lag window.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
tmin
|
float | None
|
Optional lag window in seconds to extract from the stored interval. If omitted, the full stored interval is returned. |
None
|
tmax
|
float | None
|
Optional lag window in seconds to extract from the stored interval. If omitted, the full stored interval is returned. |
None
|
Returns:
| Type | Description |
|---|---|
interval, times:
|
|
copy()
Return a copy of the estimator and all learned arrays.
Returns:
| Type | Description |
|---|---|
TRF
|
Independent copy of the estimator state. NumPy arrays and other stored values are copied so that subsequent mutations on the copy do not affect the original instance. |
cross_spectral_diagnostics(*, stimulus=None, response=None, tmin=None, tmax=None, trial_weights=_USE_STORED_TRIAL_WEIGHTS)
Compute observed-vs-predicted cross-spectral diagnostics.
This method reuses the fitted kernel to generate predictions for the provided data, then compares predicted and observed targets in the frequency domain. The returned diagnostics include:
- the learned complex transfer function
- predicted and observed output spectra
- matched predicted-vs-observed cross-spectra
- magnitude-squared coherence between prediction and target
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
stimulus
|
ndarray | Sequence[ndarray] | None
|
Data to evaluate. The method follows the same directional
conventions as :meth: |
None
|
response
|
ndarray | Sequence[ndarray] | None
|
Data to evaluate. The method follows the same directional
conventions as :meth: |
None
|
tmin
|
float | None
|
Optional lag window used when generating predictions for the diagnostics. If omitted, the fitted lag window is reused. |
None
|
tmax
|
float | None
|
Optional lag window used when generating predictions for the diagnostics. If omitted, the fitted lag window is reused. |
None
|
trial_weights
|
None | str | Sequence[float] | object
|
Trial weights used when aggregating the diagnostic spectra. The
default sentinel value means "reuse the weighting strategy stored on
the model". You can also pass |
_USE_STORED_TRIAL_WEIGHTS
|
Returns:
| Type | Description |
|---|---|
TRFDiagnostics
|
Container holding transfer-function slices, spectra, cross-spectra, and coherence. |
Notes
These diagnostics answer a different question than the lag-domain kernel plots: they show whether the model captures the spectral content of the observed targets, output by output, under the same spectral estimation settings used during fitting.
diagnostics(*, stimulus=None, response=None, tmin=None, tmax=None, trial_weights=_USE_STORED_TRIAL_WEIGHTS)
Compatibility alias for :meth:cross_spectral_diagnostics.
frequency_resolved_weights(*, n_bands=24, fmin=None, fmax=None, tmin=None, tmax=None, scale='linear', bandwidth=None, value_mode='real')
Return a spectrotemporal decomposition of the fitted kernel.
The learned transfer function is partitioned into smooth frequency
bands, and each band is transformed back into the lag domain. This
yields a frequency-by-lag representation that can be plotted like a
spectrogram. In the default value_mode="real" setting, summing the
returned weights across the band axis reconstructs the ordinary
time-domain kernel, provided the full fitted frequency range is used.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
n_bands
|
int
|
Number of analysis bands used for the decomposition. |
24
|
fmin
|
float | None
|
Frequency range in Hz to analyze. The default covers the full fitted range from DC to Nyquist. |
None
|
fmax
|
float | None
|
Frequency range in Hz to analyze. The default covers the full fitted range from DC to Nyquist. |
None
|
tmin
|
float | None
|
Optional lag window to extract. If omitted, the fitted lag window is reused. |
None
|
tmax
|
float | None
|
Optional lag window to extract. If omitted, the fitted lag window is reused. |
None
|
scale
|
str
|
Spacing of the band centers. Use |
'linear'
|
bandwidth
|
float | None
|
Gaussian band width in Hz. When omitted, it is inferred from the spacing between neighboring band centers. |
None
|
value_mode
|
str
|
|
'real'
|
Returns:
| Type | Description |
|---|---|
FrequencyResolvedWeights
|
Container holding the filter bank, lag axis, and resolved kernel
tensor with shape |
load(path)
Load estimator state from a pickle file into this instance.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
path
|
str | Path
|
Pickle file previously created by :meth: |
required |
Notes
The current instance is updated in place. After loading, compatibility checks fill in any newer attributes that may not have existed in older saved files.
permutation_test(stimulus=None, response=None, *, n_permutations=1000, average=True, tmin=None, tmax=None, surrogate='circular_shift', min_shift=None, tail='greater', seed=None, n_jobs=1)
Estimate score significance against a surrogate null distribution.
This method evaluates a fitted model on aligned data, then compares the observed prediction score against surrogate scores obtained from the same predictions and a permuted target side. It answers the question "is this held-out score larger than would be expected under a null alignment?" rather than the stronger and slower "would the entire training pipeline beat chance if retrained on permuted data?".
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
stimulus
|
ndarray | Sequence[ndarray] | None
|
Evaluation data using the same directional conventions as
:meth: |
None
|
response
|
ndarray | Sequence[ndarray] | None
|
Evaluation data using the same directional conventions as
:meth: |
None
|
n_permutations
|
int
|
Number of surrogate scores used to form the null distribution. |
1000
|
average
|
bool | Sequence[int]
|
Score reduction applied to the observed and surrogate scores. This
follows the same rules as :meth: |
True
|
tmin
|
float | None
|
Optional lag window used during prediction before scoring. |
None
|
tmax
|
float | None
|
Optional lag window used during prediction before scoring. |
None
|
surrogate
|
str
|
Strategy used to break the alignment between predictions and
observed targets. |
'circular_shift'
|
min_shift
|
float | None
|
Minimum circular shift, in seconds, used when
|
None
|
tail
|
str
|
Tail convention for the p-value calculation: |
'greater'
|
seed
|
int | None
|
Optional random seed for reproducible surrogate generation. |
None
|
n_jobs
|
int | None
|
Number of worker threads used to score the surrogate targets.
|
1
|
Returns:
| Type | Description |
|---|---|
PermutationTestResult
|
Container with the observed score, surrogate null scores, p-value, and z-score. |
Notes
All built-in metrics in ffTRF use the "larger is better" convention,
so tail="greater" is the natural default. The returned null scores
follow the same scalar-vs-array contract as :meth:score: aggregated
scores are stored as a 1D array of length n_permutations, while
average=False keeps one surrogate score per output.
plot(*, input_index=0, output_index=0, tmin=None, tmax=None, ax=None, time_unit='ms', color=None, linewidth=2.0, show_bootstrap_interval=False, interval_color=None, interval_alpha=0.2, title=None, label=None)
Plot one fitted time-domain kernel.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
input_index
|
int
|
Select which input/output pair should be plotted. |
0
|
output_index
|
int
|
Select which input/output pair should be plotted. |
0
|
tmin
|
float | None
|
Optional lag window to visualize. If omitted, the fitted window is used. This is applied by re-extracting the impulse response from the stored transfer function. |
None
|
tmax
|
float | None
|
Optional lag window to visualize. If omitted, the fitted window is used. This is applied by re-extracting the impulse response from the stored transfer function. |
None
|
ax
|
Existing matplotlib axes. When omitted, a new figure and axes are created. |
None
|
|
time_unit
|
str
|
Either |
'ms'
|
color
|
str | None
|
Standard matplotlib styling arguments for the kernel line. |
None
|
linewidth
|
str | None
|
Standard matplotlib styling arguments for the kernel line. |
None
|
title
|
str | None
|
Standard matplotlib styling arguments for the kernel line. |
None
|
label
|
str | None
|
Standard matplotlib styling arguments for the kernel line. |
None
|
show_bootstrap_interval
|
bool
|
If |
False
|
interval_color
|
str | None
|
Styling for the bootstrap interval shading. |
None
|
interval_alpha
|
str | None
|
Styling for the bootstrap interval shading. |
None
|
Returns:
| Type | Description |
|---|---|
fig, ax:
|
The matplotlib figure and axes containing the plot. |
Notes
This helper is a thin wrapper around :mod:matplotlib. It imports the
plotting backend lazily so that the core toolbox stays usable without
installing plotting dependencies.
plot_coherence(*, stimulus=None, response=None, diagnostics=None, output_index=0, ax=None, color=None, linewidth=2.0, title=None)
Plot magnitude-squared coherence between predictions and targets.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
stimulus
|
ndarray | Sequence[ndarray] | None
|
Evaluation data used when |
None
|
response
|
ndarray | Sequence[ndarray] | None
|
Evaluation data used when |
None
|
diagnostics
|
TRFDiagnostics | None
|
Precomputed diagnostics from :meth: |
None
|
output_index
|
int
|
Target output channel to display. |
0
|
ax
|
Existing matplotlib axes. When omitted, a new figure is created. |
None
|
|
color
|
str | None
|
Standard matplotlib styling options for the coherence curve. |
None
|
linewidth
|
str | None
|
Standard matplotlib styling options for the coherence curve. |
None
|
title
|
str | None
|
Standard matplotlib styling options for the coherence curve. |
None
|
Returns:
| Type | Description |
|---|---|
fig, ax:
|
The matplotlib figure and axes containing the coherence plot. |
plot_cross_spectrum(*, stimulus=None, response=None, diagnostics=None, output_index=0, kind='both', ax=None, color=None, phase_color=None, linewidth=2.0, phase_unit='rad', title=None)
Plot the predicted-vs-observed cross spectrum for one output.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
stimulus
|
ndarray | Sequence[ndarray] | None
|
Evaluation data used when |
None
|
response
|
ndarray | Sequence[ndarray] | None
|
Evaluation data used when |
None
|
diagnostics
|
TRFDiagnostics | None
|
Precomputed diagnostics from :meth: |
None
|
output_index
|
int
|
Target output channel to display. |
0
|
kind
|
str
|
Which quantity to plot. Use |
'both'
|
ax
|
Existing matplotlib axes. For |
None
|
|
color
|
str | None
|
Standard plotting options passed through to the underlying matplotlib helper. |
None
|
phase_color
|
str | None
|
Standard plotting options passed through to the underlying matplotlib helper. |
None
|
linewidth
|
str | None
|
Standard plotting options passed through to the underlying matplotlib helper. |
None
|
phase_unit
|
str | None
|
Standard plotting options passed through to the underlying matplotlib helper. |
None
|
title
|
str | None
|
Standard plotting options passed through to the underlying matplotlib helper. |
None
|
Returns:
| Type | Description |
|---|---|
fig, ax:
|
The matplotlib figure and axes containing the requested view. |
plot_frequency_resolved_weights(*, resolved=None, input_index=0, output_index=0, n_bands=24, fmin=None, fmax=None, tmin=None, tmax=None, scale='linear', bandwidth=None, value_mode='real', ax=None, time_unit='ms', cmap=None, colorbar=True, title=None, vmin=None, vmax=None, frequency_axis_scale=None)
Plot one frequency-resolved kernel map as a heatmap.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
resolved
|
FrequencyResolvedWeights | None
|
Precomputed output of :meth: |
None
|
input_index
|
int
|
Select which input/output pair to display. |
0
|
output_index
|
int
|
Select which input/output pair to display. |
0
|
n_bands
|
int
|
Parameters used when |
24
|
fmin
|
int
|
Parameters used when |
24
|
fmax
|
int
|
Parameters used when |
24
|
tmin
|
int
|
Parameters used when |
24
|
tmax
|
int
|
Parameters used when |
24
|
scale
|
int
|
Parameters used when |
24
|
bandwidth
|
int
|
Parameters used when |
24
|
value_mode
|
int
|
Parameters used when |
24
|
ax
|
Existing matplotlib axes. When omitted, a new figure is created. |
None
|
|
time_unit
|
str
|
Either |
'ms'
|
cmap
|
str | None
|
Optional matplotlib colormap. When omitted, a diverging map is used for signed weights and a sequential map is used for non-negative views. |
None
|
colorbar
|
bool
|
If |
True
|
title
|
str | None
|
Optional plot title. |
None
|
vmin
|
float | None
|
Optional explicit color limits. |
None
|
vmax
|
float | None
|
Optional explicit color limits. |
None
|
frequency_axis_scale
|
str | None
|
Optional display scaling for the frequency axis. If omitted, the
scale stored in |
None
|
Returns:
| Type | Description |
|---|---|
fig, ax:
|
The matplotlib figure and axes containing the heatmap. |
plot_grid(*, tmin=None, tmax=None, ax=None, time_unit='ms', color=None, linewidth=1.8, show_bootstrap_interval=False, interval_color=None, interval_alpha=0.2, input_labels=None, output_labels=None, title=None, sharey=False)
Plot every input/output kernel in a grid.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
tmin
|
float | None
|
Optional lag window to visualize. If omitted, the fitted lag window is used for every panel. |
None
|
tmax
|
float | None
|
Optional lag window to visualize. If omitted, the fitted lag window is used for every panel. |
None
|
ax
|
Optional array of pre-created matplotlib axes with shape
|
None
|
|
time_unit
|
str
|
Either |
'ms'
|
color
|
str | None
|
Styling applied to every kernel trace in the grid. |
None
|
linewidth
|
str | None
|
Styling applied to every kernel trace in the grid. |
None
|
show_bootstrap_interval
|
bool
|
If |
False
|
interval_color
|
str | None
|
Styling for the interval shading. |
None
|
interval_alpha
|
str | None
|
Styling for the interval shading. |
None
|
input_labels
|
Sequence[str] | None
|
Optional human-readable labels used to title the individual panels.
When omitted, generic |
None
|
output_labels
|
Sequence[str] | None
|
Optional human-readable labels used to title the individual panels.
When omitted, generic |
None
|
title
|
str | None
|
Optional figure-level title. |
None
|
sharey
|
bool
|
If |
False
|
Returns:
| Type | Description |
|---|---|
fig, axes:
|
The matplotlib figure and the full axes grid. |
Notes
This is convenient for multifeature or multichannel models where
calling :meth:plot repeatedly would be cumbersome.
plot_time_frequency_power(*, power=None, input_index=0, output_index=0, n_bands=24, fmin=None, fmax=None, tmin=None, tmax=None, scale='linear', bandwidth=None, method='hilbert', ax=None, time_unit='ms', cmap=None, colorbar=True, title=None, vmin=None, vmax=None, frequency_axis_scale=None)
Plot a spectrogram-like time-frequency power map.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
power
|
TimeFrequencyPower | None
|
Precomputed output of :meth: |
None
|
input_index
|
int
|
Select which input/output pair to display. |
0
|
output_index
|
int
|
Select which input/output pair to display. |
0
|
n_bands
|
int
|
Parameters used when |
24
|
fmin
|
int
|
Parameters used when |
24
|
fmax
|
int
|
Parameters used when |
24
|
tmin
|
int
|
Parameters used when |
24
|
tmax
|
int
|
Parameters used when |
24
|
scale
|
int
|
Parameters used when |
24
|
bandwidth
|
int
|
Parameters used when |
24
|
method
|
int
|
Parameters used when |
24
|
ax
|
Existing matplotlib axes. When omitted, a new figure is created. |
None
|
|
time_unit
|
str
|
Either |
'ms'
|
cmap
|
str | None
|
Optional matplotlib colormap for the heatmap. |
None
|
colorbar
|
bool
|
If |
True
|
title
|
str | None
|
Optional plot title. |
None
|
vmin
|
float | None
|
Optional explicit color limits. |
None
|
vmax
|
float | None
|
Optional explicit color limits. |
None
|
frequency_axis_scale
|
str | None
|
Optional display scaling for the frequency axis. If omitted, the
scale stored in |
None
|
Returns:
| Type | Description |
|---|---|
fig, ax:
|
The matplotlib figure and axes containing the heatmap. |
plot_transfer_function(*, input_index=0, output_index=0, kind='both', ax=None, color=None, phase_color=None, group_delay_color=None, linewidth=2.0, phase_unit='rad', group_delay_unit='ms', title=None)
Plot magnitude, phase, and/or group delay of one transfer function.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
input_index
|
int
|
Select the predictor-target pair to display. |
0
|
output_index
|
int
|
Select the predictor-target pair to display. |
0
|
kind
|
str
|
Which quantity to plot. Use |
'both'
|
ax
|
Existing matplotlib axes. For |
None
|
|
color
|
str | None
|
Optional line colors for the displayed curves. |
None
|
phase_color
|
str | None
|
Optional line colors for the displayed curves. |
None
|
group_delay_color
|
str | None
|
Optional line colors for the displayed curves. |
None
|
linewidth
|
float
|
Line width passed to matplotlib. |
2.0
|
phase_unit
|
str
|
Unit for plotted phase values, either |
'rad'
|
group_delay_unit
|
str
|
Unit for plotted group delay values, either |
'ms'
|
title
|
str | None
|
Optional plot title. |
None
|
Returns:
| Type | Description |
|---|---|
fig, ax:
|
The matplotlib figure and axes containing the requested view. |
predict(stimulus=None, response=None, *, average=True, tmin=None, tmax=None)
Generate predictions from a fitted model.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
stimulus
|
ndarray | Sequence[ndarray] | None
|
Inputs follow the same single-trial / multi-trial conventions as
:meth: |
None
|
response
|
ndarray | Sequence[ndarray] | None
|
Inputs follow the same single-trial / multi-trial conventions as
:meth: |
None
|
average
|
bool | Sequence[int]
|
Reduction strategy for the returned score. |
True
|
tmin
|
float | None
|
Optional lag window used during prediction. If omitted, the fitted lag window is used. |
None
|
tmax
|
float | None
|
Optional lag window used during prediction. If omitted, the fitted lag window is used. |
None
|
Returns:
| Type | Description |
|---|---|
prediction or (prediction, metric):
|
Predicted trials are returned in the same single-trial vs list form
as the predictor input. When observed targets are supplied, the
method also returns the metric defined on the estimator. The score
shape is controlled by |
Notes
Prediction is performed by convolving the predictor with the extracted time-domain kernel over the requested lag window. This means you can evaluate alternative lag windows on a fitted model without repeating spectral training, as long as the requested window stays within what is representable by the stored transfer function.
refit_permutation_test(*, train_stimulus, train_response, test_stimulus, test_response, n_permutations=100, average=True, surrogate='circular_shift', min_shift=None, tail='greater', seed=None, n_jobs=1, fit_n_jobs=1, fit_kwargs=None)
Estimate a stronger null by retraining on surrogate-aligned data.
This method fits one model on the original training alignment and then fits one fresh surrogate model per permutation after breaking the predictor-target alignment on the training set. All models are scored on the same held-out aligned evaluation data.
Compared with :meth:permutation_test, this is slower but answers a
stronger question: whether the full training pipeline, including
regularization selection, outperforms chance under a surrogate null.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
train_stimulus
|
ndarray | Sequence[ndarray]
|
Training data used to fit the observed and surrogate models. |
required |
train_response
|
ndarray | Sequence[ndarray]
|
Training data used to fit the observed and surrogate models. |
required |
test_stimulus
|
ndarray | Sequence[ndarray]
|
Held-out aligned evaluation data used to score every fitted model. |
required |
test_response
|
ndarray | Sequence[ndarray]
|
Held-out aligned evaluation data used to score every fitted model. |
required |
n_permutations
|
int
|
Number of surrogate refits used to form the null distribution. |
100
|
average
|
bool | Sequence[int]
|
Score reduction applied to the observed and surrogate scores. This
follows the same rules as :meth: |
True
|
surrogate
|
str
|
Strategy used to break the training alignment. |
'circular_shift'
|
min_shift
|
float | None
|
Minimum circular shift, in seconds, used when
|
None
|
tail
|
str
|
Tail convention for the p-value calculation: |
'greater'
|
seed
|
int | None
|
Optional random seed for reproducible surrogate generation. |
None
|
n_jobs
|
int | None
|
Number of worker threads used across surrogate refits. |
1
|
fit_n_jobs
|
int | None
|
Number of worker threads used inside each individual refit.
The default |
1
|
fit_kwargs
|
dict[str, object] | None
|
Optional overrides for the stored fit configuration. When omitted, the most recent training configuration of this estimator is reused, except that bootstrap estimation is disabled and progress output is suppressed during the surrogate refits. |
None
|
Returns:
| Type | Description |
|---|---|
PermutationTestResult
|
Container with the observed held-out score, surrogate null scores, p-value, and z-score. |
save(path)
Serialize the estimator to disk using :mod:pickle.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
path
|
str | Path
|
Destination file. Parent directories must already exist. |
required |
Notes
The entire estimator instance is serialized, including fitted weights, spectral settings, chosen regularization, bootstrap intervals, and the configured scoring metric. Pickle files should only be loaded from trusted sources.
score(stimulus=None, response=None, *, average=True, tmin=None, tmax=None)
Score predictions without returning the predicted signals.
This is a convenience wrapper around :meth:predict for workflows where
only the metric is needed. Unlike :meth:predict, this method always
requires the observed target side because it returns only the metric,
not the predicted signals themselves.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
stimulus
|
ndarray | Sequence[ndarray] | None
|
Identical to :meth: |
None
|
response
|
ndarray | Sequence[ndarray] | None
|
Identical to :meth: |
None
|
average
|
ndarray | Sequence[ndarray] | None
|
Identical to :meth: |
None
|
tmin
|
ndarray | Sequence[ndarray] | None
|
Identical to :meth: |
None
|
tmax
|
ndarray | Sequence[ndarray] | None
|
Identical to :meth: |
None
|
Returns:
| Type | Description |
|---|---|
ndarray or float
|
Prediction score computed with the estimator's configured metric. |
time_frequency_power(*, n_bands=24, fmin=None, fmax=None, tmin=None, tmax=None, scale='linear', bandwidth=None, method='hilbert')
Estimate spectrogram-like power from the fitted kernel.
This method starts from the signed band-limited kernels returned by
:meth:frequency_resolved_weights and converts each frequency band into
a smoother power representation. With the default method="hilbert",
power is the squared magnitude of the analytic signal of each
band-limited kernel. The result is closer to what users expect from a
spectrogram than simply squaring the oscillatory kernel itself.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
n_bands
|
int
|
Number of analysis bands used to partition the fitted transfer function. |
24
|
fmin
|
float | None
|
Frequency range in Hz to analyze. By default the full fitted range is used. |
None
|
fmax
|
float | None
|
Frequency range in Hz to analyze. By default the full fitted range is used. |
None
|
tmin
|
float | None
|
Optional lag window in seconds to extract from the reconstructed band-limited kernels. |
None
|
tmax
|
float | None
|
Optional lag window in seconds to extract from the reconstructed band-limited kernels. |
None
|
scale
|
str
|
Placement of band centers. |
'linear'
|
bandwidth
|
float | None
|
Width of the Gaussian analysis filters in Hz. If omitted, it is inferred from neighboring band centers. |
None
|
method
|
str
|
Power estimation method. Currently only |
'hilbert'
|
Returns:
| Type | Description |
|---|---|
TimeFrequencyPower
|
Container holding band centers, lag axis, and the power tensor with
shape |
Notes
This representation is best interpreted as a descriptive view of the fitted kernel, not as a spectrogram of the original stimulus or response recordings.
to_impulse_response(tmin=None, tmax=None)
Extract a time-domain kernel from the fitted transfer function.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
tmin
|
float | None
|
Optional lag window in seconds. If omitted, the window used during
:meth: |
None
|
tmax
|
float | None
|
Optional lag window in seconds. If omitted, the window used during
:meth: |
None
|
Returns:
| Type | Description |
|---|---|
weights, times:
|
|
Notes
This method is useful when you want to inspect a different lag window
without refitting the spectral model. It is the same operation used
internally to populate :attr:weights and :attr:times after
training, and it powers downstream helpers such as :meth:plot,
:meth:predict, and :meth:bootstrap_interval_at.
train(stimulus, response, fs, tmin, tmax, regularization, *, bands=None, segment_length=None, segment_duration=None, overlap=0.0, n_fft=None, spectral_method='standard', time_bandwidth=3.5, n_tapers=None, window=None, detrend='constant', k=-1, average=True, seed=None, show_progress=False, n_jobs=1, trial_weights=None, bootstrap_samples=0, bootstrap_level=0.95, bootstrap_seed=None)
Fit the frequency-domain TRF.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
stimulus
|
ndarray | Sequence[ndarray]
|
One trial as a 1D/2D array or multiple trials as a list of arrays.
Each trial must have shape |
required |
response
|
ndarray | Sequence[ndarray]
|
Neural data with one trial as a 1D/2D array or multiple trials as a
list of arrays. Each trial must have shape
|
required |
fs
|
float
|
Sampling rate in Hz shared by stimulus and response. |
required |
tmin
|
float
|
Time window, in seconds, that should be extracted from the learned transfer function as a time-domain kernel. |
required |
tmax
|
float
|
Time window, in seconds, that should be extracted from the learned transfer function as a time-domain kernel. |
required |
regularization
|
float | Sequence[float] | Sequence[Sequence[float]]
|
Regularization specification. The default behavior matches a standard ridge TRF fit:
When |
required |
bands
|
None | Sequence[int]
|
Optional contiguous feature-group sizes for banded ridge
regularization. For example, if the predictor contains one envelope
feature followed by a 16-band spectrogram, use |
None
|
segment_length
|
int | None
|
Segment size used to estimate cross-spectra. If |
None
|
segment_duration
|
float | None
|
Segment duration in seconds. This is a user-friendly alternative to
|
None
|
overlap
|
float
|
Fractional overlap between neighboring segments. Must lie in
|
0.0
|
n_fft
|
int | None
|
FFT size used for spectral estimation. If omitted, a fast FFT length
is chosen automatically from |
None
|
spectral_method
|
SpectralMethod
|
Spectral estimator used to compute the sufficient statistics.
|
'standard'
|
time_bandwidth
|
float
|
Time-bandwidth product used when |
3.5
|
n_tapers
|
int | None
|
Number of DPSS tapers used for |
None
|
window
|
None | str | tuple[str, float] | ndarray
|
Window applied to each segment before FFT. By default no window is
applied, which keeps the behavior closer to a standard lagged ridge
fit. When using short overlapping segments, |
None
|
detrend
|
None | str
|
Optional detrending passed to :func: |
'constant'
|
k
|
int | str
|
Number of cross-validation folds when multiple regularization values
are supplied. |
-1
|
average
|
bool | Sequence[int]
|
How cross-validation scores should be reduced across output
channels/features. |
True
|
seed
|
int | None
|
Optional random seed for shuffling trial order before creating folds. |
None
|
show_progress
|
bool
|
If |
False
|
n_jobs
|
int | None
|
Number of worker threads used for cross-validation folds and
bootstrap resamples. |
1
|
trial_weights
|
None | str | Sequence[float]
|
Optional trial weights. Use |
None
|
bootstrap_samples
|
int
|
Number of trial-bootstrap resamples used to estimate a confidence
interval for the fitted kernel. |
0
|
bootstrap_level
|
float
|
Confidence level used for the stored bootstrap interval. |
0.95
|
bootstrap_seed
|
int | None
|
Optional random seed used for bootstrap resampling. |
None
|
Returns:
| Type | Description |
|---|---|
None or ndarray
|
|
Notes
The fitted model is always stored on the instance, even when
cross-validation is used. In that case the final fit uses the selected
regularization value and all provided trials. When multiple
regularization values are supplied, the per-trial spectra are cached so
the FFT work is performed only once. Direct single-lambda fits use a
lower-memory aggregated-spectra path automatically because no trialwise
cache is needed. Banded regularization is entirely opt-in through
bands; leaving it unset preserves the default "mTRF in Fourier
space" workflow.
train_multitaper(stimulus, response, fs, tmin, tmax, regularization, *, bands=None, segment_length=None, segment_duration=None, overlap=0.0, n_fft=None, time_bandwidth=3.5, n_tapers=None, detrend='constant', k=-1, average=True, seed=None, show_progress=False, n_jobs=1, trial_weights=None, bootstrap_samples=0, bootstrap_level=0.95, bootstrap_seed=None)
Fit the model with DPSS multi-taper spectral estimation.
This is a convenience wrapper around :meth:train for users who want
a named multi-taper estimation path without manually setting
spectral_method="multitaper" and window=None themselves.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
stimulus
|
ndarray | Sequence[ndarray]
|
Identical to :meth: |
required |
response
|
ndarray | Sequence[ndarray]
|
Identical to :meth: |
required |
fs
|
ndarray | Sequence[ndarray]
|
Identical to :meth: |
required |
tmin
|
ndarray | Sequence[ndarray]
|
Identical to :meth: |
required |
tmax
|
ndarray | Sequence[ndarray]
|
Identical to :meth: |
required |
regularization
|
ndarray | Sequence[ndarray]
|
Identical to :meth: |
required |
bands
|
None | Sequence[int]
|
Optional grouped-feature definition for banded ridge, exactly as in
:meth: |
None
|
segment_length
|
int | None
|
Segmentation and FFT settings for the cross-spectral estimates.
These behave the same as in :meth: |
None
|
segment_duration
|
int | None
|
Segmentation and FFT settings for the cross-spectral estimates.
These behave the same as in :meth: |
None
|
overlap
|
int | None
|
Segmentation and FFT settings for the cross-spectral estimates.
These behave the same as in :meth: |
None
|
n_fft
|
int | None
|
Segmentation and FFT settings for the cross-spectral estimates.
These behave the same as in :meth: |
None
|
time_bandwidth
|
float
|
Time-bandwidth product of the DPSS tapers. Larger values create broader spectral smoothing and permit more orthogonal tapers. |
3.5
|
n_tapers
|
int | None
|
Number of DPSS tapers to use. If omitted, the conventional
|
None
|
detrend
|
None | str
|
Optional per-segment detrending passed to the underlying spectral estimator. |
'constant'
|
k
|
int | str
|
Cross-validation controls with the same semantics as in
:meth: |
-1
|
average
|
int | str
|
Cross-validation controls with the same semantics as in
:meth: |
-1
|
seed
|
int | str
|
Cross-validation controls with the same semantics as in
:meth: |
-1
|
show_progress
|
int | str
|
Cross-validation controls with the same semantics as in
:meth: |
-1
|
n_jobs
|
int | str
|
Cross-validation controls with the same semantics as in
:meth: |
-1
|
trial_weights
|
None | str | Sequence[float]
|
Optional trial weights used when aggregating training spectra. |
None
|
bootstrap_samples
|
int
|
Optional trial-bootstrap interval settings with the same semantics
as in :meth: |
0
|
bootstrap_level
|
int
|
Optional trial-bootstrap interval settings with the same semantics
as in :meth: |
0
|
bootstrap_seed
|
int
|
Optional trial-bootstrap interval settings with the same semantics
as in :meth: |
0
|
Returns:
| Type | Description |
|---|---|
None or ndarray
|
Same return contract as :meth: |
Notes
Multi-taper fitting is often useful when trial data are noisy or when
short segments make ordinary single-window spectra unstable. Because
the DPSS tapers already define the segment weighting, this wrapper
always disables the ordinary window parameter.
transfer_function_at(*, input_index=0, output_index=0)
Return one complex-valued transfer function slice.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
input_index
|
int
|
Select the predictor-target pair to inspect. |
0
|
output_index
|
int
|
Select the predictor-target pair to inspect. |
0
|
Returns:
| Type | Description |
|---|---|
frequencies, transfer_function:
|
Frequency vector in Hz and the matching complex transfer-function values for the selected input/output pair. |
Notes
The returned complex values encode both gain and phase. If you want
ready-to-plot derived quantities such as magnitude, phase, or group
delay, use :meth:transfer_function_components_at instead.
transfer_function_components_at(*, input_index=0, output_index=0, phase_unit='rad')
Return magnitude, phase, and group delay for one transfer function.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
input_index
|
int
|
Select the predictor-target pair to inspect. |
0
|
output_index
|
int
|
Select the predictor-target pair to inspect. |
0
|
phase_unit
|
str
|
Unit used for the returned phase values. Must be |
'rad'
|
Returns:
| Type | Description |
|---|---|
TransferFunctionComponents
|
Container with the raw complex transfer function plus magnitude, unwrapped phase, and group delay for the selected pair. |