AI for Condensed Matter & Materials Physics: Complete Guide

Materials science sits at the epicenter of the AI-in-science revolution. In 2023, DeepMind's GNoME discovered 2.2 million new crystal structures — more than all previously known materials combined. Graph Neural Networks predict properties of crystals never synthesised. ML force fields simulate molecular dynamics 1,000× faster than quantum chemistry. This cluster is your complete guide to how it all works — and how to build it yourself.

🔬 Crystal Property Prediction ⚡ ML Force Fields 🌀 Phase Transitions 🧪 Generative Design

AI for Physics Students  ›  Cluster 3: Particle Physics  ›  Cluster 4: Astrophysics  ›  Cluster 5: Condensed Matter

Continuing from Cluster 4 where we applied ML to the largest scales in nature — galaxies, gravitational waves, cosmological simulations — we now shift to the smallest: individual atoms, crystal lattices, and quantum many-body systems. The tools overlap significantly (GNNs, generative models, Bayesian inference), but the physics is entirely different.

📋 What You Will Learn in This Cluster
  1. Why Materials Science Needs ML Now
  2. Representing Crystals as Graphs
  3. CGCNN: Predicting Formation Energy
  4. ML Force Fields & Fast MD Simulations
  5. Phase Transition Detection with ML
  6. Generative Models for New Materials
  7. The Materials Project API
  8. End-to-End Discovery Pipeline

Section 1 — Why Condensed Matter Physics Needs Machine Learning Right Now

There is a number that every materials scientist knows: 10²³. That’s Avogadro’s number — roughly the number of atoms in a gram of a typical solid. Every macroscopic property you care about (electrical conductivity, magnetic susceptibility, thermal expansion, mechanical strength) emerges from quantum mechanical interactions between those atoms. Computing those properties from first principles is possible in principle. In practice, density functional theory (DFT) — the workhorse of computational materials science — scales as O(N³) with system size. Doubling the unit cell is 8× more expensive. Simulating a realistic interface, defect, or disordered system is beyond reach.

Machine learning changes this by learning a fast approximation to the expensive quantum calculation. Once trained on DFT-computed properties of thousands of materials, a GNN can predict the properties of a new material in milliseconds — with accuracy approaching the DFT calculation it was trained on. This is not an approximation to physics; it is learning physics from data, and using that learned physics to explore the materials space at unprecedented scale.

The scale of the opportunity is staggering. The number of possible inorganic crystal structures with up to four elements is estimated at ~10²°. DFT can compute perhaps 10&sup4; to 10&sup5; of these per year with current compute budgets. Machine learning-guided exploration has already expanded the known stable crystal space by a factor of 40 in a single project (GNoME, Merchant et al. 2023). We are, quite literally, mapping a continent that was previously invisible.

2.2M
new crystal structures discovered by DeepMind GNoME — more than all of human history combined
1,000×
speedup from ML force fields over DFT for molecular dynamics simulations
150k+
DFT-computed structures in the Materials Project — free, open, ready for ML training
🧠 Concept: The Materials Genome Initiative
In 2011, the US government launched the Materials Genome Initiative with the goal of halving the time to discover, develop, and deploy new materials. The philosophy: treat materials discovery as a big-data problem. Build open databases (Materials Project, AFLOW, OQMD), develop computational high-throughput methods, and apply machine learning to find patterns across millions of DFT calculations. This is the scientific infrastructure that makes ML in materials science possible today.

Section 2 — Representing Crystal Structures as Graphs

Before applying machine learning to crystal property prediction, you need to decide how to represent a crystal in a way that captures its physics. This is the most important design decision in the entire pipeline. Get it wrong and your model will miss the structure that determines the properties you care about.

A crystal is a periodic arrangement of atoms. Its unit cell contains atoms at specific positions, and the crystal is the infinite repetition of this cell in three dimensions. To fully characterise a crystal you need: (1) the types of atoms (chemical species), (2) their positions within the unit cell, (3) the lattice vectors defining periodicity, and (4) the topology of bonding — which atoms interact with which.

Why a Graph is the Right Representation

A crystal maps naturally onto a graph where atoms are nodes and bonds (or proximity connections) are edges. This has three crucial properties aligned with physics:

