From c4235c3860f8e3d299114b2123a03bfcfd77b73f Mon Sep 17 00:00:00 2001 From: Sergey Chubaryan Date: Wed, 21 Aug 2024 19:58:14 +0300 Subject: [PATCH] added contextual logger --- src/logger/logger.go | 63 +++++++++++-- src/server/middleware/recovery.go | 150 ++++++++++++++++++++++++++++++ src/server/server.go | 2 +- 3 files changed, 206 insertions(+), 9 deletions(-) create mode 100644 src/server/middleware/recovery.go diff --git a/src/logger/logger.go b/src/logger/logger.go index 633f6e8..97e9c5a 100644 --- a/src/logger/logger.go +++ b/src/logger/logger.go @@ -1,30 +1,77 @@ package logger -import "github.com/rs/zerolog" +import ( + "backend/src/request_context" + "context" + + "github.com/rs/zerolog" +) type Logger interface { Log() Event Warning() Event Error() Event Fatal() Event + + Printf(format string, v ...any) } type logger struct { + prefix string + requestCtx request_context.RequestContext zeroLogger *zerolog.Logger } func (l logger) Log() Event { - return event{l.zeroLogger.Log()} + return l.wrapEvent(l.zeroLogger.Log()) } -func (l *logger) Warning() Event { - return event{l.zeroLogger.Warn()} +func (l logger) Warning() Event { + return l.wrapEvent(l.zeroLogger.Warn()) } -func (l *logger) Error() Event { - return event{l.zeroLogger.Error()} +func (l logger) Error() Event { + return l.wrapEvent(l.zeroLogger.Error()) } -func (l *logger) Fatal() Event { - return event{l.zeroLogger.Fatal()} +func (l logger) Fatal() Event { + return l.wrapEvent(l.zeroLogger.Fatal()) +} + +func (l logger) Printf(format string, v ...any) { + l.zeroLogger.Printf(format, v...) +} + +func (l logger) wrapEvent(zerologEvent *zerolog.Event) Event { + var e Event = event{zerologEvent} + + if l.requestCtx != nil { + e = e.Str("requestId", l.requestCtx.RequestId()) + e = e.Str("userId", l.requestCtx.UserId()) + if l.prefix != "" { + e = e.Str("prefix", l.prefix) + } + } + + return e +} + +func (l logger) WithContext(ctx context.Context) Logger { + if rctx, ok := ctx.(request_context.RequestContext); ok { + return logger{ + prefix: l.prefix, + requestCtx: rctx, + zeroLogger: l.zeroLogger, + } + } + + return l +} + +func (l logger) WithPrefix(prefix string) Logger { + return logger{ + prefix: prefix, + requestCtx: l.requestCtx, + zeroLogger: l.zeroLogger, + } } diff --git a/src/server/middleware/recovery.go b/src/server/middleware/recovery.go new file mode 100644 index 0000000..5957e94 --- /dev/null +++ b/src/server/middleware/recovery.go @@ -0,0 +1,150 @@ +package middleware + +import ( + "backend/src/logger" + "bytes" + "errors" + "fmt" + "net" + "net/http" + "net/http/httputil" + "os" + "runtime" + "strings" + "time" + + "github.com/gin-gonic/gin" +) + +const ( + reset = "\033[0m" +) + +var ( + dunno = []byte("???") + centerDot = []byte("·") + dot = []byte(".") + slash = []byte("/") +) + +func NewRecoveryMiddleware(logger logger.Logger, debugMode bool) gin.HandlerFunc { + handle := defaultHandleRecovery + return func(c *gin.Context) { + defer func() { + if err := recover(); err != nil { + // Check for a broken connection, as it is not really a + // condition that warrants a panic stack trace. + var brokenPipe bool + if ne, ok := err.(*net.OpError); ok { + var se *os.SyscallError + if errors.As(ne, &se) { + seStr := strings.ToLower(se.Error()) + if strings.Contains(seStr, "broken pipe") || + strings.Contains(seStr, "connection reset by peer") { + brokenPipe = true + } + } + } + if logger != nil { + stack := stack(3) + httpRequest, _ := httputil.DumpRequest(c.Request, false) + headers := strings.Split(string(httpRequest), "\r\n") + for idx, header := range headers { + current := strings.Split(header, ":") + if current[0] == "Authorization" { + headers[idx] = current[0] + ": *" + } + } + headersToStr := strings.Join(headers, "\r\n") + if brokenPipe { + logger.Printf("%s\n%s%s", err, headersToStr, reset) + } else if debugMode { + logger.Printf("[Recovery] %s panic recovered:\n%s\n%s\n%s%s", + timeFormat(time.Now()), headersToStr, err, stack, reset) + } else { + logger.Printf("[Recovery] %s panic recovered:\n%s\n%s%s", + timeFormat(time.Now()), err, stack, reset) + } + } + if brokenPipe { + // If the connection is dead, we can't write a status to it. + c.Error(err.(error)) //nolint: errcheck + c.Abort() + } else { + handle(c, err) + } + } + }() + c.Next() + } +} + +func defaultHandleRecovery(c *gin.Context, _ any) { + c.AbortWithStatus(http.StatusInternalServerError) +} + +// stack returns a nicely formatted stack frame, skipping skip frames. +func stack(skip int) []byte { + buf := new(bytes.Buffer) // the returned data + // As we loop, we open files and read them. These variables record the currently + // loaded file. + var lines [][]byte + var lastFile string + for i := skip; ; i++ { // Skip the expected number of frames + pc, file, line, ok := runtime.Caller(i) + if !ok { + break + } + // Print this much at least. If we can't find the source, it won't show. + fmt.Fprintf(buf, "%s:%d (0x%x)\n", file, line, pc) + if file != lastFile { + data, err := os.ReadFile(file) + if err != nil { + continue + } + lines = bytes.Split(data, []byte{'\n'}) + lastFile = file + } + fmt.Fprintf(buf, "\t%s: %s\n", function(pc), source(lines, line)) + } + return buf.Bytes() +} + +// source returns a space-trimmed slice of the n'th line. +func source(lines [][]byte, n int) []byte { + n-- // in stack trace, lines are 1-indexed but our array is 0-indexed + if n < 0 || n >= len(lines) { + return dunno + } + return bytes.TrimSpace(lines[n]) +} + +// function returns, if possible, the name of the function containing the PC. +func function(pc uintptr) []byte { + fn := runtime.FuncForPC(pc) + if fn == nil { + return dunno + } + name := []byte(fn.Name()) + // The name includes the path name to the package, which is unnecessary + // since the file name is already included. Plus, it has center dots. + // That is, we see + // runtime/debug.*T·ptrmethod + // and want + // *T.ptrmethod + // Also the package path might contain dot (e.g. code.google.com/...), + // so first eliminate the path prefix + if lastSlash := bytes.LastIndex(name, slash); lastSlash >= 0 { + name = name[lastSlash+1:] + } + if period := bytes.Index(name, dot); period >= 0 { + name = name[period+1:] + } + name = bytes.ReplaceAll(name, centerDot, dot) + return name +} + +// timeFormat returns a customized time string for logger. +func timeFormat(t time.Time) string { + return t.Format("2006/01/02 - 15:04:05") +} diff --git a/src/server/server.go b/src/server/server.go index 189a2fc..cfd5eaf 100644 --- a/src/server/server.go +++ b/src/server/server.go @@ -33,7 +33,7 @@ func New(opts NewServerOpts) *Server { r := gin.New() r.Use(middleware.NewRequestLogMiddleware(opts.Logger)) - r.Use(gin.Recovery()) + r.Use(middleware.NewRecoveryMiddleware(opts.Logger, opts.DebugMode)) r.Static("/webapp", "./webapp")