#include "FaceColorModel.h"

#include <stdio.h>
#include <stdexcept>

namespace FaceColorModel {

  // --------------------------------------------------

  // Support functions

  // --------------------------------------------------

  inline double sqr(double x) {
    return x*x;
  }

  // --------------------------------------------------

  /// Approximation to digamma function, based on series expansion of
  /// digamma(x) - log(x).
  double digamma(double x) {
    assert(x > 0);
    double result = 0;
    double xx, xx2;
    while(x < 7) {
      // The approximation is valid for x >> 1 only, so use the identity
      // [digamma(x) = digamma(x+1) - 1/x] to get x in right range.
      result -= 1/x;
      x++;
    }
    xx = 1 / (x-0.5);
    xx2 = xx * xx;
    result +=
      log(x) + xx2 * (
        (1./24.) - xx2 * (
          (7./960.) + xx2 * (
            (31./8064.) - xx2*(127./30720.))));
    return result;
  }

  // --------------------------------------------------

  /// Convert from BGR to YPbPr color space. The input should be length
  /// 3*iLength. BGR values should be in the range [0,255]. The function
  /// output has the same shape as the input with YPbPr values in the
  /// range [0,1].
  void bgr2ypbpr(unsigned char const* ipBgr, double* opYpbpr,
		 int iLength) {
    double const A[] = { 0.299/255,     0.587/255,     0.114/255,
			-0.168736/255, -0.331264/255,  0.5/255,
			 0.5/255,      -0.418688/255, -0.081312/255};
    double const B[] = {0, 0.5, 0.5};

    int imageIndex = 0;
    for(int i = 0; i < iLength; i++) {
      unsigned char b = ipBgr[imageIndex++];
      unsigned char g = ipBgr[imageIndex++];
      unsigned char r = ipBgr[imageIndex++];
      for(int j = 0; j < 3; j++)
	opYpbpr[3*i+j] = r*A[3*j+0] + g*A[3*j+1] + b*A[3*j+2] + B[j];
    }
  }

  // --------------------------------------------------

  /// Compute color histogram bin index of each pixel ipImage must have
  /// shape rows x columns opHistogram must have space for rows*columns
  /// entries.
  void image_to_histogram(IplImage const* ipImage, int* opHistogram,
			  int iRows, int iColumns) {
    int histogramIndex = 0;
    int imageIndex = 0;
    for(int _ = 0; _ < iRows*iColumns; _++) {
      int bBin = ((int)(((unsigned char*)ipImage->imageData)[imageIndex++]) *
		  FCM_HISTOGRAM_BINS) >> 8;
      int gBin = ((int)(((unsigned char*)ipImage->imageData)[imageIndex++]) *
		  FCM_HISTOGRAM_BINS) >> 8;
      int rBin = ((int)(((unsigned char*)ipImage->imageData)[imageIndex++]) *
		  FCM_HISTOGRAM_BINS) >> 8;
      opHistogram[histogramIndex] =
	rBin * FCM_HISTOGRAM_BINS*FCM_HISTOGRAM_BINS +	\
	gBin * FCM_HISTOGRAM_BINS +			\
	bBin;
      histogramIndex++;
    }
  }

  // --------------------------------------------------

  // Lifecycle methods

  // --------------------------------------------------

