46 for(
auto &roc:
fROCs) avg+=roc.second;
47 return avg/fROCs.size();
67 fLogger << kHEADER <<
" ==== Results ====" <<
Endl;
69 fLogger << kINFO <<
Form(
"Fold %i ROC-Int : %.4f",item.first,item.second) << std::endl;
71 fLogger << kINFO <<
"------------------------" <<
Endl;
83 fROCCurves->GetXaxis()->SetTitle(
" Signal Efficiency ");
84 fROCCurves->GetYaxis()->SetTitle(
" Background Rejection ");
87 c->
SetTitle(
"Cross Validation ROC Curves");
94 fNumFolds(5),fClassifier(new
TMVA::
Factory(
"CrossValidation",
"!V:!ROC:Silent:!ModelPersistence:!Color:!DrawProgressBar:AnalysisType=Classification"))
123 if (methodName ==
"")
124 Log() << kFATAL <<
"No method booked for cross-validation" <<
Endl;
128 Log() << kINFO <<
"Evaluate method: " << methodTitle <<
Endl;
139 Log() << kDEBUG <<
"Fold (" << methodTitle <<
"): " << i <<
Endl;
141 TString foldTitle = methodTitle;
142 foldTitle +=
"_fold";
186 Log() << kINFO <<
"Evaluation done." <<
Endl;
194 Log() << kFATAL <<
"No cross-validation results available" <<
Endl;
Float_t GetROCAverage() const
virtual void SetLineWidth(Width_t lwidth)
Set the line width.
MsgLogger & Endl(MsgLogger &ml)
void AddOutput(Types::ETreeType type, Types::EAnalysisType analysisType)
void SetTitle(const char *title="")
Set canvas title.
A TMultiGraph is a collection of TGraph (or derived) objects.
Virtual base Class for all MVA method.
Class to save the results of cross validation, the metric for the classification ins ROC and you can ...
std::unique_ptr< Factory > fClassifier
virtual void SetTitle(const char *title="")
Set graph title.
LongDouble_t Power(LongDouble_t x, LongDouble_t y)
static void SetIsTraining(Bool_t)
when this static function is called, it sets the flag whether events with negative event weight shoul...
std::vector< CrossValidationResult > fResults
void SetNumFolds(UInt_t i)
Abstract base class for all high level ml algorithms, you can book ml methods like BDT...
virtual Double_t GetEfficiency(const TString &, Types::ETreeType, Double_t &err)
fill background efficiency (resp.
virtual void SetLineColor(Color_t lcolor)
Set the line color.
virtual void ParseOptions()
Method to parse the internal option string.
void DeleteResults(const TString &, Types::ETreeType type, Types::EAnalysisType analysistype)
delete the results stored for this particular Method instance.
char * Form(const char *fmt,...)
Float_t GetROCStandardDeviation() const
const TString & GetMethodName() const
const std::vector< CrossValidationResult > & GetResults() const
This is the main MVA steering class.
virtual Double_t GetSignificance() const
compute significance of mean difference
CrossValidation(DataLoader *loader)
virtual void Evaluate()
Virtual method to be implemented with your algorithm.
virtual Double_t GetSeparation(TH1 *, TH1 *) const
compute "separation" defined as
std::shared_ptr< DataLoader > fDataLoader
Booked method information.
TMultiGraph * GetROCCurves(Bool_t fLegend=kTRUE)
ostringstream derivative to redirect and format output
virtual void Draw(Option_t *option="")
Draw a canvas.
Abstract ClassifierFactory template that handles arbitrary types.
virtual TLegend * BuildLegend(Double_t x1=0.3, Double_t y1=0.21, Double_t x2=0.3, Double_t y2=0.21, const char *title="", Option_t *option="")
Build a legend from the graphical objects in the pad.
std::map< UInt_t, Float_t > fROCs
A Graph is a graphics object made of two arrays X and Y with npoints each.
TCanvas * Draw(const TString name="CrossValidation") const
virtual Double_t GetTrainingEfficiency(const TString &)
Types::EAnalysisType GetAnalysisType() const
Double_t Sqrt(Double_t x)
static void EnableOutput()
virtual void TestClassification()
initialization
std::shared_ptr< TMultiGraph > fROCCurves
std::vector< OptionMap > fMethods
const char * Data() const