"""
Adapt the face color model to a single input image and produce a
labeled output image.

Copyright (c) 2011 Idiap Research Institute, http://www.idiap.ch/
Written by Carl Scheffler <carl.scheffler@gmail.com>

This file is part of FaceColorModel.

FaceColorModel is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License version 3 as
published by the Free Software Foundation.

FaceColorModel is distributed in the hope that it will be useful, but
WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
General Public License for more details.

You should have received a copy of the GNU General Public License
along with FaceColorModel. If not, see <http://www.gnu.org/licenses/>.
"""
from __future__ import division
from scipy import *
import opencv
from opencv import highgui
from viola_jones_opencv import viola_jones_opencv
import FaceColorModelWrapper as FaceColorModule
from FaceColorModelWrapper import FaceColorModelWrapper as FaceColorModel
from colorlib import load_image
import command_line
import sys, os

commandLineSpec = (
    ("kMrf", "--kmrf", float, 2, """
      The strength constant of the Markov random field."""),
    ("skinType", "--skin", str, "cont", """
      Whether to use the discrete (disc) or continuous (cont) color
      model for the skin palette."""),
    ("hairType", "--hair", str, "cont", """
      Whether to use the discrete (disc) or continuous (cont) color
      model for the hair palette."""),
    ("threshold", "--threshold", float, 0.75, """
      Certainty threshold. The posterior probability that a pixel
      belongs to a class should be above this threshold before the
      output pixel will be labelled as belonging to that class."""),
)

# Calling for help
if "--help" in sys.argv[1:]:
    print command_line.usage(commandLineSpec, sys.argv)
    sys.exit()

# Parse command line arguments
try:
    PARAMS, filenames = command_line.parse(commandLineSpec, sys.argv)
except ValueError, msg:
    print msg
    print command_line.usage(commandLineSpec, sys.argv)
    sys.exit()

inputFilename = filenames[0]

scaledSize = FaceColorModule.FCM_SCALED_SIZE
classIndex = {
    'skin': FaceColorModule.FCM_CHANNEL_SKIN,
    'hair': FaceColorModule.FCM_CHANNEL_HAIR,
    'clothes': FaceColorModule.FCM_CHANNEL_CLOTHES,
    'background': FaceColorModule.FCM_CHANNEL_BACKGROUND,
}

labelColors = {
    'skin':       (255, 255,   0),
    'hair':       (255,   0,   0),
    'background': (  0,   0, 255),
    'clothes':    (  0, 255,   0),
}

# Load input image
inputImage = load_image(inputFilename)
if inputImage is None:
    print "Error: Could not load image."
    sys.exit()

# Run Viola-Jones face detector
grayImage = load_image(inputFilename, highgui.CV_LOAD_IMAGE_GRAYSCALE)
affineMap = opencv.cvCreateMat(2, 3, opencv.CV_32F)
resizedImage = opencv.cvCreateImage(opencv.cvSize(scaledSize, scaledSize),
                                    opencv.IPL_DEPTH_8U, 3)
faceBox = viola_jones_opencv(grayImage)
del grayImage
if faceBox is None:
    print "Error: No faces found"
    sys.exit()

scale = scaledSize/(2*(faceBox[2]-faceBox[0]))
affineMap[0,0] = scale
affineMap[0,1] = 0
affineMap[0,2] = scaledSize/2 - scale*(faceBox[1]+faceBox[3])/2
affineMap[1,0] = 0
affineMap[1,1] = scale
affineMap[1,2] = scaledSize/2 - scale*(faceBox[0]+faceBox[2])/2

# Resize input image
opencv.cvWarpAffine(inputImage, resizedImage, affineMap,
                    opencv.CV_INTER_AREA+opencv.CV_WARP_FILL_OUTLIERS, opencv.cvScalar(0, 255, 0))
del inputImage, affineMap, faceBox
highgui.cvSaveImage("segment_single_image_cropped.png", resizedImage)

# Do face color adaptation
typeMap = {'disc': FaceColorModule.FCM_DISCRETE,
           'cont': FaceColorModule.FCM_CONTINUOUS}
faceColorModel = FaceColorModel(PARAMS['kMrf'], typeMap[PARAMS['skinType']],
                                typeMap[PARAMS['hairType']])
faceColorModel.adapt_to(resizedImage)

# Generate and save output image
estimatedClass = reshape(faceColorModel.get_class_posterior(), (len(classIndex), scaledSize, scaledSize))
outputImageArray = ones((scaledSize, scaledSize, 3), dtype=float) * 255
for clas in classIndex:
    outputImageArray -= array(estimatedClass[classIndex[clas]] > PARAMS['threshold'], dtype=int)[:,:,newaxis] *\
                        (255-array(labelColors[clas]))[newaxis,newaxis,:]
outputImage = opencv.adaptors.NumPy2Ipl(array(outputImageArray, dtype=uint8))
highgui.cvSaveImage("segment_single_image_output.png", outputImage)

for class_ in ["skin","hair","clothes","background"]:
    with open("segment_single_image_%s.data"%class_, "wb") as fp:
        fp.write(array(faceColorModel.get_color_posterior(class_), dtype=float).data)
