15file_name = str(ROOT.gROOT.GetTutorialDir()) +
"/machine_learning/data/Higgs_data.root"
25dl = ROOT.Experimental.ML.RDataLoader(
33gen_train, gen_validation = dl.train_test_split(test_size=0.3)
36input_columns = gen_train.train_columns
37num_features = len(input_columns)
40def calc_accuracy(targets, pred):
41 return torch.sum(targets == pred.round()) / pred.size(0)
45model = torch.nn.Sequential(
46 torch.nn.Linear(num_features, 300),
48 torch.nn.Linear(300, 300),
50 torch.nn.Linear(300, 300),
52 torch.nn.Linear(300, 1),
55loss_fn = torch.nn.MSELoss(reduction=
"mean")
56optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
60for i
in range(number_of_epochs):
64 for i, (x_train, y_train)
in enumerate(gen_train.as_torch()):
67 loss = loss_fn(pred, y_train)
75 accuracy = calc_accuracy(y_train, pred)
77 print(f
"Training => accuracy: {accuracy}")
85 for i, (x_val, y_val)
in enumerate(gen_validation.as_torch()):
88 accuracy = calc_accuracy(y_val, pred)
90 print(f
"Validation => accuracy: {accuracy}")
ROOT's RDataFrame offers a modern, high-level interface for analysis of data stored in TTree ,...