"""
Copyright (c) 2011 Idiap Research Institute, http://www.idiap.ch/
Written by Carl Scheffler <carl.scheffler@gmail.com>

This file is part of FaceColorModel.

FaceColorModel is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License version 3 as
published by the Free Software Foundation.

FaceColorModel is distributed in the hope that it will be useful, but
WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
General Public License for more details.

You should have received a copy of the GNU General Public License
along with FaceColorModel. If not, see <http://www.gnu.org/licenses/>.
"""
from __future__ import division
from scipy import *
import opencv
from opencv import highgui
from viola_jones_opencv import viola_jones_opencv
from colorlib import load_image
import sys, command_line

import FaceColorModelWrapper as FaceColorModule
from FaceColorModelWrapper import FaceColorModelWrapper as FaceColorModel

commandLineSpec = (
    ("inputListFilename", "--input_list", str, "input_list.txt", """
      A text file listing the relative path to each input image
      in the test set."""),
    ("labelListFilename", "--label_list", str, "label_list.txt", """
      A text file listing the relative path to each label image
      in the test set. Each line in the label list should
      correspond to a line in the input list."""),
    ("kMrf", "--kmrf", float, 2, """
      The strength constant of the Markov random field."""),
    ("skinType", "--skin", str, "cont", """
      Whether to use the discrete (disc) or continuous (cont) color
      model for the skin palette."""),
    ("hairType", "--hair", str, "cont", """
      Whether to use the discrete (disc) or continuous (cont) color
      model for the hair palette."""),
    ("quiet", "--quiet", """
      Tell the script *not* to be verbose."""),
)

# Calling for help
if "--help" in sys.argv[1:]:
    print command_line.usage(commandLineSpec, sys.argv)
    sys.exit()

# Parse command line arguments
try:
    PARAMS, _ = command_line.parse(commandLineSpec, sys.argv)
except ValueError, msg:
    print msg
    print command_line.usage(commandLineSpec, sys.argv)
    sys.exit()

# Read input and label lists
fileLists = {}
for typ in ['input','label']:
    with open(PARAMS[typ + "ListFilename"], 'rt') as fp:
        fileLists[typ] = fp.read().strip().split('\n')
if len(fileLists['input']) != len(fileLists['label']):
    print "Error: input and label file lists have different numbers of entries"
    sys.exit()


# The size in pixels of a face image centered on the Viola-Jones
# bounding box and twice its size.
from FaceColorModelWrapper import FCM_SCALED_SIZE as scaledSize

classes = ['skin', 'hair', 'clothes', 'background']
classIndex = dict(zip(classes, range(len(classes))))
labelColors = {
    'skin':       (255, 255,   0),
    'hair':       (255,   0,   0),
    'background': (  0,   0, 255),
    'clothes':    (  0, 255,   0),
}

PAPER_PLOT = False # Produce plots for BMVC paper


typeMap = {'disc': FaceColorModule.FCM_DISCRETE,
           'cont': FaceColorModule.FCM_CONTINUOUS}
faceColorModel = FaceColorModel(PARAMS['kMrf'], typeMap[PARAMS['skinType']],
                                typeMap[PARAMS['hairType']])

