Machine Learning for Astrophysics & Cosmology: Complete Guide

The universe generates data at a scale that humbles even the most ambitious particle physicist. Billions of galaxies, millions of transients, continuous gravitational wave streams, terabytes of stellar spectra — and now, AI systems that can read all of it. This cluster covers the algorithms, pipelines, and breakthroughs putting machine learning at the heart of observational astronomy.

🌌 Galaxy Classification 🌊 Gravitational Waves 🔭 Exoplanet Detection 🧮 Symbolic Regression

AI for Physics Students  ›  Cluster 4: ML for Astrophysics & Cosmology

📋 In This Article
  1. The Astrophysics Data Tsunami
  2. Galaxy Morphology with CNNs
  3. Photometric Redshift Estimation
  4. Gravitational Wave Detection
  5. Exoplanet Transit Detection
  6. Symbolic Regression for Cosmological Laws
  7. Simulation-Based Inference
  8. Working with astropy + ML

Section 1 — The Astrophysics Data Tsunami

Astrophysics has always been a data-rich science. Every photon that reaches Earth carries information about the universe's history. But for most of astronomical history, the bottleneck was collecting enough signal. That constraint has inverted completely.

The Vera C. Rubin Observatory, coming online in the late 2020s, will image the entire southern sky every three nights for ten years, generating roughly 20 terabytes of data per night and cataloguing 17 billion galaxies. The Square Kilometre Array (SKA) will produce 700 terabytes per second of raw radio data. The Gaia mission has already catalogued 1.7 billion stars. The TESS exoplanet telescope produces 100 GB of light curves per day.

The bottleneck is now analysis — and the solution is machine learning at every stage: automated classification, anomaly flagging, parameter estimation, and even law discovery. This isn't a future aspiration; it's the present operational reality of modern observatories.

17B
galaxies catalogued by Rubin — impossible to classify by eye
20 TB
per night from Vera Rubin Observatory — requires automated ML pipelines
5,000+
confirmed exoplanets, most found using ML-assisted transit detection

Section 2 — Galaxy Morphology Classification with CNNs

The Galaxy Zoo project launched in 2007 with a bold experiment: crowdsource galaxy classification to volunteers. Over 150,000 people classified over 900,000 galaxies — elliptical, spiral, merging, and dozens of subcategories. It worked beautifully, and produced the largest labelled galaxy dataset ever assembled. It also produced a problem: humans can't scale to 17 billion galaxies.

Convolutional neural networks solved this scaling problem. Trained on the Galaxy Zoo labels, CNNs can classify galaxy morphology at a fraction of a second per image, with accuracy matching or exceeding the volunteer consensus. More importantly, CNNs discover morphological features that humans never explicitly defined — subtle asymmetries, colour gradients, and structural details that correlate with galaxy age, star formation rate, and merger history.

Galaxy Morphology CNN from Scratch

Python — Galaxy morphology CNN: ResNet18 transfer learning + astronomy augmentations
import torch, torch.nn as nn
from torchvision import transforms, models
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader

# ── Transfer learning from ResNet18 (pre-trained on ImageNet) ─
# Galaxy images share low-level features (edges, textures) with natural images
model = models.resnet18(pretrained=True)

# Freeze early layers — keep learned edges/textures
for name, param in model.named_parameters():
    if 'layer1' in name or 'layer2' in name:
        param.requires_grad = False

# Replace final classifier for galaxy morphology (4 classes)
# Elliptical | Spiral | Merging | Irregular
n_features = model.fc.in_features
model.fc = nn.Sequential(
    nn.Dropout(0.4),
    nn.Linear(n_features, 256), nn.ReLU(),
    nn.Dropout(0.3),
    nn.Linear(256, 4)  # 4 morphology classes
)

# ── Data augmentation (critical for galaxy images) ────────────
train_transform = transforms.Compose([
    transforms.RandomRotation(360),          # galaxies have no preferred orientation
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.ColorJitter(brightness=0.3, contrast=0.3),
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406],
                       std=[0.229,0.224,0.225])
])

# ── Training ─────────────────────────────────────────────────
dataset  = ImageFolder('galaxy_zoo/train', transform=train_transform)
loader   = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=4)
optim    = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=1e-4)
scheduler= torch.optim.lr_scheduler.CosineAnnealingLR(optim, T_max=50)
criterion= nn.CrossEntropyLoss(label_smoothing=0.1)  # prevents overconfidence

