Source code for sys_mapping.contamination

"""Contamination model for galaxy overdensity systematics.

Implements the forward model (Eq. 11-13 of Berlfein et al. 2024):
    delta_g_obs = delta_g * (1 + b·delta_t) + a·delta_t

Three nested models:
  - 'additive':       b = 0, free params: [a_0,...,a_{N-1}]
  - 'multiplicative': b = a, free params: [a_0,...,a_{N-1}]
  - 'combined':       a, b free, params: [a_0,...,a_{N-1}, b_0,...,b_{N-1}]
"""

from __future__ import annotations

import jax.numpy as jnp
import numpy as np
from jax import Array

# ---------------------------------------------------------------------------
# Parameter layout helpers
# ---------------------------------------------------------------------------

[docs] def n_free_params(n_sys: int, model: str) -> int: """Number of contamination parameters (excluding sigma / gamma). Parameters ---------- n_sys : int Number of systematic templates. model : str One of ``'additive'``, ``'multiplicative'``, or ``'combined'``. Returns ------- int Number of free contamination parameters (not counting sigma or gamma). Precision --------- Exact integer arithmetic; no floating-point error. Examples -------- >>> from sys_mapping import n_free_params >>> n_free_params(4, "additive") 4 >>> n_free_params(4, "multiplicative") 4 >>> n_free_params(4, "combined") 8 """ match model: case "additive": return n_sys case "multiplicative": return n_sys case "combined": return 2 * n_sys case _: raise ValueError(f"Unknown model '{model}'. Choose additive, multiplicative, or combined.")
[docs] def pack_params( a: np.ndarray | Array, b: np.ndarray | Array | None, sigma: float, gamma: float | None = None, model: str = "combined", ) -> np.ndarray: """Pack (a, b, sigma[, gamma]) into a flat parameter vector. Parameters ---------- a : (n_sys,) additive coefficients b : (n_sys,) or None multiplicative coefficients (required for ``combined``) sigma : float noise standard deviation gamma : float or None skewness parameter (appended last when provided) model : str ``'additive'``, ``'multiplicative'``, or ``'combined'`` Returns ------- (n_dim,) flat numpy array suitable for emcee / scipy optimizers Precision --------- Pure concatenation; values are preserved to machine precision (``float64``). Examples -------- >>> import numpy as np >>> from sys_mapping import pack_params >>> a = np.array([0.05, -0.03, 0.02]) >>> b = np.array([0.04, 0.01, -0.02]) >>> theta = pack_params(a, b, sigma=0.12, model="combined") >>> theta # [a0, a1, a2, b0, b1, b2, sigma] array([ 0.05, -0.03, 0.02, 0.04, 0.01, -0.02, 0.12]) >>> pack_params(a, None, 0.12, gamma=1.5, model="additive") array([ 0.05, -0.03, 0.02, 0.12, 1.5 ]) """ parts: list[np.ndarray] = [np.asarray(a)] if model == "combined": if b is None: raise ValueError("combined model requires b") parts.append(np.asarray(b)) parts.append(np.array([sigma])) if gamma is not None: parts.append(np.array([gamma])) return np.concatenate(parts)
[docs] def unpack_params( theta: np.ndarray | Array, n_sys: int, model: str, use_skewed: bool = False, ) -> tuple[Array, Array, Array, Array | None]: """Unpack flat parameter vector into (a, b, sigma, gamma). Parameters ---------- theta : (n_dim,) flat parameter vector from :func:`pack_params` n_sys : int number of systematic templates model : str ``'additive'``, ``'multiplicative'``, or ``'combined'`` use_skewed : bool if True, extract gamma from the last element Returns ------- a : (n_sys,) additive coefficients (JAX array) b : (n_sys,) multiplicative coefficients; zeros for additive, equals a for multiplicative sigma : scalar noise standard deviation gamma : scalar skewness parameter or None Precision --------- Slice/view operations; values are preserved to machine precision. Examples -------- >>> import numpy as np, jax.numpy as jnp >>> from sys_mapping import pack_params, unpack_params >>> a = np.array([0.05, -0.03]) >>> b = np.array([0.04, 0.01]) >>> theta = pack_params(a, b, sigma=0.12, model="combined") >>> a2, b2, sigma2, _ = unpack_params(jnp.asarray(theta), n_sys=2, model="combined") >>> float(sigma2) 0.12 """ theta = jnp.asarray(theta) match model: case "additive": a = theta[:n_sys] b = jnp.zeros(n_sys) rest = theta[n_sys:] case "multiplicative": a = theta[:n_sys] b = a rest = theta[n_sys:] case "combined": a = theta[:n_sys] b = theta[n_sys : 2 * n_sys] rest = theta[2 * n_sys :] case _: raise ValueError(f"Unknown model '{model}'.") sigma = rest[0] gamma = rest[1] if use_skewed else None return a, b, sigma, gamma
# --------------------------------------------------------------------------- # Forward / inverse contamination # ---------------------------------------------------------------------------
[docs] def apply_contamination( delta_g: Array, delta_t: Array, a: Array, b: Array, ) -> Array: """Apply contamination to clean overdensity (Eq. 12). Computes ``δ_g_obs = δ_g * (1 + Σ b_i δ_{t,i}) + Σ a_i δ_{t,i}``. Parameters ---------- delta_g : (n_pix,) true galaxy overdensity delta_t : (n_sys, n_pix) systematic templates a : (n_sys,) additive coefficients b : (n_sys,) multiplicative coefficients Returns ------- (n_pix,) observed (contaminated) overdensity Performance ----------- Measured on CPU (JAX, n_pix=10_000, n_sys=5): **~830 μs/call** after warmup. Scales as O(n_sys × n_pix). Precision --------- Forward–inverse roundtrip recovers ``delta_g`` with max absolute error < 1e-5 (float32 accumulation). RMS error < 1e-6 for n_pix ≥ 50_000. Examples -------- >>> import numpy as np, jax.numpy as jnp >>> from sys_mapping import apply_contamination >>> rng = np.random.default_rng(0) >>> delta_g = jnp.asarray(rng.standard_normal(1000) * 0.1) >>> delta_t = jnp.asarray(rng.standard_normal((2, 1000))) >>> a = jnp.array([0.05, -0.03]) >>> b = jnp.array([0.04, 0.01]) >>> delta_g_obs = apply_contamination(delta_g, delta_t, a, b) >>> delta_g_obs.shape (1000,) """ mult_term = jnp.einsum("i,ij->j", b, delta_t) add_term = jnp.einsum("i,ij->j", a, delta_t) return delta_g * (1.0 + mult_term) + add_term
[docs] def invert_contamination( delta_g_obs: Array, delta_t: Array, a: Array, b: Array, ) -> Array: """Remove contamination from observed overdensity (Eq. 12 inverted). Computes ``δ_g = (δ_g_obs − Σ a_i δ_{t,i}) / (1 + Σ b_i δ_{t,i})``. Parameters ---------- delta_g_obs : (n_pix,) observed galaxy overdensity delta_t : (n_sys, n_pix) systematic templates a : (n_sys,) additive coefficients b : (n_sys,) multiplicative coefficients Returns ------- (n_pix,) cleaned galaxy overdensity Performance ----------- Measured on CPU (JAX, n_pix=10_000, n_sys=5): **~1730 μs/call** after warmup. Slightly slower than :func:`apply_contamination` due to the element-wise division. Precision --------- Roundtrip (apply → invert) max absolute error < 1e-5; RMS < 1e-6 for large n_pix. Assumes ``1 + b·δ_t ≠ 0`` at every pixel (satisfied for small \|b\|). Examples -------- >>> import numpy as np, jax.numpy as jnp >>> from sys_mapping import apply_contamination, invert_contamination >>> rng = np.random.default_rng(0) >>> delta_g = jnp.asarray(rng.standard_normal(1000) * 0.1) >>> delta_t = jnp.asarray(rng.standard_normal((2, 1000))) >>> a = jnp.array([0.05, -0.03]) >>> b = jnp.array([0.04, 0.01]) >>> obs = apply_contamination(delta_g, delta_t, a, b) >>> recovered = invert_contamination(obs, delta_t, a, b) >>> float(jnp.max(jnp.abs(recovered - delta_g))) # < 1e-5 """ mult_term = jnp.einsum("i,ij->j", b, delta_t) add_term = jnp.einsum("i,ij->j", a, delta_t) return (delta_g_obs - add_term) / (1.0 + mult_term)
# --------------------------------------------------------------------------- # Two-point correction (Eq. 15-16) # ---------------------------------------------------------------------------
[docs] def compute_two_point_correction( w_obs: Array, a_sq: Array, b_sq: Array, template_correlations: Array, ) -> Array: """Apply two-point function correction (Eq. 15-16). Computes ``w_corr = (w_obs − Σ a²_i ξ_i(θ)) / (1 + Σ b²_i ξ_i(θ))``, where ``ξ_i(θ) = ⟨δ_{t,i} δ_{t,i}⟩(θ)`` is the template auto-correlation. Parameters ---------- w_obs : (n_bins,) observed angular correlation function a_sq : (n_sys,) debiased squared additive coefficients b_sq : (n_sys,) debiased squared multiplicative coefficients template_correlations : (n_sys, n_bins) template auto-correlations per angular bin Returns ------- (n_bins,) corrected angular correlation function Performance ----------- Measured on CPU (JAX, n_sys=5, n_bins=15): **~576 μs/call** after warmup. Precision --------- Matches the analytic formula to ``rtol=1e-6`` (single float32 einsum). Examples -------- >>> import numpy as np, jax.numpy as jnp >>> from sys_mapping import compute_two_point_correction >>> n_sys, n_bins = 3, 15 >>> rng = np.random.default_rng(0) >>> w_obs = jnp.asarray(rng.standard_normal(n_bins) * 0.01) >>> a_sq = jnp.asarray(np.abs(rng.standard_normal(n_sys)) * 1e-3) >>> b_sq = jnp.asarray(np.abs(rng.standard_normal(n_sys)) * 1e-3) >>> tcorr = jnp.asarray(np.abs(rng.standard_normal((n_sys, n_bins))) * 1e-3) >>> w_corr = compute_two_point_correction(w_obs, a_sq, b_sq, tcorr) >>> w_corr.shape (15,) """ add_bias = jnp.einsum("i,ij->j", a_sq, template_correlations) mult_bias = jnp.einsum("i,ij->j", b_sq, template_correlations) return (w_obs - add_bias) / (1.0 + mult_bias)