https://github.com/ceyron/trainax
Training methodologies for autoregressive neural operators in JAX.
Science Score: 10.0%
This score indicates how likely this project is to be science-related based on various indicators:
-
○CITATION.cff file
-
○codemeta.json file
-
○.zenodo.json file
-
○DOI references
-
✓Academic publication links
Links to: arxiv.org -
○Academic email domains
-
○Institutional organization owner
-
○JOSS paper metadata
-
○Scientific vocabulary similarity
Low similarity (11.2%) to scientific vocabulary
Keywords
autoregressive-models
deep-learning
differentiable-physics
optax
optimization
timeseries-forecasting
unrolled-training
Last synced: 6 months ago
·
JSON representation
Repository
Training methodologies for autoregressive neural operators in JAX.
Basic Info
- Host: GitHub
- Owner: Ceyron
- License: mit
- Language: Python
- Default Branch: main
- Homepage: https://fkoehler.site/trainax/
- Size: 6.43 MB
Statistics
- Stars: 2
- Watchers: 3
- Forks: 0
- Open Issues: 1
- Releases: 0
Topics
autoregressive-models
deep-learning
differentiable-physics
optax
optimization
timeseries-forecasting
unrolled-training
Created about 2 years ago
· Last pushed over 1 year ago
https://github.com/Ceyron/trainax/blob/main/
Learning Methodologies for Autoregressive Neural Emulators.
Installation Documentation Quickstart Background Features Citation
Convenience abstractions using `optax` to train neural networks to autoregressively emulate time-dependent problems taking care of trajectory subsampling and offering a wide range of training methodologies (regarding unrolling length and including differentiable physics). ## Installation ```bash pip install trainax ``` Requires Python 3.10+ and JAX 0.4.13+. [JAX install guide](https://jax.readthedocs.io/en/latest/installation.html). ## Documentation The documentation is available at [fkoehler.site/trainax](https://fkoehler.site/trainax/). ## Quickstart Train a kernel size 2 linear convolution (no bias) to become an emulator for the 1D advection problem. ```python import jax import jax.numpy as jnp import equinox as eqx import optax # pip install optax import trainax as tx CFL = -0.75 ref_data = tx.sample_data.advection_1d_periodic( cfl = CFL, key = jax.random.PRNGKey(0), ) linear_conv_kernel_2 = eqx.nn.Conv1d( 1, 1, 2, padding="SAME", padding_mode="CIRCULAR", use_bias=False, key=jax.random.PRNGKey(73) ) sup_1_trainer, sup_5_trainer, sup_20_trainer = ( tx.trainer.SupervisedTrainer( ref_data, num_rollout_steps=r, optimizer=optax.adam(1e-2), num_training_steps=1000, batch_size=32, ) for r in (1, 5, 20) ) sup_1_conv, sup_1_loss_history = sup_1_trainer( linear_conv_kernel_2, key=jax.random.PRNGKey(42) ) sup_5_conv, sup_5_loss_history = sup_5_trainer( linear_conv_kernel_2, key=jax.random.PRNGKey(42) ) sup_20_conv, sup_20_loss_history = sup_20_trainer( linear_conv_kernel_2, key=jax.random.PRNGKey(42) ) FOU_STENCIL = jnp.array([1+CFL, -CFL]) print(jnp.linalg.norm(sup_1_conv.weight - FOU_STENCIL)) # 0.033 print(jnp.linalg.norm(sup_5_conv.weight - FOU_STENCIL)) # 0.025 print(jnp.linalg.norm(sup_20_conv.weight - FOU_STENCIL)) # 0.017 ``` Increasing the supervised unrolling steps during training makes the learned stencil come closer to the numerical FOU stencil. ## Background After the discretization of space and time, the simulation of a time-dependent partial differential equation amounts to the repeated application of a simulation operator $\mathcal{P}_h$. Here, we are interested in imitating/emulating this physical/numerical operator with a neural network $f_\theta$. This repository is concerned with an abstract implementation of all ways we can frame a learning problem to inject "knowledge" from $\mathcal{P}_h$ into $f_\theta$. Assume we have a distribution of initial conditions $\mathcal{Q}$ from which we sample $S$ initial states, $u^{[0]} \propto \mathcal{Q}$. Then, we can save them in an array of shape $(S, C, *N)$ (with C channels and an arbitrary number of spatial axes of dimension N) and repeatedly apply $\mathcal{P}$ to obtain the training trajectory of shape $(S, T+1, C, *N)$. For a one-step supervised learning task, we substack the training trajectory into windows of size $2$ and merge the two leftover batch axes to get a data array of shape $(S \cdot T, 2, N)$ that can be used in supervised learning scenario $$ L(\theta) = \mathbb{E}_{(u^{[0]}, u^{[1]}) \sim \mathcal{Q}} \left[ l\left( f_\theta(u^{[0]}), u^{[1]} \right) \right] $$ where $l$ is a **time-level loss**. In the easiest case $l = \text{MSE}$. `Trainax` supports way more than just one-step supervised learning, e.g., to train with unrolled steps, to include the reference simulator $\mathcal{P}_h$ in training, train on residuum conditions instead of resolved reference states, cut and modify the gradient flow, etc. ## Features * Wide collection of unrolled training methodologies: * Supervised * Diverted Chain * Mix Chain * Residuum * Based on [JAX](https://github.com/google/jax): * One of the best Automatic Differentiation engines (forward & reverse) * Automatic vectorization * Backend-agnostic code (run on CPU, GPU, and TPU) * Build on top and compatible with [Equinox](https://github.com/patrick-kidger/equinox) * Batch-Parallel Training * Collection of Callbacks * Composability ## Citation This package was developed as part of the [APEBench paper (arxiv.org/abs/2411.00180)](https://arxiv.org/abs/2411.00180) (accepted at Neurips 2024). If you find it useful for your research, please consider citing it: ```bibtex @article{koehler2024apebench, title={{APEBench}: A Benchmark for Autoregressive Neural Emulators of {PDE}s}, author={Felix Koehler and Simon Niedermayr and R{\"}udiger Westermann and Nils Thuerey}, journal={Advances in Neural Information Processing Systems (NeurIPS)}, volume={38}, year={2024} } ``` (Feel free to also give the project a star on GitHub if you like it.) [Here](https://github.com/tum-pbs/apebench) you can find the APEBench benchmark suite. ## Funding The main author (Felix Koehler) is a PhD student in the group of [Prof. Thuerey at TUM](https://ge.in.tum.de/) and his research is funded by the [Munich Center for Machine Learning](https://mcml.ai/). ## License MIT, see [here](https://github.com/Ceyron/trainax/blob/main/LICENSE.txt) --- > [fkoehler.site](https://fkoehler.site/) · > GitHub [@ceyron](https://github.com/ceyron) · > X [@felix_m_koehler](https://twitter.com/felix_m_koehler) · > LinkedIn [Felix Khler](www.linkedin.com/in/felix-koehler)
![]()
Owner
- Name: Felix Köhler
- Login: Ceyron
- Kind: user
- Location: Munich
- Website: www.linkedin.com/in/felix-koehler
- Twitter: felix_m_koehler
- Repositories: 6
- Profile: https://github.com/Ceyron
🤖 Machine Learning & 🌊 Simulation. I love open science and open education.
GitHub Events
Total
- Create event: 3
- Issues event: 1
- Release event: 2
- Watch event: 7
- Delete event: 1
- Public event: 1
- Push event: 5
- Pull request event: 2
Last Year
- Create event: 3
- Issues event: 1
- Release event: 2
- Watch event: 7
- Delete event: 1
- Public event: 1
- Push event: 5
- Pull request event: 2