#include <math.h>
#include "seq.h"
#include "quads.h"
#include "tree.h"

#define DEPENDENCY   16.3
#define MINIMUM_NUCS 175


// Branch functions

Branch::Branch(int arraysize) {
  
  leaf = true;
  left = NULL;
  right = NULL;
  branchRightOn1 = BASE_UNKNOWN;
  branchRightOn2 = BASE_UNKNOWN;
  count = indexOfNuc = 0;

  size = arraysize;
  probtable = new (QProb*)[size];
  for (int i=0; i < size; i++)
    probtable[i] = new QProb();
}


Branch::~Branch() {

  for (int i=0; i < size; i++)
    delete probtable[i];
  delete[] probtable;
  
  if (!leaf) {
    if (!( left && right )) cerr << "BAD TREE !!!!!!!!!!!!!!!!!!!!";
    else {
      delete left;
      delete right;
    }
  }
}

void Branch::addToCount(Nucleotide n, int index) {
  
  probtable[index]->addToCount(n);
}

void Branch::computeProbs() {
  for (int i=0; i < size; i++)
    probtable[i]->computeProb();

  if (!leaf) {
    left->computeProbs();
    right->computeProbs();
  }
}


double Branch::getProb(Nucleotide n, int index) {
  return probtable[index]->getProb(n);
}


double Branch::getCount(Nucleotide n, int index) {
  return probtable[index]->getCount(n);
}


istream& operator>>(istream& cin, Branch& b) {

  int temp1, temp2;

  cin >> b.count >> b.leaf >> b.size;
  cin >> temp1 >> temp2 >> b.indexOfNuc;
  for (int i=0; i < b.size; i++) 
    cin >> *(b.probtable[i]);
  if (!b.leaf) {
    if (!b.left)  b.left  = new Branch(b.size - 1);
    if (!b.right) b.right = new Branch(b.size - 1);
    cin >> *(b.left) >> *(b.right);
  }
  b.branchRightOn1 = (Nucleotide)temp1;
  b.branchRightOn2 = (Nucleotide)temp2;
  return cin;
}


ostream& operator<<(ostream& cout, Branch& b) {
  
  cout << b.count << "\t" << b.leaf << "\t" << b.size << endl;
  cout << (int)b.branchRightOn1 << "\t" << (int)b.branchRightOn2 << "\t" 
       << b.indexOfNuc << endl;
  for (int i=0; i < b.size; i++) 
    cout << *(b.probtable[i]);
  if (!b.leaf)
    cout << *(b.left) << "\n" << *(b.right) << endl;
  
  return cout;
}



// Tree functions

Tree::Tree(int maxArraySize) {

  size = maxArraySize;
  root = new Branch(size);
}


Tree::~Tree() {

  delete root;
}


istream& operator>>(istream& cin, Tree& tree) {
  
  cin >> tree.size >> *(tree.root); 
  return cin;
}


ostream& operator<<(ostream& cout, Tree& tree) {

  cout << tree.size << "\n" <<  *(tree.root) << endl;
  return cout;
}


void Tree::computeProb(Nucleotide tides[], double &probability) {

  probability = 1.0;
  Branch *branch = root;

  // keep track of indeces we've already seen
  bool seen[size];
  for (int i=0; i < size; i++)
    seen[i] = false;
  
  while (!branch->leaf) {
    
    long oldcount = branch->count;
    int index = branch->indexOfNuc;

    if (index >= size) {
      cerr << "Bad index: " << branch->indexOfNuc << endl;
      return;
    }
    
    if (branch->branchRightOn1 == tides[index] || 
	(branch->branchRightOn2 != BASE_UNKNOWN &&
	 branch->branchRightOn2 == tides[index]))
      branch = branch->right;
    else
      branch = branch->left;
    
    probability *= (double)branch->count / (double)oldcount;
    seen[index] = true;
  }

  int counter = 0;
  for (int i=0; i < size; i++) 
    if (!seen[i]) 
      probability *= branch->getProb(tides[i], counter++);
}


void Tree::createTree(Nucleotide *nucs, long length) {
  
  // determine the consensus basepairs
  Nucleotide *consensus = new Nucleotide[2*size];
  
  int x = 4 * size;
  long *counts = new long[x]; 
  for (int i=0; i < x; i++) 
    counts[i] = 0;
  long *tots = new long[size]; 
  for (int i=0; i < size; i++) 
    tots[i] = 0;
  
  // for each set of size nucleotides
  for (int i=0; i < length; i+=size) {
    
    // for each nucleotide in the set
    for (int j=0; j < size; j++) {
      
      Nucleotide n = nucs[i+j];
      if (n == BASE_UNKNOWN)
	n = BASE_C;
      
      counts[j * 4 + (int)n]++;
      tots[j]++;
    }
  }
  
  for (int i=0; i < size; i++) {

    long val1 = -1, val2 = -1;
    int index1, index2;

    for (int j=0; j < 4; j++) {
      
      x = i*4+j;
      if (counts[x] > val1) {
	val2 = val1;
	index2 = index1;
	val1 = counts[x]; 
	index1 = j;
      }
      else if (counts[x] > val2) {
	val2 = counts[x]; 
	index2 = j;
      }
    }
    
    consensus[i*2] = (Nucleotide)index1; 
    if ((double)tots[i] < 1.8 * (double)val1 || val1 > 2 * val2)
      consensus[i*2+1] = BASE_UNKNOWN; 
    else
      consensus[i*2+1] = (Nucleotide)index2; 
  }    
  
  delete counts;
  delete tots;

  //      for debugging...
  //  for (int i=0; i < 2*size; i++) 
  //  cout << (int)consensus[i] << ", ";
  //  cout << "\n";
  

  // keep track of indeces we've already seen
  bool seen[size];
  for (int i=0; i < size; i++)
    seen[i] = false;

  // now let's create the tree
  recursiveFormTree(nucs, length, root, consensus, seen);
  
  // have each branch compute the probs for its tables
  root->computeProbs();

  delete consensus;
}


