#include <fstream.h>
#include <iostream.h>
#include <string.h>
#include <stdlib.h>
#include <math.h>
#include "fseq.h"
#include "splice.h"
#include "tree.h"


// Splice functions

Splice::Splice(SpliceSite site) {

  SiteType = site;

  switch (SiteType) {

  case AG_S: 
    
    for (int i=0; i < 41; i++) {
      singles[i] = new QProb();
      Nsingles[i] = new QProb();
      pairs[i] = new QuadProb();
      Npairs[i] = new QuadProb();
    }
    for (int i=0; i < 23; i++) {
      branchnormal[i] = new QuadProb();
      Nbranchnormal[i] = new QuadProb();
    }
    for (int i=0; i < 18; i++) {
      branchpoint[i] = new Quad4Prob();
      Nbranchpoint[i] = new Quad4Prob();
    }
    return;
    
    // otherwise it's a donor splice site
  case GT_S: 
    decisionTree = new Tree(7);
    NdecisionTree = new Tree(7);
    countnum = Ncountnum = 0;
    maxcount = Nmaxcount = 1000;
    counts = new Nucleotide[maxcount];
    Ncounts = new Nucleotide[Nmaxcount];
    return;

  case ATG_S:

    decisionTree = new Tree(5);
    NdecisionTree = new Tree(5);
    countnum = Ncountnum = 0;
    maxcount = Nmaxcount = 1000;
    counts = new Nucleotide[maxcount];
    Ncounts = new Nucleotide[Nmaxcount];
    return;
    
  }
}

Splice::~Splice() {
  
  switch (SiteType) {

  case AG_S: 
    
    for (int i=0; i < 41; i++) {
      delete singles[i];
      delete Nsingles[i];
      delete pairs[i];
      delete Npairs[i];
    }
    for (int i=0; i < 23; i++) {
      delete branchnormal[i];
      delete Nbranchnormal[i];
    }
    for (int i=0; i < 18; i++) {
      delete branchpoint[i];
      delete Nbranchpoint[i];
    }
    return;
    
  case GT_S:
  case ATG_S:
    delete decisionTree;
    delete NdecisionTree;
    delete counts;
    delete Ncounts;
    return;
  }
}


void Splice::findSplices(GeneSequence *seq) {

  switch (SiteType) {

  case AG_S: 
    findAcceptorSplices(seq);
    return;
  case GT_S: 
    findDonorSplices(seq);
    return;
  case ATG_S: 
    findATGSplices(seq);
    return;
  }
}


void Splice::findAcceptorSplices(GeneSequence *seq) {
  
  // search through the sequence for AG_S pairs
  long length = seq->get_length() - 4;  // don't need to look at last 4

  // don't need to start looking until 41st base
  for (long i=41; i < length; i++) {

    // if the first bases match, let's compute...
    if (seq->get(i) == BASE_A) {
      if (seq->get(i+1) == BASE_G) {

	// if it's a true splice site:
	if (seq->whatRegion(i) == REGION_INTRON && 
	    seq->whatRegion(i+2) == REGION_CEXON) {

	  Nucleotide previous_base = seq->get(i-37); 
	  int counter = 0;
	  
	  for (long index = i - 36; index < i + 5; index++) {
	    
	    Nucleotide base = seq->get(index);
	    // singles:
	    singles[counter]->addToCount(base);
	    // pairs:
	    pairs[counter]->addToCount(previous_base, base);
	    // branchpoints:
	    if (counter < 18) { // branchpoint region
	      branchpoint[counter]->addToCount(seq->get(index-4), 
					      seq->get(index-3),
					      seq->get(index-2));
	      branchpoint[counter]->addToCount(seq->get(index-3), 
					      seq->get(index-2),
					      previous_base);
	      branchpoint[counter]->addToCount(seq->get(index-2), 
					       previous_base, base);
	      branchpoint[counter]->addToCount(previous_base, base, 
					      seq->get(index+1));
	      branchpoint[counter]->addToCount(base, seq->get(index+1),
					      seq->get(index+2));
	    }
	    else
	      branchnormal[counter-18]->addToCount(previous_base, base);
	    
	    previous_base = base;
	    counter++;
	  }
	}
	 
	// else it's a pseudo one...
	else {

	  Nucleotide previous_base = seq->get(i-37); 
	  int counter = 0;
	  
	  for (long index = i - 36; index < i + 5; index++) {
	    
	    Nucleotide base = seq->get(index);
	    // singles:
	    Nsingles[counter]->addToCount(base);
	    // pairs:
	    Npairs[counter]->addToCount(previous_base, base);
	    // branchpoints:
	    if (counter < 18) { // branchpoint region
	      Nbranchpoint[counter]->addToCount(seq->get(index-4), 
						seq->get(index-3),
						seq->get(index-2));
	      Nbranchpoint[counter]->addToCount(seq->get(index-3), 
						seq->get(index-2),
						previous_base);
	      Nbranchpoint[counter]->addToCount(seq->get(index-2), 
						previous_base, base);
	      Nbranchpoint[counter]->addToCount(previous_base, base, 
						seq->get(index+1));
	      Nbranchpoint[counter]->addToCount(base, seq->get(index+1),
						seq->get(index+2));
	    }
	    else
	      Nbranchnormal[counter-18]->addToCount(previous_base, base);
	    
	    previous_base = base;
	    counter++;
	  }
	}
      }
    }
  }
}


