torchdiffeq

Differentiable ODE solvers with full GPU support and O(1)-memory backpropagation.

https://github.com/rtqichen/torchdiffeq

Science Score: 64.0%

This score indicates how likely this project is to be science-related based on various indicators:

  • CITATION.cff file
    Found CITATION.cff file
  • codemeta.json file
    Found codemeta.json file
  • .zenodo.json file
    Found .zenodo.json file
  • DOI references
  • Academic publication links
    Links to: arxiv.org
  • Committers with academic emails
    1 of 22 committers (4.5%) from academic institutions
  • Institutional organization owner
  • JOSS paper metadata
  • Scientific vocabulary similarity
    Low similarity (13.9%) to scientific vocabulary
Last synced: 6 months ago · JSON representation ·

Repository

Differentiable ODE solvers with full GPU support and O(1)-memory backpropagation.

Basic Info
  • Host: GitHub
  • Owner: rtqichen
  • License: mit
  • Language: Python
  • Default Branch: master
  • Homepage:
  • Size: 8.16 MB
Statistics
  • Stars: 5,972
  • Watchers: 126
  • Forks: 963
  • Open Issues: 84
  • Releases: 0
Created over 7 years ago · Last pushed 11 months ago
Metadata Files
Readme License Citation

README.md

PyTorch Implementation of Differentiable ODE Solvers

This library provides ordinary differential equation (ODE) solvers implemented in PyTorch. Backpropagation through ODE solutions is supported using the adjoint method for constant memory cost. For usage of ODE solvers in deep learning applications, see reference [1].

As the solvers are implemented in PyTorch, algorithms in this repository are fully supported to run on the GPU.

Installation

To install latest stable version: pip install torchdiffeq

To install latest on GitHub: pip install git+https://github.com/rtqichen/torchdiffeq

Examples

Examples are placed in the examples directory.

We encourage those who are interested in using this library to take a look at examples/ode_demo.py for understanding how to use torchdiffeq to fit a simple spiral ODE.

ODE Demo

Basic usage

This library provides one main interface odeint which contains general-purpose algorithms for solving initial value problems (IVP), with gradients implemented for all main arguments. An initial value problem consists of an ODE and an initial value, dy/dt = f(t, y) y(t_0) = y_0. The goal of an ODE solver is to find a continuous trajectory satisfying the ODE that passes through the initial condition.

To solve an IVP using the default solver: ``` from torchdiffeq import odeint

odeint(func, y0, t) `` wherefuncis any callable implementing the ordinary differential equationf(t, x),y0is an _any_-D Tensor representing the initial values, andtis a 1-D Tensor containing the evaluation points. The initial time is taken to bet[0]`.

Backpropagation through odeint goes through the internals of the solver. Note that this is not numerically stable for all solvers (but should probably be fine with the default dopri5 method). Instead, we encourage the use of the adjoint method explained in [1], which will allow solving with as many steps as necessary due to O(1) memory usage.

To use the adjoint method: ``` from torchdiffeq import odeint_adjoint as odeint

odeint(func, y0, t) `` odeint_adjointsimply wraps aroundodeint`, but will use only O(1) memory in exchange for solving an adjoint ODE in the backward call.

The biggest gotcha is that func must be a nn.Module when using the adjoint method. This is used to collect parameters of the differential equation.

Differentiable event handling

We allow terminating an ODE solution based on an event function. Backpropagation through most solvers is supported. For usage of event handling in deep learning applications, see reference [2].

This can be invoked with odeint_event: from torchdiffeq import odeint_event odeint_event(func, y0, t0, *, event_fn, reverse_time=False, odeint_interface=odeint, **kwargs) - func and y0 are the same as odeint. - t0 is a scalar representing the initial time value. - event_fn(t, y) returns a tensor, and is a required keyword argument. - reverse_time is a boolean specifying whether we should solve in reverse time. Default is False. - odeint_interface is one of odeint or odeint_adjoint, specifying whether adjoint mode should be used for differentiating through the ODE solution. Default is odeint. - **kwargs: any remaining keyword arguments are passed to odeint_interface.

