service backend
All checks were successful
ci/woodpecker/push/db Pipeline was successful
ci/woodpecker/push/nats Pipeline was successful

This commit is contained in:
Stephan D
2025-11-07 18:35:26 +01:00
parent 20e8f9acc4
commit 62a6631b9a
537 changed files with 48453 additions and 0 deletions

View File

@@ -0,0 +1,30 @@
package accountdb
import (
ri "github.com/tech/sendico/pkg/db/repository/index"
"github.com/tech/sendico/pkg/db/template"
"github.com/tech/sendico/pkg/mlogger"
"github.com/tech/sendico/pkg/model"
"github.com/tech/sendico/pkg/mservice"
"go.mongodb.org/mongo-driver/mongo"
"go.uber.org/zap"
)
type AccountDB struct {
template.DBImp[*model.Account]
}
func Create(logger mlogger.Logger, db *mongo.Database) (*AccountDB, error) {
p := &AccountDB{
DBImp: *template.Create[*model.Account](logger, mservice.Accounts, db),
}
if err := p.DBImp.Repository.CreateIndex(&ri.Definition{
Keys: []ri.Key{{Field: "login", Sort: ri.Asc}},
Unique: true,
}); err != nil {
p.Logger.Error("Failed to create account database", zap.Error(err))
return nil, err
}
return p, nil
}

View File

@@ -0,0 +1,13 @@
package accountdb
import (
"context"
"github.com/tech/sendico/pkg/db/repository"
"github.com/tech/sendico/pkg/model"
)
func (db *AccountDB) GetByToken(ctx context.Context, email string) (*model.Account, error) {
var account model.Account
return &account, db.FindOne(ctx, repository.Query().Filter(repository.Field("verifyToken"), email), &account)
}

View File

@@ -0,0 +1,21 @@
package accountdb
import (
"context"
"github.com/tech/sendico/pkg/db/repository"
"github.com/tech/sendico/pkg/db/repository/builder"
"github.com/tech/sendico/pkg/model"
mutil "github.com/tech/sendico/pkg/mutil/db"
"go.mongodb.org/mongo-driver/bson/primitive"
)
func (db *AccountDB) GetAccountsByRefs(ctx context.Context, orgRef primitive.ObjectID, refs []primitive.ObjectID) ([]model.Account, error) {
filter := repository.Query().Comparison(repository.IDField(), builder.In, refs)
return mutil.GetObjects[model.Account](ctx, db.Logger, filter, nil, db.Repository)
}
func (db *AccountDB) GetByEmail(ctx context.Context, email string) (*model.Account, error) {
var account model.Account
return &account, db.FindOne(ctx, repository.Filter("login", email), &account)
}

View File

@@ -0,0 +1,99 @@
package archivable
import (
"context"
"github.com/tech/sendico/pkg/db/repository"
"github.com/tech/sendico/pkg/db/storable"
"github.com/tech/sendico/pkg/mlogger"
"github.com/tech/sendico/pkg/model"
"github.com/tech/sendico/pkg/mutil/mzap"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.uber.org/zap"
)
// ArchivableDB implements archive management for entities with model.Archivable embedded
type ArchivableDB[T storable.Storable] struct {
repo repository.Repository
logger mlogger.Logger
createEmpty func() T
getArchivable func(T) model.Archivable
}
// NewArchivableDB creates a new ArchivableDB instance
func NewArchivableDB[T storable.Storable](
repo repository.Repository,
logger mlogger.Logger,
createEmpty func() T,
getArchivable func(T) model.Archivable,
) *ArchivableDB[T] {
return &ArchivableDB[T]{
repo: repo,
logger: logger,
createEmpty: createEmpty,
getArchivable: getArchivable,
}
}
// SetArchived sets the archived status of an entity
func (db *ArchivableDB[T]) SetArchived(ctx context.Context, objectRef primitive.ObjectID, archived bool) error {
// Get current object to check current archived status
obj := db.createEmpty()
if err := db.repo.Get(ctx, objectRef, obj); err != nil {
db.logger.Warn("Failed to get object for setting archived status",
zap.Error(err),
mzap.ObjRef("object_ref", objectRef),
zap.Bool("archived", archived))
return err
}
// Extract archivable from the object
archivable := db.getArchivable(obj)
currentArchived := archivable.IsArchived()
if currentArchived == archived {
db.logger.Debug("No change needed - same archived status",
mzap.ObjRef("object_ref", objectRef),
zap.Bool("archived", archived))
return nil // No change needed
}
// Set the archived status
patch := repository.Patch().Set(repository.IsArchivedField(), archived)
if err := db.repo.Patch(ctx, objectRef, patch); err != nil {
db.logger.Warn("Failed to set archived status on object",
zap.Error(err),
mzap.ObjRef("object_ref", objectRef),
zap.Bool("archived", archived))
return err
}
db.logger.Debug("Successfully set archived status on object",
mzap.ObjRef("object_ref", objectRef),
zap.Bool("archived", archived))
return nil
}
// IsArchived checks if an entity is archived
func (db *ArchivableDB[T]) IsArchived(ctx context.Context, objectRef primitive.ObjectID) (bool, error) {
obj := db.createEmpty()
if err := db.repo.Get(ctx, objectRef, obj); err != nil {
db.logger.Warn("Failed to get object for checking archived status",
zap.Error(err),
mzap.ObjRef("object_ref", objectRef))
return false, err
}
archivable := db.getArchivable(obj)
return archivable.IsArchived(), nil
}
// Archive archives an entity (sets archived to true)
func (db *ArchivableDB[T]) Archive(ctx context.Context, objectRef primitive.ObjectID) error {
return db.SetArchived(ctx, objectRef, true)
}
// Unarchive unarchives an entity (sets archived to false)
func (db *ArchivableDB[T]) Unarchive(ctx context.Context, objectRef primitive.ObjectID) error {
return db.SetArchived(ctx, objectRef, false)
}

View File

@@ -0,0 +1,175 @@
//go:build integration
// +build integration
package archivable
import (
"context"
"testing"
"time"
"github.com/tech/sendico/pkg/db/internal/mongo/repositoryimp"
"github.com/tech/sendico/pkg/db/storable"
"github.com/tech/sendico/pkg/model"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/testcontainers/testcontainers-go"
"github.com/testcontainers/testcontainers-go/modules/mongodb"
"github.com/testcontainers/testcontainers-go/wait"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
"go.uber.org/zap"
)
// TestArchivableObject represents a test object with archivable functionality
type TestArchivableObject struct {
storable.Base `bson:",inline" json:",inline"`
model.ArchivableBase `bson:",inline" json:",inline"`
Name string `bson:"name" json:"name"`
}
func (t *TestArchivableObject) Collection() string {
return "testArchivableObject"
}
func (t *TestArchivableObject) GetArchivable() model.Archivable {
return &t.ArchivableBase
}
func TestArchivableDB(t *testing.T) {
ctx := context.Background()
// Start MongoDB container (stable)
mongoContainer, err := mongodb.Run(ctx,
"mongo:latest",
mongodb.WithUsername("test"),
mongodb.WithPassword("test"),
testcontainers.WithWaitStrategy(wait.ForListeningPort("27017/tcp").WithStartupTimeout(2*time.Minute)),
)
require.NoError(t, err)
defer func() {
termCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
if err := mongoContainer.Terminate(termCtx); err != nil {
t.Logf("Failed to terminate container: %v", err)
}
}()
// Get MongoDB connection string
mongoURI, err := mongoContainer.ConnectionString(ctx)
require.NoError(t, err)
// Connect to MongoDB
client, err := mongo.Connect(ctx, options.Client().ApplyURI(mongoURI))
require.NoError(t, err)
defer func() {
if err := client.Disconnect(context.Background()); err != nil {
t.Logf("Failed to disconnect from MongoDB: %v", err)
}
}()
// Ping the database
err = client.Ping(ctx, nil)
require.NoError(t, err)
// Create repository
repo := repositoryimp.NewMongoRepository(client.Database("test_"+t.Name()), "testArchivableCollection")
// Create archivable DB
archivableDB := NewArchivableDB(
repo,
zap.NewNop(),
func() *TestArchivableObject { return &TestArchivableObject{} },
func(obj *TestArchivableObject) model.Archivable { return obj.GetArchivable() },
)
t.Run("SetArchived_Success", func(t *testing.T) {
obj := &TestArchivableObject{Name: "test", ArchivableBase: model.ArchivableBase{Archived: false}}
err := repo.Insert(ctx, obj, nil)
require.NoError(t, err)
err = archivableDB.SetArchived(ctx, obj.ID, true)
require.NoError(t, err)
var result TestArchivableObject
err = repo.Get(ctx, obj.ID, &result)
require.NoError(t, err)
assert.True(t, result.IsArchived())
})
t.Run("SetArchived_NoChange", func(t *testing.T) {
obj := &TestArchivableObject{Name: "test", ArchivableBase: model.ArchivableBase{Archived: true}}
err := repo.Insert(ctx, obj, nil)
require.NoError(t, err)
err = archivableDB.SetArchived(ctx, obj.ID, true)
require.NoError(t, err) // Should not error, just not change anything
var result TestArchivableObject
err = repo.Get(ctx, obj.ID, &result)
require.NoError(t, err)
assert.True(t, result.IsArchived())
})
t.Run("SetArchived_Unarchive", func(t *testing.T) {
obj := &TestArchivableObject{Name: "test", ArchivableBase: model.ArchivableBase{Archived: true}}
err := repo.Insert(ctx, obj, nil)
require.NoError(t, err)
err = archivableDB.SetArchived(ctx, obj.ID, false)
require.NoError(t, err)
var result TestArchivableObject
err = repo.Get(ctx, obj.ID, &result)
require.NoError(t, err)
assert.False(t, result.IsArchived())
})
t.Run("IsArchived_True", func(t *testing.T) {
obj := &TestArchivableObject{Name: "test", ArchivableBase: model.ArchivableBase{Archived: true}}
err := repo.Insert(ctx, obj, nil)
require.NoError(t, err)
isArchived, err := archivableDB.IsArchived(ctx, obj.ID)
require.NoError(t, err)
assert.True(t, isArchived)
})
t.Run("IsArchived_False", func(t *testing.T) {
obj := &TestArchivableObject{Name: "test", ArchivableBase: model.ArchivableBase{Archived: false}}
err := repo.Insert(ctx, obj, nil)
require.NoError(t, err)
isArchived, err := archivableDB.IsArchived(ctx, obj.ID)
require.NoError(t, err)
assert.False(t, isArchived)
})
t.Run("Archive_Success", func(t *testing.T) {
obj := &TestArchivableObject{Name: "test", ArchivableBase: model.ArchivableBase{Archived: false}}
err := repo.Insert(ctx, obj, nil)
require.NoError(t, err)
err = archivableDB.Archive(ctx, obj.ID)
require.NoError(t, err)
var result TestArchivableObject
err = repo.Get(ctx, obj.ID, &result)
require.NoError(t, err)
assert.True(t, result.IsArchived())
})
t.Run("Unarchive_Success", func(t *testing.T) {
obj := &TestArchivableObject{Name: "test", ArchivableBase: model.ArchivableBase{Archived: true}}
err := repo.Insert(ctx, obj, nil)
require.NoError(t, err)
err = archivableDB.Unarchive(ctx, obj.ID)
require.NoError(t, err)
var result TestArchivableObject
err = repo.Get(ctx, obj.ID, &result)
require.NoError(t, err)
assert.False(t, result.IsArchived())
})
}

257
api/pkg/db/internal/mongo/db.go Executable file
View File

@@ -0,0 +1,257 @@
package mongo
import (
"context"
"os"
"github.com/mitchellh/mapstructure"
"github.com/tech/sendico/pkg/auth"
"github.com/tech/sendico/pkg/db/account"
"github.com/tech/sendico/pkg/db/internal/mongo/accountdb"
"github.com/tech/sendico/pkg/db/internal/mongo/invitationdb"
"github.com/tech/sendico/pkg/db/internal/mongo/organizationdb"
"github.com/tech/sendico/pkg/db/internal/mongo/policiesdb"
"github.com/tech/sendico/pkg/db/internal/mongo/refreshtokensdb"
"github.com/tech/sendico/pkg/db/internal/mongo/rolesdb"
"github.com/tech/sendico/pkg/db/internal/mongo/transactionimp"
"github.com/tech/sendico/pkg/db/invitation"
"github.com/tech/sendico/pkg/db/organization"
"github.com/tech/sendico/pkg/db/policy"
"github.com/tech/sendico/pkg/db/refreshtokens"
"github.com/tech/sendico/pkg/db/repository"
"github.com/tech/sendico/pkg/db/role"
"github.com/tech/sendico/pkg/db/transaction"
"github.com/tech/sendico/pkg/mlogger"
"github.com/tech/sendico/pkg/model"
"github.com/tech/sendico/pkg/mservice"
mutil "github.com/tech/sendico/pkg/mutil/config"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
"go.mongodb.org/mongo-driver/mongo/readpref"
"go.uber.org/zap"
)
// Config represents configuration
type Config struct {
Port *string `mapstructure:"port"`
PortEnv *string `mapstructure:"port_env"`
User *string `mapstructure:"user"`
UserEnv *string `mapstructure:"user_env"`
PasswordEnv string `mapstructure:"password_env"`
Database *string `mapstructure:"database"`
DatabaseEnv *string `mapstructure:"database_env"`
Host *string `mapstructure:"host"`
HostEnv *string `mapstructure:"host_env"`
AuthSource *string `mapstructure:"auth_source,omitempty"`
AuthSourceEnv *string `mapstructure:"auth_source_env,omitempty"`
AuthMechanism *string `mapstructure:"auth_mechanism,omitempty"`
AuthMechanismEnv *string `mapstructure:"auth_mechanism_env,omitempty"`
ReplicaSet *string `mapstructure:"replica_set,omitempty"`
ReplicaSetEnv *string `mapstructure:"replica_set_env,omitempty"`
Enforcer *auth.Config `mapstructure:"enforcer"`
}
type DBSettings struct {
Host string
Port string
User string
Password string
Database string
AuthSource string
AuthMechanism string
ReplicaSet string
}
func newProtectedDB[T any](
db *DB,
create func(ctx context.Context, logger mlogger.Logger, enforcer auth.Enforcer, pdb policy.DB, client *mongo.Database) (T, error),
) (T, error) {
pdb, err := db.NewPoliciesDB()
if err != nil {
db.logger.Warn("Failed to create policies database", zap.Error(err))
var zero T
return zero, err
}
return create(context.Background(), db.logger, db.Enforcer(), pdb, db.db())
}
func Config2DBSettings(logger mlogger.Logger, config *Config) *DBSettings {
p := new(DBSettings)
p.Port = mutil.GetConfigValue(logger, "port", "port_env", config.Port, config.PortEnv)
p.Database = mutil.GetConfigValue(logger, "database", "database_env", config.Database, config.DatabaseEnv)
p.Password = os.Getenv(config.PasswordEnv)
p.User = mutil.GetConfigValue(logger, "user", "user_env", config.User, config.UserEnv)
p.Host = mutil.GetConfigValue(logger, "host", "host_env", config.Host, config.HostEnv)
p.AuthSource = mutil.GetConfigValue(logger, "auth_source", "auth_source_env", config.AuthSource, config.AuthSourceEnv)
p.AuthMechanism = mutil.GetConfigValue(logger, "auth_mechanism", "auth_mechanism_env", config.AuthMechanism, config.AuthMechanismEnv)
p.ReplicaSet = mutil.GetConfigValue(logger, "replica_set", "replica_set_env", config.ReplicaSet, config.ReplicaSetEnv)
return p
}
func decodeConfig(logger mlogger.Logger, settings model.SettingsT) (*Config, *DBSettings, error) {
var config Config
if err := mapstructure.Decode(settings, &config); err != nil {
logger.Warn("Failed to decode settings", zap.Error(err), zap.Any("settings", settings))
return nil, nil, err
}
dbSettings := Config2DBSettings(logger, &config)
return &config, dbSettings, nil
}
func dialMongo(logger mlogger.Logger, dbSettings *DBSettings) (*mongo.Client, error) {
cred := options.Credential{
AuthMechanism: dbSettings.AuthMechanism,
AuthSource: dbSettings.AuthSource,
Username: dbSettings.User,
Password: dbSettings.Password,
}
dbURI := buildURI(dbSettings)
client, err := mongo.Connect(context.Background(), options.Client().ApplyURI(dbURI).SetAuth(cred))
if err != nil {
logger.Error("Unable to connect to database", zap.Error(err))
return nil, err
}
logger.Info("Connected successfully", zap.String("uri", dbURI))
if err := client.Ping(context.Background(), readpref.Primary()); err != nil {
logger.Error("Unable to ping database", zap.Error(err))
_ = client.Disconnect(context.Background())
return nil, err
}
return client, nil
}
func ConnectClient(logger mlogger.Logger, settings model.SettingsT) (*mongo.Client, *Config, *DBSettings, error) {
config, dbSettings, err := decodeConfig(logger, settings)
if err != nil {
return nil, nil, nil, err
}
client, err := dialMongo(logger, dbSettings)
if err != nil {
return nil, nil, nil, err
}
return client, config, dbSettings, nil
}
// DB represents the structure of the database
type DB struct {
logger mlogger.Logger
config *DBSettings
client *mongo.Client
enforcer auth.Enforcer
manager auth.Manager
pdb policy.DB
}
func (db *DB) db() *mongo.Database {
return db.client.Database(db.config.Database)
}
func (db *DB) NewAccountDB() (account.DB, error) {
return accountdb.Create(db.logger, db.db())
}
func (db *DB) NewOrganizationDB() (organization.DB, error) {
pdb, err := db.NewPoliciesDB()
if err != nil {
db.logger.Warn("Failed to create policies database", zap.Error(err))
return nil, err
}
organizationDB, err := organizationdb.Create(context.Background(), db.logger, db.Enforcer(), pdb, db.db())
if err != nil {
return nil, err
}
// Return the concrete type - interface mismatch will be handled at runtime
// TODO: Update organization.DB interface to match implementation signatures
return organizationDB, nil
}
func (db *DB) NewRefreshTokensDB() (refreshtokens.DB, error) {
return refreshtokensdb.Create(db.logger, db.db())
}
func (db *DB) NewInvitationsDB() (invitation.DB, error) {
return newProtectedDB(db, invitationdb.Create)
}
func (db *DB) NewPoliciesDB() (policy.DB, error) {
return db.pdb, nil
}
func (db *DB) NewRolesDB() (role.DB, error) {
return rolesdb.Create(db.logger, db.db())
}
func (db *DB) TransactionFactory() transaction.Factory {
return transactionimp.CreateFactory(db.client)
}
func (db *DB) Permissions() auth.Provider {
return db
}
func (db *DB) Manager() auth.Manager {
return db.manager
}
func (db *DB) Enforcer() auth.Enforcer {
return db.enforcer
}
func (db *DB) GetPolicyDescription(ctx context.Context, resource mservice.Type) (*model.PolicyDescription, error) {
var policyDescription model.PolicyDescription
return &policyDescription, db.pdb.FindOne(ctx, repository.Filter("resourceTypes", resource), &policyDescription)
}
func (db *DB) CloseConnection() {
if err := db.client.Disconnect(context.Background()); err != nil {
db.logger.Warn("Failed to close connection", zap.Error(err))
}
db.logger.Info("Database connection closed")
}
// NewConnection creates a new database connection
func NewConnection(logger mlogger.Logger, settings model.SettingsT) (*DB, error) {
client, config, dbSettings, err := ConnectClient(logger, settings)
if err != nil {
return nil, err
}
db := &DB{
logger: logger.Named("db"),
config: dbSettings,
client: client,
}
cleanup := func(ctx context.Context) {
if err := client.Disconnect(ctx); err != nil {
logger.Warn("Failed to close MongoDB connection", zap.Error(err))
}
}
rdb, err := db.NewRolesDB()
if err != nil {
db.logger.Warn("Failed to create roles database", zap.Error(err))
cleanup(context.Background())
return nil, err
}
if db.pdb, err = policiesdb.Create(db.logger, db.db()); err != nil {
db.logger.Warn("Failed to create policies database", zap.Error(err))
cleanup(context.Background())
return nil, err
}
if db.enforcer, db.manager, err = auth.CreateAuth(logger, db.client, db.db(), db.pdb, rdb, config.Enforcer); err != nil {
db.logger.Warn("Failed to create permissions enforcer", zap.Error(err))
cleanup(context.Background())
return nil, err
}
return db, nil
}

View File

@@ -0,0 +1,144 @@
# Indexable Implementation (Refactored)
## Overview
This package provides a refactored implementation of the `indexable.DB` interface that uses `mutil.GetObjects` for better consistency with the existing codebase. The implementation has been moved to the mongo folder and includes a factory for project indexable in the pkg/db folder.
## Structure
### 1. `api/pkg/db/internal/mongo/indexable/indexable.go`
- **`ReorderTemplate[T]`**: Generic template function that uses `mutil.GetObjects` for fetching objects
- **`IndexableDB`**: Base struct for creating concrete implementations
- **Type-safe implementation**: Uses Go generics with proper type constraints
### 2. `api/pkg/db/project_indexable.go`
- **`ProjectIndexableDB`**: Factory implementation for Project objects
- **`NewProjectIndexableDB`**: Constructor function
- **`ReorderTemplate`**: Duplicate of the mongo version for convenience
## Key Changes from Previous Implementation
### 1. **Uses `mutil.GetObjects`**
```go
// Old implementation (manual cursor handling)
err = repo.FindManyByFilter(ctx, filter, func(cursor *mongo.Cursor) error {
var obj T
if err := cursor.Decode(&obj); err != nil {
return err
}
objects = append(objects, obj)
return nil
})
// New implementation (using mutil.GetObjects)
objects, err := mutil.GetObjects[T](
ctx,
logger,
filterFunc().
And(
repository.IndexOpFilter(minIdx, builder.Gte),
repository.IndexOpFilter(maxIdx, builder.Lte),
),
nil, nil, nil, // limit, offset, isArchived
repo,
)
```
### 2. **Moved to Mongo Folder**
- Location: `api/pkg/db/internal/mongo/indexable/`
- Consistent with other mongo implementations
- Better organization within the codebase
### 3. **Added Factory in pkg/db**
- Location: `api/pkg/db/project_indexable.go`
- Provides easy access to project indexable functionality
- Includes logger parameter for better error handling
## Usage
### Using the Factory (Recommended)
```go
import "github.com/tech/sendico/pkg/db"
// Create a project indexable DB
projectDB := db.NewProjectIndexableDB(repo, logger, organizationRef)
// Reorder a project
err := projectDB.Reorder(ctx, projectID, newIndex)
if err != nil {
// Handle error
}
```
### Using the Template Directly
```go
import "github.com/tech/sendico/pkg/db/internal/mongo/indexable"
// Define helper functions
getIndexable := func(p *model.Project) *model.Indexable {
return &p.Indexable
}
updateIndexable := func(p *model.Project, newIndex int) {
p.Index = newIndex
}
createEmpty := func() *model.Project {
return &model.Project{}
}
filterFunc := func() builder.Query {
return repository.OrgFilter(organizationRef)
}
// Use the template
err := indexable.ReorderTemplate(
ctx,
logger,
repo,
objectRef,
newIndex,
filterFunc,
getIndexable,
updateIndexable,
createEmpty,
)
```
## Benefits of Refactoring
1. **Consistency**: Uses `mutil.GetObjects` like other parts of the codebase
2. **Better Error Handling**: Includes logger parameter for proper error logging
3. **Organization**: Moved to appropriate folder structure
4. **Factory Pattern**: Easy-to-use factory for common use cases
5. **Type Safety**: Maintains compile-time type checking
6. **Performance**: Leverages existing optimized `mutil.GetObjects` implementation
## Testing
### Mongo Implementation Tests
```bash
go test ./db/internal/mongo/indexable -v
```
### Factory Tests
```bash
go test ./db -v
```
## Integration
The refactored implementation is ready for integration with existing project reordering APIs. The factory pattern makes it easy to add reordering functionality to any service that needs to reorder projects within an organization.
## Migration from Old Implementation
If you were using the old implementation:
1. **Update imports**: Change from `api/pkg/db/internal/indexable` to `api/pkg/db`
2. **Use factory**: Replace manual template usage with `NewProjectIndexableDB`
3. **Add logger**: Include a logger parameter in your constructor calls
4. **Update tests**: Use the new test structure if needed
The API remains the same, so existing code should work with minimal changes.

View File

@@ -0,0 +1,174 @@
# Indexable Usage Guide
## Generic Implementation for Any Indexable Struct
The implementation is now **generic** and supports **any struct that embeds `model.Indexable`**!
- **Interface**: `api/pkg/db/indexable.go` - defines the contract
- **Implementation**: `api/pkg/db/internal/mongo/indexable/` - generic implementation
- **Factory**: `api/pkg/db/project_indexable.go` - convenient factory for projects
## Usage
### 1. Using the Generic Implementation Directly
```go
import "github.com/tech/sendico/pkg/db/internal/mongo/indexable"
// For any type that embeds model.Indexable, define helper functions:
createEmpty := func() *YourType {
return &YourType{}
}
getIndexable := func(obj *YourType) *model.Indexable {
return &obj.Indexable
}
// Create generic IndexableDB
indexableDB := indexable.NewIndexableDB(repo, logger, createEmpty, getIndexable)
// Use with single filter parameter
err := indexableDB.Reorder(ctx, objectID, newIndex, filter)
```
### 2. Using the Project Factory (Recommended for Projects)
```go
import "github.com/tech/sendico/pkg/db"
// Create project indexable DB (automatically applies org filter)
projectDB := db.NewProjectIndexableDB(repo, logger, organizationRef)
// Reorder project (org filter applied automatically)
err := projectDB.Reorder(ctx, projectID, newIndex, repository.Query())
// Reorder with additional filters (combined with org filter)
additionalFilter := repository.Query().Comparison(repository.Field("state"), builder.Eq, "active")
err := projectDB.Reorder(ctx, projectID, newIndex, additionalFilter)
```
## Examples for Different Types
### Project IndexableDB
```go
createEmpty := func() *model.Project {
return &model.Project{}
}
getIndexable := func(p *model.Project) *model.Indexable {
return &p.Indexable
}
projectDB := indexable.NewIndexableDB(repo, logger, createEmpty, getIndexable)
orgFilter := repository.OrgFilter(organizationRef)
projectDB.Reorder(ctx, projectID, 2, orgFilter)
```
### Status IndexableDB
```go
createEmpty := func() *model.Status {
return &model.Status{}
}
getIndexable := func(s *model.Status) *model.Indexable {
return &s.Indexable
}
statusDB := indexable.NewIndexableDB(repo, logger, createEmpty, getIndexable)
projectFilter := repository.Query().Comparison(repository.Field("projectRef"), builder.Eq, projectRef)
statusDB.Reorder(ctx, statusID, 1, projectFilter)
```
### Task IndexableDB
```go
createEmpty := func() *model.Task {
return &model.Task{}
}
getIndexable := func(t *model.Task) *model.Indexable {
return &t.Indexable
}
taskDB := indexable.NewIndexableDB(repo, logger, createEmpty, getIndexable)
statusFilter := repository.Query().Comparison(repository.Field("statusRef"), builder.Eq, statusRef)
taskDB.Reorder(ctx, taskID, 3, statusFilter)
```
### Priority IndexableDB
```go
createEmpty := func() *model.Priority {
return &model.Priority{}
}
getIndexable := func(p *model.Priority) *model.Indexable {
return &p.Indexable
}
priorityDB := indexable.NewIndexableDB(repo, logger, createEmpty, getIndexable)
orgFilter := repository.OrgFilter(organizationRef)
priorityDB.Reorder(ctx, priorityID, 0, orgFilter)
```
### Global Reordering (No Filter)
```go
createEmpty := func() *model.Project {
return &model.Project{}
}
getIndexable := func(p *model.Project) *model.Indexable {
return &p.Indexable
}
globalDB := indexable.NewIndexableDB(repo, logger, createEmpty, getIndexable)
// Reorders all items globally (empty filter)
globalDB.Reorder(ctx, objectID, 5, repository.Query())
```
## Key Features
### ✅ **Generic Support**
- Works with **any struct** that embeds `model.Indexable`
- Type-safe with compile-time checking
- No hardcoded types
### ✅ **Single Filter Parameter**
- **Simple**: Single `builder.Query` parameter instead of variadic `interface{}`
- **Flexible**: Can incorporate any combination of filters
- **Type-safe**: No runtime type assertions needed
### ✅ **Clean Architecture**
- Interface separated from implementation
- Generic implementation in internal package
- Easy-to-use factories for common types
## How It Works
### Generic Algorithm
1. **Get current index** using type-specific helper function
2. **If no change needed** → return early
3. **Apply filter** to scope affected items
4. **Shift affected items** using `PatchMany` with `$inc`
5. **Update target object** using `Patch` with `$set`
### Type-Safe Implementation
```go
type IndexableDB[T storable.Storable] struct {
repo repository.Repository
logger mlogger.Logger
createEmpty func() T
getIndexable func(T) *model.Indexable
}
// Single filter parameter - clean and simple
func (db *IndexableDB[T]) Reorder(ctx context.Context, objectRef primitive.ObjectID, newIndex int, filter builder.Query) error
```
## Benefits
**Generic** - Works with any Indexable struct
**Type Safe** - Compile-time type checking
**Simple** - Single filter parameter instead of variadic interface{}
**Efficient** - Uses patches, not full updates
**Clean** - Interface separated from implementation
That's it! **Generic, type-safe, and simple** reordering for any Indexable struct with a single filter parameter.

