Neural Style Transfer

Econ 425T / Biostat 203B

Author

Dr. Hua Zhou @ UCLA

Published

March 6, 2023

Source: https://keras.io/examples/generative/neural_style_transfer/

Display system information for reproducibility.

import IPython
print(IPython.sys_info())
{'commit_hash': 'add5877a4',
 'commit_source': 'installation',
 'default_encoding': 'utf-8',
 'ipython_path': '/Users/huazhou/opt/anaconda3/lib/python3.9/site-packages/IPython',
 'ipython_version': '8.8.0',
 'os_name': 'posix',
 'platform': 'macOS-10.16-x86_64-i386-64bit',
 'sys_executable': '/Users/huazhou/opt/anaconda3/bin/python3',
 'sys_platform': 'darwin',
 'sys_version': '3.9.12 (main, Apr  5 2022, 01:56:13) \n[Clang 12.0.0 ]'}

Imports and function definitions:

import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.applications import vgg19

base_image_path = keras.utils.get_file("paris.jpg", "https://i.imgur.com/F28w3Ac.jpg")
style_reference_image_path = keras.utils.get_file(
    "starry_night.jpg", "https://i.imgur.com/9ooB60I.jpg"
)
result_prefix = "paris_generated"

# Weights of the different loss components
total_variation_weight = 1e-6
style_weight = 1e-6
content_weight = 2.5e-8

# Dimensions of the generated picture.
width, height = keras.preprocessing.image.load_img(base_image_path).size
img_nrows = 400
img_ncols = int(width * img_nrows / height)

1 Display base (content) image and the style reference image

from tensorflow_docs.vis import embed

embed.embed_file(base_image_path)
embed.embed_file(style_reference_image_path)

2 Image preprocessing / deprocessing utilities

def preprocess_image(image_path):
    # Util function to open, resize and format pictures into appropriate tensors
    img = keras.preprocessing.image.load_img(
        image_path, target_size = (img_nrows, img_ncols)
    )
    img = keras.preprocessing.image.img_to_array(img)
    img = np.expand_dims(img, axis = 0)
    img = vgg19.preprocess_input(img)
    return tf.convert_to_tensor(img)


def deprocess_image(x):
    # Util function to convert a tensor into a valid image
    x = x.reshape((img_nrows, img_ncols, 3))
    # Remove zero-center by mean pixel
    x[:, :, 0] += 103.939
    x[:, :, 1] += 116.779
    x[:, :, 2] += 123.68
    # 'BGR'->'RGB'
    x = x[:, :, ::-1]
    x = np.clip(x, 0, 255).astype("uint8")
    return x

3 Compute the style transfer loss

First, we need to define 4 utility functions:

  • gram_matrix (used to compute the style loss).

  • The style_loss function, which keeps the generated image close to the local textures of the style reference image.

  • The content_loss function, which keeps the high-level representation of the generated image close to that of the base image.

  • The total_variation_loss function, a regularization loss which keeps the generated image locally-coherent.

# The gram matrix of an image tensor (feature-wise outer product)
def gram_matrix(x):
    x = tf.transpose(x, (2, 0, 1))
    features = tf.reshape(x, (tf.shape(x)[0], -1))
    gram = tf.matmul(features, tf.transpose(features))
    return gram

# The "style loss" is designed to maintain
# the style of the reference image in the generated image.
# It is based on the gram matrices (which capture style) of
# feature maps from the style reference image
# and from the generated image
def style_loss(style, combination):
    S = gram_matrix(style)
    C = gram_matrix(combination)
    channels = 3
    size = img_nrows * img_ncols
    return tf.reduce_sum(tf.square(S - C)) / (4.0 * (channels**2) * (size**2))

# An auxiliary loss function
# designed to maintain the "content" of the
# base image in the generated image
def content_loss(base, combination):
    return tf.reduce_sum(tf.square(combination - base))

# The 3rd loss function, total variation loss,
# designed to keep the generated image locally coherent
def total_variation_loss(x):
    a = tf.square(
        x[:, : img_nrows - 1, : img_ncols - 1, :] - x[:, 1:, : img_ncols - 1, :]
    )
    b = tf.square(
        x[:, : img_nrows - 1, : img_ncols - 1, :] - x[:, : img_nrows - 1, 1:, :]
    )
    return tf.reduce_sum(tf.pow(a + b, 1.25))

