import os
import tempfile
from copy import deepcopy
from pathlib import Path
import numpy as np
from scipy.ndimage import labeled_comprehension
from astropy.table import Table
from . import image_analysis
from .utils.io import (
IO_FORMATS_READ,
IO_FORMATS_WRITE,
read_image_trace_file,
read_parameter_trace_file,
)
from .utils.plot import (
plot_parameter_distributions,
plot_parameter_traces,
plot_pixel_trace,
plot_pixel_trace_neighbours,
)
DTYPE_DEFAULT = np.float64
__all__ = ["LIRADeconvolver", "LIRADeconvolverResult"]
[docs]class LIRADeconvolver:
"""LIRA image deconvolution method
Parameters
----------
alpha_init : `~numpy.ndarray`
Initial alpha parameters. The length must be n for an input image of size 2^n x 2^n
n_iter_max : int
Max. number of iterations.
n_burn_in : int
Number of burn-in iterations.
fit_background_scale : bool
Fit background scale.
save_thin : True
Save thin?
ms_ttlcnt_pr: float
Multiscale prior TODO: improve description
ms_ttlcnt_exp: float
Multiscale prior TODO: improve description
ms_al_kap1: float
Multiscale prior TODO: improve description
ms_al_kap2: float
Multiscale prior TODO: improve description
ms_al_kap3: float
Multiscale prior TODO: improve description
random_state : `~numpy.random.RandomState`
Random state
Examples
--------
This how to use the class:
.. code::
from pylira import LIRADeconvolver
from pylira.data import point_source_gauss_psf
data = point_source_gauss_psf()
data["flux_init"] = data["flux"]
deconvolve = LIRADeconvolver(
alpha_init=np.ones(np.log2(data["counts"].shape[0]).astype(int))
)
result = deconvolve.run(data=data)
"""
def __init__(
self,
alpha_init,
n_iter_max=3000,
n_burn_in=1000,
fit_background_scale=False,
save_thin=True,
ms_ttlcnt_pr=1,
ms_ttlcnt_exp=0.05,
ms_al_kap1=0.0,
ms_al_kap2=1000.0,
ms_al_kap3=3.0,
filename_out=None,
filename_out_par=None,
random_state=None,
):
self.alpha_init = np.array(alpha_init, dtype=DTYPE_DEFAULT)
self.n_iter_max = n_iter_max
self.n_burn_in = n_burn_in
self.fit_background_scale = fit_background_scale
self.save_thin = save_thin
self.ms_ttlcnt_pr = ms_ttlcnt_pr
self.ms_ttlcnt_exp = ms_ttlcnt_exp
self.ms_al_kap1 = ms_al_kap1
self.ms_al_kap2 = ms_al_kap2
self.ms_al_kap3 = ms_al_kap3
if random_state is None:
random_state = np.random.RandomState()
self.random_state = random_state
def __str__(self):
"""String representation"""
cls_name = self.__class__.__name__
info = cls_name + "\n"
info += len(cls_name) * "-" + "\n\n"
data = self.to_dict()
for key, value in data.items():
info += f"\t{key:21s}: {value}\n"
return info.expandtabs(tabsize=4)
def _check_input_sizes(self, shape):
obs_shape = shape[0]
if obs_shape & (obs_shape - 1) != 0:
raise ValueError(
f"Size of the input observation must be a power of 2. Size given: {obs_shape}"
)
if len(self.alpha_init) != np.log2(obs_shape):
raise ValueError(
f"Number of elements in alpha_init must be {np.log2(obs_shape)}.\
Size given: {len(self.alpha_init)} "
)
[docs] def to_dict(self):
"""Convert deconvolver configuration to dict, with simple data types.
Returns
-------
data : dict
Parameter dict.
"""
data = {}
data.update(self.__dict__)
data["alpha_init"] = self.alpha_init.tolist()
# TOOD: serialise random state for reproducibility?
data.pop("random_state")
return data
[docs] def run(self, data):
"""Run the algorithm
Parameters
----------
data : dict of `~numpy.ndarray`
Data
Returns
-------
result : `LIRADeconvolverResult`
Result object.
"""
data = {
name: arr.astype(DTYPE_DEFAULT)
for name, arr in data.items()
if name != "wcs"
}
shape_counts = data["counts"].shape
self._check_input_sizes(shape_counts)
for name in ["background", "exposure", "flux_init"]:
shape = data[name].shape
if shape != shape_counts:
raise ValueError(
f"Quantity '{name}' has a shape of {shape}, however {shape_counts} is expected."
)
random_seed = self.random_state.randint(1, np.iinfo(np.uint32).max)
with tempfile.TemporaryDirectory() as tmpdir:
filename_image_trace = str(os.path.join(tmpdir, "image-trace.tx"))
filename_parameter_trace = str(os.path.join(tmpdir, "parameter-trace.txt"))
posterior_mean = image_analysis(
observed_im=data["counts"],
start_im=data["flux_init"],
psf_im=data["psf"],
expmap_im=data["exposure"],
baseline_im=data["background"],
out_img_file=filename_image_trace,
out_param_file=filename_parameter_trace,
max_iter=self.n_iter_max,
burn_in=self.n_burn_in,
save_thin=self.save_thin,
fit_bkgscl=int(self.fit_background_scale),
alpha_init=self.alpha_init,
ms_ttlcnt_pr=self.ms_ttlcnt_pr,
ms_ttlcnt_exp=self.ms_ttlcnt_exp,
ms_al_kap1=self.ms_al_kap1,
ms_al_kap2=self.ms_al_kap2,
ms_al_kap3=self.ms_al_kap3,
random_seed=random_seed,
)
parameter_trace = read_parameter_trace_file(
filename=filename_parameter_trace, format="ascii"
)
image_trace = read_image_trace_file(
filename=filename_image_trace, format="ascii"
)
config = self.to_dict()
config["random_seed"] = random_seed
posterior_std = np.nanstd(image_trace[self.n_burn_in :], axis=0)
return LIRADeconvolverResult(
posterior_mean=posterior_mean,
posterior_std=posterior_std,
parameter_trace=parameter_trace,
image_trace=image_trace,
config=config,
)
[docs]class LIRADeconvolverResult:
"""LIRA deconvolution result object.
Parameters
----------
config : `dict`
Configuration from the `LIRADeconvolver`
posterior_mean : `~numpy.ndarray`
Posterior mean
posterior_std : `~numpy.ndarray`
Posterior standard deviation
parameter_trace : `~astropy.table.Table` or dict
Parameter trace. If a dict is provided it triggers the lazy loading.
The dict must contain the argument to `read_parameter_trace_file`.
image_trace : `~astropy.table.Table` or dict
Image trace. If a dict is provided it triggers the lazy loading.
The dict must contain the argument to `read_image_trace_file`.
wcs : `~astropy.wcs.WCS`
World coordinate transform object
"""
def __init__(
self,
config,
posterior_mean=None,
posterior_std=None,
parameter_trace=None,
image_trace=None,
wcs=None,
):
self._config = config
self._posterior_mean = posterior_mean
self._posterior_std = posterior_std
self._wcs = wcs
self._image_trace = image_trace
self._parameter_trace = parameter_trace
@property
def config(self):
"""Configuration data (`dict`)"""
return self._config
@property
def config_table(self):
"""Configuration data as table (`~astropy.table.Table`)"""
config = Table()
for key, value in self.config.items():
if key == "alpha_init":
value = [value]
config[key] = value
return config
@property
def wcs(self):
"""Optional wcs"""
return self._wcs
@property
def n_burn_in(self):
"""Number of burn in iterations"""
return self.config.get("n_burn_in", 0)
@property
def n_iter_max(self):
"""Number of max. iterations"""
return self.config["n_iter_max"]
@property
def posterior_mean(self):
"""Posterior mean (`~numpy.ndarray`)"""
return self._posterior_mean
@property
def posterior_std(self):
"""Posterior standard deviation (`~numpy.ndarray`)"""
return self._posterior_std
@property
def posterior_mean_from_trace(self):
"""Posterior mean computed from trace(`~numpy.ndarray`)"""
if self.image_trace is None:
raise ValueError("No image trace available.")
return np.nanmean(self.image_trace[self.n_burn_in :], axis=0)
@property
def posterior_std_from_trace(self):
"""Posterior std computed from trace(`~numpy.ndarray`)"""
if self.image_trace is None:
raise ValueError("No image trace available.")
return np.nanstd(self.image_trace[self.n_burn_in :], axis=0)
@property
def image_trace(self):
"""Image trace (`~numpy.ndarray`)"""
# TODO: this currently handles only in memory data, this might not scale for
# many iterations and/or large images
if isinstance(self._image_trace, dict):
self._image_trace = read_image_trace_file(**self._image_trace)
return self._image_trace
@property
def parameter_trace(self):
"""Parameter trace (`~astropy.table.Table`)"""
if isinstance(self._parameter_trace, dict):
self._parameter_trace = read_parameter_trace_file(**self._parameter_trace)
# TODO: add config to meta data of table, not sure whether it's the right place.
self._parameter_trace.meta.update(self.config)
return self._parameter_trace
[docs] def plot_pixel_traces_region(
self, center_pix, radius_pix=0, figsize=(16, 6), **kwargs
):
"""Plot pixel traces in a given region.
Parameters
----------
center_pix : tuple of int
Pixel indices, order is (x, y). By default the trace at the center is plotted.
radius_pix : float
Radius in which the traces are plotted.
figsize : tuple of float
Figure size
"""
import matplotlib.pyplot as plt
from matplotlib.patches import Circle
fig = plt.figure(figsize=figsize)
data = self.posterior_mean_from_trace
ax_image = plt.subplot(1, 2, 1, projection=self.wcs)
im = ax_image.imshow(data, origin="lower", **kwargs)
fig.colorbar(im, ax=ax_image, label="Posterior Mean")
radius = max(radius_pix, 1)
artist = Circle(center_pix, radius=radius, color="w", fc="None")
ax_image.add_artist(artist)
ax_trace = plt.subplot(1, 2, 2)
plot_pixel_trace(
image_trace=self.image_trace,
center_pix=center_pix,
ax=ax_trace,
config=self.config,
)
plot_pixel_trace_neighbours(
image_trace=self.image_trace,
center_pix=center_pix,
radius_pix=radius_pix,
ax=ax_trace,
)
return {
"ax-image": ax_image,
"ax-trace": ax_trace,
}
[docs] def plot_pixel_trace(self, center_pix=None, **kwargs):
"""Plot pixel trace at a given position.
Parameters
----------
center_pix : tuple of int
Pixel indices, order is (x, y). By default the trace at the center is plotted.
**kwargs : dict
Keyword arguments forwarded to `plot_pixel_trace`
"""
if self.image_trace is None:
raise ValueError("No image trace available.")
if center_pix is None:
# choose center as default
center_pix = tuple(np.array(self.posterior_mean.shape) // 2)
plot_pixel_trace(
image_trace=self.image_trace,
config=self.config,
center_pix=center_pix,
**kwargs,
)
[docs] def plot_pixel_trace_neighbours(self, center_pix=None, radius_pix=0, **kwargs):
"""Plot pixel traces in a given region.
Parameters
----------
center_pix : tuple of int
Pixel indices, order is (x, y). By default the trace at the center is plotted.
radius_pix : float
Radius in which the traces are plotted.
**kwargs : dict
Keyword arguments forwarded to `~matplotlib.pyplot.plot`
"""
if self.image_trace is None:
raise ValueError("No image trace available.")
if center_pix is None:
# choose center as default
center_pix = tuple(np.array(self.posterior_mean.shape) // 2)
plot_pixel_trace_neighbours(
image_trace=self.image_trace,
center_pix=center_pix,
radius_pix=radius_pix,
**kwargs,
)
[docs] def plot_posterior_mean(self, from_image_trace=False, **kwargs):
"""Plot posteriror mean
Parameters
----------
from_image_trace : bool
Recompute posterior from image trace.
**kwargs : dict
Keyword arguments forwarded to `~matplotlib.pyplot.imshow`
"""
import matplotlib.pyplot as plt
fig = plt.gcf()
if from_image_trace:
data = self.posterior_mean_from_trace
else:
data = self.posterior_mean
ax = plt.subplot(projection=self.wcs)
im = ax.imshow(data, origin="lower", **kwargs)
fig.colorbar(im, ax=ax, label="Posterior Mean")
[docs] def plot_image_trace_interactive(self, **kwargs):
"""Plot image trace interactively
Parameters
----------
**kwargs : dict
Keyword arguments forwarded to `~matplotlib.pyplot.imshow`
"""
if self.image_trace is None:
raise ValueError("No image trace available.")
import matplotlib.pyplot as plt
from ipywidgets import IntSlider
from ipywidgets.widgets.interaction import interact
kwargs.setdefault("interpolation", "nearest")
kwargs.setdefault("origin", "lower")
slider = IntSlider(
value=0,
min=0,
max=self.image_trace.shape[0] - 1,
description="Select idx: ",
continuous_update=False,
style={"description_width": "initial"},
layout={"width": "50%"},
)
@interact(idx=slider)
def _plot_interactive(idx):
ax = plt.subplot(projection=self.wcs)
im = ax.imshow(self.image_trace[idx], **kwargs)
plt.colorbar(im, ax=ax, label="Flux")
[docs] def plot_image_trace_animation(
self,
ax=None,
interval=20,
repeat=True,
n_frames=None,
cumulative=False,
label_dxy=(20, 10),
**kwargs,
):
"""Plot image trace animation
Parameters
----------
ax : `~matplotlib.pyplot.Axes`
Plot axes
interval : int
Interval im ms.
repeat : bool
Repeat animation
n_frames : int
Number of frames
cumulative : bool
Cumulated mean of the samples.
label_dxy : tuple of int
Shift of the frame label.
Returns
-------
anim : `~matplotlib.animation.FuncAnimation`
Func animation object.
"""
if self.image_trace is None:
raise ValueError("No image trace available.")
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
if ax is None:
ax = plt.subplot(projection=self.wcs)
if n_frames is None:
n_frames = self.n_iter_max
kwargs.setdefault("origin", "lower")
data = np.zeros(self.posterior_mean.shape)
image = ax.imshow(data, **kwargs)
y, x = self.posterior_mean.shape
dx, dy = label_dxy
text = ax.text(x - dx, y - dy, s="", color="w", va="center", ha="center")
def animate(idx, im, txt, result):
if cumulative:
data = np.mean(result.image_trace[:idx], axis=0)
else:
data = result.image_trace[idx]
im.set_data(data)
txt.set_text(f"$N_{{Iter}} = {idx}$")
return (image,)
anim = FuncAnimation(
fig=ax.figure,
func=animate,
fargs=[image, text, self],
frames=n_frames,
interval=interval,
blit=True,
repeat=repeat,
)
return anim
[docs] def plot_parameter_traces(self, **kwargs):
"""Plot parameter traces
Parameters
----------
**kwargs : dict
Keyword arguments forwarded to `plot_parameter_traces`
"""
plot_parameter_traces(self.parameter_trace, config=self.config, **kwargs)
[docs] def plot_parameter_distributions(self, **kwargs):
"""Plot parameter distributions
Parameters
----------
**kwargs : dict
Keyword arguments forwarded to `plot_parameter_distributions`
"""
plot_parameter_distributions(self.parameter_trace, config=self.config, **kwargs)
[docs] def write(self, filename, overwrite=False, format="fits"):
"""Write result fo file
Parameters
----------
filename : str or `Path`
Output filename
overwrite : bool
Overwrite file.
format : {"fits"}
Format to use.
"""
filename = Path(filename)
if format not in IO_FORMATS_WRITE:
raise ValueError(
f"Not a valid format '{format}', choose from {list(IO_FORMATS_WRITE)}"
)
writer = IO_FORMATS_WRITE[format]
writer(result=self, filename=filename, overwrite=overwrite)
[docs] @classmethod
def read(cls, filename, format="fits"):
"""Write result fo file
Parameters
----------
filename : str or `Path`
Output filename
format : {"fits"}
Format to use.
Returns
-------
result : `~LIRADeconvolverResult`
Result object
"""
filename = Path(filename)
if format not in IO_FORMATS_READ:
raise ValueError(
f"Not a valid format '{format}', choose from {list(IO_FORMATS_READ)}"
)
reader = IO_FORMATS_READ[format]
kwargs = reader(filename=filename)
return cls(**kwargs)
[docs] def reduce_to_mean_std(self):
"""Reduce to mean and std
Returns
-------
result : `~LIRADeconvolverResult`
Reduced result object
"""
return self.__class__(
config=deepcopy(self.config),
posterior_mean=self.posterior_mean_from_trace,
posterior_std=self.posterior_std_from_trace,
wcs=deepcopy(self.wcs),
parameter_trace=deepcopy(self.parameter_trace),
)
[docs]class LIRASignificanceEstimator:
"""
Estimate the significance of emission from specified regions
using the method described in Stein et al. (2015)
Parameters
----------
result_observed_im: `~LIRADeconvolverResult`
LIRA result for the observed image
result_replicates: list
LIRA result array for the baseline images
labels_im: `~numpy.ndarray`
Image with regions where each region is indicated with a unique integer
"""
def __init__(
self,
result_observed_im,
result_replicates,
labels_im,
):
self._result_observed_im = result_observed_im
self._result_replicates = result_replicates
self._labels_im = labels_im
self._labels = np.array([str(i) for i in np.unique(labels_im.flatten())])
def _estimate_xi(self, result, data):
xi_regions = []
burnin = result.config["n_burn_in"]
n_iter = result.config["n_iter_max"]
thin = result.config["save_thin"]
fit_bkgscl = result.config["fit_background_scale"]
bkg_scale_trace = (
result.parameter_trace["bkgScale"]
if "bkgScale" in result.parameter_trace.keys()
else np.ones(result.parameter_trace["iteration"].shape[0])
)
image_trace = result.image_trace
baseline_im = data["background"]
baseline_sum = labeled_comprehension(
baseline_im, self._labels_im, self._labels, np.sum, float, 0
)
# loop over each image from the trace and estimate xi
iter = 0
for i in range(burnin, n_iter, thin):
tau_1 = labeled_comprehension(
image_trace[iter, :, :], self._labels_im, self._labels, np.sum, float, 0
)
tau_0 = baseline_sum
if fit_bkgscl == 1:
tau_0 = baseline_sum * bkg_scale_trace[iter]
xi_regions.append(tau_1 / (tau_1 + tau_0))
iter = iter + 1
# each row is a distribution of xi for one region
xi_regions = np.array(xi_regions).T
return {self._labels[i]: xi_regions[i] for i in range(self._labels.shape[0])}
def _estimate_test_statistic(self, tail, observed_dist):
return (observed_dist >= tail).sum() / observed_dist.shape[0]
def _estimate_pval_ul(self, gamma, test_stat):
"""
Stein et al. (2015) eq. 22
"""
return gamma / test_stat
[docs] def estimate_p_values(self, data, gamma=0.005):
xi_dist_observed_im = self._estimate_xi(self._result_observed_im, data)
xi_dist_replicates = [
self._estimate_xi(result_replicate, data)
for result_replicate in self._result_replicates
]
xi_dist_merged_replicates = {
self._labels[i]: [] for i in range(self._labels.shape[0])
}
for xi_replicate in xi_dist_replicates:
for k, v in xi_replicate.items():
xi_dist_merged_replicates[k] = np.concatenate(
(xi_dist_merged_replicates[k], v)
)
xi_dist_merged_replicates = {
k: v.flatten() for k, v in xi_dist_merged_replicates.items()
}
# find the 1-gamma percentile
tail_1_gamma = {
k: np.percentile(v, (1 - gamma) * 100)
for k, v in xi_dist_merged_replicates.items()
}
# find the number of values in the xi_dist_observed beyond these percentiles
test_statistic = {
k: self._estimate_test_statistic(v, xi_dist_observed_im[k])
for k, v in tail_1_gamma.items()
}
# estimate upper limit on p-values
p_value_ul = {
k: self._estimate_pval_ul(gamma, v) for k, v in test_statistic.items()
}
return (
p_value_ul,
xi_dist_merged_replicates,
xi_dist_observed_im,
tail_1_gamma,
test_statistic,
)
def _plot_xi(self, xi_dist, ax, ls="--", c="gray", tol=1e-10, label=None):
from scipy import stats
xi_dist_c = deepcopy(xi_dist)
xi_dist_c[xi_dist_c <= tol] = tol
xi_dist_c = np.log10(xi_dist_c)
kernel = stats.gaussian_kde(xi_dist_c)
eval_points = np.linspace(np.min(xi_dist_c), 0, 100)
kde = kernel(eval_points)
ax.plot(eval_points, kde, ls=ls, c=c, label=label)
[docs] def plot_xi_dist(self, xi_obs, xi_repl, region_id, figsize=(8, 5)):
"""
Plot the posterior distributions of xi for a region
Parameters
----------
xi_obs : `~numpy.ndarray`
Posterior distribution of xi for the observation
xi_repl : `~numpy.ndarray`
Posterior distribution of xi for all the replicates
region_id : int
Integer representing the region
figsize : tuple
Figure size
"""
import matplotlib.pyplot as plt
fig, ax = plt.subplots(1, 1, figsize=figsize)
n_replicates = int(xi_repl[region_id].shape[0] / xi_obs[region_id].shape[0])
n_iters = xi_obs[region_id].shape[0]
# plot the replicate distribution
for i in range(0, n_replicates * n_iters, n_iters):
self._plot_xi(xi_repl[region_id][i : i + n_iters], ax)
# plot the mean distribution
self._plot_xi(
xi_repl[region_id], ax, ls="-", c="black", label="Mean null distribution"
)
# plot the observed distribution
self._plot_xi(
xi_obs[region_id], ax, ls="-", c="blue", label="Best fit distribution"
)
ax.set_xlabel(r"Posterior distribution (log$_{10}\xi$)")
ax.set_ylabel("Density")
ax.set_title(f"Region: {region_id}")
plt.legend()
return ax