49#define MinNoTrainingEvents 10
80 auto inte =
roc->GetROCIntegral();
107 fMethod =
cr.fMethod;
108 fDataLoaderName =
cr.fDataLoaderName;
109 fMvaTrain =
cr.fMvaTrain;
110 fMvaTest =
cr.fMvaTest;
111 fIsCuts =
cr.fIsCuts;
112 fROCIntegral =
cr.fROCIntegral;
126 TString hLine =
"--------------------------------------------------- :";
129 fLogger << kINFO <<
"DataSet MVA :" <<
Endl;
130 fLogger << kINFO <<
"Name: Method/Title: ROC-integ :" <<
Endl;
132 fLogger << kINFO <<
TString::Format(
"%-20s %-15s %#1.3f :", fDataLoaderName.Data(),
134 fMethod.GetValue<
TString>(
"MethodTitle").
Data()).Data(),
152 roc->SetName(
TString::Format(
"%s/%s", GetMethodName().Data(), GetMethodTitle().Data()).Data());
153 roc->SetTitle(
TString::Format(
"%s/%s", GetMethodName().Data(), GetMethodTitle().Data()).Data());
154 roc->GetXaxis()->SetTitle(
" Signal Efficiency ");
155 roc->GetYaxis()->SetTitle(
" Background Rejection ");
221 for (
auto m : fIMethods) {
236 for (
auto &
meth : fMethods) {
259 for (
auto &
meth : fMethods) {
263 fWorkers.SetNWorkers(fJobs);
273 if (!IsSilentFile()) {
276 f->mkdir(fDataLoader->GetName());
282 if (!IsSilentFile()) {
289 fResults = fWorkers.Map(executor,
ROOT::TSeqI(fMethods.size()));
298 TString hLine =
"--------------------------------------------------- :";
300 Log() << kINFO <<
"DataSet MVA :" <<
Endl;
301 Log() << kINFO <<
"Name: Method/Title: ROC-integ :" <<
Endl;
303 for (
auto &
r : fResults) {
305 Log() << kINFO <<
TString::Format(
"%-20s %-15s %#1.3f :",
r.GetDataLoaderName().Data(),
306 TString::Format(
"%s/%s",
r.GetMethodName().Data(),
r.GetMethodTitle().Data()).Data(),
312 Log() << kINFO <<
"-----------------------------------------------------" <<
Endl;
313 Log() << kHEADER <<
"Evaluation done." <<
Endl <<
Endl;
314 Log() << kINFO <<
TString::Format(
"Jobs = %d Real Time = %lf ", fJobs, fTimer.RealTime()) <<
Endl;
315 Log() << kINFO <<
"-----------------------------------------------------" <<
Endl;
316 Log() << kINFO <<
"Evaluation done." <<
Endl;
326 for (
auto &
meth : fMethods) {
350 method->DataInfo().GetNClasses() < 2)
351 Log() << kFATAL <<
"You want to do classification training, but specified less than two classes." <<
Endl;
357 Log() << kWARNING <<
"Method " <<
method->GetMethodName() <<
" not trained (training tree has less entries ["
362 Log() << kHEADER <<
"Train method: " <<
method->GetMethodName() <<
" for Classification" <<
Endl <<
Endl;
364 Log() << kHEADER <<
"Training finished" <<
Endl <<
Endl;
391 Log() << kERROR <<
"Trying to get method not booked." <<
Endl;
399 if (GetDataLoaderDataInput().GetEntries() <=
401 Log() << kFATAL <<
"No input data for the training provided!" <<
Endl;
411 conf->DeclareOptionRef(
boostNum = 0,
"Boost_num",
"Number of times the classifier will be boosted");
412 conf->ParseOptions();
416 if (fModelPersistence) {
417 fFileDir = fDataLoader->GetName();
426 GetDataLoaderDataSetInfo(),
moptions);
429 Log() << kDEBUG <<
"Boost Number is " <<
boostNum <<
" > 0: train boosted classifier" <<
Endl;
434 Log() << kFATAL <<
"Method with type kBoost cannot be casted to MethodCategory. /Classification" <<
Endl;
436 if (fModelPersistence)
438 methBoost->SetModelPersistence(fModelPersistence);
440 methBoost->fDataSetManager = GetDataLoaderDataSetManager();
442 methBoost->SetSilentFile(IsSilentFile());
453 Log() << kFATAL <<
"Method with type kCategory cannot be casted to MethodCategory. /Classification" <<
Endl;
455 if (fModelPersistence)
456 methCat->SetWeightFileDir(fFileDir);
457 methCat->SetModelPersistence(fModelPersistence);
458 methCat->fDataSetManager = GetDataLoaderDataSetManager();
460 methCat->SetSilentFile(IsSilentFile());
463 if (!
method->HasAnalysisType(fAnalysisType, GetDataLoaderDataSetInfo().GetNClasses(),
464 GetDataLoaderDataSetInfo().GetNTargets())) {
465 Log() << kWARNING <<
"Method " <<
method->GetMethodTypeName() <<
" is not capable of handling ";
466 Log() <<
"classification with " << GetDataLoaderDataSetInfo().GetNClasses() <<
" classes." <<
Endl;
470 if (fModelPersistence)
471 method->SetWeightFileDir(fFileDir);
472 method->SetModelPersistence(fModelPersistence);
473 method->SetAnalysisType(fAnalysisType);
477 method->SetFile(fFile.get());
478 method->SetSilentFile(IsSilentFile());
482 fIMethods.push_back(
method);
496 if (fIMethods.empty())
498 for (
UInt_t i = 0; i < fIMethods.size(); i++) {
515 for (
auto &
meth : fMethods) {
539 Log() << kHEADER <<
"Test method: " <<
method->GetMethodName() <<
" for Classification"
553 std::vector<std::vector<TString>>
mname(2);
554 std::vector<std::vector<Double_t>> sig(2), sep(2),
roc(2);
559 method->SetFile(fFile.get());
560 method->SetSilentFile(IsSilentFile());
563 if (!IsCutsMethod(
method))
566 Log() << kHEADER <<
"Evaluate classifier: " <<
method->GetMethodName() <<
Endl <<
Endl;
567 isel = (
method->GetMethodTypeName().Contains(
"Variable")) ? 1 : 0;
570 method->TestClassification();
574 sig[
isel].push_back(
method->GetSignificance());
593 if (!IsSilentFile()) {
594 Log() << kDEBUG <<
"\tWrite evaluation histograms to file" <<
Endl;
600 for (
Int_t k = 0; k < 2; k++) {
601 std::vector<std::vector<Double_t>>
vtemp;
612 vtemp.push_back(sig[k]);
613 vtemp.push_back(sep[k]);
641 const Int_t nvar =
method->fDataSetInfo.GetNVariables();
646 std::vector<Double_t>
rvec;
653 std::vector<TString> *
theVars =
new std::vector<TString>;
654 std::vector<ResultsClassification *>
mvaRes;
657 theVars->back().ReplaceAll(
"MVA_",
"");
679 Log() << kWARNING <<
"Found NaN return value in event: " <<
ievt <<
" for method \""
687 if (
method->fDataSetInfo.IsSignal(
ev)) {
708 (*overlapS) *= (1.0 /
defDs->GetNEvtSigTest());
709 (*overlapB) *= (1.0 /
defDs->GetNEvtBkgdTest());
711 tpSig->MakePrincipals();
712 tpBkg->MakePrincipals();
746 Log() << kINFO <<
Endl;
748 <<
"Inter-MVA correlation matrix (signal):" <<
Endl;
750 Log() << kINFO <<
Endl;
753 <<
"Inter-MVA correlation matrix (background):" <<
Endl;
755 Log() << kINFO <<
Endl;
759 <<
"Correlations between input variables and MVA response (signal):" <<
Endl;
761 Log() << kINFO <<
Endl;
764 <<
"Correlations between input variables and MVA response (background):" <<
Endl;
766 Log() << kINFO <<
Endl;
769 <<
"<TestAllMethods> cannot compute correlation matrices" <<
Endl;
773 <<
"The following \"overlap\" matrices contain the fraction of events for which " <<
Endl;
775 <<
"the MVAs 'i' and 'j' have returned conform answers about \"signal-likeness\"" <<
Endl;
777 <<
"An event is signal-like, if its MVA output exceeds the following value:" <<
Endl;
780 <<
"which correspond to the working point: eff(signal) = 1 - eff(background)" <<
Endl;
785 <<
"Note: no correlations and overlap with cut method are provided at present" <<
Endl;
788 Log() << kINFO <<
Endl;
790 <<
"Inter-MVA overlap matrix (signal):" <<
Endl;
792 Log() << kINFO <<
Endl;
795 <<
"Inter-MVA overlap matrix (background):" <<
Endl;
820 Log().EnableOutput();
823 TString hLine =
"------------------------------------------------------------------------------------------"
824 "-------------------------";
825 Log() << kINFO <<
"Evaluation results ranked by best signal efficiency and purity (area)" <<
Endl;
827 Log() << kINFO <<
"DataSet MVA " <<
Endl;
828 Log() << kINFO <<
"Name: Method: ROC-integ" <<
Endl;
831 for (
Int_t k = 0; k < 2; k++) {
834 Log() << kINFO <<
"Input Variables: " <<
Endl <<
hLine <<
Endl;
853 if (sep[k][i] < 0 || sig[k][i] < 0) {
855 fResult.fROCIntegral =
effArea[k][i];
857 <<
TString::Format(
"%-13s %-15s: %#1.3f", fDataLoader->GetName(), methodName.
Data(), fResult.fROCIntegral)
867 Log() << kINFO <<
Endl;
868 Log() << kINFO <<
"Testing efficiency compared to training efficiency (overtraining check)" <<
Endl;
871 <<
"DataSet MVA Signal efficiency: from test sample (from training sample) "
873 Log() << kINFO <<
"Name: Method: @B=0.01 @B=0.10 @B=0.30 "
876 for (
Int_t k = 0; k < 2; k++) {
879 Log() << kINFO <<
"Input Variables: " <<
Endl <<
hLine <<
Endl;
883 mname[k][i].ReplaceAll(
"Variable_",
"");
885 Log() << kINFO <<
TString::Format(
"%-20s %-15s: %#1.3f (%#1.3f) %#1.3f (%#1.3f) %#1.3f (%#1.3f)",
892 Log() << kINFO <<
Endl;
894 if (
gTools().CheckForSilentOption(GetOptions()))
895 Log().InhibitOutput();
896 }
else if (IsCutsMethod(
method)) {
897 for (
Int_t k = 0; k < 2; k++) {
900 if (sep[k][i] < 0 || sig[k][i] < 0) {
902 fResult.fROCIntegral =
effArea[k][i];
911 if (IsCutsMethod(
method)) {
912 fResult.fIsCuts =
kTRUE;
918 TString className =
method->DataInfo().GetClassInfo(0)->GetName();
919 fResult.fClassNames.push_back(className);
921 if (!IsSilentFile()) {
923 RootBaseDir()->cd(
method->fDataSetInfo.GetName());
947 if (fResults.size() == 0)
948 Log() << kFATAL <<
"No Classification results available" <<
Endl;
972 for (
auto &
result : fResults) {
979 result.fDataLoaderName = fDataLoader->GetName();
980 fResults.push_back(
result);
981 return fResults.back();
996 dataset->SetCurrentType(
type);
1001 Log() << kERROR <<
TString::Format(
"Given class number (iClass = %i) does not exist. There are %i classes in dataset.",
1023 std::vector<Float_t>
mvaRes;
1079 <<
TString::Format(
"ROCCurve object was not created in MethodName = %s MethodTitle = %s not found with Dataset = %s ",
1133 auto dsdir = fFile->mkdir(fDataLoader->GetName());
1138 for (
UInt_t i = 0; i < fMethods.size(); i++) {
1153 fFile->cd(fDataLoader->GetName());
1170 auto entries =
tmptrain->GetEntries();
1177 entries =
tmptest->GetEntries();
1189 for (
UInt_t i = 0; i < fMethods.size(); i++) {
#define MinNoTrainingEvents
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t r
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t result
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t index
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t src
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h Atom_t Int_t ULong_t ULong_t unsigned char prop_list Atom_t Atom_t Atom_t Time_t type
TMatrixT< Double_t > TMatrixD
R__EXTERN TSystem * gSystem
TClass instances represent classes, structs and namespaces in the ROOT type system.
Bool_t InheritsFrom(const char *cl) const override
Return kTRUE if this class inherits from a class with name "classname".
static TClass * GetClass(const char *name, Bool_t load=kTRUE, Bool_t silent=kFALSE)
Static method returning pointer to TClass of the specified class name.
A ROOT file is structured in Directories (like a file system).
Describe directory structure in memory.
A ROOT file is an on-disk file, usually with extension .root, that stores objects in a file-system-li...
A TGraph is an object made of two arrays X and Y with npoints each.
Book space in a file, create I/O buffers, to fill them, (un)compress them.
virtual const char * GetClassName() const
virtual TObject * ReadObj()
To read a TObject* from the file.
static ClassifierFactory & Instance()
access to the ClassifierFactory singleton creates the instance if needed
void SetDrawProgressBar(Bool_t d)
void SetUseColor(Bool_t uc)
class TMVA::Config::VariablePlotting fVariablePlotting
void SetConfigDescription(const char *d)
OptionBase * DeclareOptionRef(T &ref, const TString &name, const TString &desc="")
void SetConfigName(const char *n)
void CheckForUnusedOptions() const
checks for unused options in option string
Class that contains all the data information.
Abstract base class for all high level ml algorithms, you can book ml methods like BDT,...
Bool_t fModelPersistence
! flag to save the trained model
std::shared_ptr< DataLoader > fDataLoader
! data
virtual void ParseOptions()
Method to parse the internal option string.
static void SetIsTraining(Bool_t)
when this static function is called, it sets the flag whether events with negative event weight shoul...
Double_t GetROCIntegral(UInt_t iClass=0, TMVA::Types::ETreeType type=TMVA::Types::kTesting)
Method to get ROC-Integral value from mvas.
TGraph * GetROCGraph(UInt_t iClass=0, TMVA::Types::ETreeType type=TMVA::Types::kTesting)
Method to get TGraph object with the ROC curve.
void Show()
Method to print the results in stdout.
Bool_t IsMethod(TString methodname, TString methodtitle)
Method to check if method was booked.
std::map< UInt_t, std::vector< std::tuple< Float_t, Float_t, Bool_t > > > fMvaTest
Mvas for two-class and multiclass classification.
ROCCurve * GetROC(UInt_t iClass=0, TMVA::Types::ETreeType type=TMVA::Types::kTesting)
Method to get TMVA::ROCCurve Object.
Bool_t fIsCuts
if it is a method cuts need special output
ClassificationResult & operator=(const ClassificationResult &r)
std::map< UInt_t, std::vector< std::tuple< Float_t, Float_t, Bool_t > > > fMvaTrain
Mvas for two-class classification.
Classification(DataLoader *loader, TFile *file, TString options)
Contructor to create a two class classifier.
Double_t GetROCIntegral(TString methodname, TString methodtitle, UInt_t iClass=0)
Method to get ROC-Integral value from mvas.
virtual void Test()
Perform test evaluation in all booked methods.
TString GetMethodOptions(TString methodname, TString methodtitle)
return the options for the booked method.
MethodBase * GetMethod(TString methodname, TString methodtitle)
Return a TMVA::MethodBase object.
virtual void TrainMethod(TString methodname, TString methodtitle)
Lets train an specific ml method.
Bool_t HasMethodObject(TString methodname, TString methodtitle, Int_t &index)
Allows to check if the TMVA::MethodBase was created and return the index in the vector.
std::vector< ClassificationResult > & GetResults()
Return the vector of TMVA::Experimental::ClassificationResult objects.
virtual void Train()
Method to train all booked ml methods.
virtual void Evaluate()
Method to perform Train/Test over all ml method booked.
Types::EAnalysisType fAnalysisType
!
TMVA::ROCCurve * GetROC(TMVA::MethodBase *method, UInt_t iClass=0, TMVA::Types::ETreeType type=TMVA::Types::kTesting)
Method to get TMVA::ROCCurve Object.
Bool_t IsCutsMethod(TMVA::MethodBase *method)
Allows to check if the ml method is a Cuts method.
void CopyFrom(TDirectory *src, TFile *file)
virtual void TestMethod(TString methodname, TString methodtitle)
Lets perform test an specific ml method.
Interface for all concrete MVA method implementations.
Virtual base Class for all MVA method.
Class for boosting a TMVA method.
Class for categorizing the phase space.
ostringstream derivative to redirect and format output
static void InhibitOutput()
static void EnableOutput()
Class that is the base-class for a vector of result.
Class which takes the results of a multiclass classification.
Class that is the base-class for a vector of result.
Singleton class for Global types used by TMVA.
static Types & Instance()
The single instance of "Types" if existing already, or create it (Singleton)
const char * GetName() const override
Returns name of object.
Mother of all ROOT objects.
@ kOverwrite
overwrite existing object with same name
virtual Int_t Write(const char *name=nullptr, Int_t option=0, Int_t bufsize=0)
Write this object to the current directory.
Principal Components Analysis (PCA)
const char * Data() const
TString & ReplaceAll(const TString &s1, const TString &s2)
static TString Format(const char *fmt,...)
Static method which formats a string using a printf style format descriptor and return a TString.
virtual int MakeDirectory(const char *name)
Make a directory.
virtual int Unlink(const char *name)
Unlink, i.e.
A TTree represents a columnar dataset.
create variable transformations
MsgLogger & Endl(MsgLogger &ml)