57 TMVA::VariableTransformBase::VariableTransformBase( DataSetInfo& dsi,
58 Types::EVariableTransform tf,
64 fBackTransformedEvent(0),
65 fVariableTransform(tf),
69 fTransformName(trfName),
70 fVariableTypesAreCounted(
false),
78 fLogger =
new MsgLogger(
this,
kINFO);
79 for (
UInt_t ivar = 0; ivar < fDsi.GetNVariables(); ivar++) {
80 fVariables.push_back( VariableInfo( fDsi.GetVariableInfo(ivar) ) );
82 for (
UInt_t itgt = 0; itgt < fDsi.GetNTargets(); itgt++) {
83 fTargets.push_back( VariableInfo( fDsi.GetTargetInfo(itgt) ) );
85 for (
UInt_t ispct = 0; ispct < fDsi.GetNSpectators(); ispct++) {
86 fTargets.push_back( VariableInfo( fDsi.GetSpectatorInfo(ispct) ) );
105 TString inputVariables = _inputVariables;
110 UInt_t nvars = GetNVariables();
111 UInt_t ntgts = GetNTargets();
112 UInt_t nspcts = GetNSpectators();
114 typedef std::set<Int_t> SelectedIndices;
116 SelectedIndices varIndices;
117 SelectedIndices tgtIndices;
118 SelectedIndices spctIndices;
120 if (inputVariables ==
"")
122 inputVariables =
"_V_,_T_";
137 if( variables.
Length() == 0 ){
138 for(
UInt_t ivar = 0; ivar < nvars; ++ivar ) {
139 fGet.push_back( std::pair<Char_t,UInt_t>(
'v',ivar) );
140 varIndices.insert( ivar );
145 Log() <<
kFATAL <<
"You selected variable with index : " << idx <<
" of only " << nvars <<
" variables." <<
Endl;
146 fGet.push_back( std::pair<Char_t,UInt_t>(
'v',idx) );
147 varIndices.insert( idx );
151 if( variables.
Length() == 0 ){
152 for(
UInt_t itgt = 0; itgt < ntgts; ++itgt ) {
153 fGet.push_back( std::pair<Char_t,UInt_t>(
't',itgt) );
154 tgtIndices.insert( itgt );
159 Log() <<
kFATAL <<
"You selected target with index : " << idx <<
" of only " << ntgts <<
" targets." <<
Endl;
160 fGet.push_back( std::pair<Char_t,UInt_t>(
't',idx) );
161 tgtIndices.insert( idx );
165 if( variables.
Length() == 0 ){
166 for(
UInt_t ispct = 0; ispct < nspcts; ++ispct ) {
167 fGet.push_back( std::pair<Char_t,UInt_t>(
's',ispct) );
168 spctIndices.insert( ispct );
173 Log() <<
kFATAL <<
"You selected spectator with index : " << idx <<
" of only " << nspcts <<
" spectators." <<
Endl;
174 fGet.push_back( std::pair<Char_t,UInt_t>(
's',idx) );
175 spctIndices.insert( idx );
178 ToggleInputSortOrder(
kFALSE );
180 Log() <<
kINFO <<
"Variable rearrangement set true: Variable order given in transformation option is used for input to transformation!" <<
Endl;
184 Int_t numIndices = varIndices.size()+tgtIndices.size()+spctIndices.size();
185 for(
UInt_t ivar = 0; ivar < nvars; ++ivar ) {
186 if( fDsi.GetVariableInfo( ivar ).GetLabel() ==
variables ) {
187 fGet.push_back( std::pair<Char_t,UInt_t>(
'v',ivar) );
188 varIndices.insert( ivar );
192 for(
UInt_t itgt = 0; itgt < ntgts; ++itgt ) {
193 if( fDsi.GetTargetInfo( itgt ).GetLabel() ==
variables ) {
194 fGet.push_back( std::pair<Char_t,UInt_t>(
't',itgt) );
195 tgtIndices.insert( itgt );
199 for(
UInt_t ispct = 0; ispct < nspcts; ++ispct ) {
200 if( fDsi.GetSpectatorInfo( ispct ).GetLabel() ==
variables ) {
201 fGet.push_back( std::pair<Char_t,UInt_t>(
's',ispct) );
202 spctIndices.insert( ispct );
206 Int_t numIndicesEndOfLoop = varIndices.size()+tgtIndices.size()+spctIndices.size();
207 if( numIndicesEndOfLoop == numIndices )
208 Log() <<
kWARNING <<
"Error at parsing the options for the variable transformations: Variable/Target/Spectator '" << variables.
Data() <<
"' not found." <<
Endl;
209 numIndices = numIndicesEndOfLoop;
214 if( putIntoVariables ) {
216 for( SelectedIndices::iterator it = varIndices.begin(), itEnd = varIndices.end(); it != itEnd; ++it ) {
217 fPut.push_back( std::pair<Char_t,UInt_t>(
'v',idx) );
220 for( SelectedIndices::iterator it = tgtIndices.begin(), itEnd = tgtIndices.end(); it != itEnd; ++it ) {
221 fPut.push_back( std::pair<Char_t,UInt_t>(
't',idx) );
224 for( SelectedIndices::iterator it = spctIndices.begin(), itEnd = spctIndices.end(); it != itEnd; ++it ) {
225 fPut.push_back( std::pair<Char_t,UInt_t>(
's',idx) );
229 for( SelectedIndices::iterator it = varIndices.begin(), itEnd = varIndices.end(); it != itEnd; ++it ) {
231 fPut.push_back( std::pair<Char_t,UInt_t>(
'v',idx) );
233 for( SelectedIndices::iterator it = tgtIndices.begin(), itEnd = tgtIndices.end(); it != itEnd; ++it ) {
235 fPut.push_back( std::pair<Char_t,UInt_t>(
't',idx) );
237 for( SelectedIndices::iterator it = spctIndices.begin(), itEnd = spctIndices.end(); it != itEnd; ++it ) {
239 fPut.push_back( std::pair<Char_t,UInt_t>(
's',idx) );
245 fGet.assign( fPut.begin(), fPut.end() );
250 Log() <<
kINFO <<
"Transformation, Variable selection : " <<
Endl;
253 const DataSetInfo* outputDsiPtr = (fDsiOutput? &(*fDsiOutput) : &fDsi );
257 ItVarTypeIdx itGet = fGet.begin(), itGetEnd = fGet.end();
259 for( ; itGet != itGetEnd; ++itGet ) {
262 Char_t inputType = (*itGet).first;
263 Int_t inputIdx = (*itGet).second;
265 TString inputLabel =
"NOT FOND";
266 if( inputType ==
'v' ) {
267 inputLabel = fDsi.GetVariableInfo( inputIdx ).GetLabel();
268 inputTypeString =
"variable";
270 else if( inputType ==
't' ){
271 inputLabel = fDsi.GetTargetInfo( inputIdx ).GetLabel();
272 inputTypeString =
"target";
274 else if( inputType ==
's' ){
275 inputLabel = fDsi.GetSpectatorInfo( inputIdx ).GetLabel();
276 inputTypeString =
"spectator";
279 TString outputTypeString =
"?";
281 Char_t outputType = (*itPut).first;
282 Int_t outputIdx = (*itPut).second;
284 TString outputLabel =
"NOT FOUND";
285 if( outputType ==
'v' ) {
287 outputTypeString =
"variable";
289 else if( outputType ==
't' ){
291 outputTypeString =
"target";
293 else if( outputType ==
's' ){
295 outputTypeString =
"spectator";
299 Log() <<
kINFO <<
"Input : " << inputTypeString.
Data() <<
" '" << inputLabel.
Data() <<
"' (index=" << inputIdx <<
"). <---> "
300 <<
"Output : " << outputTypeString.
Data() <<
" '" << outputLabel.
Data() <<
"' (index=" << outputIdx <<
")." <<
Endl;
318 if( backTransformation && !fPut.empty() ){
319 itEntry = fPut.begin();
320 itEntryEnd = fPut.end();
321 input.reserve(fPut.size());
324 itEntry = fGet.begin();
325 itEntryEnd = fGet.end();
326 input.reserve(fGet.size() );
331 for( ; itEntry != itEntryEnd; ++itEntry ) {
333 Int_t idx = (*itEntry).second;
338 input.push_back( event->
GetValue(idx) );
341 input.push_back( event->
GetTarget(idx) );
347 Log() <<
kFATAL <<
"VariableTransformBase/GetInput : unknown type '" << type <<
"'." <<
Endl;
351 catch(std::out_of_range& ){
352 input.push_back(0.
f);
353 mask.push_back(
kTRUE);
354 hasMaskedEntries =
kTRUE;
357 return hasMaskedEntries;
365 std::vector<Float_t>::iterator itOutput = output.begin();
366 std::vector<Char_t>::iterator itMask = mask.begin();
369 event->CopyVarValues( *oldEvent );
376 if( backTransformation || fPut.empty() ){
377 itEntry = fGet.begin();
378 itEntryEnd = fGet.end();
381 itEntry = fPut.begin();
382 itEntryEnd = fPut.end();
386 for( ; itEntry != itEntryEnd; ++itEntry ) {
393 Int_t idx = (*itEntry).second;
394 if (itOutput == output.end())
Log() <<
kFATAL <<
"Read beyond array boundaries in VariableTransformBase::SetOutput"<<
Endl;
399 event->SetVal( idx, value );
402 event->SetTarget( idx, value );
405 event->SetSpectator( idx, value );
408 Log() <<
kFATAL <<
"VariableTransformBase/GetInput : unknown type '" << type <<
"'." <<
Endl;
410 if( !(*itMask) ) ++itOutput;
414 }
catch( std::exception&
except ){
415 Log() <<
kFATAL <<
"VariableTransformBase/SetOutput : exception/" << except.what() <<
Endl;
426 if( fVariableTypesAreCounted ){
429 nspcts = fNSpectators;
433 nvars = ntgts = nspcts = 0;
435 for(
ItVarTypeIdxConst itEntry = fGet.begin(), itEntryEnd = fGet.end(); itEntry != itEntryEnd; ++itEntry ) {
449 Log() <<
kFATAL <<
"VariableTransformBase/GetVariableTypeNumbers : unknown type '" << type <<
"'." <<
Endl;
455 fNSpectators = nspcts;
457 fVariableTypesAreCounted =
true;
468 if (!IsCreated())
return;
470 const UInt_t nvars = GetNVariables();
471 const UInt_t ntgts = GetNTargets();
473 UInt_t nevts = events.size();
476 TVectorD x0( nvars+ntgts ); x0 *= 0;
479 for (
UInt_t ievt=0; ievt<nevts; ievt++) {
480 const Event* ev = events[ievt];
483 sumOfWeights += weight;
484 for (
UInt_t ivar=0; ivar<nvars; ivar++) {
487 Variables().at(ivar).SetMin(x);
488 Variables().at(ivar).SetMax(x);
491 UpdateNorm( ivar, x );
493 x0(ivar) += x*weight;
494 x2(ivar) += x*x*weight;
496 for (
UInt_t itgt=0; itgt<ntgts; itgt++) {
499 Targets().at(itgt).SetMin(x);
500 Targets().at(itgt).SetMax(x);
503 UpdateNorm( nvars+itgt, x );
505 x0(nvars+itgt) += x*weight;
506 x2(nvars+itgt) += x*x*weight;
510 if (sumOfWeights <= 0) {
511 Log() <<
kFATAL <<
" the sum of event weights calcualted for your input is == 0"
512 <<
" or exactly: " << sumOfWeights <<
" there is obviously some problem..."<<
Endl;
516 for (
UInt_t ivar=0; ivar<nvars; ivar++) {
517 Double_t mean = x0(ivar)/sumOfWeights;
519 Variables().at(ivar).SetMean( mean );
520 if (
x2(ivar)/sumOfWeights - mean*mean < 0) {
521 Log() <<
kFATAL <<
" the RMS of your input variable " << ivar
522 <<
" evaluates to an imaginary number: sqrt("<<
x2(ivar)/sumOfWeights - mean*mean
523 <<
") .. sometimes related to a problem with outliers and negative event weights"
526 Variables().at(ivar).SetRMS(
TMath::Sqrt(
x2(ivar)/sumOfWeights - mean*mean) );
528 for (
UInt_t itgt=0; itgt<ntgts; itgt++) {
529 Double_t mean = x0(nvars+itgt)/sumOfWeights;
530 Targets().at(itgt).SetMean( mean );
531 if (
x2(nvars+itgt)/sumOfWeights - mean*mean < 0) {
532 Log() <<
kFATAL <<
" the RMS of your target variable " << itgt
533 <<
" evaluates to an imaginary number: sqrt(" <<
x2(nvars+itgt)/sumOfWeights - mean*mean
534 <<
") .. sometimes related to a problem with outliers and negative event weights"
537 Targets().at(itgt).SetRMS(
TMath::Sqrt(
x2(nvars+itgt)/sumOfWeights - mean*mean) );
541 Log() << std::setprecision(3);
542 for (
UInt_t ivar=0; ivar<GetNVariables(); ivar++)
543 Log() <<
" " << Variables().at(ivar).GetInternalName()
544 <<
"\t: [" << Variables().at(ivar).GetMin() <<
"\t, " << Variables().at(ivar).GetMax() <<
"\t] " <<
Endl;
546 Log() << std::setprecision(3);
547 for (
UInt_t itgt=0; itgt<GetNTargets(); itgt++)
548 Log() <<
" " << Targets().at(itgt).GetInternalName()
549 <<
"\t: [" << Targets().at(itgt).GetMin() <<
"\t, " << Targets().at(itgt).GetMax() <<
"\t] " <<
Endl;
550 Log() << std::setprecision(5);
560 std::vector<TString>* strVec =
new std::vector<TString>;
561 for (
UInt_t ivar=0; ivar<GetNVariables(); ivar++) {
562 strVec->push_back( Variables()[ivar].GetLabel() +
"_[transformed]");
574 Int_t nvars = fDsi.GetNVariables();
576 if (x < Variables().
at(ivar).GetMin()) Variables().at(ivar).SetMin(x);
577 if (x > Variables().
at(ivar).GetMax()) Variables().at(ivar).SetMax(x);
579 if (x < Targets().
at(ivar-nvars).GetMin()) Targets().at(ivar-nvars).SetMin(x);
580 if (x > Targets().
at(ivar-nvars).GetMax()) Targets().at(ivar-nvars).SetMax(x);
595 const DataSetInfo* outputDsiPtr = (fDsiOutput? fDsiOutput : &fDsi );
597 for(
ItVarTypeIdx itGet = fGet.begin(), itGetEnd = fGet.end(); itGet != itGetEnd; ++itGet ) {
598 UInt_t idx = (*itGet).second;
606 typeString =
"Variable";
607 label = fDsi.GetVariableInfo( idx ).GetLabel();
608 expression = fDsi.GetVariableInfo( idx ).GetExpression();
611 typeString =
"Target";
612 label = fDsi.GetTargetInfo( idx ).GetLabel();
613 expression = fDsi.GetTargetInfo( idx ).GetExpression();
616 typeString =
"Spectator";
617 label = fDsi.GetSpectatorInfo( idx ).GetLabel();
618 expression = fDsi.GetSpectatorInfo( idx ).GetExpression();
621 Log() <<
kFATAL <<
"VariableTransformBase/AttachXMLTo unknown variable type '" << type <<
"'." <<
Endl;
635 for(
ItVarTypeIdx itPut = fPut.begin(), itPutEnd = fPut.end(); itPut != itPutEnd; ++itPut ) {
636 UInt_t idx = (*itPut).second;
644 typeString =
"Variable";
649 typeString =
"Target";
654 typeString =
"Spectator";
659 Log() <<
kFATAL <<
"VariableTransformBase/AttachXMLTo unknown variable type '" << type <<
"'." <<
Endl;
680 UInt_t nvars = GetNVariables();
681 UInt_t ntgts = GetNTargets();
682 UInt_t nspcts = GetNSpectators();
702 if( typeString ==
"Variable" ){
703 for(
UInt_t ivar = 0; ivar < nvars; ++ivar ) {
704 if( fDsi.GetVariableInfo( ivar ).GetLabel() == label ||
705 fDsi.GetVariableInfo( ivar ).GetExpression() == expression) {
706 fGet.push_back( std::pair<Char_t,UInt_t>(
'v',ivar) );
710 }
else if( typeString ==
"Target" ){
711 for(
UInt_t itgt = 0; itgt < ntgts; ++itgt ) {
712 if( fDsi.GetTargetInfo( itgt ).GetLabel() == label ||
713 fDsi.GetTargetInfo( itgt ).GetExpression() == expression ) {
714 fGet.push_back( std::pair<Char_t,UInt_t>(
't',itgt) );
718 }
else if( typeString ==
"Spectator" ){
719 for(
UInt_t ispct = 0; ispct < nspcts; ++ispct ) {
720 if( fDsi.GetSpectatorInfo( ispct ).GetLabel() == label ||
721 fDsi.GetSpectatorInfo( ispct ).GetExpression() == expression ) {
722 fGet.push_back( std::pair<Char_t,UInt_t>(
's',ispct) );
727 Log() <<
kFATAL <<
"VariableTransformationBase/ReadFromXML : unknown type '" << typeString <<
"'." <<
Endl;
732 assert( nInputs == fGet.size() );
750 if( typeString ==
"Variable" ){
751 for(
UInt_t ivar = 0; ivar < nvars; ++ivar ) {
752 if( fDsi.GetVariableInfo( ivar ).GetLabel() == label ||
753 fDsi.GetVariableInfo( ivar ).GetExpression() == expression ) {
754 fPut.push_back( std::pair<Char_t,UInt_t>(
'v',ivar) );
758 }
else if( typeString ==
"Target" ){
759 for(
UInt_t itgt = 0; itgt < ntgts; ++itgt ) {
760 if( fDsi.GetTargetInfo( itgt ).GetLabel() == label ||
761 fDsi.GetTargetInfo( itgt ).GetExpression() == expression ) {
762 fPut.push_back( std::pair<Char_t,UInt_t>(
't',itgt) );
766 }
else if( typeString ==
"Spectator" ){
767 for(
UInt_t ispct = 0; ispct < nspcts; ++ispct ) {
768 if( fDsi.GetSpectatorInfo( ispct ).GetLabel() == label ||
769 fDsi.GetSpectatorInfo( ispct ).GetExpression() == expression ) {
770 fPut.push_back( std::pair<Char_t,UInt_t>(
's',ispct) );
775 Log() <<
kFATAL <<
"VariableTransformationBase/ReadFromXML : unknown type '" << typeString <<
"'." <<
Endl;
780 assert( nOutputs == fPut.size() );
794 fout <<
" // define the indices of the variables which are transformed by this transformation" << std::endl;
795 fout <<
" static std::vector<int> indicesGet;" << std::endl;
796 fout <<
" static std::vector<int> indicesPut;" << std::endl << std::endl;
797 fout <<
" if ( indicesGet.empty() ) { " << std::endl;
798 fout <<
" indicesGet.reserve(fNvars);" << std::endl;
800 for(
ItVarTypeIdxConst itEntry = fGet.begin(), itEntryEnd = fGet.end(); itEntry != itEntryEnd; ++itEntry ) {
802 Int_t idx = (*itEntry).second;
806 fout <<
" indicesGet.push_back( " << idx <<
");" << std::endl;
809 Log() <<
kWARNING <<
"MakeClass doesn't work with transformation of targets. The results will be wrong!" <<
Endl;
812 Log() <<
kWARNING <<
"MakeClass doesn't work with transformation of spectators. The results will be wrong!" <<
Endl;
815 Log() <<
kFATAL <<
"VariableTransformBase/GetInput : unknown type '" << type <<
"'." <<
Endl;
818 fout <<
" } " << std::endl;
819 fout <<
" if ( indicesPut.empty() ) { " << std::endl;
820 fout <<
" indicesPut.reserve(fNvars);" << std::endl;
822 for(
ItVarTypeIdxConst itEntry = fPut.begin(), itEntryEnd = fPut.end(); itEntry != itEntryEnd; ++itEntry ) {
824 Int_t idx = (*itEntry).second;
828 fout <<
" indicesPut.push_back( " << idx <<
");" << std::endl;
831 Log() <<
kWARNING <<
"MakeClass doesn't work with transformation of targets. The results will be wrong!" <<
Endl;
834 Log() <<
kWARNING <<
"MakeClass doesn't work with transformation of spectators. The results will be wrong!" <<
Endl;
837 Log() <<
kFATAL <<
"VariableTransformBase/PutInput : unknown type '" << type <<
"'." <<
Endl;
841 fout <<
" } " << std::endl;
844 }
else if( part == 1){
#define TMVA_VERSION_CODE
MsgLogger & Endl(MsgLogger &ml)
void variables(TString fin="TMVA.root", TString dirName="InputVariables_Id", TString title="TMVA Input Variables", Bool_t isRegression=kFALSE, Bool_t useTMVAStyle=kTRUE)
Collectable string class.
const TString & GetExpression() const
Double_t GetWeight() const
return the event weight - depending on whether the flag IgnoreNegWeightsInTraining is or not...
Float_t GetSpectator(UInt_t ivar) const
return spectator content
ClassImp(TIterator) Bool_t TIterator return false
Compare two iterator objects.
Bool_t BeginsWith(const char *s, ECaseCompare cmp=kExact) const
const char * Data() const
static const double x2[5]
Int_t Atoi() const
Return integer value of string.
Bool_t EndsWith(const char *pat, ECaseCompare cmp=kExact) const
Return true if string ends with the specified string.
VariableInfo & GetTargetInfo(Int_t i)
TString & Remove(Ssiz_t pos)
Float_t GetValue(UInt_t ivar) const
return value of i'th variable
VariableInfo & GetSpectatorInfo(Int_t i)
VariableInfo & GetVariableInfo(Int_t i)
Mother of all ROOT objects.
Float_t GetTarget(UInt_t itgt) const
Double_t Sqrt(Double_t x)
const TString & GetLabel() const
static void output(int code)