Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
VariableDecorrTransform.cxx
Go to the documentation of this file.
1// @(#)root/tmva $Id$
2// Author: Andreas Hoecker, Joerg Stelzer, Helge Voss, Eckhard von Toerne
3
4/**********************************************************************************
5 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6 * Package: TMVA *
7 * Class : VariableDecorrTransform *
8 * Web : http://tmva.sourceforge.net *
9 * *
10 * Description: *
11 * Implementation (see header for description) *
12 * *
13 * Authors (alphabetical): *
14 * Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland *
15 * Peter Speckmayer <Peter.Speckmayer@cern.ch> - CERN, Switzerland *
16 * Joerg Stelzer <Joerg.Stelzer@cern.ch> - CERN, Switzerland *
17 * Helge Voss <Helge.Voss@cern.ch> - MPI-K Heidelberg, Germany *
18 * *
19 * Copyright (c) 2005-2011: *
20 * CERN, Switzerland *
21 * MPI-K Heidelberg, Germany *
22 * U. of Bonn, Germany *
23 * *
24 * Redistribution and use in source and binary forms, with or without *
25 * modification, are permitted according to the terms listed in LICENSE *
26 * (http://tmva.sourceforge.net/LICENSE) *
27 **********************************************************************************/
28
29/*! \class TMVA::VariableDecorrTransform
30\ingroup TMVA
31Linear interpolation class
32*/
33
35
36#include "TMVA/DataSet.h"
37#include "TMVA/Event.h"
38#include "TMVA/MsgLogger.h"
39#include "TMVA/Tools.h"
40#include "TMVA/Types.h"
41#include "TMVA/VariableInfo.h"
42
43#include "TVectorF.h"
44#include "TVectorD.h"
45#include "TMatrixD.h"
46#include "TMatrixDBase.h"
47
48#include <iostream>
49#include <iomanip>
50#include <algorithm>
51
53
54////////////////////////////////////////////////////////////////////////////////
55/// constructor
56
58: VariableTransformBase( dsi, Types::kDecorrelated, "Deco" )
59{
60}
61
62////////////////////////////////////////////////////////////////////////////////
63/// destructor
64
66{
67 for (std::vector<TMatrixD*>::iterator it = fDecorrMatrices.begin(); it != fDecorrMatrices.end(); ++it) {
68 if ((*it) != 0) delete (*it);
69 }
70}
71
72////////////////////////////////////////////////////////////////////////////////
73/// initialisation
74
76{
77}
78
79////////////////////////////////////////////////////////////////////////////////
80/// calculate the decorrelation matrix and the normalization
81
83{
84 Initialize();
85
86 if (!IsEnabled() || IsCreated()) return kTRUE;
87
88 Log() << kINFO << "Preparing the Decorrelation transformation..." << Endl;
89
90 Int_t inputSize = fGet.size();
91 SetNVariables(inputSize);
92
93 if (inputSize > 200) {
94 Log() << kINFO << "----------------------------------------------------------------------------"
95 << Endl;
96 Log() << kINFO
97 << ": More than 200 variables, will not calculate decorrelation matrix "
98 << "!" << Endl;
99 Log() << kINFO << "----------------------------------------------------------------------------"
100 << Endl;
101 return kFALSE;
102 }
103
104 CalcSQRMats( events, GetNClasses() );
105
106 SetCreated( kTRUE );
107
108 return kTRUE;
109}
110
111////////////////////////////////////////////////////////////////////////////////
112/// creates string with variable transformations applied
113
115{
116 Int_t whichMatrix = cls;
117 // if cls (the class chosen by the user) not existing, assume that user wants to
118 // have the matrix for all classes together.
119
120 if (cls < 0 || cls > GetNClasses()) whichMatrix = GetNClasses();
121
122 TMatrixD* m = fDecorrMatrices.at(whichMatrix);
123 if (m == 0) {
124 if (whichMatrix == GetNClasses() )
125 Log() << kFATAL << "Transformation matrix all classes is not defined"
126 << Endl;
127 else
128 Log() << kFATAL << "Transformation matrix for class " << whichMatrix << " is not defined"
129 << Endl;
130 }
131
132 const Int_t nvar = fGet.size();
133 std::vector<TString>* strVec = new std::vector<TString>;
134
135 // fill vector
136 for (Int_t ivar=0; ivar<nvar; ivar++) {
137 TString str( "" );
138 for (Int_t jvar=0; jvar<nvar; jvar++) {
139 str += ((*m)(ivar,jvar) > 0) ? " + " : " - ";
140
141 Char_t type = fGet.at(jvar).first;
142 Int_t idx = fGet.at(jvar).second;
143
144 switch( type ) {
145 case 'v':
146 str += Form( "%10.5g*[%s]", TMath::Abs((*m)(ivar,jvar)), Variables()[idx].GetLabel().Data() );
147 break;
148 case 't':
149 str += Form( "%10.5g*[%s]", TMath::Abs((*m)(ivar,jvar)), Targets()[idx].GetLabel().Data() );
150 break;
151 case 's':
152 str += Form( "%10.5g*[%s]", TMath::Abs((*m)(ivar,jvar)), Spectators()[idx].GetLabel().Data() );
153 break;
154 default:
155 Log() << kFATAL << "VariableDecorrTransform::GetTransformationStrings : unknown type '" << type << "'." << Endl;
156 }
157 }
158 strVec->push_back( str );
159 }
160
161 return strVec;
162}
163
164////////////////////////////////////////////////////////////////////////////////
165/// apply the decorrelation transformation
166
168{
169 if (!IsCreated())
170 Log() << kFATAL << "Transformation matrix not yet created"
171 << Endl;
172
173 Int_t whichMatrix = cls;
174 // if cls (the class chosen by the user) not existing, assume that he wants to have the matrix for all classes together.
175 // EVT this is a workaround to address the reader problem with transforma and EvaluateMVA(std::vector<float/double> ,...)
176 if (cls < 0 || cls >= (int) fDecorrMatrices.size()) whichMatrix = fDecorrMatrices.size()-1;
177 //EVT workaround end
178 //if (cls < 0 || cls > GetNClasses()) {
179 // whichMatrix = GetNClasses();
180 // if (GetNClasses() == 1 ) whichMatrix = (fDecorrMatrices.size()==1?0:2);
181 //}
182
183 TMatrixD* m = fDecorrMatrices.at(whichMatrix);
184 if (m == 0) {
185 if (whichMatrix == GetNClasses() )
186 Log() << kFATAL << "Transformation matrix all classes is not defined"
187 << Endl;
188 else
189 Log() << kFATAL << "Transformation matrix for class " << whichMatrix << " is not defined"
190 << Endl;
191 }
192
193 if (fTransformedEvent==0 || fTransformedEvent->GetNVariables()!=ev->GetNVariables()) {
194 if (fTransformedEvent!=0) { delete fTransformedEvent; fTransformedEvent = 0; }
195 fTransformedEvent = new Event();
196 }
197
198 // transformation to decorrelate the variables
199 const Int_t nvar = fGet.size();
200
201 std::vector<Float_t> input;
202 std::vector<Char_t> mask; // entries with kTRUE must not be transformed
203 Bool_t hasMaskedEntries = GetInput( ev, input, mask );
204
205 if( hasMaskedEntries ){ // targets might be masked (for events where the targets have not been computed yet)
206 UInt_t numMasked = std::count(mask.begin(), mask.end(), (Char_t)kTRUE);
207 UInt_t numOK = std::count(mask.begin(), mask.end(), (Char_t)kFALSE);
208 if( numMasked>0 && numOK>0 ){
209 Log() << kFATAL << "You mixed variables and targets in the decorrelation transformation. This is not possible." << Endl;
210 }
211 SetOutput( fTransformedEvent, input, mask, ev );
212 return fTransformedEvent;
213 }
214
215 TVectorD vec( nvar );
216 for (Int_t ivar=0; ivar<nvar; ivar++) vec(ivar) = input.at(ivar);
217
218 // diagonalise variable vectors
219 vec *= *m;
220
221 input.clear();
222 for (Int_t ivar=0; ivar<nvar; ivar++) input.push_back( vec(ivar) );
223
224 SetOutput( fTransformedEvent, input, mask, ev );
225
226 return fTransformedEvent;
227}
228
229////////////////////////////////////////////////////////////////////////////////
230/// apply the inverse decorrelation transformation ...
231/// TODO : ... build the inverse transformation
232
234{
235 Log() << kFATAL << "Inverse transformation for decorrelation transformation not yet implemented. Hence, this transformation cannot be applied together with regression if targets should be transformed. Please contact the authors if necessary." << Endl;
236
237
238 return fBackTransformedEvent;
239}
240
241////////////////////////////////////////////////////////////////////////////////
242/// compute square-root matrices for signal and background
243
244void TMVA::VariableDecorrTransform::CalcSQRMats( const std::vector< Event*>& events, Int_t maxCls )
245{
246 // delete old matrices if any
247 for (std::vector<TMatrixD*>::iterator it = fDecorrMatrices.begin();
248 it != fDecorrMatrices.end(); ++it)
249 if (0 != (*it) ) { delete (*it); *it=0; }
250
251
252 // if more than one classes, then produce one matrix for all events as well (beside the matrices for each class)
253 const UInt_t matNum = (maxCls<=1)?maxCls:maxCls+1;
254 fDecorrMatrices.resize( matNum, (TMatrixD*) 0 );
255
256 std::vector<TMatrixDSym*>* covMat = gTools().CalcCovarianceMatrices( events, maxCls, this );
257
258
259 for (UInt_t cls=0; cls<matNum; cls++) {
260 TMatrixD* sqrMat = gTools().GetSQRootMatrix( covMat->at(cls) );
261 if ( sqrMat==0 )
262 Log() << kFATAL << "<GetSQRMats> Zero pointer returned for SQR matrix" << Endl;
263 fDecorrMatrices[cls] = sqrMat;
264 delete (*covMat)[cls];
265 }
266 delete covMat;
267}
268
269////////////////////////////////////////////////////////////////////////////////
270/// write the decorrelation matrix to the stream
271
273{
274 Int_t cls = 0;
275 Int_t dp = o.precision();
276 for (std::vector<TMatrixD*>::const_iterator itm = fDecorrMatrices.begin(); itm != fDecorrMatrices.end(); ++itm) {
277 o << "# correlation matrix " << std::endl;
278 TMatrixD* mat = (*itm);
279 o << cls << " " << mat->GetNrows() << " x " << mat->GetNcols() << std::endl;
280 for (Int_t row = 0; row<mat->GetNrows(); row++) {
281 for (Int_t col = 0; col<mat->GetNcols(); col++) {
282 o << std::setprecision(12) << std::setw(20) << (*mat)[row][col] << " ";
283 }
284 o << std::endl;
285 }
286 cls++;
287 }
288 o << "##" << std::endl;
289 o << std::setprecision(dp);
290}
291
292////////////////////////////////////////////////////////////////////////////////
293/// node attachment to parent
294
296{
297 void* trf = gTools().AddChild(parent, "Transform");
298 gTools().AddAttr(trf,"Name", "Decorrelation");
299
301
302 for (std::vector<TMatrixD*>::const_iterator itm = fDecorrMatrices.begin(); itm != fDecorrMatrices.end(); ++itm) {
303 TMatrixD* mat = (*itm);
304 /*void* decmat = gTools().xmlengine().NewChild(trf, 0, "Matrix");
305 gTools().xmlengine().NewAttr(decmat,0,"Rows", gTools().StringFromInt(mat->GetNrows()) );
306 gTools().xmlengine().NewAttr(decmat,0,"Columns", gTools().StringFromInt(mat->GetNcols()) );
307
308 std::stringstream s;
309 for (Int_t row = 0; row<mat->GetNrows(); row++) {
310 for (Int_t col = 0; col<mat->GetNcols(); col++) {
311 s << (*mat)[row][col] << " ";
312 }
313 }
314 gTools().xmlengine().AddRawLine( decmat, s.str().c_str() );*/
315 gTools().WriteTMatrixDToXML(trf,"Matrix",mat);
316 }
317}
318
319////////////////////////////////////////////////////////////////////////////////
320/// Read the transformation matrices from the xml node
321
323{
324 // first delete the old matrices
325 for( std::vector<TMatrixD*>::iterator it = fDecorrMatrices.begin(); it != fDecorrMatrices.end(); ++it )
326 if( (*it) != 0 ) delete (*it);
327 fDecorrMatrices.clear();
328
329 Bool_t newFormat = kFALSE;
330
331 void* inpnode = NULL;
332
333 inpnode = gTools().GetChild(trfnode, "Selection"); // new xml format
334 if( inpnode!=NULL )
335 newFormat = kTRUE; // new xml format
336
337 void* ch = NULL;
338 if( newFormat ){
339 // ------------- new format --------------------
340 // read input
342
343 ch = gTools().GetNextChild(inpnode);
344 }else
345 ch = gTools().GetChild(trfnode);
346
347 // Read the transformation matrices from the xml node
348 while(ch!=0) {
349 Int_t nrows, ncols;
350 gTools().ReadAttr(ch, "Rows", nrows);
351 gTools().ReadAttr(ch, "Columns", ncols);
352 TMatrixD* mat = new TMatrixD(nrows,ncols);
353 const char* content = gTools().GetContent(ch);
354 std::stringstream s(content);
355 for (Int_t row = 0; row<nrows; row++) {
356 for (Int_t col = 0; col<ncols; col++) {
357 s >> (*mat)[row][col];
358 }
359 }
360 fDecorrMatrices.push_back(mat);
361 ch = gTools().GetNextChild(ch);
362 }
363 SetCreated();
364}
365
366////////////////////////////////////////////////////////////////////////////////
367/// Read the decorellation matrix from an input stream
368
370{
371 char buf[512];
372 istr.getline(buf,512);
373 TString strvar, dummy;
374 Int_t nrows(0), ncols(0);
375 UInt_t classIdx=0;
376 while (!(buf[0]=='#'&& buf[1]=='#')) { // if line starts with ## return
377 char* p = buf;
378 while (*p==' ' || *p=='\t') p++; // 'remove' leading whitespace
379 if (*p=='#' || *p=='\0') {
380 istr.getline(buf,512);
381 continue; // if comment or empty line, read the next line
382 }
383 std::stringstream sstr(buf);
384
385 sstr >> strvar;
386 if (strvar=="signal" || strvar=="background") {
387 UInt_t cls=0;
388 if(strvar=="background") cls=1;
389 if(strvar==classname) classIdx = cls;
390 // coverity[tainted_data_argument]
391 sstr >> nrows >> dummy >> ncols;
392 if (fDecorrMatrices.size() <= cls ) fDecorrMatrices.resize(cls+1);
393 if (fDecorrMatrices.at(cls) != 0) delete fDecorrMatrices.at(cls);
394 TMatrixD* mat = fDecorrMatrices.at(cls) = new TMatrixD(nrows,ncols);
395 // now read all matrix parameters
396 for (Int_t row = 0; row<mat->GetNrows(); row++) {
397 for (Int_t col = 0; col<mat->GetNcols(); col++) {
398 istr >> (*mat)[row][col];
399 }
400 }
401 } // done reading a matrix
402 istr.getline(buf,512); // reading the next line
403 }
404
405 fDecorrMatrices.push_back( new TMatrixD(*fDecorrMatrices[classIdx]) );
406
407 SetCreated();
408}
409
410////////////////////////////////////////////////////////////////////////////////
411/// prints the transformation matrix
412
414{
415 Int_t cls = 0;
416 for (std::vector<TMatrixD*>::iterator itm = fDecorrMatrices.begin(); itm != fDecorrMatrices.end(); ++itm) {
417 Log() << kINFO << "Transformation matrix "<< cls <<":" << Endl;
418 (*itm)->Print();
419 }
420}
421
422////////////////////////////////////////////////////////////////////////////////
423/// creates C++ code fragment of the decorrelation transform for inclusion in standalone C++ class
424
425void TMVA::VariableDecorrTransform::MakeFunction( std::ostream& fout, const TString& fcncName, Int_t part, UInt_t trCounter, Int_t )
426{
427 Int_t dp = fout.precision();
428
429 UInt_t numC = fDecorrMatrices.size();
430 // creates a decorrelation function
431 if (part==1) {
432 TMatrixD* mat = fDecorrMatrices.at(0); // ToDo check if all Decorr matrices have identical dimensions
433 fout << std::endl;
434 fout << " double fDecTF_"<<trCounter<<"["<<numC<<"]["<<mat->GetNrows()<<"]["<<mat->GetNcols()<<"];" << std::endl;
435 }
436
437 if (part==2) {
438 fout << std::endl;
439 fout << "//_______________________________________________________________________" << std::endl;
440 fout << "inline void " << fcncName << "::InitTransform_"<<trCounter<<"()" << std::endl;
441 fout << "{" << std::endl;
442 fout << " // Decorrelation transformation, initialisation" << std::endl;
443 for (UInt_t icls = 0; icls < numC; icls++){
444 TMatrixD* matx = fDecorrMatrices.at(icls);
445 for (int i=0; i<matx->GetNrows(); i++) {
446 for (int j=0; j<matx->GetNcols(); j++) {
447 fout << " fDecTF_"<<trCounter<<"["<<icls<<"]["<<i<<"]["<<j<<"] = " << std::setprecision(12) << (*matx)[i][j] << ";" << std::endl;
448 }
449 }
450 }
451 fout << "}" << std::endl;
452 fout << std::endl;
453 TMatrixD* matx = fDecorrMatrices.at(0); // ToDo check if all Decorr matrices have identical dimensions
454 fout << "//_______________________________________________________________________" << std::endl;
455 fout << "inline void " << fcncName << "::Transform_"<<trCounter<<"( std::vector<double>& iv, int cls) const" << std::endl;
456 fout << "{" << std::endl;
457 fout << " // Decorrelation transformation" << std::endl;
458 fout << " if (cls < 0 || cls > "<<GetNClasses()<<") {"<< std::endl;
459 fout << " if ("<<GetNClasses()<<" > 1 ) cls = "<<GetNClasses()<<";"<< std::endl;
460 fout << " else cls = "<<(fDecorrMatrices.size()==1?0:2)<<";"<< std::endl;
461 fout << " }"<< std::endl;
462
463 VariableTransformBase::MakeFunction(fout, fcncName, 0, trCounter, 0 );
464
465 fout << " std::vector<double> tv;" << std::endl;
466 fout << " for (int i=0; i<"<<matx->GetNrows()<<";i++) {" << std::endl;
467 fout << " double v = 0;" << std::endl;
468 fout << " for (int j=0; j<"<<matx->GetNcols()<<"; j++)" << std::endl;
469 fout << " v += iv[indicesGet.at(j)] * fDecTF_"<<trCounter<<"[cls][i][j];" << std::endl;
470 fout << " tv.push_back(v);" << std::endl;
471 fout << " }" << std::endl;
472 fout << " for (int i=0; i<"<<matx->GetNrows()<<";i++) iv[indicesPut.at(i)] = tv[i];" << std::endl;
473 fout << "}" << std::endl;
474 }
475
476 fout << std::setprecision(dp);
477}
char Char_t
Definition RtypesCore.h:37
const Bool_t kFALSE
Definition RtypesCore.h:92
const Bool_t kTRUE
Definition RtypesCore.h:91
#define ClassImp(name)
Definition Rtypes.h:364
int type
Definition TGX11.cxx:121
TMatrixT< Double_t > TMatrixD
Definition TMatrixDfwd.h:23
char * Form(const char *fmt,...)
Class that contains all the data information.
Definition DataSetInfo.h:62
UInt_t GetNVariables() const
accessor to the number of variables
Definition Event.cxx:308
void * GetNextChild(void *prevchild, const char *childname=0)
XML helpers.
Definition Tools.cxx:1174
TMatrixD * GetSQRootMatrix(TMatrixDSym *symMat)
square-root of symmetric matrix of course the resulting sqrtMat is also symmetric,...
Definition Tools.cxx:283
void * AddChild(void *parent, const char *childname, const char *content=0, bool isRootNode=false)
add child node
Definition Tools.cxx:1136
const char * GetContent(void *node)
XML helpers.
Definition Tools.cxx:1186
void * GetChild(void *parent, const char *childname=0)
get child node
Definition Tools.cxx:1162
void WriteTMatrixDToXML(void *node, const char *name, TMatrixD *mat)
XML helpers.
Definition Tools.cxx:1255
void ReadAttr(void *node, const char *, T &value)
read attribute from xml
Definition Tools.h:335
void AddAttr(void *node, const char *, const T &value, Int_t precision=16)
add attribute to xml
Definition Tools.h:353
std::vector< TMatrixDSym * > * CalcCovarianceMatrices(const std::vector< Event * > &events, Int_t maxCls, VariableTransformBase *transformBase=0)
compute covariance matrices
Definition Tools.cxx:1526
Singleton class for Global types used by TMVA.
Definition Types.h:73
Linear interpolation class.
virtual void PrintTransformation(std::ostream &o)
prints the transformation matrix
virtual void MakeFunction(std::ostream &fout, const TString &fncName, Int_t part, UInt_t trCounter, Int_t cls)
creates C++ code fragment of the decorrelation transform for inclusion in standalone C++ class
virtual ~VariableDecorrTransform(void)
destructor
Bool_t PrepareTransformation(const std::vector< Event * > &)
calculate the decorrelation matrix and the normalization
VariableDecorrTransform(DataSetInfo &dsi)
constructor
void CalcSQRMats(const std::vector< Event * > &, Int_t maxCls)
Decorrelation matrix [class0/class1/.../all classes].
virtual void ReadFromXML(void *trfnode)
Read the transformation matrices from the xml node.
std::vector< TString > * GetTransformationStrings(Int_t cls) const
creates string with variable transformations applied
void WriteTransformationToStream(std::ostream &) const
write the decorrelation matrix to the stream
virtual void AttachXMLTo(void *parent)
node attachment to parent
virtual const Event * Transform(const Event *const, Int_t cls) const
apply the decorrelation transformation
void ReadTransformationFromStream(std::istream &, const TString &)
Read the decorellation matrix from an input stream.
virtual const Event * InverseTransform(const Event *const, Int_t cls) const
apply the inverse decorrelation transformation ... TODO : ... build the inverse transformation
Linear interpolation class.
virtual void MakeFunction(std::ostream &fout, const TString &fncName, Int_t part, UInt_t trCounter, Int_t cls)=0
getinput and setoutput equivalent
virtual void ReadFromXML(void *trfnode)=0
Read the input variables from the XML node.
virtual void AttachXMLTo(void *parent)=0
create XML description the transformation (write out info of selected variables)
Int_t GetNrows() const
Int_t GetNcols() const
virtual void Print(Option_t *option="") const
This method must be overridden when a class wants to print itself.
Definition TObject.cxx:552
Basic string class.
Definition TString.h:136
Tools & gTools()
MsgLogger & Endl(MsgLogger &ml)
Definition MsgLogger.h:158
Short_t Abs(Short_t d)
Definition TMathBase.h:120
auto * m
Definition textangle.C:8