// C++ code file

#include "stateval.h"

#include <math.h>
#include <assert.h>
#include <fstream.h>
#include <stdlib.h>
#include <sys/types.h>
#include <netinet/in.h>


StatEvaluator::StatEvaluator() {

  AGsite = 0;
  GTsite = 0;
  ATGsite = 0;
}


StatEvaluator::~StatEvaluator() {

  if (AGsite) delete AGsite;
  if (GTsite) delete GTsite;
  if (ATGsite) delete ATGsite;
};


int LENIENT = 0; // 0, 1, or 2;

int StatEvaluator::evaluateStarts(FilterSequence *seq, double *scores) {
  int i, seql = seq->get_length();
  
  arrayInit(-Infinity, scores, seql);
  int *startptr = new int[seql];  arrayZero(startptr, seql);
  int *leftnuc  = new int[seql];  arrayZero(leftnuc,  seql);
  int *rightnuc = new int[seql];  arrayZero(rightnuc, seql);
  int startcnt = 0;

  for (i=41; i<=seql-20; ++i)
    if (seq->get(i-2)==BASE_A && seq->get(i-1)==BASE_G) {
      startptr[startcnt] = i;
      leftnuc[startcnt]  = seq->get(i-3);
      rightnuc[startcnt] = seq->get(i);
      startcnt++;
    }

  PyrTable pt;
  
  for (i=0; i<startcnt; ++i) {
    
    double dummy1, dummy2;
    burgeEvalAG(seq, startptr[i], dummy1, dummy2, scores[startptr[i]]);
    switch (LENIENT) {
    case 0:
      if (i) scores[startptr[i]] += MIN(0.0,scores[startptr[i]] - scores[startptr[i-1]]);
      if (VERBOSE && i && MIN(0.0,scores[startptr[i]] - scores[startptr[i-1]]) < -1)
	cout << "\t\t leftRule: " << startptr[i] << " plus " << MIN(0.0,scores[startptr[i]] - scores[startptr[i-1]]) << endl;
      continue;
      break;
    case 1: 
      if (scores[startptr[i]] < -4.8) scores[startptr[i]] = -Infinity;
      if (i>=1)
	if (scores[startptr[i]] < scores[startptr[i-1]] &&
	    startptr[i] - startptr[i-1] < 18)
	  scores[startptr[i]] = -Infinity;
      continue;
      break;
    case 2:
      if (scores[startptr[i]] < -3.5) scores[startptr[i]] = -Infinity;
      if (dummy2 < -3.5)              scores[startptr[i]] = -Infinity;
      if (i>=1)
	if (scores[startptr[i]] < scores[startptr[i-1]] &&
	    startptr[i] - startptr[i-1] < 20 &&
	    scores[startptr[i]] < 5)
	  scores[startptr[i]] = -Infinity;
      continue;
      break;
    case 3:
      if (scores[startptr[i]] < -4.5) scores[startptr[i]] = -Infinity;
      if (i>=1)
	if (scores[startptr[i]] < scores[startptr[i-1]] && startptr[i] - startptr[i-1] < 18)
	  scores[startptr[i]] = -Infinity;

      if (i>1 && (startptr[i] - startptr[i-2] <= 14) && scores[startptr[i]] < 5 &&
	  (leftnuc[i] != BASE_C || (rightnuc[i] != BASE_A && rightnuc[i] != BASE_G)))
	scores[startptr[i]] = -Infinity;

      if (i>=1 && leftnuc[i] == BASE_A && startptr[i] - startptr[i-1] < 30 && scores[startptr[i]] < 0)
	scores[startptr[i]] = -Infinity;
      continue;
      break;
    }
    continue;
    
    if (leftnuc[i] == BASE_G) {
      int pyrt = pt.countPyr(seq,startptr[i]);
      if (pyrt < 15) {
	GAGbadPyr++;
	continue;
      }
    }
    if (i>1 && (startptr[i] - startptr[i-2] <= 14) &&
	(leftnuc[i] != BASE_C || (rightnuc[i] != BASE_A && rightnuc[i] != BASE_G))) {
      
      continue;
    }
    
    if (i < startcnt-1 && (leftnuc[i] == BASE_A || leftnuc[i] == BASE_G) && 
	(leftnuc[i+1] == BASE_C | leftnuc[i+1] == BASE_T) &&
	startptr[i+1] - startptr[i] < 5) {
      
      continue;
    }
    if (i>0) {
      if ((leftnuc[i] == BASE_A || leftnuc[i] == BASE_G) &&
	  ( ( startptr[i] - startptr[i-1] <= 17)                               ||
	    
	    ( startptr[i] - startptr[i-1] <= 24 && 
	      ( (leftnuc[i-1] == BASE_C && pt.countPyr(seq,startptr[i]) < 13)  ||
		(rightnuc[i] == BASE_C || rightnuc[i] == BASE_T))))) {
	continue;
      }
      if (leftnuc[i] == BASE_T && 
	  ( (startptr[i] - startptr[i-1] <= 12)            ||
	    (startptr[i] - startptr[i-1] <= 17 &&
	     rightnuc[i] == BASE_C)) ) {
	continue;
      }
      if (leftnuc[i] == BASE_C)
	if (pt.countAG(seq,startptr[i]) > 2) {
	  continue;
	}
	else if (rightnuc[i] != BASE_G && rightnuc[i] != BASE_A && leftnuc[i-1] == BASE_C &&
		 rightnuc[i-1] == BASE_G &&
		 (startptr[i]-startptr[i-1] < 20)) {
	  continue;
	}
      scores[startptr[i]] = pt.countPyr(seq,startptr[i]);
      
      if (scores[startptr[i]] < 8) {
	scores[startptr[i]] = -Infinity;
	continue;
      }
      else if (scores[startptr[i]] < 12 && (leftnuc[i] == BASE_A || leftnuc[i] == BASE_G))  {
	scores[startptr[i]] = -Infinity;
	continue;
      }
      else if (scores[startptr[i]] < 9  && leftnuc[i]  == BASE_T)           {
	scores[startptr[i]] = -Infinity;
	continue;
      }  
      else if (scores[startptr[i]] < 11 && rightnuc[i] == BASE_C)            {
	scores[startptr[i]] = -Infinity;
	continue;
      }  
      else if (scores[startptr[i]] < 11 && rightnuc[i] == BASE_T
	       && pt.countGA(seq,startptr[i]) > 0 && leftnuc[i] != BASE_C)   {
	scores[startptr[i]] = -Infinity;
	continue;
      }  
      else if (scores[startptr[i]] < 15 && leftnuc[i]  == BASE_A && 
	       (rightnuc[i] == BASE_C || 
		(rightnuc[i] != BASE_G && seq->get(i+1) == BASE_A)))          {
	scores[startptr[i]] = -Infinity;
	continue;
      }
      
      if (scores[startptr[i]] > -Infinity) {
	switch (rightnuc[i]) {
	case BASE_C: scores[startptr[i]] = 1; break;
	case BASE_T: scores[startptr[i]] = 2; break;
	case BASE_A: scores[startptr[i]] = 3; break;
	case BASE_G: scores[startptr[i]] = 4; break;
	}
	switch (leftnuc[i]) {
	case BASE_G: scores[startptr[i]]+=  0; break;
	case BASE_A: scores[startptr[i]]+=  4; break;
	case BASE_T: scores[startptr[i]]+=  8; break;
	case BASE_C: scores[startptr[i]]+= 12; break;
	}
      }
    }
  }
  delete[] startptr;
  delete[] leftnuc;
  delete[] rightnuc;

  return 1;
}