  FaceColorModel::FaceColorModel(double iKMrf/*=2*/,
				 int iSkinType/*=FCM_CONTINUOUS*/,
				 int iHairType/*=FCM_CONTINUOUS*/) {
    // Set color model types
    mpModelTypes[FCM_CHANNEL_SKIN]       = iSkinType;
    mpModelTypes[FCM_CHANNEL_HAIR]       = iHairType;
    mpModelTypes[FCM_CHANNEL_CLOTHES]    = FCM_DISCRETE;
    mpModelTypes[FCM_CHANNEL_BACKGROUND] = FCM_DISCRETE;

    mVariationalIterations = 20;

    mKMrf = iKMrf;

    // Load the prior color models from file
    for(int channel = 0; channel < FCM_NUM_CHANNELS; channel++) {
      std::string filename = "storage/color_prior_";
      int dataSize;
      switch(channel) {
      case FCM_CHANNEL_SKIN:
	filename += "skin";       break;
      case FCM_CHANNEL_HAIR:
	filename += "hair";       break;
      case FCM_CHANNEL_CLOTHES:
	filename += "clothes";    break;
      case FCM_CHANNEL_BACKGROUND:
	filename += "background"; break;
      }
      if(mpModelTypes[channel] == FCM_DISCRETE) {
	filename += "_discrete.data";
	dataSize = FCM_HISTOGRAM_SIZE;
      } else {
	filename += "_continuous.data";
	dataSize = 12;
      }

      mpColorPriors[channel] = new double [dataSize];
      mpColorPosteriors[channel] = new double [dataSize];
      FILE* f = fopen(filename.c_str(), "rb");
      fread(mpColorPriors[channel], sizeof(double), dataSize, f);
      fclose(f);

      // Pre-computation for color priors
      if(mpModelTypes[channel] == FCM_DISCRETE) {
	mpPrecomputeColorPriors[channel] = new DiscreteColorPrecompute;
	mpPrecomputeColorPosteriors[channel] = new DiscreteColorPrecompute;
	__precompute_discrete_color(
	  mpColorPriors[channel],
	  (DiscreteColorPrecompute*)mpPrecomputeColorPriors[channel]);
      } else {
	mpPrecomputeColorPriors[channel] = new ContinuousColorPrecompute;
	mpPrecomputeColorPosteriors[channel] = new ContinuousColorPrecompute;
	__precompute_continuous_color(
	  mpColorPriors[channel],
	  (ContinuousColorPrecompute*)mpPrecomputeColorPriors[channel]);
      }
    }

    // Load the PIM prior from file
    mpChannelPrior =
      new double [FCM_NUM_CHANNELS * FCM_SCALED_SIZE * FCM_SCALED_SIZE];
    mpChannelPosterior =
      new double [FCM_NUM_CHANNELS * FCM_SCALED_SIZE * FCM_SCALED_SIZE];
    FILE* f = fopen("storage/pim_prior.data", "rb");
    fread(mpChannelPrior, sizeof(double),
	  FCM_NUM_CHANNELS * FCM_SCALED_SIZE * FCM_SCALED_SIZE, f);
    fclose(f);

    reset_to_prior();

    // Pre-computation for variational inference
    mLogDetTransformYpbpr = -1.4427388438401436;
    mLogBinVolume = -3*log(FCM_HISTOGRAM_BINS);
    mpLogChannelPrior =
      new double [FCM_NUM_CHANNELS * FCM_SCALED_SIZE * FCM_SCALED_SIZE];
    for(int i = 0;
	i < FCM_NUM_CHANNELS * FCM_SCALED_SIZE * FCM_SCALED_SIZE;
	i++) {
      mpLogChannelPrior[i] = log(mpChannelPrior[i]);
    }
  }

  // --------------------------------------------------

  FaceColorModel::~FaceColorModel() {
    for(int channel = 0; channel < FCM_NUM_CHANNELS; channel++) {
      delete [] mpColorPriors[channel];
      delete [] mpColorPosteriors[channel];
      if(mpModelTypes[channel] == FCM_CONTINUOUS) {
	delete (ContinuousColorPrecompute*)mpPrecomputeColorPriors[channel];
	delete (ContinuousColorPrecompute*)mpPrecomputeColorPosteriors[channel];
      } else {
	delete (DiscreteColorPrecompute*)mpPrecomputeColorPriors[channel];
	delete (DiscreteColorPrecompute*)mpPrecomputeColorPosteriors[channel];
      }
    }
    delete [] mpChannelPrior;
    delete [] mpChannelPosterior;
    delete [] mpLogChannelPrior;
  }

  // --------------------------------------------------

  // Getters and setters

  // --------------------------------------------------