void Tree::recursiveFormTree(Nucleotide *nucs, long length, Branch *branch, 
			     Nucleotide *consensus, bool seen[]) {
  
  // the values for each possible consensus nucleotide
  double *values = new double[size]; 
  for (int i=0; i < size; i++) 
    values[i] = 0.0;
  
  // do a "chi-squared" computation on the data
  chi_squared(nucs, length, consensus, values);
  

  // for debugging purposes...
  //for (int i=0; i < size; i++)
  //  cout << values[i] << ", ";
  //cout << "\n";
  
  
  // find the max value in values
  int max = 0;
  double maxval = values[0];
  for (int i=1; i < size; i++) {
    if (values[i] > maxval) {
      max = i;
      maxval = values[i];
    }
  }
  
  delete values;

  // fill in the branch's tables
  int counter;
  for (long i=0; i < length; i+=size) {
    counter = 0;
    for (int j=0; j < size; j++) 
      if (!seen[j]) 
	branch->addToCount(nucs[i+j], counter++);
  }
  
  // set up the branch info
  branch->indexOfNuc = max;
  branch->branchRightOn1 = consensus[2*max];
  branch->branchRightOn2 = consensus[2*max+1];  
  branch->count = length / size;

  // do we want to continue branching now?
  //    3 rules for stopping:
  //  Rule #1: lowest level of tree
  if (branch->size == 1)
    return;

  //  Rule #2: no dependencies
  if (maxval < DEPENDENCY)
    return;
  

  // go through the nucleotide sets and separate them into two 
  //  subsets: one for each of the branches
  Nucleotide *subnucs1 = new Nucleotide[length];
  Nucleotide *subnucs2 = new Nucleotide[length];
  long sublength1 = 0, sublength2 = 0;

  for (long i=0; i < length; i+=size) {
    
    int x = i+max;
    int y = 2*max;
    // if it's the consensus
    if (nucs[x] == consensus[y] || 
	(consensus[y+1] != BASE_UNKNOWN &&
	 nucs[x] == consensus[y+1]))
      for (int j=0; j < size; j++) 
	subnucs1[sublength1++] = nucs[i+j];
    else
      for (int j=0; j < size; j++) 
	subnucs2[sublength2++] = nucs[i+j];
  }


  //  Rule #3: not enough sequences
  if (sublength1 / size < MINIMUM_NUCS || 
      sublength2 / size < MINIMUM_NUCS) {
    delete subnucs1;
    delete subnucs2;
    return;
  }

  seen[max] = true;
  branch->leaf = false;
  branch->left = new Branch(branch->size -1);
  branch->right = new Branch(branch->size -1);
  
  
  // both need their own copies of seen...
  bool seencopy[size];
  for (int i=0; i < size; i++)
    seencopy[i] = seen[i];
  
  recursiveFormTree(subnucs1, sublength1, branch->right, consensus, seen);
  delete subnucs1;
  recursiveFormTree(subnucs2, sublength2, branch->left, consensus, seencopy);
  delete subnucs2;
}


void Tree::copyTree(Tree *tree, Nucleotide *nucs, long length) {

  Branch *thisbranch = root;
  Branch *thatbranch = tree->root;
  
  // keep track of indeces we've already seen
  bool seen[size];
  for (int i=0; i < size; i++)
    seen[i] = false;
  recursiveCopy(thisbranch, thatbranch, nucs, length, seen);
  
  // have each branch compute the probs for its tables
  root->computeProbs();
}