View File

@@ -0,0 +1,69 @@
package indexable
import (
"context"
"github.com/tech/sendico/pkg/db/repository"
"github.com/tech/sendico/pkg/db/repository/builder"
"github.com/tech/sendico/pkg/mlogger"
"github.com/tech/sendico/pkg/model"
"go.mongodb.org/mongo-driver/bson/primitive"
)
// Example usage of the generic IndexableDB with different types
// Example 1: Using with Project
func ExampleProjectIndexableDB(repo repository.Repository, logger mlogger.Logger, organizationRef primitive.ObjectID) {
// Define helper functions for Project
createEmpty := func() *model.Project {
return &model.Project{}
}
getIndexable := func(p *model.Project) *model.Indexable {
return &p.Indexable
}
// Create generic IndexableDB for Project
projectDB := NewIndexableDB(repo, logger, createEmpty, getIndexable)
// Use with organization filter
orgFilter := repository.OrgFilter(organizationRef)
projectDB.Reorder(context.Background(), primitive.NewObjectID(), 2, orgFilter)
}
// Example 3: Using with Task
func ExampleTaskIndexableDB(repo repository.Repository, logger mlogger.Logger, statusRef primitive.ObjectID) {
// Define helper functions for Task
createEmpty := func() *model.Task {
return &model.Task{}
}
getIndexable := func(t *model.Task) *model.Indexable {
return &t.Indexable
}
// Create generic IndexableDB for Task
taskDB := NewIndexableDB(repo, logger, createEmpty, getIndexable)
// Use with status filter
statusFilter := repository.Query().Comparison(repository.Field("statusRef"), builder.Eq, statusRef)
taskDB.Reorder(context.Background(), primitive.NewObjectID(), 3, statusFilter)
}
// Example 5: Using without any filter (global reordering)
func ExampleGlobalIndexableDB(repo repository.Repository, logger mlogger.Logger) {
// Define helper functions for any Indexable type
createEmpty := func() *model.Project {
return &model.Project{}
}
getIndexable := func(p *model.Project) *model.Indexable {
return &p.Indexable
}
// Create generic IndexableDB without filters
globalDB := NewIndexableDB(repo, logger, createEmpty, getIndexable)
// Use without any filter - reorders all items globally
globalDB.Reorder(context.Background(), primitive.NewObjectID(), 5, repository.Query())
}

View File

@@ -0,0 +1,122 @@
package indexable
import (
"context"
"github.com/tech/sendico/pkg/db/repository"
"github.com/tech/sendico/pkg/db/repository/builder"
"github.com/tech/sendico/pkg/db/storable"
"github.com/tech/sendico/pkg/mlogger"
"github.com/tech/sendico/pkg/model"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.uber.org/zap"
)
// IndexableDB implements db.IndexableDB interface with generic support
type IndexableDB[T storable.Storable] struct {
repo repository.Repository
logger mlogger.Logger
createEmpty func() T
getIndexable func(T) *model.Indexable
}
// NewIndexableDB creates a new IndexableDB instance
func NewIndexableDB[T storable.Storable](
repo repository.Repository,
logger mlogger.Logger,
createEmpty func() T,
getIndexable func(T) *model.Indexable,
) *IndexableDB[T] {
return &IndexableDB[T]{
repo: repo,
logger: logger,
createEmpty: createEmpty,
getIndexable: getIndexable,
}
}
// Reorder implements the db.IndexableDB interface with single filter parameter
func (db *IndexableDB[T]) Reorder(ctx context.Context, objectRef primitive.ObjectID, newIndex int, filter builder.Query) error {
// Get current object to find its index
obj := db.createEmpty()
err := db.repo.Get(ctx, objectRef, obj)
if err != nil {
db.logger.Error("Failed to get object for reordering",
zap.Error(err),
zap.String("object_ref", objectRef.Hex()),
zap.Int("new_index", newIndex))
return err
}
// Extract index from the object
indexable := db.getIndexable(obj)
currentIndex := indexable.Index
if currentIndex == newIndex {
db.logger.Debug("No reordering needed - same index",
zap.String("object_ref", objectRef.Hex()),
zap.Int("current_index", currentIndex),
zap.Int("new_index", newIndex))
return nil // No change needed
}
// Simple reordering logic
if currentIndex < newIndex {
// Moving down: shift items between currentIndex+1 and newIndex up by -1
patch := repository.Patch().Inc(repository.IndexField(), -1)
reorderFilter := filter.
And(repository.IndexOpFilter(currentIndex+1, builder.Gte)).
And(repository.IndexOpFilter(newIndex, builder.Lte))
updatedCount, err := db.repo.PatchMany(ctx, reorderFilter, patch)
if err != nil {
db.logger.Error("Failed to shift objects during reordering (moving down)",
zap.Error(err),
zap.String("object_ref", objectRef.Hex()),
zap.Int("current_index", currentIndex),
zap.Int("new_index", newIndex),
zap.Int("updated_count", updatedCount))
return err
}
db.logger.Debug("Successfully shifted objects (moving down)",
zap.String("object_ref", objectRef.Hex()),
zap.Int("updated_count", updatedCount))
} else {
// Moving up: shift items between newIndex and currentIndex-1 down by +1
patch := repository.Patch().Inc(repository.IndexField(), 1)
reorderFilter := filter.
And(repository.IndexOpFilter(newIndex, builder.Gte)).
And(repository.IndexOpFilter(currentIndex-1, builder.Lte))
updatedCount, err := db.repo.PatchMany(ctx, reorderFilter, patch)
if err != nil {
db.logger.Error("Failed to shift objects during reordering (moving up)",
zap.Error(err),
zap.String("object_ref", objectRef.Hex()),
zap.Int("current_index", currentIndex),
zap.Int("new_index", newIndex),
zap.Int("updated_count", updatedCount))
return err
}
db.logger.Debug("Successfully shifted objects (moving up)",
zap.String("object_ref", objectRef.Hex()),
zap.Int("updated_count", updatedCount))
}
// Update the target object to new index
patch := repository.Patch().Set(repository.IndexField(), newIndex)
err = db.repo.Patch(ctx, objectRef, patch)
if err != nil {
db.logger.Error("Failed to update target object index",
zap.Error(err),
zap.String("object_ref", objectRef.Hex()),
zap.Int("current_index", currentIndex),
zap.Int("new_index", newIndex))
return err
}
db.logger.Info("Successfully reordered object",
zap.String("object_ref", objectRef.Hex()),
zap.Int("old_index", currentIndex),
zap.Int("new_index", newIndex))
return nil
}

View File

@@ -0,0 +1,314 @@
//go:build integration
// +build integration
package indexable
import (
"context"
"testing"
"time"
"github.com/tech/sendico/pkg/db/repository"
"github.com/tech/sendico/pkg/model"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/testcontainers/testcontainers-go"
"github.com/testcontainers/testcontainers-go/modules/mongodb"
"github.com/testcontainers/testcontainers-go/wait"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
"go.uber.org/zap"
)
func setupTestDB(t *testing.T) (repository.Repository, func()) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
mongoContainer, err := mongodb.Run(ctx,
"mongo:latest",
mongodb.WithUsername("root"),
mongodb.WithPassword("password"),
testcontainers.WithWaitStrategy(wait.ForLog("Waiting for connections")),
)
require.NoError(t, err, "failed to start MongoDB container")
mongoURI, err := mongoContainer.ConnectionString(ctx)
require.NoError(t, err, "failed to get MongoDB connection string")
clientOptions := options.Client().ApplyURI(mongoURI)
client, err := mongo.Connect(ctx, clientOptions)
require.NoError(t, err, "failed to connect to MongoDB")
db := client.Database("testdb")
repo := repository.CreateMongoRepository(db, "projects")
cleanup := func() {
disconnect(ctx, t, client)
terminate(ctx, t, mongoContainer)
}
return repo, cleanup
}
func disconnect(ctx context.Context, t *testing.T, client *mongo.Client) {
if err := client.Disconnect(ctx); err != nil {
t.Logf("failed to disconnect from MongoDB: %v", err)
}
}
func terminate(ctx context.Context, t *testing.T, container testcontainers.Container) {
if err := container.Terminate(ctx); err != nil {
t.Logf("failed to terminate MongoDB container: %v", err)
}
}
func TestIndexableDB_Reorder(t *testing.T) {
repo, cleanup := setupTestDB(t)
defer cleanup()
ctx := context.Background()
organizationRef := primitive.NewObjectID()
logger := zap.NewNop()
// Create test projects with different indices
projects := []*model.Project{
{
ProjectBase: model.ProjectBase{
PermissionBound: model.PermissionBound{
OrganizationBoundBase: model.OrganizationBoundBase{
OrganizationRef: organizationRef,
},
},
Describable: model.Describable{Name: "Project A"},
Indexable: model.Indexable{Index: 0},
Mnemonic: "A",
State: model.ProjectStateActive,
},
},
{
ProjectBase: model.ProjectBase{
PermissionBound: model.PermissionBound{
OrganizationBoundBase: model.OrganizationBoundBase{
OrganizationRef: organizationRef,
},
},
Describable: model.Describable{Name: "Project B"},
Indexable: model.Indexable{Index: 1},
Mnemonic: "B",
State: model.ProjectStateActive,
},
},
{
ProjectBase: model.ProjectBase{
PermissionBound: model.PermissionBound{
OrganizationBoundBase: model.OrganizationBoundBase{
OrganizationRef: organizationRef,
},
},
Describable: model.Describable{Name: "Project C"},
Indexable: model.Indexable{Index: 2},
Mnemonic: "C",
State: model.ProjectStateActive,
},
},
{
ProjectBase: model.ProjectBase{
PermissionBound: model.PermissionBound{
OrganizationBoundBase: model.OrganizationBoundBase{
OrganizationRef: organizationRef,
},
},
Describable: model.Describable{Name: "Project D"},
Indexable: model.Indexable{Index: 3},
Mnemonic: "D",
State: model.ProjectStateActive,
},
},
}
// Insert projects into database
for _, project := range projects {
project.ID = primitive.NewObjectID()
err := repo.Insert(ctx, project, nil)
require.NoError(t, err)
}
// Create helper functions for Project type
createEmpty := func() *model.Project {
return &model.Project{}
}
getIndexable := func(p *model.Project) *model.Indexable {
return &p.Indexable
}
indexableDB := NewIndexableDB(repo, logger, createEmpty, getIndexable)
t.Run("Reorder_NoChange", func(t *testing.T) {
// Test reordering to the same position (should be no-op)
err := indexableDB.Reorder(ctx, projects[1].ID, 1, repository.Query())
require.NoError(t, err)
// Verify indices haven't changed
var result model.Project
err = repo.Get(ctx, projects[0].ID, &result)
require.NoError(t, err)
assert.Equal(t, 0, result.Index)
err = repo.Get(ctx, projects[1].ID, &result)
require.NoError(t, err)
assert.Equal(t, 1, result.Index)
})
t.Run("Reorder_MoveDown", func(t *testing.T) {
// Move Project A (index 0) to index 2
err := indexableDB.Reorder(ctx, projects[0].ID, 2, repository.Query())
require.NoError(t, err)
// Verify the reordering:
// Project A should now be at index 2
// Project B should be at index 0
// Project C should be at index 1
// Project D should remain at index 3
var result model.Project
// Check Project A (moved to index 2)
err = repo.Get(ctx, projects[0].ID, &result)
require.NoError(t, err)
assert.Equal(t, 2, result.Index)
// Check Project B (shifted to index 0)
err = repo.Get(ctx, projects[1].ID, &result)
require.NoError(t, err)
assert.Equal(t, 0, result.Index)
// Check Project C (shifted to index 1)
err = repo.Get(ctx, projects[2].ID, &result)
require.NoError(t, err)
assert.Equal(t, 1, result.Index)
// Check Project D (unchanged)
err = repo.Get(ctx, projects[3].ID, &result)
require.NoError(t, err)
assert.Equal(t, 3, result.Index)
})
t.Run("Reorder_MoveUp", func(t *testing.T) {
// Reset indices for this test
for i, project := range projects {
project.Index = i
err := repo.Update(ctx, project)
require.NoError(t, err)
}
// Move Project C (index 2) to index 0
err := indexableDB.Reorder(ctx, projects[2].ID, 0, repository.Query())
require.NoError(t, err)
// Verify the reordering:
// Project C should now be at index 0
// Project A should be at index 1
// Project B should be at index 2
// Project D should remain at index 3
var result model.Project
// Check Project C (moved to index 0)
err = repo.Get(ctx, projects[2].ID, &result)
require.NoError(t, err)
assert.Equal(t, 0, result.Index)
// Check Project A (shifted to index 1)
err = repo.Get(ctx, projects[0].ID, &result)
require.NoError(t, err)
assert.Equal(t, 1, result.Index)
// Check Project B (shifted to index 2)
err = repo.Get(ctx, projects[1].ID, &result)
require.NoError(t, err)
assert.Equal(t, 2, result.Index)
// Check Project D (unchanged)
err = repo.Get(ctx, projects[3].ID, &result)
require.NoError(t, err)
assert.Equal(t, 3, result.Index)
})
t.Run("Reorder_WithFilter", func(t *testing.T) {
// Reset indices for this test
for i, project := range projects {
project.Index = i
err := repo.Update(ctx, project)
require.NoError(t, err)
}
// Test reordering with organization filter
orgFilter := repository.OrgFilter(organizationRef)
err := indexableDB.Reorder(ctx, projects[0].ID, 2, orgFilter)
require.NoError(t, err)
// Verify the reordering worked with filter
var result model.Project
err = repo.Get(ctx, projects[0].ID, &result)
require.NoError(t, err)
assert.Equal(t, 2, result.Index)
})
}
func TestIndexableDB_EdgeCases(t *testing.T) {
repo, cleanup := setupTestDB(t)
defer cleanup()
ctx := context.Background()
organizationRef := primitive.NewObjectID()
logger := zap.NewNop()
// Create a single project for edge case testing
project := &model.Project{
ProjectBase: model.ProjectBase{
PermissionBound: model.PermissionBound{
OrganizationBoundBase: model.OrganizationBoundBase{
OrganizationRef: organizationRef,
},
},
Describable: model.Describable{Name: "Test Project"},
Indexable: model.Indexable{Index: 0},
Mnemonic: "TEST",
State: model.ProjectStateActive,
},
}
project.ID = primitive.NewObjectID()
err := repo.Insert(ctx, project, nil)
require.NoError(t, err)
// Create helper functions for Project type
createEmpty := func() *model.Project {
return &model.Project{}
}
getIndexable := func(p *model.Project) *model.Indexable {
return &p.Indexable
}
indexableDB := NewIndexableDB(repo, logger, createEmpty, getIndexable)
t.Run("Reorder_SingleItem", func(t *testing.T) {
// Test reordering a single item (should work but have no effect)
err := indexableDB.Reorder(ctx, project.ID, 0, repository.Query())
require.NoError(t, err)
var result model.Project
err = repo.Get(ctx, project.ID, &result)
require.NoError(t, err)
assert.Equal(t, 0, result.Index)
})
t.Run("Reorder_InvalidObjectID", func(t *testing.T) {
// Test reordering with an invalid object ID
invalidID := primitive.NewObjectID()
err := indexableDB.Reorder(ctx, invalidID, 1, repository.Query())
require.Error(t, err) // Should fail because object doesn't exist
})
}

View File

@@ -0,0 +1,12 @@
package invitationdb
import (
"context"
"github.com/tech/sendico/pkg/model"
"go.mongodb.org/mongo-driver/bson/primitive"
)
func (db *InvitationDB) Accept(ctx context.Context, invitationRef primitive.ObjectID) error {
return db.updateStatus(ctx, invitationRef, model.InvitationAccepted)
}

View File

@@ -0,0 +1,49 @@
package invitationdb
import (
"context"
"github.com/tech/sendico/pkg/merrors"
"github.com/tech/sendico/pkg/model"
"github.com/tech/sendico/pkg/mutil/mzap"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.uber.org/zap"
)
// SetArchived sets the archived status of an invitation
// Invitation supports archiving through PermissionBound embedding ArchivableBase
func (db *InvitationDB) SetArchived(ctx context.Context, accountRef, organizationRef, invitationRef primitive.ObjectID, archived, cascade bool) error {
db.DBImp.Logger.Debug("Setting invitation archived status", mzap.ObjRef("invitation_ref", invitationRef), zap.Bool("archived", archived), zap.Bool("cascade", cascade))
res, err := db.Enforcer.Enforce(ctx, db.PermissionRef, accountRef, organizationRef, invitationRef, model.ActionUpdate)
if err != nil {
db.DBImp.Logger.Warn("Failed to enforce archivation permission", zap.Error(err), mzap.ObjRef("invitation_ref", invitationRef))
return err
}
if !res {
db.DBImp.Logger.Debug("Permission denied for archivation", mzap.ObjRef("invitation_ref", invitationRef))
return merrors.AccessDenied(db.Collection, string(model.ActionUpdate), invitationRef)
}
// Get the invitation first
var invitation model.Invitation
if err := db.Get(ctx, accountRef, invitationRef, &invitation); err != nil {
db.DBImp.Logger.Warn("Error retrieving invitation for archival", zap.Error(err), mzap.ObjRef("invitation_ref", invitationRef))
return err
}
// Update the invitation's archived status
invitation.SetArchived(archived)
if err := db.Update(ctx, accountRef, &invitation); err != nil {
db.DBImp.Logger.Warn("Error updating invitation archived status", zap.Error(err), mzap.ObjRef("invitation_ref", invitationRef))
return err
}
// Note: Currently no cascade dependencies for invitations
// If cascade is enabled, we could add logic here for any future dependencies
if cascade {
db.DBImp.Logger.Debug("Cascade archiving requested but no dependencies to archive for invitation", mzap.ObjRef("invitation_ref", invitationRef))
}
db.DBImp.Logger.Debug("Successfully set invitation archived status", mzap.ObjRef("invitation_ref", invitationRef), zap.Bool("archived", archived))
return nil
}

View File

@@ -0,0 +1,24 @@
package invitationdb
import (
"context"
"github.com/tech/sendico/pkg/mutil/mzap"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.uber.org/zap"
)
// DeleteCascade deletes an invitation
// Invitations don't have cascade dependencies, so this is a simple deletion
func (db *InvitationDB) DeleteCascade(ctx context.Context, accountRef, invitationRef primitive.ObjectID) error {
db.DBImp.Logger.Debug("Starting invitation cascade deletion", mzap.ObjRef("invitation_ref", invitationRef))
// Delete the invitation itself (no dependencies to cascade delete)
if err := db.Delete(ctx, accountRef, invitationRef); err != nil {
db.DBImp.Logger.Error("Error deleting invitation", zap.Error(err), mzap.ObjRef("invitation_ref", invitationRef))
return err
}
db.DBImp.Logger.Debug("Successfully deleted invitation", mzap.ObjRef("invitation_ref", invitationRef))
return nil
}

View File

@@ -0,0 +1,53 @@
package invitationdb
import (
"context"
"github.com/tech/sendico/pkg/auth"
"github.com/tech/sendico/pkg/db/policy"
"github.com/tech/sendico/pkg/db/repository"
ri "github.com/tech/sendico/pkg/db/repository/index"
"github.com/tech/sendico/pkg/mlogger"
"github.com/tech/sendico/pkg/model"
"github.com/tech/sendico/pkg/mservice"
"go.mongodb.org/mongo-driver/mongo"
"go.uber.org/zap"
)
type InvitationDB struct {
auth.ProtectedDBImp[*model.Invitation]
}
func Create(
ctx context.Context,
logger mlogger.Logger,
enforcer auth.Enforcer,
pdb policy.DB,
db *mongo.Database,
) (*InvitationDB, error) {
p, err := auth.CreateDBImp[*model.Invitation](ctx, logger, pdb, enforcer, mservice.Invitations, db)
if err != nil {
return nil, err
}
// unique email per organization
if err := p.DBImp.Repository.CreateIndex(&ri.Definition{
Keys: []ri.Key{{Field: repository.OrgField().Build(), Sort: ri.Asc}, {Field: "description.email", Sort: ri.Asc}},
Unique: true,
}); err != nil {
p.DBImp.Logger.Error("Failed to create unique mnemonic index", zap.Error(err))
return nil, err
}
// ttl index
ttl := int32(0) // zero ttl means expiration on date preset when inserting data
if err := p.DBImp.Repository.CreateIndex(&ri.Definition{
Keys: []ri.Key{{Field: "expiresAt", Sort: ri.Asc}},
TTL: &ttl,
}); err != nil {
p.DBImp.Logger.Warn("Failed to create ttl index in the invitations", zap.Error(err))
return nil, err
}
return &InvitationDB{ProtectedDBImp: *p}, nil
}

View File

@@ -0,0 +1,12 @@
package invitationdb
import (
"context"
"github.com/tech/sendico/pkg/model"
"go.mongodb.org/mongo-driver/bson/primitive"
)
func (db *InvitationDB) Decline(ctx context.Context, invitationRef primitive.ObjectID) error {
return db.updateStatus(ctx, invitationRef, model.InvitationDeclined)
}

View File

@@ -0,0 +1,121 @@
package invitationdb
import (
"context"
"fmt"
"github.com/tech/sendico/pkg/db/repository"
"github.com/tech/sendico/pkg/merrors"
"github.com/tech/sendico/pkg/model"
"github.com/tech/sendico/pkg/mservice"
"github.com/tech/sendico/pkg/mutil/mzap"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/mongo"
"go.uber.org/zap"
)
func (db *InvitationDB) GetPublic(ctx context.Context, invitationRef primitive.ObjectID) (*model.PublicInvitation, error) {
roleField := repository.Field("role")
orgField := repository.Field("organization")
accField := repository.Field("account")
empField := repository.Field("employee")
regField := repository.Field("registrationAcc")
descEmailField := repository.Field("description").Dot("email")
pipeline := repository.Pipeline().
// 0) Filter to exactly the invitation(s) you want
Match(repository.IDFilter(invitationRef).And(repository.Filter("status", model.InvitationCreated))).
// 1) Lookup the role document
Lookup(
mservice.Roles,
repository.Field("roleRef"),
repository.IDField(),
roleField,
).
Unwind(repository.Ref(roleField)).
// 2) Lookup the organization document
Lookup(
mservice.Organizations,
repository.Field("organizationRef"),
repository.IDField(),
orgField,
).
Unwind(repository.Ref(orgField)).
// 3) Lookup the account document
Lookup(
mservice.Accounts,
repository.Field("inviterRef"),
repository.IDField(),
accField,
).
Unwind(repository.Ref(accField)).
/* 4) do we already have an account whose login == invitation.description ? */
Lookup(
mservice.Accounts,
descEmailField, // local field (invitation.description.email)
repository.Field("login"), // foreign field (account.login)
regField, // array: 0-length or ≥1
).
// 5) Projection
Project(
repository.SimpleAlias(
empField.Dot("description"),
repository.Ref(accField),
),
repository.SimpleAlias(
empField.Dot("avatarUrl"),
repository.Ref(accField.Dot("avatarUrl")),
),
repository.SimpleAlias(
orgField.Dot("description"),
repository.Ref(orgField),
),
repository.SimpleAlias(
orgField.Dot("logoUrl"),
repository.Ref(orgField.Dot("logoUrl")),
),
repository.SimpleAlias(
roleField,
repository.Ref(roleField),
),
repository.SimpleAlias(
repository.Field("invitation"), // ← left-hand side
repository.Ref(repository.Field("description")), // ← right-hand side (“$description”)
),
repository.SimpleAlias(
repository.Field("storable"), // ← left-hand side
repository.RootRef(), // ← right-hand side (“$description”)
),
repository.ProjectionExpr(
repository.Field("registrationRequired"),
repository.Eq(
repository.Size(repository.Value(repository.Ref(regField).Build())),
repository.Literal(0),
),
),
)
var res model.PublicInvitation
haveResult := false
decoder := func(cur *mongo.Cursor) error {
if haveResult {
// should never get here
db.DBImp.Logger.Warn("Unexpected extra invitation", mzap.ObjRef("invitation_ref", invitationRef))
return merrors.Internal("Unexpected extra invitation found by reference")
}
if e := cur.Decode(&res); e != nil {
db.DBImp.Logger.Warn("Failed to decode entity", zap.Error(e), zap.Any("data", cur.Current.String()))
return e
}
haveResult = true
return nil
}
if err := db.DBImp.Repository.Aggregate(ctx, pipeline, decoder); err != nil {
db.DBImp.Logger.Warn("Failed to execute aggregation pipeline", zap.Error(err), mzap.ObjRef("invitation_ref", invitationRef))
return nil, err
}
if !haveResult {
db.DBImp.Logger.Warn("No results fetched", mzap.ObjRef("invitation_ref", invitationRef))
return nil, merrors.NoData(fmt.Sprintf("Invitation %s not found", invitationRef.Hex()))
}
return &res, nil
}

View File

@@ -0,0 +1,28 @@
package invitationdb
import (
"context"
"errors"
"github.com/tech/sendico/pkg/db/repository"
"github.com/tech/sendico/pkg/merrors"
"github.com/tech/sendico/pkg/model"
mauth "github.com/tech/sendico/pkg/mutil/db/auth"
"go.mongodb.org/mongo-driver/bson/primitive"
)
func (db *InvitationDB) List(ctx context.Context, accountRef, organizationRef, _ primitive.ObjectID, cursor *model.ViewCursor) ([]model.Invitation, error) {
res, err := mauth.GetProtectedObjects[model.Invitation](
ctx,
db.DBImp.Logger,
accountRef, organizationRef, model.ActionRead,
repository.OrgFilter(organizationRef),
cursor,
db.Enforcer,
db.DBImp.Repository,
)
if errors.Is(err, merrors.ErrNoData) {
return []model.Invitation{}, nil
}
return res, err
}

View File

@@ -0,0 +1,26 @@
package invitationdb
import (
"context"
"github.com/tech/sendico/pkg/db/repository"
"github.com/tech/sendico/pkg/model"
"github.com/tech/sendico/pkg/mutil/mzap"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.uber.org/zap"
)
func (db *InvitationDB) updateStatus(ctx context.Context, invitationRef primitive.ObjectID, newStatus model.InvitationStatus) error {
// db.DBImp.Up
var inv model.Invitation
if err := db.DBImp.FindOne(ctx, repository.IDFilter(invitationRef), &inv); err != nil {
db.DBImp.Logger.Warn("Failed to fetch invitation", zap.Error(err), mzap.ObjRef("invitation_ref", invitationRef), zap.String("new_status", string(newStatus)))
return err
}
inv.Status = newStatus
if err := db.DBImp.Update(ctx, &inv); err != nil {
db.DBImp.Logger.Warn("Failed to update invitation", zap.Error(err), mzap.ObjRef("invitation_ref", invitationRef), zap.String("new_status", string(newStatus)))
return err
}
return nil
}

View File

@@ -0,0 +1,22 @@
package mongo
import (
"net/url"
)
func buildURI(s *DBSettings) string {
u := &url.URL{
Scheme: "mongodb",
Host: s.Host,
Path: "/" + url.PathEscape(s.Database), // /my%20db
}
q := url.Values{}
if s.ReplicaSet != "" {
q.Set("replicaSet", s.ReplicaSet)
}
u.RawQuery = q.Encode()
return u.String()
}

View File

