Files
sendico/api/server/internal/server/confirmationimp/store.go
2026-01-06 17:51:35 +01:00

188 lines
4.5 KiB
Go

package confirmationimp
import (
"context"
"crypto/rand"
"crypto/sha256"
"crypto/subtle"
"errors"
"time"
"github.com/tech/sendico/pkg/db/confirmation"
"github.com/tech/sendico/pkg/merrors"
"github.com/tech/sendico/pkg/model"
"go.mongodb.org/mongo-driver/bson/primitive"
)
var (
errConfirmationNotFound confirmationError = "confirmation not found or expired"
errConfirmationUsed confirmationError = "confirmation already used"
errConfirmationMismatch confirmationError = "confirmation code mismatch"
errConfirmationAttemptsExceeded confirmationError = "confirmation attempts exceeded"
errConfirmationCooldown confirmationError = "confirmation cooldown active"
errConfirmationResendLimit confirmationError = "confirmation resend limit reached"
)
type confirmationError string
func (e confirmationError) Error() string {
return string(e)
}
type ConfirmationStore struct {
db confirmation.DB
}
func NewStore(db confirmation.DB) *ConfirmationStore {
return &ConfirmationStore{db: db}
}
func (s *ConfirmationStore) Create(
ctx context.Context,
accountRef primitive.ObjectID,
destination string,
target model.ConfirmationTarget,
cfg Config,
generator func() (string, error),
) (string, *model.ConfirmationCode, error) {
if err := s.db.DeleteTuple(ctx, accountRef, destination, target); err != nil && !errors.Is(err, merrors.ErrNoData) {
return "", nil, err
}
code, _, rec, err := s.buildRecord(accountRef, destination, target, cfg, generator)
if err != nil {
return "", nil, err
}
if err := s.db.Create(ctx, rec); err != nil {
return "", nil, err
}
return code, rec, nil
}
func (s *ConfirmationStore) Resend(
ctx context.Context,
accountRef primitive.ObjectID,
destination string,
target model.ConfirmationTarget,
cfg Config,
generator func() (string, error),
) (string, *model.ConfirmationCode, error) {
now := time.Now().UTC()
active, err := s.db.FindActive(ctx, accountRef, destination, target, now.Unix())
if errors.Is(err, merrors.ErrNoData) {
return s.Create(ctx, accountRef, destination, target, cfg, generator)
}
if err != nil {
return "", nil, err
}
if active.ResendCount >= active.ResendLimit {
return "", nil, errConfirmationResendLimit
}
if now.Before(active.CooldownUntil) {
return "", nil, errConfirmationCooldown
}
code, salt, updated, err := s.buildRecord(accountRef, destination, target, cfg, generator)
if err != nil {
return "", nil, err
}
// Preserve attempt counters but bump resend count.
updated.ID = active.ID
updated.CreatedAt = active.CreatedAt
updated.Attempts = active.Attempts
updated.ResendCount = active.ResendCount + 1
updated.Salt = salt
if err := s.db.Update(ctx, updated); err != nil {
return "", nil, err
}
return code, updated, nil
}
func (s *ConfirmationStore) Verify(
ctx context.Context,
accountRef primitive.ObjectID,
destination string,
target model.ConfirmationTarget,
code string,
) error {
now := time.Now().UTC()
rec, err := s.db.FindActive(ctx, accountRef, destination, target, now.Unix())
if errors.Is(err, merrors.ErrNoData) {
return errConfirmationNotFound
}
if err != nil {
return err
}
if rec.Used {
return errConfirmationUsed
}
rec.Attempts++
if rec.Attempts > rec.MaxAttempts {
rec.Used = true
_ = s.db.Update(ctx, rec)
return errConfirmationAttemptsExceeded
}
if subtle.ConstantTimeCompare(rec.CodeHash, hashCode(rec.Salt, code)) != 1 {
_ = s.db.Update(ctx, rec)
return errConfirmationMismatch
}
rec.Used = true
return s.db.Update(ctx, rec)
}
func (s *ConfirmationStore) buildRecord(
accountRef primitive.ObjectID,
destination string,
target model.ConfirmationTarget,
cfg Config,
generator func() (string, error),
) (string, []byte, *model.ConfirmationCode, error) {
code, err := generator()
if err != nil {
return "", nil, nil, err
}
salt, err := newSalt()
if err != nil {
return "", nil, nil, err
}
now := time.Now().UTC()
rec := &model.ConfirmationCode{
AccountRef: accountRef,
Destination: destination,
Target: target,
CodeHash: hashCode(salt, code),
Salt: salt,
ExpiresAt: now.Add(cfg.TTL),
MaxAttempts: cfg.MaxAttempts,
ResendLimit: cfg.ResendLimit,
CooldownUntil: now.Add(cfg.Cooldown),
Used: false,
Attempts: 0,
ResendCount: 0,
}
return code, salt, rec, nil
}
func hashCode(salt []byte, code string) []byte {
h := sha256.New()
h.Write(salt)
h.Write([]byte(code))
return h.Sum(nil)
}
func newSalt() ([]byte, error) {
buf := make([]byte, 16)
if _, err := rand.Read(buf); err != nil {
return nil, err
}
return buf, nil
}