// C++ code file

#include "pairtable.h"

#include <math.h>
#include <assert.h>
#include <iostream.h>
#include <stdlib.h>
#include <sys/types.h>
#include <netinet/in.h>


#define sqr(x)	((x)*(x))
#define my_log(x) (x > 0.0 ? log(x) : -INF)

const double INF = HUGE_VAL;


PairTable::PairTable(istream& istr) {

  int i;
  from1 = read_long(istr);
  to1 = read_long(istr);
  len1 = to1 - from1;

  from2 = read_long(istr);
  to2 = read_long(istr);
  len2 = to2 - from2;

  single_counts1 = create_table(len1 * 4);
  single_totals1 = create_table(len1);
  single_probs_log1 = new double[len1 * 4];

  single_counts2 = create_table(len2 * 4);
  single_totals2 = create_table(len2);
  single_probs_log2 = new double[len2 * 4];

  double_counts = create_table(len1 * len2 * 4 * 4);
  double_totals = create_table(len1 * len2);
  double_probs_log = new double[len1 * len2 * 4 * 4];

  
  for (i = from1; i < to1; i++) {
    single_total1(i) = read_long(istr);
    for (int k = 0; k < 4; k++)
      single_count1(i, Nucleotide(k)) = read_long(istr);
  }

  for (i = from2; i < to2; i++) {
    single_total2(i) = read_long(istr);
    for (int k = 0; k < 4; k++)
      single_count2(i, Nucleotide(k)) = read_long(istr);
  }
  
  for (i = from1; i < to1; i++) 
    for(int j = from2; j < to2; j++) {
      double_total(i, j) = read_long(istr);
      for (int k = 0; k < 4; k++)
	for (int l = 0; l < 4; l++)
	  double_count(i, j, Nucleotide(k), Nucleotide(l)) = read_long(istr);
    }
  
  calculate_probabilities();
  
}

// Note: Use the constructor below for the Minus One option only!

PairTable::PairTable(istream& istr, FilterSequence *seq, char* start_or_stop) {

  int i;
  from1 = read_long(istr);
  to1 = read_long(istr);
  len1 = to1 - from1;

  from2 = read_long(istr);
  to2 = read_long(istr);
  len2 = to2 - from2;

  single_counts1 = create_table(len1 * 4);
  single_totals1 = create_table(len1);
  single_probs_log1 = new double[len1 * 4];

  single_counts2 = create_table(len2 * 4);
  single_totals2 = create_table(len2);
  single_probs_log2 = new double[len2 * 4];

  double_counts = create_table(len1 * len2 * 4 * 4);
  double_totals = create_table(len1 * len2);
  double_probs_log = new double[len1 * len2 * 4 * 4];

  
  for (i = from1; i < to1; i++) {
    single_total1(i) = read_long(istr);
    for (int k = 0; k < 4; k++)
      single_count1(i, Nucleotide(k)) = read_long(istr);
  }

  for (i = from2; i < to2; i++) {
    single_total2(i) = read_long(istr);
    for (int k = 0; k < 4; k++)
      single_count2(i, Nucleotide(k)) = read_long(istr);
  }
  
  for (i = from1; i < to1; i++) 
    for(int j = from2; j < to2; j++) {
      double_total(i, j) = read_long(istr);
      for (int k = 0; k < 4; k++)
	for (int l = 0; l < 4; l++)
	  double_count(i, j, Nucleotide(k), Nucleotide(l)) = read_long(istr);
    }


  /* Minus One Option: Remove a sequence from the pairtable. */
  this->minusOne(seq, start_or_stop);

  calculate_probabilities();
  
}

