from  RNA.MSARi.MSA.Jaynes import match_distribution, count_prob
reload(match_distribution); reload(count_prob)
from RNA.structure.ViennaOdds import pair_probs
import warnings

class Sequence:

    """Keep  track  of  a putative  RNA  sequence  as  part of  a  MSA
    ensemble."""

    def __init__(self, seq, otherseqs, winlength):

        self.seq = seq

        # Maps  between  the  gapped  and ungapped  positions  in  the
        # sequence.
        self.un_gap_map, self.gap_un_map = [], []
        for seqposidx, char in enumerate(self.seq):
            self.gap_un_map.append(len(self.un_gap_map))
            if char in 'ACGT':
                self.un_gap_map.append(seqposidx)
            else:
                assert char == '-', "Sequences should match '^[ACGT-]*$'"
        for ungapidx, gapidx in enumerate(self.un_gap_map):
            assert self.gap_un_map[gapidx] == ungapidx

        self.distributions = match_distribution.dist(
            seq, otherseqs, winlength)

        warnings.warn('Get rid of the otherseqs attribute.')
        self.otherseqs = otherseqs
        self.pairs = pair_probs(self.seq)

        # Set of registration offsets to consider when looking for the
        # best local match.
        self.offset_range = range(-2, 3)

    def best_pairing(self, pos1, pos2, winlength, verbose=False):

        """Return  the  starts  of  the windows  of  length  winlength
        containing  pos1 and  pos2 which  have the  most complementary
        nucleotides   between  them,   as  well   as  the   number  of
        complementary  nucleotides.   Returned   as  a  tuple  (count,
        startpos1, startpos2).   Both the argument  and return indices
        refer to positions in the gapped sequence."""

        start1s, start2s = sets.Set(), sets.Set()
        best_score, best_start1, best_start2 = 0, None, None

        assert pos1 < pos2

        # Iterate  over  windows   of  length  winlength  starting  at
        # positions  at most  winlength  positions prior  to pos1  and
        # winlength positions following pos2.
        start1 = max(0, self.gap_un_map[pos1]-(winlength/2))
        start1 = min(len(self.un_gap_map)-winlength-1, start1)
        for offset in self.offset_range:
            start2 = max(0, self.gap_un_map[pos2]+offset)
            start2 = min(len(self.un_gap_map)-winlength-1, start2)
            if (start1 in start1s) or (start2 in start2s):

                # Don't consider offsets that have already been looked
                # at.
                continue
            score = self.complementarity(start1, start2, winlength)
            if verbose:
                print pos1, pos2, score
            if score > best_score:
                best_score = score
                best_start1, best_start2 = start1, start2
        if best_start2 is not None:
            pos1 = self.un_gap_map[best_start1]
            pos2 = self.un_gap_map[best_start2]
            self.gapped_pairing_prob(pos1, pos2, winlength, best_score)
            return best_score, pos1, pos2
        else:
            return 0, None, None

    def get_ungapped_positions(self, start1, start2, winlength):

        """Takes  a  pair  start1,  start2  of  indices  referring  to
        positions  in  the ungapped  sequence,  and  a window  length.
        Returns  the lists  of winlength  indices *gapped*  indices of
        valid nonoverlapping positions  starting at start1 and start2,
        and going backwards from start2"""

        winlength = min(winlength, (start2 - start1)/2)
        seq1pos = self.un_gap_map[start1:start1+winlength]
        seq2pos = self.un_gap_map[start2-winlength:start2]
        seq2pos.reverse()
        return seq1pos, seq2pos

    def complementarity(self, start1, start2, winlength):

        """Return the number  of complementary nucleotides between the
        sequence starting at start1 and ending at start1+winlength and
        the  sequence  starting  at  start2-winlength  and  ending  at
        start2.  The  indices start1 and start2 refer  to positions in
        the ungapped sequence."""

        # Get the actual nucleotides
        seq1,seq2=self.get_ungapped_positions(start1,start2,winlength)
        seq1 = [self.seq[i] for i in seq1]
        seq2 = [self.seq[i] for i in seq2]

        # Count how many are complementary.
        score = 0
        for nuke1, nuke2 in zip(seq1, seq2):
            if nuke1 in match_distribution.complements[nuke2]:
                score += 1
        return score

    def pairing_prob(self, start1, start2, winlength, score,
                     verbose=False):

        """Return  the   probability  of  getting   at  least  <score>
        complementary  nucleotides  between   the  two  windows.   The
        indices start1  and start2 refer to positions  in the ungapped
        sequence."""

        # Get the probabilities of complementary pairs for each of the
        # paired positions in the windows
        s1, s2 = self.get_ungapped_positions(start1, start2,winlength)
        chars1 = [[seq[i] for seq in [self.seq] + self.otherseqs] for i in s1]
        chars2 = [[seq[i] for seq in [self.seq] + self.otherseqs] for i in s2]
        count = 0
        for c1, c2 in zip(chars1, chars2):
            if c1[0] in match_distribution.complements[c2[0]]:
                count += 1
        assert count >= score
        vars1 = [self.distributions[i] for i in s1]
        vars2 = [self.distributions[i] for i in s2]
        probabilities = [match_distribution.pairing_prob(v1, v2) \
                         for v1, v2 in zip(vars1, vars2)]

        countprob = count_prob.VarCount(probabilities)
        assert self.complementarity(start1, start2, winlength) >= score
        rv = countprob.significances[score]

        # Take  into account  the fact  that this  was chosen  from 20
        # possible windows.
        rv = 1-((1-rv)**len(offset_range))
        
        if not verbose:
            return rv
        if 1 or rv < 0.05:
            print 80*'*'
            print start1, start2, score, rv
            for c1, c2, v1, v2 in zip(chars1, chars2, vars1, vars2):
                print match_distribution.pairing_prob(v1, v2)
                print c1
                print c2
                print v1
                print v2
        return rv

    def gapped_pairing_prob(self, pos1, pos2, winlength, score,
                            verbose=False):

        """Return  the   probability  of  getting   at  least  <score>
        complementary  nucleotides  between   the  two  windows.   The
        indices  pos1  and  pos2  refer  to positions  in  the  gapped
        sequence."""

        return self.pairing_prob(self.gap_un_map[pos1],
                                 self.gap_un_map[pos2],
                                 winlength, score, verbose)

def cleanseq(seq):

    return re.sub('[^ACGT-]', '-', seq.upper())

if __name__ == '__main__':

    from RNA.data.sequences.rnase import sequences
    from RNA.MSARi.MSA.Jaynes import Sequence
    reload(Sequence)
    _seq = cleanseq(sequences.values()[0])
    otherseqs = map(cleanseq, sequences.values()[1:10])
    winlength = 10
    seq = Sequence.Sequence(_seq, otherseqs, winlength)
    for seqidx, pairs in enumerate(seq.pairs):
        for prob, complement in pairs:
            if prob > 0.05:
                continue
            if (complement >= seqidx):
                count, pos1, pos2 = seq.best_pairing(seqidx, complement, winlength)
                if count > 0:
                    prob = seq.gapped_pairing_prob(pos1, pos2, winlength, count)