# Public dataset: Galaxy Zoo 2 on Kaggle / AstroML
# or use: astroml.datasets.fetch_sdss_galaxy_colors() for photometric data
Galaxy Zoo datasets 🌐 Public Data The full Galaxy Zoo challenge dataset is on Kaggle with 61,578 galaxy images and continuous morphology scores. The Sloan Digital Sky Survey (SDSS) provides free access to hundreds of millions of galaxy images and spectra via their SkyServer API.

Section 3 — Photometric Redshift Estimation

Every galaxy's light is redshifted by the expansion of the universe — and the redshift tells us the galaxy's distance. The gold standard is spectroscopic redshift: disperse the galaxy's light into a spectrum and measure the shift of known spectral lines. This is extremely accurate but slow — a large telescope can measure maybe 1,000 spectra per night.

With billions of galaxies to survey, spectroscopy can only cover a tiny fraction. The solution is photometric redshift: estimate the redshift from multi-band photometry (brightness measured through several colour filters). It's less accurate, but 1,000× faster. Neural networks have become the dominant method for this task.

Photometric redshift: argmin chi-squared between observed and template fluxes

Classical template-fitting methods like that equation above have been mostly superseded by neural networks that learn the mapping from photometric colours to redshift directly from spectroscopic training sets.

Python — Mixture Density Network for photometric redshifts: full p(z) not just a point estimate
import numpy as np
import torch, torch.nn as nn
from astroml.datasets import fetch_sdss_specgals

# ── Load SDSS photometric data ─────────────────────────────────
data = fetch_sdss_specgals()              # ~700k galaxies with ugriz photometry
# Features: magnitudes in u, g, r, i, z bands + colour indices
X = np.column_stack([
    data['u'], data['g'], data['r'], data['i'], data['z'],
    data['u'] - data['g'],   # u-g colour
    data['g'] - data['r'],   # g-r colour
    data['r'] - data['i'],   # r-i colour
    data['i'] - data['z'],   # i-z colour — most redshift-sensitive
]).astype(np.float32)
y = data['redshift'].astype(np.float32)   # spectroscopic z — ground truth

# ── Mixture Density Network: predicts full p(z|photometry) ────
# Instead of a point estimate, outputs a mixture of Gaussians
# This is physically motivated: photo-z has multimodal uncertainties
class MDN(nn.Module):
    def __init__(self, n_input=9, n_hidden=256, n_components=5):
        super().__init__()
        self.K = n_components
        self.shared = nn.Sequential(
            nn.Linear(n_input, n_hidden), nn.ReLU(), nn.BatchNorm1d(n_hidden),
            nn.Linear(n_hidden, n_hidden), nn.ReLU(), nn.BatchNorm1d(n_hidden),
            nn.Linear(n_hidden, n_hidden), nn.ReLU()
        )
        self.pi    = nn.Linear(n_hidden, n_components)   # mixture weights
        self.mu    = nn.Linear(n_hidden, n_components)   # Gaussian means
        self.sigma = nn.Linear(n_hidden, n_components)   # Gaussian widths

    def forward(self, x):
        h = self.shared(x)
        pi    = torch.softmax(self.pi(h), dim=1)
        mu    = self.mu(h)
        sigma = torch.exp(self.sigma(h)).clamp(min=1e-4)  # must be positive
        return pi, mu, sigma

# MDN loss: negative log-likelihood of a Gaussian mixture
def mdn_loss(pi, mu, sigma, y):
    y  = y.unsqueeze(1).expand_as(mu)
    log_probs = -(0.5 * ((y-mu)/sigma)**2 + sigma.log() + 0.5*np.log(2*np.pi))
    log_pi    = torch.log(pi + 1e-10)
    return -torch.logsumexp(log_probs + log_pi, dim=1).mean()

# Training: same as any neural network — minimise mdn_loss
# Evaluation metric: NMAD = 1.48 * median(|Δz/(1+z_spec)|)
# State-of-art neural nets achieve NMAD < 0.01 on SDSS

Section 4 — Gravitational Wave Detection with Machine Learning

