unified code verification service

This commit is contained in:
Stephan D
2026-02-10 01:55:33 +01:00
parent 76c3bfdea9
commit 7f540671c1
120 changed files with 1863 additions and 1394 deletions

View File

@@ -1,16 +0,0 @@
package confirmation
import (
"context"
"github.com/tech/sendico/pkg/db/template"
"github.com/tech/sendico/pkg/model"
"go.mongodb.org/mongo-driver/v2/bson"
)
type DB interface {
template.DB[*model.ConfirmationCode]
FindActive(ctx context.Context, accountRef bson.ObjectID, destination string, target model.ConfirmationTarget, now int64) (*model.ConfirmationCode, error)
DeleteTuple(ctx context.Context, accountRef bson.ObjectID, destination string, target model.ConfirmationTarget) error
}

View File

@@ -4,7 +4,6 @@ import (
"github.com/tech/sendico/pkg/auth"
"github.com/tech/sendico/pkg/db/account"
"github.com/tech/sendico/pkg/db/chainassets"
"github.com/tech/sendico/pkg/db/confirmation"
mongoimpl "github.com/tech/sendico/pkg/db/internal/mongo"
"github.com/tech/sendico/pkg/db/invitation"
"github.com/tech/sendico/pkg/db/organization"
@@ -22,7 +21,6 @@ import (
// Factory exposes high-level repositories used by application services.
type Factory interface {
NewRefreshTokensDB() (refreshtokens.DB, error)
NewConfirmationsDB() (confirmation.DB, error)
NewChainAsstesDB() (chainassets.DB, error)

View File

@@ -1,67 +0,0 @@
package confirmationdb
import (
"github.com/tech/sendico/pkg/db/confirmation"
ri "github.com/tech/sendico/pkg/db/repository/index"
"github.com/tech/sendico/pkg/db/template"
"github.com/tech/sendico/pkg/mlogger"
"github.com/tech/sendico/pkg/model"
"github.com/tech/sendico/pkg/mservice"
"go.mongodb.org/mongo-driver/v2/mongo"
"go.uber.org/zap"
)
const (
fieldAccountRef = "accountRef"
fieldDestination = "destination"
fieldTarget = "target"
fieldExpiresAt = "expiresAt"
fieldUsed = "used"
)
type ConfirmationDB struct {
template.DBImp[*model.ConfirmationCode]
}
func Create(logger mlogger.Logger, db *mongo.Database) (confirmation.DB, error) {
p := &ConfirmationDB{
DBImp: *template.Create[*model.ConfirmationCode](logger, mservice.Confirmations, db),
}
// Ensure one active code per account/destination/target.
if err := p.Repository.CreateIndex(&ri.Definition{
Keys: []ri.Key{
{Field: fieldAccountRef, Sort: ri.Asc},
{Field: fieldDestination, Sort: ri.Asc},
{Field: fieldTarget, Sort: ri.Asc},
},
Unique: true,
}); err != nil {
p.Logger.Error("Failed to create confirmation unique index", zap.Error(err))
return nil, err
}
// TTL on expiry.
ttl := int32(0)
if err := p.Repository.CreateIndex(&ri.Definition{
Keys: []ri.Key{
{Field: fieldExpiresAt, Sort: ri.Asc},
},
TTL: &ttl,
}); err != nil {
p.Logger.Error("Failed to create confirmation TTL index", zap.Error(err))
return nil, err
}
// Query helper indexes.
if err := p.Repository.CreateIndex(&ri.Definition{
Keys: []ri.Key{
{Field: fieldUsed, Sort: ri.Asc},
},
}); err != nil {
p.Logger.Error("Failed to create confirmation used index", zap.Error(err))
return nil, err
}
return p, nil
}

View File

@@ -1,17 +0,0 @@
package confirmationdb
import (
"context"
"github.com/tech/sendico/pkg/db/repository"
"github.com/tech/sendico/pkg/model"
"go.mongodb.org/mongo-driver/v2/bson"
)
func (db *ConfirmationDB) DeleteTuple(ctx context.Context, accountRef bson.ObjectID, destination string, target model.ConfirmationTarget) error {
query := repository.Query().
Filter(repository.Field(fieldAccountRef), accountRef).
Filter(repository.Field(fieldDestination), destination).
Filter(repository.Field(fieldTarget), target)
return db.DeleteMany(ctx, query)
}

View File

@@ -1,26 +0,0 @@
package confirmationdb
import (
"context"
"time"
"github.com/tech/sendico/pkg/db/repository"
"github.com/tech/sendico/pkg/db/repository/builder"
"github.com/tech/sendico/pkg/model"
"go.mongodb.org/mongo-driver/v2/bson"
)
func (db *ConfirmationDB) FindActive(ctx context.Context, accountRef bson.ObjectID, destination string, target model.ConfirmationTarget, now int64) (*model.ConfirmationCode, error) {
var res model.ConfirmationCode
query := repository.Query().
Filter(repository.Field(fieldAccountRef), accountRef).
Filter(repository.Field(fieldDestination), destination).
Filter(repository.Field(fieldTarget), target).
Filter(repository.Field(fieldUsed), false).
Comparison(repository.Field(fieldExpiresAt), builder.Gt, time.Unix(now, 0))
if err := db.FindOne(ctx, query, &res); err != nil {
return nil, err
}
return &res, nil
}

View File

@@ -11,10 +11,8 @@ import (
"github.com/tech/sendico/pkg/auth"
"github.com/tech/sendico/pkg/db/account"
"github.com/tech/sendico/pkg/db/chainassets"
"github.com/tech/sendico/pkg/db/confirmation"
"github.com/tech/sendico/pkg/db/internal/mongo/accountdb"
"github.com/tech/sendico/pkg/db/internal/mongo/chainassetsdb"
"github.com/tech/sendico/pkg/db/internal/mongo/confirmationdb"
"github.com/tech/sendico/pkg/db/internal/mongo/invitationdb"
"github.com/tech/sendico/pkg/db/internal/mongo/organizationdb"
"github.com/tech/sendico/pkg/db/internal/mongo/paymethoddb"
@@ -188,10 +186,6 @@ func (db *DB) NewAccountDB() (account.DB, error) {
return accountdb.Create(db.logger, db.db())
}
func (db *DB) NewConfirmationsDB() (confirmation.DB, error) {
return confirmationdb.Create(db.logger, db.db())
}
func (db *DB) NewOrganizationDB() (organization.DB, error) {
pdb, err := db.NewPoliciesDB()
if err != nil {

View File

@@ -12,7 +12,7 @@ func (db *PaymentMethodsDB) SetArchived(ctx context.Context, accountRef, organiz
// Use the ArchivableDB for the main archiving logic
if err := db.ArchivableDB.SetArchived(ctx, accountRef, objectRef, isArchived); err != nil {
db.DBImp.Logger.Warn("Failed to chnage object archive status", zap.Error(err),
mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("organization_ref", organizationRef),
mzap.AccRef(accountRef), mzap.ObjRef("organization_ref", organizationRef),
mzap.ObjRef("object_ref", objectRef), zap.Bool("archived", isArchived), zap.Bool("cascade", cascade))
return err
}

View File

@@ -14,7 +14,7 @@ func (db *RecipientDB) SetArchived(ctx context.Context, accountRef, organization
// Use the ArchivableDB for the main archiving logic
if err := db.ArchivableDB.SetArchived(ctx, accountRef, objectRef, isArchived); err != nil {
db.DBImp.Logger.Warn("Failed to change recipient archive status", zap.Error(err),
mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("organization_ref", organizationRef),
mzap.AccRef(accountRef), mzap.ObjRef("organization_ref", organizationRef),
mzap.ObjRef("recipient_ref", objectRef), zap.Bool("archived", isArchived), zap.Bool("cascade", cascade))
return err
}
@@ -22,7 +22,7 @@ func (db *RecipientDB) SetArchived(ctx context.Context, accountRef, organization
if cascade {
if err := db.setArchivedPaymentMethods(ctx, accountRef, organizationRef, objectRef, isArchived); err != nil {
db.DBImp.Logger.Warn("Failed to update payment methods archive status", zap.Error(err),
mzap.ObjRef("account_ref", accountRef), mzap.ObjRef("organization_ref", organizationRef),
mzap.AccRef(accountRef), mzap.ObjRef("organization_ref", organizationRef),
mzap.ObjRef("recipient_ref", objectRef), zap.Bool("archived", isArchived), zap.Bool("cascade", cascade))
return err

View File

@@ -78,7 +78,7 @@ func (db *RefreshTokenDB) Revoke(ctx context.Context, accountRef bson.ObjectID,
if err := db.Repository.FindOneByFilter(ctx, f, &rt); err != nil {
if errors.Is(err, merrors.ErrNoData) {
db.Logger.Warn("Failed to find refresh token", zap.Error(err),
mzap.ObjRef("account_ref", accountRef), zap.String("client_id", session.ClientID), zap.String("device_id", session.DeviceID))
mzap.AccRef(accountRef), zap.String("client_id", session.ClientID), zap.String("device_id", session.DeviceID))
return nil
}
return err

View File

@@ -6,68 +6,146 @@ import (
"time"
"github.com/tech/sendico/pkg/db/repository"
"github.com/tech/sendico/pkg/db/repository/builder"
"github.com/tech/sendico/pkg/db/verification"
"github.com/tech/sendico/pkg/merrors"
"github.com/tech/sendico/pkg/model"
mutil "github.com/tech/sendico/pkg/mutil/db"
"github.com/tech/sendico/pkg/mutil/mzap"
"go.mongodb.org/mongo-driver/v2/bson"
"go.uber.org/zap"
)
func (db *verificationDB) Consume(
ct context.Context,
accountRef bson.ObjectID,
purpose model.VerificationPurpose,
rawToken string,
) (*model.VerificationToken, error) {
hash := tokenHash(rawToken)
now := time.Now().UTC()
// 1) Find token by hash (do NOT filter by usedAt/expiresAt here),
// otherwise you can't distinguish "used/expired" from "not found".
filter := repository.Query().And(
repository.Filter("verifyTokenHash", hash),
)
t, e := db.tf.CreateTransaction().Execute(
ct,
func(ctx context.Context) (any, error) {
var existing model.VerificationToken
if err := db.DBImp.FindOne(ctx, filter, &existing); err != nil {
// 1) Load active tokens for this context
activeFilter := repository.Query().And(
repository.Filter("accountRef", accountRef),
repository.Filter("purpose", purpose),
repository.Filter("usedAt", nil),
repository.Query().Comparison(repository.Field("expiresAt"), builder.Gt, now),
)
tokens, err := mutil.GetObjects[model.VerificationToken](
ctx, db.Logger, activeFilter, nil, db.DBImp.Repository,
)
if err != nil {
if errors.Is(err, merrors.ErrNoData) {
db.Logger.Debug("Token hash not found", zap.Error(err), zap.String("hash", hash))
db.Logger.Debug("No tokens found", zap.Error(err), mzap.AccRef(accountRef), zap.String("purpose", string(purpose)))
return nil, verification.ErorrTokenNotFound()
}
db.Logger.Warn("Failed to check token", zap.Error(err), zap.String("hash", hash))
db.Logger.Warn("Failed to load active tokens", zap.Error(err), mzap.AccRef(accountRef), zap.String("purpose", string(purpose)))
return nil, err
}
// 2) Semantic checks
if existing.UsedAt != nil {
db.Logger.Debug(
"Token has already been used",
zap.String("hash", hash),
zap.Time("used_at", *existing.UsedAt),
)
if len(tokens) == 0 {
db.Logger.Debug("No tokens found", zap.Error(err), mzap.AccRef(accountRef), zap.String("purpose", string(purpose)))
return nil, verification.ErorrTokenNotFound()
}
// 2) Find matching token via hasher (OTP or Magic — doesn't matter)
var token *model.VerificationToken
for i := range tokens {
t := &tokens[i]
hash := hasherFor(t).Hash(rawToken, t)
if hash == t.VerifyTokenHash {
token = t
break
}
}
if token == nil {
// wrong code/token → increment attempts
for _, t := range tokens {
_, _ = db.DBImp.PatchMany(
ctx,
repository.IDFilter(t.ID),
repository.Patch().Inc(repository.Field("attempts"), 1),
)
}
return nil, verification.ErorrTokenNotFound()
}
// 3) Static checks
if token.UsedAt != nil {
return nil, verification.ErorrTokenAlreadyUsed()
}
if !existing.ExpiresAt.After(now) { // includes equal time edge-case
db.Logger.Debug(
"Token has already expired",
zap.String("hash", hash),
zap.Time("expired_at", existing.ExpiresAt),
)
if !token.ExpiresAt.After(now) {
return nil, verification.ErorrTokenExpired()
}
if token.MaxRetries != nil && token.Attempts >= *token.MaxRetries {
return nil, verification.ErrorTokenAttemptsExceeded()
}
// 3) Mark as used
existing.UsedAt = &now
if err := db.DBImp.Update(ctx, &existing); err != nil {
db.Logger.Warn("Failed to consume token", zap.Error(err), zap.String("hash", hash))
// 4) Atomic consume
consumeFilter := repository.Query().And(
repository.IDFilter(token.ID),
repository.Filter("accountRef", accountRef),
repository.Filter("purpose", purpose),
repository.Filter("usedAt", nil),
repository.Query().Comparison(repository.Field("expiresAt"), builder.Gt, now),
)
if token.MaxRetries != nil {
consumeFilter = consumeFilter.And(
repository.Query().Comparison(repository.Field("attempts"), builder.Lt, *token.MaxRetries),
)
}
updated, err := db.DBImp.PatchMany(
ctx,
consumeFilter,
repository.Patch().Set(repository.Field("usedAt"), now),
)
if err != nil {
return nil, err
}
return &existing, nil
if updated == 1 {
token.UsedAt = &now
return token, nil
}
// 5) Consume failed → increment attempts
_, _ = db.DBImp.PatchMany(
ctx,
repository.IDFilter(token.ID),
repository.Patch().Inc(repository.Field("attempts"), 1),
)
// 6) Re-check state
var fresh model.VerificationToken
if err := db.DBImp.FindOne(ctx, repository.IDFilter(token.ID), &fresh); err != nil {
return nil, merrors.Internal("failed to re-check token state")
}
if fresh.UsedAt != nil {
return nil, verification.ErorrTokenAlreadyUsed()
}
if !fresh.ExpiresAt.After(now) {
return nil, verification.ErorrTokenExpired()
}
if fresh.MaxRetries != nil && fresh.Attempts >= *fresh.MaxRetries {
return nil, verification.ErrorTokenAttemptsExceeded()
}
return nil, verification.ErorrTokenNotFound()
},
)
if e != nil {
return nil, e
}
@@ -76,6 +154,5 @@ func (db *verificationDB) Consume(
if !ok {
return nil, merrors.Internal("unexpected token type")
}
return res, nil
}

View File

@@ -2,95 +2,224 @@ package verificationimp
import (
"context"
"crypto/rand"
"encoding/base64"
"errors"
"strings"
"time"
"github.com/tech/sendico/pkg/db/repository"
"github.com/tech/sendico/pkg/db/repository/builder"
"github.com/tech/sendico/pkg/db/verification"
"github.com/tech/sendico/pkg/merrors"
"github.com/tech/sendico/pkg/model"
"github.com/tech/sendico/pkg/mutil/mzap"
"go.mongodb.org/mongo-driver/v2/bson"
"go.uber.org/zap"
)
const verificationTokenBytes = 32
func normalizedIdempotencyKey(value *string) (string, bool) {
if value == nil {
return "", false
}
key := strings.TrimSpace(*value)
if key == "" {
return "", false
}
return key, true
}
func idempotencyFilter(
request *verification.Request,
idempotencyKey string,
) builder.Query {
return repository.Query().And(
repository.Filter("accountRef", request.AccountRef),
repository.Filter("purpose", request.Purpose),
repository.Filter("target", request.Target),
repository.Filter("idempotencyKey", idempotencyKey),
)
}
func hashFilter(hash string) builder.Query {
return repository.Filter("verifyTokenHash", hash)
}
func idempotencySeed(request *verification.Request, idempotencyKey string) string {
return strings.Join([]string{
request.AccountRef.Hex(),
string(request.Purpose),
request.Target,
request.Kind,
idempotencyKey,
}, "|")
}
func newVerificationToken(
accountRef bson.ObjectID,
purpose model.VerificationPurpose,
target string,
ttl time.Duration,
request *verification.Request,
idempotencyKey string,
hasIdempotency bool,
) (*model.VerificationToken, string, error) {
raw := make([]byte, verificationTokenBytes)
if _, err := rand.Read(raw); err != nil {
return nil, "", err
}
rawToken := base64.RawURLEncoding.EncodeToString(raw)
hashStr := tokenHash(rawToken)
now := time.Now().UTC()
token := &model.VerificationToken{
AccountRef: accountRef,
Purpose: purpose,
Target: target,
VerifyTokenHash: hashStr,
UsedAt: nil,
ExpiresAt: now.Add(ttl),
var (
raw string
hash string
salt *string
err error
)
switch request.Kind {
case verification.TokenKindOTP:
if hasIdempotency {
var saltValue string
raw, saltValue, hash = generateDeterministicOTP(idempotencySeed(request, idempotencyKey))
salt = &saltValue
} else {
var s string
raw, s, hash, err = generateOTP()
if err != nil {
return nil, "", err
}
salt = &s
}
default: // Magic token
if hasIdempotency {
raw, hash = generateDeterministicMagic(idempotencySeed(request, idempotencyKey))
} else {
raw, hash, err = generateMagic()
if err != nil {
return nil, "", err
}
}
}
return token, rawToken, nil
token := &model.VerificationToken{
AccountRef: request.AccountRef,
Purpose: request.Purpose,
Target: request.Target,
IdempotencyKey: nil,
VerifyTokenHash: hash,
Salt: salt,
UsedAt: nil,
ExpiresAt: now.Add(request.Ttl),
MaxRetries: request.MaxRetries,
}
if hasIdempotency {
token.IdempotencyKey = &idempotencyKey
}
return token, raw, nil
}
func (db *verificationDB) Create(
ctx context.Context,
accountRef bson.ObjectID,
purpose model.VerificationPurpose,
target string,
ttl time.Duration,
request *verification.Request,
) (string, error) {
logFields := []zap.Field{
zap.String("purpose", string(purpose)), zap.Duration("ttl", ttl),
mzap.AccRef(accountRef), zap.String("target", target),
if request == nil {
return "", merrors.Internal("nil request")
}
token, raw, err := newVerificationToken(accountRef, purpose, target, ttl)
idempotencyKey, hasIdempotency := normalizedIdempotencyKey(request.IdempotencyKey)
token, raw, err := newVerificationToken(request, idempotencyKey, hasIdempotency)
if err != nil {
db.Logger.Warn("Failed to generate verification token", append(logFields, zap.Error(err))...)
return "", err
}
// Invalidate any active tokens for the same (accountRef, purpose, target).
now := time.Now().UTC()
invalidated, err := db.DBImp.PatchMany(ctx,
repository.Query().And(
repository.Filter("accountRef", accountRef),
repository.Filter("purpose", purpose),
repository.Filter("target", target),
_, err = db.tf.CreateTransaction().Execute(ctx, func(tx context.Context) (any, error) {
now := time.Now().UTC()
baseFilter := repository.Query().And(
repository.Filter("accountRef", request.AccountRef),
repository.Filter("purpose", request.Purpose),
repository.Filter("target", request.Target),
repository.Filter("usedAt", nil),
repository.Query().Comparison(repository.Field("expiresAt"), builder.Gt, now),
),
repository.Patch().Set(repository.Field("usedAt"), now),
)
)
// Optional idempotency key support for safe retries.
if hasIdempotency {
var sameToken model.VerificationToken
err := db.DBImp.FindOne(tx, hashFilter(token.VerifyTokenHash), &sameToken)
switch {
case err == nil:
// Same hash means the same Create operation already succeeded.
return nil, nil
case errors.Is(err, merrors.ErrNoData):
default:
return nil, err
}
var existing model.VerificationToken
err = db.DBImp.FindOne(tx, idempotencyFilter(request, idempotencyKey), &existing)
switch {
case err == nil:
// Existing request with the same idempotency scope has already succeeded.
return nil, nil
case errors.Is(err, merrors.ErrNoData):
default:
return nil, err
}
}
// 1) Cooldown: if there exists ANY active token created after cutoff → block
if request.Cooldown != nil {
cutoff := now.Add(-*request.Cooldown)
cooldownFilter := baseFilter.And(
repository.Query().Comparison(repository.Field("createdAt"), builder.Gt, cutoff),
)
var recent model.VerificationToken
err := db.DBImp.FindOne(tx, cooldownFilter, &recent)
switch {
case err == nil:
return nil, verification.ErrorCooldownActive()
case errors.Is(err, merrors.ErrNoData):
default:
return nil, err
}
}
// 2) Invalidate active tokens for this context
if _, err := db.DBImp.PatchMany(
tx,
baseFilter,
repository.Patch().Set(repository.Field("usedAt"), now),
); err != nil {
return nil, err
}
// 3) Create new token only after cooldown/idempotency checks pass.
if err := db.DBImp.Create(tx, token); err != nil {
if hasIdempotency && errors.Is(err, merrors.ErrDataConflict) {
var sameToken model.VerificationToken
findErr := db.DBImp.FindOne(tx, hashFilter(token.VerifyTokenHash), &sameToken)
switch {
case findErr == nil:
return nil, nil
case errors.Is(findErr, merrors.ErrNoData):
default:
return nil, findErr
}
var existing model.VerificationToken
findErr = db.DBImp.FindOne(tx, idempotencyFilter(request, idempotencyKey), &existing)
switch {
case findErr == nil:
return nil, nil
case errors.Is(findErr, merrors.ErrNoData):
default:
return nil, findErr
}
}
return nil, err
}
return nil, nil
})
if err != nil {
db.Logger.Warn("Failed to invalidate previous tokens", append(logFields, zap.Error(err))...)
return "", err
}
if invalidated > 0 {
db.Logger.Debug("Invalidated previous tokens", append(logFields, zap.Int("count", invalidated))...)
}
if err := db.DBImp.Create(ctx, token); err != nil {
db.Logger.Warn("Failed to persist verification token", append(logFields, zap.Error(err))...)
return "", err
}
db.Logger.Debug("Verification token created", append(logFields, zap.String("hash", token.VerifyTokenHash))...)
return raw, nil
}

View File

@@ -35,6 +35,21 @@ func Create(
return nil, err
}
if err := p.Repository.CreateIndex(&ri.Definition{
Keys: []ri.Key{
{Field: "accountRef", Sort: ri.Asc},
{Field: "purpose", Sort: ri.Asc},
{Field: "target", Sort: ri.Asc},
{Field: "idempotencyKey", Sort: ri.Asc},
},
Unique: true,
Sparse: true,
Name: "uniq_verification_context_idempotency",
}); err != nil {
p.Logger.Error("Failed to create unique idempotency index on verification context", zap.Error(err))
return nil, err
}
ttl := int32(2678400) // 30 days
if err := p.Repository.CreateIndex(&ri.Definition{
Keys: []ri.Key{{Field: "expiresAt", Sort: ri.Asc}},

View File

@@ -1,10 +1,105 @@
package verificationimp
import (
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"fmt"
"github.com/tech/sendico/pkg/model"
)
type TokenHasher interface {
Hash(raw string, token *model.VerificationToken) string
}
type magicHasher struct{}
func (magicHasher) Hash(raw string, _ *model.VerificationToken) string {
return tokenHash(raw)
}
type otpHasher struct{}
func (otpHasher) Hash(raw string, t *model.VerificationToken) string {
return otpHash(raw, *t.Salt)
}
func hasherFor(t *model.VerificationToken) TokenHasher {
if t.Salt != nil {
return otpHasher{}
}
return magicHasher{}
}
const verificationTokenBytes = 32
const otpDigits = 6
func generateMagic() (raw, hash string, err error) {
rawBytes := make([]byte, verificationTokenBytes)
if _, err = rand.Read(rawBytes); err != nil {
return
}
raw = base64.RawURLEncoding.EncodeToString(rawBytes)
hash = tokenHash(raw)
return
}
func generateDeterministicMagic(seed string) (raw, hash string) {
sum := sha256.Sum256([]byte("magic:" + seed))
raw = base64.RawURLEncoding.EncodeToString(sum[:])
hash = tokenHash(raw)
return
}
func generateOTP() (code, salt, hash string, err error) {
// otpDigits-digit code
n := make([]byte, 4)
if _, err = rand.Read(n); err != nil {
return
}
num := int(n[0])<<24 | int(n[1])<<16 | int(n[2])<<8 | int(n[3])
mod := 1
for i := 0; i < otpDigits; i++ {
mod *= 10
}
code = fmt.Sprintf("%0*d", otpDigits, num%mod)
// per-token salt
saltBytes := make([]byte, 16)
if _, err = rand.Read(saltBytes); err != nil {
return
}
salt = base64.RawURLEncoding.EncodeToString(saltBytes)
hash = otpHash(code, salt)
return
}
func generateDeterministicOTP(seed string) (code, salt, hash string) {
sum := sha256.Sum256([]byte("otp:" + seed))
num := int(sum[0])<<24 | int(sum[1])<<16 | int(sum[2])<<8 | int(sum[3])
mod := 1
for i := 0; i < otpDigits; i++ {
mod *= 10
}
code = fmt.Sprintf("%0*d", otpDigits, num%mod)
salt = base64.RawURLEncoding.EncodeToString(sum[4:20])
hash = otpHash(code, salt)
return
}
// We store only the resulting hash (+salt) in DB, never the OTP itself.
func otpHash(code, salt string) string {
sum := sha256.Sum256([]byte(salt + ":" + code))
return hex.EncodeToString(sum[:])
}
func tokenHash(rawToken string) string {
hash := sha256.Sum256([]byte(rawToken))
return base64.RawURLEncoding.EncodeToString(hash[:])

View File

@@ -20,6 +20,7 @@ import (
"github.com/tech/sendico/pkg/model"
"github.com/tech/sendico/pkg/mservice"
"go.mongodb.org/mongo-driver/v2/bson"
"go.mongodb.org/mongo-driver/v2/mongo"
"go.uber.org/zap"
)
@@ -28,6 +29,10 @@ import (
// ---------------------------------------------------------------------------
func newTestVerificationDB(t *testing.T) *verificationDB {
return newTestVerificationDBWithFactory(t, &passthroughTxFactory{})
}
func newTestVerificationDBWithFactory(t *testing.T, tf transaction.Factory) *verificationDB {
t.Helper()
repo := newMemoryTokenRepository()
logger := zap.NewNop()
@@ -36,7 +41,7 @@ func newTestVerificationDB(t *testing.T) *verificationDB {
Logger: logger,
Repository: repo,
},
tf: &passthroughTxFactory{},
tf: tf,
}
}
@@ -51,6 +56,20 @@ func (*passthroughTx) Execute(ctx context.Context, cb transaction.Callback) (any
return cb(ctx)
}
// retryingTxFactory simulates transaction callbacks being executed more than once.
type retryingTxFactory struct{}
func (*retryingTxFactory) CreateTransaction() transaction.Transaction { return &retryingTx{} }
type retryingTx struct{}
func (*retryingTx) Execute(ctx context.Context, cb transaction.Callback) (any, error) {
if _, err := cb(ctx); err != nil {
return nil, err
}
return cb(ctx)
}
// ---------------------------------------------------------------------------
// in-memory repository for VerificationToken
// ---------------------------------------------------------------------------
@@ -156,8 +175,34 @@ func (m *memoryTokenRepository) InsertMany(ctx context.Context, objs []storable.
}
return nil
}
func (m *memoryTokenRepository) FindManyByFilter(context.Context, builder.Query, rd.DecodingFunc) error {
return merrors.NotImplemented("not needed")
func (m *memoryTokenRepository) FindManyByFilter(_ context.Context, query builder.Query, decoder rd.DecodingFunc) error {
m.mu.Lock()
var matches []interface{}
for _, id := range m.order {
tok := m.data[id]
if tok != nil && matchToken(query, tok) {
raw, err := bson.Marshal(cloneToken(tok))
if err != nil {
m.mu.Unlock()
return err
}
matches = append(matches, bson.Raw(raw))
}
}
m.mu.Unlock()
cur, err := mongo.NewCursorFromDocuments(matches, nil, nil)
if err != nil {
return err
}
defer cur.Close(context.Background())
for cur.Next(context.Background()) {
if err := decoder(cur); err != nil {
return err
}
}
return nil
}
func (m *memoryTokenRepository) Patch(context.Context, bson.ObjectID, builder.Patch) error {
return merrors.NotImplemented("not needed")
@@ -190,8 +235,14 @@ func (m *memoryTokenRepository) Collection() string { return mservice.Verificati
// tokenFieldValue returns the stored value for a given BSON field name.
func tokenFieldValue(tok *model.VerificationToken, field string) any {
switch field {
case "_id":
return tok.ID
case "createdAt":
return tok.CreatedAt
case "verifyTokenHash":
return tok.VerifyTokenHash
case "salt":
return tok.Salt
case "usedAt":
return tok.UsedAt
case "expiresAt":
@@ -202,6 +253,15 @@ func tokenFieldValue(tok *model.VerificationToken, field string) any {
return tok.Purpose
case "target":
return tok.Target
case "idempotencyKey":
if tok.IdempotencyKey == nil {
return nil
}
return *tok.IdempotencyKey
case "maxRetries":
return tok.MaxRetries
case "attempts":
return tok.Attempts
default:
return nil
}
@@ -261,11 +321,11 @@ func matchOperator(stored any, ops bson.M) bool {
for op, cmpVal := range ops {
switch op {
case "$gt":
if !timeGt(stored, cmpVal) {
if !cmpGt(stored, cmpVal) {
return false
}
case "$lt":
if !timeLt(stored, cmpVal) {
if !cmpLt(stored, cmpVal) {
return false
}
}
@@ -273,6 +333,36 @@ func matchOperator(stored any, ops bson.M) bool {
return true
}
func cmpGt(stored, cmpVal any) bool {
if si, ok := toInt(stored); ok {
if ci, ok := toInt(cmpVal); ok {
return si > ci
}
}
return timeGt(stored, cmpVal)
}
func cmpLt(stored, cmpVal any) bool {
if si, ok := toInt(stored); ok {
if ci, ok := toInt(cmpVal); ok {
return si < ci
}
}
return timeLt(stored, cmpVal)
}
func toInt(v any) (int, bool) {
switch iv := v.(type) {
case int:
return iv, true
case int64:
return int(iv), true
case int32:
return int(iv), true
}
return 0, false
}
func valuesEqual(a, b any) bool {
// nil checks: usedAt == nil
if b == nil {
@@ -343,21 +433,34 @@ func toTime(v any) (time.Time, bool) {
return time.Time{}, false
}
// applyPatch applies $set operations from a patch bson.D to a token.
// applyPatch applies $set and $inc operations from a patch bson.D to a token.
func applyPatch(tok *model.VerificationToken, patchDoc bson.D) {
for _, op := range patchDoc {
if op.Key != "$set" {
continue
}
fields, ok := op.Value.(bson.D)
if !ok {
continue
}
for _, f := range fields {
switch f.Key {
case "usedAt":
if t, ok := f.Value.(time.Time); ok {
tok.UsedAt = &t
switch op.Key {
case "$set":
fields, ok := op.Value.(bson.D)
if !ok {
continue
}
for _, f := range fields {
switch f.Key {
case "usedAt":
if t, ok := f.Value.(time.Time); ok {
tok.UsedAt = &t
}
}
}
case "$inc":
fields, ok := op.Value.(bson.D)
if !ok {
continue
}
for _, f := range fields {
switch f.Key {
case "attempts":
if v, ok := f.Value.(int); ok {
tok.Attempts += v
}
}
}
}
@@ -370,20 +473,27 @@ func cloneToken(src *model.VerificationToken) *model.VerificationToken {
t := *src.UsedAt
dst.UsedAt = &t
}
if src.MaxRetries != nil {
v := *src.MaxRetries
dst.MaxRetries = &v
}
if src.Salt != nil {
s := *src.Salt
dst.Salt = &s
}
if src.IdempotencyKey != nil {
k := *src.IdempotencyKey
dst.IdempotencyKey = &k
}
return &dst
}
// allTokens returns every stored token for inspection in tests.
func (m *memoryTokenRepository) allTokens() []*model.VerificationToken {
m.mu.Lock()
defer m.mu.Unlock()
out := make([]*model.VerificationToken, 0, len(m.data))
for _, id := range m.order {
if tok, ok := m.data[id]; ok {
out = append(out, cloneToken(tok))
}
}
return out
// ---------------------------------------------------------------------------
// helpers request builder
// ---------------------------------------------------------------------------
func req(accountRef bson.ObjectID, purpose model.VerificationPurpose, target string, ttl time.Duration) *verification.Request {
return verification.NewLinkRequest(accountRef, purpose, target).WithTTL(ttl)
}
// ---------------------------------------------------------------------------
@@ -395,7 +505,7 @@ func TestCreate_ReturnsRawToken(t *testing.T) {
ctx := context.Background()
accountRef := bson.NewObjectID()
raw, err := db.Create(ctx, accountRef, model.PurposePasswordReset, "", time.Hour)
raw, err := db.Create(ctx, req(accountRef, model.PurposePasswordReset, "", time.Hour))
require.NoError(t, err)
assert.NotEmpty(t, raw)
}
@@ -405,10 +515,10 @@ func TestCreate_TokenCanBeConsumed(t *testing.T) {
ctx := context.Background()
accountRef := bson.NewObjectID()
raw, err := db.Create(ctx, accountRef, model.PurposePasswordReset, "", time.Hour)
raw, err := db.Create(ctx, req(accountRef, model.PurposePasswordReset, "", time.Hour))
require.NoError(t, err)
tok, err := db.Consume(ctx, raw)
tok, err := db.Consume(ctx, accountRef, model.PurposePasswordReset, raw)
require.NoError(t, err)
assert.Equal(t, accountRef, tok.AccountRef)
assert.Equal(t, model.PurposePasswordReset, tok.Purpose)
@@ -420,10 +530,10 @@ func TestConsume_ReturnsCorrectFields(t *testing.T) {
ctx := context.Background()
accountRef := bson.NewObjectID()
raw, err := db.Create(ctx, accountRef, model.PurposeEmailChange, "new@example.com", time.Hour)
raw, err := db.Create(ctx, req(accountRef, model.PurposeEmailChange, "new@example.com", time.Hour))
require.NoError(t, err)
tok, err := db.Consume(ctx, raw)
tok, err := db.Consume(ctx, accountRef, model.PurposeEmailChange, raw)
require.NoError(t, err)
assert.Equal(t, accountRef, tok.AccountRef)
assert.Equal(t, model.PurposeEmailChange, tok.Purpose)
@@ -435,16 +545,16 @@ func TestConsume_SecondConsumeFailsAlreadyUsed(t *testing.T) {
ctx := context.Background()
accountRef := bson.NewObjectID()
raw, err := db.Create(ctx, accountRef, model.PurposePasswordReset, "", time.Hour)
raw, err := db.Create(ctx, req(accountRef, model.PurposePasswordReset, "", time.Hour))
require.NoError(t, err)
_, err = db.Consume(ctx, raw)
_, err = db.Consume(ctx, accountRef, model.PurposePasswordReset, raw)
require.NoError(t, err)
_, err = db.Consume(ctx, raw)
_, err = db.Consume(ctx, accountRef, model.PurposePasswordReset, raw)
require.Error(t, err)
assert.True(t, errors.Is(err, verification.ErrTokenAlreadyUsed),
"second consume should fail because usedAt is set")
assert.True(t, errors.Is(err, verification.ErrTokenNotFound),
"second consume should fail — used tokens are excluded from active filter")
}
func TestConsume_ExpiredTokenFails(t *testing.T) {
@@ -453,20 +563,20 @@ func TestConsume_ExpiredTokenFails(t *testing.T) {
accountRef := bson.NewObjectID()
// Create with a TTL that is already in the past.
raw, err := db.Create(ctx, accountRef, model.PurposePasswordReset, "", -time.Hour)
raw, err := db.Create(ctx, req(accountRef, model.PurposePasswordReset, "", -time.Hour))
require.NoError(t, err)
_, err = db.Consume(ctx, raw)
_, err = db.Consume(ctx, accountRef, model.PurposePasswordReset, raw)
require.Error(t, err)
assert.True(t, errors.Is(err, verification.ErrTokenExpired),
"expired token should not be consumable")
assert.True(t, errors.Is(err, verification.ErrTokenNotFound),
"expired token is excluded from active filter")
}
func TestConsume_UnknownTokenFails(t *testing.T) {
db := newTestVerificationDB(t)
ctx := context.Background()
_, err := db.Consume(ctx, "nonexistent-token-value")
_, err := db.Consume(ctx, bson.NilObjectID, "", "nonexistent-token-value")
require.Error(t, err)
assert.True(t, errors.Is(err, verification.ErrTokenNotFound))
}
@@ -476,21 +586,21 @@ func TestCreate_InvalidatesPreviousToken(t *testing.T) {
ctx := context.Background()
accountRef := bson.NewObjectID()
oldRaw, err := db.Create(ctx, accountRef, model.PurposePasswordReset, "", time.Hour)
oldRaw, err := db.Create(ctx, req(accountRef, model.PurposePasswordReset, "", time.Hour))
require.NoError(t, err)
newRaw, err := db.Create(ctx, accountRef, model.PurposePasswordReset, "", time.Hour)
newRaw, err := db.Create(ctx, req(accountRef, model.PurposePasswordReset, "", time.Hour))
require.NoError(t, err)
assert.NotEqual(t, oldRaw, newRaw, "new token should differ from old one")
// Old token is no longer consumable.
_, err = db.Consume(ctx, oldRaw)
// Old token is no longer consumable — invalidated (usedAt set) by the second Create.
_, err = db.Consume(ctx, accountRef, model.PurposePasswordReset, oldRaw)
require.Error(t, err)
assert.True(t, errors.Is(err, verification.ErrTokenAlreadyUsed),
"old token should be invalidated (usedAt set) after new token creation")
assert.True(t, errors.Is(err, verification.ErrTokenNotFound),
"old token should be invalidated after new token creation")
// New token works fine.
tok, err := db.Consume(ctx, newRaw)
tok, err := db.Consume(ctx, accountRef, model.PurposePasswordReset, newRaw)
require.NoError(t, err)
assert.Equal(t, accountRef, tok.AccountRef)
}
@@ -500,19 +610,19 @@ func TestCreate_InvalidatesMultiplePreviousTokens(t *testing.T) {
ctx := context.Background()
accountRef := bson.NewObjectID()
first, err := db.Create(ctx, accountRef, model.PurposePasswordReset, "", time.Hour)
first, err := db.Create(ctx, req(accountRef, model.PurposePasswordReset, "", time.Hour))
require.NoError(t, err)
second, err := db.Create(ctx, accountRef, model.PurposePasswordReset, "", time.Hour)
second, err := db.Create(ctx, req(accountRef, model.PurposePasswordReset, "", time.Hour))
require.NoError(t, err)
third, err := db.Create(ctx, accountRef, model.PurposePasswordReset, "", time.Hour)
third, err := db.Create(ctx, req(accountRef, model.PurposePasswordReset, "", time.Hour))
require.NoError(t, err)
_, err = db.Consume(ctx, first)
assert.True(t, errors.Is(err, verification.ErrTokenAlreadyUsed), "first should be invalidated")
_, err = db.Consume(ctx, second)
assert.True(t, errors.Is(err, verification.ErrTokenAlreadyUsed), "second should be invalidated")
_, err = db.Consume(ctx, accountRef, model.PurposePasswordReset, first)
assert.True(t, errors.Is(err, verification.ErrTokenNotFound), "first should be invalidated")
_, err = db.Consume(ctx, accountRef, model.PurposePasswordReset, second)
assert.True(t, errors.Is(err, verification.ErrTokenNotFound), "second should be invalidated")
tok, err := db.Consume(ctx, third)
tok, err := db.Consume(ctx, accountRef, model.PurposePasswordReset, third)
require.NoError(t, err)
assert.Equal(t, accountRef, tok.AccountRef)
}
@@ -522,14 +632,14 @@ func TestCreate_DifferentPurposeNotInvalidated(t *testing.T) {
ctx := context.Background()
accountRef := bson.NewObjectID()
resetRaw, err := db.Create(ctx, accountRef, model.PurposePasswordReset, "", time.Hour)
resetRaw, err := db.Create(ctx, req(accountRef, model.PurposePasswordReset, "", time.Hour))
require.NoError(t, err)
// Creating an activation token should NOT invalidate the password-reset token.
_, err = db.Create(ctx, accountRef, model.PurposeAccountActivation, "", time.Hour)
_, err = db.Create(ctx, req(accountRef, model.PurposeAccountActivation, "", time.Hour))
require.NoError(t, err)
tok, err := db.Consume(ctx, resetRaw)
tok, err := db.Consume(ctx, accountRef, model.PurposePasswordReset, resetRaw)
require.NoError(t, err)
assert.Equal(t, model.PurposePasswordReset, tok.Purpose)
}
@@ -539,14 +649,14 @@ func TestCreate_DifferentTargetNotInvalidated(t *testing.T) {
ctx := context.Background()
accountRef := bson.NewObjectID()
firstRaw, err := db.Create(ctx, accountRef, model.PurposeEmailChange, "a@example.com", time.Hour)
firstRaw, err := db.Create(ctx, req(accountRef, model.PurposeEmailChange, "a@example.com", time.Hour))
require.NoError(t, err)
// Creating a token for a different target email should NOT invalidate the first.
_, err = db.Create(ctx, accountRef, model.PurposeEmailChange, "b@example.com", time.Hour)
_, err = db.Create(ctx, req(accountRef, model.PurposeEmailChange, "b@example.com", time.Hour))
require.NoError(t, err)
tok, err := db.Consume(ctx, firstRaw)
tok, err := db.Consume(ctx, accountRef, model.PurposeEmailChange, firstRaw)
require.NoError(t, err)
assert.Equal(t, "a@example.com", tok.Target)
}
@@ -557,13 +667,13 @@ func TestCreate_DifferentAccountNotInvalidated(t *testing.T) {
account1 := bson.NewObjectID()
account2 := bson.NewObjectID()
raw1, err := db.Create(ctx, account1, model.PurposePasswordReset, "", time.Hour)
raw1, err := db.Create(ctx, req(account1, model.PurposePasswordReset, "", time.Hour))
require.NoError(t, err)
_, err = db.Create(ctx, account2, model.PurposePasswordReset, "", time.Hour)
_, err = db.Create(ctx, req(account2, model.PurposePasswordReset, "", time.Hour))
require.NoError(t, err)
tok, err := db.Consume(ctx, raw1)
tok, err := db.Consume(ctx, account1, model.PurposePasswordReset, raw1)
require.NoError(t, err)
assert.Equal(t, account1, tok.AccountRef)
}
@@ -574,18 +684,18 @@ func TestCreate_AlreadyUsedTokenNotInvalidatedAgain(t *testing.T) {
accountRef := bson.NewObjectID()
// Create and consume first token.
raw1, err := db.Create(ctx, accountRef, model.PurposePasswordReset, "", time.Hour)
raw1, err := db.Create(ctx, req(accountRef, model.PurposePasswordReset, "", time.Hour))
require.NoError(t, err)
_, err = db.Consume(ctx, raw1)
_, err = db.Consume(ctx, accountRef, model.PurposePasswordReset, raw1)
require.NoError(t, err)
// Create second — the already-consumed token should have usedAt set,
// so the invalidation query (usedAt == nil) should skip it.
// This tests that the PatchMany filter correctly excludes already-used tokens.
raw2, err := db.Create(ctx, accountRef, model.PurposePasswordReset, "", time.Hour)
raw2, err := db.Create(ctx, req(accountRef, model.PurposePasswordReset, "", time.Hour))
require.NoError(t, err)
tok, err := db.Consume(ctx, raw2)
tok, err := db.Consume(ctx, accountRef, model.PurposePasswordReset, raw2)
require.NoError(t, err)
assert.Equal(t, accountRef, tok.AccountRef)
}
@@ -596,14 +706,14 @@ func TestCreate_ExpiredTokenNotInvalidated(t *testing.T) {
accountRef := bson.NewObjectID()
// Create a token that is already expired.
_, err := db.Create(ctx, accountRef, model.PurposePasswordReset, "", -time.Hour)
_, err := db.Create(ctx, req(accountRef, model.PurposePasswordReset, "", -time.Hour))
require.NoError(t, err)
// Create a fresh one — invalidation should skip the expired token (expiresAt > now filter).
raw2, err := db.Create(ctx, accountRef, model.PurposePasswordReset, "", time.Hour)
raw2, err := db.Create(ctx, req(accountRef, model.PurposePasswordReset, "", time.Hour))
require.NoError(t, err)
tok, err := db.Consume(ctx, raw2)
tok, err := db.Consume(ctx, accountRef, model.PurposePasswordReset, raw2)
require.NoError(t, err)
assert.Equal(t, accountRef, tok.AccountRef)
}
@@ -619,3 +729,228 @@ func TestTokenHash_DifferentInputs(t *testing.T) {
h2 := tokenHash("input-b")
assert.NotEqual(t, h1, h2)
}
// ---------------------------------------------------------------------------
// cooldown tests
// ---------------------------------------------------------------------------
func TestCreate_CooldownBlocksCreation(t *testing.T) {
db := newTestVerificationDB(t)
ctx := context.Background()
accountRef := bson.NewObjectID()
// First creation without cooldown.
_, err := db.Create(ctx, req(accountRef, model.PurposePasswordReset, "", time.Hour))
require.NoError(t, err)
// Immediate re-create with cooldown should be blocked — token is too recent to invalidate.
r2 := req(accountRef, model.PurposePasswordReset, "", time.Hour).WithCooldown(time.Minute)
_, err = db.Create(ctx, r2)
require.Error(t, err)
assert.True(t, errors.Is(err, verification.ErrCooldownActive))
}
func TestCreate_CooldownExpiresAllowsCreation(t *testing.T) {
db := newTestVerificationDB(t)
ctx := context.Background()
accountRef := bson.NewObjectID()
// First creation without cooldown.
_, err := db.Create(ctx, req(accountRef, model.PurposePasswordReset, "", time.Hour))
require.NoError(t, err)
time.Sleep(2 * time.Millisecond)
// Re-create with short cooldown — the prior token is old enough to be invalidated.
r2 := req(accountRef, model.PurposePasswordReset, "", time.Hour).WithCooldown(time.Millisecond)
_, err = db.Create(ctx, r2)
require.NoError(t, err)
}
func TestCreate_CooldownNilIgnored(t *testing.T) {
db := newTestVerificationDB(t)
ctx := context.Background()
accountRef := bson.NewObjectID()
_, err := db.Create(ctx, req(accountRef, model.PurposePasswordReset, "", time.Hour))
require.NoError(t, err)
// No cooldown set — immediate re-create should succeed.
_, err = db.Create(ctx, req(accountRef, model.PurposePasswordReset, "", time.Hour))
require.NoError(t, err)
}
func TestCreate_IdempotencyKeyReplayReturnsSameToken(t *testing.T) {
db := newTestVerificationDB(t)
ctx := context.Background()
accountRef := bson.NewObjectID()
firstReq := req(accountRef, model.PurposePasswordReset, "", time.Hour).WithIdempotencyKey("same-key")
firstRaw, err := db.Create(ctx, firstReq)
require.NoError(t, err)
require.NotEmpty(t, firstRaw)
// Replay with the same idempotency key should return success and same token.
secondReq := req(accountRef, model.PurposePasswordReset, "", time.Hour).WithIdempotencyKey("same-key")
secondRaw, err := db.Create(ctx, secondReq)
require.NoError(t, err)
assert.Equal(t, firstRaw, secondRaw)
repo := db.Repository.(*memoryTokenRepository)
repo.mu.Lock()
assert.Len(t, repo.data, 1)
repo.mu.Unlock()
}
func TestCreate_IdempotencyScopeIncludesTarget(t *testing.T) {
db := newTestVerificationDB(t)
ctx := context.Background()
accountRef := bson.NewObjectID()
r1 := req(accountRef, model.PurposeEmailChange, "a@example.com", time.Hour).WithIdempotencyKey("same-key")
raw1, err := db.Create(ctx, r1)
require.NoError(t, err)
require.NotEmpty(t, raw1)
// Same account/purpose/key but different target should be treated as a different idempotency scope.
r2 := req(accountRef, model.PurposeEmailChange, "b@example.com", time.Hour).WithIdempotencyKey("same-key")
raw2, err := db.Create(ctx, r2)
require.NoError(t, err)
require.NotEmpty(t, raw2)
assert.NotEqual(t, raw1, raw2)
t1, err := db.Consume(ctx, accountRef, model.PurposeEmailChange, raw1)
require.NoError(t, err)
assert.Equal(t, "a@example.com", t1.Target)
t2, err := db.Consume(ctx, accountRef, model.PurposeEmailChange, raw2)
require.NoError(t, err)
assert.Equal(t, "b@example.com", t2.Target)
}
func TestCreate_IdempotencySurvivesCallbackRetry(t *testing.T) {
db := newTestVerificationDBWithFactory(t, &retryingTxFactory{})
ctx := context.Background()
accountRef := bson.NewObjectID()
// Cooldown would block the second callback execution if idempotency wasn't handled.
r := req(accountRef, model.PurposePasswordReset, "", time.Hour).
WithCooldown(time.Minute).
WithIdempotencyKey("retry-safe")
raw, err := db.Create(ctx, r)
require.NoError(t, err)
require.NotEmpty(t, raw)
repo := db.Repository.(*memoryTokenRepository)
repo.mu.Lock()
require.Len(t, repo.data, 1)
for _, tok := range repo.data {
require.NotNil(t, tok.IdempotencyKey)
assert.Equal(t, "retry-safe", *tok.IdempotencyKey)
assert.Nil(t, tok.UsedAt)
assert.Equal(t, tok.VerifyTokenHash, tokenHash(raw))
}
repo.mu.Unlock()
}
// ---------------------------------------------------------------------------
// max retries / attempts tests
// ---------------------------------------------------------------------------
func TestConsume_MaxRetriesExceeded(t *testing.T) {
db := newTestVerificationDB(t)
ctx := context.Background()
accountRef := bson.NewObjectID()
r := req(accountRef, model.PurposePasswordReset, "", time.Hour).WithMaxRetries(2)
raw, err := db.Create(ctx, r)
require.NoError(t, err)
// Simulate 2 prior failed attempts by setting Attempts directly.
repo := db.Repository.(*memoryTokenRepository)
repo.mu.Lock()
for _, tok := range repo.data {
tok.Attempts = 2
}
repo.mu.Unlock()
// Consume with correct token should fail — attempts already at max.
_, err = db.Consume(ctx, accountRef, model.PurposePasswordReset, raw)
require.Error(t, err)
assert.True(t, errors.Is(err, verification.ErrTokenAttemptsExceeded))
}
func TestConsume_UnderMaxRetriesSucceeds(t *testing.T) {
db := newTestVerificationDB(t)
ctx := context.Background()
accountRef := bson.NewObjectID()
r := req(accountRef, model.PurposePasswordReset, "", time.Hour).WithMaxRetries(3)
raw, err := db.Create(ctx, r)
require.NoError(t, err)
// Simulate 2 prior failed attempts (under maxRetries=3).
repo := db.Repository.(*memoryTokenRepository)
repo.mu.Lock()
for _, tok := range repo.data {
tok.Attempts = 2
}
repo.mu.Unlock()
// Consume with correct token should succeed.
tok, err := db.Consume(ctx, accountRef, model.PurposePasswordReset, raw)
require.NoError(t, err)
assert.Equal(t, accountRef, tok.AccountRef)
}
func TestConsume_NoMaxRetriesIgnoresAttempts(t *testing.T) {
db := newTestVerificationDB(t)
ctx := context.Background()
accountRef := bson.NewObjectID()
// Create without MaxRetries.
raw, err := db.Create(ctx, req(accountRef, model.PurposePasswordReset, "", time.Hour))
require.NoError(t, err)
// Simulate high attempt count — should be ignored since MaxRetries is nil.
repo := db.Repository.(*memoryTokenRepository)
repo.mu.Lock()
for _, tok := range repo.data {
tok.Attempts = 100
}
repo.mu.Unlock()
tok, err := db.Consume(ctx, accountRef, model.PurposePasswordReset, raw)
require.NoError(t, err)
assert.Equal(t, accountRef, tok.AccountRef)
}
func TestConsume_WrongHashReturnsNotFound(t *testing.T) {
db := newTestVerificationDB(t)
ctx := context.Background()
accountRef := bson.NewObjectID()
_, err := db.Create(ctx, req(accountRef, model.PurposePasswordReset, "", time.Hour))
require.NoError(t, err)
// Wrong code — hash won't match any token.
_, err = db.Consume(ctx, accountRef, model.PurposePasswordReset, "wrong-code")
require.Error(t, err)
assert.True(t, errors.Is(err, verification.ErrTokenNotFound))
}
func TestConsume_ContextMismatchReturnsNotFound(t *testing.T) {
db := newTestVerificationDB(t)
ctx := context.Background()
accountRef := bson.NewObjectID()
otherAccount := bson.NewObjectID()
raw, err := db.Create(ctx, req(accountRef, model.PurposePasswordReset, "", time.Hour))
require.NoError(t, err)
// Correct token but wrong accountRef — context mismatch.
_, err = db.Consume(ctx, otherAccount, model.PurposePasswordReset, raw)
require.Error(t, err)
assert.True(t, errors.Is(err, verification.ErrTokenNotFound))
}

View File

@@ -6,9 +6,12 @@ import (
)
var (
ErrTokenNotFound = errors.New("vtNotFound")
ErrTokenAlreadyUsed = errors.New("vtAlreadyUsed")
ErrTokenExpired = errors.New("vtExpired")
ErrTokenNotFound = errors.New("vtNotFound")
ErrTokenAlreadyUsed = errors.New("vtAlreadyUsed")
ErrTokenExpired = errors.New("vtExpired")
ErrTokenAttemptsExceeded = errors.New("vtAttemptsExceeded")
ErrCooldownActive = errors.New("vtCooldownActive")
ErrIdempotencyConflict = errors.New("vtIdempotencyConflict")
)
func wrap(err error, msg string) error {
@@ -26,3 +29,15 @@ func ErorrTokenAlreadyUsed() error {
func ErorrTokenExpired() error {
return wrap(ErrTokenExpired, "verification token expired")
}
func ErrorCooldownActive() error {
return wrap(ErrCooldownActive, "token creation cooldown is active")
}
func ErrorTokenAttemptsExceeded() error {
return wrap(ErrTokenAttemptsExceeded, "verification token max attempts exceeded")
}
func ErrorIdempotencyConflict() error {
return wrap(ErrIdempotencyConflict, "verification token request idempotency key has already been used")
}

View File

@@ -0,0 +1,70 @@
package verification
import (
"strings"
"time"
"github.com/tech/sendico/pkg/model"
"go.mongodb.org/mongo-driver/v2/bson"
)
type TokenKind = string
const (
TokenKindOTP TokenKind = "otp"
TokenKindLink TokenKind = "link"
)
type Request struct {
AccountRef bson.ObjectID
Purpose model.VerificationPurpose
Target string
Ttl time.Duration
Kind TokenKind
MaxRetries *int
Cooldown *time.Duration
IdempotencyKey *string // Optional key to make Create idempotent for retries.
}
func newRequest(accountRef bson.ObjectID, purpose model.VerificationPurpose, target string, kind TokenKind) *Request {
return &Request{
AccountRef: accountRef,
Purpose: purpose,
Target: target,
Kind: kind,
Ttl: 15 * time.Minute, // default TTL for verification tokens
}
}
func NewLinkRequest(accountRef bson.ObjectID, purpose model.VerificationPurpose, target string) *Request {
return newRequest(accountRef, purpose, target, TokenKindLink)
}
func NewOTPRequest(accountRef bson.ObjectID, purpose model.VerificationPurpose, target string) *Request {
return newRequest(accountRef, purpose, target, TokenKindOTP)
}
func (r *Request) WithTTL(ttl time.Duration) *Request {
r.Ttl = ttl
return r
}
func (r *Request) WithMaxRetries(maxRetries int) *Request {
r.MaxRetries = &maxRetries
return r
}
func (r *Request) WithCooldown(cooldown time.Duration) *Request {
r.Cooldown = &cooldown
return r
}
func (r *Request) WithIdempotencyKey(key string) *Request {
normalized := strings.TrimSpace(key)
if normalized == "" {
r.IdempotencyKey = nil
return r
}
r.IdempotencyKey = &normalized
return r
}

View File

@@ -2,7 +2,6 @@ package verification
import (
"context"
"time"
"github.com/tech/sendico/pkg/model"
"go.mongodb.org/mongo-driver/v2/bson"
@@ -10,12 +9,6 @@ import (
type DB interface {
// template.DB[*model.VerificationToken]
Create(
ctx context.Context,
accountRef bson.ObjectID,
purpose model.VerificationPurpose,
target string,
ttl time.Duration,
) (rawToken string, err error)
Consume(ctx context.Context, rawToken string) (*model.VerificationToken, error)
Create(ctx context.Context, request *Request) (verificationToken string, err error)
Consume(ctx context.Context, accountRef bson.ObjectID, purpose model.VerificationPurpose, verificationToken string) (*model.VerificationToken, error)
}