import numpy as np
from scipy.stats import pearsonr

from .._shared.utils import check_shape_equality, as_binary_ndarray

__all__ = [
    'pearson_corr_coeff',
    'manders_coloc_coeff',
    'manders_overlap_coeff',
    'intersection_coeff',
]


def pearson_corr_coeff(image0, image1, mask=None):
    r"""Calculate Pearson's Correlation Coefficient between pixel intensities
    in channels.

    Parameters
    ----------
    image0 : (M, N) ndarray
        Image of channel A.
    image1 : (M, N) ndarray
        Image of channel 2 to be correlated with channel B.
        Must have same dimensions as `image0`.
    mask : (M, N) ndarray of dtype bool, optional
        Only `image0` and `image1` pixels within this region of interest mask
        are included in the calculation. Must have same dimensions as `image0`.

    Returns
    -------
    pcc : float
        Pearson's correlation coefficient of the pixel intensities between
        the two images, within the mask if provided.
    p-value : float
        Two-tailed p-value.

    Notes
    -----
    Pearson's Correlation Coefficient (PCC) measures the linear correlation
    between the pixel intensities of the two images. Its value ranges from -1
    for perfect linear anti-correlation to +1 for perfect linear correlation.
    The calculation of the p-value assumes that the intensities of pixels in
    each input image are normally distributed.

    Scipy's implementation of Pearson's correlation coefficient is used. Please
    refer to it for further information and caveats [1]_.

    .. math::
        r = \frac{\sum (A_i - m_A_i) (B_i - m_B_i)}
        {\sqrt{\sum (A_i - m_A_i)^2 \sum (B_i - m_B_i)^2}}

    where
        :math:`A_i` is the value of the :math:`i^{th}` pixel in `image0`
        :math:`B_i` is the value of the :math:`i^{th}` pixel in `image1`,
        :math:`m_A_i` is the mean of the pixel values in `image0`
        :math:`m_B_i` is the mean of the pixel values in `image1`

    A low PCC value does not necessarily mean that there is no correlation
    between the two channel intensities, just that there is no linear
    correlation. You may wish to plot the pixel intensities of each of the two
    channels in a 2D scatterplot and use Spearman's rank correlation if a
    non-linear correlation is visually identified [2]_. Also consider if you
    are interested in correlation or co-occurence, in which case a method
    involving segmentation masks (e.g. MCC or intersection coefficient) may be
    more suitable [3]_ [4]_.

    Providing the mask of only relevant sections of the image (e.g., cells, or
    particular cellular compartments) and removing noise is important as the
    PCC is sensitive to these measures [3]_ [4]_.

    References
    ----------
    .. [1] https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.pearsonr.html
    .. [2] https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.spearmanr.html
    .. [3] Dunn, K. W., Kamocka, M. M., & McDonald, J. H. (2011). A practical
           guide to evaluating colocalization in biological microscopy.
           American journal of physiology. Cell physiology, 300(4), C723–C742.
           https://doi.org/10.1152/ajpcell.00462.2010
    .. [4] Bolte, S. and Cordelières, F.P. (2006), A guided tour into
           subcellular colocalization analysis in light microscopy. Journal of
           Microscopy, 224: 213-232.
           https://doi.org/10.1111/j.1365-2818.2006.01706.x
    """
    image0 = np.asarray(image0)
    image1 = np.asarray(image1)
    if mask is not None:
        mask = as_binary_ndarray(mask, variable_name="mask")
        check_shape_equality(image0, image1, mask)
        image0 = image0[mask]
        image1 = image1[mask]
    else:
        check_shape_equality(image0, image1)
        # scipy pearsonr function only takes flattened arrays
        image0 = image0.reshape(-1)
        image1 = image1.reshape(-1)

    return tuple(float(v) for v in pearsonr(image0, image1))


