#include "train_eval.h"
#include "train_labelling.h"
#include "train_types.h"
#include <graph_types.h>
#include <vector>
#include <string>
#include <fstream>
#include <algorithm>
#include <dai/factorgraph.h>


#ifdef DEBUG_TRAIN_LABELLING
extern
void	train_labelling_log	(char*);
static
char	log_txt	[4 * 4096];
#define	LOG_TRAIN_LABELLING(s) sprintf(log_txt,s); \
							if(Classifier::Verbose()) train_labelling_log(log_txt);
#define	LOG_TRAIN_LABELLING_2(s,t) sprintf(log_txt,s,t); \
							if(Classifier::Verbose()) train_labelling_log(log_txt);
#define	LOG_TRAIN_LABELLING_3(s,t,u) sprintf(log_txt,s,t,u); \
							if(Classifier::Verbose()) train_labelling_log(log_txt);
#define	LOG_TRAIN_LABELLING_4(s,t,u,v) sprintf(log_txt,s,t,u,v); \
							if(Classifier::Verbose()) train_labelling_log(log_txt);
#else

#define	LOG_TRAIN_LABELLING(s)
#define	LOG_TRAIN_LABELLING_2(s,t)
#define	LOG_TRAIN_LABELLING_3(s,t,u)
#define	LOG_TRAIN_LABELLING_4(s,t,u,v)

#endif // DEBUG_TRAIN_LABELLING

