#!/usr/bin/env perl
############################################################
#
# $Id: contingency-stats,v 1.14 2011/02/17 04:54:49 rsat Exp $
#
# Time-stamp: <2003-07-04 12:48:55 jvanheld>
#
############################################################

## use strict;

=pod

=head1 NAME

contingency-stats

=head1 DESCRIPTION

Takes as input a contingency table, and calculates various matching
statistics between the rows and columns (Sn, PPV, Acc, ...). The
description of these statistics can be found in Brohee and van Helden,
2006.

=head1 AUTHORS

Jacques van Helden <jvhelden@ulb.ac.be>

=head1 CATEGORY

util

=head1 USAGE

contingency-stats [-i inputfile] [-o outputfile] [-v]

=head1 INPUT FORMAT

A contingency table is a N*M table used to compare the contents of two
classifications. Rows represent the clusters of the first
classification (considered as reference), and columns the clusters of
the second classification (query).

Contingency tables can be generated with the program
contingency-table, or with compare-classes (option -matrix QR).

=head1 OUTPUT FORMAT

A tab-delimited text file with one row per statistics. 

=head1 STATISTICS

=over 4

=item B<Sn>

Sensitivity. This parameter indicates the fraction of each reference
cluster (row) covered by its best matching query cluster (column).

Sensitivity is calculated at the level of each cell (cell-wise Sn), of
each row (row-wise Sn) and of the whole contingency table
(table-wise Sn).

=over

=item I<Cell-wise sensitivity>

Sn_{i,j} = X_{i,j}/SUM_j(X_{i,j})

=item I<row-wise sensitivity> 

Sn_{i.} = MAX_j(Sn_{i,j})

The row-wise sensitivity of a row is the maximal value of
sensitivity for all the cells of this row.

=item I<table-wise sensitivity> 

Sn = SUM_i(Sn_{i.})/M

The table-wise sensitivity is the average of the row-wise
sensitivity over all the rows of the contingency table.

=back

=item B<PPV>

Positive Predictive Value. This parameter indicates the fraction of
each query cluster (column) covered by its best matching reference
cluster (row).

PPV is calculated at the level of each cell (cell-wise PPV), of each
column (column-wise PPV) and of the whole contingency table
(table-wise PPV).

=over

=item I<Cell-wise PPV>

PPV_{i,j} = X_{i,j}/SUM_i(X_{i,j})

=item I<column-wise PPV> 

PPV_{.j} = MAX_i(PPV_{i,j})

The column-wise PPV of a column is the maximal value of PPV for all
the cells of this column.

=item I<table-wise PPV> 

PPV = SUM_j(PPV_{j.})/N

The table-wise PPV is the average of the column-wise PPV over
all the columns of the contingency table.

=back

=item B<Acc.geom>

Geometric accuracy. This reflects the tradeoff between sensitivity and
positive predictive value, by computing the geometric accuracy between
Sn and PPV.

Acc.geom = sqrt(Sn*PPV)

=item B<Sep>

Separation.

The separation is defined, at the level of each cell (cell-wise
separation) as the product between Sn and PPV.

=over

=item I<Cell-wise separation>

sep_{i,j}=Sn_{i,j}*PPV_{i,j}

=item I<Column-wise separation>

Column-wise separation is defined at the level of each column, as the
sum of separation value for all the cells of this column.

sep_{.j} = SUM_i( sep_{i,j})

=item I<Row-wise separation>

Row-wise separation is defined at the level of each row, as the
sum of separation value for all the cells of this row.

sep_{i.} = SUM_j(sep_{i,j})

=item I<Table-wise separation>

Three table-wise statistics are computed for separation.

=over

=item average column-wise separation

sep_c = AVG_j(sep_{.j})

=item average row-wise separation

sep_r = AVG_i(sep_{i.})

=item table-wise separation

sep = sqrt(sep_r*sep_c)

=back

=back

=back

=head1 SEE ALSO

=item contingency-table

=item compare-classes

=head1 REFERENCES

=item Brohee, S. & van Helden, J. (2006). Evaluation of clustering
algorithms for protein-protein interaction networks. BMC
Bioinformatics 7, 488.

