-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpy-recur-helpers.h
152 lines (128 loc) · 3.74 KB
/
py-recur-helpers.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
#ifndef HAVE_PY_RECUR_HELPERS_H
#define HAVE_PY_RECUR_HELPERS_H
#include <Python.h>
#include "structmember.h"
#include "structseq.h"
#include "recur-nn.h"
#define BaseNet_HEAD \
PyObject_HEAD \
RecurNN *net; \
rnn_learning_method learning_method; \
float momentum; \
int batch_size; \
const char *filename; \
typedef struct {
BaseNet_HEAD
} BaseNet;
#define RNNPY_BASE_NET(x) ((BaseNet *)(x))
static PyObject *
BaseNet_save(BaseNet *self, PyObject *args, PyObject *kwds)
{
RecurNN *net = self->net;
const char *filename = NULL;
int backup = 1;
static char *kwlist[] = {"filename", /* z */
"backup", /* i */
NULL};
if (!PyArg_ParseTupleAndKeywords(args, kwds, "|zi", kwlist,
&filename,
&backup
)){
return NULL;
}
if (filename == NULL){
filename = self->filename;
}
int r = rnn_save_net(net, filename, backup);
if (r){
return PyErr_Format(PyExc_IOError, "could not save to %s",
filename);
}
return Py_BuildValue("");
}
static int
set_net_filename(BaseNet *self, const char *filename, const char *basename,
char *metadata)
{
char s[1000];
RecurNN *net = self->net;
if (filename){
self->filename = strdup(filename);
}
else {
u32 sig = rnn_hash32(metadata);
int wrote = snprintf(s, sizeof(s), "%s-%0" PRIx32 "i%d-h%d-o%d.net",
basename, sig, net->input_size, net->hidden_size, net->output_size);
if (wrote >= sizeof(s)){
PyErr_Format(PyExc_ValueError,
"filename is trying to be too long!");
return -1;
}
self->filename = strdup(s);
}
return 0;
}
/* Net_{get,set}float_{rnn,bptt}. These access float attributes that are
pointed to via an integer offset into the struct.
Without this we'd need separate access functions for each attribute.
*/
static UNUSED PyObject *
Net_getfloat_rnn(BaseNet *self, int *offset)
{
void *addr = ((void *)self->net) + *offset;
float f = *(float *)addr;
return PyFloat_FromDouble((double)f);
}
static UNUSED int
Net_setfloat_rnn(BaseNet *self, PyObject *value, int *offset)
{
PyObject *pyfloat = PyNumber_Float(value);
if (pyfloat == NULL){
return -1;
}
void *addr = ((void *)self->net) + *offset;
float f = PyFloat_AS_DOUBLE(pyfloat);
*(float *)addr = f;
return 0;
}
static UNUSED PyObject *
Net_getfloat_bptt(BaseNet *self, int *offset)
{
void *addr = ((void *)self->net->bptt) + *offset;
float f = *(float *)addr;
return PyFloat_FromDouble((double)f);
}
static UNUSED int
Net_setfloat_bptt(BaseNet *self, PyObject *value, int *offset)
{
PyObject *pyfloat = PyNumber_Float(value);
if (pyfloat == NULL){
return -1;
}
void *addr = ((void *)self->net->bptt) + *offset;
float f = PyFloat_AS_DOUBLE(pyfloat);
*(float *)addr = f;
return 0;
}
static int add_module_constants(PyObject* m)
{
int r = 0;
#define ADD_INT_CONSTANT(x) (PyModule_AddIntConstant(m, QUOTE(x), (RNN_ ##x)))
r = r || ADD_INT_CONSTANT(MOMENTUM_WEIGHTED);
r = r || ADD_INT_CONSTANT(MOMENTUM_NESTEROV);
r = r || ADD_INT_CONSTANT(MOMENTUM_SIMPLIFIED_NESTEROV);
r = r || ADD_INT_CONSTANT(MOMENTUM_CLASSICAL);
r = r || ADD_INT_CONSTANT(ADAGRAD);
r = r || ADD_INT_CONSTANT(ADADELTA);
r = r || ADD_INT_CONSTANT(RPROP);
r = r || ADD_INT_CONSTANT(RELU);
r = r || ADD_INT_CONSTANT(RESQRT);
r = r || ADD_INT_CONSTANT(RECLIP20);
r = r || ADD_INT_CONSTANT(INIT_ZERO);
r = r || ADD_INT_CONSTANT(INIT_FLAT);
r = r || ADD_INT_CONSTANT(INIT_FAN_IN);
r = r || ADD_INT_CONSTANT(INIT_RUNS);
#undef ADD_INT_CONSTANT
return r;
}
#endif