14from ROOT
import TMVA, TFile, TTree, TCut, gROOT
15from os.path
import isfile
27 '!V:!Silent:Color:DrawProgressBar:Transformations=D,G:AnalysisType=multiclass')
31if not isfile(
'tmva_example_multiple_background.root'):
32 createDataMacro = str(gROOT.GetTutorialDir()) +
'/tmva/createData.C'
33 print(createDataMacro)
34 gROOT.ProcessLine(
'.L {}'.format(createDataMacro))
35 gROOT.ProcessLine(
'create_MultipleBackground(4000)')
37data =
TFile.Open(
'tmva_example_multiple_background.root')
38signal = data.Get(
'TreeS')
39background0 = data.Get(
'TreeB0')
40background1 = data.Get(
'TreeB1')
41background2 = data.Get(
'TreeB2')
44for branch
in signal.GetListOfBranches():
45 dataloader.AddVariable(branch.GetName())
47dataloader.AddTree(signal,
'Signal')
48dataloader.AddTree(background0,
'Background_0')
49dataloader.AddTree(background1,
'Background_1')
50dataloader.AddTree(background2,
'Background_2')
51dataloader.PrepareTrainingAndTestTree(
TCut(
''),
52 'SplitMode=Random:NormMode=NumEvents:!V')
57model = nn.Sequential()
58model.add_module(
'linear_1', nn.Linear(in_features=4, out_features=32))
59model.add_module(
'relu', nn.ReLU())
60model.add_module(
'linear_2', nn.Linear(in_features=32, out_features=4))
61model.add_module(
'softmax', nn.Softmax(dim=1))
65loss = nn.CrossEntropyLoss()
66optimizer = torch.optim.SGD
70def train(model, train_loader, val_loader, num_epochs, batch_size, optimizer, criterion, save_best, scheduler):
71 trainer = optimizer(model.parameters(), lr=0.01)
72 schedule, schedulerSteps = scheduler
75 for epoch
in range(num_epochs):
79 running_train_loss = 0.0
80 running_val_loss = 0.0
81 for i, (X, y)
in enumerate(train_loader):
84 target = torch.max(y, 1)[1]
85 train_loss = criterion(output, target)
90 running_train_loss += train_loss.item()
92 print(
"[{}, {}] train loss: {:.3f}".format(epoch+1, i+1, running_train_loss / 32))
93 running_train_loss = 0.0
96 schedule(optimizer, epoch, schedulerSteps)
101 with torch.no_grad():
102 for i, (X, y)
in enumerate(val_loader):
104 target = torch.max(y, 1)[1]
105 val_loss = criterion(output, target)
106 running_val_loss += val_loss.item()
108 curr_val = running_val_loss / len(val_loader)
112 best_val = save_best(model, curr_val, best_val)
115 print(
"[{}] val loss: {:.3f}".format(epoch+1, curr_val))
116 running_val_loss = 0.0
118 print(
"Finished Training on {} Epochs!".format(epoch+1))
124def predict(model, test_X, batch_size=32):
128 test_dataset = torch.utils.data.TensorDataset(torch.Tensor(test_X))
129 test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=
False)
132 with torch.no_grad():
133 for i, data
in enumerate(test_loader):
136 predictions.append(outputs)
137 preds = torch.cat(predictions)
142load_model_custom_objects = {
"optimizer": optimizer,
"criterion": loss,
"train_func": train,
"predict_func": predict}
147m = torch.jit.script(model)
148torch.jit.save(m,
"modelMultiClass.pt")
153factory.BookMethod(dataloader, TMVA.Types.kFisher,
'Fisher',
154 '!H:!V:Fisher:VarTransform=D,G')
155factory.BookMethod(dataloader, TMVA.Types.kPyTorch,
"PyTorch",
156 'H:!V:VarTransform=D,G:FilenameModel=modelMultiClass.pt:FilenameTrainedModel=trainedModelMultiClass.pt:NumEpochs=20:BatchSize=32')
160factory.TrainAllMethods()
161factory.TestAllMethods()
162factory.EvaluateAllMethods()
165roc = factory.GetROCCurve(dataloader)
166roc.SaveAs(
'ROC_MulticlassPyTorch.png')
A specialized string object used for TTree selections.
static TFile * Open(const char *name, Option_t *option="", const char *ftitle="", Int_t compress=ROOT::RCompressionSetting::EDefaults::kUseCompiledDefault, Int_t netopt=0)
Create / open a file.
This is the main MVA steering class.
static void PyInitialize()
Initialize Python interpreter.