"""Regression-based systematic decontamination methods.
Implements:
- ElasticNet regularised regression (Weaverdyck & Huterer 2021)
- Iterative Systematics Decontamination via polynomial OLS (Rodríguez-Monroy+2025)
- Multi-method comparison framework (Weaverdyck & Huterer 2021)
- Unified run_decontamination() interface for all 6 methods
"""
from __future__ import annotations
import itertools
import time
import warnings
from typing import Callable
import numpy as np
_VALID_METHODS = frozenset({"OLS", "ElasticNet", "ISD-1", "ISD-3", "MCMC-add", "MCMC-comb"})
# Maximum weight allowed during ISD iterations and in the final ISD output.
# Pixels where 1 + a@t ≈ 0 would otherwise produce weights → 1/eps = 1e6,
# making all subsequent re-weighting steps meaningless. A cap of 20 corresponds
# to a denominator floor of 0.05 — generous enough not to bias well-behaved
# scenarios while preventing numerical blow-up under strong contamination.
_ISD_MAX_WEIGHT: float = 20.0
def _compute_weights(
alpha: np.ndarray,
delta_t: np.ndarray,
eps: float = 1e-6,
) -> np.ndarray:
"""Per-pixel systematic weights w(p) = 1 / (1 + alpha · t(p)).
Parameters
----------
alpha:
Regression coefficients (shape ``(n_sys,)``).
delta_t:
Template values (shape ``(n_sys, n_pix)``).
eps:
Minimum denominator to prevent division by zero.
Returns
-------
weights:
Per-pixel weights (shape ``(n_pix,)``), all positive.
"""
systematic_field = alpha @ delta_t # (n_pix,)
denominator = np.maximum(1.0 + systematic_field, eps)
return 1.0 / denominator
[docs]
def elasticnet_contamination_fit(
delta_g_obs: np.ndarray,
delta_t: np.ndarray,
*,
l1_ratio: float = 0.5,
alpha_reg: float | None = None,
cv_folds: int = 5,
max_iter: int = 10_000,
fit_intercept: bool = False,
) -> tuple[np.ndarray, np.ndarray, dict]:
"""Fit systematic contamination amplitudes via ElasticNet regression.
Minimises the penalised least-squares loss (Weaverdyck & Huterer 2021, Eq. 9):
.. math::
\\mathcal{L} = \\frac{1}{2N_{\\rm pix}}
\\left\\|\\delta_g - \\sum_i\\alpha_i\\,t_i\\right\\|_2^2
+ \\lambda_1\\|\\boldsymbol\\alpha\\|_1
+ \\lambda_2\\|\\boldsymbol\\alpha\\|_2^2
where the total regularisation strength and L1 fraction are controlled by
``alpha_reg`` and ``l1_ratio``.
Parameters
----------
delta_g_obs:
Observed galaxy overdensity (shape ``(n_pix,)``).
delta_t:
Template maps (shape ``(n_sys, n_pix)``).
l1_ratio:
ElasticNet mixing parameter: 1.0 = Lasso, 0.0 = Ridge.
alpha_reg:
Regularisation strength. If ``None``, tuned automatically via
``cv_folds``-fold cross-validation (``ElasticNetCV``).
cv_folds:
Number of cross-validation folds used when ``alpha_reg=None``.
max_iter:
Maximum coordinate descent iterations.
fit_intercept:
Whether to fit a constant offset. Typically ``False`` because
the overdensity field has zero mean by construction.
Returns
-------
alpha_hat:
Fitted contamination amplitudes (shape ``(n_sys,)``).
weights:
Per-pixel systematic weights ``w(p) = 1 / (1 + alpha_hat · t(p))``
(shape ``(n_pix,)``).
cv_info:
Dictionary with keys ``alpha_reg`` (used value), ``l1_ratio``,
and ``cv_scores`` (array of cross-validation MSE scores, or ``None``
if ``alpha_reg`` was provided explicitly).
Notes
-----
Requires ``scikit-learn >= 1.3``. Install via::
pip install "sys_mapping[regression]"
The per-pixel weight corresponds to the DES-Y1 / Weaverdyck+2021 form
:math:`w(p) = (1 + \\hat{\\boldsymbol\\alpha}\\cdot\\mathbf{t}(p))^{-1}`.
Examples
--------
>>> import numpy as np
>>> from sys_mapping import elasticnet_contamination_fit
>>> rng = np.random.default_rng(42)
>>> n_pix, n_sys = 5000, 3
>>> delta_t = rng.standard_normal((n_sys, n_pix))
>>> true_alpha = np.array([0.2, -0.1, 0.05])
>>> delta_g = true_alpha @ delta_t + rng.standard_normal(n_pix) * 0.3
>>> alpha_hat, weights, info = elasticnet_contamination_fit(
... delta_g, delta_t, alpha_reg=1e-4)
>>> weights.shape
(5000,)
>>> np.all(weights > 0)
True
References
----------
Weaverdyck & Huterer 2021, MNRAS 503, 5061.
"""
try:
from sklearn.linear_model import ElasticNet, ElasticNetCV
except ImportError as exc:
raise ImportError(
"scikit-learn is required for elasticnet_contamination_fit. "
"Install it via: pip install 'sys_mapping[regression]'"
) from exc
X = delta_t.T # (n_pix, n_sys)
y = delta_g_obs # (n_pix,)
cv_scores = None
if alpha_reg is None:
model = ElasticNetCV(
l1_ratio=l1_ratio,
cv=cv_folds,
max_iter=max_iter,
fit_intercept=fit_intercept,
n_jobs=-1,
)
model.fit(X, y)
alpha_reg = float(model.alpha_)
cv_scores = model.mse_path_.mean(axis=-1)
else:
model = ElasticNet(
alpha=alpha_reg,
l1_ratio=l1_ratio,
max_iter=max_iter,
fit_intercept=fit_intercept,
)
model.fit(X, y)
alpha_hat = model.coef_ # (n_sys,)
weights = np.clip(
_compute_weights(alpha_hat, delta_t),
1.0 / _ISD_MAX_WEIGHT, _ISD_MAX_WEIGHT,
)
cv_info = {
"alpha_reg": alpha_reg,
"l1_ratio": l1_ratio,
"cv_scores": cv_scores,
}
return alpha_hat, weights, cv_info
[docs]
def iterative_systematics_decontamination(
delta_g_obs: np.ndarray,
delta_t: np.ndarray,
*,
poly_order: int = 3,
max_iter: int = 20,
tol: float = 1e-5,
lambda_poly: float = 0.0,
) -> tuple[np.ndarray, np.ndarray, int]:
"""Iterative Systematics Decontamination (ISD) via polynomial OLS.
Expands templates to include cross-products up to ``poly_order``, then
iteratively fits OLS weights until convergence. Implements the approach
of Rodríguez-Monroy et al. 2025 (DES Y6), Sec. 3.2.
The ISD cleansed overdensity after convergence is:
.. math::
\\delta_g^{\\rm clean}(p) =
\\frac{\\delta_g(p) - f_{\\rm add}(p)}{1 + f_{\\rm mult}(p)}
where :math:`f_{\\rm add}` and :math:`f_{\\rm mult}` are the additive and
multiplicative systematic components estimated from polynomial template
regression (here approximated by a single linear OLS pass per iteration).
Parameters
----------
delta_g_obs:
Observed galaxy overdensity (shape ``(n_pix,)``).
delta_t:
Template maps (shape ``(n_sys, n_pix)``).
poly_order:
Maximum polynomial order for template expansion. Order 1 = linear
templates only; order 2 adds pairwise products; order 3 adds triple
products. Automatically reduced if the expanded template matrix is
underdetermined.
max_iter:
Maximum number of OLS iterations.
tol:
Convergence tolerance on the relative change in weights
(``||w_new - w_old|| / ||w_old||``).
lambda_poly:
Ridge penalty applied exclusively to the polynomial (non-linear)
expansion columns — the first ``n_sys`` linear columns are never
penalised. A small positive value (e.g. ``1e-3 * np.var(delta_g_obs)``)
prevents the polynomial terms from absorbing signal that belongs to the
linear coefficients, which are the only ones used for the weight update.
Defaults to ``0.0`` (plain OLS, backward-compatible).
Returns
-------
weights:
Per-pixel systematic weights (shape ``(n_pix,)``), all positive.
alpha_hat_all:
Fitted coefficients for the expanded template matrix at the final
iteration (shape ``(n_expanded,)``).
n_iterations:
Number of iterations performed (1 if converged on first pass).
Notes
-----
Template expansion uses
``itertools.combinations_with_replacement`` — no new dependencies.
For ``n_sys=5``, ``poly_order=3`` the expanded set has
:math:`\\binom{n+k}{k} - 1 = 55` columns. If ``n_pix < 10 × n_expanded``,
``poly_order`` is reduced with a warning.
**Speed:** the design matrix ``X`` is constant across iterations; only the
diagonal weight matrix changes. The per-iteration cost is therefore one
BLAS-3 call ``X^T diag(w) X`` of shape ``(n_expanded, n_expanded)`` plus a
small ``lstsq`` solve — roughly 5–15× faster than SVD of the full
``(n_pix × n_expanded)`` matrix that the previous sqrt-weight formulation
required.
Examples
--------
>>> import numpy as np
>>> from sys_mapping import iterative_systematics_decontamination
>>> rng = np.random.default_rng(0)
>>> n_pix, n_sys = 8000, 3
>>> delta_t = rng.standard_normal((n_sys, n_pix))
>>> true_alpha = np.array([0.15, -0.08, 0.04])
>>> delta_g = true_alpha @ delta_t + rng.standard_normal(n_pix) * 0.4
>>> weights, alpha_all, n_it = iterative_systematics_decontamination(
... delta_g, delta_t, poly_order=2, max_iter=10)
>>> weights.shape
(8000,)
>>> np.all(weights > 0)
True
References
----------
Rodríguez-Monroy et al. 2025, arXiv:2509.07943.
"""
n_sys, n_pix = delta_t.shape
# Build polynomial template expansion
def _build_expanded(t: np.ndarray, order: int) -> np.ndarray:
rows = [t]
for deg in range(2, order + 1):
for combo in itertools.combinations_with_replacement(range(n_sys), deg):
product = np.ones(n_pix)
for idx in combo:
product *= t[idx]
rows.append(product[np.newaxis, :])
return np.vstack(rows) # (n_expanded, n_pix)
# Reduce poly_order if problem is underdetermined
current_order = poly_order
while current_order > 1:
test = _build_expanded(delta_t, current_order)
n_expanded = test.shape[0]
if n_pix >= 10 * n_expanded:
break
warnings.warn(
f"poly_order={current_order} gives {n_expanded} columns but n_pix={n_pix}. "
f"Reducing poly_order to {current_order - 1}.",
stacklevel=2,
)
current_order -= 1
delta_t_expanded = _build_expanded(delta_t, current_order)
n_expanded = delta_t_expanded.shape[0]
weights = np.ones(n_pix)
alpha_hat_all = np.zeros(n_expanded)
# X is (n_pix, n_expanded) and never changes between iterations; precompute
# its transpose so the per-iteration cost is one BLAS-3 matmul rather than a
# full SVD of the (n_pix × n_expanded) matrix.
X = delta_t_expanded.T # (n_pix, n_expanded)
Xt = delta_t_expanded # (n_expanded, n_pix)
# Ridge penalty on polynomial-only columns: prevents them from absorbing
# signal that belongs to the linear coefficients used in the weight update.
# Linear columns (first n_sys) are never penalised.
ridge_diag = np.zeros(n_expanded)
ridge_diag[n_sys:] = lambda_poly
for iteration in range(1, max_iter + 1):
weights_old = weights.copy()
# Weighted normal equations: (X^T W X + diag(ridge)) α = X^T W y.
# Xw = X^T diag(w) is computed via broadcasting: (n_expanded, n_pix).
# XtWX and XtWy are (n_expanded, n_expanded) and (n_expanded,) — small
# regardless of n_pix. lstsq on this small system is SVD-stable even
# when lambda_poly=0.
Xw = Xt * weights # (n_expanded, n_pix): row i scaled by w
XtWX = Xw @ X # (n_expanded, n_expanded)
XtWX.flat[:: n_expanded + 1] += ridge_diag # add ridge to diagonal in-place
XtWy = Xw @ delta_g_obs # (n_expanded,)
alpha_hat_all, *_ = np.linalg.lstsq(XtWX, XtWy, rcond=None)
weights = _compute_weights(alpha_hat_all[:n_sys], delta_t)
# Clip to prevent divergent re-weighting. Without this, pixels where
# 1 + a@t → 0 get weight 1/eps = 1e6, which makes all subsequent
# weighted OLS fits numerically meaningless. Clip to [1/w_max, w_max]
# so the iteration stays in a stable regime.
weights = np.clip(weights, 1.0 / _ISD_MAX_WEIGHT, _ISD_MAX_WEIGHT)
rel_change = np.linalg.norm(weights - weights_old) / (np.linalg.norm(weights_old) + 1e-30)
if rel_change < tol:
return weights, alpha_hat_all, iteration
return weights, alpha_hat_all, max_iter
[docs]
def method_comparison(
delta_g_obs: np.ndarray,
delta_t: np.ndarray,
good_pixels: np.ndarray,
nside: int,
methods: tuple[str, ...] = ("combined_mcmc", "additive_mcmc", "elasticnet", "ols"),
seed: int = 42,
mcmc_kwargs: dict | None = None,
) -> dict[str, dict]:
"""Run multiple decontamination methods on the same data and compare.
Implements the multi-method comparison framework of Weaverdyck & Huterer
2021 (Sec. 4). Each method returns coefficients, per-pixel weights, and
a goodness-of-fit metric, enabling direct comparison.
Parameters
----------
delta_g_obs:
Observed overdensity (shape ``(n_pix,)``).
delta_t:
Template maps (shape ``(n_sys, n_pix)``).
good_pixels:
Boolean mask (shape ``(n_full_pix,)``). Passed through for spatial
context but not used in computation directly (templates already masked).
nside:
HEALPix NSIDE of the full map.
methods:
Subset of ``{"combined_mcmc", "additive_mcmc", "elasticnet", "ols"}``.
``"combined_mcmc"`` (the default first entry) is the recommended method;
it infers both additive and multiplicative amplitudes jointly.
seed:
Random seed for MCMC initialisation.
mcmc_kwargs:
Keyword arguments forwarded to :func:`~sys_mapping.inference.run_mcmc`
for the ``"additive_mcmc"`` method.
Returns
-------
results:
Dictionary keyed by method name. Each value is a dict with:
- ``alpha_hat`` or ``a_hat``: contamination amplitudes
- ``weights``: per-pixel weights (shape ``(n_pix,)``)
- ``chi2_residual``: sum-of-squared residuals / n_pix
Notes
-----
``"additive_mcmc"`` requires no optional dependencies (uses the built-in
MCMC pipeline). ``"elasticnet"`` requires ``scikit-learn``.
Examples
--------
>>> import numpy as np
>>> from sys_mapping import method_comparison
>>> rng = np.random.default_rng(0)
>>> n_pix, n_sys = 3000, 2
>>> delta_t = rng.standard_normal((n_sys, n_pix))
>>> delta_g = 0.1 * delta_t[0] + rng.standard_normal(n_pix) * 0.5
>>> good = np.ones(n_pix, dtype=bool)
>>> results = method_comparison(delta_g, delta_t, good, nside=16,
... methods=("elasticnet", "ols"))
>>> set(results.keys()) == {"elasticnet", "ols"}
True
>>> "weights" in results["ols"]
True
References
----------
Weaverdyck & Huterer 2021, MNRAS 503, 5061.
"""
results: dict[str, dict] = {}
n_pix = delta_g_obs.shape[0]
def _chi2(alpha: np.ndarray) -> float:
residual = delta_g_obs - alpha @ delta_t
return float(np.sum(residual**2) / n_pix)
if "ols" in methods:
X = delta_t.T # (n_pix, n_sys)
alpha_ols, *_ = np.linalg.lstsq(X, delta_g_obs, rcond=None)
weights_ols = np.clip(
_compute_weights(alpha_ols, delta_t),
1.0 / _ISD_MAX_WEIGHT, _ISD_MAX_WEIGHT,
)
results["ols"] = {
"alpha_hat": alpha_ols,
"weights": weights_ols,
"chi2_residual": _chi2(alpha_ols),
}
if "elasticnet" in methods:
alpha_en, weights_en, cv_info = elasticnet_contamination_fit(
delta_g_obs, delta_t, alpha_reg=None
)
results["elasticnet"] = {
"alpha_hat": alpha_en,
"weights": weights_en,
"chi2_residual": _chi2(alpha_en),
"cv_info": cv_info,
}
if "combined_mcmc" in methods or "additive_mcmc" in methods:
from .inference import run_mcmc, get_mle_params
from .contamination import unpack_params
kw = {"seed": seed}
if mcmc_kwargs:
kw.update(mcmc_kwargs)
if "combined_mcmc" in methods:
flat_chain, _ = run_mcmc(
n_sys=delta_t.shape[0],
model="combined",
delta_g_obs=delta_g_obs,
delta_t=delta_t,
**kw,
)
theta_hat = get_mle_params(flat_chain)
n_sys = delta_t.shape[0]
a_hat, b_hat, _, _ = unpack_params(theta_hat, n_sys, "combined", use_skewed=False)
from .contamination import invert_contamination as _inv_cont
import jax.numpy as _jnp
_dg_clean = np.asarray(_inv_cont(
_jnp.asarray(delta_g_obs), _jnp.asarray(delta_t),
_jnp.asarray(a_hat), _jnp.asarray(b_hat),
))
weights_comb = np.clip(
(1.0 + _dg_clean) / np.maximum(1.0 + delta_g_obs, 1e-6),
1.0 / _ISD_MAX_WEIGHT, _ISD_MAX_WEIGHT,
)
results["combined_mcmc"] = {
"a_hat": np.asarray(a_hat),
"b_hat": np.asarray(b_hat),
"weights": weights_comb,
"chi2_residual": _chi2(np.asarray(a_hat)),
}
if "additive_mcmc" in methods:
flat_chain, _ = run_mcmc(
n_sys=delta_t.shape[0],
model="additive",
delta_g_obs=delta_g_obs,
delta_t=delta_t,
**kw,
)
theta_hat = get_mle_params(flat_chain)
n_sys = delta_t.shape[0]
a_hat, _, _, _ = unpack_params(theta_hat, n_sys, "additive", use_skewed=False)
weights_mcmc = np.clip(
_compute_weights(a_hat, delta_t),
1.0 / _ISD_MAX_WEIGHT, _ISD_MAX_WEIGHT,
)
results["additive_mcmc"] = {
"a_hat": np.asarray(a_hat),
"weights": weights_mcmc,
"chi2_residual": _chi2(np.asarray(a_hat)),
}
return results
[docs]
def run_decontamination(
method: str,
delta_g_obs: np.ndarray,
delta_t: np.ndarray,
*,
n_walkers: int = 100,
n_steps: int = 1200,
n_burn: int = 200,
seed: int = 42,
use_skewed: bool = False,
progress: bool = False,
cv_folds: int = 5,
isd_max_iter: int = 50,
isd_lambda_poly: float = 0.0,
) -> dict:
"""Run a single decontamination method and return a standardised result dict.
This is the unified entry point used by both validation scripts and
production pipelines. All 6 methods share the same interface so callers
can iterate over a list of method names without any method-specific
boilerplate.
Parameters
----------
method:
One of ``"OLS"``, ``"ElasticNet"``, ``"ISD-1"``, ``"ISD-3"``,
``"MCMC-add"``, ``"MCMC-comb"``.
delta_g_obs:
Observed galaxy overdensity at unmasked pixels, shape ``(n_pix,)``.
delta_t:
Template maps in the *original* (non-rotated) basis,
shape ``(n_sys, n_pix)``. MCMC methods rotate internally.
n_walkers, n_steps, n_burn:
MCMC sampler settings.
seed:
Random seed for MCMC walker initialisation.
use_skewed:
Use skew-normal likelihood for MCMC-comb. Ignored for all other
methods.
progress:
Show emcee progress bar (MCMC methods only).
cv_folds:
Cross-validation folds for ElasticNet (when ``alpha_reg`` is auto).
isd_max_iter:
Maximum iterations for ISD methods.
isd_lambda_poly:
Ridge penalty on the polynomial expansion columns for ISD-3 (see
:func:`iterative_systematics_decontamination`). A value of
``1e-3 * np.var(delta_g_obs)`` is recommended when ISD-3 converges
to implausible solutions. Has no effect for ISD-1 (poly_order=1 has
no polynomial-only columns). Default ``0.0`` (plain OLS).
Returns
-------
dict with keys:
All methods
``a_hat`` (n_sys,), ``b_hat`` (n_sys,), ``weights`` (n_pix,),
``elapsed_s`` (float).
MCMC methods additionally
``flat_chain``, ``sampler``, ``a_rot``, ``b_rot``,
``cov_a_rot``, ``cov_b_rot``, ``cov_a``, ``cov_b``,
``R``, ``acceptance_fraction``, ``sigma_hat``.
ISD methods additionally
``n_iterations`` (int).
ElasticNet additionally
``cv_info`` (dict).
Non-applicable keys are ``None``.
Notes
-----
**Weight convention**
* OLS / ElasticNet / ISD-1 / ISD-3 / MCMC-add:
``w(p) = 1 / (1 + a_hat @ t(p))`` — additive weight.
* MCMC-comb:
``w(p) = (1 + δ_g_clean(p)) / (1 + δ_g_obs(p))`` — exact pixel-level
inverse of the contamination weight; cancels ``weight_cont`` exactly
when (a_hat, b_hat) equal the true parameters.
References
----------
Berlfein et al. 2024, MNRAS 531, 4954.
Weaverdyck & Huterer 2021, MNRAS 503, 5061.
Rodríguez-Monroy et al. 2025, arXiv:2509.07943.
"""
if method not in _VALID_METHODS:
raise ValueError(
f"method must be one of {sorted(_VALID_METHODS)}, got {method!r}"
)
n_sys, n_pix = delta_t.shape
t0 = time.perf_counter()
# ── Default return scaffold (all None for method-specific extras) ─────────
result: dict = {
"a_hat": None, "b_hat": None, "weights": None, "elapsed_s": None,
# MCMC-only
"flat_chain": None, "sampler": None,
"a_rot": None, "b_rot": None,
"cov_a_rot": None, "cov_b_rot": None,
"cov_a": None, "cov_b": None,
"R": None,
"acceptance_fraction": None, "sigma_hat": None,
# ISD-only
"n_iterations": None,
"isd_outlier_mask": None, # bool (n_pix,): True = pixel excluded in pass 2
"isd_masked_fraction": None, # float: fraction of pixels excluded in pass 2
# ElasticNet-only
"cv_info": None,
}
# ── OLS ───────────────────────────────────────────────────────────────────
if method == "OLS":
X = delta_t.T # (n_pix, n_sys)
a_hat, *_ = np.linalg.lstsq(X, delta_g_obs, rcond=None)
b_hat = np.zeros(n_sys)
weights = np.clip(
_compute_weights(a_hat, delta_t),
1.0 / _ISD_MAX_WEIGHT, _ISD_MAX_WEIGHT,
)
result.update({"a_hat": a_hat, "b_hat": b_hat, "weights": weights})
# ── ElasticNet ────────────────────────────────────────────────────────────
elif method == "ElasticNet":
a_hat, weights, cv_info = elasticnet_contamination_fit(
delta_g_obs, delta_t, cv_folds=cv_folds
)
b_hat = np.zeros(n_sys)
result.update({"a_hat": a_hat, "b_hat": b_hat, "weights": weights,
"cv_info": cv_info})
# ── ISD-1 / ISD-3 ─────────────────────────────────────────────────────────
elif method in ("ISD-1", "ISD-3"):
poly_order = 1 if method == "ISD-1" else 3
# Pass 1 — run on all pixels (weights clipped to [1/W, W] internally).
weights_p1, alpha_all_p1, n_iter_p1 = iterative_systematics_decontamination(
delta_g_obs, delta_t, poly_order=poly_order, max_iter=isd_max_iter,
lambda_poly=isd_lambda_poly,
)
# Pixels that hit the clip boundary are pathological: 1 + a@t ≈ 0.
# Mask them and refit so the final coefficients are not biased by them.
_boundary = _ISD_MAX_WEIGHT * (1.0 - 1e-6)
outlier_mask = (weights_p1 >= _boundary) | (weights_p1 <= 1.0 / _boundary)
n_outlier = int(outlier_mask.sum())
masked_fraction = n_outlier / n_pix
if n_outlier > 0 and (n_pix - n_outlier) >= max(10 * n_sys, 50):
# Pass 2 — refit on pixels that stayed away from the clip boundary.
good = ~outlier_mask
weights_p2, alpha_all_p2, n_iter_p2 = iterative_systematics_decontamination(
delta_g_obs[good], delta_t[:, good],
poly_order=poly_order, max_iter=isd_max_iter,
lambda_poly=isd_lambda_poly,
)
# Apply clean coefficients to ALL pixels (outliers included).
a_hat = np.asarray(alpha_all_p2)[:n_sys]
weights = _compute_weights(a_hat, delta_t)
weights = np.clip(weights, 1.0 / _ISD_MAX_WEIGHT, _ISD_MAX_WEIGHT)
n_iterations = n_iter_p1 + n_iter_p2
warnings.warn(
f"{method}: {n_outlier}/{n_pix} outlier pixels masked for pass 2 "
f"(pass1={n_iter_p1} iters, pass2={n_iter_p2} iters).",
UserWarning,
stacklevel=2,
)
else:
a_hat = np.asarray(alpha_all_p1)[:n_sys]
weights = weights_p1
n_iterations = n_iter_p1
b_hat = np.zeros(n_sys)
# Guard: non-finite coefficients (residual edge case).
if not np.all(np.isfinite(a_hat)):
warnings.warn(
f"{method} produced non-finite coefficients; "
"a_hat zeroed, weights set to ones.",
stacklevel=2,
)
a_hat = np.zeros(n_sys)
weights = np.ones(n_pix)
outlier_mask = np.zeros(n_pix, dtype=bool)
masked_fraction = 0.0
result.update({
"a_hat": a_hat, "b_hat": b_hat, "weights": weights,
"n_iterations": n_iterations,
"isd_outlier_mask": outlier_mask,
"isd_masked_fraction": masked_fraction,
})
# ── MCMC-add / MCMC-comb ──────────────────────────────────────────────────
else:
from .inference import run_mcmc, get_mle_params, get_param_covariance_from_chain
from .contamination import unpack_params
from .correction import rotate_templates, transform_params_from_rotated
mcmc_model = "additive" if method == "MCMC-add" else "combined"
_use_skewed = (use_skewed if mcmc_model == "combined" else False)
# PCA rotation — deterministic for fixed delta_t
delta_t_rot, R, _eigenvalues = rotate_templates(delta_t)
# Emcee requires n_walkers >= 2*n_dim+2
n_dim = n_sys + 1 if mcmc_model == "additive" else 2 * n_sys + 1
nw = max(n_walkers, 2 * n_dim + 2)
flat_chain, sampler = run_mcmc(
n_sys=n_sys, model=mcmc_model,
delta_g_obs=delta_g_obs, delta_t=delta_t_rot,
n_walkers=nw, n_steps=n_steps, n_burn=n_burn,
seed=seed, progress=progress,
use_skewed=_use_skewed,
)
theta_hat = get_mle_params(flat_chain)
a_rot_raw, b_rot_raw, sigma_hat_val, _ = unpack_params(
theta_hat, n_sys, mcmc_model, use_skewed=_use_skewed
)
a_hat, b_hat = transform_params_from_rotated(
np.asarray(a_rot_raw), np.asarray(b_rot_raw), R
)
cov_a_rot, cov_b_rot = get_param_covariance_from_chain(
flat_chain, n_sys, mcmc_model
)
cov_a = R.T @ cov_a_rot @ R
cov_b = R.T @ cov_b_rot @ R
# Exact pixel-level decontamination weight:
# w(p) = (1 + δ_g_clean(p)) / (1 + δ_g_obs(p))
# where δ_g_clean = invert_contamination(δ_g_obs, δ_t, a_hat, b_hat).
# This is the exact inverse of the contamination weight and cancels
# weight_cont perfectly when (a_hat, b_hat) = (a_true, b_true).
# The 1/(1+α·t) approximation overcorrects when noisy b_hat≠0 in
# additive scenarios — the exact inverse avoids that bias.
from .contamination import invert_contamination
import jax.numpy as _jnp
_delta_g_clean = np.asarray(invert_contamination(
_jnp.asarray(delta_g_obs),
_jnp.asarray(delta_t),
_jnp.asarray(a_hat),
_jnp.asarray(b_hat),
))
_denom = np.maximum(1.0 + delta_g_obs, 1e-6)
weights = np.clip(
(1.0 + _delta_g_clean) / _denom,
1.0 / _ISD_MAX_WEIGHT, _ISD_MAX_WEIGHT,
)
result.update({
"a_hat": a_hat, "b_hat": b_hat, "weights": weights,
"flat_chain": flat_chain, "sampler": sampler,
"a_rot": np.asarray(a_rot_raw), "b_rot": np.asarray(b_rot_raw),
"cov_a_rot": cov_a_rot, "cov_b_rot": cov_b_rot,
"cov_a": cov_a, "cov_b": cov_b,
"R": R,
"acceptance_fraction": float(np.mean(sampler.acceptance_fraction)),
"sigma_hat": float(sigma_hat_val),
})
result["elapsed_s"] = time.perf_counter() - t0
return result