1
Permutation invariance
Crystal properties don't depend on the arbitrary ordering of atoms in your input list. GNNs with invariant aggregation (sum or mean over neighbours) are automatically permutation-invariant — no special handling required.
2
Variable system size
Different crystals have different numbers of atoms per unit cell — from 1 (pure metals) to hundreds (complex oxides). Graphs handle variable-size systems natively. Fixed-size descriptors like Coulomb matrices require padding and lose structural information.
3
Physical locality
Most material properties are governed by local chemical environment — the atom type, bond lengths, and coordination of each atom and its near neighbours. Graph message-passing naturally encodes this locality, aggregating information across k-hop neighbourhoods in k rounds of message passing.

Node and Edge Features: What Goes in the Graph

Feature TypeWhat It EncodesRepresentation
Atom type (node)Chemical species identityLearnable embedding (64–128 dim), initialised from atomic number
Electronegativity (node)Charge transfer tendencyPauling scale, continuous scalar
Bond distance (edge)Most important single featureGaussian RBF expansion: exp(−(d−μk)²/σ²) for 40 centres
Bond vector (edge)Directionality for anisotropic materialsUnit vector; requires equivariant GNNs (e3nn/NequIP) to be fully used
Periodic image offset (edge)Which periodic image the bond crossesInteger triple (0,0,0) to (±1,±1,±1); handles periodicity without unrolling

Section 3 — CGCNN: The Crystal Graph Convolutional Neural Network

The Crystal Graph Convolutional Neural Network (CGCNN), introduced by Xie & Grossman (2018), was the paper that demonstrated GNNs could predict DFT-computed material properties with accuracy rivalling the DFT calculations themselves — at a fraction of the cost. It remains one of the most important papers in ML for materials science, and a beautiful example of physics knowledge embedded directly in network architecture.

The key idea: represent a crystal as a graph, run graph convolutions to build atom-level representations that incorporate local chemical environment, pool all atom representations into a crystal-level embedding, then predict the target property from that embedding.

Formation energy: E_form = E_crystal - sum n_i mu_i

Formation energy — the energy released when a crystal forms from its constituent elements — is the primary stability predictor. A negative formation energy means the crystal is thermodynamically stable. This is the most commonly predicted property in materials ML, because it gates everything else: a material with positive formation energy simply won't exist under normal conditions.

GNN message passing: y_hat = sum over neighbors phi(h_v, h_i, e_vi)

This is the core message-passing update. At each layer, every atom v aggregates messages from its neighbours i, where φ is a learned function of the sender features h_i, receiver features h_v, and edge features e_vi (the Gaussian-expanded bond distance). After several rounds, each atom’s representation encodes information about its multi-shell local chemical environment — exactly the information that determines reactivity, bonding, and thermodynamic stability.

Python — Full CGCNN: structure → graph → convolutions → formation energy
# pip install torch-geometric pymatgen mp-api
import torch, torch.nn as nn
from torch_geometric.nn import MessagePassing, global_mean_pool
from torch_geometric.data import Data, DataLoader
from pymatgen.core import Structure
import numpy as np

# ── Step 1: Convert pymatgen Structure to PyG graph ──────────
def structure_to_graph(structure, cutoff_radius=6.0):
    # Node features: [atomic_number, electronegativity, radius, IE, EA]
    atom_features = []
    for site in structure:
        el = site.specie
        atom_features.append([
            el.Z / 100.0,
            el.X or 0.0,
            float(el.atomic_radius or 1.5),
            float(el.ionization_energy or 8.0),
            float(el.electron_affinity or 0.0)
        ])
    x = torch.tensor(atom_features, dtype=torch.float)

    # Build edges: all atom pairs within cutoff_radius [Angstroms]
    all_nbrs = structure.get_all_neighbors(cutoff_radius, include_index=True)
    src, dst, edge_attr = [], [], []
    centers = np.linspace(0, cutoff_radius, 40)   # 40 Gaussian RBF centres
    sigma   = cutoff_radius / 40 * 2.0
    for i, nbrs in enumerate(all_nbrs):
        for nbr in nbrs:
            j, dist = nbr[2], nbr[1]
            src.append(i); dst.append(j)
            # RBF: each distance becomes a 40-dim smooth feature vector
            rbf = np.exp(-((dist - centers)**2) / (2 * sigma**2))
            edge_attr.append(rbf)
    edge_index = torch.tensor([src, dst], dtype=torch.long)
    edge_attr  = torch.tensor(np.array(edge_attr), dtype=torch.float)
    return Data(x=x, edge_index=edge_index, edge_attr=edge_attr)

