tether.functional package

Submodules

tether.functional.alif module

class tether.functional.alif.ALIFSubFunction(*args, **kwargs)[source]

Bases: Function

static backward(ctx, grad_spikes, grad_v_final, grad_a_final, grad_v_seq, grad_a_seq)[source]

Backward pass of the ALIF function.

static forward(ctx, x_seq, v_init, a_init, decay_v, decay_a, threshold, beta, alpha)[source]

Forward pass of the ALIF function.

tether.functional.lif module

class tether.functional.lif.LIFSubFunction(*args, **kwargs)[source]

Bases: Function

static backward(ctx, grad_spikes, grad_v_final, grad_v_seq)[source]

Backward pass of the LIF function.

Computes gradients for inputs and parameters using BPTT and surrogate gradients.

Parameters:
  • ctx (context) – Context object with saved tensors.

  • grad_spikes (torch.Tensor) – Gradient of loss w.r.t. output spikes.

  • grad_v_final (torch.Tensor) – Gradient of loss w.r.t. final membrane potentials.

  • grad_v_seq (torch.Tensor) – Gradient of loss w.r.t. voltage sequence (unused).

Returns:

Gradients w.r.t. (x_seq, v_init, decay, threshold, alpha, surrogate_type).

Return type:

tuple

static forward(ctx, x_seq, v_init, decay, threshold, alpha, surrogate_type)[source]

Forward pass of the Leaky Integrate-and-Fire (LIF) function.

Uses fused Triton kernels for high-performance execution.

Parameters:
  • ctx (context) – Context object for saving tensors for the backward pass.

  • x_seq (torch.Tensor) – Input spike/current sequence. Shape: (n_steps, batch_size * n_neurons).

  • v_init (torch.Tensor) – Initial membrane potentials. Shape: (batch_size * n_neurons,).

  • decay (torch.Tensor) – Decay factor (scalar).

  • threshold (torch.Tensor) – Spiking threshold (scalar).

  • alpha (torch.Tensor) – Surrogate gradient scaling parameter (scalar).

  • surrogate_type (int) – Integer ID representing the surrogate gradient function.

Returns:

  • spikes (torch.Tensor): Output spike sequence. Shape same as x_seq.

  • v_final (torch.Tensor): Final membrane potentials. Shape same as v_init.

  • v_seq (torch.Tensor): Membrane potential sequence. Shape same as x_seq.

Return type:

tuple

tether.functional.plif module

class tether.functional.plif.PLIFSubFunction(*args, **kwargs)[source]

Bases: Function

static backward(ctx, grad_spikes, grad_v_final, grad_v_seq)[source]

Backward pass of the PLIF function.

Computes gradients for inputs and parameters (decay, threshold, alpha) using the fused Triton backward kernel.

Parameters:
  • ctx (context) – Context object with saved tensors.

  • grad_spikes (torch.Tensor) – Gradient of loss w.r.t. output spikes.

  • grad_v_final (torch.Tensor) – Gradient of loss w.r.t. final membrane potential.

  • grad_v_seq (torch.Tensor) – Gradient of loss w.r.t. membrane potential sequence (usually None or unused).

Returns:

Gradients w.r.t. (x_seq, v_init, decay, threshold, alpha, surrogate_type). Note: surrogate_type gradient is None.

Return type:

tuple

static forward(ctx, x_seq, v_init, decay, threshold, alpha, surrogate_type)[source]

Forward pass of the PLIF function with vector-valued parameters.

Parameters:
  • ctx (context) – Context object for saving tensors for backward pass.

  • x_seq (torch.Tensor) – Input spike/current sequence. Shape: (n_steps, batch_size * n_neurons) or (n_steps, n_total_neurons).

  • v_init (torch.Tensor) – Initial membrane potentials. Shape: (n_total_neurons,).

  • decay (torch.Tensor) – Decay factor vector. Shape: (n_total_neurons,).

  • threshold (torch.Tensor) – Threshold vector. Shape: (n_total_neurons,).

  • alpha (torch.Tensor) – Surrogate gradient scaling parameter (scalar).

  • surrogate_type (int) – Integer ID representing the surrogate gradient function (0: Arctan, 1: Sigmoid, 2: FastSigmoid).

Returns:

  • spikes (torch.Tensor): Output spike sequence. Shape same as x_seq.

  • v_final (torch.Tensor): Final membrane potentials. Shape same as v_init.

  • v_seq (torch.Tensor): Membrane potential sequence. Shape same as x_seq.

Return type:

tuple

Module contents