-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathapi_linux.go
280 lines (233 loc) · 7.32 KB
/
api_linux.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
//go:build linux
package tunnel
import (
"errors"
"fmt"
"os"
"sync"
"syscall"
"golang.org/x/sys/unix"
)
const tunModuleCharPath = "/dev/net/tun"
const tcpOffloads = unix.TUN_F_CSUM | unix.TUN_F_TSO4 | unix.TUN_F_TSO6
const udpOffloads = unix.TUN_F_USO4 | unix.TUN_F_USO6
const maxPacketLen = 1<<16 - 1
const virtioBuffLen = maxPacketLen + virtioNetHdrLen
const rcveBuffLen = 1 << 14
const sendBuffLen = 1 << 14
const pktBuffArrs = 1 << 6
type readWriteHandler [2]func(*tunDevice, []byte) (int, error)
type readBuffers struct {
virtbuff,
copybuff, unreadPtr []byte
}
type writeBuffers struct {
virtbuff []byte
pktsbuff [][]byte
pktIndex []int
buff_pos int
tcpGroTable *tcpGROTable
udpGroTable *udpGROTable
}
type platformMethods struct{ name string }
type tunDevice struct {
rwHandler readWriteHandler
tunReadMux sync.Mutex
tunWriteMux sync.Mutex
file *os.File
r_buff *readBuffers
w_buff *writeBuffers
udpGsoEnabled bool
}
// Name returns tun interface name
func (iface platformMethods) Name() string { return iface.name }
// Read reads packets from the tunnel
func (dev *tunDevice) Read(b []byte) (int, error) { return dev.rwHandler[0](dev, b) }
// Write writes packets to the tunnel
// when gso/gro is enabled, it is recommended to provide the data as a
// sequence of packets in a buffer to the write method, in this case, the first IP
// packet must be valid; otherwise, an error is returned.
// if subsequent sequential packets in the buffer are incorrect or corrupted, a nil error
// and the number of bytes written up to the corrupted packet are returned. for the first
// packet in the buffer, if the error is ErrShortPacket or ErrFragmentedPacket, it is
// usually because the packet is incomplete in the buffer. In this situation,
// the user may retry writing the complete packet again
func (dev *tunDevice) Write(b []byte) (int, error) { return dev.rwHandler[1](dev, b) }
// Close closes the tunnel
func (dev *tunDevice) Close() error { return dev.file.Close() }
func genericRead(dev *tunDevice, buff []byte) (int, error) { return dev.file.Read(buff) }
func genericWrite(dev *tunDevice, buff []byte) (int, error) { return dev.file.Write(buff) }
func openDevice(config Config) (iface *Iface, err error) {
nfd, err := unix.Open(tunModuleCharPath, unix.O_CLOEXEC|unix.O_RDWR, 0)
if err != nil {
return iface, fmt.Errorf("tunnel: open %s: %v", tunModuleCharPath, err)
}
iface, err = setupTunDevice(nfd, config)
if err != nil {
unix.Close(nfd)
}
return
}
func setupTunDevice(nfd int, config Config) (*Iface, error) {
ifreq, err := unix.NewIfreq(config.Name)
if err != nil {
return nil, fmt.Errorf("tunnel: create iface: %v", err)
}
var flags uint16 = unix.IFF_NO_PI | unix.IFF_TUN
if config.MultiQueue {
flags |= unix.IFF_MULTI_QUEUE
}
if !config.DisableGsoGro {
flags |= unix.IFF_VNET_HDR
}
ifreq.SetUint16(flags)
err = unix.IoctlIfreq(nfd, unix.TUNSETIFF, ifreq)
if err != nil {
return nil, fmt.Errorf("tunnel: ioctl ifreq: %v", err)
}
err = unix.SetNonblock(nfd, true)
if err != nil {
return nil, fmt.Errorf("tunnel: set nonblock: %v", err)
}
if err := setDeviceOptions(nfd, config); err != nil {
return nil, fmt.Errorf("tunnel: device option: %v", err)
}
iface := &Iface{platformMethods: platformMethods{name: ifreq.Name()}}
dev := new(tunDevice)
dev.file = os.NewFile(uintptr(nfd), ifreq.Name())
dev.rwHandler[0], dev.rwHandler[1] = genericRead, genericWrite
if flags&unix.IFF_VNET_HDR != 0 {
if err := dev.setTunnelvnetHdr(iface.Name()); err != nil {
return nil, err
}
}
iface.ReadWriteCloser = dev
return iface, nil
}
func setDeviceOptions(fd int, config Config) error {
if config.Permissions != nil {
own := config.Permissions.Owner
err := unix.IoctlSetInt(fd, unix.TUNSETOWNER, int(own))
if err != nil {
return fmt.Errorf("set tunnel owner: %v", err)
}
grp := config.Permissions.Group
err = unix.IoctlSetInt(fd, unix.TUNSETGROUP, int(grp))
if err != nil {
return fmt.Errorf("set tunnel group: %v", err)
}
}
persistflag := 0
if config.Persist {
persistflag = 1
}
err := unix.IoctlSetInt(fd, unix.TUNSETPERSIST, persistflag)
if err != nil {
return fmt.Errorf("set persist flag: %v", err)
}
return err
}
func (p *tunDevice) setTunnelvnetHdr(name string) error {
var vnethdr bool
sysconn, err := p.file.SyscallConn()
if err != nil {
return fmt.Errorf("tunnel: syscall conn: %v", err)
}
if errconn := sysconn.Control(func(fd uintptr) {
var ifReq *unix.Ifreq
ifReq, err = unix.NewIfreq(name)
if err != nil {
err = fmt.Errorf("tunnel: iface request: %v", err)
return
}
err = unix.IoctlIfreq(int(fd), unix.TUNGETIFF, ifReq)
if err != nil {
err = fmt.Errorf("tunnel: ioctl get iface: %v", err)
return
}
reqFlags := ifReq.Uint16()
if reqFlags&unix.IFF_VNET_HDR != 0 {
err = unix.IoctlSetInt(int(fd), unix.TUNSETOFFLOAD, tcpOffloads)
if err != nil {
err = fmt.Errorf("tunnel: set offload: %v", err)
return
}
vnethdr = true
p.udpGsoEnabled = unix.IoctlSetInt(int(fd),
unix.TUNSETOFFLOAD, tcpOffloads|udpOffloads) == nil
return
}
}); errconn != nil {
return fmt.Errorf("tunnel: set ctrl func: %v", errconn)
}
if vnethdr {
p.rwHandler[0], p.rwHandler[1] = vnetHdrRead, vnetHdrWrite
p.r_buff, p.w_buff = new(readBuffers), new(writeBuffers)
p.r_buff.virtbuff = make([]byte, virtioBuffLen)
p.r_buff.copybuff = make([]byte, rcveBuffLen)
p.w_buff.virtbuff = make([]byte, sendBuffLen)
p.w_buff.pktsbuff = make([][]byte, 0, pktBuffArrs)
p.w_buff.pktIndex = make([]int, 0, pktBuffArrs)
p.w_buff.tcpGroTable = newTCPGROTable()
p.w_buff.udpGroTable = newUDPGROTable()
}
return err
}
func vnetHdrRead(dev *tunDevice, buff []byte) (int, error) {
dev.tunReadMux.Lock()
defer dev.tunReadMux.Unlock()
defer func() {
if len(dev.r_buff.unreadPtr) == 0 && cap(dev.r_buff.copybuff) > rcveBuffLen {
dev.r_buff.unreadPtr = nil
dev.r_buff.copybuff = dev.r_buff.copybuff[:rcveBuffLen:rcveBuffLen]
}
}()
cp_from_buff := copy(buff, dev.r_buff.unreadPtr)
dev.r_buff.unreadPtr = dev.r_buff.unreadPtr[cp_from_buff:]
if cp_from_buff == len(buff) {
return cp_from_buff, nil
}
index, err := dev.file.Read(dev.r_buff.virtbuff)
if errors.Is(err, unix.EBADFD) {
err = os.ErrClosed
}
if err != nil {
return cp_from_buff, err
}
nb, err := dev.virtioRead(index)
dev.r_buff.unreadPtr = dev.r_buff.copybuff[:nb]
cp_left := copy(buff[cp_from_buff:], dev.r_buff.unreadPtr)
dev.r_buff.unreadPtr = dev.r_buff.unreadPtr[cp_left:]
return cp_from_buff + cp_left, err
}
func vnetHdrWrite(dev *tunDevice, buff []byte) (int, error) {
dev.tunWriteMux.Lock()
defer dev.tunWriteMux.Unlock()
defer func() {
dev.w_buff.tcpGroTable.reset()
dev.w_buff.udpGroTable.reset()
dev.w_buff.buff_pos = 0
if len(dev.w_buff.virtbuff) > sendBuffLen {
dev.w_buff.virtbuff = dev.w_buff.virtbuff[:sendBuffLen:sendBuffLen]
}
dev.w_buff.pktsbuff = dev.w_buff.pktsbuff[:0:pktBuffArrs]
dev.w_buff.pktIndex = dev.w_buff.pktIndex[:0:pktBuffArrs]
}()
nw, err := dev.slicePackets(buff)
if err != nil {
return nw, err
}
if err := dev.virtioMakeGro(); err != nil {
return 0, err
}
for _, pktIndex := range dev.w_buff.pktIndex {
_, err := dev.file.Write(dev.w_buff.pktsbuff[pktIndex][:])
if errors.Is(err, syscall.EBADFD) {
return 0, os.ErrClosed
}
if err != nil {
return 0, err
}
}
return nw, nil
}