fixed token errors
This commit is contained in:
@@ -11,9 +11,7 @@ import (
|
||||
"github.com/tech/sendico/pkg/merrors"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
mutil "github.com/tech/sendico/pkg/mutil/db"
|
||||
"github.com/tech/sendico/pkg/mutil/mzap"
|
||||
"go.mongodb.org/mongo-driver/v2/bson"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func (db *verificationDB) Consume(
|
||||
@@ -24,59 +22,74 @@ func (db *verificationDB) Consume(
|
||||
) (*model.VerificationToken, error) {
|
||||
|
||||
now := time.Now().UTC()
|
||||
accountScoped := accountRef != bson.NilObjectID
|
||||
|
||||
t, e := db.tf.CreateTransaction().Execute(
|
||||
ct,
|
||||
func(ctx context.Context) (any, error) {
|
||||
|
||||
// 1) Load active tokens for this context
|
||||
activeFilter := repository.Query().And(
|
||||
repository.Filter("accountRef", accountRef),
|
||||
scopeFilter := repository.Query().And(
|
||||
repository.Filter("purpose", purpose),
|
||||
repository.Filter("usedAt", nil),
|
||||
repository.Query().Comparison(repository.Field("expiresAt"), builder.Gt, now),
|
||||
)
|
||||
if accountScoped {
|
||||
scopeFilter = scopeFilter.And(repository.Filter("accountRef", accountRef))
|
||||
}
|
||||
|
||||
tokens, err := mutil.GetObjects[model.VerificationToken](
|
||||
ctx, db.Logger, activeFilter, nil, db.DBImp.Repository,
|
||||
// 1) Fast path for magic-link tokens: hash is deterministic and globally unique.
|
||||
var token *model.VerificationToken
|
||||
magicFilter := scopeFilter.And(
|
||||
repository.Filter("verifyTokenHash", tokenHash(rawToken)),
|
||||
)
|
||||
if err != nil {
|
||||
if errors.Is(err, merrors.ErrNoData) {
|
||||
db.Logger.Debug("No tokens found", zap.Error(err), mzap.AccRef(accountRef), zap.String("purpose", string(purpose)))
|
||||
return nil, verification.ErorrTokenNotFound()
|
||||
}
|
||||
db.Logger.Warn("Failed to load active tokens", zap.Error(err), mzap.AccRef(accountRef), zap.String("purpose", string(purpose)))
|
||||
var direct model.VerificationToken
|
||||
err := db.DBImp.FindOne(ctx, magicFilter, &direct)
|
||||
switch {
|
||||
case err == nil:
|
||||
token = &direct
|
||||
case errors.Is(err, merrors.ErrNoData):
|
||||
default:
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(tokens) == 0 {
|
||||
db.Logger.Debug("No tokens found", zap.Error(err), mzap.AccRef(accountRef), zap.String("purpose", string(purpose)))
|
||||
// If account is unknown, do not scan OTP candidates globally.
|
||||
if token == nil && !accountScoped {
|
||||
return nil, verification.ErorrTokenNotFound()
|
||||
}
|
||||
|
||||
// 2) Find matching token via hasher (OTP or Magic — doesn't matter)
|
||||
var token *model.VerificationToken
|
||||
|
||||
for i := range tokens {
|
||||
t := &tokens[i]
|
||||
hash := hasherFor(t).Hash(rawToken, t)
|
||||
|
||||
if hash == t.VerifyTokenHash {
|
||||
token = t
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// 2) OTP path (and fallback): load purpose/account scoped tokens and compare hash with per-token salt.
|
||||
if token == nil {
|
||||
// wrong code/token → increment attempts
|
||||
for _, t := range tokens {
|
||||
tokens, err := mutil.GetObjects[model.VerificationToken](
|
||||
ctx, db.Logger, scopeFilter, nil, db.DBImp.Repository,
|
||||
)
|
||||
if err != nil {
|
||||
if errors.Is(err, merrors.ErrNoData) {
|
||||
return nil, verification.ErorrTokenNotFound()
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for i := range tokens {
|
||||
t := &tokens[i]
|
||||
hash := hasherFor(t).Hash(rawToken, t)
|
||||
if hash == t.VerifyTokenHash {
|
||||
token = t
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if token == nil {
|
||||
// wrong code/token → increment attempts for active (not used, not expired) scoped tokens
|
||||
activeFilter := scopeFilter.And(
|
||||
repository.Filter("usedAt", nil),
|
||||
repository.Query().Comparison(repository.Field("expiresAt"), builder.Gt, now),
|
||||
)
|
||||
|
||||
_, _ = db.DBImp.PatchMany(
|
||||
ctx,
|
||||
repository.IDFilter(t.ID),
|
||||
activeFilter,
|
||||
repository.Patch().Inc(repository.Field("attempts"), 1),
|
||||
)
|
||||
return nil, verification.ErorrTokenNotFound()
|
||||
}
|
||||
return nil, verification.ErorrTokenNotFound()
|
||||
}
|
||||
|
||||
// 3) Static checks
|
||||
@@ -93,11 +106,13 @@ func (db *verificationDB) Consume(
|
||||
// 4) Atomic consume
|
||||
consumeFilter := repository.Query().And(
|
||||
repository.IDFilter(token.ID),
|
||||
repository.Filter("accountRef", accountRef),
|
||||
repository.Filter("purpose", purpose),
|
||||
repository.Filter("usedAt", nil),
|
||||
repository.Query().Comparison(repository.Field("expiresAt"), builder.Gt, now),
|
||||
)
|
||||
if accountScoped {
|
||||
consumeFilter = consumeFilter.And(repository.Filter("accountRef", accountRef))
|
||||
}
|
||||
|
||||
if token.MaxRetries != nil {
|
||||
consumeFilter = consumeFilter.And(
|
||||
|
||||
Reference in New Issue
Block a user