#include "gibbs.h"

int VERBOSITY = 0;

#define VERBOSE1(...)	      if (VERBOSITY >= 1) {fprintf(stderr, __VA_ARGS__);}
#define VERBOSE2(...)	      if (VERBOSITY >= 2) {fprintf(stderr, __VA_ARGS__);}
#define VERBOSE3(...)	      if (VERBOSITY >= 3) {fprintf(stderr, __VA_ARGS__);}

#define NEW_TRACE(trace, i)   if (VERBOSITY >= 3) {char name[256];sprintf(name, "%d.ic", i);trace = fopen(name, "w");}
#define CLOSE_TRACE(tace)     if (VERBOSITY >= 3) {fclose(trace);}
#define TRACE(i, ...)	      if (VERBOSITY >= 3) {fprintf(trace, __VA_ARGS__);}
#define DEBUG(...)            if (VERBOSITY >= 3) {fprintf(stdout, __VA_ARGS__); fflush(stdout);}
#define ERROR(...) {fprintf(stderr, "ERROR: "); fprintf(stderr, __VA_ARGS__); fprintf(stderr, "\n"); exit(1);}

#define MAX_DOUBLE 			  1.7e308
#define SITES vector<Site>
#define RAND (rand() / (double) RAND_MAX)

/***************************************************************************
 *                                                                         *
 *  LLR (log likelihood ratio) & Iseq (information content)
 *                                                                         *
 ***************************************************************************/
/*
 *   llr =   Sum_k log P(w^k | M) - log P(w^k | B)
 *   llr = Sum_i Sum_k log P(w_i^k | M) - Sum_k P(w^k | B)
 *         ----------------------------   ----------------
 *               alpha                        beta
 *   
 *   llr =~ 1/N Sum_i Sum_j C_ij /n log C_ij /n - Sum_k log P(w^k | B)
 */
struct LLR{
	int l;
	int n;
	double *p;
	double **T;	
	double beta;
	double alpha;
	double P;
	int i,j,k;
	Array Tab;
	Markov markov;
	Sequence *sequences; // avoid sequences copy
	double **logp;

	LLR(SITES &motif, int r, Array &matrix, int length, Markov &markov_model, Sequences &seqs, double **logprob, double pseudo=PSEUDO){
		l = length;
		n = motif.size();
		markov = markov_model;
		p = markov.priori;
		sequences = seqs.data; // sequences store only data !
		logp = logprob;
		
		// compute the matrix
		for(i=0; i<l; i++){
			for(j=0; j<matrix.J; j++)
				matrix[i][j] = 0;
			for(k=0; k<n; k++){
				if (k == r)
					continue;
				j = sequences[motif[k].s][motif[k].p+i];
				matrix[i][j] += 1;
			}
		}

		// compute alpha
		int J;
		Tab.alloc(l, matrix.J);

		for(i=0; i<l; i++)
		{
			for(j=0; j<matrix.J; j++)
			{
				alpha = 0.0;
				for(k=0; k<n; k++)
				{
					if (k == r)
						continue;
					J = sequences[motif[k].s][motif[k].p+i];
					if (j == J)
						alpha += log ( (matrix[i][J] + 1 + p[J] * pseudo) / (n + pseudo) );
					else
						alpha += log ( (matrix[i][J] + p[J] * pseudo) / (n + pseudo) );					
				}
				// the new letter contribution
				alpha += log ( (matrix[i][j] + 1 + p[j] * pseudo) / (n + pseudo) );
				Tab.data[i][j] = alpha;
			}			
		}
		T = Tab.get_data();

        // compute beta
		beta = 0.0;
		for(k=0; k<n; k++){
			if (k == r)
				continue;
			beta += logp[motif[k].s][motif[k].p];
		}
	}

	double llr(Site &site){
		alpha = 0.0;
		for(i=0; i<l; i++){
			j = sequences[site.s][site.p+i];
			alpha += T[i][j];
		}
		return alpha - beta - logp[site.s][site.p];
	}
};


