Initial commit: Open sourcing all of the Maple Open Technologies code.
This commit is contained in:
commit
755d54a99d
2010 changed files with 448675 additions and 0 deletions
281
native/desktop/maplefile/internal/service/auth/service.go
Normal file
281
native/desktop/maplefile/internal/service/auth/service.go
Normal file
|
|
@ -0,0 +1,281 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
"codeberg.org/mapleopentech/monorepo/cloud/maplefile-backend/pkg/maplefile/client"
|
||||
domainSession "codeberg.org/mapleopentech/monorepo/native/desktop/maplefile/internal/domain/session"
|
||||
"codeberg.org/mapleopentech/monorepo/native/desktop/maplefile/internal/usecase/session"
|
||||
"codeberg.org/mapleopentech/monorepo/native/desktop/maplefile/internal/utils"
|
||||
)
|
||||
|
||||
type Service struct {
|
||||
apiClient *client.Client
|
||||
createSessionUC *session.CreateUseCase
|
||||
getSessionUC *session.GetByIdUseCase
|
||||
deleteSessionUC *session.DeleteUseCase
|
||||
saveSessionUC *session.SaveUseCase
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// ProvideService creates the auth service for Wire
|
||||
func ProvideService(
|
||||
apiClient *client.Client,
|
||||
createSessionUC *session.CreateUseCase,
|
||||
getSessionUC *session.GetByIdUseCase,
|
||||
deleteSessionUC *session.DeleteUseCase,
|
||||
saveSessionUC *session.SaveUseCase,
|
||||
logger *zap.Logger,
|
||||
) *Service {
|
||||
svc := &Service{
|
||||
apiClient: apiClient,
|
||||
createSessionUC: createSessionUC,
|
||||
getSessionUC: getSessionUC,
|
||||
deleteSessionUC: deleteSessionUC,
|
||||
saveSessionUC: saveSessionUC,
|
||||
logger: logger.Named("auth-service"),
|
||||
}
|
||||
|
||||
// Set up token refresh callback to persist new tokens to session
|
||||
apiClient.OnTokenRefresh(func(accessToken, refreshToken, accessTokenExpiryDate string) {
|
||||
svc.handleTokenRefresh(accessToken, refreshToken, accessTokenExpiryDate)
|
||||
})
|
||||
|
||||
return svc
|
||||
}
|
||||
|
||||
// handleTokenRefresh is called when the API client automatically refreshes the access token
|
||||
func (s *Service) handleTokenRefresh(accessToken, refreshToken, accessTokenExpiryDate string) {
|
||||
// Get the current session
|
||||
existingSession, err := s.getSessionUC.Execute()
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to get session during token refresh callback", zap.Error(err))
|
||||
return
|
||||
}
|
||||
|
||||
if existingSession == nil {
|
||||
s.logger.Warn("No session found during token refresh callback")
|
||||
return
|
||||
}
|
||||
|
||||
// Update the session with new tokens
|
||||
existingSession.AccessToken = accessToken
|
||||
existingSession.RefreshToken = refreshToken
|
||||
|
||||
// Parse the actual expiry date from the response instead of using hardcoded value
|
||||
if accessTokenExpiryDate != "" {
|
||||
expiryTime, parseErr := time.Parse(time.RFC3339, accessTokenExpiryDate)
|
||||
if parseErr != nil {
|
||||
s.logger.Warn("Failed to parse access token expiry date, using default 15m",
|
||||
zap.String("expiry_date", accessTokenExpiryDate),
|
||||
zap.Error(parseErr))
|
||||
existingSession.ExpiresAt = time.Now().Add(15 * time.Minute)
|
||||
} else {
|
||||
existingSession.ExpiresAt = expiryTime
|
||||
s.logger.Debug("Using actual token expiry from response",
|
||||
zap.Time("expiry_time", expiryTime))
|
||||
}
|
||||
} else {
|
||||
s.logger.Warn("No access token expiry date in refresh response, using default 15m")
|
||||
existingSession.ExpiresAt = time.Now().Add(15 * time.Minute)
|
||||
}
|
||||
|
||||
// Save updated session
|
||||
if err := s.saveSessionUC.Execute(existingSession); err != nil {
|
||||
s.logger.Error("Failed to save session after token refresh", zap.Error(err))
|
||||
return
|
||||
}
|
||||
|
||||
s.logger.Info("Session updated with refreshed tokens", zap.String("email", utils.MaskEmail(existingSession.Email)))
|
||||
}
|
||||
|
||||
// RequestOTT requests a one-time token for login
|
||||
func (s *Service) RequestOTT(ctx context.Context, email string) error {
|
||||
_, err := s.apiClient.RequestOTT(ctx, email)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to request OTT", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
s.logger.Info("OTT requested successfully", zap.String("email", utils.MaskEmail(email)))
|
||||
return nil
|
||||
}
|
||||
|
||||
// VerifyOTT verifies the one-time token and returns the encrypted challenge
|
||||
func (s *Service) VerifyOTT(ctx context.Context, email, ott string) (*client.VerifyOTTResponse, error) {
|
||||
resp, err := s.apiClient.VerifyOTT(ctx, email, ott)
|
||||
if err != nil {
|
||||
s.logger.Error("OTT verification failed", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
s.logger.Info("OTT verified successfully", zap.String("email", utils.MaskEmail(email)))
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// CompleteLogin completes the login process with OTT and challenge
|
||||
func (s *Service) CompleteLogin(ctx context.Context, input *client.CompleteLoginInput) (*client.LoginResponse, error) {
|
||||
// Complete login via API
|
||||
resp, err := s.apiClient.CompleteLogin(ctx, input)
|
||||
if err != nil {
|
||||
s.logger.Error("Login failed", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Parse expiration time from response
|
||||
var expiresIn time.Duration
|
||||
if resp.AccessTokenExpiryDate != "" {
|
||||
expiryTime, parseErr := time.Parse(time.RFC3339, resp.AccessTokenExpiryDate)
|
||||
if parseErr != nil {
|
||||
s.logger.Warn("Failed to parse access token expiry date, using default 15m",
|
||||
zap.String("expiry_date", resp.AccessTokenExpiryDate),
|
||||
zap.Error(parseErr))
|
||||
expiresIn = 15 * time.Minute // Default to 15 minutes (backend default)
|
||||
} else {
|
||||
expiresIn = time.Until(expiryTime)
|
||||
s.logger.Info("Parsed access token expiry",
|
||||
zap.Time("expiry_time", expiryTime),
|
||||
zap.Duration("expires_in", expiresIn))
|
||||
}
|
||||
} else {
|
||||
s.logger.Warn("No access token expiry date in response, using default 15m")
|
||||
expiresIn = 15 * time.Minute // Default to 15 minutes (backend default)
|
||||
}
|
||||
|
||||
// Use email as userID for now (can be improved later)
|
||||
userID := input.Email
|
||||
|
||||
// Save session locally via use case
|
||||
err = s.createSessionUC.Execute(
|
||||
userID,
|
||||
input.Email,
|
||||
resp.AccessToken,
|
||||
resp.RefreshToken,
|
||||
expiresIn,
|
||||
)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to save session", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s.logger.Info("User logged in successfully", zap.String("email", utils.MaskEmail(input.Email)))
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// Logout removes the local session
|
||||
func (s *Service) Logout(ctx context.Context) error {
|
||||
// Delete local session
|
||||
err := s.deleteSessionUC.Execute()
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to delete session", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
s.logger.Info("User logged out successfully")
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetCurrentSession retrieves the current user session
|
||||
func (s *Service) GetCurrentSession(ctx context.Context) (*domainSession.Session, error) {
|
||||
sess, err := s.getSessionUC.Execute()
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to get session", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
return sess, nil
|
||||
}
|
||||
|
||||
// UpdateSession updates the current session
|
||||
func (s *Service) UpdateSession(ctx context.Context, sess *domainSession.Session) error {
|
||||
return s.saveSessionUC.Execute(sess)
|
||||
}
|
||||
|
||||
// IsLoggedIn checks if a user is currently logged in
|
||||
func (s *Service) IsLoggedIn(ctx context.Context) (bool, error) {
|
||||
sess, err := s.getSessionUC.Execute()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if sess == nil {
|
||||
return false, nil
|
||||
}
|
||||
return sess.IsValid(), nil
|
||||
}
|
||||
|
||||
// RestoreSession restores tokens to the API client from a persisted session
|
||||
// This is used on app startup to resume a session from a previous run
|
||||
func (s *Service) RestoreSession(ctx context.Context, sess *domainSession.Session) error {
|
||||
if sess == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Restore tokens to API client
|
||||
s.apiClient.SetTokens(sess.AccessToken, sess.RefreshToken)
|
||||
s.logger.Info("Session restored to API client",
|
||||
zap.String("user_id", sess.UserID),
|
||||
zap.String("email", utils.MaskEmail(sess.Email)))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Register creates a new user account
|
||||
func (s *Service) Register(ctx context.Context, input *client.RegisterInput) error {
|
||||
_, err := s.apiClient.Register(ctx, input)
|
||||
if err != nil {
|
||||
s.logger.Error("Registration failed", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
s.logger.Info("User registered successfully", zap.String("email", utils.MaskEmail(input.Email)))
|
||||
return nil
|
||||
}
|
||||
|
||||
// VerifyEmail verifies the email with the verification code
|
||||
func (s *Service) VerifyEmail(ctx context.Context, input *client.VerifyEmailInput) error {
|
||||
_, err := s.apiClient.VerifyEmailCode(ctx, input)
|
||||
if err != nil {
|
||||
s.logger.Error("Email verification failed", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
s.logger.Info("Email verified successfully", zap.String("email", utils.MaskEmail(input.Email)))
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAPIClient returns the API client instance
|
||||
// This allows other parts of the application to make authenticated API calls
|
||||
func (s *Service) GetAPIClient() *client.Client {
|
||||
return s.apiClient
|
||||
}
|
||||
|
||||
// InitiateRecovery initiates the account recovery process
|
||||
func (s *Service) InitiateRecovery(ctx context.Context, email, method string) (*client.RecoveryInitiateResponse, error) {
|
||||
resp, err := s.apiClient.RecoveryInitiate(ctx, email, method)
|
||||
if err != nil {
|
||||
s.logger.Error("Recovery initiation failed", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
s.logger.Info("Recovery initiated successfully", zap.String("email", utils.MaskEmail(email)))
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// VerifyRecovery verifies the recovery challenge
|
||||
func (s *Service) VerifyRecovery(ctx context.Context, input *client.RecoveryVerifyInput) (*client.RecoveryVerifyResponse, error) {
|
||||
resp, err := s.apiClient.RecoveryVerify(ctx, input)
|
||||
if err != nil {
|
||||
s.logger.Error("Recovery verification failed", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
s.logger.Info("Recovery verification successful")
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// CompleteRecovery completes the account recovery and resets credentials
|
||||
func (s *Service) CompleteRecovery(ctx context.Context, input *client.RecoveryCompleteInput) (*client.RecoveryCompleteResponse, error) {
|
||||
resp, err := s.apiClient.RecoveryComplete(ctx, input)
|
||||
if err != nil {
|
||||
s.logger.Error("Recovery completion failed", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
s.logger.Info("Recovery completed successfully")
|
||||
return resp, nil
|
||||
}
|
||||
|
|
@ -0,0 +1,199 @@
|
|||
package httpclient
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Service provides an HTTP client with proper timeouts.
|
||||
// This addresses OWASP security concern B1: using http.DefaultClient which has
|
||||
// no timeouts and can be vulnerable to slowloris attacks and resource exhaustion.
|
||||
//
|
||||
// Note: TLS/SSL is handled by Caddy reverse proxy in production (see OWASP report
|
||||
// A04-4.1 "Certificate Pinning Not Required" - BY DESIGN). This service focuses
|
||||
// on adding timeouts, not TLS configuration.
|
||||
//
|
||||
// For large file downloads, use DoDownloadNoTimeout() which relies on the request's
|
||||
// context for cancellation instead of a fixed timeout. This allows multi-gigabyte
|
||||
// files to download without timeout issues while still being cancellable.
|
||||
type Service struct {
|
||||
// client is the configured HTTP client for API requests
|
||||
client *http.Client
|
||||
|
||||
// downloadClient is a separate client for file downloads with longer timeouts
|
||||
downloadClient *http.Client
|
||||
|
||||
// noTimeoutClient is for large file downloads where context controls cancellation
|
||||
noTimeoutClient *http.Client
|
||||
}
|
||||
|
||||
// Config holds configuration options for the HTTP client service
|
||||
type Config struct {
|
||||
// RequestTimeout is the overall timeout for API requests (default: 30s)
|
||||
RequestTimeout time.Duration
|
||||
|
||||
// DownloadTimeout is the overall timeout for file downloads (default: 10m)
|
||||
DownloadTimeout time.Duration
|
||||
|
||||
// ConnectTimeout is the timeout for establishing connections (default: 10s)
|
||||
ConnectTimeout time.Duration
|
||||
|
||||
// TLSHandshakeTimeout is the timeout for TLS handshake (default: 10s)
|
||||
TLSHandshakeTimeout time.Duration
|
||||
|
||||
// IdleConnTimeout is how long idle connections stay in the pool (default: 90s)
|
||||
IdleConnTimeout time.Duration
|
||||
|
||||
// MaxIdleConns is the max number of idle connections (default: 100)
|
||||
MaxIdleConns int
|
||||
|
||||
// MaxIdleConnsPerHost is the max idle connections per host (default: 10)
|
||||
MaxIdleConnsPerHost int
|
||||
}
|
||||
|
||||
// DefaultConfig returns sensible default configuration values
|
||||
func DefaultConfig() Config {
|
||||
return Config{
|
||||
RequestTimeout: 30 * time.Second,
|
||||
DownloadTimeout: 10 * time.Minute,
|
||||
ConnectTimeout: 10 * time.Second,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
MaxIdleConns: 100,
|
||||
MaxIdleConnsPerHost: 10,
|
||||
}
|
||||
}
|
||||
|
||||
// ProvideService creates a new HTTP client service with secure defaults
|
||||
func ProvideService() *Service {
|
||||
return NewService(DefaultConfig())
|
||||
}
|
||||
|
||||
// NewService creates a new HTTP client service with the given configuration
|
||||
func NewService(cfg Config) *Service {
|
||||
// Create transport with timeouts and connection pooling
|
||||
// Note: We don't set TLSClientConfig - Go's defaults are secure and
|
||||
// production uses Caddy for TLS termination anyway
|
||||
transport := &http.Transport{
|
||||
DialContext: (&net.Dialer{
|
||||
Timeout: cfg.ConnectTimeout,
|
||||
KeepAlive: 30 * time.Second,
|
||||
}).DialContext,
|
||||
TLSHandshakeTimeout: cfg.TLSHandshakeTimeout,
|
||||
IdleConnTimeout: cfg.IdleConnTimeout,
|
||||
MaxIdleConns: cfg.MaxIdleConns,
|
||||
MaxIdleConnsPerHost: cfg.MaxIdleConnsPerHost,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
ForceAttemptHTTP2: true,
|
||||
}
|
||||
|
||||
// Create the main client for API requests
|
||||
client := &http.Client{
|
||||
Transport: transport,
|
||||
Timeout: cfg.RequestTimeout,
|
||||
}
|
||||
|
||||
// Create a separate transport for downloads with longer timeouts
|
||||
downloadTransport := &http.Transport{
|
||||
DialContext: (&net.Dialer{
|
||||
Timeout: cfg.ConnectTimeout,
|
||||
KeepAlive: 30 * time.Second,
|
||||
}).DialContext,
|
||||
TLSHandshakeTimeout: cfg.TLSHandshakeTimeout,
|
||||
IdleConnTimeout: cfg.IdleConnTimeout,
|
||||
MaxIdleConns: cfg.MaxIdleConns,
|
||||
MaxIdleConnsPerHost: cfg.MaxIdleConnsPerHost,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
ForceAttemptHTTP2: true,
|
||||
// Disable compression for downloads to avoid decompression overhead
|
||||
DisableCompression: true,
|
||||
}
|
||||
|
||||
// Create the download client with longer timeout
|
||||
downloadClient := &http.Client{
|
||||
Transport: downloadTransport,
|
||||
Timeout: cfg.DownloadTimeout,
|
||||
}
|
||||
|
||||
// Create a no-timeout transport for large file downloads
|
||||
// This client has no overall timeout - cancellation is controlled via request context
|
||||
// Connection and TLS handshake still have timeouts to prevent hanging on initial connect
|
||||
noTimeoutTransport := &http.Transport{
|
||||
DialContext: (&net.Dialer{
|
||||
Timeout: cfg.ConnectTimeout,
|
||||
KeepAlive: 30 * time.Second,
|
||||
}).DialContext,
|
||||
TLSHandshakeTimeout: cfg.TLSHandshakeTimeout,
|
||||
IdleConnTimeout: cfg.IdleConnTimeout,
|
||||
MaxIdleConns: cfg.MaxIdleConns,
|
||||
MaxIdleConnsPerHost: cfg.MaxIdleConnsPerHost,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
ForceAttemptHTTP2: true,
|
||||
DisableCompression: true,
|
||||
}
|
||||
|
||||
// No timeout - relies on context cancellation for large file downloads
|
||||
noTimeoutClient := &http.Client{
|
||||
Transport: noTimeoutTransport,
|
||||
Timeout: 0, // No timeout
|
||||
}
|
||||
|
||||
return &Service{
|
||||
client: client,
|
||||
downloadClient: downloadClient,
|
||||
noTimeoutClient: noTimeoutClient,
|
||||
}
|
||||
}
|
||||
|
||||
// Client returns the HTTP client for API requests (30s timeout)
|
||||
func (s *Service) Client() *http.Client {
|
||||
return s.client
|
||||
}
|
||||
|
||||
// DownloadClient returns the HTTP client for file downloads (10m timeout)
|
||||
func (s *Service) DownloadClient() *http.Client {
|
||||
return s.downloadClient
|
||||
}
|
||||
|
||||
// Do executes an HTTP request using the API client
|
||||
func (s *Service) Do(req *http.Request) (*http.Response, error) {
|
||||
return s.client.Do(req)
|
||||
}
|
||||
|
||||
// DoDownload executes an HTTP request using the download client (longer timeout)
|
||||
func (s *Service) DoDownload(req *http.Request) (*http.Response, error) {
|
||||
return s.downloadClient.Do(req)
|
||||
}
|
||||
|
||||
// Get performs an HTTP GET request using the API client
|
||||
func (s *Service) Get(url string) (*http.Response, error) {
|
||||
return s.client.Get(url)
|
||||
}
|
||||
|
||||
// GetDownload performs an HTTP GET request using the download client (longer timeout)
|
||||
func (s *Service) GetDownload(url string) (*http.Response, error) {
|
||||
return s.downloadClient.Get(url)
|
||||
}
|
||||
|
||||
// DoLargeDownload executes an HTTP request for large file downloads.
|
||||
// This client has NO overall timeout - cancellation must be handled via the request's context.
|
||||
// Use this for multi-gigabyte files that may take hours to download.
|
||||
// The connection establishment and TLS handshake still have timeouts.
|
||||
//
|
||||
// Example usage:
|
||||
//
|
||||
// ctx, cancel := context.WithCancel(context.Background())
|
||||
// defer cancel() // Call cancel() to abort the download
|
||||
// req, _ := http.NewRequestWithContext(ctx, "GET", url, nil)
|
||||
// resp, err := httpClient.DoLargeDownload(req)
|
||||
func (s *Service) DoLargeDownload(req *http.Request) (*http.Response, error) {
|
||||
return s.noTimeoutClient.Do(req)
|
||||
}
|
||||
|
||||
// GetLargeDownload performs an HTTP GET request for large file downloads.
|
||||
// This client has NO overall timeout - the download can run indefinitely.
|
||||
// Use this for multi-gigabyte files. To cancel, use DoLargeDownload with a context.
|
||||
func (s *Service) GetLargeDownload(url string) (*http.Response, error) {
|
||||
return s.noTimeoutClient.Get(url)
|
||||
}
|
||||
|
|
@ -0,0 +1,263 @@
|
|||
package inputvalidation
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/mail"
|
||||
"regexp"
|
||||
"strings"
|
||||
"unicode"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
// Validation limits for input fields
|
||||
const (
|
||||
// Email limits
|
||||
MaxEmailLength = 254 // RFC 5321
|
||||
|
||||
// Name limits (collection names, file names, user names)
|
||||
MinNameLength = 1
|
||||
MaxNameLength = 255
|
||||
|
||||
// Display name limits
|
||||
MaxDisplayNameLength = 100
|
||||
|
||||
// Description limits
|
||||
MaxDescriptionLength = 1000
|
||||
|
||||
// UUID format (standard UUID v4)
|
||||
UUIDLength = 36
|
||||
|
||||
// OTT (One-Time Token) limits
|
||||
OTTLength = 8 // 8-digit code
|
||||
|
||||
// Password limits
|
||||
MinPasswordLength = 8
|
||||
MaxPasswordLength = 128
|
||||
)
|
||||
|
||||
// uuidRegex matches standard UUID format (8-4-4-4-12)
|
||||
var uuidRegex = regexp.MustCompile(`^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}$`)
|
||||
|
||||
// ottRegex matches 8-digit OTT codes
|
||||
var ottRegex = regexp.MustCompile(`^[0-9]{8}$`)
|
||||
|
||||
// ValidateEmail validates an email address
|
||||
func ValidateEmail(email string) error {
|
||||
if email == "" {
|
||||
return fmt.Errorf("email is required")
|
||||
}
|
||||
|
||||
// Check length
|
||||
if len(email) > MaxEmailLength {
|
||||
return fmt.Errorf("email exceeds maximum length of %d characters", MaxEmailLength)
|
||||
}
|
||||
|
||||
// Use Go's mail package for RFC 5322 validation
|
||||
_, err := mail.ParseAddress(email)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid email format")
|
||||
}
|
||||
|
||||
// Additional checks for security
|
||||
if strings.ContainsAny(email, "\x00\n\r") {
|
||||
return fmt.Errorf("email contains invalid characters")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateUUID validates a UUID string
|
||||
func ValidateUUID(id, fieldName string) error {
|
||||
if id == "" {
|
||||
return fmt.Errorf("%s is required", fieldName)
|
||||
}
|
||||
|
||||
if len(id) != UUIDLength {
|
||||
return fmt.Errorf("%s must be a valid UUID", fieldName)
|
||||
}
|
||||
|
||||
if !uuidRegex.MatchString(id) {
|
||||
return fmt.Errorf("%s must be a valid UUID format", fieldName)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateName validates a name field (collection name, filename, etc.)
|
||||
func ValidateName(name, fieldName string) error {
|
||||
if name == "" {
|
||||
return fmt.Errorf("%s is required", fieldName)
|
||||
}
|
||||
|
||||
// Check length
|
||||
if len(name) > MaxNameLength {
|
||||
return fmt.Errorf("%s exceeds maximum length of %d characters", fieldName, MaxNameLength)
|
||||
}
|
||||
|
||||
// Check for valid UTF-8
|
||||
if !utf8.ValidString(name) {
|
||||
return fmt.Errorf("%s contains invalid characters", fieldName)
|
||||
}
|
||||
|
||||
// Check for control characters (except tab and newline which might be valid in descriptions)
|
||||
for _, r := range name {
|
||||
if r < 32 && r != '\t' && r != '\n' && r != '\r' {
|
||||
return fmt.Errorf("%s contains invalid control characters", fieldName)
|
||||
}
|
||||
// Also check for null byte and other dangerous characters
|
||||
if r == 0 {
|
||||
return fmt.Errorf("%s contains null characters", fieldName)
|
||||
}
|
||||
}
|
||||
|
||||
// Check that it's not all whitespace
|
||||
if strings.TrimSpace(name) == "" {
|
||||
return fmt.Errorf("%s cannot be empty or whitespace only", fieldName)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateDisplayName validates a display name (first name, last name, etc.)
|
||||
func ValidateDisplayName(name, fieldName string) error {
|
||||
// Display names can be empty (optional fields)
|
||||
if name == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check length
|
||||
if len(name) > MaxDisplayNameLength {
|
||||
return fmt.Errorf("%s exceeds maximum length of %d characters", fieldName, MaxDisplayNameLength)
|
||||
}
|
||||
|
||||
// Check for valid UTF-8
|
||||
if !utf8.ValidString(name) {
|
||||
return fmt.Errorf("%s contains invalid characters", fieldName)
|
||||
}
|
||||
|
||||
// Check for control characters
|
||||
for _, r := range name {
|
||||
if r < 32 || !unicode.IsPrint(r) {
|
||||
if r != ' ' { // Allow spaces
|
||||
return fmt.Errorf("%s contains invalid characters", fieldName)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateDescription validates a description field
|
||||
func ValidateDescription(desc string) error {
|
||||
// Descriptions can be empty (optional)
|
||||
if desc == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check length
|
||||
if len(desc) > MaxDescriptionLength {
|
||||
return fmt.Errorf("description exceeds maximum length of %d characters", MaxDescriptionLength)
|
||||
}
|
||||
|
||||
// Check for valid UTF-8
|
||||
if !utf8.ValidString(desc) {
|
||||
return fmt.Errorf("description contains invalid characters")
|
||||
}
|
||||
|
||||
// Check for null bytes
|
||||
if strings.ContainsRune(desc, 0) {
|
||||
return fmt.Errorf("description contains null characters")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateOTT validates a one-time token (8-digit code)
|
||||
func ValidateOTT(ott string) error {
|
||||
if ott == "" {
|
||||
return fmt.Errorf("verification code is required")
|
||||
}
|
||||
|
||||
// Trim whitespace (users might copy-paste with spaces)
|
||||
ott = strings.TrimSpace(ott)
|
||||
|
||||
if !ottRegex.MatchString(ott) {
|
||||
return fmt.Errorf("verification code must be an 8-digit number")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidatePassword validates a password
|
||||
func ValidatePassword(password string) error {
|
||||
if password == "" {
|
||||
return fmt.Errorf("password is required")
|
||||
}
|
||||
|
||||
if len(password) < MinPasswordLength {
|
||||
return fmt.Errorf("password must be at least %d characters", MinPasswordLength)
|
||||
}
|
||||
|
||||
if len(password) > MaxPasswordLength {
|
||||
return fmt.Errorf("password exceeds maximum length of %d characters", MaxPasswordLength)
|
||||
}
|
||||
|
||||
// Check for null bytes (could indicate injection attempt)
|
||||
if strings.ContainsRune(password, 0) {
|
||||
return fmt.Errorf("password contains invalid characters")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateCollectionID is a convenience function for collection ID validation
|
||||
func ValidateCollectionID(id string) error {
|
||||
return ValidateUUID(id, "collection ID")
|
||||
}
|
||||
|
||||
// ValidateFileID is a convenience function for file ID validation
|
||||
func ValidateFileID(id string) error {
|
||||
return ValidateUUID(id, "file ID")
|
||||
}
|
||||
|
||||
// ValidateTagID is a convenience function for tag ID validation
|
||||
func ValidateTagID(id string) error {
|
||||
return ValidateUUID(id, "tag ID")
|
||||
}
|
||||
|
||||
// ValidateCollectionName validates a collection name
|
||||
func ValidateCollectionName(name string) error {
|
||||
return ValidateName(name, "collection name")
|
||||
}
|
||||
|
||||
// ValidateFileName validates a file name
|
||||
func ValidateFileName(name string) error {
|
||||
if err := ValidateName(name, "filename"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Additional file-specific validations
|
||||
// Check for path traversal attempts
|
||||
if strings.Contains(name, "..") {
|
||||
return fmt.Errorf("filename cannot contain path traversal sequences")
|
||||
}
|
||||
|
||||
// Check for path separators
|
||||
if strings.ContainsAny(name, "/\\") {
|
||||
return fmt.Errorf("filename cannot contain path separators")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SanitizeString removes or replaces potentially dangerous characters
|
||||
// This is a defense-in-depth measure - validation should be done first
|
||||
func SanitizeString(s string) string {
|
||||
// Remove null bytes
|
||||
s = strings.ReplaceAll(s, "\x00", "")
|
||||
|
||||
// Trim excessive whitespace
|
||||
s = strings.TrimSpace(s)
|
||||
|
||||
return s
|
||||
}
|
||||
|
|
@ -0,0 +1,167 @@
|
|||
package inputvalidation
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/url"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// AllowedDownloadHosts lists the allowed hosts for presigned download URLs.
|
||||
// These are the only hosts from which the application will download files.
|
||||
var AllowedDownloadHosts = []string{
|
||||
// Production S3-compatible storage (Digital Ocean Spaces)
|
||||
".digitaloceanspaces.com",
|
||||
// AWS S3 (if used in future)
|
||||
".s3.amazonaws.com",
|
||||
".s3.us-east-1.amazonaws.com",
|
||||
".s3.us-west-2.amazonaws.com",
|
||||
".s3.eu-west-1.amazonaws.com",
|
||||
// MapleFile domains (if serving files directly)
|
||||
".maplefile.ca",
|
||||
// Local development
|
||||
"localhost",
|
||||
"127.0.0.1",
|
||||
}
|
||||
|
||||
// ValidateDownloadURL validates a presigned download URL before use.
|
||||
// This prevents SSRF attacks by ensuring downloads only happen from trusted hosts.
|
||||
func ValidateDownloadURL(rawURL string) error {
|
||||
if rawURL == "" {
|
||||
return fmt.Errorf("download URL is required")
|
||||
}
|
||||
|
||||
// Parse the URL
|
||||
parsedURL, err := url.Parse(rawURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid URL format: %w", err)
|
||||
}
|
||||
|
||||
// Validate scheme - must be HTTPS (except localhost for development)
|
||||
if parsedURL.Scheme != "https" && parsedURL.Scheme != "http" {
|
||||
return fmt.Errorf("URL must use HTTP or HTTPS scheme")
|
||||
}
|
||||
|
||||
// Get host without port
|
||||
host := parsedURL.Hostname()
|
||||
if host == "" {
|
||||
return fmt.Errorf("URL must have a valid host")
|
||||
}
|
||||
|
||||
// For HTTPS requirement - only allow HTTP for localhost/local IPs
|
||||
if parsedURL.Scheme == "http" {
|
||||
if !isLocalHost(host) {
|
||||
return fmt.Errorf("non-local URLs must use HTTPS")
|
||||
}
|
||||
}
|
||||
|
||||
// Check if host is in allowed list
|
||||
if !isAllowedHost(host) {
|
||||
return fmt.Errorf("download from host %q is not allowed", host)
|
||||
}
|
||||
|
||||
// Check for credentials in URL (security risk)
|
||||
if parsedURL.User != nil {
|
||||
return fmt.Errorf("URL must not contain credentials")
|
||||
}
|
||||
|
||||
// Check for suspicious path traversal in URL path
|
||||
if strings.Contains(parsedURL.Path, "..") {
|
||||
return fmt.Errorf("URL path contains invalid sequences")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// isAllowedHost checks if a host is in the allowed download hosts list
|
||||
func isAllowedHost(host string) bool {
|
||||
host = strings.ToLower(host)
|
||||
|
||||
for _, allowed := range AllowedDownloadHosts {
|
||||
allowed = strings.ToLower(allowed)
|
||||
|
||||
// Exact match
|
||||
if host == allowed {
|
||||
return true
|
||||
}
|
||||
|
||||
// Suffix match for wildcard domains (e.g., ".digitaloceanspaces.com")
|
||||
if strings.HasPrefix(allowed, ".") && strings.HasSuffix(host, allowed) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Handle subdomains for non-wildcard entries
|
||||
if !strings.HasPrefix(allowed, ".") {
|
||||
if host == allowed || strings.HasSuffix(host, "."+allowed) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// isLocalHost checks if a host is localhost or a local IP address
|
||||
func isLocalHost(host string) bool {
|
||||
host = strings.ToLower(host)
|
||||
|
||||
// Check common localhost names
|
||||
if host == "localhost" || host == "127.0.0.1" || host == "::1" {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check if it's a local network IP
|
||||
ip := net.ParseIP(host)
|
||||
if ip == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check for loopback
|
||||
if ip.IsLoopback() {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check for private network ranges (10.x.x.x, 192.168.x.x, 172.16-31.x.x)
|
||||
if ip.IsPrivate() {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// ValidateAPIBaseURL validates a base URL for API requests
|
||||
func ValidateAPIBaseURL(rawURL string) error {
|
||||
if rawURL == "" {
|
||||
return fmt.Errorf("API URL is required")
|
||||
}
|
||||
|
||||
parsedURL, err := url.Parse(rawURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid URL format: %w", err)
|
||||
}
|
||||
|
||||
// Validate scheme
|
||||
if parsedURL.Scheme != "https" && parsedURL.Scheme != "http" {
|
||||
return fmt.Errorf("URL must use HTTP or HTTPS scheme")
|
||||
}
|
||||
|
||||
// Get host
|
||||
host := parsedURL.Hostname()
|
||||
if host == "" {
|
||||
return fmt.Errorf("URL must have a valid host")
|
||||
}
|
||||
|
||||
// For HTTPS requirement - only allow HTTP for localhost/local IPs
|
||||
if parsedURL.Scheme == "http" {
|
||||
if !isLocalHost(host) {
|
||||
return fmt.Errorf("non-local URLs must use HTTPS")
|
||||
}
|
||||
}
|
||||
|
||||
// Check for credentials in URL
|
||||
if parsedURL.User != nil {
|
||||
return fmt.Errorf("URL must not contain credentials")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
181
native/desktop/maplefile/internal/service/keycache/keycache.go
Normal file
181
native/desktop/maplefile/internal/service/keycache/keycache.go
Normal file
|
|
@ -0,0 +1,181 @@
|
|||
// Package keycache provides secure in-memory caching of cryptographic keys during a session.
|
||||
// Keys are stored in memguard Enclaves (encrypted at rest in memory) and automatically
|
||||
// cleared when the application shuts down or the user logs out.
|
||||
package keycache
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/awnumar/memguard"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"codeberg.org/mapleopentech/monorepo/native/desktop/maplefile/internal/utils"
|
||||
)
|
||||
|
||||
// Service manages cached cryptographic keys in secure memory
|
||||
type Service struct {
|
||||
logger *zap.Logger
|
||||
mu sync.RWMutex
|
||||
// Map of email -> Enclave containing master key
|
||||
// Enclave stores data encrypted in memory, must be opened to access
|
||||
masterKeys map[string]*memguard.Enclave
|
||||
}
|
||||
|
||||
// ProvideService creates a new key cache service (for Wire)
|
||||
func ProvideService(logger *zap.Logger) *Service {
|
||||
return &Service{
|
||||
logger: logger.Named("keycache"),
|
||||
masterKeys: make(map[string]*memguard.Enclave),
|
||||
}
|
||||
}
|
||||
|
||||
// StoreMasterKey stores a user's master key in an encrypted memory Enclave
|
||||
// The key will remain cached until cleared or the app exits
|
||||
func (s *Service) StoreMasterKey(email string, masterKey []byte) error {
|
||||
if email == "" {
|
||||
return fmt.Errorf("email is required")
|
||||
}
|
||||
|
||||
if len(masterKey) == 0 {
|
||||
return fmt.Errorf("master key is empty")
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// If there's already a cached key, remove it first
|
||||
if existing, exists := s.masterKeys[email]; exists {
|
||||
// Enclaves are garbage collected when removed from map
|
||||
delete(s.masterKeys, email)
|
||||
s.logger.Debug("Replaced existing cached master key", zap.String("email", utils.MaskEmail(email)))
|
||||
_ = existing // Prevent unused variable warning
|
||||
}
|
||||
|
||||
// Create a LockedBuffer from the master key bytes first
|
||||
// This locks the memory pages and prevents swapping
|
||||
lockedBuf := memguard.NewBufferFromBytes(masterKey)
|
||||
|
||||
// Create an Enclave from the LockedBuffer
|
||||
// Enclave stores the data encrypted at rest in memory
|
||||
enclave := lockedBuf.Seal()
|
||||
|
||||
// The LockedBuffer is consumed by Seal(), so we don't need to Destroy() it
|
||||
|
||||
// Store the enclave
|
||||
s.masterKeys[email] = enclave
|
||||
|
||||
s.logger.Info("Master key cached securely in memory",
|
||||
zap.String("email", utils.MaskEmail(email)),
|
||||
zap.Int("size", len(masterKey)))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetMasterKey retrieves a cached master key for the given email
|
||||
// Returns the key bytes and a cleanup function that MUST be called when done
|
||||
// The cleanup function destroys the LockedBuffer to prevent memory leaks
|
||||
func (s *Service) GetMasterKey(email string) ([]byte, func(), error) {
|
||||
if email == "" {
|
||||
return nil, nil, fmt.Errorf("email is required")
|
||||
}
|
||||
|
||||
s.mu.RLock()
|
||||
enclave, exists := s.masterKeys[email]
|
||||
s.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return nil, nil, fmt.Errorf("no cached master key found for email: %s", email)
|
||||
}
|
||||
|
||||
// Open the enclave to access the master key
|
||||
lockedBuf, err := enclave.Open()
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to open enclave for reading: %w", err)
|
||||
}
|
||||
|
||||
// Get the bytes (caller will use these)
|
||||
masterKey := lockedBuf.Bytes()
|
||||
|
||||
// Return cleanup function that destroys the LockedBuffer
|
||||
cleanup := func() {
|
||||
lockedBuf.Destroy()
|
||||
}
|
||||
|
||||
s.logger.Debug("Retrieved cached master key from secure memory",
|
||||
zap.String("email", utils.MaskEmail(email)))
|
||||
|
||||
return masterKey, cleanup, nil
|
||||
}
|
||||
|
||||
// WithMasterKey provides a callback pattern for using a cached master key
|
||||
// The key is automatically cleaned up after the callback returns
|
||||
func (s *Service) WithMasterKey(email string, fn func([]byte) error) error {
|
||||
masterKey, cleanup, err := s.GetMasterKey(email)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer cleanup()
|
||||
|
||||
return fn(masterKey)
|
||||
}
|
||||
|
||||
// HasMasterKey checks if a master key is cached for the given email
|
||||
func (s *Service) HasMasterKey(email string) bool {
|
||||
if email == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
_, exists := s.masterKeys[email]
|
||||
return exists
|
||||
}
|
||||
|
||||
// ClearMasterKey removes a cached master key for a specific user
|
||||
func (s *Service) ClearMasterKey(email string) error {
|
||||
if email == "" {
|
||||
return fmt.Errorf("email is required")
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if enclave, exists := s.masterKeys[email]; exists {
|
||||
delete(s.masterKeys, email)
|
||||
s.logger.Info("Cleared cached master key from secure memory",
|
||||
zap.String("email", utils.MaskEmail(email)))
|
||||
_ = enclave // Enclave will be garbage collected
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("no cached master key found for email: %s", email)
|
||||
}
|
||||
|
||||
// ClearAll removes all cached master keys
|
||||
// This should be called on logout or application shutdown
|
||||
func (s *Service) ClearAll() {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
count := len(s.masterKeys)
|
||||
if count == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// Clear all enclaves
|
||||
for email := range s.masterKeys {
|
||||
delete(s.masterKeys, email)
|
||||
}
|
||||
|
||||
s.logger.Info("Cleared all cached master keys from secure memory",
|
||||
zap.Int("count", count))
|
||||
}
|
||||
|
||||
// Cleanup performs cleanup operations when the service is shutting down
|
||||
// This is called by the application shutdown handler
|
||||
func (s *Service) Cleanup() {
|
||||
s.logger.Info("Cleaning up key cache service")
|
||||
s.ClearAll()
|
||||
}
|
||||
|
|
@ -0,0 +1,180 @@
|
|||
package passwordstore
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/awnumar/memguard"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"codeberg.org/mapleopentech/monorepo/native/desktop/maplefile/internal/utils"
|
||||
)
|
||||
|
||||
// Service manages password storage in secure RAM
|
||||
type Service struct {
|
||||
logger *zap.Logger
|
||||
mu sync.RWMutex
|
||||
|
||||
// RAM storage (memguard enclaves) - email -> encrypted password
|
||||
memoryStore map[string]*memguard.Enclave
|
||||
}
|
||||
|
||||
// New creates a new password storage service
|
||||
func New(logger *zap.Logger) *Service {
|
||||
// Initialize memguard
|
||||
memguard.CatchInterrupt()
|
||||
|
||||
return &Service{
|
||||
logger: logger,
|
||||
memoryStore: make(map[string]*memguard.Enclave),
|
||||
}
|
||||
}
|
||||
|
||||
// StorePassword stores password in secure RAM (memguard).
|
||||
// SECURITY NOTE: This method accepts a string for API compatibility with JSON inputs.
|
||||
// The string is immediately converted to []byte and the byte slice is zeroed after
|
||||
// creating the secure enclave. For maximum security, use StorePasswordBytes when
|
||||
// you have direct access to []byte data.
|
||||
func (s *Service) StorePassword(email, password string) error {
|
||||
// Convert string to byte slice for secure handling
|
||||
passwordBytes := []byte(password)
|
||||
|
||||
// Store using the secure byte-based method
|
||||
err := s.StorePasswordBytes(email, passwordBytes)
|
||||
|
||||
// Zero the byte slice after use (defense in depth)
|
||||
// Note: The original string cannot be zeroed in Go, but we minimize exposure
|
||||
// by zeroing the byte slice copy as soon as possible
|
||||
zeroBytes(passwordBytes)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// StorePasswordBytes stores password from []byte in secure RAM (memguard).
|
||||
// This is the preferred method when you have direct access to password bytes,
|
||||
// as it allows the caller to zero the source bytes after this call returns.
|
||||
// The provided byte slice is copied into secure memory and can be safely zeroed
|
||||
// by the caller after this method returns.
|
||||
func (s *Service) StorePasswordBytes(email string, password []byte) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// Remove existing enclave if any (will be garbage collected)
|
||||
if _, exists := s.memoryStore[email]; exists {
|
||||
delete(s.memoryStore, email)
|
||||
s.logger.Debug("Replaced existing password enclave",
|
||||
zap.String("email", utils.MaskEmail(email)))
|
||||
}
|
||||
|
||||
// Create new secure enclave (memguard copies the data into protected memory)
|
||||
enclave := memguard.NewEnclave(password)
|
||||
s.memoryStore[email] = enclave
|
||||
|
||||
s.logger.Debug("Password stored in secure RAM",
|
||||
zap.String("email", utils.MaskEmail(email)))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetPassword retrieves password from secure RAM as a string.
|
||||
// SECURITY NOTE: The returned string cannot be zeroed in Go. For operations
|
||||
// that can work with []byte, use GetPasswordBytes instead.
|
||||
func (s *Service) GetPassword(email string) (string, error) {
|
||||
passwordBytes, err := s.GetPasswordBytes(email)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Convert to string (unfortunately creates a copy that can't be zeroed)
|
||||
password := string(passwordBytes)
|
||||
|
||||
// Zero the byte slice
|
||||
zeroBytes(passwordBytes)
|
||||
|
||||
return password, nil
|
||||
}
|
||||
|
||||
// GetPasswordBytes retrieves password from secure RAM as []byte.
|
||||
// The caller SHOULD zero the returned byte slice after use by calling
|
||||
// zeroBytes or similar. This is the preferred method for security-sensitive
|
||||
// operations.
|
||||
func (s *Service) GetPasswordBytes(email string) ([]byte, error) {
|
||||
s.mu.RLock()
|
||||
enclave, exists := s.memoryStore[email]
|
||||
s.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("no password stored for %s", email)
|
||||
}
|
||||
|
||||
// Open enclave to read password
|
||||
lockedBuffer, err := enclave.Open()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open password enclave: %w", err)
|
||||
}
|
||||
defer lockedBuffer.Destroy()
|
||||
|
||||
// Copy the password bytes (memguard buffer will be destroyed after defer)
|
||||
passwordBytes := make([]byte, len(lockedBuffer.Bytes()))
|
||||
copy(passwordBytes, lockedBuffer.Bytes())
|
||||
|
||||
return passwordBytes, nil
|
||||
}
|
||||
|
||||
// HasPassword checks if password is stored for given email
|
||||
func (s *Service) HasPassword(email string) bool {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
_, exists := s.memoryStore[email]
|
||||
return exists
|
||||
}
|
||||
|
||||
// ClearPassword removes password from RAM (logout)
|
||||
func (s *Service) ClearPassword(email string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if _, exists := s.memoryStore[email]; exists {
|
||||
delete(s.memoryStore, email)
|
||||
s.logger.Debug("Password cleared from RAM",
|
||||
zap.String("email", utils.MaskEmail(email)))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Cleanup destroys all secure memory (called on shutdown)
|
||||
func (s *Service) Cleanup() {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// Log cleanup for each stored password
|
||||
for email := range s.memoryStore {
|
||||
s.logger.Debug("Clearing password enclave on shutdown",
|
||||
zap.String("email", utils.MaskEmail(email)))
|
||||
}
|
||||
|
||||
// Clear the map (enclaves will be garbage collected)
|
||||
s.memoryStore = make(map[string]*memguard.Enclave)
|
||||
|
||||
// Purge all memguard secure memory
|
||||
memguard.Purge()
|
||||
|
||||
s.logger.Debug("Password store cleanup complete - all secure memory purged")
|
||||
}
|
||||
|
||||
// zeroBytes overwrites a byte slice with zeros to clear sensitive data from memory.
|
||||
// This is a defense-in-depth measure - while Go's GC may still have copies,
|
||||
// this reduces the window of exposure.
|
||||
func zeroBytes(b []byte) {
|
||||
for i := range b {
|
||||
b[i] = 0
|
||||
}
|
||||
}
|
||||
|
||||
// ZeroBytes is exported for callers who receive password bytes and need to clear them.
|
||||
// Use this after you're done with password bytes returned by GetPasswordBytes.
|
||||
func ZeroBytes(b []byte) {
|
||||
zeroBytes(b)
|
||||
}
|
||||
|
|
@ -0,0 +1,10 @@
|
|||
package passwordstore
|
||||
|
||||
import (
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// ProvideService creates the password storage service
|
||||
func ProvideService(logger *zap.Logger) *Service {
|
||||
return New(logger.Named("password-store"))
|
||||
}
|
||||
|
|
@ -0,0 +1,260 @@
|
|||
// Package ratelimiter provides client-side rate limiting for sensitive operations.
|
||||
//
|
||||
// Security Note: This is a defense-in-depth measure. The backend MUST also implement
|
||||
// rate limiting as the authoritative control. Client-side rate limiting provides:
|
||||
// - Protection against accidental rapid requests (e.g., user double-clicking)
|
||||
// - Reduced load on backend during legitimate high-frequency usage
|
||||
// - Better UX by failing fast with clear error messages
|
||||
// - Deterrent against simple automated attacks (though not a security boundary)
|
||||
//
|
||||
// This does NOT replace server-side rate limiting, which remains the security control.
|
||||
package ratelimiter
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Operation represents a rate-limited operation type
|
||||
type Operation string
|
||||
|
||||
const (
|
||||
// OpRequestOTT is the operation for requesting a one-time token
|
||||
OpRequestOTT Operation = "request_ott"
|
||||
// OpVerifyOTT is the operation for verifying a one-time token
|
||||
OpVerifyOTT Operation = "verify_ott"
|
||||
// OpCompleteLogin is the operation for completing login
|
||||
OpCompleteLogin Operation = "complete_login"
|
||||
// OpRegister is the operation for user registration
|
||||
OpRegister Operation = "register"
|
||||
// OpVerifyEmail is the operation for email verification
|
||||
OpVerifyEmail Operation = "verify_email"
|
||||
)
|
||||
|
||||
// RateLimitError is returned when an operation is rate limited
|
||||
type RateLimitError struct {
|
||||
Operation Operation
|
||||
RetryAfter time.Duration
|
||||
AttemptsMade int
|
||||
MaxAttempts int
|
||||
}
|
||||
|
||||
func (e *RateLimitError) Error() string {
|
||||
return fmt.Sprintf(
|
||||
"rate limited: %s operation exceeded %d attempts, retry after %v",
|
||||
e.Operation, e.MaxAttempts, e.RetryAfter.Round(time.Second),
|
||||
)
|
||||
}
|
||||
|
||||
// operationLimit defines the rate limit configuration for an operation
|
||||
type operationLimit struct {
|
||||
maxAttempts int // Maximum attempts allowed in the window
|
||||
window time.Duration // Time window for the limit
|
||||
cooldown time.Duration // Cooldown period after hitting the limit
|
||||
}
|
||||
|
||||
// operationState tracks the current state of rate limiting for an operation
|
||||
type operationState struct {
|
||||
attempts int // Current attempt count
|
||||
windowStart time.Time // When the current window started
|
||||
lockedUntil time.Time // If rate limited, when the cooldown ends
|
||||
}
|
||||
|
||||
// Service provides rate limiting functionality
|
||||
type Service struct {
|
||||
mu sync.Mutex
|
||||
limits map[Operation]operationLimit
|
||||
state map[string]*operationState // key: operation + identifier (e.g., email)
|
||||
}
|
||||
|
||||
// New creates a new rate limiter service with default limits.
|
||||
//
|
||||
// Default limits are designed to:
|
||||
// - Allow normal user behavior (typos, retries)
|
||||
// - Prevent rapid automated attempts
|
||||
// - Provide reasonable cooldown periods
|
||||
func New() *Service {
|
||||
return &Service{
|
||||
limits: map[Operation]operationLimit{
|
||||
// OTT request: 3 attempts per 60 seconds, 2 minute cooldown
|
||||
// Rationale: Users might request OTT multiple times if email is slow
|
||||
OpRequestOTT: {
|
||||
maxAttempts: 3,
|
||||
window: 60 * time.Second,
|
||||
cooldown: 2 * time.Minute,
|
||||
},
|
||||
// OTT verification: 5 attempts per 60 seconds, 1 minute cooldown
|
||||
// Rationale: Users might mistype the 8-digit code
|
||||
OpVerifyOTT: {
|
||||
maxAttempts: 5,
|
||||
window: 60 * time.Second,
|
||||
cooldown: 1 * time.Minute,
|
||||
},
|
||||
// Complete login: 5 attempts per 60 seconds, 1 minute cooldown
|
||||
// Rationale: Password decryption might fail due to typos
|
||||
OpCompleteLogin: {
|
||||
maxAttempts: 5,
|
||||
window: 60 * time.Second,
|
||||
cooldown: 1 * time.Minute,
|
||||
},
|
||||
// Registration: 3 attempts per 5 minutes, 5 minute cooldown
|
||||
// Rationale: Registration is a one-time operation, limit abuse
|
||||
OpRegister: {
|
||||
maxAttempts: 3,
|
||||
window: 5 * time.Minute,
|
||||
cooldown: 5 * time.Minute,
|
||||
},
|
||||
// Email verification: 5 attempts per 60 seconds, 1 minute cooldown
|
||||
// Rationale: Users might mistype the verification code
|
||||
OpVerifyEmail: {
|
||||
maxAttempts: 5,
|
||||
window: 60 * time.Second,
|
||||
cooldown: 1 * time.Minute,
|
||||
},
|
||||
},
|
||||
state: make(map[string]*operationState),
|
||||
}
|
||||
}
|
||||
|
||||
// ProvideService creates the rate limiter service for Wire dependency injection
|
||||
func ProvideService() *Service {
|
||||
return New()
|
||||
}
|
||||
|
||||
// Check verifies if an operation is allowed and records the attempt.
|
||||
// The identifier is typically the user's email address.
|
||||
// Returns nil if the operation is allowed, or a RateLimitError if rate limited.
|
||||
func (s *Service) Check(op Operation, identifier string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
limit, ok := s.limits[op]
|
||||
if !ok {
|
||||
// Unknown operation, allow by default (fail open for usability)
|
||||
return nil
|
||||
}
|
||||
|
||||
key := string(op) + ":" + identifier
|
||||
now := time.Now()
|
||||
|
||||
state, exists := s.state[key]
|
||||
if !exists {
|
||||
// First attempt for this operation+identifier
|
||||
s.state[key] = &operationState{
|
||||
attempts: 1,
|
||||
windowStart: now,
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if currently in cooldown
|
||||
if now.Before(state.lockedUntil) {
|
||||
return &RateLimitError{
|
||||
Operation: op,
|
||||
RetryAfter: state.lockedUntil.Sub(now),
|
||||
AttemptsMade: state.attempts,
|
||||
MaxAttempts: limit.maxAttempts,
|
||||
}
|
||||
}
|
||||
|
||||
// Check if window has expired (reset if so)
|
||||
if now.Sub(state.windowStart) > limit.window {
|
||||
state.attempts = 1
|
||||
state.windowStart = now
|
||||
state.lockedUntil = time.Time{}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Increment attempt count
|
||||
state.attempts++
|
||||
|
||||
// Check if limit exceeded
|
||||
if state.attempts > limit.maxAttempts {
|
||||
state.lockedUntil = now.Add(limit.cooldown)
|
||||
return &RateLimitError{
|
||||
Operation: op,
|
||||
RetryAfter: limit.cooldown,
|
||||
AttemptsMade: state.attempts,
|
||||
MaxAttempts: limit.maxAttempts,
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Reset clears the rate limit state for a specific operation and identifier.
|
||||
// Call this after a successful operation where you want to allow fresh attempts
|
||||
// (e.g., after successful OTT verification). Do NOT call this for operations
|
||||
// where success shouldn't reset the limit (e.g., OTT request, registration).
|
||||
func (s *Service) Reset(op Operation, identifier string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
key := string(op) + ":" + identifier
|
||||
delete(s.state, key)
|
||||
}
|
||||
|
||||
// ResetAll clears all rate limit state for an identifier (e.g., after successful login).
|
||||
func (s *Service) ResetAll(identifier string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
for op := range s.limits {
|
||||
key := string(op) + ":" + identifier
|
||||
delete(s.state, key)
|
||||
}
|
||||
}
|
||||
|
||||
// GetRemainingAttempts returns the number of remaining attempts for an operation.
|
||||
// Returns -1 if the operation is currently rate limited.
|
||||
func (s *Service) GetRemainingAttempts(op Operation, identifier string) int {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
limit, ok := s.limits[op]
|
||||
if !ok {
|
||||
return -1
|
||||
}
|
||||
|
||||
key := string(op) + ":" + identifier
|
||||
state, exists := s.state[key]
|
||||
if !exists {
|
||||
return limit.maxAttempts
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
|
||||
// In cooldown
|
||||
if now.Before(state.lockedUntil) {
|
||||
return -1
|
||||
}
|
||||
|
||||
// Window expired
|
||||
if now.Sub(state.windowStart) > limit.window {
|
||||
return limit.maxAttempts
|
||||
}
|
||||
|
||||
remaining := limit.maxAttempts - state.attempts
|
||||
if remaining < 0 {
|
||||
return 0
|
||||
}
|
||||
return remaining
|
||||
}
|
||||
|
||||
// Cleanup removes expired state entries to prevent memory growth.
|
||||
// This should be called periodically (e.g., every hour).
|
||||
func (s *Service) Cleanup() {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
maxAge := 24 * time.Hour // Remove entries older than 24 hours
|
||||
|
||||
for key, state := range s.state {
|
||||
// Remove if window started more than maxAge ago and not in cooldown
|
||||
if now.Sub(state.windowStart) > maxAge && now.After(state.lockedUntil) {
|
||||
delete(s.state, key)
|
||||
}
|
||||
}
|
||||
}
|
||||
512
native/desktop/maplefile/internal/service/search/search.go
Normal file
512
native/desktop/maplefile/internal/service/search/search.go
Normal file
|
|
@ -0,0 +1,512 @@
|
|||
// Package search provides full-text search functionality using Bleve.
|
||||
//
|
||||
// This package implements a local full-text search index for files and collections
|
||||
// using the Bleve search library (https://blevesearch.com/). The search index is
|
||||
// stored per-user in their local application data directory.
|
||||
//
|
||||
// Key features:
|
||||
// - Case-insensitive substring matching (e.g., "mesh" matches "meshtastic")
|
||||
// - Support for Bleve query syntax (+, -, "", *, ?)
|
||||
// - Deduplication of search results by document ID
|
||||
// - Batch indexing for efficient rebuilds
|
||||
// - User-isolated indexes (each user has their own search index)
|
||||
//
|
||||
// Location: monorepo/native/desktop/maplefile/internal/service/search/search.go
|
||||
package search
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/blevesearch/bleve/v2"
|
||||
"github.com/blevesearch/bleve/v2/mapping"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"codeberg.org/mapleopentech/monorepo/native/desktop/maplefile/internal/config"
|
||||
)
|
||||
|
||||
// SearchService provides full-text search capabilities
|
||||
type SearchService interface {
|
||||
// Initialize opens or creates the search index for the specified user email
|
||||
Initialize(ctx context.Context, userEmail string) error
|
||||
|
||||
// Close closes the search index
|
||||
Close() error
|
||||
|
||||
// IndexFile adds or updates a file in the search index
|
||||
IndexFile(file *FileDocument) error
|
||||
|
||||
// IndexCollection adds or updates a collection in the search index
|
||||
IndexCollection(collection *CollectionDocument) error
|
||||
|
||||
// DeleteFile removes a file from the search index
|
||||
DeleteFile(fileID string) error
|
||||
|
||||
// DeleteCollection removes a collection from the search index
|
||||
DeleteCollection(collectionID string) error
|
||||
|
||||
// Search performs a full-text search
|
||||
Search(query string, limit int) (*SearchResult, error)
|
||||
|
||||
// RebuildIndex rebuilds the entire search index from scratch
|
||||
RebuildIndex(userEmail string, files []*FileDocument, collections []*CollectionDocument) error
|
||||
|
||||
// GetIndexSize returns the size of the search index in bytes
|
||||
GetIndexSize() (int64, error)
|
||||
|
||||
// GetDocumentCount returns the number of documents in the index
|
||||
GetDocumentCount() (uint64, error)
|
||||
}
|
||||
|
||||
// FileDocument represents a file document in the search index
|
||||
type FileDocument struct {
|
||||
ID string `json:"id"`
|
||||
Filename string `json:"filename"`
|
||||
Description string `json:"description"`
|
||||
CollectionID string `json:"collection_id"`
|
||||
CollectionName string `json:"collection_name"` // Denormalized for search
|
||||
Tags []string `json:"tags"`
|
||||
Size int64 `json:"size"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
Type string `json:"type"` // "file"
|
||||
}
|
||||
|
||||
// CollectionDocument represents a collection document in the search index
|
||||
type CollectionDocument struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Tags []string `json:"tags"`
|
||||
FileCount int `json:"file_count"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
Type string `json:"type"` // "collection"
|
||||
}
|
||||
|
||||
// SearchResult contains the search results
|
||||
type SearchResult struct {
|
||||
Files []*FileDocument `json:"files"`
|
||||
Collections []*CollectionDocument `json:"collections"`
|
||||
TotalFiles int `json:"total_files"`
|
||||
TotalCollections int `json:"total_collections"`
|
||||
TotalHits uint64 `json:"total_hits"`
|
||||
MaxScore float64 `json:"max_score"`
|
||||
Took time.Duration `json:"took"`
|
||||
Query string `json:"query"`
|
||||
}
|
||||
|
||||
// searchService implements SearchService
|
||||
type searchService struct {
|
||||
index bleve.Index
|
||||
configService config.ConfigService
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// New creates a new search service
|
||||
func New(configService config.ConfigService, logger *zap.Logger) SearchService {
|
||||
return &searchService{
|
||||
configService: configService,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize opens or creates the search index for the specified user
|
||||
func (s *searchService) Initialize(ctx context.Context, userEmail string) error {
|
||||
if userEmail == "" {
|
||||
return fmt.Errorf("user email is required")
|
||||
}
|
||||
|
||||
// Get search index path
|
||||
indexPath, err := s.configService.GetUserSearchIndexDir(ctx, userEmail)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get search index path: %w", err)
|
||||
}
|
||||
|
||||
if indexPath == "" {
|
||||
return fmt.Errorf("search index path is empty")
|
||||
}
|
||||
|
||||
s.logger.Info("Initializing search index", zap.String("path", indexPath))
|
||||
|
||||
// Try to open existing index
|
||||
index, err := bleve.Open(indexPath)
|
||||
if err == bleve.ErrorIndexPathDoesNotExist {
|
||||
// Create new index
|
||||
s.logger.Info("Creating new search index")
|
||||
indexMapping := buildIndexMapping()
|
||||
index, err = bleve.New(indexPath, indexMapping)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create search index: %w", err)
|
||||
}
|
||||
} else if err != nil {
|
||||
return fmt.Errorf("failed to open search index: %w", err)
|
||||
}
|
||||
|
||||
s.index = index
|
||||
s.logger.Info("Search index initialized successfully")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close closes the search index
|
||||
func (s *searchService) Close() error {
|
||||
if s.index != nil {
|
||||
err := s.index.Close()
|
||||
s.index = nil
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// IndexFile adds or updates a file in the search index
|
||||
func (s *searchService) IndexFile(file *FileDocument) error {
|
||||
if s.index == nil {
|
||||
return fmt.Errorf("search index not initialized")
|
||||
}
|
||||
|
||||
file.Type = "file"
|
||||
return s.index.Index(file.ID, file)
|
||||
}
|
||||
|
||||
// IndexCollection adds or updates a collection in the search index
|
||||
func (s *searchService) IndexCollection(collection *CollectionDocument) error {
|
||||
if s.index == nil {
|
||||
return fmt.Errorf("search index not initialized")
|
||||
}
|
||||
|
||||
collection.Type = "collection"
|
||||
return s.index.Index(collection.ID, collection)
|
||||
}
|
||||
|
||||
// DeleteFile removes a file from the search index
|
||||
func (s *searchService) DeleteFile(fileID string) error {
|
||||
if s.index == nil {
|
||||
return fmt.Errorf("search index not initialized")
|
||||
}
|
||||
|
||||
return s.index.Delete(fileID)
|
||||
}
|
||||
|
||||
// DeleteCollection removes a collection from the search index
|
||||
func (s *searchService) DeleteCollection(collectionID string) error {
|
||||
if s.index == nil {
|
||||
return fmt.Errorf("search index not initialized")
|
||||
}
|
||||
|
||||
return s.index.Delete(collectionID)
|
||||
}
|
||||
|
||||
// Search performs a full-text search across files and collections.
|
||||
//
|
||||
// The search supports:
|
||||
// - Simple queries: automatically wrapped with wildcards for substring matching
|
||||
// - Advanced queries: use Bleve query syntax directly (+, -, "", *, ?)
|
||||
//
|
||||
// Examples:
|
||||
// - "mesh" → matches "meshtastic", "mesh_config", etc.
|
||||
// - "\"exact phrase\"" → matches exact phrase only
|
||||
// - "+required -excluded" → must contain "required", must not contain "excluded"
|
||||
func (s *searchService) Search(query string, limit int) (*SearchResult, error) {
|
||||
if s.index == nil {
|
||||
return nil, fmt.Errorf("search index not initialized")
|
||||
}
|
||||
|
||||
if limit <= 0 || limit > 100 {
|
||||
limit = 50
|
||||
}
|
||||
|
||||
// Convert to lowercase for case-insensitive search
|
||||
searchQueryStr := strings.ToLower(query)
|
||||
|
||||
// For simple queries (no operators), wrap with wildcards to enable substring matching.
|
||||
// This allows "mesh" to match "meshtastic_antenna.png".
|
||||
// If the user provides operators or wildcards, use their query as-is.
|
||||
if !strings.Contains(searchQueryStr, "*") && !strings.Contains(searchQueryStr, "?") &&
|
||||
!strings.Contains(searchQueryStr, "+") && !strings.Contains(searchQueryStr, "-") &&
|
||||
!strings.Contains(searchQueryStr, "\"") {
|
||||
searchQueryStr = "*" + searchQueryStr + "*"
|
||||
}
|
||||
|
||||
searchQuery := bleve.NewQueryStringQuery(searchQueryStr)
|
||||
searchRequest := bleve.NewSearchRequest(searchQuery)
|
||||
searchRequest.Size = limit
|
||||
searchRequest.Fields = []string{"*"}
|
||||
searchRequest.Highlight = bleve.NewHighlight()
|
||||
|
||||
// Execute search
|
||||
searchResults, err := s.index.Search(searchRequest)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("search failed: %w", err)
|
||||
}
|
||||
|
||||
// Parse results with deduplication
|
||||
result := &SearchResult{
|
||||
Files: make([]*FileDocument, 0),
|
||||
Collections: make([]*CollectionDocument, 0),
|
||||
TotalHits: searchResults.Total,
|
||||
MaxScore: searchResults.MaxScore,
|
||||
Took: searchResults.Took,
|
||||
Query: query,
|
||||
}
|
||||
|
||||
// Use maps to deduplicate by ID
|
||||
seenFileIDs := make(map[string]bool)
|
||||
seenCollectionIDs := make(map[string]bool)
|
||||
|
||||
for _, hit := range searchResults.Hits {
|
||||
docType, ok := hit.Fields["type"].(string)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
if docType == "file" {
|
||||
// Skip if we've already seen this file ID
|
||||
if seenFileIDs[hit.ID] {
|
||||
s.logger.Warn("Duplicate file in search results", zap.String("id", hit.ID))
|
||||
continue
|
||||
}
|
||||
seenFileIDs[hit.ID] = true
|
||||
|
||||
file := &FileDocument{
|
||||
ID: hit.ID,
|
||||
Filename: getStringField(hit.Fields, "filename"),
|
||||
Description: getStringField(hit.Fields, "description"),
|
||||
CollectionID: getStringField(hit.Fields, "collection_id"),
|
||||
CollectionName: getStringField(hit.Fields, "collection_name"),
|
||||
Tags: getStringArrayField(hit.Fields, "tags"),
|
||||
Size: getInt64Field(hit.Fields, "size"),
|
||||
}
|
||||
if createdAt, ok := hit.Fields["created_at"].(string); ok {
|
||||
file.CreatedAt, _ = time.Parse(time.RFC3339, createdAt)
|
||||
}
|
||||
result.Files = append(result.Files, file)
|
||||
} else if docType == "collection" {
|
||||
// Skip if we've already seen this collection ID
|
||||
if seenCollectionIDs[hit.ID] {
|
||||
s.logger.Warn("Duplicate collection in search results", zap.String("id", hit.ID))
|
||||
continue
|
||||
}
|
||||
seenCollectionIDs[hit.ID] = true
|
||||
|
||||
collection := &CollectionDocument{
|
||||
ID: hit.ID,
|
||||
Name: getStringField(hit.Fields, "name"),
|
||||
Description: getStringField(hit.Fields, "description"),
|
||||
Tags: getStringArrayField(hit.Fields, "tags"),
|
||||
FileCount: getIntField(hit.Fields, "file_count"),
|
||||
}
|
||||
if createdAt, ok := hit.Fields["created_at"].(string); ok {
|
||||
collection.CreatedAt, _ = time.Parse(time.RFC3339, createdAt)
|
||||
}
|
||||
result.Collections = append(result.Collections, collection)
|
||||
}
|
||||
}
|
||||
|
||||
result.TotalFiles = len(result.Files)
|
||||
result.TotalCollections = len(result.Collections)
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// RebuildIndex rebuilds the entire search index from scratch.
|
||||
//
|
||||
// This method:
|
||||
// 1. Closes the existing index (if any)
|
||||
// 2. Deletes the index directory completely
|
||||
// 3. Creates a fresh new index
|
||||
// 4. Batch-indexes all provided files and collections
|
||||
//
|
||||
// This approach ensures no stale or duplicate documents remain in the index.
|
||||
// The userEmail is required to locate the user-specific index directory.
|
||||
func (s *searchService) RebuildIndex(userEmail string, files []*FileDocument, collections []*CollectionDocument) error {
|
||||
s.logger.Info("Rebuilding search index from scratch",
|
||||
zap.Int("files", len(files)),
|
||||
zap.Int("collections", len(collections)))
|
||||
|
||||
if userEmail == "" {
|
||||
return fmt.Errorf("user email is required for rebuild")
|
||||
}
|
||||
|
||||
// Close the current index
|
||||
if s.index != nil {
|
||||
s.logger.Info("Closing current index before rebuild")
|
||||
if err := s.index.Close(); err != nil {
|
||||
s.logger.Warn("Error closing index before rebuild", zap.Error(err))
|
||||
}
|
||||
s.index = nil
|
||||
}
|
||||
|
||||
// Get the index path from config
|
||||
ctx := context.Background()
|
||||
indexPath, err := s.configService.GetUserSearchIndexDir(ctx, userEmail)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get search index path: %w", err)
|
||||
}
|
||||
|
||||
// Delete the existing index directory
|
||||
s.logger.Info("Deleting existing index", zap.String("path", indexPath))
|
||||
// We don't check for error here because the directory might not exist
|
||||
// and that's okay - we're about to create it
|
||||
os.RemoveAll(indexPath)
|
||||
|
||||
// Create a fresh index
|
||||
s.logger.Info("Creating fresh index", zap.String("path", indexPath))
|
||||
indexMapping := buildIndexMapping()
|
||||
index, err := bleve.New(indexPath, indexMapping)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create fresh index: %w", err)
|
||||
}
|
||||
|
||||
s.index = index
|
||||
|
||||
// Now index all files and collections in a batch
|
||||
batch := s.index.NewBatch()
|
||||
|
||||
// Index all files
|
||||
for _, file := range files {
|
||||
file.Type = "file"
|
||||
if err := batch.Index(file.ID, file); err != nil {
|
||||
s.logger.Error("Failed to batch index file", zap.String("id", file.ID), zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
// Index all collections
|
||||
for _, collection := range collections {
|
||||
collection.Type = "collection"
|
||||
if err := batch.Index(collection.ID, collection); err != nil {
|
||||
s.logger.Error("Failed to batch index collection", zap.String("id", collection.ID), zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
// Execute batch
|
||||
if err := s.index.Batch(batch); err != nil {
|
||||
return fmt.Errorf("failed to execute batch index: %w", err)
|
||||
}
|
||||
|
||||
finalCount, _ := s.index.DocCount()
|
||||
s.logger.Info("Search index rebuilt successfully",
|
||||
zap.Uint64("documents", finalCount),
|
||||
zap.Int("files_indexed", len(files)),
|
||||
zap.Int("collections_indexed", len(collections)))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetIndexSize returns the size of the search index in bytes
|
||||
func (s *searchService) GetIndexSize() (int64, error) {
|
||||
if s.index == nil {
|
||||
return 0, fmt.Errorf("search index not initialized")
|
||||
}
|
||||
|
||||
// Note: Bleve doesn't provide a direct way to get index size
|
||||
// We return the document count as a proxy for size
|
||||
// For actual disk usage, you would need to walk the index directory
|
||||
count, err := s.index.DocCount()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return int64(count), nil
|
||||
}
|
||||
|
||||
// GetDocumentCount returns the number of documents in the index
|
||||
func (s *searchService) GetDocumentCount() (uint64, error) {
|
||||
if s.index == nil {
|
||||
return 0, fmt.Errorf("search index not initialized")
|
||||
}
|
||||
|
||||
count, err := s.index.DocCount()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// buildIndexMapping creates the Bleve index mapping for files and collections.
|
||||
//
|
||||
// Field types:
|
||||
// - Text fields (filename, description, name, tags): Analyzed with "standard" analyzer
|
||||
// for good tokenization without stemming (better for substring matching)
|
||||
// - Keyword fields (collection_id, type): Exact match only, no analysis
|
||||
// - Numeric fields (size, file_count): Stored as numbers for range queries
|
||||
// - Date fields (created_at): Stored as datetime for date-based queries
|
||||
func buildIndexMapping() mapping.IndexMapping {
|
||||
indexMapping := bleve.NewIndexMapping()
|
||||
|
||||
// Use standard analyzer (not English) for better substring matching.
|
||||
// The English analyzer applies stemming which can interfere with partial matches.
|
||||
textFieldMapping := bleve.NewTextFieldMapping()
|
||||
textFieldMapping.Analyzer = "standard"
|
||||
|
||||
// Create keyword field mapping (no analysis)
|
||||
keywordFieldMapping := bleve.NewKeywordFieldMapping()
|
||||
|
||||
// Create numeric field mapping
|
||||
numericFieldMapping := bleve.NewNumericFieldMapping()
|
||||
|
||||
// Create datetime field mapping
|
||||
dateFieldMapping := bleve.NewDateTimeFieldMapping()
|
||||
|
||||
// File document mapping
|
||||
fileMapping := bleve.NewDocumentMapping()
|
||||
fileMapping.AddFieldMappingsAt("filename", textFieldMapping)
|
||||
fileMapping.AddFieldMappingsAt("description", textFieldMapping)
|
||||
fileMapping.AddFieldMappingsAt("collection_name", textFieldMapping)
|
||||
fileMapping.AddFieldMappingsAt("tags", textFieldMapping)
|
||||
fileMapping.AddFieldMappingsAt("collection_id", keywordFieldMapping)
|
||||
fileMapping.AddFieldMappingsAt("size", numericFieldMapping)
|
||||
fileMapping.AddFieldMappingsAt("created_at", dateFieldMapping)
|
||||
fileMapping.AddFieldMappingsAt("type", keywordFieldMapping)
|
||||
|
||||
// Collection document mapping
|
||||
collectionMapping := bleve.NewDocumentMapping()
|
||||
collectionMapping.AddFieldMappingsAt("name", textFieldMapping)
|
||||
collectionMapping.AddFieldMappingsAt("description", textFieldMapping)
|
||||
collectionMapping.AddFieldMappingsAt("tags", textFieldMapping)
|
||||
collectionMapping.AddFieldMappingsAt("file_count", numericFieldMapping)
|
||||
collectionMapping.AddFieldMappingsAt("created_at", dateFieldMapping)
|
||||
collectionMapping.AddFieldMappingsAt("type", keywordFieldMapping)
|
||||
|
||||
indexMapping.AddDocumentMapping("file", fileMapping)
|
||||
indexMapping.AddDocumentMapping("collection", collectionMapping)
|
||||
|
||||
return indexMapping
|
||||
}
|
||||
|
||||
// Helper functions to extract fields from search results
|
||||
|
||||
func getStringField(fields map[string]interface{}, key string) string {
|
||||
if val, ok := fields[key].(string); ok {
|
||||
return val
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func getStringArrayField(fields map[string]interface{}, key string) []string {
|
||||
if val, ok := fields[key].([]interface{}); ok {
|
||||
result := make([]string, 0, len(val))
|
||||
for _, v := range val {
|
||||
if str, ok := v.(string); ok {
|
||||
result = append(result, str)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
return []string{}
|
||||
}
|
||||
|
||||
func getIntField(fields map[string]interface{}, key string) int {
|
||||
if val, ok := fields[key].(float64); ok {
|
||||
return int(val)
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func getInt64Field(fields map[string]interface{}, key string) int64 {
|
||||
if val, ok := fields[key].(float64); ok {
|
||||
return int64(val)
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
|
@ -0,0 +1,276 @@
|
|||
// Package securitylog provides security event logging for audit purposes.
|
||||
// This captures security-relevant events for monitoring and forensics.
|
||||
package securitylog
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// EventType defines the type of security event
|
||||
type EventType string
|
||||
|
||||
const (
|
||||
// Authentication events
|
||||
EventLoginAttempt EventType = "LOGIN_ATTEMPT"
|
||||
EventLoginSuccess EventType = "LOGIN_SUCCESS"
|
||||
EventLoginFailure EventType = "LOGIN_FAILURE"
|
||||
EventLogout EventType = "LOGOUT"
|
||||
EventSessionRestored EventType = "SESSION_RESTORED"
|
||||
EventSessionExpired EventType = "SESSION_EXPIRED"
|
||||
EventSessionRevoked EventType = "SESSION_REVOKED"
|
||||
EventTokenRefresh EventType = "TOKEN_REFRESH"
|
||||
EventTokenRefreshFail EventType = "TOKEN_REFRESH_FAILURE"
|
||||
|
||||
// Registration events
|
||||
EventRegistration EventType = "REGISTRATION"
|
||||
EventEmailVerification EventType = "EMAIL_VERIFICATION"
|
||||
EventOTTRequest EventType = "OTT_REQUEST"
|
||||
EventOTTVerify EventType = "OTT_VERIFY"
|
||||
EventPasswordChallenge EventType = "PASSWORD_CHALLENGE"
|
||||
|
||||
// Rate limiting events
|
||||
EventRateLimitExceeded EventType = "RATE_LIMIT_EXCEEDED"
|
||||
|
||||
// Data access events
|
||||
EventCollectionCreate EventType = "COLLECTION_CREATE"
|
||||
EventCollectionUpdate EventType = "COLLECTION_UPDATE"
|
||||
EventCollectionDelete EventType = "COLLECTION_DELETE"
|
||||
EventCollectionAccess EventType = "COLLECTION_ACCESS"
|
||||
|
||||
EventFileUpload EventType = "FILE_UPLOAD"
|
||||
EventFileDownload EventType = "FILE_DOWNLOAD"
|
||||
EventFileDelete EventType = "FILE_DELETE"
|
||||
EventFileAccess EventType = "FILE_ACCESS"
|
||||
EventFileOpen EventType = "FILE_OPEN"
|
||||
|
||||
// Export events
|
||||
EventExportStart EventType = "EXPORT_START"
|
||||
EventExportComplete EventType = "EXPORT_COMPLETE"
|
||||
EventExportFailure EventType = "EXPORT_FAILURE"
|
||||
|
||||
// Configuration events
|
||||
EventConfigChange EventType = "CONFIG_CHANGE"
|
||||
EventConfigIntegrityFail EventType = "CONFIG_INTEGRITY_FAILURE"
|
||||
EventCloudProviderChange EventType = "CLOUD_PROVIDER_CHANGE"
|
||||
|
||||
// Security events
|
||||
EventSecurityValidationFail EventType = "SECURITY_VALIDATION_FAILURE"
|
||||
EventURLValidationFail EventType = "URL_VALIDATION_FAILURE"
|
||||
EventInputValidationFail EventType = "INPUT_VALIDATION_FAILURE"
|
||||
EventPathTraversalAttempt EventType = "PATH_TRAVERSAL_ATTEMPT"
|
||||
|
||||
// Key management events
|
||||
EventMasterKeyDerived EventType = "MASTER_KEY_DERIVED"
|
||||
EventMasterKeyCleared EventType = "MASTER_KEY_CLEARED"
|
||||
EventPasswordCleared EventType = "PASSWORD_CLEARED"
|
||||
|
||||
// Application lifecycle events
|
||||
EventAppStart EventType = "APP_START"
|
||||
EventAppShutdown EventType = "APP_SHUTDOWN"
|
||||
)
|
||||
|
||||
// EventOutcome indicates the result of an event
|
||||
type EventOutcome string
|
||||
|
||||
const (
|
||||
OutcomeSuccess EventOutcome = "SUCCESS"
|
||||
OutcomeFailure EventOutcome = "FAILURE"
|
||||
OutcomeBlocked EventOutcome = "BLOCKED"
|
||||
)
|
||||
|
||||
// SecurityEvent represents a security-relevant event
|
||||
type SecurityEvent struct {
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
EventType EventType `json:"event_type"`
|
||||
Outcome EventOutcome `json:"outcome"`
|
||||
UserEmail string `json:"user_email,omitempty"` // Masked email
|
||||
ResourceID string `json:"resource_id,omitempty"`
|
||||
ResourceType string `json:"resource_type,omitempty"`
|
||||
Details map[string]string `json:"details,omitempty"`
|
||||
ErrorMsg string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// Service provides security event logging
|
||||
type Service struct {
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// New creates a new security logging service
|
||||
func New(logger *zap.Logger) *Service {
|
||||
return &Service{
|
||||
logger: logger.Named("security"),
|
||||
}
|
||||
}
|
||||
|
||||
// ProvideService is the Wire provider for the security log service
|
||||
func ProvideService(logger *zap.Logger) *Service {
|
||||
return New(logger)
|
||||
}
|
||||
|
||||
// LogEvent logs a security event
|
||||
func (s *Service) LogEvent(event *SecurityEvent) {
|
||||
event.Timestamp = time.Now().UTC()
|
||||
|
||||
fields := []zap.Field{
|
||||
zap.String("event_type", string(event.EventType)),
|
||||
zap.String("outcome", string(event.Outcome)),
|
||||
zap.Time("timestamp", event.Timestamp),
|
||||
}
|
||||
|
||||
if event.UserEmail != "" {
|
||||
fields = append(fields, zap.String("user_email", event.UserEmail))
|
||||
}
|
||||
if event.ResourceID != "" {
|
||||
fields = append(fields, zap.String("resource_id", event.ResourceID))
|
||||
}
|
||||
if event.ResourceType != "" {
|
||||
fields = append(fields, zap.String("resource_type", event.ResourceType))
|
||||
}
|
||||
if event.ErrorMsg != "" {
|
||||
fields = append(fields, zap.String("error", event.ErrorMsg))
|
||||
}
|
||||
for k, v := range event.Details {
|
||||
fields = append(fields, zap.String("detail_"+k, v))
|
||||
}
|
||||
|
||||
switch event.Outcome {
|
||||
case OutcomeSuccess:
|
||||
s.logger.Info("Security event", fields...)
|
||||
case OutcomeFailure:
|
||||
s.logger.Warn("Security event", fields...)
|
||||
case OutcomeBlocked:
|
||||
s.logger.Warn("Security event (blocked)", fields...)
|
||||
default:
|
||||
s.logger.Info("Security event", fields...)
|
||||
}
|
||||
}
|
||||
|
||||
// Helper methods for common events
|
||||
|
||||
// LogLoginAttempt logs a login attempt
|
||||
func (s *Service) LogLoginAttempt(maskedEmail string) {
|
||||
s.LogEvent(&SecurityEvent{
|
||||
EventType: EventLoginAttempt,
|
||||
Outcome: OutcomeSuccess,
|
||||
UserEmail: maskedEmail,
|
||||
})
|
||||
}
|
||||
|
||||
// LogLoginSuccess logs a successful login
|
||||
func (s *Service) LogLoginSuccess(maskedEmail string) {
|
||||
s.LogEvent(&SecurityEvent{
|
||||
EventType: EventLoginSuccess,
|
||||
Outcome: OutcomeSuccess,
|
||||
UserEmail: maskedEmail,
|
||||
})
|
||||
}
|
||||
|
||||
// LogLoginFailure logs a failed login
|
||||
func (s *Service) LogLoginFailure(maskedEmail string, reason string) {
|
||||
s.LogEvent(&SecurityEvent{
|
||||
EventType: EventLoginFailure,
|
||||
Outcome: OutcomeFailure,
|
||||
UserEmail: maskedEmail,
|
||||
ErrorMsg: reason,
|
||||
})
|
||||
}
|
||||
|
||||
// LogLogout logs a logout event
|
||||
func (s *Service) LogLogout(maskedEmail string) {
|
||||
s.LogEvent(&SecurityEvent{
|
||||
EventType: EventLogout,
|
||||
Outcome: OutcomeSuccess,
|
||||
UserEmail: maskedEmail,
|
||||
})
|
||||
}
|
||||
|
||||
// LogRateLimitExceeded logs a rate limit exceeded event
|
||||
func (s *Service) LogRateLimitExceeded(maskedEmail string, operation string) {
|
||||
s.LogEvent(&SecurityEvent{
|
||||
EventType: EventRateLimitExceeded,
|
||||
Outcome: OutcomeBlocked,
|
||||
UserEmail: maskedEmail,
|
||||
Details: map[string]string{
|
||||
"operation": operation,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// LogFileAccess logs a file access event
|
||||
func (s *Service) LogFileAccess(maskedEmail string, fileID string, operation string, outcome EventOutcome) {
|
||||
s.LogEvent(&SecurityEvent{
|
||||
EventType: EventFileAccess,
|
||||
Outcome: outcome,
|
||||
UserEmail: maskedEmail,
|
||||
ResourceID: fileID,
|
||||
ResourceType: "file",
|
||||
Details: map[string]string{
|
||||
"operation": operation,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// LogCollectionAccess logs a collection access event
|
||||
func (s *Service) LogCollectionAccess(maskedEmail string, collectionID string, operation string, outcome EventOutcome) {
|
||||
s.LogEvent(&SecurityEvent{
|
||||
EventType: EventCollectionAccess,
|
||||
Outcome: outcome,
|
||||
UserEmail: maskedEmail,
|
||||
ResourceID: collectionID,
|
||||
ResourceType: "collection",
|
||||
Details: map[string]string{
|
||||
"operation": operation,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// LogSecurityValidationFailure logs a security validation failure
|
||||
func (s *Service) LogSecurityValidationFailure(eventType EventType, details map[string]string, errorMsg string) {
|
||||
s.LogEvent(&SecurityEvent{
|
||||
EventType: eventType,
|
||||
Outcome: OutcomeBlocked,
|
||||
Details: details,
|
||||
ErrorMsg: errorMsg,
|
||||
})
|
||||
}
|
||||
|
||||
// LogExport logs an export operation
|
||||
func (s *Service) LogExport(maskedEmail string, eventType EventType, outcome EventOutcome, details map[string]string) {
|
||||
s.LogEvent(&SecurityEvent{
|
||||
EventType: eventType,
|
||||
Outcome: outcome,
|
||||
UserEmail: maskedEmail,
|
||||
Details: details,
|
||||
})
|
||||
}
|
||||
|
||||
// LogConfigChange logs a configuration change
|
||||
func (s *Service) LogConfigChange(setting string, maskedEmail string) {
|
||||
s.LogEvent(&SecurityEvent{
|
||||
EventType: EventConfigChange,
|
||||
Outcome: OutcomeSuccess,
|
||||
UserEmail: maskedEmail,
|
||||
Details: map[string]string{
|
||||
"setting": setting,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// LogAppLifecycle logs application start/shutdown
|
||||
func (s *Service) LogAppLifecycle(eventType EventType) {
|
||||
s.LogEvent(&SecurityEvent{
|
||||
EventType: eventType,
|
||||
Outcome: OutcomeSuccess,
|
||||
})
|
||||
}
|
||||
|
||||
// LogKeyManagement logs key management operations
|
||||
func (s *Service) LogKeyManagement(eventType EventType, maskedEmail string, outcome EventOutcome) {
|
||||
s.LogEvent(&SecurityEvent{
|
||||
EventType: eventType,
|
||||
Outcome: outcome,
|
||||
UserEmail: maskedEmail,
|
||||
})
|
||||
}
|
||||
|
|
@ -0,0 +1,284 @@
|
|||
// Package storagemanager provides a service for managing user-specific storage.
|
||||
// It handles the lifecycle of storage instances, creating new storage when a user
|
||||
// logs in and cleaning up when they log out.
|
||||
//
|
||||
// Storage is organized as follows:
|
||||
// - Global storage (session): {appDir}/session/ - stores current login session
|
||||
// - User-specific storage: {appDir}/users/{emailHash}/ - stores user data
|
||||
//
|
||||
// This ensures:
|
||||
// 1. Different users have completely isolated data
|
||||
// 2. Dev and production modes have separate directories
|
||||
// 3. Email addresses are not exposed in directory names (hashed)
|
||||
package storagemanager
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
"codeberg.org/mapleopentech/monorepo/native/desktop/maplefile/internal/config"
|
||||
collectionDomain "codeberg.org/mapleopentech/monorepo/native/desktop/maplefile/internal/domain/collection"
|
||||
fileDomain "codeberg.org/mapleopentech/monorepo/native/desktop/maplefile/internal/domain/file"
|
||||
syncstateDomain "codeberg.org/mapleopentech/monorepo/native/desktop/maplefile/internal/domain/syncstate"
|
||||
collectionRepo "codeberg.org/mapleopentech/monorepo/native/desktop/maplefile/internal/repo/collection"
|
||||
fileRepo "codeberg.org/mapleopentech/monorepo/native/desktop/maplefile/internal/repo/file"
|
||||
syncstateRepo "codeberg.org/mapleopentech/monorepo/native/desktop/maplefile/internal/repo/syncstate"
|
||||
"codeberg.org/mapleopentech/monorepo/native/desktop/maplefile/pkg/storage"
|
||||
"codeberg.org/mapleopentech/monorepo/native/desktop/maplefile/pkg/storage/leveldb"
|
||||
)
|
||||
|
||||
// Manager manages user-specific storage instances.
|
||||
// It creates storage when a user logs in and cleans up when they log out.
|
||||
type Manager struct {
|
||||
logger *zap.Logger
|
||||
mu sync.RWMutex
|
||||
|
||||
// Current user's email (empty if no user is logged in)
|
||||
currentUserEmail string
|
||||
|
||||
// User-specific storage instances
|
||||
localFilesStorage storage.Storage
|
||||
syncStateStorage storage.Storage
|
||||
|
||||
// User-specific repositories (built on top of storage)
|
||||
fileRepo fileDomain.Repository
|
||||
collectionRepo collectionDomain.Repository
|
||||
syncStateRepo syncstateDomain.Repository
|
||||
}
|
||||
|
||||
// ProvideManager creates a new storage manager.
|
||||
func ProvideManager(logger *zap.Logger) *Manager {
|
||||
return &Manager{
|
||||
logger: logger.Named("storage-manager"),
|
||||
}
|
||||
}
|
||||
|
||||
// InitializeForUser initializes user-specific storage for the given user.
|
||||
// This should be called after a user successfully logs in.
|
||||
// If storage is already initialized for a different user, it will be cleaned up first.
|
||||
func (m *Manager) InitializeForUser(userEmail string) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if userEmail == "" {
|
||||
return fmt.Errorf("user email is required")
|
||||
}
|
||||
|
||||
// If same user, no need to reinitialize
|
||||
if m.currentUserEmail == userEmail && m.localFilesStorage != nil {
|
||||
m.logger.Debug("Storage already initialized for user",
|
||||
zap.String("email_hash", config.GetEmailHashForPath(userEmail)))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Clean up existing storage if different user
|
||||
if m.currentUserEmail != "" && m.currentUserEmail != userEmail {
|
||||
m.logger.Info("Switching user storage",
|
||||
zap.String("old_user_hash", config.GetEmailHashForPath(m.currentUserEmail)),
|
||||
zap.String("new_user_hash", config.GetEmailHashForPath(userEmail)))
|
||||
m.cleanupStorageUnsafe()
|
||||
}
|
||||
|
||||
m.logger.Info("Initializing storage for user",
|
||||
zap.String("email_hash", config.GetEmailHashForPath(userEmail)))
|
||||
|
||||
// Initialize local files storage
|
||||
localFilesProvider, err := config.NewLevelDBConfigurationProviderForLocalFilesWithUser(userEmail)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create local files storage provider: %w", err)
|
||||
}
|
||||
m.localFilesStorage = leveldb.NewDiskStorage(localFilesProvider, m.logger.Named("local-files"))
|
||||
|
||||
// Initialize sync state storage
|
||||
syncStateProvider, err := config.NewLevelDBConfigurationProviderForSyncStateWithUser(userEmail)
|
||||
if err != nil {
|
||||
m.cleanupStorageUnsafe()
|
||||
return fmt.Errorf("failed to create sync state storage provider: %w", err)
|
||||
}
|
||||
m.syncStateStorage = leveldb.NewDiskStorage(syncStateProvider, m.logger.Named("sync-state"))
|
||||
|
||||
// Create repositories
|
||||
m.fileRepo = fileRepo.ProvideRepository(m.localFilesStorage)
|
||||
m.collectionRepo = collectionRepo.ProvideRepository(m.localFilesStorage)
|
||||
m.syncStateRepo = syncstateRepo.ProvideRepository(m.syncStateStorage)
|
||||
|
||||
m.currentUserEmail = userEmail
|
||||
|
||||
m.logger.Info("Storage initialized successfully",
|
||||
zap.String("email_hash", config.GetEmailHashForPath(userEmail)))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Cleanup cleans up all user-specific storage.
|
||||
// This should be called when a user logs out.
|
||||
func (m *Manager) Cleanup() {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.cleanupStorageUnsafe()
|
||||
}
|
||||
|
||||
// cleanupStorageUnsafe cleans up storage without acquiring the lock.
|
||||
// Caller must hold the lock.
|
||||
func (m *Manager) cleanupStorageUnsafe() {
|
||||
if m.localFilesStorage != nil {
|
||||
if closer, ok := m.localFilesStorage.(interface{ Close() error }); ok {
|
||||
if err := closer.Close(); err != nil {
|
||||
m.logger.Warn("Failed to close local files storage", zap.Error(err))
|
||||
}
|
||||
}
|
||||
m.localFilesStorage = nil
|
||||
}
|
||||
|
||||
if m.syncStateStorage != nil {
|
||||
if closer, ok := m.syncStateStorage.(interface{ Close() error }); ok {
|
||||
if err := closer.Close(); err != nil {
|
||||
m.logger.Warn("Failed to close sync state storage", zap.Error(err))
|
||||
}
|
||||
}
|
||||
m.syncStateStorage = nil
|
||||
}
|
||||
|
||||
m.fileRepo = nil
|
||||
m.collectionRepo = nil
|
||||
m.syncStateRepo = nil
|
||||
m.currentUserEmail = ""
|
||||
|
||||
m.logger.Debug("Storage cleaned up")
|
||||
}
|
||||
|
||||
// IsInitialized returns true if storage has been initialized for a user.
|
||||
func (m *Manager) IsInitialized() bool {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
return m.localFilesStorage != nil
|
||||
}
|
||||
|
||||
// GetCurrentUserEmail returns the email of the user for whom storage is initialized.
|
||||
// Returns empty string if no user storage is initialized.
|
||||
func (m *Manager) GetCurrentUserEmail() string {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
return m.currentUserEmail
|
||||
}
|
||||
|
||||
// GetFileRepository returns the file repository for the current user.
|
||||
// Returns nil if storage is not initialized.
|
||||
func (m *Manager) GetFileRepository() fileDomain.Repository {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
return m.fileRepo
|
||||
}
|
||||
|
||||
// GetCollectionRepository returns the collection repository for the current user.
|
||||
// Returns nil if storage is not initialized.
|
||||
func (m *Manager) GetCollectionRepository() collectionDomain.Repository {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
return m.collectionRepo
|
||||
}
|
||||
|
||||
// GetSyncStateRepository returns the sync state repository for the current user.
|
||||
// Returns nil if storage is not initialized.
|
||||
func (m *Manager) GetSyncStateRepository() syncstateDomain.Repository {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
return m.syncStateRepo
|
||||
}
|
||||
|
||||
// GetLocalFilesStorage returns the raw local files storage for the current user.
|
||||
// Returns nil if storage is not initialized.
|
||||
func (m *Manager) GetLocalFilesStorage() storage.Storage {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
return m.localFilesStorage
|
||||
}
|
||||
|
||||
// DeleteUserData permanently deletes all local data for the specified user.
|
||||
// This includes all files, metadata, and sync state stored on this device.
|
||||
// IMPORTANT: This is a destructive operation and cannot be undone.
|
||||
// The user will need to re-download all files from the cloud after this operation.
|
||||
func (m *Manager) DeleteUserData(userEmail string) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if userEmail == "" {
|
||||
return fmt.Errorf("user email is required")
|
||||
}
|
||||
|
||||
emailHash := config.GetEmailHashForPath(userEmail)
|
||||
m.logger.Info("Deleting all local data for user",
|
||||
zap.String("email_hash", emailHash))
|
||||
|
||||
// If this is the current user, clean up storage first
|
||||
if m.currentUserEmail == userEmail {
|
||||
m.cleanupStorageUnsafe()
|
||||
}
|
||||
|
||||
// Get the user's data directory
|
||||
userDir, err := config.GetUserSpecificDataDir("maplefile", userEmail)
|
||||
if err != nil {
|
||||
m.logger.Error("Failed to get user data directory", zap.Error(err))
|
||||
return fmt.Errorf("failed to get user data directory: %w", err)
|
||||
}
|
||||
|
||||
// Check if the directory exists
|
||||
if _, err := os.Stat(userDir); os.IsNotExist(err) {
|
||||
m.logger.Debug("User data directory does not exist, nothing to delete",
|
||||
zap.String("email_hash", emailHash))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Remove the entire user directory and all its contents
|
||||
if err := os.RemoveAll(userDir); err != nil {
|
||||
m.logger.Error("Failed to delete user data directory",
|
||||
zap.Error(err),
|
||||
zap.String("path", userDir))
|
||||
return fmt.Errorf("failed to delete user data: %w", err)
|
||||
}
|
||||
|
||||
m.logger.Info("Successfully deleted all local data for user",
|
||||
zap.String("email_hash", emailHash))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetUserDataSize returns the total size of local data stored for the specified user in bytes.
|
||||
// Returns 0 if no data exists or if there's an error calculating the size.
|
||||
func (m *Manager) GetUserDataSize(userEmail string) (int64, error) {
|
||||
if userEmail == "" {
|
||||
return 0, fmt.Errorf("user email is required")
|
||||
}
|
||||
|
||||
userDir, err := config.GetUserSpecificDataDir("maplefile", userEmail)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to get user data directory: %w", err)
|
||||
}
|
||||
|
||||
// Check if the directory exists
|
||||
if _, err := os.Stat(userDir); os.IsNotExist(err) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
var totalSize int64
|
||||
err = filepath.Walk(userDir, func(path string, info os.FileInfo, err error) error {
|
||||
if err != nil {
|
||||
return nil // Ignore errors and continue
|
||||
}
|
||||
if !info.IsDir() {
|
||||
totalSize += info.Size()
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
m.logger.Warn("Error calculating user data size", zap.Error(err))
|
||||
return totalSize, nil // Return what we have
|
||||
}
|
||||
|
||||
return totalSize, nil
|
||||
}
|
||||
225
native/desktop/maplefile/internal/service/sync/collection.go
Normal file
225
native/desktop/maplefile/internal/service/sync/collection.go
Normal file
|
|
@ -0,0 +1,225 @@
|
|||
package sync
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
"codeberg.org/mapleopentech/monorepo/cloud/maplefile-backend/pkg/maplefile/client"
|
||||
collectionDomain "codeberg.org/mapleopentech/monorepo/native/desktop/maplefile/internal/domain/collection"
|
||||
"codeberg.org/mapleopentech/monorepo/native/desktop/maplefile/internal/domain/syncstate"
|
||||
)
|
||||
|
||||
// CollectionSyncService defines the interface for collection synchronization
|
||||
type CollectionSyncService interface {
|
||||
Execute(ctx context.Context, input *SyncInput) (*SyncResult, error)
|
||||
}
|
||||
|
||||
type collectionSyncService struct {
|
||||
logger *zap.Logger
|
||||
apiClient *client.Client
|
||||
repoProvider RepositoryProvider
|
||||
}
|
||||
|
||||
// ProvideCollectionSyncService creates a new collection sync service for Wire
|
||||
func ProvideCollectionSyncService(
|
||||
logger *zap.Logger,
|
||||
apiClient *client.Client,
|
||||
repoProvider RepositoryProvider,
|
||||
) CollectionSyncService {
|
||||
return &collectionSyncService{
|
||||
logger: logger.Named("CollectionSyncService"),
|
||||
apiClient: apiClient,
|
||||
repoProvider: repoProvider,
|
||||
}
|
||||
}
|
||||
|
||||
// getCollectionRepo returns the collection repository, or an error if not initialized
|
||||
func (s *collectionSyncService) getCollectionRepo() (collectionDomain.Repository, error) {
|
||||
if !s.repoProvider.IsInitialized() {
|
||||
return nil, fmt.Errorf("storage not initialized - user must be logged in")
|
||||
}
|
||||
repo := s.repoProvider.GetCollectionRepository()
|
||||
if repo == nil {
|
||||
return nil, fmt.Errorf("collection repository not available")
|
||||
}
|
||||
return repo, nil
|
||||
}
|
||||
|
||||
// getSyncStateRepo returns the sync state repository, or an error if not initialized
|
||||
func (s *collectionSyncService) getSyncStateRepo() (syncstate.Repository, error) {
|
||||
if !s.repoProvider.IsInitialized() {
|
||||
return nil, fmt.Errorf("storage not initialized - user must be logged in")
|
||||
}
|
||||
repo := s.repoProvider.GetSyncStateRepository()
|
||||
if repo == nil {
|
||||
return nil, fmt.Errorf("sync state repository not available")
|
||||
}
|
||||
return repo, nil
|
||||
}
|
||||
|
||||
// Execute synchronizes collections from the cloud to local storage
|
||||
func (s *collectionSyncService) Execute(ctx context.Context, input *SyncInput) (*SyncResult, error) {
|
||||
s.logger.Info("Starting collection synchronization")
|
||||
|
||||
// Get repositories (will fail if user not logged in)
|
||||
syncStateRepo, err := s.getSyncStateRepo()
|
||||
if err != nil {
|
||||
s.logger.Error("Cannot sync - storage not initialized", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Set defaults
|
||||
if input == nil {
|
||||
input = &SyncInput{}
|
||||
}
|
||||
if input.BatchSize <= 0 {
|
||||
input.BatchSize = DefaultBatchSize
|
||||
}
|
||||
if input.MaxBatches <= 0 {
|
||||
input.MaxBatches = DefaultMaxBatches
|
||||
}
|
||||
|
||||
// Get current sync state
|
||||
state, err := syncStateRepo.Get()
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to get sync state", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := &SyncResult{}
|
||||
batchCount := 0
|
||||
|
||||
// Sync loop - fetch and process batches until done or max reached
|
||||
for batchCount < input.MaxBatches {
|
||||
// Prepare API request
|
||||
syncInput := &client.SyncInput{
|
||||
Cursor: state.CollectionCursor,
|
||||
Limit: input.BatchSize,
|
||||
}
|
||||
|
||||
// Fetch batch from cloud
|
||||
resp, err := s.apiClient.SyncCollections(ctx, syncInput)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to fetch collections from cloud", zap.Error(err))
|
||||
result.Errors = append(result.Errors, "failed to fetch collections: "+err.Error())
|
||||
break
|
||||
}
|
||||
|
||||
// Process each collection in the batch
|
||||
for _, cloudCol := range resp.Collections {
|
||||
if err := s.processCollection(ctx, cloudCol, input.Password, result); err != nil {
|
||||
s.logger.Error("Failed to process collection",
|
||||
zap.String("id", cloudCol.ID),
|
||||
zap.Error(err))
|
||||
result.Errors = append(result.Errors, "failed to process collection "+cloudCol.ID+": "+err.Error())
|
||||
}
|
||||
result.CollectionsProcessed++
|
||||
}
|
||||
|
||||
// Update sync state with new cursor
|
||||
state.UpdateCollectionSync(resp.NextCursor, resp.HasMore)
|
||||
if err := syncStateRepo.Save(state); err != nil {
|
||||
s.logger.Error("Failed to save sync state", zap.Error(err))
|
||||
result.Errors = append(result.Errors, "failed to save sync state: "+err.Error())
|
||||
}
|
||||
|
||||
batchCount++
|
||||
|
||||
// Check if we're done
|
||||
if !resp.HasMore {
|
||||
s.logger.Info("Collection sync completed - no more items")
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
s.logger.Info("Collection sync finished",
|
||||
zap.Int("processed", result.CollectionsProcessed),
|
||||
zap.Int("added", result.CollectionsAdded),
|
||||
zap.Int("updated", result.CollectionsUpdated),
|
||||
zap.Int("deleted", result.CollectionsDeleted),
|
||||
zap.Int("errors", len(result.Errors)))
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// processCollection handles a single collection from the cloud
|
||||
// Note: ctx and password are reserved for future use (on-demand content decryption)
|
||||
func (s *collectionSyncService) processCollection(_ context.Context, cloudCol *client.Collection, _ string, result *SyncResult) error {
|
||||
// Get collection repository
|
||||
collectionRepo, err := s.getCollectionRepo()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Check if collection exists locally
|
||||
localCol, err := collectionRepo.Get(cloudCol.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Handle deleted collections
|
||||
if cloudCol.State == collectionDomain.StateDeleted {
|
||||
if localCol != nil {
|
||||
if err := collectionRepo.Delete(cloudCol.ID); err != nil {
|
||||
return err
|
||||
}
|
||||
result.CollectionsDeleted++
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// The collection name comes from the API already decrypted for owned collections.
|
||||
// For shared collections, it would need decryption using the key chain.
|
||||
// For now, we use the name as-is from the API response.
|
||||
collectionName := cloudCol.Name
|
||||
|
||||
// Create or update local collection
|
||||
if localCol == nil {
|
||||
// Create new local collection
|
||||
newCol := s.mapCloudToLocal(cloudCol, collectionName)
|
||||
if err := collectionRepo.Create(newCol); err != nil {
|
||||
return err
|
||||
}
|
||||
result.CollectionsAdded++
|
||||
} else {
|
||||
// Update existing collection
|
||||
updatedCol := s.mapCloudToLocal(cloudCol, collectionName)
|
||||
updatedCol.SyncStatus = localCol.SyncStatus // Preserve local sync status
|
||||
updatedCol.LastSyncedAt = time.Now()
|
||||
if err := collectionRepo.Update(updatedCol); err != nil {
|
||||
return err
|
||||
}
|
||||
result.CollectionsUpdated++
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// mapCloudToLocal converts a cloud collection to local domain model
|
||||
func (s *collectionSyncService) mapCloudToLocal(cloudCol *client.Collection, decryptedName string) *collectionDomain.Collection {
|
||||
return &collectionDomain.Collection{
|
||||
ID: cloudCol.ID,
|
||||
ParentID: cloudCol.ParentID,
|
||||
OwnerID: cloudCol.UserID,
|
||||
EncryptedCollectionKey: cloudCol.EncryptedCollectionKey.Ciphertext,
|
||||
Nonce: cloudCol.EncryptedCollectionKey.Nonce,
|
||||
Name: decryptedName,
|
||||
Description: cloudCol.Description,
|
||||
CustomIcon: cloudCol.CustomIcon, // Custom icon (emoji or "icon:<id>")
|
||||
TotalFiles: cloudCol.TotalFiles,
|
||||
TotalSizeInBytes: cloudCol.TotalSizeInBytes,
|
||||
PermissionLevel: cloudCol.PermissionLevel,
|
||||
IsOwner: cloudCol.IsOwner,
|
||||
OwnerName: cloudCol.OwnerName,
|
||||
OwnerEmail: cloudCol.OwnerEmail,
|
||||
SyncStatus: collectionDomain.SyncStatusCloudOnly,
|
||||
LastSyncedAt: time.Now(),
|
||||
State: cloudCol.State,
|
||||
CreatedAt: cloudCol.CreatedAt,
|
||||
ModifiedAt: cloudCol.ModifiedAt,
|
||||
}
|
||||
}
|
||||
|
||||
254
native/desktop/maplefile/internal/service/sync/file.go
Normal file
254
native/desktop/maplefile/internal/service/sync/file.go
Normal file
|
|
@ -0,0 +1,254 @@
|
|||
package sync
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
"codeberg.org/mapleopentech/monorepo/cloud/maplefile-backend/pkg/maplefile/client"
|
||||
collectionDomain "codeberg.org/mapleopentech/monorepo/native/desktop/maplefile/internal/domain/collection"
|
||||
fileDomain "codeberg.org/mapleopentech/monorepo/native/desktop/maplefile/internal/domain/file"
|
||||
"codeberg.org/mapleopentech/monorepo/native/desktop/maplefile/internal/domain/syncstate"
|
||||
)
|
||||
|
||||
// FileSyncService defines the interface for file synchronization
|
||||
type FileSyncService interface {
|
||||
Execute(ctx context.Context, input *SyncInput) (*SyncResult, error)
|
||||
}
|
||||
|
||||
// RepositoryProvider provides access to user-specific repositories.
|
||||
// This interface allows sync services to work with dynamically initialized storage.
|
||||
// The storagemanager.Manager implements this interface.
|
||||
type RepositoryProvider interface {
|
||||
GetFileRepository() fileDomain.Repository
|
||||
GetCollectionRepository() collectionDomain.Repository
|
||||
GetSyncStateRepository() syncstate.Repository
|
||||
IsInitialized() bool
|
||||
}
|
||||
|
||||
type fileSyncService struct {
|
||||
logger *zap.Logger
|
||||
apiClient *client.Client
|
||||
repoProvider RepositoryProvider
|
||||
}
|
||||
|
||||
// ProvideFileSyncService creates a new file sync service for Wire
|
||||
func ProvideFileSyncService(
|
||||
logger *zap.Logger,
|
||||
apiClient *client.Client,
|
||||
repoProvider RepositoryProvider,
|
||||
) FileSyncService {
|
||||
return &fileSyncService{
|
||||
logger: logger.Named("FileSyncService"),
|
||||
apiClient: apiClient,
|
||||
repoProvider: repoProvider,
|
||||
}
|
||||
}
|
||||
|
||||
// getFileRepo returns the file repository, or an error if not initialized
|
||||
func (s *fileSyncService) getFileRepo() (fileDomain.Repository, error) {
|
||||
if !s.repoProvider.IsInitialized() {
|
||||
return nil, fmt.Errorf("storage not initialized - user must be logged in")
|
||||
}
|
||||
repo := s.repoProvider.GetFileRepository()
|
||||
if repo == nil {
|
||||
return nil, fmt.Errorf("file repository not available")
|
||||
}
|
||||
return repo, nil
|
||||
}
|
||||
|
||||
// getSyncStateRepo returns the sync state repository, or an error if not initialized
|
||||
func (s *fileSyncService) getSyncStateRepo() (syncstate.Repository, error) {
|
||||
if !s.repoProvider.IsInitialized() {
|
||||
return nil, fmt.Errorf("storage not initialized - user must be logged in")
|
||||
}
|
||||
repo := s.repoProvider.GetSyncStateRepository()
|
||||
if repo == nil {
|
||||
return nil, fmt.Errorf("sync state repository not available")
|
||||
}
|
||||
return repo, nil
|
||||
}
|
||||
|
||||
// Execute synchronizes files from the cloud to local storage (metadata only)
|
||||
func (s *fileSyncService) Execute(ctx context.Context, input *SyncInput) (*SyncResult, error) {
|
||||
s.logger.Info("Starting file synchronization")
|
||||
|
||||
// Get repositories (will fail if user not logged in)
|
||||
syncStateRepo, err := s.getSyncStateRepo()
|
||||
if err != nil {
|
||||
s.logger.Error("Cannot sync - storage not initialized", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Set defaults
|
||||
if input == nil {
|
||||
input = &SyncInput{}
|
||||
}
|
||||
if input.BatchSize <= 0 {
|
||||
input.BatchSize = DefaultBatchSize
|
||||
}
|
||||
if input.MaxBatches <= 0 {
|
||||
input.MaxBatches = DefaultMaxBatches
|
||||
}
|
||||
|
||||
// Get current sync state
|
||||
state, err := syncStateRepo.Get()
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to get sync state", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := &SyncResult{}
|
||||
batchCount := 0
|
||||
|
||||
// Sync loop - fetch and process batches until done or max reached
|
||||
for batchCount < input.MaxBatches {
|
||||
// Prepare API request
|
||||
syncInput := &client.SyncInput{
|
||||
Cursor: state.FileCursor,
|
||||
Limit: input.BatchSize,
|
||||
}
|
||||
|
||||
// Fetch batch from cloud
|
||||
resp, err := s.apiClient.SyncFiles(ctx, syncInput)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to fetch files from cloud", zap.Error(err))
|
||||
result.Errors = append(result.Errors, "failed to fetch files: "+err.Error())
|
||||
break
|
||||
}
|
||||
|
||||
// Process each file in the batch
|
||||
for _, cloudFile := range resp.Files {
|
||||
if err := s.processFile(ctx, cloudFile, input.Password, result); err != nil {
|
||||
s.logger.Error("Failed to process file",
|
||||
zap.String("id", cloudFile.ID),
|
||||
zap.Error(err))
|
||||
result.Errors = append(result.Errors, "failed to process file "+cloudFile.ID+": "+err.Error())
|
||||
}
|
||||
result.FilesProcessed++
|
||||
}
|
||||
|
||||
// Update sync state with new cursor
|
||||
state.UpdateFileSync(resp.NextCursor, resp.HasMore)
|
||||
if err := syncStateRepo.Save(state); err != nil {
|
||||
s.logger.Error("Failed to save sync state", zap.Error(err))
|
||||
result.Errors = append(result.Errors, "failed to save sync state: "+err.Error())
|
||||
}
|
||||
|
||||
batchCount++
|
||||
|
||||
// Check if we're done
|
||||
if !resp.HasMore {
|
||||
s.logger.Info("File sync completed - no more items")
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
s.logger.Info("File sync finished",
|
||||
zap.Int("processed", result.FilesProcessed),
|
||||
zap.Int("added", result.FilesAdded),
|
||||
zap.Int("updated", result.FilesUpdated),
|
||||
zap.Int("deleted", result.FilesDeleted),
|
||||
zap.Int("errors", len(result.Errors)))
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// processFile handles a single file from the cloud
|
||||
// Note: ctx and password are reserved for future use (on-demand content decryption)
|
||||
func (s *fileSyncService) processFile(_ context.Context, cloudFile *client.File, _ string, result *SyncResult) error {
|
||||
// Get file repository
|
||||
fileRepo, err := s.getFileRepo()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Check if file exists locally
|
||||
localFile, err := fileRepo.Get(cloudFile.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Handle deleted files
|
||||
if cloudFile.State == fileDomain.StateDeleted {
|
||||
if localFile != nil {
|
||||
if err := fileRepo.Delete(cloudFile.ID); err != nil {
|
||||
return err
|
||||
}
|
||||
result.FilesDeleted++
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Create or update local file (metadata only - no content download)
|
||||
if localFile == nil {
|
||||
// Create new local file record
|
||||
newFile := s.mapCloudToLocal(cloudFile)
|
||||
if err := fileRepo.Create(newFile); err != nil {
|
||||
return err
|
||||
}
|
||||
result.FilesAdded++
|
||||
} else {
|
||||
// Update existing file metadata
|
||||
updatedFile := s.mapCloudToLocal(cloudFile)
|
||||
// Preserve local-only fields
|
||||
updatedFile.FilePath = localFile.FilePath
|
||||
updatedFile.EncryptedFilePath = localFile.EncryptedFilePath
|
||||
updatedFile.ThumbnailPath = localFile.ThumbnailPath
|
||||
updatedFile.Name = localFile.Name
|
||||
updatedFile.MimeType = localFile.MimeType
|
||||
updatedFile.Metadata = localFile.Metadata
|
||||
|
||||
// If file has local content, it's synced; otherwise it's cloud-only
|
||||
if localFile.HasLocalContent() {
|
||||
updatedFile.SyncStatus = fileDomain.SyncStatusSynced
|
||||
} else {
|
||||
updatedFile.SyncStatus = fileDomain.SyncStatusCloudOnly
|
||||
}
|
||||
updatedFile.LastSyncedAt = time.Now()
|
||||
|
||||
if err := fileRepo.Update(updatedFile); err != nil {
|
||||
return err
|
||||
}
|
||||
result.FilesUpdated++
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// mapCloudToLocal converts a cloud file to local domain model
|
||||
func (s *fileSyncService) mapCloudToLocal(cloudFile *client.File) *fileDomain.File {
|
||||
return &fileDomain.File{
|
||||
ID: cloudFile.ID,
|
||||
CollectionID: cloudFile.CollectionID,
|
||||
OwnerID: cloudFile.UserID,
|
||||
EncryptedFileKey: fileDomain.EncryptedFileKeyData{
|
||||
Ciphertext: cloudFile.EncryptedFileKey.Ciphertext,
|
||||
Nonce: cloudFile.EncryptedFileKey.Nonce,
|
||||
},
|
||||
FileKeyNonce: cloudFile.FileKeyNonce,
|
||||
EncryptedMetadata: cloudFile.EncryptedMetadata,
|
||||
MetadataNonce: cloudFile.MetadataNonce,
|
||||
FileNonce: cloudFile.FileNonce,
|
||||
EncryptedSizeInBytes: cloudFile.EncryptedSizeInBytes,
|
||||
DecryptedSizeInBytes: cloudFile.DecryptedSizeInBytes,
|
||||
// Local paths are empty until file is downloaded (onloaded)
|
||||
EncryptedFilePath: "",
|
||||
FilePath: "",
|
||||
ThumbnailPath: "",
|
||||
// Metadata will be decrypted when file is onloaded
|
||||
Name: "",
|
||||
MimeType: "",
|
||||
Metadata: nil,
|
||||
SyncStatus: fileDomain.SyncStatusCloudOnly, // Files start as cloud-only
|
||||
LastSyncedAt: time.Now(),
|
||||
State: cloudFile.State,
|
||||
StorageMode: cloudFile.StorageMode,
|
||||
Version: cloudFile.Version,
|
||||
CreatedAt: cloudFile.CreatedAt,
|
||||
ModifiedAt: cloudFile.ModifiedAt,
|
||||
ThumbnailURL: cloudFile.ThumbnailURL,
|
||||
}
|
||||
}
|
||||
149
native/desktop/maplefile/internal/service/sync/service.go
Normal file
149
native/desktop/maplefile/internal/service/sync/service.go
Normal file
|
|
@ -0,0 +1,149 @@
|
|||
package sync
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
"codeberg.org/mapleopentech/monorepo/native/desktop/maplefile/internal/domain/syncstate"
|
||||
)
|
||||
|
||||
// Service provides unified sync operations
|
||||
type Service interface {
|
||||
// SyncAll synchronizes both collections and files
|
||||
SyncAll(ctx context.Context, input *SyncInput) (*SyncResult, error)
|
||||
|
||||
// SyncCollections synchronizes collections only
|
||||
SyncCollections(ctx context.Context, input *SyncInput) (*SyncResult, error)
|
||||
|
||||
// SyncFiles synchronizes files only
|
||||
SyncFiles(ctx context.Context, input *SyncInput) (*SyncResult, error)
|
||||
|
||||
// GetSyncStatus returns the current sync status
|
||||
GetSyncStatus(ctx context.Context) (*SyncStatus, error)
|
||||
|
||||
// ResetSync resets all sync state for a fresh sync
|
||||
ResetSync(ctx context.Context) error
|
||||
}
|
||||
|
||||
type service struct {
|
||||
logger *zap.Logger
|
||||
collectionSync CollectionSyncService
|
||||
fileSync FileSyncService
|
||||
repoProvider RepositoryProvider
|
||||
}
|
||||
|
||||
// ProvideService creates a new unified sync service for Wire
|
||||
func ProvideService(
|
||||
logger *zap.Logger,
|
||||
collectionSync CollectionSyncService,
|
||||
fileSync FileSyncService,
|
||||
repoProvider RepositoryProvider,
|
||||
) Service {
|
||||
return &service{
|
||||
logger: logger.Named("SyncService"),
|
||||
collectionSync: collectionSync,
|
||||
fileSync: fileSync,
|
||||
repoProvider: repoProvider,
|
||||
}
|
||||
}
|
||||
|
||||
// getSyncStateRepo returns the sync state repository, or an error if not initialized
|
||||
func (s *service) getSyncStateRepo() (syncstate.Repository, error) {
|
||||
if !s.repoProvider.IsInitialized() {
|
||||
return nil, fmt.Errorf("storage not initialized - user must be logged in")
|
||||
}
|
||||
repo := s.repoProvider.GetSyncStateRepository()
|
||||
if repo == nil {
|
||||
return nil, fmt.Errorf("sync state repository not available")
|
||||
}
|
||||
return repo, nil
|
||||
}
|
||||
|
||||
// SyncAll synchronizes both collections and files
|
||||
func (s *service) SyncAll(ctx context.Context, input *SyncInput) (*SyncResult, error) {
|
||||
s.logger.Info("Starting full sync (collections + files)")
|
||||
|
||||
// Sync collections first
|
||||
colResult, err := s.collectionSync.Execute(ctx, input)
|
||||
if err != nil {
|
||||
s.logger.Error("Collection sync failed during full sync", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Sync files
|
||||
fileResult, err := s.fileSync.Execute(ctx, input)
|
||||
if err != nil {
|
||||
s.logger.Error("File sync failed during full sync", zap.Error(err))
|
||||
// Return partial result with collection data
|
||||
return &SyncResult{
|
||||
CollectionsProcessed: colResult.CollectionsProcessed,
|
||||
CollectionsAdded: colResult.CollectionsAdded,
|
||||
CollectionsUpdated: colResult.CollectionsUpdated,
|
||||
CollectionsDeleted: colResult.CollectionsDeleted,
|
||||
Errors: append(colResult.Errors, "file sync failed: "+err.Error()),
|
||||
}, err
|
||||
}
|
||||
|
||||
// Merge results
|
||||
result := &SyncResult{
|
||||
CollectionsProcessed: colResult.CollectionsProcessed,
|
||||
CollectionsAdded: colResult.CollectionsAdded,
|
||||
CollectionsUpdated: colResult.CollectionsUpdated,
|
||||
CollectionsDeleted: colResult.CollectionsDeleted,
|
||||
FilesProcessed: fileResult.FilesProcessed,
|
||||
FilesAdded: fileResult.FilesAdded,
|
||||
FilesUpdated: fileResult.FilesUpdated,
|
||||
FilesDeleted: fileResult.FilesDeleted,
|
||||
Errors: append(colResult.Errors, fileResult.Errors...),
|
||||
}
|
||||
|
||||
s.logger.Info("Full sync completed",
|
||||
zap.Int("collections_processed", result.CollectionsProcessed),
|
||||
zap.Int("files_processed", result.FilesProcessed),
|
||||
zap.Int("errors", len(result.Errors)))
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// SyncCollections synchronizes collections only
|
||||
func (s *service) SyncCollections(ctx context.Context, input *SyncInput) (*SyncResult, error) {
|
||||
return s.collectionSync.Execute(ctx, input)
|
||||
}
|
||||
|
||||
// SyncFiles synchronizes files only
|
||||
func (s *service) SyncFiles(ctx context.Context, input *SyncInput) (*SyncResult, error) {
|
||||
return s.fileSync.Execute(ctx, input)
|
||||
}
|
||||
|
||||
// GetSyncStatus returns the current sync status
|
||||
func (s *service) GetSyncStatus(ctx context.Context) (*SyncStatus, error) {
|
||||
syncStateRepo, err := s.getSyncStateRepo()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
state, err := syncStateRepo.Get()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &SyncStatus{
|
||||
CollectionsSynced: state.IsCollectionSyncComplete(),
|
||||
FilesSynced: state.IsFileSyncComplete(),
|
||||
FullySynced: state.IsFullySynced(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ResetSync resets all sync state for a fresh sync
|
||||
func (s *service) ResetSync(ctx context.Context) error {
|
||||
s.logger.Info("Resetting sync state")
|
||||
|
||||
syncStateRepo, err := s.getSyncStateRepo()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return syncStateRepo.Reset()
|
||||
}
|
||||
39
native/desktop/maplefile/internal/service/sync/types.go
Normal file
39
native/desktop/maplefile/internal/service/sync/types.go
Normal file
|
|
@ -0,0 +1,39 @@
|
|||
package sync
|
||||
|
||||
// SyncResult represents the result of a sync operation
|
||||
type SyncResult struct {
|
||||
// Collection sync statistics
|
||||
CollectionsProcessed int `json:"collections_processed"`
|
||||
CollectionsAdded int `json:"collections_added"`
|
||||
CollectionsUpdated int `json:"collections_updated"`
|
||||
CollectionsDeleted int `json:"collections_deleted"`
|
||||
|
||||
// File sync statistics
|
||||
FilesProcessed int `json:"files_processed"`
|
||||
FilesAdded int `json:"files_added"`
|
||||
FilesUpdated int `json:"files_updated"`
|
||||
FilesDeleted int `json:"files_deleted"`
|
||||
|
||||
// Errors encountered during sync
|
||||
Errors []string `json:"errors,omitempty"`
|
||||
}
|
||||
|
||||
// SyncInput represents input parameters for sync operations
|
||||
type SyncInput struct {
|
||||
BatchSize int64 `json:"batch_size,omitempty"` // Number of items per batch (default: 50)
|
||||
MaxBatches int `json:"max_batches,omitempty"` // Maximum batches to process (default: 100)
|
||||
Password string `json:"password"` // Required for E2EE decryption
|
||||
}
|
||||
|
||||
// SyncStatus represents the current sync status
|
||||
type SyncStatus struct {
|
||||
CollectionsSynced bool `json:"collections_synced"`
|
||||
FilesSynced bool `json:"files_synced"`
|
||||
FullySynced bool `json:"fully_synced"`
|
||||
}
|
||||
|
||||
// DefaultBatchSize is the default number of items to fetch per API call
|
||||
const DefaultBatchSize = 50
|
||||
|
||||
// DefaultMaxBatches is the default maximum number of batches to process
|
||||
const DefaultMaxBatches = 100
|
||||
929
native/desktop/maplefile/internal/service/tokenmanager/README.md
Normal file
929
native/desktop/maplefile/internal/service/tokenmanager/README.md
Normal file
|
|
@ -0,0 +1,929 @@
|
|||
# Token Manager Service
|
||||
|
||||
## Table of Contents
|
||||
1. [Overview](#overview)
|
||||
2. [Why Do We Need This?](#why-do-we-need-this)
|
||||
3. [How It Works](#how-it-works)
|
||||
4. [Architecture](#architecture)
|
||||
5. [Configuration](#configuration)
|
||||
6. [Lifecycle Management](#lifecycle-management)
|
||||
7. [Error Handling](#error-handling)
|
||||
8. [Testing](#testing)
|
||||
9. [Troubleshooting](#troubleshooting)
|
||||
10. [Examples](#examples)
|
||||
|
||||
---
|
||||
|
||||
## Overview
|
||||
|
||||
The Token Manager is a background service that automatically refreshes authentication tokens before they expire. This ensures users stay logged in without interruption and don't experience failed API requests due to expired tokens.
|
||||
|
||||
**Key Benefits:**
|
||||
- ✅ Seamless user experience (no sudden logouts)
|
||||
- ✅ No failed API requests due to expired tokens
|
||||
- ✅ Automatic cleanup on app shutdown
|
||||
- ✅ Graceful handling of refresh failures
|
||||
|
||||
---
|
||||
|
||||
## Why Do We Need This?
|
||||
|
||||
### The Problem
|
||||
|
||||
When you log into MapleFile, the backend gives you two tokens:
|
||||
|
||||
1. **Access Token** - Used for API requests (expires quickly, e.g., 1 hour)
|
||||
2. **Refresh Token** - Used to get new access tokens (lasts longer, e.g., 30 days)
|
||||
|
||||
**Without Token Manager:**
|
||||
```
|
||||
User logs in → Gets tokens (expires in 1 hour)
|
||||
User works for 61 minutes
|
||||
User tries to upload file → ❌ 401 Unauthorized!
|
||||
User gets logged out → 😞 Lost work, has to login again
|
||||
```
|
||||
|
||||
**With Token Manager:**
|
||||
```
|
||||
User logs in → Gets tokens (expires in 1 hour)
|
||||
Token Manager checks every 30 seconds
|
||||
At 59 minutes → Token Manager refreshes tokens automatically
|
||||
User works for hours → ✅ Everything just works!
|
||||
```
|
||||
|
||||
### The Solution
|
||||
|
||||
The Token Manager runs in the background and:
|
||||
1. **Checks** token expiration every 30 seconds
|
||||
2. **Refreshes** tokens when < 1 minute remains
|
||||
3. **Handles failures** gracefully (3 strikes = logout)
|
||||
4. **Shuts down cleanly** when app closes
|
||||
|
||||
---
|
||||
|
||||
## How It Works
|
||||
|
||||
### High-Level Flow
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────┐
|
||||
│ Application Lifecycle │
|
||||
└─────────────────────────────────────────────────────────────┘
|
||||
│
|
||||
▼
|
||||
┌──────────────────────────────────────┐
|
||||
│ App Starts / User Logs In │
|
||||
└──────────────────────────────────────┘
|
||||
│
|
||||
▼
|
||||
┌──────────────────────────────────────┐
|
||||
│ Token Manager Starts │
|
||||
│ (background goroutine) │
|
||||
└──────────────────────────────────────┘
|
||||
│
|
||||
▼
|
||||
┌──────────────────────────────────────┐
|
||||
│ Every 30 seconds: │
|
||||
│ 1. Check session │
|
||||
│ 2. Calculate time until expiry │
|
||||
│ 3. Refresh if < 1 minute │
|
||||
└──────────────────────────────────────┘
|
||||
│
|
||||
┌─────────┴─────────┐
|
||||
▼ ▼
|
||||
┌───────────────────┐ ┌──────────────────┐
|
||||
│ Refresh Success │ │ Refresh Failed │
|
||||
│ (reset counter) │ │ (increment) │
|
||||
└───────────────────┘ └──────────────────┘
|
||||
│
|
||||
▼
|
||||
┌──────────────────┐
|
||||
│ 3 failures? │
|
||||
└──────────────────┘
|
||||
│
|
||||
Yes │ No
|
||||
┌──────────┴──────┐
|
||||
▼ ▼
|
||||
┌─────────────────┐ ┌──────────┐
|
||||
│ Force Logout │ │ Continue │
|
||||
└─────────────────┘ └──────────┘
|
||||
│
|
||||
▼
|
||||
┌──────────────────────────────────────┐
|
||||
│ App Shuts Down / User Logs Out │
|
||||
└──────────────────────────────────────┘
|
||||
│
|
||||
▼
|
||||
┌──────────────────────────────────────┐
|
||||
│ Token Manager Stops Gracefully │
|
||||
│ (goroutine cleanup) │
|
||||
└──────────────────────────────────────┘
|
||||
```
|
||||
|
||||
### Detailed Process
|
||||
|
||||
#### 1. **Starting the Token Manager**
|
||||
|
||||
When a user logs in OR when the app restarts with a valid session:
|
||||
|
||||
```go
|
||||
// In CompleteLogin or Startup
|
||||
tokenManager.Start()
|
||||
```
|
||||
|
||||
This creates a background goroutine that runs continuously.
|
||||
|
||||
#### 2. **Background Refresh Loop**
|
||||
|
||||
The goroutine runs this logic every 30 seconds:
|
||||
|
||||
```go
|
||||
1. Get current session from LevelDB
|
||||
2. Check if session exists and is valid
|
||||
3. Calculate: timeUntilExpiry = session.ExpiresAt - time.Now()
|
||||
4. If timeUntilExpiry < 1 minute:
|
||||
a. Call API to refresh tokens
|
||||
b. API returns new access + refresh tokens
|
||||
c. Tokens automatically saved to session
|
||||
5. If refresh fails:
|
||||
a. Increment failure counter
|
||||
b. If counter >= 3: Force logout
|
||||
6. If refresh succeeds:
|
||||
a. Reset failure counter to 0
|
||||
```
|
||||
|
||||
#### 3. **Stopping the Token Manager**
|
||||
|
||||
When user logs out OR app shuts down:
|
||||
|
||||
```go
|
||||
// Create a timeout context (max 3 seconds)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Stop gracefully
|
||||
tokenManager.Stop(ctx)
|
||||
```
|
||||
|
||||
This signals the goroutine to stop and waits for confirmation.
|
||||
|
||||
---
|
||||
|
||||
## Architecture
|
||||
|
||||
### Component Structure
|
||||
|
||||
```
|
||||
internal/service/tokenmanager/
|
||||
├── config.go # Configuration settings
|
||||
├── manager.go # Main token manager logic
|
||||
├── provider.go # Wire dependency injection
|
||||
└── README.md # This file
|
||||
```
|
||||
|
||||
### Key Components
|
||||
|
||||
#### 1. **Manager Struct**
|
||||
|
||||
```go
|
||||
type Manager struct {
|
||||
// Dependencies
|
||||
config Config // Settings (intervals, thresholds)
|
||||
client *client.Client // API client for token refresh
|
||||
authService *auth.Service // Auth service for logout
|
||||
getSession *session.GetByIdUseCase // Get current session
|
||||
logger *zap.Logger // Structured logging
|
||||
|
||||
// Lifecycle management
|
||||
ctx context.Context // Manager's context
|
||||
cancel context.CancelFunc // Cancel function
|
||||
stopCh chan struct{} // Signal to stop
|
||||
stoppedCh chan struct{} // Confirmation of stopped
|
||||
running atomic.Bool // Is manager running?
|
||||
|
||||
// Refresh state
|
||||
mu sync.Mutex // Protects failure counter
|
||||
consecutiveFailures int // Track failures
|
||||
}
|
||||
```
|
||||
|
||||
#### 2. **Config Struct**
|
||||
|
||||
```go
|
||||
type Config struct {
|
||||
RefreshBeforeExpiry time.Duration // How early to refresh (default: 1 min)
|
||||
CheckInterval time.Duration // How often to check (default: 30 sec)
|
||||
MaxConsecutiveFailures int // Failures before logout (default: 3)
|
||||
}
|
||||
```
|
||||
|
||||
### Goroutine Management
|
||||
|
||||
#### Why Use Goroutines?
|
||||
|
||||
A **goroutine** is Go's way of running code in the background (like a separate thread). We need this because:
|
||||
|
||||
- Main app needs to respond to UI events
|
||||
- Token checking can happen in the background
|
||||
- No blocking of user actions
|
||||
|
||||
#### The Double-Channel Pattern
|
||||
|
||||
We use **two channels** for clean shutdown:
|
||||
|
||||
```go
|
||||
stopCh chan struct{} // We close this to signal "please stop"
|
||||
stoppedCh chan struct{} // Goroutine closes this to say "I stopped"
|
||||
```
|
||||
|
||||
**Why two channels?**
|
||||
|
||||
```go
|
||||
// Without confirmation:
|
||||
close(stopCh) // Signal stop
|
||||
// Goroutine might still be running! ⚠️
|
||||
// App shuts down → goroutine orphaned → potential crash
|
||||
|
||||
// With confirmation:
|
||||
close(stopCh) // Signal stop
|
||||
<-stoppedCh // Wait for confirmation
|
||||
// Now we KNOW goroutine is done ✅
|
||||
```
|
||||
|
||||
#### Thread Safety
|
||||
|
||||
**Problem:** Multiple parts of the app might access the token manager at once.
|
||||
|
||||
**Solution:** Use synchronization primitives:
|
||||
|
||||
1. **`atomic.Bool` for running flag**
|
||||
```go
|
||||
// Atomic operations are thread-safe (no mutex needed)
|
||||
if !tm.running.CompareAndSwap(false, true) {
|
||||
return // Already running, don't start again
|
||||
}
|
||||
```
|
||||
|
||||
2. **`sync.Mutex` for failure counter**
|
||||
```go
|
||||
// Lock before accessing shared data
|
||||
tm.mu.Lock()
|
||||
defer tm.mu.Unlock()
|
||||
tm.consecutiveFailures++
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Configuration
|
||||
|
||||
### Default Settings
|
||||
|
||||
```go
|
||||
Config{
|
||||
RefreshBeforeExpiry: 1 * time.Minute, // Refresh with 1 min remaining
|
||||
CheckInterval: 30 * time.Second, // Check every 30 seconds
|
||||
MaxConsecutiveFailures: 3, // 3 failures = logout
|
||||
}
|
||||
```
|
||||
|
||||
### Why These Values?
|
||||
|
||||
| Setting | Value | Reasoning |
|
||||
|---------|-------|-----------|
|
||||
| **RefreshBeforeExpiry** | 1 minute | Conservative buffer. Even if one check fails, we have time for next attempt |
|
||||
| **CheckInterval** | 30 seconds | Frequent enough to catch the 1-minute window, not too aggressive on resources |
|
||||
| **MaxConsecutiveFailures** | 3 failures | Balances between transient network issues and genuine auth problems |
|
||||
|
||||
### Customizing Configuration
|
||||
|
||||
To change settings, modify `provider.go`:
|
||||
|
||||
```go
|
||||
func ProvideManager(...) *Manager {
|
||||
config := Config{
|
||||
RefreshBeforeExpiry: 2 * time.Minute, // More conservative
|
||||
CheckInterval: 1 * time.Minute, // Less frequent checks
|
||||
MaxConsecutiveFailures: 5, // More tolerant
|
||||
}
|
||||
return New(config, client, authService, getSession, logger)
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Lifecycle Management
|
||||
|
||||
### 1. **Starting the Token Manager**
|
||||
|
||||
**Called from:**
|
||||
- `Application.Startup()` - If valid session exists from previous run
|
||||
- `Application.CompleteLogin()` - After successful login
|
||||
|
||||
**What happens:**
|
||||
|
||||
```go
|
||||
func (m *Manager) Start() {
|
||||
// 1. Check if already running (thread-safe)
|
||||
if !m.running.CompareAndSwap(false, true) {
|
||||
return // Already running, do nothing
|
||||
}
|
||||
|
||||
// 2. Create context for goroutine
|
||||
m.ctx, m.cancel = context.WithCancel(context.Background())
|
||||
|
||||
// 3. Create channels for communication
|
||||
m.stopCh = make(chan struct{})
|
||||
m.stoppedCh = make(chan struct{})
|
||||
|
||||
// 4. Reset failure counter
|
||||
m.consecutiveFailures = 0
|
||||
|
||||
// 5. Launch background goroutine
|
||||
go m.refreshLoop()
|
||||
}
|
||||
```
|
||||
|
||||
**Why it's safe to call multiple times:**
|
||||
|
||||
The `CompareAndSwap` operation ensures only ONE goroutine starts, even if `Start()` is called many times.
|
||||
|
||||
### 2. **Running the Refresh Loop**
|
||||
|
||||
**The goroutine does this forever (until stopped):**
|
||||
|
||||
```go
|
||||
func (m *Manager) refreshLoop() {
|
||||
// Ensure we always mark as stopped when exiting
|
||||
defer close(m.stoppedCh)
|
||||
defer m.running.Store(false)
|
||||
|
||||
// Create ticker (fires every 30 seconds)
|
||||
ticker := time.NewTicker(m.config.CheckInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
// Do initial check immediately
|
||||
m.checkAndRefresh()
|
||||
|
||||
// Loop forever
|
||||
for {
|
||||
select {
|
||||
case <-m.stopCh:
|
||||
// Stop signal received
|
||||
return
|
||||
|
||||
case <-m.ctx.Done():
|
||||
// Context cancelled
|
||||
return
|
||||
|
||||
case <-ticker.C:
|
||||
// 30 seconds elapsed, check again
|
||||
m.checkAndRefresh()
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**The `select` statement explained:**
|
||||
|
||||
Think of `select` like a switch statement for channels. It waits for one of these events:
|
||||
- `stopCh` closed → Time to stop
|
||||
- `ctx.Done()` → Forced cancellation
|
||||
- `ticker.C` → 30 seconds passed, do work
|
||||
|
||||
### 3. **Stopping the Token Manager**
|
||||
|
||||
**Called from:**
|
||||
- `Application.Shutdown()` - App closing
|
||||
- `Application.Logout()` - User logging out
|
||||
|
||||
**What happens:**
|
||||
|
||||
```go
|
||||
func (m *Manager) Stop(ctx context.Context) error {
|
||||
// 1. Check if running
|
||||
if !m.running.Load() {
|
||||
return nil // Not running, nothing to do
|
||||
}
|
||||
|
||||
// 2. Signal stop (close the channel)
|
||||
close(m.stopCh)
|
||||
|
||||
// 3. Wait for confirmation OR timeout
|
||||
select {
|
||||
case <-m.stoppedCh:
|
||||
// Goroutine confirmed it stopped
|
||||
return nil
|
||||
|
||||
case <-ctx.Done():
|
||||
// Timeout! Force cancel
|
||||
m.cancel()
|
||||
|
||||
// Give it 100ms more
|
||||
select {
|
||||
case <-m.stoppedCh:
|
||||
return nil
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
return ctx.Err() // Failed to stop cleanly
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Why the timeout?**
|
||||
|
||||
If the goroutine is stuck (e.g., in a long API call), we can't wait forever. The app needs to shut down!
|
||||
|
||||
---
|
||||
|
||||
## Error Handling
|
||||
|
||||
### 1. **Refresh Failures**
|
||||
|
||||
**Types of failures:**
|
||||
|
||||
| Failure Type | Cause | Handling |
|
||||
|--------------|-------|----------|
|
||||
| **Network Error** | No internet connection | Increment counter, retry next check |
|
||||
| **401 Unauthorized** | Refresh token expired | Increment counter, likely force logout |
|
||||
| **500 Server Error** | Backend issue | Increment counter, retry next check |
|
||||
| **Timeout** | Slow network | Increment counter, retry next check |
|
||||
|
||||
**Failure tracking:**
|
||||
|
||||
```go
|
||||
func (m *Manager) checkAndRefresh() error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
// ... check if refresh needed ...
|
||||
|
||||
// Attempt refresh
|
||||
if err := m.client.RefreshToken(ctx); err != nil {
|
||||
m.consecutiveFailures++
|
||||
|
||||
if m.consecutiveFailures >= m.config.MaxConsecutiveFailures {
|
||||
// Too many failures! Force logout
|
||||
return m.forceLogout()
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// Success! Reset counter
|
||||
m.consecutiveFailures = 0
|
||||
return nil
|
||||
}
|
||||
```
|
||||
|
||||
### 2. **Force Logout**
|
||||
|
||||
**When it happens:**
|
||||
- 3 consecutive refresh failures
|
||||
- Session expired on startup
|
||||
|
||||
**What it does:**
|
||||
|
||||
```go
|
||||
func (m *Manager) forceLogout() error {
|
||||
m.logger.Warn("Forcing logout due to token refresh issues")
|
||||
|
||||
// Use background context (not manager's context which might be cancelled)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Clear session from LevelDB
|
||||
if err := m.authService.Logout(ctx); err != nil {
|
||||
m.logger.Error("Failed to force logout", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
// User will see login screen on next UI interaction
|
||||
return nil
|
||||
}
|
||||
```
|
||||
|
||||
**User experience:**
|
||||
|
||||
When force logout happens, the user will see the login screen the next time they interact with the app. Their work is NOT lost (local files remain), they just need to log in again.
|
||||
|
||||
### 3. **Session Not Found**
|
||||
|
||||
**Scenario:** User manually deleted session file, or session expired.
|
||||
|
||||
**Handling:**
|
||||
|
||||
```go
|
||||
// Get current session
|
||||
sess, err := m.getSession.Execute()
|
||||
if err != nil || sess == nil {
|
||||
// No session = user not logged in
|
||||
// This is normal, not an error
|
||||
return nil // Do nothing
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Testing
|
||||
|
||||
### Manual Testing
|
||||
|
||||
#### Test 1: Normal Refresh
|
||||
|
||||
1. Log in to the app
|
||||
2. Watch logs for token manager start
|
||||
3. Wait ~30 seconds
|
||||
4. Check logs for "Token refresh not needed yet"
|
||||
5. Verify `time_until_expiry` is decreasing
|
||||
|
||||
**Expected logs:**
|
||||
```
|
||||
INFO Token manager starting
|
||||
INFO Token refresh loop started
|
||||
DEBUG Token refresh not needed yet {"time_until_expiry": "59m30s"}
|
||||
... wait 30 seconds ...
|
||||
DEBUG Token refresh not needed yet {"time_until_expiry": "59m0s"}
|
||||
```
|
||||
|
||||
#### Test 2: Automatic Refresh
|
||||
|
||||
1. Log in and get tokens with short expiry (if possible)
|
||||
2. Wait until < 1 minute remaining
|
||||
3. Watch logs for automatic refresh
|
||||
|
||||
**Expected logs:**
|
||||
```
|
||||
INFO Token refresh needed {"time_until_expiry": "45s"}
|
||||
INFO Token refreshed successfully
|
||||
DEBUG Token refresh not needed yet {"time_until_expiry": "59m30s"}
|
||||
```
|
||||
|
||||
#### Test 3: Graceful Shutdown
|
||||
|
||||
1. Log in (token manager running)
|
||||
2. Close the app (Cmd+Q on Mac, Alt+F4 on Windows)
|
||||
3. Check logs for clean shutdown
|
||||
|
||||
**Expected logs:**
|
||||
```
|
||||
INFO MapleFile desktop application shutting down
|
||||
INFO Token manager stopping...
|
||||
INFO Token refresh loop received stop signal
|
||||
INFO Token refresh loop exited
|
||||
INFO Token manager stopped gracefully
|
||||
```
|
||||
|
||||
#### Test 4: Logout
|
||||
|
||||
1. Log in (token manager running)
|
||||
2. Click logout button
|
||||
3. Verify token manager stops
|
||||
|
||||
**Expected logs:**
|
||||
```
|
||||
INFO Token manager stopping...
|
||||
INFO Token manager stopped gracefully
|
||||
INFO User logged out successfully
|
||||
```
|
||||
|
||||
#### Test 5: Session Resume on Restart
|
||||
|
||||
1. Log in
|
||||
2. Close app
|
||||
3. Restart app
|
||||
4. Check logs for session resume
|
||||
|
||||
**Expected logs:**
|
||||
```
|
||||
INFO MapleFile desktop application started
|
||||
INFO Resuming valid session from previous run
|
||||
INFO Session restored to API client
|
||||
INFO Token manager starting
|
||||
INFO Token manager started for resumed session
|
||||
```
|
||||
|
||||
### Unit Testing (TODO)
|
||||
|
||||
```go
|
||||
// Example test structure (to be implemented)
|
||||
|
||||
func TestTokenManager_Start(t *testing.T) {
|
||||
// Test that Start() can be called multiple times safely
|
||||
// Test that goroutine actually starts
|
||||
}
|
||||
|
||||
func TestTokenManager_Stop(t *testing.T) {
|
||||
// Test graceful shutdown
|
||||
// Test timeout handling
|
||||
}
|
||||
|
||||
func TestTokenManager_RefreshLogic(t *testing.T) {
|
||||
// Test refresh when < 1 minute
|
||||
// Test no refresh when > 1 minute
|
||||
}
|
||||
|
||||
func TestTokenManager_FailureHandling(t *testing.T) {
|
||||
// Test failure counter increment
|
||||
// Test force logout after 3 failures
|
||||
// Test counter reset on success
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Problem: Token manager not starting
|
||||
|
||||
**Symptoms:**
|
||||
- No "Token manager starting" log
|
||||
- App works but might get logged out after token expires
|
||||
|
||||
**Possible causes:**
|
||||
|
||||
1. **No session on startup**
|
||||
```
|
||||
Check logs for: "No session found on startup"
|
||||
Solution: This is normal if user hasn't logged in yet
|
||||
```
|
||||
|
||||
2. **Session expired**
|
||||
```
|
||||
Check logs for: "Session expired on startup"
|
||||
Solution: User needs to log in again
|
||||
```
|
||||
|
||||
3. **Token manager already running**
|
||||
```
|
||||
Check logs for: "Token manager already running"
|
||||
Solution: This is expected behavior (prevents duplicate goroutines)
|
||||
```
|
||||
|
||||
### Problem: "Token manager stop timeout"
|
||||
|
||||
**Symptoms:**
|
||||
- App takes long time to close
|
||||
- Warning in logs: "Token manager stop timeout, forcing cancellation"
|
||||
|
||||
**Possible causes:**
|
||||
|
||||
1. **Refresh in progress during shutdown**
|
||||
```
|
||||
Goroutine might be in the middle of API call
|
||||
Solution: Wait for current API call to timeout (max 30s)
|
||||
```
|
||||
|
||||
2. **Network issue**
|
||||
```
|
||||
API call hanging due to network problems
|
||||
Solution: Force cancellation (already handled automatically)
|
||||
```
|
||||
|
||||
### Problem: Getting logged out unexpectedly
|
||||
|
||||
**Symptoms:**
|
||||
- User sees login screen randomly
|
||||
- Logs show "Forcing logout due to token refresh issues"
|
||||
|
||||
**Possible causes:**
|
||||
|
||||
1. **Network connectivity issues**
|
||||
```
|
||||
Check logs for repeated: "Token refresh failed"
|
||||
Solution: Check internet connection, backend availability
|
||||
```
|
||||
|
||||
2. **Backend API down**
|
||||
```
|
||||
All refresh attempts failing
|
||||
Solution: Check backend service status
|
||||
```
|
||||
|
||||
3. **Refresh token expired**
|
||||
```
|
||||
Backend returns 401 on refresh
|
||||
Solution: User needs to log in again (this is expected)
|
||||
```
|
||||
|
||||
### Problem: High CPU/memory usage
|
||||
|
||||
**Symptoms:**
|
||||
- App using lots of resources
|
||||
- Multiple token managers running
|
||||
|
||||
**Diagnosis:**
|
||||
|
||||
```bash
|
||||
# Check goroutines
|
||||
curl http://localhost:34115/debug/pprof/goroutine?debug=1
|
||||
|
||||
# Look for multiple "refreshLoop" goroutines
|
||||
```
|
||||
|
||||
**Possible causes:**
|
||||
|
||||
1. **Token manager not stopping on logout**
|
||||
```
|
||||
Check logs for missing: "Token manager stopped gracefully"
|
||||
Solution: Bug in stop logic (report issue)
|
||||
```
|
||||
|
||||
2. **Multiple Start() calls**
|
||||
```
|
||||
Should not happen (atomic bool prevents this)
|
||||
Solution: Report issue with reproduction steps
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Examples
|
||||
|
||||
### Example 1: Adding Custom Logging
|
||||
|
||||
Want to know exactly when refresh happens?
|
||||
|
||||
```go
|
||||
// In tokenmanager/manager.go, modify checkAndRefresh():
|
||||
|
||||
func (m *Manager) checkAndRefresh() error {
|
||||
// ... existing code ...
|
||||
|
||||
// Before refresh
|
||||
m.logger.Info("REFRESH STARTING",
|
||||
zap.Time("now", time.Now()),
|
||||
zap.Time("token_expires_at", sess.ExpiresAt))
|
||||
|
||||
if err := m.client.RefreshToken(ctx); err != nil {
|
||||
// Log failure details
|
||||
m.logger.Error("REFRESH FAILED",
|
||||
zap.Error(err),
|
||||
zap.String("error_type", fmt.Sprintf("%T", err)))
|
||||
return err
|
||||
}
|
||||
|
||||
// After refresh
|
||||
m.logger.Info("REFRESH COMPLETED",
|
||||
zap.Time("completion_time", time.Now()))
|
||||
|
||||
return nil
|
||||
}
|
||||
```
|
||||
|
||||
### Example 2: Custom Failure Callback
|
||||
|
||||
Want to notify UI when logout happens?
|
||||
|
||||
```go
|
||||
// Add callback to Manager struct:
|
||||
|
||||
type Manager struct {
|
||||
// ... existing fields ...
|
||||
onForceLogout func(reason string) // NEW
|
||||
}
|
||||
|
||||
// In checkAndRefresh():
|
||||
if m.consecutiveFailures >= m.config.MaxConsecutiveFailures {
|
||||
reason := fmt.Sprintf("%d consecutive refresh failures", m.consecutiveFailures)
|
||||
|
||||
if m.onForceLogout != nil {
|
||||
m.onForceLogout(reason) // Notify callback
|
||||
}
|
||||
|
||||
return m.forceLogout()
|
||||
}
|
||||
|
||||
// In Application, set callback:
|
||||
func (a *Application) Startup(ctx context.Context) {
|
||||
// ... existing code ...
|
||||
|
||||
// Set callback to emit Wails event
|
||||
a.tokenManager.onForceLogout = func(reason string) {
|
||||
runtime.EventsEmit(a.ctx, "auth:logged-out", reason)
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Example 3: Metrics Collection
|
||||
|
||||
Want to track refresh statistics?
|
||||
|
||||
```go
|
||||
type RefreshMetrics struct {
|
||||
TotalRefreshes int64
|
||||
SuccessfulRefreshes int64
|
||||
FailedRefreshes int64
|
||||
LastRefreshTime time.Time
|
||||
}
|
||||
|
||||
// Add to Manager:
|
||||
type Manager struct {
|
||||
// ... existing fields ...
|
||||
metrics RefreshMetrics
|
||||
metricsMu sync.Mutex
|
||||
}
|
||||
|
||||
// In checkAndRefresh():
|
||||
if err := m.client.RefreshToken(ctx); err != nil {
|
||||
m.metricsMu.Lock()
|
||||
m.metrics.TotalRefreshes++
|
||||
m.metrics.FailedRefreshes++
|
||||
m.metricsMu.Unlock()
|
||||
return err
|
||||
}
|
||||
|
||||
m.metricsMu.Lock()
|
||||
m.metrics.TotalRefreshes++
|
||||
m.metrics.SuccessfulRefreshes++
|
||||
m.metrics.LastRefreshTime = time.Now()
|
||||
m.metricsMu.Unlock()
|
||||
|
||||
// Export metrics via Wails:
|
||||
func (a *Application) GetRefreshMetrics() map[string]interface{} {
|
||||
return map[string]interface{}{
|
||||
"total": a.tokenManager.metrics.TotalRefreshes,
|
||||
"successful": a.tokenManager.metrics.SuccessfulRefreshes,
|
||||
"failed": a.tokenManager.metrics.FailedRefreshes,
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Summary for Junior Developers
|
||||
|
||||
### Key Concepts to Remember
|
||||
|
||||
1. **Goroutines are background threads**
|
||||
- They run concurrently with your main app
|
||||
- Need careful management (start/stop)
|
||||
|
||||
2. **Channels are for communication**
|
||||
- `close(stopCh)` = "Please stop"
|
||||
- `<-stoppedCh` = "I confirm I stopped"
|
||||
|
||||
3. **Mutexes prevent race conditions**
|
||||
- Lock before accessing shared data
|
||||
- Always defer unlock
|
||||
|
||||
4. **Atomic operations are thread-safe**
|
||||
- Use for simple flags
|
||||
- No mutex needed
|
||||
|
||||
5. **Context carries deadlines**
|
||||
- Respect timeouts
|
||||
- Use for cancellation
|
||||
|
||||
### What NOT to Do
|
||||
|
||||
❌ **Don't call Start() in a loop**
|
||||
```go
|
||||
// Bad!
|
||||
for {
|
||||
tokenManager.Start() // Creates goroutine leak!
|
||||
}
|
||||
```
|
||||
|
||||
❌ **Don't forget to Stop()**
|
||||
```go
|
||||
// Bad!
|
||||
func Logout() {
|
||||
authService.Logout() // Token manager still running!
|
||||
}
|
||||
```
|
||||
|
||||
❌ **Don't block on Stop() without timeout**
|
||||
```go
|
||||
// Bad!
|
||||
tokenManager.Stop(context.Background()) // Could hang forever!
|
||||
|
||||
// Good!
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
tokenManager.Stop(ctx)
|
||||
```
|
||||
|
||||
### Learning Resources
|
||||
|
||||
- **Go Concurrency Patterns**: https://go.dev/blog/pipelines
|
||||
- **Context Package**: https://go.dev/blog/context
|
||||
- **Sync Package**: https://pkg.go.dev/sync
|
||||
|
||||
### Getting Help
|
||||
|
||||
If you're stuck:
|
||||
1. Check the logs (they're very detailed)
|
||||
2. Look at the troubleshooting section above
|
||||
3. Ask senior developers for code review
|
||||
4. File an issue with reproduction steps
|
||||
|
||||
---
|
||||
|
||||
## Changelog
|
||||
|
||||
### v1.0.0 (2025-11-21)
|
||||
- Initial implementation
|
||||
- Background refresh every 30 seconds
|
||||
- Refresh when < 1 minute before expiry
|
||||
- Graceful shutdown with timeout handling
|
||||
- Automatic logout after 3 consecutive failures
|
||||
- Session resume on app restart
|
||||
|
|
@ -0,0 +1,27 @@
|
|||
package tokenmanager
|
||||
|
||||
import "time"
|
||||
|
||||
// Config holds configuration for the token manager
|
||||
type Config struct {
|
||||
// RefreshBeforeExpiry is how long before expiry to refresh the token
|
||||
// Default: 1 minute
|
||||
RefreshBeforeExpiry time.Duration
|
||||
|
||||
// CheckInterval is how often to check if refresh is needed
|
||||
// Default: 30 seconds
|
||||
CheckInterval time.Duration
|
||||
|
||||
// MaxConsecutiveFailures is how many consecutive refresh failures before forcing logout
|
||||
// Default: 3
|
||||
MaxConsecutiveFailures int
|
||||
}
|
||||
|
||||
// DefaultConfig returns the default configuration
|
||||
func DefaultConfig() Config {
|
||||
return Config{
|
||||
RefreshBeforeExpiry: 1 * time.Minute,
|
||||
CheckInterval: 30 * time.Second,
|
||||
MaxConsecutiveFailures: 3,
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,228 @@
|
|||
package tokenmanager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
"codeberg.org/mapleopentech/monorepo/cloud/maplefile-backend/pkg/maplefile/client"
|
||||
"codeberg.org/mapleopentech/monorepo/native/desktop/maplefile/internal/service/auth"
|
||||
"codeberg.org/mapleopentech/monorepo/native/desktop/maplefile/internal/usecase/session"
|
||||
)
|
||||
|
||||
// Manager handles automatic token refresh with graceful shutdown
|
||||
type Manager struct {
|
||||
config Config
|
||||
client *client.Client
|
||||
authService *auth.Service
|
||||
getSession *session.GetByIdUseCase
|
||||
logger *zap.Logger
|
||||
|
||||
// Lifecycle management
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
stopCh chan struct{} // Signal to stop
|
||||
stoppedCh chan struct{} // Confirmation of stopped
|
||||
running atomic.Bool // Thread-safe running flag
|
||||
|
||||
// Refresh state management
|
||||
mu sync.Mutex
|
||||
consecutiveFailures int
|
||||
}
|
||||
|
||||
// New creates a new token manager
|
||||
func New(
|
||||
config Config,
|
||||
client *client.Client,
|
||||
authService *auth.Service,
|
||||
getSession *session.GetByIdUseCase,
|
||||
logger *zap.Logger,
|
||||
) *Manager {
|
||||
return &Manager{
|
||||
config: config,
|
||||
client: client,
|
||||
authService: authService,
|
||||
getSession: getSession,
|
||||
logger: logger.Named("token-manager"),
|
||||
}
|
||||
}
|
||||
|
||||
// Start begins the token refresh background process
|
||||
// Safe to call multiple times - will only start once
|
||||
func (m *Manager) Start() {
|
||||
// Only start if not already running
|
||||
if !m.running.CompareAndSwap(false, true) {
|
||||
m.logger.Debug("Token manager already running, skipping start")
|
||||
return
|
||||
}
|
||||
|
||||
m.ctx, m.cancel = context.WithCancel(context.Background())
|
||||
m.stopCh = make(chan struct{})
|
||||
m.stoppedCh = make(chan struct{})
|
||||
m.consecutiveFailures = 0
|
||||
|
||||
m.logger.Info("Token manager starting")
|
||||
go m.refreshLoop()
|
||||
}
|
||||
|
||||
// Stop gracefully stops the token refresh background process
|
||||
// Blocks until stopped or context deadline exceeded
|
||||
func (m *Manager) Stop(ctx context.Context) error {
|
||||
if !m.running.Load() {
|
||||
m.logger.Debug("Token manager not running, nothing to stop")
|
||||
return nil
|
||||
}
|
||||
|
||||
m.logger.Info("Token manager stopping...")
|
||||
|
||||
// Signal stop
|
||||
close(m.stopCh)
|
||||
|
||||
// Wait for goroutine to finish or timeout
|
||||
select {
|
||||
case <-m.stoppedCh:
|
||||
m.logger.Info("Token manager stopped gracefully")
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
m.logger.Warn("Token manager stop timeout, forcing cancellation")
|
||||
m.cancel()
|
||||
// Wait a bit more for cancellation to take effect
|
||||
select {
|
||||
case <-m.stoppedCh:
|
||||
m.logger.Info("Token manager stopped after forced cancellation")
|
||||
return nil
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
m.logger.Error("Token manager failed to stop cleanly")
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// IsRunning returns true if the token manager is currently running
|
||||
func (m *Manager) IsRunning() bool {
|
||||
return m.running.Load()
|
||||
}
|
||||
|
||||
// refreshLoop is the background goroutine that checks and refreshes tokens
|
||||
func (m *Manager) refreshLoop() {
|
||||
defer close(m.stoppedCh)
|
||||
defer m.running.Store(false)
|
||||
defer m.logger.Info("Token refresh loop exited")
|
||||
|
||||
ticker := time.NewTicker(m.config.CheckInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
m.logger.Info("Token refresh loop started",
|
||||
zap.Duration("check_interval", m.config.CheckInterval),
|
||||
zap.Duration("refresh_before_expiry", m.config.RefreshBeforeExpiry))
|
||||
|
||||
// Do initial check immediately
|
||||
if err := m.checkAndRefresh(); err != nil {
|
||||
m.logger.Error("Initial token refresh check failed", zap.Error(err))
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-m.stopCh:
|
||||
m.logger.Info("Token refresh loop received stop signal")
|
||||
return
|
||||
|
||||
case <-m.ctx.Done():
|
||||
m.logger.Info("Token refresh loop context cancelled")
|
||||
return
|
||||
|
||||
case <-ticker.C:
|
||||
if err := m.checkAndRefresh(); err != nil {
|
||||
m.logger.Error("Token refresh check failed", zap.Error(err))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// checkAndRefresh checks if token refresh is needed and performs it
|
||||
func (m *Manager) checkAndRefresh() error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
// Get current session
|
||||
sess, err := m.getSession.Execute()
|
||||
if err != nil {
|
||||
m.logger.Debug("No session found, skipping refresh check", zap.Error(err))
|
||||
return nil // Not an error - user might not be logged in
|
||||
}
|
||||
|
||||
if sess == nil {
|
||||
m.logger.Debug("Session is nil, skipping refresh check")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if session is still valid
|
||||
if sess.IsExpired() {
|
||||
m.logger.Warn("Session has expired, forcing logout")
|
||||
return m.forceLogout()
|
||||
}
|
||||
|
||||
// Check if refresh is needed
|
||||
timeUntilExpiry := time.Until(sess.ExpiresAt)
|
||||
if timeUntilExpiry > m.config.RefreshBeforeExpiry {
|
||||
// No refresh needed yet
|
||||
if m.consecutiveFailures > 0 {
|
||||
// Reset failure counter on successful check
|
||||
m.logger.Info("Session valid, resetting failure counter")
|
||||
m.consecutiveFailures = 0
|
||||
}
|
||||
m.logger.Debug("Token refresh not needed yet",
|
||||
zap.Duration("time_until_expiry", timeUntilExpiry))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Refresh needed
|
||||
m.logger.Info("Token refresh needed",
|
||||
zap.Duration("time_until_expiry", timeUntilExpiry))
|
||||
|
||||
// Attempt refresh (with background context, not the manager's context)
|
||||
refreshCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := m.client.RefreshToken(refreshCtx); err != nil {
|
||||
m.consecutiveFailures++
|
||||
m.logger.Error("Token refresh failed",
|
||||
zap.Error(err),
|
||||
zap.Int("consecutive_failures", m.consecutiveFailures),
|
||||
zap.Int("max_failures", m.config.MaxConsecutiveFailures))
|
||||
|
||||
if m.consecutiveFailures >= m.config.MaxConsecutiveFailures {
|
||||
m.logger.Error("Max consecutive refresh failures reached, forcing logout")
|
||||
return m.forceLogout()
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// Success - reset failure counter
|
||||
m.consecutiveFailures = 0
|
||||
m.logger.Info("Token refreshed successfully",
|
||||
zap.Duration("time_until_old_expiry", timeUntilExpiry))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// forceLogout forces a logout due to refresh failures
|
||||
func (m *Manager) forceLogout() error {
|
||||
m.logger.Warn("Forcing logout due to token refresh issues")
|
||||
|
||||
// Use background context since manager might be shutting down
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := m.authService.Logout(ctx); err != nil {
|
||||
m.logger.Error("Failed to force logout", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
m.logger.Info("Force logout completed successfully")
|
||||
return nil
|
||||
}
|
||||
|
|
@ -0,0 +1,20 @@
|
|||
package tokenmanager
|
||||
|
||||
import (
|
||||
"go.uber.org/zap"
|
||||
|
||||
"codeberg.org/mapleopentech/monorepo/cloud/maplefile-backend/pkg/maplefile/client"
|
||||
"codeberg.org/mapleopentech/monorepo/native/desktop/maplefile/internal/service/auth"
|
||||
"codeberg.org/mapleopentech/monorepo/native/desktop/maplefile/internal/usecase/session"
|
||||
)
|
||||
|
||||
// ProvideManager creates the token manager for Wire
|
||||
func ProvideManager(
|
||||
client *client.Client,
|
||||
authService *auth.Service,
|
||||
getSession *session.GetByIdUseCase,
|
||||
logger *zap.Logger,
|
||||
) *Manager {
|
||||
config := DefaultConfig()
|
||||
return New(config, client, authService, getSession, logger)
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue