Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
Objectives.hxx
Go to the documentation of this file.
1/**********************************************************************************
2 * Project: ROOT - a Root-integrated toolkit for multivariate data analysis *
3 * Package: TMVA *
4 * Web : http://tmva.sourceforge.net *
5 * *
6 * Description: *
7 * *
8 * Authors: *
9 * Stefan Wunsch (stefan.wunsch@cern.ch) *
10 * Luca Zampieri (luca.zampieri@alumni.epfl.ch) *
11 * *
12 * Copyright (c) 2019: *
13 * CERN, Switzerland *
14 * *
15 * Redistribution and use in source and binary forms, with or without *
16 * modification, are permitted according to the terms listed in LICENSE *
17 * (http://tmva.sourceforge.net/LICENSE) *
18 **********************************************************************************/
19
20#ifndef TMVA_TREEINFERENCE_OBJECTIVES
21#define TMVA_TREEINFERENCE_OBJECTIVES
22
23#include <string>
24#include <stdexcept>
25#include <cmath> // std::exp
26#include <functional> // std::function
27
28namespace TMVA {
29namespace Experimental {
30namespace Objectives {
31
32/// Logistic function f(x) = 1 / (1 + exp(-x))
33template <typename T>
34inline T Logistic(T value)
35{
36 return 1.0 / (1.0 + std::exp(-1.0 * value));
37}
38
39/// Identity function f(x) = x
40template <typename T>
41inline T Identity(T value)
42{
43 return value;
44}
45
46/// Natural exponential function f(x) = exp(x)
47///
48/// This objective is used for the softmax objective in the multiclass
49/// case with the formula exp(x)/sum(exp(x)) and the vector x.
50template <typename T>
51inline T Exponential(T value)
52{
53 return std::exp(value);
54}
55
56/// Get function pointer to implementation from name given as string
57template <typename T>
58std::function<T(T)> GetFunction(const std::string &name)
59{
60 if (name.compare("identity") == 0)
61 return std::function<T(T)>(Identity<T>);
62 else if (name.compare("logistic") == 0)
63 return std::function<T(T)>(Logistic<T>);
64 else if (name.compare("softmax") == 0)
65 return std::function<T(T)>(Exponential<T>);
66 else
67 throw std::runtime_error("Objective function with name \"" + name + "\" is not implemented.");
68}
69
70} // namespace Objectives
71} // namespace Experimental
72} // namespace TMVA
73
74#endif // TMVA_TREEINFERENCE_OBJECTIVES
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void value
char name[80]
Definition TGX11.cxx:110
T Identity(T value)
Identity function f(x) = x.
T Exponential(T value)
Natural exponential function f(x) = exp(x)
std::function< T(T)> GetFunction(const std::string &name)
Get function pointer to implementation from name given as string.
T Logistic(T value)
Logistic function f(x) = 1 / (1 + exp(-x))
create variable transformations