Skip to content

Commit

Permalink
process: Merge pull request #1261 from svinota/1213-ns-init
Browse files Browse the repository at this point in the history
Bug-Url: #1261
Bug-Url: #1213
  • Loading branch information
svinota authored Feb 1, 2025
2 parents 3d0d36f + d657963 commit 0006ffe
Show file tree
Hide file tree
Showing 9 changed files with 296 additions and 67 deletions.
6 changes: 6 additions & 0 deletions .github/workflows/pull_request.yml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,12 @@ jobs:
- run: sudo chown -R $USER:$USER $GITHUB_WORKSPACE
- uses: actions/checkout@v4
- run: sudo make nox session=integration
process:
runs-on: Linux
steps:
- run: sudo chown -R $USER:$USER $GITHUB_WORKSPACE
- uses: actions/checkout@v4
- run: make nox session=process
minimal:
runs-on: Linux
steps:
Expand Down
9 changes: 9 additions & 0 deletions noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
'repo',
'unit',
'neutron',
'process',
'integration',
'linux-python3.9',
'linux-python3.10',
Expand Down Expand Up @@ -293,6 +294,14 @@ def core(session, config):
)


@nox.session
@add_session_config
def process(session, config):
'''Test child process module.'''
setup_venv_dev(session)
session.run(*options('test_process', config))


@nox.session
@add_session_config
def minimal(session, config):
Expand Down
1 change: 1 addition & 0 deletions pyroute2/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def parse_kernel_version(kernel_name):
db_transaction_limit = 1
cache_expire = 60

child_process_mode = 'fork'
signal_stop_remote = None
if hasattr(signal, 'SIGUSR1'):
signal_stop_remote = signal.SIGUSR1
Expand Down
43 changes: 8 additions & 35 deletions pyroute2/netlink/core.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
import asyncio
import builtins
import collections
import errno
import json
import logging
import multiprocessing
import os
import socket
import struct
import threading
from urllib import parse

from pyroute2 import config
from pyroute2 import config, netns
from pyroute2.common import AddrPool
from pyroute2.netlink import NLM_F_MULTI
from pyroute2.netns import setns
Expand Down Expand Up @@ -247,9 +245,6 @@ async def ensure_socket(self):
self.local.fileno = None
self.local.msg_queue = CoreMessageQueue()
# 8<-----------------------------------------
# Setup netns
self.local.fileno = self.setup_netns()
# 8<-----------------------------------------
self.local.socket = self.setup_socket()
if self.spec['netns'] is not None and config.mock_netlink:
self.local.socket.netns = self.spec['netns']
Expand Down Expand Up @@ -318,38 +313,16 @@ def setup_socket(self, sock=None):
sock = self.socket if sock is None else sock
if sock is not None:
sock.close()
sock = config.SocketBase(socket.AF_INET, socket.SOCK_STREAM)
sock = netns.create_socket(
self.spec['netns'],
socket.AF_INET,
socket.SOCK_STREAM,
flags=self.spec['flags'],
libc=self.libc,
)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
return sock

def setup_netns(self):
if self.spec['netns'] is not None and not config.mock_netlink:
# inspect self.__init__ argument names
ctrl = socket.socketpair()
nsproc = multiprocessing.Process(
target=netns_init,
args=(
ctrl[0],
self.spec['netns'],
self.spec['flags'],
self.libc,
type(self),
),
)
nsproc.start()
(data, fds, _, _) = socket.recv_fds(ctrl[1], 1024, 1)
# load the feedback
payload = json.loads(data.decode('utf-8'))
if payload:
if set(payload.keys()) != set(('name', 'args')):
raise TypeError('error loading netns feedback')
error_class = getattr(builtins, payload['name'])
if not issubclass(error_class, Exception):
raise TypeError('error loading netns error')
raise error_class(*payload['args'])
nsproc.join()
return fds[0]