On September 14th 2015, LIGO detected the first gravitational wave signal from a binary black hole merger 1.3 billion light years away. The signal spent 0.2 seconds in the detector's sensitive band. It was buried in noise 10,000 times larger than the signal itself. Finding it required matched filtering against a bank of 250,000 theoretical waveform templates — a technique that works brilliantly for known signal shapes but becomes computationally prohibitive as template banks grow.

Machine learning is now transforming every stage of gravitational wave astronomy: real-time detection, parameter estimation, noise transient classification, and the search for continuous waves from spinning neutron stars. The key advantage: a trained neural network can classify an event in milliseconds, versus minutes for matched filtering — critical for enabling rapid electromagnetic follow-up before the afterglow fades.

CNN Classifier for Gravitational Wave Signals

Python — 1D CNN for GW signal detection: whitened strain from LIGO/GWOSC real data
# Gravitational wave signal classification
# Input: 1-second strain time series from LIGO/Virgo detector
# Task: binary classification — GW signal present (1) or noise only (0)

from gwpy.timeseries import TimeSeries   # pip install gwpy
import torch, torch.nn as nn
import numpy as np

# ── Load real LIGO data via GWOSC (Gravitational Wave Open Science Center) ─
# GW150914: the first ever detection
data = TimeSeries.fetch_open_data(
    'H1',                    # LIGO Hanford
    1126259462, 1126259478,   # GPS time around GW150914
    sample_rate=4096,
    cache=True
)
# Whitening: flatten the noise power spectrum so signal stands out
whitened = data.whiten(fftlength=4, method='median')

# ── 1D CNN on time-domain strain data ─────────────────────────
class GWDetector(nn.Module):
    def __init__(self, seq_len=4096):
        super().__init__()
        self.conv = nn.Sequential(
            # Multi-scale convolutions capture different GW chirp rates
            nn.Conv1d(1,  64,  kernel_size=32,  stride=2), nn.BatchNorm1d(64),  nn.ReLU(),
            nn.Conv1d(64, 128, kernel_size=16,  stride=2), nn.BatchNorm1d(128), nn.ReLU(),
            nn.Conv1d(128,256, kernel_size=8,   stride=2), nn.BatchNorm1d(256), nn.ReLU(),
            nn.Conv1d(256,512, kernel_size=4,   stride=2), nn.BatchNorm1d(512), nn.ReLU(),
            nn.AdaptiveAvgPool1d(1)
        )
        self.head = nn.Sequential(
            nn.Flatten(),
            nn.Linear(512, 128), nn.ReLU(), nn.Dropout(0.3),
            nn.Linear(128, 1),   nn.Sigmoid()
        )
    def forward(self, x): return self.head(self.conv(x))

# Training data: injected signals into real LIGO noise segments
# Public dataset: MLGWSC-1 (Machine Learning GW Search Comparison) — arxiv:2209.11146
gw_model = GWDetector()
print(f"Parameters: {sum(p.numel() for p in gw_model.parameters()):,}")
LIGO Open Science Center 🌐 Open Data All published LIGO/Virgo gravitational wave events — including GW150914, GW170817, and 90+ others — are freely available at gwosc.org. The gwpy library makes loading and processing these time series trivial in Python.

Section 5 — Exoplanet Transit Detection

When a planet passes in front of its host star, it blocks a tiny fraction of the star's light — typically 0.01% to 1%. This periodic dimming is a transit, and detecting it in noisy photometric light curves is an exquisitely sensitive signal-detection problem. The Kepler telescope alone produced light curves for over 150,000 stars, making manual vetting impossible.

NASA's Kepler mission used machine learning to identify planet candidates from the raw light curves. A notable 2018 paper by Shallue & Vanderburg trained a CNN on Kepler data and discovered two new exoplanets — including one in an 8-planet system — that had been missed by conventional pipelines. The approach is now standard practice for TESS exoplanet vetting.

Python — Dual-view CNN for exoplanet transit detection (global + local light curve views)
import lightkurve as lk       # pip install lightkurve
import numpy as np
import torch, torch.nn as nn

# ── Download a Kepler light curve ─────────────────────────────
search = lk.search_lightcurve('Kepler-22', mission='Kepler')
lcs    = search.download_all()
lc     = lcs.stitch().remove_nans().flatten(window_length=401)
print(f"Light curve: {len(lc)} cadences, baseline {lc.time.max()-lc.time.min():.0f} days")

