ROOT  6.06/09
Reference Guide
ResultsMulticlass.cxx
Go to the documentation of this file.
1 // @(#)root/tmva $Id$
2 // Author: Andreas Hoecker, Peter Speckmayer, Joerg Stelzer, Helge Voss, Jan Therhaag
3 
4 /**********************************************************************************
5  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6  * Package: TMVA *
7  * Class : ResultsMulticlass *
8  * Web : http://tmva.sourceforge.net *
9  * *
10  * Description: *
11  * Implementation (see header for description) *
12  * *
13  * Authors (alphabetical): *
14  * Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland *
15  * Peter Speckmayer <Peter.Speckmayer@cern.ch> - CERN, Switzerland *
16  * Joerg Stelzer <Joerg.Stelzer@cern.ch> - CERN, Switzerland *
17  * Jan Therhaag <Jan.Therhaag@cern.ch> - U of Bonn, Germany *
18  * Helge Voss <Helge.Voss@cern.ch> - MPI-K Heidelberg, Germany *
19  * *
20  * Copyright (c) 2006: *
21  * CERN, Switzerland *
22  * MPI-K Heidelberg, Germany *
23  * U. of Bonn, Germany *
24  * *
25  * Redistribution and use in source and binary forms, with or without *
26  * modification, are permitted according to the terms listed in LICENSE *
27  * (http://tmva.sourceforge.net/LICENSE) *
28  **********************************************************************************/
29 
30 #include <vector>
31 #include <limits>
32 
33 #include "TMVA/ResultsMulticlass.h"
34 #include "TMVA/MsgLogger.h"
35 #include "TMVA/DataSet.h"
36 #include "TMVA/Tools.h"
37 #include "TMVA/GeneticAlgorithm.h"
38 #include "TMVA/GeneticFitter.h"
39 
40 ////////////////////////////////////////////////////////////////////////////////
41 /// constructor
42 
44  : Results( dsi, resultsName ),
45  IFitterTarget(),
46  fLogger( new MsgLogger(Form("ResultsMultiClass%s",resultsName.Data()) , kINFO) ),
47  fClassToOptimize(0),
48  fAchievableEff(dsi->GetNClasses()),
49  fAchievablePur(dsi->GetNClasses()),
50  fBestCuts(dsi->GetNClasses(),std::vector<Double_t>(dsi->GetNClasses()))
51 {
52 }
53 
54 ////////////////////////////////////////////////////////////////////////////////
55 /// destructor
56 
58 {
59  delete fLogger;
60 }
61 
62 ////////////////////////////////////////////////////////////////////////////////
63 
64 void TMVA::ResultsMulticlass::SetValue( std::vector<Float_t>& value, Int_t ievt )
65 {
66  if (ievt >= (Int_t)fMultiClassValues.size()) fMultiClassValues.resize( ievt+1 );
67  fMultiClassValues[ievt] = value;
68 }
69 
70 //_______________________________________________________________________
71 
72 Double_t TMVA::ResultsMulticlass::EstimatorFunction( std::vector<Double_t> & cutvalues ){
73 
74  DataSet* ds = GetDataSet();
75  ds->SetCurrentType( GetTreeType() );
76  Float_t truePositive = 0;
77  Float_t falsePositive = 0;
78  Float_t sumWeights = 0;
79 
80  for (Int_t ievt=0; ievt<ds->GetNEvents(); ievt++) {
81  const Event* ev = ds->GetEvent(ievt);
82  Float_t w = ev->GetWeight();
83  if(ev->GetClass()==fClassToOptimize)
84  sumWeights += w;
85  bool passed = true;
86  for(UInt_t icls = 0; icls<cutvalues.size(); ++icls){
87  if(cutvalues.at(icls)<0. ? -fMultiClassValues[ievt][icls]<cutvalues.at(icls) : fMultiClassValues[ievt][icls]<=cutvalues.at(icls)){
88  passed = false;
89  break;
90  }
91  }
92  if(!passed)
93  continue;
94  if(ev->GetClass()==fClassToOptimize)
95  truePositive += w;
96  else
97  falsePositive += w;
98  }
99 
100  Float_t eff = truePositive/sumWeights;
101  Float_t pur = truePositive/(truePositive+falsePositive);
102  Float_t effTimesPur = eff*pur;
103 
105  if( effTimesPur > 0 )
106  toMinimize = 1./(effTimesPur); // we want to minimize 1/efficiency*purity
107 
108  fAchievableEff.at(fClassToOptimize) = eff;
109  fAchievablePur.at(fClassToOptimize) = pur;
110 
111  return toMinimize;
112 }
113 
114 //_______________________________________________________________________
115 
116 std::vector<Double_t> TMVA::ResultsMulticlass::GetBestMultiClassCuts(UInt_t targetClass){
117 
118  //calculate the best working point (optimal cut values)
119  //for the multiclass classifier
120  const DataSetInfo* dsi = GetDataSetInfo();
121  Log() << kINFO << "Calculating best set of cuts for class "
122  << dsi->GetClassInfo( targetClass )->GetName() << Endl;
123 
124  fClassToOptimize = targetClass;
125  std::vector<Interval*> ranges(dsi->GetNClasses(), new Interval(-1,1));
126 
127  const TString name( "MulticlassGA" );
128  const TString opts( "PopSize=100:Steps=30" );
129  GeneticFitter mg( *this, name, ranges, opts);
130 
131  std::vector<Double_t> result;
132  mg.Run(result);
133 
134  fBestCuts.at(targetClass) = result;
135 
136  UInt_t n = 0;
137  for( std::vector<Double_t>::iterator it = result.begin(); it<result.end(); it++ ){
138  Log() << kINFO << " cutValue[" <<dsi->GetClassInfo( n )->GetName() << "] = " << (*it) << ";"<< Endl;
139  n++;
140  }
141 
142  return result;
143 }
144 
145 //_______________________________________________________________________
146 
148 {
149  //this function fills the mva response histos for multiclass classification
150  Log() << kINFO << "Creating multiclass response histograms..." << Endl;
151 
152  DataSet* ds = GetDataSet();
153  ds->SetCurrentType( GetTreeType() );
154  const DataSetInfo* dsi = GetDataSetInfo();
155 
156  std::vector<std::vector<TH1F*> > histos;
157  Float_t xmin = 0.-0.0002;
158  Float_t xmax = 1.+0.0002;
159  for (UInt_t iCls = 0; iCls < dsi->GetNClasses(); iCls++) {
160  histos.push_back(std::vector<TH1F*>(0));
161  for (UInt_t jCls = 0; jCls < dsi->GetNClasses(); jCls++) {
162  TString name(Form("%s_%s_prob_for_%s",prefix.Data(),
163  dsi->GetClassInfo( jCls )->GetName().Data(),
164  dsi->GetClassInfo( iCls )->GetName().Data()));
165  histos.at(iCls).push_back(new TH1F(name,name,nbins,xmin,xmax));
166  }
167  }
168 
169  for (Int_t ievt=0; ievt<ds->GetNEvents(); ievt++) {
170  const Event* ev = ds->GetEvent(ievt);
171  Int_t cls = ev->GetClass();
172  Float_t w = ev->GetWeight();
173  for (UInt_t jCls = 0; jCls < dsi->GetNClasses(); jCls++) {
174  histos.at(cls).at(jCls)->Fill(fMultiClassValues[ievt][jCls],w);
175  }
176  }
177  for (UInt_t iCls = 0; iCls < dsi->GetNClasses(); iCls++) {
178  for (UInt_t jCls = 0; jCls < dsi->GetNClasses(); jCls++) {
179  gTools().NormHist( histos.at(iCls).at(jCls) );
180  Store(histos.at(iCls).at(jCls));
181  }
182  }
183 
184  /*
185  //fill fine binned histos for testing
186  if(prefix.Contains("Test")){
187  std::vector<std::vector<TH1F*> > histos_highbin;
188  for (UInt_t iCls = 0; iCls < dsi->GetNClasses(); iCls++) {
189  histos_highbin.push_back(std::vector<TH1F*>(0));
190  for (UInt_t jCls = 0; jCls < dsi->GetNClasses(); jCls++) {
191  TString name(Form("%s_%s_prob_for_%s_HIGHBIN",prefix.Data(),
192  dsi->GetClassInfo( jCls )->GetName().Data(),
193  dsi->GetClassInfo( iCls )->GetName().Data()));
194  histos_highbin.at(iCls).push_back(new TH1F(name,name,nbins_high,xmin,xmax));
195  }
196  }
197 
198  for (Int_t ievt=0; ievt<ds->GetNEvents(); ievt++) {
199  const Event* ev = ds->GetEvent(ievt);
200  Int_t cls = ev->GetClass();
201  Float_t w = ev->GetWeight();
202  for (UInt_t jCls = 0; jCls < dsi->GetNClasses(); jCls++) {
203  histos_highbin.at(cls).at(jCls)->Fill(fMultiClassValues[ievt][jCls],w);
204  }
205  }
206  for (UInt_t iCls = 0; iCls < dsi->GetNClasses(); iCls++) {
207  for (UInt_t jCls = 0; jCls < dsi->GetNClasses(); jCls++) {
208  gTools().NormHist( histos_highbin.at(iCls).at(jCls) );
209  Store(histos_highbin.at(iCls).at(jCls));
210  }
211  }
212  }
213  */
214 }
float xmin
Definition: THbookFile.cxx:93
MsgLogger & Endl(MsgLogger &ml)
Definition: MsgLogger.h:162
float Float_t
Definition: RtypesCore.h:53
Double_t Run(std::vector< Double_t > &pars)
Execute fitting.
UInt_t GetNClasses() const
Definition: DataSetInfo.h:152
THist< 1, float > TH1F
Definition: THist.h:315
Basic string class.
Definition: TString.h:137
int Int_t
Definition: RtypesCore.h:41
int nbins[3]
STL namespace.
Double_t GetWeight() const
return the event weight - depending on whether the flag IgnoreNegWeightsInTraining is or not...
Definition: Event.cxx:376
const char * Data() const
Definition: TString.h:349
Tools & gTools()
Definition: Tools.cxx:79
std::vector< std::vector< double > > Data
Double_t NormHist(TH1 *theHist, Double_t norm=1.0)
normalises histogram
Definition: Tools.cxx:395
ResultsMulticlass(const DataSetInfo *dsi, TString resultsName)
constructor
TFileCollection * GetDataSet(const char *ds, const char *server="")
GetDataSet wrapper.
Definition: pq2wrappers.cxx:87
void SetValue(std::vector< Float_t > &value, Int_t ievt)
ClassInfo * GetClassInfo(Int_t clNum) const
unsigned int UInt_t
Definition: RtypesCore.h:42
char * Form(const char *fmt,...)
Double_t EstimatorFunction(std::vector< Double_t > &)
const TString & GetName() const
Definition: ClassInfo.h:72
float xmax
Definition: THbookFile.cxx:93
void CreateMulticlassHistos(TString prefix, Int_t nbins, Int_t nbins_high)
const Event * GetEvent() const
Definition: DataSet.cxx:180
void SetCurrentType(Types::ETreeType type) const
Definition: DataSet.h:111
double Double_t
Definition: RtypesCore.h:55
std::vector< Double_t > GetBestMultiClassCuts(UInt_t targetClass)
Long64_t GetNEvents(Types::ETreeType type=Types::kMaxTreeType) const
Definition: DataSet.h:225
UInt_t GetClass() const
Definition: Event.h:86
static Vc_ALWAYS_INLINE int_v max(const int_v &x, const int_v &y)
Definition: vector.h:440
#define name(a, b)
Definition: linkTestLib0.cpp:5
double result[121]
float value
Definition: math.cpp:443
const Int_t n
Definition: legend1.C:16
Definition: math.cpp:60