# ── Step 2: CGCNN gated graph convolution layer ──────────────
class CGCNNConv(MessagePassing):
    def __init__(self, atom_fea_dim=64, edge_fea_dim=40):
        super().__init__(aggr='add')
        full_in = 2 * atom_fea_dim + edge_fea_dim
        self.gate_linear   = nn.Linear(full_in, atom_fea_dim)   # learned gate
        self.filter_linear = nn.Linear(full_in, atom_fea_dim)   # learned filter
        self.bn            = nn.BatchNorm1d(atom_fea_dim)

    def forward(self, x, edge_index, edge_attr):
        return self.propagate(edge_index, x=x, edge_attr=edge_attr)

    def message(self, x_i, x_j, edge_attr):
        z       = torch.cat([x_i, x_j, edge_attr], dim=1)
        gate    = torch.sigmoid(self.gate_linear(z))    # what to keep
        message = torch.softplus(self.filter_linear(z)) # what to pass
        return gate * message

    def update(self, aggr_out, x):
        return self.bn(torch.relu(aggr_out + x))  # residual connection

# ── Step 3: Full CGCNN model ──────────────────────────────────
class CGCNN(nn.Module):
    def __init__(self, orig_atom_fea_len=5, nbr_fea_len=40,
                 atom_fea_len=64, n_conv=4):
        super().__init__()
        self.embedding = nn.Linear(orig_atom_fea_len, atom_fea_len)
        self.convs = nn.ModuleList([
            CGCNNConv(atom_fea_len, nbr_fea_len) for _ in range(n_conv)
        ])
        self.fc_out = nn.Sequential(
            nn.Linear(atom_fea_len, 128), nn.Softplus(),
            nn.Dropout(0.0),
            nn.Linear(128, 1)   # predict formation energy
        )

    def forward(self, data):
        x = torch.relu(self.embedding(data.x))
        for conv in self.convs:
            x = conv(x, data.edge_index, data.edge_attr)
        x = global_mean_pool(x, data.batch)  # average all atoms → crystal vector
        return self.fc_out(x)

model = CGCNN(n_conv=4)
print(f"CGCNN params: {sum(p.numel() for p in model.parameters()):,}")
print("Input: crystal graph → Output: formation energy [eV/atom]")
The Materials Project API 🌐 Data SourceThe Materials Project provides free API access to 150,000+ DFT-computed crystal structures. Install mp-api and query: MPRester(api_key).materials.search(elements=["Fe","O"]). Download formation energies, band gaps, elastic moduli for any structure in the database.
⚠ Common MistakeA very common mistake: using raw distance values as edge features rather than Gaussian RBF expansion. The RBF expansion creates a smooth, continuous 40-dimensional representation of distance that generalises much better. Always expand bond distances into 20–40 Gaussian basis functions before feeding them as edge features. This single change typically reduces formation energy MAE by 15–25%.

Section 4 — Machine Learning Force Fields: 1,000× Faster Molecular Dynamics

Molecular dynamics (MD) simulation tracks how a collection of atoms moves through time by integrating Newton’s equations of motion. At every time step you need forces on every atom. The most accurate way is DFT, but DFT scales as O(N³) — a single force evaluation takes seconds to minutes. For a nanosecond simulation with femtosecond time steps, you need 10&sup6; evaluations. Completely impractical for systems larger than ~200 atoms.

Classical force fields (AMBER, CHARMM, LAMMPS potentials) solve this with fixed analytic expressions — they’re fast, but lose quantum mechanical accuracy entirely and fail for chemical reactions, charge transfer, and unusual bonding. Machine learning force fields (MLFFs) find the middle ground: they learn the potential energy surface from DFT calculations, achieving near-DFT accuracy at classical force field speeds. The speedup is typically 3–4 orders of magnitude.

ML force field: energy E(R) and forces F_i = negative gradient of E

The critical design constraint: forces must be derived as the negative gradient of the energy via automatic differentiation — never predicted independently. This ensures energy conservation in MD trajectories. A model that predicts forces directly (without an energy) will violate energy conservation and produce unphysical, unstable dynamics within nanoseconds.

Energy-Conserving MLFF with Automatic Differentiation

