Skip to content

TRF Estimator

fftrf.TRF is the main public API of the toolbox. This page explains the shared semantics of its parameters and attributes before showing the generated function-by-function reference.

Constructor

Create a forward model with:

model = TRF(direction=1)

Create a backward model with:

model = TRF(direction=-1)

The metric argument controls how predictions are scored whenever you call predict(...) with observed targets or score(...). It is also the criterion used to pick the best regularization value during cross-validation. It does not change the actual fitting objective, which remains ridge-regularized frequency-domain TRF estimation.

The same metric also defines the observed and surrogate scores returned by permutation_test(...) and refit_permutation_test(...).

Common Parameter Meanings

These parameters appear repeatedly across the API:

  • stimulus, response: one trial as a NumPy array or multiple trials as a list of arrays
  • fs: sampling rate in Hz used to interpret lags and frequency bins
  • tmin, tmax: lag window in seconds for extracting the time-domain kernel
  • regularization: ridge value or candidate grid; can also describe banded regularization
  • bands: contiguous feature-group sizes for grouped ridge penalties
  • segment_length: segment size in samples for spectral estimation
  • segment_duration: segment size in seconds; a friendlier alternative to segment_length
  • overlap: fractional overlap between neighboring segments
  • n_fft: FFT size used when constructing sufficient statistics
  • spectral_method: "standard" or "multitaper"
  • time_bandwidth, n_tapers: DPSS settings used in multi-taper mode
  • window: optional window applied before the FFT in standard mode
  • detrend: optional per-segment detrending
  • k: number of cross-validation folds or "loo" for leave-one-out over trials
  • average: how output-channel scores are reduced
  • trial_weights: optional weighting over trials during aggregation
  • input_index, output_index: which predictor/target pair to inspect or plot

If you want a first guess for the segment-related arguments, see suggest_segment_settings and the Choosing Segment Settings guide.

What Training Stores

After a successful fit, the estimator keeps enough state to support prediction, plotting, and diagnostics without re-running the fit:

  • transfer_function: complex frequency-domain solution
  • frequencies: frequency axis in Hz
  • weights: lag-domain kernel
  • times: lag axis in seconds
  • regularization: chosen scalar ridge or banded tuple
  • regularization_candidates: evaluated grid, if applicable
  • segment_length, segment_duration, n_fft, overlap: spectral settings
  • spectral_method, time_bandwidth, n_tapers, window, detrend: estimation settings
  • bootstrap_interval, bootstrap_level, bootstrap_samples: uncertainty information when bootstrap intervals are computed

Reading the Generated API

The generated reference below is the authoritative source for signatures, arguments, return values, and stored attributes. Use the guides for conceptual advice and the reference for exact behavior.

fftrf.TRF

Estimate stimulus-response mappings in the frequency domain.

TRF is the main estimator of this toolbox. Its public API is intentionally close to mTRF:

  • call :meth:train to fit the model
  • call :meth:predict to generate predicted responses or stimuli
  • call :meth:score to evaluate predictions
  • call :meth:permutation_test to assess held-out prediction scores against surrogate nulls
  • call :meth:plot to visualize the fitted kernel
  • call :meth:plot_grid to visualize all input/output kernels at once
  • call :meth:frequency_resolved_weights or :meth:plot_frequency_resolved_weights for a spectrogram-like kernel view
  • call :meth:time_frequency_power or :meth:plot_time_frequency_power for a smoothed spectrogram-like power view of the kernel
  • call :meth:plot_transfer_function to inspect magnitude, phase, or group delay
  • call :meth:cross_spectral_diagnostics, :meth:plot_coherence, and :meth:plot_cross_spectrum for spectral prediction diagnostics
  • inspect :attr:weights and :attr:times as the time-domain kernel

Unlike a classic time-domain TRF, the fit is performed through ridge-regularized spectral deconvolution. This is often attractive for high-rate continuous data where explicitly building large lag matrices is cumbersome. When multiple regularization values are supplied, the estimator caches per-trial spectral statistics so cross-validation can reuse them instead of recomputing FFTs for every fold and candidate value. In direct single-lambda fits, the estimator automatically uses an aggregated lower-memory spectral path because no per-trial cache is needed.

Parameters:

Name Type Description Default
direction int

Modeling direction. Use 1 for a forward model (stimulus -> neural response) and -1 for a backward model (neural response -> stimulus).

1
metric MetricSpec

Callable or built-in metric name used to score predictions. It must accept (y_true, y_pred) and return one score per output column. Built-ins currently include "pearsonr", "r2", "explained_variance", and "neg_mse". This metric does not change the underlying TRF solver: fitting is still ridge-regularized spectral deconvolution. The metric is only used when scoring predictions, selecting among regularization candidates during cross-validation, and returning scores from :meth:predict or :meth:score.

pearsonr

Attributes:

Name Type Description
transfer_function ndarray | None

Complex-valued frequency-domain mapping with shape (n_frequencies, n_inputs, n_outputs).

frequencies ndarray | None

Frequency vector in Hz corresponding to transfer_function.

weights ndarray | None

Time-domain kernel extracted from transfer_function over the fitted lag window. Shape is (n_inputs, n_lags, n_outputs).

times ndarray | None

Lag values in seconds corresponding to the second axis of :attr:weights.

regularization RegularizationSpec | None

Selected ridge parameter. In ordinary ridge mode this is a scalar. When bands are used, it stores one coefficient per band.

bands tuple[int, ...] | None

Optional contiguous feature-group definition used for banded regularization.

feature_regularization ndarray | None

