Slax: a composable JAX library for rapid and flexible prototyping of spiking neural networks
Spiking neural networks (SNNs) offer rich temporal dynamics and unique capabilities, but their training presents challenges. While backpropagation through time with surrogate gradients is the defacto standard for training SNNs, it scales poorly with long time sequences. Alternative learning rules an...
Saved in:
Main Authors: | , |
---|---|
Format: | Article |
Language: | English |
Published: |
IOP Publishing
2025-01-01
|
Series: | Neuromorphic Computing and Engineering |
Subjects: | |
Online Access: | https://doi.org/10.1088/2634-4386/ada9a8 |
Tags: |
Add Tag
No Tags, Be the first to tag this record!
|
Summary: | Spiking neural networks (SNNs) offer rich temporal dynamics and unique capabilities, but their training presents challenges. While backpropagation through time with surrogate gradients is the defacto standard for training SNNs, it scales poorly with long time sequences. Alternative learning rules and algorithms could help further develop models and systems across the spectrum of performance, bio-plausibility, and complexity. However, these alternatives are not consistently implemented with the same, if any, SNN framework, often complicating their comparison and use. To address this, we introduce Slax, a JAX-based library designed to accelerate SNN algorithm design and evaluation. Slax is compatible with the broader JAX and Flax ecosystem and provides optimized implementations of diverse training algorithms, enabling direct performance comparisons. Its toolkit includes methods to visualize and debug algorithms through loss landscapes, gradient similarities, and other metrics of model behavior during training. By streamlining the implementation and evaluation of novel SNN learning algorithms, Slax aims to facilitate research and development in this promising field. |
---|---|
ISSN: | 2634-4386 |