Skip to content

Commit

Permalink
[Half precision] Make sure half-precision is correct (huggingface#182)
Browse files Browse the repository at this point in the history
* [Half precision] Make sure half-precision is correct

* Update src/diffusers/models/unet_2d.py

* Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py

* correct some tests

* Apply suggestions from code review

Co-authored-by: Suraj Patil <[email protected]>

* finalize

* finish

Co-authored-by: Suraj Patil <[email protected]>
  • Loading branch information
patrickvonplaten and patil-suraj authored Aug 16, 2022
1 parent 8b71fd3 commit 0020569
Show file tree
Hide file tree
Showing 6 changed files with 51 additions and 19 deletions.
6 changes: 3 additions & 3 deletions models/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@ def get_timestep_embedding(
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"

half_dim = embedding_dim // 2
exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32)
exponent = exponent / (half_dim - downscale_freq_shift)

emb_coeff = -math.log(max_period) / (half_dim - downscale_freq_shift)
emb = torch.arange(half_dim, dtype=torch.float32, device=timesteps.device)
emb = torch.exp(emb * emb_coeff)
emb = torch.exp(exponent).to(device=timesteps.device)
emb = timesteps[:, None].float() * emb[None, :]

# scale embeddings
Expand Down
8 changes: 6 additions & 2 deletions models/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,9 @@ def __init__(
def forward(self, x, temb, hey=False):
h = x

h = self.norm1(h)
# make sure hidden states is in float32
# when running in half-precision
h = self.norm1(h.float()).type(h.dtype)
h = self.nonlinearity(h)

if self.upsample is not None:
Expand All @@ -347,7 +349,9 @@ def forward(self, x, temb, hey=False):
temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None]
h = h + temb

h = self.norm2(h)
# make sure hidden states is in float32
# when running in half-precision
h = self.norm2(h.float()).type(h.dtype)
h = self.nonlinearity(h)

h = self.dropout(h)
Expand Down
7 changes: 6 additions & 1 deletion models/unet_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,9 @@ def forward(
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)

# broadcast to batch dimension
timesteps = timesteps.broadcast_to(sample.shape[0])

t_emb = self.time_proj(timesteps)
emb = self.time_embedding(t_emb)

Expand Down Expand Up @@ -166,7 +169,9 @@ def forward(
sample = upsample_block(sample, res_samples, emb)

# 6. post-process
sample = self.conv_norm_out(sample)
# make sure hidden states is in float32
# when running in half-precision
sample = self.conv_norm_out(sample.float()).type(sample.dtype)
sample = self.conv_act(sample)
sample = self.conv_out(sample)

Expand Down
8 changes: 6 additions & 2 deletions models/unet_2d_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,9 @@ def forward(
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)

# broadcast to batch dimension
timesteps = timesteps.broadcast_to(sample.shape[0])

t_emb = self.time_proj(timesteps)
emb = self.time_embedding(t_emb)

Expand Down Expand Up @@ -172,8 +175,9 @@ def forward(
sample = upsample_block(hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples)

# 6. post-process

sample = self.conv_norm_out(sample)
# make sure hidden states is in float32
# when running in half-precision
sample = self.conv_norm_out(sample.float()).type(sample.dtype)
sample = self.conv_act(sample)
sample = self.conv_out(sample)

Expand Down
26 changes: 19 additions & 7 deletions pipelines/stable_diffusion/pipeline_stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,13 @@ def __call__(
self.text_encoder.to(torch_device)

# get prompt text embeddings
text_input = self.tokenizer(prompt, padding=True, truncation=True, return_tensors="pt")
text_input = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_embeddings = self.text_encoder(text_input.input_ids.to(torch_device))[0]

# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
Expand All @@ -79,19 +85,25 @@ def __call__(
latents = torch.randn(
(batch_size, self.unet.in_channels, height // 8, width // 8),
generator=generator,
device=torch_device,
)
latents = latents.to(torch_device)

# set timesteps
accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())
extra_set_kwargs = {}
if accepts_offset:
extra_set_kwargs["offset"] = 1

self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)

# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_kwargs = {}
extra_step_kwargs = {}
if accepts_eta:
extra_kwargs["eta"] = eta

self.scheduler.set_timesteps(num_inference_steps)
extra_step_kwargs["eta"] = eta

for t in tqdm(self.scheduler.timesteps):
# expand the latents if we are doing classifier free guidance
Expand All @@ -106,7 +118,7 @@ def __call__(
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_kwargs)["prev_sample"]
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs)["prev_sample"]

# scale and decode the image latents with vae
latents = 1 / 0.18215 * latents
Expand Down
15 changes: 11 additions & 4 deletions schedulers/scheduling_ddim.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def __init__(
trained_betas=None,
timestep_values=None,
clip_sample=True,
clip_alpha_at_one=True,
tensor_format="pt",
):

Expand All @@ -75,7 +76,12 @@ def __init__(

self.alphas = 1.0 - self.betas
self.alphas_cumprod = np.cumprod(self.alphas, axis=0)
self.one = np.array(1.0)

# At every step in ddim, we are looking into the previous alphas_cumprod
# For the final step, there is no previous alphas_cumprod because we are already at 0
# `clip_alpha_at_one` decides whether we set this paratemer simply to one or
# whether we use the final alpha of the "non-previous" one.
self.final_alpha_cumprod = np.array(1.0) if clip_alpha_at_one else self.alphas_cumprod[0]

# setable values
self.num_inference_steps = None
Expand All @@ -86,19 +92,20 @@ def __init__(

def _get_variance(self, timestep, prev_timestep):
alpha_prod_t = self.alphas_cumprod[timestep]
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.one
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
beta_prod_t = 1 - alpha_prod_t
beta_prod_t_prev = 1 - alpha_prod_t_prev

variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)

return variance

def set_timesteps(self, num_inference_steps):
def set_timesteps(self, num_inference_steps, offset=0):
self.num_inference_steps = num_inference_steps
self.timesteps = np.arange(
0, self.config.num_train_timesteps, self.config.num_train_timesteps // self.num_inference_steps
)[::-1].copy()
self.timesteps += offset
self.set_format(tensor_format=self.tensor_format)

def step(
Expand Down Expand Up @@ -126,7 +133,7 @@ def step(

# 2. compute alphas, betas
alpha_prod_t = self.alphas_cumprod[timestep]
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.one
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
beta_prod_t = 1 - alpha_prod_t

# 3. compute predicted original sample from predicted noise also called
Expand Down

0 comments on commit 0020569

Please sign in to comment.