Expanded per-feature penalty vector actually used by the spectral solver. This is especially useful when banded regularization is active.

regularization_candidates list[RegularizationSpec] | None

Candidate ridge values or banded coefficient tuples evaluated during training. This lets cross-validation scores be mapped back to the tested values.

fs float | None

Sampling rate used during fitting.

segment_duration float | None

Segment length expressed in seconds. This mirrors :attr:segment_length in a more user-friendly unit.

bootstrap_interval ndarray | None

Optional trial-bootstrap confidence interval with shape (2, n_inputs, n_lags, n_outputs).

bootstrap_level float | None

Confidence level used for :attr:bootstrap_interval.

spectral_method SpectralMethod

Spectral estimator used during fitting. "standard" denotes the default windowed FFT estimator and "multitaper" activates DPSS multi-taper averaging.

time_bandwidth, n_tapers

Multi-taper settings stored for fitted models that use spectral_method="multitaper".

Examples:

>>> import numpy as np
>>> from fftrf import TRF
>>> x = np.random.randn(2000, 1)
>>> y = np.random.randn(2000, 1)
>>> model = TRF(direction=1)
>>> model.train(x, y, fs=1000, tmin=0.0, tmax=0.03, regularization=1e-3)
>>> prediction = model.predict(stimulus=x)

bootstrap_confidence_interval(stimulus, response, *, n_bootstraps=200, level=0.95, seed=None, n_jobs=1, trial_weights=_USE_STORED_TRIAL_WEIGHTS)

Estimate and store a trial-bootstrap confidence interval.

The estimator must already be fitted. By default the method reuses the same fit settings and the same trial-weighting strategy as the model. Bootstrap resampling is performed over trials, so at least two trials are required. n_jobs controls optional parallel execution across bootstrap resamples.

Parameters:

Name Type Description Default
stimulus ndarray | Sequence[ndarray]

Trial data used for the bootstrap resampling. They follow the same single-trial vs multi-trial conventions as :meth:train.

required
response ndarray | Sequence[ndarray]

Trial data used for the bootstrap resampling. They follow the same single-trial vs multi-trial conventions as :meth:train.

required
n_bootstraps int

Number of bootstrap resamples. Larger values yield a more stable interval estimate but increase runtime.

200
level float

Confidence level between 0 and 1. For example, 0.95 stores a 95% interval.

0.95
seed int | None

Optional random seed for reproducible bootstrap resampling.

None
n_jobs int | None

Number of worker threads used to evaluate bootstrap resamples. 1 runs serially and -1 uses all available cores.

1
trial_weights None | str | Sequence[float] | object

Trial-weighting strategy used while aggregating spectra inside each bootstrap replicate. The default reuses the model's stored weighting behavior.

_USE_STORED_TRIAL_WEIGHTS

Returns:

Type Description
interval, times:

Stored confidence interval with shape (2, n_inputs, n_lags, n_outputs) and the corresponding lag axis in seconds.

bootstrap_interval_at(*, tmin=None, tmax=None)

Return the stored bootstrap interval over the requested lag window.

Parameters:

Name Type Description Default
tmin float | None

Optional lag window in seconds to extract from the stored interval. If omitted, the full stored interval is returned.

None
tmax float | None

Optional lag window in seconds to extract from the stored interval. If omitted, the full stored interval is returned.

None

Returns:

Type Description
interval, times:

interval has shape (2, n_inputs, n_lags, n_outputs) where the first axis contains lower and upper bounds. times contains the matching lag values in seconds.

copy()

Return a copy of the estimator and all learned arrays.

Returns:

Type Description
TRF

Independent copy of the estimator state. NumPy arrays and other stored values are copied so that subsequent mutations on the copy do not affect the original instance.

cross_spectral_diagnostics(*, stimulus=None, response=None, tmin=None, tmax=None, trial_weights=_USE_STORED_TRIAL_WEIGHTS)

Compute observed-vs-predicted cross-spectral diagnostics.

This method reuses the fitted kernel to generate predictions for the provided data, then compares predicted and observed targets in the frequency domain. The returned diagnostics include:

  • the learned complex transfer function
  • predicted and observed output spectra
  • matched predicted-vs-observed cross-spectra
  • magnitude-squared coherence between prediction and target

Parameters:

Name Type Description Default
stimulus ndarray | Sequence[ndarray] | None

Data to evaluate. The method follows the same directional conventions as :meth:predict: forward models require stimulus and compare predictions against response, while backward models require response and compare predictions against stimulus.

None
response ndarray | Sequence[ndarray] | None

Data to evaluate. The method follows the same directional conventions as :meth:predict: forward models require stimulus and compare predictions against response, while backward models require response and compare predictions against stimulus.

None
tmin float | None

Optional lag window used when generating predictions for the diagnostics. If omitted, the fitted lag window is reused.

None
tmax float | None

Optional lag window used when generating predictions for the diagnostics. If omitted, the fitted lag window is reused.

None
trial_weights None | str | Sequence[float] | object

Trial weights used when aggregating the diagnostic spectra. The default sentinel value means "reuse the weighting strategy stored on the model". You can also pass None, "inverse_variance", or an explicit vector of weights.

_USE_STORED_TRIAL_WEIGHTS

Returns:

Type Description
TRFDiagnostics

Container holding transfer-function slices, spectra, cross-spectra, and coherence.

Notes

These diagnostics answer a different question than the lag-domain kernel plots: they show whether the model captures the spectral content of the observed targets, output by output, under the same spectral estimation settings used during fitting.