Python — Energy-conserving ML force field: autograd forces, SiLU activations, RBF expansion
# Modern MLFFs: NequIP, MACE, SchNet, PaiNN, DimeNet
# All share this key design: energy -> autograd -> forces
class MLForceField(nn.Module):
    def __init__(self, n_atom_types=100, n_features=128,
                 n_interactions=6, cutoff=5.0):
        super().__init__()
        self.cutoff     = cutoff
        self.n_rbf      = 50
        # Learnable atom type embeddings
        self.embedding  = nn.Embedding(n_atom_types, n_features)
        # Fixed RBF centres (evenly spaced 0 to cutoff)
        self.register_buffer('rbf_centers',
            torch.linspace(0.5, cutoff, 50))
        # Interaction blocks: aggregate neighbour info into atom representations
        self.interactions = nn.ModuleList([
            nn.Sequential(
                nn.Linear(n_features + 50, n_features),
                nn.SiLU(),   # SiLU/Swish outperforms ReLU for smooth energy surfaces
                nn.Linear(n_features, n_features),
                nn.SiLU()
            ) for _ in range(n_interactions)
        ])
        # Atomic energy output (sum over atoms → total energy)
        self.output = nn.Sequential(
            nn.Linear(n_features, 64), nn.SiLU(),
            nn.Linear(64, 1)
        )

    def rbf_expansion(self, distances):
        # Smooth cutoff envelope × Gaussian basis (ensures forces → 0 at cutoff)
        envelope = (1 - distances/self.cutoff).clamp(0, 1) ** 2
        gauss    = torch.exp(-((distances.unsqueeze(-1) - self.rbf_centers)**2)
                         / (2 * (self.cutoff/50) ** 2))
        return gauss * envelope.unsqueeze(-1)  # [N_edges, 50]

    def forward(self, atomic_numbers, positions, edge_index):
        h        = self.embedding(atomic_numbers)     # [N_atoms, n_features]
        src, dst = edge_index
        disp     = positions[dst] - positions[src]    # displacement vectors
        dist     = disp.norm(dim=-1)              # interatomic distances [Angstrom]
        rbf_feat = self.rbf_expansion(dist)       # [N_edges, 50]
        for interaction in self.interactions:
            msg = interaction(torch.cat([h[src], rbf_feat], dim=-1))
            h   = h.scatter_add(0, dst.unsqueeze(-1).expand_as(msg), msg)
        atomic_E = self.output(h).squeeze(-1)   # energy per atom
        return atomic_E.sum()                  # total system energy

# ── Force computation: ALWAYS via autograd — never predict directly ─
def compute_forces_and_energy(model, atomic_numbers, positions, edge_index):
    positions = positions.clone().requires_grad_(True)
    energy    = model(atomic_numbers, positions, edge_index)
    forces    = -torch.autograd.grad(
        energy, positions,
        create_graph=True,   # needed if computing stress tensor too
        retain_graph=True
    )[0]
    return energy, forces       # both have physical units: eV and eV/Angstrom

# Training loss: MAE on energy + MAE on forces (forces weighted ~100x)
def mlff_loss(E_pred, F_pred, E_true, F_true, force_weight=100.0):
    loss_E = nn.L1Loss()(E_pred, E_true)   # eV/atom
    loss_F = nn.L1Loss()(F_pred, F_true)   # eV/Angstrom
    return loss_E + force_weight * loss_F
# Force weight 100x: because N atoms gives 3N forces but only 1 energy
# More force data points → forces dominate gradient signal during training
🧠 Concept: Why force weight matters in MLFF training
You might wonder why forces get 100× more weight in the loss. The answer is data density: a crystal with N atoms has 1 energy value but 3N force components. For N=50 atoms, that’s 1 energy vs 150 forces. Without heavy force weighting, the model learns a qualitatively correct energy landscape but with incorrect local gradients — leading to MD trajectories that rapidly become unphysical. Always weight forces at least 10× higher than energies; 50–100× is typical.

Section 5 — Detecting Phase Transitions with Machine Learning

Phase transitions — the abrupt change from one state of matter to another — are among the most fascinating phenomena in condensed matter physics. The ferromagnetic transition in iron, the superconducting transition in cuprates, the Mott metal-insulator transition, topological phase transitions: all involve a qualitative change in order at a critical temperature or coupling strength.

Identifying phase boundaries in complex systems (frustrated magnets, strongly correlated electrons, disordered systems) remains one of the hardest open problems in physics. Traditional approaches require defining an order parameter by hand — which requires knowing which symmetry is broken. Machine learning approaches this without that prior knowledge: train a classifier to distinguish phases, and look for where its confidence changes. The critical point emerges from the data.

Phase classification loss: cross-entropy between predicted and true phase labels

Approach 1: Supervised CNN Classifier on the Ising Model

The 2D Ising model is the perfect training ground for this. Above the critical temperature T_c ≈ 2.269 J/k_B, spins are disordered (paramagnetic phase). Below it, they align (ferromagnetic phase). A CNN trained to classify configurations as ordered or disordered will show a sharp change in its confidence exactly at T_c — without ever being told what T_c is.

