import random, sys

msaripath = '/var/tmp/MSARi'
msadir = '/var/tmp/MSAs/'
output_path = '/var/tmp/msari.results'

if sys.path[0] != msaripath:
    sys.path.insert(0, '/var/tmp/MSARi')

import admin.python.pythonrc

from RNA.MSARi.MSA.Jaynes import MSA
reload(MSA)

from RNA.data.sequences import rnase, srp
from RNA.MSARi.tests import scramble
from tools.Alignment import clustalw
reload(clustalw)

MSASIZE = 10

class DevNull:

    def write(self, output): pass

def shuffleMSA(msa):

    """Shuffle the columns of this MSA"""

    seqs = [t[1] for t in msa]
    names = [t[0] for t in msa]
    randseqs = scramble.random_ensemble(seqs, DevNull())
    randseqs = [seq.replace('-', '') for seq in randseqs]
    seqdict = {}
    for name, seq in zip(names, randseqs):
        seqdict[name] = seq
    _raw_msa = clustalw.get_MSA(seqdict)
    raw_msa = [t[1] for t in _raw_msa]
    return cleanMSA(raw_msa)

def NullHypothesisMSA(msa):

    """Make  an 'MSA'  out of  the null  hypotheses associated  to the
    columns in the given MSA, and shuffle the columns."""

    msa = shuffleMSA(msa)
    new_msa = [[] for seq in msa]
    for colidx in range(len(msa[0])):
        column = [seq[colidx] for seq in msa]
        sorted_column = []
        for nuke in 'ACGT':
            sorted_column.extend(column.count(nuke) * [nuke])
        for seq in new_msa:
            seq.append(random.choice(sorted_column))
    return [''.join(seq) for seq in new_msa]

def cleanMSA(msa):

    """Change  U's   to  T's,   and  get  rid   of  columns   with  no
    nucleotides."""

    minlen = min(map(len, msa))
    msa = [seq.replace('U', 'T')[:minlen] for seq in msa]
    msa = [seq.replace('u', 't') for seq in msa]
    msa = [seq for seq in msa if not re.match('[^ACGTacgt]+$', seq)]
    new_msa = [[] for seq in msa]
    for colidx in range(len(msa[0])):
        column = [seq[colidx] for seq in msa]
        if max([column.count(nuke) for nuke in 'ACGT']) > 0:
            for char, seq in zip(column, new_msa):
                seq.append(char)
    return [''.join(seq) for seq in new_msa]

def stripMSAhomology(msa):

    """Mask  regions  of  sequences  in  the  MSA  that  are  strongly
    homologous to earlier sequences in the MSA."""

    new_msa = [msa[0]]
    for seq in msa[1:]:
        newseq = list(seq)
        for winidx in range(0, len(seq)-40, 5)+[len(seq)-40]:
            for prior_seq in new_msa:

                # Get the similarity between seq and prior_seq
                nuke_count = 0.
                match_count = 0
                for charidx in range(winidx, winidx+40):
                    nuke_count += 1
                    if seq[charidx] == prior_seq[charidx]:
                        match_count += 1
                if nuke_count:
                    similarity = match_count/nuke_count
                else:
                    similarity = 0
                if similarity >= 0.85:
                    newseq[winidx:winidx+40] = 40 * ['-']
                    break
        new_msa.append(''.join(newseq))
    return new_msa

def score_msa(msa):

    tupmsa = MSA.MSA(msa, 7)
    tupmsa.get_matches()
    pp(tupmsa.matches[:20])
    # best_score = tupmsa.best_odds()
    tupscore = tupmsa.combined_odds()
        
    return tupscore

    colmsa = ColumnMSA.MSA(msa)
    colmsa.get_matches()
    colmsa.matches.sort()
    pp(colmsa.matches[:20])
    colscore = colmsa.best_odds()
    helmsa = HelixMSA.MSA(msa)
    _helscore = helmsa.common_alignments_odds()
    if _helscore:
        helscore = _helscore[0]
    else:
        helscore = 1
    return colscore, tupscore, helscore