diagnostics(*, stimulus=None, response=None, tmin=None, tmax=None, trial_weights=_USE_STORED_TRIAL_WEIGHTS)

Compatibility alias for :meth:cross_spectral_diagnostics.

frequency_resolved_weights(*, n_bands=24, fmin=None, fmax=None, tmin=None, tmax=None, scale='linear', bandwidth=None, value_mode='real')

Return a spectrotemporal decomposition of the fitted kernel.

The learned transfer function is partitioned into smooth frequency bands, and each band is transformed back into the lag domain. This yields a frequency-by-lag representation that can be plotted like a spectrogram. In the default value_mode="real" setting, summing the returned weights across the band axis reconstructs the ordinary time-domain kernel, provided the full fitted frequency range is used.

Parameters:

Name Type Description Default
n_bands int

Number of analysis bands used for the decomposition.

24
fmin float | None

Frequency range in Hz to analyze. The default covers the full fitted range from DC to Nyquist.

None
fmax float | None

Frequency range in Hz to analyze. The default covers the full fitted range from DC to Nyquist.

None
tmin float | None

Optional lag window to extract. If omitted, the fitted lag window is reused.

None
tmax float | None

Optional lag window to extract. If omitted, the fitted lag window is reused.

None
scale str

Spacing of the band centers. Use "linear" for evenly spaced bands or "log" for logarithmic spacing.

'linear'
bandwidth float | None

Gaussian band width in Hz. When omitted, it is inferred from the spacing between neighboring band centers.

None
value_mode str

"real" returns the signed band-limited kernels, "magnitude" returns their absolute value, and "power" returns squared magnitude.

'real'

Returns:

Type Description
FrequencyResolvedWeights

Container holding the filter bank, lag axis, and resolved kernel tensor with shape (n_inputs, n_bands, n_lags, n_outputs).

load(path)

Load estimator state from a pickle file into this instance.

Parameters:

Name Type Description Default
path str | Path

Pickle file previously created by :meth:save.

required
Notes

The current instance is updated in place. After loading, compatibility checks fill in any newer attributes that may not have existed in older saved files.

permutation_test(stimulus=None, response=None, *, n_permutations=1000, average=True, tmin=None, tmax=None, surrogate='circular_shift', min_shift=None, tail='greater', seed=None, n_jobs=1)

Estimate score significance against a surrogate null distribution.

This method evaluates a fitted model on aligned data, then compares the observed prediction score against surrogate scores obtained from the same predictions and a permuted target side. It answers the question "is this held-out score larger than would be expected under a null alignment?" rather than the stronger and slower "would the entire training pipeline beat chance if retrained on permuted data?".

Parameters:

Name Type Description Default
stimulus ndarray | Sequence[ndarray] | None

Evaluation data using the same directional conventions as :meth:score. Forward models require both stimulus and response. Backward models require both response and stimulus.

None
response ndarray | Sequence[ndarray] | None

Evaluation data using the same directional conventions as :meth:score. Forward models require both stimulus and response. Backward models require both response and stimulus.

None
n_permutations int

Number of surrogate scores used to form the null distribution.

1000
average bool | Sequence[int]

Score reduction applied to the observed and surrogate scores. This follows the same rules as :meth:score.

True
tmin float | None

Optional lag window used during prediction before scoring.

None
tmax float | None

Optional lag window used during prediction before scoring.

None
surrogate str

Strategy used to break the alignment between predictions and observed targets. "circular_shift" rolls each evaluation trial by a random non-zero offset and works for single-trial or multi-trial data. "trial_shuffle" permutes whole target trials and therefore requires at least two equal-length evaluation trials.

'circular_shift'
min_shift float | None

Minimum circular shift, in seconds, used when surrogate="circular_shift". None allows any non-zero shift.

None
tail str

Tail convention for the p-value calculation: "greater", "less", or "two-sided".

'greater'
seed int | None

Optional random seed for reproducible surrogate generation.

None
n_jobs int | None

Number of worker threads used to score the surrogate targets. 1 runs serially and -1 uses all available cores.

1

Returns:

Type Description
PermutationTestResult

Container with the observed score, surrogate null scores, p-value, and z-score.

Notes

All built-in metrics in ffTRF use the "larger is better" convention, so tail="greater" is the natural default. The returned null scores follow the same scalar-vs-array contract as :meth:score: aggregated scores are stored as a 1D array of length n_permutations, while average=False keeps one surrogate score per output.

plot(*, input_index=0, output_index=0, tmin=None, tmax=None, ax=None, time_unit='ms', color=None, linewidth=2.0, show_bootstrap_interval=False, interval_color=None, interval_alpha=0.2, title=None, label=None)

Plot one fitted time-domain kernel.

Parameters:

Name Type Description Default
input_index int

Select which input/output pair should be plotted. input_index refers to the predictor feature dimension and output_index to the predicted target channel.

0
output_index int

Select which input/output pair should be plotted. input_index refers to the predictor feature dimension and output_index to the predicted target channel.

0
tmin float | None

Optional lag window to visualize. If omitted, the fitted window is used. This is applied by re-extracting the impulse response from the stored transfer function.

None
tmax float | None

Optional lag window to visualize. If omitted, the fitted window is used. This is applied by re-extracting the impulse response from the stored transfer function.

None
ax

Existing matplotlib axes. When omitted, a new figure and axes are created.

None
time_unit str

Either "ms" or "s" for the x-axis labels and plotted lag values.

'ms'
color str | None

Standard matplotlib styling arguments for the kernel line.

None
linewidth str | None

Standard matplotlib styling arguments for the kernel line.

None
title str | None

