fixed token errors
This commit is contained in:
@@ -11,9 +11,7 @@ import (
|
|||||||
"github.com/tech/sendico/pkg/merrors"
|
"github.com/tech/sendico/pkg/merrors"
|
||||||
"github.com/tech/sendico/pkg/model"
|
"github.com/tech/sendico/pkg/model"
|
||||||
mutil "github.com/tech/sendico/pkg/mutil/db"
|
mutil "github.com/tech/sendico/pkg/mutil/db"
|
||||||
"github.com/tech/sendico/pkg/mutil/mzap"
|
|
||||||
"go.mongodb.org/mongo-driver/v2/bson"
|
"go.mongodb.org/mongo-driver/v2/bson"
|
||||||
"go.uber.org/zap"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func (db *verificationDB) Consume(
|
func (db *verificationDB) Consume(
|
||||||
@@ -24,59 +22,74 @@ func (db *verificationDB) Consume(
|
|||||||
) (*model.VerificationToken, error) {
|
) (*model.VerificationToken, error) {
|
||||||
|
|
||||||
now := time.Now().UTC()
|
now := time.Now().UTC()
|
||||||
|
accountScoped := accountRef != bson.NilObjectID
|
||||||
|
|
||||||
t, e := db.tf.CreateTransaction().Execute(
|
t, e := db.tf.CreateTransaction().Execute(
|
||||||
ct,
|
ct,
|
||||||
func(ctx context.Context) (any, error) {
|
func(ctx context.Context) (any, error) {
|
||||||
|
|
||||||
// 1) Load active tokens for this context
|
scopeFilter := repository.Query().And(
|
||||||
activeFilter := repository.Query().And(
|
|
||||||
repository.Filter("accountRef", accountRef),
|
|
||||||
repository.Filter("purpose", purpose),
|
repository.Filter("purpose", purpose),
|
||||||
repository.Filter("usedAt", nil),
|
|
||||||
repository.Query().Comparison(repository.Field("expiresAt"), builder.Gt, now),
|
|
||||||
)
|
)
|
||||||
|
if accountScoped {
|
||||||
|
scopeFilter = scopeFilter.And(repository.Filter("accountRef", accountRef))
|
||||||
|
}
|
||||||
|
|
||||||
tokens, err := mutil.GetObjects[model.VerificationToken](
|
// 1) Fast path for magic-link tokens: hash is deterministic and globally unique.
|
||||||
ctx, db.Logger, activeFilter, nil, db.DBImp.Repository,
|
var token *model.VerificationToken
|
||||||
|
magicFilter := scopeFilter.And(
|
||||||
|
repository.Filter("verifyTokenHash", tokenHash(rawToken)),
|
||||||
)
|
)
|
||||||
if err != nil {
|
var direct model.VerificationToken
|
||||||
if errors.Is(err, merrors.ErrNoData) {
|
err := db.DBImp.FindOne(ctx, magicFilter, &direct)
|
||||||
db.Logger.Debug("No tokens found", zap.Error(err), mzap.AccRef(accountRef), zap.String("purpose", string(purpose)))
|
switch {
|
||||||
return nil, verification.ErorrTokenNotFound()
|
case err == nil:
|
||||||
}
|
token = &direct
|
||||||
db.Logger.Warn("Failed to load active tokens", zap.Error(err), mzap.AccRef(accountRef), zap.String("purpose", string(purpose)))
|
case errors.Is(err, merrors.ErrNoData):
|
||||||
|
default:
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(tokens) == 0 {
|
// If account is unknown, do not scan OTP candidates globally.
|
||||||
db.Logger.Debug("No tokens found", zap.Error(err), mzap.AccRef(accountRef), zap.String("purpose", string(purpose)))
|
if token == nil && !accountScoped {
|
||||||
return nil, verification.ErorrTokenNotFound()
|
return nil, verification.ErorrTokenNotFound()
|
||||||
}
|
}
|
||||||
|
|
||||||
// 2) Find matching token via hasher (OTP or Magic — doesn't matter)
|
// 2) OTP path (and fallback): load purpose/account scoped tokens and compare hash with per-token salt.
|
||||||
var token *model.VerificationToken
|
|
||||||
|
|
||||||
for i := range tokens {
|
|
||||||
t := &tokens[i]
|
|
||||||
hash := hasherFor(t).Hash(rawToken, t)
|
|
||||||
|
|
||||||
if hash == t.VerifyTokenHash {
|
|
||||||
token = t
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if token == nil {
|
if token == nil {
|
||||||
// wrong code/token → increment attempts
|
tokens, err := mutil.GetObjects[model.VerificationToken](
|
||||||
for _, t := range tokens {
|
ctx, db.Logger, scopeFilter, nil, db.DBImp.Repository,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, merrors.ErrNoData) {
|
||||||
|
return nil, verification.ErorrTokenNotFound()
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range tokens {
|
||||||
|
t := &tokens[i]
|
||||||
|
hash := hasherFor(t).Hash(rawToken, t)
|
||||||
|
if hash == t.VerifyTokenHash {
|
||||||
|
token = t
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if token == nil {
|
||||||
|
// wrong code/token → increment attempts for active (not used, not expired) scoped tokens
|
||||||
|
activeFilter := scopeFilter.And(
|
||||||
|
repository.Filter("usedAt", nil),
|
||||||
|
repository.Query().Comparison(repository.Field("expiresAt"), builder.Gt, now),
|
||||||
|
)
|
||||||
|
|
||||||
_, _ = db.DBImp.PatchMany(
|
_, _ = db.DBImp.PatchMany(
|
||||||
ctx,
|
ctx,
|
||||||
repository.IDFilter(t.ID),
|
activeFilter,
|
||||||
repository.Patch().Inc(repository.Field("attempts"), 1),
|
repository.Patch().Inc(repository.Field("attempts"), 1),
|
||||||
)
|
)
|
||||||
|
return nil, verification.ErorrTokenNotFound()
|
||||||
}
|
}
|
||||||
return nil, verification.ErorrTokenNotFound()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 3) Static checks
|
// 3) Static checks
|
||||||
@@ -93,11 +106,13 @@ func (db *verificationDB) Consume(
|
|||||||
// 4) Atomic consume
|
// 4) Atomic consume
|
||||||
consumeFilter := repository.Query().And(
|
consumeFilter := repository.Query().And(
|
||||||
repository.IDFilter(token.ID),
|
repository.IDFilter(token.ID),
|
||||||
repository.Filter("accountRef", accountRef),
|
|
||||||
repository.Filter("purpose", purpose),
|
repository.Filter("purpose", purpose),
|
||||||
repository.Filter("usedAt", nil),
|
repository.Filter("usedAt", nil),
|
||||||
repository.Query().Comparison(repository.Field("expiresAt"), builder.Gt, now),
|
repository.Query().Comparison(repository.Field("expiresAt"), builder.Gt, now),
|
||||||
)
|
)
|
||||||
|
if accountScoped {
|
||||||
|
consumeFilter = consumeFilter.And(repository.Filter("accountRef", accountRef))
|
||||||
|
}
|
||||||
|
|
||||||
if token.MaxRetries != nil {
|
if token.MaxRetries != nil {
|
||||||
consumeFilter = consumeFilter.And(
|
consumeFilter = consumeFilter.And(
|
||||||
|
|||||||
@@ -553,8 +553,8 @@ func TestConsume_SecondConsumeFailsAlreadyUsed(t *testing.T) {
|
|||||||
|
|
||||||
_, err = db.Consume(ctx, accountRef, model.PurposePasswordReset, raw)
|
_, err = db.Consume(ctx, accountRef, model.PurposePasswordReset, raw)
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
assert.True(t, errors.Is(err, verification.ErrTokenNotFound),
|
assert.True(t, errors.Is(err, verification.ErrTokenAlreadyUsed),
|
||||||
"second consume should fail — used tokens are excluded from active filter")
|
"second consume should fail with already-used after usedAt is set")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestConsume_ExpiredTokenFails(t *testing.T) {
|
func TestConsume_ExpiredTokenFails(t *testing.T) {
|
||||||
@@ -568,8 +568,8 @@ func TestConsume_ExpiredTokenFails(t *testing.T) {
|
|||||||
|
|
||||||
_, err = db.Consume(ctx, accountRef, model.PurposePasswordReset, raw)
|
_, err = db.Consume(ctx, accountRef, model.PurposePasswordReset, raw)
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
assert.True(t, errors.Is(err, verification.ErrTokenNotFound),
|
assert.True(t, errors.Is(err, verification.ErrTokenExpired),
|
||||||
"expired token is excluded from active filter")
|
"expired token should return explicit expiry error")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestConsume_UnknownTokenFails(t *testing.T) {
|
func TestConsume_UnknownTokenFails(t *testing.T) {
|
||||||
@@ -581,6 +581,20 @@ func TestConsume_UnknownTokenFails(t *testing.T) {
|
|||||||
assert.True(t, errors.Is(err, verification.ErrTokenNotFound))
|
assert.True(t, errors.Is(err, verification.ErrTokenNotFound))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestConsume_AccountActivationWithoutAccountRef(t *testing.T) {
|
||||||
|
db := newTestVerificationDB(t)
|
||||||
|
ctx := context.Background()
|
||||||
|
accountRef := bson.NewObjectID()
|
||||||
|
|
||||||
|
raw, err := db.Create(ctx, req(accountRef, model.PurposeAccountActivation, "", time.Hour))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
tok, err := db.Consume(ctx, bson.NilObjectID, model.PurposeAccountActivation, raw)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, accountRef, tok.AccountRef)
|
||||||
|
assert.Equal(t, model.PurposeAccountActivation, tok.Purpose)
|
||||||
|
}
|
||||||
|
|
||||||
func TestCreate_InvalidatesPreviousToken(t *testing.T) {
|
func TestCreate_InvalidatesPreviousToken(t *testing.T) {
|
||||||
db := newTestVerificationDB(t)
|
db := newTestVerificationDB(t)
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
@@ -596,8 +610,8 @@ func TestCreate_InvalidatesPreviousToken(t *testing.T) {
|
|||||||
// Old token is no longer consumable — invalidated (usedAt set) by the second Create.
|
// Old token is no longer consumable — invalidated (usedAt set) by the second Create.
|
||||||
_, err = db.Consume(ctx, accountRef, model.PurposePasswordReset, oldRaw)
|
_, err = db.Consume(ctx, accountRef, model.PurposePasswordReset, oldRaw)
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
assert.True(t, errors.Is(err, verification.ErrTokenNotFound),
|
assert.True(t, errors.Is(err, verification.ErrTokenAlreadyUsed),
|
||||||
"old token should be invalidated after new token creation")
|
"old token should return already-used after invalidation")
|
||||||
|
|
||||||
// New token works fine.
|
// New token works fine.
|
||||||
tok, err := db.Consume(ctx, accountRef, model.PurposePasswordReset, newRaw)
|
tok, err := db.Consume(ctx, accountRef, model.PurposePasswordReset, newRaw)
|
||||||
@@ -618,9 +632,9 @@ func TestCreate_InvalidatesMultiplePreviousTokens(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
_, err = db.Consume(ctx, accountRef, model.PurposePasswordReset, first)
|
_, err = db.Consume(ctx, accountRef, model.PurposePasswordReset, first)
|
||||||
assert.True(t, errors.Is(err, verification.ErrTokenNotFound), "first should be invalidated")
|
assert.True(t, errors.Is(err, verification.ErrTokenAlreadyUsed), "first should be invalidated/used")
|
||||||
_, err = db.Consume(ctx, accountRef, model.PurposePasswordReset, second)
|
_, err = db.Consume(ctx, accountRef, model.PurposePasswordReset, second)
|
||||||
assert.True(t, errors.Is(err, verification.ErrTokenNotFound), "second should be invalidated")
|
assert.True(t, errors.Is(err, verification.ErrTokenAlreadyUsed), "second should be invalidated/used")
|
||||||
|
|
||||||
tok, err := db.Consume(ctx, accountRef, model.PurposePasswordReset, third)
|
tok, err := db.Consume(ctx, accountRef, model.PurposePasswordReset, third)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|||||||
Reference in New Issue
Block a user