sys_mapping.likelihood
JAX-compiled Gaussian and skew-normal log-likelihoods for the contamination model.
The factory function make_log_likelihood() returns
a JIT-compiled callable that evaluates \(\ln\mathcal{L}(\boldsymbol\theta)\)
for a given combination of (n_sys, model, use_skewed). The Jacobian
correction \(-\sum_p \ln|1 + \sum_i b_i t_{ip}|\) is included for the
combined model (Berlfein et al. 2024, Eq. 17).
Key paper: Berlfein et al. 2024 — see also Methods Reference.
JAX JIT-compiled log-likelihood factory (Eq. 17-18, Berlfein et al. 2024).
- Gaussian log-likelihood (Eq. 17):
- ln L = -(N/2) ln(2πσ²) - Σ ln|1 + b·δ_{t,j}|
(1/2σ²) Σ δ_{g,j}²
where δ_{g,j} = (δ_g_obs,j - Σ a_i t_ij) / (1 + Σ b_i t_ij) is the clean overdensity.
- Skew-normal log-likelihood (Eq. 18):
ln L_skew = ln L_gauss + N·ln 2 + Σ log_ndtr(γ·δ_{g,j}/σ)
where ξ = -σ·δ·√(2/π), δ = γ/√(1+γ²) is the mean correction for zero mean, and log_ndtr(z) = log Φ(z) is the standard-normal log-CDF (numerically stable).
- sys_mapping.likelihood.make_log_likelihood(n_sys, model, use_skewed=False)[source]
Return a JIT-compiled log-likelihood function for the given model.
The returned function is decorated with
@jax.jitand traces on first call. Reuse it across MCMC iterations to amortize compilation cost.- Parameters:
- Returns:
log_likelihood(theta, delta_g_obs, delta_t) -> scalar –
theta: flat parameter vector (seepack_params())delta_g_obs: (n_pix,) observed overdensitydelta_t: (n_sys, n_pix) template values at unmasked pixelsPerformance
———–
Measured on CPU (n_pix=10_000, n_sys=5)
- **First call (JIT compile) (~227 ms* — call once before MCMC loops.*)
- **Subsequent calls (Gaussian) (~284 μs/call*.*)
- **Subsequent calls (skew-normal) (~592 μs/call** (extra
log_ndtrsum).)Precision
———
At zero contamination (a=b=0), matches
scipy.stats.norm.logpdfsum tortol=1e-5. JAX gradient at the additive OLS MLE is < 1e-4 for n_pix ≥ 10_000.Setting
gamma=0in skew mode recovers the Gaussian toatol=1e-6.
- Return type:
Callable[[Array, Array, Array], Array]
Examples
>>> import numpy as np, jax.numpy as jnp >>> from sys_mapping import make_log_likelihood, pack_params >>> rng = np.random.default_rng(0) >>> n_pix, n_sys = 5000, 3 >>> delta_g = rng.standard_normal(n_pix) * 0.1 >>> delta_t = rng.standard_normal((n_sys, n_pix)) >>> a = np.zeros(n_sys); b = np.zeros(n_sys); sigma = 0.1 >>> log_lik = make_log_likelihood(n_sys, "combined") # compile once >>> theta = jnp.asarray(pack_params(a, b, sigma, model="combined")) >>> ll = float(log_lik(theta, jnp.asarray(delta_g), jnp.asarray(delta_t))) >>> ll # typical value: large negative number