Standard matplotlib styling arguments for the kernel line.

None
label str | None

Standard matplotlib styling arguments for the kernel line.

None
show_bootstrap_interval bool

If True, shade the stored bootstrap confidence interval behind the kernel. This requires that the model already contains a stored interval from training or :meth:bootstrap_confidence_interval.

False
interval_color str | None

Styling for the bootstrap interval shading.

None
interval_alpha str | None

Styling for the bootstrap interval shading.

None

Returns:

Type Description
fig, ax:

The matplotlib figure and axes containing the plot.

Notes

This helper is a thin wrapper around :mod:matplotlib. It imports the plotting backend lazily so that the core toolbox stays usable without installing plotting dependencies.

plot_coherence(*, stimulus=None, response=None, diagnostics=None, output_index=0, ax=None, color=None, linewidth=2.0, title=None)

Plot magnitude-squared coherence between predictions and targets.

Parameters:

Name Type Description Default
stimulus ndarray | Sequence[ndarray] | None

Evaluation data used when diagnostics is not supplied.

None
response ndarray | Sequence[ndarray] | None

Evaluation data used when diagnostics is not supplied.

None
diagnostics TRFDiagnostics | None

Precomputed diagnostics from :meth:cross_spectral_diagnostics. Passing this is useful when plotting several outputs from the same evaluation result.

None
output_index int

Target output channel to display.

0
ax

Existing matplotlib axes. When omitted, a new figure is created.

None
color str | None

Standard matplotlib styling options for the coherence curve.

None
linewidth str | None

Standard matplotlib styling options for the coherence curve.

None
title str | None

Standard matplotlib styling options for the coherence curve.

None

Returns:

Type Description
fig, ax:

The matplotlib figure and axes containing the coherence plot.

plot_cross_spectrum(*, stimulus=None, response=None, diagnostics=None, output_index=0, kind='both', ax=None, color=None, phase_color=None, linewidth=2.0, phase_unit='rad', title=None)

Plot the predicted-vs-observed cross spectrum for one output.

Parameters:

Name Type Description Default
stimulus ndarray | Sequence[ndarray] | None

Evaluation data used when diagnostics is not supplied.

None
response ndarray | Sequence[ndarray] | None

Evaluation data used when diagnostics is not supplied.

None
diagnostics TRFDiagnostics | None

Precomputed diagnostics from :meth:cross_spectral_diagnostics.

None
output_index int

Target output channel to display.

0
kind str

Which quantity to plot. Use "magnitude", "phase", or "both".

'both'
ax

Existing matplotlib axes. For kind="both" provide two stacked axes. When omitted, a new figure is created.

None
color str | None

Standard plotting options passed through to the underlying matplotlib helper.

None
phase_color str | None

Standard plotting options passed through to the underlying matplotlib helper.

None
linewidth str | None

Standard plotting options passed through to the underlying matplotlib helper.

None
phase_unit str | None

Standard plotting options passed through to the underlying matplotlib helper.

None
title str | None

Standard plotting options passed through to the underlying matplotlib helper.

None

Returns:

Type Description
fig, ax:

The matplotlib figure and axes containing the requested view.

plot_frequency_resolved_weights(*, resolved=None, input_index=0, output_index=0, n_bands=24, fmin=None, fmax=None, tmin=None, tmax=None, scale='linear', bandwidth=None, value_mode='real', ax=None, time_unit='ms', cmap=None, colorbar=True, title=None, vmin=None, vmax=None, frequency_axis_scale=None)

Plot one frequency-resolved kernel map as a heatmap.

Parameters:

Name Type Description Default
resolved FrequencyResolvedWeights | None

Precomputed output of :meth:frequency_resolved_weights. Supplying this is useful when you want to plot several channels from the same decomposition without recomputing it. If omitted, the decomposition is computed on demand from the remaining keyword arguments.

None
input_index int

Select which input/output pair to display.

0
output_index int

Select which input/output pair to display.

0
n_bands int

Parameters used when resolved is not supplied. They map exactly to :meth:frequency_resolved_weights.

24
fmin int

Parameters used when resolved is not supplied. They map exactly to :meth:frequency_resolved_weights.

24
fmax int

Parameters used when resolved is not supplied. They map exactly to :meth:frequency_resolved_weights.

24
tmin int

Parameters used when resolved is not supplied. They map exactly to :meth:frequency_resolved_weights.

24
tmax int

Parameters used when resolved is not supplied. They map exactly to :meth:frequency_resolved_weights.

24
scale int

Parameters used when resolved is not supplied. They map exactly to :meth:frequency_resolved_weights.

24
bandwidth int

Parameters used when resolved is not supplied. They map exactly to :meth:frequency_resolved_weights.

24
value_mode int

Parameters used when resolved is not supplied. They map exactly to :meth:frequency_resolved_weights.

24
ax

Existing matplotlib axes. When omitted, a new figure is created.

None
time_unit str

Either "ms" or "s" for the lag axis.

'ms'
cmap str | None

Optional matplotlib colormap. When omitted, a diverging map is used for signed weights and a sequential map is used for non-negative views.

None
colorbar bool

If True, add a colorbar describing the displayed values.

True
title str | None

Optional plot title.

None
vmin float | None

Optional explicit color limits.

None
vmax float | None

Optional explicit color limits.

None
frequency_axis_scale str | None

Optional display scaling for the frequency axis. If omitted, the scale stored in resolved is used.

None

Returns:

Type Description
fig, ax:

The matplotlib figure and axes containing the heatmap.