# ── Phase-fold on known period to build transit template ───────
period_days = 289.86           # Kepler-22b orbital period
lc_folded   = lc.fold(period=period_days, epoch_time=2454833.0)

# ── CNN for transit classification ────────────────────────────
# Input: phase-folded + local-view light curve segments [N, 2, 201]
# 2 views: global (full period) + local (transit window)
class TransitClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        # Separate CNNs for global and local view
        self.global_cnn = self._make_cnn(seq_len=2001)
        self.local_cnn  = self._make_cnn(seq_len=201)
        self.classifier = nn.Sequential(
            nn.Linear(256, 128), nn.ReLU(), nn.Dropout(0.5),
            nn.Linear(128, 1),   nn.Sigmoid()
        )

    def _make_cnn(self, seq_len):
        return nn.Sequential(
            nn.Conv1d(1, 16, 5), nn.ReLU(), nn.MaxPool1d(2),
            nn.Conv1d(16,32, 5), nn.ReLU(), nn.MaxPool1d(2),
            nn.AdaptiveAvgPool1d(4), nn.Flatten(),
            nn.Linear(128, 128), nn.ReLU()
        )

    def forward(self, x_global, x_local):
        g = self.global_cnn(x_global.unsqueeze(1))
        l = self.local_cnn(x_local.unsqueeze(1))
        return self.classifier(torch.cat([g, l], dim=1))

# Based on: Shallue & Vanderburg (2018) — Identifying Exoplanets with DL, AJ 155:94

Section 6 — Symbolic Regression: Discovering Cosmological Laws from Data

This is perhaps the most philosophically exciting application of ML to astrophysics. Rather than fitting a neural network to data, symbolic regression searches the space of mathematical expressions to find the analytic formula that best describes your data. It literally tries to discover physical laws.

The PySR library (Cranmer 2023) has been used to recover Kepler's third law, the NFW dark matter profile, the Hubble-Lemaître law, and even previously unknown relationships in cosmological simulations — equations that were later validated theoretically.

NFW dark matter density profile: rho(r) = rho_0 / [(r/r_s)(1+r/r_s)^2]
Python — PySR symbolic regression: recovering Kepler's third law T² ∝ a³ from noisy data
# pip install pysr
from pysr import PySRRegressor
import numpy as np

# ── Simulate: recover Kepler's third law from orbital data ─────
# T² ∝ a³  where T = period, a = semi-major axis
np.random.seed(42)
a = np.random.uniform(0.3, 40, 500)  # semi-major axis [AU]
T = a**1.5 + np.random.normal(0, 0.1, 500)  # T [years] + noise

# ── Run symbolic regression ──────────────────────────────────
model = PySRRegressor(
    niterations=50,
    binary_operators=['+' , '*' , '/' , '-' , '**'],
    unary_operators=['sqrt' , 'log' , 'exp'],
    maxsize=10,             # limit expression complexity
    populations=30,         # number of independent populations
    parsimony=0.001,        # penalise complexity — prefer simpler laws
    verbosity=1
)
model.fit(a.reshape(-1,1), T)

# ── Results ─────────────────────────────────────────────────
print(model.get_best())   # should return: x0^1.5 or equivalent
print(model.equations_)

# In real astrophysics, run PySR on simulation outputs
# where the true law is UNKNOWN — that's where it gets exciting
# Example: Cranmer et al. 2020 applied this to cosmological N-body sims
# and found new scaling relations between galaxy properties

Section 7 — Simulation-Based Inference for Cosmological Parameters

Many of the most important questions in cosmology — the matter power spectrum, the Hubble constant, the dark energy equation of state — involve comparing observed data to expensive simulations. The problem is that running a full cosmological simulation can take thousands of CPU-hours, making traditional grid-based parameter estimation completely impractical.

Simulation-Based Inference (SBI), also called likelihood-free inference, sidesteps this by training a neural density estimator on simulation outputs. The network learns to map from data summaries to posterior distributions over parameters — without ever evaluating the likelihood analytically.

