Logo ROOT   6.14/05
Reference Guide
SVWorkingSet.cxx
Go to the documentation of this file.
1 // @(#)root/tmva $Id$
2 // Author: Andrzej Zemla
3 
4 /**********************************************************************************
5  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6  * Package: TMVA *
7  * Class : SVWorkingSet *
8  * Web : http://tmva.sourceforge.net *
9  * *
10  * Description: *
11  * Implementation *
12  * *
13  * Authors (alphabetical): *
14  * Marcin Wolter <Marcin.Wolter@cern.ch> - IFJ PAN, Krakow, Poland *
15  * Andrzej Zemla <azemla@cern.ch> - IFJ PAN, Krakow, Poland *
16  * (IFJ PAN: Henryk Niewodniczanski Inst. Nucl. Physics, Krakow, Poland) *
17  * *
18  * Copyright (c) 2005: *
19  * CERN, Switzerland *
20  * MPI-K Heidelberg, Germany *
21  * PAN, Krakow, Poland *
22  * *
23  * Redistribution and use in source and binary forms, with or without *
24  * modification, are permitted according to the terms listed in LICENSE *
25  * (http://tmva.sourceforge.net/LICENSE) *
26  **********************************************************************************/
27 
28 /*! \class TMVA::SVWorkingSet
29 \ingroup TMVA
30 Working class for Support Vector Machine
31 */
32 
33 #include "TMVA/SVWorkingSet.h"
34 
35 #include "TMVA/MsgLogger.h"
36 #include "TMVA/SVEvent.h"
37 #include "TMVA/SVKernelFunction.h"
38 #include "TMVA/SVKernelMatrix.h"
39 #include "TMVA/Types.h"
40 
41 
42 #include "TMath.h"
43 #include "TRandom3.h"
44 
45 #include <iostream>
46 #include <vector>
47 
48 ////////////////////////////////////////////////////////////////////////////////
49 /// constructor
50 
52  : fdoRegression(kFALSE),
53  fInputData(0),
54  fSupVec(0),
55  fKFunction(0),
56  fKMatrix(0),
57  fTEventUp(0),
58  fTEventLow(0),
59  fB_low(1.),
60  fB_up(-1.),
61  fTolerance(0.01),
62  fLogger( new MsgLogger( "SVWorkingSet", kINFO ) )
63 {
64 }
65 
66 ////////////////////////////////////////////////////////////////////////////////
67 /// constructor
68 
69 TMVA::SVWorkingSet::SVWorkingSet(std::vector<TMVA::SVEvent*>*inputVectors, SVKernelFunction* kernelFunction,
70  Float_t tol, Bool_t doreg)
71  : fdoRegression(doreg),
72  fInputData(inputVectors),
73  fSupVec(0),
74  fKFunction(kernelFunction),
75  fTEventUp(0),
76  fTEventLow(0),
77  fB_low(1.),
78  fB_up(-1.),
79  fTolerance(tol),
80  fLogger( new MsgLogger( "SVWorkingSet", kINFO ) )
81 {
82  fKMatrix = new TMVA::SVKernelMatrix(inputVectors, kernelFunction);
83  Float_t *pt;
84  for( UInt_t i = 0; i < fInputData->size(); i++){
85  pt = fKMatrix->GetLine(i);
86  fInputData->at(i)->SetLine(pt);
87  fInputData->at(i)->SetNs(i);
88  if(fdoRegression) fInputData->at(i)->SetErrorCache(fInputData->at(i)->GetTarget());
89  }
90  TRandom3 rand;
91  UInt_t kk = rand.Integer(fInputData->size());
92  if(fdoRegression) {
96  }
97  else{
98  while(1){
99  if(fInputData->at(kk)->GetTypeFlag()==-1){
100  fTEventLow = fInputData->at(kk);
101  break;
102  }
103  kk = rand.Integer(fInputData->size());
104  }
105 
106  while (1){
107  if (fInputData->at(kk)->GetTypeFlag()==1) {
108  fTEventUp = fInputData->at(kk);
109  break;
110  }
111  kk = rand.Integer(fInputData->size());
112  }
113  }
116 }
117 
118 ////////////////////////////////////////////////////////////////////////////////
119 /// destructor
120 
122 {
123  if (fKMatrix != 0) {delete fKMatrix; fKMatrix = 0;}
124  delete fLogger;
125 }
126 
127 ////////////////////////////////////////////////////////////////////////////////
128 
130 {
131  SVEvent* ievt=0;
132  Float_t fErrorC_J = 0.;
133  if( jevt->GetIdx()==0) fErrorC_J = jevt->GetErrorCache();
134  else{
135  Float_t *fKVals = jevt->GetLine();
136  fErrorC_J = 0.;
137  std::vector<TMVA::SVEvent*>::iterator idIter;
138 
139  UInt_t k=0;
140  for(idIter = fInputData->begin(); idIter != fInputData->end(); ++idIter){
141  if((*idIter)->GetAlpha()>0)
142  fErrorC_J += (*idIter)->GetAlpha()*(*idIter)->GetTypeFlag()*fKVals[k];
143  k++;
144  }
145 
146 
147  fErrorC_J -= jevt->GetTypeFlag();
148  jevt->SetErrorCache(fErrorC_J);
149 
150  if((jevt->GetIdx() == 1) && (fErrorC_J < fB_up )){
151  fB_up = fErrorC_J;
152  fTEventUp = jevt;
153  }
154  else if ((jevt->GetIdx() == -1)&&(fErrorC_J > fB_low)) {
155  fB_low = fErrorC_J;
156  fTEventLow = jevt;
157  }
158  }
159  Bool_t converged = kTRUE;
160 
161  if((jevt->GetIdx()>=0) && (fB_low - fErrorC_J > 2*fTolerance)) {
162  converged = kFALSE;
163  ievt = fTEventLow;
164  }
165 
166  if((jevt->GetIdx()<=0) && (fErrorC_J - fB_up > 2*fTolerance)) {
167  converged = kFALSE;
168  ievt = fTEventUp;
169  }
170 
171  if (converged) return kFALSE;
172 
173  if(jevt->GetIdx()==0){
174  if(fB_low - fErrorC_J > fErrorC_J - fB_up) ievt = fTEventLow;
175  else ievt = fTEventUp;
176  }
177 
178  if (TakeStep(ievt, jevt)) return kTRUE;
179  else return kFALSE;
180 }
181 
182 
183 ////////////////////////////////////////////////////////////////////////////////
184 
186 {
187  if (ievt == jevt) return kFALSE;
188  std::vector<TMVA::SVEvent*>::iterator idIter;
189  const Float_t epsilon = 1e-8; //make it 1-e6 or 1-e5 to make it faster
190 
191  Float_t type_I, type_J;
192  Float_t errorC_I, errorC_J;
193  Float_t alpha_I, alpha_J;
194 
195  Float_t newAlpha_I, newAlpha_J;
196  Int_t s;
197 
198  Float_t l, h, lobj = 0, hobj = 0;
199  Float_t eta;
200 
201  type_I = ievt->GetTypeFlag();
202  alpha_I = ievt->GetAlpha();
203  errorC_I = ievt->GetErrorCache();
204 
205  type_J = jevt->GetTypeFlag();
206  alpha_J = jevt->GetAlpha();
207  errorC_J = jevt->GetErrorCache();
208 
209  s = Int_t( type_I * type_J );
210 
211  Float_t c_i = ievt->GetCweight();
212 
213  Float_t c_j = jevt->GetCweight();
214 
215  // compute l, h objective function
216 
217  if (type_I == type_J) {
218  Float_t gamma = alpha_I + alpha_J;
219 
220  if ( c_i > c_j ) {
221  if ( gamma < c_j ) {
222  l = 0;
223  h = gamma;
224  }
225  else{
226  h = c_j;
227  if ( gamma < c_i )
228  l = 0;
229  else
230  l = gamma - c_i;
231  }
232  }
233  else {
234  if ( gamma < c_i ){
235  l = 0;
236  h = gamma;
237  }
238  else {
239  l = gamma - c_i;
240  if ( gamma < c_j )
241  h = gamma;
242  else
243  h = c_j;
244  }
245  }
246  }
247  else {
248  Float_t gamma = alpha_I - alpha_J;
249  if (gamma > 0) {
250  l = 0;
251  if ( gamma >= (c_i - c_j) )
252  h = c_i - gamma;
253  else
254  h = c_j;
255  }
256  else {
257  l = -gamma;
258  if ( (c_i - c_j) >= gamma)
259  h = c_j;
260  else
261  h = c_i - gamma;
262  }
263  }
264 
265  if (l == h) return kFALSE;
266  Float_t kernel_II, kernel_IJ, kernel_JJ;
267 
268  kernel_II = fKMatrix->GetElement(ievt->GetNs(),ievt->GetNs());
269  kernel_IJ = fKMatrix->GetElement(ievt->GetNs(), jevt->GetNs());
270  kernel_JJ = fKMatrix->GetElement(jevt->GetNs(),jevt->GetNs());
271 
272  eta = 2*kernel_IJ - kernel_II - kernel_JJ;
273  if (eta < 0) {
274  newAlpha_J = alpha_J + (type_J*( errorC_J - errorC_I ))/eta;
275  if (newAlpha_J < l) newAlpha_J = l;
276  else if (newAlpha_J > h) newAlpha_J = h;
277 
278  }
279 
280  else {
281 
282  Float_t c_I = eta/2;
283  Float_t c_J = type_J*( errorC_I - errorC_J ) - eta * alpha_J;
284  lobj = c_I * l * l + c_J * l;
285  hobj = c_I * h * h + c_J * h;
286 
287  if (lobj > hobj + epsilon) newAlpha_J = l;
288  else if (lobj < hobj - epsilon) newAlpha_J = h;
289  else newAlpha_J = alpha_J;
290  }
291 
292  if (TMath::Abs( newAlpha_J - alpha_J ) < ( epsilon * ( newAlpha_J + alpha_J+ epsilon ))){
293  return kFALSE;
294  //it spends here to much time... it is stupido
295  }
296  newAlpha_I = alpha_I - s*( newAlpha_J - alpha_J );
297 
298  if (newAlpha_I < 0) {
299  newAlpha_J += s* newAlpha_I;
300  newAlpha_I = 0;
301  }
302  else if (newAlpha_I > c_i) {
303  Float_t temp = newAlpha_I - c_i;
304  newAlpha_J += s * temp;
305  newAlpha_I = c_i;
306  }
307 
308  Float_t dL_I = type_I * ( newAlpha_I - alpha_I );
309  Float_t dL_J = type_J * ( newAlpha_J - alpha_J );
310 
311  Int_t k = 0;
312  for(idIter = fInputData->begin(); idIter != fInputData->end(); ++idIter){
313  k++;
314  if((*idIter)->GetIdx()==0){
315  Float_t ii = fKMatrix->GetElement(ievt->GetNs(), (*idIter)->GetNs());
316  Float_t jj = fKMatrix->GetElement(jevt->GetNs(), (*idIter)->GetNs());
317 
318  (*idIter)->UpdateErrorCache(dL_I * ii + dL_J * jj);
319  }
320  }
321  ievt->SetAlpha(newAlpha_I);
322  jevt->SetAlpha(newAlpha_J);
323  // set new indexes
324  SetIndex(ievt);
325  SetIndex(jevt);
326 
327  // update error cache
328  ievt->SetErrorCache(errorC_I + dL_I*kernel_II + dL_J*kernel_IJ);
329  jevt->SetErrorCache(errorC_J + dL_I*kernel_IJ + dL_J*kernel_JJ);
330 
331  // compute fI_low, fB_low
332 
333  fB_low = -1*1e30;
334  fB_up = 1e30;
335 
336  for(idIter = fInputData->begin(); idIter != fInputData->end(); ++idIter){
337  if((*idIter)->GetIdx()==0){
338  if((*idIter)->GetErrorCache()> fB_low){
339  fB_low = (*idIter)->GetErrorCache();
340  fTEventLow = (*idIter);
341  }
342  if( (*idIter)->GetErrorCache()< fB_up){
343  fB_up =(*idIter)->GetErrorCache();
344  fTEventUp = (*idIter);
345  }
346  }
347  }
348 
349  // for optimized alfa's
350  if (fB_low < TMath::Max(ievt->GetErrorCache(), jevt->GetErrorCache())) {
351  if (ievt->GetErrorCache() > fB_low) {
352  fB_low = ievt->GetErrorCache();
353  fTEventLow = ievt;
354  }
355  else {
356  fB_low = jevt->GetErrorCache();
357  fTEventLow = jevt;
358  }
359  }
360 
361  if (fB_up > TMath::Max(ievt->GetErrorCache(), jevt->GetErrorCache())) {
362  if (ievt->GetErrorCache()< fB_low) {
363  fB_up =ievt->GetErrorCache();
364  fTEventUp = ievt;
365  }
366  else {
367  fB_up =jevt->GetErrorCache() ;
368  fTEventUp = jevt;
369  }
370  }
371  return kTRUE;
372 }
373 
374 ////////////////////////////////////////////////////////////////////////////////
375 
377 {
378  if((fB_up > fB_low - 2*fTolerance)) return kTRUE;
379  return kFALSE;
380 }
381 
382 ////////////////////////////////////////////////////////////////////////////////
383 /// train the SVM
384 
386 {
387 
388  Int_t numChanged = 0;
389  Int_t examineAll = 1;
390 
391  Float_t numChangedOld = 0;
392  Int_t deltaChanges = 0;
393  UInt_t numit = 0;
394 
395  std::vector<TMVA::SVEvent*>::iterator idIter;
396 
397  while ((numChanged > 0) || (examineAll > 0)) {
398  if (fIPyCurrentIter) *fIPyCurrentIter = numit;
399  if (fExitFromTraining && *fExitFromTraining) break;
400  numChanged = 0;
401  if (examineAll) {
402  for (idIter = fInputData->begin(); idIter!=fInputData->end(); ++idIter){
403  if(!fdoRegression) numChanged += (UInt_t)ExamineExample(*idIter);
404  else numChanged += (UInt_t)ExamineExampleReg(*idIter);
405  }
406  }
407  else {
408  for (idIter = fInputData->begin(); idIter!=fInputData->end(); ++idIter) {
409  if ((*idIter)->IsInI0()) {
410  if(!fdoRegression) numChanged += (UInt_t)ExamineExample(*idIter);
411  else numChanged += (UInt_t)ExamineExampleReg(*idIter);
412  if (Terminated()) {
413  numChanged = 0;
414  break;
415  }
416  }
417  }
418  }
419 
420  if (examineAll == 1) examineAll = 0;
421  else if (numChanged == 0 || numChanged < 10 || deltaChanges > 3 ) examineAll = 1;
422 
423  if (numChanged == numChangedOld) deltaChanges++;
424  else deltaChanges = 0;
425  numChangedOld = numChanged;
426  ++numit;
427 
428  if (numit >= nMaxIter) {
429  *fLogger << kWARNING
430  << "Max number of iterations exceeded. "
431  << "Training may not be completed. Try use less Cost parameter" << Endl;
432  break;
433  }
434  }
435 }
436 
437 ////////////////////////////////////////////////////////////////////////////////
438 
440 {
441  if( (0< event->GetAlpha()) && (event->GetAlpha()< event->GetCweight()))
442  event->SetIdx(0);
443 
444  if( event->GetTypeFlag() == 1){
445  if( event->GetAlpha() == 0)
446  event->SetIdx(1);
447  else if( event->GetAlpha() == event->GetCweight() )
448  event->SetIdx(-1);
449  }
450  if( event->GetTypeFlag() == -1){
451  if( event->GetAlpha() == 0)
452  event->SetIdx(-1);
453  else if( event->GetAlpha() == event->GetCweight() )
454  event->SetIdx(1);
455  }
456 }
457 
458 ////////////////////////////////////////////////////////////////////////////////
459 
461 {
462  std::vector<TMVA::SVEvent*>::iterator idIter;
463  UInt_t counter = 0;
464  for( idIter = fInputData->begin(); idIter != fInputData->end(); ++idIter)
465  if((*idIter)->GetAlpha() !=0) counter++;
466 }
467 
468 ////////////////////////////////////////////////////////////////////////////////
469 
470 std::vector<TMVA::SVEvent*>* TMVA::SVWorkingSet::GetSupportVectors()
471 {
472  std::vector<TMVA::SVEvent*>::iterator idIter;
473  if( fSupVec != 0) {delete fSupVec; fSupVec = 0; }
474  fSupVec = new std::vector<TMVA::SVEvent*>(0);
475 
476  for( idIter = fInputData->begin(); idIter != fInputData->end(); ++idIter){
477  if((*idIter)->GetDeltaAlpha() !=0){
478  fSupVec->push_back((*idIter));
479  }
480  }
481  return fSupVec;
482 }
483 
484 //for regression
485 
487 {
488  if (ievt == jevt) return kFALSE;
489  std::vector<TMVA::SVEvent*>::iterator idIter;
490  const Float_t epsilon = 0.001*fTolerance;//TODO
491 
492  const Float_t kernel_II = fKMatrix->GetElement(ievt->GetNs(),ievt->GetNs());
493  const Float_t kernel_IJ = fKMatrix->GetElement(ievt->GetNs(),jevt->GetNs());
494  const Float_t kernel_JJ = fKMatrix->GetElement(jevt->GetNs(),jevt->GetNs());
495 
496  //compute eta & gamma
497  const Float_t eta = -2*kernel_IJ + kernel_II + kernel_JJ;
498  const Float_t gamma = ievt->GetDeltaAlpha() + jevt->GetDeltaAlpha();
499 
500  //TODO CHECK WHAT IF ETA <0
501  //w.r.t Mercer's conditions it should never happen, but what if?
502 
503  Bool_t caseA, caseB, caseC, caseD, terminated;
504  caseA = caseB = caseC = caseD = terminated = kFALSE;
505  Float_t b_alpha_i, b_alpha_j, b_alpha_i_p, b_alpha_j_p; //temporary Lagrange multipliers
506  const Float_t b_cost_i = ievt->GetCweight();
507  const Float_t b_cost_j = jevt->GetCweight();
508 
509  b_alpha_i = ievt->GetAlpha();
510  b_alpha_j = jevt->GetAlpha();
511  b_alpha_i_p = ievt->GetAlpha_p();
512  b_alpha_j_p = jevt->GetAlpha_p();
513 
514  //calculate deltafi
515  Float_t deltafi = ievt->GetErrorCache()-jevt->GetErrorCache();
516 
517  // main loop
518  while(!terminated) {
519  const Float_t null = 0.; //!!! dummy float null declaration because of problems with TMath::Max/Min(Float_t, Float_t) function
520  Float_t low, high;
521  Float_t tmp_alpha_i, tmp_alpha_j;
522  tmp_alpha_i = tmp_alpha_j = 0.;
523 
524  //TODO check this conditions, are they proper
525  if((caseA == kFALSE) && (b_alpha_i > 0 || (b_alpha_i_p == 0 && deltafi > 0)) && (b_alpha_j > 0 || (b_alpha_j_p == 0 && deltafi < 0)))
526  {
527  //compute low, high w.r.t a_i, a_j
528  low = TMath::Max( null, gamma - b_cost_j );
529  high = TMath::Min( b_cost_i , gamma);
530 
531  if(low<high){
532  tmp_alpha_j = b_alpha_j - (deltafi/eta);
533  tmp_alpha_j = TMath::Min(tmp_alpha_j,high );
534  tmp_alpha_j = TMath::Max(low ,tmp_alpha_j);
535  tmp_alpha_i = b_alpha_i - (tmp_alpha_j - b_alpha_j);
536 
537  //update Li & Lj if change is significant (??)
538  if( IsDiffSignificant(b_alpha_j,tmp_alpha_j, epsilon) || IsDiffSignificant(b_alpha_i,tmp_alpha_i, epsilon)){
539  b_alpha_j = tmp_alpha_j;
540  b_alpha_i = tmp_alpha_i;
541  }
542 
543  }
544  else
545  terminated = kTRUE;
546 
547  caseA = kTRUE;
548  }
549  else if((caseB==kFALSE) && (b_alpha_i>0 || (b_alpha_i_p==0 && deltafi >2*epsilon )) && (b_alpha_j_p>0 || (b_alpha_j==0 && deltafi>2*epsilon)))
550  {
551  //compute LH w.r.t. a_i, a_j*
552  low = TMath::Max( null, gamma ); //TODO
553  high = TMath::Min( b_cost_i , b_cost_j + gamma);
554 
555 
556  if(low<high){
557  tmp_alpha_j = b_alpha_j_p - ((deltafi-2*epsilon)/eta);
558  tmp_alpha_j = TMath::Min(tmp_alpha_j,high);
559  tmp_alpha_j = TMath::Max(low,tmp_alpha_j);
560  tmp_alpha_i = b_alpha_i - (tmp_alpha_j - b_alpha_j_p);
561 
562  //update alphai alphaj_p
563  if( IsDiffSignificant(b_alpha_j_p,tmp_alpha_j, epsilon) || IsDiffSignificant(b_alpha_i,tmp_alpha_i, epsilon)){
564  b_alpha_j_p = tmp_alpha_j;
565  b_alpha_i = tmp_alpha_i;
566  }
567  }
568  else
569  terminated = kTRUE;
570 
571  caseB = kTRUE;
572  }
573  else if((caseC==kFALSE) && (b_alpha_i_p>0 || (b_alpha_i==0 && deltafi < -2*epsilon )) && (b_alpha_j>0 || (b_alpha_j_p==0 && deltafi< -2*epsilon)))
574  {
575  //compute LH w.r.t. alphai_p alphaj
576  low = TMath::Max(null, -gamma );
577  high = TMath::Min(b_cost_i, -gamma+b_cost_j);
578 
579  if(low<high){
580  tmp_alpha_j = b_alpha_j - ((deltafi+2*epsilon)/eta);
581  tmp_alpha_j = TMath::Min(tmp_alpha_j,high );
582  tmp_alpha_j = TMath::Max(low ,tmp_alpha_j);
583  tmp_alpha_i = b_alpha_i_p - (tmp_alpha_j - b_alpha_j);
584 
585  //update alphai_p alphaj
586  if( IsDiffSignificant(b_alpha_j,tmp_alpha_j, epsilon) || IsDiffSignificant(b_alpha_i_p,tmp_alpha_i, epsilon)){
587  b_alpha_j = tmp_alpha_j;
588  b_alpha_i_p = tmp_alpha_i;
589  }
590  }
591  else
592  terminated = kTRUE;
593 
594  caseC = kTRUE;
595  }
596  else if((caseD == kFALSE) &&
597  (b_alpha_i_p>0 || (b_alpha_i==0 && deltafi <0 )) &&
598  (b_alpha_j_p>0 || (b_alpha_j==0 && deltafi >0 )))
599  {
600  //compute LH w.r.t. alphai_p alphaj_p
601  low = TMath::Max(null,-gamma - b_cost_j);
602  high = TMath::Min(b_cost_i, -gamma);
603 
604  if(low<high){
605  tmp_alpha_j = b_alpha_j_p + (deltafi/eta);
606  tmp_alpha_j = TMath::Min(tmp_alpha_j,high );
607  tmp_alpha_j = TMath::Max(low ,tmp_alpha_j);
608  tmp_alpha_i = b_alpha_i_p - (tmp_alpha_j - b_alpha_j_p);
609 
610  if( IsDiffSignificant(b_alpha_j_p,tmp_alpha_j, epsilon) || IsDiffSignificant(b_alpha_i_p,tmp_alpha_i, epsilon)){
611  b_alpha_j_p = tmp_alpha_j;
612  b_alpha_i_p = tmp_alpha_i;
613  }
614  }
615  else
616  terminated = kTRUE;
617 
618  caseD = kTRUE;
619  }
620  else
621  terminated = kTRUE;
622  }
623  // TODO ad commment how it was calculated
624  deltafi += ievt->GetDeltaAlpha()*(kernel_II - kernel_IJ) + jevt->GetDeltaAlpha()*(kernel_IJ - kernel_JJ);
625 
626  if( IsDiffSignificant(b_alpha_i, ievt->GetAlpha(), epsilon) ||
627  IsDiffSignificant(b_alpha_j, jevt->GetAlpha(), epsilon) ||
628  IsDiffSignificant(b_alpha_i_p, ievt->GetAlpha_p(), epsilon) ||
629  IsDiffSignificant(b_alpha_j_p, jevt->GetAlpha_p(), epsilon) ){
630 
631  //TODO check if these conditions might be easier
632  //TODO write documentation for this
633  const Float_t diff_alpha_i = ievt->GetDeltaAlpha()+b_alpha_i_p - ievt->GetAlpha();
634  const Float_t diff_alpha_j = jevt->GetDeltaAlpha()+b_alpha_j_p - jevt->GetAlpha();
635 
636  //update error cache
637  Int_t k = 0;
638  for(idIter = fInputData->begin(); idIter != fInputData->end(); ++idIter){
639  k++;
640  //there will be some changes in Idx notation
641  if((*idIter)->GetIdx()==0){
642  Float_t k_ii = fKMatrix->GetElement(ievt->GetNs(), (*idIter)->GetNs());
643  Float_t k_jj = fKMatrix->GetElement(jevt->GetNs(), (*idIter)->GetNs());
644 
645  (*idIter)->UpdateErrorCache(diff_alpha_i * k_ii + diff_alpha_j * k_jj);
646  }
647  }
648 
649  //store new alphas in SVevents
650  ievt->SetAlpha(b_alpha_i);
651  jevt->SetAlpha(b_alpha_j);
652  ievt->SetAlpha_p(b_alpha_i_p);
653  jevt->SetAlpha_p(b_alpha_j_p);
654 
655  //TODO update Idexes
656 
657  // compute fI_low, fB_low
658 
659  fB_low = -1*1e30;
660  fB_up =1e30;
661 
662  for(idIter = fInputData->begin(); idIter != fInputData->end(); ++idIter){
663  if((!(*idIter)->IsInI3()) && ((*idIter)->GetErrorCache()> fB_low)){
664  fB_low = (*idIter)->GetErrorCache();
665  fTEventLow = (*idIter);
666 
667  }
668  if((!(*idIter)->IsInI2()) && ((*idIter)->GetErrorCache()< fB_up)){
669  fB_up =(*idIter)->GetErrorCache();
670  fTEventUp = (*idIter);
671  }
672  }
673  return kTRUE;
674  } else return kFALSE;
675 }
676 
677 
678 ////////////////////////////////////////////////////////////////////////////////
679 
681 {
682  Float_t feps = 1e-7;// TODO check which value is the best
683  SVEvent* ievt=0;
684  Float_t fErrorC_J = 0.;
685  if( jevt->IsInI0()) {
686  fErrorC_J = jevt->GetErrorCache();
687  }
688  else{
689  Float_t *fKVals = jevt->GetLine();
690  fErrorC_J = 0.;
691  std::vector<TMVA::SVEvent*>::iterator idIter;
692 
693  UInt_t k=0;
694  for(idIter = fInputData->begin(); idIter != fInputData->end(); ++idIter){
695  fErrorC_J -= (*idIter)->GetDeltaAlpha()*fKVals[k];
696  k++;
697  }
698 
699  fErrorC_J += jevt->GetTarget();
700  jevt->SetErrorCache(fErrorC_J);
701 
702  if(jevt->IsInI1()){
703  if(fErrorC_J + feps < fB_up ){
704  fB_up = fErrorC_J + feps;
705  fTEventUp = jevt;
706  }
707  else if(fErrorC_J -feps > fB_low) {
708  fB_low = fErrorC_J - feps;
709  fTEventLow = jevt;
710  }
711  }else if((jevt->IsInI2()) && (fErrorC_J + feps > fB_low)){
712  fB_low = fErrorC_J + feps;
713  fTEventLow = jevt;
714  }else if((jevt->IsInI3()) && (fErrorC_J - feps < fB_up)){
715  fB_up = fErrorC_J - feps;
716  fTEventUp = jevt;
717  }
718  }
719 
720  Bool_t converged = kTRUE;
721  //case 1
722  if(jevt->IsInI0a()){
723  if( fB_low -fErrorC_J + feps > 2*fTolerance){
724  converged = kFALSE;
725  ievt = fTEventLow;
726  if(fErrorC_J-feps-fB_up > fB_low-fErrorC_J+feps){
727  ievt = fTEventUp;
728  }
729  }else if(fErrorC_J -feps - fB_up > 2*fTolerance){
730  converged = kFALSE;
731  ievt = fTEventUp;
732  if(fB_low - fErrorC_J+feps > fErrorC_J-feps -fB_up){
733  ievt = fTEventLow;
734  }
735  }
736  }
737 
738  //case 2
739  if(jevt->IsInI0b()){
740  if( fB_low -fErrorC_J - feps > 2*fTolerance){
741  converged = kFALSE;
742  ievt = fTEventLow;
743  if(fErrorC_J+feps-fB_up > fB_low-fErrorC_J-feps){
744  ievt = fTEventUp;
745  }
746  }else if(fErrorC_J + feps - fB_up > 2*fTolerance){
747  converged = kFALSE;
748  ievt = fTEventUp;
749  if(fB_low - fErrorC_J-feps > fErrorC_J+feps -fB_up){
750  ievt = fTEventLow;
751  }
752  }
753  }
754 
755  //case 3
756  if(jevt->IsInI1()){
757  if( fB_low -fErrorC_J - feps > 2*fTolerance){
758  converged = kFALSE;
759  ievt = fTEventLow;
760  if(fErrorC_J+feps-fB_up > fB_low-fErrorC_J-feps){
761  ievt = fTEventUp;
762  }
763  }else if(fErrorC_J - feps - fB_up > 2*fTolerance){
764  converged = kFALSE;
765  ievt = fTEventUp;
766  if(fB_low - fErrorC_J+feps > fErrorC_J-feps -fB_up){
767  ievt = fTEventLow;
768  }
769  }
770  }
771 
772  //case 4
773  if(jevt->IsInI2()){
774  if( fErrorC_J + feps -fB_up > 2*fTolerance){
775  converged = kFALSE;
776  ievt = fTEventUp;
777  }
778  }
779 
780  //case 5
781  if(jevt->IsInI3()){
782  if(fB_low -fErrorC_J +feps > 2*fTolerance){
783  converged = kFALSE;
784  ievt = fTEventLow;
785  }
786  }
787 
788  if(converged) return kFALSE;
789  if (TakeStepReg(ievt, jevt)) return kTRUE;
790  else return kFALSE;
791 }
792 
794 {
795  if( TMath::Abs(a_i - a_j) > eps*(a_i + a_j + eps)) return kTRUE;
796  else return kFALSE;
797 }
798 
Random number generator class based on M.
Definition: TRandom3.h:27
Bool_t IsInI1() const
Definition: SVEvent.h:77
Kernel for Support Vector Machine.
MsgLogger & Endl(MsgLogger &ml)
Definition: MsgLogger.h:158
Float_t * GetLine(UInt_t)
returns a row of the kernel matrix
void SetAlpha(Float_t alpha)
Definition: SVEvent.h:51
Int_t GetIdx() const
Definition: SVEvent.h:68
~SVWorkingSet()
destructor
float Float_t
Definition: RtypesCore.h:53
SVWorkingSet()
constructor
Bool_t IsInI0() const
Definition: SVEvent.h:76
std::vector< TMVA::SVEvent * > * fSupVec
Definition: SVWorkingSet.h:75
void SetIndex(TMVA::SVEvent *)
void Train(UInt_t nIter=1000)
train the SVM
Bool_t IsInI0b() const
Definition: SVEvent.h:75
Short_t Min(Short_t a, Short_t b)
Definition: TMathBase.h:168
int Int_t
Definition: RtypesCore.h:41
Bool_t TakeStep(SVEvent *, SVEvent *)
bool Bool_t
Definition: RtypesCore.h:59
Float_t GetCweight() const
Definition: SVEvent.h:71
Short_t Abs(Short_t d)
Definition: TMathBase.h:108
Kernel matrix for Support Vector Machine.
UInt_t GetNs() const
Definition: SVEvent.h:70
MsgLogger * fLogger
Definition: SVWorkingSet.h:86
null_t< F > null()
void SetErrorCache(Float_t err_cache)
Definition: SVEvent.h:53
Float_t * GetLine() const
Definition: SVEvent.h:69
Bool_t IsInI3() const
Definition: SVEvent.h:79
Float_t GetElement(UInt_t i, UInt_t j)
returns an element of the kernel matrix
SVKernelFunction * fKFunction
Definition: SVWorkingSet.h:76
virtual UInt_t Integer(UInt_t imax)
Returns a random integer on [ 0, imax-1 ].
Definition: TRandom.cxx:341
Float_t GetAlpha_p() const
Definition: SVEvent.h:62
Bool_t ExamineExampleReg(SVEvent *)
bool * fExitFromTraining
Definition: SVWorkingSet.h:90
Float_t GetErrorCache() const
Definition: SVEvent.h:65
double gamma(double x)
TPaveText * pt
Int_t GetTypeFlag() const
Definition: SVEvent.h:66
unsigned int UInt_t
Definition: RtypesCore.h:42
Float_t GetDeltaAlpha() const
Definition: SVEvent.h:63
Bool_t IsInI0a() const
Definition: SVEvent.h:74
Event class for Support Vector Machine.
Definition: SVEvent.h:40
Bool_t ExamineExample(SVEvent *)
Float_t GetAlpha() const
Definition: SVEvent.h:61
REAL epsilon
Definition: triangle.c:617
#define h(i)
Definition: RSha256.hxx:106
const Bool_t kFALSE
Definition: RtypesCore.h:88
void SetIdx(Int_t idx)
Definition: SVEvent.h:56
Bool_t TakeStepReg(SVEvent *, SVEvent *)
std::vector< TMVA::SVEvent * > * GetSupportVectors()
Float_t GetTarget() const
Definition: SVEvent.h:72
static constexpr double s
you should not use this method at all Int_t Int_t Double_t Double_t Double_t e
Definition: TRolke.cxx:630
ostringstream derivative to redirect and format output
Definition: MsgLogger.h:59
Bool_t IsInI2() const
Definition: SVEvent.h:78
auto * l
Definition: textangle.C:4
SVKernelMatrix * fKMatrix
Definition: SVWorkingSet.h:77
Bool_t IsDiffSignificant(Float_t, Float_t, Float_t)
Short_t Max(Short_t a, Short_t b)
Definition: TMathBase.h:200
void SetAlpha_p(Float_t alpha)
Definition: SVEvent.h:52
std::vector< TMVA::SVEvent * > * fInputData
Definition: SVWorkingSet.h:74
UInt_t * fIPyCurrentIter
message logger
Definition: SVWorkingSet.h:89
const Bool_t kTRUE
Definition: RtypesCore.h:87
SVEvent * fTEventUp
Definition: SVWorkingSet.h:79
SVEvent * fTEventLow
Definition: SVWorkingSet.h:80