@@ -0,0 +1,32 @@
package organizationdb
import (
"context"
"github.com/tech/sendico/pkg/model"
"github.com/tech/sendico/pkg/mutil/mzap"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.uber.org/zap"
)
// SetArchived sets the archived status of an organization and optionally cascades to projects, tasks, comments, and reactions
func (db *OrganizationDB) SetArchived(ctx context.Context, accountRef, organizationRef primitive.ObjectID, archived, cascade bool) error {
db.DBImp.Logger.Debug("Setting organization archived status", mzap.ObjRef("organization_ref", organizationRef), zap.Bool("archived", archived), zap.Bool("cascade", cascade))
// Get the organization first
var organization model.Organization
if err := db.Get(ctx, accountRef, organizationRef, &organization); err != nil {
db.DBImp.Logger.Warn("Error retrieving organization for archival", zap.Error(err), mzap.ObjRef("organization_ref", organizationRef))
return err
}
// Update the organization's archived status
organization.SetArchived(archived)
if err := db.Update(ctx, accountRef, &organization); err != nil {
db.DBImp.Logger.Warn("Error updating organization archived status", zap.Error(err), mzap.ObjRef("organization_ref", organizationRef))
return err
}
db.DBImp.Logger.Debug("Successfully set organization archived status", mzap.ObjRef("organization_ref", organizationRef), zap.Bool("archived", archived))
return nil
}

View File

@@ -0,0 +1,23 @@
package organizationdb
import (
"context"
"github.com/tech/sendico/pkg/mutil/mzap"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.uber.org/zap"
)
// DeleteCascade deletes an organization and all its related data (projects, tasks, comments, reactions, statuses)
func (db *OrganizationDB) DeleteCascade(ctx context.Context, organizationRef primitive.ObjectID) error {
db.DBImp.Logger.Debug("Starting organization deletion with projects", mzap.ObjRef("organization_ref", organizationRef))
// Delete the organization itself
if err := db.Unprotected().Delete(ctx, organizationRef); err != nil {
db.DBImp.Logger.Warn("Error deleting organization", zap.Error(err), mzap.ObjRef("organization_ref", organizationRef))
return err
}
db.DBImp.Logger.Debug("Successfully deleted organization with projects", mzap.ObjRef("organization_ref", organizationRef))
return nil
}

View File

@@ -0,0 +1,19 @@
package organizationdb
import (
"context"
"github.com/tech/sendico/pkg/merrors"
"github.com/tech/sendico/pkg/model"
"go.mongodb.org/mongo-driver/bson/primitive"
)
func (db *OrganizationDB) Create(ctx context.Context, _, _ primitive.ObjectID, org *model.Organization) error {
if org == nil {
return merrors.InvalidArgument("Organization object is nil")
}
org.SetID(primitive.NewObjectID())
// Organizaiton reference must be set to the same value as own organization reference
org.SetOrganizationRef(*org.GetID())
return db.DBImp.Create(ctx, org)
}

View File

@@ -0,0 +1,34 @@
package organizationdb
import (
"context"
"github.com/tech/sendico/pkg/auth"
"github.com/tech/sendico/pkg/db/policy"
"github.com/tech/sendico/pkg/mlogger"
"github.com/tech/sendico/pkg/model"
"github.com/tech/sendico/pkg/mservice"
"go.mongodb.org/mongo-driver/mongo"
)
type OrganizationDB struct {
auth.ProtectedDBImp[*model.Organization]
}
func Create(ctx context.Context,
logger mlogger.Logger,
enforcer auth.Enforcer,
pdb policy.DB,
db *mongo.Database,
) (*OrganizationDB, error) {
p, err := auth.CreateDBImp[*model.Organization](ctx, logger, pdb, enforcer, mservice.Organizations, db)
if err != nil {
return nil, err
}
res := &OrganizationDB{
ProtectedDBImp: *p,
}
p.DBImp.SetDeleter(res.DeleteCascade)
return res, nil
}

View File

@@ -0,0 +1,12 @@
package organizationdb
import (
"context"
"github.com/tech/sendico/pkg/model"
"go.mongodb.org/mongo-driver/bson/primitive"
)
func (db *OrganizationDB) GetByRef(ctx context.Context, organizationRef primitive.ObjectID, org *model.Organization) error {
return db.Unprotected().Get(ctx, organizationRef, org)
}

View File

@@ -0,0 +1,16 @@
package organizationdb
import (
"context"
"github.com/tech/sendico/pkg/db/repository"
"github.com/tech/sendico/pkg/db/repository/builder"
"github.com/tech/sendico/pkg/model"
mutil "github.com/tech/sendico/pkg/mutil/db"
"go.mongodb.org/mongo-driver/bson/primitive"
)
func (db *OrganizationDB) List(ctx context.Context, accountRef primitive.ObjectID, cursor *model.ViewCursor) ([]model.Organization, error) {
filter := repository.Query().Comparison(repository.Field("members"), builder.Eq, accountRef)
return mutil.GetObjects[model.Organization](ctx, db.DBImp.Logger, filter, cursor, db.DBImp.Repository)
}

View File

@@ -0,0 +1,14 @@
package organizationdb
import (
"context"
"github.com/tech/sendico/pkg/db/repository"
"github.com/tech/sendico/pkg/model"
mutil "github.com/tech/sendico/pkg/mutil/db"
"go.mongodb.org/mongo-driver/bson/primitive"
)
func (db *OrganizationDB) ListOwned(ctx context.Context, accountRef primitive.ObjectID) ([]model.Organization, error) {
return mutil.GetObjects[model.Organization](ctx, db.DBImp.Logger, repository.Filter("ownerRef", accountRef), nil, db.DBImp.Repository)
}

View File

@@ -0,0 +1,562 @@
//go:build integration
// +build integration
package organizationdb
import (
"context"
"errors"
"testing"
"github.com/tech/sendico/pkg/db/internal/mongo/commentdb"
"github.com/tech/sendico/pkg/db/internal/mongo/projectdb"
"github.com/tech/sendico/pkg/db/internal/mongo/reactiondb"
"github.com/tech/sendico/pkg/db/internal/mongo/statusdb"
"github.com/tech/sendico/pkg/db/internal/mongo/taskdb"
"github.com/tech/sendico/pkg/db/repository/builder"
"github.com/tech/sendico/pkg/db/template"
"github.com/tech/sendico/pkg/merrors"
"github.com/tech/sendico/pkg/model"
"github.com/tech/sendico/pkg/mservice"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/testcontainers/testcontainers-go/modules/mongodb"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
"go.uber.org/zap"
)
func setupSetArchivedTestDB(t *testing.T) (*OrganizationDB, *projectDBAdapter, *taskdb.TaskDB, *commentdb.CommentDB, *reactiondb.ReactionDB, func()) {
ctx := context.Background()
// Start MongoDB container
mongodbContainer, err := mongodb.Run(ctx, "mongo:latest")
require.NoError(t, err)
// Get connection string
endpoint, err := mongodbContainer.Endpoint(ctx, "")
require.NoError(t, err)
// Connect to MongoDB
client, err := mongo.Connect(ctx, options.Client().ApplyURI("mongodb://"+endpoint))
require.NoError(t, err)
db := client.Database("test_organization_setarchived")
logger := zap.NewNop()
// Create mock enforcer and policy DB
mockEnforcer := &mockSetArchivedEnforcer{}
mockPolicyDB := &mockSetArchivedPolicyDB{}
mockPGroupDB := &mockSetArchivedPGroupDB{}
// Create databases
// We need to create a projectDB first, but we'll create a temporary one for organizationDB creation
// Create temporary taskDB and statusDB for the temporary projectDB
// Create temporary reactionDB and commentDB for the temporary taskDB
tempReactionDB, err := reactiondb.Create(ctx, logger, mockEnforcer, mockPolicyDB, db)
require.NoError(t, err)
tempCommentDB, err := commentdb.Create(ctx, logger, mockEnforcer, mockPolicyDB, db, tempReactionDB)
require.NoError(t, err)
tempTaskDB, err := taskdb.Create(ctx, logger, mockEnforcer, mockPolicyDB, db, tempCommentDB, tempReactionDB)
require.NoError(t, err)
tempStatusDB, err := statusdb.Create(ctx, logger, mockEnforcer, mockPolicyDB, db)
require.NoError(t, err)
tempProjectDB, err := projectdb.Create(ctx, logger, mockEnforcer, mockPolicyDB, tempTaskDB, tempStatusDB, db)
require.NoError(t, err)
// Create adapter for organizationDB creation
tempProjectDBAdapter := &projectDBAdapter{
ProjectDB: tempProjectDB,
taskDB: tempTaskDB,
commentDB: tempCommentDB,
reactionDB: tempReactionDB,
statusDB: tempStatusDB,
}
organizationDB, err := Create(ctx, logger, mockEnforcer, mockPolicyDB, tempProjectDBAdapter, mockPGroupDB, db)
require.NoError(t, err)
var projectDB *projectdb.ProjectDB
var taskDB *taskdb.TaskDB
var commentDB *commentdb.CommentDB
var reactionDB *reactiondb.ReactionDB
var statusDB *statusdb.StatusDB
// Create databases in dependency order
reactionDB, err = reactiondb.Create(ctx, logger, mockEnforcer, mockPolicyDB, db)
require.NoError(t, err)
commentDB, err = commentdb.Create(ctx, logger, mockEnforcer, mockPolicyDB, db, reactionDB)
require.NoError(t, err)
taskDB, err = taskdb.Create(ctx, logger, mockEnforcer, mockPolicyDB, db, commentDB, reactionDB)
require.NoError(t, err)
statusDB, err = statusdb.Create(ctx, logger, mockEnforcer, mockPolicyDB, db)
require.NoError(t, err)
projectDB, err = projectdb.Create(ctx, logger, mockEnforcer, mockPolicyDB, taskDB, statusDB, db)
require.NoError(t, err)
// Create adapter for the actual projectDB
projectDBAdapter := &projectDBAdapter{
ProjectDB: projectDB,
taskDB: taskDB,
commentDB: commentDB,
reactionDB: reactionDB,
statusDB: statusDB,
}
cleanup := func() {
client.Disconnect(context.Background())
mongodbContainer.Terminate(ctx)
}
return organizationDB, projectDBAdapter, taskDB, commentDB, reactionDB, cleanup
}
// projectDBAdapter adapts projectdb.ProjectDB to project.DB interface for testing
type projectDBAdapter struct {
*projectdb.ProjectDB
taskDB *taskdb.TaskDB
commentDB *commentdb.CommentDB
reactionDB *reactiondb.ReactionDB
statusDB *statusdb.StatusDB
}
// DeleteCascade implements the project.DB interface
func (a *projectDBAdapter) DeleteCascade(ctx context.Context, projectRef primitive.ObjectID) error {
// Call the concrete implementation
return a.ProjectDB.DeleteCascade(ctx, projectRef)
}
// SetArchived implements the project.DB interface
func (a *projectDBAdapter) SetArchived(ctx context.Context, accountRef, organizationRef, projectRef primitive.ObjectID, archived, cascade bool) error {
// Use the stored dependencies for the concrete implementation
return a.ProjectDB.SetArchived(ctx, accountRef, organizationRef, projectRef, archived, cascade)
}
// List implements the project.DB interface
func (a *projectDBAdapter) List(ctx context.Context, accountRef, organizationRef, _ primitive.ObjectID, cursor *model.ViewCursor) ([]model.Project, error) {
return a.ProjectDB.List(ctx, accountRef, organizationRef, primitive.NilObjectID, cursor)
}
// Previews implements the project.DB interface
func (a *projectDBAdapter) Previews(ctx context.Context, accountRef, organizationRef primitive.ObjectID, projectRefs []primitive.ObjectID, cursor *model.ViewCursor, assigneeRefs, reporterRefs []primitive.ObjectID) ([]model.ProjectPreview, error) {
return a.ProjectDB.Previews(ctx, accountRef, organizationRef, projectRefs, cursor, assigneeRefs, reporterRefs)
}
// DeleteProject implements the project.DB interface
func (a *projectDBAdapter) DeleteProject(ctx context.Context, accountRef, organizationRef, projectRef primitive.ObjectID, migrateToRef *primitive.ObjectID) error {
// Call the concrete implementation with the organizationRef
return a.ProjectDB.DeleteProject(ctx, accountRef, organizationRef, projectRef, migrateToRef)
}
// RemoveTagFromProjects implements the project.DB interface
func (a *projectDBAdapter) RemoveTagFromProjects(ctx context.Context, accountRef, organizationRef, tagRef primitive.ObjectID) error {
// Call the concrete implementation
return a.ProjectDB.RemoveTagFromProjects(ctx, accountRef, organizationRef, tagRef)
}
// Mock implementations for SetArchived testing
type mockSetArchivedEnforcer struct{}
func (m *mockSetArchivedEnforcer) Enforce(ctx context.Context, permissionRef, accountRef, orgRef, objectRef primitive.ObjectID, action model.Action) (bool, error) {
return true, nil
}
func (m *mockSetArchivedEnforcer) EnforceBatch(ctx context.Context, objectRefs []model.PermissionBoundStorable, accountRef primitive.ObjectID, action model.Action) (map[primitive.ObjectID]bool, error) {
// Allow all objects for testing
result := make(map[primitive.ObjectID]bool)
for _, obj := range objectRefs {
result[*obj.GetID()] = true
}
return result, nil
}
func (m *mockSetArchivedEnforcer) GetRoles(ctx context.Context, accountRef, organizationRef primitive.ObjectID) ([]model.Role, error) {
return nil, nil
}
func (m *mockSetArchivedEnforcer) GetPermissions(ctx context.Context, accountRef, organizationRef primitive.ObjectID) ([]model.Role, []model.Permission, error) {
return nil, nil, nil
}
type mockSetArchivedPolicyDB struct{}
func (m *mockSetArchivedPolicyDB) Create(ctx context.Context, policy *model.PolicyDescription) error {
return nil
}
func (m *mockSetArchivedPolicyDB) Get(ctx context.Context, policyRef primitive.ObjectID, result *model.PolicyDescription) error {
return merrors.ErrNoData
}
func (m *mockSetArchivedPolicyDB) InsertMany(ctx context.Context, objects []*model.PolicyDescription) error { return nil }
func (m *mockSetArchivedPolicyDB) Update(ctx context.Context, policy *model.PolicyDescription) error {
return nil
}
func (m *mockSetArchivedPolicyDB) Patch(ctx context.Context, objectRef primitive.ObjectID, patch builder.Patch) error {
return nil
}
func (m *mockSetArchivedPolicyDB) Delete(ctx context.Context, policyRef primitive.ObjectID) error {
return nil
}
func (m *mockSetArchivedPolicyDB) DeleteMany(ctx context.Context, filter builder.Query) error {
return nil
}
func (m *mockSetArchivedPolicyDB) FindOne(ctx context.Context, filter builder.Query, result *model.PolicyDescription) error {
return merrors.ErrNoData
}
func (m *mockSetArchivedPolicyDB) ListIDs(ctx context.Context, query builder.Query) ([]primitive.ObjectID, error) {
return nil, nil
}
func (m *mockSetArchivedPolicyDB) ListPermissionBound(ctx context.Context, query builder.Query) ([]model.PermissionBoundStorable, error) {
return nil, nil
}
func (m *mockSetArchivedPolicyDB) Collection() string { return "" }
func (m *mockSetArchivedPolicyDB) All(ctx context.Context, organizationRef primitive.ObjectID) ([]model.PolicyDescription, error) {
return nil, nil
}
func (m *mockSetArchivedPolicyDB) Policies(ctx context.Context, refs []primitive.ObjectID) ([]model.PolicyDescription, error) {
return nil, nil
}
func (m *mockSetArchivedPolicyDB) GetBuiltInPolicy(ctx context.Context, resourceType mservice.Type, policy *model.PolicyDescription) error {
return nil
}
func (m *mockSetArchivedPolicyDB) DeleteCascade(ctx context.Context, policyRef primitive.ObjectID) error {
return nil
}
type mockSetArchivedPGroupDB struct{}
func (m *mockSetArchivedPGroupDB) Create(ctx context.Context, accountRef, organizationRef primitive.ObjectID, pgroup *model.PriorityGroup) error {
return nil
}
func (m *mockSetArchivedPGroupDB) InsertMany(ctx context.Context, accountRef, organizationRef primitive.ObjectID, objects []*model.PriorityGroup) error { return nil }
func (m *mockSetArchivedPGroupDB) Get(ctx context.Context, accountRef, pgroupRef primitive.ObjectID, result *model.PriorityGroup) error {
return merrors.ErrNoData
}
func (m *mockSetArchivedPGroupDB) Update(ctx context.Context, accountRef primitive.ObjectID, pgroup *model.PriorityGroup) error {
return nil
}
func (m *mockSetArchivedPGroupDB) Delete(ctx context.Context, accountRef, pgroupRef primitive.ObjectID) error {
return nil
}
func (m *mockSetArchivedPGroupDB) DeleteCascadeAuth(ctx context.Context, accountRef, pgroupRef primitive.ObjectID) error {
return nil
}
func (m *mockSetArchivedPGroupDB) Patch(ctx context.Context, accountRef, pgroupRef primitive.ObjectID, patch builder.Patch) error {
return nil
}
func (m *mockSetArchivedPGroupDB) PatchMany(ctx context.Context, accountRef primitive.ObjectID, query builder.Query, patch builder.Patch) (int, error) {
return 0, nil
}
func (m *mockSetArchivedPGroupDB) Unprotected() template.DB[*model.PriorityGroup] {
return nil
}
func (m *mockSetArchivedPGroupDB) ListIDs(ctx context.Context, action model.Action, accountRef primitive.ObjectID, query builder.Query) ([]primitive.ObjectID, error) {
return nil, nil
}
func (m *mockSetArchivedPGroupDB) All(ctx context.Context, organizationRef primitive.ObjectID, limit, offset *int64) ([]model.PriorityGroup, error) {
return nil, nil
}
func (m *mockSetArchivedPGroupDB) List(ctx context.Context, accountRef, organizationRef, _ primitive.ObjectID, cursor *model.ViewCursor) ([]model.PriorityGroup, error) {
return nil, nil
}
func (m *mockSetArchivedPGroupDB) DeleteCascade(ctx context.Context, statusRef primitive.ObjectID) error {
return nil
}
func (m *mockSetArchivedPGroupDB) SetArchived(ctx context.Context, accountRef, organizationRef, statusRef primitive.ObjectID, archived, cascade bool) error {
return nil
}
func (m *mockSetArchivedPGroupDB) Reorder(ctx context.Context, accountRef, priorityGroupRef primitive.ObjectID, oldIndex, newIndex int) error {
return nil
}
// Mock project DB for statusdb creation
type mockSetArchivedProjectDB struct{}
func (m *mockSetArchivedProjectDB) Create(ctx context.Context, accountRef, organizationRef primitive.ObjectID, project *model.Project) error {
return nil
}
func (m *mockSetArchivedProjectDB) Get(ctx context.Context, accountRef, projectRef primitive.ObjectID, result *model.Project) error {
return merrors.ErrNoData
}
func (m *mockSetArchivedProjectDB) Update(ctx context.Context, accountRef primitive.ObjectID, project *model.Project) error {
return nil
}
func (m *mockSetArchivedProjectDB) Delete(ctx context.Context, accountRef, projectRef primitive.ObjectID) error {
return nil
}
func (m *mockSetArchivedProjectDB) DeleteCascadeAuth(ctx context.Context, accountRef, projectRef primitive.ObjectID) error {
return nil
}
func (m *mockSetArchivedProjectDB) Patch(ctx context.Context, accountRef, objectRef primitive.ObjectID, patch builder.Patch) error {
return nil
}
func (m *mockSetArchivedProjectDB) PatchMany(ctx context.Context, accountRef primitive.ObjectID, query builder.Query, patch builder.Patch) (int, error) {
return 0, nil
}
func (m *mockSetArchivedProjectDB) Unprotected() template.DB[*model.Project] { return nil }
func (m *mockSetArchivedProjectDB) ListIDs(ctx context.Context, action model.Action, accountRef primitive.ObjectID, query builder.Query) ([]primitive.ObjectID, error) {
return nil, nil
}
func (m *mockSetArchivedProjectDB) List(ctx context.Context, accountRef, organizationRef, _ primitive.ObjectID, cursor *model.ViewCursor) ([]model.Project, error) {
return nil, nil
}
func (m *mockSetArchivedProjectDB) Previews(ctx context.Context, accountRef, organizationRef primitive.ObjectID, projectRefs []primitive.ObjectID, cursor *model.ViewCursor, assigneeRefs, reporterRefs []primitive.ObjectID) ([]model.ProjectPreview, error) {
return nil, nil
}
func (m *mockSetArchivedProjectDB) DeleteProject(ctx context.Context, accountRef, organizationRef, projectRef primitive.ObjectID, migrateToRef *primitive.ObjectID) error {
return nil
}
func (m *mockSetArchivedProjectDB) DeleteCascade(ctx context.Context, projectRef primitive.ObjectID) error {
return nil
}
func (m *mockSetArchivedProjectDB) SetArchived(ctx context.Context, accountRef, organizationRef, projectRef primitive.ObjectID, archived, cascade bool) error {
return nil
}
func (m *mockSetArchivedProjectDB) All(ctx context.Context, organizationRef primitive.ObjectID, limit, offset *int64) ([]model.Project, error) {
return nil, nil
}
func (m *mockSetArchivedProjectDB) Reorder(ctx context.Context, accountRef, objectRef primitive.ObjectID, newIndex int, filter builder.Query) error {
return nil
}
func (m *mockSetArchivedProjectDB) AddTag(ctx context.Context, accountRef, objectRef, tagRef primitive.ObjectID) error {
return nil
}
func (m *mockSetArchivedProjectDB) RemoveTag(ctx context.Context, accountRef, objectRef, tagRef primitive.ObjectID) error {
return nil
}
func (m *mockSetArchivedProjectDB) RemoveTags(ctx context.Context, accountRef, organizationRef, tagRef primitive.ObjectID) error {
return nil
}
func (m *mockSetArchivedProjectDB) AddTags(ctx context.Context, accountRef, objectRef primitive.ObjectID, tagRefs []primitive.ObjectID) error {
return nil
}
func (m *mockSetArchivedProjectDB) SetTags(ctx context.Context, accountRef, objectRef primitive.ObjectID, tagRefs []primitive.ObjectID) error {
return nil
}
func (m *mockSetArchivedProjectDB) RemoveAllTags(ctx context.Context, accountRef, objectRef primitive.ObjectID) error {
return nil
}
func (m *mockSetArchivedProjectDB) GetTags(ctx context.Context, accountRef, objectRef primitive.ObjectID) ([]primitive.ObjectID, error) {
return nil, nil
}
func (m *mockSetArchivedProjectDB) HasTag(ctx context.Context, accountRef, objectRef, tagRef primitive.ObjectID) (bool, error) {
return false, nil
}
func (m *mockSetArchivedProjectDB) FindByTag(ctx context.Context, accountRef, tagRef primitive.ObjectID) ([]*model.Project, error) {
return nil, nil
}
func (m *mockSetArchivedProjectDB) FindByTags(ctx context.Context, accountRef primitive.ObjectID, tagRefs []primitive.ObjectID) ([]*model.Project, error) {
return nil, nil
}
func TestOrganizationDB_SetArchived(t *testing.T) {
organizationDB, projectDBAdapter, taskDB, commentDB, reactionDB, cleanup := setupSetArchivedTestDB(t)
defer cleanup()
ctx := context.Background()
accountRef := primitive.NewObjectID()
t.Run("SetArchived_OrganizationWithProjectsTasksCommentsAndReactions_Cascade", func(t *testing.T) {
// Create an organization using unprotected DB
organization := &model.Organization{
OrganizationBase: model.OrganizationBase{
Describable: model.Describable{Name: "Test Organization for Archive"},
TimeZone: "UTC",
},
}
organization.ID = primitive.NewObjectID()
err := organizationDB.Create(ctx, accountRef, organization.ID, organization)
require.NoError(t, err)
// Create a project for the organization using unprotected DB
project := &model.Project{
ProjectBase: model.ProjectBase{
PermissionBound: model.PermissionBound{
OrganizationBoundBase: model.OrganizationBoundBase{
OrganizationRef: organization.ID,
},
},
Describable: model.Describable{Name: "Test Project"},
Indexable: model.Indexable{Index: 0},
Mnemonic: "TEST",
State: model.ProjectStateActive,
},
}
project.ID = primitive.NewObjectID()
err = projectDBAdapter.Unprotected().Create(ctx, project)
require.NoError(t, err)
// Create a task for the project using unprotected DB
task := &model.Task{
PermissionBound: model.PermissionBound{
OrganizationBoundBase: model.OrganizationBoundBase{
OrganizationRef: organization.ID,
},
},
Describable: model.Describable{Name: "Test Task for Archive"},
ProjectRef: project.ID,
}
task.ID = primitive.NewObjectID()
err = taskDB.Unprotected().Create(ctx, task)
require.NoError(t, err)
// Create comments for the task using unprotected DB
comment := &model.Comment{
CommentBase: model.CommentBase{
PermissionBound: model.PermissionBound{
OrganizationBoundBase: model.OrganizationBoundBase{
OrganizationRef: organization.ID,
},
},
AuthorRef: accountRef,
TaskRef: task.ID,
Content: "Test Comment for Archive",
},
}
comment.ID = primitive.NewObjectID()
err = commentDB.Unprotected().Create(ctx, comment)
require.NoError(t, err)
// Create reaction for the comment using unprotected DB
reaction := &model.Reaction{
PermissionBound: model.PermissionBound{
OrganizationBoundBase: model.OrganizationBoundBase{
OrganizationRef: organization.ID,
},
},
Type: "like",
AuthorRef: accountRef,
CommentRef: comment.ID,
}
reaction.ID = primitive.NewObjectID()
err = reactionDB.Unprotected().Create(ctx, reaction)
require.NoError(t, err)
// Verify all entities are not archived initially
var retrievedOrganization model.Organization
err = organizationDB.Get(ctx, accountRef, organization.ID, &retrievedOrganization)
require.NoError(t, err)
assert.False(t, retrievedOrganization.IsArchived())
var retrievedProject model.Project
err = projectDBAdapter.Unprotected().Get(ctx, project.ID, &retrievedProject)
require.NoError(t, err)
assert.False(t, retrievedProject.IsArchived())
var retrievedTask model.Task
err = taskDB.Unprotected().Get(ctx, task.ID, &retrievedTask)
require.NoError(t, err)
assert.False(t, retrievedTask.IsArchived())
var retrievedComment model.Comment
err = commentDB.Unprotected().Get(ctx, comment.ID, &retrievedComment)
require.NoError(t, err)
assert.False(t, retrievedComment.IsArchived())
// Archive organization with cascade
err = organizationDB.SetArchived(ctx, accountRef, organization.ID, true, true)
require.NoError(t, err)
// Verify all entities are archived due to cascade
err = organizationDB.Get(ctx, accountRef, organization.ID, &retrievedOrganization)
require.NoError(t, err)
assert.True(t, retrievedOrganization.IsArchived())
err = projectDBAdapter.Unprotected().Get(ctx, project.ID, &retrievedProject)
require.NoError(t, err)
assert.True(t, retrievedProject.IsArchived())
err = taskDB.Unprotected().Get(ctx, task.ID, &retrievedTask)
require.NoError(t, err)
assert.True(t, retrievedTask.IsArchived())
err = commentDB.Unprotected().Get(ctx, comment.ID, &retrievedComment)
require.NoError(t, err)
assert.True(t, retrievedComment.IsArchived())
// Verify reaction still exists (reactions don't support archiving)
var retrievedReaction model.Reaction
err = reactionDB.Unprotected().Get(ctx, reaction.ID, &retrievedReaction)
require.NoError(t, err)
// Unarchive organization with cascade
err = organizationDB.SetArchived(ctx, accountRef, organization.ID, false, true)
require.NoError(t, err)
// Verify all entities are unarchived
err = organizationDB.Get(ctx, accountRef, organization.ID, &retrievedOrganization)
require.NoError(t, err)
assert.False(t, retrievedOrganization.IsArchived())
err = projectDBAdapter.Unprotected().Get(ctx, project.ID, &retrievedProject)
require.NoError(t, err)
assert.False(t, retrievedProject.IsArchived())
err = taskDB.Unprotected().Get(ctx, task.ID, &retrievedTask)
require.NoError(t, err)
assert.False(t, retrievedTask.IsArchived())
err = commentDB.Unprotected().Get(ctx, comment.ID, &retrievedComment)
require.NoError(t, err)
assert.False(t, retrievedComment.IsArchived())
// Clean up
err = reactionDB.Unprotected().Delete(ctx, reaction.ID)
require.NoError(t, err)
err = commentDB.Unprotected().Delete(ctx, comment.ID)
require.NoError(t, err)
err = taskDB.Unprotected().Delete(ctx, task.ID)
require.NoError(t, err)
err = projectDBAdapter.Unprotected().Delete(ctx, project.ID)
require.NoError(t, err)
err = organizationDB.Delete(ctx, accountRef, organization.ID)
require.NoError(t, err)
})
t.Run("SetArchived_NonExistentOrganization", func(t *testing.T) {
// Try to archive non-existent organization
nonExistentID := primitive.NewObjectID()
err := organizationDB.SetArchived(ctx, accountRef, nonExistentID, true, true)
assert.Error(t, err)
// Could be either no data or access denied error depending on the permission system
assert.True(t, errors.Is(err, merrors.ErrNoData) || errors.Is(err, merrors.ErrAccessDenied))
})
}

