Science Score: 44.0%
This score indicates how likely this project is to be science-related based on various indicators:
-
✓CITATION.cff file
Found CITATION.cff file -
✓codemeta.json file
Found codemeta.json file -
✓.zenodo.json file
Found .zenodo.json file -
○DOI references
-
○Academic publication links
-
○Academic email domains
-
○Institutional organization owner
-
○JOSS paper metadata
-
○Scientific vocabulary similarity
Low similarity (11.5%) to scientific vocabulary
Repository
taylor‑mode‑pinns
Basic Info
- Host: GitHub
- Owner: BabaYara
- Language: Jupyter Notebook
- Default Branch: main
- Size: 195 KB
Statistics
- Stars: 0
- Watchers: 0
- Forks: 0
- Open Issues: 0
- Releases: 0
Metadata Files
README.md
Taylor Mode PINNs
This repository provides a small Python package offering Taylor-mode automatic differentiation utilities and a simple Kronecker-Factored Approximate Curvature (KFAC) optimizer for Physics-informed neural networks (PINNs).
Installation
The package follows a standard Python layout. Install it in editable mode with
bash
pip install -e .
This will make the taylor_mode, kron_utils, networks, and pinns modules
available. The pinns module now also exposes a simple train_pinn routine
for quick experiments with KFAC training. The pinns.operators submodule
includes helpers like poisson_residual for assembling common PDE losses,
as well as convenience functions such as divergence.
The operators module now also provides heat_residual for the 1D heat
equation, and burgers_residual for Burgers' equation, making it easy
to experiment with time-dependent problems.
Example
Several notebooks in the notebooks/ folder demonstrate the library.
02_gradient_operator.ipynb demonstrates computing gradients using Taylor-mode utilities.
04_PINN_loss_demo.ipynb shows building a simple Poisson PINN using pinn_loss.
08_KFAC_implementation.ipynb shows a short linear-regression example using the KFACOptimizer.
10_pinn_with_kfac.ipynb demonstrates training a tiny PINN using the KFAC optimizer and the simple MLP utilities from networks.
11_kfac_training.ipynb shows the train_pinn helper in action.
12_poisson_residual_demo.ipynb demonstrates computing Poisson residuals with
the new convenience function.
13_heat_residual_demo.ipynb shows how to evaluate the heat equation residual
using heat_residual.
14_burgers_residual_demo.ipynb illustrates evaluating the Burgers equation
residual with burgers_residual.
examples/v6_two_tree_solver.py provides an end-to-end solver in one file.
13 End-to-end example (two agents, two trees, EZW preferences)
```python """
examples/v6twotree_solver.py
A single-file replica of the full v6 solver we prototyped in
the chat. This script does not import the still-to-be-built
deepbsde/ package; instead it lays out every step explicitly,
with line-by-line comments so Codex agents can translate each
block into its future modular home.
Run: python examples/v6twotreesolver.py --device cpu Estimated wall-time on CPU ≈ 3 min (NFINE=150, batch=256) """
︙ 1. Imports -------------------------------------------------------------
import jax, jax.numpy as jnp import equinox as eqx, optax, functools, argparse, time
︙ 2. Command-line flags --------------------------------------------------
P = argparse.ArgumentParser() P.addargument("--device", default="cpu") P.addargument("--depth", type=int, default=8) P.addargument("--width", type=int, default=128) P.addargument("--steps", type=int, default=5000) ARGS = P.parseargs() jax.config.update("jaxplatformname", ARGS.device)
︙ 3. Economic primitives -------------------------------------------------
γA, ψA, ρA = 7.0, .9, .02 # agent A EZW γB, ψB, ρB =10.0,1.3, .02 # agent B EZW θA = (1-γA)/(1-1/ψA) # homogeneity coeffs θB = (1-γB)/(1-1/ψB)
κu, μu, σ_u = 2.5, 0.0, 0.35 # OU drift in logit of dividend share def logistic(u): return jnp.exp(u)/(1+jnp.exp(u))
T, N = 20., 150 # time-horizon & grid BATCH = 256 eta = 0.0 # log Pareto weight λA/λB ← will stay fixed here
helper to cache algebraic functions
def cacheeta(eta): λA, λB = jnp.exp(.5eta), jnp.exp(-.5eta) λAψ, λBψ = λAψA, λBψB ratio = (λA/λB)**(1/θB) def shareJB(JA): JB = ratio * JA(-θA/θB) num = λAψ * JA(ψAθA) den = num + λBψ * JB(ψBθB) return num/den, JB βlin = ρA + (γA-1)*σu*2/(2ψA) # linearised discount exponent return λA, λB, shareJB, βlin λA, λB, shareJB, βlin = cache_eta(eta)
︙ 4. Residual network (depth & width from CLI) ---------------------------
def resblock(width, key): k1,k2 = jax.random.split(key) return (eqx.nn.Linear(width,width,key=k1), eqx.nn.Linear(width,width,key=k2)) class ResNet(eqx.Module): inproj : eqx.nn.Linear blocks : tuple headY : eqx.nn.Linear headZ : eqx.nn.Linear def call(self,t,u): h = jax.nn.silu(self.inproj(jnp.stack([t,u],-1))) for w1,w2 in self.blocks: h = h + jax.nn.silu(w2(jax.nn.silu(w1(h)))) y = self.headY(h)[:,0] z = self.headZ(h)[:,0] return y,z
key = jax.random.PRNGKey(0) k0,*ks = jax.random.split(key, ARGS.depth+3) blocks = tuple(resblock(ARGS.width,k) for k in ks[:ARGS.depth]) net = ResNet(eqx.nn.Linear(2,ARGS.width,key=k0), blocks, eqx.nn.Linear(ARGS.width,1,key=ks[-2]), eqx.nn.Linear(ARGS.width,1,key=ks[-1]))
︙ 5. Generator f(u,y,z) (social-planner HJB ⇔ BSDE driver) --------------
@functools.partial(jax.jit, staticargnums=0) def makef(eta): λA, , shareJB, _ = cacheeta(eta) @jax.jit def f(u,y,z): x = logistic(u) shareA, = shareJB(y) M = λAshareA(-γA)y**θA # marginal utility index def Mu(uu): s,_ = shareJB(y) # JA treated const wrt u return λAs(-γA)y**θA dM = jax.grad(Mu)(u); d2M = jax.grad(jax.grad(Mu))(u) μu = κu(μ_u-u); LuM = μudM + .5σ_u2d2M r = -LuM/M k = -σu*dM/M σx = σu * x(1-x) σlnC = σx * (1-shareA) return θA/(1-γA)(ρA-γAr-.5(θA-1)(z/σx)2)y \ + θAy(z/σx)*(k-σlnC) return f f = make_f(eta)
︙ 6. Brownian sampler (Sobol + antithetic) ------------------------------
def brownianbatch(N, key): sob = jax.random.sobolsample(N, BATCH//2, dtype=jnp.float32) sob = jnp.clip(sob,1e-6,1-1e-6) g = jax.scipy.stats.norm.ppf(sob) g = jnp.concatenate([g,-g],0) return jnp.sqrt(T/N)*g # scale by √Δt
︙ 7. Loss (Euler + full control-variate) -------------------------------
def bsdeloss(net, key, step): dW = brownianbatch(N, key) dt = T/N def body(carry, inp): u,y,yl = carry; i, dwi = inp t = idtjnp.oneslike(u) yhat, zhat = net(t,u) # control-variate pair ycv = yhat - yl # tamed Euler μu = κu*(μu-u); μu = μu/(1+dtjnp.abs(μu)) u1 = u + μudt + σu*dwi y1 = yhat - f(u,yhat,zhat)dt + z_hatdwi yl1= yl - β_linyldt return (u1,y1,yl1), ycv[-1] # pen = last ycv (uT,yT,ylT), pen = jax.lax.scan( body, (jnp.zeros((BATCH,)), net(0.,0.)[0], jnp.ones((BATCH,))), (jnp.arange(N), dW.T)) return jnp.mean((yT-ylT)2) + .01*jnp.mean(pen2)
optim = optax.adam(learningrate=3e-4) optstate = optim.init(net)
︙ 8. Training loop ------------------------------------------------------
for step in range(ARGS.steps): key, sub = jax.random.split(key) loss, grads = eqx.filtervalueandgrad(bsdeloss)(net, sub, step) updates, optstate = optim.update(grads, optstate, net) net = eqx.apply_updates(net, updates) if step % 500 == 0: print(f"step {step:5d} loss {float(loss):.3e}")
︙ 9. Quick PDE residual check ------------------------------------------
grid = 2.5*jnp.cos(jnp.linspace(0,jnp.pi,400)) def residual(u): y,_ = net(0.,u) z = jax.grad(lambda uu: net(0.,uu)[0])(u) return f(u,y,z) print("max PDE residual", float(jnp.max(jnp.abs(jax.vmap(residual)(grid))))) ```
Owner
- Name: Baba-yara
- Login: BabaYara
- Kind: user
- Location: Portugal
- Company: Nova School of Business and Economics
- Website: www.babayara.com
- Twitter: baba_yara
- Repositories: 103
- Profile: https://github.com/BabaYara
I am a Ph.D. candidate at NOVA SBE who combines machine-learning with econometrics in the study of asset pricing.
Citation (CITATION.cff)
cff-version: 1.2.0
message: "If you use this software, please cite it as below."
title: "Taylor Mode PINNs"
version: "0.1.0"
authors:
- family-names: Baba-Yara
given-names: Fahiz
orcid: "0000-0000-0000-0000"
year: 2024
license: MIT
url: "https://github.com/example/TaylorModePINNs"
GitHub Events
Total
- Push event: 35
- Pull request event: 30
- Create event: 16
Last Year
- Push event: 35
- Pull request event: 30
- Create event: 16
Dependencies
- jax *
- jaxlib *
- numpy *