diff --git a/registration-api/main.go b/registration-api/main.go index cc18d88b..c79996f3 100644 --- a/registration-api/main.go +++ b/registration-api/main.go @@ -8,9 +8,11 @@ import ( "net" "net/http" "os" + "os/signal" "strconv" "strings" "sync" + "syscall" "github.com/BurntSushi/toml" "github.com/golang/protobuf/proto" @@ -191,51 +193,108 @@ func parseIP(addrPort string) *net.IP { } -func main() { - var s server - s.logger = log.New(os.Stdout, "[API] ", log.Ldate|log.Lmicroseconds) - s.messageAccepter = s.sendToZMQ +// setupReloadHandler spawns a lightweight thread to listen for reload signals +// and loads updated configurations for the registration-api process everytime +// the reload signal is received +func (s *server) setupReloadHandler() { + signalChan := make(chan os.Signal, 1) + + signal.Notify( + signalChan, + syscall.SIGHUP, // listen for SIGHUP as reload signal + ) + + // spawn a goroutine to handle os signals + go func() { + for { + <-signalChan + s.loadNewConfig() + } + }() +} - _, err := toml.DecodeFile(os.Getenv("CJ_API_CONFIG"), &s) +// loadNewConfig reads configuration for registration-api, and updates the in-memory config +// all configs other than the ports can be reloaded. Updating the port of ZMQ socket and/or +// the application should require a restart. +func (s *server) loadNewConfig() { + var newConfig *config + + _, err := toml.DecodeFile(os.Getenv("CJ_API_CONFIG"), &newConfig) if err != nil { s.logger.Fatalln("failed to load config:", err) } - // Should we log client IP addresses - s.logClientIP, err = strconv.ParseBool(os.Getenv("LOG_CLIENT_IP")) - if err != nil { - s.logger.Printf("failed parse client ip logging setting: %v\n", err) - s.logClientIP = false - } + privkey := s.getPrivKey(newConfig) - sock, err := zmq.NewSocket(zmq.PUB) + s.Lock() + defer s.Unlock() + + // AuthStart() is not idempotent, must explictly stop auth before updating curve + zmq.AuthStop() + // update the auth curve of the ZMQ socket without creating a new one + s.setupAuth(newConfig, privkey, s.sock) +} + +// getPrivKey reads the content of the private key on disk, truncates the key content +// and returns the truncated key as a string in Z85 format +func (s *server) getPrivKey(conf *config) string { + // always read from key path everytime this function is called because + // even if the key path stays the same, the key content may have changed + privkeyBytes, err := ioutil.ReadFile(conf.PrivateKeyPath) if err != nil { - s.logger.Fatalln("failed to create zmq socket:", err) + s.logger.Fatalln("failed to get private key:", err) } - if s.AuthType == "CURVE" { - privkeyBytes, err := ioutil.ReadFile(s.PrivateKeyPath) - if err != nil { - s.logger.Fatalln("failed to get private key:", err) - } + privkey := zmq.Z85encode(string(privkeyBytes[:32])) - privkey := zmq.Z85encode(string(privkeyBytes[:32])) + return privkey +} + +// setupAuth resets the auth settings based on the configuration +func (s *server) setupAuth(conf *config, privkey string, sock *zmq.Socket) { + if conf.AuthType == "CURVE" { + zmq.AuthSetVerbose(conf.AuthVerbose) - zmq.AuthSetVerbose(s.AuthVerbose) - err = zmq.AuthStart() + err := zmq.AuthStart() if err != nil { s.logger.Fatalln("failed to start zmq auth:", err) } - s.logger.Println(s.StationPublicKeys) + s.logger.Println(conf.StationPublicKeys) zmq.AuthAllow("*") - zmq.AuthCurveAdd("*", s.StationPublicKeys...) + zmq.AuthCurveAdd("*", conf.StationPublicKeys...) err = sock.ServerAuthCurve("*", privkey) if err != nil { s.logger.Fatalln("failed to set up auth on zmq socket:", err) } } +} + +func main() { + var s server + s.logger = log.New(os.Stdout, "[API] ", log.Ldate|log.Lmicroseconds) + s.messageAccepter = s.sendToZMQ + + _, err := toml.DecodeFile(os.Getenv("CJ_API_CONFIG"), &s) + if err != nil { + s.logger.Fatalln("failed to load config:", err) + } + + // Should we log client IP addresses + s.logClientIP, err = strconv.ParseBool(os.Getenv("LOG_CLIENT_IP")) + if err != nil { + s.logger.Printf("failed parse client ip logging setting: %v\n", err) + s.logClientIP = false + } + + sock, err := zmq.NewSocket(zmq.PUB) + if err != nil { + s.logger.Fatalln("failed to create zmq socket:", err) + } + + privkey := s.getPrivKey(&s.config) + s.setupAuth(&s.config, privkey, sock) err = sock.Bind(fmt.Sprintf("tcp://*:%d", s.ZMQPort)) if err != nil { @@ -251,5 +310,7 @@ func main() { r.HandleFunc("/register", s.register) http.Handle("/", r) + s.setupReloadHandler() + s.logger.Fatal(http.ListenAndServe(fmt.Sprintf(":%d", s.APIPort), nil)) } diff --git a/sysconfig/conjure-registration-api.service b/sysconfig/conjure-registration-api.service index 4ac3978d..151aa942 100644 --- a/sysconfig/conjure-registration-api.service +++ b/sysconfig/conjure-registration-api.service @@ -11,6 +11,9 @@ EnvironmentFile=/opt/conjure/sysconfig/conjure.conf ExecStart=/opt/conjure/registration-api/registration-api +# send SIGHUP to the process running API service on systemd reload +ExecReload=/bin/kill -HUP $MAINPID + # on stop processes will get SIGTERM, and after 10 secs - SIGKILL (default 90) TimeoutStopSec=10