Source code for tether.nn.plif
import torch
import torch.nn as nn
from ..functional.plif import PLIFSubFunction
from .surrogates import Surrogate, Arctan
[docs]
class PLIF(nn.Module):
def __init__(
self,
n_neurons,
init_decay=0.9,
init_threshold=1.0,
alpha=2.0,
surrogate: Surrogate = None,
store_traces=False,
):
"""
Initialize the Parametric LIF (PLIF) module.
Decay and Threshold are learnable vectors per neuron.
Parameters
----------
n_neurons : int
Number of neurons.
init_decay : float, optional
Initial decay factor (default is 0.9).
init_threshold : float, optional
Initial spiking threshold (default is 1.0).
alpha : float, optional
Surrogate gradient parameter (default is 2.0).
surrogate : Surrogate, optional
Surrogate gradient module. If None, uses Arctan.
store_traces : bool, optional
Whether to store the membrane potential trace.
"""
super().__init__()
self.n_neurons = n_neurons
# Learnable vector parameters
self.decay = nn.Parameter(torch.full((n_neurons,), init_decay))
self.threshold = nn.Parameter(torch.full((n_neurons,), init_threshold))
if surrogate is None:
self.surrogate = Arctan(alpha=alpha, trainable=True)
else:
self.surrogate = surrogate
self.store_traces = store_traces
self.register_buffer("v", torch.zeros(n_neurons))
self.v_seq = None
self.firing_rate = 0.0
[docs]
def forward(self, x_seq):
"""
Forward pass of the PLIF module.
Parameters
----------
x_seq : torch.Tensor
Input sequence of shape (Time, Batch, Neurons) or (Time, Batch * Neurons).
Returns
-------
torch.Tensor
Output spikes with the same shape as input.
"""
orig_shape = x_seq.shape
x_flat = x_seq.reshape(orig_shape[0], -1)
if self.v.shape[0] != x_flat.shape[1]:
self.v = torch.zeros(x_flat.shape[1], device=x_seq.device)
spikes, v_next, v_seq = PLIFSubFunction.apply(
x_flat,
self.v,
self.decay,
self.threshold,
self.surrogate.alpha,
self.surrogate.get_id(),
)
if self.store_traces:
self.v_seq = v_seq.detach().reshape(orig_shape)
else:
self.v_seq = None
self.firing_rate = spikes.mean().item()
self.v = v_next.detach()
return spikes.reshape(orig_shape)