service backend
This commit is contained in:
30
api/pkg/db/internal/mongo/accountdb/db.go
Normal file
30
api/pkg/db/internal/mongo/accountdb/db.go
Normal file
@@ -0,0 +1,30 @@
|
||||
package accountdb
|
||||
|
||||
import (
|
||||
ri "github.com/tech/sendico/pkg/db/repository/index"
|
||||
"github.com/tech/sendico/pkg/db/template"
|
||||
"github.com/tech/sendico/pkg/mlogger"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
"github.com/tech/sendico/pkg/mservice"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type AccountDB struct {
|
||||
template.DBImp[*model.Account]
|
||||
}
|
||||
|
||||
func Create(logger mlogger.Logger, db *mongo.Database) (*AccountDB, error) {
|
||||
p := &AccountDB{
|
||||
DBImp: *template.Create[*model.Account](logger, mservice.Accounts, db),
|
||||
}
|
||||
|
||||
if err := p.DBImp.Repository.CreateIndex(&ri.Definition{
|
||||
Keys: []ri.Key{{Field: "login", Sort: ri.Asc}},
|
||||
Unique: true,
|
||||
}); err != nil {
|
||||
p.Logger.Error("Failed to create account database", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
return p, nil
|
||||
}
|
||||
13
api/pkg/db/internal/mongo/accountdb/token.go
Normal file
13
api/pkg/db/internal/mongo/accountdb/token.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package accountdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/tech/sendico/pkg/db/repository"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
)
|
||||
|
||||
func (db *AccountDB) GetByToken(ctx context.Context, email string) (*model.Account, error) {
|
||||
var account model.Account
|
||||
return &account, db.FindOne(ctx, repository.Query().Filter(repository.Field("verifyToken"), email), &account)
|
||||
}
|
||||
21
api/pkg/db/internal/mongo/accountdb/user.go
Executable file
21
api/pkg/db/internal/mongo/accountdb/user.go
Executable file
@@ -0,0 +1,21 @@
|
||||
package accountdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/tech/sendico/pkg/db/repository"
|
||||
"github.com/tech/sendico/pkg/db/repository/builder"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
mutil "github.com/tech/sendico/pkg/mutil/db"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
)
|
||||
|
||||
func (db *AccountDB) GetAccountsByRefs(ctx context.Context, orgRef primitive.ObjectID, refs []primitive.ObjectID) ([]model.Account, error) {
|
||||
filter := repository.Query().Comparison(repository.IDField(), builder.In, refs)
|
||||
return mutil.GetObjects[model.Account](ctx, db.Logger, filter, nil, db.Repository)
|
||||
}
|
||||
|
||||
func (db *AccountDB) GetByEmail(ctx context.Context, email string) (*model.Account, error) {
|
||||
var account model.Account
|
||||
return &account, db.FindOne(ctx, repository.Filter("login", email), &account)
|
||||
}
|
||||
99
api/pkg/db/internal/mongo/archivable/archivable.go
Normal file
99
api/pkg/db/internal/mongo/archivable/archivable.go
Normal file
@@ -0,0 +1,99 @@
|
||||
package archivable
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/tech/sendico/pkg/db/repository"
|
||||
"github.com/tech/sendico/pkg/db/storable"
|
||||
"github.com/tech/sendico/pkg/mlogger"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
"github.com/tech/sendico/pkg/mutil/mzap"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// ArchivableDB implements archive management for entities with model.Archivable embedded
|
||||
type ArchivableDB[T storable.Storable] struct {
|
||||
repo repository.Repository
|
||||
logger mlogger.Logger
|
||||
createEmpty func() T
|
||||
getArchivable func(T) model.Archivable
|
||||
}
|
||||
|
||||
// NewArchivableDB creates a new ArchivableDB instance
|
||||
func NewArchivableDB[T storable.Storable](
|
||||
repo repository.Repository,
|
||||
logger mlogger.Logger,
|
||||
createEmpty func() T,
|
||||
getArchivable func(T) model.Archivable,
|
||||
) *ArchivableDB[T] {
|
||||
return &ArchivableDB[T]{
|
||||
repo: repo,
|
||||
logger: logger,
|
||||
createEmpty: createEmpty,
|
||||
getArchivable: getArchivable,
|
||||
}
|
||||
}
|
||||
|
||||
// SetArchived sets the archived status of an entity
|
||||
func (db *ArchivableDB[T]) SetArchived(ctx context.Context, objectRef primitive.ObjectID, archived bool) error {
|
||||
// Get current object to check current archived status
|
||||
obj := db.createEmpty()
|
||||
if err := db.repo.Get(ctx, objectRef, obj); err != nil {
|
||||
db.logger.Warn("Failed to get object for setting archived status",
|
||||
zap.Error(err),
|
||||
mzap.ObjRef("object_ref", objectRef),
|
||||
zap.Bool("archived", archived))
|
||||
return err
|
||||
}
|
||||
|
||||
// Extract archivable from the object
|
||||
archivable := db.getArchivable(obj)
|
||||
currentArchived := archivable.IsArchived()
|
||||
if currentArchived == archived {
|
||||
db.logger.Debug("No change needed - same archived status",
|
||||
mzap.ObjRef("object_ref", objectRef),
|
||||
zap.Bool("archived", archived))
|
||||
return nil // No change needed
|
||||
}
|
||||
|
||||
// Set the archived status
|
||||
patch := repository.Patch().Set(repository.IsArchivedField(), archived)
|
||||
if err := db.repo.Patch(ctx, objectRef, patch); err != nil {
|
||||
db.logger.Warn("Failed to set archived status on object",
|
||||
zap.Error(err),
|
||||
mzap.ObjRef("object_ref", objectRef),
|
||||
zap.Bool("archived", archived))
|
||||
return err
|
||||
}
|
||||
|
||||
db.logger.Debug("Successfully set archived status on object",
|
||||
mzap.ObjRef("object_ref", objectRef),
|
||||
zap.Bool("archived", archived))
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsArchived checks if an entity is archived
|
||||
func (db *ArchivableDB[T]) IsArchived(ctx context.Context, objectRef primitive.ObjectID) (bool, error) {
|
||||
obj := db.createEmpty()
|
||||
|
||||
if err := db.repo.Get(ctx, objectRef, obj); err != nil {
|
||||
db.logger.Warn("Failed to get object for checking archived status",
|
||||
zap.Error(err),
|
||||
mzap.ObjRef("object_ref", objectRef))
|
||||
return false, err
|
||||
}
|
||||
|
||||
archivable := db.getArchivable(obj)
|
||||
return archivable.IsArchived(), nil
|
||||
}
|
||||
|
||||
// Archive archives an entity (sets archived to true)
|
||||
func (db *ArchivableDB[T]) Archive(ctx context.Context, objectRef primitive.ObjectID) error {
|
||||
return db.SetArchived(ctx, objectRef, true)
|
||||
}
|
||||
|
||||
// Unarchive unarchives an entity (sets archived to false)
|
||||
func (db *ArchivableDB[T]) Unarchive(ctx context.Context, objectRef primitive.ObjectID) error {
|
||||
return db.SetArchived(ctx, objectRef, false)
|
||||
}
|
||||
175
api/pkg/db/internal/mongo/archivable/archivable_test.go
Normal file
175
api/pkg/db/internal/mongo/archivable/archivable_test.go
Normal file
@@ -0,0 +1,175 @@
|
||||
//go:build integration
|
||||
// +build integration
|
||||
|
||||
package archivable
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/tech/sendico/pkg/db/internal/mongo/repositoryimp"
|
||||
"github.com/tech/sendico/pkg/db/storable"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/testcontainers/testcontainers-go"
|
||||
"github.com/testcontainers/testcontainers-go/modules/mongodb"
|
||||
"github.com/testcontainers/testcontainers-go/wait"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
"go.mongodb.org/mongo-driver/mongo/options"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// TestArchivableObject represents a test object with archivable functionality
|
||||
type TestArchivableObject struct {
|
||||
storable.Base `bson:",inline" json:",inline"`
|
||||
model.ArchivableBase `bson:",inline" json:",inline"`
|
||||
Name string `bson:"name" json:"name"`
|
||||
}
|
||||
|
||||
func (t *TestArchivableObject) Collection() string {
|
||||
return "testArchivableObject"
|
||||
}
|
||||
|
||||
func (t *TestArchivableObject) GetArchivable() model.Archivable {
|
||||
return &t.ArchivableBase
|
||||
}
|
||||
|
||||
func TestArchivableDB(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
// Start MongoDB container (stable)
|
||||
mongoContainer, err := mongodb.Run(ctx,
|
||||
"mongo:latest",
|
||||
mongodb.WithUsername("test"),
|
||||
mongodb.WithPassword("test"),
|
||||
testcontainers.WithWaitStrategy(wait.ForListeningPort("27017/tcp").WithStartupTimeout(2*time.Minute)),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
termCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
if err := mongoContainer.Terminate(termCtx); err != nil {
|
||||
t.Logf("Failed to terminate container: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Get MongoDB connection string
|
||||
mongoURI, err := mongoContainer.ConnectionString(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Connect to MongoDB
|
||||
client, err := mongo.Connect(ctx, options.Client().ApplyURI(mongoURI))
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
if err := client.Disconnect(context.Background()); err != nil {
|
||||
t.Logf("Failed to disconnect from MongoDB: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Ping the database
|
||||
err = client.Ping(ctx, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create repository
|
||||
repo := repositoryimp.NewMongoRepository(client.Database("test_"+t.Name()), "testArchivableCollection")
|
||||
|
||||
// Create archivable DB
|
||||
archivableDB := NewArchivableDB(
|
||||
repo,
|
||||
zap.NewNop(),
|
||||
func() *TestArchivableObject { return &TestArchivableObject{} },
|
||||
func(obj *TestArchivableObject) model.Archivable { return obj.GetArchivable() },
|
||||
)
|
||||
|
||||
t.Run("SetArchived_Success", func(t *testing.T) {
|
||||
obj := &TestArchivableObject{Name: "test", ArchivableBase: model.ArchivableBase{Archived: false}}
|
||||
err := repo.Insert(ctx, obj, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = archivableDB.SetArchived(ctx, obj.ID, true)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result TestArchivableObject
|
||||
err = repo.Get(ctx, obj.ID, &result)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, result.IsArchived())
|
||||
})
|
||||
|
||||
t.Run("SetArchived_NoChange", func(t *testing.T) {
|
||||
obj := &TestArchivableObject{Name: "test", ArchivableBase: model.ArchivableBase{Archived: true}}
|
||||
err := repo.Insert(ctx, obj, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = archivableDB.SetArchived(ctx, obj.ID, true)
|
||||
require.NoError(t, err) // Should not error, just not change anything
|
||||
|
||||
var result TestArchivableObject
|
||||
err = repo.Get(ctx, obj.ID, &result)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, result.IsArchived())
|
||||
})
|
||||
|
||||
t.Run("SetArchived_Unarchive", func(t *testing.T) {
|
||||
obj := &TestArchivableObject{Name: "test", ArchivableBase: model.ArchivableBase{Archived: true}}
|
||||
err := repo.Insert(ctx, obj, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = archivableDB.SetArchived(ctx, obj.ID, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result TestArchivableObject
|
||||
err = repo.Get(ctx, obj.ID, &result)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, result.IsArchived())
|
||||
})
|
||||
|
||||
t.Run("IsArchived_True", func(t *testing.T) {
|
||||
obj := &TestArchivableObject{Name: "test", ArchivableBase: model.ArchivableBase{Archived: true}}
|
||||
err := repo.Insert(ctx, obj, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
isArchived, err := archivableDB.IsArchived(ctx, obj.ID)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, isArchived)
|
||||
})
|
||||
|
||||
t.Run("IsArchived_False", func(t *testing.T) {
|
||||
obj := &TestArchivableObject{Name: "test", ArchivableBase: model.ArchivableBase{Archived: false}}
|
||||
err := repo.Insert(ctx, obj, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
isArchived, err := archivableDB.IsArchived(ctx, obj.ID)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, isArchived)
|
||||
})
|
||||
|
||||
t.Run("Archive_Success", func(t *testing.T) {
|
||||
obj := &TestArchivableObject{Name: "test", ArchivableBase: model.ArchivableBase{Archived: false}}
|
||||
err := repo.Insert(ctx, obj, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = archivableDB.Archive(ctx, obj.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result TestArchivableObject
|
||||
err = repo.Get(ctx, obj.ID, &result)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, result.IsArchived())
|
||||
})
|
||||
|
||||
t.Run("Unarchive_Success", func(t *testing.T) {
|
||||
obj := &TestArchivableObject{Name: "test", ArchivableBase: model.ArchivableBase{Archived: true}}
|
||||
err := repo.Insert(ctx, obj, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = archivableDB.Unarchive(ctx, obj.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result TestArchivableObject
|
||||
err = repo.Get(ctx, obj.ID, &result)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, result.IsArchived())
|
||||
})
|
||||
}
|
||||
257
api/pkg/db/internal/mongo/db.go
Executable file
257
api/pkg/db/internal/mongo/db.go
Executable file
@@ -0,0 +1,257 @@
|
||||
package mongo
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
|
||||
"github.com/mitchellh/mapstructure"
|
||||
"github.com/tech/sendico/pkg/auth"
|
||||
"github.com/tech/sendico/pkg/db/account"
|
||||
"github.com/tech/sendico/pkg/db/internal/mongo/accountdb"
|
||||
"github.com/tech/sendico/pkg/db/internal/mongo/invitationdb"
|
||||
"github.com/tech/sendico/pkg/db/internal/mongo/organizationdb"
|
||||
"github.com/tech/sendico/pkg/db/internal/mongo/policiesdb"
|
||||
"github.com/tech/sendico/pkg/db/internal/mongo/refreshtokensdb"
|
||||
"github.com/tech/sendico/pkg/db/internal/mongo/rolesdb"
|
||||
"github.com/tech/sendico/pkg/db/internal/mongo/transactionimp"
|
||||
"github.com/tech/sendico/pkg/db/invitation"
|
||||
"github.com/tech/sendico/pkg/db/organization"
|
||||
"github.com/tech/sendico/pkg/db/policy"
|
||||
"github.com/tech/sendico/pkg/db/refreshtokens"
|
||||
"github.com/tech/sendico/pkg/db/repository"
|
||||
"github.com/tech/sendico/pkg/db/role"
|
||||
"github.com/tech/sendico/pkg/db/transaction"
|
||||
"github.com/tech/sendico/pkg/mlogger"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
"github.com/tech/sendico/pkg/mservice"
|
||||
mutil "github.com/tech/sendico/pkg/mutil/config"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
"go.mongodb.org/mongo-driver/mongo/options"
|
||||
"go.mongodb.org/mongo-driver/mongo/readpref"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// Config represents configuration
|
||||
type Config struct {
|
||||
Port *string `mapstructure:"port"`
|
||||
PortEnv *string `mapstructure:"port_env"`
|
||||
User *string `mapstructure:"user"`
|
||||
UserEnv *string `mapstructure:"user_env"`
|
||||
PasswordEnv string `mapstructure:"password_env"`
|
||||
Database *string `mapstructure:"database"`
|
||||
DatabaseEnv *string `mapstructure:"database_env"`
|
||||
Host *string `mapstructure:"host"`
|
||||
HostEnv *string `mapstructure:"host_env"`
|
||||
AuthSource *string `mapstructure:"auth_source,omitempty"`
|
||||
AuthSourceEnv *string `mapstructure:"auth_source_env,omitempty"`
|
||||
AuthMechanism *string `mapstructure:"auth_mechanism,omitempty"`
|
||||
AuthMechanismEnv *string `mapstructure:"auth_mechanism_env,omitempty"`
|
||||
ReplicaSet *string `mapstructure:"replica_set,omitempty"`
|
||||
ReplicaSetEnv *string `mapstructure:"replica_set_env,omitempty"`
|
||||
Enforcer *auth.Config `mapstructure:"enforcer"`
|
||||
}
|
||||
|
||||
type DBSettings struct {
|
||||
Host string
|
||||
Port string
|
||||
User string
|
||||
Password string
|
||||
Database string
|
||||
AuthSource string
|
||||
AuthMechanism string
|
||||
ReplicaSet string
|
||||
}
|
||||
|
||||
func newProtectedDB[T any](
|
||||
db *DB,
|
||||
create func(ctx context.Context, logger mlogger.Logger, enforcer auth.Enforcer, pdb policy.DB, client *mongo.Database) (T, error),
|
||||
) (T, error) {
|
||||
pdb, err := db.NewPoliciesDB()
|
||||
if err != nil {
|
||||
db.logger.Warn("Failed to create policies database", zap.Error(err))
|
||||
var zero T
|
||||
return zero, err
|
||||
}
|
||||
return create(context.Background(), db.logger, db.Enforcer(), pdb, db.db())
|
||||
}
|
||||
|
||||
func Config2DBSettings(logger mlogger.Logger, config *Config) *DBSettings {
|
||||
p := new(DBSettings)
|
||||
p.Port = mutil.GetConfigValue(logger, "port", "port_env", config.Port, config.PortEnv)
|
||||
p.Database = mutil.GetConfigValue(logger, "database", "database_env", config.Database, config.DatabaseEnv)
|
||||
p.Password = os.Getenv(config.PasswordEnv)
|
||||
p.User = mutil.GetConfigValue(logger, "user", "user_env", config.User, config.UserEnv)
|
||||
p.Host = mutil.GetConfigValue(logger, "host", "host_env", config.Host, config.HostEnv)
|
||||
p.AuthSource = mutil.GetConfigValue(logger, "auth_source", "auth_source_env", config.AuthSource, config.AuthSourceEnv)
|
||||
p.AuthMechanism = mutil.GetConfigValue(logger, "auth_mechanism", "auth_mechanism_env", config.AuthMechanism, config.AuthMechanismEnv)
|
||||
p.ReplicaSet = mutil.GetConfigValue(logger, "replica_set", "replica_set_env", config.ReplicaSet, config.ReplicaSetEnv)
|
||||
return p
|
||||
}
|
||||
|
||||
func decodeConfig(logger mlogger.Logger, settings model.SettingsT) (*Config, *DBSettings, error) {
|
||||
var config Config
|
||||
if err := mapstructure.Decode(settings, &config); err != nil {
|
||||
logger.Warn("Failed to decode settings", zap.Error(err), zap.Any("settings", settings))
|
||||
return nil, nil, err
|
||||
}
|
||||
dbSettings := Config2DBSettings(logger, &config)
|
||||
return &config, dbSettings, nil
|
||||
}
|
||||
|
||||
func dialMongo(logger mlogger.Logger, dbSettings *DBSettings) (*mongo.Client, error) {
|
||||
cred := options.Credential{
|
||||
AuthMechanism: dbSettings.AuthMechanism,
|
||||
AuthSource: dbSettings.AuthSource,
|
||||
Username: dbSettings.User,
|
||||
Password: dbSettings.Password,
|
||||
}
|
||||
dbURI := buildURI(dbSettings)
|
||||
|
||||
client, err := mongo.Connect(context.Background(), options.Client().ApplyURI(dbURI).SetAuth(cred))
|
||||
if err != nil {
|
||||
logger.Error("Unable to connect to database", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
logger.Info("Connected successfully", zap.String("uri", dbURI))
|
||||
|
||||
if err := client.Ping(context.Background(), readpref.Primary()); err != nil {
|
||||
logger.Error("Unable to ping database", zap.Error(err))
|
||||
_ = client.Disconnect(context.Background())
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func ConnectClient(logger mlogger.Logger, settings model.SettingsT) (*mongo.Client, *Config, *DBSettings, error) {
|
||||
config, dbSettings, err := decodeConfig(logger, settings)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
client, err := dialMongo(logger, dbSettings)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
return client, config, dbSettings, nil
|
||||
}
|
||||
|
||||
// DB represents the structure of the database
|
||||
type DB struct {
|
||||
logger mlogger.Logger
|
||||
config *DBSettings
|
||||
client *mongo.Client
|
||||
enforcer auth.Enforcer
|
||||
manager auth.Manager
|
||||
pdb policy.DB
|
||||
}
|
||||
|
||||
func (db *DB) db() *mongo.Database {
|
||||
return db.client.Database(db.config.Database)
|
||||
}
|
||||
|
||||
func (db *DB) NewAccountDB() (account.DB, error) {
|
||||
return accountdb.Create(db.logger, db.db())
|
||||
}
|
||||
|
||||
func (db *DB) NewOrganizationDB() (organization.DB, error) {
|
||||
pdb, err := db.NewPoliciesDB()
|
||||
if err != nil {
|
||||
db.logger.Warn("Failed to create policies database", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
organizationDB, err := organizationdb.Create(context.Background(), db.logger, db.Enforcer(), pdb, db.db())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Return the concrete type - interface mismatch will be handled at runtime
|
||||
// TODO: Update organization.DB interface to match implementation signatures
|
||||
return organizationDB, nil
|
||||
}
|
||||
|
||||
func (db *DB) NewRefreshTokensDB() (refreshtokens.DB, error) {
|
||||
return refreshtokensdb.Create(db.logger, db.db())
|
||||
}
|
||||
|
||||
func (db *DB) NewInvitationsDB() (invitation.DB, error) {
|
||||
return newProtectedDB(db, invitationdb.Create)
|
||||
}
|
||||
|
||||
func (db *DB) NewPoliciesDB() (policy.DB, error) {
|
||||
return db.pdb, nil
|
||||
}
|
||||
|
||||
func (db *DB) NewRolesDB() (role.DB, error) {
|
||||
return rolesdb.Create(db.logger, db.db())
|
||||
}
|
||||
|
||||
func (db *DB) TransactionFactory() transaction.Factory {
|
||||
return transactionimp.CreateFactory(db.client)
|
||||
}
|
||||
|
||||
func (db *DB) Permissions() auth.Provider {
|
||||
return db
|
||||
}
|
||||
|
||||
func (db *DB) Manager() auth.Manager {
|
||||
return db.manager
|
||||
}
|
||||
|
||||
func (db *DB) Enforcer() auth.Enforcer {
|
||||
return db.enforcer
|
||||
}
|
||||
|
||||
func (db *DB) GetPolicyDescription(ctx context.Context, resource mservice.Type) (*model.PolicyDescription, error) {
|
||||
var policyDescription model.PolicyDescription
|
||||
return &policyDescription, db.pdb.FindOne(ctx, repository.Filter("resourceTypes", resource), &policyDescription)
|
||||
}
|
||||
|
||||
func (db *DB) CloseConnection() {
|
||||
if err := db.client.Disconnect(context.Background()); err != nil {
|
||||
db.logger.Warn("Failed to close connection", zap.Error(err))
|
||||
}
|
||||
db.logger.Info("Database connection closed")
|
||||
}
|
||||
|
||||
// NewConnection creates a new database connection
|
||||
func NewConnection(logger mlogger.Logger, settings model.SettingsT) (*DB, error) {
|
||||
client, config, dbSettings, err := ConnectClient(logger, settings)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
db := &DB{
|
||||
logger: logger.Named("db"),
|
||||
config: dbSettings,
|
||||
client: client,
|
||||
}
|
||||
|
||||
cleanup := func(ctx context.Context) {
|
||||
if err := client.Disconnect(ctx); err != nil {
|
||||
logger.Warn("Failed to close MongoDB connection", zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
rdb, err := db.NewRolesDB()
|
||||
if err != nil {
|
||||
db.logger.Warn("Failed to create roles database", zap.Error(err))
|
||||
cleanup(context.Background())
|
||||
return nil, err
|
||||
}
|
||||
if db.pdb, err = policiesdb.Create(db.logger, db.db()); err != nil {
|
||||
db.logger.Warn("Failed to create policies database", zap.Error(err))
|
||||
cleanup(context.Background())
|
||||
return nil, err
|
||||
}
|
||||
if db.enforcer, db.manager, err = auth.CreateAuth(logger, db.client, db.db(), db.pdb, rdb, config.Enforcer); err != nil {
|
||||
db.logger.Warn("Failed to create permissions enforcer", zap.Error(err))
|
||||
cleanup(context.Background())
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return db, nil
|
||||
}
|
||||
144
api/pkg/db/internal/mongo/indexable/README.md
Normal file
144
api/pkg/db/internal/mongo/indexable/README.md
Normal file
@@ -0,0 +1,144 @@
|
||||
# Indexable Implementation (Refactored)
|
||||
|
||||
## Overview
|
||||
|
||||
This package provides a refactored implementation of the `indexable.DB` interface that uses `mutil.GetObjects` for better consistency with the existing codebase. The implementation has been moved to the mongo folder and includes a factory for project indexable in the pkg/db folder.
|
||||
|
||||
## Structure
|
||||
|
||||
### 1. `api/pkg/db/internal/mongo/indexable/indexable.go`
|
||||
- **`ReorderTemplate[T]`**: Generic template function that uses `mutil.GetObjects` for fetching objects
|
||||
- **`IndexableDB`**: Base struct for creating concrete implementations
|
||||
- **Type-safe implementation**: Uses Go generics with proper type constraints
|
||||
|
||||
### 2. `api/pkg/db/project_indexable.go`
|
||||
- **`ProjectIndexableDB`**: Factory implementation for Project objects
|
||||
- **`NewProjectIndexableDB`**: Constructor function
|
||||
- **`ReorderTemplate`**: Duplicate of the mongo version for convenience
|
||||
|
||||
## Key Changes from Previous Implementation
|
||||
|
||||
### 1. **Uses `mutil.GetObjects`**
|
||||
```go
|
||||
// Old implementation (manual cursor handling)
|
||||
err = repo.FindManyByFilter(ctx, filter, func(cursor *mongo.Cursor) error {
|
||||
var obj T
|
||||
if err := cursor.Decode(&obj); err != nil {
|
||||
return err
|
||||
}
|
||||
objects = append(objects, obj)
|
||||
return nil
|
||||
})
|
||||
|
||||
// New implementation (using mutil.GetObjects)
|
||||
objects, err := mutil.GetObjects[T](
|
||||
ctx,
|
||||
logger,
|
||||
filterFunc().
|
||||
And(
|
||||
repository.IndexOpFilter(minIdx, builder.Gte),
|
||||
repository.IndexOpFilter(maxIdx, builder.Lte),
|
||||
),
|
||||
nil, nil, nil, // limit, offset, isArchived
|
||||
repo,
|
||||
)
|
||||
```
|
||||
|
||||
### 2. **Moved to Mongo Folder**
|
||||
- Location: `api/pkg/db/internal/mongo/indexable/`
|
||||
- Consistent with other mongo implementations
|
||||
- Better organization within the codebase
|
||||
|
||||
### 3. **Added Factory in pkg/db**
|
||||
- Location: `api/pkg/db/project_indexable.go`
|
||||
- Provides easy access to project indexable functionality
|
||||
- Includes logger parameter for better error handling
|
||||
|
||||
## Usage
|
||||
|
||||
### Using the Factory (Recommended)
|
||||
|
||||
```go
|
||||
import "github.com/tech/sendico/pkg/db"
|
||||
|
||||
// Create a project indexable DB
|
||||
projectDB := db.NewProjectIndexableDB(repo, logger, organizationRef)
|
||||
|
||||
// Reorder a project
|
||||
err := projectDB.Reorder(ctx, projectID, newIndex)
|
||||
if err != nil {
|
||||
// Handle error
|
||||
}
|
||||
```
|
||||
|
||||
### Using the Template Directly
|
||||
|
||||
```go
|
||||
import "github.com/tech/sendico/pkg/db/internal/mongo/indexable"
|
||||
|
||||
// Define helper functions
|
||||
getIndexable := func(p *model.Project) *model.Indexable {
|
||||
return &p.Indexable
|
||||
}
|
||||
|
||||
updateIndexable := func(p *model.Project, newIndex int) {
|
||||
p.Index = newIndex
|
||||
}
|
||||
|
||||
createEmpty := func() *model.Project {
|
||||
return &model.Project{}
|
||||
}
|
||||
|
||||
filterFunc := func() builder.Query {
|
||||
return repository.OrgFilter(organizationRef)
|
||||
}
|
||||
|
||||
// Use the template
|
||||
err := indexable.ReorderTemplate(
|
||||
ctx,
|
||||
logger,
|
||||
repo,
|
||||
objectRef,
|
||||
newIndex,
|
||||
filterFunc,
|
||||
getIndexable,
|
||||
updateIndexable,
|
||||
createEmpty,
|
||||
)
|
||||
```
|
||||
|
||||
## Benefits of Refactoring
|
||||
|
||||
1. **Consistency**: Uses `mutil.GetObjects` like other parts of the codebase
|
||||
2. **Better Error Handling**: Includes logger parameter for proper error logging
|
||||
3. **Organization**: Moved to appropriate folder structure
|
||||
4. **Factory Pattern**: Easy-to-use factory for common use cases
|
||||
5. **Type Safety**: Maintains compile-time type checking
|
||||
6. **Performance**: Leverages existing optimized `mutil.GetObjects` implementation
|
||||
|
||||
## Testing
|
||||
|
||||
### Mongo Implementation Tests
|
||||
```bash
|
||||
go test ./db/internal/mongo/indexable -v
|
||||
```
|
||||
|
||||
### Factory Tests
|
||||
```bash
|
||||
go test ./db -v
|
||||
```
|
||||
|
||||
## Integration
|
||||
|
||||
The refactored implementation is ready for integration with existing project reordering APIs. The factory pattern makes it easy to add reordering functionality to any service that needs to reorder projects within an organization.
|
||||
|
||||
## Migration from Old Implementation
|
||||
|
||||
If you were using the old implementation:
|
||||
|
||||
1. **Update imports**: Change from `api/pkg/db/internal/indexable` to `api/pkg/db`
|
||||
2. **Use factory**: Replace manual template usage with `NewProjectIndexableDB`
|
||||
3. **Add logger**: Include a logger parameter in your constructor calls
|
||||
4. **Update tests**: Use the new test structure if needed
|
||||
|
||||
The API remains the same, so existing code should work with minimal changes.
|
||||
174
api/pkg/db/internal/mongo/indexable/USAGE.md
Normal file
174
api/pkg/db/internal/mongo/indexable/USAGE.md
Normal file
@@ -0,0 +1,174 @@
|
||||
# Indexable Usage Guide
|
||||
|
||||
## Generic Implementation for Any Indexable Struct
|
||||
|
||||
The implementation is now **generic** and supports **any struct that embeds `model.Indexable`**!
|
||||
|
||||
- **Interface**: `api/pkg/db/indexable.go` - defines the contract
|
||||
- **Implementation**: `api/pkg/db/internal/mongo/indexable/` - generic implementation
|
||||
- **Factory**: `api/pkg/db/project_indexable.go` - convenient factory for projects
|
||||
|
||||
## Usage
|
||||
|
||||
### 1. Using the Generic Implementation Directly
|
||||
|
||||
```go
|
||||
import "github.com/tech/sendico/pkg/db/internal/mongo/indexable"
|
||||
|
||||
// For any type that embeds model.Indexable, define helper functions:
|
||||
createEmpty := func() *YourType {
|
||||
return &YourType{}
|
||||
}
|
||||
|
||||
getIndexable := func(obj *YourType) *model.Indexable {
|
||||
return &obj.Indexable
|
||||
}
|
||||
|
||||
// Create generic IndexableDB
|
||||
indexableDB := indexable.NewIndexableDB(repo, logger, createEmpty, getIndexable)
|
||||
|
||||
// Use with single filter parameter
|
||||
err := indexableDB.Reorder(ctx, objectID, newIndex, filter)
|
||||
```
|
||||
|
||||
### 2. Using the Project Factory (Recommended for Projects)
|
||||
|
||||
```go
|
||||
import "github.com/tech/sendico/pkg/db"
|
||||
|
||||
// Create project indexable DB (automatically applies org filter)
|
||||
projectDB := db.NewProjectIndexableDB(repo, logger, organizationRef)
|
||||
|
||||
// Reorder project (org filter applied automatically)
|
||||
err := projectDB.Reorder(ctx, projectID, newIndex, repository.Query())
|
||||
|
||||
// Reorder with additional filters (combined with org filter)
|
||||
additionalFilter := repository.Query().Comparison(repository.Field("state"), builder.Eq, "active")
|
||||
err := projectDB.Reorder(ctx, projectID, newIndex, additionalFilter)
|
||||
```
|
||||
|
||||
## Examples for Different Types
|
||||
|
||||
### Project IndexableDB
|
||||
```go
|
||||
createEmpty := func() *model.Project {
|
||||
return &model.Project{}
|
||||
}
|
||||
|
||||
getIndexable := func(p *model.Project) *model.Indexable {
|
||||
return &p.Indexable
|
||||
}
|
||||
|
||||
projectDB := indexable.NewIndexableDB(repo, logger, createEmpty, getIndexable)
|
||||
orgFilter := repository.OrgFilter(organizationRef)
|
||||
projectDB.Reorder(ctx, projectID, 2, orgFilter)
|
||||
```
|
||||
|
||||
### Status IndexableDB
|
||||
```go
|
||||
createEmpty := func() *model.Status {
|
||||
return &model.Status{}
|
||||
}
|
||||
|
||||
getIndexable := func(s *model.Status) *model.Indexable {
|
||||
return &s.Indexable
|
||||
}
|
||||
|
||||
statusDB := indexable.NewIndexableDB(repo, logger, createEmpty, getIndexable)
|
||||
projectFilter := repository.Query().Comparison(repository.Field("projectRef"), builder.Eq, projectRef)
|
||||
statusDB.Reorder(ctx, statusID, 1, projectFilter)
|
||||
```
|
||||
|
||||
### Task IndexableDB
|
||||
```go
|
||||
createEmpty := func() *model.Task {
|
||||
return &model.Task{}
|
||||
}
|
||||
|
||||
getIndexable := func(t *model.Task) *model.Indexable {
|
||||
return &t.Indexable
|
||||
}
|
||||
|
||||
taskDB := indexable.NewIndexableDB(repo, logger, createEmpty, getIndexable)
|
||||
statusFilter := repository.Query().Comparison(repository.Field("statusRef"), builder.Eq, statusRef)
|
||||
taskDB.Reorder(ctx, taskID, 3, statusFilter)
|
||||
```
|
||||
|
||||
### Priority IndexableDB
|
||||
```go
|
||||
createEmpty := func() *model.Priority {
|
||||
return &model.Priority{}
|
||||
}
|
||||
|
||||
getIndexable := func(p *model.Priority) *model.Indexable {
|
||||
return &p.Indexable
|
||||
}
|
||||
|
||||
priorityDB := indexable.NewIndexableDB(repo, logger, createEmpty, getIndexable)
|
||||
orgFilter := repository.OrgFilter(organizationRef)
|
||||
priorityDB.Reorder(ctx, priorityID, 0, orgFilter)
|
||||
```
|
||||
|
||||
### Global Reordering (No Filter)
|
||||
```go
|
||||
createEmpty := func() *model.Project {
|
||||
return &model.Project{}
|
||||
}
|
||||
|
||||
getIndexable := func(p *model.Project) *model.Indexable {
|
||||
return &p.Indexable
|
||||
}
|
||||
|
||||
globalDB := indexable.NewIndexableDB(repo, logger, createEmpty, getIndexable)
|
||||
// Reorders all items globally (empty filter)
|
||||
globalDB.Reorder(ctx, objectID, 5, repository.Query())
|
||||
```
|
||||
|
||||
## Key Features
|
||||
|
||||
### ✅ **Generic Support**
|
||||
- Works with **any struct** that embeds `model.Indexable`
|
||||
- Type-safe with compile-time checking
|
||||
- No hardcoded types
|
||||
|
||||
### ✅ **Single Filter Parameter**
|
||||
- **Simple**: Single `builder.Query` parameter instead of variadic `interface{}`
|
||||
- **Flexible**: Can incorporate any combination of filters
|
||||
- **Type-safe**: No runtime type assertions needed
|
||||
|
||||
### ✅ **Clean Architecture**
|
||||
- Interface separated from implementation
|
||||
- Generic implementation in internal package
|
||||
- Easy-to-use factories for common types
|
||||
|
||||
## How It Works
|
||||
|
||||
### Generic Algorithm
|
||||
1. **Get current index** using type-specific helper function
|
||||
2. **If no change needed** → return early
|
||||
3. **Apply filter** to scope affected items
|
||||
4. **Shift affected items** using `PatchMany` with `$inc`
|
||||
5. **Update target object** using `Patch` with `$set`
|
||||
|
||||
### Type-Safe Implementation
|
||||
```go
|
||||
type IndexableDB[T storable.Storable] struct {
|
||||
repo repository.Repository
|
||||
logger mlogger.Logger
|
||||
createEmpty func() T
|
||||
getIndexable func(T) *model.Indexable
|
||||
}
|
||||
|
||||
// Single filter parameter - clean and simple
|
||||
func (db *IndexableDB[T]) Reorder(ctx context.Context, objectRef primitive.ObjectID, newIndex int, filter builder.Query) error
|
||||
```
|
||||
|
||||
## Benefits
|
||||
|
||||
✅ **Generic** - Works with any Indexable struct
|
||||
✅ **Type Safe** - Compile-time type checking
|
||||
✅ **Simple** - Single filter parameter instead of variadic interface{}
|
||||
✅ **Efficient** - Uses patches, not full updates
|
||||
✅ **Clean** - Interface separated from implementation
|
||||
|
||||
That's it! **Generic, type-safe, and simple** reordering for any Indexable struct with a single filter parameter.
|
||||
69
api/pkg/db/internal/mongo/indexable/examples.go
Normal file
69
api/pkg/db/internal/mongo/indexable/examples.go
Normal file
@@ -0,0 +1,69 @@
|
||||
package indexable
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/tech/sendico/pkg/db/repository"
|
||||
"github.com/tech/sendico/pkg/db/repository/builder"
|
||||
"github.com/tech/sendico/pkg/mlogger"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
)
|
||||
|
||||
// Example usage of the generic IndexableDB with different types
|
||||
|
||||
// Example 1: Using with Project
|
||||
func ExampleProjectIndexableDB(repo repository.Repository, logger mlogger.Logger, organizationRef primitive.ObjectID) {
|
||||
// Define helper functions for Project
|
||||
createEmpty := func() *model.Project {
|
||||
return &model.Project{}
|
||||
}
|
||||
|
||||
getIndexable := func(p *model.Project) *model.Indexable {
|
||||
return &p.Indexable
|
||||
}
|
||||
|
||||
// Create generic IndexableDB for Project
|
||||
projectDB := NewIndexableDB(repo, logger, createEmpty, getIndexable)
|
||||
|
||||
// Use with organization filter
|
||||
orgFilter := repository.OrgFilter(organizationRef)
|
||||
projectDB.Reorder(context.Background(), primitive.NewObjectID(), 2, orgFilter)
|
||||
}
|
||||
|
||||
// Example 3: Using with Task
|
||||
func ExampleTaskIndexableDB(repo repository.Repository, logger mlogger.Logger, statusRef primitive.ObjectID) {
|
||||
// Define helper functions for Task
|
||||
createEmpty := func() *model.Task {
|
||||
return &model.Task{}
|
||||
}
|
||||
|
||||
getIndexable := func(t *model.Task) *model.Indexable {
|
||||
return &t.Indexable
|
||||
}
|
||||
|
||||
// Create generic IndexableDB for Task
|
||||
taskDB := NewIndexableDB(repo, logger, createEmpty, getIndexable)
|
||||
|
||||
// Use with status filter
|
||||
statusFilter := repository.Query().Comparison(repository.Field("statusRef"), builder.Eq, statusRef)
|
||||
taskDB.Reorder(context.Background(), primitive.NewObjectID(), 3, statusFilter)
|
||||
}
|
||||
|
||||
// Example 5: Using without any filter (global reordering)
|
||||
func ExampleGlobalIndexableDB(repo repository.Repository, logger mlogger.Logger) {
|
||||
// Define helper functions for any Indexable type
|
||||
createEmpty := func() *model.Project {
|
||||
return &model.Project{}
|
||||
}
|
||||
|
||||
getIndexable := func(p *model.Project) *model.Indexable {
|
||||
return &p.Indexable
|
||||
}
|
||||
|
||||
// Create generic IndexableDB without filters
|
||||
globalDB := NewIndexableDB(repo, logger, createEmpty, getIndexable)
|
||||
|
||||
// Use without any filter - reorders all items globally
|
||||
globalDB.Reorder(context.Background(), primitive.NewObjectID(), 5, repository.Query())
|
||||
}
|
||||
122
api/pkg/db/internal/mongo/indexable/indexable.go
Normal file
122
api/pkg/db/internal/mongo/indexable/indexable.go
Normal file
@@ -0,0 +1,122 @@
|
||||
package indexable
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/tech/sendico/pkg/db/repository"
|
||||
"github.com/tech/sendico/pkg/db/repository/builder"
|
||||
"github.com/tech/sendico/pkg/db/storable"
|
||||
"github.com/tech/sendico/pkg/mlogger"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// IndexableDB implements db.IndexableDB interface with generic support
|
||||
type IndexableDB[T storable.Storable] struct {
|
||||
repo repository.Repository
|
||||
logger mlogger.Logger
|
||||
createEmpty func() T
|
||||
getIndexable func(T) *model.Indexable
|
||||
}
|
||||
|
||||
// NewIndexableDB creates a new IndexableDB instance
|
||||
func NewIndexableDB[T storable.Storable](
|
||||
repo repository.Repository,
|
||||
logger mlogger.Logger,
|
||||
createEmpty func() T,
|
||||
getIndexable func(T) *model.Indexable,
|
||||
) *IndexableDB[T] {
|
||||
return &IndexableDB[T]{
|
||||
repo: repo,
|
||||
logger: logger,
|
||||
createEmpty: createEmpty,
|
||||
getIndexable: getIndexable,
|
||||
}
|
||||
}
|
||||
|
||||
// Reorder implements the db.IndexableDB interface with single filter parameter
|
||||
func (db *IndexableDB[T]) Reorder(ctx context.Context, objectRef primitive.ObjectID, newIndex int, filter builder.Query) error {
|
||||
// Get current object to find its index
|
||||
obj := db.createEmpty()
|
||||
err := db.repo.Get(ctx, objectRef, obj)
|
||||
if err != nil {
|
||||
db.logger.Error("Failed to get object for reordering",
|
||||
zap.Error(err),
|
||||
zap.String("object_ref", objectRef.Hex()),
|
||||
zap.Int("new_index", newIndex))
|
||||
return err
|
||||
}
|
||||
|
||||
// Extract index from the object
|
||||
indexable := db.getIndexable(obj)
|
||||
currentIndex := indexable.Index
|
||||
if currentIndex == newIndex {
|
||||
db.logger.Debug("No reordering needed - same index",
|
||||
zap.String("object_ref", objectRef.Hex()),
|
||||
zap.Int("current_index", currentIndex),
|
||||
zap.Int("new_index", newIndex))
|
||||
return nil // No change needed
|
||||
}
|
||||
|
||||
// Simple reordering logic
|
||||
if currentIndex < newIndex {
|
||||
// Moving down: shift items between currentIndex+1 and newIndex up by -1
|
||||
patch := repository.Patch().Inc(repository.IndexField(), -1)
|
||||
reorderFilter := filter.
|
||||
And(repository.IndexOpFilter(currentIndex+1, builder.Gte)).
|
||||
And(repository.IndexOpFilter(newIndex, builder.Lte))
|
||||
|
||||
updatedCount, err := db.repo.PatchMany(ctx, reorderFilter, patch)
|
||||
if err != nil {
|
||||
db.logger.Error("Failed to shift objects during reordering (moving down)",
|
||||
zap.Error(err),
|
||||
zap.String("object_ref", objectRef.Hex()),
|
||||
zap.Int("current_index", currentIndex),
|
||||
zap.Int("new_index", newIndex),
|
||||
zap.Int("updated_count", updatedCount))
|
||||
return err
|
||||
}
|
||||
db.logger.Debug("Successfully shifted objects (moving down)",
|
||||
zap.String("object_ref", objectRef.Hex()),
|
||||
zap.Int("updated_count", updatedCount))
|
||||
} else {
|
||||
// Moving up: shift items between newIndex and currentIndex-1 down by +1
|
||||
patch := repository.Patch().Inc(repository.IndexField(), 1)
|
||||
reorderFilter := filter.
|
||||
And(repository.IndexOpFilter(newIndex, builder.Gte)).
|
||||
And(repository.IndexOpFilter(currentIndex-1, builder.Lte))
|
||||
|
||||
updatedCount, err := db.repo.PatchMany(ctx, reorderFilter, patch)
|
||||
if err != nil {
|
||||
db.logger.Error("Failed to shift objects during reordering (moving up)",
|
||||
zap.Error(err),
|
||||
zap.String("object_ref", objectRef.Hex()),
|
||||
zap.Int("current_index", currentIndex),
|
||||
zap.Int("new_index", newIndex),
|
||||
zap.Int("updated_count", updatedCount))
|
||||
return err
|
||||
}
|
||||
db.logger.Debug("Successfully shifted objects (moving up)",
|
||||
zap.String("object_ref", objectRef.Hex()),
|
||||
zap.Int("updated_count", updatedCount))
|
||||
}
|
||||
|
||||
// Update the target object to new index
|
||||
patch := repository.Patch().Set(repository.IndexField(), newIndex)
|
||||
err = db.repo.Patch(ctx, objectRef, patch)
|
||||
if err != nil {
|
||||
db.logger.Error("Failed to update target object index",
|
||||
zap.Error(err),
|
||||
zap.String("object_ref", objectRef.Hex()),
|
||||
zap.Int("current_index", currentIndex),
|
||||
zap.Int("new_index", newIndex))
|
||||
return err
|
||||
}
|
||||
|
||||
db.logger.Info("Successfully reordered object",
|
||||
zap.String("object_ref", objectRef.Hex()),
|
||||
zap.Int("old_index", currentIndex),
|
||||
zap.Int("new_index", newIndex))
|
||||
return nil
|
||||
}
|
||||
314
api/pkg/db/internal/mongo/indexable/indexable_test.go
Normal file
314
api/pkg/db/internal/mongo/indexable/indexable_test.go
Normal file
@@ -0,0 +1,314 @@
|
||||
//go:build integration
|
||||
// +build integration
|
||||
|
||||
package indexable
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/tech/sendico/pkg/db/repository"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/testcontainers/testcontainers-go"
|
||||
"github.com/testcontainers/testcontainers-go/modules/mongodb"
|
||||
"github.com/testcontainers/testcontainers-go/wait"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
"go.mongodb.org/mongo-driver/mongo/options"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func setupTestDB(t *testing.T) (repository.Repository, func()) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
mongoContainer, err := mongodb.Run(ctx,
|
||||
"mongo:latest",
|
||||
mongodb.WithUsername("root"),
|
||||
mongodb.WithPassword("password"),
|
||||
testcontainers.WithWaitStrategy(wait.ForLog("Waiting for connections")),
|
||||
)
|
||||
require.NoError(t, err, "failed to start MongoDB container")
|
||||
|
||||
mongoURI, err := mongoContainer.ConnectionString(ctx)
|
||||
require.NoError(t, err, "failed to get MongoDB connection string")
|
||||
|
||||
clientOptions := options.Client().ApplyURI(mongoURI)
|
||||
client, err := mongo.Connect(ctx, clientOptions)
|
||||
require.NoError(t, err, "failed to connect to MongoDB")
|
||||
|
||||
db := client.Database("testdb")
|
||||
repo := repository.CreateMongoRepository(db, "projects")
|
||||
|
||||
cleanup := func() {
|
||||
disconnect(ctx, t, client)
|
||||
terminate(ctx, t, mongoContainer)
|
||||
}
|
||||
|
||||
return repo, cleanup
|
||||
}
|
||||
|
||||
func disconnect(ctx context.Context, t *testing.T, client *mongo.Client) {
|
||||
if err := client.Disconnect(ctx); err != nil {
|
||||
t.Logf("failed to disconnect from MongoDB: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func terminate(ctx context.Context, t *testing.T, container testcontainers.Container) {
|
||||
if err := container.Terminate(ctx); err != nil {
|
||||
t.Logf("failed to terminate MongoDB container: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIndexableDB_Reorder(t *testing.T) {
|
||||
repo, cleanup := setupTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
organizationRef := primitive.NewObjectID()
|
||||
logger := zap.NewNop()
|
||||
|
||||
// Create test projects with different indices
|
||||
projects := []*model.Project{
|
||||
{
|
||||
ProjectBase: model.ProjectBase{
|
||||
PermissionBound: model.PermissionBound{
|
||||
OrganizationBoundBase: model.OrganizationBoundBase{
|
||||
OrganizationRef: organizationRef,
|
||||
},
|
||||
},
|
||||
Describable: model.Describable{Name: "Project A"},
|
||||
Indexable: model.Indexable{Index: 0},
|
||||
Mnemonic: "A",
|
||||
State: model.ProjectStateActive,
|
||||
},
|
||||
},
|
||||
{
|
||||
ProjectBase: model.ProjectBase{
|
||||
PermissionBound: model.PermissionBound{
|
||||
OrganizationBoundBase: model.OrganizationBoundBase{
|
||||
OrganizationRef: organizationRef,
|
||||
},
|
||||
},
|
||||
Describable: model.Describable{Name: "Project B"},
|
||||
Indexable: model.Indexable{Index: 1},
|
||||
Mnemonic: "B",
|
||||
State: model.ProjectStateActive,
|
||||
},
|
||||
},
|
||||
{
|
||||
ProjectBase: model.ProjectBase{
|
||||
PermissionBound: model.PermissionBound{
|
||||
OrganizationBoundBase: model.OrganizationBoundBase{
|
||||
OrganizationRef: organizationRef,
|
||||
},
|
||||
},
|
||||
Describable: model.Describable{Name: "Project C"},
|
||||
Indexable: model.Indexable{Index: 2},
|
||||
Mnemonic: "C",
|
||||
State: model.ProjectStateActive,
|
||||
},
|
||||
},
|
||||
{
|
||||
ProjectBase: model.ProjectBase{
|
||||
PermissionBound: model.PermissionBound{
|
||||
OrganizationBoundBase: model.OrganizationBoundBase{
|
||||
OrganizationRef: organizationRef,
|
||||
},
|
||||
},
|
||||
Describable: model.Describable{Name: "Project D"},
|
||||
Indexable: model.Indexable{Index: 3},
|
||||
Mnemonic: "D",
|
||||
State: model.ProjectStateActive,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Insert projects into database
|
||||
for _, project := range projects {
|
||||
project.ID = primitive.NewObjectID()
|
||||
err := repo.Insert(ctx, project, nil)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Create helper functions for Project type
|
||||
createEmpty := func() *model.Project {
|
||||
return &model.Project{}
|
||||
}
|
||||
|
||||
getIndexable := func(p *model.Project) *model.Indexable {
|
||||
return &p.Indexable
|
||||
}
|
||||
|
||||
indexableDB := NewIndexableDB(repo, logger, createEmpty, getIndexable)
|
||||
|
||||
t.Run("Reorder_NoChange", func(t *testing.T) {
|
||||
// Test reordering to the same position (should be no-op)
|
||||
err := indexableDB.Reorder(ctx, projects[1].ID, 1, repository.Query())
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify indices haven't changed
|
||||
var result model.Project
|
||||
err = repo.Get(ctx, projects[0].ID, &result)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 0, result.Index)
|
||||
|
||||
err = repo.Get(ctx, projects[1].ID, &result)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, result.Index)
|
||||
})
|
||||
|
||||
t.Run("Reorder_MoveDown", func(t *testing.T) {
|
||||
// Move Project A (index 0) to index 2
|
||||
err := indexableDB.Reorder(ctx, projects[0].ID, 2, repository.Query())
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the reordering:
|
||||
// Project A should now be at index 2
|
||||
// Project B should be at index 0
|
||||
// Project C should be at index 1
|
||||
// Project D should remain at index 3
|
||||
|
||||
var result model.Project
|
||||
|
||||
// Check Project A (moved to index 2)
|
||||
err = repo.Get(ctx, projects[0].ID, &result)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 2, result.Index)
|
||||
|
||||
// Check Project B (shifted to index 0)
|
||||
err = repo.Get(ctx, projects[1].ID, &result)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 0, result.Index)
|
||||
|
||||
// Check Project C (shifted to index 1)
|
||||
err = repo.Get(ctx, projects[2].ID, &result)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, result.Index)
|
||||
|
||||
// Check Project D (unchanged)
|
||||
err = repo.Get(ctx, projects[3].ID, &result)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 3, result.Index)
|
||||
})
|
||||
|
||||
t.Run("Reorder_MoveUp", func(t *testing.T) {
|
||||
// Reset indices for this test
|
||||
for i, project := range projects {
|
||||
project.Index = i
|
||||
err := repo.Update(ctx, project)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Move Project C (index 2) to index 0
|
||||
err := indexableDB.Reorder(ctx, projects[2].ID, 0, repository.Query())
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the reordering:
|
||||
// Project C should now be at index 0
|
||||
// Project A should be at index 1
|
||||
// Project B should be at index 2
|
||||
// Project D should remain at index 3
|
||||
|
||||
var result model.Project
|
||||
|
||||
// Check Project C (moved to index 0)
|
||||
err = repo.Get(ctx, projects[2].ID, &result)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 0, result.Index)
|
||||
|
||||
// Check Project A (shifted to index 1)
|
||||
err = repo.Get(ctx, projects[0].ID, &result)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, result.Index)
|
||||
|
||||
// Check Project B (shifted to index 2)
|
||||
err = repo.Get(ctx, projects[1].ID, &result)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 2, result.Index)
|
||||
|
||||
// Check Project D (unchanged)
|
||||
err = repo.Get(ctx, projects[3].ID, &result)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 3, result.Index)
|
||||
})
|
||||
|
||||
t.Run("Reorder_WithFilter", func(t *testing.T) {
|
||||
// Reset indices for this test
|
||||
for i, project := range projects {
|
||||
project.Index = i
|
||||
err := repo.Update(ctx, project)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Test reordering with organization filter
|
||||
orgFilter := repository.OrgFilter(organizationRef)
|
||||
err := indexableDB.Reorder(ctx, projects[0].ID, 2, orgFilter)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the reordering worked with filter
|
||||
var result model.Project
|
||||
err = repo.Get(ctx, projects[0].ID, &result)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 2, result.Index)
|
||||
})
|
||||
}
|
||||
|
||||
func TestIndexableDB_EdgeCases(t *testing.T) {
|
||||
repo, cleanup := setupTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
organizationRef := primitive.NewObjectID()
|
||||
logger := zap.NewNop()
|
||||
|
||||
// Create a single project for edge case testing
|
||||
project := &model.Project{
|
||||
ProjectBase: model.ProjectBase{
|
||||
PermissionBound: model.PermissionBound{
|
||||
OrganizationBoundBase: model.OrganizationBoundBase{
|
||||
OrganizationRef: organizationRef,
|
||||
},
|
||||
},
|
||||
Describable: model.Describable{Name: "Test Project"},
|
||||
Indexable: model.Indexable{Index: 0},
|
||||
Mnemonic: "TEST",
|
||||
State: model.ProjectStateActive,
|
||||
},
|
||||
}
|
||||
project.ID = primitive.NewObjectID()
|
||||
err := repo.Insert(ctx, project, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create helper functions for Project type
|
||||
createEmpty := func() *model.Project {
|
||||
return &model.Project{}
|
||||
}
|
||||
|
||||
getIndexable := func(p *model.Project) *model.Indexable {
|
||||
return &p.Indexable
|
||||
}
|
||||
|
||||
indexableDB := NewIndexableDB(repo, logger, createEmpty, getIndexable)
|
||||
|
||||
t.Run("Reorder_SingleItem", func(t *testing.T) {
|
||||
// Test reordering a single item (should work but have no effect)
|
||||
err := indexableDB.Reorder(ctx, project.ID, 0, repository.Query())
|
||||
require.NoError(t, err)
|
||||
|
||||
var result model.Project
|
||||
err = repo.Get(ctx, project.ID, &result)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 0, result.Index)
|
||||
})
|
||||
|
||||
t.Run("Reorder_InvalidObjectID", func(t *testing.T) {
|
||||
// Test reordering with an invalid object ID
|
||||
invalidID := primitive.NewObjectID()
|
||||
err := indexableDB.Reorder(ctx, invalidID, 1, repository.Query())
|
||||
require.Error(t, err) // Should fail because object doesn't exist
|
||||
})
|
||||
}
|
||||
12
api/pkg/db/internal/mongo/invitationdb/accept.go
Normal file
12
api/pkg/db/internal/mongo/invitationdb/accept.go
Normal file
@@ -0,0 +1,12 @@
|
||||
package invitationdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
)
|
||||
|
||||
func (db *InvitationDB) Accept(ctx context.Context, invitationRef primitive.ObjectID) error {
|
||||
return db.updateStatus(ctx, invitationRef, model.InvitationAccepted)
|
||||
}
|
||||
49
api/pkg/db/internal/mongo/invitationdb/archived.go
Normal file
49
api/pkg/db/internal/mongo/invitationdb/archived.go
Normal file
@@ -0,0 +1,49 @@
|
||||
package invitationdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/tech/sendico/pkg/merrors"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
"github.com/tech/sendico/pkg/mutil/mzap"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// SetArchived sets the archived status of an invitation
|
||||
// Invitation supports archiving through PermissionBound embedding ArchivableBase
|
||||
func (db *InvitationDB) SetArchived(ctx context.Context, accountRef, organizationRef, invitationRef primitive.ObjectID, archived, cascade bool) error {
|
||||
db.DBImp.Logger.Debug("Setting invitation archived status", mzap.ObjRef("invitation_ref", invitationRef), zap.Bool("archived", archived), zap.Bool("cascade", cascade))
|
||||
res, err := db.Enforcer.Enforce(ctx, db.PermissionRef, accountRef, organizationRef, invitationRef, model.ActionUpdate)
|
||||
if err != nil {
|
||||
db.DBImp.Logger.Warn("Failed to enforce archivation permission", zap.Error(err), mzap.ObjRef("invitation_ref", invitationRef))
|
||||
return err
|
||||
}
|
||||
if !res {
|
||||
db.DBImp.Logger.Debug("Permission denied for archivation", mzap.ObjRef("invitation_ref", invitationRef))
|
||||
return merrors.AccessDenied(db.Collection, string(model.ActionUpdate), invitationRef)
|
||||
}
|
||||
|
||||
// Get the invitation first
|
||||
var invitation model.Invitation
|
||||
if err := db.Get(ctx, accountRef, invitationRef, &invitation); err != nil {
|
||||
db.DBImp.Logger.Warn("Error retrieving invitation for archival", zap.Error(err), mzap.ObjRef("invitation_ref", invitationRef))
|
||||
return err
|
||||
}
|
||||
|
||||
// Update the invitation's archived status
|
||||
invitation.SetArchived(archived)
|
||||
if err := db.Update(ctx, accountRef, &invitation); err != nil {
|
||||
db.DBImp.Logger.Warn("Error updating invitation archived status", zap.Error(err), mzap.ObjRef("invitation_ref", invitationRef))
|
||||
return err
|
||||
}
|
||||
|
||||
// Note: Currently no cascade dependencies for invitations
|
||||
// If cascade is enabled, we could add logic here for any future dependencies
|
||||
if cascade {
|
||||
db.DBImp.Logger.Debug("Cascade archiving requested but no dependencies to archive for invitation", mzap.ObjRef("invitation_ref", invitationRef))
|
||||
}
|
||||
|
||||
db.DBImp.Logger.Debug("Successfully set invitation archived status", mzap.ObjRef("invitation_ref", invitationRef), zap.Bool("archived", archived))
|
||||
return nil
|
||||
}
|
||||
24
api/pkg/db/internal/mongo/invitationdb/cascade.go
Normal file
24
api/pkg/db/internal/mongo/invitationdb/cascade.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package invitationdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/tech/sendico/pkg/mutil/mzap"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// DeleteCascade deletes an invitation
|
||||
// Invitations don't have cascade dependencies, so this is a simple deletion
|
||||
func (db *InvitationDB) DeleteCascade(ctx context.Context, accountRef, invitationRef primitive.ObjectID) error {
|
||||
db.DBImp.Logger.Debug("Starting invitation cascade deletion", mzap.ObjRef("invitation_ref", invitationRef))
|
||||
|
||||
// Delete the invitation itself (no dependencies to cascade delete)
|
||||
if err := db.Delete(ctx, accountRef, invitationRef); err != nil {
|
||||
db.DBImp.Logger.Error("Error deleting invitation", zap.Error(err), mzap.ObjRef("invitation_ref", invitationRef))
|
||||
return err
|
||||
}
|
||||
|
||||
db.DBImp.Logger.Debug("Successfully deleted invitation", mzap.ObjRef("invitation_ref", invitationRef))
|
||||
return nil
|
||||
}
|
||||
53
api/pkg/db/internal/mongo/invitationdb/db.go
Normal file
53
api/pkg/db/internal/mongo/invitationdb/db.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package invitationdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/tech/sendico/pkg/auth"
|
||||
"github.com/tech/sendico/pkg/db/policy"
|
||||
"github.com/tech/sendico/pkg/db/repository"
|
||||
ri "github.com/tech/sendico/pkg/db/repository/index"
|
||||
"github.com/tech/sendico/pkg/mlogger"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
"github.com/tech/sendico/pkg/mservice"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type InvitationDB struct {
|
||||
auth.ProtectedDBImp[*model.Invitation]
|
||||
}
|
||||
|
||||
func Create(
|
||||
ctx context.Context,
|
||||
logger mlogger.Logger,
|
||||
enforcer auth.Enforcer,
|
||||
pdb policy.DB,
|
||||
db *mongo.Database,
|
||||
) (*InvitationDB, error) {
|
||||
p, err := auth.CreateDBImp[*model.Invitation](ctx, logger, pdb, enforcer, mservice.Invitations, db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// unique email per organization
|
||||
if err := p.DBImp.Repository.CreateIndex(&ri.Definition{
|
||||
Keys: []ri.Key{{Field: repository.OrgField().Build(), Sort: ri.Asc}, {Field: "description.email", Sort: ri.Asc}},
|
||||
Unique: true,
|
||||
}); err != nil {
|
||||
p.DBImp.Logger.Error("Failed to create unique mnemonic index", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// ttl index
|
||||
ttl := int32(0) // zero ttl means expiration on date preset when inserting data
|
||||
if err := p.DBImp.Repository.CreateIndex(&ri.Definition{
|
||||
Keys: []ri.Key{{Field: "expiresAt", Sort: ri.Asc}},
|
||||
TTL: &ttl,
|
||||
}); err != nil {
|
||||
p.DBImp.Logger.Warn("Failed to create ttl index in the invitations", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &InvitationDB{ProtectedDBImp: *p}, nil
|
||||
}
|
||||
12
api/pkg/db/internal/mongo/invitationdb/decline.go
Normal file
12
api/pkg/db/internal/mongo/invitationdb/decline.go
Normal file
@@ -0,0 +1,12 @@
|
||||
package invitationdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
)
|
||||
|
||||
func (db *InvitationDB) Decline(ctx context.Context, invitationRef primitive.ObjectID) error {
|
||||
return db.updateStatus(ctx, invitationRef, model.InvitationDeclined)
|
||||
}
|
||||
121
api/pkg/db/internal/mongo/invitationdb/getpublic.go
Normal file
121
api/pkg/db/internal/mongo/invitationdb/getpublic.go
Normal file
@@ -0,0 +1,121 @@
|
||||
package invitationdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/tech/sendico/pkg/db/repository"
|
||||
"github.com/tech/sendico/pkg/merrors"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
"github.com/tech/sendico/pkg/mservice"
|
||||
"github.com/tech/sendico/pkg/mutil/mzap"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func (db *InvitationDB) GetPublic(ctx context.Context, invitationRef primitive.ObjectID) (*model.PublicInvitation, error) {
|
||||
roleField := repository.Field("role")
|
||||
orgField := repository.Field("organization")
|
||||
accField := repository.Field("account")
|
||||
empField := repository.Field("employee")
|
||||
regField := repository.Field("registrationAcc")
|
||||
descEmailField := repository.Field("description").Dot("email")
|
||||
pipeline := repository.Pipeline().
|
||||
// 0) Filter to exactly the invitation(s) you want
|
||||
Match(repository.IDFilter(invitationRef).And(repository.Filter("status", model.InvitationCreated))).
|
||||
// 1) Lookup the role document
|
||||
Lookup(
|
||||
mservice.Roles,
|
||||
repository.Field("roleRef"),
|
||||
repository.IDField(),
|
||||
roleField,
|
||||
).
|
||||
Unwind(repository.Ref(roleField)).
|
||||
// 2) Lookup the organization document
|
||||
Lookup(
|
||||
mservice.Organizations,
|
||||
repository.Field("organizationRef"),
|
||||
repository.IDField(),
|
||||
orgField,
|
||||
).
|
||||
Unwind(repository.Ref(orgField)).
|
||||
// 3) Lookup the account document
|
||||
Lookup(
|
||||
mservice.Accounts,
|
||||
repository.Field("inviterRef"),
|
||||
repository.IDField(),
|
||||
accField,
|
||||
).
|
||||
Unwind(repository.Ref(accField)).
|
||||
/* 4) do we already have an account whose login == invitation.description ? */
|
||||
Lookup(
|
||||
mservice.Accounts,
|
||||
descEmailField, // local field (invitation.description.email)
|
||||
repository.Field("login"), // foreign field (account.login)
|
||||
regField, // array: 0-length or ≥1
|
||||
).
|
||||
// 5) Projection
|
||||
Project(
|
||||
repository.SimpleAlias(
|
||||
empField.Dot("description"),
|
||||
repository.Ref(accField),
|
||||
),
|
||||
repository.SimpleAlias(
|
||||
empField.Dot("avatarUrl"),
|
||||
repository.Ref(accField.Dot("avatarUrl")),
|
||||
),
|
||||
repository.SimpleAlias(
|
||||
orgField.Dot("description"),
|
||||
repository.Ref(orgField),
|
||||
),
|
||||
repository.SimpleAlias(
|
||||
orgField.Dot("logoUrl"),
|
||||
repository.Ref(orgField.Dot("logoUrl")),
|
||||
),
|
||||
repository.SimpleAlias(
|
||||
roleField,
|
||||
repository.Ref(roleField),
|
||||
),
|
||||
repository.SimpleAlias(
|
||||
repository.Field("invitation"), // ← left-hand side
|
||||
repository.Ref(repository.Field("description")), // ← right-hand side (“$description”)
|
||||
),
|
||||
repository.SimpleAlias(
|
||||
repository.Field("storable"), // ← left-hand side
|
||||
repository.RootRef(), // ← right-hand side (“$description”)
|
||||
),
|
||||
repository.ProjectionExpr(
|
||||
repository.Field("registrationRequired"),
|
||||
repository.Eq(
|
||||
repository.Size(repository.Value(repository.Ref(regField).Build())),
|
||||
repository.Literal(0),
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
var res model.PublicInvitation
|
||||
haveResult := false
|
||||
decoder := func(cur *mongo.Cursor) error {
|
||||
if haveResult {
|
||||
// should never get here
|
||||
db.DBImp.Logger.Warn("Unexpected extra invitation", mzap.ObjRef("invitation_ref", invitationRef))
|
||||
return merrors.Internal("Unexpected extra invitation found by reference")
|
||||
}
|
||||
if e := cur.Decode(&res); e != nil {
|
||||
db.DBImp.Logger.Warn("Failed to decode entity", zap.Error(e), zap.Any("data", cur.Current.String()))
|
||||
return e
|
||||
}
|
||||
haveResult = true
|
||||
return nil
|
||||
}
|
||||
if err := db.DBImp.Repository.Aggregate(ctx, pipeline, decoder); err != nil {
|
||||
db.DBImp.Logger.Warn("Failed to execute aggregation pipeline", zap.Error(err), mzap.ObjRef("invitation_ref", invitationRef))
|
||||
return nil, err
|
||||
}
|
||||
if !haveResult {
|
||||
db.DBImp.Logger.Warn("No results fetched", mzap.ObjRef("invitation_ref", invitationRef))
|
||||
return nil, merrors.NoData(fmt.Sprintf("Invitation %s not found", invitationRef.Hex()))
|
||||
}
|
||||
return &res, nil
|
||||
}
|
||||
28
api/pkg/db/internal/mongo/invitationdb/list.go
Normal file
28
api/pkg/db/internal/mongo/invitationdb/list.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package invitationdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/tech/sendico/pkg/db/repository"
|
||||
"github.com/tech/sendico/pkg/merrors"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
mauth "github.com/tech/sendico/pkg/mutil/db/auth"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
)
|
||||
|
||||
func (db *InvitationDB) List(ctx context.Context, accountRef, organizationRef, _ primitive.ObjectID, cursor *model.ViewCursor) ([]model.Invitation, error) {
|
||||
res, err := mauth.GetProtectedObjects[model.Invitation](
|
||||
ctx,
|
||||
db.DBImp.Logger,
|
||||
accountRef, organizationRef, model.ActionRead,
|
||||
repository.OrgFilter(organizationRef),
|
||||
cursor,
|
||||
db.Enforcer,
|
||||
db.DBImp.Repository,
|
||||
)
|
||||
if errors.Is(err, merrors.ErrNoData) {
|
||||
return []model.Invitation{}, nil
|
||||
}
|
||||
return res, err
|
||||
}
|
||||
26
api/pkg/db/internal/mongo/invitationdb/updatestatus.go
Normal file
26
api/pkg/db/internal/mongo/invitationdb/updatestatus.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package invitationdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/tech/sendico/pkg/db/repository"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
"github.com/tech/sendico/pkg/mutil/mzap"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func (db *InvitationDB) updateStatus(ctx context.Context, invitationRef primitive.ObjectID, newStatus model.InvitationStatus) error {
|
||||
// db.DBImp.Up
|
||||
var inv model.Invitation
|
||||
if err := db.DBImp.FindOne(ctx, repository.IDFilter(invitationRef), &inv); err != nil {
|
||||
db.DBImp.Logger.Warn("Failed to fetch invitation", zap.Error(err), mzap.ObjRef("invitation_ref", invitationRef), zap.String("new_status", string(newStatus)))
|
||||
return err
|
||||
}
|
||||
inv.Status = newStatus
|
||||
if err := db.DBImp.Update(ctx, &inv); err != nil {
|
||||
db.DBImp.Logger.Warn("Failed to update invitation", zap.Error(err), mzap.ObjRef("invitation_ref", invitationRef), zap.String("new_status", string(newStatus)))
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
22
api/pkg/db/internal/mongo/mongo.go
Normal file
22
api/pkg/db/internal/mongo/mongo.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package mongo
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
)
|
||||
|
||||
func buildURI(s *DBSettings) string {
|
||||
u := &url.URL{
|
||||
Scheme: "mongodb",
|
||||
Host: s.Host,
|
||||
Path: "/" + url.PathEscape(s.Database), // /my%20db
|
||||
}
|
||||
|
||||
q := url.Values{}
|
||||
if s.ReplicaSet != "" {
|
||||
q.Set("replicaSet", s.ReplicaSet)
|
||||
}
|
||||
|
||||
u.RawQuery = q.Encode()
|
||||
|
||||
return u.String()
|
||||
}
|
||||
32
api/pkg/db/internal/mongo/organizationdb/archived.go
Normal file
32
api/pkg/db/internal/mongo/organizationdb/archived.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package organizationdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
"github.com/tech/sendico/pkg/mutil/mzap"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// SetArchived sets the archived status of an organization and optionally cascades to projects, tasks, comments, and reactions
|
||||
func (db *OrganizationDB) SetArchived(ctx context.Context, accountRef, organizationRef primitive.ObjectID, archived, cascade bool) error {
|
||||
db.DBImp.Logger.Debug("Setting organization archived status", mzap.ObjRef("organization_ref", organizationRef), zap.Bool("archived", archived), zap.Bool("cascade", cascade))
|
||||
|
||||
// Get the organization first
|
||||
var organization model.Organization
|
||||
if err := db.Get(ctx, accountRef, organizationRef, &organization); err != nil {
|
||||
db.DBImp.Logger.Warn("Error retrieving organization for archival", zap.Error(err), mzap.ObjRef("organization_ref", organizationRef))
|
||||
return err
|
||||
}
|
||||
|
||||
// Update the organization's archived status
|
||||
organization.SetArchived(archived)
|
||||
if err := db.Update(ctx, accountRef, &organization); err != nil {
|
||||
db.DBImp.Logger.Warn("Error updating organization archived status", zap.Error(err), mzap.ObjRef("organization_ref", organizationRef))
|
||||
return err
|
||||
}
|
||||
|
||||
db.DBImp.Logger.Debug("Successfully set organization archived status", mzap.ObjRef("organization_ref", organizationRef), zap.Bool("archived", archived))
|
||||
return nil
|
||||
}
|
||||
23
api/pkg/db/internal/mongo/organizationdb/cascade.go
Normal file
23
api/pkg/db/internal/mongo/organizationdb/cascade.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package organizationdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/tech/sendico/pkg/mutil/mzap"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// DeleteCascade deletes an organization and all its related data (projects, tasks, comments, reactions, statuses)
|
||||
func (db *OrganizationDB) DeleteCascade(ctx context.Context, organizationRef primitive.ObjectID) error {
|
||||
db.DBImp.Logger.Debug("Starting organization deletion with projects", mzap.ObjRef("organization_ref", organizationRef))
|
||||
|
||||
// Delete the organization itself
|
||||
if err := db.Unprotected().Delete(ctx, organizationRef); err != nil {
|
||||
db.DBImp.Logger.Warn("Error deleting organization", zap.Error(err), mzap.ObjRef("organization_ref", organizationRef))
|
||||
return err
|
||||
}
|
||||
|
||||
db.DBImp.Logger.Debug("Successfully deleted organization with projects", mzap.ObjRef("organization_ref", organizationRef))
|
||||
return nil
|
||||
}
|
||||
19
api/pkg/db/internal/mongo/organizationdb/create.go
Normal file
19
api/pkg/db/internal/mongo/organizationdb/create.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package organizationdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/tech/sendico/pkg/merrors"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
)
|
||||
|
||||
func (db *OrganizationDB) Create(ctx context.Context, _, _ primitive.ObjectID, org *model.Organization) error {
|
||||
if org == nil {
|
||||
return merrors.InvalidArgument("Organization object is nil")
|
||||
}
|
||||
org.SetID(primitive.NewObjectID())
|
||||
// Organizaiton reference must be set to the same value as own organization reference
|
||||
org.SetOrganizationRef(*org.GetID())
|
||||
return db.DBImp.Create(ctx, org)
|
||||
}
|
||||
34
api/pkg/db/internal/mongo/organizationdb/db.go
Normal file
34
api/pkg/db/internal/mongo/organizationdb/db.go
Normal file
@@ -0,0 +1,34 @@
|
||||
package organizationdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/tech/sendico/pkg/auth"
|
||||
"github.com/tech/sendico/pkg/db/policy"
|
||||
"github.com/tech/sendico/pkg/mlogger"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
"github.com/tech/sendico/pkg/mservice"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
)
|
||||
|
||||
type OrganizationDB struct {
|
||||
auth.ProtectedDBImp[*model.Organization]
|
||||
}
|
||||
|
||||
func Create(ctx context.Context,
|
||||
logger mlogger.Logger,
|
||||
enforcer auth.Enforcer,
|
||||
pdb policy.DB,
|
||||
db *mongo.Database,
|
||||
) (*OrganizationDB, error) {
|
||||
p, err := auth.CreateDBImp[*model.Organization](ctx, logger, pdb, enforcer, mservice.Organizations, db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
res := &OrganizationDB{
|
||||
ProtectedDBImp: *p,
|
||||
}
|
||||
p.DBImp.SetDeleter(res.DeleteCascade)
|
||||
return res, nil
|
||||
}
|
||||
12
api/pkg/db/internal/mongo/organizationdb/get.go
Normal file
12
api/pkg/db/internal/mongo/organizationdb/get.go
Normal file
@@ -0,0 +1,12 @@
|
||||
package organizationdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
)
|
||||
|
||||
func (db *OrganizationDB) GetByRef(ctx context.Context, organizationRef primitive.ObjectID, org *model.Organization) error {
|
||||
return db.Unprotected().Get(ctx, organizationRef, org)
|
||||
}
|
||||
16
api/pkg/db/internal/mongo/organizationdb/list.go
Normal file
16
api/pkg/db/internal/mongo/organizationdb/list.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package organizationdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/tech/sendico/pkg/db/repository"
|
||||
"github.com/tech/sendico/pkg/db/repository/builder"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
mutil "github.com/tech/sendico/pkg/mutil/db"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
)
|
||||
|
||||
func (db *OrganizationDB) List(ctx context.Context, accountRef primitive.ObjectID, cursor *model.ViewCursor) ([]model.Organization, error) {
|
||||
filter := repository.Query().Comparison(repository.Field("members"), builder.Eq, accountRef)
|
||||
return mutil.GetObjects[model.Organization](ctx, db.DBImp.Logger, filter, cursor, db.DBImp.Repository)
|
||||
}
|
||||
14
api/pkg/db/internal/mongo/organizationdb/owned.go
Normal file
14
api/pkg/db/internal/mongo/organizationdb/owned.go
Normal file
@@ -0,0 +1,14 @@
|
||||
package organizationdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/tech/sendico/pkg/db/repository"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
mutil "github.com/tech/sendico/pkg/mutil/db"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
)
|
||||
|
||||
func (db *OrganizationDB) ListOwned(ctx context.Context, accountRef primitive.ObjectID) ([]model.Organization, error) {
|
||||
return mutil.GetObjects[model.Organization](ctx, db.DBImp.Logger, repository.Filter("ownerRef", accountRef), nil, db.DBImp.Repository)
|
||||
}
|
||||
562
api/pkg/db/internal/mongo/organizationdb/setarchived_test.go
Normal file
562
api/pkg/db/internal/mongo/organizationdb/setarchived_test.go
Normal file
@@ -0,0 +1,562 @@
|
||||
//go:build integration
|
||||
// +build integration
|
||||
|
||||
package organizationdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/tech/sendico/pkg/db/internal/mongo/commentdb"
|
||||
"github.com/tech/sendico/pkg/db/internal/mongo/projectdb"
|
||||
"github.com/tech/sendico/pkg/db/internal/mongo/reactiondb"
|
||||
"github.com/tech/sendico/pkg/db/internal/mongo/statusdb"
|
||||
"github.com/tech/sendico/pkg/db/internal/mongo/taskdb"
|
||||
"github.com/tech/sendico/pkg/db/repository/builder"
|
||||
"github.com/tech/sendico/pkg/db/template"
|
||||
"github.com/tech/sendico/pkg/merrors"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
"github.com/tech/sendico/pkg/mservice"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/testcontainers/testcontainers-go/modules/mongodb"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
"go.mongodb.org/mongo-driver/mongo/options"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func setupSetArchivedTestDB(t *testing.T) (*OrganizationDB, *projectDBAdapter, *taskdb.TaskDB, *commentdb.CommentDB, *reactiondb.ReactionDB, func()) {
|
||||
ctx := context.Background()
|
||||
|
||||
// Start MongoDB container
|
||||
mongodbContainer, err := mongodb.Run(ctx, "mongo:latest")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Get connection string
|
||||
endpoint, err := mongodbContainer.Endpoint(ctx, "")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Connect to MongoDB
|
||||
client, err := mongo.Connect(ctx, options.Client().ApplyURI("mongodb://"+endpoint))
|
||||
require.NoError(t, err)
|
||||
|
||||
db := client.Database("test_organization_setarchived")
|
||||
logger := zap.NewNop()
|
||||
|
||||
// Create mock enforcer and policy DB
|
||||
mockEnforcer := &mockSetArchivedEnforcer{}
|
||||
mockPolicyDB := &mockSetArchivedPolicyDB{}
|
||||
mockPGroupDB := &mockSetArchivedPGroupDB{}
|
||||
|
||||
// Create databases
|
||||
// We need to create a projectDB first, but we'll create a temporary one for organizationDB creation
|
||||
// Create temporary taskDB and statusDB for the temporary projectDB
|
||||
// Create temporary reactionDB and commentDB for the temporary taskDB
|
||||
tempReactionDB, err := reactiondb.Create(ctx, logger, mockEnforcer, mockPolicyDB, db)
|
||||
require.NoError(t, err)
|
||||
|
||||
tempCommentDB, err := commentdb.Create(ctx, logger, mockEnforcer, mockPolicyDB, db, tempReactionDB)
|
||||
require.NoError(t, err)
|
||||
|
||||
tempTaskDB, err := taskdb.Create(ctx, logger, mockEnforcer, mockPolicyDB, db, tempCommentDB, tempReactionDB)
|
||||
require.NoError(t, err)
|
||||
|
||||
tempStatusDB, err := statusdb.Create(ctx, logger, mockEnforcer, mockPolicyDB, db)
|
||||
require.NoError(t, err)
|
||||
|
||||
tempProjectDB, err := projectdb.Create(ctx, logger, mockEnforcer, mockPolicyDB, tempTaskDB, tempStatusDB, db)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create adapter for organizationDB creation
|
||||
tempProjectDBAdapter := &projectDBAdapter{
|
||||
ProjectDB: tempProjectDB,
|
||||
taskDB: tempTaskDB,
|
||||
commentDB: tempCommentDB,
|
||||
reactionDB: tempReactionDB,
|
||||
statusDB: tempStatusDB,
|
||||
}
|
||||
|
||||
organizationDB, err := Create(ctx, logger, mockEnforcer, mockPolicyDB, tempProjectDBAdapter, mockPGroupDB, db)
|
||||
require.NoError(t, err)
|
||||
|
||||
var projectDB *projectdb.ProjectDB
|
||||
var taskDB *taskdb.TaskDB
|
||||
var commentDB *commentdb.CommentDB
|
||||
var reactionDB *reactiondb.ReactionDB
|
||||
var statusDB *statusdb.StatusDB
|
||||
|
||||
// Create databases in dependency order
|
||||
reactionDB, err = reactiondb.Create(ctx, logger, mockEnforcer, mockPolicyDB, db)
|
||||
require.NoError(t, err)
|
||||
|
||||
commentDB, err = commentdb.Create(ctx, logger, mockEnforcer, mockPolicyDB, db, reactionDB)
|
||||
require.NoError(t, err)
|
||||
|
||||
taskDB, err = taskdb.Create(ctx, logger, mockEnforcer, mockPolicyDB, db, commentDB, reactionDB)
|
||||
require.NoError(t, err)
|
||||
|
||||
statusDB, err = statusdb.Create(ctx, logger, mockEnforcer, mockPolicyDB, db)
|
||||
require.NoError(t, err)
|
||||
|
||||
projectDB, err = projectdb.Create(ctx, logger, mockEnforcer, mockPolicyDB, taskDB, statusDB, db)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create adapter for the actual projectDB
|
||||
projectDBAdapter := &projectDBAdapter{
|
||||
ProjectDB: projectDB,
|
||||
taskDB: taskDB,
|
||||
commentDB: commentDB,
|
||||
reactionDB: reactionDB,
|
||||
statusDB: statusDB,
|
||||
}
|
||||
|
||||
cleanup := func() {
|
||||
client.Disconnect(context.Background())
|
||||
mongodbContainer.Terminate(ctx)
|
||||
}
|
||||
|
||||
return organizationDB, projectDBAdapter, taskDB, commentDB, reactionDB, cleanup
|
||||
}
|
||||
|
||||
// projectDBAdapter adapts projectdb.ProjectDB to project.DB interface for testing
|
||||
type projectDBAdapter struct {
|
||||
*projectdb.ProjectDB
|
||||
taskDB *taskdb.TaskDB
|
||||
commentDB *commentdb.CommentDB
|
||||
reactionDB *reactiondb.ReactionDB
|
||||
statusDB *statusdb.StatusDB
|
||||
}
|
||||
|
||||
// DeleteCascade implements the project.DB interface
|
||||
func (a *projectDBAdapter) DeleteCascade(ctx context.Context, projectRef primitive.ObjectID) error {
|
||||
// Call the concrete implementation
|
||||
return a.ProjectDB.DeleteCascade(ctx, projectRef)
|
||||
}
|
||||
|
||||
// SetArchived implements the project.DB interface
|
||||
func (a *projectDBAdapter) SetArchived(ctx context.Context, accountRef, organizationRef, projectRef primitive.ObjectID, archived, cascade bool) error {
|
||||
// Use the stored dependencies for the concrete implementation
|
||||
return a.ProjectDB.SetArchived(ctx, accountRef, organizationRef, projectRef, archived, cascade)
|
||||
}
|
||||
|
||||
// List implements the project.DB interface
|
||||
func (a *projectDBAdapter) List(ctx context.Context, accountRef, organizationRef, _ primitive.ObjectID, cursor *model.ViewCursor) ([]model.Project, error) {
|
||||
return a.ProjectDB.List(ctx, accountRef, organizationRef, primitive.NilObjectID, cursor)
|
||||
}
|
||||
|
||||
// Previews implements the project.DB interface
|
||||
func (a *projectDBAdapter) Previews(ctx context.Context, accountRef, organizationRef primitive.ObjectID, projectRefs []primitive.ObjectID, cursor *model.ViewCursor, assigneeRefs, reporterRefs []primitive.ObjectID) ([]model.ProjectPreview, error) {
|
||||
return a.ProjectDB.Previews(ctx, accountRef, organizationRef, projectRefs, cursor, assigneeRefs, reporterRefs)
|
||||
}
|
||||
|
||||
// DeleteProject implements the project.DB interface
|
||||
func (a *projectDBAdapter) DeleteProject(ctx context.Context, accountRef, organizationRef, projectRef primitive.ObjectID, migrateToRef *primitive.ObjectID) error {
|
||||
// Call the concrete implementation with the organizationRef
|
||||
return a.ProjectDB.DeleteProject(ctx, accountRef, organizationRef, projectRef, migrateToRef)
|
||||
}
|
||||
|
||||
// RemoveTagFromProjects implements the project.DB interface
|
||||
func (a *projectDBAdapter) RemoveTagFromProjects(ctx context.Context, accountRef, organizationRef, tagRef primitive.ObjectID) error {
|
||||
// Call the concrete implementation
|
||||
return a.ProjectDB.RemoveTagFromProjects(ctx, accountRef, organizationRef, tagRef)
|
||||
}
|
||||
|
||||
// Mock implementations for SetArchived testing
|
||||
type mockSetArchivedEnforcer struct{}
|
||||
|
||||
func (m *mockSetArchivedEnforcer) Enforce(ctx context.Context, permissionRef, accountRef, orgRef, objectRef primitive.ObjectID, action model.Action) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (m *mockSetArchivedEnforcer) EnforceBatch(ctx context.Context, objectRefs []model.PermissionBoundStorable, accountRef primitive.ObjectID, action model.Action) (map[primitive.ObjectID]bool, error) {
|
||||
// Allow all objects for testing
|
||||
result := make(map[primitive.ObjectID]bool)
|
||||
for _, obj := range objectRefs {
|
||||
result[*obj.GetID()] = true
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (m *mockSetArchivedEnforcer) GetRoles(ctx context.Context, accountRef, organizationRef primitive.ObjectID) ([]model.Role, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockSetArchivedEnforcer) GetPermissions(ctx context.Context, accountRef, organizationRef primitive.ObjectID) ([]model.Role, []model.Permission, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
|
||||
type mockSetArchivedPolicyDB struct{}
|
||||
|
||||
func (m *mockSetArchivedPolicyDB) Create(ctx context.Context, policy *model.PolicyDescription) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockSetArchivedPolicyDB) Get(ctx context.Context, policyRef primitive.ObjectID, result *model.PolicyDescription) error {
|
||||
return merrors.ErrNoData
|
||||
}
|
||||
|
||||
func (m *mockSetArchivedPolicyDB) InsertMany(ctx context.Context, objects []*model.PolicyDescription) error { return nil }
|
||||
|
||||
func (m *mockSetArchivedPolicyDB) Update(ctx context.Context, policy *model.PolicyDescription) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockSetArchivedPolicyDB) Patch(ctx context.Context, objectRef primitive.ObjectID, patch builder.Patch) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockSetArchivedPolicyDB) Delete(ctx context.Context, policyRef primitive.ObjectID) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockSetArchivedPolicyDB) DeleteMany(ctx context.Context, filter builder.Query) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockSetArchivedPolicyDB) FindOne(ctx context.Context, filter builder.Query, result *model.PolicyDescription) error {
|
||||
return merrors.ErrNoData
|
||||
}
|
||||
|
||||
func (m *mockSetArchivedPolicyDB) ListIDs(ctx context.Context, query builder.Query) ([]primitive.ObjectID, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockSetArchivedPolicyDB) ListPermissionBound(ctx context.Context, query builder.Query) ([]model.PermissionBoundStorable, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (m *mockSetArchivedPolicyDB) Collection() string { return "" }
|
||||
func (m *mockSetArchivedPolicyDB) All(ctx context.Context, organizationRef primitive.ObjectID) ([]model.PolicyDescription, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockSetArchivedPolicyDB) Policies(ctx context.Context, refs []primitive.ObjectID) ([]model.PolicyDescription, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockSetArchivedPolicyDB) GetBuiltInPolicy(ctx context.Context, resourceType mservice.Type, policy *model.PolicyDescription) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockSetArchivedPolicyDB) DeleteCascade(ctx context.Context, policyRef primitive.ObjectID) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type mockSetArchivedPGroupDB struct{}
|
||||
|
||||
func (m *mockSetArchivedPGroupDB) Create(ctx context.Context, accountRef, organizationRef primitive.ObjectID, pgroup *model.PriorityGroup) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockSetArchivedPGroupDB) InsertMany(ctx context.Context, accountRef, organizationRef primitive.ObjectID, objects []*model.PriorityGroup) error { return nil }
|
||||
|
||||
func (m *mockSetArchivedPGroupDB) Get(ctx context.Context, accountRef, pgroupRef primitive.ObjectID, result *model.PriorityGroup) error {
|
||||
return merrors.ErrNoData
|
||||
}
|
||||
|
||||
func (m *mockSetArchivedPGroupDB) Update(ctx context.Context, accountRef primitive.ObjectID, pgroup *model.PriorityGroup) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockSetArchivedPGroupDB) Delete(ctx context.Context, accountRef, pgroupRef primitive.ObjectID) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockSetArchivedPGroupDB) DeleteCascadeAuth(ctx context.Context, accountRef, pgroupRef primitive.ObjectID) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockSetArchivedPGroupDB) Patch(ctx context.Context, accountRef, pgroupRef primitive.ObjectID, patch builder.Patch) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockSetArchivedPGroupDB) PatchMany(ctx context.Context, accountRef primitive.ObjectID, query builder.Query, patch builder.Patch) (int, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (m *mockSetArchivedPGroupDB) Unprotected() template.DB[*model.PriorityGroup] {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockSetArchivedPGroupDB) ListIDs(ctx context.Context, action model.Action, accountRef primitive.ObjectID, query builder.Query) ([]primitive.ObjectID, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockSetArchivedPGroupDB) All(ctx context.Context, organizationRef primitive.ObjectID, limit, offset *int64) ([]model.PriorityGroup, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockSetArchivedPGroupDB) List(ctx context.Context, accountRef, organizationRef, _ primitive.ObjectID, cursor *model.ViewCursor) ([]model.PriorityGroup, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockSetArchivedPGroupDB) DeleteCascade(ctx context.Context, statusRef primitive.ObjectID) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockSetArchivedPGroupDB) SetArchived(ctx context.Context, accountRef, organizationRef, statusRef primitive.ObjectID, archived, cascade bool) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockSetArchivedPGroupDB) Reorder(ctx context.Context, accountRef, priorityGroupRef primitive.ObjectID, oldIndex, newIndex int) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Mock project DB for statusdb creation
|
||||
type mockSetArchivedProjectDB struct{}
|
||||
|
||||
func (m *mockSetArchivedProjectDB) Create(ctx context.Context, accountRef, organizationRef primitive.ObjectID, project *model.Project) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockSetArchivedProjectDB) Get(ctx context.Context, accountRef, projectRef primitive.ObjectID, result *model.Project) error {
|
||||
return merrors.ErrNoData
|
||||
}
|
||||
func (m *mockSetArchivedProjectDB) Update(ctx context.Context, accountRef primitive.ObjectID, project *model.Project) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockSetArchivedProjectDB) Delete(ctx context.Context, accountRef, projectRef primitive.ObjectID) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockSetArchivedProjectDB) DeleteCascadeAuth(ctx context.Context, accountRef, projectRef primitive.ObjectID) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockSetArchivedProjectDB) Patch(ctx context.Context, accountRef, objectRef primitive.ObjectID, patch builder.Patch) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockSetArchivedProjectDB) PatchMany(ctx context.Context, accountRef primitive.ObjectID, query builder.Query, patch builder.Patch) (int, error) {
|
||||
return 0, nil
|
||||
}
|
||||
func (m *mockSetArchivedProjectDB) Unprotected() template.DB[*model.Project] { return nil }
|
||||
func (m *mockSetArchivedProjectDB) ListIDs(ctx context.Context, action model.Action, accountRef primitive.ObjectID, query builder.Query) ([]primitive.ObjectID, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (m *mockSetArchivedProjectDB) List(ctx context.Context, accountRef, organizationRef, _ primitive.ObjectID, cursor *model.ViewCursor) ([]model.Project, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (m *mockSetArchivedProjectDB) Previews(ctx context.Context, accountRef, organizationRef primitive.ObjectID, projectRefs []primitive.ObjectID, cursor *model.ViewCursor, assigneeRefs, reporterRefs []primitive.ObjectID) ([]model.ProjectPreview, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (m *mockSetArchivedProjectDB) DeleteProject(ctx context.Context, accountRef, organizationRef, projectRef primitive.ObjectID, migrateToRef *primitive.ObjectID) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockSetArchivedProjectDB) DeleteCascade(ctx context.Context, projectRef primitive.ObjectID) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockSetArchivedProjectDB) SetArchived(ctx context.Context, accountRef, organizationRef, projectRef primitive.ObjectID, archived, cascade bool) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockSetArchivedProjectDB) All(ctx context.Context, organizationRef primitive.ObjectID, limit, offset *int64) ([]model.Project, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (m *mockSetArchivedProjectDB) Reorder(ctx context.Context, accountRef, objectRef primitive.ObjectID, newIndex int, filter builder.Query) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockSetArchivedProjectDB) AddTag(ctx context.Context, accountRef, objectRef, tagRef primitive.ObjectID) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockSetArchivedProjectDB) RemoveTag(ctx context.Context, accountRef, objectRef, tagRef primitive.ObjectID) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockSetArchivedProjectDB) RemoveTags(ctx context.Context, accountRef, organizationRef, tagRef primitive.ObjectID) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockSetArchivedProjectDB) AddTags(ctx context.Context, accountRef, objectRef primitive.ObjectID, tagRefs []primitive.ObjectID) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockSetArchivedProjectDB) SetTags(ctx context.Context, accountRef, objectRef primitive.ObjectID, tagRefs []primitive.ObjectID) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockSetArchivedProjectDB) RemoveAllTags(ctx context.Context, accountRef, objectRef primitive.ObjectID) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockSetArchivedProjectDB) GetTags(ctx context.Context, accountRef, objectRef primitive.ObjectID) ([]primitive.ObjectID, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (m *mockSetArchivedProjectDB) HasTag(ctx context.Context, accountRef, objectRef, tagRef primitive.ObjectID) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
func (m *mockSetArchivedProjectDB) FindByTag(ctx context.Context, accountRef, tagRef primitive.ObjectID) ([]*model.Project, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (m *mockSetArchivedProjectDB) FindByTags(ctx context.Context, accountRef primitive.ObjectID, tagRefs []primitive.ObjectID) ([]*model.Project, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func TestOrganizationDB_SetArchived(t *testing.T) {
|
||||
organizationDB, projectDBAdapter, taskDB, commentDB, reactionDB, cleanup := setupSetArchivedTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
accountRef := primitive.NewObjectID()
|
||||
|
||||
t.Run("SetArchived_OrganizationWithProjectsTasksCommentsAndReactions_Cascade", func(t *testing.T) {
|
||||
// Create an organization using unprotected DB
|
||||
organization := &model.Organization{
|
||||
OrganizationBase: model.OrganizationBase{
|
||||
Describable: model.Describable{Name: "Test Organization for Archive"},
|
||||
TimeZone: "UTC",
|
||||
},
|
||||
}
|
||||
organization.ID = primitive.NewObjectID()
|
||||
|
||||
err := organizationDB.Create(ctx, accountRef, organization.ID, organization)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create a project for the organization using unprotected DB
|
||||
project := &model.Project{
|
||||
ProjectBase: model.ProjectBase{
|
||||
PermissionBound: model.PermissionBound{
|
||||
OrganizationBoundBase: model.OrganizationBoundBase{
|
||||
OrganizationRef: organization.ID,
|
||||
},
|
||||
},
|
||||
Describable: model.Describable{Name: "Test Project"},
|
||||
Indexable: model.Indexable{Index: 0},
|
||||
Mnemonic: "TEST",
|
||||
State: model.ProjectStateActive,
|
||||
},
|
||||
}
|
||||
project.ID = primitive.NewObjectID()
|
||||
|
||||
err = projectDBAdapter.Unprotected().Create(ctx, project)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create a task for the project using unprotected DB
|
||||
task := &model.Task{
|
||||
PermissionBound: model.PermissionBound{
|
||||
OrganizationBoundBase: model.OrganizationBoundBase{
|
||||
OrganizationRef: organization.ID,
|
||||
},
|
||||
},
|
||||
Describable: model.Describable{Name: "Test Task for Archive"},
|
||||
ProjectRef: project.ID,
|
||||
}
|
||||
task.ID = primitive.NewObjectID()
|
||||
|
||||
err = taskDB.Unprotected().Create(ctx, task)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create comments for the task using unprotected DB
|
||||
comment := &model.Comment{
|
||||
CommentBase: model.CommentBase{
|
||||
PermissionBound: model.PermissionBound{
|
||||
OrganizationBoundBase: model.OrganizationBoundBase{
|
||||
OrganizationRef: organization.ID,
|
||||
},
|
||||
},
|
||||
AuthorRef: accountRef,
|
||||
TaskRef: task.ID,
|
||||
Content: "Test Comment for Archive",
|
||||
},
|
||||
}
|
||||
comment.ID = primitive.NewObjectID()
|
||||
|
||||
err = commentDB.Unprotected().Create(ctx, comment)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create reaction for the comment using unprotected DB
|
||||
reaction := &model.Reaction{
|
||||
PermissionBound: model.PermissionBound{
|
||||
OrganizationBoundBase: model.OrganizationBoundBase{
|
||||
OrganizationRef: organization.ID,
|
||||
},
|
||||
},
|
||||
Type: "like",
|
||||
AuthorRef: accountRef,
|
||||
CommentRef: comment.ID,
|
||||
}
|
||||
reaction.ID = primitive.NewObjectID()
|
||||
|
||||
err = reactionDB.Unprotected().Create(ctx, reaction)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify all entities are not archived initially
|
||||
var retrievedOrganization model.Organization
|
||||
err = organizationDB.Get(ctx, accountRef, organization.ID, &retrievedOrganization)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, retrievedOrganization.IsArchived())
|
||||
|
||||
var retrievedProject model.Project
|
||||
err = projectDBAdapter.Unprotected().Get(ctx, project.ID, &retrievedProject)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, retrievedProject.IsArchived())
|
||||
|
||||
var retrievedTask model.Task
|
||||
err = taskDB.Unprotected().Get(ctx, task.ID, &retrievedTask)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, retrievedTask.IsArchived())
|
||||
|
||||
var retrievedComment model.Comment
|
||||
err = commentDB.Unprotected().Get(ctx, comment.ID, &retrievedComment)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, retrievedComment.IsArchived())
|
||||
|
||||
// Archive organization with cascade
|
||||
err = organizationDB.SetArchived(ctx, accountRef, organization.ID, true, true)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify all entities are archived due to cascade
|
||||
err = organizationDB.Get(ctx, accountRef, organization.ID, &retrievedOrganization)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, retrievedOrganization.IsArchived())
|
||||
|
||||
err = projectDBAdapter.Unprotected().Get(ctx, project.ID, &retrievedProject)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, retrievedProject.IsArchived())
|
||||
|
||||
err = taskDB.Unprotected().Get(ctx, task.ID, &retrievedTask)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, retrievedTask.IsArchived())
|
||||
|
||||
err = commentDB.Unprotected().Get(ctx, comment.ID, &retrievedComment)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, retrievedComment.IsArchived())
|
||||
|
||||
// Verify reaction still exists (reactions don't support archiving)
|
||||
var retrievedReaction model.Reaction
|
||||
err = reactionDB.Unprotected().Get(ctx, reaction.ID, &retrievedReaction)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Unarchive organization with cascade
|
||||
err = organizationDB.SetArchived(ctx, accountRef, organization.ID, false, true)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify all entities are unarchived
|
||||
err = organizationDB.Get(ctx, accountRef, organization.ID, &retrievedOrganization)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, retrievedOrganization.IsArchived())
|
||||
|
||||
err = projectDBAdapter.Unprotected().Get(ctx, project.ID, &retrievedProject)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, retrievedProject.IsArchived())
|
||||
|
||||
err = taskDB.Unprotected().Get(ctx, task.ID, &retrievedTask)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, retrievedTask.IsArchived())
|
||||
|
||||
err = commentDB.Unprotected().Get(ctx, comment.ID, &retrievedComment)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, retrievedComment.IsArchived())
|
||||
|
||||
// Clean up
|
||||
err = reactionDB.Unprotected().Delete(ctx, reaction.ID)
|
||||
require.NoError(t, err)
|
||||
err = commentDB.Unprotected().Delete(ctx, comment.ID)
|
||||
require.NoError(t, err)
|
||||
err = taskDB.Unprotected().Delete(ctx, task.ID)
|
||||
require.NoError(t, err)
|
||||
err = projectDBAdapter.Unprotected().Delete(ctx, project.ID)
|
||||
require.NoError(t, err)
|
||||
err = organizationDB.Delete(ctx, accountRef, organization.ID)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("SetArchived_NonExistentOrganization", func(t *testing.T) {
|
||||
// Try to archive non-existent organization
|
||||
nonExistentID := primitive.NewObjectID()
|
||||
err := organizationDB.SetArchived(ctx, accountRef, nonExistentID, true, true)
|
||||
assert.Error(t, err)
|
||||
// Could be either no data or access denied error depending on the permission system
|
||||
assert.True(t, errors.Is(err, merrors.ErrNoData) || errors.Is(err, merrors.ErrAccessDenied))
|
||||
})
|
||||
}
|
||||
20
api/pkg/db/internal/mongo/policiesdb/all.go
Normal file
20
api/pkg/db/internal/mongo/policiesdb/all.go
Normal file
@@ -0,0 +1,20 @@
|
||||
package policiesdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/tech/sendico/pkg/db/repository"
|
||||
"github.com/tech/sendico/pkg/db/storable"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
mutil "github.com/tech/sendico/pkg/mutil/db"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
)
|
||||
|
||||
func (db *PoliciesDB) All(ctx context.Context, organizationRef primitive.ObjectID) ([]model.PolicyDescription, error) {
|
||||
// all documents
|
||||
filter := repository.Query().Or(
|
||||
repository.Filter(storable.OrganizationRefField, nil),
|
||||
repository.OrgFilter(organizationRef),
|
||||
)
|
||||
return mutil.GetObjects[model.PolicyDescription](ctx, db.Logger, filter, nil, db.Repository)
|
||||
}
|
||||
13
api/pkg/db/internal/mongo/policiesdb/builtin.go
Normal file
13
api/pkg/db/internal/mongo/policiesdb/builtin.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package policiesdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/tech/sendico/pkg/db/repository"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
"github.com/tech/sendico/pkg/mservice"
|
||||
)
|
||||
|
||||
func (db *PoliciesDB) GetBuiltInPolicy(ctx context.Context, resourceType mservice.Type, policy *model.PolicyDescription) error {
|
||||
return db.FindOne(ctx, repository.Filter("resourceTypes", resourceType), policy)
|
||||
}
|
||||
21
api/pkg/db/internal/mongo/policiesdb/db.go
Normal file
21
api/pkg/db/internal/mongo/policiesdb/db.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package policiesdb
|
||||
|
||||
import (
|
||||
"github.com/tech/sendico/pkg/db/template"
|
||||
"github.com/tech/sendico/pkg/mlogger"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
"github.com/tech/sendico/pkg/mservice"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
)
|
||||
|
||||
type PoliciesDB struct {
|
||||
template.DBImp[*model.PolicyDescription]
|
||||
}
|
||||
|
||||
func Create(logger mlogger.Logger, db *mongo.Database) (*PoliciesDB, error) {
|
||||
p := &PoliciesDB{
|
||||
DBImp: *template.Create[*model.PolicyDescription](logger, mservice.Policies, db),
|
||||
}
|
||||
|
||||
return p, nil
|
||||
}
|
||||
353
api/pkg/db/internal/mongo/policiesdb/db_test.go
Normal file
353
api/pkg/db/internal/mongo/policiesdb/db_test.go
Normal file
@@ -0,0 +1,353 @@
|
||||
//go:build integration
|
||||
// +build integration
|
||||
|
||||
package policiesdb_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
// Your internal packages
|
||||
"github.com/tech/sendico/pkg/db/internal/mongo/policiesdb"
|
||||
"github.com/tech/sendico/pkg/db/repository"
|
||||
"github.com/tech/sendico/pkg/db/repository/builder"
|
||||
"github.com/tech/sendico/pkg/merrors"
|
||||
// Model package (contains PolicyDescription + Describable)
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
// Testcontainers
|
||||
"github.com/testcontainers/testcontainers-go"
|
||||
"github.com/testcontainers/testcontainers-go/modules/mongodb"
|
||||
"github.com/testcontainers/testcontainers-go/wait"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
"go.mongodb.org/mongo-driver/mongo/options"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// Helper to terminate container
|
||||
func terminate(t *testing.T, ctx context.Context, container *mongodb.MongoDBContainer) {
|
||||
err := container.Terminate(ctx)
|
||||
require.NoError(t, err, "failed to terminate MongoDB container")
|
||||
}
|
||||
|
||||
// Helper to disconnect client
|
||||
func disconnect(t *testing.T, ctx context.Context, client *mongo.Client) {
|
||||
err := client.Disconnect(context.Background())
|
||||
require.NoError(t, err, "failed to disconnect from MongoDB")
|
||||
}
|
||||
|
||||
// Helper to drop the Policies collection
|
||||
func cleanupCollection(t *testing.T, ctx context.Context, db *mongo.Database) {
|
||||
// The actual collection name is typically the value returned by
|
||||
// (&model.PolicyDescription{}).Collection(), or something similar.
|
||||
// Make sure it matches what your code uses (often "policies" or "policyDescription").
|
||||
err := db.Collection((&model.PolicyDescription{}).Collection()).Drop(ctx)
|
||||
require.NoError(t, err, "failed to drop collection between sub-tests")
|
||||
}
|
||||
|
||||
func TestPoliciesDB(t *testing.T) {
|
||||
// Create context with reasonable timeout
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
// Start MongoDB test container
|
||||
mongoC, err := mongodb.Run(ctx,
|
||||
"mongo:latest",
|
||||
mongodb.WithUsername("root"),
|
||||
mongodb.WithPassword("password"),
|
||||
testcontainers.WithWaitStrategy(wait.ForLog("Waiting for connections")),
|
||||
)
|
||||
require.NoError(t, err, "failed to start MongoDB container")
|
||||
defer terminate(t, ctx, mongoC)
|
||||
|
||||
// Get connection URI
|
||||
mongoURI, err := mongoC.ConnectionString(ctx)
|
||||
require.NoError(t, err, "failed to get connection string")
|
||||
|
||||
// Connect client
|
||||
clientOpts := options.Client().ApplyURI(mongoURI)
|
||||
client, err := mongo.Connect(ctx, clientOpts)
|
||||
require.NoError(t, err, "failed to connect to MongoDB")
|
||||
defer disconnect(t, ctx, client)
|
||||
|
||||
// Create test DB
|
||||
db := client.Database("testdb")
|
||||
|
||||
// Use a no-op logger (or real logger if you prefer)
|
||||
logger := zap.NewNop()
|
||||
|
||||
// Create an instance of PoliciesDB
|
||||
pdb, err := policiesdb.Create(logger, db)
|
||||
require.NoError(t, err, "unexpected error creating PoliciesDB")
|
||||
|
||||
// ---------------------------------------------------------
|
||||
// Each sub-test below starts by dropping the collection.
|
||||
// ---------------------------------------------------------
|
||||
|
||||
t.Run("CreateAndGet", func(t *testing.T) {
|
||||
cleanupCollection(t, ctx, db) // ensure no leftover data
|
||||
|
||||
desc := "Test policy description"
|
||||
policy := &model.PolicyDescription{
|
||||
Describable: model.Describable{
|
||||
Name: "TestPolicy",
|
||||
Description: &desc,
|
||||
},
|
||||
}
|
||||
require.NoError(t, pdb.Create(ctx, policy))
|
||||
|
||||
result := &model.PolicyDescription{}
|
||||
err := pdb.Get(ctx, policy.ID, result)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, policy.ID, result.ID)
|
||||
assert.Equal(t, "TestPolicy", result.Name)
|
||||
assert.NotNil(t, result.Description)
|
||||
assert.Equal(t, "Test policy description", *result.Description)
|
||||
})
|
||||
|
||||
t.Run("Get_NotFound", func(t *testing.T) {
|
||||
cleanupCollection(t, ctx, db)
|
||||
|
||||
// Attempt to get a non-existent ID
|
||||
nonExistentID := primitive.NewObjectID()
|
||||
result := &model.PolicyDescription{}
|
||||
err := pdb.Get(ctx, nonExistentID, result)
|
||||
assert.Error(t, err)
|
||||
assert.True(t, errors.Is(err, merrors.ErrNoData))
|
||||
})
|
||||
|
||||
t.Run("Update", func(t *testing.T) {
|
||||
cleanupCollection(t, ctx, db)
|
||||
|
||||
originalDesc := "Original description"
|
||||
policy := &model.PolicyDescription{
|
||||
Describable: model.Describable{
|
||||
Name: "OriginalName",
|
||||
Description: &originalDesc,
|
||||
},
|
||||
}
|
||||
require.NoError(t, pdb.Create(ctx, policy))
|
||||
|
||||
newDesc := "Updated description"
|
||||
policy.Name = "UpdatedName"
|
||||
policy.Description = &newDesc
|
||||
|
||||
err := pdb.Update(ctx, policy)
|
||||
require.NoError(t, err)
|
||||
|
||||
updated := &model.PolicyDescription{}
|
||||
err = pdb.Get(ctx, policy.ID, updated)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "UpdatedName", updated.Name)
|
||||
assert.NotNil(t, updated.Description)
|
||||
assert.Equal(t, "Updated description", *updated.Description)
|
||||
})
|
||||
|
||||
t.Run("Delete", func(t *testing.T) {
|
||||
cleanupCollection(t, ctx, db)
|
||||
|
||||
desc := "To be deleted"
|
||||
policy := &model.PolicyDescription{
|
||||
Describable: model.Describable{
|
||||
Name: "WillDelete",
|
||||
Description: &desc,
|
||||
},
|
||||
}
|
||||
require.NoError(t, pdb.Create(ctx, policy))
|
||||
|
||||
err := pdb.Delete(ctx, policy.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
deleted := &model.PolicyDescription{}
|
||||
err = pdb.Get(ctx, policy.ID, deleted)
|
||||
assert.Error(t, err)
|
||||
assert.True(t, errors.Is(err, merrors.ErrNoData))
|
||||
})
|
||||
|
||||
t.Run("DeleteMany", func(t *testing.T) {
|
||||
cleanupCollection(t, ctx, db)
|
||||
|
||||
desc1 := "Will be deleted 1"
|
||||
desc2 := "Will be deleted 2"
|
||||
pol1 := &model.PolicyDescription{
|
||||
Describable: model.Describable{
|
||||
Name: "BatchDelete1",
|
||||
Description: &desc1,
|
||||
},
|
||||
}
|
||||
pol2 := &model.PolicyDescription{
|
||||
Describable: model.Describable{
|
||||
Name: "BatchDelete2",
|
||||
Description: &desc2,
|
||||
},
|
||||
}
|
||||
require.NoError(t, pdb.Create(ctx, pol1))
|
||||
require.NoError(t, pdb.Create(ctx, pol2))
|
||||
|
||||
q := repository.Query().RegEx(repository.Field("description"), "^Will be deleted", "")
|
||||
err := pdb.DeleteMany(ctx, q)
|
||||
require.NoError(t, err)
|
||||
|
||||
res1 := &model.PolicyDescription{}
|
||||
err1 := pdb.Get(ctx, pol1.ID, res1)
|
||||
assert.Error(t, err1)
|
||||
assert.True(t, errors.Is(err1, merrors.ErrNoData))
|
||||
|
||||
res2 := &model.PolicyDescription{}
|
||||
err2 := pdb.Get(ctx, pol2.ID, res2)
|
||||
assert.Error(t, err2)
|
||||
assert.True(t, errors.Is(err2, merrors.ErrNoData))
|
||||
})
|
||||
|
||||
t.Run("FindOne", func(t *testing.T) {
|
||||
cleanupCollection(t, ctx, db)
|
||||
|
||||
desc := "Unique find test"
|
||||
policy := &model.PolicyDescription{
|
||||
Describable: model.Describable{
|
||||
Name: "FindOneTest",
|
||||
Description: &desc,
|
||||
},
|
||||
}
|
||||
require.NoError(t, pdb.Create(ctx, policy))
|
||||
|
||||
// Match by name == "FindOneTest"
|
||||
q := repository.Query().Comparison(repository.Field("name"), builder.Eq, "FindOneTest")
|
||||
|
||||
found := &model.PolicyDescription{}
|
||||
err := pdb.FindOne(ctx, q, found)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, policy.ID, found.ID)
|
||||
assert.Equal(t, "FindOneTest", found.Name)
|
||||
assert.NotNil(t, found.Description)
|
||||
assert.Equal(t, "Unique find test", *found.Description)
|
||||
})
|
||||
|
||||
t.Run("All", func(t *testing.T) {
|
||||
cleanupCollection(t, ctx, db)
|
||||
|
||||
// Insert some policies (orgA, orgB, nil org)
|
||||
orgA := primitive.NewObjectID()
|
||||
orgB := primitive.NewObjectID()
|
||||
|
||||
descA := "Org A policy"
|
||||
policyA := &model.PolicyDescription{
|
||||
Describable: model.Describable{
|
||||
Name: "PolicyA",
|
||||
Description: &descA,
|
||||
},
|
||||
OrganizationRef: &orgA, // belongs to orgA
|
||||
}
|
||||
descB := "Org B policy"
|
||||
policyB := &model.PolicyDescription{
|
||||
Describable: model.Describable{
|
||||
Name: "PolicyB",
|
||||
Description: &descB,
|
||||
},
|
||||
OrganizationRef: &orgB, // belongs to orgB
|
||||
}
|
||||
descNil := "No org policy"
|
||||
policyNil := &model.PolicyDescription{
|
||||
Describable: model.Describable{
|
||||
Name: "PolicyNil",
|
||||
Description: &descNil,
|
||||
},
|
||||
// nil => built-in
|
||||
}
|
||||
require.NoError(t, pdb.Create(ctx, policyA))
|
||||
require.NoError(t, pdb.Create(ctx, policyB))
|
||||
require.NoError(t, pdb.Create(ctx, policyNil))
|
||||
|
||||
// Suppose the requirement is: "All" returns
|
||||
// - policies for the requested org
|
||||
// - plus built-in (nil) ones
|
||||
resultsA, err := pdb.All(ctx, orgA)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, resultsA, 2) // orgA + built-in
|
||||
|
||||
var idsA []primitive.ObjectID
|
||||
for _, r := range resultsA {
|
||||
idsA = append(idsA, r.ID)
|
||||
}
|
||||
assert.Contains(t, idsA, policyA.ID)
|
||||
assert.Contains(t, idsA, policyNil.ID)
|
||||
assert.NotContains(t, idsA, policyB.ID)
|
||||
|
||||
resultsB, err := pdb.All(ctx, orgB)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, resultsB, 2) // orgB + built-in
|
||||
|
||||
var idsB []primitive.ObjectID
|
||||
for _, r := range resultsB {
|
||||
idsB = append(idsB, r.ID)
|
||||
}
|
||||
assert.Contains(t, idsB, policyB.ID)
|
||||
assert.Contains(t, idsB, policyNil.ID)
|
||||
assert.NotContains(t, idsB, policyA.ID)
|
||||
})
|
||||
|
||||
t.Run("Policies", func(t *testing.T) {
|
||||
cleanupCollection(t, ctx, db)
|
||||
|
||||
desc1 := "PolicyOne"
|
||||
pol1 := &model.PolicyDescription{
|
||||
Describable: model.Describable{
|
||||
Name: "PolicyOne",
|
||||
Description: &desc1,
|
||||
},
|
||||
}
|
||||
desc2 := "PolicyTwo"
|
||||
pol2 := &model.PolicyDescription{
|
||||
Describable: model.Describable{
|
||||
Name: "PolicyTwo",
|
||||
Description: &desc2,
|
||||
},
|
||||
}
|
||||
desc3 := "PolicyThree"
|
||||
pol3 := &model.PolicyDescription{
|
||||
Describable: model.Describable{
|
||||
Name: "PolicyThree",
|
||||
Description: &desc3,
|
||||
},
|
||||
}
|
||||
require.NoError(t, pdb.Create(ctx, pol1))
|
||||
require.NoError(t, pdb.Create(ctx, pol2))
|
||||
require.NoError(t, pdb.Create(ctx, pol3))
|
||||
|
||||
// 1) Request pol1, pol2
|
||||
results12, err := pdb.Policies(ctx, []primitive.ObjectID{pol1.ID, pol2.ID})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, results12, 2)
|
||||
// IDs might be out of order, so we do a set-like check
|
||||
var set12 []primitive.ObjectID
|
||||
for _, r := range results12 {
|
||||
set12 = append(set12, r.ID)
|
||||
}
|
||||
assert.Contains(t, set12, pol1.ID)
|
||||
assert.Contains(t, set12, pol2.ID)
|
||||
|
||||
// 2) Request pol1, pol3, plus a random ID
|
||||
fakeID := primitive.NewObjectID()
|
||||
results13Fake, err := pdb.Policies(ctx, []primitive.ObjectID{pol1.ID, pol3.ID, fakeID})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, results13Fake, 2) // pol1 + pol3 only
|
||||
var set13Fake []primitive.ObjectID
|
||||
for _, r := range results13Fake {
|
||||
set13Fake = append(set13Fake, r.ID)
|
||||
}
|
||||
assert.Contains(t, set13Fake, pol1.ID)
|
||||
assert.Contains(t, set13Fake, pol3.ID)
|
||||
|
||||
// 3) Request with empty slice => expect no results
|
||||
resultsEmpty, err := pdb.Policies(ctx, []primitive.ObjectID{})
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, resultsEmpty, 0)
|
||||
})
|
||||
}
|
||||
18
api/pkg/db/internal/mongo/policiesdb/policies.go
Normal file
18
api/pkg/db/internal/mongo/policiesdb/policies.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package policiesdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/tech/sendico/pkg/db/repository"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
mutil "github.com/tech/sendico/pkg/mutil/db"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
)
|
||||
|
||||
func (db *PoliciesDB) Policies(ctx context.Context, refs []primitive.ObjectID) ([]model.PolicyDescription, error) {
|
||||
if len(refs) == 0 {
|
||||
return []model.PolicyDescription{}, nil
|
||||
}
|
||||
filter := repository.Query().In(repository.IDField(), refs)
|
||||
return mutil.GetObjects[model.PolicyDescription](ctx, db.Logger, filter, nil, db.Repository)
|
||||
}
|
||||
12
api/pkg/db/internal/mongo/refreshtokensdb/client.go
Normal file
12
api/pkg/db/internal/mongo/refreshtokensdb/client.go
Normal file
@@ -0,0 +1,12 @@
|
||||
package refreshtokensdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
)
|
||||
|
||||
func (db *RefreshTokenDB) GetClient(ctx context.Context, clientID string) (*model.Client, error) {
|
||||
var client model.Client
|
||||
return &client, db.clients.FindOneByFilter(ctx, filterByClientId(clientID), &client)
|
||||
}
|
||||
122
api/pkg/db/internal/mongo/refreshtokensdb/crud.go
Normal file
122
api/pkg/db/internal/mongo/refreshtokensdb/crud.go
Normal file
@@ -0,0 +1,122 @@
|
||||
package refreshtokensdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/tech/sendico/pkg/db/repository"
|
||||
"github.com/tech/sendico/pkg/merrors"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
"github.com/tech/sendico/pkg/mservice"
|
||||
"github.com/tech/sendico/pkg/mutil/mzap"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func (db *RefreshTokenDB) Create(ctx context.Context, rt *model.RefreshToken) error {
|
||||
// First, try to find an existing token for this account/client/device combination
|
||||
var existing model.RefreshToken
|
||||
if rt.AccountRef == nil {
|
||||
return merrors.InvalidArgument("Account reference must have a vaild value")
|
||||
}
|
||||
if err := db.FindOne(ctx, filterByAccount(*rt.AccountRef, &rt.SessionIdentifier), &existing); err != nil {
|
||||
if errors.Is(err, merrors.ErrNoData) {
|
||||
// No existing token, create a new one
|
||||
db.Logger.Info("Registering refresh token", zap.String("client_id", rt.ClientID), zap.String("device_id", rt.DeviceID))
|
||||
return db.DBImp.Create(ctx, rt)
|
||||
}
|
||||
db.Logger.Warn("Something went wrong when checking existing sessions", zap.Error(err),
|
||||
zap.String("client_id", rt.ClientID), zap.String("device_id", rt.DeviceID))
|
||||
return err
|
||||
}
|
||||
|
||||
// Token already exists, update it with new values
|
||||
db.Logger.Info("Updating existing refresh token", zap.String("client_id", rt.ClientID), zap.String("device_id", rt.DeviceID))
|
||||
|
||||
patch := repository.Patch().
|
||||
Set(repository.Field(TokenField), rt.RefreshToken).
|
||||
Set(repository.Field(ExpiresAtField), rt.ExpiresAt).
|
||||
Set(repository.Field(UserAgentField), rt.UserAgent).
|
||||
Set(repository.Field(IPAddressField), rt.IPAddress).
|
||||
Set(repository.Field(LastUsedAtField), rt.LastUsedAt).
|
||||
Set(repository.Field(IsRevokedField), rt.IsRevoked)
|
||||
|
||||
if err := db.Patch(ctx, *existing.GetID(), patch); err != nil {
|
||||
db.Logger.Warn("Failed to patch refresh token", zap.Error(err), zap.String("client_id", rt.ClientID), zap.String("device_id", rt.DeviceID))
|
||||
return err
|
||||
}
|
||||
|
||||
// Update the ID of the input token to match the existing one
|
||||
rt.SetID(*existing.GetID())
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *RefreshTokenDB) Update(ctx context.Context, rt *model.RefreshToken) error {
|
||||
rt.LastUsedAt = time.Now()
|
||||
|
||||
// Use Patch instead of Update to avoid race conditions
|
||||
patch := repository.Patch().
|
||||
Set(repository.Field(TokenField), rt.RefreshToken).
|
||||
Set(repository.Field(ExpiresAtField), rt.ExpiresAt).
|
||||
Set(repository.Field(UserAgentField), rt.UserAgent).
|
||||
Set(repository.Field(IPAddressField), rt.IPAddress).
|
||||
Set(repository.Field(LastUsedAtField), rt.LastUsedAt).
|
||||
Set(repository.Field(IsRevokedField), rt.IsRevoked)
|
||||
|
||||
return db.Patch(ctx, *rt.GetID(), patch)
|
||||
}
|
||||
|
||||
func (db *RefreshTokenDB) Delete(ctx context.Context, tokenRef primitive.ObjectID) error {
|
||||
db.Logger.Info("Deleting refresh token", mzap.ObjRef("refresh_token_ref", tokenRef))
|
||||
return db.DBImp.Delete(ctx, tokenRef)
|
||||
}
|
||||
|
||||
func (db *RefreshTokenDB) Revoke(ctx context.Context, accountRef primitive.ObjectID, session *model.SessionIdentifier) error {
|
||||
var rt model.RefreshToken
|
||||
f := filterByAccount(accountRef, session)
|
||||
if err := db.Repository.FindOneByFilter(ctx, f, &rt); err != nil {
|
||||
if errors.Is(err, merrors.ErrNoData) {
|
||||
db.Logger.Warn("Failed to find refresh token", zap.Error(err),
|
||||
mzap.ObjRef("account_ref", accountRef), zap.String("client_id", session.ClientID), zap.String("device_id", session.DeviceID))
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// Use Patch to update the revocation status
|
||||
patch := repository.Patch().
|
||||
Set(repository.Field(IsRevokedField), true).
|
||||
Set(repository.Field(LastUsedAtField), time.Now())
|
||||
|
||||
return db.Patch(ctx, *rt.GetID(), patch)
|
||||
}
|
||||
|
||||
func (db *RefreshTokenDB) GetByCRT(ctx context.Context, t *model.ClientRefreshToken) (*model.RefreshToken, error) {
|
||||
var rt model.RefreshToken
|
||||
f := filter(&t.SessionIdentifier).And(repository.Query().Filter(repository.Field("token"), t.RefreshToken))
|
||||
if err := db.Repository.FindOneByFilter(ctx, f, &rt); err != nil {
|
||||
if !errors.Is(err, merrors.ErrNoData) {
|
||||
db.Logger.Warn("Failed to fetch refresh token", zap.Error(err),
|
||||
zap.String("client_id", t.ClientID), zap.String("device_id", t.DeviceID))
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Check if token is expired
|
||||
if rt.ExpiresAt.Before(time.Now()) {
|
||||
db.Logger.Warn("Refresh token expired", mzap.StorableRef(&rt),
|
||||
zap.String("client_id", t.ClientID), zap.String("device_id", t.DeviceID),
|
||||
zap.Time("expires_at", rt.ExpiresAt))
|
||||
return nil, merrors.AccessDenied(mservice.RefreshTokens, string(model.ActionRead), *rt.GetID())
|
||||
}
|
||||
|
||||
// Check if token is revoked
|
||||
if rt.IsRevoked {
|
||||
db.Logger.Warn("Refresh token is revoked", mzap.StorableRef(&rt),
|
||||
zap.String("client_id", t.ClientID), zap.String("device_id", t.DeviceID))
|
||||
return nil, merrors.ErrNoData
|
||||
}
|
||||
|
||||
return &rt, nil
|
||||
}
|
||||
62
api/pkg/db/internal/mongo/refreshtokensdb/db.go
Normal file
62
api/pkg/db/internal/mongo/refreshtokensdb/db.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package refreshtokensdb
|
||||
|
||||
import (
|
||||
"github.com/tech/sendico/pkg/db/repository"
|
||||
ri "github.com/tech/sendico/pkg/db/repository/index"
|
||||
"github.com/tech/sendico/pkg/db/template"
|
||||
"github.com/tech/sendico/pkg/mlogger"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
"github.com/tech/sendico/pkg/mservice"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type RefreshTokenDB struct {
|
||||
template.DBImp[*model.RefreshToken]
|
||||
clients repository.Repository
|
||||
}
|
||||
|
||||
func Create(logger mlogger.Logger, db *mongo.Database) (*RefreshTokenDB, error) {
|
||||
p := &RefreshTokenDB{
|
||||
DBImp: *template.Create[*model.RefreshToken](logger, mservice.RefreshTokens, db),
|
||||
clients: repository.CreateMongoRepository(db, mservice.Clients),
|
||||
}
|
||||
|
||||
if err := p.Repository.CreateIndex(&ri.Definition{
|
||||
Keys: []ri.Key{{Field: "token", Sort: ri.Asc}},
|
||||
Unique: true,
|
||||
}); err != nil {
|
||||
p.Logger.Error("Failed to create unique token index", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Add unique constraint on account/client/device combination
|
||||
if err := p.Repository.CreateIndex(&ri.Definition{
|
||||
Keys: []ri.Key{
|
||||
{Field: "accountRef", Sort: ri.Asc},
|
||||
{Field: "clientId", Sort: ri.Asc},
|
||||
{Field: "deviceId", Sort: ri.Asc},
|
||||
},
|
||||
Unique: true,
|
||||
}); err != nil {
|
||||
p.Logger.Error("Failed to create unique account/client/device index", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := p.Repository.CreateIndex(&ri.Definition{
|
||||
Keys: []ri.Key{{Field: IsRevokedField, Sort: ri.Asc}},
|
||||
}); err != nil {
|
||||
p.Logger.Error("Failed to create unique token revokation status index", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := p.clients.CreateIndex(&ri.Definition{
|
||||
Keys: []ri.Key{{Field: "clientId", Sort: ri.Asc}},
|
||||
Unique: true,
|
||||
}); err != nil {
|
||||
p.Logger.Error("Failed to create unique client identifier index", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return p, nil
|
||||
}
|
||||
10
api/pkg/db/internal/mongo/refreshtokensdb/fields.go
Normal file
10
api/pkg/db/internal/mongo/refreshtokensdb/fields.go
Normal file
@@ -0,0 +1,10 @@
|
||||
package refreshtokensdb
|
||||
|
||||
const (
|
||||
ExpiresAtField = "expiresAt"
|
||||
IsRevokedField = "isRevoked"
|
||||
TokenField = "token"
|
||||
UserAgentField = "userAgent"
|
||||
IPAddressField = "ipAddress"
|
||||
LastUsedAtField = "lastUsedAt"
|
||||
)
|
||||
25
api/pkg/db/internal/mongo/refreshtokensdb/filters.go
Normal file
25
api/pkg/db/internal/mongo/refreshtokensdb/filters.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package refreshtokensdb
|
||||
|
||||
import (
|
||||
"github.com/tech/sendico/pkg/db/repository"
|
||||
"github.com/tech/sendico/pkg/db/repository/builder"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
)
|
||||
|
||||
func filterByClientId(clientID string) builder.Query {
|
||||
return repository.Query().Comparison(repository.Field("clientId"), builder.Eq, clientID)
|
||||
}
|
||||
|
||||
func filter(session *model.SessionIdentifier) builder.Query {
|
||||
filter := filterByClientId(session.ClientID)
|
||||
filter.And(
|
||||
repository.Query().Comparison(repository.Field("deviceId"), builder.Eq, session.DeviceID),
|
||||
repository.Query().Comparison(repository.Field(IsRevokedField), builder.Eq, false),
|
||||
)
|
||||
return filter
|
||||
}
|
||||
|
||||
func filterByAccount(accountRef primitive.ObjectID, session *model.SessionIdentifier) builder.Query {
|
||||
return filter(session).And(repository.Query().Comparison(repository.AccountField(), builder.Eq, accountRef))
|
||||
}
|
||||
@@ -0,0 +1,639 @@
|
||||
//go:build integration
|
||||
// +build integration
|
||||
|
||||
package refreshtokensdb_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/tech/sendico/pkg/db/internal/mongo/refreshtokensdb"
|
||||
"github.com/tech/sendico/pkg/db/repository"
|
||||
"github.com/tech/sendico/pkg/db/repository/builder"
|
||||
"github.com/tech/sendico/pkg/merrors"
|
||||
factory "github.com/tech/sendico/pkg/mlogger/factory"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/testcontainers/testcontainers-go"
|
||||
"github.com/testcontainers/testcontainers-go/modules/mongodb"
|
||||
"github.com/testcontainers/testcontainers-go/wait"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
"go.mongodb.org/mongo-driver/mongo/options"
|
||||
)
|
||||
|
||||
func setupTestDB(t *testing.T) (*refreshtokensdb.RefreshTokenDB, func()) {
|
||||
// mark as helper for better test failure reporting
|
||||
t.Helper()
|
||||
|
||||
startCtx, startCancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer startCancel()
|
||||
|
||||
mongoContainer, err := mongodb.Run(startCtx,
|
||||
"mongo:latest",
|
||||
mongodb.WithUsername("root"),
|
||||
mongodb.WithPassword("password"),
|
||||
testcontainers.WithWaitStrategy(wait.ForListeningPort("27017/tcp").WithStartupTimeout(2*time.Minute)),
|
||||
)
|
||||
require.NoError(t, err, "failed to start MongoDB container")
|
||||
|
||||
mongoURI, err := mongoContainer.ConnectionString(startCtx)
|
||||
require.NoError(t, err, "failed to get MongoDB connection string")
|
||||
|
||||
clientOptions := options.Client().ApplyURI(mongoURI)
|
||||
client, err := mongo.Connect(startCtx, clientOptions)
|
||||
require.NoError(t, err, "failed to connect to MongoDB")
|
||||
|
||||
database := client.Database("test_refresh_tokens_" + t.Name())
|
||||
logger := factory.NewLogger(true)
|
||||
|
||||
db, err := refreshtokensdb.Create(logger, database)
|
||||
require.NoError(t, err, "failed to create refresh tokens db")
|
||||
|
||||
cleanup := func() {
|
||||
_ = database.Drop(context.Background())
|
||||
_ = client.Disconnect(context.Background())
|
||||
termCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
_ = mongoContainer.Terminate(termCtx)
|
||||
}
|
||||
|
||||
return db, cleanup
|
||||
}
|
||||
|
||||
func createTestRefreshToken(accountRef primitive.ObjectID, clientID, deviceID, token string) *model.RefreshToken {
|
||||
return &model.RefreshToken{
|
||||
ClientRefreshToken: model.ClientRefreshToken{
|
||||
SessionIdentifier: model.SessionIdentifier{
|
||||
ClientID: clientID,
|
||||
DeviceID: deviceID,
|
||||
},
|
||||
RefreshToken: token,
|
||||
},
|
||||
AccountBoundBase: model.AccountBoundBase{
|
||||
AccountRef: &accountRef,
|
||||
},
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
IsRevoked: false,
|
||||
UserAgent: "TestUserAgent/1.0",
|
||||
IPAddress: "192.168.1.1",
|
||||
LastUsedAt: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
func TestRefreshTokenDB_AuthenticationFlow(t *testing.T) {
|
||||
db, cleanup := setupTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("Complete_User_Authentication_Flow", func(t *testing.T) {
|
||||
// Setup: Create user and client
|
||||
userID := primitive.NewObjectID()
|
||||
clientID := "web-app"
|
||||
deviceID := "user-desktop-chrome"
|
||||
token := "refresh_token_12345"
|
||||
|
||||
// Step 1: User logs in - create initial refresh token
|
||||
refreshToken := createTestRefreshToken(userID, clientID, deviceID, token)
|
||||
err := db.Create(ctx, refreshToken)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Step 2: User uses refresh token to get new access token
|
||||
crt := &model.ClientRefreshToken{
|
||||
SessionIdentifier: model.SessionIdentifier{
|
||||
ClientID: clientID,
|
||||
DeviceID: deviceID,
|
||||
},
|
||||
RefreshToken: token,
|
||||
}
|
||||
|
||||
retrievedToken, err := db.GetByCRT(ctx, crt)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, retrievedToken.AccountRef)
|
||||
assert.Equal(t, userID, *retrievedToken.AccountRef)
|
||||
assert.Equal(t, token, retrievedToken.RefreshToken)
|
||||
assert.False(t, retrievedToken.IsRevoked)
|
||||
|
||||
// Step 3: User logs out - revoke the token
|
||||
session := &model.SessionIdentifier{
|
||||
ClientID: clientID,
|
||||
DeviceID: deviceID,
|
||||
}
|
||||
err = db.Revoke(ctx, userID, session)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Step 4: Try to use revoked token - should fail
|
||||
_, err = db.GetByCRT(ctx, crt)
|
||||
assert.Error(t, err)
|
||||
assert.True(t, errors.Is(err, merrors.ErrNoData))
|
||||
})
|
||||
|
||||
t.Run("Manual_Token_Revocation_Workaround", func(t *testing.T) {
|
||||
// Test manual revocation by directly updating the token
|
||||
userID := primitive.NewObjectID()
|
||||
clientID := "web-app"
|
||||
deviceID := "user-desktop-chrome"
|
||||
token := "manual_revoke_token_123"
|
||||
|
||||
// Step 1: Create token
|
||||
refreshToken := createTestRefreshToken(userID, clientID, deviceID, token)
|
||||
err := db.Create(ctx, refreshToken)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Step 2: Manually revoke token by updating it directly
|
||||
refreshToken.IsRevoked = true
|
||||
err = db.Update(ctx, refreshToken)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Step 3: Try to use revoked token - should fail
|
||||
crt := &model.ClientRefreshToken{
|
||||
SessionIdentifier: model.SessionIdentifier{
|
||||
ClientID: clientID,
|
||||
DeviceID: deviceID,
|
||||
},
|
||||
RefreshToken: token,
|
||||
}
|
||||
|
||||
_, err = db.GetByCRT(ctx, crt)
|
||||
assert.Error(t, err)
|
||||
assert.True(t, errors.Is(err, merrors.ErrNoData))
|
||||
})
|
||||
}
|
||||
|
||||
func TestRefreshTokenDB_MultiDeviceManagement(t *testing.T) {
|
||||
db, cleanup := setupTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("User_With_Multiple_Devices", func(t *testing.T) {
|
||||
userID := primitive.NewObjectID()
|
||||
clientID := "mobile-app"
|
||||
|
||||
// User logs in from phone
|
||||
phoneToken := createTestRefreshToken(userID, clientID, "phone-ios", "phone_token_123")
|
||||
err := db.Create(ctx, phoneToken)
|
||||
require.NoError(t, err)
|
||||
|
||||
// User logs in from tablet
|
||||
tabletToken := createTestRefreshToken(userID, clientID, "tablet-android", "tablet_token_456")
|
||||
err = db.Create(ctx, tabletToken)
|
||||
require.NoError(t, err)
|
||||
|
||||
// User logs in from desktop
|
||||
desktopToken := createTestRefreshToken(userID, clientID, "desktop-windows", "desktop_token_789")
|
||||
err = db.Create(ctx, desktopToken)
|
||||
require.NoError(t, err)
|
||||
|
||||
// User wants to logout from all devices except current (phone)
|
||||
err = db.RevokeAll(ctx, userID, "phone-ios")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Phone should still work
|
||||
phoneCRT := &model.ClientRefreshToken{
|
||||
SessionIdentifier: model.SessionIdentifier{
|
||||
ClientID: clientID,
|
||||
DeviceID: "phone-ios",
|
||||
},
|
||||
RefreshToken: "phone_token_123",
|
||||
}
|
||||
_, err = db.GetByCRT(ctx, phoneCRT)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Tablet and desktop should be revoked
|
||||
tabletCRT := &model.ClientRefreshToken{
|
||||
SessionIdentifier: model.SessionIdentifier{
|
||||
ClientID: clientID,
|
||||
DeviceID: "tablet-android",
|
||||
},
|
||||
RefreshToken: "tablet_token_456",
|
||||
}
|
||||
_, err = db.GetByCRT(ctx, tabletCRT)
|
||||
assert.Error(t, err)
|
||||
|
||||
desktopCRT := &model.ClientRefreshToken{
|
||||
SessionIdentifier: model.SessionIdentifier{
|
||||
ClientID: clientID,
|
||||
DeviceID: "desktop-windows",
|
||||
},
|
||||
RefreshToken: "desktop_token_789",
|
||||
}
|
||||
_, err = db.GetByCRT(ctx, desktopCRT)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRefreshTokenDB_TokenRotation(t *testing.T) {
|
||||
db, cleanup := setupTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("Token_Rotation_On_Use", func(t *testing.T) {
|
||||
userID := primitive.NewObjectID()
|
||||
clientID := "web-app"
|
||||
deviceID := "user-browser"
|
||||
initialToken := "initial_token_123"
|
||||
|
||||
// Create initial token
|
||||
refreshToken := createTestRefreshToken(userID, clientID, deviceID, initialToken)
|
||||
err := db.Create(ctx, refreshToken)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Simulate small delay
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
// Use token - should update LastUsedAt
|
||||
crt := &model.ClientRefreshToken{
|
||||
SessionIdentifier: model.SessionIdentifier{
|
||||
ClientID: clientID,
|
||||
DeviceID: deviceID,
|
||||
},
|
||||
RefreshToken: initialToken,
|
||||
}
|
||||
|
||||
retrievedToken, err := db.GetByCRT(ctx, crt)
|
||||
require.NoError(t, err)
|
||||
// LastUsedAt is not updated by GetByCRT; validate token data instead
|
||||
assert.Equal(t, initialToken, retrievedToken.RefreshToken)
|
||||
|
||||
// Create new token with rotated value (simulating token rotation)
|
||||
newToken := "rotated_token_456"
|
||||
retrievedToken.RefreshToken = newToken
|
||||
err = db.Update(ctx, retrievedToken)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Old token should no longer work
|
||||
_, err = db.GetByCRT(ctx, crt)
|
||||
assert.Error(t, err)
|
||||
|
||||
// New token should work
|
||||
newCRT := &model.ClientRefreshToken{
|
||||
SessionIdentifier: model.SessionIdentifier{
|
||||
ClientID: clientID,
|
||||
DeviceID: deviceID,
|
||||
},
|
||||
RefreshToken: newToken,
|
||||
}
|
||||
_, err = db.GetByCRT(ctx, newCRT)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRefreshTokenDB_SessionReplacement(t *testing.T) {
|
||||
db, cleanup := setupTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("User_Login_From_Same_Device_Twice", func(t *testing.T) {
|
||||
userID := primitive.NewObjectID()
|
||||
clientID := "web-app"
|
||||
deviceID := "user-laptop"
|
||||
|
||||
// First login
|
||||
firstToken := createTestRefreshToken(userID, clientID, deviceID, "first_token_123")
|
||||
err := db.Create(ctx, firstToken)
|
||||
require.NoError(t, err)
|
||||
firstTokenID := *firstToken.GetID()
|
||||
|
||||
// Second login from same device - should replace existing token
|
||||
secondToken := createTestRefreshToken(userID, clientID, deviceID, "second_token_456")
|
||||
err = db.Create(ctx, secondToken)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should reuse the same database record
|
||||
assert.Equal(t, firstTokenID, *secondToken.GetID())
|
||||
|
||||
// First token should no longer work
|
||||
firstCRT := &model.ClientRefreshToken{
|
||||
SessionIdentifier: model.SessionIdentifier{
|
||||
ClientID: clientID,
|
||||
DeviceID: deviceID,
|
||||
},
|
||||
RefreshToken: "first_token_123",
|
||||
}
|
||||
_, err = db.GetByCRT(ctx, firstCRT)
|
||||
assert.Error(t, err)
|
||||
|
||||
// Second token should work
|
||||
secondCRT := &model.ClientRefreshToken{
|
||||
SessionIdentifier: model.SessionIdentifier{
|
||||
ClientID: clientID,
|
||||
DeviceID: deviceID,
|
||||
},
|
||||
RefreshToken: "second_token_456",
|
||||
}
|
||||
_, err = db.GetByCRT(ctx, secondCRT)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRefreshTokenDB_ClientManagement(t *testing.T) {
|
||||
db, cleanup := setupTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("Client_CRUD_Operations", func(t *testing.T) {
|
||||
// Note: Client management is handled by a separate client database
|
||||
// This test verifies that refresh tokens work with different client IDs
|
||||
|
||||
userID := primitive.NewObjectID()
|
||||
|
||||
// Create refresh tokens for different clients
|
||||
webToken := createTestRefreshToken(userID, "web-app", "device1", "token1")
|
||||
err := db.Create(ctx, webToken)
|
||||
require.NoError(t, err)
|
||||
|
||||
mobileToken := createTestRefreshToken(userID, "mobile-app", "device2", "token2")
|
||||
err = db.Create(ctx, mobileToken)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify tokens can be retrieved by client ID
|
||||
webCRT := &model.ClientRefreshToken{
|
||||
SessionIdentifier: model.SessionIdentifier{
|
||||
ClientID: "web-app",
|
||||
DeviceID: "device1",
|
||||
},
|
||||
RefreshToken: "token1",
|
||||
}
|
||||
|
||||
retrievedToken, err := db.GetByCRT(ctx, webCRT)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "web-app", retrievedToken.ClientID)
|
||||
assert.Equal(t, "device1", retrievedToken.DeviceID)
|
||||
|
||||
mobileCRT := &model.ClientRefreshToken{
|
||||
SessionIdentifier: model.SessionIdentifier{
|
||||
ClientID: "mobile-app",
|
||||
DeviceID: "device2",
|
||||
},
|
||||
RefreshToken: "token2",
|
||||
}
|
||||
|
||||
retrievedToken, err = db.GetByCRT(ctx, mobileCRT)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "mobile-app", retrievedToken.ClientID)
|
||||
assert.Equal(t, "device2", retrievedToken.DeviceID)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRefreshTokenDB_SecurityScenarios(t *testing.T) {
|
||||
db, cleanup := setupTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("Token_Hijacking_Prevention", func(t *testing.T) {
|
||||
userID := primitive.NewObjectID()
|
||||
clientID := "web-app"
|
||||
deviceID := "user-browser"
|
||||
token := "hijacked_token_123"
|
||||
|
||||
// Create legitimate token
|
||||
refreshToken := createTestRefreshToken(userID, clientID, deviceID, token)
|
||||
err := db.Create(ctx, refreshToken)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Simulate security concern - revoke token
|
||||
session := &model.SessionIdentifier{
|
||||
ClientID: clientID,
|
||||
DeviceID: deviceID,
|
||||
}
|
||||
err = db.Revoke(ctx, userID, session)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Attacker tries to use hijacked token
|
||||
crt := &model.ClientRefreshToken{
|
||||
SessionIdentifier: model.SessionIdentifier{
|
||||
ClientID: clientID,
|
||||
DeviceID: deviceID,
|
||||
},
|
||||
RefreshToken: token,
|
||||
}
|
||||
|
||||
_, err = db.GetByCRT(ctx, crt)
|
||||
assert.Error(t, err)
|
||||
assert.True(t, errors.Is(err, merrors.ErrNoData))
|
||||
})
|
||||
|
||||
t.Run("Invalid_Token_Attempts", func(t *testing.T) {
|
||||
// Try to use completely invalid token
|
||||
crt := &model.ClientRefreshToken{
|
||||
SessionIdentifier: model.SessionIdentifier{
|
||||
ClientID: "invalid-client",
|
||||
DeviceID: "invalid-device",
|
||||
},
|
||||
RefreshToken: "invalid_token_123",
|
||||
}
|
||||
|
||||
_, err := db.GetByCRT(ctx, crt)
|
||||
assert.Error(t, err)
|
||||
assert.True(t, errors.Is(err, merrors.ErrNoData))
|
||||
})
|
||||
}
|
||||
|
||||
func TestRefreshTokenDB_ExpiredTokenHandling(t *testing.T) {
|
||||
db, cleanup := setupTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("Expired_Token_Cleanup", func(t *testing.T) {
|
||||
userID := primitive.NewObjectID()
|
||||
clientID := "web-app"
|
||||
deviceID := "user-device"
|
||||
token := "expired_token_123"
|
||||
|
||||
// Create token that expires in the past
|
||||
refreshToken := createTestRefreshToken(userID, clientID, deviceID, token)
|
||||
refreshToken.ExpiresAt = time.Now().Add(-1 * time.Hour) // Expired 1 hour ago
|
||||
err := db.Create(ctx, refreshToken)
|
||||
require.NoError(t, err)
|
||||
|
||||
// The token exists in database but is expired
|
||||
var storedToken model.RefreshToken
|
||||
err = db.Get(ctx, *refreshToken.GetID(), &storedToken)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, storedToken.ExpiresAt.Before(time.Now()))
|
||||
|
||||
// Application should reject expired tokens
|
||||
crt := &model.ClientRefreshToken{
|
||||
SessionIdentifier: model.SessionIdentifier{
|
||||
ClientID: clientID,
|
||||
DeviceID: deviceID,
|
||||
},
|
||||
RefreshToken: token,
|
||||
}
|
||||
|
||||
_, err = db.GetByCRT(ctx, crt)
|
||||
assert.Error(t, err)
|
||||
assert.True(t, errors.Is(err, merrors.ErrAccessDenied))
|
||||
})
|
||||
}
|
||||
|
||||
func TestRefreshTokenDB_ConcurrentAccess(t *testing.T) {
|
||||
db, cleanup := setupTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("Concurrent_Token_Usage", func(t *testing.T) {
|
||||
userID := primitive.NewObjectID()
|
||||
clientID := "web-app"
|
||||
deviceID := "user-device"
|
||||
token := "concurrent_token_123"
|
||||
|
||||
// Create token
|
||||
refreshToken := createTestRefreshToken(userID, clientID, deviceID, token)
|
||||
err := db.Create(ctx, refreshToken)
|
||||
require.NoError(t, err)
|
||||
|
||||
crt := &model.ClientRefreshToken{
|
||||
SessionIdentifier: model.SessionIdentifier{
|
||||
ClientID: clientID,
|
||||
DeviceID: deviceID,
|
||||
},
|
||||
RefreshToken: token,
|
||||
}
|
||||
|
||||
// Simulate concurrent access
|
||||
done := make(chan error, 2)
|
||||
|
||||
go func() {
|
||||
_, err := db.GetByCRT(ctx, crt)
|
||||
done <- err
|
||||
}()
|
||||
|
||||
go func() {
|
||||
_, err := db.GetByCRT(ctx, crt)
|
||||
done <- err
|
||||
}()
|
||||
|
||||
// Both operations should succeed
|
||||
for i := 0; i < 2; i++ {
|
||||
err := <-done
|
||||
require.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestRefreshTokenDB_EdgeCases(t *testing.T) {
|
||||
db, cleanup := setupTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("Delete_Token_By_ID", func(t *testing.T) {
|
||||
userID := primitive.NewObjectID()
|
||||
refreshToken := createTestRefreshToken(userID, "web-app", "device-1", "token_123")
|
||||
err := db.Create(ctx, refreshToken)
|
||||
require.NoError(t, err)
|
||||
|
||||
tokenID := *refreshToken.GetID()
|
||||
|
||||
// Delete token
|
||||
err = db.Delete(ctx, tokenID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Token should no longer exist
|
||||
var result model.RefreshToken
|
||||
err = db.Get(ctx, tokenID, &result)
|
||||
assert.Error(t, err)
|
||||
assert.True(t, errors.Is(err, merrors.ErrNoData))
|
||||
})
|
||||
|
||||
t.Run("Revoke_Non_Existent_Token", func(t *testing.T) {
|
||||
userID := primitive.NewObjectID()
|
||||
session := &model.SessionIdentifier{
|
||||
ClientID: "non-existent-client",
|
||||
DeviceID: "non-existent-device",
|
||||
}
|
||||
|
||||
err := db.Revoke(ctx, userID, session)
|
||||
// Should handle gracefully for non-existent tokens
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("RevokeAll_No_Other_Devices", func(t *testing.T) {
|
||||
userID := primitive.NewObjectID()
|
||||
clientID := "web-app"
|
||||
deviceID := "only-device"
|
||||
|
||||
// Create single token
|
||||
refreshToken := createTestRefreshToken(userID, clientID, deviceID, "token_123")
|
||||
err := db.Create(ctx, refreshToken)
|
||||
require.NoError(t, err)
|
||||
|
||||
// RevokeAll should not affect current device
|
||||
err = db.RevokeAll(ctx, userID, deviceID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Token should still work
|
||||
crt := &model.ClientRefreshToken{
|
||||
SessionIdentifier: model.SessionIdentifier{
|
||||
ClientID: clientID,
|
||||
DeviceID: deviceID,
|
||||
},
|
||||
RefreshToken: "token_123",
|
||||
}
|
||||
|
||||
_, err = db.GetByCRT(ctx, crt)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRefreshTokenDB_DatabaseIndexes(t *testing.T) {
|
||||
db, cleanup := setupTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("Unique_Token_Constraint", func(t *testing.T) {
|
||||
userID1 := primitive.NewObjectID()
|
||||
userID2 := primitive.NewObjectID()
|
||||
token := "duplicate_token_123"
|
||||
|
||||
// Create first token
|
||||
refreshToken1 := createTestRefreshToken(userID1, "client1", "device1", token)
|
||||
err := db.Create(ctx, refreshToken1)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Try to create second token with same token value - should fail due to unique index
|
||||
refreshToken2 := createTestRefreshToken(userID2, "client2", "device2", token)
|
||||
err = db.Create(ctx, refreshToken2)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "duplicate")
|
||||
})
|
||||
|
||||
t.Run("Query_Performance_By_Revocation_Status", func(t *testing.T) {
|
||||
userID := primitive.NewObjectID()
|
||||
clientID := "web-app"
|
||||
|
||||
// Create multiple tokens
|
||||
for i := 0; i < 10; i++ {
|
||||
token := createTestRefreshToken(userID, clientID,
|
||||
fmt.Sprintf("device_%d", i), fmt.Sprintf("token_%d", i))
|
||||
if i%2 == 0 {
|
||||
token.IsRevoked = true
|
||||
}
|
||||
err := db.Create(ctx, token)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Query should efficiently filter by revocation status
|
||||
query := repository.Query().
|
||||
Filter(repository.AccountField(), userID).
|
||||
And(repository.Query().Comparison(repository.Field(refreshtokensdb.IsRevokedField), builder.Eq, false))
|
||||
|
||||
ids, err := db.ListIDs(ctx, query)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, ids, 5) // Should find 5 non-revoked tokens
|
||||
})
|
||||
}
|
||||
24
api/pkg/db/internal/mongo/refreshtokensdb/revoke.go
Normal file
24
api/pkg/db/internal/mongo/refreshtokensdb/revoke.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package refreshtokensdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/tech/sendico/pkg/db/repository"
|
||||
"github.com/tech/sendico/pkg/db/repository/builder"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
)
|
||||
|
||||
func (db *RefreshTokenDB) RevokeAll(ctx context.Context, accountRef primitive.ObjectID, deviceID string) error {
|
||||
query := repository.Query().
|
||||
Filter(repository.AccountField(), accountRef).
|
||||
And(repository.Query().Comparison(repository.Field("deviceId"), builder.Ne, deviceID)).
|
||||
And(repository.Query().Comparison(repository.Field(IsRevokedField), builder.Eq, false))
|
||||
|
||||
patch := repository.Patch().
|
||||
Set(repository.Field(ExpiresAtField), time.Now()).
|
||||
Set(repository.Field(IsRevokedField), true)
|
||||
|
||||
_, err := db.Repository.PatchMany(ctx, query, patch)
|
||||
return err
|
||||
}
|
||||
@@ -0,0 +1,90 @@
|
||||
package builderimp
|
||||
|
||||
import (
|
||||
"github.com/tech/sendico/pkg/db/repository/builder"
|
||||
"go.mongodb.org/mongo-driver/bson"
|
||||
)
|
||||
|
||||
type literalAccumulatorImp struct {
|
||||
op builder.MongoOperation
|
||||
value any
|
||||
}
|
||||
|
||||
func (a *literalAccumulatorImp) Build() bson.D {
|
||||
return bson.D{{Key: string(a.op), Value: a.value}}
|
||||
}
|
||||
|
||||
func NewAccumulator(op builder.MongoOperation, value any) builder.Accumulator {
|
||||
return &literalAccumulatorImp{op: op, value: value}
|
||||
}
|
||||
|
||||
func AddToSet(value builder.Expression) builder.Expression {
|
||||
return newUnaryExpression(builder.AddToSet, value)
|
||||
}
|
||||
|
||||
func Size(value builder.Expression) builder.Expression {
|
||||
return newUnaryExpression(builder.Size, value)
|
||||
}
|
||||
|
||||
func Ne(left, right builder.Expression) builder.Expression {
|
||||
return newBinaryExpression(builder.Ne, left, right)
|
||||
}
|
||||
|
||||
func Sum(value any) builder.Accumulator {
|
||||
return NewAccumulator(builder.Sum, value)
|
||||
}
|
||||
|
||||
func Avg(value any) builder.Accumulator {
|
||||
return NewAccumulator(builder.Avg, value)
|
||||
}
|
||||
|
||||
func Min(value any) builder.Accumulator {
|
||||
return NewAccumulator(builder.Min, value)
|
||||
}
|
||||
|
||||
func Max(value any) builder.Accumulator {
|
||||
return NewAccumulator(builder.Max, value)
|
||||
}
|
||||
|
||||
func Eq(left, right builder.Expression) builder.Expression {
|
||||
return newBinaryExpression(builder.Eq, left, right)
|
||||
}
|
||||
|
||||
func Gt(left, right builder.Expression) builder.Expression {
|
||||
return newBinaryExpression(builder.Gt, left, right)
|
||||
}
|
||||
|
||||
func Add(left, right builder.Accumulator) builder.Accumulator {
|
||||
return newBinaryAccumulator(builder.Add, left, right)
|
||||
}
|
||||
|
||||
func Subtract(left, right builder.Accumulator) builder.Accumulator {
|
||||
return newBinaryAccumulator(builder.Subtract, left, right)
|
||||
}
|
||||
|
||||
func Multiply(left, right builder.Accumulator) builder.Accumulator {
|
||||
return newBinaryAccumulator(builder.Multiply, left, right)
|
||||
}
|
||||
|
||||
func Divide(left, right builder.Accumulator) builder.Accumulator {
|
||||
return newBinaryAccumulator(builder.Divide, left, right)
|
||||
}
|
||||
|
||||
type binaryAccumulator struct {
|
||||
op builder.MongoOperation
|
||||
left builder.Accumulator
|
||||
right builder.Accumulator
|
||||
}
|
||||
|
||||
func newBinaryAccumulator(op builder.MongoOperation, left, right builder.Accumulator) builder.Accumulator {
|
||||
return &binaryAccumulator{
|
||||
op: op,
|
||||
left: left,
|
||||
right: right,
|
||||
}
|
||||
}
|
||||
|
||||
func (b *binaryAccumulator) Build() bson.D {
|
||||
args := []any{b.left.Build(), b.right.Build()}
|
||||
return bson.D{{Key: string(b.op), Value: args}}
|
||||
}
|
||||
102
api/pkg/db/internal/mongo/repositoryimp/builderimp/alias.go
Normal file
102
api/pkg/db/internal/mongo/repositoryimp/builderimp/alias.go
Normal file
@@ -0,0 +1,102 @@
|
||||
package builderimp
|
||||
|
||||
import (
|
||||
"github.com/tech/sendico/pkg/db/repository/builder"
|
||||
"go.mongodb.org/mongo-driver/bson"
|
||||
)
|
||||
|
||||
type aliasImp struct {
|
||||
lhs builder.Field
|
||||
rhs any
|
||||
}
|
||||
|
||||
func (a *aliasImp) Field() builder.Field {
|
||||
return a.lhs
|
||||
}
|
||||
|
||||
func (a *aliasImp) Build() bson.D {
|
||||
return bson.D{{Key: a.lhs.Build(), Value: a.rhs}}
|
||||
}
|
||||
|
||||
// 1. Null alias (_id: null)
|
||||
func NewNullAlias(lhs builder.Field) builder.Alias {
|
||||
return &aliasImp{lhs: lhs, rhs: nil}
|
||||
}
|
||||
|
||||
func NewAlias(lhs builder.Field, rhs any) builder.Alias {
|
||||
return &aliasImp{lhs: lhs, rhs: rhs}
|
||||
}
|
||||
|
||||
// 2. Simple alias (_id: "$taskRef")
|
||||
func NewSimpleAlias(lhs, rhs builder.Field) builder.Alias {
|
||||
return &aliasImp{lhs: lhs, rhs: rhs.Build()}
|
||||
}
|
||||
|
||||
// 3. Complex alias (_id: { aliasName: "$originalField", ... })
|
||||
type ComplexAlias struct {
|
||||
lhs builder.Field
|
||||
rhs []builder.Alias // Correcting handling of slice of aliases
|
||||
}
|
||||
|
||||
func (a *ComplexAlias) Field() builder.Field {
|
||||
return a.lhs
|
||||
}
|
||||
|
||||
func (a *ComplexAlias) Build() bson.D {
|
||||
fieldMap := bson.M{}
|
||||
|
||||
for _, alias := range a.rhs {
|
||||
// Each alias.Build() still returns a bson.D
|
||||
aliasDoc := alias.Build()
|
||||
|
||||
// 1. Marshal the ordered D into raw BSON bytes
|
||||
raw, err := bson.Marshal(aliasDoc)
|
||||
if err != nil {
|
||||
panic("Failed to marshal alias document: " + err.Error())
|
||||
}
|
||||
|
||||
// 2. Unmarshal those bytes into an unordered M
|
||||
var docM bson.M
|
||||
if err := bson.Unmarshal(raw, &docM); err != nil {
|
||||
panic("Failed to unmarshal alias document: " + err.Error())
|
||||
}
|
||||
|
||||
// Merge into our accumulator
|
||||
for k, v := range docM {
|
||||
fieldMap[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
return bson.D{{Key: a.lhs.Build(), Value: fieldMap}}
|
||||
}
|
||||
|
||||
func NewComplexAlias(lhs builder.Field, rhs []builder.Alias) builder.Alias {
|
||||
return &ComplexAlias{lhs: lhs, rhs: rhs}
|
||||
}
|
||||
|
||||
type aliasesImp struct {
|
||||
aliases []builder.Alias
|
||||
}
|
||||
|
||||
func (a *aliasesImp) Field() builder.Field {
|
||||
if len(a.aliases) > 0 {
|
||||
return a.aliases[0].Field()
|
||||
}
|
||||
return NewFieldImp("")
|
||||
}
|
||||
|
||||
func (a *aliasesImp) Build() bson.D {
|
||||
results := make([]bson.D, 0)
|
||||
for _, alias := range a.aliases {
|
||||
results = append(results, alias.Build())
|
||||
}
|
||||
aliases := bson.D{}
|
||||
for _, r := range results {
|
||||
aliases = append(aliases, r...)
|
||||
}
|
||||
return aliases
|
||||
}
|
||||
|
||||
func NewAliases(aliases ...builder.Alias) builder.Alias {
|
||||
return &aliasesImp{aliases: aliases}
|
||||
}
|
||||
27
api/pkg/db/internal/mongo/repositoryimp/builderimp/array.go
Normal file
27
api/pkg/db/internal/mongo/repositoryimp/builderimp/array.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package builderimp
|
||||
|
||||
import (
|
||||
"github.com/tech/sendico/pkg/db/repository/builder"
|
||||
"go.mongodb.org/mongo-driver/bson"
|
||||
)
|
||||
|
||||
type arrayImp struct {
|
||||
elements []builder.Expression
|
||||
}
|
||||
|
||||
// Build renders the literal array:
|
||||
//
|
||||
// [ <expr1>, <expr2>, … ]
|
||||
func (b *arrayImp) Build() bson.A {
|
||||
arr := make(bson.A, len(b.elements))
|
||||
for i, expr := range b.elements {
|
||||
// each expr.Build() returns the raw value or sub‐expression
|
||||
arr[i] = expr.Build()
|
||||
}
|
||||
return arr
|
||||
}
|
||||
|
||||
// NewArray constructs a new array expression from the given sub‐expressions.
|
||||
func NewArray(exprs ...builder.Expression) *arrayImp {
|
||||
return &arrayImp{elements: exprs}
|
||||
}
|
||||
108
api/pkg/db/internal/mongo/repositoryimp/builderimp/expression.go
Normal file
108
api/pkg/db/internal/mongo/repositoryimp/builderimp/expression.go
Normal file
@@ -0,0 +1,108 @@
|
||||
package builderimp
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
|
||||
"github.com/tech/sendico/pkg/db/repository/builder"
|
||||
"go.mongodb.org/mongo-driver/bson"
|
||||
)
|
||||
|
||||
type literalExpression struct {
|
||||
value any
|
||||
}
|
||||
|
||||
func NewLiteralExpression(value any) builder.Expression {
|
||||
return &literalExpression{value: value}
|
||||
}
|
||||
|
||||
func (e *literalExpression) Build() any {
|
||||
return bson.D{{Key: string(builder.Literal), Value: e.value}}
|
||||
}
|
||||
|
||||
type variadicExpression struct {
|
||||
op builder.MongoOperation
|
||||
parts []builder.Expression
|
||||
}
|
||||
|
||||
func (e *variadicExpression) Build() any {
|
||||
args := make([]any, 0, len(e.parts))
|
||||
for _, p := range e.parts {
|
||||
args = append(args, p.Build())
|
||||
}
|
||||
return bson.D{{Key: string(e.op), Value: args}}
|
||||
}
|
||||
|
||||
func newVariadicExpression(op builder.MongoOperation, exprs ...builder.Expression) builder.Expression {
|
||||
return &variadicExpression{
|
||||
op: op,
|
||||
parts: exprs,
|
||||
}
|
||||
}
|
||||
|
||||
func newBinaryExpression(op builder.MongoOperation, left, right builder.Expression) builder.Expression {
|
||||
return &variadicExpression{
|
||||
op: op,
|
||||
parts: []builder.Expression{left, right},
|
||||
}
|
||||
}
|
||||
|
||||
type unaryExpression struct {
|
||||
op builder.MongoOperation
|
||||
rhs builder.Expression
|
||||
}
|
||||
|
||||
func (e *unaryExpression) Build() any {
|
||||
return bson.D{{Key: string(e.op), Value: e.rhs.Build()}}
|
||||
}
|
||||
|
||||
func newUnaryExpression(op builder.MongoOperation, right builder.Expression) builder.Expression {
|
||||
return &unaryExpression{
|
||||
op: op,
|
||||
rhs: right,
|
||||
}
|
||||
}
|
||||
|
||||
type matchExpression struct {
|
||||
op builder.MongoOperation
|
||||
rhs builder.Expression
|
||||
}
|
||||
|
||||
func (e *matchExpression) Build() any {
|
||||
return bson.E{Key: string(e.op), Value: e.rhs.Build()}
|
||||
}
|
||||
|
||||
func newMatchExpression(op builder.MongoOperation, right builder.Expression) builder.Expression {
|
||||
return &matchExpression{
|
||||
op: op,
|
||||
rhs: right,
|
||||
}
|
||||
}
|
||||
|
||||
func InRef(value builder.Field) builder.Expression {
|
||||
return newMatchExpression(builder.In, NewValue(NewRefFieldImp(value).Build()))
|
||||
}
|
||||
|
||||
type inImpl struct {
|
||||
values []any
|
||||
}
|
||||
|
||||
func (e *inImpl) Build() any {
|
||||
return bson.D{{Key: string(builder.In), Value: e.values}}
|
||||
}
|
||||
|
||||
func In(values ...any) builder.Expression {
|
||||
var flattenedValues []any
|
||||
|
||||
for _, v := range values {
|
||||
switch reflect.TypeOf(v).Kind() {
|
||||
case reflect.Slice:
|
||||
slice := reflect.ValueOf(v)
|
||||
for i := range slice.Len() {
|
||||
flattenedValues = append(flattenedValues, slice.Index(i).Interface())
|
||||
}
|
||||
default:
|
||||
flattenedValues = append(flattenedValues, v)
|
||||
}
|
||||
}
|
||||
return &inImpl{values: flattenedValues}
|
||||
}
|
||||
71
api/pkg/db/internal/mongo/repositoryimp/builderimp/field.go
Normal file
71
api/pkg/db/internal/mongo/repositoryimp/builderimp/field.go
Normal file
@@ -0,0 +1,71 @@
|
||||
package builderimp
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/tech/sendico/pkg/db/repository/builder"
|
||||
)
|
||||
|
||||
type FieldImp struct {
|
||||
fields []string
|
||||
}
|
||||
|
||||
func (b *FieldImp) Dot(field string) builder.Field {
|
||||
newFields := make([]string, len(b.fields), len(b.fields)+1)
|
||||
copy(newFields, b.fields)
|
||||
newFields = append(newFields, field)
|
||||
return &FieldImp{fields: newFields}
|
||||
}
|
||||
|
||||
func (b *FieldImp) CopyWith(field string) builder.Field {
|
||||
copiedFields := make([]string, 0, len(b.fields)+1)
|
||||
copiedFields = append(copiedFields, b.fields...)
|
||||
copiedFields = append(copiedFields, field)
|
||||
return &FieldImp{
|
||||
fields: copiedFields,
|
||||
}
|
||||
}
|
||||
|
||||
func (b *FieldImp) Build() string {
|
||||
return strings.Join(b.fields, ".")
|
||||
}
|
||||
|
||||
func NewFieldImp(baseName string) builder.Field {
|
||||
return &FieldImp{
|
||||
fields: []string{baseName},
|
||||
}
|
||||
}
|
||||
|
||||
type RefField struct {
|
||||
imp builder.Field
|
||||
}
|
||||
|
||||
func (b *RefField) Build() string {
|
||||
return "$" + b.imp.Build()
|
||||
}
|
||||
|
||||
func (b *RefField) CopyWith(field string) builder.Field {
|
||||
return &RefField{
|
||||
imp: b.imp.CopyWith(field),
|
||||
}
|
||||
}
|
||||
|
||||
func (b *RefField) Dot(field string) builder.Field {
|
||||
return &RefField{
|
||||
imp: b.imp.Dot(field),
|
||||
}
|
||||
}
|
||||
|
||||
func NewRefFieldImp(field builder.Field) builder.Field {
|
||||
return &RefField{
|
||||
imp: field,
|
||||
}
|
||||
}
|
||||
|
||||
func NewRootRef() builder.Field {
|
||||
return NewFieldImp("$$ROOT")
|
||||
}
|
||||
|
||||
func NewRemoveRef() builder.Field {
|
||||
return NewFieldImp("$$REMOVE")
|
||||
}
|
||||
137
api/pkg/db/internal/mongo/repositoryimp/builderimp/func.go
Normal file
137
api/pkg/db/internal/mongo/repositoryimp/builderimp/func.go
Normal file
@@ -0,0 +1,137 @@
|
||||
package builderimp
|
||||
|
||||
import (
|
||||
"github.com/tech/sendico/pkg/db/repository/builder"
|
||||
"go.mongodb.org/mongo-driver/bson"
|
||||
)
|
||||
|
||||
type condImp struct {
|
||||
condition builder.Expression
|
||||
ifTrue any
|
||||
ifFalse any
|
||||
}
|
||||
|
||||
func (c *condImp) Build() any {
|
||||
return bson.D{
|
||||
{Key: string(builder.Cond), Value: bson.D{
|
||||
{Key: "if", Value: c.condition.Build()},
|
||||
{Key: "then", Value: c.ifTrue},
|
||||
{Key: "else", Value: c.ifFalse},
|
||||
}},
|
||||
}
|
||||
}
|
||||
|
||||
func NewCond(condition builder.Expression, ifTrue, ifFalse any) builder.Expression {
|
||||
return &condImp{
|
||||
condition: condition,
|
||||
ifTrue: ifTrue,
|
||||
ifFalse: ifFalse,
|
||||
}
|
||||
}
|
||||
|
||||
// setUnionImp implements builder.Expression but takes only builder.Array inputs.
|
||||
type setUnionImp struct {
|
||||
inputs []builder.Expression
|
||||
}
|
||||
|
||||
// Build renders the $setUnion stage:
|
||||
//
|
||||
// { $setUnion: [ <array1>, <array2>, … ] }
|
||||
func (s *setUnionImp) Build() any {
|
||||
arr := make(bson.A, len(s.inputs))
|
||||
for i, arrayExpr := range s.inputs {
|
||||
arr[i] = arrayExpr.Build()
|
||||
}
|
||||
return bson.D{
|
||||
{Key: string(builder.SetUnion), Value: arr},
|
||||
}
|
||||
}
|
||||
|
||||
// NewSetUnion constructs a new $setUnion expression from the given Arrays.
|
||||
func NewSetUnion(arrays ...builder.Expression) builder.Expression {
|
||||
return &setUnionImp{inputs: arrays}
|
||||
}
|
||||
|
||||
type assignmentImp struct {
|
||||
field builder.Field
|
||||
expression builder.Expression
|
||||
}
|
||||
|
||||
func (a *assignmentImp) Build() bson.D {
|
||||
// Assign it to the given field name
|
||||
return bson.D{
|
||||
{Key: a.field.Build(), Value: a.expression.Build()},
|
||||
}
|
||||
}
|
||||
|
||||
// NewAssignment creates a projection assignment of the form:
|
||||
//
|
||||
// <field>: <expression>
|
||||
func NewAssignment(field builder.Field, expression builder.Expression) builder.Projection {
|
||||
return &assignmentImp{
|
||||
field: field,
|
||||
expression: expression,
|
||||
}
|
||||
}
|
||||
|
||||
type computeImp struct {
|
||||
field builder.Field
|
||||
expression builder.Expression
|
||||
}
|
||||
|
||||
func (a *computeImp) Build() any {
|
||||
return bson.D{
|
||||
{Key: string(a.field.Build()), Value: a.expression.Build()},
|
||||
}
|
||||
}
|
||||
|
||||
func NewCompute(field builder.Field, expression builder.Expression) builder.Expression {
|
||||
return &computeImp{
|
||||
field: field,
|
||||
expression: expression,
|
||||
}
|
||||
}
|
||||
|
||||
func NewIfNull(expression, replacement builder.Expression) builder.Expression {
|
||||
return newBinaryExpression(builder.IfNull, expression, replacement)
|
||||
}
|
||||
|
||||
func NewPush(expression builder.Expression) builder.Expression {
|
||||
return newUnaryExpression(builder.Push, expression)
|
||||
}
|
||||
|
||||
func NewAnd(exprs ...builder.Expression) builder.Expression {
|
||||
return newVariadicExpression(builder.And, exprs...)
|
||||
}
|
||||
|
||||
func NewOr(exprs ...builder.Expression) builder.Expression {
|
||||
return newVariadicExpression(builder.Or, exprs...)
|
||||
}
|
||||
|
||||
func NewEach(exprs ...builder.Expression) builder.Expression {
|
||||
return newVariadicExpression(builder.Each, exprs...)
|
||||
}
|
||||
|
||||
func NewLt(left, right builder.Expression) builder.Expression {
|
||||
return newBinaryExpression(builder.Lt, left, right)
|
||||
}
|
||||
|
||||
func NewNot(expression builder.Expression) builder.Expression {
|
||||
return newUnaryExpression(builder.Not, expression)
|
||||
}
|
||||
|
||||
func NewSum(expression builder.Expression) builder.Expression {
|
||||
return newUnaryExpression(builder.Sum, expression)
|
||||
}
|
||||
|
||||
func NewMin(expression builder.Expression) builder.Expression {
|
||||
return newUnaryExpression(builder.Min, expression)
|
||||
}
|
||||
|
||||
func First(expr builder.Expression) builder.Expression {
|
||||
return newUnaryExpression(builder.First, expr)
|
||||
}
|
||||
|
||||
func NewType(expr builder.Expression) builder.Expression {
|
||||
return newUnaryExpression(builder.Type, expr)
|
||||
}
|
||||
@@ -0,0 +1,35 @@
|
||||
package builderimp
|
||||
|
||||
import (
|
||||
"github.com/tech/sendico/pkg/db/repository/builder"
|
||||
"go.mongodb.org/mongo-driver/bson"
|
||||
)
|
||||
|
||||
type groupAccumulatorImp struct {
|
||||
field builder.Field
|
||||
acc builder.Accumulator
|
||||
}
|
||||
|
||||
// NewGroupAccumulator creates a new GroupAccumulator for the given field using the specified operator and value.
|
||||
func NewGroupAccumulator(field builder.Field, acc builder.Accumulator) builder.GroupAccumulator {
|
||||
return &groupAccumulatorImp{
|
||||
field: field,
|
||||
acc: acc,
|
||||
}
|
||||
}
|
||||
|
||||
func (g *groupAccumulatorImp) Field() builder.Field {
|
||||
return g.field
|
||||
}
|
||||
|
||||
func (g *groupAccumulatorImp) Accumulator() builder.Accumulator {
|
||||
return g.acc
|
||||
}
|
||||
|
||||
// Build returns a bson.E element for this group accumulator.
|
||||
func (g *groupAccumulatorImp) Build() bson.D {
|
||||
return bson.D{{
|
||||
Key: g.field.Build(),
|
||||
Value: g.acc.Build(),
|
||||
}}
|
||||
}
|
||||
60
api/pkg/db/internal/mongo/repositoryimp/builderimp/patch.go
Normal file
60
api/pkg/db/internal/mongo/repositoryimp/builderimp/patch.go
Normal file
@@ -0,0 +1,60 @@
|
||||
package builderimp
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/tech/sendico/pkg/db/repository/builder"
|
||||
"github.com/tech/sendico/pkg/db/storable"
|
||||
"go.mongodb.org/mongo-driver/bson"
|
||||
)
|
||||
|
||||
type patchBuilder struct {
|
||||
updates bson.D
|
||||
}
|
||||
|
||||
func set(field builder.Field, value any) bson.E {
|
||||
return bson.E{Key: string(builder.Set), Value: bson.D{{Key: field.Build(), Value: value}}}
|
||||
}
|
||||
|
||||
func (u *patchBuilder) Set(field builder.Field, value any) builder.Patch {
|
||||
u.updates = append(u.updates, set(field, value))
|
||||
return u
|
||||
}
|
||||
|
||||
func (u *patchBuilder) Inc(field builder.Field, value any) builder.Patch {
|
||||
u.updates = append(u.updates, bson.E{Key: string(builder.Inc), Value: bson.D{{Key: field.Build(), Value: value}}})
|
||||
return u
|
||||
}
|
||||
|
||||
func (u *patchBuilder) Unset(field builder.Field) builder.Patch {
|
||||
u.updates = append(u.updates, bson.E{Key: string(builder.Unset), Value: bson.D{{Key: field.Build(), Value: ""}}})
|
||||
return u
|
||||
}
|
||||
|
||||
func (u *patchBuilder) Rename(field builder.Field, newName string) builder.Patch {
|
||||
u.updates = append(u.updates, bson.E{Key: string(builder.Rename), Value: bson.D{{Key: field.Build(), Value: newName}}})
|
||||
return u
|
||||
}
|
||||
|
||||
func (u *patchBuilder) Push(field builder.Field, value any) builder.Patch {
|
||||
u.updates = append(u.updates, bson.E{Key: string(builder.Push), Value: bson.D{{Key: field.Build(), Value: value}}})
|
||||
return u
|
||||
}
|
||||
|
||||
func (u *patchBuilder) Pull(field builder.Field, value any) builder.Patch {
|
||||
u.updates = append(u.updates, bson.E{Key: string(builder.Pull), Value: bson.D{{Key: field.Build(), Value: value}}})
|
||||
return u
|
||||
}
|
||||
|
||||
func (u *patchBuilder) AddToSet(field builder.Field, value any) builder.Patch {
|
||||
u.updates = append(u.updates, bson.E{Key: string(builder.AddToSet), Value: bson.D{{Key: field.Build(), Value: value}}})
|
||||
return u
|
||||
}
|
||||
|
||||
func (u *patchBuilder) Build() bson.D {
|
||||
return append(u.updates, set(NewFieldImp(storable.UpdatedAtField), time.Now()))
|
||||
}
|
||||
|
||||
func NewPatchImp() builder.Patch {
|
||||
return &patchBuilder{updates: bson.D{}}
|
||||
}
|
||||
131
api/pkg/db/internal/mongo/repositoryimp/builderimp/pipeline.go
Normal file
131
api/pkg/db/internal/mongo/repositoryimp/builderimp/pipeline.go
Normal file
@@ -0,0 +1,131 @@
|
||||
package builderimp
|
||||
|
||||
import (
|
||||
"github.com/tech/sendico/pkg/db/repository/builder"
|
||||
"github.com/tech/sendico/pkg/mservice"
|
||||
"go.mongodb.org/mongo-driver/bson"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
)
|
||||
|
||||
type unwindOpts = builder.UnwindOpts
|
||||
|
||||
// UnwindOption is the same type defined in the builder package.
|
||||
type UnwindOption = builder.UnwindOption
|
||||
|
||||
// NewUnwindOpts applies all UnwindOption's to a fresh unwindOpts.
|
||||
func NewUnwindOpts(opts ...UnwindOption) *unwindOpts {
|
||||
cfg := &unwindOpts{}
|
||||
for _, opt := range opts {
|
||||
opt(cfg)
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
|
||||
type PipelineImp struct {
|
||||
pipeline mongo.Pipeline
|
||||
}
|
||||
|
||||
func (b *PipelineImp) Match(filter builder.Query) builder.Pipeline {
|
||||
b.pipeline = append(b.pipeline, filter.BuildPipeline())
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *PipelineImp) Lookup(from mservice.Type, localField, foreignField, as builder.Field) builder.Pipeline {
|
||||
b.pipeline = append(b.pipeline, bson.D{{Key: string(builder.Lookup), Value: bson.D{
|
||||
{Key: string(builder.MKFrom), Value: from},
|
||||
{Key: string(builder.MKLocalField), Value: localField.Build()},
|
||||
{Key: string(builder.MKForeignField), Value: foreignField.Build()},
|
||||
{Key: string(builder.MKAs), Value: as.Build()},
|
||||
}}})
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *PipelineImp) LookupWithPipeline(
|
||||
from mservice.Type,
|
||||
nested builder.Pipeline,
|
||||
as builder.Field,
|
||||
let *map[string]builder.Field,
|
||||
) builder.Pipeline {
|
||||
lookupStage := bson.D{
|
||||
{Key: string(builder.MKFrom), Value: from},
|
||||
{Key: string(builder.MKPipeline), Value: nested.Build()},
|
||||
{Key: string(builder.MKAs), Value: as.Build()},
|
||||
}
|
||||
|
||||
// only add "let" if provided and not empty
|
||||
if let != nil && len(*let) > 0 {
|
||||
letDoc := bson.D{}
|
||||
for varName, fld := range *let {
|
||||
letDoc = append(letDoc, bson.E{Key: varName, Value: fld.Build()})
|
||||
}
|
||||
lookupStage = append(lookupStage, bson.E{Key: string(builder.MKLet), Value: letDoc})
|
||||
}
|
||||
|
||||
b.pipeline = append(b.pipeline, bson.D{{Key: string(builder.Lookup), Value: lookupStage}})
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *PipelineImp) Unwind(path builder.Field, opts ...UnwindOption) builder.Pipeline {
|
||||
cfg := NewUnwindOpts(opts...)
|
||||
|
||||
var stageValue interface{}
|
||||
// if no options, shorthand
|
||||
if !cfg.PreserveNullAndEmptyArrays && cfg.IncludeArrayIndex == "" {
|
||||
stageValue = path.Build()
|
||||
} else {
|
||||
d := bson.D{{Key: string(builder.MKPath), Value: path.Build()}}
|
||||
if cfg.PreserveNullAndEmptyArrays {
|
||||
d = append(d, bson.E{Key: string(builder.MKPreserveNullAndEmptyArrays), Value: true})
|
||||
}
|
||||
if cfg.IncludeArrayIndex != "" {
|
||||
d = append(d, bson.E{Key: string(builder.MKIncludeArrayIndex), Value: cfg.IncludeArrayIndex})
|
||||
}
|
||||
stageValue = d
|
||||
}
|
||||
|
||||
b.pipeline = append(b.pipeline, bson.D{{Key: string(builder.Unwind), Value: stageValue}})
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *PipelineImp) Count(field builder.Field) builder.Pipeline {
|
||||
b.pipeline = append(b.pipeline, bson.D{{Key: string(builder.Count), Value: field.Build()}})
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *PipelineImp) Group(groupBy builder.Alias, accumulators ...builder.GroupAccumulator) builder.Pipeline {
|
||||
groupDoc := groupBy.Build()
|
||||
for _, acc := range accumulators {
|
||||
groupDoc = append(groupDoc, acc.Build()...)
|
||||
}
|
||||
|
||||
b.pipeline = append(b.pipeline, bson.D{
|
||||
{Key: string(builder.Group), Value: groupDoc},
|
||||
})
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *PipelineImp) Project(projections ...builder.Projection) builder.Pipeline {
|
||||
projDoc := bson.D{}
|
||||
for _, pr := range projections {
|
||||
projDoc = append(projDoc, pr.Build()...)
|
||||
}
|
||||
b.pipeline = append(b.pipeline, bson.D{{Key: string(builder.Project), Value: projDoc}})
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *PipelineImp) ReplaceRoot(newRoot builder.Expression) builder.Pipeline {
|
||||
b.pipeline = append(b.pipeline, bson.D{{Key: string(builder.ReplaceRoot), Value: bson.D{
|
||||
{Key: string(builder.MKNewRoot), Value: newRoot.Build()},
|
||||
}}})
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *PipelineImp) Build() mongo.Pipeline {
|
||||
return b.pipeline
|
||||
}
|
||||
|
||||
func NewPipelineImp() builder.Pipeline {
|
||||
return &PipelineImp{
|
||||
pipeline: mongo.Pipeline{},
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,563 @@
|
||||
package builderimp
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/tech/sendico/pkg/db/repository/builder"
|
||||
"github.com/tech/sendico/pkg/mservice"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"go.mongodb.org/mongo-driver/bson"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
"go.mongodb.org/mongo-driver/mongo/options"
|
||||
)
|
||||
|
||||
func TestNewPipelineImp(t *testing.T) {
|
||||
pipeline := NewPipelineImp()
|
||||
|
||||
assert.NotNil(t, pipeline)
|
||||
assert.IsType(t, &PipelineImp{}, pipeline)
|
||||
|
||||
// Build should return empty pipeline initially
|
||||
built := pipeline.Build()
|
||||
assert.NotNil(t, built)
|
||||
assert.Len(t, built, 0)
|
||||
}
|
||||
|
||||
func TestPipelineImp_Match(t *testing.T) {
|
||||
pipeline := NewPipelineImp()
|
||||
mockQuery := &MockQuery{
|
||||
buildPipeline: bson.D{{Key: "$match", Value: bson.D{{Key: "field", Value: "value"}}}},
|
||||
}
|
||||
|
||||
result := pipeline.Match(mockQuery)
|
||||
|
||||
// Should return self for chaining
|
||||
assert.Same(t, pipeline, result)
|
||||
|
||||
built := pipeline.Build()
|
||||
assert.Len(t, built, 1)
|
||||
assert.Equal(t, bson.D{{Key: "$match", Value: bson.D{{Key: "field", Value: "value"}}}}, built[0])
|
||||
}
|
||||
|
||||
func TestPipelineImp_Lookup(t *testing.T) {
|
||||
pipeline := NewPipelineImp()
|
||||
mockLocalField := &MockField{build: "localField"}
|
||||
mockForeignField := &MockField{build: "foreignField"}
|
||||
mockAsField := &MockField{build: "asField"}
|
||||
|
||||
result := pipeline.Lookup(mservice.Projects, mockLocalField, mockForeignField, mockAsField)
|
||||
|
||||
// Should return self for chaining
|
||||
assert.Same(t, pipeline, result)
|
||||
|
||||
built := pipeline.Build()
|
||||
assert.Len(t, built, 1)
|
||||
|
||||
expected := bson.D{{Key: string(builder.Lookup), Value: bson.D{
|
||||
{Key: string(builder.MKFrom), Value: mservice.Projects},
|
||||
{Key: string(builder.MKLocalField), Value: "localField"},
|
||||
{Key: string(builder.MKForeignField), Value: "foreignField"},
|
||||
{Key: string(builder.MKAs), Value: "asField"},
|
||||
}}}
|
||||
|
||||
assert.Equal(t, expected, built[0])
|
||||
}
|
||||
|
||||
func TestPipelineImp_LookupWithPipeline_WithoutLet(t *testing.T) {
|
||||
pipeline := NewPipelineImp()
|
||||
mockNestedPipeline := &MockPipeline{
|
||||
build: mongo.Pipeline{bson.D{{Key: "$match", Value: bson.D{{Key: "nested", Value: true}}}}},
|
||||
}
|
||||
mockAsField := &MockField{build: "asField"}
|
||||
|
||||
result := pipeline.LookupWithPipeline(mservice.Tasks, mockNestedPipeline, mockAsField, nil)
|
||||
|
||||
// Should return self for chaining
|
||||
assert.Same(t, pipeline, result)
|
||||
|
||||
built := pipeline.Build()
|
||||
assert.Len(t, built, 1)
|
||||
|
||||
expected := bson.D{{Key: string(builder.Lookup), Value: bson.D{
|
||||
{Key: string(builder.MKFrom), Value: mservice.Tasks},
|
||||
{Key: string(builder.MKPipeline), Value: mockNestedPipeline.build},
|
||||
{Key: string(builder.MKAs), Value: "asField"},
|
||||
}}}
|
||||
|
||||
assert.Equal(t, expected, built[0])
|
||||
}
|
||||
|
||||
func TestPipelineImp_LookupWithPipeline_WithLet(t *testing.T) {
|
||||
pipeline := NewPipelineImp()
|
||||
mockNestedPipeline := &MockPipeline{
|
||||
build: mongo.Pipeline{bson.D{{Key: "$match", Value: bson.D{{Key: "nested", Value: true}}}}},
|
||||
}
|
||||
mockAsField := &MockField{build: "asField"}
|
||||
mockLetField := &MockField{build: "$_id"}
|
||||
|
||||
letVars := map[string]builder.Field{
|
||||
"projRef": mockLetField,
|
||||
}
|
||||
|
||||
result := pipeline.LookupWithPipeline(mservice.Tasks, mockNestedPipeline, mockAsField, &letVars)
|
||||
|
||||
// Should return self for chaining
|
||||
assert.Same(t, pipeline, result)
|
||||
|
||||
built := pipeline.Build()
|
||||
assert.Len(t, built, 1)
|
||||
|
||||
expected := bson.D{{Key: string(builder.Lookup), Value: bson.D{
|
||||
{Key: string(builder.MKFrom), Value: mservice.Tasks},
|
||||
{Key: string(builder.MKPipeline), Value: mockNestedPipeline.build},
|
||||
{Key: string(builder.MKAs), Value: "asField"},
|
||||
{Key: string(builder.MKLet), Value: bson.D{{Key: "projRef", Value: "$_id"}}},
|
||||
}}}
|
||||
|
||||
assert.Equal(t, expected, built[0])
|
||||
}
|
||||
|
||||
func TestPipelineImp_LookupWithPipeline_WithEmptyLet(t *testing.T) {
|
||||
pipeline := NewPipelineImp()
|
||||
mockNestedPipeline := &MockPipeline{
|
||||
build: mongo.Pipeline{bson.D{{Key: "$match", Value: bson.D{{Key: "nested", Value: true}}}}},
|
||||
}
|
||||
mockAsField := &MockField{build: "asField"}
|
||||
|
||||
emptyLetVars := map[string]builder.Field{}
|
||||
|
||||
pipeline.LookupWithPipeline(mservice.Tasks, mockNestedPipeline, mockAsField, &emptyLetVars)
|
||||
|
||||
built := pipeline.Build()
|
||||
assert.Len(t, built, 1)
|
||||
|
||||
// Should not include let field when empty
|
||||
expected := bson.D{{Key: string(builder.Lookup), Value: bson.D{
|
||||
{Key: string(builder.MKFrom), Value: mservice.Tasks},
|
||||
{Key: string(builder.MKPipeline), Value: mockNestedPipeline.build},
|
||||
{Key: string(builder.MKAs), Value: "asField"},
|
||||
}}}
|
||||
|
||||
assert.Equal(t, expected, built[0])
|
||||
}
|
||||
|
||||
func TestPipelineImp_Unwind_Simple(t *testing.T) {
|
||||
pipeline := NewPipelineImp()
|
||||
mockField := &MockField{build: "$array"}
|
||||
|
||||
result := pipeline.Unwind(mockField)
|
||||
|
||||
// Should return self for chaining
|
||||
assert.Same(t, pipeline, result)
|
||||
|
||||
built := pipeline.Build()
|
||||
assert.Len(t, built, 1)
|
||||
|
||||
expected := bson.D{{Key: string(builder.Unwind), Value: "$array"}}
|
||||
assert.Equal(t, expected, built[0])
|
||||
}
|
||||
|
||||
func TestPipelineImp_Unwind_WithPreserveNullAndEmptyArrays(t *testing.T) {
|
||||
pipeline := NewPipelineImp()
|
||||
mockField := &MockField{build: "$array"}
|
||||
|
||||
// Mock the UnwindOption function
|
||||
preserveOpt := func(opts *builder.UnwindOpts) {
|
||||
opts.PreserveNullAndEmptyArrays = true
|
||||
}
|
||||
|
||||
pipeline.Unwind(mockField, preserveOpt)
|
||||
|
||||
built := pipeline.Build()
|
||||
assert.Len(t, built, 1)
|
||||
|
||||
expected := bson.D{{Key: string(builder.Unwind), Value: bson.D{
|
||||
{Key: string(builder.MKPath), Value: "$array"},
|
||||
{Key: string(builder.MKPreserveNullAndEmptyArrays), Value: true},
|
||||
}}}
|
||||
|
||||
assert.Equal(t, expected, built[0])
|
||||
}
|
||||
|
||||
func TestPipelineImp_Unwind_WithIncludeArrayIndex(t *testing.T) {
|
||||
pipeline := NewPipelineImp()
|
||||
mockField := &MockField{build: "$array"}
|
||||
|
||||
// Mock the UnwindOption function
|
||||
indexOpt := func(opts *builder.UnwindOpts) {
|
||||
opts.IncludeArrayIndex = "arrayIndex"
|
||||
}
|
||||
|
||||
pipeline.Unwind(mockField, indexOpt)
|
||||
|
||||
built := pipeline.Build()
|
||||
assert.Len(t, built, 1)
|
||||
|
||||
expected := bson.D{{Key: string(builder.Unwind), Value: bson.D{
|
||||
{Key: string(builder.MKPath), Value: "$array"},
|
||||
{Key: string(builder.MKIncludeArrayIndex), Value: "arrayIndex"},
|
||||
}}}
|
||||
|
||||
assert.Equal(t, expected, built[0])
|
||||
}
|
||||
|
||||
func TestPipelineImp_Unwind_WithBothOptions(t *testing.T) {
|
||||
pipeline := NewPipelineImp()
|
||||
mockField := &MockField{build: "$array"}
|
||||
|
||||
// Mock the UnwindOption functions
|
||||
preserveOpt := func(opts *builder.UnwindOpts) {
|
||||
opts.PreserveNullAndEmptyArrays = true
|
||||
}
|
||||
indexOpt := func(opts *builder.UnwindOpts) {
|
||||
opts.IncludeArrayIndex = "arrayIndex"
|
||||
}
|
||||
|
||||
pipeline.Unwind(mockField, preserveOpt, indexOpt)
|
||||
|
||||
built := pipeline.Build()
|
||||
assert.Len(t, built, 1)
|
||||
|
||||
expected := bson.D{{Key: string(builder.Unwind), Value: bson.D{
|
||||
{Key: string(builder.MKPath), Value: "$array"},
|
||||
{Key: string(builder.MKPreserveNullAndEmptyArrays), Value: true},
|
||||
{Key: string(builder.MKIncludeArrayIndex), Value: "arrayIndex"},
|
||||
}}}
|
||||
|
||||
assert.Equal(t, expected, built[0])
|
||||
}
|
||||
|
||||
func TestPipelineImp_Count(t *testing.T) {
|
||||
pipeline := NewPipelineImp()
|
||||
mockField := &MockField{build: "totalCount"}
|
||||
|
||||
result := pipeline.Count(mockField)
|
||||
|
||||
// Should return self for chaining
|
||||
assert.Same(t, pipeline, result)
|
||||
|
||||
built := pipeline.Build()
|
||||
assert.Len(t, built, 1)
|
||||
|
||||
expected := bson.D{{Key: string(builder.Count), Value: "totalCount"}}
|
||||
assert.Equal(t, expected, built[0])
|
||||
}
|
||||
|
||||
func TestPipelineImp_Group(t *testing.T) {
|
||||
pipeline := NewPipelineImp()
|
||||
mockAlias := &MockAlias{
|
||||
build: bson.D{{Key: "_id", Value: "$field"}},
|
||||
field: &MockField{build: "_id"},
|
||||
}
|
||||
mockAccumulator := &MockGroupAccumulator{
|
||||
build: bson.D{{Key: "count", Value: bson.D{{Key: "$sum", Value: 1}}}},
|
||||
}
|
||||
|
||||
result := pipeline.Group(mockAlias, mockAccumulator)
|
||||
|
||||
// Should return self for chaining
|
||||
assert.Same(t, pipeline, result)
|
||||
|
||||
built := pipeline.Build()
|
||||
assert.Len(t, built, 1)
|
||||
|
||||
expected := bson.D{{Key: string(builder.Group), Value: bson.D{
|
||||
{Key: "_id", Value: "$field"},
|
||||
{Key: "count", Value: bson.D{{Key: "$sum", Value: 1}}},
|
||||
}}}
|
||||
|
||||
assert.Equal(t, expected, built[0])
|
||||
}
|
||||
|
||||
func TestPipelineImp_Group_MultipleAccumulators(t *testing.T) {
|
||||
pipeline := NewPipelineImp()
|
||||
mockAlias := &MockAlias{
|
||||
build: bson.D{{Key: "_id", Value: "$field"}},
|
||||
field: &MockField{build: "_id"},
|
||||
}
|
||||
mockAccumulator1 := &MockGroupAccumulator{
|
||||
build: bson.D{{Key: "count", Value: bson.D{{Key: "$sum", Value: 1}}}},
|
||||
}
|
||||
mockAccumulator2 := &MockGroupAccumulator{
|
||||
build: bson.D{{Key: "total", Value: bson.D{{Key: "$sum", Value: "$amount"}}}},
|
||||
}
|
||||
|
||||
pipeline.Group(mockAlias, mockAccumulator1, mockAccumulator2)
|
||||
|
||||
built := pipeline.Build()
|
||||
assert.Len(t, built, 1)
|
||||
|
||||
expected := bson.D{{Key: string(builder.Group), Value: bson.D{
|
||||
{Key: "_id", Value: "$field"},
|
||||
{Key: "count", Value: bson.D{{Key: "$sum", Value: 1}}},
|
||||
{Key: "total", Value: bson.D{{Key: "$sum", Value: "$amount"}}},
|
||||
}}}
|
||||
|
||||
assert.Equal(t, expected, built[0])
|
||||
}
|
||||
|
||||
func TestPipelineImp_Project(t *testing.T) {
|
||||
pipeline := NewPipelineImp()
|
||||
mockProjection := &MockProjection{
|
||||
build: bson.D{{Key: "field1", Value: 1}},
|
||||
}
|
||||
|
||||
result := pipeline.Project(mockProjection)
|
||||
|
||||
// Should return self for chaining
|
||||
assert.Same(t, pipeline, result)
|
||||
|
||||
built := pipeline.Build()
|
||||
assert.Len(t, built, 1)
|
||||
|
||||
expected := bson.D{{Key: string(builder.Project), Value: bson.D{
|
||||
{Key: "field1", Value: 1},
|
||||
}}}
|
||||
|
||||
assert.Equal(t, expected, built[0])
|
||||
}
|
||||
|
||||
func TestPipelineImp_Project_MultipleProjections(t *testing.T) {
|
||||
pipeline := NewPipelineImp()
|
||||
mockProjection1 := &MockProjection{
|
||||
build: bson.D{{Key: "field1", Value: 1}},
|
||||
}
|
||||
mockProjection2 := &MockProjection{
|
||||
build: bson.D{{Key: "field2", Value: 0}},
|
||||
}
|
||||
|
||||
pipeline.Project(mockProjection1, mockProjection2)
|
||||
|
||||
built := pipeline.Build()
|
||||
assert.Len(t, built, 1)
|
||||
|
||||
expected := bson.D{{Key: string(builder.Project), Value: bson.D{
|
||||
{Key: "field1", Value: 1},
|
||||
{Key: "field2", Value: 0},
|
||||
}}}
|
||||
|
||||
assert.Equal(t, expected, built[0])
|
||||
}
|
||||
|
||||
func TestPipelineImp_ChainedOperations(t *testing.T) {
|
||||
pipeline := NewPipelineImp()
|
||||
|
||||
// Create mocks
|
||||
mockQuery := &MockQuery{
|
||||
buildPipeline: bson.D{{Key: "$match", Value: bson.D{{Key: "status", Value: "active"}}}},
|
||||
}
|
||||
mockLocalField := &MockField{build: "userId"}
|
||||
mockForeignField := &MockField{build: "_id"}
|
||||
mockAsField := &MockField{build: "user"}
|
||||
mockUnwindField := &MockField{build: "$user"}
|
||||
mockProjection := &MockProjection{
|
||||
build: bson.D{{Key: "name", Value: "$user.name"}},
|
||||
}
|
||||
|
||||
// Chain operations
|
||||
result := pipeline.
|
||||
Match(mockQuery).
|
||||
Lookup(mservice.Accounts, mockLocalField, mockForeignField, mockAsField).
|
||||
Unwind(mockUnwindField).
|
||||
Project(mockProjection)
|
||||
|
||||
// Should return self for chaining
|
||||
assert.Same(t, pipeline, result)
|
||||
|
||||
built := pipeline.Build()
|
||||
assert.Len(t, built, 4)
|
||||
|
||||
// Verify each stage
|
||||
assert.Equal(t, bson.D{{Key: "$match", Value: bson.D{{Key: "status", Value: "active"}}}}, built[0])
|
||||
|
||||
expectedLookup := bson.D{{Key: string(builder.Lookup), Value: bson.D{
|
||||
{Key: string(builder.MKFrom), Value: mservice.Accounts},
|
||||
{Key: string(builder.MKLocalField), Value: "userId"},
|
||||
{Key: string(builder.MKForeignField), Value: "_id"},
|
||||
{Key: string(builder.MKAs), Value: "user"},
|
||||
}}}
|
||||
assert.Equal(t, expectedLookup, built[1])
|
||||
|
||||
assert.Equal(t, bson.D{{Key: string(builder.Unwind), Value: "$user"}}, built[2])
|
||||
|
||||
expectedProject := bson.D{{Key: string(builder.Project), Value: bson.D{
|
||||
{Key: "name", Value: "$user.name"},
|
||||
}}}
|
||||
assert.Equal(t, expectedProject, built[3])
|
||||
}
|
||||
|
||||
func TestNewUnwindOpts(t *testing.T) {
|
||||
t.Run("NoOptions", func(t *testing.T) {
|
||||
opts := NewUnwindOpts()
|
||||
|
||||
assert.NotNil(t, opts)
|
||||
assert.False(t, opts.PreserveNullAndEmptyArrays)
|
||||
assert.Empty(t, opts.IncludeArrayIndex)
|
||||
})
|
||||
|
||||
t.Run("WithPreserveOption", func(t *testing.T) {
|
||||
preserveOpt := func(opts *builder.UnwindOpts) {
|
||||
opts.PreserveNullAndEmptyArrays = true
|
||||
}
|
||||
|
||||
opts := NewUnwindOpts(preserveOpt)
|
||||
|
||||
assert.True(t, opts.PreserveNullAndEmptyArrays)
|
||||
assert.Empty(t, opts.IncludeArrayIndex)
|
||||
})
|
||||
|
||||
t.Run("WithIndexOption", func(t *testing.T) {
|
||||
indexOpt := func(opts *builder.UnwindOpts) {
|
||||
opts.IncludeArrayIndex = "index"
|
||||
}
|
||||
|
||||
opts := NewUnwindOpts(indexOpt)
|
||||
|
||||
assert.False(t, opts.PreserveNullAndEmptyArrays)
|
||||
assert.Equal(t, "index", opts.IncludeArrayIndex)
|
||||
})
|
||||
|
||||
t.Run("WithBothOptions", func(t *testing.T) {
|
||||
preserveOpt := func(opts *builder.UnwindOpts) {
|
||||
opts.PreserveNullAndEmptyArrays = true
|
||||
}
|
||||
indexOpt := func(opts *builder.UnwindOpts) {
|
||||
opts.IncludeArrayIndex = "index"
|
||||
}
|
||||
|
||||
opts := NewUnwindOpts(preserveOpt, indexOpt)
|
||||
|
||||
assert.True(t, opts.PreserveNullAndEmptyArrays)
|
||||
assert.Equal(t, "index", opts.IncludeArrayIndex)
|
||||
})
|
||||
}
|
||||
|
||||
// Mock implementations for testing
|
||||
|
||||
type MockQuery struct {
|
||||
buildPipeline bson.D
|
||||
}
|
||||
|
||||
func (m *MockQuery) And(filters ...builder.Query) builder.Query { return m }
|
||||
func (m *MockQuery) Or(filters ...builder.Query) builder.Query { return m }
|
||||
func (m *MockQuery) Filter(field builder.Field, value any) builder.Query { return m }
|
||||
func (m *MockQuery) Expression(value builder.Expression) builder.Query { return m }
|
||||
func (m *MockQuery) Comparison(field builder.Field, operator builder.MongoOperation, value any) builder.Query {
|
||||
return m
|
||||
}
|
||||
func (m *MockQuery) RegEx(field builder.Field, pattern, options string) builder.Query { return m }
|
||||
func (m *MockQuery) In(field builder.Field, values ...any) builder.Query { return m }
|
||||
func (m *MockQuery) NotIn(field builder.Field, values ...any) builder.Query { return m }
|
||||
func (m *MockQuery) Sort(field builder.Field, ascending bool) builder.Query { return m }
|
||||
func (m *MockQuery) Limit(limit *int64) builder.Query { return m }
|
||||
func (m *MockQuery) Offset(offset *int64) builder.Query { return m }
|
||||
func (m *MockQuery) Archived(isArchived *bool) builder.Query { return m }
|
||||
func (m *MockQuery) BuildPipeline() bson.D { return m.buildPipeline }
|
||||
func (m *MockQuery) BuildQuery() bson.D { return bson.D{} }
|
||||
func (m *MockQuery) BuildOptions() *options.FindOptions { return &options.FindOptions{} }
|
||||
|
||||
type MockField struct {
|
||||
build string
|
||||
}
|
||||
|
||||
func (m *MockField) Dot(field string) builder.Field { return &MockField{build: m.build + "." + field} }
|
||||
func (m *MockField) CopyWith(field string) builder.Field { return &MockField{build: field} }
|
||||
func (m *MockField) Build() string { return m.build }
|
||||
|
||||
type MockPipeline struct {
|
||||
build mongo.Pipeline
|
||||
}
|
||||
|
||||
func (m *MockPipeline) Match(filter builder.Query) builder.Pipeline { return m }
|
||||
func (m *MockPipeline) Lookup(from mservice.Type, localField, foreignField, as builder.Field) builder.Pipeline {
|
||||
return m
|
||||
}
|
||||
func (m *MockPipeline) LookupWithPipeline(from mservice.Type, pipeline builder.Pipeline, as builder.Field, let *map[string]builder.Field) builder.Pipeline {
|
||||
return m
|
||||
}
|
||||
func (m *MockPipeline) Unwind(path builder.Field, opts ...UnwindOption) builder.Pipeline { return m }
|
||||
func (m *MockPipeline) Count(field builder.Field) builder.Pipeline { return m }
|
||||
func (m *MockPipeline) Group(groupBy builder.Alias, accumulators ...builder.GroupAccumulator) builder.Pipeline {
|
||||
return m
|
||||
}
|
||||
func (m *MockPipeline) Project(projections ...builder.Projection) builder.Pipeline { return m }
|
||||
func (m *MockPipeline) ReplaceRoot(newRoot builder.Expression) builder.Pipeline { return m }
|
||||
func (m *MockPipeline) Build() mongo.Pipeline { return m.build }
|
||||
|
||||
type MockAlias struct {
|
||||
build bson.D
|
||||
field builder.Field
|
||||
}
|
||||
|
||||
func (m *MockAlias) Field() builder.Field { return m.field }
|
||||
func (m *MockAlias) Build() bson.D { return m.build }
|
||||
|
||||
type MockGroupAccumulator struct {
|
||||
build bson.D
|
||||
}
|
||||
|
||||
func (m *MockGroupAccumulator) Build() bson.D { return m.build }
|
||||
|
||||
type MockProjection struct {
|
||||
build bson.D
|
||||
}
|
||||
|
||||
func (m *MockProjection) Build() bson.D { return m.build }
|
||||
|
||||
func TestPipelineImp_ReplaceRoot(t *testing.T) {
|
||||
pipeline := NewPipelineImp()
|
||||
mockExpr := &MockExpression{build: "$newRoot"}
|
||||
|
||||
result := pipeline.ReplaceRoot(mockExpr)
|
||||
|
||||
// Should return self for chaining
|
||||
assert.Same(t, pipeline, result)
|
||||
|
||||
built := pipeline.Build()
|
||||
assert.Len(t, built, 1)
|
||||
|
||||
expected := bson.D{{Key: string(builder.ReplaceRoot), Value: bson.D{
|
||||
{Key: string(builder.MKNewRoot), Value: "$newRoot"},
|
||||
}}}
|
||||
|
||||
assert.Equal(t, expected, built[0])
|
||||
}
|
||||
|
||||
func TestPipelineImp_ReplaceRoot_WithNestedField(t *testing.T) {
|
||||
pipeline := NewPipelineImp()
|
||||
mockExpr := &MockExpression{build: "$document.data"}
|
||||
|
||||
pipeline.ReplaceRoot(mockExpr)
|
||||
|
||||
built := pipeline.Build()
|
||||
assert.Len(t, built, 1)
|
||||
|
||||
expected := bson.D{{Key: string(builder.ReplaceRoot), Value: bson.D{
|
||||
{Key: string(builder.MKNewRoot), Value: "$document.data"},
|
||||
}}}
|
||||
|
||||
assert.Equal(t, expected, built[0])
|
||||
}
|
||||
|
||||
func TestPipelineImp_ReplaceRoot_WithExpression(t *testing.T) {
|
||||
pipeline := NewPipelineImp()
|
||||
// Mock a complex expression like { $mergeObjects: [...] }
|
||||
mockExpr := &MockExpression{build: bson.D{{Key: "$mergeObjects", Value: bson.A{"$field1", "$field2"}}}}
|
||||
|
||||
pipeline.ReplaceRoot(mockExpr)
|
||||
|
||||
built := pipeline.Build()
|
||||
assert.Len(t, built, 1)
|
||||
|
||||
expected := bson.D{{Key: string(builder.ReplaceRoot), Value: bson.D{
|
||||
{Key: string(builder.MKNewRoot), Value: bson.D{{Key: "$mergeObjects", Value: bson.A{"$field1", "$field2"}}}},
|
||||
}}}
|
||||
|
||||
assert.Equal(t, expected, built[0])
|
||||
}
|
||||
|
||||
type MockExpression struct {
|
||||
build any
|
||||
}
|
||||
|
||||
func (m *MockExpression) Build() any { return m.build }
|
||||
@@ -0,0 +1,97 @@
|
||||
package builderimp
|
||||
|
||||
import (
|
||||
"github.com/tech/sendico/pkg/db/repository/builder"
|
||||
"go.mongodb.org/mongo-driver/bson"
|
||||
)
|
||||
|
||||
// projectionExprImp is a concrete implementation of builder.Projection
|
||||
// that projects a field using a custom expression.
|
||||
type projectionExprImp struct {
|
||||
expr builder.Expression // The expression for this projection.
|
||||
field builder.Field // The field name for the projected field.
|
||||
}
|
||||
|
||||
// Field returns the field being projected.
|
||||
func (p *projectionExprImp) Field() builder.Field {
|
||||
return p.field
|
||||
}
|
||||
|
||||
// Expression returns the expression for the projection.
|
||||
func (p *projectionExprImp) Expression() builder.Expression {
|
||||
return p.expr
|
||||
}
|
||||
|
||||
// Build returns the built expression. If no expression is provided, returns 1.
|
||||
func (p *projectionExprImp) Build() bson.D {
|
||||
if p.expr == nil {
|
||||
return bson.D{{Key: p.field.Build(), Value: 1}}
|
||||
}
|
||||
return bson.D{{Key: p.field.Build(), Value: p.expr.Build()}}
|
||||
}
|
||||
|
||||
// NewProjectionExpr creates a new Projection for a given field and expression.
|
||||
func NewProjectionExpr(field builder.Field, expr builder.Expression) builder.Projection {
|
||||
return &projectionExprImp{field: field, expr: expr}
|
||||
}
|
||||
|
||||
// aliasProjectionImp is a concrete implementation of builder.Projection
|
||||
// that projects an alias (renaming a field or expression).
|
||||
type aliasProjectionImp struct {
|
||||
alias builder.Alias // The alias for this projection.
|
||||
}
|
||||
|
||||
// Field returns the field being projected (via the alias).
|
||||
func (p *aliasProjectionImp) Field() builder.Field {
|
||||
return p.alias.Field()
|
||||
}
|
||||
|
||||
// Expression returns no additional expression for an alias projection.
|
||||
func (p *aliasProjectionImp) Expression() builder.Expression {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Build returns the built alias expression.
|
||||
func (p *aliasProjectionImp) Build() bson.D {
|
||||
return p.alias.Build()
|
||||
}
|
||||
|
||||
// NewAliasProjection creates a new Projection that renames or wraps an existing field or expression.
|
||||
func NewAliasProjection(alias builder.Alias) builder.Projection {
|
||||
return &aliasProjectionImp{alias: alias}
|
||||
}
|
||||
|
||||
// sinkProjectionImp is a simple include/exclude projection (0 or 1).
|
||||
type sinkProjectionImp struct {
|
||||
field builder.Field // The field name for the projected field.
|
||||
val int // 1 to include, 0 to exclude.
|
||||
}
|
||||
|
||||
// Expression returns no expression for a sink projection.
|
||||
func (p *sinkProjectionImp) Expression() builder.Expression {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Build returns the include/exclude projection.
|
||||
func (p *sinkProjectionImp) Build() bson.D {
|
||||
return bson.D{{Key: p.field.Build(), Value: p.val}}
|
||||
}
|
||||
|
||||
// NewSinkProjection creates a new Projection that includes (true) or excludes (false) a field.
|
||||
func NewSinkProjection(field builder.Field, include bool) builder.Projection {
|
||||
val := 0
|
||||
if include {
|
||||
val = 1
|
||||
}
|
||||
return &sinkProjectionImp{field: field, val: val}
|
||||
}
|
||||
|
||||
// IncludeField returns a projection including the given field.
|
||||
func IncludeField(field builder.Field) builder.Projection {
|
||||
return NewSinkProjection(field, true)
|
||||
}
|
||||
|
||||
// ExcludeField returns a projection excluding the given field.
|
||||
func ExcludeField(field builder.Field) builder.Projection {
|
||||
return NewSinkProjection(field, false)
|
||||
}
|
||||
156
api/pkg/db/internal/mongo/repositoryimp/builderimp/query.go
Normal file
156
api/pkg/db/internal/mongo/repositoryimp/builderimp/query.go
Normal file
@@ -0,0 +1,156 @@
|
||||
package builderimp
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
|
||||
"github.com/tech/sendico/pkg/db/repository/builder"
|
||||
"github.com/tech/sendico/pkg/db/storable"
|
||||
"go.mongodb.org/mongo-driver/bson"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
"go.mongodb.org/mongo-driver/mongo/options"
|
||||
)
|
||||
|
||||
type QueryImp struct {
|
||||
filter bson.D
|
||||
sort bson.D
|
||||
limit *int64
|
||||
offset *int64
|
||||
}
|
||||
|
||||
func (b *QueryImp) Filter(field builder.Field, value any) builder.Query {
|
||||
b.filter = append(b.filter, bson.E{Key: field.Build(), Value: value})
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *QueryImp) And(filters ...builder.Query) builder.Query {
|
||||
andFilters := bson.A{}
|
||||
for _, f := range filters {
|
||||
andFilters = append(andFilters, f.BuildQuery())
|
||||
}
|
||||
b.filter = append(b.filter, bson.E{Key: string(builder.And), Value: andFilters})
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *QueryImp) Or(filters ...builder.Query) builder.Query {
|
||||
orFilters := bson.A{}
|
||||
for _, f := range filters {
|
||||
orFilters = append(orFilters, f.BuildQuery())
|
||||
}
|
||||
b.filter = append(b.filter, bson.E{Key: string(builder.Or), Value: orFilters})
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *QueryImp) Comparison(field builder.Field, operator builder.MongoOperation, value any) builder.Query {
|
||||
b.filter = append(b.filter, bson.E{Key: field.Build(), Value: bson.M{string(operator): value}})
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *QueryImp) Expression(value builder.Expression) builder.Query {
|
||||
b.filter = append(b.filter, bson.E{Key: string(builder.Expr), Value: value.Build()})
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *QueryImp) RegEx(field builder.Field, pattern, options string) builder.Query {
|
||||
b.filter = append(b.filter, bson.E{Key: field.Build(), Value: primitive.Regex{Pattern: pattern, Options: options}})
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *QueryImp) opIn(field builder.Field, op builder.MongoOperation, values ...any) builder.Query {
|
||||
var flattenedValues []any
|
||||
|
||||
for _, v := range values {
|
||||
switch reflect.TypeOf(v).Kind() {
|
||||
case reflect.Slice:
|
||||
slice := reflect.ValueOf(v)
|
||||
for i := range slice.Len() {
|
||||
flattenedValues = append(flattenedValues, slice.Index(i).Interface())
|
||||
}
|
||||
default:
|
||||
flattenedValues = append(flattenedValues, v)
|
||||
}
|
||||
}
|
||||
|
||||
b.filter = append(b.filter, bson.E{Key: field.Build(), Value: bson.M{string(op): flattenedValues}})
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *QueryImp) NotIn(field builder.Field, values ...any) builder.Query {
|
||||
return b.opIn(field, builder.NotIn, values...)
|
||||
}
|
||||
|
||||
func (b *QueryImp) In(field builder.Field, values ...any) builder.Query {
|
||||
return b.opIn(field, builder.In, values...)
|
||||
}
|
||||
|
||||
func (b *QueryImp) Archived(isArchived *bool) builder.Query {
|
||||
if isArchived == nil {
|
||||
return b
|
||||
}
|
||||
return b.And(NewQueryImp().Filter(NewFieldImp(storable.IsArchivedField), *isArchived))
|
||||
}
|
||||
|
||||
func (b *QueryImp) Sort(field builder.Field, ascending bool) builder.Query {
|
||||
order := 1
|
||||
if !ascending {
|
||||
order = -1
|
||||
}
|
||||
b.sort = append(b.sort, bson.E{Key: field.Build(), Value: order})
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *QueryImp) BuildPipeline() bson.D {
|
||||
query := bson.D{}
|
||||
|
||||
if len(b.filter) > 0 {
|
||||
query = append(query, bson.E{Key: string(builder.Match), Value: b.filter})
|
||||
}
|
||||
|
||||
if len(b.sort) > 0 {
|
||||
query = append(query, bson.E{Key: string(builder.Sort), Value: b.sort})
|
||||
}
|
||||
|
||||
if b.limit != nil {
|
||||
query = append(query, bson.E{Key: string(builder.Limit), Value: *b.limit})
|
||||
}
|
||||
|
||||
if b.offset != nil {
|
||||
query = append(query, bson.E{Key: string(builder.Skip), Value: *b.offset})
|
||||
}
|
||||
|
||||
return query
|
||||
}
|
||||
|
||||
func (b *QueryImp) BuildQuery() bson.D {
|
||||
return b.filter
|
||||
}
|
||||
|
||||
func (b *QueryImp) Limit(limit *int64) builder.Query {
|
||||
b.limit = limit
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *QueryImp) Offset(offset *int64) builder.Query {
|
||||
b.offset = offset
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *QueryImp) BuildOptions() *options.FindOptions {
|
||||
opts := options.Find()
|
||||
if b.limit != nil {
|
||||
opts.SetLimit(*b.limit)
|
||||
}
|
||||
if b.offset != nil {
|
||||
opts.SetSkip(*b.offset)
|
||||
}
|
||||
if len(b.sort) > 0 {
|
||||
opts.SetSort(b.sort)
|
||||
}
|
||||
return opts
|
||||
}
|
||||
|
||||
func NewQueryImp() builder.Query {
|
||||
return &QueryImp{
|
||||
filter: bson.D{},
|
||||
sort: bson.D{},
|
||||
}
|
||||
}
|
||||
17
api/pkg/db/internal/mongo/repositoryimp/builderimp/value.go
Normal file
17
api/pkg/db/internal/mongo/repositoryimp/builderimp/value.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package builderimp
|
||||
|
||||
import (
|
||||
"github.com/tech/sendico/pkg/db/repository/builder"
|
||||
)
|
||||
|
||||
type valueImp struct {
|
||||
value any
|
||||
}
|
||||
|
||||
func (v *valueImp) Build() any {
|
||||
return v.value
|
||||
}
|
||||
|
||||
func NewValue(value any) builder.Value {
|
||||
return &valueImp{value: value}
|
||||
}
|
||||
50
api/pkg/db/internal/mongo/repositoryimp/index.go
Normal file
50
api/pkg/db/internal/mongo/repositoryimp/index.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package repositoryimp
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
ri "github.com/tech/sendico/pkg/db/repository/index"
|
||||
"github.com/tech/sendico/pkg/merrors"
|
||||
"go.mongodb.org/mongo-driver/bson"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
"go.mongodb.org/mongo-driver/mongo/options"
|
||||
)
|
||||
|
||||
func (r *MongoRepository) CreateIndex(def *ri.Definition) error {
|
||||
if r.collection == nil {
|
||||
return merrors.NoData("data collection is not set")
|
||||
}
|
||||
if len(def.Keys) == 0 {
|
||||
return merrors.InvalidArgument("Index definition has no keys")
|
||||
}
|
||||
|
||||
// ----- build BSON keys --------------------------------------------------
|
||||
keys := bson.D{}
|
||||
for _, k := range def.Keys {
|
||||
var value any
|
||||
switch {
|
||||
case k.Type != "":
|
||||
value = k.Type // text, 2dsphere, …
|
||||
case k.Sort == ri.Desc:
|
||||
value = int8(-1)
|
||||
default:
|
||||
value = int8(1) // default to Asc
|
||||
}
|
||||
keys = append(keys, bson.E{Key: k.Field, Value: value})
|
||||
}
|
||||
|
||||
opts := options.Index().
|
||||
SetUnique(def.Unique)
|
||||
if def.TTL != nil {
|
||||
opts.SetExpireAfterSeconds(*def.TTL)
|
||||
}
|
||||
if def.Name != "" {
|
||||
opts.SetName(def.Name)
|
||||
}
|
||||
|
||||
_, err := r.collection.Indexes().CreateOne(
|
||||
context.Background(),
|
||||
mongo.IndexModel{Keys: keys, Options: opts},
|
||||
)
|
||||
return err
|
||||
}
|
||||
250
api/pkg/db/internal/mongo/repositoryimp/repository.go
Normal file
250
api/pkg/db/internal/mongo/repositoryimp/repository.go
Normal file
@@ -0,0 +1,250 @@
|
||||
package repositoryimp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/tech/sendico/pkg/db/repository/builder"
|
||||
rd "github.com/tech/sendico/pkg/db/repository/decoder"
|
||||
"github.com/tech/sendico/pkg/db/storable"
|
||||
"github.com/tech/sendico/pkg/merrors"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
"go.mongodb.org/mongo-driver/bson"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
"go.mongodb.org/mongo-driver/mongo/options"
|
||||
)
|
||||
|
||||
type MongoRepository struct {
|
||||
collectionName string
|
||||
collection *mongo.Collection
|
||||
}
|
||||
|
||||
func idFilter(id primitive.ObjectID) bson.D {
|
||||
return bson.D{
|
||||
{Key: storable.IDField, Value: id},
|
||||
}
|
||||
}
|
||||
|
||||
func NewMongoRepository(db *mongo.Database, collection string) *MongoRepository {
|
||||
return &MongoRepository{
|
||||
collectionName: collection,
|
||||
collection: db.Collection(collection),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *MongoRepository) Collection() string {
|
||||
return r.collectionName
|
||||
}
|
||||
|
||||
func (r *MongoRepository) Insert(ctx context.Context, obj storable.Storable, getFilter builder.Query) error {
|
||||
if (obj.GetID() == nil) || (obj.GetID().IsZero()) {
|
||||
obj.SetID(primitive.NewObjectID())
|
||||
}
|
||||
obj.Update()
|
||||
_, err := r.collection.InsertOne(ctx, obj)
|
||||
if mongo.IsDuplicateKeyError(err) {
|
||||
if getFilter != nil {
|
||||
if err = r.FindOneByFilter(ctx, getFilter, obj); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return merrors.DataConflict("duplicate_key")
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *MongoRepository) InsertMany(ctx context.Context, objects []storable.Storable) error {
|
||||
if len(objects) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
docs := make([]interface{}, len(objects))
|
||||
for i, obj := range objects {
|
||||
if (obj.GetID() == nil) || (obj.GetID().IsZero()) {
|
||||
obj.SetID(primitive.NewObjectID())
|
||||
}
|
||||
obj.Update()
|
||||
docs[i] = obj
|
||||
}
|
||||
|
||||
_, err := r.collection.InsertMany(ctx, docs)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *MongoRepository) findOneByFilterImp(ctx context.Context, filter bson.D, errMessage string, result storable.Storable) error {
|
||||
err := r.collection.FindOne(ctx, filter).Decode(result)
|
||||
if errors.Is(err, mongo.ErrNoDocuments) {
|
||||
return merrors.NoData(errMessage)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *MongoRepository) Get(ctx context.Context, id primitive.ObjectID, result storable.Storable) error {
|
||||
if id.IsZero() {
|
||||
return merrors.InvalidArgument("zero id provided while fetching " + result.Collection())
|
||||
}
|
||||
return r.findOneByFilterImp(ctx, idFilter(id), fmt.Sprintf("%s with ID = %s not found", result.Collection(), id.Hex()), result)
|
||||
}
|
||||
|
||||
type QueryFunc func(ctx context.Context, collection *mongo.Collection) (*mongo.Cursor, error)
|
||||
|
||||
func (r *MongoRepository) executeQuery(ctx context.Context, queryFunc QueryFunc, decoder rd.DecodingFunc) error {
|
||||
cursor, err := queryFunc(ctx, r.collection)
|
||||
if errors.Is(err, mongo.ErrNoDocuments) {
|
||||
return merrors.NoData("no_items_in_array")
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer cursor.Close(ctx)
|
||||
|
||||
for cursor.Next(ctx) {
|
||||
if err = decoder(cursor); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *MongoRepository) Aggregate(ctx context.Context, pipeline builder.Pipeline, decoder rd.DecodingFunc) error {
|
||||
queryFunc := func(ctx context.Context, collection *mongo.Collection) (*mongo.Cursor, error) {
|
||||
return collection.Aggregate(ctx, pipeline.Build())
|
||||
}
|
||||
return r.executeQuery(ctx, queryFunc, decoder)
|
||||
}
|
||||
|
||||
func (r *MongoRepository) FindManyByFilter(ctx context.Context, query builder.Query, decoder rd.DecodingFunc) error {
|
||||
queryFunc := func(ctx context.Context, collection *mongo.Collection) (*mongo.Cursor, error) {
|
||||
return collection.Find(ctx, query.BuildQuery(), query.BuildOptions())
|
||||
}
|
||||
return r.executeQuery(ctx, queryFunc, decoder)
|
||||
}
|
||||
|
||||
func (r *MongoRepository) FindOneByFilter(ctx context.Context, query builder.Query, result storable.Storable) error {
|
||||
return r.findOneByFilterImp(ctx, query.BuildQuery(), result.Collection()+" not found by filter", result)
|
||||
}
|
||||
|
||||
func (r *MongoRepository) Update(ctx context.Context, obj storable.Storable) error {
|
||||
obj.Update()
|
||||
return r.collection.FindOneAndReplace(ctx, idFilter(*obj.GetID()), obj).Err()
|
||||
}
|
||||
|
||||
func (r *MongoRepository) Patch(ctx context.Context, id primitive.ObjectID, patch builder.Patch) error {
|
||||
if id.IsZero() {
|
||||
return merrors.InvalidArgument("zero id provided while patching")
|
||||
}
|
||||
_, err := r.collection.UpdateByID(ctx, id, patch.Build())
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *MongoRepository) PatchMany(ctx context.Context, query builder.Query, patch builder.Patch) (int, error) {
|
||||
result, err := r.collection.UpdateMany(ctx, query.BuildQuery(), patch.Build())
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return int(result.ModifiedCount), nil
|
||||
}
|
||||
|
||||
func (r *MongoRepository) ListIDs(ctx context.Context, query builder.Query) ([]primitive.ObjectID, error) {
|
||||
filter := query.BuildQuery()
|
||||
findOptions := options.Find().SetProjection(bson.M{storable.IDField: 1})
|
||||
|
||||
cursor, err := r.collection.Find(ctx, filter, findOptions)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer cursor.Close(ctx)
|
||||
|
||||
var ids []primitive.ObjectID
|
||||
for cursor.Next(ctx) {
|
||||
var doc struct {
|
||||
ID primitive.ObjectID `bson:"_id"`
|
||||
}
|
||||
if err := cursor.Decode(&doc); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ids = append(ids, doc.ID)
|
||||
}
|
||||
if err := cursor.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
func (r *MongoRepository) ListPermissionBound(ctx context.Context, query builder.Query) ([]model.PermissionBoundStorable, error) {
|
||||
filter := query.BuildQuery()
|
||||
findOptions := options.Find().SetProjection(bson.M{
|
||||
storable.IDField: 1,
|
||||
storable.PermissionRefField: 1,
|
||||
storable.OrganizationRefField: 1,
|
||||
})
|
||||
|
||||
cursor, err := r.collection.Find(ctx, filter, findOptions)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer cursor.Close(ctx)
|
||||
|
||||
result := make([]model.PermissionBoundStorable, 0)
|
||||
|
||||
for cursor.Next(ctx) {
|
||||
var doc model.PermissionBound
|
||||
if err := cursor.Decode(&doc); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result = append(result, &doc)
|
||||
}
|
||||
if err := cursor.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (r *MongoRepository) ListAccountBound(ctx context.Context, query builder.Query) ([]model.AccountBoundStorable, error) {
|
||||
filter := query.BuildQuery()
|
||||
findOptions := options.Find().SetProjection(bson.M{
|
||||
storable.IDField: 1,
|
||||
model.AccountRefField: 1,
|
||||
model.OrganizationRefField: 1,
|
||||
})
|
||||
|
||||
cursor, err := r.collection.Find(ctx, filter, findOptions)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer cursor.Close(ctx)
|
||||
|
||||
result := make([]model.AccountBoundStorable, 0)
|
||||
|
||||
for cursor.Next(ctx) {
|
||||
var doc model.AccountBoundBase
|
||||
if err := cursor.Decode(&doc); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result = append(result, &doc)
|
||||
}
|
||||
if err := cursor.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (r *MongoRepository) Delete(ctx context.Context, id primitive.ObjectID) error {
|
||||
_, err := r.collection.DeleteOne(ctx, idFilter(id))
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *MongoRepository) DeleteMany(ctx context.Context, query builder.Query) error {
|
||||
_, err := r.collection.DeleteMany(ctx, query.BuildQuery())
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *MongoRepository) Name() string {
|
||||
return r.collection.Name()
|
||||
}
|
||||
@@ -0,0 +1,577 @@
|
||||
//go:build integration
|
||||
// +build integration
|
||||
|
||||
package repositoryimp_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/tech/sendico/pkg/db/internal/mongo/repositoryimp"
|
||||
"github.com/tech/sendico/pkg/db/internal/mongo/repositoryimp/builderimp"
|
||||
"github.com/tech/sendico/pkg/db/repository/builder"
|
||||
"github.com/tech/sendico/pkg/merrors"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/testcontainers/testcontainers-go"
|
||||
"github.com/testcontainers/testcontainers-go/modules/mongodb"
|
||||
"github.com/testcontainers/testcontainers-go/wait"
|
||||
"go.mongodb.org/mongo-driver/bson"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
"go.mongodb.org/mongo-driver/mongo/options"
|
||||
)
|
||||
|
||||
func TestMongoRepository_Insert(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
mongoContainer, err := mongodb.Run(ctx,
|
||||
"mongo:latest",
|
||||
mongodb.WithUsername("root"),
|
||||
mongodb.WithPassword("password"),
|
||||
testcontainers.WithWaitStrategy(wait.ForLog("Waiting for connections")),
|
||||
)
|
||||
require.NoError(t, err, "failed to start MongoDB container")
|
||||
defer terminate(ctx, t, mongoContainer)
|
||||
|
||||
mongoURI, err := mongoContainer.ConnectionString(ctx)
|
||||
require.NoError(t, err, "failed to get MongoDB connection string")
|
||||
|
||||
clientOptions := options.Client().ApplyURI(mongoURI)
|
||||
client, err := mongo.Connect(ctx, clientOptions)
|
||||
require.NoError(t, err, "failed to connect to MongoDB")
|
||||
defer disconnect(ctx, t, client)
|
||||
|
||||
db := client.Database("testdb")
|
||||
repository := repositoryimp.NewMongoRepository(db, "testcollection")
|
||||
|
||||
t.Run("Insert_WithoutID", func(t *testing.T) {
|
||||
testObj := &TestObject{Name: "testInsert"}
|
||||
// ID should be nil/zero initially
|
||||
assert.True(t, testObj.GetID().IsZero())
|
||||
|
||||
err := repository.Insert(ctx, testObj, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// ID should be assigned after insert
|
||||
assert.False(t, testObj.GetID().IsZero())
|
||||
assert.NotEmpty(t, testObj.CreatedAt)
|
||||
assert.NotEmpty(t, testObj.UpdatedAt)
|
||||
})
|
||||
|
||||
t.Run("Insert_WithExistingID", func(t *testing.T) {
|
||||
existingID := primitive.NewObjectID()
|
||||
testObj := &TestObject{Name: "testInsertWithID"}
|
||||
testObj.SetID(existingID)
|
||||
|
||||
err := repository.Insert(ctx, testObj, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// ID should remain the same
|
||||
assert.Equal(t, existingID, *testObj.GetID())
|
||||
})
|
||||
|
||||
t.Run("Insert_DuplicateKey", func(t *testing.T) {
|
||||
// Insert first object
|
||||
testObj1 := &TestObject{Name: "duplicate"}
|
||||
err := repository.Insert(ctx, testObj1, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Try to insert object with same ID
|
||||
testObj2 := &TestObject{Name: "duplicate2"}
|
||||
testObj2.SetID(*testObj1.GetID())
|
||||
|
||||
err = repository.Insert(ctx, testObj2, nil)
|
||||
assert.Error(t, err)
|
||||
assert.True(t, errors.Is(err, merrors.ErrDataConflict))
|
||||
})
|
||||
|
||||
t.Run("Insert_DuplicateKeyWithGetFilter", func(t *testing.T) {
|
||||
// Insert first object
|
||||
testObj1 := &TestObject{Name: "duplicateWithFilter"}
|
||||
err := repository.Insert(ctx, testObj1, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Try to insert object with same ID, but with getFilter
|
||||
testObj2 := &TestObject{Name: "duplicateWithFilter2"}
|
||||
testObj2.SetID(*testObj1.GetID())
|
||||
|
||||
getFilter := builderimp.NewQueryImp().Comparison(builderimp.NewFieldImp("_id"), builder.Eq, *testObj1.GetID())
|
||||
|
||||
err = repository.Insert(ctx, testObj2, getFilter)
|
||||
assert.Error(t, err)
|
||||
assert.True(t, errors.Is(err, merrors.ErrDataConflict))
|
||||
|
||||
// But testObj2 should be populated with the existing object data
|
||||
assert.Equal(t, testObj1.Name, testObj2.Name)
|
||||
})
|
||||
}
|
||||
|
||||
func TestMongoRepository_Update(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
mongoContainer, err := mongodb.Run(ctx,
|
||||
"mongo:latest",
|
||||
mongodb.WithUsername("root"),
|
||||
mongodb.WithPassword("password"),
|
||||
testcontainers.WithWaitStrategy(wait.ForLog("Waiting for connections")),
|
||||
)
|
||||
require.NoError(t, err, "failed to start MongoDB container")
|
||||
defer terminate(ctx, t, mongoContainer)
|
||||
|
||||
mongoURI, err := mongoContainer.ConnectionString(ctx)
|
||||
require.NoError(t, err, "failed to get MongoDB connection string")
|
||||
|
||||
clientOptions := options.Client().ApplyURI(mongoURI)
|
||||
client, err := mongo.Connect(ctx, clientOptions)
|
||||
require.NoError(t, err, "failed to connect to MongoDB")
|
||||
defer disconnect(ctx, t, client)
|
||||
|
||||
db := client.Database("testdb")
|
||||
repository := repositoryimp.NewMongoRepository(db, "testcollection")
|
||||
|
||||
t.Run("Update_ExistingObject", func(t *testing.T) {
|
||||
// Insert object first
|
||||
testObj := &TestObject{Name: "originalName"}
|
||||
err := repository.Insert(ctx, testObj, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
originalUpdatedAt := testObj.UpdatedAt
|
||||
|
||||
// Update the object
|
||||
testObj.Name = "updatedName"
|
||||
time.Sleep(10 * time.Millisecond) // Ensure time difference
|
||||
|
||||
err = repository.Update(ctx, testObj)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the object was updated
|
||||
result := &TestObject{}
|
||||
err = repository.Get(ctx, *testObj.GetID(), result)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "updatedName", result.Name)
|
||||
assert.True(t, result.UpdatedAt.After(originalUpdatedAt))
|
||||
})
|
||||
|
||||
t.Run("Update_NonExistentObject", func(t *testing.T) {
|
||||
nonExistentID := primitive.NewObjectID()
|
||||
testObj := &TestObject{Name: "nonExistent"}
|
||||
testObj.SetID(nonExistentID)
|
||||
|
||||
err := repository.Update(ctx, testObj)
|
||||
assert.Error(t, err)
|
||||
assert.True(t, errors.Is(err, mongo.ErrNoDocuments))
|
||||
})
|
||||
}
|
||||
|
||||
func TestMongoRepository_Delete(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
mongoContainer, err := mongodb.Run(ctx,
|
||||
"mongo:latest",
|
||||
mongodb.WithUsername("root"),
|
||||
mongodb.WithPassword("password"),
|
||||
testcontainers.WithWaitStrategy(wait.ForLog("Waiting for connections")),
|
||||
)
|
||||
require.NoError(t, err, "failed to start MongoDB container")
|
||||
defer terminate(ctx, t, mongoContainer)
|
||||
|
||||
mongoURI, err := mongoContainer.ConnectionString(ctx)
|
||||
require.NoError(t, err, "failed to get MongoDB connection string")
|
||||
|
||||
clientOptions := options.Client().ApplyURI(mongoURI)
|
||||
client, err := mongo.Connect(ctx, clientOptions)
|
||||
require.NoError(t, err, "failed to connect to MongoDB")
|
||||
defer disconnect(ctx, t, client)
|
||||
|
||||
db := client.Database("testdb")
|
||||
repository := repositoryimp.NewMongoRepository(db, "testcollection")
|
||||
|
||||
t.Run("Delete_ExistingObject", func(t *testing.T) {
|
||||
// Insert object first
|
||||
testObj := &TestObject{Name: "toDelete"}
|
||||
err := repository.Insert(ctx, testObj, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Delete the object
|
||||
err = repository.Delete(ctx, *testObj.GetID())
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the object was deleted
|
||||
result := &TestObject{}
|
||||
err = repository.Get(ctx, *testObj.GetID(), result)
|
||||
assert.Error(t, err)
|
||||
assert.True(t, errors.Is(err, merrors.ErrNoData))
|
||||
})
|
||||
|
||||
t.Run("Delete_NonExistentObject", func(t *testing.T) {
|
||||
nonExistentID := primitive.NewObjectID()
|
||||
|
||||
err := repository.Delete(ctx, nonExistentID)
|
||||
// Delete should not return error even if object doesn't exist
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestMongoRepository_FindOneByFilter(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
mongoContainer, err := mongodb.Run(ctx,
|
||||
"mongo:latest",
|
||||
mongodb.WithUsername("root"),
|
||||
mongodb.WithPassword("password"),
|
||||
testcontainers.WithWaitStrategy(wait.ForLog("Waiting for connections")),
|
||||
)
|
||||
require.NoError(t, err, "failed to start MongoDB container")
|
||||
defer terminate(ctx, t, mongoContainer)
|
||||
|
||||
mongoURI, err := mongoContainer.ConnectionString(ctx)
|
||||
require.NoError(t, err, "failed to get MongoDB connection string")
|
||||
|
||||
clientOptions := options.Client().ApplyURI(mongoURI)
|
||||
client, err := mongo.Connect(ctx, clientOptions)
|
||||
require.NoError(t, err, "failed to connect to MongoDB")
|
||||
defer disconnect(ctx, t, client)
|
||||
|
||||
db := client.Database("testdb")
|
||||
repository := repositoryimp.NewMongoRepository(db, "testcollection")
|
||||
|
||||
t.Run("FindOneByFilter_MatchingFilter", func(t *testing.T) {
|
||||
// Insert test objects
|
||||
testObjs := []*TestObject{
|
||||
{Name: "findMe"},
|
||||
{Name: "dontFindMe"},
|
||||
{Name: "findMeToo"},
|
||||
}
|
||||
|
||||
for _, obj := range testObjs {
|
||||
err := repository.Insert(ctx, obj, nil)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Find by filter
|
||||
query := builderimp.NewQueryImp().Comparison(builderimp.NewFieldImp("name"), builder.Eq, "findMe")
|
||||
result := &TestObject{}
|
||||
|
||||
err := repository.FindOneByFilter(ctx, query, result)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "findMe", result.Name)
|
||||
})
|
||||
|
||||
t.Run("FindOneByFilter_NoMatch", func(t *testing.T) {
|
||||
query := builderimp.NewQueryImp().Comparison(builderimp.NewFieldImp("name"), builder.Eq, "nonExistentName")
|
||||
result := &TestObject{}
|
||||
|
||||
err := repository.FindOneByFilter(ctx, query, result)
|
||||
assert.Error(t, err)
|
||||
assert.True(t, errors.Is(err, merrors.ErrNoData))
|
||||
})
|
||||
}
|
||||
|
||||
func TestMongoRepository_FindManyByFilter(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
mongoContainer, err := mongodb.Run(ctx,
|
||||
"mongo:latest",
|
||||
mongodb.WithUsername("root"),
|
||||
mongodb.WithPassword("password"),
|
||||
testcontainers.WithWaitStrategy(wait.ForLog("Waiting for connections")),
|
||||
)
|
||||
require.NoError(t, err, "failed to start MongoDB container")
|
||||
defer terminate(ctx, t, mongoContainer)
|
||||
|
||||
mongoURI, err := mongoContainer.ConnectionString(ctx)
|
||||
require.NoError(t, err, "failed to get MongoDB connection string")
|
||||
|
||||
clientOptions := options.Client().ApplyURI(mongoURI)
|
||||
client, err := mongo.Connect(ctx, clientOptions)
|
||||
require.NoError(t, err, "failed to connect to MongoDB")
|
||||
defer disconnect(ctx, t, client)
|
||||
|
||||
db := client.Database("testdb")
|
||||
repository := repositoryimp.NewMongoRepository(db, "testcollection")
|
||||
|
||||
t.Run("FindManyByFilter_MultipleResults", func(t *testing.T) {
|
||||
// Insert test objects
|
||||
testObjs := []*TestObject{
|
||||
{Name: "findMany1"},
|
||||
{Name: "findMany2"},
|
||||
{Name: "dontFind"},
|
||||
}
|
||||
|
||||
for _, obj := range testObjs {
|
||||
err := repository.Insert(ctx, obj, nil)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Find objects with names starting with "findMany"
|
||||
query := builderimp.NewQueryImp().RegEx(builderimp.NewFieldImp("name"), "^findMany", "")
|
||||
|
||||
var results []*TestObject
|
||||
decoder := func(cursor *mongo.Cursor) error {
|
||||
var obj TestObject
|
||||
if err := cursor.Decode(&obj); err != nil {
|
||||
return err
|
||||
}
|
||||
results = append(results, &obj)
|
||||
return nil
|
||||
}
|
||||
|
||||
err := repository.FindManyByFilter(ctx, query, decoder)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, results, 2)
|
||||
|
||||
names := make([]string, len(results))
|
||||
for i, obj := range results {
|
||||
names[i] = obj.Name
|
||||
}
|
||||
assert.Contains(t, names, "findMany1")
|
||||
assert.Contains(t, names, "findMany2")
|
||||
})
|
||||
|
||||
t.Run("FindManyByFilter_NoResults", func(t *testing.T) {
|
||||
query := builderimp.NewQueryImp().Comparison(builderimp.NewFieldImp("name"), builder.Eq, "nonExistentPattern")
|
||||
|
||||
var results []*TestObject
|
||||
decoder := func(cursor *mongo.Cursor) error {
|
||||
var obj TestObject
|
||||
if err := cursor.Decode(&obj); err != nil {
|
||||
return err
|
||||
}
|
||||
results = append(results, &obj)
|
||||
return nil
|
||||
}
|
||||
|
||||
err := repository.FindManyByFilter(ctx, query, decoder)
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, results)
|
||||
})
|
||||
}
|
||||
|
||||
func TestMongoRepository_DeleteMany(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
mongoContainer, err := mongodb.Run(ctx,
|
||||
"mongo:latest",
|
||||
mongodb.WithUsername("root"),
|
||||
mongodb.WithPassword("password"),
|
||||
testcontainers.WithWaitStrategy(wait.ForLog("Waiting for connections")),
|
||||
)
|
||||
require.NoError(t, err, "failed to start MongoDB container")
|
||||
defer terminate(ctx, t, mongoContainer)
|
||||
|
||||
mongoURI, err := mongoContainer.ConnectionString(ctx)
|
||||
require.NoError(t, err, "failed to get MongoDB connection string")
|
||||
|
||||
clientOptions := options.Client().ApplyURI(mongoURI)
|
||||
client, err := mongo.Connect(ctx, clientOptions)
|
||||
require.NoError(t, err, "failed to connect to MongoDB")
|
||||
defer disconnect(ctx, t, client)
|
||||
|
||||
db := client.Database("testdb")
|
||||
repository := repositoryimp.NewMongoRepository(db, "testcollection")
|
||||
|
||||
t.Run("DeleteMany_MultipleDocuments", func(t *testing.T) {
|
||||
// Insert test objects
|
||||
testObjs := []*TestObject{
|
||||
{Name: "deleteMany1"},
|
||||
{Name: "deleteMany2"},
|
||||
{Name: "keepMe"},
|
||||
}
|
||||
|
||||
for _, obj := range testObjs {
|
||||
err := repository.Insert(ctx, obj, nil)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Delete objects with names starting with "deleteMany"
|
||||
query := builderimp.NewQueryImp().RegEx(builderimp.NewFieldImp("name"), "^deleteMany", "")
|
||||
|
||||
err := repository.DeleteMany(ctx, query)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify deletions
|
||||
queryAll := builderimp.NewQueryImp()
|
||||
var results []*TestObject
|
||||
decoder := func(cursor *mongo.Cursor) error {
|
||||
var obj TestObject
|
||||
if err := cursor.Decode(&obj); err != nil {
|
||||
return err
|
||||
}
|
||||
results = append(results, &obj)
|
||||
return nil
|
||||
}
|
||||
|
||||
err = repository.FindManyByFilter(ctx, queryAll, decoder)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, results, 1)
|
||||
assert.Equal(t, "keepMe", results[0].Name)
|
||||
})
|
||||
}
|
||||
|
||||
func TestMongoRepository_Name(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
mongoContainer, err := mongodb.Run(ctx,
|
||||
"mongo:latest",
|
||||
mongodb.WithUsername("root"),
|
||||
mongodb.WithPassword("password"),
|
||||
testcontainers.WithWaitStrategy(wait.ForLog("Waiting for connections")),
|
||||
)
|
||||
require.NoError(t, err, "failed to start MongoDB container")
|
||||
defer terminate(ctx, t, mongoContainer)
|
||||
|
||||
mongoURI, err := mongoContainer.ConnectionString(ctx)
|
||||
require.NoError(t, err, "failed to get MongoDB connection string")
|
||||
|
||||
clientOptions := options.Client().ApplyURI(mongoURI)
|
||||
client, err := mongo.Connect(ctx, clientOptions)
|
||||
require.NoError(t, err, "failed to connect to MongoDB")
|
||||
defer disconnect(ctx, t, client)
|
||||
|
||||
db := client.Database("testdb")
|
||||
repository := repositoryimp.NewMongoRepository(db, "mycollection")
|
||||
|
||||
t.Run("Name_ReturnsCollectionName", func(t *testing.T) {
|
||||
name := repository.Name()
|
||||
assert.Equal(t, "mycollection", name)
|
||||
})
|
||||
}
|
||||
|
||||
func TestMongoRepository_ListPermissionBound(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
mongoContainer, err := mongodb.Run(ctx,
|
||||
"mongo:latest",
|
||||
mongodb.WithUsername("root"),
|
||||
mongodb.WithPassword("password"),
|
||||
testcontainers.WithWaitStrategy(wait.ForLog("Waiting for connections")),
|
||||
)
|
||||
require.NoError(t, err, "failed to start MongoDB container")
|
||||
defer terminate(ctx, t, mongoContainer)
|
||||
|
||||
mongoURI, err := mongoContainer.ConnectionString(ctx)
|
||||
require.NoError(t, err, "failed to get MongoDB connection string")
|
||||
|
||||
clientOptions := options.Client().ApplyURI(mongoURI)
|
||||
client, err := mongo.Connect(ctx, clientOptions)
|
||||
require.NoError(t, err, "failed to connect to MongoDB")
|
||||
defer disconnect(ctx, t, client)
|
||||
|
||||
db := client.Database("testdb")
|
||||
repository := repositoryimp.NewMongoRepository(db, "testcollection")
|
||||
|
||||
t.Run("ListPermissionBound_WithData", func(t *testing.T) {
|
||||
// Insert test objects with permission bound data
|
||||
orgID := primitive.NewObjectID()
|
||||
|
||||
// Insert documents directly with permission bound fields
|
||||
_, err := db.Collection("testcollection").InsertMany(ctx, []interface{}{
|
||||
bson.M{
|
||||
"_id": primitive.NewObjectID(),
|
||||
"organizationRef": orgID,
|
||||
"permissionRef": primitive.NewObjectID(),
|
||||
},
|
||||
bson.M{
|
||||
"_id": primitive.NewObjectID(),
|
||||
"organizationRef": orgID,
|
||||
"permissionRef": primitive.NewObjectID(),
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Query for permission bound objects
|
||||
query := builderimp.NewQueryImp().Comparison(builderimp.NewFieldImp("organizationRef"), builder.Eq, orgID)
|
||||
|
||||
results, err := repository.ListPermissionBound(ctx, query)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, results, 2)
|
||||
|
||||
for _, result := range results {
|
||||
assert.Equal(t, orgID, result.GetOrganizationRef())
|
||||
assert.NotNil(t, result.GetPermissionRef())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ListPermissionBound_EmptyResult", func(t *testing.T) {
|
||||
nonExistentOrgID := primitive.NewObjectID()
|
||||
query := builderimp.NewQueryImp().Comparison(builderimp.NewFieldImp("organizationRef"), builder.Eq, nonExistentOrgID)
|
||||
|
||||
results, err := repository.ListPermissionBound(ctx, query)
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, results)
|
||||
})
|
||||
}
|
||||
|
||||
func TestMongoRepository_UpdateTimestamp(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
mongoContainer, err := mongodb.Run(ctx,
|
||||
"mongo:latest",
|
||||
mongodb.WithUsername("root"),
|
||||
mongodb.WithPassword("password"),
|
||||
testcontainers.WithWaitStrategy(wait.ForLog("Waiting for connections")),
|
||||
)
|
||||
require.NoError(t, err, "failed to start MongoDB container")
|
||||
defer terminate(ctx, t, mongoContainer)
|
||||
|
||||
mongoURI, err := mongoContainer.ConnectionString(ctx)
|
||||
require.NoError(t, err, "failed to get MongoDB connection string")
|
||||
|
||||
clientOptions := options.Client().ApplyURI(mongoURI)
|
||||
client, err := mongo.Connect(ctx, clientOptions)
|
||||
require.NoError(t, err, "failed to connect to MongoDB")
|
||||
defer disconnect(ctx, t, client)
|
||||
|
||||
db := client.Database("testdb")
|
||||
repository := repositoryimp.NewMongoRepository(db, "testcollection")
|
||||
|
||||
t.Run("Update_Should_Update_Timestamp", func(t *testing.T) {
|
||||
// Create test object
|
||||
obj := &TestObject{
|
||||
Name: "Test Object",
|
||||
}
|
||||
|
||||
// Set ID and initial timestamps
|
||||
obj.SetID(primitive.NewObjectID())
|
||||
originalCreatedAt := obj.CreatedAt
|
||||
originalUpdatedAt := obj.UpdatedAt
|
||||
|
||||
// Insert the object
|
||||
err := repository.Insert(ctx, obj, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Wait a moment to ensure timestamp difference
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
// Update the object
|
||||
obj.Name = "Updated Object"
|
||||
err = repository.Update(ctx, obj)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify timestamps
|
||||
assert.Equal(t, originalCreatedAt, obj.CreatedAt, "CreatedAt should not change")
|
||||
assert.True(t, obj.UpdatedAt.After(originalUpdatedAt), "UpdatedAt should be updated")
|
||||
|
||||
// Verify the object was actually updated in the database
|
||||
var retrieved TestObject
|
||||
err = repository.Get(ctx, *obj.GetID(), &retrieved)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "Updated Object", retrieved.Name, "Name should be updated")
|
||||
assert.WithinDuration(t, originalCreatedAt, retrieved.CreatedAt, time.Second, "CreatedAt should not change in DB")
|
||||
assert.True(t, retrieved.UpdatedAt.After(originalUpdatedAt), "UpdatedAt should be updated in DB")
|
||||
assert.WithinDuration(t, obj.UpdatedAt, retrieved.UpdatedAt, time.Second, "UpdatedAt should match between object and DB")
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,153 @@
|
||||
//go:build integration
|
||||
// +build integration
|
||||
|
||||
package repositoryimp_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/tech/sendico/pkg/db/internal/mongo/repositoryimp"
|
||||
"github.com/tech/sendico/pkg/db/storable"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/testcontainers/testcontainers-go"
|
||||
"github.com/testcontainers/testcontainers-go/modules/mongodb"
|
||||
"github.com/testcontainers/testcontainers-go/wait"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
"go.mongodb.org/mongo-driver/mongo/options"
|
||||
)
|
||||
|
||||
func TestMongoRepository_InsertMany(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
mongoContainer, err := mongodb.Run(ctx,
|
||||
"mongo:latest",
|
||||
mongodb.WithUsername("root"),
|
||||
mongodb.WithPassword("password"),
|
||||
testcontainers.WithWaitStrategy(wait.ForLog("Waiting for connections")),
|
||||
)
|
||||
require.NoError(t, err, "failed to start MongoDB container")
|
||||
defer terminate(ctx, t, mongoContainer)
|
||||
|
||||
mongoURI, err := mongoContainer.ConnectionString(ctx)
|
||||
require.NoError(t, err, "failed to get MongoDB connection string")
|
||||
|
||||
clientOptions := options.Client().ApplyURI(mongoURI)
|
||||
client, err := mongo.Connect(ctx, clientOptions)
|
||||
require.NoError(t, err, "failed to connect to MongoDB")
|
||||
defer disconnect(ctx, t, client)
|
||||
|
||||
db := client.Database("testdb")
|
||||
repository := repositoryimp.NewMongoRepository(db, "testcollection")
|
||||
|
||||
t.Run("InsertMany_Success", func(t *testing.T) {
|
||||
objects := []storable.Storable{
|
||||
&TestObject{Name: "test1"},
|
||||
&TestObject{Name: "test2"},
|
||||
&TestObject{Name: "test3"},
|
||||
}
|
||||
|
||||
err := repository.InsertMany(ctx, objects)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify all objects were inserted and have IDs
|
||||
for _, obj := range objects {
|
||||
assert.NotNil(t, obj.GetID())
|
||||
assert.False(t, obj.GetID().IsZero())
|
||||
|
||||
// Verify we can retrieve each object
|
||||
result := &TestObject{}
|
||||
err := repository.Get(ctx, *obj.GetID(), result)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, obj.(*TestObject).Name, result.Name)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("InsertMany_EmptySlice", func(t *testing.T) {
|
||||
objects := []storable.Storable{}
|
||||
|
||||
err := repository.InsertMany(ctx, objects)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("InsertMany_WithExistingIDs", func(t *testing.T) {
|
||||
id1 := primitive.NewObjectID()
|
||||
id2 := primitive.NewObjectID()
|
||||
|
||||
objects := []storable.Storable{
|
||||
&TestObject{Base: storable.Base{ID: id1}, Name: "preassigned1"},
|
||||
&TestObject{Base: storable.Base{ID: id2}, Name: "preassigned2"},
|
||||
}
|
||||
|
||||
err := repository.InsertMany(ctx, objects)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify objects were inserted with pre-assigned IDs
|
||||
result1 := &TestObject{}
|
||||
err = repository.Get(ctx, id1, result1)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "preassigned1", result1.Name)
|
||||
|
||||
result2 := &TestObject{}
|
||||
err = repository.Get(ctx, id2, result2)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "preassigned2", result2.Name)
|
||||
})
|
||||
|
||||
t.Run("InsertMany_MixedTypes", func(t *testing.T) {
|
||||
objects := []storable.Storable{
|
||||
&TestObject{Name: "test1"},
|
||||
&AnotherObject{Description: "desc1"},
|
||||
&TestObject{Name: "test2"},
|
||||
}
|
||||
|
||||
err := repository.InsertMany(ctx, objects)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify all objects were inserted
|
||||
for _, obj := range objects {
|
||||
assert.NotNil(t, obj.GetID())
|
||||
assert.False(t, obj.GetID().IsZero())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("InsertMany_DuplicateKey", func(t *testing.T) {
|
||||
id := primitive.NewObjectID()
|
||||
|
||||
// Insert first object
|
||||
obj1 := &TestObject{Base: storable.Base{ID: id}, Name: "original"}
|
||||
err := repository.Insert(ctx, obj1, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Try to insert multiple objects including one with duplicate ID
|
||||
objects := []storable.Storable{
|
||||
&TestObject{Name: "test1"},
|
||||
&TestObject{Base: storable.Base{ID: id}, Name: "duplicate"},
|
||||
}
|
||||
|
||||
err = repository.InsertMany(ctx, objects)
|
||||
assert.Error(t, err)
|
||||
assert.True(t, mongo.IsDuplicateKeyError(err))
|
||||
})
|
||||
|
||||
t.Run("InsertMany_UpdateTimestamps", func(t *testing.T) {
|
||||
objects := []storable.Storable{
|
||||
&TestObject{Name: "test1"},
|
||||
&TestObject{Name: "test2"},
|
||||
}
|
||||
|
||||
err := repository.InsertMany(ctx, objects)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify timestamps were set
|
||||
for _, obj := range objects {
|
||||
testObj := obj.(*TestObject)
|
||||
assert.NotZero(t, testObj.CreatedAt)
|
||||
assert.NotZero(t, testObj.UpdatedAt)
|
||||
}
|
||||
})
|
||||
}
|
||||
233
api/pkg/db/internal/mongo/repositoryimp/repository_patch_test.go
Normal file
233
api/pkg/db/internal/mongo/repositoryimp/repository_patch_test.go
Normal file
@@ -0,0 +1,233 @@
|
||||
//go:build integration
|
||||
// +build integration
|
||||
|
||||
package repositoryimp_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/tech/sendico/pkg/db/internal/mongo/repositoryimp"
|
||||
"github.com/tech/sendico/pkg/db/repository"
|
||||
"github.com/tech/sendico/pkg/db/repository/builder"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/testcontainers/testcontainers-go"
|
||||
"github.com/testcontainers/testcontainers-go/modules/mongodb"
|
||||
"github.com/testcontainers/testcontainers-go/wait"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
"go.mongodb.org/mongo-driver/mongo/options"
|
||||
)
|
||||
|
||||
func TestMongoRepository_PatchOperations(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
mongoContainer, err := mongodb.Run(ctx,
|
||||
"mongo:latest",
|
||||
mongodb.WithUsername("root"),
|
||||
mongodb.WithPassword("password"),
|
||||
testcontainers.WithWaitStrategy(wait.ForLog("Waiting for connections")),
|
||||
)
|
||||
require.NoError(t, err, "failed to start MongoDB container")
|
||||
defer terminate(ctx, t, mongoContainer)
|
||||
|
||||
mongoURI, err := mongoContainer.ConnectionString(ctx)
|
||||
require.NoError(t, err, "failed to get MongoDB connection string")
|
||||
|
||||
clientOptions := options.Client().ApplyURI(mongoURI)
|
||||
client, err := mongo.Connect(ctx, clientOptions)
|
||||
require.NoError(t, err, "failed to connect to MongoDB")
|
||||
defer disconnect(ctx, t, client)
|
||||
|
||||
db := client.Database("testdb")
|
||||
repo := repositoryimp.NewMongoRepository(db, "testcollection")
|
||||
|
||||
t.Run("Patch_SingleDocument", func(t *testing.T) {
|
||||
obj := &TestObject{Name: "old"}
|
||||
err := repo.Insert(ctx, obj, nil)
|
||||
require.NoError(t, err)
|
||||
original := obj.UpdatedAt
|
||||
|
||||
patch := repository.Patch().Set(repository.Field("name"), "new")
|
||||
err = repo.Patch(ctx, *obj.GetID(), patch)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result TestObject
|
||||
err = repo.Get(ctx, *obj.GetID(), &result)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "new", result.Name)
|
||||
assert.True(t, result.UpdatedAt.After(original))
|
||||
})
|
||||
|
||||
t.Run("PatchMany_MultipleDocuments", func(t *testing.T) {
|
||||
objs := []*TestObject{{Name: "match"}, {Name: "match"}, {Name: "other"}}
|
||||
for _, o := range objs {
|
||||
err := repo.Insert(ctx, o, nil)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
query := repository.Query().Comparison(repository.Field("name"), builder.Eq, "match")
|
||||
patch := repository.Patch().Set(repository.Field("name"), "patched")
|
||||
modified, err := repo.PatchMany(ctx, query, patch)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 2, modified)
|
||||
|
||||
verify := repository.Query().Comparison(repository.Field("name"), builder.Eq, "patched")
|
||||
var results []TestObject
|
||||
decoder := func(cursor *mongo.Cursor) error {
|
||||
var obj TestObject
|
||||
if err := cursor.Decode(&obj); err != nil {
|
||||
return err
|
||||
}
|
||||
results = append(results, obj)
|
||||
return nil
|
||||
}
|
||||
err = repo.FindManyByFilter(ctx, verify, decoder)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, results, 2)
|
||||
})
|
||||
|
||||
t.Run("Patch_PushArray", func(t *testing.T) {
|
||||
obj := &TestObject{Name: "test", Tags: []string{"tag1"}}
|
||||
err := repo.Insert(ctx, obj, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
patch := repository.Patch().Push(repository.Field("tags"), "tag2")
|
||||
err = repo.Patch(ctx, *obj.GetID(), patch)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result TestObject
|
||||
err = repo.Get(ctx, *obj.GetID(), &result)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []string{"tag1", "tag2"}, result.Tags)
|
||||
})
|
||||
|
||||
t.Run("Patch_PullArray", func(t *testing.T) {
|
||||
obj := &TestObject{Name: "test", Tags: []string{"tag1", "tag2", "tag3"}}
|
||||
err := repo.Insert(ctx, obj, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
patch := repository.Patch().Pull(repository.Field("tags"), "tag2")
|
||||
err = repo.Patch(ctx, *obj.GetID(), patch)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result TestObject
|
||||
err = repo.Get(ctx, *obj.GetID(), &result)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []string{"tag1", "tag3"}, result.Tags)
|
||||
})
|
||||
|
||||
t.Run("Patch_AddToSetArray", func(t *testing.T) {
|
||||
obj := &TestObject{Name: "test", Tags: []string{"tag1"}}
|
||||
err := repo.Insert(ctx, obj, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Add new tag
|
||||
patch := repository.Patch().AddToSet(repository.Field("tags"), "tag2")
|
||||
err = repo.Patch(ctx, *obj.GetID(), patch)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result TestObject
|
||||
err = repo.Get(ctx, *obj.GetID(), &result)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []string{"tag1", "tag2"}, result.Tags)
|
||||
|
||||
// Try to add duplicate tag - should not add
|
||||
patch = repository.Patch().AddToSet(repository.Field("tags"), "tag1")
|
||||
err = repo.Patch(ctx, *obj.GetID(), patch)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = repo.Get(ctx, *obj.GetID(), &result)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []string{"tag1", "tag2"}, result.Tags)
|
||||
})
|
||||
|
||||
t.Run("Patch_PushToEmptyArray", func(t *testing.T) {
|
||||
obj := &TestObject{Name: "test", Tags: []string{}}
|
||||
err := repo.Insert(ctx, obj, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
patch := repository.Patch().Push(repository.Field("tags"), "tag1")
|
||||
err = repo.Patch(ctx, *obj.GetID(), patch)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result TestObject
|
||||
err = repo.Get(ctx, *obj.GetID(), &result)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []string{"tag1"}, result.Tags)
|
||||
})
|
||||
|
||||
t.Run("Patch_PullFromEmptyArray", func(t *testing.T) {
|
||||
obj := &TestObject{Name: "test", Tags: []string{}}
|
||||
err := repo.Insert(ctx, obj, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
patch := repository.Patch().Pull(repository.Field("tags"), "nonexistent")
|
||||
err = repo.Patch(ctx, *obj.GetID(), patch)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result TestObject
|
||||
err = repo.Get(ctx, *obj.GetID(), &result)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []string{}, result.Tags)
|
||||
})
|
||||
|
||||
t.Run("Patch_PullNonExistentElement", func(t *testing.T) {
|
||||
obj := &TestObject{Name: "test", Tags: []string{"tag1", "tag2"}}
|
||||
err := repo.Insert(ctx, obj, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
patch := repository.Patch().Pull(repository.Field("tags"), "nonexistent")
|
||||
err = repo.Patch(ctx, *obj.GetID(), patch)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result TestObject
|
||||
err = repo.Get(ctx, *obj.GetID(), &result)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []string{"tag1", "tag2"}, result.Tags)
|
||||
})
|
||||
|
||||
t.Run("Patch_ChainedArrayOperations", func(t *testing.T) {
|
||||
obj := &TestObject{Name: "test", Tags: []string{"tag1"}}
|
||||
err := repo.Insert(ctx, obj, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Note: MongoDB doesn't allow multiple operations on the same array field in a single update
|
||||
// This test demonstrates that chained array operations on the same field will fail
|
||||
patch := repository.Patch().
|
||||
Push(repository.Field("tags"), "tag2").
|
||||
AddToSet(repository.Field("tags"), "tag3").
|
||||
Pull(repository.Field("tags"), "tag1")
|
||||
err = repo.Patch(ctx, *obj.GetID(), patch)
|
||||
require.Error(t, err) // This should fail due to MongoDB's limitation
|
||||
assert.Contains(t, err.Error(), "conflict")
|
||||
})
|
||||
|
||||
t.Run("PatchMany_ArrayOperations", func(t *testing.T) {
|
||||
objs := []*TestObject{
|
||||
{Name: "obj1", Tags: []string{"tag1"}},
|
||||
{Name: "obj2", Tags: []string{"tag2"}},
|
||||
{Name: "obj3", Tags: []string{"tag3"}},
|
||||
}
|
||||
for _, o := range objs {
|
||||
err := repo.Insert(ctx, o, nil)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
query := repository.Query().Comparison(repository.Field("name"), builder.In, []string{"obj1", "obj2"})
|
||||
patch := repository.Patch().Push(repository.Field("tags"), "common")
|
||||
modified, err := repo.PatchMany(ctx, query, patch)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 2, modified)
|
||||
|
||||
// Verify the changes
|
||||
for _, name := range []string{"obj1", "obj2"} {
|
||||
var result TestObject
|
||||
err = repo.FindOneByFilter(ctx, repository.Query().Comparison(repository.Field("name"), builder.Eq, name), &result)
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, result.Tags, "common")
|
||||
}
|
||||
})
|
||||
}
|
||||
188
api/pkg/db/internal/mongo/repositoryimp/repository_test.go
Normal file
188
api/pkg/db/internal/mongo/repositoryimp/repository_test.go
Normal file
@@ -0,0 +1,188 @@
|
||||
//go:build integration
|
||||
// +build integration
|
||||
|
||||
package repositoryimp_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/tech/sendico/pkg/db/internal/mongo/repositoryimp"
|
||||
"github.com/tech/sendico/pkg/db/internal/mongo/repositoryimp/builderimp"
|
||||
"github.com/tech/sendico/pkg/db/repository/builder"
|
||||
"github.com/tech/sendico/pkg/db/storable"
|
||||
"github.com/tech/sendico/pkg/merrors"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/testcontainers/testcontainers-go"
|
||||
"github.com/testcontainers/testcontainers-go/modules/mongodb"
|
||||
"github.com/testcontainers/testcontainers-go/wait"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
"go.mongodb.org/mongo-driver/mongo/options"
|
||||
)
|
||||
|
||||
type TestObject struct {
|
||||
storable.Base `bson:",inline" json:",inline"`
|
||||
Name string `bson:"name"`
|
||||
Tags []string `bson:"tags"`
|
||||
}
|
||||
|
||||
func (t *TestObject) Collection() string {
|
||||
return "testObject"
|
||||
}
|
||||
|
||||
type AnotherObject struct {
|
||||
storable.Base `bson:",inline" json:",inline"`
|
||||
Description string `bson:"description"`
|
||||
}
|
||||
|
||||
func (a *AnotherObject) Collection() string {
|
||||
return "anotherObject"
|
||||
}
|
||||
|
||||
func terminate(ctx context.Context, t *testing.T, container *mongodb.MongoDBContainer) {
|
||||
err := container.Terminate(ctx)
|
||||
require.NoError(t, err, "failed to terminate MongoDB container")
|
||||
}
|
||||
|
||||
func disconnect(ctx context.Context, t *testing.T, client *mongo.Client) {
|
||||
err := client.Disconnect(ctx)
|
||||
require.NoError(t, err, "failed to disconnect from MongoDB")
|
||||
}
|
||||
|
||||
func TestMongoRepository_Get(t *testing.T) {
|
||||
// Use a context with timeout, so if container spinning or DB ops hang,
|
||||
// the test won't run indefinitely.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
mongoContainer, err := mongodb.Run(ctx,
|
||||
"mongo:latest",
|
||||
mongodb.WithUsername("root"),
|
||||
mongodb.WithPassword("password"),
|
||||
testcontainers.WithWaitStrategy(wait.ForLog("Waiting for connections")),
|
||||
)
|
||||
require.NoError(t, err, "failed to start MongoDB container")
|
||||
defer terminate(ctx, t, mongoContainer)
|
||||
|
||||
mongoURI, err := mongoContainer.ConnectionString(ctx)
|
||||
require.NoError(t, err, "failed to get MongoDB connection string")
|
||||
|
||||
clientOptions := options.Client().ApplyURI(mongoURI)
|
||||
client, err := mongo.Connect(ctx, clientOptions)
|
||||
require.NoError(t, err, "failed to connect to MongoDB")
|
||||
defer disconnect(ctx, t, client)
|
||||
|
||||
db := client.Database("testdb")
|
||||
repository := repositoryimp.NewMongoRepository(db, "testcollection")
|
||||
|
||||
t.Run("Get_Success", func(t *testing.T) {
|
||||
testObj := &TestObject{Name: "testName"}
|
||||
err := repository.Insert(ctx, testObj, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
result := &TestObject{}
|
||||
err = repository.Get(ctx, testObj.ID, result)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, testObj.Name, result.Name)
|
||||
assert.Equal(t, testObj.ID, result.ID)
|
||||
})
|
||||
|
||||
t.Run("Get_NotFound", func(t *testing.T) {
|
||||
nonExistentID := primitive.NewObjectID()
|
||||
result := &TestObject{}
|
||||
|
||||
err := repository.Get(ctx, nonExistentID, result)
|
||||
assert.Error(t, err)
|
||||
assert.True(t, errors.Is(err, merrors.ErrNoData))
|
||||
})
|
||||
|
||||
t.Run("Get_InvalidID", func(t *testing.T) {
|
||||
invalidID := primitive.ObjectID{} // zero value
|
||||
result := &TestObject{}
|
||||
|
||||
err := repository.Get(ctx, invalidID, result)
|
||||
assert.Error(t, err)
|
||||
assert.True(t, errors.Is(err, merrors.ErrInvalidArg))
|
||||
})
|
||||
|
||||
t.Run("Get_DifferentTypes", func(t *testing.T) {
|
||||
anotherObj := &AnotherObject{Description: "testDescription"}
|
||||
err := repository.Insert(ctx, anotherObj, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
result := &AnotherObject{}
|
||||
err = repository.Get(ctx, anotherObj.ID, result)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, anotherObj.Description, result.Description)
|
||||
assert.Equal(t, anotherObj.ID, result.ID)
|
||||
})
|
||||
}
|
||||
|
||||
func TestMongoRepository_ListIDs(t *testing.T) {
|
||||
// Use a context with timeout, so if container spinning or DB ops hang,
|
||||
// the test won't run indefinitely.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
mongoContainer, err := mongodb.Run(ctx,
|
||||
"mongo:latest",
|
||||
mongodb.WithUsername("root"),
|
||||
mongodb.WithPassword("password"),
|
||||
testcontainers.WithWaitStrategy(wait.ForLog("Waiting for connections")),
|
||||
)
|
||||
require.NoError(t, err, "failed to start MongoDB container")
|
||||
defer terminate(ctx, t, mongoContainer)
|
||||
|
||||
mongoURI, err := mongoContainer.ConnectionString(ctx)
|
||||
require.NoError(t, err, "failed to get MongoDB connection string")
|
||||
|
||||
clientOptions := options.Client().ApplyURI(mongoURI)
|
||||
client, err := mongo.Connect(ctx, clientOptions)
|
||||
require.NoError(t, err, "failed to connect to MongoDB")
|
||||
defer disconnect(ctx, t, client)
|
||||
|
||||
db := client.Database("testdb")
|
||||
repository := repositoryimp.NewMongoRepository(db, "testcollection")
|
||||
|
||||
t.Run("ListIDs_Success", func(t *testing.T) {
|
||||
// Insert test data
|
||||
testObjs := []*TestObject{
|
||||
{Name: "testName1"},
|
||||
{Name: "testName2"},
|
||||
{Name: "testName3"},
|
||||
}
|
||||
for _, obj := range testObjs {
|
||||
err := repository.Insert(ctx, obj, nil)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Define a query to match all objects
|
||||
query := builderimp.NewQueryImp()
|
||||
|
||||
// Call ListIDs
|
||||
ids, err := repository.ListIDs(ctx, query)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Assert the IDs are correct
|
||||
require.Len(t, ids, len(testObjs))
|
||||
for _, obj := range testObjs {
|
||||
assert.Contains(t, ids, obj.ID)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ListIDs_EmptyResult", func(t *testing.T) {
|
||||
// Define a query that matches no objects
|
||||
query := builderimp.NewQueryImp().Comparison(builderimp.NewFieldImp("name"), builder.Eq, "nonExistentName")
|
||||
|
||||
// Call ListIDs
|
||||
ids, err := repository.ListIDs(ctx, query)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Assert no IDs are returned
|
||||
assert.Empty(t, ids)
|
||||
})
|
||||
}
|
||||
21
api/pkg/db/internal/mongo/rolesdb/db.go
Normal file
21
api/pkg/db/internal/mongo/rolesdb/db.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package rolesdb
|
||||
|
||||
import (
|
||||
"github.com/tech/sendico/pkg/db/template"
|
||||
"github.com/tech/sendico/pkg/mlogger"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
"github.com/tech/sendico/pkg/mservice"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
)
|
||||
|
||||
type RolesDB struct {
|
||||
template.DBImp[*model.RoleDescription]
|
||||
}
|
||||
|
||||
func Create(logger mlogger.Logger, db *mongo.Database) (*RolesDB, error) {
|
||||
p := &RolesDB{
|
||||
DBImp: *template.Create[*model.RoleDescription](logger, mservice.Roles, db),
|
||||
}
|
||||
|
||||
return p, nil
|
||||
}
|
||||
15
api/pkg/db/internal/mongo/rolesdb/list.go
Normal file
15
api/pkg/db/internal/mongo/rolesdb/list.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package rolesdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/tech/sendico/pkg/db/repository"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
mutil "github.com/tech/sendico/pkg/mutil/db"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
)
|
||||
|
||||
func (db *RolesDB) List(ctx context.Context, organizationRef primitive.ObjectID, cursor *model.ViewCursor) ([]model.RoleDescription, error) {
|
||||
filter := repository.OrgFilter(organizationRef)
|
||||
return mutil.GetObjects[model.RoleDescription](ctx, db.Logger, filter, cursor, db.Repository)
|
||||
}
|
||||
15
api/pkg/db/internal/mongo/rolesdb/roles.go
Normal file
15
api/pkg/db/internal/mongo/rolesdb/roles.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package rolesdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/tech/sendico/pkg/db/repository"
|
||||
"github.com/tech/sendico/pkg/model"
|
||||
mutil "github.com/tech/sendico/pkg/mutil/db"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
)
|
||||
|
||||
func (db *RolesDB) Roles(ctx context.Context, refs []primitive.ObjectID) ([]model.RoleDescription, error) {
|
||||
filter := repository.Query().In(repository.IDField(), refs)
|
||||
return mutil.GetObjects[model.RoleDescription](ctx, db.Logger, filter, nil, db.Repository)
|
||||
}
|
||||
18
api/pkg/db/internal/mongo/transactionimp/factory.go
Normal file
18
api/pkg/db/internal/mongo/transactionimp/factory.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package transactionimp
|
||||
|
||||
import (
|
||||
"github.com/tech/sendico/pkg/db/transaction"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
)
|
||||
|
||||
type MongoTransactionFactory struct {
|
||||
client *mongo.Client
|
||||
}
|
||||
|
||||
func (mtf *MongoTransactionFactory) CreateTransaction() transaction.Transaction {
|
||||
return Create(mtf.client)
|
||||
}
|
||||
|
||||
func CreateFactory(client *mongo.Client) transaction.Factory {
|
||||
return &MongoTransactionFactory{client: client}
|
||||
}
|
||||
30
api/pkg/db/internal/mongo/transactionimp/transaction.go
Normal file
30
api/pkg/db/internal/mongo/transactionimp/transaction.go
Normal file
@@ -0,0 +1,30 @@
|
||||
package transactionimp
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/tech/sendico/pkg/db/transaction"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
)
|
||||
|
||||
type MongoTransaction struct {
|
||||
client *mongo.Client
|
||||
}
|
||||
|
||||
func (mt *MongoTransaction) Execute(ctx context.Context, cb transaction.Callback) (any, error) {
|
||||
session, err := mt.client.StartSession()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer session.EndSession(ctx)
|
||||
|
||||
callback := func(sessCtx mongo.SessionContext) (any, error) {
|
||||
return cb(sessCtx)
|
||||
}
|
||||
|
||||
return session.WithTransaction(ctx, callback)
|
||||
}
|
||||
|
||||
func Create(client *mongo.Client) *MongoTransaction {
|
||||
return &MongoTransaction{client: client}
|
||||
}
|
||||
118
api/pkg/db/internal/mongo/tseriesimp/tseries.go
Normal file
118
api/pkg/db/internal/mongo/tseriesimp/tseries.go
Normal file
@@ -0,0 +1,118 @@
|
||||
package tseriesimp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/tech/sendico/pkg/db/repository"
|
||||
"github.com/tech/sendico/pkg/db/repository/builder"
|
||||
rdecoder "github.com/tech/sendico/pkg/db/repository/decoder"
|
||||
tsoptions "github.com/tech/sendico/pkg/db/tseries/options"
|
||||
tspoint "github.com/tech/sendico/pkg/db/tseries/point"
|
||||
"github.com/tech/sendico/pkg/merrors"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
"go.mongodb.org/mongo-driver/mongo/options"
|
||||
)
|
||||
|
||||
type TimeSeries struct {
|
||||
options tsoptions.Options
|
||||
collection *mongo.Collection
|
||||
}
|
||||
|
||||
func NewMongoTimeSeriesCollection(ctx context.Context, db *mongo.Database, tsOpts *tsoptions.Options) (*TimeSeries, error) {
|
||||
if tsOpts == nil {
|
||||
return nil, merrors.InvalidArgument("nil time-series options provided")
|
||||
}
|
||||
// Configure time-series options
|
||||
granularity := tsOpts.Granularity.String()
|
||||
ts := &options.TimeSeriesOptions{
|
||||
TimeField: tsOpts.TimeField,
|
||||
Granularity: &granularity,
|
||||
}
|
||||
if tsOpts.MetaField != "" {
|
||||
ts.MetaField = &tsOpts.MetaField
|
||||
}
|
||||
|
||||
// Collection options
|
||||
collOpts := options.CreateCollection().SetTimeSeriesOptions(ts)
|
||||
|
||||
// Set TTL if requested
|
||||
if tsOpts.ExpireAfter > 0 {
|
||||
secs := int64(tsOpts.ExpireAfter / time.Second)
|
||||
collOpts.SetExpireAfterSeconds(secs)
|
||||
}
|
||||
|
||||
if err := db.CreateCollection(ctx, tsOpts.Collection, collOpts); err != nil {
|
||||
if cmdErr, ok := err.(mongo.CommandError); !ok || cmdErr.Code != 48 {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return &TimeSeries{collection: db.Collection(tsOpts.Collection), options: *tsOpts}, nil
|
||||
}
|
||||
|
||||
func (ts *TimeSeries) Aggregate(ctx context.Context, pipeline builder.Pipeline, decoder rdecoder.DecodingFunc) error {
|
||||
queryFunc := func(ctx context.Context, collection *mongo.Collection) (*mongo.Cursor, error) {
|
||||
return collection.Aggregate(ctx, pipeline.Build())
|
||||
}
|
||||
return ts.executeQuery(ctx, decoder, queryFunc)
|
||||
}
|
||||
|
||||
func (ts *TimeSeries) Insert(ctx context.Context, timePoint tspoint.TimePoint) error {
|
||||
_, err := ts.collection.InsertOne(ctx, timePoint)
|
||||
return err
|
||||
}
|
||||
|
||||
func (ts *TimeSeries) InsertMany(ctx context.Context, timePoints []tspoint.TimePoint) error {
|
||||
docs := make([]any, len(timePoints))
|
||||
for i, p := range timePoints {
|
||||
docs[i] = p
|
||||
}
|
||||
|
||||
// ignore the result if you like, or capture it
|
||||
_, err := ts.collection.InsertMany(ctx, docs)
|
||||
return err
|
||||
}
|
||||
|
||||
type QueryFunc func(ctx context.Context, collection *mongo.Collection) (*mongo.Cursor, error)
|
||||
|
||||
func (ts *TimeSeries) executeQuery(ctx context.Context, decoder rdecoder.DecodingFunc, queryFunc QueryFunc) error {
|
||||
cursor, err := queryFunc(ctx, ts.collection)
|
||||
if errors.Is(err, mongo.ErrNoDocuments) {
|
||||
return merrors.NoData("no_items_in_array")
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer cursor.Close(ctx)
|
||||
|
||||
for cursor.Next(ctx) {
|
||||
if err := cursor.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err = decoder(cursor); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ts *TimeSeries) Query(ctx context.Context, decoder rdecoder.DecodingFunc, query builder.Query, from, to *time.Time) error {
|
||||
timeLimitedQuery := query
|
||||
if from != nil {
|
||||
timeLimitedQuery = timeLimitedQuery.And(repository.Query().Comparison(repository.Field(ts.options.TimeField), builder.Gte, *from))
|
||||
}
|
||||
if to != nil {
|
||||
timeLimitedQuery = timeLimitedQuery.And(repository.Query().Comparison(repository.Field(ts.options.TimeField), builder.Lte, *to))
|
||||
}
|
||||
queryFunc := func(ctx context.Context, collection *mongo.Collection) (*mongo.Cursor, error) {
|
||||
return collection.Find(ctx, timeLimitedQuery.BuildQuery(), timeLimitedQuery.BuildOptions())
|
||||
}
|
||||
return ts.executeQuery(ctx, decoder, queryFunc)
|
||||
}
|
||||
|
||||
func (ts *TimeSeries) Name() string {
|
||||
return ts.collection.Name()
|
||||
}
|
||||
Reference in New Issue
Block a user