cell 2 code: -------------------------------------------- from matplotlib import pyplot as plt import clip import torch import numpy as np from dalle.models import Dalle from dalle.utils.utils import clip_score
prompt = "3 armchairs in the shape of an avocado lined up on the road. 3 armchairs imitating avocados lined up on the road. Pikachu is sitting on the chair and holding an umbrella." print(prompt) num_candidates = 256 images = [] torch.cuda.empty_cache() for i in range(int(num_candidates / 32)): with torch.no_grad(): images.append(model.sampling(prompt=prompt, top_k=128, top_p=None, softmax_temperature=0.7, num_candidates=32, device=device).cpu().numpy()) torch.cuda.empty_cache() images = np.concatenate(images) images = np.transpose(images, (0, 2, 3, 1)) with torch.no_grad(): rank = clip_score(prompt=prompt, images=images, model_clip=model_clip, preprocess_clip=preprocess_clip, device=device) torch.cuda.empty_cache() images = images[rank] n = num_candidates fig = plt.figure(figsize=(6*int(math.sqrt(n)), 6*int(math.sqrt(n)))) for i in range(n): ax = fig.add_subplot(int(math.sqrt(n)), int(math.sqrt(n)), i+1) ax.imshow(images[i]) ax.set_axis_off() plt.tight_layout() plt.show() ------------------------------------------ Artificial General Intelligence List: AGI Permalink: https://agi.topicbox.com/groups/agi/T4ae89526c13077bd-M2e348435a58056df2eb279c3 Delivery options: https://agi.topicbox.com/groups/agi/subscription