View File

@@ -0,0 +1,20 @@
package policiesdb
import (
"context"
"github.com/tech/sendico/pkg/db/repository"
"github.com/tech/sendico/pkg/db/storable"
"github.com/tech/sendico/pkg/model"
mutil "github.com/tech/sendico/pkg/mutil/db"
"go.mongodb.org/mongo-driver/bson/primitive"
)
func (db *PoliciesDB) All(ctx context.Context, organizationRef primitive.ObjectID) ([]model.PolicyDescription, error) {
// all documents
filter := repository.Query().Or(
repository.Filter(storable.OrganizationRefField, nil),
repository.OrgFilter(organizationRef),
)
return mutil.GetObjects[model.PolicyDescription](ctx, db.Logger, filter, nil, db.Repository)
}

View File

@@ -0,0 +1,13 @@
package policiesdb
import (
"context"
"github.com/tech/sendico/pkg/db/repository"
"github.com/tech/sendico/pkg/model"
"github.com/tech/sendico/pkg/mservice"
)
func (db *PoliciesDB) GetBuiltInPolicy(ctx context.Context, resourceType mservice.Type, policy *model.PolicyDescription) error {
return db.FindOne(ctx, repository.Filter("resourceTypes", resourceType), policy)
}

View File

@@ -0,0 +1,21 @@
package policiesdb
import (
"github.com/tech/sendico/pkg/db/template"
"github.com/tech/sendico/pkg/mlogger"
"github.com/tech/sendico/pkg/model"
"github.com/tech/sendico/pkg/mservice"
"go.mongodb.org/mongo-driver/mongo"
)
type PoliciesDB struct {
template.DBImp[*model.PolicyDescription]
}
func Create(logger mlogger.Logger, db *mongo.Database) (*PoliciesDB, error) {
p := &PoliciesDB{
DBImp: *template.Create[*model.PolicyDescription](logger, mservice.Policies, db),
}
return p, nil
}

View File

@@ -0,0 +1,353 @@
//go:build integration
// +build integration
package policiesdb_test
import (
"context"
"errors"
"testing"
"time"
// Your internal packages
"github.com/tech/sendico/pkg/db/internal/mongo/policiesdb"
"github.com/tech/sendico/pkg/db/repository"
"github.com/tech/sendico/pkg/db/repository/builder"
"github.com/tech/sendico/pkg/merrors"
// Model package (contains PolicyDescription + Describable)
"github.com/tech/sendico/pkg/model"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
// Testcontainers
"github.com/testcontainers/testcontainers-go"
"github.com/testcontainers/testcontainers-go/modules/mongodb"
"github.com/testcontainers/testcontainers-go/wait"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
"go.uber.org/zap"
)
// Helper to terminate container
func terminate(t *testing.T, ctx context.Context, container *mongodb.MongoDBContainer) {
err := container.Terminate(ctx)
require.NoError(t, err, "failed to terminate MongoDB container")
}
// Helper to disconnect client
func disconnect(t *testing.T, ctx context.Context, client *mongo.Client) {
err := client.Disconnect(context.Background())
require.NoError(t, err, "failed to disconnect from MongoDB")
}
// Helper to drop the Policies collection
func cleanupCollection(t *testing.T, ctx context.Context, db *mongo.Database) {
// The actual collection name is typically the value returned by
// (&model.PolicyDescription{}).Collection(), or something similar.
// Make sure it matches what your code uses (often "policies" or "policyDescription").
err := db.Collection((&model.PolicyDescription{}).Collection()).Drop(ctx)
require.NoError(t, err, "failed to drop collection between sub-tests")
}
func TestPoliciesDB(t *testing.T) {
// Create context with reasonable timeout
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
// Start MongoDB test container
mongoC, err := mongodb.Run(ctx,
"mongo:latest",
mongodb.WithUsername("root"),
mongodb.WithPassword("password"),
testcontainers.WithWaitStrategy(wait.ForLog("Waiting for connections")),
)
require.NoError(t, err, "failed to start MongoDB container")
defer terminate(t, ctx, mongoC)
// Get connection URI
mongoURI, err := mongoC.ConnectionString(ctx)
require.NoError(t, err, "failed to get connection string")
// Connect client
clientOpts := options.Client().ApplyURI(mongoURI)
client, err := mongo.Connect(ctx, clientOpts)
require.NoError(t, err, "failed to connect to MongoDB")
defer disconnect(t, ctx, client)
// Create test DB
db := client.Database("testdb")
// Use a no-op logger (or real logger if you prefer)
logger := zap.NewNop()
// Create an instance of PoliciesDB
pdb, err := policiesdb.Create(logger, db)
require.NoError(t, err, "unexpected error creating PoliciesDB")
// ---------------------------------------------------------
// Each sub-test below starts by dropping the collection.
// ---------------------------------------------------------
t.Run("CreateAndGet", func(t *testing.T) {
cleanupCollection(t, ctx, db) // ensure no leftover data
desc := "Test policy description"
policy := &model.PolicyDescription{
Describable: model.Describable{
Name: "TestPolicy",
Description: &desc,
},
}
require.NoError(t, pdb.Create(ctx, policy))
result := &model.PolicyDescription{}
err := pdb.Get(ctx, policy.ID, result)
require.NoError(t, err)
assert.Equal(t, policy.ID, result.ID)
assert.Equal(t, "TestPolicy", result.Name)
assert.NotNil(t, result.Description)
assert.Equal(t, "Test policy description", *result.Description)
})
t.Run("Get_NotFound", func(t *testing.T) {
cleanupCollection(t, ctx, db)
// Attempt to get a non-existent ID
nonExistentID := primitive.NewObjectID()
result := &model.PolicyDescription{}
err := pdb.Get(ctx, nonExistentID, result)
assert.Error(t, err)
assert.True(t, errors.Is(err, merrors.ErrNoData))
})
t.Run("Update", func(t *testing.T) {
cleanupCollection(t, ctx, db)
originalDesc := "Original description"
policy := &model.PolicyDescription{
Describable: model.Describable{
Name: "OriginalName",
Description: &originalDesc,
},
}
require.NoError(t, pdb.Create(ctx, policy))
newDesc := "Updated description"
policy.Name = "UpdatedName"
policy.Description = &newDesc
err := pdb.Update(ctx, policy)
require.NoError(t, err)
updated := &model.PolicyDescription{}
err = pdb.Get(ctx, policy.ID, updated)
require.NoError(t, err)
assert.Equal(t, "UpdatedName", updated.Name)
assert.NotNil(t, updated.Description)
assert.Equal(t, "Updated description", *updated.Description)
})
t.Run("Delete", func(t *testing.T) {
cleanupCollection(t, ctx, db)
desc := "To be deleted"
policy := &model.PolicyDescription{
Describable: model.Describable{
Name: "WillDelete",
Description: &desc,
},
}
require.NoError(t, pdb.Create(ctx, policy))
err := pdb.Delete(ctx, policy.ID)
require.NoError(t, err)
deleted := &model.PolicyDescription{}
err = pdb.Get(ctx, policy.ID, deleted)
assert.Error(t, err)
assert.True(t, errors.Is(err, merrors.ErrNoData))
})
t.Run("DeleteMany", func(t *testing.T) {
cleanupCollection(t, ctx, db)
desc1 := "Will be deleted 1"
desc2 := "Will be deleted 2"
pol1 := &model.PolicyDescription{
Describable: model.Describable{
Name: "BatchDelete1",
Description: &desc1,
},
}
pol2 := &model.PolicyDescription{
Describable: model.Describable{
Name: "BatchDelete2",
Description: &desc2,
},
}
require.NoError(t, pdb.Create(ctx, pol1))
require.NoError(t, pdb.Create(ctx, pol2))
q := repository.Query().RegEx(repository.Field("description"), "^Will be deleted", "")
err := pdb.DeleteMany(ctx, q)
require.NoError(t, err)
res1 := &model.PolicyDescription{}
err1 := pdb.Get(ctx, pol1.ID, res1)
assert.Error(t, err1)
assert.True(t, errors.Is(err1, merrors.ErrNoData))
res2 := &model.PolicyDescription{}
err2 := pdb.Get(ctx, pol2.ID, res2)
assert.Error(t, err2)
assert.True(t, errors.Is(err2, merrors.ErrNoData))
})
t.Run("FindOne", func(t *testing.T) {
cleanupCollection(t, ctx, db)
desc := "Unique find test"
policy := &model.PolicyDescription{
Describable: model.Describable{
Name: "FindOneTest",
Description: &desc,
},
}
require.NoError(t, pdb.Create(ctx, policy))
// Match by name == "FindOneTest"
q := repository.Query().Comparison(repository.Field("name"), builder.Eq, "FindOneTest")
found := &model.PolicyDescription{}
err := pdb.FindOne(ctx, q, found)
require.NoError(t, err)
assert.Equal(t, policy.ID, found.ID)
assert.Equal(t, "FindOneTest", found.Name)
assert.NotNil(t, found.Description)
assert.Equal(t, "Unique find test", *found.Description)
})
t.Run("All", func(t *testing.T) {
cleanupCollection(t, ctx, db)
// Insert some policies (orgA, orgB, nil org)
orgA := primitive.NewObjectID()
orgB := primitive.NewObjectID()
descA := "Org A policy"
policyA := &model.PolicyDescription{
Describable: model.Describable{
Name: "PolicyA",
Description: &descA,
},
OrganizationRef: &orgA, // belongs to orgA
}
descB := "Org B policy"
policyB := &model.PolicyDescription{
Describable: model.Describable{
Name: "PolicyB",
Description: &descB,
},
OrganizationRef: &orgB, // belongs to orgB
}
descNil := "No org policy"
policyNil := &model.PolicyDescription{
Describable: model.Describable{
Name: "PolicyNil",
Description: &descNil,
},
// nil => built-in
}
require.NoError(t, pdb.Create(ctx, policyA))
require.NoError(t, pdb.Create(ctx, policyB))
require.NoError(t, pdb.Create(ctx, policyNil))
// Suppose the requirement is: "All" returns
// - policies for the requested org
// - plus built-in (nil) ones
resultsA, err := pdb.All(ctx, orgA)
require.NoError(t, err)
require.Len(t, resultsA, 2) // orgA + built-in
var idsA []primitive.ObjectID
for _, r := range resultsA {
idsA = append(idsA, r.ID)
}
assert.Contains(t, idsA, policyA.ID)
assert.Contains(t, idsA, policyNil.ID)
assert.NotContains(t, idsA, policyB.ID)
resultsB, err := pdb.All(ctx, orgB)
require.NoError(t, err)
require.Len(t, resultsB, 2) // orgB + built-in
var idsB []primitive.ObjectID
for _, r := range resultsB {
idsB = append(idsB, r.ID)
}
assert.Contains(t, idsB, policyB.ID)
assert.Contains(t, idsB, policyNil.ID)
assert.NotContains(t, idsB, policyA.ID)
})
t.Run("Policies", func(t *testing.T) {
cleanupCollection(t, ctx, db)
desc1 := "PolicyOne"
pol1 := &model.PolicyDescription{
Describable: model.Describable{
Name: "PolicyOne",
Description: &desc1,
},
}
desc2 := "PolicyTwo"
pol2 := &model.PolicyDescription{
Describable: model.Describable{
Name: "PolicyTwo",
Description: &desc2,
},
}
desc3 := "PolicyThree"
pol3 := &model.PolicyDescription{
Describable: model.Describable{
Name: "PolicyThree",
Description: &desc3,
},
}
require.NoError(t, pdb.Create(ctx, pol1))
require.NoError(t, pdb.Create(ctx, pol2))
require.NoError(t, pdb.Create(ctx, pol3))
// 1) Request pol1, pol2
results12, err := pdb.Policies(ctx, []primitive.ObjectID{pol1.ID, pol2.ID})
require.NoError(t, err)
require.Len(t, results12, 2)
// IDs might be out of order, so we do a set-like check
var set12 []primitive.ObjectID
for _, r := range results12 {
set12 = append(set12, r.ID)
}
assert.Contains(t, set12, pol1.ID)
assert.Contains(t, set12, pol2.ID)
// 2) Request pol1, pol3, plus a random ID
fakeID := primitive.NewObjectID()
results13Fake, err := pdb.Policies(ctx, []primitive.ObjectID{pol1.ID, pol3.ID, fakeID})
require.NoError(t, err)
require.Len(t, results13Fake, 2) // pol1 + pol3 only
var set13Fake []primitive.ObjectID
for _, r := range results13Fake {
set13Fake = append(set13Fake, r.ID)
}
assert.Contains(t, set13Fake, pol1.ID)
assert.Contains(t, set13Fake, pol3.ID)
// 3) Request with empty slice => expect no results
resultsEmpty, err := pdb.Policies(ctx, []primitive.ObjectID{})
require.NoError(t, err)
assert.Len(t, resultsEmpty, 0)
})
}

View File

@@ -0,0 +1,18 @@
package policiesdb
import (
"context"
"github.com/tech/sendico/pkg/db/repository"
"github.com/tech/sendico/pkg/model"
mutil "github.com/tech/sendico/pkg/mutil/db"
"go.mongodb.org/mongo-driver/bson/primitive"
)
func (db *PoliciesDB) Policies(ctx context.Context, refs []primitive.ObjectID) ([]model.PolicyDescription, error) {
if len(refs) == 0 {
return []model.PolicyDescription{}, nil
}
filter := repository.Query().In(repository.IDField(), refs)
return mutil.GetObjects[model.PolicyDescription](ctx, db.Logger, filter, nil, db.Repository)
}

View File

@@ -0,0 +1,12 @@
package refreshtokensdb
import (
"context"
"github.com/tech/sendico/pkg/model"
)
func (db *RefreshTokenDB) GetClient(ctx context.Context, clientID string) (*model.Client, error) {
var client model.Client
return &client, db.clients.FindOneByFilter(ctx, filterByClientId(clientID), &client)
}

View File

@@ -0,0 +1,122 @@
package refreshtokensdb
import (
"context"
"errors"
"time"
"github.com/tech/sendico/pkg/db/repository"
"github.com/tech/sendico/pkg/merrors"
"github.com/tech/sendico/pkg/model"
"github.com/tech/sendico/pkg/mservice"
"github.com/tech/sendico/pkg/mutil/mzap"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.uber.org/zap"
)
func (db *RefreshTokenDB) Create(ctx context.Context, rt *model.RefreshToken) error {
// First, try to find an existing token for this account/client/device combination
var existing model.RefreshToken
if rt.AccountRef == nil {
return merrors.InvalidArgument("Account reference must have a vaild value")
}
if err := db.FindOne(ctx, filterByAccount(*rt.AccountRef, &rt.SessionIdentifier), &existing); err != nil {
if errors.Is(err, merrors.ErrNoData) {
// No existing token, create a new one
db.Logger.Info("Registering refresh token", zap.String("client_id", rt.ClientID), zap.String("device_id", rt.DeviceID))
return db.DBImp.Create(ctx, rt)
}
db.Logger.Warn("Something went wrong when checking existing sessions", zap.Error(err),
zap.String("client_id", rt.ClientID), zap.String("device_id", rt.DeviceID))
return err
}
// Token already exists, update it with new values
db.Logger.Info("Updating existing refresh token", zap.String("client_id", rt.ClientID), zap.String("device_id", rt.DeviceID))
patch := repository.Patch().
Set(repository.Field(TokenField), rt.RefreshToken).
Set(repository.Field(ExpiresAtField), rt.ExpiresAt).
Set(repository.Field(UserAgentField), rt.UserAgent).
Set(repository.Field(IPAddressField), rt.IPAddress).
Set(repository.Field(LastUsedAtField), rt.LastUsedAt).
Set(repository.Field(IsRevokedField), rt.IsRevoked)
if err := db.Patch(ctx, *existing.GetID(), patch); err != nil {
db.Logger.Warn("Failed to patch refresh token", zap.Error(err), zap.String("client_id", rt.ClientID), zap.String("device_id", rt.DeviceID))
return err
}
// Update the ID of the input token to match the existing one
rt.SetID(*existing.GetID())
return nil
}
func (db *RefreshTokenDB) Update(ctx context.Context, rt *model.RefreshToken) error {
rt.LastUsedAt = time.Now()
// Use Patch instead of Update to avoid race conditions
patch := repository.Patch().
Set(repository.Field(TokenField), rt.RefreshToken).
Set(repository.Field(ExpiresAtField), rt.ExpiresAt).
Set(repository.Field(UserAgentField), rt.UserAgent).
Set(repository.Field(IPAddressField), rt.IPAddress).
Set(repository.Field(LastUsedAtField), rt.LastUsedAt).
Set(repository.Field(IsRevokedField), rt.IsRevoked)
return db.Patch(ctx, *rt.GetID(), patch)
}
func (db *RefreshTokenDB) Delete(ctx context.Context, tokenRef primitive.ObjectID) error {
db.Logger.Info("Deleting refresh token", mzap.ObjRef("refresh_token_ref", tokenRef))
return db.DBImp.Delete(ctx, tokenRef)
}
func (db *RefreshTokenDB) Revoke(ctx context.Context, accountRef primitive.ObjectID, session *model.SessionIdentifier) error {
var rt model.RefreshToken
f := filterByAccount(accountRef, session)
if err := db.Repository.FindOneByFilter(ctx, f, &rt); err != nil {
if errors.Is(err, merrors.ErrNoData) {
db.Logger.Warn("Failed to find refresh token", zap.Error(err),
mzap.ObjRef("account_ref", accountRef), zap.String("client_id", session.ClientID), zap.String("device_id", session.DeviceID))
return nil
}
return err
}
// Use Patch to update the revocation status
patch := repository.Patch().
Set(repository.Field(IsRevokedField), true).
Set(repository.Field(LastUsedAtField), time.Now())
return db.Patch(ctx, *rt.GetID(), patch)
}
func (db *RefreshTokenDB) GetByCRT(ctx context.Context, t *model.ClientRefreshToken) (*model.RefreshToken, error) {
var rt model.RefreshToken
f := filter(&t.SessionIdentifier).And(repository.Query().Filter(repository.Field("token"), t.RefreshToken))
if err := db.Repository.FindOneByFilter(ctx, f, &rt); err != nil {
if !errors.Is(err, merrors.ErrNoData) {
db.Logger.Warn("Failed to fetch refresh token", zap.Error(err),
zap.String("client_id", t.ClientID), zap.String("device_id", t.DeviceID))
}
return nil, err
}
// Check if token is expired
if rt.ExpiresAt.Before(time.Now()) {
db.Logger.Warn("Refresh token expired", mzap.StorableRef(&rt),
zap.String("client_id", t.ClientID), zap.String("device_id", t.DeviceID),
zap.Time("expires_at", rt.ExpiresAt))
return nil, merrors.AccessDenied(mservice.RefreshTokens, string(model.ActionRead), *rt.GetID())
}
// Check if token is revoked
if rt.IsRevoked {
db.Logger.Warn("Refresh token is revoked", mzap.StorableRef(&rt),
zap.String("client_id", t.ClientID), zap.String("device_id", t.DeviceID))
return nil, merrors.ErrNoData
}
return &rt, nil
}

View File

@@ -0,0 +1,62 @@
package refreshtokensdb
import (
"github.com/tech/sendico/pkg/db/repository"
ri "github.com/tech/sendico/pkg/db/repository/index"
"github.com/tech/sendico/pkg/db/template"
"github.com/tech/sendico/pkg/mlogger"
"github.com/tech/sendico/pkg/model"
"github.com/tech/sendico/pkg/mservice"
"go.mongodb.org/mongo-driver/mongo"
"go.uber.org/zap"
)
type RefreshTokenDB struct {
template.DBImp[*model.RefreshToken]
clients repository.Repository
}
func Create(logger mlogger.Logger, db *mongo.Database) (*RefreshTokenDB, error) {
p := &RefreshTokenDB{
DBImp: *template.Create[*model.RefreshToken](logger, mservice.RefreshTokens, db),
clients: repository.CreateMongoRepository(db, mservice.Clients),
}
if err := p.Repository.CreateIndex(&ri.Definition{
Keys: []ri.Key{{Field: "token", Sort: ri.Asc}},
Unique: true,
}); err != nil {
p.Logger.Error("Failed to create unique token index", zap.Error(err))
return nil, err
}
// Add unique constraint on account/client/device combination
if err := p.Repository.CreateIndex(&ri.Definition{
Keys: []ri.Key{
{Field: "accountRef", Sort: ri.Asc},
{Field: "clientId", Sort: ri.Asc},
{Field: "deviceId", Sort: ri.Asc},
},
Unique: true,
}); err != nil {
p.Logger.Error("Failed to create unique account/client/device index", zap.Error(err))
return nil, err
}
if err := p.Repository.CreateIndex(&ri.Definition{
Keys: []ri.Key{{Field: IsRevokedField, Sort: ri.Asc}},
}); err != nil {
p.Logger.Error("Failed to create unique token revokation status index", zap.Error(err))
return nil, err
}
if err := p.clients.CreateIndex(&ri.Definition{
Keys: []ri.Key{{Field: "clientId", Sort: ri.Asc}},
Unique: true,
}); err != nil {
p.Logger.Error("Failed to create unique client identifier index", zap.Error(err))
return nil, err
}
return p, nil
}

View File

@@ -0,0 +1,10 @@
package refreshtokensdb
const (
ExpiresAtField = "expiresAt"
IsRevokedField = "isRevoked"
TokenField = "token"
UserAgentField = "userAgent"
IPAddressField = "ipAddress"
LastUsedAtField = "lastUsedAt"
)

View File

@@ -0,0 +1,25 @@
package refreshtokensdb
import (
"github.com/tech/sendico/pkg/db/repository"
"github.com/tech/sendico/pkg/db/repository/builder"
"github.com/tech/sendico/pkg/model"
"go.mongodb.org/mongo-driver/bson/primitive"
)
func filterByClientId(clientID string) builder.Query {
return repository.Query().Comparison(repository.Field("clientId"), builder.Eq, clientID)
}
func filter(session *model.SessionIdentifier) builder.Query {
filter := filterByClientId(session.ClientID)
filter.And(
repository.Query().Comparison(repository.Field("deviceId"), builder.Eq, session.DeviceID),
repository.Query().Comparison(repository.Field(IsRevokedField), builder.Eq, false),
)
return filter
}
func filterByAccount(accountRef primitive.ObjectID, session *model.SessionIdentifier) builder.Query {
return filter(session).And(repository.Query().Comparison(repository.AccountField(), builder.Eq, accountRef))
}

View File

