Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
SVWorkingSet.cxx
Go to the documentation of this file.
1// @(#)root/tmva $Id$
2// Author: Andrzej Zemla
3
4/**********************************************************************************
5 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6 * Package: TMVA *
7 * Class : SVWorkingSet *
8 * Web : http://tmva.sourceforge.net *
9 * *
10 * Description: *
11 * Implementation *
12 * *
13 * Authors (alphabetical): *
14 * Marcin Wolter <Marcin.Wolter@cern.ch> - IFJ PAN, Krakow, Poland *
15 * Andrzej Zemla <azemla@cern.ch> - IFJ PAN, Krakow, Poland *
16 * (IFJ PAN: Henryk Niewodniczanski Inst. Nucl. Physics, Krakow, Poland) *
17 * *
18 * Copyright (c) 2005: *
19 * CERN, Switzerland *
20 * MPI-K Heidelberg, Germany *
21 * PAN, Krakow, Poland *
22 * *
23 * Redistribution and use in source and binary forms, with or without *
24 * modification, are permitted according to the terms listed in LICENSE *
25 * (http://tmva.sourceforge.net/LICENSE) *
26 **********************************************************************************/
27
28/*! \class TMVA::SVWorkingSet
29\ingroup TMVA
30Working class for Support Vector Machine
31*/
32
33#include "TMVA/SVWorkingSet.h"
34
35#include "TMVA/MsgLogger.h"
36#include "TMVA/SVEvent.h"
38#include "TMVA/SVKernelMatrix.h"
39#include "TMVA/Types.h"
40
41
42#include "TMath.h"
43#include "TRandom3.h"
44
45#include <vector>
46
47////////////////////////////////////////////////////////////////////////////////
48/// constructor
49
51 : fdoRegression(kFALSE),
52 fInputData(0),
53 fSupVec(0),
54 fKFunction(0),
55 fKMatrix(0),
56 fTEventUp(0),
57 fTEventLow(0),
58 fB_low(1.),
59 fB_up(-1.),
60 fTolerance(0.01),
61 fLogger( new MsgLogger( "SVWorkingSet", kINFO ) )
62{
63}
64
65////////////////////////////////////////////////////////////////////////////////
66/// constructor
67
68TMVA::SVWorkingSet::SVWorkingSet(std::vector<TMVA::SVEvent*>*inputVectors, SVKernelFunction* kernelFunction,
69 Float_t tol, Bool_t doreg)
70 : fdoRegression(doreg),
71 fInputData(inputVectors),
72 fSupVec(0),
73 fKFunction(kernelFunction),
74 fTEventUp(0),
75 fTEventLow(0),
76 fB_low(1.),
77 fB_up(-1.),
78 fTolerance(tol),
79 fLogger( new MsgLogger( "SVWorkingSet", kINFO ) )
80{
81 fKMatrix = new TMVA::SVKernelMatrix(inputVectors, kernelFunction);
82 Float_t *pt;
83 for( UInt_t i = 0; i < fInputData->size(); i++){
84 pt = fKMatrix->GetLine(i);
85 fInputData->at(i)->SetLine(pt);
86 fInputData->at(i)->SetNs(i);
87 if(fdoRegression) fInputData->at(i)->SetErrorCache(fInputData->at(i)->GetTarget());
88 }
89 TRandom3 rand;
90 UInt_t kk = rand.Integer(fInputData->size());
91 if(fdoRegression) {
95 }
96 else{
97 while(1){
98 if(fInputData->at(kk)->GetTypeFlag()==-1){
99 fTEventLow = fInputData->at(kk);
100 break;
101 }
102 kk = rand.Integer(fInputData->size());
103 }
104
105 while (1){
106 if (fInputData->at(kk)->GetTypeFlag()==1) {
107 fTEventUp = fInputData->at(kk);
108 break;
109 }
110 kk = rand.Integer(fInputData->size());
111 }
112 }
115}
116
117////////////////////////////////////////////////////////////////////////////////
118/// destructor
119
121{
122 if (fKMatrix != 0) {delete fKMatrix; fKMatrix = 0;}
123 delete fLogger;
124}
125
126////////////////////////////////////////////////////////////////////////////////
127
129{
130 SVEvent* ievt=0;
131 Float_t fErrorC_J = 0.;
132 if( jevt->GetIdx()==0) fErrorC_J = jevt->GetErrorCache();
133 else{
134 Float_t *fKVals = jevt->GetLine();
135 fErrorC_J = 0.;
136 std::vector<TMVA::SVEvent*>::iterator idIter;
137
138 UInt_t k=0;
139 for(idIter = fInputData->begin(); idIter != fInputData->end(); ++idIter){
140 if((*idIter)->GetAlpha()>0)
141 fErrorC_J += (*idIter)->GetAlpha()*(*idIter)->GetTypeFlag()*fKVals[k];
142 k++;
143 }
144
145
146 fErrorC_J -= jevt->GetTypeFlag();
147 jevt->SetErrorCache(fErrorC_J);
148
149 if((jevt->GetIdx() == 1) && (fErrorC_J < fB_up )){
150 fB_up = fErrorC_J;
151 fTEventUp = jevt;
152 }
153 else if ((jevt->GetIdx() == -1)&&(fErrorC_J > fB_low)) {
154 fB_low = fErrorC_J;
155 fTEventLow = jevt;
156 }
157 }
158 Bool_t converged = kTRUE;
159
160 if((jevt->GetIdx()>=0) && (fB_low - fErrorC_J > 2*fTolerance)) {
161 converged = kFALSE;
162 ievt = fTEventLow;
163 }
164
165 if((jevt->GetIdx()<=0) && (fErrorC_J - fB_up > 2*fTolerance)) {
166 converged = kFALSE;
167 ievt = fTEventUp;
168 }
169
170 if (converged) return kFALSE;
171
172 if(jevt->GetIdx()==0){
173 if(fB_low - fErrorC_J > fErrorC_J - fB_up) ievt = fTEventLow;
174 else ievt = fTEventUp;
175 }
176
177 if (TakeStep(ievt, jevt)) return kTRUE;
178 else return kFALSE;
179}
180
181
182////////////////////////////////////////////////////////////////////////////////
183
185{
186 if (ievt == jevt) return kFALSE;
187 std::vector<TMVA::SVEvent*>::iterator idIter;
188 const Float_t epsilon = 1e-8; //make it 1-e6 or 1-e5 to make it faster
189
190 Float_t type_I, type_J;
191 Float_t errorC_I, errorC_J;
192 Float_t alpha_I, alpha_J;
193
194 Float_t newAlpha_I, newAlpha_J;
195 Int_t s;
196
197 Float_t l, h, lobj = 0, hobj = 0;
198 Float_t eta;
199
200 type_I = ievt->GetTypeFlag();
201 alpha_I = ievt->GetAlpha();
202 errorC_I = ievt->GetErrorCache();
203
204 type_J = jevt->GetTypeFlag();
205 alpha_J = jevt->GetAlpha();
206 errorC_J = jevt->GetErrorCache();
207
208 s = Int_t( type_I * type_J );
209
210 Float_t c_i = ievt->GetCweight();
211
212 Float_t c_j = jevt->GetCweight();
213
214 // compute l, h objective function
215
216 if (type_I == type_J) {
217 Float_t gamma = alpha_I + alpha_J;
218
219 if ( c_i > c_j ) {
220 if ( gamma < c_j ) {
221 l = 0;
222 h = gamma;
223 }
224 else{
225 h = c_j;
226 if ( gamma < c_i )
227 l = 0;
228 else
229 l = gamma - c_i;
230 }
231 }
232 else {
233 if ( gamma < c_i ){
234 l = 0;
235 h = gamma;
236 }
237 else {
238 l = gamma - c_i;
239 if ( gamma < c_j )
240 h = gamma;
241 else
242 h = c_j;
243 }
244 }
245 }
246 else {
247 Float_t gamma = alpha_I - alpha_J;
248 if (gamma > 0) {
249 l = 0;
250 if ( gamma >= (c_i - c_j) )
251 h = c_i - gamma;
252 else
253 h = c_j;
254 }
255 else {
256 l = -gamma;
257 if ( (c_i - c_j) >= gamma)
258 h = c_j;
259 else
260 h = c_i - gamma;
261 }
262 }
263
264 if (l == h) return kFALSE;
265 Float_t kernel_II, kernel_IJ, kernel_JJ;
266
267 kernel_II = fKMatrix->GetElement(ievt->GetNs(),ievt->GetNs());
268 kernel_IJ = fKMatrix->GetElement(ievt->GetNs(), jevt->GetNs());
269 kernel_JJ = fKMatrix->GetElement(jevt->GetNs(),jevt->GetNs());
270
271 eta = 2*kernel_IJ - kernel_II - kernel_JJ;
272 if (eta < 0) {
273 newAlpha_J = alpha_J + (type_J*( errorC_J - errorC_I ))/eta;
274 if (newAlpha_J < l) newAlpha_J = l;
275 else if (newAlpha_J > h) newAlpha_J = h;
276
277 }
278
279 else {
280
281 Float_t c_I = eta/2;
282 Float_t c_J = type_J*( errorC_I - errorC_J ) - eta * alpha_J;
283 lobj = c_I * l * l + c_J * l;
284 hobj = c_I * h * h + c_J * h;
285
286 if (lobj > hobj + epsilon) newAlpha_J = l;
287 else if (lobj < hobj - epsilon) newAlpha_J = h;
288 else newAlpha_J = alpha_J;
289 }
290
291 if (TMath::Abs( newAlpha_J - alpha_J ) < ( epsilon * ( newAlpha_J + alpha_J+ epsilon ))){
292 return kFALSE;
293 //it spends here to much time... it is stupido
294 }
295 newAlpha_I = alpha_I - s*( newAlpha_J - alpha_J );
296
297 if (newAlpha_I < 0) {
298 newAlpha_J += s* newAlpha_I;
299 newAlpha_I = 0;
300 }
301 else if (newAlpha_I > c_i) {
302 Float_t temp = newAlpha_I - c_i;
303 newAlpha_J += s * temp;
304 newAlpha_I = c_i;
305 }
306
307 Float_t dL_I = type_I * ( newAlpha_I - alpha_I );
308 Float_t dL_J = type_J * ( newAlpha_J - alpha_J );
309
310 Int_t k = 0;
311 for(idIter = fInputData->begin(); idIter != fInputData->end(); ++idIter){
312 k++;
313 if((*idIter)->GetIdx()==0){
314 Float_t ii = fKMatrix->GetElement(ievt->GetNs(), (*idIter)->GetNs());
315 Float_t jj = fKMatrix->GetElement(jevt->GetNs(), (*idIter)->GetNs());
316
317 (*idIter)->UpdateErrorCache(dL_I * ii + dL_J * jj);
318 }
319 }
320 ievt->SetAlpha(newAlpha_I);
321 jevt->SetAlpha(newAlpha_J);
322 // set new indexes
323 SetIndex(ievt);
324 SetIndex(jevt);
325
326 // update error cache
327 ievt->SetErrorCache(errorC_I + dL_I*kernel_II + dL_J*kernel_IJ);
328 jevt->SetErrorCache(errorC_J + dL_I*kernel_IJ + dL_J*kernel_JJ);
329
330 // compute fI_low, fB_low
331
332 fB_low = -1*1e30;
333 fB_up = 1e30;
334
335 for(idIter = fInputData->begin(); idIter != fInputData->end(); ++idIter){
336 if((*idIter)->GetIdx()==0){
337 if((*idIter)->GetErrorCache()> fB_low){
338 fB_low = (*idIter)->GetErrorCache();
339 fTEventLow = (*idIter);
340 }
341 if( (*idIter)->GetErrorCache()< fB_up){
342 fB_up =(*idIter)->GetErrorCache();
343 fTEventUp = (*idIter);
344 }
345 }
346 }
347
348 // for optimized alfa's
349 if (fB_low < TMath::Max(ievt->GetErrorCache(), jevt->GetErrorCache())) {
350 if (ievt->GetErrorCache() > fB_low) {
351 fB_low = ievt->GetErrorCache();
352 fTEventLow = ievt;
353 }
354 else {
355 fB_low = jevt->GetErrorCache();
356 fTEventLow = jevt;
357 }
358 }
359
360 if (fB_up > TMath::Max(ievt->GetErrorCache(), jevt->GetErrorCache())) {
361 if (ievt->GetErrorCache()< fB_low) {
362 fB_up =ievt->GetErrorCache();
363 fTEventUp = ievt;
364 }
365 else {
366 fB_up =jevt->GetErrorCache() ;
367 fTEventUp = jevt;
368 }
369 }
370 return kTRUE;
371}
372
373////////////////////////////////////////////////////////////////////////////////
374
376{
377 if((fB_up > fB_low - 2*fTolerance)) return kTRUE;
378 return kFALSE;
379}
380
381////////////////////////////////////////////////////////////////////////////////
382/// train the SVM
383
385{
386
387 Int_t numChanged = 0;
388 Int_t examineAll = 1;
389
390 Float_t numChangedOld = 0;
391 Int_t deltaChanges = 0;
392 UInt_t numit = 0;
393
394 std::vector<TMVA::SVEvent*>::iterator idIter;
395
396 while ((numChanged > 0) || (examineAll > 0)) {
397 if (fIPyCurrentIter) *fIPyCurrentIter = numit;
398 if (fExitFromTraining && *fExitFromTraining) break;
399 numChanged = 0;
400 if (examineAll) {
401 for (idIter = fInputData->begin(); idIter!=fInputData->end(); ++idIter){
402 if(!fdoRegression) numChanged += (UInt_t)ExamineExample(*idIter);
403 else numChanged += (UInt_t)ExamineExampleReg(*idIter);
404 }
405 }
406 else {
407 for (idIter = fInputData->begin(); idIter!=fInputData->end(); ++idIter) {
408 if ((*idIter)->IsInI0()) {
409 if(!fdoRegression) numChanged += (UInt_t)ExamineExample(*idIter);
410 else numChanged += (UInt_t)ExamineExampleReg(*idIter);
411 if (Terminated()) {
412 numChanged = 0;
413 break;
414 }
415 }
416 }
417 }
418
419 if (examineAll == 1) examineAll = 0;
420 else if (numChanged == 0 || numChanged < 10 || deltaChanges > 3 ) examineAll = 1;
421
422 if (numChanged == numChangedOld) deltaChanges++;
423 else deltaChanges = 0;
424 numChangedOld = numChanged;
425 ++numit;
426
427 if (numit >= nMaxIter) {
428 *fLogger << kWARNING
429 << "Max number of iterations exceeded. "
430 << "Training may not be completed. Try use less Cost parameter" << Endl;
431 break;
432 }
433 }
434}
435
436////////////////////////////////////////////////////////////////////////////////
437
439{
440 if( (0< event->GetAlpha()) && (event->GetAlpha()< event->GetCweight()))
441 event->SetIdx(0);
442
443 if( event->GetTypeFlag() == 1){
444 if( event->GetAlpha() == 0)
445 event->SetIdx(1);
446 else if( event->GetAlpha() == event->GetCweight() )
447 event->SetIdx(-1);
448 }
449 if( event->GetTypeFlag() == -1){
450 if( event->GetAlpha() == 0)
451 event->SetIdx(-1);
452 else if( event->GetAlpha() == event->GetCweight() )
453 event->SetIdx(1);
454 }
455}
456
457////////////////////////////////////////////////////////////////////////////////
458
460{
461 std::vector<TMVA::SVEvent*>::iterator idIter;
462 UInt_t counter = 0;
463 for( idIter = fInputData->begin(); idIter != fInputData->end(); ++idIter)
464 if((*idIter)->GetAlpha() !=0) counter++;
465}
466
467////////////////////////////////////////////////////////////////////////////////
468
469std::vector<TMVA::SVEvent*>* TMVA::SVWorkingSet::GetSupportVectors()
470{
471 std::vector<TMVA::SVEvent*>::iterator idIter;
472 if( fSupVec != 0) {delete fSupVec; fSupVec = 0; }
473 fSupVec = new std::vector<TMVA::SVEvent*>(0);
474
475 for( idIter = fInputData->begin(); idIter != fInputData->end(); ++idIter){
476 if((*idIter)->GetDeltaAlpha() !=0){
477 fSupVec->push_back((*idIter));
478 }
479 }
480 return fSupVec;
481}
482
483//for regression
484
486{
487 if (ievt == jevt) return kFALSE;
488 std::vector<TMVA::SVEvent*>::iterator idIter;
489 const Float_t epsilon = 0.001*fTolerance;//TODO
490
491 const Float_t kernel_II = fKMatrix->GetElement(ievt->GetNs(),ievt->GetNs());
492 const Float_t kernel_IJ = fKMatrix->GetElement(ievt->GetNs(),jevt->GetNs());
493 const Float_t kernel_JJ = fKMatrix->GetElement(jevt->GetNs(),jevt->GetNs());
494
495 //compute eta & gamma
496 const Float_t eta = -2*kernel_IJ + kernel_II + kernel_JJ;
497 const Float_t gamma = ievt->GetDeltaAlpha() + jevt->GetDeltaAlpha();
498
499 //TODO CHECK WHAT IF ETA <0
500 //w.r.t Mercer's conditions it should never happen, but what if?
501
502 Bool_t caseA, caseB, caseC, caseD, terminated;
503 caseA = caseB = caseC = caseD = terminated = kFALSE;
504 Float_t b_alpha_i, b_alpha_j, b_alpha_i_p, b_alpha_j_p; //temporary Lagrange multipliers
505 const Float_t b_cost_i = ievt->GetCweight();
506 const Float_t b_cost_j = jevt->GetCweight();
507
508 b_alpha_i = ievt->GetAlpha();
509 b_alpha_j = jevt->GetAlpha();
510 b_alpha_i_p = ievt->GetAlpha_p();
511 b_alpha_j_p = jevt->GetAlpha_p();
512
513 //calculate deltafi
514 Float_t deltafi = ievt->GetErrorCache()-jevt->GetErrorCache();
515
516 // main loop
517 while(!terminated) {
518 const Float_t null = 0.; //!!! dummy float null declaration because of problems with TMath::Max/Min(Float_t, Float_t) function
519 Float_t low, high;
520 Float_t tmp_alpha_i, tmp_alpha_j;
521 tmp_alpha_i = tmp_alpha_j = 0.;
522
523 //TODO check this conditions, are they proper
524 if((caseA == kFALSE) && (b_alpha_i > 0 || (b_alpha_i_p == 0 && deltafi > 0)) && (b_alpha_j > 0 || (b_alpha_j_p == 0 && deltafi < 0)))
525 {
526 //compute low, high w.r.t a_i, a_j
527 low = TMath::Max( null, gamma - b_cost_j );
528 high = TMath::Min( b_cost_i , gamma);
529
530 if(low<high){
531 tmp_alpha_j = b_alpha_j - (deltafi/eta);
532 tmp_alpha_j = TMath::Min(tmp_alpha_j,high );
533 tmp_alpha_j = TMath::Max(low ,tmp_alpha_j);
534 tmp_alpha_i = b_alpha_i - (tmp_alpha_j - b_alpha_j);
535
536 //update Li & Lj if change is significant (??)
537 if( IsDiffSignificant(b_alpha_j,tmp_alpha_j, epsilon) || IsDiffSignificant(b_alpha_i,tmp_alpha_i, epsilon)){
538 b_alpha_j = tmp_alpha_j;
539 b_alpha_i = tmp_alpha_i;
540 }
541
542 }
543 else
544 terminated = kTRUE;
545
546 caseA = kTRUE;
547 }
548 else if((caseB==kFALSE) && (b_alpha_i>0 || (b_alpha_i_p==0 && deltafi >2*epsilon )) && (b_alpha_j_p>0 || (b_alpha_j==0 && deltafi>2*epsilon)))
549 {
550 //compute LH w.r.t. a_i, a_j*
551 low = TMath::Max( null, gamma ); //TODO
552 high = TMath::Min( b_cost_i , b_cost_j + gamma);
553
554
555 if(low<high){
556 tmp_alpha_j = b_alpha_j_p - ((deltafi-2*epsilon)/eta);
557 tmp_alpha_j = TMath::Min(tmp_alpha_j,high);
558 tmp_alpha_j = TMath::Max(low,tmp_alpha_j);
559 tmp_alpha_i = b_alpha_i - (tmp_alpha_j - b_alpha_j_p);
560
561 //update alphai alphaj_p
562 if( IsDiffSignificant(b_alpha_j_p,tmp_alpha_j, epsilon) || IsDiffSignificant(b_alpha_i,tmp_alpha_i, epsilon)){
563 b_alpha_j_p = tmp_alpha_j;
564 b_alpha_i = tmp_alpha_i;
565 }
566 }
567 else
568 terminated = kTRUE;
569
570 caseB = kTRUE;
571 }
572 else if((caseC==kFALSE) && (b_alpha_i_p>0 || (b_alpha_i==0 && deltafi < -2*epsilon )) && (b_alpha_j>0 || (b_alpha_j_p==0 && deltafi< -2*epsilon)))
573 {
574 //compute LH w.r.t. alphai_p alphaj
575 low = TMath::Max(null, -gamma );
576 high = TMath::Min(b_cost_i, -gamma+b_cost_j);
577
578 if(low<high){
579 tmp_alpha_j = b_alpha_j - ((deltafi+2*epsilon)/eta);
580 tmp_alpha_j = TMath::Min(tmp_alpha_j,high );
581 tmp_alpha_j = TMath::Max(low ,tmp_alpha_j);
582 tmp_alpha_i = b_alpha_i_p - (tmp_alpha_j - b_alpha_j);
583
584 //update alphai_p alphaj
585 if( IsDiffSignificant(b_alpha_j,tmp_alpha_j, epsilon) || IsDiffSignificant(b_alpha_i_p,tmp_alpha_i, epsilon)){
586 b_alpha_j = tmp_alpha_j;
587 b_alpha_i_p = tmp_alpha_i;
588 }
589 }
590 else
591 terminated = kTRUE;
592
593 caseC = kTRUE;
594 }
595 else if((caseD == kFALSE) &&
596 (b_alpha_i_p>0 || (b_alpha_i==0 && deltafi <0 )) &&
597 (b_alpha_j_p>0 || (b_alpha_j==0 && deltafi >0 )))
598 {
599 //compute LH w.r.t. alphai_p alphaj_p
600 low = TMath::Max(null,-gamma - b_cost_j);
601 high = TMath::Min(b_cost_i, -gamma);
602
603 if(low<high){
604 tmp_alpha_j = b_alpha_j_p + (deltafi/eta);
605 tmp_alpha_j = TMath::Min(tmp_alpha_j,high );
606 tmp_alpha_j = TMath::Max(low ,tmp_alpha_j);
607 tmp_alpha_i = b_alpha_i_p - (tmp_alpha_j - b_alpha_j_p);
608
609 if( IsDiffSignificant(b_alpha_j_p,tmp_alpha_j, epsilon) || IsDiffSignificant(b_alpha_i_p,tmp_alpha_i, epsilon)){
610 b_alpha_j_p = tmp_alpha_j;
611 b_alpha_i_p = tmp_alpha_i;
612 }
613 }
614 else
615 terminated = kTRUE;
616
617 caseD = kTRUE;
618 }
619 else
620 terminated = kTRUE;
621 }
622 // TODO ad commment how it was calculated
623 deltafi += ievt->GetDeltaAlpha()*(kernel_II - kernel_IJ) + jevt->GetDeltaAlpha()*(kernel_IJ - kernel_JJ);
624
625 if( IsDiffSignificant(b_alpha_i, ievt->GetAlpha(), epsilon) ||
626 IsDiffSignificant(b_alpha_j, jevt->GetAlpha(), epsilon) ||
627 IsDiffSignificant(b_alpha_i_p, ievt->GetAlpha_p(), epsilon) ||
628 IsDiffSignificant(b_alpha_j_p, jevt->GetAlpha_p(), epsilon) ){
629
630 //TODO check if these conditions might be easier
631 //TODO write documentation for this
632 const Float_t diff_alpha_i = ievt->GetDeltaAlpha()+b_alpha_i_p - ievt->GetAlpha();
633 const Float_t diff_alpha_j = jevt->GetDeltaAlpha()+b_alpha_j_p - jevt->GetAlpha();
634
635 //update error cache
636 Int_t k = 0;
637 for(idIter = fInputData->begin(); idIter != fInputData->end(); ++idIter){
638 k++;
639 //there will be some changes in Idx notation
640 if((*idIter)->GetIdx()==0){
641 Float_t k_ii = fKMatrix->GetElement(ievt->GetNs(), (*idIter)->GetNs());
642 Float_t k_jj = fKMatrix->GetElement(jevt->GetNs(), (*idIter)->GetNs());
643
644 (*idIter)->UpdateErrorCache(diff_alpha_i * k_ii + diff_alpha_j * k_jj);
645 }
646 }
647
648 //store new alphas in SVevents
649 ievt->SetAlpha(b_alpha_i);
650 jevt->SetAlpha(b_alpha_j);
651 ievt->SetAlpha_p(b_alpha_i_p);
652 jevt->SetAlpha_p(b_alpha_j_p);
653
654 //TODO update Idexes
655
656 // compute fI_low, fB_low
657
658 fB_low = -1*1e30;
659 fB_up =1e30;
660
661 for(idIter = fInputData->begin(); idIter != fInputData->end(); ++idIter){
662 if((!(*idIter)->IsInI3()) && ((*idIter)->GetErrorCache()> fB_low)){
663 fB_low = (*idIter)->GetErrorCache();
664 fTEventLow = (*idIter);
665
666 }
667 if((!(*idIter)->IsInI2()) && ((*idIter)->GetErrorCache()< fB_up)){
668 fB_up =(*idIter)->GetErrorCache();
669 fTEventUp = (*idIter);
670 }
671 }
672 return kTRUE;
673 } else return kFALSE;
674}
675
676
677////////////////////////////////////////////////////////////////////////////////
678
680{
681 Float_t feps = 1e-7;// TODO check which value is the best
682 SVEvent* ievt=0;
683 Float_t fErrorC_J = 0.;
684 if( jevt->IsInI0()) {
685 fErrorC_J = jevt->GetErrorCache();
686 }
687 else{
688 Float_t *fKVals = jevt->GetLine();
689 fErrorC_J = 0.;
690 std::vector<TMVA::SVEvent*>::iterator idIter;
691
692 UInt_t k=0;
693 for(idIter = fInputData->begin(); idIter != fInputData->end(); ++idIter){
694 fErrorC_J -= (*idIter)->GetDeltaAlpha()*fKVals[k];
695 k++;
696 }
697
698 fErrorC_J += jevt->GetTarget();
699 jevt->SetErrorCache(fErrorC_J);
700
701 if(jevt->IsInI1()){
702 if(fErrorC_J + feps < fB_up ){
703 fB_up = fErrorC_J + feps;
704 fTEventUp = jevt;
705 }
706 else if(fErrorC_J -feps > fB_low) {
707 fB_low = fErrorC_J - feps;
708 fTEventLow = jevt;
709 }
710 }else if((jevt->IsInI2()) && (fErrorC_J + feps > fB_low)){
711 fB_low = fErrorC_J + feps;
712 fTEventLow = jevt;
713 }else if((jevt->IsInI3()) && (fErrorC_J - feps < fB_up)){
714 fB_up = fErrorC_J - feps;
715 fTEventUp = jevt;
716 }
717 }
718
719 Bool_t converged = kTRUE;
720 //case 1
721 if(jevt->IsInI0a()){
722 if( fB_low -fErrorC_J + feps > 2*fTolerance){
723 converged = kFALSE;
724 ievt = fTEventLow;
725 if(fErrorC_J-feps-fB_up > fB_low-fErrorC_J+feps){
726 ievt = fTEventUp;
727 }
728 }else if(fErrorC_J -feps - fB_up > 2*fTolerance){
729 converged = kFALSE;
730 ievt = fTEventUp;
731 if(fB_low - fErrorC_J+feps > fErrorC_J-feps -fB_up){
732 ievt = fTEventLow;
733 }
734 }
735 }
736
737 //case 2
738 if(jevt->IsInI0b()){
739 if( fB_low -fErrorC_J - feps > 2*fTolerance){
740 converged = kFALSE;
741 ievt = fTEventLow;
742 if(fErrorC_J+feps-fB_up > fB_low-fErrorC_J-feps){
743 ievt = fTEventUp;
744 }
745 }else if(fErrorC_J + feps - fB_up > 2*fTolerance){
746 converged = kFALSE;
747 ievt = fTEventUp;
748 if(fB_low - fErrorC_J-feps > fErrorC_J+feps -fB_up){
749 ievt = fTEventLow;
750 }
751 }
752 }
753
754 //case 3
755 if(jevt->IsInI1()){
756 if( fB_low -fErrorC_J - feps > 2*fTolerance){
757 converged = kFALSE;
758 ievt = fTEventLow;
759 if(fErrorC_J+feps-fB_up > fB_low-fErrorC_J-feps){
760 ievt = fTEventUp;
761 }
762 }else if(fErrorC_J - feps - fB_up > 2*fTolerance){
763 converged = kFALSE;
764 ievt = fTEventUp;
765 if(fB_low - fErrorC_J+feps > fErrorC_J-feps -fB_up){
766 ievt = fTEventLow;
767 }
768 }
769 }
770
771 //case 4
772 if(jevt->IsInI2()){
773 if( fErrorC_J + feps -fB_up > 2*fTolerance){
774 converged = kFALSE;
775 ievt = fTEventUp;
776 }
777 }
778
779 //case 5
780 if(jevt->IsInI3()){
781 if(fB_low -fErrorC_J +feps > 2*fTolerance){
782 converged = kFALSE;
783 ievt = fTEventLow;
784 }
785 }
786
787 if(converged) return kFALSE;
788 if (TakeStepReg(ievt, jevt)) return kTRUE;
789 else return kFALSE;
790}
791
793{
794 if( TMath::Abs(a_i - a_j) > eps*(a_i + a_j + eps)) return kTRUE;
795 else return kFALSE;
796}
797
#define h(i)
Definition RSha256.hxx:106
#define e(i)
Definition RSha256.hxx:103
int Int_t
Definition RtypesCore.h:45
unsigned int UInt_t
Definition RtypesCore.h:46
const Bool_t kFALSE
Definition RtypesCore.h:92
float Float_t
Definition RtypesCore.h:57
const Bool_t kTRUE
Definition RtypesCore.h:91
ostringstream derivative to redirect and format output
Definition MsgLogger.h:59
Event class for Support Vector Machine.
Definition SVEvent.h:40
Bool_t IsInI0a() const
Definition SVEvent.h:74
Float_t GetTarget() const
Definition SVEvent.h:72
Int_t GetIdx() const
Definition SVEvent.h:68
Bool_t IsInI2() const
Definition SVEvent.h:78
Float_t GetErrorCache() const
Definition SVEvent.h:65
Float_t GetCweight() const
Definition SVEvent.h:71
Float_t * GetLine() const
Definition SVEvent.h:69
void SetAlpha_p(Float_t alpha)
Definition SVEvent.h:52
Bool_t IsInI3() const
Definition SVEvent.h:79
Float_t GetAlpha() const
Definition SVEvent.h:61
Bool_t IsInI0() const
Definition SVEvent.h:76
UInt_t GetNs() const
Definition SVEvent.h:70
Bool_t IsInI0b() const
Definition SVEvent.h:75
Float_t GetAlpha_p() const
Definition SVEvent.h:62
void SetAlpha(Float_t alpha)
Definition SVEvent.h:51
Bool_t IsInI1() const
Definition SVEvent.h:77
Int_t GetTypeFlag() const
Definition SVEvent.h:66
void SetErrorCache(Float_t err_cache)
Definition SVEvent.h:53
Float_t GetDeltaAlpha() const
Definition SVEvent.h:63
Kernel for Support Vector Machine.
Kernel matrix for Support Vector Machine.
Float_t * GetLine(UInt_t)
returns a row of the kernel matrix
Bool_t TakeStep(SVEvent *, SVEvent *)
void Train(UInt_t nIter=1000)
train the SVM
Bool_t IsDiffSignificant(Float_t, Float_t, Float_t)
Bool_t ExamineExample(SVEvent *)
Bool_t ExamineExampleReg(SVEvent *)
Bool_t TakeStepReg(SVEvent *, SVEvent *)
void SetIndex(TMVA::SVEvent *)
~SVWorkingSet()
destructor
SVWorkingSet()
constructor
std::vector< TMVA::SVEvent * > * fInputData
SVKernelMatrix * fKMatrix
std::vector< TMVA::SVEvent * > * GetSupportVectors()
Random number generator class based on M.
Definition TRandom3.h:27
virtual UInt_t Integer(UInt_t imax)
Returns a random integer uniformly distributed on the interval [ 0, imax-1 ].
Definition TRandom.cxx:360
TPaveText * pt
null_t< F > null()
MsgLogger & Endl(MsgLogger &ml)
Definition MsgLogger.h:158
Short_t Max(Short_t a, Short_t b)
Definition TMathBase.h:212
Short_t Min(Short_t a, Short_t b)
Definition TMathBase.h:180
Short_t Abs(Short_t d)
Definition TMathBase.h:120
auto * l
Definition textangle.C:4
REAL epsilon
Definition triangle.c:617