Skip to content

Commit

Permalink
Update caddywaf.go
Browse files Browse the repository at this point in the history
body flush and headers fix
  • Loading branch information
fabriziosalmi authored Jan 9, 2025
1 parent 928c976 commit c4fbf8c
Showing 1 changed file with 32 additions and 96 deletions.
128 changes: 32 additions & 96 deletions caddywaf.go
Original file line number Diff line number Diff line change
Expand Up @@ -740,7 +740,6 @@ func (m *Middleware) ServeHTTP(w http.ResponseWriter, r *http.Request, next cadd
ctx := context.WithValue(r.Context(), "logID", logID)
r = r.WithContext(ctx)

// Example within your ServeHTTP method
m.logRequest(zapcore.DebugLevel, "Entering ServeHTTP", zap.String("path", r.URL.Path))

state := &WAFState{
Expand Down Expand Up @@ -769,137 +768,74 @@ func (m *Middleware) ServeHTTP(w http.ResponseWriter, r *http.Request, next cadd
return nil
}

// Whitelist Check
// Phase 1: Country Whitelist and Blacklist Check
if m.CountryWhitelist.Enabled {
whitelisted, err := m.isCountryInList(r.RemoteAddr, m.CountryWhitelist.CountryList, m.CountryWhitelist.geoIP)
if err != nil {
m.logRequest(zapcore.ErrorLevel, "Failed to check whitelist",
zap.String("log_id", logID),
zap.String("ip", r.RemoteAddr),
zap.Error(err),
)
// Consider blocking or allowing based on your policy if whitelist check fails
// For now, proceeding to blacklist as if not whitelisted
m.logRequest(zapcore.ErrorLevel, "Failed to check whitelist", zap.String("log_id", logID), zap.Error(err))
} else if whitelisted {
m.logRequest(zapcore.InfoLevel, "Request allowed - country whitelisted",
zap.String("log_id", logID),
zap.String("country", m.getCountryCode(r.RemoteAddr, m.CountryWhitelist.geoIP)),
)
return next.ServeHTTP(w, r) // Allow immediately
m.logRequest(zapcore.InfoLevel, "Request allowed - country whitelisted", zap.String("log_id", logID))
return next.ServeHTTP(w, r)
}
}

// Blacklist Check (only if not whitelisted or whitelist disabled)
if m.CountryBlock.Enabled {
blacklisted, err := m.isCountryInList(r.RemoteAddr, m.CountryBlock.CountryList, m.CountryBlock.geoIP)
if err != nil {
m.logRequest(zapcore.ErrorLevel, "Failed to check blacklist",
zap.String("log_id", logID),
zap.String("ip", r.RemoteAddr),
zap.Error(err),
)
return block(http.StatusInternalServerError, "blacklist_check_error", zap.String("message", "Internal error during blacklist check"))
m.logRequest(zapcore.ErrorLevel, "Failed to check blacklist", zap.String("log_id", logID), zap.Error(err))
return block(http.StatusInternalServerError, "blacklist_check_error")
} else if blacklisted {
m.logRequest(zapcore.WarnLevel, "Request blocked - country blacklisted",
zap.String("log_id", logID),
zap.String("country", m.getCountryCode(r.RemoteAddr, m.CountryBlock.geoIP)),
)
return block(http.StatusForbidden, "country_blacklist", zap.String("message", "Request blocked by country blacklist"))
m.logRequest(zapcore.WarnLevel, "Request blocked - country blacklisted", zap.String("log_id", logID))
return block(http.StatusForbidden, "country_blacklist")
}
}

// Phase 1 - Request Headers
m.logger.Debug("Executing Phase 1 (Request Headers)",
zap.String("log_id", logID),
)
// Phase 2: Request Body Handling
m.handlePhase(w, r, 1, state)
if state.Blocked && !state.ResponseWritten {
m.logRequest(zapcore.WarnLevel, "Request blocked in Phase 1",
zap.String("log_id", logID),
zap.Int("status_code", state.StatusCode),
zap.String("reason", "phase_1_block"),
)
if state.Blocked {
w.WriteHeader(state.StatusCode)
return nil
}

// Phase 2 - Request Body
if !state.ResponseWritten {
m.logger.Debug("Executing Phase 2 (Request Body)",
zap.String("log_id", logID),
)
m.handlePhase(w, r, 2, state)
if state.Blocked && !state.ResponseWritten {
m.logRequest(zapcore.WarnLevel, "Request blocked in Phase 2",
zap.String("log_id", logID),
zap.Int("status_code", state.StatusCode),
zap.String("reason", "phase_2_block"),
)
w.WriteHeader(state.StatusCode)
return nil
}
m.handlePhase(w, r, 2, state)
if state.Blocked {
w.WriteHeader(state.StatusCode)
return nil
}

// Response Recorder for Phase 3 and 4
// Set up response recorder for Phase 3 and Phase 4
recorder := &responseRecorder{ResponseWriter: w, body: new(bytes.Buffer)}
w = recorder
err := next.ServeHTTP(recorder, r)

// Pass request to upstream handler
err := next.ServeHTTP(w, r)

// Phase 3 - Response Headers
m.logger.Debug("Executing Phase 3 (Response Headers)",
zap.String("log_id", logID),
)
// Phase 3: Response Header Rules
m.handlePhase(recorder, r, 3, state)
if state.Blocked && !state.ResponseWritten {
m.logRequest(zapcore.WarnLevel, "Request blocked in Phase 3",
zap.String("log_id", logID),
zap.Int("status_code", state.StatusCode),
zap.String("reason", "phase_3_block"),
)
w.WriteHeader(state.StatusCode)
if state.Blocked {
recorder.WriteHeader(state.StatusCode)
return nil
}

// Phase 4 - Response Body (after response is written)
m.logger.Debug("Executing Phase 4 (Response Body)",
zap.String("log_id", logID),
zap.Int("response_length", recorder.body.Len()), // Use recorder.body.Len()
)

// Check if recorder.body is nil before accessing it
// Phase 4: Response Body Rules
if recorder.body != nil {
body := recorder.body.String()
m.logger.Debug("Phase 4 Response Body", zap.String("response_body", body))
m.logger.Debug("Response body captured", zap.String("body", body))

for _, rule := range m.Rules[4] {
m.logger.Debug("Checking rule", zap.String("rule_id", rule.ID), zap.String("pattern", rule.Pattern), zap.String("description", rule.Description))
if rule.regex.MatchString(body) {
m.processRuleMatch(recorder, r, &rule, body, state)
if state.Blocked && !state.ResponseWritten {
m.logRequest(zapcore.WarnLevel, "Request blocked in Phase 4 (Response Body)",
zap.String("log_id", logID),
zap.String("rule_id", rule.ID),
zap.String("description", rule.Description),
zap.Int("status_code", state.StatusCode),
zap.String("reason", "phase_4_block"),
)
recorder.WriteHeader(state.StatusCode) // Use recorder to ensureWriteHeader is called only once
if state.Blocked {
recorder.WriteHeader(state.StatusCode)
return nil
}
}
}

// If not blocked, flush the recorded body to the client
if !state.Blocked && !state.ResponseWritten {
_, err = w.Write(recorder.body.Bytes())
if err != nil {
m.logger.Error("Failed to flush response body", zap.Error(err))
// If not blocked, write the body to the client
if !state.ResponseWritten {
_, writeErr := w.Write(recorder.body.Bytes())
if writeErr != nil {
m.logger.Error("Failed to write response body", zap.Error(writeErr))
}
}
} else {
m.logger.Debug("Phase 4: Response body is nil, skipping rule evaluation")
}

m.logger.Info("WAF evaluation complete",
Expand Down Expand Up @@ -1053,11 +989,11 @@ func (r *responseRecorder) StatusCode() int {
// Write captures the response body and writes to the buffer only
func (r *responseRecorder) Write(b []byte) (int, error) {
if r.statusCode == 0 {
// Ensure status code is set if WriteHeader wasn't called
r.WriteHeader(http.StatusOK)
r.WriteHeader(http.StatusOK) // Default to 200 if not set
}
// Only write to the buffer, not the underlying ResponseWriter
return r.body.Write(b)
n, err := r.body.Write(b)
log.Printf("[DEBUG] Recorder Body Written: %d bytes, Error: %v", n, err)
return n, err
}

func (m *Middleware) processRuleMatch(w http.ResponseWriter, r *http.Request, rule *Rule, value string, state *WAFState) {
Expand Down

0 comments on commit c4fbf8c

Please sign in to comment.