Next, let’s create a feature extraction model that retrieves the intermediate activations of VGG19 (as a dict, by name).

# Build a VGG19 model loaded with pre-trained ImageNet weights
model = vgg19.VGG19(weights = "imagenet", include_top = False)

# Get the symbolic outputs of each "key" layer (we gave them unique names).
outputs_dict = dict([(layer.name, layer.output) for layer in model.layers])

# Set up a model that returns the activation values for every layer in
# VGG19 (as a dict).
feature_extractor = keras.Model(inputs = model.inputs, outputs = outputs_dict)

Finally, here’s the code that computes the style transfer loss.

# List of layers to use for the style loss.
style_layer_names = [
    "block1_conv1",
    "block2_conv1",
    "block3_conv1",
    "block4_conv1",
    "block5_conv1",
]
# The layer to use for the content loss.
content_layer_name = "block5_conv2"

def compute_loss(combination_image, base_image, style_reference_image):
    input_tensor = tf.concat(
        [base_image, style_reference_image, combination_image], axis=0
    )
    features = feature_extractor(input_tensor)

    # Initialize the loss
    loss = tf.zeros(shape=())

    # Add content loss
    layer_features = features[content_layer_name]
    base_image_features = layer_features[0, :, :, :]
    combination_features = layer_features[2, :, :, :]
    loss = loss + content_weight * content_loss(
        base_image_features, combination_features
    )
    # Add style loss
    for layer_name in style_layer_names:
        layer_features = features[layer_name]
        style_reference_features = layer_features[1, :, :, :]
        combination_features = layer_features[2, :, :, :]
        sl = style_loss(style_reference_features, combination_features)
        loss += (style_weight / len(style_layer_names)) * sl

    # Add total variation loss
    loss += total_variation_weight * total_variation_loss(combination_image)
    return loss

4 Add a tf.function decorator to loss & gradient computation

To compile it, and thus make it fast.

@tf.function
def compute_loss_and_grads(combination_image, base_image, style_reference_image):
    with tf.GradientTape() as tape:
        loss = compute_loss(combination_image, base_image, style_reference_image)
    grads = tape.gradient(loss, combination_image)
    return loss, grads

5 The training loop

Repeatedly run vanilla gradient descent steps to minimize the loss, and save the resulting image every 100 iterations.

We decay the learning rate by 0.96 every 100 steps.

optimizer = keras.optimizers.SGD(
    keras.optimizers.schedules.ExponentialDecay(
        initial_learning_rate = 100.0, decay_steps = 100, decay_rate = 0.96
    )
)

base_image = preprocess_image(base_image_path)
style_reference_image = preprocess_image(style_reference_image_path)
combination_image = tf.Variable(preprocess_image(base_image_path))

iterations = 100
for i in range(1, iterations + 1):
    loss, grads = compute_loss_and_grads(
        combination_image, base_image, style_reference_image
    )
    optimizer.apply_gradients([(grads, combination_image)])
    if i % 10 == 0:
        print("Iteration %d: loss=%.2f" % (i, loss))
        img = deprocess_image(combination_image.numpy())
        fname = result_prefix + "_at_iteration_%d.png" % i
        keras.preprocessing.image.save_img(fname, img)