plot_grid(*, tmin=None, tmax=None, ax=None, time_unit='ms', color=None, linewidth=1.8, show_bootstrap_interval=False, interval_color=None, interval_alpha=0.2, input_labels=None, output_labels=None, title=None, sharey=False)

Plot every input/output kernel in a grid.

Parameters:

Name Type Description Default
tmin float | None

Optional lag window to visualize. If omitted, the fitted lag window is used for every panel.

None
tmax float | None

Optional lag window to visualize. If omitted, the fitted lag window is used for every panel.

None
ax

Optional array of pre-created matplotlib axes with shape (n_inputs, n_outputs). When omitted, a matching grid is created automatically.

None
time_unit str

Either "ms" or "s" for the shared lag axis.

'ms'
color str | None

Styling applied to every kernel trace in the grid.

None
linewidth str | None

Styling applied to every kernel trace in the grid.

None
show_bootstrap_interval bool

If True, overlay the stored bootstrap interval in every panel.

False
interval_color str | None

Styling for the interval shading.

None
interval_alpha str | None

Styling for the interval shading.

None
input_labels Sequence[str] | None

Optional human-readable labels used to title the individual panels. When omitted, generic Input N / Output N labels are used.

None
output_labels Sequence[str] | None

Optional human-readable labels used to title the individual panels. When omitted, generic Input N / Output N labels are used.

None
title str | None

Optional figure-level title.

None
sharey bool

If True, all axes share the same y limits. This is useful for visual comparison across channels but can hide small kernels when one panel contains much larger values than the others.

False

Returns:

Type Description
fig, axes:

The matplotlib figure and the full axes grid.

Notes

This is convenient for multifeature or multichannel models where calling :meth:plot repeatedly would be cumbersome.

plot_time_frequency_power(*, power=None, input_index=0, output_index=0, n_bands=24, fmin=None, fmax=None, tmin=None, tmax=None, scale='linear', bandwidth=None, method='hilbert', ax=None, time_unit='ms', cmap=None, colorbar=True, title=None, vmin=None, vmax=None, frequency_axis_scale=None)

Plot a spectrogram-like time-frequency power map.

Parameters:

Name Type Description Default
power TimeFrequencyPower | None

Precomputed output of :meth:time_frequency_power. Supplying this avoids recomputing the decomposition when plotting multiple panels. If omitted, the power map is computed on demand from the remaining keyword arguments.

None
input_index int

Select which input/output pair to display.

0
output_index int

Select which input/output pair to display.

0
n_bands int

Parameters used when power is not supplied. They map exactly to :meth:time_frequency_power.

24
fmin int

Parameters used when power is not supplied. They map exactly to :meth:time_frequency_power.

24
fmax int

Parameters used when power is not supplied. They map exactly to :meth:time_frequency_power.

24
tmin int

Parameters used when power is not supplied. They map exactly to :meth:time_frequency_power.

24
tmax int

Parameters used when power is not supplied. They map exactly to :meth:time_frequency_power.

24
scale int

Parameters used when power is not supplied. They map exactly to :meth:time_frequency_power.

24
bandwidth int

Parameters used when power is not supplied. They map exactly to :meth:time_frequency_power.

24
method int

Parameters used when power is not supplied. They map exactly to :meth:time_frequency_power.

24
ax

Existing matplotlib axes. When omitted, a new figure is created.

None
time_unit str

Either "ms" or "s" for the lag axis.

'ms'
cmap str | None

Optional matplotlib colormap for the heatmap.

None
colorbar bool

If True, add a colorbar labeled as power.

True
title str | None

Optional plot title.

None
vmin float | None

Optional explicit color limits.

None
vmax float | None

Optional explicit color limits.

None
frequency_axis_scale str | None

Optional display scaling for the frequency axis. If omitted, the scale stored in power is used.

None

Returns:

Type Description
fig, ax:

The matplotlib figure and axes containing the heatmap.

plot_transfer_function(*, input_index=0, output_index=0, kind='both', ax=None, color=None, phase_color=None, group_delay_color=None, linewidth=2.0, phase_unit='rad', group_delay_unit='ms', title=None)

Plot magnitude, phase, and/or group delay of one transfer function.

Parameters:

Name Type Description Default
input_index int

Select the predictor-target pair to display.

0
output_index int

Select the predictor-target pair to display.

0
kind str

Which quantity to plot. Use "magnitude", "phase", "group_delay", "both" (magnitude plus phase), or "all" (magnitude, phase, and group delay).

'both'
ax

Existing matplotlib axes. For kind="both" provide two stacked axes, and for kind="all" provide three. When omitted, a new figure is created.

None
color str | None

Optional line colors for the displayed curves.

None
phase_color str | None

Optional line colors for the displayed curves.

None
group_delay_color str | None

Optional line colors for the displayed curves.

None
linewidth float

Line width passed to matplotlib.

2.0
phase_unit str

Unit for plotted phase values, either "rad" or "deg".

'rad'
group_delay_unit str

Unit for plotted group delay values, either "s" or "ms".

'ms'
title str | None

Optional plot title.

None

Returns:

Type Description
fig, ax:

The matplotlib figure and axes containing the requested view.

predict(stimulus=None, response=None, *, average=True, tmin=None, tmax=None)

Generate predictions from a fitted model.

Parameters:

Name Type Description Default
stimulus ndarray | Sequence[ndarray] | None

Inputs follow the same single-trial / multi-trial conventions as :meth:train. For forward models, stimulus is required. For backward models, response is required. If the corresponding observed target is also provided, the method additionally returns a prediction score.

None
response ndarray | Sequence[ndarray] | None

