diff --git a/container.go b/container.go new file mode 100644 index 0000000..6a6d2f6 --- /dev/null +++ b/container.go @@ -0,0 +1,91 @@ +package di + +import ( + "fmt" + "reflect" + "strings" + "sync" +) + +type container struct { + mu sync.RWMutex + creators map[string]func() any + instances map[string]any +} + +var globalContainer *container +var once sync.Once + +func getContainer() *container { + once.Do(func() { + globalContainer = &container{ + mu: sync.RWMutex{}, + creators: make(map[string]func() any), + instances: make(map[string]any), + } + }) + return globalContainer +} + +func (c *container) injectable(typ reflect.Type, creator func() any) { + c.mu.Lock() + defer c.mu.Unlock() + + selector := c.getSelector(typ) + c.creators[selector] = creator +} + +func (c *container) replace(typ reflect.Type, creator func() any, identifier ...string) { + c.mu.Lock() + defer c.mu.Unlock() + + selector := c.getSelector(typ) + instanceSelector := c.getSelector(typ, identifier...) + c.creators[selector] = creator + createdInstance := creator() + c.instances[instanceSelector] = createdInstance +} + +func (c *container) inject(typ reflect.Type, identifier ...string) (any, bool) { + instanceSelector := c.getSelector(typ, identifier...) + + c.mu.RLock() + if instance, ok := c.instances[instanceSelector]; ok { + c.mu.RUnlock() + return instance, true + } + c.mu.RUnlock() + + c.mu.RLock() + if instance, ok := c.instances[instanceSelector]; ok { + c.mu.RUnlock() + return instance, true + } + c.mu.RUnlock() + + selector := c.getSelector(typ) + creator, creatorExists := c.creators[selector] + if !creatorExists { + return nil, false + } + + createdInstance := creator() + c.mu.Lock() + c.instances[instanceSelector] = createdInstance + c.mu.Unlock() + return createdInstance, true +} + +func (c *container) destroy(typ reflect.Type, identifier ...string) { + c.mu.Lock() + defer c.mu.Unlock() + + instanceSelector := c.getSelector(typ, identifier...) + delete(c.instances, instanceSelector) +} + +func (c *container) getSelector(typ reflect.Type, identifier ...string) string { + typeName := typ.String() + additionalKey := strings.Join(identifier, "_") + return fmt.Sprintf("%s_%s", typeName, additionalKey) +} diff --git a/injector.go b/injector.go index feaaabb..2bd8fe5 100644 --- a/injector.go +++ b/injector.go @@ -1,83 +1,37 @@ package di import ( - "fmt" "reflect" - "strings" - "sync" ) -var creatorMutex = sync.Mutex{} -var instanceMutex = sync.Mutex{} -var creators = make(map[string]func() any) -var instances = make(map[string]any) - // Injectable marks a constructor Function of a Struct for DI func Injectable[T any](creator func() T) { - creatorMutex.Lock() - defer creatorMutex.Unlock() - creators[getSelector[T]()] = func() any { - return creator() - } + typ := reflect.TypeOf((*T)(nil)).Elem() + getContainer().injectable(typ, func() any { return creator() }) } func Replace[T any](creator func() T, identifier ...string) { - Injectable(creator) - selector := getSelector[T]() - instanceSelector := getSelector[T](identifier...) - cre, creatorExists := creators[selector] - if !creatorExists { - return - } - createdInstance, instanceCreated := cre().(T) - if instanceCreated { - instanceMutex.Lock() - defer instanceMutex.Unlock() - instances[instanceSelector] = createdInstance - } + typ := reflect.TypeOf((*T)(nil)).Elem() + getContainer().replace(typ, func() any { return creator() }) +} + +func ReplaceInstance[T any](instance T, identifier ...string) { + Replace(func() T { return instance }, identifier...) } // Inject gets or create a Instance of the Struct used the Injectable constructor Function func Inject[T any](identifier ...string) T { - var nilResult T - selector := getSelector[T]() - instanceSelector := getSelector[T](identifier...) - _, instanceExists := instances[instanceSelector].(T) - if !instanceExists { - creator, creatorExists := creators[selector] - if !creatorExists { - return nilResult - } - createdInstance, instanceCreated := creator().(T) - if instanceCreated { - instanceMutex.Lock() - defer instanceMutex.Unlock() - instance, instanceExists := instances[instanceSelector].(T) - if instanceExists { - return instance - } - instances[instanceSelector] = createdInstance + var result T + typ := reflect.TypeOf((*T)(nil)).Elem() + if instance, ok := getContainer().inject(typ, identifier...); ok { + if result, ok = instance.(T); ok { + return result } } - return instances[instanceSelector].(T) + return result } func Destroy[T any](identifier ...string) { - instanceMutex.Lock() - defer instanceMutex.Unlock() - instanceSelector := getSelector[T](identifier...) - delete(instances, instanceSelector) -} - -func getSelector[T any](identifier ...string) string { - var def T - typeName := "" - typeOf := reflect.TypeOf(def) - if typeOf != nil { - typeName = typeOf.String() - } else { - typeName = reflect.TypeOf((*T)(nil)).Elem().String() - } - additionalKey := strings.Join(identifier, "_") - return fmt.Sprintf("%s_%s", typeName, additionalKey) + typ := reflect.TypeOf((*T)(nil)).Elem() + getContainer().destroy(typ, identifier...) } diff --git a/injector_test.go b/injector_test.go index dc8710c..40f4959 100644 --- a/injector_test.go +++ b/injector_test.go @@ -116,6 +116,9 @@ func (ctx *messageService) GetTextServiceID() string { } func TestInject(t *testing.T) { + // testMutex.Lock() + // defer testMutex.Unlock() + msg := newMessageService() println(msg.texts.Greeting()) if msg.texts.Greeting() != "Hello Markus" { @@ -124,6 +127,9 @@ func TestInject(t *testing.T) { } func TestInject_Duplicate(t *testing.T) { + // testMutex.Lock() + // defer testMutex.Unlock() + msg1 := newMessageService() msg2 := newMessageService() println(msg1.texts.Greeting()) @@ -192,6 +198,19 @@ func TestOverwriteInjectable(t *testing.T) { } } +func TestOverwriteInjectableInstance(t *testing.T) { + basic := di.Inject[overridableService]() + basicID := basic.GetInstanceID() + di.ReplaceInstance(newBasicOverridableServiceMock()) + basic = di.Inject[overridableService]() + if basic.GetInstanceID() == basicID { + t.Errorf("basic and newOne are the same instance") + } + if basic.GetValue() != "i am mock" { + t.Errorf("service not overwritten") + } +} + func TestDestroy(t *testing.T) { _ = di.Inject[textService]("a") di.Destroy[textService]("a") diff --git a/makefile b/makefile new file mode 100644 index 0000000..135ef0a --- /dev/null +++ b/makefile @@ -0,0 +1,2 @@ +test: + - go test ./... \ No newline at end of file