package middleware

import (
	"errors"
	"fmt"
	"golangtemplate/servers/app/models"
	"net/http"
	"runtime/debug"
	"time"

	"corelab.mkcl.org/MKCLOS/coredevelopmentplatform/corepkgv2/loggermdl"
	"github.com/gin-contrib/cors"
	"github.com/gin-gonic/gin"
	"github.com/golang-jwt/jwt/v5"
	"github.com/golang-jwt/jwt/v5/request"
)

// Init -Init
func Init(g *gin.Engine) {
	g.Use(cors.Default())
	g.Use(Recovery())
}

type Principal struct {
	UserID            string    `json:"userId"`
	Groups            []string  `json:"groups"`
	SessionExpiration time.Time `json:"sessionExpiration"`
	ClientIP          string    `json:"clientIP"`
	HitsCount         int       `json:"hitsCount"`
	Token             string    `json:"token"`
	Metadata          string    `json:"metadata"`
	jwt.RegisteredClaims
}

var group = []string{"delete"}

// Recovery - a recoverymdl for gin to log the panic and recover
func Recovery() gin.HandlerFunc {
	return func(c *gin.Context) {
		defer func() {
			if err := recover(); err != nil {
				loggermdl.LogError("recovered:", err)
				loggermdl.LogError(string(debug.Stack()))
				// write to access log
				gin.DefaultErrorWriter.Write(debug.Stack())
			}
		}()
		c.Next()
	}
}

func CustomLogger() gin.HandlerFunc {
	return gin.LoggerWithFormatter(func(params gin.LogFormatterParams) string {
		return fmt.Sprintf("", params.ClientIP, params.BodySize, params.Request, params.Method, params.Path)
	})
}

func Auth() gin.HandlerFunc {
	return func(c *gin.Context) {
		_, err := request.ParseFromRequest(c.Request, request.OAuth2Extractor, func(token *jwt.Token) (interface{}, error) {
			b := ([]byte(models.JWTKey))
			return b, nil
		})
		if err != nil {
			c.AbortWithError(401, err)
		}
	}
}

func RolebasedAuth() gin.HandlerFunc {
	return func(c *gin.Context) {
		// Decode the JWT token
		tokenString := c.Request.Header.Get("Authorization")
		if tokenString == "" {
			c.AbortWithError(401, errors.New("missing Authorization header"))
			return
		}
		// Decode the JWT token
		token, err := jwt.ParseWithClaims(tokenString, &Principal{}, func(token *jwt.Token) (interface{}, error) {
			return []byte(models.JWTKey), nil
		})
		if err != nil {
			c.AbortWithError(401, err)
			c.IndentedJSON(http.StatusUnauthorized, "Error parsing token")
			return
		}
		if !token.Valid {
			c.AbortWithError(401, errors.New("invalid token"))
			c.IndentedJSON(http.StatusUnauthorized, "Token is invalid")
			return
		}
		claims, ok := token.Claims.(*Principal)
		if !ok || claims == nil {
			c.AbortWithError(401, errors.New("invalid claims"))
			c.IndentedJSON(http.StatusUnauthorized, "Invalid claims")
			return
		}
		isPresent := false
		for _, i := range group {
			for _, j := range claims.Groups {
				if i == j {
					isPresent = true
					break
				}
			}
		}
		if !isPresent {
			c.AbortWithError(403, errors.New("Access forbidden."))
			c.IndentedJSON(http.StatusForbidden, "No roles present.")
			return
		}
		c.Next()
	}
}