diff options
| -rw-r--r-- | nonsense-time.go | 49 |
1 files changed, 35 insertions, 14 deletions
diff --git a/nonsense-time.go b/nonsense-time.go index 1d4e218..4bb1067 100644 --- a/nonsense-time.go +++ b/nonsense-time.go @@ -9,6 +9,7 @@ import ( "net/http" "nonsense-time/api" "os" + "strings" "time" ) @@ -25,30 +26,49 @@ func timeoutMiddleware(next http.Handler, duration time.Duration) http.Handler { } type gzipResponseWriter struct { - gzw gzip.Writer - hrw http.ResponseWriter + gzip.Writer + http.ResponseWriter } -func (gzrw gzipResponseWriter) Header() http.Header { - headers := gzrw.hrw.Header() - // TODO: add gzip header - return headers +func newGzipResponseWriter(w http.ResponseWriter) *gzipResponseWriter { + gzrw := new(gzipResponseWriter) + gzrw.Writer = *gzip.NewWriter(w) + gzrw.ResponseWriter = w + + return gzrw +} + +func (gzrw *gzipResponseWriter) Close() error { + return gzrw.Writer.Close() } -func (gzrw gzipResponseWriter) Write(p []byte) (int, error) { - // TODO: write data using gzrw.gzw.Write - return gzrw.hrw.Write(p) +func (gzrw *gzipResponseWriter) Header() http.Header { + return gzrw.ResponseWriter.Header() } -func (gzrw gzipResponseWriter) WriteHeader(statuscode int) { - gzrw.hrw.WriteHeader(statuscode) +func (gzrw *gzipResponseWriter) Write(p []byte) (int, error) { + n, err := gzrw.Writer.Write(p) + return n, err +} + +func (gzrw *gzipResponseWriter) WriteHeader(statuscode int) { + gzrw.ResponseWriter.WriteHeader(statuscode) } // Middleware to conditionally gzip a response func gzipMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // TODO: check for support gzip header - gzrw := gzipResponseWriter{hrw: w} + acceptedEncodings := r.Header.Get("Accept-Encoding") + + if !strings.Contains(acceptedEncodings, "gzip") { + logger.DebugContext(r.Context(), "Using gzip middleware but client does not support gzip encoding") + next.ServeHTTP(w, r) + return + } + + w.Header().Add("Content-Encoding", "gzip") + gzrw := newGzipResponseWriter(w) + defer gzrw.Close() next.ServeHTTP(gzrw, r) }) @@ -92,10 +112,11 @@ func main() { api.Logger = logger vtt := timeoutMiddleware(http.HandlerFunc(api.VttOnline), *waitTime) + vttLogs := gzipMiddleware(http.HandlerFunc(api.VttLogs)) mux.Handle("GET /vtt/status", vtt) mux.HandleFunc("GET /vtt", api.VttRedirect) - mux.HandleFunc("GET /vtt/logs", api.VttLogs) + mux.Handle("GET /vtt/logs", vttLogs) logger.Info(fmt.Sprint("Listening on ", addr)) logger.Info(http.ListenAndServe(addr, mux).Error()) |