void Splice::findDonorSplices(GeneSequence *seq) {
  
  // search through the sequence for GT_S pairs
  long length = seq->get_length() - 5;  // don't need to look at last 5
  
  // don't need to start looking until 4th base
  for (long i=4; i < length; i++) {
    
    // if the first bases match, let's compute...
    if (seq->get(i) == BASE_G) {
      if (seq->get(i+1) == BASE_T) {
	
	Nucleotide tides[7];
	tides[0] = seq->get(i-3);  
	tides[1] = seq->get(i-2);  
	tides[2] = seq->get(i-1);  
	tides[3] = seq->get(i+2);  
	tides[4] = seq->get(i+3);  
	tides[5] = seq->get(i+4);  
	tides[6] = seq->get(i+5);  
	
	// if it's a true splice site:
	if (seq->whatRegion(i) == REGION_INTRON && 
	    seq->whatRegion(i-1) == REGION_CEXON) {
	  
	  // make sure enough memory is allocated
	  if (countnum > maxcount-7) {
	    
	    // create new arrays
	    long newmax = 2*maxcount;
	    Nucleotide *newcount = new Nucleotide[newmax];
	    
	    // copy the old stuff into the new arrays
	    for (long i=0; i < countnum; i++)
	      newcount[i] = counts[i];
	    delete counts;
	    counts = newcount;
	    maxcount = newmax;
	  }	  
	  
	  // add to the array the 7 nucleotides
	  for (int a=0; a < 7; a++) 
	    counts[countnum++] = tides[a];
	}	 
	
	// else it's a pseudo one...
	else { 

	  // make sure enough memory is allocated
	  if (Ncountnum > Nmaxcount-7) {
	    
	    // create new arrays
	    long newmax = 2*Nmaxcount;
	    Nucleotide *newcount = new Nucleotide[newmax];
	    
	    // copy the old stuff into the new arrays
	    for (long i=0; i < Ncountnum; i++)
	      newcount[i] = Ncounts[i];
	    delete Ncounts;
	    Ncounts = newcount;
	    Nmaxcount = newmax;
	  }	  
	  
	  // add to the array the 7 nucleotides
	  for (int a=0; a < 7; a++) 
	    Ncounts[Ncountnum++] = tides[a];
	}	 
      }
    }
  }
}


