Skip to content

Commit

Permalink
Use different flags for IPV6_RECVERR and IP_RECVERR
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 446361103
  • Loading branch information
ghananigans authored and gvisor-bot committed May 4, 2022
1 parent 13cc10b commit da0c67b
Show file tree
Hide file tree
Showing 6 changed files with 106 additions and 20 deletions.
8 changes: 4 additions & 4 deletions pkg/sentry/socket/netstack/netstack.go
Original file line number Diff line number Diff line change
Expand Up @@ -1447,7 +1447,7 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name
return nil, syserr.ErrInvalidArgument
}

v := primitive.Int32(boolToInt32(ep.SocketOptions().GetRecvError()))
v := primitive.Int32(boolToInt32(ep.SocketOptions().GetIPv6RecvError()))
return &v, nil

case linux.IPV6_RECVORIGDSTADDR:
Expand Down Expand Up @@ -1647,7 +1647,7 @@ func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in
return nil, syserr.ErrInvalidArgument
}

v := primitive.Int32(boolToInt32(ep.SocketOptions().GetRecvError()))
v := primitive.Int32(boolToInt32(ep.SocketOptions().GetIPv4RecvError()))
return &v, nil

case linux.IP_PKTINFO:
Expand Down Expand Up @@ -2338,7 +2338,7 @@ func setSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name
if err != nil {
return err
}
ep.SocketOptions().SetRecvError(v != 0)
ep.SocketOptions().SetIPv6RecvError(v != 0)
return nil

case linux.IP6T_SO_SET_REPLACE:
Expand Down Expand Up @@ -2555,7 +2555,7 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in
if err != nil {
return err
}
ep.SocketOptions().SetRecvError(v != 0)
ep.SocketOptions().SetIPv4RecvError(v != 0)
return nil

