Skip to content

Commit

Permalink
Add CSRF protection to the Hypervisor API (#1604)
Browse files Browse the repository at this point in the history
* Add CSRF protection to the Hypervisor API

* Param for disabling CSRF protection
  • Loading branch information
Senyoret1 authored Dec 10, 2023
1 parent 2c99655 commit c4ddd24
Show file tree
Hide file tree
Showing 4 changed files with 181 additions and 4 deletions.
2 changes: 2 additions & 0 deletions pkg/visor/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ var (
logTag string
hiddenflags []string
all bool
useCsrf bool
pkg bool
usr bool
localIPs []net.IP // nolint:unused
Expand Down Expand Up @@ -131,6 +132,7 @@ func init() {
RootCmd.Flags().BoolVar(&isForceColor, "forcecolor", false, "force color logging when out is not STDOUT")
hiddenflags = append(hiddenflags, "forcecolor")
RootCmd.Flags().BoolVar(&all, "all", false, "show all flags")
RootCmd.Flags().BoolVar(&useCsrf, "csrf", true, "Request a CSRF token for sensitive hypervisor API requests")
for _, j := range hiddenflags {
RootCmd.Flags().MarkHidden(j) //nolint
}
Expand Down
108 changes: 108 additions & 0 deletions pkg/visor/csrf.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
// Package visor pkg/visor/hypervisor.go
package visor

import (
"time"

"crypto/hmac"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"errors"
"strings"

"github.com/skycoin/skycoin/src/cipher"
)

const (
// CSRFHeaderName is the name of the CSRF header
CSRFHeaderName = "X-CSRF-Token"

// CSRFMaxAge is the lifetime of a CSRF token in seconds
CSRFMaxAge = time.Second * 30

csrfSecretLength = 64

csrfNonceLength = 64
)

var (
// ErrCSRFInvalid is returned when the the CSRF token is in invalid format
ErrCSRFInvalid = errors.New("invalid CSRF token")
// ErrCSRFExpired is returned when the csrf token has expired
ErrCSRFExpired = errors.New("csrf token expired")
)

var csrfSecretKey []byte

func init() {
csrfSecretKey = cipher.RandByte(csrfSecretLength)
}

// CSRFToken csrf token
type CSRFToken struct {
Nonce []byte
ExpiresAt time.Time
}

// newCSRFToken generates a new CSRF Token
func newCSRFToken() (string, error) {
token := &CSRFToken{
Nonce: cipher.RandByte(csrfNonceLength),
ExpiresAt: time.Now().Add(CSRFMaxAge),
}

tokenJSON, err := json.Marshal(token)
if err != nil {
return "", err
}

h := hmac.New(sha256.New, csrfSecretKey)
_, err = h.Write([]byte(tokenJSON))
if err != nil {
return "", err
}

sig := base64.RawURLEncoding.EncodeToString(h.Sum(nil))

signingString := base64.RawURLEncoding.EncodeToString(tokenJSON)

return strings.Join([]string{signingString, sig}, "."), nil
}

// verifyCSRFToken checks validity of the given token
func verifyCSRFToken(headerToken string) error {
tokenParts := strings.Split(headerToken, ".")
if len(tokenParts) != 2 {
return ErrCSRFInvalid
}

signingString, err := base64.RawURLEncoding.DecodeString(tokenParts[0])
if err != nil {
return err
}

h := hmac.New(sha256.New, csrfSecretKey)
_, err = h.Write([]byte(signingString))
if err != nil {
return err
}

sig := base64.RawURLEncoding.EncodeToString(h.Sum(nil))

if sig != tokenParts[1] {
return ErrCSRFInvalid
}

var csrfToken CSRFToken
err = json.Unmarshal(signingString, &csrfToken)
if err != nil {
return err
}

if time.Now().After(csrfToken.ExpiresAt) {
return ErrCSRFExpired
}

return nil
}
40 changes: 40 additions & 0 deletions pkg/visor/hypervisor.go
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,8 @@ func (hv *Hypervisor) makeMux() chi.Router {

r.Get("/ping", hv.getPong())

r.Get("/csrf", hv.getCsrf())

if hv.c.EnableAuth {
r.Group(func(r chi.Router) {
r.Post("/create-account", hv.users.CreateAccount())
Expand Down Expand Up @@ -299,6 +301,29 @@ func (hv *Hypervisor) getPong() http.HandlerFunc {
}
}

// Csrf provides a temporal security token.
type Csrf struct {
Token string `json:"csrf_token"`
}

func (hv *Hypervisor) getCsrf() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if useCsrf {
token, err := newCSRFToken()
if err != nil {
httputil.WriteJSON(w, r, http.StatusInternalServerError, err)
return
}

httputil.WriteJSON(w, r, http.StatusOK, Csrf{
Token: token,
})
} else {
httputil.WriteJSON(w, r, http.StatusOK, Csrf{Token: ""})
}
}
}

// About provides info about the hypervisor.
type About struct {
PubKey cipher.PubKey `json:"public_key"` // The hypervisor's public key.
Expand Down Expand Up @@ -1352,6 +1377,21 @@ func (hv *Hypervisor) visorCtx(w http.ResponseWriter, r *http.Request) (*httpCtx
return nil, false
}

if useCsrf && (r.Method == "POST" || r.Method == "PUT" || r.Method == "DELETE") {
csrfToken := r.Header.Get(CSRFHeaderName)
if csrfToken == "" {
errMsg := fmt.Errorf("no csrf token for %s request", r.Method)
httputil.WriteJSON(w, r, http.StatusForbidden, errMsg)
return nil, false
}

err = verifyCSRFToken(csrfToken)
if err != nil {
httputil.WriteJSON(w, r, http.StatusForbidden, err)
return nil, false
}
}

if pk != hv.c.PK {
v, ok := hv.visorConn(pk)

Expand Down
35 changes: 31 additions & 4 deletions static/skywire-manager-src/src/app/services/api.service.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { Injectable, NgZone } from '@angular/core';
import { HttpClient, HttpErrorResponse, HttpHeaders } from '@angular/common/http';
import { Observable, throwError } from 'rxjs';
import { catchError, map } from 'rxjs/operators';
import { catchError, first, map, mergeMap } from 'rxjs/operators';
import { webSocket } from 'rxjs/webSocket';
import { Router } from '@angular/router';

Expand All @@ -22,6 +22,7 @@ export class RequestOptions {
requestType = RequestTypes.Json;
ignoreAuth = false;
vpnKeyForAuth: string;
csrfToken: string;

public constructor(init?: Partial<RequestOptions>) {
Object.assign(this, init);
Expand Down Expand Up @@ -69,23 +70,45 @@ export class ApiService {
* @param url Endpoint URL, after the "/api/" part.
*/
post(url: string, body: any = {}, options: RequestOptions = null): Observable<any> {
return this.request('POST', url, body, options);
return this.getCsrf().pipe(first(), mergeMap(csrf => {
options = options ? options : new RequestOptions();
options.csrfToken = csrf;

return this.request('POST', url, body, options);
}));
}

/**
* Makes a request to a PUT endpoint.
* @param url Endpoint URL, after the "/api/" part.
*/
put(url: string, body: any = {}, options: RequestOptions = null): Observable<any> {
return this.request('PUT', url, body, options);
return this.getCsrf().pipe(first(), mergeMap(csrf => {
options = options ? options : new RequestOptions();
options.csrfToken = csrf;

return this.request('PUT', url, body, options);
}));
}

/**
* Makes a request to a DELETE endpoint.
* @param url Endpoint URL, after the "/api/" part.
*/
delete(url: string, options: RequestOptions = null): Observable<any> {
return this.request('DELETE', url, {}, options);
return this.getCsrf().pipe(first(), mergeMap(csrf => {
options = options ? options : new RequestOptions();
options.csrfToken = csrf;

return this.request('DELETE', url, {}, options);
}));
}

/**
* Gets a csrf token from the node, to be able to make protected requests.
*/
private getCsrf(): Observable<string> {
return this.get('csrf').pipe(map(response => response.csrf_token));
}

/**
Expand Down Expand Up @@ -138,6 +161,10 @@ export class ApiService {
requestOptions.headers = requestOptions.headers.append('Content-Type', 'application/json');
}

if (options.csrfToken) {
requestOptions.headers = requestOptions.headers.append('X-CSRF-Token', options.csrfToken);
}

return requestOptions;
}

Expand Down

0 comments on commit c4ddd24

Please sign in to comment.