-
Notifications
You must be signed in to change notification settings - Fork 80
/
Copy pathdl_clustering.yaml
55 lines (48 loc) · 1.98 KB
/
dl_clustering.yaml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
name: "dl_clustering"
feature_definition:
use_norm_features: True
data_loader:
batch_size: 512
cluster_batch_size: 256
# Define how each data batch is sampled from multiple datasets and/or people
# option: "across_dataset", "within_dataset", "across_person", "within_person"
generate_by: "across_dataset"
mixup: "across" # option: null, "across", "within"
mixup_alpha: 0.2
model_params:
arch: "1dCNN" # model architecture
# input dimension
# 2: a multi-channel time series data (for models like 1dCNN)
# 3: a single-channel image data (for models like 2dCNN)
input_dim: 2
conv_shapes: [8,8,8] # a list of convolutional layers shapes for the siamese network
embedding_size: 16 # vector length of the feature embedding network output
flag_y_vector: False # False: use sparse label for negative mining during contrastive learning
triplet_loss_margin: 0.2 # triplet loss margin
clustering: # parameter for the unsupervised clustering model
conv_shapes: [64, 32]
embedding_size: 10
n_clusters: 60
loss_weight: 0.1
tdistribution_shape_param: 1
error_tolerance: 0.001
training_params:
optimizer: "Adam" # option: SGD or Adam
learning_rate: 0.001
epochs: 1000
steps_per_epoch: 10 # number of batches per epoch
cos_annealing_step: 100
cos_annealing_decay: 0.95
best_epoch_strategy: "direct" # option: direct or on_test
verbose: 0
clustering_verbose: 2
skip_training: False
# only if train models for particular clusters, add the clusters idx.
# This can be used on the cluster server (\eg SLURM-based) to accelerate the training
focus_cluster_idx: []
# Parameters used based on the embeddings directions
metrics_params:
n_neighbor: 3 # how many top close/far datapoints used for classification
positive_weight_ratio: 1.5
direction: "close" # option: both, close, far. Which side of datapoints to be used for classification
flag_distance_weight: False # whether use distance as a weight (closer/further datapoints have higher weights)