def content(msa):

    id_proportion = 0
    gc_count = 0
    total_count = 0.
    for colidx in range(len(msa[0])):
        column = [seq[colidx] for seq in msa]
        nuke_props = dict([(nuke, 0) for nuke in 'ACGT'])
        total = 0.
        for char in column:
            if char in 'ACGT':
                total_count += 1
                total += 1
                nuke_props[char] += 1
                if char in 'GC':
                    gc_count += 1
        id_proportion += max(nuke_props.values())/total
    
    return id_proportion/len(msa[0]), gc_count/total_count

def makeMSAs(seqs):

    scores = []

    # Get an MSA, without any help from existing gaps
    seqs = dict([(n, seq.replace('-', '')) for n,seq in seqs.items()])
    _raw_msa = clustalw.get_MSA(seqs) 
    raw_msa = [t[1] for t in _raw_msa]
    raw_msa = cleanMSA(raw_msa)

    control_msa1 = cleanMSA(shuffleMSA(_raw_msa))
    # control_msa2 = cleanMSA(shuffleMSA(_raw_msa))
    # control_msa3 = cleanMSA(shuffleMSA(_raw_msa))
    # control_msa4 = cleanMSA(shuffleMSA(_raw_msa))
    for msa in (raw_msa,
                control_msa1,
                # control_msa2,
                # control_msa3,
                # control_msa4
                ):
        scores.append(score_msa(msa))
    pp(scores)
    return scores, raw_msa, [control_msa1,
                             # control_msa2, control_msa3, control_msa4
                             ]

def simextremes(seqname, chosen_seqs, seqs):

    align_seqs = dict([(n, seq.replace('-', '').upper()) for n,seq in chosen_seqs.items()])
    align_seqs[seqname] = seqs[seqname].replace('-', '').upper()
    align_seqs = dict(clustalw.get_MSA(align_seqs))
    seq = align_seqs[seqname].upper()
    del align_seqs[seqname]
    charcount = len(re.sub('[^ACGTU]', '', seq))
    sims = []
    for cseqname in align_seqs:
        cseq = align_seqs[cseqname].upper()
        count = 0.
        for posidx in range(len(seq)):
            if seq[posidx].upper() in 'ACGTU':
                count += (seq[posidx] == cseq[posidx])
        sims.append(count)
    sims.sort()
    return sims[0]/charcount, sims[-1]/charcount
    sims = sims[-len(sims)/2:]
    return sum(sims)/len(sims)

def choose_sequences(seqs):

    choice_count = 0
    while choice_count < MSASIZE:
        chosen_seqs = {}
        unchosen_seqs = seqs.copy()
        unconsidered = unchosen_seqs.copy()
        while (len(chosen_seqs) < MSASIZE) and unconsidered:
            next = random.choice(unconsidered.keys())
            del unconsidered[next]
            if chosen_seqs:
                simmin, simmax = simextremes(next, chosen_seqs, seqs)
                if .5 <= simmax <= .95:
                    chosen_seqs[next] = seqs[next]
                    del unchosen_seqs[next]
                    unconsidered = unchosen_seqs.copy()
            else:
                chosen_seqs[next] = seqs[next]
                del unchosen_seqs[next]
                unconsidered = unchosen_seqs.copy()
        if len(chosen_seqs) == MSASIZE:
            return chosen_seqs
        else:
            choice_count += 1

def process_sequences(seqs, output):

    print >> output, seqs.keys()
    output.flush()
    print >> output, content(cleanMSA(seqs.values()))
    output.flush()
    scores, raw_msa, control_msas = makeMSAs(seqs)
    print >> output, scores
    output.flush()
    msapath = msadir + str(seqs.keys()).replace(' ', '_')
    open(msapath, 'w').write(str(raw_msa) + '\n' + str(control_msas))

def main():

    output = open(output_path, 'a')

    for i in range(1000):
        for _seqs in [srp.sequences['Eukar.'], rnase.sequences, ]:
            seqs = choose_sequences(_seqs)
            process_sequences(seqs, output)

if __name__ == '__main__':
    from RNA.MSARi.tests import general_test
    reload(general_test)
    general_test.main()