Python — Ising model phase classifier: CNN discovers T_c from raw spin configurations
# ── Generate Ising configurations via Metropolis Monte Carlo ────
import numpy as np
import torch, torch.nn as nn

def ising_mc(L=32, T=2.0, n_steps=50000, n_burnin=20000):
    """2D Ising model Monte Carlo on L×L periodic lattice.
    T_c = 2.269 J/k_B  (Onsager exact solution)
    """
    spins = np.random.choice([-1, 1], size=(L, L))
    for step in range(n_steps + n_burnin):
        i, j  = np.random.randint(0, L, 2)
        S_nb  = (spins[(i-1)%L,j] + spins[(i+1)%L,j] +
                  spins[i,(j-1)%L] + spins[i,(j+1)%L])
        dE    = 2 * spins[i,j] * S_nb      # energy cost of single flip
        if dE < 0 or np.random.rand() < np.exp(-dE/T):
            spins[i,j] *= -1             # Metropolis acceptance
    return spins                         # thermalised configuration

# ── Build dataset: label 0=ordered (TT_c) ──
T_c = 2.269
configs, labels, temps = [], [], []
for T in np.linspace(1.0, 4.0, 30):   # 30 temperatures
    for _ in range(200):                # 200 configs per T
        configs.append(ising_mc(L=32, T=T))
        labels.append(0 if T < T_c else 1)
        temps.append(T)

X = torch.tensor(np.array(configs), dtype=torch.float).unsqueeze(1)  # [N,1,32,32]
y = torch.tensor(labels, dtype=torch.long)