void Splice::findATGSplices(GeneSequence *seq) {
  
  // search through the sequence for ATG_S pairs
  long length = seq->get_length() - 5;  // don't need to look at last 5
  
  // don't need to start looking until 7th base
  for (long i=7; i < length; i++) {
    
    // if the first bases match, let's compute...
    if (seq->get(i) == BASE_A) {
      if (seq->get(i+1) == BASE_T) {
	if (seq->get(i+2) == BASE_G) {
	
	  Nucleotide tides[5];
	  tides[0] = seq->get(i-6);  
	  tides[1] = seq->get(i-5);  
	  tides[2] = seq->get(i-4);  
	  tides[3] = seq->get(i-3);  
	  //tides[4] = seq->get(i-2);  
	  //tides[5] = seq->get(i-1);  
	  tides[4] = seq->get(i+3);  
	  //tides[5] = seq->get(i+4);  
	  //tides[6] = seq->get(i+5);  
	  
	  // if it's a true splice site:
	  if (seq->whatRegion(i) == REGION_CEXON && 
	      seq->whatRegion(i-1) != REGION_CEXON &&
	      seq->nuc2reg(i) == seq->firstCExon()) {
	    
	    // make sure enough memory is allocated
	    if (countnum > maxcount-5) {
	      
	      // create new arrays
	      long newmax = 2*maxcount;
	      Nucleotide *newcount = new Nucleotide[newmax];
	      
	      // copy the old stuff into the new arrays
	      for (long i=0; i < countnum; i++)
		newcount[i] = counts[i];
	      delete counts;
	      counts = newcount;
	      maxcount = newmax;
	    }	  
	    
	    // add to the array the 5 nucleotides
	    for (int a=0; a < 5; a++) 
	      counts[countnum++] = tides[a];
	  }	 
	  
	  // else it's a pseudo one...
	  else { 
	    
	    // make sure enough memory is allocated
	    if (Ncountnum > Nmaxcount-5) {
	      
	      // create new arrays
	      long newmax = 2*Nmaxcount;
	      Nucleotide *newcount = new Nucleotide[newmax];
	      
	      // copy the old stuff into the new arrays
	      for (long i=0; i < Ncountnum; i++)
		newcount[i] = Ncounts[i];
	      delete Ncounts;
	      Ncounts = newcount;
	      Nmaxcount = newmax;
	    }	  
	    
	    // add to the array the 5 nucleotides
	    for (int a=0; a < 5; a++) 
	      Ncounts[Ncountnum++] = tides[a];
	  }	 
	}
      }
    }
  }
}


istream& operator>>(istream& cin, Splice& splice) {
  
  int temp;
  cin >> temp;
  splice.SiteType = (SpliceSite)temp;
  
  switch (splice.SiteType) {
    
  case AG_S:  
    
    for (int i=0; i < 41; i++) 
      cin >> *(splice.singles[i]);
    for (int i=0; i < 41; i++) 
      cin >> *(splice.Nsingles[i]);
    for (int i=0; i < 41; i++) 
      cin >> *(splice.pairs[i]);
    for (int i=0; i < 41; i++) 
      cin >> *(splice.Npairs[i]); 
    for (int i=0; i < 23; i++) 
      cin >> *(splice.branchnormal[i]);
    for (int i=0; i < 23; i++) 
      cin >> *(splice.Nbranchnormal[i]);
    for (int i=0; i < 18; i++) 
      cin >> *(splice.branchpoint[i]);
    for (int i=0; i < 18; i++) 
      cin >> *(splice.Nbranchpoint[i]);
    break;
  
  case GT_S:
  case ATG_S:

    cin >> *(splice.decisionTree) >> *(splice.NdecisionTree);
    break;
  }
  
  return cin;
}

ostream& operator<<(ostream& cout, Splice& splice) {
  
  cout << (int)splice.SiteType << endl;

  switch (splice.SiteType) {

  case AG_S:  
    
    for (int i=0; i < 41; i++) 
      cout << *(splice.singles[i]) << endl;
    for (int i=0; i < 41; i++) 
      cout << *(splice.Nsingles[i]) << endl;
    for (int i=0; i < 41; i++) 
      cout << *(splice.pairs[i]) << endl;
    for (int i=0; i < 41; i++) 
      cout << *(splice.Npairs[i]) << endl;
    for (int i=0; i < 23; i++) 
      cout << *(splice.branchnormal[i]) << endl;
    for (int i=0; i < 23; i++) 
      cout << *(splice.Nbranchnormal[i]) << endl;
    for (int i=0; i < 18; i++) 
      cout << *(splice.branchpoint[i]) << endl;
    for (int i=0; i < 18; i++) 
      cout << *(splice.Nbranchpoint[i]) << endl;
    break;

  case GT_S:
  case ATG_S:

    cout << *(splice.decisionTree) << "\n" 
	 << *(splice.NdecisionTree) << endl;
    break;
  }
  
  return cout;
}