@@ -0,0 +1,639 @@
//go:build integration
// +build integration
package refreshtokensdb_test
import (
"context"
"errors"
"fmt"
"testing"
"time"
"github.com/tech/sendico/pkg/db/internal/mongo/refreshtokensdb"
"github.com/tech/sendico/pkg/db/repository"
"github.com/tech/sendico/pkg/db/repository/builder"
"github.com/tech/sendico/pkg/merrors"
factory "github.com/tech/sendico/pkg/mlogger/factory"
"github.com/tech/sendico/pkg/model"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/testcontainers/testcontainers-go"
"github.com/testcontainers/testcontainers-go/modules/mongodb"
"github.com/testcontainers/testcontainers-go/wait"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
)
func setupTestDB(t *testing.T) (*refreshtokensdb.RefreshTokenDB, func()) {
// mark as helper for better test failure reporting
t.Helper()
startCtx, startCancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer startCancel()
mongoContainer, err := mongodb.Run(startCtx,
"mongo:latest",
mongodb.WithUsername("root"),
mongodb.WithPassword("password"),
testcontainers.WithWaitStrategy(wait.ForListeningPort("27017/tcp").WithStartupTimeout(2*time.Minute)),
)
require.NoError(t, err, "failed to start MongoDB container")
mongoURI, err := mongoContainer.ConnectionString(startCtx)
require.NoError(t, err, "failed to get MongoDB connection string")
clientOptions := options.Client().ApplyURI(mongoURI)
client, err := mongo.Connect(startCtx, clientOptions)
require.NoError(t, err, "failed to connect to MongoDB")
database := client.Database("test_refresh_tokens_" + t.Name())
logger := factory.NewLogger(true)
db, err := refreshtokensdb.Create(logger, database)
require.NoError(t, err, "failed to create refresh tokens db")
cleanup := func() {
_ = database.Drop(context.Background())
_ = client.Disconnect(context.Background())
termCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
_ = mongoContainer.Terminate(termCtx)
}
return db, cleanup
}
func createTestRefreshToken(accountRef primitive.ObjectID, clientID, deviceID, token string) *model.RefreshToken {
return &model.RefreshToken{
ClientRefreshToken: model.ClientRefreshToken{
SessionIdentifier: model.SessionIdentifier{
ClientID: clientID,
DeviceID: deviceID,
},
RefreshToken: token,
},
AccountBoundBase: model.AccountBoundBase{
AccountRef: &accountRef,
},
ExpiresAt: time.Now().Add(24 * time.Hour),
IsRevoked: false,
UserAgent: "TestUserAgent/1.0",
IPAddress: "192.168.1.1",
LastUsedAt: time.Now(),
}
}
func TestRefreshTokenDB_AuthenticationFlow(t *testing.T) {
db, cleanup := setupTestDB(t)
defer cleanup()
ctx := context.Background()
t.Run("Complete_User_Authentication_Flow", func(t *testing.T) {
// Setup: Create user and client
userID := primitive.NewObjectID()
clientID := "web-app"
deviceID := "user-desktop-chrome"
token := "refresh_token_12345"
// Step 1: User logs in - create initial refresh token
refreshToken := createTestRefreshToken(userID, clientID, deviceID, token)
err := db.Create(ctx, refreshToken)
require.NoError(t, err)
// Step 2: User uses refresh token to get new access token
crt := &model.ClientRefreshToken{
SessionIdentifier: model.SessionIdentifier{
ClientID: clientID,
DeviceID: deviceID,
},
RefreshToken: token,
}
retrievedToken, err := db.GetByCRT(ctx, crt)
require.NoError(t, err)
require.NotNil(t, retrievedToken.AccountRef)
assert.Equal(t, userID, *retrievedToken.AccountRef)
assert.Equal(t, token, retrievedToken.RefreshToken)
assert.False(t, retrievedToken.IsRevoked)
// Step 3: User logs out - revoke the token
session := &model.SessionIdentifier{
ClientID: clientID,
DeviceID: deviceID,
}
err = db.Revoke(ctx, userID, session)
require.NoError(t, err)
// Step 4: Try to use revoked token - should fail
_, err = db.GetByCRT(ctx, crt)
assert.Error(t, err)
assert.True(t, errors.Is(err, merrors.ErrNoData))
})
t.Run("Manual_Token_Revocation_Workaround", func(t *testing.T) {
// Test manual revocation by directly updating the token
userID := primitive.NewObjectID()
clientID := "web-app"
deviceID := "user-desktop-chrome"
token := "manual_revoke_token_123"
// Step 1: Create token
refreshToken := createTestRefreshToken(userID, clientID, deviceID, token)
err := db.Create(ctx, refreshToken)
require.NoError(t, err)
// Step 2: Manually revoke token by updating it directly
refreshToken.IsRevoked = true
err = db.Update(ctx, refreshToken)
require.NoError(t, err)
// Step 3: Try to use revoked token - should fail
crt := &model.ClientRefreshToken{
SessionIdentifier: model.SessionIdentifier{
ClientID: clientID,
DeviceID: deviceID,
},
RefreshToken: token,
}
_, err = db.GetByCRT(ctx, crt)
assert.Error(t, err)
assert.True(t, errors.Is(err, merrors.ErrNoData))
})
}
func TestRefreshTokenDB_MultiDeviceManagement(t *testing.T) {
db, cleanup := setupTestDB(t)
defer cleanup()
ctx := context.Background()
t.Run("User_With_Multiple_Devices", func(t *testing.T) {
userID := primitive.NewObjectID()
clientID := "mobile-app"
// User logs in from phone
phoneToken := createTestRefreshToken(userID, clientID, "phone-ios", "phone_token_123")
err := db.Create(ctx, phoneToken)
require.NoError(t, err)
// User logs in from tablet
tabletToken := createTestRefreshToken(userID, clientID, "tablet-android", "tablet_token_456")
err = db.Create(ctx, tabletToken)
require.NoError(t, err)
// User logs in from desktop
desktopToken := createTestRefreshToken(userID, clientID, "desktop-windows", "desktop_token_789")
err = db.Create(ctx, desktopToken)
require.NoError(t, err)
// User wants to logout from all devices except current (phone)
err = db.RevokeAll(ctx, userID, "phone-ios")
require.NoError(t, err)
// Phone should still work
phoneCRT := &model.ClientRefreshToken{
SessionIdentifier: model.SessionIdentifier{
ClientID: clientID,
DeviceID: "phone-ios",
},
RefreshToken: "phone_token_123",
}
_, err = db.GetByCRT(ctx, phoneCRT)
require.NoError(t, err)
// Tablet and desktop should be revoked
tabletCRT := &model.ClientRefreshToken{
SessionIdentifier: model.SessionIdentifier{
ClientID: clientID,
DeviceID: "tablet-android",
},
RefreshToken: "tablet_token_456",
}
_, err = db.GetByCRT(ctx, tabletCRT)
assert.Error(t, err)
desktopCRT := &model.ClientRefreshToken{
SessionIdentifier: model.SessionIdentifier{
ClientID: clientID,
DeviceID: "desktop-windows",
},
RefreshToken: "desktop_token_789",
}
_, err = db.GetByCRT(ctx, desktopCRT)
assert.Error(t, err)
})
}
func TestRefreshTokenDB_TokenRotation(t *testing.T) {
db, cleanup := setupTestDB(t)
defer cleanup()
ctx := context.Background()
t.Run("Token_Rotation_On_Use", func(t *testing.T) {
userID := primitive.NewObjectID()
clientID := "web-app"
deviceID := "user-browser"
initialToken := "initial_token_123"
// Create initial token
refreshToken := createTestRefreshToken(userID, clientID, deviceID, initialToken)
err := db.Create(ctx, refreshToken)
require.NoError(t, err)
// Simulate small delay
time.Sleep(10 * time.Millisecond)
// Use token - should update LastUsedAt
crt := &model.ClientRefreshToken{
SessionIdentifier: model.SessionIdentifier{
ClientID: clientID,
DeviceID: deviceID,
},
RefreshToken: initialToken,
}
retrievedToken, err := db.GetByCRT(ctx, crt)
require.NoError(t, err)
// LastUsedAt is not updated by GetByCRT; validate token data instead
assert.Equal(t, initialToken, retrievedToken.RefreshToken)
// Create new token with rotated value (simulating token rotation)
newToken := "rotated_token_456"
retrievedToken.RefreshToken = newToken
err = db.Update(ctx, retrievedToken)
require.NoError(t, err)
// Old token should no longer work
_, err = db.GetByCRT(ctx, crt)
assert.Error(t, err)
// New token should work
newCRT := &model.ClientRefreshToken{
SessionIdentifier: model.SessionIdentifier{
ClientID: clientID,
DeviceID: deviceID,
},
RefreshToken: newToken,
}
_, err = db.GetByCRT(ctx, newCRT)
require.NoError(t, err)
})
}
func TestRefreshTokenDB_SessionReplacement(t *testing.T) {
db, cleanup := setupTestDB(t)
defer cleanup()
ctx := context.Background()
t.Run("User_Login_From_Same_Device_Twice", func(t *testing.T) {
userID := primitive.NewObjectID()
clientID := "web-app"
deviceID := "user-laptop"
// First login
firstToken := createTestRefreshToken(userID, clientID, deviceID, "first_token_123")
err := db.Create(ctx, firstToken)
require.NoError(t, err)
firstTokenID := *firstToken.GetID()
// Second login from same device - should replace existing token
secondToken := createTestRefreshToken(userID, clientID, deviceID, "second_token_456")
err = db.Create(ctx, secondToken)
require.NoError(t, err)
// Should reuse the same database record
assert.Equal(t, firstTokenID, *secondToken.GetID())
// First token should no longer work
firstCRT := &model.ClientRefreshToken{
SessionIdentifier: model.SessionIdentifier{
ClientID: clientID,
DeviceID: deviceID,
},
RefreshToken: "first_token_123",
}
_, err = db.GetByCRT(ctx, firstCRT)
assert.Error(t, err)
// Second token should work
secondCRT := &model.ClientRefreshToken{
SessionIdentifier: model.SessionIdentifier{
ClientID: clientID,
DeviceID: deviceID,
},
RefreshToken: "second_token_456",
}
_, err = db.GetByCRT(ctx, secondCRT)
require.NoError(t, err)
})
}
func TestRefreshTokenDB_ClientManagement(t *testing.T) {
db, cleanup := setupTestDB(t)
defer cleanup()
ctx := context.Background()
t.Run("Client_CRUD_Operations", func(t *testing.T) {
// Note: Client management is handled by a separate client database
// This test verifies that refresh tokens work with different client IDs
userID := primitive.NewObjectID()
// Create refresh tokens for different clients
webToken := createTestRefreshToken(userID, "web-app", "device1", "token1")
err := db.Create(ctx, webToken)
require.NoError(t, err)
mobileToken := createTestRefreshToken(userID, "mobile-app", "device2", "token2")
err = db.Create(ctx, mobileToken)
require.NoError(t, err)
// Verify tokens can be retrieved by client ID
webCRT := &model.ClientRefreshToken{
SessionIdentifier: model.SessionIdentifier{
ClientID: "web-app",
DeviceID: "device1",
},
RefreshToken: "token1",
}
retrievedToken, err := db.GetByCRT(ctx, webCRT)
require.NoError(t, err)
assert.Equal(t, "web-app", retrievedToken.ClientID)
assert.Equal(t, "device1", retrievedToken.DeviceID)
mobileCRT := &model.ClientRefreshToken{
SessionIdentifier: model.SessionIdentifier{
ClientID: "mobile-app",
DeviceID: "device2",
},
RefreshToken: "token2",
}
retrievedToken, err = db.GetByCRT(ctx, mobileCRT)
require.NoError(t, err)
assert.Equal(t, "mobile-app", retrievedToken.ClientID)
assert.Equal(t, "device2", retrievedToken.DeviceID)
})
}
func TestRefreshTokenDB_SecurityScenarios(t *testing.T) {
db, cleanup := setupTestDB(t)
defer cleanup()
ctx := context.Background()
t.Run("Token_Hijacking_Prevention", func(t *testing.T) {
userID := primitive.NewObjectID()
clientID := "web-app"
deviceID := "user-browser"
token := "hijacked_token_123"
// Create legitimate token
refreshToken := createTestRefreshToken(userID, clientID, deviceID, token)
err := db.Create(ctx, refreshToken)
require.NoError(t, err)
// Simulate security concern - revoke token
session := &model.SessionIdentifier{
ClientID: clientID,
DeviceID: deviceID,
}
err = db.Revoke(ctx, userID, session)
require.NoError(t, err)
// Attacker tries to use hijacked token
crt := &model.ClientRefreshToken{
SessionIdentifier: model.SessionIdentifier{
ClientID: clientID,
DeviceID: deviceID,
},
RefreshToken: token,
}
_, err = db.GetByCRT(ctx, crt)
assert.Error(t, err)
assert.True(t, errors.Is(err, merrors.ErrNoData))
})
t.Run("Invalid_Token_Attempts", func(t *testing.T) {
// Try to use completely invalid token
crt := &model.ClientRefreshToken{
SessionIdentifier: model.SessionIdentifier{
ClientID: "invalid-client",
DeviceID: "invalid-device",
},
RefreshToken: "invalid_token_123",
}
_, err := db.GetByCRT(ctx, crt)
assert.Error(t, err)
assert.True(t, errors.Is(err, merrors.ErrNoData))
})
}
func TestRefreshTokenDB_ExpiredTokenHandling(t *testing.T) {
db, cleanup := setupTestDB(t)
defer cleanup()
ctx := context.Background()
t.Run("Expired_Token_Cleanup", func(t *testing.T) {
userID := primitive.NewObjectID()
clientID := "web-app"
deviceID := "user-device"
token := "expired_token_123"
// Create token that expires in the past
refreshToken := createTestRefreshToken(userID, clientID, deviceID, token)
refreshToken.ExpiresAt = time.Now().Add(-1 * time.Hour) // Expired 1 hour ago
err := db.Create(ctx, refreshToken)
require.NoError(t, err)
// The token exists in database but is expired
var storedToken model.RefreshToken
err = db.Get(ctx, *refreshToken.GetID(), &storedToken)
require.NoError(t, err)
assert.True(t, storedToken.ExpiresAt.Before(time.Now()))
// Application should reject expired tokens
crt := &model.ClientRefreshToken{
SessionIdentifier: model.SessionIdentifier{
ClientID: clientID,
DeviceID: deviceID,
},
RefreshToken: token,
}
_, err = db.GetByCRT(ctx, crt)
assert.Error(t, err)
assert.True(t, errors.Is(err, merrors.ErrAccessDenied))
})
}
func TestRefreshTokenDB_ConcurrentAccess(t *testing.T) {
db, cleanup := setupTestDB(t)
defer cleanup()
ctx := context.Background()
t.Run("Concurrent_Token_Usage", func(t *testing.T) {
userID := primitive.NewObjectID()
clientID := "web-app"
deviceID := "user-device"
token := "concurrent_token_123"
// Create token
refreshToken := createTestRefreshToken(userID, clientID, deviceID, token)
err := db.Create(ctx, refreshToken)
require.NoError(t, err)
crt := &model.ClientRefreshToken{
SessionIdentifier: model.SessionIdentifier{
ClientID: clientID,
DeviceID: deviceID,
},
RefreshToken: token,
}
// Simulate concurrent access
done := make(chan error, 2)
go func() {
_, err := db.GetByCRT(ctx, crt)
done <- err
}()
go func() {
_, err := db.GetByCRT(ctx, crt)
done <- err
}()
// Both operations should succeed
for i := 0; i < 2; i++ {
err := <-done
require.NoError(t, err)
}
})
}
func TestRefreshTokenDB_EdgeCases(t *testing.T) {
db, cleanup := setupTestDB(t)
defer cleanup()
ctx := context.Background()
t.Run("Delete_Token_By_ID", func(t *testing.T) {
userID := primitive.NewObjectID()
refreshToken := createTestRefreshToken(userID, "web-app", "device-1", "token_123")
err := db.Create(ctx, refreshToken)
require.NoError(t, err)
tokenID := *refreshToken.GetID()
// Delete token
err = db.Delete(ctx, tokenID)
require.NoError(t, err)
// Token should no longer exist
var result model.RefreshToken
err = db.Get(ctx, tokenID, &result)
assert.Error(t, err)
assert.True(t, errors.Is(err, merrors.ErrNoData))
})
t.Run("Revoke_Non_Existent_Token", func(t *testing.T) {
userID := primitive.NewObjectID()
session := &model.SessionIdentifier{
ClientID: "non-existent-client",
DeviceID: "non-existent-device",
}
err := db.Revoke(ctx, userID, session)
// Should handle gracefully for non-existent tokens
assert.NoError(t, err)
})
t.Run("RevokeAll_No_Other_Devices", func(t *testing.T) {
userID := primitive.NewObjectID()
clientID := "web-app"
deviceID := "only-device"
// Create single token
refreshToken := createTestRefreshToken(userID, clientID, deviceID, "token_123")
err := db.Create(ctx, refreshToken)
require.NoError(t, err)
// RevokeAll should not affect current device
err = db.RevokeAll(ctx, userID, deviceID)
require.NoError(t, err)
// Token should still work
crt := &model.ClientRefreshToken{
SessionIdentifier: model.SessionIdentifier{
ClientID: clientID,
DeviceID: deviceID,
},
RefreshToken: "token_123",
}
_, err = db.GetByCRT(ctx, crt)
require.NoError(t, err)
})
}
func TestRefreshTokenDB_DatabaseIndexes(t *testing.T) {
db, cleanup := setupTestDB(t)
defer cleanup()
ctx := context.Background()
t.Run("Unique_Token_Constraint", func(t *testing.T) {
userID1 := primitive.NewObjectID()
userID2 := primitive.NewObjectID()
token := "duplicate_token_123"
// Create first token
refreshToken1 := createTestRefreshToken(userID1, "client1", "device1", token)
err := db.Create(ctx, refreshToken1)
require.NoError(t, err)
// Try to create second token with same token value - should fail due to unique index
refreshToken2 := createTestRefreshToken(userID2, "client2", "device2", token)
err = db.Create(ctx, refreshToken2)
assert.Error(t, err)
assert.Contains(t, err.Error(), "duplicate")
})
t.Run("Query_Performance_By_Revocation_Status", func(t *testing.T) {
userID := primitive.NewObjectID()
clientID := "web-app"
// Create multiple tokens
for i := 0; i < 10; i++ {
token := createTestRefreshToken(userID, clientID,
fmt.Sprintf("device_%d", i), fmt.Sprintf("token_%d", i))
if i%2 == 0 {
token.IsRevoked = true
}
err := db.Create(ctx, token)
require.NoError(t, err)
}
// Query should efficiently filter by revocation status
query := repository.Query().
Filter(repository.AccountField(), userID).
And(repository.Query().Comparison(repository.Field(refreshtokensdb.IsRevokedField), builder.Eq, false))
ids, err := db.ListIDs(ctx, query)
require.NoError(t, err)
assert.Len(t, ids, 5) // Should find 5 non-revoked tokens
})
}

View File

@@ -0,0 +1,24 @@
package refreshtokensdb
import (
"context"
"time"
"github.com/tech/sendico/pkg/db/repository"
"github.com/tech/sendico/pkg/db/repository/builder"
"go.mongodb.org/mongo-driver/bson/primitive"
)
func (db *RefreshTokenDB) RevokeAll(ctx context.Context, accountRef primitive.ObjectID, deviceID string) error {
query := repository.Query().
Filter(repository.AccountField(), accountRef).
And(repository.Query().Comparison(repository.Field("deviceId"), builder.Ne, deviceID)).
And(repository.Query().Comparison(repository.Field(IsRevokedField), builder.Eq, false))
patch := repository.Patch().
Set(repository.Field(ExpiresAtField), time.Now()).
Set(repository.Field(IsRevokedField), true)
_, err := db.Repository.PatchMany(ctx, query, patch)
return err
}

View File

@@ -0,0 +1,90 @@
package builderimp
import (
"github.com/tech/sendico/pkg/db/repository/builder"
"go.mongodb.org/mongo-driver/bson"
)
type literalAccumulatorImp struct {
op builder.MongoOperation
value any
}
func (a *literalAccumulatorImp) Build() bson.D {
return bson.D{{Key: string(a.op), Value: a.value}}
}
func NewAccumulator(op builder.MongoOperation, value any) builder.Accumulator {
return &literalAccumulatorImp{op: op, value: value}
}
func AddToSet(value builder.Expression) builder.Expression {
return newUnaryExpression(builder.AddToSet, value)
}
func Size(value builder.Expression) builder.Expression {
return newUnaryExpression(builder.Size, value)
}
func Ne(left, right builder.Expression) builder.Expression {
return newBinaryExpression(builder.Ne, left, right)
}
func Sum(value any) builder.Accumulator {
return NewAccumulator(builder.Sum, value)
}
func Avg(value any) builder.Accumulator {
return NewAccumulator(builder.Avg, value)
}
func Min(value any) builder.Accumulator {
return NewAccumulator(builder.Min, value)
}
func Max(value any) builder.Accumulator {
return NewAccumulator(builder.Max, value)
}
func Eq(left, right builder.Expression) builder.Expression {
return newBinaryExpression(builder.Eq, left, right)
}
func Gt(left, right builder.Expression) builder.Expression {
return newBinaryExpression(builder.Gt, left, right)
}
func Add(left, right builder.Accumulator) builder.Accumulator {
return newBinaryAccumulator(builder.Add, left, right)
}
func Subtract(left, right builder.Accumulator) builder.Accumulator {
return newBinaryAccumulator(builder.Subtract, left, right)
}
func Multiply(left, right builder.Accumulator) builder.Accumulator {
return newBinaryAccumulator(builder.Multiply, left, right)
}
func Divide(left, right builder.Accumulator) builder.Accumulator {
return newBinaryAccumulator(builder.Divide, left, right)
}
type binaryAccumulator struct {
op builder.MongoOperation
left builder.Accumulator
right builder.Accumulator
}
func newBinaryAccumulator(op builder.MongoOperation, left, right builder.Accumulator) builder.Accumulator {
return &binaryAccumulator{
op: op,
left: left,
right: right,
}
}
func (b *binaryAccumulator) Build() bson.D {
args := []any{b.left.Build(), b.right.Build()}
return bson.D{{Key: string(b.op), Value: args}}
}

View File

@@ -0,0 +1,102 @@
package builderimp
import (
"github.com/tech/sendico/pkg/db/repository/builder"
"go.mongodb.org/mongo-driver/bson"
)
type aliasImp struct {
lhs builder.Field
rhs any
}
func (a *aliasImp) Field() builder.Field {
return a.lhs
}
func (a *aliasImp) Build() bson.D {
return bson.D{{Key: a.lhs.Build(), Value: a.rhs}}
}
// 1. Null alias (_id: null)
func NewNullAlias(lhs builder.Field) builder.Alias {
return &aliasImp{lhs: lhs, rhs: nil}
}
func NewAlias(lhs builder.Field, rhs any) builder.Alias {
return &aliasImp{lhs: lhs, rhs: rhs}
}
// 2. Simple alias (_id: "$taskRef")
func NewSimpleAlias(lhs, rhs builder.Field) builder.Alias {
return &aliasImp{lhs: lhs, rhs: rhs.Build()}
}
// 3. Complex alias (_id: { aliasName: "$originalField", ... })
type ComplexAlias struct {
lhs builder.Field
rhs []builder.Alias // Correcting handling of slice of aliases
}
func (a *ComplexAlias) Field() builder.Field {
return a.lhs
}
func (a *ComplexAlias) Build() bson.D {
fieldMap := bson.M{}
for _, alias := range a.rhs {
// Each alias.Build() still returns a bson.D
aliasDoc := alias.Build()
// 1. Marshal the ordered D into raw BSON bytes
raw, err := bson.Marshal(aliasDoc)
if err != nil {
panic("Failed to marshal alias document: " + err.Error())
}
// 2. Unmarshal those bytes into an unordered M
var docM bson.M
if err := bson.Unmarshal(raw, &docM); err != nil {
panic("Failed to unmarshal alias document: " + err.Error())
}
// Merge into our accumulator
for k, v := range docM {
fieldMap[k] = v
}
}
return bson.D{{Key: a.lhs.Build(), Value: fieldMap}}
}
func NewComplexAlias(lhs builder.Field, rhs []builder.Alias) builder.Alias {
return &ComplexAlias{lhs: lhs, rhs: rhs}
}
type aliasesImp struct {
aliases []builder.Alias
}
func (a *aliasesImp) Field() builder.Field {
if len(a.aliases) > 0 {
return a.aliases[0].Field()
}
return NewFieldImp("")
}
func (a *aliasesImp) Build() bson.D {
results := make([]bson.D, 0)
for _, alias := range a.aliases {
results = append(results, alias.Build())
}
aliases := bson.D{}
for _, r := range results {
aliases = append(aliases, r...)
}
return aliases
}
func NewAliases(aliases ...builder.Alias) builder.Alias {
return &aliasesImp{aliases: aliases}
}

View File

@@ -0,0 +1,27 @@
package builderimp
import (
"github.com/tech/sendico/pkg/db/repository/builder"
"go.mongodb.org/mongo-driver/bson"
)
type arrayImp struct {
elements []builder.Expression
}
// Build renders the literal array:
//
// [ <expr1>, <expr2>, … ]
func (b *arrayImp) Build() bson.A {
arr := make(bson.A, len(b.elements))
for i, expr := range b.elements {
// each expr.Build() returns the raw value or subexpression
arr[i] = expr.Build()
}
return arr
}
// NewArray constructs a new array expression from the given subexpressions.
func NewArray(exprs ...builder.Expression) *arrayImp {
return &arrayImp{elements: exprs}
}

View File

@@ -0,0 +1,108 @@
package builderimp
import (
"reflect"
"github.com/tech/sendico/pkg/db/repository/builder"
"go.mongodb.org/mongo-driver/bson"
)
type literalExpression struct {
value any
}
func NewLiteralExpression(value any) builder.Expression {
return &literalExpression{value: value}
}
func (e *literalExpression) Build() any {
return bson.D{{Key: string(builder.Literal), Value: e.value}}
}
type variadicExpression struct {
op builder.MongoOperation
parts []builder.Expression
}
func (e *variadicExpression) Build() any {
args := make([]any, 0, len(e.parts))
for _, p := range e.parts {
args = append(args, p.Build())
}
return bson.D{{Key: string(e.op), Value: args}}
}
func newVariadicExpression(op builder.MongoOperation, exprs ...builder.Expression) builder.Expression {
return &variadicExpression{
op: op,
parts: exprs,
}
}
func newBinaryExpression(op builder.MongoOperation, left, right builder.Expression) builder.Expression {
return &variadicExpression{
op: op,
parts: []builder.Expression{left, right},
}
}
type unaryExpression struct {
op builder.MongoOperation
rhs builder.Expression
}
func (e *unaryExpression) Build() any {
return bson.D{{Key: string(e.op), Value: e.rhs.Build()}}
}
func newUnaryExpression(op builder.MongoOperation, right builder.Expression) builder.Expression {
return &unaryExpression{
op: op,
rhs: right,
}
}
type matchExpression struct {
op builder.MongoOperation
rhs builder.Expression
}
func (e *matchExpression) Build() any {
return bson.E{Key: string(e.op), Value: e.rhs.Build()}
}
func newMatchExpression(op builder.MongoOperation, right builder.Expression) builder.Expression {
return &matchExpression{
op: op,
rhs: right,
}
}
func InRef(value builder.Field) builder.Expression {
return newMatchExpression(builder.In, NewValue(NewRefFieldImp(value).Build()))
}
type inImpl struct {
values []any
}
func (e *inImpl) Build() any {
return bson.D{{Key: string(builder.In), Value: e.values}}
}
func In(values ...any) builder.Expression {
var flattenedValues []any
for _, v := range values {
switch reflect.TypeOf(v).Kind() {
case reflect.Slice:
slice := reflect.ValueOf(v)
for i := range slice.Len() {
flattenedValues = append(flattenedValues, slice.Index(i).Interface())
}
default:
flattenedValues = append(flattenedValues, v)
}
}
return &inImpl{values: flattenedValues}
}

View File

@@ -0,0 +1,71 @@
package builderimp
import (
"strings"
"github.com/tech/sendico/pkg/db/repository/builder"
)
type FieldImp struct {
fields []string
}
func (b *FieldImp) Dot(field string) builder.Field {
newFields := make([]string, len(b.fields), len(b.fields)+1)
copy(newFields, b.fields)
newFields = append(newFields, field)
return &FieldImp{fields: newFields}
}
func (b *FieldImp) CopyWith(field string) builder.Field {
copiedFields := make([]string, 0, len(b.fields)+1)
copiedFields = append(copiedFields, b.fields...)
copiedFields = append(copiedFields, field)
return &FieldImp{
fields: copiedFields,
}
}
func (b *FieldImp) Build() string {
return strings.Join(b.fields, ".")
}
func NewFieldImp(baseName string) builder.Field {
return &FieldImp{
fields: []string{baseName},
}
}
type RefField struct {
imp builder.Field
}
func (b *RefField) Build() string {
return "$" + b.imp.Build()
}
func (b *RefField) CopyWith(field string) builder.Field {
return &RefField{
imp: b.imp.CopyWith(field),
}
}
func (b *RefField) Dot(field string) builder.Field {
return &RefField{
imp: b.imp.Dot(field),
}
}
func NewRefFieldImp(field builder.Field) builder.Field {
return &RefField{
imp: field,
}
}
func NewRootRef() builder.Field {
return NewFieldImp("$$ROOT")
}
func NewRemoveRef() builder.Field {
return NewFieldImp("$$REMOVE")
}

View File

@@ -0,0 +1,137 @@
package builderimp
import (
"github.com/tech/sendico/pkg/db/repository/builder"
"go.mongodb.org/mongo-driver/bson"
)
type condImp struct {
condition builder.Expression
ifTrue any
ifFalse any
}
func (c *condImp) Build() any {
return bson.D{
{Key: string(builder.Cond), Value: bson.D{
{Key: "if", Value: c.condition.Build()},
{Key: "then", Value: c.ifTrue},
{Key: "else", Value: c.ifFalse},
}},
}
}
func NewCond(condition builder.Expression, ifTrue, ifFalse any) builder.Expression {
return &condImp{
condition: condition,
ifTrue: ifTrue,
ifFalse: ifFalse,
}
}
// setUnionImp implements builder.Expression but takes only builder.Array inputs.
type setUnionImp struct {
inputs []builder.Expression
}
// Build renders the $setUnion stage:
//
// { $setUnion: [ <array1>, <array2>, … ] }
func (s *setUnionImp) Build() any {
arr := make(bson.A, len(s.inputs))
for i, arrayExpr := range s.inputs {
arr[i] = arrayExpr.Build()
}
return bson.D{
{Key: string(builder.SetUnion), Value: arr},
}
}
// NewSetUnion constructs a new $setUnion expression from the given Arrays.
func NewSetUnion(arrays ...builder.Expression) builder.Expression {
return &setUnionImp{inputs: arrays}
}
type assignmentImp struct {
field builder.Field
expression builder.Expression
}
func (a *assignmentImp) Build() bson.D {
// Assign it to the given field name
return bson.D{
{Key: a.field.Build(), Value: a.expression.Build()},
}
}
// NewAssignment creates a projection assignment of the form:
//
// <field>: <expression>
func NewAssignment(field builder.Field, expression builder.Expression) builder.Projection {
return &assignmentImp{
field: field,
expression: expression,
}
}
type computeImp struct {
field builder.Field
expression builder.Expression
}
func (a *computeImp) Build() any {
return bson.D{
{Key: string(a.field.Build()), Value: a.expression.Build()},
}
}
func NewCompute(field builder.Field, expression builder.Expression) builder.Expression {
return &computeImp{
field: field,
expression: expression,
}
}
func NewIfNull(expression, replacement builder.Expression) builder.Expression {
return newBinaryExpression(builder.IfNull, expression, replacement)
}
func NewPush(expression builder.Expression) builder.Expression {
return newUnaryExpression(builder.Push, expression)
}
func NewAnd(exprs ...builder.Expression) builder.Expression {
return newVariadicExpression(builder.And, exprs...)
}
func NewOr(exprs ...builder.Expression) builder.Expression {
return newVariadicExpression(builder.Or, exprs...)
}
func NewEach(exprs ...builder.Expression) builder.Expression {
return newVariadicExpression(builder.Each, exprs...)
}
func NewLt(left, right builder.Expression) builder.Expression {
return newBinaryExpression(builder.Lt, left, right)
}
func NewNot(expression builder.Expression) builder.Expression {
return newUnaryExpression(builder.Not, expression)
}
func NewSum(expression builder.Expression) builder.Expression {
return newUnaryExpression(builder.Sum, expression)
}
func NewMin(expression builder.Expression) builder.Expression {
return newUnaryExpression(builder.Min, expression)
}
func First(expr builder.Expression) builder.Expression {
return newUnaryExpression(builder.First, expr)
}
func NewType(expr builder.Expression) builder.Expression {
return newUnaryExpression(builder.Type, expr)
}

View File

@@ -0,0 +1,35 @@
package builderimp
import (
"github.com/tech/sendico/pkg/db/repository/builder"
"go.mongodb.org/mongo-driver/bson"
)
type groupAccumulatorImp struct {
field builder.Field
acc builder.Accumulator
}
// NewGroupAccumulator creates a new GroupAccumulator for the given field using the specified operator and value.
func NewGroupAccumulator(field builder.Field, acc builder.Accumulator) builder.GroupAccumulator {
return &groupAccumulatorImp{
field: field,
acc: acc,
}
}
func (g *groupAccumulatorImp) Field() builder.Field {
return g.field
}
func (g *groupAccumulatorImp) Accumulator() builder.Accumulator {
return g.acc
}
// Build returns a bson.E element for this group accumulator.
func (g *groupAccumulatorImp) Build() bson.D {
return bson.D{{
Key: g.field.Build(),
Value: g.acc.Build(),
}}
}

View File

