Frequency-Resolved Notebook¶
This notebook uses a simulated event-related response with a time-locked alpha burst around 150 ms. Because the ground truth is structured, the frequency-resolved views are easy to interpret: the ordinary kernel shows the burst in time, and the band-limited views isolate it near 10 Hz.
In [1]:
Copied!
import matplotlib
matplotlib.use("module://matplotlib_inline.backend_inline")
import matplotlib.pyplot as plt
import numpy as np
from scipy.signal import fftconvolve
from fftrf import TRF
def lag_times(fs: float, tmin: float, tmax: float) -> np.ndarray:
lag_start = int(round(tmin * fs))
lag_stop = int(round(tmax * fs))
return np.arange(lag_start, lag_stop, dtype=int) / fs
def make_event_driver(
n_samples: int,
fs: float,
rng: np.random.Generator,
) -> np.ndarray:
event_train = np.zeros(n_samples, dtype=float)
event_index = int(round(0.18 * fs))
min_interval = int(round(0.50 * fs))
max_interval = int(round(0.90 * fs))
while event_index < n_samples:
event_train[event_index] = 1.0
event_index += int(rng.integers(min_interval, max_interval + 1))
smoothing = np.hanning(max(5, int(round(0.020 * fs))))
smoothing /= np.clip(smoothing.sum(), np.finfo(float).eps, None)
driver = fftconvolve(event_train, smoothing, mode="full")[:n_samples]
driver += 0.015 * rng.standard_normal(n_samples)
driver = np.clip(driver, 0.0, None)
driver = driver / np.clip(driver.std(), np.finfo(float).eps, None)
return driver[:, np.newaxis]
def simulate_response(
stimulus: np.ndarray,
kernel: np.ndarray,
*,
noise_scale: float,
rng: np.random.Generator,
) -> np.ndarray:
response = fftconvolve(stimulus[:, 0], kernel, mode="full")[: stimulus.shape[0]]
response += noise_scale * rng.standard_normal(stimulus.shape[0])
return response[:, np.newaxis]
rng = np.random.default_rng(11)
fs = 128.0
tmin = 0.0
tmax = 0.320
times = lag_times(fs, tmin, tmax)
true_kernel = (
0.38 * np.exp(-0.5 * ((times - 0.035) / 0.010) ** 2)
- 0.24 * np.exp(-0.5 * ((times - 0.072) / 0.016) ** 2)
+ 0.52
* np.exp(-0.5 * ((times - 0.150) / 0.055) ** 2)
* np.cos(2.0 * np.pi * 10.0 * (times - 0.150))
+ 0.10 * np.exp(-0.5 * ((times - 0.240) / 0.030) ** 2)
)
stimulus = []
response = []
for _ in range(8):
trial_stimulus = make_event_driver(n_samples=4_096, fs=fs, rng=rng)
trial_response = simulate_response(
trial_stimulus,
true_kernel,
noise_scale=0.05,
rng=rng,
)
stimulus.append(trial_stimulus)
response.append(trial_response)
train_stimulus = stimulus[:-1]
train_response = response[:-1]
test_stimulus = stimulus[-1]
test_response = response[-1]
model = TRF(direction=1)
model.train(
stimulus=train_stimulus,
response=train_response,
fs=fs,
tmin=tmin,
tmax=tmax,
regularization=1e-2,
segment_duration=2.0,
overlap=0.5,
window=None,
)
_, score = model.predict(stimulus=test_stimulus, response=test_response)
kernel_correlation = np.corrcoef(true_kernel, model.weights[0, :, 0])[0, 1]
print(f"Held-out Pearson r: {float(score):.3f}")
print(f"Kernel correlation to ground truth: {kernel_correlation:.3f}")
print(f"Selected regularization: {model.regularization}")
import matplotlib
matplotlib.use("module://matplotlib_inline.backend_inline")
import matplotlib.pyplot as plt
import numpy as np
from scipy.signal import fftconvolve
from fftrf import TRF
def lag_times(fs: float, tmin: float, tmax: float) -> np.ndarray:
lag_start = int(round(tmin * fs))
lag_stop = int(round(tmax * fs))
return np.arange(lag_start, lag_stop, dtype=int) / fs
def make_event_driver(
n_samples: int,
fs: float,
rng: np.random.Generator,
) -> np.ndarray:
event_train = np.zeros(n_samples, dtype=float)
event_index = int(round(0.18 * fs))
min_interval = int(round(0.50 * fs))
max_interval = int(round(0.90 * fs))
while event_index < n_samples:
event_train[event_index] = 1.0
event_index += int(rng.integers(min_interval, max_interval + 1))
smoothing = np.hanning(max(5, int(round(0.020 * fs))))
smoothing /= np.clip(smoothing.sum(), np.finfo(float).eps, None)
driver = fftconvolve(event_train, smoothing, mode="full")[:n_samples]
driver += 0.015 * rng.standard_normal(n_samples)
driver = np.clip(driver, 0.0, None)
driver = driver / np.clip(driver.std(), np.finfo(float).eps, None)
return driver[:, np.newaxis]
def simulate_response(
stimulus: np.ndarray,
kernel: np.ndarray,
*,
noise_scale: float,
rng: np.random.Generator,
) -> np.ndarray:
response = fftconvolve(stimulus[:, 0], kernel, mode="full")[: stimulus.shape[0]]
response += noise_scale * rng.standard_normal(stimulus.shape[0])
return response[:, np.newaxis]
rng = np.random.default_rng(11)
fs = 128.0
tmin = 0.0
tmax = 0.320
times = lag_times(fs, tmin, tmax)
true_kernel = (
0.38 * np.exp(-0.5 * ((times - 0.035) / 0.010) ** 2)
- 0.24 * np.exp(-0.5 * ((times - 0.072) / 0.016) ** 2)
+ 0.52
* np.exp(-0.5 * ((times - 0.150) / 0.055) ** 2)
* np.cos(2.0 * np.pi * 10.0 * (times - 0.150))
+ 0.10 * np.exp(-0.5 * ((times - 0.240) / 0.030) ** 2)
)
stimulus = []
response = []
for _ in range(8):
trial_stimulus = make_event_driver(n_samples=4_096, fs=fs, rng=rng)
trial_response = simulate_response(
trial_stimulus,
true_kernel,
noise_scale=0.05,
rng=rng,
)
stimulus.append(trial_stimulus)
response.append(trial_response)
train_stimulus = stimulus[:-1]
train_response = response[:-1]
test_stimulus = stimulus[-1]
test_response = response[-1]
model = TRF(direction=1)
model.train(
stimulus=train_stimulus,
response=train_response,
fs=fs,
tmin=tmin,
tmax=tmax,
regularization=1e-2,
segment_duration=2.0,
overlap=0.5,
window=None,
)
_, score = model.predict(stimulus=test_stimulus, response=test_response)
kernel_correlation = np.corrcoef(true_kernel, model.weights[0, :, 0])[0, 1]
print(f"Held-out Pearson r: {float(score):.3f}")
print(f"Kernel correlation to ground truth: {kernel_correlation:.3f}")
print(f"Selected regularization: {model.regularization}")
Held-out Pearson r: 0.999 Kernel correlation to ground truth: 0.999 Selected regularization: 0.01
Why this example is useful¶
- The early positive and negative peaks act like a small event-related response.
- The oscillatory packet centered near 150 ms creates a clear 10 Hz alpha burst.
- Because the target structure is known, the frequency-resolved views are straightforward to interpret.
In [2]:
Copied!
resolved = model.frequency_resolved_weights(
n_bands=20,
fmax=30.0,
value_mode="real",
)
power = model.time_frequency_power(
n_bands=20,
fmax=30.0,
)
fig, axes = plt.subplots(
3,
1,
figsize=(9, 10),
constrained_layout=True,
gridspec_kw={"height_ratios": [1.0, 1.4, 1.4]},
)
model.plot(
input_index=0,
output_index=0,
ax=axes[0],
title=f"Recovered time-domain kernel (held-out r = {float(score):.3f})",
)
model.plot_frequency_resolved_weights(
resolved=resolved,
input_index=0,
output_index=0,
ax=axes[1],
title="Frequency-resolved weights (alpha burst near 10 Hz)",
time_unit="ms",
)
model.plot_time_frequency_power(
power=power,
input_index=0,
output_index=0,
ax=axes[2],
title="Time-frequency power",
time_unit="ms",
)
plt.show()
resolved = model.frequency_resolved_weights(
n_bands=20,
fmax=30.0,
value_mode="real",
)
power = model.time_frequency_power(
n_bands=20,
fmax=30.0,
)
fig, axes = plt.subplots(
3,
1,
figsize=(9, 10),
constrained_layout=True,
gridspec_kw={"height_ratios": [1.0, 1.4, 1.4]},
)
model.plot(
input_index=0,
output_index=0,
ax=axes[0],
title=f"Recovered time-domain kernel (held-out r = {float(score):.3f})",
)
model.plot_frequency_resolved_weights(
resolved=resolved,
input_index=0,
output_index=0,
ax=axes[1],
title="Frequency-resolved weights (alpha burst near 10 Hz)",
time_unit="ms",
)
model.plot_time_frequency_power(
power=power,
input_index=0,
output_index=0,
ax=axes[2],
title="Time-frequency power",
time_unit="ms",
)
plt.show()