sys_mapping.likelihood

JAX-compiled Gaussian and skew-normal log-likelihoods for the contamination model.

The factory function make_log_likelihood() returns a JIT-compiled callable that evaluates \(\ln\mathcal{L}(\boldsymbol\theta)\) for a given combination of (n_sys, model, use_skewed). The Jacobian correction \(-\sum_p \ln|1 + \sum_i b_i t_{ip}|\) is included for the combined model (Berlfein et al. 2024, Eq. 17).

Key paper: Berlfein et al. 2024 — see also Methods Reference.

JAX JIT-compiled log-likelihood factory (Eq. 17-18, Berlfein et al. 2024).

Gaussian log-likelihood (Eq. 17):
ln L = -(N/2) ln(2πσ²) - Σ ln|1 + b·δ_{t,j}|
  • (1/2σ²) Σ δ_{g,j}²

where δ_{g,j} = (δ_g_obs,j - Σ a_i t_ij) / (1 + Σ b_i t_ij) is the clean overdensity.

Skew-normal log-likelihood (Eq. 18):

ln L_skew = ln L_gauss + N·ln 2 + Σ log_ndtr(γ·δ_{g,j}/σ)

where ξ = -σ·δ·√(2/π), δ = γ/√(1+γ²) is the mean correction for zero mean, and log_ndtr(z) = log Φ(z) is the standard-normal log-CDF (numerically stable).

sys_mapping.likelihood.make_log_likelihood(n_sys, model, use_skewed=False)[source]

Return a JIT-compiled log-likelihood function for the given model.

The returned function is decorated with @jax.jit and traces on first call. Reuse it across MCMC iterations to amortize compilation cost.

Parameters:
  • n_sys (int) – Number of systematic templates.

  • model (str) – One of 'additive', 'multiplicative', or 'combined'.

  • use_skewed (bool) – If True, use skew-normal likelihood (Eq. 18); otherwise Gaussian (Eq. 17).

Returns:

  • log_likelihood(theta, delta_g_obs, delta_t) -> scalartheta : flat parameter vector (see pack_params()) delta_g_obs : (n_pix,) observed overdensity delta_t : (n_sys, n_pix) template values at unmasked pixels

  • Performance

  • ———–

  • Measured on CPU (n_pix=10_000, n_sys=5)

  • - **First call (JIT compile) (~227 ms* — call once before MCMC loops.*)

  • - **Subsequent calls (Gaussian) (~284 μs/call*.*)

  • - **Subsequent calls (skew-normal) (~592 μs/call** (extra log_ndtr sum).)

  • Precision

  • ———

  • At zero contamination (a=b=0), matches scipy.stats.norm.logpdf sum to

  • rtol=1e-5. JAX gradient at the additive OLS MLE is < 1e-4 for n_pix ≥ 10_000.

  • Setting gamma=0 in skew mode recovers the Gaussian to atol=1e-6.

Return type:

Callable[[Array, Array, Array], Array]

Examples

>>> import numpy as np, jax.numpy as jnp
>>> from sys_mapping import make_log_likelihood, pack_params
>>> rng = np.random.default_rng(0)
>>> n_pix, n_sys = 5000, 3
>>> delta_g = rng.standard_normal(n_pix) * 0.1
>>> delta_t = rng.standard_normal((n_sys, n_pix))
>>> a = np.zeros(n_sys); b = np.zeros(n_sys); sigma = 0.1
>>> log_lik = make_log_likelihood(n_sys, "combined")   # compile once
>>> theta = jnp.asarray(pack_params(a, b, sigma, model="combined"))
>>> ll = float(log_lik(theta, jnp.asarray(delta_g), jnp.asarray(delta_t)))
>>> ll  # typical value: large negative number