Computer Vision Part 2: Load Tagged Data With Pytorch🔗
From Images and Masks to Loading Data
In this we'll go take a set of Images and Target labels (black and white mask images), transform them into Pytorch Tensors, and set them up in a Dataloader to be fed into a Neural Net model
We'll be using PIL / Pillow to load images, as torchvision has built in support for that in its transforms.
Since Masks are images we won't need to load them with numpy
0. Python Setup🔗
Imports
Python | |
---|---|
Text Only | |
---|---|
1 |
|
Python | |
---|---|
Text Only | |
---|---|
1 |
|
1. Image and Target Loading🔗
We also want to make sure we're loading our images and targets as we intend, so let's double check that first
Our example will be to segment what is UI and what is not UI, so we can get away with just one color to signal UI.
We're treating the target as binary, so 'L' or '1' Mode in PIL should work. We most care about it being binary 0 and 1 after transforming to torch tensor
SIDE-NOTE In a multi-class segmentation problem (e.x. self-driving cars segmenting 'road', 'sky', 'people', 'cars', etc.) we often use different colors to represent different classes in our target data, which requires an additional step to our data-loading pipeline. The colors used don't matter, so long as they are unique and consistent; since the color values are known, we make a pytorch tensor of all 0's for each class (which is the same width and height as our target segmentation image), then fill each with a 1 only where the target segmentation pixels are equal to that specific class color value.
One caveat to this is that you cannot have overlapping segmentations. You'd need to either save a segmentation image for each class, save the segmentations as multi-dimensional (as many dimensions as you have classes) tensors / arrays, or perform some kind of bit shifting or mod operations on pixel values to allow multiple class-colors to overlap and add together.
Text Only | |
---|---|
1 2 |
|
2. Pytorch Dataset🔗
Pytorch has 2 tools that make feeding data into your deep learning model easier and rather generalizable to different tasks, data, and domains. These are the Dataset
and Dataloader
in the torch.utils.data
module
The Dataset Class handles the length of our data and how we receive each sample item. Note that the type of a sample dataset[x_index]
depends on the __getitem__
function, but also on the transform
that is passed in. Also note this sample in the dataset will be a single sample and not have a Batch dimension, even if it is a tensor (this matters at the DataLoader step).
We'll start with the dataset and provide plenty of comments.
NOTE If you're just loading Images (and not targets, e.x. in an auto-encoding task), or are loading Images by class (e.x. predicting animal or pokemon type or something), then torchvision imagefolder is probably the easiest thing to use, it returns a tuple. (sidenote it expects a root directory filled with at least 1 directory of images, that's how it seperates by class)
Text Only | |
---|---|
1 2 3 4 5 6 |
|
3. Load the folder into a Dataset🔗
Depending on your task you may not need to transform it at all (ex. a class name like 'cheese')
In our case we need to make sure whatever we do to the Image we also do to the Target mask. Most importantly, we need to convert both from PIL images to Pytorch Tensors so that they can be ingested in the model / network.
We'll also perform a resize, because reducing image size reduces the amount of data passing through the network and also to demonstrate Composing a chain of transforms
We use transform ToTensor()
from torchvision transforms because it's the simplest conversion from PIL to Tensor. You usually want to do this last, depending on the other transforms of course
Always remember your target image!
But this function only takes one input... So we can either use the 'Functional' version or write our own class. We'll take the class route, which lets you chain more easily.
If using random transforms on images and image targets make sure the same randomness is applied to both (you'll probably need to write your own transforms)
In a more advanced case we can chain together more transforms like crops, flips, and rotations to vary the training data. Typically we don't test / validate on data with as many transforms though.
One rather important transform we're ignoring for now is image normalization, which brings all of your dataset images within the same range using the entire dataset mean and standard deviation.
Torchvision ToTensor()
normalizes to a 0-1 scale which will keep our model's gradients in check. To go further we'd get the mean and std of our overfit image or dataset
NOTE many people import torchvision.transforms as T or similar. Feel free to do the same. I usually write my own transforms (ex. input_target_transforms.py
for auto encoders, and import as TT
or Transforms
for clarity)
Text Only | |
---|---|
1 2 3 4 5 6 7 |
|
Text Only | |
---|---|
1 |
|
4. Load into DataLoader🔗
This will be the last part we cover, as after this is running training samples through a model!
The Pytorch Dataloader works pretty efficiently and takes a lot of effort out of batching your samples when looping through your dataset.
In most cases you can use shuffle=True
on training and shuffle=False
on validation / test and not worry about a sampler / batch sampler.
drop_last=True
will make it so you end on the last full batch of data (ex if you have 999 samples and batch size of 4 then the last batch could only have 3 samples, this flag ignores that un-full batch of 3)
batch_size
tends to depend on the size of your GPU(s), num_workers
on the other hand scales with CPU processor cores that you have (probably 4, 2, maybe 1, maybe 8, if you have 16 you probably know; depends on your machine).
The dataloader gives results batched along a new dimension, meaning each tensor representing a whole batch of images has dimensions like this: [ Batch_Size, Channels, Height, Width ]
. It does this using the collate function
, which I generally leave default and is why I return tuples in Dataset implementations
NOTE You generally only want to touch your dataloaders (i.e. get a batch from them) when you are looping through training, as they are implemented as generators, which will continue sampling forward by default whenever you get a next()
item from it. In other words you can't get a specific index from them (you always can from the underlying dataset though), and if you access the first sample then the training may be a bit messed up because it will start from the second sample .
Text Only | |
---|---|
1 2 3 4 5 6 |
|
5. BONUS: Overfit Dataset🔗
We've seen how to load and transform images, and this will almost definitely work with a larger folder of images, and a larger batch size, but lets make sure.
Text Only | |
---|---|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 |
|
Python | |
---|---|
Created: June 7, 2023