61 fLogger( new
MsgLogger(
Form(
"ResultsMultiClass%s",resultsName.Data()) , kINFO) ),
63 fAchievableEff(dsi->GetNClasses()),
64 fAchievablePur(dsi->GetNClasses()),
65 fBestCuts(dsi->GetNClasses(),
std::vector<
Double_t>(dsi->GetNClasses()))
81 if (ievt >= (
Int_t)fMultiClassValues.size()) fMultiClassValues.resize( ievt+1 );
82 fMultiClassValues[ievt] = value;
91 const DataSet *ds = GetDataSet();
96 TMatrixD mat(numClasses, numClasses);
99 for (
UInt_t iRow = 0; iRow < numClasses; ++iRow) {
100 for (
UInt_t iCol = 0; iCol < numClasses; ++iCol) {
104 mat(iRow, iCol) = std::numeric_limits<double>::quiet_NaN();
107 std::vector<Float_t> valueVector;
108 std::vector<Bool_t> classVector;
109 std::vector<Float_t> weightVector;
115 const Float_t mvaValue = fMultiClassValues[iEvt][iRow];
117 if (cls != iRow and cls != iCol) {
121 classVector.push_back(cls == iRow);
122 weightVector.push_back(weight);
123 valueVector.push_back(mvaValue);
126 ROCCurve roc(valueVector, classVector, weightVector);
146 UInt_t evClass = fEventClasses[ievt];
147 Float_t w = fEventWeights[ievt];
149 Bool_t break_outer_loop =
false;
150 for (
UInt_t icls = 0; icls < cutvalues.size(); ++icls) {
151 auto value = fMultiClassValues[ievt][icls];
152 auto cutvalue = cutvalues.at(icls);
153 if (cutvalue < 0. ? (-value < cutvalue) : (+value <= cutvalue)) {
154 break_outer_loop =
true;
159 if (break_outer_loop) {
163 Bool_t isEvCurrClass = (evClass == fClassToOptimize);
164 positives[isEvCurrClass] += w;
167 const Float_t truePositive = positives[1];
168 const Float_t falsePositive = positives[0];
170 Float_t eff = truePositive / fClassSumWeights[fClassToOptimize];
171 Float_t pur = truePositive / (truePositive + falsePositive);
174 Float_t toMinimize = std::numeric_limits<float>::max();
175 if (effTimesPur > std::numeric_limits<float>::min())
176 toMinimize = 1./(effTimesPur);
178 fAchievableEff.at(fClassToOptimize) = eff;
179 fAchievablePur.at(fClassToOptimize) = pur;
191 Log() << kINFO <<
"Calculating best set of cuts for class "
194 fClassToOptimize = targetClass;
197 fClassSumWeights.clear();
198 fEventWeights.clear();
199 fEventClasses.clear();
202 fClassSumWeights.push_back(0);
209 fEventWeights.push_back(ev->
GetWeight());
210 fEventClasses.push_back(ev->
GetClass());
214 const TString opts(
"PopSize=100:Steps=30" );
217 std::vector<Double_t> result;
220 fBestCuts.at(targetClass) = result;
223 for( std::vector<Double_t>::iterator it = result.begin(); it<result.end(); ++it ){
243 Log() << kINFO <<
"Creating multiclass performance histograms..." <<
Endl;
251 std::vector<std::vector<Float_t>> *rawMvaRes = GetValueVector();
256 for (
size_t iClass = 0; iClass < numClasses; ++iClass) {
263 if ( DoesExist(
name) ) {
268 std::vector<Float_t> mvaRes;
269 std::vector<Bool_t> mvaResTypes;
270 std::vector<Float_t> mvaResWeights;
275 mvaRes.reserve(rawMvaRes->size());
276 for (
auto item : *rawMvaRes) {
277 mvaRes.push_back(item[iClass]);
281 mvaResTypes.reserve(eventCollection.size());
282 mvaResWeights.reserve(eventCollection.size());
283 for (
auto ev : eventCollection) {
284 mvaResTypes.push_back(ev->GetClass() == iClass);
285 mvaResWeights.push_back(ev->GetWeight());
304 for (
size_t iClass = 0; iClass < numClasses; ++iClass) {
305 for (
size_t jClass = 0; jClass < numClasses; ++jClass) {
306 if (iClass == jClass) {
313 std::vector<Float_t> mvaRes;
314 std::vector<Bool_t> mvaResTypes;
315 std::vector<Float_t> mvaResWeights;
317 mvaRes.reserve(rawMvaRes->size());
318 mvaResTypes.reserve(eventCollection.size());
319 mvaResWeights.reserve(eventCollection.size());
321 for (
size_t iEvent = 0; iEvent < eventCollection.size(); ++iEvent) {
322 Event *ev = eventCollection[iEvent];
325 Float_t output_value = (*rawMvaRes)[iEvent][iClass];
326 mvaRes.push_back(output_value);
327 mvaResTypes.push_back(ev->
GetClass() == iClass);
328 mvaResWeights.push_back(ev->
GetWeight());
356 Log() << kINFO <<
"Creating multiclass response histograms..." <<
Endl;
362 std::vector<std::vector<TH1F*> > histos;
366 histos.push_back(std::vector<TH1F*>(0));
373 if ( DoesExist(
name) ) {
386 histos.at(cls).at(jCls)->Fill(fMultiClassValues[ievt][jCls],w);
392 Store(histos.at(iCls).at(jCls));
char * Form(const char *fmt,...)
A Graph is a graphics object made of two arrays X and Y with npoints each.
virtual void SetName(const char *name="")
Set graph name.
virtual void SetTitle(const char *title="")
Set graph title.
1-D histogram with a float per channel (see TH1 documentation)}
Class that contains all the data information.
UInt_t GetNClasses() const
ClassInfo * GetClassInfo(Int_t clNum) const
Class that contains all the data information.
const Event * GetEvent() const
Long64_t GetNEvents(Types::ETreeType type=Types::kMaxTreeType) const
void SetCurrentType(Types::ETreeType type) const
const std::vector< Event * > & GetEventCollection(Types::ETreeType type=Types::kMaxTreeType) const
Double_t GetWeight() const
return the event weight - depending on whether the flag IgnoreNegWeightsInTraining is or not.
Fitter using a Genetic Algorithm.
Interface for a fitter 'target'.
The TMVA::Interval Class.
ostringstream derivative to redirect and format output
Double_t GetEffSForEffB(Double_t effB, const UInt_t num_points=41)
Calculate the signal efficiency (sensitivity) for a given background efficiency (sensitivity).
TGraph * GetROCCurve(const UInt_t points=100)
Returns a new TGraph containing the ROC curve.
TMatrixD GetConfusionMatrix(Double_t effB)
Returns a confusion matrix where each class is pitted against each other.
std::vector< Double_t > GetBestMultiClassCuts(UInt_t targetClass)
calculate the best working point (optimal cut values) for the multiclass classifier
Double_t EstimatorFunction(std::vector< Double_t > &)
ResultsMulticlass(const DataSetInfo *dsi, TString resultsName)
constructor
void CreateMulticlassHistos(TString prefix, Int_t nbins, Int_t nbins_high)
this function fills the mva response histos for multiclass classification
~ResultsMulticlass()
destructor
void CreateMulticlassPerformanceHistos(TString prefix)
Create performance graphs for this classifier a multiclass setting.
void SetValue(std::vector< Float_t > &value, Int_t ievt)
Class that is the base-class for a vector of result.
virtual const char * GetName() const
Returns name of object.
const char * Data() const
static constexpr double mg
MsgLogger & Endl(MsgLogger &ml)