From e7cbe89970fb80cfaa959df4983e13ae431635ba Mon Sep 17 00:00:00 2001 From: Austin Brown Date: Mon, 5 Aug 2024 21:18:22 -0400 Subject: [PATCH 01/12] added recall class, to-do: implement --- mimir/attacks/all_attacks.py | 1 + mimir/attacks/recall.py | 18 ++++++++++++++++++ mimir/attacks/utils.py | 2 ++ 3 files changed, 21 insertions(+) create mode 100644 mimir/attacks/recall.py diff --git a/mimir/attacks/all_attacks.py b/mimir/attacks/all_attacks.py index 21fc613..4692695 100644 --- a/mimir/attacks/all_attacks.py +++ b/mimir/attacks/all_attacks.py @@ -15,6 +15,7 @@ class AllAttacks(str, Enum): MIN_K_PLUS_PLUS = "min_k++" # Done NEIGHBOR = "ne" # Done GRADNORM = "gradnorm" # Done + RECALL = "recall" # QUANTILE = "quantile" # Uncomment when tested implementation is available diff --git a/mimir/attacks/recall.py b/mimir/attacks/recall.py new file mode 100644 index 0000000..46ef82c --- /dev/null +++ b/mimir/attacks/recall.py @@ -0,0 +1,18 @@ +""" + ReCaLL Attack: https://github.com/ruoyuxie/recall/ +""" +import torch +import numpy as np +from mimir.attacks.all_attacks import Attack +from mimir.models import Model +from mimir.config import ExperimentConfig + +class ReCaLLAttack(Attack): + + def __init__(self, config: ExperimentConfig, target_model: Model): + super().__init__(config, target_model, ref_model = None) + + @torch.no_grad() + def _attack(self, document, probs, tokens = None, **kwargs): + # TODO implement ReCaLL Attack + raise NotImplementedError("Need to do") diff --git a/mimir/attacks/utils.py b/mimir/attacks/utils.py index 05b66cc..e360abb 100644 --- a/mimir/attacks/utils.py +++ b/mimir/attacks/utils.py @@ -7,6 +7,7 @@ from mimir.attacks.min_k_plus_plus import MinKPlusPlusAttack from mimir.attacks.neighborhood import NeighborhoodAttack from mimir.attacks.gradnorm import GradNormAttack +from mimir.attacks.recall import ReCallAttack # TODO Use decorators to link attack implementations with enum above @@ -19,6 +20,7 @@ def get_attacker(attack: str): AllAttacks.MIN_K_PLUS_PLUS: MinKPlusPlusAttack, AllAttacks.NEIGHBOR: NeighborhoodAttack, AllAttacks.GRADNORM: GradNormAttack, + AllAttacks.RECALL: ReCaLLAttack } attack_cls = mapping.get(attack, None) if attack_cls is None: From dc6260fdee8a5e08a06c2a48ef90eccafcc54f08 Mon Sep 17 00:00:00 2001 From: Austin Brown Date: Mon, 5 Aug 2024 21:19:10 -0400 Subject: [PATCH 02/12] typo --- mimir/attacks/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mimir/attacks/utils.py b/mimir/attacks/utils.py index e360abb..766e22b 100644 --- a/mimir/attacks/utils.py +++ b/mimir/attacks/utils.py @@ -7,7 +7,7 @@ from mimir.attacks.min_k_plus_plus import MinKPlusPlusAttack from mimir.attacks.neighborhood import NeighborhoodAttack from mimir.attacks.gradnorm import GradNormAttack -from mimir.attacks.recall import ReCallAttack +from mimir.attacks.recall import ReCaLLAttack # TODO Use decorators to link attack implementations with enum above From 84752ec89723ed4f81c06c01f298cb9d9b4015cc Mon Sep 17 00:00:00 2001 From: Austin Brown Date: Wed, 14 Aug 2024 00:27:27 -0400 Subject: [PATCH 03/12] new json files for testing --- configs/new_mi.json | 24 +++++++++++++++--------- configs/recall.json | 32 ++++++++++++++++++++++++++++++++ 2 files changed, 47 insertions(+), 9 deletions(-) create mode 100644 configs/recall.json diff --git a/configs/new_mi.json b/configs/new_mi.json index 280a80e..37ee81b 100644 --- a/configs/new_mi.json +++ b/configs/new_mi.json @@ -8,19 +8,25 @@ "max_tokens": 512, "max_data": 100000, "output_name": "unified_mia", - "specific_source": "Github", + "specific_source": "Github_ngram_13_<0.8_truncated", "n_samples": 1000, - "blackbox_attacks": ["loss", "ref", "min_k", "zlib"], - "ref_config": { - "models": [ - "stabilityai/stablelm-base-alpha-3b-v2" - ] - }, + "blackbox_attacks": ["loss", "min_k", "zlib"], "env_config": { "results": "results_new", "device_map": "balanced_low_0" }, + "neighborhood_config": { + "model": "bert", + "n_perturbation_list": [ + 25 + ], + "pct_words_masked": 0.3, + "span_length": 2, + "dump_cache": false, + "load_from_cache": true, + "neighbor_strategy": "random" + }, "dump_cache": false, - "load_from_cache": true, - "load_from_hf": false + "load_from_cache": false, + "load_from_hf": true } \ No newline at end of file diff --git a/configs/recall.json b/configs/recall.json new file mode 100644 index 0000000..4120fe1 --- /dev/null +++ b/configs/recall.json @@ -0,0 +1,32 @@ +{ + "experiment_name": "edited_members_exp", + "base_model": "EleutherAI/gpt-neo-125m", + "dataset_member": "the_pile", + "dataset_nonmember": "the_pile", + "min_words": 100, + "max_words": 200, + "max_tokens": 512, + "max_data": 100000, + "output_name": "unified_mia", + "specific_source": "Github_ngram_13_<0.8_truncated", + "n_samples": 1000, + "blackbox_attacks": ["loss", "min_k", "zlib", "recall"], + "env_config": { + "results": "results_new", + "device_map": "balanced_low_0" + }, + "neighborhood_config": { + "model": "bert", + "n_perturbation_list": [ + 25 + ], + "pct_words_masked": 0.3, + "span_length": 2, + "dump_cache": false, + "load_from_cache": true, + "neighbor_strategy": "random" + }, + "dump_cache": false, + "load_from_cache": false, + "load_from_hf": true +} \ No newline at end of file From 4510d96fd0d23bcba428dc9858af40e1904ea3cf Mon Sep 17 00:00:00 2001 From: Austin Brown Date: Wed, 21 Aug 2024 02:41:36 -0400 Subject: [PATCH 04/12] working recall along with config file --- configs/recall.json | 3 ++- mimir/attacks/recall.py | 44 ++++++++++++++++++++++++++++++++++++++--- mimir/config.py | 2 ++ run.py | 27 ++++++++++++++++++++++++- 4 files changed, 71 insertions(+), 5 deletions(-) diff --git a/configs/recall.json b/configs/recall.json index 4120fe1..1fb6939 100644 --- a/configs/recall.json +++ b/configs/recall.json @@ -10,7 +10,8 @@ "output_name": "unified_mia", "specific_source": "Github_ngram_13_<0.8_truncated", "n_samples": 1000, - "blackbox_attacks": ["loss", "min_k", "zlib", "recall"], + "recall_num_shots": 1, + "blackbox_attacks": ["recall"], "env_config": { "results": "results_new", "device_map": "balanced_low_0" diff --git a/mimir/attacks/recall.py b/mimir/attacks/recall.py index 46ef82c..003611f 100644 --- a/mimir/attacks/recall.py +++ b/mimir/attacks/recall.py @@ -13,6 +13,44 @@ def __init__(self, config: ExperimentConfig, target_model: Model): super().__init__(config, target_model, ref_model = None) @torch.no_grad() - def _attack(self, document, probs, tokens = None, **kwargs): - # TODO implement ReCaLL Attack - raise NotImplementedError("Need to do") + def _attack(self, document, probs, tokens = None, **kwargs): + nonmember_prefix = kwargs.get("nonmember_prefix", None) + assert nonmember_prefix, "nonmember_prefix should not be None or empty" + + lls = self.target_model.get_ll(document, probs = probs, tokens = tokens) + ll_nonmember = self.get_conditional_ll(nonmember_prefix = nonmember_prefix, text = document, + model = self.target_model, tokenizer=self.target_model.tokenizer, tokens = tokens) + recall = ll_nonmember / lls + + return recall + + def get_conditional_ll(self, nonmember_prefix, text, model, tokenizer, tokens = None): + assert nonmember_prefix, "nonmember_prefix should not be None or empty" + + input_encodings = tokenizer(text = nonmember_prefix, return_tensors="pt") + if tokens is None: + target_encodings = tokenizer(text = text, return_tensors="pt") + else: + target_encodings = tokens + + max_length = model.max_length + input_ids = input_encodings.input_ids.to(model.device) + target_ids = target_encodings.input_ids.to(model.device) + + total_length = input_ids.size(1) + target_ids.size(1) + + if total_length > max_length: + excess_length = total_length - max_length + target_ids = target_ids[:, :-excess_length] + concat_ids = torch.cat((input_ids, target_ids), dim=1) + labels = concat_ids.clone() + labels[:, :input_ids.size(1)] = -100 + + with torch.no_grad(): + outputs = model.model(concat_ids, labels=labels) + + loss, logits = outputs[:2] + ll = -loss.item() + return ll + + diff --git a/mimir/config.py b/mimir/config.py index 64c911e..9b2615c 100644 --- a/mimir/config.py +++ b/mimir/config.py @@ -174,6 +174,8 @@ class ExperimentConfig(Serializable): """Chunk size""" scoring_model_name: Optional[str] = None """Scoring model (if different from base model)""" + recall_num_shots: Optional[int] = 1 + """Number of shots for ReCaLL Attacks""" top_k: Optional[int] = 40 """Consider only top-k tokens""" do_top_k: Optional[bool] = False diff --git a/run.py b/run.py index de8032a..6becfaa 100644 --- a/run.py +++ b/run.py @@ -77,6 +77,7 @@ def get_mia_scores( is_train: bool, n_samples: int = None, batch_size: int = 50, + **kwargs ): # Fix randomness fix_seed(config.random_seed) @@ -100,6 +101,11 @@ def get_mia_scores( n_perturbation: [] for n_perturbation in n_perturbation_list } + nonmember_prefix = kwargs.get("nonmember_prefix", None) + if AllAttacks.RECALL in attackers_dict.keys(): + if nonmember_prefix is None: + raise ValueError("Must include a prefix for ReCaLL attack") + # For each batch of data # TODO: Batch-size isn't really "batching" data - change later for batch in tqdm(range(math.ceil(n_samples / batch_size)), desc=f"Computing criterion"): @@ -160,8 +166,10 @@ def get_mia_scores( ), loss=loss, all_probs=s_all_probs, + nonmember_prefix = nonmember_prefix ) sample_information[attack].append(score) + else: # For each 'number of neighbors' for n_perturbation in n_perturbation_list: @@ -515,6 +523,21 @@ def main(config: ExperimentConfig): mask_model_tokenizer=mask_model.tokenizer if mask_model else None, ) + #* ReCaLL Specific + if AllAttacks.RECALL in config.blackbox_attacks: + num_shots = config.recall_num_shots + nonmember_prefix = data_nonmember[:num_shots] + nonmember_data = data_nonmember[num_shots:] + + member_prefix = data_member[:num_shots] + member_data = data_member[num_shots:] + + data_nonmember = nonmember_data + data_member = member_data + else: + nonmember_prefix = None + + other_objs, other_nonmembers = None, None if config.dataset_nonmember_other_sources is not None: other_objs, other_nonmembers = [], [] @@ -628,7 +651,8 @@ def main(config: ExperimentConfig): ref_models=ref_models, config=config, is_train=True, - n_samples=n_samples + n_samples=n_samples, + nonmember_prefix = nonmember_prefix ) # Collect scores for non-members nonmember_preds, nonmember_samples = get_mia_scores( @@ -640,6 +664,7 @@ def main(config: ExperimentConfig): config=config, is_train=False, n_samples=n_samples, + nonmember_prefix = nonmember_prefix ) blackbox_outputs = compute_metrics_from_scores( member_preds, From 4472fab41a0b59ed5597f05e8ac9858b37dbee19 Mon Sep 17 00:00:00 2001 From: Austin Brown Date: Wed, 28 Aug 2024 22:11:07 -0400 Subject: [PATCH 05/12] added prefix processing --- configs/recall.json | 12 +++-- mimir/attacks/recall.py | 104 +++++++++++++++++++++++++++++++--------- run.py | 7 ++- 3 files changed, 96 insertions(+), 27 deletions(-) diff --git a/configs/recall.json b/configs/recall.json index 1fb6939..2775eea 100644 --- a/configs/recall.json +++ b/configs/recall.json @@ -1,6 +1,6 @@ { "experiment_name": "edited_members_exp", - "base_model": "EleutherAI/gpt-neo-125m", + "base_model": "EleutherAI/pythia-1.4b", "dataset_member": "the_pile", "dataset_nonmember": "the_pile", "min_words": 100, @@ -11,10 +11,16 @@ "specific_source": "Github_ngram_13_<0.8_truncated", "n_samples": 1000, "recall_num_shots": 1, - "blackbox_attacks": ["recall"], + "blackbox_attacks": ["loss", "ref", "zlib", "min_k", "min_k++", "recall"], "env_config": { "results": "results_new", - "device_map": "balanced_low_0" + "device": "cuda:0", + "device_aux": "cuda:0" + }, + "ref_config": { + "models": [ + "EleutherAI/pythia-160m" + ] }, "neighborhood_config": { "model": "bert", diff --git a/mimir/attacks/recall.py b/mimir/attacks/recall.py index 003611f..c7ca097 100644 --- a/mimir/attacks/recall.py +++ b/mimir/attacks/recall.py @@ -14,43 +14,103 @@ def __init__(self, config: ExperimentConfig, target_model: Model): @torch.no_grad() def _attack(self, document, probs, tokens = None, **kwargs): - nonmember_prefix = kwargs.get("nonmember_prefix", None) + recall_dict: dict = kwargs.get("recall_dict", None) + + nonmember_prefix = recall_dict.get("prefix") + num_shots = recall_dict.get("num_shots") + avg_length = recall_dict.get("avg_length") + assert nonmember_prefix, "nonmember_prefix should not be None or empty" lls = self.target_model.get_ll(document, probs = probs, tokens = tokens) ll_nonmember = self.get_conditional_ll(nonmember_prefix = nonmember_prefix, text = document, - model = self.target_model, tokenizer=self.target_model.tokenizer, tokens = tokens) + num_shots = num_shots, avg_length = avg_length, + tokens = tokens) recall = ll_nonmember / lls return recall - def get_conditional_ll(self, nonmember_prefix, text, model, tokenizer, tokens = None): + def process_prefix(self, prefix, avg_length, total_shots): + model = self.target_model + tokenizer = self.target_model.tokenizer + + max_length = model.max_length + token_counts = [len(tokenizer.encode(shot)) for shot in prefix] + + target_token_count = avg_length + total_tokens = sum(token_counts) + target_token_count + if total_tokens<=max_length: + return prefix + # Determine the maximum number of shots that can fit within the max_length + max_shots = 0 + cumulative_tokens = target_token_count + for count in token_counts: + if cumulative_tokens + count <= max_length: + max_shots += 1 + cumulative_tokens += count + else: + break + # Truncate the prefix to include only the maximum number of shots + truncated_prefix = prefix[-max_shots:] + print(f"""Too many shots used. Initial ReCaLL number of shots was {total_shots}. + Maximum number of shots is {max_shots}. Defaulting to maximum number of shots.""") + return truncated_prefix + + def get_conditional_ll(self, nonmember_prefix, text, num_shots, avg_length, tokens=None): assert nonmember_prefix, "nonmember_prefix should not be None or empty" - - input_encodings = tokenizer(text = nonmember_prefix, return_tensors="pt") + + model = self.target_model + tokenizer = self.target_model.tokenizer + if tokens is None: - target_encodings = tokenizer(text = text, return_tensors="pt") + target_encodings = tokenizer(text=text, return_tensors="pt") else: target_encodings = tokens + processed_prefix = self.process_prefix(nonmember_prefix, avg_length, total_shots=num_shots) + input_encodings = tokenizer(text=processed_prefix, return_tensors="pt") + + prefix_ids = input_encodings.input_ids.to(model.device) + text_ids = target_encodings.input_ids.to(model.device) + max_length = model.max_length - input_ids = input_encodings.input_ids.to(model.device) - target_ids = target_encodings.input_ids.to(model.device) - - total_length = input_ids.size(1) + target_ids.size(1) - - if total_length > max_length: - excess_length = total_length - max_length - target_ids = target_ids[:, :-excess_length] - concat_ids = torch.cat((input_ids, target_ids), dim=1) - labels = concat_ids.clone() - labels[:, :input_ids.size(1)] = -100 + total_length = prefix_ids.size(1) + text_ids.size(1) + + if prefix_ids.size(1) >= max_length: + raise ValueError("Prefix length exceeds or equals the model's maximum context window.") + + log_likelihoods = [] + stride = model.stride + labels = torch.cat((prefix_ids, text_ids), dim=1) with torch.no_grad(): - outputs = model.model(concat_ids, labels=labels) - - loss, logits = outputs[:2] - ll = -loss.item() - return ll + for i in range(0, labels.size(1), stride): + + begin_loc = max(i + stride - max_length, 0) + end_loc = min(i + stride, labels.size(1)) + trg_len = end_loc - i # This may be different from stride on the last loop + + # Extract the input_ids for the current window + input_ids = labels[:, begin_loc:end_loc].to(model.device) + + # Clone input_ids to create target_ids, masking out the prefix and the initial part of the text + target_ids = input_ids.clone() + + # Masking: prefix part + initial part of the text in the sliding window + if begin_loc < prefix_ids.size(1): + prefix_mask_length = prefix_ids.size(1) - begin_loc + target_ids[:, :prefix_mask_length] = -100 + + # Mask the initial part of the text according to trg_len + target_ids[:, :-trg_len] = -100 + + outputs = model.model(input_ids, labels=target_ids) + loss = outputs.loss + + log_likelihoods.append(-loss.item()) + + total_log_likelihood = sum(log_likelihoods) + return total_log_likelihood + diff --git a/run.py b/run.py index 6becfaa..d320729 100644 --- a/run.py +++ b/run.py @@ -105,7 +105,10 @@ def get_mia_scores( if AllAttacks.RECALL in attackers_dict.keys(): if nonmember_prefix is None: raise ValueError("Must include a prefix for ReCaLL attack") - + num_shots = config.recall_num_shots + avg_length = int(np.mean([len(target_model.tokenizer.encode(ex)) for ex in data["records"]])) + recall_dict = {"prefix":nonmember_prefix, "num_shots":num_shots, "avg_length":avg_length} + # For each batch of data # TODO: Batch-size isn't really "batching" data - change later for batch in tqdm(range(math.ceil(n_samples / batch_size)), desc=f"Computing criterion"): @@ -166,7 +169,7 @@ def get_mia_scores( ), loss=loss, all_probs=s_all_probs, - nonmember_prefix = nonmember_prefix + recall_dict = recall_dict ) sample_information[attack].append(score) From 3cc6c9bffd9d74979a29700384654ce796eb9542 Mon Sep 17 00:00:00 2001 From: Austin Brown Date: Wed, 28 Aug 2024 22:15:45 -0400 Subject: [PATCH 06/12] reset new_mi.json --- configs/new_mi.json | 24 +++++++++--------------- 1 file changed, 9 insertions(+), 15 deletions(-) diff --git a/configs/new_mi.json b/configs/new_mi.json index 37ee81b..280a80e 100644 --- a/configs/new_mi.json +++ b/configs/new_mi.json @@ -8,25 +8,19 @@ "max_tokens": 512, "max_data": 100000, "output_name": "unified_mia", - "specific_source": "Github_ngram_13_<0.8_truncated", + "specific_source": "Github", "n_samples": 1000, - "blackbox_attacks": ["loss", "min_k", "zlib"], + "blackbox_attacks": ["loss", "ref", "min_k", "zlib"], + "ref_config": { + "models": [ + "stabilityai/stablelm-base-alpha-3b-v2" + ] + }, "env_config": { "results": "results_new", "device_map": "balanced_low_0" }, - "neighborhood_config": { - "model": "bert", - "n_perturbation_list": [ - 25 - ], - "pct_words_masked": 0.3, - "span_length": 2, - "dump_cache": false, - "load_from_cache": true, - "neighbor_strategy": "random" - }, "dump_cache": false, - "load_from_cache": false, - "load_from_hf": true + "load_from_cache": true, + "load_from_hf": false } \ No newline at end of file From 3f19c7ed09eae0e2bd3bc9b2658c292da35958fa Mon Sep 17 00:00:00 2001 From: Austin Brown Date: Mon, 2 Sep 2024 14:54:32 -0400 Subject: [PATCH 07/12] fix nan errors --- mimir/attacks/recall.py | 44 ++++++++++++++++++++--------------------- 1 file changed, 21 insertions(+), 23 deletions(-) diff --git a/mimir/attacks/recall.py b/mimir/attacks/recall.py index c7ca097..be890d0 100644 --- a/mimir/attacks/recall.py +++ b/mimir/attacks/recall.py @@ -11,6 +11,7 @@ class ReCaLLAttack(Attack): def __init__(self, config: ExperimentConfig, target_model: Model): super().__init__(config, target_model, ref_model = None) + self.prefix = None @torch.no_grad() def _attack(self, document, probs, tokens = None, **kwargs): @@ -28,19 +29,25 @@ def _attack(self, document, probs, tokens = None, **kwargs): tokens = tokens) recall = ll_nonmember / lls + assert not np.isnan(recall) return recall def process_prefix(self, prefix, avg_length, total_shots): model = self.target_model tokenizer = self.target_model.tokenizer + if self.prefix is not None: + # We only need to process the prefix once, after that we can just return + return self.prefix + max_length = model.max_length token_counts = [len(tokenizer.encode(shot)) for shot in prefix] target_token_count = avg_length total_tokens = sum(token_counts) + target_token_count if total_tokens<=max_length: - return prefix + self.prefix = prefix + return self.prefix # Determine the maximum number of shots that can fit within the max_length max_shots = 0 cumulative_tokens = target_token_count @@ -52,9 +59,9 @@ def process_prefix(self, prefix, avg_length, total_shots): break # Truncate the prefix to include only the maximum number of shots truncated_prefix = prefix[-max_shots:] - print(f"""Too many shots used. Initial ReCaLL number of shots was {total_shots}. - Maximum number of shots is {max_shots}. Defaulting to maximum number of shots.""") - return truncated_prefix + print(f"""\nToo many shots used. Initial ReCaLL number of shots was {total_shots}. Maximum number of shots is {max_shots}. Defaulting to maximum number of shots.""") + self.prefix = truncated_prefix + return self.prefix def get_conditional_ll(self, nonmember_prefix, text, num_shots, avg_length, tokens=None): assert nonmember_prefix, "nonmember_prefix should not be None or empty" @@ -68,7 +75,7 @@ def get_conditional_ll(self, nonmember_prefix, text, num_shots, avg_length, toke target_encodings = tokens processed_prefix = self.process_prefix(nonmember_prefix, avg_length, total_shots=num_shots) - input_encodings = tokenizer(text=processed_prefix, return_tensors="pt") + input_encodings = tokenizer(text="".join(processed_prefix), return_tensors="pt") prefix_ids = input_encodings.input_ids.to(model.device) text_ids = target_encodings.input_ids.to(model.device) @@ -78,39 +85,30 @@ def get_conditional_ll(self, nonmember_prefix, text, num_shots, avg_length, toke if prefix_ids.size(1) >= max_length: raise ValueError("Prefix length exceeds or equals the model's maximum context window.") - - log_likelihoods = [] stride = model.stride labels = torch.cat((prefix_ids, text_ids), dim=1) + + total_loss = 0 with torch.no_grad(): for i in range(0, labels.size(1), stride): - begin_loc = max(i + stride - max_length, 0) end_loc = min(i + stride, labels.size(1)) - trg_len = end_loc - i # This may be different from stride on the last loop - - # Extract the input_ids for the current window + trg_len = end_loc - i input_ids = labels[:, begin_loc:end_loc].to(model.device) - - # Clone input_ids to create target_ids, masking out the prefix and the initial part of the text target_ids = input_ids.clone() - # Masking: prefix part + initial part of the text in the sliding window - if begin_loc < prefix_ids.size(1): - prefix_mask_length = prefix_ids.size(1) - begin_loc - target_ids[:, :prefix_mask_length] = -100 - - # Mask the initial part of the text according to trg_len + target_ids[:, :max(0, prefix_ids.size(1) - begin_loc)] = -100 target_ids[:, :-trg_len] = -100 + if torch.all(target_ids == -100): + continue #this prevents the model from outputting nan as loss value + outputs = model.model(input_ids, labels=target_ids) loss = outputs.loss + total_loss += -loss.item() - log_likelihoods.append(-loss.item()) - - total_log_likelihood = sum(log_likelihoods) - return total_log_likelihood + return total_loss From 0249ab991ef26902fd88729a957d2b0274babb9a Mon Sep 17 00:00:00 2001 From: Austin Brown Date: Tue, 3 Sep 2024 18:23:29 -0400 Subject: [PATCH 08/12] better sliding window --- mimir/attacks/recall.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/mimir/attacks/recall.py b/mimir/attacks/recall.py index be890d0..38f9fa4 100644 --- a/mimir/attacks/recall.py +++ b/mimir/attacks/recall.py @@ -81,19 +81,17 @@ def get_conditional_ll(self, nonmember_prefix, text, num_shots, avg_length, toke text_ids = target_encodings.input_ids.to(model.device) max_length = model.max_length - total_length = prefix_ids.size(1) + text_ids.size(1) if prefix_ids.size(1) >= max_length: raise ValueError("Prefix length exceeds or equals the model's maximum context window.") - stride = model.stride labels = torch.cat((prefix_ids, text_ids), dim=1) total_loss = 0 with torch.no_grad(): - for i in range(0, labels.size(1), stride): - begin_loc = max(i + stride - max_length, 0) - end_loc = min(i + stride, labels.size(1)) + for i in range(0, labels.size(1), max_length): + begin_loc = max(i - max_length, 0) + end_loc = min(i, labels.size(1)) trg_len = end_loc - i input_ids = labels[:, begin_loc:end_loc].to(model.device) target_ids = input_ids.clone() From 793bdd7a559a9dd04a27c2f54374fb1dd55cdc72 Mon Sep 17 00:00:00 2001 From: Austin Brown Date: Tue, 3 Sep 2024 18:23:49 -0400 Subject: [PATCH 09/12] json update --- configs/recall.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/configs/recall.json b/configs/recall.json index 2775eea..4d54fda 100644 --- a/configs/recall.json +++ b/configs/recall.json @@ -1,5 +1,5 @@ { - "experiment_name": "edited_members_exp", + "experiment_name": "recall", "base_model": "EleutherAI/pythia-1.4b", "dataset_member": "the_pile", "dataset_nonmember": "the_pile", From 475e66b675e5e17bfbfb7966c9e76115b941556c Mon Sep 17 00:00:00 2001 From: Austin Brown Date: Wed, 4 Sep 2024 20:05:34 -0400 Subject: [PATCH 10/12] normalized losses --- mimir/attacks/recall.py | 37 ++++++++++++++++++++++++++----------- 1 file changed, 26 insertions(+), 11 deletions(-) diff --git a/mimir/attacks/recall.py b/mimir/attacks/recall.py index 38f9fa4..412ff18 100644 --- a/mimir/attacks/recall.py +++ b/mimir/attacks/recall.py @@ -29,6 +29,7 @@ def _attack(self, document, probs, tokens = None, **kwargs): tokens = tokens) recall = ll_nonmember / lls + assert not np.isnan(recall) return recall @@ -86,27 +87,41 @@ def get_conditional_ll(self, nonmember_prefix, text, num_shots, avg_length, toke raise ValueError("Prefix length exceeds or equals the model's maximum context window.") labels = torch.cat((prefix_ids, text_ids), dim=1) + total_length = labels.size(1) total_loss = 0 + total_tokens = 0 with torch.no_grad(): - for i in range(0, labels.size(1), max_length): - begin_loc = max(i - max_length, 0) - end_loc = min(i, labels.size(1)) - trg_len = end_loc - i + for i in range(0, total_length, max_length): + begin_loc = i + end_loc = min(i + max_length, total_length) + trg_len = end_loc - begin_loc + input_ids = labels[:, begin_loc:end_loc].to(model.device) target_ids = input_ids.clone() - target_ids[:, :max(0, prefix_ids.size(1) - begin_loc)] = -100 - target_ids[:, :-trg_len] = -100 - + if begin_loc < prefix_ids.size(1): + prefix_overlap = min(prefix_ids.size(1) - begin_loc, max_length) + target_ids[:, :prefix_overlap] = -100 + + if end_loc > total_length - text_ids.size(1): + target_overlap = min(end_loc - (total_length - text_ids.size(1)), max_length) + target_ids[:, -target_overlap:] = input_ids[:, -target_overlap:] + if torch.all(target_ids == -100): - continue #this prevents the model from outputting nan as loss value + continue outputs = model.model(input_ids, labels=target_ids) loss = outputs.loss - total_loss += -loss.item() - - return total_loss + if torch.isnan(loss): + print(f"NaN detected in loss at iteration {i}. Non masked target_ids size is {(target_ids != -100).sum().item()}") + continue + non_masked_tokens = (target_ids != -100).sum().item() + total_loss += loss.item() * non_masked_tokens + total_tokens += non_masked_tokens + + average_loss = total_loss / total_tokens if total_tokens > 0 else 0 + return -average_loss From c701cc045dabad60747fdab07e2d5cce1a668809 Mon Sep 17 00:00:00 2001 From: Austin Brown Date: Wed, 4 Sep 2024 20:25:40 -0400 Subject: [PATCH 11/12] update readme --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 1c2b213..d6307ea 100644 --- a/README.md +++ b/README.md @@ -55,6 +55,7 @@ We include and implement the following attacks, as described in our paper. - [Min-K% Prob](https://swj0419.github.io/detect-pretrain.github.io/) (`min_k`). Uses k% of tokens with minimum likelihood for score computation. - [Min-K%++](https://zjysteven.github.io/mink-plus-plus/) (`min_k++`). Uses k% of tokens with minimum *normalized* likelihood for score computation. - [Gradient Norm](https://arxiv.org/abs/2402.17012) (`gradnorm`). Uses gradient norm of the target datapoint as score. +- [ReCaLL](https://royxie.com/recall-project-page/)(`recall`). Operates by comparing the unconditional and conditional log-likelihoods. ## Adding your own dataset From bdb3c0673dbc4562bc42cd5dfe661acb39aad763 Mon Sep 17 00:00:00 2001 From: Austin Brown Date: Sun, 15 Sep 2024 15:33:26 -0400 Subject: [PATCH 12/12] made requested changes --- configs/recall.json | 4 +++- mimir/attacks/recall.py | 5 +++++ mimir/config.py | 11 +++++++++-- run.py | 42 +++++++++++++++++++++++++---------------- 4 files changed, 43 insertions(+), 19 deletions(-) diff --git a/configs/recall.json b/configs/recall.json index 4d54fda..c6753d5 100644 --- a/configs/recall.json +++ b/configs/recall.json @@ -10,7 +10,6 @@ "output_name": "unified_mia", "specific_source": "Github_ngram_13_<0.8_truncated", "n_samples": 1000, - "recall_num_shots": 1, "blackbox_attacks": ["loss", "ref", "zlib", "min_k", "min_k++", "recall"], "env_config": { "results": "results_new", @@ -22,6 +21,9 @@ "EleutherAI/pythia-160m" ] }, + "recall_config":{ + "num_shots": 1 + }, "neighborhood_config": { "model": "bert", "n_perturbation_list": [ diff --git a/mimir/attacks/recall.py b/mimir/attacks/recall.py index 412ff18..537c32e 100644 --- a/mimir/attacks/recall.py +++ b/mimir/attacks/recall.py @@ -9,6 +9,9 @@ class ReCaLLAttack(Attack): + #** Note: this is a suboptimal implementation of the ReCaLL attack due to necessary changes made to integrate it alongside the other attacks + #** for a better performing version, please refer to: https://github.com/ruoyuxie/recall + def __init__(self, config: ExperimentConfig, target_model: Model): super().__init__(config, target_model, ref_model = None) self.prefix = None @@ -22,6 +25,8 @@ def _attack(self, document, probs, tokens = None, **kwargs): avg_length = recall_dict.get("avg_length") assert nonmember_prefix, "nonmember_prefix should not be None or empty" + assert num_shots, "num_shots should not be None or empty" + assert avg_length, "avg_length should not be None or empty" lls = self.target_model.get_ll(document, probs = probs, tokens = tokens) ll_nonmember = self.get_conditional_ll(nonmember_prefix = nonmember_prefix, text = document, diff --git a/mimir/config.py b/mimir/config.py index 9b2615c..2d9a05b 100644 --- a/mimir/config.py +++ b/mimir/config.py @@ -59,6 +59,13 @@ def __post_init__(self): if self.dump_cache and self.load_from_cache: raise ValueError("Cannot dump and load cache at the same time") +@dataclass +class ReCaLLConfig(Serializable): + """ + Config for ReCaLL attack + """ + num_shots: Optional[int] = 1 + """Number of shots for ReCaLL Attacks""" @dataclass class EnvironmentConfig(Serializable): @@ -174,8 +181,6 @@ class ExperimentConfig(Serializable): """Chunk size""" scoring_model_name: Optional[str] = None """Scoring model (if different from base model)""" - recall_num_shots: Optional[int] = 1 - """Number of shots for ReCaLL Attacks""" top_k: Optional[int] = 40 """Consider only top-k tokens""" do_top_k: Optional[bool] = False @@ -196,6 +201,8 @@ class ExperimentConfig(Serializable): """Random seed""" ref_config: Optional[ReferenceConfig] = None """Reference model config""" + recall_config: Optional[ReCaLLConfig] = None + """ReCaLL attack config""" neighborhood_config: Optional[NeighborhoodConfig] = None """Neighborhood attack config""" env_config: Optional[EnvironmentConfig] = None diff --git a/run.py b/run.py index d320729..b1f571e 100644 --- a/run.py +++ b/run.py @@ -19,7 +19,8 @@ EnvironmentConfig, NeighborhoodConfig, ReferenceConfig, - OpenAIConfig + OpenAIConfig, + ReCaLLConfig ) import mimir.data_utils as data_utils import mimir.plot_utils as plot_utils @@ -101,11 +102,10 @@ def get_mia_scores( n_perturbation: [] for n_perturbation in n_perturbation_list } - nonmember_prefix = kwargs.get("nonmember_prefix", None) - if AllAttacks.RECALL in attackers_dict.keys(): - if nonmember_prefix is None: - raise ValueError("Must include a prefix for ReCaLL attack") - num_shots = config.recall_num_shots + recall_config = config.recall_config + if recall_config: + nonmember_prefix = kwargs.get("nonmember_prefix", None) + num_shots = recall_config.num_shots avg_length = int(np.mean([len(target_model.tokenizer.encode(ex)) for ex in data["records"]])) recall_dict = {"prefix":nonmember_prefix, "num_shots":num_shots, "avg_length":avg_length} @@ -158,10 +158,10 @@ def get_mia_scores( if attack.startswith(AllAttacks.REFERENCE_BASED) or attack == AllAttacks.LOSS: continue - if attack != AllAttacks.NEIGHBOR: + if attack == AllAttacks.RECALL: score = attacker.attack( substr, - probs=s_tk_probs, + probs = s_tk_probs, detokenized_sample=( detokenized_sample[i] if config.pretokenized @@ -172,6 +172,21 @@ def get_mia_scores( recall_dict = recall_dict ) sample_information[attack].append(score) + + + elif attack != AllAttacks.NEIGHBOR: + score = attacker.attack( + substr, + probs=s_tk_probs, + detokenized_sample=( + detokenized_sample[i] + if config.pretokenized + else None + ), + loss=loss, + all_probs=s_all_probs, + ) + sample_information[attack].append(score) else: # For each 'number of neighbors' @@ -427,6 +442,7 @@ def main(config: ExperimentConfig): neigh_config: NeighborhoodConfig = config.neighborhood_config ref_config: ReferenceConfig = config.ref_config openai_config: OpenAIConfig = config.openai_config + recall_config: ReCaLLConfig = config.recall_config if openai_config: openAI_model = OpenAI_APIModel(config) @@ -528,15 +544,9 @@ def main(config: ExperimentConfig): #* ReCaLL Specific if AllAttacks.RECALL in config.blackbox_attacks: - num_shots = config.recall_num_shots + assert recall_config, "Must provide a recall_config" + num_shots = recall_config.num_shots nonmember_prefix = data_nonmember[:num_shots] - nonmember_data = data_nonmember[num_shots:] - - member_prefix = data_member[:num_shots] - member_data = data_member[num_shots:] - - data_nonmember = nonmember_data - data_member = member_data else: nonmember_prefix = None