Logo ROOT   6.16/01
Reference Guide
VariableTransformBase.cxx
Go to the documentation of this file.
1// @(#)root/tmva $Id$
2// Author: Andreas Hoecker, Peter Speckmayer, Joerg Stelzer, Helge Voss
3
4/**********************************************************************************
5 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6 * Package: TMVA *
7 * Class : VariableTransformBase *
8 * Web : http://tmva.sourceforge.net *
9 * *
10 * Description: *
11 * Implementation (see header for description) *
12 * *
13 * Authors (alphabetical): *
14 * Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland *
15 * Peter Speckmayer <Peter.Speckmayer@cern.ch> - CERN, Switzerland *
16 * Joerg Stelzer <Joerg.Stelzer@cern.ch> - CERN, Switzerland *
17 * Helge Voss <Helge.Voss@cern.ch> - MPI-K Heidelberg, Germany *
18 * *
19 * Copyright (c) 2005: *
20 * CERN, Switzerland *
21 * MPI-K Heidelberg, Germany *
22 * *
23 * Redistribution and use in source and binary forms, with or without *
24 * modification, are permitted according to the terms listed in LICENSE *
25 * (http://tmva.sourceforge.net/LICENSE) *
26 **********************************************************************************/
27
28/*! \class TMVA::VariableTransformBase
29\ingroup TMVA
30Linear interpolation class.
31*/
32
34
35#include "TMVA/Config.h"
36#include "TMVA/DataSetInfo.h"
37#include "TMVA/MsgLogger.h"
38#include "TMVA/Ranking.h"
39#include "TMVA/Tools.h"
40#include "TMVA/Types.h"
41#include "TMVA/VariableInfo.h"
42#include "TMVA/Version.h"
43
44#include "TH1.h"
45#include "TH2.h"
46#include "THashTable.h"
47#include "TList.h"
48#include "TMath.h"
49#include "TProfile.h"
50#include "TVectorD.h"
51
52#include <algorithm>
53#include <cassert>
54#include <exception>
55#include <iomanip>
56#include <stdexcept>
57#include <set>
58
60
62
63////////////////////////////////////////////////////////////////////////////////
64/// standard constructor
65
68 const TString& trfName )
69: TObject(),
70 fDsi(dsi),
71 fDsiOutput(NULL),
72 fTransformedEvent(0),
73 fBackTransformedEvent(0),
74 fVariableTransform(tf),
75 fEnabled( kTRUE ),
76 fCreated( kFALSE ),
77 fNormalise( kFALSE ),
78 fTransformName(trfName),
79 fVariableTypesAreCounted(false),
80 fNVariables(0),
81 fNTargets(0),
82 fNSpectators(0),
83 fSortGet(kTRUE),
84 fTMVAVersion(TMVA_VERSION_CODE),
85 fLogger( 0 )
86{
87 fLogger = new MsgLogger(this, kINFO);
88 for (UInt_t ivar = 0; ivar < fDsi.GetNVariables(); ivar++) {
89 fVariables.push_back( VariableInfo( fDsi.GetVariableInfo(ivar) ) );
90 }
91 for (UInt_t itgt = 0; itgt < fDsi.GetNTargets(); itgt++) {
92 fTargets.push_back( VariableInfo( fDsi.GetTargetInfo(itgt) ) );
93 }
94 for (UInt_t ispct = 0; ispct < fDsi.GetNSpectators(); ispct++) {
95 fSpectators.push_back( VariableInfo( fDsi.GetSpectatorInfo(ispct) ) );
96 }
97}
98
99////////////////////////////////////////////////////////////////////////////////
100
102{
103 if (fTransformedEvent!=0) delete fTransformedEvent;
104 if (fBackTransformedEvent!=0) delete fBackTransformedEvent;
105 // destructor
106 delete fLogger;
107}
108
109////////////////////////////////////////////////////////////////////////////////
110/// select the variables/targets/spectators which serve as input to the transformation
111
112void TMVA::VariableTransformBase::SelectInput( const TString& _inputVariables, Bool_t putIntoVariables )
113{
114 TString inputVariables = _inputVariables;
115
116 // unselect all variables first
117 fGet.clear();
118
119 UInt_t nvars = GetNVariables();
120 UInt_t ntgts = GetNTargets();
121 UInt_t nspcts = GetNSpectators();
122
123 typedef std::set<Int_t> SelectedIndices;
124
125 SelectedIndices varIndices;
126 SelectedIndices tgtIndices;
127 SelectedIndices spctIndices;
128
129 if (inputVariables == "") // default is all variables and all targets
130 { // (the default can be changed by decorating this member function in the implementations)
131 inputVariables = "_V_,_T_";
132 }
133
134 TList* inList = gTools().ParseFormatLine( inputVariables, "," );
135 TListIter inIt(inList);
136 while (TObjString* os = (TObjString*)inIt()) {
137
138 TString variables = os->GetString();
139
140 if( variables.BeginsWith("_") && variables.EndsWith("_") ) { // special symbol (keyword)
141 variables.Remove( 0,1); // remove first "_"
142 variables.Remove( variables.Length()-1,1 ); // remove last "_"
143
144 if( variables.BeginsWith("V") ) { // variables
145 variables.Remove(0,1); // remove "V"
146 if( variables.Length() == 0 ){
147 for( UInt_t ivar = 0; ivar < nvars; ++ivar ) {
148 fGet.push_back( std::pair<Char_t,UInt_t>('v',ivar) );
149 varIndices.insert( ivar );
150 }
151 } else {
152 UInt_t idx = variables.Atoi();
153 if( idx >= nvars )
154 Log() << kFATAL << "You selected variable with index : " << idx << " of only " << nvars << " variables." << Endl;
155 fGet.push_back( std::pair<Char_t,UInt_t>('v',idx) );
156 varIndices.insert( idx );
157 }
158 }else if( variables.BeginsWith("T") ) { // targets
159 variables.Remove(0,1); // remove "T"
160 if( variables.Length() == 0 ){
161 for( UInt_t itgt = 0; itgt < ntgts; ++itgt ) {
162 fGet.push_back( std::pair<Char_t,UInt_t>('t',itgt) );
163 tgtIndices.insert( itgt );
164 }
165 } else {
166 UInt_t idx = variables.Atoi();
167 if( idx >= ntgts )
168 Log() << kFATAL << "You selected target with index : " << idx << " of only " << ntgts << " targets." << Endl;
169 fGet.push_back( std::pair<Char_t,UInt_t>('t',idx) );
170 tgtIndices.insert( idx );
171 }
172 }else if( variables.BeginsWith("S") ) { // spectators
173 variables.Remove(0,1); // remove "S"
174 if( variables.Length() == 0 ){
175 for( UInt_t ispct = 0; ispct < nspcts; ++ispct ) {
176 fGet.push_back( std::pair<Char_t,UInt_t>('s',ispct) );
177 spctIndices.insert( ispct );
178 }
179 } else {
180 UInt_t idx = variables.Atoi();
181 if( idx >= nspcts )
182 Log() << kFATAL << "You selected spectator with index : " << idx << " of only " << nspcts << " spectators." << Endl;
183 fGet.push_back( std::pair<Char_t,UInt_t>('s',idx) );
184 spctIndices.insert( idx );
185 }
186 }else if( TString("REARRANGE").BeginsWith(variables) ) { // toggle rearrange sorting (take sort order given in the options)
187 ToggleInputSortOrder( kFALSE );
188 if( !fSortGet )
189 Log() << kINFO << "Variable rearrangement set true: Variable order given in transformation option is used for input to transformation!" << Endl;
190
191 }
192 }else{ // no keyword, ... user provided variable labels
193 Int_t numIndices = varIndices.size()+tgtIndices.size()+spctIndices.size();
194 for( UInt_t ivar = 0; ivar < nvars; ++ivar ) { // search all variables
195 if( fDsi.GetVariableInfo( ivar ).GetLabel() == variables ) {
196 fGet.push_back( std::pair<Char_t,UInt_t>('v',ivar) );
197 varIndices.insert( ivar );
198 break;
199 }
200 }
201 for( UInt_t itgt = 0; itgt < ntgts; ++itgt ) { // search all targets
202 if( fDsi.GetTargetInfo( itgt ).GetLabel() == variables ) {
203 fGet.push_back( std::pair<Char_t,UInt_t>('t',itgt) );
204 tgtIndices.insert( itgt );
205 break;
206 }
207 }
208 for( UInt_t ispct = 0; ispct < nspcts; ++ispct ) { // search all spectators
209 if( fDsi.GetSpectatorInfo( ispct ).GetLabel() == variables ) {
210 fGet.push_back( std::pair<Char_t,UInt_t>('s',ispct) );
211 spctIndices.insert( ispct );
212 break;
213 }
214 }
215 Int_t numIndicesEndOfLoop = varIndices.size()+tgtIndices.size()+spctIndices.size();
216 if( numIndicesEndOfLoop == numIndices )
217 Log() << kWARNING << "Error at parsing the options for the variable transformations: Variable/Target/Spectator '" << variables.Data() << "' not found." << Endl;
218 numIndices = numIndicesEndOfLoop;
219 }
220 }
221
222
223 if( putIntoVariables ) {
224 Int_t idx = 0;
225 for( SelectedIndices::iterator it = varIndices.begin(), itEnd = varIndices.end(); it != itEnd; ++it ) {
226 fPut.push_back( std::pair<Char_t,UInt_t>('v',idx) );
227 ++idx;
228 }
229 for( SelectedIndices::iterator it = tgtIndices.begin(), itEnd = tgtIndices.end(); it != itEnd; ++it ) {
230 fPut.push_back( std::pair<Char_t,UInt_t>('t',idx) );
231 ++idx;
232 }
233 for( SelectedIndices::iterator it = spctIndices.begin(), itEnd = spctIndices.end(); it != itEnd; ++it ) {
234 fPut.push_back( std::pair<Char_t,UInt_t>('s',idx) );
235 ++idx;
236 }
237 }else {
238 for( SelectedIndices::iterator it = varIndices.begin(), itEnd = varIndices.end(); it != itEnd; ++it ) {
239 Int_t idx = (*it);
240 fPut.push_back( std::pair<Char_t,UInt_t>('v',idx) );
241 }
242 for( SelectedIndices::iterator it = tgtIndices.begin(), itEnd = tgtIndices.end(); it != itEnd; ++it ) {
243 Int_t idx = (*it);
244 fPut.push_back( std::pair<Char_t,UInt_t>('t',idx) );
245 }
246 for( SelectedIndices::iterator it = spctIndices.begin(), itEnd = spctIndices.end(); it != itEnd; ++it ) {
247 Int_t idx = (*it);
248 fPut.push_back( std::pair<Char_t,UInt_t>('s',idx) );
249 }
250
251 // if sorting is turned on, fGet should have the indices sorted as fPut has them.
252 if( fSortGet ) {
253 fGet.clear();
254 fGet.assign( fPut.begin(), fPut.end() );
255 }
256 }
257
258 Log() << kHEADER << "Transformation, Variable selection : " << Endl;
259
260 // choose the new dsi for output if present, if not, take the common one
261 const DataSetInfo* outputDsiPtr = (fDsiOutput? &(*fDsiOutput) : &fDsi );
262
263
264
265 ItVarTypeIdx itGet = fGet.begin(), itGetEnd = fGet.end();
266 ItVarTypeIdx itPut = fPut.begin(); // , itPutEnd = fPut.end();
267 for( ; itGet != itGetEnd; ++itGet ) {
268 TString inputTypeString = "?";
269
270 Char_t inputType = (*itGet).first;
271 Int_t inputIdx = (*itGet).second;
272
273 TString inputLabel = "NOT FOND";
274 if( inputType == 'v' ) {
275 inputLabel = fDsi.GetVariableInfo( inputIdx ).GetLabel();
276 inputTypeString = "variable";
277 }
278 else if( inputType == 't' ){
279 inputLabel = fDsi.GetTargetInfo( inputIdx ).GetLabel();
280 inputTypeString = "target";
281 }
282 else if( inputType == 's' ){
283 inputLabel = fDsi.GetSpectatorInfo( inputIdx ).GetLabel();
284 inputTypeString = "spectator";
285 }
286
287 TString outputTypeString = "?";
288
289 Char_t outputType = (*itPut).first;
290 Int_t outputIdx = (*itPut).second;
291
292 TString outputLabel = "NOT FOUND";
293 if( outputType == 'v' ) {
294 outputLabel = outputDsiPtr->GetVariableInfo( outputIdx ).GetLabel();
295 outputTypeString = "variable";
296 }
297 else if( outputType == 't' ){
298 outputLabel = outputDsiPtr->GetTargetInfo( outputIdx ).GetLabel();
299 outputTypeString = "target";
300 }
301 else if( outputType == 's' ){
302 outputLabel = outputDsiPtr->GetSpectatorInfo( outputIdx ).GetLabel();
303 outputTypeString = "spectator";
304 }
305 Log() << kINFO << "Input : " << inputTypeString.Data() << " '" << inputLabel.Data() << "'" << " <---> " << "Output : " << outputTypeString.Data() << " '" << outputLabel.Data() << "'" << Endl;
306 Log() << kDEBUG << "\t(index=" << inputIdx << ")." << "\t(index=" << outputIdx << ")." << Endl;
307
308 ++itPut;
309 }
310 // Log() << kINFO << Endl;
311}
312
313
314////////////////////////////////////////////////////////////////////////////////
315/// select the values from the event
316
317Bool_t TMVA::VariableTransformBase::GetInput( const Event* event, std::vector<Float_t>& input, std::vector<Char_t>& mask, Bool_t backTransformation ) const
318{
319 ItVarTypeIdxConst itEntry;
320 ItVarTypeIdxConst itEntryEnd;
321
322 input.clear();
323 mask.clear();
324
325 if( backTransformation && !fPut.empty() ){
326 itEntry = fPut.begin();
327 itEntryEnd = fPut.end();
328 input.reserve(fPut.size());
329 }
330 else {
331 itEntry = fGet.begin();
332 itEntryEnd = fGet.end();
333 input.reserve(fGet.size() );
334 }
335
336 Bool_t hasMaskedEntries = kFALSE;
337 // event->Print(std::cout);
338 for( ; itEntry != itEntryEnd; ++itEntry ) {
339 Char_t type = (*itEntry).first;
340 Int_t idx = (*itEntry).second;
341
342 try{
343 switch( type ) {
344 case 'v':
345 input.push_back( event->GetValue(idx) );
346 break;
347 case 't':
348 input.push_back( event->GetTarget(idx) );
349 break;
350 case 's':
351 input.push_back( event->GetSpectator(idx) );
352 break;
353 default:
354 Log() << kFATAL << "VariableTransformBase/GetInput : unknown type '" << type << "'." << Endl;
355 }
356 mask.push_back(kFALSE);
357 }
358 catch(std::out_of_range& /* excpt */ ){ // happens when an event is transformed which does not yet have the targets calculated (in the application phase)
359 input.push_back(0.f);
360 mask.push_back(kTRUE);
361 hasMaskedEntries = kTRUE;
362 }
363 }
364 return hasMaskedEntries;
365}
366
367////////////////////////////////////////////////////////////////////////////////
368/// select the values from the event
369
370void TMVA::VariableTransformBase::SetOutput( Event* event, std::vector<Float_t>& output, std::vector<Char_t>& mask, const Event* oldEvent, Bool_t backTransformation ) const
371{
372 std::vector<Float_t>::iterator itOutput = output.begin();
373 std::vector<Char_t>::iterator itMask = mask.begin();
374
375 if( oldEvent )
376 event->CopyVarValues( *oldEvent );
377
378 try {
379
380 ItVarTypeIdxConst itEntry;
381 ItVarTypeIdxConst itEntryEnd;
382
383 if( backTransformation || fPut.empty() ){ // as in GetInput, but the other way round (from fPut for transformation, from fGet for backTransformation)
384 itEntry = fGet.begin();
385 itEntryEnd = fGet.end();
386 }
387 else {
388 itEntry = fPut.begin();
389 itEntryEnd = fPut.end();
390 }
391
392
393 for( ; itEntry != itEntryEnd; ++itEntry ) {
394
395 if( (*itMask) ){ // if the value is masked
396 continue;
397 }
398
399 Char_t type = (*itEntry).first;
400 Int_t idx = (*itEntry).second;
401 if (itOutput == output.end()) Log() << kFATAL << "Read beyond array boundaries in VariableTransformBase::SetOutput"<<Endl;
402 Float_t value = (*itOutput);
403
404 switch( type ) {
405 case 'v':
406 event->SetVal( idx, value );
407 break;
408 case 't':
409 event->SetTarget( idx, value );
410 break;
411 case 's':
412 event->SetSpectator( idx, value );
413 break;
414 default:
415 Log() << kFATAL << "VariableTransformBase/GetInput : unknown type '" << type << "'." << Endl;
416 }
417 if( !(*itMask) ) ++itOutput;
418 ++itMask;
419
420 }
421 }catch( std::exception& except ){
422 Log() << kFATAL << "VariableTransformBase/SetOutput : exception/" << except.what() << Endl;
423 throw;
424 }
425}
426
427
428////////////////////////////////////////////////////////////////////////////////
429/// count variables, targets and spectators
430
432{
433 if( fVariableTypesAreCounted ){
434 nvars = fNVariables;
435 ntgts = fNTargets;
436 nspcts = fNSpectators;
437 return;
438 }
439
440 nvars = ntgts = nspcts = 0;
441
442 for( ItVarTypeIdxConst itEntry = fGet.begin(), itEntryEnd = fGet.end(); itEntry != itEntryEnd; ++itEntry ) {
443 Char_t type = (*itEntry).first;
444
445 switch( type ) {
446 case 'v':
447 nvars++;
448 break;
449 case 't':
450 ntgts++;
451 break;
452 case 's':
453 nspcts++;
454 break;
455 default:
456 Log() << kFATAL << "VariableTransformBase/GetVariableTypeNumbers : unknown type '" << type << "'." << Endl;
457 }
458 }
459
460 fNVariables = nvars;
461 fNTargets = ntgts;
462 fNSpectators = nspcts;
463
464 fVariableTypesAreCounted = true;
465}
466
467////////////////////////////////////////////////////////////////////////////////
468/// TODO --> adapt to variable,target,spectator selection
469/// method to calculate minimum, maximum, mean, and RMS for all
470/// variables used in the MVA
471
472void TMVA::VariableTransformBase::CalcNorm( const std::vector<const Event*>& events )
473{
474 if (!IsCreated()) return;
475
476 const UInt_t nvars = GetNVariables();
477 const UInt_t ntgts = GetNTargets();
478
479 UInt_t nevts = events.size();
480
481 TVectorD x2( nvars+ntgts ); x2 *= 0;
482 TVectorD x0( nvars+ntgts ); x0 *= 0;
483 TVectorD v0( nvars+ntgts ); v0 *= 0;
484
485 Double_t sumOfWeights = 0;
486 for (UInt_t ievt=0; ievt<nevts; ievt++) {
487 const Event* ev = events[ievt];
488
489 Double_t weight = ev->GetWeight();
490 sumOfWeights += weight;
491 for (UInt_t ivar=0; ivar<nvars; ivar++) {
492 Double_t x = ev->GetValue(ivar);
493 if (ievt==0) {
494 Variables().at(ivar).SetMin(x);
495 Variables().at(ivar).SetMax(x);
496 }
497 else {
498 UpdateNorm( ivar, x );
499 }
500 x0(ivar) += x*weight;
501 x2(ivar) += x*x*weight;
502 }
503 for (UInt_t itgt=0; itgt<ntgts; itgt++) {
504 Double_t x = ev->GetTarget(itgt);
505 if (ievt==0) {
506 Targets().at(itgt).SetMin(x);
507 Targets().at(itgt).SetMax(x);
508 }
509 else {
510 UpdateNorm( nvars+itgt, x );
511 }
512 x0(nvars+itgt) += x*weight;
513 x2(nvars+itgt) += x*x*weight;
514 }
515 }
516
517 if (sumOfWeights <= 0) {
518 Log() << kFATAL << " the sum of event weights calculated for your input is == 0"
519 << " or exactly: " << sumOfWeights << " there is obviously some problem..."<< Endl;
520 }
521
522 // set Mean and RMS
523 for (UInt_t ivar=0; ivar<nvars; ivar++) {
524 Double_t mean = x0(ivar)/sumOfWeights;
525
526 Variables().at(ivar).SetMean( mean );
527 if (x2(ivar)/sumOfWeights - mean*mean < 0) {
528 Log() << kFATAL << " the RMS of your input variable " << ivar
529 << " evaluates to an imaginary number: sqrt("<< x2(ivar)/sumOfWeights - mean*mean
530 <<") .. sometimes related to a problem with outliers and negative event weights"
531 << Endl;
532 }
533 Variables().at(ivar).SetRMS( TMath::Sqrt( x2(ivar)/sumOfWeights - mean*mean) );
534 }
535 for (UInt_t itgt=0; itgt<ntgts; itgt++) {
536 Double_t mean = x0(nvars+itgt)/sumOfWeights;
537 Targets().at(itgt).SetMean( mean );
538 if (x2(nvars+itgt)/sumOfWeights - mean*mean < 0) {
539 Log() << kFATAL << " the RMS of your target variable " << itgt
540 << " evaluates to an imaginary number: sqrt(" << x2(nvars+itgt)/sumOfWeights - mean*mean
541 <<") .. sometimes related to a problem with outliers and negative event weights"
542 << Endl;
543 }
544 Targets().at(itgt).SetRMS( TMath::Sqrt( x2(nvars+itgt)/sumOfWeights - mean*mean) );
545 }
546 // calculate variance
547 for (UInt_t ievt=0; ievt<nevts; ievt++) {
548 const Event* ev = events[ievt];
549 Double_t weight = ev->GetWeight();
550 for (UInt_t ivar=0; ivar<nvars; ivar++) {
551 Double_t x = ev->GetValue(ivar);
552 Double_t mean = Variables().at(ivar).GetMean();
553 v0(ivar) += weight*(x-mean)*(x-mean);
554 }
555 for (UInt_t itgt=0; itgt<ntgts; itgt++) {
556 Double_t x = ev->GetTarget(itgt);
557 Double_t mean = Targets().at(itgt).GetMean();
558 v0(nvars+itgt) += weight*(x-mean)*(x-mean);
559 }
560
561 }
562
563 // set variance
564 for (UInt_t ivar=0; ivar<nvars; ivar++) {
565 Double_t variance = v0(ivar)/sumOfWeights;
566 Variables().at(ivar).SetVariance( variance );
567 Log() << kINFO << "Variable " << Variables().at(ivar).GetExpression() <<" variance = " << variance << Endl;
568 }
569 for (UInt_t itgt=0; itgt<ntgts; itgt++) {
570 Double_t variance = v0(nvars+itgt)/sumOfWeights;
571 Targets().at(itgt).SetVariance( variance );
572 Log() << kINFO << "Target " << Targets().at(itgt).GetExpression() <<" variance = " << variance << Endl;
573 }
574
575 Log() << kVERBOSE << "Set minNorm/maxNorm for variables to: " << Endl;
576 Log() << std::setprecision(3);
577 for (UInt_t ivar=0; ivar<GetNVariables(); ivar++)
578 Log() << " " << Variables().at(ivar).GetInternalName()
579 << "\t: [" << Variables().at(ivar).GetMin() << "\t, " << Variables().at(ivar).GetMax() << "\t] " << Endl;
580 Log() << kVERBOSE << "Set minNorm/maxNorm for targets to: " << Endl;
581 Log() << std::setprecision(3);
582 for (UInt_t itgt=0; itgt<GetNTargets(); itgt++)
583 Log() << " " << Targets().at(itgt).GetInternalName()
584 << "\t: [" << Targets().at(itgt).GetMin() << "\t, " << Targets().at(itgt).GetMax() << "\t] " << Endl;
585 Log() << std::setprecision(5); // reset to better value
586}
587
588////////////////////////////////////////////////////////////////////////////////
589/// TODO --> adapt to variable,target,spectator selection
590/// default transformation output
591/// --> only indicate that transformation occurred
592
594{
595 std::vector<TString>* strVec = new std::vector<TString>;
596 for (UInt_t ivar=0; ivar<GetNVariables(); ivar++) {
597 strVec->push_back( Variables()[ivar].GetLabel() + "_[transformed]");
598 }
599
600 return strVec;
601}
602
603////////////////////////////////////////////////////////////////////////////////
604/// TODO --> adapt to variable,target,spectator selection
605/// update min and max of a given variable (target) and a given transformation method
606
608{
609 Int_t nvars = fDsi.GetNVariables();
610 if( ivar < nvars ){
611 if (x < Variables().at(ivar).GetMin()) Variables().at(ivar).SetMin(x);
612 if (x > Variables().at(ivar).GetMax()) Variables().at(ivar).SetMax(x);
613 }else{
614 if (x < Targets().at(ivar-nvars).GetMin()) Targets().at(ivar-nvars).SetMin(x);
615 if (x > Targets().at(ivar-nvars).GetMax()) Targets().at(ivar-nvars).SetMax(x);
616 }
617}
618
619////////////////////////////////////////////////////////////////////////////////
620/// create XML description the transformation (write out info of selected variables)
621
623{
624 void* selxml = gTools().AddChild(parent, "Selection");
625
626 void* inpxml = gTools().AddChild(selxml, "Input");
627 gTools().AddAttr(inpxml, "NInputs", fGet.size() );
628
629 // choose the new dsi for output if present, if not, take the common one
630 const DataSetInfo* outputDsiPtr = (fDsiOutput? fDsiOutput : &fDsi );
631
632 for( ItVarTypeIdx itGet = fGet.begin(), itGetEnd = fGet.end(); itGet != itGetEnd; ++itGet ) {
633 UInt_t idx = (*itGet).second;
634 Char_t type = (*itGet).first;
635
636 TString label = "";
637 TString expression = "";
638 TString typeString = "";
639 switch( type ){
640 case 'v':
641 typeString = "Variable";
642 label = fDsi.GetVariableInfo( idx ).GetLabel();
643 expression = fDsi.GetVariableInfo( idx ).GetExpression();
644 break;
645 case 't':
646 typeString = "Target";
647 label = fDsi.GetTargetInfo( idx ).GetLabel();
648 expression = fDsi.GetTargetInfo( idx ).GetExpression();
649 break;
650 case 's':
651 typeString = "Spectator";
652 label = fDsi.GetSpectatorInfo( idx ).GetLabel();
653 expression = fDsi.GetSpectatorInfo( idx ).GetExpression();
654 break;
655 default:
656 Log() << kFATAL << "VariableTransformBase/AttachXMLTo unknown variable type '" << type << "'." << Endl;
657 }
658
659 void* idxxml = gTools().AddChild(inpxml, "Input");
660 // gTools().AddAttr(idxxml, "Index", idx);
661 gTools().AddAttr(idxxml, "Type", typeString);
662 gTools().AddAttr(idxxml, "Label", label);
663 gTools().AddAttr(idxxml, "Expression", expression);
664 }
665
666
667 void* outxml = gTools().AddChild(selxml, "Output");
668 gTools().AddAttr(outxml, "NOutputs", fPut.size() );
669
670 for( ItVarTypeIdx itPut = fPut.begin(), itPutEnd = fPut.end(); itPut != itPutEnd; ++itPut ) {
671 UInt_t idx = (*itPut).second;
672 Char_t type = (*itPut).first;
673
674 TString label = "";
675 TString expression = "";
676 TString typeString = "";
677 switch( type ){
678 case 'v':
679 typeString = "Variable";
680 label = outputDsiPtr->GetVariableInfo( idx ).GetLabel();
681 expression = outputDsiPtr->GetVariableInfo( idx ).GetExpression();
682 break;
683 case 't':
684 typeString = "Target";
685 label = outputDsiPtr->GetTargetInfo( idx ).GetLabel();
686 expression = outputDsiPtr->GetTargetInfo( idx ).GetExpression();
687 break;
688 case 's':
689 typeString = "Spectator";
690 label = outputDsiPtr->GetSpectatorInfo( idx ).GetLabel();
691 expression = outputDsiPtr->GetSpectatorInfo( idx ).GetExpression();
692 break;
693 default:
694 Log() << kFATAL << "VariableTransformBase/AttachXMLTo unknown variable type '" << type << "'." << Endl;
695 }
696
697 void* idxxml = gTools().AddChild(outxml, "Output");
698 // gTools().AddAttr(idxxml, "Index", idx);
699 gTools().AddAttr(idxxml, "Type", typeString);
700 gTools().AddAttr(idxxml, "Label", label);
701 gTools().AddAttr(idxxml, "Expression", expression);
702 }
703
704
705}
706
707////////////////////////////////////////////////////////////////////////////////
708/// Read the input variables from the XML node
709
711{
712 void* inpnode = gTools().GetChild( selnode );
713 void* outnode = gTools().GetNextChild( inpnode );
714
715 UInt_t nvars = GetNVariables();
716 UInt_t ntgts = GetNTargets();
717 UInt_t nspcts = GetNSpectators();
718
719 // read inputs
720 fGet.clear();
721
722 UInt_t nInputs = 0;
723 gTools().ReadAttr(inpnode, "NInputs", nInputs);
724
725 void* ch = gTools().GetChild( inpnode );
726 while(ch) {
727 TString typeString = "";
728 TString label = "";
729 TString expression = "";
730
731 gTools().ReadAttr(ch, "Type", typeString);
732 gTools().ReadAttr(ch, "Label", label);
733 gTools().ReadAttr(ch, "Expression", expression);
734
735 if( typeString == "Variable" ){
736 for( UInt_t ivar = 0; ivar < nvars; ++ivar ) { // search all variables
737 if( fDsi.GetVariableInfo( ivar ).GetLabel() == label ||
738 fDsi.GetVariableInfo( ivar ).GetExpression() == expression) {
739 fGet.push_back( std::pair<Char_t,UInt_t>('v',ivar) );
740 break;
741 }
742 }
743 }else if( typeString == "Target" ){
744 for( UInt_t itgt = 0; itgt < ntgts; ++itgt ) { // search all targets
745 if( fDsi.GetTargetInfo( itgt ).GetLabel() == label ||
746 fDsi.GetTargetInfo( itgt ).GetExpression() == expression ) {
747 fGet.push_back( std::pair<Char_t,UInt_t>('t',itgt) );
748 break;
749 }
750 }
751 }else if( typeString == "Spectator" ){
752 for( UInt_t ispct = 0; ispct < nspcts; ++ispct ) { // search all spectators
753 if( fDsi.GetSpectatorInfo( ispct ).GetLabel() == label ||
754 fDsi.GetSpectatorInfo( ispct ).GetExpression() == expression ) {
755 fGet.push_back( std::pair<Char_t,UInt_t>('s',ispct) );
756 break;
757 }
758 }
759 }else{
760 Log() << kFATAL << "VariableTransformationBase/ReadFromXML : unknown type '" << typeString << "'." << Endl;
761 }
762 ch = gTools().GetNextChild( ch );
763 }
764
765 assert( nInputs == fGet.size() );
766
767 // read outputs
768 fPut.clear();
769
770 UInt_t nOutputs = 0;
771 gTools().ReadAttr(outnode, "NOutputs", nOutputs);
772
773 void* chOut = gTools().GetChild( outnode );
774 while(chOut) {
775 TString typeString = "";
776 TString label = "";
777 TString expression = "";
778
779 gTools().ReadAttr(chOut, "Type", typeString);
780 gTools().ReadAttr(chOut, "Label", label);
781 gTools().ReadAttr(chOut, "Expression", expression);
782
783 if( typeString == "Variable" ){
784 for( UInt_t ivar = 0; ivar < nvars; ++ivar ) { // search all variables
785 if( fDsi.GetVariableInfo( ivar ).GetLabel() == label ||
786 fDsi.GetVariableInfo( ivar ).GetExpression() == expression ) {
787 fPut.push_back( std::pair<Char_t,UInt_t>('v',ivar) );
788 break;
789 }
790 }
791 }else if( typeString == "Target" ){
792 for( UInt_t itgt = 0; itgt < ntgts; ++itgt ) { // search all targets
793 if( fDsi.GetTargetInfo( itgt ).GetLabel() == label ||
794 fDsi.GetTargetInfo( itgt ).GetExpression() == expression ) {
795 fPut.push_back( std::pair<Char_t,UInt_t>('t',itgt) );
796 break;
797 }
798 }
799 }else if( typeString == "Spectator" ){
800 for( UInt_t ispct = 0; ispct < nspcts; ++ispct ) { // search all spectators
801 if( fDsi.GetSpectatorInfo( ispct ).GetLabel() == label ||
802 fDsi.GetSpectatorInfo( ispct ).GetExpression() == expression ) {
803 fPut.push_back( std::pair<Char_t,UInt_t>('s',ispct) );
804 break;
805 }
806 }
807 }else{
808 Log() << kFATAL << "VariableTransformationBase/ReadFromXML : unknown type '" << typeString << "'." << Endl;
809 }
810 chOut = gTools().GetNextChild( chOut );
811 }
812
813 assert( nOutputs == fPut.size() );
814}
815
816////////////////////////////////////////////////////////////////////////////////
817/// getinput and setoutput equivalent
818
819void TMVA::VariableTransformBase::MakeFunction( std::ostream& fout, const TString& /*fncName*/, Int_t part,
820 UInt_t /*trCounter*/, Int_t /*cls*/ )
821{
822 if( part == 0 ){ // definitions
823 fout << std::endl;
824 fout << " // define the indices of the variables which are transformed by this transformation" << std::endl;
825 fout << " static std::vector<int> indicesGet;" << std::endl;
826 fout << " static std::vector<int> indicesPut;" << std::endl << std::endl;
827 fout << " if ( indicesGet.empty() ) {" << std::endl;
828 fout << " indicesGet.reserve(fNvars);" << std::endl;
829
830 for( ItVarTypeIdxConst itEntry = fGet.begin(), itEntryEnd = fGet.end(); itEntry != itEntryEnd; ++itEntry ) {
831 Char_t type = (*itEntry).first;
832 Int_t idx = (*itEntry).second;
833
834 switch( type ) {
835 case 'v':
836 fout << " indicesGet.push_back( " << idx << ");" << std::endl;
837 break;
838 case 't':
839 Log() << kWARNING << "MakeClass doesn't work with transformation of targets. The results will be wrong!" << Endl;
840 break;
841 case 's':
842 Log() << kWARNING << "MakeClass doesn't work with transformation of spectators. The results will be wrong!" << Endl;
843 break;
844 default:
845 Log() << kFATAL << "VariableTransformBase/GetInput : unknown type '" << type << "'." << Endl;
846 }
847 }
848 fout << " }" << std::endl;
849 fout << " if ( indicesPut.empty() ) {" << std::endl;
850 fout << " indicesPut.reserve(fNvars);" << std::endl;
851
852 for( ItVarTypeIdxConst itEntry = fPut.begin(), itEntryEnd = fPut.end(); itEntry != itEntryEnd; ++itEntry ) {
853 Char_t type = (*itEntry).first;
854 Int_t idx = (*itEntry).second;
855
856 switch( type ) {
857 case 'v':
858 fout << " indicesPut.push_back( " << idx << ");" << std::endl;
859 break;
860 case 't':
861 Log() << kWARNING << "MakeClass doesn't work with transformation of targets. The results will be wrong!" << Endl;
862 break;
863 case 's':
864 Log() << kWARNING << "MakeClass doesn't work with transformation of spectators. The results will be wrong!" << Endl;
865 break;
866 default:
867 Log() << kFATAL << "VariableTransformBase/PutInput : unknown type '" << type << "'." << Endl;
868 }
869 }
870
871 fout << " }" << std::endl;
872 fout << std::endl;
873
874 }else if( part == 1){
875 }
876}
static const double x2[5]
int Int_t
Definition: RtypesCore.h:41
char Char_t
Definition: RtypesCore.h:29
unsigned int UInt_t
Definition: RtypesCore.h:42
const Bool_t kFALSE
Definition: RtypesCore.h:88
bool Bool_t
Definition: RtypesCore.h:59
double Double_t
Definition: RtypesCore.h:55
float Float_t
Definition: RtypesCore.h:53
const Bool_t kTRUE
Definition: RtypesCore.h:87
#define ClassImp(name)
Definition: Rtypes.h:363
int type
Definition: TGX11.cxx:120
bool advanced
#define TMVA_VERSION_CODE
Definition: Version.h:47
Iterator of linked list.
Definition: TList.h:200
A doubly linked list.
Definition: TList.h:44
Class that contains all the data information.
Definition: DataSetInfo.h:60
UInt_t GetNVariables() const
Definition: DataSetInfo.h:110
UInt_t GetNSpectators(bool all=kTRUE) const
UInt_t GetNTargets() const
Definition: DataSetInfo.h:111
VariableInfo & GetVariableInfo(Int_t i)
Definition: DataSetInfo.h:96
VariableInfo & GetTargetInfo(Int_t i)
Definition: DataSetInfo.h:101
VariableInfo & GetSpectatorInfo(Int_t i)
Definition: DataSetInfo.h:106
Float_t GetValue(UInt_t ivar) const
return value of i'th variable
Definition: Event.cxx:237
Double_t GetWeight() const
return the event weight - depending on whether the flag IgnoreNegWeightsInTraining is or not.
Definition: Event.cxx:382
Float_t GetSpectator(UInt_t ivar) const
return spectator content
Definition: Event.cxx:262
Float_t GetTarget(UInt_t itgt) const
Definition: Event.h:97
ostringstream derivative to redirect and format output
Definition: MsgLogger.h:59
TList * ParseFormatLine(TString theString, const char *sep=":")
Parse the string and cut into labels separated by ":".
Definition: Tools.cxx:413
void * GetNextChild(void *prevchild, const char *childname=0)
XML helpers.
Definition: Tools.cxx:1174
void * AddChild(void *parent, const char *childname, const char *content=0, bool isRootNode=false)
add child node
Definition: Tools.cxx:1136
void * GetChild(void *parent, const char *childname=0)
get child node
Definition: Tools.cxx:1162
void ReadAttr(void *node, const char *, T &value)
read attribute from xml
Definition: Tools.h:337
void AddAttr(void *node, const char *, const T &value, Int_t precision=16)
add attribute to xml
Definition: Tools.h:355
EVariableTransform
Definition: Types.h:115
Class for type info of MVA input variable.
Definition: VariableInfo.h:47
const TString & GetLabel() const
Definition: VariableInfo.h:59
const TString & GetExpression() const
Definition: VariableInfo.h:57
Linear interpolation class.
virtual void MakeFunction(std::ostream &fout, const TString &fncName, Int_t part, UInt_t trCounter, Int_t cls)=0
getinput and setoutput equivalent
virtual Bool_t GetInput(const Event *event, std::vector< Float_t > &input, std::vector< Char_t > &mask, Bool_t backTransform=kFALSE) const
select the values from the event
void CalcNorm(const std::vector< const Event * > &)
TODO --> adapt to variable,target,spectator selection method to calculate minimum,...
virtual void ReadFromXML(void *trfnode)=0
Read the input variables from the XML node.
virtual void AttachXMLTo(void *parent)=0
create XML description the transformation (write out info of selected variables)
virtual void SetOutput(Event *event, std::vector< Float_t > &output, std::vector< Char_t > &mask, const Event *oldEvent=0, Bool_t backTransform=kFALSE) const
select the values from the event
std::vector< TMVA::VariableInfo > fVariables
VariableTransformBase(DataSetInfo &dsi, Types::EVariableTransform tf, const TString &trfName)
standard constructor
void UpdateNorm(Int_t ivar, Double_t x)
TODO --> adapt to variable,target,spectator selection update min and max of a given variable (target)...
virtual void CountVariableTypes(UInt_t &nvars, UInt_t &ntgts, UInt_t &nspcts) const
count variables, targets and spectators
virtual std::vector< TString > * GetTransformationStrings(Int_t cls) const
TODO --> adapt to variable,target,spectator selection default transformation output --> only indicate...
virtual void SelectInput(const TString &inputVariables, Bool_t putIntoVariables=kFALSE)
select the variables/targets/spectators which serve as input to the transformation
VectorOfCharAndInt::iterator ItVarTypeIdx
std::vector< TMVA::VariableInfo > fSpectators
std::vector< TMVA::VariableInfo > fTargets
VectorOfCharAndInt::const_iterator ItVarTypeIdxConst
Collectable string class.
Definition: TObjString.h:28
Mother of all ROOT objects.
Definition: TObject.h:37
Basic string class.
Definition: TString.h:131
const char * Data() const
Definition: TString.h:364
Double_t x[n]
Definition: legend1.C:17
Tools & gTools()
void variables(TString dataset, TString fin="TMVA.root", TString dirName="InputVariables_Id", TString title="TMVA Input Variables", Bool_t isRegression=kFALSE, Bool_t useTMVAStyle=kTRUE)
MsgLogger & Endl(MsgLogger &ml)
Definition: MsgLogger.h:158
Double_t Log(Double_t x)
Definition: TMath.h:748
Double_t Sqrt(Double_t x)
Definition: TMath.h:679