-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathconstraints_code.py
51 lines (43 loc) · 1.62 KB
/
constraints_code.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import glob
import librosa
import torch
import torchaudio
import torchaudio.functional as F
import torchaudio.transforms as transforms
from torch.utils.data import Dataset, DataLoader, ConcatDataset, random_split
from librosa import feature
import torch.nn as nn
from torchsummary import summary
from tqdm import tqdm
import torchvision.transforms as vtransforms
import dcase_util
from sed_eval.audio_tag import AudioTaggingMetrics as get_metrics_sedeval
from collections import OrderedDict
from torchlibrosa.augmentation import SpecAugmentation
import torch.nn.functional as Ft
from implication_code import Implication
class Constraints(nn.Module):
def __init__(self,ontology={'parent':['child1','...childn']}):
super(Constraints, self).__init__()
self.ontology = ontology
self.constraints = Implication(ontology = self.ontology)
self.lambda_param = nn.Parameter(torch.zeros(self.constraints.num_constraints()))
def get_optim_params(self,ddlr=0.06,use_wt_as_lr_factor = True):
params = {'params':None,'lr':None}
factor = 1
if use_wt_as_lr_factor:
factor = self.constraints.weight
self.constraint_dict[k].weight = 1
if self.lambda_param.requires_grad:
params['params'] = self.lambda_param
params['lr'] = ddlr*factor
return params
def forward(self,scores):
h_k = self.constraints.get_penalty(scores)
penalty = (self.lambda_param*h_k).sum()
loss = self.constraints.weight*penalty
return loss,h_k