-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathshow_MLR_graph.py
34 lines (28 loc) · 1.08 KB
/
show_MLR_graph.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import numpy as np
from ml import MLR, dataset
import matplotlib.pyplot as plt
plt.rc('font', family='Malgun Gothic')
#load datasets
TRAIN = 1800
TEST = 2350 #max
dataset.MIN = 0
dataset.MAX = 2350
date = dataset.DateTerm(1).values
d1 = dataset.Scale(dataset.KOSPI200FIDX())
d2 = dataset.Scale(dataset.NASDAQ100FIDX())
d3 = dataset.Scale(dataset.SAMSUNGELECSTKP())
#train, test split data-set
train_d1 = d1[0:TRAIN]; train_d2 = d2[0:TRAIN]; train_d3 = d3[0:TRAIN]
test_d1 = d1[TRAIN:TEST]; test_d2 = d2[TRAIN:TEST]; test_d3 = d3[TRAIN:TEST]
a1, a2, b, mse = MLR.train(train_d1, train_d2, train_d3)
print(mse[-1])
train_pred = a1*train_d1 + a2*train_d2 + b
test_pred = a1*test_d1 + a2*test_d2 + b
plt.plot(date, d1, date, d2, date, d3)
plt.plot(date[0:TRAIN], train_pred, linestyle=':')
plt.plot(date[TRAIN:TEST], test_pred, linestyle='--')
plt.legend(["KOSPI200FIDX", "NASDAQ100FIDX", "SAMSUNGELECSTKP", "학습 주가 예측", "비학습 주가 예측"])
plt.ylabel("정규화된 수치")
plt.xlabel("{0}--> 날짜 진행 -->{1}".format(date[0], date[-1]))
plt.xticks([], color='w')
plt.show()