Inputs follow the same single-trial / multi-trial conventions as :meth:train. For forward models, stimulus is required. For backward models, response is required. If the corresponding observed target is also provided, the method additionally returns a prediction score.

None
average bool | Sequence[int]

Reduction strategy for the returned score. True averages over outputs, False returns one score per output, and a sequence of indices averages only over selected outputs.

True
tmin float | None

Optional lag window used during prediction. If omitted, the fitted lag window is used.

None
tmax float | None

Optional lag window used during prediction. If omitted, the fitted lag window is used.

None

Returns:

Type Description
prediction or (prediction, metric):

Predicted trials are returned in the same single-trial vs list form as the predictor input. When observed targets are supplied, the method also returns the metric defined on the estimator. The score shape is controlled by average.

Notes

Prediction is performed by convolving the predictor with the extracted time-domain kernel over the requested lag window. This means you can evaluate alternative lag windows on a fitted model without repeating spectral training, as long as the requested window stays within what is representable by the stored transfer function.

refit_permutation_test(*, train_stimulus, train_response, test_stimulus, test_response, n_permutations=100, average=True, surrogate='circular_shift', min_shift=None, tail='greater', seed=None, n_jobs=1, fit_n_jobs=1, fit_kwargs=None)

Estimate a stronger null by retraining on surrogate-aligned data.

This method fits one model on the original training alignment and then fits one fresh surrogate model per permutation after breaking the predictor-target alignment on the training set. All models are scored on the same held-out aligned evaluation data.

Compared with :meth:permutation_test, this is slower but answers a stronger question: whether the full training pipeline, including regularization selection, outperforms chance under a surrogate null.

Parameters:

Name Type Description Default
train_stimulus ndarray | Sequence[ndarray]

Training data used to fit the observed and surrogate models.

required
train_response ndarray | Sequence[ndarray]

Training data used to fit the observed and surrogate models.

required
test_stimulus ndarray | Sequence[ndarray]

Held-out aligned evaluation data used to score every fitted model.

required
test_response ndarray | Sequence[ndarray]

Held-out aligned evaluation data used to score every fitted model.

required
n_permutations int

Number of surrogate refits used to form the null distribution.

100
average bool | Sequence[int]

Score reduction applied to the observed and surrogate scores. This follows the same rules as :meth:score.

True
surrogate str

Strategy used to break the training alignment. "circular_shift" rolls each training target trial by a random non-zero offset. "trial_shuffle" permutes whole training target trials and therefore requires at least two equal-length training trials.

'circular_shift'
min_shift float | None

Minimum circular shift, in seconds, used when surrogate="circular_shift".

None
tail str

Tail convention for the p-value calculation: "greater", "less", or "two-sided".

'greater'
seed int | None

Optional random seed for reproducible surrogate generation.

None
n_jobs int | None

Number of worker threads used across surrogate refits.

1
fit_n_jobs int | None

Number of worker threads used inside each individual refit. The default 1 avoids nested oversubscription.

1
fit_kwargs dict[str, object] | None

Optional overrides for the stored fit configuration. When omitted, the most recent training configuration of this estimator is reused, except that bootstrap estimation is disabled and progress output is suppressed during the surrogate refits.

None

Returns:

Type Description
PermutationTestResult

Container with the observed held-out score, surrogate null scores, p-value, and z-score.

save(path)

Serialize the estimator to disk using :mod:pickle.

Parameters:

Name Type Description Default
path str | Path

Destination file. Parent directories must already exist.

required
Notes

The entire estimator instance is serialized, including fitted weights, spectral settings, chosen regularization, bootstrap intervals, and the configured scoring metric. Pickle files should only be loaded from trusted sources.

score(stimulus=None, response=None, *, average=True, tmin=None, tmax=None)

Score predictions without returning the predicted signals.

This is a convenience wrapper around :meth:predict for workflows where only the metric is needed. Unlike :meth:predict, this method always requires the observed target side because it returns only the metric, not the predicted signals themselves.

Parameters:

Name Type Description Default
stimulus ndarray | Sequence[ndarray] | None

Identical to :meth:predict.

None
response ndarray | Sequence[ndarray] | None

Identical to :meth:predict.

None
average ndarray | Sequence[ndarray] | None

Identical to :meth:predict.

None
tmin ndarray | Sequence[ndarray] | None

Identical to :meth:predict.

None
tmax ndarray | Sequence[ndarray] | None

Identical to :meth:predict.

None

Returns:

Type Description
ndarray or float

Prediction score computed with the estimator's configured metric.

time_frequency_power(*, n_bands=24, fmin=None, fmax=None, tmin=None, tmax=None, scale='linear', bandwidth=None, method='hilbert')

Estimate spectrogram-like power from the fitted kernel.

This method starts from the signed band-limited kernels returned by :meth:frequency_resolved_weights and converts each frequency band into a smoother power representation. With the default method="hilbert", power is the squared magnitude of the analytic signal of each band-limited kernel. The result is closer to what users expect from a spectrogram than simply squaring the oscillatory kernel itself.

Parameters:

Name Type Description Default
n_bands int

Number of analysis bands used to partition the fitted transfer function.

24
fmin float | None

Frequency range in Hz to analyze. By default the full fitted range is used.

None
fmax float | None

Frequency range in Hz to analyze. By default the full fitted range is used.

None
tmin float | None

Optional lag window in seconds to extract from the reconstructed band-limited kernels.

None
tmax float | None

Optional lag window in seconds to extract from the reconstructed band-limited kernels.

None
scale str

