-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathBayes.jl
131 lines (108 loc) · 2.54 KB
/
Bayes.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
using PyPlot
function muvar(M)
N, = size(M)
x1 = []
x2 = []
y1 = []
y2 = []
# First pass: mean
μₚʸ = 0
μₚⁿ = 0
μᵥʸ = 0
μᵥⁿ = 0
n = 0
for i in 1:N
if (M[i,3] == 1)
push!(x1,M[i,1])
push!(x2,M[i,2])
μₚʸ += M[i,1]
μᵥʸ += M[i,2]
n += 1
else
push!(y1,M[i,1])
push!(y2,M[i,2])
μₚⁿ += M[i,1]
μᵥⁿ += M[i,2]
end
end
μₚʸ /= n
μₚⁿ /= (N-n)
μᵥʸ /= n
μᵥⁿ /= (N-n)
# Second pass: variance
σ²ₚʸ = 0
σ²ₚⁿ = 0
σ²ᵥʸ = 0
σ²ᵥⁿ = 0
for i in 1:N
if (M[i,3] == 1)
σ²ₚʸ += (M[i,1] - μₚʸ)^2
σ²ᵥʸ += (M[i,2] - μᵥʸ)^2
else
σ²ₚⁿ += (M[i,1] - μₚⁿ)^2
σ²ᵥⁿ += (M[i,2] - μᵥⁿ)^2
end
end
σ²ₚʸ /= (n-1)
σ²ₚⁿ /= (N-n-1)
σ²ᵥʸ /= (n-1)
σ²ᵥⁿ /= (N-n-1)
return [μₚʸ,μₚⁿ,μᵥʸ,μᵥⁿ,σ²ₚʸ,σ²ₚⁿ,σ²ᵥʸ,σ²ᵥⁿ]
end
function Gaussian(x,μ,σ²)
return exp(-(x-μ)^2/(2σ²))/sqrt(2π*σ²)
end
function Classifier(price,volume,p)
μₚʸ,μₚⁿ,μᵥʸ,μᵥⁿ,σ²ₚʸ,σ²ₚⁿ,σ²ᵥʸ,σ²ᵥⁿ = p
P_buy = Gaussian(price,μₚʸ,σ²ₚʸ)*Gaussian(volume,μᵥʸ,σ²ᵥʸ)
P_not = Gaussian(price,μₚⁿ,σ²ₚⁿ)*Gaussian(volume,μᵥⁿ,σ²ᵥⁿ)
if P_buy > P_not
return 1
else
return 0
end
end
M = [10.3 100 0;
9.5 50 1;
11.5 90 0;
5.5 30 1;
7.5 120 1;
12.2 40 0;
7.1 80 1;
10.5 65 0]
N, = size(M)
x1 = []
x2 = []
y1 = []
y2 = []
for i in 1:N
if (M[i,3] == 1)
push!(x1,M[i,1])
push!(x2,M[i,2])
else
push!(y1,M[i,1])
push!(y2,M[i,2])
end
end
par = muvar(M)
p_sp = LinRange(5.2,12.5,100)
v_sp = LinRange(25.6,124.4,100)
K = zeros(100,100)
for i in 1:100
p = p_sp[i]
for j in 1:100
v = v_sp[j]
K[j,i] = Classifier(p,v,par)
end
end
contourf(p_sp,v_sp,K,alpha=0.5)
# plotting
#println("Buy")
#println(" mu_p: ",μₚʸ,", var_p: ",σ²ₚʸ,", mu_v: ",μᵥʸ,", var_v: ",σ²ᵥʸ)
#println("Don't")
#println(" mu_p: ",μₚⁿ,", var_p: ",σ²ₚⁿ,", mu_v: ",μᵥⁿ,", var_v: ",σ²ᵥⁿ)
plot(x1,x2,"s",color="blue")
plot(y1,y2,"o",color="orange")
xlabel("Price")
ylabel("Volume")
show()