Skip to content

Commit

Permalink
feat: implement dgram for unix socket
Browse files Browse the repository at this point in the history
  • Loading branch information
HeartLinked committed Jan 18, 2025
1 parent 06589bf commit be727e9
Show file tree
Hide file tree
Showing 2 changed files with 383 additions and 92 deletions.
139 changes: 110 additions & 29 deletions api/ruxos_posix_api/src/imp/net.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,24 @@ fn addrun_convert(addr: *const ctypes::sockaddr_un) -> SocketAddrUnix {
}
}

fn addrun_convert_withlen(addr: *const ctypes::sockaddr_un, addrlen: usize) -> SocketAddrUnix {
unsafe {
let sun_family = *(addr as *const u16);
let mut sun_path_array = [0i8; 108];
if addrlen > 2 {
let len = (addrlen - 2).min(108);
let src = (addr as *const u8).add(2) as *const i8;
let dst = sun_path_array.as_mut_ptr();
core::ptr::copy_nonoverlapping(src, dst, len);
}
SocketAddrUnix {
sun_family,
sun_path: sun_path_array,
}
}
}

#[derive(Debug)]
pub enum UnifiedSocketAddress {
Net(SocketAddr),
Unix(SocketAddrUnix),
Expand Down Expand Up @@ -109,16 +127,14 @@ impl Socket {
let addr = from_sockaddr(socket_addr, addrlen)?;
Ok(tcpsocket.lock().bind(addr)?)
}
Socket::Unix(socket) => {
Socket::Unix(unixsocket) => {
if socket_addr.is_null() {
return Err(LinuxError::EFAULT);
}
if addrlen != size_of::<ctypes::sockaddr_un>() as _ {
return Err(LinuxError::EINVAL);
}
Ok(socket
.lock()
.bind(addrun_convert(socket_addr as *const ctypes::sockaddr_un))?)
Ok(unixsocket.lock().bind(addrun_convert_withlen(
socket_addr as *const ctypes::sockaddr_un,
addrlen.try_into().unwrap(),
))?)
}
}
}
Expand All @@ -141,34 +157,68 @@ impl Socket {
if socket_addr.is_null() {
return Err(LinuxError::EFAULT);
}
if addrlen != size_of::<ctypes::sockaddr_un>() as _ {
return Err(LinuxError::EINVAL);
}
Ok(socket
.lock()
.connect(addrun_convert(socket_addr as *const ctypes::sockaddr_un))?)
Ok(socket.lock().connect(addrun_convert_withlen(
socket_addr as *const ctypes::sockaddr_un,
addrlen.try_into().unwrap(),
))?)
}
}
}

fn sendto(&self, buf: &[u8], addr: SocketAddr) -> LinuxResult<usize> {
fn sendto(
&self,
buf: &[u8],
socket_addr: *const ctypes::sockaddr,
addrlen: ctypes::socklen_t,
) -> LinuxResult<usize> {
match self {
// diff: must bind before sendto
Socket::Udp(udpsocket) => Ok(udpsocket.lock().send_to(buf, addr)?),
Socket::Udp(udpsocket) => {
let addr = from_sockaddr(socket_addr, addrlen)?;
Ok(udpsocket.lock().send_to(buf, addr)?)
}
Socket::Tcp(_) => Err(LinuxError::EISCONN),
Socket::Unix(_) => Err(LinuxError::EISCONN),
Socket::Unix(unixsocket) => {
if socket_addr.is_null() {
return Err(LinuxError::EFAULT);
}
Ok(unixsocket.lock().sendto(
buf,
addrun_convert_withlen(
socket_addr as *const ctypes::sockaddr_un,
addrlen.try_into().unwrap(),
),
)?)
}
}
}

fn recvfrom(&self, buf: &mut [u8]) -> LinuxResult<(usize, Option<SocketAddr>)> {
fn recvfrom(&self, buf: &mut [u8]) -> LinuxResult<(usize, Option<UnifiedSocketAddress>)> {
match self {
// diff: must bind before recvfrom
Socket::Udp(udpsocket) => Ok(udpsocket
.lock()
.recv_from(buf)
.map(|res| (res.0, Some(res.1)))?),
Socket::Tcp(tcpsocket) => Ok(tcpsocket.lock().recv(buf, 0).map(|res| (res, None))?),
Socket::Unix(socket) => Ok(socket.lock().recv(buf, 0).map(|res| (res, None))?),
Socket::Udp(udpsocket) => {
let (size, addr) = udpsocket.lock().recv_from(buf)?;
Ok((size, Some(UnifiedSocketAddress::Net(addr))))
}
Socket::Tcp(tcpsocket) => {
let size = tcpsocket.lock().recv(buf, 0)?;
Ok((size, None))
}
Socket::Unix(unixsocket) => {
let guard = unixsocket.lock();
match guard.get_sockettype() {
// diff: must bind before recvfrom
UnixSocketType::SockDgram => {
let (size, addr) = guard.recvfrom(buf)?;
Ok((size, addr.map(UnifiedSocketAddress::Unix)))
}
UnixSocketType::SockStream => {
let size = guard.recv(buf, 0)?;
Ok((size, None))
}
_ => Err(LinuxError::EOPNOTSUPP),
}
}
}
}

Expand Down Expand Up @@ -358,6 +408,10 @@ pub fn sys_socket(domain: c_int, socktype: c_int, protocol: c_int) -> c_int {
Socket::Unix(Mutex::new(UnixSocket::new(UnixSocketType::SockStream)))
.add_to_fd_table()
}
(ctypes::AF_UNIX, ctypes::SOCK_DGRAM, 0) => {
Socket::Unix(Mutex::new(UnixSocket::new(UnixSocketType::SockDgram)))
.add_to_fd_table()
}
_ => Err(LinuxError::EINVAL),
}
})
Expand Down Expand Up @@ -432,16 +486,15 @@ pub fn sys_sendto(
socket_fd, buf_ptr as usize, len, flag, socket_addr as usize, addrlen
);
if socket_addr.is_null() {
debug!("sendto without address, use send instead");
return sys_send(socket_fd, buf_ptr, len, flag);
}

