9 return x.view(-1,1,16,16)
12net = torch.nn.Sequential(
14 nn.Conv2d(1, 10, kernel_size=3, padding=1),
17 nn.Conv2d(10, 10, kernel_size=3, padding=1),
19 nn.MaxPool2d(kernel_size=2),
21 nn.Linear(10*8*8, 256),
28criterion = nn.BCELoss()
29optimizer = torch.optim.Adam
32def fit(model, train_loader, val_loader, num_epochs, batch_size, optimizer, criterion, save_best, scheduler):
33 trainer =
optimizer(model.parameters(), lr=0.01)
34 schedule, schedulerSteps = scheduler
38 device = torch.device(
'cuda' if torch.cuda.is_available()
else 'cpu')
39 model = model.to(device)
41 for epoch
in range(num_epochs):
45 running_train_loss = 0.0
46 running_val_loss = 0.0
47 for i, (X, y)
in enumerate(train_loader):
49 X, y = X.to(device), y.to(device)
57 running_train_loss += train_loss.item()
59 print(f
"[{epoch+1}, {i+1}] train loss: {running_train_loss / 4 :.3f}")
60 running_train_loss = 0.0
63 schedule(optimizer, epoch, schedulerSteps)
69 for i, (X, y)
in enumerate(val_loader):
70 X, y = X.to(device), y.to(device)
74 running_val_loss += val_loss.item()
76 curr_val = running_val_loss / len(val_loader)
80 best_val = save_best(model, curr_val, best_val)
83 print(f
"[{epoch+1}] val loss: {curr_val :.3f}")
84 running_val_loss = 0.0
86 print(f
"Finished Training on {epoch+1} Epochs!")
91def predict(model, test_X, batch_size=100):
94 device = torch.device(
'cuda' if torch.cuda.is_available()
else 'cpu')
95 model = model.to(device)
100 test_dataset = torch.utils.data.TensorDataset(torch.Tensor(test_X))
101 test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=
False)
104 with torch.no_grad():
105 for i, data
in enumerate(test_loader):
106 X = data[0].to(device)
108 predictions.append(outputs)
109 preds = torch.cat(predictions)
111 return preds.cpu().numpy()
114load_model_custom_objects = {
"optimizer": optimizer,
"criterion": criterion,
"train_func": fit,
"predict_func": predict}
117m = torch.jit.script(net)
118torch.jit.save(m,
"PyTorchModelCNN.pt")
predict(model, test_X, batch_size=100)
fit(model, train_loader, val_loader, num_epochs, batch_size, optimizer, criterion, save_best, scheduler)