void count_matrix(Array &matrix, SITES &motif, Sequences &sequences){
	int i,j,k;
	int n = (int) motif.size();
	for(i=0; i<matrix.I; i++){
		for(j=0; j<matrix.J; j++)
			matrix[i][j] = 0;
		
		for(k=0; k<n; k++){
			j = (int) sequences[motif[k].s][motif[k].p+i];
			matrix[i][j] += 1;
		}
	}
}

/**
 *  Convert a list of words to a frequency matrix using priori probabilities p
 *  m_{b,i} = ( f_{b,i} + p_i ) / ( N + 1) [Hertz 1999]
 */
inline void freq_matrix(Array &matrix, SITES &motif, Sequences &sequences, Markov &markov, double pseudo=PSEUDO){
	int i,j,k;
	int n = (int) motif.size();
	for(i=0; i<matrix.I; i++){
		for(j=0; j<matrix.J; j++)
			matrix[i][j] = 0;
		
		for(k=0; k<n; k++){
			j = (int) sequences[motif[k].s][motif[k].p+i];
			matrix[i][j] += 1;
		}
		for(j=0; j<matrix.J; j++)
			matrix[i][j] = (matrix[i][j] + markov.priori[j] * pseudo) / (n+pseudo);
	}
}


/*
    Relative Entropy [Hertz 1999]
    Iseq = \sum_{i=1}^w \sum_{b=1}^4 f_{b,i} ln (f_{b,i} / p_b)

    matrix -- frequency matrix
    p      -- priori probability [0.25, 0.25, 0.25, 0.25]
*/
double Iseq(SITES &motif, Sequences &sequences, Array &matrix, Markov &markov, double pseudo=PSEUDO){
	// matrix
	freq_matrix(matrix, motif, sequences, markov, pseudo);

	// iseq
	double iseq = 0.0;
	for(int i=0; i<matrix.I; i++){
		for(int j=0; j<matrix.J; j++){
			if (matrix[i][j] != 0.0){
				iseq += matrix[i][j] * log(matrix[i][j] / markov.priori[j]);
			}
		}
	}
	return iseq;
}


double llr(SITES &motif, Sequences &sequences, Array &matrix, Markov &markov, double pseudo=PSEUDO){
	// matrix
	freq_matrix(matrix, motif, sequences, markov, pseudo);
	// llr
	int n = (int) motif.size();
	int j;
	double s = 0.0;
	for(int k=0; k<n; k++){
		for(int i=0; i<matrix.I; i++){
			j = (int) sequences[motif[k].s][motif[k].p+i];
			s += log(matrix[i][j] / markov.priori[j]);
		}
	}
	return s / n;
}


/***************************************************************************
 *                                                                         *
 *  SITES = all available motif positions
 *                                                                         *
 ***************************************************************************/
void print_sites(SITES &sites){
	for(unsigned int i=0; i<sites.size(); i++){
		printf("%d %d\n", sites[i].s, sites[i].p);
	}
}

SITES mask_motif(SITES &sites, SITES &motif){
	SITES masked_sites;
	for (int i=0; i<(int) sites.size(); i++){
        bool remove_site = false;
	    for (int j=0; j<(int) motif.size(); j++){
            if (sites[i].s == motif[j].s && sites[i].p == motif[j].p){
                remove_site = true;
            }
        }
        if (! remove_site)
		    masked_sites.push_back(sites[i]);
	}
	return masked_sites;
}

SITES all_sites(Sequences &sequences, int l){
	SITES sites;
	for (int s=0; s<sequences.size(); s++){
		int len = sequences[s].size();
		for(int p=0; p<len-l+1; p++){
			bool is_valid_site = true;
			for(int x=0; x<l; x++){
				if (sequences[s][p+x] == -1){
					is_valid_site = false;
					break;
				}
			}
			if (is_valid_site)
				sites.push_back(Site(s,p));
		}
	}	
	return sites;
}


