pinn_kfac

taylor‑mode‑pinns

https://github.com/babayara/pinn_kfac

Science Score: 44.0%

This score indicates how likely this project is to be science-related based on various indicators:

  • CITATION.cff file
    Found CITATION.cff file
  • codemeta.json file
    Found codemeta.json file
  • .zenodo.json file
    Found .zenodo.json file
  • DOI references
  • Academic publication links
  • Academic email domains
  • Institutional organization owner
  • JOSS paper metadata
  • Scientific vocabulary similarity
    Low similarity (11.5%) to scientific vocabulary
Last synced: 6 months ago · JSON representation ·

Repository

taylor‑mode‑pinns

Basic Info
  • Host: GitHub
  • Owner: BabaYara
  • Language: Jupyter Notebook
  • Default Branch: main
  • Size: 195 KB
Statistics
  • Stars: 0
  • Watchers: 0
  • Forks: 0
  • Open Issues: 0
  • Releases: 0
Created 9 months ago · Last pushed 7 months ago
Metadata Files
Readme Citation

README.md

Taylor Mode PINNs

This repository provides a small Python package offering Taylor-mode automatic differentiation utilities and a simple Kronecker-Factored Approximate Curvature (KFAC) optimizer for Physics-informed neural networks (PINNs).

Installation

The package follows a standard Python layout. Install it in editable mode with

bash pip install -e .

This will make the taylor_mode, kron_utils, networks, and pinns modules available. The pinns module now also exposes a simple train_pinn routine for quick experiments with KFAC training. The pinns.operators submodule includes helpers like poisson_residual for assembling common PDE losses, as well as convenience functions such as divergence.

The operators module now also provides heat_residual for the 1D heat equation, and burgers_residual for Burgers' equation, making it easy to experiment with time-dependent problems.

Example

Several notebooks in the notebooks/ folder demonstrate the library. 02_gradient_operator.ipynb demonstrates computing gradients using Taylor-mode utilities. 04_PINN_loss_demo.ipynb shows building a simple Poisson PINN using pinn_loss. 08_KFAC_implementation.ipynb shows a short linear-regression example using the KFACOptimizer. 10_pinn_with_kfac.ipynb demonstrates training a tiny PINN using the KFAC optimizer and the simple MLP utilities from networks. 11_kfac_training.ipynb shows the train_pinn helper in action. 12_poisson_residual_demo.ipynb demonstrates computing Poisson residuals with the new convenience function. 13_heat_residual_demo.ipynb shows how to evaluate the heat equation residual using heat_residual. 14_burgers_residual_demo.ipynb illustrates evaluating the Burgers equation residual with burgers_residual. examples/v6_two_tree_solver.py provides an end-to-end solver in one file.

13 End-to-end example (two agents, two trees, EZW preferences)