  const double* FaceColorModel::get_color_posterior(int iChannelIndex) {
    if(mPosteriorsEqualPriors)
      return mpColorPriors[iChannelIndex];
    else
      return mpColorPosteriors[iChannelIndex];
  }

  // --------------------------------------------------

  const double* FaceColorModel::get_class_posterior() {
    if(mPosteriorsEqualPriors)
      return mpChannelPrior;
    else
      return mpChannelPosterior;
  }

  // --------------------------------------------------

  int FaceColorModel::get_channel_type(int iChannelIndex) {
    return mpModelTypes[iChannelIndex];
  }

  // --------------------------------------------------

  // General methods

  // --------------------------------------------------

  void FaceColorModel::adapt_to(IplImage const* ipImage) {
    // Check that image is of the right data type
    if((ipImage->nChannels != 3) or (ipImage->depth != IPL_DEPTH_8U))
      throw std::runtime_error("Input image mush have 3 channels and"
			       "a depth of 1 byte (IPL_DEPTH_8U).");

    // This flag needs to be cleared early so that
    // __add_color_log_likelihood() works correctly.
    mPosteriorsEqualPriors = false;

    // Check if input image has the right size
    IplImage* pScaledImage;
    IplImage const* pInputImage;
    if((ipImage->width != FCM_SCALED_SIZE) or
       (ipImage->height != FCM_SCALED_SIZE)) {
      pScaledImage = cvCreateImage(cvSize(FCM_SCALED_SIZE,
					  FCM_SCALED_SIZE),
				   IPL_DEPTH_8U, 3);
      cvResize(ipImage, pScaledImage, CV_INTER_LINEAR);
      pInputImage = pScaledImage;
    } else {
      pInputImage = ipImage;
    }

    // Compute color histogram bin index of each pixel
    int* pInputBins = new int [FCM_SCALED_SIZE*FCM_SCALED_SIZE];
    image_to_histogram(pInputImage, pInputBins,
		       FCM_SCALED_SIZE, FCM_SCALED_SIZE);

    // Convert input image to YPbPr color space
    double* pInputYpbpr = new double [FCM_SCALED_SIZE*FCM_SCALED_SIZE*3];
    bgr2ypbpr((unsigned char*)pInputImage->imageData, pInputYpbpr,
	      FCM_SCALED_SIZE*FCM_SCALED_SIZE);

    // Initialize channel beliefs to the prior
    memcpy(
      mpChannelPosterior, mpChannelPrior,
      FCM_NUM_CHANNELS * FCM_SCALED_SIZE * FCM_SCALED_SIZE * sizeof(double));

    // Compute (variational) posteriors over all parameters --- i.e., all
    // palettes in all images and the class of each pixel in each image.
    double* pLogChannelPosterior =
      new double [FCM_NUM_CHANNELS * FCM_SCALED_SIZE * FCM_SCALED_SIZE];
    for(int variationalIteration = 0;
	variationalIteration < mVariationalIterations;
	variationalIteration++) {

      // Update belief over palettes
      for(int channel = 0; channel < FCM_NUM_CHANNELS; channel++) {
	double* pChannelMask =
	  &(mpChannelPosterior[channel * FCM_SCALED_SIZE * FCM_SCALED_SIZE]);
	double* pColorPrior = mpColorPriors[channel];
	double* pColorPosterior = mpColorPosteriors[channel];

	if(mpModelTypes[channel] == FCM_CONTINUOUS) {
	  double cumulant0 = 0;
	  double cumulant1[3] = {0,0,0};
	  double cumulant2[3] = {0,0,0};
	  for(int i = 0; i < FCM_SCALED_SIZE*FCM_SCALED_SIZE; i++) {
	    double mask = pChannelMask[i];
	    cumulant0 += mask;
	    for(int j = 0; j < 3; j++) {
	      double col = pInputYpbpr[3*i+j];
	      double x = mask * col;
	      cumulant1[j] += x;
	      cumulant2[j] += x * col;
	    }
	  }

	  for(int i = 0; i < 3; i++) {
	    pColorPosterior[1+4*i] = pColorPrior[1+4*i] + cumulant0;
	    pColorPosterior[0+4*i] =
	      (pColorPrior[0+4*i] * pColorPrior[1+4*i] + cumulant1[i])
	      / pColorPosterior[1+4*i];
	    pColorPosterior[2+4*i] = pColorPrior[2+4*i] + cumulant0/2;
	    pColorPosterior[3+4*i] =
	      pColorPrior[3+4*i] + cumulant2[i]/2 +
	      sqr(pColorPrior[0+4*i]) * pColorPrior[1+4*i]/2 -
	      sqr(pColorPosterior[0+4*i]) * pColorPosterior[1+4*i]/2;
	  }
	  __precompute_continuous_color(
	    pColorPosterior,
	    (ContinuousColorPrecompute*)mpPrecomputeColorPosteriors[channel]);
	} else { // if(mpModelTypes[channel] == FCM_DISCRETE)
	  memcpy(pColorPosterior, pColorPrior,
		 FCM_HISTOGRAM_SIZE * sizeof(double));
	  for(int i = 0; i < FCM_SCALED_SIZE*FCM_SCALED_SIZE; i++) {
	    pColorPosterior[pInputBins[i]] +=
	      pChannelMask[i];
	  }
	  __precompute_discrete_color(
	    pColorPosterior,
	    (DiscreteColorPrecompute*)mpPrecomputeColorPosteriors[channel]);
	}
      }

      // Update belief over pixel channels
      double kMrf = (mKMrf*variationalIteration)/mVariationalIterations;
      // Initialize to prior
      memcpy(
	pLogChannelPosterior, mpLogChannelPrior,
	FCM_NUM_CHANNELS * FCM_SCALED_SIZE * FCM_SCALED_SIZE * sizeof(double));
      // Add contribution from Markov random field
      for(int channel = 0; channel < FCM_NUM_CHANNELS; channel++) {
	int prevRowIndex;
	int thisRowIndex = channel*FCM_SCALED_SIZE*FCM_SCALED_SIZE;
	int nextRowIndex = thisRowIndex + FCM_SCALED_SIZE;
	// First row
	pLogChannelPosterior[thisRowIndex] +=
	  kMrf * (
		  mpChannelPosterior[thisRowIndex + 1] +
		  mpChannelPosterior[nextRowIndex]);
	thisRowIndex++;
	nextRowIndex++;
	for(int col = 1; col < FCM_SCALED_SIZE-1; col++) {
	  pLogChannelPosterior[thisRowIndex] +=
	    kMrf * (
		    mpChannelPosterior[thisRowIndex - 1] +
		    mpChannelPosterior[thisRowIndex + 1] +
		    mpChannelPosterior[nextRowIndex]);
	  thisRowIndex++;
	  nextRowIndex++;
	}
	pLogChannelPosterior[thisRowIndex] +=
	  kMrf * (
		  mpChannelPosterior[thisRowIndex - 1] +
		  mpChannelPosterior[nextRowIndex]);
	thisRowIndex++;
	nextRowIndex++;
	// Up to penultimate row
	prevRowIndex = thisRowIndex - FCM_SCALED_SIZE;
	for(int row = 1; row < FCM_SCALED_SIZE-1; row++) {
	  pLogChannelPosterior[thisRowIndex] +=
	    kMrf * (
		    mpChannelPosterior[prevRowIndex] +
		    mpChannelPosterior[thisRowIndex + 1] +
		    mpChannelPosterior[nextRowIndex]);
	  prevRowIndex++;
	  thisRowIndex++;
	  nextRowIndex++;
	  for(int col = 1; col < FCM_SCALED_SIZE-1; col++) {
	    pLogChannelPosterior[thisRowIndex] +=
	      kMrf * (
		      mpChannelPosterior[prevRowIndex] +
		      mpChannelPosterior[thisRowIndex - 1] +
		      mpChannelPosterior[thisRowIndex + 1] +
		      mpChannelPosterior[nextRowIndex]);
	    prevRowIndex++;
	    thisRowIndex++;
	    nextRowIndex++;
	  }
	  pLogChannelPosterior[thisRowIndex] +=
	    kMrf * (
		    mpChannelPosterior[prevRowIndex] +
		    mpChannelPosterior[thisRowIndex - 1] +
		    mpChannelPosterior[nextRowIndex]);
	  prevRowIndex++;
	  thisRowIndex++;
	  nextRowIndex++;
	}
	// Last row
	pLogChannelPosterior[thisRowIndex] +=
	  kMrf * (
		  mpChannelPosterior[prevRowIndex] +
		  mpChannelPosterior[thisRowIndex + 1]);
	prevRowIndex++;
	thisRowIndex++;
	for(int col = 1; col < FCM_SCALED_SIZE-1; col++) {
	  pLogChannelPosterior[thisRowIndex] +=
	    kMrf * (
		    mpChannelPosterior[prevRowIndex] +
		    mpChannelPosterior[thisRowIndex - 1] +
		    mpChannelPosterior[thisRowIndex + 1]);
	  prevRowIndex++;
	  thisRowIndex++;
	}
	pLogChannelPosterior[thisRowIndex] +=
	  kMrf * (
		  mpChannelPosterior[prevRowIndex] +
		  mpChannelPosterior[thisRowIndex - 1]);
	prevRowIndex++;
	thisRowIndex++;
      }

      // Add contributions from color likelihoods
      __add_color_log_likelihood(pInputYpbpr, pInputBins,
				 pLogChannelPosterior,
				 FCM_SCALED_SIZE, FCM_SCALED_SIZE);

      // Convert log channel posterior to channel posterior
      for(int pixel = 0; pixel < FCM_SCALED_SIZE*FCM_SCALED_SIZE; pixel++) {
	double maxLog = -1e300;
	for(int channel = 0; channel < FCM_NUM_CHANNELS; channel++) {
	  int index = pixel + channel*(FCM_SCALED_SIZE*FCM_SCALED_SIZE);
	  if(pLogChannelPosterior[index] > maxLog)
	    maxLog = pLogChannelPosterior[index];
	}

	double sum = 0;
	for(int channel = 0, index = pixel;
	    channel < FCM_NUM_CHANNELS;
	    channel++, index += FCM_SCALED_SIZE*FCM_SCALED_SIZE) {
	  mpChannelPosterior[index] = exp(pLogChannelPosterior[index]-maxLog);
	  sum += mpChannelPosterior[index];
	}

	for(int channel = 0, index = pixel;
	    channel < FCM_NUM_CHANNELS;
	    channel++, index += FCM_SCALED_SIZE*FCM_SCALED_SIZE)
	  mpChannelPosterior[index] /= sum;
      }
    }

    delete [] pInputBins;
    delete [] pInputYpbpr;
    delete [] pLogChannelPosterior;
  }