void PairTable::minusOne(FilterSequence* seq, char* start_or_stop) {

  if (strcmp(start_or_stop, "rand") == 0) {
    for (int i=0; i <= seq->get_length(); i+=500) 
      this->remove_data(seq, i, i);
    return;
  } 
  
  for (int i=1; i<=seq->get_region_num(); i++) {
    Region* reg = seq->get_region(i);
          
    if (i > 1 && i < seq->get_region_num()) {
      Region* prev_reg = seq->get_region(i-1);
      Region* next_reg = seq->get_region(i+1);
     
	if (prev_reg->type == REGION_NCEXON ||
	    next_reg->type == REGION_NCEXON)
	  continue;
	
    }
      
    if (reg->is_start_inexact() || reg->is_stop_inexact()) 
      continue;

    if (strcmp(start_or_stop, "none") ==0)
      continue;

    if ((reg->type == REGION_CEXON) &&
	(strcmp(start_or_stop, "start") == 0)) {
	this->remove_data(seq, reg->start, reg->start);
	continue;
      }
    
    if ((reg->type == REGION_INTRON) &&
	(strcmp(start_or_stop, "stop") == 0)){
      this->remove_data(seq, (reg->start)-1, (reg->start)-1);
      continue;
    }
    
  }

  
}




PairTable::PairTable(int afrom1, int ato1, int afrom2, int ato2) {
  from1 = afrom1;
  to1 = ato1;
  len1 = to1 - from1;

  from2 = afrom2;
  to2 = ato2;
  len2 = to2 - from2;

  total = 0;

  single_counts1 = create_table(len1 * 4);
  single_totals1 = create_table(len1);
  single_probs_log1 = new double[len1 * 4];

  single_counts2 = create_table(len2 * 4);
  single_totals2 = create_table(len2);
  single_probs_log2 = new double[len2 * 4];

  double_counts = create_table(len1 * len2 * 4 * 4);
  double_totals = create_table(len1 * len2);
  double_probs_log = new double[len1 * len2 * 4 * 4];
}

PairTable::~PairTable() {
  delete double_counts;
  delete double_totals;
  delete[] double_probs_log;

  delete single_counts1;
  delete single_totals1;
  delete[] single_probs_log1;

  delete single_counts2;
  delete single_totals2;
  delete[] single_probs_log2;
};

bool PairTable::add_data(Sequence *seq, long pos1, long pos2) {
  if (pos1 + from1 <= 0 || pos1 + to1 - 1 > seq->get_length() ||
      pos2 + from2 <= 0 || pos2 + to2 - 1 > seq->get_length() )
    return FALSE;
  
  total++;
  
  int i;
  
  for (i = from1; i < to1; i++) 
    if (seq->get(i + pos1) < BASE_UNKNOWN) {
      single_total1(i)++;
      single_count1(i, seq->get(i+pos1))++;
    }
  
  for (int j = from2; j < to2; j++) 
    if (seq->get(j + pos2) < BASE_UNKNOWN) {
      single_total2(j)++;
      single_count2(j, seq->get(j+pos2))++;
    }
  
  for (i = from1; i < to1; i++) {
    if (seq->get(i + pos1) < BASE_UNKNOWN)
      for (int j = from2; j < to2; j++) {
	if (seq->get(j + pos2) < BASE_UNKNOWN) {
	  double_total(i, j)++;
	  double_count(i, j, seq->get(i+pos1), 
		       seq->get(j+pos2))++;
	}
      }
  }
  return TRUE;
}

bool PairTable::add_data(Sequence *seq, long pos1, long prevpos1, 
			 long nextpos1, long pos2, long prevpos2, 
			 long nextpos2) {
  if (pos1 + from1 <= 0 || pos1 + to1 - 1 > seq->get_length() ||
      pos2 + from2 <= 0 || pos2 + to2 - 1 > seq->get_length()   )
    return FALSE;
  // this bound could be made more lenient - 
  // we're throwing away some information here
  
  total++;

  int i;
  
  for (i = from1; i < to1; i++)
    if (((i+pos1) > prevpos1) && ((i+pos1) < nextpos1))
      if (seq->get(i + pos1) < BASE_UNKNOWN) {
	single_total1(i)++;
	single_count1(i, seq->get(i+pos1))++;
      }
  
  for (int j = from2; j < to2; j++) 
    if (((j+pos2) > prevpos2) && ((j+pos2) < nextpos2))
      if (seq->get(j + pos2 ) < BASE_UNKNOWN) {
	single_total2(j)++;
	single_count2(j, seq->get(j+pos2))++;
      }
  
  for (i = from1; i < to1; i++) {
    if (seq->get(i + pos1) < BASE_UNKNOWN)
      for (int j = from2; j < to2; j++) {
	if (((i+pos1) > prevpos1) && ((i+pos1) < nextpos1) &&
	    ((j+pos2) > prevpos2) && ((j+pos2) < nextpos2)   )
	  if (seq->get(j + pos2) < BASE_UNKNOWN) {
	    if (pos1 == pos2) { // the conditional for the traditional self comparison
	      double_total(i, j)++;
	      double_count(i, j, seq->get(i+pos1), 
			   seq->get(j+pos2))++;
	    } else {
	      if (i+pos1 < j + pos2) { 
		// so that we don't count a sequence against itself (i.e., on small exons)
		double_total(i, j)++;
		double_count(i, j, seq->get(i+pos1), 
			     seq->get(j+pos2))++; }
	      
	    }
	  }
      }
  }
  return TRUE;
  
}