int StatEvaluator::evaluateStops(FilterSequence *seq, double *scores) {
  int i, seql = seq->get_length();
  
  arrayInit(-Infinity, scores, seql);
  int *stopptr = new int[seql];  arrayZero(stopptr, seql);
  //  int *ham0    = new int[seql];  arrayZero(ham0,    seql);
  //  int *ranks   = new int[seql];  arrayZero(ranks,   seql);
  int stopcnt  = 0;

  //  HamEvaluator hEval0(0), hEval1(1), hEval2(2);
  //  rankTable rt("/data/tables/SpliceSites/rank_istart.bin");

  for (i=41; i<seql-20; ++i) {
    if (seq->get(i+1)==BASE_G && seq->get(i+2)==BASE_T) {
      stopptr[stopcnt] = i;
      /* exonLengths
      ham0[stopcnt]    = hEval0.evaluate(seq, i+1);
      ranks[stopcnt]   = rt.rank(seq, i);
      */
      stopcnt++;
    }
    if (seq->get(i+1)==BASE_G && seq->get(i+2)==BASE_C) {
      if (seq->get(i+3) == BASE_A && seq->get(i+4) == BASE_A && seq->get(i+5) == BASE_G &&
	  seq->get(i+6) == BASE_T && seq->get(i+7) == BASE_G && seq->get(i+8) == BASE_G &&
	  seq->get(i)   == BASE_G && seq->get(i-1) == BASE_A && seq->get(i-2) == BASE_A) {
	stopptr[stopcnt] = i;
	stopcnt++;
      }
      //      cout << "found gc splice site in position " << i+1 << endl;
    }
  }
  for (i=0; i<stopcnt; ++i) {
    int pos = stopptr[i];
    
    /*exonLengths
    int dist[8];
    for (j=1; j<=8; ++j)
      if (i-j>=0) dist[j-1] = pos - stopptr[i-j];
      else dist[j-1] = 10000;
    */

    /* exonLengths
    int isTrue = 0;
    for (j=0; j<tsc; ++j)
      if (trueStops[j] == pos) {
	isTrue = 1;
      }
    */

    double burgeScore1=-1000, burgeScore2=-1000;
    if (seq->get(pos+2) == BASE_T) burgeEvalGT(seq, pos, burgeScore1, burgeScore2);
    else burgeScore2 = 5;
    /*
      if (burgeScore2 < -5.5) { // Was -6.2;
      if (isTrue) GTfn << stopptr[i] << "\t Burge Rule" << endl;
      GTGoodRules[21]++;
      
      continue;
      }
    */
    
    /* exonLengths
    if (i < stopcnt-1 && stopptr[i+1] - pos < 60 && 
	( (ranks[i+1] == 1 && ham0[i] > 9) ||
	  (ranks[i+1] == 2 && ranks[i] > 2))) {
      if (isTrue)
	GTfn << stopptr[i] << "\t rule 1" << endl;
      GTGoodRules[0]++;
     
      continue;
    }
    */
    /* Removed January 16;
       if (i < stopcnt-2 && stopptr[i+2] - pos < 60 && 
       (ranks[i+2] == 2 && ranks[i] > 2)) {
       if (isTrue)
       GTfn << stopptr[i] << "\t rule 2" << endl;
       GTGoodRules[1]++;
       
       continue;
       }
    */
    /* exonLengths
    if (i < stopcnt-1 && i > 0 && stopptr[i+2] - pos < 60 
	&& dist[0] < 60 && ham0[i] > 11 &&
	ranks[i+1] < ranks[i] && ranks[i-1] < ranks[i]) {
      if (isTrue)
	GTfn << stopptr[i] << "\t rule 3" << endl;
	GTGoodRules[2]++;
     
      continue;
    }
    verbC("0, ",isTrue);
    if (i>0 && dist[0] < 30  && ham0[i] > ham0[i-1] && ranks[i] > 16) {
      if (isTrue)
	GTfn << stopptr[i] << "\t rule 4" << endl;
	GTGoodRules[3]++;
     
      continue;
    }

    if (i>1 && dist[1] < 80 && ham0[i] > ham0[i-2] && ranks[i] > 30) {
      if (isTrue) 
	GTfn << stopptr[i] << "\t rule 5" << endl;
	GTGoodRules[4]++;
     
      continue;
    }

    if (i>2 && dist[2] < 140 && ham0[i] > ham0[i-3] && ranks[i] > 32) {
      if (isTrue)
	GTfn << stopptr[i] << "\t rule 6" << endl;
	GTGoodRules[5]++;
     
      continue;
    }

    if (i>3 && dist[3] < 60  && ham0[i] > ham0[i-4] && ranks[i] > 32) {
      if (isTrue)
	GTfn << stopptr[i] << "\t rule 7" << endl;
	GTGoodRules[6]++;
     
      continue;
    }

    if (i>4 && dist[4] < 80  && ham0[i] > ham0[i-5] && ranks[i] > 40) {
      if (isTrue)
	GTfn << stopptr[i] << "\t rule 8" << endl;
	GTGoodRules[7]++;
     
      continue;
    }
    verbC("5, ",isTrue);
    if (i>5 && dist[5] < 120 && ham0[i] > ham0[i-6] && ranks[i] > 35) {
      if (isTrue)
	GTfn << stopptr[i] << "\t rule 9" << endl;
	GTGoodRules[8]++;
     
      continue;
    }
    */
    /* removed January 16;
       if (i>6 && dist[6] < 80 && ham0[i] > ham0[i-7] && ranks[i] > 16) { // was dist[6]<150: relaxed dec 16;
       if (isTrue) 
       GTfn << stopptr[i] << "\t rule 10" << endl;
       GTGoodRules[9]++;
       continue;
       }
       
    */
    /* exonLengths
    if (i>7 && dist[7] < 150 && ham0[i] > ham0[i-8] && ranks[i] > 24) {
      if (isTrue)
	GTfn << stopptr[i] << "\t rule 11" << endl;
      GTGoodRules[10]++;
      
      continue;
    }

    int ham1 = hEval1.evaluate(seq,pos+1);
    if (ham0[i] > 14 && ham1 > 0) {
      if (isTrue)
	GTfn << stopptr[i] << "\t rule 12" << endl;
	GTGoodRules[11]++;
     
      continue;  // Rule 2;
    } 
    if (minCluster > 2) {
      if (isTrue)
	GTfn << stopptr[i] << "\t rule 13" << endl;
      GTGoodRules[12]++;
      
      continue;  // Rule 6.a;
    }
    int    ham2       = hEval2.evaluate(seq,pos+1);
    double score12    = smallEvaluate(seq, pos);
    assert(score12 < 100000);

    if (score12 < -3.2 && ham0[i] > 10 && ham1 > 0
	&& ham2 > 0 && minCluster > 0 && ranks[i] > 10 ) {
      if (isTrue)
	GTfn << stopptr[i] << "\t rule 14" << endl;
      GTGoodRules[13]++;
      continue;  // Rule 3;
    } 
    double score80    = largeEvaluate(seq,pos);
    if (ranks[i] > 8 && ham0[i] > 7 && 
	minCluster > 1 && score12 < 0) {
      if (isTrue)
	GTfn << stopptr[i] << "\t rule 15" << endl;
      GTGoodRules[14]++;
      
      continue;  // Rule 5;
    }
    if (score80 > -Infinity && score80 < -47 && ham0[i] > 9) {
      if (isTrue)
	GTfn << stopptr[i] << "\t rule 16" << endl;
      GTGoodRules[15]++;
      
      continue;  // Rule 4;
    }
    if (ranks[i] >= 45 && score80 < 40 && ham0[i] > 9 &&
	minCluster > 0 && ham2 > 0) {
      if (isTrue)
	GTfn << stopptr[i] << "\t rule 17" << endl;
      GTGoodRules[16]++;
      
      continue;   // Old Rule, added;
    }
    if (score80 > -Infinity && score80 < -60) {
      if (isTrue)
	GTfn << stopptr[i] << "\t rule 18" << endl;
      GTGoodRules[17]++;
      
      continue;
    }
    */
    /* Removed January 16;
       if (ranks[i] > 5 && score12 > -Infinity && score12 < -6.4) {
       if (isTrue)
       GTfn << stopptr[i] << "\t rule 19" << endl;
       GTGoodRules[18]++;
       
       continue;
       }
    */
    /* Removed January 16;
       if (score80 < -27 && ranks[i] > 5 && ham0[i] >= 14) {
       if (isTrue)
       GTfn << stopptr[i] << "\t rule 20" << endl;
       GTGoodRules[19]++;
       
       continue;
       }
    */
    /* exonLengths
    if (score80 == -Infinity && minCluster > 0 &&
	ham0[i] > 9 && ranks[i] > 5) {
      if (isTrue)
	GTfn << stopptr[i] << "\t rule 21" << endl;
	GTGoodRules[20]++;
     
      continue;
    }
    verbC("14, ",isTrue);

    */
    scores[pos] = burgeScore2;
    if (i) {
      scores[pos] += MIN(0.0,scores[pos] - scores[stopptr[i-1]]);
      if (VERBOSE && MIN(0.0,scores[pos] - scores[stopptr[i-1]]) < -1)
	cout << "leftRule for GT: " << pos << " minus " << MIN(0.0,scores[pos] - scores[stopptr[i-1]]) << endl;
    }
  }
  /* exonLengths
  for (i=0; i<=21; ++i) {
    cout << GTGoodRules[i] << " ";
    GTeff << GTGoodRules[i] << " ";
  }
  
  GTeff << endl;
  */
  //  cout << endl;

  delete[] stopptr;
  //  delete[] ham0;
  //  delete[] ranks;

  return 1;
}