@@ -0,0 +1,60 @@
package builderimp
import (
"time"
"github.com/tech/sendico/pkg/db/repository/builder"
"github.com/tech/sendico/pkg/db/storable"
"go.mongodb.org/mongo-driver/bson"
)
type patchBuilder struct {
updates bson.D
}
func set(field builder.Field, value any) bson.E {
return bson.E{Key: string(builder.Set), Value: bson.D{{Key: field.Build(), Value: value}}}
}
func (u *patchBuilder) Set(field builder.Field, value any) builder.Patch {
u.updates = append(u.updates, set(field, value))
return u
}
func (u *patchBuilder) Inc(field builder.Field, value any) builder.Patch {
u.updates = append(u.updates, bson.E{Key: string(builder.Inc), Value: bson.D{{Key: field.Build(), Value: value}}})
return u
}
func (u *patchBuilder) Unset(field builder.Field) builder.Patch {
u.updates = append(u.updates, bson.E{Key: string(builder.Unset), Value: bson.D{{Key: field.Build(), Value: ""}}})
return u
}
func (u *patchBuilder) Rename(field builder.Field, newName string) builder.Patch {
u.updates = append(u.updates, bson.E{Key: string(builder.Rename), Value: bson.D{{Key: field.Build(), Value: newName}}})
return u
}
func (u *patchBuilder) Push(field builder.Field, value any) builder.Patch {
u.updates = append(u.updates, bson.E{Key: string(builder.Push), Value: bson.D{{Key: field.Build(), Value: value}}})
return u
}
func (u *patchBuilder) Pull(field builder.Field, value any) builder.Patch {
u.updates = append(u.updates, bson.E{Key: string(builder.Pull), Value: bson.D{{Key: field.Build(), Value: value}}})
return u
}
func (u *patchBuilder) AddToSet(field builder.Field, value any) builder.Patch {
u.updates = append(u.updates, bson.E{Key: string(builder.AddToSet), Value: bson.D{{Key: field.Build(), Value: value}}})
return u
}
func (u *patchBuilder) Build() bson.D {
return append(u.updates, set(NewFieldImp(storable.UpdatedAtField), time.Now()))
}
func NewPatchImp() builder.Patch {
return &patchBuilder{updates: bson.D{}}
}

View File

@@ -0,0 +1,131 @@
package builderimp
import (
"github.com/tech/sendico/pkg/db/repository/builder"
"github.com/tech/sendico/pkg/mservice"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
)
type unwindOpts = builder.UnwindOpts
// UnwindOption is the same type defined in the builder package.
type UnwindOption = builder.UnwindOption
// NewUnwindOpts applies all UnwindOption's to a fresh unwindOpts.
func NewUnwindOpts(opts ...UnwindOption) *unwindOpts {
cfg := &unwindOpts{}
for _, opt := range opts {
opt(cfg)
}
return cfg
}
type PipelineImp struct {
pipeline mongo.Pipeline
}
func (b *PipelineImp) Match(filter builder.Query) builder.Pipeline {
b.pipeline = append(b.pipeline, filter.BuildPipeline())
return b
}
func (b *PipelineImp) Lookup(from mservice.Type, localField, foreignField, as builder.Field) builder.Pipeline {
b.pipeline = append(b.pipeline, bson.D{{Key: string(builder.Lookup), Value: bson.D{
{Key: string(builder.MKFrom), Value: from},
{Key: string(builder.MKLocalField), Value: localField.Build()},
{Key: string(builder.MKForeignField), Value: foreignField.Build()},
{Key: string(builder.MKAs), Value: as.Build()},
}}})
return b
}
func (b *PipelineImp) LookupWithPipeline(
from mservice.Type,
nested builder.Pipeline,
as builder.Field,
let *map[string]builder.Field,
) builder.Pipeline {
lookupStage := bson.D{
{Key: string(builder.MKFrom), Value: from},
{Key: string(builder.MKPipeline), Value: nested.Build()},
{Key: string(builder.MKAs), Value: as.Build()},
}
// only add "let" if provided and not empty
if let != nil && len(*let) > 0 {
letDoc := bson.D{}
for varName, fld := range *let {
letDoc = append(letDoc, bson.E{Key: varName, Value: fld.Build()})
}
lookupStage = append(lookupStage, bson.E{Key: string(builder.MKLet), Value: letDoc})
}
b.pipeline = append(b.pipeline, bson.D{{Key: string(builder.Lookup), Value: lookupStage}})
return b
}
func (b *PipelineImp) Unwind(path builder.Field, opts ...UnwindOption) builder.Pipeline {
cfg := NewUnwindOpts(opts...)
var stageValue interface{}
// if no options, shorthand
if !cfg.PreserveNullAndEmptyArrays && cfg.IncludeArrayIndex == "" {
stageValue = path.Build()
} else {
d := bson.D{{Key: string(builder.MKPath), Value: path.Build()}}
if cfg.PreserveNullAndEmptyArrays {
d = append(d, bson.E{Key: string(builder.MKPreserveNullAndEmptyArrays), Value: true})
}
if cfg.IncludeArrayIndex != "" {
d = append(d, bson.E{Key: string(builder.MKIncludeArrayIndex), Value: cfg.IncludeArrayIndex})
}
stageValue = d
}
b.pipeline = append(b.pipeline, bson.D{{Key: string(builder.Unwind), Value: stageValue}})
return b
}
func (b *PipelineImp) Count(field builder.Field) builder.Pipeline {
b.pipeline = append(b.pipeline, bson.D{{Key: string(builder.Count), Value: field.Build()}})
return b
}
func (b *PipelineImp) Group(groupBy builder.Alias, accumulators ...builder.GroupAccumulator) builder.Pipeline {
groupDoc := groupBy.Build()
for _, acc := range accumulators {
groupDoc = append(groupDoc, acc.Build()...)
}
b.pipeline = append(b.pipeline, bson.D{
{Key: string(builder.Group), Value: groupDoc},
})
return b
}
func (b *PipelineImp) Project(projections ...builder.Projection) builder.Pipeline {
projDoc := bson.D{}
for _, pr := range projections {
projDoc = append(projDoc, pr.Build()...)
}
b.pipeline = append(b.pipeline, bson.D{{Key: string(builder.Project), Value: projDoc}})
return b
}
func (b *PipelineImp) ReplaceRoot(newRoot builder.Expression) builder.Pipeline {
b.pipeline = append(b.pipeline, bson.D{{Key: string(builder.ReplaceRoot), Value: bson.D{
{Key: string(builder.MKNewRoot), Value: newRoot.Build()},
}}})
return b
}
func (b *PipelineImp) Build() mongo.Pipeline {
return b.pipeline
}
func NewPipelineImp() builder.Pipeline {
return &PipelineImp{
pipeline: mongo.Pipeline{},
}
}

View File

@@ -0,0 +1,563 @@
package builderimp
import (
"testing"
"github.com/tech/sendico/pkg/db/repository/builder"
"github.com/tech/sendico/pkg/mservice"
"github.com/stretchr/testify/assert"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
)
func TestNewPipelineImp(t *testing.T) {
pipeline := NewPipelineImp()
assert.NotNil(t, pipeline)
assert.IsType(t, &PipelineImp{}, pipeline)
// Build should return empty pipeline initially
built := pipeline.Build()
assert.NotNil(t, built)
assert.Len(t, built, 0)
}
func TestPipelineImp_Match(t *testing.T) {
pipeline := NewPipelineImp()
mockQuery := &MockQuery{
buildPipeline: bson.D{{Key: "$match", Value: bson.D{{Key: "field", Value: "value"}}}},
}
result := pipeline.Match(mockQuery)
// Should return self for chaining
assert.Same(t, pipeline, result)
built := pipeline.Build()
assert.Len(t, built, 1)
assert.Equal(t, bson.D{{Key: "$match", Value: bson.D{{Key: "field", Value: "value"}}}}, built[0])
}
func TestPipelineImp_Lookup(t *testing.T) {
pipeline := NewPipelineImp()
mockLocalField := &MockField{build: "localField"}
mockForeignField := &MockField{build: "foreignField"}
mockAsField := &MockField{build: "asField"}
result := pipeline.Lookup(mservice.Projects, mockLocalField, mockForeignField, mockAsField)
// Should return self for chaining
assert.Same(t, pipeline, result)
built := pipeline.Build()
assert.Len(t, built, 1)
expected := bson.D{{Key: string(builder.Lookup), Value: bson.D{
{Key: string(builder.MKFrom), Value: mservice.Projects},
{Key: string(builder.MKLocalField), Value: "localField"},
{Key: string(builder.MKForeignField), Value: "foreignField"},
{Key: string(builder.MKAs), Value: "asField"},
}}}
assert.Equal(t, expected, built[0])
}
func TestPipelineImp_LookupWithPipeline_WithoutLet(t *testing.T) {
pipeline := NewPipelineImp()
mockNestedPipeline := &MockPipeline{
build: mongo.Pipeline{bson.D{{Key: "$match", Value: bson.D{{Key: "nested", Value: true}}}}},
}
mockAsField := &MockField{build: "asField"}
result := pipeline.LookupWithPipeline(mservice.Tasks, mockNestedPipeline, mockAsField, nil)
// Should return self for chaining
assert.Same(t, pipeline, result)
built := pipeline.Build()
assert.Len(t, built, 1)
expected := bson.D{{Key: string(builder.Lookup), Value: bson.D{
{Key: string(builder.MKFrom), Value: mservice.Tasks},
{Key: string(builder.MKPipeline), Value: mockNestedPipeline.build},
{Key: string(builder.MKAs), Value: "asField"},
}}}
assert.Equal(t, expected, built[0])
}
func TestPipelineImp_LookupWithPipeline_WithLet(t *testing.T) {
pipeline := NewPipelineImp()
mockNestedPipeline := &MockPipeline{
build: mongo.Pipeline{bson.D{{Key: "$match", Value: bson.D{{Key: "nested", Value: true}}}}},
}
mockAsField := &MockField{build: "asField"}
mockLetField := &MockField{build: "$_id"}
letVars := map[string]builder.Field{
"projRef": mockLetField,
}
result := pipeline.LookupWithPipeline(mservice.Tasks, mockNestedPipeline, mockAsField, &letVars)
// Should return self for chaining
assert.Same(t, pipeline, result)
built := pipeline.Build()
assert.Len(t, built, 1)
expected := bson.D{{Key: string(builder.Lookup), Value: bson.D{
{Key: string(builder.MKFrom), Value: mservice.Tasks},
{Key: string(builder.MKPipeline), Value: mockNestedPipeline.build},
{Key: string(builder.MKAs), Value: "asField"},
{Key: string(builder.MKLet), Value: bson.D{{Key: "projRef", Value: "$_id"}}},
}}}
assert.Equal(t, expected, built[0])
}
func TestPipelineImp_LookupWithPipeline_WithEmptyLet(t *testing.T) {
pipeline := NewPipelineImp()
mockNestedPipeline := &MockPipeline{
build: mongo.Pipeline{bson.D{{Key: "$match", Value: bson.D{{Key: "nested", Value: true}}}}},
}
mockAsField := &MockField{build: "asField"}
emptyLetVars := map[string]builder.Field{}
pipeline.LookupWithPipeline(mservice.Tasks, mockNestedPipeline, mockAsField, &emptyLetVars)
built := pipeline.Build()
assert.Len(t, built, 1)
// Should not include let field when empty
expected := bson.D{{Key: string(builder.Lookup), Value: bson.D{
{Key: string(builder.MKFrom), Value: mservice.Tasks},
{Key: string(builder.MKPipeline), Value: mockNestedPipeline.build},
{Key: string(builder.MKAs), Value: "asField"},
}}}
assert.Equal(t, expected, built[0])
}
func TestPipelineImp_Unwind_Simple(t *testing.T) {
pipeline := NewPipelineImp()
mockField := &MockField{build: "$array"}
result := pipeline.Unwind(mockField)
// Should return self for chaining
assert.Same(t, pipeline, result)
built := pipeline.Build()
assert.Len(t, built, 1)
expected := bson.D{{Key: string(builder.Unwind), Value: "$array"}}
assert.Equal(t, expected, built[0])
}
func TestPipelineImp_Unwind_WithPreserveNullAndEmptyArrays(t *testing.T) {
pipeline := NewPipelineImp()
mockField := &MockField{build: "$array"}
// Mock the UnwindOption function
preserveOpt := func(opts *builder.UnwindOpts) {
opts.PreserveNullAndEmptyArrays = true
}
pipeline.Unwind(mockField, preserveOpt)
built := pipeline.Build()
assert.Len(t, built, 1)
expected := bson.D{{Key: string(builder.Unwind), Value: bson.D{
{Key: string(builder.MKPath), Value: "$array"},
{Key: string(builder.MKPreserveNullAndEmptyArrays), Value: true},
}}}
assert.Equal(t, expected, built[0])
}
func TestPipelineImp_Unwind_WithIncludeArrayIndex(t *testing.T) {
pipeline := NewPipelineImp()
mockField := &MockField{build: "$array"}
// Mock the UnwindOption function
indexOpt := func(opts *builder.UnwindOpts) {
opts.IncludeArrayIndex = "arrayIndex"
}
pipeline.Unwind(mockField, indexOpt)
built := pipeline.Build()
assert.Len(t, built, 1)
expected := bson.D{{Key: string(builder.Unwind), Value: bson.D{
{Key: string(builder.MKPath), Value: "$array"},
{Key: string(builder.MKIncludeArrayIndex), Value: "arrayIndex"},
}}}
assert.Equal(t, expected, built[0])
}
func TestPipelineImp_Unwind_WithBothOptions(t *testing.T) {
pipeline := NewPipelineImp()
mockField := &MockField{build: "$array"}
// Mock the UnwindOption functions
preserveOpt := func(opts *builder.UnwindOpts) {
opts.PreserveNullAndEmptyArrays = true
}
indexOpt := func(opts *builder.UnwindOpts) {
opts.IncludeArrayIndex = "arrayIndex"
}
pipeline.Unwind(mockField, preserveOpt, indexOpt)
built := pipeline.Build()
assert.Len(t, built, 1)
expected := bson.D{{Key: string(builder.Unwind), Value: bson.D{
{Key: string(builder.MKPath), Value: "$array"},
{Key: string(builder.MKPreserveNullAndEmptyArrays), Value: true},
{Key: string(builder.MKIncludeArrayIndex), Value: "arrayIndex"},
}}}
assert.Equal(t, expected, built[0])
}
func TestPipelineImp_Count(t *testing.T) {
pipeline := NewPipelineImp()
mockField := &MockField{build: "totalCount"}
result := pipeline.Count(mockField)
// Should return self for chaining
assert.Same(t, pipeline, result)
built := pipeline.Build()
assert.Len(t, built, 1)
expected := bson.D{{Key: string(builder.Count), Value: "totalCount"}}
assert.Equal(t, expected, built[0])
}
func TestPipelineImp_Group(t *testing.T) {
pipeline := NewPipelineImp()
mockAlias := &MockAlias{
build: bson.D{{Key: "_id", Value: "$field"}},
field: &MockField{build: "_id"},
}
mockAccumulator := &MockGroupAccumulator{
build: bson.D{{Key: "count", Value: bson.D{{Key: "$sum", Value: 1}}}},
}
result := pipeline.Group(mockAlias, mockAccumulator)
// Should return self for chaining
assert.Same(t, pipeline, result)
built := pipeline.Build()
assert.Len(t, built, 1)
expected := bson.D{{Key: string(builder.Group), Value: bson.D{
{Key: "_id", Value: "$field"},
{Key: "count", Value: bson.D{{Key: "$sum", Value: 1}}},
}}}
assert.Equal(t, expected, built[0])
}
func TestPipelineImp_Group_MultipleAccumulators(t *testing.T) {
pipeline := NewPipelineImp()
mockAlias := &MockAlias{
build: bson.D{{Key: "_id", Value: "$field"}},
field: &MockField{build: "_id"},
}
mockAccumulator1 := &MockGroupAccumulator{
build: bson.D{{Key: "count", Value: bson.D{{Key: "$sum", Value: 1}}}},
}
mockAccumulator2 := &MockGroupAccumulator{
build: bson.D{{Key: "total", Value: bson.D{{Key: "$sum", Value: "$amount"}}}},
}
pipeline.Group(mockAlias, mockAccumulator1, mockAccumulator2)
built := pipeline.Build()
assert.Len(t, built, 1)
expected := bson.D{{Key: string(builder.Group), Value: bson.D{
{Key: "_id", Value: "$field"},
{Key: "count", Value: bson.D{{Key: "$sum", Value: 1}}},
{Key: "total", Value: bson.D{{Key: "$sum", Value: "$amount"}}},
}}}
assert.Equal(t, expected, built[0])
}
func TestPipelineImp_Project(t *testing.T) {
pipeline := NewPipelineImp()
mockProjection := &MockProjection{
build: bson.D{{Key: "field1", Value: 1}},
}
result := pipeline.Project(mockProjection)
// Should return self for chaining
assert.Same(t, pipeline, result)
built := pipeline.Build()
assert.Len(t, built, 1)
expected := bson.D{{Key: string(builder.Project), Value: bson.D{
{Key: "field1", Value: 1},
}}}
assert.Equal(t, expected, built[0])
}
func TestPipelineImp_Project_MultipleProjections(t *testing.T) {
pipeline := NewPipelineImp()
mockProjection1 := &MockProjection{
build: bson.D{{Key: "field1", Value: 1}},
}
mockProjection2 := &MockProjection{
build: bson.D{{Key: "field2", Value: 0}},
}
pipeline.Project(mockProjection1, mockProjection2)
built := pipeline.Build()
assert.Len(t, built, 1)
expected := bson.D{{Key: string(builder.Project), Value: bson.D{
{Key: "field1", Value: 1},
{Key: "field2", Value: 0},
}}}
assert.Equal(t, expected, built[0])
}
func TestPipelineImp_ChainedOperations(t *testing.T) {
pipeline := NewPipelineImp()
// Create mocks
mockQuery := &MockQuery{
buildPipeline: bson.D{{Key: "$match", Value: bson.D{{Key: "status", Value: "active"}}}},
}
mockLocalField := &MockField{build: "userId"}
mockForeignField := &MockField{build: "_id"}
mockAsField := &MockField{build: "user"}
mockUnwindField := &MockField{build: "$user"}
mockProjection := &MockProjection{
build: bson.D{{Key: "name", Value: "$user.name"}},
}
// Chain operations
result := pipeline.
Match(mockQuery).
Lookup(mservice.Accounts, mockLocalField, mockForeignField, mockAsField).
Unwind(mockUnwindField).
Project(mockProjection)
// Should return self for chaining
assert.Same(t, pipeline, result)
built := pipeline.Build()
assert.Len(t, built, 4)
// Verify each stage
assert.Equal(t, bson.D{{Key: "$match", Value: bson.D{{Key: "status", Value: "active"}}}}, built[0])
expectedLookup := bson.D{{Key: string(builder.Lookup), Value: bson.D{
{Key: string(builder.MKFrom), Value: mservice.Accounts},
{Key: string(builder.MKLocalField), Value: "userId"},
{Key: string(builder.MKForeignField), Value: "_id"},
{Key: string(builder.MKAs), Value: "user"},
}}}
assert.Equal(t, expectedLookup, built[1])
assert.Equal(t, bson.D{{Key: string(builder.Unwind), Value: "$user"}}, built[2])
expectedProject := bson.D{{Key: string(builder.Project), Value: bson.D{
{Key: "name", Value: "$user.name"},
}}}
assert.Equal(t, expectedProject, built[3])
}
func TestNewUnwindOpts(t *testing.T) {
t.Run("NoOptions", func(t *testing.T) {
opts := NewUnwindOpts()
assert.NotNil(t, opts)
assert.False(t, opts.PreserveNullAndEmptyArrays)
assert.Empty(t, opts.IncludeArrayIndex)
})
t.Run("WithPreserveOption", func(t *testing.T) {
preserveOpt := func(opts *builder.UnwindOpts) {
opts.PreserveNullAndEmptyArrays = true
}
opts := NewUnwindOpts(preserveOpt)
assert.True(t, opts.PreserveNullAndEmptyArrays)
assert.Empty(t, opts.IncludeArrayIndex)
})
t.Run("WithIndexOption", func(t *testing.T) {
indexOpt := func(opts *builder.UnwindOpts) {
opts.IncludeArrayIndex = "index"
}
opts := NewUnwindOpts(indexOpt)
assert.False(t, opts.PreserveNullAndEmptyArrays)
assert.Equal(t, "index", opts.IncludeArrayIndex)
})
t.Run("WithBothOptions", func(t *testing.T) {
preserveOpt := func(opts *builder.UnwindOpts) {
opts.PreserveNullAndEmptyArrays = true
}
indexOpt := func(opts *builder.UnwindOpts) {
opts.IncludeArrayIndex = "index"
}
opts := NewUnwindOpts(preserveOpt, indexOpt)
assert.True(t, opts.PreserveNullAndEmptyArrays)
assert.Equal(t, "index", opts.IncludeArrayIndex)
})
}
// Mock implementations for testing
type MockQuery struct {
buildPipeline bson.D
}
func (m *MockQuery) And(filters ...builder.Query) builder.Query { return m }
func (m *MockQuery) Or(filters ...builder.Query) builder.Query { return m }
func (m *MockQuery) Filter(field builder.Field, value any) builder.Query { return m }
func (m *MockQuery) Expression(value builder.Expression) builder.Query { return m }
func (m *MockQuery) Comparison(field builder.Field, operator builder.MongoOperation, value any) builder.Query {
return m
}
func (m *MockQuery) RegEx(field builder.Field, pattern, options string) builder.Query { return m }
func (m *MockQuery) In(field builder.Field, values ...any) builder.Query { return m }
func (m *MockQuery) NotIn(field builder.Field, values ...any) builder.Query { return m }
func (m *MockQuery) Sort(field builder.Field, ascending bool) builder.Query { return m }
func (m *MockQuery) Limit(limit *int64) builder.Query { return m }
func (m *MockQuery) Offset(offset *int64) builder.Query { return m }
func (m *MockQuery) Archived(isArchived *bool) builder.Query { return m }
func (m *MockQuery) BuildPipeline() bson.D { return m.buildPipeline }
func (m *MockQuery) BuildQuery() bson.D { return bson.D{} }
func (m *MockQuery) BuildOptions() *options.FindOptions { return &options.FindOptions{} }
type MockField struct {
build string
}
func (m *MockField) Dot(field string) builder.Field { return &MockField{build: m.build + "." + field} }
func (m *MockField) CopyWith(field string) builder.Field { return &MockField{build: field} }
func (m *MockField) Build() string { return m.build }
type MockPipeline struct {
build mongo.Pipeline
}
func (m *MockPipeline) Match(filter builder.Query) builder.Pipeline { return m }
func (m *MockPipeline) Lookup(from mservice.Type, localField, foreignField, as builder.Field) builder.Pipeline {
return m
}
func (m *MockPipeline) LookupWithPipeline(from mservice.Type, pipeline builder.Pipeline, as builder.Field, let *map[string]builder.Field) builder.Pipeline {
return m
}
func (m *MockPipeline) Unwind(path builder.Field, opts ...UnwindOption) builder.Pipeline { return m }
func (m *MockPipeline) Count(field builder.Field) builder.Pipeline { return m }
func (m *MockPipeline) Group(groupBy builder.Alias, accumulators ...builder.GroupAccumulator) builder.Pipeline {
return m
}
func (m *MockPipeline) Project(projections ...builder.Projection) builder.Pipeline { return m }
func (m *MockPipeline) ReplaceRoot(newRoot builder.Expression) builder.Pipeline { return m }
func (m *MockPipeline) Build() mongo.Pipeline { return m.build }
type MockAlias struct {
build bson.D
field builder.Field
}
func (m *MockAlias) Field() builder.Field { return m.field }
func (m *MockAlias) Build() bson.D { return m.build }
type MockGroupAccumulator struct {
build bson.D
}
func (m *MockGroupAccumulator) Build() bson.D { return m.build }
type MockProjection struct {
build bson.D
}
func (m *MockProjection) Build() bson.D { return m.build }
func TestPipelineImp_ReplaceRoot(t *testing.T) {
pipeline := NewPipelineImp()
mockExpr := &MockExpression{build: "$newRoot"}
result := pipeline.ReplaceRoot(mockExpr)
// Should return self for chaining
assert.Same(t, pipeline, result)
built := pipeline.Build()
assert.Len(t, built, 1)
expected := bson.D{{Key: string(builder.ReplaceRoot), Value: bson.D{
{Key: string(builder.MKNewRoot), Value: "$newRoot"},
}}}
assert.Equal(t, expected, built[0])
}
func TestPipelineImp_ReplaceRoot_WithNestedField(t *testing.T) {
pipeline := NewPipelineImp()
mockExpr := &MockExpression{build: "$document.data"}
pipeline.ReplaceRoot(mockExpr)
built := pipeline.Build()
assert.Len(t, built, 1)
expected := bson.D{{Key: string(builder.ReplaceRoot), Value: bson.D{
{Key: string(builder.MKNewRoot), Value: "$document.data"},
}}}
assert.Equal(t, expected, built[0])
}
func TestPipelineImp_ReplaceRoot_WithExpression(t *testing.T) {
pipeline := NewPipelineImp()
// Mock a complex expression like { $mergeObjects: [...] }
mockExpr := &MockExpression{build: bson.D{{Key: "$mergeObjects", Value: bson.A{"$field1", "$field2"}}}}
pipeline.ReplaceRoot(mockExpr)
built := pipeline.Build()
assert.Len(t, built, 1)
expected := bson.D{{Key: string(builder.ReplaceRoot), Value: bson.D{
{Key: string(builder.MKNewRoot), Value: bson.D{{Key: "$mergeObjects", Value: bson.A{"$field1", "$field2"}}}},
}}}
assert.Equal(t, expected, built[0])
}
type MockExpression struct {
build any
}
func (m *MockExpression) Build() any { return m.build }

View File

@@ -0,0 +1,97 @@
package builderimp
import (
"github.com/tech/sendico/pkg/db/repository/builder"
"go.mongodb.org/mongo-driver/bson"
)
// projectionExprImp is a concrete implementation of builder.Projection
// that projects a field using a custom expression.
type projectionExprImp struct {
expr builder.Expression // The expression for this projection.
field builder.Field // The field name for the projected field.
}
// Field returns the field being projected.
func (p *projectionExprImp) Field() builder.Field {
return p.field
}
// Expression returns the expression for the projection.
func (p *projectionExprImp) Expression() builder.Expression {
return p.expr
}
// Build returns the built expression. If no expression is provided, returns 1.
func (p *projectionExprImp) Build() bson.D {
if p.expr == nil {
return bson.D{{Key: p.field.Build(), Value: 1}}
}
return bson.D{{Key: p.field.Build(), Value: p.expr.Build()}}
}
// NewProjectionExpr creates a new Projection for a given field and expression.
func NewProjectionExpr(field builder.Field, expr builder.Expression) builder.Projection {
return &projectionExprImp{field: field, expr: expr}
}
// aliasProjectionImp is a concrete implementation of builder.Projection
// that projects an alias (renaming a field or expression).
type aliasProjectionImp struct {
alias builder.Alias // The alias for this projection.
}
// Field returns the field being projected (via the alias).
func (p *aliasProjectionImp) Field() builder.Field {
return p.alias.Field()
}
// Expression returns no additional expression for an alias projection.
func (p *aliasProjectionImp) Expression() builder.Expression {
return nil
}
// Build returns the built alias expression.
func (p *aliasProjectionImp) Build() bson.D {
return p.alias.Build()
}
// NewAliasProjection creates a new Projection that renames or wraps an existing field or expression.
func NewAliasProjection(alias builder.Alias) builder.Projection {
return &aliasProjectionImp{alias: alias}
}
// sinkProjectionImp is a simple include/exclude projection (0 or 1).
type sinkProjectionImp struct {
field builder.Field // The field name for the projected field.
val int // 1 to include, 0 to exclude.
}
// Expression returns no expression for a sink projection.
func (p *sinkProjectionImp) Expression() builder.Expression {
return nil
}
// Build returns the include/exclude projection.
func (p *sinkProjectionImp) Build() bson.D {
return bson.D{{Key: p.field.Build(), Value: p.val}}
}
// NewSinkProjection creates a new Projection that includes (true) or excludes (false) a field.
func NewSinkProjection(field builder.Field, include bool) builder.Projection {
val := 0
if include {
val = 1
}
return &sinkProjectionImp{field: field, val: val}
}
// IncludeField returns a projection including the given field.
func IncludeField(field builder.Field) builder.Projection {
return NewSinkProjection(field, true)
}
// ExcludeField returns a projection excluding the given field.
func ExcludeField(field builder.Field) builder.Projection {
return NewSinkProjection(field, false)
}

