ROOT  6.06/09
Reference Guide
DataSetInfo.cxx
Go to the documentation of this file.
1 // @(#)root/tmva $Id$
2 // Author: Joerg Stelzer, Peter Speckmeier
3 
4 /**********************************************************************************
5  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6  * Package: TMVA *
7  * Class : DataSetInfo *
8  * Web : http://tmva.sourceforge.net *
9  * *
10  * Description: *
11  * Implementation (see header for description) *
12  * *
13  * Authors (alphabetical): *
14  * Peter Speckmayer <speckmay@mail.cern.ch> - CERN, Switzerland *
15  * Joerg Stelzer <Joerg.Stelzer@cern.ch> - DESY, Germany *
16  * *
17  * Copyright (c) 2008: *
18  * CERN, Switzerland *
19  * MPI-K Heidelberg, Germany *
20  * DESY Hamburg, Germany *
21  * *
22  * Redistribution and use in source and binary forms, with or without *
23  * modification, are permitted according to the terms listed in LICENSE *
24  * (http://tmva.sourceforge.net/LICENSE) *
25  **********************************************************************************/
26 
27 #include <vector>
28 
29 #include "TEventList.h"
30 #include "TFile.h"
31 #include "TH1.h"
32 #include "TH2.h"
33 #include "TProfile.h"
34 #include "TRandom3.h"
35 #include "TMatrixF.h"
36 #include "TVectorF.h"
37 #include "TMath.h"
38 #include "TROOT.h"
39 #include "TObjString.h"
40 
41 #ifndef ROOT_TMVA_MsgLogger
42 #include "TMVA/MsgLogger.h"
43 #endif
44 #ifndef ROOT_TMVA_Tools
45 #include "TMVA/Tools.h"
46 #endif
47 #ifndef ROOT_TMVA_DataSet
48 #include "TMVA/DataSet.h"
49 #endif
50 #ifndef ROOT_TMVA_DataSetInfo
51 #include "TMVA/DataSetInfo.h"
52 #endif
53 #ifndef ROOT_TMVA_DataSetManager
54 #include "TMVA/DataSetManager.h"
55 #endif
56 #ifndef ROOT_TMVA_Event
57 #include "TMVA/Event.h"
58 #endif
59 
60 ////////////////////////////////////////////////////////////////////////////////
61 /// constructor
62 
64  : TObject(),
65  fDataSetManager(NULL),
66  fName(name),
67  fDataSet( 0 ),
68  fNeedsRebuilding( kTRUE ),
69  fVariables(),
70  fTargets(),
71  fSpectators(),
72  fClasses( 0 ),
73  fNormalization( "NONE" ),
74  fSplitOptions(""),
75  fTrainingSumSignalWeights(-1),
76  fTrainingSumBackgrWeights(-1),
77  fTestingSumSignalWeights (-1),
78  fTestingSumBackgrWeights (-1),
79  fOwnRootDir(0),
80  fVerbose( kFALSE ),
81  fSignalClass(0),
82  fTargetsForMulticlass(0),
83  fLogger( new MsgLogger("DataSetInfo", kINFO) )
84 {
85 }
86 
87 ////////////////////////////////////////////////////////////////////////////////
88 /// destructor
89 
91 {
92  ClearDataSet();
93 
94  for(UInt_t i=0, iEnd = fClasses.size(); i<iEnd; ++i) {
95  delete fClasses[i];
96  }
97 
98  delete fTargetsForMulticlass;
99 
100  delete fLogger;
101 }
102 
103 ////////////////////////////////////////////////////////////////////////////////
104 
106 {
107  if(fDataSet!=0) { delete fDataSet; fDataSet=0; }
108 }
109 
110 void
112 {
113  fLogger->SetMinType(t);
114 }
115 
116 ////////////////////////////////////////////////////////////////////////////////
117 
119 {
120  ClassInfo* theClass = GetClassInfo(className);
121  if (theClass) return theClass;
122 
123  fClasses.push_back( new ClassInfo(className) );
124  fClasses.back()->SetNumber(fClasses.size()-1);
125 
126  Log() << kINFO << "Added class \"" << className << "\"\t with internal class number "
127  << fClasses.back()->GetNumber() << Endl;
128 
129  if (className == "Signal") fSignalClass = fClasses.size()-1; // store the signal class index ( for comparison reasons )
130 
131  return fClasses.back();
132 }
133 
134 ////////////////////////////////////////////////////////////////////////////////
135 
137 {
138  for (std::vector<ClassInfo*>::iterator it = fClasses.begin(); it < fClasses.end(); it++) {
139  if ((*it)->GetName() == name) return (*it);
140  }
141  return 0;
142 }
143 
144 ////////////////////////////////////////////////////////////////////////////////
145 
147 {
148  try {
149  return fClasses.at(cls);
150  }
151  catch(...) {
152  return 0;
153  }
154 }
155 
156 ////////////////////////////////////////////////////////////////////////////////
157 
159 {
160  for (UInt_t cls = 0; cls < GetNClasses() ; cls++) {
161  Log() << kINFO << "Class index : " << cls << " name : " << GetClassInfo(cls)->GetName() << Endl;
162  }
163 }
164 
165 ////////////////////////////////////////////////////////////////////////////////
166 
168 {
169  return (ev->GetClass() == fSignalClass);
170 }
171 
172 ////////////////////////////////////////////////////////////////////////////////
173 
174 std::vector<Float_t>* TMVA::DataSetInfo::GetTargetsForMulticlass( const TMVA::Event* ev )
175 {
176  if( !fTargetsForMulticlass ) fTargetsForMulticlass = new std::vector<Float_t>( GetNClasses() );
177 // fTargetsForMulticlass->resize( GetNClasses() );
178  fTargetsForMulticlass->assign( GetNClasses(), 0.0 );
179  fTargetsForMulticlass->at( ev->GetClass() ) = 1.0;
180  return fTargetsForMulticlass;
181 }
182 
183 
184 ////////////////////////////////////////////////////////////////////////////////
185 
187 {
188  Bool_t hasCuts = kFALSE;
189  for (std::vector<ClassInfo*>::iterator it = fClasses.begin(); it < fClasses.end(); it++) {
190  if( TString((*it)->GetCut()) != TString("") ) hasCuts = kTRUE;
191  }
192  return hasCuts;
193 }
194 
195 ////////////////////////////////////////////////////////////////////////////////
196 
197 const TMatrixD* TMVA::DataSetInfo::CorrelationMatrix( const TString& className ) const
198 {
199  ClassInfo* ptr = GetClassInfo(className);
200  return ptr?ptr->GetCorrelationMatrix():0;
201 }
202 
203 ////////////////////////////////////////////////////////////////////////////////
204 /// add a variable (can be a complex expression) to the set of
205 /// variables used in the MV analysis
206 
208  const TString& title,
209  const TString& unit,
211  char varType,
212  Bool_t normalized,
213  void* external )
214 {
215  TString regexpr = expression; // remove possible blanks
216  regexpr.ReplaceAll(" ", "" );
217  fVariables.push_back(VariableInfo( regexpr, title, unit,
218  fVariables.size()+1, varType, external, min, max, normalized ));
219  fNeedsRebuilding = kTRUE;
220  return fVariables.back();
221 }
222 
223 ////////////////////////////////////////////////////////////////////////////////
224 /// add variable with given VariableInfo
225 
227  fVariables.push_back(VariableInfo( varInfo ));
228  fNeedsRebuilding = kTRUE;
229  return fVariables.back();
230 }
231 
232 ////////////////////////////////////////////////////////////////////////////////
233 /// add a variable (can be a complex expression) to the set of
234 /// variables used in the MV analysis
235 
237  const TString& title,
238  const TString& unit,
240  Bool_t normalized,
241  void* external )
242 {
243  TString regexpr = expression; // remove possible blanks
244  regexpr.ReplaceAll(" ", "" );
245  char type='F';
246  fTargets.push_back(VariableInfo( regexpr, title, unit,
247  fTargets.size()+1, type, external, min,
248  max, normalized ));
249  fNeedsRebuilding = kTRUE;
250  return fTargets.back();
251 }
252 
253 ////////////////////////////////////////////////////////////////////////////////
254 /// add target with given VariableInfo
255 
257  fTargets.push_back(VariableInfo( varInfo ));
258  fNeedsRebuilding = kTRUE;
259  return fTargets.back();
260 }
261 
262 ////////////////////////////////////////////////////////////////////////////////
263 /// add a spectator (can be a complex expression) to the set of spectator variables used in
264 /// the MV analysis
265 
267  const TString& title,
268  const TString& unit,
269  Double_t min, Double_t max, char type,
270  Bool_t normalized, void* external )
271 {
272  TString regexpr = expression; // remove possible blanks
273  regexpr.ReplaceAll(" ", "" );
274  fSpectators.push_back(VariableInfo( regexpr, title, unit,
275  fSpectators.size()+1, type, external, min, max, normalized ));
276  fNeedsRebuilding = kTRUE;
277  return fSpectators.back();
278 }
279 
280 ////////////////////////////////////////////////////////////////////////////////
281 /// add spectator with given VariableInfo
282 
284  fSpectators.push_back(VariableInfo( varInfo ));
285  fNeedsRebuilding = kTRUE;
286  return fSpectators.back();
287 }
288 
289 ////////////////////////////////////////////////////////////////////////////////
290 /// find variable by name
291 
293 {
294  for (UInt_t ivar=0; ivar<GetNVariables(); ivar++)
295  if (var == GetVariableInfo(ivar).GetInternalName()) return ivar;
296 
297  for (UInt_t ivar=0; ivar<GetNVariables(); ivar++)
298  Log() << kINFO << GetVariableInfo(ivar).GetInternalName() << Endl;
299 
300  Log() << kFATAL << "<FindVarIndex> Variable \'" << var << "\' not found." << Endl;
301 
302  return -1;
303 }
304 
305 ////////////////////////////////////////////////////////////////////////////////
306 /// set the weight expressions for the classes
307 /// if class name is specified, set only for this class
308 /// if class name is unknown, register new class with this name
309 
310 void TMVA::DataSetInfo::SetWeightExpression( const TString& expr, const TString& className )
311 {
312  if (className != "") {
313  TMVA::ClassInfo* ci = AddClass(className);
314  ci->SetWeight( expr );
315  }
316  else {
317  // no class name specified, set weight for all classes
318  if (fClasses.empty()) {
319  Log() << kWARNING << "No classes registered yet, cannot specify weight expression!" << Endl;
320  }
321  for (std::vector<ClassInfo*>::iterator it = fClasses.begin(); it < fClasses.end(); it++) {
322  (*it)->SetWeight( expr );
323  }
324  }
325 }
326 
327 ////////////////////////////////////////////////////////////////////////////////
328 
329 void TMVA::DataSetInfo::SetCorrelationMatrix( const TString& className, TMatrixD* matrix )
330 {
331  GetClassInfo(className)->SetCorrelationMatrix(matrix);
332 }
333 
334 ////////////////////////////////////////////////////////////////////////////////
335 /// set the cut for the classes
336 
337 void TMVA::DataSetInfo::SetCut( const TCut& cut, const TString& className )
338 {
339  if (className == "") { // if no className has been given set the cut for all the classes
340  for (std::vector<ClassInfo*>::iterator it = fClasses.begin(); it < fClasses.end(); it++) {
341  (*it)->SetCut( cut );
342  }
343  }
344  else {
345  TMVA::ClassInfo* ci = AddClass(className);
346  ci->SetCut( cut );
347  }
348 }
349 
350 ////////////////////////////////////////////////////////////////////////////////
351 /// set the cut for the classes
352 
353 void TMVA::DataSetInfo::AddCut( const TCut& cut, const TString& className )
354 {
355  if (className == "") { // if no className has been given set the cut for all the classes
356  for (std::vector<ClassInfo*>::iterator it = fClasses.begin(); it < fClasses.end(); it++) {
357  const TCut& oldCut = (*it)->GetCut();
358  (*it)->SetCut( oldCut+cut );
359  }
360  }
361  else {
362  TMVA::ClassInfo* ci = AddClass(className);
363  ci->SetCut( ci->GetCut()+cut );
364  }
365 }
366 
367 ////////////////////////////////////////////////////////////////////////////////
368 /// returns list of variables
369 
370 std::vector<TString> TMVA::DataSetInfo::GetListOfVariables() const
371 {
372  std::vector<TString> vNames;
373  std::vector<TMVA::VariableInfo>::const_iterator viIt = GetVariableInfos().begin();
374  for(;viIt != GetVariableInfos().end(); viIt++) vNames.push_back( (*viIt).GetExpression() );
375 
376  return vNames;
377 }
378 
379 ////////////////////////////////////////////////////////////////////////////////
380 /// calculates the correlation matrices for signal and background,
381 /// prints them to standard output, and fills 2D histograms
382 
384 {
385  Log() << kINFO << "Correlation matrix (" << className << "):" << Endl;
386  gTools().FormattedOutput( *CorrelationMatrix( className ), GetListOfVariables(), Log() );
387 }
388 
389 ////////////////////////////////////////////////////////////////////////////////
390 
392  const TString& hName,
393  const TString& hTitle ) const
394 {
395  if (m==0) return 0;
396 
397  const UInt_t nvar = GetNVariables();
398 
399  // workaround till the TMatrix templates are comonly used
400  // this keeps backward compatibility
401  TMatrixF* tm = new TMatrixF( nvar, nvar );
402  for (UInt_t ivar=0; ivar<nvar; ivar++) {
403  for (UInt_t jvar=0; jvar<nvar; jvar++) {
404  (*tm)(ivar, jvar) = (*m)(ivar,jvar);
405  }
406  }
407 
408  TH2F* h2 = new TH2F( *tm );
409  h2->SetNameTitle( hName, hTitle );
410 
411  for (UInt_t ivar=0; ivar<nvar; ivar++) {
412  h2->GetXaxis()->SetBinLabel( ivar+1, GetVariableInfo(ivar).GetTitle() );
413  h2->GetYaxis()->SetBinLabel( ivar+1, GetVariableInfo(ivar).GetTitle() );
414  }
415 
416  // present in percent, and round off digits
417  // also, use absolute value of correlation coefficient (ignore sign)
418  h2->Scale( 100.0 );
419  for (UInt_t ibin=1; ibin<=nvar; ibin++) {
420  for (UInt_t jbin=1; jbin<=nvar; jbin++) {
421  h2->SetBinContent( ibin, jbin, Int_t(h2->GetBinContent( ibin, jbin )) );
422  }
423  }
424 
425  // style settings
426  const Float_t labelSize = 0.055;
427  h2->SetStats( 0 );
428  h2->GetXaxis()->SetLabelSize( labelSize );
429  h2->GetYaxis()->SetLabelSize( labelSize );
430  h2->SetMarkerSize( 1.5 );
431  h2->SetMarkerColor( 0 );
432  h2->LabelsOption( "d" ); // diagonal labels on x axis
433  h2->SetLabelOffset( 0.011 );// label offset on x axis
434  h2->SetMinimum( -100.0 );
435  h2->SetMaximum( +100.0 );
436 
437  // -------------------------------------------------------------------------------------
438  // just in case one wants to change the position of the color palette axis
439  // -------------------------------------------------------------------------------------
440  // gROOT->SetStyle("Plain");
441  // TStyle* gStyle = gROOT->GetStyle( "Plain" );
442  // gStyle->SetPalette( 1, 0 );
443  // TPaletteAxis* paletteAxis
444  // = (TPaletteAxis*)h2->GetListOfFunctions()->FindObject( "palette" );
445  // -------------------------------------------------------------------------------------
446 
447  Log() << kDEBUG << "Created correlation matrix as 2D histogram: " << h2->GetName() << Endl;
448 
449  return h2;
450 }
451 
452 ////////////////////////////////////////////////////////////////////////////////
453 /// returns data set
454 
456 {
457  if (fDataSet==0 || fNeedsRebuilding) {
458  if(fDataSet!=0) ClearDataSet();
459 // fDataSet = DataSetManager::Instance().CreateDataSet(GetName()); //DSMTEST replaced by following lines
460  if( !fDataSetManager )
461  Log() << kFATAL << "DataSetManager has not been set in DataSetInfo (GetDataSet() )." << Endl;
462  fDataSet = fDataSetManager->CreateDataSet(GetName());
463 
464  fNeedsRebuilding = kFALSE;
465  }
466  return fDataSet;
467 }
468 
469 ////////////////////////////////////////////////////////////////////////////////
470 
472 {
473  if(all)
474  return fSpectators.size();
475  UInt_t nsp(0);
476  for(std::vector<VariableInfo>::const_iterator spit=fSpectators.begin(); spit!=fSpectators.end(); ++spit) {
477  if(spit->GetVarType()!='C') nsp++;
478  }
479  return nsp;
480 }
481 
482 ////////////////////////////////////////////////////////////////////////////////
483 
485 {
486  Int_t maxL = 0;
487  for (UInt_t cl = 0; cl < GetNClasses(); cl++) {
488  if (TString(GetClassInfo(cl)->GetName()).Length() > maxL) maxL = TString(GetClassInfo(cl)->GetName()).Length();
489  }
490 
491  return maxL;
492 }
493 
494 
496  if (fTrainingSumSignalWeights<0) Log() << kFATAL << " asking for the sum of training signal event weights which is not initicalised yet" << Endl;
497  return fTrainingSumSignalWeights;
498 }
500  if (fTrainingSumBackgrWeights<0) Log() << kFATAL << " asking for the sum of training backgr event weights which is not initicalised yet" << Endl;
501  return fTrainingSumBackgrWeights;
502 }
504  if (fTestingSumSignalWeights<0) Log() << kFATAL << " asking for the sum of testing signal event weights which is not initicalised yet" << Endl;
505  return fTestingSumSignalWeights ;
506 }
508  if (fTestingSumBackgrWeights<0) Log() << kFATAL << " asking for the sum of testing backgr event weights which is not initicalised yet" << Endl;
509  return fTestingSumBackgrWeights ;
510 }
511 
virtual void SetNameTitle(const char *name, const char *title)
Change the name and title of this histogram.
Definition: TH1.cxx:8303
virtual void Scale(Double_t c1=1, Option_t *option="")
Multiply this histogram by a constant c1.
Definition: TH1.cxx:6174
static Vc_ALWAYS_INLINE int_v min(const int_v &x, const int_v &y)
Definition: vector.h:433
MsgLogger & Endl(MsgLogger &ml)
Definition: MsgLogger.h:162
VariableInfo & AddTarget(const TString &expression, const TString &title, const TString &unit, Double_t min, Double_t max, Bool_t normalized=kTRUE, void *external=0)
add a variable (can be a complex expression) to the set of variables used in the MV analysis ...
virtual void SetMaximum(Double_t maximum=-1111)
Definition: TH1.h:394
virtual void LabelsOption(Option_t *option="h", Option_t *axis="X")
Set option(s) to draw axis with labels.
Definition: TH1.cxx:4901
Ssiz_t Length() const
Definition: TString.h:390
float Float_t
Definition: RtypesCore.h:53
THist< 2, float > TH2F
Definition: THist.h:321
void SetCut(const TCut &cut, const TString &className)
set the cut for the classes
TMatrixT< Float_t > TMatrixF
Definition: TMatrixFfwd.h:24
TString & ReplaceAll(const TString &s1, const TString &s2)
Definition: TString.h:635
Int_t GetClassNameMaxLength() const
virtual void SetMinimum(Double_t minimum=-1111)
Definition: TH1.h:395
Basic string class.
Definition: TString.h:137
void AddClass(const char *cname, Version_t id, const type_info &info, DictFuncPtr_t dict, Int_t pragmabits)
Global function called by the ctor of a class's init class (see the ClassImp macro).
int Int_t
Definition: RtypesCore.h:41
bool Bool_t
Definition: RtypesCore.h:59
const Bool_t kFALSE
Definition: Rtypes.h:92
void AddCut(const TCut &cut, const TString &className)
set the cut for the classes
Double_t GetTrainingSumSignalWeights()
DataSet * GetDataSet() const
returns data set
Tools & gTools()
Definition: Tools.cxx:79
void SetWeight(const TString &weight)
Definition: ClassInfo.h:66
Bool_t IsSignal(const Event *ev) const
virtual ~DataSetInfo()
destructor
Definition: DataSetInfo.cxx:90
virtual void SetMarkerColor(Color_t mcolor=1)
Definition: TAttMarker.h:51
void PrintCorrelationMatrix(const TString &className)
calculates the correlation matrices for signal and background, prints them to standard output...
virtual Double_t GetBinContent(Int_t bin) const
Return content of bin number bin.
Definition: TH2.h:90
DataSetInfo(const TString &name="Default")
constructor
Definition: DataSetInfo.cxx:63
A specialized string object used for TTree selections.
Definition: TCut.h:27
void SetMsgType(EMsgType t) const
void SetCorrelationMatrix(const TString &className, TMatrixD *matrix)
Double_t GetTestingSumSignalWeights()
Service class for 2-Dim histogram classes.
Definition: TH2.h:36
2-D histogram with a float per channel (see TH1 documentation)}
Definition: TH2.h:256
ClassInfo * GetClassInfo(Int_t clNum) const
EMsgType
Definition: Types.h:61
const TMatrixD * GetCorrelationMatrix() const
Definition: ClassInfo.h:76
const TMatrixD * CorrelationMatrix(const TString &className) const
void SetWeightExpression(const TString &exp, const TString &className="")
set the weight expressions for the classes if class name is specified, set only for this class if cla...
unsigned int UInt_t
Definition: RtypesCore.h:42
TMarker * m
Definition: textangle.C:8
virtual const char * GetName() const
Returns name of object.
Definition: TNamed.h:51
TAxis * GetYaxis()
Definition: TH1.h:320
const TCut & GetCut() const
Definition: ClassInfo.h:74
virtual void SetLabelSize(Float_t size=0.04)
Set size of axis labels The size is expressed in per cent of the pad width.
Definition: TAttAxis.cxx:186
virtual void SetMarkerSize(Size_t msize=1)
Definition: TAttMarker.h:54
void PrintClasses() const
TH2 * CreateCorrelationMatrixHist(const TMatrixD *m, const TString &hName, const TString &hTitle) const
double Double_t
Definition: RtypesCore.h:55
Double_t GetTrainingSumBackgrWeights()
int type
Definition: TGX11.cxx:120
void SetCut(const TCut &cut)
Definition: ClassInfo.h:67
virtual void SetBinLabel(Int_t bin, const char *label)
Set label for bin.
Definition: TAxis.cxx:793
void ClearDataSet() const
ClassInfo * AddClass(const TString &className)
UInt_t GetClass() const
Definition: Event.h:86
static Vc_ALWAYS_INLINE int_v max(const int_v &x, const int_v &y)
Definition: vector.h:440
#define name(a, b)
Definition: linkTestLib0.cpp:5
void FormattedOutput(const std::vector< Double_t > &, const std::vector< TString > &, const TString titleVars, const TString titleValues, MsgLogger &logger, TString format="%+1.3f")
formatted output of simple table
Definition: Tools.cxx:896
Mother of all ROOT objects.
Definition: TObject.h:58
VariableInfo & AddSpectator(const TString &expression, const TString &title, const TString &unit, Double_t min, Double_t max, char type= 'F', Bool_t normalized=kTRUE, void *external=0)
add a spectator (can be a complex expression) to the set of spectator variables used in the MV analys...
UInt_t GetNSpectators(bool all=kTRUE) const
void SetNumber(const UInt_t index)
Definition: ClassInfo.h:68
Int_t FindVarIndex(const TString &) const
find variable by name
#define NULL
Definition: Rtypes.h:82
VariableInfo & AddVariable(const TString &expression, const TString &title="", const TString &unit="", Double_t min=0, Double_t max=0, char varType='F', Bool_t normalized=kTRUE, void *external=0)
add a variable (can be a complex expression) to the set of variables used in the MV analysis ...
std::vector< TString > GetListOfVariables() const
returns list of variables
virtual void SetBinContent(Int_t bin, Double_t content)
Set bin content.
Definition: TH2.cxx:2689
const Bool_t kTRUE
Definition: Rtypes.h:91
Bool_t HasCuts() const
virtual void SetStats(Bool_t stats=kTRUE)
Set statistics option on/off.
Definition: TH1.cxx:8320
Double_t GetTestingSumBackgrWeights()
Definition: math.cpp:60
TAxis * GetXaxis()
Definition: TH1.h:319
std::vector< Float_t > * GetTargetsForMulticlass(const Event *ev)
virtual void SetLabelOffset(Float_t offset=0.005, Option_t *axis="X")
Set offset between axis and axis' labels.
Definition: Haxis.cxx:261