# ── CNN classifier ───────────────────────────────────────────────
class PhaseClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(1,  32,  3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(32, 64,  3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(64, 128, 3, padding=1), nn.ReLU(),
            nn.AdaptiveAvgPool2d(1), nn.Flatten()
        )
        self.head = nn.Sequential(
            nn.Linear(128, 64), nn.ReLU(),
            nn.Linear(64, 2)
        )
    def forward(self, x): return self.head(self.cnn(x))

# ── After training: plot softmax(output)[:,0] vs temperature ────
# The confidence drops sharply at T_c — network discovers the transition
# Key result (Carrasquilla & Melko 2017): network trained with NO
# knowledge of T_c still identifies it precisely from raw spin configs

Approach 2: Unsupervised Phase Detection with PCA

What if you don’t have labels? What if you don’t even know how many phases exist? Unsupervised dimensionality reduction can reveal phase structure without any prior knowledge. Configurations in the same phase cluster together in representation space, and the phase boundary appears as a discontinuity in the first principal component.

Python — Unsupervised phase detection: PCA separates Ising phases without any labels
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt

# Flatten 32x32 spin configs → 1024-dimensional vectors
X_flat = np.array(configs).reshape(len(configs), -1).astype(float)
T_arr  = np.array(temps)

# PCA: find directions of maximum variance across all configurations
pca   = PCA(n_components=5)
X_pca = pca.fit_transform(X_flat)

# Explained variance: how much of config variability each PC captures
print(f"PC1 explains {pca.explained_variance_ratio_[0]*100:.1f}% of variance")
print(f"PC2 explains {pca.explained_variance_ratio_[1]*100:.1f}% of variance")

# Plot PC1 vs temperature — sharp discontinuity reveals phase boundary
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
sc = ax1.scatter(T_arr, X_pca[:,0], c=T_arr, cmap='coolwarm', s=8, alpha=0.6)
ax1.axvline(2.269, color='k', linestyle='--', lw=2, label='T_c (exact)')
ax1.set_xlabel('Temperature T/J'); ax1.set_ylabel('First Principal Component')
ax1.legend(); plt.colorbar(sc, ax=ax1, label='Temperature')

# 2D PCA: ordered phase clusters bottom-left, disordered spreads top-right
ax2.scatter(X_pca[:,0], X_pca[:,1], c=T_arr, cmap='coolwarm', s=8, alpha=0.5)
ax2.set_xlabel('PC 1')
ax2.set_ylabel('PC 2')
ax2.set_title('2D PCA: Phase Clusters in Configuration Space')
plt.tight_layout(); plt.show()

# The two phases are visually separated in PC space — no labels used
# This technique extends to topological phases, glass transitions, etc.
# where traditional order parameters are hard or impossible to define

Section 6 — Generative Models for Designing New Materials

Predicting properties of known materials is valuable. But the ultimate goal is inverse design: given a target property (band gap of 1.5 eV for a solar cell absorber, high Debye temperature for a superconductor, specific elasticity for an aerospace alloy), generate novel crystal structures that have it. This is the materials science equivalent of drug discovery.

The challenge is formidable. The space of possible crystal structures is enormous and combinatorially complex. The representation must respect physical symmetries (rotational and translational invariance, periodicity). And critically, generated structures must be both property-optimal and synthesisable — a thermodynamically stable material that can actually be made in the lab, not just a computer-generated fiction.

Crystal VAE: Learning a Continuous Latent Space for Materials

A Variational Autoencoder (VAE) compresses crystal representations into a continuous, low-dimensional latent space. Once trained, you can perform gradient-based optimisation in latent space — starting from any crystal, moving in the direction that increases the predicted target property, and decoding the optimised latent vector back to a new crystal structure proposal.

Python — Crystal VAE with latent-space inverse design: gradient ascent toward target band gap
# Crystal VAE — learning a navigable latent space for materials
class CrystalEncoder(nn.Module):
    def __init__(self, atom_fea_dim=64, latent_dim=32):
        super().__init__()
        # Reuse CGCNN-style message passing for encoding
        self.conv1 = CGCNNConv(atom_fea_dim, 40)
        self.conv2 = CGCNNConv(atom_fea_dim, 40)
        self.conv3 = CGCNNConv(atom_fea_dim, 40)
        self.mu_head     = nn.Linear(atom_fea_dim, latent_dim)  # mean of latent Gaussian
        self.logvar_head = nn.Linear(atom_fea_dim, latent_dim)  # log-variance

    def forward(self, data):
        x = self.conv1(data.x, data.edge_index, data.edge_attr)
        x = self.conv2(x,      data.edge_index, data.edge_attr)
        x = self.conv3(x,      data.edge_index, data.edge_attr)
        x = global_mean_pool(x, data.batch)
        return self.mu_head(x), self.logvar_head(x)

class CrystalVAE(nn.Module):
    def __init__(self, latent_dim=32, n_atom_types=100):
        super().__init__()
        self.encoder = CrystalEncoder(latent_dim=latent_dim)
        # Decoder: z → composition vector + 6 lattice parameters (a,b,c,α,β,γ)
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 128), nn.ReLU(),
            nn.Linear(128, 256),         nn.ReLU(),
            nn.Linear(256, n_atom_types + 6)
        )                                          # atom probs + 6 lattice params
        # Property predictor in latent space
        self.property_head = nn.Linear(latent_dim, 1)  # e.g. band gap [eV]

    def reparameterise(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std   # reparameterisation trick

    def forward(self, data):
        mu, logvar = self.encoder(data)
        z          = self.reparameterise(mu, logvar)
        recon      = self.decoder(z)
        prop_pred  = self.property_head(z)
        return recon, mu, logvar, prop_pred

# VAE loss: reconstruction + KL divergence + property MSE
def vae_loss(recon, target, mu, logvar, prop_pred, prop_true, beta=1.0):
    recon_loss = nn.MSELoss()(recon, target)
    kl_loss    = -0.5 * torch.mean(1 + logvar - mu**2 - logvar.exp())
    prop_loss  = nn.MSELoss()(prop_pred, prop_true)
    return recon_loss + beta * kl_loss + prop_loss

# ── Inverse design: gradient ascent toward target property ───────
def optimise_latent(model, z_start, target_gap=1.5, steps=300, lr=0.01):
    """Move z toward target band gap via gradient ascent in latent space."""
    z   = z_start.clone().requires_grad_(True)
    opt = torch.optim.Adam([z], lr=lr)
    for step in range(steps):
        opt.zero_grad()
        pred = model.property_head(z)
        loss = (pred - target_gap)**2
        loss.backward(); opt.step()
        if step % 100 == 0:
            print(f"Step {step}: predicted gap = {pred.item():.3f} eV")
    return model.decoder(z.detach())  # decode optimised z → proposed crystal

# Real-world: GNoME, CDVAE, DiffCSP use more sophisticated architectures
# but this VAE captures the core idea: continuous, navigable crystal space
✅ The GNoME ConnectionDeepMind’s GNoME (Graph Networks for Materials Exploration, Merchant et al. 2023) uses a similar philosophy but at massive scale. It combines GNN property predictors trained on the Materials Project with a structural relaxation network (similar to our MLFF), filters candidates through a stability predictor, then confirms the top candidates with DFT. The result: 2.2 million new stable crystal structures, of which 736 have already been experimentally validated by labs worldwide.

Section 7 — Working with the Materials Project API

Before training any of the models above, you need data. The Materials Project is the most comprehensive open database of computed material properties — 150,000+ crystal structures with DFT-computed formation energies, band gaps, elastic properties, dielectric constants, magnetic moments, and much more. Here is how to access it programmatically.

Python — Materials Project query, DataFrame, graph dataset, CGCNN training loop end-to-end
# pip install mp-api pymatgen
from mp_api.client import MPRester
from pymatgen.core import Structure
import pandas as pd

# ── Query: all stable ternary oxides containing Fe, Ti, or Mn ──
with MPRester("YOUR_API_KEY") as mpr:
    results = mpr.materials.search(
        elements=["O", "Fe", "Ti", "Mn"],
        num_elements=(3, 4),
        energy_above_hull=(0, 0.05),  # stable+nearly-stable [eV/atom]
        fields=[
            "material_id", "formula_pretty",
            "formation_energy_per_atom",
            "band_gap", "structure", "nsites",
            "bulk_modulus", "theoretical"
        ]
    )

# ── Convert to pandas DataFrame ──────────────────────────────
data = [{
    "mp_id":    r.material_id,
    "formula":  r.formula_pretty,
    "E_form":   r.formation_energy_per_atom,
    "band_gap"  r.band_gap,
    "n_sites":  r.nsites,
    "structure"  r.structure    # pymatgen Structure object — use structure_to_graph()
} for r in results]
df = pd.DataFrame(data)
print(f"Found {len(df)} materials")
print(df[["formula", "E_form", "band_gap"]].describe())

# ── Train/val/test split (always split by material, not randomly) ─
from sklearn.model_selection import train_test_split
df_train, df_test = train_test_split(df, test_size=0.15, random_state=42)
df_train, df_val  = train_test_split(df_train, test_size=0.12, random_state=42)
print(f"Train: {len(df_train)}  Val: {len(df_val)}  Test: {len(df_test)}")

# ── Build graph dataset and DataLoader ───────────────────────
graphs = []
for _, row in df_train.iterrows():
    try:
        g   = structure_to_graph(row["structure"], cutoff_radius=6.0)
        g.y = torch.tensor([row["E_form"]],  dtype=torch.float)
        graphs.append(g)
    except Exception as e:
        print(f"Skipped: {e}")

loader = DataLoader(graphs, batch_size=64, shuffle=True)
print(f"Dataset ready: {len(graphs)} training crystals")

# ── Training loop ─────────────────────────────────────────────
model    = CGCNN(n_conv=4)
optim    = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=1e-5)
scheduler= torch.optim.lr_scheduler.ReduceLROnPlateau(optim, patience=10, factor=0.5)
best_mae = float('inf')
for epoch in range(300):
    model.train()
    total_loss = 0
    for batch in loader:
        optim.zero_grad()
        pred = model(batch).squeeze()
        loss = nn.L1Loss()(pred, batch.y.squeeze())
        loss.backward(); optim.step()
        total_loss += loss.item()
    scheduler.step(total_loss)
    if (epoch+1) % 50 == 0:
        print(f"Epoch {epoch+1}: MAE = {total_loss/len(loader):.4f} eV/atom")
