Logo ROOT   6.12/07
Reference Guide
DataLoader.cxx
Go to the documentation of this file.
1 // @(#)root/tmva $Id$
2 // Author: Omar Zapata
3 // Mentors: Lorenzo Moneta, Sergei Gleyzer
4 //NOTE: Based on TMVA::Factory
5 
6 /**********************************************************************************
7  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
8  * Package: TMVA *
9  * Class : DataLoader *
10  * Web : http://tmva.sourceforge.net *
11  * *
12  * Description: *
13  * This is a class to load datasets into every booked method *
14  * *
15  * Authors (alphabetical): *
16  * Lorenzo Moneta <Lorenzo.Moneta@cern.ch> - CERN, Switzerland *
17  * Omar Zapata <Omar.Zapata@cern.ch> - ITM/UdeA, Colombia *
18  * Sergei Gleyzer<sergei.gleyzer@cern.ch> - CERN, Switzerland *
19  * *
20  * Copyright (c) 2005-2015: *
21  * CERN, Switzerland *
22  * ITM/UdeA, Colombia *
23  * *
24  * Redistribution and use in source and binary forms, with or without *
25  * modification, are permitted according to the terms listed in LICENSE *
26  * (http://tmva.sourceforge.net/LICENSE) *
27  **********************************************************************************/
28 
29 
30 /*! \class TMVA::DataLoader
31 \ingroup TMVA
32 
33 */
34 
35 #include "TROOT.h"
36 #include "TFile.h"
37 #include "TTree.h"
38 #include "TLeaf.h"
39 #include "TEventList.h"
40 #include "TH2.h"
41 #include "TText.h"
42 #include "TStyle.h"
43 #include "TMatrixF.h"
44 #include "TMatrixDSym.h"
45 #include "TPaletteAxis.h"
46 #include "TPrincipal.h"
47 #include "TMath.h"
48 #include "TObjString.h"
49 #include "TRandom3.h"
50 
51 #include <string.h>
52 
53 #include "TMVA/Configurable.h"
54 #include "TMVA/DataLoader.h"
55 #include "TMVA/Config.h"
56 #include "TMVA/Tools.h"
57 #include "TMVA/Ranking.h"
58 #include "TMVA/DataSet.h"
59 #include "TMVA/IMethod.h"
60 #include "TMVA/MethodBase.h"
61 #include "TMVA/DataInputHandler.h"
62 #include "TMVA/DataSetManager.h"
63 #include "TMVA/DataSetInfo.h"
64 #include "TMVA/MethodBoost.h"
65 #include "TMVA/MethodCategory.h"
66 
67 #include "TMVA/VariableInfo.h"
74 
75 
77 #include "TMVA/ResultsRegression.h"
78 #include "TMVA/ResultsMulticlass.h"
79 #include "TMVA/Types.h"
80 
82 
83 
84 ////////////////////////////////////////////////////////////////////////////////
85 
87 : Configurable( ),
88  fDataSetManager ( NULL ), //DSMTEST
89  fDataInputHandler ( new DataInputHandler ),
90  fTransformations ( "I" ),
91  fVerbose ( kFALSE ),
92  fDataAssignType ( kAssignEvents ),
93  fATreeEvent (0),
94  fMakeFoldDataSet ( kFALSE )
95 {
97  SetName(thedlName.Data());
98  fLogger->SetSource("DataLoader");
99 }
100 
101 ////////////////////////////////////////////////////////////////////////////////
102 
104 {
105  // destructor
106 
107  std::vector<TMVA::VariableTransformBase*>::iterator trfIt = fDefaultTrfs.begin();
108  for (;trfIt != fDefaultTrfs.end(); trfIt++) delete (*trfIt);
109 
110  delete fDataInputHandler;
111 
112  // destroy singletons
113  // DataSetManager::DestroyInstance(); // DSMTEST replaced by following line
114  delete fDataSetManager; // DSMTEST
115 
116  // problem with call of REGISTER_METHOD macro ...
117  // ClassifierDataLoader::DestroyInstance();
118  // Types::DestroyInstance();
119  //Tools::DestroyInstance();
120  //Config::DestroyInstance();
121 }
122 
123 
124 ////////////////////////////////////////////////////////////////////////////////
125 
127 {
128  return fDataSetManager->AddDataSetInfo(dsi); // DSMTEST
129 }
130 
131 ////////////////////////////////////////////////////////////////////////////////
132 
134 {
135  DataSetInfo* dsi = fDataSetManager->GetDataSetInfo(dsiName); // DSMTEST
136 
137  if (dsi!=0) return *dsi;
138 
139  return fDataSetManager->AddDataSetInfo(*(new DataSetInfo(dsiName))); // DSMTEST
140 }
141 
142 ////////////////////////////////////////////////////////////////////////////////
143 
145 {
146  return DefaultDataSetInfo(); // DSMTEST
147 }
148 
149 ////////////////////////////////////////////////////////////////////////////////
150 /// Transforms the variables and return a new DataLoader with the transformed
151 /// variables
152 
154 {
155  TString trOptions = "0";
156  TString trName = "None";
157  if (trafoDefinition.Contains("(")) {
158 
159  // contains transformation parameters
160  Ssiz_t parStart = trafoDefinition.Index( "(" );
161  Ssiz_t parLen = trafoDefinition.Index( ")", parStart )-parStart+1;
162 
163  trName = trafoDefinition(0,parStart);
164  trOptions = trafoDefinition(parStart,parLen);
165  trOptions.Remove(parLen-1,1);
166  trOptions.Remove(0,1);
167  }
168  else
169  trName = trafoDefinition;
170 
171  VarTransformHandler* handler = new VarTransformHandler(this);
172  // variance threshold variable transformation
173  if (trName == "VT") {
174 
175  // find threshold value from given input
176  Double_t threshold = 0.0;
177  if (!trOptions.IsFloat()){
178  Log() << kFATAL << " VT transformation must be passed a floating threshold value" << Endl;
179  delete handler;
180  return this;
181  }
182  else
183  threshold = trOptions.Atof();
184  TMVA::DataLoader *transformedLoader = handler->VarianceThreshold(threshold);
185  delete handler;
186  return transformedLoader;
187  }
188  else {
189  Log() << kFATAL << "Incorrect transformation string provided, please check" << Endl;
190  }
191  Log() << kINFO << "No transformation applied, returning original loader" << Endl;
192  return this;
193 }
194 
195 ////////////////////////////////////////////////////////////////////////////////
196 // the next functions are to assign events directly
197 
198 ////////////////////////////////////////////////////////////////////////////////
199 /// create the data assignment tree (for event-wise data assignment by user)
200 
202 {
203  TTree * assignTree = new TTree( name, name );
204  assignTree->SetDirectory(0);
205  assignTree->Branch( "type", &fATreeType, "ATreeType/I" );
206  assignTree->Branch( "weight", &fATreeWeight, "ATreeWeight/F" );
207 
208  std::vector<VariableInfo>& vars = DefaultDataSetInfo().GetVariableInfos();
209  std::vector<VariableInfo>& tgts = DefaultDataSetInfo().GetTargetInfos();
210  std::vector<VariableInfo>& spec = DefaultDataSetInfo().GetSpectatorInfos();
211 
212  if (fATreeEvent.size()==0) fATreeEvent.resize(vars.size()+tgts.size()+spec.size());
213  // add variables
214  for (UInt_t ivar=0; ivar<vars.size(); ivar++) {
215  TString vname = vars[ivar].GetExpression();
216  assignTree->Branch( vname, &fATreeEvent[ivar], vname + "/F" );
217  }
218  // add targets
219  for (UInt_t itgt=0; itgt<tgts.size(); itgt++) {
220  TString vname = tgts[itgt].GetExpression();
221  assignTree->Branch( vname, &fATreeEvent[vars.size()+itgt], vname + "/F" );
222  }
223  // add spectators
224  for (UInt_t ispc=0; ispc<spec.size(); ispc++) {
225  TString vname = spec[ispc].GetExpression();
226  assignTree->Branch( vname, &fATreeEvent[vars.size()+tgts.size()+ispc], vname + "/F" );
227  }
228  return assignTree;
229 }
230 
231 ////////////////////////////////////////////////////////////////////////////////
232 /// add signal training event
233 
234 void TMVA::DataLoader::AddSignalTrainingEvent( const std::vector<Double_t>& event, Double_t weight )
235 {
236  AddEvent( "Signal", Types::kTraining, event, weight );
237 }
238 
239 ////////////////////////////////////////////////////////////////////////////////
240 /// add signal testing event
241 
242 void TMVA::DataLoader::AddSignalTestEvent( const std::vector<Double_t>& event, Double_t weight )
243 {
244  AddEvent( "Signal", Types::kTesting, event, weight );
245 }
246 
247 ////////////////////////////////////////////////////////////////////////////////
248 /// add signal training event
249 
250 void TMVA::DataLoader::AddBackgroundTrainingEvent( const std::vector<Double_t>& event, Double_t weight )
251 {
252  AddEvent( "Background", Types::kTraining, event, weight );
253 }
254 
255 ////////////////////////////////////////////////////////////////////////////////
256 /// add signal training event
257 
258 void TMVA::DataLoader::AddBackgroundTestEvent( const std::vector<Double_t>& event, Double_t weight )
259 {
260  AddEvent( "Background", Types::kTesting, event, weight );
261 }
262 
263 ////////////////////////////////////////////////////////////////////////////////
264 /// add signal training event
265 
266 void TMVA::DataLoader::AddTrainingEvent( const TString& className, const std::vector<Double_t>& event, Double_t weight )
267 {
268  AddEvent( className, Types::kTraining, event, weight );
269 }
270 
271 ////////////////////////////////////////////////////////////////////////////////
272 /// add signal test event
273 
274 void TMVA::DataLoader::AddTestEvent( const TString& className, const std::vector<Double_t>& event, Double_t weight )
275 {
276  AddEvent( className, Types::kTesting, event, weight );
277 }
278 
279 ////////////////////////////////////////////////////////////////////////////////
280 /// add event
281 /// vector event : the order of values is: variables + targets + spectators
282 
284  const std::vector<Double_t>& event, Double_t weight )
285 {
286  ClassInfo* theClass = DefaultDataSetInfo().AddClass(className); // returns class (creates it if necessary)
287  UInt_t clIndex = theClass->GetNumber();
288 
289 
290  // set analysistype to "kMulticlass" if more than two classes and analysistype == kNoAnalysisType
291  if( fAnalysisType == Types::kNoAnalysisType && DefaultDataSetInfo().GetNClasses() > 2 )
293 
294 
295  if (clIndex>=fTrainAssignTree.size()) {
296  fTrainAssignTree.resize(clIndex+1, 0);
297  fTestAssignTree.resize(clIndex+1, 0);
298  }
299 
300  if (fTrainAssignTree[clIndex]==0) { // does not exist yet
301  fTrainAssignTree[clIndex] = CreateEventAssignTrees( Form("TrainAssignTree_%s", className.Data()) );
302  fTestAssignTree[clIndex] = CreateEventAssignTrees( Form("TestAssignTree_%s", className.Data()) );
303  }
304 
305  fATreeType = clIndex;
306  fATreeWeight = weight;
307  for (UInt_t ivar=0; ivar<event.size(); ivar++) fATreeEvent[ivar] = event[ivar];
308 
309  if(tt==Types::kTraining) fTrainAssignTree[clIndex]->Fill();
310  else fTestAssignTree[clIndex]->Fill();
311 
312 }
313 
314 ////////////////////////////////////////////////////////////////////////////////
315 ///
316 
318 {
319  return fTrainAssignTree[clIndex]!=0;
320 }
321 
322 ////////////////////////////////////////////////////////////////////////////////
323 /// assign event-wise local trees to data set
324 
326 {
327  UInt_t size = fTrainAssignTree.size();
328  for(UInt_t i=0; i<size; i++) {
329  if(!UserAssignEvents(i)) continue;
330  const TString& className = DefaultDataSetInfo().GetClassInfo(i)->GetName();
331  SetWeightExpression( "weight", className );
332  AddTree(fTrainAssignTree[i], className, 1.0, TCut(""), Types::kTraining );
333  AddTree(fTestAssignTree[i], className, 1.0, TCut(""), Types::kTesting );
334  }
335 }
336 
337 ////////////////////////////////////////////////////////////////////////////////
338 /// number of signal events (used to compute significance)
339 
340 void TMVA::DataLoader::AddTree( TTree* tree, const TString& className, Double_t weight,
341  const TCut& cut, const TString& treetype )
342 {
344  TString tmpTreeType = treetype; tmpTreeType.ToLower();
345  if (tmpTreeType.Contains( "train" ) && tmpTreeType.Contains( "test" )) tt = Types::kMaxTreeType;
346  else if (tmpTreeType.Contains( "train" )) tt = Types::kTraining;
347  else if (tmpTreeType.Contains( "test" )) tt = Types::kTesting;
348  else {
349  Log() << kFATAL << "<AddTree> cannot interpret tree type: \"" << treetype
350  << "\" should be \"Training\" or \"Test\" or \"Training and Testing\"" << Endl;
351  }
352  AddTree( tree, className, weight, cut, tt );
353 }
354 
355 ////////////////////////////////////////////////////////////////////////////////
356 
357 void TMVA::DataLoader::AddTree( TTree* tree, const TString& className, Double_t weight,
358  const TCut& cut, Types::ETreeType tt )
359 {
360  if(!tree)
361  Log() << kFATAL << "Tree does not exist (empty pointer)." << Endl;
362 
363  DefaultDataSetInfo().AddClass( className );
364 
365  // set analysistype to "kMulticlass" if more than two classes and analysistype == kNoAnalysisType
366  if( fAnalysisType == Types::kNoAnalysisType && DefaultDataSetInfo().GetNClasses() > 2 )
368 
369  Log() << kINFO<< "Add Tree " << tree->GetName() << " of type " << className
370  << " with " << tree->GetEntries() << " events" << Endl;
371  DataInput().AddTree( tree, className, weight, cut, tt );
372 }
373 
374 ////////////////////////////////////////////////////////////////////////////////
375 /// number of signal events (used to compute significance)
376 
378 {
379  AddTree( signal, "Signal", weight, TCut(""), treetype );
380 }
381 
382 ////////////////////////////////////////////////////////////////////////////////
383 /// add signal tree from text file
384 
386 {
387  // create trees from these ascii files
388  TTree* signalTree = new TTree( "TreeS", "Tree (S)" );
389  signalTree->ReadFile( datFileS );
390 
391  Log() << kINFO << "Create TTree objects from ASCII input files ... \n- Signal file : \""
392  << datFileS << Endl;
393 
394  // number of signal events (used to compute significance)
395  AddTree( signalTree, "Signal", weight, TCut(""), treetype );
396 }
397 
398 ////////////////////////////////////////////////////////////////////////////////
399 
401 {
402  AddTree( signal, "Signal", weight, TCut(""), treetype );
403 }
404 
405 ////////////////////////////////////////////////////////////////////////////////
406 /// number of signal events (used to compute significance)
407 
409 {
410  AddTree( signal, "Background", weight, TCut(""), treetype );
411 }
412 
413 ////////////////////////////////////////////////////////////////////////////////
414 /// add background tree from text file
415 
417 {
418  // create trees from these ascii files
419  TTree* bkgTree = new TTree( "TreeB", "Tree (B)" );
420  bkgTree->ReadFile( datFileB );
421 
422  Log() << kINFO << "Create TTree objects from ASCII input files ... \n- Background file : \""
423  << datFileB << Endl;
424 
425  // number of signal events (used to compute significance)
426  AddTree( bkgTree, "Background", weight, TCut(""), treetype );
427 }
428 
429 ////////////////////////////////////////////////////////////////////////////////
430 
432 {
433  AddTree( signal, "Background", weight, TCut(""), treetype );
434 }
435 
436 ////////////////////////////////////////////////////////////////////////////////
437 
439 {
440  AddTree( tree, "Signal", weight );
441 }
442 
443 ////////////////////////////////////////////////////////////////////////////////
444 
446 {
447  AddTree( tree, "Background", weight );
448 }
449 
450 ////////////////////////////////////////////////////////////////////////////////
451 /// set background tree
452 
453 void TMVA::DataLoader::SetTree( TTree* tree, const TString& className, Double_t weight )
454 {
455  AddTree( tree, className, weight, TCut(""), Types::kMaxTreeType );
456 }
457 
458 ////////////////////////////////////////////////////////////////////////////////
459 /// define the input trees for signal and background; no cuts are applied
460 
462  Double_t signalWeight, Double_t backgroundWeight )
463 {
464  AddTree( signal, "Signal", signalWeight, TCut(""), Types::kMaxTreeType );
465  AddTree( background, "Background", backgroundWeight, TCut(""), Types::kMaxTreeType );
466 }
467 
468 ////////////////////////////////////////////////////////////////////////////////
469 
470 void TMVA::DataLoader::SetInputTrees( const TString& datFileS, const TString& datFileB,
471  Double_t signalWeight, Double_t backgroundWeight )
472 {
473  DataInput().AddTree( datFileS, "Signal", signalWeight );
474  DataInput().AddTree( datFileB, "Background", backgroundWeight );
475 }
476 
477 ////////////////////////////////////////////////////////////////////////////////
478 /// define the input trees for signal and background from single input tree,
479 /// containing both signal and background events distinguished by the type
480 /// identifiers: SigCut and BgCut
481 
482 void TMVA::DataLoader::SetInputTrees( TTree* inputTree, const TCut& SigCut, const TCut& BgCut )
483 {
484  AddTree( inputTree, "Signal", 1.0, SigCut, Types::kMaxTreeType );
485  AddTree( inputTree, "Background", 1.0, BgCut , Types::kMaxTreeType );
486 }
487 
488 ////////////////////////////////////////////////////////////////////////////////
489 /// user inserts discriminating variable in data set info
490 
491 void TMVA::DataLoader::AddVariable( const TString& expression, const TString& title, const TString& unit,
492  char type, Double_t min, Double_t max )
493 {
494  DefaultDataSetInfo().AddVariable( expression, title, unit, min, max, type );
495 }
496 
497 ////////////////////////////////////////////////////////////////////////////////
498 /// user inserts discriminating variable in data set info
499 
500 void TMVA::DataLoader::AddVariable( const TString& expression, char type,
501  Double_t min, Double_t max )
502 {
503  DefaultDataSetInfo().AddVariable( expression, "", "", min, max, type );
504 }
505 
506 ////////////////////////////////////////////////////////////////////////////////
507 /// user inserts target in data set info
508 
509 void TMVA::DataLoader::AddTarget( const TString& expression, const TString& title, const TString& unit,
510  Double_t min, Double_t max )
511 {
514 
515  DefaultDataSetInfo().AddTarget( expression, title, unit, min, max );
516 }
517 
518 ////////////////////////////////////////////////////////////////////////////////
519 /// user inserts target in data set info
520 
521 void TMVA::DataLoader::AddSpectator( const TString& expression, const TString& title, const TString& unit,
522  Double_t min, Double_t max )
523 {
524  DefaultDataSetInfo().AddSpectator( expression, title, unit, min, max );
525 }
526 
527 ////////////////////////////////////////////////////////////////////////////////
528 /// default creation
529 
531 {
532  return AddDataSet( fName );
533 }
534 
535 ////////////////////////////////////////////////////////////////////////////////
536 /// fill input variables in data set
537 
538 void TMVA::DataLoader::SetInputVariables( std::vector<TString>* theVariables )
539 {
540  for (std::vector<TString>::iterator it=theVariables->begin();
541  it!=theVariables->end(); it++) AddVariable(*it);
542 }
543 
544 ////////////////////////////////////////////////////////////////////////////////
545 
547 {
548  DefaultDataSetInfo().SetWeightExpression(variable, "Signal");
549 }
550 
551 ////////////////////////////////////////////////////////////////////////////////
552 
554 {
555  DefaultDataSetInfo().SetWeightExpression(variable, "Background");
556 }
557 
558 ////////////////////////////////////////////////////////////////////////////////
559 
560 void TMVA::DataLoader::SetWeightExpression( const TString& variable, const TString& className )
561 {
562  //Log() << kWarning << DefaultDataSetInfo().GetNClasses() /*fClasses.size()*/ << Endl;
563  if (className=="") {
564  SetSignalWeightExpression(variable);
566  }
567  else DefaultDataSetInfo().SetWeightExpression( variable, className );
568 }
569 
570 ////////////////////////////////////////////////////////////////////////////////
571 
572 void TMVA::DataLoader::SetCut( const TString& cut, const TString& className ) {
573  SetCut( TCut(cut), className );
574 }
575 
576 ////////////////////////////////////////////////////////////////////////////////
577 
578 void TMVA::DataLoader::SetCut( const TCut& cut, const TString& className )
579 {
580  DefaultDataSetInfo().SetCut( cut, className );
581 }
582 
583 ////////////////////////////////////////////////////////////////////////////////
584 
585 void TMVA::DataLoader::AddCut( const TString& cut, const TString& className )
586 {
587  AddCut( TCut(cut), className );
588 }
589 
590 ////////////////////////////////////////////////////////////////////////////////
591 void TMVA::DataLoader::AddCut( const TCut& cut, const TString& className )
592 {
593  DefaultDataSetInfo().AddCut( cut, className );
594 }
595 
596 ////////////////////////////////////////////////////////////////////////////////
597 /// prepare the training and test trees
598 
600  Int_t NsigTrain, Int_t NbkgTrain, Int_t NsigTest, Int_t NbkgTest,
601  const TString& otherOpt )
602 {
604 
605  AddCut( cut );
606 
607  DefaultDataSetInfo().SetSplitOptions( Form("nTrain_Signal=%i:nTrain_Background=%i:nTest_Signal=%i:nTest_Background=%i:%s",
608  NsigTrain, NbkgTrain, NsigTest, NbkgTest, otherOpt.Data()) );
609 }
610 
611 ////////////////////////////////////////////////////////////////////////////////
612 /// prepare the training and test trees
613 /// kept for backward compatibility
614 
616 {
618 
619  AddCut( cut );
620 
621  DefaultDataSetInfo().SetSplitOptions( Form("nTrain_Signal=%i:nTrain_Background=%i:nTest_Signal=%i:nTest_Background=%i:SplitMode=Random:EqualTrainSample:!V",
622  Ntrain, Ntrain, Ntest, Ntest) );
623 }
624 
625 ////////////////////////////////////////////////////////////////////////////////
626 /// prepare the training and test trees
627 /// -> same cuts for signal and background
628 
630 {
632 
634  AddCut( cut );
636 }
637 
638 ////////////////////////////////////////////////////////////////////////////////
639 /// prepare the training and test trees
640 
641 void TMVA::DataLoader::PrepareTrainingAndTestTree( TCut sigcut, TCut bkgcut, const TString& splitOpt )
642 {
643  // if event-wise data assignment, add local trees to dataset first
645 
646  //Log() << kINFO <<"Preparing trees for training and testing..."<< Endl;
647  AddCut( sigcut, "Signal" );
648  AddCut( bkgcut, "Background" );
649 
650  DefaultDataSetInfo().SetSplitOptions( splitOpt );
651 }
652 
653 ////////////////////////////////////////////////////////////////////////////////
654 /// Function required to split the training and testing datasets into a
655 /// number of folds. Required by the CrossValidation and HyperParameterOptimisation
656 /// classes. The option to split the training dataset into a training set and
657 /// a validation set is implemented but not currently used.
658 
659 void TMVA::DataLoader::MakeKFoldDataSet(UInt_t numberFolds, bool validationSet){
660  // No need to do it again if the sets have already been split.
661  if(fMakeFoldDataSet){
662  Log() << kInfo << "Splitting in k-folds has been already done" << Endl;
663  return;
664  }
665 
667 
668  // Get the original event vectors for testing and training from the dataset.
669  const std::vector<Event*> TrainingData = DefaultDataSetInfo().GetDataSet()->GetEventCollection(Types::kTraining);
670  const std::vector<Event*> TestingData = DefaultDataSetInfo().GetDataSet()->GetEventCollection(Types::kTesting);
671 
672  std::vector<Event*> TrainSigData;
673  std::vector<Event*> TrainBkgData;
674  std::vector<Event*> TestSigData;
675  std::vector<Event*> TestBkgData;
676 
677  // Split the testing and training sets into signal and background classes.
678  for(UInt_t i=0; i<TrainingData.size(); ++i){
679  if( strncmp( DefaultDataSetInfo().GetClassInfo( TrainingData.at(i)->GetClass() )->GetName(), "Signal", 6) == 0){ TrainSigData.push_back(TrainingData.at(i)); }
680  else if( strncmp( DefaultDataSetInfo().GetClassInfo( TrainingData.at(i)->GetClass() )->GetName(), "Background", 10) == 0){ TrainBkgData.push_back(TrainingData.at(i)); }
681  else{
682  Log() << kFATAL << "DataSets should only contain Signal and Background classes for classification, " << DefaultDataSetInfo().GetClassInfo( TrainingData.at(i)->GetClass() )->GetName() << " is not a recognised class" << Endl;
683  }
684  }
685 
686  for(UInt_t i=0; i<TestingData.size(); ++i){
687  if( strncmp( DefaultDataSetInfo().GetClassInfo( TestingData.at(i)->GetClass() )->GetName(), "Signal", 6) == 0){ TestSigData.push_back(TestingData.at(i)); }
688  else if( strncmp( DefaultDataSetInfo().GetClassInfo( TestingData.at(i)->GetClass() )->GetName(), "Background", 10) == 0){ TestBkgData.push_back(TestingData.at(i)); }
689  else{
690  Log() << kFATAL << "DataSets should only contain Signal and Background classes for classification, " << DefaultDataSetInfo().GetClassInfo( TestingData.at(i)->GetClass() )->GetName() << " is not a recognised class" << Endl;
691  }
692  }
693 
694 
695  // Split the sets into the number of folds.
696  if(validationSet){
697  std::vector<std::vector<Event*>> tempSigEvents = SplitSets(TrainSigData,0,2);
698  std::vector<std::vector<Event*>> tempBkgEvents = SplitSets(TrainBkgData,0,2);
699  fTrainSigEvents = SplitSets(tempSigEvents.at(0),0,numberFolds);
700  fTrainBkgEvents = SplitSets(tempBkgEvents.at(0),0,numberFolds);
701  fValidSigEvents = SplitSets(tempSigEvents.at(1),0,numberFolds);
702  fValidBkgEvents = SplitSets(tempBkgEvents.at(1),0,numberFolds);
703  }
704  else{
705  fTrainSigEvents = SplitSets(TrainSigData,0,numberFolds);
706  fTrainBkgEvents = SplitSets(TrainBkgData,0,numberFolds);
707  }
708 
709  fTestSigEvents = SplitSets(TestSigData,0,numberFolds);
710  fTestBkgEvents = SplitSets(TestBkgData,0,numberFolds);
711 }
712 
713 ////////////////////////////////////////////////////////////////////////////////
714 /// Function for assigning the correct folds to the testing or training set.
715 
717 
718  UInt_t numFolds = fTrainSigEvents.size();
719 
720  std::vector<Event*>* tempTrain = new std::vector<Event*>;
721  std::vector<Event*>* tempTest = new std::vector<Event*>;
722 
723  UInt_t nTrain = 0;
724  UInt_t nTest = 0;
725 
726  // Get the number of events so the memory can be reserved.
727  for(UInt_t i=0; i<numFolds; ++i){
728  if(tt == Types::kTraining){
729  if(i!=foldNumber){
730  nTrain += fTrainSigEvents.at(i).size();
731  nTrain += fTrainBkgEvents.at(i).size();
732  }
733  else{
734  nTest += fTrainSigEvents.at(i).size();
735  nTest += fTrainSigEvents.at(i).size();
736  }
737  }
738  else if(tt == Types::kValidation){
739  if(i!=foldNumber){
740  nTrain += fValidSigEvents.at(i).size();
741  nTrain += fValidBkgEvents.at(i).size();
742  }
743  else{
744  nTest += fValidSigEvents.at(i).size();
745  nTest += fValidSigEvents.at(i).size();
746  }
747  }
748  else if(tt == Types::kTesting){
749  if(i!=foldNumber){
750  nTrain += fTestSigEvents.at(i).size();
751  nTrain += fTestBkgEvents.at(i).size();
752  }
753  else{
754  nTest += fTestSigEvents.at(i).size();
755  nTest += fTestSigEvents.at(i).size();
756  }
757  }
758  }
759 
760  // Reserve memory before filling vectors
761  tempTrain->reserve(nTrain);
762  tempTest->reserve(nTest);
763 
764  // Fill vectors with correct folds for testing and training.
765  for(UInt_t j=0; j<numFolds; ++j){
766  if(tt == Types::kTraining){
767  if(j!=foldNumber){
768  tempTrain->insert(tempTrain->end(), fTrainSigEvents.at(j).begin(), fTrainSigEvents.at(j).end());
769  tempTrain->insert(tempTrain->end(), fTrainBkgEvents.at(j).begin(), fTrainBkgEvents.at(j).end());
770  }
771  else{
772  tempTest->insert(tempTest->end(), fTrainSigEvents.at(j).begin(), fTrainSigEvents.at(j).end());
773  tempTest->insert(tempTest->end(), fTrainBkgEvents.at(j).begin(), fTrainBkgEvents.at(j).end());
774  }
775  }
776  else if(tt == Types::kValidation){
777  if(j!=foldNumber){
778  tempTrain->insert(tempTrain->end(), fValidSigEvents.at(j).begin(), fValidSigEvents.at(j).end());
779  tempTrain->insert(tempTrain->end(), fValidBkgEvents.at(j).begin(), fValidBkgEvents.at(j).end());
780  }
781  else{
782  tempTest->insert(tempTest->end(), fValidSigEvents.at(j).begin(), fValidSigEvents.at(j).end());
783  tempTest->insert(tempTest->end(), fValidBkgEvents.at(j).begin(), fValidBkgEvents.at(j).end());
784  }
785  }
786  else if(tt == Types::kTesting){
787  if(j!=foldNumber){
788  tempTrain->insert(tempTrain->end(), fTestSigEvents.at(j).begin(), fTestSigEvents.at(j).end());
789  tempTrain->insert(tempTrain->end(), fTestBkgEvents.at(j).begin(), fTestBkgEvents.at(j).end());
790  }
791  else{
792  tempTest->insert(tempTest->end(), fTestSigEvents.at(j).begin(), fTestSigEvents.at(j).end());
793  tempTest->insert(tempTest->end(), fTestBkgEvents.at(j).begin(), fTestBkgEvents.at(j).end());
794  }
795  }
796  }
797 
798  // Assign the vectors of the events to rebuild the dataset
801  delete tempTest;
802  delete tempTrain;
803 }
804 
805 ////////////////////////////////////////////////////////////////////////////////
806 /// Splits the input vector in to equally sized randomly sampled folds.
807 
808 std::vector<std::vector<TMVA::Event*>> TMVA::DataLoader::SplitSets(std::vector<TMVA::Event*>& oldSet, int seedNum, int numFolds){
809 
810  ULong64_t nEntries = oldSet.size();
811  ULong64_t foldSize = nEntries/numFolds;
812 
813  std::vector<std::vector<Event*>> tempSets;
814  tempSets.resize(numFolds);
815 
816  TRandom3 r(seedNum);
817 
818  ULong64_t inSet = 0;
819 
820  for(ULong64_t i=0; i<nEntries; i++){
821  bool inTree = false;
822  if(inSet == foldSize*numFolds){
823  break;
824  }
825  else{
826  while(!inTree){
827  int s = r.Integer(numFolds);
828  if(tempSets.at(s).size()<foldSize){
829  tempSets.at(s).push_back(oldSet.at(i));
830  inSet++;
831  inTree=true;
832  }
833  }
834  }
835  }
836 
837  return tempSets;
838 
839 }
840 
841 ////////////////////////////////////////////////////////////////////////////////
842 /// Copy method use in VI and CV
843 
845 {
846  TMVA::DataLoader* des=new TMVA::DataLoader(name);
847  DataLoaderCopy(des,this);
848  return des;
849 }
850 
851 ////////////////////////////////////////////////////////////////////////////////
852 ///Loading Dataset from DataInputHandler for subseed
853 
855 {
856  for( std::vector<TreeInfo>::const_iterator treeinfo=src->DataInput().Sbegin();treeinfo!=src->DataInput().Send();treeinfo++)
857  {
858  des->AddSignalTree( (*treeinfo).GetTree(), (*treeinfo).GetWeight(),(*treeinfo).GetTreeType());
859  }
860 
861  for( std::vector<TreeInfo>::const_iterator treeinfo=src->DataInput().Bbegin();treeinfo!=src->DataInput().Bend();treeinfo++)
862  {
863  des->AddBackgroundTree( (*treeinfo).GetTree(), (*treeinfo).GetWeight(),(*treeinfo).GetTreeType());
864  }
865 }
866 
867 ////////////////////////////////////////////////////////////////////////////////
868 /// returns the correlation matrix of datasets
869 
871 {
872  const TMatrixD * m = DefaultDataSetInfo().CorrelationMatrix(className);
874  "CorrelationMatrix"+className, "Correlation Matrix ("+className+")");
875 }
DataSetInfo * GetDataSetInfo(const TString &dsiName)
returns datasetinfo object for given name
void AddBackgroundTree(TTree *background, Double_t weight=1.0, Types::ETreeType treetype=Types::kMaxTreeType)
number of signal events (used to compute significance)
Definition: DataLoader.cxx:408
virtual const char * GetName() const
Returns name of object.
Definition: TNamed.h:47
DataSetManager * fDataSetManager
Definition: DataLoader.h:190
virtual ~DataLoader()
Definition: DataLoader.cxx:103
Random number generator class based on M.
Definition: TRandom3.h:27
std::vector< TreeInfo >::const_iterator Bend() const
auto * tt
Definition: textangle.C:16
MsgLogger & Endl(MsgLogger &ml)
Definition: MsgLogger.h:158
VariableInfo & AddTarget(const TString &expression, const TString &title, const TString &unit, Double_t min, Double_t max, Bool_t normalized=kTRUE, void *external=0)
add a variable (can be a complex expression) to the set of variables used in the MV analysis ...
auto * m
Definition: textangle.C:8
void AddTrainingEvent(const TString &className, const std::vector< Double_t > &event, Double_t weight)
add signal training event
Definition: DataLoader.cxx:266
std::vector< VariableInfo > & GetSpectatorInfos()
Definition: DataSetInfo.h:104
std::vector< TreeInfo >::const_iterator Bbegin() const
std::vector< TMVA::VariableTransformBase * > fDefaultTrfs
Definition: DataLoader.h:195
DataLoader(TString thedlName="default")
Definition: DataLoader.cxx:86
void SetCut(const TCut &cut, const TString &className)
set the cut for the classes
TMVA::DataLoader * VarianceThreshold(Double_t threshold)
Computes variance of all the variables and returns a new DataLoader with the selected variables whose...
std::vector< std::vector< TMVA::Event * > > fTrainBkgEvents
Definition: DataLoader.h:211
virtual void SetName(const char *name)
Set the name of the TNamed.
Definition: TNamed.cxx:140
MsgLogger & Log() const
Definition: Configurable.h:122
DataSetInfo & GetDataSetInfo()
Definition: DataLoader.cxx:144
Bool_t IsFloat() const
Returns kTRUE if string contains a floating point or integer number.
Definition: TString.cxx:1845
TTree * CreateEventAssignTrees(const TString &name)
create the data assignment tree (for event-wise data assignment by user)
Definition: DataLoader.cxx:201
DataSetInfo & DefaultDataSetInfo()
default creation
Definition: DataLoader.cxx:530
Ssiz_t Index(const char *pat, Ssiz_t i=0, ECaseCompare cmp=kExact) const
Definition: TString.h:585
DataLoader * VarTransform(TString trafoDefinition)
Transforms the variables and return a new DataLoader with the transformed variables.
Definition: DataLoader.cxx:153
Basic string class.
Definition: TString.h:125
void ToLower()
Change string to lower-case.
Definition: TString.cxx:1099
int Int_t
Definition: RtypesCore.h:41
void MakeKFoldDataSet(UInt_t numberFolds, bool validationSet=false)
Function required to split the training and testing datasets into a number of folds.
Definition: DataLoader.cxx:659
bool Bool_t
Definition: RtypesCore.h:59
void DataLoaderCopy(TMVA::DataLoader *des, TMVA::DataLoader *src)
std::vector< TreeInfo >::const_iterator Send() const
void AddCut(const TCut &cut, const TString &className)
set the cut for the classes
void SetBackgroundTree(TTree *background, Double_t weight=1.0)
Definition: DataLoader.cxx:445
const std::vector< Event * > & GetEventCollection(Types::ETreeType type=Types::kMaxTreeType) const
Definition: DataSet.h:225
DataInputHandler * fDataInputHandler
Definition: DataLoader.h:193
Types::EAnalysisType fAnalysisType
Definition: DataLoader.h:221
void AddBackgroundTestEvent(const std::vector< Double_t > &event, Double_t weight=1.0)
add signal training event
Definition: DataLoader.cxx:258
TH2 * GetCorrelationMatrix(const TString &className)
returns the correlation matrix of datasets
Definition: DataLoader.cxx:870
std::vector< std::vector< TMVA::Event * > > fTestBkgEvents
Definition: DataLoader.h:215
MsgLogger * fLogger
Definition: Configurable.h:128
void AddVariable(const TString &expression, const TString &title, const TString &unit, char type='F', Double_t min=0, Double_t max=0)
user inserts discriminating variable in data set info
Definition: DataLoader.cxx:491
std::vector< TreeInfo >::const_iterator Sbegin() const
VariableInfo & AddSpectator(const TString &expression, const TString &title, const TString &unit, Double_t min, Double_t max, char type='F', Bool_t normalized=kTRUE, void *external=0)
add a spectator (can be a complex expression) to the set of spectator variables used in the MV analys...
Class that contains all the information of a class.
Definition: ClassInfo.h:49
TH2 * CreateCorrelationMatrixHist(const TMatrixD *m, const TString &hName, const TString &hTitle) const
void AddTestEvent(const TString &className, const std::vector< Double_t > &event, Double_t weight)
add signal test event
Definition: DataLoader.cxx:274
void SetInputTrees(const TString &signalFileName, const TString &backgroundFileName, Double_t signalWeight=1.0, Double_t backgroundWeight=1.0)
Definition: DataLoader.cxx:470
virtual UInt_t Integer(UInt_t imax)
Returns a random integer on [ 0, imax-1 ].
Definition: TRandom.cxx:341
void SetTree(TTree *tree, const TString &className, Double_t weight)
set background tree
Definition: DataLoader.cxx:453
void PrepareFoldDataSet(UInt_t foldNumber, Types::ETreeType tt)
Function for assigning the correct folds to the testing or training set.
Definition: DataLoader.cxx:716
Class that contains all the data information.
Definition: DataSetInfo.h:60
void SetInputVariables(std::vector< TString > *theVariables)
fill input variables in data set
Definition: DataLoader.cxx:538
DataSetInfo & AddDataSet(DataSetInfo &)
Definition: DataLoader.cxx:126
std::vector< VariableInfo > & GetTargetInfos()
Definition: DataSetInfo.h:99
void AddCut(const TString &cut, const TString &className="")
Definition: DataLoader.cxx:585
A specialized string object used for TTree selections.
Definition: TCut.h:25
void SetInputTreesFromEventAssignTrees()
assign event-wise local trees to data set
Definition: DataLoader.cxx:325
Float_t fATreeWeight
Definition: DataLoader.h:218
void SetSplitOptions(const TString &so)
Definition: DataSetInfo.h:166
Bool_t fMakeFoldDataSet
Definition: DataLoader.h:223
DataInputHandler & DataInput()
Definition: DataLoader.h:176
ROOT::R::TRInterface & r
Definition: Object.C:4
Service class for 2-Dim histogram classes.
Definition: TH2.h:30
const Int_t kInfo
Definition: TError.h:37
Class that contains all the data information.
ClassInfo * GetClassInfo(Int_t clNum) const
const TMatrixD * CorrelationMatrix(const TString &className) const
void AddTree(TTree *tree, const TString &className, Double_t weight=1.0, const TCut &cut="", Types::ETreeType tt=Types::kMaxTreeType)
add tree of className events for tt (Training;Testing..) type as input ..
void SetWeightExpression(const TString &exp, const TString &className="")
set the weight expressions for the classes if class name is specified, set only for this class if cla...
DataSetInfo & AddDataSetInfo(DataSetInfo &dsi)
stores a copy of the dataset info object
unsigned int UInt_t
Definition: RtypesCore.h:42
char * Form(const char *fmt,...)
std::vector< TTree * > fTestAssignTree
Definition: DataLoader.h:208
Bool_t UserAssignEvents(UInt_t clIndex)
Definition: DataLoader.cxx:317
std::vector< Float_t > fATreeEvent
Definition: DataLoader.h:219
std::vector< std::vector< TMVA::Event * > > fTestSigEvents
Definition: DataLoader.h:214
void PrintClasses() const
TString fName
Definition: TNamed.h:32
const Bool_t kFALSE
Definition: RtypesCore.h:88
DataLoader * MakeCopy(TString name)
Copy method use in VI and CV.
Definition: DataLoader.cxx:844
TString & Remove(Ssiz_t pos)
Definition: TString.h:619
int Ssiz_t
Definition: RtypesCore.h:63
void AddTree(TTree *tree, const TString &className, Double_t weight=1.0, const TCut &cut="", Types::ETreeType tt=Types::kMaxTreeType)
Definition: DataLoader.cxx:357
void PrepareTrainingAndTestTree(const TCut &cut, const TString &splitOpt)
prepare the training and test trees -> same cuts for signal and background
Definition: DataLoader.cxx:629
virtual void SetDirectory(TDirectory *dir)
Change the tree&#39;s directory.
Definition: TTree.cxx:8464
#define ClassImp(name)
Definition: Rtypes.h:359
double Double_t
Definition: RtypesCore.h:55
void AddEvent(const TString &className, Types::ETreeType tt, const std::vector< Double_t > &event, Double_t weight)
add event vector event : the order of values is: variables + targets + spectators ...
Definition: DataLoader.cxx:283
Class that contains all the data information.
void SetBackgroundWeightExpression(const TString &variable)
Definition: DataLoader.cxx:553
int type
Definition: TGX11.cxx:120
unsigned long long ULong64_t
Definition: RtypesCore.h:70
virtual Long64_t ReadFile(const char *filename, const char *branchDescriptor="", char delimiter=' ')
Create or simply read branches from filename.
Definition: TTree.cxx:7172
void AddTarget(const TString &expression, const TString &title="", const TString &unit="", Double_t min=0, Double_t max=0)
user inserts target in data set info
Definition: DataLoader.cxx:509
void SetWeightExpression(const TString &variable, const TString &className="")
Definition: DataLoader.cxx:560
void AddBackgroundTrainingEvent(const std::vector< Double_t > &event, Double_t weight=1.0)
add signal training event
Definition: DataLoader.cxx:250
Bool_t Contains(const char *pat, ECaseCompare cmp=kExact) const
Definition: TString.h:570
static constexpr double s
void SetEventCollection(std::vector< Event *> *, Types::ETreeType, Bool_t deleteEvents=true)
Sets the event collection (by DataSetFactory)
Definition: DataSet.cxx:250
ClassInfo * AddClass(const TString &className)
UInt_t GetNumber() const
Definition: ClassInfo.h:65
void SetSignalWeightExpression(const TString &variable)
Definition: DataLoader.cxx:546
virtual Long64_t GetEntries() const
Definition: TTree.h:382
std::vector< std::vector< TMVA::Event * > > fTrainSigEvents
Definition: DataLoader.h:210
virtual Int_t Branch(TCollection *list, Int_t bufsize=32000, Int_t splitlevel=99, const char *name="")
Create one branch for each element in the collection.
Definition: TTree.cxx:1701
void SetSource(const std::string &source)
Definition: MsgLogger.h:70
std::vector< TTree * > fTrainAssignTree
Definition: DataLoader.h:207
void AddSignalTestEvent(const std::vector< Double_t > &event, Double_t weight=1.0)
add signal testing event
Definition: DataLoader.cxx:242
std::vector< std::vector< TMVA::Event * > > fValidBkgEvents
Definition: DataLoader.h:213
void AddSignalTrainingEvent(const std::vector< Double_t > &event, Double_t weight=1.0)
add signal training event
Definition: DataLoader.cxx:234
friend void DataLoaderCopy(TMVA::DataLoader *des, TMVA::DataLoader *src)
void SetSignalTree(TTree *signal, Double_t weight=1.0)
Definition: DataLoader.cxx:438
VariableInfo & AddVariable(const TString &expression, const TString &title="", const TString &unit="", Double_t min=0, Double_t max=0, char varType='F', Bool_t normalized=kTRUE, void *external=0)
add a variable (can be a complex expression) to the set of variables used in the MV analysis ...
Definition: tree.py:1
Double_t Atof() const
Return floating-point value contained in string.
Definition: TString.cxx:2041
A TTree object has a header with a name and a title.
Definition: TTree.h:70
std::vector< std::vector< TMVA::Event * > > SplitSets(std::vector< TMVA::Event *> &oldSet, int seedNum, int numFolds)
Splits the input vector in to equally sized randomly sampled folds.
Definition: DataLoader.cxx:808
std::vector< std::vector< TMVA::Event * > > fValidSigEvents
Definition: DataLoader.h:212
const Bool_t kTRUE
Definition: RtypesCore.h:87
void AddSignalTree(TTree *signal, Double_t weight=1.0, Types::ETreeType treetype=Types::kMaxTreeType)
number of signal events (used to compute significance)
Definition: DataLoader.cxx:377
DataSet * GetDataSet() const
returns data set
std::vector< VariableInfo > & GetVariableInfos()
Definition: DataSetInfo.h:94
void SetCut(const TString &cut, const TString &className="")
Definition: DataLoader.cxx:572
char name[80]
Definition: TGX11.cxx:109
void AddSpectator(const TString &expression, const TString &title="", const TString &unit="", Double_t min=0, Double_t max=0)
user inserts target in data set info
Definition: DataLoader.cxx:521
const char * Data() const
Definition: TString.h:345