116 lines
3.9 KiB
Go
116 lines
3.9 KiB
Go
package store
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
|
|
"github.com/tech/sendico/ledger/storage"
|
|
"github.com/tech/sendico/ledger/storage/model"
|
|
"github.com/tech/sendico/pkg/db/repository"
|
|
ri "github.com/tech/sendico/pkg/db/repository/index"
|
|
"github.com/tech/sendico/pkg/merrors"
|
|
"github.com/tech/sendico/pkg/mlogger"
|
|
"go.mongodb.org/mongo-driver/bson/primitive"
|
|
"go.mongodb.org/mongo-driver/mongo"
|
|
"go.uber.org/zap"
|
|
)
|
|
|
|
type balancesStore struct {
|
|
logger mlogger.Logger
|
|
repo repository.Repository
|
|
}
|
|
|
|
func NewBalances(logger mlogger.Logger, db *mongo.Database) (storage.BalancesStore, error) {
|
|
repo := repository.CreateMongoRepository(db, model.AccountBalancesCollection)
|
|
|
|
// Create unique index on accountRef (one balance per account)
|
|
uniqueIndex := &ri.Definition{
|
|
Keys: []ri.Key{
|
|
{Field: "accountRef", Sort: ri.Asc},
|
|
},
|
|
Unique: true,
|
|
}
|
|
if err := repo.CreateIndex(uniqueIndex); err != nil {
|
|
logger.Error("failed to ensure balances unique index", zap.Error(err))
|
|
return nil, err
|
|
}
|
|
|
|
childLogger := logger.Named(model.AccountBalancesCollection)
|
|
childLogger.Debug("balances store initialised", zap.String("collection", model.AccountBalancesCollection))
|
|
|
|
return &balancesStore{
|
|
logger: childLogger,
|
|
repo: repo,
|
|
}, nil
|
|
}
|
|
|
|
func (b *balancesStore) Get(ctx context.Context, accountRef primitive.ObjectID) (*model.AccountBalance, error) {
|
|
if accountRef.IsZero() {
|
|
b.logger.Warn("attempt to get balance with zero account ID")
|
|
return nil, merrors.InvalidArgument("balancesStore: zero account ID")
|
|
}
|
|
|
|
query := repository.Filter("accountRef", accountRef)
|
|
|
|
result := &model.AccountBalance{}
|
|
if err := b.repo.FindOneByFilter(ctx, query, result); err != nil {
|
|
if errors.Is(err, merrors.ErrNoData) {
|
|
b.logger.Debug("balance not found", zap.String("accountRef", accountRef.Hex()))
|
|
return nil, storage.ErrBalanceNotFound
|
|
}
|
|
b.logger.Warn("failed to get balance", zap.Error(err), zap.String("accountRef", accountRef.Hex()))
|
|
return nil, err
|
|
}
|
|
|
|
b.logger.Debug("balance loaded", zap.String("accountRef", accountRef.Hex()),
|
|
zap.String("balance", result.Balance))
|
|
return result, nil
|
|
}
|
|
|
|
func (b *balancesStore) Upsert(ctx context.Context, balance *model.AccountBalance) error {
|
|
if balance == nil {
|
|
b.logger.Warn("attempt to upsert nil balance")
|
|
return merrors.InvalidArgument("balancesStore: nil balance")
|
|
}
|
|
if balance.AccountRef.IsZero() {
|
|
b.logger.Warn("attempt to upsert balance with zero account ID")
|
|
return merrors.InvalidArgument("balancesStore: zero account ID")
|
|
}
|
|
|
|
existing := &model.AccountBalance{}
|
|
filter := repository.Filter("accountRef", balance.AccountRef)
|
|
|
|
if err := b.repo.FindOneByFilter(ctx, filter, existing); err != nil {
|
|
if errors.Is(err, merrors.ErrNoData) {
|
|
b.logger.Debug("inserting new balance", zap.String("accountRef", balance.AccountRef.Hex()))
|
|
return b.repo.Insert(ctx, balance, filter)
|
|
}
|
|
b.logger.Warn("failed to fetch balance", zap.Error(err), zap.String("accountRef", balance.AccountRef.Hex()))
|
|
return err
|
|
}
|
|
|
|
if existing.GetID() != nil {
|
|
balance.SetID(*existing.GetID())
|
|
}
|
|
b.logger.Debug("updating balance", zap.String("accountRef", balance.AccountRef.Hex()),
|
|
zap.String("balance", balance.Balance))
|
|
return b.repo.Update(ctx, balance)
|
|
}
|
|
|
|
func (b *balancesStore) IncrementBalance(ctx context.Context, accountRef primitive.ObjectID, amount string) error {
|
|
if accountRef.IsZero() {
|
|
b.logger.Warn("attempt to increment balance with zero account ID")
|
|
return merrors.InvalidArgument("balancesStore: zero account ID")
|
|
}
|
|
|
|
// Note: This implementation uses $inc on a string field, which won't work.
|
|
// In a real implementation, you'd need to:
|
|
// 1. Fetch the balance
|
|
// 2. Parse amount strings to decimal
|
|
// 3. Add them
|
|
// 4. Update with optimistic locking via version field
|
|
// For now, return not implemented to indicate this needs proper decimal handling
|
|
b.logger.Warn("IncrementBalance not fully implemented - requires decimal arithmetic")
|
|
return merrors.NotImplemented("IncrementBalance requires proper decimal handling")
|
|
}
|