=cut


BEGIN {
    if ($0 =~ /([^(\/)]+)$/) {
	push (@INC, "$`lib/");
    }
}
require "RSA.lib";
use RSAT::stats;
use RSAT::table;


################################################################
## Main package
package main;
{

    ################################################################
    #### initialise parameters
    local $start_time = &RSAT::util::StartScript();
    my $header=1;
    local $decimals = 3;
    local @param_keys = ();
    local @param_values = ();

    %main::infile = ();
    %main::outfile = ();

    $main::verbose = 0;
#    $main::in = STDIN;
    $main::out = STDOUT;
    ## Output fields
    %supported_return_fields = (
				stats=>1,
				rowstats=>1,
				colstats=>1,
				tables=>1,
				margins=>1,
				all=>1,
			       );
    $supported_return_fields = join (",", sort(keys( %supported_return_fields)));

    local %return_fields = ();

    &ReadArguments();

    ################################################################
    #### check argument values

    ## Return fields
    unless (scalar(keys(%return_fields)) > 0) {
      $return_fields{stats} = 1;
    }
    if ($return_fields{all}) {
      foreach my $field (keys %supported_return_fields) {
	$return_fields{$field} = 1;
      }
    }


    ################################################################
    ### open output stream
    $main::out = &OpenOutputFile($main::outfile{output});

    ################################################################
    ### Read input table
    local $counts = new RSAT::table();
    $counts->readFromFile($infile{input}, "tab",header=>$header);
    my $ncol = $counts->ncol();
    my $nrow = $counts->nrow();
    $counts->set_attribute("type", "counts");
    $counts->force_attribute("margins", $return_fields{margins});

    ## Read group sizes if specified
    if ($infile{rsizes}) {
      &read_row_sizes();
    }
    if ($infile{csizes}) {
      &read_col_sizes();
    }

    ## PPV
    my $PPV = new RSAT::table();
    $PPV->setAlphabet($counts->getAlphabet);
    $PPV->push_attribute("header", $counts->get_attribute("header"));
    $PPV->setTable($nrow, $ncol, $counts->col_freq());
    $PPV->set_attribute("type", "freq");
    $PPV->force_attribute("margins", $return_fields{margins});
    my @col_wise_PPV = $PPV->col_max();

    ## Sn
    my $Sn = new RSAT::table();
    $Sn->setAlphabet($counts->getAlphabet);
    $Sn->push_attribute("header", $counts->get_attribute("header"));
    $Sn->setTable($nrow, $ncol, $counts->row_freq());
    $Sn->set_attribute("type", "freq");
    $Sn->force_attribute("margins", $return_fields{margins});
    my @row_wise_Sn = $Sn->row_max();

    ## Separation
    my $sep = new RSAT::table();
    $sep->setAlphabet($counts->getAlphabet);
    $sep->push_attribute("header", $counts->get_attribute("header"));
    my @Sn_table = $Sn->getTable();
    my @PPV_table = $PPV->getTable();
    my @sep_table = ();
    foreach my $r (0..($nrow-1)) {
      for my $c (0..($ncol-1)) {
	$sep_table[$c][$r]=$Sn_table[$c][$r]*$PPV_table[$c][$r];
      }
    }
    $sep->setTable($nrow, $ncol, @sep_table);
    $sep->set_attribute("type", "freq");
    $sep->force_attribute("margins", $return_fields{margins});

    ## Calculate basic table-wise statistics
    my %table_wise_stats = ();
    $table_wise_stats->{ncol} = $counts->ncol();
    $table_wise_stats->{nrow} = $counts->nrow();
    $table_wise_stats->{sum} = &RSAT::stats::sum($counts->row_sum());
    $table_wise_stats->{min} = &RSAT::stats::min($counts->row_min());
    $table_wise_stats->{max} = &RSAT::stats::max($counts->row_max());
#    my @counts = $counts->getTable();
#    for my $r (0..($nrow-1)) {
#      for my $c (0..($ncol-1)) {
#	$table_wise_stats->{sum}+=$counts[$c][$r];
#      }
#    }
    $table_wise_stats->{mean} = $table_wise_stats->{sum}/($ncol*$nrow);

    ## Calculate table-wise Sn, PPV, and geometric accuracy
    $table_wise_stats->{Sn} = &RSAT::stats::mean(@row_wise_Sn);
    $table_wise_stats->{PPV} = &RSAT::stats::mean(@col_wise_PPV);
    $table_wise_stats->{acc} = ($table_wise_stats->{Sn} + $table_wise_stats->{PPV})/2;
    $table_wise_stats->{acc_g} = sqrt($table_wise_stats->{Sn} * $table_wise_stats->{PPV});

    ## Calculate weigthed Sn, PPV and  geometric accuracy
    my @row_weights = $counts->row_sum();
    $table_wise_stats->{Sn_w} = 0;
    for my $r (0..($nrow-1)) {
      $table_wise_stats->{Sn_w} += $row_weights[$r]*$row_wise_Sn[$r];
    }
    $table_wise_stats->{Sn_w} /= &RSAT::stats::sum(@row_weights);
    my @col_weights = $counts->col_sum();
    $table_wise_stats->{PPV_w} = 0;
    for my $c (0..($ncol-1)) {
      $table_wise_stats->{PPV_w} += $col_weights[$c]*$col_wise_PPV[$c];
    }
    $table_wise_stats->{PPV_w} /= &RSAT::stats::sum(@row_weights);
    $table_wise_stats->{acc_w} = ($table_wise_stats->{Sn_w} + $table_wise_stats->{PPV_w})/2;
    $table_wise_stats->{acc_g_w} = sqrt($table_wise_stats->{Sn_w} * $table_wise_stats->{PPV_w});

    ## Calculate table-wise separation
    $table_wise_stats->{sep_r} = &RSAT::stats::mean($sep->row_sum());
    $table_wise_stats->{sep_c} = &RSAT::stats::mean($sep->col_sum());
    $table_wise_stats->{sep} = sqrt($table_wise_stats->{sep_r} * $table_wise_stats->{sep_c});

    ################################################################
    #### print verbose
    &Verbose() if ($main::verbose);

    ################################################################
    ###### execute the command

    ################################################################
    ###### print output

    if ($return_fields{stats}) {
      for my $p  (0..$#param_keys) {
	printf $out "%s\t%s\n", $param_keys[$p], $param_values[$p];
      }

      for my $stat  (qw(ncol nrow sum min max)) {
	printf $out "%s\t%d\n", $stat, $table_wise_stats->{$stat};
      }
      for my $stat  (qw(mean Sn PPV acc acc_g Sn_w PPV_w acc_w acc_g_w sep_r sep_c sep)) {
	printf $out "%s\t%.${decimals}g\n", $stat, $table_wise_stats->{$stat};
      }
    }

    ################################################################
    ## Print row-wise statistics
    if ($return_fields{rowstats}) {
      print $out ";\n";
      print $out "; Row statistics\n";
      print $out join("\t", 
		      "; mean", 
		      "sum", 
		      "min", 
		      "max",
		      "Sn",
		      "sep_r",
		      "row_name",
		    ), "\n";
      my @row_names = $counts->getAlphabet();
      my @row_sum = $counts->row_sum();
      my @row_min = $counts->row_min();
      my @row_max = $counts->row_max();
      my @row_mean = $counts->row_mean();
      my @row_Sn = $Sn->row_max();
      my @row_sep = $sep->row_sum();
      for my $r (0..($nrow-1)) {
	print $out join("\t", 
			sprintf ("%.${decimals}g", $row_mean[$r]),
			$row_sum[$r],
			$row_min[$r],
			$row_max[$r],
			sprintf ("%.${decimals}g", $row_Sn[$r]),
			sprintf ("%.${decimals}g", $row_sep[$r]),
			$row_names[$r],
		       ), "\n";
      }
    }


    ################################################################
    ## Print column-wise statistics
    if ($return_fields{colstats}) {
      print $out ";\n";
      print $out "; Column statistics\n";
      print $out join("\t", 
		      "; mean", 
		      "sum", 
		      "min", 
		      "max",
		      "PPV",
		      "sep_c",
		      "column_name",
		     ), "\n";
      my @col_names = $counts->get_attribute("header");
      my @col_sum = $counts->col_sum();
      my @col_min = $counts->col_min();
      my @col_max = $counts->col_max();
      my @col_mean = $counts->col_mean();
      my @col_PPV = $PPV->col_max();
      my @col_sep = $sep->col_sum();
      for my $c (0..($ncol-1)) {
	print $out join("\t", 
			sprintf ("%.${decimals}g", $col_mean[$c]),
			$col_sum[$c],
			$col_min[$c],
			$col_max[$c],
			sprintf ("%.${decimals}g", $col_PPV[$c]),
			sprintf ("%.${decimals}g", $col_sep[$c]),
			$col_names[$c] || "col $c+1",
		       ), "\n";
      }
    }

    ################################################################
    ## Full tables
    if ($return_fields{tables}) {
      print $out $counts->toString(corner=>"Count");
      print $out $PPV->toString(decimals=>$decimals, corner=>"PPV");
      print $out $Sn->toString(decimals=>$decimals, corner=>"Sn");
      print $out $sep->toString(decimals=>$decimals, corner=>"Sep");
    }


    ################################################################
    ## Report execution time and close output stream
    my $exec_time = &RSAT::util::ReportExecutionTime($start_time); ## This has to be exectuted by all scripts
    print $main::out $exec_time if ($main::verbose >= 1); ## only report exec time if verbosity is specified
    close $main::out if ($main::outfile{output});


    exit(0);
}

