Technology

Experimenting with Style Transfer using PyTorch

Yijiao Wang
January 2nd 2020
I love painting and painting is fun. What's also extremely fun is using a fascinating technique in Deep Learning known as Style Transfer to make an ordinary picture assemble some most famous artists' styles and create your very own piece of art. Hurrah! Such fun!
In today's post I will go over some basic terms and concepts behind generating a style-transfer image using a pre-trained neural network model, along with some code. Additionally, I ran the examples on Google Colab platform which provides free GPUs and TPUs. I have experimented with several of my own paintings, landscape photos as content images and obtained some interesting results.
Someone forgot to caption me
Someone forgot to caption me

What is style transfer?

Style transfer is the technique of enabling the artistic style of one image to be applied to another image, while keeping its semantic content. It copies texture inputs from style image including color patterns, brush strokes and combinations, changes the input to resemble the content of content-image and the style of style-image, as shown in the image set above.
Interesting eh? Let's see how we are accomplishing this.
We are importing VGG-19 - a convolutional neural network (CNN) that has been pre-trained on more than a million images - to help us generate target image. The reasons being:
  1. CNN has the ability to significantly reduce the number of parameters in the model. CNN is quite useful in high-dimensional image processing because: imagine each pixel is like one feature. A 100 x 100 pixel image consists of 10,000 features, and such optimizing such big model would be very computationally intensive! Because we don't care about every individual pixel of an image, instead we only care about those deciding pixels on the edge or boundary between two color regions. So we can divide the image into many smaller squares, each one will carry one feature used to detect the same pattern (e.g. grass, leaves, sky) in the input image.
  2. Using a pre-trained model in the idea of Transfer Learning to massively reducing time in dataset labeling and training.

Load pre-trained model

VGG networks are trained on images with each channel normalized by mean=[0.485, 0.456, 0.406] and std=[V0.229, 0.224, 0.225]. We will also use them to normalize the image before sending it into the network for best results.
vgg = models.vgg19(pretrained=True).features

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vgg.to(device)
One thing to pay attention here is that we are calling the .cuda() function to utilize GPUs for neural networks' computations.

Load, transform images

def image_loader(image_path, max_size=400):
    image = Image.open(image_path).convert('RGB')
    if max(image.size) > max_size:
        size = max_size
    else:
        size = max(image.size)

    in_transform = transforms.Compose([
                        transforms.Resize(size),
                        transforms.ToTensor(),
                        transforms.Normalize((0.485, 0.456, 0.406),
                                             (0.229, 0.224, 0.225))])

    image = in_transform(image).unsqueeze(0)

    return image

style_img = image_loader(STYLE_IMAGE_PATH).to(device)
content_img = image_loader(CONTENT_IMAGE_PATH).to(device)
In the above part, we are feeding a list of transformations into an array then transforms.Compose. We are setting default maximum size parameter to be 400px for faster computation; Then we transform image into PyTorch tensor to feed to our model; At last, normalization.
In the following code block, we are doing the opposite - converting a tensor back to numpy array for displaying using matplotlib package.
def image_converter(tensor):
    image = tensor.to("cpu").clone().detach()
    image = image.numpy().squeeze()
    image = image.transpose(1,2,0)
    image = image * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))
    image = image.clip(0, 1)

    return image

Extract features

We first extract model features from style and content images. We then extract features from the target image and compare each feature with corresponding feature from our style and content image to calculate losses.
VGG-19 is a 19 layer 3 x 3 convolution network. We are going to use only 1 layer (conv4_2) for content extraction, as a single layer is sufficient for extracting content, and 5 layers for style feature extraction.
def get_features(image, model, layers=None):
  if layers is None:
    layers = {'0': 'conv1_1',
              '5': 'conv2_1',
              '10': 'conv3_1',
              '19': 'conv4_1',
              '21': 'conv4_2',  ## content layer
              '28': 'conv5_1'}
  features = {}
  x = image
  for name, layer in enumerate(model):
    x = layer(x)
    if str(name) in layers:
      features[layers[str(name)]] = x

  return features

target = content_img.clone().requires_grad_(True).to(device) ## We want target image to be adjusted along with optimization process thus gradients are required

content_features = get_features(content_img, vgg)
style_features = get_features(style_img, vgg)

## We will be running target features extraction in the loop and extract features from target on each iteration

Calculate content loss

The idea behind content transfer is to use Mean Square Error (MSE) loss or other loss functions between input image's features and target features - and to optimize such loss function. This will help us in preserving original content in the generated image.
content_loss = torch.mean((target_features['conv4_2'] -
                             content_features['conv4_2']) ** 2)

Calculate style loss

For calculating style loss, we will consider feature maps from many convolution layers and multiply weight at each layer.
Gram Matrix is the result of multiplying a given matrix by its transposed matrix. So we're going to take an image tensor as input and reshape it in order to apply the Gram Matrix transformation. Gram Matrix measures the correlations between features that convolution layer's filters supposed to detect.
def gram_matrix(tensor):
  _, d, h, w = tensor.size() ## batch size, depth, height, width
  tensor = tensor.view(d, h*w)
  gram = torch.mm(tensor, tensor.t())
  return gram
We can prioritize certain style layers than the others by associating different weight with each layer. Earlier layers are usually more effect at recreating style features so we are assigning a heavier weight to first and second layer.
style_weights = {'conv1_1': 0.75,
                 'conv2_1': 0.5,
                 'conv3_1': 0.2,
                 'conv4_1': 0.2,
                 'conv5_1': 0.2}
Another weight parameter is the balance between content image and style image. This ratio is referred in style transfer papers as alpha / beta where alpha is the content image weight, and beta represents style image weight. I kept the content weight (alpha) at 1 and played around with the style weight (beta) instead of adjusting full ratio. Experimenting with different ratios generates following images, among which we can see some differences among images at various ratios, but not quite so significant.
content_weight = 1
style_weight = 1e6
Someone forgot to caption me
Calculating style loss is similar to calculating content loss:
  style_loss = 0
  for layer in style_weights:
    target_feature = target_features[layer]
    target_gram = gram_matrix(target_feature)
    _, d, h, w = target_feature.shape

    style_gram = style_grams[layer]

    layer_style_loss = style_weights[layer] * torch.mean(
      (target_gram - style_gram) ** 2)

    style_loss += layer_style_loss / (d * h * w)
At last, we combine content loss and style loss and optimize target image:
    total_loss = content_weight * content_loss + style_weight * style_loss
    total_loss.backward(retain_graph=True)
    optimizer = optim.Adam([target], lr=0.01) ## Define optimizer and learning rate
    optimizer.step()
So that sums it up! Rather than having a definite right or wrong answer with image classification problems, style transfer provides very subjective results. Generally we don't care so much about what is depicted in a style image rather than its colors, textures, repetitive patterns, so it usually generates a more "pleasing" end result when choosing an abstract, pattern-heavy style image.

Further Reading

Want invites to cool events and things?

Boom, Newsletter sign up form.