  // --------------------------------------------------

  void FaceColorModel::channel_log_likelihood(
    const IplImage* ipImage, double* opLogLikelihood) {

    // Check that image is of the right data type
    if((ipImage->nChannels != 3) or (ipImage->depth != IPL_DEPTH_8U))
      throw std::runtime_error("Input image mush have 3 channels and"
			       "a depth of 1 byte (IPL_DEPTH_8U).");

    // Compute color histogram bin index of each pixel
    int* pInputBins = new int [ipImage->height * ipImage->width];
    image_to_histogram(ipImage, pInputBins,
		       ipImage->height, ipImage->width);

    // Convert input image to YPbPr color space
    double* pInputYpbpr = new double [ipImage->height * ipImage->width * 3];
    bgr2ypbpr((unsigned char*)ipImage->imageData, pInputYpbpr,
	      ipImage->width * ipImage->height);

    // Update belief over pixel channels
    memset(
      opLogLikelihood, 0,
      FCM_NUM_CHANNELS * ipImage->height * ipImage->width * sizeof(double));
    __add_color_log_likelihood(pInputYpbpr, pInputBins, opLogLikelihood,
			       ipImage->height, ipImage->width);

    delete [] pInputBins;
    delete [] pInputYpbpr;
  }

  // --------------------------------------------------