#define INVALID -1
SITES remove_neighbours(SITES &allsites, SITES &motif, int dmin=0, int r=-1){
 	if (dmin == 0)
 		return allsites;
	SITES sites = allsites;
	for(int k=0; k< (int) motif.size(); k++){
		if (k == r)
			continue;
		for(int i=0; i< (int) sites.size(); i++){
			 if (sites[i].s > motif[k].s)
			 	break;
			 else if (sites[i].s < motif[k].s)
			 	continue;
			 if (sites[i].p >= motif[k].p + dmin)
			 	break;
			if ( (sites[i].p > motif[k].p - dmin) && (sites[i].p < motif[k].p + dmin) )
				sites[i].s = INVALID; //invalidate site
		}
	}
	SITES new_sites;
	for(int i=0; i< (int) sites.size(); i++){
		if (sites[i].s != INVALID)
			new_sites.push_back(sites[i]);
	}
	return new_sites;
}

/***************************************************************************
 *                                                                         *
 *  MOTIF
 *                                                                         *
 ***************************************************************************/
void print_motif(SITES &motif, vector<string> &raw_sequences, Sequences &sequences, int l, double ic, bool rc=false){
	int n = motif.size();

	printf("; total.information             %.3f\n", ic);
	printf("; information.per.column       	%.3f\n", ic / l);
	printf("; sites                        	%d\n", (int) motif.size());
	printf("; seq\tstrand\tpos\tword\n");

	//sites
	int nseq = raw_sequences.size();
	int seq;
	char strand_label;
	
	for(int i=0; i<n; i++){
		int s = motif[i].s;
		int p = motif[i].p;
		if (rc && s > nseq / 2){
			strand_label = '-';
			seq = s - nseq / 2 + 1;
		}else{
			strand_label = '+';
			seq = s + 1;
		}
		const char *word = raw_sequences[s].substr(p,l).c_str();
		printf("; %d\t%c\t%d\t%s\n", seq, strand_label, p, word);
	}

	//matrix
	Array matrix = Array(l, ALPHABET_SIZE);
	count_matrix(matrix, motif, sequences);

	for(int j=0; j<ALPHABET_SIZE; j++){
		printf("%c | ", ALPHABET[j]);
		for(int i=0; i<l; i++){		
			printf("\t%d", (int) matrix[i][j]);
		}
		printf("\n");
	}
	printf("//\n");
}


bool inline is_in_sites(Site &site, SITES &sites){
	for(unsigned int i=0; i<sites.size(); i++){
		if (site.p == sites[i].p && site.s == sites[i].s)
			return true;
	}
	return false;
}


SITES random_motif(SITES &allsites, int n, int dmin=0){
	SITES motif;
	int j = 0;
	SITES sites = allsites;
    VERBOSE2("generating random motif\n");
	while ((int)motif.size() < n && j++ < n*2){
		int i = (int) (RAND * sites.size());
		if (! is_in_sites(sites[i], motif)){
			motif.push_back(sites[i]);
			sites = remove_neighbours(sites, motif, dmin);
		}
	}
	return motif;
}

/***************************************************************************
 *                                                                         *
 *  SPEEDUP STRUCTURES
 *                                                                         *
 ***************************************************************************/
struct Is_a_site {
	bool **cache;
	Sequences *seqs;

	Is_a_site(Sequences &sequences, SITES &sites){
		int S = sequences.size();
		seqs   = &sequences;
		cache = new bool*[S];
		// alloc & set to false
		for(int s=0; s<S; s++){
			int len = sequences[s].size();
			cache[s] = new bool[len];
			for(int p=0; p<len; p++)
				cache[s][p] = false;
		}
		// init with sites
		for(int k=0; k< (int) sites.size(); k++)
			cache[sites[k].s][sites[k].p] = true;
	}
	
