Compare commits

..

19 Commits

Author SHA1 Message Date
e58ba1b8d1 chore: add title 2026-03-18 18:07:15 +01:00
4ffc601723 fix: avoid mangle backslash 2026-03-11 07:22:51 +01:00
d6321e775a test(integration): update DB monitoring health assertion to match nested services payload 2026-03-06 15:37:53 +01:00
de9b544afb refactor(cors): deduplicate origin validation and header logic without behavior change 2026-03-06 15:37:44 +01:00
19291b7f61 feat: update swagger 2026-03-05 11:39:24 +01:00
c31eb2f3df test(e2e): make middleware tests assertion-driven and deterministic 2026-02-23 07:11:22 +01:00
de08878de7 test(e2e): add middleware-enabled test context and server config toggles 2026-02-23 07:11:17 +01:00
f0e8da51d0 feat(server): allow cacheable paths to be configured in router 2026-02-23 07:11:14 +01:00
85882bae14 refactor: go fix ftw 2026-02-19 17:37:42 +01:00
9185ffa6b5 test(server): mock title fetcher in router tests to remove network dependency 2026-02-19 17:37:31 +01:00
986b4e9388 refactor: modernize code using go fix 2026-02-19 17:31:06 +01:00
ac6e1ba80b refactor: modern code using go fix 2026-02-19 17:30:12 +01:00
14da02bc3f refactor: use go fix 2026-02-19 17:29:44 +01:00
31ef30c941 test(health): expect unhealthy for SMTP connection failures 2026-02-16 08:43:46 +01:00
d4a89325e0 fix(health): mark SMTP connection/bootstrap failures as unhealthy 2026-02-16 08:43:33 +01:00
4eb0a6360f test(health): cover SMTP unhealthy aggregation behavior 2026-02-16 08:43:14 +01:00
040b9148de fix(health): treat SMTP unhealthy as degraded at app level 2026-02-16 08:43:01 +01:00
6e0dfabcff feat: health check now return json, definitely 2026-02-16 08:33:51 +01:00
9e81ddfdfa fix: don't reinvent the wheel 2026-02-15 12:05:25 +01:00
53 changed files with 453 additions and 556 deletions

View File

