"""
Copyright (c) 2011 Idiap Research Institute, http://www.idiap.ch/
Written by Carl Scheffler <carl.scheffler@gmail.com>

This file is part of FaceColorModel.

FaceColorModel is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License version 3 as
published by the Free Software Foundation.

FaceColorModel is distributed in the hope that it will be useful, but
WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
General Public License for more details.

You should have received a copy of the GNU General Public License
along with FaceColorModel. If not, see <http://www.gnu.org/licenses/>.
"""
from __future__ import division
import scipy as S

def render_discrete_color_distribution(iDistribution, iOutputSize=(64,64)):
    """
    Visualize a discrete distribution over RGB colours. It is assumed
    that the input is a (possibly ravel'ed) 3-d histogram of RGB
    colours. It will return a colour image (a MxNx3 array) that can be
    displayed using pylab.imshow(). The colour with greatest
    probability will be shown first and the number of pixels devoted
    to each colour represents the probability mass of the colour.
    """
    if len(iDistribution.shape) == 1:
        binsPerChannel = int(round(iDistribution.shape[0] ** (1/3)))
        assert iDistribution.shape[0] == binsPerChannel**3
    else:
        assert len(iDistribution.shape) == 3
        assert iDistribution.shape[0] == iDistribution.shape[1] == iDistribution.shape[2]
        binsPerChannel = iDistribution.shape[0]
        iDistribution = S.ravel(iDistribution)
    iDistribution = iDistribution / S.sum(iDistribution)
        
    outputChannels = S.empty((3, S.prod(iOutputSize)), dtype=float)
    sortedIndices = S.argsort(iDistribution)[::-1]
    totalProbabilityMass = 0
    for index in sortedIndices:
        color = (index // (binsPerChannel**2),
                 index % (binsPerChannel**2) // binsPerChannel,
                 index % binsPerChannel)
        startIndex = int(round(totalProbabilityMass * outputChannels.shape[1]))
        totalProbabilityMass += iDistribution[index]
        stopIndex = int(round(totalProbabilityMass * outputChannels.shape[1]))
        if startIndex != stopIndex:
            for channel in range(3):
                outputChannels[channel, startIndex:stopIndex] = color[channel]
        if stopIndex == outputChannels.shape[1]: break
    outputChannels = (outputChannels+0.5)/binsPerChannel

    outputImage = S.rollaxis(S.reshape(outputChannels, (3,) + iOutputSize), 0, 3)
    return outputImage

def render_continuous_color_distribution(iParameters, *args, **kwargs):
    from gamma import scipy_gamma_samples
    from colorlib import ypbpr2rgb
    
    BIN_COUNT = 16
    numSamples = 2000
    distribution = S.zeros(BIN_COUNT**3, dtype=int)
    toSample = numSamples
    while toSample > 0:
        lambda_ = S.empty((toSample, 3), dtype=float)
        for i in range(3):
            lambda_[:,i] = scipy_gamma_samples(iParameters[2+i*4], iParameters[3+i*4], size=toSample)
        mu = S.random.normal(iParameters[0::4], (iParameters[1::4]*lambda_)**-0.5, size=(toSample, 3))
        ypbpr = S.random.normal(mu, lambda_**-0.5).clip(0,1).transpose()
        rgb = ypbpr2rgb(ypbpr)
        colorIndices = S.sum(S.array(rgb*BIN_COUNT, dtype=int) * S.array([[BIN_COUNT**2],[BIN_COUNT],[1]]), axis=0)
        for i in range(rgb.shape[1]):
            colorIndex = colorIndices[i]
            if (min(rgb[:,i]) >= 0) and (max(rgb[:,i]) < 1):
                distribution[colorIndex] += 1
                toSample -= 1
                
    return render_discrete_color_distribution(distribution, *args, **kwargs)
