Bayesian Sérsic profile fitting for galaxy photometry using JAX/NumPyro. Use when fitting galaxy surface brightness profiles, extracting structural parameters, or performing multi-component galaxy decomposition.
Fit Sérsic (and related) profiles to galaxy images using JAX-accelerated Bayesian inference via NumPyro.
pip install pysersic
# If arviz 1.0+ breaks it:
pip install 'arviz<1.0'
# If photutils is too old for numpy 2:
pip install --upgrade photutils
JAX backend:
pip install 'jax[cpu]'pip install --upgrade 'jax[cuda12]'Dependencies: JAX, NumPyro, astropy, photutils, scipy, matplotlib, arviz (<1.0), asdf, corner, equinox, interpax, tqdm.
profile_type | Description |
|---|---|
"sersic" | General Sérsic (free n) |
"exp" | Exponential (n=1, disk) |
"dev" | de Vaucouleurs (n=4, bulge) |
"doublesersic" | Two Sérsic components |
"pointsource" | Unresolved source |
import numpy as np
from astropy.io import fits
# Image cutout, RMS (sigma/noise) map, PSF, mask (True = bad pixel)
image = fits.getdata("image.fits").astype(float)
rms = np.sqrt(fits.getdata("variance.fits").astype(float))
psf = fits.getdata("psf.fits").astype(float)
psf /= psf.sum() # normalize
mask = fits.getdata("mask.fits").astype(bool) # True = masked
from pysersic import check_input_data
check_input_data(data=image, rms=rms, psf=psf, mask=mask)
Auto-prior (recommended for quick start):
from pysersic.priors import autoprior
prior = autoprior(
image=image,
profile_type="sersic",
mask=mask,
sky_type="none", # "none", "flat", or "tilted-plane"
)
Manual prior via SourceProperties:
from pysersic.priors import SourceProperties
props = SourceProperties(image, mask=mask)
# Override guesses if needed:
# props.set_flux_guess(flux, flux_err)
# props.set_r_eff_guess(r_eff, r_eff_err)
prior = props.generate_prior("sersic", sky_type="none")
from pysersic import FitSingle
from pysersic.loss import gaussian_loss, student_t_loss
from jax.random import PRNGKey
fitter = FitSingle(
data=image,
rms=rms,
psf=psf,
prior=prior,
mask=mask,
loss_func=student_t_loss, # robust to outliers
)
# MAP (fast, seconds)
map_params = fitter.find_MAP(rkey=PRNGKey(42))
# Posterior approximation (minutes)
svi_res = fitter.estimate_posterior(rkey=PRNGKey(43), method="svi-flow")
# Full NUTS sampling (slower, most accurate)
fitter.sample(rkey=PRNGKey(44))
sampling_res = fitter.sampling_results
from pysersic.results import plot_residual
# Residual plot: data | model | residual
fig, ax = plot_residual(image, map_params["model"], mask=mask, vmin=-1, vmax=1)
# Parameter summary
sampling_res.summary()
# Corner plot
fig = sampling_res.corner()
# Quantiles as DataFrame
df = sampling_res.retrieve_param_quantiles(return_dataframe=True)
# LaTeX table
sampling_res.latex_table()
# Save to ASDF
sampling_res.save_result("fit_result.asdf")
from pysersic import FitMulti
from pysersic.priors import PySersicMultiPrior, estimate_sky
from pysersic.results import parse_multi_results
# Source catalog (from sep, photutils, or manual)
catalog = {
"flux": [1000, 500, 50],
"x": [50.0, 30.0, 70.0],
"y": [50.0, 60.0, 40.0],
"r": [5.0, 3.0, 1.0], # effective radius guess
"type": ["sersic", "sersic", "pointsource"],
}
# Sky estimate
med_sky, std_sky, n_pix = estimate_sky(image, mask)
prior = PySersicMultiPrior(
catalog=catalog,
sky_type="flat",
sky_guess=med_sky,
sky_guess_err=2 * std_sky / np.sqrt(n_pix),
)
fm = FitMulti(data=image, rms=rms, psf=psf, prior=prior)
map_dict = fm.find_MAP(rkey=PRNGKey(99))
# Extract single source from multi-fit
fm.estimate_posterior(method="laplace", rkey=PRNGKey(100))
source_2 = parse_multi_results(fm.svi_results, 2)
source_2.summary()
| Parameter | Description | Typical range |
|---|---|---|
xc, yc | Centroid (pixels) | image center ± few pix |
flux / Ftot | Total flux | > 0 |
r_eff | Half-light radius (pixels) | 1–100 |
n | Sérsic index | 0.5–8 |
ellip | Ellipticity (1 - b/a) | 0–0.9 |
theta | Position angle (radians) | 0–π |
student_t_loss for robustness to bad pixels/artifacts; gaussian_loss for clean data.sky_type="none" if background is already subtracted; "flat" to fit a constant residual.True = bad/masked pixel (excluded from fit).svi-flow takes minutes; NUTS sampling takes 5–30 min per source.