unified code verification service
This commit is contained in:
@@ -20,6 +20,7 @@ import (
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
"github.com/tech/sendico/pkg/mservice"
|
||||
"go.mongodb.org/mongo-driver/v2/bson"
|
||||
"go.mongodb.org/mongo-driver/v2/mongo"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
@@ -28,6 +29,10 @@ import (
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func newTestVerificationDB(t *testing.T) *verificationDB {
|
||||
return newTestVerificationDBWithFactory(t, &passthroughTxFactory{})
|
||||
}
|
||||
|
||||
func newTestVerificationDBWithFactory(t *testing.T, tf transaction.Factory) *verificationDB {
|
||||
t.Helper()
|
||||
repo := newMemoryTokenRepository()
|
||||
logger := zap.NewNop()
|
||||
@@ -36,7 +41,7 @@ func newTestVerificationDB(t *testing.T) *verificationDB {
|
||||
Logger: logger,
|
||||
Repository: repo,
|
||||
},
|
||||
tf: &passthroughTxFactory{},
|
||||
tf: tf,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -51,6 +56,20 @@ func (*passthroughTx) Execute(ctx context.Context, cb transaction.Callback) (any
|
||||
return cb(ctx)
|
||||
}
|
||||
|
||||
// retryingTxFactory simulates transaction callbacks being executed more than once.
|
||||
type retryingTxFactory struct{}
|
||||
|
||||
func (*retryingTxFactory) CreateTransaction() transaction.Transaction { return &retryingTx{} }
|
||||
|
||||
type retryingTx struct{}
|
||||
|
||||
func (*retryingTx) Execute(ctx context.Context, cb transaction.Callback) (any, error) {
|
||||
if _, err := cb(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return cb(ctx)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// in-memory repository for VerificationToken
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -156,8 +175,34 @@ func (m *memoryTokenRepository) InsertMany(ctx context.Context, objs []storable.
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (m *memoryTokenRepository) FindManyByFilter(context.Context, builder.Query, rd.DecodingFunc) error {
|
||||
return merrors.NotImplemented("not needed")
|
||||
func (m *memoryTokenRepository) FindManyByFilter(_ context.Context, query builder.Query, decoder rd.DecodingFunc) error {
|
||||
m.mu.Lock()
|
||||
var matches []interface{}
|
||||
for _, id := range m.order {
|
||||
tok := m.data[id]
|
||||
if tok != nil && matchToken(query, tok) {
|
||||
raw, err := bson.Marshal(cloneToken(tok))
|
||||
if err != nil {
|
||||
m.mu.Unlock()
|
||||
return err
|
||||
}
|
||||
matches = append(matches, bson.Raw(raw))
|
||||
}
|
||||
}
|
||||
m.mu.Unlock()
|
||||
|
||||
cur, err := mongo.NewCursorFromDocuments(matches, nil, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer cur.Close(context.Background())
|
||||
|
||||
for cur.Next(context.Background()) {
|
||||
if err := decoder(cur); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (m *memoryTokenRepository) Patch(context.Context, bson.ObjectID, builder.Patch) error {
|
||||
return merrors.NotImplemented("not needed")
|
||||
@@ -190,8 +235,14 @@ func (m *memoryTokenRepository) Collection() string { return mservice.Verificati
|
||||
// tokenFieldValue returns the stored value for a given BSON field name.
|
||||
func tokenFieldValue(tok *model.VerificationToken, field string) any {
|
||||
switch field {
|
||||
case "_id":
|
||||
return tok.ID
|
||||
case "createdAt":
|
||||
return tok.CreatedAt
|
||||
case "verifyTokenHash":
|
||||
return tok.VerifyTokenHash
|
||||
case "salt":
|
||||
return tok.Salt
|
||||
case "usedAt":
|
||||
return tok.UsedAt
|
||||
case "expiresAt":
|
||||
@@ -202,6 +253,15 @@ func tokenFieldValue(tok *model.VerificationToken, field string) any {
|
||||
return tok.Purpose
|
||||
case "target":
|
||||
return tok.Target
|
||||
case "idempotencyKey":
|
||||
if tok.IdempotencyKey == nil {
|
||||
return nil
|
||||
}
|
||||
return *tok.IdempotencyKey
|
||||
case "maxRetries":
|
||||
return tok.MaxRetries
|
||||
case "attempts":
|
||||
return tok.Attempts
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
@@ -261,11 +321,11 @@ func matchOperator(stored any, ops bson.M) bool {
|
||||
for op, cmpVal := range ops {
|
||||
switch op {
|
||||
case "$gt":
|
||||
if !timeGt(stored, cmpVal) {
|
||||
if !cmpGt(stored, cmpVal) {
|
||||
return false
|
||||
}
|
||||
case "$lt":
|
||||
if !timeLt(stored, cmpVal) {
|
||||
if !cmpLt(stored, cmpVal) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
@@ -273,6 +333,36 @@ func matchOperator(stored any, ops bson.M) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func cmpGt(stored, cmpVal any) bool {
|
||||
if si, ok := toInt(stored); ok {
|
||||
if ci, ok := toInt(cmpVal); ok {
|
||||
return si > ci
|
||||
}
|
||||
}
|
||||
return timeGt(stored, cmpVal)
|
||||
}
|
||||
|
||||
func cmpLt(stored, cmpVal any) bool {
|
||||
if si, ok := toInt(stored); ok {
|
||||
if ci, ok := toInt(cmpVal); ok {
|
||||
return si < ci
|
||||
}
|
||||
}
|
||||
return timeLt(stored, cmpVal)
|
||||
}
|
||||
|
||||
func toInt(v any) (int, bool) {
|
||||
switch iv := v.(type) {
|
||||
case int:
|
||||
return iv, true
|
||||
case int64:
|
||||
return int(iv), true
|
||||
case int32:
|
||||
return int(iv), true
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
func valuesEqual(a, b any) bool {
|
||||
// nil checks: usedAt == nil
|
||||
if b == nil {
|
||||
@@ -343,21 +433,34 @@ func toTime(v any) (time.Time, bool) {
|
||||
return time.Time{}, false
|
||||
}
|
||||
|
||||
// applyPatch applies $set operations from a patch bson.D to a token.
|
||||
// applyPatch applies $set and $inc operations from a patch bson.D to a token.
|
||||
func applyPatch(tok *model.VerificationToken, patchDoc bson.D) {
|
||||
for _, op := range patchDoc {
|
||||
if op.Key != "$set" {
|
||||
continue
|
||||
}
|
||||
fields, ok := op.Value.(bson.D)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
for _, f := range fields {
|
||||
switch f.Key {
|
||||
case "usedAt":
|
||||
if t, ok := f.Value.(time.Time); ok {
|
||||
tok.UsedAt = &t
|
||||
switch op.Key {
|
||||
case "$set":
|
||||
fields, ok := op.Value.(bson.D)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
for _, f := range fields {
|
||||
switch f.Key {
|
||||
case "usedAt":
|
||||
if t, ok := f.Value.(time.Time); ok {
|
||||
tok.UsedAt = &t
|
||||
}
|
||||
}
|
||||
}
|
||||
case "$inc":
|
||||
fields, ok := op.Value.(bson.D)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
for _, f := range fields {
|
||||
switch f.Key {
|
||||
case "attempts":
|
||||
if v, ok := f.Value.(int); ok {
|
||||
tok.Attempts += v
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -370,20 +473,27 @@ func cloneToken(src *model.VerificationToken) *model.VerificationToken {
|
||||
t := *src.UsedAt
|
||||
dst.UsedAt = &t
|
||||
}
|
||||
if src.MaxRetries != nil {
|
||||
v := *src.MaxRetries
|
||||
dst.MaxRetries = &v
|
||||
}
|
||||
if src.Salt != nil {
|
||||
s := *src.Salt
|
||||
dst.Salt = &s
|
||||
}
|
||||
if src.IdempotencyKey != nil {
|
||||
k := *src.IdempotencyKey
|
||||
dst.IdempotencyKey = &k
|
||||
}
|
||||
return &dst
|
||||
}
|
||||
|
||||
// allTokens returns every stored token for inspection in tests.
|
||||
func (m *memoryTokenRepository) allTokens() []*model.VerificationToken {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
out := make([]*model.VerificationToken, 0, len(m.data))
|
||||
for _, id := range m.order {
|
||||
if tok, ok := m.data[id]; ok {
|
||||
out = append(out, cloneToken(tok))
|
||||
}
|
||||
}
|
||||
return out
|
||||
// ---------------------------------------------------------------------------
|
||||
// helpers – request builder
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func req(accountRef bson.ObjectID, purpose model.VerificationPurpose, target string, ttl time.Duration) *verification.Request {
|
||||
return verification.NewLinkRequest(accountRef, purpose, target).WithTTL(ttl)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -395,7 +505,7 @@ func TestCreate_ReturnsRawToken(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
accountRef := bson.NewObjectID()
|
||||
|
||||
raw, err := db.Create(ctx, accountRef, model.PurposePasswordReset, "", time.Hour)
|
||||
raw, err := db.Create(ctx, req(accountRef, model.PurposePasswordReset, "", time.Hour))
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, raw)
|
||||
}
|
||||
@@ -405,10 +515,10 @@ func TestCreate_TokenCanBeConsumed(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
accountRef := bson.NewObjectID()
|
||||
|
||||
raw, err := db.Create(ctx, accountRef, model.PurposePasswordReset, "", time.Hour)
|
||||
raw, err := db.Create(ctx, req(accountRef, model.PurposePasswordReset, "", time.Hour))
|
||||
require.NoError(t, err)
|
||||
|
||||
tok, err := db.Consume(ctx, raw)
|
||||
tok, err := db.Consume(ctx, accountRef, model.PurposePasswordReset, raw)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, accountRef, tok.AccountRef)
|
||||
assert.Equal(t, model.PurposePasswordReset, tok.Purpose)
|
||||
@@ -420,10 +530,10 @@ func TestConsume_ReturnsCorrectFields(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
accountRef := bson.NewObjectID()
|
||||
|
||||
raw, err := db.Create(ctx, accountRef, model.PurposeEmailChange, "new@example.com", time.Hour)
|
||||
raw, err := db.Create(ctx, req(accountRef, model.PurposeEmailChange, "new@example.com", time.Hour))
|
||||
require.NoError(t, err)
|
||||
|
||||
tok, err := db.Consume(ctx, raw)
|
||||
tok, err := db.Consume(ctx, accountRef, model.PurposeEmailChange, raw)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, accountRef, tok.AccountRef)
|
||||
assert.Equal(t, model.PurposeEmailChange, tok.Purpose)
|
||||
@@ -435,16 +545,16 @@ func TestConsume_SecondConsumeFailsAlreadyUsed(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
accountRef := bson.NewObjectID()
|
||||
|
||||
raw, err := db.Create(ctx, accountRef, model.PurposePasswordReset, "", time.Hour)
|
||||
raw, err := db.Create(ctx, req(accountRef, model.PurposePasswordReset, "", time.Hour))
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = db.Consume(ctx, raw)
|
||||
_, err = db.Consume(ctx, accountRef, model.PurposePasswordReset, raw)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = db.Consume(ctx, raw)
|
||||
_, err = db.Consume(ctx, accountRef, model.PurposePasswordReset, raw)
|
||||
require.Error(t, err)
|
||||
assert.True(t, errors.Is(err, verification.ErrTokenAlreadyUsed),
|
||||
"second consume should fail because usedAt is set")
|
||||
assert.True(t, errors.Is(err, verification.ErrTokenNotFound),
|
||||
"second consume should fail — used tokens are excluded from active filter")
|
||||
}
|
||||
|
||||
func TestConsume_ExpiredTokenFails(t *testing.T) {
|
||||
@@ -453,20 +563,20 @@ func TestConsume_ExpiredTokenFails(t *testing.T) {
|
||||
accountRef := bson.NewObjectID()
|
||||
|
||||
// Create with a TTL that is already in the past.
|
||||
raw, err := db.Create(ctx, accountRef, model.PurposePasswordReset, "", -time.Hour)
|
||||
raw, err := db.Create(ctx, req(accountRef, model.PurposePasswordReset, "", -time.Hour))
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = db.Consume(ctx, raw)
|
||||
_, err = db.Consume(ctx, accountRef, model.PurposePasswordReset, raw)
|
||||
require.Error(t, err)
|
||||
assert.True(t, errors.Is(err, verification.ErrTokenExpired),
|
||||
"expired token should not be consumable")
|
||||
assert.True(t, errors.Is(err, verification.ErrTokenNotFound),
|
||||
"expired token is excluded from active filter")
|
||||
}
|
||||
|
||||
func TestConsume_UnknownTokenFails(t *testing.T) {
|
||||
db := newTestVerificationDB(t)
|
||||
ctx := context.Background()
|
||||
|
||||
_, err := db.Consume(ctx, "nonexistent-token-value")
|
||||
_, err := db.Consume(ctx, bson.NilObjectID, "", "nonexistent-token-value")
|
||||
require.Error(t, err)
|
||||
assert.True(t, errors.Is(err, verification.ErrTokenNotFound))
|
||||
}
|
||||
@@ -476,21 +586,21 @@ func TestCreate_InvalidatesPreviousToken(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
accountRef := bson.NewObjectID()
|
||||
|
||||
oldRaw, err := db.Create(ctx, accountRef, model.PurposePasswordReset, "", time.Hour)
|
||||
oldRaw, err := db.Create(ctx, req(accountRef, model.PurposePasswordReset, "", time.Hour))
|
||||
require.NoError(t, err)
|
||||
|
||||
newRaw, err := db.Create(ctx, accountRef, model.PurposePasswordReset, "", time.Hour)
|
||||
newRaw, err := db.Create(ctx, req(accountRef, model.PurposePasswordReset, "", time.Hour))
|
||||
require.NoError(t, err)
|
||||
assert.NotEqual(t, oldRaw, newRaw, "new token should differ from old one")
|
||||
|
||||
// Old token is no longer consumable.
|
||||
_, err = db.Consume(ctx, oldRaw)
|
||||
// Old token is no longer consumable — invalidated (usedAt set) by the second Create.
|
||||
_, err = db.Consume(ctx, accountRef, model.PurposePasswordReset, oldRaw)
|
||||
require.Error(t, err)
|
||||
assert.True(t, errors.Is(err, verification.ErrTokenAlreadyUsed),
|
||||
"old token should be invalidated (usedAt set) after new token creation")
|
||||
assert.True(t, errors.Is(err, verification.ErrTokenNotFound),
|
||||
"old token should be invalidated after new token creation")
|
||||
|
||||
// New token works fine.
|
||||
tok, err := db.Consume(ctx, newRaw)
|
||||
tok, err := db.Consume(ctx, accountRef, model.PurposePasswordReset, newRaw)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, accountRef, tok.AccountRef)
|
||||
}
|
||||
@@ -500,19 +610,19 @@ func TestCreate_InvalidatesMultiplePreviousTokens(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
accountRef := bson.NewObjectID()
|
||||
|
||||
first, err := db.Create(ctx, accountRef, model.PurposePasswordReset, "", time.Hour)
|
||||
first, err := db.Create(ctx, req(accountRef, model.PurposePasswordReset, "", time.Hour))
|
||||
require.NoError(t, err)
|
||||
second, err := db.Create(ctx, accountRef, model.PurposePasswordReset, "", time.Hour)
|
||||
second, err := db.Create(ctx, req(accountRef, model.PurposePasswordReset, "", time.Hour))
|
||||
require.NoError(t, err)
|
||||
third, err := db.Create(ctx, accountRef, model.PurposePasswordReset, "", time.Hour)
|
||||
third, err := db.Create(ctx, req(accountRef, model.PurposePasswordReset, "", time.Hour))
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = db.Consume(ctx, first)
|
||||
assert.True(t, errors.Is(err, verification.ErrTokenAlreadyUsed), "first should be invalidated")
|
||||
_, err = db.Consume(ctx, second)
|
||||
assert.True(t, errors.Is(err, verification.ErrTokenAlreadyUsed), "second should be invalidated")
|
||||
_, err = db.Consume(ctx, accountRef, model.PurposePasswordReset, first)
|
||||
assert.True(t, errors.Is(err, verification.ErrTokenNotFound), "first should be invalidated")
|
||||
_, err = db.Consume(ctx, accountRef, model.PurposePasswordReset, second)
|
||||
assert.True(t, errors.Is(err, verification.ErrTokenNotFound), "second should be invalidated")
|
||||
|
||||
tok, err := db.Consume(ctx, third)
|
||||
tok, err := db.Consume(ctx, accountRef, model.PurposePasswordReset, third)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, accountRef, tok.AccountRef)
|
||||
}
|
||||
@@ -522,14 +632,14 @@ func TestCreate_DifferentPurposeNotInvalidated(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
accountRef := bson.NewObjectID()
|
||||
|
||||
resetRaw, err := db.Create(ctx, accountRef, model.PurposePasswordReset, "", time.Hour)
|
||||
resetRaw, err := db.Create(ctx, req(accountRef, model.PurposePasswordReset, "", time.Hour))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Creating an activation token should NOT invalidate the password-reset token.
|
||||
_, err = db.Create(ctx, accountRef, model.PurposeAccountActivation, "", time.Hour)
|
||||
_, err = db.Create(ctx, req(accountRef, model.PurposeAccountActivation, "", time.Hour))
|
||||
require.NoError(t, err)
|
||||
|
||||
tok, err := db.Consume(ctx, resetRaw)
|
||||
tok, err := db.Consume(ctx, accountRef, model.PurposePasswordReset, resetRaw)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, model.PurposePasswordReset, tok.Purpose)
|
||||
}
|
||||
@@ -539,14 +649,14 @@ func TestCreate_DifferentTargetNotInvalidated(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
accountRef := bson.NewObjectID()
|
||||
|
||||
firstRaw, err := db.Create(ctx, accountRef, model.PurposeEmailChange, "a@example.com", time.Hour)
|
||||
firstRaw, err := db.Create(ctx, req(accountRef, model.PurposeEmailChange, "a@example.com", time.Hour))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Creating a token for a different target email should NOT invalidate the first.
|
||||
_, err = db.Create(ctx, accountRef, model.PurposeEmailChange, "b@example.com", time.Hour)
|
||||
_, err = db.Create(ctx, req(accountRef, model.PurposeEmailChange, "b@example.com", time.Hour))
|
||||
require.NoError(t, err)
|
||||
|
||||
tok, err := db.Consume(ctx, firstRaw)
|
||||
tok, err := db.Consume(ctx, accountRef, model.PurposeEmailChange, firstRaw)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "a@example.com", tok.Target)
|
||||
}
|
||||
@@ -557,13 +667,13 @@ func TestCreate_DifferentAccountNotInvalidated(t *testing.T) {
|
||||
account1 := bson.NewObjectID()
|
||||
account2 := bson.NewObjectID()
|
||||
|
||||
raw1, err := db.Create(ctx, account1, model.PurposePasswordReset, "", time.Hour)
|
||||
raw1, err := db.Create(ctx, req(account1, model.PurposePasswordReset, "", time.Hour))
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = db.Create(ctx, account2, model.PurposePasswordReset, "", time.Hour)
|
||||
_, err = db.Create(ctx, req(account2, model.PurposePasswordReset, "", time.Hour))
|
||||
require.NoError(t, err)
|
||||
|
||||
tok, err := db.Consume(ctx, raw1)
|
||||
tok, err := db.Consume(ctx, account1, model.PurposePasswordReset, raw1)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, account1, tok.AccountRef)
|
||||
}
|
||||
@@ -574,18 +684,18 @@ func TestCreate_AlreadyUsedTokenNotInvalidatedAgain(t *testing.T) {
|
||||
accountRef := bson.NewObjectID()
|
||||
|
||||
// Create and consume first token.
|
||||
raw1, err := db.Create(ctx, accountRef, model.PurposePasswordReset, "", time.Hour)
|
||||
raw1, err := db.Create(ctx, req(accountRef, model.PurposePasswordReset, "", time.Hour))
|
||||
require.NoError(t, err)
|
||||
_, err = db.Consume(ctx, raw1)
|
||||
_, err = db.Consume(ctx, accountRef, model.PurposePasswordReset, raw1)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create second — the already-consumed token should have usedAt set,
|
||||
// so the invalidation query (usedAt == nil) should skip it.
|
||||
// This tests that the PatchMany filter correctly excludes already-used tokens.
|
||||
raw2, err := db.Create(ctx, accountRef, model.PurposePasswordReset, "", time.Hour)
|
||||
raw2, err := db.Create(ctx, req(accountRef, model.PurposePasswordReset, "", time.Hour))
|
||||
require.NoError(t, err)
|
||||
|
||||
tok, err := db.Consume(ctx, raw2)
|
||||
tok, err := db.Consume(ctx, accountRef, model.PurposePasswordReset, raw2)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, accountRef, tok.AccountRef)
|
||||
}
|
||||
@@ -596,14 +706,14 @@ func TestCreate_ExpiredTokenNotInvalidated(t *testing.T) {
|
||||
accountRef := bson.NewObjectID()
|
||||
|
||||
// Create a token that is already expired.
|
||||
_, err := db.Create(ctx, accountRef, model.PurposePasswordReset, "", -time.Hour)
|
||||
_, err := db.Create(ctx, req(accountRef, model.PurposePasswordReset, "", -time.Hour))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create a fresh one — invalidation should skip the expired token (expiresAt > now filter).
|
||||
raw2, err := db.Create(ctx, accountRef, model.PurposePasswordReset, "", time.Hour)
|
||||
raw2, err := db.Create(ctx, req(accountRef, model.PurposePasswordReset, "", time.Hour))
|
||||
require.NoError(t, err)
|
||||
|
||||
tok, err := db.Consume(ctx, raw2)
|
||||
tok, err := db.Consume(ctx, accountRef, model.PurposePasswordReset, raw2)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, accountRef, tok.AccountRef)
|
||||
}
|
||||
@@ -619,3 +729,228 @@ func TestTokenHash_DifferentInputs(t *testing.T) {
|
||||
h2 := tokenHash("input-b")
|
||||
assert.NotEqual(t, h1, h2)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// cooldown tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestCreate_CooldownBlocksCreation(t *testing.T) {
|
||||
db := newTestVerificationDB(t)
|
||||
ctx := context.Background()
|
||||
accountRef := bson.NewObjectID()
|
||||
|
||||
// First creation without cooldown.
|
||||
_, err := db.Create(ctx, req(accountRef, model.PurposePasswordReset, "", time.Hour))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Immediate re-create with cooldown should be blocked — token is too recent to invalidate.
|
||||
r2 := req(accountRef, model.PurposePasswordReset, "", time.Hour).WithCooldown(time.Minute)
|
||||
_, err = db.Create(ctx, r2)
|
||||
require.Error(t, err)
|
||||
assert.True(t, errors.Is(err, verification.ErrCooldownActive))
|
||||
}
|
||||
|
||||
func TestCreate_CooldownExpiresAllowsCreation(t *testing.T) {
|
||||
db := newTestVerificationDB(t)
|
||||
ctx := context.Background()
|
||||
accountRef := bson.NewObjectID()
|
||||
|
||||
// First creation without cooldown.
|
||||
_, err := db.Create(ctx, req(accountRef, model.PurposePasswordReset, "", time.Hour))
|
||||
require.NoError(t, err)
|
||||
|
||||
time.Sleep(2 * time.Millisecond)
|
||||
|
||||
// Re-create with short cooldown — the prior token is old enough to be invalidated.
|
||||
r2 := req(accountRef, model.PurposePasswordReset, "", time.Hour).WithCooldown(time.Millisecond)
|
||||
_, err = db.Create(ctx, r2)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestCreate_CooldownNilIgnored(t *testing.T) {
|
||||
db := newTestVerificationDB(t)
|
||||
ctx := context.Background()
|
||||
accountRef := bson.NewObjectID()
|
||||
|
||||
_, err := db.Create(ctx, req(accountRef, model.PurposePasswordReset, "", time.Hour))
|
||||
require.NoError(t, err)
|
||||
|
||||
// No cooldown set — immediate re-create should succeed.
|
||||
_, err = db.Create(ctx, req(accountRef, model.PurposePasswordReset, "", time.Hour))
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestCreate_IdempotencyKeyReplayReturnsSameToken(t *testing.T) {
|
||||
db := newTestVerificationDB(t)
|
||||
ctx := context.Background()
|
||||
accountRef := bson.NewObjectID()
|
||||
|
||||
firstReq := req(accountRef, model.PurposePasswordReset, "", time.Hour).WithIdempotencyKey("same-key")
|
||||
firstRaw, err := db.Create(ctx, firstReq)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, firstRaw)
|
||||
|
||||
// Replay with the same idempotency key should return success and same token.
|
||||
secondReq := req(accountRef, model.PurposePasswordReset, "", time.Hour).WithIdempotencyKey("same-key")
|
||||
secondRaw, err := db.Create(ctx, secondReq)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, firstRaw, secondRaw)
|
||||
|
||||
repo := db.Repository.(*memoryTokenRepository)
|
||||
repo.mu.Lock()
|
||||
assert.Len(t, repo.data, 1)
|
||||
repo.mu.Unlock()
|
||||
}
|
||||
|
||||
func TestCreate_IdempotencyScopeIncludesTarget(t *testing.T) {
|
||||
db := newTestVerificationDB(t)
|
||||
ctx := context.Background()
|
||||
accountRef := bson.NewObjectID()
|
||||
|
||||
r1 := req(accountRef, model.PurposeEmailChange, "a@example.com", time.Hour).WithIdempotencyKey("same-key")
|
||||
raw1, err := db.Create(ctx, r1)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, raw1)
|
||||
|
||||
// Same account/purpose/key but different target should be treated as a different idempotency scope.
|
||||
r2 := req(accountRef, model.PurposeEmailChange, "b@example.com", time.Hour).WithIdempotencyKey("same-key")
|
||||
raw2, err := db.Create(ctx, r2)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, raw2)
|
||||
assert.NotEqual(t, raw1, raw2)
|
||||
|
||||
t1, err := db.Consume(ctx, accountRef, model.PurposeEmailChange, raw1)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "a@example.com", t1.Target)
|
||||
|
||||
t2, err := db.Consume(ctx, accountRef, model.PurposeEmailChange, raw2)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "b@example.com", t2.Target)
|
||||
}
|
||||
|
||||
func TestCreate_IdempotencySurvivesCallbackRetry(t *testing.T) {
|
||||
db := newTestVerificationDBWithFactory(t, &retryingTxFactory{})
|
||||
ctx := context.Background()
|
||||
accountRef := bson.NewObjectID()
|
||||
|
||||
// Cooldown would block the second callback execution if idempotency wasn't handled.
|
||||
r := req(accountRef, model.PurposePasswordReset, "", time.Hour).
|
||||
WithCooldown(time.Minute).
|
||||
WithIdempotencyKey("retry-safe")
|
||||
|
||||
raw, err := db.Create(ctx, r)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, raw)
|
||||
|
||||
repo := db.Repository.(*memoryTokenRepository)
|
||||
repo.mu.Lock()
|
||||
require.Len(t, repo.data, 1)
|
||||
for _, tok := range repo.data {
|
||||
require.NotNil(t, tok.IdempotencyKey)
|
||||
assert.Equal(t, "retry-safe", *tok.IdempotencyKey)
|
||||
assert.Nil(t, tok.UsedAt)
|
||||
assert.Equal(t, tok.VerifyTokenHash, tokenHash(raw))
|
||||
}
|
||||
repo.mu.Unlock()
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// max retries / attempts tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestConsume_MaxRetriesExceeded(t *testing.T) {
|
||||
db := newTestVerificationDB(t)
|
||||
ctx := context.Background()
|
||||
accountRef := bson.NewObjectID()
|
||||
|
||||
r := req(accountRef, model.PurposePasswordReset, "", time.Hour).WithMaxRetries(2)
|
||||
raw, err := db.Create(ctx, r)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Simulate 2 prior failed attempts by setting Attempts directly.
|
||||
repo := db.Repository.(*memoryTokenRepository)
|
||||
repo.mu.Lock()
|
||||
for _, tok := range repo.data {
|
||||
tok.Attempts = 2
|
||||
}
|
||||
repo.mu.Unlock()
|
||||
|
||||
// Consume with correct token should fail — attempts already at max.
|
||||
_, err = db.Consume(ctx, accountRef, model.PurposePasswordReset, raw)
|
||||
require.Error(t, err)
|
||||
assert.True(t, errors.Is(err, verification.ErrTokenAttemptsExceeded))
|
||||
}
|
||||
|
||||
func TestConsume_UnderMaxRetriesSucceeds(t *testing.T) {
|
||||
db := newTestVerificationDB(t)
|
||||
ctx := context.Background()
|
||||
accountRef := bson.NewObjectID()
|
||||
|
||||
r := req(accountRef, model.PurposePasswordReset, "", time.Hour).WithMaxRetries(3)
|
||||
raw, err := db.Create(ctx, r)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Simulate 2 prior failed attempts (under maxRetries=3).
|
||||
repo := db.Repository.(*memoryTokenRepository)
|
||||
repo.mu.Lock()
|
||||
for _, tok := range repo.data {
|
||||
tok.Attempts = 2
|
||||
}
|
||||
repo.mu.Unlock()
|
||||
|
||||
// Consume with correct token should succeed.
|
||||
tok, err := db.Consume(ctx, accountRef, model.PurposePasswordReset, raw)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, accountRef, tok.AccountRef)
|
||||
}
|
||||
|
||||
func TestConsume_NoMaxRetriesIgnoresAttempts(t *testing.T) {
|
||||
db := newTestVerificationDB(t)
|
||||
ctx := context.Background()
|
||||
accountRef := bson.NewObjectID()
|
||||
|
||||
// Create without MaxRetries.
|
||||
raw, err := db.Create(ctx, req(accountRef, model.PurposePasswordReset, "", time.Hour))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Simulate high attempt count — should be ignored since MaxRetries is nil.
|
||||
repo := db.Repository.(*memoryTokenRepository)
|
||||
repo.mu.Lock()
|
||||
for _, tok := range repo.data {
|
||||
tok.Attempts = 100
|
||||
}
|
||||
repo.mu.Unlock()
|
||||
|
||||
tok, err := db.Consume(ctx, accountRef, model.PurposePasswordReset, raw)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, accountRef, tok.AccountRef)
|
||||
}
|
||||
|
||||
func TestConsume_WrongHashReturnsNotFound(t *testing.T) {
|
||||
db := newTestVerificationDB(t)
|
||||
ctx := context.Background()
|
||||
accountRef := bson.NewObjectID()
|
||||
|
||||
_, err := db.Create(ctx, req(accountRef, model.PurposePasswordReset, "", time.Hour))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Wrong code — hash won't match any token.
|
||||
_, err = db.Consume(ctx, accountRef, model.PurposePasswordReset, "wrong-code")
|
||||
require.Error(t, err)
|
||||
assert.True(t, errors.Is(err, verification.ErrTokenNotFound))
|
||||
}
|
||||
|
||||
func TestConsume_ContextMismatchReturnsNotFound(t *testing.T) {
|
||||
db := newTestVerificationDB(t)
|
||||
ctx := context.Background()
|
||||
accountRef := bson.NewObjectID()
|
||||
otherAccount := bson.NewObjectID()
|
||||
|
||||
raw, err := db.Create(ctx, req(accountRef, model.PurposePasswordReset, "", time.Hour))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Correct token but wrong accountRef — context mismatch.
|
||||
_, err = db.Consume(ctx, otherAccount, model.PurposePasswordReset, raw)
|
||||
require.Error(t, err)
|
||||
assert.True(t, errors.Is(err, verification.ErrTokenNotFound))
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user