-
Notifications
You must be signed in to change notification settings - Fork 63
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #4 from bigdata-ustc/DKT+
[FEATURE] Add DKT+
- Loading branch information
Showing
26 changed files
with
1,269 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -109,4 +109,5 @@ venv.bak/ | |
# Pyre type checker | ||
.pyre/ | ||
|
||
# User Definition | ||
# User Definition | ||
data/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,7 @@ | ||
v0.0.5: | ||
* add DKT+ | ||
* add some util functions | ||
|
||
v0.0.4: | ||
* fix potential ModuleNotFoundError | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
# coding: utf-8 | ||
# 2021/5/25 @ tongshiwei | ||
|
||
import logging | ||
import torch | ||
from EduKTM import KTM | ||
from torch import nn | ||
import torch.nn.functional as F | ||
from tqdm import tqdm | ||
from EduKTM.utils import sequence_mask, SLMLoss, tensor2list, pick | ||
from sklearn.metrics import roc_auc_score, accuracy_score | ||
import numpy as np | ||
|
||
|
||
class DKTNet(nn.Module): | ||
def __init__(self, ku_num, hidden_num, add_embedding_layer=False, embedding_dim=None, dropout=0.0, **kwargs): | ||
super(DKTNet, self).__init__() | ||
self.ku_num = ku_num | ||
self.hidden_dim = hidden_num | ||
self.output_dim = ku_num | ||
if add_embedding_layer is True: | ||
embedding_dim = self.hidden_dim if embedding_dim is None else embedding_dim | ||
self.embeddings = nn.Sequential( | ||
nn.Embedding(ku_num * 2, embedding_dim), | ||
nn.Dropout(kwargs.get("embedding_dropout", 0.2)) | ||
) | ||
rnn_input_dim = embedding_dim | ||
else: | ||
self.embeddings = lambda x: F.one_hot(x, num_classes=self.output_dim * 2).float() | ||
rnn_input_dim = ku_num * 2 | ||
|
||
self.rnn = nn.RNN(rnn_input_dim, hidden_num, 1, batch_first=True, nonlinearity='tanh') | ||
self.fc = nn.Linear(self.hidden_dim, self.output_dim) | ||
self.dropout = nn.Dropout(dropout) | ||
self.sig = nn.Sigmoid() | ||
|
||
def forward(self, responses, mask=None, begin_state=None): | ||
responses = self.embeddings(responses) | ||
output, hn = self.rnn(responses) | ||
output = self.sig(self.fc(self.dropout(output))) | ||
if mask is not None: | ||
output = sequence_mask(output, mask) | ||
return output, hn | ||
|
||
|
||
class DKTPlus(KTM): | ||
def __init__(self, ku_num, hidden_num, net_params: dict = None, loss_params=None): | ||
super(DKTPlus, self).__init__() | ||
self.dkt_net = DKTNet( | ||
ku_num, | ||
hidden_num, | ||
**(net_params if net_params is not None else {}) | ||
) | ||
self.loss_params = loss_params if loss_params is not None else {} | ||
|
||
def train(self, train_data, test_data=None, *, epoch: int, device="cpu", lr=0.001) -> ...: | ||
loss_function = SLMLoss(**self.loss_params) | ||
|
||
trainer = torch.optim.Adam(self.dkt_net.parameters(), lr) | ||
|
||
for e in range(epoch): | ||
losses = [] | ||
for (data, data_mask, label, pick_index, label_mask) in tqdm(train_data, "Epoch %s" % e): | ||
# convert to device | ||
data: torch.Tensor = data.to(device) | ||
data_mask: torch.Tensor = data_mask.to(device) | ||
label: torch.Tensor = label.to(device) | ||
pick_index: torch.Tensor = pick_index.to(device) | ||
label_mask: torch.Tensor = label_mask.to(device) | ||
|
||
# real training | ||
predicted_response, _ = self.dkt_net(data, data_mask) | ||
loss = loss_function(predicted_response, pick_index, label, label_mask) | ||
|
||
# back propagation | ||
trainer.zero_grad() | ||
loss.backward() | ||
trainer.step() | ||
|
||
losses.append(loss.mean().item()) | ||
print("[Epoch %d] SLMoss: %.6f" % (e, float(np.mean(losses)))) | ||
|
||
if test_data is not None: | ||
auc, accuracy = self.eval(test_data) | ||
print("[Epoch %d] auc: %.6f, accuracy: %.6f" % (e, auc, accuracy)) | ||
|
||
def eval(self, test_data, device="cpu") -> tuple: | ||
self.dkt_net.eval() | ||
y_true = [] | ||
y_pred = [] | ||
|
||
for (data, data_mask, label, pick_index, label_mask) in tqdm(test_data, "evaluating"): | ||
# convert to device | ||
data: torch.Tensor = data.to(device) | ||
data_mask: torch.Tensor = data_mask.to(device) | ||
label: torch.Tensor = label.to(device) | ||
pick_index: torch.Tensor = pick_index.to(device) | ||
label_mask: torch.Tensor = label_mask.to(device) | ||
|
||
# real evaluating | ||
output, _ = self.dkt_net(data, data_mask) | ||
output = output[:, :-1] | ||
output = pick(output, pick_index.to(output.device)) | ||
pred = tensor2list(output) | ||
label = tensor2list(label) | ||
for i, length in enumerate(label_mask.numpy().tolist()): | ||
length = int(length) | ||
y_true.extend(label[i][:length]) | ||
y_pred.extend(pred[i][:length]) | ||
self.dkt_net.train() | ||
return roc_auc_score(y_true, y_pred), accuracy_score(y_true, np.array(y_pred) >= 0.5) | ||
|
||
def save(self, filepath) -> ...: | ||
torch.save(self.dkt_net.state_dict(), filepath) | ||
logging.info("save parameters to %s" % filepath) | ||
|
||
def load(self, filepath): | ||
self.dkt_net.load_state_dict(torch.load(filepath)) | ||
logging.info("load parameters from %s" % filepath) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# coding: utf-8 | ||
# 2021/5/25 @ tongshiwei | ||
|
||
from .DKTPlus import DKTPlus | ||
from .etl import etl |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
# coding: utf-8 | ||
# 2021/5/25 @ tongshiwei | ||
|
||
import torch | ||
import json | ||
from tqdm import tqdm | ||
from EduKTM.utils.torch_utils import PadSequence, FixedBucketSampler | ||
|
||
|
||
def extract(data_src): # pragma: no cover | ||
responses = [] | ||
step = 200 | ||
with open(data_src) as f: | ||
for line in tqdm(f, "reading data from %s" % data_src): | ||
data = json.loads(line) | ||
for i in range(0, len(data), step): | ||
if len(data[i: i + step]) < 2: | ||
continue | ||
responses.append(data[i: i + step]) | ||
|
||
return responses | ||
|
||
|
||
def transform(raw_data, batch_size, num_buckets=100): | ||
# 定义数据转换接口 | ||
# raw_data --> batch_data | ||
|
||
responses = raw_data | ||
|
||
batch_idxes = FixedBucketSampler([len(rs) for rs in responses], batch_size, num_buckets=num_buckets) | ||
batch = [] | ||
|
||
def index(r): | ||
correct = 0 if r[1] <= 0 else 1 | ||
return r[0] * 2 + correct | ||
|
||
for batch_idx in tqdm(batch_idxes, "batchify"): | ||
batch_rs = [] | ||
batch_pick_index = [] | ||
batch_labels = [] | ||
for idx in batch_idx: | ||
batch_rs.append([index(r) for r in responses[idx]]) | ||
if len(responses[idx]) <= 1: # pragma: no cover | ||
pick_index, labels = [], [] | ||
else: | ||
pick_index, labels = zip(*[(r[0], 0 if r[1] <= 0 else 1) for r in responses[idx][1:]]) | ||
batch_pick_index.append(list(pick_index)) | ||
batch_labels.append(list(labels)) | ||
|
||
max_len = max([len(rs) for rs in batch_rs]) | ||
padder = PadSequence(max_len, pad_val=0) | ||
batch_rs, data_mask = zip(*[(padder(rs), len(rs)) for rs in batch_rs]) | ||
|
||
max_len = max([len(rs) for rs in batch_labels]) | ||
padder = PadSequence(max_len, pad_val=0) | ||
batch_labels, label_mask = zip(*[(padder(labels), len(labels)) for labels in batch_labels]) | ||
batch_pick_index = [padder(pick_index) for pick_index in batch_pick_index] | ||
# Load | ||
batch.append( | ||
[torch.tensor(batch_rs), torch.tensor(data_mask), torch.tensor(batch_labels), | ||
torch.tensor(batch_pick_index), | ||
torch.tensor(label_mask)]) | ||
|
||
return batch | ||
|
||
|
||
def etl(data_src, batch_size, **kwargs): # pragma: no cover | ||
raw_data = extract(data_src) | ||
return transform(raw_data, batch_size, **kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,3 +4,4 @@ | |
from .meta import KTM | ||
from .KPT import KPT | ||
from .DKT import DKT | ||
from .DKTPlus import DKTPlus |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
# coding: utf-8 | ||
# 2021/5/24 @ tongshiwei | ||
|
||
from .utils import * | ||
from .loss import SequenceLogisticMaskLoss as SLMLoss | ||
from .torch_utils import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
# coding: utf-8 | ||
# 2021/5/24 @ tongshiwei | ||
__all__ = ["SequenceLogisticMaskLoss", "LogisticMaskLoss"] | ||
|
||
import torch | ||
from torch import nn | ||
|
||
from .torch_utils import pick, sequence_mask | ||
|
||
|
||
class SequenceLogisticMaskLoss(nn.Module): | ||
""" | ||
Notes | ||
----- | ||
The loss has been average, so when call the step method of trainer, batch_size should be 1 | ||
""" | ||
|
||
def __init__(self, lr=0.0, lw1=0.0, lw2=0.0): | ||
""" | ||
Parameters | ||
---------- | ||
lr: reconstruction | ||
lw1 | ||
lw2 | ||
""" | ||
super(SequenceLogisticMaskLoss, self).__init__() | ||
self.lr = lr | ||
self.lw1 = lw1 | ||
self.lw2 = lw2 | ||
self.loss = torch.nn.BCELoss(reduction='none') | ||
|
||
def forward(self, pred_rs, pick_index, label, label_mask): | ||
if self.lw1 > 0.0 or self.lw2 > 0.0: | ||
post_pred_rs = pred_rs[:, 1:] | ||
pre_pred_rs = pred_rs[:, :-1] | ||
diff = post_pred_rs - pre_pred_rs | ||
diff = sequence_mask(diff, label_mask) | ||
w1 = torch.mean(torch.norm(diff, 1, -1)) / diff.shape[-1] | ||
w2 = torch.mean(torch.norm(diff, 2, -1)) / diff.shape[-1] | ||
# w2 = F.mean(F.sqrt(diff ** 2)) | ||
w1 = w1 * self.lw1 if self.lw1 > 0.0 else 0.0 | ||
w2 = w2 * self.lw2 if self.lw2 > 0.0 else 0.0 | ||
else: | ||
w1 = 0.0 | ||
w2 = 0.0 | ||
|
||
if self.lr > 0.0: | ||
re_pred_rs = pred_rs[:, 1:] | ||
re_pred_rs = pick(re_pred_rs, pick_index) | ||
wr = sequence_mask(self.loss(re_pred_rs, label.float()), label_mask) | ||
wr = torch.mean(wr) * self.lr | ||
else: | ||
wr = 0.0 | ||
|
||
pred_rs = pred_rs[:, 1:] | ||
pred_rs = pick(pred_rs, pick_index) | ||
loss = sequence_mask(self.loss(pred_rs, label.float()), label_mask) | ||
# loss = F.sum(loss, axis=-1) | ||
loss = torch.mean(loss) + w1 + w2 + wr | ||
return loss | ||
|
||
|
||
class LogisticMaskLoss(nn.Module): # pragma: no cover | ||
""" | ||
Notes | ||
----- | ||
The loss has been average, so when call the step method of trainer, batch_size should be 1 | ||
""" | ||
|
||
def __init__(self): | ||
super(LogisticMaskLoss, self).__init__() | ||
|
||
self.loss = torch.nn.BCELoss() | ||
|
||
def forward(self, pred_rs, label, label_mask, *args, **kwargs): | ||
loss = sequence_mask(self.loss(pred_rs, label), label_mask) | ||
loss = torch.mean(loss) | ||
return loss |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
# coding: utf-8 | ||
# 2021/5/26 @ tongshiwei | ||
|
||
def pseudo_data_generation(ku_num, record_num=10, max_length=20): | ||
# 在这里定义测试用伪数据流 | ||
import random | ||
random.seed(10) | ||
|
||
raw_data = [ | ||
[ | ||
(random.randint(0, ku_num - 1), random.randint(-1, 1)) | ||
for _ in range(random.randint(2, max_length)) | ||
] for _ in range(record_num) | ||
] | ||
|
||
return raw_data |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# coding: utf-8 | ||
# 2021/5/25 @ tongshiwei | ||
|
||
from .extlib import * | ||
from .functional import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# coding: utf-8 | ||
# 2021/5/26 @ tongshiwei | ||
|
||
from .data import * | ||
from .sampler import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
# coding: utf-8 | ||
# 2021/5/25 @ tongshiwei | ||
# These codes are modified from gluonnlp | ||
|
||
__all__ = ["PadSequence"] | ||
|
||
|
||
class PadSequence: | ||
"""Pad the sequence. | ||
Pad the sequence to the given `length` by inserting `pad_val`. If `clip` is set, | ||
sequence that has length larger than `length` will be clipped. | ||
Parameters | ||
---------- | ||
length : int | ||
The maximum length to pad/clip the sequence | ||
pad_val : number | ||
The pad value. Default 0 | ||
clip : bool | ||
""" | ||
|
||
def __init__(self, length, pad_val=0, clip=True): | ||
self._length = length | ||
self._pad_val = pad_val | ||
self._clip = clip | ||
|
||
def __call__(self, sample): | ||
""" | ||
Parameters | ||
---------- | ||
sample : list of number or mx.nd.NDArray or np.ndarray | ||
Returns | ||
------- | ||
ret : list of number or mx.nd.NDArray or np.ndarray | ||
""" | ||
sample_length = len(sample) | ||
if sample_length >= self._length: | ||
if self._clip and sample_length > self._length: | ||
return sample[:self._length] | ||
else: | ||
return sample | ||
else: | ||
return sample + [ | ||
self._pad_val for _ in range(self._length - sample_length) | ||
] |
Oops, something went wrong.