```python """

examples/v6twotree_solver.py

A single-file replica of the full v6 solver we prototyped in the chat. This script does not import the still-to-be-built deepbsde/ package; instead it lays out every step explicitly, with line-by-line comments so Codex agents can translate each block into its future modular home.

Run: python examples/v6twotreesolver.py --device cpu Estimated wall-time on CPU ≈ 3 min (NFINE=150, batch=256) """

︙ 1. Imports -------------------------------------------------------------

import jax, jax.numpy as jnp import equinox as eqx, optax, functools, argparse, time

︙ 2. Command-line flags --------------------------------------------------

P = argparse.ArgumentParser() P.addargument("--device", default="cpu") P.addargument("--depth", type=int, default=8) P.addargument("--width", type=int, default=128) P.addargument("--steps", type=int, default=5000) ARGS = P.parseargs() jax.config.update("jaxplatformname", ARGS.device)

︙ 3. Economic primitives -------------------------------------------------

γA, ψA, ρA = 7.0, .9, .02 # agent A EZW γB, ψB, ρB =10.0,1.3, .02 # agent B EZW θA = (1-γA)/(1-1/ψA) # homogeneity coeffs θB = (1-γB)/(1-1/ψB)

κu, μu, σ_u = 2.5, 0.0, 0.35 # OU drift in logit of dividend share def logistic(u): return jnp.exp(u)/(1+jnp.exp(u))

T, N = 20., 150 # time-horizon & grid BATCH = 256 eta = 0.0 # log Pareto weight λA/λB ← will stay fixed here

helper to cache algebraic functions

def cacheeta(eta): λA, λB = jnp.exp(.5eta), jnp.exp(-.5eta) λAψ, λBψ = λAψA, λBψB ratio = (λA/λB)**(1/θB) def shareJB(JA): JB = ratio * JA(-θA/θB) num = λAψ * JA(ψAθA) den = num + λBψ * JB(ψBθB) return num/den, JB βlin = ρA + (γA-1)*σu*2/(2ψA) # linearised discount exponent return λA, λB, shareJB, βlin λA, λB, shareJB, βlin = cache_eta(eta)

︙ 4. Residual network (depth & width from CLI) ---------------------------

def resblock(width, key): k1,k2 = jax.random.split(key) return (eqx.nn.Linear(width,width,key=k1), eqx.nn.Linear(width,width,key=k2)) class ResNet(eqx.Module): inproj : eqx.nn.Linear blocks : tuple headY : eqx.nn.Linear headZ : eqx.nn.Linear def call(self,t,u): h = jax.nn.silu(self.inproj(jnp.stack([t,u],-1))) for w1,w2 in self.blocks: h = h + jax.nn.silu(w2(jax.nn.silu(w1(h)))) y = self.headY(h)[:,0] z = self.headZ(h)[:,0] return y,z

key = jax.random.PRNGKey(0) k0,*ks = jax.random.split(key, ARGS.depth+3) blocks = tuple(resblock(ARGS.width,k) for k in ks[:ARGS.depth]) net = ResNet(eqx.nn.Linear(2,ARGS.width,key=k0), blocks, eqx.nn.Linear(ARGS.width,1,key=ks[-2]), eqx.nn.Linear(ARGS.width,1,key=ks[-1]))

︙ 5. Generator f(u,y,z) (social-planner HJB ⇔ BSDE driver) --------------

@functools.partial(jax.jit, staticargnums=0) def makef(eta): λA, , shareJB, _ = cacheeta(eta) @jax.jit def f(u,y,z): x = logistic(u) shareA, = shareJB(y) M = λAshareA(-γA)y**θA # marginal utility index def Mu(uu): s,_ = shareJB(y) # JA treated const wrt u return λAs(-γA)y**θA dM = jax.grad(Mu)(u); d2M = jax.grad(jax.grad(Mu))(u) μu = κu(μ_u-u); LuM = μudM + .5σ_u2d2M r = -LuM/M k = -σu*dM/M σx = σu * x(1-x) σlnC = σx * (1-shareA) return θA/(1-γA)(ρA-γAr-.5(θA-1)(z/σx)2)y \ + θAy(z/σx)*(k-σlnC) return f f = make_f(eta)

︙ 6. Brownian sampler (Sobol + antithetic) ------------------------------

def brownianbatch(N, key): sob = jax.random.sobolsample(N, BATCH//2, dtype=jnp.float32) sob = jnp.clip(sob,1e-6,1-1e-6) g = jax.scipy.stats.norm.ppf(sob) g = jnp.concatenate([g,-g],0) return jnp.sqrt(T/N)*g # scale by √Δt

︙ 7. Loss (Euler + full control-variate) -------------------------------

def bsdeloss(net, key, step): dW = brownianbatch(N, key) dt = T/N def body(carry, inp): u,y,yl = carry; i, dwi = inp t = idtjnp.oneslike(u) yhat, zhat = net(t,u) # control-variate pair ycv = yhat - yl # tamed Euler μu = κu*(μu-u); μu = μu/(1+dtjnp.abs(μu)) u1 = u + μudt + σu*dwi y1 = yhat - f(u,yhat,zhat)dt + z_hatdwi yl1= yl - β_linyldt return (u1,y1,yl1), ycv[-1] # pen = last ycv (uT,yT,ylT), pen = jax.lax.scan( body, (jnp.zeros((BATCH,)), net(0.,0.)[0], jnp.ones((BATCH,))), (jnp.arange(N), dW.T)) return jnp.mean((yT-ylT)2) + .01*jnp.mean(pen2)

optim = optax.adam(learningrate=3e-4) optstate = optim.init(net)

︙ 8. Training loop ------------------------------------------------------

for step in range(ARGS.steps): key, sub = jax.random.split(key) loss, grads = eqx.filtervalueandgrad(bsdeloss)(net, sub, step) updates, optstate = optim.update(grads, optstate, net) net = eqx.apply_updates(net, updates) if step % 500 == 0: print(f"step {step:5d} loss {float(loss):.3e}")

︙ 9. Quick PDE residual check ------------------------------------------

grid = 2.5*jnp.cos(jnp.linspace(0,jnp.pi,400)) def residual(u): y,_ = net(0.,u) z = jax.grad(lambda uu: net(0.,uu)[0])(u) return f(u,y,z) print("max PDE residual", float(jnp.max(jnp.abs(jax.vmap(residual)(grid))))) ```

Owner

  • Name: Baba-yara
  • Login: BabaYara
  • Kind: user
  • Location: Portugal
  • Company: Nova School of Business and Economics

I am a Ph.D. candidate at NOVA SBE who combines machine-learning with econometrics in the study of asset pricing.

Citation (CITATION.cff)

cff-version: 1.2.0
message: "If you use this software, please cite it as below."
title: "Taylor Mode PINNs"
version: "0.1.0"
authors:
  - family-names: Baba-Yara
    given-names: Fahiz
    orcid: "0000-0000-0000-0000"
year: 2024
license: MIT
url: "https://github.com/example/TaylorModePINNs"

GitHub Events

Total
  • Push event: 35
  • Pull request event: 30
  • Create event: 16
Last Year
  • Push event: 35
  • Pull request event: 30
  • Create event: 16

Dependencies

pyproject.toml pypi
  • jax *
  • jaxlib *
  • numpy *