void Splice::computations() {
  
  switch (SiteType) {

  case AG_S:  
    
    for (int i=0; i < 41; i++) {
      singles[i]->computeProb();
      Nsingles[i]->computeProb();
      pairs[i]->computeProb();
      Npairs[i]->computeProb();
    }
    for (int i=0; i < 23; i++) {
      branchnormal[i]->computeProb();
      Nbranchnormal[i]->computeProb();
    }
    for (int i=0; i < 18; i++) {
      branchpoint[i]->computeProb();
      Nbranchpoint[i]->computeProb();
    }
    break;

  case GT_S:
  case ATG_S:

    decisionTree->createTree(counts, countnum);
    NdecisionTree->copyTree(decisionTree, Ncounts, Ncountnum);
    break;
  }      
}


void Splice::testSplices(GeneSequence *seq, ostream& outfile) {
  
  switch (SiteType) {

  case AG_S:  
    testAcceptorSplices(seq, outfile);
    return;
  case GT_S:  
    testDonorSplices(seq, outfile);
    return;
  case ATG_S:  
    testATGSplices(seq, outfile);
    return;
  }
}


void Splice::testAcceptorSplices(GeneSequence *seq, ostream& outfile) {
  
  static int lookedAt = 0;
  
  // search through the sequence for AG_S pairs
  long length = seq->get_length() - 4;  // don't need to look at last 4
  
  // don't need to start looking until 41st base
  for (long i=41; i < length; i++) {
    
    // if the first bases match, let's compute...
    if (seq->get(i) == BASE_A) {
      if (seq->get(i+1) == BASE_G) {
	
	double sprob = 1.0, pprob = 1.0, bprob= 1.0;
	double nsprob = 1.0, npprob = 1.0, nbprob= 1.0;
	int counter = 0;
	Nucleotide previous_base = seq->get(i-37);

	for (long index = i - 36; index < i + 5; index++) {
	    
	  Nucleotide base = seq->get(index);

	  sprob *= singles[counter]->getProb(base);
	  nsprob *= Nsingles[counter]->getProb(base);

	  // first pairs entry isn't really a pair -- it's a single
	  if (counter == 0) {
	    
	    double temp = 0.0, ntemp = 0.0;
	    for (int j=0; j < 4; j++) {
	      temp += pairs[0]->getProb((Nucleotide)j, base);
	      ntemp += Npairs[0]->getProb((Nucleotide)j, base);
	    }

	    pprob = temp;
	    npprob = ntemp;
	  }
	  else {
	    pprob *= pairs[counter]->getProb(previous_base, base);
	    npprob *= Npairs[counter]->getProb(previous_base, base);
	  }

	  // *** DO WE DO THE SAME THING HERE AS WE DID FOR THE FIRST PAIRS
	  // ***  ENTRY???  MAYBE WE SHOULD....
	  if (counter < 18) { // branchpoint region
	    Nucleotide otherbase = seq->get(index-2);
	    bprob *= branchpoint[counter]->getProb(otherbase, previous_base,
						   base);
	    nbprob *= Nbranchpoint[counter]->getProb(otherbase, previous_base,
						     base);
	  }
	  else {
	    bprob *= branchnormal[counter-18]->getProb(previous_base, base);
	    nbprob *= Nbranchnormal[counter-18]->getProb(previous_base, base);
	  }

	  counter++;
	  previous_base = base;
	}

	double Sprob = log(sprob / nsprob);
	double Pprob = log(pprob / npprob);
	double Bprob = log(bprob / nbprob);

	char *locus = new char[500];
	strcpy(locus, seq->get_locus());

	bool trueSite = false;
	if (seq->whatRegion(i) == REGION_INTRON && 
	    seq->whatRegion(i+2) == REGION_CEXON) 
	  trueSite = true;
	
	if (Sprob < -5.0 && Pprob < -5.0 && Bprob < -5.0) {
	  ;
	}
	else {
	  outfile << trueSite << "\t" << Bprob << "\t" << (int)seq->get(i-1) 
		  << "\t" << (int)seq->get(i+2) << "\t" << lookedAt << "\t" 
		  << i << endl; 
	  //outfile << trueSite << "\t" << Sprob << "\t" << Pprob << "\t" << 
	  // Bprob << "\t" << (int)seq->get(i-1) << "\t" << (int)seq->get(i+2) 
	  //	  << "\t" << lookedAt << "\t" << i << endl; 
	}
      }
    }
  }
  
  lookedAt++;
}


