Source code for sys_mapping.inference

"""MCMC inference using emcee for systematic contamination parameters.

Uses the JIT-compiled log-likelihood from likelihood.py.
emcee walkers are initialized near reasonable starting values and
the JAX computation is called from within a numpy-compatible wrapper.

GPU / vectorised mode
---------------------
Pass ``vectorize=True`` to :func:`run_mcmc` (or :func:`make_log_prob`) to
enable batched walker evaluation.  emcee then calls the log-prob once per step
with the full ``(n_walkers, n_dim)`` position matrix, and a single
``jax.vmap``-over-``jax.jit`` kernel evaluates all walkers simultaneously.
This turns ``n_walkers`` sequential JAX dispatches into one batched GPU kernel
per step — the primary source of GPU speedup.  Requires emcee ≥ 3.0 (which
supports ``EnsembleSampler(vectorize=True)``).
"""

from __future__ import annotations

import numpy as np
import emcee
import jax
import jax.numpy as jnp
from jax import Array

from .likelihood import make_log_likelihood
from .contamination import n_free_params, pack_params


[docs] def make_log_prob( n_sys: int, model: str, delta_g_obs: np.ndarray, delta_t: np.ndarray, use_skewed: bool = False, sigma_min: float = 1e-6, vectorize: bool = False, ) -> tuple[callable, int]: """Build an emcee-compatible log_prob function with a sigma > 0 prior. Wraps the JIT-compiled log-likelihood from :func:`~likelihood.make_log_likelihood` in a Python callable that emcee can call. Returns ``-inf`` for ``sigma ≤ sigma_min``. Parameters ---------- n_sys : int model : str ``'additive'``, ``'multiplicative'``, or ``'combined'`` delta_g_obs : (n_pix,) observed overdensity delta_t : (n_sys, n_pix) template values use_skewed : bool sigma_min : float lower bound on sigma (hard prior); default ``1e-6`` vectorize : bool If True, return a *batched* log_prob that accepts ``(n_walkers, n_dim)`` and returns ``(n_walkers,)``. Uses ``jax.vmap`` so all walkers are evaluated in a single GPU kernel. Pass ``vectorize=True`` to ``emcee.EnsembleSampler`` as well. Returns ------- log_prob : callable(theta) -> float (scalar mode) callable(thetas: (n_walkers, n_dim)) -> (n_walkers,) (vectorized mode) n_dim : int dimensionality of the parameter space Precision --------- All floating-point arithmetic delegated to the JAX log-likelihood (see :func:`~likelihood.make_log_likelihood` for precision details). Examples -------- >>> import numpy as np >>> from sys_mapping import make_log_prob >>> rng = np.random.default_rng(0) >>> delta_g = rng.standard_normal(2000) * 0.1 >>> delta_t = rng.standard_normal((3, 2000)) >>> log_prob, n_dim = make_log_prob(3, "combined", delta_g, delta_t) >>> n_dim # 3a + 3b + 1sigma = 7 7 >>> import numpy as np >>> theta0 = np.zeros(n_dim); theta0[-1] = 0.1 # sigma = 0.1 >>> log_prob(theta0) # finite negative number """ log_likelihood = make_log_likelihood(n_sys, model, use_skewed) _delta_g = jnp.asarray(delta_g_obs) _delta_t = jnp.asarray(delta_t) n_cont = n_free_params(n_sys, model) n_dim = n_cont + 1 + (1 if use_skewed else 0) if vectorize: # Single vmap+jit kernel: evaluates all walkers in one GPU dispatch. _batched_ll = jax.jit( jax.vmap(lambda theta: log_likelihood(theta, _delta_g, _delta_t)) ) def log_prob_vec(thetas: np.ndarray) -> np.ndarray: lps = np.array(_batched_ll(jnp.asarray(thetas, dtype=jnp.float64))) lps[thetas[:, n_cont] <= sigma_min] = -np.inf lps[~np.isfinite(lps)] = -np.inf return lps return log_prob_vec, n_dim def log_prob(theta: np.ndarray) -> float: sigma = theta[n_cont] if sigma <= sigma_min: return -np.inf lp = float(log_likelihood(jnp.asarray(theta), _delta_g, _delta_t)) if not np.isfinite(lp): return -np.inf return lp return log_prob, n_dim
[docs] def run_mcmc( n_sys: int, *, model: str = "combined", delta_g_obs: np.ndarray, delta_t: np.ndarray, use_skewed: bool = False, n_walkers: int = 250, n_steps: int = 1500, n_burn: int = 300, seed: int = 42, progress: bool = True, vectorize: bool = False, ) -> tuple[np.ndarray, emcee.EnsembleSampler]: """Run emcee MCMC to infer contamination parameters. Walkers are initialized near zero contamination with sigma near ``std(delta_g_obs)``. The first ``n_burn`` steps are discarded as burn-in. The **combined** model (joint :math:`a_i, b_i` inference) is the default and the recommended choice; it is strictly superior to the additive-only model across all contamination regimes (see :doc:`/results_systematic_tests`). Parameters ---------- n_sys : int model : str ``'combined'`` (default), ``'additive'``, or ``'multiplicative'`` delta_g_obs : (n_pix,) observed overdensity (keyword-only) delta_t : (n_sys, n_pix) template values (keyword-only) use_skewed : bool n_walkers : int number of emcee ensemble walkers (≥ 2 × n_dim) n_steps : int total MCMC steps (burn-in included) n_burn : int steps to discard; ``flat_chain`` has ``n_walkers*(n_steps-n_burn)`` rows seed : int progress : bool show tqdm progress bar vectorize : bool Evaluate all walkers in one batched JAX/GPU call per step (see :func:`make_log_prob`). Requires emcee ≥ 3.0. Recommended on GPU. Returns ------- flat_chain : (n_walkers × (n_steps − n_burn), n_dim) posterior samples sampler : emcee.EnsembleSampler Performance ----------- Wall time scales as O(n_walkers × n_steps × n_pix × n_sys). Typical runs: **minutes** for n_pix ~ 50_000, n_walkers=250, n_steps=1500. The JAX likelihood is JIT-compiled on the first walker evaluation (~227 ms), then executes at ~284 μs/eval (Gaussian) or ~592 μs/eval (skewed). Precision --------- Posterior median of the chain is taken as the point estimate via :func:`get_mle_params`. Posterior variance is estimated via :func:`get_param_variance_from_chain`. Examples -------- >>> import numpy as np >>> from sys_mapping import run_mcmc >>> rng = np.random.default_rng(0) >>> delta_g = rng.standard_normal(500) * 0.1 >>> delta_t = rng.standard_normal((2, 500)) >>> chain, sampler = run_mcmc( ... n_sys=2, model="additive", ... delta_g_obs=delta_g, delta_t=delta_t, ... n_walkers=50, n_steps=200, n_burn=50, ... progress=False, ... ) >>> chain.shape # (50 walkers * 150 steps, 3 params: a0, a1, sigma) (7500, 3) """ log_prob, n_dim = make_log_prob( n_sys, model, delta_g_obs, delta_t, use_skewed, vectorize=vectorize ) rng = np.random.default_rng(seed) n_cont = n_free_params(n_sys, model) sigma0 = float(np.std(delta_g_obs)) # Initial walker positions: spread over a range that covers typical # contamination amplitudes (0.02–0.10). A scale of 0.05 ensures walkers # explore the full posterior from the start rather than having to diffuse # away from an overly tight ball near zero. p0 = np.zeros((n_walkers, n_dim)) p0[:, :n_cont] = rng.normal(0.0, 0.05, (n_walkers, n_cont)) p0[:, n_cont] = sigma0 * (1.0 + rng.normal(0.0, 0.1, n_walkers)) if use_skewed: p0[:, n_cont + 1] = rng.normal(0.0, 0.1, n_walkers) sampler = emcee.EnsembleSampler(n_walkers, n_dim, log_prob, vectorize=vectorize) sampler.run_mcmc(p0, n_steps, progress=progress) flat_chain = sampler.get_chain(discard=n_burn, flat=True) return flat_chain, sampler
[docs] def get_mle_params(flat_chain: np.ndarray) -> np.ndarray: """Return the posterior median as a robust point estimate. Parameters ---------- flat_chain : (n_samples, n_dim) posterior samples from :func:`run_mcmc` Returns ------- (n_dim,) parameter vector (median of the flat chain) Performance ----------- O(n_samples × n_dim) numpy median; negligible compared to MCMC runtime. Precision --------- Median is a robust estimator; for symmetric posteriors it converges to the true parameter at rate O(1/√n_samples). Examples -------- >>> import numpy as np >>> from sys_mapping import get_mle_params >>> rng = np.random.default_rng(0) >>> flat_chain = rng.standard_normal((5000, 4)) # mock chain >>> theta_hat = get_mle_params(flat_chain) >>> theta_hat.shape (4,) """ # Use posterior median as a robust point estimate return np.median(flat_chain, axis=0)
[docs] def get_param_variance_from_chain( flat_chain: np.ndarray, n_sys: int, model: str, ) -> tuple[np.ndarray, np.ndarray]: """Estimate variance of contamination parameters from MCMC chain. Returns only the diagonal (per-parameter variance). For the full covariance matrix needed to propagate uncertainty through a basis rotation, use :func:`get_param_covariance_from_chain`. Parameters ---------- flat_chain : (n_samples, n_dim) posterior samples from :func:`run_mcmc` n_sys : int model : str ``'additive'``, ``'multiplicative'``, or ``'combined'`` Returns ------- var_a : (n_sys,) posterior variance of additive parameters var_b : (n_sys,) posterior variance of multiplicative parameters (all zeros for ``'additive'`` model) Examples -------- >>> import numpy as np >>> from sys_mapping import get_param_variance_from_chain >>> rng = np.random.default_rng(0) >>> n_sys, n_samples = 3, 2000 >>> flat_chain = rng.standard_normal((n_samples, 2 * n_sys + 1)) * 0.01 >>> var_a, var_b = get_param_variance_from_chain(flat_chain, n_sys, "combined") >>> var_a.shape (3,) """ cov_a, cov_b = get_param_covariance_from_chain(flat_chain, n_sys, model) return np.diag(cov_a), np.diag(cov_b)
[docs] def get_param_covariance_from_chain( flat_chain: np.ndarray, n_sys: int, model: str, ) -> tuple[np.ndarray, np.ndarray]: """Estimate the full covariance matrix of contamination parameters. Slices the flat chain directly rather than unpacking sample by sample. The returned (n_sys, n_sys) matrices are needed for correct variance propagation when back-transforming parameters from the PCA-rotated basis to the original template basis via ``R.T @ cov_rot @ R``. Parameters ---------- flat_chain : (n_samples, n_dim) posterior samples from :func:`run_mcmc` n_sys : int model : str ``'additive'``, ``'multiplicative'``, or ``'combined'`` Returns ------- cov_a : (n_sys, n_sys) posterior covariance of additive parameters cov_b : (n_sys, n_sys) posterior covariance of multiplicative parameters Examples -------- >>> import numpy as np >>> from sys_mapping import get_param_covariance_from_chain >>> rng = np.random.default_rng(0) >>> n_sys, n_samples = 3, 2000 >>> flat_chain = rng.standard_normal((n_samples, 2 * n_sys + 1)) * 0.01 >>> cov_a, cov_b = get_param_covariance_from_chain(flat_chain, n_sys, "combined") >>> cov_a.shape (3, 3) """ a_arr = flat_chain[:, :n_sys] if model == 'combined': b_arr = flat_chain[:, n_sys:2 * n_sys] elif model == 'multiplicative': b_arr = a_arr else: # additive b_arr = np.zeros_like(a_arr) cov_a = np.cov(a_arr, rowvar=False) if n_sys > 1 else np.atleast_2d(np.var(a_arr)) cov_b = np.cov(b_arr, rowvar=False) if n_sys > 1 else np.atleast_2d(np.var(b_arr)) return cov_a, cov_b