fullTestResults = [] # true/false positives/negatives for each image
for fileIndex in range(len(fileLists['input'])):
    inputFilename = fileLists['input'][fileIndex]
    inputImage = load_image(inputFilename)
    if inputImage is None:
        print 'Error: Input image "%s" is None'%inputFilename
        continue

    labelFilename = fileLists['label'][fileIndex]
    labelImage = load_image(labelFilename)
    if labelImage is None:
        print 'Error: Label image "%s" is None'%labelFilename
        continue

    if (inputImage.width != labelImage.width) or (inputImage.height != labelImage.height):
        print 'Error: Input and label image shape mismatch (%s)'%inputFilename
        continue

    grayImage = load_image(inputFilename, highgui.CV_LOAD_IMAGE_GRAYSCALE)
    affineMap = opencv.cvCreateMat(2, 3, opencv.CV_32F)
    resizedImage = opencv.cvCreateImage(opencv.cvSize(scaledSize, scaledSize),
                                        opencv.IPL_DEPTH_8U, 3)
    face = viola_jones_opencv(grayImage)
    if face is None:
        print "Error: No faces found"
        continue

    scale = scaledSize/(2*(face[2]-face[0]))
    affineMap[0,0] = scale
    affineMap[0,1] = 0
    affineMap[0,2] = scaledSize/2 - scale*(face[1]+face[3])/2
    affineMap[1,0] = 0
    affineMap[1,1] = scale
    affineMap[1,2] = scaledSize/2 - scale*(face[0]+face[2])/2

    # Resize label image
    groundTruth = empty((len(classes), scaledSize**2), dtype=float)
    labelArray = array(labelImage, dtype=uint8)[:,:,::-1] # Reverse colour channels: BGR -> RGB
    resizedMask = opencv.cvCreateImage(opencv.cvSize(scaledSize, scaledSize),
                                       opencv.IPL_DEPTH_8U, 1)
    for clas in classes:
        index = classIndex[clas]
        classMask = opencv.adaptors.NumPy2Ipl(
            array((labelArray == array(labelColors[clas], dtype=uint8)).all(axis=2) * 255, dtype=uint8))
        opencv.cvWarpAffine(classMask, resizedMask, affineMap,
                            opencv.CV_INTER_AREA+opencv.CV_WARP_FILL_OUTLIERS, 0)
        groundTruth[index] = ravel(array(resizedMask) / 255)

    # Resize input image
    opencv.cvWarpAffine(inputImage, resizedImage, affineMap,
                        opencv.CV_INTER_AREA+opencv.CV_WARP_FILL_OUTLIERS, 0)
    del inputImage, affineMap, face

    # Do face color adaptation
    faceColorModel.adapt_to(resizedImage)
    estimatedClass = reshape(faceColorModel.get_class_posterior(), (len(classes), scaledSize**2))

    true_pos = {}
    false_pos = {}
    true_neg = {}
    false_neg = {}
    groundTruthWeight = sum(groundTruth, axis=0) # 0 weight indicates that a pixel
                                                 # was not labelled. In this case it
                                                 # should not contribute to any of the
                                                 # true/false positive/negative counts
    for clas in classes:
        index = classIndex[clas]
        true_pos[clas]  = sum(estimatedClass[index] * groundTruth[index])
        false_pos[clas] = sum(estimatedClass[index] * (groundTruthWeight-groundTruth[index]))
        true_neg[clas]  = sum((1-estimatedClass[index]) * (groundTruthWeight-groundTruth[index]))
        false_neg[clas] = sum((1-estimatedClass[index]) * groundTruth[index])
    fullTestResults.append([true_pos, false_pos, true_neg, false_neg])
    for clas in classes:
        precision = true_pos[clas]/(true_pos[clas]+false_pos[clas])
        recall = true_pos[clas]/(true_pos[clas]+false_neg[clas])
        fScore = 2*precision*recall/(precision+recall)
        print '%-10s: f %.3f, p %.3f, r %.3f -- tp %.3f, fp %.3f, tn %.3f, fn %.3f'%(
            clas, fScore, precision, recall, true_pos[clas],
            false_pos[clas], true_neg[clas], false_neg[clas])

print 'Overall:'
true_pos = {}
false_pos = {}
true_neg = {}
false_neg = {}
for clas in classes:
    true_pos[clas] = false_pos[clas] = true_neg[clas] = false_neg[clas] = 0
    for result in fullTestResults:
        true_pos[clas]  += result[0][clas]
        false_pos[clas] += result[1][clas]
        true_neg[clas]  += result[2][clas]
        false_neg[clas] += result[3][clas]
    precision = true_pos[clas]/(true_pos[clas]+false_pos[clas])
    recall = true_pos[clas]/(true_pos[clas]+false_neg[clas])
    fScore = 2*precision*recall/(precision+recall)
    print '%-10s: f %.3f, p %.3f, r %.3f -- tp %.3f, fp %.3f, tn %.3f, fn %.3f'%(
        clas, fScore, precision, recall, true_pos[clas],
        false_pos[clas], true_neg[clas], false_neg[clas])
