19from ROOT
import TMVA, TFile, TCut, gROOT
20from subprocess
import call
21from os.path
import isfile
30 '!V:!Silent:Color:DrawProgressBar:Transformations=D,G:AnalysisType=Classification')
34data =
TFile.Open(str(gROOT.GetTutorialDir()) +
'/machine_learning/data/tmva_class_example.root')
35signal = data.Get(
'TreeS')
36background = data.Get(
'TreeB')
39for branch
in signal.GetListOfBranches():
40 dataloader.AddVariable(branch.GetName())
42dataloader.AddSignalTree(signal, 1.0)
43dataloader.AddBackgroundTree(background, 1.0)
44dataloader.PrepareTrainingAndTestTree(
TCut(
''),
45 'nTrain_Signal=4000:nTrain_Background=4000:SplitMode=Random:NormMode=NumEvents:!V')
51model = nn.Sequential()
52model.add_module(
'linear_1', nn.Linear(in_features=4, out_features=64))
53model.add_module(
'relu', nn.ReLU())
54model.add_module(
'linear_2', nn.Linear(in_features=64, out_features=2))
55model.add_module(
'softmax', nn.Softmax(dim=1))
59loss = torch.nn.MSELoss()
60optimizer = torch.optim.SGD
64def train(model, train_loader, val_loader, num_epochs, batch_size, optimizer, criterion, save_best, scheduler):
65 trainer = optimizer(model.parameters(), lr=0.01)
66 schedule, schedulerSteps = scheduler
69 for epoch
in range(num_epochs):
73 running_train_loss = 0.0
74 running_val_loss = 0.0
75 for i, (X, y)
in enumerate(train_loader):
78 train_loss = criterion(output, y)
83 running_train_loss += train_loss.item()
85 print(
"[{}, {}] train loss: {:.3f}".format(epoch+1, i+1, running_train_loss / 32))
86 running_train_loss = 0.0
89 schedule(optimizer, epoch, schedulerSteps)
95 for i, (X, y)
in enumerate(val_loader):
97 val_loss = criterion(output, y)
98 running_val_loss += val_loss.item()
100 curr_val = running_val_loss / len(val_loader)
104 best_val = save_best(model, curr_val, best_val)
107 print(
"[{}] val loss: {:.3f}".format(epoch+1, curr_val))
108 running_val_loss = 0.0
110 print(
"Finished Training on {} Epochs!".format(epoch+1))
116def predict(model, test_X, batch_size=32):
120 test_dataset = torch.utils.data.TensorDataset(torch.Tensor(test_X))
121 test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=
False)
124 with torch.no_grad():
125 for i, data
in enumerate(test_loader):
128 predictions.append(outputs)
129 preds = torch.cat(predictions)
134load_model_custom_objects = {
"optimizer": optimizer,
"criterion": loss,
"train_func": train,
"predict_func": predict}
139m = torch.jit.script(model)
140torch.jit.save(m,
"modelClassification.pt")
145factory.BookMethod(dataloader, TMVA.Types.kFisher,
'Fisher',
146 '!H:!V:Fisher:VarTransform=D,G')
147factory.BookMethod(dataloader, TMVA.Types.kPyTorch,
'PyTorch',
148 'H:!V:VarTransform=D,G:FilenameModel=modelClassification.pt:FilenameTrainedModel=trainedModelClassification.pt:NumEpochs=20:BatchSize=32')
152factory.TrainAllMethods()
153factory.TestAllMethods()
154factory.EvaluateAllMethods()
158roc = factory.GetROCCurve(dataloader)
159roc.SaveAs(
'ROC_ClassificationPyTorch.png')
A specialized string object used for TTree selections.
static TFile * Open(const char *name, Option_t *option="", const char *ftitle="", Int_t compress=ROOT::RCompressionSetting::EDefaults::kUseCompiledDefault, Int_t netopt=0)
Create / open a file.
This is the main MVA steering class.
static void PyInitialize()
Initialize Python interpreter.