otp-450 #451

Merged
tech merged 2 commits from otp-450 into main 2026-02-10 01:12:05 +00:00
2 changed files with 72 additions and 43 deletions
Showing only changes of commit 7c182afd23 - Show all commits

View File

@@ -11,9 +11,7 @@ import (
"github.com/tech/sendico/pkg/merrors" "github.com/tech/sendico/pkg/merrors"
"github.com/tech/sendico/pkg/model" "github.com/tech/sendico/pkg/model"
mutil "github.com/tech/sendico/pkg/mutil/db" mutil "github.com/tech/sendico/pkg/mutil/db"
"github.com/tech/sendico/pkg/mutil/mzap"
"go.mongodb.org/mongo-driver/v2/bson" "go.mongodb.org/mongo-driver/v2/bson"
"go.uber.org/zap"
) )
func (db *verificationDB) Consume( func (db *verificationDB) Consume(
@@ -24,59 +22,74 @@ func (db *verificationDB) Consume(
) (*model.VerificationToken, error) { ) (*model.VerificationToken, error) {
now := time.Now().UTC() now := time.Now().UTC()
accountScoped := accountRef != bson.NilObjectID
t, e := db.tf.CreateTransaction().Execute( t, e := db.tf.CreateTransaction().Execute(
ct, ct,
func(ctx context.Context) (any, error) { func(ctx context.Context) (any, error) {
// 1) Load active tokens for this context scopeFilter := repository.Query().And(
activeFilter := repository.Query().And(
repository.Filter("accountRef", accountRef),
repository.Filter("purpose", purpose), repository.Filter("purpose", purpose),
repository.Filter("usedAt", nil),
repository.Query().Comparison(repository.Field("expiresAt"), builder.Gt, now),
) )
if accountScoped {
scopeFilter = scopeFilter.And(repository.Filter("accountRef", accountRef))
}
tokens, err := mutil.GetObjects[model.VerificationToken]( // 1) Fast path for magic-link tokens: hash is deterministic and globally unique.
ctx, db.Logger, activeFilter, nil, db.DBImp.Repository, var token *model.VerificationToken
magicFilter := scopeFilter.And(
repository.Filter("verifyTokenHash", tokenHash(rawToken)),
) )
if err != nil { var direct model.VerificationToken
if errors.Is(err, merrors.ErrNoData) { err := db.DBImp.FindOne(ctx, magicFilter, &direct)
db.Logger.Debug("No tokens found", zap.Error(err), mzap.AccRef(accountRef), zap.String("purpose", string(purpose))) switch {
return nil, verification.ErorrTokenNotFound() case err == nil:
} token = &direct
db.Logger.Warn("Failed to load active tokens", zap.Error(err), mzap.AccRef(accountRef), zap.String("purpose", string(purpose))) case errors.Is(err, merrors.ErrNoData):
default:
return nil, err return nil, err
} }
if len(tokens) == 0 { // If account is unknown, do not scan OTP candidates globally.
db.Logger.Debug("No tokens found", zap.Error(err), mzap.AccRef(accountRef), zap.String("purpose", string(purpose))) if token == nil && !accountScoped {
return nil, verification.ErorrTokenNotFound() return nil, verification.ErorrTokenNotFound()
} }
// 2) Find matching token via hasher (OTP or Magic — doesn't matter) // 2) OTP path (and fallback): load purpose/account scoped tokens and compare hash with per-token salt.
var token *model.VerificationToken
for i := range tokens {
t := &tokens[i]
hash := hasherFor(t).Hash(rawToken, t)
if hash == t.VerifyTokenHash {
token = t
break
}
}
if token == nil { if token == nil {
// wrong code/token → increment attempts tokens, err := mutil.GetObjects[model.VerificationToken](
for _, t := range tokens { ctx, db.Logger, scopeFilter, nil, db.DBImp.Repository,
)
if err != nil {
if errors.Is(err, merrors.ErrNoData) {
return nil, verification.ErorrTokenNotFound()
}
return nil, err
}
for i := range tokens {
t := &tokens[i]
hash := hasherFor(t).Hash(rawToken, t)
if hash == t.VerifyTokenHash {
token = t
break
}
}
if token == nil {
// wrong code/token → increment attempts for active (not used, not expired) scoped tokens
activeFilter := scopeFilter.And(
repository.Filter("usedAt", nil),
repository.Query().Comparison(repository.Field("expiresAt"), builder.Gt, now),
)
_, _ = db.DBImp.PatchMany( _, _ = db.DBImp.PatchMany(
ctx, ctx,
repository.IDFilter(t.ID), activeFilter,
repository.Patch().Inc(repository.Field("attempts"), 1), repository.Patch().Inc(repository.Field("attempts"), 1),
) )
return nil, verification.ErorrTokenNotFound()
} }
return nil, verification.ErorrTokenNotFound()
} }
// 3) Static checks // 3) Static checks
@@ -93,11 +106,13 @@ func (db *verificationDB) Consume(
// 4) Atomic consume // 4) Atomic consume
consumeFilter := repository.Query().And( consumeFilter := repository.Query().And(
repository.IDFilter(token.ID), repository.IDFilter(token.ID),
repository.Filter("accountRef", accountRef),
repository.Filter("purpose", purpose), repository.Filter("purpose", purpose),
repository.Filter("usedAt", nil), repository.Filter("usedAt", nil),
repository.Query().Comparison(repository.Field("expiresAt"), builder.Gt, now), repository.Query().Comparison(repository.Field("expiresAt"), builder.Gt, now),
) )
if accountScoped {
consumeFilter = consumeFilter.And(repository.Filter("accountRef", accountRef))
}
if token.MaxRetries != nil { if token.MaxRetries != nil {
consumeFilter = consumeFilter.And( consumeFilter = consumeFilter.And(

View File

@@ -553,8 +553,8 @@ func TestConsume_SecondConsumeFailsAlreadyUsed(t *testing.T) {
_, err = db.Consume(ctx, accountRef, model.PurposePasswordReset, raw) _, err = db.Consume(ctx, accountRef, model.PurposePasswordReset, raw)
require.Error(t, err) require.Error(t, err)
assert.True(t, errors.Is(err, verification.ErrTokenNotFound), assert.True(t, errors.Is(err, verification.ErrTokenAlreadyUsed),
"second consume should fail — used tokens are excluded from active filter") "second consume should fail with already-used after usedAt is set")
} }
func TestConsume_ExpiredTokenFails(t *testing.T) { func TestConsume_ExpiredTokenFails(t *testing.T) {
@@ -568,8 +568,8 @@ func TestConsume_ExpiredTokenFails(t *testing.T) {
_, err = db.Consume(ctx, accountRef, model.PurposePasswordReset, raw) _, err = db.Consume(ctx, accountRef, model.PurposePasswordReset, raw)
require.Error(t, err) require.Error(t, err)
assert.True(t, errors.Is(err, verification.ErrTokenNotFound), assert.True(t, errors.Is(err, verification.ErrTokenExpired),
"expired token is excluded from active filter") "expired token should return explicit expiry error")
} }
func TestConsume_UnknownTokenFails(t *testing.T) { func TestConsume_UnknownTokenFails(t *testing.T) {
@@ -581,6 +581,20 @@ func TestConsume_UnknownTokenFails(t *testing.T) {
assert.True(t, errors.Is(err, verification.ErrTokenNotFound)) assert.True(t, errors.Is(err, verification.ErrTokenNotFound))
} }
func TestConsume_AccountActivationWithoutAccountRef(t *testing.T) {
db := newTestVerificationDB(t)
ctx := context.Background()
accountRef := bson.NewObjectID()
raw, err := db.Create(ctx, req(accountRef, model.PurposeAccountActivation, "", time.Hour))
require.NoError(t, err)
tok, err := db.Consume(ctx, bson.NilObjectID, model.PurposeAccountActivation, raw)
require.NoError(t, err)
assert.Equal(t, accountRef, tok.AccountRef)
assert.Equal(t, model.PurposeAccountActivation, tok.Purpose)
}
func TestCreate_InvalidatesPreviousToken(t *testing.T) { func TestCreate_InvalidatesPreviousToken(t *testing.T) {
db := newTestVerificationDB(t) db := newTestVerificationDB(t)
ctx := context.Background() ctx := context.Background()
@@ -596,8 +610,8 @@ func TestCreate_InvalidatesPreviousToken(t *testing.T) {
// Old token is no longer consumable — invalidated (usedAt set) by the second Create. // Old token is no longer consumable — invalidated (usedAt set) by the second Create.
_, err = db.Consume(ctx, accountRef, model.PurposePasswordReset, oldRaw) _, err = db.Consume(ctx, accountRef, model.PurposePasswordReset, oldRaw)
require.Error(t, err) require.Error(t, err)
assert.True(t, errors.Is(err, verification.ErrTokenNotFound), assert.True(t, errors.Is(err, verification.ErrTokenAlreadyUsed),
"old token should be invalidated after new token creation") "old token should return already-used after invalidation")
// New token works fine. // New token works fine.
tok, err := db.Consume(ctx, accountRef, model.PurposePasswordReset, newRaw) tok, err := db.Consume(ctx, accountRef, model.PurposePasswordReset, newRaw)
@@ -618,9 +632,9 @@ func TestCreate_InvalidatesMultiplePreviousTokens(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
_, err = db.Consume(ctx, accountRef, model.PurposePasswordReset, first) _, err = db.Consume(ctx, accountRef, model.PurposePasswordReset, first)
assert.True(t, errors.Is(err, verification.ErrTokenNotFound), "first should be invalidated") assert.True(t, errors.Is(err, verification.ErrTokenAlreadyUsed), "first should be invalidated/used")
_, err = db.Consume(ctx, accountRef, model.PurposePasswordReset, second) _, err = db.Consume(ctx, accountRef, model.PurposePasswordReset, second)
assert.True(t, errors.Is(err, verification.ErrTokenNotFound), "second should be invalidated") assert.True(t, errors.Is(err, verification.ErrTokenAlreadyUsed), "second should be invalidated/used")
tok, err := db.Consume(ctx, accountRef, model.PurposePasswordReset, third) tok, err := db.Consume(ctx, accountRef, model.PurposePasswordReset, third)
require.NoError(t, err) require.NoError(t, err)