tether.functional package¶
Submodules¶
tether.functional.alif module¶
tether.functional.lif module¶
- class tether.functional.lif.LIFSubFunction(*args, **kwargs)[source]¶
Bases:
Function- static backward(ctx, grad_spikes, grad_v_final, grad_v_seq)[source]¶
Backward pass of the LIF function.
Computes gradients for inputs and parameters using BPTT and surrogate gradients.
- Parameters:
ctx (context) – Context object with saved tensors.
grad_spikes (torch.Tensor) – Gradient of loss w.r.t. output spikes.
grad_v_final (torch.Tensor) – Gradient of loss w.r.t. final membrane potentials.
grad_v_seq (torch.Tensor) – Gradient of loss w.r.t. voltage sequence (unused).
- Returns:
Gradients w.r.t. (x_seq, v_init, decay, threshold, alpha, surrogate_type).
- Return type:
tuple
- static forward(ctx, x_seq, v_init, decay, threshold, alpha, surrogate_type)[source]¶
Forward pass of the Leaky Integrate-and-Fire (LIF) function.
Uses fused Triton kernels for high-performance execution.
- Parameters:
ctx (context) – Context object for saving tensors for the backward pass.
x_seq (torch.Tensor) – Input spike/current sequence. Shape: (n_steps, batch_size * n_neurons).
v_init (torch.Tensor) – Initial membrane potentials. Shape: (batch_size * n_neurons,).
decay (torch.Tensor) – Decay factor (scalar).
threshold (torch.Tensor) – Spiking threshold (scalar).
alpha (torch.Tensor) – Surrogate gradient scaling parameter (scalar).
surrogate_type (int) – Integer ID representing the surrogate gradient function.
- Returns:
spikes (torch.Tensor): Output spike sequence. Shape same as x_seq.
v_final (torch.Tensor): Final membrane potentials. Shape same as v_init.
v_seq (torch.Tensor): Membrane potential sequence. Shape same as x_seq.
- Return type:
tuple
tether.functional.plif module¶
- class tether.functional.plif.PLIFSubFunction(*args, **kwargs)[source]¶
Bases:
Function- static backward(ctx, grad_spikes, grad_v_final, grad_v_seq)[source]¶
Backward pass of the PLIF function.
Computes gradients for inputs and parameters (decay, threshold, alpha) using the fused Triton backward kernel.
- Parameters:
ctx (context) – Context object with saved tensors.
grad_spikes (torch.Tensor) – Gradient of loss w.r.t. output spikes.
grad_v_final (torch.Tensor) – Gradient of loss w.r.t. final membrane potential.
grad_v_seq (torch.Tensor) – Gradient of loss w.r.t. membrane potential sequence (usually None or unused).
- Returns:
Gradients w.r.t. (x_seq, v_init, decay, threshold, alpha, surrogate_type). Note: surrogate_type gradient is None.
- Return type:
tuple
- static forward(ctx, x_seq, v_init, decay, threshold, alpha, surrogate_type)[source]¶
Forward pass of the PLIF function with vector-valued parameters.
- Parameters:
ctx (context) – Context object for saving tensors for backward pass.
x_seq (torch.Tensor) – Input spike/current sequence. Shape: (n_steps, batch_size * n_neurons) or (n_steps, n_total_neurons).
v_init (torch.Tensor) – Initial membrane potentials. Shape: (n_total_neurons,).
decay (torch.Tensor) – Decay factor vector. Shape: (n_total_neurons,).
threshold (torch.Tensor) – Threshold vector. Shape: (n_total_neurons,).
alpha (torch.Tensor) – Surrogate gradient scaling parameter (scalar).
surrogate_type (int) – Integer ID representing the surrogate gradient function (0: Arctan, 1: Sigmoid, 2: FastSigmoid).
- Returns:
spikes (torch.Tensor): Output spike sequence. Shape same as x_seq.
v_final (torch.Tensor): Final membrane potentials. Shape same as v_init.
v_seq (torch.Tensor): Membrane potential sequence. Shape same as x_seq.
- Return type:
tuple