Skip to content

Commit

Permalink
Merge pull request #26 from austinbrown5/main
Browse files Browse the repository at this point in the history
add ReCaLL attack
  • Loading branch information
iamgroot42 authored Sep 16, 2024
2 parents 2772f72 + bdb3c06 commit 99b67d2
Show file tree
Hide file tree
Showing 7 changed files with 227 additions and 3 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ We include and implement the following attacks, as described in our paper.
- [Min-K% Prob](https://swj0419.github.io/detect-pretrain.github.io/) (`min_k`). Uses k% of tokens with minimum likelihood for score computation.
- [Min-K%++](https://zjysteven.github.io/mink-plus-plus/) (`min_k++`). Uses k% of tokens with minimum *normalized* likelihood for score computation.
- [Gradient Norm](https://arxiv.org/abs/2402.17012) (`gradnorm`). Uses gradient norm of the target datapoint as score.
- [ReCaLL](https://royxie.com/recall-project-page/)(`recall`). Operates by comparing the unconditional and conditional log-likelihoods.

## Adding your own dataset

Expand Down
41 changes: 41 additions & 0 deletions configs/recall.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
{
"experiment_name": "recall",
"base_model": "EleutherAI/pythia-1.4b",
"dataset_member": "the_pile",
"dataset_nonmember": "the_pile",
"min_words": 100,
"max_words": 200,
"max_tokens": 512,
"max_data": 100000,
"output_name": "unified_mia",
"specific_source": "Github_ngram_13_<0.8_truncated",
"n_samples": 1000,
"blackbox_attacks": ["loss", "ref", "zlib", "min_k", "min_k++", "recall"],
"env_config": {
"results": "results_new",
"device": "cuda:0",
"device_aux": "cuda:0"
},
"ref_config": {
"models": [
"EleutherAI/pythia-160m"
]
},
"recall_config":{
"num_shots": 1
},
"neighborhood_config": {
"model": "bert",
"n_perturbation_list": [
25
],
"pct_words_masked": 0.3,
"span_length": 2,
"dump_cache": false,
"load_from_cache": true,
"neighbor_strategy": "random"
},
"dump_cache": false,
"load_from_cache": false,
"load_from_hf": true
}
1 change: 1 addition & 0 deletions mimir/attacks/all_attacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class AllAttacks(str, Enum):
MIN_K_PLUS_PLUS = "min_k++" # Done
NEIGHBOR = "ne" # Done
GRADNORM = "gradnorm" # Done
RECALL = "recall"
# QUANTILE = "quantile" # Uncomment when tested implementation is available


Expand Down
132 changes: 132 additions & 0 deletions mimir/attacks/recall.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
"""
ReCaLL Attack: https://github.com/ruoyuxie/recall/
"""
import torch
import numpy as np
from mimir.attacks.all_attacks import Attack
from mimir.models import Model
from mimir.config import ExperimentConfig

class ReCaLLAttack(Attack):

#** Note: this is a suboptimal implementation of the ReCaLL attack due to necessary changes made to integrate it alongside the other attacks
#** for a better performing version, please refer to: https://github.com/ruoyuxie/recall

def __init__(self, config: ExperimentConfig, target_model: Model):
super().__init__(config, target_model, ref_model = None)
self.prefix = None

@torch.no_grad()
def _attack(self, document, probs, tokens = None, **kwargs):
recall_dict: dict = kwargs.get("recall_dict", None)

nonmember_prefix = recall_dict.get("prefix")
num_shots = recall_dict.get("num_shots")
avg_length = recall_dict.get("avg_length")

assert nonmember_prefix, "nonmember_prefix should not be None or empty"
assert num_shots, "num_shots should not be None or empty"
assert avg_length, "avg_length should not be None or empty"

lls = self.target_model.get_ll(document, probs = probs, tokens = tokens)
ll_nonmember = self.get_conditional_ll(nonmember_prefix = nonmember_prefix, text = document,
num_shots = num_shots, avg_length = avg_length,
tokens = tokens)
recall = ll_nonmember / lls


assert not np.isnan(recall)
return recall

def process_prefix(self, prefix, avg_length, total_shots):
model = self.target_model
tokenizer = self.target_model.tokenizer

if self.prefix is not None:
# We only need to process the prefix once, after that we can just return
return self.prefix

max_length = model.max_length
token_counts = [len(tokenizer.encode(shot)) for shot in prefix]

target_token_count = avg_length
total_tokens = sum(token_counts) + target_token_count
if total_tokens<=max_length:
self.prefix = prefix
return self.prefix
# Determine the maximum number of shots that can fit within the max_length
max_shots = 0
cumulative_tokens = target_token_count
for count in token_counts:
if cumulative_tokens + count <= max_length:
max_shots += 1
cumulative_tokens += count
else:
break
# Truncate the prefix to include only the maximum number of shots
truncated_prefix = prefix[-max_shots:]
print(f"""\nToo many shots used. Initial ReCaLL number of shots was {total_shots}. Maximum number of shots is {max_shots}. Defaulting to maximum number of shots.""")
self.prefix = truncated_prefix
return self.prefix

def get_conditional_ll(self, nonmember_prefix, text, num_shots, avg_length, tokens=None):
assert nonmember_prefix, "nonmember_prefix should not be None or empty"

model = self.target_model
tokenizer = self.target_model.tokenizer

if tokens is None:
target_encodings = tokenizer(text=text, return_tensors="pt")
else:
target_encodings = tokens

processed_prefix = self.process_prefix(nonmember_prefix, avg_length, total_shots=num_shots)
input_encodings = tokenizer(text="".join(processed_prefix), return_tensors="pt")

prefix_ids = input_encodings.input_ids.to(model.device)
text_ids = target_encodings.input_ids.to(model.device)

max_length = model.max_length

if prefix_ids.size(1) >= max_length:
raise ValueError("Prefix length exceeds or equals the model's maximum context window.")

labels = torch.cat((prefix_ids, text_ids), dim=1)
total_length = labels.size(1)

total_loss = 0
total_tokens = 0
with torch.no_grad():
for i in range(0, total_length, max_length):
begin_loc = i
end_loc = min(i + max_length, total_length)
trg_len = end_loc - begin_loc

input_ids = labels[:, begin_loc:end_loc].to(model.device)
target_ids = input_ids.clone()

if begin_loc < prefix_ids.size(1):
prefix_overlap = min(prefix_ids.size(1) - begin_loc, max_length)
target_ids[:, :prefix_overlap] = -100

if end_loc > total_length - text_ids.size(1):
target_overlap = min(end_loc - (total_length - text_ids.size(1)), max_length)
target_ids[:, -target_overlap:] = input_ids[:, -target_overlap:]

if torch.all(target_ids == -100):
continue

outputs = model.model(input_ids, labels=target_ids)
loss = outputs.loss
if torch.isnan(loss):
print(f"NaN detected in loss at iteration {i}. Non masked target_ids size is {(target_ids != -100).sum().item()}")
continue
non_masked_tokens = (target_ids != -100).sum().item()
total_loss += loss.item() * non_masked_tokens
total_tokens += non_masked_tokens

average_loss = total_loss / total_tokens if total_tokens > 0 else 0
return -average_loss



2 changes: 2 additions & 0 deletions mimir/attacks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from mimir.attacks.min_k_plus_plus import MinKPlusPlusAttack
from mimir.attacks.neighborhood import NeighborhoodAttack
from mimir.attacks.gradnorm import GradNormAttack
from mimir.attacks.recall import ReCaLLAttack


# TODO Use decorators to link attack implementations with enum above
Expand All @@ -19,6 +20,7 @@ def get_attacker(attack: str):
AllAttacks.MIN_K_PLUS_PLUS: MinKPlusPlusAttack,
AllAttacks.NEIGHBOR: NeighborhoodAttack,
AllAttacks.GRADNORM: GradNormAttack,
AllAttacks.RECALL: ReCaLLAttack
}
attack_cls = mapping.get(attack, None)
if attack_cls is None:
Expand Down
9 changes: 9 additions & 0 deletions mimir/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,13 @@ def __post_init__(self):
if self.dump_cache and self.load_from_cache:
raise ValueError("Cannot dump and load cache at the same time")

@dataclass
class ReCaLLConfig(Serializable):
"""
Config for ReCaLL attack
"""
num_shots: Optional[int] = 1
"""Number of shots for ReCaLL Attacks"""

@dataclass
class EnvironmentConfig(Serializable):
Expand Down Expand Up @@ -194,6 +201,8 @@ class ExperimentConfig(Serializable):
"""Random seed"""
ref_config: Optional[ReferenceConfig] = None
"""Reference model config"""
recall_config: Optional[ReCaLLConfig] = None
"""ReCaLL attack config"""
neighborhood_config: Optional[NeighborhoodConfig] = None
"""Neighborhood attack config"""
env_config: Optional[EnvironmentConfig] = None
Expand Down
44 changes: 41 additions & 3 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
EnvironmentConfig,
NeighborhoodConfig,
ReferenceConfig,
OpenAIConfig
OpenAIConfig,
ReCaLLConfig
)
import mimir.data_utils as data_utils
import mimir.plot_utils as plot_utils
Expand Down Expand Up @@ -77,6 +78,7 @@ def get_mia_scores(
is_train: bool,
n_samples: int = None,
batch_size: int = 50,
**kwargs
):
# Fix randomness
fix_seed(config.random_seed)
Expand All @@ -100,6 +102,13 @@ def get_mia_scores(
n_perturbation: [] for n_perturbation in n_perturbation_list
}

recall_config = config.recall_config
if recall_config:
nonmember_prefix = kwargs.get("nonmember_prefix", None)
num_shots = recall_config.num_shots
avg_length = int(np.mean([len(target_model.tokenizer.encode(ex)) for ex in data["records"]]))
recall_dict = {"prefix":nonmember_prefix, "num_shots":num_shots, "avg_length":avg_length}

# For each batch of data
# TODO: Batch-size isn't really "batching" data - change later
for batch in tqdm(range(math.ceil(n_samples / batch_size)), desc=f"Computing criterion"):
Expand Down Expand Up @@ -149,7 +158,23 @@ def get_mia_scores(
if attack.startswith(AllAttacks.REFERENCE_BASED) or attack == AllAttacks.LOSS:
continue

if attack != AllAttacks.NEIGHBOR:
if attack == AllAttacks.RECALL:
score = attacker.attack(
substr,
probs = s_tk_probs,
detokenized_sample=(
detokenized_sample[i]
if config.pretokenized
else None
),
loss=loss,
all_probs=s_all_probs,
recall_dict = recall_dict
)
sample_information[attack].append(score)


elif attack != AllAttacks.NEIGHBOR:
score = attacker.attack(
substr,
probs=s_tk_probs,
Expand All @@ -162,6 +187,7 @@ def get_mia_scores(
all_probs=s_all_probs,
)
sample_information[attack].append(score)

else:
# For each 'number of neighbors'
for n_perturbation in n_perturbation_list:
Expand Down Expand Up @@ -416,6 +442,7 @@ def main(config: ExperimentConfig):
neigh_config: NeighborhoodConfig = config.neighborhood_config
ref_config: ReferenceConfig = config.ref_config
openai_config: OpenAIConfig = config.openai_config
recall_config: ReCaLLConfig = config.recall_config

if openai_config:
openAI_model = OpenAI_APIModel(config)
Expand Down Expand Up @@ -515,6 +542,15 @@ def main(config: ExperimentConfig):
mask_model_tokenizer=mask_model.tokenizer if mask_model else None,
)

#* ReCaLL Specific
if AllAttacks.RECALL in config.blackbox_attacks:
assert recall_config, "Must provide a recall_config"
num_shots = recall_config.num_shots
nonmember_prefix = data_nonmember[:num_shots]
else:
nonmember_prefix = None


other_objs, other_nonmembers = None, None
if config.dataset_nonmember_other_sources is not None:
other_objs, other_nonmembers = [], []
Expand Down Expand Up @@ -628,7 +664,8 @@ def main(config: ExperimentConfig):
ref_models=ref_models,
config=config,
is_train=True,
n_samples=n_samples
n_samples=n_samples,
nonmember_prefix = nonmember_prefix
)
# Collect scores for non-members
nonmember_preds, nonmember_samples = get_mia_scores(
Expand All @@ -640,6 +677,7 @@ def main(config: ExperimentConfig):
config=config,
is_train=False,
n_samples=n_samples,
nonmember_prefix = nonmember_prefix
)
blackbox_outputs = compute_metrics_from_scores(
member_preds,
Expand Down

0 comments on commit 99b67d2

Please sign in to comment.