https://github.com/aaltoml/sfr
PyTorch implementation of Sparse Function-space Representation of Neural Networks
Science Score: 10.0%
This score indicates how likely this project is to be science-related based on various indicators:
-
○CITATION.cff file
-
○codemeta.json file
-
○.zenodo.json file
-
○DOI references
-
✓Academic publication links
Links to: arxiv.org -
○Academic email domains
-
○Institutional organization owner
-
○JOSS paper metadata
-
○Scientific vocabulary similarity
Low similarity (9.9%) to scientific vocabulary
Keywords
Repository
PyTorch implementation of Sparse Function-space Representation of Neural Networks
Basic Info
- Host: GitHub
- Owner: AaltoML
- License: mit
- Language: Jupyter Notebook
- Default Branch: main
- Homepage: https://aaltoml.github.io/sfr/
- Size: 100 MB
Statistics
- Stars: 4
- Watchers: 3
- Forks: 0
- Open Issues: 0
- Releases: 0
Topics
Metadata Files
README.md
SFR - Sparse Function-space Representation of Neural Networks
This repository contains a clean and minimal PyTorch implementation of Sparse Function-space Representation (SFR) of Neural Networks. If you'd like to use SFR we recommend using this repo. Please see sfr-experiments for reproducing the experiments in the ICLR 2024 paper.
Install
CPU
Create an environment with:
sh
conda env create -f env_cpu.yaml
Activate the environment with:
sh
source activate sfr
NVIDIA GPU
Create an environment with:
sh
conda env create -f env_nvidia.yaml
Activate the environment with:
sh
source activate sfr
Useage
See the notebooks for how to use our code for both regression and classification.
Image Classification
We provide a minimal training script in train.py which can be used to train a CNN and fit SFR
on MNIST/Fashion-MNIST/CIFAR-10. It is advised to run this on GPU.
Example
Here's a short example: ```python import src import torch
torch.setdefaultdtype(torch.float64)
def func(x, noise=True): return torch.sin(x * 5) / x + torch.cos(x * 10)
Toy data set
Xtrain = torch.rand((100, 1)) * 2 Ytrain = func(Xtrain, noise=True) data = (Xtrain, Y_train)
Training config
width = 64 numepochs = 1000 batchsize = 16 learningrate = 1e-3 delta = 0.00005 # prior precision dataloader = torch.utils.data.DataLoader( torch.utils.data.TensorDataset(*data), batchsize=batchsize )
Create a neural network
network = torch.nn.Sequential( torch.nn.Linear(1, width), torch.nn.Tanh(), torch.nn.Linear(width, width), torch.nn.Tanh(), torch.nn.Linear(width, 1), )
Instantiate SFR (handles NN training/prediction as they're coupled via the prior/likelihood)
sfr = src.SFR( network=network, prior=src.priors.Gaussian(params=network.parameters, delta=delta), likelihood=src.likelihoods.Gaussian(sigmanoise=2), outputdim=1, numinducing=32, dualbatch_size=None, # this reduces the memory required for computing dual parameters jitter=1e-4, )
sfr.train() optimizer = torch.optim.Adam([{"params": sfr.parameters()}], lr=learningrate) for epochidx in range(numepochs): for batchidx, batch in enumerate(dataloader): x, y = batch loss = sfr.loss(x, y) optimizer.zerograd() loss.backward() optimizer.step()
sfr.set_data(data) # This builds the dual parameters
Make predictions in function space
Xtest = torch.linspace(-0.7, 3.5, 300, dtype=torch.float64).reshape(-1, 1) fmean, fvar = sfr.predictf(X_test)
Make predictions in output space
ymean, yvar = sfr.predict(X_test) ```
Development
Set up pre-commit by running:
sh
pre-commit install
Now when you commit the formatter/linter etc will automatically be run.
Citation
Please consider citing our conference paper
bibtex
@inproceedings{scannellFunction2024,
title = {Function-space Prameterization of Neural Networks for Sequential Learning},
booktitle = {Proceedings of The Twelth International Conference on Learning Representations (ICLR 2024)},
author = {Aidan Scannell and Riccardo Mereu and Paul Chang and Ella Tami and Joni Pajarinen and Arno Solin},
year = {2024},
month = {5},
}
Or our workshop
bibtex
@inproceedings{scannellSparse2023,
title = {Sparse Function-space Representation of Neural Networks},
maintitle = {ICML 2023 Workshop on Duality Principles for Modern Machine Learning},
author = {Aidan Scannell and Riccardo Mereu and Paul Chang and Ella Tami and Joni Pajarinen and Arno Solin},
year = {2023},
month = {7},
}
Owner
- Name: AaltoML
- Login: AaltoML
- Kind: organization
- Location: Finland
- Website: http://arno.solin.fi
- Repositories: 20
- Profile: https://github.com/AaltoML
Machine learning group at Aalto University lead by Prof. Solin
GitHub Events
Total
- Watch event: 2
Last Year
- Watch event: 2
Dependencies
- ax-platform *
- matplotlib ==3.5.1
- numpy ==1.24.2
- torch ==2.0.0
- torchtyping ==0.1.4
- torchvision ==0.15.1