void Splice::testDonorSplices(GeneSequence *seq, ostream& outfile) {
  
  static int lookedAt = 0;

  // search through the sequence for GT_S pairs
  long length = seq->get_length() - 5;  // don't need to look at last 5
  
  // don't need to start looking until 4th base
  for (long i=4; i < length; i++) {
    
    // if the first bases match, let's compute...
    if (seq->get(i) == BASE_G) {
      if (seq->get(i+1) == BASE_T) {
	
	// if it's a true splice site:
	bool trueSite = false;
	if (seq->whatRegion(i) == REGION_INTRON && 
	    seq->whatRegion(i-1) == REGION_CEXON) {
	  trueSite = true;
	}	 
	
	double probability = 1.0, badprob = 1.0;
	Nucleotide tides[7];
	tides[0] = seq->get(i-3);  
	tides[1] = seq->get(i-2);  
	tides[2] = seq->get(i-1);  
	tides[3] = seq->get(i+2);  
	tides[4] = seq->get(i+3);  
	tides[5] = seq->get(i+4);  
	tides[6] = seq->get(i+5);  
	
	decisionTree->computeProb(tides, probability);
	NdecisionTree->computeProb(tides, badprob);

	//double otherbadprob = .25 * .25 * .25 * .25 * .25 * .25 * .25;
	
	double answer = log(probability / badprob); 
	//double otheranswer = log(probability / otherbadprob); 
	
	outfile << trueSite << "\t" << answer << "\t" << lookedAt 
		<< "\t" << i << endl;
	//outfile << trueSite << "\t" << otheranswer << "\t" << answer << "\t" 
	//	<< lookedAt << "\t" << i << endl;
      }
    }
  }
  
  lookedAt++;
}


void Splice::testATGSplices(GeneSequence *seq, ostream& outfile) {
  
  static int lookedAt = 0;

  // search through the sequence for GT_S pairs
  long length = seq->get_length() - 5;  // don't need to look at last 5
  
  // don't need to start looking until 4th base
  for (long i=7; i < length; i++) {
    
    // if the first bases match, let's compute... 
    if (seq->get(i) == BASE_A) {
      if (seq->get(i+1) == BASE_T) {
	if (seq->get(i+2) == BASE_G) {
	  
	  // if it's a true splice site:
	  bool trueSite = false;
	  if (seq->whatRegion(i) == REGION_CEXON && 
	      seq->whatRegion(i-1) != REGION_CEXON &&
	      seq->nuc2reg(i) == seq->firstCExon()) {
	    trueSite = true;
	  }	 
	
	  double probability = 1.0, badprob = 1.0;
	  Nucleotide tides[5];
	  tides[0] = seq->get(i-6);  
	  tides[1] = seq->get(i-5);  
	  tides[2] = seq->get(i-4);  
	  tides[3] = seq->get(i-3);  
	  //tides[4] = seq->get(i-2);  
	  //tides[5] = seq->get(i-1);  
	  tides[4] = seq->get(i+3);  
	  //tides[5] = seq->get(i+4);  
	  //tides[6] = seq->get(i+5);  
	
	  decisionTree->computeProb(tides, probability);
	  NdecisionTree->computeProb(tides, badprob);

	  //double otherbadprob = .25 * .25 * .25 * .25 * .25 * .25 * .25;
	
	  double answer = log(probability / badprob); 
	  //double otheranswer = log(probability / otherbadprob); 
	
	  outfile << trueSite << "\t" << answer << "\t" << lookedAt 
		  << "\t" << i << endl;
	  //outfile << trueSite << "\t" << otheranswer << "\t" << answer 
	  // << "\t" << lookedAt << "\t" << i << endl;
	}
      }
    }
  }
  
  lookedAt++;
}