case linux.IP_PKTINFO:
Expand Down
37 changes: 27 additions & 10 deletions pkg/tcpip/socketops.go
Original file line number Diff line number Diff line change
Expand Up @@ -212,9 +212,13 @@ type SocketOptions struct {
// the incoming packet should be returned as an ancillary message.
receiveOriginalDstAddress atomicbitops.Uint32

// recvErrEnabled determines whether extended reliable error message passing
// is enabled.
recvErrEnabled atomicbitops.Uint32
// ipv4RecvErrEnabled determines whether extended reliable error message
// passing is enabled for IPv4.
ipv4RecvErrEnabled atomicbitops.Uint32

// ipv6RecvErrEnabled determines whether extended reliable error message
// passing is enabled for IPv6.
ipv6RecvErrEnabled atomicbitops.Uint32

// errQueue is the per-socket error queue. It is protected by errQueueMu.
errQueueMu sync.Mutex `state:"nosave"`
Expand Down Expand Up @@ -470,14 +474,27 @@ func (so *SocketOptions) SetReceiveOriginalDstAddress(v bool) {
storeAtomicBool(&so.receiveOriginalDstAddress, v)
}

// GetRecvError gets value for IP*_RECVERR option.
func (so *SocketOptions) GetRecvError() bool {
return so.recvErrEnabled.Load() != 0
// GetIPv4RecvError gets value for IP_RECVERR option.
func (so *SocketOptions) GetIPv4RecvError() bool {
return so.ipv4RecvErrEnabled.Load() != 0
}

// SetIPv4RecvError sets value for IP_RECVERR option.
func (so *SocketOptions) SetIPv4RecvError(v bool) {
storeAtomicBool(&so.ipv4RecvErrEnabled, v)
if !v {
so.pruneErrQueue()
}
}

// GetIPv6RecvError gets value for IPV6_RECVERR option.
func (so *SocketOptions) GetIPv6RecvError() bool {
return so.ipv6RecvErrEnabled.Load() != 0
}

// SetRecvError sets value for IP*_RECVERR option.
func (so *SocketOptions) SetRecvError(v bool) {
storeAtomicBool(&so.recvErrEnabled, v)
// SetIPv6RecvError sets value for IPV6_RECVERR option.
func (so *SocketOptions) SetIPv6RecvError(v bool) {
storeAtomicBool(&so.ipv6RecvErrEnabled, v)
if !v {
so.pruneErrQueue()
}
Expand Down Expand Up @@ -627,7 +644,7 @@ func (so *SocketOptions) PeekErr() *SockError {

// QueueErr inserts the error at the back of the error queue.
//
// Preconditions: so.GetRecvError() == true.
// Preconditions: so.GetIPv4RecvError() or so.GetIPv6RecvError() is true.
func (so *SocketOptions) QueueErr(err *SockError) {
so.errQueueMu.Lock()
defer so.errQueueMu.Unlock()
Expand Down
3 changes: 2 additions & 1 deletion pkg/tcpip/tests/integration/link_resolution_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,8 @@ func TestTCPLinkResolutionFailure(t *testing.T) {
defer clientEP.Close()

sockOpts := clientEP.SocketOptions()
sockOpts.SetRecvError(true)
sockOpts.SetIPv4RecvError(true)
sockOpts.SetIPv6RecvError(true)

remoteAddr := listenerAddr
remoteAddr.Addr = test.remoteAddr
Expand Down
13 changes: 11 additions & 2 deletions pkg/tcpip/transport/tcp/endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -2840,8 +2840,17 @@ func (e *endpoint) onICMPError(err tcpip.Error, transErr stack.TransportError, p
e.lastError = err
e.lastErrorMu.Unlock()

// Update the error queue if IP_RECVERR is enabled.
if e.SocketOptions().GetRecvError() {
var recvErr bool
switch pkt.NetworkProtocolNumber {
case header.IPv4ProtocolNumber:
recvErr = e.SocketOptions().GetIPv4RecvError()
case header.IPv6ProtocolNumber:
recvErr = e.SocketOptions().GetIPv6RecvError()
default:
panic(fmt.Sprintf("unhandled network protocol number = %d", pkt.NetworkProtocolNumber))
}

if recvErr {
e.SocketOptions().QueueErr(&tcpip.SockError{
Err: err,
Cause: transErr,
Expand Down
15 changes: 12 additions & 3 deletions pkg/tcpip/transport/udp/endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,7 @@ func (e *endpoint) prepareForWrite(p tcpip.Payloader, opts tcpip.WriteOptions) (
// errors aren't report to the error queue at all.
if ctx.PacketInfo().NetProto == header.IPv6ProtocolNumber {
so := e.SocketOptions()
if so.GetRecvError() {
if so.GetIPv6RecvError() {
so.QueueLocalErr(
&tcpip.ErrMessageTooLong{},
e.net.NetProto(),
Expand Down Expand Up @@ -996,8 +996,17 @@ func (e *endpoint) onICMPError(err tcpip.Error, transErr stack.TransportError, p
e.lastError = err
e.lastErrorMu.Unlock()

// Update the error queue if IP_RECVERR is enabled.
if e.SocketOptions().GetRecvError() {
var recvErr bool
switch pkt.NetworkProtocolNumber {
case header.IPv4ProtocolNumber:
recvErr = e.SocketOptions().GetIPv4RecvError()
case header.IPv6ProtocolNumber:
recvErr = e.SocketOptions().GetIPv6RecvError()
default:
panic(fmt.Sprintf("unhandled network protocol number = %d", pkt.NetworkProtocolNumber))
}

if recvErr {
// Linux passes the payload without the UDP header.
var payload []byte
udp := header.UDP(pkt.Data().AsRange().ToOwnedView())
Expand Down
50 changes: 50 additions & 0 deletions test/syscalls/linux/udp_socket.cc
Original file line number Diff line number Diff line change
Expand Up @@ -811,6 +811,48 @@ TEST_P(UdpSocketTest, ConnectAndSendNoReceiver) {
}

#ifdef __linux__
TEST_P(UdpSocketTest, RecvErrorConnRefusedOtherAFSockOpt) {
int got;
socklen_t got_len = sizeof(got);
if (GetParam() == AF_INET) {
EXPECT_THAT(setsockopt(sock_.get(), SOL_IPV6, IPV6_RECVERR, &kSockOptOn,
sizeof(kSockOptOn)),
SyscallFailsWithErrno(ENOPROTOOPT));
EXPECT_THAT(getsockopt(sock_.get(), SOL_IPV6, IPV6_RECVERR, &got, &got_len),
SyscallFailsWithErrno(ENOTSUP));
ASSERT_THAT(got_len, sizeof(got));
return;
}
ASSERT_THAT(setsockopt(sock_.get(), SOL_IP, IP_RECVERR, &kSockOptOn,
sizeof(kSockOptOn)),
SyscallSucceeds());
{
EXPECT_THAT(getsockopt(sock_.get(), SOL_IP, IP_RECVERR, &got, &got_len),
SyscallSucceeds());
ASSERT_THAT(got_len, sizeof(got));
EXPECT_THAT(got, kSockOptOn);
}

// We will simulate an ICMP error and verify that we don't receive that error
// via recvmsg(MSG_ERRQUEUE) since we set another address family's RECVERR
// flag.
ASSERT_NO_ERRNO(BindLoopback());
ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds());
// Close the bind socket to release the port so that we get an ICMP error
// when sending packets to it.
ASSERT_THAT(close(bind_.release()), SyscallSucceeds());

// Send to an unbound port which should trigger a port unreachable error.
char buf[1];
EXPECT_THAT(send(sock_.get(), buf, sizeof(buf), 0),
SyscallSucceedsWithValue(sizeof(buf)));

// Should not have the error since we did not set the right socket option.
msghdr msg = {};
EXPECT_THAT(recvmsg(sock_.get(), &msg, MSG_ERRQUEUE),
SyscallFailsWithErrno(EAGAIN));
}

TEST_P(UdpSocketTest, RecvErrorConnRefused) {
// We will simulate an ICMP error and verify that we do receive that error via
// recvmsg(MSG_ERRQUEUE).
Expand All @@ -829,6 +871,14 @@ TEST_P(UdpSocketTest, RecvErrorConnRefused) {
}
ASSERT_THAT(setsockopt(sock_.get(), opt_level, opt_type, &v, optlen),
SyscallSucceeds());
{
int got;
socklen_t got_len = sizeof(got);
EXPECT_THAT(getsockopt(sock_.get(), opt_level, opt_type, &got, &got_len),
SyscallSucceeds());
ASSERT_THAT(got_len, sizeof(got));
EXPECT_THAT(got, kSockOptOn);
}

// Connect to loopback:bind_addr_ which should *hopefully* not be bound by an
// UDP socket. There is no easy way to ensure that the UDP port is not bound
Expand Down

0 comments on commit da0c67b

Please sign in to comment.