Files
sendico/api/pkg/db/internal/mongo/refreshtokensdb/crud.go
Stephan D 717dafc673
Some checks failed
ci/woodpecker/push/billing_fees Pipeline was successful
ci/woodpecker/push/bff Pipeline was successful
ci/woodpecker/push/db Pipeline was successful
ci/woodpecker/push/chain_gateway Pipeline was successful
ci/woodpecker/push/fx_ingestor Pipeline was successful
ci/woodpecker/push/fx_oracle Pipeline was successful
ci/woodpecker/push/frontend Pipeline was successful
ci/woodpecker/push/payments_orchestrator Pipeline was successful
ci/woodpecker/push/bump_version Pipeline failed
ci/woodpecker/push/nats Pipeline was successful
ci/woodpecker/push/ledger Pipeline was successful
ci/woodpecker/push/notification Pipeline was successful
better message formatting
2025-11-19 13:54:25 +01:00

123 lines
4.7 KiB
Go

package refreshtokensdb
import (
"context"
"errors"
"time"
"github.com/tech/sendico/pkg/db/repository"
"github.com/tech/sendico/pkg/merrors"
"github.com/tech/sendico/pkg/model"
"github.com/tech/sendico/pkg/mservice"
"github.com/tech/sendico/pkg/mutil/mzap"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.uber.org/zap"
)
func (db *RefreshTokenDB) Create(ctx context.Context, rt *model.RefreshToken) error {
// First, try to find an existing token for this account/client/device combination
var existing model.RefreshToken
if rt.AccountRef == nil {
return merrors.InvalidArgument("Account reference must have a vaild value", "refreshToken.accountRef")
}
if err := db.FindOne(ctx, filterByAccount(*rt.AccountRef, &rt.SessionIdentifier), &existing); err != nil {
if errors.Is(err, merrors.ErrNoData) {
// No existing token, create a new one
db.Logger.Info("Registering refresh token", zap.String("client_id", rt.ClientID), zap.String("device_id", rt.DeviceID))
return db.DBImp.Create(ctx, rt)
}
db.Logger.Warn("Something went wrong when checking existing sessions", zap.Error(err),
zap.String("client_id", rt.ClientID), zap.String("device_id", rt.DeviceID))
return err
}
// Token already exists, update it with new values
db.Logger.Info("Updating existing refresh token", zap.String("client_id", rt.ClientID), zap.String("device_id", rt.DeviceID))
patch := repository.Patch().
Set(repository.Field(TokenField), rt.RefreshToken).
Set(repository.Field(ExpiresAtField), rt.ExpiresAt).
Set(repository.Field(UserAgentField), rt.UserAgent).
Set(repository.Field(IPAddressField), rt.IPAddress).
Set(repository.Field(LastUsedAtField), rt.LastUsedAt).
Set(repository.Field(IsRevokedField), rt.IsRevoked)
if err := db.Patch(ctx, *existing.GetID(), patch); err != nil {
db.Logger.Warn("Failed to patch refresh token", zap.Error(err), zap.String("client_id", rt.ClientID), zap.String("device_id", rt.DeviceID))
return err
}
// Update the ID of the input token to match the existing one
rt.SetID(*existing.GetID())
return nil
}
func (db *RefreshTokenDB) Update(ctx context.Context, rt *model.RefreshToken) error {
rt.LastUsedAt = time.Now()
// Use Patch instead of Update to avoid race conditions
patch := repository.Patch().
Set(repository.Field(TokenField), rt.RefreshToken).
Set(repository.Field(ExpiresAtField), rt.ExpiresAt).
Set(repository.Field(UserAgentField), rt.UserAgent).
Set(repository.Field(IPAddressField), rt.IPAddress).
Set(repository.Field(LastUsedAtField), rt.LastUsedAt).
Set(repository.Field(IsRevokedField), rt.IsRevoked)
return db.Patch(ctx, *rt.GetID(), patch)
}
func (db *RefreshTokenDB) Delete(ctx context.Context, tokenRef primitive.ObjectID) error {
db.Logger.Info("Deleting refresh token", mzap.ObjRef("refresh_token_ref", tokenRef))
return db.DBImp.Delete(ctx, tokenRef)
}
func (db *RefreshTokenDB) Revoke(ctx context.Context, accountRef primitive.ObjectID, session *model.SessionIdentifier) error {
var rt model.RefreshToken
f := filterByAccount(accountRef, session)
if err := db.Repository.FindOneByFilter(ctx, f, &rt); err != nil {
if errors.Is(err, merrors.ErrNoData) {
db.Logger.Warn("Failed to find refresh token", zap.Error(err),
mzap.ObjRef("account_ref", accountRef), zap.String("client_id", session.ClientID), zap.String("device_id", session.DeviceID))
return nil
}
return err
}
// Use Patch to update the revocation status
patch := repository.Patch().
Set(repository.Field(IsRevokedField), true).
Set(repository.Field(LastUsedAtField), time.Now())
return db.Patch(ctx, *rt.GetID(), patch)
}
func (db *RefreshTokenDB) GetByCRT(ctx context.Context, t *model.ClientRefreshToken) (*model.RefreshToken, error) {
var rt model.RefreshToken
f := filter(&t.SessionIdentifier).And(repository.Query().Filter(repository.Field("token"), t.RefreshToken))
if err := db.Repository.FindOneByFilter(ctx, f, &rt); err != nil {
if !errors.Is(err, merrors.ErrNoData) {
db.Logger.Warn("Failed to fetch refresh token", zap.Error(err),
zap.String("client_id", t.ClientID), zap.String("device_id", t.DeviceID))
}
return nil, err
}
// Check if token is expired
if rt.ExpiresAt.Before(time.Now()) {
db.Logger.Warn("Refresh token expired", mzap.StorableRef(&rt),
zap.String("client_id", t.ClientID), zap.String("device_id", t.DeviceID),
zap.Time("expires_at", rt.ExpiresAt))
return nil, merrors.AccessDenied(mservice.RefreshTokens, string(model.ActionRead), *rt.GetID())
}
// Check if token is revoked
if rt.IsRevoked {
db.Logger.Warn("Refresh token is revoked", mzap.StorableRef(&rt),
zap.String("client_id", t.ClientID), zap.String("device_id", t.DeviceID))
return nil, merrors.ErrNoData
}
return &rt, nil
}