class UISegmentationSystem(pytorch_lightning.LightningModule):
"""
Predict UI segments of videogame screenshots using a Convolutional Neural Network.
This class can be used for training, validation, and prediction.
"""
def __init__(self, hparams):
"""
Requires args for batch size, data folder, output folder, etc.
"""
super(UISegmentationSystem, self).__init__()
log.info("Running __init__")
self.hparams = hparams
# I usually initialize like this to use it as a boolean condition elsewhere
self.dataset = None
# We could use any kind of network architecture, or set it via an arg flag
self.network = UI_SEG.UNet(num_in_channels=3, num_out_channels=1, max_features=256)
self.models_dir = f"{hparams.output_dir}/model_checkpoints"
self.images_dir = f"{hparams.output_dir}/step_images"
os.makedirs(self.models_dir, exist_ok=True)
os.makedirs(self.images_dir, exist_ok=True)
def forward(self, input):
"""
Send an input image (tensor with batch dimension) through the prediction network and return the prediction (tensor with same batch dimension)
This is the same as the forward method of a pytorch nn module, we just have our network in a seperate file to have fewer lines here
"""
return self.network(input)
# ---------------------
# TRAINING SETUP
# ---------------------
def configure_optimizers(self):
"""
REQUIRED
Adam is a pretty popular optimizer, we'll use it
Lightning handles the zero_grad() and step() calls
:return: optimizer to use in training
(optionally return list of optimizers and a list of learning rate schedulers in a tuple)
"""
log.info("Configuring optimizer")
optimizer = torch.optim.Adam(
self.network.parameters(),
lr=self.hparams.learning_rate
)
# scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
# factor=0.1)
# return [optimizer], [scheduler]
return optimizer
def prepare_data(self):
"""
REQUIRED
Lightning calls this once at the beginning of training / inference, it gets called once no matter how many GPUs you're using
"""
# Resize to square and turn PIL image into torch Tensor
training_transform = UI_SEG.ImageTargetCompose([
UI_SEG.ImageTargetResize((self.hparams.image_size, self.hparams.image_size)),
UI_SEG.ImageTargetToTensor()
])
log.info('Prepped training transform')
self.dataset = UI_SEG.UIImageMaskDataset(data_dir=self.hparams.dataset, transform=training_transform, overfit_num=self.hparams.overfit_num)
self.val_dataset = UI_SEG.UIImageMaskDataset(data_dir=self.hparams.dataset, transform=training_transform, overfit_num=self.hparams.overfit_num)
log.info(
f'Train ({len(self.dataset)} samples) and Val ({len(self.val_dataset)} samples) datasets loaded')
# Configure val set with 20% of full data. or do k-fold, up to you
# num_samples = len(self.dataset)
# num_train = int(0.8 * num_samples)
# indices = torch.randperm(num_samples).tolist()
# self.dataset = torch.utils.data.Subset(
# self.dataset, indices[0:num_train])
# self.val_dataset = torch.utils.data.Subset(
# self.val_dataset, indices[num_train:])
# log.info(f'Full Dataset loaded: len: {num_samples}')
def train_dataloader(self):
"""
REQUIRED. Return pytorch dataloader to be used in training loop
Lightning passes batches to the `training_step()` function below.
It also handles making a distributed sampler if you're using multiple GPUs.
"""
log.info("Fetching training dataloader")
return torch.utils.data.DataLoader(self.dataset, batch_size=self.hparams.batch_size, drop_last=True, shuffle=True)
def val_dataloader(self):
"""
REQUIRED if doing val loop. Return pytorch dataloader to be used in validation loop
Lighting passes each batch to the `validation_step()` function below.
"""
log.info("Fetching validation dataloader")
# Batch_size increase if using a full val set
return torch.utils.data.DataLoader(self.val_dataset, batch_size=1, shuffle=False)
def loss_function(self, targets, predictions):
"""
Used in training and validation to backpropogate into the network weights and assess how well our model is doing.
Combines sigmoid with Binary Cross Entropy for numerical stability.
Can include pos_weight to increase recall of underrepresented classes
:param targets: ground truth segmentation mask
:param predictions: output of the network
"""
bce_loss = F.binary_cross_entropy_with_logits(predictions, targets)
return bce_loss
# -----------------------------
# TRAINING LOOP
# -----------------------------
def training_step(self, batch, batch_idx):
"""
Lightning calls this inside the training loop
:param batch: Lightning yields a batch from the dataloader, in our case it is a tuple of (images, targets) already batched (tensors of dimesions [Batch, Channels, Height, Width])
:param batch_idx: batch_number in current epoch
:return: dict
- loss -> tensor scalar [REQUIRED]
- progress_bar -> Dict for progress bar display. Must have only tensors
- log -> Dict of metrics to add to logger. Must have only tensors (no images, etc)
"""
# forward pass
images, targets = batch
predictions = self.forward(images)
# calculate loss
loss_val = self.loss_function(targets, predictions)
to_log = {'training_loss': loss_val}
output = {
'loss': loss_val, # required
'progress_bar': to_log,
'log': to_log,
}
return output
# ---------------------
# VALIDATION LOOP
# ---------------------
def validation_step(self, batch, batch_number):
"""
Called in validation loop with model in eval mode
:param batch: Lightning yields a batch from the dataloader, in our case it is a tuple of (images, targets) already batched (tensors of dimesions [Batch, Channels, Height, Width])
:param batch_number: Lightning tells us what batch number we're on. We'll use this to choose when to log some images
:return: dict
- val_loss -> tensor scalar [Useful for Early Stopping]
- progress_bar -> Dict for progress bar display. Must have only tensors
- log -> Dict of metrics to add to logger. Must have only tensors (no images, etc)
"""
images, targets = batch
predictions = self.forward(images)
loss_val = self.loss_function(targets, predictions)
if batch_number == 0:
self.log_validation_images(images, predictions, targets, step=self.global_step)
to_log = {'val_loss': loss_val}
output = {
'val_loss': loss_val,
'progress_bar': to_log,
'log': to_log
}
return output
def log_validation_images(self, inputs, predictions, targets, step=0):
# Turn 1 channel binary masks into 3 channel 'rgb'
model_outputs = torch.cat((predictions, predictions, predictions), dim=1)
viz_targets = torch.cat((targets, targets, targets), dim=1)
img_grid = torchvision.utils.make_grid(torch.cat([inputs, model_outputs, viz_targets]), nrow=inputs.shape[0], padding=20, normalize=True, range=(0,1))
plt.imshow(np.asarray(img_grid.detach().permute((1,2,0)).cpu()))
torchvision.utils.save_image(img_grid, os.path.join(self.images_dir, f"step_{step}_val_batch.png"))