14from ROOT
import TMVA, TFile, TTree, TCut
15from subprocess
import call
16from os.path
import isfile
28 '!V:!Silent:Color:DrawProgressBar:Transformations=D,G:AnalysisType=Regression')
32if not isfile(
'tmva_reg_example.root'):
33 call([
'curl',
'-L',
'-O',
'http://root.cern.ch/files/tmva_reg_example.root'])
36tree = data.Get(
'TreeR')
39for branch
in tree.GetListOfBranches():
40 name = branch.GetName()
42 dataloader.AddVariable(name)
43dataloader.AddTarget(
'fvalue')
45dataloader.AddRegressionTree(tree, 1.0)
46dataloader.PrepareTrainingAndTestTree(
TCut(
''),
47 'nTrain_Regression=4000:SplitMode=Random:NormMode=NumEvents:!V')
53model = nn.Sequential()
54model.add_module(
'linear_1', nn.Linear(in_features=2, out_features=64))
55model.add_module(
'relu', nn.Tanh())
56model.add_module(
'linear_2', nn.Linear(in_features=64, out_features=1))
60loss = torch.nn.MSELoss()
61optimizer = torch.optim.SGD
65def train(model, train_loader, val_loader, num_epochs, batch_size, optimizer, criterion, save_best, scheduler):
66 trainer = optimizer(model.parameters(), lr=0.01)
67 schedule, schedulerSteps = scheduler
70 for epoch
in range(num_epochs):
74 running_train_loss = 0.0
75 running_val_loss = 0.0
76 for i, (X, y)
in enumerate(train_loader):
79 train_loss = criterion(output, y)
84 running_train_loss += train_loss.item()
86 print(
"[{}, {}] train loss: {:.3f}".format(epoch+1, i+1, running_train_loss / 32))
87 running_train_loss = 0.0
90 schedule(optimizer, epoch, schedulerSteps)
96 for i, (X, y)
in enumerate(val_loader):
98 val_loss = criterion(output, y)
99 running_val_loss += val_loss.item()
101 curr_val = running_val_loss / len(val_loader)
105 best_val = save_best(model, curr_val, best_val)
108 print(
"[{}] val loss: {:.3f}".format(epoch+1, curr_val))
109 running_val_loss = 0.0
111 print(
"Finished Training on {} Epochs!".format(epoch+1))
117def predict(model, test_X, batch_size=32):
121 test_dataset = torch.utils.data.TensorDataset(torch.Tensor(test_X))
122 test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=
False)
125 with torch.no_grad():
126 for i, data
in enumerate(test_loader):
129 predictions.append(outputs)
130 preds = torch.cat(predictions)
135load_model_custom_objects = {
"optimizer": optimizer,
"criterion": loss,
"train_func": train,
"predict_func": predict}
140m = torch.jit.script(model)
141torch.jit.save(m,
"modelRegression.pt")
146factory.BookMethod(dataloader, TMVA.Types.kPyTorch,
'PyTorch',
147 'H:!V:VarTransform=D,G:FilenameModel=modelRegression.pt:FilenameTrainedModel=trainedModelRegression.pt:NumEpochs=20:BatchSize=32')
148factory.BookMethod(dataloader, TMVA.Types.kBDT,
'BDTG',
149 '!H:!V:VarTransform=D,G:NTrees=1000:BoostType=Grad:Shrinkage=0.1:UseBaggedBoost:BaggedSampleFraction=0.5:nCuts=20:MaxDepth=4')
153factory.TrainAllMethods()
154factory.TestAllMethods()
155factory.EvaluateAllMethods()
A specialized string object used for TTree selections.
static TFile * Open(const char *name, Option_t *option="", const char *ftitle="", Int_t compress=ROOT::RCompressionSetting::EDefaults::kUseCompiledDefault, Int_t netopt=0)
Create / open a file.
This is the main MVA steering class.
static void PyInitialize()
Initialize Python interpreter.