package refreshtokensdb import ( "context" "errors" "time" "github.com/tech/sendico/pkg/db/repository" "github.com/tech/sendico/pkg/merrors" "github.com/tech/sendico/pkg/model" "github.com/tech/sendico/pkg/mservice" "github.com/tech/sendico/pkg/mutil/mzap" "go.mongodb.org/mongo-driver/bson/primitive" "go.uber.org/zap" ) func (db *RefreshTokenDB) Create(ctx context.Context, rt *model.RefreshToken) error { // First, try to find an existing token for this account/client/device combination var existing model.RefreshToken if rt.AccountRef == nil { return merrors.InvalidArgument("Account reference must have a vaild value") } if err := db.FindOne(ctx, filterByAccount(*rt.AccountRef, &rt.SessionIdentifier), &existing); err != nil { if errors.Is(err, merrors.ErrNoData) { // No existing token, create a new one db.Logger.Info("Registering refresh token", zap.String("client_id", rt.ClientID), zap.String("device_id", rt.DeviceID)) return db.DBImp.Create(ctx, rt) } db.Logger.Warn("Something went wrong when checking existing sessions", zap.Error(err), zap.String("client_id", rt.ClientID), zap.String("device_id", rt.DeviceID)) return err } // Token already exists, update it with new values db.Logger.Info("Updating existing refresh token", zap.String("client_id", rt.ClientID), zap.String("device_id", rt.DeviceID)) patch := repository.Patch(). Set(repository.Field(TokenField), rt.RefreshToken). Set(repository.Field(ExpiresAtField), rt.ExpiresAt). Set(repository.Field(UserAgentField), rt.UserAgent). Set(repository.Field(IPAddressField), rt.IPAddress). Set(repository.Field(LastUsedAtField), rt.LastUsedAt). Set(repository.Field(IsRevokedField), rt.IsRevoked) if err := db.Patch(ctx, *existing.GetID(), patch); err != nil { db.Logger.Warn("Failed to patch refresh token", zap.Error(err), zap.String("client_id", rt.ClientID), zap.String("device_id", rt.DeviceID)) return err } // Update the ID of the input token to match the existing one rt.SetID(*existing.GetID()) return nil } func (db *RefreshTokenDB) Update(ctx context.Context, rt *model.RefreshToken) error { rt.LastUsedAt = time.Now() // Use Patch instead of Update to avoid race conditions patch := repository.Patch(). Set(repository.Field(TokenField), rt.RefreshToken). Set(repository.Field(ExpiresAtField), rt.ExpiresAt). Set(repository.Field(UserAgentField), rt.UserAgent). Set(repository.Field(IPAddressField), rt.IPAddress). Set(repository.Field(LastUsedAtField), rt.LastUsedAt). Set(repository.Field(IsRevokedField), rt.IsRevoked) return db.Patch(ctx, *rt.GetID(), patch) } func (db *RefreshTokenDB) Delete(ctx context.Context, tokenRef primitive.ObjectID) error { db.Logger.Info("Deleting refresh token", mzap.ObjRef("refresh_token_ref", tokenRef)) return db.DBImp.Delete(ctx, tokenRef) } func (db *RefreshTokenDB) Revoke(ctx context.Context, accountRef primitive.ObjectID, session *model.SessionIdentifier) error { var rt model.RefreshToken f := filterByAccount(accountRef, session) if err := db.Repository.FindOneByFilter(ctx, f, &rt); err != nil { if errors.Is(err, merrors.ErrNoData) { db.Logger.Warn("Failed to find refresh token", zap.Error(err), mzap.ObjRef("account_ref", accountRef), zap.String("client_id", session.ClientID), zap.String("device_id", session.DeviceID)) return nil } return err } // Use Patch to update the revocation status patch := repository.Patch(). Set(repository.Field(IsRevokedField), true). Set(repository.Field(LastUsedAtField), time.Now()) return db.Patch(ctx, *rt.GetID(), patch) } func (db *RefreshTokenDB) GetByCRT(ctx context.Context, t *model.ClientRefreshToken) (*model.RefreshToken, error) { var rt model.RefreshToken f := filter(&t.SessionIdentifier).And(repository.Query().Filter(repository.Field("token"), t.RefreshToken)) if err := db.Repository.FindOneByFilter(ctx, f, &rt); err != nil { if !errors.Is(err, merrors.ErrNoData) { db.Logger.Warn("Failed to fetch refresh token", zap.Error(err), zap.String("client_id", t.ClientID), zap.String("device_id", t.DeviceID)) } return nil, err } // Check if token is expired if rt.ExpiresAt.Before(time.Now()) { db.Logger.Warn("Refresh token expired", mzap.StorableRef(&rt), zap.String("client_id", t.ClientID), zap.String("device_id", t.DeviceID), zap.Time("expires_at", rt.ExpiresAt)) return nil, merrors.AccessDenied(mservice.RefreshTokens, string(model.ActionRead), *rt.GetID()) } // Check if token is revoked if rt.IsRevoked { db.Logger.Warn("Refresh token is revoked", mzap.StorableRef(&rt), zap.String("client_id", t.ClientID), zap.String("device_id", t.DeviceID)) return nil, merrors.ErrNoData } return &rt, nil }