  void FaceColorModel::reset_to_prior() {
    mPosteriorsEqualPriors = true;
  }

  // --------------------------------------------------

  // Support methods

  // --------------------------------------------------

  void FaceColorModel::__add_color_log_likelihood(
    double* ipInputYpbpr, int* ipInputBins,
    double* opLikelihood, int iRows, int iColumns) {

    double** pColorDistributions;
    void** pPrecomputeColorDistributions;
    if(mPosteriorsEqualPriors) {
      pColorDistributions = mpColorPriors;
      pPrecomputeColorDistributions = mpPrecomputeColorPriors;
    } else {
      pColorDistributions = mpColorPosteriors;
      pPrecomputeColorDistributions = mpPrecomputeColorPosteriors;
    }

    for(int channel = 0; channel < FCM_NUM_CHANNELS; channel++) {
      if(mpModelTypes[channel] == FCM_CONTINUOUS) {
	ContinuousColorPrecompute* pPrecompute =
	  (ContinuousColorPrecompute*)pPrecomputeColorDistributions[channel];
	int index = channel * iRows * iColumns;
	for(int j = 0; j < iRows * iColumns; j++) {
	  double accumulator = 0;
	  for(int i = 0; i < 3; i++)
	    accumulator +=
	      pPrecompute->alphaOver2Beta[i] *
	      sqr(ipInputYpbpr[i+3*j] - pPrecompute->mu[i]);
	  opLikelihood[index++] += pPrecompute->logConstant - accumulator;
	}
      } else { // FCM_DISCRETE
	DiscreteColorPrecompute* pPrecompute =
	  (DiscreteColorPrecompute*)pPrecomputeColorDistributions[channel];
	int index = channel * iRows * iColumns;
	for(int j = 0; j < iRows * iColumns; j++, index++) {
	  opLikelihood[index] +=
	    + pPrecompute->digammaDistribution[ipInputBins[j]]
	    - pPrecompute->logConstant;
	}
      }
    }
  }