double StatEvaluator::burgeEvalAG(char *seqchar, int pos, double &score1, double &score2, double &score3) {
  
  // if the first bases match, let's compute...
  double sprob = 1.0, pprob = 1.0, bprob = 1.0;
  double nsprob = 1.0, npprob = 1.0, nbprob = 1.0;
  int counter = 0;
  Nucleotide previous_base = c2b(seqchar[pos - 39]);
  
  // figure out the probabilities
  for (long index = pos - 38; index < pos + 3; index++) {
    
    Nucleotide base = c2b(seqchar[index]);
    
    sprob *= AGsite->singles[counter]->getProb(base);
    nsprob *= AGsite->Nsingles[counter]->getProb(base);
    
    // first pairs entry isn't really a pair -- it's a single
    if (counter == 0) {
      
      double temp = 0.0, ntemp = 0.0;
      for (int j=0; j < 4; j++) {
	temp += AGsite->pairs[0]->getProb((Nucleotide)j, base);
	ntemp += AGsite->Npairs[0]->getProb((Nucleotide)j, base);
      }
      
      pprob = temp;
      npprob = ntemp;
    }
    else {
      pprob *= AGsite->pairs[counter]->getProb(previous_base, base);
      npprob *= AGsite->Npairs[counter]->getProb(previous_base, base);
    }
    
    // *** DO WE DO THE SAME THING HERE AS WE DID FOR THE FIRST PAIRS
    // ***  ENTRY???  MAYBE WE SHOULD....
    if (counter < 18) { // branchpoint region
      Nucleotide otherbase = c2b(seqchar[index-2]);
      bprob *= AGsite->branchpoint[counter]->getProb(otherbase, previous_base,
						     base);
      nbprob *= AGsite->Nbranchpoint[counter]->getProb(otherbase, previous_base,
						       base);
    }
    else {
      bprob *= AGsite->branchnormal[counter-18]->getProb(previous_base, base);
      nbprob *= AGsite->Nbranchnormal[counter-18]->getProb(previous_base, base);
    }
    
    counter++;
    previous_base = base;
  }
  
  score1 = log(sprob / nsprob);
  score2 = log(pprob / npprob);
  score3 = log(bprob / nbprob);
  
  return score3;
}

