Files
goyco/internal/middleware/compression.go

175 lines
3.7 KiB
Go

package middleware
import (
"bytes"
"compress/gzip"
"io"
"net/http"
"slices"
"strings"
)
func CompressionMiddleware() func(http.Handler) http.Handler {
return CompressionMiddlewareWithConfig(nil)
}
func CompressionMiddlewareWithConfig(config *CompressionConfig) func(http.Handler) http.Handler {
if config == nil {
config = DefaultCompressionConfig()
}
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !strings.Contains(r.Header.Get("Accept-Encoding"), "gzip") {
next.ServeHTTP(w, r)
return
}
if !shouldCompress(r, config) {
next.ServeHTTP(w, r)
return
}
var buf bytes.Buffer
bufferedWriter := &bufferedResponseWriter{
ResponseWriter: w,
buffer: &buf,
}
next.ServeHTTP(bufferedWriter, r)
if buf.Len() < config.MinSize {
bufferedWriter.flush()
w.Write(buf.Bytes())
return
}
responseContentType := w.Header().Get("Content-Type")
if !shouldCompressResponse(responseContentType, config) {
bufferedWriter.flush()
w.Write(buf.Bytes())
return
}
w.Header().Set("Content-Encoding", "gzip")
w.Header().Set("Vary", "Accept-Encoding")
bufferedWriter.flush()
gz, err := gzip.NewWriterLevel(w, config.Level)
if err != nil {
gz = gzip.NewWriter(w)
}
defer gz.Close()
if _, err := gz.Write(buf.Bytes()); err != nil {
return
}
})
}
}
type bufferedResponseWriter struct {
http.ResponseWriter
buffer *bytes.Buffer
statusCode int
headerWritten bool
}
func (brw *bufferedResponseWriter) Write(b []byte) (int, error) {
if !brw.headerWritten {
brw.statusCode = http.StatusOK
}
return brw.buffer.Write(b)
}
func (brw *bufferedResponseWriter) WriteHeader(code int) {
if brw.headerWritten {
return
}
brw.statusCode = code
}
func (brw *bufferedResponseWriter) Header() http.Header {
return brw.ResponseWriter.Header()
}
func (brw *bufferedResponseWriter) flush() {
if !brw.headerWritten {
brw.ResponseWriter.WriteHeader(brw.statusCode)
brw.headerWritten = true
}
}
func shouldCompress(r *http.Request, config *CompressionConfig) bool {
return r.Header.Get("Content-Encoding") == ""
}
func shouldCompressResponse(contentType string, config *CompressionConfig) bool {
if contentType == "" {
return true
}
compressible := false
for _, compressibleType := range config.CompressibleTypes {
if strings.HasPrefix(contentType, compressibleType) {
compressible = true
break
}
}
if !compressible {
return false
}
nonCompressiblePrefixes := []string{"image/", "video/", "audio/"}
nonCompressibleExact := []string{"application/zip", "application/gzip"}
for _, prefix := range nonCompressiblePrefixes {
if strings.HasPrefix(contentType, prefix) {
return false
}
}
return !slices.Contains(nonCompressibleExact, contentType)
}
func DecompressionMiddleware() func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Header.Get("Content-Encoding") == "gzip" {
gz, err := gzip.NewReader(r.Body)
if err != nil {
http.Error(w, "Invalid gzip encoding", http.StatusBadRequest)
return
}
defer gz.Close()
r.Body = io.NopCloser(gz)
r.Header.Del("Content-Encoding")
}
next.ServeHTTP(w, r)
})
}
}
type CompressionConfig struct {
Level int
MinSize int
CompressibleTypes []string
}
func DefaultCompressionConfig() *CompressionConfig {
return &CompressionConfig{
Level: gzip.DefaultCompression,
MinSize: 0,
CompressibleTypes: []string{
"text/",
"application/json",
"application/xml",
"application/javascript",
"application/css",
"application/",
},
}
}