"""
Use skin images from the Compaq skin database to construct a
normal-gamma belief over a Gaussian distribution of skin color in
YPbPr color space.
Requires: Compaq skin database at path compaq_database/
Creates:  storage/color_prior_skin_continuous.data

Copyright (c) 2011 Idiap Research Institute, http://www.idiap.ch/
Written by Carl Scheffler <carl.scheffler@gmail.com>

This file is part of FaceColorModel.

FaceColorModel is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License version 3 as
published by the Free Software Foundation.

FaceColorModel is distributed in the hope that it will be useful, but
WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
General Public License for more details.

You should have received a copy of the GNU General Public License
along with FaceColorModel. If not, see <http://www.gnu.org/licenses/>.
"""
from __future__ import division
from scipy import *
from opencv import highgui
import os, sys
from viola_jones_opencv import viola_jones_opencv
from colorlib import rgb2ypbpr, load_image
from normal_gamma_distribution import compute_normal_gamma_parameters_from_sufficient_stats

basePath = 'compaq_database/'

# Read command line arguments
verbose = ("--quiet" not in sys.argv[1:])

# Read mask list
maskPath = basePath + 'masks/'
with open(basePath + 'mask-list', 'rt') as fp:
    maskList = [line + '.pbm' for line in fp.read().strip().split('\n')]
maskList.sort()

# Read skin image list
skinPath = basePath + 'skin-images/'
with open(basePath + 'skin-list-good', 'rt') as fp:
    skinList = fp.read().strip().split('\n')
skinList.sort()

# Sanity check that lists match exactly
assert [filename[:filename.find('.')] for filename in maskList] ==\
       [filename[:filename.find('.')] for filename in skinList]

# Compute the sample mean and standard deviation along each of the Y,
# Pb and Pr color axes of the skin pixels in each image in the Compaq
# database.
mu = [[], [], []]
sigma = [[], [], []]
for imageIndex in range(len(maskList)):
    if verbose and ((imageIndex+1)%100 == 0):
        print imageIndex+1, '/', len(maskList)

    maskImage = load_image(maskPath + maskList[imageIndex],
                           highgui.CV_LOAD_IMAGE_GRAYSCALE)
    if maskImage is None:
        if verbose:
            print "Mask is None:", maskList[imageIndex]
        continue

    grayImage = load_image(skinPath + skinList[imageIndex],
                           highgui.CV_LOAD_IMAGE_GRAYSCALE)
    if grayImage is None:
        if verbose:
            print "Skin is None:", skinList[imageIndex]
        continue

    face = viola_jones_opencv(grayImage)
    if face is None:
        continue
    
    mask = ravel(array(maskImage)[face[0]:face[2], face[1]:face[3]]) > 0
    skin = array(load_image(skinPath + skinList[imageIndex]))[face[0]:face[2], face[1]:face[3], :]
    blue  = compress(mask, ravel(skin[:,:,0]))
    if len(blue) == 0:
        if verbose:
            print 'No skin pixels in image:', skinList[imageIndex]
        continue
    if len(blue) < 12:
        # Not enough skin pixels for reliable maximum likelihood
        # Gaussian estimate
        continue
    green = compress(mask, ravel(skin[:,:,1]))
    red   = compress(mask, ravel(skin[:,:,2]))
    rgb = vstack((red, green, blue))
    ypbpr = rgb2ypbpr(rgb/255)

    # Check for outliers: if the variance in color space is very small
    # it usually means that we have a grayscale image or something
    # close to it.
    passed = True
    for j in range(3):
        if var(ypbpr[j,:]) < 1./20000:
            passed = False
            break
    if not passed:
        continue

    # Store statistics: mean and standard deviation
    for j in range(3):
        mu[j].append(mean(ypbpr[j,:]))
        sigma[j].append(std(ypbpr[j,:]))

if verbose:
    print "Computing normal-gamma distribution from YPbPr means and variances..."
colorStats = empty((6, len(mu[0])), dtype=float) # Each column is [mean, precision] * 3
                                                 # (for the 3 color channels)
for j in range(3):
    colorStats[0+j*2,:] = array(mu[j])
    colorStats[1+j*2,:] = array(sigma[j])**-2
skinColorPrior = compute_normal_gamma_parameters_from_sufficient_stats(colorStats)

if verbose:
    print "Saving..."
with open("storage/color_prior_skin_continuous.data", "wb") as fp:
    fp.write(skinColorPrior.data)