	bool is_valid(Site &site){
		if (site.s < 0 or site.s >= seqs->size())
			return false;
		else if (site.p < 0 or site.s >= seqs->data[site.s].size())
		 	return false;
		else{
		 	return cache[site.s][site.p];
		}
	}
};

 
struct SamplingData {
	int l;
	int T; // max number of sites
	int S; // number of sequences
	double **logp;       // log P(word) cache for each p
	double *cdf;         // cumulative dist function 
	Site *sampled_sites; // sites (corresponds to cdf)

	SamplingData(Sequences &sequences, SITES &sites, Markov &markov, int l)
	{
		S = (int) sequences.size();
		T = (int) sites.size();
		cdf = new double[T];
		sampled_sites = new Site[T];
		logp = new double*[S];

		// alloc logP
		for(int s=0; s<S; s++){
			int len = sequences[s].size();
			logp[s] = new double[len];
		}

		for(int i=0; i<T; i++){
			int s = sites[i].s;
			int p = sites[i].p;	
			logp[s][p] = markov.logP(&sequences[s][p], l);
		}
	}
};

/***************************************************************************
 *                                                                         *
 *  SHIFTING
 *                                                                         *
 ***************************************************************************/
SITES shifted(SITES &motif, SITES &sites, int delta, Is_a_site &cache){
	if (delta == 0)
		return motif;

	SITES shifted_motif = motif;
	for(unsigned int i=0; i<motif.size(); i++){
		Site site = Site(motif[i].s, motif[i].p+delta);
		if ( ! cache.is_valid(site))
			return motif;
		shifted_motif[i] = site;
	}
	return shifted_motif;
}


SITES shift(SITES &motif, Sequences &sequences, SITES &sites, Array &matrix, Markov &markov, Is_a_site &cache){
	SITES best_motif;
	SITES current_motif;
	double best_ic = 0.0;
	double current_ic;

	for(int delta=-1; delta<=1; delta++){
		current_motif = shifted(motif, sites, delta, cache);
		current_ic = Iseq(current_motif, sequences, matrix, markov);
		if (current_ic > best_ic){
			best_motif = current_motif;
			best_ic = current_ic;
		}
	}
	return best_motif;
}

/***************************************************************************
 *                                                                         *
 *  SAMPLING
 *                                                                         *
 ***************************************************************************/
int UPDATE = -1;

void sample_update(SITES &allsites, Sequences &sequences, SITES &motif, Array &matrix, Markov &markov,\
                    int l, double temperature, SamplingData &data, int dmin=0){
	int n = motif.size();
	int r = 0;

	double beta = 1.0 / temperature;
	double *cdf = data.cdf;
	double S = 0.0;
	double val = 0.0;
	int i;
	Site *sampled_sites = data.sampled_sites;
	LLR llr_table = LLR(motif, r, matrix, l, markov, sequences, data.logp, PSEUDO);

	// set update
	if (UPDATE == -1)
		UPDATE = n;

	// choose word
	r = (int) (RAND * n);
	SITES sites = remove_neighbours(allsites, motif, dmin, r);
	//SITES &sites = allsites;

	int T = sites.size();
	double s = 0.0;
	for(i=0; i<T; i++){
		s = exp(beta * llr_table.llr(sites[i]));
		if (s >= 1e300){
			s = 1e300;
		}
		S += s;
		sampled_sites[i] = sites[i];
		cdf[i] = S;		
	}
	if (S >= MAX_DOUBLE){
		DEBUG("WARNING math overflow error\n");
		return;
	}

	// choose new sites
	for(int j=0; j<UPDATE; j++){
		// random choose
		val = RAND * S;
		for(i=0; i<T; i++){
			if (cdf[i] > val)
				break;
		}

		// update motif
		Site new_site = sampled_sites[i];
		if (! is_in_sites(new_site, motif)){
			motif[r] = new_site;
		}
		r = (int) (RAND * n);
	}
}


