For the purpose of learning I created a minimal DDIM for the MNIST dataset. Everything besides the math of diffusions I consider "extras."
Here is my list:
- U-Net (removed - replaced with something simpler)
- Positional embeddings (removed - part of unet)
- Diffusion Schedule (added it back in case it helps)
- Normalization of the dataset (left it in there for now)
The reason for a minimal example is because I do not understand the contribution of these other tricks. Therefore, If I start with something simpler - I can see the contribution of additional optimizations.
I expected to see some pictures that resemble a number but I do not. Loss goes down very slowly but not good enough.
I may have a bug or "what I have" it s just not enough. What would it take to make this minimal example to barely work? Any help would be greatly appreciated
The code is borrowed from this great Keras example: https://keras.io/examples/generative/ddim/
Here is my working code:
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import os
print("tf version: ", tf.__version__)
# data
diffusion_steps = 20
image_size = 28
# sampling
min_signal_rate = 0.02
max_signal_rate = 0.95
# optimization
batch_size = 64
num_epochs = 1000
learning_rate = 1e-3
x0 = tf.keras.Input(shape=(28, 28, 1))
t0 = tf.keras.Input(shape=(1, 1, 1))
combined = tf.keras.layers.Add()([x0, t0])
x = tf.keras.layers.Flatten()(combined)
x = tf.keras.layers.Dense(7 * 7 * 64, activation="relu")(x)
x = tf.keras.layers.Reshape((7, 7, 64))(x)
x = tf.keras.layers.Conv2DTranspose(
64, 3, activation="relu", strides=2, padding="same"
)(x)
x = tf.keras.layers.Conv2DTranspose(
32, 3, activation="relu", strides=2, padding="same"
)(x)
output = tf.keras.layers.Conv2DTranspose(1, 3, activation="sigmoid", padding="same")(x)
network = tf.keras.Model(inputs=[x0, t0], outputs=output)
# print(network.summary())
class DiffusionModel(tf.keras.Model):
def __init__(self, network):
super().__init__()
self.normalizer = tf.keras.layers.Normalization()
self.network = network
def compile(self, **kwargs):
super().compile(**kwargs)
self.noise_loss_tracker = tf.keras.metrics.Mean(name="n_loss")
self.image_loss_tracker = tf.keras.metrics.Mean(name="i_loss")
@property
def metrics(self):
return [self.noise_loss_tracker, self.image_loss_tracker]
def denormalize(self, images):
return tf.clip_by_value(images, 0.0, 1.0)
def diffusion_schedule(self, diffusion_times):
# diffusion times -> angles
start_angle = tf.acos(max_signal_rate)
end_angle = tf.acos(min_signal_rate)
diffusion_angles = start_angle + diffusion_times * (end_angle - start_angle)
# angles -> signal and noise rates
signal_rates = tf.cos(diffusion_angles)
noise_rates = tf.sin(diffusion_angles)
# note that their squared sum is always: sin^2(x) + cos^2(x) = 1
return noise_rates, signal_rates
# predictive stage
def denoise(self, noisy_images, noise_rates, signal_rates, training):
# predict noise component and calculate the image component using it
pred_noises = self.network([noisy_images, noise_rates**2], training=training)
pred_images = (noisy_images - noise_rates * pred_noises) / signal_rates
return pred_noises, pred_images
def reverse_diffusion(self, initial_noise, steps):
# reverse diffusion = sampling
batch = initial_noise.shape[0]
step_size = 1.0 / steps
# important line:
# at the first sampling step, the "noisy image" is pure noise
# but its signal rate is assumed to be nonzero (min_signal_rate)
next_noisy_images = initial_noise
for step in range(diffusion_steps):
noisy_images = next_noisy_images
diffusion_times = tf.ones((batch, 1, 1, 1)) - step * step_size
noise_rates, signal_rates = self.diffusion_schedule(diffusion_times)
pred_noises, pred_images = self.denoise(
noisy_images, noise_rates, signal_rates, training=False
)
# this new noisy image will be used in the next step
next_diffusion_times = diffusion_times - step_size
next_noise_rates, next_signal_rates = self.diffusion_schedule(
next_diffusion_times
)
next_noisy_images = (
next_signal_rates * pred_images + next_noise_rates * pred_noises
)
return pred_images
def generate(self, num_images, steps):
# noise -> images -> denormalized images
initial_noise = tf.random.normal(shape=(num_images, image_size, image_size, 1))
generated_images = self.reverse_diffusion(initial_noise, steps)
generated_images = self.denormalize(generated_images)
return generated_images
def train_step(self, images):
# normalize images to have standard deviation of 1, like the noises
images = self.normalizer(images, training=True)
noises = tf.random.normal(shape=(batch_size, image_size, image_size, 1))
diffusion_times = tf.random.uniform(
shape=(batch_size, 1, 1, 1), minval=0.0, maxval=1.0
)
noise_rates, signal_rates = self.diffusion_schedule(diffusion_times)
# mix the images with noises accordingly
noisy_images = signal_rates * images + noise_rates * noises
with tf.GradientTape() as tape:
# train the network to separate noisy images to their components
pred_noises, pred_images = self.denoise(
noisy_images, noise_rates, signal_rates, training=True
)
noise_loss = self.loss(noises, pred_noises) # used for training
image_loss = self.loss(images, pred_images) # only used as metric
gradients = tape.gradient(noise_loss, self.network.trainable_weights)
self.optimizer.apply_gradients(zip(gradients, self.network.trainable_weights))
self.noise_loss_tracker.update_state(noise_loss)
self.image_loss_tracker.update_state(image_loss)
return {m.name: m.result() for m in self.metrics}
def plot_images(
self,
epoch=None,
logs=None,
num_rows=3,
num_cols=6,
write_to_file=True,
output_dir="output",
):
# plot random generated images for visual evaluation of generation quality
generated_images = self.generate(
num_images=num_rows * num_cols,
steps=diffusion_steps,
)
plt.figure(figsize=(num_cols * 2.0, num_rows * 2.0))
for row in range(num_rows):
for col in range(num_cols):
index = row * num_cols + col
plt.subplot(num_rows, num_cols, index + 1)
plt.imshow(generated_images[index])
plt.axis("off")
plt.tight_layout()
if write_to_file:
if not os.path.exists(output_dir):
os.makedirs(output_dir)
if epoch is not None:
filename = os.path.join(
output_dir, "image_epoch_{:04d}.png".format(epoch)
)
else:
import time
timestr = time.strftime("%Y%m%d-%H%M%S")
filename = os.path.join(output_dir, "image_{}.png".format(timestr))
plt.savefig(filename)
else:
plt.show()
plt.close()
# create and compile the model
model = DiffusionModel(network)
# below tensorflow 2.9:
# pip install tensorflow_addons
# import tensorflow_addons as tfa
# optimizer=tfa.optimizers.AdamW
model.compile(
optimizer=tf.keras.optimizers.experimental.AdamW(learning_rate=learning_rate),
loss=tf.keras.losses.mean_absolute_error,
)
# pixelwise mean absolute error is used as loss
# save the best model based on the noise loss
checkpoint_path = "checkpoints/diffusion_model"
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
filepath=checkpoint_path,
save_weights_only=True,
monitor="i_loss",
mode="min",
save_best_only=True,
)
(x_train, _), (x_test, _) = tf.keras.datasets.mnist.load_data()
mnist_digits = np.concatenate([x_train, x_test], axis=0)
mnist_digits = np.expand_dims(mnist_digits, -1).astype("float32") / 255
dataset = tf.data.Dataset.from_tensor_slices(mnist_digits)
dataset = dataset.batch(batch_size, drop_remainder=True)
# calculate mean and variance of training dataset for normalization
model.normalizer.adapt(mnist_digits)
# run training and plot generated images periodically
model.fit(
dataset,
epochs=num_epochs,
batch_size=batch_size,
callbacks=[
tf.keras.callbacks.LambdaCallback(on_epoch_end=model.plot_images),
checkpoint_callback,
],
)
# load the best model and generate images
model.load_weights(checkpoint_path)
model.plot_images(write_to_file=False)
Edit
- Removed commented out block as pointed out by @xdurch0
Unfortunately, still no luck. Just to clarify, to answer this question one can either provide a network that is simpler than u-net that we can recognize some digits, or explain why we need u-net.