Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
tmva101_Training.py
Go to the documentation of this file.
1## \file
2## \ingroup tutorial_ml
3## \notebook -nodraw
4## This tutorial show how you can train a machine learning model with any package
5## reading the training data directly from ROOT files. Using XGBoost, we illustrate
6## how you can convert an externally trained model in a format serializable and readable
7## with the fast tree inference engine offered by TMVA.
8##
9## \macro_code
10## \macro_output
11##
12## \date August 2019
13## \author Stefan Wunsch
14
15import ROOT
16import numpy as np
17
18from tmva100_DataPreparation import variables
19
20
21def load_data(signal_filename, background_filename):
22 # Read data from ROOT files
23 data_sig = ROOT.RDataFrame("Events", signal_filename).AsNumpy()
24 data_bkg = ROOT.RDataFrame("Events", background_filename).AsNumpy()
25
26 # Convert inputs to format readable by machine learning tools
27 x_sig = np.vstack([data_sig[var] for var in variables]).T
28 x_bkg = np.vstack([data_bkg[var] for var in variables]).T
29 x = np.vstack([x_sig, x_bkg])
30
31 # Create labels
32 num_sig = x_sig.shape[0]
33 num_bkg = x_bkg.shape[0]
34 y = np.hstack([np.ones(num_sig), np.zeros(num_bkg)])
35
36 # Compute weights balancing both classes
37 num_all = num_sig + num_bkg
38 w = np.hstack([np.ones(num_sig) * num_all / num_sig, np.ones(num_bkg) * num_all / num_bkg])
39
40 return x, y, w
41
42
43if __name__ == "__main__":
44
45 from xgboost import XGBClassifier
46
47 # Load data
48 x, y, w = load_data("train_signal.root", "train_background.root")
49
50 # Fit xgboost model
51 bdt = XGBClassifier(max_depth=3, n_estimators=500)
52 bdt.fit(x, y, sample_weight=w)
53
54 # Save model in TMVA format
55 print("Training done on ", x.shape[0], "events. Saving model in tmva101.root")
56 ROOT.TMVA.Experimental.SaveXGBoost(bdt, "myBDT", "tmva101.root", num_inputs=x.shape[1])
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
ROOT's RDataFrame offers a modern, high-level interface for analysis of data stored in TTree ,...