View File

@@ -0,0 +1,156 @@
package builderimp
import (
"reflect"
"github.com/tech/sendico/pkg/db/repository/builder"
"github.com/tech/sendico/pkg/db/storable"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/mongo/options"
)
type QueryImp struct {
filter bson.D
sort bson.D
limit *int64
offset *int64
}
func (b *QueryImp) Filter(field builder.Field, value any) builder.Query {
b.filter = append(b.filter, bson.E{Key: field.Build(), Value: value})
return b
}
func (b *QueryImp) And(filters ...builder.Query) builder.Query {
andFilters := bson.A{}
for _, f := range filters {
andFilters = append(andFilters, f.BuildQuery())
}
b.filter = append(b.filter, bson.E{Key: string(builder.And), Value: andFilters})
return b
}
func (b *QueryImp) Or(filters ...builder.Query) builder.Query {
orFilters := bson.A{}
for _, f := range filters {
orFilters = append(orFilters, f.BuildQuery())
}
b.filter = append(b.filter, bson.E{Key: string(builder.Or), Value: orFilters})
return b
}
func (b *QueryImp) Comparison(field builder.Field, operator builder.MongoOperation, value any) builder.Query {
b.filter = append(b.filter, bson.E{Key: field.Build(), Value: bson.M{string(operator): value}})
return b
}
func (b *QueryImp) Expression(value builder.Expression) builder.Query {
b.filter = append(b.filter, bson.E{Key: string(builder.Expr), Value: value.Build()})
return b
}
func (b *QueryImp) RegEx(field builder.Field, pattern, options string) builder.Query {
b.filter = append(b.filter, bson.E{Key: field.Build(), Value: primitive.Regex{Pattern: pattern, Options: options}})
return b
}
func (b *QueryImp) opIn(field builder.Field, op builder.MongoOperation, values ...any) builder.Query {
var flattenedValues []any
for _, v := range values {
switch reflect.TypeOf(v).Kind() {
case reflect.Slice:
slice := reflect.ValueOf(v)
for i := range slice.Len() {
flattenedValues = append(flattenedValues, slice.Index(i).Interface())
}
default:
flattenedValues = append(flattenedValues, v)
}
}
b.filter = append(b.filter, bson.E{Key: field.Build(), Value: bson.M{string(op): flattenedValues}})
return b
}
func (b *QueryImp) NotIn(field builder.Field, values ...any) builder.Query {
return b.opIn(field, builder.NotIn, values...)
}
func (b *QueryImp) In(field builder.Field, values ...any) builder.Query {
return b.opIn(field, builder.In, values...)
}
func (b *QueryImp) Archived(isArchived *bool) builder.Query {
if isArchived == nil {
return b
}
return b.And(NewQueryImp().Filter(NewFieldImp(storable.IsArchivedField), *isArchived))
}
func (b *QueryImp) Sort(field builder.Field, ascending bool) builder.Query {
order := 1
if !ascending {
order = -1
}
b.sort = append(b.sort, bson.E{Key: field.Build(), Value: order})
return b
}
func (b *QueryImp) BuildPipeline() bson.D {
query := bson.D{}
if len(b.filter) > 0 {
query = append(query, bson.E{Key: string(builder.Match), Value: b.filter})
}
if len(b.sort) > 0 {
query = append(query, bson.E{Key: string(builder.Sort), Value: b.sort})
}
if b.limit != nil {
query = append(query, bson.E{Key: string(builder.Limit), Value: *b.limit})
}
if b.offset != nil {
query = append(query, bson.E{Key: string(builder.Skip), Value: *b.offset})
}
return query
}
func (b *QueryImp) BuildQuery() bson.D {
return b.filter
}
func (b *QueryImp) Limit(limit *int64) builder.Query {
b.limit = limit
return b
}
func (b *QueryImp) Offset(offset *int64) builder.Query {
b.offset = offset
return b
}
func (b *QueryImp) BuildOptions() *options.FindOptions {
opts := options.Find()
if b.limit != nil {
opts.SetLimit(*b.limit)
}
if b.offset != nil {
opts.SetSkip(*b.offset)
}
if len(b.sort) > 0 {
opts.SetSort(b.sort)
}
return opts
}
func NewQueryImp() builder.Query {
return &QueryImp{
filter: bson.D{},
sort: bson.D{},
}
}

View File

@@ -0,0 +1,17 @@
package builderimp
import (
"github.com/tech/sendico/pkg/db/repository/builder"
)
type valueImp struct {
value any
}
func (v *valueImp) Build() any {
return v.value
}
func NewValue(value any) builder.Value {
return &valueImp{value: value}
}

View File

@@ -0,0 +1,50 @@
package repositoryimp
import (
"context"
ri "github.com/tech/sendico/pkg/db/repository/index"
"github.com/tech/sendico/pkg/merrors"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
)
func (r *MongoRepository) CreateIndex(def *ri.Definition) error {
if r.collection == nil {
return merrors.NoData("data collection is not set")
}
if len(def.Keys) == 0 {
return merrors.InvalidArgument("Index definition has no keys")
}
// ----- build BSON keys --------------------------------------------------
keys := bson.D{}
for _, k := range def.Keys {
var value any
switch {
case k.Type != "":
value = k.Type // text, 2dsphere, …
case k.Sort == ri.Desc:
value = int8(-1)
default:
value = int8(1) // default to Asc
}
keys = append(keys, bson.E{Key: k.Field, Value: value})
}
opts := options.Index().
SetUnique(def.Unique)
if def.TTL != nil {
opts.SetExpireAfterSeconds(*def.TTL)
}
if def.Name != "" {
opts.SetName(def.Name)
}
_, err := r.collection.Indexes().CreateOne(
context.Background(),
mongo.IndexModel{Keys: keys, Options: opts},
)
return err
}

View File

@@ -0,0 +1,250 @@
package repositoryimp
import (
"context"
"errors"
"fmt"
"github.com/tech/sendico/pkg/db/repository/builder"
rd "github.com/tech/sendico/pkg/db/repository/decoder"
"github.com/tech/sendico/pkg/db/storable"
"github.com/tech/sendico/pkg/merrors"
"github.com/tech/sendico/pkg/model"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
)
type MongoRepository struct {
collectionName string
collection *mongo.Collection
}
func idFilter(id primitive.ObjectID) bson.D {
return bson.D{
{Key: storable.IDField, Value: id},
}
}
func NewMongoRepository(db *mongo.Database, collection string) *MongoRepository {
return &MongoRepository{
collectionName: collection,
collection: db.Collection(collection),
}
}
func (r *MongoRepository) Collection() string {
return r.collectionName
}
func (r *MongoRepository) Insert(ctx context.Context, obj storable.Storable, getFilter builder.Query) error {
if (obj.GetID() == nil) || (obj.GetID().IsZero()) {
obj.SetID(primitive.NewObjectID())
}
obj.Update()
_, err := r.collection.InsertOne(ctx, obj)
if mongo.IsDuplicateKeyError(err) {
if getFilter != nil {
if err = r.FindOneByFilter(ctx, getFilter, obj); err != nil {
return err
}
}
return merrors.DataConflict("duplicate_key")
}
return err
}
func (r *MongoRepository) InsertMany(ctx context.Context, objects []storable.Storable) error {
if len(objects) == 0 {
return nil
}
docs := make([]interface{}, len(objects))
for i, obj := range objects {
if (obj.GetID() == nil) || (obj.GetID().IsZero()) {
obj.SetID(primitive.NewObjectID())
}
obj.Update()
docs[i] = obj
}
_, err := r.collection.InsertMany(ctx, docs)
return err
}
func (r *MongoRepository) findOneByFilterImp(ctx context.Context, filter bson.D, errMessage string, result storable.Storable) error {
err := r.collection.FindOne(ctx, filter).Decode(result)
if errors.Is(err, mongo.ErrNoDocuments) {
return merrors.NoData(errMessage)
}
return err
}
func (r *MongoRepository) Get(ctx context.Context, id primitive.ObjectID, result storable.Storable) error {
if id.IsZero() {
return merrors.InvalidArgument("zero id provided while fetching " + result.Collection())
}
return r.findOneByFilterImp(ctx, idFilter(id), fmt.Sprintf("%s with ID = %s not found", result.Collection(), id.Hex()), result)
}
type QueryFunc func(ctx context.Context, collection *mongo.Collection) (*mongo.Cursor, error)
func (r *MongoRepository) executeQuery(ctx context.Context, queryFunc QueryFunc, decoder rd.DecodingFunc) error {
cursor, err := queryFunc(ctx, r.collection)
if errors.Is(err, mongo.ErrNoDocuments) {
return merrors.NoData("no_items_in_array")
}
if err != nil {
return err
}
defer cursor.Close(ctx)
for cursor.Next(ctx) {
if err = decoder(cursor); err != nil {
return err
}
}
return nil
}
func (r *MongoRepository) Aggregate(ctx context.Context, pipeline builder.Pipeline, decoder rd.DecodingFunc) error {
queryFunc := func(ctx context.Context, collection *mongo.Collection) (*mongo.Cursor, error) {
return collection.Aggregate(ctx, pipeline.Build())
}
return r.executeQuery(ctx, queryFunc, decoder)
}
func (r *MongoRepository) FindManyByFilter(ctx context.Context, query builder.Query, decoder rd.DecodingFunc) error {
queryFunc := func(ctx context.Context, collection *mongo.Collection) (*mongo.Cursor, error) {
return collection.Find(ctx, query.BuildQuery(), query.BuildOptions())
}
return r.executeQuery(ctx, queryFunc, decoder)
}
func (r *MongoRepository) FindOneByFilter(ctx context.Context, query builder.Query, result storable.Storable) error {
return r.findOneByFilterImp(ctx, query.BuildQuery(), result.Collection()+" not found by filter", result)
}
func (r *MongoRepository) Update(ctx context.Context, obj storable.Storable) error {
obj.Update()
return r.collection.FindOneAndReplace(ctx, idFilter(*obj.GetID()), obj).Err()
}
func (r *MongoRepository) Patch(ctx context.Context, id primitive.ObjectID, patch builder.Patch) error {
if id.IsZero() {
return merrors.InvalidArgument("zero id provided while patching")
}
_, err := r.collection.UpdateByID(ctx, id, patch.Build())
return err
}
func (r *MongoRepository) PatchMany(ctx context.Context, query builder.Query, patch builder.Patch) (int, error) {
result, err := r.collection.UpdateMany(ctx, query.BuildQuery(), patch.Build())
if err != nil {
return 0, err
}
return int(result.ModifiedCount), nil
}
func (r *MongoRepository) ListIDs(ctx context.Context, query builder.Query) ([]primitive.ObjectID, error) {
filter := query.BuildQuery()
findOptions := options.Find().SetProjection(bson.M{storable.IDField: 1})
cursor, err := r.collection.Find(ctx, filter, findOptions)
if err != nil {
return nil, err
}
defer cursor.Close(ctx)
var ids []primitive.ObjectID
for cursor.Next(ctx) {
var doc struct {
ID primitive.ObjectID `bson:"_id"`
}
if err := cursor.Decode(&doc); err != nil {
return nil, err
}
ids = append(ids, doc.ID)
}
if err := cursor.Err(); err != nil {
return nil, err
}
return ids, nil
}
func (r *MongoRepository) ListPermissionBound(ctx context.Context, query builder.Query) ([]model.PermissionBoundStorable, error) {
filter := query.BuildQuery()
findOptions := options.Find().SetProjection(bson.M{
storable.IDField: 1,
storable.PermissionRefField: 1,
storable.OrganizationRefField: 1,
})
cursor, err := r.collection.Find(ctx, filter, findOptions)
if err != nil {
return nil, err
}
defer cursor.Close(ctx)
result := make([]model.PermissionBoundStorable, 0)
for cursor.Next(ctx) {
var doc model.PermissionBound
if err := cursor.Decode(&doc); err != nil {
return nil, err
}
result = append(result, &doc)
}
if err := cursor.Err(); err != nil {
return nil, err
}
return result, nil
}
func (r *MongoRepository) ListAccountBound(ctx context.Context, query builder.Query) ([]model.AccountBoundStorable, error) {
filter := query.BuildQuery()
findOptions := options.Find().SetProjection(bson.M{
storable.IDField: 1,
model.AccountRefField: 1,
model.OrganizationRefField: 1,
})
cursor, err := r.collection.Find(ctx, filter, findOptions)
if err != nil {
return nil, err
}
defer cursor.Close(ctx)
result := make([]model.AccountBoundStorable, 0)
for cursor.Next(ctx) {
var doc model.AccountBoundBase
if err := cursor.Decode(&doc); err != nil {
return nil, err
}
result = append(result, &doc)
}
if err := cursor.Err(); err != nil {
return nil, err
}
return result, nil
}
func (r *MongoRepository) Delete(ctx context.Context, id primitive.ObjectID) error {
_, err := r.collection.DeleteOne(ctx, idFilter(id))
return err
}
func (r *MongoRepository) DeleteMany(ctx context.Context, query builder.Query) error {
_, err := r.collection.DeleteMany(ctx, query.BuildQuery())
return err
}
func (r *MongoRepository) Name() string {
return r.collection.Name()
}

View File

@@ -0,0 +1,577 @@
//go:build integration
// +build integration
package repositoryimp_test
import (
"context"
"errors"
"testing"
"time"
"github.com/tech/sendico/pkg/db/internal/mongo/repositoryimp"
"github.com/tech/sendico/pkg/db/internal/mongo/repositoryimp/builderimp"
"github.com/tech/sendico/pkg/db/repository/builder"
"github.com/tech/sendico/pkg/merrors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/testcontainers/testcontainers-go"
"github.com/testcontainers/testcontainers-go/modules/mongodb"
"github.com/testcontainers/testcontainers-go/wait"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
)
func TestMongoRepository_Insert(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
mongoContainer, err := mongodb.Run(ctx,
"mongo:latest",
mongodb.WithUsername("root"),
mongodb.WithPassword("password"),
testcontainers.WithWaitStrategy(wait.ForLog("Waiting for connections")),
)
require.NoError(t, err, "failed to start MongoDB container")
defer terminate(ctx, t, mongoContainer)
mongoURI, err := mongoContainer.ConnectionString(ctx)
require.NoError(t, err, "failed to get MongoDB connection string")
clientOptions := options.Client().ApplyURI(mongoURI)
client, err := mongo.Connect(ctx, clientOptions)
require.NoError(t, err, "failed to connect to MongoDB")
defer disconnect(ctx, t, client)
db := client.Database("testdb")
repository := repositoryimp.NewMongoRepository(db, "testcollection")
t.Run("Insert_WithoutID", func(t *testing.T) {
testObj := &TestObject{Name: "testInsert"}
// ID should be nil/zero initially
assert.True(t, testObj.GetID().IsZero())
err := repository.Insert(ctx, testObj, nil)
require.NoError(t, err)
// ID should be assigned after insert
assert.False(t, testObj.GetID().IsZero())
assert.NotEmpty(t, testObj.CreatedAt)
assert.NotEmpty(t, testObj.UpdatedAt)
})
t.Run("Insert_WithExistingID", func(t *testing.T) {
existingID := primitive.NewObjectID()
testObj := &TestObject{Name: "testInsertWithID"}
testObj.SetID(existingID)
err := repository.Insert(ctx, testObj, nil)
require.NoError(t, err)
// ID should remain the same
assert.Equal(t, existingID, *testObj.GetID())
})
t.Run("Insert_DuplicateKey", func(t *testing.T) {
// Insert first object
testObj1 := &TestObject{Name: "duplicate"}
err := repository.Insert(ctx, testObj1, nil)
require.NoError(t, err)
// Try to insert object with same ID
testObj2 := &TestObject{Name: "duplicate2"}
testObj2.SetID(*testObj1.GetID())
err = repository.Insert(ctx, testObj2, nil)
assert.Error(t, err)
assert.True(t, errors.Is(err, merrors.ErrDataConflict))
})
t.Run("Insert_DuplicateKeyWithGetFilter", func(t *testing.T) {
// Insert first object
testObj1 := &TestObject{Name: "duplicateWithFilter"}
err := repository.Insert(ctx, testObj1, nil)
require.NoError(t, err)
// Try to insert object with same ID, but with getFilter
testObj2 := &TestObject{Name: "duplicateWithFilter2"}
testObj2.SetID(*testObj1.GetID())
getFilter := builderimp.NewQueryImp().Comparison(builderimp.NewFieldImp("_id"), builder.Eq, *testObj1.GetID())
err = repository.Insert(ctx, testObj2, getFilter)
assert.Error(t, err)
assert.True(t, errors.Is(err, merrors.ErrDataConflict))
// But testObj2 should be populated with the existing object data
assert.Equal(t, testObj1.Name, testObj2.Name)
})
}
func TestMongoRepository_Update(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
mongoContainer, err := mongodb.Run(ctx,
"mongo:latest",
mongodb.WithUsername("root"),
mongodb.WithPassword("password"),
testcontainers.WithWaitStrategy(wait.ForLog("Waiting for connections")),
)
require.NoError(t, err, "failed to start MongoDB container")
defer terminate(ctx, t, mongoContainer)
mongoURI, err := mongoContainer.ConnectionString(ctx)
require.NoError(t, err, "failed to get MongoDB connection string")
clientOptions := options.Client().ApplyURI(mongoURI)
client, err := mongo.Connect(ctx, clientOptions)
require.NoError(t, err, "failed to connect to MongoDB")
defer disconnect(ctx, t, client)
db := client.Database("testdb")
repository := repositoryimp.NewMongoRepository(db, "testcollection")
t.Run("Update_ExistingObject", func(t *testing.T) {
// Insert object first
testObj := &TestObject{Name: "originalName"}
err := repository.Insert(ctx, testObj, nil)
require.NoError(t, err)
originalUpdatedAt := testObj.UpdatedAt
// Update the object
testObj.Name = "updatedName"
time.Sleep(10 * time.Millisecond) // Ensure time difference
err = repository.Update(ctx, testObj)
require.NoError(t, err)
// Verify the object was updated
result := &TestObject{}
err = repository.Get(ctx, *testObj.GetID(), result)
require.NoError(t, err)
assert.Equal(t, "updatedName", result.Name)
assert.True(t, result.UpdatedAt.After(originalUpdatedAt))
})
t.Run("Update_NonExistentObject", func(t *testing.T) {
nonExistentID := primitive.NewObjectID()
testObj := &TestObject{Name: "nonExistent"}
testObj.SetID(nonExistentID)
err := repository.Update(ctx, testObj)
assert.Error(t, err)
assert.True(t, errors.Is(err, mongo.ErrNoDocuments))
})
}
func TestMongoRepository_Delete(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
mongoContainer, err := mongodb.Run(ctx,
"mongo:latest",
mongodb.WithUsername("root"),
mongodb.WithPassword("password"),
testcontainers.WithWaitStrategy(wait.ForLog("Waiting for connections")),
)
require.NoError(t, err, "failed to start MongoDB container")
defer terminate(ctx, t, mongoContainer)
mongoURI, err := mongoContainer.ConnectionString(ctx)
require.NoError(t, err, "failed to get MongoDB connection string")
clientOptions := options.Client().ApplyURI(mongoURI)
client, err := mongo.Connect(ctx, clientOptions)
require.NoError(t, err, "failed to connect to MongoDB")
defer disconnect(ctx, t, client)
db := client.Database("testdb")
repository := repositoryimp.NewMongoRepository(db, "testcollection")
t.Run("Delete_ExistingObject", func(t *testing.T) {
// Insert object first
testObj := &TestObject{Name: "toDelete"}
err := repository.Insert(ctx, testObj, nil)
require.NoError(t, err)
// Delete the object
err = repository.Delete(ctx, *testObj.GetID())
require.NoError(t, err)
// Verify the object was deleted
result := &TestObject{}
err = repository.Get(ctx, *testObj.GetID(), result)
assert.Error(t, err)
assert.True(t, errors.Is(err, merrors.ErrNoData))
})
t.Run("Delete_NonExistentObject", func(t *testing.T) {
nonExistentID := primitive.NewObjectID()
err := repository.Delete(ctx, nonExistentID)
// Delete should not return error even if object doesn't exist
assert.NoError(t, err)
})
}
func TestMongoRepository_FindOneByFilter(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
mongoContainer, err := mongodb.Run(ctx,
"mongo:latest",
mongodb.WithUsername("root"),
mongodb.WithPassword("password"),
testcontainers.WithWaitStrategy(wait.ForLog("Waiting for connections")),
)
require.NoError(t, err, "failed to start MongoDB container")
defer terminate(ctx, t, mongoContainer)
mongoURI, err := mongoContainer.ConnectionString(ctx)
require.NoError(t, err, "failed to get MongoDB connection string")
clientOptions := options.Client().ApplyURI(mongoURI)
client, err := mongo.Connect(ctx, clientOptions)
require.NoError(t, err, "failed to connect to MongoDB")
defer disconnect(ctx, t, client)
db := client.Database("testdb")
repository := repositoryimp.NewMongoRepository(db, "testcollection")
t.Run("FindOneByFilter_MatchingFilter", func(t *testing.T) {
// Insert test objects
testObjs := []*TestObject{
{Name: "findMe"},
{Name: "dontFindMe"},
{Name: "findMeToo"},
}
for _, obj := range testObjs {
err := repository.Insert(ctx, obj, nil)
require.NoError(t, err)
}
// Find by filter
query := builderimp.NewQueryImp().Comparison(builderimp.NewFieldImp("name"), builder.Eq, "findMe")
result := &TestObject{}
err := repository.FindOneByFilter(ctx, query, result)
require.NoError(t, err)
assert.Equal(t, "findMe", result.Name)
})
t.Run("FindOneByFilter_NoMatch", func(t *testing.T) {
query := builderimp.NewQueryImp().Comparison(builderimp.NewFieldImp("name"), builder.Eq, "nonExistentName")
result := &TestObject{}
err := repository.FindOneByFilter(ctx, query, result)
assert.Error(t, err)
assert.True(t, errors.Is(err, merrors.ErrNoData))
})
}
func TestMongoRepository_FindManyByFilter(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
mongoContainer, err := mongodb.Run(ctx,
"mongo:latest",
mongodb.WithUsername("root"),
mongodb.WithPassword("password"),
testcontainers.WithWaitStrategy(wait.ForLog("Waiting for connections")),
)
require.NoError(t, err, "failed to start MongoDB container")
defer terminate(ctx, t, mongoContainer)
mongoURI, err := mongoContainer.ConnectionString(ctx)
require.NoError(t, err, "failed to get MongoDB connection string")
clientOptions := options.Client().ApplyURI(mongoURI)
client, err := mongo.Connect(ctx, clientOptions)
require.NoError(t, err, "failed to connect to MongoDB")
defer disconnect(ctx, t, client)
db := client.Database("testdb")
repository := repositoryimp.NewMongoRepository(db, "testcollection")
t.Run("FindManyByFilter_MultipleResults", func(t *testing.T) {
// Insert test objects
testObjs := []*TestObject{
{Name: "findMany1"},
{Name: "findMany2"},
{Name: "dontFind"},
}
for _, obj := range testObjs {
err := repository.Insert(ctx, obj, nil)
require.NoError(t, err)
}
// Find objects with names starting with "findMany"
query := builderimp.NewQueryImp().RegEx(builderimp.NewFieldImp("name"), "^findMany", "")
var results []*TestObject
decoder := func(cursor *mongo.Cursor) error {
var obj TestObject
if err := cursor.Decode(&obj); err != nil {
return err
}
results = append(results, &obj)
return nil
}
err := repository.FindManyByFilter(ctx, query, decoder)
require.NoError(t, err)
assert.Len(t, results, 2)
names := make([]string, len(results))
for i, obj := range results {
names[i] = obj.Name
}
assert.Contains(t, names, "findMany1")
assert.Contains(t, names, "findMany2")
})
t.Run("FindManyByFilter_NoResults", func(t *testing.T) {
query := builderimp.NewQueryImp().Comparison(builderimp.NewFieldImp("name"), builder.Eq, "nonExistentPattern")
var results []*TestObject
decoder := func(cursor *mongo.Cursor) error {
var obj TestObject
if err := cursor.Decode(&obj); err != nil {
return err
}
results = append(results, &obj)
return nil
}
err := repository.FindManyByFilter(ctx, query, decoder)
require.NoError(t, err)
assert.Empty(t, results)
})
}
func TestMongoRepository_DeleteMany(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
mongoContainer, err := mongodb.Run(ctx,
"mongo:latest",
mongodb.WithUsername("root"),
mongodb.WithPassword("password"),
testcontainers.WithWaitStrategy(wait.ForLog("Waiting for connections")),
)
require.NoError(t, err, "failed to start MongoDB container")
defer terminate(ctx, t, mongoContainer)
mongoURI, err := mongoContainer.ConnectionString(ctx)
require.NoError(t, err, "failed to get MongoDB connection string")
clientOptions := options.Client().ApplyURI(mongoURI)
client, err := mongo.Connect(ctx, clientOptions)
require.NoError(t, err, "failed to connect to MongoDB")
defer disconnect(ctx, t, client)
db := client.Database("testdb")
repository := repositoryimp.NewMongoRepository(db, "testcollection")
t.Run("DeleteMany_MultipleDocuments", func(t *testing.T) {
// Insert test objects
testObjs := []*TestObject{
{Name: "deleteMany1"},
{Name: "deleteMany2"},
{Name: "keepMe"},
}
for _, obj := range testObjs {
err := repository.Insert(ctx, obj, nil)
require.NoError(t, err)
}
// Delete objects with names starting with "deleteMany"
query := builderimp.NewQueryImp().RegEx(builderimp.NewFieldImp("name"), "^deleteMany", "")
err := repository.DeleteMany(ctx, query)
require.NoError(t, err)
// Verify deletions
queryAll := builderimp.NewQueryImp()
var results []*TestObject
decoder := func(cursor *mongo.Cursor) error {
var obj TestObject
if err := cursor.Decode(&obj); err != nil {
return err
}
results = append(results, &obj)
return nil
}
err = repository.FindManyByFilter(ctx, queryAll, decoder)
require.NoError(t, err)
assert.Len(t, results, 1)
assert.Equal(t, "keepMe", results[0].Name)
})
}
func TestMongoRepository_Name(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
mongoContainer, err := mongodb.Run(ctx,
"mongo:latest",
mongodb.WithUsername("root"),
mongodb.WithPassword("password"),
testcontainers.WithWaitStrategy(wait.ForLog("Waiting for connections")),
)
require.NoError(t, err, "failed to start MongoDB container")
defer terminate(ctx, t, mongoContainer)
mongoURI, err := mongoContainer.ConnectionString(ctx)
require.NoError(t, err, "failed to get MongoDB connection string")
clientOptions := options.Client().ApplyURI(mongoURI)
client, err := mongo.Connect(ctx, clientOptions)
require.NoError(t, err, "failed to connect to MongoDB")
defer disconnect(ctx, t, client)
db := client.Database("testdb")
repository := repositoryimp.NewMongoRepository(db, "mycollection")
t.Run("Name_ReturnsCollectionName", func(t *testing.T) {
name := repository.Name()
assert.Equal(t, "mycollection", name)
})
}
func TestMongoRepository_ListPermissionBound(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
mongoContainer, err := mongodb.Run(ctx,
"mongo:latest",
mongodb.WithUsername("root"),
mongodb.WithPassword("password"),
testcontainers.WithWaitStrategy(wait.ForLog("Waiting for connections")),
)
require.NoError(t, err, "failed to start MongoDB container")
defer terminate(ctx, t, mongoContainer)
mongoURI, err := mongoContainer.ConnectionString(ctx)
require.NoError(t, err, "failed to get MongoDB connection string")
clientOptions := options.Client().ApplyURI(mongoURI)
client, err := mongo.Connect(ctx, clientOptions)
require.NoError(t, err, "failed to connect to MongoDB")
defer disconnect(ctx, t, client)
db := client.Database("testdb")
repository := repositoryimp.NewMongoRepository(db, "testcollection")
t.Run("ListPermissionBound_WithData", func(t *testing.T) {
// Insert test objects with permission bound data
orgID := primitive.NewObjectID()
// Insert documents directly with permission bound fields
_, err := db.Collection("testcollection").InsertMany(ctx, []interface{}{
bson.M{
"_id": primitive.NewObjectID(),
"organizationRef": orgID,
"permissionRef": primitive.NewObjectID(),
},
bson.M{
"_id": primitive.NewObjectID(),
"organizationRef": orgID,
"permissionRef": primitive.NewObjectID(),
},
})
require.NoError(t, err)
// Query for permission bound objects
query := builderimp.NewQueryImp().Comparison(builderimp.NewFieldImp("organizationRef"), builder.Eq, orgID)
results, err := repository.ListPermissionBound(ctx, query)
require.NoError(t, err)
assert.Len(t, results, 2)
for _, result := range results {
assert.Equal(t, orgID, result.GetOrganizationRef())
assert.NotNil(t, result.GetPermissionRef())
}
})
t.Run("ListPermissionBound_EmptyResult", func(t *testing.T) {
nonExistentOrgID := primitive.NewObjectID()
query := builderimp.NewQueryImp().Comparison(builderimp.NewFieldImp("organizationRef"), builder.Eq, nonExistentOrgID)
results, err := repository.ListPermissionBound(ctx, query)
require.NoError(t, err)
assert.Empty(t, results)
})
}
func TestMongoRepository_UpdateTimestamp(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
mongoContainer, err := mongodb.Run(ctx,
"mongo:latest",
mongodb.WithUsername("root"),
mongodb.WithPassword("password"),
testcontainers.WithWaitStrategy(wait.ForLog("Waiting for connections")),
)
require.NoError(t, err, "failed to start MongoDB container")
defer terminate(ctx, t, mongoContainer)
mongoURI, err := mongoContainer.ConnectionString(ctx)
require.NoError(t, err, "failed to get MongoDB connection string")
clientOptions := options.Client().ApplyURI(mongoURI)
client, err := mongo.Connect(ctx, clientOptions)
require.NoError(t, err, "failed to connect to MongoDB")
defer disconnect(ctx, t, client)
db := client.Database("testdb")
repository := repositoryimp.NewMongoRepository(db, "testcollection")
t.Run("Update_Should_Update_Timestamp", func(t *testing.T) {
// Create test object
obj := &TestObject{
Name: "Test Object",
}
// Set ID and initial timestamps
obj.SetID(primitive.NewObjectID())
originalCreatedAt := obj.CreatedAt
originalUpdatedAt := obj.UpdatedAt
// Insert the object
err := repository.Insert(ctx, obj, nil)
require.NoError(t, err)
// Wait a moment to ensure timestamp difference
time.Sleep(10 * time.Millisecond)
// Update the object
obj.Name = "Updated Object"
err = repository.Update(ctx, obj)
require.NoError(t, err)
// Verify timestamps
assert.Equal(t, originalCreatedAt, obj.CreatedAt, "CreatedAt should not change")
assert.True(t, obj.UpdatedAt.After(originalUpdatedAt), "UpdatedAt should be updated")
// Verify the object was actually updated in the database
var retrieved TestObject
err = repository.Get(ctx, *obj.GetID(), &retrieved)
require.NoError(t, err)
assert.Equal(t, "Updated Object", retrieved.Name, "Name should be updated")
assert.WithinDuration(t, originalCreatedAt, retrieved.CreatedAt, time.Second, "CreatedAt should not change in DB")
assert.True(t, retrieved.UpdatedAt.After(originalUpdatedAt), "UpdatedAt should be updated in DB")
assert.WithinDuration(t, obj.UpdatedAt, retrieved.UpdatedAt, time.Second, "UpdatedAt should match between object and DB")
})
}