SBI loss: negative log-likelihood of the neural density estimator
Python — Simulation-Based Inference: neural posterior estimation for cosmological parameters
# pip install sbi
from sbi import utils as sbi_utils
from sbi import inference as sbi_inference
import torch

# ── Define cosmological simulator (simplified) ────────────────
# In practice this would call CLASS, CAMB, or an N-body code
def cosmology_simulator(theta):
    # theta: [Omega_m, sigma_8] — matter density and clustering amplitude
    Omega_m, sigma_8 = theta[:, 0], theta[:, 1]
    # Simplified: return 5-bin power spectrum summary statistic
    k = torch.linspace(0.01, 1.0, 5)
    P_k = sigma_8.unsqueeze(1) * Omega_m.unsqueeze(1) * k.pow(-2.0)
    return P_k + 0.05 * torch.randn_like(P_k)  # add noise

# ── Prior over cosmological parameters ───────────────────────
prior = sbi_utils.BoxUniform(
    low=torch.tensor([0.1, 0.5]),   # [Omega_m_min, sigma_8_min]
    high=torch.tensor([0.6, 1.5])   # [Omega_m_max, sigma_8_max]
)

# ── Run SBI: sample prior, simulate, train neural posterior ──
inference = sbi_inference.SNPE(prior=prior)     # Sequential Neural Posterior Estimation

theta_sim = prior.sample((10000,))       # 10k prior samples
x_sim     = cosmology_simulator(theta_sim)     # run simulator for each

inference = inference.append_simulations(theta_sim, x_sim)
density_estimator = inference.train()           # train neural posterior

# ── Given observed data, draw posterior samples ───────────────
x_observed = cosmology_simulator(torch.tensor([[0.3, 0.8]]))  # mock observation
posterior  = inference.build_posterior(density_estimator)
samples    = posterior.sample((5000,), x=x_observed)
print(f"Omega_m posterior: {samples[:,0].mean():.3f} ± {samples[:,0].std():.3f}")

External References & Further Reading

  • Baron (2019)Machine Learning in Astronomy: A Practical Overview. arXiv:1904.07248 — Comprehensive survey of ML across all subfields of astronomy.
  • Dieleman et al. (2015)Rotation-invariant CNNs for galaxy morphology prediction. MNRAS. arXiv:1503.07077 — The paper that brought CNNs to galaxy classification.
  • Shallue & Vanderburg (2018)Identifying Exoplanets with Deep Learning. AJ 155:94. arXiv:1712.05044 — CNN discovers new exoplanets missed by conventional pipelines.
  • George & Huerta (2018)Deep Learning for Real-Time GW Detection. Physics Letters B. arXiv:1711.03121 — Real-time GW classification with CNNs at LIGO sensitivity.
  • Cranmer et al. (2020)The frontier of simulation-based inference. PNAS. doi.org/10.1073/pnas.1912789117 — Definitive review of likelihood-free inference methods.
  • Cranmer (2023)PySR: Fast & Parallelized Symbolic Regression in Python/Julia. arXiv:2305.01582 — The PySR library paper. Used by ESA/NASA teams.
  • GWOSCgwosc.org — Gravitational Wave Open Science Center: all published GW events freely available.
📋 Key Takeaways — Cluster 4
  • The data bottleneck has inverted. Modern telescopes produce more data than humans can analyse. ML pipelines are no longer optional — they are load-bearing infrastructure.
  • Transfer learning works for astronomy. ResNet pre-trained on ImageNet transfers effectively to galaxy images — the low-level features (edges, textures) are universal. Always try fine-tuning before training from scratch.
  • Point estimates are not enough. Photometric redshifts, parameter estimates, and transit classifications all require uncertainty quantification. Use MDNs, Bayesian NNs, or SBI to get full posteriors.
  • GW detection in milliseconds. CNNs on whitened strain data achieve matched-filter sensitivity at 1,000× the speed — essential for multi-messenger alerts.
  • Symbolic regression is law discovery. PySR recovers analytic expressions from data — Kepler's laws, NFW profiles, Hubble's law — and has found relationships in simulations that were later theoretically explained.
  • SBI handles intractable likelihoods. When your simulator is expensive and your likelihood is intractable, neural posterior estimation gives you full Bayesian inference without ever computing the likelihood.

Leave a Comment

Your email address will not be published. Required fields are marked *

Scroll to Top