package verificationimp import ( "context" "errors" "sync" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/tech/sendico/pkg/db/repository/builder" rd "github.com/tech/sendico/pkg/db/repository/decoder" ri "github.com/tech/sendico/pkg/db/repository/index" "github.com/tech/sendico/pkg/db/storable" "github.com/tech/sendico/pkg/db/template" "github.com/tech/sendico/pkg/db/transaction" "github.com/tech/sendico/pkg/db/verification" "github.com/tech/sendico/pkg/merrors" "github.com/tech/sendico/pkg/model" "github.com/tech/sendico/pkg/mservice" "go.mongodb.org/mongo-driver/v2/bson" "go.uber.org/zap" ) // --------------------------------------------------------------------------- // helpers // --------------------------------------------------------------------------- func newTestVerificationDB(t *testing.T) *verificationDB { t.Helper() repo := newMemoryTokenRepository() logger := zap.NewNop() return &verificationDB{ DBImp: template.DBImp[*model.VerificationToken]{ Logger: logger, Repository: repo, }, tf: &passthroughTxFactory{}, } } // passthroughTxFactory executes callbacks directly without a real transaction. type passthroughTxFactory struct{} func (*passthroughTxFactory) CreateTransaction() transaction.Transaction { return &passthroughTx{} } type passthroughTx struct{} func (*passthroughTx) Execute(ctx context.Context, cb transaction.Callback) (any, error) { return cb(ctx) } // --------------------------------------------------------------------------- // in-memory repository for VerificationToken // --------------------------------------------------------------------------- type memoryTokenRepository struct { mu sync.Mutex data map[bson.ObjectID]*model.VerificationToken order []bson.ObjectID seq int } func newMemoryTokenRepository() *memoryTokenRepository { return &memoryTokenRepository{data: make(map[bson.ObjectID]*model.VerificationToken)} } func (m *memoryTokenRepository) Insert(_ context.Context, obj storable.Storable, _ builder.Query) error { m.mu.Lock() defer m.mu.Unlock() tok, ok := obj.(*model.VerificationToken) if !ok { return merrors.InvalidDataType("expected VerificationToken") } id := tok.GetID() if id == nil || *id == bson.NilObjectID { m.seq++ tok.SetID(bson.NewObjectID()) id = tok.GetID() } if _, exists := m.data[*id]; exists { return merrors.DataConflict("token already exists") } m.data[*id] = cloneToken(tok) m.order = append(m.order, *id) return nil } func (m *memoryTokenRepository) Get(_ context.Context, id bson.ObjectID, result storable.Storable) error { m.mu.Lock() defer m.mu.Unlock() tok, ok := m.data[id] if !ok { return merrors.ErrNoData } dst := result.(*model.VerificationToken) *dst = *cloneToken(tok) return nil } func (m *memoryTokenRepository) FindOneByFilter(_ context.Context, query builder.Query, result storable.Storable) error { m.mu.Lock() defer m.mu.Unlock() for _, id := range m.order { tok := m.data[id] if tok != nil && matchToken(query, tok) { dst := result.(*model.VerificationToken) *dst = *cloneToken(tok) return nil } } return merrors.ErrNoData } func (m *memoryTokenRepository) Update(_ context.Context, obj storable.Storable) error { m.mu.Lock() defer m.mu.Unlock() tok := obj.(*model.VerificationToken) id := tok.GetID() if id == nil { return merrors.InvalidArgument("id required") } if _, exists := m.data[*id]; !exists { return merrors.ErrNoData } m.data[*id] = cloneToken(tok) return nil } func (m *memoryTokenRepository) PatchMany(_ context.Context, filter builder.Query, patch builder.Patch) (int, error) { m.mu.Lock() defer m.mu.Unlock() patchDoc := patch.Build() count := 0 for _, id := range m.order { tok := m.data[id] if tok != nil && matchToken(filter, tok) { applyPatch(tok, patchDoc) count++ } } return count, nil } // stubs — not exercised by verification DB but required by the interface func (m *memoryTokenRepository) Aggregate(context.Context, builder.Pipeline, rd.DecodingFunc) error { return merrors.NotImplemented("not needed") } func (m *memoryTokenRepository) InsertMany(ctx context.Context, objs []storable.Storable) error { for _, o := range objs { if err := m.Insert(ctx, o, nil); err != nil { return err } } return nil } func (m *memoryTokenRepository) FindManyByFilter(context.Context, builder.Query, rd.DecodingFunc) error { return merrors.NotImplemented("not needed") } func (m *memoryTokenRepository) Patch(context.Context, bson.ObjectID, builder.Patch) error { return merrors.NotImplemented("not needed") } func (m *memoryTokenRepository) Delete(_ context.Context, id bson.ObjectID) error { m.mu.Lock() defer m.mu.Unlock() delete(m.data, id) return nil } func (m *memoryTokenRepository) DeleteMany(context.Context, builder.Query) error { return merrors.NotImplemented("not needed") } func (m *memoryTokenRepository) CreateIndex(*ri.Definition) error { return nil } func (m *memoryTokenRepository) ListIDs(context.Context, builder.Query) ([]bson.ObjectID, error) { return nil, merrors.NotImplemented("not needed") } func (m *memoryTokenRepository) ListPermissionBound(context.Context, builder.Query) ([]model.PermissionBoundStorable, error) { return nil, merrors.NotImplemented("not needed") } func (m *memoryTokenRepository) ListAccountBound(context.Context, builder.Query) ([]model.AccountBoundStorable, error) { return nil, merrors.NotImplemented("not needed") } func (m *memoryTokenRepository) Collection() string { return mservice.VerificationTokens } // --------------------------------------------------------------------------- // bson.D query evaluation for VerificationToken // --------------------------------------------------------------------------- // tokenFieldValue returns the stored value for a given BSON field name. func tokenFieldValue(tok *model.VerificationToken, field string) any { switch field { case "verifyTokenHash": return tok.VerifyTokenHash case "usedAt": return tok.UsedAt case "expiresAt": return tok.ExpiresAt case "accountRef": return tok.AccountRef case "purpose": return tok.Purpose case "target": return tok.Target default: return nil } } // matchToken evaluates a bson.D filter against a token. func matchToken(query builder.Query, tok *model.VerificationToken) bool { if query == nil { return true } return matchBsonD(query.BuildQuery(), tok) } func matchBsonD(filter bson.D, tok *model.VerificationToken) bool { for _, elem := range filter { if !matchElem(elem, tok) { return false } } return true } func matchElem(elem bson.E, tok *model.VerificationToken) bool { switch elem.Key { case "$and": arr, ok := elem.Value.(bson.A) if !ok { return false } for _, sub := range arr { d, ok := sub.(bson.D) if !ok { return false } if !matchBsonD(d, tok) { return false } } return true default: // Either a direct field match or a comparison operator doc. stored := tokenFieldValue(tok, elem.Key) // Check for operator document like {$gt: value} if opDoc, ok := elem.Value.(bson.M); ok { return matchOperator(stored, opDoc) } // Direct equality (including nil check). return valuesEqual(stored, elem.Value) } } func matchOperator(stored any, ops bson.M) bool { for op, cmpVal := range ops { switch op { case "$gt": if !timeGt(stored, cmpVal) { return false } case "$lt": if !timeLt(stored, cmpVal) { return false } } } return true } func valuesEqual(a, b any) bool { // nil checks: usedAt == nil if b == nil { return a == nil || a == (*time.Time)(nil) } switch av := a.(type) { case *time.Time: if av == nil { return b == nil } if bv, ok := b.(*time.Time); ok { return av.Equal(*bv) } return false case bson.ObjectID: if bv, ok := b.(bson.ObjectID); ok { return av == bv } return false case model.VerificationPurpose: if bv, ok := b.(model.VerificationPurpose); ok { return av == bv } return false case string: if bv, ok := b.(string); ok { return av == bv } return false } return false } func timeGt(stored, cmpVal any) bool { st, ok := toTime(stored) if !ok { return false } ct, ok := toTime(cmpVal) if !ok { return false } return st.After(ct) } func timeLt(stored, cmpVal any) bool { st, ok := toTime(stored) if !ok { return false } ct, ok := toTime(cmpVal) if !ok { return false } return st.Before(ct) } func toTime(v any) (time.Time, bool) { switch tv := v.(type) { case time.Time: return tv, true case *time.Time: if tv == nil { return time.Time{}, false } return *tv, true } return time.Time{}, false } // applyPatch applies $set operations from a patch bson.D to a token. func applyPatch(tok *model.VerificationToken, patchDoc bson.D) { for _, op := range patchDoc { if op.Key != "$set" { continue } fields, ok := op.Value.(bson.D) if !ok { continue } for _, f := range fields { switch f.Key { case "usedAt": if t, ok := f.Value.(time.Time); ok { tok.UsedAt = &t } } } } } func cloneToken(src *model.VerificationToken) *model.VerificationToken { dst := *src if src.UsedAt != nil { t := *src.UsedAt dst.UsedAt = &t } return &dst } // allTokens returns every stored token for inspection in tests. func (m *memoryTokenRepository) allTokens() []*model.VerificationToken { m.mu.Lock() defer m.mu.Unlock() out := make([]*model.VerificationToken, 0, len(m.data)) for _, id := range m.order { if tok, ok := m.data[id]; ok { out = append(out, cloneToken(tok)) } } return out } // --------------------------------------------------------------------------- // tests // --------------------------------------------------------------------------- func TestCreate_ReturnsRawToken(t *testing.T) { db := newTestVerificationDB(t) ctx := context.Background() accountRef := bson.NewObjectID() raw, err := db.Create(ctx, accountRef, model.PurposePasswordReset, "", time.Hour) require.NoError(t, err) assert.NotEmpty(t, raw) } func TestCreate_TokenCanBeConsumed(t *testing.T) { db := newTestVerificationDB(t) ctx := context.Background() accountRef := bson.NewObjectID() raw, err := db.Create(ctx, accountRef, model.PurposePasswordReset, "", time.Hour) require.NoError(t, err) tok, err := db.Consume(ctx, raw) require.NoError(t, err) assert.Equal(t, accountRef, tok.AccountRef) assert.Equal(t, model.PurposePasswordReset, tok.Purpose) assert.NotNil(t, tok.UsedAt) } func TestConsume_ReturnsCorrectFields(t *testing.T) { db := newTestVerificationDB(t) ctx := context.Background() accountRef := bson.NewObjectID() raw, err := db.Create(ctx, accountRef, model.PurposeEmailChange, "new@example.com", time.Hour) require.NoError(t, err) tok, err := db.Consume(ctx, raw) require.NoError(t, err) assert.Equal(t, accountRef, tok.AccountRef) assert.Equal(t, model.PurposeEmailChange, tok.Purpose) assert.Equal(t, "new@example.com", tok.Target) } func TestConsume_SecondConsumeFailsAlreadyUsed(t *testing.T) { db := newTestVerificationDB(t) ctx := context.Background() accountRef := bson.NewObjectID() raw, err := db.Create(ctx, accountRef, model.PurposePasswordReset, "", time.Hour) require.NoError(t, err) _, err = db.Consume(ctx, raw) require.NoError(t, err) _, err = db.Consume(ctx, raw) require.Error(t, err) assert.True(t, errors.Is(err, verification.ErrTokenAlreadyUsed), "second consume should fail because usedAt is set") } func TestConsume_ExpiredTokenFails(t *testing.T) { db := newTestVerificationDB(t) ctx := context.Background() accountRef := bson.NewObjectID() // Create with a TTL that is already in the past. raw, err := db.Create(ctx, accountRef, model.PurposePasswordReset, "", -time.Hour) require.NoError(t, err) _, err = db.Consume(ctx, raw) require.Error(t, err) assert.True(t, errors.Is(err, verification.ErrTokenExpired), "expired token should not be consumable") } func TestConsume_UnknownTokenFails(t *testing.T) { db := newTestVerificationDB(t) ctx := context.Background() _, err := db.Consume(ctx, "nonexistent-token-value") require.Error(t, err) assert.True(t, errors.Is(err, verification.ErrTokenNotFound)) } func TestCreate_InvalidatesPreviousToken(t *testing.T) { db := newTestVerificationDB(t) ctx := context.Background() accountRef := bson.NewObjectID() oldRaw, err := db.Create(ctx, accountRef, model.PurposePasswordReset, "", time.Hour) require.NoError(t, err) newRaw, err := db.Create(ctx, accountRef, model.PurposePasswordReset, "", time.Hour) require.NoError(t, err) assert.NotEqual(t, oldRaw, newRaw, "new token should differ from old one") // Old token is no longer consumable. _, err = db.Consume(ctx, oldRaw) require.Error(t, err) assert.True(t, errors.Is(err, verification.ErrTokenAlreadyUsed), "old token should be invalidated (usedAt set) after new token creation") // New token works fine. tok, err := db.Consume(ctx, newRaw) require.NoError(t, err) assert.Equal(t, accountRef, tok.AccountRef) } func TestCreate_InvalidatesMultiplePreviousTokens(t *testing.T) { db := newTestVerificationDB(t) ctx := context.Background() accountRef := bson.NewObjectID() first, err := db.Create(ctx, accountRef, model.PurposePasswordReset, "", time.Hour) require.NoError(t, err) second, err := db.Create(ctx, accountRef, model.PurposePasswordReset, "", time.Hour) require.NoError(t, err) third, err := db.Create(ctx, accountRef, model.PurposePasswordReset, "", time.Hour) require.NoError(t, err) _, err = db.Consume(ctx, first) assert.True(t, errors.Is(err, verification.ErrTokenAlreadyUsed), "first should be invalidated") _, err = db.Consume(ctx, second) assert.True(t, errors.Is(err, verification.ErrTokenAlreadyUsed), "second should be invalidated") tok, err := db.Consume(ctx, third) require.NoError(t, err) assert.Equal(t, accountRef, tok.AccountRef) } func TestCreate_DifferentPurposeNotInvalidated(t *testing.T) { db := newTestVerificationDB(t) ctx := context.Background() accountRef := bson.NewObjectID() resetRaw, err := db.Create(ctx, accountRef, model.PurposePasswordReset, "", time.Hour) require.NoError(t, err) // Creating an activation token should NOT invalidate the password-reset token. _, err = db.Create(ctx, accountRef, model.PurposeAccountActivation, "", time.Hour) require.NoError(t, err) tok, err := db.Consume(ctx, resetRaw) require.NoError(t, err) assert.Equal(t, model.PurposePasswordReset, tok.Purpose) } func TestCreate_DifferentTargetNotInvalidated(t *testing.T) { db := newTestVerificationDB(t) ctx := context.Background() accountRef := bson.NewObjectID() firstRaw, err := db.Create(ctx, accountRef, model.PurposeEmailChange, "a@example.com", time.Hour) require.NoError(t, err) // Creating a token for a different target email should NOT invalidate the first. _, err = db.Create(ctx, accountRef, model.PurposeEmailChange, "b@example.com", time.Hour) require.NoError(t, err) tok, err := db.Consume(ctx, firstRaw) require.NoError(t, err) assert.Equal(t, "a@example.com", tok.Target) } func TestCreate_DifferentAccountNotInvalidated(t *testing.T) { db := newTestVerificationDB(t) ctx := context.Background() account1 := bson.NewObjectID() account2 := bson.NewObjectID() raw1, err := db.Create(ctx, account1, model.PurposePasswordReset, "", time.Hour) require.NoError(t, err) _, err = db.Create(ctx, account2, model.PurposePasswordReset, "", time.Hour) require.NoError(t, err) tok, err := db.Consume(ctx, raw1) require.NoError(t, err) assert.Equal(t, account1, tok.AccountRef) } func TestCreate_AlreadyUsedTokenNotInvalidatedAgain(t *testing.T) { db := newTestVerificationDB(t) ctx := context.Background() accountRef := bson.NewObjectID() // Create and consume first token. raw1, err := db.Create(ctx, accountRef, model.PurposePasswordReset, "", time.Hour) require.NoError(t, err) _, err = db.Consume(ctx, raw1) require.NoError(t, err) // Create second — the already-consumed token should have usedAt set, // so the invalidation query (usedAt == nil) should skip it. // This tests that the PatchMany filter correctly excludes already-used tokens. raw2, err := db.Create(ctx, accountRef, model.PurposePasswordReset, "", time.Hour) require.NoError(t, err) tok, err := db.Consume(ctx, raw2) require.NoError(t, err) assert.Equal(t, accountRef, tok.AccountRef) } func TestCreate_ExpiredTokenNotInvalidated(t *testing.T) { db := newTestVerificationDB(t) ctx := context.Background() accountRef := bson.NewObjectID() // Create a token that is already expired. _, err := db.Create(ctx, accountRef, model.PurposePasswordReset, "", -time.Hour) require.NoError(t, err) // Create a fresh one — invalidation should skip the expired token (expiresAt > now filter). raw2, err := db.Create(ctx, accountRef, model.PurposePasswordReset, "", time.Hour) require.NoError(t, err) tok, err := db.Consume(ctx, raw2) require.NoError(t, err) assert.Equal(t, accountRef, tok.AccountRef) } func TestTokenHash_Deterministic(t *testing.T) { h1 := tokenHash("same-input") h2 := tokenHash("same-input") assert.Equal(t, h1, h2) } func TestTokenHash_DifferentInputs(t *testing.T) { h1 := tokenHash("input-a") h2 := tokenHash("input-b") assert.NotEqual(t, h1, h2) }