1061 lines
31 KiB
Go
1061 lines
31 KiB
Go
package verificationimp
|
||
|
||
import (
|
||
"context"
|
||
"errors"
|
||
"strings"
|
||
"sync"
|
||
"testing"
|
||
"time"
|
||
|
||
"github.com/stretchr/testify/assert"
|
||
"github.com/stretchr/testify/require"
|
||
"github.com/tech/sendico/pkg/db/repository/builder"
|
||
rd "github.com/tech/sendico/pkg/db/repository/decoder"
|
||
ri "github.com/tech/sendico/pkg/db/repository/index"
|
||
"github.com/tech/sendico/pkg/db/storable"
|
||
"github.com/tech/sendico/pkg/db/template"
|
||
"github.com/tech/sendico/pkg/db/transaction"
|
||
"github.com/tech/sendico/pkg/db/verification"
|
||
"github.com/tech/sendico/pkg/merrors"
|
||
"github.com/tech/sendico/pkg/model"
|
||
"github.com/tech/sendico/pkg/mservice"
|
||
"go.mongodb.org/mongo-driver/v2/bson"
|
||
"go.mongodb.org/mongo-driver/v2/mongo"
|
||
"go.uber.org/zap"
|
||
)
|
||
|
||
// ---------------------------------------------------------------------------
|
||
// helpers
|
||
// ---------------------------------------------------------------------------
|
||
|
||
func newTestVerificationDB(t *testing.T) *verificationDB {
|
||
return newTestVerificationDBWithFactory(t, &passthroughTxFactory{})
|
||
}
|
||
|
||
func newTestVerificationDBWithFactory(t *testing.T, tf transaction.Factory) *verificationDB {
|
||
t.Helper()
|
||
repo := newMemoryTokenRepository()
|
||
logger := zap.NewNop()
|
||
return &verificationDB{
|
||
DBImp: template.DBImp[*model.VerificationToken]{
|
||
Logger: logger,
|
||
Repository: repo,
|
||
},
|
||
tf: tf,
|
||
}
|
||
}
|
||
|
||
// passthroughTxFactory executes callbacks directly without a real transaction.
|
||
type passthroughTxFactory struct{}
|
||
|
||
func (*passthroughTxFactory) CreateTransaction() transaction.Transaction { return &passthroughTx{} }
|
||
|
||
type passthroughTx struct{}
|
||
|
||
func (*passthroughTx) Execute(ctx context.Context, cb transaction.Callback) (any, error) {
|
||
return cb(ctx)
|
||
}
|
||
|
||
// retryingTxFactory simulates transaction callbacks being executed more than once.
|
||
type retryingTxFactory struct{}
|
||
|
||
func (*retryingTxFactory) CreateTransaction() transaction.Transaction { return &retryingTx{} }
|
||
|
||
type retryingTx struct{}
|
||
|
||
func (*retryingTx) Execute(ctx context.Context, cb transaction.Callback) (any, error) {
|
||
if _, err := cb(ctx); err != nil {
|
||
return nil, err
|
||
}
|
||
return cb(ctx)
|
||
}
|
||
|
||
// ---------------------------------------------------------------------------
|
||
// in-memory repository for VerificationToken
|
||
// ---------------------------------------------------------------------------
|
||
|
||
type memoryTokenRepository struct {
|
||
mu sync.Mutex
|
||
data map[bson.ObjectID]*model.VerificationToken
|
||
order []bson.ObjectID
|
||
seq int
|
||
}
|
||
|
||
func newMemoryTokenRepository() *memoryTokenRepository {
|
||
return &memoryTokenRepository{data: make(map[bson.ObjectID]*model.VerificationToken)}
|
||
}
|
||
|
||
func (m *memoryTokenRepository) Insert(_ context.Context, obj storable.Storable, _ builder.Query) error {
|
||
m.mu.Lock()
|
||
defer m.mu.Unlock()
|
||
tok, ok := obj.(*model.VerificationToken)
|
||
if !ok {
|
||
return merrors.InvalidDataType("expected VerificationToken")
|
||
}
|
||
id := tok.GetID()
|
||
if id == nil || *id == bson.NilObjectID {
|
||
m.seq++
|
||
tok.SetID(bson.NewObjectID())
|
||
id = tok.GetID()
|
||
}
|
||
if _, exists := m.data[*id]; exists {
|
||
return merrors.DataConflict("token already exists")
|
||
}
|
||
for _, existing := range m.data {
|
||
if existing.VerifyTokenHash == tok.VerifyTokenHash {
|
||
return merrors.DataConflict("duplicate verifyTokenHash")
|
||
}
|
||
if existing.AccountRef != tok.AccountRef {
|
||
continue
|
||
}
|
||
if existing.Purpose != tok.Purpose {
|
||
continue
|
||
}
|
||
if existing.Target != tok.Target {
|
||
continue
|
||
}
|
||
|
||
switch {
|
||
case existing.IdempotencyKey == nil && tok.IdempotencyKey == nil:
|
||
return merrors.DataConflict("duplicate verification context idempotency")
|
||
case existing.IdempotencyKey != nil && tok.IdempotencyKey != nil && *existing.IdempotencyKey == *tok.IdempotencyKey:
|
||
return merrors.DataConflict("duplicate verification context idempotency")
|
||
}
|
||
}
|
||
m.data[*id] = cloneToken(tok)
|
||
m.order = append(m.order, *id)
|
||
return nil
|
||
}
|
||
|
||
func (m *memoryTokenRepository) Get(_ context.Context, id bson.ObjectID, result storable.Storable) error {
|
||
m.mu.Lock()
|
||
defer m.mu.Unlock()
|
||
tok, ok := m.data[id]
|
||
if !ok {
|
||
return merrors.ErrNoData
|
||
}
|
||
dst := result.(*model.VerificationToken)
|
||
*dst = *cloneToken(tok)
|
||
return nil
|
||
}
|
||
|
||
func (m *memoryTokenRepository) FindOneByFilter(_ context.Context, query builder.Query, result storable.Storable) error {
|
||
m.mu.Lock()
|
||
defer m.mu.Unlock()
|
||
for _, id := range m.order {
|
||
tok := m.data[id]
|
||
if tok != nil && matchToken(query, tok) {
|
||
dst := result.(*model.VerificationToken)
|
||
*dst = *cloneToken(tok)
|
||
return nil
|
||
}
|
||
}
|
||
return merrors.ErrNoData
|
||
}
|
||
|
||
func (m *memoryTokenRepository) Update(_ context.Context, obj storable.Storable) error {
|
||
m.mu.Lock()
|
||
defer m.mu.Unlock()
|
||
tok := obj.(*model.VerificationToken)
|
||
id := tok.GetID()
|
||
if id == nil {
|
||
return merrors.InvalidArgument("id required")
|
||
}
|
||
if _, exists := m.data[*id]; !exists {
|
||
return merrors.ErrNoData
|
||
}
|
||
m.data[*id] = cloneToken(tok)
|
||
return nil
|
||
}
|
||
|
||
func (m *memoryTokenRepository) Upsert(ctx context.Context, obj storable.Storable) error {
|
||
id := obj.GetID()
|
||
if id == nil || *id == bson.NilObjectID {
|
||
return m.Insert(ctx, obj, nil)
|
||
}
|
||
if err := m.Update(ctx, obj); err != nil {
|
||
if errors.Is(err, merrors.ErrNoData) {
|
||
return m.Insert(ctx, obj, nil)
|
||
}
|
||
return err
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func (m *memoryTokenRepository) PatchMany(_ context.Context, filter builder.Query, patch builder.Patch) (int, error) {
|
||
m.mu.Lock()
|
||
defer m.mu.Unlock()
|
||
patchDoc := patch.Build()
|
||
count := 0
|
||
for _, id := range m.order {
|
||
tok := m.data[id]
|
||
if tok != nil && matchToken(filter, tok) {
|
||
applyPatch(tok, patchDoc)
|
||
count++
|
||
}
|
||
}
|
||
return count, nil
|
||
}
|
||
|
||
// stubs — not exercised by verification DB but required by the interface
|
||
|
||
func (m *memoryTokenRepository) Aggregate(context.Context, builder.Pipeline, rd.DecodingFunc) error {
|
||
return merrors.NotImplemented("not needed")
|
||
}
|
||
func (m *memoryTokenRepository) InsertMany(ctx context.Context, objs []storable.Storable) error {
|
||
for _, o := range objs {
|
||
if err := m.Insert(ctx, o, nil); err != nil {
|
||
return err
|
||
}
|
||
}
|
||
return nil
|
||
}
|
||
func (m *memoryTokenRepository) FindManyByFilter(_ context.Context, query builder.Query, decoder rd.DecodingFunc) error {
|
||
m.mu.Lock()
|
||
var matches []interface{}
|
||
for _, id := range m.order {
|
||
tok := m.data[id]
|
||
if tok != nil && matchToken(query, tok) {
|
||
raw, err := bson.Marshal(cloneToken(tok))
|
||
if err != nil {
|
||
m.mu.Unlock()
|
||
return err
|
||
}
|
||
matches = append(matches, bson.Raw(raw))
|
||
}
|
||
}
|
||
m.mu.Unlock()
|
||
|
||
cur, err := mongo.NewCursorFromDocuments(matches, nil, nil)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
defer cur.Close(context.Background())
|
||
|
||
for cur.Next(context.Background()) {
|
||
if err := decoder(cur); err != nil {
|
||
return err
|
||
}
|
||
}
|
||
return nil
|
||
}
|
||
func (m *memoryTokenRepository) Patch(context.Context, bson.ObjectID, builder.Patch) error {
|
||
return merrors.NotImplemented("not needed")
|
||
}
|
||
func (m *memoryTokenRepository) Delete(_ context.Context, id bson.ObjectID) error {
|
||
m.mu.Lock()
|
||
defer m.mu.Unlock()
|
||
delete(m.data, id)
|
||
return nil
|
||
}
|
||
func (m *memoryTokenRepository) DeleteMany(context.Context, builder.Query) error {
|
||
return merrors.NotImplemented("not needed")
|
||
}
|
||
func (m *memoryTokenRepository) CreateIndex(*ri.Definition) error { return nil }
|
||
func (m *memoryTokenRepository) ListIDs(context.Context, builder.Query) ([]bson.ObjectID, error) {
|
||
return nil, merrors.NotImplemented("not needed")
|
||
}
|
||
func (m *memoryTokenRepository) ListPermissionBound(context.Context, builder.Query) ([]model.PermissionBoundStorable, error) {
|
||
return nil, merrors.NotImplemented("not needed")
|
||
}
|
||
func (m *memoryTokenRepository) ListAccountBound(context.Context, builder.Query) ([]model.AccountBoundStorable, error) {
|
||
return nil, merrors.NotImplemented("not needed")
|
||
}
|
||
func (m *memoryTokenRepository) Collection() string { return mservice.VerificationTokens }
|
||
|
||
// ---------------------------------------------------------------------------
|
||
// bson.D query evaluation for VerificationToken
|
||
// ---------------------------------------------------------------------------
|
||
|
||
// tokenFieldValue returns the stored value for a given BSON field name.
|
||
func tokenFieldValue(tok *model.VerificationToken, field string) any {
|
||
switch field {
|
||
case "_id":
|
||
return tok.ID
|
||
case "createdAt":
|
||
return tok.CreatedAt
|
||
case "verifyTokenHash":
|
||
return tok.VerifyTokenHash
|
||
case "salt":
|
||
return tok.Salt
|
||
case "usedAt":
|
||
return tok.UsedAt
|
||
case "expiresAt":
|
||
return tok.ExpiresAt
|
||
case "accountRef":
|
||
return tok.AccountRef
|
||
case "purpose":
|
||
return tok.Purpose
|
||
case "target":
|
||
return tok.Target
|
||
case "idempotencyKey":
|
||
if tok.IdempotencyKey == nil {
|
||
return nil
|
||
}
|
||
return *tok.IdempotencyKey
|
||
case "maxRetries":
|
||
return tok.MaxRetries
|
||
case "attempts":
|
||
return tok.Attempts
|
||
default:
|
||
return nil
|
||
}
|
||
}
|
||
|
||
// matchToken evaluates a bson.D filter against a token.
|
||
func matchToken(query builder.Query, tok *model.VerificationToken) bool {
|
||
if query == nil {
|
||
return true
|
||
}
|
||
return matchBsonD(query.BuildQuery(), tok)
|
||
}
|
||
|
||
func matchBsonD(filter bson.D, tok *model.VerificationToken) bool {
|
||
for _, elem := range filter {
|
||
if !matchElem(elem, tok) {
|
||
return false
|
||
}
|
||
}
|
||
return true
|
||
}
|
||
|
||
func matchElem(elem bson.E, tok *model.VerificationToken) bool {
|
||
switch elem.Key {
|
||
|
||
case "$and":
|
||
arr, ok := elem.Value.(bson.A)
|
||
if !ok {
|
||
return false
|
||
}
|
||
for _, sub := range arr {
|
||
d, ok := sub.(bson.D)
|
||
if !ok {
|
||
return false
|
||
}
|
||
if !matchBsonD(d, tok) {
|
||
return false
|
||
}
|
||
}
|
||
return true
|
||
|
||
default:
|
||
// Either a direct field match or a comparison operator doc.
|
||
stored := tokenFieldValue(tok, elem.Key)
|
||
|
||
// Check for operator document like {$gt: value}
|
||
if opDoc, ok := elem.Value.(bson.M); ok {
|
||
return matchOperator(stored, opDoc)
|
||
}
|
||
|
||
// Direct equality (including nil check).
|
||
return valuesEqual(stored, elem.Value)
|
||
}
|
||
}
|
||
|
||
func matchOperator(stored any, ops bson.M) bool {
|
||
for op, cmpVal := range ops {
|
||
switch op {
|
||
case "$gt":
|
||
if !cmpGt(stored, cmpVal) {
|
||
return false
|
||
}
|
||
case "$lt":
|
||
if !cmpLt(stored, cmpVal) {
|
||
return false
|
||
}
|
||
}
|
||
}
|
||
return true
|
||
}
|
||
|
||
func cmpGt(stored, cmpVal any) bool {
|
||
if si, ok := toInt(stored); ok {
|
||
if ci, ok := toInt(cmpVal); ok {
|
||
return si > ci
|
||
}
|
||
}
|
||
return timeGt(stored, cmpVal)
|
||
}
|
||
|
||
func cmpLt(stored, cmpVal any) bool {
|
||
if si, ok := toInt(stored); ok {
|
||
if ci, ok := toInt(cmpVal); ok {
|
||
return si < ci
|
||
}
|
||
}
|
||
return timeLt(stored, cmpVal)
|
||
}
|
||
|
||
func toInt(v any) (int, bool) {
|
||
switch iv := v.(type) {
|
||
case int:
|
||
return iv, true
|
||
case int64:
|
||
return int(iv), true
|
||
case int32:
|
||
return int(iv), true
|
||
}
|
||
return 0, false
|
||
}
|
||
|
||
func valuesEqual(a, b any) bool {
|
||
// nil checks: usedAt == nil
|
||
if b == nil {
|
||
return a == nil || a == (*time.Time)(nil)
|
||
}
|
||
switch av := a.(type) {
|
||
case *time.Time:
|
||
if av == nil {
|
||
return b == nil
|
||
}
|
||
if bv, ok := b.(*time.Time); ok {
|
||
return av.Equal(*bv)
|
||
}
|
||
return false
|
||
case bson.ObjectID:
|
||
if bv, ok := b.(bson.ObjectID); ok {
|
||
return av == bv
|
||
}
|
||
return false
|
||
case model.VerificationPurpose:
|
||
if bv, ok := b.(model.VerificationPurpose); ok {
|
||
return av == bv
|
||
}
|
||
return false
|
||
case string:
|
||
if bv, ok := b.(string); ok {
|
||
return av == bv
|
||
}
|
||
return false
|
||
}
|
||
return false
|
||
}
|
||
|
||
func timeGt(stored, cmpVal any) bool {
|
||
st, ok := toTime(stored)
|
||
if !ok {
|
||
return false
|
||
}
|
||
ct, ok := toTime(cmpVal)
|
||
if !ok {
|
||
return false
|
||
}
|
||
return st.After(ct)
|
||
}
|
||
|
||
func timeLt(stored, cmpVal any) bool {
|
||
st, ok := toTime(stored)
|
||
if !ok {
|
||
return false
|
||
}
|
||
ct, ok := toTime(cmpVal)
|
||
if !ok {
|
||
return false
|
||
}
|
||
return st.Before(ct)
|
||
}
|
||
|
||
func toTime(v any) (time.Time, bool) {
|
||
switch tv := v.(type) {
|
||
case time.Time:
|
||
return tv, true
|
||
case *time.Time:
|
||
if tv == nil {
|
||
return time.Time{}, false
|
||
}
|
||
return *tv, true
|
||
}
|
||
return time.Time{}, false
|
||
}
|
||
|
||
// applyPatch applies $set and $inc operations from a patch bson.D to a token.
|
||
func applyPatch(tok *model.VerificationToken, patchDoc bson.D) {
|
||
for _, op := range patchDoc {
|
||
switch op.Key {
|
||
case "$set":
|
||
fields, ok := op.Value.(bson.D)
|
||
if !ok {
|
||
continue
|
||
}
|
||
for _, f := range fields {
|
||
switch f.Key {
|
||
case "usedAt":
|
||
if t, ok := f.Value.(time.Time); ok {
|
||
tok.UsedAt = &t
|
||
}
|
||
}
|
||
}
|
||
case "$inc":
|
||
fields, ok := op.Value.(bson.D)
|
||
if !ok {
|
||
continue
|
||
}
|
||
for _, f := range fields {
|
||
switch f.Key {
|
||
case "attempts":
|
||
if v, ok := f.Value.(int); ok {
|
||
tok.Attempts += v
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
func cloneToken(src *model.VerificationToken) *model.VerificationToken {
|
||
dst := *src
|
||
if src.UsedAt != nil {
|
||
t := *src.UsedAt
|
||
dst.UsedAt = &t
|
||
}
|
||
if src.MaxRetries != nil {
|
||
v := *src.MaxRetries
|
||
dst.MaxRetries = &v
|
||
}
|
||
if src.Salt != nil {
|
||
s := *src.Salt
|
||
dst.Salt = &s
|
||
}
|
||
if src.IdempotencyKey != nil {
|
||
k := *src.IdempotencyKey
|
||
dst.IdempotencyKey = &k
|
||
}
|
||
return &dst
|
||
}
|
||
|
||
// ---------------------------------------------------------------------------
|
||
// helpers – request builder
|
||
// ---------------------------------------------------------------------------
|
||
|
||
func req(accountRef bson.ObjectID, purpose model.VerificationPurpose, target string, ttl time.Duration) *verification.Request {
|
||
return verification.NewLinkRequest(accountRef, purpose, target).WithTTL(ttl)
|
||
}
|
||
|
||
// ---------------------------------------------------------------------------
|
||
// tests
|
||
// ---------------------------------------------------------------------------
|
||
|
||
func TestCreate_ReturnsRawToken(t *testing.T) {
|
||
db := newTestVerificationDB(t)
|
||
ctx := context.Background()
|
||
accountRef := bson.NewObjectID()
|
||
|
||
raw, err := db.Create(ctx, req(accountRef, model.PurposePasswordReset, "", time.Hour))
|
||
require.NoError(t, err)
|
||
assert.NotEmpty(t, raw)
|
||
}
|
||
|
||
func TestCreate_TokenCanBeConsumed(t *testing.T) {
|
||
db := newTestVerificationDB(t)
|
||
ctx := context.Background()
|
||
accountRef := bson.NewObjectID()
|
||
|
||
raw, err := db.Create(ctx, req(accountRef, model.PurposePasswordReset, "", time.Hour))
|
||
require.NoError(t, err)
|
||
|
||
tok, err := db.Consume(ctx, accountRef, model.PurposePasswordReset, raw)
|
||
require.NoError(t, err)
|
||
assert.Equal(t, accountRef, tok.AccountRef)
|
||
assert.Equal(t, model.PurposePasswordReset, tok.Purpose)
|
||
assert.NotNil(t, tok.UsedAt)
|
||
}
|
||
|
||
func TestConsume_ReturnsCorrectFields(t *testing.T) {
|
||
db := newTestVerificationDB(t)
|
||
ctx := context.Background()
|
||
accountRef := bson.NewObjectID()
|
||
|
||
raw, err := db.Create(ctx, req(accountRef, model.PurposeEmailChange, "new@example.com", time.Hour))
|
||
require.NoError(t, err)
|
||
|
||
tok, err := db.Consume(ctx, accountRef, model.PurposeEmailChange, raw)
|
||
require.NoError(t, err)
|
||
assert.Equal(t, accountRef, tok.AccountRef)
|
||
assert.Equal(t, model.PurposeEmailChange, tok.Purpose)
|
||
assert.Equal(t, "new@example.com", tok.Target)
|
||
}
|
||
|
||
func TestConsume_SecondConsumeFailsAlreadyUsed(t *testing.T) {
|
||
db := newTestVerificationDB(t)
|
||
ctx := context.Background()
|
||
accountRef := bson.NewObjectID()
|
||
|
||
raw, err := db.Create(ctx, req(accountRef, model.PurposePasswordReset, "", time.Hour))
|
||
require.NoError(t, err)
|
||
|
||
_, err = db.Consume(ctx, accountRef, model.PurposePasswordReset, raw)
|
||
require.NoError(t, err)
|
||
|
||
_, err = db.Consume(ctx, accountRef, model.PurposePasswordReset, raw)
|
||
require.Error(t, err)
|
||
assert.True(t, errors.Is(err, verification.ErrTokenAlreadyUsed),
|
||
"second consume should fail with already-used after usedAt is set")
|
||
}
|
||
|
||
func TestConsume_ExpiredTokenFails(t *testing.T) {
|
||
db := newTestVerificationDB(t)
|
||
ctx := context.Background()
|
||
accountRef := bson.NewObjectID()
|
||
|
||
// Create with a TTL that is already in the past.
|
||
raw, err := db.Create(ctx, req(accountRef, model.PurposePasswordReset, "", -time.Hour))
|
||
require.NoError(t, err)
|
||
|
||
_, err = db.Consume(ctx, accountRef, model.PurposePasswordReset, raw)
|
||
require.Error(t, err)
|
||
assert.True(t, errors.Is(err, verification.ErrTokenExpired),
|
||
"expired token should return explicit expiry error")
|
||
}
|
||
|
||
func TestConsume_UnknownTokenFails(t *testing.T) {
|
||
db := newTestVerificationDB(t)
|
||
ctx := context.Background()
|
||
|
||
_, err := db.Consume(ctx, bson.NilObjectID, "", "nonexistent-token-value")
|
||
require.Error(t, err)
|
||
assert.True(t, errors.Is(err, verification.ErrTokenNotFound))
|
||
}
|
||
|
||
func TestConsume_AccountActivationWithoutAccountRef(t *testing.T) {
|
||
db := newTestVerificationDB(t)
|
||
ctx := context.Background()
|
||
accountRef := bson.NewObjectID()
|
||
|
||
raw, err := db.Create(ctx, req(accountRef, model.PurposeAccountActivation, "", time.Hour))
|
||
require.NoError(t, err)
|
||
|
||
tok, err := db.Consume(ctx, bson.NilObjectID, model.PurposeAccountActivation, raw)
|
||
require.NoError(t, err)
|
||
assert.Equal(t, accountRef, tok.AccountRef)
|
||
assert.Equal(t, model.PurposeAccountActivation, tok.Purpose)
|
||
}
|
||
|
||
func TestCreate_InvalidatesPreviousToken(t *testing.T) {
|
||
db := newTestVerificationDB(t)
|
||
ctx := context.Background()
|
||
accountRef := bson.NewObjectID()
|
||
|
||
oldRaw, err := db.Create(ctx, req(accountRef, model.PurposePasswordReset, "", time.Hour))
|
||
require.NoError(t, err)
|
||
|
||
newRaw, err := db.Create(ctx, req(accountRef, model.PurposePasswordReset, "", time.Hour))
|
||
require.NoError(t, err)
|
||
assert.NotEqual(t, oldRaw, newRaw, "new token should differ from old one")
|
||
|
||
// Old token is no longer consumable — invalidated (usedAt set) by the second Create.
|
||
_, err = db.Consume(ctx, accountRef, model.PurposePasswordReset, oldRaw)
|
||
require.Error(t, err)
|
||
assert.True(t, errors.Is(err, verification.ErrTokenAlreadyUsed),
|
||
"old token should return already-used after invalidation")
|
||
|
||
// New token works fine.
|
||
tok, err := db.Consume(ctx, accountRef, model.PurposePasswordReset, newRaw)
|
||
require.NoError(t, err)
|
||
assert.Equal(t, accountRef, tok.AccountRef)
|
||
}
|
||
|
||
func TestCreate_AccountActivationResendWithoutIdempotency_ReissuesToken(t *testing.T) {
|
||
db := newTestVerificationDB(t)
|
||
ctx := context.Background()
|
||
accountRef := bson.NewObjectID()
|
||
|
||
// First issue during signup.
|
||
firstRaw, err := db.Create(ctx, req(accountRef, model.PurposeAccountActivation, "", time.Hour))
|
||
require.NoError(t, err)
|
||
|
||
// Second issue during resend should rotate token instead of failing with duplicate key.
|
||
secondRaw, err := db.Create(ctx, req(accountRef, model.PurposeAccountActivation, "", time.Hour))
|
||
require.NoError(t, err)
|
||
assert.NotEqual(t, firstRaw, secondRaw)
|
||
|
||
// Old token becomes unusable after reissue.
|
||
_, err = db.Consume(ctx, bson.NilObjectID, model.PurposeAccountActivation, firstRaw)
|
||
require.Error(t, err)
|
||
assert.True(t, errors.Is(err, verification.ErrTokenAlreadyUsed))
|
||
|
||
// New token is valid.
|
||
tok, err := db.Consume(ctx, bson.NilObjectID, model.PurposeAccountActivation, secondRaw)
|
||
require.NoError(t, err)
|
||
assert.Equal(t, accountRef, tok.AccountRef)
|
||
assert.Equal(t, model.PurposeAccountActivation, tok.Purpose)
|
||
|
||
// Non-idempotent requests should still persist unique internal keys,
|
||
// preventing uniqueness collisions on (accountRef, purpose, target, idempotencyKey).
|
||
repo := db.Repository.(*memoryTokenRepository)
|
||
repo.mu.Lock()
|
||
defer repo.mu.Unlock()
|
||
|
||
keys := map[string]struct{}{}
|
||
for _, stored := range repo.data {
|
||
if stored.AccountRef != accountRef || stored.Purpose != model.PurposeAccountActivation {
|
||
continue
|
||
}
|
||
require.NotNil(t, stored.IdempotencyKey)
|
||
assert.True(t, strings.HasPrefix(*stored.IdempotencyKey, "auto:"))
|
||
keys[*stored.IdempotencyKey] = struct{}{}
|
||
}
|
||
assert.Len(t, keys, 2)
|
||
}
|
||
|
||
func TestCreate_InvalidatesMultiplePreviousTokens(t *testing.T) {
|
||
db := newTestVerificationDB(t)
|
||
ctx := context.Background()
|
||
accountRef := bson.NewObjectID()
|
||
|
||
first, err := db.Create(ctx, req(accountRef, model.PurposePasswordReset, "", time.Hour))
|
||
require.NoError(t, err)
|
||
second, err := db.Create(ctx, req(accountRef, model.PurposePasswordReset, "", time.Hour))
|
||
require.NoError(t, err)
|
||
third, err := db.Create(ctx, req(accountRef, model.PurposePasswordReset, "", time.Hour))
|
||
require.NoError(t, err)
|
||
|
||
_, err = db.Consume(ctx, accountRef, model.PurposePasswordReset, first)
|
||
assert.True(t, errors.Is(err, verification.ErrTokenAlreadyUsed), "first should be invalidated/used")
|
||
_, err = db.Consume(ctx, accountRef, model.PurposePasswordReset, second)
|
||
assert.True(t, errors.Is(err, verification.ErrTokenAlreadyUsed), "second should be invalidated/used")
|
||
|
||
tok, err := db.Consume(ctx, accountRef, model.PurposePasswordReset, third)
|
||
require.NoError(t, err)
|
||
assert.Equal(t, accountRef, tok.AccountRef)
|
||
}
|
||
|
||
func TestCreate_DifferentPurposeNotInvalidated(t *testing.T) {
|
||
db := newTestVerificationDB(t)
|
||
ctx := context.Background()
|
||
accountRef := bson.NewObjectID()
|
||
|
||
resetRaw, err := db.Create(ctx, req(accountRef, model.PurposePasswordReset, "", time.Hour))
|
||
require.NoError(t, err)
|
||
|
||
// Creating an activation token should NOT invalidate the password-reset token.
|
||
_, err = db.Create(ctx, req(accountRef, model.PurposeAccountActivation, "", time.Hour))
|
||
require.NoError(t, err)
|
||
|
||
tok, err := db.Consume(ctx, accountRef, model.PurposePasswordReset, resetRaw)
|
||
require.NoError(t, err)
|
||
assert.Equal(t, model.PurposePasswordReset, tok.Purpose)
|
||
}
|
||
|
||
func TestCreate_DifferentTargetNotInvalidated(t *testing.T) {
|
||
db := newTestVerificationDB(t)
|
||
ctx := context.Background()
|
||
accountRef := bson.NewObjectID()
|
||
|
||
firstRaw, err := db.Create(ctx, req(accountRef, model.PurposeEmailChange, "a@example.com", time.Hour))
|
||
require.NoError(t, err)
|
||
|
||
// Creating a token for a different target email should NOT invalidate the first.
|
||
_, err = db.Create(ctx, req(accountRef, model.PurposeEmailChange, "b@example.com", time.Hour))
|
||
require.NoError(t, err)
|
||
|
||
tok, err := db.Consume(ctx, accountRef, model.PurposeEmailChange, firstRaw)
|
||
require.NoError(t, err)
|
||
assert.Equal(t, "a@example.com", tok.Target)
|
||
}
|
||
|
||
func TestCreate_DifferentAccountNotInvalidated(t *testing.T) {
|
||
db := newTestVerificationDB(t)
|
||
ctx := context.Background()
|
||
account1 := bson.NewObjectID()
|
||
account2 := bson.NewObjectID()
|
||
|
||
raw1, err := db.Create(ctx, req(account1, model.PurposePasswordReset, "", time.Hour))
|
||
require.NoError(t, err)
|
||
|
||
_, err = db.Create(ctx, req(account2, model.PurposePasswordReset, "", time.Hour))
|
||
require.NoError(t, err)
|
||
|
||
tok, err := db.Consume(ctx, account1, model.PurposePasswordReset, raw1)
|
||
require.NoError(t, err)
|
||
assert.Equal(t, account1, tok.AccountRef)
|
||
}
|
||
|
||
func TestCreate_AlreadyUsedTokenNotInvalidatedAgain(t *testing.T) {
|
||
db := newTestVerificationDB(t)
|
||
ctx := context.Background()
|
||
accountRef := bson.NewObjectID()
|
||
|
||
// Create and consume first token.
|
||
raw1, err := db.Create(ctx, req(accountRef, model.PurposePasswordReset, "", time.Hour))
|
||
require.NoError(t, err)
|
||
_, err = db.Consume(ctx, accountRef, model.PurposePasswordReset, raw1)
|
||
require.NoError(t, err)
|
||
|
||
// Create second — the already-consumed token should have usedAt set,
|
||
// so the invalidation query (usedAt == nil) should skip it.
|
||
// This tests that the PatchMany filter correctly excludes already-used tokens.
|
||
raw2, err := db.Create(ctx, req(accountRef, model.PurposePasswordReset, "", time.Hour))
|
||
require.NoError(t, err)
|
||
|
||
tok, err := db.Consume(ctx, accountRef, model.PurposePasswordReset, raw2)
|
||
require.NoError(t, err)
|
||
assert.Equal(t, accountRef, tok.AccountRef)
|
||
}
|
||
|
||
func TestCreate_ExpiredTokenNotInvalidated(t *testing.T) {
|
||
db := newTestVerificationDB(t)
|
||
ctx := context.Background()
|
||
accountRef := bson.NewObjectID()
|
||
|
||
// Create a token that is already expired.
|
||
_, err := db.Create(ctx, req(accountRef, model.PurposePasswordReset, "", -time.Hour))
|
||
require.NoError(t, err)
|
||
|
||
// Create a fresh one — invalidation should skip the expired token (expiresAt > now filter).
|
||
raw2, err := db.Create(ctx, req(accountRef, model.PurposePasswordReset, "", time.Hour))
|
||
require.NoError(t, err)
|
||
|
||
tok, err := db.Consume(ctx, accountRef, model.PurposePasswordReset, raw2)
|
||
require.NoError(t, err)
|
||
assert.Equal(t, accountRef, tok.AccountRef)
|
||
}
|
||
|
||
func TestTokenHash_Deterministic(t *testing.T) {
|
||
h1 := tokenHash("same-input")
|
||
h2 := tokenHash("same-input")
|
||
assert.Equal(t, h1, h2)
|
||
}
|
||
|
||
func TestTokenHash_DifferentInputs(t *testing.T) {
|
||
h1 := tokenHash("input-a")
|
||
h2 := tokenHash("input-b")
|
||
assert.NotEqual(t, h1, h2)
|
||
}
|
||
|
||
// ---------------------------------------------------------------------------
|
||
// cooldown tests
|
||
// ---------------------------------------------------------------------------
|
||
|
||
func TestCreate_CooldownBlocksCreation(t *testing.T) {
|
||
db := newTestVerificationDB(t)
|
||
ctx := context.Background()
|
||
accountRef := bson.NewObjectID()
|
||
|
||
// First creation without cooldown.
|
||
_, err := db.Create(ctx, req(accountRef, model.PurposePasswordReset, "", time.Hour))
|
||
require.NoError(t, err)
|
||
|
||
// Immediate re-create with cooldown should be blocked — token is too recent to invalidate.
|
||
r2 := req(accountRef, model.PurposePasswordReset, "", time.Hour).WithCooldown(time.Minute)
|
||
_, err = db.Create(ctx, r2)
|
||
require.Error(t, err)
|
||
assert.True(t, errors.Is(err, verification.ErrCooldownActive))
|
||
}
|
||
|
||
func TestCreate_CooldownExpiresAllowsCreation(t *testing.T) {
|
||
db := newTestVerificationDB(t)
|
||
ctx := context.Background()
|
||
accountRef := bson.NewObjectID()
|
||
|
||
// First creation without cooldown.
|
||
firstRaw, err := db.Create(ctx, req(accountRef, model.PurposePasswordReset, "", time.Hour))
|
||
require.NoError(t, err)
|
||
|
||
time.Sleep(2 * time.Millisecond)
|
||
|
||
// Re-create with short cooldown — the prior token is old enough to be invalidated.
|
||
r2 := req(accountRef, model.PurposePasswordReset, "", time.Hour).WithCooldown(time.Millisecond)
|
||
secondRaw, err := db.Create(ctx, r2)
|
||
require.NoError(t, err)
|
||
assert.NotEqual(t, firstRaw, secondRaw)
|
||
|
||
// Old token should be rotated out after successful re-issue.
|
||
_, err = db.Consume(ctx, accountRef, model.PurposePasswordReset, firstRaw)
|
||
require.Error(t, err)
|
||
assert.True(t, errors.Is(err, verification.ErrTokenAlreadyUsed))
|
||
|
||
// New token remains valid.
|
||
tok, err := db.Consume(ctx, accountRef, model.PurposePasswordReset, secondRaw)
|
||
require.NoError(t, err)
|
||
assert.Equal(t, accountRef, tok.AccountRef)
|
||
}
|
||
|
||
func TestCreate_CooldownNilIgnored(t *testing.T) {
|
||
db := newTestVerificationDB(t)
|
||
ctx := context.Background()
|
||
accountRef := bson.NewObjectID()
|
||
|
||
_, err := db.Create(ctx, req(accountRef, model.PurposePasswordReset, "", time.Hour))
|
||
require.NoError(t, err)
|
||
|
||
// No cooldown set — immediate re-create should succeed.
|
||
_, err = db.Create(ctx, req(accountRef, model.PurposePasswordReset, "", time.Hour))
|
||
require.NoError(t, err)
|
||
}
|
||
|
||
func TestCreate_IdempotencyKeyReplayReturnsSameToken(t *testing.T) {
|
||
db := newTestVerificationDB(t)
|
||
ctx := context.Background()
|
||
accountRef := bson.NewObjectID()
|
||
|
||
firstReq := req(accountRef, model.PurposePasswordReset, "", time.Hour).WithIdempotencyKey("same-key")
|
||
firstRaw, err := db.Create(ctx, firstReq)
|
||
require.NoError(t, err)
|
||
require.NotEmpty(t, firstRaw)
|
||
|
||
// Replay with the same idempotency key should return success and same token.
|
||
secondReq := req(accountRef, model.PurposePasswordReset, "", time.Hour).WithIdempotencyKey("same-key")
|
||
secondRaw, err := db.Create(ctx, secondReq)
|
||
require.NoError(t, err)
|
||
assert.Equal(t, firstRaw, secondRaw)
|
||
|
||
repo := db.Repository.(*memoryTokenRepository)
|
||
repo.mu.Lock()
|
||
assert.Len(t, repo.data, 1)
|
||
repo.mu.Unlock()
|
||
}
|
||
|
||
func TestCreate_IdempotencyScopeIncludesTarget(t *testing.T) {
|
||
db := newTestVerificationDB(t)
|
||
ctx := context.Background()
|
||
accountRef := bson.NewObjectID()
|
||
|
||
r1 := req(accountRef, model.PurposeEmailChange, "a@example.com", time.Hour).WithIdempotencyKey("same-key")
|
||
raw1, err := db.Create(ctx, r1)
|
||
require.NoError(t, err)
|
||
require.NotEmpty(t, raw1)
|
||
|
||
// Same account/purpose/key but different target should be treated as a different idempotency scope.
|
||
r2 := req(accountRef, model.PurposeEmailChange, "b@example.com", time.Hour).WithIdempotencyKey("same-key")
|
||
raw2, err := db.Create(ctx, r2)
|
||
require.NoError(t, err)
|
||
require.NotEmpty(t, raw2)
|
||
assert.NotEqual(t, raw1, raw2)
|
||
|
||
t1, err := db.Consume(ctx, accountRef, model.PurposeEmailChange, raw1)
|
||
require.NoError(t, err)
|
||
assert.Equal(t, "a@example.com", t1.Target)
|
||
|
||
t2, err := db.Consume(ctx, accountRef, model.PurposeEmailChange, raw2)
|
||
require.NoError(t, err)
|
||
assert.Equal(t, "b@example.com", t2.Target)
|
||
}
|
||
|
||
func TestCreate_IdempotencySurvivesCallbackRetry(t *testing.T) {
|
||
db := newTestVerificationDBWithFactory(t, &retryingTxFactory{})
|
||
ctx := context.Background()
|
||
accountRef := bson.NewObjectID()
|
||
|
||
// Cooldown would block the second callback execution if idempotency wasn't handled.
|
||
r := req(accountRef, model.PurposePasswordReset, "", time.Hour).
|
||
WithCooldown(time.Minute).
|
||
WithIdempotencyKey("retry-safe")
|
||
|
||
raw, err := db.Create(ctx, r)
|
||
require.NoError(t, err)
|
||
require.NotEmpty(t, raw)
|
||
|
||
repo := db.Repository.(*memoryTokenRepository)
|
||
repo.mu.Lock()
|
||
require.Len(t, repo.data, 1)
|
||
for _, tok := range repo.data {
|
||
require.NotNil(t, tok.IdempotencyKey)
|
||
assert.Equal(t, "retry-safe", *tok.IdempotencyKey)
|
||
assert.Nil(t, tok.UsedAt)
|
||
assert.Equal(t, tok.VerifyTokenHash, tokenHash(raw))
|
||
}
|
||
repo.mu.Unlock()
|
||
}
|
||
|
||
// ---------------------------------------------------------------------------
|
||
// max retries / attempts tests
|
||
// ---------------------------------------------------------------------------
|
||
|
||
func TestConsume_MaxRetriesExceeded(t *testing.T) {
|
||
db := newTestVerificationDB(t)
|
||
ctx := context.Background()
|
||
accountRef := bson.NewObjectID()
|
||
|
||
r := req(accountRef, model.PurposePasswordReset, "", time.Hour).WithMaxRetries(2)
|
||
raw, err := db.Create(ctx, r)
|
||
require.NoError(t, err)
|
||
|
||
// Simulate 2 prior failed attempts by setting Attempts directly.
|
||
repo := db.Repository.(*memoryTokenRepository)
|
||
repo.mu.Lock()
|
||
for _, tok := range repo.data {
|
||
tok.Attempts = 2
|
||
}
|
||
repo.mu.Unlock()
|
||
|
||
// Consume with correct token should fail — attempts already at max.
|
||
_, err = db.Consume(ctx, accountRef, model.PurposePasswordReset, raw)
|
||
require.Error(t, err)
|
||
assert.True(t, errors.Is(err, verification.ErrTokenAttemptsExceeded))
|
||
}
|
||
|
||
func TestConsume_UnderMaxRetriesSucceeds(t *testing.T) {
|
||
db := newTestVerificationDB(t)
|
||
ctx := context.Background()
|
||
accountRef := bson.NewObjectID()
|
||
|
||
r := req(accountRef, model.PurposePasswordReset, "", time.Hour).WithMaxRetries(3)
|
||
raw, err := db.Create(ctx, r)
|
||
require.NoError(t, err)
|
||
|
||
// Simulate 2 prior failed attempts (under maxRetries=3).
|
||
repo := db.Repository.(*memoryTokenRepository)
|
||
repo.mu.Lock()
|
||
for _, tok := range repo.data {
|
||
tok.Attempts = 2
|
||
}
|
||
repo.mu.Unlock()
|
||
|
||
// Consume with correct token should succeed.
|
||
tok, err := db.Consume(ctx, accountRef, model.PurposePasswordReset, raw)
|
||
require.NoError(t, err)
|
||
assert.Equal(t, accountRef, tok.AccountRef)
|
||
}
|
||
|
||
func TestConsume_NoMaxRetriesIgnoresAttempts(t *testing.T) {
|
||
db := newTestVerificationDB(t)
|
||
ctx := context.Background()
|
||
accountRef := bson.NewObjectID()
|
||
|
||
// Create without MaxRetries.
|
||
raw, err := db.Create(ctx, req(accountRef, model.PurposePasswordReset, "", time.Hour))
|
||
require.NoError(t, err)
|
||
|
||
// Simulate high attempt count — should be ignored since MaxRetries is nil.
|
||
repo := db.Repository.(*memoryTokenRepository)
|
||
repo.mu.Lock()
|
||
for _, tok := range repo.data {
|
||
tok.Attempts = 100
|
||
}
|
||
repo.mu.Unlock()
|
||
|
||
tok, err := db.Consume(ctx, accountRef, model.PurposePasswordReset, raw)
|
||
require.NoError(t, err)
|
||
assert.Equal(t, accountRef, tok.AccountRef)
|
||
}
|
||
|
||
func TestConsume_WrongHashReturnsNotFound(t *testing.T) {
|
||
db := newTestVerificationDB(t)
|
||
ctx := context.Background()
|
||
accountRef := bson.NewObjectID()
|
||
|
||
_, err := db.Create(ctx, req(accountRef, model.PurposePasswordReset, "", time.Hour))
|
||
require.NoError(t, err)
|
||
|
||
// Wrong code — hash won't match any token.
|
||
_, err = db.Consume(ctx, accountRef, model.PurposePasswordReset, "wrong-code")
|
||
require.Error(t, err)
|
||
assert.True(t, errors.Is(err, verification.ErrTokenNotFound))
|
||
}
|
||
|
||
func TestConsume_ContextMismatchReturnsNotFound(t *testing.T) {
|
||
db := newTestVerificationDB(t)
|
||
ctx := context.Background()
|
||
accountRef := bson.NewObjectID()
|
||
otherAccount := bson.NewObjectID()
|
||
|
||
raw, err := db.Create(ctx, req(accountRef, model.PurposePasswordReset, "", time.Hour))
|
||
require.NoError(t, err)
|
||
|
||
// Correct token but wrong accountRef — context mismatch.
|
||
_, err = db.Consume(ctx, otherAccount, model.PurposePasswordReset, raw)
|
||
require.Error(t, err)
|
||
assert.True(t, errors.Is(err, verification.ErrTokenNotFound))
|
||
}
|