https://github.com/aaltoml/sfr-experiments
Code accompanying ICLR 2024 paper "Function-space Parameterization of Neural Networks for Sequential Learning"
Science Score: 23.0%
This score indicates how likely this project is to be science-related based on various indicators:
-
○CITATION.cff file
-
✓codemeta.json file
Found codemeta.json file -
○.zenodo.json file
-
○DOI references
-
✓Academic publication links
Links to: arxiv.org -
○Academic email domains
-
○Institutional organization owner
-
○JOSS paper metadata
-
○Scientific vocabulary similarity
Low similarity (10.0%) to scientific vocabulary
Keywords
Repository
Code accompanying ICLR 2024 paper "Function-space Parameterization of Neural Networks for Sequential Learning"
Basic Info
- Host: GitHub
- Owner: AaltoML
- License: mit
- Language: Jupyter Notebook
- Default Branch: main
- Homepage: https://aaltoml.github.io/sfr/
- Size: 2.14 MB
Statistics
- Stars: 0
- Watchers: 4
- Forks: 0
- Open Issues: 0
- Releases: 0
Topics
Metadata Files
README.md
Function-space Parameterization of Neural Networks for Sequential Learning
Code accompanying ICLR 2024 submission Function-space Parameterization of Neural Networks for Sequential Learning. This repository contains code for reproducing the experiments in the ICLR 2024 paper. Please see this repo for a clean and minimal implementation of Sparse Function-space Representation of Neural Networks (SFR). We recommend using the clean and minimal repo.
|
Function-space Parameterization of Neural Networks for Sequential Learning Aidan Scannell*, Riccardo Mereu*, Paul Chang, Ella Tamir, Joni Pajarinen, Arno Solin International Conference on Learning Representations (ICLR 2024) |
|
Sparse Function-space Representation of Neural Networks Aidan Scannell*, Riccardo Mereu*, Paul Chang, Ella Tamir, Joni Pajarinen, Arno Solin ICML 2023 Workshop on Duality Principles for Modern Machine Learning |
Install
Install using virtual environment
Make a virtual environment:
sh
python -m venv .venv
Activate it with:
sh
source .venv/bin/activate
Install the dependencies with:
sh
python -m pip install --upgrade pip
pip install laplace-torch==0.1a2
pip install -e ".[experiments]"
We install laplace-torch separately due to version conflicts with backpacpk-for-pytorch.
Note that laplace-torch is only used for running the baselines.
Install using pip
Alternatively, manually install the dependencies with:
sh
pip install laplace-torch==0.1a2
pip install -r requirements.txt
Reproducing experiments
See experiments for details on how to reproduce the results in the paper. This includes code for generating the tables and figures.
Useage
See the notebooks/README.md for how to use our code for both regression and classification.
Example
Here's a short example: ```python import src import torch
torch.setdefaultdtype(torch.float64)
def func(x, noise=True): return torch.sin(x * 5) / x + torch.cos(x * 10)
Toy data set
Xtrain = torch.rand((100, 1)) * 2 Ytrain = func(Xtrain, noise=True) data = (Xtrain, Y_train)
Training config
width = 64 numepochs = 1000 batchsize = 16 learningrate = 1e-3 delta = 0.00005 # prior precision dataloader = torch.utils.data.DataLoader( torch.utils.data.TensorDataset(*data), batchsize=batchsize )
Create a neural network
network = torch.nn.Sequential( torch.nn.Linear(1, width), torch.nn.Tanh(), torch.nn.Linear(width, width), torch.nn.Tanh(), torch.nn.Linear(width, 1), )
Instantiate SFR (handles NN training/prediction as they're coupled via the prior/likelihood)
sfr = src.sfr.SFR( network=network, prior=src.priors.Gaussian(params=network.parameters, delta=delta), likelihood=src.likelihoods.Gaussian(sigmanoise=2), outputdim=1, numinducing=32, dualbatch_size=None, # this reduces the memory required for computing dual parameters jitter=1e-4, )
sfr.train() optimizer = torch.optim.Adam([{"params": sfr.parameters()}], lr=learningrate) for epochidx in range(numepochs): for batchidx, batch in enumerate(dataloader): x, y = batch loss = sfr.loss(x, y) optimizer.zerograd() loss.backward() optimizer.step()
sfr.set_data(data) # This builds the dual parameters
Make predictions in function space
Xtest = torch.linspace(-0.7, 3.5, 300, dtype=torch.float64).reshape(-1, 1) fmean, fvar = sfr.predictf(X_test)
Make predictions in output space
ymean, yvar = sfr.predict(X_test) ```
Citation
Please consider citing our ICLR 2024 paper.
bibtex
@inproceedings{scannellFunction2024,
title = {Function-space Prameterization of Neural Networks for Sequential Learning},
booktitle = {Proceedings of The Twelth International Conference on Learning Representations (ICLR 2024)},
author = {Aidan Scannell and Riccardo Mereu and Paul Chang and Ella Tami and Joni Pajarinen and Arno Solin},
year = {2024},
month = {5},
}
Owner
- Name: AaltoML
- Login: AaltoML
- Kind: organization
- Location: Finland
- Website: http://arno.solin.fi
- Repositories: 20
- Profile: https://github.com/AaltoML
Machine learning group at Aalto University lead by Prof. Solin
GitHub Events
Total
Last Year
Dependencies
- chex ==0.0.7
- dm-haiku ==0.0.4
- dm-sonnet ==2.0.0
- dm-tree ==0.1.6
- gpflow ==2.2.1
- gpustat ==0.6.0
- jax ==0.2.11
- jaxlib ==0.1.64
- joypy *
- keras ==2.6.0
- latex ==0.7.0
- neural-tangents ==0.3.6
- optax ==0.0.2
- pandas ==1.3.5
- pyreadr *
- pytest *
- retry ==0.9.2
- scikit-learn ==0.21.3
- scipy ==1.6.2
- seqtools ==1.1.0
- tensorflow-addons ==0.13.0
- tensorflow-datasets ==4.2.0
- tensorflow-probability ==0.13.0
- tfp-nightly ==0.12.0.dev20201117
- uncertainty-metrics ==0.0.81
- uncertainty_metrics *
- Babel ==2.14.0
- GitPython ==3.1.42
- Jinja2 ==3.1.3
- Markdown ==3.5.2
- MarkupSafe ==2.1.5
- PyOpenGL ==3.1.7
- PyYAML ==6.0.1
- Pygments ==2.17.2
- QtPy ==2.4.1
- Send2Trash ==1.8.2
- Werkzeug ==3.0.1
- absl-py ==2.1.0
- antlr4-python3-runtime ==4.9.3
- anyio ==4.3.0
- appdirs ==1.4.4
- argon2-cffi ==23.1.0
- argon2-cffi-bindings ==21.2.0
- arrow ==1.3.0
- asdfghjkl ==0.1a2
- asttokens ==2.4.1
- async-lru ==2.0.4
- attrs ==23.2.0
- ax-platform ==0.3.3
- backpack-for-pytorch ==1.6.0
- beautifulsoup4 ==4.12.3
- bleach ==6.1.0
- botorch ==0.8.5
- certifi ==2024.2.2
- cffi ==1.16.0
- charset-normalizer ==3.3.2
- click ==8.1.7
- cloudpickle ==3.0.0
- cmake ==3.28.3
- comm ==0.2.1
- cycler ==0.12.1
- debugpy ==1.8.1
- decorator ==4.4.2
- defusedxml ==0.7.1
- dm-control ==1.0.11
- dm-env ==1.6
- dm-tree ==0.1.8
- docker-pycreds ==0.4.0
- einops ==0.7.0
- exceptiongroup ==1.2.0
- executing ==2.0.1
- fastjsonschema ==2.19.1
- filelock ==3.13.1
- fonttools ==4.49.0
- fqdn ==1.5.1
- fsspec ==2024.2.0
- gitdb ==4.0.11
- glfw ==2.6.5
- gpytorch ==1.10
- grpcio ==1.60.1
- gym ==0.26.2
- gym-notices ==0.0.8
- h11 ==0.14.0
- httpcore ==1.0.3
- httpx ==0.26.0
- hydra-core ==1.3.2
- hydra-submitit-launcher ==1.2.0
- idna ==3.6
- imageio ==2.34.0
- imageio-ffmpeg ==0.4.9
- importlib-metadata ==7.0.1
- ipykernel ==6.29.2
- ipython ==8.18.1
- ipywidgets ==8.1.2
- isoduration ==20.11.0
- jedi ==0.19.1
- joblib ==1.3.2
- json5 ==0.9.17
- jsonpointer ==2.4
- jsonschema ==4.21.1
- jsonschema-specifications ==2023.12.1
- jupyter ==1.0.0
- jupyter-console ==6.6.3
- jupyter-events ==0.9.0
- jupyter-lsp ==2.2.2
- jupyter_client ==8.6.0
- jupyter_core ==5.7.1
- jupyter_server ==2.12.5
- jupyter_server_terminals ==0.5.2
- jupyterlab ==4.1.2
- jupyterlab_pygments ==0.3.0
- jupyterlab_server ==2.25.3
- jupyterlab_widgets ==3.0.10
- kiwisolver ==1.4.5
- labmaze ==1.0.6
- laplace-torch ==0.1a2
- linear-operator ==0.4.0
- lit ==17.0.6
- lxml ==5.1.0
- matplotlib ==3.5.1
- matplotlib-inline ==0.1.6
- mistune ==3.0.2
- moviepy ==1.0.3
- mpmath ==1.3.0
- mujoco ==2.3.3
- multipledispatch ==1.0.0
- nbclient ==0.9.0
- nbconvert ==7.16.1
- nbformat ==5.9.2
- nest-asyncio ==1.6.0
- netcal ==1.3.5
- networkx ==3.2.1
- notebook ==7.1.0
- notebook_shim ==0.2.4
- numpy ==1.24.2
- nvidia-cublas-cu11 ==11.10.3.66
- nvidia-cublas-cu12 ==12.1.3.1
- nvidia-cuda-cupti-cu11 ==11.7.101
- nvidia-cuda-cupti-cu12 ==12.1.105
- nvidia-cuda-nvrtc-cu11 ==11.7.99
- nvidia-cuda-nvrtc-cu12 ==12.1.105
- nvidia-cuda-runtime-cu11 ==11.7.99
- nvidia-cuda-runtime-cu12 ==12.1.105
- nvidia-cudnn-cu11 ==8.5.0.96
- nvidia-cudnn-cu12 ==8.9.2.26
- nvidia-cufft-cu11 ==10.9.0.58
- nvidia-cufft-cu12 ==11.0.2.54
- nvidia-curand-cu11 ==10.2.10.91
- nvidia-curand-cu12 ==10.3.2.106
- nvidia-cusolver-cu11 ==11.4.0.1
- nvidia-cusolver-cu12 ==11.4.5.107
- nvidia-cusparse-cu11 ==11.7.4.91
- nvidia-cusparse-cu12 ==12.1.0.106
- nvidia-nccl-cu11 ==2.14.3
- nvidia-nccl-cu12 ==2.19.3
- nvidia-nvjitlink-cu12 ==12.3.101
- nvidia-nvtx-cu11 ==11.7.91
- nvidia-nvtx-cu12 ==12.1.105
- omegaconf ==2.3.0
- opencv-python ==4.7.0.72
- opt-einsum ==3.3.0
- overrides ==7.7.0
- packaging ==23.2
- pandas ==2.2.0
- pandocfilters ==1.5.1
- parso ==0.8.3
- pexpect ==4.9.0
- pillow ==10.2.0
- platformdirs ==4.2.0
- plotly ==5.19.0
- proglog ==0.1.10
- prometheus_client ==0.20.0
- prompt-toolkit ==3.0.43
- protobuf ==4.25.3
- psutil ==5.9.8
- ptyprocess ==0.7.0
- pure-eval ==0.2.2
- pycparser ==2.21
- pygame ==2.1.0
- pyparsing ==3.1.1
- pyro-api ==0.1.2
- pyro-ppl ==1.9.0
- python-dateutil ==2.8.2
- python-json-logger ==2.0.7
- pytz ==2024.1
- pyzmq ==25.1.2
- qtconsole ==5.5.1
- referencing ==0.33.0
- requests ==2.31.0
- rfc3339-validator ==0.1.4
- rfc3986-validator ==0.1.1
- rpds-py ==0.18.0
- scikit-learn ==1.4.1.post1
- scipy ==1.12.0
- seaborn ==0.13.2
- sentry-sdk ==1.40.5
- setproctitle ==1.3.3
- six ==1.16.0
- smmap ==5.0.1
- sniffio ==1.3.0
- soupsieve ==2.5
- stack-data ==0.6.3
- submitit ==1.5.1
- sympy ==1.12
- tenacity ==8.2.3
- tensorboard ==2.16.2
- tensorboard-data-server ==0.7.2
- terminado ==0.18.0
- threadpoolctl ==3.3.0
- tikzplotlib ==0.9.8
- tinycss2 ==1.2.1
- tomli ==2.0.1
- torch ==2.0.0
- torchaudio ==2.0.1
- torchtyping ==0.1.4
- torchvision ==0.15.1
- tornado ==6.4
- tqdm ==4.66.2
- traitlets ==5.14.1
- triton ==2.0.0
- typeguard ==2.13.3
- types-python-dateutil ==2.8.19.20240106
- typing_extensions ==4.9.0
- tzdata ==2024.1
- unfoldNd ==0.2.1
- uri-template ==1.3.0
- urllib3 ==2.2.1
- wandb ==0.16.3
- wcwidth ==0.2.13
- webcolors ==1.13
- webencodings ==0.5.1
- websocket-client ==1.7.0
- widgetsnbextension ==4.0.10
- zipp ==3.17.0
- ax-platform ==0.3.3
- laplace-torch *
- matplotlib ==3.5.1
- numpy ==1.24.2
- torch ==2.0.0
- torchtyping ==0.1.4
- torchvision ==0.15.1