namespace	Training	{

size_t
Eval::m_False_Negative	=	0,
Eval::m_False_Positive	=	0,
Eval::m_True_Negative	=	0,
Eval::m_True_Positive	=	0,
Eval::m_FN_obj			=	0,
Eval::m_FP_obj			=	0,
Eval::m_TN_obj			=	0,
Eval::m_TP_obj			=	0,
Eval::m_FN_opt			=	0,
Eval::m_FP_opt			=	0,
Eval::m_TN_opt			=	0,
Eval::m_TP_opt			=	0;

double
Eval::m_Prec_opt		=	0.,
Eval::m_Recall_opt		=	0.,
Eval::m_F1_opt			=	0.,
Eval::m_Prec_0_opt		=	0.,
Eval::m_Recall_0_opt	=	0.,
Eval::m_F1_0_opt		=	0.,
Eval::m_Acc_opt			=	0.,
Eval::m_TPrate_opt		=	0.,
Eval::m_FPrate_opt		=	0.,
Eval::m_CC_opt			=	0.,
Eval::m_Prec_obj		=	0.,
Eval::m_Recall_obj		=	0.,
Eval::m_F1_obj			=	0.,
Eval::m_Prec_0_obj		=	0.,
Eval::m_Recall_0_obj	=	0.,
Eval::m_F1_0_obj		=	0.,
Eval::m_Acc_obj			=	0.,
Eval::m_CC_obj			=	0.;

const double
Eval::m_MinResults_Diff	=	.25,
Eval::m_MinResults_Slope=	2.8;

double
Eval::m_ROCset_FPR		=	.3,
Eval::m_ROCset_TPR		=	.4;

void
Eval::FullInit() {
	Init_ForObject(0);
}

void
Eval::Init_ForObject(size_t nodes_size) {
	m_FN_opt = nodes_size;
	m_FP_opt = nodes_size;
	m_TP_opt = 0;
	m_TN_opt = 0;
	m_Prec_0_opt = m_Recall_0_opt = m_F1_0_opt =
			m_Prec_opt = m_Recall_opt = m_F1_opt = m_Acc_opt = m_CC_opt = 0.;
	m_TPrate_opt = 0.;
	m_FPrate_opt = 1.;
}

void
Eval::EvalMAPLabelling(const std::vector<GraphGen::Node<Types::Node> >& nodes) {
	static char map [4 * 4096];
	char* p(NULL);
	static size_t no_evals(0);
	std::ofstream o;

	assert(sizeof(map) > Classifier::Get_MAP_Nodes().size() * 2);
	p = map;
	memset((void*)map,0,sizeof(map));
	m_FN_obj = m_FP_obj = m_TP_obj = m_TN_obj = 0;
	if(Classifier::GetOpt_Output_Per_Object())
		o.open(Classifier::Get_Output_Obj().c_str());
    for(size_t j(0); j < Classifier::Get_MAP_Nodes().size(); j++) {
    	std::vector<GraphGen::Node<Types::Node> >::const_iterator n = find_if(nodes.begin(),nodes.end(),Types::FindNodeId(Classifier::Get_Fg()->var(j).label()));
    	if(Classifier::GetOpt_Output_Per_Object())
    		o << (*n).m_Nd.m_Aa_Name.c_str() << "\t" << Classifier::Get_MAP_Nodes()[j] << std::endl;

		if(state(n->m_Nd.m_RefLabel) != 0) {
			sprintf(p++,"%1d",1);
			if(1 != Classifier::Get_MAP_Nodes()[j])
				m_FN_obj++, sprintf(p++,"%1d,",0);
			else
				m_TP_obj++, sprintf(p++,"%1d,",1);
		} else {
			sprintf(p++,"%1d",0);
			if(0 != Classifier::Get_MAP_Nodes()[j])
				m_FP_obj++, sprintf(p++,"%1d,",1);
			else
				m_TN_obj++, sprintf(p++,"%1d,",0);
		}
		p++;
		if(j > 0 && (j % 10) == 0)
			sprintf(p++,"\n");
    }
    LOG_TRAIN_LABELLING_2("%s",map)
	if(Classifier::GetOpt_Output_Per_Object()) {
		o.close();
		StatsPerObject();
	}
    m_False_Negative += m_FN_obj;
    m_True_Positive += m_TP_obj;
    m_True_Negative += m_TN_obj;
    m_False_Positive += m_FP_obj;

    if(Classifier::GetOpt_Imm_Rslts() != 0 && (++no_evals % Classifier::GetOpt_Imm_Rslts()) == 0)
    	StatsValidation();
}

void
Eval::StatsPerObject() {
	double prod(0.), prod1(0.), prod2(0.), tprate(0.), fprate(0.), slope(0.);
	m_Prec_0_obj = m_Recall_0_obj = m_F1_0_obj = m_Prec_obj = m_Recall_obj = m_F1_obj = m_Acc_obj = m_CC_obj = 0.;

	/* quality measures label 1	*/
	if(m_TP_obj + m_FP_obj > 0)
		m_Prec_obj = (double)m_TP_obj / ((double)m_TP_obj + m_FP_obj);
	if(m_TP_obj + m_FN_obj > 0)
		m_Recall_obj = (double)m_TP_obj / ((double)m_TP_obj + m_FN_obj);
	if(m_Prec_obj + m_Recall_obj != 0.)
		m_F1_obj = (2. * m_Prec_obj * m_Recall_obj) / (m_Prec_obj + m_Recall_obj);
	if(m_Prec_obj + m_Recall_obj != 0.)
		m_F1_obj = (2. * m_Prec_obj * m_Recall_obj) / (m_Prec_obj + m_Recall_obj);

	/* quality measures label 0	*/
	if(m_TN_obj + m_FN_obj > 0)
		m_Prec_0_obj = (double)m_TN_obj / ((double)m_TN_obj + m_FN_obj);
	if(m_TN_obj + m_FP_obj > 0)
		m_Recall_0_obj = (double)m_TN_obj / ((double)m_TN_obj + m_FP_obj);
	if(m_Prec_0_obj + m_Recall_0_obj != 0.)
		m_F1_0_obj = (2. * m_Prec_0_obj * m_Recall_0_obj) / (m_Prec_0_obj + m_Recall_0_obj);

	/* general quality measures	*/
	if(m_TP_obj + m_TN_obj + m_FP_obj + m_FN_obj > 0) {
		m_Acc_obj = (m_TP_obj + (double)m_TN_obj) / ((double)m_TP_obj + m_TN_obj + m_FP_obj + m_FN_obj);
		prod = (double)((m_TP_obj + m_FN_obj) * (m_TP_obj + m_FP_obj) * (m_TN_obj + m_FP_obj) * (m_TN_obj + m_FN_obj));
		if(prod != 0.) {
			prod = sqrt(prod);
			prod1 = (double)(m_TP_obj * m_TN_obj);
			prod2 = (double)(m_FP_obj * m_FN_obj);
			m_CC_obj = (prod1 - prod2) / prod;
		}
		if(((double)m_TP_obj + (double)m_FN_obj) > 0.)
			tprate = (double)m_TP_obj / ((double)m_TP_obj + (double)m_FN_obj);
		if(((double)m_FP_obj + (double)m_TN_obj))
			fprate = (double)m_FP_obj / ((double)m_FP_obj + (double)m_TN_obj);
		if(fprate > 0.)
			slope = (tprate/fprate);
		sprintf(log_txt, "\n------------------------------------------------------------------------------------\n"
				"\tResults for %s\n"
				"\tTP\t%lu\n"
				"\tTN\t%lu\n"
				"\tFP\t%lu\n"
				"\tFN\t%lu\n"
				"\tTPR\t%.4f\n"
				"\tFPR\t%.4f\n"
				"\tDff\t%.4f\n"
				"\tTan\t%.4f\n",
				Classifier::Get_Output_Obj().c_str(),
				m_TP_obj,m_TN_obj,m_FP_obj,m_FN_obj,tprate,fprate,(tprate-fprate),slope
				);
		if(!Classifier::Verbose() && !Classifier::MoreVerbose())
			printf("%s",log_txt);
		LOG_TRAIN_LABELLING_2("%s",log_txt)
	}
}

void
Eval::StatsValidation() {
	double precision_0(0.), recall_0(0.), f1_0(0.), precision(0.), recall(0.), f1(0.), accuracy(0.), cc(0.);
	double prod(0.), prod1(0.), prod2(0.), tp_rate(0.), fp_rate(0.);

	/* quality measures label 1	*/
	if(m_True_Positive + m_False_Positive > 0)
		precision = m_True_Positive / ((double)m_True_Positive + m_False_Positive);
	if(m_True_Positive + m_False_Negative > 0)
		recall = m_True_Positive / ((double)m_True_Positive + m_False_Negative);
	if(precision + recall != 0.)
		f1 = (2. * precision * recall) / (precision + recall);

	/* quality measures label 0	*/
	if(m_True_Negative + m_False_Negative > 0)
		precision_0 = m_True_Negative / ((double)m_True_Negative + m_False_Negative);
	if(m_True_Negative + m_False_Positive > 0)
		recall_0 = m_True_Negative / ((double)m_True_Negative + m_False_Positive);
	if(precision_0 + recall_0 != 0.)
		f1_0 = (2. * precision_0 * recall_0) / (precision_0 + recall_0);

	/* general quality measures	*/
	if(m_True_Positive + m_True_Negative + m_False_Positive + m_False_Negative > 0) {
		accuracy = (m_True_Positive + (double)m_True_Negative) / ((double)m_True_Positive + m_True_Negative + m_False_Positive + m_False_Negative);
		prod = (double)((m_True_Positive + m_False_Negative) * (m_True_Positive + m_False_Positive) * (m_True_Negative + m_False_Positive) * (m_True_Negative + m_False_Negative));
		if(prod != 0.) {
			prod = sqrt(prod);
			prod1 = (double)(m_True_Positive * m_True_Negative);
			prod2 = (double)(m_False_Positive * m_False_Negative);
			cc = (prod1 - prod2) / prod;
			tp_rate = (double)m_True_Positive / ((double)m_True_Positive + (double)m_False_Negative);
			fp_rate = (double)m_False_Positive / ((double)m_False_Positive + (double)m_True_Negative);
		}
		sprintf(log_txt, "\n\n------------------------------------------------------------------------------------\n"
				"\tResults overall.\n"
				"\tTP\t%lu\n"
				"\tTN\t%lu\n"
				"\tFP\t%lu\n"
				"\tFN\t%lu\n"
				"\tTPR\t%.4f\n"
				"\tFPR\t%.4f\n"
				"\tDrt\t%.3f\n"
				"----------------------------------------------------------------------------------------\n",
				m_True_Positive,m_True_Negative,m_False_Positive,m_False_Negative,tp_rate,fp_rate,(tp_rate-fp_rate)
				);
		printf("%s",log_txt);
		LOG_TRAIN_LABELLING_2("%s",log_txt)
	}
}

void
Eval::SaveBestParams(size_t at_learn_step) {
	m_FN_opt = m_FN_obj;
	m_FP_opt = m_FP_obj;
	m_TP_opt = m_TP_obj;
	m_TN_opt = m_TN_obj;
	m_Prec_opt = m_Prec_obj;
	m_F1_opt = m_F1_obj;
	m_Recall_opt = m_Recall_obj;
	m_Prec_0_opt = m_Prec_0_obj;
	m_Recall_0_opt = m_Recall_0_obj;
	m_F1_0_opt = m_F1_0_obj;
	m_CC_opt = m_CC_obj;
	m_Acc_opt = m_Acc_obj;
	m_TPrate_opt = (double)m_TP_obj / ((double)m_TP_obj + m_FN_obj);
	m_FPrate_opt = (double)m_FP_obj / ((double)m_FP_obj + m_TN_obj);
}

size_t
Eval::state(Types::Node::RefLabel label) {
	return (label == Types::Node::RefLabel_If ? 1 : 0);
}

Types::Node::RefLabel
Eval::state(size_t label) {
	return (label == 1 ? Types::Node::RefLabel_If : Types::Node::RefLabel_NoIf);
}

void
Eval::Save_FactorGraph() {
	char b [16];
	std::string fn("log/fg.");
	memset((void*)b,0,sizeof(b));
	sprintf(b,"%lu.%lu",Classifier::Get_Overall_LearnStep(),Classifier::Get_LearnStep());
	fn.append(b,strlen(b));
    Classifier::Get_Fg()->WriteToFile(fn.c_str());
}

void
Eval::Make_Reference_Labelling(const std::vector<GraphGen::Node<Types::Node> >& nodes) {
	Classifier::Get_Ref_Labelling().resize(Classifier::Get_Nr_Nodes());
	for( size_t i(0); i < Classifier::Get_Fg()->vars().size(); i++) {
		std::vector<GraphGen::Node<Types::Node> >::const_iterator n = find_if(nodes.begin(),nodes.end(),Types::FindNodeId(Classifier::Get_Fg()->var(i).label()));
		if(n != nodes.end()) {
			Classifier::Get_Ref_Labelling()[i] = state(n->m_Nd.m_RefLabel);
#if 0
			if(Classifier::MoreVerbose()) {
				LOG_TRAIN_LABELLING_4("%lu -> %lu(%lu)",i,Classifier::Get_Fg->var(i).label(),n->m_Id)
			}
#endif //0
		}
	}
}

bool
Eval::ComputeFunctional_ROCdiff() {
	double rate_tp(0.), rate_fp(0.);
	bool ret(false);
	if(((double)m_TP_obj + m_FN_obj))
		rate_tp = (double)m_TP_obj / ((double)m_TP_obj + m_FN_obj);
	if(((double)m_FP_obj + m_TN_obj))
		rate_fp = (double)m_FP_obj / ((double)m_FP_obj + m_TN_obj);
	ret = (rate_tp - rate_fp > m_MinResults_Diff && rate_tp - rate_fp >  m_TPrate_opt - m_FPrate_opt);
	return ret;
}

bool
Eval::ComputeFunctional_ROCcurve() {
	double rate_tp(0.), rate_fp(0.), slope(0.);
	bool ret(false);
	if(((double)m_TP_obj + m_FN_obj))
		rate_tp = (double)m_TP_obj / ((double)m_TP_obj + m_FN_obj);
	if(((double)m_FP_obj + m_TN_obj))
		rate_fp = (double)m_FP_obj / ((double)m_FP_obj + m_TN_obj);
	ret = (rate_tp - rate_fp > m_MinResults_Diff && rate_tp - rate_fp >  m_TPrate_opt - m_FPrate_opt);
	if(rate_fp > 0.) {
		slope = rate_tp/rate_fp;
		return (slope > m_MinResults_Slope && slope > m_TPrate_opt/m_FPrate_opt);
	} else
		return ret;
}

bool
Eval::ComputeFunctional_ROCset_1() {
	double rate_tp(0.), rate_fp(0.);
	bool ret(false);
	static double range(m_ROCset_TPR * .1);
	if(((double)m_TP_obj + m_FN_obj))
		rate_tp = (double)m_TP_obj / ((double)m_TP_obj + m_FN_obj);
	if(((double)m_FP_obj + m_TN_obj))
		rate_fp = (double)m_FP_obj / ((double)m_FP_obj + m_TN_obj);
	if(m_TPrate_opt != rate_tp && m_FPrate_opt != rate_fp)
		ret = (rate_tp >= m_ROCset_TPR - range && rate_tp <= m_ROCset_TPR + range);
	return ret;
}

bool
Eval::ComputeFunctional_ROCset_2() {
	double rate_tp(0.), rate_fp(0.);
	bool ret(false);
	static double range(m_ROCset_FPR * .1);
	if(((double)m_TP_obj + m_FN_obj))
		rate_tp = (double)m_TP_obj / ((double)m_TP_obj + m_FN_obj);
	if(((double)m_FP_obj + m_TN_obj))
		rate_fp = (double)m_FP_obj / ((double)m_FP_obj + m_TN_obj);
	if(m_TPrate_opt != rate_tp && m_FPrate_opt != rate_fp)
		ret = (rate_fp >= m_ROCset_FPR - range && rate_fp <= m_ROCset_FPR + range);
	return ret;
}

template<class T> struct PrintVec : public std::unary_function<T, void> {
	PrintVec(std::ofstream& out) : os(out), count(0) { }
	void operator() (T x) {
		os << x << std::endl;
	}
	std::ofstream& os;
	size_t count;
};

void
Eval::Save_WeightVectors(bool final, bool new_chunk, const std::vector<double>& w_q, const std::vector<double>& w_g) {
	char b [16];
	std::ofstream outfile;
	std::string fn(Classifier::Get_Output_Dir());
	memset((void*)b,0,sizeof(b));
	fn += "/learn.vecs";
	if(!final) {
		if(!new_chunk)
			sprintf(b,".%lu.%lu",Classifier::Get_Overall_LearnStep(),Classifier::Get_LearnStep());
		else
			sprintf(b,".%lu.%luc",Classifier::Get_Overall_LearnStep(),Classifier::Get_LearnStep());
		fn.append(b,strlen(b));
	}
	outfile.open(fn.c_str());
	if(!final)
		outfile << "# " << Classifier::Get_Output_Obj() << std::endl;
	outfile << w_q.size() << std::endl;
	for_each(w_q.begin(),w_q.end(), PrintVec<double>(outfile));
	outfile << w_g.size() << std::endl;
	for_each(w_g.begin(),w_g.end(), PrintVec<double>(outfile));
	outfile.close();
}

void
Eval::Save_Runs_BestWeightVectors(const std::vector<double>& w_q, const std::vector<double>& w_g) {
	char b [16];
	std::ofstream outfile;
	std::string fn(Classifier::Get_Output_Dir());
	memset((void*)b,0,sizeof(b));
	fn += "/learn.vecs";
	sprintf(b,".%lu.%lu-runopt",Classifier::Get_Overall_LearnStep(),Classifier::Get_LearnStep());
	fn.append(b,strlen(b));
	outfile.open(fn.c_str());
	outfile << "# " << Classifier::Get_Output_Obj() << std::endl;
	outfile << w_q.size() << std::endl;
	for_each(w_q.begin(),w_q.end(), PrintVec<double>(outfile));
	outfile << w_g.size() << std::endl;
	for_each(w_g.begin(),w_g.end(), PrintVec<double>(outfile));
	outfile.close();
}

template<class T>
struct Pr : public std::unary_function<T, void> {
	Pr		() {
		memset((void*)b,0,sizeof(b));
		p = b;
	}

