sys_mapping.inference
MCMC posterior sampling and MAP estimation via the emcee ensemble sampler.
run_mcmc() wraps emcee.EnsembleSampler
with sensible defaults (250 walkers, 1500 steps, 300 burn-in) and handles
the rotated-template basis internally. get_mle_params()
returns the posterior median as the MAP estimate.
Key paper: Berlfein et al. 2024, Sec. 4 — see also Methods Reference.
MCMC inference using emcee for systematic contamination parameters.
Uses the JIT-compiled log-likelihood from likelihood.py. emcee walkers are initialized near reasonable starting values and the JAX computation is called from within a numpy-compatible wrapper.
GPU / vectorised mode
Pass vectorize=True to run_mcmc() (or make_log_prob()) to
enable batched walker evaluation. emcee then calls the log-prob once per step
with the full (n_walkers, n_dim) position matrix, and a single
jax.vmap-over-jax.jit kernel evaluates all walkers simultaneously.
This turns n_walkers sequential JAX dispatches into one batched GPU kernel
per step — the primary source of GPU speedup. Requires emcee ≥ 3.0 (which
supports EnsembleSampler(vectorize=True)).
- sys_mapping.inference.make_log_prob(n_sys, model, delta_g_obs, delta_t, use_skewed=False, sigma_min=1e-06, vectorize=False)[source]
Build an emcee-compatible log_prob function with a sigma > 0 prior.
Wraps the JIT-compiled log-likelihood from
make_log_likelihood()in a Python callable that emcee can call. Returns-infforsigma ≤ sigma_min.- Parameters:
n_sys (int)
model (str
'additive','multiplicative', or'combined')delta_g_obs ((n_pix,) observed overdensity)
use_skewed (bool)
sigma_min (float lower bound on sigma (hard prior); default
1e-6)vectorize (bool) – If True, return a batched log_prob that accepts
(n_walkers, n_dim)and returns(n_walkers,). Usesjax.vmapso all walkers are evaluated in a single GPU kernel. Passvectorize=Truetoemcee.EnsembleSampleras well.
- Returns:
log_prob (callable(theta) -> float (scalar mode)) – callable(thetas: (n_walkers, n_dim)) -> (n_walkers,) (vectorized mode)
n_dim (int dimensionality of the parameter space)
Precision
———
All floating-point arithmetic delegated to the JAX log-likelihood
(see
make_log_likelihood()for precision details).
- Return type:
Examples
>>> import numpy as np >>> from sys_mapping import make_log_prob >>> rng = np.random.default_rng(0) >>> delta_g = rng.standard_normal(2000) * 0.1 >>> delta_t = rng.standard_normal((3, 2000)) >>> log_prob, n_dim = make_log_prob(3, "combined", delta_g, delta_t) >>> n_dim # 3a + 3b + 1sigma = 7 7 >>> import numpy as np >>> theta0 = np.zeros(n_dim); theta0[-1] = 0.1 # sigma = 0.1 >>> log_prob(theta0) # finite negative number
- sys_mapping.inference.run_mcmc(n_sys, *, model='combined', delta_g_obs, delta_t, use_skewed=False, n_walkers=250, n_steps=1500, n_burn=300, seed=42, progress=True, vectorize=False)[source]
Run emcee MCMC to infer contamination parameters.
Walkers are initialized near zero contamination with sigma near
std(delta_g_obs). The firstn_burnsteps are discarded as burn-in.The combined model (joint \(a_i, b_i\) inference) is the default and the recommended choice; it is strictly superior to the additive-only model across all contamination regimes (see Model test matrix).
- Parameters:
n_sys (int)
model (str
'combined'(default),'additive', or'multiplicative')delta_g_obs ((n_pix,) observed overdensity (keyword-only))
use_skewed (bool)
n_walkers (int number of emcee ensemble walkers (≥ 2 × n_dim))
n_steps (int total MCMC steps (burn-in included))
n_burn (int steps to discard;
flat_chainhasn_walkers*(n_steps-n_burn)rows)seed (int)
progress (bool show tqdm progress bar)
vectorize (bool) – Evaluate all walkers in one batched JAX/GPU call per step (see
make_log_prob()). Requires emcee ≥ 3.0. Recommended on GPU.
- Returns:
flat_chain ((n_walkers × (n_steps − n_burn), n_dim) posterior samples)
sampler (emcee.EnsembleSampler)
Performance
———–
Wall time scales as O(n_walkers × n_steps × n_pix × n_sys).
Typical runs (*minutes for n_pix ~ 50_000, n_walkers=250, n_steps=1500.*)
The JAX likelihood is JIT-compiled on the first walker evaluation (~227 ms),
then executes at ~284 μs/eval (Gaussian) or ~592 μs/eval (skewed).
Precision
———
Posterior median of the chain is taken as the point estimate via
get_mle_params(). Posterior variance is estimated via
- Return type:
Examples
>>> import numpy as np >>> from sys_mapping import run_mcmc >>> rng = np.random.default_rng(0) >>> delta_g = rng.standard_normal(500) * 0.1 >>> delta_t = rng.standard_normal((2, 500)) >>> chain, sampler = run_mcmc( ... n_sys=2, model="additive", ... delta_g_obs=delta_g, delta_t=delta_t, ... n_walkers=50, n_steps=200, n_burn=50, ... progress=False, ... ) >>> chain.shape # (50 walkers * 150 steps, 3 params: a0, a1, sigma) (7500, 3)
- sys_mapping.inference.get_mle_params(flat_chain)[source]
Return the posterior median as a robust point estimate.
- Parameters:
flat_chain ((n_samples, n_dim) posterior samples from
run_mcmc())- Returns:
(n_dim,) parameter vector (median of the flat chain)
Performance
———–
O(n_samples × n_dim) numpy median; negligible compared to MCMC runtime.
Precision
———
Median is a robust estimator; for symmetric posteriors it converges to
the true parameter at rate O(1/√n_samples).
- Return type:
Examples
>>> import numpy as np >>> from sys_mapping import get_mle_params >>> rng = np.random.default_rng(0) >>> flat_chain = rng.standard_normal((5000, 4)) # mock chain >>> theta_hat = get_mle_params(flat_chain) >>> theta_hat.shape (4,)
- sys_mapping.inference.get_param_variance_from_chain(flat_chain, n_sys, model)[source]
Estimate variance of contamination parameters from MCMC chain.
Returns only the diagonal (per-parameter variance). For the full covariance matrix needed to propagate uncertainty through a basis rotation, use
get_param_covariance_from_chain().- Parameters:
flat_chain ((n_samples, n_dim) posterior samples from
run_mcmc())n_sys (int)
model (str
'additive','multiplicative', or'combined')
- Returns:
var_a ((n_sys,) posterior variance of additive parameters)
var_b ((n_sys,) posterior variance of multiplicative parameters) – (all zeros for
'additive'model)
- Return type:
Examples
>>> import numpy as np >>> from sys_mapping import get_param_variance_from_chain >>> rng = np.random.default_rng(0) >>> n_sys, n_samples = 3, 2000 >>> flat_chain = rng.standard_normal((n_samples, 2 * n_sys + 1)) * 0.01 >>> var_a, var_b = get_param_variance_from_chain(flat_chain, n_sys, "combined") >>> var_a.shape (3,)
- sys_mapping.inference.get_param_covariance_from_chain(flat_chain, n_sys, model)[source]
Estimate the full covariance matrix of contamination parameters.
Slices the flat chain directly rather than unpacking sample by sample. The returned (n_sys, n_sys) matrices are needed for correct variance propagation when back-transforming parameters from the PCA-rotated basis to the original template basis via
R.T @ cov_rot @ R.- Parameters:
flat_chain ((n_samples, n_dim) posterior samples from
run_mcmc())n_sys (int)
model (str
'additive','multiplicative', or'combined')
- Returns:
cov_a ((n_sys, n_sys) posterior covariance of additive parameters)
cov_b ((n_sys, n_sys) posterior covariance of multiplicative parameters)
- Return type:
Examples
>>> import numpy as np >>> from sys_mapping import get_param_covariance_from_chain >>> rng = np.random.default_rng(0) >>> n_sys, n_samples = 3, 2000 >>> flat_chain = rng.standard_normal((n_samples, 2 * n_sys + 1)) * 0.01 >>> cov_a, cov_b = get_param_covariance_from_chain(flat_chain, n_sys, "combined") >>> cov_a.shape (3, 3)