Getting your Materials Project API key 🔑 API KeyRegister free at materialsproject.org. After login go to Dashboard → API Keys. The free tier allows 1,000 requests/day — enough to download tens of thousands of structures. For bulk downloads of the entire database, use their bulk data download page which provides the full dataset as compact JSON files. Aim for MAE < 0.1 eV/atom on formation energy — that is the DFT-level accuracy benchmark.

Section 8 — The End-to-End Materials Discovery Pipeline

Let’s now assemble all the pieces into a practical discovery workflow. This is the pipeline that research groups actually use — iteratively, with DFT validation at each stage to ensure ML predictions are physically trustworthy before committing expensive compute or lab resources.

🔄 The Materials Discovery Loop — 7 Stages
1
Data collection — Query Materials Project for all structures in your target element space and stability criterion. Download formation energies, band gaps, elastic moduli. Typical starting dataset: 5,000–150,000 structures.
2
Graph construction — Convert every structure to a PyG graph using structure_to_graph(). Apply Gaussian RBF expansion to bond distances. Cache graphs to disk to avoid recomputing on each training run.
3
CGCNN training — 80% train, 10% val, 10% test split. Target: formation energy MAE < 0.1 eV/atom, band gap MAE < 0.3 eV. Train 200–500 epochs with ReduceLROnPlateau. Use ensemble of 5 models for uncertainty estimates.
4
Candidate generation — Use substitution (replace atoms in stable structures with chemically similar elements) or a generative model (VAE, diffusion) to propose millions of new structures. Both approaches work; substitution is faster to implement.
5
ML screening — Run CGCNN on all candidates (millions in minutes). Keep only structures with predicted formation energy < 0 eV/atom (thermodynamically stable) AND predicted target property in the desired range. Ensemble disagreement gives uncertainty — prioritise low-uncertainty candidates.
6
DFT validation — Run full DFT on the top 100–1,000 ML-selected candidates. This is the expensive step — budget accordingly. Add confirmed stable structures to your training set and retrain (active learning). Repeat from Stage 4.
7
Experimental validation — The most promising DFT-confirmed candidates go to synthesis. ML + DFT raises the synthesis success rate from ∼5% (random search) to >50% (targeted ML). That 10× improvement in hit rate is the economic justification for the entire pipeline.
ML as a filter, not an oracle 💡 Key Mindset ShiftML doesn’t replace DFT or experiment — it makes them 100× more efficient. Instead of running DFT on 1,000 random candidates and finding 10 stable ones, use CGCNN to pre-screen 1,000,000 candidates in minutes, run DFT on the top 1,000, and find 200+ stable ones. The discovery rate improves by a factor of 20. Total compute is similar. That’s the promise, and it’s being realised right now in industrial labs designing battery materials, catalysts, and semiconductors.

