Physics is written in the language of differential equations. Neural networks can now solve them — from the heat equation to the Schrödinger equation — with a technique called Physics-Informed Neural Networks (PINNs). This is the complete guide: theory, implementation, worked examples, and where the field is going next.
AI for Physics Students › Cluster 2: Neural Networks for Differential Equations
- Why Neural Networks for DEs?
- The Mathematical Foundation
- Automatic Differentiation Explained
- Building Your First PINN from Scratch
- Solving the Heat Equation
- Solving the Schrödinger Equation
- Inverse Problems: Learning Physics from Data
- DeepXDE: The PINN Framework
- Neural ODEs
- When PINNs Fail — and What To Do
Section 1 — Why Would You Solve a Differential Equation With a Neural Network?
That's a fair question. We've been solving differential equations for centuries. We have finite difference methods, finite element methods, spectral methods, Runge-Kutta integrators. They work. So why add neural networks to the picture?
The answer depends on your problem. For simple, well-behaved equations on regular grids with known boundary conditions, classic numerical methods are often better — faster, more accurate, better understood. But modern physics is full of problems where classical methods struggle, and this is exactly where PINNs shine.
The Four Cases Where PINNs Win
You have data but you want to learn the parameters of the governing equation. A PINN can simultaneously fit data and identify unknown physical coefficients — things like diffusivity, viscosity, or decay rates — that would be impossible to extract with classical solvers.
Classical PDE solvers require a mesh. Complex, irregular domains — a patient's brain geometry, an aerofoil at an odd angle — need expensive mesh generation. PINNs are meshless: they evaluate the equation at scattered collocation points, no mesh required.
Classical methods suffer from the curse of dimensionality: a 10-dimensional PDE on a 100-point-per-dimension grid requires 100¹⁰ grid points. Neural networks don't use a grid at all — they scale far better to high-dimensional problems like quantum many-body systems.
When you have both partial data and a governing equation, neither pure ML nor pure numerics is optimal. PINNs combine both: the data constrains the solution where measurements exist, and the physics constrains it everywhere else — especially in regions with no data.
Section 2 — The Mathematical Foundation of PINNs
Let's build the idea from first principles. Suppose you want to solve a general PDE of the form:
where 𝒩 is some differential operator that encodes your physics, and θ are the parameters of a neural network that we're training to satisfy this equation.
The PINN Loss Function
The PINN loss has three components that must all be minimised simultaneously:
Each term has a clear physical meaning:
- ℒ_data — standard mean squared error between network predictions and observed measurements. This is exactly what you'd minimise in ordinary regression. If you have no data at all (pure forward problem), this term is zero.
- ℒ_ODE — the physics residual. Evaluated at a set of collocation points scattered throughout the domain, this penalises the network for violating the governing equation. If the PDE says ∂u/∂t = α ∂²u/∂x², the residual is (∂u/∂t − α ∂²u/∂x²)² — it should be zero everywhere.
- ℒ_BC — the boundary and initial condition loss. This enforces u(x, 0) = u₀(x) and u(boundary) = known values. Without this, the network might satisfy the PDE but with the wrong solution (there are infinitely many).
The weights λ and μ control the relative importance of physics vs boundary conditions vs data. Tuning these is one of the practical challenges of PINNs — we'll discuss strategies below.
What Are Collocation Points?
This concept is central and worth spending time on. In classical finite difference methods, you discretize the domain onto a regular grid and apply your equation at each grid point. PINNs don't use a grid. Instead, you randomly sample points throughout the domain (called collocation points) and evaluate the PDE residual at each one.
You typically need thousands of collocation points, but you choose them freely — uniformly random, Latin hypercube sampling, or adaptively concentrated in regions where the residual is largest. This flexibility is one of PINNs' major advantages over mesh-based methods.
Section 3 — Automatic Differentiation: The Engine Under the Hood
PINNs require computing derivatives of a neural network with respect to its inputs — not its parameters. This is different from what backpropagation does (which computes derivatives with respect to weights). The tool that makes this possible is automatic differentiation, and it's built into PyTorch via torch.autograd.
Automatic differentiation is neither symbolic differentiation (like Mathematica) nor numerical finite differences. It's exact, and it works by tracking every operation in a computational graph and applying the chain rule precisely. For a PINN, this means you can compute ∂u/∂x, ∂²u/∂x², ∂u/∂t — any derivative of the network output with respect to any input — exactly and efficiently.
Computing Derivatives in PyTorch
import torch import torch.nn as nn # Any differentiable function — could be a neural network def u_network(x, t): return torch.sin(x) * torch.exp(-t) # toy example # Create input tensors with requires_grad=True # This tells PyTorch to track all operations for differentiation x = torch.linspace(0, 1, 100, requires_grad=True).reshape(-1,1) t = torch.linspace(0, 1, 100, requires_grad=True).reshape(-1,1) u = u_network(x, t) # First derivative ∂u/∂x du_dx = torch.autograd.grad( u, x, grad_outputs=torch.ones_like(u), create_graph=True # CRITICAL: lets us differentiate again )[0] # Second derivative ∂²u/∂x² (differentiate du_dx again) d2u_dx2 = torch.autograd.grad( du_dx, x, grad_outputs=torch.ones_like(du_dx), create_graph=True )[0] # First derivative ∂u/∂t du_dt = torch.autograd.grad( u, t, grad_outputs=torch.ones_like(u), create_graph=True )[0] print(f"u shape: {u.shape}") print(f"du_dx shape: {du_dx.shape}") print(f"d2u_dx2 shape: {d2u_dx2.shape}") print(f"du_dt shape: {du_dt.shape}")
The key detail is create_graph=True. This tells PyTorch to build a new computational graph for the derivative itself, which allows you to differentiate that derivative again to get second-order derivatives. Without it, you can't compute ∂²u/∂x² — which most PDEs require.
Section 4 — Building Your First PINN From Scratch
Let's build a complete PINN to solve the simplest physical ODE: exponential decay. du/dt = −ku, with u(0) = 1 and unknown decay rate k. This is the prototype for radioactive decay, RC circuits, Newton's law of cooling — all governed by the same equation.
We'll solve this as an inverse problem: given a handful of noisy measurements, the PINN will simultaneously reconstruct the full solution and recover the unknown parameter k.
import torch import torch.nn as nn import numpy as np import matplotlib.pyplot as plt # ── 1. Generate synthetic "experimental" data ───────────────── k_true = 2.5 # true decay rate (unknown to the PINN) t_data = torch.tensor([[0.0],[0.5],[1.0],[1.5],[2.0]]) u_data = torch.exp(-k_true * t_data) + 0.02*torch.randn_like(t_data) # Collocation points — where we enforce the ODE t_phys = torch.linspace(0, 2, 1000).reshape(-1,1) # ── 2. Define the PINN architecture ────────────────────────── class PINN(nn.Module): def __init__(self): super().__init__() # Network approximates u(t) self.net = nn.Sequential( nn.Linear(1, 64), nn.Tanh(), nn.Linear(64, 64), nn.Tanh(), nn.Linear(64, 64), nn.Tanh(), nn.Linear(64, 1) ) # k is a LEARNABLE PARAMETER — this is the inverse problem part self.k = nn.Parameter(torch.tensor([1.0])) def forward(self, t): return self.net(t) def ode_residual(self, t): t = t.clone().requires_grad_(True) u = self.forward(t) du_dt = torch.autograd.grad( u, t, torch.ones_like(u), create_graph=True )[0] return du_dt + self.k * u # du/dt + k*u = 0 # ── 3. Training loop ───────────────────────────────────────── model = PINN() optim = torch.optim.Adam(model.parameters(), lr=1e-3) mse = nn.MSELoss() losses = [] for epoch in range(8000): optim.zero_grad() # Data loss — match measurements loss_data = mse(model(t_data), u_data) # Physics loss — enforce ODE at collocation points residual = model.ode_residual(t_phys) loss_phys = torch.mean(residual**2) # IC loss — enforce u(0) = 1 t0 = torch.tensor([[0.0]]) loss_ic = (model(t0) - 1.0)**2 loss = loss_data + 0.1*loss_phys + 10.0*loss_ic loss.backward() optim.step() losses.append(loss.item()) if (epoch+1) % 2000 == 0: print(f"Epoch {epoch+1}: loss={loss.item():.6f} k={model.k.item():.4f}") print(f"\nFinal: k = {model.k.item():.4f} (true k = {k_true})")
Section 5 — Solving the Heat Equation: A Full PDE Example
Now let's step up to a proper partial differential equation. The heat equation describes how temperature evolves in time and space:
Here α is the thermal diffusivity. We'll solve it on the domain x ∈ [0, 1], t ∈ [0, 1] with boundary conditions u(0,t) = u(1,t) = 0 and initial condition u(x,0) = sin(πx). The exact solution is u(x,t) = sin(πx)·exp(−αt·π²) — so we can verify our PINN precisely.
import torch, torch.nn as nn import numpy as np alpha = 0.4 # thermal diffusivity class HeatPINN(nn.Module): def __init__(self): super().__init__() self.net = nn.Sequential( nn.Linear(2, 64), nn.Tanh(), # 2 inputs: x and t nn.Linear(64, 128), nn.Tanh(), nn.Linear(128, 128), nn.Tanh(), nn.Linear(128, 64), nn.Tanh(), nn.Linear(64, 1) ) def forward(self, x, t): xt = torch.cat([x, t], dim=1) return self.net(xt) def pde_residual(self, x, t): x = x.clone().requires_grad_(True) t = t.clone().requires_grad_(True) u = self.forward(x, t) du_dt = torch.autograd.grad(u, t, torch.ones_like(u), create_graph=True)[0] du_dx = torch.autograd.grad(u, x, torch.ones_like(u), create_graph=True)[0] d2u_dx2 = torch.autograd.grad(du_dx, x, torch.ones_like(du_dx), create_graph=True)[0] return du_dt - alpha * d2u_dx2 # should = 0 # ── Collocation points (random in domain) ─────────────────── N_col = 5000 x_col = torch.rand(N_col, 1) t_col = torch.rand(N_col, 1) # ── Boundary conditions: u=0 at x=0 and x=1 ──────────────── N_bc = 200 t_bc = torch.rand(N_bc, 1) x_bc0 = torch.zeros(N_bc, 1) # x=0 x_bc1 = torch.ones(N_bc, 1) # x=1 # ── Initial condition: u(x,0) = sin(pi*x) ─────────────────── N_ic = 200 x_ic = torch.rand(N_ic, 1) t_ic = torch.zeros(N_ic, 1) u_ic = torch.sin(np.pi * x_ic) # ── Train ──────────────────────────────────────────────────── model = HeatPINN() optim = torch.optim.Adam(model.parameters(), lr=1e-3) mse = nn.MSELoss() for epoch in range(15000): optim.zero_grad() loss_pde = torch.mean(model.pde_residual(x_col, t_col)**2) loss_bc = mse(model(x_bc0, t_bc), torch.zeros_like(t_bc)) + \ mse(model(x_bc1, t_bc), torch.zeros_like(t_bc)) loss_ic = mse(model(x_ic, t_ic), u_ic) loss = loss_pde + 10*loss_bc + 10*loss_ic loss.backward() optim.step() # ── Verify against analytic solution ──────────────────────── x_test = torch.linspace(0,1,100).reshape(-1,1) t_test = torch.full_like(x_test, 0.5) # evaluate at t=0.5 u_pred = model(x_test, t_test).detach().numpy() u_exact= np.sin(np.pi * x_test.numpy()) * np.exp(-alpha * 0.5 * np.pi**2) err = np.max(np.abs(u_pred - u_exact)) print(f"Max error at t=0.5: {err:.6f}")
Section 6 — Solving the Schrödinger Equation
The time-dependent Schrödinger equation is the crown jewel of quantum mechanics:
The wave function ψ is complex-valued, which adds a layer of complexity to PINNs. The standard approach is to split into real and imaginary parts: ψ = u + iv, and have the network output two real channels. The Schrödinger equation then splits into two coupled real PDEs:
- ∂u/∂t = −(ℏ/2m)∂²v/∂x² + (V/ℏ)v
- ∂v/∂t = +(ℏ/2m)∂²u/∂x² − (V/ℏ)u
# Schrödinger PINN — particle in a box (V=0 inside, infinite outside) # Units: hbar=1, 2m=1. Domain: x in [0,1], t in [0,T] class SchrodingerPINN(nn.Module): def __init__(self): super().__init__() # Two outputs: real (u) and imaginary (v) parts of psi self.net = nn.Sequential( nn.Linear(2, 100), nn.Tanh(), nn.Linear(100, 100), nn.Tanh(), nn.Linear(100, 100), nn.Tanh(), nn.Linear(100, 2) # output: [u, v] ) def forward(self, x, t): xt = torch.cat([x, t], dim=1) return self.net(xt) # returns [u, v] per point def pde_residual(self, x, t): x = x.clone().requires_grad_(True) t = t.clone().requires_grad_(True) psi = self.forward(x, t) u, v = psi[:, 0:1], psi[:, 1:2] def D(f, var): return torch.autograd.grad(f, var, torch.ones_like(f), create_graph=True)[0] u_t = D(u, t); v_t = D(v, t) u_xx = D(D(u, x), x); v_xx = D(D(v, x), x) # Coupled real PDEs from Schrodinger (V=0, hbar=1, 2m=1) res_u = u_t + 0.5 * v_xx # ∂u/∂t = -(1/2)∂²v/∂x² res_v = v_t - 0.5 * u_xx # ∂v/∂t = +(1/2)∂²u/∂x² return res_u, res_v def probability_conservation(self, x, t): # |psi|^2 should integrate to 1 at every time step psi = self.forward(x, t) prob = psi[:, 0]**2 + psi[:, 1]**2 return (prob.mean() - 1.0)**2 # normalization constraint
Section 7 — Navier-Stokes & the DeepXDE Framework
Writing a PINN from scratch in PyTorch is great for understanding, but for production use — especially for complex PDEs like the Navier-Stokes equations — you want a dedicated framework. DeepXDE is the leading open-source library for this.
# pip install deepxde import deepxde as dde import numpy as np # ── Solve 1D diffusion-reaction: du/dt = D*d2u/dx2 - k*u ──── D, k = 1.0, 1.0 def pde(x, y): # x[:,0]=position, x[:,1]=time. y=u. dde computes gradients automatically dy_t = dde.grad.jacobian(y, x, i=0, j=1) # du/dt dy_xx = dde.grad.hessian(y, x, i=0, j=0) # d2u/dx2 return dy_t - D*dy_xx + k*y # ── Domain: x in [-1,1], t in [0,1] ───────────────────────── geom = dde.geometry.Interval(-1, 1) timedomain = dde.geometry.TimeDomain(0, 1) geomtime = dde.geometry.GeometryXTime(geom, timedomain) # ── Boundary and initial conditions ───────────────────────── def boundary(x, on_boundary): return on_boundary def ic_func(x): return np.sin(np.pi * x[:, 0:1]) bc = dde.DirichletBC(geomtime, lambda x: 0, boundary) ic = dde.IC(geomtime, ic_func, lambda x, on_initial: on_initial) # ── Assemble problem and train ─────────────────────────────── data = dde.data.TimePDE( geomtime, pde, [bc, ic], num_domain=3000, # collocation points inside domain num_boundary=300, # boundary points num_initial=300 # initial condition points ) net = dde.nn.FNN([2, 64, 64, 64, 1], "tanh", "Glorot normal") model = dde.Model(data, net) model.compile("adam", lr=1e-3) model.train(iterations=20000) model.compile("L-BFGS") # refine with L-BFGS after Adam model.train() print("Training complete. Predict with model.predict(X_test)")
DeepXDE handles the collocation point sampling, gradient computation, boundary conditions, and loss weighting automatically. It supports backends including TensorFlow, PyTorch, and JAX, and has pre-built implementations for dozens of classic PDEs. The two-phase training strategy — Adam first for fast convergence, then L-BFGS for precision — is best practice for PINN training.
Section 8 — Neural ODEs: A Different Paradigm
PINNs train a network to satisfy a known equation. Neural ODEs, introduced in the landmark Chen et al. (NeurIPS 2018) paper, take a different perspective: what if the dynamics themselves are a neural network?
The key idea is that a residual network (ResNet) looks like Euler's method for numerical integration. As you add more layers with smaller steps, a ResNet converges to a continuous-depth model — an ODE. The derivative of the hidden state is parameterised by a neural network:
h(t₁) = h(t₀) + ∫[t₀→t₁] f(h(t), t; θ) dt ← solve with any ODE solver
This is powerful for physics in several ways. First, it provides a natural way to model continuous-time dynamics from irregularly-sampled time series data — something traditional RNNs struggle with. Second, it's memory-efficient: instead of storing all intermediate activations, you can compute gradients using the adjoint method. Third, it's interpretable: the learned f(h, t; θ) is literally the equation of motion of your system.
# pip install torchdiffeq from torchdiffeq import odeint import torch, torch.nn as nn # ── Neural ODE that learns the dynamics of a damped oscillator ─ class ODEFunc(nn.Module): """Neural network that approximates the RHS of dh/dt = f(h, t)""" def __init__(self): super().__init__() self.net = nn.Sequential( nn.Linear(2, 64), nn.Tanh(), nn.Linear(64, 64), nn.Tanh(), nn.Linear(64, 2) # output: [dx/dt, dv/dt] ) def forward(self, t, h): return self.net(h) # autonomous system (no explicit t) # ── Generate training data from a known damped oscillator ────── def true_dynamics(t, h): x, v = h[:, 0], h[:, 1] return torch.stack([v, -0.3*v - 4.0*x], dim=1) # γ=0.3, ω₀²=4 t_span = torch.linspace(0, 5, 100) h0 = torch.tensor([[1.0, 0.0]]) # x(0)=1, v(0)=0 traj_true = odeint(true_dynamics, h0, t_span) # [T, batch, 2] # ── Train the Neural ODE to match the trajectory ─────────────── odefunc = ODEFunc() optim = torch.optim.Adam(odefunc.parameters(), lr=1e-3) for epoch in range(3000): optim.zero_grad() traj_pred = odeint(odefunc, h0, t_span) # predict trajectory loss = nn.MSELoss()(traj_pred, traj_true) loss.backward() optim.step() if (epoch+1) % 1000 == 0: print(f"Epoch {epoch+1}: loss = {loss.item():.6f}") # The learned odefunc.net now encodes the damped oscillator dynamics # Extrapolate beyond training range — physics generalises! t_extrap = torch.linspace(0, 10, 200) # 2x training range traj_ext = odeint(odefunc, h0, t_extrap)
Section 9 — When PINNs Fail (and What to Do About It)
PINNs are not magic. They have well-known failure modes, and understanding them is essential before you deploy one on a real problem. Here are the four most common issues:
1. Loss Term Imbalance (The #1 Problem)
The three loss terms (data, physics, boundary conditions) can have very different magnitudes. If ℒ_data ≈ 10⁻² and ℒ_physics ≈ 10⁴, gradient descent will completely ignore the physics loss in early training. This leads to a network that fits the data perfectly but violates the governing equation everywhere else.
Fix: Normalize loss terms. A good heuristic is to weight each term so its contribution to the initial total loss is roughly equal. The NTK-based adaptive weighting scheme (Wang et al., 2022) does this automatically during training.
2. Spectral Bias (Slow Convergence for High-Frequency Solutions)
Neural networks preferentially learn low-frequency components first (this is called spectral bias or frequency principle). If your PDE solution has sharp gradients or high-frequency oscillations — like in turbulent flow or a rapidly oscillating wavefunction — a standard PINN will converge very slowly or not at all.
Fix: Use Fourier feature embeddings (Tancik et al., 2020) to preprocess your coordinates before feeding them to the network. Encoding x → [sin(2πBx), cos(2πBx)] with random frequencies B gives the network access to high-frequency components from the start.
class FourierFeatureNet(nn.Module): def __init__(self, sigma=10.0, n_freq=128): super().__init__() # Random frequency matrix — fixed, not trained self.register_buffer('B', sigma * torch.randn(2, n_freq)) # 2 inputs: x,t self.net = nn.Sequential( nn.Linear(2*n_freq, 128), nn.Tanh(), nn.Linear(128, 128), nn.Tanh(), nn.Linear(128, 1) ) def encode(self, x, t): xt = torch.cat([x, t], dim=1) # [N, 2] proj = 2 * np.pi * xt @ self.B # [N, n_freq] return torch.cat([proj.sin(), proj.cos()], dim=1) # [N, 2*n_freq] def forward(self, x, t): return self.net(self.encode(x, t))
3. Causality Violation in Time-Dependent Problems
Standard PINNs sample collocation points uniformly in time. This means the network tries to satisfy the PDE at t = 0.9 before it has correctly learned the solution at t = 0.1. This violates the causal structure of time-dependent problems and leads to incorrect solutions.
Fix: Use causal training (Wang et al., 2022 — arXiv:2203.07404). Weight collocation points near t = 0 more heavily early in training, and progressively extend the time window as the solution converges. This is analogous to shooting methods in classical numerical analysis.
4. When to Just Use a Classical Solver
Quick Reference: Which Method for Which Problem?
| Problem Type | Recommended Method | Why |
|---|---|---|
| Simple ODE, known coefficients | scipy.solve_ivp / RK45 | Faster, exact, well-validated |
| ODE, unknown parameters + data | PINN (inverse problem) | Simultaneous fit + parameter discovery |
| PDE on regular grid, smooth solution | FEniCS / Dedalus / FDM | Superior accuracy and speed |
| PDE on irregular domain / sparse data | PINN / DeepXDE | Meshless, handles partial observations |
| High-dimensional PDE (d > 4) | PINN / Deep Galerkin Method | No curse of dimensionality |
| Dynamics from trajectory data | Neural ODE (torchdiffeq) | Learns equations of motion from observations |
| Need fast repeated solves (many parameters) | Neural surrogate model | 1000× speedup over classical solver per query |
External References & Further Reading
- Raissi, Perdikaris & Karniadakis (2019) — Physics-informed neural networks: A deep learning framework for solving forward and inverse problems involving nonlinear PDEs. Journal of Computational Physics. doi.org/10.1016/j.jcp.2018.10.045 — The original PINN paper. Start here.
- Chen et al. (2018) — Neural Ordinary Differential Equations. NeurIPS. arXiv:1806.07366 — The Neural ODE paper. One of the most cited ML papers of the decade.
- Lu et al. (2021) — DeepXDE: A deep learning library for solving differential equations. SIAM Review. doi.org/10.1137/19M1274067 — The DeepXDE framework paper.
- Wang, Teng & Perdikaris (2022) — Understanding and mitigating gradient pathologies in physics-informed neural networks. SIAM J. Sci. Comput. arXiv:2001.04536 — Essential reading on PINN failure modes and NTK-based fixes.
- Wang et al. (2022) — Respecting causality is all you need for training physics-informed neural networks. arXiv:2203.07404 — Causal training for time-dependent PDEs.
- Tancik et al. (2020) — Fourier Features Let Networks Learn High Frequency Functions in Low Dimensional Domains. NeurIPS. arXiv:2006.10739 — The Fourier feature embedding trick that fixes spectral bias.
- DeepXDE documentation — deepxde.readthedocs.io — Comprehensive tutorials for 40+ PDE types, inverse problems, and complex geometries.
- PINNs solve the forward problem (given a PDE, find u(x,t)) and the inverse problem (given data + a PDE structure, find unknown parameters) simultaneously.
- Three loss terms. Data loss + physics residual + boundary conditions. Weighting matters enormously.
- Automatic differentiation via PyTorch autograd computes exact PDE residuals — no finite differences needed.
- DeepXDE abstracts all the boilerplate. Use it for production. Write from scratch to understand internals.
- Neural ODEs parameterise the dynamics themselves as a neural network — powerful for learning equations of motion from data.
- Know the failure modes: loss imbalance, spectral bias, causality violation. All have known fixes.
- Don't use PINNs for simple problems. Use classical solvers when they're sufficient. PINNs earn their complexity on inverse problems, irregular domains, and high-dimensional PDEs.

Pingback: AI in Particle Physics: Machine Learning at the LHC & Beyond
Pingback: Machine Learning for Curve Fitting in Physics: Full Guide
Pingback: AI for Physics Students: The Complete Guide