Placement of band centers. "linear" uses evenly spaced bands and "log" uses logarithmic spacing.

'linear'
bandwidth float | None

Width of the Gaussian analysis filters in Hz. If omitted, it is inferred from neighboring band centers.

None
method str

Power estimation method. Currently only "hilbert" is implemented, which converts each band-limited kernel to an analytic signal and then squares its magnitude.

'hilbert'

Returns:

Type Description
TimeFrequencyPower

Container holding band centers, lag axis, and the power tensor with shape (n_inputs, n_bands, n_lags, n_outputs).

Notes

This representation is best interpreted as a descriptive view of the fitted kernel, not as a spectrogram of the original stimulus or response recordings.

to_impulse_response(tmin=None, tmax=None)

Extract a time-domain kernel from the fitted transfer function.

Parameters:

Name Type Description Default
tmin float | None

Optional lag window in seconds. If omitted, the window used during :meth:train is reused. These values only control which portion of the already fitted transfer function is transformed back to the lag domain; they do not trigger a refit.

None
tmax float | None

Optional lag window in seconds. If omitted, the window used during :meth:train is reused. These values only control which portion of the already fitted transfer function is transformed back to the lag domain; they do not trigger a refit.

None

Returns:

Type Description
weights, times:

weights has shape (n_inputs, n_lags, n_outputs) and times contains the corresponding lag values in seconds. The time axis is sampled at the model's stored sampling rate.

Notes

This method is useful when you want to inspect a different lag window without refitting the spectral model. It is the same operation used internally to populate :attr:weights and :attr:times after training, and it powers downstream helpers such as :meth:plot, :meth:predict, and :meth:bootstrap_interval_at.

train(stimulus, response, fs, tmin, tmax, regularization, *, bands=None, segment_length=None, segment_duration=None, overlap=0.0, n_fft=None, spectral_method='standard', time_bandwidth=3.5, n_tapers=None, window=None, detrend='constant', k=-1, average=True, seed=None, show_progress=False, n_jobs=1, trial_weights=None, bootstrap_samples=0, bootstrap_level=0.95, bootstrap_seed=None)

Fit the frequency-domain TRF.

Parameters:

Name Type Description Default
stimulus ndarray | Sequence[ndarray]

One trial as a 1D/2D array or multiple trials as a list of arrays. Each trial must have shape (n_samples, n_features). A 1D vector is treated as a single-feature input.

required
response ndarray | Sequence[ndarray]

Neural data with one trial as a 1D/2D array or multiple trials as a list of arrays. Each trial must have shape (n_samples, n_outputs).

required
fs float

Sampling rate in Hz shared by stimulus and response.

required
tmin float

Time window, in seconds, that should be extracted from the learned transfer function as a time-domain kernel.

required
tmax float

Time window, in seconds, that should be extracted from the learned transfer function as a time-domain kernel.

required
regularization float | Sequence[float] | Sequence[Sequence[float]]

Regularization specification. The default behavior matches a standard ridge TRF fit:

  • scalar: fit directly with one ridge value
  • 1D sequence of scalars: cross-validate over those candidates

When bands is provided, each feature group gets its own ridge coefficient. In that mode, a 1D scalar sequence follows the mTRF banded-ridge convention: the Cartesian product across bands is evaluated during cross-validation. You can also pass an explicit sequence of per-band coefficient tuples.

required
bands None | Sequence[int]

Optional contiguous feature-group sizes for banded ridge regularization. For example, if the predictor contains one envelope feature followed by a 16-band spectrogram, use bands=[1, 16]. Leaving this as None keeps the estimator in ordinary scalar ridge mode.

None
segment_length int | None

Segment size used to estimate cross-spectra. If None, each trial is treated as a single segment.

None
segment_duration float | None

Segment duration in seconds. This is a user-friendly alternative to segment_length for workflows that prefer time-based settings. Provide either segment_length or segment_duration, not both.

None
overlap float

Fractional overlap between neighboring segments. Must lie in [0, 1).

0.0
n_fft int | None

FFT size used for spectral estimation. If omitted, a fast FFT length is chosen automatically from segment_length.

None
spectral_method SpectralMethod

Spectral estimator used to compute the sufficient statistics. "standard" keeps the current windowed FFT behavior. "multitaper" averages DPSS-tapered spectra and is often more stable for noisy continuous data.

'standard'
time_bandwidth float

Time-bandwidth product used when spectral_method="multitaper". Larger values produce broader spectral smoothing and allow more tapers.

3.5
n_tapers int | None

Number of DPSS tapers used for spectral_method="multitaper". If omitted, the default 2 * time_bandwidth - 1 rule is used.

None
window None | str | tuple[str, float] | ndarray

Window applied to each segment before FFT. By default no window is applied, which keeps the behavior closer to a standard lagged ridge fit. When using short overlapping segments, window="hann" is often a good choice. In multi-taper mode this must be None because the DPSS tapers already define the segment weighting.

None
detrend None | str

Optional detrending passed to :func:scipy.signal.detrend.

'constant'
k int | str

Number of cross-validation folds when multiple regularization values are supplied. -1 or "loo" means leave-one-out over trials.

-1
average bool | Sequence[int]

How cross-validation scores should be reduced across output channels/features. True returns a single score per regularization value, False returns one score per output, and a sequence of indices averages only over the selected outputs.

True
seed int | None

Optional random seed for shuffling trial order before creating folds.

None
show_progress bool

If True and cross-validation is active, print a small progress indicator to standard error while fold/candidate evaluations run.

False
n_jobs int | None

