Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
RStandardScaler.hxx
Go to the documentation of this file.
1#ifndef TMVA_RSTANDARDSCALER
2#define TMVA_RSTANDARDSCALER
3
4#include <TFile.h>
5
6#include <TMVA/RTensor.hxx>
7#include <string_view>
8
9#include <cmath>
10#include <vector>
11
12namespace TMVA {
13namespace Experimental {
14
15template <typename T>
17private:
18 std::vector<T> fMeans;
19 std::vector<T> fStds;
20
21public:
22 RStandardScaler() = default;
23 RStandardScaler(std::string_view title, std::string_view filename);
24 void Fit(const RTensor<T>& x);
25 std::vector<T> Compute(const std::vector<T>& x);
27 std::vector<T> GetMeans() const { return fMeans; }
28 std::vector<T> GetStds() const { return fStds; }
29 void Save(std::string_view title, std::string_view filename);
30};
31
32template <typename T>
33inline RStandardScaler<T>::RStandardScaler(std::string_view title, std::string_view filename) {
34 auto file = TFile::Open(filename.data(), "READ");
36 file->GetObject(title.data(), obj);
37 fMeans = obj->GetMeans();
38 fStds = obj->GetStds();
39 delete obj;
40 file->Close();
41}
42
43template <typename T>
44inline void RStandardScaler<T>::Save(std::string_view title, std::string_view filename) {
45 auto file = TFile::Open(filename.data(), "UPDATE");
46 file->WriteObject<RStandardScaler<T>>(this, title.data());
47 file->Write();
48 file->Close();
49}
50
51template <typename T>
53 const auto shape = x.GetShape();
54 if (shape.size() != 2)
55 throw std::runtime_error("Can only fit to input tensor of rank 2.");
56 fMeans.clear();
57 fMeans.resize(shape[1]);
58 fStds.clear();
59 fStds.resize(shape[1]);
60
61 // Compute means
62 for (std::size_t i = 0; i < shape[0]; i++) {
63 for (std::size_t j = 0; j < shape[1]; j++) {
64 fMeans[j] += x(i, j);
65 }
66 }
67 for (std::size_t i = 0; i < shape[1]; i++) {
68 fMeans[i] /= shape[0];
69 }
70
71 // Compute standard deviations using unbiased estimator
72 for (std::size_t i = 0; i < shape[0]; i++) {
73 for (std::size_t j = 0; j < shape[1]; j++) {
74 fStds[j] += (x(i, j) - fMeans[j]) * (x(i, j) - fMeans[j]);
75 }
76 }
77 for (std::size_t i = 0; i < shape[1]; i++) {
78 fStds[i] = std::sqrt(fStds[i] / (shape[0] - 1));
79 }
80}
81
82template <typename T>
83inline std::vector<T> RStandardScaler<T>::Compute(const std::vector<T>& x) {
84 const auto size = x.size();
85 if (size != fMeans.size())
86 throw std::runtime_error("Size of input vector is not equal to number of fitted variables.");
87
88 std::vector<T> y(size);
89 for (std::size_t i = 0; i < size; i++) {
90 y[i] = (x[i] - fMeans[i]) / fStds[i];
91 }
92
93 return y;
94}
95
96template <typename T>
98 const auto shape = x.GetShape();
99 if (shape.size() != 2)
100 throw std::runtime_error("Can only compute output for input tensor of rank 2.");
101 if (shape[1] != fMeans.size())
102 throw std::runtime_error("Second dimension of input tensor is not equal to number of fitted variables.");
103
104 RTensor<T> y(shape);
105 for (std::size_t i = 0; i < shape[0]; i++) {
106 for (std::size_t j = 0; j < shape[1]; j++) {
107 y(i, j) = (x(i, j) - fMeans[j]) / fStds[j];
108 }
109 }
110
111 return y;
112}
113
114} // namespace Experimental
115} // namespace TMVA
116
117#endif // TMVA_RSTANDARDSCALER
size_t size(const MatrixT &matrix)
retrieve the size of a square matrix
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char filename
static TFile * Open(const char *name, Option_t *option="", const char *ftitle="", Int_t compress=ROOT::RCompressionSetting::EDefaults::kUseCompiledDefault, Int_t netopt=0)
Create / open a file.
Definition TFile.cxx:4089
void Save(std::string_view title, std::string_view filename)
void Fit(const RTensor< T > &x)
std::vector< T > Compute(const std::vector< T > &x)
RTensor is a container with contiguous memory and shape information.
Definition RTensor.hxx:162
const Shape_t & GetShape() const
Definition RTensor.hxx:242
Double_t y[n]
Definition legend1.C:17
Double_t x[n]
Definition legend1.C:17
create variable transformations