ROOT  6.07/01
Reference Guide
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Groups Pages
ResultsMulticlass.cxx
Go to the documentation of this file.
1 // @(#)root/tmva $Id$
2 // Author: Andreas Hoecker, Peter Speckmayer, Joerg Stelzer, Helge Voss, Jan Therhaag
3 
4 /**********************************************************************************
5  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6  * Package: TMVA *
7  * Class : ResultsMulticlass *
8  * Web : http://tmva.sourceforge.net *
9  * *
10  * Description: *
11  * Implementation (see header for description) *
12  * *
13  * Authors (alphabetical): *
14  * Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland *
15  * Peter Speckmayer <Peter.Speckmayer@cern.ch> - CERN, Switzerland *
16  * Joerg Stelzer <Joerg.Stelzer@cern.ch> - CERN, Switzerland *
17  * Jan Therhaag <Jan.Therhaag@cern.ch> - U of Bonn, Germany *
18  * Helge Voss <Helge.Voss@cern.ch> - MPI-K Heidelberg, Germany *
19  * *
20  * Copyright (c) 2006: *
21  * CERN, Switzerland *
22  * MPI-K Heidelberg, Germany *
23  * U. of Bonn, Germany *
24  * *
25  * Redistribution and use in source and binary forms, with or without *
26  * modification, are permitted according to the terms listed in LICENSE *
27  * (http://tmva.sourceforge.net/LICENSE) *
28  **********************************************************************************/
29 
30 #include "TMVA/ResultsMulticlass.h"
31 
32 #include "TMVA/DataSet.h"
33 #include "TMVA/DataSetInfo.h"
34 #include "TMVA/GeneticAlgorithm.h"
35 #include "TMVA/GeneticFitter.h"
36 #include "TMVA/MsgLogger.h"
37 #include "TMVA/Tools.h"
38 #include "TMVA/Types.h"
39 
40 #include "TH1F.h"
41 
42 #include <limits>
43 #include <vector>
44 
45 ////////////////////////////////////////////////////////////////////////////////
46 /// constructor
47 
49  : Results( dsi, resultsName ),
50  IFitterTarget(),
51  fLogger( new MsgLogger(Form("ResultsMultiClass%s",resultsName.Data()) , kINFO) ),
52  fClassToOptimize(0),
53  fAchievableEff(dsi->GetNClasses()),
54  fAchievablePur(dsi->GetNClasses()),
55  fBestCuts(dsi->GetNClasses(),std::vector<Double_t>(dsi->GetNClasses()))
56 {
57 }
58 
59 ////////////////////////////////////////////////////////////////////////////////
60 /// destructor
61 
63 {
64  delete fLogger;
65 }
66 
67 ////////////////////////////////////////////////////////////////////////////////
68 
69 void TMVA::ResultsMulticlass::SetValue( std::vector<Float_t>& value, Int_t ievt )
70 {
71  if (ievt >= (Int_t)fMultiClassValues.size()) fMultiClassValues.resize( ievt+1 );
72  fMultiClassValues[ievt] = value;
73 }
74 
75 //_______________________________________________________________________
76 
77 Double_t TMVA::ResultsMulticlass::EstimatorFunction( std::vector<Double_t> & cutvalues ){
78 
79  DataSet* ds = GetDataSet();
80  ds->SetCurrentType( GetTreeType() );
81  Float_t truePositive = 0;
82  Float_t falsePositive = 0;
83  Float_t sumWeights = 0;
84 
85  for (Int_t ievt=0; ievt<ds->GetNEvents(); ievt++) {
86  const Event* ev = ds->GetEvent(ievt);
87  Float_t w = ev->GetWeight();
88  if(ev->GetClass()==fClassToOptimize)
89  sumWeights += w;
90  bool passed = true;
91  for(UInt_t icls = 0; icls<cutvalues.size(); ++icls){
92  if(cutvalues.at(icls)<0. ? -fMultiClassValues[ievt][icls]<cutvalues.at(icls) : fMultiClassValues[ievt][icls]<=cutvalues.at(icls)){
93  passed = false;
94  break;
95  }
96  }
97  if(!passed)
98  continue;
99  if(ev->GetClass()==fClassToOptimize)
100  truePositive += w;
101  else
102  falsePositive += w;
103  }
104 
105  Float_t eff = truePositive/sumWeights;
106  Float_t pur = truePositive/(truePositive+falsePositive);
107  Float_t effTimesPur = eff*pur;
108 
110  if( effTimesPur > 0 )
111  toMinimize = 1./(effTimesPur); // we want to minimize 1/efficiency*purity
112 
113  fAchievableEff.at(fClassToOptimize) = eff;
114  fAchievablePur.at(fClassToOptimize) = pur;
115 
116  return toMinimize;
117 }
118 
119 //_______________________________________________________________________
120 
121 std::vector<Double_t> TMVA::ResultsMulticlass::GetBestMultiClassCuts(UInt_t targetClass){
122 
123  //calculate the best working point (optimal cut values)
124  //for the multiclass classifier
125  const DataSetInfo* dsi = GetDataSetInfo();
126  Log() << kINFO << "Calculating best set of cuts for class "
127  << dsi->GetClassInfo( targetClass )->GetName() << Endl;
128 
129  fClassToOptimize = targetClass;
130  std::vector<Interval*> ranges(dsi->GetNClasses(), new Interval(-1,1));
131 
132  const TString name( "MulticlassGA" );
133  const TString opts( "PopSize=100:Steps=30" );
134  GeneticFitter mg( *this, name, ranges, opts);
135 
136  std::vector<Double_t> result;
137  mg.Run(result);
138 
139  fBestCuts.at(targetClass) = result;
140 
141  UInt_t n = 0;
142  for( std::vector<Double_t>::iterator it = result.begin(); it<result.end(); it++ ){
143  Log() << kINFO << " cutValue[" <<dsi->GetClassInfo( n )->GetName() << "] = " << (*it) << ";"<< Endl;
144  n++;
145  }
146 
147  return result;
148 }
149 
150 //_______________________________________________________________________
151 
153 {
154  //this function fills the mva response histos for multiclass classification
155  Log() << kINFO << "Creating multiclass response histograms..." << Endl;
156 
157  DataSet* ds = GetDataSet();
158  ds->SetCurrentType( GetTreeType() );
159  const DataSetInfo* dsi = GetDataSetInfo();
160 
161  std::vector<std::vector<TH1F*> > histos;
162  Float_t xmin = 0.-0.0002;
163  Float_t xmax = 1.+0.0002;
164  for (UInt_t iCls = 0; iCls < dsi->GetNClasses(); iCls++) {
165  histos.push_back(std::vector<TH1F*>(0));
166  for (UInt_t jCls = 0; jCls < dsi->GetNClasses(); jCls++) {
167  TString name(Form("%s_%s_prob_for_%s",prefix.Data(),
168  dsi->GetClassInfo( jCls )->GetName().Data(),
169  dsi->GetClassInfo( iCls )->GetName().Data()));
170  histos.at(iCls).push_back(new TH1F(name,name,nbins,xmin,xmax));
171  }
172  }
173 
174  for (Int_t ievt=0; ievt<ds->GetNEvents(); ievt++) {
175  const Event* ev = ds->GetEvent(ievt);
176  Int_t cls = ev->GetClass();
177  Float_t w = ev->GetWeight();
178  for (UInt_t jCls = 0; jCls < dsi->GetNClasses(); jCls++) {
179  histos.at(cls).at(jCls)->Fill(fMultiClassValues[ievt][jCls],w);
180  }
181  }
182  for (UInt_t iCls = 0; iCls < dsi->GetNClasses(); iCls++) {
183  for (UInt_t jCls = 0; jCls < dsi->GetNClasses(); jCls++) {
184  gTools().NormHist( histos.at(iCls).at(jCls) );
185  Store(histos.at(iCls).at(jCls));
186  }
187  }
188 
189  /*
190  //fill fine binned histos for testing
191  if(prefix.Contains("Test")){
192  std::vector<std::vector<TH1F*> > histos_highbin;
193  for (UInt_t iCls = 0; iCls < dsi->GetNClasses(); iCls++) {
194  histos_highbin.push_back(std::vector<TH1F*>(0));
195  for (UInt_t jCls = 0; jCls < dsi->GetNClasses(); jCls++) {
196  TString name(Form("%s_%s_prob_for_%s_HIGHBIN",prefix.Data(),
197  dsi->GetClassInfo( jCls )->GetName().Data(),
198  dsi->GetClassInfo( iCls )->GetName().Data()));
199  histos_highbin.at(iCls).push_back(new TH1F(name,name,nbins_high,xmin,xmax));
200  }
201  }
202 
203  for (Int_t ievt=0; ievt<ds->GetNEvents(); ievt++) {
204  const Event* ev = ds->GetEvent(ievt);
205  Int_t cls = ev->GetClass();
206  Float_t w = ev->GetWeight();
207  for (UInt_t jCls = 0; jCls < dsi->GetNClasses(); jCls++) {
208  histos_highbin.at(cls).at(jCls)->Fill(fMultiClassValues[ievt][jCls],w);
209  }
210  }
211  for (UInt_t iCls = 0; iCls < dsi->GetNClasses(); iCls++) {
212  for (UInt_t jCls = 0; jCls < dsi->GetNClasses(); jCls++) {
213  gTools().NormHist( histos_highbin.at(iCls).at(jCls) );
214  Store(histos_highbin.at(iCls).at(jCls));
215  }
216  }
217  }
218  */
219 }
float xmin
Definition: THbookFile.cxx:93
MsgLogger & Endl(MsgLogger &ml)
Definition: MsgLogger.h:162
float Float_t
Definition: RtypesCore.h:53
Double_t Run(std::vector< Double_t > &pars)
Execute fitting.
UInt_t GetNClasses() const
Definition: DataSetInfo.h:152
Basic string class.
Definition: TString.h:137
1-D histogram with a float per channel (see TH1 documentation)}
Definition: TH1.h:570
int Int_t
Definition: RtypesCore.h:41
Double_t GetWeight() const
return the event weight - depending on whether the flag IgnoreNegWeightsInTraining is or not...
Definition: Event.cxx:376
int nbins[3]
const char * Data() const
Definition: TString.h:349
Tools & gTools()
Definition: Tools.cxx:79
std::vector< std::vector< double > > Data
Double_t NormHist(TH1 *theHist, Double_t norm=1.0)
normalises histogram
Definition: Tools.cxx:395
ResultsMulticlass(const DataSetInfo *dsi, TString resultsName)
constructor
TFileCollection * GetDataSet(const char *ds, const char *server="")
GetDataSet wrapper.
Definition: pq2wrappers.cxx:87
void SetValue(std::vector< Float_t > &value, Int_t ievt)
ClassInfo * GetClassInfo(Int_t clNum) const
unsigned int UInt_t
Definition: RtypesCore.h:42
char * Form(const char *fmt,...)
Double_t EstimatorFunction(std::vector< Double_t > &)
tuple w
Definition: qtexample.py:51
const TString & GetName() const
Definition: ClassInfo.h:72
float xmax
Definition: THbookFile.cxx:93
void CreateMulticlassHistos(TString prefix, Int_t nbins, Int_t nbins_high)
const Event * GetEvent() const
Definition: DataSet.cxx:186
void SetCurrentType(Types::ETreeType type) const
Definition: DataSet.h:111
double Double_t
Definition: RtypesCore.h:55
std::vector< Double_t > GetBestMultiClassCuts(UInt_t targetClass)
Long64_t GetNEvents(Types::ETreeType type=Types::kMaxTreeType) const
Definition: DataSet.h:225
UInt_t GetClass() const
Definition: Event.h:86
static Vc_ALWAYS_INLINE int_v max(const int_v &x, const int_v &y)
Definition: vector.h:440
list histos
Definition: hsimple.py:51
#define name(a, b)
Definition: linkTestLib0.cpp:5
double result[121]
float value
Definition: math.cpp:443
const Int_t n
Definition: legend1.C:16
Definition: math.cpp:60