/********************************************************************************/
/* Stats routines.  stolen from pattern_arraymodule.c, so if there's a
   mistake here, might want to check there, too. */
/********************************************************************************/

#include "Python.h"
#include <assert.h>
#include <math.h>

/********************************************************************************/
/* Memoize factorials, to speed things up a bit. */
/********************************************************************************/

float *stats_factorials;

float stats_factorial(
		 /* Number to take factorial of. */
		 int n
		 ) {

  /********************************************************************************/
  /* Return log(n!) */
  /********************************************************************************/

  assert(n < 500000);
  return stats_factorials[n];
}

float stats_choose(
		   /* Total set size*/
		   int total,
		   
		   /* Subset size */
		   int subset
		   ) {

  /********************************************************************************/
  /* Return log of number of subsets of size subset of a set of size
     total */
  /********************************************************************************/

  /********************************************************************************/
  return stats_factorial(total) - \
    (stats_factorial(subset) + stats_factorial(total - subset));
}

inline float stats_log_binomial_term(int count, int numevents,
				     float logprob, float logcompprob) {
  return logprob*count + logcompprob*(numevents-count) + \
    stats_choose(numevents, count);
}

float stats_binomial_likelihood(int count, int numevents, float prob) {

  /********************************************************************************/
  /* Return the tail-likelihood of seeing 'count' events of
     probability 'prob' in a sample of size numevents */
  /********************************************************************************/

  int count_idx, computing_complement;

  /* When computing the odds for each count, want to move away from
     the expected count, so that probabilities are decreasing, and
     know we can abandon the loop when an individual prob is too small
     to matter.  These are the boundaries for the two loops, from the
     expected count down, and from the expected count up. */
  int down_count_start, down_count_end, up_count_start, up_count_end;
  float rv, logprob, logcompprob, exp_cnt, log_count_odds;;

  /* Boundary cases */
  if (prob == 0) {
    if (count == 0) {
      rv = 1;
    } else {
      rv = 0;
    }
    return rv;
  }

  if (prob == 1) {
    if (count == numevents) {
      rv = 1;
    } else {
      rv = 0;
    }
    return rv;
  }

  exp_cnt = numevents * prob;

  /* Decide which tail to compute (the one with the least number of
     counts. */
  if (count < exp_cnt) {
    if (count < (numevents/2)) {
      down_count_start = count;
      down_count_end = 0;

      /* It's all covered in the down-count.  Make the up-count
	 trivial. */
      up_count_start = 1;
      up_count_end = -1;
      computing_complement = 0;
    } else {
      down_count_start = (int)exp_cnt;
      down_count_end = count+1;
      up_count_start = ((int)exp_cnt)+1;
      up_count_end = numevents;
      computing_complement = 1;
    }
  } else {
    if (count < (numevents/2)) {
      down_count_start = (int)exp_cnt;
      down_count_end = 0;
      up_count_start = ((int)exp_cnt) + 1;
      up_count_end = count-1;
      computing_complement = 1;
    } else {
      up_count_start = count;
      up_count_end = numevents;

      /* It's all covered in the up-count.  Make the down-count
	 trivial. */
      down_count_start = -1;
      down_count_end = 1;
      computing_complement = 0;
    }
  }

  rv = 0;
  logprob = log(prob);
  logcompprob = log(1-prob);
  for (count_idx = up_count_start; count_idx <= up_count_end; count_idx++) {
    log_count_odds = stats_log_binomial_term(count_idx, numevents,
					     logprob, logcompprob);
    if (log_count_odds < -300) {
      break;
    }
    rv += exp(log_count_odds);
  }

  for (count_idx = down_count_start; count_idx >= down_count_end; count_idx--) {
    log_count_odds = stats_log_binomial_term(count_idx, numevents,
					     logprob, logcompprob);
    if (log_count_odds < -300) {
      break;
    }
    rv += exp(log_count_odds);
  }
  if (computing_complement) {
    return 1-rv;
  } else{
    return rv;
  }
    
}

PyObject *stats_binomial_likelihood_wrapper(PyObject *self, PyObject *args) {

  /* Default values to shut compiler up */
  int count=0, numevents=0;
  float prob=0, rv;

  if (!PyArg_ParseTuple(args, "iif", &count, &numevents, &prob)) {
    return NULL;
  }
  if ((count < 0) || (count > numevents)) {
    PyErr_SetString(PyExc_ValueError,
		    "count should be between 0 and numevents, inclusive");
  }
  if (numevents < 0) {
    PyErr_SetString(PyExc_ValueError, "numevents should be non-negative");
  }
  if ((prob < 0) || (prob > 1)) {
    PyErr_SetString(PyExc_ValueError,
		    "prob should be between 0 and 1");
  }

  rv = stats_binomial_likelihood(count, numevents, prob);
  return Py_BuildValue("f", rv);
}

PyObject *stats_log_choose(PyObject *self, PyObject *args) {

  int m=0, n=0;
  if (!PyArg_ParseTuple(args, "ii", &m, &n)) {
    return NULL;
  }
  if ((m < 0) || (n < 0) || (n > m)) {
    PyErr_SetString(PyExc_ValueError, "Should have m >= n >= 0");
    return NULL;
  }
  return Py_BuildValue("f", stats_choose(m, n));
}

static PyMethodDef _statistics_methods[] = {
  {"binomial_likelihood", stats_binomial_likelihood_wrapper, METH_VARARGS},
  {"log_choose",          stats_log_choose,                  METH_VARARGS},  
  {NULL, NULL}
};

void init_statistics(void){

  /* Iterates over factorials to compute. */
  int factorial_idx;

  (void)Py_InitModule("_statistics", _statistics_methods);
  
  /* Record a bunch of factorials here, to speed things up. */
  stats_factorials = (float *)calloc(500000, sizeof(float));
  assert(stats_factorials); /* failure means memory overflow. */

  stats_factorials[0] = 0;
  for (factorial_idx = 1; factorial_idx < 500000; factorial_idx++) {
    stats_factorials[factorial_idx] = stats_factorials[factorial_idx-1] + \
                                log(factorial_idx);
  }
}
