#!/usr/bin/env python
'''NAME
        %(progname)s

VERSION
        %(version)s

AUTHOR
        Matthieu Defrance <defrance@bigre.ulb.ac.be>

DESCRIPTION
        implants sites in DNA sequences

CATEGORY
        motifs
        sequences

USAGE        
        %(usage)s

ARGUMENTS
    --version             show program's version number and exit
    -h, --help            show this help message and exit
    -v #, --verbosity=#   set verbosity to level #
                              0 no verbosity
                              1 max verbosity
    -i #, --input=#       read sequence from # (must be in FASTA format)
                          if not specified, the standard input is used
    -o #, --output=#      output results to #
                          if not specified, the standard output is used
    -s #, --sites=#       read sites from # (must be in FASTA format)
    -f #, --features=#    store site positions in #
    --noov                do not allow overlapping sites
    
SEE ALSO
        random-motif
        random-sites
'''
# ===========================================================================
# =                            imports
# ===========================================================================
import sys
import os
from math import *
import random
import optparse
from pydoc import pager
from random import choice, randint, shuffle

sys.path.insert(1, os.path.join(sys.path[0], 'lib'))
from lib import dna

# ===========================================================================
# =                            functions
# ===========================================================================
def read_sites(options):
    try:
        if type(options.sites) is str:
            f = open(options.sites)
        else:
            f = options.motif
        (names, sites) = dna.read_fasta_full(f)
        return zip(names, sites)
    except:
        sys.stderr.write('Error: Can not read sites\n')
        sys.exit(2)

def find_allowed(sequences, site, forbidden):
    allowed = []
    for s in range(len(sequences)):
        for i in range(len(sequences[s]) - len(site)):
            conflict = False
            for j in range(len(site)):
                if (s,i+j) in forbidden:
                    conflict = True
                    break
            if not conflict:
                allowed.append((s,i))
    return allowed

def add_forbidden(forbidden, site, pos):
    (s,i) = pos
    for j in range(len(site)):
            forbidden[(s,i+j)] = True
    return forbidden

def implant_all_sites(names, sequences, sites, noov, separate_ft):
    # find where to implant
    forbidden = {}
    positions = []
    if not noov:
        allowed = find_allowed(sequences, '*', forbidden)

    for (name, site) in sites:
        if noov:
            allowed = find_allowed(sequences, site, forbidden)
        #print(site)
        if len(allowed) == 0:
            sys.stderr.write('Warning: not all sites implanted\n')
            break
        pos = random.choice(allowed)
        (s,i) = pos
        failure_count = 0
        while i + len(site) > len(sequences[s]):
            if failure_count > 100:
                sys.stderr.write('Warning: not all sites implanted\n')
                break
            pos = random.choice(allowed)
            (s,i) = pos
            failure_count += 1

        positions.append(pos)
        forbidden = add_forbidden(forbidden, site, pos)

    # do the implatation
    labels = [str(i+1) for i in range(len(sequences))]
    features = []
    for i in range(len(positions)):
        (name, site) = sites[i]
        (s, p) = positions[i]
        sequence = sequences[s]
        if separate_ft != None:
            seqname = names[s].split()[0]
            ft = "%s\t%s\tsite\t%s\t%d\t%d" % (seqname, name.strip(), 'D', p + 1, p + len(site))
            features += [ft]
        else:
            labels[s] += '[%s:%d:%s]' % (name.strip(), p + 1, site)
        sequence = list(sequence)
        sequence[p:p + len(site)] = site
        sequences[s] = ''.join(sequence)
    return (labels, sequences, features)

def main(options, args):
    (names, sequences) = dna.read_fasta_full(options.input)
    sequences = [s.lower() for s in sequences]
    sites = read_sites(options)
    (labels, sequences, features) = implant_all_sites(names, sequences, sites, options.noov, options.ft)
    dna.write_fasta(options.output, sequences, labels)
    if options.ft != None:
        f = open(options.ft, 'w')
        for ft in features:
            f.write(ft)
            f.write('\n')
        f.close()

# ===========================================================================
# =                            main
# ===========================================================================
if __name__ == '__main__':
    USAGE = '''%s -i sequences -s sites [-h]'''
    VERSION = '2011.06'
    PROG_NAME = os.path.basename(sys.argv[0])
    
    parser = optparse.OptionParser(usage=USAGE % PROG_NAME, add_help_option=0, version=VERSION)
    parser.add_option("-o", "--output", action="store", dest="output", default=sys.stdout)
    parser.add_option("-v", "--verbosity", action="store", dest="verbosity", default=0)
    parser.add_option("-i", "--input", action="store", dest="input", metavar="#", help="sequences")
    parser.add_option("-s", "--sites", action="store", dest="sites", metavar="#", help="sites")
    parser.add_option("-f", "--features", action="store", dest="ft", metavar="#", default=None)
    parser.add_option("-h", "--help", action="store_true", dest="help")
    parser.add_option("--debug", action="store_true", dest="debug") 
    parser.add_option("--noov", action="store_true", dest="noov") 

    (options, args) = parser.parse_args()

    if options.help:
        doc =  globals()['__doc__'] % {'usage' : USAGE % PROG_NAME, 'version' : VERSION, 'progname' : PROG_NAME}
        print doc
        sys.exit(0)
    if not options.input:
        parser.print_usage()
        sys.exit()
    if len(sys.argv) == 1:
        print USAGE % PROG_NAME
        sys.exit(0)    
    try:
        main(options, args)
    except:
        if options.debug:
            raise
        else:
            sys.stderr.write('Error while running\n')

