service backend
This commit is contained in:
12
api/pkg/db/internal/mongo/refreshtokensdb/client.go
Normal file
12
api/pkg/db/internal/mongo/refreshtokensdb/client.go
Normal file
@@ -0,0 +1,12 @@
|
||||
package refreshtokensdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
)
|
||||
|
||||
func (db *RefreshTokenDB) GetClient(ctx context.Context, clientID string) (*model.Client, error) {
|
||||
var client model.Client
|
||||
return &client, db.clients.FindOneByFilter(ctx, filterByClientId(clientID), &client)
|
||||
}
|
||||
122
api/pkg/db/internal/mongo/refreshtokensdb/crud.go
Normal file
122
api/pkg/db/internal/mongo/refreshtokensdb/crud.go
Normal file
@@ -0,0 +1,122 @@
|
||||
package refreshtokensdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/tech/sendico/pkg/db/repository"
|
||||
"github.com/tech/sendico/pkg/merrors"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
"github.com/tech/sendico/pkg/mservice"
|
||||
"github.com/tech/sendico/pkg/mutil/mzap"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func (db *RefreshTokenDB) Create(ctx context.Context, rt *model.RefreshToken) error {
|
||||
// First, try to find an existing token for this account/client/device combination
|
||||
var existing model.RefreshToken
|
||||
if rt.AccountRef == nil {
|
||||
return merrors.InvalidArgument("Account reference must have a vaild value")
|
||||
}
|
||||
if err := db.FindOne(ctx, filterByAccount(*rt.AccountRef, &rt.SessionIdentifier), &existing); err != nil {
|
||||
if errors.Is(err, merrors.ErrNoData) {
|
||||
// No existing token, create a new one
|
||||
db.Logger.Info("Registering refresh token", zap.String("client_id", rt.ClientID), zap.String("device_id", rt.DeviceID))
|
||||
return db.DBImp.Create(ctx, rt)
|
||||
}
|
||||
db.Logger.Warn("Something went wrong when checking existing sessions", zap.Error(err),
|
||||
zap.String("client_id", rt.ClientID), zap.String("device_id", rt.DeviceID))
|
||||
return err
|
||||
}
|
||||
|
||||
// Token already exists, update it with new values
|
||||
db.Logger.Info("Updating existing refresh token", zap.String("client_id", rt.ClientID), zap.String("device_id", rt.DeviceID))
|
||||
|
||||
patch := repository.Patch().
|
||||
Set(repository.Field(TokenField), rt.RefreshToken).
|
||||
Set(repository.Field(ExpiresAtField), rt.ExpiresAt).
|
||||
Set(repository.Field(UserAgentField), rt.UserAgent).
|
||||
Set(repository.Field(IPAddressField), rt.IPAddress).
|
||||
Set(repository.Field(LastUsedAtField), rt.LastUsedAt).
|
||||
Set(repository.Field(IsRevokedField), rt.IsRevoked)
|
||||
|
||||
if err := db.Patch(ctx, *existing.GetID(), patch); err != nil {
|
||||
db.Logger.Warn("Failed to patch refresh token", zap.Error(err), zap.String("client_id", rt.ClientID), zap.String("device_id", rt.DeviceID))
|
||||
return err
|
||||
}
|
||||
|
||||
// Update the ID of the input token to match the existing one
|
||||
rt.SetID(*existing.GetID())
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *RefreshTokenDB) Update(ctx context.Context, rt *model.RefreshToken) error {
|
||||
rt.LastUsedAt = time.Now()
|
||||
|
||||
// Use Patch instead of Update to avoid race conditions
|
||||
patch := repository.Patch().
|
||||
Set(repository.Field(TokenField), rt.RefreshToken).
|
||||
Set(repository.Field(ExpiresAtField), rt.ExpiresAt).
|
||||
Set(repository.Field(UserAgentField), rt.UserAgent).
|
||||
Set(repository.Field(IPAddressField), rt.IPAddress).
|
||||
Set(repository.Field(LastUsedAtField), rt.LastUsedAt).
|
||||
Set(repository.Field(IsRevokedField), rt.IsRevoked)
|
||||
|
||||
return db.Patch(ctx, *rt.GetID(), patch)
|
||||
}
|
||||
|
||||
func (db *RefreshTokenDB) Delete(ctx context.Context, tokenRef primitive.ObjectID) error {
|
||||
db.Logger.Info("Deleting refresh token", mzap.ObjRef("refresh_token_ref", tokenRef))
|
||||
return db.DBImp.Delete(ctx, tokenRef)
|
||||
}
|
||||
|
||||
func (db *RefreshTokenDB) Revoke(ctx context.Context, accountRef primitive.ObjectID, session *model.SessionIdentifier) error {
|
||||
var rt model.RefreshToken
|
||||
f := filterByAccount(accountRef, session)
|
||||
if err := db.Repository.FindOneByFilter(ctx, f, &rt); err != nil {
|
||||
if errors.Is(err, merrors.ErrNoData) {
|
||||
db.Logger.Warn("Failed to find refresh token", zap.Error(err),
|
||||
mzap.ObjRef("account_ref", accountRef), zap.String("client_id", session.ClientID), zap.String("device_id", session.DeviceID))
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// Use Patch to update the revocation status
|
||||
patch := repository.Patch().
|
||||
Set(repository.Field(IsRevokedField), true).
|
||||
Set(repository.Field(LastUsedAtField), time.Now())
|
||||
|
||||
return db.Patch(ctx, *rt.GetID(), patch)
|
||||
}
|
||||
|
||||
func (db *RefreshTokenDB) GetByCRT(ctx context.Context, t *model.ClientRefreshToken) (*model.RefreshToken, error) {
|
||||
var rt model.RefreshToken
|
||||
f := filter(&t.SessionIdentifier).And(repository.Query().Filter(repository.Field("token"), t.RefreshToken))
|
||||
if err := db.Repository.FindOneByFilter(ctx, f, &rt); err != nil {
|
||||
if !errors.Is(err, merrors.ErrNoData) {
|
||||
db.Logger.Warn("Failed to fetch refresh token", zap.Error(err),
|
||||
zap.String("client_id", t.ClientID), zap.String("device_id", t.DeviceID))
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Check if token is expired
|
||||
if rt.ExpiresAt.Before(time.Now()) {
|
||||
db.Logger.Warn("Refresh token expired", mzap.StorableRef(&rt),
|
||||
zap.String("client_id", t.ClientID), zap.String("device_id", t.DeviceID),
|
||||
zap.Time("expires_at", rt.ExpiresAt))
|
||||
return nil, merrors.AccessDenied(mservice.RefreshTokens, string(model.ActionRead), *rt.GetID())
|
||||
}
|
||||
|
||||
// Check if token is revoked
|
||||
if rt.IsRevoked {
|
||||
db.Logger.Warn("Refresh token is revoked", mzap.StorableRef(&rt),
|
||||
zap.String("client_id", t.ClientID), zap.String("device_id", t.DeviceID))
|
||||
return nil, merrors.ErrNoData
|
||||
}
|
||||
|
||||
return &rt, nil
|
||||
}
|
||||
62
api/pkg/db/internal/mongo/refreshtokensdb/db.go
Normal file
62
api/pkg/db/internal/mongo/refreshtokensdb/db.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package refreshtokensdb
|
||||
|
||||
import (
|
||||
"github.com/tech/sendico/pkg/db/repository"
|
||||
ri "github.com/tech/sendico/pkg/db/repository/index"
|
||||
"github.com/tech/sendico/pkg/db/template"
|
||||
"github.com/tech/sendico/pkg/mlogger"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
"github.com/tech/sendico/pkg/mservice"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type RefreshTokenDB struct {
|
||||
template.DBImp[*model.RefreshToken]
|
||||
clients repository.Repository
|
||||
}
|
||||
|
||||
func Create(logger mlogger.Logger, db *mongo.Database) (*RefreshTokenDB, error) {
|
||||
p := &RefreshTokenDB{
|
||||
DBImp: *template.Create[*model.RefreshToken](logger, mservice.RefreshTokens, db),
|
||||
clients: repository.CreateMongoRepository(db, mservice.Clients),
|
||||
}
|
||||
|
||||
if err := p.Repository.CreateIndex(&ri.Definition{
|
||||
Keys: []ri.Key{{Field: "token", Sort: ri.Asc}},
|
||||
Unique: true,
|
||||
}); err != nil {
|
||||
p.Logger.Error("Failed to create unique token index", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Add unique constraint on account/client/device combination
|
||||
if err := p.Repository.CreateIndex(&ri.Definition{
|
||||
Keys: []ri.Key{
|
||||
{Field: "accountRef", Sort: ri.Asc},
|
||||
{Field: "clientId", Sort: ri.Asc},
|
||||
{Field: "deviceId", Sort: ri.Asc},
|
||||
},
|
||||
Unique: true,
|
||||
}); err != nil {
|
||||
p.Logger.Error("Failed to create unique account/client/device index", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := p.Repository.CreateIndex(&ri.Definition{
|
||||
Keys: []ri.Key{{Field: IsRevokedField, Sort: ri.Asc}},
|
||||
}); err != nil {
|
||||
p.Logger.Error("Failed to create unique token revokation status index", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := p.clients.CreateIndex(&ri.Definition{
|
||||
Keys: []ri.Key{{Field: "clientId", Sort: ri.Asc}},
|
||||
Unique: true,
|
||||
}); err != nil {
|
||||
p.Logger.Error("Failed to create unique client identifier index", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return p, nil
|
||||
}
|
||||
10
api/pkg/db/internal/mongo/refreshtokensdb/fields.go
Normal file
10
api/pkg/db/internal/mongo/refreshtokensdb/fields.go
Normal file
@@ -0,0 +1,10 @@
|
||||
package refreshtokensdb
|
||||
|
||||
const (
|
||||
ExpiresAtField = "expiresAt"
|
||||
IsRevokedField = "isRevoked"
|
||||
TokenField = "token"
|
||||
UserAgentField = "userAgent"
|
||||
IPAddressField = "ipAddress"
|
||||
LastUsedAtField = "lastUsedAt"
|
||||
)
|
||||
25
api/pkg/db/internal/mongo/refreshtokensdb/filters.go
Normal file
25
api/pkg/db/internal/mongo/refreshtokensdb/filters.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package refreshtokensdb
|
||||
|
||||
import (
|
||||
"github.com/tech/sendico/pkg/db/repository"
|
||||
"github.com/tech/sendico/pkg/db/repository/builder"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
)
|
||||
|
||||
func filterByClientId(clientID string) builder.Query {
|
||||
return repository.Query().Comparison(repository.Field("clientId"), builder.Eq, clientID)
|
||||
}
|
||||
|
||||
func filter(session *model.SessionIdentifier) builder.Query {
|
||||
filter := filterByClientId(session.ClientID)
|
||||
filter.And(
|
||||
repository.Query().Comparison(repository.Field("deviceId"), builder.Eq, session.DeviceID),
|
||||
repository.Query().Comparison(repository.Field(IsRevokedField), builder.Eq, false),
|
||||
)
|
||||
return filter
|
||||
}
|
||||
|
||||
func filterByAccount(accountRef primitive.ObjectID, session *model.SessionIdentifier) builder.Query {
|
||||
return filter(session).And(repository.Query().Comparison(repository.AccountField(), builder.Eq, accountRef))
|
||||
}
|
||||
@@ -0,0 +1,639 @@
|
||||
//go:build integration
|
||||
// +build integration
|
||||
|
||||
package refreshtokensdb_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/tech/sendico/pkg/db/internal/mongo/refreshtokensdb"
|
||||
"github.com/tech/sendico/pkg/db/repository"
|
||||
"github.com/tech/sendico/pkg/db/repository/builder"
|
||||
"github.com/tech/sendico/pkg/merrors"
|
||||
factory "github.com/tech/sendico/pkg/mlogger/factory"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/testcontainers/testcontainers-go"
|
||||
"github.com/testcontainers/testcontainers-go/modules/mongodb"
|
||||
"github.com/testcontainers/testcontainers-go/wait"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
"go.mongodb.org/mongo-driver/mongo/options"
|
||||
)
|
||||
|
||||
func setupTestDB(t *testing.T) (*refreshtokensdb.RefreshTokenDB, func()) {
|
||||
// mark as helper for better test failure reporting
|
||||
t.Helper()
|
||||
|
||||
startCtx, startCancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer startCancel()
|
||||
|
||||
mongoContainer, err := mongodb.Run(startCtx,
|
||||
"mongo:latest",
|
||||
mongodb.WithUsername("root"),
|
||||
mongodb.WithPassword("password"),
|
||||
testcontainers.WithWaitStrategy(wait.ForListeningPort("27017/tcp").WithStartupTimeout(2*time.Minute)),
|
||||
)
|
||||
require.NoError(t, err, "failed to start MongoDB container")
|
||||
|
||||
mongoURI, err := mongoContainer.ConnectionString(startCtx)
|
||||
require.NoError(t, err, "failed to get MongoDB connection string")
|
||||
|
||||
clientOptions := options.Client().ApplyURI(mongoURI)
|
||||
client, err := mongo.Connect(startCtx, clientOptions)
|
||||
require.NoError(t, err, "failed to connect to MongoDB")
|
||||
|
||||
database := client.Database("test_refresh_tokens_" + t.Name())
|
||||
logger := factory.NewLogger(true)
|
||||
|
||||
db, err := refreshtokensdb.Create(logger, database)
|
||||
require.NoError(t, err, "failed to create refresh tokens db")
|
||||
|
||||
cleanup := func() {
|
||||
_ = database.Drop(context.Background())
|
||||
_ = client.Disconnect(context.Background())
|
||||
termCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
_ = mongoContainer.Terminate(termCtx)
|
||||
}
|
||||
|
||||
return db, cleanup
|
||||
}
|
||||
|
||||
func createTestRefreshToken(accountRef primitive.ObjectID, clientID, deviceID, token string) *model.RefreshToken {
|
||||
return &model.RefreshToken{
|
||||
ClientRefreshToken: model.ClientRefreshToken{
|
||||
SessionIdentifier: model.SessionIdentifier{
|
||||
ClientID: clientID,
|
||||
DeviceID: deviceID,
|
||||
},
|
||||
RefreshToken: token,
|
||||
},
|
||||
AccountBoundBase: model.AccountBoundBase{
|
||||
AccountRef: &accountRef,
|
||||
},
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
IsRevoked: false,
|
||||
UserAgent: "TestUserAgent/1.0",
|
||||
IPAddress: "192.168.1.1",
|
||||
LastUsedAt: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
func TestRefreshTokenDB_AuthenticationFlow(t *testing.T) {
|
||||
db, cleanup := setupTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("Complete_User_Authentication_Flow", func(t *testing.T) {
|
||||
// Setup: Create user and client
|
||||
userID := primitive.NewObjectID()
|
||||
clientID := "web-app"
|
||||
deviceID := "user-desktop-chrome"
|
||||
token := "refresh_token_12345"
|
||||
|
||||
// Step 1: User logs in - create initial refresh token
|
||||
refreshToken := createTestRefreshToken(userID, clientID, deviceID, token)
|
||||
err := db.Create(ctx, refreshToken)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Step 2: User uses refresh token to get new access token
|
||||
crt := &model.ClientRefreshToken{
|
||||
SessionIdentifier: model.SessionIdentifier{
|
||||
ClientID: clientID,
|
||||
DeviceID: deviceID,
|
||||
},
|
||||
RefreshToken: token,
|
||||
}
|
||||
|
||||
retrievedToken, err := db.GetByCRT(ctx, crt)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, retrievedToken.AccountRef)
|
||||
assert.Equal(t, userID, *retrievedToken.AccountRef)
|
||||
assert.Equal(t, token, retrievedToken.RefreshToken)
|
||||
assert.False(t, retrievedToken.IsRevoked)
|
||||
|
||||
// Step 3: User logs out - revoke the token
|
||||
session := &model.SessionIdentifier{
|
||||
ClientID: clientID,
|
||||
DeviceID: deviceID,
|
||||
}
|
||||
err = db.Revoke(ctx, userID, session)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Step 4: Try to use revoked token - should fail
|
||||
_, err = db.GetByCRT(ctx, crt)
|
||||
assert.Error(t, err)
|
||||
assert.True(t, errors.Is(err, merrors.ErrNoData))
|
||||
})
|
||||
|
||||
t.Run("Manual_Token_Revocation_Workaround", func(t *testing.T) {
|
||||
// Test manual revocation by directly updating the token
|
||||
userID := primitive.NewObjectID()
|
||||
clientID := "web-app"
|
||||
deviceID := "user-desktop-chrome"
|
||||
token := "manual_revoke_token_123"
|
||||
|
||||
// Step 1: Create token
|
||||
refreshToken := createTestRefreshToken(userID, clientID, deviceID, token)
|
||||
err := db.Create(ctx, refreshToken)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Step 2: Manually revoke token by updating it directly
|
||||
refreshToken.IsRevoked = true
|
||||
err = db.Update(ctx, refreshToken)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Step 3: Try to use revoked token - should fail
|
||||
crt := &model.ClientRefreshToken{
|
||||
SessionIdentifier: model.SessionIdentifier{
|
||||
ClientID: clientID,
|
||||
DeviceID: deviceID,
|
||||
},
|
||||
RefreshToken: token,
|
||||
}
|
||||
|
||||
_, err = db.GetByCRT(ctx, crt)
|
||||
assert.Error(t, err)
|
||||
assert.True(t, errors.Is(err, merrors.ErrNoData))
|
||||
})
|
||||
}
|
||||
|
||||
func TestRefreshTokenDB_MultiDeviceManagement(t *testing.T) {
|
||||
db, cleanup := setupTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("User_With_Multiple_Devices", func(t *testing.T) {
|
||||
userID := primitive.NewObjectID()
|
||||
clientID := "mobile-app"
|
||||
|
||||
// User logs in from phone
|
||||
phoneToken := createTestRefreshToken(userID, clientID, "phone-ios", "phone_token_123")
|
||||
err := db.Create(ctx, phoneToken)
|
||||
require.NoError(t, err)
|
||||
|
||||
// User logs in from tablet
|
||||
tabletToken := createTestRefreshToken(userID, clientID, "tablet-android", "tablet_token_456")
|
||||
err = db.Create(ctx, tabletToken)
|
||||
require.NoError(t, err)
|
||||
|
||||
// User logs in from desktop
|
||||
desktopToken := createTestRefreshToken(userID, clientID, "desktop-windows", "desktop_token_789")
|
||||
err = db.Create(ctx, desktopToken)
|
||||
require.NoError(t, err)
|
||||
|
||||
// User wants to logout from all devices except current (phone)
|
||||
err = db.RevokeAll(ctx, userID, "phone-ios")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Phone should still work
|
||||
phoneCRT := &model.ClientRefreshToken{
|
||||
SessionIdentifier: model.SessionIdentifier{
|
||||
ClientID: clientID,
|
||||
DeviceID: "phone-ios",
|
||||
},
|
||||
RefreshToken: "phone_token_123",
|
||||
}
|
||||
_, err = db.GetByCRT(ctx, phoneCRT)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Tablet and desktop should be revoked
|
||||
tabletCRT := &model.ClientRefreshToken{
|
||||
SessionIdentifier: model.SessionIdentifier{
|
||||
ClientID: clientID,
|
||||
DeviceID: "tablet-android",
|
||||
},
|
||||
RefreshToken: "tablet_token_456",
|
||||
}
|
||||
_, err = db.GetByCRT(ctx, tabletCRT)
|
||||
assert.Error(t, err)
|
||||
|
||||
desktopCRT := &model.ClientRefreshToken{
|
||||
SessionIdentifier: model.SessionIdentifier{
|
||||
ClientID: clientID,
|
||||
DeviceID: "desktop-windows",
|
||||
},
|
||||
RefreshToken: "desktop_token_789",
|
||||
}
|
||||
_, err = db.GetByCRT(ctx, desktopCRT)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRefreshTokenDB_TokenRotation(t *testing.T) {
|
||||
db, cleanup := setupTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("Token_Rotation_On_Use", func(t *testing.T) {
|
||||
userID := primitive.NewObjectID()
|
||||
clientID := "web-app"
|
||||
deviceID := "user-browser"
|
||||
initialToken := "initial_token_123"
|
||||
|
||||
// Create initial token
|
||||
refreshToken := createTestRefreshToken(userID, clientID, deviceID, initialToken)
|
||||
err := db.Create(ctx, refreshToken)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Simulate small delay
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
// Use token - should update LastUsedAt
|
||||
crt := &model.ClientRefreshToken{
|
||||
SessionIdentifier: model.SessionIdentifier{
|
||||
ClientID: clientID,
|
||||
DeviceID: deviceID,
|
||||
},
|
||||
RefreshToken: initialToken,
|
||||
}
|
||||
|
||||
retrievedToken, err := db.GetByCRT(ctx, crt)
|
||||
require.NoError(t, err)
|
||||
// LastUsedAt is not updated by GetByCRT; validate token data instead
|
||||
assert.Equal(t, initialToken, retrievedToken.RefreshToken)
|
||||
|
||||
// Create new token with rotated value (simulating token rotation)
|
||||
newToken := "rotated_token_456"
|
||||
retrievedToken.RefreshToken = newToken
|
||||
err = db.Update(ctx, retrievedToken)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Old token should no longer work
|
||||
_, err = db.GetByCRT(ctx, crt)
|
||||
assert.Error(t, err)
|
||||
|
||||
// New token should work
|
||||
newCRT := &model.ClientRefreshToken{
|
||||
SessionIdentifier: model.SessionIdentifier{
|
||||
ClientID: clientID,
|
||||
DeviceID: deviceID,
|
||||
},
|
||||
RefreshToken: newToken,
|
||||
}
|
||||
_, err = db.GetByCRT(ctx, newCRT)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRefreshTokenDB_SessionReplacement(t *testing.T) {
|
||||
db, cleanup := setupTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("User_Login_From_Same_Device_Twice", func(t *testing.T) {
|
||||
userID := primitive.NewObjectID()
|
||||
clientID := "web-app"
|
||||
deviceID := "user-laptop"
|
||||
|
||||
// First login
|
||||
firstToken := createTestRefreshToken(userID, clientID, deviceID, "first_token_123")
|
||||
err := db.Create(ctx, firstToken)
|
||||
require.NoError(t, err)
|
||||
firstTokenID := *firstToken.GetID()
|
||||
|
||||
// Second login from same device - should replace existing token
|
||||
secondToken := createTestRefreshToken(userID, clientID, deviceID, "second_token_456")
|
||||
err = db.Create(ctx, secondToken)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should reuse the same database record
|
||||
assert.Equal(t, firstTokenID, *secondToken.GetID())
|
||||
|
||||
// First token should no longer work
|
||||
firstCRT := &model.ClientRefreshToken{
|
||||
SessionIdentifier: model.SessionIdentifier{
|
||||
ClientID: clientID,
|
||||
DeviceID: deviceID,
|
||||
},
|
||||
RefreshToken: "first_token_123",
|
||||
}
|
||||
_, err = db.GetByCRT(ctx, firstCRT)
|
||||
assert.Error(t, err)
|
||||
|
||||
// Second token should work
|
||||
secondCRT := &model.ClientRefreshToken{
|
||||
SessionIdentifier: model.SessionIdentifier{
|
||||
ClientID: clientID,
|
||||
DeviceID: deviceID,
|
||||
},
|
||||
RefreshToken: "second_token_456",
|
||||
}
|
||||
_, err = db.GetByCRT(ctx, secondCRT)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRefreshTokenDB_ClientManagement(t *testing.T) {
|
||||
db, cleanup := setupTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("Client_CRUD_Operations", func(t *testing.T) {
|
||||
// Note: Client management is handled by a separate client database
|
||||
// This test verifies that refresh tokens work with different client IDs
|
||||
|
||||
userID := primitive.NewObjectID()
|
||||
|
||||
// Create refresh tokens for different clients
|
||||
webToken := createTestRefreshToken(userID, "web-app", "device1", "token1")
|
||||
err := db.Create(ctx, webToken)
|
||||
require.NoError(t, err)
|
||||
|
||||
mobileToken := createTestRefreshToken(userID, "mobile-app", "device2", "token2")
|
||||
err = db.Create(ctx, mobileToken)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify tokens can be retrieved by client ID
|
||||
webCRT := &model.ClientRefreshToken{
|
||||
SessionIdentifier: model.SessionIdentifier{
|
||||
ClientID: "web-app",
|
||||
DeviceID: "device1",
|
||||
},
|
||||
RefreshToken: "token1",
|
||||
}
|
||||
|
||||
retrievedToken, err := db.GetByCRT(ctx, webCRT)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "web-app", retrievedToken.ClientID)
|
||||
assert.Equal(t, "device1", retrievedToken.DeviceID)
|
||||
|
||||
mobileCRT := &model.ClientRefreshToken{
|
||||
SessionIdentifier: model.SessionIdentifier{
|
||||
ClientID: "mobile-app",
|
||||
DeviceID: "device2",
|
||||
},
|
||||
RefreshToken: "token2",
|
||||
}
|
||||
|
||||
retrievedToken, err = db.GetByCRT(ctx, mobileCRT)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "mobile-app", retrievedToken.ClientID)
|
||||
assert.Equal(t, "device2", retrievedToken.DeviceID)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRefreshTokenDB_SecurityScenarios(t *testing.T) {
|
||||
db, cleanup := setupTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("Token_Hijacking_Prevention", func(t *testing.T) {
|
||||
userID := primitive.NewObjectID()
|
||||
clientID := "web-app"
|
||||
deviceID := "user-browser"
|
||||
token := "hijacked_token_123"
|
||||
|
||||
// Create legitimate token
|
||||
refreshToken := createTestRefreshToken(userID, clientID, deviceID, token)
|
||||
err := db.Create(ctx, refreshToken)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Simulate security concern - revoke token
|
||||
session := &model.SessionIdentifier{
|
||||
ClientID: clientID,
|
||||
DeviceID: deviceID,
|
||||
}
|
||||
err = db.Revoke(ctx, userID, session)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Attacker tries to use hijacked token
|
||||
crt := &model.ClientRefreshToken{
|
||||
SessionIdentifier: model.SessionIdentifier{
|
||||
ClientID: clientID,
|
||||
DeviceID: deviceID,
|
||||
},
|
||||
RefreshToken: token,
|
||||
}
|
||||
|
||||
_, err = db.GetByCRT(ctx, crt)
|
||||
assert.Error(t, err)
|
||||
assert.True(t, errors.Is(err, merrors.ErrNoData))
|
||||
})
|
||||
|
||||
t.Run("Invalid_Token_Attempts", func(t *testing.T) {
|
||||
// Try to use completely invalid token
|
||||
crt := &model.ClientRefreshToken{
|
||||
SessionIdentifier: model.SessionIdentifier{
|
||||
ClientID: "invalid-client",
|
||||
DeviceID: "invalid-device",
|
||||
},
|
||||
RefreshToken: "invalid_token_123",
|
||||
}
|
||||
|
||||
_, err := db.GetByCRT(ctx, crt)
|
||||
assert.Error(t, err)
|
||||
assert.True(t, errors.Is(err, merrors.ErrNoData))
|
||||
})
|
||||
}
|
||||
|
||||
func TestRefreshTokenDB_ExpiredTokenHandling(t *testing.T) {
|
||||
db, cleanup := setupTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("Expired_Token_Cleanup", func(t *testing.T) {
|
||||
userID := primitive.NewObjectID()
|
||||
clientID := "web-app"
|
||||
deviceID := "user-device"
|
||||
token := "expired_token_123"
|
||||
|
||||
// Create token that expires in the past
|
||||
refreshToken := createTestRefreshToken(userID, clientID, deviceID, token)
|
||||
refreshToken.ExpiresAt = time.Now().Add(-1 * time.Hour) // Expired 1 hour ago
|
||||
err := db.Create(ctx, refreshToken)
|
||||
require.NoError(t, err)
|
||||
|
||||
// The token exists in database but is expired
|
||||
var storedToken model.RefreshToken
|
||||
err = db.Get(ctx, *refreshToken.GetID(), &storedToken)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, storedToken.ExpiresAt.Before(time.Now()))
|
||||
|
||||
// Application should reject expired tokens
|
||||
crt := &model.ClientRefreshToken{
|
||||
SessionIdentifier: model.SessionIdentifier{
|
||||
ClientID: clientID,
|
||||
DeviceID: deviceID,
|
||||
},
|
||||
RefreshToken: token,
|
||||
}
|
||||
|
||||
_, err = db.GetByCRT(ctx, crt)
|
||||
assert.Error(t, err)
|
||||
assert.True(t, errors.Is(err, merrors.ErrAccessDenied))
|
||||
})
|
||||
}
|
||||
|
||||
func TestRefreshTokenDB_ConcurrentAccess(t *testing.T) {
|
||||
db, cleanup := setupTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("Concurrent_Token_Usage", func(t *testing.T) {
|
||||
userID := primitive.NewObjectID()
|
||||
clientID := "web-app"
|
||||
deviceID := "user-device"
|
||||
token := "concurrent_token_123"
|
||||
|
||||
// Create token
|
||||
refreshToken := createTestRefreshToken(userID, clientID, deviceID, token)
|
||||
err := db.Create(ctx, refreshToken)
|
||||
require.NoError(t, err)
|
||||
|
||||
crt := &model.ClientRefreshToken{
|
||||
SessionIdentifier: model.SessionIdentifier{
|
||||
ClientID: clientID,
|
||||
DeviceID: deviceID,
|
||||
},
|
||||
RefreshToken: token,
|
||||
}
|
||||
|
||||
// Simulate concurrent access
|
||||
done := make(chan error, 2)
|
||||
|
||||
go func() {
|
||||
_, err := db.GetByCRT(ctx, crt)
|
||||
done <- err
|
||||
}()
|
||||
|
||||
go func() {
|
||||
_, err := db.GetByCRT(ctx, crt)
|
||||
done <- err
|
||||
}()
|
||||
|
||||
// Both operations should succeed
|
||||
for i := 0; i < 2; i++ {
|
||||
err := <-done
|
||||
require.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestRefreshTokenDB_EdgeCases(t *testing.T) {
|
||||
db, cleanup := setupTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("Delete_Token_By_ID", func(t *testing.T) {
|
||||
userID := primitive.NewObjectID()
|
||||
refreshToken := createTestRefreshToken(userID, "web-app", "device-1", "token_123")
|
||||
err := db.Create(ctx, refreshToken)
|
||||
require.NoError(t, err)
|
||||
|
||||
tokenID := *refreshToken.GetID()
|
||||
|
||||
// Delete token
|
||||
err = db.Delete(ctx, tokenID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Token should no longer exist
|
||||
var result model.RefreshToken
|
||||
err = db.Get(ctx, tokenID, &result)
|
||||
assert.Error(t, err)
|
||||
assert.True(t, errors.Is(err, merrors.ErrNoData))
|
||||
})
|
||||
|
||||
t.Run("Revoke_Non_Existent_Token", func(t *testing.T) {
|
||||
userID := primitive.NewObjectID()
|
||||
session := &model.SessionIdentifier{
|
||||
ClientID: "non-existent-client",
|
||||
DeviceID: "non-existent-device",
|
||||
}
|
||||
|
||||
err := db.Revoke(ctx, userID, session)
|
||||
// Should handle gracefully for non-existent tokens
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("RevokeAll_No_Other_Devices", func(t *testing.T) {
|
||||
userID := primitive.NewObjectID()
|
||||
clientID := "web-app"
|
||||
deviceID := "only-device"
|
||||
|
||||
// Create single token
|
||||
refreshToken := createTestRefreshToken(userID, clientID, deviceID, "token_123")
|
||||
err := db.Create(ctx, refreshToken)
|
||||
require.NoError(t, err)
|
||||
|
||||
// RevokeAll should not affect current device
|
||||
err = db.RevokeAll(ctx, userID, deviceID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Token should still work
|
||||
crt := &model.ClientRefreshToken{
|
||||
SessionIdentifier: model.SessionIdentifier{
|
||||
ClientID: clientID,
|
||||
DeviceID: deviceID,
|
||||
},
|
||||
RefreshToken: "token_123",
|
||||
}
|
||||
|
||||
_, err = db.GetByCRT(ctx, crt)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRefreshTokenDB_DatabaseIndexes(t *testing.T) {
|
||||
db, cleanup := setupTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("Unique_Token_Constraint", func(t *testing.T) {
|
||||
userID1 := primitive.NewObjectID()
|
||||
userID2 := primitive.NewObjectID()
|
||||
token := "duplicate_token_123"
|
||||
|
||||
// Create first token
|
||||
refreshToken1 := createTestRefreshToken(userID1, "client1", "device1", token)
|
||||
err := db.Create(ctx, refreshToken1)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Try to create second token with same token value - should fail due to unique index
|
||||
refreshToken2 := createTestRefreshToken(userID2, "client2", "device2", token)
|
||||
err = db.Create(ctx, refreshToken2)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "duplicate")
|
||||
})
|
||||
|
||||
t.Run("Query_Performance_By_Revocation_Status", func(t *testing.T) {
|
||||
userID := primitive.NewObjectID()
|
||||
clientID := "web-app"
|
||||
|
||||
// Create multiple tokens
|
||||
for i := 0; i < 10; i++ {
|
||||
token := createTestRefreshToken(userID, clientID,
|
||||
fmt.Sprintf("device_%d", i), fmt.Sprintf("token_%d", i))
|
||||
if i%2 == 0 {
|
||||
token.IsRevoked = true
|
||||
}
|
||||
err := db.Create(ctx, token)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Query should efficiently filter by revocation status
|
||||
query := repository.Query().
|
||||
Filter(repository.AccountField(), userID).
|
||||
And(repository.Query().Comparison(repository.Field(refreshtokensdb.IsRevokedField), builder.Eq, false))
|
||||
|
||||
ids, err := db.ListIDs(ctx, query)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, ids, 5) // Should find 5 non-revoked tokens
|
||||
})
|
||||
}
|
||||
24
api/pkg/db/internal/mongo/refreshtokensdb/revoke.go
Normal file
24
api/pkg/db/internal/mongo/refreshtokensdb/revoke.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package refreshtokensdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/tech/sendico/pkg/db/repository"
|
||||
"github.com/tech/sendico/pkg/db/repository/builder"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
)
|
||||
|
||||
func (db *RefreshTokenDB) RevokeAll(ctx context.Context, accountRef primitive.ObjectID, deviceID string) error {
|
||||
query := repository.Query().
|
||||
Filter(repository.AccountField(), accountRef).
|
||||
And(repository.Query().Comparison(repository.Field("deviceId"), builder.Ne, deviceID)).
|
||||
And(repository.Query().Comparison(repository.Field(IsRevokedField), builder.Eq, false))
|
||||
|
||||
patch := repository.Patch().
|
||||
Set(repository.Field(ExpiresAtField), time.Now()).
|
||||
Set(repository.Field(IsRevokedField), true)
|
||||
|
||||
_, err := db.Repository.PatchMany(ctx, query, patch)
|
||||
return err
|
||||
}
|
||||
Reference in New Issue
Block a user