95 lines
3.7 KiB
Go
95 lines
3.7 KiB
Go
package middleware
|
|
|
|
import (
|
|
"context"
|
|
"net/http"
|
|
|
|
"go.uber.org/zap"
|
|
|
|
"github.com/gocql/gocql"
|
|
"codeberg.org/mapleopentech/monorepo/cloud/maplefile-backend/config/constants"
|
|
"codeberg.org/mapleopentech/monorepo/cloud/maplefile-backend/pkg/httperror"
|
|
)
|
|
|
|
func (mid *middleware) PostJWTProcessorMiddleware(fn http.HandlerFunc) http.HandlerFunc {
|
|
return func(w http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
|
|
// Get our authorization information.
|
|
isAuthorized, ok := ctx.Value(constants.SessionIsAuthorized).(bool)
|
|
if ok && isAuthorized {
|
|
// CWE-391: Safe type assertion to prevent panic-based DoS
|
|
// OWASP A09:2021: Security Logging and Monitoring - Prevents service crashes
|
|
sessionID, ok := ctx.Value(constants.SessionID).(string)
|
|
if !ok {
|
|
mid.logger.Error("Invalid session ID type in context")
|
|
problem := httperror.NewInternalServerError("Invalid session context")
|
|
problem.WithInstance(r.URL.Path).
|
|
WithTraceID(httperror.ExtractRequestID(r))
|
|
httperror.RespondWithProblem(w, problem)
|
|
return
|
|
}
|
|
|
|
// Parse the user ID from the session ID (which is actually the user ID string from JWT)
|
|
userID, err := gocql.ParseUUID(sessionID)
|
|
if err != nil {
|
|
problem := httperror.NewUnauthorizedError("Invalid user ID in token")
|
|
problem.WithInstance(r.URL.Path).
|
|
WithTraceID(httperror.ExtractRequestID(r))
|
|
httperror.RespondWithProblem(w, problem)
|
|
return
|
|
}
|
|
|
|
// Lookup our user profile by ID or return 500 error.
|
|
user, err := mid.userGetByIDUseCase.Execute(ctx, userID)
|
|
if err != nil {
|
|
// Log the actual error for debugging but return generic message to client
|
|
mid.logger.Error("Failed to get user by ID",
|
|
zap.Error(err),
|
|
zap.String("user_id", userID.String()))
|
|
problem := httperror.NewInternalServerError("Unable to verify session")
|
|
problem.WithInstance(r.URL.Path).
|
|
WithTraceID(httperror.ExtractRequestID(r))
|
|
httperror.RespondWithProblem(w, problem)
|
|
return
|
|
}
|
|
|
|
// If no user was found then that means our session expired and the
|
|
// user needs to login or use the refresh token.
|
|
if user == nil {
|
|
problem := httperror.NewUnauthorizedError("Session expired")
|
|
problem.WithInstance(r.URL.Path).
|
|
WithTraceID(httperror.ExtractRequestID(r))
|
|
httperror.RespondWithProblem(w, problem)
|
|
return
|
|
}
|
|
|
|
// // If system administrator disabled the user account then we need
|
|
// // to generate a 403 error letting the user know their account has
|
|
// // been disabled and you cannot access the protected API endpoint.
|
|
// if user.State == 0 {
|
|
// http.Error(w, "Account disabled - please contact admin", http.StatusForbidden)
|
|
// return
|
|
// }
|
|
|
|
// Save our user information to the context.
|
|
// Save our user.
|
|
ctx = context.WithValue(ctx, constants.SessionUser, user)
|
|
|
|
// Save individual pieces of the user profile.
|
|
ctx = context.WithValue(ctx, constants.SessionID, sessionID)
|
|
ctx = context.WithValue(ctx, constants.SessionUserID, user.ID)
|
|
ctx = context.WithValue(ctx, constants.SessionUserRole, user.Role)
|
|
ctx = context.WithValue(ctx, constants.SessionUserName, user.Name)
|
|
ctx = context.WithValue(ctx, constants.SessionUserFirstName, user.FirstName)
|
|
ctx = context.WithValue(ctx, constants.SessionUserLastName, user.LastName)
|
|
ctx = context.WithValue(ctx, constants.SessionUserTimezone, user.Timezone)
|
|
// ctx = context.WithValue(ctx, constants.SessionUserStoreID, user.StoreID)
|
|
// ctx = context.WithValue(ctx, constants.SessionUserStoreName, user.StoreName)
|
|
// ctx = context.WithValue(ctx, constants.SessionUserStoreLevel, user.StoreLevel)
|
|
// ctx = context.WithValue(ctx, constants.SessionUserStoreTimezone, user.StoreTimezone)
|
|
}
|
|
|
|
fn(w, r.WithContext(ctx))
|
|
}
|
|
}
|