188 lines
4.5 KiB
Go
188 lines
4.5 KiB
Go
package confirmationimp
|
|
|
|
import (
|
|
"context"
|
|
"crypto/rand"
|
|
"crypto/sha256"
|
|
"crypto/subtle"
|
|
"errors"
|
|
"time"
|
|
|
|
"github.com/tech/sendico/pkg/db/confirmation"
|
|
"github.com/tech/sendico/pkg/merrors"
|
|
"github.com/tech/sendico/pkg/model"
|
|
"go.mongodb.org/mongo-driver/bson/primitive"
|
|
)
|
|
|
|
var (
|
|
errConfirmationNotFound confirmationError = "confirmation not found or expired"
|
|
errConfirmationUsed confirmationError = "confirmation already used"
|
|
errConfirmationMismatch confirmationError = "confirmation code mismatch"
|
|
errConfirmationAttemptsExceeded confirmationError = "confirmation attempts exceeded"
|
|
errConfirmationCooldown confirmationError = "confirmation cooldown active"
|
|
errConfirmationResendLimit confirmationError = "confirmation resend limit reached"
|
|
)
|
|
|
|
type confirmationError string
|
|
|
|
func (e confirmationError) Error() string {
|
|
return string(e)
|
|
}
|
|
|
|
type ConfirmationStore struct {
|
|
db confirmation.DB
|
|
}
|
|
|
|
func NewStore(db confirmation.DB) *ConfirmationStore {
|
|
return &ConfirmationStore{db: db}
|
|
}
|
|
|
|
func (s *ConfirmationStore) Create(
|
|
ctx context.Context,
|
|
accountRef primitive.ObjectID,
|
|
destination string,
|
|
target model.ConfirmationTarget,
|
|
cfg Config,
|
|
generator func() (string, error),
|
|
) (string, *model.ConfirmationCode, error) {
|
|
if err := s.db.DeleteTuple(ctx, accountRef, destination, target); err != nil && !errors.Is(err, merrors.ErrNoData) {
|
|
return "", nil, err
|
|
}
|
|
|
|
code, _, rec, err := s.buildRecord(accountRef, destination, target, cfg, generator)
|
|
if err != nil {
|
|
return "", nil, err
|
|
}
|
|
|
|
if err := s.db.Create(ctx, rec); err != nil {
|
|
return "", nil, err
|
|
}
|
|
|
|
return code, rec, nil
|
|
}
|
|
|
|
func (s *ConfirmationStore) Resend(
|
|
ctx context.Context,
|
|
accountRef primitive.ObjectID,
|
|
destination string,
|
|
target model.ConfirmationTarget,
|
|
cfg Config,
|
|
generator func() (string, error),
|
|
) (string, *model.ConfirmationCode, error) {
|
|
now := time.Now().UTC()
|
|
active, err := s.db.FindActive(ctx, accountRef, destination, target, now.Unix())
|
|
if errors.Is(err, merrors.ErrNoData) {
|
|
return s.Create(ctx, accountRef, destination, target, cfg, generator)
|
|
}
|
|
if err != nil {
|
|
return "", nil, err
|
|
}
|
|
if active.ResendCount >= active.ResendLimit {
|
|
return "", nil, errConfirmationResendLimit
|
|
}
|
|
if now.Before(active.CooldownUntil) {
|
|
return "", nil, errConfirmationCooldown
|
|
}
|
|
|
|
code, salt, updated, err := s.buildRecord(accountRef, destination, target, cfg, generator)
|
|
if err != nil {
|
|
return "", nil, err
|
|
}
|
|
// Preserve attempt counters but bump resend count.
|
|
updated.ID = active.ID
|
|
updated.CreatedAt = active.CreatedAt
|
|
updated.Attempts = active.Attempts
|
|
updated.ResendCount = active.ResendCount + 1
|
|
updated.Salt = salt
|
|
|
|
if err := s.db.Update(ctx, updated); err != nil {
|
|
return "", nil, err
|
|
}
|
|
return code, updated, nil
|
|
}
|
|
|
|
func (s *ConfirmationStore) Verify(
|
|
ctx context.Context,
|
|
accountRef primitive.ObjectID,
|
|
destination string,
|
|
target model.ConfirmationTarget,
|
|
code string,
|
|
) error {
|
|
now := time.Now().UTC()
|
|
rec, err := s.db.FindActive(ctx, accountRef, destination, target, now.Unix())
|
|
if errors.Is(err, merrors.ErrNoData) {
|
|
return errConfirmationNotFound
|
|
}
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if rec.Used {
|
|
return errConfirmationUsed
|
|
}
|
|
|
|
rec.Attempts++
|
|
if rec.Attempts > rec.MaxAttempts {
|
|
rec.Used = true
|
|
_ = s.db.Update(ctx, rec)
|
|
return errConfirmationAttemptsExceeded
|
|
}
|
|
|
|
if subtle.ConstantTimeCompare(rec.CodeHash, hashCode(rec.Salt, code)) != 1 {
|
|
_ = s.db.Update(ctx, rec)
|
|
return errConfirmationMismatch
|
|
}
|
|
|
|
rec.Used = true
|
|
return s.db.Update(ctx, rec)
|
|
}
|
|
|
|
func (s *ConfirmationStore) buildRecord(
|
|
accountRef primitive.ObjectID,
|
|
destination string,
|
|
target model.ConfirmationTarget,
|
|
cfg Config,
|
|
generator func() (string, error),
|
|
) (string, []byte, *model.ConfirmationCode, error) {
|
|
code, err := generator()
|
|
if err != nil {
|
|
return "", nil, nil, err
|
|
}
|
|
salt, err := newSalt()
|
|
if err != nil {
|
|
return "", nil, nil, err
|
|
}
|
|
|
|
now := time.Now().UTC()
|
|
rec := &model.ConfirmationCode{
|
|
AccountRef: accountRef,
|
|
Destination: destination,
|
|
Target: target,
|
|
CodeHash: hashCode(salt, code),
|
|
Salt: salt,
|
|
ExpiresAt: now.Add(cfg.TTL),
|
|
MaxAttempts: cfg.MaxAttempts,
|
|
ResendLimit: cfg.ResendLimit,
|
|
CooldownUntil: now.Add(cfg.Cooldown),
|
|
Used: false,
|
|
Attempts: 0,
|
|
ResendCount: 0,
|
|
}
|
|
|
|
return code, salt, rec, nil
|
|
}
|
|
|
|
func hashCode(salt []byte, code string) []byte {
|
|
h := sha256.New()
|
|
h.Write(salt)
|
|
h.Write([]byte(code))
|
|
return h.Sum(nil)
|
|
}
|
|
|
|
func newSalt() ([]byte, error) {
|
|
buf := make([]byte, 16)
|
|
if _, err := rand.Read(buf); err != nil {
|
|
return nil, err
|
|
}
|
|
return buf, nil
|
|
}
|