diff --git a/auth/host_session.go b/auth/host_session.go index 62e9d4387..7a3929240 100644 --- a/auth/host_session.go +++ b/auth/host_session.go @@ -222,7 +222,7 @@ func SessionHandler(conn *websocket.Conn) { if err = conn.WriteMessage(messageType, reponseData); err != nil { logger.Log(0, "error during message writing:", err.Error()) } - go CheckNetRegAndHostUpdate(netsToAdd[:], &result.Host, uuid.Nil) + go CheckNetRegAndHostUpdate(netsToAdd[:], &result.Host, uuid.Nil, []models.TagID{}) case <-timeout: // the read from req.answerCh has timed out logger.Log(0, "timeout signal recv,exiting oauth socket conn") break @@ -236,7 +236,7 @@ func SessionHandler(conn *websocket.Conn) { } // CheckNetRegAndHostUpdate - run through networks and send a host update -func CheckNetRegAndHostUpdate(networks []string, h *models.Host, relayNodeId uuid.UUID) { +func CheckNetRegAndHostUpdate(networks []string, h *models.Host, relayNodeId uuid.UUID, tags []models.TagID) { // publish host update through MQ for i := range networks { network := networks[i] @@ -246,6 +246,14 @@ func CheckNetRegAndHostUpdate(networks []string, h *models.Host, relayNodeId uui logger.Log(0, "failed to add host to network:", h.ID.String(), h.Name, network, err.Error()) continue } + if len(tags) > 0 { + newNode.Tags = make(map[models.TagID]struct{}) + for _, tagI := range tags { + newNode.Tags[tagI] = struct{}{} + } + logic.UpsertNode(newNode) + } + if relayNodeId != uuid.Nil && !newNode.IsRelayed { // check if relay node exists and acting as relay relaynode, err := logic.GetNodeByID(relayNodeId.String()) diff --git a/controllers/controller.go b/controllers/controller.go index 75423b6e1..5a1396d3d 100644 --- a/controllers/controller.go +++ b/controllers/controller.go @@ -34,6 +34,7 @@ var HttpHandlers = []interface{}{ loggerHandlers, hostHandlers, enrollmentKeyHandlers, + tagHandlers, legacyHandlers, } diff --git a/controllers/enrollmentkeys.go b/controllers/enrollmentkeys.go index dc6669bd6..9d7fbe432 100644 --- a/controllers/enrollmentkeys.go +++ b/controllers/enrollmentkeys.go @@ -156,6 +156,7 @@ func createEnrollmentKey(w http.ResponseWriter, r *http.Request) { newTime, enrollmentKeyBody.Networks, enrollmentKeyBody.Tags, + enrollmentKeyBody.Groups, enrollmentKeyBody.Unlimited, relayId, ) @@ -206,7 +207,7 @@ func updateEnrollmentKey(w http.ResponseWriter, r *http.Request) { } } - newEnrollmentKey, err := logic.UpdateEnrollmentKey(keyId, relayId) + newEnrollmentKey, err := logic.UpdateEnrollmentKey(keyId, relayId, enrollmentKeyBody.Groups) if err != nil { slog.Error("failed to update enrollment key", "error", err) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) @@ -307,6 +308,7 @@ func handleHostRegister(w http.ResponseWriter, r *http.Request) { return } } + if err = logic.CreateHost(&newHost); err != nil { logger.Log( 0, @@ -355,5 +357,5 @@ func handleHostRegister(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) json.NewEncoder(w).Encode(&response) // notify host of changes, peer and node updates - go auth.CheckNetRegAndHostUpdate(enrollmentKey.Networks, &newHost, enrollmentKey.Relay) + go auth.CheckNetRegAndHostUpdate(enrollmentKey.Networks, &newHost, enrollmentKey.Relay, enrollmentKey.Groups) } diff --git a/controllers/network.go b/controllers/network.go index 03f40cb9e..d6ad45e6e 100644 --- a/controllers/network.go +++ b/controllers/network.go @@ -530,8 +530,9 @@ func createNetwork(w http.ResponseWriter, r *http.Request) { logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) return } - logic.CreateDefaultNetworkRolesAndGroups(models.NetworkID(network.NetID)) + logic.CreateDefaultNetworkRolesAndGroups(models.NetworkID(network.NetID)) + logic.CreateDefaultTags(models.NetworkID(network.NetID)) //add new network to allocated ip map go logic.AddNetworkToAllocatedIpMap(network.NetID) diff --git a/controllers/node.go b/controllers/node.go index fd6f5d902..ab28ba96d 100644 --- a/controllers/node.go +++ b/controllers/node.go @@ -326,6 +326,7 @@ func getNetworkNodes(w http.ResponseWriter, r *http.Request) { if len(filteredNodes) > 0 { nodes = filteredNodes } + nodes = logic.AddStaticNodestoList(nodes) // returns all the nodes in JSON/API format apiNodes := logic.GetAllNodesAPI(nodes[:]) @@ -363,7 +364,9 @@ func getAllNodes(w http.ResponseWriter, r *http.Request) { if !userPlatformRole.FullAccess { nodes = logic.GetFilteredNodesByUserAccess(*user, nodes) } + } + nodes = logic.AddStaticNodestoList(nodes) // return all the nodes in JSON/API format apiNodes := logic.GetAllNodesAPI(nodes[:]) logger.Log(3, r.Header.Get("user"), "fetched all nodes they have access to") diff --git a/controllers/tags.go b/controllers/tags.go new file mode 100644 index 000000000..2def88f68 --- /dev/null +++ b/controllers/tags.go @@ -0,0 +1,199 @@ +package controller + +import ( + "encoding/json" + "errors" + "fmt" + "net/http" + "net/url" + "strings" + "time" + + "github.com/gorilla/mux" + "github.com/gravitl/netmaker/logger" + "github.com/gravitl/netmaker/logic" + "github.com/gravitl/netmaker/models" +) + +func tagHandlers(r *mux.Router) { + r.HandleFunc("/api/v1/tags", logic.SecurityCheck(true, http.HandlerFunc(getTags))). + Methods(http.MethodGet) + r.HandleFunc("/api/v1/tags", logic.SecurityCheck(true, http.HandlerFunc(createTag))). + Methods(http.MethodPost) + r.HandleFunc("/api/v1/tags", logic.SecurityCheck(true, http.HandlerFunc(updateTag))). + Methods(http.MethodPut) + r.HandleFunc("/api/v1/tags", logic.SecurityCheck(true, http.HandlerFunc(deleteTag))). + Methods(http.MethodDelete) + +} + +// @Summary List Tags in a network +// @Router /api/v1/tags [get] +// @Tags TAG +// @Accept json +// @Success 200 {array} models.SuccessResponse +// @Failure 500 {object} models.ErrorResponse +func getTags(w http.ResponseWriter, r *http.Request) { + netID, _ := url.QueryUnescape(r.URL.Query().Get("network")) + if netID == "" { + logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("network id param is missing"), "badrequest")) + return + } + // check if network exists + _, err := logic.GetNetwork(netID) + if err != nil { + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) + return + } + tags, err := logic.ListTagsWithNodes(models.NetworkID(netID)) + if err != nil { + logger.Log(0, r.Header.Get("user"), "failed to get all network tag entries: ", err.Error()) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) + return + } + logic.SortTagEntrys(tags[:]) + logic.ReturnSuccessResponseWithJson(w, r, tags, "fetched all tags in the network "+netID) +} + +// @Summary Create Tag +// @Router /api/v1/tags [post] +// @Tags TAG +// @Accept json +// @Success 200 {array} models.SuccessResponse +// @Failure 500 {object} models.ErrorResponse +func createTag(w http.ResponseWriter, r *http.Request) { + var req models.CreateTagReq + err := json.NewDecoder(r.Body).Decode(&req) + if err != nil { + logger.Log(0, "error decoding request body: ", + err.Error()) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) + return + } + user, err := logic.GetUser(r.Header.Get("user")) + if err != nil { + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) + return + } + // check if tag network exists + _, err = logic.GetNetwork(req.Network.String()) + if err != nil { + logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("failed to get network details for "+req.Network.String()), "badrequest")) + return + } + // check if tag exists + tag := models.Tag{ + ID: models.TagID(fmt.Sprintf("%s.%s", req.Network, req.TagName)), + TagName: req.TagName, + Network: req.Network, + CreatedBy: user.UserName, + CreatedAt: time.Now(), + } + _, err = logic.GetTag(tag.ID) + if err == nil { + logic.ReturnErrorResponse(w, r, logic.FormatError(fmt.Errorf("tag with id %s exists already", tag.TagName), "badrequest")) + return + } + // validate name + err = logic.CheckIDSyntax(tag.TagName) + if err != nil { + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) + return + } + err = logic.InsertTag(tag) + if err != nil { + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) + return + } + go func() { + for _, node := range req.TaggedNodes { + if node.IsStatic { + extclient, err := logic.GetExtClient(node.StaticNode.ClientID, node.StaticNode.Network) + if err == nil && extclient.RemoteAccessClientID == "" { + if extclient.Tags == nil { + extclient.Tags = make(map[models.TagID]struct{}) + } + extclient.Tags[tag.ID] = struct{}{} + logic.SaveExtClient(&extclient) + } + continue + } + node, err := logic.GetNodeByID(node.ID) + if err != nil { + continue + } + if node.Tags == nil { + node.Tags = make(map[models.TagID]struct{}) + } + node.Tags[tag.ID] = struct{}{} + logic.UpsertNode(&node) + } + }() + + logic.ReturnSuccessResponseWithJson(w, r, req, "created tag successfully") +} + +// @Summary Update Tag +// @Router /api/v1/tags [put] +// @Tags TAG +// @Accept json +// @Success 200 {array} models.SuccessResponse +// @Failure 500 {object} models.ErrorResponse +func updateTag(w http.ResponseWriter, r *http.Request) { + var updateTag models.UpdateTagReq + err := json.NewDecoder(r.Body).Decode(&updateTag) + if err != nil { + logger.Log(0, "error decoding request body: ", + err.Error()) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) + return + } + + tag, err := logic.GetTag(updateTag.ID) + if err != nil { + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) + return + } + updateTag.NewName = strings.TrimSpace(updateTag.NewName) + var newID models.TagID + if updateTag.NewName != "" { + // validate name + err = logic.CheckIDSyntax(updateTag.NewName) + if err != nil { + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) + return + } + newID = models.TagID(fmt.Sprintf("%s.%s", tag.Network, updateTag.NewName)) + tag.ID = newID + tag.TagName = updateTag.NewName + err = logic.InsertTag(tag) + if err != nil { + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) + return + } + // delete old Tag entry + logic.DeleteTag(updateTag.ID) + } + go logic.UpdateTag(updateTag, newID) + logic.ReturnSuccessResponse(w, r, "updating tags") +} + +// @Summary Delete Tag +// @Router /api/v1/tags [delete] +// @Tags TAG +// @Accept json +// @Success 200 {array} models.SuccessResponse +// @Failure 500 {object} models.ErrorResponse +func deleteTag(w http.ResponseWriter, r *http.Request) { + tagID, _ := url.QueryUnescape(r.URL.Query().Get("tag_id")) + if tagID == "" { + logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("role is required"), "badrequest")) + return + } + err := logic.DeleteTag(models.TagID(tagID)) + if err != nil { + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) + return + } + logic.ReturnSuccessResponse(w, r, "deleted tag "+tagID) +} diff --git a/database/database.go b/database/database.go index f8508b3f0..2a950b6c3 100644 --- a/database/database.go +++ b/database/database.go @@ -67,6 +67,8 @@ const ( PENDING_USERS_TABLE_NAME = "pending_users" // USER_INVITES - table for user invites USER_INVITES_TABLE_NAME = "user_invites" + // TAG_TABLE_NAME - table for tags + TAG_TABLE_NAME = "tags" // == ERROR CONSTS == // NO_RECORD - no singular result found NO_RECORD = "no result found" @@ -152,6 +154,7 @@ func createTables() { CreateTable(PENDING_USERS_TABLE_NAME) CreateTable(USER_PERMISSIONS_TABLE_NAME) CreateTable(USER_INVITES_TABLE_NAME) + CreateTable(TAG_TABLE_NAME) } func CreateTable(tableName string) error { diff --git a/logic/enrollmentkey.go b/logic/enrollmentkey.go index d3c48a011..bf811a1a8 100644 --- a/logic/enrollmentkey.go +++ b/logic/enrollmentkey.go @@ -37,7 +37,7 @@ var ( ) // CreateEnrollmentKey - creates a new enrollment key in db -func CreateEnrollmentKey(uses int, expiration time.Time, networks, tags []string, unlimited bool, relay uuid.UUID) (*models.EnrollmentKey, error) { +func CreateEnrollmentKey(uses int, expiration time.Time, networks, tags []string, groups []models.TagID, unlimited bool, relay uuid.UUID) (*models.EnrollmentKey, error) { newKeyID, err := getUniqueEnrollmentID() if err != nil { return nil, err @@ -51,6 +51,7 @@ func CreateEnrollmentKey(uses int, expiration time.Time, networks, tags []string Tags: []string{}, Type: models.Undefined, Relay: relay, + Groups: groups, } if uses > 0 { k.UsesRemaining = uses @@ -89,7 +90,7 @@ func CreateEnrollmentKey(uses int, expiration time.Time, networks, tags []string } // UpdateEnrollmentKey - updates an existing enrollment key's associated relay -func UpdateEnrollmentKey(keyId string, relayId uuid.UUID) (*models.EnrollmentKey, error) { +func UpdateEnrollmentKey(keyId string, relayId uuid.UUID, groups []models.TagID) (*models.EnrollmentKey, error) { key, err := GetEnrollmentKey(keyId) if err != nil { return nil, err @@ -109,7 +110,7 @@ func UpdateEnrollmentKey(keyId string, relayId uuid.UUID) (*models.EnrollmentKey } key.Relay = relayId - + key.Groups = groups if err = upsertEnrollmentKey(&key); err != nil { return nil, err } diff --git a/logic/enrollmentkey_test.go b/logic/enrollmentkey_test.go index 677c47141..5e63df167 100644 --- a/logic/enrollmentkey_test.go +++ b/logic/enrollmentkey_test.go @@ -14,35 +14,35 @@ func TestCreateEnrollmentKey(t *testing.T) { database.InitializeDatabase() defer database.CloseDB() t.Run("Can_Not_Create_Key", func(t *testing.T) { - newKey, err := CreateEnrollmentKey(0, time.Time{}, nil, nil, false, uuid.Nil) + newKey, err := CreateEnrollmentKey(0, time.Time{}, nil, nil, nil, false, uuid.Nil) assert.Nil(t, newKey) assert.NotNil(t, err) assert.ErrorIs(t, err, models.ErrInvalidEnrollmentKey) }) t.Run("Can_Create_Key_Uses", func(t *testing.T) { - newKey, err := CreateEnrollmentKey(1, time.Time{}, nil, nil, false, uuid.Nil) + newKey, err := CreateEnrollmentKey(1, time.Time{}, nil, nil, nil, false, uuid.Nil) assert.Nil(t, err) assert.Equal(t, 1, newKey.UsesRemaining) assert.True(t, newKey.IsValid()) }) t.Run("Can_Create_Key_Time", func(t *testing.T) { - newKey, err := CreateEnrollmentKey(0, time.Now().Add(time.Minute), nil, nil, false, uuid.Nil) + newKey, err := CreateEnrollmentKey(0, time.Now().Add(time.Minute), nil, nil, nil, false, uuid.Nil) assert.Nil(t, err) assert.True(t, newKey.IsValid()) }) t.Run("Can_Create_Key_Unlimited", func(t *testing.T) { - newKey, err := CreateEnrollmentKey(0, time.Time{}, nil, nil, true, uuid.Nil) + newKey, err := CreateEnrollmentKey(0, time.Time{}, nil, nil, nil, true, uuid.Nil) assert.Nil(t, err) assert.True(t, newKey.IsValid()) }) t.Run("Can_Create_Key_WithNetworks", func(t *testing.T) { - newKey, err := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, true, uuid.Nil) + newKey, err := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, nil, true, uuid.Nil) assert.Nil(t, err) assert.True(t, newKey.IsValid()) assert.True(t, len(newKey.Networks) == 2) }) t.Run("Can_Create_Key_WithTags", func(t *testing.T) { - newKey, err := CreateEnrollmentKey(0, time.Time{}, nil, []string{"tag1", "tag2"}, true, uuid.Nil) + newKey, err := CreateEnrollmentKey(0, time.Time{}, nil, []string{"tag1", "tag2"}, nil, true, uuid.Nil) assert.Nil(t, err) assert.True(t, newKey.IsValid()) assert.True(t, len(newKey.Tags) == 2) @@ -62,7 +62,7 @@ func TestCreateEnrollmentKey(t *testing.T) { func TestDelete_EnrollmentKey(t *testing.T) { database.InitializeDatabase() defer database.CloseDB() - newKey, _ := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, true, uuid.Nil) + newKey, _ := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, nil, true, uuid.Nil) t.Run("Can_Delete_Key", func(t *testing.T) { assert.True(t, newKey.IsValid()) err := DeleteEnrollmentKey(newKey.Value) @@ -83,7 +83,7 @@ func TestDelete_EnrollmentKey(t *testing.T) { func TestDecrement_EnrollmentKey(t *testing.T) { database.InitializeDatabase() defer database.CloseDB() - newKey, _ := CreateEnrollmentKey(1, time.Time{}, nil, nil, false, uuid.Nil) + newKey, _ := CreateEnrollmentKey(1, time.Time{}, nil, nil, nil, false, uuid.Nil) t.Run("Check_initial_uses", func(t *testing.T) { assert.True(t, newKey.IsValid()) assert.Equal(t, newKey.UsesRemaining, 1) @@ -107,9 +107,9 @@ func TestDecrement_EnrollmentKey(t *testing.T) { func TestUsability_EnrollmentKey(t *testing.T) { database.InitializeDatabase() defer database.CloseDB() - key1, _ := CreateEnrollmentKey(1, time.Time{}, nil, nil, false, uuid.Nil) - key2, _ := CreateEnrollmentKey(0, time.Now().Add(time.Minute<<4), nil, nil, false, uuid.Nil) - key3, _ := CreateEnrollmentKey(0, time.Time{}, nil, nil, true, uuid.Nil) + key1, _ := CreateEnrollmentKey(1, time.Time{}, nil, nil, nil, false, uuid.Nil) + key2, _ := CreateEnrollmentKey(0, time.Now().Add(time.Minute<<4), nil, nil, nil, false, uuid.Nil) + key3, _ := CreateEnrollmentKey(0, time.Time{}, nil, nil, nil, true, uuid.Nil) t.Run("Check if valid use key can be used", func(t *testing.T) { assert.Equal(t, key1.UsesRemaining, 1) ok := TryToUseEnrollmentKey(key1) @@ -145,7 +145,7 @@ func removeAllEnrollments() { func TestTokenize_EnrollmentKeys(t *testing.T) { database.InitializeDatabase() defer database.CloseDB() - newKey, _ := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, true, uuid.Nil) + newKey, _ := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, nil, true, uuid.Nil) const defaultValue = "MwE5MwE5MwE5MwE5MwE5MwE5MwE5MwE5" const b64value = "eyJzZXJ2ZXIiOiJhcGkubXlzZXJ2ZXIuY29tIiwidmFsdWUiOiJNd0U1TXdFNU13RTVNd0U1TXdFNU13RTVNd0U1TXdFNSJ9" const serverAddr = "api.myserver.com" @@ -178,7 +178,7 @@ func TestTokenize_EnrollmentKeys(t *testing.T) { func TestDeTokenize_EnrollmentKeys(t *testing.T) { database.InitializeDatabase() defer database.CloseDB() - newKey, _ := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, true, uuid.Nil) + newKey, _ := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, nil, true, uuid.Nil) const b64Value = "eyJzZXJ2ZXIiOiJhcGkubXlzZXJ2ZXIuY29tIiwidmFsdWUiOiJNd0U1TXdFNU13RTVNd0U1TXdFNU13RTVNd0U1TXdFNSJ9" const serverAddr = "api.myserver.com" diff --git a/logic/extpeers.go b/logic/extpeers.go index c619dde94..542f98b56 100644 --- a/logic/extpeers.go +++ b/logic/extpeers.go @@ -329,6 +329,7 @@ func UpdateExtClient(old *models.ExtClient, update *models.CustomExtClient) mode // replace any \r\n with \n in postup and postdown from HTTP request new.PostUp = strings.Replace(update.PostUp, "\r\n", "\n", -1) new.PostDown = strings.Replace(update.PostDown, "\r\n", "\n", -1) + new.Tags = update.Tags return new } @@ -528,3 +529,40 @@ func GetExtclientAllowedIPs(client models.ExtClient) (allowedIPs []string) { } return } + +func GetStaticNodesByNetwork(network models.NetworkID) (staticNode []models.Node) { + extClients, err := GetAllExtClients() + if err != nil { + return + } + for _, extI := range extClients { + if extI.Network == network.String() { + n := models.Node{ + IsStatic: true, + StaticNode: extI, + IsUserNode: extI.RemoteAccessClientID != "", + } + staticNode = append(staticNode, n) + } + } + + return +} + +func GetStaticNodesByGw(gwNode models.Node) (staticNode []models.Node) { + extClients, err := GetAllExtClients() + if err != nil { + return + } + for _, extI := range extClients { + if extI.IngressGatewayID == gwNode.ID.String() { + n := models.Node{ + IsStatic: true, + StaticNode: extI, + IsUserNode: extI.RemoteAccessClientID != "", + } + staticNode = append(staticNode, n) + } + } + return +} diff --git a/logic/gateway.go b/logic/gateway.go index b25d3e1d0..ba7521976 100644 --- a/logic/gateway.go +++ b/logic/gateway.go @@ -2,6 +2,7 @@ package logic import ( "errors" + "fmt" "time" "github.com/gravitl/netmaker/database" @@ -182,6 +183,7 @@ func CreateIngressGateway(netid string, nodeid string, ingress models.IngressReq if node.Metadata == "" { node.Metadata = "This host can be used for remote access" } + node.Tags[models.TagID(fmt.Sprintf("%s.%s", netid, models.RemoteAccessTagName))] = struct{}{} err = UpsertNode(&node) if err != nil { return models.Node{}, err @@ -257,6 +259,7 @@ func DeleteIngressGateway(nodeid string) (models.Node, []models.ExtClient, error if !servercfg.IsPro { node.IsInternetGateway = false } + delete(node.Tags, models.TagID(fmt.Sprintf("%s.%s", node.Network, models.RemoteAccessTagName))) node.IngressGatewayRange = "" node.Metadata = "" err = UpsertNode(&node) diff --git a/logic/nodes.go b/logic/nodes.go index e241c6ccc..2560b1ea6 100644 --- a/logic/nodes.go +++ b/logic/nodes.go @@ -378,6 +378,20 @@ func GetAllNodes() ([]models.Node, error) { return nodes, nil } +func AddStaticNodestoList(nodes []models.Node) []models.Node { + netMap := make(map[string]struct{}) + for _, node := range nodes { + if _, ok := netMap[node.Network]; ok { + continue + } + if node.IsIngressGateway { + nodes = append(nodes, GetStaticNodesByNetwork(models.NetworkID(node.Network))...) + netMap[node.Network] = struct{}{} + } + } + return nodes +} + // GetNetworkByNode - gets the network model from a node func GetNetworkByNode(node *models.Node) (models.Network, error) { @@ -420,6 +434,9 @@ func SetNodeDefaults(node *models.Node, resetConnected bool) { node.SetDefaultConnected() } node.SetExpirationDateTime() + if node.Tags == nil { + node.Tags = make(map[models.TagID]struct{}) + } } // GetRecordKey - get record key @@ -698,3 +715,95 @@ func GetAllFailOvers() ([]models.Node, error) { } return igs, nil } + +func GetTagMapWithNodes(netID models.NetworkID) (tagNodesMap map[models.TagID][]models.Node) { + tagNodesMap = make(map[models.TagID][]models.Node) + nodes, _ := GetNetworkNodes(netID.String()) + for _, nodeI := range nodes { + if nodeI.Tags == nil { + continue + } + for nodeTagID := range nodeI.Tags { + tagNodesMap[nodeTagID] = append(tagNodesMap[nodeTagID], nodeI) + } + } + return AddTagMapWithStaticNodes(netID, tagNodesMap) +} + +func AddTagMapWithStaticNodes(netID models.NetworkID, + tagNodesMap map[models.TagID][]models.Node) map[models.TagID][]models.Node { + extclients, err := GetNetworkExtClients(netID.String()) + if err != nil { + return tagNodesMap + } + for _, extclient := range extclients { + if extclient.Tags == nil || extclient.RemoteAccessClientID != "" { + continue + } + for tagID := range extclient.Tags { + tagNodesMap[tagID] = append(tagNodesMap[tagID], models.Node{ + IsStatic: true, + StaticNode: extclient, + }) + } + + } + return tagNodesMap +} + +func GetNodesWithTag(tagID models.TagID) map[string]models.Node { + nMap := make(map[string]models.Node) + tag, err := GetTag(tagID) + if err != nil { + return nMap + } + nodes, _ := GetNetworkNodes(tag.Network.String()) + for _, nodeI := range nodes { + if nodeI.Tags == nil { + continue + } + if _, ok := nodeI.Tags[tagID]; ok { + nMap[nodeI.ID.String()] = nodeI + } + } + return AddStaticNodesWithTag(tag, nMap) +} + +func AddStaticNodesWithTag(tag models.Tag, nMap map[string]models.Node) map[string]models.Node { + extclients, err := GetNetworkExtClients(tag.Network.String()) + if err != nil { + return nMap + } + for _, extclient := range extclients { + if extclient.RemoteAccessClientID != "" { + continue + } + if _, ok := extclient.Tags[tag.ID]; ok { + nMap[extclient.ClientID] = models.Node{ + IsStatic: true, + StaticNode: extclient, + } + } + + } + return nMap +} + +func GetStaticNodeWithTag(tagID models.TagID) map[string]models.Node { + nMap := make(map[string]models.Node) + tag, err := GetTag(tagID) + if err != nil { + return nMap + } + extclients, err := GetNetworkExtClients(tag.Network.String()) + if err != nil { + return nMap + } + for _, extclient := range extclients { + nMap[extclient.ClientID] = models.Node{ + IsStatic: true, + StaticNode: extclient, + } + } + return nMap +} diff --git a/logic/tags.go b/logic/tags.go new file mode 100644 index 000000000..cddcb7002 --- /dev/null +++ b/logic/tags.go @@ -0,0 +1,287 @@ +package logic + +import ( + "encoding/json" + "errors" + "fmt" + "regexp" + "sort" + "sync" + "time" + + "github.com/gravitl/netmaker/database" + "github.com/gravitl/netmaker/models" + "golang.org/x/exp/slog" +) + +var tagMutex = &sync.RWMutex{} + +// GetTag - fetches tag info +func GetTag(tagID models.TagID) (models.Tag, error) { + data, err := database.FetchRecord(database.TAG_TABLE_NAME, tagID.String()) + if err != nil { + return models.Tag{}, err + } + tag := models.Tag{} + err = json.Unmarshal([]byte(data), &tag) + if err != nil { + return tag, err + } + return tag, nil +} + +// InsertTag - creates new tag +func InsertTag(tag models.Tag) error { + tagMutex.Lock() + defer tagMutex.Unlock() + _, err := database.FetchRecord(database.TAG_TABLE_NAME, tag.ID.String()) + if err == nil { + return fmt.Errorf("tag `%s` exists already", tag.ID) + } + d, err := json.Marshal(tag) + if err != nil { + return err + } + return database.Insert(tag.ID.String(), string(d), database.TAG_TABLE_NAME) +} + +// DeleteTag - delete tag, will also untag hosts +func DeleteTag(tagID models.TagID) error { + tagMutex.Lock() + defer tagMutex.Unlock() + // cleanUp tags on hosts + tag, err := GetTag(tagID) + if err != nil { + return err + } + nodes, err := GetNetworkNodes(tag.Network.String()) + if err != nil { + return err + } + for _, nodeI := range nodes { + nodeI := nodeI + if _, ok := nodeI.Tags[tagID]; ok { + delete(nodeI.Tags, tagID) + UpsertNode(&nodeI) + } + } + + extclients, _ := GetNetworkExtClients(tag.Network.String()) + for _, extclient := range extclients { + if _, ok := extclient.Tags[tagID]; ok { + delete(extclient.Tags, tagID) + SaveExtClient(&extclient) + } + } + return database.DeleteRecord(database.TAG_TABLE_NAME, tagID.String()) +} + +// ListTagsWithHosts - lists all tags with tagged hosts +func ListTagsWithNodes(netID models.NetworkID) ([]models.TagListResp, error) { + tags, err := ListNetworkTags(netID) + if err != nil { + return []models.TagListResp{}, err + } + tagsNodeMap := GetTagMapWithNodes(netID) + resp := []models.TagListResp{} + for _, tagI := range tags { + tagRespI := models.TagListResp{ + Tag: tagI, + UsedByCnt: len(tagsNodeMap[tagI.ID]), + TaggedNodes: GetAllNodesAPI(tagsNodeMap[tagI.ID]), + } + resp = append(resp, tagRespI) + } + return resp, nil +} + +// ListTags - lists all tags from DB +func ListTags() ([]models.Tag, error) { + tagMutex.RLock() + defer tagMutex.RUnlock() + data, err := database.FetchRecords(database.TAG_TABLE_NAME) + if err != nil && !database.IsEmptyRecord(err) { + return []models.Tag{}, err + } + tags := []models.Tag{} + for _, dataI := range data { + tag := models.Tag{} + err := json.Unmarshal([]byte(dataI), &tag) + if err != nil { + continue + } + tags = append(tags, tag) + } + return tags, nil +} + +// ListTags - lists all tags from DB +func ListNetworkTags(netID models.NetworkID) ([]models.Tag, error) { + tagMutex.RLock() + defer tagMutex.RUnlock() + data, err := database.FetchRecords(database.TAG_TABLE_NAME) + if err != nil && !database.IsEmptyRecord(err) { + return []models.Tag{}, err + } + tags := []models.Tag{} + for _, dataI := range data { + tag := models.Tag{} + err := json.Unmarshal([]byte(dataI), &tag) + if err != nil { + continue + } + if tag.Network == netID { + tags = append(tags, tag) + } + + } + return tags, nil +} + +// UpdateTag - updates and syncs hosts with tag update +func UpdateTag(req models.UpdateTagReq, newID models.TagID) { + tagMutex.Lock() + defer tagMutex.Unlock() + var err error + tagNodesMap := GetNodesWithTag(req.ID) + for _, apiNode := range req.TaggedNodes { + node := models.Node{} + var nodeID string + if apiNode.IsStatic { + if apiNode.StaticNode.RemoteAccessClientID != "" { + continue + } + extclient, err := GetExtClient(apiNode.StaticNode.ClientID, apiNode.StaticNode.Network) + if err != nil { + continue + } + node.IsStatic = true + nodeID = extclient.ClientID + node.StaticNode = extclient + } else { + node, err = GetNodeByID(apiNode.ID) + if err != nil { + continue + } + nodeID = node.ID.String() + } + + if _, ok := tagNodesMap[nodeID]; !ok { + if node.StaticNode.Tags == nil { + node.StaticNode.Tags = make(map[models.TagID]struct{}) + } + if node.Tags == nil { + node.Tags = make(map[models.TagID]struct{}) + } + if newID != "" { + if node.IsStatic { + node.StaticNode.Tags[newID] = struct{}{} + SaveExtClient(&node.StaticNode) + } else { + node.Tags[newID] = struct{}{} + UpsertNode(&node) + } + + } else { + if node.IsStatic { + node.StaticNode.Tags[req.ID] = struct{}{} + SaveExtClient(&node.StaticNode) + } else { + node.Tags[req.ID] = struct{}{} + UpsertNode(&node) + } + } + } else { + if newID != "" { + delete(node.Tags, req.ID) + delete(node.StaticNode.Tags, req.ID) + if node.IsStatic { + node.StaticNode.Tags[newID] = struct{}{} + SaveExtClient(&node.StaticNode) + } else { + node.Tags[newID] = struct{}{} + UpsertNode(&node) + } + } + delete(tagNodesMap, nodeID) + } + + } + for _, deletedTaggedNode := range tagNodesMap { + delete(deletedTaggedNode.Tags, req.ID) + delete(deletedTaggedNode.StaticNode.Tags, req.ID) + if deletedTaggedNode.IsStatic { + SaveExtClient(&deletedTaggedNode.StaticNode) + } else { + UpsertNode(&deletedTaggedNode) + } + } + go func(req models.UpdateTagReq) { + if newID != "" { + tagNodesMap = GetNodesWithTag(req.ID) + for _, nodeI := range tagNodesMap { + nodeI := nodeI + if nodeI.StaticNode.Tags == nil { + nodeI.StaticNode.Tags = make(map[models.TagID]struct{}) + } + if nodeI.Tags == nil { + nodeI.Tags = make(map[models.TagID]struct{}) + } + delete(nodeI.Tags, req.ID) + delete(nodeI.StaticNode.Tags, req.ID) + nodeI.Tags[newID] = struct{}{} + nodeI.StaticNode.Tags[newID] = struct{}{} + if nodeI.IsStatic { + SaveExtClient(&nodeI.StaticNode) + } else { + UpsertNode(&nodeI) + } + } + } + }(req) + +} + +// SortTagEntrys - Sorts slice of Tag entries by their id +func SortTagEntrys(tags []models.TagListResp) { + sort.Slice(tags, func(i, j int) bool { + return tags[i].ID < tags[j].ID + }) +} + +func CheckIDSyntax(id string) error { + if id == "" { + return errors.New("name is required") + } + if len(id) < 3 { + return errors.New("name should have min 3 characters") + } + reg, err := regexp.Compile("^[a-zA-Z-]+$") + if err != nil { + return err + } + if !reg.MatchString(id) { + return errors.New("invalid name. allowed characters are [a-zA-Z-]") + } + return nil +} + +func CreateDefaultTags(netID models.NetworkID) { + // create tag for remote access gws in the network + tag := models.Tag{ + ID: models.TagID(fmt.Sprintf("%s.%s", netID.String(), models.RemoteAccessTagName)), + TagName: models.RemoteAccessTagName, + Network: netID, + CreatedBy: "auto", + CreatedAt: time.Now(), + } + _, err := GetTag(tag.ID) + if err == nil { + return + } + err = InsertTag(tag) + if err != nil { + slog.Error("failed to create remote access gw tag", "error", err.Error()) + return + } +} diff --git a/migrate/migrate.go b/migrate/migrate.go index 6612d4ddc..26bbf434d 100644 --- a/migrate/migrate.go +++ b/migrate/migrate.go @@ -21,6 +21,7 @@ import ( func Run() { updateEnrollmentKeys() assignSuperAdmin() + createDefaultTags() removeOldUserGrps() syncUsers() updateHosts() @@ -166,6 +167,19 @@ func updateNodes() { return } for _, node := range nodes { + node := node + if node.IsIngressGateway { + tagID := models.TagID(fmt.Sprintf("%s.%s", node.Network, + models.RemoteAccessTagName)) + if node.Tags == nil { + node.Tags = make(map[models.TagID]struct{}) + } + if _, ok := node.Tags[tagID]; !ok { + node.Tags[tagID] = struct{}{} + logic.UpsertNode(&node) + } + + } if node.IsEgressGateway { egressRanges, update := removeInterGw(node.EgressGatewayRanges) if update { @@ -175,6 +189,18 @@ func updateNodes() { } } } + extclients, _ := logic.GetAllExtClients() + for _, extclient := range extclients { + tagID := models.TagID(fmt.Sprintf("%s.%s", extclient.Network, + models.RemoteAccessTagName)) + if extclient.Tags == nil { + extclient.Tags = make(map[models.TagID]struct{}) + } + if _, ok := extclient.Tags[tagID]; !ok { + extclient.Tags[tagID] = struct{}{} + logic.SaveExtClient(&extclient) + } + } } func removeInterGw(egressRanges []string) ([]string, bool) { @@ -432,3 +458,13 @@ func syncUsers() { } } } + +func createDefaultTags() { + networks, err := logic.GetNetworks() + if err != nil { + return + } + for _, network := range networks { + logic.CreateDefaultTags(models.NetworkID(network.NetID)) + } +} diff --git a/models/api_node.go b/models/api_node.go index 0c91bbeca..30e08c639 100644 --- a/models/api_node.go +++ b/models/api_node.go @@ -48,6 +48,10 @@ type ApiNode struct { InetNodeReq InetNodeReq `json:"inet_node_req" yaml:"inet_node_req"` InternetGwID string `json:"internetgw_node_id" yaml:"internetgw_node_id"` AdditionalRagIps []string `json:"additional_rag_ips" yaml:"additional_rag_ips"` + Tags map[TagID]struct{} `json:"tags" yaml:"tags"` + IsStatic bool `json:"is_static"` + IsUserNode bool `json:"is_user_node"` + StaticNode ExtClient `json:"static_node"` } // ApiNode.ConvertToServerNode - converts an api node to a server node @@ -123,6 +127,7 @@ func (a *ApiNode) ConvertToServerNode(currentNode *Node) *Node { } convertedNode.AdditionalRagIps = append(convertedNode.AdditionalRagIps, ragIp) } + convertedNode.Tags = a.Tags return &convertedNode } @@ -180,9 +185,13 @@ func (nm *Node) ConvertToAPINode() *ApiNode { apiNode.FailedOverBy = nm.FailedOverBy apiNode.Metadata = nm.Metadata apiNode.AdditionalRagIps = []string{} + apiNode.Tags = nm.Tags for _, ip := range nm.AdditionalRagIps { apiNode.AdditionalRagIps = append(apiNode.AdditionalRagIps, ip.String()) } + apiNode.IsStatic = nm.IsStatic + apiNode.IsUserNode = nm.IsUserNode + apiNode.StaticNode = nm.StaticNode return &apiNode } diff --git a/models/enrollment_key.go b/models/enrollment_key.go index e775344df..5aa89c8a7 100644 --- a/models/enrollment_key.go +++ b/models/enrollment_key.go @@ -52,6 +52,7 @@ type EnrollmentKey struct { Token string `json:"token,omitempty"` // B64 value of EnrollmentToken Type KeyType `json:"type"` Relay uuid.UUID `json:"relay"` + Groups []TagID `json:"groups"` } // APIEnrollmentKey - used to create enrollment keys via API @@ -63,6 +64,7 @@ type APIEnrollmentKey struct { Tags []string `json:"tags" validate:"required,dive,min=3,max=32"` Type KeyType `json:"type"` Relay string `json:"relay"` + Groups []TagID `json:"groups"` } // RegisterResponse - the response to a successful enrollment register diff --git a/models/extclient.go b/models/extclient.go index 9d67207d3..b84a7c8d9 100644 --- a/models/extclient.go +++ b/models/extclient.go @@ -20,6 +20,7 @@ type ExtClient struct { RemoteAccessClientID string `json:"remote_access_client_id"` // unique ID (MAC address) of RAC machine PostUp string `json:"postup" bson:"postup"` PostDown string `json:"postdown" bson:"postdown"` + Tags map[TagID]struct{} `json:"tags"` } // CustomExtClient - struct for CustomExtClient params @@ -33,4 +34,5 @@ type CustomExtClient struct { RemoteAccessClientID string `json:"remote_access_client_id"` // unique ID (MAC address) of RAC machine PostUp string `json:"postup" bson:"postup" validate:"max=1024"` PostDown string `json:"postdown" bson:"postdown" validate:"max=1024"` + Tags map[TagID]struct{} `json:"tags"` } diff --git a/models/node.go b/models/node.go index e5ea2cfc4..0ca699dbb 100644 --- a/models/node.go +++ b/models/node.go @@ -99,6 +99,10 @@ type Node struct { InetNodeReq InetNodeReq `json:"inet_node_req" yaml:"inet_node_req"` InternetGwID string `json:"internetgw_node_id" yaml:"internetgw_node_id"` AdditionalRagIps []net.IP `json:"additional_rag_ips" yaml:"additional_rag_ips" swaggertype:"array,number"` + Tags map[TagID]struct{} `json:"tags" yaml:"tags"` + IsStatic bool `json:"is_static"` + IsUserNode bool `json:"is_user_node"` + StaticNode ExtClient `json:"static_node"` } // LegacyNode - legacy struct for node model diff --git a/models/tags.go b/models/tags.go new file mode 100644 index 000000000..9fcb449da --- /dev/null +++ b/models/tags.go @@ -0,0 +1,52 @@ +package models + +import ( + "fmt" + "time" +) + +type TagID string + +const ( + RemoteAccessTagName = "remote-access-gws" +) + +func (id TagID) String() string { + return string(id) +} + +func (t Tag) GetIDFromName() string { + return fmt.Sprintf("%s.%s", t.Network, t.TagName) +} + +type Tag struct { + ID TagID `json:"id"` + TagName string `json:"tag_name"` + Network NetworkID `json:"network"` + CreatedBy string `json:"created_by"` + CreatedAt time.Time `json:"created_at"` +} + +type CreateTagReq struct { + TagName string `json:"tag_name"` + Network NetworkID `json:"network"` + TaggedNodes []ApiNode `json:"tagged_nodes"` +} + +type TagListResp struct { + Tag + UsedByCnt int `json:"used_by_count"` + TaggedNodes []ApiNode `json:"tagged_nodes"` +} + +type TagListRespNodes struct { + Tag + UsedByCnt int `json:"used_by_count"` + TaggedNodes []ApiNode `json:"tagged_nodes"` +} + +type UpdateTagReq struct { + Tag + NewName string `json:"new_name"` + TaggedNodes []ApiNode `json:"tagged_nodes"` +} diff --git a/pro/logic/user_mgmt.go b/pro/logic/user_mgmt.go index 0f3db1055..cb9da06bf 100644 --- a/pro/logic/user_mgmt.go +++ b/pro/logic/user_mgmt.go @@ -687,7 +687,9 @@ func GetFilteredNodesByUserAccess(user models.User, nodes []models.Node) (filter nodesMap := make(map[string]struct{}) allNetworkRoles := make(map[models.UserRoleID]struct{}) - + defer func() { + filteredNodes = logic.AddStaticNodestoList(filteredNodes) + }() if len(user.NetworkRoles) > 0 { for _, netRoles := range user.NetworkRoles { for netRoleI := range netRoles { @@ -696,7 +698,8 @@ func GetFilteredNodesByUserAccess(user models.User, nodes []models.Node) (filter } } if _, ok := user.NetworkRoles[models.AllNetworks]; ok { - return nodes + filteredNodes = nodes + return } if len(user.UserGroups) > 0 { for userGID := range user.UserGroups { @@ -704,7 +707,8 @@ func GetFilteredNodesByUserAccess(user models.User, nodes []models.Node) (filter if err == nil { if len(userG.NetworkRoles) > 0 { if _, ok := userG.NetworkRoles[models.AllNetworks]; ok { - return nodes + filteredNodes = nodes + return } for _, netRoles := range userG.NetworkRoles { for netRoleI := range netRoles {