bool PairTable::remove_data(Sequence *seq, long pos1, long pos2) {
  if (pos1 + from1 <= 0 || pos1 + to1 - 1 > seq->get_length() ||
      pos2 + from2 <= 0 || pos2 + to2 - 1 > seq->get_length() )
    return FALSE;
  
  total--;
  
  int i;
  
  for (i = from1; i < to1; i++) 
    if (seq->get(i + pos1) < BASE_UNKNOWN) {
      single_total1(i)--;
      single_count1(i, seq->get(i+pos1))--;
    }
  
  for (int j = from2; j < to2; j++) 
    if (seq->get(j + pos2) < BASE_UNKNOWN) {
      single_total2(j)--;
      single_count2(j, seq->get(j+pos2))--;
    }
  
  for (i = from1; i < to1; i++) {
    if (seq->get(i + pos1) < BASE_UNKNOWN)
      for (int j = from2; j < to2; j++) {
	if (seq->get(j + pos2) < BASE_UNKNOWN) {
	  double_total(i, j)--;
	  double_count(i, j, seq->get(i+pos1), 
		       seq->get(j+pos2))--;
	}
      }
  }
  return TRUE;
}


long *PairTable::create_table(long tlen) {
  long *table = new long[tlen];
  for (int i = 0; i < tlen; i ++)
    table[i] = 0;
  return table;
}

long& PairTable::single_count1(long idx, Nucleotide n) {
  assert(idx >= from1 && idx < to1);
  
  // our internal code should allow idx = 0, but not idx = to1
  // this makes it much easier to write loops
  
  return single_counts1[(idx - from1)*4 + n];
}

long& PairTable::single_total1(long idx) {
  assert(idx >= from1 && idx < to1);

  return single_totals1[idx-from1];
}

long& PairTable::single_count2(long idx, Nucleotide n) {
  assert(idx >= from2 && idx < to2);

  return single_counts2[(idx - from2)*4 + n];
}

long& PairTable::single_total2(long idx) {
  assert(idx >= from2 && idx < to2);

  return single_totals2[idx-from2];
}


long& PairTable::double_count(long idx1, long idx2,
				  Nucleotide n1, Nucleotide n2) {
  assert(idx1 >= from1 && idx1 < to1);
  assert(idx2 >= from2 && idx2 < to2);

  return double_counts[(((idx1-from1)*len2 + (idx2-from2))*4 + n1)*4 + n2];
}

long& PairTable::double_total(long idx1, long idx2) {
  assert(idx1 >= from1 && idx1 < to1);
  assert(idx2 >= from2 && idx2 < to2);

  return double_totals[(idx1-from1)*len2 + (idx2-from2)];
}