Number of worker threads used for cross-validation folds and bootstrap resamples. 1 keeps the serial behavior. -1 uses all available CPU cores.

1
trial_weights None | str | Sequence[float]

Optional trial weights. Use "inverse_variance" for inverse-variance weighting or pass an explicit vector with one weight per training trial.

None
bootstrap_samples int

Number of trial-bootstrap resamples used to estimate a confidence interval for the fitted kernel. 0 disables the bootstrap.

0
bootstrap_level float

Confidence level used for the stored bootstrap interval.

0.95
bootstrap_seed int | None

Optional random seed used for bootstrap resampling.

None

Returns:

Type Description
None or ndarray

None when a single regularization value is fitted directly. Otherwise returns cross-validation scores for each candidate regularization setting in the order stored by :attr:regularization_candidates.

Notes

The fitted model is always stored on the instance, even when cross-validation is used. In that case the final fit uses the selected regularization value and all provided trials. When multiple regularization values are supplied, the per-trial spectra are cached so the FFT work is performed only once. Direct single-lambda fits use a lower-memory aggregated-spectra path automatically because no trialwise cache is needed. Banded regularization is entirely opt-in through bands; leaving it unset preserves the default "mTRF in Fourier space" workflow.

train_multitaper(stimulus, response, fs, tmin, tmax, regularization, *, bands=None, segment_length=None, segment_duration=None, overlap=0.0, n_fft=None, time_bandwidth=3.5, n_tapers=None, detrend='constant', k=-1, average=True, seed=None, show_progress=False, n_jobs=1, trial_weights=None, bootstrap_samples=0, bootstrap_level=0.95, bootstrap_seed=None)

Fit the model with DPSS multi-taper spectral estimation.

This is a convenience wrapper around :meth:train for users who want a named multi-taper estimation path without manually setting spectral_method="multitaper" and window=None themselves.

Parameters:

Name Type Description Default
stimulus ndarray | Sequence[ndarray]

Identical to :meth:train.

required
response ndarray | Sequence[ndarray]

Identical to :meth:train.

required
fs ndarray | Sequence[ndarray]

Identical to :meth:train.

required
tmin ndarray | Sequence[ndarray]

Identical to :meth:train.

required
tmax ndarray | Sequence[ndarray]

Identical to :meth:train.

required
regularization ndarray | Sequence[ndarray]

Identical to :meth:train.

required
bands None | Sequence[int]

Optional grouped-feature definition for banded ridge, exactly as in :meth:train.

None
segment_length int | None

Segmentation and FFT settings for the cross-spectral estimates. These behave the same as in :meth:train.

None
segment_duration int | None

Segmentation and FFT settings for the cross-spectral estimates. These behave the same as in :meth:train.

None
overlap int | None

Segmentation and FFT settings for the cross-spectral estimates. These behave the same as in :meth:train.

None
n_fft int | None

Segmentation and FFT settings for the cross-spectral estimates. These behave the same as in :meth:train.

None
time_bandwidth float

Time-bandwidth product of the DPSS tapers. Larger values create broader spectral smoothing and permit more orthogonal tapers.

3.5
n_tapers int | None

Number of DPSS tapers to use. If omitted, the conventional 2 * time_bandwidth - 1 rule is used.

None
detrend None | str

Optional per-segment detrending passed to the underlying spectral estimator.

'constant'
k int | str

Cross-validation controls with the same semantics as in :meth:train.

-1
average int | str

Cross-validation controls with the same semantics as in :meth:train.

-1
seed int | str

Cross-validation controls with the same semantics as in :meth:train.

-1
show_progress int | str

Cross-validation controls with the same semantics as in :meth:train.

-1
n_jobs int | str

Cross-validation controls with the same semantics as in :meth:train.

-1
trial_weights None | str | Sequence[float]

Optional trial weights used when aggregating training spectra.

None
bootstrap_samples int

Optional trial-bootstrap interval settings with the same semantics as in :meth:train.

0
bootstrap_level int

Optional trial-bootstrap interval settings with the same semantics as in :meth:train.

0
bootstrap_seed int

Optional trial-bootstrap interval settings with the same semantics as in :meth:train.

0

Returns:

Type Description
None or ndarray

Same return contract as :meth:train: None for a direct fit or cross-validation scores when multiple regularization candidates are evaluated.

Notes

Multi-taper fitting is often useful when trial data are noisy or when short segments make ordinary single-window spectra unstable. Because the DPSS tapers already define the segment weighting, this wrapper always disables the ordinary window parameter.

transfer_function_at(*, input_index=0, output_index=0)

Return one complex-valued transfer function slice.

Parameters:

Name Type Description Default
input_index int

Select the predictor-target pair to inspect.

0
output_index int

Select the predictor-target pair to inspect.

0

Returns:

Type Description
frequencies, transfer_function:

Frequency vector in Hz and the matching complex transfer-function values for the selected input/output pair.

Notes

The returned complex values encode both gain and phase. If you want ready-to-plot derived quantities such as magnitude, phase, or group delay, use :meth:transfer_function_components_at instead.

transfer_function_components_at(*, input_index=0, output_index=0, phase_unit='rad')

Return magnitude, phase, and group delay for one transfer function.

Parameters:

Name Type Description Default
input_index int

Select the predictor-target pair to inspect.

0
output_index int

Select the predictor-target pair to inspect.

0
phase_unit str

Unit used for the returned phase values. Must be "rad" or "deg".

'rad'

Returns:

Type Description
TransferFunctionComponents

Container with the raw complex transfer function plus magnitude, unwrapped phase, and group delay for the selected pair.