From ed1f8b8c3f042053f7711924f81fccbd0f0f9613 Mon Sep 17 00:00:00 2001 From: Sergey Chubaryan Date: Wed, 28 Aug 2024 00:58:19 +0300 Subject: [PATCH] cache sharding, valid jwt caching, add metrics --- src/app.go | 19 +++++++++- src/core/repos/cache_inmem_shard.go | 55 ++++++++++++++++++++++++++++ src/core/services/user_service.go | 7 ++++ src/core/utils/jwt.go | 18 ++++----- src/core/utils/password.go | 2 +- src/integrations/prometheus.go | 48 ++++++++++++++++-------- src/logger/bufio_wrapper.go | 6 +-- src/server/middleware/request_log.go | 16 ++++++-- 8 files changed, 137 insertions(+), 34 deletions(-) create mode 100644 src/core/repos/cache_inmem_shard.go diff --git a/src/app.go b/src/app.go index cda51d0..2a98a01 100644 --- a/src/app.go +++ b/src/app.go @@ -114,8 +114,21 @@ func (a *App) Run(p RunParams) { userRepo = repos.NewUserRepo(sqlDb) emailRepo = repos.NewEmailRepo() actionTokenRepo = repos.NewActionTokenRepo(sqlDb) - linksCache = repos.NewCacheInmem[string, string](7 * 24 * 60 * 60) - userCache = repos.NewCacheInmem[string, models.UserDTO](60 * 60) + jwtCache = repos.NewCacheInmemSharded[string, string](60, 36, func(key string) int { + char := int(key[len(key)-1]) + if char >= 0x30 && char <= 0x39 { + return char - 0x30 + } + if char >= 0x41 && char <= 0x5A { + return char - 0x41 + } + return char - 0x61 + }) //repos.NewCacheInmem[string, string](60) + linksCache = repos.NewCacheInmem[string, string](7 * 24 * 60 * 60) + userCache = repos.NewCacheInmemSharded[string, models.UserDTO](60*60, 10, func(key string) int { + char := int(key[len(key)-1]) + return char - 0x30 + }) //repos.NewCacheInmem[string, models.UserDTO](60 * 60) ) // Periodically trigger cache cleanup @@ -128,6 +141,7 @@ func (a *App) Run(p RunParams) { case <-ctx.Done(): return case <-tmr.C: + jwtCache.CheckExpired() userCache.CheckExpired() linksCache.CheckExpired() } @@ -140,6 +154,7 @@ func (a *App) Run(p RunParams) { Password: passwordUtil, UserRepo: userRepo, UserCache: userCache, + JwtCache: jwtCache, EmailRepo: emailRepo, ActionTokenRepo: actionTokenRepo, }, diff --git a/src/core/repos/cache_inmem_shard.go b/src/core/repos/cache_inmem_shard.go new file mode 100644 index 0000000..457f864 --- /dev/null +++ b/src/core/repos/cache_inmem_shard.go @@ -0,0 +1,55 @@ +package repos + +import ( + "sync" +) + +func NewCacheInmemSharded[K comparable, V any]( + ttlSeconds, shards int, + hashFunc func(key K) int, +) Cache[K, V] { + inmems := []*cacheInmem[K, V]{} + for i := 0; i < shards; i++ { + inmems = append( + inmems, + &cacheInmem[K, V]{ + m: &sync.Mutex{}, + data: map[K]*cacheInmemItem[V]{}, + ttlSeconds: ttlSeconds, + }, + ) + } + + return &cacheInmemSharded[K, V]{ + shards: inmems, + hashFunc: hashFunc, + } +} + +type cacheInmemSharded[K comparable, V any] struct { + hashFunc func(key K) int + shards []*cacheInmem[K, V] +} + +func (c *cacheInmemSharded[K, V]) Get(key K) (V, bool) { + return c.getShard(key).Get(key) +} + +func (c *cacheInmemSharded[K, V]) Set(key K, value V, ttlSeconds int) { + c.getShard(key).Set(key, value, ttlSeconds) +} + +func (c *cacheInmemSharded[K, V]) Del(key K) { + c.getShard(key).Del(key) +} + +func (c *cacheInmemSharded[K, V]) CheckExpired() { + for _, shard := range c.shards { + shard.CheckExpired() + } +} + +func (c *cacheInmemSharded[K, V]) getShard(key K) *cacheInmem[K, V] { + index := c.hashFunc(key) + return c.shards[index] +} diff --git a/src/core/services/user_service.go b/src/core/services/user_service.go index 8d8a3da..d841959 100644 --- a/src/core/services/user_service.go +++ b/src/core/services/user_service.go @@ -34,6 +34,7 @@ type UserServiceDeps struct { Password utils.PasswordUtil UserRepo repos.UserRepo UserCache repos.Cache[string, models.UserDTO] + JwtCache repos.Cache[string, string] EmailRepo repos.EmailRepo ActionTokenRepo repos.ActionTokenRepo } @@ -195,6 +196,10 @@ func (u *userService) getUserById(ctx context.Context, userId string) (*models.U } func (u *userService) ValidateToken(ctx context.Context, tokenStr string) (*models.UserDTO, error) { + if userId, ok := u.deps.JwtCache.Get(tokenStr); ok { + return u.getUserById(ctx, userId) + } + payload, err := u.deps.Jwt.Parse(tokenStr) if err != nil { return nil, ErrUserWrongToken @@ -205,5 +210,7 @@ func (u *userService) ValidateToken(ctx context.Context, tokenStr string) (*mode return nil, err } + u.deps.JwtCache.Set(tokenStr, payload.UserId, -1) + return user, nil } diff --git a/src/core/utils/jwt.go b/src/core/utils/jwt.go index 1cafcbb..89d9bed 100644 --- a/src/core/utils/jwt.go +++ b/src/core/utils/jwt.go @@ -11,14 +11,14 @@ type JwtPayload struct { UserId string `json:"userId"` } -type jwtClaims struct { +type JwtClaims struct { jwt.RegisteredClaims JwtPayload } type JwtUtil interface { Create(payload JwtPayload) (string, error) - Parse(tokenStr string) (JwtPayload, error) + Parse(tokenStr string) (JwtClaims, error) } func NewJwtUtil(privateKey *rsa.PrivateKey) JwtUtil { @@ -32,7 +32,7 @@ type jwtUtil struct { } func (j *jwtUtil) Create(payload JwtPayload) (string, error) { - claims := &jwtClaims{JwtPayload: payload} + claims := &JwtClaims{JwtPayload: payload} token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) tokenStr, err := token.SignedString(j.privateKey) if err != nil { @@ -41,17 +41,17 @@ func (j *jwtUtil) Create(payload JwtPayload) (string, error) { return tokenStr, nil } -func (j *jwtUtil) Parse(tokenStr string) (JwtPayload, error) { - token, err := jwt.ParseWithClaims(tokenStr, &jwtClaims{}, func(t *jwt.Token) (interface{}, error) { +func (j *jwtUtil) Parse(tokenStr string) (JwtClaims, error) { + token, err := jwt.ParseWithClaims(tokenStr, &JwtClaims{}, func(t *jwt.Token) (interface{}, error) { return &j.privateKey.PublicKey, nil }) if err != nil { - return JwtPayload{}, err + return JwtClaims{}, err } - if claims, ok := token.Claims.(*jwtClaims); ok { - return claims.JwtPayload, nil + if claims, ok := token.Claims.(*JwtClaims); ok { + return *claims, nil } - return JwtPayload{}, fmt.Errorf("cant get payload") + return JwtClaims{}, fmt.Errorf("cant get payload") } diff --git a/src/core/utils/password.go b/src/core/utils/password.go index c237312..aa26fd4 100644 --- a/src/core/utils/password.go +++ b/src/core/utils/password.go @@ -25,7 +25,7 @@ type passwordUtil struct { } func (b *passwordUtil) Hash(password string) (string, error) { - bytes, _ := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) + bytes, _ := bcrypt.GenerateFromPassword([]byte(password), 8) //bcrypt.DefaultCost) return string(bytes), nil } diff --git a/src/integrations/prometheus.go b/src/integrations/prometheus.go index 5703d2d..9fecfba 100644 --- a/src/integrations/prometheus.go +++ b/src/integrations/prometheus.go @@ -9,10 +9,12 @@ import ( ) type Prometheus struct { - reg *prometheus.Registry - rpsCounter prometheus.Counter - avgReqTimeHist prometheus.Histogram - panicsHist prometheus.Histogram + reg *prometheus.Registry + rpsCounter prometheus.Counter + avgReqTimeHist prometheus.Histogram + panicsHist prometheus.Histogram + errors4xxCounter prometheus.Counter + errors5xxCounter prometheus.Counter } func NewPrometheus() *Prometheus { @@ -24,12 +26,18 @@ func NewPrometheus() *Prometheus { collectors.NewProcessCollector(collectors.ProcessCollectorOpts{}), ) - // errorsCounter := prometheus.NewCounter( - // prometheus.CounterOpts{ - // Name: "backend_errors_count", - // Help: "Summary errors count", - // }, - // ) + errors5xxCounter := prometheus.NewCounter( + prometheus.CounterOpts{ + Name: "backend_errors_count_5xx", + Help: "5xx errors count", + }, + ) + errors4xxCounter := prometheus.NewCounter( + prometheus.CounterOpts{ + Name: "backend_errors_count_4xx", + Help: "4xx errors count", + }, + ) rpsCounter := prometheus.NewCounter( prometheus.CounterOpts{ Name: "backend_requests_per_second", @@ -48,13 +56,15 @@ func NewPrometheus() *Prometheus { Help: "Panics histogram metric", }, ) - reg.MustRegister(rpsCounter, avgReqTimeHist, panicsHist) + reg.MustRegister(rpsCounter, avgReqTimeHist, panicsHist, errors4xxCounter, errors5xxCounter) return &Prometheus{ - panicsHist: panicsHist, - avgReqTimeHist: avgReqTimeHist, - rpsCounter: rpsCounter, - reg: reg, + panicsHist: panicsHist, + avgReqTimeHist: avgReqTimeHist, + rpsCounter: rpsCounter, + errors4xxCounter: errors4xxCounter, + errors5xxCounter: errors5xxCounter, + reg: reg, } } @@ -77,3 +87,11 @@ func (p *Prometheus) AddRequestTime(reqTime float64) { func (p *Prometheus) AddPanic() { p.panicsHist.Observe(1) } + +func (p *Prometheus) Add4xxError() { + p.errors4xxCounter.Inc() +} + +func (p *Prometheus) Add5xxError() { + p.errors5xxCounter.Inc() +} diff --git a/src/logger/bufio_wrapper.go b/src/logger/bufio_wrapper.go index feca55d..118bfd6 100644 --- a/src/logger/bufio_wrapper.go +++ b/src/logger/bufio_wrapper.go @@ -21,7 +21,7 @@ func newWrapper(writer io.Writer) *bufioWrapper { ticker.Stop() return &bufioWrapper{ - writer: bufio.NewWriterSize(writer, 128*1024), + writer: bufio.NewWriterSize(writer, 512*1024), mutex: &sync.RWMutex{}, ticker: ticker, } @@ -47,8 +47,8 @@ func (b *bufioWrapper) FlushRoutine(ctx context.Context) { func (b *bufioWrapper) Write(p []byte) (nn int, err error) { // TODO: try replace mutex, improve logging perfomance - b.mutex.RLock() - defer b.mutex.RUnlock() + b.mutex.Lock() + defer b.mutex.Unlock() if len(p) > b.writer.Available() { b.ticker.Reset(FlushInterval) diff --git a/src/server/middleware/request_log.go b/src/server/middleware/request_log.go index bb5f1dd..d06d9e8 100644 --- a/src/server/middleware/request_log.go +++ b/src/server/middleware/request_log.go @@ -34,12 +34,20 @@ func NewRequestLogMiddleware(logger log.Logger, prometheus *integrations.Prometh method := c.Request.Method statusCode := c.Writer.Status() - clientIP := c.ClientIP() + + if statusCode >= 200 && statusCode < 400 { + return + } ctxLogger := logger.WithContext(c) - e := ctxLogger.Log() - e.Str("ip", clientIP) - e.Msgf("Request %s %s %d %v", method, path, statusCode, latency) + if statusCode >= 400 && statusCode < 500 { + prometheus.Add4xxError() + ctxLogger.Warning().Msgf("Request %s %s %d %v", method, path, statusCode, latency) + return + } + + prometheus.Add5xxError() + ctxLogger.Error().Msgf("Request %s %s %d %v", method, path, statusCode, latency) } }