################################################################
################### subroutine definition ######################
################################################################


################################################################
#### display full help message 
sub PrintHelp {
    system "pod2text -c $0";
    exit()
}

################################################################
#### display short help message
sub PrintOptions {
    &PrintHelp();
}

################################################################
#### Read arguments 
sub ReadArguments {
#    foreach my $a (0..$#ARGV) {
    my $arg = "";

    my @arguments = @ARGV; ## create a copy to shift, because we need ARGV to report command line in &Verbose()


    while ($arg = shift (@arguments)) {

	## Verbosity
=pod


=head1 OPTIONS

=over 4

=item B<-v #>

Level of verbosity (detail in the warning messages during execution)

=cut
	if ($arg eq "-v") {
	    if (&IsNatural($arguments[0])) {
		$main::verbose = shift(@arguments);
	    } else {
		$main::verbose = 1;
	    }

	    ## Help message
=pod

=item B<-h>

Display full help message

=cut
	} elsif ($arg eq "-h") {
	    &PrintHelp();

	    ## List of options
=pod

=item B<-help>

Same as -h

=cut
	} elsif ($arg eq "-help") {
	    &PrintOptions();


	    ## Input file
=pod

=item B<-i inputfile>

If no input file is specified, the standard input is used.  This
allows to use the command within a pipe.

=cut
	} elsif ($arg eq "-i") {
	    $main::infile{input} = shift(@arguments);

	    ## Output file
=pod

=item	B<-o outputfile>

If no output file is specified, the standard output is used.  This
allows to use the command within a pipe.

=cut
	} elsif ($arg eq "-o") {
	    $main::outfile{output} = shift(@arguments);

	    ## Output file
=pod

=item	B<-decimals #>

Number of decimals to display for the computed statistics.

=cut
	} elsif ($arg eq "-decimals") {
	    $main::decimals = shift(@arguments);

	    ## Output file
=pod

=item	B<-param param_value>

Include a user-specified parameter value in the table-wise
statistics. Can be used iteratively on the same commande line.

This is useful when the script is used iteratively, to calculate
statistics under various conditions (for example various parameter
values for a clustering algorithm). The table-wise statistics can then
be integrated in a single table with the program clompare-scores.

=cut
	} elsif ($arg eq "-param") {
	    push @param_keys, shift(@arguments);
	    push @param_values, shift(@arguments);

=item B<-return return_fields>

List of fields to return.

Supported fields: stats, rowstats, colstats, tables, margins

=over

=item I<stats>

table-wise statistics

=item I<rowstats>

row-wise statistics (one line per row of the contingency table)

=item I<colstats>

column-wise statistics (one line per column of the contingency table)

=item I<tables>

full tables for each statistics (counts, Sn, PPV, separation).

=item I<margins>

marginal statistics besides the tables (requires to return tables).

=back

=cut
	  } elsif ($arg eq "-return") {
            $to_return = shift @arguments;
            my @fields_to_return = split ",", $to_return;
            foreach $field (@fields_to_return) {
	      if ($supported_return_fields{$field}) {
		$return_fields{$field} = 1;
	      } else {
		&RSAT::error::FatalError(join("\t", $field, "Invalid return field. Supported:", $supported_return_fields));
	      }
	    }

=pod

=item B<-rsizes row_size_file>

Specify row group sizes in a separate file. This option can be used in
particular cases where the marginal sum of the contingency table does
not correspond to the group sizes (for example if a classification
supoprts the same elements assigned to multiple groups, or on the
contrary if some elements can be unassigned).

The row size file must contain one row per row of the contingency
table, and two columns. The first column indicated the name of the row
(the same name as in the contingency table), and the second the size
of the corresponding group.

=cut
	  } elsif ($arg eq "-rsizes") {
	    $main::infile{rsizes} = shift(@arguments);

=pod

=item B<-csizes column_size_file>

Specify column group sizes in a separate file.

Same description as for -rsizes.

=cut
	  } elsif ($arg eq "-csizes") {
	    $main::infile{csizes} = shift(@arguments);

	  } else {
	    &FatalError(join("\t", "Invalid option", $arg));
	  }

      }

=pod

=back

=cut

}