<tf.Variable 'UnreadVariable' shape=() dtype=int64, numpy=1>
<tf.Variable 'UnreadVariable' shape=() dtype=int64, numpy=2>
<tf.Variable 'UnreadVariable' shape=() dtype=int64, numpy=3>
<tf.Variable 'UnreadVariable' shape=() dtype=int64, numpy=4>
<tf.Variable 'UnreadVariable' shape=() dtype=int64, numpy=5>
<tf.Variable 'UnreadVariable' shape=() dtype=int64, numpy=6>
<tf.Variable 'UnreadVariable' shape=() dtype=int64, numpy=7>
<tf.Variable 'UnreadVariable' shape=() dtype=int64, numpy=8>
<tf.Variable 'UnreadVariable' shape=() dtype=int64, numpy=9>
<tf.Variable 'UnreadVariable' shape=() dtype=int64, numpy=10>
Iteration 10: loss=28361.54
<tf.Variable 'UnreadVariable' shape=() dtype=int64, numpy=11>
<tf.Variable 'UnreadVariable' shape=() dtype=int64, numpy=12>
<tf.Variable 'UnreadVariable' shape=() dtype=int64, numpy=13>
<tf.Variable 'UnreadVariable' shape=() dtype=int64, numpy=14>
<tf.Variable 'UnreadVariable' shape=() dtype=int64, numpy=15>
<tf.Variable 'UnreadVariable' shape=() dtype=int64, numpy=16>
<tf.Variable 'UnreadVariable' shape=() dtype=int64, numpy=17>
<tf.Variable 'UnreadVariable' shape=() dtype=int64, numpy=18>
<tf.Variable 'UnreadVariable' shape=() dtype=int64, numpy=19>
<tf.Variable 'UnreadVariable' shape=() dtype=int64, numpy=20>
Iteration 20: loss=21659.45
<tf.Variable 'UnreadVariable' shape=() dtype=int64, numpy=21>
<tf.Variable 'UnreadVariable' shape=() dtype=int64, numpy=22>
<tf.Variable 'UnreadVariable' shape=() dtype=int64, numpy=23>
<tf.Variable 'UnreadVariable' shape=() dtype=int64, numpy=24>
<tf.Variable 'UnreadVariable' shape=() dtype=int64, numpy=25>
<tf.Variable 'UnreadVariable' shape=() dtype=int64, numpy=26>
<tf.Variable 'UnreadVariable' shape=() dtype=int64, numpy=27>
<tf.Variable 'UnreadVariable' shape=() dtype=int64, numpy=28>
<tf.Variable 'UnreadVariable' shape=() dtype=int64, numpy=29>
<tf.Variable 'UnreadVariable' shape=() dtype=int64, numpy=30>
Iteration 30: loss=18439.31
<tf.Variable 'UnreadVariable' shape=() dtype=int64, numpy=31>
<tf.Variable 'UnreadVariable' shape=() dtype=int64, numpy=32>
<tf.Variable 'UnreadVariable' shape=() dtype=int64, numpy=33>
<tf.Variable 'UnreadVariable' shape=() dtype=int64, numpy=34>
<tf.Variable 'UnreadVariable' shape=() dtype=int64, numpy=35>
<tf.Variable 'UnreadVariable' shape=() dtype=int64, numpy=36>
<tf.Variable 'UnreadVariable' shape=() dtype=int64, numpy=37>
<tf.Variable 'UnreadVariable' shape=() dtype=int64, numpy=38>
<tf.Variable 'UnreadVariable' shape=() dtype=int64, numpy=39>
<tf.Variable 'UnreadVariable' shape=() dtype=int64, numpy=40>
Iteration 40: loss=16359.35
<tf.Variable 'UnreadVariable' shape=() dtype=int64, numpy=41>
<tf.Variable 'UnreadVariable' shape=() dtype=int64, numpy=42>
<tf.Variable 'UnreadVariable' shape=() dtype=int64, numpy=43>
<tf.Variable 'UnreadVariable' shape=() dtype=int64, numpy=44>
<tf.Variable 'UnreadVariable' shape=() dtype=int64, numpy=45>
<tf.Variable 'UnreadVariable' shape=() dtype=int64, numpy=46>
<tf.Variable 'UnreadVariable' shape=() dtype=int64, numpy=47>
<tf.Variable 'UnreadVariable' shape=() dtype=int64, numpy=48>
<tf.Variable 'UnreadVariable' shape=() dtype=int64, numpy=49>
<tf.Variable 'UnreadVariable' shape=() dtype=int64, numpy=50>
Iteration 50: loss=14863.73
<tf.Variable 'UnreadVariable' shape=() dtype=int64, numpy=51>
<tf.Variable 'UnreadVariable' shape=() dtype=int64, numpy=52>
<tf.Variable 'UnreadVariable' shape=() dtype=int64, numpy=53>
<tf.Variable 'UnreadVariable' shape=() dtype=int64, numpy=54>
<tf.Variable 'UnreadVariable' shape=() dtype=int64, numpy=55>
<tf.Variable 'UnreadVariable' shape=() dtype=int64, numpy=56>
<tf.Variable 'UnreadVariable' shape=() dtype=int64, numpy=57>
<tf.Variable 'UnreadVariable' shape=() dtype=int64, numpy=58>
<tf.Variable 'UnreadVariable' shape=() dtype=int64, numpy=59>
<tf.Variable 'UnreadVariable' shape=() dtype=int64, numpy=60>
Iteration 60: loss=13727.03
<tf.Variable 'UnreadVariable' shape=() dtype=int64, numpy=61>
<tf.Variable 'UnreadVariable' shape=() dtype=int64, numpy=62>
<tf.Variable 'UnreadVariable' shape=() dtype=int64, numpy=63>
<tf.Variable 'UnreadVariable' shape=() dtype=int64, numpy=64>
<tf.Variable 'UnreadVariable' shape=() dtype=int64, numpy=65>
<tf.Variable 'UnreadVariable' shape=() dtype=int64, numpy=66>
<tf.Variable 'UnreadVariable' shape=() dtype=int64, numpy=67>
<tf.Variable 'UnreadVariable' shape=() dtype=int64, numpy=68>
<tf.Variable 'UnreadVariable' shape=() dtype=int64, numpy=69>
<tf.Variable 'UnreadVariable' shape=() dtype=int64, numpy=70>
Iteration 70: loss=12832.42
<tf.Variable 'UnreadVariable' shape=() dtype=int64, numpy=71>
<tf.Variable 'UnreadVariable' shape=() dtype=int64, numpy=72>
<tf.Variable 'UnreadVariable' shape=() dtype=int64, numpy=73>
<tf.Variable 'UnreadVariable' shape=() dtype=int64, numpy=74>
<tf.Variable 'UnreadVariable' shape=() dtype=int64, numpy=75>
<tf.Variable 'UnreadVariable' shape=() dtype=int64, numpy=76>
<tf.Variable 'UnreadVariable' shape=() dtype=int64, numpy=77>
<tf.Variable 'UnreadVariable' shape=() dtype=int64, numpy=78>
<tf.Variable 'UnreadVariable' shape=() dtype=int64, numpy=79>
<tf.Variable 'UnreadVariable' shape=() dtype=int64, numpy=80>
Iteration 80: loss=12110.48
<tf.Variable 'UnreadVariable' shape=() dtype=int64, numpy=81>
<tf.Variable 'UnreadVariable' shape=() dtype=int64, numpy=82>
<tf.Variable 'UnreadVariable' shape=() dtype=int64, numpy=83>
<tf.Variable 'UnreadVariable' shape=() dtype=int64, numpy=84>
<tf.Variable 'UnreadVariable' shape=() dtype=int64, numpy=85>
<tf.Variable 'UnreadVariable' shape=() dtype=int64, numpy=86>
<tf.Variable 'UnreadVariable' shape=() dtype=int64, numpy=87>
<tf.Variable 'UnreadVariable' shape=() dtype=int64, numpy=88>
<tf.Variable 'UnreadVariable' shape=() dtype=int64, numpy=89>
<tf.Variable 'UnreadVariable' shape=() dtype=int64, numpy=90>
Iteration 90: loss=11515.97
<tf.Variable 'UnreadVariable' shape=() dtype=int64, numpy=91>
<tf.Variable 'UnreadVariable' shape=() dtype=int64, numpy=92>
<tf.Variable 'UnreadVariable' shape=() dtype=int64, numpy=93>
<tf.Variable 'UnreadVariable' shape=() dtype=int64, numpy=94>
<tf.Variable 'UnreadVariable' shape=() dtype=int64, numpy=95>
<tf.Variable 'UnreadVariable' shape=() dtype=int64, numpy=96>
<tf.Variable 'UnreadVariable' shape=() dtype=int64, numpy=97>
<tf.Variable 'UnreadVariable' shape=() dtype=int64, numpy=98>
<tf.Variable 'UnreadVariable' shape=() dtype=int64, numpy=99>
<tf.Variable 'UnreadVariable' shape=() dtype=int64, numpy=100>
Iteration 100: loss=11018.29

After 100 iterations, we get the following result:

embed.embed_file(result_prefix + "_at_iteration_100.png")