// Begin_Html
/*
Fisher and Mahalanobis Discriminants (Linear Discriminant Analysis)
<p>
In the method of Fisher discriminants event selection is performed
in a transformed variable space with zero linear correlations, by
distinguishing the mean values of the signal and background
distributions.<br></p>
<p>
The linear discriminant analysis determines an axis in the (correlated)
hyperspace of the input variables
such that, when projecting the output classes (signal and background)
upon this axis, they are pushed as far as possible away from each other,
while events of a same class are confined in a close vicinity.
The linearity property of this method is reflected in the metric with
which "far apart" and "close vicinity" are determined: the covariance
matrix of the discriminant variable space.
</p>
<p>
The classification of the events in signal and background classes
relies on the following characteristics (only): overall sample means,
<i><my:o>x</my:o><sub>i</sub></i>, for each input variable, <i>i</i>,
class-specific sample means, <i><my:o>x</my:o><sub>S(B),i</sub></i>,
and total covariance matrix <i>T<sub>ij</sub></i>. The covariance matrix
can be decomposed into the sum of a <i>within-</i> (<i>W<sub>ij</sub></i>)
and a <i>between-class</i> (<i>B<sub>ij</sub></i>) class matrix. They describe
the dispersion of events relative to the means of their own class (within-class
matrix), and relative to the overall sample means (between-class matrix).
The Fisher coefficients, <i>F<sub>i</sub></i>, are then given by <br>
<center>
<img vspace=6 src="gif/tmva_fisherC.gif" align="bottom" >
</center>
where in TMVA is set <i>N<sub>S</sub>=N<sub>B</sub></i>, so that the factor
in front of the sum simplifies to ½.
The Fisher discriminant then reads<br>
<center>
<img vspace=6 src="gif/tmva_fisherD.gif" align="bottom" >
</center>
The offset <i>F</i><sub>0</sub> centers the sample mean of <i>x</i><sub>Fi</sub>
at zero. Instead of using the within-class matrix, the Mahalanobis variant
determines the Fisher coefficients as follows:<br>
<center>
<img vspace=6 src="gif/tmva_mahaC.gif" align="bottom" >
</center>
with resulting <i>x</i><sub>Ma</sub> that are very similar to the
<i>x</i><sub>Fi</sub>. <br></p>
TMVA provides two outputs for the ranking of the input variables:<br><p></p>
<ul>
<li> <u>Fisher test:</u> the Fisher analysis aims at simultaneously maximising
the between-class separation, while minimising the within-class dispersion.
A useful measure of the discrimination power of a variable is hence given
by the diagonal quantity: <i>B<sub>ii</sub>/W<sub>ii</sub></i>.
</li>
<li> <u>Discrimination power:</u> the value of the Fisher coefficient is a
measure of the discriminating power of a variable. The discrimination power
of set of input variables can therefore be measured by the scalar
<center>
<img vspace=6 src="gif/tmva_discpower.gif" align="bottom" >
</center>
</li>
</ul>
The corresponding numbers are printed on standard output.
*/
// End_Html
#include "Riostream.h"
#include <algorithm>
#include "TMVA/MethodFisher.h"
#include "TMVA/Tools.h"
#include "TMatrix.h"
#include "TMVA/Ranking.h"
ClassImp(TMVA::MethodFisher)
;
TMVA::MethodFisher::MethodFisher( TString jobName, TString methodTitle, DataSet& theData,
TString theOption, TDirectory* theTargetDir )
: TMVA::MethodBase( jobName, methodTitle, theData, theOption, theTargetDir )
, fTheMethod("Fisher")
{
InitFisher();
DeclareOptions();
ParseOptions();
ProcessOptions();
if (HasTrainingTree()) {
Int_t nsig = Data().GetNEvtSigTrain();
Int_t nbgd = Data().GetNEvtBkgdTrain();
fLogger << kVERBOSE << "num of events for training (signal, background): "
<< " (" << nsig << ", " << nbgd << ")" << Endl;
if (nsig != nbgd) {
fLogger << kWARNING << "\t--------------------------------------------------" << Endl;
fLogger << kWARNING << "\tWarning: different number of signal and background\n"
<< "--- " << GetName() << " \tevents: Fisher training will not be optimal :-("
<< Endl;
fLogger << kWARNING << "\t--------------------------------------------------" << Endl;
}
InitMatrices();
}
}
TMVA::MethodFisher::MethodFisher( DataSet& theData,
TString theWeightFile,
TDirectory* theTargetDir )
: TMVA::MethodBase( theData, theWeightFile, theTargetDir )
, fTheMethod("Fisher")
{
InitFisher();
DeclareOptions();
}
void TMVA::MethodFisher::InitFisher( void )
{
SetMethodName( "Fisher" );
SetMethodType( TMVA::Types::kFisher );
SetTestvarName();
fMeanMatx = 0;
fBetw = 0;
fWith = 0;
fCov = 0;
fF0 = 0;
fFisherCoeff = new vector<Double_t>( GetNvar() );
SetSignalReferenceCut( 0.0 );
}
void TMVA::MethodFisher::DeclareOptions()
{
DeclareOptionRef(fTheMethod,"Method","discrimination method");
AddPreDefVal(TString("Fisher"));
AddPreDefVal(TString("Mahalanobis"));
}
void TMVA::MethodFisher::ProcessOptions()
{
MethodBase::ProcessOptions();
if (fTheMethod == "Fisher" ) fFisherMethod = kFisher;
else fFisherMethod = kMahalanobis;
}
TMVA::MethodFisher::~MethodFisher( void )
{
delete fBetw;
delete fWith;
delete fCov;
delete fDiscrimPow;
delete fFisherCoeff;
}
void TMVA::MethodFisher::Train( void )
{
if (!CheckSanity()) fLogger << kFATAL << "<Train> sanity check failed" << Endl;
GetMean();
GetCov_WithinClass();
GetCov_BetweenClass();
GetCov_Full();
GetFisherCoeff();
GetDiscrimPower();
PrintCoefficients();
}
Double_t TMVA::MethodFisher::GetMvaValue()
{
Double_t result = fF0;
for (Int_t ivar=0; ivar<GetNvar(); ivar++)
result += (*fFisherCoeff)[ivar]*GetEventValNormalized(ivar);
return result;
}
void TMVA::MethodFisher::InitMatrices( void )
{
if (!HasTrainingTree()) {
fLogger << kFATAL << "<InitMatrices> fatal error: Data().TrainingTree() is zero pointer"
<< " --> exit(1)" << Endl;
}
fMeanMatx = new TMatrixD( GetNvar(), 3 );
fBetw = new TMatrixD( GetNvar(), GetNvar() );
fWith = new TMatrixD( GetNvar(), GetNvar() );
fCov = new TMatrixD( GetNvar(), GetNvar() );
fDiscrimPow = new vector<Double_t>( GetNvar() );
}
void TMVA::MethodFisher::GetMean( void )
{
for (Int_t ivar=0; ivar<GetNvar(); ivar++) {
Double_t sumS = 0;
Double_t sumB = 0;
for (Int_t ievt=0; ievt<Data().GetNEvtTrain(); ievt++) {
ReadTrainingEvent(ievt);
Double_t value = GetEventValNormalized( ivar );
if (Data().Event().IsSignal()) sumS += value;
else sumB += value;
}
(*fMeanMatx)( ivar, 2 ) = sumS;
(*fMeanMatx)( ivar, 0 ) = sumS/Data().GetNEvtSigTrain();
(*fMeanMatx)( ivar, 2 ) += sumB;
(*fMeanMatx)( ivar, 1 ) = sumB/Data().GetNEvtBkgdTrain();
(*fMeanMatx)( ivar, 2 ) /= Data().GetNEvtTrain();
}
}
void TMVA::MethodFisher::GetCov_WithinClass( void )
{
for (Int_t x=0; x<GetNvar(); x++) {
for (Int_t y=0; y<GetNvar(); y++) {
Double_t sumSig = 0;
Double_t sumBgd = 0;
for (Int_t ievt=0; ievt<Data().GetNEvtTrain(); ievt++) {
ReadTrainingEvent(ievt);
if (Data().Event().IsSignal()) {
sumSig += ( (GetEventValNormalized( x ) - (*fMeanMatx)(x, 0))*
(GetEventValNormalized( y ) - (*fMeanMatx)(y, 0)) );
}
else {
sumBgd += ( (GetEventValNormalized( x ) - (*fMeanMatx)(x, 1))*
(GetEventValNormalized( y ) - (*fMeanMatx)(y, 1)) );
}
}
(*fWith)(x, y) = (sumSig + sumBgd)/Data().GetNEvtTrain();
}
}
}
void TMVA::MethodFisher::GetCov_BetweenClass( void )
{
Double_t prodSig, prodBgd;
for (Int_t x=0; x<GetNvar(); x++) {
for (Int_t y=0; y<GetNvar(); y++) {
prodSig = ( ((*fMeanMatx)(x, 0) - (*fMeanMatx)(x, 2))*
((*fMeanMatx)(y, 0) - (*fMeanMatx)(y, 2)) );
prodBgd = ( ((*fMeanMatx)(x, 1) - (*fMeanMatx)(x, 2))*
((*fMeanMatx)(y, 1) - (*fMeanMatx)(y, 2)) );
(*fBetw)(x, y) = ( (Data().GetNEvtSigTrain()*prodSig + Data().GetNEvtBkgdTrain()*prodBgd)
/ Double_t(Data().GetNEvtTrain()) );
}
}
}
void TMVA::MethodFisher::GetCov_Full( void )
{
for (Int_t x=0; x<GetNvar(); x++)
for (Int_t y=0; y<GetNvar(); y++)
(*fCov)(x, y) = (*fWith)(x, y) + (*fBetw)(x, y);
}
void TMVA::MethodFisher::GetFisherCoeff( void )
{
TMatrixD* theMat = 0;
switch (GetFisherMethod()) {
case kFisher:
theMat = fWith;
break;
case kMahalanobis:
theMat = fCov;
break;
default:
fLogger << kFATAL << "<GetFisherCoeff> undefined method" << GetFisherMethod() << Endl;
}
TMatrixD invCov( *theMat );
if ( TMath::Abs(invCov.Determinant()) < 10E-24 ) {
fLogger << kWARNING << "<GetFisherCoeff> matrix is almost singular with deterninant="
<< TMath::Abs(invCov.Determinant())
<< " did you use the variables that are linear combinations or highly correlated ???"
<< Endl;
}
if ( TMath::Abs(invCov.Determinant()) < 10E-120 ) {
fLogger << kFATAL << "<GetFisherCoeff> matrix is singular with determinant="
<< TMath::Abs(invCov.Determinant())
<< " did you use the variables that are linear combinations ???"
<< Endl;
}
invCov.Invert();
Double_t xfact = ( sqrt( Double_t(Data().GetNEvtSigTrain()*Data().GetNEvtBkgdTrain()) )
/ Double_t(Data().GetNEvtTrain()) );
vector<Double_t> diffMeans( GetNvar() );
Int_t ivar, jvar;
for (ivar=0; ivar<GetNvar(); ivar++) {
(*fFisherCoeff)[ivar] = 0;
for(jvar=0; jvar<GetNvar(); jvar++) {
Double_t d = (*fMeanMatx)(jvar, 0) - (*fMeanMatx)(jvar, 1);
(*fFisherCoeff)[ivar] += invCov(ivar, jvar)*d;
}
(*fFisherCoeff)[ivar] *= xfact;
}
fF0 = 0.0;
for(ivar=0; ivar<GetNvar(); ivar++){
fF0 += (*fFisherCoeff)[ivar]*((*fMeanMatx)(ivar, 0) + (*fMeanMatx)(ivar, 1));
}
fF0 /= -2.0;
}
void TMVA::MethodFisher::GetDiscrimPower( void )
{
for (Int_t ivar=0; ivar<GetNvar(); ivar++) {
if ((*fCov)(ivar, ivar) != 0)
(*fDiscrimPow)[ivar] = (*fBetw)(ivar, ivar)/(*fCov)(ivar, ivar);
else
(*fDiscrimPow)[ivar] = 0;
}
}
const TMVA::Ranking* TMVA::MethodFisher::CreateRanking()
{
fRanking = new Ranking( GetName(), "Discr. power" );
for (Int_t ivar=0; ivar<GetNvar(); ivar++) {
fRanking->AddRank( *new Rank( GetInputExp(ivar), (*fDiscrimPow)[ivar] ) );
}
return fRanking;
}
void TMVA::MethodFisher::PrintCoefficients( void )
{
Int_t maxL = 0;
for (Int_t ivar=0; ivar<GetNvar(); ivar++) {
if ((*fInputVars)[ivar].Length() > maxL) maxL = (*fInputVars)[ivar].Length();
}
fLogger << kINFO << "results" << Endl;
fLogger << kINFO << "-------------------------------" << Endl;
fLogger << kINFO << setiosflags(ios::left) << setw(TMath::Max(maxL,10)) << "Variable \t:"
<< " Coefficient:"
<< resetiosflags(ios::right) << Endl;
fLogger << kINFO << "-------------------------------" << Endl;
for (Int_t ivar=0; ivar<GetNvar(); ivar++) {
fLogger << kINFO << Form( "%-11s\t: %+.3f", GetInputExp(ivar).Data(), (*fFisherCoeff)[ivar] )
<< Endl;
}
fLogger << kINFO << Form( "%-11s\t: %+.3f", "(offset)", fF0 ) << Endl;
fLogger << kINFO << "-------------------------------" << Endl;
}
void TMVA::MethodFisher::WriteWeightsToStream( ostream& o ) const
{
o << fF0 << endl;
for (Int_t ivar=0; ivar<GetNvar(); ivar++) o << (*fFisherCoeff)[ivar] << endl;
}
void TMVA::MethodFisher::ReadWeightsFromStream( istream& istr )
{
istr >> fF0;
for (Int_t ivar=0; ivar<GetNvar(); ivar++) istr >> (*fFisherCoeff)[ivar];
}
ROOT page - Class index - Class Hierarchy - Top of the page
This page has been automatically generated. If you have any comments or suggestions about the page layout send a mail to ROOT support, or contact the developers with any questions or problems regarding ROOT.