syscall_body!(sys_sendto, {
if buf_ptr.is_null() {
return Err(LinuxError::EFAULT);
}
let addr = from_sockaddr(socket_addr, addrlen)?;
let buf = unsafe { core::slice::from_raw_parts(buf_ptr as *const u8, len) };
Socket::from_fd(socket_fd)?.sendto(buf, addr)
Socket::from_fd(socket_fd)?.sendto(buf, socket_addr, addrlen)
})
}

Expand All @@ -455,7 +508,7 @@ pub fn sys_send(
flag: c_int, // currently not used
) -> ctypes::ssize_t {
debug!(
"sys_sendto <= {} {:#x} {} {}",
"sys_send <= {} {:#x} {} {}",
socket_fd, buf_ptr as usize, len, flag
);
syscall_body!(sys_send, {
Expand Down Expand Up @@ -483,20 +536,47 @@ pub unsafe fn sys_recvfrom(
socket_fd, buf_ptr as usize, len, flag, socket_addr as usize, addrlen as usize
);
if socket_addr.is_null() {
debug!("recvfrom without address, use recv instead");
return sys_recv(socket_fd, buf_ptr, len, flag);
}

syscall_body!(sys_recvfrom, {
if buf_ptr.is_null() || addrlen.is_null() {
warn!("recvfrom with null buffer or addrlen");
return Err(LinuxError::EFAULT);
}
let socket = Socket::from_fd(socket_fd)?;
let buf = unsafe { core::slice::from_raw_parts_mut(buf_ptr as *mut u8, len) };

let res = socket.recvfrom(buf)?;
if let Some(addr) = res.1 {
unsafe {
(*socket_addr, *addrlen) = into_sockaddr(addr);
match addr {
UnifiedSocketAddress::Net(addr) => unsafe {
(*socket_addr, *addrlen) = into_sockaddr(addr);
},
UnifiedSocketAddress::Unix(addr) => unsafe {
let sockaddr_un_size = addr.get_addr_len();
let sockaddr_un = SocketAddrUnix {
sun_family: 1 as u16, // AF_UNIX
sun_path: addr.sun_path,
};
let original_addrlen = *addrlen as usize;
*addrlen = sockaddr_un_size as ctypes::socklen_t;
if original_addrlen < sockaddr_un_size {
warn!("Provided addr buf is too small, returned address will be truncated");
core::ptr::copy_nonoverlapping(
&sockaddr_un as *const SocketAddrUnix as *const u8,
socket_addr as *mut u8,
original_addrlen,
);
} else {
core::ptr::copy_nonoverlapping(
&sockaddr_un as *const SocketAddrUnix as *const u8,
socket_addr as *mut u8,
sockaddr_un_size,
);
}
},
}
}
Ok(res.0)
Expand Down Expand Up @@ -721,6 +801,7 @@ pub fn sys_getsockopt(
);
}
syscall_body!(sys_getsockopt, {
return Ok(0);
if optval.is_null() {
return Err(LinuxError::EFAULT);
}
Expand Down
Loading

0 comments on commit be727e9

Please sign in to comment.