External References & Further Reading

  • Xie & Grossman (2018)Crystal Graph Convolutional Neural Networks for an Accurate and Interpretable Prediction of Material Properties. PRL 120, 145301. arXiv:1710.10324 — The original CGCNN. Start here.
  • Merchant et al. (2023)Scaling deep learning for materials discovery. Nature 624, 80–85. doi.org/10.1038/s41586-023-06735-9 — GNoME: 2.2M new crystal structures. Essential reading.
  • Batzner et al. (2022)E(3)-equivariant graph neural networks for data-efficient and accurate interatomic potentials. Nature Communications. arXiv:2101.03164 — NequIP equivariant MLFF.
  • Batatia et al. (2022)MACE: Higher Order Equivariant Message Passing Neural Networks for Fast and Accurate Force Fields. NeurIPS. arXiv:2206.07697 — Current accuracy leader for MLFFs.
  • Carrasquilla & Melko (2017)Machine learning phases of matter. Nature Physics 13, 431. arXiv:1605.01735 — The foundational phase-transition ML paper.
  • Jain et al. (2013)Commentary: The Materials Project: A materials genome approach to accelerating materials innovation. APL Materials. doi.org/10.1063/1.4812323 — The Materials Project paper.
  • Open Catalyst Projectopencatalystproject.org — 250M DFT calculations for catalysis ML. Largest open materials ML dataset with benchmark leaderboards.
📋 Key Takeaways — Cluster 5
  • The materials space is almost entirely unexplored. Of ~10²° possible crystals, we have computed properties for ~150,000. ML-guided exploration is expanding this by orders of magnitude — GNoME is the most dramatic example, but dozens of groups are doing this across different material classes.
  • Crystals are graphs. Always. Atoms as nodes, bonds as edges. This respects permutation invariance, handles variable unit cell size, and encodes local chemical environment through message-passing — exactly the structure that determines material properties.
  • Always use RBF expansion for bond distances. Raw distance values as edge features underperform by 15–25% compared to 40-Gaussian RBF expansion. This is a free, significant improvement.
  • ML force fields must derive forces via autograd. Never predict forces directly. Energy conservation is both a physical requirement and a training stability requirement. Use SiLU activations for smooth energy surfaces.
  • Phase transitions can be detected without defining an order parameter. Train a CNN on spin configurations — T_c emerges where confidence changes. Or use PCA: the two phases separate in the first principal component without any labels.
  • Inverse design uses gradient ascent in latent space. Train a VAE to encode crystals, attach a property predictor to the latent space, then optimise z toward your target property and decode the result.
  • ML is a filter, not an oracle. Pre-screen millions of candidates with CGCNN (seconds), validate the top 1,000 with DFT (days), synthesise the best 10 (months). Total hit rate improves 10–20× over random search.

Leave a Comment

Your email address will not be published. Required fields are marked *

Scroll to Top