"""
Use labelled images from the compaq_labels directory to construct a
normal-gamma belief over a Gaussian distribution of hair color in
YPbPr color space.
Requires: Compaq skin database, labels in compaq_labels/
Creates:  storage/color_prior_hair_continuous.data

Copyright (c) 2011 Idiap Research Institute, http://www.idiap.ch/
Written by Carl Scheffler <carl.scheffler@gmail.com>

This file is part of FaceColorModel.

FaceColorModel is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License version 3 as
published by the Free Software Foundation.

FaceColorModel is distributed in the hope that it will be useful, but
WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
General Public License for more details.

You should have received a copy of the GNU General Public License
along with FaceColorModel. If not, see <http://www.gnu.org/licenses/>.
"""
from __future__ import division
from scipy import *
import os, sys
from colorlib import rgb2ypbpr, load_image
from normal_gamma_distribution import compute_normal_gamma_parameters_from_sufficient_stats

hairLabelColor = array([255, 0, 0], dtype=uint8)

# Read command line arguments
try:                                         # Input file list
    index = sys.argv.index("--input_list")
    inputListFilename = sys.argv[index+1]
except ValueError:
    print "ERROR: input list not specified"
    sys.exit()
try:                                         # Label file list
    index = sys.argv.index("--label_list")
    labelListFilename = sys.argv[index+1]
except ValueError:
    print "ERROR: label list not specified"
    sys.exit()
verbose = ("--quiet" not in sys.argv[1:])    # Verbosity

# Read training lists
with open(inputListFilename, 'rt') as fp:
    inputList = fp.read().strip().split('\n')
with open(labelListFilename, 'rt') as fp:
    labelList = fp.read().strip().split('\n')
if len(inputList) != len(labelList):
    print "Error: input and label lists are not of equal length"
    sys.exit()

# Compute the sample mean and standard deviation along each of the Y,
# Pb and Pr color axes of the hair pixels in each labelled image.
mu = [[], [], []]
sigma = [[], [], []]
for i in range(len(inputList)):
    if verbose and ((i+1) % 100 == 0):
        print i+1, '/', len(inputList)

    # Read mask image and extract hair mask
    labelImage = load_image(labelList[i])
    if labelImage is None:
        if verbose:
            print "Warning: Could not read label image:", labelList[i]
        continue
    mask = ravel((array(labelImage, dtype=uint8)[:,:,::-1] == hairLabelColor).all(axis=2)) # Reverse colour channels: BGR -> RGB

    # Read color input image
    inputImage = load_image(inputList[i])
    if inputImage is None:
        if verbose:
            print "Warning: Could not read input image:", inputList[i]
        continue

    inputArray = array(inputImage)
    blue  = compress(mask, ravel(inputArray[:,:,0]))
    if len(blue) == 0:
        if verbose:
            print 'No hair pixels in label image:', labelList[i]
        continue
    green = compress(mask, ravel(inputArray[:,:,1]))
    red   = compress(mask, ravel(inputArray[:,:,2]))
    rgb = vstack((red, green, blue))
    ypbpr = rgb2ypbpr(rgb/255)
    for j in range(3):
        mu[j].append(mean(ypbpr[j,:]))
        sigma[j].append(std(ypbpr[j,:]))

if verbose:
    print "Computing normal-gamma distribution from YPbPr means and variances..."
colorStats = empty((6, len(mu[0])), dtype=float) # Each column is [mean, precision] * 3
                                                 # (for the 3 color channels)
for j in range(3):
    colorStats[0+j*2,:] = array(mu[j])
    colorStats[1+j*2,:] = array(sigma[j])**-2
hairColorPrior = compute_normal_gamma_parameters_from_sufficient_stats(colorStats)

if verbose:
    print "Saving..."
with open("storage/color_prior_hair_continuous.data", "wb") as fp:
    fp.write(hairColorPrior.data)
