Compare commits
29 Commits
034bd8669e
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
| e58ba1b8d1 | |||
| 4ffc601723 | |||
| d6321e775a | |||
| de9b544afb | |||
| 19291b7f61 | |||
| c31eb2f3df | |||
| de08878de7 | |||
| f0e8da51d0 | |||
| 85882bae14 | |||
| 9185ffa6b5 | |||
| 986b4e9388 | |||
| ac6e1ba80b | |||
| 14da02bc3f | |||
| 31ef30c941 | |||
| d4a89325e0 | |||
| 4eb0a6360f | |||
| 040b9148de | |||
| 6e0dfabcff | |||
| 9e81ddfdfa | |||
| b3b7c1d527 | |||
| 4c1caa44dd | |||
| 52c964abd2 | |||
| a854138eac | |||
| 70bfb54acf | |||
| a3ed6685de | |||
| 8f30fe7412 | |||
| 1a051b594c | |||
| 9718bcc79b | |||
| b1146b241c |
@@ -1,4 +1,4 @@
|
|||||||
ARG GO_VERSION=1.25.4
|
ARG GO_VERSION=1.26.0
|
||||||
|
|
||||||
# Building the binary using a golang alpine image
|
# Building the binary using a golang alpine image
|
||||||
FROM golang:${GO_VERSION}-alpine AS go-builder
|
FROM golang:${GO_VERSION}-alpine AS go-builder
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ It's basically pure Go (using Chi router), raw CSS and PostgreSQL 18.
|
|||||||
|
|
||||||
### Prerequisites
|
### Prerequisites
|
||||||
|
|
||||||
- Go 1.25.0 or later
|
- Go 1.26.0 or later
|
||||||
- PostgreSQL 18 or later
|
- PostgreSQL 18 or later
|
||||||
- SMTP server for email functionality
|
- SMTP server for email functionality
|
||||||
|
|
||||||
@@ -252,10 +252,12 @@ Goyco includes a comprehensive CLI for administration:
|
|||||||
./bin/goyco start # Start server as daemon
|
./bin/goyco start # Start server as daemon
|
||||||
./bin/goyco stop # Stop daemon
|
./bin/goyco stop # Stop daemon
|
||||||
./bin/goyco status # Check server status
|
./bin/goyco status # Check server status
|
||||||
|
./bin/goyco health # Check application and dependencies health
|
||||||
|
|
||||||
# Database management
|
# Database management
|
||||||
./bin/goyco migrate # Run database migrations
|
./bin/goyco migrate # Run database migrations
|
||||||
./bin/goyco seed database # Seed database with sample data
|
./bin/goyco seed database # Seed database with sample data
|
||||||
|
./bin/goyco seed database --posts 100 --users 10 --upvote-ratio 0.5 # Customize seeding
|
||||||
|
|
||||||
# User management
|
# User management
|
||||||
./bin/goyco user create # Create new user
|
./bin/goyco user create # Create new user
|
||||||
@@ -285,6 +287,7 @@ All CLI commands support JSON output for easier parsing and integration with scr
|
|||||||
./bin/goyco --json user list
|
./bin/goyco --json user list
|
||||||
./bin/goyco --json post list
|
./bin/goyco --json post list
|
||||||
./bin/goyco --json status
|
./bin/goyco --json status
|
||||||
|
./bin/goyco --json health # Check health with JSON output
|
||||||
|
|
||||||
# Example: Parse JSON output with jq
|
# Example: Parse JSON output with jq
|
||||||
./bin/goyco --json user list | jq '.users[0].username'
|
./bin/goyco --json user list | jq '.users[0].username'
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ import (
|
|||||||
|
|
||||||
var (
|
var (
|
||||||
helpPrinterOnce sync.Once
|
helpPrinterOnce sync.Once
|
||||||
defaultHelpPrinter func(io.Writer, string, interface{})
|
defaultHelpPrinter func(io.Writer, string, any)
|
||||||
)
|
)
|
||||||
|
|
||||||
func loadDotEnv() {
|
func loadDotEnv() {
|
||||||
@@ -59,6 +59,7 @@ func printRootUsage() {
|
|||||||
fmt.Fprintln(os.Stderr, " post manage posts (delete, list, search)")
|
fmt.Fprintln(os.Stderr, " post manage posts (delete, list, search)")
|
||||||
fmt.Fprintln(os.Stderr, " prune hard delete users and posts (posts, all)")
|
fmt.Fprintln(os.Stderr, " prune hard delete users and posts (posts, all)")
|
||||||
fmt.Fprintln(os.Stderr, " seed seed database with random data")
|
fmt.Fprintln(os.Stderr, " seed seed database with random data")
|
||||||
|
fmt.Fprintln(os.Stderr, " health check the health of the application and its dependencies")
|
||||||
}
|
}
|
||||||
|
|
||||||
func printRunUsage() {
|
func printRunUsage() {
|
||||||
@@ -70,7 +71,7 @@ func buildRootCommand(cfg *config.Config) *cli.Command {
|
|||||||
helpPrinterOnce.Do(func() {
|
helpPrinterOnce.Do(func() {
|
||||||
defaultHelpPrinter = cli.HelpPrinter
|
defaultHelpPrinter = cli.HelpPrinter
|
||||||
})
|
})
|
||||||
cli.HelpPrinter = func(w io.Writer, templ string, data interface{}) {
|
cli.HelpPrinter = func(w io.Writer, templ string, data any) {
|
||||||
if cmd, ok := data.(*cli.Command); ok && cmd.Root() == cmd {
|
if cmd, ok := data.(*cli.Command); ok && cmd.Root() == cmd {
|
||||||
printRootUsage()
|
printRootUsage()
|
||||||
return
|
return
|
||||||
@@ -180,6 +181,14 @@ func buildRootCommand(cfg *config.Config) *cli.Command {
|
|||||||
return commands.HandleSeedCommand(cfg, cmd.Name, cmd.Args().Slice())
|
return commands.HandleSeedCommand(cfg, cmd.Name, cmd.Args().Slice())
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
Name: "health",
|
||||||
|
Usage: "check the health of the application and its dependencies",
|
||||||
|
SkipFlagParsing: true,
|
||||||
|
Action: func(_ context.Context, cmd *cli.Command) error {
|
||||||
|
return commands.HandleHealthCommand(cfg, cmd.Name, cmd.Args().Slice())
|
||||||
|
},
|
||||||
|
},
|
||||||
},
|
},
|
||||||
Writer: os.Stdout,
|
Writer: os.Stdout,
|
||||||
ErrWriter: os.Stderr,
|
ErrWriter: os.Stderr,
|
||||||
|
|||||||
@@ -113,15 +113,15 @@ func truncate(in string, max int) string {
|
|||||||
return in[:max-3] + "..."
|
return in[:max-3] + "..."
|
||||||
}
|
}
|
||||||
|
|
||||||
func outputJSON(v interface{}) error {
|
func outputJSON(v any) error {
|
||||||
encoder := json.NewEncoder(os.Stdout)
|
encoder := json.NewEncoder(os.Stdout)
|
||||||
encoder.SetIndent("", " ")
|
encoder.SetIndent("", " ")
|
||||||
return encoder.Encode(v)
|
return encoder.Encode(v)
|
||||||
}
|
}
|
||||||
|
|
||||||
func outputWarning(message string, args ...interface{}) {
|
func outputWarning(message string, args ...any) {
|
||||||
if IsJSONOutput() {
|
if IsJSONOutput() {
|
||||||
outputJSON(map[string]interface{}{
|
outputJSON(map[string]any{
|
||||||
"warning": fmt.Sprintf(message, args...),
|
"warning": fmt.Sprintf(message, args...),
|
||||||
})
|
})
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@@ -236,7 +236,7 @@ func TestSetJSONOutput(t *testing.T) {
|
|||||||
done := make(chan bool)
|
done := make(chan bool)
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
for i := 0; i < 100; i++ {
|
for range 100 {
|
||||||
SetJSONOutput(true)
|
SetJSONOutput(true)
|
||||||
_ = IsJSONOutput()
|
_ = IsJSONOutput()
|
||||||
SetJSONOutput(false)
|
SetJSONOutput(false)
|
||||||
@@ -245,7 +245,7 @@ func TestSetJSONOutput(t *testing.T) {
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
for i := 0; i < 100; i++ {
|
for range 100 {
|
||||||
_ = IsJSONOutput()
|
_ = IsJSONOutput()
|
||||||
}
|
}
|
||||||
done <- true
|
done <- true
|
||||||
|
|||||||
@@ -87,7 +87,7 @@ func runStatusCommand(cfg *config.Config) error {
|
|||||||
|
|
||||||
if !isDaemonRunning(pidFile) {
|
if !isDaemonRunning(pidFile) {
|
||||||
if IsJSONOutput() {
|
if IsJSONOutput() {
|
||||||
outputJSON(map[string]interface{}{
|
outputJSON(map[string]any{
|
||||||
"status": "not_running",
|
"status": "not_running",
|
||||||
})
|
})
|
||||||
} else {
|
} else {
|
||||||
@@ -99,7 +99,7 @@ func runStatusCommand(cfg *config.Config) error {
|
|||||||
data, err := os.ReadFile(pidFile)
|
data, err := os.ReadFile(pidFile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if IsJSONOutput() {
|
if IsJSONOutput() {
|
||||||
outputJSON(map[string]interface{}{
|
outputJSON(map[string]any{
|
||||||
"status": "running",
|
"status": "running",
|
||||||
"error": fmt.Sprintf("PID file exists but cannot be read: %v", err),
|
"error": fmt.Sprintf("PID file exists but cannot be read: %v", err),
|
||||||
})
|
})
|
||||||
@@ -112,7 +112,7 @@ func runStatusCommand(cfg *config.Config) error {
|
|||||||
pid, err := strconv.Atoi(string(data))
|
pid, err := strconv.Atoi(string(data))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if IsJSONOutput() {
|
if IsJSONOutput() {
|
||||||
outputJSON(map[string]interface{}{
|
outputJSON(map[string]any{
|
||||||
"status": "running",
|
"status": "running",
|
||||||
"error": fmt.Sprintf("PID file exists but contains invalid PID: %v", err),
|
"error": fmt.Sprintf("PID file exists but contains invalid PID: %v", err),
|
||||||
})
|
})
|
||||||
@@ -123,7 +123,7 @@ func runStatusCommand(cfg *config.Config) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if IsJSONOutput() {
|
if IsJSONOutput() {
|
||||||
outputJSON(map[string]interface{}{
|
outputJSON(map[string]any{
|
||||||
"status": "running",
|
"status": "running",
|
||||||
"pid": pid,
|
"pid": pid,
|
||||||
})
|
})
|
||||||
@@ -171,7 +171,7 @@ func stopDaemon(cfg *config.Config) error {
|
|||||||
_ = os.Remove(pidFile)
|
_ = os.Remove(pidFile)
|
||||||
|
|
||||||
if IsJSONOutput() {
|
if IsJSONOutput() {
|
||||||
outputJSON(map[string]interface{}{
|
outputJSON(map[string]any{
|
||||||
"action": "stopped",
|
"action": "stopped",
|
||||||
"pid": pid,
|
"pid": pid,
|
||||||
})
|
})
|
||||||
@@ -219,7 +219,7 @@ func runDaemon(cfg *config.Config) error {
|
|||||||
return fmt.Errorf("cannot write PID file: %w", err)
|
return fmt.Errorf("cannot write PID file: %w", err)
|
||||||
}
|
}
|
||||||
if IsJSONOutput() {
|
if IsJSONOutput() {
|
||||||
outputJSON(map[string]interface{}{
|
outputJSON(map[string]any{
|
||||||
"action": "started",
|
"action": "started",
|
||||||
"pid": pid,
|
"pid": pid,
|
||||||
"pid_file": pidFile,
|
"pid_file": pidFile,
|
||||||
|
|||||||
94
cmd/goyco/commands/health.go
Normal file
94
cmd/goyco/commands/health.go
Normal file
@@ -0,0 +1,94 @@
|
|||||||
|
package commands
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"goyco/internal/config"
|
||||||
|
"goyco/internal/database"
|
||||||
|
"goyco/internal/health"
|
||||||
|
"goyco/internal/middleware"
|
||||||
|
"goyco/internal/version"
|
||||||
|
)
|
||||||
|
|
||||||
|
func HandleHealthCommand(cfg *config.Config, name string, args []string) error {
|
||||||
|
fs := newFlagSet(name, printHealthUsage)
|
||||||
|
if err := parseCommand(fs, args, name); err != nil {
|
||||||
|
if errors.Is(err, ErrHelpRequested) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if fs.NArg() > 0 {
|
||||||
|
printHealthUsage()
|
||||||
|
return errors.New("unexpected arguments for health command")
|
||||||
|
}
|
||||||
|
|
||||||
|
return runHealthCheck(cfg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func printHealthUsage() {
|
||||||
|
fmt.Fprintln(os.Stderr, "Usage: goyco health")
|
||||||
|
fmt.Fprintln(os.Stderr, "\nCheck the health status of the application and its dependencies.")
|
||||||
|
}
|
||||||
|
|
||||||
|
func runHealthCheck(cfg *config.Config) error {
|
||||||
|
compositeChecker := health.NewCompositeChecker()
|
||||||
|
|
||||||
|
dbChecker, err := createDatabaseChecker(cfg)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create database checker: %w", err)
|
||||||
|
}
|
||||||
|
if dbChecker != nil {
|
||||||
|
compositeChecker.AddChecker(dbChecker)
|
||||||
|
}
|
||||||
|
|
||||||
|
smtpChecker := createSMTPChecker(cfg)
|
||||||
|
if smtpChecker != nil {
|
||||||
|
compositeChecker.AddChecker(smtpChecker)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
result := compositeChecker.CheckWithVersion(ctx, version.GetVersion())
|
||||||
|
|
||||||
|
return outputJSON(result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func createDatabaseChecker(cfg *config.Config) (health.Checker, error) {
|
||||||
|
db, err := database.Connect(cfg)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to connect to database: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
sqlDB, err := db.DB()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get sql.DB: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
sqlDB.SetConnMaxLifetime(5 * time.Second)
|
||||||
|
sqlDB.SetMaxOpenConns(1)
|
||||||
|
sqlDB.SetMaxIdleConns(1)
|
||||||
|
|
||||||
|
monitor := middleware.NewInMemoryDBMonitor()
|
||||||
|
return health.NewDatabaseChecker(sqlDB, monitor), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func createSMTPChecker(cfg *config.Config) health.Checker {
|
||||||
|
if cfg.SMTP.Host == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
smtpConfig := health.SMTPConfig{
|
||||||
|
Host: cfg.SMTP.Host,
|
||||||
|
Port: cfg.SMTP.Port,
|
||||||
|
Username: cfg.SMTP.Username,
|
||||||
|
Password: cfg.SMTP.Password,
|
||||||
|
From: cfg.SMTP.From,
|
||||||
|
}
|
||||||
|
|
||||||
|
return health.NewSMTPChecker(smtpConfig)
|
||||||
|
}
|
||||||
89
cmd/goyco/commands/health_test.go
Normal file
89
cmd/goyco/commands/health_test.go
Normal file
@@ -0,0 +1,89 @@
|
|||||||
|
package commands
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"goyco/internal/config"
|
||||||
|
"goyco/internal/testutils"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestHandleHealthCommand(t *testing.T) {
|
||||||
|
cfg := testutils.NewTestConfig()
|
||||||
|
|
||||||
|
t.Run("help requested", func(t *testing.T) {
|
||||||
|
err := HandleHealthCommand(cfg, "health", []string{"--help"})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("unexpected error for help: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("unexpected arguments", func(t *testing.T) {
|
||||||
|
err := HandleHealthCommand(cfg, "health", []string{"extra", "args"})
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error for unexpected arguments")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(err.Error(), "unexpected arguments") {
|
||||||
|
t.Errorf("expected error containing 'unexpected arguments', got %q", err.Error())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPrintHealthUsage(t *testing.T) {
|
||||||
|
oldStderr := os.Stderr
|
||||||
|
r, w, _ := os.Pipe()
|
||||||
|
os.Stderr = w
|
||||||
|
|
||||||
|
printHealthUsage()
|
||||||
|
|
||||||
|
w.Close()
|
||||||
|
os.Stderr = oldStderr
|
||||||
|
|
||||||
|
buf := make([]byte, 1024)
|
||||||
|
n, _ := r.Read(buf)
|
||||||
|
output := string(buf[:n])
|
||||||
|
|
||||||
|
if !strings.Contains(output, "Usage: goyco health") {
|
||||||
|
t.Errorf("expected usage to contain 'Usage: goyco health', got %q", output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateSMTPChecker(t *testing.T) {
|
||||||
|
t.Run("with valid smtp config", func(t *testing.T) {
|
||||||
|
cfg := &config.Config{
|
||||||
|
SMTP: config.SMTPConfig{
|
||||||
|
Host: "smtp.example.com",
|
||||||
|
Port: 587,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
checker := createSMTPChecker(cfg)
|
||||||
|
|
||||||
|
if checker == nil {
|
||||||
|
t.Error("expected checker to be created")
|
||||||
|
}
|
||||||
|
|
||||||
|
if checker.Name() != "smtp" {
|
||||||
|
t.Errorf("expected name 'smtp', got %s", checker.Name())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("with empty host", func(t *testing.T) {
|
||||||
|
cfg := &config.Config{
|
||||||
|
SMTP: config.SMTPConfig{
|
||||||
|
Host: "",
|
||||||
|
Port: 587,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
checker := createSMTPChecker(cfg)
|
||||||
|
|
||||||
|
if checker != nil {
|
||||||
|
t.Error("expected checker to be nil when host is empty")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -91,7 +91,7 @@ func postDelete(repo repositories.PostRepository, args []string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if IsJSONOutput() {
|
if IsJSONOutput() {
|
||||||
outputJSON(map[string]interface{}{
|
outputJSON(map[string]any{
|
||||||
"action": "post_deleted",
|
"action": "post_deleted",
|
||||||
"id": id,
|
"id": id,
|
||||||
})
|
})
|
||||||
@@ -158,7 +158,7 @@ func postList(postQueries *services.PostQueries, args []string) error {
|
|||||||
CreatedAt: p.CreatedAt.Format("2006-01-02 15:04:05"),
|
CreatedAt: p.CreatedAt.Format("2006-01-02 15:04:05"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
outputJSON(map[string]interface{}{
|
outputJSON(map[string]any{
|
||||||
"posts": postsJSON,
|
"posts": postsJSON,
|
||||||
"count": len(postsJSON),
|
"count": len(postsJSON),
|
||||||
})
|
})
|
||||||
@@ -298,7 +298,7 @@ func postSearch(postQueries *services.PostQueries, args []string) error {
|
|||||||
CreatedAt: p.CreatedAt.Format("2006-01-02 15:04:05"),
|
CreatedAt: p.CreatedAt.Format("2006-01-02 15:04:05"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
outputJSON(map[string]interface{}{
|
outputJSON(map[string]any{
|
||||||
"search_term": sanitizedTerm,
|
"search_term": sanitizedTerm,
|
||||||
"posts": postsJSON,
|
"posts": postsJSON,
|
||||||
"count": len(postsJSON),
|
"count": len(postsJSON),
|
||||||
|
|||||||
@@ -508,9 +508,9 @@ func TestProgressIndicator_Concurrency(t *testing.T) {
|
|||||||
pi := NewProgressIndicator(100, "Concurrent test")
|
pi := NewProgressIndicator(100, "Concurrent test")
|
||||||
done := make(chan bool)
|
done := make(chan bool)
|
||||||
|
|
||||||
for i := 0; i < 10; i++ {
|
for range 10 {
|
||||||
go func() {
|
go func() {
|
||||||
for j := 0; j < 10; j++ {
|
for range 10 {
|
||||||
pi.Increment()
|
pi.Increment()
|
||||||
time.Sleep(1 * time.Millisecond)
|
time.Sleep(1 * time.Millisecond)
|
||||||
}
|
}
|
||||||
@@ -518,7 +518,7 @@ func TestProgressIndicator_Concurrency(t *testing.T) {
|
|||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := 0; i < 10; i++ {
|
for range 10 {
|
||||||
<-done
|
<-done
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -95,7 +95,7 @@ func prunePosts(postRepo repositories.PostRepository, args []string) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if *dryRun {
|
if *dryRun {
|
||||||
outputJSON(map[string]interface{}{
|
outputJSON(map[string]any{
|
||||||
"action": "prune_posts",
|
"action": "prune_posts",
|
||||||
"dry_run": true,
|
"dry_run": true,
|
||||||
"posts": postsJSON,
|
"posts": postsJSON,
|
||||||
@@ -110,7 +110,7 @@ func prunePosts(postRepo repositories.PostRepository, args []string) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("hard delete posts: %w", err)
|
return fmt.Errorf("hard delete posts: %w", err)
|
||||||
}
|
}
|
||||||
outputJSON(map[string]interface{}{
|
outputJSON(map[string]any{
|
||||||
"action": "prune_posts",
|
"action": "prune_posts",
|
||||||
"deleted_count": deletedCount,
|
"deleted_count": deletedCount,
|
||||||
})
|
})
|
||||||
@@ -178,7 +178,7 @@ func pruneUsers(userRepo repositories.UserRepository, postRepo repositories.Post
|
|||||||
userCount := len(users)
|
userCount := len(users)
|
||||||
if userCount == 0 {
|
if userCount == 0 {
|
||||||
if IsJSONOutput() {
|
if IsJSONOutput() {
|
||||||
outputJSON(map[string]interface{}{
|
outputJSON(map[string]any{
|
||||||
"action": "prune_users",
|
"action": "prune_users",
|
||||||
"count": 0,
|
"count": 0,
|
||||||
})
|
})
|
||||||
@@ -211,7 +211,7 @@ func pruneUsers(userRepo repositories.UserRepository, postRepo repositories.Post
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if *dryRun {
|
if *dryRun {
|
||||||
outputJSON(map[string]interface{}{
|
outputJSON(map[string]any{
|
||||||
"action": "prune_users",
|
"action": "prune_users",
|
||||||
"dry_run": true,
|
"dry_run": true,
|
||||||
"users": usersJSON,
|
"users": usersJSON,
|
||||||
@@ -229,7 +229,7 @@ func pruneUsers(userRepo repositories.UserRepository, postRepo repositories.Post
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("hard delete all users and posts: %w", err)
|
return fmt.Errorf("hard delete all users and posts: %w", err)
|
||||||
}
|
}
|
||||||
outputJSON(map[string]interface{}{
|
outputJSON(map[string]any{
|
||||||
"action": "prune_users",
|
"action": "prune_users",
|
||||||
"deleted_count": totalDeleted,
|
"deleted_count": totalDeleted,
|
||||||
"with_posts": true,
|
"with_posts": true,
|
||||||
@@ -242,7 +242,7 @@ func pruneUsers(userRepo repositories.UserRepository, postRepo repositories.Post
|
|||||||
}
|
}
|
||||||
deletedCount++
|
deletedCount++
|
||||||
}
|
}
|
||||||
outputJSON(map[string]interface{}{
|
outputJSON(map[string]any{
|
||||||
"action": "prune_users",
|
"action": "prune_users",
|
||||||
"deleted_count": deletedCount,
|
"deleted_count": deletedCount,
|
||||||
"with_posts": false,
|
"with_posts": false,
|
||||||
@@ -328,7 +328,7 @@ func pruneAll(userRepo repositories.UserRepository, postRepo repositories.PostRe
|
|||||||
|
|
||||||
if IsJSONOutput() {
|
if IsJSONOutput() {
|
||||||
if *dryRun {
|
if *dryRun {
|
||||||
outputJSON(map[string]interface{}{
|
outputJSON(map[string]any{
|
||||||
"action": "prune_all",
|
"action": "prune_all",
|
||||||
"dry_run": true,
|
"dry_run": true,
|
||||||
"user_count": len(userCount),
|
"user_count": len(userCount),
|
||||||
@@ -343,7 +343,7 @@ func pruneAll(userRepo repositories.UserRepository, postRepo repositories.PostRe
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("hard delete all: %w", err)
|
return fmt.Errorf("hard delete all: %w", err)
|
||||||
}
|
}
|
||||||
outputJSON(map[string]interface{}{
|
outputJSON(map[string]any{
|
||||||
"action": "prune_all",
|
"action": "prune_all",
|
||||||
"deleted_count": totalDeleted,
|
"deleted_count": totalDeleted,
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -56,10 +56,11 @@ func runSeedCommand(userRepo repositories.UserRepository, postRepo repositories.
|
|||||||
|
|
||||||
func printSeedUsage() {
|
func printSeedUsage() {
|
||||||
fmt.Fprintln(os.Stderr, "Seed subcommands:")
|
fmt.Fprintln(os.Stderr, "Seed subcommands:")
|
||||||
fmt.Fprintln(os.Stderr, " database [--posts <n>] [--users <n>] [--votes-per-post <n>]")
|
fmt.Fprintln(os.Stderr, " database [--posts <n>] [--users <n>] [--votes-per-post <n>] [--upvote-ratio <r>]")
|
||||||
fmt.Fprintln(os.Stderr, " --posts: number of posts to create (default: 40)")
|
fmt.Fprintln(os.Stderr, " --posts: number of posts to create (default: 40)")
|
||||||
fmt.Fprintln(os.Stderr, " --users: number of additional users to create (default: 5)")
|
fmt.Fprintln(os.Stderr, " --users: number of additional users to create (default: 5)")
|
||||||
fmt.Fprintln(os.Stderr, " --votes-per-post: average votes per post (default: 15)")
|
fmt.Fprintln(os.Stderr, " --votes-per-post: average votes per post (default: 15)")
|
||||||
|
fmt.Fprintln(os.Stderr, " --upvote-ratio: percentage of upvotes vs downvotes, 0.0-1.0 (default: 0.7)")
|
||||||
}
|
}
|
||||||
|
|
||||||
func clampFlagValue(value *int, min int, name string) {
|
func clampFlagValue(value *int, min int, name string) {
|
||||||
@@ -76,9 +77,10 @@ func seedDatabase(userRepo repositories.UserRepository, postRepo repositories.Po
|
|||||||
numPosts := fs.Int("posts", 40, "number of posts to create")
|
numPosts := fs.Int("posts", 40, "number of posts to create")
|
||||||
numUsers := fs.Int("users", 5, "number of additional users to create")
|
numUsers := fs.Int("users", 5, "number of additional users to create")
|
||||||
votesPerPost := fs.Int("votes-per-post", 15, "average votes per post")
|
votesPerPost := fs.Int("votes-per-post", 15, "average votes per post")
|
||||||
|
upvoteRatio := fs.Float64("upvote-ratio", 0.7, "percentage of upvotes vs downvotes, 0.0-1.0")
|
||||||
fs.SetOutput(os.Stderr)
|
fs.SetOutput(os.Stderr)
|
||||||
fs.Usage = func() {
|
fs.Usage = func() {
|
||||||
fmt.Fprintln(os.Stderr, "Usage: goyco seed database [--posts <n>] [--users <n>] [--votes-per-post <n>]")
|
fmt.Fprintln(os.Stderr, "Usage: goyco seed database [--posts <n>] [--users <n>] [--votes-per-post <n>] [--upvote-ratio <r>]")
|
||||||
fmt.Fprintln(os.Stderr, "\nOptions:")
|
fmt.Fprintln(os.Stderr, "\nOptions:")
|
||||||
fs.PrintDefaults()
|
fs.PrintDefaults()
|
||||||
}
|
}
|
||||||
@@ -93,6 +95,11 @@ func seedDatabase(userRepo repositories.UserRepository, postRepo repositories.Po
|
|||||||
clampFlagValue(numUsers, 0, "users")
|
clampFlagValue(numUsers, 0, "users")
|
||||||
clampFlagValue(numPosts, 1, "posts")
|
clampFlagValue(numPosts, 1, "posts")
|
||||||
clampFlagValue(votesPerPost, 0, "votes-per-post")
|
clampFlagValue(votesPerPost, 0, "votes-per-post")
|
||||||
|
if *upvoteRatio < 0 {
|
||||||
|
*upvoteRatio = 0
|
||||||
|
} else if *upvoteRatio > 1 {
|
||||||
|
*upvoteRatio = 1
|
||||||
|
}
|
||||||
|
|
||||||
if !IsJSONOutput() {
|
if !IsJSONOutput() {
|
||||||
fmt.Println("Starting database seeding...")
|
fmt.Println("Starting database seeding...")
|
||||||
@@ -119,7 +126,7 @@ func seedDatabase(userRepo repositories.UserRepository, postRepo repositories.Po
|
|||||||
fmt.Printf("Seed user ready: ID=%d Username=%s\n", seedUser.ID, seedUser.Username)
|
fmt.Printf("Seed user ready: ID=%d Username=%s\n", seedUser.ID, seedUser.Username)
|
||||||
}
|
}
|
||||||
|
|
||||||
generator := newSeedGenerator(string(userPasswordHash))
|
generator := newSeedGenerator(string(userPasswordHash), *upvoteRatio)
|
||||||
allUsers := []database.User{*seedUser}
|
allUsers := []database.User{*seedUser}
|
||||||
|
|
||||||
users, err := createUsers(generator, userRepo, *numUsers, "Creating users")
|
users, err := createUsers(generator, userRepo, *numUsers, "Creating users")
|
||||||
@@ -258,13 +265,15 @@ func validateSeedConsistency(voteRepo repositories.VoteRepository, users []datab
|
|||||||
type seedGenerator struct {
|
type seedGenerator struct {
|
||||||
passwordHash string
|
passwordHash string
|
||||||
randSource *rand.Rand
|
randSource *rand.Rand
|
||||||
|
upvoteRatio float64
|
||||||
}
|
}
|
||||||
|
|
||||||
func newSeedGenerator(passwordHash string) *seedGenerator {
|
func newSeedGenerator(passwordHash string, upvoteRatio float64) *seedGenerator {
|
||||||
seed := time.Now().UnixNano()
|
seed := time.Now().UnixNano()
|
||||||
return &seedGenerator{
|
return &seedGenerator{
|
||||||
passwordHash: passwordHash,
|
passwordHash: passwordHash,
|
||||||
randSource: rand.New(rand.NewSource(seed)),
|
randSource: rand.New(rand.NewSource(seed)),
|
||||||
|
upvoteRatio: upvoteRatio,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -312,7 +321,7 @@ func createUsers(g *seedGenerator, userRepo repositories.UserRepository, count i
|
|||||||
}
|
}
|
||||||
progress := maybeProgress(count, desc)
|
progress := maybeProgress(count, desc)
|
||||||
users := make([]database.User, 0, count)
|
users := make([]database.User, 0, count)
|
||||||
for i := 0; i < count; i++ {
|
for i := range count {
|
||||||
user, err := g.createSingleUser(userRepo, i+1)
|
user, err := g.createSingleUser(userRepo, i+1)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("create random user: %w", err)
|
return nil, fmt.Errorf("create random user: %w", err)
|
||||||
@@ -334,7 +343,7 @@ func createPosts(g *seedGenerator, postRepo repositories.PostRepository, authorI
|
|||||||
}
|
}
|
||||||
progress := maybeProgress(count, desc)
|
progress := maybeProgress(count, desc)
|
||||||
posts := make([]database.Post, 0, count)
|
posts := make([]database.Post, 0, count)
|
||||||
for i := 0; i < count; i++ {
|
for i := range count {
|
||||||
post, err := g.createSinglePost(postRepo, authorID, i+1)
|
post, err := g.createSinglePost(postRepo, authorID, i+1)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("create random post: %w", err)
|
return nil, fmt.Errorf("create random post: %w", err)
|
||||||
@@ -490,9 +499,8 @@ func (g *seedGenerator) createVotesForPost(voteRepo repositories.VoteRepository,
|
|||||||
}
|
}
|
||||||
usedUsers[user.ID] = true
|
usedUsers[user.ID] = true
|
||||||
|
|
||||||
voteTypeInt := g.randSource.Intn(10)
|
|
||||||
var voteType database.VoteType
|
var voteType database.VoteType
|
||||||
if voteTypeInt < 7 {
|
if g.randSource.Float64() < g.upvoteRatio {
|
||||||
voteType = database.VoteUp
|
voteType = database.VoteUp
|
||||||
} else {
|
} else {
|
||||||
voteType = database.VoteDown
|
voteType = database.VoteDown
|
||||||
|
|||||||
@@ -169,7 +169,7 @@ func userCreate(cfg *config.Config, repo repositories.UserRepository, args []str
|
|||||||
}
|
}
|
||||||
|
|
||||||
if IsJSONOutput() {
|
if IsJSONOutput() {
|
||||||
outputJSON(map[string]interface{}{
|
outputJSON(map[string]any{
|
||||||
"action": "user_created",
|
"action": "user_created",
|
||||||
"id": user.ID,
|
"id": user.ID,
|
||||||
"username": user.Username,
|
"username": user.Username,
|
||||||
@@ -296,7 +296,7 @@ func userUpdate(cfg *config.Config, repo repositories.UserRepository, refreshTok
|
|||||||
}
|
}
|
||||||
|
|
||||||
if IsJSONOutput() {
|
if IsJSONOutput() {
|
||||||
outputJSON(map[string]interface{}{
|
outputJSON(map[string]any{
|
||||||
"action": "user_updated",
|
"action": "user_updated",
|
||||||
"id": user.ID,
|
"id": user.ID,
|
||||||
"username": user.Username,
|
"username": user.Username,
|
||||||
@@ -424,7 +424,7 @@ func userDelete(cfg *config.Config, repo repositories.UserRepository, args []str
|
|||||||
}
|
}
|
||||||
|
|
||||||
if IsJSONOutput() {
|
if IsJSONOutput() {
|
||||||
outputJSON(map[string]interface{}{
|
outputJSON(map[string]any{
|
||||||
"action": "user_deleted",
|
"action": "user_deleted",
|
||||||
"id": id,
|
"id": id,
|
||||||
"username": user.Username,
|
"username": user.Username,
|
||||||
@@ -479,7 +479,7 @@ func userList(repo repositories.UserRepository, args []string) error {
|
|||||||
CreatedAt: u.CreatedAt.Format("2006-01-02 15:04:05"),
|
CreatedAt: u.CreatedAt.Format("2006-01-02 15:04:05"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
outputJSON(map[string]interface{}{
|
outputJSON(map[string]any{
|
||||||
"users": usersJSON,
|
"users": usersJSON,
|
||||||
"count": len(usersJSON),
|
"count": len(usersJSON),
|
||||||
})
|
})
|
||||||
@@ -578,7 +578,7 @@ func userLock(cfg *config.Config, repo repositories.UserRepository, args []strin
|
|||||||
|
|
||||||
if user.Locked {
|
if user.Locked {
|
||||||
if IsJSONOutput() {
|
if IsJSONOutput() {
|
||||||
outputJSON(map[string]interface{}{
|
outputJSON(map[string]any{
|
||||||
"action": "user_lock",
|
"action": "user_lock",
|
||||||
"id": id,
|
"id": id,
|
||||||
"username": user.Username,
|
"username": user.Username,
|
||||||
@@ -604,7 +604,7 @@ func userLock(cfg *config.Config, repo repositories.UserRepository, args []strin
|
|||||||
}
|
}
|
||||||
|
|
||||||
if IsJSONOutput() {
|
if IsJSONOutput() {
|
||||||
outputJSON(map[string]interface{}{
|
outputJSON(map[string]any{
|
||||||
"action": "user_locked",
|
"action": "user_locked",
|
||||||
"id": id,
|
"id": id,
|
||||||
"username": user.Username,
|
"username": user.Username,
|
||||||
@@ -653,7 +653,7 @@ func userUnlock(cfg *config.Config, repo repositories.UserRepository, args []str
|
|||||||
|
|
||||||
if !user.Locked {
|
if !user.Locked {
|
||||||
if IsJSONOutput() {
|
if IsJSONOutput() {
|
||||||
outputJSON(map[string]interface{}{
|
outputJSON(map[string]any{
|
||||||
"action": "user_unlock",
|
"action": "user_unlock",
|
||||||
"id": id,
|
"id": id,
|
||||||
"username": user.Username,
|
"username": user.Username,
|
||||||
@@ -679,7 +679,7 @@ func userUnlock(cfg *config.Config, repo repositories.UserRepository, args []str
|
|||||||
}
|
}
|
||||||
|
|
||||||
if IsJSONOutput() {
|
if IsJSONOutput() {
|
||||||
outputJSON(map[string]interface{}{
|
outputJSON(map[string]any{
|
||||||
"action": "user_unlocked",
|
"action": "user_unlocked",
|
||||||
"id": id,
|
"id": id,
|
||||||
"username": user.Username,
|
"username": user.Username,
|
||||||
@@ -732,7 +732,7 @@ func resetUserPassword(cfg *config.Config, repo repositories.UserRepository, ses
|
|||||||
}
|
}
|
||||||
|
|
||||||
if IsJSONOutput() {
|
if IsJSONOutput() {
|
||||||
outputJSON(map[string]interface{}{
|
outputJSON(map[string]any{
|
||||||
"action": "password_reset",
|
"action": "password_reset",
|
||||||
"id": userID,
|
"id": userID,
|
||||||
"username": user.Username,
|
"username": user.Username,
|
||||||
|
|||||||
@@ -1771,7 +1771,7 @@ const docTemplate = `{
|
|||||||
},
|
},
|
||||||
"/health": {
|
"/health": {
|
||||||
"get": {
|
"get": {
|
||||||
"description": "Check the API health status along with database connectivity details",
|
"description": "Check the API health status along with database connectivity and SMTP service details",
|
||||||
"consumes": [
|
"consumes": [
|
||||||
"application/json"
|
"application/json"
|
||||||
],
|
],
|
||||||
|
|||||||
@@ -1768,7 +1768,7 @@
|
|||||||
},
|
},
|
||||||
"/health": {
|
"/health": {
|
||||||
"get": {
|
"get": {
|
||||||
"description": "Check the API health status along with database connectivity details",
|
"description": "Check the API health status along with database connectivity and SMTP service details",
|
||||||
"consumes": [
|
"consumes": [
|
||||||
"application/json"
|
"application/json"
|
||||||
],
|
],
|
||||||
|
|||||||
@@ -1387,7 +1387,8 @@ paths:
|
|||||||
get:
|
get:
|
||||||
consumes:
|
consumes:
|
||||||
- application/json
|
- application/json
|
||||||
description: Check the API health status along with database connectivity details
|
description: Check the API health status along with database connectivity and
|
||||||
|
SMTP service details
|
||||||
produces:
|
produces:
|
||||||
- application/json
|
- application/json
|
||||||
responses:
|
responses:
|
||||||
|
|||||||
2
go.mod
2
go.mod
@@ -1,6 +1,6 @@
|
|||||||
module goyco
|
module goyco
|
||||||
|
|
||||||
go 1.25.4
|
go 1.26
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/go-chi/chi/v5 v5.2.3
|
github.com/go-chi/chi/v5 v5.2.3
|
||||||
|
|||||||
@@ -289,6 +289,19 @@ func setupTestContext(t *testing.T) *testContext {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func setupTestContextWithMiddleware(t *testing.T) *testContext {
|
||||||
|
t.Helper()
|
||||||
|
server := setupIntegrationTestServerWithMiddlewareEnabled(t)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
server.Cleanup()
|
||||||
|
})
|
||||||
|
return &testContext{
|
||||||
|
server: server,
|
||||||
|
client: server.NewHTTPClient(),
|
||||||
|
baseURL: server.BaseURL(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func setupTestContextWithAuthRateLimit(t *testing.T, authLimit int) *testContext {
|
func setupTestContextWithAuthRateLimit(t *testing.T, authLimit int) *testContext {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
server := setupIntegrationTestServerWithAuthRateLimit(t, authLimit)
|
server := setupIntegrationTestServerWithAuthRateLimit(t, authLimit)
|
||||||
@@ -603,14 +616,34 @@ func generateTokenWithExpiration(t *testing.T, user *database.User, cfg *config.
|
|||||||
|
|
||||||
type serverConfig struct {
|
type serverConfig struct {
|
||||||
authLimit int
|
authLimit int
|
||||||
|
disableCache bool
|
||||||
|
disableCompression bool
|
||||||
|
cacheablePaths []string
|
||||||
}
|
}
|
||||||
|
|
||||||
func setupIntegrationTestServer(t *testing.T) *IntegrationTestServer {
|
func setupIntegrationTestServer(t *testing.T) *IntegrationTestServer {
|
||||||
return setupIntegrationTestServerWithConfig(t, serverConfig{authLimit: 50000})
|
return setupIntegrationTestServerWithConfig(t, serverConfig{
|
||||||
|
authLimit: 50000,
|
||||||
|
disableCache: true,
|
||||||
|
disableCompression: true,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func setupIntegrationTestServerWithAuthRateLimit(t *testing.T, authLimit int) *IntegrationTestServer {
|
func setupIntegrationTestServerWithAuthRateLimit(t *testing.T, authLimit int) *IntegrationTestServer {
|
||||||
return setupIntegrationTestServerWithConfig(t, serverConfig{authLimit: authLimit})
|
return setupIntegrationTestServerWithConfig(t, serverConfig{
|
||||||
|
authLimit: authLimit,
|
||||||
|
disableCache: true,
|
||||||
|
disableCompression: true,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupIntegrationTestServerWithMiddlewareEnabled(t *testing.T) *IntegrationTestServer {
|
||||||
|
return setupIntegrationTestServerWithConfig(t, serverConfig{
|
||||||
|
authLimit: 50000,
|
||||||
|
disableCache: false,
|
||||||
|
disableCompression: false,
|
||||||
|
cacheablePaths: []string{"/api/posts"},
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func setupDatabase(t *testing.T) *gorm.DB {
|
func setupDatabase(t *testing.T) *gorm.DB {
|
||||||
@@ -678,7 +711,7 @@ func setupHandlers(authService handlers.AuthServiceInterface, userRepo repositor
|
|||||||
handlers.NewAPIHandler(cfg, postRepo, userRepo, voteService)
|
handlers.NewAPIHandler(cfg, postRepo, userRepo, voteService)
|
||||||
}
|
}
|
||||||
|
|
||||||
func setupRouter(authHandler *handlers.AuthHandler, postHandler *handlers.PostHandler, voteHandler *handlers.VoteHandler, userHandler *handlers.UserHandler, apiHandler *handlers.APIHandler, authService handlers.AuthServiceInterface, cfg *config.Config) http.Handler {
|
func setupRouter(authHandler *handlers.AuthHandler, postHandler *handlers.PostHandler, voteHandler *handlers.VoteHandler, userHandler *handlers.UserHandler, apiHandler *handlers.APIHandler, authService handlers.AuthServiceInterface, cfg *config.Config, serverCfg serverConfig) http.Handler {
|
||||||
return server.NewRouter(server.RouterConfig{
|
return server.NewRouter(server.RouterConfig{
|
||||||
AuthHandler: authHandler,
|
AuthHandler: authHandler,
|
||||||
PostHandler: postHandler,
|
PostHandler: postHandler,
|
||||||
@@ -689,8 +722,9 @@ func setupRouter(authHandler *handlers.AuthHandler, postHandler *handlers.PostHa
|
|||||||
PageHandler: nil,
|
PageHandler: nil,
|
||||||
StaticDir: findWorkspaceRoot() + "/internal/static/",
|
StaticDir: findWorkspaceRoot() + "/internal/static/",
|
||||||
Debug: false,
|
Debug: false,
|
||||||
DisableCache: true,
|
DisableCache: serverCfg.disableCache,
|
||||||
DisableCompression: true,
|
DisableCompression: serverCfg.disableCompression,
|
||||||
|
CacheablePaths: serverCfg.cacheablePaths,
|
||||||
RateLimitConfig: cfg.RateLimit,
|
RateLimitConfig: cfg.RateLimit,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -735,7 +769,7 @@ func setupIntegrationTestServerWithConfig(t *testing.T, serverCfg serverConfig)
|
|||||||
}
|
}
|
||||||
|
|
||||||
authHandler, postHandler, voteHandler, userHandler, apiHandler := setupHandlers(authService, userRepo, postRepo, voteService, cfg)
|
authHandler, postHandler, voteHandler, userHandler, apiHandler := setupHandlers(authService, userRepo, postRepo, voteService, cfg)
|
||||||
router := setupRouter(authHandler, postHandler, voteHandler, userHandler, apiHandler, authService, cfg)
|
router := setupRouter(authHandler, postHandler, voteHandler, userHandler, apiHandler, authService, cfg, serverCfg)
|
||||||
|
|
||||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
|
||||||
|
|||||||
@@ -422,9 +422,7 @@ func TestE2E_ConcurrentPostCreation(t *testing.T) {
|
|||||||
|
|
||||||
for _, user := range users {
|
for _, user := range users {
|
||||||
u := user
|
u := user
|
||||||
wg.Add(1)
|
wg.Go(func() {
|
||||||
go func() {
|
|
||||||
defer wg.Done()
|
|
||||||
client, err := ctx.loginUserSafe(t, u.Username, u.Password)
|
client, err := ctx.loginUserSafe(t, u.Username, u.Password)
|
||||||
if err != nil || client == nil {
|
if err != nil || client == nil {
|
||||||
results <- nil
|
results <- nil
|
||||||
@@ -443,7 +441,7 @@ func TestE2E_ConcurrentPostCreation(t *testing.T) {
|
|||||||
|
|
||||||
post, err := client.CreatePostSafe("Concurrent Post", url, "Content")
|
post, err := client.CreatePostSafe("Concurrent Post", url, "Content")
|
||||||
results <- post
|
results <- post
|
||||||
}()
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
|
|||||||
@@ -56,10 +56,8 @@ func TestE2E_DatabaseFailureRecovery(t *testing.T) {
|
|||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
errors := make(chan error, 10)
|
errors := make(chan error, 10)
|
||||||
|
|
||||||
for i := 0; i < 5; i++ {
|
for range 5 {
|
||||||
wg.Add(1)
|
wg.Go(func() {
|
||||||
go func() {
|
|
||||||
defer wg.Done()
|
|
||||||
conn, err := sqlDB.Conn(context.Background())
|
conn, err := sqlDB.Conn(context.Background())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errors <- err
|
errors <- err
|
||||||
@@ -68,7 +66,7 @@ func TestE2E_DatabaseFailureRecovery(t *testing.T) {
|
|||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
|
|
||||||
time.Sleep(100 * time.Millisecond)
|
time.Sleep(100 * time.Millisecond)
|
||||||
}()
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
@@ -297,7 +295,7 @@ func TestE2E_DatabaseConnectionPool(t *testing.T) {
|
|||||||
|
|
||||||
initialStats := sqlDB.Stats()
|
initialStats := sqlDB.Stats()
|
||||||
|
|
||||||
for i := 0; i < 5; i++ {
|
for range 5 {
|
||||||
req, err := http.NewRequest("GET", ctx.baseURL+"/api/posts", nil)
|
req, err := http.NewRequest("GET", ctx.baseURL+"/api/posts", nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
continue
|
continue
|
||||||
|
|||||||
@@ -7,14 +7,15 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"goyco/internal/testutils"
|
"goyco/internal/testutils"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestE2E_CompressionMiddleware(t *testing.T) {
|
func TestE2E_CompressionMiddleware(t *testing.T) {
|
||||||
ctx := setupTestContext(t)
|
ctx := setupTestContextWithMiddleware(t)
|
||||||
|
|
||||||
t.Run("compression_enabled_with_accept_encoding", func(t *testing.T) {
|
t.Run("compresses_response_when_accept_encoding_is_gzip", func(t *testing.T) {
|
||||||
request, err := http.NewRequest("GET", ctx.baseURL+"/api/posts", nil)
|
request, err := http.NewRequest("GET", ctx.baseURL+"/api/posts", nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to create request: %v", err)
|
t.Fatalf("Failed to create request: %v", err)
|
||||||
@@ -27,15 +28,22 @@ func TestE2E_CompressionMiddleware(t *testing.T) {
|
|||||||
t.Fatalf("Request failed: %v", err)
|
t.Fatalf("Request failed: %v", err)
|
||||||
}
|
}
|
||||||
defer response.Body.Close()
|
defer response.Body.Close()
|
||||||
|
failIfRateLimited(t, response.StatusCode, "compression enabled GET /api/posts")
|
||||||
|
|
||||||
contentEncoding := response.Header.Get("Content-Encoding")
|
contentEncoding := response.Header.Get("Content-Encoding")
|
||||||
if contentEncoding == "gzip" {
|
if contentEncoding != "gzip" {
|
||||||
|
t.Fatalf("Expected gzip compression, got Content-Encoding=%q", contentEncoding)
|
||||||
|
}
|
||||||
|
|
||||||
body, err := io.ReadAll(response.Body)
|
body, err := io.ReadAll(response.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to read response body: %v", err)
|
t.Fatalf("Failed to read response body: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if isGzipCompressed(body) {
|
if !isGzipCompressed(body) {
|
||||||
|
t.Fatalf("Expected gzip-compressed body bytes")
|
||||||
|
}
|
||||||
|
|
||||||
reader, err := gzip.NewReader(bytes.NewReader(body))
|
reader, err := gzip.NewReader(bytes.NewReader(body))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to create gzip reader: %v", err)
|
t.Fatalf("Failed to create gzip reader: %v", err)
|
||||||
@@ -48,42 +56,44 @@ func TestE2E_CompressionMiddleware(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if len(decompressed) == 0 {
|
if len(decompressed) == 0 {
|
||||||
t.Error("Decompressed body is empty")
|
t.Fatal("Decompressed body is empty")
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
t.Logf("Compression not applied (Content-Encoding: %s)", contentEncoding)
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("no_compression_without_accept_encoding", func(t *testing.T) {
|
t.Run("does_not_compress_without_accept_encoding", func(t *testing.T) {
|
||||||
request, err := http.NewRequest("GET", ctx.baseURL+"/api/posts", nil)
|
request, err := http.NewRequest("GET", ctx.baseURL+"/api/posts", nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to create request: %v", err)
|
t.Fatalf("Failed to create request: %v", err)
|
||||||
}
|
}
|
||||||
testutils.WithStandardHeaders(request)
|
testutils.WithStandardHeaders(request)
|
||||||
|
request.Header.Del("Accept-Encoding")
|
||||||
|
|
||||||
response, err := ctx.client.Do(request)
|
response, err := ctx.client.Do(request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Request failed: %v", err)
|
t.Fatalf("Request failed: %v", err)
|
||||||
}
|
}
|
||||||
defer response.Body.Close()
|
defer response.Body.Close()
|
||||||
|
failIfRateLimited(t, response.StatusCode, "compression disabled GET /api/posts")
|
||||||
|
|
||||||
contentEncoding := response.Header.Get("Content-Encoding")
|
contentEncoding := response.Header.Get("Content-Encoding")
|
||||||
if contentEncoding == "gzip" {
|
if contentEncoding == "gzip" {
|
||||||
t.Error("Expected no compression without Accept-Encoding header")
|
t.Fatal("Expected no compression without Accept-Encoding header")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("decompression_handles_gzip_request", func(t *testing.T) {
|
t.Run("accepts_valid_gzip_request_body", func(t *testing.T) {
|
||||||
testUser := ctx.createUserWithCleanup(t, "compressionuser", "StrongPass123!")
|
testUser := ctx.createUserWithCleanup(t, "compressionuser", "StrongPass123!")
|
||||||
authClient := ctx.loginUser(t, testUser.Username, "StrongPass123!")
|
authClient := ctx.loginUser(t, testUser.Username, "StrongPass123!")
|
||||||
|
|
||||||
var buf bytes.Buffer
|
var buf bytes.Buffer
|
||||||
gz := gzip.NewWriter(&buf)
|
gz := gzip.NewWriter(&buf)
|
||||||
postData := `{"title":"Compressed Post","url":"https://example.com/compressed","content":"Test content"}`
|
postData := `{"title":"Compressed Post","url":"https://example.com/compressed","content":"Test content"}`
|
||||||
gz.Write([]byte(postData))
|
if _, err := gz.Write([]byte(postData)); err != nil {
|
||||||
gz.Close()
|
t.Fatalf("Failed to gzip request body: %v", err)
|
||||||
|
}
|
||||||
|
if err := gz.Close(); err != nil {
|
||||||
|
t.Fatalf("Failed to finalize gzip body: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
request, err := http.NewRequest("POST", ctx.baseURL+"/api/posts", &buf)
|
request, err := http.NewRequest("POST", ctx.baseURL+"/api/posts", &buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -99,20 +109,22 @@ func TestE2E_CompressionMiddleware(t *testing.T) {
|
|||||||
t.Fatalf("Request failed: %v", err)
|
t.Fatalf("Request failed: %v", err)
|
||||||
}
|
}
|
||||||
defer response.Body.Close()
|
defer response.Body.Close()
|
||||||
|
failIfRateLimited(t, response.StatusCode, "decompression POST /api/posts")
|
||||||
|
|
||||||
switch response.StatusCode {
|
switch response.StatusCode {
|
||||||
case http.StatusBadRequest:
|
|
||||||
t.Log("Decompression middleware rejected invalid gzip")
|
|
||||||
case http.StatusCreated, http.StatusOK:
|
case http.StatusCreated, http.StatusOK:
|
||||||
t.Log("Decompression middleware handled gzip request successfully")
|
return
|
||||||
|
default:
|
||||||
|
body, _ := io.ReadAll(response.Body)
|
||||||
|
t.Fatalf("Expected status %d or %d for valid gzip request, got %d. Body: %s", http.StatusCreated, http.StatusOK, response.StatusCode, string(body))
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestE2E_CacheMiddleware(t *testing.T) {
|
func TestE2E_CacheMiddleware(t *testing.T) {
|
||||||
ctx := setupTestContext(t)
|
ctx := setupTestContextWithMiddleware(t)
|
||||||
|
|
||||||
t.Run("cache_miss_then_hit", func(t *testing.T) {
|
t.Run("returns_hit_after_repeated_get", func(t *testing.T) {
|
||||||
firstRequest, err := http.NewRequest("GET", ctx.baseURL+"/api/posts", nil)
|
firstRequest, err := http.NewRequest("GET", ctx.baseURL+"/api/posts", nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to create request: %v", err)
|
t.Fatalf("Failed to create request: %v", err)
|
||||||
@@ -124,12 +136,14 @@ func TestE2E_CacheMiddleware(t *testing.T) {
|
|||||||
t.Fatalf("Request failed: %v", err)
|
t.Fatalf("Request failed: %v", err)
|
||||||
}
|
}
|
||||||
firstResponse.Body.Close()
|
firstResponse.Body.Close()
|
||||||
|
failIfRateLimited(t, firstResponse.StatusCode, "first cache GET /api/posts")
|
||||||
|
|
||||||
firstCacheStatus := firstResponse.Header.Get("X-Cache")
|
firstCacheStatus := firstResponse.Header.Get("X-Cache")
|
||||||
if firstCacheStatus == "HIT" {
|
if firstCacheStatus == "HIT" {
|
||||||
t.Log("First request was cached (unexpected but acceptable)")
|
t.Fatalf("Expected first request to be a cache miss, got X-Cache=%q", firstCacheStatus)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for range 8 {
|
||||||
secondRequest, err := http.NewRequest("GET", ctx.baseURL+"/api/posts", nil)
|
secondRequest, err := http.NewRequest("GET", ctx.baseURL+"/api/posts", nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to create request: %v", err)
|
t.Fatalf("Failed to create request: %v", err)
|
||||||
@@ -140,18 +154,25 @@ func TestE2E_CacheMiddleware(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Request failed: %v", err)
|
t.Fatalf("Request failed: %v", err)
|
||||||
}
|
}
|
||||||
defer secondResponse.Body.Close()
|
failIfRateLimited(t, secondResponse.StatusCode, "cache warmup GET /api/posts")
|
||||||
|
|
||||||
secondCacheStatus := secondResponse.Header.Get("X-Cache")
|
secondCacheStatus := secondResponse.Header.Get("X-Cache")
|
||||||
|
secondResponse.Body.Close()
|
||||||
|
|
||||||
if secondCacheStatus == "HIT" {
|
if secondCacheStatus == "HIT" {
|
||||||
t.Log("Second request was served from cache")
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
time.Sleep(25 * time.Millisecond)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Fatal("Expected a cache HIT on repeated requests, but none observed")
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("cache_invalidation_on_post", func(t *testing.T) {
|
t.Run("invalidates_cached_get_after_post", func(t *testing.T) {
|
||||||
testUser := ctx.createUserWithCleanup(t, "cacheuser", "StrongPass123!")
|
testUser := ctx.createUserWithCleanup(t, "cacheuser", "StrongPass123!")
|
||||||
authClient := ctx.loginUser(t, testUser.Username, "StrongPass123!")
|
authClient := ctx.loginUser(t, testUser.Username, "StrongPass123!")
|
||||||
|
|
||||||
|
for attempt := range 8 {
|
||||||
firstRequest, err := http.NewRequest("GET", ctx.baseURL+"/api/posts", nil)
|
firstRequest, err := http.NewRequest("GET", ctx.baseURL+"/api/posts", nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to create request: %v", err)
|
t.Fatalf("Failed to create request: %v", err)
|
||||||
@@ -163,8 +184,21 @@ func TestE2E_CacheMiddleware(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Request failed: %v", err)
|
t.Fatalf("Request failed: %v", err)
|
||||||
}
|
}
|
||||||
|
failIfRateLimited(t, firstResponse.StatusCode, "cache priming GET /api/posts")
|
||||||
|
cacheStatus := firstResponse.Header.Get("X-Cache")
|
||||||
firstResponse.Body.Close()
|
firstResponse.Body.Close()
|
||||||
|
|
||||||
|
if cacheStatus == "HIT" {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
if attempt == 7 {
|
||||||
|
t.Fatal("Failed to prime cache: repeated GET requests never produced X-Cache=HIT")
|
||||||
|
}
|
||||||
|
|
||||||
|
time.Sleep(25 * time.Millisecond)
|
||||||
|
}
|
||||||
|
|
||||||
postData := `{"title":"Cache Invalidation Test","url":"https://example.com/cache","content":"Test"}`
|
postData := `{"title":"Cache Invalidation Test","url":"https://example.com/cache","content":"Test"}`
|
||||||
secondRequest, err := http.NewRequest("POST", ctx.baseURL+"/api/posts", strings.NewReader(postData))
|
secondRequest, err := http.NewRequest("POST", ctx.baseURL+"/api/posts", strings.NewReader(postData))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -178,8 +212,15 @@ func TestE2E_CacheMiddleware(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Request failed: %v", err)
|
t.Fatalf("Request failed: %v", err)
|
||||||
}
|
}
|
||||||
|
failIfRateLimited(t, secondResponse.StatusCode, "cache invalidation POST /api/posts")
|
||||||
|
if secondResponse.StatusCode != http.StatusCreated && secondResponse.StatusCode != http.StatusOK {
|
||||||
|
body, _ := io.ReadAll(secondResponse.Body)
|
||||||
|
secondResponse.Body.Close()
|
||||||
|
t.Fatalf("Expected post creation status %d or %d, got %d. Body: %s", http.StatusCreated, http.StatusOK, secondResponse.StatusCode, string(body))
|
||||||
|
}
|
||||||
secondResponse.Body.Close()
|
secondResponse.Body.Close()
|
||||||
|
|
||||||
|
for range 8 {
|
||||||
thirdRequest, err := http.NewRequest("GET", ctx.baseURL+"/api/posts", nil)
|
thirdRequest, err := http.NewRequest("GET", ctx.baseURL+"/api/posts", nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to create request: %v", err)
|
t.Fatalf("Failed to create request: %v", err)
|
||||||
@@ -191,19 +232,25 @@ func TestE2E_CacheMiddleware(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Request failed: %v", err)
|
t.Fatalf("Request failed: %v", err)
|
||||||
}
|
}
|
||||||
defer thirdResponse.Body.Close()
|
failIfRateLimited(t, thirdResponse.StatusCode, "post-invalidation GET /api/posts")
|
||||||
|
|
||||||
cacheStatus := thirdResponse.Header.Get("X-Cache")
|
cacheStatus := thirdResponse.Header.Get("X-Cache")
|
||||||
if cacheStatus == "HIT" {
|
thirdResponse.Body.Close()
|
||||||
t.Log("Cache was invalidated after POST")
|
|
||||||
|
if cacheStatus != "HIT" {
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
time.Sleep(25 * time.Millisecond)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Fatal("Expected cache to be invalidated after POST, but X-Cache stayed HIT")
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestE2E_CSRFProtection(t *testing.T) {
|
func TestE2E_CSRFProtection(t *testing.T) {
|
||||||
ctx := setupTestContext(t)
|
ctx := setupTestContext(t)
|
||||||
|
|
||||||
t.Run("csrf_protection_for_non_api_routes", func(t *testing.T) {
|
t.Run("non_api_post_without_csrf_is_forbidden_or_unmounted", func(t *testing.T) {
|
||||||
request, err := http.NewRequest("POST", ctx.baseURL+"/auth/login", strings.NewReader(`{"username":"test","password":"test"}`))
|
request, err := http.NewRequest("POST", ctx.baseURL+"/auth/login", strings.NewReader(`{"username":"test","password":"test"}`))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to create request: %v", err)
|
t.Fatalf("Failed to create request: %v", err)
|
||||||
@@ -216,15 +263,15 @@ func TestE2E_CSRFProtection(t *testing.T) {
|
|||||||
t.Fatalf("Request failed: %v", err)
|
t.Fatalf("Request failed: %v", err)
|
||||||
}
|
}
|
||||||
defer response.Body.Close()
|
defer response.Body.Close()
|
||||||
|
failIfRateLimited(t, response.StatusCode, "CSRF non-API POST /auth/login")
|
||||||
|
|
||||||
if response.StatusCode == http.StatusForbidden {
|
if response.StatusCode != http.StatusForbidden && response.StatusCode != http.StatusNotFound {
|
||||||
t.Log("CSRF protection active for non-API routes")
|
body, _ := io.ReadAll(response.Body)
|
||||||
} else {
|
t.Fatalf("Expected status %d (CSRF protected) or %d (route unavailable in test setup), got %d. Body: %s", http.StatusForbidden, http.StatusNotFound, response.StatusCode, string(body))
|
||||||
t.Logf("CSRF check result: status %d", response.StatusCode)
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("csrf_bypass_for_api_routes", func(t *testing.T) {
|
t.Run("api_post_without_csrf_is_not_forbidden", func(t *testing.T) {
|
||||||
testUser := ctx.createUserWithCleanup(t, "csrfuser", "StrongPass123!")
|
testUser := ctx.createUserWithCleanup(t, "csrfuser", "StrongPass123!")
|
||||||
authClient := ctx.loginUser(t, testUser.Username, "StrongPass123!")
|
authClient := ctx.loginUser(t, testUser.Username, "StrongPass123!")
|
||||||
|
|
||||||
@@ -242,13 +289,15 @@ func TestE2E_CSRFProtection(t *testing.T) {
|
|||||||
t.Fatalf("Request failed: %v", err)
|
t.Fatalf("Request failed: %v", err)
|
||||||
}
|
}
|
||||||
defer response.Body.Close()
|
defer response.Body.Close()
|
||||||
|
failIfRateLimited(t, response.StatusCode, "CSRF bypass POST /api/posts")
|
||||||
|
|
||||||
if response.StatusCode == http.StatusForbidden {
|
if response.StatusCode == http.StatusForbidden {
|
||||||
t.Error("API routes should bypass CSRF protection")
|
body, _ := io.ReadAll(response.Body)
|
||||||
|
t.Fatalf("API routes should bypass CSRF protection, got 403. Body: %s", string(body))
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("csrf_allows_get_requests", func(t *testing.T) {
|
t.Run("get_request_without_csrf_is_not_forbidden", func(t *testing.T) {
|
||||||
request, err := http.NewRequest("GET", ctx.baseURL+"/auth/login", nil)
|
request, err := http.NewRequest("GET", ctx.baseURL+"/auth/login", nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to create request: %v", err)
|
t.Fatalf("Failed to create request: %v", err)
|
||||||
@@ -260,9 +309,10 @@ func TestE2E_CSRFProtection(t *testing.T) {
|
|||||||
t.Fatalf("Request failed: %v", err)
|
t.Fatalf("Request failed: %v", err)
|
||||||
}
|
}
|
||||||
defer response.Body.Close()
|
defer response.Body.Close()
|
||||||
|
failIfRateLimited(t, response.StatusCode, "CSRF GET /auth/login")
|
||||||
|
|
||||||
if response.StatusCode == http.StatusForbidden {
|
if response.StatusCode == http.StatusForbidden {
|
||||||
t.Error("GET requests should not require CSRF token")
|
t.Fatal("GET requests should not require CSRF token")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -270,7 +320,7 @@ func TestE2E_CSRFProtection(t *testing.T) {
|
|||||||
func TestE2E_RequestSizeLimit(t *testing.T) {
|
func TestE2E_RequestSizeLimit(t *testing.T) {
|
||||||
ctx := setupTestContext(t)
|
ctx := setupTestContext(t)
|
||||||
|
|
||||||
t.Run("request_within_size_limit", func(t *testing.T) {
|
t.Run("accepts_request_within_size_limit", func(t *testing.T) {
|
||||||
testUser := ctx.createUserWithCleanup(t, "sizelimituser", "StrongPass123!")
|
testUser := ctx.createUserWithCleanup(t, "sizelimituser", "StrongPass123!")
|
||||||
authClient := ctx.loginUser(t, testUser.Username, "StrongPass123!")
|
authClient := ctx.loginUser(t, testUser.Username, "StrongPass123!")
|
||||||
|
|
||||||
@@ -289,13 +339,15 @@ func TestE2E_RequestSizeLimit(t *testing.T) {
|
|||||||
t.Fatalf("Request failed: %v", err)
|
t.Fatalf("Request failed: %v", err)
|
||||||
}
|
}
|
||||||
defer response.Body.Close()
|
defer response.Body.Close()
|
||||||
|
failIfRateLimited(t, response.StatusCode, "request within size limit POST /api/posts")
|
||||||
|
|
||||||
if response.StatusCode == http.StatusRequestEntityTooLarge {
|
if response.StatusCode == http.StatusRequestEntityTooLarge {
|
||||||
t.Error("Small request should not exceed size limit")
|
body, _ := io.ReadAll(response.Body)
|
||||||
|
t.Fatalf("Small request should not exceed size limit. Body: %s", string(body))
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("request_exceeds_size_limit", func(t *testing.T) {
|
t.Run("rejects_or_fails_oversized_request", func(t *testing.T) {
|
||||||
testUser := ctx.createUserWithCleanup(t, "sizelimituser2", "StrongPass123!")
|
testUser := ctx.createUserWithCleanup(t, "sizelimituser2", "StrongPass123!")
|
||||||
authClient := ctx.loginUser(t, testUser.Username, "StrongPass123!")
|
authClient := ctx.loginUser(t, testUser.Username, "StrongPass123!")
|
||||||
|
|
||||||
@@ -311,14 +363,14 @@ func TestE2E_RequestSizeLimit(t *testing.T) {
|
|||||||
|
|
||||||
response, err := ctx.client.Do(request)
|
response, err := ctx.client.Do(request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
t.Fatalf("Request failed: %v", err)
|
||||||
}
|
}
|
||||||
defer response.Body.Close()
|
defer response.Body.Close()
|
||||||
|
failIfRateLimited(t, response.StatusCode, "request exceeds size limit POST /api/posts")
|
||||||
|
|
||||||
if response.StatusCode == http.StatusRequestEntityTooLarge {
|
if response.StatusCode != http.StatusRequestEntityTooLarge && response.StatusCode != http.StatusBadRequest {
|
||||||
t.Log("Request size limit enforced correctly")
|
body, _ := io.ReadAll(response.Body)
|
||||||
} else {
|
t.Fatalf("Expected status %d or %d for oversized request, got %d. Body: %s", http.StatusRequestEntityTooLarge, http.StatusBadRequest, response.StatusCode, string(body))
|
||||||
t.Logf("Request size limit check result: status %d", response.StatusCode)
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -59,7 +59,7 @@ func TestE2E_Performance(t *testing.T) {
|
|||||||
var totalTime time.Duration
|
var totalTime time.Duration
|
||||||
iterations := 10
|
iterations := 10
|
||||||
|
|
||||||
for i := 0; i < iterations; i++ {
|
for range iterations {
|
||||||
req, err := endpoint.req()
|
req, err := endpoint.req()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to create request: %v", err)
|
t.Fatalf("Failed to create request: %v", err)
|
||||||
@@ -98,11 +98,9 @@ func TestE2E_Performance(t *testing.T) {
|
|||||||
var errorCount int64
|
var errorCount int64
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
for i := 0; i < concurrency; i++ {
|
for range concurrency {
|
||||||
wg.Add(1)
|
wg.Go(func() {
|
||||||
go func() {
|
for range requestsPerGoroutine {
|
||||||
defer wg.Done()
|
|
||||||
for j := 0; j < requestsPerGoroutine; j++ {
|
|
||||||
req, err := http.NewRequest("GET", ctx.baseURL+"/api/posts", nil)
|
req, err := http.NewRequest("GET", ctx.baseURL+"/api/posts", nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
atomic.AddInt64(&errorCount, 1)
|
atomic.AddInt64(&errorCount, 1)
|
||||||
@@ -123,7 +121,7 @@ func TestE2E_Performance(t *testing.T) {
|
|||||||
atomic.AddInt64(&errorCount, 1)
|
atomic.AddInt64(&errorCount, 1)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
@@ -138,7 +136,7 @@ func TestE2E_Performance(t *testing.T) {
|
|||||||
createdUser := ctx.createUserWithCleanup(t, "dbperf", "StrongPass123!")
|
createdUser := ctx.createUserWithCleanup(t, "dbperf", "StrongPass123!")
|
||||||
authClient := ctx.loginUser(t, createdUser.Username, createdUser.Password)
|
authClient := ctx.loginUser(t, createdUser.Username, createdUser.Password)
|
||||||
|
|
||||||
for i := 0; i < 10; i++ {
|
for i := range 10 {
|
||||||
authClient.CreatePost(t, fmt.Sprintf("Post %d", i), fmt.Sprintf("https://example.com/%d", i), "Content")
|
authClient.CreatePost(t, fmt.Sprintf("Post %d", i), fmt.Sprintf("https://example.com/%d", i), "Content")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -160,7 +158,7 @@ func TestE2E_Performance(t *testing.T) {
|
|||||||
authClient := ctx.loginUser(t, createdUser.Username, createdUser.Password)
|
authClient := ctx.loginUser(t, createdUser.Username, createdUser.Password)
|
||||||
|
|
||||||
initialPosts := 50
|
initialPosts := 50
|
||||||
for i := 0; i < initialPosts; i++ {
|
for i := range initialPosts {
|
||||||
authClient.CreatePost(t, fmt.Sprintf("Memory Test Post %d", i), fmt.Sprintf("https://example.com/mem%d", i), "Content")
|
authClient.CreatePost(t, fmt.Sprintf("Memory Test Post %d", i), fmt.Sprintf("https://example.com/mem%d", i), "Content")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -271,16 +269,14 @@ func TestE2E_ConcurrentWrites(t *testing.T) {
|
|||||||
|
|
||||||
for _, user := range users {
|
for _, user := range users {
|
||||||
u := user
|
u := user
|
||||||
wg.Add(1)
|
wg.Go(func() {
|
||||||
go func() {
|
|
||||||
defer wg.Done()
|
|
||||||
authClient, err := ctx.loginUserSafe(t, u.Username, u.Password)
|
authClient, err := ctx.loginUserSafe(t, u.Username, u.Password)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
atomic.AddInt64(&errorCount, 1)
|
atomic.AddInt64(&errorCount, 1)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := 0; i < 5; i++ {
|
for i := range 5 {
|
||||||
post, err := authClient.CreatePostSafe(
|
post, err := authClient.CreatePostSafe(
|
||||||
fmt.Sprintf("Concurrent Post %d", i),
|
fmt.Sprintf("Concurrent Post %d", i),
|
||||||
fmt.Sprintf("https://example.com/concurrent%d-%d", u.ID, i),
|
fmt.Sprintf("https://example.com/concurrent%d-%d", u.ID, i),
|
||||||
@@ -292,7 +288,7 @@ func TestE2E_ConcurrentWrites(t *testing.T) {
|
|||||||
atomic.AddInt64(&errorCount, 1)
|
atomic.AddInt64(&errorCount, 1)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
@@ -311,7 +307,7 @@ func TestE2E_ResponseSize(t *testing.T) {
|
|||||||
createdUser := ctx.createUserWithCleanup(t, "sizetest", "StrongPass123!")
|
createdUser := ctx.createUserWithCleanup(t, "sizetest", "StrongPass123!")
|
||||||
authClient := ctx.loginUser(t, createdUser.Username, createdUser.Password)
|
authClient := ctx.loginUser(t, createdUser.Username, createdUser.Password)
|
||||||
|
|
||||||
for i := 0; i < 100; i++ {
|
for i := range 100 {
|
||||||
authClient.CreatePost(t, fmt.Sprintf("Post %d", i), fmt.Sprintf("https://example.com/%d", i), "Content")
|
authClient.CreatePost(t, fmt.Sprintf("Post %d", i), fmt.Sprintf("https://example.com/%d", i), "Content")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -38,7 +38,7 @@ func TestE2E_RateLimitingHeaders(t *testing.T) {
|
|||||||
t.Error("Expected Retry-After header when rate limited")
|
t.Error("Expected Retry-After header when rate limited")
|
||||||
}
|
}
|
||||||
|
|
||||||
var jsonResponse map[string]interface{}
|
var jsonResponse map[string]any
|
||||||
body, _ := json.Marshal(map[string]string{})
|
body, _ := json.Marshal(map[string]string{})
|
||||||
_ = json.Unmarshal(body, &jsonResponse)
|
_ = json.Unmarshal(body, &jsonResponse)
|
||||||
|
|
||||||
@@ -72,7 +72,7 @@ func TestE2E_RateLimitingHeaders(t *testing.T) {
|
|||||||
if resp.StatusCode != http.StatusTooManyRequests {
|
if resp.StatusCode != http.StatusTooManyRequests {
|
||||||
t.Errorf("Expected status 429 on request %d, got %d", i+1, resp.StatusCode)
|
t.Errorf("Expected status 429 on request %d, got %d", i+1, resp.StatusCode)
|
||||||
} else {
|
} else {
|
||||||
var errorResponse map[string]interface{}
|
var errorResponse map[string]any
|
||||||
body, _ := io.ReadAll(resp.Body)
|
body, _ := io.ReadAll(resp.Body)
|
||||||
if err := json.Unmarshal(body, &errorResponse); err == nil {
|
if err := json.Unmarshal(body, &errorResponse); err == nil {
|
||||||
if errorResponse["error"] == nil {
|
if errorResponse["error"] == nil {
|
||||||
|
|||||||
@@ -30,12 +30,12 @@ func TestE2E_VersionEndpoint(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var apiInfo map[string]interface{}
|
var apiInfo map[string]any
|
||||||
if err := json.NewDecoder(resp.Body).Decode(&apiInfo); err != nil {
|
if err := json.NewDecoder(resp.Body).Decode(&apiInfo); err != nil {
|
||||||
t.Fatalf("Failed to decode API info response: %v", err)
|
t.Fatalf("Failed to decode API info response: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
data, ok := apiInfo["data"].(map[string]interface{})
|
data, ok := apiInfo["data"].(map[string]any)
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatalf("API info data is not a map")
|
t.Fatalf("API info data is not a map")
|
||||||
}
|
}
|
||||||
@@ -74,12 +74,12 @@ func TestE2E_VersionEndpoint(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var healthInfo map[string]interface{}
|
var healthInfo map[string]any
|
||||||
if err := json.NewDecoder(resp.Body).Decode(&healthInfo); err != nil {
|
if err := json.NewDecoder(resp.Body).Decode(&healthInfo); err != nil {
|
||||||
t.Fatalf("Failed to decode health response: %v", err)
|
t.Fatalf("Failed to decode health response: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
data, ok := healthInfo["data"].(map[string]interface{})
|
data, ok := healthInfo["data"].(map[string]any)
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatalf("Health data is not a map")
|
t.Fatalf("Health data is not a map")
|
||||||
}
|
}
|
||||||
@@ -130,18 +130,18 @@ func TestE2E_VersionEndpoint(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var apiInfo map[string]interface{}
|
var apiInfo map[string]any
|
||||||
if err := json.NewDecoder(apiResp.Body).Decode(&apiInfo); err != nil {
|
if err := json.NewDecoder(apiResp.Body).Decode(&apiInfo); err != nil {
|
||||||
t.Fatalf("Failed to decode API info: %v", err)
|
t.Fatalf("Failed to decode API info: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var healthInfo map[string]interface{}
|
var healthInfo map[string]any
|
||||||
if err := json.NewDecoder(healthResp.Body).Decode(&healthInfo); err != nil {
|
if err := json.NewDecoder(healthResp.Body).Decode(&healthInfo); err != nil {
|
||||||
t.Fatalf("Failed to decode health info: %v", err)
|
t.Fatalf("Failed to decode health info: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
apiData, _ := apiInfo["data"].(map[string]interface{})
|
apiData, _ := apiInfo["data"].(map[string]any)
|
||||||
healthData, _ := healthInfo["data"].(map[string]interface{})
|
healthData, _ := healthInfo["data"].(map[string]any)
|
||||||
|
|
||||||
apiVersion, apiOk := apiData["version"].(string)
|
apiVersion, apiOk := apiData["version"].(string)
|
||||||
healthVersion, healthOk := healthData["version"].(string)
|
healthVersion, healthOk := healthData["version"].(string)
|
||||||
@@ -169,12 +169,12 @@ func TestE2E_VersionEndpoint(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var apiInfo map[string]interface{}
|
var apiInfo map[string]any
|
||||||
if err := json.NewDecoder(resp.Body).Decode(&apiInfo); err != nil {
|
if err := json.NewDecoder(resp.Body).Decode(&apiInfo); err != nil {
|
||||||
t.Fatalf("Failed to decode API info: %v", err)
|
t.Fatalf("Failed to decode API info: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
data, ok := apiInfo["data"].(map[string]interface{})
|
data, ok := apiInfo["data"].(map[string]any)
|
||||||
if !ok {
|
if !ok {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -455,7 +455,7 @@ func TestE2E_ConcurrentRequestsWithSameSession(t *testing.T) {
|
|||||||
authClient := ctx.loginUser(t, createdUser.Username, createdUser.Password)
|
authClient := ctx.loginUser(t, createdUser.Username, createdUser.Password)
|
||||||
|
|
||||||
results := make(chan bool, 5)
|
results := make(chan bool, 5)
|
||||||
for i := 0; i < 5; i++ {
|
for range 5 {
|
||||||
go func() {
|
go func() {
|
||||||
profile := authClient.GetProfile(t)
|
profile := authClient.GetProfile(t)
|
||||||
results <- (profile != nil && profile.Data.Username == createdUser.Username)
|
results <- (profile != nil && profile.Data.Username == createdUser.Username)
|
||||||
@@ -463,7 +463,7 @@ func TestE2E_ConcurrentRequestsWithSameSession(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
successCount := 0
|
successCount := 0
|
||||||
for i := 0; i < 5; i++ {
|
for range 5 {
|
||||||
if <-results {
|
if <-results {
|
||||||
successCount++
|
successCount++
|
||||||
}
|
}
|
||||||
@@ -562,7 +562,7 @@ func TestE2E_RapidSuccessiveActions(t *testing.T) {
|
|||||||
|
|
||||||
post := authClient.CreatePost(t, "Rapid Vote Test", "https://example.com/rapid", "Content")
|
post := authClient.CreatePost(t, "Rapid Vote Test", "https://example.com/rapid", "Content")
|
||||||
|
|
||||||
for i := 0; i < 10; i++ {
|
for i := range 10 {
|
||||||
voteType := "up"
|
voteType := "up"
|
||||||
if i%2 == 0 {
|
if i%2 == 0 {
|
||||||
voteType = "down"
|
voteType = "down"
|
||||||
|
|||||||
@@ -134,9 +134,7 @@ func TestE2E_ConcurrentUserWorkflows(t *testing.T) {
|
|||||||
|
|
||||||
for _, user := range users {
|
for _, user := range users {
|
||||||
u := user
|
u := user
|
||||||
wg.Add(1)
|
wg.Go(func() {
|
||||||
go func() {
|
|
||||||
defer wg.Done()
|
|
||||||
var err error
|
var err error
|
||||||
authClient, loginErr := ctx.loginUserSafe(t, u.Username, u.Password)
|
authClient, loginErr := ctx.loginUserSafe(t, u.Username, u.Password)
|
||||||
if loginErr != nil || authClient == nil || authClient.Token == "" {
|
if loginErr != nil || authClient == nil || authClient.Token == "" {
|
||||||
@@ -157,7 +155,7 @@ func TestE2E_ConcurrentUserWorkflows(t *testing.T) {
|
|||||||
case results <- result{userID: u.ID, err: err}:
|
case results <- result{userID: u.ID, err: err}:
|
||||||
case <-done:
|
case <-done:
|
||||||
}
|
}
|
||||||
}()
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
|
|||||||
@@ -53,13 +53,13 @@ func TestRunSanitizationFuzzTest(t *testing.T) {
|
|||||||
|
|
||||||
sanitizeFunc := func(input string) string {
|
sanitizeFunc := func(input string) string {
|
||||||
|
|
||||||
result := ""
|
var result strings.Builder
|
||||||
for _, char := range input {
|
for _, char := range input {
|
||||||
if char != ' ' {
|
if char != ' ' {
|
||||||
result += string(char)
|
result.WriteString(string(char))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return result
|
return result.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
result := sanitizeFunc("hello world")
|
result := sanitizeFunc("hello world")
|
||||||
@@ -76,13 +76,13 @@ func TestRunSanitizationFuzzTestWithValidation(t *testing.T) {
|
|||||||
|
|
||||||
sanitizeFunc := func(input string) string {
|
sanitizeFunc := func(input string) string {
|
||||||
|
|
||||||
result := ""
|
var result strings.Builder
|
||||||
for _, char := range input {
|
for _, char := range input {
|
||||||
if char != ' ' {
|
if char != ' ' {
|
||||||
result += string(char)
|
result.WriteString(string(char))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return result
|
return result.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
validateFunc := func(input string) bool {
|
validateFunc := func(input string) bool {
|
||||||
@@ -1673,7 +1673,7 @@ func TestValidateHTTPRequestWithManyHeaders(t *testing.T) {
|
|||||||
helper := NewFuzzTestHelper()
|
helper := NewFuzzTestHelper()
|
||||||
|
|
||||||
req := httptest.NewRequest("GET", "/api/test", nil)
|
req := httptest.NewRequest("GET", "/api/test", nil)
|
||||||
for i := 0; i < 20; i++ {
|
for i := range 20 {
|
||||||
req.Header.Set(fmt.Sprintf("Header-%d", i), fmt.Sprintf("Value-%d", i))
|
req.Header.Set(fmt.Sprintf("Header-%d", i), fmt.Sprintf("Value-%d", i))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -158,10 +158,3 @@ func FuzzPostRepository(f *testing.F) {
|
|||||||
})
|
})
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func min(a, b int) int {
|
|
||||||
if a < b {
|
|
||||||
return a
|
|
||||||
}
|
|
||||||
return b
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"goyco/internal/config"
|
"goyco/internal/config"
|
||||||
|
"goyco/internal/health"
|
||||||
"goyco/internal/middleware"
|
"goyco/internal/middleware"
|
||||||
"goyco/internal/repositories"
|
"goyco/internal/repositories"
|
||||||
"goyco/internal/services"
|
"goyco/internal/services"
|
||||||
@@ -21,7 +22,7 @@ type APIHandler struct {
|
|||||||
userRepo repositories.UserRepository
|
userRepo repositories.UserRepository
|
||||||
voteService *services.VoteService
|
voteService *services.VoteService
|
||||||
dbMonitor middleware.DBMonitor
|
dbMonitor middleware.DBMonitor
|
||||||
healthChecker *middleware.DatabaseHealthChecker
|
healthChecker *health.CompositeChecker
|
||||||
metricsCollector *middleware.MetricsCollector
|
metricsCollector *middleware.MetricsCollector
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -44,7 +45,21 @@ func NewAPIHandlerWithMonitoring(config *config.Config, postRepo repositories.Po
|
|||||||
return NewAPIHandler(config, postRepo, userRepo, voteService)
|
return NewAPIHandler(config, postRepo, userRepo, voteService)
|
||||||
}
|
}
|
||||||
|
|
||||||
healthChecker := middleware.NewDatabaseHealthChecker(sqlDB, dbMonitor)
|
compositeChecker := health.NewCompositeChecker()
|
||||||
|
|
||||||
|
dbChecker := health.NewDatabaseChecker(sqlDB, dbMonitor)
|
||||||
|
compositeChecker.AddChecker(dbChecker)
|
||||||
|
|
||||||
|
smtpConfig := health.SMTPConfig{
|
||||||
|
Host: config.SMTP.Host,
|
||||||
|
Port: config.SMTP.Port,
|
||||||
|
Username: config.SMTP.Username,
|
||||||
|
Password: config.SMTP.Password,
|
||||||
|
From: config.SMTP.From,
|
||||||
|
}
|
||||||
|
smtpChecker := health.NewSMTPChecker(smtpConfig)
|
||||||
|
compositeChecker.AddChecker(smtpChecker)
|
||||||
|
|
||||||
metricsCollector := middleware.NewMetricsCollector(dbMonitor)
|
metricsCollector := middleware.NewMetricsCollector(dbMonitor)
|
||||||
|
|
||||||
return &APIHandler{
|
return &APIHandler{
|
||||||
@@ -53,7 +68,7 @@ func NewAPIHandlerWithMonitoring(config *config.Config, postRepo repositories.Po
|
|||||||
userRepo: userRepo,
|
userRepo: userRepo,
|
||||||
voteService: voteService,
|
voteService: voteService,
|
||||||
dbMonitor: dbMonitor,
|
dbMonitor: dbMonitor,
|
||||||
healthChecker: healthChecker,
|
healthChecker: compositeChecker,
|
||||||
metricsCollector: metricsCollector,
|
metricsCollector: metricsCollector,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -135,17 +150,17 @@ func (h *APIHandler) GetAPIInfo(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// @Summary Health check
|
// @Summary Health check
|
||||||
// @Description Check the API health status along with database connectivity details
|
// @Description Check the API health status along with database connectivity and SMTP service details
|
||||||
// @Tags api
|
// @Tags api
|
||||||
// @Accept json
|
// @Accept json
|
||||||
// @Produce json
|
// @Produce json
|
||||||
// @Success 200 {object} CommonResponse "Health check successful"
|
// @Success 200 {object} CommonResponse "Health check successful"
|
||||||
// @Router /health [get]
|
// @Router /health [get]
|
||||||
func (h *APIHandler) GetHealth(w http.ResponseWriter, r *http.Request) {
|
func (h *APIHandler) GetHealth(w http.ResponseWriter, r *http.Request) {
|
||||||
|
ctx := r.Context()
|
||||||
|
|
||||||
if h.healthChecker != nil {
|
if h.healthChecker != nil {
|
||||||
health := h.healthChecker.CheckHealth()
|
health := h.healthChecker.CheckWithVersion(ctx, version.GetVersion())
|
||||||
health["version"] = version.GetVersion()
|
|
||||||
SendSuccessResponse(w, "Health check successful", health)
|
SendSuccessResponse(w, "Health check successful", health)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1364,7 +1364,7 @@ func TestAuthHandler_ConcurrentAccess(t *testing.T) {
|
|||||||
concurrency := 10
|
concurrency := 10
|
||||||
done := make(chan bool, concurrency)
|
done := make(chan bool, concurrency)
|
||||||
|
|
||||||
for i := 0; i < concurrency; i++ {
|
for range concurrency {
|
||||||
go func() {
|
go func() {
|
||||||
req := createLoginRequest(`{"username":"testuser","password":"Password123!"}`)
|
req := createLoginRequest(`{"username":"testuser","password":"Password123!"}`)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
@@ -1376,7 +1376,7 @@ func TestAuthHandler_ConcurrentAccess(t *testing.T) {
|
|||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := 0; i < concurrency; i++ {
|
for range concurrency {
|
||||||
<-done
|
<-done
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -115,10 +115,7 @@ func NewPageHandler(templatesDir string, authService AuthServiceInterface, postR
|
|||||||
if start >= len(s) {
|
if start >= len(s) {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
end := start + length
|
end := min(start+length, len(s))
|
||||||
if end > len(s) {
|
|
||||||
end = len(s)
|
|
||||||
}
|
|
||||||
return s[start:end]
|
return s[start:end]
|
||||||
},
|
},
|
||||||
"upper": strings.ToUpper,
|
"upper": strings.ToUpper,
|
||||||
|
|||||||
54
internal/health/composite.go
Normal file
54
internal/health/composite.go
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
package health
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type CompositeChecker struct {
|
||||||
|
checkers []Checker
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewCompositeChecker(checkers ...Checker) *CompositeChecker {
|
||||||
|
return &CompositeChecker{
|
||||||
|
checkers: checkers,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *CompositeChecker) AddChecker(checker Checker) {
|
||||||
|
c.checkers = append(c.checkers, checker)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *CompositeChecker) Check(ctx context.Context) OverallResult {
|
||||||
|
results := make(map[string]Result)
|
||||||
|
var mu sync.Mutex
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
|
||||||
|
for _, checker := range c.checkers {
|
||||||
|
wg.Add(1)
|
||||||
|
go func(ch Checker) {
|
||||||
|
defer wg.Done()
|
||||||
|
|
||||||
|
result := ch.Check(ctx)
|
||||||
|
|
||||||
|
mu.Lock()
|
||||||
|
results[ch.Name()] = result
|
||||||
|
mu.Unlock()
|
||||||
|
}(checker)
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
return OverallResult{
|
||||||
|
Status: determineOverallStatus(results),
|
||||||
|
Timestamp: time.Now().UTC(),
|
||||||
|
Services: results,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *CompositeChecker) CheckWithVersion(ctx context.Context, version string) OverallResult {
|
||||||
|
result := c.Check(ctx)
|
||||||
|
result.Version = version
|
||||||
|
return result
|
||||||
|
}
|
||||||
61
internal/health/database.go
Normal file
61
internal/health/database.go
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
package health
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"goyco/internal/middleware"
|
||||||
|
)
|
||||||
|
|
||||||
|
type DatabaseChecker struct {
|
||||||
|
db *sql.DB
|
||||||
|
monitor middleware.DBMonitor
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewDatabaseChecker(db *sql.DB, monitor middleware.DBMonitor) *DatabaseChecker {
|
||||||
|
return &DatabaseChecker{
|
||||||
|
db: db,
|
||||||
|
monitor: monitor,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *DatabaseChecker) Name() string {
|
||||||
|
return "database"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *DatabaseChecker) Check(ctx context.Context) Result {
|
||||||
|
start := time.Now()
|
||||||
|
|
||||||
|
err := c.db.Ping()
|
||||||
|
latency := time.Since(start)
|
||||||
|
|
||||||
|
result := Result{
|
||||||
|
Status: StatusHealthy,
|
||||||
|
Latency: latency,
|
||||||
|
Timestamp: time.Now().UTC(),
|
||||||
|
Details: map[string]any{
|
||||||
|
"ping_time": latency.String(),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
result.Status = StatusUnhealthy
|
||||||
|
result.Message = err.Error()
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.monitor != nil {
|
||||||
|
stats := c.monitor.GetStats()
|
||||||
|
result.Details["stats"] = map[string]any{
|
||||||
|
"total_queries": stats.TotalQueries,
|
||||||
|
"slow_queries": stats.SlowQueries,
|
||||||
|
"average_duration": stats.AverageDuration.String(),
|
||||||
|
"max_duration": stats.MaxDuration.String(),
|
||||||
|
"error_count": stats.ErrorCount,
|
||||||
|
"last_query_time": stats.LastQueryTime.Format(time.RFC3339),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
}
|
||||||
61
internal/health/health.go
Normal file
61
internal/health/health.go
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
package health
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Status string
|
||||||
|
|
||||||
|
const (
|
||||||
|
StatusHealthy Status = "healthy"
|
||||||
|
StatusDegraded Status = "degraded"
|
||||||
|
StatusUnhealthy Status = "unhealthy"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Result struct {
|
||||||
|
Status Status `json:"status"`
|
||||||
|
Message string `json:"message,omitempty"`
|
||||||
|
Latency time.Duration `json:"latency"`
|
||||||
|
Timestamp time.Time `json:"timestamp"`
|
||||||
|
Details map[string]any `json:"details,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Checker interface {
|
||||||
|
Name() string
|
||||||
|
Check(ctx context.Context) Result
|
||||||
|
}
|
||||||
|
|
||||||
|
type OverallResult struct {
|
||||||
|
Status Status `json:"status"`
|
||||||
|
Timestamp time.Time `json:"timestamp"`
|
||||||
|
Version string `json:"version"`
|
||||||
|
Services map[string]Result `json:"services"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func determineOverallStatus(results map[string]Result) Status {
|
||||||
|
hasUnhealthy := false
|
||||||
|
hasSMTPUnhealthy := false
|
||||||
|
hasDegraded := false
|
||||||
|
|
||||||
|
for name, result := range results {
|
||||||
|
switch result.Status {
|
||||||
|
case StatusUnhealthy:
|
||||||
|
if name == "smtp" {
|
||||||
|
hasSMTPUnhealthy = true
|
||||||
|
} else {
|
||||||
|
hasUnhealthy = true
|
||||||
|
}
|
||||||
|
case StatusDegraded:
|
||||||
|
hasDegraded = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if hasUnhealthy {
|
||||||
|
return StatusUnhealthy
|
||||||
|
}
|
||||||
|
if hasDegraded || hasSMTPUnhealthy {
|
||||||
|
return StatusDegraded
|
||||||
|
}
|
||||||
|
return StatusHealthy
|
||||||
|
}
|
||||||
236
internal/health/health_test.go
Normal file
236
internal/health/health_test.go
Normal file
@@ -0,0 +1,236 @@
|
|||||||
|
package health
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
_ "github.com/mattn/go-sqlite3"
|
||||||
|
"goyco/internal/middleware"
|
||||||
|
)
|
||||||
|
|
||||||
|
type MockDBMonitor struct {
|
||||||
|
stats middleware.DBStats
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockDBMonitor) GetStats() middleware.DBStats {
|
||||||
|
return m.stats
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockDBMonitor) LogQuery(query string, duration time.Duration, err error) {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockDBMonitor) LogSlowQuery(query string, duration time.Duration, threshold time.Duration) {
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStatusConstants(t *testing.T) {
|
||||||
|
if StatusHealthy != "healthy" {
|
||||||
|
t.Errorf("Expected StatusHealthy to be 'healthy', got %s", StatusHealthy)
|
||||||
|
}
|
||||||
|
if StatusDegraded != "degraded" {
|
||||||
|
t.Errorf("Expected StatusDegraded to be 'degraded', got %s", StatusDegraded)
|
||||||
|
}
|
||||||
|
if StatusUnhealthy != "unhealthy" {
|
||||||
|
t.Errorf("Expected StatusUnhealthy to be 'unhealthy', got %s", StatusUnhealthy)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDetermineOverallStatus(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
results map[string]Result
|
||||||
|
expected Status
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "all healthy",
|
||||||
|
results: map[string]Result{"db": {Status: StatusHealthy}, "smtp": {Status: StatusHealthy}},
|
||||||
|
expected: StatusHealthy,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "one degraded",
|
||||||
|
results: map[string]Result{"db": {Status: StatusHealthy}, "smtp": {Status: StatusDegraded}},
|
||||||
|
expected: StatusDegraded,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "one unhealthy",
|
||||||
|
results: map[string]Result{"db": {Status: StatusUnhealthy}, "smtp": {Status: StatusHealthy}},
|
||||||
|
expected: StatusUnhealthy,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "smtp unhealthy downgrades overall to degraded",
|
||||||
|
results: map[string]Result{"db": {Status: StatusHealthy}, "smtp": {Status: StatusUnhealthy}},
|
||||||
|
expected: StatusDegraded,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "mixed degraded and unhealthy",
|
||||||
|
results: map[string]Result{"db": {Status: StatusDegraded}, "smtp": {Status: StatusUnhealthy}},
|
||||||
|
expected: StatusDegraded,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty results",
|
||||||
|
results: map[string]Result{},
|
||||||
|
expected: StatusHealthy,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := determineOverallStatus(tt.results)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("Expected %s, got %s", tt.expected, result)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDatabaseChecker_Name(t *testing.T) {
|
||||||
|
checker := &DatabaseChecker{}
|
||||||
|
if checker.Name() != "database" {
|
||||||
|
t.Errorf("Expected name 'database', got %s", checker.Name())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDatabaseChecker_Check(t *testing.T) {
|
||||||
|
db, err := sql.Open("sqlite", ":memory:")
|
||||||
|
if err != nil {
|
||||||
|
t.Logf("Could not open test database: %v", err)
|
||||||
|
t.Skip("Skipping database-dependent test")
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
monitor := &MockDBMonitor{
|
||||||
|
stats: middleware.DBStats{
|
||||||
|
TotalQueries: 10,
|
||||||
|
SlowQueries: 1,
|
||||||
|
AverageDuration: 5 * time.Millisecond,
|
||||||
|
MaxDuration: 20 * time.Millisecond,
|
||||||
|
ErrorCount: 0,
|
||||||
|
LastQueryTime: time.Now(),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
checker := NewDatabaseChecker(db, monitor)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
result := checker.Check(ctx)
|
||||||
|
|
||||||
|
if result.Timestamp.IsZero() {
|
||||||
|
t.Error("Expected non-zero timestamp")
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := result.Details["ping_time"]; !ok {
|
||||||
|
t.Error("Expected ping_time in details")
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.Status == StatusHealthy {
|
||||||
|
stats, ok := result.Details["stats"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Log("Stats not available in details (may be expected)")
|
||||||
|
} else {
|
||||||
|
if stats["total_queries"] != int64(10) {
|
||||||
|
t.Errorf("Expected total_queries to be 10, got %v", stats["total_queries"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDatabaseChecker_Check_Unhealthy(t *testing.T) {
|
||||||
|
checker := NewDatabaseChecker(nil, nil)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
t.Logf("Got expected panic with nil db: %v", r)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
result := checker.Check(ctx)
|
||||||
|
|
||||||
|
if result.Status != StatusUnhealthy {
|
||||||
|
t.Logf("Expected unhealthy status for nil db, got %s", result.Status)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCompositeChecker(t *testing.T) {
|
||||||
|
checker1 := &mockChecker{
|
||||||
|
name: "service1",
|
||||||
|
status: StatusHealthy,
|
||||||
|
}
|
||||||
|
checker2 := &mockChecker{
|
||||||
|
name: "service2",
|
||||||
|
status: StatusHealthy,
|
||||||
|
}
|
||||||
|
|
||||||
|
composite := NewCompositeChecker(checker1, checker2)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
result := composite.Check(ctx)
|
||||||
|
|
||||||
|
if result.Status != StatusHealthy {
|
||||||
|
t.Errorf("Expected overall healthy status, got %s", result.Status)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(result.Services) != 2 {
|
||||||
|
t.Errorf("Expected 2 service results, got %d", len(result.Services))
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := result.Services["service1"]; !ok {
|
||||||
|
t.Error("Expected service1 in results")
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := result.Services["service2"]; !ok {
|
||||||
|
t.Error("Expected service2 in results")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCompositeChecker_AddChecker(t *testing.T) {
|
||||||
|
composite := NewCompositeChecker()
|
||||||
|
|
||||||
|
checker := &mockChecker{
|
||||||
|
name: "test-service",
|
||||||
|
status: StatusHealthy,
|
||||||
|
}
|
||||||
|
|
||||||
|
composite.AddChecker(checker)
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
result := composite.Check(ctx)
|
||||||
|
|
||||||
|
if len(result.Services) != 1 {
|
||||||
|
t.Errorf("Expected 1 service result, got %d", len(result.Services))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCompositeChecker_CheckWithVersion(t *testing.T) {
|
||||||
|
checker := &mockChecker{
|
||||||
|
name: "test",
|
||||||
|
status: StatusHealthy,
|
||||||
|
}
|
||||||
|
|
||||||
|
composite := NewCompositeChecker(checker)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
result := composite.CheckWithVersion(ctx, "v1.2.3")
|
||||||
|
|
||||||
|
if result.Version != "v1.2.3" {
|
||||||
|
t.Errorf("Expected version 'v1.2.3', got %s", result.Version)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type mockChecker struct {
|
||||||
|
name string
|
||||||
|
status Status
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockChecker) Name() string {
|
||||||
|
return m.name
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockChecker) Check(ctx context.Context) Result {
|
||||||
|
return Result{
|
||||||
|
Status: m.status,
|
||||||
|
Timestamp: time.Now().UTC(),
|
||||||
|
Latency: 1 * time.Millisecond,
|
||||||
|
}
|
||||||
|
}
|
||||||
106
internal/health/smtp.go
Normal file
106
internal/health/smtp.go
Normal file
@@ -0,0 +1,106 @@
|
|||||||
|
package health
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/smtp"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type SMTPConfig struct {
|
||||||
|
Host string
|
||||||
|
Port int
|
||||||
|
Username string
|
||||||
|
Password string
|
||||||
|
From string
|
||||||
|
}
|
||||||
|
|
||||||
|
type SMTPChecker struct {
|
||||||
|
config SMTPConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewSMTPChecker(config SMTPConfig) *SMTPChecker {
|
||||||
|
return &SMTPChecker{
|
||||||
|
config: config,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *SMTPChecker) Name() string {
|
||||||
|
return "smtp"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *SMTPChecker) Check(ctx context.Context) Result {
|
||||||
|
start := time.Now()
|
||||||
|
address := net.JoinHostPort(c.config.Host, fmt.Sprintf("%d", c.config.Port))
|
||||||
|
|
||||||
|
result := Result{
|
||||||
|
Status: StatusHealthy,
|
||||||
|
Timestamp: time.Now().UTC(),
|
||||||
|
Details: map[string]any{
|
||||||
|
"host": address,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := net.Dial("tcp", address)
|
||||||
|
if err != nil {
|
||||||
|
result.Status = StatusUnhealthy
|
||||||
|
result.Message = fmt.Sprintf("Failed to connect to SMTP server: %v", err)
|
||||||
|
result.Latency = time.Since(start)
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
client, err := smtp.NewClient(conn, c.config.Host)
|
||||||
|
if err != nil {
|
||||||
|
result.Status = StatusUnhealthy
|
||||||
|
result.Message = fmt.Sprintf("Failed to create SMTP client: %v", err)
|
||||||
|
result.Latency = time.Since(start)
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
defer client.Close()
|
||||||
|
|
||||||
|
err = client.Hello("goyco-health-check")
|
||||||
|
if err != nil {
|
||||||
|
result.Status = StatusDegraded
|
||||||
|
result.Message = fmt.Sprintf("EHLO failed: %v", err)
|
||||||
|
result.Latency = time.Since(start)
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
if ok, _ := client.Extension("STARTTLS"); ok {
|
||||||
|
tlsConfig := &tls.Config{
|
||||||
|
ServerName: c.config.Host,
|
||||||
|
}
|
||||||
|
err = client.StartTLS(tlsConfig)
|
||||||
|
if err != nil {
|
||||||
|
result.Details["starttls"] = "failed"
|
||||||
|
result.Details["starttls_error"] = err.Error()
|
||||||
|
} else {
|
||||||
|
result.Details["starttls"] = "enabled"
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
result.Details["starttls"] = "not supported"
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.config.Username != "" && c.config.Password != "" {
|
||||||
|
auth := smtp.PlainAuth("", c.config.Username, c.config.Password, c.config.Host)
|
||||||
|
err = client.Auth(auth)
|
||||||
|
if err != nil {
|
||||||
|
result.Details["auth"] = "failed"
|
||||||
|
result.Details["auth_error"] = err.Error()
|
||||||
|
} else {
|
||||||
|
result.Details["auth"] = "success"
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
result.Details["auth"] = "not configured"
|
||||||
|
}
|
||||||
|
|
||||||
|
client.Quit()
|
||||||
|
|
||||||
|
result.Latency = time.Since(start)
|
||||||
|
result.Details["handshake"] = "completed"
|
||||||
|
|
||||||
|
return result
|
||||||
|
}
|
||||||
309
internal/health/smtp_test.go
Normal file
309
internal/health/smtp_test.go
Normal file
@@ -0,0 +1,309 @@
|
|||||||
|
package health
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestSMTPChecker_Check_NoAuth(t *testing.T) {
|
||||||
|
config := SMTPConfig{
|
||||||
|
Host: "smtp.example.com",
|
||||||
|
Port: 25,
|
||||||
|
}
|
||||||
|
|
||||||
|
checker := NewSMTPChecker(config)
|
||||||
|
|
||||||
|
if checker.config.Username != "" {
|
||||||
|
t.Error("expected empty username")
|
||||||
|
}
|
||||||
|
if checker.config.Password != "" {
|
||||||
|
t.Error("expected empty password")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSMTPChecker_Check_InvalidPort(t *testing.T) {
|
||||||
|
config := SMTPConfig{
|
||||||
|
Host: "localhost",
|
||||||
|
Port: 99999,
|
||||||
|
}
|
||||||
|
|
||||||
|
checker := NewSMTPChecker(config)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
result := checker.Check(ctx)
|
||||||
|
|
||||||
|
if result.Status != StatusUnhealthy {
|
||||||
|
t.Errorf("expected unhealthy status for invalid port, got %s", result.Status)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSMTPChecker_Check_ConnectionRefused(t *testing.T) {
|
||||||
|
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create listener: %v", err)
|
||||||
|
}
|
||||||
|
port := listener.Addr().(*net.TCPAddr).Port
|
||||||
|
listener.Close()
|
||||||
|
|
||||||
|
config := SMTPConfig{
|
||||||
|
Host: "127.0.0.1",
|
||||||
|
Port: port,
|
||||||
|
}
|
||||||
|
|
||||||
|
checker := NewSMTPChecker(config)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
result := checker.Check(ctx)
|
||||||
|
|
||||||
|
if result.Status != StatusUnhealthy {
|
||||||
|
t.Errorf("expected unhealthy status for connection refused, got %s", result.Status)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSMTPChecker_Check_WithMockServer(t *testing.T) {
|
||||||
|
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create listener: %v", err)
|
||||||
|
}
|
||||||
|
defer listener.Close()
|
||||||
|
|
||||||
|
port := listener.Addr().(*net.TCPAddr).Port
|
||||||
|
|
||||||
|
serverDone := make(chan bool)
|
||||||
|
go func() {
|
||||||
|
defer close(serverDone)
|
||||||
|
conn, err := listener.Accept()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
conn.Write([]byte("220 test.example.com ESMTP ready\r\n"))
|
||||||
|
|
||||||
|
buf := make([]byte, 1024)
|
||||||
|
n, _ := conn.Read(buf)
|
||||||
|
if n > 0 {
|
||||||
|
conn.Write([]byte("250-test.example.com\r\n250-STARTTLS\r\n250 AUTH PLAIN\r\n"))
|
||||||
|
}
|
||||||
|
|
||||||
|
n, _ = conn.Read(buf)
|
||||||
|
if n > 0 {
|
||||||
|
conn.Write([]byte("220 2.0.0 Ready to start TLS\r\n"))
|
||||||
|
}
|
||||||
|
|
||||||
|
n, _ = conn.Read(buf)
|
||||||
|
if n > 0 {
|
||||||
|
conn.Write([]byte("235 2.7.0 Authentication successful\r\n"))
|
||||||
|
}
|
||||||
|
|
||||||
|
n, _ = conn.Read(buf)
|
||||||
|
if n > 0 {
|
||||||
|
conn.Write([]byte("221 2.0.0 Bye\r\n"))
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
config := SMTPConfig{
|
||||||
|
Host: "127.0.0.1",
|
||||||
|
Port: port,
|
||||||
|
Username: "test@example.com",
|
||||||
|
Password: "password",
|
||||||
|
}
|
||||||
|
|
||||||
|
checker := NewSMTPChecker(config)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
result := checker.Check(ctx)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-serverDone:
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Error("server timeout")
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.Status != StatusHealthy && result.Status != StatusDegraded {
|
||||||
|
t.Errorf("unexpected status: %s", result.Status)
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.Details["host"] == nil {
|
||||||
|
t.Errorf("expected host in details, got %v", result.Details["host"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSMTPChecker_Check_EHLOFailure(t *testing.T) {
|
||||||
|
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create listener: %v", err)
|
||||||
|
}
|
||||||
|
defer listener.Close()
|
||||||
|
|
||||||
|
port := listener.Addr().(*net.TCPAddr).Port
|
||||||
|
|
||||||
|
serverDone := make(chan bool)
|
||||||
|
go func() {
|
||||||
|
defer close(serverDone)
|
||||||
|
conn, err := listener.Accept()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
conn.Write([]byte("220 test.example.com ESMTP ready\r\n"))
|
||||||
|
|
||||||
|
buf := make([]byte, 1024)
|
||||||
|
n, _ := conn.Read(buf)
|
||||||
|
if n > 0 {
|
||||||
|
conn.Write([]byte("500 Syntax error, command unrecognized\r\n"))
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
config := SMTPConfig{
|
||||||
|
Host: "127.0.0.1",
|
||||||
|
Port: port,
|
||||||
|
}
|
||||||
|
|
||||||
|
checker := NewSMTPChecker(config)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
result := checker.Check(ctx)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-serverDone:
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Error("server timeout")
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.Status != StatusDegraded {
|
||||||
|
t.Errorf("expected degraded status for EHLO failure, got %s", result.Status)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(result.Message, "EHLO") {
|
||||||
|
t.Errorf("expected EHLO error in message, got: %s", result.Message)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSMTPChecker_Check_STARTTLSNotSupported(t *testing.T) {
|
||||||
|
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create listener: %v", err)
|
||||||
|
}
|
||||||
|
defer listener.Close()
|
||||||
|
|
||||||
|
port := listener.Addr().(*net.TCPAddr).Port
|
||||||
|
|
||||||
|
serverDone := make(chan bool)
|
||||||
|
go func() {
|
||||||
|
defer close(serverDone)
|
||||||
|
conn, err := listener.Accept()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
conn.Write([]byte("220 test.example.com ESMTP ready\r\n"))
|
||||||
|
|
||||||
|
buf := make([]byte, 1024)
|
||||||
|
n, _ := conn.Read(buf)
|
||||||
|
if n > 0 {
|
||||||
|
conn.Write([]byte("250-test.example.com\r\n250 AUTH PLAIN\r\n"))
|
||||||
|
}
|
||||||
|
|
||||||
|
n, _ = conn.Read(buf)
|
||||||
|
if n > 0 {
|
||||||
|
conn.Write([]byte("221 2.0.0 Bye\r\n"))
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
config := SMTPConfig{
|
||||||
|
Host: "127.0.0.1",
|
||||||
|
Port: port,
|
||||||
|
}
|
||||||
|
|
||||||
|
checker := NewSMTPChecker(config)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
result := checker.Check(ctx)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-serverDone:
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Error("server timeout")
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.Details["starttls"] != "not supported" && result.Details["starttls"] != "failed" {
|
||||||
|
t.Errorf("expected starttls not supported or failed, got: %v", result.Details["starttls"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSMTPChecker_Check_AuthNotConfigured(t *testing.T) {
|
||||||
|
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create listener: %v", err)
|
||||||
|
}
|
||||||
|
defer listener.Close()
|
||||||
|
|
||||||
|
port := listener.Addr().(*net.TCPAddr).Port
|
||||||
|
|
||||||
|
serverDone := make(chan bool)
|
||||||
|
go func() {
|
||||||
|
defer close(serverDone)
|
||||||
|
conn, err := listener.Accept()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
conn.Write([]byte("220 test.example.com ESMTP ready\r\n"))
|
||||||
|
|
||||||
|
buf := make([]byte, 1024)
|
||||||
|
n, _ := conn.Read(buf)
|
||||||
|
if n > 0 {
|
||||||
|
conn.Write([]byte("250-test.example.com\r\n250-STARTTLS\r\n250 AUTH PLAIN\r\n"))
|
||||||
|
}
|
||||||
|
|
||||||
|
n, _ = conn.Read(buf)
|
||||||
|
if n > 0 {
|
||||||
|
conn.Write([]byte("221 2.0.0 Bye\r\n"))
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
config := SMTPConfig{
|
||||||
|
Host: "127.0.0.1",
|
||||||
|
Port: port,
|
||||||
|
}
|
||||||
|
|
||||||
|
checker := NewSMTPChecker(config)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
result := checker.Check(ctx)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-serverDone:
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Error("server timeout")
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.Details["auth"] != "not configured" {
|
||||||
|
t.Errorf("expected auth not configured, got: %v", result.Details["auth"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSMTPConfig_GetAddress(t *testing.T) {
|
||||||
|
config := SMTPConfig{
|
||||||
|
Host: "smtp.example.com",
|
||||||
|
Port: 587,
|
||||||
|
}
|
||||||
|
|
||||||
|
address := getSMTPAddress(config)
|
||||||
|
expected := "smtp.example.com:587"
|
||||||
|
if address != expected {
|
||||||
|
t.Errorf("expected address %s, got %s", expected, address)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func getSMTPAddress(config SMTPConfig) string {
|
||||||
|
return config.Host + ":" + strconv.Itoa(config.Port)
|
||||||
|
}
|
||||||
@@ -635,7 +635,7 @@ func TestIntegration_CompleteAPIEndpoints(t *testing.T) {
|
|||||||
ctx.Suite.EmailSender.Reset()
|
ctx.Suite.EmailSender.Reset()
|
||||||
paginationUser := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, uniqueTestUsername(t, "pagination_edge"), uniqueTestEmail(t, "pagination_edge"))
|
paginationUser := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, uniqueTestUsername(t, "pagination_edge"), uniqueTestEmail(t, "pagination_edge"))
|
||||||
|
|
||||||
for i := 0; i < 5; i++ {
|
for i := range 5 {
|
||||||
testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, paginationUser.User.ID, fmt.Sprintf("Pagination Post %d", i), fmt.Sprintf("https://example.com/pag%d", i))
|
testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, paginationUser.User.ID, fmt.Sprintf("Pagination Post %d", i), fmt.Sprintf("https://example.com/pag%d", i))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -50,7 +50,7 @@ func TestIntegration_RateLimiting(t *testing.T) {
|
|||||||
rateLimitConfig.AuthLimit = 2
|
rateLimitConfig.AuthLimit = 2
|
||||||
router, _ := setupRateLimitRouter(t, rateLimitConfig)
|
router, _ := setupRateLimitRouter(t, rateLimitConfig)
|
||||||
|
|
||||||
for i := 0; i < 2; i++ {
|
for range 2 {
|
||||||
request := httptest.NewRequest("POST", "/api/auth/login", bytes.NewBufferString(`{"username":"test","password":"test"}`))
|
request := httptest.NewRequest("POST", "/api/auth/login", bytes.NewBufferString(`{"username":"test","password":"test"}`))
|
||||||
request.Header.Set("Content-Type", "application/json")
|
request.Header.Set("Content-Type", "application/json")
|
||||||
recorder := httptest.NewRecorder()
|
recorder := httptest.NewRecorder()
|
||||||
@@ -80,7 +80,7 @@ func TestIntegration_RateLimiting(t *testing.T) {
|
|||||||
rateLimitConfig.GeneralLimit = 5
|
rateLimitConfig.GeneralLimit = 5
|
||||||
router, _ := setupRateLimitRouter(t, rateLimitConfig)
|
router, _ := setupRateLimitRouter(t, rateLimitConfig)
|
||||||
|
|
||||||
for i := 0; i < 5; i++ {
|
for range 5 {
|
||||||
request := httptest.NewRequest("GET", "/api/posts", nil)
|
request := httptest.NewRequest("GET", "/api/posts", nil)
|
||||||
recorder := httptest.NewRecorder()
|
recorder := httptest.NewRecorder()
|
||||||
router.ServeHTTP(recorder, request)
|
router.ServeHTTP(recorder, request)
|
||||||
@@ -99,7 +99,7 @@ func TestIntegration_RateLimiting(t *testing.T) {
|
|||||||
rateLimitConfig.HealthLimit = 3
|
rateLimitConfig.HealthLimit = 3
|
||||||
router, _ := setupRateLimitRouter(t, rateLimitConfig)
|
router, _ := setupRateLimitRouter(t, rateLimitConfig)
|
||||||
|
|
||||||
for i := 0; i < 3; i++ {
|
for range 3 {
|
||||||
request := httptest.NewRequest("GET", "/health", nil)
|
request := httptest.NewRequest("GET", "/health", nil)
|
||||||
recorder := httptest.NewRecorder()
|
recorder := httptest.NewRecorder()
|
||||||
router.ServeHTTP(recorder, request)
|
router.ServeHTTP(recorder, request)
|
||||||
@@ -118,7 +118,7 @@ func TestIntegration_RateLimiting(t *testing.T) {
|
|||||||
rateLimitConfig.MetricsLimit = 2
|
rateLimitConfig.MetricsLimit = 2
|
||||||
router, _ := setupRateLimitRouter(t, rateLimitConfig)
|
router, _ := setupRateLimitRouter(t, rateLimitConfig)
|
||||||
|
|
||||||
for i := 0; i < 2; i++ {
|
for range 2 {
|
||||||
request := httptest.NewRequest("GET", "/metrics", nil)
|
request := httptest.NewRequest("GET", "/metrics", nil)
|
||||||
recorder := httptest.NewRecorder()
|
recorder := httptest.NewRecorder()
|
||||||
router.ServeHTTP(recorder, request)
|
router.ServeHTTP(recorder, request)
|
||||||
@@ -138,7 +138,7 @@ func TestIntegration_RateLimiting(t *testing.T) {
|
|||||||
rateLimitConfig.GeneralLimit = 10
|
rateLimitConfig.GeneralLimit = 10
|
||||||
router, _ := setupRateLimitRouter(t, rateLimitConfig)
|
router, _ := setupRateLimitRouter(t, rateLimitConfig)
|
||||||
|
|
||||||
for i := 0; i < 2; i++ {
|
for range 2 {
|
||||||
request := httptest.NewRequest("POST", "/api/auth/login", bytes.NewBufferString(`{"username":"test","password":"test"}`))
|
request := httptest.NewRequest("POST", "/api/auth/login", bytes.NewBufferString(`{"username":"test","password":"test"}`))
|
||||||
request.Header.Set("Content-Type", "application/json")
|
request.Header.Set("Content-Type", "application/json")
|
||||||
recorder := httptest.NewRecorder()
|
recorder := httptest.NewRecorder()
|
||||||
@@ -165,7 +165,7 @@ func TestIntegration_RateLimiting(t *testing.T) {
|
|||||||
suite.EmailSender.Reset()
|
suite.EmailSender.Reset()
|
||||||
user := createAuthenticatedUser(t, authService, suite.UserRepo, uniqueTestUsername(t, "ratelimit_auth"), uniqueTestEmail(t, "ratelimit_auth"))
|
user := createAuthenticatedUser(t, authService, suite.UserRepo, uniqueTestUsername(t, "ratelimit_auth"), uniqueTestEmail(t, "ratelimit_auth"))
|
||||||
|
|
||||||
for i := 0; i < 3; i++ {
|
for range 3 {
|
||||||
request := httptest.NewRequest("GET", "/api/auth/me", nil)
|
request := httptest.NewRequest("GET", "/api/auth/me", nil)
|
||||||
request.Header.Set("Authorization", "Bearer "+user.Token)
|
request.Header.Set("Authorization", "Bearer "+user.Token)
|
||||||
request = testutils.WithUserContext(request, middleware.UserIDKey, user.User.ID)
|
request = testutils.WithUserContext(request, middleware.UserIDKey, user.User.ID)
|
||||||
|
|||||||
@@ -283,7 +283,7 @@ func TestIntegration_Repositories(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
voters := make([]*database.User, 5)
|
voters := make([]*database.User, 5)
|
||||||
for i := 0; i < 5; i++ {
|
for i := range 5 {
|
||||||
voter := &database.User{
|
voter := &database.User{
|
||||||
Username: fmt.Sprintf("voter_%d", i),
|
Username: fmt.Sprintf("voter_%d", i),
|
||||||
Email: fmt.Sprintf("voter%d@example.com", i),
|
Email: fmt.Sprintf("voter%d@example.com", i),
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ package integration
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/json"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -63,13 +62,33 @@ func TestIntegration_Router_FullMiddlewareChain(t *testing.T) {
|
|||||||
t.Run("DBMonitoring_Active", func(t *testing.T) {
|
t.Run("DBMonitoring_Active", func(t *testing.T) {
|
||||||
request := makeGetRequest(t, router, "/health")
|
request := makeGetRequest(t, router, "/health")
|
||||||
|
|
||||||
var response map[string]any
|
response := assertJSONResponse(t, request, http.StatusOK)
|
||||||
if err := json.NewDecoder(request.Body).Decode(&response); err == nil {
|
if response == nil {
|
||||||
if data, ok := response["data"].(map[string]any); ok {
|
return
|
||||||
if _, exists := data["database_stats"]; !exists {
|
|
||||||
t.Error("Expected database_stats in health response")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
data, ok := getDataFromResponse(response)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("Expected data to be a map")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
services, ok := data["services"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("Expected services in health response")
|
||||||
|
}
|
||||||
|
|
||||||
|
databaseService, ok := services["database"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("Expected database service in health response")
|
||||||
|
}
|
||||||
|
|
||||||
|
details, ok := databaseService["details"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("Expected database details in health response")
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, exists := details["stats"]; !exists {
|
||||||
|
t.Error("Expected database stats in health response")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|||||||
@@ -283,16 +283,16 @@ func TestInMemoryCacheConcurrent(t *testing.T) {
|
|||||||
|
|
||||||
t.Run("Concurrent writes", func(t *testing.T) {
|
t.Run("Concurrent writes", func(t *testing.T) {
|
||||||
done := make(chan bool, numGoroutines)
|
done := make(chan bool, numGoroutines)
|
||||||
for i := 0; i < numGoroutines; i++ {
|
for i := range numGoroutines {
|
||||||
go func(id int) {
|
go func(id int) {
|
||||||
defer func() {
|
defer func() {
|
||||||
if r := recover(); r != nil {
|
if r := recover(); r != nil {
|
||||||
t.Errorf("Goroutine %d panicked: %v", id, r)
|
t.Errorf("Goroutine %d panicked: %v", id, r)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
for j := 0; j < numOps; j++ {
|
for j := range numOps {
|
||||||
entry := &CacheEntry{
|
entry := &CacheEntry{
|
||||||
Data: []byte(fmt.Sprintf("data-%d-%d", id, j)),
|
Data: fmt.Appendf(nil, "data-%d-%d", id, j),
|
||||||
Headers: make(http.Header),
|
Headers: make(http.Header),
|
||||||
Timestamp: time.Now(),
|
Timestamp: time.Now(),
|
||||||
TTL: 5 * time.Minute,
|
TTL: 5 * time.Minute,
|
||||||
@@ -306,16 +306,16 @@ func TestInMemoryCacheConcurrent(t *testing.T) {
|
|||||||
}(i)
|
}(i)
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := 0; i < numGoroutines; i++ {
|
for range numGoroutines {
|
||||||
<-done
|
<-done
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("Concurrent reads and writes", func(t *testing.T) {
|
t.Run("Concurrent reads and writes", func(t *testing.T) {
|
||||||
|
|
||||||
for i := 0; i < 10; i++ {
|
for i := range 10 {
|
||||||
entry := &CacheEntry{
|
entry := &CacheEntry{
|
||||||
Data: []byte(fmt.Sprintf("data-%d", i)),
|
Data: fmt.Appendf(nil, "data-%d", i),
|
||||||
Headers: make(http.Header),
|
Headers: make(http.Header),
|
||||||
Timestamp: time.Now(),
|
Timestamp: time.Now(),
|
||||||
TTL: 5 * time.Minute,
|
TTL: 5 * time.Minute,
|
||||||
@@ -325,16 +325,16 @@ func TestInMemoryCacheConcurrent(t *testing.T) {
|
|||||||
|
|
||||||
done := make(chan bool, numGoroutines*2)
|
done := make(chan bool, numGoroutines*2)
|
||||||
|
|
||||||
for i := 0; i < numGoroutines; i++ {
|
for i := range numGoroutines {
|
||||||
go func(id int) {
|
go func(id int) {
|
||||||
defer func() {
|
defer func() {
|
||||||
if r := recover(); r != nil {
|
if r := recover(); r != nil {
|
||||||
t.Errorf("Writer goroutine %d panicked: %v", id, r)
|
t.Errorf("Writer goroutine %d panicked: %v", id, r)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
for j := 0; j < numOps; j++ {
|
for j := range numOps {
|
||||||
entry := &CacheEntry{
|
entry := &CacheEntry{
|
||||||
Data: []byte(fmt.Sprintf("write-%d-%d", id, j)),
|
Data: fmt.Appendf(nil, "write-%d-%d", id, j),
|
||||||
Headers: make(http.Header),
|
Headers: make(http.Header),
|
||||||
Timestamp: time.Now(),
|
Timestamp: time.Now(),
|
||||||
TTL: 5 * time.Minute,
|
TTL: 5 * time.Minute,
|
||||||
@@ -346,14 +346,14 @@ func TestInMemoryCacheConcurrent(t *testing.T) {
|
|||||||
}(i)
|
}(i)
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := 0; i < numGoroutines; i++ {
|
for i := range numGoroutines {
|
||||||
go func(id int) {
|
go func(id int) {
|
||||||
defer func() {
|
defer func() {
|
||||||
if r := recover(); r != nil {
|
if r := recover(); r != nil {
|
||||||
t.Errorf("Reader goroutine %d panicked: %v", id, r)
|
t.Errorf("Reader goroutine %d panicked: %v", id, r)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
for j := 0; j < numOps; j++ {
|
for j := range numOps {
|
||||||
key := fmt.Sprintf("key-%d", j%10)
|
key := fmt.Sprintf("key-%d", j%10)
|
||||||
cache.Get(key)
|
cache.Get(key)
|
||||||
}
|
}
|
||||||
@@ -368,9 +368,9 @@ func TestInMemoryCacheConcurrent(t *testing.T) {
|
|||||||
|
|
||||||
t.Run("Concurrent deletes", func(t *testing.T) {
|
t.Run("Concurrent deletes", func(t *testing.T) {
|
||||||
|
|
||||||
for i := 0; i < numGoroutines; i++ {
|
for i := range numGoroutines {
|
||||||
entry := &CacheEntry{
|
entry := &CacheEntry{
|
||||||
Data: []byte(fmt.Sprintf("data-%d", i)),
|
Data: fmt.Appendf(nil, "data-%d", i),
|
||||||
Headers: make(http.Header),
|
Headers: make(http.Header),
|
||||||
Timestamp: time.Now(),
|
Timestamp: time.Now(),
|
||||||
TTL: 5 * time.Minute,
|
TTL: 5 * time.Minute,
|
||||||
@@ -379,7 +379,7 @@ func TestInMemoryCacheConcurrent(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
done := make(chan bool, numGoroutines)
|
done := make(chan bool, numGoroutines)
|
||||||
for i := 0; i < numGoroutines; i++ {
|
for i := range numGoroutines {
|
||||||
go func(id int) {
|
go func(id int) {
|
||||||
defer func() {
|
defer func() {
|
||||||
if r := recover(); r != nil {
|
if r := recover(); r != nil {
|
||||||
@@ -391,7 +391,7 @@ func TestInMemoryCacheConcurrent(t *testing.T) {
|
|||||||
}(i)
|
}(i)
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := 0; i < numGoroutines; i++ {
|
for range numGoroutines {
|
||||||
<-done
|
<-done
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -26,12 +26,7 @@ func NewCORSConfig() *CORSConfig {
|
|||||||
}
|
}
|
||||||
|
|
||||||
switch env {
|
switch env {
|
||||||
case "production":
|
case "production", "staging":
|
||||||
if origins := os.Getenv("CORS_ALLOWED_ORIGINS"); origins == "" {
|
|
||||||
config.AllowedOrigins = []string{}
|
|
||||||
}
|
|
||||||
config.AllowCredentials = true
|
|
||||||
case "staging":
|
|
||||||
if origins := os.Getenv("CORS_ALLOWED_ORIGINS"); origins == "" {
|
if origins := os.Getenv("CORS_ALLOWED_ORIGINS"); origins == "" {
|
||||||
config.AllowedOrigins = []string{}
|
config.AllowedOrigins = []string{}
|
||||||
}
|
}
|
||||||
@@ -53,82 +48,66 @@ func NewCORSConfig() *CORSConfig {
|
|||||||
return config
|
return config
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func checkOrigin(origin string, allowedOrigins []string) (allowed bool, hasWildcard bool) {
|
||||||
|
for _, allowedOrigin := range allowedOrigins {
|
||||||
|
if allowedOrigin == "*" {
|
||||||
|
hasWildcard = true
|
||||||
|
allowed = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if allowedOrigin == origin {
|
||||||
|
allowed = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func setCORSHeaders(w http.ResponseWriter, origin string, hasWildcard bool, config *CORSConfig) {
|
||||||
|
if hasWildcard && !config.AllowCredentials {
|
||||||
|
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||||
|
} else {
|
||||||
|
w.Header().Set("Access-Control-Allow-Origin", origin)
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.AllowCredentials && !hasWildcard {
|
||||||
|
w.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func CORSWithConfig(config *CORSConfig) func(http.Handler) http.Handler {
|
func CORSWithConfig(config *CORSConfig) func(http.Handler) http.Handler {
|
||||||
return func(next http.Handler) http.Handler {
|
return func(next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
origin := r.Header.Get("Origin")
|
origin := r.Header.Get("Origin")
|
||||||
|
|
||||||
|
if origin == "" {
|
||||||
if r.Method == "OPTIONS" {
|
if r.Method == "OPTIONS" {
|
||||||
if origin != "" {
|
w.WriteHeader(http.StatusOK)
|
||||||
allowed := false
|
return
|
||||||
hasWildcard := false
|
|
||||||
for _, allowedOrigin := range config.AllowedOrigins {
|
|
||||||
if allowedOrigin == "*" {
|
|
||||||
hasWildcard = true
|
|
||||||
allowed = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
if allowedOrigin == origin {
|
|
||||||
allowed = true
|
|
||||||
break
|
|
||||||
}
|
}
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
allowed, hasWildcard := checkOrigin(origin, config.AllowedOrigins)
|
||||||
|
|
||||||
if !allowed {
|
if !allowed {
|
||||||
http.Error(w, "Origin not allowed", http.StatusForbidden)
|
http.Error(w, "Origin not allowed", http.StatusForbidden)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if hasWildcard && !config.AllowCredentials {
|
if r.Method == "OPTIONS" {
|
||||||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
setCORSHeaders(w, origin, hasWildcard, config)
|
||||||
} else {
|
|
||||||
w.Header().Set("Access-Control-Allow-Origin", origin)
|
|
||||||
}
|
|
||||||
|
|
||||||
w.Header().Set("Access-Control-Allow-Methods", strings.Join(config.AllowedMethods, ", "))
|
w.Header().Set("Access-Control-Allow-Methods", strings.Join(config.AllowedMethods, ", "))
|
||||||
w.Header().Set("Access-Control-Allow-Headers", strings.Join(config.AllowedHeaders, ", "))
|
w.Header().Set("Access-Control-Allow-Headers", strings.Join(config.AllowedHeaders, ", "))
|
||||||
w.Header().Set("Access-Control-Max-Age", fmt.Sprintf("%d", config.MaxAge))
|
w.Header().Set("Access-Control-Max-Age", fmt.Sprintf("%d", config.MaxAge))
|
||||||
|
|
||||||
if config.AllowCredentials && !hasWildcard {
|
|
||||||
w.Header().Set("Access-Control-Allow-Credentials", "true")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if origin != "" {
|
setCORSHeaders(w, origin, hasWildcard, config)
|
||||||
allowed := false
|
|
||||||
hasWildcard := false
|
|
||||||
for _, allowedOrigin := range config.AllowedOrigins {
|
|
||||||
if allowedOrigin == "*" {
|
|
||||||
hasWildcard = true
|
|
||||||
allowed = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
if allowedOrigin == origin {
|
|
||||||
allowed = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if !allowed {
|
|
||||||
http.Error(w, "Origin not allowed", http.StatusForbidden)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if hasWildcard && !config.AllowCredentials {
|
|
||||||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
|
||||||
} else {
|
|
||||||
w.Header().Set("Access-Control-Allow-Origin", origin)
|
|
||||||
}
|
|
||||||
|
|
||||||
if config.AllowCredentials && !hasWildcard {
|
|
||||||
w.Header().Set("Access-Control-Allow-Credentials", "true")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
next.ServeHTTP(w, r)
|
next.ServeHTTP(w, r)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -318,7 +318,7 @@ func TestConcurrentAccess(t *testing.T) {
|
|||||||
collector := NewMetricsCollector(monitor)
|
collector := NewMetricsCollector(monitor)
|
||||||
|
|
||||||
done := make(chan bool, 10)
|
done := make(chan bool, 10)
|
||||||
for i := 0; i < 10; i++ {
|
for range 10 {
|
||||||
go func() {
|
go func() {
|
||||||
monitor.LogQuery("SELECT * FROM users", 50*time.Millisecond, nil)
|
monitor.LogQuery("SELECT * FROM users", 50*time.Millisecond, nil)
|
||||||
collector.RecordRequest(100*time.Millisecond, false)
|
collector.RecordRequest(100*time.Millisecond, false)
|
||||||
@@ -326,7 +326,7 @@ func TestConcurrentAccess(t *testing.T) {
|
|||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := 0; i < 10; i++ {
|
for range 10 {
|
||||||
<-done
|
<-done
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -384,7 +384,7 @@ func TestThreadSafety(t *testing.T) {
|
|||||||
numGoroutines := 100
|
numGoroutines := 100
|
||||||
done := make(chan bool, numGoroutines)
|
done := make(chan bool, numGoroutines)
|
||||||
|
|
||||||
for i := 0; i < numGoroutines; i++ {
|
for i := range numGoroutines {
|
||||||
go func(id int) {
|
go func(id int) {
|
||||||
|
|
||||||
if id%2 == 0 {
|
if id%2 == 0 {
|
||||||
@@ -398,7 +398,7 @@ func TestThreadSafety(t *testing.T) {
|
|||||||
}(i)
|
}(i)
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := 0; i < numGoroutines; i++ {
|
for range numGoroutines {
|
||||||
<-done
|
<-done
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -388,7 +388,7 @@ func TestRateLimiterMaxKeys(t *testing.T) {
|
|||||||
limiter := NewRateLimiterWithConfig(1*time.Minute, 10, 5, 1*time.Minute, 2*time.Minute)
|
limiter := NewRateLimiterWithConfig(1*time.Minute, 10, 5, 1*time.Minute, 2*time.Minute)
|
||||||
defer limiter.StopCleanup()
|
defer limiter.StopCleanup()
|
||||||
|
|
||||||
for i := 0; i < 5; i++ {
|
for i := range 5 {
|
||||||
key := fmt.Sprintf("key-%d", i)
|
key := fmt.Sprintf("key-%d", i)
|
||||||
if !limiter.Allow(key) {
|
if !limiter.Allow(key) {
|
||||||
t.Errorf("Key %s should be allowed", key)
|
t.Errorf("Key %s should be allowed", key)
|
||||||
@@ -435,7 +435,7 @@ func TestRateLimiterRegistry(t *testing.T) {
|
|||||||
request := httptest.NewRequest("GET", "/test", nil)
|
request := httptest.NewRequest("GET", "/test", nil)
|
||||||
request.RemoteAddr = "127.0.0.1:12345"
|
request.RemoteAddr = "127.0.0.1:12345"
|
||||||
|
|
||||||
for i := 0; i < 50; i++ {
|
for i := range 50 {
|
||||||
recorder := httptest.NewRecorder()
|
recorder := httptest.NewRecorder()
|
||||||
server1.ServeHTTP(recorder, request)
|
server1.ServeHTTP(recorder, request)
|
||||||
if recorder.Code != http.StatusOK {
|
if recorder.Code != http.StatusOK {
|
||||||
@@ -443,7 +443,7 @@ func TestRateLimiterRegistry(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := 0; i < 50; i++ {
|
for i := range 50 {
|
||||||
recorder2 := httptest.NewRecorder()
|
recorder2 := httptest.NewRecorder()
|
||||||
server2.ServeHTTP(recorder2, request)
|
server2.ServeHTTP(recorder2, request)
|
||||||
if recorder2.Code != http.StatusOK {
|
if recorder2.Code != http.StatusOK {
|
||||||
@@ -463,7 +463,7 @@ func TestRateLimiterRegistry(t *testing.T) {
|
|||||||
t.Error("101st request to server2 should be rejected (shared limiter reached limit)")
|
t.Error("101st request to server2 should be rejected (shared limiter reached limit)")
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := 0; i < 50; i++ {
|
for i := range 50 {
|
||||||
recorder3 := httptest.NewRecorder()
|
recorder3 := httptest.NewRecorder()
|
||||||
server3.ServeHTTP(recorder3, request)
|
server3.ServeHTTP(recorder3, request)
|
||||||
if recorder3.Code != http.StatusOK {
|
if recorder3.Code != http.StatusOK {
|
||||||
|
|||||||
@@ -458,13 +458,13 @@ func TestIsRapidRequest(t *testing.T) {
|
|||||||
|
|
||||||
ip := "192.168.1.1"
|
ip := "192.168.1.1"
|
||||||
|
|
||||||
for i := 0; i < 50; i++ {
|
for i := range 50 {
|
||||||
if isRapidRequest(ip) {
|
if isRapidRequest(ip) {
|
||||||
t.Errorf("Request %d should not be considered rapid", i+1)
|
t.Errorf("Request %d should not be considered rapid", i+1)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := 0; i < 110; i++ {
|
for i := range 110 {
|
||||||
result := isRapidRequest(ip)
|
result := isRapidRequest(ip)
|
||||||
if i < 50 {
|
if i < 50 {
|
||||||
if result {
|
if result {
|
||||||
|
|||||||
@@ -39,7 +39,7 @@ func TestValidationMiddleware(t *testing.T) {
|
|||||||
body, _ := json.Marshal(user)
|
body, _ := json.Marshal(user)
|
||||||
request := httptest.NewRequest("POST", "/users", bytes.NewBuffer(body))
|
request := httptest.NewRequest("POST", "/users", bytes.NewBuffer(body))
|
||||||
request.Header.Set("Content-Type", "application/json")
|
request.Header.Set("Content-Type", "application/json")
|
||||||
ctx := context.WithValue(request.Context(), DTOTypeKey, reflect.TypeOf(TestUser{}))
|
ctx := context.WithValue(request.Context(), DTOTypeKey, reflect.TypeFor[TestUser]())
|
||||||
request = request.WithContext(ctx)
|
request = request.WithContext(ctx)
|
||||||
recorder := httptest.NewRecorder()
|
recorder := httptest.NewRecorder()
|
||||||
|
|
||||||
@@ -70,7 +70,7 @@ func TestValidationMiddleware(t *testing.T) {
|
|||||||
body, _ := json.Marshal(user)
|
body, _ := json.Marshal(user)
|
||||||
request := httptest.NewRequest("POST", "/users", bytes.NewBuffer(body))
|
request := httptest.NewRequest("POST", "/users", bytes.NewBuffer(body))
|
||||||
request.Header.Set("Content-Type", "application/json")
|
request.Header.Set("Content-Type", "application/json")
|
||||||
ctx := context.WithValue(request.Context(), DTOTypeKey, reflect.TypeOf(TestUser{}))
|
ctx := context.WithValue(request.Context(), DTOTypeKey, reflect.TypeFor[TestUser]())
|
||||||
request = request.WithContext(ctx)
|
request = request.WithContext(ctx)
|
||||||
recorder := httptest.NewRecorder()
|
recorder := httptest.NewRecorder()
|
||||||
|
|
||||||
@@ -117,7 +117,7 @@ func TestValidationMiddleware(t *testing.T) {
|
|||||||
body, _ := json.Marshal(user)
|
body, _ := json.Marshal(user)
|
||||||
request := httptest.NewRequest("POST", "/users", bytes.NewBuffer(body))
|
request := httptest.NewRequest("POST", "/users", bytes.NewBuffer(body))
|
||||||
request.Header.Set("Content-Type", "application/json")
|
request.Header.Set("Content-Type", "application/json")
|
||||||
ctx := context.WithValue(request.Context(), DTOTypeKey, reflect.TypeOf(TestUser{}))
|
ctx := context.WithValue(request.Context(), DTOTypeKey, reflect.TypeFor[TestUser]())
|
||||||
request = request.WithContext(ctx)
|
request = request.WithContext(ctx)
|
||||||
recorder := httptest.NewRecorder()
|
recorder := httptest.NewRecorder()
|
||||||
|
|
||||||
@@ -144,7 +144,7 @@ func TestValidationMiddleware(t *testing.T) {
|
|||||||
body, _ := json.Marshal(user)
|
body, _ := json.Marshal(user)
|
||||||
request := httptest.NewRequest("POST", "/users", bytes.NewBuffer(body))
|
request := httptest.NewRequest("POST", "/users", bytes.NewBuffer(body))
|
||||||
request.Header.Set("Content-Type", "application/json")
|
request.Header.Set("Content-Type", "application/json")
|
||||||
ctx := context.WithValue(request.Context(), DTOTypeKey, reflect.TypeOf(TestUser{}))
|
ctx := context.WithValue(request.Context(), DTOTypeKey, reflect.TypeFor[TestUser]())
|
||||||
request = request.WithContext(ctx)
|
request = request.WithContext(ctx)
|
||||||
recorder := httptest.NewRecorder()
|
recorder := httptest.NewRecorder()
|
||||||
|
|
||||||
|
|||||||
@@ -156,7 +156,7 @@ func TestPostRepository_GetAll(t *testing.T) {
|
|||||||
|
|
||||||
user := suite.CreateTestUser("testuser2", "test2@example.com", "password123")
|
user := suite.CreateTestUser("testuser2", "test2@example.com", "password123")
|
||||||
|
|
||||||
for i := 0; i < 5; i++ {
|
for i := range 5 {
|
||||||
suite.CreateTestPost(user.ID,
|
suite.CreateTestPost(user.ID,
|
||||||
"Post "+strconv.Itoa(i),
|
"Post "+strconv.Itoa(i),
|
||||||
"https://example.com/"+strconv.Itoa(i),
|
"https://example.com/"+strconv.Itoa(i),
|
||||||
@@ -178,7 +178,7 @@ func TestPostRepository_GetAll(t *testing.T) {
|
|||||||
|
|
||||||
user := suite.CreateTestUser("testuser3", "test3@example.com", "password123")
|
user := suite.CreateTestUser("testuser3", "test3@example.com", "password123")
|
||||||
|
|
||||||
for i := 0; i < 5; i++ {
|
for i := range 5 {
|
||||||
suite.CreateTestPost(user.ID,
|
suite.CreateTestPost(user.ID,
|
||||||
"Post "+strconv.Itoa(i),
|
"Post "+strconv.Itoa(i),
|
||||||
"https://example.com/"+strconv.Itoa(i),
|
"https://example.com/"+strconv.Itoa(i),
|
||||||
@@ -328,7 +328,7 @@ func TestPostRepository_Count(t *testing.T) {
|
|||||||
|
|
||||||
user := suite.CreateTestUser("testuser", "test@example.com", "password123")
|
user := suite.CreateTestUser("testuser", "test@example.com", "password123")
|
||||||
|
|
||||||
for i := 0; i < 5; i++ {
|
for i := range 5 {
|
||||||
suite.CreateTestPost(user.ID,
|
suite.CreateTestPost(user.ID,
|
||||||
"Post "+strconv.Itoa(i),
|
"Post "+strconv.Itoa(i),
|
||||||
"https://example.com/"+strconv.Itoa(i),
|
"https://example.com/"+strconv.Itoa(i),
|
||||||
@@ -506,7 +506,7 @@ func TestPostRepository_Search(t *testing.T) {
|
|||||||
|
|
||||||
user := suite.CreateTestUser("testuser2", "test2@example.com", "password123")
|
user := suite.CreateTestUser("testuser2", "test2@example.com", "password123")
|
||||||
|
|
||||||
for i := 0; i < 5; i++ {
|
for i := range 5 {
|
||||||
suite.CreateTestPost(user.ID,
|
suite.CreateTestPost(user.ID,
|
||||||
"Go Post "+strconv.Itoa(i),
|
"Go Post "+strconv.Itoa(i),
|
||||||
"https://example.com/go"+strconv.Itoa(i),
|
"https://example.com/go"+strconv.Itoa(i),
|
||||||
|
|||||||
@@ -621,7 +621,7 @@ func TestUserRepository_GetPosts(t *testing.T) {
|
|||||||
|
|
||||||
user := suite.CreateTestUser("pagination", "pagination@example.com", "password123")
|
user := suite.CreateTestUser("pagination", "pagination@example.com", "password123")
|
||||||
|
|
||||||
for i := 0; i < 5; i++ {
|
for i := range 5 {
|
||||||
suite.CreateTestPost(user.ID,
|
suite.CreateTestPost(user.ID,
|
||||||
"Post "+strconv.Itoa(i),
|
"Post "+strconv.Itoa(i),
|
||||||
"https://example.com/"+strconv.Itoa(i),
|
"https://example.com/"+strconv.Itoa(i),
|
||||||
@@ -1089,7 +1089,7 @@ func TestUserRepository_ConcurrentAccess(t *testing.T) {
|
|||||||
done := make(chan bool, 10)
|
done := make(chan bool, 10)
|
||||||
errors := make(chan error, 10)
|
errors := make(chan error, 10)
|
||||||
|
|
||||||
for i := 0; i < 10; i++ {
|
for i := range 10 {
|
||||||
go func(id int) {
|
go func(id int) {
|
||||||
defer func() { done <- true }()
|
defer func() { done <- true }()
|
||||||
user := &database.User{
|
user := &database.User{
|
||||||
@@ -1099,7 +1099,7 @@ func TestUserRepository_ConcurrentAccess(t *testing.T) {
|
|||||||
EmailVerified: true,
|
EmailVerified: true,
|
||||||
}
|
}
|
||||||
var err error
|
var err error
|
||||||
for retries := 0; retries < 5; retries++ {
|
for retries := range 5 {
|
||||||
err = suite.UserRepo.Create(user)
|
err = suite.UserRepo.Create(user)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
break
|
break
|
||||||
@@ -1116,7 +1116,7 @@ func TestUserRepository_ConcurrentAccess(t *testing.T) {
|
|||||||
}(i)
|
}(i)
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := 0; i < 10; i++ {
|
for range 10 {
|
||||||
<-done
|
<-done
|
||||||
}
|
}
|
||||||
close(errors)
|
close(errors)
|
||||||
@@ -1142,7 +1142,7 @@ func TestUserRepository_ConcurrentAccess(t *testing.T) {
|
|||||||
user := suite.CreateTestUser("concurrent_update", "update@example.com", "password123")
|
user := suite.CreateTestUser("concurrent_update", "update@example.com", "password123")
|
||||||
|
|
||||||
done := make(chan bool, 5)
|
done := make(chan bool, 5)
|
||||||
for i := 0; i < 5; i++ {
|
for i := range 5 {
|
||||||
go func(id int) {
|
go func(id int) {
|
||||||
defer func() { done <- true }()
|
defer func() { done <- true }()
|
||||||
user.Username = fmt.Sprintf("updated%d", id)
|
user.Username = fmt.Sprintf("updated%d", id)
|
||||||
@@ -1150,7 +1150,7 @@ func TestUserRepository_ConcurrentAccess(t *testing.T) {
|
|||||||
}(i)
|
}(i)
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := 0; i < 5; i++ {
|
for range 5 {
|
||||||
<-done
|
<-done
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1172,7 +1172,7 @@ func TestUserRepository_ConcurrentAccess(t *testing.T) {
|
|||||||
done := make(chan bool, 2)
|
done := make(chan bool, 2)
|
||||||
go func() {
|
go func() {
|
||||||
var err error
|
var err error
|
||||||
for retries := 0; retries < 5; retries++ {
|
for retries := range 5 {
|
||||||
err = suite.UserRepo.Delete(user1.ID)
|
err = suite.UserRepo.Delete(user1.ID)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
break
|
break
|
||||||
@@ -1187,7 +1187,7 @@ func TestUserRepository_ConcurrentAccess(t *testing.T) {
|
|||||||
}()
|
}()
|
||||||
go func() {
|
go func() {
|
||||||
var err error
|
var err error
|
||||||
for retries := 0; retries < 5; retries++ {
|
for retries := range 5 {
|
||||||
err = suite.UserRepo.Delete(user2.ID)
|
err = suite.UserRepo.Delete(user2.ID)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
break
|
break
|
||||||
@@ -1210,7 +1210,7 @@ func TestUserRepository_ConcurrentAccess(t *testing.T) {
|
|||||||
}
|
}
|
||||||
if count > 0 {
|
if count > 0 {
|
||||||
var err error
|
var err error
|
||||||
for retries := 0; retries < 5; retries++ {
|
for retries := range 5 {
|
||||||
count, err = suite.UserRepo.Count()
|
count, err = suite.UserRepo.Count()
|
||||||
if err == nil && count == 0 {
|
if err == nil && count == 0 {
|
||||||
break
|
break
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ type RouterConfig struct {
|
|||||||
Debug bool
|
Debug bool
|
||||||
DisableCache bool
|
DisableCache bool
|
||||||
DisableCompression bool
|
DisableCompression bool
|
||||||
|
CacheablePaths []string
|
||||||
DBMonitor middleware.DBMonitor
|
DBMonitor middleware.DBMonitor
|
||||||
RateLimitConfig config.RateLimitConfig
|
RateLimitConfig config.RateLimitConfig
|
||||||
}
|
}
|
||||||
@@ -49,6 +50,9 @@ func NewRouter(cfg RouterConfig) http.Handler {
|
|||||||
if !cfg.DisableCache {
|
if !cfg.DisableCache {
|
||||||
cache := middleware.NewInMemoryCache()
|
cache := middleware.NewInMemoryCache()
|
||||||
cacheConfig := middleware.DefaultCacheConfig()
|
cacheConfig := middleware.DefaultCacheConfig()
|
||||||
|
if len(cfg.CacheablePaths) > 0 {
|
||||||
|
cacheConfig.CacheablePaths = append([]string{}, cfg.CacheablePaths...)
|
||||||
|
}
|
||||||
router.Use(middleware.CacheMiddleware(cache, cacheConfig))
|
router.Use(middleware.CacheMiddleware(cache, cacheConfig))
|
||||||
router.Use(middleware.CacheInvalidationMiddleware(cache))
|
router.Use(middleware.CacheInvalidationMiddleware(cache))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -77,7 +77,8 @@ func setupTestHandlers() (*handlers.AuthHandler, *handlers.PostHandler, *handler
|
|||||||
emailSender := &testutils.MockEmailSender{}
|
emailSender := &testutils.MockEmailSender{}
|
||||||
|
|
||||||
voteService := services.NewVoteService(voteRepo, postRepo, nil)
|
voteService := services.NewVoteService(voteRepo, postRepo, nil)
|
||||||
metadataService := services.NewURLMetadataService()
|
titleFetcher := &testutils.MockTitleFetcher{}
|
||||||
|
titleFetcher.SetTitle("Example Domain")
|
||||||
|
|
||||||
mockRefreshRepo := &mockRefreshTokenRepository{}
|
mockRefreshRepo := &mockRefreshTokenRepository{}
|
||||||
mockDeletionRepo := &mockAccountDeletionRepository{}
|
mockDeletionRepo := &mockAccountDeletionRepository{}
|
||||||
@@ -94,7 +95,7 @@ func setupTestHandlers() (*handlers.AuthHandler, *handlers.PostHandler, *handler
|
|||||||
}
|
}
|
||||||
|
|
||||||
authHandler := handlers.NewAuthHandler(authFacade, userRepo)
|
authHandler := handlers.NewAuthHandler(authFacade, userRepo)
|
||||||
postHandler := handlers.NewPostHandler(postRepo, metadataService, voteService)
|
postHandler := handlers.NewPostHandler(postRepo, titleFetcher, voteService)
|
||||||
voteHandler := handlers.NewVoteHandler(voteService)
|
voteHandler := handlers.NewVoteHandler(voteService)
|
||||||
userHandler := handlers.NewUserHandler(userRepo, authFacade)
|
userHandler := handlers.NewUserHandler(userRepo, authFacade)
|
||||||
apiHandler := handlers.NewAPIHandler(testutils.AppTestConfig, postRepo, userRepo, voteService)
|
apiHandler := handlers.NewAPIHandler(testutils.AppTestConfig, postRepo, userRepo, voteService)
|
||||||
|
|||||||
@@ -90,7 +90,7 @@ func TestEmailService_Performance(t *testing.T) {
|
|||||||
|
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
iterations := 1000
|
iterations := 1000
|
||||||
for i := 0; i < iterations; i++ {
|
for range iterations {
|
||||||
service.GenerateVerificationEmailBody(user.Username, "https://example.com/confirm?token=test")
|
service.GenerateVerificationEmailBody(user.Username, "https://example.com/confirm?token=test")
|
||||||
}
|
}
|
||||||
duration := time.Since(start)
|
duration := time.Since(start)
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
@@ -339,11 +340,9 @@ func (s *URLMetadataService) validateURLForSSRF(u *url.URL) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return ErrSSRFBlocked
|
return ErrSSRFBlocked
|
||||||
}
|
}
|
||||||
for _, ip := range ips {
|
if slices.ContainsFunc(ips, isPrivateOrReservedIP) {
|
||||||
if isPrivateOrReservedIP(ip) {
|
|
||||||
return ErrSSRFBlocked
|
return ErrSSRFBlocked
|
||||||
}
|
}
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -359,13 +358,7 @@ func isLocalhost(hostname string) bool {
|
|||||||
"0:0:0:0:0:0:0:0",
|
"0:0:0:0:0:0:0:0",
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, name := range localhostNames {
|
return slices.Contains(localhostNames, hostname)
|
||||||
if hostname == name {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return false
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func isPrivateOrReservedIP(ip net.IP) bool {
|
func isPrivateOrReservedIP(ip net.IP) bool {
|
||||||
|
|||||||
@@ -285,7 +285,7 @@ func (b *VoteRequestBuilder) Build() VoteRequest {
|
|||||||
|
|
||||||
func (f *TestDataFactory) CreateTestUsers(count int) []*database.User {
|
func (f *TestDataFactory) CreateTestUsers(count int) []*database.User {
|
||||||
users := make([]*database.User, count)
|
users := make([]*database.User, count)
|
||||||
for i := 0; i < count; i++ {
|
for i := range count {
|
||||||
users[i] = f.NewUserBuilder().
|
users[i] = f.NewUserBuilder().
|
||||||
WithID(uint(i + 1)).
|
WithID(uint(i + 1)).
|
||||||
WithUsername(fmt.Sprintf("user%d", i+1)).
|
WithUsername(fmt.Sprintf("user%d", i+1)).
|
||||||
@@ -297,7 +297,7 @@ func (f *TestDataFactory) CreateTestUsers(count int) []*database.User {
|
|||||||
|
|
||||||
func (f *TestDataFactory) CreateTestPosts(count int) []*database.Post {
|
func (f *TestDataFactory) CreateTestPosts(count int) []*database.Post {
|
||||||
posts := make([]*database.Post, count)
|
posts := make([]*database.Post, count)
|
||||||
for i := 0; i < count; i++ {
|
for i := range count {
|
||||||
posts[i] = f.NewPostBuilder().
|
posts[i] = f.NewPostBuilder().
|
||||||
WithID(uint(i+1)).
|
WithID(uint(i+1)).
|
||||||
WithTitle(fmt.Sprintf("Post %d", i+1)).
|
WithTitle(fmt.Sprintf("Post %d", i+1)).
|
||||||
@@ -332,7 +332,7 @@ func (f *TestDataFactory) CreateTestVotes(count int) []*database.Vote {
|
|||||||
|
|
||||||
func (f *TestDataFactory) CreateTestAuthResults(count int) []*AuthResult {
|
func (f *TestDataFactory) CreateTestAuthResults(count int) []*AuthResult {
|
||||||
results := make([]*AuthResult, count)
|
results := make([]*AuthResult, count)
|
||||||
for i := 0; i < count; i++ {
|
for i := range count {
|
||||||
results[i] = f.NewAuthResultBuilder().
|
results[i] = f.NewAuthResultBuilder().
|
||||||
WithUser(f.NewUserBuilder().
|
WithUser(f.NewUserBuilder().
|
||||||
WithID(uint(i + 1)).
|
WithID(uint(i + 1)).
|
||||||
@@ -346,7 +346,7 @@ func (f *TestDataFactory) CreateTestAuthResults(count int) []*AuthResult {
|
|||||||
|
|
||||||
func (f *TestDataFactory) CreateTestVoteRequests(count int) []VoteRequest {
|
func (f *TestDataFactory) CreateTestVoteRequests(count int) []VoteRequest {
|
||||||
requests := make([]VoteRequest, count)
|
requests := make([]VoteRequest, count)
|
||||||
for i := 0; i < count; i++ {
|
for i := range count {
|
||||||
voteType := database.VoteUp
|
voteType := database.VoteUp
|
||||||
if i%3 == 0 {
|
if i%3 == 0 {
|
||||||
voteType = database.VoteDown
|
voteType = database.VoteDown
|
||||||
@@ -365,8 +365,9 @@ func (f *TestDataFactory) CreateTestVoteRequests(count int) []VoteRequest {
|
|||||||
return requests
|
return requests
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//go:fix inline
|
||||||
func uintPtr(u uint) *uint {
|
func uintPtr(u uint) *uint {
|
||||||
return &u
|
return new(u)
|
||||||
}
|
}
|
||||||
|
|
||||||
type E2ETestDataFactory struct {
|
type E2ETestDataFactory struct {
|
||||||
@@ -450,7 +451,7 @@ func (f *E2ETestDataFactory) CreateMultipleUsers(t *testing.T, count int, userna
|
|||||||
var users []*TestUser
|
var users []*TestUser
|
||||||
timestamp := time.Now().UnixNano()
|
timestamp := time.Now().UnixNano()
|
||||||
|
|
||||||
for i := 0; i < count; i++ {
|
for i := range count {
|
||||||
uniqueID := timestamp + int64(i)
|
uniqueID := timestamp + int64(i)
|
||||||
username := fmt.Sprintf("%s%d", usernamePrefix, uniqueID)
|
username := fmt.Sprintf("%s%d", usernamePrefix, uniqueID)
|
||||||
email := fmt.Sprintf("%s%d@example.com", emailPrefix, uniqueID)
|
email := fmt.Sprintf("%s%d@example.com", emailPrefix, uniqueID)
|
||||||
|
|||||||
@@ -123,12 +123,12 @@ func defaultIfEmpty(value, fallback string) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func extractTokenFromBody(body string) string {
|
func extractTokenFromBody(body string) string {
|
||||||
index := strings.Index(body, "token=")
|
_, after, ok := strings.Cut(body, "token=")
|
||||||
if index == -1 {
|
if !ok {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
tokenPart := body[index+len("token="):]
|
tokenPart := after
|
||||||
|
|
||||||
if delimIdx := strings.IndexAny(tokenPart, "&\"'\\\r\n <>"); delimIdx != -1 {
|
if delimIdx := strings.IndexAny(tokenPart, "&\"'\\\r\n <>"); delimIdx != -1 {
|
||||||
tokenPart = tokenPart[:delimIdx]
|
tokenPart = tokenPart[:delimIdx]
|
||||||
|
|||||||
@@ -164,7 +164,7 @@ func getErrorField(resp *APIResponse) (string, bool) {
|
|||||||
if resp == nil {
|
if resp == nil {
|
||||||
return "", false
|
return "", false
|
||||||
}
|
}
|
||||||
if dataMap, ok := resp.Data.(map[string]interface{}); ok {
|
if dataMap, ok := resp.Data.(map[string]any); ok {
|
||||||
if errorVal, ok := dataMap["error"].(string); ok {
|
if errorVal, ok := dataMap["error"].(string); ok {
|
||||||
return errorVal, true
|
return errorVal, true
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -202,7 +202,7 @@ func camelCaseToWords(s string) string {
|
|||||||
return result.String()
|
return result.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
func ValidateStruct(s interface{}) error {
|
func ValidateStruct(s any) error {
|
||||||
if s == nil {
|
if s == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -210,7 +210,7 @@ func ValidateStruct(s interface{}) error {
|
|||||||
val := reflect.ValueOf(s)
|
val := reflect.ValueOf(s)
|
||||||
typ := reflect.TypeOf(s)
|
typ := reflect.TypeOf(s)
|
||||||
|
|
||||||
if val.Kind() == reflect.Ptr {
|
if val.Kind() == reflect.Pointer {
|
||||||
if val.IsNil() {
|
if val.IsNil() {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -252,9 +252,9 @@ func ValidateStruct(s interface{}) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var tagName, param string
|
var tagName, param string
|
||||||
if idx := strings.Index(tag, "="); idx != -1 {
|
if before, after, ok := strings.Cut(tag, "="); ok {
|
||||||
tagName = tag[:idx]
|
tagName = before
|
||||||
param = tag[idx+1:]
|
param = after
|
||||||
} else {
|
} else {
|
||||||
tagName = tag
|
tagName = tag
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
# screenshot
|
||||||
|
|
||||||
In this folder, you will find screenshots of the app.
|
In this folder, you will find screenshots of the app.
|
||||||
|
|
||||||
Two kinds of screenshot here:
|
Two kinds of screenshot here:
|
||||||
|
|||||||
@@ -6,13 +6,13 @@ if [ "$EUID" -ne 0 ]; then
|
|||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
read -s "Do you want to install PostgreSQL 18? [y/N] " INSTALL_PG
|
read -rp "Do you want to install PostgreSQL 18? [y/N] " INSTALL_PG
|
||||||
if [ "$INSTALL_PG" != "y" ]; then
|
if [ "$INSTALL_PG" != "y" ]; then
|
||||||
echo "PostgreSQL 18 will not be installed"
|
echo "PostgreSQL 18 will not be installed"
|
||||||
exit 0
|
exit 0
|
||||||
fi
|
fi
|
||||||
|
|
||||||
read -s -p "Enter password for PostgreSQL user 'goyco': " GOYCO_PWD
|
read -rsp "Enter password for PostgreSQL user 'goyco': " GOYCO_PWD
|
||||||
echo
|
echo
|
||||||
|
|
||||||
apt-get update
|
apt-get update
|
||||||
|
|||||||
Reference in New Issue
Block a user