Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
SVWorkingSet.cxx
Go to the documentation of this file.
1// @(#)root/tmva $Id$
2// Author: Andrzej Zemla
3
4/**********************************************************************************
5 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6 * Package: TMVA *
7 * Class : SVWorkingSet *
8 * Web : http://tmva.sourceforge.net *
9 * *
10 * Description: *
11 * Implementation *
12 * *
13 * Authors (alphabetical): *
14 * Marcin Wolter <Marcin.Wolter@cern.ch> - IFJ PAN, Krakow, Poland *
15 * Andrzej Zemla <azemla@cern.ch> - IFJ PAN, Krakow, Poland *
16 * (IFJ PAN: Henryk Niewodniczanski Inst. Nucl. Physics, Krakow, Poland) *
17 * *
18 * Copyright (c) 2005: *
19 * CERN, Switzerland *
20 * MPI-K Heidelberg, Germany *
21 * PAN, Krakow, Poland *
22 * *
23 * Redistribution and use in source and binary forms, with or without *
24 * modification, are permitted according to the terms listed in LICENSE *
25 * (http://tmva.sourceforge.net/LICENSE) *
26 **********************************************************************************/
27
28/*! \class TMVA::SVWorkingSet
29\ingroup TMVA
30Working class for Support Vector Machine
31*/
32
33#include "TMVA/SVWorkingSet.h"
34
35#include "TMVA/MsgLogger.h"
36#include "TMVA/SVEvent.h"
38#include "TMVA/SVKernelMatrix.h"
39#include "TMVA/Types.h"
40
41
42#include "TMath.h"
43#include "TRandom3.h"
44
45#include <vector>
46
47////////////////////////////////////////////////////////////////////////////////
48/// constructor
49
51 : fdoRegression(kFALSE),
52 fInputData(0),
53 fSupVec(0),
54 fKFunction(0),
55 fKMatrix(0),
56 fTEventUp(0),
57 fTEventLow(0),
58 fB_low(1.),
59 fB_up(-1.),
60 fTolerance(0.01),
61 fLogger( new MsgLogger( "SVWorkingSet", kINFO ) )
62{
63}
64
65////////////////////////////////////////////////////////////////////////////////
66/// constructor
67
68TMVA::SVWorkingSet::SVWorkingSet(std::vector<TMVA::SVEvent*>*inputVectors, SVKernelFunction* kernelFunction,
69 Float_t tol, Bool_t doreg)
70 : fdoRegression(doreg),
71 fInputData(inputVectors),
72 fSupVec(0),
73 fKFunction(kernelFunction),
74 fTEventUp(0),
75 fTEventLow(0),
76 fB_low(1.),
77 fB_up(-1.),
78 fTolerance(tol),
79 fLogger( new MsgLogger( "SVWorkingSet", kINFO ) )
80{
81 fKMatrix = new TMVA::SVKernelMatrix(inputVectors, kernelFunction);
82 Float_t *pt;
83 for( UInt_t i = 0; i < fInputData->size(); i++){
84 pt = fKMatrix->GetLine(i);
85 fInputData->at(i)->SetLine(pt);
86 fInputData->at(i)->SetNs(i);
87 if(fdoRegression) fInputData->at(i)->SetErrorCache(fInputData->at(i)->GetTarget());
88 }
89 TRandom3 rand;
90 UInt_t kk = rand.Integer(fInputData->size());
91 if(fdoRegression) {
95 }
96 else{
97 while(1){
98 if(fInputData->at(kk)->GetTypeFlag()==-1){
99 fTEventLow = fInputData->at(kk);
100 break;
101 }
102 kk = rand.Integer(fInputData->size());
103 }
104
105 while (1){
106 if (fInputData->at(kk)->GetTypeFlag()==1) {
107 fTEventUp = fInputData->at(kk);
108 break;
109 }
110 kk = rand.Integer(fInputData->size());
111 }
112 }
115}
116
117////////////////////////////////////////////////////////////////////////////////
118/// destructor
119
121{
122 if (fKMatrix != 0) {delete fKMatrix; fKMatrix = 0;}
123 delete fLogger;
124}
125
126////////////////////////////////////////////////////////////////////////////////
127
129{
130 SVEvent* ievt=0;
131 Float_t fErrorC_J = 0.;
132 if( jevt->GetIdx()==0) fErrorC_J = jevt->GetErrorCache();
133 else{
134 Float_t *fKVals = jevt->GetLine();
135 fErrorC_J = 0.;
136 std::vector<TMVA::SVEvent*>::iterator idIter;
137
138 UInt_t k=0;
139 for(idIter = fInputData->begin(); idIter != fInputData->end(); ++idIter){
140 if((*idIter)->GetAlpha()>0)
141 fErrorC_J += (*idIter)->GetAlpha()*(*idIter)->GetTypeFlag()*fKVals[k];
142 k++;
143 }
144
145
146 fErrorC_J -= jevt->GetTypeFlag();
147 jevt->SetErrorCache(fErrorC_J);
148
149 if((jevt->GetIdx() == 1) && (fErrorC_J < fB_up )){
150 fB_up = fErrorC_J;
151 fTEventUp = jevt;
152 }
153 else if ((jevt->GetIdx() == -1)&&(fErrorC_J > fB_low)) {
154 fB_low = fErrorC_J;
155 fTEventLow = jevt;
156 }
157 }
158 Bool_t converged = kTRUE;
159
160 if((jevt->GetIdx()>=0) && (fB_low - fErrorC_J > 2*fTolerance)) {
161 converged = kFALSE;
162 ievt = fTEventLow;
163 }
164
165 if((jevt->GetIdx()<=0) && (fErrorC_J - fB_up > 2*fTolerance)) {
166 converged = kFALSE;
167 ievt = fTEventUp;
168 }
169
170 if (converged) return kFALSE;
171
172 if(jevt->GetIdx()==0){
173 if(fB_low - fErrorC_J > fErrorC_J - fB_up) ievt = fTEventLow;
174 else ievt = fTEventUp;
175 }
176
177 if (TakeStep(ievt, jevt)) return kTRUE;
178 else return kFALSE;
179}
180
181
182////////////////////////////////////////////////////////////////////////////////
183
185{
186 if (ievt == jevt) return kFALSE;
187 std::vector<TMVA::SVEvent*>::iterator idIter;
188 const Float_t epsilon = 1e-8; //make it 1-e6 or 1-e5 to make it faster
189
190 Float_t type_I, type_J;
191 Float_t errorC_I, errorC_J;
192 Float_t alpha_I, alpha_J;
193
194 Float_t newAlpha_I, newAlpha_J;
195 Int_t s;
196
197 Float_t l, h, lobj = 0, hobj = 0;
198 Float_t eta;
199
200 type_I = ievt->GetTypeFlag();
201 alpha_I = ievt->GetAlpha();
202 errorC_I = ievt->GetErrorCache();
203
204 type_J = jevt->GetTypeFlag();
205 alpha_J = jevt->GetAlpha();
206 errorC_J = jevt->GetErrorCache();
207
208 s = Int_t( type_I * type_J );
209
210 Float_t c_i = ievt->GetCweight();
211
212 Float_t c_j = jevt->GetCweight();
213
214 // compute l, h objective function
215
216 if (type_I == type_J) {
217 Float_t gamma = alpha_I + alpha_J;
218
219 if ( c_i > c_j ) {
220 if ( gamma < c_j ) {
221 l = 0;
222 h = gamma;
223 }
224 else{
225 h = c_j;
226 if ( gamma < c_i )
227 l = 0;
228 else
229 l = gamma - c_i;
230 }
231 }
232 else {
233 if ( gamma < c_i ){
234 l = 0;
235 h = gamma;
236 }
237 else {
238 l = gamma - c_i;
239 if ( gamma < c_j )
240 h = gamma;
241 else
242 h = c_j;
243 }
244 }
245 }
246 else {
247 Float_t gamma = alpha_I - alpha_J;
248 if (gamma > 0) {
249 l = 0;
250 if ( gamma >= (c_i - c_j) )
251 h = c_i - gamma;
252 else
253 h = c_j;
254 }
255 else {
256 l = -gamma;
257 if ( (c_i - c_j) >= gamma)
258 h = c_j;
259 else
260 h = c_i - gamma;
261 }
262 }
263
264 if (l == h) return kFALSE;
265 Float_t kernel_II, kernel_IJ, kernel_JJ;
266
267 kernel_II = fKMatrix->GetElement(ievt->GetNs(),ievt->GetNs());
268 kernel_IJ = fKMatrix->GetElement(ievt->GetNs(), jevt->GetNs());
269 kernel_JJ = fKMatrix->GetElement(jevt->GetNs(),jevt->GetNs());
270
271 eta = 2*kernel_IJ - kernel_II - kernel_JJ;
272 if (eta < 0) {
273 newAlpha_J = alpha_J + (type_J*( errorC_J - errorC_I ))/eta;
274 if (newAlpha_J < l) newAlpha_J = l;
275 else if (newAlpha_J > h) newAlpha_J = h;
276
277 }
278
279 else {
280
281 Float_t c_I = eta/2;
282 Float_t c_J = type_J*( errorC_I - errorC_J ) - eta * alpha_J;
283 lobj = c_I * l * l + c_J * l;
284 hobj = c_I * h * h + c_J * h;
285
286 if (lobj > hobj + epsilon) newAlpha_J = l;
287 else if (lobj < hobj - epsilon) newAlpha_J = h;
288 else newAlpha_J = alpha_J;
289 }
290
291 if (TMath::Abs( newAlpha_J - alpha_J ) < ( epsilon * ( newAlpha_J + alpha_J+ epsilon ))){
292 return kFALSE;
293 //it spends here to much time... it is stupido
294 }
295 newAlpha_I = alpha_I - s*( newAlpha_J - alpha_J );
296
297 if (newAlpha_I < 0) {
298 newAlpha_J += s* newAlpha_I;
299 newAlpha_I = 0;
300 }
301 else if (newAlpha_I > c_i) {
302 Float_t temp = newAlpha_I - c_i;
303 newAlpha_J += s * temp;
304 newAlpha_I = c_i;
305 }
306
307 Float_t dL_I = type_I * ( newAlpha_I - alpha_I );
308 Float_t dL_J = type_J * ( newAlpha_J - alpha_J );
309
310 for(idIter = fInputData->begin(); idIter != fInputData->end(); ++idIter){
311 if((*idIter)->GetIdx()==0){
312 Float_t ii = fKMatrix->GetElement(ievt->GetNs(), (*idIter)->GetNs());
313 Float_t jj = fKMatrix->GetElement(jevt->GetNs(), (*idIter)->GetNs());
314
315 (*idIter)->UpdateErrorCache(dL_I * ii + dL_J * jj);
316 }
317 }
318 ievt->SetAlpha(newAlpha_I);
319 jevt->SetAlpha(newAlpha_J);
320 // set new indexes
321 SetIndex(ievt);
322 SetIndex(jevt);
323
324 // update error cache
325 ievt->SetErrorCache(errorC_I + dL_I*kernel_II + dL_J*kernel_IJ);
326 jevt->SetErrorCache(errorC_J + dL_I*kernel_IJ + dL_J*kernel_JJ);
327
328 // compute fI_low, fB_low
329
330 fB_low = -1*1e30;
331 fB_up = 1e30;
332
333 for(idIter = fInputData->begin(); idIter != fInputData->end(); ++idIter){
334 if((*idIter)->GetIdx()==0){
335 if((*idIter)->GetErrorCache()> fB_low){
336 fB_low = (*idIter)->GetErrorCache();
337 fTEventLow = (*idIter);
338 }
339 if( (*idIter)->GetErrorCache()< fB_up){
340 fB_up =(*idIter)->GetErrorCache();
341 fTEventUp = (*idIter);
342 }
343 }
344 }
345
346 // for optimized alfa's
347 if (fB_low < TMath::Max(ievt->GetErrorCache(), jevt->GetErrorCache())) {
348 if (ievt->GetErrorCache() > fB_low) {
349 fB_low = ievt->GetErrorCache();
350 fTEventLow = ievt;
351 }
352 else {
353 fB_low = jevt->GetErrorCache();
354 fTEventLow = jevt;
355 }
356 }
357
358 if (fB_up > TMath::Max(ievt->GetErrorCache(), jevt->GetErrorCache())) {
359 if (ievt->GetErrorCache()< fB_low) {
360 fB_up =ievt->GetErrorCache();
361 fTEventUp = ievt;
362 }
363 else {
364 fB_up =jevt->GetErrorCache() ;
365 fTEventUp = jevt;
366 }
367 }
368 return kTRUE;
369}
370
371////////////////////////////////////////////////////////////////////////////////
372
374{
375 if((fB_up > fB_low - 2*fTolerance)) return kTRUE;
376 return kFALSE;
377}
378
379////////////////////////////////////////////////////////////////////////////////
380/// train the SVM
381
383{
384
385 Int_t numChanged = 0;
386 Int_t examineAll = 1;
387
388 Float_t numChangedOld = 0;
389 Int_t deltaChanges = 0;
390 UInt_t numit = 0;
391
392 std::vector<TMVA::SVEvent*>::iterator idIter;
393
394 while ((numChanged > 0) || (examineAll > 0)) {
395 if (fIPyCurrentIter) *fIPyCurrentIter = numit;
396 if (fExitFromTraining && *fExitFromTraining) break;
397 numChanged = 0;
398 if (examineAll) {
399 for (idIter = fInputData->begin(); idIter!=fInputData->end(); ++idIter){
400 if(!fdoRegression) numChanged += (UInt_t)ExamineExample(*idIter);
401 else numChanged += (UInt_t)ExamineExampleReg(*idIter);
402 }
403 }
404 else {
405 for (idIter = fInputData->begin(); idIter!=fInputData->end(); ++idIter) {
406 if ((*idIter)->IsInI0()) {
407 if(!fdoRegression) numChanged += (UInt_t)ExamineExample(*idIter);
408 else numChanged += (UInt_t)ExamineExampleReg(*idIter);
409 if (Terminated()) {
410 numChanged = 0;
411 break;
412 }
413 }
414 }
415 }
416
417 if (examineAll == 1) examineAll = 0;
418 else if (numChanged == 0 || numChanged < 10 || deltaChanges > 3 ) examineAll = 1;
419
420 if (numChanged == numChangedOld) deltaChanges++;
421 else deltaChanges = 0;
422 numChangedOld = numChanged;
423 ++numit;
424
425 if (numit >= nMaxIter) {
426 *fLogger << kWARNING
427 << "Max number of iterations exceeded. "
428 << "Training may not be completed. Try use less Cost parameter" << Endl;
429 break;
430 }
431 }
432}
433
434////////////////////////////////////////////////////////////////////////////////
435
437{
438 if( (0< event->GetAlpha()) && (event->GetAlpha()< event->GetCweight()))
439 event->SetIdx(0);
440
441 if( event->GetTypeFlag() == 1){
442 if( event->GetAlpha() == 0)
443 event->SetIdx(1);
444 else if( event->GetAlpha() == event->GetCweight() )
445 event->SetIdx(-1);
446 }
447 if( event->GetTypeFlag() == -1){
448 if( event->GetAlpha() == 0)
449 event->SetIdx(-1);
450 else if( event->GetAlpha() == event->GetCweight() )
451 event->SetIdx(1);
452 }
453}
454
455////////////////////////////////////////////////////////////////////////////////
456
457std::vector<TMVA::SVEvent*>* TMVA::SVWorkingSet::GetSupportVectors()
458{
459 std::vector<TMVA::SVEvent*>::iterator idIter;
460 if( fSupVec != 0) {delete fSupVec; fSupVec = 0; }
461 fSupVec = new std::vector<TMVA::SVEvent*>(0);
462
463 for( idIter = fInputData->begin(); idIter != fInputData->end(); ++idIter){
464 if((*idIter)->GetDeltaAlpha() !=0){
465 fSupVec->push_back((*idIter));
466 }
467 }
468 return fSupVec;
469}
470
471//for regression
472
474{
475 if (ievt == jevt) return kFALSE;
476 std::vector<TMVA::SVEvent*>::iterator idIter;
477 const Float_t epsilon = 0.001*fTolerance;//TODO
478
479 const Float_t kernel_II = fKMatrix->GetElement(ievt->GetNs(),ievt->GetNs());
480 const Float_t kernel_IJ = fKMatrix->GetElement(ievt->GetNs(),jevt->GetNs());
481 const Float_t kernel_JJ = fKMatrix->GetElement(jevt->GetNs(),jevt->GetNs());
482
483 //compute eta & gamma
484 const Float_t eta = -2*kernel_IJ + kernel_II + kernel_JJ;
485 const Float_t gamma = ievt->GetDeltaAlpha() + jevt->GetDeltaAlpha();
486
487 //TODO CHECK WHAT IF ETA <0
488 //w.r.t Mercer's conditions it should never happen, but what if?
489
490 Bool_t caseA, caseB, caseC, caseD, terminated;
491 caseA = caseB = caseC = caseD = terminated = kFALSE;
492 Float_t b_alpha_i, b_alpha_j, b_alpha_i_p, b_alpha_j_p; //temporary Lagrange multipliers
493 const Float_t b_cost_i = ievt->GetCweight();
494 const Float_t b_cost_j = jevt->GetCweight();
495
496 b_alpha_i = ievt->GetAlpha();
497 b_alpha_j = jevt->GetAlpha();
498 b_alpha_i_p = ievt->GetAlpha_p();
499 b_alpha_j_p = jevt->GetAlpha_p();
500
501 //calculate deltafi
502 Float_t deltafi = ievt->GetErrorCache()-jevt->GetErrorCache();
503
504 // main loop
505 while(!terminated) {
506 const Float_t null = 0.; //!!! dummy float null declaration because of problems with TMath::Max/Min(Float_t, Float_t) function
507 Float_t low, high;
508 Float_t tmp_alpha_i, tmp_alpha_j;
509 tmp_alpha_i = tmp_alpha_j = 0.;
510
511 //TODO check this conditions, are they proper
512 if((caseA == kFALSE) && (b_alpha_i > 0 || (b_alpha_i_p == 0 && deltafi > 0)) && (b_alpha_j > 0 || (b_alpha_j_p == 0 && deltafi < 0)))
513 {
514 //compute low, high w.r.t a_i, a_j
515 low = TMath::Max( null, gamma - b_cost_j );
516 high = TMath::Min( b_cost_i , gamma);
517
518 if(low<high){
519 tmp_alpha_j = b_alpha_j - (deltafi/eta);
520 tmp_alpha_j = TMath::Min(tmp_alpha_j,high );
521 tmp_alpha_j = TMath::Max(low ,tmp_alpha_j);
522 tmp_alpha_i = b_alpha_i - (tmp_alpha_j - b_alpha_j);
523
524 //update Li & Lj if change is significant (??)
525 if( IsDiffSignificant(b_alpha_j,tmp_alpha_j, epsilon) || IsDiffSignificant(b_alpha_i,tmp_alpha_i, epsilon)){
526 b_alpha_j = tmp_alpha_j;
527 b_alpha_i = tmp_alpha_i;
528 }
529
530 }
531 else
532 terminated = kTRUE;
533
534 caseA = kTRUE;
535 }
536 else if((caseB==kFALSE) && (b_alpha_i>0 || (b_alpha_i_p==0 && deltafi >2*epsilon )) && (b_alpha_j_p>0 || (b_alpha_j==0 && deltafi>2*epsilon)))
537 {
538 //compute LH w.r.t. a_i, a_j*
539 low = TMath::Max( null, gamma ); //TODO
540 high = TMath::Min( b_cost_i , b_cost_j + gamma);
541
542
543 if(low<high){
544 tmp_alpha_j = b_alpha_j_p - ((deltafi-2*epsilon)/eta);
545 tmp_alpha_j = TMath::Min(tmp_alpha_j,high);
546 tmp_alpha_j = TMath::Max(low,tmp_alpha_j);
547 tmp_alpha_i = b_alpha_i - (tmp_alpha_j - b_alpha_j_p);
548
549 //update alphai alphaj_p
550 if( IsDiffSignificant(b_alpha_j_p,tmp_alpha_j, epsilon) || IsDiffSignificant(b_alpha_i,tmp_alpha_i, epsilon)){
551 b_alpha_j_p = tmp_alpha_j;
552 b_alpha_i = tmp_alpha_i;
553 }
554 }
555 else
556 terminated = kTRUE;
557
558 caseB = kTRUE;
559 }
560 else if((caseC==kFALSE) && (b_alpha_i_p>0 || (b_alpha_i==0 && deltafi < -2*epsilon )) && (b_alpha_j>0 || (b_alpha_j_p==0 && deltafi< -2*epsilon)))
561 {
562 //compute LH w.r.t. alphai_p alphaj
563 low = TMath::Max(null, -gamma );
564 high = TMath::Min(b_cost_i, -gamma+b_cost_j);
565
566 if(low<high){
567 tmp_alpha_j = b_alpha_j - ((deltafi+2*epsilon)/eta);
568 tmp_alpha_j = TMath::Min(tmp_alpha_j,high );
569 tmp_alpha_j = TMath::Max(low ,tmp_alpha_j);
570 tmp_alpha_i = b_alpha_i_p - (tmp_alpha_j - b_alpha_j);
571
572 //update alphai_p alphaj
573 if( IsDiffSignificant(b_alpha_j,tmp_alpha_j, epsilon) || IsDiffSignificant(b_alpha_i_p,tmp_alpha_i, epsilon)){
574 b_alpha_j = tmp_alpha_j;
575 b_alpha_i_p = tmp_alpha_i;
576 }
577 }
578 else
579 terminated = kTRUE;
580
581 caseC = kTRUE;
582 }
583 else if((caseD == kFALSE) &&
584 (b_alpha_i_p>0 || (b_alpha_i==0 && deltafi <0 )) &&
585 (b_alpha_j_p>0 || (b_alpha_j==0 && deltafi >0 )))
586 {
587 //compute LH w.r.t. alphai_p alphaj_p
588 low = TMath::Max(null,-gamma - b_cost_j);
589 high = TMath::Min(b_cost_i, -gamma);
590
591 if(low<high){
592 tmp_alpha_j = b_alpha_j_p + (deltafi/eta);
593 tmp_alpha_j = TMath::Min(tmp_alpha_j,high );
594 tmp_alpha_j = TMath::Max(low ,tmp_alpha_j);
595 tmp_alpha_i = b_alpha_i_p - (tmp_alpha_j - b_alpha_j_p);
596
597 if( IsDiffSignificant(b_alpha_j_p,tmp_alpha_j, epsilon) || IsDiffSignificant(b_alpha_i_p,tmp_alpha_i, epsilon)){
598 b_alpha_j_p = tmp_alpha_j;
599 b_alpha_i_p = tmp_alpha_i;
600 }
601 }
602 else
603 terminated = kTRUE;
604
605 caseD = kTRUE;
606 }
607 else
608 terminated = kTRUE;
609 }
610 // TODO ad commment how it was calculated
611 deltafi += ievt->GetDeltaAlpha()*(kernel_II - kernel_IJ) + jevt->GetDeltaAlpha()*(kernel_IJ - kernel_JJ);
612
613 if( IsDiffSignificant(b_alpha_i, ievt->GetAlpha(), epsilon) ||
614 IsDiffSignificant(b_alpha_j, jevt->GetAlpha(), epsilon) ||
615 IsDiffSignificant(b_alpha_i_p, ievt->GetAlpha_p(), epsilon) ||
616 IsDiffSignificant(b_alpha_j_p, jevt->GetAlpha_p(), epsilon) ){
617
618 //TODO check if these conditions might be easier
619 //TODO write documentation for this
620 const Float_t diff_alpha_i = ievt->GetDeltaAlpha()+b_alpha_i_p - ievt->GetAlpha();
621 const Float_t diff_alpha_j = jevt->GetDeltaAlpha()+b_alpha_j_p - jevt->GetAlpha();
622
623 //update error cache
624 for(idIter = fInputData->begin(); idIter != fInputData->end(); ++idIter){
625 //there will be some changes in Idx notation
626 if((*idIter)->GetIdx()==0){
627 Float_t k_ii = fKMatrix->GetElement(ievt->GetNs(), (*idIter)->GetNs());
628 Float_t k_jj = fKMatrix->GetElement(jevt->GetNs(), (*idIter)->GetNs());
629
630 (*idIter)->UpdateErrorCache(diff_alpha_i * k_ii + diff_alpha_j * k_jj);
631 }
632 }
633
634 //store new alphas in SVevents
635 ievt->SetAlpha(b_alpha_i);
636 jevt->SetAlpha(b_alpha_j);
637 ievt->SetAlpha_p(b_alpha_i_p);
638 jevt->SetAlpha_p(b_alpha_j_p);
639
640 //TODO update Idexes
641
642 // compute fI_low, fB_low
643
644 fB_low = -1*1e30;
645 fB_up =1e30;
646
647 for(idIter = fInputData->begin(); idIter != fInputData->end(); ++idIter){
648 if((!(*idIter)->IsInI3()) && ((*idIter)->GetErrorCache()> fB_low)){
649 fB_low = (*idIter)->GetErrorCache();
650 fTEventLow = (*idIter);
651
652 }
653 if((!(*idIter)->IsInI2()) && ((*idIter)->GetErrorCache()< fB_up)){
654 fB_up =(*idIter)->GetErrorCache();
655 fTEventUp = (*idIter);
656 }
657 }
658 return kTRUE;
659 } else return kFALSE;
660}
661
662
663////////////////////////////////////////////////////////////////////////////////
664
666{
667 Float_t feps = 1e-7;// TODO check which value is the best
668 SVEvent* ievt=0;
669 Float_t fErrorC_J = 0.;
670 if( jevt->IsInI0()) {
671 fErrorC_J = jevt->GetErrorCache();
672 }
673 else{
674 Float_t *fKVals = jevt->GetLine();
675 fErrorC_J = 0.;
676 std::vector<TMVA::SVEvent*>::iterator idIter;
677
678 UInt_t k=0;
679 for(idIter = fInputData->begin(); idIter != fInputData->end(); ++idIter){
680 fErrorC_J -= (*idIter)->GetDeltaAlpha()*fKVals[k];
681 k++;
682 }
683
684 fErrorC_J += jevt->GetTarget();
685 jevt->SetErrorCache(fErrorC_J);
686
687 if(jevt->IsInI1()){
688 if(fErrorC_J + feps < fB_up ){
689 fB_up = fErrorC_J + feps;
690 fTEventUp = jevt;
691 }
692 else if(fErrorC_J -feps > fB_low) {
693 fB_low = fErrorC_J - feps;
694 fTEventLow = jevt;
695 }
696 }else if((jevt->IsInI2()) && (fErrorC_J + feps > fB_low)){
697 fB_low = fErrorC_J + feps;
698 fTEventLow = jevt;
699 }else if((jevt->IsInI3()) && (fErrorC_J - feps < fB_up)){
700 fB_up = fErrorC_J - feps;
701 fTEventUp = jevt;
702 }
703 }
704
705 Bool_t converged = kTRUE;
706 //case 1
707 if(jevt->IsInI0a()){
708 if( fB_low -fErrorC_J + feps > 2*fTolerance){
709 converged = kFALSE;
710 ievt = fTEventLow;
711 if(fErrorC_J-feps-fB_up > fB_low-fErrorC_J+feps){
712 ievt = fTEventUp;
713 }
714 }else if(fErrorC_J -feps - fB_up > 2*fTolerance){
715 converged = kFALSE;
716 ievt = fTEventUp;
717 if(fB_low - fErrorC_J+feps > fErrorC_J-feps -fB_up){
718 ievt = fTEventLow;
719 }
720 }
721 }
722
723 //case 2
724 if(jevt->IsInI0b()){
725 if( fB_low -fErrorC_J - feps > 2*fTolerance){
726 converged = kFALSE;
727 ievt = fTEventLow;
728 if(fErrorC_J+feps-fB_up > fB_low-fErrorC_J-feps){
729 ievt = fTEventUp;
730 }
731 }else if(fErrorC_J + feps - fB_up > 2*fTolerance){
732 converged = kFALSE;
733 ievt = fTEventUp;
734 if(fB_low - fErrorC_J-feps > fErrorC_J+feps -fB_up){
735 ievt = fTEventLow;
736 }
737 }
738 }
739
740 //case 3
741 if(jevt->IsInI1()){
742 if( fB_low -fErrorC_J - feps > 2*fTolerance){
743 converged = kFALSE;
744 ievt = fTEventLow;
745 if(fErrorC_J+feps-fB_up > fB_low-fErrorC_J-feps){
746 ievt = fTEventUp;
747 }
748 }else if(fErrorC_J - feps - fB_up > 2*fTolerance){
749 converged = kFALSE;
750 ievt = fTEventUp;
751 if(fB_low - fErrorC_J+feps > fErrorC_J-feps -fB_up){
752 ievt = fTEventLow;
753 }
754 }
755 }
756
757 //case 4
758 if(jevt->IsInI2()){
759 if( fErrorC_J + feps -fB_up > 2*fTolerance){
760 converged = kFALSE;
761 ievt = fTEventUp;
762 }
763 }
764
765 //case 5
766 if(jevt->IsInI3()){
767 if(fB_low -fErrorC_J +feps > 2*fTolerance){
768 converged = kFALSE;
769 ievt = fTEventLow;
770 }
771 }
772
773 if(converged) return kFALSE;
774 if (TakeStepReg(ievt, jevt)) return kTRUE;
775 else return kFALSE;
776}
777
779{
780 if( TMath::Abs(a_i - a_j) > eps*(a_i + a_j + eps)) return kTRUE;
781 else return kFALSE;
782}
783
#define h(i)
Definition RSha256.hxx:106
#define e(i)
Definition RSha256.hxx:103
int Int_t
Definition RtypesCore.h:45
unsigned int UInt_t
Definition RtypesCore.h:46
float Float_t
Definition RtypesCore.h:57
constexpr Bool_t kFALSE
Definition RtypesCore.h:101
constexpr Bool_t kTRUE
Definition RtypesCore.h:100
ostringstream derivative to redirect and format output
Definition MsgLogger.h:57
Event class for Support Vector Machine.
Definition SVEvent.h:40
Bool_t IsInI0a() const
Definition SVEvent.h:74
Float_t GetTarget() const
Definition SVEvent.h:72
Int_t GetIdx() const
Definition SVEvent.h:68
Bool_t IsInI2() const
Definition SVEvent.h:78
Float_t GetErrorCache() const
Definition SVEvent.h:65
Float_t GetCweight() const
Definition SVEvent.h:71
Float_t * GetLine() const
Definition SVEvent.h:69
void SetAlpha_p(Float_t alpha)
Definition SVEvent.h:52
Bool_t IsInI3() const
Definition SVEvent.h:79
Float_t GetAlpha() const
Definition SVEvent.h:61
Bool_t IsInI0() const
Definition SVEvent.h:76
UInt_t GetNs() const
Definition SVEvent.h:70
Bool_t IsInI0b() const
Definition SVEvent.h:75
Float_t GetAlpha_p() const
Definition SVEvent.h:62
void SetAlpha(Float_t alpha)
Definition SVEvent.h:51
Bool_t IsInI1() const
Definition SVEvent.h:77
Int_t GetTypeFlag() const
Definition SVEvent.h:66
void SetErrorCache(Float_t err_cache)
Definition SVEvent.h:53
Float_t GetDeltaAlpha() const
Definition SVEvent.h:63
Kernel for Support Vector Machine.
Kernel matrix for Support Vector Machine.
Float_t * GetLine(UInt_t)
returns a row of the kernel matrix
SVEvent * fTEventUp
last optimized event
Bool_t TakeStep(SVEvent *, SVEvent *)
void Train(UInt_t nIter=1000)
train the SVM
Bool_t IsDiffSignificant(Float_t, Float_t, Float_t)
Float_t fTolerance
documentation
Bool_t ExamineExample(SVEvent *)
Bool_t ExamineExampleReg(SVEvent *)
Bool_t fdoRegression
TODO temporary, find nicer solution.
Bool_t TakeStepReg(SVEvent *, SVEvent *)
Float_t fB_low
documentation
Float_t fB_up
documentation
SVEvent * fTEventLow
last optimized event
void SetIndex(TMVA::SVEvent *)
~SVWorkingSet()
destructor
SVWorkingSet()
constructor
std::vector< TMVA::SVEvent * > * fInputData
input events
SVKernelMatrix * fKMatrix
kernel matrix
std::vector< TMVA::SVEvent * > * GetSupportVectors()
Random number generator class based on M.
Definition TRandom3.h:27
virtual UInt_t Integer(UInt_t imax)
Returns a random integer uniformly distributed on the interval [ 0, imax-1 ].
Definition TRandom.cxx:360
TPaveText * pt
null_t< F > null()
MsgLogger & Endl(MsgLogger &ml)
Definition MsgLogger.h:148
Short_t Max(Short_t a, Short_t b)
Returns the largest of a and b.
Definition TMathBase.h:250
Short_t Min(Short_t a, Short_t b)
Returns the smallest of a and b.
Definition TMathBase.h:198
Short_t Abs(Short_t d)
Returns the absolute value of parameter Short_t d.
Definition TMathBase.h:123
TLine l
Definition textangle.C:4
double epsilon
Definition triangle.c:618