We’re ready to code! In Part 1 we looked at how GANs work and Part 2 showed how to get the data ready. In this Part, we will begin creating the functions that handle the image data including some pre-procesing and data normalisation.
- Image Functions
In the previous post we downloaded and pre-processed our training data. There were also links to the skeleton code we will be using in the remainder of the tutorial, here they are again:
gantut_imgfuncs.py: holds the image-related functions
gantut_datafuncs.py: contains the data-related functions
gantut_gan.py: is where we define the GAN
gantut_trainer.py: is the script that we will call in order to train the GAN
Now, if your folder structure that looks something like this then we’re ready to go:
~/GAN |- raw |-- 00001.jpg |-- ... |- aligned |-- 00001.jpg |-- ... |- gantut_imgfuncs.py |- gantut_datafuncs.py |- gantut_gan.py |- gantut_trainer.py
We’re going to want to be able to read-in a set of images. We will also want to be able to output some generated images. We will also add in a fail-safe cropping/transformation procedure in-case we want to make sure we have the right input format. The skeleton code
gantut_imgfuncs.py contains the definition headers for these functions, we will fill them in as we go along.
These are the functions needed to get the data from the hard-disk into our network. They are called like this:
We are dealing with standard image files and our GAN will support
.png as input. For these kind of files, Python already has well-developed tools: specifically we can use the scipy.misc.imread function from the
scipy.misc library. This is a one-liner and is already written in the skeleton code.
path: location of the image
- the image
""" Reads in the image (part of get_image function) """ def imread(path): return scipy.misc.imread(path, mode='RGB').astype(np.float)
transform()[to top] This function we will have to write into the skeleton. We are including this to make sure that the image data are all of the same dimensions. So this function will need to take in the image, the desired width (the output will be square) and whether to perform the cropping or not. We may have already cropped our images (as we have) because we've done some registration/alignment etc. We do a check on whether we want to crop the image, if we do then call the `center_crop` function, other wise, just take the `image` as it is. Before returning our cropped (or uncropped) image, we are going to perform normalisation. Currently the pixels have intensity values in the range $[0 \ 255]$ for each channel (reg, green, blue). It is best not to have this kind of skew on our data, so we will normalise our images to have intensity values in the range $[-1 \ 1]$ by dividing by the mean of the maximum range (127.5) and subtracting 1. i.e. image/127.5 - 1. We will define the cropping function next, but note that the returned image is a simply a `numpy` array. *Inputs* * `image`: the image data to be transformed * `npx`: the size of the transformed image [`npx` x `npx`] * `is_crop`: whether to preform cropping too [`True` or `False`] *Returns* * the cropped, normalised image ```python """ Transforms the image by cropping and resizing and normalises intensity values between -1 and 1 """ def transform(image, npx=64, is_crop=True): if is_crop: cropped_image = center_crop(image, npx) else: cropped_image = image return np.array(cropped_image)/127.5 - 1. ```
Lets perform the cropping of the images (if requested). Usually we deal with square images, say $[64 \times 64]$. We can add a quick option to change that with short
if statements looking at the
crop_w argument to this function. We take the current height and width (
w) from the
shape of the image
To find the location of the centre of the image around which to take the square crop, we take half the result of
h - crop_h and
w - crop_w, making sure to round both to get a definite pixel value. However, it’s not guaranteed (depending on the image dimensions) that we will end up with a nice $[64 \times 64]$ image. Let’s fix that at the end.
scipy has some efficient functions that we may as well use.
scipy.misc.imresize takes in an image array and the desired size and outputs a resized image. We can give it our array, which may not be a nice square image due to the initial image dimensions, and
imresize will perform interpolation (bilinear by default) to make sure we get a nice square image at the end.
x: the input image
crop_h: the height of the crop region
crop_w: if None crop width = crop height
resize_w: the width of the resized image
- the cropped image
""" Crops the input image at the centre pixel """ def center_crop(x, crop_h, crop_w=None, resize_w=64): if crop_w is None: crop_w = crop_h h, w = x.shape[:2] j = int(round((h - crop_h)/2.)) i = int(round((w - crop_w)/2.)) return scipy.misc.imresize(x[j:j+crop_h, i:i+crop_w], [resize_w, resize_w])
get_image function is a wrapper that will call the
transform functions. It is the function that we’ll call to get the data rather than doing two separate function calls in the main GAN
class. This is a one-liner and is already written in the skeleton code.
is_crop: whether to crop the image or not [True or False]
image_path: location of the image
image_size: width (in pixels) of the output image
- the cropped image
""" Loads the image and crops it to 'image_size' """ def get_image(image_path, image_size, is_crop=True): return transform(imread(image_path), image_size, is_crop)
When we’re training our network, we will want to see some of the results. The previous functions all deal with getting images from storage into the networks. We now want to take some images out. The functions are called like this:
Firstly, let’s put the intensities back into the skewed range, we’ll just go from $[-1 \ 1]$ to $[0 \ 1]$ here.
images: the image to be transformed
- the transformed image
""" This turns the intensities back to a normal range """ def inverse_transform(images): return (images+1.)/2.
We will create an array of several example images from the network which we can output every now and again to see how things are progressing. We need some
images to go in and a
size which will say how many images in width and height the array should be.
First get the height
h and width
w of the
images from their
shape (we assume they’re all the same size becuase we will have already used our previous functions to make this happen). Note that
images is a collection of images where each
image has the same
img to be the final image array and initialise it to all zeros. Notice that there is a ‘3’ on the end to denote the number of channels as these are RGB images. This will still work for grayscale images.
Next we will iterate through each
images and put it into place. The
% operator is the modulo which returns the remainder of the division between two numbers.
// is the floor division operator which returns the integer result of division rounded down. So this will move along the top row of the array (remembering Python indexing starts at 0) and move down placing the image at each iteration.
images: the set of input images
size: [height, width] of the array
- an array of images as a single image
""" Takes a set of 'images' and creates an array from them. """ def merge(images, size): h, w = images.shape, images.shape img = np.zeros((int(h * size), int(w * size), 3)) for idx, image in enumerate(images): i = idx % size j = idx // size img[j*h:j*h+h, i*w:i*w+w, :] = image return img
Our image array
img now has intensity values in $[0 \ 1]$ lets make this the proper image range $[0 \ 255]$ before getting the integer values as an image array with
images: the set of input images
size: [height, width] of the array
path: the save location
- an image saved to disk
""" Takes a set of `images` and calls the merge function. Converts the array to image data and saves to disk. """ def imsave(images, size, path): img = merge(images, size) return scipy.misc.imsave(path, (255*img).astype(np.uint8))
Finally, let’s create the wrapper to pull this together:
images: the images to be saves
size: the size of the img array [width height]
image_path: where the array is to be stored on disk
""" takes an image and saves it to disk. Redistributes intensity values [-1 1] from [0 255] """ def save_images(images, size, image_path): return imsave(inverse_transform(images), size, image_path)
In this post, we’ve dealt with all of the functions that are needed to import image data into our network and also some that will create outputs so we can see what’s going on. We’ve made sure that we can import any image-size and it will be dealt with correctly.
Make sure that we’ve imported
numpy to this script:
import numpy as np import scipy.misc
The complete script can be found here. In the next post, we will be working on the GAN itself and building the
gantut_datafuncs.py functions as we go.