import ROOT
import torch
tree_name = "sig_tree"
file_name = str(ROOT.gROOT.GetTutorialDir()) + "/machine_learning/data/Higgs_data.root"
batch_size = 128
target = "Type"
dl = ROOT.Experimental.ML.RDataLoader(
rdataframe,
batch_size,
target=target,
shuffle=True,
drop_remainder=True,
)
gen_train, gen_validation = dl.train_test_split(test_size=0.3)
input_columns = gen_train.train_columns
num_features = len(input_columns)
def calc_accuracy(targets, pred):
return torch.sum(targets == pred.round()) / pred.size(0)
model = torch.nn.Sequential(
torch.nn.Linear(num_features, 300),
torch.nn.Tanh(),
torch.nn.Linear(300, 300),
torch.nn.Tanh(),
torch.nn.Linear(300, 300),
torch.nn.Tanh(),
torch.nn.Linear(300, 1),
torch.nn.Sigmoid(),
)
loss_fn = torch.nn.MSELoss(reduction="mean")
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
number_of_epochs = 2
for i in range(number_of_epochs):
print("Epoch ", i)
model.train()
for i, (x_train, y_train) in enumerate(gen_train.as_torch()):
pred = model(x_train)
loss = loss_fn(pred, y_train)
model.zero_grad()
loss.backward()
optimizer.step()
accuracy = calc_accuracy(y_train, pred)
print(f"Training => accuracy: {accuracy}")
model.eval()
for i, (x_val, y_val) in enumerate(gen_validation.as_torch()):
pred = model(x_val)
accuracy = calc_accuracy(y_val, pred)
print(f"Validation => accuracy: {accuracy}")
ROOT's RDataFrame offers a modern, high-level interface for analysis of data stored in TTree ,...