Source code for aiapy.psf.deconvolve

"""
Deconvolve an AIA image with the channel point spread function.
"""

import copy
import warnings

from numpy import asarray

from aiapy.psf import _jit_over_iterations, lax, np
from aiapy.psf.psf import calculate_psf
from aiapy.utils import AIApyUserWarning

__all__ = ["deconvolve"]


@_jit_over_iterations
def _rl_deconvolve(img, psf_fft, psf_conj, *, iterations: int):
    def body(_, current):
        est = np.fft.irfft2(np.fft.rfft2(current) * psf_fft)
        ratio = img / est
        update = np.fft.irfft2(np.fft.rfft2(ratio) * psf_conj)
        return current * update

    return lax.fori_loop(0, iterations, body, img)


[docs] def deconvolve( smap, *, psf: np.ndarray | None = None, iterations: int = 25, clip_negative: bool = True, ): """ Deconvolve an AIA image with the point spread function. Perform image deconvolution on an AIA image with the instrument point spread function using the Richardson-Lucy deconvolution algorithm [1]_. .. note:: If the jax package is installed it will be used to accelerate the computation. jax can use CPUs or GPUs, `see their documentation for instructions <https://docs.jax.dev/en/latest/installation.html>`__. For more information on PSF deconvolution on a GPU, see [2]_. Parameters ---------- smap : `~sunpy.map.Map` An AIA image. psf : array-like, optional The point spread function. Defaults to `None` and it will be calculated with `aiapy.psf.calculate_psf`. iterations : `int`, optional Number of iterations in the Richardson-Lucy algorithm, defaults to 25. clip_negative : `bool`, optional If the image has negative intensity values, set them to zero. Defaults to `True`. Returns ------- `~sunpy.map.Map` Deconvolved AIA image See Also -------- calculate_psf References ---------- .. [1] https://en.wikipedia.org/wiki/Richardson%E2%80%93Lucy_deconvolution .. [2] Cheung, M., 2015, *GPU Technology Conference Silicon Valley*, `GPU-Accelerated Image Processing for NASA's Solar Dynamics Observatory <https://on-demand.gputechconf.com/gtc/2015/presentation/S5209-Mark-Cheung.pdf>`__ """ # TODO: Should check to make sure this is a full-frame image? # We need to promote to float64 for JAX only img = smap.data.astype(np.float64) if np.__name__ == "jax.numpy" else smap.data if clip_negative: img = np.where(img < 0, 0, img) if np.any(img < 0): warnings.warn( "Image contains negative intensity values. Consider setting clip_negative to True", AIApyUserWarning, stacklevel=2, ) if psf is None: psf = calculate_psf(smap.wavelength) # Center PSF at pixel (0,0) psf = np.fft.fftshift(psf) # Convolution requires FFT of the PSF psf = np.fft.rfft2(psf) psf_conj = np.conj(psf) img_decon = _rl_deconvolve(img, psf, psf_conj, iterations=iterations) return smap._new_instance( # Always convert back to numpy array asarray(img_decon), copy.deepcopy(smap.meta), plot_settings=copy.deepcopy(smap.plot_settings), mask=smap.mask, )