Slax: a composable JAX library for rapid and flexible prototyping of spiking neural networks

Spiking neural networks (SNNs) offer rich temporal dynamics and unique capabilities, but their training presents challenges. While backpropagation through time with surrogate gradients is the defacto standard for training SNNs, it scales poorly with long time sequences. Alternative learning rules an...

Full description

Saved in:
Bibliographic Details
Main Authors: Thomas M Summe, Siddharth Joshi
Format: Article
Language:English
Published: IOP Publishing 2025-01-01
Series:Neuromorphic Computing and Engineering
Subjects:
Online Access:https://doi.org/10.1088/2634-4386/ada9a8
Tags: Add Tag
No Tags, Be the first to tag this record!
Description
Summary:Spiking neural networks (SNNs) offer rich temporal dynamics and unique capabilities, but their training presents challenges. While backpropagation through time with surrogate gradients is the defacto standard for training SNNs, it scales poorly with long time sequences. Alternative learning rules and algorithms could help further develop models and systems across the spectrum of performance, bio-plausibility, and complexity. However, these alternatives are not consistently implemented with the same, if any, SNN framework, often complicating their comparison and use. To address this, we introduce Slax, a JAX-based library designed to accelerate SNN algorithm design and evaluation. Slax is compatible with the broader JAX and Flax ecosystem and provides optimized implementations of diverse training algorithms, enabling direct performance comparisons. Its toolkit includes methods to visualize and debug algorithms through loss landscapes, gradient similarities, and other metrics of model behavior during training. By streamlining the implementation and evaluation of novel SNN learning algorithms, Slax aims to facilitate research and development in this promising field.
ISSN:2634-4386