	void
	operator() (T x) {
		if(x != 0)
			*p++ = '1';
		else
			*p++ = '0';
	}

	static
	char	b	[4096];

	static
	char*	p;
};

template<class T>
char
Pr<T>::b [4096];

template<class T>
char*
Pr<T>::p;

void
Eval::Log_ClassifierResults() {
	for_each(Classifier::Get_MAP_Nodes().begin(),Classifier::Get_MAP_Nodes().end(), Pr<size_t>());
	double map_score = Score(Classifier::Get_MAP_Nodes());
    LOG_TRAIN_LABELLING_4("x*%lu,%lu>%s",Classifier::Get_LearnStep(),Classifier::Get_MAP_Nodes().size(),Pr<size_t>::b)
	LOG_TRAIN_LABELLING_3("MAP*%lu> %.5f",Classifier::Get_LearnStep(),map_score)
    for_each(Classifier::Get_Ref_Labelling().begin(),Classifier::Get_Ref_Labelling().end(), Pr<size_t>());
    LOG_TRAIN_LABELLING_4("rx%lu,%lu>%s",Classifier::Get_LearnStep(),Classifier::Get_Ref_Labelling().size(),Pr<size_t>::b)
    LOG_TRAIN_LABELLING_3("rMAP%lu> %.5f",Classifier::Get_LearnStep(),Score(Classifier::Get_Ref_Labelling()))
}

double
Eval::Score(const std::vector<size_t>& statevec) {
    std::map<dai::Var, size_t> statemap;
    double lS(Classifier::Run_InLogDomain() ? 0. : 1.);
    for(size_t i(0); i < statevec.size(); i++)
    	statemap[Classifier::Get_Fg()->var(i)] = statevec[i];

    dai::State S(statemap);
    double val(.0);
    for( size_t I(0); I < Classifier::Get_Fg()->nrFactors(); I++ ) {
    	if(Classifier::Run_InLogDomain()) {
    		val = Classifier::Get_Fg()->factor(I)[dai::BigInt_size_t(S(Classifier::Get_Fg()->factor(I).vars()))];
    		lS += val;
    	} else {
    		val = Classifier::Get_Fg()->factor(I)[dai::BigInt_size_t(S(Classifier::Get_Fg()->factor(I).vars()))];
    		lS *= val;
    	}
    }
    return lS;
}

}
