Files
sendico/api/pkg/db/internal/mongo/refreshtokensdb/refreshtokensdb_test.go

729 lines
20 KiB
Go

//go:build integration
// +build integration
package refreshtokensdb_test
import (
"context"
"errors"
"fmt"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tech/sendico/pkg/db/internal/mongo/refreshtokensdb"
"github.com/tech/sendico/pkg/db/repository"
"github.com/tech/sendico/pkg/db/repository/builder"
"github.com/tech/sendico/pkg/merrors"
factory "github.com/tech/sendico/pkg/mlogger/factory"
"github.com/tech/sendico/pkg/model"
"github.com/testcontainers/testcontainers-go"
"github.com/testcontainers/testcontainers-go/modules/mongodb"
"github.com/testcontainers/testcontainers-go/wait"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
)
func setupTestDB(t *testing.T) (*refreshtokensdb.RefreshTokenDB, func()) {
db, _, cleanup := setupTestDBWithMongo(t)
return db, cleanup
}
func setupTestDBWithMongo(t *testing.T) (*refreshtokensdb.RefreshTokenDB, *mongo.Database, func()) {
// mark as helper for better test failure reporting
t.Helper()
startCtx, startCancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer startCancel()
mongoContainer, err := mongodb.Run(startCtx,
"mongo:latest",
mongodb.WithUsername("root"),
mongodb.WithPassword("password"),
testcontainers.WithWaitStrategy(wait.ForListeningPort("27017/tcp").WithStartupTimeout(2*time.Minute)),
)
require.NoError(t, err, "failed to start MongoDB container")
mongoURI, err := mongoContainer.ConnectionString(startCtx)
require.NoError(t, err, "failed to get MongoDB connection string")
clientOptions := options.Client().ApplyURI(mongoURI)
client, err := mongo.Connect(startCtx, clientOptions)
require.NoError(t, err, "failed to connect to MongoDB")
database := client.Database("test_refresh_tokens_" + t.Name())
logger := factory.NewLogger(true)
db, err := refreshtokensdb.Create(logger, database)
require.NoError(t, err, "failed to create refresh tokens db")
cleanup := func() {
_ = database.Drop(context.Background())
_ = client.Disconnect(context.Background())
termCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
_ = mongoContainer.Terminate(termCtx)
}
return db, database, cleanup
}
func createTestRefreshToken(accountRef primitive.ObjectID, clientID, deviceID, token string) *model.RefreshToken {
return &model.RefreshToken{
ClientRefreshToken: model.ClientRefreshToken{
SessionIdentifier: model.SessionIdentifier{
ClientID: clientID,
DeviceID: deviceID,
},
RefreshToken: token,
},
AccountBoundBase: model.AccountBoundBase{
AccountRef: &accountRef,
},
ExpiresAt: time.Now().Add(24 * time.Hour),
IsRevoked: false,
UserAgent: "TestUserAgent/1.0",
IPAddress: "192.168.1.1",
LastUsedAt: time.Now(),
}
}
func TestRefreshTokenDB_AuthenticationFlow(t *testing.T) {
db, cleanup := setupTestDB(t)
defer cleanup()
ctx := context.Background()
t.Run("Complete_User_Authentication_Flow", func(t *testing.T) {
// Setup: Create user and client
userID := primitive.NewObjectID()
clientID := "web-app"
deviceID := "user-desktop-chrome"
token := "refresh_token_12345"
// Step 1: User logs in - create initial refresh token
refreshToken := createTestRefreshToken(userID, clientID, deviceID, token)
err := db.Create(ctx, refreshToken)
require.NoError(t, err)
// Step 2: User uses refresh token to get new access token
crt := &model.ClientRefreshToken{
SessionIdentifier: model.SessionIdentifier{
ClientID: clientID,
DeviceID: deviceID,
},
RefreshToken: token,
}
retrievedToken, err := db.GetByCRT(ctx, crt)
require.NoError(t, err)
require.NotNil(t, retrievedToken.AccountRef)
assert.Equal(t, userID, *retrievedToken.AccountRef)
assert.Equal(t, token, retrievedToken.RefreshToken)
assert.False(t, retrievedToken.IsRevoked)
// Step 3: User logs out - revoke the token
session := &model.SessionIdentifier{
ClientID: clientID,
DeviceID: deviceID,
}
err = db.Revoke(ctx, userID, session)
require.NoError(t, err)
// Step 4: Try to use revoked token - should fail
_, err = db.GetByCRT(ctx, crt)
assert.Error(t, err)
assert.True(t, errors.Is(err, merrors.ErrNoData))
})
t.Run("Manual_Token_Revocation_Workaround", func(t *testing.T) {
// Test manual revocation by directly updating the token
userID := primitive.NewObjectID()
clientID := "web-app"
deviceID := "user-desktop-chrome"
token := "manual_revoke_token_123"
// Step 1: Create token
refreshToken := createTestRefreshToken(userID, clientID, deviceID, token)
err := db.Create(ctx, refreshToken)
require.NoError(t, err)
// Step 2: Manually revoke token by updating it directly
refreshToken.IsRevoked = true
err = db.Update(ctx, refreshToken)
require.NoError(t, err)
// Step 3: Try to use revoked token - should fail
crt := &model.ClientRefreshToken{
SessionIdentifier: model.SessionIdentifier{
ClientID: clientID,
DeviceID: deviceID,
},
RefreshToken: token,
}
_, err = db.GetByCRT(ctx, crt)
assert.Error(t, err)
assert.True(t, errors.Is(err, merrors.ErrNoData))
})
}
func TestRefreshTokenDB_MultiDeviceManagement(t *testing.T) {
db, cleanup := setupTestDB(t)
defer cleanup()
ctx := context.Background()
t.Run("User_With_Multiple_Devices", func(t *testing.T) {
userID := primitive.NewObjectID()
clientID := "mobile-app"
// User logs in from phone
phoneToken := createTestRefreshToken(userID, clientID, "phone-ios", "phone_token_123")
err := db.Create(ctx, phoneToken)
require.NoError(t, err)
// User logs in from tablet
tabletToken := createTestRefreshToken(userID, clientID, "tablet-android", "tablet_token_456")
err = db.Create(ctx, tabletToken)
require.NoError(t, err)
// User logs in from desktop
desktopToken := createTestRefreshToken(userID, clientID, "desktop-windows", "desktop_token_789")
err = db.Create(ctx, desktopToken)
require.NoError(t, err)
// User wants to logout from all devices except current (phone)
err = db.RevokeAll(ctx, userID, "phone-ios")
require.NoError(t, err)
// Phone should still work
phoneCRT := &model.ClientRefreshToken{
SessionIdentifier: model.SessionIdentifier{
ClientID: clientID,
DeviceID: "phone-ios",
},
RefreshToken: "phone_token_123",
}
_, err = db.GetByCRT(ctx, phoneCRT)
require.NoError(t, err)
// Tablet and desktop should be revoked
tabletCRT := &model.ClientRefreshToken{
SessionIdentifier: model.SessionIdentifier{
ClientID: clientID,
DeviceID: "tablet-android",
},
RefreshToken: "tablet_token_456",
}
_, err = db.GetByCRT(ctx, tabletCRT)
assert.Error(t, err)
desktopCRT := &model.ClientRefreshToken{
SessionIdentifier: model.SessionIdentifier{
ClientID: clientID,
DeviceID: "desktop-windows",
},
RefreshToken: "desktop_token_789",
}
_, err = db.GetByCRT(ctx, desktopCRT)
assert.Error(t, err)
})
}
func TestRefreshTokenDB_TokenRotation(t *testing.T) {
db, cleanup := setupTestDB(t)
defer cleanup()
ctx := context.Background()
t.Run("Token_Rotation_On_Use", func(t *testing.T) {
userID := primitive.NewObjectID()
clientID := "web-app"
deviceID := "user-browser"
initialToken := "initial_token_123"
// Create initial token
refreshToken := createTestRefreshToken(userID, clientID, deviceID, initialToken)
err := db.Create(ctx, refreshToken)
require.NoError(t, err)
// Simulate small delay
time.Sleep(10 * time.Millisecond)
// Use token - should update LastUsedAt
crt := &model.ClientRefreshToken{
SessionIdentifier: model.SessionIdentifier{
ClientID: clientID,
DeviceID: deviceID,
},
RefreshToken: initialToken,
}
retrievedToken, err := db.GetByCRT(ctx, crt)
require.NoError(t, err)
// LastUsedAt is not updated by GetByCRT; validate token data instead
assert.Equal(t, initialToken, retrievedToken.RefreshToken)
// Create new token with rotated value (simulating token rotation)
newToken := "rotated_token_456"
retrievedToken.RefreshToken = newToken
err = db.Update(ctx, retrievedToken)
require.NoError(t, err)
// Old token should no longer work
_, err = db.GetByCRT(ctx, crt)
assert.Error(t, err)
// New token should work
newCRT := &model.ClientRefreshToken{
SessionIdentifier: model.SessionIdentifier{
ClientID: clientID,
DeviceID: deviceID,
},
RefreshToken: newToken,
}
_, err = db.GetByCRT(ctx, newCRT)
require.NoError(t, err)
})
}
func TestRefreshTokenDB_SessionReplacement(t *testing.T) {
db, cleanup := setupTestDB(t)
defer cleanup()
ctx := context.Background()
t.Run("User_Login_From_Same_Device_Twice", func(t *testing.T) {
userID := primitive.NewObjectID()
clientID := "web-app"
deviceID := "user-laptop"
// First login
firstToken := createTestRefreshToken(userID, clientID, deviceID, "first_token_123")
err := db.Create(ctx, firstToken)
require.NoError(t, err)
firstTokenID := *firstToken.GetID()
// Second login from same device - should replace existing token
secondToken := createTestRefreshToken(userID, clientID, deviceID, "second_token_456")
err = db.Create(ctx, secondToken)
require.NoError(t, err)
// Should reuse the same database record
assert.Equal(t, firstTokenID, *secondToken.GetID())
// First token should no longer work
firstCRT := &model.ClientRefreshToken{
SessionIdentifier: model.SessionIdentifier{
ClientID: clientID,
DeviceID: deviceID,
},
RefreshToken: "first_token_123",
}
_, err = db.GetByCRT(ctx, firstCRT)
assert.Error(t, err)
// Second token should work
secondCRT := &model.ClientRefreshToken{
SessionIdentifier: model.SessionIdentifier{
ClientID: clientID,
DeviceID: deviceID,
},
RefreshToken: "second_token_456",
}
_, err = db.GetByCRT(ctx, secondCRT)
require.NoError(t, err)
})
t.Run("Create_After_GlobalRevocation_AllowsNewActive", func(t *testing.T) {
userID := primitive.NewObjectID()
clientID := "web-app"
deviceID := "user-laptop"
firstToken := createTestRefreshToken(userID, clientID, deviceID, "revoked_token_123")
err := db.Create(ctx, firstToken)
require.NoError(t, err)
require.NotNil(t, firstToken.GetID())
// Global revoke (deviceID empty) — all tokens should be revoked
err = db.RevokeAll(ctx, userID, "")
require.NoError(t, err)
var revoked model.RefreshToken
err = db.Get(ctx, *firstToken.GetID(), &revoked)
require.NoError(t, err)
assert.True(t, revoked.IsRevoked)
// Creating a new token for the same account/client/device must succeed and produce an active token
reissueToken := createTestRefreshToken(userID, clientID, deviceID, "new_token_after_revocation")
err = db.Create(ctx, reissueToken)
require.NoError(t, err)
newCRT := &model.ClientRefreshToken{
SessionIdentifier: model.SessionIdentifier{
ClientID: clientID,
DeviceID: deviceID,
},
RefreshToken: "new_token_after_revocation",
}
_, err = db.GetByCRT(ctx, newCRT)
require.NoError(t, err)
// Old token must remain unusable
oldCRT := &model.ClientRefreshToken{
SessionIdentifier: model.SessionIdentifier{
ClientID: clientID,
DeviceID: deviceID,
},
RefreshToken: "revoked_token_123",
}
_, err = db.GetByCRT(ctx, oldCRT)
assert.Error(t, err)
// Both records exist: revoked + new active
query := repository.Query().
Filter(repository.AccountField(), userID).
And(
repository.Query().Comparison(repository.Field("clientId"), builder.Eq, clientID),
repository.Query().Comparison(repository.Field("deviceId"), builder.Eq, deviceID),
)
ids, err := db.Repository.ListIDs(ctx, query)
require.NoError(t, err)
assert.Len(t, ids, 2)
})
}
func TestRefreshTokenDB_ClientManagement(t *testing.T) {
db, cleanup := setupTestDB(t)
defer cleanup()
ctx := context.Background()
t.Run("Client_CRUD_Operations", func(t *testing.T) {
// Note: Client management is handled by a separate client database
// This test verifies that refresh tokens work with different client IDs
userID := primitive.NewObjectID()
// Create refresh tokens for different clients
webToken := createTestRefreshToken(userID, "web-app", "device1", "token1")
err := db.Create(ctx, webToken)
require.NoError(t, err)
mobileToken := createTestRefreshToken(userID, "mobile-app", "device2", "token2")
err = db.Create(ctx, mobileToken)
require.NoError(t, err)
// Verify tokens can be retrieved by client ID
webCRT := &model.ClientRefreshToken{
SessionIdentifier: model.SessionIdentifier{
ClientID: "web-app",
DeviceID: "device1",
},
RefreshToken: "token1",
}
retrievedToken, err := db.GetByCRT(ctx, webCRT)
require.NoError(t, err)
assert.Equal(t, "web-app", retrievedToken.ClientID)
assert.Equal(t, "device1", retrievedToken.DeviceID)
mobileCRT := &model.ClientRefreshToken{
SessionIdentifier: model.SessionIdentifier{
ClientID: "mobile-app",
DeviceID: "device2",
},
RefreshToken: "token2",
}
retrievedToken, err = db.GetByCRT(ctx, mobileCRT)
require.NoError(t, err)
assert.Equal(t, "mobile-app", retrievedToken.ClientID)
assert.Equal(t, "device2", retrievedToken.DeviceID)
})
}
func TestRefreshTokenDB_SecurityScenarios(t *testing.T) {
db, cleanup := setupTestDB(t)
defer cleanup()
ctx := context.Background()
t.Run("Token_Hijacking_Prevention", func(t *testing.T) {
userID := primitive.NewObjectID()
clientID := "web-app"
deviceID := "user-browser"
token := "hijacked_token_123"
// Create legitimate token
refreshToken := createTestRefreshToken(userID, clientID, deviceID, token)
err := db.Create(ctx, refreshToken)
require.NoError(t, err)
// Simulate security concern - revoke token
session := &model.SessionIdentifier{
ClientID: clientID,
DeviceID: deviceID,
}
err = db.Revoke(ctx, userID, session)
require.NoError(t, err)
// Attacker tries to use hijacked token
crt := &model.ClientRefreshToken{
SessionIdentifier: model.SessionIdentifier{
ClientID: clientID,
DeviceID: deviceID,
},
RefreshToken: token,
}
_, err = db.GetByCRT(ctx, crt)
assert.Error(t, err)
assert.True(t, errors.Is(err, merrors.ErrNoData))
})
t.Run("Invalid_Token_Attempts", func(t *testing.T) {
// Try to use completely invalid token
crt := &model.ClientRefreshToken{
SessionIdentifier: model.SessionIdentifier{
ClientID: "invalid-client",
DeviceID: "invalid-device",
},
RefreshToken: "invalid_token_123",
}
_, err := db.GetByCRT(ctx, crt)
assert.Error(t, err)
assert.True(t, errors.Is(err, merrors.ErrNoData))
})
}
func TestRefreshTokenDB_ExpiredTokenHandling(t *testing.T) {
db, cleanup := setupTestDB(t)
defer cleanup()
ctx := context.Background()
t.Run("Expired_Token_Cleanup", func(t *testing.T) {
userID := primitive.NewObjectID()
clientID := "web-app"
deviceID := "user-device"
token := "expired_token_123"
// Create token that expires in the past
refreshToken := createTestRefreshToken(userID, clientID, deviceID, token)
refreshToken.ExpiresAt = time.Now().Add(-1 * time.Hour) // Expired 1 hour ago
err := db.Create(ctx, refreshToken)
require.NoError(t, err)
// The token exists in database but is expired
var storedToken model.RefreshToken
err = db.Get(ctx, *refreshToken.GetID(), &storedToken)
require.NoError(t, err)
assert.True(t, storedToken.ExpiresAt.Before(time.Now()))
// Application should reject expired tokens
crt := &model.ClientRefreshToken{
SessionIdentifier: model.SessionIdentifier{
ClientID: clientID,
DeviceID: deviceID,
},
RefreshToken: token,
}
_, err = db.GetByCRT(ctx, crt)
assert.Error(t, err)
assert.True(t, errors.Is(err, merrors.ErrAccessDenied))
})
}
func TestRefreshTokenDB_ConcurrentAccess(t *testing.T) {
db, cleanup := setupTestDB(t)
defer cleanup()
ctx := context.Background()
t.Run("Concurrent_Token_Usage", func(t *testing.T) {
userID := primitive.NewObjectID()
clientID := "web-app"
deviceID := "user-device"
token := "concurrent_token_123"
// Create token
refreshToken := createTestRefreshToken(userID, clientID, deviceID, token)
err := db.Create(ctx, refreshToken)
require.NoError(t, err)
crt := &model.ClientRefreshToken{
SessionIdentifier: model.SessionIdentifier{
ClientID: clientID,
DeviceID: deviceID,
},
RefreshToken: token,
}
// Simulate concurrent access
done := make(chan error, 2)
go func() {
_, err := db.GetByCRT(ctx, crt)
done <- err
}()
go func() {
_, err := db.GetByCRT(ctx, crt)
done <- err
}()
// Both operations should succeed
for i := 0; i < 2; i++ {
err := <-done
require.NoError(t, err)
}
})
}
func TestRefreshTokenDB_EdgeCases(t *testing.T) {
db, cleanup := setupTestDB(t)
defer cleanup()
ctx := context.Background()
t.Run("Delete_Token_By_ID", func(t *testing.T) {
userID := primitive.NewObjectID()
refreshToken := createTestRefreshToken(userID, "web-app", "device-1", "token_123")
err := db.Create(ctx, refreshToken)
require.NoError(t, err)
tokenID := *refreshToken.GetID()
// Delete token
err = db.Delete(ctx, tokenID)
require.NoError(t, err)
// Token should no longer exist
var result model.RefreshToken
err = db.Get(ctx, tokenID, &result)
assert.Error(t, err)
assert.True(t, errors.Is(err, merrors.ErrNoData))
})
t.Run("Revoke_Non_Existent_Token", func(t *testing.T) {
userID := primitive.NewObjectID()
session := &model.SessionIdentifier{
ClientID: "non-existent-client",
DeviceID: "non-existent-device",
}
err := db.Revoke(ctx, userID, session)
// Should handle gracefully for non-existent tokens
assert.NoError(t, err)
})
t.Run("RevokeAll_No_Other_Devices", func(t *testing.T) {
userID := primitive.NewObjectID()
clientID := "web-app"
deviceID := "only-device"
// Create single token
refreshToken := createTestRefreshToken(userID, clientID, deviceID, "token_123")
err := db.Create(ctx, refreshToken)
require.NoError(t, err)
// RevokeAll should not affect current device
err = db.RevokeAll(ctx, userID, deviceID)
require.NoError(t, err)
// Token should still work
crt := &model.ClientRefreshToken{
SessionIdentifier: model.SessionIdentifier{
ClientID: clientID,
DeviceID: deviceID,
},
RefreshToken: "token_123",
}
_, err = db.GetByCRT(ctx, crt)
require.NoError(t, err)
})
}
func TestRefreshTokenDB_DatabaseIndexes(t *testing.T) {
db, cleanup := setupTestDB(t)
defer cleanup()
ctx := context.Background()
t.Run("Unique_Token_Constraint", func(t *testing.T) {
userID1 := primitive.NewObjectID()
userID2 := primitive.NewObjectID()
token := "duplicate_token_123"
// Create first token
refreshToken1 := createTestRefreshToken(userID1, "client1", "device1", token)
err := db.Create(ctx, refreshToken1)
require.NoError(t, err)
// Try to create second token with same token value - should fail due to unique index
refreshToken2 := createTestRefreshToken(userID2, "client2", "device2", token)
err = db.Create(ctx, refreshToken2)
assert.Error(t, err)
assert.Contains(t, err.Error(), "duplicate")
})
t.Run("Query_Performance_By_Revocation_Status", func(t *testing.T) {
userID := primitive.NewObjectID()
clientID := "web-app"
// Create multiple tokens
for i := 0; i < 10; i++ {
token := createTestRefreshToken(userID, clientID,
fmt.Sprintf("device_%d", i), fmt.Sprintf("token_%d", i))
if i%2 == 0 {
token.IsRevoked = true
}
err := db.Create(ctx, token)
require.NoError(t, err)
}
// Query should efficiently filter by revocation status
query := repository.Query().
Filter(repository.AccountField(), userID).
And(repository.Query().Comparison(repository.Field(refreshtokensdb.IsRevokedField), builder.Eq, false))
ids, err := db.ListIDs(ctx, query)
require.NoError(t, err)
assert.Len(t, ids, 5) // Should find 5 non-revoked tokens
})
}
func TestRefreshTokenDB_IndexPartialUniqueActiveSession(t *testing.T) {
db, database, cleanup := setupTestDBWithMongo(t)
defer cleanup()
ctx := context.Background()
cursor, err := database.Collection(db.Repository.Collection()).Indexes().List(ctx)
require.NoError(t, err)
defer cursor.Close(ctx)
found := false
for cursor.Next(ctx) {
var idx bson.M
require.NoError(t, cursor.Decode(&idx))
if idx["name"] == "unique_active_session" {
found = true
assert.Equal(t, true, idx["unique"])
partial, ok := idx["partialFilterExpression"].(bson.M)
require.True(t, ok)
assert.Equal(t, bson.M{"isRevoked": false}, partial)
}
}
assert.True(t, found, "unique_active_session index not found")
}