cell 2 code:
--------------------------------------------
from matplotlib import pyplot as plt
import clip
import torch
import numpy as np
from dalle.models import Dalle
from dalle.utils.utils import clip_score


prompt = "3 armchairs in the shape of an avocado lined up on the road. 3 
armchairs imitating avocados lined up on the road. Pikachu is sitting on the 
chair and holding an umbrella."
print(prompt)

num_candidates = 256
images = []


torch.cuda.empty_cache()

for i in range(int(num_candidates / 32)):
    with torch.no_grad():
        images.append(model.sampling(prompt=prompt,
                                top_k=128,
                                top_p=None,
                                softmax_temperature=0.7,
                                num_candidates=32,
                                device=device).cpu().numpy())
   
    torch.cuda.empty_cache()

images = np.concatenate(images)
images = np.transpose(images, (0, 2, 3, 1))

with torch.no_grad():
    rank = clip_score(prompt=prompt,
                      images=images,
                      model_clip=model_clip,
                      preprocess_clip=preprocess_clip,
                      device=device)

torch.cuda.empty_cache()

images = images[rank]

n = num_candidates

fig = plt.figure(figsize=(6*int(math.sqrt(n)), 6*int(math.sqrt(n))))
for i in range(n):
    ax = fig.add_subplot(int(math.sqrt(n)), int(math.sqrt(n)), i+1)
    ax.imshow(images[i])
    ax.set_axis_off()

plt.tight_layout()
plt.show()
------------------------------------------
Artificial General Intelligence List: AGI
Permalink: 
https://agi.topicbox.com/groups/agi/T4ae89526c13077bd-M2e348435a58056df2eb279c3
Delivery options: https://agi.topicbox.com/groups/agi/subscription

Reply via email to