def manders_coloc_coeff(image0, image1_mask, mask=None):
    r"""Manders' colocalization coefficient between two image channels.

    Parameters
    ----------
    image0 : (M, N) ndarray
        Input image (first channel). All pixel values should be non-negative.
    image1_mask : (M, N) ndarray of dtype bool
        Binary image giving the regions of interest in the second channel.
        Must have same shape as `image0`.
    mask : (M, N) ndarray of dtype bool, optional
        Only `image0` pixel values within `mask` are included in the calculation.
        Must have same shape as `image0`.

    Returns
    -------
    mcc : float
        Manders' colocalization coefficient.

    Notes
    -----
    Manders' colocalization coefficient (MCC) was developed in the context of
    confocal biological microscopy, to measure the fraction of colocalizing
    objects in each component of a dual-channel image. Out of the total
    intensity of, say, channel A, how much is found within the features
    (objects) of, say, channel B [1]_? The measure thus ranges from 0 for no
    colocalization to 1 for complete colocalization.

    MCC is commonly used to measure the colocalization of a particular protein
    in a subcelullar compartment. Typically, the mask for channel B is
    obtained by thresholding, to segment the features from the background.
    In this implementation, channel B is passed directly as a mask
    (`image1_mask`), leaving the segmentation step to the user (upstream).

    The implemented equation is:

    .. math::

       mcc = \frac{\sum_i A_{i,coloc}}{\sum_i A_i}

    where

    - :math:`A_i` is the value of the :math:`i^{th}` pixel in `image0`, and
    - :math:`A_{i, coloc} = A_i B_i`, considering that :math:`B_i` is the
      (``True`` or ``False``) value of the :math:`i^{th}` pixel in
      `image1_mask` cast into int or float (``1`` or ``0``, respectively).

    MCC is sensitive to noise, with diffuse signal in the first channel
    inflating its value. Therefore, images should be processed beforehand to
    remove out-of-focus and background light [2]_.

    References
    ----------
    .. [1] Manders, E.M.M., Verbeek, F.J. and Aten, J.A. (1993), Measurement of
           co-localization of objects in dual-colour confocal images. Journal
           of Microscopy, 169: 375-382.
           https://doi.org/10.1111/j.1365-2818.1993.tb03313.x
           https://imagej.net/media/manders.pdf
    .. [2] Dunn, K. W., Kamocka, M. M., & McDonald, J. H. (2011). A practical
           guide to evaluating colocalization in biological microscopy.
           American journal of physiology. Cell physiology, 300(4), C723–C742.
           https://doi.org/10.1152/ajpcell.00462.2010

    """
    image0 = np.asarray(image0)
    image1_mask = as_binary_ndarray(image1_mask, variable_name="image1_mask")
    if mask is not None:
        mask = as_binary_ndarray(mask, variable_name="mask")
        check_shape_equality(image0, image1_mask, mask)
        image0 = image0[mask]
        image1_mask = image1_mask[mask]
    else:
        check_shape_equality(image0, image1_mask)
    # check non-negative image
    if image0.min() < 0:
        raise ValueError("image contains negative values")

    sum = np.sum(image0)
    if sum == 0:
        return 0
    return np.sum(image0 * image1_mask) / sum


