Files
sendico/api/pkg/db/internal/mongo/verificationimp/verification_test.go
2026-02-12 20:26:10 +01:00

1050 lines
31 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package verificationimp
import (
"context"
"errors"
"strings"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tech/sendico/pkg/db/repository/builder"
rd "github.com/tech/sendico/pkg/db/repository/decoder"
ri "github.com/tech/sendico/pkg/db/repository/index"
"github.com/tech/sendico/pkg/db/storable"
"github.com/tech/sendico/pkg/db/template"
"github.com/tech/sendico/pkg/db/transaction"
"github.com/tech/sendico/pkg/db/verification"
"github.com/tech/sendico/pkg/merrors"
"github.com/tech/sendico/pkg/model"
"github.com/tech/sendico/pkg/mservice"
"go.mongodb.org/mongo-driver/v2/bson"
"go.mongodb.org/mongo-driver/v2/mongo"
"go.uber.org/zap"
)
// ---------------------------------------------------------------------------
// helpers
// ---------------------------------------------------------------------------
func newTestVerificationDB(t *testing.T) *verificationDB {
return newTestVerificationDBWithFactory(t, &passthroughTxFactory{})
}
func newTestVerificationDBWithFactory(t *testing.T, tf transaction.Factory) *verificationDB {
t.Helper()
repo := newMemoryTokenRepository()
logger := zap.NewNop()
return &verificationDB{
DBImp: template.DBImp[*model.VerificationToken]{
Logger: logger,
Repository: repo,
},
tf: tf,
}
}
// passthroughTxFactory executes callbacks directly without a real transaction.
type passthroughTxFactory struct{}
func (*passthroughTxFactory) CreateTransaction() transaction.Transaction { return &passthroughTx{} }
type passthroughTx struct{}
func (*passthroughTx) Execute(ctx context.Context, cb transaction.Callback) (any, error) {
return cb(ctx)
}
// retryingTxFactory simulates transaction callbacks being executed more than once.
type retryingTxFactory struct{}
func (*retryingTxFactory) CreateTransaction() transaction.Transaction { return &retryingTx{} }
type retryingTx struct{}
func (*retryingTx) Execute(ctx context.Context, cb transaction.Callback) (any, error) {
if _, err := cb(ctx); err != nil {
return nil, err
}
return cb(ctx)
}
// ---------------------------------------------------------------------------
// in-memory repository for VerificationToken
// ---------------------------------------------------------------------------
type memoryTokenRepository struct {
mu sync.Mutex
data map[bson.ObjectID]*model.VerificationToken
order []bson.ObjectID
seq int
}
func newMemoryTokenRepository() *memoryTokenRepository {
return &memoryTokenRepository{data: make(map[bson.ObjectID]*model.VerificationToken)}
}
func (m *memoryTokenRepository) Insert(_ context.Context, obj storable.Storable, _ builder.Query) error {
m.mu.Lock()
defer m.mu.Unlock()
tok, ok := obj.(*model.VerificationToken)
if !ok {
return merrors.InvalidDataType("expected VerificationToken")
}
id := tok.GetID()
if id == nil || *id == bson.NilObjectID {
m.seq++
tok.SetID(bson.NewObjectID())
id = tok.GetID()
}
if _, exists := m.data[*id]; exists {
return merrors.DataConflict("token already exists")
}
for _, existing := range m.data {
if existing.VerifyTokenHash == tok.VerifyTokenHash {
return merrors.DataConflict("duplicate verifyTokenHash")
}
if existing.AccountRef != tok.AccountRef {
continue
}
if existing.Purpose != tok.Purpose {
continue
}
if existing.Target != tok.Target {
continue
}
switch {
case existing.IdempotencyKey == nil && tok.IdempotencyKey == nil:
return merrors.DataConflict("duplicate verification context idempotency")
case existing.IdempotencyKey != nil && tok.IdempotencyKey != nil && *existing.IdempotencyKey == *tok.IdempotencyKey:
return merrors.DataConflict("duplicate verification context idempotency")
}
}
m.data[*id] = cloneToken(tok)
m.order = append(m.order, *id)
return nil
}
func (m *memoryTokenRepository) Get(_ context.Context, id bson.ObjectID, result storable.Storable) error {
m.mu.Lock()
defer m.mu.Unlock()
tok, ok := m.data[id]
if !ok {
return merrors.ErrNoData
}
dst := result.(*model.VerificationToken)
*dst = *cloneToken(tok)
return nil
}
func (m *memoryTokenRepository) FindOneByFilter(_ context.Context, query builder.Query, result storable.Storable) error {
m.mu.Lock()
defer m.mu.Unlock()
for _, id := range m.order {
tok := m.data[id]
if tok != nil && matchToken(query, tok) {
dst := result.(*model.VerificationToken)
*dst = *cloneToken(tok)
return nil
}
}
return merrors.ErrNoData
}
func (m *memoryTokenRepository) Update(_ context.Context, obj storable.Storable) error {
m.mu.Lock()
defer m.mu.Unlock()
tok := obj.(*model.VerificationToken)
id := tok.GetID()
if id == nil {
return merrors.InvalidArgument("id required")
}
if _, exists := m.data[*id]; !exists {
return merrors.ErrNoData
}
m.data[*id] = cloneToken(tok)
return nil
}
func (m *memoryTokenRepository) Upsert(ctx context.Context, obj storable.Storable) error {
id := obj.GetID()
if id == nil || *id == bson.NilObjectID {
return m.Insert(ctx, obj, nil)
}
if err := m.Update(ctx, obj); err != nil {
if errors.Is(err, merrors.ErrNoData) {
return m.Insert(ctx, obj, nil)
}
return err
}
return nil
}
func (m *memoryTokenRepository) PatchMany(_ context.Context, filter builder.Query, patch builder.Patch) (int, error) {
m.mu.Lock()
defer m.mu.Unlock()
patchDoc := patch.Build()
count := 0
for _, id := range m.order {
tok := m.data[id]
if tok != nil && matchToken(filter, tok) {
applyPatch(tok, patchDoc)
count++
}
}
return count, nil
}
// stubs — not exercised by verification DB but required by the interface
func (m *memoryTokenRepository) Aggregate(context.Context, builder.Pipeline, rd.DecodingFunc) error {
return merrors.NotImplemented("not needed")
}
func (m *memoryTokenRepository) InsertMany(ctx context.Context, objs []storable.Storable) error {
for _, o := range objs {
if err := m.Insert(ctx, o, nil); err != nil {
return err
}
}
return nil
}
func (m *memoryTokenRepository) FindManyByFilter(_ context.Context, query builder.Query, decoder rd.DecodingFunc) error {
m.mu.Lock()
var matches []interface{}
for _, id := range m.order {
tok := m.data[id]
if tok != nil && matchToken(query, tok) {
raw, err := bson.Marshal(cloneToken(tok))
if err != nil {
m.mu.Unlock()
return err
}
matches = append(matches, bson.Raw(raw))
}
}
m.mu.Unlock()
cur, err := mongo.NewCursorFromDocuments(matches, nil, nil)
if err != nil {
return err
}
defer cur.Close(context.Background())
for cur.Next(context.Background()) {
if err := decoder(cur); err != nil {
return err
}
}
return nil
}
func (m *memoryTokenRepository) Patch(context.Context, bson.ObjectID, builder.Patch) error {
return merrors.NotImplemented("not needed")
}
func (m *memoryTokenRepository) Delete(_ context.Context, id bson.ObjectID) error {
m.mu.Lock()
defer m.mu.Unlock()
delete(m.data, id)
return nil
}
func (m *memoryTokenRepository) DeleteMany(context.Context, builder.Query) error {
return merrors.NotImplemented("not needed")
}
func (m *memoryTokenRepository) CreateIndex(*ri.Definition) error { return nil }
func (m *memoryTokenRepository) ListIDs(context.Context, builder.Query) ([]bson.ObjectID, error) {
return nil, merrors.NotImplemented("not needed")
}
func (m *memoryTokenRepository) ListPermissionBound(context.Context, builder.Query) ([]model.PermissionBoundStorable, error) {
return nil, merrors.NotImplemented("not needed")
}
func (m *memoryTokenRepository) ListAccountBound(context.Context, builder.Query) ([]model.AccountBoundStorable, error) {
return nil, merrors.NotImplemented("not needed")
}
func (m *memoryTokenRepository) Collection() string { return mservice.VerificationTokens }
// ---------------------------------------------------------------------------
// bson.D query evaluation for VerificationToken
// ---------------------------------------------------------------------------
// tokenFieldValue returns the stored value for a given BSON field name.
func tokenFieldValue(tok *model.VerificationToken, field string) any {
switch field {
case "_id":
return tok.ID
case "createdAt":
return tok.CreatedAt
case "verifyTokenHash":
return tok.VerifyTokenHash
case "salt":
return tok.Salt
case "usedAt":
return tok.UsedAt
case "expiresAt":
return tok.ExpiresAt
case "accountRef":
return tok.AccountRef
case "purpose":
return tok.Purpose
case "target":
return tok.Target
case "idempotencyKey":
if tok.IdempotencyKey == nil {
return nil
}
return *tok.IdempotencyKey
case "maxRetries":
return tok.MaxRetries
case "attempts":
return tok.Attempts
default:
return nil
}
}
// matchToken evaluates a bson.D filter against a token.
func matchToken(query builder.Query, tok *model.VerificationToken) bool {
if query == nil {
return true
}
return matchBsonD(query.BuildQuery(), tok)
}
func matchBsonD(filter bson.D, tok *model.VerificationToken) bool {
for _, elem := range filter {
if !matchElem(elem, tok) {
return false
}
}
return true
}
func matchElem(elem bson.E, tok *model.VerificationToken) bool {
switch elem.Key {
case "$and":
arr, ok := elem.Value.(bson.A)
if !ok {
return false
}
for _, sub := range arr {
d, ok := sub.(bson.D)
if !ok {
return false
}
if !matchBsonD(d, tok) {
return false
}
}
return true
default:
// Either a direct field match or a comparison operator doc.
stored := tokenFieldValue(tok, elem.Key)
// Check for operator document like {$gt: value}
if opDoc, ok := elem.Value.(bson.M); ok {
return matchOperator(stored, opDoc)
}
// Direct equality (including nil check).
return valuesEqual(stored, elem.Value)
}
}
func matchOperator(stored any, ops bson.M) bool {
for op, cmpVal := range ops {
switch op {
case "$gt":
if !cmpGt(stored, cmpVal) {
return false
}
case "$lt":
if !cmpLt(stored, cmpVal) {
return false
}
}
}
return true
}
func cmpGt(stored, cmpVal any) bool {
if si, ok := toInt(stored); ok {
if ci, ok := toInt(cmpVal); ok {
return si > ci
}
}
return timeGt(stored, cmpVal)
}
func cmpLt(stored, cmpVal any) bool {
if si, ok := toInt(stored); ok {
if ci, ok := toInt(cmpVal); ok {
return si < ci
}
}
return timeLt(stored, cmpVal)
}
func toInt(v any) (int, bool) {
switch iv := v.(type) {
case int:
return iv, true
case int64:
return int(iv), true
case int32:
return int(iv), true
}
return 0, false
}
func valuesEqual(a, b any) bool {
// nil checks: usedAt == nil
if b == nil {
return a == nil || a == (*time.Time)(nil)
}
switch av := a.(type) {
case *time.Time:
if av == nil {
return b == nil
}
if bv, ok := b.(*time.Time); ok {
return av.Equal(*bv)
}
return false
case bson.ObjectID:
if bv, ok := b.(bson.ObjectID); ok {
return av == bv
}
return false
case model.VerificationPurpose:
if bv, ok := b.(model.VerificationPurpose); ok {
return av == bv
}
return false
case string:
if bv, ok := b.(string); ok {
return av == bv
}
return false
}
return false
}
func timeGt(stored, cmpVal any) bool {
st, ok := toTime(stored)
if !ok {
return false
}
ct, ok := toTime(cmpVal)
if !ok {
return false
}
return st.After(ct)
}
func timeLt(stored, cmpVal any) bool {
st, ok := toTime(stored)
if !ok {
return false
}
ct, ok := toTime(cmpVal)
if !ok {
return false
}
return st.Before(ct)
}
func toTime(v any) (time.Time, bool) {
switch tv := v.(type) {
case time.Time:
return tv, true
case *time.Time:
if tv == nil {
return time.Time{}, false
}
return *tv, true
}
return time.Time{}, false
}
// applyPatch applies $set and $inc operations from a patch bson.D to a token.
func applyPatch(tok *model.VerificationToken, patchDoc bson.D) {
for _, op := range patchDoc {
switch op.Key {
case "$set":
fields, ok := op.Value.(bson.D)
if !ok {
continue
}
for _, f := range fields {
switch f.Key {
case "usedAt":
if t, ok := f.Value.(time.Time); ok {
tok.UsedAt = &t
}
}
}
case "$inc":
fields, ok := op.Value.(bson.D)
if !ok {
continue
}
for _, f := range fields {
switch f.Key {
case "attempts":
if v, ok := f.Value.(int); ok {
tok.Attempts += v
}
}
}
}
}
}
func cloneToken(src *model.VerificationToken) *model.VerificationToken {
dst := *src
if src.UsedAt != nil {
t := *src.UsedAt
dst.UsedAt = &t
}
if src.MaxRetries != nil {
v := *src.MaxRetries
dst.MaxRetries = &v
}
if src.Salt != nil {
s := *src.Salt
dst.Salt = &s
}
if src.IdempotencyKey != nil {
k := *src.IdempotencyKey
dst.IdempotencyKey = &k
}
return &dst
}
// ---------------------------------------------------------------------------
// helpers request builder
// ---------------------------------------------------------------------------
func req(accountRef bson.ObjectID, purpose model.VerificationPurpose, target string, ttl time.Duration) *verification.Request {
return verification.NewLinkRequest(accountRef, purpose, target).WithTTL(ttl)
}
// ---------------------------------------------------------------------------
// tests
// ---------------------------------------------------------------------------
func TestCreate_ReturnsRawToken(t *testing.T) {
db := newTestVerificationDB(t)
ctx := context.Background()
accountRef := bson.NewObjectID()
raw, err := db.Create(ctx, req(accountRef, model.PurposePasswordReset, "", time.Hour))
require.NoError(t, err)
assert.NotEmpty(t, raw)
}
func TestCreate_TokenCanBeConsumed(t *testing.T) {
db := newTestVerificationDB(t)
ctx := context.Background()
accountRef := bson.NewObjectID()
raw, err := db.Create(ctx, req(accountRef, model.PurposePasswordReset, "", time.Hour))
require.NoError(t, err)
tok, err := db.Consume(ctx, accountRef, model.PurposePasswordReset, raw)
require.NoError(t, err)
assert.Equal(t, accountRef, tok.AccountRef)
assert.Equal(t, model.PurposePasswordReset, tok.Purpose)
assert.NotNil(t, tok.UsedAt)
}
func TestConsume_ReturnsCorrectFields(t *testing.T) {
db := newTestVerificationDB(t)
ctx := context.Background()
accountRef := bson.NewObjectID()
raw, err := db.Create(ctx, req(accountRef, model.PurposeEmailChange, "new@example.com", time.Hour))
require.NoError(t, err)
tok, err := db.Consume(ctx, accountRef, model.PurposeEmailChange, raw)
require.NoError(t, err)
assert.Equal(t, accountRef, tok.AccountRef)
assert.Equal(t, model.PurposeEmailChange, tok.Purpose)
assert.Equal(t, "new@example.com", tok.Target)
}
func TestConsume_SecondConsumeFailsAlreadyUsed(t *testing.T) {
db := newTestVerificationDB(t)
ctx := context.Background()
accountRef := bson.NewObjectID()
raw, err := db.Create(ctx, req(accountRef, model.PurposePasswordReset, "", time.Hour))
require.NoError(t, err)
_, err = db.Consume(ctx, accountRef, model.PurposePasswordReset, raw)
require.NoError(t, err)
_, err = db.Consume(ctx, accountRef, model.PurposePasswordReset, raw)
require.Error(t, err)
assert.True(t, errors.Is(err, verification.ErrTokenAlreadyUsed),
"second consume should fail with already-used after usedAt is set")
}
func TestConsume_ExpiredTokenFails(t *testing.T) {
db := newTestVerificationDB(t)
ctx := context.Background()
accountRef := bson.NewObjectID()
// Create with a TTL that is already in the past.
raw, err := db.Create(ctx, req(accountRef, model.PurposePasswordReset, "", -time.Hour))
require.NoError(t, err)
_, err = db.Consume(ctx, accountRef, model.PurposePasswordReset, raw)
require.Error(t, err)
assert.True(t, errors.Is(err, verification.ErrTokenExpired),
"expired token should return explicit expiry error")
}
func TestConsume_UnknownTokenFails(t *testing.T) {
db := newTestVerificationDB(t)
ctx := context.Background()
_, err := db.Consume(ctx, bson.NilObjectID, "", "nonexistent-token-value")
require.Error(t, err)
assert.True(t, errors.Is(err, verification.ErrTokenNotFound))
}
func TestConsume_AccountActivationWithoutAccountRef(t *testing.T) {
db := newTestVerificationDB(t)
ctx := context.Background()
accountRef := bson.NewObjectID()
raw, err := db.Create(ctx, req(accountRef, model.PurposeAccountActivation, "", time.Hour))
require.NoError(t, err)
tok, err := db.Consume(ctx, bson.NilObjectID, model.PurposeAccountActivation, raw)
require.NoError(t, err)
assert.Equal(t, accountRef, tok.AccountRef)
assert.Equal(t, model.PurposeAccountActivation, tok.Purpose)
}
func TestCreate_InvalidatesPreviousToken(t *testing.T) {
db := newTestVerificationDB(t)
ctx := context.Background()
accountRef := bson.NewObjectID()
oldRaw, err := db.Create(ctx, req(accountRef, model.PurposePasswordReset, "", time.Hour))
require.NoError(t, err)
newRaw, err := db.Create(ctx, req(accountRef, model.PurposePasswordReset, "", time.Hour))
require.NoError(t, err)
assert.NotEqual(t, oldRaw, newRaw, "new token should differ from old one")
// Old token is no longer consumable — invalidated (usedAt set) by the second Create.
_, err = db.Consume(ctx, accountRef, model.PurposePasswordReset, oldRaw)
require.Error(t, err)
assert.True(t, errors.Is(err, verification.ErrTokenAlreadyUsed),
"old token should return already-used after invalidation")
// New token works fine.
tok, err := db.Consume(ctx, accountRef, model.PurposePasswordReset, newRaw)
require.NoError(t, err)
assert.Equal(t, accountRef, tok.AccountRef)
}
func TestCreate_AccountActivationResendWithoutIdempotency_ReissuesToken(t *testing.T) {
db := newTestVerificationDB(t)
ctx := context.Background()
accountRef := bson.NewObjectID()
// First issue during signup.
firstRaw, err := db.Create(ctx, req(accountRef, model.PurposeAccountActivation, "", time.Hour))
require.NoError(t, err)
// Second issue during resend should rotate token instead of failing with duplicate key.
secondRaw, err := db.Create(ctx, req(accountRef, model.PurposeAccountActivation, "", time.Hour))
require.NoError(t, err)
assert.NotEqual(t, firstRaw, secondRaw)
// Old token becomes unusable after reissue.
_, err = db.Consume(ctx, bson.NilObjectID, model.PurposeAccountActivation, firstRaw)
require.Error(t, err)
assert.True(t, errors.Is(err, verification.ErrTokenAlreadyUsed))
// New token is valid.
tok, err := db.Consume(ctx, bson.NilObjectID, model.PurposeAccountActivation, secondRaw)
require.NoError(t, err)
assert.Equal(t, accountRef, tok.AccountRef)
assert.Equal(t, model.PurposeAccountActivation, tok.Purpose)
// Non-idempotent requests should still persist unique internal keys,
// preventing uniqueness collisions on (accountRef, purpose, target, idempotencyKey).
repo := db.Repository.(*memoryTokenRepository)
repo.mu.Lock()
defer repo.mu.Unlock()
keys := map[string]struct{}{}
for _, stored := range repo.data {
if stored.AccountRef != accountRef || stored.Purpose != model.PurposeAccountActivation {
continue
}
require.NotNil(t, stored.IdempotencyKey)
assert.True(t, strings.HasPrefix(*stored.IdempotencyKey, "auto:"))
keys[*stored.IdempotencyKey] = struct{}{}
}
assert.Len(t, keys, 2)
}
func TestCreate_InvalidatesMultiplePreviousTokens(t *testing.T) {
db := newTestVerificationDB(t)
ctx := context.Background()
accountRef := bson.NewObjectID()
first, err := db.Create(ctx, req(accountRef, model.PurposePasswordReset, "", time.Hour))
require.NoError(t, err)
second, err := db.Create(ctx, req(accountRef, model.PurposePasswordReset, "", time.Hour))
require.NoError(t, err)
third, err := db.Create(ctx, req(accountRef, model.PurposePasswordReset, "", time.Hour))
require.NoError(t, err)
_, err = db.Consume(ctx, accountRef, model.PurposePasswordReset, first)
assert.True(t, errors.Is(err, verification.ErrTokenAlreadyUsed), "first should be invalidated/used")
_, err = db.Consume(ctx, accountRef, model.PurposePasswordReset, second)
assert.True(t, errors.Is(err, verification.ErrTokenAlreadyUsed), "second should be invalidated/used")
tok, err := db.Consume(ctx, accountRef, model.PurposePasswordReset, third)
require.NoError(t, err)
assert.Equal(t, accountRef, tok.AccountRef)
}
func TestCreate_DifferentPurposeNotInvalidated(t *testing.T) {
db := newTestVerificationDB(t)
ctx := context.Background()
accountRef := bson.NewObjectID()
resetRaw, err := db.Create(ctx, req(accountRef, model.PurposePasswordReset, "", time.Hour))
require.NoError(t, err)
// Creating an activation token should NOT invalidate the password-reset token.
_, err = db.Create(ctx, req(accountRef, model.PurposeAccountActivation, "", time.Hour))
require.NoError(t, err)
tok, err := db.Consume(ctx, accountRef, model.PurposePasswordReset, resetRaw)
require.NoError(t, err)
assert.Equal(t, model.PurposePasswordReset, tok.Purpose)
}
func TestCreate_DifferentTargetNotInvalidated(t *testing.T) {
db := newTestVerificationDB(t)
ctx := context.Background()
accountRef := bson.NewObjectID()
firstRaw, err := db.Create(ctx, req(accountRef, model.PurposeEmailChange, "a@example.com", time.Hour))
require.NoError(t, err)
// Creating a token for a different target email should NOT invalidate the first.
_, err = db.Create(ctx, req(accountRef, model.PurposeEmailChange, "b@example.com", time.Hour))
require.NoError(t, err)
tok, err := db.Consume(ctx, accountRef, model.PurposeEmailChange, firstRaw)
require.NoError(t, err)
assert.Equal(t, "a@example.com", tok.Target)
}
func TestCreate_DifferentAccountNotInvalidated(t *testing.T) {
db := newTestVerificationDB(t)
ctx := context.Background()
account1 := bson.NewObjectID()
account2 := bson.NewObjectID()
raw1, err := db.Create(ctx, req(account1, model.PurposePasswordReset, "", time.Hour))
require.NoError(t, err)
_, err = db.Create(ctx, req(account2, model.PurposePasswordReset, "", time.Hour))
require.NoError(t, err)
tok, err := db.Consume(ctx, account1, model.PurposePasswordReset, raw1)
require.NoError(t, err)
assert.Equal(t, account1, tok.AccountRef)
}
func TestCreate_AlreadyUsedTokenNotInvalidatedAgain(t *testing.T) {
db := newTestVerificationDB(t)
ctx := context.Background()
accountRef := bson.NewObjectID()
// Create and consume first token.
raw1, err := db.Create(ctx, req(accountRef, model.PurposePasswordReset, "", time.Hour))
require.NoError(t, err)
_, err = db.Consume(ctx, accountRef, model.PurposePasswordReset, raw1)
require.NoError(t, err)
// Create second — the already-consumed token should have usedAt set,
// so the invalidation query (usedAt == nil) should skip it.
// This tests that the PatchMany filter correctly excludes already-used tokens.
raw2, err := db.Create(ctx, req(accountRef, model.PurposePasswordReset, "", time.Hour))
require.NoError(t, err)
tok, err := db.Consume(ctx, accountRef, model.PurposePasswordReset, raw2)
require.NoError(t, err)
assert.Equal(t, accountRef, tok.AccountRef)
}
func TestCreate_ExpiredTokenNotInvalidated(t *testing.T) {
db := newTestVerificationDB(t)
ctx := context.Background()
accountRef := bson.NewObjectID()
// Create a token that is already expired.
_, err := db.Create(ctx, req(accountRef, model.PurposePasswordReset, "", -time.Hour))
require.NoError(t, err)
// Create a fresh one — invalidation should skip the expired token (expiresAt > now filter).
raw2, err := db.Create(ctx, req(accountRef, model.PurposePasswordReset, "", time.Hour))
require.NoError(t, err)
tok, err := db.Consume(ctx, accountRef, model.PurposePasswordReset, raw2)
require.NoError(t, err)
assert.Equal(t, accountRef, tok.AccountRef)
}
func TestTokenHash_Deterministic(t *testing.T) {
h1 := tokenHash("same-input")
h2 := tokenHash("same-input")
assert.Equal(t, h1, h2)
}
func TestTokenHash_DifferentInputs(t *testing.T) {
h1 := tokenHash("input-a")
h2 := tokenHash("input-b")
assert.NotEqual(t, h1, h2)
}
// ---------------------------------------------------------------------------
// cooldown tests
// ---------------------------------------------------------------------------
func TestCreate_CooldownBlocksCreation(t *testing.T) {
db := newTestVerificationDB(t)
ctx := context.Background()
accountRef := bson.NewObjectID()
// First creation without cooldown.
_, err := db.Create(ctx, req(accountRef, model.PurposePasswordReset, "", time.Hour))
require.NoError(t, err)
// Immediate re-create with cooldown should be blocked — token is too recent to invalidate.
r2 := req(accountRef, model.PurposePasswordReset, "", time.Hour).WithCooldown(time.Minute)
_, err = db.Create(ctx, r2)
require.Error(t, err)
assert.True(t, errors.Is(err, verification.ErrCooldownActive))
}
func TestCreate_CooldownExpiresAllowsCreation(t *testing.T) {
db := newTestVerificationDB(t)
ctx := context.Background()
accountRef := bson.NewObjectID()
// First creation without cooldown.
_, err := db.Create(ctx, req(accountRef, model.PurposePasswordReset, "", time.Hour))
require.NoError(t, err)
time.Sleep(2 * time.Millisecond)
// Re-create with short cooldown — the prior token is old enough to be invalidated.
r2 := req(accountRef, model.PurposePasswordReset, "", time.Hour).WithCooldown(time.Millisecond)
_, err = db.Create(ctx, r2)
require.NoError(t, err)
}
func TestCreate_CooldownNilIgnored(t *testing.T) {
db := newTestVerificationDB(t)
ctx := context.Background()
accountRef := bson.NewObjectID()
_, err := db.Create(ctx, req(accountRef, model.PurposePasswordReset, "", time.Hour))
require.NoError(t, err)
// No cooldown set — immediate re-create should succeed.
_, err = db.Create(ctx, req(accountRef, model.PurposePasswordReset, "", time.Hour))
require.NoError(t, err)
}
func TestCreate_IdempotencyKeyReplayReturnsSameToken(t *testing.T) {
db := newTestVerificationDB(t)
ctx := context.Background()
accountRef := bson.NewObjectID()
firstReq := req(accountRef, model.PurposePasswordReset, "", time.Hour).WithIdempotencyKey("same-key")
firstRaw, err := db.Create(ctx, firstReq)
require.NoError(t, err)
require.NotEmpty(t, firstRaw)
// Replay with the same idempotency key should return success and same token.
secondReq := req(accountRef, model.PurposePasswordReset, "", time.Hour).WithIdempotencyKey("same-key")
secondRaw, err := db.Create(ctx, secondReq)
require.NoError(t, err)
assert.Equal(t, firstRaw, secondRaw)
repo := db.Repository.(*memoryTokenRepository)
repo.mu.Lock()
assert.Len(t, repo.data, 1)
repo.mu.Unlock()
}
func TestCreate_IdempotencyScopeIncludesTarget(t *testing.T) {
db := newTestVerificationDB(t)
ctx := context.Background()
accountRef := bson.NewObjectID()
r1 := req(accountRef, model.PurposeEmailChange, "a@example.com", time.Hour).WithIdempotencyKey("same-key")
raw1, err := db.Create(ctx, r1)
require.NoError(t, err)
require.NotEmpty(t, raw1)
// Same account/purpose/key but different target should be treated as a different idempotency scope.
r2 := req(accountRef, model.PurposeEmailChange, "b@example.com", time.Hour).WithIdempotencyKey("same-key")
raw2, err := db.Create(ctx, r2)
require.NoError(t, err)
require.NotEmpty(t, raw2)
assert.NotEqual(t, raw1, raw2)
t1, err := db.Consume(ctx, accountRef, model.PurposeEmailChange, raw1)
require.NoError(t, err)
assert.Equal(t, "a@example.com", t1.Target)
t2, err := db.Consume(ctx, accountRef, model.PurposeEmailChange, raw2)
require.NoError(t, err)
assert.Equal(t, "b@example.com", t2.Target)
}
func TestCreate_IdempotencySurvivesCallbackRetry(t *testing.T) {
db := newTestVerificationDBWithFactory(t, &retryingTxFactory{})
ctx := context.Background()
accountRef := bson.NewObjectID()
// Cooldown would block the second callback execution if idempotency wasn't handled.
r := req(accountRef, model.PurposePasswordReset, "", time.Hour).
WithCooldown(time.Minute).
WithIdempotencyKey("retry-safe")
raw, err := db.Create(ctx, r)
require.NoError(t, err)
require.NotEmpty(t, raw)
repo := db.Repository.(*memoryTokenRepository)
repo.mu.Lock()
require.Len(t, repo.data, 1)
for _, tok := range repo.data {
require.NotNil(t, tok.IdempotencyKey)
assert.Equal(t, "retry-safe", *tok.IdempotencyKey)
assert.Nil(t, tok.UsedAt)
assert.Equal(t, tok.VerifyTokenHash, tokenHash(raw))
}
repo.mu.Unlock()
}
// ---------------------------------------------------------------------------
// max retries / attempts tests
// ---------------------------------------------------------------------------
func TestConsume_MaxRetriesExceeded(t *testing.T) {
db := newTestVerificationDB(t)
ctx := context.Background()
accountRef := bson.NewObjectID()
r := req(accountRef, model.PurposePasswordReset, "", time.Hour).WithMaxRetries(2)
raw, err := db.Create(ctx, r)
require.NoError(t, err)
// Simulate 2 prior failed attempts by setting Attempts directly.
repo := db.Repository.(*memoryTokenRepository)
repo.mu.Lock()
for _, tok := range repo.data {
tok.Attempts = 2
}
repo.mu.Unlock()
// Consume with correct token should fail — attempts already at max.
_, err = db.Consume(ctx, accountRef, model.PurposePasswordReset, raw)
require.Error(t, err)
assert.True(t, errors.Is(err, verification.ErrTokenAttemptsExceeded))
}
func TestConsume_UnderMaxRetriesSucceeds(t *testing.T) {
db := newTestVerificationDB(t)
ctx := context.Background()
accountRef := bson.NewObjectID()
r := req(accountRef, model.PurposePasswordReset, "", time.Hour).WithMaxRetries(3)
raw, err := db.Create(ctx, r)
require.NoError(t, err)
// Simulate 2 prior failed attempts (under maxRetries=3).
repo := db.Repository.(*memoryTokenRepository)
repo.mu.Lock()
for _, tok := range repo.data {
tok.Attempts = 2
}
repo.mu.Unlock()
// Consume with correct token should succeed.
tok, err := db.Consume(ctx, accountRef, model.PurposePasswordReset, raw)
require.NoError(t, err)
assert.Equal(t, accountRef, tok.AccountRef)
}
func TestConsume_NoMaxRetriesIgnoresAttempts(t *testing.T) {
db := newTestVerificationDB(t)
ctx := context.Background()
accountRef := bson.NewObjectID()
// Create without MaxRetries.
raw, err := db.Create(ctx, req(accountRef, model.PurposePasswordReset, "", time.Hour))
require.NoError(t, err)
// Simulate high attempt count — should be ignored since MaxRetries is nil.
repo := db.Repository.(*memoryTokenRepository)
repo.mu.Lock()
for _, tok := range repo.data {
tok.Attempts = 100
}
repo.mu.Unlock()
tok, err := db.Consume(ctx, accountRef, model.PurposePasswordReset, raw)
require.NoError(t, err)
assert.Equal(t, accountRef, tok.AccountRef)
}
func TestConsume_WrongHashReturnsNotFound(t *testing.T) {
db := newTestVerificationDB(t)
ctx := context.Background()
accountRef := bson.NewObjectID()
_, err := db.Create(ctx, req(accountRef, model.PurposePasswordReset, "", time.Hour))
require.NoError(t, err)
// Wrong code — hash won't match any token.
_, err = db.Consume(ctx, accountRef, model.PurposePasswordReset, "wrong-code")
require.Error(t, err)
assert.True(t, errors.Is(err, verification.ErrTokenNotFound))
}
func TestConsume_ContextMismatchReturnsNotFound(t *testing.T) {
db := newTestVerificationDB(t)
ctx := context.Background()
accountRef := bson.NewObjectID()
otherAccount := bson.NewObjectID()
raw, err := db.Create(ctx, req(accountRef, model.PurposePasswordReset, "", time.Hour))
require.NoError(t, err)
// Correct token but wrong accountRef — context mismatch.
_, err = db.Consume(ctx, otherAccount, model.PurposePasswordReset, raw)
require.Error(t, err)
assert.True(t, errors.Is(err, verification.ErrTokenNotFound))
}