The solve is terminated at an event time t and state y when an element of event_fn(t, y) is equal to zero. Multiple outputs from event_fn can be used to specify multiple event functions, of which the first to trigger will terminate the solve.

Both the event time and final state are returned from odeint_event, and can be differentiated. Gradients will be backpropagated through the event function. NOTE: parameters for the event function must be in the state itself to obtain gradients.

The numerical precision for the event time is determined by the atol argument.

See example of simulating and differentiating through a bouncing ball in examples/bouncing_ball.py. See example code for learning a simple event function in examples/learn_physics.py.

Bouncing Ball

Keyword arguments for odeint(_adjoint)

Keyword arguments:

  • rtol Relative tolerance.
  • atol Absolute tolerance.
  • method One of the solvers listed below.
  • options A dictionary of solver-specific options, see the further documentation.

List of ODE Solvers:

Adaptive-step: - dopri8 Runge-Kutta of order 8 of Dormand-Prince-Shampine. - dopri5 Runge-Kutta of order 5 of Dormand-Prince-Shampine [default]. - bosh3 Runge-Kutta of order 3 of Bogacki-Shampine. - fehlberg2 Runge-Kutta-Fehlberg of order 2. - adaptive_heun Runge-Kutta of order 2.

Fixed-step: - euler Euler method. - midpoint Midpoint method. - rk4 Fourth-order Runge-Kutta with 3/8 rule. - explicit_adams Explicit Adams-Bashforth. - implicit_adams Implicit Adams-Bashforth-Moulton.

Additionally, all solvers available through SciPy are wrapped for use with scipy_solver.

For most problems, good choices are the default dopri5, or to use rk4 with options=dict(step_size=...) set appropriately small. Adjusting the tolerances (adaptive solvers) or step size (fixed solvers), will allow for trade-offs between speed and accuracy.

Frequently Asked Questions

Take a look at our FAQ for frequently asked questions.

Further documentation

For details of the adjoint-specific and solver-specific options, check out the further documentation.

References

Applications of differentiable ODE solvers and event handling are discussed in these two papers:

Ricky T. Q. Chen, Yulia Rubanova, Jesse Bettencourt, David Duvenaud. "Neural Ordinary Differential Equations." Advances in Neural Information Processing Systems. 2018. [arxiv]

@article{chen2018neuralode, title={Neural Ordinary Differential Equations}, author={Chen, Ricky T. Q. and Rubanova, Yulia and Bettencourt, Jesse and Duvenaud, David}, journal={Advances in Neural Information Processing Systems}, year={2018} }

Ricky T. Q. Chen, Brandon Amos, Maximilian Nickel. "Learning Neural Event Functions for Ordinary Differential Equations." International Conference on Learning Representations. 2021. [arxiv]

@article{chen2021eventfn, title={Learning Neural Event Functions for Ordinary Differential Equations}, author={Chen, Ricky T. Q. and Amos, Brandon and Nickel, Maximilian}, journal={International Conference on Learning Representations}, year={2021} }

The seminorm option for computing adjoints is discussed in

Patrick Kidger, Ricky T. Q. Chen, Terry Lyons. "'Hey, that’s not an ODE': Faster ODE Adjoints via Seminorms." International Conference on Machine Learning. 2021. [arxiv] @article{kidger2021hey, title={"Hey, that's not an ODE": Faster ODE Adjoints via Seminorms.}, author={Kidger, Patrick and Chen, Ricky T. Q. and Lyons, Terry J.}, journal={International Conference on Machine Learning}, year={2021} }