double StatEvaluator::burgeEvalAG(FilterSequence *seq, int pos, double &score1, double &score2, double &score3) {
  
  // if the first bases match, let's compute...
  if (!AGsite || 
      (seq->get(pos-2) != BASE_A || seq->get(pos-1) != BASE_G) ||
      pos < 39 || pos > seq->get_length()-3)
    {
      score1 = score2 = score3 = (double)LOW_SCORE;
      return score3;
    }
  
  double sprob = 1.0, pprob = 1.0, bprob = 1.0;
  double nsprob = 1.0, npprob = 1.0, nbprob = 1.0;
  int counter = 0;
  Nucleotide previous_base = seq->get(pos - 39);
  
  // figure out the probabilities
  for (long index = pos - 38; index < pos + 3; index++) {
    
    Nucleotide base = seq->get(index);
    
    sprob *= AGsite->singles[counter]->getProb(base);
    nsprob *= AGsite->Nsingles[counter]->getProb(base);
    
    // first pairs entry isn't really a pair -- it's a single
    if (counter == 0) {
      
      double temp = 0.0, ntemp = 0.0;
      for (int j=0; j < 4; j++) {
	temp += AGsite->pairs[0]->getProb((Nucleotide)j, base);
	ntemp += AGsite->Npairs[0]->getProb((Nucleotide)j, base);
      }
      
      pprob = temp;
      npprob = ntemp;
    }
    else {
      pprob *= AGsite->pairs[counter]->getProb(previous_base, base);
      npprob *= AGsite->Npairs[counter]->getProb(previous_base, base);
    }
    
    // *** DO WE DO THE SAME THING HERE AS WE DID FOR THE FIRST PAIRS
    // ***  ENTRY???  MAYBE WE SHOULD....
    if (counter < 18) { // branchpoint region
      Nucleotide otherbase = seq->get(index-2);
      bprob *= AGsite->branchpoint[counter]->getProb(otherbase, previous_base,
						     base);
      nbprob *= AGsite->Nbranchpoint[counter]->getProb(otherbase, previous_base,
						       base);
    }
    else {
      bprob *= AGsite->branchnormal[counter-18]->getProb(previous_base, base);
      nbprob *= AGsite->Nbranchnormal[counter-18]->getProb(previous_base, base);
    }
    
    counter++;
    previous_base = base;
  }
  
  score1 = log(sprob / nsprob);
  score2 = log(pprob / npprob);
  score3 = log(bprob / nbprob);
  
  return score3;
}