/***************************************************************************
 *                                                                         *
 *  FIND ONE MOTIF
 *                                                                         *
 ***************************************************************************/
struct Result {
    SITES motif;
    double ic;
    int l;
};

Result find_one_motif(vector<string> &raw_sequences, Sequences & sequences, SITES &sites, Markov &markov, Parameters &params){
    int l = params.l;
    int n = params.n;
    int max_iter = params.iter;
    double temperature = params.temperature;
    int n_run = params.nrun;
    int dmin = params.dmin;
	
	double current_llr    = 0.0;
	double current_ic     = 0.0;
	double best_ic        = 0.0;
	double best_ic_in_run = 0.0;	

	Array matrix = Array(l, markov.alphabet_size);
	SITES motif;
	SITES best_motif;

	// init efficiency only dedicated data
	SamplingData data = SamplingData(sequences, sites, markov, l);

	Is_a_site sites_cache = Is_a_site(sequences, sites);
	FILE *trace = NULL;

	if ((int) sites.size() < n){
        ERROR("too few allowed sites");
    }

	for(int run=0; run<n_run; run++){
        VERBOSE1("run %d\n", run+1);
		NEW_TRACE(trace, run+1);
 		motif = random_motif(sites, n, dmin);
		current_ic = 0.0;
		best_ic_in_run = 0.0;
		int iter=-1;

		while (++iter < max_iter){
			sample_update(sites, sequences, motif, matrix, markov, l, temperature, data, dmin);
			motif = shift(motif, sequences, sites, matrix, markov, sites_cache);
			current_llr = llr(motif, sequences, matrix, markov);
			current_ic = Iseq(motif, sequences, matrix, markov);
			best_ic_in_run = max(current_ic, best_ic_in_run);

			if (current_ic > best_ic){
				best_ic = current_ic;
				best_motif = motif;
			}
			VERBOSE3("[%d] llr=%.2f  best=%.2f iseq=%.2f\n", iter, current_llr, best_ic_in_run, current_ic);
			TRACE(trace, "%i\t%.3f\t%.3f\n", iter, current_ic, best_ic_in_run);
		}
		CLOSE_TRACE(trace);
	}
    Result result;
    result.ic = best_ic;
    result.motif = best_motif;
    result.l = l;
    return result;
}


/***************************************************************************
 *                                                                         *
 *  MAIN GIBBS
 *                                                                         *
 ***************************************************************************/
int SEED = time(NULL);

void gibbs(vector<string> &raw_sequences, Sequences & sequences, Markov &markov, Parameters &params){
    vector<Result> all_results;

	// init random number generator
	srand(SEED);

	SITES sites = all_sites(sequences, params.l);
    for(int i=0; i<params.motifs; i++){
        Result r = find_one_motif(raw_sequences, sequences, sites, markov, params);
        all_results.push_back(r);
        sites = mask_motif(sites, r.motif);        
    }

	printf("; info-gibbs %d\n", VERSION);
	printf("; %s\n", COMMAND_LINE);
	printf("; random seed                   %d\n", SEED);
	printf("; number of runs                %d\n", params.nrun);
	printf("; sequences (including rc)      %d\n", (int) raw_sequences.size());
	printf("; expected motif occurrences    %d\n", params.n);
	printf("; prior                         a:%.3f|c:%.3f|g:%.3f|t:%.3f\n", markov.priori[0], markov.priori[1], markov.priori[2], markov.priori[3]);
	printf("; number of motifs              %d\n", params.motifs);
	printf(";\n");    
    for(int m=0; m<(int)all_results.size(); m++){
	    printf("; motif                         %d\n", m);
        print_motif(all_results[m].motif, raw_sequences, sequences, all_results[m].l, all_results[m].ic, params.rc);
    }

    
}

