Source code for tether.functional.lif

import torch
import triton
from ..kernels.lif import lif_fwd_kernel, lif_bwd_kernel


[docs] class LIFSubFunction(torch.autograd.Function):
[docs] @staticmethod def forward(ctx, x_seq, v_init, decay, threshold, alpha, surrogate_type): """ Forward pass of the Leaky Integrate-and-Fire (LIF) function. Uses fused Triton kernels for high-performance execution. Parameters ---------- ctx : context Context object for saving tensors for the backward pass. x_seq : torch.Tensor Input spike/current sequence. Shape: (n_steps, batch_size * n_neurons). v_init : torch.Tensor Initial membrane potentials. Shape: (batch_size * n_neurons,). decay : torch.Tensor Decay factor (scalar). threshold : torch.Tensor Spiking threshold (scalar). alpha : torch.Tensor Surrogate gradient scaling parameter (scalar). surrogate_type : int Integer ID representing the surrogate gradient function. Returns ------- tuple - spikes (torch.Tensor): Output spike sequence. Shape same as x_seq. - v_final (torch.Tensor): Final membrane potentials. Shape same as v_init. - v_seq (torch.Tensor): Membrane potential sequence. Shape same as x_seq. """ x_seq, v_init = x_seq.contiguous(), v_init.contiguous() n_steps, n_neurons = x_seq.shape out_spikes = torch.empty_like(x_seq) # Bit-packing: 32 spikes per int32 n_int32 = (n_steps + 31) // 32 out_spikes_packed = torch.empty( (n_int32, n_neurons), device=x_seq.device, dtype=torch.int32 ) v_seq = torch.empty_like(x_seq) # Store membrane potentials for backward v_final = torch.empty_like(v_init) grid = (triton.cdiv(n_neurons, 1024),) lif_fwd_kernel[grid]( x_seq, v_init, out_spikes, out_spikes_packed, v_seq, v_final, n_neurons, n_steps, decay.item(), threshold.item(), BLOCK_SIZE=1024, ) # Save packed spikes for backward to save memory ctx.save_for_backward(out_spikes_packed, v_seq, v_init, decay, threshold, alpha) ctx.surrogate_type = surrogate_type ctx.mark_non_differentiable(v_seq) return out_spikes, v_final, v_seq
[docs] @staticmethod def backward(ctx, grad_spikes, grad_v_final, grad_v_seq): """ Backward pass of the LIF function. Computes gradients for inputs and parameters using BPTT and surrogate gradients. Parameters ---------- ctx : context Context object with saved tensors. grad_spikes : torch.Tensor Gradient of loss w.r.t. output spikes. grad_v_final : torch.Tensor Gradient of loss w.r.t. final membrane potentials. grad_v_seq : torch.Tensor Gradient of loss w.r.t. voltage sequence (unused). Returns ------- tuple Gradients w.r.t. (x_seq, v_init, decay, threshold, alpha, surrogate_type). """ out_spikes_packed, v_seq, v_init, decay, threshold, alpha = ctx.saved_tensors surrogate_type = ctx.surrogate_type n_steps, n_neurons = v_seq.shape grad_x = torch.empty_like(v_seq) # Gradients for parameters grad_decay = torch.zeros(1, device=grad_spikes.device, dtype=torch.float32) grad_threshold = torch.zeros(1, device=grad_spikes.device, dtype=torch.float32) grad_alpha = torch.zeros(1, device=grad_spikes.device, dtype=torch.float32) if grad_v_final is None: grad_v_final = torch.zeros_like(v_init) grid = (triton.cdiv(n_neurons, 1024),) lif_bwd_kernel[grid]( grad_spikes.contiguous(), out_spikes_packed, v_seq.contiguous(), grad_x, grad_v_final.contiguous(), v_init.contiguous(), n_neurons, n_steps, decay, threshold, alpha, grad_decay, grad_threshold, grad_alpha, surrogate_type, BLOCK_SIZE=1024, ) # Returns grads for (x_seq, v_init, decay, threshold, alpha, surrogate_type) return grad_x, grad_v_final, grad_decay, grad_threshold, grad_alpha, None