Logo ROOT   6.08/07
Reference Guide
ResultsMulticlass.cxx
Go to the documentation of this file.
1 // @(#)root/tmva $Id$
2 // Author: Andreas Hoecker, Peter Speckmayer, Joerg Stelzer, Helge Voss, Jan Therhaag
3 
4 /**********************************************************************************
5  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6  * Package: TMVA *
7  * Class : ResultsMulticlass *
8  * Web : http://tmva.sourceforge.net *
9  * *
10  * Description: *
11  * Implementation (see header for description) *
12  * *
13  * Authors (alphabetical): *
14  * Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland *
15  * Peter Speckmayer <Peter.Speckmayer@cern.ch> - CERN, Switzerland *
16  * Joerg Stelzer <Joerg.Stelzer@cern.ch> - CERN, Switzerland *
17  * Jan Therhaag <Jan.Therhaag@cern.ch> - U of Bonn, Germany *
18  * Helge Voss <Helge.Voss@cern.ch> - MPI-K Heidelberg, Germany *
19  * *
20  * Copyright (c) 2006: *
21  * CERN, Switzerland *
22  * MPI-K Heidelberg, Germany *
23  * U. of Bonn, Germany *
24  * *
25  * Redistribution and use in source and binary forms, with or without *
26  * modification, are permitted according to the terms listed in LICENSE *
27  * (http://tmva.sourceforge.net/LICENSE) *
28  **********************************************************************************/
29 
30 #include "TMVA/ResultsMulticlass.h"
31 
32 #include "TMVA/DataSet.h"
33 #include "TMVA/DataSetInfo.h"
34 #include "TMVA/GeneticAlgorithm.h"
35 #include "TMVA/GeneticFitter.h"
36 #include "TMVA/MsgLogger.h"
37 #include "TMVA/Results.h"
38 #include "TMVA/Tools.h"
39 #include "TMVA/Types.h"
40 
41 #include "TH1F.h"
42 
43 #include <limits>
44 #include <vector>
45 
46 
47 ////////////////////////////////////////////////////////////////////////////////
48 /// constructor
49 
51  : Results( dsi, resultsName ),
52  IFitterTarget(),
53  fLogger( new MsgLogger(Form("ResultsMultiClass%s",resultsName.Data()) , kINFO) ),
54  fClassToOptimize(0),
55  fAchievableEff(dsi->GetNClasses()),
56  fAchievablePur(dsi->GetNClasses()),
57  fBestCuts(dsi->GetNClasses(),std::vector<Double_t>(dsi->GetNClasses()))
58 {
59 }
60 
61 ////////////////////////////////////////////////////////////////////////////////
62 /// destructor
63 
65 {
66  delete fLogger;
67 }
68 
69 ////////////////////////////////////////////////////////////////////////////////
70 
71 void TMVA::ResultsMulticlass::SetValue( std::vector<Float_t>& value, Int_t ievt )
72 {
73  if (ievt >= (Int_t)fMultiClassValues.size()) fMultiClassValues.resize( ievt+1 );
74  fMultiClassValues[ievt] = value;
75 }
76 
77 //_______________________________________________________________________
78 
79 Double_t TMVA::ResultsMulticlass::EstimatorFunction( std::vector<Double_t> & cutvalues ){
80 
81  DataSet* ds = GetDataSet();
82  ds->SetCurrentType( GetTreeType() );
83  Float_t truePositive = 0;
84  Float_t falsePositive = 0;
85  Float_t sumWeights = 0;
86 
87  for (Int_t ievt=0; ievt<ds->GetNEvents(); ievt++) {
88  const Event* ev = ds->GetEvent(ievt);
89  Float_t w = ev->GetWeight();
90  if(ev->GetClass()==fClassToOptimize)
91  sumWeights += w;
92  bool passed = true;
93  for(UInt_t icls = 0; icls<cutvalues.size(); ++icls){
94  if(cutvalues.at(icls)<0. ? -fMultiClassValues[ievt][icls]<cutvalues.at(icls) : fMultiClassValues[ievt][icls]<=cutvalues.at(icls)){
95  passed = false;
96  break;
97  }
98  }
99  if(!passed)
100  continue;
101  if(ev->GetClass()==fClassToOptimize)
102  truePositive += w;
103  else
104  falsePositive += w;
105  }
106 
107  Float_t eff = truePositive/sumWeights;
108  Float_t pur = truePositive/(truePositive+falsePositive);
109  Float_t effTimesPur = eff*pur;
110 
111  Float_t toMinimize = std::numeric_limits<float>::max();
112  if( effTimesPur > 0 )
113  toMinimize = 1./(effTimesPur); // we want to minimize 1/efficiency*purity
114 
117 
118  return toMinimize;
119 }
120 
121 //_______________________________________________________________________
122 
123 std::vector<Double_t> TMVA::ResultsMulticlass::GetBestMultiClassCuts(UInt_t targetClass){
124 
125  //calculate the best working point (optimal cut values)
126  //for the multiclass classifier
127  const DataSetInfo* dsi = GetDataSetInfo();
128  Log() << kINFO << "Calculating best set of cuts for class "
129  << dsi->GetClassInfo( targetClass )->GetName() << Endl;
130 
131  fClassToOptimize = targetClass;
132  std::vector<Interval*> ranges(dsi->GetNClasses(), new Interval(-1,1));
133 
134  const TString name( "MulticlassGA" );
135  const TString opts( "PopSize=100:Steps=30" );
136  GeneticFitter mg( *this, name, ranges, opts);
137 
138  std::vector<Double_t> result;
139  mg.Run(result);
140 
141  fBestCuts.at(targetClass) = result;
142 
143  UInt_t n = 0;
144  for( std::vector<Double_t>::iterator it = result.begin(); it<result.end(); it++ ){
145  Log() << kINFO << " cutValue[" <<dsi->GetClassInfo( n )->GetName() << "] = " << (*it) << ";"<< Endl;
146  n++;
147  }
148 
149  return result;
150 }
151 
152 //_______________________________________________________________________
153 
155 {
156  //this function fills the mva response histos for multiclass classification
157  Log() << kINFO << "Creating multiclass response histograms..." << Endl;
158 
159  DataSet* ds = GetDataSet();
160  ds->SetCurrentType( GetTreeType() );
161  const DataSetInfo* dsi = GetDataSetInfo();
162 
163  std::vector<std::vector<TH1F*> > histos;
164  Float_t xmin = 0.-0.0002;
165  Float_t xmax = 1.+0.0002;
166  for (UInt_t iCls = 0; iCls < dsi->GetNClasses(); iCls++) {
167  histos.push_back(std::vector<TH1F*>(0));
168  for (UInt_t jCls = 0; jCls < dsi->GetNClasses(); jCls++) {
169  TString name(Form("%s_%s_prob_for_%s",prefix.Data(),
170  dsi->GetClassInfo( jCls )->GetName(),
171  dsi->GetClassInfo( iCls )->GetName()));
172  histos.at(iCls).push_back(new TH1F(name,name,nbins,xmin,xmax));
173  }
174  }
175 
176  for (Int_t ievt=0; ievt<ds->GetNEvents(); ievt++) {
177  const Event* ev = ds->GetEvent(ievt);
178  Int_t cls = ev->GetClass();
179  Float_t w = ev->GetWeight();
180  for (UInt_t jCls = 0; jCls < dsi->GetNClasses(); jCls++) {
181  histos.at(cls).at(jCls)->Fill(fMultiClassValues[ievt][jCls],w);
182  }
183  }
184  for (UInt_t iCls = 0; iCls < dsi->GetNClasses(); iCls++) {
185  for (UInt_t jCls = 0; jCls < dsi->GetNClasses(); jCls++) {
186  gTools().NormHist( histos.at(iCls).at(jCls) );
187  Store(histos.at(iCls).at(jCls));
188  }
189  }
190 
191  /*
192  //fill fine binned histos for testing
193  if(prefix.Contains("Test")){
194  std::vector<std::vector<TH1F*> > histos_highbin;
195  for (UInt_t iCls = 0; iCls < dsi->GetNClasses(); iCls++) {
196  histos_highbin.push_back(std::vector<TH1F*>(0));
197  for (UInt_t jCls = 0; jCls < dsi->GetNClasses(); jCls++) {
198  TString name(Form("%s_%s_prob_for_%s_HIGHBIN",prefix.Data(),
199  dsi->GetClassInfo( jCls )->GetName().Data(),
200  dsi->GetClassInfo( iCls )->GetName().Data()));
201  histos_highbin.at(iCls).push_back(new TH1F(name,name,nbins_high,xmin,xmax));
202  }
203  }
204 
205  for (Int_t ievt=0; ievt<ds->GetNEvents(); ievt++) {
206  const Event* ev = ds->GetEvent(ievt);
207  Int_t cls = ev->GetClass();
208  Float_t w = ev->GetWeight();
209  for (UInt_t jCls = 0; jCls < dsi->GetNClasses(); jCls++) {
210  histos_highbin.at(cls).at(jCls)->Fill(fMultiClassValues[ievt][jCls],w);
211  }
212  }
213  for (UInt_t iCls = 0; iCls < dsi->GetNClasses(); iCls++) {
214  for (UInt_t jCls = 0; jCls < dsi->GetNClasses(); jCls++) {
215  gTools().NormHist( histos_highbin.at(iCls).at(jCls) );
216  Store(histos_highbin.at(iCls).at(jCls));
217  }
218  }
219  }
220  */
221 }
virtual const char * GetName() const
Returns name of object.
Definition: TNamed.h:51
float xmin
Definition: THbookFile.cxx:93
MsgLogger & Endl(MsgLogger &ml)
Definition: MsgLogger.h:162
float Float_t
Definition: RtypesCore.h:53
std::vector< Float_t > fAchievablePur
Double_t Run(std::vector< Double_t > &pars)
Execute fitting.
THist< 1, float, THistStatContent, THistStatUncertainty > TH1F
Definition: THist.hxx:302
const DataSetInfo * GetDataSetInfo() const
Definition: Results.h:77
DataSet * GetDataSet() const
Definition: Results.h:78
Basic string class.
Definition: TString.h:137
int Int_t
Definition: RtypesCore.h:41
UInt_t GetNClasses() const
Definition: DataSetInfo.h:154
int nbins[3]
std::vector< std::vector< Double_t > > fBestCuts
STL namespace.
Tools & gTools()
Definition: Tools.cxx:79
UInt_t GetClass() const
Definition: Event.h:89
std::vector< std::vector< double > > Data
Double_t NormHist(TH1 *theHist, Double_t norm=1.0)
normalises histogram
Definition: Tools.cxx:395
MsgLogger & Log() const
message logger
Double_t GetWeight() const
return the event weight - depending on whether the flag IgnoreNegWeightsInTraining is or not...
Definition: Event.cxx:378
ResultsMulticlass(const DataSetInfo *dsi, TString resultsName)
constructor
ClassInfo * GetClassInfo(Int_t clNum) const
std::vector< std::vector< Float_t > > fMultiClassValues
void SetValue(std::vector< Float_t > &value, Int_t ievt)
unsigned int UInt_t
Definition: RtypesCore.h:42
char * Form(const char *fmt,...)
Double_t EstimatorFunction(std::vector< Double_t > &)
float xmax
Definition: THbookFile.cxx:93
void CreateMulticlassHistos(TString prefix, Int_t nbins, Int_t nbins_high)
std::vector< Float_t > fAchievableEff
double Double_t
Definition: RtypesCore.h:55
std::vector< Double_t > GetBestMultiClassCuts(UInt_t targetClass)
void SetCurrentType(Types::ETreeType type) const
Definition: DataSet.h:114
Types::ETreeType GetTreeType() const
Definition: Results.h:76
Long64_t GetNEvents(Types::ETreeType type=Types::kMaxTreeType) const
Definition: DataSet.h:229
double result[121]
void Store(TObject *obj, const char *alias=0)
Definition: Results.cxx:83
const Int_t n
Definition: legend1.C:16
const Event * GetEvent() const
Definition: DataSet.cxx:211
char name[80]
Definition: TGX11.cxx:109
const char * Data() const
Definition: TString.h:349