void PairTable::calculate_probabilities() {
  for (int i = from1; i < to1; i++)
    for (int j = from2; j < to2; j++) {

      if (double_total(i, j) == 0) {
	cout << "problem at pos " << i << ", " << j << " - bailing out\n";
	assert(0);}
      
      double den = log(double_total(i, j));
      for (int k = 0; k < 4; k++)
	for (int l = 0; l < 4; l++)
	  double_prob_log(i, j, Nucleotide(k), Nucleotide(l)) = 
	    my_log(double_count(i, j, Nucleotide(k), Nucleotide(l))) - den;
    }  
  
  for (int i = from1; i < to1; i++) {
    double den = log(single_total1(i));
    for (int k = 0; k < 4; k++)
      single_prob_log1(i, Nucleotide(k)) = 
	my_log(single_count1(i, Nucleotide(k))) - den;
  }
  
  for (int j = from2; j < to2; j++) {
    double den = log(single_total2(j));
    for (int k = 0; k < 4; k++)
      single_prob_log2(j, Nucleotide(k)) = 
	my_log(single_count2(j, Nucleotide(k))) - den;
  }
  
}



double& PairTable::double_prob_log(long idx1, long idx2,
				   Nucleotide n1, Nucleotide n2) {
  assert(idx1 >= from1 && idx1 < to1);
  assert(idx2 >= from2 && idx2 < to2);
  
  return double_probs_log[(((idx1-from1)*len2 + (idx2-from2))*4 + n1)*4 + n2];
}


double& PairTable::single_prob_log1(long idx, Nucleotide n) {
  assert(idx >= from1 && idx < to1);  
  return single_probs_log1[(idx - from1)*4 + n];
  
}

double& PairTable::single_prob_log2(long idx, Nucleotide n) {
  assert(idx >= from2 && idx < to2);
  return single_probs_log2[(idx - from2)*4 + n];
  
}


void PairTable::write_long(ostream& ostr, long l) {
  l = htonl(l);
  ostr.write((char *) &l, sizeof(long));
}

long PairTable::read_long(istream& istr) {
  long l;
  istr.read((char *) &l, sizeof(long));
  return ntohl(l);
}

void PairTable::store(ostream& ostr) {
  int i;

  write_long(ostr, from1);
  write_long(ostr, to1);

  write_long(ostr, from2);
  write_long(ostr, to2);

  for (i = from1; i < to1; i++) {
    write_long(ostr, single_total1(i));
    for (int k = 0; k < 4; k++)
      write_long(ostr, single_count1(i, Nucleotide(k)));
  }
    
  for (i = from2; i < to2; i++) {
    write_long(ostr, single_total2(i));
    for (int k = 0; k < 4; k++)
      write_long(ostr, single_count2(i, Nucleotide(k)));
  }
    

  for (i = from1; i < to1; i++) 
    for(int j = from2; j < to2; j++) {
      write_long(ostr, double_total(i, j));
      for (int k = 0; k < 4; k++)
	for (int l = 0; l < 4; l++)
	  write_long(ostr, double_count(i, j, Nucleotide(k), Nucleotide(l)));
    }
}

void PairTable::write_text(ostream& ostr) {
  int i;

  ostr << from1 << endl << to1 << endl
       << from2 << endl << to2 << endl;

  ostr << "SINGLES 1" << endl;
  for (i = from1; i < to1; i++) {
    ostr << i << ": " << single_total1(i) << "\t";
    for (int k = 0; k < 4; k++)
      ostr << single_count1(i, Nucleotide(k)) << "\t";
    ostr << endl;
  }
  ostr << "\n\n\n";

  ostr << "SINGLES 2" << endl;
  for (i = from2; i < to2; i++) {
    ostr << i << ": "<< single_total2(i) << "\t";
    for (int k = 0; k < 4; k++)
      ostr << single_count2(i, Nucleotide(k)) << "\t";
    ostr << endl;
  }
  ostr << "\n\n\n";

  ostr << "DOUBLES" << endl;
  for (i = from1; i < to1; i++) 
    for(int j = from2; j < to2; j++) {
      ostr << "(" << i << "," << j << ")" << double_total(i, j) << "\t" << endl;
      for (int k = 0; k < 4; k++) {
	for (int l = 0; l < 4; l++)
	  ostr << double_count(i, j, Nucleotide(k), Nucleotide(l)) << "\t";
	ostr << endl;
      }
      ostr << endl;
    }
  ostr << endl;
}