double StatEvaluator::burgeEvalGT(FilterSequence *seq, int pos, 
				  double &score1, double &score2) {
  
  // if the first bases match, let's compute...
  if (!GTsite || 
      (seq->get(pos+1) != BASE_G || seq->get(pos+2) != BASE_T) ||
      pos < 3 || pos > seq->get_length()-6)
    {
      score1 = score2 = (double)LOW_SCORE;
      return score2;
    }

  Nucleotide tides[7];
  tides[0] = seq->get(pos-2);  
  tides[1] = seq->get(pos-1);
  tides[2] = seq->get(pos);
  tides[3] = seq->get(pos+3);
  tides[4] = seq->get(pos+4);
  tides[5] = seq->get(pos+5);
  tides[6] = seq->get(pos+6);
  
  if (VERBOSE) {
    cout << "Estimating GT site " << pos << ": ";
    for (int i=-2; i<=6; ++i) cout << b2c(seq->get(pos+i));
  }
  double probability = 1.0, badprob = 1.0;
  GTsite->decisionTree->computeProb(tides, probability);
  if (VERBOSE) cout << "\t prob: " << probability;
  GTsite->NdecisionTree->computeProb(tides, badprob);
  if (VERBOSE) cout << "\t badprob: " << badprob;

  double otherbadprob = .25 * .25 * .25 * .25 * .25 * .25 * .25;
  
  score1 = log(probability / otherbadprob); 
  score2 = log(probability / badprob); 
  
  if (VERBOSE) cout << "\t logprobs: " << score2 << endl;

  return score2;
}


double StatEvaluator::burgeEvalATG(FilterSequence *seq, int pos, 
				   double &score1, double &score2) {
  
  // if the first bases match, let's compute...
  if (!ATGsite || 
      (seq->get(pos) != BASE_A || seq->get(pos+1) != BASE_T || 
       seq->get(pos+2) != BASE_G) ||
      pos < 7 || pos > seq->get_length()-4)
    {
      score1 = score2 = (double)LOW_SCORE;
      return score2;
    }

  Nucleotide tides[5];
  tides[0] = seq->get(pos-6);  
  tides[1] = seq->get(pos-5);  
  tides[2] = seq->get(pos-4);  
  tides[3] = seq->get(pos-3);  
  tides[4] = seq->get(pos+3);  
  
  double probability = 1.0, badprob = 1.0;
  ATGsite->decisionTree->computeProb(tides, probability);
  ATGsite->NdecisionTree->computeProb(tides, badprob);
  
  double otherbadprob = .25 * .25 * .25 * .25 * .25;
  
  score1 = log(probability / otherbadprob); 
  score2 = log(probability / badprob); 
  
  return score2;
}
