Welcome to Tether’s documentation!

Tether

Gemini_Generated_Image_xofloxxofloxxofl

Tether is a Triton-powered framework for training and deploying Spiking Transformers and deep Spiking Neural Networks (SNNs).

We’ve solved the non-differentiability of discrete spikes by implementing high-performance Triton kernels with modular Surrogate Gradients.

Key Features

  • High-Performance Neurons:

    • LIF (Leaky Integrate-and-Fire): Standard spiking neuron with fused Triton kernels.

    • ALIF (Adaptive LIF): Neurons with adaptive thresholds for better temporal dynamics.

    • PLIF (Parametric LIF): Neurons with learnable, per-channel decay and threshold parameters.

  • Modular Surrogate Gradients: Choose from Arctan, Sigmoid, or FastSigmoid to train your SNNs effectively.

  • Linear Spike-Driven Attention: Eliminates the $O(N^2)$ Softmax bottleneck, allowing for massive context windows with significantly lower energy per inference.

  • Data Utilities: SpikingDatasetWrapper and encoding functions (rate_encoding, latency_encoding) to convert static datasets to spike trains.

  • Triton-Powered: Leverages OpenAI’s Triton language for custom CUDA kernels, enabling massive speedups (60x+) over vanilla PyTorch.

Installation

This project is managed with uv.

uv sync

Or install dependencies manually:

pip install torch triton numpy

Usage

Using PLIF with Sigmoid Surrogate

import torch
from tether import PLIF, Sigmoid

# Create a Parametric LIF layer with Sigmoid surrogate
# Decay and threshold are learnable vectors per neuron
layer = PLIF(
    n_neurons=128,
    init_decay=0.9,
    surrogate=Sigmoid(alpha=4.0)
).cuda()

# Input sequence: (Time, Batch, Neurons)
x = torch.randn(32, 16, 128).cuda()
spikes = layer(x)

Training a Spiking Language Model

The train_stories.py script demonstrates training a Spiking-LLM on the TinyShakespeare dataset.

python train_stories.py

Data Encoding

from tether.data import SpikingDatasetWrapper, rate_encoding
from torchvision.datasets import MNIST

# Wrap MNIST to output spike trains
spiking_mnist = SpikingDatasetWrapper(
    MNIST(root="./data", download=True, train=True),
    encode_fn=lambda x: rate_encoding(x, n_steps=10)
)

Architecture

  • tether.kernels: Custom Triton kernels for LIF, ALIF, and PLIF.

  • tether.functional: PyTorch autograd functions wrapping the Triton kernels.

  • tether.nn: Neural network modules including LIF, ALIF, PLIF, SpikingSelfAttention.

  • tether.data: Utilities for spike encoding and dataset wrapping.

License

Apache-2.0

API Reference:

Indices and tables