View File

@@ -0,0 +1,153 @@
//go:build integration
// +build integration
package repositoryimp_test
import (
"context"
"testing"
"time"
"github.com/tech/sendico/pkg/db/internal/mongo/repositoryimp"
"github.com/tech/sendico/pkg/db/storable"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/testcontainers/testcontainers-go"
"github.com/testcontainers/testcontainers-go/modules/mongodb"
"github.com/testcontainers/testcontainers-go/wait"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
)
func TestMongoRepository_InsertMany(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
mongoContainer, err := mongodb.Run(ctx,
"mongo:latest",
mongodb.WithUsername("root"),
mongodb.WithPassword("password"),
testcontainers.WithWaitStrategy(wait.ForLog("Waiting for connections")),
)
require.NoError(t, err, "failed to start MongoDB container")
defer terminate(ctx, t, mongoContainer)
mongoURI, err := mongoContainer.ConnectionString(ctx)
require.NoError(t, err, "failed to get MongoDB connection string")
clientOptions := options.Client().ApplyURI(mongoURI)
client, err := mongo.Connect(ctx, clientOptions)
require.NoError(t, err, "failed to connect to MongoDB")
defer disconnect(ctx, t, client)
db := client.Database("testdb")
repository := repositoryimp.NewMongoRepository(db, "testcollection")
t.Run("InsertMany_Success", func(t *testing.T) {
objects := []storable.Storable{
&TestObject{Name: "test1"},
&TestObject{Name: "test2"},
&TestObject{Name: "test3"},
}
err := repository.InsertMany(ctx, objects)
require.NoError(t, err)
// Verify all objects were inserted and have IDs
for _, obj := range objects {
assert.NotNil(t, obj.GetID())
assert.False(t, obj.GetID().IsZero())
// Verify we can retrieve each object
result := &TestObject{}
err := repository.Get(ctx, *obj.GetID(), result)
require.NoError(t, err)
assert.Equal(t, obj.(*TestObject).Name, result.Name)
}
})
t.Run("InsertMany_EmptySlice", func(t *testing.T) {
objects := []storable.Storable{}
err := repository.InsertMany(ctx, objects)
require.NoError(t, err)
})
t.Run("InsertMany_WithExistingIDs", func(t *testing.T) {
id1 := primitive.NewObjectID()
id2 := primitive.NewObjectID()
objects := []storable.Storable{
&TestObject{Base: storable.Base{ID: id1}, Name: "preassigned1"},
&TestObject{Base: storable.Base{ID: id2}, Name: "preassigned2"},
}
err := repository.InsertMany(ctx, objects)
require.NoError(t, err)
// Verify objects were inserted with pre-assigned IDs
result1 := &TestObject{}
err = repository.Get(ctx, id1, result1)
require.NoError(t, err)
assert.Equal(t, "preassigned1", result1.Name)
result2 := &TestObject{}
err = repository.Get(ctx, id2, result2)
require.NoError(t, err)
assert.Equal(t, "preassigned2", result2.Name)
})
t.Run("InsertMany_MixedTypes", func(t *testing.T) {
objects := []storable.Storable{
&TestObject{Name: "test1"},
&AnotherObject{Description: "desc1"},
&TestObject{Name: "test2"},
}
err := repository.InsertMany(ctx, objects)
require.NoError(t, err)
// Verify all objects were inserted
for _, obj := range objects {
assert.NotNil(t, obj.GetID())
assert.False(t, obj.GetID().IsZero())
}
})
t.Run("InsertMany_DuplicateKey", func(t *testing.T) {
id := primitive.NewObjectID()
// Insert first object
obj1 := &TestObject{Base: storable.Base{ID: id}, Name: "original"}
err := repository.Insert(ctx, obj1, nil)
require.NoError(t, err)
// Try to insert multiple objects including one with duplicate ID
objects := []storable.Storable{
&TestObject{Name: "test1"},
&TestObject{Base: storable.Base{ID: id}, Name: "duplicate"},
}
err = repository.InsertMany(ctx, objects)
assert.Error(t, err)
assert.True(t, mongo.IsDuplicateKeyError(err))
})
t.Run("InsertMany_UpdateTimestamps", func(t *testing.T) {
objects := []storable.Storable{
&TestObject{Name: "test1"},
&TestObject{Name: "test2"},
}
err := repository.InsertMany(ctx, objects)
require.NoError(t, err)
// Verify timestamps were set
for _, obj := range objects {
testObj := obj.(*TestObject)
assert.NotZero(t, testObj.CreatedAt)
assert.NotZero(t, testObj.UpdatedAt)
}
})
}

View File

@@ -0,0 +1,233 @@
//go:build integration
// +build integration
package repositoryimp_test
import (
"context"
"testing"
"time"
"github.com/tech/sendico/pkg/db/internal/mongo/repositoryimp"
"github.com/tech/sendico/pkg/db/repository"
"github.com/tech/sendico/pkg/db/repository/builder"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/testcontainers/testcontainers-go"
"github.com/testcontainers/testcontainers-go/modules/mongodb"
"github.com/testcontainers/testcontainers-go/wait"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
)
func TestMongoRepository_PatchOperations(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
mongoContainer, err := mongodb.Run(ctx,
"mongo:latest",
mongodb.WithUsername("root"),
mongodb.WithPassword("password"),
testcontainers.WithWaitStrategy(wait.ForLog("Waiting for connections")),
)
require.NoError(t, err, "failed to start MongoDB container")
defer terminate(ctx, t, mongoContainer)
mongoURI, err := mongoContainer.ConnectionString(ctx)
require.NoError(t, err, "failed to get MongoDB connection string")
clientOptions := options.Client().ApplyURI(mongoURI)
client, err := mongo.Connect(ctx, clientOptions)
require.NoError(t, err, "failed to connect to MongoDB")
defer disconnect(ctx, t, client)
db := client.Database("testdb")
repo := repositoryimp.NewMongoRepository(db, "testcollection")
t.Run("Patch_SingleDocument", func(t *testing.T) {
obj := &TestObject{Name: "old"}
err := repo.Insert(ctx, obj, nil)
require.NoError(t, err)
original := obj.UpdatedAt
patch := repository.Patch().Set(repository.Field("name"), "new")
err = repo.Patch(ctx, *obj.GetID(), patch)
require.NoError(t, err)
var result TestObject
err = repo.Get(ctx, *obj.GetID(), &result)
require.NoError(t, err)
assert.Equal(t, "new", result.Name)
assert.True(t, result.UpdatedAt.After(original))
})
t.Run("PatchMany_MultipleDocuments", func(t *testing.T) {
objs := []*TestObject{{Name: "match"}, {Name: "match"}, {Name: "other"}}
for _, o := range objs {
err := repo.Insert(ctx, o, nil)
require.NoError(t, err)
}
query := repository.Query().Comparison(repository.Field("name"), builder.Eq, "match")
patch := repository.Patch().Set(repository.Field("name"), "patched")
modified, err := repo.PatchMany(ctx, query, patch)
require.NoError(t, err)
assert.Equal(t, 2, modified)
verify := repository.Query().Comparison(repository.Field("name"), builder.Eq, "patched")
var results []TestObject
decoder := func(cursor *mongo.Cursor) error {
var obj TestObject
if err := cursor.Decode(&obj); err != nil {
return err
}
results = append(results, obj)
return nil
}
err = repo.FindManyByFilter(ctx, verify, decoder)
require.NoError(t, err)
assert.Len(t, results, 2)
})
t.Run("Patch_PushArray", func(t *testing.T) {
obj := &TestObject{Name: "test", Tags: []string{"tag1"}}
err := repo.Insert(ctx, obj, nil)
require.NoError(t, err)
patch := repository.Patch().Push(repository.Field("tags"), "tag2")
err = repo.Patch(ctx, *obj.GetID(), patch)
require.NoError(t, err)
var result TestObject
err = repo.Get(ctx, *obj.GetID(), &result)
require.NoError(t, err)
assert.Equal(t, []string{"tag1", "tag2"}, result.Tags)
})
t.Run("Patch_PullArray", func(t *testing.T) {
obj := &TestObject{Name: "test", Tags: []string{"tag1", "tag2", "tag3"}}
err := repo.Insert(ctx, obj, nil)
require.NoError(t, err)
patch := repository.Patch().Pull(repository.Field("tags"), "tag2")
err = repo.Patch(ctx, *obj.GetID(), patch)
require.NoError(t, err)
var result TestObject
err = repo.Get(ctx, *obj.GetID(), &result)
require.NoError(t, err)
assert.Equal(t, []string{"tag1", "tag3"}, result.Tags)
})
t.Run("Patch_AddToSetArray", func(t *testing.T) {
obj := &TestObject{Name: "test", Tags: []string{"tag1"}}
err := repo.Insert(ctx, obj, nil)
require.NoError(t, err)
// Add new tag
patch := repository.Patch().AddToSet(repository.Field("tags"), "tag2")
err = repo.Patch(ctx, *obj.GetID(), patch)
require.NoError(t, err)
var result TestObject
err = repo.Get(ctx, *obj.GetID(), &result)
require.NoError(t, err)
assert.Equal(t, []string{"tag1", "tag2"}, result.Tags)
// Try to add duplicate tag - should not add
patch = repository.Patch().AddToSet(repository.Field("tags"), "tag1")
err = repo.Patch(ctx, *obj.GetID(), patch)
require.NoError(t, err)
err = repo.Get(ctx, *obj.GetID(), &result)
require.NoError(t, err)
assert.Equal(t, []string{"tag1", "tag2"}, result.Tags)
})
t.Run("Patch_PushToEmptyArray", func(t *testing.T) {
obj := &TestObject{Name: "test", Tags: []string{}}
err := repo.Insert(ctx, obj, nil)
require.NoError(t, err)
patch := repository.Patch().Push(repository.Field("tags"), "tag1")
err = repo.Patch(ctx, *obj.GetID(), patch)
require.NoError(t, err)
var result TestObject
err = repo.Get(ctx, *obj.GetID(), &result)
require.NoError(t, err)
assert.Equal(t, []string{"tag1"}, result.Tags)
})
t.Run("Patch_PullFromEmptyArray", func(t *testing.T) {
obj := &TestObject{Name: "test", Tags: []string{}}
err := repo.Insert(ctx, obj, nil)
require.NoError(t, err)
patch := repository.Patch().Pull(repository.Field("tags"), "nonexistent")
err = repo.Patch(ctx, *obj.GetID(), patch)
require.NoError(t, err)
var result TestObject
err = repo.Get(ctx, *obj.GetID(), &result)
require.NoError(t, err)
assert.Equal(t, []string{}, result.Tags)
})
t.Run("Patch_PullNonExistentElement", func(t *testing.T) {
obj := &TestObject{Name: "test", Tags: []string{"tag1", "tag2"}}
err := repo.Insert(ctx, obj, nil)
require.NoError(t, err)
patch := repository.Patch().Pull(repository.Field("tags"), "nonexistent")
err = repo.Patch(ctx, *obj.GetID(), patch)
require.NoError(t, err)
var result TestObject
err = repo.Get(ctx, *obj.GetID(), &result)
require.NoError(t, err)
assert.Equal(t, []string{"tag1", "tag2"}, result.Tags)
})
t.Run("Patch_ChainedArrayOperations", func(t *testing.T) {
obj := &TestObject{Name: "test", Tags: []string{"tag1"}}
err := repo.Insert(ctx, obj, nil)
require.NoError(t, err)
// Note: MongoDB doesn't allow multiple operations on the same array field in a single update
// This test demonstrates that chained array operations on the same field will fail
patch := repository.Patch().
Push(repository.Field("tags"), "tag2").
AddToSet(repository.Field("tags"), "tag3").
Pull(repository.Field("tags"), "tag1")
err = repo.Patch(ctx, *obj.GetID(), patch)
require.Error(t, err) // This should fail due to MongoDB's limitation
assert.Contains(t, err.Error(), "conflict")
})
t.Run("PatchMany_ArrayOperations", func(t *testing.T) {
objs := []*TestObject{
{Name: "obj1", Tags: []string{"tag1"}},
{Name: "obj2", Tags: []string{"tag2"}},
{Name: "obj3", Tags: []string{"tag3"}},
}
for _, o := range objs {
err := repo.Insert(ctx, o, nil)
require.NoError(t, err)
}
query := repository.Query().Comparison(repository.Field("name"), builder.In, []string{"obj1", "obj2"})
patch := repository.Patch().Push(repository.Field("tags"), "common")
modified, err := repo.PatchMany(ctx, query, patch)
require.NoError(t, err)
assert.Equal(t, 2, modified)
// Verify the changes
for _, name := range []string{"obj1", "obj2"} {
var result TestObject
err = repo.FindOneByFilter(ctx, repository.Query().Comparison(repository.Field("name"), builder.Eq, name), &result)
require.NoError(t, err)
assert.Contains(t, result.Tags, "common")
}
})
}

View File

@@ -0,0 +1,188 @@
//go:build integration
// +build integration
package repositoryimp_test
import (
"context"
"errors"
"testing"
"time"
"github.com/tech/sendico/pkg/db/internal/mongo/repositoryimp"
"github.com/tech/sendico/pkg/db/internal/mongo/repositoryimp/builderimp"
"github.com/tech/sendico/pkg/db/repository/builder"
"github.com/tech/sendico/pkg/db/storable"
"github.com/tech/sendico/pkg/merrors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/testcontainers/testcontainers-go"
"github.com/testcontainers/testcontainers-go/modules/mongodb"
"github.com/testcontainers/testcontainers-go/wait"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
)
type TestObject struct {
storable.Base `bson:",inline" json:",inline"`
Name string `bson:"name"`
Tags []string `bson:"tags"`
}
func (t *TestObject) Collection() string {
return "testObject"
}
type AnotherObject struct {
storable.Base `bson:",inline" json:",inline"`
Description string `bson:"description"`
}
func (a *AnotherObject) Collection() string {
return "anotherObject"
}
func terminate(ctx context.Context, t *testing.T, container *mongodb.MongoDBContainer) {
err := container.Terminate(ctx)
require.NoError(t, err, "failed to terminate MongoDB container")
}
func disconnect(ctx context.Context, t *testing.T, client *mongo.Client) {
err := client.Disconnect(ctx)
require.NoError(t, err, "failed to disconnect from MongoDB")
}
func TestMongoRepository_Get(t *testing.T) {
// Use a context with timeout, so if container spinning or DB ops hang,
// the test won't run indefinitely.
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
mongoContainer, err := mongodb.Run(ctx,
"mongo:latest",
mongodb.WithUsername("root"),
mongodb.WithPassword("password"),
testcontainers.WithWaitStrategy(wait.ForLog("Waiting for connections")),
)
require.NoError(t, err, "failed to start MongoDB container")
defer terminate(ctx, t, mongoContainer)
mongoURI, err := mongoContainer.ConnectionString(ctx)
require.NoError(t, err, "failed to get MongoDB connection string")
clientOptions := options.Client().ApplyURI(mongoURI)
client, err := mongo.Connect(ctx, clientOptions)
require.NoError(t, err, "failed to connect to MongoDB")
defer disconnect(ctx, t, client)
db := client.Database("testdb")
repository := repositoryimp.NewMongoRepository(db, "testcollection")
t.Run("Get_Success", func(t *testing.T) {
testObj := &TestObject{Name: "testName"}
err := repository.Insert(ctx, testObj, nil)
require.NoError(t, err)
result := &TestObject{}
err = repository.Get(ctx, testObj.ID, result)
require.NoError(t, err)
assert.Equal(t, testObj.Name, result.Name)
assert.Equal(t, testObj.ID, result.ID)
})
t.Run("Get_NotFound", func(t *testing.T) {
nonExistentID := primitive.NewObjectID()
result := &TestObject{}
err := repository.Get(ctx, nonExistentID, result)
assert.Error(t, err)
assert.True(t, errors.Is(err, merrors.ErrNoData))
})
t.Run("Get_InvalidID", func(t *testing.T) {
invalidID := primitive.ObjectID{} // zero value
result := &TestObject{}
err := repository.Get(ctx, invalidID, result)
assert.Error(t, err)
assert.True(t, errors.Is(err, merrors.ErrInvalidArg))
})
t.Run("Get_DifferentTypes", func(t *testing.T) {
anotherObj := &AnotherObject{Description: "testDescription"}
err := repository.Insert(ctx, anotherObj, nil)
require.NoError(t, err)
result := &AnotherObject{}
err = repository.Get(ctx, anotherObj.ID, result)
require.NoError(t, err)
assert.Equal(t, anotherObj.Description, result.Description)
assert.Equal(t, anotherObj.ID, result.ID)
})
}
func TestMongoRepository_ListIDs(t *testing.T) {
// Use a context with timeout, so if container spinning or DB ops hang,
// the test won't run indefinitely.
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
mongoContainer, err := mongodb.Run(ctx,
"mongo:latest",
mongodb.WithUsername("root"),
mongodb.WithPassword("password"),
testcontainers.WithWaitStrategy(wait.ForLog("Waiting for connections")),
)
require.NoError(t, err, "failed to start MongoDB container")
defer terminate(ctx, t, mongoContainer)
mongoURI, err := mongoContainer.ConnectionString(ctx)
require.NoError(t, err, "failed to get MongoDB connection string")
clientOptions := options.Client().ApplyURI(mongoURI)
client, err := mongo.Connect(ctx, clientOptions)
require.NoError(t, err, "failed to connect to MongoDB")
defer disconnect(ctx, t, client)
db := client.Database("testdb")
repository := repositoryimp.NewMongoRepository(db, "testcollection")
t.Run("ListIDs_Success", func(t *testing.T) {
// Insert test data
testObjs := []*TestObject{
{Name: "testName1"},
{Name: "testName2"},
{Name: "testName3"},
}
for _, obj := range testObjs {
err := repository.Insert(ctx, obj, nil)
require.NoError(t, err)
}
// Define a query to match all objects
query := builderimp.NewQueryImp()
// Call ListIDs
ids, err := repository.ListIDs(ctx, query)
require.NoError(t, err)
// Assert the IDs are correct
require.Len(t, ids, len(testObjs))
for _, obj := range testObjs {
assert.Contains(t, ids, obj.ID)
}
})
t.Run("ListIDs_EmptyResult", func(t *testing.T) {
// Define a query that matches no objects
query := builderimp.NewQueryImp().Comparison(builderimp.NewFieldImp("name"), builder.Eq, "nonExistentName")
// Call ListIDs
ids, err := repository.ListIDs(ctx, query)
require.NoError(t, err)
// Assert no IDs are returned
assert.Empty(t, ids)
})
}

View File

@@ -0,0 +1,21 @@
package rolesdb
import (
"github.com/tech/sendico/pkg/db/template"
"github.com/tech/sendico/pkg/mlogger"
"github.com/tech/sendico/pkg/model"
"github.com/tech/sendico/pkg/mservice"
"go.mongodb.org/mongo-driver/mongo"
)
type RolesDB struct {
template.DBImp[*model.RoleDescription]
}
func Create(logger mlogger.Logger, db *mongo.Database) (*RolesDB, error) {
p := &RolesDB{
DBImp: *template.Create[*model.RoleDescription](logger, mservice.Roles, db),
}
return p, nil
}

View File

@@ -0,0 +1,15 @@
package rolesdb
import (
"context"
"github.com/tech/sendico/pkg/db/repository"
"github.com/tech/sendico/pkg/model"
mutil "github.com/tech/sendico/pkg/mutil/db"
"go.mongodb.org/mongo-driver/bson/primitive"
)
func (db *RolesDB) List(ctx context.Context, organizationRef primitive.ObjectID, cursor *model.ViewCursor) ([]model.RoleDescription, error) {
filter := repository.OrgFilter(organizationRef)
return mutil.GetObjects[model.RoleDescription](ctx, db.Logger, filter, cursor, db.Repository)
}

View File

@@ -0,0 +1,15 @@
package rolesdb
import (
"context"
"github.com/tech/sendico/pkg/db/repository"
"github.com/tech/sendico/pkg/model"
mutil "github.com/tech/sendico/pkg/mutil/db"
"go.mongodb.org/mongo-driver/bson/primitive"
)
func (db *RolesDB) Roles(ctx context.Context, refs []primitive.ObjectID) ([]model.RoleDescription, error) {
filter := repository.Query().In(repository.IDField(), refs)
return mutil.GetObjects[model.RoleDescription](ctx, db.Logger, filter, nil, db.Repository)
}

View File

@@ -0,0 +1,18 @@
package transactionimp
import (
"github.com/tech/sendico/pkg/db/transaction"
"go.mongodb.org/mongo-driver/mongo"
)
type MongoTransactionFactory struct {
client *mongo.Client
}
func (mtf *MongoTransactionFactory) CreateTransaction() transaction.Transaction {
return Create(mtf.client)
}
func CreateFactory(client *mongo.Client) transaction.Factory {
return &MongoTransactionFactory{client: client}
}

View File

@@ -0,0 +1,30 @@
package transactionimp
import (
"context"
"github.com/tech/sendico/pkg/db/transaction"
"go.mongodb.org/mongo-driver/mongo"
)
type MongoTransaction struct {
client *mongo.Client
}
func (mt *MongoTransaction) Execute(ctx context.Context, cb transaction.Callback) (any, error) {
session, err := mt.client.StartSession()
if err != nil {
return nil, err
}
defer session.EndSession(ctx)
callback := func(sessCtx mongo.SessionContext) (any, error) {
return cb(sessCtx)
}
return session.WithTransaction(ctx, callback)
}
func Create(client *mongo.Client) *MongoTransaction {
return &MongoTransaction{client: client}
}

View File

@@ -0,0 +1,118 @@
package tseriesimp
import (
"context"
"errors"
"time"
"github.com/tech/sendico/pkg/db/repository"
"github.com/tech/sendico/pkg/db/repository/builder"
rdecoder "github.com/tech/sendico/pkg/db/repository/decoder"
tsoptions "github.com/tech/sendico/pkg/db/tseries/options"
tspoint "github.com/tech/sendico/pkg/db/tseries/point"
"github.com/tech/sendico/pkg/merrors"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
)
type TimeSeries struct {
options tsoptions.Options
collection *mongo.Collection
}
func NewMongoTimeSeriesCollection(ctx context.Context, db *mongo.Database, tsOpts *tsoptions.Options) (*TimeSeries, error) {
if tsOpts == nil {
return nil, merrors.InvalidArgument("nil time-series options provided")
}
// Configure time-series options
granularity := tsOpts.Granularity.String()
ts := &options.TimeSeriesOptions{
TimeField: tsOpts.TimeField,
Granularity: &granularity,
}
if tsOpts.MetaField != "" {
ts.MetaField = &tsOpts.MetaField
}
// Collection options
collOpts := options.CreateCollection().SetTimeSeriesOptions(ts)
// Set TTL if requested
if tsOpts.ExpireAfter > 0 {
secs := int64(tsOpts.ExpireAfter / time.Second)
collOpts.SetExpireAfterSeconds(secs)
}
if err := db.CreateCollection(ctx, tsOpts.Collection, collOpts); err != nil {
if cmdErr, ok := err.(mongo.CommandError); !ok || cmdErr.Code != 48 {
return nil, err
}
}
return &TimeSeries{collection: db.Collection(tsOpts.Collection), options: *tsOpts}, nil
}
func (ts *TimeSeries) Aggregate(ctx context.Context, pipeline builder.Pipeline, decoder rdecoder.DecodingFunc) error {
queryFunc := func(ctx context.Context, collection *mongo.Collection) (*mongo.Cursor, error) {
return collection.Aggregate(ctx, pipeline.Build())
}
return ts.executeQuery(ctx, decoder, queryFunc)
}
func (ts *TimeSeries) Insert(ctx context.Context, timePoint tspoint.TimePoint) error {
_, err := ts.collection.InsertOne(ctx, timePoint)
return err
}
func (ts *TimeSeries) InsertMany(ctx context.Context, timePoints []tspoint.TimePoint) error {
docs := make([]any, len(timePoints))
for i, p := range timePoints {
docs[i] = p
}
// ignore the result if you like, or capture it
_, err := ts.collection.InsertMany(ctx, docs)
return err
}
type QueryFunc func(ctx context.Context, collection *mongo.Collection) (*mongo.Cursor, error)
func (ts *TimeSeries) executeQuery(ctx context.Context, decoder rdecoder.DecodingFunc, queryFunc QueryFunc) error {
cursor, err := queryFunc(ctx, ts.collection)
if errors.Is(err, mongo.ErrNoDocuments) {
return merrors.NoData("no_items_in_array")
}
if err != nil {
return err
}
defer cursor.Close(ctx)
for cursor.Next(ctx) {
if err := cursor.Err(); err != nil {
return err
}
if err = decoder(cursor); err != nil {
return err
}
}
return nil
}
func (ts *TimeSeries) Query(ctx context.Context, decoder rdecoder.DecodingFunc, query builder.Query, from, to *time.Time) error {
timeLimitedQuery := query
if from != nil {
timeLimitedQuery = timeLimitedQuery.And(repository.Query().Comparison(repository.Field(ts.options.TimeField), builder.Gte, *from))
}
if to != nil {
timeLimitedQuery = timeLimitedQuery.And(repository.Query().Comparison(repository.Field(ts.options.TimeField), builder.Lte, *to))
}
queryFunc := func(ctx context.Context, collection *mongo.Collection) (*mongo.Cursor, error) {
return collection.Find(ctx, timeLimitedQuery.BuildQuery(), timeLimitedQuery.BuildOptions())
}
return ts.executeQuery(ctx, decoder, queryFunc)
}
func (ts *TimeSeries) Name() string {
return ts.collection.Name()
}