package verificationimp import ( "context" "errors" "strings" "time" "github.com/tech/sendico/pkg/db/repository" "github.com/tech/sendico/pkg/db/repository/builder" "github.com/tech/sendico/pkg/db/verification" "github.com/tech/sendico/pkg/merrors" "github.com/tech/sendico/pkg/model" "go.mongodb.org/mongo-driver/v2/bson" ) func normalizedIdempotencyKey(value *string) (string, bool) { if value == nil { return "", false } key := strings.TrimSpace(*value) if key == "" { return "", false } return key, true } func syntheticIdempotencyKey() string { return "auto:" + bson.NewObjectID().Hex() } func verificationContextFilter(request *verification.Request) builder.Query { return repository.Query().And( repository.Filter("accountRef", request.AccountRef), repository.Filter("purpose", request.Purpose), repository.Filter("target", request.Target), ) } func activeContextFilter(request *verification.Request, now time.Time) builder.Query { return repository.Query().And( repository.Filter("accountRef", request.AccountRef), repository.Filter("purpose", request.Purpose), repository.Filter("target", request.Target), repository.Filter("usedAt", nil), repository.Query().Comparison(repository.Field("expiresAt"), builder.Gt, now), ) } func cooldownActiveContextFilter(request *verification.Request, now, cutoff time.Time) builder.Query { return repository.Query().And( repository.Filter("accountRef", request.AccountRef), repository.Filter("purpose", request.Purpose), repository.Filter("target", request.Target), repository.Filter("usedAt", nil), repository.Query().Comparison(repository.Field("expiresAt"), builder.Gt, now), repository.Query().Comparison(repository.Field("createdAt"), builder.Gt, cutoff), ) } func idempotencyFilter( request *verification.Request, idempotencyKey string, ) builder.Query { return repository.Query().And( repository.Filter("accountRef", request.AccountRef), repository.Filter("purpose", request.Purpose), repository.Filter("target", request.Target), repository.Filter("idempotencyKey", idempotencyKey), ) } func hashFilter(hash string) builder.Query { return repository.Filter("verifyTokenHash", hash) } func idempotencySeed(request *verification.Request, idempotencyKey string) string { return strings.Join([]string{ request.AccountRef.Hex(), string(request.Purpose), request.Target, request.Kind, idempotencyKey, }, "|") } func newVerificationToken( request *verification.Request, idempotencyKey string, hasIdempotency bool, ) (*model.VerificationToken, string, error) { now := time.Now().UTC() var ( raw string hash string salt *string err error ) switch request.Kind { case verification.TokenKindOTP: if hasIdempotency { var saltValue string raw, saltValue, hash = generateDeterministicOTP(idempotencySeed(request, idempotencyKey)) salt = &saltValue } else { var s string raw, s, hash, err = generateOTP() if err != nil { return nil, "", err } salt = &s } default: // Magic token if hasIdempotency { raw, hash = generateDeterministicMagic(idempotencySeed(request, idempotencyKey)) } else { raw, hash, err = generateMagic() if err != nil { return nil, "", err } } } token := &model.VerificationToken{ AccountRef: request.AccountRef, Purpose: request.Purpose, Target: request.Target, IdempotencyKey: nil, VerifyTokenHash: hash, Salt: salt, UsedAt: nil, ExpiresAt: now.Add(request.Ttl), MaxRetries: request.MaxRetries, } if hasIdempotency { token.IdempotencyKey = &idempotencyKey } return token, raw, nil } func (db *verificationDB) Create( ctx context.Context, request *verification.Request, ) (string, error) { if request == nil { return "", merrors.Internal("nil request") } idempotencyKey, hasIdempotency := normalizedIdempotencyKey(request.IdempotencyKey) if !hasIdempotency { // Legacy deployments may still enforce uniqueness on (accountRef, purpose, target, idempotencyKey), // where missing idempotency key behaves like a shared null key. Assign an internal per-request key // so token reissue works even when callers do not provide idempotency explicitly. idempotencyKey = syntheticIdempotencyKey() hasIdempotency = true } token, raw, err := newVerificationToken(request, idempotencyKey, hasIdempotency) if err != nil { return "", err } _, err = db.tf.CreateTransaction().Execute(ctx, func(tx context.Context) (any, error) { now := time.Now().UTC() activeFilter := activeContextFilter(request, now) // Optional idempotency key support for safe retries. if hasIdempotency { var sameToken model.VerificationToken err := db.DBImp.FindOne(tx, hashFilter(token.VerifyTokenHash), &sameToken) switch { case err == nil: // Same hash means the same Create operation already succeeded. return nil, nil case errors.Is(err, merrors.ErrNoData): default: return nil, err } var existing model.VerificationToken err = db.DBImp.FindOne(tx, idempotencyFilter(request, idempotencyKey), &existing) switch { case err == nil: // Existing request with the same idempotency scope has already succeeded. return nil, nil case errors.Is(err, merrors.ErrNoData): default: return nil, err } } // 1) Cooldown: if there exists ANY active token created after cutoff → block if request.Cooldown != nil { cutoff := now.Add(-*request.Cooldown) var recent model.VerificationToken err := db.DBImp.FindOne(tx, cooldownActiveContextFilter(request, now, cutoff), &recent) switch { case err == nil: return nil, verification.ErrorCooldownActive() case errors.Is(err, merrors.ErrNoData): default: return nil, err } } // 2) Invalidate active tokens for this context if _, err := db.DBImp.PatchMany( tx, activeFilter, repository.Patch().Set(repository.Field("usedAt"), now), ); err != nil { return nil, err } // 3) Create new token only after cooldown/idempotency checks pass. if err := db.DBImp.Create(tx, token); err != nil { if hasIdempotency && errors.Is(err, merrors.ErrDataConflict) { var sameToken model.VerificationToken findErr := db.DBImp.FindOne(tx, hashFilter(token.VerifyTokenHash), &sameToken) switch { case findErr == nil: return nil, nil case errors.Is(findErr, merrors.ErrNoData): default: return nil, findErr } var existing model.VerificationToken findErr = db.DBImp.FindOne(tx, idempotencyFilter(request, idempotencyKey), &existing) switch { case findErr == nil: return nil, nil case errors.Is(findErr, merrors.ErrNoData): default: return nil, findErr } } return nil, err } return nil, nil }) if err != nil { return "", err } return raw, nil }