  // --------------------------------------------------

  void FaceColorModel::__precompute_discrete_color(
    double const* ipDistribution, DiscreteColorPrecompute* opPrecompute) {

    double constant = 0;
    for(int i = 0; i < FCM_HISTOGRAM_SIZE; i++)
      constant += ipDistribution[i];
    opPrecompute->logConstant = digamma(constant) + mLogBinVolume;

    for(int i = 0; i < FCM_HISTOGRAM_SIZE; i++)
      opPrecompute->digammaDistribution[i] = digamma(ipDistribution[i]);
  }

  // --------------------------------------------------

  void FaceColorModel::__precompute_continuous_color(
    double const* ipDistribution, ContinuousColorPrecompute* opPrecompute) {

    double sumDigammaAlpha = 0;
    double sumLogBeta = 0;
    double sumInvTau = 0;
    for(int i = 0; i < 3; i++) {
      double tau   = ipDistribution[1+4*i];
      double alpha = ipDistribution[2+4*i];
      double beta  = ipDistribution[3+4*i];

      sumDigammaAlpha += digamma(alpha);
      sumLogBeta += log(beta);
      sumInvTau += 1/tau;

      opPrecompute->mu[i] = ipDistribution[0+4*i];
      opPrecompute->alphaOver2Beta[i] = alpha/beta/2;
    }

    opPrecompute->logConstant =
      + mLogDetTransformYpbpr // + log |A|
      - 2.756815599614018     // - 1.5*log(2*pi)
      + sumDigammaAlpha / 2
      - sumLogBeta / 2
      - sumInvTau / 2;
  }

  // --------------------------------------------------

} // namespace FaceColorModel
