Logo ROOT  
Reference Guide
SVWorkingSet.cxx
Go to the documentation of this file.
1// @(#)root/tmva $Id$
2// Author: Andrzej Zemla
3
4/**********************************************************************************
5 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6 * Package: TMVA *
7 * Class : SVWorkingSet *
8 * Web : http://tmva.sourceforge.net *
9 * *
10 * Description: *
11 * Implementation *
12 * *
13 * Authors (alphabetical): *
14 * Marcin Wolter <Marcin.Wolter@cern.ch> - IFJ PAN, Krakow, Poland *
15 * Andrzej Zemla <azemla@cern.ch> - IFJ PAN, Krakow, Poland *
16 * (IFJ PAN: Henryk Niewodniczanski Inst. Nucl. Physics, Krakow, Poland) *
17 * *
18 * Copyright (c) 2005: *
19 * CERN, Switzerland *
20 * MPI-K Heidelberg, Germany *
21 * PAN, Krakow, Poland *
22 * *
23 * Redistribution and use in source and binary forms, with or without *
24 * modification, are permitted according to the terms listed in LICENSE *
25 * (http://tmva.sourceforge.net/LICENSE) *
26 **********************************************************************************/
27
28/*! \class TMVA::SVWorkingSet
29\ingroup TMVA
30Working class for Support Vector Machine
31*/
32
33#include "TMVA/SVWorkingSet.h"
34
35#include "TMVA/MsgLogger.h"
36#include "TMVA/SVEvent.h"
38#include "TMVA/SVKernelMatrix.h"
39#include "TMVA/Types.h"
40
41
42#include "TMath.h"
43#include "TRandom3.h"
44
45#include <iostream>
46#include <vector>
47
48////////////////////////////////////////////////////////////////////////////////
49/// constructor
50
52 : fdoRegression(kFALSE),
53 fInputData(0),
54 fSupVec(0),
55 fKFunction(0),
56 fKMatrix(0),
57 fTEventUp(0),
58 fTEventLow(0),
59 fB_low(1.),
60 fB_up(-1.),
61 fTolerance(0.01),
62 fLogger( new MsgLogger( "SVWorkingSet", kINFO ) )
63{
64}
65
66////////////////////////////////////////////////////////////////////////////////
67/// constructor
68
69TMVA::SVWorkingSet::SVWorkingSet(std::vector<TMVA::SVEvent*>*inputVectors, SVKernelFunction* kernelFunction,
70 Float_t tol, Bool_t doreg)
71 : fdoRegression(doreg),
72 fInputData(inputVectors),
73 fSupVec(0),
74 fKFunction(kernelFunction),
75 fTEventUp(0),
76 fTEventLow(0),
77 fB_low(1.),
78 fB_up(-1.),
79 fTolerance(tol),
80 fLogger( new MsgLogger( "SVWorkingSet", kINFO ) )
81{
82 fKMatrix = new TMVA::SVKernelMatrix(inputVectors, kernelFunction);
83 Float_t *pt;
84 for( UInt_t i = 0; i < fInputData->size(); i++){
85 pt = fKMatrix->GetLine(i);
86 fInputData->at(i)->SetLine(pt);
87 fInputData->at(i)->SetNs(i);
88 if(fdoRegression) fInputData->at(i)->SetErrorCache(fInputData->at(i)->GetTarget());
89 }
90 TRandom3 rand;
91 UInt_t kk = rand.Integer(fInputData->size());
92 if(fdoRegression) {
96 }
97 else{
98 while(1){
99 if(fInputData->at(kk)->GetTypeFlag()==-1){
100 fTEventLow = fInputData->at(kk);
101 break;
102 }
103 kk = rand.Integer(fInputData->size());
104 }
105
106 while (1){
107 if (fInputData->at(kk)->GetTypeFlag()==1) {
108 fTEventUp = fInputData->at(kk);
109 break;
110 }
111 kk = rand.Integer(fInputData->size());
112 }
113 }
116}
117
118////////////////////////////////////////////////////////////////////////////////
119/// destructor
120
122{
123 if (fKMatrix != 0) {delete fKMatrix; fKMatrix = 0;}
124 delete fLogger;
125}
126
127////////////////////////////////////////////////////////////////////////////////
128
130{
131 SVEvent* ievt=0;
132 Float_t fErrorC_J = 0.;
133 if( jevt->GetIdx()==0) fErrorC_J = jevt->GetErrorCache();
134 else{
135 Float_t *fKVals = jevt->GetLine();
136 fErrorC_J = 0.;
137 std::vector<TMVA::SVEvent*>::iterator idIter;
138
139 UInt_t k=0;
140 for(idIter = fInputData->begin(); idIter != fInputData->end(); ++idIter){
141 if((*idIter)->GetAlpha()>0)
142 fErrorC_J += (*idIter)->GetAlpha()*(*idIter)->GetTypeFlag()*fKVals[k];
143 k++;
144 }
145
146
147 fErrorC_J -= jevt->GetTypeFlag();
148 jevt->SetErrorCache(fErrorC_J);
149
150 if((jevt->GetIdx() == 1) && (fErrorC_J < fB_up )){
151 fB_up = fErrorC_J;
152 fTEventUp = jevt;
153 }
154 else if ((jevt->GetIdx() == -1)&&(fErrorC_J > fB_low)) {
155 fB_low = fErrorC_J;
156 fTEventLow = jevt;
157 }
158 }
159 Bool_t converged = kTRUE;
160
161 if((jevt->GetIdx()>=0) && (fB_low - fErrorC_J > 2*fTolerance)) {
162 converged = kFALSE;
163 ievt = fTEventLow;
164 }
165
166 if((jevt->GetIdx()<=0) && (fErrorC_J - fB_up > 2*fTolerance)) {
167 converged = kFALSE;
168 ievt = fTEventUp;
169 }
170
171 if (converged) return kFALSE;
172
173 if(jevt->GetIdx()==0){
174 if(fB_low - fErrorC_J > fErrorC_J - fB_up) ievt = fTEventLow;
175 else ievt = fTEventUp;
176 }
177
178 if (TakeStep(ievt, jevt)) return kTRUE;
179 else return kFALSE;
180}
181
182
183////////////////////////////////////////////////////////////////////////////////
184
186{
187 if (ievt == jevt) return kFALSE;
188 std::vector<TMVA::SVEvent*>::iterator idIter;
189 const Float_t epsilon = 1e-8; //make it 1-e6 or 1-e5 to make it faster
190
191 Float_t type_I, type_J;
192 Float_t errorC_I, errorC_J;
193 Float_t alpha_I, alpha_J;
194
195 Float_t newAlpha_I, newAlpha_J;
196 Int_t s;
197
198 Float_t l, h, lobj = 0, hobj = 0;
199 Float_t eta;
200
201 type_I = ievt->GetTypeFlag();
202 alpha_I = ievt->GetAlpha();
203 errorC_I = ievt->GetErrorCache();
204
205 type_J = jevt->GetTypeFlag();
206 alpha_J = jevt->GetAlpha();
207 errorC_J = jevt->GetErrorCache();
208
209 s = Int_t( type_I * type_J );
210
211 Float_t c_i = ievt->GetCweight();
212
213 Float_t c_j = jevt->GetCweight();
214
215 // compute l, h objective function
216
217 if (type_I == type_J) {
218 Float_t gamma = alpha_I + alpha_J;
219
220 if ( c_i > c_j ) {
221 if ( gamma < c_j ) {
222 l = 0;
223 h = gamma;
224 }
225 else{
226 h = c_j;
227 if ( gamma < c_i )
228 l = 0;
229 else
230 l = gamma - c_i;
231 }
232 }
233 else {
234 if ( gamma < c_i ){
235 l = 0;
236 h = gamma;
237 }
238 else {
239 l = gamma - c_i;
240 if ( gamma < c_j )
241 h = gamma;
242 else
243 h = c_j;
244 }
245 }
246 }
247 else {
248 Float_t gamma = alpha_I - alpha_J;
249 if (gamma > 0) {
250 l = 0;
251 if ( gamma >= (c_i - c_j) )
252 h = c_i - gamma;
253 else
254 h = c_j;
255 }
256 else {
257 l = -gamma;
258 if ( (c_i - c_j) >= gamma)
259 h = c_j;
260 else
261 h = c_i - gamma;
262 }
263 }
264
265 if (l == h) return kFALSE;
266 Float_t kernel_II, kernel_IJ, kernel_JJ;
267
268 kernel_II = fKMatrix->GetElement(ievt->GetNs(),ievt->GetNs());
269 kernel_IJ = fKMatrix->GetElement(ievt->GetNs(), jevt->GetNs());
270 kernel_JJ = fKMatrix->GetElement(jevt->GetNs(),jevt->GetNs());
271
272 eta = 2*kernel_IJ - kernel_II - kernel_JJ;
273 if (eta < 0) {
274 newAlpha_J = alpha_J + (type_J*( errorC_J - errorC_I ))/eta;
275 if (newAlpha_J < l) newAlpha_J = l;
276 else if (newAlpha_J > h) newAlpha_J = h;
277
278 }
279
280 else {
281
282 Float_t c_I = eta/2;
283 Float_t c_J = type_J*( errorC_I - errorC_J ) - eta * alpha_J;
284 lobj = c_I * l * l + c_J * l;
285 hobj = c_I * h * h + c_J * h;
286
287 if (lobj > hobj + epsilon) newAlpha_J = l;
288 else if (lobj < hobj - epsilon) newAlpha_J = h;
289 else newAlpha_J = alpha_J;
290 }
291
292 if (TMath::Abs( newAlpha_J - alpha_J ) < ( epsilon * ( newAlpha_J + alpha_J+ epsilon ))){
293 return kFALSE;
294 //it spends here to much time... it is stupido
295 }
296 newAlpha_I = alpha_I - s*( newAlpha_J - alpha_J );
297
298 if (newAlpha_I < 0) {
299 newAlpha_J += s* newAlpha_I;
300 newAlpha_I = 0;
301 }
302 else if (newAlpha_I > c_i) {
303 Float_t temp = newAlpha_I - c_i;
304 newAlpha_J += s * temp;
305 newAlpha_I = c_i;
306 }
307
308 Float_t dL_I = type_I * ( newAlpha_I - alpha_I );
309 Float_t dL_J = type_J * ( newAlpha_J - alpha_J );
310
311 Int_t k = 0;
312 for(idIter = fInputData->begin(); idIter != fInputData->end(); ++idIter){
313 k++;
314 if((*idIter)->GetIdx()==0){
315 Float_t ii = fKMatrix->GetElement(ievt->GetNs(), (*idIter)->GetNs());
316 Float_t jj = fKMatrix->GetElement(jevt->GetNs(), (*idIter)->GetNs());
317
318 (*idIter)->UpdateErrorCache(dL_I * ii + dL_J * jj);
319 }
320 }
321 ievt->SetAlpha(newAlpha_I);
322 jevt->SetAlpha(newAlpha_J);
323 // set new indexes
324 SetIndex(ievt);
325 SetIndex(jevt);
326
327 // update error cache
328 ievt->SetErrorCache(errorC_I + dL_I*kernel_II + dL_J*kernel_IJ);
329 jevt->SetErrorCache(errorC_J + dL_I*kernel_IJ + dL_J*kernel_JJ);
330
331 // compute fI_low, fB_low
332
333 fB_low = -1*1e30;
334 fB_up = 1e30;
335
336 for(idIter = fInputData->begin(); idIter != fInputData->end(); ++idIter){
337 if((*idIter)->GetIdx()==0){
338 if((*idIter)->GetErrorCache()> fB_low){
339 fB_low = (*idIter)->GetErrorCache();
340 fTEventLow = (*idIter);
341 }
342 if( (*idIter)->GetErrorCache()< fB_up){
343 fB_up =(*idIter)->GetErrorCache();
344 fTEventUp = (*idIter);
345 }
346 }
347 }
348
349 // for optimized alfa's
350 if (fB_low < TMath::Max(ievt->GetErrorCache(), jevt->GetErrorCache())) {
351 if (ievt->GetErrorCache() > fB_low) {
352 fB_low = ievt->GetErrorCache();
353 fTEventLow = ievt;
354 }
355 else {
356 fB_low = jevt->GetErrorCache();
357 fTEventLow = jevt;
358 }
359 }
360
361 if (fB_up > TMath::Max(ievt->GetErrorCache(), jevt->GetErrorCache())) {
362 if (ievt->GetErrorCache()< fB_low) {
363 fB_up =ievt->GetErrorCache();
364 fTEventUp = ievt;
365 }
366 else {
367 fB_up =jevt->GetErrorCache() ;
368 fTEventUp = jevt;
369 }
370 }
371 return kTRUE;
372}
373
374////////////////////////////////////////////////////////////////////////////////
375
377{
378 if((fB_up > fB_low - 2*fTolerance)) return kTRUE;
379 return kFALSE;
380}
381
382////////////////////////////////////////////////////////////////////////////////
383/// train the SVM
384
386{
387
388 Int_t numChanged = 0;
389 Int_t examineAll = 1;
390
391 Float_t numChangedOld = 0;
392 Int_t deltaChanges = 0;
393 UInt_t numit = 0;
394
395 std::vector<TMVA::SVEvent*>::iterator idIter;
396
397 while ((numChanged > 0) || (examineAll > 0)) {
398 if (fIPyCurrentIter) *fIPyCurrentIter = numit;
399 if (fExitFromTraining && *fExitFromTraining) break;
400 numChanged = 0;
401 if (examineAll) {
402 for (idIter = fInputData->begin(); idIter!=fInputData->end(); ++idIter){
403 if(!fdoRegression) numChanged += (UInt_t)ExamineExample(*idIter);
404 else numChanged += (UInt_t)ExamineExampleReg(*idIter);
405 }
406 }
407 else {
408 for (idIter = fInputData->begin(); idIter!=fInputData->end(); ++idIter) {
409 if ((*idIter)->IsInI0()) {
410 if(!fdoRegression) numChanged += (UInt_t)ExamineExample(*idIter);
411 else numChanged += (UInt_t)ExamineExampleReg(*idIter);
412 if (Terminated()) {
413 numChanged = 0;
414 break;
415 }
416 }
417 }
418 }
419
420 if (examineAll == 1) examineAll = 0;
421 else if (numChanged == 0 || numChanged < 10 || deltaChanges > 3 ) examineAll = 1;
422
423 if (numChanged == numChangedOld) deltaChanges++;
424 else deltaChanges = 0;
425 numChangedOld = numChanged;
426 ++numit;
427
428 if (numit >= nMaxIter) {
429 *fLogger << kWARNING
430 << "Max number of iterations exceeded. "
431 << "Training may not be completed. Try use less Cost parameter" << Endl;
432 break;
433 }
434 }
435}
436
437////////////////////////////////////////////////////////////////////////////////
438
440{
441 if( (0< event->GetAlpha()) && (event->GetAlpha()< event->GetCweight()))
442 event->SetIdx(0);
443
444 if( event->GetTypeFlag() == 1){
445 if( event->GetAlpha() == 0)
446 event->SetIdx(1);
447 else if( event->GetAlpha() == event->GetCweight() )
448 event->SetIdx(-1);
449 }
450 if( event->GetTypeFlag() == -1){
451 if( event->GetAlpha() == 0)
452 event->SetIdx(-1);
453 else if( event->GetAlpha() == event->GetCweight() )
454 event->SetIdx(1);
455 }
456}
457
458////////////////////////////////////////////////////////////////////////////////
459
461{
462 std::vector<TMVA::SVEvent*>::iterator idIter;
463 UInt_t counter = 0;
464 for( idIter = fInputData->begin(); idIter != fInputData->end(); ++idIter)
465 if((*idIter)->GetAlpha() !=0) counter++;
466}
467
468////////////////////////////////////////////////////////////////////////////////
469
470std::vector<TMVA::SVEvent*>* TMVA::SVWorkingSet::GetSupportVectors()
471{
472 std::vector<TMVA::SVEvent*>::iterator idIter;
473 if( fSupVec != 0) {delete fSupVec; fSupVec = 0; }
474 fSupVec = new std::vector<TMVA::SVEvent*>(0);
475
476 for( idIter = fInputData->begin(); idIter != fInputData->end(); ++idIter){
477 if((*idIter)->GetDeltaAlpha() !=0){
478 fSupVec->push_back((*idIter));
479 }
480 }
481 return fSupVec;
482}
483
484//for regression
485
487{
488 if (ievt == jevt) return kFALSE;
489 std::vector<TMVA::SVEvent*>::iterator idIter;
490 const Float_t epsilon = 0.001*fTolerance;//TODO
491
492 const Float_t kernel_II = fKMatrix->GetElement(ievt->GetNs(),ievt->GetNs());
493 const Float_t kernel_IJ = fKMatrix->GetElement(ievt->GetNs(),jevt->GetNs());
494 const Float_t kernel_JJ = fKMatrix->GetElement(jevt->GetNs(),jevt->GetNs());
495
496 //compute eta & gamma
497 const Float_t eta = -2*kernel_IJ + kernel_II + kernel_JJ;
498 const Float_t gamma = ievt->GetDeltaAlpha() + jevt->GetDeltaAlpha();
499
500 //TODO CHECK WHAT IF ETA <0
501 //w.r.t Mercer's conditions it should never happen, but what if?
502
503 Bool_t caseA, caseB, caseC, caseD, terminated;
504 caseA = caseB = caseC = caseD = terminated = kFALSE;
505 Float_t b_alpha_i, b_alpha_j, b_alpha_i_p, b_alpha_j_p; //temporary Lagrange multipliers
506 const Float_t b_cost_i = ievt->GetCweight();
507 const Float_t b_cost_j = jevt->GetCweight();
508
509 b_alpha_i = ievt->GetAlpha();
510 b_alpha_j = jevt->GetAlpha();
511 b_alpha_i_p = ievt->GetAlpha_p();
512 b_alpha_j_p = jevt->GetAlpha_p();
513
514 //calculate deltafi
515 Float_t deltafi = ievt->GetErrorCache()-jevt->GetErrorCache();
516
517 // main loop
518 while(!terminated) {
519 const Float_t null = 0.; //!!! dummy float null declaration because of problems with TMath::Max/Min(Float_t, Float_t) function
520 Float_t low, high;
521 Float_t tmp_alpha_i, tmp_alpha_j;
522 tmp_alpha_i = tmp_alpha_j = 0.;
523
524 //TODO check this conditions, are they proper
525 if((caseA == kFALSE) && (b_alpha_i > 0 || (b_alpha_i_p == 0 && deltafi > 0)) && (b_alpha_j > 0 || (b_alpha_j_p == 0 && deltafi < 0)))
526 {
527 //compute low, high w.r.t a_i, a_j
528 low = TMath::Max( null, gamma - b_cost_j );
529 high = TMath::Min( b_cost_i , gamma);
530
531 if(low<high){
532 tmp_alpha_j = b_alpha_j - (deltafi/eta);
533 tmp_alpha_j = TMath::Min(tmp_alpha_j,high );
534 tmp_alpha_j = TMath::Max(low ,tmp_alpha_j);
535 tmp_alpha_i = b_alpha_i - (tmp_alpha_j - b_alpha_j);
536
537 //update Li & Lj if change is significant (??)
538 if( IsDiffSignificant(b_alpha_j,tmp_alpha_j, epsilon) || IsDiffSignificant(b_alpha_i,tmp_alpha_i, epsilon)){
539 b_alpha_j = tmp_alpha_j;
540 b_alpha_i = tmp_alpha_i;
541 }
542
543 }
544 else
545 terminated = kTRUE;
546
547 caseA = kTRUE;
548 }
549 else if((caseB==kFALSE) && (b_alpha_i>0 || (b_alpha_i_p==0 && deltafi >2*epsilon )) && (b_alpha_j_p>0 || (b_alpha_j==0 && deltafi>2*epsilon)))
550 {
551 //compute LH w.r.t. a_i, a_j*
552 low = TMath::Max( null, gamma ); //TODO
553 high = TMath::Min( b_cost_i , b_cost_j + gamma);
554
555
556 if(low<high){
557 tmp_alpha_j = b_alpha_j_p - ((deltafi-2*epsilon)/eta);
558 tmp_alpha_j = TMath::Min(tmp_alpha_j,high);
559 tmp_alpha_j = TMath::Max(low,tmp_alpha_j);
560 tmp_alpha_i = b_alpha_i - (tmp_alpha_j - b_alpha_j_p);
561
562 //update alphai alphaj_p
563 if( IsDiffSignificant(b_alpha_j_p,tmp_alpha_j, epsilon) || IsDiffSignificant(b_alpha_i,tmp_alpha_i, epsilon)){
564 b_alpha_j_p = tmp_alpha_j;
565 b_alpha_i = tmp_alpha_i;
566 }
567 }
568 else
569 terminated = kTRUE;
570
571 caseB = kTRUE;
572 }
573 else if((caseC==kFALSE) && (b_alpha_i_p>0 || (b_alpha_i==0 && deltafi < -2*epsilon )) && (b_alpha_j>0 || (b_alpha_j_p==0 && deltafi< -2*epsilon)))
574 {
575 //compute LH w.r.t. alphai_p alphaj
576 low = TMath::Max(null, -gamma );
577 high = TMath::Min(b_cost_i, -gamma+b_cost_j);
578
579 if(low<high){
580 tmp_alpha_j = b_alpha_j - ((deltafi+2*epsilon)/eta);
581 tmp_alpha_j = TMath::Min(tmp_alpha_j,high );
582 tmp_alpha_j = TMath::Max(low ,tmp_alpha_j);
583 tmp_alpha_i = b_alpha_i_p - (tmp_alpha_j - b_alpha_j);
584
585 //update alphai_p alphaj
586 if( IsDiffSignificant(b_alpha_j,tmp_alpha_j, epsilon) || IsDiffSignificant(b_alpha_i_p,tmp_alpha_i, epsilon)){
587 b_alpha_j = tmp_alpha_j;
588 b_alpha_i_p = tmp_alpha_i;
589 }
590 }
591 else
592 terminated = kTRUE;
593
594 caseC = kTRUE;
595 }
596 else if((caseD == kFALSE) &&
597 (b_alpha_i_p>0 || (b_alpha_i==0 && deltafi <0 )) &&
598 (b_alpha_j_p>0 || (b_alpha_j==0 && deltafi >0 )))
599 {
600 //compute LH w.r.t. alphai_p alphaj_p
601 low = TMath::Max(null,-gamma - b_cost_j);
602 high = TMath::Min(b_cost_i, -gamma);
603
604 if(low<high){
605 tmp_alpha_j = b_alpha_j_p + (deltafi/eta);
606 tmp_alpha_j = TMath::Min(tmp_alpha_j,high );
607 tmp_alpha_j = TMath::Max(low ,tmp_alpha_j);
608 tmp_alpha_i = b_alpha_i_p - (tmp_alpha_j - b_alpha_j_p);
609
610 if( IsDiffSignificant(b_alpha_j_p,tmp_alpha_j, epsilon) || IsDiffSignificant(b_alpha_i_p,tmp_alpha_i, epsilon)){
611 b_alpha_j_p = tmp_alpha_j;
612 b_alpha_i_p = tmp_alpha_i;
613 }
614 }
615 else
616 terminated = kTRUE;
617
618 caseD = kTRUE;
619 }
620 else
621 terminated = kTRUE;
622 }
623 // TODO ad commment how it was calculated
624 deltafi += ievt->GetDeltaAlpha()*(kernel_II - kernel_IJ) + jevt->GetDeltaAlpha()*(kernel_IJ - kernel_JJ);
625
626 if( IsDiffSignificant(b_alpha_i, ievt->GetAlpha(), epsilon) ||
627 IsDiffSignificant(b_alpha_j, jevt->GetAlpha(), epsilon) ||
628 IsDiffSignificant(b_alpha_i_p, ievt->GetAlpha_p(), epsilon) ||
629 IsDiffSignificant(b_alpha_j_p, jevt->GetAlpha_p(), epsilon) ){
630
631 //TODO check if these conditions might be easier
632 //TODO write documentation for this
633 const Float_t diff_alpha_i = ievt->GetDeltaAlpha()+b_alpha_i_p - ievt->GetAlpha();
634 const Float_t diff_alpha_j = jevt->GetDeltaAlpha()+b_alpha_j_p - jevt->GetAlpha();
635
636 //update error cache
637 Int_t k = 0;
638 for(idIter = fInputData->begin(); idIter != fInputData->end(); ++idIter){
639 k++;
640 //there will be some changes in Idx notation
641 if((*idIter)->GetIdx()==0){
642 Float_t k_ii = fKMatrix->GetElement(ievt->GetNs(), (*idIter)->GetNs());
643 Float_t k_jj = fKMatrix->GetElement(jevt->GetNs(), (*idIter)->GetNs());
644
645 (*idIter)->UpdateErrorCache(diff_alpha_i * k_ii + diff_alpha_j * k_jj);
646 }
647 }
648
649 //store new alphas in SVevents
650 ievt->SetAlpha(b_alpha_i);
651 jevt->SetAlpha(b_alpha_j);
652 ievt->SetAlpha_p(b_alpha_i_p);
653 jevt->SetAlpha_p(b_alpha_j_p);
654
655 //TODO update Idexes
656
657 // compute fI_low, fB_low
658
659 fB_low = -1*1e30;
660 fB_up =1e30;
661
662 for(idIter = fInputData->begin(); idIter != fInputData->end(); ++idIter){
663 if((!(*idIter)->IsInI3()) && ((*idIter)->GetErrorCache()> fB_low)){
664 fB_low = (*idIter)->GetErrorCache();
665 fTEventLow = (*idIter);
666
667 }
668 if((!(*idIter)->IsInI2()) && ((*idIter)->GetErrorCache()< fB_up)){
669 fB_up =(*idIter)->GetErrorCache();
670 fTEventUp = (*idIter);
671 }
672 }
673 return kTRUE;
674 } else return kFALSE;
675}
676
677
678////////////////////////////////////////////////////////////////////////////////
679
681{
682 Float_t feps = 1e-7;// TODO check which value is the best
683 SVEvent* ievt=0;
684 Float_t fErrorC_J = 0.;
685 if( jevt->IsInI0()) {
686 fErrorC_J = jevt->GetErrorCache();
687 }
688 else{
689 Float_t *fKVals = jevt->GetLine();
690 fErrorC_J = 0.;
691 std::vector<TMVA::SVEvent*>::iterator idIter;
692
693 UInt_t k=0;
694 for(idIter = fInputData->begin(); idIter != fInputData->end(); ++idIter){
695 fErrorC_J -= (*idIter)->GetDeltaAlpha()*fKVals[k];
696 k++;
697 }
698
699 fErrorC_J += jevt->GetTarget();
700 jevt->SetErrorCache(fErrorC_J);
701
702 if(jevt->IsInI1()){
703 if(fErrorC_J + feps < fB_up ){
704 fB_up = fErrorC_J + feps;
705 fTEventUp = jevt;
706 }
707 else if(fErrorC_J -feps > fB_low) {
708 fB_low = fErrorC_J - feps;
709 fTEventLow = jevt;
710 }
711 }else if((jevt->IsInI2()) && (fErrorC_J + feps > fB_low)){
712 fB_low = fErrorC_J + feps;
713 fTEventLow = jevt;
714 }else if((jevt->IsInI3()) && (fErrorC_J - feps < fB_up)){
715 fB_up = fErrorC_J - feps;
716 fTEventUp = jevt;
717 }
718 }
719
720 Bool_t converged = kTRUE;
721 //case 1
722 if(jevt->IsInI0a()){
723 if( fB_low -fErrorC_J + feps > 2*fTolerance){
724 converged = kFALSE;
725 ievt = fTEventLow;
726 if(fErrorC_J-feps-fB_up > fB_low-fErrorC_J+feps){
727 ievt = fTEventUp;
728 }
729 }else if(fErrorC_J -feps - fB_up > 2*fTolerance){
730 converged = kFALSE;
731 ievt = fTEventUp;
732 if(fB_low - fErrorC_J+feps > fErrorC_J-feps -fB_up){
733 ievt = fTEventLow;
734 }
735 }
736 }
737
738 //case 2
739 if(jevt->IsInI0b()){
740 if( fB_low -fErrorC_J - feps > 2*fTolerance){
741 converged = kFALSE;
742 ievt = fTEventLow;
743 if(fErrorC_J+feps-fB_up > fB_low-fErrorC_J-feps){
744 ievt = fTEventUp;
745 }
746 }else if(fErrorC_J + feps - fB_up > 2*fTolerance){
747 converged = kFALSE;
748 ievt = fTEventUp;
749 if(fB_low - fErrorC_J-feps > fErrorC_J+feps -fB_up){
750 ievt = fTEventLow;
751 }
752 }
753 }
754
755 //case 3
756 if(jevt->IsInI1()){
757 if( fB_low -fErrorC_J - feps > 2*fTolerance){
758 converged = kFALSE;
759 ievt = fTEventLow;
760 if(fErrorC_J+feps-fB_up > fB_low-fErrorC_J-feps){
761 ievt = fTEventUp;
762 }
763 }else if(fErrorC_J - feps - fB_up > 2*fTolerance){
764 converged = kFALSE;
765 ievt = fTEventUp;
766 if(fB_low - fErrorC_J+feps > fErrorC_J-feps -fB_up){
767 ievt = fTEventLow;
768 }
769 }
770 }
771
772 //case 4
773 if(jevt->IsInI2()){
774 if( fErrorC_J + feps -fB_up > 2*fTolerance){
775 converged = kFALSE;
776 ievt = fTEventUp;
777 }
778 }
779
780 //case 5
781 if(jevt->IsInI3()){
782 if(fB_low -fErrorC_J +feps > 2*fTolerance){
783 converged = kFALSE;
784 ievt = fTEventLow;
785 }
786 }
787
788 if(converged) return kFALSE;
789 if (TakeStepReg(ievt, jevt)) return kTRUE;
790 else return kFALSE;
791}
792
794{
795 if( TMath::Abs(a_i - a_j) > eps*(a_i + a_j + eps)) return kTRUE;
796 else return kFALSE;
797}
798
#define h(i)
Definition: RSha256.hxx:106
#define e(i)
Definition: RSha256.hxx:103
int Int_t
Definition: RtypesCore.h:41
unsigned int UInt_t
Definition: RtypesCore.h:42
const Bool_t kFALSE
Definition: RtypesCore.h:88
bool Bool_t
Definition: RtypesCore.h:59
float Float_t
Definition: RtypesCore.h:53
const Bool_t kTRUE
Definition: RtypesCore.h:87
ostringstream derivative to redirect and format output
Definition: MsgLogger.h:59
Event class for Support Vector Machine.
Definition: SVEvent.h:40
Bool_t IsInI0a() const
Definition: SVEvent.h:74
Float_t GetTarget() const
Definition: SVEvent.h:72
Int_t GetIdx() const
Definition: SVEvent.h:68
Bool_t IsInI2() const
Definition: SVEvent.h:78
Float_t GetErrorCache() const
Definition: SVEvent.h:65
Float_t GetCweight() const
Definition: SVEvent.h:71
Float_t * GetLine() const
Definition: SVEvent.h:69
void SetAlpha_p(Float_t alpha)
Definition: SVEvent.h:52
Bool_t IsInI3() const
Definition: SVEvent.h:79
Float_t GetAlpha() const
Definition: SVEvent.h:61
Bool_t IsInI0() const
Definition: SVEvent.h:76
UInt_t GetNs() const
Definition: SVEvent.h:70
Bool_t IsInI0b() const
Definition: SVEvent.h:75
Float_t GetAlpha_p() const
Definition: SVEvent.h:62
void SetAlpha(Float_t alpha)
Definition: SVEvent.h:51
Bool_t IsInI1() const
Definition: SVEvent.h:77
Int_t GetTypeFlag() const
Definition: SVEvent.h:66
void SetErrorCache(Float_t err_cache)
Definition: SVEvent.h:53
Float_t GetDeltaAlpha() const
Definition: SVEvent.h:63
Kernel for Support Vector Machine.
Kernel matrix for Support Vector Machine.
Float_t * GetLine(UInt_t)
returns a row of the kernel matrix
SVEvent * fTEventUp
Definition: SVWorkingSet.h:79
Bool_t TakeStep(SVEvent *, SVEvent *)
void Train(UInt_t nIter=1000)
train the SVM
Bool_t IsDiffSignificant(Float_t, Float_t, Float_t)
Bool_t ExamineExample(SVEvent *)
Bool_t ExamineExampleReg(SVEvent *)
Bool_t TakeStepReg(SVEvent *, SVEvent *)
SVEvent * fTEventLow
Definition: SVWorkingSet.h:80
void SetIndex(TMVA::SVEvent *)
~SVWorkingSet()
destructor
SVWorkingSet()
constructor
std::vector< TMVA::SVEvent * > * fInputData
Definition: SVWorkingSet.h:74
SVKernelMatrix * fKMatrix
Definition: SVWorkingSet.h:77
std::vector< TMVA::SVEvent * > * GetSupportVectors()
Random number generator class based on M.
Definition: TRandom3.h:27
virtual UInt_t Integer(UInt_t imax)
Returns a random integer uniformly distributed on the interval [ 0, imax-1 ].
Definition: TRandom.cxx:349
TPaveText * pt
double gamma(double x)
static constexpr double s
null_t< F > null()
MsgLogger & Endl(MsgLogger &ml)
Definition: MsgLogger.h:158
Short_t Max(Short_t a, Short_t b)
Definition: TMathBase.h:212
Short_t Min(Short_t a, Short_t b)
Definition: TMathBase.h:180
Short_t Abs(Short_t d)
Definition: TMathBase.h:120
auto * l
Definition: textangle.C:4
REAL epsilon
Definition: triangle.c:617