-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathmain_ood.py
43 lines (37 loc) · 1.28 KB
/
main_ood.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
from ood.criteria.aupr import Aupr
from ood.criteria.auroc import Auroc
from ood.criteria.fpr import Fpr
from ood.datasets.cifar100_dataset import Cifar100Dataset
from ood.datasets.cifar10_dataset import Cifar10Dataset
from ood.evaluator import Evaluator
from ood.ood_methods.mlv import Mlv
from ood.ood_methods.mlv_oe import MlvOE
from ood.ood_methods.msp import Msp
from ood.ood_methods.msp_oe import MspOE
import torch.backends.cudnn as cudnn
cudnn.benchmark = True
cifar10 = Cifar10Dataset()
cifar100 = Cifar100Dataset()
criteria = [Auroc(), Aupr(), Fpr()]
benchmarks = []
# Cifar10
benchmarks.extend([
Evaluator(Msp(cifar10), cifar10, cifar100, criteria),
Evaluator(MspOE(cifar10), cifar10, cifar100, criteria),
Evaluator(Mlv(cifar10), cifar10, cifar100, criteria),
Evaluator(MlvOE(cifar10), cifar10, cifar100, criteria),
])
# Cifar100
benchmarks.extend([
Evaluator(Msp(cifar100), cifar100, cifar10, criteria),
Evaluator(MspOE(cifar100), cifar100, cifar10, criteria),
Evaluator(Mlv(cifar100), cifar100, cifar10, criteria),
Evaluator(MlvOE(cifar100), cifar100, cifar10, criteria),
])
def main():
for i, evaluator in enumerate(benchmarks):
print('\nBenchmark', i, ':')
print(evaluator)
evaluator.evaluate()
if __name__ == '__main__':
main()