diff --git a/src/app.go b/src/app.go index 2a98a01..269755b 100644 --- a/src/app.go +++ b/src/app.go @@ -114,21 +114,9 @@ func (a *App) Run(p RunParams) { userRepo = repos.NewUserRepo(sqlDb) emailRepo = repos.NewEmailRepo() actionTokenRepo = repos.NewActionTokenRepo(sqlDb) - jwtCache = repos.NewCacheInmemSharded[string, string](60, 36, func(key string) int { - char := int(key[len(key)-1]) - if char >= 0x30 && char <= 0x39 { - return char - 0x30 - } - if char >= 0x41 && char <= 0x5A { - return char - 0x41 - } - return char - 0x61 - }) //repos.NewCacheInmem[string, string](60) - linksCache = repos.NewCacheInmem[string, string](7 * 24 * 60 * 60) - userCache = repos.NewCacheInmemSharded[string, models.UserDTO](60*60, 10, func(key string) int { - char := int(key[len(key)-1]) - return char - 0x30 - }) //repos.NewCacheInmem[string, models.UserDTO](60 * 60) + userCache = repos.NewCacheInmemSharded[models.UserDTO](60*60, repos.ShardingTypeInteger) + jwtCache = repos.NewCacheInmemSharded[string](60, repos.ShardingTypeJWT) + linksCache = repos.NewCacheInmem[string, string](7 * 24 * 60 * 60) ) // Periodically trigger cache cleanup @@ -141,8 +129,8 @@ func (a *App) Run(p RunParams) { case <-ctx.Done(): return case <-tmr.C: - jwtCache.CheckExpired() userCache.CheckExpired() + jwtCache.CheckExpired() linksCache.CheckExpired() } } diff --git a/src/core/repos/cache_inmem_shard.go b/src/core/repos/cache_inmem_shard.go index 457f864..3d17ea2 100644 --- a/src/core/repos/cache_inmem_shard.go +++ b/src/core/repos/cache_inmem_shard.go @@ -4,52 +4,85 @@ import ( "sync" ) -func NewCacheInmemSharded[K comparable, V any]( - ttlSeconds, shards int, - hashFunc func(key K) int, -) Cache[K, V] { - inmems := []*cacheInmem[K, V]{} +type ShardingType int + +const ( + ShardingTypeJWT ShardingType = iota + ShardingTypeInteger +) + +type shardingHashFunc func(key string) int + +func getShardingInfo(shardingType ShardingType) (int, shardingHashFunc) { + switch shardingType { + case ShardingTypeInteger: + return 10, func(key string) int { + char := int(key[len(key)-1]) + return char - 0x30 + } + case ShardingTypeJWT: + return 36, func(key string) int { + char := int(key[len(key)-1]) + if char >= 0x30 && char <= 0x39 { + return char - 0x30 + } + if char >= 0x41 && char <= 0x5A { + return char - 0x41 + } + return char - 0x61 + } + } + + return 1, func(key string) int { + return 0 + } +} + +func NewCacheInmemSharded[V any](defaultTtlSeconds int, shardingType ShardingType) Cache[string, V] { + shards, hashFunc := getShardingInfo(shardingType) + + inmems := []*cacheInmem[string, V]{} for i := 0; i < shards; i++ { inmems = append( inmems, - &cacheInmem[K, V]{ + &cacheInmem[string, V]{ m: &sync.Mutex{}, - data: map[K]*cacheInmemItem[V]{}, - ttlSeconds: ttlSeconds, + data: map[string]*cacheInmemItem[V]{}, + ttlSeconds: defaultTtlSeconds, }, ) } - return &cacheInmemSharded[K, V]{ + return &cacheInmemSharded[V]{ shards: inmems, hashFunc: hashFunc, } } -type cacheInmemSharded[K comparable, V any] struct { - hashFunc func(key K) int - shards []*cacheInmem[K, V] +type cacheInmemSharded[V any] struct { + hashFunc shardingHashFunc + shards []*cacheInmem[string, V] } -func (c *cacheInmemSharded[K, V]) Get(key K) (V, bool) { +func (c *cacheInmemSharded[V]) Get(key string) (V, bool) { return c.getShard(key).Get(key) } -func (c *cacheInmemSharded[K, V]) Set(key K, value V, ttlSeconds int) { +func (c *cacheInmemSharded[V]) Set(key string, value V, ttlSeconds int) { c.getShard(key).Set(key, value, ttlSeconds) } -func (c *cacheInmemSharded[K, V]) Del(key K) { +func (c *cacheInmemSharded[V]) Del(key string) { c.getShard(key).Del(key) } -func (c *cacheInmemSharded[K, V]) CheckExpired() { +func (c *cacheInmemSharded[V]) CheckExpired() { for _, shard := range c.shards { shard.CheckExpired() } } -func (c *cacheInmemSharded[K, V]) getShard(key K) *cacheInmem[K, V] { +func (c *cacheInmemSharded[V]) getShard(key string) *cacheInmem[string, V] { index := c.hashFunc(key) return c.shards[index] }