void PairTable::store_write(ostream& ostr1, ostream& ostr2) {
  store(ostr1);
  write_text(ostr2);
}

int PairTable::getSize() {
  if (len1 != len2) {
    cerr << "Pairtable is not square! Use getHeight and getWidth!" << endl;
  } else {
    return len1;
  }
  return -1;
}

const PairTable& PairTable::operator/=(PairTable &p) {
  int i;
  for(i = 0; i < len1*4; i++) 
    single_probs_log1[i] -= p.single_probs_log1[i];
  for(i = 0; i < len2*4; i++)
    single_probs_log2[i] -= p.single_probs_log2[i];
  for(i = 0; i < len1*len2*4*4; i++)
    double_probs_log[i] -= p.double_probs_log[i];
 
  return *this;
}



void PairTable::unparse_table(ostream& cout) {  // Compute chi-square values for table.
 
  for (int i = from1; i < to1; i++) {
    for (int j = from2; j < to2; j++) {
      double x2 = 0.0;

      for (int k = 0; k < 4; k++)
	for (int l = 0; l < 4; l++) {
	  double e = 
	      (double) single_count1(i, Nucleotide(k)) / single_total1(i)
	    * (double) single_count2(j, Nucleotide(l)) / single_total2(j) 
	    * double_total(i, j);
	  x2 += sqr(double_count(i, j, Nucleotide(k), Nucleotide(l)) - e) / e;
	}
      double v2 = x2 / (3*double_total(i,j));
      cout << v2 << " ";
  
      
    }
    cout << "\n";
  }
      
}






/******************* Deprecated: use store and write_text instead *****************/


void PairTable::unparse(ostream& cout) {
  cout << " (" << from1 << ", " << to1 << "; "
               << from2 << ", " << to2 << ") " 
    << single_total1(0) << "\n\n";

  unparse_table(cout);
  
  cout << "\n";

  double ee[4][4];

  for (int i = from1; i < to1; i++)
    for (int j = from2; j < to2; j++) {
      double x2 = 0.0;
      double v2 = 0.0;
      int k;
      for (k = 0; k < 4; k++)
	for (int l = 0; l < 4; l++) {
	  double e = 0.0;
	  e = (double) single_count1(i, Nucleotide(k)) / single_total1(i)
	    * (double) single_count2(j, Nucleotide(l)) / single_total2(j) 
	    * double_total(i, j);
	  x2 += sqr(double_count(i, j, Nucleotide(k), Nucleotide(l))-e) / e;
	  ee[k][l] = e;
        }
      
      double num = double_total(i,j);
      v2 = x2 / (3*num);
      cout << "( " << (i >= 0 ? i + 1 : i) << ", " << (j >= 0 ? j + 1 : j) << " )\tN = " << long(num) << 
	"\tX2 = " << x2 << "\tV2 = " << v2 << "\n";
      int oldprec = cout.precision(3);
      num = 1; 
      // I added this in only because the following lines contain "/num"
      // for reasons which baffle me
      cout << "\t";
      for (k = 0; k < 4; k++)
	cout << Sequence::base2char(Nucleotide(k)) << "\t";
      cout << "\t";
      for (k = 0; k < 4; k++)
	cout << Sequence::base2char(Nucleotide(k)) << "\t";
      cout << "\n";
      
      for (k = 0; k < 4; k++) {
	cout << Sequence::base2char(Nucleotide(k)) << "\t";
	cout << double_count(i,j,Nucleotide(k),BASE_A)/num << "\t"
	     << double_count(i,j,Nucleotide(k),BASE_C)/num << "\t" 
	     << double_count(i,j,Nucleotide(k),BASE_G)/num << "\t" 
	     << double_count(i,j,Nucleotide(k),BASE_T)/num << "\t\t";
	cout << int(ee[k][Nucleotide(0)])/num << "\t"
	     << int(ee[k][Nucleotide(1)])/num << "\t"
	     << int(ee[k][Nucleotide(2)])/num << "\t"
	     << int(ee[k][Nucleotide(3)])/num << "\n";
      }
      cout.precision(oldprec);
      cout << "\n";
    }
}

