Logo ROOT   6.07/09
Reference Guide
DataLoader.cxx
Go to the documentation of this file.
1 // @(#)root/tmva $Id$
2 // Author: Omar Zapata
3 // Mentors: Lorenzo Moneta, Sergei Gleyzer
4 //NOTE: Based on TMVA::Factory
5 
6 /**********************************************************************************
7  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
8  * Package: TMVA *
9  * Class : DataLoader *
10  * Web : http://tmva.sourceforge.net *
11  * *
12  * Description: *
13  * This is a class to load datasets into every booked method *
14  * *
15  * Authors (alphabetical): *
16  * Lorenzo Moneta <Lorenzo.Moneta@cern.ch> - CERN, Switzerland *
17  * Omar Zapata <Omar.Zapata@cern.ch> - ITM/UdeA, Colombia *
18  * Sergei Gleyzer<sergei.gleyzer@cern.ch> - CERN, Switzerland *
19  * *
20  * Copyright (c) 2005-2015: *
21  * CERN, Switzerland *
22  * ITM/UdeA, Colombia *
23  * *
24  * Redistribution and use in source and binary forms, with or without *
25  * modification, are permitted according to the terms listed in LICENSE *
26  * (http://tmva.sourceforge.net/LICENSE) *
27  **********************************************************************************/
28 
29 
30 #include "TROOT.h"
31 #include "TFile.h"
32 #include "TTree.h"
33 #include "TLeaf.h"
34 #include "TEventList.h"
35 #include "TH2.h"
36 #include "TText.h"
37 #include "TStyle.h"
38 #include "TMatrixF.h"
39 #include "TMatrixDSym.h"
40 #include "TPaletteAxis.h"
41 #include "TPrincipal.h"
42 #include "TMath.h"
43 #include "TObjString.h"
44 #include "TRandom3.h"
45 
46 #include <string.h>
47 
48 #include "TMVA/Configurable.h"
49 #include "TMVA/DataLoader.h"
50 #include "TMVA/Config.h"
51 #include "TMVA/Tools.h"
52 #include "TMVA/Ranking.h"
53 #include "TMVA/DataSet.h"
54 #include "TMVA/IMethod.h"
55 #include "TMVA/MethodBase.h"
56 #include "TMVA/DataInputHandler.h"
57 #include "TMVA/DataSetManager.h"
58 #include "TMVA/DataSetInfo.h"
59 #include "TMVA/MethodBoost.h"
60 #include "TMVA/MethodCategory.h"
61 
62 #include "TMVA/VariableInfo.h"
69 
70 
72 #include "TMVA/ResultsRegression.h"
73 #include "TMVA/ResultsMulticlass.h"
74 #include "TMVA/Types.h"
75 
76 
78 
79 
80 //_______________________________________________________________________
82 : Configurable( ),
83  fDataSetManager ( NULL ), //DSMTEST
84  fDataInputHandler ( new DataInputHandler ),
85  fTransformations ( "I" ),
86  fVerbose ( kFALSE ),
87  fDataAssignType ( kAssignEvents ),
88  fATreeEvent (0)
89 {
90  fDataSetManager = new DataSetManager( *fDataInputHandler ); // DSMTEST
91  SetName(thedlName.Data());
92  fLogger->SetSource("DataLoader");
93 }
94 
95 
96 //_______________________________________________________________________
98 {
99  // destructor
100 
101  std::vector<TMVA::VariableTransformBase*>::iterator trfIt = fDefaultTrfs.begin();
102  for (;trfIt != fDefaultTrfs.end(); trfIt++) delete (*trfIt);
103 
104  delete fDataInputHandler;
105 
106  // destroy singletons
107  // DataSetManager::DestroyInstance(); // DSMTEST replaced by following line
108  delete fDataSetManager; // DSMTEST
109 
110  // problem with call of REGISTER_METHOD macro ...
111  // ClassifierDataLoader::DestroyInstance();
112  // Types::DestroyInstance();
115 }
116 
117 
118 //_______________________________________________________________________
120 {
121  return fDataSetManager->AddDataSetInfo(dsi); // DSMTEST
122 }
123 
124 //_______________________________________________________________________
126 {
127  DataSetInfo* dsi = fDataSetManager->GetDataSetInfo(dsiName); // DSMTEST
128 
129  if (dsi!=0) return *dsi;
130 
131  return fDataSetManager->AddDataSetInfo(*(new DataSetInfo(dsiName))); // DSMTEST
132 }
133 
134 //_______________________________________________________________________
136 {
137  return DefaultDataSetInfo(); // DSMTEST
138 }
139 
140 ////////////////////////////////////////////////////////////////////////////////
141 /// Transforms the variables and return a new DataLoader with the transformed
142 /// variables
143 
145 {
146  TString trOptions = "0";
147  TString trName = "None";
148  if (trafoDefinition.Contains("(")) {
149 
150  // contains transformation parameters
151  Ssiz_t parStart = trafoDefinition.Index( "(" );
152  Ssiz_t parLen = trafoDefinition.Index( ")", parStart )-parStart+1;
153 
154  trName = trafoDefinition(0,parStart);
155  trOptions = trafoDefinition(parStart,parLen);
156  trOptions.Remove(parLen-1,1);
157  trOptions.Remove(0,1);
158  }
159  else
160  trName = trafoDefinition;
161 
162  VarTransformHandler* handler = new VarTransformHandler(this);
163  // variance threshold variable transformation
164  if (trName == "VT") {
165 
166  // find threshold value from given input
167  Double_t threshold = 0.0;
168  if (!trOptions.IsFloat()){
169  Log() << kFATAL << " VT transformation must be passed a floating threshold value" << Endl;
170  return this;
171  }
172  else
173  threshold = trOptions.Atof();
174  TMVA::DataLoader *transformedLoader = handler->VarianceThreshold(threshold);
175  return transformedLoader;
176  }
177  else {
178  Log() << kFATAL << "Incorrect transformation string provided, please check" << Endl;
179  }
180  Log() << kINFO << "No transformation applied, returning original loader" << Endl;
181  return this;
182 }
183 
184 // ________________________________________________
185 // the next functions are to assign events directly
186 
187 //_______________________________________________________________________
189 {
190  // create the data assignment tree (for event-wise data assignment by user)
191  TTree * assignTree = new TTree( name, name );
192  assignTree->SetDirectory(0);
193  assignTree->Branch( "type", &fATreeType, "ATreeType/I" );
194  assignTree->Branch( "weight", &fATreeWeight, "ATreeWeight/F" );
195 
196  std::vector<VariableInfo>& vars = DefaultDataSetInfo().GetVariableInfos();
197  std::vector<VariableInfo>& tgts = DefaultDataSetInfo().GetTargetInfos();
198  std::vector<VariableInfo>& spec = DefaultDataSetInfo().GetSpectatorInfos();
199 
200  if (fATreeEvent.size()==0) fATreeEvent.resize(vars.size()+tgts.size()+spec.size());
201  // add variables
202  for (UInt_t ivar=0; ivar<vars.size(); ivar++) {
203  TString vname = vars[ivar].GetExpression();
204  assignTree->Branch( vname, &fATreeEvent[ivar], vname + "/F" );
205  }
206  // add targets
207  for (UInt_t itgt=0; itgt<tgts.size(); itgt++) {
208  TString vname = tgts[itgt].GetExpression();
209  assignTree->Branch( vname, &fATreeEvent[vars.size()+itgt], vname + "/F" );
210  }
211  // add spectators
212  for (UInt_t ispc=0; ispc<spec.size(); ispc++) {
213  TString vname = spec[ispc].GetExpression();
214  assignTree->Branch( vname, &fATreeEvent[vars.size()+tgts.size()+ispc], vname + "/F" );
215  }
216  return assignTree;
217 }
218 
219 //_______________________________________________________________________
220 void TMVA::DataLoader::AddSignalTrainingEvent( const std::vector<Double_t>& event, Double_t weight )
221 {
222  // add signal training event
223  AddEvent( "Signal", Types::kTraining, event, weight );
224 }
225 
226 //_______________________________________________________________________
227 void TMVA::DataLoader::AddSignalTestEvent( const std::vector<Double_t>& event, Double_t weight )
228 {
229  // add signal testing event
230  AddEvent( "Signal", Types::kTesting, event, weight );
231 }
232 
233 //_______________________________________________________________________
234 void TMVA::DataLoader::AddBackgroundTrainingEvent( const std::vector<Double_t>& event, Double_t weight )
235 {
236  // add signal training event
237  AddEvent( "Background", Types::kTraining, event, weight );
238 }
239 
240 //_______________________________________________________________________
241 void TMVA::DataLoader::AddBackgroundTestEvent( const std::vector<Double_t>& event, Double_t weight )
242 {
243  // add signal training event
244  AddEvent( "Background", Types::kTesting, event, weight );
245 }
246 
247 //_______________________________________________________________________
248 void TMVA::DataLoader::AddTrainingEvent( const TString& className, const std::vector<Double_t>& event, Double_t weight )
249 {
250  // add signal training event
251  AddEvent( className, Types::kTraining, event, weight );
252 }
253 
254 //_______________________________________________________________________
255 void TMVA::DataLoader::AddTestEvent( const TString& className, const std::vector<Double_t>& event, Double_t weight )
256 {
257  // add signal test event
258  AddEvent( className, Types::kTesting, event, weight );
259 }
260 
261 //_______________________________________________________________________
263  const std::vector<Double_t>& event, Double_t weight )
264 {
265  // add event
266  // vector event : the order of values is: variables + targets + spectators
267  ClassInfo* theClass = DefaultDataSetInfo().AddClass(className); // returns class (creates it if necessary)
268  UInt_t clIndex = theClass->GetNumber();
269 
270 
271  // set analysistype to "kMulticlass" if more than two classes and analysistype == kNoAnalysisType
272  if( fAnalysisType == Types::kNoAnalysisType && DefaultDataSetInfo().GetNClasses() > 2 )
274 
275 
276  if (clIndex>=fTrainAssignTree.size()) {
277  fTrainAssignTree.resize(clIndex+1, 0);
278  fTestAssignTree.resize(clIndex+1, 0);
279  }
280 
281  if (fTrainAssignTree[clIndex]==0) { // does not exist yet
282  fTrainAssignTree[clIndex] = CreateEventAssignTrees( Form("TrainAssignTree_%s", className.Data()) );
283  fTestAssignTree[clIndex] = CreateEventAssignTrees( Form("TestAssignTree_%s", className.Data()) );
284  }
285 
286  fATreeType = clIndex;
287  fATreeWeight = weight;
288  for (UInt_t ivar=0; ivar<event.size(); ivar++) fATreeEvent[ivar] = event[ivar];
289 
290  if(tt==Types::kTraining) fTrainAssignTree[clIndex]->Fill();
291  else fTestAssignTree[clIndex]->Fill();
292 
293 }
294 
295 //_______________________________________________________________________
297 {
298  //
299  return fTrainAssignTree[clIndex]!=0;
300 }
301 
302 //_______________________________________________________________________
304 {
305  // assign event-wise local trees to data set
306  UInt_t size = fTrainAssignTree.size();
307  for(UInt_t i=0; i<size; i++) {
308  if(!UserAssignEvents(i)) continue;
309  const TString& className = DefaultDataSetInfo().GetClassInfo(i)->GetName();
310  SetWeightExpression( "weight", className );
311  AddTree(fTrainAssignTree[i], className, 1.0, TCut(""), Types::kTraining );
312  AddTree(fTestAssignTree[i], className, 1.0, TCut(""), Types::kTesting );
313  }
314 }
315 
316 //_______________________________________________________________________
317 void TMVA::DataLoader::AddTree( TTree* tree, const TString& className, Double_t weight,
318  const TCut& cut, const TString& treetype )
319 {
320  // number of signal events (used to compute significance)
322  TString tmpTreeType = treetype; tmpTreeType.ToLower();
323  if (tmpTreeType.Contains( "train" ) && tmpTreeType.Contains( "test" )) tt = Types::kMaxTreeType;
324  else if (tmpTreeType.Contains( "train" )) tt = Types::kTraining;
325  else if (tmpTreeType.Contains( "test" )) tt = Types::kTesting;
326  else {
327  Log() << kFATAL << "<AddTree> cannot interpret tree type: \"" << treetype
328  << "\" should be \"Training\" or \"Test\" or \"Training and Testing\"" << Endl;
329  }
330  AddTree( tree, className, weight, cut, tt );
331 }
332 
333 //_______________________________________________________________________
334 void TMVA::DataLoader::AddTree( TTree* tree, const TString& className, Double_t weight,
335  const TCut& cut, Types::ETreeType tt )
336 {
337  if(!tree)
338  Log() << kFATAL << "Tree does not exist (empty pointer)." << Endl;
339 
340  DefaultDataSetInfo().AddClass( className );
341 
342  // set analysistype to "kMulticlass" if more than two classes and analysistype == kNoAnalysisType
343  if( fAnalysisType == Types::kNoAnalysisType && DefaultDataSetInfo().GetNClasses() > 2 )
345 
346  Log() << kINFO<< "Add Tree " << tree->GetName() << " of type " << className
347  << " with " << tree->GetEntries() << " events" << Endl;
348  DataInput().AddTree( tree, className, weight, cut, tt );
349 }
350 
351 //_______________________________________________________________________
353 {
354  // number of signal events (used to compute significance)
355  AddTree( signal, "Signal", weight, TCut(""), treetype );
356 }
357 
358 //_______________________________________________________________________
360 {
361  // add signal tree from text file
362 
363  // create trees from these ascii files
364  TTree* signalTree = new TTree( "TreeS", "Tree (S)" );
365  signalTree->ReadFile( datFileS );
366 
367  Log() << kINFO << "Create TTree objects from ASCII input files ... \n- Signal file : \""
368  << datFileS << Endl;
369 
370  // number of signal events (used to compute significance)
371  AddTree( signalTree, "Signal", weight, TCut(""), treetype );
372 }
373 
374 //_______________________________________________________________________
375 void TMVA::DataLoader::AddSignalTree( TTree* signal, Double_t weight, const TString& treetype )
376 {
377  AddTree( signal, "Signal", weight, TCut(""), treetype );
378 }
379 
380 //_______________________________________________________________________
382 {
383  // number of signal events (used to compute significance)
384  AddTree( signal, "Background", weight, TCut(""), treetype );
385 }
386 //_______________________________________________________________________
388 {
389  // add background tree from text file
390 
391  // create trees from these ascii files
392  TTree* bkgTree = new TTree( "TreeB", "Tree (B)" );
393  bkgTree->ReadFile( datFileB );
394 
395  Log() << kINFO << "Create TTree objects from ASCII input files ... \n- Background file : \""
396  << datFileB << Endl;
397 
398  // number of signal events (used to compute significance)
399  AddTree( bkgTree, "Background", weight, TCut(""), treetype );
400 }
401 
402 //_______________________________________________________________________
403 void TMVA::DataLoader::AddBackgroundTree( TTree* signal, Double_t weight, const TString& treetype )
404 {
405  AddTree( signal, "Background", weight, TCut(""), treetype );
406 }
407 
408 //_______________________________________________________________________
410 {
411  AddTree( tree, "Signal", weight );
412 }
413 
414 //_______________________________________________________________________
416 {
417  AddTree( tree, "Background", weight );
418 }
419 
420 //_______________________________________________________________________
421 void TMVA::DataLoader::SetTree( TTree* tree, const TString& className, Double_t weight )
422 {
423  // set background tree
424  AddTree( tree, className, weight, TCut(""), Types::kMaxTreeType );
425 }
426 
427 //_______________________________________________________________________
429  Double_t signalWeight, Double_t backgroundWeight )
430 {
431  // define the input trees for signal and background; no cuts are applied
432  AddTree( signal, "Signal", signalWeight, TCut(""), Types::kMaxTreeType );
433  AddTree( background, "Background", backgroundWeight, TCut(""), Types::kMaxTreeType );
434 }
435 
436 //_______________________________________________________________________
437 void TMVA::DataLoader::SetInputTrees( const TString& datFileS, const TString& datFileB,
438  Double_t signalWeight, Double_t backgroundWeight )
439 {
440  DataInput().AddTree( datFileS, "Signal", signalWeight );
441  DataInput().AddTree( datFileB, "Background", backgroundWeight );
442 }
443 
444 //_______________________________________________________________________
445 void TMVA::DataLoader::SetInputTrees( TTree* inputTree, const TCut& SigCut, const TCut& BgCut )
446 {
447  // define the input trees for signal and background from single input tree,
448  // containing both signal and background events distinguished by the type
449  // identifiers: SigCut and BgCut
450  AddTree( inputTree, "Signal", 1.0, SigCut, Types::kMaxTreeType );
451  AddTree( inputTree, "Background", 1.0, BgCut , Types::kMaxTreeType );
452 }
453 
454 //_______________________________________________________________________
455 void TMVA::DataLoader::AddVariable( const TString& expression, const TString& title, const TString& unit,
456  char type, Double_t min, Double_t max )
457 {
458  // user inserts discriminating variable in data set info
459  DefaultDataSetInfo().AddVariable( expression, title, unit, min, max, type );
460 }
461 
462 //_______________________________________________________________________
463 void TMVA::DataLoader::AddVariable( const TString& expression, char type,
464  Double_t min, Double_t max )
465 {
466  // user inserts discriminating variable in data set info
467  DefaultDataSetInfo().AddVariable( expression, "", "", min, max, type );
468 }
469 
470 //_______________________________________________________________________
471 void TMVA::DataLoader::AddTarget( const TString& expression, const TString& title, const TString& unit,
472  Double_t min, Double_t max )
473 {
474  // user inserts target in data set info
475 
478 
479  DefaultDataSetInfo().AddTarget( expression, title, unit, min, max );
480 }
481 
482 //_______________________________________________________________________
483 void TMVA::DataLoader::AddSpectator( const TString& expression, const TString& title, const TString& unit,
484  Double_t min, Double_t max )
485 {
486  // user inserts target in data set info
487  DefaultDataSetInfo().AddSpectator( expression, title, unit, min, max );
488 }
489 
490 //_______________________________________________________________________
492 {
493  // default creation
494  return AddDataSet( fName );
495 }
496 
497 //_______________________________________________________________________
498 void TMVA::DataLoader::SetInputVariables( std::vector<TString>* theVariables )
499 {
500  // fill input variables in data set
501  for (std::vector<TString>::iterator it=theVariables->begin();
502  it!=theVariables->end(); it++) AddVariable(*it);
503 }
504 
505 //_______________________________________________________________________
507 {
508  DefaultDataSetInfo().SetWeightExpression(variable, "Signal");
509 }
510 
511 //_______________________________________________________________________
513 {
514  DefaultDataSetInfo().SetWeightExpression(variable, "Background");
515 }
516 
517 //_______________________________________________________________________
518 void TMVA::DataLoader::SetWeightExpression( const TString& variable, const TString& className )
519 {
520  //Log() << kWarning << DefaultDataSetInfo().GetNClasses() /*fClasses.size()*/ << Endl;
521  if (className=="") {
522  SetSignalWeightExpression(variable);
524  }
525  else DefaultDataSetInfo().SetWeightExpression( variable, className );
526 }
527 
528 //_______________________________________________________________________
529 void TMVA::DataLoader::SetCut( const TString& cut, const TString& className ) {
530  SetCut( TCut(cut), className );
531 }
532 
533 //_______________________________________________________________________
534 void TMVA::DataLoader::SetCut( const TCut& cut, const TString& className )
535 {
536  DefaultDataSetInfo().SetCut( cut, className );
537 }
538 
539 //_______________________________________________________________________
540 void TMVA::DataLoader::AddCut( const TString& cut, const TString& className )
541 {
542  AddCut( TCut(cut), className );
543 }
544 
545 //_______________________________________________________________________
546 void TMVA::DataLoader::AddCut( const TCut& cut, const TString& className )
547 {
548  DefaultDataSetInfo().AddCut( cut, className );
549 }
550 
551 //_______________________________________________________________________
553  Int_t NsigTrain, Int_t NbkgTrain, Int_t NsigTest, Int_t NbkgTest,
554  const TString& otherOpt )
555 {
556  // prepare the training and test trees
558 
559  AddCut( cut );
560 
561  DefaultDataSetInfo().SetSplitOptions( Form("nTrain_Signal=%i:nTrain_Background=%i:nTest_Signal=%i:nTest_Background=%i:%s",
562  NsigTrain, NbkgTrain, NsigTest, NbkgTest, otherOpt.Data()) );
563 }
564 
565 //_______________________________________________________________________
567 {
568  // prepare the training and test trees
569  // kept for backward compatibility
571 
572  AddCut( cut );
573 
574  DefaultDataSetInfo().SetSplitOptions( Form("nTrain_Signal=%i:nTrain_Background=%i:nTest_Signal=%i:nTest_Background=%i:SplitMode=Random:EqualTrainSample:!V",
575  Ntrain, Ntrain, Ntest, Ntest) );
576 }
577 
578 //_______________________________________________________________________
580 {
581  // prepare the training and test trees
582  // -> same cuts for signal and background
584 
586  AddCut( cut );
588 }
589 
590 //_______________________________________________________________________
591 void TMVA::DataLoader::PrepareTrainingAndTestTree( TCut sigcut, TCut bkgcut, const TString& splitOpt )
592 {
593  // prepare the training and test trees
594 
595  // if event-wise data assignment, add local trees to dataset first
597 
598  //Log() << kINFO <<"Preparing trees for training and testing..."<< Endl;
599  AddCut( sigcut, "Signal" );
600  AddCut( bkgcut, "Background" );
601 
602  DefaultDataSetInfo().SetSplitOptions( splitOpt );
603 }
604 
605 //______________________________________________________________________
606 // Function required to split the training and testing datasets into a
607 // number of folds. Required by the CrossValidation and HyperParameterOptimisation
608 // classes. The option to split the training dataset into a training set and
609 // a validation set is implemented but not currently used.
610 void TMVA::DataLoader::MakeKFoldDataSet(UInt_t numberFolds, bool validationSet){
611 
612  if(!fMakeFoldDataSet){ return; } // No need to do it again if the sets have already been split.
613 
614  // Get the original event vectors for testing and training from the dataset.
615  const std::vector<Event*> TrainingData = DefaultDataSetInfo().GetDataSet()->GetEventCollection(Types::kTraining);
616  const std::vector<Event*> TestingData = DefaultDataSetInfo().GetDataSet()->GetEventCollection(Types::kTesting);
617 
618  std::vector<Event*> TrainSigData;
619  std::vector<Event*> TrainBkgData;
620  std::vector<Event*> TestSigData;
621  std::vector<Event*> TestBkgData;
622 
623  // Split the testing and training sets into signal and background classes.
624  for(UInt_t i=0; i<TrainingData.size(); ++i){
625  if( strncmp( DefaultDataSetInfo().GetClassInfo( TrainingData.at(i)->GetClass() )->GetName(), "Signal", 6)){ TrainSigData.push_back(TrainingData.at(i)); }
626  else if( strncmp( DefaultDataSetInfo().GetClassInfo( TrainingData.at(i)->GetClass() )->GetName(), "Background", 10)){ TrainBkgData.push_back(TrainingData.at(i)); }
627  else{
628  Log() << kFATAL << "DataSets should only contain Signal and Background classes for classification, " << DefaultDataSetInfo().GetClassInfo( TrainingData.at(i)->GetClass() )->GetName() << " is not a recognised class" << Endl;
629  }
630  }
631 
632  for(UInt_t i=0; i<TestingData.size(); ++i){
633  if( strncmp( DefaultDataSetInfo().GetClassInfo( TestingData.at(i)->GetClass() )->GetName(), "Signal", 6)){ TestSigData.push_back(TestingData.at(i)); }
634  else if( strncmp( DefaultDataSetInfo().GetClassInfo( TestingData.at(i)->GetClass() )->GetName(), "Background", 10)){ TestBkgData.push_back(TestingData.at(i)); }
635  else{
636  Log() << kFATAL << "DataSets should only contain Signal and Background classes for classification, " << DefaultDataSetInfo().GetClassInfo( TrainingData.at(i)->GetClass() )->GetName() << " is not a recognised class" << Endl;
637  }
638  }
639 
640 
641  // Split the sets into the number of folds.
642  if(validationSet){
643  std::vector<std::vector<Event*>> tempSigEvents = SplitSets(TrainSigData,0,2);
644  std::vector<std::vector<Event*>> tempBkgEvents = SplitSets(TrainBkgData,0,2);
645  fTrainSigEvents = SplitSets(tempSigEvents.at(0),0,numberFolds);
646  fTrainBkgEvents = SplitSets(tempBkgEvents.at(0),0,numberFolds);
647  fValidSigEvents = SplitSets(tempSigEvents.at(1),0,numberFolds);
648  fValidBkgEvents = SplitSets(tempBkgEvents.at(1),0,numberFolds);
649  }
650  else{
651  fTrainSigEvents = SplitSets(TrainSigData,0,numberFolds);
652  fTrainBkgEvents = SplitSets(TrainBkgData,0,numberFolds);
653  }
654 
655  fTestSigEvents = SplitSets(TestSigData,0,numberFolds);
656  fTestBkgEvents = SplitSets(TestBkgData,0,numberFolds);
657 }
658 
659 //______________________________________________________________________
660 // Function for assigning the correct folds to the testing or training set.
662 
663  UInt_t numFolds = fTrainSigEvents.size();
664 
665  std::vector<Event*>* tempTrain = new std::vector<Event*>;
666  std::vector<Event*>* tempTest = new std::vector<Event*>;
667 
668  UInt_t nTrain = 0;
669  UInt_t nTest = 0;
670 
671  // Get the number of events so the memory can be reserved.
672  for(UInt_t i=0; i<numFolds; ++i){
673  if(tt == Types::kTraining){
674  if(i!=foldNumber){
675  nTrain += fTrainSigEvents.at(i).size();
676  nTrain += fTrainBkgEvents.at(i).size();
677  }
678  else{
679  nTest += fTrainSigEvents.at(i).size();
680  nTest += fTrainSigEvents.at(i).size();
681  }
682  }
683  else if(tt == Types::kValidation){
684  if(i!=foldNumber){
685  nTrain += fValidSigEvents.at(i).size();
686  nTrain += fValidBkgEvents.at(i).size();
687  }
688  else{
689  nTest += fValidSigEvents.at(i).size();
690  nTest += fValidSigEvents.at(i).size();
691  }
692  }
693  else if(tt == Types::kTesting){
694  if(i!=foldNumber){
695  nTrain += fTestSigEvents.at(i).size();
696  nTrain += fTestBkgEvents.at(i).size();
697  }
698  else{
699  nTest += fTestSigEvents.at(i).size();
700  nTest += fTestSigEvents.at(i).size();
701  }
702  }
703  }
704 
705  // Reserve memory before filling vectors
706  tempTrain->reserve(nTrain);
707  tempTest->reserve(nTest);
708 
709  // Fill vectors with correct folds for testing and training.
710  for(UInt_t j=0; j<numFolds; ++j){
711  if(tt == Types::kTraining){
712  if(j!=foldNumber){
713  tempTrain->insert(tempTrain->end(), fTrainSigEvents.at(j).begin(), fTrainSigEvents.at(j).end());
714  tempTrain->insert(tempTrain->end(), fTrainBkgEvents.at(j).begin(), fTrainBkgEvents.at(j).end());
715  }
716  else{
717  tempTest->insert(tempTest->end(), fTrainSigEvents.at(j).begin(), fTrainSigEvents.at(j).end());
718  tempTest->insert(tempTest->end(), fTrainBkgEvents.at(j).begin(), fTrainBkgEvents.at(j).end());
719  }
720  }
721  else if(tt == Types::kValidation){
722  if(j!=foldNumber){
723  tempTrain->insert(tempTrain->end(), fValidSigEvents.at(j).begin(), fValidSigEvents.at(j).end());
724  tempTrain->insert(tempTrain->end(), fValidBkgEvents.at(j).begin(), fValidBkgEvents.at(j).end());
725  }
726  else{
727  tempTest->insert(tempTest->end(), fValidSigEvents.at(j).begin(), fValidSigEvents.at(j).end());
728  tempTest->insert(tempTest->end(), fValidBkgEvents.at(j).begin(), fValidBkgEvents.at(j).end());
729  }
730  }
731  else if(tt == Types::kTesting){
732  if(j!=foldNumber){
733  tempTrain->insert(tempTrain->end(), fTestSigEvents.at(j).begin(), fTestSigEvents.at(j).end());
734  tempTrain->insert(tempTrain->end(), fTestBkgEvents.at(j).begin(), fTestBkgEvents.at(j).end());
735  }
736  else{
737  tempTest->insert(tempTest->end(), fTestSigEvents.at(j).begin(), fTestSigEvents.at(j).end());
738  tempTest->insert(tempTest->end(), fTestBkgEvents.at(j).begin(), fTestBkgEvents.at(j).end());
739  }
740  }
741  }
742 
743  // Assign the vectors of the events to rebuild the dataset
746 
747 }
748 
749 //______________________________________________________________________
750 // Splits the input vector in to equally sized randomly sampled folds.
751 std::vector<std::vector<TMVA::Event*>> TMVA::DataLoader::SplitSets(std::vector<TMVA::Event*>& oldSet, int seedNum, int numFolds){
752 
753  ULong64_t nEntries = oldSet.size();
754  ULong64_t foldSize = nEntries/numFolds;
755 
756  std::vector<std::vector<Event*>> tempSets;
757  tempSets.resize(numFolds);
758 
759  TRandom3 r(seedNum);
760 
761  ULong64_t inSet = 0;
762 
763  for(ULong64_t i=0; i<nEntries; i++){
764  bool inTree = false;
765  if(inSet == foldSize*numFolds){
766  break;
767  }
768  else{
769  while(!inTree){
770  int s = r.Integer(numFolds);
771  if(tempSets.at(s).size()<foldSize){
772  tempSets.at(s).push_back(oldSet.at(i));
773  inSet++;
774  inTree=true;
775  }
776  }
777  }
778  }
779 
780  return tempSets;
781 
782 }
783 
784 //_______________________________________________________________________
785 //Copy method use in VI and CV
787 {
788  TMVA::DataLoader* des=new TMVA::DataLoader(name);
789  DataLoaderCopy(des,this);
790  return des;
791 }
792 
793 //_______________________________________________________________________
795 {
796  //Loading Dataset from DataInputHandler for subseed
797  for( std::vector<TreeInfo>::const_iterator treeinfo=src->DataInput().Sbegin();treeinfo!=src->DataInput().Send();treeinfo++)
798  {
799  des->AddSignalTree( (*treeinfo).GetTree(), (*treeinfo).GetWeight(),(*treeinfo).GetTreeType());
800  }
801 
802  for( std::vector<TreeInfo>::const_iterator treeinfo=src->DataInput().Bbegin();treeinfo!=src->DataInput().Bend();treeinfo++)
803  {
804  des->AddBackgroundTree( (*treeinfo).GetTree(), (*treeinfo).GetWeight(),(*treeinfo).GetTreeType());
805  }
806 }
807 
808 //_______________________________________________________________________
810 {
811  //returns the correlation matrix of datasets
812  const TMatrixD * m = DefaultDataSetInfo().CorrelationMatrix(className);
814  "CorrelationMatrix"+className, "Correlation Matrix ("+className+")");
815 }
816 
817 
DataSetInfo * GetDataSetInfo(const TString &dsiName)
returns datasetinfo object for given name
void AddBackgroundTree(TTree *background, Double_t weight=1.0, Types::ETreeType treetype=Types::kMaxTreeType)
Definition: DataLoader.cxx:381
DataSetManager * fDataSetManager
Definition: DataLoader.h:197
virtual ~DataLoader()
Definition: DataLoader.cxx:97
Random number generator class based on M.
Definition: TRandom3.h:29
MsgLogger & Endl(MsgLogger &ml)
Definition: MsgLogger.h:162
VariableInfo & AddTarget(const TString &expression, const TString &title, const TString &unit, Double_t min, Double_t max, Bool_t normalized=kTRUE, void *external=0)
add a variable (can be a complex expression) to the set of variables used in the MV analysis ...
void AddTrainingEvent(const TString &className, const std::vector< Double_t > &event, Double_t weight)
Definition: DataLoader.cxx:248
std::vector< VariableInfo > & GetSpectatorInfos()
Definition: DataSetInfo.h:122
std::vector< TMVA::VariableTransformBase * > fDefaultTrfs
Definition: DataLoader.h:202
void SetCut(const TCut &cut, const TString &className)
set the cut for the classes
TMVA::DataLoader * VarianceThreshold(Double_t threshold)
Computes variance of all the variables and returns a new DataLoader with the selected variables whose...
std::vector< std::vector< TMVA::Event * > > fTrainBkgEvents
Definition: DataLoader.h:218
DataSetInfo & GetDataSetInfo()
Definition: DataLoader.cxx:135
Double_t Atof() const
Return floating-point value contained in string.
Definition: TString.cxx:2031
Double_t background(Double_t *x, Double_t *par)
TTree * CreateEventAssignTrees(const TString &name)
Definition: DataLoader.cxx:188
DataSetInfo & DefaultDataSetInfo()
Definition: DataLoader.cxx:491
DataLoader * VarTransform(TString trafoDefinition)
Transforms the variables and return a new DataLoader with the transformed variables.
Definition: DataLoader.cxx:144
Basic string class.
Definition: TString.h:137
void ToLower()
Change string to lower-case.
Definition: TString.cxx:1089
int Int_t
Definition: RtypesCore.h:41
void MakeKFoldDataSet(UInt_t numberFolds, bool validationSet=false)
Definition: DataLoader.cxx:610
bool Bool_t
Definition: RtypesCore.h:59
const Bool_t kFALSE
Definition: Rtypes.h:92
void DataLoaderCopy(TMVA::DataLoader *des, TMVA::DataLoader *src)
Definition: DataLoader.cxx:794
void AddCut(const TCut &cut, const TString &className)
set the cut for the classes
void SetBackgroundTree(TTree *background, Double_t weight=1.0)
Definition: DataLoader.cxx:415
DataInputHandler * fDataInputHandler
Definition: DataLoader.h:200
Types::EAnalysisType fAnalysisType
Definition: DataLoader.h:228
void AddBackgroundTestEvent(const std::vector< Double_t > &event, Double_t weight=1.0)
Definition: DataLoader.cxx:241
TH2 * GetCorrelationMatrix(const TString &className)
Definition: DataLoader.cxx:809
std::vector< std::vector< TMVA::Event * > > fTestBkgEvents
Definition: DataLoader.h:222
DataSet * GetDataSet() const
returns data set
void AddVariable(const TString &expression, const TString &title, const TString &unit, char type='F', Double_t min=0, Double_t max=0)
Definition: DataLoader.cxx:455
const char * Data() const
Definition: TString.h:349
std::vector< TreeInfo >::const_iterator Bbegin() const
TText * tt
Definition: textangle.C:16
void AddTestEvent(const TString &className, const std::vector< Double_t > &event, Double_t weight)
Definition: DataLoader.cxx:255
void SetInputTrees(const TString &signalFileName, const TString &backgroundFileName, Double_t signalWeight=1.0, Double_t backgroundWeight=1.0)
Definition: DataLoader.cxx:437
virtual UInt_t Integer(UInt_t imax)
Returns a random integer on [ 0, imax-1 ].
Definition: TRandom.cxx:320
void SetTree(TTree *tree, const TString &className, Double_t weight)
Definition: DataLoader.cxx:421
void PrepareFoldDataSet(UInt_t foldNumber, Types::ETreeType tt)
Definition: DataLoader.cxx:661
void SetInputVariables(std::vector< TString > *theVariables)
Definition: DataLoader.cxx:498
DataSetInfo & AddDataSet(DataSetInfo &)
Definition: DataLoader.cxx:119
std::vector< VariableInfo > & GetTargetInfos()
Definition: DataSetInfo.h:117
void AddCut(const TString &cut, const TString &className="")
Definition: DataLoader.cxx:540
A specialized string object used for TTree selections.
Definition: TCut.h:27
static void DestroyInstance()
Definition: Tools.cxx:95
void SetInputTreesFromEventAssignTrees()
Definition: DataLoader.cxx:303
Float_t fATreeWeight
Definition: DataLoader.h:225
void SetSplitOptions(const TString &so)
Definition: DataSetInfo.h:184
Bool_t fMakeFoldDataSet
Definition: DataLoader.h:230
DataInputHandler & DataInput()
Definition: DataLoader.h:183
TRandom2 r(17)
Service class for 2-Dim histogram classes.
Definition: TH2.h:36
ClassInfo * GetClassInfo(Int_t clNum) const
std::vector< TreeInfo >::const_iterator Send() const
void AddTree(TTree *tree, const TString &className, Double_t weight=1.0, const TCut &cut="", Types::ETreeType tt=Types::kMaxTreeType)
add tree of className events for tt (Training;Testing..) type as input ..
const TMatrixD * CorrelationMatrix(const TString &className) const
void SetWeightExpression(const TString &exp, const TString &className="")
set the weight expressions for the classes if class name is specified, set only for this class if cla...
DataSetInfo & AddDataSetInfo(DataSetInfo &dsi)
stores a copy of the dataset info object
unsigned int UInt_t
Definition: RtypesCore.h:42
TMarker * m
Definition: textangle.C:8
char * Form(const char *fmt,...)
std::vector< TTree * > fTestAssignTree
Definition: DataLoader.h:215
Bool_t UserAssignEvents(UInt_t clIndex)
Definition: DataLoader.cxx:296
std::vector< Float_t > fATreeEvent
Definition: DataLoader.h:226
virtual const char * GetName() const
Returns name of object.
Definition: TNamed.h:51
std::vector< std::vector< TMVA::Event * > > fTestSigEvents
Definition: DataLoader.h:221
void SetEventCollection(std::vector< Event * > *, Types::ETreeType, Bool_t deleteEvents=true)
Sets the event collection (by DataSetFactory)
Definition: DataSet.cxx:259
TString fName
Definition: TNamed.h:36
void PrintClasses() const
TH2 * CreateCorrelationMatrixHist(const TMatrixD *m, const TString &hName, const TString &hTitle) const
DataLoader * MakeCopy(TString name)
Definition: DataLoader.cxx:786
TString & Remove(Ssiz_t pos)
Definition: TString.h:616
int Ssiz_t
Definition: RtypesCore.h:63
void AddTree(TTree *tree, const TString &className, Double_t weight=1.0, const TCut &cut="", Types::ETreeType tt=Types::kMaxTreeType)
Definition: DataLoader.cxx:334
void PrepareTrainingAndTestTree(const TCut &cut, const TString &splitOpt)
Definition: DataLoader.cxx:579
virtual void SetDirectory(TDirectory *dir)
Change the tree&#39;s directory.
Definition: TTree.cxx:8285
#define ClassImp(name)
Definition: Rtypes.h:279
double Double_t
Definition: RtypesCore.h:55
void AddEvent(const TString &className, Types::ETreeType tt, const std::vector< Double_t > &event, Double_t weight)
Definition: DataLoader.cxx:262
std::vector< TreeInfo >::const_iterator Bend() const
void SetBackgroundWeightExpression(const TString &variable)
Definition: DataLoader.cxx:512
int type
Definition: TGX11.cxx:120
unsigned long long ULong64_t
Definition: RtypesCore.h:70
Bool_t IsFloat() const
Returns kTRUE if string contains a floating point or integer number.
Definition: TString.cxx:1835
static void DestroyInstance()
static function: destroy TMVA instance
Definition: Config.cxx:81
MsgLogger & Log() const
Definition: Configurable.h:128
void AddTarget(const TString &expression, const TString &title="", const TString &unit="", Double_t min=0, Double_t max=0)
Definition: DataLoader.cxx:471
void SetWeightExpression(const TString &variable, const TString &className="")
Definition: DataLoader.cxx:518
void AddBackgroundTrainingEvent(const std::vector< Double_t > &event, Double_t weight=1.0)
Definition: DataLoader.cxx:234
ClassInfo * AddClass(const TString &className)
void SetSignalWeightExpression(const TString &variable)
Definition: DataLoader.cxx:506
std::vector< std::vector< TMVA::Event * > > fTrainSigEvents
Definition: DataLoader.h:217
virtual Long64_t ReadFile(const char *filename, const char *branchDescriptor="", char delimiter= ' ')
Create or simply read branches from filename.
Definition: TTree.cxx:6995
virtual Int_t Branch(TCollection *list, Int_t bufsize=32000, Int_t splitlevel=99, const char *name="")
Create one branch for each element in the collection.
Definition: TTree.cxx:1651
Abstract ClassifierFactory template that handles arbitrary types.
std::vector< TreeInfo >::const_iterator Sbegin() const
std::vector< TTree * > fTrainAssignTree
Definition: DataLoader.h:214
VariableInfo & AddSpectator(const TString &expression, const TString &title, const TString &unit, Double_t min, Double_t max, char type= 'F', Bool_t normalized=kTRUE, void *external=0)
add a spectator (can be a complex expression) to the set of spectator variables used in the MV analys...
void AddSignalTestEvent(const std::vector< Double_t > &event, Double_t weight=1.0)
Definition: DataLoader.cxx:227
std::vector< std::vector< TMVA::Event * > > fValidBkgEvents
Definition: DataLoader.h:220
void AddSignalTrainingEvent(const std::vector< Double_t > &event, Double_t weight=1.0)
Definition: DataLoader.cxx:220
friend void DataLoaderCopy(TMVA::DataLoader *des, TMVA::DataLoader *src)
Bool_t Contains(const char *pat, ECaseCompare cmp=kExact) const
Definition: TString.h:567
void SetSignalTree(TTree *signal, Double_t weight=1.0)
Definition: DataLoader.cxx:409
#define NULL
Definition: Rtypes.h:82
VariableInfo & AddVariable(const TString &expression, const TString &title="", const TString &unit="", Double_t min=0, Double_t max=0, char varType='F', Bool_t normalized=kTRUE, void *external=0)
add a variable (can be a complex expression) to the set of variables used in the MV analysis ...
const std::vector< Event * > & GetEventCollection(Types::ETreeType type=Types::kMaxTreeType) const
Definition: DataSet.h:239
Definition: tree.py:1
virtual Long64_t GetEntries() const
Definition: TTree.h:392
A TTree object has a header with a name and a title.
Definition: TTree.h:98
std::vector< std::vector< TMVA::Event * > > SplitSets(std::vector< TMVA::Event * > &oldSet, int seedNum, int numFolds)
Definition: DataLoader.cxx:751
Ssiz_t Index(const char *pat, Ssiz_t i=0, ECaseCompare cmp=kExact) const
Definition: TString.h:582
#define I(x, y, z)
UInt_t GetNumber() const
Definition: ClassInfo.h:73
std::vector< std::vector< TMVA::Event * > > fValidSigEvents
Definition: DataLoader.h:219
void AddSignalTree(TTree *signal, Double_t weight=1.0, Types::ETreeType treetype=Types::kMaxTreeType)
Definition: DataLoader.cxx:352
std::vector< VariableInfo > & GetVariableInfos()
Definition: DataSetInfo.h:112
gr SetName("gr")
void SetCut(const TString &cut, const TString &className="")
Definition: DataLoader.cxx:529
char name[80]
Definition: TGX11.cxx:109
void AddSpectator(const TString &expression, const TString &title="", const TString &unit="", Double_t min=0, Double_t max=0)
Definition: DataLoader.cxx:483