20from ROOT 
import TMVA, TFile, TTree, TCut, gROOT
 
   21from os.path 
import isfile
 
   30    '!V:!Silent:Color:DrawProgressBar:Transformations=D,G:AnalysisType=multiclass')
 
   34if not isfile(
'tmva_example_multiple_background.root'):
 
   35    createDataMacro = str(gROOT.GetTutorialDir()) + 
'/tmva/createData.C' 
   36    print(createDataMacro)
 
   37    gROOT.ProcessLine(
'.L {}'.
format(createDataMacro))
 
   38    gROOT.ProcessLine(
'create_MultipleBackground(4000)')
 
   40data = 
TFile.Open(
'tmva_example_multiple_background.root')
 
   41signal = data.Get(
'TreeS')
 
   42background0 = data.Get(
'TreeB0')
 
   43background1 = data.Get(
'TreeB1')
 
   44background2 = data.Get(
'TreeB2')
 
   47for branch 
in signal.GetListOfBranches():
 
   48    dataloader.AddVariable(branch.GetName())
 
   50dataloader.AddTree(signal, 
'Signal')
 
   51dataloader.AddTree(background0, 
'Background_0')
 
   52dataloader.AddTree(background1, 
'Background_1')
 
   53dataloader.AddTree(background2, 
'Background_2')
 
   54dataloader.PrepareTrainingAndTestTree(
TCut(
''),
 
   55        'SplitMode=Random:NormMode=NumEvents:!V')
 
   60model = nn.Sequential()
 
   61model.add_module(
'linear_1', nn.Linear(in_features=4, out_features=32))
 
   62model.add_module(
'relu', nn.ReLU())
 
   63model.add_module(
'linear_2', nn.Linear(in_features=32, out_features=4))
 
   64model.add_module(
'softmax', nn.Softmax(dim=1))
 
   68loss = nn.CrossEntropyLoss()
 
   69optimizer = torch.optim.SGD
 
   73def train(model, train_loader, val_loader, num_epochs, batch_size, optimizer, criterion, save_best, scheduler):
 
   74    trainer = optimizer(model.parameters(), lr=0.01)
 
   75    schedule, schedulerSteps = scheduler
 
   78    for epoch 
in range(num_epochs):
 
   82        running_train_loss = 0.0
 
   83        running_val_loss = 0.0
 
   84        for i, (X, y) 
in enumerate(train_loader):
 
   87            target = torch.max(y, 1)[1]
 
   88            train_loss = criterion(output, target)
 
   93            running_train_loss += train_loss.item()
 
   95                print(
"[{}, {}] train loss: {:.3f}".
format(epoch+1, i+1, running_train_loss / 32))
 
   96                running_train_loss = 0.0
 
   99            schedule(optimizer, epoch, schedulerSteps)
 
  104        with torch.no_grad():
 
  105            for i, (X, y) 
in enumerate(val_loader):
 
  107                target = torch.max(y, 1)[1]
 
  108                val_loss = criterion(output, target)
 
  109                running_val_loss += val_loss.item()
 
  111            curr_val = running_val_loss / 
len(val_loader)
 
  115               best_val = save_best(model, curr_val, best_val)
 
  118            print(
"[{}] val loss: {:.3f}".
format(epoch+1, curr_val))
 
  119            running_val_loss = 0.0
 
  121    print(
"Finished Training on {} Epochs!".
format(epoch+1))
 
  127def predict(model, test_X, batch_size=32):
 
  131    test_dataset = torch.utils.data.TensorDataset(torch.Tensor(test_X))
 
  132    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=
False)
 
  135    with torch.no_grad():
 
  136        for i, data 
in enumerate(test_loader):
 
  139            predictions.append(outputs)
 
  140        preds = torch.cat(predictions)
 
  145load_model_custom_objects = {
"optimizer": optimizer, 
"criterion": loss, 
"train_func": train, 
"predict_func": predict}
 
  150m = torch.jit.script(model)
 
  151torch.jit.save(m, 
"modelMultiClass.pt")
 
  156factory.BookMethod(dataloader, TMVA.Types.kFisher, 
'Fisher',
 
  157        '!H:!V:Fisher:VarTransform=D,G')
 
  158factory.BookMethod(dataloader, TMVA.Types.kPyTorch, 
"PyTorch",
 
  159        'H:!V:VarTransform=D,G:FilenameModel=modelMultiClass.pt:FilenameTrainedModel=trainedModelMultiClass.pt:NumEpochs=20:BatchSize=32')
 
  163factory.TrainAllMethods()
 
  164factory.TestAllMethods()
 
  165factory.EvaluateAllMethods()
 
  168roc = factory.GetROCCurve(dataloader)
 
  169roc.SaveAs(
'ROC_MulticlassPyTorch.png')
 
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h Atom_t Int_t ULong_t ULong_t unsigned char prop_list Atom_t Atom_t Atom_t Time_t UChar_t len
 
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h Atom_t Int_t ULong_t ULong_t unsigned char prop_list Atom_t Atom_t Atom_t Time_t format
 
A specialized string object used for TTree selections.
 
static TFile * Open(const char *name, Option_t *option="", const char *ftitle="", Int_t compress=ROOT::RCompressionSetting::EDefaults::kUseCompiledDefault, Int_t netopt=0)
Create / open a file.
 
This is the main MVA steering class.
 
static void PyInitialize()
Initialize Python interpreter.