diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 6e917696dd..daac7bc9cc 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -1447,7 +1447,7 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name return nil, syserr.ErrInvalidArgument } - v := primitive.Int32(boolToInt32(ep.SocketOptions().GetRecvError())) + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetIPv6RecvError())) return &v, nil case linux.IPV6_RECVORIGDSTADDR: @@ -1647,7 +1647,7 @@ func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in return nil, syserr.ErrInvalidArgument } - v := primitive.Int32(boolToInt32(ep.SocketOptions().GetRecvError())) + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetIPv4RecvError())) return &v, nil case linux.IP_PKTINFO: @@ -2338,7 +2338,7 @@ func setSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name if err != nil { return err } - ep.SocketOptions().SetRecvError(v != 0) + ep.SocketOptions().SetIPv6RecvError(v != 0) return nil case linux.IP6T_SO_SET_REPLACE: @@ -2555,7 +2555,7 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in if err != nil { return err } - ep.SocketOptions().SetRecvError(v != 0) + ep.SocketOptions().SetIPv4RecvError(v != 0) return nil case linux.IP_PKTINFO: diff --git a/pkg/tcpip/socketops.go b/pkg/tcpip/socketops.go index b57c0fb38b..1bc3723b10 100644 --- a/pkg/tcpip/socketops.go +++ b/pkg/tcpip/socketops.go @@ -212,9 +212,13 @@ type SocketOptions struct { // the incoming packet should be returned as an ancillary message. receiveOriginalDstAddress atomicbitops.Uint32 - // recvErrEnabled determines whether extended reliable error message passing - // is enabled. - recvErrEnabled atomicbitops.Uint32 + // ipv4RecvErrEnabled determines whether extended reliable error message + // passing is enabled for IPv4. + ipv4RecvErrEnabled atomicbitops.Uint32 + + // ipv6RecvErrEnabled determines whether extended reliable error message + // passing is enabled for IPv6. + ipv6RecvErrEnabled atomicbitops.Uint32 // errQueue is the per-socket error queue. It is protected by errQueueMu. errQueueMu sync.Mutex `state:"nosave"` @@ -470,14 +474,27 @@ func (so *SocketOptions) SetReceiveOriginalDstAddress(v bool) { storeAtomicBool(&so.receiveOriginalDstAddress, v) } -// GetRecvError gets value for IP*_RECVERR option. -func (so *SocketOptions) GetRecvError() bool { - return so.recvErrEnabled.Load() != 0 +// GetIPv4RecvError gets value for IP_RECVERR option. +func (so *SocketOptions) GetIPv4RecvError() bool { + return so.ipv4RecvErrEnabled.Load() != 0 +} + +// SetIPv4RecvError sets value for IP_RECVERR option. +func (so *SocketOptions) SetIPv4RecvError(v bool) { + storeAtomicBool(&so.ipv4RecvErrEnabled, v) + if !v { + so.pruneErrQueue() + } +} + +// GetIPv6RecvError gets value for IPV6_RECVERR option. +func (so *SocketOptions) GetIPv6RecvError() bool { + return so.ipv6RecvErrEnabled.Load() != 0 } -// SetRecvError sets value for IP*_RECVERR option. -func (so *SocketOptions) SetRecvError(v bool) { - storeAtomicBool(&so.recvErrEnabled, v) +// SetIPv6RecvError sets value for IPV6_RECVERR option. +func (so *SocketOptions) SetIPv6RecvError(v bool) { + storeAtomicBool(&so.ipv6RecvErrEnabled, v) if !v { so.pruneErrQueue() } @@ -627,7 +644,7 @@ func (so *SocketOptions) PeekErr() *SockError { // QueueErr inserts the error at the back of the error queue. // -// Preconditions: so.GetRecvError() == true. +// Preconditions: so.GetIPv4RecvError() or so.GetIPv6RecvError() is true. func (so *SocketOptions) QueueErr(err *SockError) { so.errQueueMu.Lock() defer so.errQueueMu.Unlock() diff --git a/pkg/tcpip/tests/integration/link_resolution_test.go b/pkg/tcpip/tests/integration/link_resolution_test.go index 974bee0a62..a6f4b4b578 100644 --- a/pkg/tcpip/tests/integration/link_resolution_test.go +++ b/pkg/tcpip/tests/integration/link_resolution_test.go @@ -324,7 +324,8 @@ func TestTCPLinkResolutionFailure(t *testing.T) { defer clientEP.Close() sockOpts := clientEP.SocketOptions() - sockOpts.SetRecvError(true) + sockOpts.SetIPv4RecvError(true) + sockOpts.SetIPv6RecvError(true) remoteAddr := listenerAddr remoteAddr.Addr = test.remoteAddr diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index c3d73be89e..5e9f9adb9e 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -2840,8 +2840,17 @@ func (e *endpoint) onICMPError(err tcpip.Error, transErr stack.TransportError, p e.lastError = err e.lastErrorMu.Unlock() - // Update the error queue if IP_RECVERR is enabled. - if e.SocketOptions().GetRecvError() { + var recvErr bool + switch pkt.NetworkProtocolNumber { + case header.IPv4ProtocolNumber: + recvErr = e.SocketOptions().GetIPv4RecvError() + case header.IPv6ProtocolNumber: + recvErr = e.SocketOptions().GetIPv6RecvError() + default: + panic(fmt.Sprintf("unhandled network protocol number = %d", pkt.NetworkProtocolNumber)) + } + + if recvErr { e.SocketOptions().QueueErr(&tcpip.SockError{ Err: err, Cause: transErr, diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 17fd766cdd..3b5f4e5a6a 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -401,7 +401,7 @@ func (e *endpoint) prepareForWrite(p tcpip.Payloader, opts tcpip.WriteOptions) ( // errors aren't report to the error queue at all. if ctx.PacketInfo().NetProto == header.IPv6ProtocolNumber { so := e.SocketOptions() - if so.GetRecvError() { + if so.GetIPv6RecvError() { so.QueueLocalErr( &tcpip.ErrMessageTooLong{}, e.net.NetProto(), @@ -996,8 +996,17 @@ func (e *endpoint) onICMPError(err tcpip.Error, transErr stack.TransportError, p e.lastError = err e.lastErrorMu.Unlock() - // Update the error queue if IP_RECVERR is enabled. - if e.SocketOptions().GetRecvError() { + var recvErr bool + switch pkt.NetworkProtocolNumber { + case header.IPv4ProtocolNumber: + recvErr = e.SocketOptions().GetIPv4RecvError() + case header.IPv6ProtocolNumber: + recvErr = e.SocketOptions().GetIPv6RecvError() + default: + panic(fmt.Sprintf("unhandled network protocol number = %d", pkt.NetworkProtocolNumber)) + } + + if recvErr { // Linux passes the payload without the UDP header. var payload []byte udp := header.UDP(pkt.Data().AsRange().ToOwnedView()) diff --git a/test/syscalls/linux/udp_socket.cc b/test/syscalls/linux/udp_socket.cc index 7fd237eb22..d16a4ba840 100644 --- a/test/syscalls/linux/udp_socket.cc +++ b/test/syscalls/linux/udp_socket.cc @@ -811,6 +811,48 @@ TEST_P(UdpSocketTest, ConnectAndSendNoReceiver) { } #ifdef __linux__ +TEST_P(UdpSocketTest, RecvErrorConnRefusedOtherAFSockOpt) { + int got; + socklen_t got_len = sizeof(got); + if (GetParam() == AF_INET) { + EXPECT_THAT(setsockopt(sock_.get(), SOL_IPV6, IPV6_RECVERR, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallFailsWithErrno(ENOPROTOOPT)); + EXPECT_THAT(getsockopt(sock_.get(), SOL_IPV6, IPV6_RECVERR, &got, &got_len), + SyscallFailsWithErrno(ENOTSUP)); + ASSERT_THAT(got_len, sizeof(got)); + return; + } + ASSERT_THAT(setsockopt(sock_.get(), SOL_IP, IP_RECVERR, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + { + EXPECT_THAT(getsockopt(sock_.get(), SOL_IP, IP_RECVERR, &got, &got_len), + SyscallSucceeds()); + ASSERT_THAT(got_len, sizeof(got)); + EXPECT_THAT(got, kSockOptOn); + } + + // We will simulate an ICMP error and verify that we don't receive that error + // via recvmsg(MSG_ERRQUEUE) since we set another address family's RECVERR + // flag. + ASSERT_NO_ERRNO(BindLoopback()); + ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); + // Close the bind socket to release the port so that we get an ICMP error + // when sending packets to it. + ASSERT_THAT(close(bind_.release()), SyscallSucceeds()); + + // Send to an unbound port which should trigger a port unreachable error. + char buf[1]; + EXPECT_THAT(send(sock_.get(), buf, sizeof(buf), 0), + SyscallSucceedsWithValue(sizeof(buf))); + + // Should not have the error since we did not set the right socket option. + msghdr msg = {}; + EXPECT_THAT(recvmsg(sock_.get(), &msg, MSG_ERRQUEUE), + SyscallFailsWithErrno(EAGAIN)); +} + TEST_P(UdpSocketTest, RecvErrorConnRefused) { // We will simulate an ICMP error and verify that we do receive that error via // recvmsg(MSG_ERRQUEUE). @@ -829,6 +871,14 @@ TEST_P(UdpSocketTest, RecvErrorConnRefused) { } ASSERT_THAT(setsockopt(sock_.get(), opt_level, opt_type, &v, optlen), SyscallSucceeds()); + { + int got; + socklen_t got_len = sizeof(got); + EXPECT_THAT(getsockopt(sock_.get(), opt_level, opt_type, &got, &got_len), + SyscallSucceeds()); + ASSERT_THAT(got_len, sizeof(got)); + EXPECT_THAT(got, kSockOptOn); + } // Connect to loopback:bind_addr_ which should *hopefully* not be bound by an // UDP socket. There is no easy way to ensure that the UDP port is not bound