174 lines
4.6 KiB
Go
174 lines
4.6 KiB
Go
package verificationimp
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"time"
|
|
|
|
"github.com/tech/sendico/pkg/db/repository"
|
|
"github.com/tech/sendico/pkg/db/repository/builder"
|
|
"github.com/tech/sendico/pkg/db/verification"
|
|
"github.com/tech/sendico/pkg/merrors"
|
|
"github.com/tech/sendico/pkg/model"
|
|
mutil "github.com/tech/sendico/pkg/mutil/db"
|
|
"go.mongodb.org/mongo-driver/v2/bson"
|
|
)
|
|
|
|
func (db *verificationDB) Consume(
|
|
ct context.Context,
|
|
accountRef bson.ObjectID,
|
|
purpose model.VerificationPurpose,
|
|
rawToken string,
|
|
) (*model.VerificationToken, error) {
|
|
|
|
now := time.Now().UTC()
|
|
accountScoped := accountRef != bson.NilObjectID
|
|
|
|
t, e := db.tf.CreateTransaction().Execute(
|
|
ct,
|
|
func(ctx context.Context) (any, error) {
|
|
|
|
scopeFilter := repository.Query().And(
|
|
repository.Filter("purpose", purpose),
|
|
)
|
|
if accountScoped {
|
|
scopeFilter = scopeFilter.And(repository.Filter("accountRef", accountRef))
|
|
}
|
|
|
|
// 1) Fast path for magic-link tokens: hash is deterministic and globally unique.
|
|
var token *model.VerificationToken
|
|
magicFilter := scopeFilter.And(
|
|
repository.Filter("verifyTokenHash", tokenHash(rawToken)),
|
|
)
|
|
var direct model.VerificationToken
|
|
err := db.DBImp.FindOne(ctx, magicFilter, &direct)
|
|
switch {
|
|
case err == nil:
|
|
token = &direct
|
|
case errors.Is(err, merrors.ErrNoData):
|
|
default:
|
|
return nil, err
|
|
}
|
|
|
|
// If account is unknown, do not scan OTP candidates globally.
|
|
if token == nil && !accountScoped {
|
|
return nil, verification.ErorrTokenNotFound()
|
|
}
|
|
|
|
// 2) OTP path (and fallback): load purpose/account scoped tokens and compare hash with per-token salt.
|
|
if token == nil {
|
|
tokens, err := mutil.GetObjects[model.VerificationToken](
|
|
ctx, db.Logger, scopeFilter, nil, db.DBImp.Repository,
|
|
)
|
|
if err != nil {
|
|
if errors.Is(err, merrors.ErrNoData) {
|
|
return nil, verification.ErorrTokenNotFound()
|
|
}
|
|
return nil, err
|
|
}
|
|
|
|
for i := range tokens {
|
|
t := &tokens[i]
|
|
hash := hasherFor(t).Hash(rawToken, t)
|
|
if hash == t.VerifyTokenHash {
|
|
token = t
|
|
break
|
|
}
|
|
}
|
|
|
|
if token == nil {
|
|
// wrong code/token → increment attempts for active (not used, not expired) scoped tokens
|
|
activeFilter := scopeFilter.And(
|
|
repository.Filter("usedAt", nil),
|
|
repository.Query().Comparison(repository.Field("expiresAt"), builder.Gt, now),
|
|
)
|
|
|
|
_, _ = db.DBImp.PatchMany(
|
|
ctx,
|
|
activeFilter,
|
|
repository.Patch().Inc(repository.Field("attempts"), 1),
|
|
)
|
|
return nil, verification.ErorrTokenNotFound()
|
|
}
|
|
}
|
|
|
|
// 3) Static checks
|
|
if token.UsedAt != nil {
|
|
return nil, verification.ErorrTokenAlreadyUsed()
|
|
}
|
|
if !token.ExpiresAt.After(now) {
|
|
return nil, verification.ErorrTokenExpired()
|
|
}
|
|
if token.MaxRetries != nil && token.Attempts >= *token.MaxRetries {
|
|
return nil, verification.ErrorTokenAttemptsExceeded()
|
|
}
|
|
|
|
// 4) Atomic consume
|
|
consumeFilter := repository.Query().And(
|
|
repository.IDFilter(token.ID),
|
|
repository.Filter("purpose", purpose),
|
|
repository.Filter("usedAt", nil),
|
|
repository.Query().Comparison(repository.Field("expiresAt"), builder.Gt, now),
|
|
)
|
|
if accountScoped {
|
|
consumeFilter = consumeFilter.And(repository.Filter("accountRef", accountRef))
|
|
}
|
|
|
|
if token.MaxRetries != nil {
|
|
consumeFilter = consumeFilter.And(
|
|
repository.Query().Comparison(repository.Field("attempts"), builder.Lt, *token.MaxRetries),
|
|
)
|
|
}
|
|
|
|
updated, err := db.DBImp.PatchMany(
|
|
ctx,
|
|
consumeFilter,
|
|
repository.Patch().Set(repository.Field("usedAt"), now),
|
|
)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if updated == 1 {
|
|
token.UsedAt = &now
|
|
return token, nil
|
|
}
|
|
|
|
// 5) Consume failed → increment attempts
|
|
_, _ = db.DBImp.PatchMany(
|
|
ctx,
|
|
repository.IDFilter(token.ID),
|
|
repository.Patch().Inc(repository.Field("attempts"), 1),
|
|
)
|
|
|
|
// 6) Re-check state
|
|
var fresh model.VerificationToken
|
|
if err := db.DBImp.FindOne(ctx, repository.IDFilter(token.ID), &fresh); err != nil {
|
|
return nil, merrors.Internal("failed to re-check token state")
|
|
}
|
|
|
|
if fresh.UsedAt != nil {
|
|
return nil, verification.ErorrTokenAlreadyUsed()
|
|
}
|
|
if !fresh.ExpiresAt.After(now) {
|
|
return nil, verification.ErorrTokenExpired()
|
|
}
|
|
if fresh.MaxRetries != nil && fresh.Attempts >= *fresh.MaxRetries {
|
|
return nil, verification.ErrorTokenAttemptsExceeded()
|
|
}
|
|
|
|
return nil, verification.ErorrTokenNotFound()
|
|
},
|
|
)
|
|
|
|
if e != nil {
|
|
return nil, e
|
|
}
|
|
|
|
res, ok := t.(*model.VerificationToken)
|
|
if !ok {
|
|
return nil, merrors.Internal("unexpected token type")
|
|
}
|
|
return res, nil
|
|
}
|