def __getattr__(self, attr):
if attr in (
'getsockname',
Expand Down
7 changes: 5 additions & 2 deletions pyroute2/netlink/nlsocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@
import struct
from socket import SO_RCVBUF, SO_SNDBUF, SOCK_DGRAM, SOL_SOCKET

from pyroute2 import config
from pyroute2 import config, netns
from pyroute2.common import AddrPool, basestring, msg_done
from pyroute2.config import AF_NETLINK
from pyroute2.netlink import (
Expand Down Expand Up @@ -288,11 +288,14 @@ def setup_socket(self, sock=None):
sock = self.socket if sock is None else sock
if sock is not None:
sock.close()
sock = config.SocketBase(
sock = netns.create_socket(
self.spec['netns'],
AF_NETLINK,
SOCK_DGRAM,
self.spec['family'],
self.spec['fileno'] or self.local.fileno,
self.spec['flags'],
self.libc,
)
sock.setsockopt(SOL_SOCKET, SO_SNDBUF, self.status['sndbuf'])
sock.setsockopt(SOL_SOCKET, SO_RCVBUF, self.status['rcvbuf'])
Expand Down
2 changes: 0 additions & 2 deletions pyroute2/netlink/rtnl/iprsocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,8 +291,6 @@ def __init__(
)
self.asyncore.local = NotLocal()
self.asyncore.local.event_loop = self.asyncore.setup_event_loop()
if netns is not None:
self.asyncore.local.fileno = self.asyncore.setup_netns()
self.asyncore.local.socket = self.asyncore.setup_socket()

@property
Expand Down
80 changes: 52 additions & 28 deletions pyroute2/netns/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,14 +84,17 @@
import ctypes.util
import errno
import io
import logging
import os
import os.path
import pickle
import struct
import traceback
import socket
import time

from pyroute2 import config
from pyroute2.common import basestring
from pyroute2.process import ChildProcess, ChildProcessReturnValue

log = logging.getLogger(__name__)

try:
file = file
Expand Down Expand Up @@ -274,29 +277,11 @@ def create(netns, libc=None):
'''
Create a network namespace.
'''
rctl, wctl = os.pipe()
pid = os.fork()
if pid == 0:
# child
error = None
try:
_create(netns, libc)
except Exception as e:
error = e
error.tb = traceback.format_exc()
msg = pickle.dumps(error)
os.write(wctl, struct.pack('I', len(msg)))
os.write(wctl, msg)
os._exit(0)
else:
# parent
msglen = struct.unpack('I', os.read(rctl, 4))[0]
error = pickle.loads(os.read(rctl, msglen))
os.close(rctl)
os.close(wctl)
os.waitpid(pid, 0)
if error is not None:
raise error
proc = ChildProcess(target=_create, args=[netns, libc])
proc.run()
proc.communicate()
proc.stop(kill=True)
proc.close()


def attach(netns, pid, libc=None):
Expand All @@ -317,7 +302,7 @@ def remove(netns, libc=None):
os.unlink(netnspath)


def setns(netns, flags=os.O_CREAT, libc=None):
def setns(netns, flags=os.O_CREAT, libc=None, fork=True):
'''
Set netns for the current process.
Expand Down Expand Up @@ -346,7 +331,10 @@ def setns(netns, flags=os.O_CREAT, libc=None):
raise OSError(errno.EEXIST, 'netns exists', netns)
else:
if flags & os.O_CREAT:
create(netns, libc=libc)
if fork:
create(netns, libc=libc)
else:
_create(netns, libc=libc)
nsfd = os.open(netnspath, os.O_RDONLY)
newfd = True
elif isinstance(netns, file):
Expand Down Expand Up @@ -403,3 +391,39 @@ def dropns(libc=None):
os.close(fd)
except Exception:
pass


def _create_socket_child(nsname, flags, family, socket_type, proto, libc=None):
setns(nsname, flags=flags, libc=libc, fork=False)
sock = socket.socket(family, socket_type, proto)
return ChildProcessReturnValue(None, [sock])


def create_socket(
netns=None,
family=socket.AF_INET,
socket_type=socket.SOCK_STREAM,
proto=0,
fileno=None,
flags=os.O_CREAT,
libc=None,
timeout=5,
):
if fileno is not None and netns is not None:
raise TypeError('you can not specify both fileno and netns')
if fileno is not None:
return socket.socket(fileno=fileno)
if netns is None:
return socket.socket(family, socket_type, proto)

start_time = time.time()
while time.time() - start_time < 5:
with ChildProcess(
target=_create_socket_child,
args=[netns, flags, family, socket_type, proto, libc],
) as proc:
if (fds := proc.communicate(timeout=1)) is None:
continue
return socket.socket(fileno=fds[0])

raise TimeoutError('could not start netns socket within timeout')
Loading

0 comments on commit 0006ffe

Please sign in to comment.