void Tree::recursiveCopy(Branch *thisbranch, Branch *thatbranch, 
			 Nucleotide *nucs, long length, bool seen[]) {
  
  // copy this branch's info over
  thisbranch->leaf = thatbranch->leaf;
  thisbranch->size = thatbranch->size;
  thisbranch->indexOfNuc = thatbranch->indexOfNuc;
  thisbranch->branchRightOn1 = thatbranch->branchRightOn1;
  thisbranch->branchRightOn2 = thatbranch->branchRightOn2;

  // fill this branch's tables
  int counter;
  for (long i=0; i < length; i+=size) {
    counter = 0;
    for (int j=0; j < size; j++) 
      if (!seen[j]) 
	thisbranch->addToCount(nucs[i+j], counter++);
  }
  
  thisbranch->count = length / size; 

  if (thisbranch->leaf) 
    return;
  
  seen[thisbranch->indexOfNuc] = true;
  thisbranch->left = new Branch(thisbranch->size -1);
  thisbranch->right = new Branch(thisbranch->size -1);

  // go through the nucleotide sets and separate them into two 
  //  subsets: one for each of the branches
  Nucleotide *subnucs1 = new Nucleotide[length];
  Nucleotide *subnucs2 = new Nucleotide[length];
  long sublength1 = 0, sublength2 = 0;

  for (long i=0; i < length; i+=size) {
    
    int x = i + thisbranch->indexOfNuc;
    
    // if it's the consensus
    if (thisbranch->branchRightOn1 == nucs[x] || 
	(thisbranch->branchRightOn2 != BASE_UNKNOWN &&
	 thisbranch->branchRightOn2 == nucs[x]))
      for (int j=0; j < size; j++) 
	subnucs1[sublength1++] = nucs[i+j];
    else
      for (int j=0; j < size; j++) 
	subnucs2[sublength2++] = nucs[i+j];
  }
  
  // both need their own copies of seen...
  bool seencopy[size];
  for (int i=0; i < size; i++)
    seencopy[i] = seen[i];
  
  recursiveCopy(thisbranch->right, thatbranch->right, subnucs1, sublength1,
		seencopy);
  delete subnucs1;
  recursiveCopy(thisbranch->left, thatbranch->left, subnucs2, sublength2,seen);
  delete subnucs2;
}


void Tree::chi_squared(Nucleotide *nucs, long length, Nucleotide *consensus, 
		       double *values) {
  
  // chi-squared output table
  int x = size * size;
  double *table = new double[x];
  for (int i=0; i < x; i++) 
    table[i] = 0.0;
  
  double SeqTot = length / size;
  double I;
  double *J = new double[4*size];
  double *IandJ = new double[4*size]; 
  
  Nucleotide n1, n2;
  
  // NOTE: this is NOT the most efficient implementation.  
  //  I purposely do this, however, for clarity!
  
  // Step 1: initialize variables
  for (int a=0; a < size; a++) 
    for (long b=0; b < 4; b++)  
      J[a*4+b] = 0.0;
  
  
  // Step 2: calculate how often each basepair occurs in each position

  // for each k = set of nucleotides
  for (long k=0; k < length; k+=size) { 
    // for each l = nucleotides in set
    for (int l=0; l < size; l++) {

      // make sure it's it's not unknown
      Nucleotide n = nucs[k+l]; 
      if (n == BASE_UNKNOWN)
	n = BASE_C;

      // counts of basepair at a given position
      J [l * 4 + (int)n]++; 
    }
  }
  
  
  // Step 3: for each basepair that fits the given consensus at that
  //   spot, we then calculate how often each basepair occurs in each 
  //   other position

  // for each i = consensus
  for (int i=0; i < size; i++) {
    
    n1 = consensus[2*i];
    n2 = consensus[2*i+1];
    
    // Step 3a: reset vars
    I = 0.0;
    for (int j=0; j < size; j++) {
      for (long k=0; k < 4; k++) 
	IandJ[j*4+k] = 0.0;
    }
    
    // Step 3b: find consensus matches
    // for each k = set of nucleotides
    for (long k=0; k < length; k+=size) { 
      
      // does I = consensus?
      if (nucs[k+i] == n1 ||
	  (n2 != BASE_UNKNOWN && nucs[k+i] == n2)) {
	
	I++;
	
	// Step 3c: for each consensus match, count other pairs
	// for each l = nucleotides in set
	for (int l=0; l < size; l++) {
	  Nucleotide n = nucs[k+l]; 
	  if (n == BASE_UNKNOWN)
	    n = BASE_C;
	  IandJ [l * 4 + (int)n]++; 
	}
      }
    }

    // Step 4: do the math for each consensus
    // for each j = other positions
    for (int j=0; j < size; j++) {
      
      if (i == j || I == 0)
	table[i*size+j] = 0.0;
      
      else {
	
	double total = 0.0;
	
	// for each k = all base pairs
	for (int k=0; k < 4; k++) {

	  double Fa = IandJ[j*4+k];
	  double NotFa = J[j*4+k] - Fa;
	  double Ea = J[j*4+k] * (I / SeqTot);
	  double NotEa = J[j*4+k] * ((SeqTot - I) / SeqTot);

	  if (Ea == 0.0 || Ea == SeqTot || NotEa == 0.0 || NotEa == SeqTot)
 	    total += 0.0;
	  else
	    total += (pow(Fa - Ea, 2) / Ea) + (pow(NotFa - NotEa, 2) / NotEa);
	}
	
	table[i*size+j] = total;
      }
    }
  }
  
  for (int i=0; i < size; i++) {
    values[i] = 0.0;
    for (int j=0; j < size; j++) 
      values[i] += table[i*size+j];
  }
  
  delete J;
  delete IandJ;
  delete table;
}



