"""Greedily choose sequences with high scores w.r.t. the scheme tested
in RNA/MSARi/tests/bacteria/check_variation.py"""

from RNA.structure.ViennaOdds import distinct_prob_pairings

winlength = 7

class Portion:

    """Portion(seq,  start) keeps  track  of variation  in a  portion,
    starting at index  start, of a MSA during  its assembly.  Assembly
    begins with sequence seq. """

    def __init__(self, seedseq, start):
        
        self.start = start
        self.portions_list = []
        self.portions = [self.portion(seedseq)]
        self.scores = []

    def portion(self, seq):

        return seq[self.start:self.start+winlength]

    def varcount(self, seq):

        variations = []
        portion = self.portion(seq)
        for eportion in self.portions:
            vrs = [None for (c1, c2) in zip(portion, eportion)
                   if (c1 != c2) and c1 in 'ACGT']
            variations.append(len(vrs))
        return min(variations)

    def total_score(self, scores):

        return sum([s**2 for s in scores])**0.5

    def varscore(self, seq):
        
        return self.total_score(self.scores + [self.varcount(seq)])

    def addseq(self, seq):

        self.scores.append(self.varcount(seq))
        self.portions.append(self.portion(seq))
        self.portions_list.append(self.portion(seq))

    def score(self):

        return self.total_score(self.scores)

class MSAChooser:

    def __init__(self, seedseq):

        self.seqs = [seedseq]
        self.portions = []
        pairs = list(distinct_prob_pairings(seedseq, winlength, 0.9))
        pairs.sort(reverse=True)

        # Keeps track of the  positions covered by the portions chosen
        # so  far.    distinct_prob_pairings  explicitly  allows  more
        # overlap  for the  context  of MSARi  than  makes sense  when
        # choosing MSAs.
        seen_pos = sets.Set()
        for prob, pair in pairs:
            for pos in pair:
                start = pos - (winlength/2)
                portionpos = sets.Set(range(start, start+winlength))
                if not portionpos.intersection(seen_pos):
                    self.portions.append(Portion(seedseq, start))
                    seen_pos.union_update(portionpos)

    def total_scores(self, scores):

        return sum([s**2 for s in scores])

    def score(self, seq):

        return self.total_scores(
            [p.varscore(seq) for p in self.portions])
        
    def addseq(self, seq):

        for portion in self.portions:
            portion.addseq(seq)
        self.seqs.append(seq)

    def bestseq(self, seqs):

        seqs = seqs[:]
        seqs.sort(key=self.score)
        return seqs[-1]

    def add_best_seq(self, seqs):

        bestseq = self.bestseq(seqs)
        self.addseq(bestseq)
        seqs = seqs[:]
        while bestseq in seqs:
            seqs.remove(bestseq)
        return seqs

    def add_random_seq(self, seqs):

        pass

    def size(self):

        return len(self.seqs)

    def current_score(self):

        return self.total_scores(
            [p.score() for p in self.portions])

def check_scores(msa):

    from RNA.MSARi.tests.bacteria import check_variation
    reload(check_variation)
    portions = check_variation.PortionScores(msa.seqs, None)
    # assert msa.current_score() == portions.total()
    positions = {}
    for p in msa.portions:
        positions[p.start] = [p]
    for p in portions.all_variations:
        positions[p.start].append(p)
    st()
    print msa.current_score()
    print portions.total()

if __name__ == '__main__':

    from RNA.data.sequences import rnase, srp
    from RNA.MSARi.MSA import choose_MSA
    reload(choose_MSA)
    seqs = [s.upper()[:785] for s in rnase.sequences.values()]
    msa = choose_MSA.MSAChooser(seqs.pop())
    while msa.size() < 15:
        print msa.size()
        seqs = msa.add_best_seq(seqs)
    choose_MSA.check_scores(msa)
