from copy import deepcopy
import MSA
try:
    from MSA import MAX_TUPLE_LENGTH
except ImportError:
    import sys
    print >> sys.stderr, 'Failed to import MSA!'

def get_column_matches(seqs, colidx):

    """Return a list  of lists, each element containing  a list of 0's
    and 1's, with 0 if the colidx'th nucleotide of the correspondingly
    positioned   sequences  in   seqs   matches  the   correspondingly
    positioned sequence's colidx'th nucleotide."""

    matches = []
    for seqidx, seq in iterate(seqs):
        nuke = seq[colidx]
        if nuke == '-':

            # Nothing can match a dash, so skip.
            matches.append(seqidx * [0])
        else:
            current_matches = []
            matches.append(current_matches)
            for seq2 in seqs[:seqidx]:
                if nuke == seq2[colidx]:
                    current_matches.append(1)
                else:
                    current_matches.append(0)
    return matches

def add_matches(total, matches):

    """matches should  be a structure  returned by get_column_matches,
    and total should be a list  of lists with the same structure.  Add
    the integers in matches to the corresponding integers in total."""

    for seqidx, (ctotal, cmatches) in iterate(zip(total, matches)):
        for seqidx2 in range(seqidx):
            ctotal[seqidx2] += cmatches[seqidx2]

def subtract_matches(total, matches):

    """Does the same as add_matches, except it subtracts"""

    for seqidx, (ctotal, cmatches) in iterate(zip(total, matches)):
        for seqidx2 in range(seqidx):
            assert type(ctotal[seqidx2]) == type(cmatches[seqidx2]) == type(0)
            ctotal[seqidx2] -= cmatches[seqidx2]

def get_match_counts(seqs):

    """Return a list of lists  of lists, each element corresponding to
    the  MAX_TUPLE_LENGTH-long window  starting  at the  corresponding
    column in the  MSA.  Elements of these lists  are lists containing
    the similarities of the rows prior to the corresponding row in the
    window."""

    assert len(Set(map(len, seqs))) == 1,'Seqs are different lengths.'

    # Return value
    match_counts = []

    # Per-column matches
    column_sims = [get_column_matches(seqs, colidx) \
                   for colidx in range(len(seqs[0]))]

    # Get  the running  count started  with the  matches in  the first
    # MAX_TUPLE_LENGTH windows.
    firstsim = deepcopy(column_sims[0])
    match_counts.append(firstsim)
    for cidx in range(1, MAX_TUPLE_LENGTH):
        add_matches(firstsim, column_sims[cidx])

    # Construct the next similarity count from the previous one.
    for cidx in range(1, len(seqs[0]) - MAX_TUPLE_LENGTH):
        currentsim = deepcopy(match_counts[cidx-1])
        match_counts.append(currentsim)
        subtract_matches(currentsim, column_sims[cidx-1])
        last_matches = column_sims[cidx + MAX_TUPLE_LENGTH-1]
        add_matches(currentsim, last_matches)

    return match_counts

def make_clusters(match_count):

    """Return  a list  of  lists, each  element  corresponding to  the
    sequence at the same index  in the MSA, and containing the indices
    of  the sequences  that are  in the  same cluster  as  the current
    sequence."""

    clusters = [[0]]
    for seqidx, current_count in iterate(match_count[1:]):
        seqidx += 1
        highest_similarity = max(current_count)
        if highest_similarity < 0.5 * MAX_TUPLE_LENGTH:

            # This sequence didn't  have significant similarity to any
            # prior sequence.  Make a new cluster for it.
            clusters.append([seqidx])
        else:
            first_match = current_count.index(highest_similarity)

            # Add this  sequence to  the cluster containing  the first
            # sequence which it had the strongest match to.
            clusters[first_match].append(seqidx)

            # Put that  cluster in the position  corresponding to this
            # sequence.
            clusters.append(clusters[first_match])

    return clusters

def get_clusters(seqs):

    match_counts = get_match_counts(seqs)
    return map(make_clusters, match_counts)

def refine_clusters(cluster1, cluster2):

    rv = []
    seqcount = 0
    for cluster in Set(map(tuple, cluster1)):
        sc1 = Set(cluster)
        for seqidx in cluster:
            seqcount += 1
            new_cluster = (sc1 & Set(cluster2[seqidx])).values()
            if new_cluster not in rv:
                rv.append(new_cluster)
        if seqcount == len(cluster2):
            break
    return rv

if __name__ == '__main__':

    from mouse.tools import clustalw
    test_path = '/scratch2/data2/bacteria/sequences/msas/++53GK3C89qjP77Z1XbAVg.aln'
    seqs = clustalw.parse_clustalw(open(test_path))
    from mouse.rna.msa.tuples3 import clusters
    reload(clusters)
    t = clusters.get_clusters(seqs.values())
    pp(clusters.refine_clusters(t[70], t[80]))
