diff --git a/CaptionsDataset.py b/CaptionsDataset.py new file mode 100644 index 0000000..af9fbf7 --- /dev/null +++ b/CaptionsDataset.py @@ -0,0 +1,145 @@ +import clip +import torch +import pickle +import random +from typing import Tuple +from torch.utils.data import Dataset +from transformers import AutoTokenizer +from utils import parse_entities, padding_captions +from load_annotations import load_entities_text, load_stopwords + +class CaptionsDataset(Dataset): + + def __init__( + self, + language_model: str = 'gpt2', + max_num_of_entities: int = 5, + using_clip_features: bool = False, + path_of_datasets: str = './annotations/coco/coco_with_entities.pickle', + debug: bool = False, + args = None, + ) -> None: + """ + Args: + language_model: the used tokenizer + max_num_of_entities: the maximum number of entities (nouns) detected in a single sentence + using_clip_features: loading pre-extracted clip text embeddings + path_of_datasets: the path of training datasets, i.e., ./annotations/***/*** + """ + + # initializing + tokenizer = AutoTokenizer.from_pretrained(language_model) + self.using_clip_features = using_clip_features + + # the format of dataset (List[List[List, str]]): + # [[['baby', 'giraffe', 'wall', 'zoo', 'environment'], 'A baby giraffe standing against a wall in a zoo like environment.'], ...] + # or (using_clip_features = True): + # [[['baby', 'giraffe', 'wall', 'zoo', 'environment'], + # A baby giraffe standing against a wall in a zoo like environment.', + # torch.tensor (size = (clip_hidden_size, ))], ...] + with open(path_of_datasets, 'rb') as infile: # loading datasets + captions_with_entities = pickle.load(infile) + + # low-data settings + if args.few_shot_ratio < 1.0: + random.shuffle(captions_with_entities) + N = len(captions_with_entities) * args.few_shot_ratio + captions_with_entities = captions_with_entities[: int(N)] + + if debug: # debug + captions_with_entities = captions_with_entities[:500] + + captions_lm_lengths = [] + self.detected_entities = [] + self.captions = [] + self.captions_lm_tokens = [] + if self.using_clip_features: + self.captions_clip_features = [] + else: + self.captions_clip_tokens = [] + + for caption_with_entities in captions_with_entities: + if self.using_clip_features: + temp_detected_entities, temp_caption, temp_clip_features = caption_with_entities + self.captions_clip_features.append(temp_clip_features) # dtype = float16, size = (clip_hidden_size, ) + else: + temp_detected_entities, temp_caption = caption_with_entities + self.captions_clip_tokens.append(clip.tokenize(temp_caption, truncate = True).squeeze(dim = 0)) # dtype = int32, size = (77, ) + self.captions.append(temp_caption) + self.detected_entities.append(temp_detected_entities[:max_num_of_entities]) + + # captions_lm_tokens are used for auto-regressive training, while captions_clip_tokens are accounted as image features during text-only training + self.captions_lm_tokens.append(torch.tensor(tokenizer.encode(temp_caption), dtype = torch.int64)) # dtype = int64, size = (n_seq,) + captions_lm_lengths.append(len(self.captions_lm_tokens[-1])) + + self.captions_lm_lengths = torch.tensor(captions_lm_lengths, dtype = torch.float32) + self.max_length_per_caption = min(int(self.captions_lm_lengths.mean() + 10 * self.captions_lm_lengths.std()), int(self.captions_lm_lengths.max())) + self.args = args + self.tokenizer = tokenizer + self.stopwords = load_stopwords() + + self.people_vocabs = ['people', 'person', 'man', 'men', 'woman', 'women', 'adult','boy', 'girl', 'kid', 'children', 'child', 'baby', 'guy', 'player', 'male', 'female', 'worker'] + self.objects_vocabs = load_entities_text(args.name_of_objects_vocabs, args.path_of_objects_vocabs, all_entities = False) + print('Dataset Loading: {} successful. Max sentence length: {}'.format(path_of_datasets, self.max_length_per_caption)) + + def __len__(self) -> int: + # return the size of this dataset + return len(self.captions) + + def pad_tokens(self, item: int) -> Tuple[torch.Tensor, ...]: + """ + Return: + tokens: tensor with a shape of (n_seq, ), padding 0 or truncating caption tokens to n_seq + mask: tensor with a shape of (n_seq, ), valid texts for attention computing + """ + tokens = self.captions_lm_tokens[item] # caption tokens + padding = self.max_length_per_caption - len(tokens) + tokens = tokens[:self.max_length_per_caption] # truncating tokens to max_seq_len + if padding > 0: # padding 0 to max_seq_len + tokens = torch.cat((tokens, torch.zeros(padding, dtype = torch.int64) - 1)) + + mask = tokens.ge(0) + tokens[~mask] = 0 # when calculating loss, the position where idx = 0 should be ignored + mask = mask.float() + + return tokens, mask + + def __getitem__(self, item: int) -> Tuple[torch.Tensor, ...]: + """ + Return: + captions_clip: tensor with a shape of (clip_hidden_size, ) for clip features, or (77) clip tokens + captions_gpt_tokens: tensor with a shape of (n_seq, ), the caption tokens encoded by language model + masks: tensor with a shape of (n_seq, ), valid texts for attention computing + discrete_tokens: tensor with a shape of (len_discrete_tokens + len_prompt_templates, ) + """ + caption_lm_tokens, mask = self.pad_tokens(item) + + if self.using_clip_features: + captions_clip = self.captions_clip_features[item] + else: + captions_clip = self.captions_clip_tokens[item] + + detected_entities = self.detected_entities[item] + masks = mask + captions_gpt_tokens = caption_lm_tokens + + discrete_tokens = None + if self.args.using_hard_prompt: + discrete_tokens = parse_entities(self.args, self.tokenizer, [detected_entities], self.stopwords, self.people_vocabs, self.objects_vocabs)[0] + return self.args, captions_clip, captions_gpt_tokens, masks, discrete_tokens + + +def collate(batch): + batch_size = len(batch) + args = batch[0][0] + _, captions_clip, captions_gpt_tokens, masks, discrete_tokens = zip(*batch) + captions_clip = torch.stack(captions_clip) + captions_gpt_tokens = torch.stack(captions_gpt_tokens, dim=0) + masks = torch.stack(masks) + + hard_prompts_length = None + if args.using_hard_prompt: + captions_gpt_tokens, captions_tokens_for_loss, masks, hard_prompts_length = padding_captions(args, captions_gpt_tokens, masks, discrete_tokens) + else: + captions_gpt_tokens, captions_tokens_for_loss, masks = padding_captions(args, captions_gpt_tokens, masks) + return captions_clip, captions_gpt_tokens, captions_tokens_for_loss, masks, hard_prompts_length \ No newline at end of file diff --git a/ClipCap.py b/ClipCap.py new file mode 100644 index 0000000..32b1520 --- /dev/null +++ b/ClipCap.py @@ -0,0 +1,251 @@ +import torch +import torch.nn as nn +import torch.nn.functional as nnf +from typing import Tuple, Optional, List +from transformers import GPT2LMHeadModel + +class MlpTransformer(nn.Module): + + def __init__( + self, + input_size: int, # the input size of mlp + hidden_size: int, # the hidden layer size of mlp + output_size: Optional[int] = None, # the output size of mlp + act = nnf.relu, + dropout: float = 0.0 + ) -> None: + super().__init__() + output_size = output_size if output_size is not None else input_size + self.fc1 = nn.Linear(input_size, hidden_size) + self.act = act + self.fc2 = nn.Linear(hidden_size, output_size) + self.dropout = nn.Dropout(dropout) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.fc1(x) + x = self.act(x) + x = self.dropout(x) + x = self.fc2(x) + x = self.dropout(x) + return x + +class MultiHeadAttention(nn.Module): + + def __init__( + self, + query_size: int, + key_value_size: int, + num_heads: int, + bias = True, + dropout: float = 0.0 + ) -> None: + super(MultiHeadAttention, self).__init__() + self.num_heads = num_heads + self.head_size = query_size // num_heads # the size of each head + self.scale = self.head_size ** -0.5 # normalization factor for each head + self.to_queries = nn.Linear(query_size, query_size, bias = bias) + # projecting key and value together and spliting them for computing efficiently + self.to_keys_values = nn.Linear(key_value_size, 2 * query_size, bias = bias) + self.project = nn.Linear(query_size, query_size) + self.dropout = nn.Dropout(dropout) + + def forward(self, query: torch.Tensor, key_value: torch.Tensor = None, mask: torch.Tensor = None) -> torch.Tensor: + key_value = key_value if key_value is not None else query + b, n, d_query = query.shape + _, m, _ = key_value.shape + queries = self.to_queries(query).reshape(b, n, self.num_heads, self.head_size) # (batch_size, n_seq, num_heads, head_size) + keys_values = self.to_keys_values(key_value).reshape(b, m, 2, self.num_heads, self.head_size) # (batch_size, m_seq, 2, num_heads, head_size) + keys, values = keys_values[:, :, 0], keys_values[:, :, 1] # (batch_size, m_seq, num_heads, head_size), (batch_size, m_seq, num_heads, head_size) + attention = torch.einsum('bnhd,bmhd->bnmh', queries, keys) * self.scale # (batch_size, n_seq, m_seq, num_heads) + + if mask is not None: + if mask.dim() == 2: + mask = mask.unsqueeze(dim = 1) # expending dimension, shape: (batch_size, 1, m_seq) + attention = attention.masked_fill(mask.unsqueeze(dim = 3), float("-inf")) # expending dimension n_seq head and fill -inf according to mask + + attention = attention.softmax(dim = 2) # softmax alongside the dimension of key_value pairs + outputs = torch.einsum('bnmh,bmhd->bnhd', attention, values).reshape(b, n, d_query) # (batch_size, n_seq, d_query) + outputs = self.project(outputs) + return outputs, attention + +class TransformerLayer(nn.Module): + + def __init__( + self, + query_size: int, + key_value_size: int, + num_heads: int, + mlp_ratio = 4.0, + bias = False, + dropout: float = 0.0, + act = nnf.relu, + norm_layer: nn.Module = nn.LayerNorm + ) -> None: + super(TransformerLayer, self).__init__() + self.norm1 = norm_layer(query_size) + self.attn = MultiHeadAttention(query_size, key_value_size, num_heads, bias = bias, dropout = dropout) + self.norm2 = norm_layer(query_size) + self.mlp = MlpTransformer(query_size, int(query_size * mlp_ratio), act = act, dropout = dropout) + + def forward(self, query: torch.Tensor, key_value: torch.Tensor = None, mask: torch.Tensor = None) -> torch.Tensor: + query_, self.attention = self.attn(self.norm1(query), key_value, mask) + query = query + query_ + query = query + self.mlp(self.norm2(query)) + return query + +class Transformer(nn.Module): + + def __init__( + self, + query_size: int, # query size + num_layers: int, # number of layer + num_heads: int, # number of head + key_value_size: Optional[int] = None, # key/value size + mlp_ratio: float = 2.0, # ratio for hidden size in mlp + act = nnf.relu, # activation + norm_layer: nn.Module = nn.LayerNorm # normalization + ) -> None: + super(Transformer, self).__init__() + key_value_size = key_value_size if key_value_size is not None else query_size + layers = [] + for _ in range(num_layers): + layers.append(TransformerLayer(query_size, key_value_size, num_heads, mlp_ratio = mlp_ratio, act = act, norm_layer = norm_layer)) + self.layers = nn.Sequential(*layers) + + def forward(self, query: torch.Tensor, key_value: torch.Tensor = None, mask: torch.Tensor = None) -> torch.Tensor: + self.attentions = [] + for layer in self.layers: + query = layer(query, key_value, mask) + self.attentions.append(layer.attention) + return query + +class MappingNetwork(nn.Module): + + def __init__( + self, + clip_project_length: int, + clip_hidden_size: int, + prefix_length: int, + d_model: int, # the hidden size of language model + num_layers: int = 8, + num_heads: int = 8 + ) -> None: + super(MappingNetwork, self).__init__() + self.clip_project_length = clip_project_length + # projector for input + self.linear = nn.Linear(clip_hidden_size, clip_project_length * d_model) + # learnable prefix embeddings + self.prefix_const = nn.Parameter(torch.randn(prefix_length, d_model), requires_grad = True) + self.transformer = Transformer(d_model, num_layers, num_heads) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x: clip cls feature with a shape of (batch_size, clip_hidden_size) + Return: + the embeddings of prefix with the shape of (batch_size, prefix_length, d_model) + """ + x = self.linear(x).view(x.shape[0], self.clip_project_length, -1) # (b, clip_project_length, d_model) + prefix = self.prefix_const.unsqueeze(dim = 0).expand(x.shape[0], *self.prefix_const.shape) # (b, prefix_length, d_model) + inputs = torch.cat((x, prefix), dim = 1) # (b, clip_project_length + prefix_length, d_model) + outputs = self.transformer(inputs)[:,self.clip_project_length:,:] # (b, prefix_length, d_model) + + return outputs + +def get_language_mode(lm_type): + if 'gpt' in lm_type: + model = GPT2LMHeadModel.from_pretrained(lm_type) + hidden_size = model.config.hidden_size + elif 'opt' in lm_type: + from modeling_opt import OPTForCausalLM + model = OPTForCausalLM.from_pretrained(lm_type, torch_dtype = torch.float16) + hidden_size = model.config.word_embed_proj_dim + return model, hidden_size + +class ClipCaptionModel(nn.Module): + + def __init__( + self, + continuous_length: int = 10, + clip_project_length: int = 10, + clip_hidden_size: int = 512, + num_layers: int = 8, + num_heads: int = 8, + gpt_type: str = 'gpt2', + soft_prompt_first: bool = False, + only_hard_prompt: bool = False + ) -> None: + """ + Args: + continuous_length: the length of soft prompts which will be fed into language model as continuous part + clip_project_length: clip cls features (b, 1, d) -> (b, n, d) + clip_hidden_size: the dimensions of CLIP features + num_layers: the number of layer in projector + num_heads: the number of heads each layer + gpt_type: the language model + soft_prompt_first: False -> hard prompt + soft prompt; True -> soft prompt + hard prompt + only_hard_prompt: using the hard prompts only + """ + super(ClipCaptionModel, self).__init__() + self.soft_prompt_first = soft_prompt_first + self.only_hard_prompt = only_hard_prompt + self.continuous_length = continuous_length + self.gpt, self.gpt_hidden_size = get_language_mode(gpt_type) + self.mapping_network = MappingNetwork(clip_project_length, clip_hidden_size, continuous_length, self.gpt_hidden_size, num_layers, num_heads) + self.gpt_type = gpt_type + + def word_embed(self, caption_tokens): + if 'gpt' in self.gpt_type: + caption_embeddings = self.gpt.transformer.wte(caption_tokens) # (b, caption_length, gpt_hidden_size) + elif 'opt' in self.gpt_type: + caption_embeddings = self.gpt.model.decoder.embed_tokens(caption_tokens) + return caption_embeddings + + def forward( + self, + continuous_prompt: torch.Tensor, + caption_tokens: torch.Tensor, + hard_prompts_length: Optional[List] = None, + mask: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, ...]: + """ + Args: + continuous_prompt: tensor with a shape of (b, clip_hidden_size), in text-only training, the caption features are eaxtracted from CLIP and used as image features + caption_tokens: caption tokens with a shape of (b, max_length_per_caption) + hard_prompts_length: list with len = batch size, the length of hard prompts constructed for each caption + mask: tensor with a shape of (b, discrete_length + continuous_length + max_length_per_caption), valid texts for attention computing + Return: + the output of language model + """ + caption_embeddings = self.word_embed(caption_tokens) + continuous_embeddings = self.mapping_network(continuous_prompt).view(-1, self.continuous_length, self.gpt_hidden_size) # (b, continuous_length, gpt_hidden_size) + if hard_prompts_length is not None: # with hard prompts + if self.only_hard_prompt: + embeddings = caption_embeddings + elif self.soft_prompt_first: # soft prompts + hard prompts + embeddings = torch.cat((continuous_embeddings, caption_embeddings), dim = 1) + else: # hard prompts + soft prompts + embeddings = None + for i in range(len(hard_prompts_length)): + length = hard_prompts_length[i] + temp_embeddings = torch.cat((caption_embeddings[i][:length], continuous_embeddings[i], caption_embeddings[i][length:]), dim = 0).unsqueeze(dim = 0) + if embeddings is None: + embeddings = temp_embeddings + else: + embeddings = torch.cat((embeddings, temp_embeddings), dim = 0) + else: # without hard prompts + embeddings = torch.cat((continuous_embeddings, caption_embeddings), dim = 1) # (b, continuous_length + caption_length, gpt_hidden_size) + + out = self.gpt(inputs_embeds = embeddings.type(self.gpt.dtype), attention_mask = mask) + + return out + +class ClipCaptionPrefix(ClipCaptionModel): + + def parameters(self, recurse: bool = True): + return self.mapping_network.parameters() + + def train(self, mode: bool = True): + super(ClipCaptionPrefix, self).train(mode) + self.gpt.eval() + return self \ No newline at end of file diff --git a/README.md b/README.md index 28df2a1..fedf6e3 100644 --- a/README.md +++ b/README.md @@ -1,3 +1 @@ -# ViECap - -the code will be released in two weeks +The instruction of this repo will be released as soon as possible. \ No newline at end of file diff --git a/images/honkai1.jpg b/images/honkai1.jpg new file mode 100644 index 0000000..feef6d4 Binary files /dev/null and b/images/honkai1.jpg differ diff --git a/images/honkai2.jpg b/images/honkai2.jpg new file mode 100644 index 0000000..cd6af13 Binary files /dev/null and b/images/honkai2.jpg differ diff --git a/images/honkai3.jpg b/images/honkai3.jpg new file mode 100644 index 0000000..dad8828 Binary files /dev/null and b/images/honkai3.jpg differ diff --git a/images/honkai4.jpg b/images/honkai4.jpg new file mode 100644 index 0000000..53f0d8b Binary files /dev/null and b/images/honkai4.jpg differ diff --git a/images/nocaps1.jpg b/images/nocaps1.jpg new file mode 100644 index 0000000..3f67452 Binary files /dev/null and b/images/nocaps1.jpg differ diff --git a/images/nocaps2.jpg b/images/nocaps2.jpg new file mode 100644 index 0000000..4ec7bbe Binary files /dev/null and b/images/nocaps2.jpg differ diff --git a/images/nocaps3.jpg b/images/nocaps3.jpg new file mode 100644 index 0000000..e2b8f9d Binary files /dev/null and b/images/nocaps3.jpg differ diff --git a/images/nocaps4.jpg b/images/nocaps4.jpg new file mode 100644 index 0000000..9a95719 Binary files /dev/null and b/images/nocaps4.jpg differ diff --git a/images/nocaps5.jpg b/images/nocaps5.jpg new file mode 100644 index 0000000..773b4a3 Binary files /dev/null and b/images/nocaps5.jpg differ diff --git a/images/nocaps6.jpg b/images/nocaps6.jpg new file mode 100644 index 0000000..1ae003d Binary files /dev/null and b/images/nocaps6.jpg differ diff --git a/images/predictions.json b/images/predictions.json new file mode 100644 index 0000000..f1da8fb --- /dev/null +++ b/images/predictions.json @@ -0,0 +1,42 @@ +[ + { + "image_name": "honkai1.jpg", + "prediction": "A group of people standing around a table with a cake." + }, + { + "image_name": "honkai2.jpg", + "prediction": "A group of people that are posing for a picture." + }, + { + "image_name": "honkai3.jpg", + "prediction": "A woman that is standing in the dark." + }, + { + "image_name": "honkai4.jpg", + "prediction": "A beautiful woman sitting on top of a wooden bench." + }, + { + "image_name": "nocaps1.jpg", + "prediction": "A blue jay is perched on a tree branch." + }, + { + "image_name": "nocaps2.jpg", + "prediction": "A centipede is eating from a tree branch." + }, + { + "image_name": "nocaps3.jpg", + "prediction": "A group of dolphins are swimming in the water." + }, + { + "image_name": "nocaps4.jpg", + "prediction": "A gray and brown raccoon is sitting on a rock." + }, + { + "image_name": "nocaps5.jpg", + "prediction": "A white mechanical fan hanging from a ceiling." + }, + { + "image_name": "nocaps6.jpg", + "prediction": "A close up of a pair of sunglasses on a table." + } +] \ No newline at end of file diff --git a/infer_by_batch.py b/infer_by_batch.py new file mode 100644 index 0000000..a22d375 --- /dev/null +++ b/infer_by_batch.py @@ -0,0 +1,137 @@ +import os +import json +import clip +import torch +import argparse +from tqdm import tqdm +from PIL import Image +from ClipCap import ClipCaptionModel +from transformers import AutoTokenizer +from utils import compose_discrete_prompts +from load_annotations import load_entities_text +from search import greedy_search, beam_search, opt_search +from retrieval_categories import clip_texts_embeddings, image_text_simiarlity, top_k_categories + +@torch.no_grad() +def main(args) -> None: + # initializing + device = args.device + clip_name = args.clip_model.replace('/', '') + clip_hidden_size = 640 if 'RN' in args.clip_model else 512 + + # loading categories vocabulary for objects + if args.name_of_entities_text == 'visual_genome_entities': + entities_text = load_entities_text(args.name_of_entities_text, './annotations/vocabulary/all_objects_attributes_relationships.pickle', not args.disable_all_entities) + if args.prompt_ensemble: # loading ensemble embeddings + texts_embeddings = clip_texts_embeddings(entities_text, f'./annotations/vocabulary/visual_genome_embedding_{clip_name}_with_ensemble.pickle') + else: + texts_embeddings = clip_texts_embeddings(entities_text, f'./annotations/vocabulary/visual_genome_embedding_{clip_name}.pickle') + elif args.name_of_entities_text == 'coco_entities': + entities_text = load_entities_text(args.name_of_entities_text, './annotations/vocabulary/coco_categories.json', not args.disable_all_entities) + if args.prompt_ensemble: + texts_embeddings = clip_texts_embeddings(entities_text, f'./annotations/vocabulary/coco_embeddings_{clip_name}_with_ensemble.pickle') + else: + texts_embeddings = clip_texts_embeddings(entities_text, f'./annotations/vocabulary/coco_embeddings_{clip_name}.pickle') + elif args.name_of_entities_text == 'open_image_entities': + entities_text = load_entities_text(args.name_of_entities_text, './annotations/vocabulary/oidv7-class-descriptions-boxable.csv', not args.disable_all_entities) + if args.prompt_ensemble: + texts_embeddings = clip_texts_embeddings(entities_text, f'./annotations/vocabulary/open_image_embeddings_{clip_name}_with_ensemble.pickle') + else: + texts_embeddings = clip_texts_embeddings(entities_text, f'./annotations/vocabulary/open_image_embeddings_{clip_name}.pickle') + elif args.name_of_entities_text == 'vinvl_vg_entities': + entities_text = load_entities_text(args.name_of_entities_text, './annotations/vocabulary/VG-SGG-dicts-vgoi6-clipped.json', not args.disable_all_entities) + if args.prompt_ensemble: + texts_embeddings = clip_texts_embeddings(entities_text, f'./annotations/vocabulary/vg_embeddings_{clip_name}_with_ensemble.pickle') + else: + texts_embeddings = clip_texts_embeddings(entities_text, f'./annotations/vocabulary/vg_embeddings_{clip_name}.pickle') + elif args.name_of_entities_text == 'vinvl_vgoi_entities': + entities_text = load_entities_text(args.name_of_entities_text, './annotations/vocabulary/vgcocooiobjects_v1_class2ind.json', not args.disable_all_entities) + if args.prompt_ensemble: + texts_embeddings = clip_texts_embeddings(entities_text, f'./annotations/vocabulary/vgoi_embeddings_{clip_name}_with_ensemble.pickle') + else: + texts_embeddings = clip_texts_embeddings(entities_text, f'./annotations/vocabulary/vgoi_embeddings_{clip_name}.pickle') + else: + print('The entities text should be input correctly!') + return + + # loading model + tokenizer = AutoTokenizer.from_pretrained(args.language_model) + model = ClipCaptionModel(args.continuous_prompt_length, args.clip_project_length, clip_hidden_size, gpt_type = args.language_model) + model.load_state_dict(torch.load(args.weight_path, map_location = device)) + model.to(device) + encoder, preprocess = clip.load(args.clip_model, device = device) + + images_list = os.listdir(args.image_path) + images_list.sort() + images_list = [os.path.join(args.image_path, image) for image in images_list] + + predicts = [] + for _, im_path in tqdm(enumerate(images_list)): + try: + image = preprocess(Image.open(im_path)).unsqueeze(dim = 0).to(device) + except: + continue + image_features = encoder.encode_image(image).float() # (1, clip_hidden_size) + image_features /= image_features.norm(2, dim = -1, keepdim = True) + continuous_embeddings = model.mapping_network(image_features).view(-1, args.continuous_prompt_length, model.gpt_hidden_size) + if args.using_hard_prompt: + logits = image_text_simiarlity(texts_embeddings, temperature = args.temperature, images_features = image_features) + detected_objects, _ = top_k_categories(entities_text, logits, args.top_k, args.threshold) # List[List[]], [[category1, category2, ...], [], ...] + detected_objects = detected_objects[0] # infering single image -> List[category1, category2, ...] + discrete_tokens = compose_discrete_prompts(tokenizer, detected_objects).unsqueeze(dim = 0).to(args.device) + discrete_embeddings = model.word_embed(discrete_tokens) + if args.only_hard_prompt: + embeddings = discrete_embeddings + elif args.soft_prompt_first: + embeddings = torch.cat((continuous_embeddings, discrete_embeddings), dim = 1) + else: + embeddings = torch.cat((discrete_embeddings, continuous_embeddings), dim = 1) + else: + embeddings = continuous_embeddings + + if 'gpt' in args.language_model: + if not args.using_greedy_search: + sentence = beam_search(embeddings = embeddings, tokenizer = tokenizer, beam_width = args.beam_width, model = model.gpt) # List[str] + sentence = sentence[0] # selected top 1 + else: + sentence = greedy_search(embeddings = embeddings, tokenizer = tokenizer, model = model.gpt) + else: + sentence = opt_search(prompts=args.text_prompt, embeddings = embeddings, tokenizer = tokenizer, beam_width = args.beam_width, model = model.gpt) + sentence=sentence[0] + + _, im_path = os.path.split(im_path) + predict = {} + predict["image_name"] = im_path + predict["prediction"] = sentence + predicts.append(predict) + + outpath = os.path.join(args.image_path, 'predictions.json') + with open(outpath, 'w') as outfile: + json.dump(predicts, outfile, indent = 4) + +if __name__ == '__main__': + + parser = argparse.ArgumentParser() + parser.add_argument('--device', default = 'cuda:0') + parser.add_argument('--clip_model', default = 'ViT-B/32') + parser.add_argument('--language_model', default = 'gpt2') + parser.add_argument('--continuous_prompt_length', type = int, default = 10) + parser.add_argument('--clip_project_length', type = int, default = 10) + parser.add_argument('--temperature', type = float, default = 0.01) + parser.add_argument('--top_k', type = int, default = 3) + parser.add_argument('--threshold', type = float, default = 0.2) + parser.add_argument('--disable_all_entities', action = 'store_true', default = False, help = 'whether to use entities with a single word only') + parser.add_argument('--name_of_entities_text', default = 'vinvl_vgoi_entities', choices = ('visual_genome_entities', 'coco_entities', 'open_image_entities', 'vinvl_vg_entities', 'vinvl_vgoi_entities')) + parser.add_argument('--prompt_ensemble', action = 'store_true', default = False) + parser.add_argument('--weight_path', default = './checkpoints/train_coco/coco_prefix-0014.pt') + parser.add_argument('--image_path', default = './images') + parser.add_argument('--using_hard_prompt', action = 'store_true', default = False) + parser.add_argument('--soft_prompt_first', action = 'store_true', default = False) + parser.add_argument('--only_hard_prompt', action = 'store_true', default = False) + parser.add_argument('--using_greedy_search', action = 'store_true', default = False, help = 'greedy search or beam search') + parser.add_argument('--beam_width', type = int, default = 5, help = 'width of beam') + parser.add_argument('--text_prompt', type = str, default = None) + args = parser.parse_args() + print('args: {}\n'.format(vars(args))) + + main(args) \ No newline at end of file diff --git a/infer_by_instance.py b/infer_by_instance.py new file mode 100644 index 0000000..5e5a9bc --- /dev/null +++ b/infer_by_instance.py @@ -0,0 +1,118 @@ +import clip +import torch +import argparse +from PIL import Image +from ClipCap import ClipCaptionModel +from transformers import AutoTokenizer +from utils import compose_discrete_prompts +from load_annotations import load_entities_text +from search import greedy_search, beam_search, opt_search +from retrieval_categories import clip_texts_embeddings, image_text_simiarlity, top_k_categories + +@torch.no_grad() +def main(args) -> None: + # initializing + device = args.device + clip_name = args.clip_model.replace('/', '') + clip_hidden_size = 640 if 'RN' in args.clip_model else 512 + + # loading categories vocabulary for objects + if args.name_of_entities_text == 'visual_genome_entities': + entities_text = load_entities_text(args.name_of_entities_text, './annotations/vocabulary/all_objects_attributes_relationships.pickle', not args.disable_all_entities) + if args.prompt_ensemble: # loading ensemble embeddings + texts_embeddings = clip_texts_embeddings(entities_text, f'./annotations/vocabulary/visual_genome_embedding_{clip_name}_with_ensemble.pickle') + else: + texts_embeddings = clip_texts_embeddings(entities_text, f'./annotations/vocabulary/visual_genome_embedding_{clip_name}.pickle') + elif args.name_of_entities_text == 'coco_entities': + entities_text = load_entities_text(args.name_of_entities_text, './annotations/vocabulary/coco_categories.json', not args.disable_all_entities) + if args.prompt_ensemble: + texts_embeddings = clip_texts_embeddings(entities_text, f'./annotations/vocabulary/coco_embeddings_{clip_name}_with_ensemble.pickle') + else: + texts_embeddings = clip_texts_embeddings(entities_text, f'./annotations/vocabulary/coco_embeddings_{clip_name}.pickle') + elif args.name_of_entities_text == 'open_image_entities': + entities_text = load_entities_text(args.name_of_entities_text, './annotations/vocabulary/oidv7-class-descriptions-boxable.csv', not args.disable_all_entities) + if args.prompt_ensemble: + texts_embeddings = clip_texts_embeddings(entities_text, f'./annotations/vocabulary/open_image_embeddings_{clip_name}_with_ensemble.pickle') + else: + texts_embeddings = clip_texts_embeddings(entities_text, f'./annotations/vocabulary/open_image_embeddings_{clip_name}.pickle') + elif args.name_of_entities_text == 'vinvl_vg_entities': + entities_text = load_entities_text(args.name_of_entities_text, './annotations/vocabulary/VG-SGG-dicts-vgoi6-clipped.json', not args.disable_all_entities) + if args.prompt_ensemble: + texts_embeddings = clip_texts_embeddings(entities_text, f'./annotations/vocabulary/vg_embeddings_{clip_name}_with_ensemble.pickle') + else: + texts_embeddings = clip_texts_embeddings(entities_text, f'./annotations/vocabulary/vg_embeddings_{clip_name}.pickle') + elif args.name_of_entities_text == 'vinvl_vgoi_entities': + entities_text = load_entities_text(args.name_of_entities_text, './annotations/vocabulary/vgcocooiobjects_v1_class2ind.json', not args.disable_all_entities) + if args.prompt_ensemble: + texts_embeddings = clip_texts_embeddings(entities_text, f'./annotations/vocabulary/vgoi_embeddings_{clip_name}_with_ensemble.pickle') + else: + texts_embeddings = clip_texts_embeddings(entities_text, f'./annotations/vocabulary/vgoi_embeddings_{clip_name}.pickle') + else: + print('The entities text should be input correctly!') + return + + # loading model + tokenizer = AutoTokenizer.from_pretrained(args.language_model) + model = ClipCaptionModel(args.continuous_prompt_length, args.clip_project_length, clip_hidden_size, gpt_type = args.language_model) + model.load_state_dict(torch.load(args.weight_path, map_location = device)) + model.to(device) + encoder, preprocess = clip.load(args.clip_model, device = device) + + image = preprocess(Image.open(args.image_path)).unsqueeze(dim = 0).to(device) + image_features = encoder.encode_image(image).float() + image_features /= image_features.norm(2, dim = -1, keepdim = True) + continuous_embeddings = model.mapping_network(image_features).view(-1, args.continuous_prompt_length, model.gpt_hidden_size) + if args.using_hard_prompt: + logits = image_text_simiarlity(texts_embeddings, temperature = args.temperature, images_features = image_features) + detected_objects, _ = top_k_categories(entities_text, logits, args.top_k, args.threshold) # List[List[]], [[category1, category2, ...], [], ...] + detected_objects = detected_objects[0] # infering single image -> List[category1, category2, ...] + discrete_tokens = compose_discrete_prompts(tokenizer, detected_objects).unsqueeze(dim = 0).to(args.device) + + discrete_embeddings = model.word_embed(discrete_tokens) + if args.only_hard_prompt: + embeddings = discrete_embeddings + elif args.soft_prompt_first: + embeddings = torch.cat((continuous_embeddings, discrete_embeddings), dim = 1) + else: + embeddings = torch.cat((discrete_embeddings, continuous_embeddings), dim = 1) + else: + embeddings = continuous_embeddings + + if 'gpt' in args.language_model: + if not args.using_greedy_search: + sentence = beam_search(embeddings = embeddings, tokenizer = tokenizer, beam_width = args.beam_width, model = model.gpt) # List[str] + sentence = sentence[0] # selected top 1 + else: + sentence = greedy_search(embeddings = embeddings, tokenizer = tokenizer, model = model.gpt) + else: + sentence = opt_search(prompts=args.text_prompt, embeddings = embeddings, tokenizer = tokenizer, beam_width = args.beam_width, model = model.gpt) + sentence=sentence[0] + + print(f'the generated caption: {sentence}') + +if __name__ == '__main__': + + parser = argparse.ArgumentParser() + parser.add_argument('--device', default = 'cuda:0') + parser.add_argument('--clip_model', default = 'ViT-B/32') + parser.add_argument('--language_model', default = 'gpt2') + parser.add_argument('--continuous_prompt_length', type = int, default = 10) + parser.add_argument('--clip_project_length', type = int, default = 10) + parser.add_argument('--temperature', type = float, default = 0.01) + parser.add_argument('--top_k', type = int, default = 3) + parser.add_argument('--threshold', type = float, default = 0.2) + parser.add_argument('--disable_all_entities', action = 'store_true', default = False, help = 'whether to use entities with a single word only') + parser.add_argument('--name_of_entities_text', default = 'vinvl_vgoi_entities', choices = ('visual_genome_entities', 'coco_entities', 'open_image_entities', 'vinvl_vg_entities', 'vinvl_vgoi_entities')) + parser.add_argument('--prompt_ensemble', action = 'store_true', default = False) + parser.add_argument('--weight_path', default = './checkpoints/train_coco/coco_prefix-0014.pt') + parser.add_argument('--image_path', default = './images/') + parser.add_argument('--using_hard_prompt', action = 'store_true', default = False) + parser.add_argument('--soft_prompt_first', action = 'store_true', default = False) + parser.add_argument('--only_hard_prompt', action = 'store_true', default = False) + parser.add_argument('--using_greedy_search', action = 'store_true', default = False, help = 'greedy search or beam search') + parser.add_argument('--beam_width', type = int, default = 5, help = 'width of beam') + parser.add_argument('--text_prompt', type = str, default = None) + args = parser.parse_args() + print('args: {}\n'.format(vars(args))) + + main(args) \ No newline at end of file diff --git a/load_annotations.py b/load_annotations.py new file mode 100644 index 0000000..7eb5296 --- /dev/null +++ b/load_annotations.py @@ -0,0 +1,213 @@ +import json +import pickle +import pandas as pd +from typing import List + +def load_coco_captions(path: str) -> List[str]: + + with open(path, 'r') as infile: + annotations = json.load(infile) # dictionary -> {image_path: List[caption1, caption2, ...]} + punctuations = ['!', '"', '#', '$', '%', '&', "'", '(', ')', '*', '+', ',', ' ', '-', '.', '/', ':', ';', '<', '=', '>', '?', '@', '[', '\\', ']', '^', '_', '`', '{', '|', '}', '~'] + + captions = [] + for image_path in annotations: + temp_captions = annotations[image_path] # List: [caption1, caption2, ...], captions for the ith image + for caption in temp_captions: # caption + caption = caption.strip() # removing space at the end of the caption + if caption.isupper(): # processing the special caption in the COCO Caption, e.g., 'A BOY IS PLAYING BASEBALL.' + caption = caption.lower() + caption = caption[0].upper() + caption[1:] # capitalizing the first letter in the caption + if caption[-1] not in punctuations: # adding a '.' at the end of the caption if there are no punctuations. + caption += '.' + captions.append(caption) # final versin: A boy is playing baseball. + + return captions + +def load_flickr30k_captions(path: str) -> List[str]: + + with open(path, 'r') as infile: + annotations = json.load(infile) # dictionary -> {image_path: List[caption1, caption2, ...]} + punctuations = ['!', '"', '#', '$', '%', '&', "'", '(', ')', '*', '+', ',', ' ', '-', '.', '/', ':', ';', '<', '=', '>', '?', '@', '[', '\\', ']', '^', '_', '`', '{', '|', '}', '~'] + + captions = [] + for image_path in annotations: + temp_captions = annotations[image_path] + for caption in temp_captions: + caption = caption.strip() + if caption.isupper(): + caption = caption.lower() + caption = caption[0].upper() + caption[1:] + if caption[-1] not in punctuations: + caption += '.' + captions.append(caption) + + return captions + +def load_captions(name_of_datasets: str, path_of_datasets: str) -> List[str]: + """ + Args: + name_of_datasets: specifying the name of datasets + path_of_datasets: specifying the path of datasets + Return: + [caption1, caption2, ...] + """ + if name_of_datasets == 'coco_captions': + return load_coco_captions(path_of_datasets) + + if name_of_datasets == 'flickr30k_captions': + return load_flickr30k_captions(path_of_datasets) + + print('The datasets for training fail to load!') + +def load_stopwords() -> List[str]: + # Return: stopwords and punctuations + + stopwords = {'per', '’ll', 'could', 'fifteen', 'been', "isn't", 'whoever', 'any', 'whole', 'front', "won't", 'upon', 'there', 's', 'am', 'via', 'the', 'as', "haven't", 'on', 'km', 'further', 'their', 'quite', 'have', 'twenty', 'during', 'full', 'it', 'thin', 'so', 'what', 'an', 't', 'less', 'if', 'sixty', 'everyone', 'us', 'were', 'side', 'she', 'cannot', 'thereby', '‘ve', 'amount', 'n’t', 'be', 'nine', 'isn', 'wouldn', 'by', 'along', "'ll", 'themselves', 'forty', 'everywhere', "'d", 'thru', 'sometimes', 'hasnt', 'seeming', 'own', 'that', "'ve", 'least', 'with', 'inc', 'really', 'afterwards', 'due', 'for', 'sometime', 'last', 'find', 'therein', 'all', 'thick', 'detail', 'few', 'hundred', 'some', 'even', 'off', '’m', 'ain', '’re', 'hence', 'etc', 'into', 'rather', 'where', 'm', 'its', 'onto', '’s', 'get', 'other', 'moreover', 'noone', 'being', 'must', 'bill', "wasn't", 'system', 'neither', "you'll", 'third', 'whereby', 'nobody', 'among', 'throughout', 'except', 'beforehand', "didn't", 'was', 'without', 'whose', 'hasn', '‘d', 'or', 'theirs', 'various', 'name', 'twelve', 'myself', 'former', 'though', 'we', 'ours', 'many', 'sincere', 'regarding', 'had', 'before', 'mustn', 'either', 'doing', 'why', 'fill', 'eight', 'won', 'anything', 'hereupon', 'this', 'amoungst', '‘s', 'of', 'yourselves', 'beside', 'within', 'ourselves', '‘re', 'about', 'elsewhere', 'latter', 'through', 'll', 'i', 'wasn', 'anywhere', 'weren', 'just', 'itself', "you're", 'wherein', 'four', 'keep', 'whether', 'nothing', 'found', 'back', 'needn', "aren't", 'has', 'one', 'wherever', 'serious', 'everything', 'hadn', 'first', 'anyway', 'co', 'still', 'five', 'becomes', "don't", 'formerly', 'ever', 'part', 'nowhere', 'made', 'himself', "couldn't", 'none', 'others', 'now', 'doesn', 'at', 'another', 'does', 'kg', 'see', 'often', 'them', 'shan', 'fifty', 'ltd', 'namely', 'they', 'somewhere', 'haven', 'take', 'latterly', 'well', 'whatever', 'nor', 'whereafter', 'might', 'only', 'de', 'our', 'hers', "mustn't", 'aren', 'you', 'his', "wouldn't", 'please', 'empty', 'but', 'mightn', 'then', 'should', 'and', 'each', 'such', 'a', 'yet', 'y', 'enough', 'someone', 'would', 'since', 'however', 'make', 'alone', 'anyone', 'amongst', 'these', 'whereupon', 'fire', "hasn't", 'shouldn', 'didn', 'do', 'me', 'becoming', 'after', 'several', 'seem', 'her', 'three', 'out', 'ten', 'whence', 'eg', 'couldn', 'un', 'did', "she's", 'whither', 'toward', 'once', "should've", 'call', "weren't", 'again', 'more', 'show', 'seems', "needn't", 'thereupon', 'used', 'most', 'hereby', 'put', 'ie', 've', 'my', 'your', 'thence', 'already', 'always', 'having', 'much', 'move', 'eleven', "'re", 'here', 'yours', 'con', 'done', 'up', 'over', 'yourself', "it's", 'o', 'six', 'can', 'how', "hadn't", 'anyhow', 'below', 'also', 'say', 'together', 'down', 'using', 'while', 'almost', 'cry', "you've", '’ve', 'two', 'towards', 'meanwhile', 'perhaps', 'when', 'ma', "shouldn't", 'both', 'hereafter', 'he', 'describe', 'ca', 'which', 'every', 'between', 'give', 'go', 'very', '’d', 'nevertheless', 'is', 'n‘t', 'therefore', '‘ll', 'unless', 'next', 'who', 'became', 'mill', 'him', 'don', 'same', "'s", 'seemed', 'mostly', 'will', 're', "you'd", 'no', 'in', 'too', "mightn't", 'besides', 'are', 'because', 'couldnt', 'd', 'against', "doesn't", 'cant', 'whenever', 'somehow', 'thereafter', 'although', 'beyond', 'from', 'whereas', 'thus', 'than', "shan't", 'to', 'top', 'until', 'those', 'whom', 'bottom', 'else', 'herein', 'something', '‘m', 'may', 'not', "that'll", "'m", 'indeed', 'never', 'herself', 'interest', "n't", 'become', 'mine', 'otherwise'} + punctuations = ['!', '"', '#', '$', '%', '&', "'", '(', ')', '*', '+', ',', ' ', '-', '.', '/', ':', ';', '<', '=', '>', '?', '@', '[', '\\', ']', '^', '_', '`', '{', '|', '}', '~'] + other_words = {'photo', 'image', 'picture', 'pic', 'side', 'part', 'background'} + stopwords_and_punctuations = stopwords.union(punctuations) + stopwords_and_punctuations = stopwords_and_punctuations.union(other_words) + stopwords_and_punctuations = [stopword.lower() for stopword in stopwords_and_punctuations] + stopwords_and_punctuations.sort() + + return stopwords_and_punctuations + +def load_visual_genome_entities(path: str, all_entities: bool = True) -> List[str]: + # Visual Genome Vocabulary + + with open(path, 'rb') as infile: + all_objects_attributes_relationships = pickle.load(infile) # dictionary {'relationships': dict, 'attributes': dict, 'objects': dict} + entities = all_objects_attributes_relationships['objects'] # dictionary {'gqa': set, 'vg': set, 'joint': set}, joint = gqa + vg + entities = entities['joint'] # set + + if all_entities: + entities = [entity.lower().strip() for entity in entities] + else: + entities = [entity.lower().strip() for entity in entities if len(entity.split()) == 1] + entities.sort() # sort + + return entities + +def load_coco_entities(path: str, all_entities: bool = True) -> List[str]: + # COCO Vocabulary + + with open(path, 'r') as infile: + entities = json.load(infile) # List [category1, category2, ...] + + if all_entities: + entities = [entity.lower().strip() for entity in entities] + else: + entities = [entity.lower().strip() for entity in entities if len(entity.split()) == 1] + entities.sort() # sort + + return entities + +def load_open_image_entities(path: str, all_entities: bool = True) -> List[str]: + # Open Image Vocabulary + + open_images = pd.read_csv(path) # 601x2, i.e., [LabelName, DisplayName] + open_image_entities = list(open_images.DisplayName) # list + + for i in range(len(open_image_entities)): + entity = open_image_entities[i].lower().strip() + if entity[-1] == ')': + entity = entity[:entity.find('(')].strip() + open_image_entities[i] = entity + + if all_entities: + entities = [entity for entity in open_image_entities] + else: + entities = [entity for entity in open_image_entities if len(entity.split()) == 1] + entities.sort() # sort + + return entities + +def load_vinvl_vg_entities(path: str, all_entities: bool = True) -> List[str]: + # VG Vocabulary + + with open(path, 'r') as infile: + annotations = json.load(infile) # dictionary = {'label_to_idx':dict,'idx_to_label':dict,'attribute_to_idx':dict,'idx_to_attribute':dict,'predicate_to_idx':dict,'idx_to_predicate':dict,'object_count':dict,'attribute_count':dict,'predicate_count':dict,} + vinvl_entities = annotations['object_count'] # dictionary = {str: int, str: int, ...} + + if all_entities: + entities = [entity.lower().strip() for entity in vinvl_entities] + else: + entities = [entity.lower().strip() for entity in vinvl_entities if len(entity.split()) == 1] + entities.sort() # sort + + return entities + +def load_vinvl_vgoi_entities(path: str, all_entities: bool = True) -> List[str]: + + with open(path, 'r') as infile: + vgoi_entities = json.load(infile) # dictionary = {str: int} + + if all_entities: + entities = [entity.lower().strip() for entity in vgoi_entities] + else: + entities = [entity.lower().strip() for entity in vgoi_entities if len(entity.split()) == 1] + entities.sort() # sort + + return entities + +def load_entities_text(name_of_entities: str, path_of_entities: str, all_entities: bool = True) -> List[str]: + """ + Args: + name_of_entities: specifying the name of entities text + path_of_entities: specifying the path of entities text + all_entities: whether to apply all entities text. True denotes using entities including len(entitites.split()) > 1 + Return: + [entity1, entity2, ...] + """ + if name_of_entities == 'visual_genome_entities': + return load_visual_genome_entities(path_of_entities, all_entities) + + if name_of_entities == 'coco_entities': + return load_coco_entities(path_of_entities, all_entities) + + if name_of_entities == 'open_image_entities': + return load_open_image_entities(path_of_entities, all_entities) + + if name_of_entities == 'vinvl_vg_entities': + return load_vinvl_vg_entities(path_of_entities, all_entities) + + if name_of_entities == 'vinvl_vgoi_entities': + return load_vinvl_vgoi_entities(path_of_entities, all_entities) + + print('The entities text fails to load!') + +if __name__ == '__main__': + + # loading captions + datasets = ['coco_captions', 'flickr30k_captions'] + captions_path = [ + './annotations/coco/train_captions.json', + './annotations/flickr30k/train_captions.json', + ] + captions_idx = 1 + captions = load_captions(datasets[captions_idx], captions_path[captions_idx]) + for caption in captions[:20]: + print(caption) + print(len(captions), type(captions)) + + # loading stopwords + stopwords = load_stopwords() + print('stopwords: ', stopwords[:10], type(stopwords), len(stopwords)) + + # loading entities text + entities_text = ['visual_genome_entities', 'coco_entities', 'open_image_entities', 'vinvl_vg_entities', 'vinvl_vgoi_entities'] + entities_path = [ + './annotations/vocabulary/all_objects_attributes_relationships.pickle', + './annotations/vocabulary/coco_categories.json', + './annotations/vocabulary/oidv7-class-descriptions-boxable.csv', + './annotations/vocabulary/VG-SGG-dicts-vgoi6-clipped.json', + './annotations/vocabulary/vgcocooiobjects_v1_class2ind.json' + ] + # using all entities text + entities_idx = 4 + entities = load_entities_text(entities_text[entities_idx], entities_path[entities_idx]) + print('entities text: ', entities[:10], type(entities), len(entities)) + # using entities text with a single word + entities_idx = 4 + entities = load_entities_text(entities_text[entities_idx], entities_path[entities_idx], all_entities = False) + print('entities text: ', entities[:10], type(entities), len(entities)) \ No newline at end of file diff --git a/main.py b/main.py new file mode 100644 index 0000000..1f5a4d3 --- /dev/null +++ b/main.py @@ -0,0 +1,168 @@ +import os +import sys +import clip +import torch +import random +import argparse +import numpy as np +from tqdm import tqdm +import torch.nn.functional as nnf +from utils import noise_injection +from CaptionsDataset import collate +from torch.utils.data import DataLoader +from CaptionsDataset import CaptionsDataset +from ClipCap import ClipCaptionModel, ClipCaptionPrefix +from transformers import AdamW, get_linear_schedule_with_warmup + +def set_seed(seed: int) -> None: + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + +def train( + args, # parameters used for training + datasets: CaptionsDataset, # datasets used for training + model: ClipCaptionModel, # captioning model used for training + warmup_steps: int = 5000, # warming up steps used for traing + output_dir: str = '.', # output path of the wights + output_prefix: str = '' # file prefix name of saved weights +): + device = args.device + batch_size = args.bs + epochs = args.epochs + + + # if the path of outputs does not exist, create it according to the output_dir + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + # loading model + model = model.to(device) + model.train() + if not args.using_clip_features: + encoder, _ = clip.load(args.clip_model, device = device) + encoder.eval() + + # method of optimization + optimizer = AdamW(model.parameters(), lr = args.lr) + dataloader = DataLoader(datasets, batch_size = batch_size, shuffle = True, drop_last = True, num_workers=args.num_workers, collate_fn=collate) + tokenizer = dataloader.dataset.tokenizer + schedular = get_linear_schedule_with_warmup(optimizer, num_warmup_steps = warmup_steps, num_training_steps = epochs * len(dataloader)) + scaler = torch.cuda.amp.GradScaler(enabled = args.use_amp) + for epoch in range(epochs): + # visualization + sys.stdout.flush() + print(f">>> Training epoch {epoch}") + progress = tqdm(total = len(dataloader), desc = output_prefix) + train_loss_sum = 0 + # training + for idx, (captions_clip, captions_gpt_tokens, captions_tokens_for_loss, masks, hard_prompts_length) in enumerate(dataloader): + model.zero_grad() + if not args.using_clip_features: + with torch.no_grad(): + captions_clip_tokens = captions_clip.to(device) # caption_clip -> tokens, (b, 77) + continuous_prefix = encoder.encode_text(captions_clip_tokens).float() # (b, clip_hidden_size) + else: + continuous_prefix = captions_clip.to(device).float() # caption_clip -> embeddings, (b, clip_hidden_size) + + if args.normalize_prefix: + continuous_prefix /= continuous_prefix.norm(2, dim = -1, keepdim = True) + continuous_prefix = noise_injection(continuous_prefix, variance = args.noise_variance, device = args.device) + captions_gpt_tokens, captions_tokens_for_loss, masks = captions_gpt_tokens.to(device), captions_tokens_for_loss.to(device), masks.to(device) + + with torch.cuda.amp.autocast(enabled = args.use_amp): + if args.using_hard_prompt: + outputs = model(continuous_prefix, captions_gpt_tokens, hard_prompts_length, masks) + logits = outputs.logits # (batch_size, max_length, vocab_size) + else: + outputs = model(continuous_prefix, captions_gpt_tokens, mask = masks) + logits = outputs.logits # (batch_size, max_length, vocab_size) + captions_tokens_for_loss = captions_tokens_for_loss.masked_fill(captions_tokens_for_loss == tokenizer.eos_token_id, 0) + + # ignore_index = target, value: specifying a target value that is ignored and does not contribute to the input gradient + loss = nnf.cross_entropy(logits.reshape(-1, logits.shape[-1]), captions_tokens_for_loss.flatten(), ignore_index = 0) + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() + schedular.step() + optimizer.zero_grad() + progress.set_postfix({"loss": loss.item()}) + progress.update() + train_loss_sum += loss.item() + log_iters = len(dataloader)//5 if len(dataloader) > 5 else len(dataloader) + if (idx + 1) % (log_iters) == 0: + print('epoch {}, iter {}, average train loss: {}'.format(epoch, idx, train_loss_sum / log_iters)) + train_loss_sum = 0 + torch.save(model.state_dict(), os.path.join(output_dir, f"{output_prefix}_latest.pt")) + progress.close() + if (epoch+1) % args.save_every == 0 or epoch == epochs - 1: + ckpt_path = os.path.join(output_dir, f"{output_prefix}-00{epoch}.pt") + torch.save(model.state_dict(), ckpt_path) + print(f'saving checkpoint to {ckpt_path}') + +def main(): + os.environ["TOKENIZERS_PARALLELISM"] = "false" + + parser = argparse.ArgumentParser() + parser.add_argument('--bs', type = int, default = 80, help = 'batch size') + parser.add_argument('--lr', type = float, default = 2e-5, help = 'learning rate for training') + parser.add_argument('--device', default = 'cuda:0', help = 'gpu for training') + parser.add_argument('--epochs', type = int, default = 15, help = 'number of epochs') + parser.add_argument('--random_mask', action = 'store_true', default = False, help = 'entity masking strategy') + parser.add_argument('--prob_of_random_mask', type = float, default = 0.4, help = 'masking rate') + parser.add_argument('--clip_project_length', type = int, default = 10, help = 'clip projecting length') + parser.add_argument('--continuous_prompt_length', type = int, default = 10, help = 'soft prompts length') + parser.add_argument('--max_num_of_entities', type = int, default = 10, help = 'maximum number of detected entities') + parser.add_argument('--prompt_template_length', type = int, default = 5, help = 'maximum number of hard prompt entities') + parser.add_argument('--num_layers', type = int, default = 8, help = 'number of layer in Transformer-based projector') + parser.add_argument('--noise_variance', type = float, default = 0.016, help = 'noise variance') + parser.add_argument('--clip_model', default = 'ViT-B/32', help = "'RN50', 'RN101', 'RN50x4', 'ViT-B/32'") + parser.add_argument('--using_clip_features', action = 'store_true', default = False, help = 'whether to use the pre-extracted features') + parser.add_argument('--is_rn', dest = 'is_rn', action = 'store_true', default = False, help = 'CLIP backbone: True -> ResNet, False -> ViT') + parser.add_argument('--language_model', default = 'gpt2', help = 'gpt2, facebook/opt-350m') + parser.add_argument('--using_hard_prompt', action = 'store_true', default = False, help = 'whether to entity-aware hard prompts') + parser.add_argument('--soft_prompt_first', action = 'store_true', default = False, help = 'True -> soft prompt first, i.e., soft prompt + hard prompt') + parser.add_argument('--only_hard_prompt', action = 'store_true', default = False, help = 'True -> do not use soft prompts in this case') + parser.add_argument('--debug', action = 'store_true', default = False, help = 'debug = True means using a smaller dataloader') + parser.add_argument('--few_shot_ratio', type = float, default = 1.0, help = 'measuring the low-data setting') + parser.add_argument('--save_every', type = int, default = 1, help = 'save weights every n epochs') + parser.add_argument('--prefix', default = 'coco_prefix', help = 'prefix name for saved weights') + parser.add_argument('--path_of_datasets', default = './annotations/coco/coco_with_entities.pickle') + parser.add_argument('--out_dir', default = './checkpoints', help = 'the path of output') + parser.add_argument('--normalize_prefix', dest = 'normalize_prefix', type = int, default = True, help = 'normalizing prefix') + parser.add_argument('--name_of_objects_vocabs', default = 'visual_genome_entities') + parser.add_argument('--path_of_objects_vocabs', default = './annotations/vocabulary/all_objects_attributes_relationships.pickle') + parser.add_argument('--frozen_gpt', action = 'store_true', default = False, help = 'freezing language models during training') + parser.add_argument('--num_workers', type = int, default = 0) + parser.add_argument('--use_amp', action = 'store_true', default = False, help = "whether to use torch.amp to acclerate training") + parser.add_argument('--disable_random_seed', action = 'store_true', default = False, help = 'set random seed for reproducing') + parser.add_argument('--random_seed', type = int, default = 30, help = 'set random seed for reproducing') + + args = parser.parse_args() + print(f'args: {vars(args)}') + if not args.disable_random_seed: + set_seed(args.random_seed) + + clip_hidden_size = 640 if args.is_rn else 512 + + datasets = CaptionsDataset( + language_model = args.language_model, + max_num_of_entities = args.max_num_of_entities, + using_clip_features = args.using_clip_features, + path_of_datasets = args.path_of_datasets, + debug = args.debug, + args = args + ) + if args.frozen_gpt: + model = ClipCaptionPrefix(args.continuous_prompt_length, args.clip_project_length, clip_hidden_size, args.num_layers, gpt_type = args.language_model, soft_prompt_first = args.soft_prompt_first, only_hard_prompt = args.only_hard_prompt) + else: + model = ClipCaptionModel(args.continuous_prompt_length, args.clip_project_length, clip_hidden_size, args.num_layers, gpt_type = args.language_model, soft_prompt_first = args.soft_prompt_first, only_hard_prompt = args.only_hard_prompt) + + train(args, datasets, model, output_dir = args.out_dir, output_prefix = args.prefix) + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/modeling_opt.py b/modeling_opt.py new file mode 100644 index 0000000..2aca5b3 --- /dev/null +++ b/modeling_opt.py @@ -0,0 +1,1113 @@ +# coding=utf-8 +# Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch OPT model.""" +import random +from typing import List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from transformers.activations import ACT2FN +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from transformers.models.opt.configuration_opt import OPTConfig + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "facebook/opt-350m" +_CONFIG_FOR_DOC = "OPTConfig" +_TOKENIZER_FOR_DOC = "GPT2Tokenizer" + +# Base model docstring +_EXPECTED_OUTPUT_SHAPE = [1, 8, 1024] + +# SequenceClassification docstring +_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION = "ArthurZ/opt-350m-dummy-sc" +_SEQ_CLASS_EXPECTED_LOSS = 1.71 +_SEQ_CLASS_EXPECTED_OUTPUT = "'LABEL_0'" + +# QuestionAnswering docstring +_QA_EXPECTED_OUTPUT = "'a nice puppet'" +_QA_EXPECTED_LOSS = 7.41 +_QA_TARGET_START_INDEX = 14 +_QA_TARGET_END_INDEX = 15 + +OPT_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "facebook/opt-125m", + "facebook/opt-350m", + "facebook/opt-1.3b", + "facebook/opt-2.7b", + "facebook/opt-6.7b", + "facebook/opt-13b", + "facebook/opt-30b", + # See all OPT models at https://huggingface.co/models?filter=opt +] + + +def _make_causal_mask( + input_ids_shape: torch.Size, dtype: torch.dtype, past_key_values_length: int = 0 +): + """ + Make causal mask used for bi-directional self-attention. + """ + bsz, tgt_len = input_ids_shape + mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min)) + mask_cond = torch.arange(mask.size(-1)) + mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) + mask = mask.to(dtype) + + if past_key_values_length > 0: + mask = torch.cat( + [torch.zeros(tgt_len, past_key_values_length, dtype=dtype), mask], dim=-1 + ) + return mask[None, None, :, :].expand( + bsz, 1, tgt_len, tgt_len + past_key_values_length + ) + + +def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill( + inverted_mask.to(torch.bool), torch.finfo(dtype).min + ) + + +class OPTLearnedPositionalEmbedding(nn.Embedding): + """ + This module learns positional embeddings up to a fixed maximum size. + """ + + def __init__(self, num_embeddings: int, embedding_dim: int): + # OPT is set up so that if padding_idx is specified then offset the embedding ids by 2 + # and adjust num_embeddings appropriately. Other models don't have this hack + self.offset = 2 + super().__init__(num_embeddings + self.offset, embedding_dim) + + def forward( + self, attention_mask: torch.LongTensor, past_key_values_length: int = 0 + ): + """`input_ids_shape` is expected to be [bsz x seqlen].""" + attention_mask = attention_mask.long() + + # create positions depending on attention_mask + positions = ( + torch.cumsum(attention_mask, dim=1).type_as(attention_mask) * attention_mask + ).long() - 1 + + # cut positions if `past_key_values_length` is > 0 + positions = positions[:, past_key_values_length:] + + return super().forward(positions + self.offset) + + +class OPTAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + is_decoder: bool = False, + bias: bool = True, + ): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + + if (self.head_dim * num_heads) != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {num_heads})." + ) + self.scaling = self.head_dim**-0.5 + self.is_decoder = is_decoder + + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return ( + tensor.view(bsz, seq_len, self.num_heads, self.head_dim) + .transpose(1, 2) + .contiguous() + ) + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, tgt_len, _ = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) * self.scaling + # get key, value proj + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + key_states = key_states.view(*proj_shape) + value_states = value_states.view(*proj_shape) + + src_len = key_states.size(1) + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + ) + attn_weights = ( + attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + + attention_mask + ) + attn_weights = torch.max( + attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min) + ) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + # upcast to fp32 if the weights are in fp16. Please see https://github.com/huggingface/transformers/pull/17437 + if attn_weights.dtype == torch.float16: + attn_weights = nn.functional.softmax( + attn_weights, dim=-1, dtype=torch.float32 + ).to(torch.float16) + else: + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if layer_head_mask is not None: + if layer_head_mask.size() != (self.num_heads,): + raise ValueError( + f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" + f" {layer_head_mask.size()}" + ) + attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view( + bsz, self.num_heads, tgt_len, src_len + ) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if output_attentions: + # this operation is a bit awkward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to be reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view( + bsz, self.num_heads, tgt_len, src_len + ) + attn_weights = attn_weights_reshaped.view( + bsz * self.num_heads, tgt_len, src_len + ) + else: + attn_weights_reshaped = None + + attn_probs = nn.functional.dropout( + attn_weights, p=self.dropout, training=self.training + ) + + attn_output = torch.bmm(attn_probs, value_states) + + if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + attn_output = attn_output.transpose(1, 2) + + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned aross GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights_reshaped, past_key_value + + +class OPTDecoderLayer(nn.Module): + def __init__(self, config: OPTConfig): + super().__init__() + self.embed_dim = config.hidden_size + self.self_attn = OPTAttention( + embed_dim=self.embed_dim, + num_heads=config.num_attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + ) + self.do_layer_norm_before = config.do_layer_norm_before + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.fc1 = nn.Linear(self.embed_dim, config.ffn_dim) + self.fc2 = nn.Linear(config.ffn_dim, self.embed_dim) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + ) -> Tuple[ + torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] + ]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (`torch.FloatTensor`, *optional*): mask for attention heads in a given layer of size + `(encoder_attention_heads,)`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + + residual = hidden_states + + # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention + if self.do_layer_norm_before: + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + past_key_value=past_key_value, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + ) + hidden_states = nn.functional.dropout( + hidden_states, p=self.dropout, training=self.training + ) + hidden_states = residual + hidden_states + + # 350m applies layer norm AFTER attention + if not self.do_layer_norm_before: + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Fully Connected + hidden_states_shape = hidden_states.shape + hidden_states = hidden_states.reshape(-1, hidden_states.size(-1)) + residual = hidden_states + + # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention + if self.do_layer_norm_before: + hidden_states = self.final_layer_norm(hidden_states) + + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout( + hidden_states, p=self.dropout, training=self.training + ) + + hidden_states = (residual + hidden_states).view(hidden_states_shape) + + # 350m applies layer norm AFTER attention + if not self.do_layer_norm_before: + hidden_states = self.final_layer_norm(hidden_states) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +OPT_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`OPTConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare OPT Model outputting raw hidden-states without any specific head on top.", + OPT_START_DOCSTRING, +) +class OPTPreTrainedModel(PreTrainedModel): + + config_class = OPTConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["OPTDecoderLayer"] + _keys_to_ignore_on_load_unexpected = [r"decoder\.version"] + + def _init_weights(self, module): + std = self.config.init_std + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (OPTDecoder)): + module.gradient_checkpointing = value + + +OPT_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`GPT2Tokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`OPTTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +class OPTDecoder(OPTPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`OPTDecoderLayer`] + + Args: + config: OPTConfig + """ + + def __init__(self, config: OPTConfig): + super().__init__(config) + self.dropout = config.dropout + self.layerdrop = config.layerdrop + self.padding_idx = config.pad_token_id + self.max_target_positions = config.max_position_embeddings + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding( + config.vocab_size, config.word_embed_proj_dim, self.padding_idx + ) + self.embed_positions = OPTLearnedPositionalEmbedding( + config.max_position_embeddings, config.hidden_size + ) + + if config.word_embed_proj_dim != config.hidden_size: + self.project_out = nn.Linear( + config.hidden_size, config.word_embed_proj_dim, bias=False + ) + else: + self.project_out = None + + if config.word_embed_proj_dim != config.hidden_size: + self.project_in = nn.Linear( + config.word_embed_proj_dim, config.hidden_size, bias=False + ) + else: + self.project_in = None + + # Note that the only purpose of `config._remove_final_layer_norm` is to keep backward compatibility + # with checkpoints that have been fine-tuned before transformers v4.20.1 + # see https://github.com/facebookresearch/metaseq/pull/164 + if config.do_layer_norm_before and not config._remove_final_layer_norm: + self.final_layer_norm = nn.LayerNorm(config.hidden_size) + else: + self.final_layer_norm = None + + self.layers = nn.ModuleList( + [OPTDecoderLayer(config) for _ in range(config.num_hidden_layers)] + ) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask + def _prepare_decoder_attention_mask( + self, attention_mask, input_shape, inputs_embeds, past_key_values_length + ): + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + combined_attention_mask = None + if input_shape[-1] > 1: + combined_attention_mask = _make_causal_mask( + input_shape, + inputs_embeds.dtype, + past_key_values_length=past_key_values_length, + ).to(inputs_embeds.device) + + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = _expand_mask( + attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ).to(inputs_embeds.device) + combined_attention_mask = ( + expanded_attn_mask + if combined_attention_mask is None + else expanded_attn_mask + combined_attention_mask + ) + + return combined_attention_mask + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + query_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`OPTTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the + cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those + that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of + all `decoder_input_ids` of shape `(batch_size, sequence_length)`. + + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time" + ) + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError( + "You have to specify either decoder_input_ids or decoder_inputs_embeds" + ) + + past_key_values_length = ( + past_key_values[0][0].shape[2] if past_key_values is not None else 0 + ) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if query_embeds is not None: + inputs_embeds = torch.cat([query_embeds, inputs_embeds], dim=1) + input_shape = inputs_embeds.size()[:-1] + + # embed positions + if attention_mask is None: + attention_mask = torch.ones( + inputs_embeds.shape[:2], dtype=torch.bool, device=inputs_embeds.device + ) + pos_embeds = self.embed_positions(attention_mask, past_key_values_length) + + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length + ) + + if self.project_in is not None: + inputs_embeds = self.project_in(inputs_embeds) + + hidden_states = inputs_embeds + pos_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + + # check if head_mask has a correct number of layers specified if desired + for attn_mask, mask_name in zip([head_mask], ["head_mask"]): + if attn_mask is not None: + if attn_mask.size()[0] != (len(self.layers)): + raise ValueError( + f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for" + f" {head_mask.size()[0]}." + ) + + for idx, decoder_layer in enumerate(self.layers): + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + if output_hidden_states: + all_hidden_states += (hidden_states,) + + dropout_probability = random.uniform(0, 1) + if self.training and (dropout_probability < self.layerdrop): + continue + + past_key_value = ( + past_key_values[idx] if past_key_values is not None else None + ) + + if self.gradient_checkpointing and self.training: + + if use_cache: + logger.warning( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, output_attentions, None) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(decoder_layer), + hidden_states, + attention_mask, + head_mask[idx] if head_mask is not None else None, + None, + ) + else: + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if self.final_layer_norm is not None: + hidden_states = self.final_layer_norm(hidden_states) + + if self.project_out is not None: + hidden_states = self.project_out(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] + if v is not None + ) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +@add_start_docstrings( + "The bare OPT Model outputting raw hidden-states without any specific head on top.", + OPT_START_DOCSTRING, +) +class OPTModel(OPTPreTrainedModel): + def __init__(self, config: OPTConfig): + super().__init__(config) + self.decoder = OPTDecoder(config) + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.decoder.embed_tokens + + def set_input_embeddings(self, value): + self.decoder.embed_tokens = value + + def get_decoder(self): + return self.decoder + + @add_start_docstrings_to_model_forward(OPT_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutputWithPast, + config_class=_CONFIG_FOR_DOC, + expected_output=_EXPECTED_OUTPUT_SHAPE, + ) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + query_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) + decoder_outputs = self.decoder( + input_ids=input_ids, + attention_mask=attention_mask, + head_mask=head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + query_embeds=query_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if not return_dict: + return decoder_outputs + + return BaseModelOutputWithPast( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + hidden_states=decoder_outputs.hidden_states, + attentions=decoder_outputs.attentions, + ) + + +class OPTForCausalLM(OPTPreTrainedModel): + _keys_to_ignore_on_load_missing = [r"lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = OPTModel(config) + + # the lm_head weight is automatically tied to the embed tokens weight + self.lm_head = nn.Linear( + config.word_embed_proj_dim, config.vocab_size, bias=False + ) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.decoder.embed_tokens + + def set_input_embeddings(self, value): + self.model.decoder.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model.decoder = decoder + + def get_decoder(self): + return self.model.decoder + + @replace_return_docstrings( + output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC + ) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + query_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + reduction: Optional[str] = "mean", + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`OPTTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of + shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional + tensors are only required when the model is used as a decoder in a Sequence to Sequence model. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the + cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those + that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of + all `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + + Returns: + + Example: + + ```python + >>> from transformers import GPT2Tokenizer, OPTForCausalLM + + >>> model = OPTForCausalLM.from_pretrained("facebook/opt-350m") + >>> tokenizer = GPT2Tokenizer.from_pretrained("facebook/opt-350m") + + >>> prompt = "Hey, are you consciours? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you." + ```""" + + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model.decoder( + input_ids=input_ids, + attention_mask=attention_mask, + head_mask=head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + query_embeds=query_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + logits = self.lm_head(outputs[0]).contiguous() + + loss = None + if labels is not None: + logits = logits[:, -labels.size(1) :, :] + + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss(reduction=reduction) + loss = loss_fct( + shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1) + ) + if reduction == "none": + loss = loss.view(shift_logits.size(0), -1).sum(1) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, + input_ids=None, + query_embeds=None, + past=None, + attention_mask=None, + use_cache=None, + **kwargs, + ): + # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly + if attention_mask is None: + if input_ids is not None: + attention_mask = input_ids.new_ones(input_ids.shape) + if past: + input_ids = input_ids[:, -1:] + query_embeds = None + # first step, decoder_cached_states are empty + return { + "input_ids": input_ids, + "query_embeds": query_embeds, + "attention_mask": attention_mask, + "past_key_values": past, + "use_cache": use_cache, + } + + @staticmethod + def _reorder_cache(past, beam_idx): + reordered_past = () + for layer_past in past: + reordered_past += ( + tuple( + past_state.index_select(0, beam_idx) for past_state in layer_past + ), + ) + return reordered_past \ No newline at end of file diff --git a/retrieval_categories.py b/retrieval_categories.py new file mode 100644 index 0000000..98359eb --- /dev/null +++ b/retrieval_categories.py @@ -0,0 +1,117 @@ +import os +import clip +import torch +import pickle +from PIL import Image +from typing import List, Optional, Tuple + +@torch.no_grad() +def clip_texts_embeddings( + texts: List[str], + outpath = '', + device: Optional[str] = None, + batch_size: Optional[int] = 32, + clip_type: Optional[str] = None +) -> torch.Tensor: + """ + Args: + texts: name of categories, i.e., ['category1', 'category2', ...] + outpath: saving embeddings of category texts to outpath. reading it directly if existing + device: specifying device used + batch_size: the number of categories that would be transformed to embeddings per epoch + clip_type: specifying clip backbone used + Return: + tensor with a shape of (num_categories, clip_hidden_size), float32 + """ + if os.path.exists(outpath): + with open(outpath, 'rb') as infile: + texts_embeddings = pickle.load(infile) # (num_categories, clip_hidden_size) + return texts_embeddings + + # adding prompt for each category text, i.e., Photo of an ariplane. / Photo of a bicycle. + vowel = ['a', 'e', 'i', 'o', 'u', 'A', 'E', 'I', 'O', 'U'] + prompt_texts = [] + for text in texts: + if text[0] in vowel: + prompt_texts.append(f'A photo of an {text}.') + else: + prompt_texts.append(f'A photo of a {text}.') + + clip_texts_tokens = clip.tokenize(prompt_texts) # (num_categories, 77) + model, _ = clip.load(clip_type, device = device) # loading clip encoder + model.eval() + num_categories = len(texts) + texts_embeddings = None + epochs = int(num_categories / batch_size) if num_categories % batch_size == 0 else 1 + int (num_categories // batch_size) + for epoch in range(epochs): + temp_texts_tokens = clip_texts_tokens[batch_size * epoch : batch_size * (epoch + 1)] # (batch_size/(num_categories % batch_size), 77) + temp_texts_tokens = temp_texts_tokens.to(device) + with torch.no_grad(): + temp_texts_embeddings = model.encode_text(temp_texts_tokens).float().to('cpu') # (batch_size/(num_categories % batch_size), clip_hidden_size) + if texts_embeddings is None: + texts_embeddings = temp_texts_embeddings + else: + texts_embeddings = torch.cat((texts_embeddings, temp_texts_embeddings), dim = 0) + + with open(outpath, 'wb') as outfile: + pickle.dump(texts_embeddings, outfile) + + return texts_embeddings + +def image_text_simiarlity( + texts_embeddings: torch.Tensor, + temperature: float = 0.01, + image_path: Optional[str] = None, + images_features: Optional[torch.Tensor] = None, + clip_type: Optional[str] = None, + device: Optional[str] = None +) -> torch.Tensor: + """ + Args: + texts_embeddings: (num_categories, clip_hidden_size), float32, the embeddings of categories + temperature: temperature hyperparameter for computing similarity + image_path: Optional, the path of a single image + images_feature: (num_images, clip_hidden_size), float32, Optional + clip_type: clip type, using when input is image path + device: device using when input is device + Return: + logits with a shape of (num_images, num_categories) + """ + if images_features is None: + encoder, preprocess = clip.load(clip_type, device) + assert image_path is not None, 'Either image path or images feature should be given!' + image = preprocess(Image.open(image_path)).unsqueeze(dim = 0).to(device) # (1, 3, 224, 224) + with torch.no_grad(): + images_features = encoder.encode_image(image) # (1, clip_hidden_size) + + # computing on cpu to avoid out of memory + images_features = images_features.float().to('cpu') # (num_images, clip_hidden_size) + texts_embeddings = texts_embeddings.float().to('cpu') # (num_categories, clip_hidden_size) + images_features /= images_features.norm(dim = -1, keepdim = True) # (num_images, clip_hidden_size) + texts_embeddings /= texts_embeddings.norm(dim = -1, keepdim = True) # (num_categories, clip_hidden_size) + + image_to_text_similarity = torch.matmul(images_features, texts_embeddings.transpose(1, 0)) / temperature # (num_imegs, num_categories) + image_to_text_logits = torch.nn.functional.softmax(image_to_text_similarity, dim = -1) # (num_imegs, num_categories) + + return image_to_text_logits + +def top_k_categories( + texts: List[str], # ['category1', 'category2', ...], len = num_categories + logits: torch.Tensor, # (num_images, num_categories) + top_k: Optional[int] = 5, # choosing top k categories as retrieved category + threshold: Optional[float] = 0.0 # probability which is less than threshold will be filtered +) -> Tuple: + + top_k_probs, top_k_indices = torch.topk(logits, k = top_k, dim = -1) # (num_images, top_k) + top_k_texts = [] + for i in range(len(top_k_probs)): + per_image_top_k_probs = top_k_probs[i] # the ith image top k probability + per_image_top_k_indices = top_k_indices[i] # the ith image top k indices + temp_texts = [] + for j in range(top_k): + if per_image_top_k_probs[j] < threshold: + break + temp_texts.append(texts[per_image_top_k_indices[j]]) + top_k_texts.append(temp_texts) + + return top_k_texts, top_k_probs \ No newline at end of file diff --git a/scripts/eval_coco.sh b/scripts/eval_coco.sh new file mode 100644 index 0000000..92e07bd --- /dev/null +++ b/scripts/eval_coco.sh @@ -0,0 +1,39 @@ +SHELL_FOLDER=$(cd "$(dirname "$0")";pwd) +cd $SHELL_FOLDER/.. + +EXP_NAME=$1 +DEVICE=$2 +OTHER_ARGS=$3 +EPOCH=$4 +WEIGHT_PATH=checkpoints/$EXP_NAME/coco_prefix-00${EPOCH}.pt +COCO_OUT_PATH=checkpoints/$EXP_NAME + +TIME_START=$(date "+%Y-%m-%d-%H-%M-%S") +LOG_FOLDER=logs/${EXP_NAME}_EVAL +mkdir -p $LOG_FOLDER + +COCO_LOG_FILE="$LOG_FOLDER/COCO_${TIME_START}.log" + +python validation.py \ +--device cuda:$DEVICE \ +--clip_model ViT-B/32 \ +--language_model gpt2 \ +--continuous_prompt_length 10 \ +--clip_project_length 10 \ +--top_k 3 \ +--threshold 0.4 \ +--using_image_features \ +--name_of_datasets coco \ +--path_of_val_datasets ./annotations/coco/test_captions.json \ +--name_of_entities_text coco_entities \ +--image_folder ./annotations/coco/val2014/ \ +--prompt_ensemble \ +--weight_path=$WEIGHT_PATH \ +--out_path=$COCO_OUT_PATH \ +--using_hard_prompt \ +--soft_prompt_first \ +$OTHER_ARGS \ +|& tee -a ${COCO_LOG_FILE} + +echo "==========================COCO EVAL================================" +python evaluation/cocoeval.py --result_file_path $COCO_OUT_PATH/coco*.json |& tee -a ${COCO_LOG_FILE} \ No newline at end of file diff --git a/scripts/eval_flickr30k.sh b/scripts/eval_flickr30k.sh new file mode 100644 index 0000000..53e6dbc --- /dev/null +++ b/scripts/eval_flickr30k.sh @@ -0,0 +1,40 @@ +SHELL_FOLDER=$(cd "$(dirname "$0")";pwd) +cd $SHELL_FOLDER/.. + +EXP_NAME=$1 +DEVICE=$2 +OTHER_ARGS=$3 +EPOCH=$4 +WEIGHT_PATH=checkpoints/$EXP_NAME/coco_prefix-00${EPOCH}.pt +FLICKR_OUT_PATH=checkpoints/$EXP_NAME + +TIME_START=$(date "+%Y-%m-%d-%H-%M-%S") +LOG_FOLDER=logs/${EXP_NAME}_EVAL +mkdir -p $LOG_FOLDER + +FLICKR_LOG_FILE="$LOG_FOLDER/FLICKR_${TIME_START}.log" + +python validation.py \ +--device cuda:$DEVICE \ +--clip_model ViT-B/32 \ +--language_model gpt2 \ +--continuous_prompt_length 10 \ +--clip_project_length 10 \ +--top_k 3 \ +--threshold 0.3 \ +--using_image_features \ +--name_of_datasets flickr30k \ +--path_of_val_datasets ./annotations/flickr30k/test_captions.json \ +--name_of_entities_text vinvl_vgoi_entities \ +--image_folder ./annotations/flickr30k/flickr30k-images/ \ +--prompt_ensemble \ +--weight_path=$WEIGHT_PATH \ +--out_path=$FLICKR_OUT_PATH \ +--using_hard_prompt \ +--soft_prompt_first \ +--using_greedy_search \ +$OTHER_ARGS \ +|& tee -a ${FLICKR_LOG_FILE} + +echo "==========================FLICKR EVAL================================" +python evaluation/cocoeval.py --result_file_path $FLICKR_OUT_PATH/flickr30k*.json |& tee -a ${FLICKR_LOG_FILE} \ No newline at end of file diff --git a/scripts/eval_nocaps.sh b/scripts/eval_nocaps.sh new file mode 100644 index 0000000..fb96a95 --- /dev/null +++ b/scripts/eval_nocaps.sh @@ -0,0 +1,46 @@ +SHELL_FOLDER=$(cd "$(dirname "$0")";pwd) +cd $SHELL_FOLDER/.. + +EXP_NAME=$1 +DEVICE=$2 +OTHER_ARGS=$3 +EPOCH=$4 +WEIGHT_PATH=checkpoints/$EXP_NAME/coco_prefix-00${EPOCH}.pt +NOCAPS_OUT_PATH=checkpoints/$EXP_NAME + +TIME_START=$(date "+%Y-%m-%d-%H-%M-%S") +LOG_FOLDER=logs/${EXP_NAME}_EVAL +mkdir -p $LOG_FOLDER + +NOCAPS_LOG_FILE="$LOG_FOLDER/NOCAPS_${TIME_START}.log" + +python validation.py \ +--device cuda:$DEVICE \ +--clip_model ViT-B/32 \ +--language_model gpt2 \ +--continuous_prompt_length 10 \ +--clip_project_length 10 \ +--top_k 3 \ +--threshold 0.2 \ +--using_image_features \ +--name_of_datasets nocaps \ +--path_of_val_datasets ./annotations/nocaps/nocaps_corpus.json \ +--name_of_entities_text vinvl_vgoi_entities \ +--image_folder ./annotations/nocaps/ \ +--prompt_ensemble \ +--weight_path=$WEIGHT_PATH \ +--out_path=$NOCAPS_OUT_PATH \ +--using_hard_prompt \ +--soft_prompt_first \ +$OTHER_ARGS \ +|& tee -a ${NOCAPS_LOG_FILE} + +echo "==========================NOCAPS IN-DOAMIN================================" +python evaluation/cocoeval.py --result_file_path ${NOCAPS_OUT_PATH}/indomain*.json |& tee -a ${NOCAPS_LOG_FILE} +echo "==========================NOCAPS NEAR-DOAMIN================================" +python evaluation/cocoeval.py --result_file_path ${NOCAPS_OUT_PATH}/neardomain*.json |& tee -a ${NOCAPS_LOG_FILE} +echo "==========================NOCAPS OUT-DOAMIN================================" +python evaluation/cocoeval.py --result_file_path ${NOCAPS_OUT_PATH}/outdomain*.json |& tee -a ${NOCAPS_LOG_FILE} +echo "==========================NOCAPS ALL-DOAMIN================================" +python evaluation/cocoeval.py --result_file_path ${NOCAPS_OUT_PATH}/overall*.json |& tee -a ${NOCAPS_LOG_FILE} + diff --git a/scripts/language_eval.sh b/scripts/language_eval.sh new file mode 100644 index 0000000..d4b72b8 --- /dev/null +++ b/scripts/language_eval.sh @@ -0,0 +1,4 @@ +SHELL_FOLDER=$(cd "$(dirname "$0")";pwd) +cd $SHELL_FOLDER/../evaluation + +python cocoeval.py --result_file_path $1 \ No newline at end of file diff --git a/scripts/train_coco.sh b/scripts/train_coco.sh new file mode 100644 index 0000000..8d13113 --- /dev/null +++ b/scripts/train_coco.sh @@ -0,0 +1,32 @@ +SHELL_FOLDER=$(cd "$(dirname "$0")";pwd) +cd $SHELL_FOLDER/.. + +DEVICE=$1 +EXP_NAME=`echo "$(basename $0)" | cut -d'.' -f1` +LOG_FILE=logs/$EXP_NAME + +TIME_START=$(date "+%Y-%m-%d-%H-%M-%S") +LOG_FOLDER=logs/${EXP_NAME} +LOG_FILE="$LOG_FOLDER/${TIME_START}.log" +mkdir -p $LOG_FOLDER + +echo "==========================================================" +echo "RUNNING EXPERIMENTS: $EXP_NAME, saving in checkpoints/$EXP_NAME" +echo "==========================================================" + +python main.py \ +--bs 80 \ +--lr 0.00002 \ +--epochs 15 \ +--device cuda:$DEVICE \ +--random_mask \ +--prob_of_random_mask 0.4 \ +--clip_model ViT-B/32 \ +--using_clip_features \ +--language_model gpt2 \ +--using_hard_prompt \ +--soft_prompt_first \ +--path_of_datasets ./annotations/coco/coco_texts_features_ViT-B32.pickle \ +--out_dir checkpoints/$EXP_NAME \ +--use_amp \ +|& tee -a ${LOG_FILE} \ No newline at end of file diff --git a/scripts/train_flickr30k.sh b/scripts/train_flickr30k.sh new file mode 100644 index 0000000..c4e19b6 --- /dev/null +++ b/scripts/train_flickr30k.sh @@ -0,0 +1,32 @@ +SHELL_FOLDER=$(cd "$(dirname "$0")";pwd) +cd $SHELL_FOLDER/.. + +DEVICE=$1 +EXP_NAME=`echo "$(basename $0)" | cut -d'.' -f1` +LOG_FILE=logs/$EXP_NAME + +TIME_START=$(date "+%Y-%m-%d-%H-%M-%S") +LOG_FOLDER=logs/${EXP_NAME} +LOG_FILE="$LOG_FOLDER/${TIME_START}.log" +mkdir -p $LOG_FOLDER + +echo "==========================================================" +echo "RUNNING EXPERIMENTS: $EXP_NAME, saving in checkpoints/$EXP_NAME" +echo "==========================================================" + +python main.py \ +--bs 80 \ +--lr 0.00002 \ +--epochs 30 \ +--device cuda:$DEVICE \ +--random_mask \ +--prob_of_random_mask 0.4 \ +--clip_model ViT-B/32 \ +--using_clip_features \ +--language_model gpt2 \ +--using_hard_prompt \ +--soft_prompt_first \ +--path_of_datasets ./annotations/flickr30k/flickr30k_texts_features_ViT-B32.pickle \ +--out_dir checkpoints/$EXP_NAME \ +--use_amp \ +|& tee -a ${LOG_FILE} \ No newline at end of file diff --git a/search.py b/search.py new file mode 100644 index 0000000..8f2199e --- /dev/null +++ b/search.py @@ -0,0 +1,682 @@ +import clip +import torch +import numpy as np +from PIL import Image +import torch.nn.functional as F +from typing import Optional, Tuple, List +from transformers import GPT2Tokenizer, GPT2LMHeadModel + + +@torch.no_grad() +def opt_search( + prompts: Optional[str] = None, + tokens: Optional[torch.Tensor] = None, + embeddings: Optional[torch.Tensor] = None, + max_len: int = 64, + beam_width: int = 5, + end_of_sentence: str = ".", + tokenizer: GPT2Tokenizer = None, + model: GPT2LMHeadModel = None, +) -> List[str]: + """ + Sentence generation through choosing token guided by model confidence. + Taking text input as prompts, tokens or embeddings, if more than one input a time, priority should follow: embeddings > tokens > prompts. + Args: + prompts: str, prompts for generated sentence + tokens: tensor with shape of (b, n_seq), device = model.device, dtype = int64 + embeddings: tensor with shape of (b, n_seq, lm_hidden_size), device = model.device, dtype = float16/float32 (from clip encoder/gpt2 encoder) + max_len: int, the maximum length of generated sentence (without considering the length of prompts/tokens/embeddings) + end_of_sentence: str, early stopping once generated word is equal to end_of_sentence + tokenizer: transforming word/sentence to indice/list and vice versa, i.e., str -> List[int64] or List[int64] -> str + model: language model (taking input as either tokens or embeddings) + Return: + list[str] for generated sentence when batch size is greater than 1 (i.e., len(list) = batch_size), and string when batch size is equal to 1 + """ + model.eval() + device = model.device + + # tokenizing end of sentence, when the length of eos tokens is greater than 1, setting the first token of eos tokens as eos token + eos = tokenizer.encode(end_of_sentence)[-1] + + # prefix should transform into word embeddings so that sentence generation is capable of processing input of prompts, tokens or embeddings unifiedly + # priority: embeddings > tokens > prompts + if embeddings is not None: + generating = embeddings # (b, n_seq, lm_hidden_size) + else: + if tokens is None: + tokens = torch.tensor(tokenizer.encode(prompts)) # (n_seq), tokenizing prompts + tokens = tokens.unsqueeze(dim = 0).to(device) # (b(=1), n_seq), adding batch dimension + generating = word_embed(model, tokens) + # generating = model.transformer.wte(tokens) # (b, n_seq, lm_hidden_size), transforming to word embeddings + generating = generating.float() # (b, n_seq, lm_hidden_size) + assert generating.dim() == 3, 'The dimension of prompts should equal to 3!' + + b = generating.shape[0] + # past_key_values = None + inputs_opt = generating + + use_nucleus_sampling = False + num_beams=beam_width + max_length=max_len + min_length=1 + top_p=0.9 + repetition_penalty=1.0 + length_penalty=1.0 + num_captions=1 + temperature=1 + + if use_nucleus_sampling: + query_embeds = inputs_opt.repeat_interleave(num_captions, dim=0) + num_beams = 1 + else: + query_embeds = inputs_opt.repeat_interleave(num_beams, dim=0) + + atts_opt = torch.ones(inputs_opt.size()[:-1], dtype=torch.long).to(inputs_opt.device) + + prompt = tokenizer.eos_token + prompts if prompts else tokenizer.eos_token + prompt = [prompt] * b + opt_tokens = tokenizer(prompt, add_special_tokens=False, return_tensors="pt").to(embeddings.device) + input_ids = opt_tokens.input_ids + attention_mask = torch.cat([atts_opt, opt_tokens.attention_mask], dim=1) + + # import pdb + # pdb.set_trace() + + outputs = model.generate( + input_ids=input_ids, + query_embeds=query_embeds.type(model.dtype), + attention_mask=attention_mask, + do_sample=use_nucleus_sampling, + top_p=top_p, + temperature=temperature, + num_beams=num_beams, + max_new_tokens=max_length, + min_length=min_length, + eos_token_id= eos, + repetition_penalty=repetition_penalty, + length_penalty=length_penalty, + num_return_sequences=num_captions, + ) + + output_text = tokenizer.batch_decode(outputs[:, :], skip_special_tokens=True) + output_text = [text.strip() for text in output_text] + print(output_text) + return output_text + + +@torch.no_grad() +def greedy_search( + prompts: Optional[str] = None, + tokens: Optional[torch.Tensor] = None, + embeddings: Optional[torch.Tensor] = None, + max_len: int = 64, + end_of_sentences: List = [".", " ."], + tokenizer: GPT2Tokenizer = None, + model: GPT2LMHeadModel = None +) -> List[str]: + """ + Sentence generation through choosing token guided by model confidence. + Taking text input as prompts, tokens or embeddings, if more than one input a time, priority should follow: embeddings > tokens > prompts. + Args: + prompts: str, prompts for generated sentence + tokens: tensor with shape of (b, n_seq), device = model.device, dtype = int64 + embeddings: tensor with shape of (b, n_seq, lm_hidden_size), device = model.device, dtype = float16/float32 (from clip encoder/gpt2 encoder) + max_len: int, the maximum length of generated sentence (without considering the length of prompts/tokens/embeddings) + end_of_sentence: str, early stopping once generated word is equal to end_of_sentence + tokenizer: transforming word/sentence to indice/list and vice versa, i.e., str -> List[int64] or List[int64] -> str + model: language model (taking input as either tokens or embeddings) + Return: + list[str] for generated sentence when batch size is greater than 1 (i.e., len(list) = batch_size), and string when batch size is equal to 1 + """ + model.eval() + device = model.device + + # tokenizing end of sentence, when the length of eos tokens is greater than 1, setting the first token of eos tokens as eos token + eos = [tokenizer.encode(end_of_sentence)[-1] for end_of_sentence in end_of_sentences] + + # prefix should transform into word embeddings so that sentence generation is capable of processing input of prompts, tokens or embeddings unifiedly + # priority: embeddings > tokens > prompts + if embeddings is not None: + generating = embeddings # (b, n_seq, lm_hidden_size) + else: + if tokens is None: + tokens = torch.tensor(tokenizer.encode(prompts)) # (n_seq), tokenizing prompts + tokens = tokens.unsqueeze(dim = 0).to(device) # (b(=1), n_seq), adding batch dimension + generating = word_embed(model, tokens) + # generating = model.transformer.wte(tokens) # (b, n_seq, lm_hidden_size), transforming to word embeddings + generating = generating.float() # (b, n_seq, lm_hidden_size) + assert generating.dim() == 3, 'The dimension of prompts should equal to 3!' + + b = generating.shape[0] + past_key_values = None + for step in range(max_len): + # generating initial states of language model + if step == 0: + outputs = model(inputs_embeds = generating.type(model.dtype), past_key_values = past_key_values, use_cache = True) + next_token_logits = outputs.logits[:, -1, :] # (b, n_seq, vocal_size) -> (b, vocal_size), logits of the last token + past_key_values = outputs.past_key_values # Tuple[Tuple[(b, h, n_seq, lm_hidden_size/h)]], layers -> (key, value) -> torch.tensor + + next_token = torch.argmax(next_token_logits, dim = -1, keepdim = True) # (b, 1) + next_embedding = word_embed(model, next_token) # (b, 1, lm_hidden_size) + # next_embedding = model.transformer.wte(next_token) # (b, 1, lm_hidden_size) + outputs = model(inputs_embeds = next_embedding.type(model.dtype), past_key_values = past_key_values, use_cache = True) + next_token_logits = outputs.logits[:, -1, :] # (b, 1, vocal_size) -> (b, vocal_size) + past_key_values = outputs.past_key_values # Tuple[Tuple[(b, h, n_seq + 1, lm_hidden_size/h)]] + + # updating tokens + if tokens is None: + tokens = next_token + else: + tokens = torch.cat((tokens, next_token), dim = 1) # (b, n_seq + 1) + + # whether to stop early according to the end of sentence, only working when batch size is equal to 1 + if b == 1 and next_token.item() in eos: + new_tokens = tokens.squeeze(dim = 0).tolist() + sentence = tokenizer.decode(new_tokens) + return sentence + + # tokens: (1/b, n_seq + max_len) where n_seq refers to the length of inputs tokens or prompts + # torch.tensor(1/b, n_seq + max_Len) -> str/list[str] + sentence = [] + if b == 1: + new_tokens = tokens.squeeze(dim = 0).tolist() + sentence = tokenizer.decode(new_tokens) + else: + for temp_tokens in tokens: + for i in range(len(temp_tokens)): + if temp_tokens[i].item() in eos: + break + new_tokens = temp_tokens[:i + 1].tolist() + sentence.append(tokenizer.decode(new_tokens)) + return sentence + +def beam_search( + prompts: Optional[str] = None, + tokens: Optional[torch.Tensor] = None, + embeddings: Optional[torch.Tensor] = None, + temperature = 1.0, + max_len: int = 64, + beam_width: int = 5, + end_of_sentences: List = [".", " ."], + tokenizer: GPT2Tokenizer = None, + model: GPT2LMHeadModel = None +) -> List[str]: + """ + Sentence generation through choosing token guided by model confidence. + Taking text input as prompts, tokens or embeddings, if more than one input a time, priority should follow: embeddings > tokens > prompts. + Args: + prompts: str, prompts for generated sentence + tokens: tensor with shape of (b, n_seq), device = model.device, dtype = int64 + embeddings: tensor with shape of (b, n_seq, lm_hidden_size), device = model.device, dtype = float16/float32 (from clip encoder/gpt2 encoder) + max_len: int, the maximum length of generated sentence (without considering the length of prompts/tokens/embeddings) + beam_width: the width of beam + end_of_sentence: str, early stopping once generated word is equal to end_of_sentence + tokenizer: transforming word/sentence to indice/list and vice versa, i.e., str -> List[int64] or List[int64] -> str + model: language model (taking input as either tokens or embeddings) + Return: + list[str] for generated sentence when batch size is greater than 1 (i.e., len(list) = batch_size), and string when batch size is equal to 1 + """ + model.eval() + device = model.device + + # tokenizing end of sentence, when the length of eos tokens is greater than 1, setting the first token of eos tokens as eos token + eos = [tokenizer.encode(end_of_sentence)[-1] for end_of_sentence in end_of_sentences] + scores = None + seq_lengths = torch.ones(beam_width, device = device) + is_stopped = torch.zeros(beam_width, device = device, dtype=torch.bool) + # prefix should transform into word embeddings so that sentence generation is capable of processing input of prompts, tokens or embeddings unifiedly + # priority: embeddings > tokens > prompts + if embeddings is not None: + generated = embeddings # (b, n_seq, lm_hidden_size) + else: + if tokens is None: + tokens = torch.tensor(tokenizer.encode(prompts)) # (n_seq), tokenizing prompts + tokens = tokens.unsqueeze(dim = 0).to(device) # (b(=1), n_seq), adding batch dimension + generated = word_embed(model, tokens) + # generated = model.transformer.wte(tokens) # (b, n_seq, lm_hidden_size), transforming to word embeddings + generated = generated.float() # (b, n_seq, lm_hidden_size) + assert generated.dim() == 3, 'The dimension of prompts should equal to 3!' + + + for i in range(max_len): + outputs = model(inputs_embeds=generated.type(model.dtype)) + logits = outputs.logits + logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0) + logits = logits.softmax(-1).log() + if scores is None: + scores, next_tokens = logits.topk(beam_width, -1) + generated = generated.expand(beam_width, *generated.shape[1:]) + next_tokens, scores = next_tokens.permute(1, 0), scores.squeeze(0) + if tokens is None: + tokens = next_tokens + else: + tokens = tokens.expand(beam_width, *tokens.shape[1:]) + tokens = torch.cat((tokens, next_tokens), dim=1) + else: + logits[is_stopped] = -float(np.inf) + logits[is_stopped, 0] = 0 + scores_sum = scores[:, None] + logits + seq_lengths[~is_stopped] += 1 + scores_sum_average = scores_sum / seq_lengths[:, None] + scores_sum_average, next_tokens = scores_sum_average.view(-1).topk(beam_width, -1) + # next_tokens_source = torch.floor(torch.div(next_tokens, scores_sum.shape[1])).long() + next_tokens_source = torch.div(next_tokens, scores_sum.shape[1], rounding_mode = 'trunc') + seq_lengths = seq_lengths[next_tokens_source] + next_tokens = next_tokens % scores_sum.shape[1] + next_tokens = next_tokens.unsqueeze(1) + tokens = tokens[next_tokens_source] + tokens = torch.cat((tokens, next_tokens), dim=1) + generated = generated[next_tokens_source] + scores = scores_sum_average * seq_lengths + is_stopped = is_stopped[next_tokens_source] + next_token_embed = word_embed(model, next_tokens.squeeze()).view(generated.shape[0], 1, -1) + # next_token_embed = model.gpt.transformer.wte(next_tokens.squeeze()).view(generated.shape[0], 1, -1) + generated = torch.cat((generated, next_token_embed), dim=1) + assert len(eos) == 2 # hack + is_stopped = is_stopped + (next_tokens.eq(eos[0]) | next_tokens.eq(eos[1])).squeeze() + if is_stopped.all(): + break + scores = scores / seq_lengths + output_list = tokens.cpu().numpy() + output_texts = [tokenizer.decode(output[:int(length)]) for output, length in zip(output_list, seq_lengths)] + order = scores.argsort(descending=True) + output_texts = [output_texts[i] for i in order] + + return output_texts + +def word_embed(gpt, caption_tokens): + if hasattr(gpt, 'transformer'): + embedding_text = gpt.transformer.wte(caption_tokens) + elif hasattr(gpt, 'model'): + embedding_text = gpt.model.decoder.embed_tokens(caption_tokens) + return embedding_text + +@torch.no_grad() +def contrastive_search( + prompts: Optional[str] = None, + tokens: Optional[torch.Tensor] = None, + embeddings: Optional[torch.Tensor] = None, + alpha: float = 0.1, + top_k: int = 48, + max_len: int = 64, + end_of_sentence: str = '.', + tokenizer: GPT2Tokenizer = None, + model: GPT2LMHeadModel = None +) -> List[str]: + """ + Sentence generation through choosing token guided by model confidence, degeneration penality. + Taking text input as prompts, tokens or embeddings, if more than one input a time, priority should follow: embeddings > tokens > prompts. + Args: + prompts: str, prompts for generated sentence + tokens: tensor with shape of (b, n_seq), device = model.device, dtype = int64 + embeddings: tensor with shape of (b, n_seq, lm_hidden_size), device = model.device, dtype = float16/float32 (from clip encoder/gpt2 encoder) + alpha: float from 0.0 to 1.0, controlling the strength of degenration penalty (i.e., avoiding repeat) + top_k: int, generating k candidate tokens each time step in next token predicition (i.e., next token will be selected from the top k candidates) + max_len: int, the maximum length of generated sentence (without considering the length of prompts/tokens/embeddings) + end_of_sentence: str, early stopping once generated word is equal to end_of_sentence + tokenizer: transforming word/sentence to indice/list and vice versa, i.e., str -> List[int64] or List[int64] -> str + model: language model (taking input as either tokens or embeddings) + Return: + list[str] for generated sentence when batch size is greater than 1 (i.e., len(list) = batch_size), and string when batch size is equal to 1 + """ + model.eval() + device = model.device + + # tokenizing end of sentence, when the length of eos tokens is greater than 1, setting the first token of eos tokens as eos token + eos = tokenizer.encode(end_of_sentence)[0] + + # prefix should transform into word embeddings so that sentence generation is capable of processing input of prompts, tokens or embeddings unifiedly + # priority: embeddings > tokens > prompts + if embeddings is not None: + generating = embeddings # (b, n_seq, lm_hidden_size) + else: + if tokens is None: + tokens = torch.tensor(tokenizer.encode(prompts)) # (n_seq), tokenizing prompts + tokens = tokens.unsqueeze(dim = 0).to(device) # (b(=1), n_seq), adding batch dimension + generated = word_embed(model, tokens) + # generating = model.transformer.wte(tokens) # (b, n_seq, lm_hidden_size), transforming to word embeddings + generating = generating.float() # (b, n_seq, lm_hidden_size) + assert generating.dim() == 3, 'The dimension of prompts should equal to 3!' + + past_key_values = None + for step in range(max_len): + # generating the initial states of model + if step == 0: + outputs = model(inputs_embeds = generating, past_key_values = past_key_values, use_cache = True, output_hidden_states = True) + next_token_logits = outputs.logits[:, -1, :] # (b, n_seq, vocal_size) -> (b, vocal_size), logits of the last token + past_key_values = outputs.past_key_values # Tuple[Tuple[(b, h, n_seq, lm_hidden_size/h)]], layers -> (key, value) -> torch.tensor + past_hidden_states = outputs.hidden_states[-1] # Tuple[(b, n_seq, lm_hidden_size)] -> (b, n_seq, lm_hidden_size) (i.e., hidden state of last layer) + + # selecting top k candidates and their probability from next_tokens_logits + b, n_seq, lm_hidden_size = past_hidden_states.size() + next_token_probs = F.softmax(next_token_logits, dim = -1) # (b, vocal_size) + _, top_k_indices = torch.topk(next_token_logits, dim = -1, k = top_k) # (b, k), the indices for top k candidates (i.e., tokens) + top_k_probs = torch.gather(next_token_probs, dim = 1, index = top_k_indices) # (b, k), the probability for top k candidates + + # transformering b*k tokens to embeddings and processing past_key_values to compute simultaneously for k tokens + top_k_embeddings = model.transformer.wte(top_k_indices.view(-1, 1)) # (b*k, 1, lm_hidden_size) + past_key_values = reshape_from_past_key_values(past_key_values, top_k) # Tuple[Tuple[(b*k, h, n_seq, lm_hidden_size/h)]] + # computing hidden state of next token (b * top_k in total) + outputs = model(inputs_embeds = top_k_embeddings, past_key_values = past_key_values, use_cache = True, output_hidden_states = True) + logits = outputs.logits[:, -1, :] # (b*k, 1, vocal_size) -> (b*k, vocal_size) + past_key_values = outputs.past_key_values # Tuple[Tuple[(b*k, h, n_seq + 1, lm_hidden_size/h)]] + next_hidden_state = outputs.hidden_states[-1] # Tuple[(b*k, 1, lm_hidden_size)] -> (b*k, 1, lm_hidden_size) + context_hidden_states = past_hidden_states.unsqueeze(dim = 1).expand(-1, top_k, -1, -1).reshape(b*top_k, n_seq, lm_hidden_size) # (b*k, n_seq, lm_hidden_size) + + # selecting next token within top k candidates for each sentence + selected_max_prob_indices = ranking_and_selecting(context_hidden_states, next_hidden_state, top_k_probs, alpha, top_k) # (b) + + # updating next_token_logits, past key-values and last hidden state + logits = torch.stack(torch.split(logits, top_k), dim = 0) # (b, k, vocal_size) + next_token_logits = logits[range(b), selected_max_prob_indices, :] # (b, vocal_size) + past_key_values = reshape_to_past_key_values(past_key_values, selected_max_prob_indices, top_k) # (b, h, n_seq + 1, lm_hidden_size/h) + next_hidden_state = torch.stack(torch.split(next_hidden_state.squeeze(dim = 1), top_k), dim = 0) # (b, k, lm_hidden_size) + next_hidden_state = next_hidden_state[range(b), selected_max_prob_indices, :] # (b, lm_hidden_size) + past_hidden_states = torch.cat([past_hidden_states, next_hidden_state.unsqueeze(dim = 1)], dim=1) # [b, n_seq + 1, lm_hidden_size] + + # computing next token and saving it + next_token = top_k_indices[range(b), selected_max_prob_indices].unsqueeze(dim = -1) # (b, 1) + if tokens is None: + tokens = next_token + else: + tokens = torch.cat((tokens, next_token), dim = 1) # (b, n_seq + 1) + + # whether to stop early according to the end of sentence, only working when batch size is equal to 1 + if b == 1 and next_token.item() == eos: + new_tokens = tokens.squeeze(dim = 0).tolist() + sentence = tokenizer.decode(new_tokens) + return sentence + + # tokens: (1/b, n_seq + max_len) where n_seq refers to the length of inputs tokens or prompts + # torch.tensor(1/b, n_seq + max_Len) -> str/list[str] + sentence = [] + if b == 1: + new_tokens = tokens.squeeze(dim = 0).tolist() + sentence = tokenizer.decode(new_tokens) + else: + for temp_tokens in tokens: + for i in range(len(temp_tokens)): + if temp_tokens[i].item() == eos: + break + new_tokens = temp_tokens[:i + 1].tolist() + sentence.append(tokenizer.decode(new_tokens)) + return sentence + +@torch.no_grad() +def magic_search( + prompts: Optional[str] = None, + tokens: Optional[torch.Tensor] = None, + embeddings: Optional[torch.Tensor] = None, + image_path: Optional[str] = None, + images_feature: Optional[torch.Tensor] = None, + alpha: float = 0.1, + beta: float = 2.0, + top_k: int = 48, + max_len: int = 64, + clip_text_max_len: int = 60, + end_of_sentence: str = '.', + tokenizer: GPT2Tokenizer = None, + model: GPT2LMHeadModel = None +) -> List[str]: + """ + Sentence generation through choosing token guided by model confidence, degeneration penality and image at each time step. + Taking text input as prompts, tokens or embeddings, if more than one input a time, priority should follow: embeddings > tokens > prompts. + Taking image input as images_path or images_feature, if more than one input a time, priority should follow images_feature > image_path. + Args: + prompts: str, prompts for generated sentence + tokens: tensor with shape of (b, n_seq), device = model.device, dtype = int64 + embeddings: tensor with shape of (b, n_seq, lm_hidden_size), device = model.device, dtype = float16/float32 (from clip encoder/gpt2 encoder) + image_path: str, the path of a single image + images_feature: tensor with shape of (b, clip_hidden_size), device = model.device, dtype = float32 + alpha: float from 0.0 to 1.0, controlling the strength of degenration penalty (i.e., avoiding repeat) + beta: float, controlling image-guided strength + top_k: int, generating k candidate tokens each time step in next token predicition (i.e., next token will be selected from the top k candidates) + max_len: int, the maximum length of generated sentence (without considering the length of prompts/tokens/embeddings) + clip_text_max_len: int, the maximum length of clip textual encoder + end_of_sentence: str, early stopping once generated word is equal to end_of_sentence + tokenizer: transforming word/sentence to indice/list and vice versa, i.e., str -> List[int64] or List[int64] -> str + model: language model (taking input as either tokens or embeddings) + Return: + list[str] for generated sentence when batch size is greater than 1 (i.e., len(list) = batch_size), and string when batch size is equal to 1 + """ + model.eval() + device = model.device + + # tokenizing end of sentence, when the length of eos tokens is greater than 1, setting the first token of eos tokens as eos token + eos = tokenizer.encode(end_of_sentence)[0] + + # prefix should transform into word embeddings so that sentence generation is capable of processing input of prompts, tokens or embeddings unifiedly + # priority: embeddings > tokens > prompts + if embeddings is not None: + generating = embeddings # (b, n_seq, lm_hidden_size) + else: + if tokens is None: + tokens = torch.tensor(tokenizer.encode(prompts)) # (n_seq), tokenizing prompts + tokens = tokens.unsqueeze(dim = 0).to(device) # (b(=1), n_seq), adding batch dimension + generating = model.transformer.wte(tokens) # (b, n_seq, lm_hidden_size), transforming to word embeddings + generating = generating.float() # (b, n_seq, lm_hidden_size) + assert generating.dim() == 3, 'The dimension of prompts should equal to 3!' + + # generating image feature using clip visual encoder + # note that the dtype of feature from clip visual encoder is equal to float16, transforming it into float32 + # priority: images_feature > image_path + clip_model, preprocess = clip.load('ViT-B/32', device = device) + clip_model.eval() + if images_feature is None: + image = preprocess(Image.open(image_path)).unsqueeze(dim = 0).to(device) # (b(=1), 3, 224, 224) + images_feature = clip_model.encode_image(image) # (b, clip_hidden_size) + images_feature = images_feature.float() # (b, clip_hidden_size) + assert images_feature.dim() == 2, 'The dimension of images feature should equal to 2!' + assert images_feature.shape[0] == generating.shape[0], 'The number of images should be equal to the number of prompts/tokens/embeddings!' + + past_key_values = None + tokens_generated = None + for step in range(max_len): + # generating the initial states of model + if step == 0: + outputs = model(inputs_embeds = generating, past_key_values = past_key_values, use_cache = True, output_hidden_states = True) + next_token_logits = outputs.logits[:, -1, :] # (b, n_seq, vocal_size) -> (b, vocal_size), logits of the last token + past_key_values = outputs.past_key_values # Tuple[Tuple[(b, h, n_seq, lm_hidden_size/h)]], layers -> (key, value) -> torch.tensor + past_hidden_states = outputs.hidden_states[-1] # Tuple[(b, n_seq, lm_hidden_size)] -> (b, n_seq, lm_hidden_size) (i.e., hidden state of last layer) + + # selecting top k candidates and their probability from next_tokens_logits + b, n_seq, lm_hidden_size = past_hidden_states.size() + next_token_probs = F.softmax(next_token_logits, dim = -1) # (b, vocal_size) + _, top_k_indices = torch.topk(next_token_logits, dim = -1, k = top_k) # (b, k), the indices for top k candidates (i.e., tokens) + top_k_probs = torch.gather(next_token_probs, dim = 1, index = top_k_indices) # (b, k), the probability for top k candidates + + # computing similarity between image and sentence (b * k in total) + image_sentence_score = image_sentence_similarity(tokens_generated, top_k_indices, images_feature, top_k, clip_text_max_len, tokenizer, clip_model) # (b, k) + + # transformering b*k tokens to embeddings and processing past_key_values to compute simultaneously for k tokens + top_k_embeddings = model.transformer.wte(top_k_indices.view(-1, 1)) # (b*k, 1, lm_hidden_size) + past_key_values = reshape_from_past_key_values(past_key_values, top_k) # Tuple[Tuple[(b*k, h, n_seq, lm_hidden_size/h)]] + # computing hidden state of next token (b * top_k in total) + outputs = model(inputs_embeds = top_k_embeddings, past_key_values = past_key_values, use_cache = True, output_hidden_states = True) + logits = outputs.logits[:, -1, :] # (b*k, 1, vocal_size) -> (b*k, vocal_size) + past_key_values = outputs.past_key_values # Tuple[Tuple[(b*k, h, n_seq + 1, lm_hidden_size/h)]] + next_hidden_state = outputs.hidden_states[-1] # Tuple[(b*k, 1, lm_hidden_size)] -> (b*k, 1, lm_hidden_size) + context_hidden_states = past_hidden_states.unsqueeze(dim = 1).expand(-1, top_k, -1, -1).reshape(b*top_k, n_seq, lm_hidden_size) # (b*k, n_seq, lm_hidden_size) + + # selecting next token within top k candidates for each sentence + selected_max_prob_indices = ranking_and_selecting(context_hidden_states, next_hidden_state, top_k_probs, alpha, top_k, beta, image_sentence_score) # (b) + + # updating next_token_logits, past key-values and last hidden state + logits = torch.stack(torch.split(logits, top_k), dim = 0) # (b, k, vocal_size) + next_token_logits = logits[range(b), selected_max_prob_indices, :] # (b, vocal_size) + past_key_values = reshape_to_past_key_values(past_key_values, selected_max_prob_indices, top_k) # (b, h, n_seq + 1, lm_hidden_size/h) + next_hidden_state = torch.stack(torch.split(next_hidden_state.squeeze(dim = 1), top_k), dim = 0) # (b, k, lm_hidden_size) + next_hidden_state = next_hidden_state[range(b), selected_max_prob_indices, :] # (b, lm_hidden_size) + past_hidden_states = torch.cat([past_hidden_states, next_hidden_state.unsqueeze(dim = 1)], dim=1) # [b, n_seq + 1, lm_hidden_size] + + # computing next token and saving it + next_token = top_k_indices[range(b), selected_max_prob_indices].unsqueeze(dim = -1) # (b, 1) + if tokens is None: + tokens = next_token + tokens_generated = next_token + else: + if tokens_generated is None: + tokens_generated = next_token + else: + tokens_generated = torch.cat((tokens_generated, next_token), dim = 1) + tokens = torch.cat((tokens, next_token), dim = 1) # (b, n_seq + 1) + + # whether to stop early according to the end of sentence, only working when batch size is equal to 1 + if b == 1 and next_token.item() == eos: + new_tokens = tokens.squeeze(dim = 0).tolist() + sentence = tokenizer.decode(new_tokens) + return sentence + + # tokens: (1/b, n_seq + max_len) where n_seq refers to the length of inputs tokens or prompts + # torch.tensor(1/b, n_seq + max_Len) -> str/list[str] + sentence = [] + if b == 1: + new_tokens = tokens.squeeze(dim = 0).tolist() + sentence = tokenizer.decode(new_tokens) + else: + for temp_tokens in tokens: + for i in range(len(temp_tokens)): + if temp_tokens[i].item() == eos: + break + new_tokens = temp_tokens[:i + 1].tolist() + sentence.append(tokenizer.decode(new_tokens)) + return sentence + +def image_sentence_similarity( + tokens_generated: torch.Tensor, + top_k_indices: torch.Tensor, + images_feature: torch.Tensor, + top_k: int, + clip_text_max_len: int, + tokenizer: GPT2Tokenizer, + clip_model: clip +) -> torch.Tensor: + """ + Args: + tokens_generated: tensor with shape of (b, n_seq), the sentence generated (without considering the prompts) + top_k_indices: tensor with shape of (b, top_k), the top k candidates for each sentence + images_feature: tensor with shape of (b, clip_hidden_size), image feature encoded by clip + top_k: int, k candidates + clip_text_max_len: int, the maximum length of clip textual encoder + tokenizer: transforming word/sentence to indice/list and vice versa + clip_model: pre-trained clip model which encodes image or image to embeddings with dtype of float16 (transforming to float32) + + Return: + image-sentence similarity score with shape of (b, k), i.e., for each sentence (b in total), returning top k tokens similarity with image + """ + device = top_k_indices.device + + # obtaining tokens of generated (b sentences and k tokens for each sentence, i.e., b * k sentences in total) + if tokens_generated is None: + temp_tokens = top_k_indices.view(-1).unsqueeze(dim = 1) # (b*k, n_seq + 1), where n_seq = 0 + else: + b, n = tokens_generated.size() + tokens_generated = tokens_generated.unsqueeze(dim = 1).expand(-1, top_k, -1).reshape(b*top_k, n) # (b*k, n_seq) + top_k_indices = top_k_indices.view(-1).unsqueeze(dim = 1) # (b*k, 1) + temp_tokens = torch.cat([tokens_generated, top_k_indices], dim = 1) # (b*k, n_seq + 1) + + # converting to sentence + sentences = [] + for temp_token in temp_tokens: + # taking the latest clip_text_max_len tokens when tokens length is greater than clip_text_max_len + sentence = tokenizer.decode(temp_token[-clip_text_max_len:].to('cpu').tolist()) + sentences.append(sentence) # len(sentences) = b*k + + # converting to text tokens and embeddings of clip + clip_tokens = clip.tokenize(sentences).to(device) # (b*k, n_seq) + clip_embeddings = clip_model.encode_text(clip_tokens) # (b*k, clip_hidden_size) + clip_embeddings = torch.stack(torch.split(clip_embeddings, top_k), dim = 0).float() # (b, k, clip_hidden_size) + + # computing similarity score + images_feature = images_feature.unsqueeze(dim = 1) # (b, 1, clip_hidden_size) + clip_embeddings = clip_embeddings / clip_embeddings.norm(dim = -1, keepdim = True) # (b, k, clip_hidden_size) + images_feature = images_feature / images_feature.norm(dim = -1, keepdim = True) # (b, 1, clip_hidden_size) + scaling = clip_model.logit_scale.exp() + score = torch.matmul(clip_embeddings, images_feature.transpose(1, 2)).squeeze(dim = 2) * scaling # (b, k) + + return F.softmax(score, dim = -1) + +def reshape_from_past_key_values(past_key_values: Tuple[Tuple[torch.Tensor]], top_k: int) -> Tuple[Tuple[torch.Tensor]]: + """ + To compute top k candidates simultaneously for each sentence in a batch, duplicating k times for each sentence. + Args: + past_key_values: Tuple[Tuple[(b, h, n_seq, lm_hidden_size/h)]], the first tuple refers to layers and the second tuple refers to key-value pair + top_k: int, k candidates + Return: + Tuple[Tuple[(b*k, h, n_seq, lm_hidden_size/h)]] + """ + new_key_values = [] + for layer in past_key_values: + items = [] + for item in layer: + b, h, n, d = item.size() # d = lm_hidden_size/h + # duplicating k times for each sentence in a batch, the only difference between each k repeated sample is the candidate waiting to concatenate + item = item.unsqueeze(dim = 1).expand(-1, top_k, -1, -1, -1).reshape(b*top_k, h, n, d) # (b*k, h, n_seq, lm_hidden_size/h) + items.append(item) + new_key_values.append(items) + return new_key_values + +def reshape_to_past_key_values(past_key_values: Tuple[Tuple[torch.Tensor]], selected_max_prob_indices: torch.Tensor, top_k: int) -> Tuple[Tuple[torch.Tensor]]: + """ + Args: + past_key_values: Tuple[Tuple[(b*k, h, n_seq + 1, lm_hidden_size/h)]] + selected_max_prob_indices: tensor with shape of (b), indices of maximum probability in k candidates + top_k: int, k candidates + Return: + Tuple[Tuple[(b, h, n_seq + 1, lm_hidden_size/h)]] + """ + new_key_values = [] + for layer in past_key_values: + items = [] + for item in layer: + bk = item.shape[0] + b = int(bk//top_k) + item = torch.stack(torch.split(item, top_k), dim = 0) # (b, k, h, n_seq + 1, lm_hidden_size/h) + item = item[range(b), selected_max_prob_indices, :, :, :] # (b, h, n_seq + 1, lm_hidden_size/h) + items.append(item) + new_key_values.append(items) + return new_key_values + +def ranking_and_selecting( + context_hidden_states: torch.Tensor, + next_hidden_state: torch.Tensor, + top_k_probs: torch.Tensor, + alpha: float, + top_k: int, + beta: Optional[float] = None, + image_sentence_score: Optional[torch.Tensor] = None +) -> torch.Tensor: + """ + Args: + context_hidden_states: tensor with shape of (b*k, n_seq, lm_hidden_size), the hidden state of each token in sentence before candidates (i.e.