diff --git a/authmdl/jwtmdl/jwtmdl.go b/authmdl/jwtmdl/jwtmdl.go index 1b55c6d38e8d38793842e6a7f0c4f65003faca7e..240bc5a8d7edaf80c07d10318d4fdaf48981234e 100644 --- a/authmdl/jwtmdl/jwtmdl.go +++ b/authmdl/jwtmdl/jwtmdl.go @@ -1,12 +1,11 @@ package jwtmdl import ( - "fmt" "strings" "time" + "corelab.mkcl.org/MKCLOS/coredevelopmentplatform/corepkgv2/authmdl/sessionmdl" "corelab.mkcl.org/MKCLOS/coredevelopmentplatform/corepkgv2/errormdl" - "corelab.mkcl.org/MKCLOS/coredevelopmentplatform/corepkgv2/loggermdl" jwt "github.com/dgrijalva/jwt-go" "github.com/tidwall/gjson" ) @@ -14,46 +13,15 @@ import ( // GlobalJWTKey - key to decode and encode token var GlobalJWTKey string -// // DecodeTokenWithJWTKey decode token -// func DecodeTokenWithJWTKey(req *http.Request, jwtKey string) (jwt.MapClaims, error) { -// token, err := request.ParseFromRequest(req, request.OAuth2Extractor, func(token *jwt.Token) (interface{}, error) { -// b := ([]byte(jwtKey)) -// return b, nil -// }) -// if errormdl.CheckErr(err) != nil { -// loggermdl.LogError("Error while parsing JWT Token: ", errormdl.CheckErr(err)) -// return nil, errormdl.CheckErr(err) -// } - -// claims, ok := token.Claims.(jwt.MapClaims) -// if !errormdl.CheckBool1(ok) { -// loggermdl.LogError("Error while parsing claims to MapClaims") -// return nil, errormdl.Wrap("Error while getting claims") -// } - -// return claims, nil -// } - -// // DecodeToken decode token -// func DecodeToken(req *http.Request) (jwt.MapClaims, error) { -// token, err := request.ParseFromRequest(req, request.OAuth2Extractor, func(token *jwt.Token) (interface{}, error) { -// b := ([]byte(GlobalJWTKey)) -// return b, nil -// }) -// if errormdl.CheckErr(err) != nil { -// loggermdl.LogError("Error while parsing JWT Token: ", errormdl.CheckErr(err)) -// return nil, errormdl.CheckErr(err) -// } -// claims, ok := token.Claims.(jwt.MapClaims) -// if !errormdl.CheckBool1(ok) { -// loggermdl.LogError("Error while parsing claims to MapClaims") -// return nil, errormdl.Wrap("Error while getting claims") -// } -// return claims, nil -// } +var keyFunc = func(key string) jwt.Keyfunc { + return func(*jwt.Token) (interface{}, error) { + return []byte(key), nil + } +} type jwtCustomClaim struct { UserID string `json:"userId"` + SessionId string `json:"sessionId,omitempty"` Groups []string `json:"groups"` ClientIP string `json:"clientIP"` HitsCount int `json:"hitsCount"` @@ -62,28 +30,79 @@ type jwtCustomClaim struct { jwt.StandardClaims } -// GenerateToken generates JWT token from Login object -func GenerateToken(loginID string, groups []string, clientIP string, metadata gjson.Result, expirationTime time.Duration) (string, error) { +func generate(claims jwtCustomClaim, key string) (string, error) { + return jwt.NewWithClaims(jwt.SigningMethodHS256, claims).SignedString([]byte(key)) +} + +// extract return token from header string +func extract(tokenReq string) (string, error) { + tokenArray := strings.Split(tokenReq, "Bearer") + if len(tokenArray) <= 1 { + return "", errormdl.Wrap("Provided JWT token is nil or invalid ") + } + + return strings.Trim(tokenArray[1], " "), nil +} + +// decode accepts a parsed token and error from parse operation. +func decode(token *jwt.Token, err error) (jwt.MapClaims, error) { + if err != nil { + // loggermdl.LogError("Error while parsing JWT Token: ", err) + return nil, err + } + + claims, ok := token.Claims.(jwt.MapClaims) + if !ok { + // loggermdl.LogError("Error while parsing claims to MapClaims") + return nil, errormdl.Wrap("Error while getting claims") + } + + // validate user session from session id present in token + if err := sessionmdl.ValidateSessionFromToken(claims); err != nil { + // loggermdl.LogError("session validation failed with err:", err) + return nil, sessionmdl.ErrSessionValidationFailed + } + + return claims, nil +} + +func GenerateTokenWithOptions(args ...Option) (string, error) { + options := new(Options) + + options.Key = GlobalJWTKey + + for i := range args { + args[i](options) + } + claims := jwtCustomClaim{ - UserID: loginID, - Groups: groups, - ClientIP: clientIP, - Metadata: metadata.String(), + ClientIP: options.ClientIP, + Groups: options.Groups, + Metadata: options.Metadata, + SessionId: options.Session.SessionId, + UserID: options.UserID, StandardClaims: jwt.StandardClaims{ - ExpiresAt: time.Now().Add(expirationTime).Unix(), + ExpiresAt: options.ExpiresAt, }, } - // Create token with claims - token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) - // Generate encoded token and send it as response. - t, err := token.SignedString([]byte(GlobalJWTKey)) - if errormdl.CheckErr(err) != nil { - loggermdl.LogError(err) - return t, errormdl.CheckErr(err) + + t, err := generate(claims, options.Key) + if err != nil { + return "", err + } + + if len(options.Session.SessionId) > 0 { + sessionmdl.Set(options.UserID, options.Session) } + return t, nil } +// GenerateToken generates JWT token from Login object +func GenerateToken(loginID string, groups []string, clientIP string, metadata gjson.Result, expirationTime time.Duration) (string, error) { + return GenerateTokenWithJWTKey(loginID, groups, clientIP, metadata, expirationTime, GlobalJWTKey) +} + // GenerateTokenWithJWTKey generates JWT token from Login object func GenerateTokenWithJWTKey(loginID string, groups []string, clientIP string, metadata gjson.Result, expirationTime time.Duration, JWTKey string) (string, error) { claims := jwtCustomClaim{ @@ -95,43 +114,17 @@ func GenerateTokenWithJWTKey(loginID string, groups []string, clientIP string, m ExpiresAt: time.Now().Add(expirationTime).Unix(), }, } - // Create token with claims - token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) - // Generate encoded token and send it as response. - t, err := token.SignedString([]byte(JWTKey)) - if errormdl.CheckErr(err) != nil { - loggermdl.LogError(err) - return t, errormdl.CheckErr(err) - } - return t, nil + + return generate(claims, JWTKey) } //GeneratePricipleObjUsingToken GeneratePricipleObjUsingToken func GeneratePricipleObjUsingToken(tokenReq string, jwtKey string) (jwt.MapClaims, error) { - tokenArray := strings.Split(tokenReq, "Bearer") - if len(tokenArray) <= 1 { - return nil, errormdl.Wrap("Provided JWT token is nil or invalid ") - } - tokenFromRequest := strings.Trim(tokenArray[1], " ") - // get data i.e.Claims from token - token, err := jwt.Parse(tokenFromRequest, func(token *jwt.Token) (interface{}, error) { - // Don't forget to validate the alg is what you expect: - _, ok := token.Method.(*jwt.SigningMethodHMAC) - if !ok { - return nil, fmt.Errorf("Unexpected signing method: %v", token.Header["alg"]) - } - return []byte(jwtKey), nil - }) + token, err := extract(tokenReq) if err != nil { - loggermdl.LogError("Error while parsing JWT Token: ", err) return nil, err } - claims, ok := token.Claims.(jwt.MapClaims) - if !errormdl.CheckBool1(ok) { - loggermdl.LogError("Error while parsing claims to MapClaims") - return nil, errormdl.Wrap("Error while getting claims") - } - return claims, nil + return decode(jwt.Parse(token, keyFunc(jwtKey))) } diff --git a/authmdl/jwtmdl/jwtmdl_fasthttp.go b/authmdl/jwtmdl/jwtmdl_fasthttp.go index 044027a34ddb1dd56c8d66d17df9e6270e3487e7..10ff39b107c1bb3c888f3110f6607ced4988d84d 100644 --- a/authmdl/jwtmdl/jwtmdl_fasthttp.go +++ b/authmdl/jwtmdl/jwtmdl_fasthttp.go @@ -3,75 +3,22 @@ package jwtmdl import ( - "fmt" - "strings" - - "github.com/valyala/fasthttp" - - "corelab.mkcl.org/MKCLOS/coredevelopmentplatform/corepkgv2/errormdl" - "corelab.mkcl.org/MKCLOS/coredevelopmentplatform/corepkgv2/loggermdl" jwt "github.com/dgrijalva/jwt-go" + "github.com/valyala/fasthttp" ) // DecodeTokenWithJWTKey decode token func DecodeTokenWithJWTKey(req *fasthttp.Request, jwtKey string) (jwt.MapClaims, error) { - tokenFromRequest := string(req.Header.Peek("Authorization")) - tokenArray := strings.Split(tokenFromRequest, "Bearer") - if len(tokenArray) <= 1 { - return nil, errormdl.Wrap("Provided JWT token is nil or invalid ") - } - tokenFromRequest = strings.Trim(tokenArray[1], " ") - // get data i.e.Claims from token - token, err := jwt.Parse(tokenFromRequest, func(token *jwt.Token) (interface{}, error) { - // Don't forget to validate the alg is what you expect: - _, ok := token.Method.(*jwt.SigningMethodHMAC) - if !ok { - return nil, fmt.Errorf("Unexpected signing method: %v", token.Header["alg"]) - } - return []byte(jwtKey), nil - }) + tokenFromRequest, err := extract(string(req.Header.Peek("Authorization"))) if err != nil { - loggermdl.LogError("Error while parsing JWT Token: ", err) return nil, err } - claims, ok := token.Claims.(jwt.MapClaims) - if !errormdl.CheckBool1(ok) { - loggermdl.LogError("Error while parsing claims to MapClaims") - return nil, errormdl.Wrap("Error while getting claims") - } - - return claims, nil + return decode(jwt.Parse(tokenFromRequest, keyFunc(jwtKey))) } // DecodeToken decode token func DecodeToken(req *fasthttp.Request) (jwt.MapClaims, error) { - tokenFromRequest := string(req.Header.Peek("Authorization")) - tokenArray := strings.Split(tokenFromRequest, "Bearer") - if len(tokenArray) <= 1 { - return nil, errormdl.Wrap("Provided JWT token is nil or invalid ") - } - tokenFromRequest = strings.Trim(tokenArray[1], " ") - // get data i.e.Claims from token - token, err := jwt.Parse(tokenFromRequest, func(token *jwt.Token) (interface{}, error) { - // Don't forget to validate the alg is what you expect: - _, ok := token.Method.(*jwt.SigningMethodHMAC) - if !ok { - return nil, fmt.Errorf("Unexpected signing method: %v", token.Header["alg"]) - } - return []byte(GlobalJWTKey), nil - }) - - if err != nil { - loggermdl.LogError("Error while parsing JWT Token: ", err) - return nil, err - } - - claims, ok := token.Claims.(jwt.MapClaims) - if !errormdl.CheckBool1(ok) { - loggermdl.LogError("Error while parsing claims to MapClaims") - return nil, errormdl.Wrap("Error while getting claims") - } - return claims, nil + return DecodeTokenWithJWTKey(req, GlobalJWTKey) } diff --git a/authmdl/jwtmdl/jwtmdl_http.go b/authmdl/jwtmdl/jwtmdl_http.go index 2209d3d810c7952c4cb6b7bfdd28f5e772baf308..0e2da98a9e793b89da5a550da8cb1a7274154eb5 100644 --- a/authmdl/jwtmdl/jwtmdl_http.go +++ b/authmdl/jwtmdl/jwtmdl_http.go @@ -7,45 +7,15 @@ import ( "github.com/dgrijalva/jwt-go/request" - "corelab.mkcl.org/MKCLOS/coredevelopmentplatform/corepkgv2/errormdl" - "corelab.mkcl.org/MKCLOS/coredevelopmentplatform/corepkgv2/loggermdl" jwt "github.com/dgrijalva/jwt-go" ) // DecodeTokenWithJWTKey decode token func DecodeTokenWithJWTKey(req *http.Request, jwtKey string) (jwt.MapClaims, error) { - token, err := request.ParseFromRequest(req, request.OAuth2Extractor, func(token *jwt.Token) (interface{}, error) { - b := ([]byte(jwtKey)) - return b, nil - }) - if errormdl.CheckErr(err) != nil { - loggermdl.LogError("Error while parsing JWT Token: ", errormdl.CheckErr(err)) - return nil, errormdl.CheckErr(err) - } - - claims, ok := token.Claims.(jwt.MapClaims) - if !errormdl.CheckBool1(ok) { - loggermdl.LogError("Error while parsing claims to MapClaims") - return nil, errormdl.Wrap("Error while getting claims") - } - - return claims, nil + return decode(request.ParseFromRequest(req, request.OAuth2Extractor, keyFunc(jwtKey))) } // DecodeToken decode token func DecodeToken(req *http.Request) (jwt.MapClaims, error) { - token, err := request.ParseFromRequest(req, request.OAuth2Extractor, func(token *jwt.Token) (interface{}, error) { - b := ([]byte(GlobalJWTKey)) - return b, nil - }) - if errormdl.CheckErr(err) != nil { - loggermdl.LogError("Error while parsing JWT Token: ", errormdl.CheckErr(err)) - return nil, errormdl.CheckErr(err) - } - claims, ok := token.Claims.(jwt.MapClaims) - if !errormdl.CheckBool1(ok) { - loggermdl.LogError("Error while parsing claims to MapClaims") - return nil, errormdl.Wrap("Error while getting claims") - } - return claims, nil + return DecodeTokenWithJWTKey(req, GlobalJWTKey) } diff --git a/authmdl/jwtmdl/jwtmdl_test.go b/authmdl/jwtmdl/jwtmdl_test.go index 82e1921943f49dad642a7ba957d6faed8cb48468..62565948974effb67b15fef92ea1fca193a33767 100644 --- a/authmdl/jwtmdl/jwtmdl_test.go +++ b/authmdl/jwtmdl/jwtmdl_test.go @@ -4,11 +4,30 @@ import ( "fmt" "net/http" "testing" + "time" + "corelab.mkcl.org/MKCLOS/coredevelopmentplatform/corepkgv2/authmdl/sessionmdl" + "corelab.mkcl.org/MKCLOS/coredevelopmentplatform/corepkgv2/cachemdl" + "corelab.mkcl.org/MKCLOS/coredevelopmentplatform/corepkgv2/utiliymdl/guidmdl" + jwt "github.com/dgrijalva/jwt-go" "github.com/gin-gonic/gin" "github.com/stretchr/testify/assert" + "github.com/tidwall/gjson" ) +const ( + TestKey = "vJUufKHyu2xiMYmDj1TmojHR11ciUaq3" +) + +func checkToken(token string) error { + claims, err := decode(jwt.Parse(token, keyFunc(TestKey))) + if err != nil { + return err + } + + return sessionmdl.ValidateSessionFromToken(claims) +} + func server() { g := gin.Default() g.GET("/status", func(c *gin.Context) { @@ -71,3 +90,102 @@ func TestDecodeTokenvalid(t *testing.T) { // assert.Error(t, derror, "error occured") // errormdl.IsTestingNegetiveCaseOnCheckBool1 = false // } + +func TestGenerateTokenWithOptions(t *testing.T) { + sessionmdl.InitUserSessionCache(cachemdl.TypeFastCache) + + type args struct { + args []Option + } + tests := []struct { + name string + args args + want string + wantErr bool + }{ + { + name: "Token without session", + args: args{ + args: []Option{ + WithUserID("tom@company.org"), + WithExpiration(0), + WithKey(TestKey), + WithMetaData(`{"name":"tom"}`), + }, + }, + wantErr: false, + }, + { + name: "Token with session", + args: args{ + args: []Option{ + WithUserID("tom@company.org"), + WithExpiration(0), + WithKey(TestKey), + WithMetaData(`{"name":"tom"}`), + WithSession(guidmdl.GetGUID(), "me"), + }, + }, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := GenerateTokenWithOptions(tt.args.args...) + if (err != nil) != tt.wantErr { + t.Errorf("GenerateTokenWithOptions() error = %v, wantErr %v", err, tt.wantErr) + return + } + + err = checkToken(got) + if (err != nil) != tt.wantErr { + t.Errorf("GenerateTokenWithOptions() error = %v, wantErr %v", err, tt.wantErr) + return + } + }) + } +} + +func TestGenerateTokenWithJWTKey(t *testing.T) { + type args struct { + loginID string + groups []string + clientIP string + metadata gjson.Result + expirationTime time.Duration + JWTKey string + } + tests := []struct { + name string + args args + // want string + wantErr bool + }{ + { + name: "Test genrate token", + args: args{ + JWTKey: TestKey, + expirationTime: time.Minute * 5, + groups: []string{"admin"}, + loginID: "tom@company.org", + metadata: gjson.Parse(`{"name":"tom"}`), + }, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := GenerateTokenWithJWTKey(tt.args.loginID, tt.args.groups, tt.args.clientIP, tt.args.metadata, tt.args.expirationTime, tt.args.JWTKey) + if (err != nil) != tt.wantErr { + t.Errorf("GenerateTokenWithJWTKey() error = %v, wantErr %v", err, tt.wantErr) + return + } + + err = checkToken(got) + if (err != nil) != tt.wantErr { + t.Errorf("GenerateTokenWithJWTKey() error = %v, wantErr %v", err, tt.wantErr) + return + } + }) + } +} diff --git a/authmdl/jwtmdl/options.go b/authmdl/jwtmdl/options.go new file mode 100644 index 0000000000000000000000000000000000000000..e0db9986c0e24d749df21fb361c35250c77a0c21 --- /dev/null +++ b/authmdl/jwtmdl/options.go @@ -0,0 +1,73 @@ +package jwtmdl + +import ( + "time" + + "corelab.mkcl.org/MKCLOS/coredevelopmentplatform/corepkgv2/authmdl/sessionmdl" +) + +type Options struct { + Key string + UserID string + ClientIP string + Metadata string + Groups []string + ExpiresAt int64 + Session sessionmdl.Session +} + +type Option func(*Options) + +// WithKey uses provided jwt key for token generation +func WithKey(k string) Option { + return func(args *Options) { + args.Key = k + } +} + +func WithUserID(uid string) Option { + return func(args *Options) { + args.UserID = uid + } +} + +// WithSession enables session validation on jwt decode. Required fields must not be empty. +func WithSession(sid, sessionFor string) Option { + return func(args *Options) { + args.Session = sessionmdl.Session{ + SessionId: sid, + SessionFor: sessionFor, + } + } +} + +func WithClientIP(ip string) Option { + return func(args *Options) { + args.ClientIP = ip + } +} + +// WithMetaData embeds provided data in token. It is available againt `metadata` key. **It must be a valid json** +func WithMetaData(data string) Option { + return func(args *Options) { + args.Metadata = data + } +} + +func WithGroups(gs []string) Option { + return func(args *Options) { + args.Groups = gs + } +} + +// WithExpiration adds provided expiration to jwt token. Use `0` or ignore this option to generate a token witout expiry. +func WithExpiration(e time.Duration) Option { + return func(args *Options) { + if e == 0 { + args.ExpiresAt = 0 + return + } + + args.ExpiresAt = time.Now().Add(e).Unix() + } +} diff --git a/authmdl/sessionmdl/session.go b/authmdl/sessionmdl/session.go new file mode 100644 index 0000000000000000000000000000000000000000..10bfc22fa525274b9ccbdaded87c32d55e405cb8 --- /dev/null +++ b/authmdl/sessionmdl/session.go @@ -0,0 +1,165 @@ +// package sessionmdl provides APIs to Add, Validate and Delete user sessions. These APIs must be used along with JWT Auth. +// If you want to use this functionality, a jwt token must contain `userId` and `sessionId` fields. +// A user can have multiple active sessions for different use cases. To check if user has an active session for particular usecase use `CheckForSessionAvailability()`. +// And to check user session on each request use `ValidateSessionFromToken()`. +// +// An in memory cache is used to store sessions. It automatically falls back to redis cache if -gridmode=1 is set. +// +// The expiraion of session must be same as of token expiration. + +package sessionmdl + +import ( + "errors" + + "corelab.mkcl.org/MKCLOS/coredevelopmentplatform/corepkgv2/cachemdl" +) + +type Session struct { + SessionFor string + SessionId string +} + +// store is used to store sessions in memory, falls back to redis cache on grid mode. +var store cachemdl.Cacher + +var ( + ErrUserNotFound = errors.New("user not found") + ErrSessionNotFound = errors.New("session not found") + ErrInvalidSessionInstance = errors.New("got invalid session instance id") + ErrSessionValidationFailed = errors.New("session validation failed") +) + +// Init initializes sessions with provided cache. Subsequent calls will not have any effect after first initialization. +func Init(cache cachemdl.Cacher) { + if store != nil { + return + } + + store = cache +} + +// Set stores the sessions for provided userId. Session is appended to the list. It does not check if the same session exists or not. +func Set(userId string, s ...Session) { + i, ok := store.Get(userId) + if !ok || i == nil { + set(userId, s) + return + } + + sessions, ok := i.([]Session) + if !ok { + set(userId, s) + return + } + + set(userId, append(sessions, s...)) +} + +func set(key string, val interface{}) { + // Set the user sessions with no expiry as each session can have different expiry depending on the JWT token expiry. + store.SetNoExpiration(key, val) +} + +// Get returns all the available sessions for the user. This may contain expired but not deleted sessions. +func Get(userId string) ([]Session, error) { + var ( + s []Session + i interface{} + ok bool + ) + + i, ok = store.Get(userId) + if !ok { + return s, ErrUserNotFound + } + + s, _ = i.([]Session) + // if !ok { + // return s, errors.New("failed to retrieve previous sessions") + // } + + return s, nil +} + +// Delete removes all the sessions associated with the user. +func Delete(userId string) { + store.Delete(userId) +} + +// DeleteSession removes a particular session for user, if present. +func DeleteSession(userId, sessionFor string) { + + sessions, err := Get(userId) + if err != nil { + return + } + + for i := 0; i < len(sessions); i++ { + if sessions[i].SessionFor == sessionFor { + sessions[i] = sessions[len(sessions)-1] + sessions = sessions[:len(sessions)-1] + } + } + + if len(sessions) == 0 { + store.Delete(userId) + return + } + + set(userId, sessions) +} + +// ValidateSessionFromToken checks for session id in claims against available sessions. +// Validate only if a nonempty `sessionId` is present. The claims must contain `userId` field if session is present. +func ValidateSessionFromToken(claims map[string]interface{}) error { + + // check for sessionId field, if not present then it is ignored at the time of token generation. + // This means user doesn't want to validate session. + i, ok := claims["sessionId"] + if !ok || i == nil { + return nil + } + + sessionId, _ := i.(string) + if len(sessionId) == 0 { + return errors.New("\"sessionId\" field is empty") + } + + i, ok = claims["userId"] + if !ok { + return errors.New("\"userId\" field not found in token") + } + + userId, _ := i.(string) + if len(userId) == 0 { + return errors.New("\"userId\" field is empty") + } + + sessions, err := Get(userId) + if err != nil { + return err + } + + for i := range sessions { + if sessions[i].SessionId == sessionId { + return nil + } + } + + return ErrSessionNotFound +} + +// CheckForSessionAvailability checks if the user has active session for provided `sessionFor`. Returns true if session is available. +func CheckForSessionAvailability(userId, sessionFor string) bool { + + sessions, _ := Get(userId) + + for i := range sessions { + if sessions[i].SessionFor == sessionFor { + return true + } + } + + return false +} diff --git a/authmdl/sessionmdl/session_test.go b/authmdl/sessionmdl/session_test.go new file mode 100644 index 0000000000000000000000000000000000000000..fb27c67a2d8f2bd801f772bfcc0c651379121e53 --- /dev/null +++ b/authmdl/sessionmdl/session_test.go @@ -0,0 +1,147 @@ +package sessionmdl + +import ( + "testing" + + "corelab.mkcl.org/MKCLOS/coredevelopmentplatform/corepkgv2/cachemdl" +) + +func init() { + Init(cachemdl.SetupFastCache(cachemdl.FCWithMaxEntries(10000))) +} + +func TestSet(t *testing.T) { + Set("tom@company.org", Session{SessionFor: "xyz", SessionId: "789"}) + + type args struct { + userId string + s []Session + } + tests := []struct { + name string + args args + }{ + // TODO: Add test cases. + { + name: "User present", + args: args{ + s: []Session{Session{SessionFor: "abc", SessionId: "123"}}, + userId: "tom@company.org", + }, + }, + { + name: "User not present", + args: args{ + s: []Session{Session{SessionFor: "abc", SessionId: "123"}}, + userId: "ronny@company.org", + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + Set(tt.args.userId, tt.args.s...) + }) + } + + all := store.GetItemsCount() + if all != len(tests) { + t.Errorf("expected %d users got %d", len(tests), all) + } +} + +func TestValidateSessionFromToken(t *testing.T) { + Set("tom@company.org", Session{SessionFor: "xyz", SessionId: "789"}) + type args struct { + claims map[string]interface{} + } + tests := []struct { + name string + args args + wantErr bool + }{ + { + name: "Session present", + args: args{claims: map[string]interface{}{"userId": "tom@company.org", "sessionId": "789"}}, + wantErr: false, + }, + { + name: "Session not present", + args: args{claims: map[string]interface{}{"userId": "invalid@company.org", "sessionId": "123"}}, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := ValidateSessionFromToken(tt.args.claims); (err != nil) != tt.wantErr { + t.Errorf("ValidateSessionFromToken() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestCheckForSessionAvailability(t *testing.T) { + Set("tom@company.org", Session{SessionFor: "xyz", SessionId: "789"}) + type args struct { + userId string + sessionFor string + } + tests := []struct { + name string + args args + want bool + }{ + { + name: "Session present", + args: args{userId: "tom@company.org", sessionFor: "xyz"}, + want: true, + }, + { + name: "Session not present", + args: args{userId: "tom@company.org", sessionFor: "someRandomID"}, + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if res := CheckForSessionAvailability(tt.args.userId, tt.args.sessionFor); res != tt.want { + t.Errorf("CheckForSessionAvailability() got = %v, want %v", res, tt.want) + } + }) + } +} + +func TestDeleteSession(t *testing.T) { + Set("TestDeleteSession", Session{SessionFor: "xyz", SessionId: "789"}) + Set("TestDeleteSession", Session{SessionFor: "abc", SessionId: "123"}) + type args struct { + userId string + sessionFor string + } + tests := []struct { + name string + args args + }{ + { + name: "Delete existing session", + args: args{userId: "TestDeleteSession", sessionFor: "xyz"}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + DeleteSession(tt.args.userId, tt.args.sessionFor) + i, _ := store.Get(tt.args.userId) + // if !ok { + // t.Error("expected", tt.args.userId, "not to be deleted") + // } + + sessions := i.([]Session) + + for _, s := range sessions { + if s.SessionFor == tt.args.sessionFor { + t.Error("expected", tt.args.sessionFor, "to be deleted") + } + } + }) + } + +} diff --git a/cachemdl/cache.go b/cachemdl/cache.go index fff9c6acee17530ccd1a49611dd54ed6acdc6092..01452d9ec7266b7a53349a6430c784d6b44958f8 100644 --- a/cachemdl/cache.go +++ b/cachemdl/cache.go @@ -1,7 +1,6 @@ package cachemdl import ( - "errors" "time" ) @@ -12,11 +11,6 @@ const ( TypeRedisCache ) -var ( - // ErrInvalidCacheType indicates provided cache type is not supported - ErrInvalidCacheType = errors.New("invalid cache type provided") -) - // Cacher provides access to underlying cache, make sure all caches implement these methods. // // The return types of data can be different. Ex. In case of redis cache it is `string`. The caller needs to handle this with the help of Type() method. @@ -31,7 +25,7 @@ type Cacher interface { // GET Get(key string) (interface{}, bool) - // GetAll() map[string]interface{} + GetAll() map[string]interface{} // DELETE Delete(key string) @@ -42,26 +36,3 @@ type Cacher interface { Type() int } - -// GetCacheInstance returns a cache instance, panics if invalid cache type is provided -func GetCacheInstance(cfg *CacheConfig) Cacher { - switch cfg.Type { - case TypeFastCache: - cfg.FastCache.Setup(cfg.FastCache.MaxEntries, cfg.FastCache.Expiration, cfg.FastCache.CleanupTime) - return cfg.FastCache - - case TypeRedisCache: - cfg.RedisCache.Setup(cfg.RedisCache.Addr, cfg.RedisCache.Password, cfg.RedisCache.Prefix, cfg.RedisCache.DB, cfg.RedisCache.Expiration) - return cfg.RedisCache - - default: - panic(ErrInvalidCacheType) - } -} - -// CacheConfig - -type CacheConfig struct { - Type int - FastCache *FastCacheHelper - RedisCache *RedisCache -} diff --git a/cachemdl/cache_redis.go b/cachemdl/cache_redis.go index a021c28a6a8b43400d46fbfa25aa26a207744855..a516b36c371ae9658ecc54be65706b0235d04e05 100644 --- a/cachemdl/cache_redis.go +++ b/cachemdl/cache_redis.go @@ -16,6 +16,7 @@ Note - corelab.mkcl.org/MKCLOS/coredevelopmentplatform/corepkgv2/loggermdl must import ( "encoding/json" + "errors" "log" "strings" "time" @@ -28,11 +29,6 @@ import ( const ( noExp time.Duration = 0 - // NotOK refers to a unsuccessfull operations - NotOK int64 = 0 - // OK refers to a successfull operations - OK int64 = 1 - keySplitter = ":" ) @@ -51,6 +47,42 @@ type RedisCache struct { Prefix string // this will be used for storing keys for provided project } +type configRedis struct { + addr string // redis server address, default "127.0.0.1:6379" + db int // redis DB on provided server, default 0 + password string // + expiration time.Duration // this duration will be used for Set() method + prefix string // this will be used for storing keys for provided project +} + +type redisOption func(*configRedis) + +func RedisWithAddr(addr string) redisOption { + return func(cfg *configRedis) { + cfg.addr = addr + } +} +func RedisWithDB(db int) redisOption { + return func(cfg *configRedis) { + cfg.db = db + } +} +func RedisWithPrefix(pfx string) redisOption { + return func(cfg *configRedis) { + cfg.prefix = pfx + } +} +func RedisWithPassword(p string) redisOption { + return func(cfg *configRedis) { + cfg.password = p + } +} +func RedisWithExpiration(exp time.Duration) redisOption { + return func(cfg *configRedis) { + cfg.expiration = exp + } +} + // Setup initializes redis cache for application. Must be called only once. func (rc *RedisCache) Setup(addr, password, prefix string, db int, exp time.Duration) { @@ -87,6 +119,46 @@ func (rc *RedisCache) Setup(addr, password, prefix string, db int, exp time.Dura } +// SetupRedisCache initializes redis cache for application and returns it. Must be called only once. +func SetupRedisCache(opts ...redisOption) (*RedisCache, error) { + + rc := new(RedisCache) + + cfg := new(configRedis) + + for i := range opts { + opts[i](cfg) + } + + rc.Addr = cfg.addr + rc.Password = cfg.password + rc.DB = cfg.db + rc.Expiration = cfg.expiration + rc.Prefix = cfg.prefix + + rc.opt = &redis.Options{ + Addr: cfg.addr, + Password: cfg.password, + DB: cfg.db, + } + + rc.cli = redis.NewClient(rc.opt) + + if _, err := rc.cli.Ping().Result(); err != nil { + + return nil, errors.New("connection to redis server failed: " + err.Error()) + } + + rc.connected = true + + if cfg.prefix != "" { + rc.keyStr = contcat(rc.Prefix, keySplitter) + rc.addPrefix = true + } + + return rc, nil +} + // Set marshalls provided value and stores against provided key. Errors will be logged to initialized logger. func (rc *RedisCache) Set(key string, val interface{}) { ba, err := marshalWithTypeCheck(val) @@ -121,17 +193,8 @@ func (rc *RedisCache) SetNoExpiration(key string, val interface{}) { rc.cli.Set(rc.key(key), ba, noExp) } -// Get returns data against provided key. The final result is parsed with gjson. Returns false if not present. +// Get returns data against provided key. Returns false if not present. func (rc *RedisCache) Get(key string) (interface{}, bool) { - // exists, err := rc.cli.Exists(key).Result() - // if err != nil { - // loggermdl.LogError("error checking key ", key, " error: ", err) - // return nil, false - // } - - // if exists == NotOK { - // return nil, false - // } // Get returns error if key is not present. val, err := rc.cli.Get(rc.key(key)).Result() @@ -150,13 +213,13 @@ func (rc *RedisCache) Delete(key string) { // GetItemsCount - func (rc *RedisCache) GetItemsCount() int { - pattern := "*" - keys, err := rc.cli.Keys(pattern).Result() - if err != nil { - loggermdl.LogError("error getting item count for ", pattern, " error: ", err) - return 0 - } - return len(keys) + // pattern := rc.Prefix + "*" + // keys, err := rc.cli.Keys(pattern).Result() + // if err != nil { + // loggermdl.LogError("error getting item count for ", pattern, " error: ", err) + // return 0 + // } + return len(rc.keys()) } func (rc *RedisCache) flushDB() (string, error) { @@ -204,6 +267,47 @@ func (rc *RedisCache) key(key string) string { return key } +func (rc *RedisCache) actualKey(key string) string { + if rc.addPrefix { + return strings.TrimPrefix(key, rc.keyStr) + } + return key +} + func (rc *RedisCache) Type() int { return TypeRedisCache } + +// GetAll returns all keys with values present in redis server. Excludes the keys which does not have specified prefix. If prefix is empty, then returns all keys. +// +// **This is not intended for production use. May hamper performance** +func (rc *RedisCache) GetAll() map[string]interface{} { + keys := rc.keys() + + result := make(map[string]interface{}, len(keys)) + + for i := range keys { + ba, err := rc.cli.Get(keys[i]).Bytes() + if err != nil { + loggermdl.LogError("error getting key", keys[i], "from redis cache with error:", err) + continue + } + + var val interface{} + _ = json.Unmarshal(ba, &val) + + result[rc.actualKey(keys[i])] = val + } + + return result +} + +// GetItemsCount - +func (rc *RedisCache) keys() []string { + pattern := rc.Prefix + "*" + keys, err := rc.cli.Keys(pattern).Result() + if err != nil { + loggermdl.LogError("error getting item count for ", pattern, " error: ", err) + } + return keys +} diff --git a/cachemdl/cache_redis_test.go b/cachemdl/cache_redis_test.go index 7f6c80136c3cf56e3def86f4f57033359ebc5e38..5be7d5bb37f0b54c6d953adf28e518fd84fd8ebf 100644 --- a/cachemdl/cache_redis_test.go +++ b/cachemdl/cache_redis_test.go @@ -405,3 +405,51 @@ func BenchmarkMarshalWithTypeCheckStruct(b *testing.B) { _, _ = marshalWithTypeCheck(s) } } + +func TestRedisCache_GetAll(t *testing.T) { + tests := []struct { + name string + rc *RedisCache + want map[string]interface{} + init func(rc *RedisCache) + }{ + { + name: "Get All Items", + rc: &RedisCache{}, + want: map[string]interface{}{ + "a": 1.24, + "b": 1.25, + }, + init: func(rc *RedisCache) { + rc.Setup("127.0.0.1:6379", "", "tests", 0, time.Second*60) + rc.flushDB() + + rc.Set("a", 1.24) + rc.Set("b", 1.25) + }, + }, + { + name: "Get All Items without prefix", + rc: &RedisCache{}, + want: map[string]interface{}{ + "a": 5.24, + "b": 5.25, + }, + init: func(rc *RedisCache) { + rc.Setup("127.0.0.1:6379", "", "", 0, time.Second*60) + rc.flushDB() + + rc.Set("a", 5.24) + rc.Set("b", 5.25) + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.init(tt.rc) + if got := tt.rc.GetAll(); !reflect.DeepEqual(got, tt.want) { + t.Errorf("RedisCache.GetAll() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/cachemdl/cachemdl.go b/cachemdl/cachemdl.go index 5d15698eb9f11e1918fb6291148f3d123d562b11..6d7b7e5e65c36bbd066e876969d7fabf2f776840 100755 --- a/cachemdl/cachemdl.go +++ b/cachemdl/cachemdl.go @@ -19,6 +19,26 @@ type FastCacheHelper struct { MaxEntries int } +type fastCacheOption func(*FastCacheHelper) + +func FCWithMaxEntries(i int) fastCacheOption { + return func(cfg *FastCacheHelper) { + cfg.MaxEntries = i + } +} + +func FCWithExpiration(exp time.Duration) fastCacheOption { + return func(cfg *FastCacheHelper) { + cfg.Expiration = exp + } +} + +func FCWithCleanupInterval(ivl time.Duration) fastCacheOption { + return func(cfg *FastCacheHelper) { + cfg.CleanupTime = ivl + } +} + // Setup initializes fastcache cache for application. Must be called only once. func (fastCacheHelper *FastCacheHelper) Setup(maxEntries int, expiration time.Duration, cleanupTime time.Duration) { @@ -28,6 +48,18 @@ func (fastCacheHelper *FastCacheHelper) Setup(maxEntries int, expiration time.Du } +// SetupFastCache initializes fastcache cache for application and returns its instance. +func SetupFastCache(opts ...fastCacheOption) *FastCacheHelper { + fc := new(FastCacheHelper) + + for i := range opts { + opts[i](fc) + } + + fc.FastCache = cache.New(fc.Expiration, fc.CleanupTime) + return fc +} + // Get - func (fastCacheHelper *FastCacheHelper) Get(key string) (interface{}, bool) { return fastCacheHelper.FastCache.Get(key) @@ -71,3 +103,15 @@ func (fastCacheHelper *FastCacheHelper) GetItemsCount() int { func (fh *FastCacheHelper) Type() int { return TypeFastCache } + +// GetAll returns all keys with values present in memory. **This is not intended for production use. May hamper performance** +func (fastCacheHelper *FastCacheHelper) GetAll() map[string]interface{} { + items := fastCacheHelper.FastCache.Items() + + result := make(map[string]interface{}, len(items)) + for k, v := range items { + result[k] = v.Object + } + + return result +} diff --git a/cachemdl/cachemdl_test.go b/cachemdl/cachemdl_test.go index 5a815eba95c699ef130136dc653f4498495e0225..de9d36173da7e3e663106b93e2edd0b9942a892e 100755 --- a/cachemdl/cachemdl_test.go +++ b/cachemdl/cachemdl_test.go @@ -1,7 +1,14 @@ //@author Ajit Jagtap + //@version Thu Jul 05 2018 06:11:54 GMT+0530 (IST) + package cachemdl +import ( + "reflect" + "testing" +) + // import ( // "testing" // "time" @@ -39,3 +46,42 @@ package cachemdl // cnt = ch.Count() // assert.Zero(t, cnt, "After Purge Count should be zero") // } + +func TestFastCacheHelper_GetAll(t *testing.T) { + tests := []struct { + name string + fastCacheHelper *FastCacheHelper + want map[string]interface{} + init func(fs *FastCacheHelper) + }{ + { + name: "Get all items Success", + fastCacheHelper: &FastCacheHelper{}, + want: map[string]interface{}{ + "a": 1, + "b": 2, + }, + init: func(fs *FastCacheHelper) { + fs.Setup(2, 0, 0) + fs.Set("a", 1) + fs.Set("b", 2) + }, + }, + { + name: "Get all items Empty", + fastCacheHelper: &FastCacheHelper{}, + want: map[string]interface{}{}, + init: func(fs *FastCacheHelper) { + fs.Setup(2, 0, 0) + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.init(tt.fastCacheHelper) + if got := tt.fastCacheHelper.GetAll(); !reflect.DeepEqual(got, tt.want) { + t.Errorf("FastCacheHelper.GetAll() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/dalmdl/corefdb/bucket/packbucket.go b/dalmdl/corefdb/bucket/packbucket.go index 6fb1d1400ad2a48902338b39ac2629559bee1080..3832d03eca11a39c3810f3fb5aa9d27c8abcda81 100644 --- a/dalmdl/corefdb/bucket/packbucket.go +++ b/dalmdl/corefdb/bucket/packbucket.go @@ -288,5 +288,5 @@ func (pb *PackBucket) Reorg(filePaths []string) (errList []error) { continue } } - return nil + return } diff --git a/dalmdl/corefdb/corefdb.go b/dalmdl/corefdb/corefdb.go index e4baed7e7a21ed534c452c958cccf72e0650a3f5..eca4468b96942f9e023d88f2b5524a747e59c169 100644 --- a/dalmdl/corefdb/corefdb.go +++ b/dalmdl/corefdb/corefdb.go @@ -186,6 +186,12 @@ func (f *FDB) GetFDBIndex(indexID string) (*index.Index, bool) { return index, ok } +// GetFDBBucketStore - returns bucket +func (f *FDB) GetFDBBucketStore(bucketID string) (bucket.Store, bool) { + store, ok := f.buckets[bucketID] + return store, ok +} + func SaveDataInFDB(dbName string, indexID string, data *gjson.Result) error { fdb, err := GetFDBInstance(dbName) if err != nil { @@ -286,6 +292,38 @@ func SaveDataInFDB(dbName string, indexID string, data *gjson.Result) error { return nil } +func GetFilePaths(dbName, indexID string, queries []string) ([]string, error) { + filePaths := make([]string, 0) + fdb, err := GetFDBInstance(dbName) + if err != nil { + loggermdl.LogError("fdb instance not found: ", dbName) + return filePaths, errormdl.Wrap("fdb instance not found " + dbName) + } + // get index Id from index map + index, ok := fdb.GetFDBIndex(indexID) + if !ok { + loggermdl.LogError("INDEX not found: " + indexID) + return filePaths, errormdl.Wrap("INDEX not found: " + indexID) + } + + indexKeyValueMap, err := index.GetEntriesByQueries(queries) + if err != nil { + loggermdl.LogError(err) + return filePaths, err + } + if len(indexKeyValueMap) == 0 { + return filePaths, nil + } + for filePath := range indexKeyValueMap { + filePath, err = filepath.Abs(filepath.Join(fdb.DBPath, filePath)) + if err != nil { + loggermdl.LogError(err) + return filePaths, err + } + filePaths = append(filePaths, filePath) + } + return filePaths, nil +} func ReadDataFromFDB(dbName, indexID string, data *gjson.Result, queries []string, infileIndexQueries []string) (*gjson.Result, error) { fdb, err := GetFDBInstance(dbName) if err != nil { @@ -316,7 +354,6 @@ func ReadDataFromFDB(dbName, indexID string, data *gjson.Result, queries []strin } resultToReturn := gjson.Parse("[]") if len(indexKeyValueMap) == 0 { - loggermdl.LogError("files not found") return &resultToReturn, nil } filePaths := make([]string, 0) @@ -328,7 +365,6 @@ func ReadDataFromFDB(dbName, indexID string, data *gjson.Result, queries []strin } filePaths = append(filePaths, filePath) } - loggermdl.LogError("filePaths", filePaths) resultArray, err := bucket.Find(filePaths, infileIndexQueries, data) if err != nil { loggermdl.LogError(err) diff --git a/dalmdl/corefdb/filetype/pack.go b/dalmdl/corefdb/filetype/pack.go index 36efd81ef7dc24e4073465332356d49c93922363..119e0041f824372f167827bdf9d4b514657fbf0b 100644 --- a/dalmdl/corefdb/filetype/pack.go +++ b/dalmdl/corefdb/filetype/pack.go @@ -578,6 +578,49 @@ func (p *PackFile) ReadMedia(recordID string) ([]byte, *gjson.Result, error) { return dataByte, &metaDataObj, nil } +func (p *PackFile) ReadMediaByQuery(inFileIndexQueries []string) (map[string][]byte, map[string]gjson.Result, error) { + f := p.Fp + indexData, err := getInFileIndexData(f) + if err != nil { + loggermdl.LogError("index data not found: ", f.Name(), err) + return nil, nil, err + } + indexRows := indexData + // indexRows := indexData.Get(`#[fileType==` + requestedFileType + `]#`) + for i := 0; i < len(inFileIndexQueries); i++ { + indexRows = indexRows.Get(inFileIndexQueries[i] + "#") + } + if indexRows.String() == "" { + loggermdl.LogError("data not found") + return nil, nil, errormdl.Wrap("data not found") + } + dataMap := make(map[string][]byte, 0) + metaDataMap := make(map[string]gjson.Result, 0) + for _, indexRow := range indexRows.Array() { + startOffSet := indexRow.Get("startOffset").Int() + dataSize := indexRow.Get("dataSize").Int() + if startOffSet == 0 || dataSize == 0 { + return nil, nil, errormdl.Wrap("data not found") + } + dataByte, err := getFileDataFromPack(f, startOffSet, dataSize, nil, nil) + if err != nil { + loggermdl.LogError(err) + return nil, nil, err + } + recordID := indexRow.Get("recordID").String() + if recordID == "" { + return nil, nil, errormdl.Wrap("record id not found") + } + data, _ := sjson.Set("", "requiredData", indexRow.Get("requiredData").String()) + // data, _ = sjson.Set(data, "infileIndex", indexData.String()) + metaDataObj := gjson.Parse(data) + dataMap[recordID] = dataByte + metaDataMap[recordID] = metaDataObj + } + + return dataMap, metaDataMap, nil +} + func (p *PackFile) RemoveMedia(recordID string) error { queries := []string{`#[recordID=` + recordID + `]`} _, err := p.Remove(queries) diff --git a/dalmdl/coremongo/coremongo.go b/dalmdl/coremongo/coremongo.go index f775d6f3743f93a76d63eb07ca9a17ec444536bf..8163d69b266d9e9ea8c7984ab0bfc06530ec87fa 100644 --- a/dalmdl/coremongo/coremongo.go +++ b/dalmdl/coremongo/coremongo.go @@ -222,7 +222,7 @@ func (mg *MongoDAO) SaveData(data interface{}) (string, error) { if errormdl.CheckErr1(insertError) != nil { return "", errormdl.CheckErr1(insertError) } - return opts.InsertedID.(primitive.ObjectID).Hex(), nil + return getInsertedId(opts.InsertedID), nil } // UpdateAll update all @@ -686,3 +686,14 @@ func bindMongoServerWithPort(server string, port int) string { } return serverURI } + +func getInsertedId(id interface{}) string { + switch v := id.(type) { + case string: + return v + case primitive.ObjectID: + return v.Hex() + default: + return "" + } +} diff --git a/dalmdl/dalmdl.go b/dalmdl/dalmdl.go index 0b8917a5d3c46ac78f8529164146177640baac1e..a1416c0b0d73e556c8817ceeecb6d22bb37b51a8 100755 --- a/dalmdl/dalmdl.go +++ b/dalmdl/dalmdl.go @@ -1 +1,9 @@ package dalmdl + +const ( + MONGODB = "MONGO" + MYSQL = "MYSQL" + FDB = "FDB" + SQLSERVER = "SQLSERVER" + GraphDB = "GRAPHDB" +) diff --git a/routebuildermdl/routebuilder_fasthttp.go b/routebuildermdl/routebuilder_fasthttp.go index 2ab4ce36ee1a6d4437f3e834f950306f0dbb4a05..2a49e1e3d1a7b12d6686039239fbbe6467df584c 100644 --- a/routebuildermdl/routebuilder_fasthttp.go +++ b/routebuildermdl/routebuilder_fasthttp.go @@ -4,21 +4,20 @@ package routebuildermdl import ( "context" - "coresls/servers/coresls/app/modules/constantmdl" + "net" "strings" - "github.com/pquerna/ffjson/ffjson" - - routing "github.com/qiangxue/fasthttp-routing" - - "corelab.mkcl.org/MKCLOS/coredevelopmentplatform/corepkgv2/statemdl" - "corelab.mkcl.org/MKCLOS/coredevelopmentplatform/corepkgv2/authmdl/jwtmdl" "corelab.mkcl.org/MKCLOS/coredevelopmentplatform/corepkgv2/authmdl/roleenforcemdl" + "corelab.mkcl.org/MKCLOS/coredevelopmentplatform/corepkgv2/dalmdl" "corelab.mkcl.org/MKCLOS/coredevelopmentplatform/corepkgv2/errormdl" "corelab.mkcl.org/MKCLOS/coredevelopmentplatform/corepkgv2/loggermdl" "corelab.mkcl.org/MKCLOS/coredevelopmentplatform/corepkgv2/servicebuildermdl" + "corelab.mkcl.org/MKCLOS/coredevelopmentplatform/corepkgv2/statemdl" + version "github.com/hashicorp/go-version" + "github.com/pquerna/ffjson/ffjson" + routing "github.com/qiangxue/fasthttp-routing" "github.com/tidwall/gjson" ) @@ -72,7 +71,7 @@ func commonHandler(c *routing.Context, isRestricted, isRoleBased, heavyDataActiv var err error // database transaction rollback if transaction is enabled switch ab.DatabaseType { - case constantmdl.MYSQL: + case dalmdl.MYSQL: if ab.TXN != nil { loggermdl.LogError("MYSQL Transaction Rollbacked") err = ab.TXN.Rollback() @@ -82,7 +81,7 @@ func commonHandler(c *routing.Context, isRestricted, isRoleBased, heavyDataActiv } } - case constantmdl.SQLSERVER: + case dalmdl.SQLSERVER: if ab.SQLServerTXN != nil { loggermdl.LogError("SQLSERVER Transaction Rollbacked") err = ab.SQLServerTXN.Rollback() @@ -92,7 +91,7 @@ func commonHandler(c *routing.Context, isRestricted, isRoleBased, heavyDataActiv } } - case constantmdl.GraphDB: + case dalmdl.GraphDB: if ab.GraphDbTXN != nil { loggermdl.LogError("GRAPHDB Transaction Rollbacked") err = ab.GraphDbTXN.Discard(context.TODO()) @@ -116,7 +115,7 @@ func commonHandler(c *routing.Context, isRestricted, isRoleBased, heavyDataActiv var err error switch ab.DatabaseType { - case constantmdl.MYSQL: + case dalmdl.MYSQL: if ab.TXN != nil { loggermdl.LogError("MYSQL Transaction Commit") err = ab.TXN.Commit() @@ -129,7 +128,7 @@ func commonHandler(c *routing.Context, isRestricted, isRoleBased, heavyDataActiv } } - case constantmdl.SQLSERVER: + case dalmdl.SQLSERVER: if ab.SQLServerTXN != nil { loggermdl.LogError("SQLSERVER Transaction Commit") err = ab.SQLServerTXN.Commit() @@ -142,7 +141,7 @@ func commonHandler(c *routing.Context, isRestricted, isRoleBased, heavyDataActiv } } - case constantmdl.GraphDB: + case dalmdl.GraphDB: if ab.SQLServerTXN != nil { loggermdl.LogError("GRAPHDB Transaction Commit") err = ab.GraphDbTXN.Commit(context.TODO()) @@ -205,8 +204,7 @@ func commonHandler(c *routing.Context, isRestricted, isRoleBased, heavyDataActiv func OpenHandler(c *routing.Context) error { c.Response.Header.Set("content-type", "application/json") principal := servicebuildermdl.Principal{} - - principal.ClientIP = c.RemoteIP().String() + principal.ClientIP = getClientIP(c) commonHandler(c, false, false, false, principal) return nil } @@ -221,7 +219,6 @@ func RestrictedHandler(c *routing.Context) error { c.SetStatusCode(412) return err } - pricipalObj.ClientIP = c.RemoteIP().String() commonHandler(c, true, false, false, pricipalObj) return nil } @@ -236,7 +233,6 @@ func RoleBasedHandler(c *routing.Context) error { c.SetStatusCode(412) return err } - pricipalObj.ClientIP = c.RemoteIP().String() commonHandler(c, true, true, false, pricipalObj) return nil } @@ -245,8 +241,7 @@ func RoleBasedHandler(c *routing.Context) error { func HeavyOpenHandler(c *routing.Context) error { c.Response.Header.Set("content-type", "application/json") principal := servicebuildermdl.Principal{} - - principal.ClientIP = c.RemoteIP().String() + principal.ClientIP = getClientIP(c) commonHandler(c, false, false, true, principal) return nil } @@ -261,7 +256,6 @@ func HeavyRestrictedHandler(c *routing.Context) error { c.SetStatusCode(412) return err } - pricipalObj.ClientIP = c.RemoteIP().String() commonHandler(c, true, false, true, pricipalObj) return nil } @@ -276,7 +270,6 @@ func HeavyRoleBasedHandler(c *routing.Context) error { c.SetStatusCode(412) return err } - pricipalObj.ClientIP = c.RemoteIP().String() commonHandler(c, true, true, true, pricipalObj) return nil } @@ -315,7 +308,7 @@ func extractPricipalObject(c *routing.Context) (servicebuildermdl.Principal, err claim, decodeError := jwtmdl.DecodeToken(&c.Request) if errormdl.CheckErr(decodeError) != nil { - loggermdl.LogError(decodeError) + // loggermdl.LogError(decodeError) return principal, errormdl.CheckErr(decodeError) } @@ -324,12 +317,7 @@ func extractPricipalObject(c *routing.Context) (servicebuildermdl.Principal, err loggermdl.LogError(grperr) return principal, errormdl.CheckErr(grperr) } - userID, ok := claim["userId"].(string) - if !ok { - loggermdl.LogError("Unable to parse UserID from JWT Token") - return principal, errormdl.Wrap("Unable to parse UserID from JWT Token") - } - + userID, _ := claim["userId"].(string) if len(userID) < 2 { loggermdl.LogError("UserID length is less than 2") return principal, errormdl.Wrap("UserID length is less than 2") @@ -347,5 +335,31 @@ func extractPricipalObject(c *routing.Context) (servicebuildermdl.Principal, err principal.Groups = groups principal.UserID = userID principal.Token = string(c.Request.Header.Peek("Authorization")) + // set client ip + principal.ClientIP = getClientIP(c) return principal, nil } + +// getClientIP - returns respected header value from request header +func getClientIP(c *routing.Context) string { + clientIP := string(c.Request.Header.Peek("X-Real-Ip")) + if clientIP == "" { + clientIP = string(c.Request.Header.Peek("X-Forwarded-For")) + } + if clientIP == "" { + clientIP, _, splitHostPortError := net.SplitHostPort(c.RemoteIP().String()) + if splitHostPortError == nil && isCorrectIP(clientIP) { + return clientIP + } + return "" + } + if isCorrectIP(clientIP) { + return clientIP + } + return "" +} + +// isCorrectIP - return true if ip string is valid textual representation of an IP address, else returns false +func isCorrectIP(ip string) bool { + return net.ParseIP(ip) != nil +} diff --git a/routebuildermdl/routebuilder_gin.go b/routebuildermdl/routebuilder_gin.go index 8a91f0fdbfc14d989f044b8404cb036bd83546d0..528db5083682e7a51ca4359cfc4b692a17f7e29b 100644 --- a/routebuildermdl/routebuilder_gin.go +++ b/routebuildermdl/routebuilder_gin.go @@ -4,19 +4,19 @@ package routebuildermdl import ( "context" - "coresls/servers/coresls/app/modules/constantmdl" "io/ioutil" + "net" "net/http" "strings" - "corelab.mkcl.org/MKCLOS/coredevelopmentplatform/corepkgv2/authmdl/roleenforcemdl" - - "corelab.mkcl.org/MKCLOS/coredevelopmentplatform/corepkgv2/statemdl" - "corelab.mkcl.org/MKCLOS/coredevelopmentplatform/corepkgv2/authmdl/jwtmdl" + "corelab.mkcl.org/MKCLOS/coredevelopmentplatform/corepkgv2/authmdl/roleenforcemdl" + "corelab.mkcl.org/MKCLOS/coredevelopmentplatform/corepkgv2/dalmdl" "corelab.mkcl.org/MKCLOS/coredevelopmentplatform/corepkgv2/errormdl" "corelab.mkcl.org/MKCLOS/coredevelopmentplatform/corepkgv2/loggermdl" "corelab.mkcl.org/MKCLOS/coredevelopmentplatform/corepkgv2/servicebuildermdl" + "corelab.mkcl.org/MKCLOS/coredevelopmentplatform/corepkgv2/statemdl" + "github.com/gin-gonic/gin" version "github.com/hashicorp/go-version" "github.com/tidwall/gjson" @@ -81,7 +81,7 @@ func commonHandler(c *gin.Context, isRestricted, isRoleBased, heavyDataActivity var err error // database transaction rollback if transaction is enabled switch ab.DatabaseType { - case constantmdl.MYSQL: + case dalmdl.MYSQL: if ab.TXN != nil { loggermdl.LogError("MYSQL Transaction Rollbacked") err = ab.TXN.Rollback() @@ -91,7 +91,7 @@ func commonHandler(c *gin.Context, isRestricted, isRoleBased, heavyDataActivity } } - case constantmdl.SQLSERVER: + case dalmdl.SQLSERVER: if ab.SQLServerTXN != nil { loggermdl.LogError("SQLSERVER Transaction Rollbacked") err = ab.SQLServerTXN.Rollback() @@ -101,7 +101,7 @@ func commonHandler(c *gin.Context, isRestricted, isRoleBased, heavyDataActivity } } - case constantmdl.GraphDB: + case dalmdl.GraphDB: if ab.GraphDbTXN != nil { loggermdl.LogError("GRAPHDB Transaction Rollbacked") err = ab.GraphDbTXN.Discard(context.TODO()) @@ -125,7 +125,7 @@ func commonHandler(c *gin.Context, isRestricted, isRoleBased, heavyDataActivity var err error switch ab.DatabaseType { - case constantmdl.MYSQL: + case dalmdl.MYSQL: if ab.TXN != nil { loggermdl.LogError("MYSQL Transaction Commit") err = ab.TXN.Commit() @@ -138,7 +138,7 @@ func commonHandler(c *gin.Context, isRestricted, isRoleBased, heavyDataActivity } } - case constantmdl.SQLSERVER: + case dalmdl.SQLSERVER: if ab.SQLServerTXN != nil { loggermdl.LogError("SQLSERVER Transaction Commit") err = ab.SQLServerTXN.Commit() @@ -151,7 +151,7 @@ func commonHandler(c *gin.Context, isRestricted, isRoleBased, heavyDataActivity } } - case constantmdl.GraphDB: + case dalmdl.GraphDB: if ab.SQLServerTXN != nil { loggermdl.LogError("GRAPHDB Transaction Commit") err = ab.GraphDbTXN.Commit(context.TODO()) @@ -193,8 +193,7 @@ func commonHandler(c *gin.Context, isRestricted, isRoleBased, heavyDataActivity // OpenHandler for /o func OpenHandler(c *gin.Context) { principal := servicebuildermdl.Principal{} - - // principal.ClientIP = c.Request.RemoteAddr + principal.ClientIP = getClientIP(c) commonHandler(c, false, false, false, principal) } @@ -206,7 +205,6 @@ func RestrictedHandler(c *gin.Context) { c.JSON(http.StatusExpectationFailed, extractError.Error()) return } - pricipalObj.ClientIP = c.Request.RemoteAddr commonHandler(c, true, false, false, pricipalObj) } @@ -218,15 +216,13 @@ func RoleBasedHandler(c *gin.Context) { c.JSON(http.StatusExpectationFailed, extractError.Error()) return } - pricipalObj.ClientIP = c.Request.RemoteAddr commonHandler(c, true, true, false, pricipalObj) } // HeavyOpenHandler for /o func HeavyOpenHandler(c *gin.Context) { principal := servicebuildermdl.Principal{} - - // principal.ClientIP = c.Request.RemoteAddr + principal.ClientIP = getClientIP(c) commonHandler(c, false, false, true, principal) } @@ -238,7 +234,6 @@ func HeavyRestrictedHandler(c *gin.Context) { c.JSON(http.StatusExpectationFailed, extractError.Error()) return } - pricipalObj.ClientIP = c.Request.RemoteAddr commonHandler(c, true, false, true, pricipalObj) } @@ -250,7 +245,6 @@ func HeavyRoleBasedHandler(c *gin.Context) { c.JSON(http.StatusExpectationFailed, extractError.Error()) return } - pricipalObj.ClientIP = c.Request.RemoteAddr commonHandler(c, true, true, true, pricipalObj) } @@ -287,7 +281,6 @@ func extractPricipalObject(c *gin.Context) (servicebuildermdl.Principal, error) } claim, decodeError := jwtmdl.DecodeToken(c.Request) if errormdl.CheckErr(decodeError) != nil { - loggermdl.LogError(decodeError) return principal, errormdl.CheckErr(decodeError) } @@ -296,12 +289,7 @@ func extractPricipalObject(c *gin.Context) (servicebuildermdl.Principal, error) loggermdl.LogError(grperr) return principal, errormdl.CheckErr(grperr) } - userID, ok := claim["userId"].(string) - if !ok { - loggermdl.LogError("Unable to parse UserID from JWT Token") - return principal, errormdl.Wrap("Unable to parse UserID from JWT Token") - } - + userID, _ := claim["userId"].(string) if len(userID) < 2 { loggermdl.LogError("UserID length is less than 2") return principal, errormdl.Wrap("UserID length is less than 2") @@ -319,5 +307,30 @@ func extractPricipalObject(c *gin.Context) (servicebuildermdl.Principal, error) principal.Groups = groups principal.UserID = userID principal.Token = c.Request.Header.Get("Authorization") + principal.ClientIP = getClientIP(c) return principal, nil } + +// getClientIP - returns respected header value from request header +func getClientIP(c *gin.Context) string { + clientIP := c.Request.Header.Get("X-Real-Ip") + if clientIP == "" { + clientIP = c.Request.Header.Get("X-Forwarded-For") + } + if clientIP == "" { + clientIP, _, splitHostPortError := net.SplitHostPort(c.Request.RemoteAddr) + if splitHostPortError == nil && isCorrectIP(clientIP) { + return clientIP + } + return "" + } + if isCorrectIP(clientIP) { + return clientIP + } + return "" +} + +// isCorrectIP - return true if ip string is valid textual representation of an IP address, else returns false +func isCorrectIP(ip string) bool { + return net.ParseIP(ip) != nil +} diff --git a/servicebuildermdl/servicebuildermdl.go b/servicebuildermdl/servicebuildermdl.go index 163879e9cda9c092cbd3fbfe21761d202e65a8b6..a774c7ffff35b70dcbc247835009e03b8003a985 100644 --- a/servicebuildermdl/servicebuildermdl.go +++ b/servicebuildermdl/servicebuildermdl.go @@ -6,6 +6,7 @@ package servicebuildermdl import ( "database/sql" + "net" "strings" "sync" "time" @@ -43,6 +44,45 @@ var once sync.Once var ruleCache map[string]conditions.Expr var mutex = &sync.Mutex{} +// get server ip address +var ( + serverIP = func() string { + ifaces, err := net.Interfaces() + if err != nil { + return "" + } + for _, iface := range ifaces { + if iface.Flags&net.FlagUp == 0 { + continue // interface down + } + if iface.Flags&net.FlagLoopback != 0 { + continue // loopback interface + } + addrs, err := iface.Addrs() + if err != nil { + return "" + } + for _, addr := range addrs { + var ip net.IP + switch v := addr.(type) { + case *net.IPNet: + ip = v.IP + case *net.IPAddr: + ip = v.IP + } + if ip == nil || ip.IsLoopback() { + continue + } + if ip = ip.To4(); ip == nil { + continue // not an ipv4 address + } + return ip.String() + } + } + return "" + }() +) + func init() { ruleCache = make(map[string]conditions.Expr) globalConfig = make(map[string]GlobalConfigModel) @@ -144,6 +184,7 @@ func (ab *AbstractBusinessLogicHolder) GetDataResultset(key string) (*gjson.Resu return value, true } +// GetMQLRequestData - returns MQLRequestData func (ab *AbstractBusinessLogicHolder) GetMQLRequestData() (*gjson.Result, bool) { //check in map temp, found := ab.localServiceData[constantmdl.MQLRequestData] @@ -254,7 +295,7 @@ func (ab *AbstractBusinessLogicHolder) SetCustomData(key string, data interface{ ab.localServiceData[key] = data } -// SetMQLToken will set token in header +// SetErrorData will set error func (ab *AbstractBusinessLogicHolder) SetErrorData(data interface{}) { ab.ServiceError = data } @@ -270,6 +311,16 @@ func (ab *AbstractBusinessLogicHolder) SetFinalData(data interface{}) { ab.localServiceData["finaldata"] = data } +// GetClientIP will returns client ip address +func (ab *AbstractBusinessLogicHolder) GetClientIP() string { + return ab.pricipalObject.ClientIP +} + +// GetServerIP will returns server ip address +func (ab *AbstractBusinessLogicHolder) GetServerIP() string { + return serverIP +} + // SetErrorCode - SetErrorCode in service context func (ab *AbstractBusinessLogicHolder) SetErrorCode(code int) { ab.GlobalErrorCode = code diff --git a/sessionmanagermdl/sessionmanager.go b/sessionmanagermdl/sessionmanager.go index bcde051d05354001266697e755abac7ca43940e7..f023f97a69486f4c58c893e8d88803cbbb6889ab 100644 --- a/sessionmanagermdl/sessionmanager.go +++ b/sessionmanagermdl/sessionmanager.go @@ -1,7 +1,6 @@ package sessionmanagermdl import ( - "coresls/servers/coresls/app/models" "errors" "time" @@ -30,23 +29,13 @@ var store cachemdl.Cacher var ErrSessionNotFound = errors.New("SESSION_NOT_FOUND") var ErrInvalidDataType = errors.New("INVALID_DATA_Type") -// InitSessionManagerCache initializes the cache with provided configuration. Need to provide a cache type to use. -func InitSessionManagerCache(chacheType int) { - cacheConfig := cachemdl.CacheConfig{ - Type: chacheType, - RedisCache: &cachemdl.RedisCache{ - Addr: "", // empty takes default redis address 6379 - DB: models.RedisDBSessionManager, - Prefix: models.ProjectID, - }, - FastCache: &cachemdl.FastCacheHelper{ - CleanupTime: 60 * time.Minute, - MaxEntries: 1000, - Expiration: -1, - }, +// Init initializes session manager with provided cache. Subsequent calls will not have any effect after first initialization. +func Init(cache cachemdl.Cacher) { + if store != nil { + return } - store = cachemdl.GetCacheInstance(&cacheConfig) + store = cache } // NewEntry prepares the object required to store data in session. @@ -122,6 +111,11 @@ func Retrieve(key string) (Entry, error) { } } +// RetrieveAll returns all entries present in memory. **Not for production use. May add performance costs** +func RetrieveAll() map[string]interface{} { + return store.GetAll() +} + // RetrieveAndExtend returns the entry and extends the entry expiration by provided `SECONDS`, only if remaining time < extendBy. // If extendBy < 0, it is same as Retrieve function. func RetrieveAndExtend(key string, extendBy int64) (Entry, error) {