-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathTrainBR.java
116 lines (75 loc) · 2.82 KB
/
TrainBR.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
/*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
import java.io.BufferedReader;
import java.io.FileReader;
import java.io.PrintWriter;
import java.util.Iterator;
import meka.classifiers.multilabel.BR;
import weka.classifiers.functions.LibLINEAR;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Utils;
/**
* @author Felipe Bravo
* usage: java -jar TrainBR.java trainFile.arff testFile.arff testOriginal.csv predictions.csv
*/
public class TrainBR {
static public void main(String args[]) throws Exception{
String trainFile=args[0];
String testFile=args[1];
String testOriginal=args[2];
String output=args[3];
// reads train file
BufferedReader reader = new BufferedReader(
new FileReader(trainFile));
Instances train = new Instances(reader);
reader.close();
train.setClassIndex(11);
// reads test file
reader = new BufferedReader(
new FileReader(testFile));
Instances test = new Instances(reader);
reader.close();
test.setClassIndex(11);
// creates multi-label Meka model
BR mClass=new BR();
LibLINEAR ll=new LibLINEAR();
ll.setOptions(Utils.splitOptions("-S 1 -C 1.0 -E 0.001 -B 1.0 -L 0.1 -I 1000"));
mClass.setClassifier(ll);
// trains model
mClass.buildClassifier(train);
// reads original test file
reader= new BufferedReader(
new FileReader(testOriginal));
PrintWriter pw=new PrintWriter(output);
// copies header
pw.println(reader.readLine());
Iterator<Instance> instIt=test.iterator();
while(instIt.hasNext()){
String line=reader.readLine();
String parts[]=line.split("\t");
String outLine=parts[0]+"\t"+parts[1]+"\t";
Instance testInst=instIt.next();
double[] predictions=mClass.distributionForInstance(testInst);
// convert predictions into desired format
outLine += (int)predictions[0]+"\t"+(int)predictions[1]+"\t"+(int)predictions[2]+"\t"+(int)predictions[3]+"\t"+(int)predictions[4]+"\t"
+(int)predictions[5]+"\t"+(int)predictions[6]+"\t"+(int)predictions[7]+"\t"+(int)predictions[8]+"\t"
+(int)predictions[9]+"\t"+(int)predictions[10];
pw.println(outLine);
}
reader.close();
pw.close();
}
}