-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathSelection.java
154 lines (136 loc) · 5.85 KB
/
Selection.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
package zad2;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import zad1.Factory;
import zad1.InstanceEnum;
import java.util.*;
public class Selection {
private static final Random random = new Random();
private final Factory factory;
public Selection(Factory factory) {
this.factory = factory;
}
public double[] reverseEvaluation(ArrayList<int[]> generation) {
double[] evaluationForEach = Arrays.stream(getEvaluationForEach(generation)).asDoubleStream().toArray();
var maxVal = Arrays.stream(evaluationForEach).max().getAsDouble();
for (int i = 0; i < evaluationForEach.length; i++) {
evaluationForEach[i] = (-evaluationForEach[i]) + maxVal + 1;
}
return evaluationForEach;
}
public TreeMap<Float, Integer> getDistributionTreeMap(ArrayList<int[]> generation) {
var treeMap = new TreeMap<Float, Integer>();
var evaluationForEach = reverseEvaluation(generation);
double sumAll = Arrays.stream(evaluationForEach).sum();
double[] distribution = new double[generation.size()];
distribution[0] = evaluationForEach[0] / sumAll;
treeMap.put((float) distribution[0], 0);
for (int i = 0; i < distribution.length - 1; i++) {
distribution[i + 1] = ((float) evaluationForEach[i + 1] / sumAll) + distribution[i];
treeMap.put((float) distribution[i + 1], i + 1);
}
return treeMap;
}
public int[] getEvaluationForEach(ArrayList<int[]> generation) {
int[] evaluationForEach = new int[generation.size()];
for (int i = 0; i < generation.size(); i++) {
evaluationForEach[i] = factory.evaluateGrid(generation.get(i));
}
return evaluationForEach;
}
private static HashSet<Integer> getRandomIntsDistributionTreeMap(TreeMap<Float, Integer> treeMap, int numberOfChosenInts) {
HashSet<Integer> selectedInts = new HashSet<>();
while (numberOfChosenInts > selectedInts.size()) {
var randFloat = random.nextFloat();
float chosenKey = treeMap.tailMap(randFloat).firstKey();
selectedInts.add(treeMap.get(chosenKey));
}
return selectedInts;
}
private static HashSet<Integer> getRandomInts(int maxIntRange, int numberOfChosenInts) {
HashSet<Integer> selectedInts = new HashSet<>();
while (numberOfChosenInts > selectedInts.size()) {
selectedInts.add(random.nextInt(maxIntRange));
}
return selectedInts;
}
private int[] getBestForIndexes(ArrayList<int[]> generation, HashSet<Integer> chosenIndexes) {
int[] selected = null;
int selectedEval = -1;
for (var value : chosenIndexes) {
if (selected == null) {
selected = generation.get(value);
selectedEval = factory.evaluateGrid(selected);
}
else {
var newSelected = generation.get(value);
var newSelectedEval = factory.evaluateGrid(newSelected);
if(newSelectedEval < selectedEval) {
selected = newSelected;
selectedEval = newSelectedEval;
}
}
}
return selected;
}
private int[] selectionTournament(ArrayList<int[]> generation, int N) {
var chosenInts = getRandomInts(generation.size(), N);
return getBestForIndexes(generation, chosenInts);
}
private int[] selectionRouletteAndTournament(ArrayList<int[]> generation, int N) {
var distributionTreeMap = getDistributionTreeMap(generation);
var chosenInts = getRandomIntsDistributionTreeMap(distributionTreeMap, N);
return getBestForIndexes(generation, chosenInts);
}
private int[] selectionRoulette(ArrayList<int[]> generation) {
var distributionTreeMap = getDistributionTreeMap(generation);
var randFloat = random.nextFloat();
float chosenKey = distributionTreeMap.tailMap(randFloat).firstKey();
return generation.get(distributionTreeMap.get(chosenKey));
}
public int[] selection(SelectionEnum selectionEnum, ArrayList<int[]> generation, int N) {
if(selectionEnum == SelectionEnum.ROULETTE) {
return selectionRoulette(generation);
// return selectionRouletteAndTournament(generation, N);
}
if(selectionEnum == SelectionEnum.TOURNAMENT) {
return selectionTournament(generation, N);
}
throw new IllegalArgumentException("zad2.SelectionEnum: " + selectionEnum + " is incorrect.");
}
}
class SelectionTest {
static String folderPath = "F:\\sztuczna_inteligencja\\flo_dane_v1.2";
static Factory factory;
static Selection selection;
static ArrayList<int[]> generation;
@BeforeAll
public static void beforeAll() {
factory = new Factory(InstanceEnum.HARD, folderPath); // 5x6
selection = new Selection(factory);
generation = factory.getRandomGeneration(1000);
}
@Test
public void testDistribution() {
var tree = selection.getDistributionTreeMap(generation);
var treeKeys = tree.keySet();
Arrays.sort(treeKeys.toArray());
}
// @Test
// public void testTournament() {
// for (var N : new int[]{10, 50, 100, 500, 1000}){
// var selected = selection.selection(SelectionEnum.TOURNAMENT, generation, N);
// System.out.println(factory.evaluateGrid(selected));
//// System.out.println(Arrays.toString(selected));
// }
// }
//
// @Test
// public void testRoulette() {
// for (var N : new int[]{10, 50, 100, 500, 1000}){
// var selected = selection.selection(SelectionEnum.ROULETTE, generation, N);
// System.out.println(factory.evaluateGrid(selected));
//// System.out.println(Arrays.toString(selected));
// }
// }
}