session.go 3.26 KiB
Newer Older
package sessionmdl

import (
	"errors"
	"time"

	"coresls/servers/coresls/app/models"

	"corelab.mkcl.org/MKCLOS/coredevelopmentplatform/corepkgv2/cachemdl"
	"corelab.mkcl.org/MKCLOS/coredevelopmentplatform/corepkgv2/utiliymdl/guidmdl"

	"github.com/dgrijalva/jwt-go"
)

type Session struct {
	// UserID     string
	SessionFor string
	SessionID  string
}

const (
	InstanceHeader = "session-instance"
)

var ValidateSession bool

var SessionInstance = guidmdl.GetGUID()

var store cachemdl.Cacher

var (
	ErrSessionNotFound         = errors.New("session not found")
	ErrInvalidSessionInstance  = errors.New("got invalid session instance id")
	ErrSessionValidationFailed = errors.New("got invalid instance id")
)

// InitSessionManagerCache initializes the cache with provided configuration. Need to provide a cache type to use.
func InitSessionManagerCache(chacheType int) {
	cacheConfig := cachemdl.CacheConfig{
		Type: chacheType,
		RedisCache: &cachemdl.RedisCache{
			Addr:   "", // empty takes default redis address 6379
			DB:     models.RedisDBSessionManager,
			Prefix: models.ProjectID,
		},
		FastCache: &cachemdl.FastCacheHelper{
			CleanupTime: 60 * time.Minute,
			MaxEntries:  1000,
			Expiration:  -1,
		},
	}

	store = cachemdl.GetCacheInstance(&cacheConfig)
}

func Set(userID string, s ...Session) {
	i, _ := store.Get(userID)
	if i == nil {
		store.Set(userID, s)
		return
	}

	sessions, ok := i.([]Session)
	if !ok {
		store.Set(userID, s)
		return
	}

	store.Set(userID, append(sessions, s...))
}

func Get(userID string) ([]Session, error) {
	var (
		s  []Session
		i  interface{}
		ok bool
	)

	i, ok = store.Get(userID)
	if !ok {
		return s, ErrSessionNotFound
	}

	s, ok = i.([]Session)
	if !ok {
		return s, errors.New("failed to retrieve previous sessions")
	}

	return s, nil
}

func Delete(userId string) {
	store.Delete(userId)
}

// ValidateSessionFromToken checks for session id in claims against available sessions
func ValidateSessionFromToken(claims jwt.MapClaims) error {
	i, ok := claims["userId"]
	if !ok {
		return errors.New("\"userId\" field not found in token")
	}

	userId, ok := i.(string)
	if !ok {
		return errors.New("\"userId\" field is not string")
	}

	sessions, err := Get(userId)
	if err != nil {
		return err
	}

	i, ok = claims["sessionId"]
	if !ok {
		return errors.New("\"sessionId\" field not found in token")
	}

	sessionId, ok := i.(string)
	if !ok {
		return errors.New("\"sessionId\" field is not string")
	}

	var found bool

	for i := range sessions {
		if sessions[i].SessionID == sessionId {
			found = true
			break
		}
	}

	if !found {
		return ErrSessionNotFound
	}

	return nil
}

func CheckForSessionAvailability(userId, sessionFor string) error {

	sessions, err := Get(userId)
	if err != nil {
		return err
	}

	var found bool

	for i := range sessions {
		if sessions[i].SessionFor == sessionFor {
			found = true
			break
		}
	}

	if !found {
		return ErrSessionNotFound
	}

	return nil
}

func DeleteSession(userId, sessionFor string) {

	sessions, err := Get(userId)
	if err != nil {
		return
	}

	for i := 0; i < len(sessions); i++ {
		if sessions[i].SessionFor == sessionFor {
			sessions[i] = sessions[len(sessions)-1]
			sessions = sessions[:len(sessions)-1]
		}
	}

	if len(sessions) == 0 {
		store.Delete(userId)
		return
	}

	store.Set(userId, sessions)
}