Logo ROOT   6.18/05
Reference Guide
DataSetFactory.cxx
Go to the documentation of this file.
1// @(#)root/tmva $Id$
2// Author: Andreas Hoecker, Peter Speckmayer, Joerg Stelzer, Eckhard von Toerne, Helge Voss
3
4/*****************************************************************************
5 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6 * Package: TMVA *
7 * Class : DataSetFactory *
8 * Web : http://tmva.sourceforge.net *
9 * *
10 * Description: *
11 * Implementation (see header for description) *
12 * *
13 * Authors (alphabetical): *
14 * Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland *
15 * Peter Speckmayer <Peter.Speckmayer@cern.ch> - CERN, Switzerland *
16 * Joerg Stelzer <Joerg.Stelzer@cern.ch> - MSU, USA *
17 * Eckhard von Toerne <evt@physik.uni-bonn.de> - U. of Bonn, Germany *
18 * Helge Voss <Helge.Voss@cern.ch> - MPI-K Heidelberg, Germany *
19 * *
20 * Copyright (c) 2009: *
21 * CERN, Switzerland *
22 * MPI-K Heidelberg, Germany *
23 * U. of Bonn, Germany *
24 * Redistribution and use in source and binary forms, with or without *
25 * modification, are permitted according to the terms listed in LICENSE *
26 * (http://tmva.sourceforge.net/LICENSE) *
27 *****************************************************************************/
28
29/*! \class TMVA::DataSetFactory
30\ingroup TMVA
31
32Class that contains all the data information
33
34*/
35
36#include <assert.h>
37
38#include <map>
39#include <vector>
40#include <iomanip>
41#include <iostream>
42
43#include <algorithm>
44#include <functional>
45#include <numeric>
46#include <random>
47
48#include "TMVA/DataSetFactory.h"
49
50#include "TEventList.h"
51#include "TFile.h"
52#include "TH1.h"
53#include "TH2.h"
54#include "TProfile.h"
55#include "TRandom3.h"
56#include "TMatrixF.h"
57#include "TVectorF.h"
58#include "TMath.h"
59#include "TROOT.h"
60
61#include "TMVA/MsgLogger.h"
62#include "TMVA/Configurable.h"
66#include "TMVA/DataSet.h"
67#include "TMVA/DataSetInfo.h"
69#include "TMVA/Event.h"
70
71#include "TMVA/Tools.h"
72#include "TMVA/Types.h"
73#include "TMVA/VariableInfo.h"
74
75using namespace std;
76
77//TMVA::DataSetFactory* TMVA::DataSetFactory::fgInstance = 0;
78
79namespace TMVA {
80 // calculate the largest common divider
81 // this function is not happy if numbers are negative!
83 {
84 if (a<b) {Int_t tmp = a; a=b; b=tmp; } // achieve a>=b
85 if (b==0) return a;
86 Int_t fullFits = a/b;
87 return LargestCommonDivider(b,a-b*fullFits);
88 }
89}
90
91
92////////////////////////////////////////////////////////////////////////////////
93/// constructor
94
96 fVerbose(kFALSE),
97 fVerboseLevel(TString("Info")),
98 fScaleWithPreselEff(0),
99 fCurrentTree(0),
100 fCurrentEvtIdx(0),
101 fInputFormulas(0),
102 fLogger( new MsgLogger("DataSetFactory", kINFO) )
103{
104}
105
106////////////////////////////////////////////////////////////////////////////////
107/// destructor
108
110{
111 std::vector<TTreeFormula*>::const_iterator formIt;
112
113 for (formIt = fInputFormulas.begin() ; formIt!=fInputFormulas.end() ; ++formIt) if (*formIt) delete *formIt;
114 for (formIt = fTargetFormulas.begin() ; formIt!=fTargetFormulas.end() ; ++formIt) if (*formIt) delete *formIt;
115 for (formIt = fCutFormulas.begin() ; formIt!=fCutFormulas.end() ; ++formIt) if (*formIt) delete *formIt;
116 for (formIt = fWeightFormula.begin() ; formIt!=fWeightFormula.end() ; ++formIt) if (*formIt) delete *formIt;
117 for (formIt = fSpectatorFormulas.begin(); formIt!=fSpectatorFormulas.end(); ++formIt) if (*formIt) delete *formIt;
118
119 delete fLogger;
120}
121
122////////////////////////////////////////////////////////////////////////////////
123/// steering the creation of a new dataset
124
126 TMVA::DataInputHandler& dataInput )
127{
128 // build the first dataset from the data input
129 DataSet * ds = BuildInitialDataSet( dsi, dataInput );
130
131 if (ds->GetNEvents() > 1 && fComputeCorrelations ) {
132 CalcMinMax(ds,dsi);
133
134 // from the the final dataset build the correlation matrix
135 for (UInt_t cl = 0; cl< dsi.GetNClasses(); cl++) {
136 const TString className = dsi.GetClassInfo(cl)->GetName();
137 dsi.SetCorrelationMatrix( className, CalcCorrelationMatrix( ds, cl ) );
138 if (fCorrelations) {
139 dsi.PrintCorrelationMatrix(className);
140 }
141 }
142 //Log() << kHEADER << Endl;
143 Log() << kHEADER << Form("[%s] : ",dsi.GetName()) << " " << Endl << Endl;
144 }
145
146 return ds;
147}
148
149////////////////////////////////////////////////////////////////////////////////
150
152{
153 Log() << kDEBUG << Form("Dataset[%s] : ",dsi.GetName()) << "Build DataSet consisting of one Event with dynamically changing variables" << Endl;
154 DataSet* ds = new DataSet(dsi);
155
156 // create a DataSet with one Event which uses dynamic variables
157 // (pointers to variables)
158 if(dsi.GetNClasses()==0){
159 dsi.AddClass( "data" );
160 dsi.GetClassInfo( "data" )->SetNumber(0);
161 }
162
163 std::vector<Float_t*>* evdyn = new std::vector<Float_t*>(0);
164
165 std::vector<VariableInfo>& varinfos = dsi.GetVariableInfos();
166
167 if (varinfos.empty())
168 Log() << kFATAL << Form("Dataset[%s] : ",dsi.GetName()) << "Dynamic data set cannot be built, since no variable informations are present. Apparently no variables have been set. This should not happen, please contact the TMVA authors." << Endl;
169
170 std::vector<VariableInfo>::iterator it = varinfos.begin(), itEnd=varinfos.end();
171 for (;it!=itEnd;++it) {
172 Float_t* external=(Float_t*)(*it).GetExternalLink();
173 if (external==0)
174 Log() << kDEBUG << Form("Dataset[%s] : ",dsi.GetName()) << "The link to the external variable is NULL while I am trying to build a dynamic data set. In this case fTmpEvent from MethodBase HAS TO BE USED in the method to get useful values in variables." << Endl;
175 else evdyn->push_back (external);
176 }
177
178 std::vector<VariableInfo>& spectatorinfos = dsi.GetSpectatorInfos();
179 it = spectatorinfos.begin();
180 for (;it!=spectatorinfos.end();++it) evdyn->push_back( (Float_t*)(*it).GetExternalLink() );
181
182 TMVA::Event * ev = new Event((const std::vector<Float_t*>*&)evdyn, varinfos.size());
183 std::vector<Event*>* newEventVector = new std::vector<Event*>;
184 newEventVector->push_back(ev);
185
186 ds->SetEventCollection(newEventVector, Types::kTraining);
188 ds->SetCurrentEvent( 0 );
189
190 delete newEventVector;
191 return ds;
192}
193
194////////////////////////////////////////////////////////////////////////////////
195/// if no entries, than create a DataSet with one Event which uses
196/// dynamic variables (pointers to variables)
197
200 DataInputHandler& dataInput )
201{
202 if (dataInput.GetEntries()==0) return BuildDynamicDataSet( dsi );
203 // -------------------------------------------------------------------------
204
205 // register the classes in the datasetinfo-object
206 // information comes from the trees in the dataInputHandler-object
207 std::vector< TString >* classList = dataInput.GetClassList();
208 for (std::vector<TString>::iterator it = classList->begin(); it< classList->end(); ++it) {
209 dsi.AddClass( (*it) );
210 }
211 delete classList;
212
213 EvtStatsPerClass eventCounts(dsi.GetNClasses());
214 TString normMode;
215 TString splitMode;
216 TString mixMode;
217 UInt_t splitSeed;
218
219 InitOptions( dsi, eventCounts, normMode, splitSeed, splitMode , mixMode );
220 // ======= build event-vector from input, apply preselection ===============
221 EventVectorOfClassesOfTreeType tmpEventVector;
222 BuildEventVector( dsi, dataInput, tmpEventVector, eventCounts );
223
224 DataSet* ds = MixEvents( dsi, tmpEventVector, eventCounts,
225 splitMode, mixMode, normMode, splitSeed );
226
227 const Bool_t showCollectedOutput = kFALSE;
228 if (showCollectedOutput) {
229 Int_t maxL = dsi.GetClassNameMaxLength();
230 Log() << kINFO << Form("Dataset[%s] : ",dsi.GetName()) << "Collected:" << Endl;
231 for (UInt_t cl = 0; cl < dsi.GetNClasses(); cl++) {
232 Log() << kINFO << Form("Dataset[%s] : ",dsi.GetName()) << " "
233 << setiosflags(ios::left) << std::setw(maxL) << dsi.GetClassInfo(cl)->GetName()
234 << " training entries: " << ds->GetNClassEvents( 0, cl ) << Endl;
235 Log() << kINFO << Form("Dataset[%s] : ",dsi.GetName()) << " "
236 << setiosflags(ios::left) << std::setw(maxL) << dsi.GetClassInfo(cl)->GetName()
237 << " testing entries: " << ds->GetNClassEvents( 1, cl ) << Endl;
238 }
239 Log() << kINFO << Form("Dataset[%s] : ",dsi.GetName()) << " " << Endl;
240 }
241
242 return ds;
243}
244
245////////////////////////////////////////////////////////////////////////////////
246/// checks a TTreeFormula for problems
247
249 const TString& expression,
250 Bool_t& hasDollar )
251{
252 Bool_t worked = kTRUE;
253
254 if( ttf->GetNdim() <= 0 )
255 Log() << kFATAL << "Expression " << expression.Data()
256 << " could not be resolved to a valid formula. " << Endl;
257 if( ttf->GetNdata() == 0 ){
258 Log() << kWARNING << "Expression: " << expression.Data()
259 << " does not provide data for this event. "
260 << "This event is not taken into account. --> please check if you use as a variable "
261 << "an entry of an array which is not filled for some events "
262 << "(e.g. arr[4] when arr has only 3 elements)." << Endl;
263 Log() << kWARNING << "If you want to take the event into account you can do something like: "
264 << "\"Alt$(arr[4],0)\" where in cases where arr doesn't have a 4th element, "
265 << " 0 is taken as an alternative." << Endl;
266 worked = kFALSE;
267 }
268 if( expression.Contains("$") )
269 hasDollar = kTRUE;
270 else
271 {
272 for (int i = 0, iEnd = ttf->GetNcodes (); i < iEnd; ++i)
273 {
274 TLeaf* leaf = ttf->GetLeaf (i);
275 if (!leaf->IsOnTerminalBranch())
276 hasDollar = kTRUE;
277 }
278 }
279 return worked;
280}
281
282
283////////////////////////////////////////////////////////////////////////////////
284/// While the data gets copied into the local training and testing
285/// trees, the input tree can change (for instance when changing from
286/// signal to background tree, or using TChains as input) The
287/// TTreeFormulas, that hold the input expressions need to be
288/// re-associated with the new tree, which is done here
289
291{
292 TTree *tr = tinfo.GetTree()->GetTree();
293
294 tr->SetBranchStatus("*",1);
296
297 Bool_t hasDollar = kFALSE;
298
299 // 1) the input variable formulas
300 Log() << kDEBUG << Form("Dataset[%s] : ",dsi.GetName()) << "transform input variables" << Endl;
301 std::vector<TTreeFormula*>::const_iterator formIt, formItEnd;
302 for (formIt = fInputFormulas.begin(), formItEnd=fInputFormulas.end(); formIt!=formItEnd; ++formIt) if (*formIt) delete *formIt;
303 fInputFormulas.clear();
304 TTreeFormula* ttf = 0;
305
306 for (UInt_t i=0; i<dsi.GetNVariables(); i++) {
307 ttf = new TTreeFormula( Form( "Formula%s", dsi.GetVariableInfo(i).GetInternalName().Data() ),
308 dsi.GetVariableInfo(i).GetExpression().Data(), tr );
309 CheckTTreeFormula( ttf, dsi.GetVariableInfo(i).GetExpression(), hasDollar );
310 fInputFormulas.push_back( ttf );
311 }
312
313 //
314 // targets
315 //
316 Log() << kDEBUG << Form("Dataset[%s] : ",dsi.GetName()) << "transform regression targets" << Endl;
317 for (formIt = fTargetFormulas.begin(), formItEnd = fTargetFormulas.end(); formIt!=formItEnd; ++formIt) if (*formIt) delete *formIt;
318 fTargetFormulas.clear();
319 for (UInt_t i=0; i<dsi.GetNTargets(); i++) {
320 ttf = new TTreeFormula( Form( "Formula%s", dsi.GetTargetInfo(i).GetInternalName().Data() ),
321 dsi.GetTargetInfo(i).GetExpression().Data(), tr );
322 CheckTTreeFormula( ttf, dsi.GetTargetInfo(i).GetExpression(), hasDollar );
323 fTargetFormulas.push_back( ttf );
324 }
325
326 //
327 // spectators
328 //
329 Log() << kDEBUG << Form("Dataset[%s] : ",dsi.GetName()) << "transform spectator variables" << Endl;
330 for (formIt = fSpectatorFormulas.begin(), formItEnd = fSpectatorFormulas.end(); formIt!=formItEnd; ++formIt) if (*formIt) delete *formIt;
331 fSpectatorFormulas.clear();
332 for (UInt_t i=0; i<dsi.GetNSpectators(); i++) {
333 ttf = new TTreeFormula( Form( "Formula%s", dsi.GetSpectatorInfo(i).GetInternalName().Data() ),
334 dsi.GetSpectatorInfo(i).GetExpression().Data(), tr );
335 CheckTTreeFormula( ttf, dsi.GetSpectatorInfo(i).GetExpression(), hasDollar );
336 fSpectatorFormulas.push_back( ttf );
337 }
338
339 //
340 // the cuts (one per class, if non-existent: formula pointer = 0)
341 //
342 Log() << kDEBUG << Form("Dataset[%s] : ",dsi.GetName()) << "transform cuts" << Endl;
343 for (formIt = fCutFormulas.begin(), formItEnd = fCutFormulas.end(); formIt!=formItEnd; ++formIt) if (*formIt) delete *formIt;
344 fCutFormulas.clear();
345 for (UInt_t clIdx=0; clIdx<dsi.GetNClasses(); clIdx++) {
346 const TCut& tmpCut = dsi.GetClassInfo(clIdx)->GetCut();
347 const TString tmpCutExp(tmpCut.GetTitle());
348 ttf = 0;
349 if (tmpCutExp!="") {
350 ttf = new TTreeFormula( Form("CutClass%i",clIdx), tmpCutExp, tr );
351 Bool_t worked = CheckTTreeFormula( ttf, tmpCutExp, hasDollar );
352 if( !worked ){
353 Log() << kWARNING << "Please check class \"" << dsi.GetClassInfo(clIdx)->GetName()
354 << "\" cut \"" << dsi.GetClassInfo(clIdx)->GetCut() << Endl;
355 }
356 }
357 fCutFormulas.push_back( ttf );
358 }
359
360 //
361 // the weights (one per class, if non-existent: formula pointer = 0)
362 //
363 Log() << kDEBUG << Form("Dataset[%s] : ",dsi.GetName()) << "transform weights" << Endl;
364 for (formIt = fWeightFormula.begin(), formItEnd = fWeightFormula.end(); formIt!=formItEnd; ++formIt) if (*formIt) delete *formIt;
365 fWeightFormula.clear();
366 for (UInt_t clIdx=0; clIdx<dsi.GetNClasses(); clIdx++) {
367 const TString tmpWeight = dsi.GetClassInfo(clIdx)->GetWeight();
368
369 if (dsi.GetClassInfo(clIdx)->GetName() != tinfo.GetClassName() ) { // if the tree is of another class
370 fWeightFormula.push_back( 0 );
371 continue;
372 }
373
374 ttf = 0;
375 if (tmpWeight!="") {
376 ttf = new TTreeFormula( "FormulaWeight", tmpWeight, tr );
377 Bool_t worked = CheckTTreeFormula( ttf, tmpWeight, hasDollar );
378 if( !worked ){
379 Log() << kWARNING << Form("Dataset[%s] : ",dsi.GetName()) << "Please check class \"" << dsi.GetClassInfo(clIdx)->GetName()
380 << "\" weight \"" << dsi.GetClassInfo(clIdx)->GetWeight() << Endl;
381 }
382 }
383 else {
384 ttf = 0;
385 }
386 fWeightFormula.push_back( ttf );
387 }
388 Log() << kDEBUG << Form("Dataset[%s] : ",dsi.GetName()) << "enable branches" << Endl;
389 // now enable only branches that are needed in any input formula, target, cut, weight
390
391 if (!hasDollar) {
392 tr->SetBranchStatus("*",0);
393 Log() << kDEBUG << Form("Dataset[%s] : ",dsi.GetName()) << "enable branches: input variables" << Endl;
394 // input vars
395 for (formIt = fInputFormulas.begin(); formIt!=fInputFormulas.end(); ++formIt) {
396 ttf = *formIt;
397 for (Int_t bi = 0; bi<ttf->GetNcodes(); bi++) {
398 tr->SetBranchStatus( ttf->GetLeaf(bi)->GetBranch()->GetName(), 1 );
399 }
400 }
401 // targets
402 Log() << kDEBUG << Form("Dataset[%s] : ",dsi.GetName()) << "enable branches: targets" << Endl;
403 for (formIt = fTargetFormulas.begin(); formIt!=fTargetFormulas.end(); ++formIt) {
404 ttf = *formIt;
405 for (Int_t bi = 0; bi<ttf->GetNcodes(); bi++)
406 tr->SetBranchStatus( ttf->GetLeaf(bi)->GetBranch()->GetName(), 1 );
407 }
408 // spectators
409 Log() << kDEBUG << Form("Dataset[%s] : ",dsi.GetName()) << "enable branches: spectators" << Endl;
410 for (formIt = fSpectatorFormulas.begin(); formIt!=fSpectatorFormulas.end(); ++formIt) {
411 ttf = *formIt;
412 for (Int_t bi = 0; bi<ttf->GetNcodes(); bi++)
413 tr->SetBranchStatus( ttf->GetLeaf(bi)->GetBranch()->GetName(), 1 );
414 }
415 // cuts
416 Log() << kDEBUG << Form("Dataset[%s] : ",dsi.GetName()) << "enable branches: cuts" << Endl;
417 for (formIt = fCutFormulas.begin(); formIt!=fCutFormulas.end(); ++formIt) {
418 ttf = *formIt;
419 if (!ttf) continue;
420 for (Int_t bi = 0; bi<ttf->GetNcodes(); bi++)
421 tr->SetBranchStatus( ttf->GetLeaf(bi)->GetBranch()->GetName(), 1 );
422 }
423 // weights
424 Log() << kDEBUG << Form("Dataset[%s] : ",dsi.GetName()) << "enable branches: weights" << Endl;
425 for (formIt = fWeightFormula.begin(); formIt!=fWeightFormula.end(); ++formIt) {
426 ttf = *formIt;
427 if (!ttf) continue;
428 for (Int_t bi = 0; bi<ttf->GetNcodes(); bi++)
429 tr->SetBranchStatus( ttf->GetLeaf(bi)->GetBranch()->GetName(), 1 );
430 }
431 }
432 Log() << kDEBUG << Form("Dataset[%s] : ",dsi.GetName()) << "tree initialized" << Endl;
433 return;
434}
435
436////////////////////////////////////////////////////////////////////////////////
437/// compute covariance matrix
438
440{
441 const UInt_t nvar = ds->GetNVariables();
442 const UInt_t ntgts = ds->GetNTargets();
443 const UInt_t nvis = ds->GetNSpectators();
444
445 Float_t *min = new Float_t[nvar];
446 Float_t *max = new Float_t[nvar];
447 Float_t *tgmin = new Float_t[ntgts];
448 Float_t *tgmax = new Float_t[ntgts];
449 Float_t *vmin = new Float_t[nvis];
450 Float_t *vmax = new Float_t[nvis];
451
452 for (UInt_t ivar=0; ivar<nvar ; ivar++) { min[ivar] = FLT_MAX; max[ivar] = -FLT_MAX; }
453 for (UInt_t ivar=0; ivar<ntgts; ivar++) { tgmin[ivar] = FLT_MAX; tgmax[ivar] = -FLT_MAX; }
454 for (UInt_t ivar=0; ivar<nvis; ivar++) { vmin[ivar] = FLT_MAX; vmax[ivar] = -FLT_MAX; }
455
456 // perform event loop
457
458 for (Int_t i=0; i<ds->GetNEvents(); i++) {
459 const Event * ev = ds->GetEvent(i);
460 for (UInt_t ivar=0; ivar<nvar; ivar++) {
461 Double_t v = ev->GetValue(ivar);
462 if (v<min[ivar]) min[ivar] = v;
463 if (v>max[ivar]) max[ivar] = v;
464 }
465 for (UInt_t itgt=0; itgt<ntgts; itgt++) {
466 Double_t v = ev->GetTarget(itgt);
467 if (v<tgmin[itgt]) tgmin[itgt] = v;
468 if (v>tgmax[itgt]) tgmax[itgt] = v;
469 }
470 for (UInt_t ivis=0; ivis<nvis; ivis++) {
471 Double_t v = ev->GetSpectator(ivis);
472 if (v<vmin[ivis]) vmin[ivis] = v;
473 if (v>vmax[ivis]) vmax[ivis] = v;
474 }
475 }
476
477 for (UInt_t ivar=0; ivar<nvar; ivar++) {
478 dsi.GetVariableInfo(ivar).SetMin(min[ivar]);
479 dsi.GetVariableInfo(ivar).SetMax(max[ivar]);
480 if( TMath::Abs(max[ivar]-min[ivar]) <= FLT_MIN )
481 Log() << kWARNING << Form("Dataset[%s] : ",dsi.GetName()) << "Variable " << dsi.GetVariableInfo(ivar).GetExpression().Data() << " is constant. Please remove the variable." << Endl;
482 }
483 for (UInt_t ivar=0; ivar<ntgts; ivar++) {
484 dsi.GetTargetInfo(ivar).SetMin(tgmin[ivar]);
485 dsi.GetTargetInfo(ivar).SetMax(tgmax[ivar]);
486 if( TMath::Abs(tgmax[ivar]-tgmin[ivar]) <= FLT_MIN )
487 Log() << kFATAL << Form("Dataset[%s] : ",dsi.GetName()) << "Target " << dsi.GetTargetInfo(ivar).GetExpression().Data() << " is constant. Please remove the variable." << Endl;
488 }
489 for (UInt_t ivar=0; ivar<nvis; ivar++) {
490 dsi.GetSpectatorInfo(ivar).SetMin(vmin[ivar]);
491 dsi.GetSpectatorInfo(ivar).SetMax(vmax[ivar]);
492 // if( TMath::Abs(vmax[ivar]-vmin[ivar]) <= FLT_MIN )
493 // Log() << kWARNING << "Spectator variable " << dsi.GetSpectatorInfo(ivar).GetExpression().Data() << " is constant." << Endl;
494 }
495 delete [] min;
496 delete [] max;
497 delete [] tgmin;
498 delete [] tgmax;
499 delete [] vmin;
500 delete [] vmax;
501}
502
503////////////////////////////////////////////////////////////////////////////////
504/// computes correlation matrix for variables "theVars" in tree;
505/// "theType" defines the required event "type"
506/// ("type" variable must be present in tree)
507
509{
510 // first compute variance-covariance
511 TMatrixD* mat = CalcCovarianceMatrix( ds, classNumber );
512
513 // now the correlation
514 UInt_t nvar = ds->GetNVariables(), ivar, jvar;
515
516 for (ivar=0; ivar<nvar; ivar++) {
517 for (jvar=0; jvar<nvar; jvar++) {
518 if (ivar != jvar) {
519 Double_t d = (*mat)(ivar, ivar)*(*mat)(jvar, jvar);
520 if (d > 0) (*mat)(ivar, jvar) /= sqrt(d);
521 else {
522 Log() << kWARNING << Form("Dataset[%s] : ",DataSetInfo().GetName())<< "<GetCorrelationMatrix> Zero variances for variables "
523 << "(" << ivar << ", " << jvar << ") = " << d
524 << Endl;
525 (*mat)(ivar, jvar) = 0;
526 }
527 }
528 }
529 }
530
531 for (ivar=0; ivar<nvar; ivar++) (*mat)(ivar, ivar) = 1.0;
532
533 return mat;
534}
535
536////////////////////////////////////////////////////////////////////////////////
537/// compute covariance matrix
538
540{
541 UInt_t nvar = ds->GetNVariables();
542 UInt_t ivar = 0, jvar = 0;
543
544 TMatrixD* mat = new TMatrixD( nvar, nvar );
545
546 // init matrices
547 TVectorD vec(nvar);
548 TMatrixD mat2(nvar, nvar);
549 for (ivar=0; ivar<nvar; ivar++) {
550 vec(ivar) = 0;
551 for (jvar=0; jvar<nvar; jvar++) mat2(ivar, jvar) = 0;
552 }
553
554 // perform event loop
555 Double_t ic = 0;
556 for (Int_t i=0; i<ds->GetNEvents(); i++) {
557
558 const Event * ev = ds->GetEvent(i);
559 if (ev->GetClass() != classNumber ) continue;
560
561 Double_t weight = ev->GetWeight();
562 ic += weight; // count used events
563
564 for (ivar=0; ivar<nvar; ivar++) {
565
566 Double_t xi = ev->GetValue(ivar);
567 vec(ivar) += xi*weight;
568 mat2(ivar, ivar) += (xi*xi*weight);
569
570 for (jvar=ivar+1; jvar<nvar; jvar++) {
571 Double_t xj = ev->GetValue(jvar);
572 mat2(ivar, jvar) += (xi*xj*weight);
573 }
574 }
575 }
576
577 for (ivar=0; ivar<nvar; ivar++)
578 for (jvar=ivar+1; jvar<nvar; jvar++)
579 mat2(jvar, ivar) = mat2(ivar, jvar); // symmetric matrix
580
581
582 // variance-covariance
583 for (ivar=0; ivar<nvar; ivar++) {
584 for (jvar=0; jvar<nvar; jvar++) {
585 (*mat)(ivar, jvar) = mat2(ivar, jvar)/ic - vec(ivar)*vec(jvar)/(ic*ic);
586 }
587 }
588
589 return mat;
590}
591
592// --------------------------------------- new versions
593
594////////////////////////////////////////////////////////////////////////////////
595/// the dataset splitting
596
597void
599 EvtStatsPerClass& nEventRequests,
600 TString& normMode,
601 UInt_t& splitSeed,
602 TString& splitMode,
603 TString& mixMode)
604{
605 Configurable splitSpecs( dsi.GetSplitOptions() );
606 splitSpecs.SetConfigName("DataSetFactory");
607 splitSpecs.SetConfigDescription( "Configuration options given in the \"PrepareForTrainingAndTesting\" call; these options define the creation of the data sets used for training and expert validation by TMVA" );
608
609 splitMode = "Random"; // the splitting mode
610 splitSpecs.DeclareOptionRef( splitMode, "SplitMode",
611 "Method of picking training and testing events (default: random)" );
612 splitSpecs.AddPreDefVal(TString("Random"));
613 splitSpecs.AddPreDefVal(TString("Alternate"));
614 splitSpecs.AddPreDefVal(TString("Block"));
615
616 mixMode = "SameAsSplitMode"; // the splitting mode
617 splitSpecs.DeclareOptionRef( mixMode, "MixMode",
618 "Method of mixing events of different classes into one dataset (default: SameAsSplitMode)" );
619 splitSpecs.AddPreDefVal(TString("SameAsSplitMode"));
620 splitSpecs.AddPreDefVal(TString("Random"));
621 splitSpecs.AddPreDefVal(TString("Alternate"));
622 splitSpecs.AddPreDefVal(TString("Block"));
623
624 splitSeed = 100;
625 splitSpecs.DeclareOptionRef( splitSeed, "SplitSeed",
626 "Seed for random event shuffling" );
627
628 normMode = "EqualNumEvents"; // the weight normalisation modes
629 splitSpecs.DeclareOptionRef( normMode, "NormMode",
630 "Overall renormalisation of event-by-event weights used in the training (NumEvents: average weight of 1 per event, independently for signal and background; EqualNumEvents: average weight of 1 per event for signal, and sum of weights for background equal to sum of weights for signal)" );
631 splitSpecs.AddPreDefVal(TString("None"));
632 splitSpecs.AddPreDefVal(TString("NumEvents"));
633 splitSpecs.AddPreDefVal(TString("EqualNumEvents"));
634
635 splitSpecs.DeclareOptionRef(fScaleWithPreselEff=kFALSE,"ScaleWithPreselEff","Scale the number of requested events by the eff. of the preselection cuts (or not)" );
636
637 // the number of events
638
639 // fill in the numbers
640 for (UInt_t cl = 0; cl < dsi.GetNClasses(); cl++) {
641 TString clName = dsi.GetClassInfo(cl)->GetName();
642 TString titleTrain = TString().Format("Number of training events of class %s (default: 0 = all)",clName.Data()).Data();
643 TString titleTest = TString().Format("Number of test events of class %s (default: 0 = all)",clName.Data()).Data();
644 TString titleSplit = TString().Format("Split in training and test events of class %s (default: 0 = deactivated)",clName.Data()).Data();
645
646 splitSpecs.DeclareOptionRef( nEventRequests.at(cl).nTrainingEventsRequested, TString("nTrain_")+clName, titleTrain );
647 splitSpecs.DeclareOptionRef( nEventRequests.at(cl).nTestingEventsRequested , TString("nTest_")+clName , titleTest );
648 splitSpecs.DeclareOptionRef( nEventRequests.at(cl).TrainTestSplitRequested , TString("TrainTestSplit_")+clName , titleTest );
649 }
650
651 splitSpecs.DeclareOptionRef( fVerbose, "V", "Verbosity (default: true)" );
652
653 splitSpecs.DeclareOptionRef( fVerboseLevel=TString("Info"), "VerboseLevel", "VerboseLevel (Debug/Verbose/Info)" );
654 splitSpecs.AddPreDefVal(TString("Debug"));
655 splitSpecs.AddPreDefVal(TString("Verbose"));
656 splitSpecs.AddPreDefVal(TString("Info"));
657
658 fCorrelations = kTRUE;
659 splitSpecs.DeclareOptionRef(fCorrelations, "Correlations", "Boolean to show correlation output (Default: true)");
660 fComputeCorrelations = kTRUE;
661 splitSpecs.DeclareOptionRef(fComputeCorrelations, "CalcCorrelations", "Compute correlations and also some variable statistics, e.g. min/max (Default: true )");
662
663 splitSpecs.ParseOptions();
664 splitSpecs.CheckForUnusedOptions();
665
666 // output logging verbosity
667 if (Verbose()) fLogger->SetMinType( kVERBOSE );
668 if (fVerboseLevel.CompareTo("Debug") ==0) fLogger->SetMinType( kDEBUG );
669 if (fVerboseLevel.CompareTo("Verbose") ==0) fLogger->SetMinType( kVERBOSE );
670 if (fVerboseLevel.CompareTo("Info") ==0) fLogger->SetMinType( kINFO );
671
672 // put all to upper case
673 splitMode.ToUpper(); mixMode.ToUpper(); normMode.ToUpper();
674 // adjust mixmode if same as splitmode option has been set
675 Log() << kDEBUG //<< Form("Dataset[%s] : ",dsi.GetName())
676 << "\tSplitmode is: \"" << splitMode << "\" the mixmode is: \"" << mixMode << "\"" << Endl;
677 if (mixMode=="SAMEASSPLITMODE") mixMode = splitMode;
678 else if (mixMode!=splitMode)
679 Log() << kINFO << Form("Dataset[%s] : ",dsi.GetName()) << "DataSet splitmode="<<splitMode
680 <<" differs from mixmode="<<mixMode<<Endl;
681}
682
683////////////////////////////////////////////////////////////////////////////////
684/// build empty event vectors
685/// distributes events between kTraining/kTesting/kMaxTreeType
686
687void
689 TMVA::DataInputHandler& dataInput,
691 EvtStatsPerClass& eventCounts)
692{
693 const UInt_t nclasses = dsi.GetNClasses();
694
695 eventsmap[ Types::kTraining ] = EventVectorOfClasses(nclasses);
696 eventsmap[ Types::kTesting ] = EventVectorOfClasses(nclasses);
697 eventsmap[ Types::kMaxTreeType ] = EventVectorOfClasses(nclasses);
698
699 // create the type, weight and boostweight branches
700 const UInt_t nvars = dsi.GetNVariables();
701 const UInt_t ntgts = dsi.GetNTargets();
702 const UInt_t nvis = dsi.GetNSpectators();
703
704 for (size_t i=0; i<nclasses; i++) {
705 eventCounts[i].varAvLength = new Float_t[nvars];
706 for (UInt_t ivar=0; ivar<nvars; ivar++)
707 eventCounts[i].varAvLength[ivar] = 0;
708 }
709
710 // Bool_t haveArrayVariable = kFALSE;
711 Bool_t *varIsArray = new Bool_t[nvars];
712
713 // If there are NaNs in the tree:
714 // => warn if used variables/cuts/weights contain nan (no problem if event is cut out)
715 // => fatal if cut value is nan or (event not cut out and nans somewhere)
716 // Count & collect all these warnings/errors and output them at the end.
717 std::map<TString, int> nanInfWarnings;
718 std::map<TString, int> nanInfErrors;
719
720 // if we work with chains we need to remember the current tree if
721 // the chain jumps to a new tree we have to reset the formulas
722 for (UInt_t cl=0; cl<nclasses; cl++) {
723
724 //Log() << kINFO << Form("Dataset[%s] : ",dsi.GetName()) << "Create training and testing trees -- looping over class \"" << dsi.GetClassInfo(cl)->GetName() << "\" ..." << Endl;
725
726 EventStats& classEventCounts = eventCounts[cl];
727
728 // info output for weights
729 Log() << kDEBUG //<< Form("Dataset[%s] : ",dsi.GetName())
730 << "\tWeight expression for class \'" << dsi.GetClassInfo(cl)->GetName() << "\': \""
731 << dsi.GetClassInfo(cl)->GetWeight() << "\"" << Endl;
732
733 // used for chains only
734 TString currentFileName("");
735
736 std::vector<TreeInfo>::const_iterator treeIt(dataInput.begin(dsi.GetClassInfo(cl)->GetName()));
737 for (;treeIt!=dataInput.end(dsi.GetClassInfo(cl)->GetName()); ++treeIt) {
738
739 // read first the variables
740 std::vector<Float_t> vars(nvars);
741 std::vector<Float_t> tgts(ntgts);
742 std::vector<Float_t> vis(nvis);
743 TreeInfo currentInfo = *treeIt;
744
745 Log() << kDEBUG << "Building event vectors " << currentInfo.GetTreeType() << Endl;
746
747 EventVector& event_v = eventsmap[currentInfo.GetTreeType()].at(cl);
748
749 Bool_t isChain = (TString("TChain") == currentInfo.GetTree()->ClassName());
750 currentInfo.GetTree()->LoadTree(0);
751 ChangeToNewTree( currentInfo, dsi );
752
753 // count number of events in tree before cut
754 classEventCounts.nInitialEvents += currentInfo.GetTree()->GetEntries();
755
756 // loop over events in ntuple
757 const UInt_t nEvts = currentInfo.GetTree()->GetEntries();
758 for (Long64_t evtIdx = 0; evtIdx < nEvts; evtIdx++) {
759 currentInfo.GetTree()->LoadTree(evtIdx);
760
761 // may need to reload tree in case of chains
762 if (isChain) {
763 if (currentInfo.GetTree()->GetTree()->GetDirectory()->GetFile()->GetName() != currentFileName) {
764 currentFileName = currentInfo.GetTree()->GetTree()->GetDirectory()->GetFile()->GetName();
765 ChangeToNewTree( currentInfo, dsi );
766 }
767 }
768 currentInfo.GetTree()->GetEntry(evtIdx);
769 Int_t sizeOfArrays = 1;
770 Int_t prevArrExpr = 0;
771
772 // ======= evaluate all formulas =================
773
774 // first we check if some of the formulas are arrays
775 for (UInt_t ivar=0; ivar<nvars; ivar++) {
776 Int_t ndata = fInputFormulas[ivar]->GetNdata();
777 classEventCounts.varAvLength[ivar] += ndata;
778 if (ndata == 1) continue;
779 // haveArrayVariable = kTRUE;
780 varIsArray[ivar] = kTRUE;
781 if (sizeOfArrays == 1) {
782 sizeOfArrays = ndata;
783 prevArrExpr = ivar;
784 }
785 else if (sizeOfArrays!=ndata) {
786 Log() << kERROR << Form("Dataset[%s] : ",dsi.GetName())<< "ERROR while preparing training and testing trees:" << Endl;
787 Log() << Form("Dataset[%s] : ",dsi.GetName())<< " multiple array-type expressions of different length were encountered" << Endl;
788 Log() << Form("Dataset[%s] : ",dsi.GetName())<< " location of error: event " << evtIdx
789 << " in tree " << currentInfo.GetTree()->GetName()
790 << " of file " << currentInfo.GetTree()->GetCurrentFile()->GetName() << Endl;
791 Log() << Form("Dataset[%s] : ",dsi.GetName())<< " expression " << fInputFormulas[ivar]->GetTitle() << " has "
792 << Form("Dataset[%s] : ",dsi.GetName()) << ndata << " entries, while" << Endl;
793 Log() << Form("Dataset[%s] : ",dsi.GetName())<< " expression " << fInputFormulas[prevArrExpr]->GetTitle() << " has "
794 << Form("Dataset[%s] : ",dsi.GetName())<< fInputFormulas[prevArrExpr]->GetNdata() << " entries" << Endl;
795 Log() << kFATAL << Form("Dataset[%s] : ",dsi.GetName())<< "Need to abort" << Endl;
796 }
797 }
798
799 // now we read the information
800 for (Int_t idata = 0; idata<sizeOfArrays; idata++) {
801 Bool_t contains_NaN_or_inf = kFALSE;
802
803 auto checkNanInf = [&](std::map<TString, int> &msgMap, Float_t value, const char *what, const char *formulaTitle) {
804 if (TMath::IsNaN(value)) {
805 contains_NaN_or_inf = kTRUE;
806 ++msgMap[TString::Format("Dataset[%s] : %s expression resolves to indeterminate value (NaN): %s", dsi.GetName(), what, formulaTitle)];
807 } else if (!TMath::Finite(value)) {
808 contains_NaN_or_inf = kTRUE;
809 ++msgMap[TString::Format("Dataset[%s] : %s expression resolves to infinite value (+inf or -inf): %s", dsi.GetName(), what, formulaTitle)];
810 }
811 };
812
813 TTreeFormula* formula = 0;
814
815 // the cut expression
816 Double_t cutVal = 1.;
817 formula = fCutFormulas[cl];
818 if (formula) {
819 Int_t ndata = formula->GetNdata();
820 cutVal = (ndata==1 ?
821 formula->EvalInstance(0) :
822 formula->EvalInstance(idata));
823 checkNanInf(nanInfErrors, cutVal, "Cut", formula->GetTitle());
824 }
825
826 // if event is cut out, add to warnings, else add to errors.
827 auto &nanMessages = cutVal < 0.5 ? nanInfWarnings : nanInfErrors;
828
829 // the input variable
830 for (UInt_t ivar=0; ivar<nvars; ivar++) {
831 formula = fInputFormulas[ivar];
832 formula->SetQuickLoad(true);
833 Int_t ndata = formula->GetNdata();
834 vars[ivar] = (ndata == 1 ?
835 formula->EvalInstance(0) :
836 formula->EvalInstance(idata));
837 checkNanInf(nanMessages, vars[ivar], "Input", formula->GetTitle());
838 }
839
840 // the targets
841 for (UInt_t itrgt=0; itrgt<ntgts; itrgt++) {
842 formula = fTargetFormulas[itrgt];
843 Int_t ndata = formula->GetNdata();
844 tgts[itrgt] = (ndata == 1 ?
845 formula->EvalInstance(0) :
846 formula->EvalInstance(idata));
847 checkNanInf(nanMessages, tgts[itrgt], "Target", formula->GetTitle());
848 }
849
850 // the spectators
851 for (UInt_t itVis=0; itVis<nvis; itVis++) {
852 formula = fSpectatorFormulas[itVis];
853 Int_t ndata = formula->GetNdata();
854 vis[itVis] = (ndata == 1 ?
855 formula->EvalInstance(0) :
856 formula->EvalInstance(idata));
857 checkNanInf(nanMessages, vis[itVis], "Spectator", formula->GetTitle());
858 }
859
860
861 // the weight
862 Float_t weight = currentInfo.GetWeight(); // multiply by tree weight
863 formula = fWeightFormula[cl];
864 if (formula!=0) {
865 Int_t ndata = formula->GetNdata();
866 weight *= (ndata == 1 ?
867 formula->EvalInstance() :
868 formula->EvalInstance(idata));
869 checkNanInf(nanMessages, weight, "Weight", formula->GetTitle());
870 }
871
872 // Count the events before rejection due to cut or NaN
873 // value (weighted and unweighted)
874 classEventCounts.nEvBeforeCut++;
875 if (!TMath::IsNaN(weight))
876 classEventCounts.nWeEvBeforeCut += weight;
877
878 // apply the cut, skip rest if cut is not fulfilled
879 if (cutVal<0.5) continue;
880
881 // global flag if negative weights exist -> can be used
882 // by classifiers who may require special data
883 // treatment (also print warning)
884 if (weight < 0) classEventCounts.nNegWeights++;
885
886 // now read the event-values (variables and regression targets)
887
888 if (contains_NaN_or_inf) {
889 Log() << kWARNING << Form("Dataset[%s] : ",dsi.GetName())<< "NaN or +-inf in Event " << evtIdx << Endl;
890 if (sizeOfArrays>1) Log() << kWARNING << Form("Dataset[%s] : ",dsi.GetName())<< " rejected" << Endl;
891 continue;
892 }
893
894 // Count the events after rejection due to cut or NaN value
895 // (weighted and unweighted)
896 classEventCounts.nEvAfterCut++;
897 classEventCounts.nWeEvAfterCut += weight;
898
899 // event accepted, fill temporary ntuple
900 event_v.push_back(new Event(vars, tgts , vis, cl , weight));
901 }
902 }
903 currentInfo.GetTree()->ResetBranchAddresses();
904 }
905 }
906
907 if (!nanInfWarnings.empty()) {
908 Log() << kWARNING << "Found events with NaN and/or +-inf values" << Endl;
909 for (const auto &warning : nanInfWarnings) {
910 auto &log = Log() << kWARNING << warning.first;
911 if (warning.second > 1) log << " (" << warning.second << " times)";
912 log << Endl;
913 }
914 Log() << kWARNING << "These NaN and/or +-infs were all removed by the specified cut, continuing." << Endl;
915 Log() << Endl;
916 }
917
918 if (!nanInfErrors.empty()) {
919 Log() << kWARNING << "Found events with NaN and/or +-inf values (not removed by cut)" << Endl;
920 for (const auto &error : nanInfErrors) {
921 auto &log = Log() << kWARNING << error.first;
922 if (error.second > 1) log << " (" << error.second << " times)";
923 log << Endl;
924 }
925 Log() << kFATAL << "How am I supposed to train a NaN or +-inf?!" << Endl;
926 }
927
928 // for output format, get the maximum class name length
929 Int_t maxL = dsi.GetClassNameMaxLength();
930
931 Log() << kHEADER << Form("[%s] : ",dsi.GetName()) << "Number of events in input trees" << Endl;
932 Log() << kDEBUG << "(after possible flattening of arrays):" << Endl;
933
934
935 for (UInt_t cl = 0; cl < dsi.GetNClasses(); cl++) {
936 Log() << kDEBUG //<< Form("[%s] : ",dsi.GetName())
937 << " "
938 << setiosflags(ios::left) << std::setw(maxL) << dsi.GetClassInfo(cl)->GetName()
939 << " -- number of events : "
940 << std::setw(5) << eventCounts[cl].nEvBeforeCut
941 << " / sum of weights: " << std::setw(5) << eventCounts[cl].nWeEvBeforeCut << Endl;
942 }
943
944 for (UInt_t cl = 0; cl < dsi.GetNClasses(); cl++) {
945 Log() << kDEBUG //<< Form("Dataset[%s] : ",dsi.GetName())
946 << " " << std::setw(maxL) << dsi.GetClassInfo(cl)->GetName()
947 <<" tree -- total number of entries: "
948 << std::setw(5) << dataInput.GetEntries(dsi.GetClassInfo(cl)->GetName()) << Endl;
949 }
950
951 if (fScaleWithPreselEff)
952 Log() << kDEBUG //<< Form("Dataset[%s] : ",dsi.GetName())
953 << "\tPreselection: (will affect number of requested training and testing events)" << Endl;
954 else
955 Log() << kDEBUG //<< Form("Dataset[%s] : ",dsi.GetName())
956 << "\tPreselection: (will NOT affect number of requested training and testing events)" << Endl;
957
958 if (dsi.HasCuts()) {
959 for (UInt_t cl = 0; cl< dsi.GetNClasses(); cl++) {
960 Log() << kINFO << Form("Dataset[%s] : ",dsi.GetName()) << " " << setiosflags(ios::left) << std::setw(maxL) << dsi.GetClassInfo(cl)->GetName()
961 << " requirement: \"" << dsi.GetClassInfo(cl)->GetCut() << "\"" << Endl;
962 Log() << kINFO << Form("Dataset[%s] : ",dsi.GetName()) << " "
963 << setiosflags(ios::left) << std::setw(maxL) << dsi.GetClassInfo(cl)->GetName()
964 << " -- number of events passed: "
965 << std::setw(5) << eventCounts[cl].nEvAfterCut
966 << " / sum of weights: " << std::setw(5) << eventCounts[cl].nWeEvAfterCut << Endl;
967 Log() << kINFO << Form("Dataset[%s] : ",dsi.GetName()) << " "
968 << setiosflags(ios::left) << std::setw(maxL) << dsi.GetClassInfo(cl)->GetName()
969 << " -- efficiency : "
970 << std::setw(6) << eventCounts[cl].nWeEvAfterCut/eventCounts[cl].nWeEvBeforeCut << Endl;
971 }
972 }
973 else Log() << kDEBUG //<< Form("Dataset[%s] : ",dsi.GetName())
974 << " No preselection cuts applied on event classes" << Endl;
975
976 delete[] varIsArray;
977
978}
979
980////////////////////////////////////////////////////////////////////////////////
981/// Select and distribute unassigned events to kTraining and kTesting
982
985 EventVectorOfClassesOfTreeType& tmpEventVector,
986 EvtStatsPerClass& eventCounts,
987 const TString& splitMode,
988 const TString& mixMode,
989 const TString& normMode,
990 UInt_t splitSeed)
991{
992 TMVA::RandomGenerator<TRandom3> rndm(splitSeed);
993
994 // ==== splitting of undefined events to kTraining and kTesting
995
996 // if splitMode contains "RANDOM", then shuffle the undefined events
997 if (splitMode.Contains( "RANDOM" ) /*&& !emptyUndefined*/ ) {
998 // random shuffle the undefined events of each class
999 for( UInt_t cls = 0; cls < dsi.GetNClasses(); ++cls ){
1000 EventVector& unspecifiedEvents = tmpEventVector[Types::kMaxTreeType].at(cls);
1001 if( ! unspecifiedEvents.empty() ) {
1002 Log() << kDEBUG << "randomly shuffling "
1003 << unspecifiedEvents.size()
1004 << " events of class " << cls
1005 << " which are not yet associated to testing or training" << Endl;
1006 std::shuffle(unspecifiedEvents.begin(), unspecifiedEvents.end(), rndm);
1007 }
1008 }
1009 }
1010
1011 // check for each class the number of training and testing events, the requested number and the available number
1012 Log() << kDEBUG << Form("Dataset[%s] : ",dsi.GetName())<< "SPLITTING ========" << Endl;
1013 for( UInt_t cls = 0; cls < dsi.GetNClasses(); ++cls ){
1014 Log() << kDEBUG << Form("Dataset[%s] : ",dsi.GetName())<< "---- class " << cls << Endl;
1015 Log() << kDEBUG << Form("Dataset[%s] : ",dsi.GetName())<< "check number of training/testing events, requested and available number of events and for class " << cls << Endl;
1016
1017 // check if enough or too many events are already in the training/testing eventvectors of the class cls
1018 EventVector& eventVectorTraining = tmpEventVector[ Types::kTraining ].at(cls);
1019 EventVector& eventVectorTesting = tmpEventVector[ Types::kTesting ].at(cls);
1020 EventVector& eventVectorUndefined = tmpEventVector[ Types::kMaxTreeType ].at(cls);
1021
1022 Int_t availableTraining = eventVectorTraining.size();
1023 Int_t availableTesting = eventVectorTesting.size();
1024 Int_t availableUndefined = eventVectorUndefined.size();
1025
1026 Float_t presel_scale;
1027 if (fScaleWithPreselEff) {
1028 presel_scale = eventCounts[cls].cutScaling();
1029 if (presel_scale < 1)
1030 Log() << kINFO << Form("Dataset[%s] : ",dsi.GetName()) << " you have opted for scaling the number of requested training/testing events\n to be scaled by the preselection efficiency"<< Endl;
1031 }else{
1032 presel_scale = 1.; // this scaling was too confusing to most people, including me! Sorry... (Helge)
1033 if (eventCounts[cls].cutScaling() < 1)
1034 Log() << kINFO << Form("Dataset[%s] : ",dsi.GetName()) << " you have opted for interpreting the requested number of training/testing events\n to be the number of events AFTER your preselection cuts" << Endl;
1035
1036 }
1037
1038 // If TrainTestSplit_<class> is set, set number of requested training events to split*num_all_events
1039 // Requested number of testing events is set to zero and therefore takes all other events
1040 // The option TrainTestSplit_<class> overrides nTrain_<class> or nTest_<class>
1041 if(eventCounts[cls].TrainTestSplitRequested < 1.0 && eventCounts[cls].TrainTestSplitRequested > 0.0){
1042 eventCounts[cls].nTrainingEventsRequested = Int_t(eventCounts[cls].TrainTestSplitRequested*(availableTraining+availableTesting+availableUndefined));
1043 eventCounts[cls].nTestingEventsRequested = Int_t(0);
1044 }
1045 else if(eventCounts[cls].TrainTestSplitRequested != 0.0) Log() << kFATAL << Form("The option TrainTestSplit_<class> has to be in range (0, 1] but is set to %f.",eventCounts[cls].TrainTestSplitRequested) << Endl;
1046 Int_t requestedTraining = Int_t(eventCounts[cls].nTrainingEventsRequested * presel_scale);
1047 Int_t requestedTesting = Int_t(eventCounts[cls].nTestingEventsRequested * presel_scale);
1048
1049 Log() << kDEBUG << Form("Dataset[%s] : ",dsi.GetName())<< "events in training trees : " << availableTraining << Endl;
1050 Log() << kDEBUG << Form("Dataset[%s] : ",dsi.GetName())<< "events in testing trees : " << availableTesting << Endl;
1051 Log() << kDEBUG << Form("Dataset[%s] : ",dsi.GetName())<< "events in unspecified trees : " << availableUndefined << Endl;
1052 Log() << kDEBUG << Form("Dataset[%s] : ",dsi.GetName())<< "requested for training : " << requestedTraining << Endl;;
1053
1054 if(presel_scale<1)
1055 Log() << " ( " << eventCounts[cls].nTrainingEventsRequested
1056 << " * " << presel_scale << " preselection efficiency)" << Endl;
1057 else
1058 Log() << Endl;
1059 Log() << kDEBUG << "requested for testing : " << requestedTesting;
1060 if(presel_scale<1)
1061 Log() << " ( " << eventCounts[cls].nTestingEventsRequested
1062 << " * " << presel_scale << " preselection efficiency)" << Endl;
1063 else
1064 Log() << Endl;
1065
1066 // nomenclature r = available training
1067 // s = available testing
1068 // u = available undefined
1069 // R = requested training
1070 // S = requested testing
1071 // nR = to be used to select training events
1072 // nS = to be used to select test events
1073 // we have the constraint: nR + nS < r+s+u,
1074 // since we can not use more events than we have
1075 // free events: Nfree = u-Thet(R-r)-Thet(S-s)
1076 // nomenclature: Thet(x) = x, if x>0 else 0
1077 // nR = max(R,r) + 0.5 * Nfree
1078 // nS = max(S,s) + 0.5 * Nfree
1079 // nR +nS = R+S + u-R+r-S+s = u+r+s= ok! for R>r
1080 // nR +nS = r+S + u-S+s = u+r+s= ok! for r>R
1081
1082 // three different cases might occur here
1083 //
1084 // Case a
1085 // requestedTraining and requestedTesting >0
1086 // free events: Nfree = u-Thet(R-r)-Thet(S-s)
1087 // nR = Max(R,r) + 0.5 * Nfree
1088 // nS = Max(S,s) + 0.5 * Nfree
1089 //
1090 // Case b
1091 // exactly one of requestedTraining or requestedTesting >0
1092 // assume training R >0
1093 // nR = max(R,r)
1094 // nS = s+u+r-nR
1095 // and s=nS
1096 //
1097 // Case c
1098 // requestedTraining=0, requestedTesting=0
1099 // Nfree = u-|r-s|
1100 // if NFree >=0
1101 // R = Max(r,s) + 0.5 * Nfree = S
1102 // else if r>s
1103 // R = r; S=s+u
1104 // else
1105 // R = r+u; S=s
1106 //
1107 // Next steps:
1108 // Determination of Event numbers R,S, nR, nS
1109 // distribute undefined events according to nR, nS
1110 // finally determine actual sub samples from nR and nS to be used in training / testing
1111 //
1112
1113 Int_t useForTesting(0),useForTraining(0);
1114 Int_t allAvailable(availableUndefined + availableTraining + availableTesting);
1115
1116 if( (requestedTraining == 0) && (requestedTesting == 0)){
1117
1118 // Case C: balance the number of training and testing events
1119
1120 if ( availableUndefined >= TMath::Abs(availableTraining - availableTesting) ) {
1121 // enough unspecified are available to equal training and testing
1122 useForTraining = useForTesting = allAvailable/2;
1123 } else {
1124 // all unspecified are assigned to the smaller of training / testing
1125 useForTraining = availableTraining;
1126 useForTesting = availableTesting;
1127 if (availableTraining < availableTesting)
1128 useForTraining += availableUndefined;
1129 else
1130 useForTesting += availableUndefined;
1131 }
1132 requestedTraining = useForTraining;
1133 requestedTesting = useForTesting;
1134 }
1135
1136 else if (requestedTesting == 0){
1137 // case B
1138 useForTraining = TMath::Max(requestedTraining,availableTraining);
1139 if (allAvailable < useForTraining) {
1140 Log() << kFATAL << Form("Dataset[%s] : ",dsi.GetName())<< "More events requested for training ("
1141 << requestedTraining << ") than available ("
1142 << allAvailable << ")!" << Endl;
1143 }
1144 useForTesting = allAvailable - useForTraining; // the rest
1145 requestedTesting = useForTesting;
1146 }
1147
1148 else if (requestedTraining == 0){ // case B)
1149 useForTesting = TMath::Max(requestedTesting,availableTesting);
1150 if (allAvailable < useForTesting) {
1151 Log() << kFATAL << Form("Dataset[%s] : ",dsi.GetName())<< "More events requested for testing ("
1152 << requestedTesting << ") than available ("
1153 << allAvailable << ")!" << Endl;
1154 }
1155 useForTraining= allAvailable - useForTesting; // the rest
1156 requestedTraining = useForTraining;
1157 }
1158
1159 else {
1160 // Case A
1161 // requestedTraining R and requestedTesting S >0
1162 // free events: Nfree = u-Thet(R-r)-Thet(S-s)
1163 // nR = Max(R,r) + 0.5 * Nfree
1164 // nS = Max(S,s) + 0.5 * Nfree
1165 Int_t stillNeedForTraining = TMath::Max(requestedTraining-availableTraining,0);
1166 Int_t stillNeedForTesting = TMath::Max(requestedTesting-availableTesting,0);
1167
1168 int NFree = availableUndefined - stillNeedForTraining - stillNeedForTesting;
1169 if (NFree <0) NFree = 0;
1170 useForTraining = TMath::Max(requestedTraining,availableTraining) + NFree/2;
1171 useForTesting= allAvailable - useForTraining; // the rest
1172 }
1173
1174 Log() << kDEBUG << Form("Dataset[%s] : ",dsi.GetName())<< "determined event sample size to select training sample from="<<useForTraining<<Endl;
1175 Log() << kDEBUG << Form("Dataset[%s] : ",dsi.GetName())<< "determined event sample size to select test sample from="<<useForTesting<<Endl;
1176
1177
1178
1179 // associate undefined events
1180 if( splitMode == "ALTERNATE" ){
1181 Log() << kDEBUG << Form("Dataset[%s] : ",dsi.GetName())<< "split 'ALTERNATE'" << Endl;
1182 Int_t nTraining = availableTraining;
1183 Int_t nTesting = availableTesting;
1184 for( EventVector::iterator it = eventVectorUndefined.begin(), itEnd = eventVectorUndefined.end(); it != itEnd; ){
1185 ++nTraining;
1186 if( nTraining <= requestedTraining ){
1187 eventVectorTraining.insert( eventVectorTraining.end(), (*it) );
1188 ++it;
1189 }
1190 if( it != itEnd ){
1191 ++nTesting;
1192 eventVectorTesting.insert( eventVectorTesting.end(), (*it) );
1193 ++it;
1194 }
1195 }
1196 } else {
1197 Log() << kDEBUG << Form("Dataset[%s] : ",dsi.GetName())<< "split '" << splitMode << "'" << Endl;
1198
1199 // test if enough events are available
1200 Log() << kDEBUG << Form("Dataset[%s] : ",dsi.GetName())<< "availableundefined : " << availableUndefined << Endl;
1201 Log() << kDEBUG << Form("Dataset[%s] : ",dsi.GetName())<< "useForTraining : " << useForTraining << Endl;
1202 Log() << kDEBUG << Form("Dataset[%s] : ",dsi.GetName())<< "useForTesting : " << useForTesting << Endl;
1203 Log() << kDEBUG << Form("Dataset[%s] : ",dsi.GetName())<< "availableTraining : " << availableTraining << Endl;
1204 Log() << kDEBUG << Form("Dataset[%s] : ",dsi.GetName())<< "availableTesting : " << availableTesting << Endl;
1205
1206 if( availableUndefined<(useForTraining-availableTraining) ||
1207 availableUndefined<(useForTesting -availableTesting ) ||
1208 availableUndefined<(useForTraining+useForTesting-availableTraining-availableTesting ) ){
1209 Log() << kFATAL << Form("Dataset[%s] : ",dsi.GetName())<< "More events requested than available!" << Endl;
1210 }
1211
1212 // select the events
1213 if (useForTraining>availableTraining){
1214 eventVectorTraining.insert( eventVectorTraining.end() , eventVectorUndefined.begin(), eventVectorUndefined.begin()+ useForTraining- availableTraining );
1215 eventVectorUndefined.erase( eventVectorUndefined.begin(), eventVectorUndefined.begin() + useForTraining- availableTraining);
1216 }
1217 if (useForTesting>availableTesting){
1218 eventVectorTesting.insert( eventVectorTesting.end() , eventVectorUndefined.begin(), eventVectorUndefined.begin()+ useForTesting- availableTesting );
1219 }
1220 }
1221 eventVectorUndefined.clear();
1222
1223 // finally shorten the event vectors to the requested size by removing random events
1224 if (splitMode.Contains( "RANDOM" )){
1225 UInt_t sizeTraining = eventVectorTraining.size();
1226 if( sizeTraining > UInt_t(requestedTraining) ){
1227 std::vector<UInt_t> indicesTraining( sizeTraining );
1228 // make indices
1229 std::generate( indicesTraining.begin(), indicesTraining.end(), TMVA::Increment<UInt_t>(0) );
1230 // shuffle indices
1231 std::shuffle(indicesTraining.begin(), indicesTraining.end(), rndm);
1232 // erase indices of not needed events
1233 indicesTraining.erase( indicesTraining.begin()+sizeTraining-UInt_t(requestedTraining), indicesTraining.end() );
1234 // delete all events with the given indices
1235 for( std::vector<UInt_t>::iterator it = indicesTraining.begin(), itEnd = indicesTraining.end(); it != itEnd; ++it ){
1236 delete eventVectorTraining.at( (*it) ); // delete event
1237 eventVectorTraining.at( (*it) ) = NULL; // set pointer to NULL
1238 }
1239 // now remove and erase all events with pointer==NULL
1240 eventVectorTraining.erase( std::remove( eventVectorTraining.begin(), eventVectorTraining.end(), (void*)NULL ), eventVectorTraining.end() );
1241 }
1242
1243 UInt_t sizeTesting = eventVectorTesting.size();
1244 if( sizeTesting > UInt_t(requestedTesting) ){
1245 std::vector<UInt_t> indicesTesting( sizeTesting );
1246 // make indices
1247 std::generate( indicesTesting.begin(), indicesTesting.end(), TMVA::Increment<UInt_t>(0) );
1248 // shuffle indices
1249 std::shuffle(indicesTesting.begin(), indicesTesting.end(), rndm);
1250 // erase indices of not needed events
1251 indicesTesting.erase( indicesTesting.begin()+sizeTesting-UInt_t(requestedTesting), indicesTesting.end() );
1252 // delete all events with the given indices
1253 for( std::vector<UInt_t>::iterator it = indicesTesting.begin(), itEnd = indicesTesting.end(); it != itEnd; ++it ){
1254 delete eventVectorTesting.at( (*it) ); // delete event
1255 eventVectorTesting.at( (*it) ) = NULL; // set pointer to NULL
1256 }
1257 // now remove and erase all events with pointer==NULL
1258 eventVectorTesting.erase( std::remove( eventVectorTesting.begin(), eventVectorTesting.end(), (void*)NULL ), eventVectorTesting.end() );
1259 }
1260 }
1261 else { // erase at end
1262 if( eventVectorTraining.size() < UInt_t(requestedTraining) )
1263 Log() << kWARNING << Form("Dataset[%s] : ",dsi.GetName())<< "DataSetFactory/requested number of training samples larger than size of eventVectorTraining.\n"
1264 << "There is probably an issue. Please contact the TMVA developers." << Endl;
1265 std::for_each( eventVectorTraining.begin()+requestedTraining, eventVectorTraining.end(), DeleteFunctor<Event>() );
1266 eventVectorTraining.erase(eventVectorTraining.begin()+requestedTraining,eventVectorTraining.end());
1267
1268 if( eventVectorTesting.size() < UInt_t(requestedTesting) )
1269 Log() << kWARNING << Form("Dataset[%s] : ",dsi.GetName())<< "DataSetFactory/requested number of testing samples larger than size of eventVectorTesting.\n"
1270 << "There is probably an issue. Please contact the TMVA developers." << Endl;
1271 std::for_each( eventVectorTesting.begin()+requestedTesting, eventVectorTesting.end(), DeleteFunctor<Event>() );
1272 eventVectorTesting.erase(eventVectorTesting.begin()+requestedTesting,eventVectorTesting.end());
1273 }
1274 }
1275
1276 TMVA::DataSetFactory::RenormEvents( dsi, tmpEventVector, eventCounts, normMode );
1277
1278 Int_t trainingSize = 0;
1279 Int_t testingSize = 0;
1280
1281 // sum up number of training and testing events
1282 for( UInt_t cls = 0; cls < dsi.GetNClasses(); ++cls ){
1283 trainingSize += tmpEventVector[Types::kTraining].at(cls).size();
1284 testingSize += tmpEventVector[Types::kTesting].at(cls).size();
1285 }
1286
1287 // --- collect all training (testing) events into the training (testing) eventvector
1288
1289 // create event vectors reserve enough space
1290 EventVector* trainingEventVector = new EventVector();
1291 EventVector* testingEventVector = new EventVector();
1292
1293 trainingEventVector->reserve( trainingSize );
1294 testingEventVector->reserve( testingSize );
1295
1296
1297 // collect the events
1298
1299 // mixing of kTraining and kTesting data sets
1300 Log() << kDEBUG << " MIXING ============= " << Endl;
1301
1302 if( mixMode == "ALTERNATE" ){
1303 // Inform user if he tries to use alternate mixmode for
1304 // event classes with different number of events, this works but the alternation stops at the last event of the smaller class
1305 for( UInt_t cls = 1; cls < dsi.GetNClasses(); ++cls ){
1306 if (tmpEventVector[Types::kTraining].at(cls).size() != tmpEventVector[Types::kTraining].at(0).size()){
1307 Log() << kINFO << Form("Dataset[%s] : ",dsi.GetName()) << "Training sample: You are trying to mix events in alternate mode although the classes have different event numbers. This works but the alternation stops at the last event of the smaller class."<<Endl;
1308 }
1309 if (tmpEventVector[Types::kTesting].at(cls).size() != tmpEventVector[Types::kTesting].at(0).size()){
1310 Log() << kINFO << Form("Dataset[%s] : ",dsi.GetName()) << "Testing sample: You are trying to mix events in alternate mode although the classes have different event numbers. This works but the alternation stops at the last event of the smaller class."<<Endl;
1311 }
1312 }
1313 typedef EventVector::iterator EvtVecIt;
1314 EvtVecIt itEvent, itEventEnd;
1315
1316 // insert first class
1317 Log() << kDEBUG << "insert class 0 into training and test vector" << Endl;
1318 trainingEventVector->insert( trainingEventVector->end(), tmpEventVector[Types::kTraining].at(0).begin(), tmpEventVector[Types::kTraining].at(0).end() );
1319 testingEventVector->insert( testingEventVector->end(), tmpEventVector[Types::kTesting].at(0).begin(), tmpEventVector[Types::kTesting].at(0).end() );
1320
1321 // insert other classes
1322 EvtVecIt itTarget;
1323 for( UInt_t cls = 1; cls < dsi.GetNClasses(); ++cls ){
1324 Log() << kDEBUG << Form("Dataset[%s] : ",dsi.GetName())<< "insert class " << cls << Endl;
1325 // training vector
1326 itTarget = trainingEventVector->begin() - 1; // start one before begin
1327 // loop over source
1328 for( itEvent = tmpEventVector[Types::kTraining].at(cls).begin(), itEventEnd = tmpEventVector[Types::kTraining].at(cls).end(); itEvent != itEventEnd; ++itEvent ){
1329 // if( std::distance( itTarget, trainingEventVector->end()) < Int_t(cls+1) ) {
1330 if( (trainingEventVector->end() - itTarget) < Int_t(cls+1) ) {
1331 itTarget = trainingEventVector->end();
1332 trainingEventVector->insert( itTarget, itEvent, itEventEnd ); // fill in the rest without mixing
1333 break;
1334 }else{
1335 itTarget += cls+1;
1336 trainingEventVector->insert( itTarget, (*itEvent) ); // fill event
1337 }
1338 }
1339 // testing vector
1340 itTarget = testingEventVector->begin() - 1;
1341 // loop over source
1342 for( itEvent = tmpEventVector[Types::kTesting].at(cls).begin(), itEventEnd = tmpEventVector[Types::kTesting].at(cls).end(); itEvent != itEventEnd; ++itEvent ){
1343 // if( std::distance( itTarget, testingEventVector->end()) < Int_t(cls+1) ) {
1344 if( ( testingEventVector->end() - itTarget ) < Int_t(cls+1) ) {
1345 itTarget = testingEventVector->end();
1346 testingEventVector->insert( itTarget, itEvent, itEventEnd ); // fill in the rest without mixing
1347 break;
1348 }else{
1349 itTarget += cls+1;
1350 testingEventVector->insert( itTarget, (*itEvent) ); // fill event
1351 }
1352 }
1353 }
1354 }else{
1355 for( UInt_t cls = 0; cls < dsi.GetNClasses(); ++cls ){
1356 trainingEventVector->insert( trainingEventVector->end(), tmpEventVector[Types::kTraining].at(cls).begin(), tmpEventVector[Types::kTraining].at(cls).end() );
1357 testingEventVector->insert ( testingEventVector->end(), tmpEventVector[Types::kTesting].at(cls).begin(), tmpEventVector[Types::kTesting].at(cls).end() );
1358 }
1359 }
1360 // delete the tmpEventVector (but not the events therein)
1361 tmpEventVector[Types::kTraining].clear();
1362 tmpEventVector[Types::kTesting].clear();
1363
1364 tmpEventVector[Types::kMaxTreeType].clear();
1365
1366 if (mixMode == "RANDOM") {
1367 Log() << kDEBUG << Form("Dataset[%s] : ",dsi.GetName())<< "shuffling events"<<Endl;
1368
1369 std::shuffle(trainingEventVector->begin(), trainingEventVector->end(), rndm);
1370 std::shuffle(testingEventVector->begin(), testingEventVector->end(), rndm);
1371 }
1372
1373 Log() << kDEBUG << Form("Dataset[%s] : ",dsi.GetName())<< "trainingEventVector " << trainingEventVector->size() << Endl;
1374 Log() << kDEBUG << Form("Dataset[%s] : ",dsi.GetName())<< "testingEventVector " << testingEventVector->size() << Endl;
1375
1376 // create dataset
1377 DataSet* ds = new DataSet(dsi);
1378
1379 // Log() << kINFO << Form("Dataset[%s] : ",dsi.GetName()) << "Create internal training tree" << Endl;
1380 ds->SetEventCollection(trainingEventVector, Types::kTraining );
1381 // Log() << kINFO << Form("Dataset[%s] : ",dsi.GetName()) << "Create internal testing tree" << Endl;
1382 ds->SetEventCollection(testingEventVector, Types::kTesting );
1383
1384
1385 if (ds->GetNTrainingEvents() < 1){
1386 Log() << kFATAL << "Dataset " << std::string(dsi.GetName()) << " does not have any training events, I better stop here and let you fix that one first " << Endl;
1387 }
1388
1389 if (ds->GetNTestEvents() < 1) {
1390 Log() << kERROR << "Dataset " << std::string(dsi.GetName()) << " does not have any testing events, guess that will cause problems later..but for now, I continue " << Endl;
1391 }
1392
1393 delete trainingEventVector;
1394 delete testingEventVector;
1395 return ds;
1396
1397}
1398
1399////////////////////////////////////////////////////////////////////////////////
1400/// renormalisation of the TRAINING event weights
1401/// - none (kind of obvious) .. use the weights as supplied by the
1402/// user.. (we store however the relative weight for later use)
1403/// - numEvents
1404/// - equalNumEvents reweight the training events such that the sum of all
1405/// backgr. (class > 0) weights equal that of the signal (class 0)
1406
1407void
1409 EventVectorOfClassesOfTreeType& tmpEventVector,
1410 const EvtStatsPerClass& eventCounts,
1411 const TString& normMode )
1412{
1413
1414
1415 // print rescaling info
1416 // ---------------------------------
1417 // compute sizes and sums of weights
1418 Int_t trainingSize = 0;
1419 Int_t testingSize = 0;
1420
1421 ValuePerClass trainingSumWeightsPerClass( dsi.GetNClasses() );
1422 ValuePerClass testingSumWeightsPerClass( dsi.GetNClasses() );
1423
1424 NumberPerClass trainingSizePerClass( dsi.GetNClasses() );
1425 NumberPerClass testingSizePerClass( dsi.GetNClasses() );
1426
1427 Double_t trainingSumSignalWeights = 0;
1428 Double_t trainingSumBackgrWeights = 0; // Backgr. includes all classes that are not signal
1429 Double_t testingSumSignalWeights = 0;
1430 Double_t testingSumBackgrWeights = 0; // Backgr. includes all classes that are not signal
1431
1432
1433
1434 for( UInt_t cls = 0, clsEnd = dsi.GetNClasses(); cls < clsEnd; ++cls ){
1435 trainingSizePerClass.at(cls) = tmpEventVector[Types::kTraining].at(cls).size();
1436 testingSizePerClass.at(cls) = tmpEventVector[Types::kTesting].at(cls).size();
1437
1438 trainingSize += trainingSizePerClass.back();
1439 testingSize += testingSizePerClass.back();
1440
1441 // the functional solution
1442 // sum up the weights in Double_t although the individual weights are Float_t to prevent rounding issues in addition of floating points
1443 //
1444 // accumulate --> does what the name says
1445 // begin() and end() denote the range of the vector to be accumulated
1446 // Double_t(0) tells accumulate the type and the starting value
1447 // compose_binary creates a BinaryFunction of ...
1448 // std::plus<Double_t>() knows how to sum up two doubles
1449 // null<Double_t>() leaves the first argument (the running sum) unchanged and returns it
1450 //
1451 // all together sums up all the event-weights of the events in the vector and returns it
1452 trainingSumWeightsPerClass.at(cls) =
1453 std::accumulate(tmpEventVector[Types::kTraining].at(cls).begin(),
1454 tmpEventVector[Types::kTraining].at(cls).end(),
1455 Double_t(0), [](Double_t w, const TMVA::Event *E) { return w + E->GetOriginalWeight(); });
1456
1457 testingSumWeightsPerClass.at(cls) =
1458 std::accumulate(tmpEventVector[Types::kTesting].at(cls).begin(),
1459 tmpEventVector[Types::kTesting].at(cls).end(),
1460 Double_t(0), [](Double_t w, const TMVA::Event *E) { return w + E->GetOriginalWeight(); });
1461
1462 if ( cls == dsi.GetSignalClassIndex()){
1463 trainingSumSignalWeights += trainingSumWeightsPerClass.at(cls);
1464 testingSumSignalWeights += testingSumWeightsPerClass.at(cls);
1465 }else{
1466 trainingSumBackgrWeights += trainingSumWeightsPerClass.at(cls);
1467 testingSumBackgrWeights += testingSumWeightsPerClass.at(cls);
1468 }
1469 }
1470
1471 // ---------------------------------
1472 // compute renormalization factors
1473
1474 ValuePerClass renormFactor( dsi.GetNClasses() );
1475
1476
1477 // for information purposes
1478 dsi.SetNormalization( normMode );
1479 // !! these will be overwritten later by the 'rescaled' ones if
1480 // NormMode != None !!!
1481 dsi.SetTrainingSumSignalWeights(trainingSumSignalWeights);
1482 dsi.SetTrainingSumBackgrWeights(trainingSumBackgrWeights);
1483 dsi.SetTestingSumSignalWeights(testingSumSignalWeights);
1484 dsi.SetTestingSumBackgrWeights(testingSumBackgrWeights);
1485
1486
1487 if (normMode == "NONE") {
1488 Log() << kINFO << Form("Dataset[%s] : ",dsi.GetName()) << "No weight renormalisation applied: use original global and event weights" << Endl;
1489 return;
1490 }
1491 //changed by Helge 27.5.2013 What on earth was done here before? I still remember the idea behind this which apparently was
1492 //NOT understood by the 'programmer' :) .. the idea was to have SAME amount of effective TRAINING data for signal and background.
1493 // Testing events are totally irrelevant for this and might actually skew the whole normalisation!!
1494 else if (normMode == "NUMEVENTS") {
1495 Log() << kDEBUG //<< Form("Dataset[%s] : ",dsi.GetName())
1496 << "\tWeight renormalisation mode: \"NumEvents\": renormalises all event classes " << Endl;
1497 Log() << kDEBUG //<< Form("Dataset[%s] : ",dsi.GetName())
1498 << " such that the effective (weighted) number of events in each class equals the respective " << Endl;
1499 Log() << kDEBUG //<< Form("Dataset[%s] : ",dsi.GetName())
1500 << " number of events (entries) that you demanded in PrepareTrainingAndTestTree(\"\",\"nTrain_Signal=.. )" << Endl;
1501 Log() << kDEBUG //<< Form("Dataset[%s] : ",dsi.GetName())
1502 << " ... i.e. such that Sum[i=1..N_j]{w_i} = N_j, j=0,1,2..." << Endl;
1503 Log() << kDEBUG //<< Form("Dataset[%s] : ",dsi.GetName())
1504 << " ... (note that N_j is the sum of TRAINING events (nTrain_j...with j=Signal,Background.." << Endl;
1505 Log() << kDEBUG //<< Form("Dataset[%s] : ",dsi.GetName())
1506 << " ..... Testing events are not renormalised nor included in the renormalisation factor! )"<< Endl;
1507
1508 for( UInt_t cls = 0, clsEnd = dsi.GetNClasses(); cls < clsEnd; ++cls ){
1509 // renormFactor.at(cls) = ( (trainingSizePerClass.at(cls) + testingSizePerClass.at(cls))/
1510 // (trainingSumWeightsPerClass.at(cls) + testingSumWeightsPerClass.at(cls)) );
1511 //changed by Helge 27.5.2013
1512 renormFactor.at(cls) = ((Float_t)trainingSizePerClass.at(cls) )/
1513 (trainingSumWeightsPerClass.at(cls)) ;
1514 }
1515 }
1516 else if (normMode == "EQUALNUMEVENTS") {
1517 //changed by Helge 27.5.2013 What on earth was done here before? I still remember the idea behind this which apparently was
1518 //NOT understood by the 'programmer' :) .. the idea was to have SAME amount of effective TRAINING data for signal and background.
1519 //done here was something like having each data source normalized to its number of entries and this even for training+testing together.
1520 // what should this have been good for ???
1521
1522 Log() << kINFO << Form("Dataset[%s] : ",dsi.GetName()) << "Weight renormalisation mode: \"EqualNumEvents\": renormalises all event classes ..." << Endl;
1523 Log() << kINFO << Form("Dataset[%s] : ",dsi.GetName()) << " such that the effective (weighted) number of events in each class is the same " << Endl;
1524 Log() << kINFO << Form("Dataset[%s] : ",dsi.GetName()) << " (and equals the number of events (entries) given for class=0 )" << Endl;
1525 Log() << kINFO << Form("Dataset[%s] : ",dsi.GetName()) << "... i.e. such that Sum[i=1..N_j]{w_i} = N_classA, j=classA, classB, ..." << Endl;
1526 Log() << kINFO << Form("Dataset[%s] : ",dsi.GetName()) << "... (note that N_j is the sum of TRAINING events" << Endl;
1527 Log() << kINFO << Form("Dataset[%s] : ",dsi.GetName()) << " ..... Testing events are not renormalised nor included in the renormalisation factor!)" << Endl;
1528
1529 // normalize to size of first class
1530 UInt_t referenceClass = 0;
1531 for (UInt_t cls = 0, clsEnd = dsi.GetNClasses(); cls < clsEnd; ++cls ) {
1532 renormFactor.at(cls) = Float_t(trainingSizePerClass.at(referenceClass))/
1533 (trainingSumWeightsPerClass.at(cls));
1534 }
1535 }
1536 else {
1537 Log() << kFATAL << Form("Dataset[%s] : ",dsi.GetName())<< "<PrepareForTrainingAndTesting> Unknown NormMode: " << normMode << Endl;
1538 }
1539
1540 // ---------------------------------
1541 // now apply the normalization factors
1542 Int_t maxL = dsi.GetClassNameMaxLength();
1543 for (UInt_t cls = 0, clsEnd = dsi.GetNClasses(); cls<clsEnd; ++cls) {
1544 Log() << kDEBUG //<< Form("Dataset[%s] : ",dsi.GetName())
1545 << "--> Rescale " << setiosflags(ios::left) << std::setw(maxL)
1546 << dsi.GetClassInfo(cls)->GetName() << " event weights by factor: " << renormFactor.at(cls) << Endl;
1547 for (EventVector::iterator it = tmpEventVector[Types::kTraining].at(cls).begin(),
1548 itEnd = tmpEventVector[Types::kTraining].at(cls).end(); it != itEnd; ++it){
1549 (*it)->SetWeight ((*it)->GetWeight() * renormFactor.at(cls));
1550 }
1551
1552 }
1553
1554
1555 // print out the result
1556 // (same code as before --> this can be done nicer )
1557 //
1558
1559 Log() << kINFO //<< Form("Dataset[%s] : ",dsi.GetName())
1560 << "Number of training and testing events" << Endl;
1561 Log() << kDEBUG << "\tafter rescaling:" << Endl;
1562 Log() << kINFO //<< Form("Dataset[%s] : ",dsi.GetName())
1563 << "---------------------------------------------------------------------------" << Endl;
1564
1565 trainingSumSignalWeights = 0;
1566 trainingSumBackgrWeights = 0; // Backgr. includes all classes that are not signal
1567 testingSumSignalWeights = 0;
1568 testingSumBackgrWeights = 0; // Backgr. includes all classes that are not signal
1569
1570 for( UInt_t cls = 0, clsEnd = dsi.GetNClasses(); cls < clsEnd; ++cls ){
1571 trainingSumWeightsPerClass.at(cls) =
1572 std::accumulate(tmpEventVector[Types::kTraining].at(cls).begin(),
1573 tmpEventVector[Types::kTraining].at(cls).end(),
1574 Double_t(0), [](Double_t w, const TMVA::Event *E) { return w + E->GetOriginalWeight(); });
1575
1576 testingSumWeightsPerClass.at(cls) =
1577 std::accumulate(tmpEventVector[Types::kTesting].at(cls).begin(),
1578 tmpEventVector[Types::kTesting].at(cls).end(),
1579 Double_t(0), [](Double_t w, const TMVA::Event *E) { return w + E->GetOriginalWeight(); });
1580
1581 if ( cls == dsi.GetSignalClassIndex()){
1582 trainingSumSignalWeights += trainingSumWeightsPerClass.at(cls);
1583 testingSumSignalWeights += testingSumWeightsPerClass.at(cls);
1584 }else{
1585 trainingSumBackgrWeights += trainingSumWeightsPerClass.at(cls);
1586 testingSumBackgrWeights += testingSumWeightsPerClass.at(cls);
1587 }
1588
1589 // output statistics
1590
1591 Log() << kINFO //<< Form("Dataset[%s] : ",dsi.GetName())
1592 << setiosflags(ios::left) << std::setw(maxL)
1593 << dsi.GetClassInfo(cls)->GetName() << " -- "
1594 << "training events : " << trainingSizePerClass.at(cls) << Endl;
1595 Log() << kDEBUG << "\t(sum of weights: " << trainingSumWeightsPerClass.at(cls) << ")"
1596 << " - requested were " << eventCounts[cls].nTrainingEventsRequested << " events" << Endl;
1597 Log() << kINFO //<< Form("Dataset[%s] : ",dsi.GetName())
1598 << setiosflags(ios::left) << std::setw(maxL)
1599 << dsi.GetClassInfo(cls)->GetName() << " -- "
1600 << "testing events : " << testingSizePerClass.at(cls) << Endl;
1601 Log() << kDEBUG << "\t(sum of weights: " << testingSumWeightsPerClass.at(cls) << ")"
1602 << " - requested were " << eventCounts[cls].nTestingEventsRequested << " events" << Endl;
1603 Log() << kINFO //<< Form("Dataset[%s] : ",dsi.GetName())
1604 << setiosflags(ios::left) << std::setw(maxL)
1605 << dsi.GetClassInfo(cls)->GetName() << " -- "
1606 << "training and testing events: "
1607 << (trainingSizePerClass.at(cls)+testingSizePerClass.at(cls)) << Endl;
1608 Log() << kDEBUG << "\t(sum of weights: "
1609 << (trainingSumWeightsPerClass.at(cls)+testingSumWeightsPerClass.at(cls)) << ")" << Endl;
1610 if(eventCounts[cls].nEvAfterCut<eventCounts[cls].nEvBeforeCut) {
1611 Log() << kINFO << Form("Dataset[%s] : ",dsi.GetName()) << setiosflags(ios::left) << std::setw(maxL)
1612 << dsi.GetClassInfo(cls)->GetName() << " -- "
1613 << "due to the preselection a scaling factor has been applied to the numbers of requested events: "
1614 << eventCounts[cls].cutScaling() << Endl;
1615 }
1616 }
1617 Log() << kINFO << Endl;
1618
1619 // for information purposes
1620 dsi.SetTrainingSumSignalWeights(trainingSumSignalWeights);
1621 dsi.SetTrainingSumBackgrWeights(trainingSumBackgrWeights);
1622 dsi.SetTestingSumSignalWeights(testingSumSignalWeights);
1623 dsi.SetTestingSumBackgrWeights(testingSumBackgrWeights);
1624
1625
1626}
1627
SVector< double, 2 > v
Definition: Dict.h:5
#define d(i)
Definition: RSha256.hxx:102
#define b(i)
Definition: RSha256.hxx:100
int Int_t
Definition: RtypesCore.h:41
unsigned int UInt_t
Definition: RtypesCore.h:42
const Bool_t kFALSE
Definition: RtypesCore.h:88
bool Bool_t
Definition: RtypesCore.h:59
double Double_t
Definition: RtypesCore.h:55
long long Long64_t
Definition: RtypesCore.h:69
float Float_t
Definition: RtypesCore.h:53
const Bool_t kTRUE
Definition: RtypesCore.h:87
double sqrt(double)
double log(double)
TMatrixT< Double_t > TMatrixD
Definition: TMatrixDfwd.h:22
char * Form(const char *fmt,...)
virtual Int_t GetNdim() const
Definition: TFormula.h:237
A specialized string object used for TTree selections.
Definition: TCut.h:25
virtual TFile * GetFile() const
Definition: TDirectory.h:152
A TLeaf describes individual elements of a TBranch See TBranch structure in TTree.
Definition: TLeaf.h:49
virtual Bool_t IsOnTerminalBranch() const
Definition: TLeaf.h:127
TBranch * GetBranch() const
Definition: TLeaf.h:99
const TCut & GetCut() const
Definition: ClassInfo.h:64
void SetNumber(const UInt_t index)
Definition: ClassInfo.h:59
const TString & GetWeight() const
Definition: ClassInfo.h:63
void SetConfigDescription(const char *d)
Definition: Configurable.h:64
OptionBase * DeclareOptionRef(T &ref, const TString &name, const TString &desc="")
void AddPreDefVal(const T &)
Definition: Configurable.h:168
void SetConfigName(const char *n)
Definition: Configurable.h:63
virtual void ParseOptions()
options parser
void CheckForUnusedOptions() const
checks for unused options in option string
Class that contains all the data information.
UInt_t GetEntries(const TString &name) const
std::vector< TreeInfo >::const_iterator end(const TString &className) const
std::vector< TString > * GetClassList() const
std::vector< TreeInfo >::const_iterator begin(const TString &className) const
DataSet * BuildInitialDataSet(DataSetInfo &, TMVA::DataInputHandler &)
if no entries, than create a DataSet with one Event which uses dynamic variables (pointers to variabl...
DataSetFactory()
constructor
std::map< Types::ETreeType, EventVectorOfClasses > EventVectorOfClassesOfTreeType
void ChangeToNewTree(TreeInfo &, const DataSetInfo &)
While the data gets copied into the local training and testing trees, the input tree can change (for ...
void BuildEventVector(DataSetInfo &dsi, DataInputHandler &dataInput, EventVectorOfClassesOfTreeType &eventsmap, EvtStatsPerClass &eventCounts)
build empty event vectors distributes events between kTraining/kTesting/kMaxTreeType
DataSet * CreateDataSet(DataSetInfo &, DataInputHandler &)
steering the creation of a new dataset
DataSet * MixEvents(DataSetInfo &dsi, EventVectorOfClassesOfTreeType &eventsmap, EvtStatsPerClass &eventCounts, const TString &splitMode, const TString &mixMode, const TString &normMode, UInt_t splitSeed)
Select and distribute unassigned events to kTraining and kTesting.
std::vector< int > NumberPerClass
std::vector< EventVector > EventVectorOfClasses
void InitOptions(DataSetInfo &dsi, EvtStatsPerClass &eventsmap, TString &normMode, UInt_t &splitSeed, TString &splitMode, TString &mixMode)
the dataset splitting
void CalcMinMax(DataSet *, DataSetInfo &dsi)
compute covariance matrix
std::vector< Double_t > ValuePerClass
DataSet * BuildDynamicDataSet(DataSetInfo &)
std::vector< EventStats > EvtStatsPerClass
Bool_t CheckTTreeFormula(TTreeFormula *ttf, const TString &expression, Bool_t &hasDollar)
checks a TTreeFormula for problems
void RenormEvents(DataSetInfo &dsi, EventVectorOfClassesOfTreeType &eventsmap, const EvtStatsPerClass &eventCounts, const TString &normMode)
renormalisation of the TRAINING event weights
TMatrixD * CalcCorrelationMatrix(DataSet *, const UInt_t classNumber)
computes correlation matrix for variables "theVars" in tree; "theType" defines the required event "ty...
TMatrixD * CalcCovarianceMatrix(DataSet *, const UInt_t classNumber)
compute covariance matrix
std::vector< Event * > EventVector
Class that contains all the data information.
Definition: DataSetInfo.h:60
std::vector< VariableInfo > & GetVariableInfos()
Definition: DataSetInfo.h:94
Bool_t HasCuts() const
UInt_t GetNVariables() const
Definition: DataSetInfo.h:110
UInt_t GetNSpectators(bool all=kTRUE) const
ClassInfo * AddClass(const TString &className)
virtual const char * GetName() const
Returns name of object.
Definition: DataSetInfo.h:67
std::vector< VariableInfo > & GetSpectatorInfos()
Definition: DataSetInfo.h:104
void SetNormalization(const TString &norm)
Definition: DataSetInfo.h:115
UInt_t GetNClasses() const
Definition: DataSetInfo.h:136
const TString & GetSplitOptions() const
Definition: DataSetInfo.h:167
UInt_t GetNTargets() const
Definition: DataSetInfo.h:111
void SetTestingSumSignalWeights(Double_t testingSumSignalWeights)
Definition: DataSetInfo.h:119
UInt_t GetSignalClassIndex()
Definition: DataSetInfo.h:139
void SetTrainingSumSignalWeights(Double_t trainingSumSignalWeights)
Definition: DataSetInfo.h:117
ClassInfo * GetClassInfo(Int_t clNum) const
void SetTestingSumBackgrWeights(Double_t testingSumBackgrWeights)
Definition: DataSetInfo.h:120
Int_t GetClassNameMaxLength() const
void PrintCorrelationMatrix(const TString &className)
calculates the correlation matrices for signal and background, prints them to standard output,...
VariableInfo & GetVariableInfo(Int_t i)
Definition: DataSetInfo.h:96
void SetTrainingSumBackgrWeights(Double_t trainingSumBackgrWeights)
Definition: DataSetInfo.h:118
VariableInfo & GetTargetInfo(Int_t i)
Definition: DataSetInfo.h:101
VariableInfo & GetSpectatorInfo(Int_t i)
Definition: DataSetInfo.h:106
void SetCorrelationMatrix(const TString &className, TMatrixD *matrix)
Class that contains all the data information.
Definition: DataSet.h:69
UInt_t GetNTargets() const
access the number of targets through the datasetinfo
Definition: DataSet.cxx:224
void SetEventCollection(std::vector< Event * > *, Types::ETreeType, Bool_t deleteEvents=true)
Sets the event collection (by DataSetFactory)
Definition: DataSet.cxx:250
Long64_t GetNTestEvents() const
Definition: DataSet.h:80
const Event * GetEvent() const
Definition: DataSet.cxx:202
Long64_t GetNEvents(Types::ETreeType type=Types::kMaxTreeType) const
Definition: DataSet.h:217
Long64_t GetNClassEvents(Int_t type, UInt_t classNumber)
Definition: DataSet.cxx:168
Long64_t GetNTrainingEvents() const
Definition: DataSet.h:79
UInt_t GetNSpectators() const
access the number of targets through the datasetinfo
Definition: DataSet.cxx:232
UInt_t GetNVariables() const
access the number of variables through the datasetinfo
Definition: DataSet.cxx:216
void SetCurrentType(Types::ETreeType type) const
Definition: DataSet.h:100
void SetCurrentEvent(Long64_t ievt) const
Definition: DataSet.h:99
Float_t GetValue(UInt_t ivar) const
return value of i'th variable
Definition: Event.cxx:237
Double_t GetWeight() const
return the event weight - depending on whether the flag IgnoreNegWeightsInTraining is or not.
Definition: Event.cxx:382
Float_t GetSpectator(UInt_t ivar) const
return spectator content
Definition: Event.cxx:262
UInt_t GetClass() const
Definition: Event.h:87
Float_t GetTarget(UInt_t itgt) const
Definition: Event.h:103
ostringstream derivative to redirect and format output
Definition: MsgLogger.h:59
Types::ETreeType GetTreeType() const
const TString & GetClassName() const
Double_t GetWeight() const
TTree * GetTree() const
@ kMaxTreeType
Definition: Types.h:146
@ kTraining
Definition: Types.h:144
@ kTesting
Definition: Types.h:145
void SetMax(Double_t v)
Definition: VariableInfo.h:70
const TString & GetExpression() const
Definition: VariableInfo.h:57
void SetMin(Double_t v)
Definition: VariableInfo.h:69
const TString & GetInternalName() const
Definition: VariableInfo.h:58
virtual const char * GetTitle() const
Returns title of object.
Definition: TNamed.h:48
virtual const char * GetName() const
Returns name of object.
Definition: TNamed.h:47
virtual const char * ClassName() const
Returns name of class to which the object belongs.
Definition: TObject.cxx:128
Basic string class.
Definition: TString.h:131
const char * Data() const
Definition: TString.h:364
void ToUpper()
Change string to upper case.
Definition: TString.cxx:1138
static TString Format(const char *fmt,...)
Static method which formats a string using a printf style format descriptor and return a TString.
Definition: TString.cxx:2311
Bool_t Contains(const char *pat, ECaseCompare cmp=kExact) const
Definition: TString.h:619
Used to pass a selection expression to the Tree drawing routine.
Definition: TTreeFormula.h:58
virtual TLeaf * GetLeaf(Int_t n) const
Return leaf corresponding to serial number n.
virtual Int_t GetNcodes() const
Definition: TTreeFormula.h:193
T EvalInstance(Int_t i=0, const char *stringStack[]=0)
Evaluate this treeformula.
virtual Int_t GetNdata()
Return number of available instances in the formula.
void SetQuickLoad(Bool_t quick)
Definition: TTreeFormula.h:207
A TTree represents a columnar dataset.
Definition: TTree.h:71
TFile * GetCurrentFile() const
Return pointer to the current file.
Definition: TTree.cxx:5263
TDirectory * GetDirectory() const
Definition: TTree.h:401
virtual Long64_t GetEntries() const
Definition: TTree.h:402
virtual TTree * GetTree() const
Definition: TTree.h:456
virtual Long64_t LoadTree(Long64_t entry)
Set current entry.
Definition: TTree.cxx:6251
virtual Int_t GetEntry(Long64_t entry=0, Int_t getall=0)
Read all branches of entry and return total number of bytes read.
Definition: TTree.cxx:5422
virtual void ResetBranchAddresses()
Tell all of our branches to drop their current objects and allocate new ones.
Definition: TTree.cxx:7781
virtual void SetBranchStatus(const char *bname, Bool_t status=1, UInt_t *found=0)
Set branch status to Process or DoNotProcess.
Definition: TTree.cxx:8195
std::string GetName(const std::string &scope_name)
Definition: Cppyy.cxx:146
RooCmdArg Verbose(Bool_t flag=kTRUE)
create variable transformations
Int_t LargestCommonDivider(Int_t a, Int_t b)
MsgLogger & Endl(MsgLogger &ml)
Definition: MsgLogger.h:158
Bool_t IsNaN(Double_t x)
Definition: TMath.h:880
Short_t Max(Short_t a, Short_t b)
Definition: TMathBase.h:212
Int_t Finite(Double_t x)
Check if it is finite with a mask in order to be consistent in presence of fast math.
Definition: TMath.h:759
constexpr Double_t E()
Base of natural log:
Definition: TMath.h:97
Double_t Log(Double_t x)
Definition: TMath.h:748
Short_t Abs(Short_t d)
Definition: TMathBase.h:120
auto * a
Definition: textangle.C:12