16print(
"running Torch code defining the model....")
19class Reshape(torch.nn.Module):
21 return x.view(-1,1,16,16)
24net = torch.nn.Sequential(
26 nn.Conv2d(1, 10, kernel_size=3, padding=1),
29 nn.Conv2d(10, 10, kernel_size=3, padding=1),
31 nn.MaxPool2d(kernel_size=2),
33 nn.Linear(10*8*8, 256),
40criterion = nn.BCELoss()
41optimizer = torch.optim.Adam
44def fit(model, train_loader, val_loader, num_epochs, batch_size, optimizer, criterion, save_best, scheduler):
45 trainer = optimizer(model.parameters(), lr=0.01)
46 schedule, schedulerSteps = scheduler
50 device = torch.device(
'cuda' if torch.cuda.is_available()
else 'cpu')
51 model = model.to(device)
53 for epoch
in range(num_epochs):
57 running_train_loss = 0.0
58 running_val_loss = 0.0
59 for i, (X, y)
in enumerate(train_loader):
61 X, y = X.to(device), y.to(device)
64 train_loss = criterion(output, target)
69 running_train_loss += train_loss.item()
71 print(f
"[{epoch+1}, {i+1}] train loss: {running_train_loss / 4 :.3f}")
72 running_train_loss = 0.0
75 schedule(optimizer, epoch, schedulerSteps)
81 for i, (X, y)
in enumerate(val_loader):
82 X, y = X.to(device), y.to(device)
85 val_loss = criterion(output, target)
86 running_val_loss += val_loss.item()
88 curr_val = running_val_loss / len(val_loader)
92 best_val = save_best(model, curr_val, best_val)
95 print(f
"[{epoch+1}] val loss: {curr_val :.3f}")
96 running_val_loss = 0.0
98 print(f
"Finished Training on {epoch+1} Epochs!")
103def predict(model, test_X, batch_size=100):
106 device = torch.device(
'cuda' if torch.cuda.is_available()
else 'cpu')
107 model = model.to(device)
112 test_dataset = torch.utils.data.TensorDataset(torch.Tensor(test_X))
113 test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=
False)
116 with torch.no_grad():
117 for i, data
in enumerate(test_loader):
118 X = data[0].to(device)
120 predictions.append(outputs)
121 preds = torch.cat(predictions)
123 return preds.cpu().numpy()
126load_model_custom_objects = {
"optimizer": optimizer,
"criterion": criterion,
"train_func": fit,
"predict_func": predict}
129m = torch.jit.script(net)
130torch.jit.save(m,
"PyTorchModelCNN.pt")
131print(
"The PyTorch CNN model is created and saved as PyTorchModelCNN.pt")