################################################################
#### verbose message
sub Verbose {
    print $main::out "; contingency-stats ";
    &PrintArguments($main::out);
    if (%main::infile) {
	print $main::out "; Input files\n";
	while (my ($key,$value) = each %main::infile) {
	    print $main::out ";\t$key\t$value\n";
	}
    }
    if (%main::outfile) {
	print $main::out "; Output files\n";
	while (my ($key,$value) = each %main::outfile) {
	    print $main::out ";\t$key\t$value\n";
	}
    }
}


################################################################
## Read row sizes
sub read_row_sizes {
  my %rsizes = ();
  my @rsizes = ();
  my ($rsize_stream) = &OpenInputFile($infile{rsizes});
  while (<$rsize_stream>) {
    chomp();
    next if (/^;/);
    next if (/^#/);
    next if (/^--/);
    next unless (/\S/);
    my @fields = split "\t";
    my $group = $fields[0];
    my $size = $fields[1];
    &RSAT::error::FatalError(join("\t", $size, "Invalid size specification for group", $group))
      unless (&RSAT::util::IsNatural($size));
    $rsizes{$group} = $size;
  }
  close $rsize_stream;
  foreach my $row_label ($counts->getAlphabet()) {
    if (defined($rsizes{$row_label})) {
      push @rsizes, $rsizes{$row_label};
    } else {
      &RSAT::error::FatalError(join("\t", "Invalid row size file: no specified value for row", $row_label));
    }
  }
  $counts->force_row_sum(@rsizes);
}

################################################################
## Read col sizes
sub read_col_sizes {
  my %csizes = ();
  my @csizes = ();
  my ($csize_stream) = &OpenInputFile($infile{csizes});
  while (<$csize_stream>) {
    chomp();
    next if (/^;/);
    next if (/^#/);
    next if (/^--/);
    next unless (/\S/);
    my @fields = split "\t";
    my $group = $fields[0];
    my $size = $fields[1];
    &RSAT::error::FatalError(join("\t", $size, "Invalid size specification for group", $group))
      unless (&RSAT::util::IsNatural($size));
    $csizes{$group} = $size;
  }
  close $csize_stream;
  foreach my $col_label ($counts->get_attribute("header")) {
    if (defined($csizes{$col_label})) {
      push @csizes, $csizes{$col_label};
    } else {
      &RSAT::error::FatalError(join("\t", "Invalid col size file: no specified value for col", $col_label));
    }
  }
  $counts->force_col_sum(@csizes);
}

__END__