def manders_overlap_coeff(image0, image1, mask=None):
    r"""Manders' overlap coefficient

    Parameters
    ----------
    image0 : (M, N) ndarray
        Image of channel A. All pixel values should be non-negative.
    image1 : (M, N) ndarray
        Image of channel B. All pixel values should be non-negative.
        Must have same dimensions as `image0`
    mask : (M, N) ndarray of dtype bool, optional
        Only `image0` and `image1` pixel values within this region of interest
        mask are included in the calculation.
        Must have ♣same dimensions as `image0`.

    Returns
    -------
    moc: float
        Manders' Overlap Coefficient of pixel intensities between the two
        images.

    Notes
    -----
    Manders' Overlap Coefficient (MOC) is given by the equation [1]_:

    .. math::
        r = \frac{\sum A_i B_i}{\sqrt{\sum A_i^2 \sum B_i^2}}

    where
        :math:`A_i` is the value of the :math:`i^{th}` pixel in `image0`
        :math:`B_i` is the value of the :math:`i^{th}` pixel in `image1`

    It ranges between 0 for no colocalization and 1 for complete colocalization
    of all pixels.

    MOC does not take into account pixel intensities, just the fraction of
    pixels that have positive values for both channels[2]_ [3]_. Its usefulness
    has been criticized as it changes in response to differences in both
    co-occurence and correlation and so a particular MOC value could indicate
    a wide range of colocalization patterns [4]_ [5]_.

    References
    ----------
    .. [1] Manders, E.M.M., Verbeek, F.J. and Aten, J.A. (1993), Measurement of
           co-localization of objects in dual-colour confocal images. Journal
           of Microscopy, 169: 375-382.
           https://doi.org/10.1111/j.1365-2818.1993.tb03313.x
           https://imagej.net/media/manders.pdf
    .. [2] Dunn, K. W., Kamocka, M. M., & McDonald, J. H. (2011). A practical
           guide to evaluating colocalization in biological microscopy.
           American journal of physiology. Cell physiology, 300(4), C723–C742.
           https://doi.org/10.1152/ajpcell.00462.2010
    .. [3] Bolte, S. and Cordelières, F.P. (2006), A guided tour into
           subcellular colocalization analysis in light microscopy. Journal of
           Microscopy, 224: 213-232.
           https://doi.org/10.1111/j.1365-2818.2006.01
    .. [4] Adler J, Parmryd I. (2010), Quantifying colocalization by
           correlation: the Pearson correlation coefficient is
           superior to the Mander's overlap coefficient. Cytometry A.
           Aug;77(8):733-42.https://doi.org/10.1002/cyto.a.20896
    .. [5] Adler, J, Parmryd, I. Quantifying colocalization: The case for
           discarding the Manders overlap coefficient. Cytometry. 2021; 99:
           910– 920. https://doi.org/10.1002/cyto.a.24336

    """
    image0 = np.asarray(image0)
    image1 = np.asarray(image1)
    if mask is not None:
        mask = as_binary_ndarray(mask, variable_name="mask")
        check_shape_equality(image0, image1, mask)
        image0 = image0[mask]
        image1 = image1[mask]
    else:
        check_shape_equality(image0, image1)

    # check non-negative image
    if image0.min() < 0:
        raise ValueError("image0 contains negative values")
    if image1.min() < 0:
        raise ValueError("image1 contains negative values")

    denom = (np.sum(np.square(image0)) * (np.sum(np.square(image1)))) ** 0.5
    return np.sum(np.multiply(image0, image1)) / denom


def intersection_coeff(image0_mask, image1_mask, mask=None):
    r"""Fraction of a channel's segmented binary mask that overlaps with a
    second channel's segmented binary mask.

    Parameters
    ----------
    image0_mask : (M, N) ndarray of dtype bool
        Image mask of channel A.
    image1_mask : (M, N) ndarray of dtype bool
        Image mask of channel B.
        Must have same dimensions as `image0_mask`.
    mask : (M, N) ndarray of dtype bool, optional
        Only `image0_mask` and `image1_mask` pixels within this region of
        interest
        mask are included in the calculation.
        Must have same dimensions as `image0_mask`.

    Returns
    -------
    Intersection coefficient, float
        Fraction of `image0_mask` that overlaps with `image1_mask`.

    """
    image0_mask = as_binary_ndarray(image0_mask, variable_name="image0_mask")
    image1_mask = as_binary_ndarray(image1_mask, variable_name="image1_mask")
    if mask is not None:
        mask = as_binary_ndarray(mask, variable_name="mask")
        check_shape_equality(image0_mask, image1_mask, mask)
        image0_mask = image0_mask[mask]
        image1_mask = image1_mask[mask]
    else:
        check_shape_equality(image0_mask, image1_mask)

    nonzero_image0 = np.count_nonzero(image0_mask)
    if nonzero_image0 == 0:
        return 0
    nonzero_joint = np.count_nonzero(np.logical_and(image0_mask, image1_mask))
    return nonzero_joint / nonzero_image0