@@ -18,7 +18,7 @@ import (
var ( var (
helpPrinterOnce sync.Once helpPrinterOnce sync.Once
defaultHelpPrinter func(io.Writer, string, interface{}) defaultHelpPrinter func(io.Writer, string, any)
) )
func loadDotEnv() { func loadDotEnv() {
@@ -71,7 +71,7 @@ func buildRootCommand(cfg *config.Config) *cli.Command {
helpPrinterOnce.Do(func() { helpPrinterOnce.Do(func() {
defaultHelpPrinter = cli.HelpPrinter defaultHelpPrinter = cli.HelpPrinter
}) })
cli.HelpPrinter = func(w io.Writer, templ string, data interface{}) { cli.HelpPrinter = func(w io.Writer, templ string, data any) {
if cmd, ok := data.(*cli.Command); ok && cmd.Root() == cmd { if cmd, ok := data.(*cli.Command); ok && cmd.Root() == cmd {
printRootUsage() printRootUsage()
return return

View File

@@ -113,15 +113,15 @@ func truncate(in string, max int) string {
return in[:max-3] + "..." return in[:max-3] + "..."
} }
func outputJSON(v interface{}) error { func outputJSON(v any) error {
encoder := json.NewEncoder(os.Stdout) encoder := json.NewEncoder(os.Stdout)
encoder.SetIndent("", " ") encoder.SetIndent("", " ")
return encoder.Encode(v) return encoder.Encode(v)
} }
func outputWarning(message string, args ...interface{}) { func outputWarning(message string, args ...any) {
if IsJSONOutput() { if IsJSONOutput() {
outputJSON(map[string]interface{}{ outputJSON(map[string]any{
"warning": fmt.Sprintf(message, args...), "warning": fmt.Sprintf(message, args...),
}) })
} else { } else {

View File

@@ -236,7 +236,7 @@ func TestSetJSONOutput(t *testing.T) {
done := make(chan bool) done := make(chan bool)
go func() { go func() {
for i := 0; i < 100; i++ { for range 100 {
SetJSONOutput(true) SetJSONOutput(true)
_ = IsJSONOutput() _ = IsJSONOutput()
SetJSONOutput(false) SetJSONOutput(false)
@@ -245,7 +245,7 @@ func TestSetJSONOutput(t *testing.T) {
}() }()
go func() { go func() {
for i := 0; i < 100; i++ { for range 100 {
_ = IsJSONOutput() _ = IsJSONOutput()
} }
done <- true done <- true

View File

@@ -87,7 +87,7 @@ func runStatusCommand(cfg *config.Config) error {
if !isDaemonRunning(pidFile) { if !isDaemonRunning(pidFile) {
if IsJSONOutput() { if IsJSONOutput() {
outputJSON(map[string]interface{}{ outputJSON(map[string]any{
"status": "not_running", "status": "not_running",
}) })
} else { } else {
@@ -99,7 +99,7 @@ func runStatusCommand(cfg *config.Config) error {
data, err := os.ReadFile(pidFile) data, err := os.ReadFile(pidFile)
if err != nil { if err != nil {
if IsJSONOutput() { if IsJSONOutput() {
outputJSON(map[string]interface{}{ outputJSON(map[string]any{
"status": "running", "status": "running",
"error": fmt.Sprintf("PID file exists but cannot be read: %v", err), "error": fmt.Sprintf("PID file exists but cannot be read: %v", err),
}) })
@@ -112,7 +112,7 @@ func runStatusCommand(cfg *config.Config) error {
pid, err := strconv.Atoi(string(data)) pid, err := strconv.Atoi(string(data))
if err != nil { if err != nil {
if IsJSONOutput() { if IsJSONOutput() {
outputJSON(map[string]interface{}{ outputJSON(map[string]any{
"status": "running", "status": "running",
"error": fmt.Sprintf("PID file exists but contains invalid PID: %v", err), "error": fmt.Sprintf("PID file exists but contains invalid PID: %v", err),
}) })
@@ -123,7 +123,7 @@ func runStatusCommand(cfg *config.Config) error {
} }
if IsJSONOutput() { if IsJSONOutput() {
outputJSON(map[string]interface{}{ outputJSON(map[string]any{
"status": "running", "status": "running",
"pid": pid, "pid": pid,
}) })
@@ -171,7 +171,7 @@ func stopDaemon(cfg *config.Config) error {
_ = os.Remove(pidFile) _ = os.Remove(pidFile)
if IsJSONOutput() { if IsJSONOutput() {
outputJSON(map[string]interface{}{ outputJSON(map[string]any{
"action": "stopped", "action": "stopped",
"pid": pid, "pid": pid,
}) })
@@ -219,7 +219,7 @@ func runDaemon(cfg *config.Config) error {
return fmt.Errorf("cannot write PID file: %w", err) return fmt.Errorf("cannot write PID file: %w", err)
} }
if IsJSONOutput() { if IsJSONOutput() {
outputJSON(map[string]interface{}{ outputJSON(map[string]any{
"action": "started", "action": "started",
"pid": pid, "pid": pid,
"pid_file": pidFile, "pid_file": pidFile,

View File

@@ -55,11 +55,7 @@ func runHealthCheck(cfg *config.Config) error {
ctx := context.Background() ctx := context.Background()
result := compositeChecker.CheckWithVersion(ctx, version.GetVersion()) result := compositeChecker.CheckWithVersion(ctx, version.GetVersion())
if IsJSONOutput() { return outputJSON(result)
return outputJSON(result)
}
return printHealthResult(result)
} }
func createDatabaseChecker(cfg *config.Config) (health.Checker, error) { func createDatabaseChecker(cfg *config.Config) (health.Checker, error) {
@@ -96,34 +92,3 @@ func createSMTPChecker(cfg *config.Config) health.Checker {
return health.NewSMTPChecker(smtpConfig) return health.NewSMTPChecker(smtpConfig)
} }
func printHealthResult(result health.OverallResult) error {
fmt.Printf("Health Status: %s\n", result.Status)
fmt.Printf("Version: %s\n", result.Version)
fmt.Printf("Timestamp: %s\n", result.Timestamp.Format(time.RFC3339))
fmt.Println()
if len(result.Services) == 0 {
fmt.Println("No services configured.")
return nil
}
fmt.Println("Services:")
for name, service := range result.Services {
fmt.Printf(" %s:\n", name)
fmt.Printf(" Status: %s\n", service.Status)
fmt.Printf(" Latency: %s\n", service.Latency)
if service.Message != "" {
fmt.Printf(" Message: %s\n", service.Message)
}
if len(service.Details) > 0 {
fmt.Printf(" Details:\n")
for key, value := range service.Details {
fmt.Printf(" %s: %v\n", key, value)
}
}
fmt.Println()
}
return nil
}

View File

@@ -4,10 +4,8 @@ import (
"os" "os"
"strings" "strings"
"testing" "testing"
"time"
"goyco/internal/config" "goyco/internal/config"
"goyco/internal/health"
"goyco/internal/testutils" "goyco/internal/testutils"
) )
@@ -89,117 +87,3 @@ func TestCreateSMTPChecker(t *testing.T) {
} }
}) })
} }
func TestPrintHealthResult(t *testing.T) {
t.Run("healthy result", func(t *testing.T) {
result := health.OverallResult{
Status: health.StatusHealthy,
Version: "v1.0.0",
Timestamp: time.Now().UTC(),
Services: map[string]health.Result{
"database": {
Status: health.StatusHealthy,
Latency: 2 * time.Millisecond,
Details: map[string]any{
"ping_time": "2ms",
},
},
},
}
oldStdout := os.Stdout
r, w, _ := os.Pipe()
os.Stdout = w
err := printHealthResult(result)
w.Close()
os.Stdout = oldStdout
if err != nil {
t.Errorf("unexpected error: %v", err)
}
buf := make([]byte, 1024)
n, _ := r.Read(buf)
output := string(buf[:n])
if !strings.Contains(output, "Health Status: healthy") {
t.Errorf("expected 'Health Status: healthy', got %q", output)
}
if !strings.Contains(output, "Version: v1.0.0") {
t.Errorf("expected 'Version: v1.0.0', got %q", output)
}
})
t.Run("degraded result with message", func(t *testing.T) {
result := health.OverallResult{
Status: health.StatusDegraded,
Version: "v1.0.0",
Timestamp: time.Now().UTC(),
Services: map[string]health.Result{
"smtp": {
Status: health.StatusDegraded,
Message: "Connection failed",
Latency: 5 * time.Millisecond,
},
},
}
oldStdout := os.Stdout
r, w, _ := os.Pipe()
os.Stdout = w
err := printHealthResult(result)
w.Close()
os.Stdout = oldStdout
if err != nil {
t.Errorf("unexpected error: %v", err)
}
buf := make([]byte, 1024)
n, _ := r.Read(buf)
output := string(buf[:n])
if !strings.Contains(output, "Health Status: degraded") {
t.Errorf("expected 'Health Status: degraded', got %q", output)
}
if !strings.Contains(output, "Connection failed") {
t.Errorf("expected error message in output, got %q", output)
}
})
t.Run("empty services", func(t *testing.T) {
result := health.OverallResult{
Status: health.StatusHealthy,
Version: "v1.0.0",
Timestamp: time.Now().UTC(),
Services: map[string]health.Result{},
}
oldStdout := os.Stdout
r, w, _ := os.Pipe()
os.Stdout = w
err := printHealthResult(result)
w.Close()
os.Stdout = oldStdout
if err != nil {
t.Errorf("unexpected error: %v", err)
}
buf := make([]byte, 1024)
n, _ := r.Read(buf)
output := string(buf[:n])
if !strings.Contains(output, "No services configured") {
t.Errorf("expected 'No services configured', got %q", output)
}
})
}

View File

@@ -91,7 +91,7 @@ func postDelete(repo repositories.PostRepository, args []string) error {
} }
if IsJSONOutput() { if IsJSONOutput() {
outputJSON(map[string]interface{}{ outputJSON(map[string]any{
"action": "post_deleted", "action": "post_deleted",
"id": id, "id": id,
}) })
@@ -158,7 +158,7 @@ func postList(postQueries *services.PostQueries, args []string) error {
CreatedAt: p.CreatedAt.Format("2006-01-02 15:04:05"), CreatedAt: p.CreatedAt.Format("2006-01-02 15:04:05"),
} }
} }
outputJSON(map[string]interface{}{ outputJSON(map[string]any{
"posts": postsJSON, "posts": postsJSON,
"count": len(postsJSON), "count": len(postsJSON),
}) })
@@ -298,7 +298,7 @@ func postSearch(postQueries *services.PostQueries, args []string) error {
CreatedAt: p.CreatedAt.Format("2006-01-02 15:04:05"), CreatedAt: p.CreatedAt.Format("2006-01-02 15:04:05"),
} }
} }
outputJSON(map[string]interface{}{ outputJSON(map[string]any{
"search_term": sanitizedTerm, "search_term": sanitizedTerm,
"posts": postsJSON, "posts": postsJSON,
"count": len(postsJSON), "count": len(postsJSON),

View File

@@ -508,9 +508,9 @@ func TestProgressIndicator_Concurrency(t *testing.T) {
pi := NewProgressIndicator(100, "Concurrent test") pi := NewProgressIndicator(100, "Concurrent test")
done := make(chan bool) done := make(chan bool)
for i := 0; i < 10; i++ { for range 10 {
go func() { go func() {
for j := 0; j < 10; j++ { for range 10 {
pi.Increment() pi.Increment()
time.Sleep(1 * time.Millisecond) time.Sleep(1 * time.Millisecond)
} }
@@ -518,7 +518,7 @@ func TestProgressIndicator_Concurrency(t *testing.T) {
}() }()
} }
for i := 0; i < 10; i++ { for range 10 {
<-done <-done
} }

View File

@@ -95,7 +95,7 @@ func prunePosts(postRepo repositories.PostRepository, args []string) error {
} }
} }
if *dryRun { if *dryRun {
outputJSON(map[string]interface{}{ outputJSON(map[string]any{
"action": "prune_posts", "action": "prune_posts",
"dry_run": true, "dry_run": true,
"posts": postsJSON, "posts": postsJSON,
@@ -110,7 +110,7 @@ func prunePosts(postRepo repositories.PostRepository, args []string) error {
if err != nil { if err != nil {
return fmt.Errorf("hard delete posts: %w", err) return fmt.Errorf("hard delete posts: %w", err)
} }
outputJSON(map[string]interface{}{ outputJSON(map[string]any{
"action": "prune_posts", "action": "prune_posts",
"deleted_count": deletedCount, "deleted_count": deletedCount,
}) })
@@ -178,7 +178,7 @@ func pruneUsers(userRepo repositories.UserRepository, postRepo repositories.Post
userCount := len(users) userCount := len(users)
if userCount == 0 { if userCount == 0 {
if IsJSONOutput() { if IsJSONOutput() {
outputJSON(map[string]interface{}{ outputJSON(map[string]any{
"action": "prune_users", "action": "prune_users",
"count": 0, "count": 0,
}) })
@@ -211,7 +211,7 @@ func pruneUsers(userRepo repositories.UserRepository, postRepo repositories.Post
} }
} }
if *dryRun { if *dryRun {
outputJSON(map[string]interface{}{ outputJSON(map[string]any{
"action": "prune_users", "action": "prune_users",
"dry_run": true, "dry_run": true,
"users": usersJSON, "users": usersJSON,
@@ -229,7 +229,7 @@ func pruneUsers(userRepo repositories.UserRepository, postRepo repositories.Post
if err != nil { if err != nil {
return fmt.Errorf("hard delete all users and posts: %w", err) return fmt.Errorf("hard delete all users and posts: %w", err)
} }
outputJSON(map[string]interface{}{ outputJSON(map[string]any{
"action": "prune_users", "action": "prune_users",
"deleted_count": totalDeleted, "deleted_count": totalDeleted,
"with_posts": true, "with_posts": true,
@@ -242,7 +242,7 @@ func pruneUsers(userRepo repositories.UserRepository, postRepo repositories.Post
} }
deletedCount++ deletedCount++
} }
outputJSON(map[string]interface{}{ outputJSON(map[string]any{
"action": "prune_users", "action": "prune_users",
"deleted_count": deletedCount, "deleted_count": deletedCount,
"with_posts": false, "with_posts": false,
@@ -328,7 +328,7 @@ func pruneAll(userRepo repositories.UserRepository, postRepo repositories.PostRe
if IsJSONOutput() { if IsJSONOutput() {
if *dryRun { if *dryRun {
outputJSON(map[string]interface{}{ outputJSON(map[string]any{
"action": "prune_all", "action": "prune_all",
"dry_run": true, "dry_run": true,
"user_count": len(userCount), "user_count": len(userCount),
@@ -343,7 +343,7 @@ func pruneAll(userRepo repositories.UserRepository, postRepo repositories.PostRe
if err != nil { if err != nil {
return fmt.Errorf("hard delete all: %w", err) return fmt.Errorf("hard delete all: %w", err)
} }
outputJSON(map[string]interface{}{ outputJSON(map[string]any{
"action": "prune_all", "action": "prune_all",
"deleted_count": totalDeleted, "deleted_count": totalDeleted,
}) })

View File

@@ -321,7 +321,7 @@ func createUsers(g *seedGenerator, userRepo repositories.UserRepository, count i
} }
progress := maybeProgress(count, desc) progress := maybeProgress(count, desc)
users := make([]database.User, 0, count) users := make([]database.User, 0, count)
for i := 0; i < count; i++ { for i := range count {
user, err := g.createSingleUser(userRepo, i+1) user, err := g.createSingleUser(userRepo, i+1)
if err != nil { if err != nil {
return nil, fmt.Errorf("create random user: %w", err) return nil, fmt.Errorf("create random user: %w", err)
@@ -343,7 +343,7 @@ func createPosts(g *seedGenerator, postRepo repositories.PostRepository, authorI
} }
progress := maybeProgress(count, desc) progress := maybeProgress(count, desc)
posts := make([]database.Post, 0, count) posts := make([]database.Post, 0, count)
for i := 0; i < count; i++ { for i := range count {
post, err := g.createSinglePost(postRepo, authorID, i+1) post, err := g.createSinglePost(postRepo, authorID, i+1)
if err != nil { if err != nil {
return nil, fmt.Errorf("create random post: %w", err) return nil, fmt.Errorf("create random post: %w", err)

View File

@@ -169,7 +169,7 @@ func userCreate(cfg *config.Config, repo repositories.UserRepository, args []str
} }
if IsJSONOutput() { if IsJSONOutput() {
outputJSON(map[string]interface{}{ outputJSON(map[string]any{
"action": "user_created", "action": "user_created",
"id": user.ID, "id": user.ID,
"username": user.Username, "username": user.Username,
@@ -296,7 +296,7 @@ func userUpdate(cfg *config.Config, repo repositories.UserRepository, refreshTok
} }
if IsJSONOutput() { if IsJSONOutput() {
outputJSON(map[string]interface{}{ outputJSON(map[string]any{
"action": "user_updated", "action": "user_updated",
"id": user.ID, "id": user.ID,
"username": user.Username, "username": user.Username,
@@ -424,7 +424,7 @@ func userDelete(cfg *config.Config, repo repositories.UserRepository, args []str
} }
if IsJSONOutput() { if IsJSONOutput() {
outputJSON(map[string]interface{}{ outputJSON(map[string]any{
"action": "user_deleted", "action": "user_deleted",
"id": id, "id": id,
"username": user.Username, "username": user.Username,
@@ -479,7 +479,7 @@ func userList(repo repositories.UserRepository, args []string) error {
CreatedAt: u.CreatedAt.Format("2006-01-02 15:04:05"), CreatedAt: u.CreatedAt.Format("2006-01-02 15:04:05"),
} }
} }
outputJSON(map[string]interface{}{ outputJSON(map[string]any{
"users": usersJSON, "users": usersJSON,
"count": len(usersJSON), "count": len(usersJSON),
}) })
@@ -578,7 +578,7 @@ func userLock(cfg *config.Config, repo repositories.UserRepository, args []strin
if user.Locked { if user.Locked {
if IsJSONOutput() { if IsJSONOutput() {
outputJSON(map[string]interface{}{ outputJSON(map[string]any{
"action": "user_lock", "action": "user_lock",
"id": id, "id": id,
"username": user.Username, "username": user.Username,
@@ -604,7 +604,7 @@ func userLock(cfg *config.Config, repo repositories.UserRepository, args []strin
} }
if IsJSONOutput() { if IsJSONOutput() {
outputJSON(map[string]interface{}{ outputJSON(map[string]any{
"action": "user_locked", "action": "user_locked",
"id": id, "id": id,
"username": user.Username, "username": user.Username,
@@ -653,7 +653,7 @@ func userUnlock(cfg *config.Config, repo repositories.UserRepository, args []str
if !user.Locked { if !user.Locked {
if IsJSONOutput() { if IsJSONOutput() {
outputJSON(map[string]interface{}{ outputJSON(map[string]any{
"action": "user_unlock", "action": "user_unlock",
"id": id, "id": id,
"username": user.Username, "username": user.Username,
@@ -679,7 +679,7 @@ func userUnlock(cfg *config.Config, repo repositories.UserRepository, args []str
} }
if IsJSONOutput() { if IsJSONOutput() {
outputJSON(map[string]interface{}{ outputJSON(map[string]any{
"action": "user_unlocked", "action": "user_unlocked",
"id": id, "id": id,
"username": user.Username, "username": user.Username,
@@ -732,7 +732,7 @@ func resetUserPassword(cfg *config.Config, repo repositories.UserRepository, ses
} }
if IsJSONOutput() { if IsJSONOutput() {
outputJSON(map[string]interface{}{ outputJSON(map[string]any{
"action": "password_reset", "action": "password_reset",
"id": userID, "id": userID,
"username": user.Username, "username": user.Username,

View File

@@ -1771,7 +1771,7 @@ const docTemplate = `{
}, },
"/health": { "/health": {
"get": { "get": {
"description": "Check the API health status along with database connectivity details", "description": "Check the API health status along with database connectivity and SMTP service details",
"consumes": [ "consumes": [
"application/json" "application/json"
], ],

View File

@@ -1768,7 +1768,7 @@
}, },
"/health": { "/health": {
"get": { "get": {
"description": "Check the API health status along with database connectivity details", "description": "Check the API health status along with database connectivity and SMTP service details",
"consumes": [ "consumes": [
"application/json" "application/json"
], ],

View File

@@ -1387,7 +1387,8 @@ paths:
get: get:
consumes: consumes:
- application/json - application/json
description: Check the API health status along with database connectivity details description: Check the API health status along with database connectivity and
SMTP service details
produces: produces:
- application/json - application/json
responses: responses:

View File

@@ -289,6 +289,19 @@ func setupTestContext(t *testing.T) *testContext {
} }
} }
func setupTestContextWithMiddleware(t *testing.T) *testContext {
t.Helper()
server := setupIntegrationTestServerWithMiddlewareEnabled(t)
t.Cleanup(func() {
server.Cleanup()
})
return &testContext{
server: server,
client: server.NewHTTPClient(),
baseURL: server.BaseURL(),
}
}
func setupTestContextWithAuthRateLimit(t *testing.T, authLimit int) *testContext { func setupTestContextWithAuthRateLimit(t *testing.T, authLimit int) *testContext {
t.Helper() t.Helper()
server := setupIntegrationTestServerWithAuthRateLimit(t, authLimit) server := setupIntegrationTestServerWithAuthRateLimit(t, authLimit)
@@ -602,15 +615,35 @@ func generateTokenWithExpiration(t *testing.T, user *database.User, cfg *config.
} }
type serverConfig struct { type serverConfig struct {
authLimit int authLimit int
disableCache bool
disableCompression bool
cacheablePaths []string
} }
func setupIntegrationTestServer(t *testing.T) *IntegrationTestServer { func setupIntegrationTestServer(t *testing.T) *IntegrationTestServer {
return setupIntegrationTestServerWithConfig(t, serverConfig{authLimit: 50000}) return setupIntegrationTestServerWithConfig(t, serverConfig{
authLimit: 50000,
disableCache: true,
disableCompression: true,
})
} }
func setupIntegrationTestServerWithAuthRateLimit(t *testing.T, authLimit int) *IntegrationTestServer { func setupIntegrationTestServerWithAuthRateLimit(t *testing.T, authLimit int) *IntegrationTestServer {
return setupIntegrationTestServerWithConfig(t, serverConfig{authLimit: authLimit}) return setupIntegrationTestServerWithConfig(t, serverConfig{
authLimit: authLimit,
disableCache: true,
disableCompression: true,
})
}
func setupIntegrationTestServerWithMiddlewareEnabled(t *testing.T) *IntegrationTestServer {
return setupIntegrationTestServerWithConfig(t, serverConfig{
authLimit: 50000,
disableCache: false,
disableCompression: false,
cacheablePaths: []string{"/api/posts"},
})
} }
func setupDatabase(t *testing.T) *gorm.DB { func setupDatabase(t *testing.T) *gorm.DB {
@@ -678,7 +711,7 @@ func setupHandlers(authService handlers.AuthServiceInterface, userRepo repositor
handlers.NewAPIHandler(cfg, postRepo, userRepo, voteService) handlers.NewAPIHandler(cfg, postRepo, userRepo, voteService)
} }
func setupRouter(authHandler *handlers.AuthHandler, postHandler *handlers.PostHandler, voteHandler *handlers.VoteHandler, userHandler *handlers.UserHandler, apiHandler *handlers.APIHandler, authService handlers.AuthServiceInterface, cfg *config.Config) http.Handler { func setupRouter(authHandler *handlers.AuthHandler, postHandler *handlers.PostHandler, voteHandler *handlers.VoteHandler, userHandler *handlers.UserHandler, apiHandler *handlers.APIHandler, authService handlers.AuthServiceInterface, cfg *config.Config, serverCfg serverConfig) http.Handler {
return server.NewRouter(server.RouterConfig{ return server.NewRouter(server.RouterConfig{
AuthHandler: authHandler, AuthHandler: authHandler,
PostHandler: postHandler, PostHandler: postHandler,
@@ -689,8 +722,9 @@ func setupRouter(authHandler *handlers.AuthHandler, postHandler *handlers.PostHa
PageHandler: nil, PageHandler: nil,
StaticDir: findWorkspaceRoot() + "/internal/static/", StaticDir: findWorkspaceRoot() + "/internal/static/",
Debug: false, Debug: false,
DisableCache: true, DisableCache: serverCfg.disableCache,
DisableCompression: true, DisableCompression: serverCfg.disableCompression,
CacheablePaths: serverCfg.cacheablePaths,
RateLimitConfig: cfg.RateLimit, RateLimitConfig: cfg.RateLimit,
}) })
} }
@@ -735,7 +769,7 @@ func setupIntegrationTestServerWithConfig(t *testing.T, serverCfg serverConfig)
} }
authHandler, postHandler, voteHandler, userHandler, apiHandler := setupHandlers(authService, userRepo, postRepo, voteService, cfg) authHandler, postHandler, voteHandler, userHandler, apiHandler := setupHandlers(authService, userRepo, postRepo, voteService, cfg)
router := setupRouter(authHandler, postHandler, voteHandler, userHandler, apiHandler, authService, cfg) router := setupRouter(authHandler, postHandler, voteHandler, userHandler, apiHandler, authService, cfg, serverCfg)
listener, err := net.Listen("tcp", "127.0.0.1:0") listener, err := net.Listen("tcp", "127.0.0.1:0")

View File

@@ -422,9 +422,7 @@ func TestE2E_ConcurrentPostCreation(t *testing.T) {
for _, user := range users { for _, user := range users {
u := user u := user
wg.Add(1) wg.Go(func() {
go func() {
defer wg.Done()
client, err := ctx.loginUserSafe(t, u.Username, u.Password) client, err := ctx.loginUserSafe(t, u.Username, u.Password)
if err != nil || client == nil { if err != nil || client == nil {
results <- nil results <- nil
@@ -443,7 +441,7 @@ func TestE2E_ConcurrentPostCreation(t *testing.T) {
post, err := client.CreatePostSafe("Concurrent Post", url, "Content") post, err := client.CreatePostSafe("Concurrent Post", url, "Content")
results <- post results <- post
}() })
} }
wg.Wait() wg.Wait()

View File

@@ -56,10 +56,8 @@ func TestE2E_DatabaseFailureRecovery(t *testing.T) {
var wg sync.WaitGroup var wg sync.WaitGroup
errors := make(chan error, 10) errors := make(chan error, 10)
for i := 0; i < 5; i++ { for range 5 {
wg.Add(1) wg.Go(func() {
go func() {
defer wg.Done()
conn, err := sqlDB.Conn(context.Background()) conn, err := sqlDB.Conn(context.Background())
if err != nil { if err != nil {
errors <- err errors <- err
@@ -68,7 +66,7 @@ func TestE2E_DatabaseFailureRecovery(t *testing.T) {
defer conn.Close() defer conn.Close()
time.Sleep(100 * time.Millisecond) time.Sleep(100 * time.Millisecond)
}() })
} }
wg.Wait() wg.Wait()
@@ -297,7 +295,7 @@ func TestE2E_DatabaseConnectionPool(t *testing.T) {
initialStats := sqlDB.Stats() initialStats := sqlDB.Stats()
for i := 0; i < 5; i++ { for range 5 {
req, err := http.NewRequest("GET", ctx.baseURL+"/api/posts", nil) req, err := http.NewRequest("GET", ctx.baseURL+"/api/posts", nil)
if err != nil { if err != nil {
continue continue

View File

@@ -7,14 +7,15 @@ import (
"net/http" "net/http"
"strings" "strings"
"testing" "testing"
"time"
"goyco/internal/testutils" "goyco/internal/testutils"
) )
func TestE2E_CompressionMiddleware(t *testing.T) { func TestE2E_CompressionMiddleware(t *testing.T) {
ctx := setupTestContext(t) ctx := setupTestContextWithMiddleware(t)
t.Run("compression_enabled_with_accept_encoding", func(t *testing.T) { t.Run("compresses_response_when_accept_encoding_is_gzip", func(t *testing.T) {
request, err := http.NewRequest("GET", ctx.baseURL+"/api/posts", nil) request, err := http.NewRequest("GET", ctx.baseURL+"/api/posts", nil)
if err != nil { if err != nil {
t.Fatalf("Failed to create request: %v", err) t.Fatalf("Failed to create request: %v", err)
@@ -27,63 +28,72 @@ func TestE2E_CompressionMiddleware(t *testing.T) {
t.Fatalf("Request failed: %v", err) t.Fatalf("Request failed: %v", err)
} }
defer response.Body.Close() defer response.Body.Close()
failIfRateLimited(t, response.StatusCode, "compression enabled GET /api/posts")
contentEncoding := response.Header.Get("Content-Encoding") contentEncoding := response.Header.Get("Content-Encoding")
if contentEncoding == "gzip" { if contentEncoding != "gzip" {
body, err := io.ReadAll(response.Body) t.Fatalf("Expected gzip compression, got Content-Encoding=%q", contentEncoding)
if err != nil { }
t.Fatalf("Failed to read response body: %v", err)
}
if isGzipCompressed(body) { body, err := io.ReadAll(response.Body)
reader, err := gzip.NewReader(bytes.NewReader(body)) if err != nil {
if err != nil { t.Fatalf("Failed to read response body: %v", err)
t.Fatalf("Failed to create gzip reader: %v", err) }
}
defer reader.Close()
decompressed, err := io.ReadAll(reader) if !isGzipCompressed(body) {
if err != nil { t.Fatalf("Expected gzip-compressed body bytes")
t.Fatalf("Failed to decompress: %v", err) }
}
if len(decompressed) == 0 { reader, err := gzip.NewReader(bytes.NewReader(body))
t.Error("Decompressed body is empty") if err != nil {
} t.Fatalf("Failed to create gzip reader: %v", err)
} }
} else { defer reader.Close()
t.Logf("Compression not applied (Content-Encoding: %s)", contentEncoding)
decompressed, err := io.ReadAll(reader)
if err != nil {
t.Fatalf("Failed to decompress: %v", err)
}
if len(decompressed) == 0 {
t.Fatal("Decompressed body is empty")
} }
}) })
t.Run("no_compression_without_accept_encoding", func(t *testing.T) { t.Run("does_not_compress_without_accept_encoding", func(t *testing.T) {
request, err := http.NewRequest("GET", ctx.baseURL+"/api/posts", nil) request, err := http.NewRequest("GET", ctx.baseURL+"/api/posts", nil)
if err != nil { if err != nil {
t.Fatalf("Failed to create request: %v", err) t.Fatalf("Failed to create request: %v", err)
} }
testutils.WithStandardHeaders(request) testutils.WithStandardHeaders(request)
request.Header.Del("Accept-Encoding")
response, err := ctx.client.Do(request) response, err := ctx.client.Do(request)
if err != nil { if err != nil {
t.Fatalf("Request failed: %v", err) t.Fatalf("Request failed: %v", err)
} }
defer response.Body.Close() defer response.Body.Close()
failIfRateLimited(t, response.StatusCode, "compression disabled GET /api/posts")
contentEncoding := response.Header.Get("Content-Encoding") contentEncoding := response.Header.Get("Content-Encoding")
if contentEncoding == "gzip" { if contentEncoding == "gzip" {
t.Error("Expected no compression without Accept-Encoding header") t.Fatal("Expected no compression without Accept-Encoding header")
} }
}) })
t.Run("decompression_handles_gzip_request", func(t *testing.T) { t.Run("accepts_valid_gzip_request_body", func(t *testing.T) {
testUser := ctx.createUserWithCleanup(t, "compressionuser", "StrongPass123!") testUser := ctx.createUserWithCleanup(t, "compressionuser", "StrongPass123!")
authClient := ctx.loginUser(t, testUser.Username, "StrongPass123!") authClient := ctx.loginUser(t, testUser.Username, "StrongPass123!")
var buf bytes.Buffer var buf bytes.Buffer
gz := gzip.NewWriter(&buf) gz := gzip.NewWriter(&buf)
postData := `{"title":"Compressed Post","url":"https://example.com/compressed","content":"Test content"}` postData := `{"title":"Compressed Post","url":"https://example.com/compressed","content":"Test content"}`
gz.Write([]byte(postData)) if _, err := gz.Write([]byte(postData)); err != nil {
gz.Close() t.Fatalf("Failed to gzip request body: %v", err)
}
if err := gz.Close(); err != nil {
t.Fatalf("Failed to finalize gzip body: %v", err)
}
request, err := http.NewRequest("POST", ctx.baseURL+"/api/posts", &buf) request, err := http.NewRequest("POST", ctx.baseURL+"/api/posts", &buf)
if err != nil { if err != nil {
@@ -99,20 +109,22 @@ func TestE2E_CompressionMiddleware(t *testing.T) {
t.Fatalf("Request failed: %v", err) t.Fatalf("Request failed: %v", err)
} }
defer response.Body.Close() defer response.Body.Close()
failIfRateLimited(t, response.StatusCode, "decompression POST /api/posts")
switch response.StatusCode { switch response.StatusCode {
case http.StatusBadRequest:
t.Log("Decompression middleware rejected invalid gzip")
case http.StatusCreated, http.StatusOK: case http.StatusCreated, http.StatusOK:
t.Log("Decompression middleware handled gzip request successfully") return
default:
body, _ := io.ReadAll(response.Body)
t.Fatalf("Expected status %d or %d for valid gzip request, got %d. Body: %s", http.StatusCreated, http.StatusOK, response.StatusCode, string(body))
} }
}) })
} }
func TestE2E_CacheMiddleware(t *testing.T) { func TestE2E_CacheMiddleware(t *testing.T) {
ctx := setupTestContext(t) ctx := setupTestContextWithMiddleware(t)
t.Run("cache_miss_then_hit", func(t *testing.T) { t.Run("returns_hit_after_repeated_get", func(t *testing.T) {
firstRequest, err := http.NewRequest("GET", ctx.baseURL+"/api/posts", nil) firstRequest, err := http.NewRequest("GET", ctx.baseURL+"/api/posts", nil)
if err != nil { if err != nil {
t.Fatalf("Failed to create request: %v", err) t.Fatalf("Failed to create request: %v", err)
@@ -124,46 +136,68 @@ func TestE2E_CacheMiddleware(t *testing.T) {
t.Fatalf("Request failed: %v", err) t.Fatalf("Request failed: %v", err)
} }
firstResponse.Body.Close() firstResponse.Body.Close()
failIfRateLimited(t, firstResponse.StatusCode, "first cache GET /api/posts")
firstCacheStatus := firstResponse.Header.Get("X-Cache") firstCacheStatus := firstResponse.Header.Get("X-Cache")
if firstCacheStatus == "HIT" { if firstCacheStatus == "HIT" {
t.Log("First request was cached (unexpected but acceptable)") t.Fatalf("Expected first request to be a cache miss, got X-Cache=%q", firstCacheStatus)
} }
secondRequest, err := http.NewRequest("GET", ctx.baseURL+"/api/posts", nil) for range 8 {
if err != nil { secondRequest, err := http.NewRequest("GET", ctx.baseURL+"/api/posts", nil)
t.Fatalf("Failed to create request: %v", err) if err != nil {
} t.Fatalf("Failed to create request: %v", err)
testutils.WithStandardHeaders(secondRequest) }
testutils.WithStandardHeaders(secondRequest)
secondResponse, err := ctx.client.Do(secondRequest) secondResponse, err := ctx.client.Do(secondRequest)
if err != nil { if err != nil {
t.Fatalf("Request failed: %v", err) t.Fatalf("Request failed: %v", err)
} }
defer secondResponse.Body.Close() failIfRateLimited(t, secondResponse.StatusCode, "cache warmup GET /api/posts")
secondCacheStatus := secondResponse.Header.Get("X-Cache")
secondResponse.Body.Close()
secondCacheStatus := secondResponse.Header.Get("X-Cache") if secondCacheStatus == "HIT" {
if secondCacheStatus == "HIT" { return
t.Log("Second request was served from cache") }
time.Sleep(25 * time.Millisecond)
} }
t.Fatal("Expected a cache HIT on repeated requests, but none observed")
}) })
t.Run("cache_invalidation_on_post", func(t *testing.T) { t.Run("invalidates_cached_get_after_post", func(t *testing.T) {
testUser := ctx.createUserWithCleanup(t, "cacheuser", "StrongPass123!") testUser := ctx.createUserWithCleanup(t, "cacheuser", "StrongPass123!")
authClient := ctx.loginUser(t, testUser.Username, "StrongPass123!") authClient := ctx.loginUser(t, testUser.Username, "StrongPass123!")
firstRequest, err := http.NewRequest("GET", ctx.baseURL+"/api/posts", nil) for attempt := range 8 {
if err != nil { firstRequest, err := http.NewRequest("GET", ctx.baseURL+"/api/posts", nil)
t.Fatalf("Failed to create request: %v", err) if err != nil {
} t.Fatalf("Failed to create request: %v", err)
testutils.WithStandardHeaders(firstRequest) }
firstRequest.Header.Set("Authorization", "Bearer "+authClient.Token) testutils.WithStandardHeaders(firstRequest)
firstRequest.Header.Set("Authorization", "Bearer "+authClient.Token)
firstResponse, err := ctx.client.Do(firstRequest) firstResponse, err := ctx.client.Do(firstRequest)
if err != nil { if err != nil {
t.Fatalf("Request failed: %v", err) t.Fatalf("Request failed: %v", err)
}
failIfRateLimited(t, firstResponse.StatusCode, "cache priming GET /api/posts")
cacheStatus := firstResponse.Header.Get("X-Cache")
firstResponse.Body.Close()
if cacheStatus == "HIT" {
break
}
if attempt == 7 {
t.Fatal("Failed to prime cache: repeated GET requests never produced X-Cache=HIT")
}
time.Sleep(25 * time.Millisecond)
} }
firstResponse.Body.Close()
postData := `{"title":"Cache Invalidation Test","url":"https://example.com/cache","content":"Test"}` postData := `{"title":"Cache Invalidation Test","url":"https://example.com/cache","content":"Test"}`
secondRequest, err := http.NewRequest("POST", ctx.baseURL+"/api/posts", strings.NewReader(postData)) secondRequest, err := http.NewRequest("POST", ctx.baseURL+"/api/posts", strings.NewReader(postData))
@@ -178,32 +212,45 @@ func TestE2E_CacheMiddleware(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("Request failed: %v", err) t.Fatalf("Request failed: %v", err)
} }
failIfRateLimited(t, secondResponse.StatusCode, "cache invalidation POST /api/posts")
if secondResponse.StatusCode != http.StatusCreated && secondResponse.StatusCode != http.StatusOK {
body, _ := io.ReadAll(secondResponse.Body)
secondResponse.Body.Close()
t.Fatalf("Expected post creation status %d or %d, got %d. Body: %s", http.StatusCreated, http.StatusOK, secondResponse.StatusCode, string(body))
}
secondResponse.Body.Close() secondResponse.Body.Close()
thirdRequest, err := http.NewRequest("GET", ctx.baseURL+"/api/posts", nil) for range 8 {
if err != nil { thirdRequest, err := http.NewRequest("GET", ctx.baseURL+"/api/posts", nil)
t.Fatalf("Failed to create request: %v", err) if err != nil {
} t.Fatalf("Failed to create request: %v", err)
testutils.WithStandardHeaders(thirdRequest) }
thirdRequest.Header.Set("Authorization", "Bearer "+authClient.Token) testutils.WithStandardHeaders(thirdRequest)
thirdRequest.Header.Set("Authorization", "Bearer "+authClient.Token)
thirdResponse, err := ctx.client.Do(thirdRequest) thirdResponse, err := ctx.client.Do(thirdRequest)
if err != nil { if err != nil {
t.Fatalf("Request failed: %v", err) t.Fatalf("Request failed: %v", err)
} }
defer thirdResponse.Body.Close() failIfRateLimited(t, thirdResponse.StatusCode, "post-invalidation GET /api/posts")
cacheStatus := thirdResponse.Header.Get("X-Cache")
thirdResponse.Body.Close()
cacheStatus := thirdResponse.Header.Get("X-Cache") if cacheStatus != "HIT" {
if cacheStatus == "HIT" { return
t.Log("Cache was invalidated after POST") }
time.Sleep(25 * time.Millisecond)
} }
t.Fatal("Expected cache to be invalidated after POST, but X-Cache stayed HIT")
}) })
} }
func TestE2E_CSRFProtection(t *testing.T) { func TestE2E_CSRFProtection(t *testing.T) {
ctx := setupTestContext(t) ctx := setupTestContext(t)
t.Run("csrf_protection_for_non_api_routes", func(t *testing.T) { t.Run("non_api_post_without_csrf_is_forbidden_or_unmounted", func(t *testing.T) {
request, err := http.NewRequest("POST", ctx.baseURL+"/auth/login", strings.NewReader(`{"username":"test","password":"test"}`)) request, err := http.NewRequest("POST", ctx.baseURL+"/auth/login", strings.NewReader(`{"username":"test","password":"test"}`))
if err != nil { if err != nil {
t.Fatalf("Failed to create request: %v", err) t.Fatalf("Failed to create request: %v", err)
@@ -216,15 +263,15 @@ func TestE2E_CSRFProtection(t *testing.T) {
t.Fatalf("Request failed: %v", err) t.Fatalf("Request failed: %v", err)
} }
defer response.Body.Close() defer response.Body.Close()
failIfRateLimited(t, response.StatusCode, "CSRF non-API POST /auth/login")
if response.StatusCode == http.StatusForbidden { if response.StatusCode != http.StatusForbidden && response.StatusCode != http.StatusNotFound {
t.Log("CSRF protection active for non-API routes") body, _ := io.ReadAll(response.Body)
} else { t.Fatalf("Expected status %d (CSRF protected) or %d (route unavailable in test setup), got %d. Body: %s", http.StatusForbidden, http.StatusNotFound, response.StatusCode, string(body))
t.Logf("CSRF check result: status %d", response.StatusCode)
} }
}) })
t.Run("csrf_bypass_for_api_routes", func(t *testing.T) { t.Run("api_post_without_csrf_is_not_forbidden", func(t *testing.T) {
testUser := ctx.createUserWithCleanup(t, "csrfuser", "StrongPass123!") testUser := ctx.createUserWithCleanup(t, "csrfuser", "StrongPass123!")
authClient := ctx.loginUser(t, testUser.Username, "StrongPass123!") authClient := ctx.loginUser(t, testUser.Username, "StrongPass123!")
@@ -242,13 +289,15 @@ func TestE2E_CSRFProtection(t *testing.T) {
t.Fatalf("Request failed: %v", err) t.Fatalf("Request failed: %v", err)
} }
defer response.Body.Close() defer response.Body.Close()
failIfRateLimited(t, response.StatusCode, "CSRF bypass POST /api/posts")
if response.StatusCode == http.StatusForbidden { if response.StatusCode == http.StatusForbidden {
t.Error("API routes should bypass CSRF protection") body, _ := io.ReadAll(response.Body)
t.Fatalf("API routes should bypass CSRF protection, got 403. Body: %s", string(body))
} }
}) })
t.Run("csrf_allows_get_requests", func(t *testing.T) { t.Run("get_request_without_csrf_is_not_forbidden", func(t *testing.T) {
request, err := http.NewRequest("GET", ctx.baseURL+"/auth/login", nil) request, err := http.NewRequest("GET", ctx.baseURL+"/auth/login", nil)
if err != nil { if err != nil {
t.Fatalf("Failed to create request: %v", err) t.Fatalf("Failed to create request: %v", err)
@@ -260,9 +309,10 @@ func TestE2E_CSRFProtection(t *testing.T) {
t.Fatalf("Request failed: %v", err) t.Fatalf("Request failed: %v", err)
} }
defer response.Body.Close() defer response.Body.Close()
failIfRateLimited(t, response.StatusCode, "CSRF GET /auth/login")
if response.StatusCode == http.StatusForbidden { if response.StatusCode == http.StatusForbidden {
t.Error("GET requests should not require CSRF token") t.Fatal("GET requests should not require CSRF token")
} }
}) })
} }
@@ -270,7 +320,7 @@ func TestE2E_CSRFProtection(t *testing.T) {
func TestE2E_RequestSizeLimit(t *testing.T) { func TestE2E_RequestSizeLimit(t *testing.T) {
ctx := setupTestContext(t) ctx := setupTestContext(t)
t.Run("request_within_size_limit", func(t *testing.T) { t.Run("accepts_request_within_size_limit", func(t *testing.T) {
testUser := ctx.createUserWithCleanup(t, "sizelimituser", "StrongPass123!") testUser := ctx.createUserWithCleanup(t, "sizelimituser", "StrongPass123!")
authClient := ctx.loginUser(t, testUser.Username, "StrongPass123!") authClient := ctx.loginUser(t, testUser.Username, "StrongPass123!")
@@ -289,13 +339,15 @@ func TestE2E_RequestSizeLimit(t *testing.T) {
t.Fatalf("Request failed: %v", err) t.Fatalf("Request failed: %v", err)
} }
defer response.Body.Close() defer response.Body.Close()
failIfRateLimited(t, response.StatusCode, "request within size limit POST /api/posts")
if response.StatusCode == http.StatusRequestEntityTooLarge { if response.StatusCode == http.StatusRequestEntityTooLarge {
t.Error("Small request should not exceed size limit") body, _ := io.ReadAll(response.Body)
t.Fatalf("Small request should not exceed size limit. Body: %s", string(body))
} }
}) })
t.Run("request_exceeds_size_limit", func(t *testing.T) { t.Run("rejects_or_fails_oversized_request", func(t *testing.T) {
testUser := ctx.createUserWithCleanup(t, "sizelimituser2", "StrongPass123!") testUser := ctx.createUserWithCleanup(t, "sizelimituser2", "StrongPass123!")
authClient := ctx.loginUser(t, testUser.Username, "StrongPass123!") authClient := ctx.loginUser(t, testUser.Username, "StrongPass123!")
@@ -311,14 +363,14 @@ func TestE2E_RequestSizeLimit(t *testing.T) {
response, err := ctx.client.Do(request) response, err := ctx.client.Do(request)
if err != nil { if err != nil {
return t.Fatalf("Request failed: %v", err)
} }
defer response.Body.Close() defer response.Body.Close()
failIfRateLimited(t, response.StatusCode, "request exceeds size limit POST /api/posts")
if response.StatusCode == http.StatusRequestEntityTooLarge { if response.StatusCode != http.StatusRequestEntityTooLarge && response.StatusCode != http.StatusBadRequest {
t.Log("Request size limit enforced correctly") body, _ := io.ReadAll(response.Body)
} else { t.Fatalf("Expected status %d or %d for oversized request, got %d. Body: %s", http.StatusRequestEntityTooLarge, http.StatusBadRequest, response.StatusCode, string(body))
t.Logf("Request size limit check result: status %d", response.StatusCode)
} }
}) })
} }

View File

@@ -59,7 +59,7 @@ func TestE2E_Performance(t *testing.T) {
var totalTime time.Duration var totalTime time.Duration
iterations := 10 iterations := 10
for i := 0; i < iterations; i++ { for range iterations {
req, err := endpoint.req() req, err := endpoint.req()
if err != nil { if err != nil {
t.Fatalf("Failed to create request: %v", err) t.Fatalf("Failed to create request: %v", err)
@@ -98,11 +98,9 @@ func TestE2E_Performance(t *testing.T) {
var errorCount int64 var errorCount int64
var wg sync.WaitGroup var wg sync.WaitGroup
for i := 0; i < concurrency; i++ { for range concurrency {
wg.Add(1) wg.Go(func() {
go func() { for range requestsPerGoroutine {
defer wg.Done()
for j := 0; j < requestsPerGoroutine; j++ {
req, err := http.NewRequest("GET", ctx.baseURL+"/api/posts", nil) req, err := http.NewRequest("GET", ctx.baseURL+"/api/posts", nil)
if err != nil { if err != nil {
atomic.AddInt64(&errorCount, 1) atomic.AddInt64(&errorCount, 1)
@@ -123,7 +121,7 @@ func TestE2E_Performance(t *testing.T) {
atomic.AddInt64(&errorCount, 1) atomic.AddInt64(&errorCount, 1)
} }
} }
}() })
} }
wg.Wait() wg.Wait()
@@ -138,7 +136,7 @@ func TestE2E_Performance(t *testing.T) {
createdUser := ctx.createUserWithCleanup(t, "dbperf", "StrongPass123!") createdUser := ctx.createUserWithCleanup(t, "dbperf", "StrongPass123!")
authClient := ctx.loginUser(t, createdUser.Username, createdUser.Password) authClient := ctx.loginUser(t, createdUser.Username, createdUser.Password)
for i := 0; i < 10; i++ { for i := range 10 {
authClient.CreatePost(t, fmt.Sprintf("Post %d", i), fmt.Sprintf("https://example.com/%d", i), "Content") authClient.CreatePost(t, fmt.Sprintf("Post %d", i), fmt.Sprintf("https://example.com/%d", i), "Content")
} }
@@ -160,7 +158,7 @@ func TestE2E_Performance(t *testing.T) {
authClient := ctx.loginUser(t, createdUser.Username, createdUser.Password) authClient := ctx.loginUser(t, createdUser.Username, createdUser.Password)
initialPosts := 50 initialPosts := 50
for i := 0; i < initialPosts; i++ { for i := range initialPosts {
authClient.CreatePost(t, fmt.Sprintf("Memory Test Post %d", i), fmt.Sprintf("https://example.com/mem%d", i), "Content") authClient.CreatePost(t, fmt.Sprintf("Memory Test Post %d", i), fmt.Sprintf("https://example.com/mem%d", i), "Content")
} }
@@ -271,16 +269,14 @@ func TestE2E_ConcurrentWrites(t *testing.T) {
for _, user := range users { for _, user := range users {
u := user u := user
wg.Add(1) wg.Go(func() {
go func() {
defer wg.Done()
authClient, err := ctx.loginUserSafe(t, u.Username, u.Password) authClient, err := ctx.loginUserSafe(t, u.Username, u.Password)
if err != nil { if err != nil {
atomic.AddInt64(&errorCount, 1) atomic.AddInt64(&errorCount, 1)
return return
} }
for i := 0; i < 5; i++ { for i := range 5 {
post, err := authClient.CreatePostSafe( post, err := authClient.CreatePostSafe(
fmt.Sprintf("Concurrent Post %d", i), fmt.Sprintf("Concurrent Post %d", i),
fmt.Sprintf("https://example.com/concurrent%d-%d", u.ID, i), fmt.Sprintf("https://example.com/concurrent%d-%d", u.ID, i),
@@ -292,7 +288,7 @@ func TestE2E_ConcurrentWrites(t *testing.T) {
atomic.AddInt64(&errorCount, 1) atomic.AddInt64(&errorCount, 1)
} }
} }
}() })
} }
wg.Wait() wg.Wait()
@@ -311,7 +307,7 @@ func TestE2E_ResponseSize(t *testing.T) {
createdUser := ctx.createUserWithCleanup(t, "sizetest", "StrongPass123!") createdUser := ctx.createUserWithCleanup(t, "sizetest", "StrongPass123!")
authClient := ctx.loginUser(t, createdUser.Username, createdUser.Password) authClient := ctx.loginUser(t, createdUser.Username, createdUser.Password)
for i := 0; i < 100; i++ { for i := range 100 {
authClient.CreatePost(t, fmt.Sprintf("Post %d", i), fmt.Sprintf("https://example.com/%d", i), "Content") authClient.CreatePost(t, fmt.Sprintf("Post %d", i), fmt.Sprintf("https://example.com/%d", i), "Content")
} }

View File

@@ -38,7 +38,7 @@ func TestE2E_RateLimitingHeaders(t *testing.T) {
t.Error("Expected Retry-After header when rate limited") t.Error("Expected Retry-After header when rate limited")
} }
var jsonResponse map[string]interface{} var jsonResponse map[string]any
body, _ := json.Marshal(map[string]string{}) body, _ := json.Marshal(map[string]string{})
_ = json.Unmarshal(body, &jsonResponse) _ = json.Unmarshal(body, &jsonResponse)
@@ -72,7 +72,7 @@ func TestE2E_RateLimitingHeaders(t *testing.T) {
if resp.StatusCode != http.StatusTooManyRequests { if resp.StatusCode != http.StatusTooManyRequests {
t.Errorf("Expected status 429 on request %d, got %d", i+1, resp.StatusCode) t.Errorf("Expected status 429 on request %d, got %d", i+1, resp.StatusCode)
} else { } else {
var errorResponse map[string]interface{} var errorResponse map[string]any
body, _ := io.ReadAll(resp.Body) body, _ := io.ReadAll(resp.Body)
if err := json.Unmarshal(body, &errorResponse); err == nil { if err := json.Unmarshal(body, &errorResponse); err == nil {
if errorResponse["error"] == nil { if errorResponse["error"] == nil {

View File

@@ -30,12 +30,12 @@ func TestE2E_VersionEndpoint(t *testing.T) {
return return
} }
var apiInfo map[string]interface{} var apiInfo map[string]any
if err := json.NewDecoder(resp.Body).Decode(&apiInfo); err != nil { if err := json.NewDecoder(resp.Body).Decode(&apiInfo); err != nil {
t.Fatalf("Failed to decode API info response: %v", err) t.Fatalf("Failed to decode API info response: %v", err)
} }
data, ok := apiInfo["data"].(map[string]interface{}) data, ok := apiInfo["data"].(map[string]any)
if !ok { if !ok {
t.Fatalf("API info data is not a map") t.Fatalf("API info data is not a map")
} }
@@ -74,12 +74,12 @@ func TestE2E_VersionEndpoint(t *testing.T) {
return return
} }
var healthInfo map[string]interface{} var healthInfo map[string]any
if err := json.NewDecoder(resp.Body).Decode(&healthInfo); err != nil { if err := json.NewDecoder(resp.Body).Decode(&healthInfo); err != nil {
t.Fatalf("Failed to decode health response: %v", err) t.Fatalf("Failed to decode health response: %v", err)
} }
data, ok := healthInfo["data"].(map[string]interface{}) data, ok := healthInfo["data"].(map[string]any)
if !ok { if !ok {
t.Fatalf("Health data is not a map") t.Fatalf("Health data is not a map")
} }
@@ -130,18 +130,18 @@ func TestE2E_VersionEndpoint(t *testing.T) {
return return
} }
var apiInfo map[string]interface{} var apiInfo map[string]any
if err := json.NewDecoder(apiResp.Body).Decode(&apiInfo); err != nil { if err := json.NewDecoder(apiResp.Body).Decode(&apiInfo); err != nil {
t.Fatalf("Failed to decode API info: %v", err) t.Fatalf("Failed to decode API info: %v", err)
} }
var healthInfo map[string]interface{} var healthInfo map[string]any
if err := json.NewDecoder(healthResp.Body).Decode(&healthInfo); err != nil { if err := json.NewDecoder(healthResp.Body).Decode(&healthInfo); err != nil {
t.Fatalf("Failed to decode health info: %v", err) t.Fatalf("Failed to decode health info: %v", err)
} }
apiData, _ := apiInfo["data"].(map[string]interface{}) apiData, _ := apiInfo["data"].(map[string]any)
healthData, _ := healthInfo["data"].(map[string]interface{}) healthData, _ := healthInfo["data"].(map[string]any)
apiVersion, apiOk := apiData["version"].(string) apiVersion, apiOk := apiData["version"].(string)
healthVersion, healthOk := healthData["version"].(string) healthVersion, healthOk := healthData["version"].(string)
@@ -169,12 +169,12 @@ func TestE2E_VersionEndpoint(t *testing.T) {
return return
} }
var apiInfo map[string]interface{} var apiInfo map[string]any
if err := json.NewDecoder(resp.Body).Decode(&apiInfo); err != nil { if err := json.NewDecoder(resp.Body).Decode(&apiInfo); err != nil {
t.Fatalf("Failed to decode API info: %v", err) t.Fatalf("Failed to decode API info: %v", err)
} }
data, ok := apiInfo["data"].(map[string]interface{}) data, ok := apiInfo["data"].(map[string]any)
if !ok { if !ok {
return return
} }

View File

@@ -455,7 +455,7 @@ func TestE2E_ConcurrentRequestsWithSameSession(t *testing.T) {
authClient := ctx.loginUser(t, createdUser.Username, createdUser.Password) authClient := ctx.loginUser(t, createdUser.Username, createdUser.Password)
results := make(chan bool, 5) results := make(chan bool, 5)
for i := 0; i < 5; i++ { for range 5 {
go func() { go func() {
profile := authClient.GetProfile(t) profile := authClient.GetProfile(t)
results <- (profile != nil && profile.Data.Username == createdUser.Username) results <- (profile != nil && profile.Data.Username == createdUser.Username)
@@ -463,7 +463,7 @@ func TestE2E_ConcurrentRequestsWithSameSession(t *testing.T) {
} }
successCount := 0 successCount := 0
for i := 0; i < 5; i++ { for range 5 {
if <-results { if <-results {
successCount++ successCount++
} }
@@ -562,7 +562,7 @@ func TestE2E_RapidSuccessiveActions(t *testing.T) {
post := authClient.CreatePost(t, "Rapid Vote Test", "https://example.com/rapid", "Content") post := authClient.CreatePost(t, "Rapid Vote Test", "https://example.com/rapid", "Content")
for i := 0; i < 10; i++ { for i := range 10 {
voteType := "up" voteType := "up"
if i%2 == 0 { if i%2 == 0 {
voteType = "down" voteType = "down"

View File

@@ -134,9 +134,7 @@ func TestE2E_ConcurrentUserWorkflows(t *testing.T) {
for _, user := range users { for _, user := range users {
u := user u := user
wg.Add(1) wg.Go(func() {
go func() {
defer wg.Done()
var err error var err error
authClient, loginErr := ctx.loginUserSafe(t, u.Username, u.Password) authClient, loginErr := ctx.loginUserSafe(t, u.Username, u.Password)
if loginErr != nil || authClient == nil || authClient.Token == "" { if loginErr != nil || authClient == nil || authClient.Token == "" {
@@ -157,7 +155,7 @@ func TestE2E_ConcurrentUserWorkflows(t *testing.T) {
case results <- result{userID: u.ID, err: err}: case results <- result{userID: u.ID, err: err}:
case <-done: case <-done:
} }
}() })
} }
go func() { go func() {

View File

@@ -53,13 +53,13 @@ func TestRunSanitizationFuzzTest(t *testing.T) {
sanitizeFunc := func(input string) string { sanitizeFunc := func(input string) string {
result := "" var result strings.Builder
for _, char := range input { for _, char := range input {
if char != ' ' { if char != ' ' {
result += string(char) result.WriteString(string(char))
} }
} }
return result return result.String()
} }
result := sanitizeFunc("hello world") result := sanitizeFunc("hello world")
@@ -76,13 +76,13 @@ func TestRunSanitizationFuzzTestWithValidation(t *testing.T) {
sanitizeFunc := func(input string) string { sanitizeFunc := func(input string) string {
result := "" var result strings.Builder
for _, char := range input { for _, char := range input {
if char != ' ' { if char != ' ' {
result += string(char) result.WriteString(string(char))
} }
} }
return result return result.String()
} }
validateFunc := func(input string) bool { validateFunc := func(input string) bool {
@@ -1673,7 +1673,7 @@ func TestValidateHTTPRequestWithManyHeaders(t *testing.T) {
helper := NewFuzzTestHelper() helper := NewFuzzTestHelper()
req := httptest.NewRequest("GET", "/api/test", nil) req := httptest.NewRequest("GET", "/api/test", nil)
for i := 0; i < 20; i++ { for i := range 20 {
req.Header.Set(fmt.Sprintf("Header-%d", i), fmt.Sprintf("Value-%d", i)) req.Header.Set(fmt.Sprintf("Header-%d", i), fmt.Sprintf("Value-%d", i))
} }

View File

@@ -158,10 +158,3 @@ func FuzzPostRepository(f *testing.F) {
}) })
}) })
} }
func min(a, b int) int {
if a < b {
return a
}
return b
}

View File

@@ -1364,7 +1364,7 @@ func TestAuthHandler_ConcurrentAccess(t *testing.T) {
concurrency := 10 concurrency := 10
done := make(chan bool, concurrency) done := make(chan bool, concurrency)
for i := 0; i < concurrency; i++ { for range concurrency {
go func() { go func() {
req := createLoginRequest(`{"username":"testuser","password":"Password123!"}`) req := createLoginRequest(`{"username":"testuser","password":"Password123!"}`)
w := httptest.NewRecorder() w := httptest.NewRecorder()
@@ -1376,7 +1376,7 @@ func TestAuthHandler_ConcurrentAccess(t *testing.T) {
}() }()
} }
for i := 0; i < concurrency; i++ { for range concurrency {
<-done <-done
} }
}) })

View File

@@ -115,10 +115,7 @@ func NewPageHandler(templatesDir string, authService AuthServiceInterface, postR
if start >= len(s) { if start >= len(s) {
return "" return ""
} }
end := start + length end := min(start+length, len(s))
if end > len(s) {
end = len(s)
}
return s[start:end] return s[start:end]
}, },
"upper": strings.ToUpper, "upper": strings.ToUpper,

View File

@@ -35,12 +35,17 @@ type OverallResult struct {
func determineOverallStatus(results map[string]Result) Status { func determineOverallStatus(results map[string]Result) Status {
hasUnhealthy := false hasUnhealthy := false
hasSMTPUnhealthy := false
hasDegraded := false hasDegraded := false
for _, result := range results { for name, result := range results {
switch result.Status { switch result.Status {
case StatusUnhealthy: case StatusUnhealthy:
hasUnhealthy = true if name == "smtp" {
hasSMTPUnhealthy = true
} else {
hasUnhealthy = true
}
case StatusDegraded: case StatusDegraded:
hasDegraded = true hasDegraded = true
} }
@@ -49,7 +54,7 @@ func determineOverallStatus(results map[string]Result) Status {
if hasUnhealthy { if hasUnhealthy {
return StatusUnhealthy return StatusUnhealthy
} }
if hasDegraded { if hasDegraded || hasSMTPUnhealthy {
return StatusDegraded return StatusDegraded
} }
return StatusHealthy return StatusHealthy

View File

@@ -57,10 +57,15 @@ func TestDetermineOverallStatus(t *testing.T) {
results: map[string]Result{"db": {Status: StatusUnhealthy}, "smtp": {Status: StatusHealthy}}, results: map[string]Result{"db": {Status: StatusUnhealthy}, "smtp": {Status: StatusHealthy}},
expected: StatusUnhealthy, expected: StatusUnhealthy,
}, },
{
name: "smtp unhealthy downgrades overall to degraded",
results: map[string]Result{"db": {Status: StatusHealthy}, "smtp": {Status: StatusUnhealthy}},
expected: StatusDegraded,
},
{ {
name: "mixed degraded and unhealthy", name: "mixed degraded and unhealthy",
results: map[string]Result{"db": {Status: StatusDegraded}, "smtp": {Status: StatusUnhealthy}}, results: map[string]Result{"db": {Status: StatusDegraded}, "smtp": {Status: StatusUnhealthy}},
expected: StatusUnhealthy, expected: StatusDegraded,
}, },
{ {
name: "empty results", name: "empty results",

View File

@@ -33,7 +33,7 @@ func (c *SMTPChecker) Name() string {
func (c *SMTPChecker) Check(ctx context.Context) Result { func (c *SMTPChecker) Check(ctx context.Context) Result {
start := time.Now() start := time.Now()
address := fmt.Sprintf("%s:%d", c.config.Host, c.config.Port) address := net.JoinHostPort(c.config.Host, fmt.Sprintf("%d", c.config.Port))
result := Result{ result := Result{
Status: StatusHealthy, Status: StatusHealthy,
@@ -45,7 +45,7 @@ func (c *SMTPChecker) Check(ctx context.Context) Result {
conn, err := net.Dial("tcp", address) conn, err := net.Dial("tcp", address)
if err != nil { if err != nil {
result.Status = StatusDegraded result.Status = StatusUnhealthy
result.Message = fmt.Sprintf("Failed to connect to SMTP server: %v", err) result.Message = fmt.Sprintf("Failed to connect to SMTP server: %v", err)
result.Latency = time.Since(start) result.Latency = time.Since(start)
return result return result
@@ -54,7 +54,7 @@ func (c *SMTPChecker) Check(ctx context.Context) Result {
client, err := smtp.NewClient(conn, c.config.Host) client, err := smtp.NewClient(conn, c.config.Host)
if err != nil { if err != nil {
result.Status = StatusDegraded result.Status = StatusUnhealthy
result.Message = fmt.Sprintf("Failed to create SMTP client: %v", err) result.Message = fmt.Sprintf("Failed to create SMTP client: %v", err)
result.Latency = time.Since(start) result.Latency = time.Since(start)
return result return result

View File

@@ -3,6 +3,8 @@ package health
import ( import (
"context" "context"
"net" "net"
"strconv"
"strings"
"testing" "testing"
"time" "time"
) )
@@ -34,8 +36,8 @@ func TestSMTPChecker_Check_InvalidPort(t *testing.T) {
result := checker.Check(ctx) result := checker.Check(ctx)
if result.Status != StatusDegraded { if result.Status != StatusUnhealthy {
t.Errorf("expected degraded status for invalid port, got %s", result.Status) t.Errorf("expected unhealthy status for invalid port, got %s", result.Status)
} }
} }
@@ -57,8 +59,8 @@ func TestSMTPChecker_Check_ConnectionRefused(t *testing.T) {
result := checker.Check(ctx) result := checker.Check(ctx)
if result.Status != StatusDegraded { if result.Status != StatusUnhealthy {
t.Errorf("expected degraded status for connection refused, got %s", result.Status) t.Errorf("expected unhealthy status for connection refused, got %s", result.Status)
} }
} }
@@ -178,7 +180,7 @@ func TestSMTPChecker_Check_EHLOFailure(t *testing.T) {
t.Errorf("expected degraded status for EHLO failure, got %s", result.Status) t.Errorf("expected degraded status for EHLO failure, got %s", result.Status)
} }
if !contains(result.Message, "EHLO") { if !strings.Contains(result.Message, "EHLO") {
t.Errorf("expected EHLO error in message, got: %s", result.Message) t.Errorf("expected EHLO error in message, got: %s", result.Message)
} }
} }
@@ -303,35 +305,5 @@ func TestSMTPConfig_GetAddress(t *testing.T) {
} }
func getSMTPAddress(config SMTPConfig) string { func getSMTPAddress(config SMTPConfig) string {
return config.Host + ":" + itoa(config.Port) return config.Host + ":" + strconv.Itoa(config.Port)
}
func itoa(n int) string {
if n == 0 {
return "0"
}
if n < 0 {
return "-" + itoa(-n)
}
var result []byte
for n > 0 {
result = append([]byte{byte('0' + n%10)}, result...)
n /= 10
}
return string(result)
}
func contains(s, substr string) bool {
if len(substr) == 0 {
return true
}
if len(s) < len(substr) {
return false
}
for i := 0; i <= len(s)-len(substr); i++ {
if s[i:i+len(substr)] == substr {
return true
}
}
return false
} }

View File

@@ -635,7 +635,7 @@ func TestIntegration_CompleteAPIEndpoints(t *testing.T) {
ctx.Suite.EmailSender.Reset() ctx.Suite.EmailSender.Reset()
paginationUser := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, uniqueTestUsername(t, "pagination_edge"), uniqueTestEmail(t, "pagination_edge")) paginationUser := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, uniqueTestUsername(t, "pagination_edge"), uniqueTestEmail(t, "pagination_edge"))
for i := 0; i < 5; i++ { for i := range 5 {
testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, paginationUser.User.ID, fmt.Sprintf("Pagination Post %d", i), fmt.Sprintf("https://example.com/pag%d", i)) testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, paginationUser.User.ID, fmt.Sprintf("Pagination Post %d", i), fmt.Sprintf("https://example.com/pag%d", i))
} }

View File

@@ -50,7 +50,7 @@ func TestIntegration_RateLimiting(t *testing.T) {
rateLimitConfig.AuthLimit = 2 rateLimitConfig.AuthLimit = 2
router, _ := setupRateLimitRouter(t, rateLimitConfig) router, _ := setupRateLimitRouter(t, rateLimitConfig)
for i := 0; i < 2; i++ { for range 2 {
request := httptest.NewRequest("POST", "/api/auth/login", bytes.NewBufferString(`{"username":"test","password":"test"}`)) request := httptest.NewRequest("POST", "/api/auth/login", bytes.NewBufferString(`{"username":"test","password":"test"}`))
request.Header.Set("Content-Type", "application/json") request.Header.Set("Content-Type", "application/json")
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
@@ -80,7 +80,7 @@ func TestIntegration_RateLimiting(t *testing.T) {
rateLimitConfig.GeneralLimit = 5 rateLimitConfig.GeneralLimit = 5
router, _ := setupRateLimitRouter(t, rateLimitConfig) router, _ := setupRateLimitRouter(t, rateLimitConfig)
for i := 0; i < 5; i++ { for range 5 {
request := httptest.NewRequest("GET", "/api/posts", nil) request := httptest.NewRequest("GET", "/api/posts", nil)
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, request) router.ServeHTTP(recorder, request)
@@ -99,7 +99,7 @@ func TestIntegration_RateLimiting(t *testing.T) {
rateLimitConfig.HealthLimit = 3 rateLimitConfig.HealthLimit = 3
router, _ := setupRateLimitRouter(t, rateLimitConfig) router, _ := setupRateLimitRouter(t, rateLimitConfig)
for i := 0; i < 3; i++ { for range 3 {
request := httptest.NewRequest("GET", "/health", nil) request := httptest.NewRequest("GET", "/health", nil)
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, request) router.ServeHTTP(recorder, request)
@@ -118,7 +118,7 @@ func TestIntegration_RateLimiting(t *testing.T) {
rateLimitConfig.MetricsLimit = 2 rateLimitConfig.MetricsLimit = 2
router, _ := setupRateLimitRouter(t, rateLimitConfig) router, _ := setupRateLimitRouter(t, rateLimitConfig)
for i := 0; i < 2; i++ { for range 2 {
request := httptest.NewRequest("GET", "/metrics", nil) request := httptest.NewRequest("GET", "/metrics", nil)
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, request) router.ServeHTTP(recorder, request)
@@ -138,7 +138,7 @@ func TestIntegration_RateLimiting(t *testing.T) {
rateLimitConfig.GeneralLimit = 10 rateLimitConfig.GeneralLimit = 10
router, _ := setupRateLimitRouter(t, rateLimitConfig) router, _ := setupRateLimitRouter(t, rateLimitConfig)
for i := 0; i < 2; i++ { for range 2 {
request := httptest.NewRequest("POST", "/api/auth/login", bytes.NewBufferString(`{"username":"test","password":"test"}`)) request := httptest.NewRequest("POST", "/api/auth/login", bytes.NewBufferString(`{"username":"test","password":"test"}`))
request.Header.Set("Content-Type", "application/json") request.Header.Set("Content-Type", "application/json")
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
@@ -165,7 +165,7 @@ func TestIntegration_RateLimiting(t *testing.T) {
suite.EmailSender.Reset() suite.EmailSender.Reset()
user := createAuthenticatedUser(t, authService, suite.UserRepo, uniqueTestUsername(t, "ratelimit_auth"), uniqueTestEmail(t, "ratelimit_auth")) user := createAuthenticatedUser(t, authService, suite.UserRepo, uniqueTestUsername(t, "ratelimit_auth"), uniqueTestEmail(t, "ratelimit_auth"))
for i := 0; i < 3; i++ { for range 3 {
request := httptest.NewRequest("GET", "/api/auth/me", nil) request := httptest.NewRequest("GET", "/api/auth/me", nil)
request.Header.Set("Authorization", "Bearer "+user.Token) request.Header.Set("Authorization", "Bearer "+user.Token)
request = testutils.WithUserContext(request, middleware.UserIDKey, user.User.ID) request = testutils.WithUserContext(request, middleware.UserIDKey, user.User.ID)

View File

@@ -283,7 +283,7 @@ func TestIntegration_Repositories(t *testing.T) {
} }
voters := make([]*database.User, 5) voters := make([]*database.User, 5)
for i := 0; i < 5; i++ { for i := range 5 {
voter := &database.User{ voter := &database.User{
Username: fmt.Sprintf("voter_%d", i), Username: fmt.Sprintf("voter_%d", i),
Email: fmt.Sprintf("voter%d@example.com", i), Email: fmt.Sprintf("voter%d@example.com", i),

View File

@@ -2,7 +2,6 @@ package integration
import ( import (
"bytes" "bytes"
"encoding/json"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"strings" "strings"
@@ -63,13 +62,33 @@ func TestIntegration_Router_FullMiddlewareChain(t *testing.T) {
t.Run("DBMonitoring_Active", func(t *testing.T) { t.Run("DBMonitoring_Active", func(t *testing.T) {
request := makeGetRequest(t, router, "/health") request := makeGetRequest(t, router, "/health")
var response map[string]any response := assertJSONResponse(t, request, http.StatusOK)
if err := json.NewDecoder(request.Body).Decode(&response); err == nil { if response == nil {
if data, ok := response["data"].(map[string]any); ok { return
if _, exists := data["database_stats"]; !exists { }
t.Error("Expected database_stats in health response")
} data, ok := getDataFromResponse(response)
} if !ok {
t.Fatal("Expected data to be a map")
}
services, ok := data["services"].(map[string]any)
if !ok {
t.Fatal("Expected services in health response")
}
databaseService, ok := services["database"].(map[string]any)
if !ok {
t.Fatal("Expected database service in health response")
}
details, ok := databaseService["details"].(map[string]any)
if !ok {
t.Fatal("Expected database details in health response")
}
if _, exists := details["stats"]; !exists {
t.Error("Expected database stats in health response")
} }
}) })

View File

@@ -283,16 +283,16 @@ func TestInMemoryCacheConcurrent(t *testing.T) {
t.Run("Concurrent writes", func(t *testing.T) { t.Run("Concurrent writes", func(t *testing.T) {
done := make(chan bool, numGoroutines) done := make(chan bool, numGoroutines)
for i := 0; i < numGoroutines; i++ { for i := range numGoroutines {
go func(id int) { go func(id int) {
defer func() { defer func() {
if r := recover(); r != nil { if r := recover(); r != nil {
t.Errorf("Goroutine %d panicked: %v", id, r) t.Errorf("Goroutine %d panicked: %v", id, r)
} }
}() }()
for j := 0; j < numOps; j++ { for j := range numOps {
entry := &CacheEntry{ entry := &CacheEntry{
Data: []byte(fmt.Sprintf("data-%d-%d", id, j)), Data: fmt.Appendf(nil, "data-%d-%d", id, j),
Headers: make(http.Header), Headers: make(http.Header),
Timestamp: time.Now(), Timestamp: time.Now(),
TTL: 5 * time.Minute, TTL: 5 * time.Minute,
@@ -306,16 +306,16 @@ func TestInMemoryCacheConcurrent(t *testing.T) {
}(i) }(i)
} }
for i := 0; i < numGoroutines; i++ { for range numGoroutines {
<-done <-done
} }
}) })
t.Run("Concurrent reads and writes", func(t *testing.T) { t.Run("Concurrent reads and writes", func(t *testing.T) {
for i := 0; i < 10; i++ { for i := range 10 {
entry := &CacheEntry{ entry := &CacheEntry{
Data: []byte(fmt.Sprintf("data-%d", i)), Data: fmt.Appendf(nil, "data-%d", i),
Headers: make(http.Header), Headers: make(http.Header),
Timestamp: time.Now(), Timestamp: time.Now(),
TTL: 5 * time.Minute, TTL: 5 * time.Minute,
@@ -325,16 +325,16 @@ func TestInMemoryCacheConcurrent(t *testing.T) {
done := make(chan bool, numGoroutines*2) done := make(chan bool, numGoroutines*2)
for i := 0; i < numGoroutines; i++ { for i := range numGoroutines {
go func(id int) { go func(id int) {
defer func() { defer func() {
if r := recover(); r != nil { if r := recover(); r != nil {
t.Errorf("Writer goroutine %d panicked: %v", id, r) t.Errorf("Writer goroutine %d panicked: %v", id, r)
} }
}() }()
for j := 0; j < numOps; j++ { for j := range numOps {
entry := &CacheEntry{ entry := &CacheEntry{
Data: []byte(fmt.Sprintf("write-%d-%d", id, j)), Data: fmt.Appendf(nil, "write-%d-%d", id, j),
Headers: make(http.Header), Headers: make(http.Header),
Timestamp: time.Now(), Timestamp: time.Now(),
TTL: 5 * time.Minute, TTL: 5 * time.Minute,
@@ -346,14 +346,14 @@ func TestInMemoryCacheConcurrent(t *testing.T) {
}(i) }(i)
} }
for i := 0; i < numGoroutines; i++ { for i := range numGoroutines {
go func(id int) { go func(id int) {
defer func() { defer func() {
if r := recover(); r != nil { if r := recover(); r != nil {
t.Errorf("Reader goroutine %d panicked: %v", id, r) t.Errorf("Reader goroutine %d panicked: %v", id, r)
} }
}() }()
for j := 0; j < numOps; j++ { for j := range numOps {
key := fmt.Sprintf("key-%d", j%10) key := fmt.Sprintf("key-%d", j%10)
cache.Get(key) cache.Get(key)
} }
@@ -368,9 +368,9 @@ func TestInMemoryCacheConcurrent(t *testing.T) {
t.Run("Concurrent deletes", func(t *testing.T) { t.Run("Concurrent deletes", func(t *testing.T) {
for i := 0; i < numGoroutines; i++ { for i := range numGoroutines {
entry := &CacheEntry{ entry := &CacheEntry{
Data: []byte(fmt.Sprintf("data-%d", i)), Data: fmt.Appendf(nil, "data-%d", i),
Headers: make(http.Header), Headers: make(http.Header),
Timestamp: time.Now(), Timestamp: time.Now(),
TTL: 5 * time.Minute, TTL: 5 * time.Minute,
@@ -379,7 +379,7 @@ func TestInMemoryCacheConcurrent(t *testing.T) {
} }
done := make(chan bool, numGoroutines) done := make(chan bool, numGoroutines)
for i := 0; i < numGoroutines; i++ { for i := range numGoroutines {
go func(id int) { go func(id int) {
defer func() { defer func() {
if r := recover(); r != nil { if r := recover(); r != nil {
@@ -391,7 +391,7 @@ func TestInMemoryCacheConcurrent(t *testing.T) {
}(i) }(i)
} }
for i := 0; i < numGoroutines; i++ { for range numGoroutines {
<-done <-done
} }
}) })

View File

@@ -26,12 +26,7 @@ func NewCORSConfig() *CORSConfig {
} }
switch env { switch env {
case "production": case "production", "staging":
if origins := os.Getenv("CORS_ALLOWED_ORIGINS"); origins == "" {
config.AllowedOrigins = []string{}
}
config.AllowCredentials = true
case "staging":
if origins := os.Getenv("CORS_ALLOWED_ORIGINS"); origins == "" { if origins := os.Getenv("CORS_ALLOWED_ORIGINS"); origins == "" {
config.AllowedOrigins = []string{} config.AllowedOrigins = []string{}
} }
@@ -53,82 +48,66 @@ func NewCORSConfig() *CORSConfig {
return config return config
} }
func checkOrigin(origin string, allowedOrigins []string) (allowed bool, hasWildcard bool) {
for _, allowedOrigin := range allowedOrigins {
if allowedOrigin == "*" {
hasWildcard = true
allowed = true
break
}
if allowedOrigin == origin {
allowed = true
break
}
}
return
}
func setCORSHeaders(w http.ResponseWriter, origin string, hasWildcard bool, config *CORSConfig) {
if hasWildcard && !config.AllowCredentials {
w.Header().Set("Access-Control-Allow-Origin", "*")
} else {
w.Header().Set("Access-Control-Allow-Origin", origin)
}
if config.AllowCredentials && !hasWildcard {
w.Header().Set("Access-Control-Allow-Credentials", "true")
}
}
func CORSWithConfig(config *CORSConfig) func(http.Handler) http.Handler { func CORSWithConfig(config *CORSConfig) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
origin := r.Header.Get("Origin") origin := r.Header.Get("Origin")
if r.Method == "OPTIONS" { if origin == "" {
if origin != "" { if r.Method == "OPTIONS" {
allowed := false w.WriteHeader(http.StatusOK)
hasWildcard := false return
for _, allowedOrigin := range config.AllowedOrigins {
if allowedOrigin == "*" {
hasWildcard = true
allowed = true
break
}
if allowedOrigin == origin {
allowed = true
break
}
}
if !allowed {
http.Error(w, "Origin not allowed", http.StatusForbidden)
return
}
if hasWildcard && !config.AllowCredentials {
w.Header().Set("Access-Control-Allow-Origin", "*")
} else {
w.Header().Set("Access-Control-Allow-Origin", origin)
}
w.Header().Set("Access-Control-Allow-Methods", strings.Join(config.AllowedMethods, ", "))
w.Header().Set("Access-Control-Allow-Headers", strings.Join(config.AllowedHeaders, ", "))
w.Header().Set("Access-Control-Max-Age", fmt.Sprintf("%d", config.MaxAge))
if config.AllowCredentials && !hasWildcard {
w.Header().Set("Access-Control-Allow-Credentials", "true")
}
} }
next.ServeHTTP(w, r)
return
}
allowed, hasWildcard := checkOrigin(origin, config.AllowedOrigins)
if !allowed {
http.Error(w, "Origin not allowed", http.StatusForbidden)
return
}
if r.Method == "OPTIONS" {
setCORSHeaders(w, origin, hasWildcard, config)
w.Header().Set("Access-Control-Allow-Methods", strings.Join(config.AllowedMethods, ", "))
w.Header().Set("Access-Control-Allow-Headers", strings.Join(config.AllowedHeaders, ", "))
w.Header().Set("Access-Control-Max-Age", fmt.Sprintf("%d", config.MaxAge))
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
return return
} }
if origin != "" { setCORSHeaders(w, origin, hasWildcard, config)
allowed := false
hasWildcard := false
for _, allowedOrigin := range config.AllowedOrigins {
if allowedOrigin == "*" {
hasWildcard = true
allowed = true
break
}
if allowedOrigin == origin {
allowed = true
break
}
}
if !allowed {
http.Error(w, "Origin not allowed", http.StatusForbidden)
return
}
if hasWildcard && !config.AllowCredentials {
w.Header().Set("Access-Control-Allow-Origin", "*")
} else {
w.Header().Set("Access-Control-Allow-Origin", origin)
}
if config.AllowCredentials && !hasWildcard {
w.Header().Set("Access-Control-Allow-Credentials", "true")
}
}
next.ServeHTTP(w, r) next.ServeHTTP(w, r)
}) })
} }

View File

@@ -318,7 +318,7 @@ func TestConcurrentAccess(t *testing.T) {
collector := NewMetricsCollector(monitor) collector := NewMetricsCollector(monitor)
done := make(chan bool, 10) done := make(chan bool, 10)
for i := 0; i < 10; i++ { for range 10 {
go func() { go func() {
monitor.LogQuery("SELECT * FROM users", 50*time.Millisecond, nil) monitor.LogQuery("SELECT * FROM users", 50*time.Millisecond, nil)
collector.RecordRequest(100*time.Millisecond, false) collector.RecordRequest(100*time.Millisecond, false)
@@ -326,7 +326,7 @@ func TestConcurrentAccess(t *testing.T) {
}() }()
} }
for i := 0; i < 10; i++ { for range 10 {
<-done <-done
} }
@@ -384,7 +384,7 @@ func TestThreadSafety(t *testing.T) {
numGoroutines := 100 numGoroutines := 100
done := make(chan bool, numGoroutines) done := make(chan bool, numGoroutines)
for i := 0; i < numGoroutines; i++ { for i := range numGoroutines {
go func(id int) { go func(id int) {
if id%2 == 0 { if id%2 == 0 {
@@ -398,7 +398,7 @@ func TestThreadSafety(t *testing.T) {
}(i) }(i)
} }
for i := 0; i < numGoroutines; i++ { for range numGoroutines {
<-done <-done
} }

View File

@@ -388,7 +388,7 @@ func TestRateLimiterMaxKeys(t *testing.T) {
limiter := NewRateLimiterWithConfig(1*time.Minute, 10, 5, 1*time.Minute, 2*time.Minute) limiter := NewRateLimiterWithConfig(1*time.Minute, 10, 5, 1*time.Minute, 2*time.Minute)
defer limiter.StopCleanup() defer limiter.StopCleanup()
for i := 0; i < 5; i++ { for i := range 5 {
key := fmt.Sprintf("key-%d", i) key := fmt.Sprintf("key-%d", i)
if !limiter.Allow(key) { if !limiter.Allow(key) {
t.Errorf("Key %s should be allowed", key) t.Errorf("Key %s should be allowed", key)
@@ -435,7 +435,7 @@ func TestRateLimiterRegistry(t *testing.T) {
request := httptest.NewRequest("GET", "/test", nil) request := httptest.NewRequest("GET", "/test", nil)
request.RemoteAddr = "127.0.0.1:12345" request.RemoteAddr = "127.0.0.1:12345"
for i := 0; i < 50; i++ { for i := range 50 {
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
server1.ServeHTTP(recorder, request) server1.ServeHTTP(recorder, request)
if recorder.Code != http.StatusOK { if recorder.Code != http.StatusOK {
@@ -443,7 +443,7 @@ func TestRateLimiterRegistry(t *testing.T) {
} }
} }
for i := 0; i < 50; i++ { for i := range 50 {
recorder2 := httptest.NewRecorder() recorder2 := httptest.NewRecorder()
server2.ServeHTTP(recorder2, request) server2.ServeHTTP(recorder2, request)
if recorder2.Code != http.StatusOK { if recorder2.Code != http.StatusOK {
@@ -463,7 +463,7 @@ func TestRateLimiterRegistry(t *testing.T) {
t.Error("101st request to server2 should be rejected (shared limiter reached limit)") t.Error("101st request to server2 should be rejected (shared limiter reached limit)")
} }
for i := 0; i < 50; i++ { for i := range 50 {
recorder3 := httptest.NewRecorder() recorder3 := httptest.NewRecorder()
server3.ServeHTTP(recorder3, request) server3.ServeHTTP(recorder3, request)
if recorder3.Code != http.StatusOK { if recorder3.Code != http.StatusOK {

View File

@@ -458,13 +458,13 @@ func TestIsRapidRequest(t *testing.T) {
ip := "192.168.1.1" ip := "192.168.1.1"
for i := 0; i < 50; i++ { for i := range 50 {
if isRapidRequest(ip) { if isRapidRequest(ip) {
t.Errorf("Request %d should not be considered rapid", i+1) t.Errorf("Request %d should not be considered rapid", i+1)
} }
} }
for i := 0; i < 110; i++ { for i := range 110 {
result := isRapidRequest(ip) result := isRapidRequest(ip)
if i < 50 { if i < 50 {
if result { if result {

View File

@@ -39,7 +39,7 @@ func TestValidationMiddleware(t *testing.T) {
body, _ := json.Marshal(user) body, _ := json.Marshal(user)
request := httptest.NewRequest("POST", "/users", bytes.NewBuffer(body)) request := httptest.NewRequest("POST", "/users", bytes.NewBuffer(body))
request.Header.Set("Content-Type", "application/json") request.Header.Set("Content-Type", "application/json")
ctx := context.WithValue(request.Context(), DTOTypeKey, reflect.TypeOf(TestUser{})) ctx := context.WithValue(request.Context(), DTOTypeKey, reflect.TypeFor[TestUser]())
request = request.WithContext(ctx) request = request.WithContext(ctx)
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
@@ -70,7 +70,7 @@ func TestValidationMiddleware(t *testing.T) {
body, _ := json.Marshal(user) body, _ := json.Marshal(user)
request := httptest.NewRequest("POST", "/users", bytes.NewBuffer(body)) request := httptest.NewRequest("POST", "/users", bytes.NewBuffer(body))
request.Header.Set("Content-Type", "application/json") request.Header.Set("Content-Type", "application/json")
ctx := context.WithValue(request.Context(), DTOTypeKey, reflect.TypeOf(TestUser{})) ctx := context.WithValue(request.Context(), DTOTypeKey, reflect.TypeFor[TestUser]())
request = request.WithContext(ctx) request = request.WithContext(ctx)
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
@@ -117,7 +117,7 @@ func TestValidationMiddleware(t *testing.T) {
body, _ := json.Marshal(user) body, _ := json.Marshal(user)
request := httptest.NewRequest("POST", "/users", bytes.NewBuffer(body)) request := httptest.NewRequest("POST", "/users", bytes.NewBuffer(body))
request.Header.Set("Content-Type", "application/json") request.Header.Set("Content-Type", "application/json")
ctx := context.WithValue(request.Context(), DTOTypeKey, reflect.TypeOf(TestUser{})) ctx := context.WithValue(request.Context(), DTOTypeKey, reflect.TypeFor[TestUser]())
request = request.WithContext(ctx) request = request.WithContext(ctx)
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
@@ -144,7 +144,7 @@ func TestValidationMiddleware(t *testing.T) {
body, _ := json.Marshal(user) body, _ := json.Marshal(user)
request := httptest.NewRequest("POST", "/users", bytes.NewBuffer(body)) request := httptest.NewRequest("POST", "/users", bytes.NewBuffer(body))
request.Header.Set("Content-Type", "application/json") request.Header.Set("Content-Type", "application/json")
ctx := context.WithValue(request.Context(), DTOTypeKey, reflect.TypeOf(TestUser{})) ctx := context.WithValue(request.Context(), DTOTypeKey, reflect.TypeFor[TestUser]())
request = request.WithContext(ctx) request = request.WithContext(ctx)
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()

View File

@@ -156,7 +156,7 @@ func TestPostRepository_GetAll(t *testing.T) {
user := suite.CreateTestUser("testuser2", "test2@example.com", "password123") user := suite.CreateTestUser("testuser2", "test2@example.com", "password123")
for i := 0; i < 5; i++ { for i := range 5 {
suite.CreateTestPost(user.ID, suite.CreateTestPost(user.ID,
"Post "+strconv.Itoa(i), "Post "+strconv.Itoa(i),
"https://example.com/"+strconv.Itoa(i), "https://example.com/"+strconv.Itoa(i),
@@ -178,7 +178,7 @@ func TestPostRepository_GetAll(t *testing.T) {
user := suite.CreateTestUser("testuser3", "test3@example.com", "password123") user := suite.CreateTestUser("testuser3", "test3@example.com", "password123")
for i := 0; i < 5; i++ { for i := range 5 {
suite.CreateTestPost(user.ID, suite.CreateTestPost(user.ID,
"Post "+strconv.Itoa(i), "Post "+strconv.Itoa(i),
"https://example.com/"+strconv.Itoa(i), "https://example.com/"+strconv.Itoa(i),
@@ -328,7 +328,7 @@ func TestPostRepository_Count(t *testing.T) {
user := suite.CreateTestUser("testuser", "test@example.com", "password123") user := suite.CreateTestUser("testuser", "test@example.com", "password123")
for i := 0; i < 5; i++ { for i := range 5 {
suite.CreateTestPost(user.ID, suite.CreateTestPost(user.ID,
"Post "+strconv.Itoa(i), "Post "+strconv.Itoa(i),
"https://example.com/"+strconv.Itoa(i), "https://example.com/"+strconv.Itoa(i),
@@ -506,7 +506,7 @@ func TestPostRepository_Search(t *testing.T) {
user := suite.CreateTestUser("testuser2", "test2@example.com", "password123") user := suite.CreateTestUser("testuser2", "test2@example.com", "password123")
for i := 0; i < 5; i++ { for i := range 5 {
suite.CreateTestPost(user.ID, suite.CreateTestPost(user.ID,
"Go Post "+strconv.Itoa(i), "Go Post "+strconv.Itoa(i),
"https://example.com/go"+strconv.Itoa(i), "https://example.com/go"+strconv.Itoa(i),

View File

@@ -621,7 +621,7 @@ func TestUserRepository_GetPosts(t *testing.T) {
user := suite.CreateTestUser("pagination", "pagination@example.com", "password123") user := suite.CreateTestUser("pagination", "pagination@example.com", "password123")
for i := 0; i < 5; i++ { for i := range 5 {
suite.CreateTestPost(user.ID, suite.CreateTestPost(user.ID,
"Post "+strconv.Itoa(i), "Post "+strconv.Itoa(i),
"https://example.com/"+strconv.Itoa(i), "https://example.com/"+strconv.Itoa(i),
@@ -1089,7 +1089,7 @@ func TestUserRepository_ConcurrentAccess(t *testing.T) {
done := make(chan bool, 10) done := make(chan bool, 10)
errors := make(chan error, 10) errors := make(chan error, 10)
for i := 0; i < 10; i++ { for i := range 10 {
go func(id int) { go func(id int) {
defer func() { done <- true }() defer func() { done <- true }()
user := &database.User{ user := &database.User{
@@ -1099,7 +1099,7 @@ func TestUserRepository_ConcurrentAccess(t *testing.T) {
EmailVerified: true, EmailVerified: true,
} }
var err error var err error
for retries := 0; retries < 5; retries++ { for retries := range 5 {
err = suite.UserRepo.Create(user) err = suite.UserRepo.Create(user)
if err == nil { if err == nil {
break break
@@ -1116,7 +1116,7 @@ func TestUserRepository_ConcurrentAccess(t *testing.T) {
}(i) }(i)
} }
for i := 0; i < 10; i++ { for range 10 {
<-done <-done
} }
close(errors) close(errors)
@@ -1142,7 +1142,7 @@ func TestUserRepository_ConcurrentAccess(t *testing.T) {
user := suite.CreateTestUser("concurrent_update", "update@example.com", "password123") user := suite.CreateTestUser("concurrent_update", "update@example.com", "password123")
done := make(chan bool, 5) done := make(chan bool, 5)
for i := 0; i < 5; i++ { for i := range 5 {
go func(id int) { go func(id int) {
defer func() { done <- true }() defer func() { done <- true }()
user.Username = fmt.Sprintf("updated%d", id) user.Username = fmt.Sprintf("updated%d", id)
@@ -1150,7 +1150,7 @@ func TestUserRepository_ConcurrentAccess(t *testing.T) {
}(i) }(i)
} }
for i := 0; i < 5; i++ { for range 5 {
<-done <-done
} }
@@ -1172,7 +1172,7 @@ func TestUserRepository_ConcurrentAccess(t *testing.T) {
done := make(chan bool, 2) done := make(chan bool, 2)
go func() { go func() {
var err error var err error
for retries := 0; retries < 5; retries++ { for retries := range 5 {
err = suite.UserRepo.Delete(user1.ID) err = suite.UserRepo.Delete(user1.ID)
if err == nil { if err == nil {
break break
@@ -1187,7 +1187,7 @@ func TestUserRepository_ConcurrentAccess(t *testing.T) {
}() }()
go func() { go func() {
var err error var err error
for retries := 0; retries < 5; retries++ { for retries := range 5 {
err = suite.UserRepo.Delete(user2.ID) err = suite.UserRepo.Delete(user2.ID)
if err == nil { if err == nil {
break break
@@ -1210,7 +1210,7 @@ func TestUserRepository_ConcurrentAccess(t *testing.T) {
} }
if count > 0 { if count > 0 {
var err error var err error
for retries := 0; retries < 5; retries++ { for retries := range 5 {
count, err = suite.UserRepo.Count() count, err = suite.UserRepo.Count()
if err == nil && count == 0 { if err == nil && count == 0 {
break break

View File

@@ -27,6 +27,7 @@ type RouterConfig struct {
Debug bool Debug bool
DisableCache bool DisableCache bool
DisableCompression bool DisableCompression bool
CacheablePaths []string
DBMonitor middleware.DBMonitor DBMonitor middleware.DBMonitor
RateLimitConfig config.RateLimitConfig RateLimitConfig config.RateLimitConfig
} }
@@ -49,6 +50,9 @@ func NewRouter(cfg RouterConfig) http.Handler {
if !cfg.DisableCache { if !cfg.DisableCache {
cache := middleware.NewInMemoryCache() cache := middleware.NewInMemoryCache()
cacheConfig := middleware.DefaultCacheConfig() cacheConfig := middleware.DefaultCacheConfig()
if len(cfg.CacheablePaths) > 0 {
cacheConfig.CacheablePaths = append([]string{}, cfg.CacheablePaths...)
}
router.Use(middleware.CacheMiddleware(cache, cacheConfig)) router.Use(middleware.CacheMiddleware(cache, cacheConfig))
router.Use(middleware.CacheInvalidationMiddleware(cache)) router.Use(middleware.CacheInvalidationMiddleware(cache))
} }

View File

@@ -77,7 +77,8 @@ func setupTestHandlers() (*handlers.AuthHandler, *handlers.PostHandler, *handler
emailSender := &testutils.MockEmailSender{} emailSender := &testutils.MockEmailSender{}
voteService := services.NewVoteService(voteRepo, postRepo, nil) voteService := services.NewVoteService(voteRepo, postRepo, nil)
metadataService := services.NewURLMetadataService() titleFetcher := &testutils.MockTitleFetcher{}
titleFetcher.SetTitle("Example Domain")
mockRefreshRepo := &mockRefreshTokenRepository{} mockRefreshRepo := &mockRefreshTokenRepository{}
mockDeletionRepo := &mockAccountDeletionRepository{} mockDeletionRepo := &mockAccountDeletionRepository{}
@@ -94,7 +95,7 @@ func setupTestHandlers() (*handlers.AuthHandler, *handlers.PostHandler, *handler
} }
authHandler := handlers.NewAuthHandler(authFacade, userRepo) authHandler := handlers.NewAuthHandler(authFacade, userRepo)
postHandler := handlers.NewPostHandler(postRepo, metadataService, voteService) postHandler := handlers.NewPostHandler(postRepo, titleFetcher, voteService)
voteHandler := handlers.NewVoteHandler(voteService) voteHandler := handlers.NewVoteHandler(voteService)
userHandler := handlers.NewUserHandler(userRepo, authFacade) userHandler := handlers.NewUserHandler(userRepo, authFacade)
apiHandler := handlers.NewAPIHandler(testutils.AppTestConfig, postRepo, userRepo, voteService) apiHandler := handlers.NewAPIHandler(testutils.AppTestConfig, postRepo, userRepo, voteService)

View File

@@ -90,7 +90,7 @@ func TestEmailService_Performance(t *testing.T) {
start := time.Now() start := time.Now()
iterations := 1000 iterations := 1000
for i := 0; i < iterations; i++ { for range iterations {
service.GenerateVerificationEmailBody(user.Username, "https://example.com/confirm?token=test") service.GenerateVerificationEmailBody(user.Username, "https://example.com/confirm?token=test")
} }
duration := time.Since(start) duration := time.Since(start)

View File

@@ -8,6 +8,7 @@ import (
"net" "net"
"net/http" "net/http"
"net/url" "net/url"
"slices"
"strings" "strings"
"sync" "sync"
"time" "time"
@@ -339,10 +340,8 @@ func (s *URLMetadataService) validateURLForSSRF(u *url.URL) error {
if err != nil { if err != nil {
return ErrSSRFBlocked return ErrSSRFBlocked
} }
for _, ip := range ips { if slices.ContainsFunc(ips, isPrivateOrReservedIP) {
if isPrivateOrReservedIP(ip) { return ErrSSRFBlocked
return ErrSSRFBlocked
}
} }
return nil return nil
} }
@@ -359,13 +358,7 @@ func isLocalhost(hostname string) bool {
"0:0:0:0:0:0:0:0", "0:0:0:0:0:0:0:0",
} }
for _, name := range localhostNames { return slices.Contains(localhostNames, hostname)
if hostname == name {
return true
}
}
return false
} }
func isPrivateOrReservedIP(ip net.IP) bool { func isPrivateOrReservedIP(ip net.IP) bool {

View File

@@ -285,7 +285,7 @@ func (b *VoteRequestBuilder) Build() VoteRequest {
func (f *TestDataFactory) CreateTestUsers(count int) []*database.User { func (f *TestDataFactory) CreateTestUsers(count int) []*database.User {
users := make([]*database.User, count) users := make([]*database.User, count)
for i := 0; i < count; i++ { for i := range count {
users[i] = f.NewUserBuilder(). users[i] = f.NewUserBuilder().
WithID(uint(i + 1)). WithID(uint(i + 1)).
WithUsername(fmt.Sprintf("user%d", i+1)). WithUsername(fmt.Sprintf("user%d", i+1)).
@@ -297,7 +297,7 @@ func (f *TestDataFactory) CreateTestUsers(count int) []*database.User {
func (f *TestDataFactory) CreateTestPosts(count int) []*database.Post { func (f *TestDataFactory) CreateTestPosts(count int) []*database.Post {
posts := make([]*database.Post, count) posts := make([]*database.Post, count)
for i := 0; i < count; i++ { for i := range count {
posts[i] = f.NewPostBuilder(). posts[i] = f.NewPostBuilder().
WithID(uint(i+1)). WithID(uint(i+1)).
WithTitle(fmt.Sprintf("Post %d", i+1)). WithTitle(fmt.Sprintf("Post %d", i+1)).
@@ -332,7 +332,7 @@ func (f *TestDataFactory) CreateTestVotes(count int) []*database.Vote {
func (f *TestDataFactory) CreateTestAuthResults(count int) []*AuthResult { func (f *TestDataFactory) CreateTestAuthResults(count int) []*AuthResult {
results := make([]*AuthResult, count) results := make([]*AuthResult, count)
for i := 0; i < count; i++ { for i := range count {
results[i] = f.NewAuthResultBuilder(). results[i] = f.NewAuthResultBuilder().
WithUser(f.NewUserBuilder(). WithUser(f.NewUserBuilder().
WithID(uint(i + 1)). WithID(uint(i + 1)).
@@ -346,7 +346,7 @@ func (f *TestDataFactory) CreateTestAuthResults(count int) []*AuthResult {
func (f *TestDataFactory) CreateTestVoteRequests(count int) []VoteRequest { func (f *TestDataFactory) CreateTestVoteRequests(count int) []VoteRequest {
requests := make([]VoteRequest, count) requests := make([]VoteRequest, count)
for i := 0; i < count; i++ { for i := range count {
voteType := database.VoteUp voteType := database.VoteUp
if i%3 == 0 { if i%3 == 0 {
voteType = database.VoteDown voteType = database.VoteDown
@@ -365,8 +365,9 @@ func (f *TestDataFactory) CreateTestVoteRequests(count int) []VoteRequest {
return requests return requests
} }
//go:fix inline
func uintPtr(u uint) *uint { func uintPtr(u uint) *uint {
return &u return new(u)
} }
type E2ETestDataFactory struct { type E2ETestDataFactory struct {
@@ -450,7 +451,7 @@ func (f *E2ETestDataFactory) CreateMultipleUsers(t *testing.T, count int, userna
var users []*TestUser var users []*TestUser
timestamp := time.Now().UnixNano() timestamp := time.Now().UnixNano()
for i := 0; i < count; i++ { for i := range count {
uniqueID := timestamp + int64(i) uniqueID := timestamp + int64(i)
username := fmt.Sprintf("%s%d", usernamePrefix, uniqueID) username := fmt.Sprintf("%s%d", usernamePrefix, uniqueID)
email := fmt.Sprintf("%s%d@example.com", emailPrefix, uniqueID) email := fmt.Sprintf("%s%d@example.com", emailPrefix, uniqueID)

View File

@@ -123,12 +123,12 @@ func defaultIfEmpty(value, fallback string) string {
} }
func extractTokenFromBody(body string) string { func extractTokenFromBody(body string) string {
index := strings.Index(body, "token=") _, after, ok := strings.Cut(body, "token=")
if index == -1 { if !ok {
return "" return ""
} }
tokenPart := body[index+len("token="):] tokenPart := after
if delimIdx := strings.IndexAny(tokenPart, "&\"'\\\r\n <>"); delimIdx != -1 { if delimIdx := strings.IndexAny(tokenPart, "&\"'\\\r\n <>"); delimIdx != -1 {
tokenPart = tokenPart[:delimIdx] tokenPart = tokenPart[:delimIdx]

View File

@@ -164,7 +164,7 @@ func getErrorField(resp *APIResponse) (string, bool) {
if resp == nil { if resp == nil {
return "", false return "", false
} }
if dataMap, ok := resp.Data.(map[string]interface{}); ok { if dataMap, ok := resp.Data.(map[string]any); ok {
if errorVal, ok := dataMap["error"].(string); ok { if errorVal, ok := dataMap["error"].(string); ok {
return errorVal, true return errorVal, true
} }

View File

@@ -202,7 +202,7 @@ func camelCaseToWords(s string) string {
return result.String() return result.String()
} }
func ValidateStruct(s interface{}) error { func ValidateStruct(s any) error {
if s == nil { if s == nil {
return nil return nil
} }
@@ -210,7 +210,7 @@ func ValidateStruct(s interface{}) error {
val := reflect.ValueOf(s) val := reflect.ValueOf(s)
typ := reflect.TypeOf(s) typ := reflect.TypeOf(s)
if val.Kind() == reflect.Ptr { if val.Kind() == reflect.Pointer {
if val.IsNil() { if val.IsNil() {
return nil return nil
} }
@@ -252,9 +252,9 @@ func ValidateStruct(s interface{}) error {
} }
var tagName, param string var tagName, param string
if idx := strings.Index(tag, "="); idx != -1 { if before, after, ok := strings.Cut(tag, "="); ok {
tagName = tag[:idx] tagName = before
param = tag[idx+1:] param = after
} else { } else {
tagName = tag tagName = tag
} }

View File

@@ -1,3 +1,5 @@
# screenshot
In this folder, you will find screenshots of the app. In this folder, you will find screenshots of the app.
Two kinds of screenshot here: Two kinds of screenshot here:

View File

@@ -6,13 +6,13 @@ if [ "$EUID" -ne 0 ]; then
exit 1 exit 1
fi fi
read -s "Do you want to install PostgreSQL 18? [y/N] " INSTALL_PG read -rp "Do you want to install PostgreSQL 18? [y/N] " INSTALL_PG
if [ "$INSTALL_PG" != "y" ]; then if [ "$INSTALL_PG" != "y" ]; then
echo "PostgreSQL 18 will not be installed" echo "PostgreSQL 18 will not be installed"
exit 0 exit 0
fi fi
read -s -p "Enter password for PostgreSQL user 'goyco': " GOYCO_PWD read -rsp "Enter password for PostgreSQL user 'goyco': " GOYCO_PWD
echo echo
apt-get update apt-get update