|
|
|
|
@@ -0,0 +1,621 @@
|
|
|
|
|
package verificationimp
|
|
|
|
|
|
|
|
|
|
import (
|
|
|
|
|
"context"
|
|
|
|
|
"errors"
|
|
|
|
|
"sync"
|
|
|
|
|
"testing"
|
|
|
|
|
"time"
|
|
|
|
|
|
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
|
|
|
"github.com/stretchr/testify/require"
|
|
|
|
|
"github.com/tech/sendico/pkg/db/repository/builder"
|
|
|
|
|
rd "github.com/tech/sendico/pkg/db/repository/decoder"
|
|
|
|
|
ri "github.com/tech/sendico/pkg/db/repository/index"
|
|
|
|
|
"github.com/tech/sendico/pkg/db/storable"
|
|
|
|
|
"github.com/tech/sendico/pkg/db/template"
|
|
|
|
|
"github.com/tech/sendico/pkg/db/transaction"
|
|
|
|
|
"github.com/tech/sendico/pkg/db/verification"
|
|
|
|
|
"github.com/tech/sendico/pkg/merrors"
|
|
|
|
|
"github.com/tech/sendico/pkg/model"
|
|
|
|
|
"github.com/tech/sendico/pkg/mservice"
|
|
|
|
|
"go.mongodb.org/mongo-driver/v2/bson"
|
|
|
|
|
"go.uber.org/zap"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
// ---------------------------------------------------------------------------
|
|
|
|
|
// helpers
|
|
|
|
|
// ---------------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
func newTestVerificationDB(t *testing.T) *verificationDB {
|
|
|
|
|
t.Helper()
|
|
|
|
|
repo := newMemoryTokenRepository()
|
|
|
|
|
logger := zap.NewNop()
|
|
|
|
|
return &verificationDB{
|
|
|
|
|
DBImp: template.DBImp[*model.VerificationToken]{
|
|
|
|
|
Logger: logger,
|
|
|
|
|
Repository: repo,
|
|
|
|
|
},
|
|
|
|
|
tf: &passthroughTxFactory{},
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// passthroughTxFactory executes callbacks directly without a real transaction.
|
|
|
|
|
type passthroughTxFactory struct{}
|
|
|
|
|
|
|
|
|
|
func (*passthroughTxFactory) CreateTransaction() transaction.Transaction { return &passthroughTx{} }
|
|
|
|
|
|
|
|
|
|
type passthroughTx struct{}
|
|
|
|
|
|
|
|
|
|
func (*passthroughTx) Execute(ctx context.Context, cb transaction.Callback) (any, error) {
|
|
|
|
|
return cb(ctx)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// ---------------------------------------------------------------------------
|
|
|
|
|
// in-memory repository for VerificationToken
|
|
|
|
|
// ---------------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
type memoryTokenRepository struct {
|
|
|
|
|
mu sync.Mutex
|
|
|
|
|
data map[bson.ObjectID]*model.VerificationToken
|
|
|
|
|
order []bson.ObjectID
|
|
|
|
|
seq int
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func newMemoryTokenRepository() *memoryTokenRepository {
|
|
|
|
|
return &memoryTokenRepository{data: make(map[bson.ObjectID]*model.VerificationToken)}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (m *memoryTokenRepository) Insert(_ context.Context, obj storable.Storable, _ builder.Query) error {
|
|
|
|
|
m.mu.Lock()
|
|
|
|
|
defer m.mu.Unlock()
|
|
|
|
|
tok, ok := obj.(*model.VerificationToken)
|
|
|
|
|
if !ok {
|
|
|
|
|
return merrors.InvalidDataType("expected VerificationToken")
|
|
|
|
|
}
|
|
|
|
|
id := tok.GetID()
|
|
|
|
|
if id == nil || *id == bson.NilObjectID {
|
|
|
|
|
m.seq++
|
|
|
|
|
tok.SetID(bson.NewObjectID())
|
|
|
|
|
id = tok.GetID()
|
|
|
|
|
}
|
|
|
|
|
if _, exists := m.data[*id]; exists {
|
|
|
|
|
return merrors.DataConflict("token already exists")
|
|
|
|
|
}
|
|
|
|
|
m.data[*id] = cloneToken(tok)
|
|
|
|
|
m.order = append(m.order, *id)
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (m *memoryTokenRepository) Get(_ context.Context, id bson.ObjectID, result storable.Storable) error {
|
|
|
|
|
m.mu.Lock()
|
|
|
|
|
defer m.mu.Unlock()
|
|
|
|
|
tok, ok := m.data[id]
|
|
|
|
|
if !ok {
|
|
|
|
|
return merrors.ErrNoData
|
|
|
|
|
}
|
|
|
|
|
dst := result.(*model.VerificationToken)
|
|
|
|
|
*dst = *cloneToken(tok)
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (m *memoryTokenRepository) FindOneByFilter(_ context.Context, query builder.Query, result storable.Storable) error {
|
|
|
|
|
m.mu.Lock()
|
|
|
|
|
defer m.mu.Unlock()
|
|
|
|
|
for _, id := range m.order {
|
|
|
|
|
tok := m.data[id]
|
|
|
|
|
if tok != nil && matchToken(query, tok) {
|
|
|
|
|
dst := result.(*model.VerificationToken)
|
|
|
|
|
*dst = *cloneToken(tok)
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return merrors.ErrNoData
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (m *memoryTokenRepository) Update(_ context.Context, obj storable.Storable) error {
|
|
|
|
|
m.mu.Lock()
|
|
|
|
|
defer m.mu.Unlock()
|
|
|
|
|
tok := obj.(*model.VerificationToken)
|
|
|
|
|
id := tok.GetID()
|
|
|
|
|
if id == nil {
|
|
|
|
|
return merrors.InvalidArgument("id required")
|
|
|
|
|
}
|
|
|
|
|
if _, exists := m.data[*id]; !exists {
|
|
|
|
|
return merrors.ErrNoData
|
|
|
|
|
}
|
|
|
|
|
m.data[*id] = cloneToken(tok)
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (m *memoryTokenRepository) PatchMany(_ context.Context, filter builder.Query, patch builder.Patch) (int, error) {
|
|
|
|
|
m.mu.Lock()
|
|
|
|
|
defer m.mu.Unlock()
|
|
|
|
|
patchDoc := patch.Build()
|
|
|
|
|
count := 0
|
|
|
|
|
for _, id := range m.order {
|
|
|
|
|
tok := m.data[id]
|
|
|
|
|
if tok != nil && matchToken(filter, tok) {
|
|
|
|
|
applyPatch(tok, patchDoc)
|
|
|
|
|
count++
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return count, nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// stubs — not exercised by verification DB but required by the interface
|
|
|
|
|
|
|
|
|
|
func (m *memoryTokenRepository) Aggregate(context.Context, builder.Pipeline, rd.DecodingFunc) error {
|
|
|
|
|
return merrors.NotImplemented("not needed")
|
|
|
|
|
}
|
|
|
|
|
func (m *memoryTokenRepository) InsertMany(ctx context.Context, objs []storable.Storable) error {
|
|
|
|
|
for _, o := range objs {
|
|
|
|
|
if err := m.Insert(ctx, o, nil); err != nil {
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
func (m *memoryTokenRepository) FindManyByFilter(context.Context, builder.Query, rd.DecodingFunc) error {
|
|
|
|
|
return merrors.NotImplemented("not needed")
|
|
|
|
|
}
|
|
|
|
|
func (m *memoryTokenRepository) Patch(context.Context, bson.ObjectID, builder.Patch) error {
|
|
|
|
|
return merrors.NotImplemented("not needed")
|
|
|
|
|
}
|
|
|
|
|
func (m *memoryTokenRepository) Delete(_ context.Context, id bson.ObjectID) error {
|
|
|
|
|
m.mu.Lock()
|
|
|
|
|
defer m.mu.Unlock()
|
|
|
|
|
delete(m.data, id)
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
func (m *memoryTokenRepository) DeleteMany(context.Context, builder.Query) error {
|
|
|
|
|
return merrors.NotImplemented("not needed")
|
|
|
|
|
}
|
|
|
|
|
func (m *memoryTokenRepository) CreateIndex(*ri.Definition) error { return nil }
|
|
|
|
|
func (m *memoryTokenRepository) ListIDs(context.Context, builder.Query) ([]bson.ObjectID, error) {
|
|
|
|
|
return nil, merrors.NotImplemented("not needed")
|
|
|
|
|
}
|
|
|
|
|
func (m *memoryTokenRepository) ListPermissionBound(context.Context, builder.Query) ([]model.PermissionBoundStorable, error) {
|
|
|
|
|
return nil, merrors.NotImplemented("not needed")
|
|
|
|
|
}
|
|
|
|
|
func (m *memoryTokenRepository) ListAccountBound(context.Context, builder.Query) ([]model.AccountBoundStorable, error) {
|
|
|
|
|
return nil, merrors.NotImplemented("not needed")
|
|
|
|
|
}
|
|
|
|
|
func (m *memoryTokenRepository) Collection() string { return mservice.VerificationTokens }
|
|
|
|
|
|
|
|
|
|
// ---------------------------------------------------------------------------
|
|
|
|
|
// bson.D query evaluation for VerificationToken
|
|
|
|
|
// ---------------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
// tokenFieldValue returns the stored value for a given BSON field name.
|
|
|
|
|
func tokenFieldValue(tok *model.VerificationToken, field string) any {
|
|
|
|
|
switch field {
|
|
|
|
|
case "verifyTokenHash":
|
|
|
|
|
return tok.VerifyTokenHash
|
|
|
|
|
case "usedAt":
|
|
|
|
|
return tok.UsedAt
|
|
|
|
|
case "expiresAt":
|
|
|
|
|
return tok.ExpiresAt
|
|
|
|
|
case "accountRef":
|
|
|
|
|
return tok.AccountRef
|
|
|
|
|
case "purpose":
|
|
|
|
|
return tok.Purpose
|
|
|
|
|
case "target":
|
|
|
|
|
return tok.Target
|
|
|
|
|
default:
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// matchToken evaluates a bson.D filter against a token.
|
|
|
|
|
func matchToken(query builder.Query, tok *model.VerificationToken) bool {
|
|
|
|
|
if query == nil {
|
|
|
|
|
return true
|
|
|
|
|
}
|
|
|
|
|
return matchBsonD(query.BuildQuery(), tok)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func matchBsonD(filter bson.D, tok *model.VerificationToken) bool {
|
|
|
|
|
for _, elem := range filter {
|
|
|
|
|
if !matchElem(elem, tok) {
|
|
|
|
|
return false
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return true
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func matchElem(elem bson.E, tok *model.VerificationToken) bool {
|
|
|
|
|
switch elem.Key {
|
|
|
|
|
|
|
|
|
|
case "$and":
|
|
|
|
|
arr, ok := elem.Value.(bson.A)
|
|
|
|
|
if !ok {
|
|
|
|
|
return false
|
|
|
|
|
}
|
|
|
|
|
for _, sub := range arr {
|
|
|
|
|
d, ok := sub.(bson.D)
|
|
|
|
|
if !ok {
|
|
|
|
|
return false
|
|
|
|
|
}
|
|
|
|
|
if !matchBsonD(d, tok) {
|
|
|
|
|
return false
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return true
|
|
|
|
|
|
|
|
|
|
default:
|
|
|
|
|
// Either a direct field match or a comparison operator doc.
|
|
|
|
|
stored := tokenFieldValue(tok, elem.Key)
|
|
|
|
|
|
|
|
|
|
// Check for operator document like {$gt: value}
|
|
|
|
|
if opDoc, ok := elem.Value.(bson.M); ok {
|
|
|
|
|
return matchOperator(stored, opDoc)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Direct equality (including nil check).
|
|
|
|
|
return valuesEqual(stored, elem.Value)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func matchOperator(stored any, ops bson.M) bool {
|
|
|
|
|
for op, cmpVal := range ops {
|
|
|
|
|
switch op {
|
|
|
|
|
case "$gt":
|
|
|
|
|
if !timeGt(stored, cmpVal) {
|
|
|
|
|
return false
|
|
|
|
|
}
|
|
|
|
|
case "$lt":
|
|
|
|
|
if !timeLt(stored, cmpVal) {
|
|
|
|
|
return false
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return true
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func valuesEqual(a, b any) bool {
|
|
|
|
|
// nil checks: usedAt == nil
|
|
|
|
|
if b == nil {
|
|
|
|
|
return a == nil || a == (*time.Time)(nil)
|
|
|
|
|
}
|
|
|
|
|
switch av := a.(type) {
|
|
|
|
|
case *time.Time:
|
|
|
|
|
if av == nil {
|
|
|
|
|
return b == nil
|
|
|
|
|
}
|
|
|
|
|
if bv, ok := b.(*time.Time); ok {
|
|
|
|
|
return av.Equal(*bv)
|
|
|
|
|
}
|
|
|
|
|
return false
|
|
|
|
|
case bson.ObjectID:
|
|
|
|
|
if bv, ok := b.(bson.ObjectID); ok {
|
|
|
|
|
return av == bv
|
|
|
|
|
}
|
|
|
|
|
return false
|
|
|
|
|
case model.VerificationPurpose:
|
|
|
|
|
if bv, ok := b.(model.VerificationPurpose); ok {
|
|
|
|
|
return av == bv
|
|
|
|
|
}
|
|
|
|
|
return false
|
|
|
|
|
case string:
|
|
|
|
|
if bv, ok := b.(string); ok {
|
|
|
|
|
return av == bv
|
|
|
|
|
}
|
|
|
|
|
return false
|
|
|
|
|
}
|
|
|
|
|
return false
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func timeGt(stored, cmpVal any) bool {
|
|
|
|
|
st, ok := toTime(stored)
|
|
|
|
|
if !ok {
|
|
|
|
|
return false
|
|
|
|
|
}
|
|
|
|
|
ct, ok := toTime(cmpVal)
|
|
|
|
|
if !ok {
|
|
|
|
|
return false
|
|
|
|
|
}
|
|
|
|
|
return st.After(ct)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func timeLt(stored, cmpVal any) bool {
|
|
|
|
|
st, ok := toTime(stored)
|
|
|
|
|
if !ok {
|
|
|
|
|
return false
|
|
|
|
|
}
|
|
|
|
|
ct, ok := toTime(cmpVal)
|
|
|
|
|
if !ok {
|
|
|
|
|
return false
|
|
|
|
|
}
|
|
|
|
|
return st.Before(ct)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func toTime(v any) (time.Time, bool) {
|
|
|
|
|
switch tv := v.(type) {
|
|
|
|
|
case time.Time:
|
|
|
|
|
return tv, true
|
|
|
|
|
case *time.Time:
|
|
|
|
|
if tv == nil {
|
|
|
|
|
return time.Time{}, false
|
|
|
|
|
}
|
|
|
|
|
return *tv, true
|
|
|
|
|
}
|
|
|
|
|
return time.Time{}, false
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// applyPatch applies $set operations from a patch bson.D to a token.
|
|
|
|
|
func applyPatch(tok *model.VerificationToken, patchDoc bson.D) {
|
|
|
|
|
for _, op := range patchDoc {
|
|
|
|
|
if op.Key != "$set" {
|
|
|
|
|
continue
|
|
|
|
|
}
|
|
|
|
|
fields, ok := op.Value.(bson.D)
|
|
|
|
|
if !ok {
|
|
|
|
|
continue
|
|
|
|
|
}
|
|
|
|
|
for _, f := range fields {
|
|
|
|
|
switch f.Key {
|
|
|
|
|
case "usedAt":
|
|
|
|
|
if t, ok := f.Value.(time.Time); ok {
|
|
|
|
|
tok.UsedAt = &t
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func cloneToken(src *model.VerificationToken) *model.VerificationToken {
|
|
|
|
|
dst := *src
|
|
|
|
|
if src.UsedAt != nil {
|
|
|
|
|
t := *src.UsedAt
|
|
|
|
|
dst.UsedAt = &t
|
|
|
|
|
}
|
|
|
|
|
return &dst
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// allTokens returns every stored token for inspection in tests.
|
|
|
|
|
func (m *memoryTokenRepository) allTokens() []*model.VerificationToken {
|
|
|
|
|
m.mu.Lock()
|
|
|
|
|
defer m.mu.Unlock()
|
|
|
|
|
out := make([]*model.VerificationToken, 0, len(m.data))
|
|
|
|
|
for _, id := range m.order {
|
|
|
|
|
if tok, ok := m.data[id]; ok {
|
|
|
|
|
out = append(out, cloneToken(tok))
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return out
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// ---------------------------------------------------------------------------
|
|
|
|
|
// tests
|
|
|
|
|
// ---------------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
func TestCreate_ReturnsRawToken(t *testing.T) {
|
|
|
|
|
db := newTestVerificationDB(t)
|
|
|
|
|
ctx := context.Background()
|
|
|
|
|
accountRef := bson.NewObjectID()
|
|
|
|
|
|
|
|
|
|
raw, err := db.Create(ctx, accountRef, model.PurposePasswordReset, "", time.Hour)
|
|
|
|
|
require.NoError(t, err)
|
|
|
|
|
assert.NotEmpty(t, raw)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func TestCreate_TokenCanBeConsumed(t *testing.T) {
|
|
|
|
|
db := newTestVerificationDB(t)
|
|
|
|
|
ctx := context.Background()
|
|
|
|
|
accountRef := bson.NewObjectID()
|
|
|
|
|
|
|
|
|
|
raw, err := db.Create(ctx, accountRef, model.PurposePasswordReset, "", time.Hour)
|
|
|
|
|
require.NoError(t, err)
|
|
|
|
|
|
|
|
|
|
tok, err := db.Consume(ctx, raw)
|
|
|
|
|
require.NoError(t, err)
|
|
|
|
|
assert.Equal(t, accountRef, tok.AccountRef)
|
|
|
|
|
assert.Equal(t, model.PurposePasswordReset, tok.Purpose)
|
|
|
|
|
assert.NotNil(t, tok.UsedAt)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func TestConsume_ReturnsCorrectFields(t *testing.T) {
|
|
|
|
|
db := newTestVerificationDB(t)
|
|
|
|
|
ctx := context.Background()
|
|
|
|
|
accountRef := bson.NewObjectID()
|
|
|
|
|
|
|
|
|
|
raw, err := db.Create(ctx, accountRef, model.PurposeEmailChange, "new@example.com", time.Hour)
|
|
|
|
|
require.NoError(t, err)
|
|
|
|
|
|
|
|
|
|
tok, err := db.Consume(ctx, raw)
|
|
|
|
|
require.NoError(t, err)
|
|
|
|
|
assert.Equal(t, accountRef, tok.AccountRef)
|
|
|
|
|
assert.Equal(t, model.PurposeEmailChange, tok.Purpose)
|
|
|
|
|
assert.Equal(t, "new@example.com", tok.Target)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func TestConsume_SecondConsumeFailsAlreadyUsed(t *testing.T) {
|
|
|
|
|
db := newTestVerificationDB(t)
|
|
|
|
|
ctx := context.Background()
|
|
|
|
|
accountRef := bson.NewObjectID()
|
|
|
|
|
|
|
|
|
|
raw, err := db.Create(ctx, accountRef, model.PurposePasswordReset, "", time.Hour)
|
|
|
|
|
require.NoError(t, err)
|
|
|
|
|
|
|
|
|
|
_, err = db.Consume(ctx, raw)
|
|
|
|
|
require.NoError(t, err)
|
|
|
|
|
|
|
|
|
|
_, err = db.Consume(ctx, raw)
|
|
|
|
|
require.Error(t, err)
|
|
|
|
|
assert.True(t, errors.Is(err, verification.ErrTokenAlreadyUsed),
|
|
|
|
|
"second consume should fail because usedAt is set")
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func TestConsume_ExpiredTokenFails(t *testing.T) {
|
|
|
|
|
db := newTestVerificationDB(t)
|
|
|
|
|
ctx := context.Background()
|
|
|
|
|
accountRef := bson.NewObjectID()
|
|
|
|
|
|
|
|
|
|
// Create with a TTL that is already in the past.
|
|
|
|
|
raw, err := db.Create(ctx, accountRef, model.PurposePasswordReset, "", -time.Hour)
|
|
|
|
|
require.NoError(t, err)
|
|
|
|
|
|
|
|
|
|
_, err = db.Consume(ctx, raw)
|
|
|
|
|
require.Error(t, err)
|
|
|
|
|
assert.True(t, errors.Is(err, verification.ErrTokenExpired),
|
|
|
|
|
"expired token should not be consumable")
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func TestConsume_UnknownTokenFails(t *testing.T) {
|
|
|
|
|
db := newTestVerificationDB(t)
|
|
|
|
|
ctx := context.Background()
|
|
|
|
|
|
|
|
|
|
_, err := db.Consume(ctx, "nonexistent-token-value")
|
|
|
|
|
require.Error(t, err)
|
|
|
|
|
assert.True(t, errors.Is(err, verification.ErrTokenNotFound))
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func TestCreate_InvalidatesPreviousToken(t *testing.T) {
|
|
|
|
|
db := newTestVerificationDB(t)
|
|
|
|
|
ctx := context.Background()
|
|
|
|
|
accountRef := bson.NewObjectID()
|
|
|
|
|
|
|
|
|
|
oldRaw, err := db.Create(ctx, accountRef, model.PurposePasswordReset, "", time.Hour)
|
|
|
|
|
require.NoError(t, err)
|
|
|
|
|
|
|
|
|
|
newRaw, err := db.Create(ctx, accountRef, model.PurposePasswordReset, "", time.Hour)
|
|
|
|
|
require.NoError(t, err)
|
|
|
|
|
assert.NotEqual(t, oldRaw, newRaw, "new token should differ from old one")
|
|
|
|
|
|
|
|
|
|
// Old token is no longer consumable.
|
|
|
|
|
_, err = db.Consume(ctx, oldRaw)
|
|
|
|
|
require.Error(t, err)
|
|
|
|
|
assert.True(t, errors.Is(err, verification.ErrTokenAlreadyUsed),
|
|
|
|
|
"old token should be invalidated (usedAt set) after new token creation")
|
|
|
|
|
|
|
|
|
|
// New token works fine.
|
|
|
|
|
tok, err := db.Consume(ctx, newRaw)
|
|
|
|
|
require.NoError(t, err)
|
|
|
|
|
assert.Equal(t, accountRef, tok.AccountRef)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func TestCreate_InvalidatesMultiplePreviousTokens(t *testing.T) {
|
|
|
|
|
db := newTestVerificationDB(t)
|
|
|
|
|
ctx := context.Background()
|
|
|
|
|
accountRef := bson.NewObjectID()
|
|
|
|
|
|
|
|
|
|
first, err := db.Create(ctx, accountRef, model.PurposePasswordReset, "", time.Hour)
|
|
|
|
|
require.NoError(t, err)
|
|
|
|
|
second, err := db.Create(ctx, accountRef, model.PurposePasswordReset, "", time.Hour)
|
|
|
|
|
require.NoError(t, err)
|
|
|
|
|
third, err := db.Create(ctx, accountRef, model.PurposePasswordReset, "", time.Hour)
|
|
|
|
|
require.NoError(t, err)
|
|
|
|
|
|
|
|
|
|
_, err = db.Consume(ctx, first)
|
|
|
|
|
assert.True(t, errors.Is(err, verification.ErrTokenAlreadyUsed), "first should be invalidated")
|
|
|
|
|
_, err = db.Consume(ctx, second)
|
|
|
|
|
assert.True(t, errors.Is(err, verification.ErrTokenAlreadyUsed), "second should be invalidated")
|
|
|
|
|
|
|
|
|
|
tok, err := db.Consume(ctx, third)
|
|
|
|
|
require.NoError(t, err)
|
|
|
|
|
assert.Equal(t, accountRef, tok.AccountRef)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func TestCreate_DifferentPurposeNotInvalidated(t *testing.T) {
|
|
|
|
|
db := newTestVerificationDB(t)
|
|
|
|
|
ctx := context.Background()
|
|
|
|
|
accountRef := bson.NewObjectID()
|
|
|
|
|
|
|
|
|
|
resetRaw, err := db.Create(ctx, accountRef, model.PurposePasswordReset, "", time.Hour)
|
|
|
|
|
require.NoError(t, err)
|
|
|
|
|
|
|
|
|
|
// Creating an activation token should NOT invalidate the password-reset token.
|
|
|
|
|
_, err = db.Create(ctx, accountRef, model.PurposeAccountActivation, "", time.Hour)
|
|
|
|
|
require.NoError(t, err)
|
|
|
|
|
|
|
|
|
|
tok, err := db.Consume(ctx, resetRaw)
|
|
|
|
|
require.NoError(t, err)
|
|
|
|
|
assert.Equal(t, model.PurposePasswordReset, tok.Purpose)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func TestCreate_DifferentTargetNotInvalidated(t *testing.T) {
|
|
|
|
|
db := newTestVerificationDB(t)
|
|
|
|
|
ctx := context.Background()
|
|
|
|
|
accountRef := bson.NewObjectID()
|
|
|
|
|
|
|
|
|
|
firstRaw, err := db.Create(ctx, accountRef, model.PurposeEmailChange, "a@example.com", time.Hour)
|
|
|
|
|
require.NoError(t, err)
|
|
|
|
|
|
|
|
|
|
// Creating a token for a different target email should NOT invalidate the first.
|
|
|
|
|
_, err = db.Create(ctx, accountRef, model.PurposeEmailChange, "b@example.com", time.Hour)
|
|
|
|
|
require.NoError(t, err)
|
|
|
|
|
|
|
|
|
|
tok, err := db.Consume(ctx, firstRaw)
|
|
|
|
|
require.NoError(t, err)
|
|
|
|
|
assert.Equal(t, "a@example.com", tok.Target)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func TestCreate_DifferentAccountNotInvalidated(t *testing.T) {
|
|
|
|
|
db := newTestVerificationDB(t)
|
|
|
|
|
ctx := context.Background()
|
|
|
|
|
account1 := bson.NewObjectID()
|
|
|
|
|
account2 := bson.NewObjectID()
|
|
|
|
|
|
|
|
|
|
raw1, err := db.Create(ctx, account1, model.PurposePasswordReset, "", time.Hour)
|
|
|
|
|
require.NoError(t, err)
|
|
|
|
|
|
|
|
|
|
_, err = db.Create(ctx, account2, model.PurposePasswordReset, "", time.Hour)
|
|
|
|
|
require.NoError(t, err)
|
|
|
|
|
|
|
|
|
|
tok, err := db.Consume(ctx, raw1)
|
|
|
|
|
require.NoError(t, err)
|
|
|
|
|
assert.Equal(t, account1, tok.AccountRef)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func TestCreate_AlreadyUsedTokenNotInvalidatedAgain(t *testing.T) {
|
|
|
|
|
db := newTestVerificationDB(t)
|
|
|
|
|
ctx := context.Background()
|
|
|
|
|
accountRef := bson.NewObjectID()
|
|
|
|
|
|
|
|
|
|
// Create and consume first token.
|
|
|
|
|
raw1, err := db.Create(ctx, accountRef, model.PurposePasswordReset, "", time.Hour)
|
|
|
|
|
require.NoError(t, err)
|
|
|
|
|
_, err = db.Consume(ctx, raw1)
|
|
|
|
|
require.NoError(t, err)
|
|
|
|
|
|
|
|
|
|
// Create second — the already-consumed token should have usedAt set,
|
|
|
|
|
// so the invalidation query (usedAt == nil) should skip it.
|
|
|
|
|
// This tests that the PatchMany filter correctly excludes already-used tokens.
|
|
|
|
|
raw2, err := db.Create(ctx, accountRef, model.PurposePasswordReset, "", time.Hour)
|
|
|
|
|
require.NoError(t, err)
|
|
|
|
|
|
|
|
|
|
tok, err := db.Consume(ctx, raw2)
|
|
|
|
|
require.NoError(t, err)
|
|
|
|
|
assert.Equal(t, accountRef, tok.AccountRef)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func TestCreate_ExpiredTokenNotInvalidated(t *testing.T) {
|
|
|
|
|
db := newTestVerificationDB(t)
|
|
|
|
|
ctx := context.Background()
|
|
|
|
|
accountRef := bson.NewObjectID()
|
|
|
|
|
|
|
|
|
|
// Create a token that is already expired.
|
|
|
|
|
_, err := db.Create(ctx, accountRef, model.PurposePasswordReset, "", -time.Hour)
|
|
|
|
|
require.NoError(t, err)
|
|
|
|
|
|
|
|
|
|
// Create a fresh one — invalidation should skip the expired token (expiresAt > now filter).
|
|
|
|
|
raw2, err := db.Create(ctx, accountRef, model.PurposePasswordReset, "", time.Hour)
|
|
|
|
|
require.NoError(t, err)
|
|
|
|
|
|
|
|
|
|
tok, err := db.Consume(ctx, raw2)
|
|
|
|
|
require.NoError(t, err)
|
|
|
|
|
assert.Equal(t, accountRef, tok.AccountRef)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func TestTokenHash_Deterministic(t *testing.T) {
|
|
|
|
|
h1 := tokenHash("same-input")
|
|
|
|
|
h2 := tokenHash("same-input")
|
|
|
|
|
assert.Equal(t, h1, h2)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func TestTokenHash_DifferentInputs(t *testing.T) {
|
|
|
|
|
h1 := tokenHash("input-a")
|
|
|
|
|
h2 := tokenHash("input-b")
|
|
|
|
|
assert.NotEqual(t, h1, h2)
|
|
|
|
|
}
|