diff --git a/pkg/station/lib/registration_ingest.go b/pkg/station/lib/registration_ingest.go index 35845ca1..09cbab0c 100644 --- a/pkg/station/lib/registration_ingest.go +++ b/pkg/station/lib/registration_ingest.go @@ -1,6 +1,7 @@ package lib import ( + "encoding/binary" "bytes" "context" "errors" @@ -418,6 +419,10 @@ func (rm *RegistrationManager) NewRegistrationC2SWrapper(c2sw *pb.C2SWrapper, in // If a C2SWrapper has a registration response at this stage EITHER auth was disabled OR it was // signed by a registration server and has overrides that should be applied var dstPort = -1 + + // Used to apply phantom IP overrides from the registration response + var ipOverride net.IP + if rr := c2sw.GetRegistrationResponse(); rr != nil { if rr.DstPort != nil { dstPort = int(rr.GetDstPort()) @@ -428,8 +433,21 @@ func (rm *RegistrationManager) NewRegistrationC2SWrapper(c2sw *pb.C2SWrapper, in if rr.GetTransportParams() != nil && !c2s.GetDisableRegistrarOverrides() { c2s.TransportParams = rr.GetTransportParams() } + if !includeV6 { + // apply the ipv4 address from the registration response, if rr.Ipv4Addr is not empty + if rr.Ipv4Addr != nil && *rr.Ipv4Addr != 0 { + ipv4Bytes := make([]byte, 4) + binary.BigEndian.PutUint32(ipv4Bytes, *rr.Ipv4Addr) + ipOverride = net.IP(ipv4Bytes) + } + } else { + // apply the ipv6 address from the registration response, if rr.Ipv6Addr is not empty + if rr.Ipv6Addr != nil { + ipOverride = net.IP(rr.Ipv6Addr) + } + + } - // TODO: future, apply the ip addresses from the Registration response (rr.IPv4Addr, rr.IPv6Addr) } reg, err := rm.NewRegistration(c2s, &conjureKeys, includeV6, ®Src) @@ -437,6 +455,12 @@ func (rm *RegistrationManager) NewRegistrationC2SWrapper(c2sw *pb.C2SWrapper, in return nil, fmt.Errorf("failed to build registration: %w", err) } + if ipOverride != nil { + // If the ipOverride (which is populated by Ipv4Addr or Ipv6Addr from the registration response) + // is not empty, use it to override the phantom IP that the station derived + reg.PhantomIp = ipOverride + } + clientAddr := net.IP(c2sw.GetRegistrationAddress()) if reg.PhantomIp.To4() != nil && clientAddr.To4() == nil {