If you found this library useful in your research, please consider citing. @misc{torchdiffeq, author={Chen, Ricky T. Q.}, title={torchdiffeq}, year={2018}, url={https://github.com/rtqichen/torchdiffeq}, }

Owner

  • Name: Ricky Chen
  • Login: rtqichen
  • Kind: user
  • Company: FAIR Labs, Meta AI

Citation (CITATION.cff)

# YAML 1.2
---
abstract: |
    "This library provides ordinary differential equation (ODE) solvers implemented in PyTorch. Backpropagation through ODE solutions is supported using the adjoint method for constant memory cost. We also allow terminating an ODE solution based on an event function, with exact gradient computed.
    
    As the solvers are implemented in PyTorch, algorithms in this repository are fully supported to run on the GPU."
authors: 
  -
    family-names: Chen
    given-names: "Ricky T. Q."
cff-version: "1.1.0"
date-released: 2021-06-02
license: MIT
message: "PyTorch Implementation of Differentiable ODE Solvers"
repository-code: "https://github.com/rtqichen/torchdiffeq"
title: torchdiffeq
version: "0.2.2"
...

GitHub Events

Total
  • Issues event: 9
  • Watch event: 523
  • Issue comment event: 16
  • Push event: 4
  • Pull request review event: 6
  • Pull request review comment event: 4
  • Pull request event: 10
  • Fork event: 62
Last Year
  • Issues event: 9
  • Watch event: 523
  • Issue comment event: 16
  • Push event: 4
  • Pull request review event: 6
  • Pull request review comment event: 4
  • Pull request event: 10
  • Fork event: 62

Committers

Last synced: 9 months ago

All Time
  • Total Commits: 202
  • Total Committers: 22
  • Avg Commits per committer: 9.182
  • Development Distribution Score (DDS): 0.436
Past Year
  • Commits: 7
  • Committers: 2
  • Avg Commits per committer: 3.5
  • Development Distribution Score (DDS): 0.286
Top Committers
Name Email Commits
Ricky Tian Qi Chen r****n@g****m 114
Patrick Kidger 3****r 39
Rafael Valle j****e@g****m 9
Sam Lishak s****m@l****m 6
psv4 4****4 5
timudk t****n@g****m 4
James Morrill j****6@g****m 4
Komal Gupta k****a@a****u 3
lbq 4****7@q****m 2
Rajat Vadiraj Dwaraknath r****d@g****m 2
JamesAllingham j****m@g****m 2
Brandon Amos b****s@g****m 2
Emerson Castaneda e****s 1
Chris Finlay c****y@g****m 1
Brett Koonce k****e@h****m 1
Adam Golinski a****3@g****m 1
MaricelaM 3****M 1
Samet Demir d****t@h****m 1
Shiva s****l@g****m 1
StefOe s****e@g****m 1
Xuechen 1****n 1
haowggit 5****t 1
Committer Domains (Top 20 + Academic)

Issues and Pull Requests

Last synced: 6 months ago

All Time
  • Total issues: 121
  • Total pull requests: 23
  • Average time to close issues: 4 months
  • Average time to close pull requests: 3 months
  • Total issue authors: 115
  • Total pull request authors: 16
  • Average comments per issue: 1.79
  • Average comments per pull request: 1.26
  • Merged pull requests: 8
  • Bot issues: 0
  • Bot pull requests: 0
Past Year
  • Issues: 11
  • Pull requests: 11
  • Average time to close issues: 7 months
  • Average time to close pull requests: 6 days
  • Issue authors: 11
  • Pull request authors: 6
  • Average comments per issue: 0.36
  • Average comments per pull request: 0.64
  • Merged pull requests: 3
  • Bot issues: 0
  • Bot pull requests: 0
Top Authors
Issue Authors
  • raabuchanan (2)
  • agjignesh (2)
  • azournas (2)
  • MaricelaM (2)
  • EyalRozenberg1 (2)
  • chrisyxue (2)
  • zhuqunxi (1)
  • hellomynameisjiji (1)
  • roflmaostc (1)
  • Yunyi-learner (1)
  • gitGksgk (1)
  • tomzhu0225 (1)
  • xiaotai-yang (1)
  • rid-sun (1)
  • bfs18 (1)
Pull Request Authors
  • psv4 (5)
  • pollycoder (2)
  • slishak (2)
  • patrick-kidger (2)
  • rayanirban (2)
  • mbaddar1 (1)
  • pianpwk (1)
  • westny (1)
  • shivak (1)
  • jambo6 (1)
  • MaricelaM (1)
  • Zymrael (1)
  • varunagrawal (1)
  • halduaij (1)
  • ChrisDeGrendele (1)
Top Labels
Issue Labels
Pull Request Labels

Dependencies

setup.py pypi
  • scipy >=1.4.0
  • torch >=1.5.0