Compare commits

..

59 Commits

Author SHA1 Message Date
07c6b89525 docs: remove project structure, boring and hard to maintain 2025-12-20 17:35:16 +01:00
817205d42f refactor: modernize using min() 2025-12-16 15:45:51 +01:00
199ac143a4 refactor: replace interface{} by any 2025-12-16 15:05:23 +01:00
aa7e259ed0 format: shfmt 2025-12-16 15:02:42 +01:00
4587609e17 refactor: create createTestRouter and test edge cases 2025-12-14 21:14:42 +01:00
33da6503e3 test: also test put/delete routes 2025-12-14 21:06:15 +01:00
cafc44ed77 test: add a test for route parameters 2025-12-14 21:04:36 +01:00
1480135e75 test: verified all routes to exist 2025-12-14 21:02:25 +01:00
02a764c736 clean: remove merged files 2025-12-14 20:52:14 +01:00
6834ad7764 refactor: merge facade, types and utils into one auth_service.go 2025-12-14 20:52:03 +01:00
dcf054046f test: fix parallel processor test expectations and setup 2025-12-10 07:30:27 +01:00
d2a788933d fix: track completed items in the main loop instead of using the index 2025-12-10 07:29:55 +01:00
18be3950dc clean: obsolete function 2025-12-09 22:07:30 +01:00
f9cb140e95 clean: removed the obsolete functions outputMessage and outputError 2025-12-09 22:06:12 +01:00
86d4835ccf feat: seed user is now uniq 2025-12-09 22:03:26 +01:00
feddb2ed43 test: new unit test for EnsureSeedUser 2025-12-09 22:03:16 +01:00
457b5c88e2 refactor: improve seed consistency validation 2025-12-09 21:53:03 +01:00
a8d363b2bf fix: templates now parse with the same func map as the page handler 2025-12-09 21:37:21 +01:00
0cd68e847c refactor: add a helper to centralize CSRF token retrieval 2025-12-09 15:58:28 +01:00
df6aeed713 docs: unocss it is 2025-12-04 20:43:30 +01:00
785faeb60c feat: update alpine to 3.23 2025-12-04 10:01:19 +01:00
0623c027ba docs: prepare CONTRIBUTING.md 2025-12-03 20:57:32 +01:00
d4e91b6034 refactor: complete refactor and better helpers use 2025-11-29 15:19:41 +01:00
7d46d3e81b clean: remove the unused expectedValue in assertHeader (always set to "") 2025-11-29 15:19:28 +01:00
216aaf3117 refactor: clean code and use new request helpers 2025-11-29 14:58:52 +01:00
435047ad0c refactor: clean code 2025-11-29 14:58:37 +01:00
b7ee8bd11d refactor: clean variable names and use new request helpers 2025-11-29 14:58:20 +01:00
040cd48be8 refactor: clean variables 2025-11-29 14:58:07 +01:00
2dd16e0e00 refactor: complete 2025-11-29 14:56:18 +01:00
d6db70cc79 refactor: clean code and variables, use new request helpers 2025-11-29 14:55:47 +01:00
58e10ade7d refactor: clean variable names and modernize code 2025-11-29 14:50:35 +01:00
7403a75d8e refactor: clean variable naming 2025-11-29 14:46:26 +01:00
b429bc11af refactor: clean code and use new request helpers 2025-11-29 14:41:38 +01:00
2ec5c28fb5 refactor: rename variables and clean code 2025-11-29 14:37:18 +01:00
3743a99e40 refactor: req -> request, rec -> recorder, reqBody -> requestBody... 2025-11-29 14:21:07 +01:00
5710921b87 refactor: use new request helpers 2025-11-29 14:17:25 +01:00
84d9c81484 refactor: rec -> recorder, req -> request and modernize loop 2025-11-29 14:15:07 +01:00
b0c2038927 feat: add new helpers to make requests properly in integration tests 2025-11-29 14:11:32 +01:00
fd88931146 refactor: variable names and modernize loop 2025-11-25 14:37:59 +01:00
6ce0f4dfad refactor: name variables 2025-11-25 10:13:10 +01:00
68e3dceefc refactor: name variables 2025-11-25 10:08:48 +01:00
cded14c526 fix: force correct mime types for static files after modifying compression middleware's buffering 2025-11-24 07:53:08 +01:00
cfca668ca6 docs: clean readme 2025-11-23 21:56:53 +01:00
279255b587 fix: don't let rate limiting fails the test 2025-11-23 21:48:11 +01:00
b83f8c2228 fix: update ValidationMiddleware to return a JSON error response when JSON decoding fails 2025-11-23 21:42:27 +01:00
aabc48128c fix: use router in handlers integration tests (for dto validation) 2025-11-23 15:10:55 +01:00
68716b977b fix: verify XSS sanitization in handler response instead of repository stub 2025-11-23 15:01:54 +01:00
dbe1600632 fix: indentation 2025-11-23 14:49:31 +01:00
458e25cf79 fix: modify compression middleware to pass through redirects immediately without buffering 2025-11-23 14:48:59 +01:00
d4595d8dbf fix: properly encoding the flash message in the redirect URL 2025-11-23 14:48:39 +01:00
c5418f4e4c docs: update swagger 2025-11-23 14:26:52 +01:00
db0369225e refactor: update references to VoteRequest 2025-11-23 14:26:45 +01:00
07ac965b3d refactor: use consistent naming (VoteRequest -> CastVoteRequest) 2025-11-23 14:26:19 +01:00
e2e5d42035 feat: add SetValidatedDTOInContext to support test helper functions 2025-11-23 14:22:59 +01:00
6e4b41894f fix: update test cases to use createCreatePostRequests 2025-11-23 14:22:35 +01:00
0a8ed2e27c fix: add explicite validation check for empty url, title and content length 2025-11-23 14:21:30 +01:00
216e8657f6 feat: add generic createRequestWithDTO along with helpers functions 2025-11-23 14:20:59 +01:00
fb7206c0a2 fix: test context handling 2025-11-23 14:20:09 +01:00
c25926514b fix: add explicit empty-field validation check in handlers 2025-11-23 14:19:54 +01:00
52 changed files with 2962 additions and 2694 deletions

View File

@@ -11,7 +11,7 @@ ARG TARGETARCH=amd64
RUN CGO_ENABLED=0 GOOS=${TARGETOS} GOARCH=${TARGETARCH} go build -ldflags="-s -w" -o /out/goyco ./cmd/goyco RUN CGO_ENABLED=0 GOOS=${TARGETOS} GOARCH=${TARGETARCH} go build -ldflags="-s -w" -o /out/goyco ./cmd/goyco
# building the application image # building the application image
FROM alpine:3.22 FROM alpine:3.23
RUN addgroup -S goyco && adduser -S -G goyco goyco \ RUN addgroup -S goyco && adduser -S -G goyco goyco \
&& apk add --no-cache ca-certificates tzdata && apk add --no-cache ca-certificates tzdata
WORKDIR /app WORKDIR /app

View File

@@ -354,44 +354,6 @@ It will start the application in development mode. You can also run it as a daem
Then, use `./bin/goyco` to manage the application and notably to seed the database with sample data. Then, use `./bin/goyco` to manage the application and notably to seed the database with sample data.
### Project Structure
```bash
goyco/
├── bin/ # Compiled binaries (created after build)
├── cmd/
│ └── goyco/ # Main CLI application entrypoint
├── docker/ # Docker Compose & related files
├── docs/ # Documentation and API specs
├── internal/
│ ├── config/ # Configuration management
│ ├── database/ # Database models and access
│ ├── dto/ # Data Transfer Objects (DTOs)
│ ├── e2e/ # End-to-end tests
│ ├── fuzz/ # Fuzz tests
│ ├── handlers/ # HTTP handlers
│ ├── integration/ # Integration tests
│ ├── middleware/ # HTTP middleware
│ ├── repositories/ # Data access layer
│ ├── security/ # Security and auth logic
│ ├── server/ # HTTP server implementation
│ ├── services/ # Business logic
│ ├── static/ # Static web assets
│ ├── templates/ # HTML templates
│ ├── testutils/ # Test helpers/utilities
│ ├── validation/ # Input validation
│ └── version/ # Version information
├── scripts/ # Utility/maintenance scripts
├── services/
│ └── goyco.service # Systemd service unit example
├── .env.example # Environment variable example
├── AUTHORS # Authors file
├── Dockerfile # Docker build file
├── LICENSE # License file
├── Makefile # Project build/test targets
└── README.md # This file
```
### Testing ### Testing
```bash ```bash
@@ -437,31 +399,10 @@ This will regenerate the swagger documentation and update the `docs/swagger.json
- [ ] add right management within the app - [ ] add right management within the app
- [ ] add an admin backoffice to manage rights, users, content and settings - [ ] add an admin backoffice to manage rights, users, content and settings
- [ ] add a way to run read-only communities - [ ] add a way to run read-only communities
- [ ] maybe use a css framework instead of raw css - [ ] migrate raw CSS to UnoCSS
- [ ] kubernetes deployment - [ ] kubernetes deployment
- [ ] store configuration in the database - [ ] store configuration in the database
## Contributing
Feedbacks are welcome!
But as it's a personal gitea and you cannot create accounts, feel free to contact me at <sandro@cazzaniga.fr> to get one.
Once you have it, follow the usual workflow:
1. Fork the repository
2. Create a feature branch
3. Make your changes
4. Add tests for new functionality
5. Ensure all tests pass
6. Submit a pull request
Then, I'll review your changes and merge them if they are good.
## License ## License
This project is licensed under the GNU General Public License v3.0 or later (GPLv3+). See the [LICENSE](LICENSE) file for details. This project is licensed under the GNU General Public License v3.0 or later (GPLv3+). See the [LICENSE](LICENSE) file for details.
---
**Goyco** - A modern news aggregation platform built with Go, PostgreSQL and most importantly, love.

View File

@@ -8,9 +8,10 @@ import (
"os" "os"
"sync" "sync"
"gorm.io/gorm"
"goyco/internal/config" "goyco/internal/config"
"goyco/internal/database" "goyco/internal/database"
"gorm.io/gorm"
) )
var ErrHelpRequested = errors.New("help requested") var ErrHelpRequested = errors.New("help requested")
@@ -118,26 +119,6 @@ func outputJSON(v interface{}) error {
return encoder.Encode(v) return encoder.Encode(v)
} }
func outputMessage(message string, args ...interface{}) {
if IsJSONOutput() {
outputJSON(map[string]interface{}{
"message": fmt.Sprintf(message, args...),
})
} else {
fmt.Printf(message+"\n", args...)
}
}
func outputError(message string, args ...interface{}) {
if IsJSONOutput() {
outputJSON(map[string]interface{}{
"error": fmt.Sprintf(message, args...),
})
} else {
fmt.Fprintf(os.Stderr, message+"\n", args...)
}
}
func outputWarning(message string, args ...interface{}) { func outputWarning(message string, args ...interface{}) {
if IsJSONOutput() { if IsJSONOutput() {
outputJSON(map[string]interface{}{ outputJSON(map[string]interface{}{

View File

@@ -5,9 +5,10 @@ import (
"fmt" "fmt"
"os" "os"
"gorm.io/gorm"
"goyco/internal/config" "goyco/internal/config"
"goyco/internal/database" "goyco/internal/database"
"gorm.io/gorm"
) )
func HandleMigrateCommand(cfg *config.Config, name string, args []string) error { func HandleMigrateCommand(cfg *config.Config, name string, args []string) error {
@@ -37,7 +38,7 @@ func runMigrateCommand(db *gorm.DB) error {
return fmt.Errorf("run migrations: %w", err) return fmt.Errorf("run migrations: %w", err)
} }
if IsJSONOutput() { if IsJSONOutput() {
outputJSON(map[string]interface{}{ outputJSON(map[string]any{
"action": "migrations_applied", "action": "migrations_applied",
"status": "success", "status": "success",
}) })

View File

@@ -261,6 +261,7 @@ func processItemsInParallelNoResult[T any](
) error { ) error {
count := len(items) count := len(items)
errors := make(chan error, count) errors := make(chan error, count)
completions := make(chan struct{}, count)
semaphore := make(chan struct{}, maxWorkers) semaphore := make(chan struct{}, maxWorkers)
var wg sync.WaitGroup var wg sync.WaitGroup
@@ -288,20 +289,45 @@ func processItemsInParallelNoResult[T any](
return return
} }
if progress != nil { completions <- struct{}{}
progress.Update(index + 1)
}
}(i, item) }(i, item)
} }
go func() { go func() {
wg.Wait() wg.Wait()
close(errors) close(errors)
close(completions)
}() }()
completed := 0
firstError := make(chan error, 1)
go func() {
for err := range errors { for err := range errors {
if err != nil { if err != nil {
select {
case firstError <- err:
default:
}
return
}
}
}()
for completed < count {
select {
case _, ok := <-completions:
if !ok {
return nil
}
completed++
if progress != nil {
progress.Update(completed)
}
case err := <-firstError:
return err return err
case <-ctx.Done():
return fmt.Errorf("timeout: %w", ctx.Err())
} }
} }

View File

@@ -2,15 +2,15 @@ package commands_test
import ( import (
"errors" "errors"
"fmt"
"sync" "sync"
"testing" "testing"
"golang.org/x/crypto/bcrypt"
"goyco/cmd/goyco/commands" "goyco/cmd/goyco/commands"
"goyco/internal/database" "goyco/internal/database"
"goyco/internal/repositories" "goyco/internal/repositories"
"goyco/internal/testutils" "goyco/internal/testutils"
"golang.org/x/crypto/bcrypt"
) )
func TestParallelProcessor_CreateUsersInParallel(t *testing.T) { func TestParallelProcessor_CreateUsersInParallel(t *testing.T) {
@@ -25,7 +25,7 @@ func TestParallelProcessor_CreateUsersInParallel(t *testing.T) {
wantErr bool wantErr bool
}{ }{
{ {
name: "creates users with deterministic fields", name: "creates users with required fields",
count: successCount, count: successCount,
repoFactory: func() repositories.UserRepository { repoFactory: func() repositories.UserRepository {
base := testutils.NewMockUserRepository() base := testutils.NewMockUserRepository()
@@ -37,14 +37,24 @@ func TestParallelProcessor_CreateUsersInParallel(t *testing.T) {
if len(got) != successCount { if len(got) != successCount {
t.Fatalf("expected %d users, got %d", successCount, len(got)) t.Fatalf("expected %d users, got %d", successCount, len(got))
} }
usernames := make(map[string]bool)
for i, user := range got { for i, user := range got {
expectedUsername := fmt.Sprintf("user_%d", i+1) if user.Username == "" {
expectedEmail := fmt.Sprintf("user_%d@goyco.local", i+1) t.Errorf("user %d expected non-empty username", i)
if user.Username != expectedUsername {
t.Errorf("user %d username mismatch: got %q want %q", i, user.Username, expectedUsername)
} }
if user.Email != expectedEmail { if len(user.Username) < 6 || user.Username[:5] != "user_" {
t.Errorf("user %d email mismatch: got %q want %q", i, user.Email, expectedEmail) t.Errorf("user %d username should start with 'user_', got %q", i, user.Username)
}
if usernames[user.Username] {
t.Errorf("user %d duplicate username: %q", i, user.Username)
}
usernames[user.Username] = true
if user.Email == "" {
t.Errorf("user %d expected non-empty email", i)
}
if len(user.Email) < 20 || user.Email[:5] != "user_" || user.Email[len(user.Email)-12:] != "@goyco.local" {
t.Errorf("user %d email should match pattern 'user_*@goyco.local', got %q", i, user.Email)
} }
if !user.EmailVerified { if !user.EmailVerified {
t.Errorf("user %d expected EmailVerified to be true", i) t.Errorf("user %d expected EmailVerified to be true", i)
@@ -83,6 +93,11 @@ func TestParallelProcessor_CreateUsersInParallel(t *testing.T) {
t.Parallel() t.Parallel()
repo := tt.repoFactory() repo := tt.repoFactory()
p := commands.NewParallelProcessor() p := commands.NewParallelProcessor()
passwordHash, err := bcrypt.GenerateFromPassword([]byte("password123"), bcrypt.DefaultCost)
if err != nil {
t.Fatalf("failed to generate password hash: %v", err)
}
p.SetPasswordHash(string(passwordHash))
got, gotErr := p.CreateUsersInParallel(repo, tt.count, tt.progress) got, gotErr := p.CreateUsersInParallel(repo, tt.count, tt.progress)
if gotErr != nil { if gotErr != nil {
if !tt.wantErr { if !tt.wantErr {

View File

@@ -35,17 +35,6 @@ func initSeedRand() {
}) })
} }
func generateRandomIdentifier() string {
initSeedRand()
const length = 12
const chars = "abcdefghijklmnopqrstuvwxyz0123456789"
identifier := make([]byte, length)
for i := range identifier {
identifier[i] = chars[seedRandSource.Intn(len(chars))]
}
return string(identifier)
}
func HandleSeedCommand(cfg *config.Config, name string, args []string) error { func HandleSeedCommand(cfg *config.Config, name string, args []string) error {
fs := newFlagSet(name, printSeedUsage) fs := newFlagSet(name, printSeedUsage)
if err := parseCommand(fs, args, name); err != nil { if err := parseCommand(fs, args, name); err != nil {
@@ -213,7 +202,7 @@ func seedDatabase(userRepo repositories.UserRepository, postRepo repositories.Po
progress.Complete() progress.Complete()
} }
if err := validateSeedConsistency(postRepo, voteRepo, allUsers, posts); err != nil { if err := validateSeedConsistency(voteRepo, allUsers, posts); err != nil {
return fmt.Errorf("seed consistency validation failed: %w", err) return fmt.Errorf("seed consistency validation failed: %w", err)
} }
@@ -236,26 +225,16 @@ func seedDatabase(userRepo repositories.UserRepository, postRepo repositories.Po
return nil return nil
} }
func findExistingSeedUser(userRepo repositories.UserRepository) (*database.User, error) { const (
user, err := userRepo.GetByUsernamePrefix("seed_admin_") seedUsername = "seed_admin"
if err != nil { seedEmail = "seed_admin@goyco.local"
return nil, fmt.Errorf("no existing seed user found") )
}
return user, nil
}
func ensureSeedUser(userRepo repositories.UserRepository, passwordHash string) (*database.User, error) { func ensureSeedUser(userRepo repositories.UserRepository, passwordHash string) (*database.User, error) {
existingUser, err := findExistingSeedUser(userRepo) if user, err := userRepo.GetByUsername(seedUsername); err == nil {
if err == nil && existingUser != nil { return user, nil
return existingUser, nil
} }
randomID := generateRandomIdentifier()
seedUsername := fmt.Sprintf("seed_admin_%s", randomID)
seedEmail := fmt.Sprintf("seed_admin_%s@goyco.local", randomID)
const maxRetries = 10
for range maxRetries {
user := &database.User{ user := &database.User{
Username: seedUsername, Username: seedUsername,
Email: seedEmail, Email: seedEmail,
@@ -264,48 +243,75 @@ func ensureSeedUser(userRepo repositories.UserRepository, passwordHash string) (
} }
if err := userRepo.Create(user); err != nil { if err := userRepo.Create(user); err != nil {
randomID = generateRandomIdentifier() return nil, fmt.Errorf("failed to create seed user: %w", err)
seedUsername = fmt.Sprintf("seed_admin_%s", randomID)
seedEmail = fmt.Sprintf("seed_admin_%s@goyco.local", randomID)
continue
} }
return user, nil return user, nil
}
return nil, fmt.Errorf("failed to create seed user after %d attempts", maxRetries)
} }
func getVoteCounts(voteRepo repositories.VoteRepository, postID uint) (int, int, error) { func getVoteCounts(voteRepo repositories.VoteRepository, postID uint) (int, int, error) {
return voteRepo.GetVoteCountsByPostID(postID) return voteRepo.GetVoteCountsByPostID(postID)
} }
func validateSeedConsistency(postRepo repositories.PostRepository, voteRepo repositories.VoteRepository, users []database.User, posts []database.Post) error { func validateSeedConsistency(voteRepo repositories.VoteRepository, users []database.User, posts []database.Post) error {
userIDs := make(map[uint]bool) userIDSet := make(map[uint]struct{}, len(users))
for _, user := range users { for _, user := range users {
userIDs[user.ID] = true userIDSet[user.ID] = struct{}{}
}
postIDSet := make(map[uint]struct{}, len(posts))
for _, post := range posts {
postIDSet[post.ID] = struct{}{}
} }
for _, post := range posts { for _, post := range posts {
if post.AuthorID == nil { if err := validatePost(post, userIDSet); err != nil {
return fmt.Errorf("post %d has no author", post.ID) return err
}
if !userIDs[*post.AuthorID] {
return fmt.Errorf("post %d has invalid author ID %d", post.ID, *post.AuthorID)
} }
votes, err := voteRepo.GetByPostID(post.ID) votes, err := voteRepo.GetByPostID(post.ID)
if err != nil { if err != nil {
return fmt.Errorf("get votes for post %d: %w", post.ID, err) return fmt.Errorf("failed to retrieve votes for post %d: %w", post.ID, err)
} }
for _, vote := range votes { if err := validateVotesForPost(post.ID, votes, userIDSet, postIDSet); err != nil {
if vote.UserID != nil && !userIDs[*vote.UserID] { return err
return fmt.Errorf("vote %d has invalid user ID %d", vote.ID, *vote.UserID) }
} }
if vote.PostID != post.ID {
return fmt.Errorf("vote %d has invalid post ID %d (expected %d)", vote.ID, vote.PostID, post.ID) return nil
} }
func validatePost(post database.Post, userIDSet map[uint]struct{}) error {
if post.AuthorID == nil {
return fmt.Errorf("post %d has no author ID", post.ID)
}
if _, exists := userIDSet[*post.AuthorID]; !exists {
return fmt.Errorf("post %d references non-existent author ID %d", post.ID, *post.AuthorID)
}
return nil
}
func validateVotesForPost(postID uint, votes []database.Vote, userIDSet map[uint]struct{}, postIDSet map[uint]struct{}) error {
for _, vote := range votes {
if vote.PostID != postID {
return fmt.Errorf("vote %d references post ID %d but was retrieved for post %d", vote.ID, vote.PostID, postID)
}
if _, exists := postIDSet[vote.PostID]; !exists {
return fmt.Errorf("vote %d references non-existent post ID %d", vote.ID, vote.PostID)
}
if vote.UserID != nil {
if _, exists := userIDSet[*vote.UserID]; !exists {
return fmt.Errorf("vote %d references non-existent user ID %d", vote.ID, *vote.UserID)
}
}
if vote.Type != database.VoteUp && vote.Type != database.VoteDown {
return fmt.Errorf("vote %d has invalid type %q", vote.ID, vote.Type)
} }
} }

View File

@@ -47,7 +47,7 @@ func TestSeedCommand(t *testing.T) {
var seedUser *database.User var seedUser *database.User
regularUserCount := 0 regularUserCount := 0
for i := range users { for i := range users {
if strings.HasPrefix(users[i].Username, "seed_admin_") { if users[i].Username == "seed_admin" {
seedUserCount++ seedUserCount++
seedUser = &users[i] seedUser = &users[i]
} else if strings.HasPrefix(users[i].Username, "user_") { } else if strings.HasPrefix(users[i].Username, "user_") {
@@ -63,12 +63,12 @@ func TestSeedCommand(t *testing.T) {
t.Fatal("Expected seed user to be created") t.Fatal("Expected seed user to be created")
} }
if !strings.HasPrefix(seedUser.Username, "seed_admin_") { if seedUser.Username != "seed_admin" {
t.Errorf("Expected username to start with 'seed_admin_', got '%s'", seedUser.Username) t.Errorf("Expected username to be 'seed_admin', got '%s'", seedUser.Username)
} }
if !strings.HasPrefix(seedUser.Email, "seed_admin_") || !strings.HasSuffix(seedUser.Email, "@goyco.local") { if seedUser.Email != "seed_admin@goyco.local" {
t.Errorf("Expected email to start with 'seed_admin_' and end with '@goyco.local', got '%s'", seedUser.Email) t.Errorf("Expected email to be 'seed_admin@goyco.local', got '%s'", seedUser.Email)
} }
if !seedUser.EmailVerified { if !seedUser.EmailVerified {
@@ -302,13 +302,13 @@ func TestSeedCommandIdempotency(t *testing.T) {
seedUserCount := 0 seedUserCount := 0
for _, user := range users { for _, user := range users {
if strings.HasPrefix(user.Username, "seed_admin_") { if user.Username == "seed_admin" {
seedUserCount++ seedUserCount++
} }
} }
if seedUserCount < 1 { if seedUserCount != 1 {
t.Errorf("Expected at least 1 seed user, got %d", seedUserCount) t.Errorf("Expected exactly 1 seed user, got %d", seedUserCount)
} }
}) })
@@ -387,7 +387,7 @@ func TestSeedCommandIdempotency(t *testing.T) {
func findSeedUser(users []database.User) *database.User { func findSeedUser(users []database.User) *database.User {
for i := range users { for i := range users {
if strings.HasPrefix(users[i].Username, "seed_admin_") { if users[i].Username == "seed_admin" {
return &users[i] return &users[i]
} }
} }
@@ -476,3 +476,58 @@ func TestSeedCommandTransactionRollback(t *testing.T) {
} }
}) })
} }
func TestEnsureSeedUser(t *testing.T) {
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
if err != nil {
t.Fatalf("Failed to connect to database: %v", err)
}
if err := db.AutoMigrate(&database.User{}); err != nil {
t.Fatalf("Failed to migrate database: %v", err)
}
userRepo := repositories.NewUserRepository(db)
passwordHash := "test_password_hash"
firstUser, err := ensureSeedUser(userRepo, passwordHash)
if err != nil {
t.Fatalf("Failed to create seed user: %v", err)
}
if firstUser.Username != "seed_admin" || firstUser.Email != "seed_admin@goyco.local" || firstUser.Password != passwordHash || !firstUser.EmailVerified {
t.Errorf("Invalid seed user: username=%s, email=%s, password matches=%v, emailVerified=%v",
firstUser.Username, firstUser.Email, firstUser.Password == passwordHash, firstUser.EmailVerified)
}
secondUser, err := ensureSeedUser(userRepo, "different_password_hash")
if err != nil {
t.Fatalf("Failed to reuse seed user: %v", err)
}
if firstUser.ID != secondUser.ID {
t.Errorf("Expected same user to be reused (ID %d), got different user (ID %d)", firstUser.ID, secondUser.ID)
}
for i := 0; i < 3; i++ {
if _, err := ensureSeedUser(userRepo, passwordHash); err != nil {
t.Fatalf("Call %d failed: %v", i+1, err)
}
}
users, err := userRepo.GetAll(100, 0)
if err != nil {
t.Fatalf("Failed to get users: %v", err)
}
seedUserCount := 0
for _, user := range users {
if user.Username == "seed_admin" {
seedUserCount++
}
}
if seedUserCount != 1 {
t.Errorf("Expected exactly 1 seed user, got %d", seedUserCount)
}
}

View File

@@ -96,21 +96,21 @@ func TestServerConfigurationFromConfig(t *testing.T) {
testServer := httptest.NewServer(srv.Handler) testServer := httptest.NewServer(srv.Handler)
defer testServer.Close() defer testServer.Close()
req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, testServer.URL+"/health", nil) request, err := http.NewRequestWithContext(context.Background(), http.MethodGet, testServer.URL+"/health", nil)
if err != nil { if err != nil {
t.Fatalf("Failed to create request: %v", err) t.Fatalf("Failed to create request: %v", err)
} }
resp, err := http.DefaultClient.Do(req) response, err := http.DefaultClient.Do(request)
if err != nil { if err != nil {
t.Fatalf("Failed to make request: %v", err) t.Fatalf("Failed to make request: %v", err)
} }
defer func() { defer func() {
_ = resp.Body.Close() _ = response.Body.Close()
}() }()
if resp.StatusCode != http.StatusOK { if response.StatusCode != http.StatusOK {
t.Errorf("Expected status 200, got %d", resp.StatusCode) t.Errorf("Expected status 200, got %d", response.StatusCode)
} }
} }
@@ -208,27 +208,27 @@ func TestTLSWiringFromConfig(t *testing.T) {
}, },
} }
req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, testServer.URL+"/health", nil) request, err := http.NewRequestWithContext(context.Background(), http.MethodGet, testServer.URL+"/health", nil)
if err != nil { if err != nil {
t.Fatalf("Failed to create request: %v", err) t.Fatalf("Failed to create request: %v", err)
} }
resp, err := client.Do(req) response, err := client.Do(request)
if err != nil { if err != nil {
t.Fatalf("Failed to make TLS request: %v", err) t.Fatalf("Failed to make TLS request: %v", err)
} }
defer func() { defer func() {
_ = resp.Body.Close() _ = response.Body.Close()
}() }()
if resp.StatusCode != http.StatusOK { if response.StatusCode != http.StatusOK {
t.Errorf("Expected status 200 over TLS, got %d", resp.StatusCode) t.Errorf("Expected status 200 over TLS, got %d", response.StatusCode)
} }
if resp.TLS == nil { if response.TLS == nil {
t.Error("Expected TLS connection info to be present in response") t.Error("Expected TLS connection info to be present in response")
} else if resp.TLS.Version < tls.VersionTLS12 { } else if response.TLS.Version < tls.VersionTLS12 {
t.Errorf("Expected TLS version 1.2 or higher, got %x", resp.TLS.Version) t.Errorf("Expected TLS version 1.2 or higher, got %x", response.TLS.Version)
} }
} }
} }
@@ -368,38 +368,38 @@ func TestServerInitializationFlow(t *testing.T) {
testServer := httptest.NewServer(srv.Handler) testServer := httptest.NewServer(srv.Handler)
defer testServer.Close() defer testServer.Close()
req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, testServer.URL+"/health", nil) request, err := http.NewRequestWithContext(context.Background(), http.MethodGet, testServer.URL+"/health", nil)
if err != nil { if err != nil {
t.Fatalf("Failed to create request: %v", err) t.Fatalf("Failed to create request: %v", err)
} }
resp, err := http.DefaultClient.Do(req) response, err := http.DefaultClient.Do(request)
if err != nil { if err != nil {
t.Fatalf("Failed to make request: %v", err) t.Fatalf("Failed to make request: %v", err)
} }
defer func() { defer func() {
_ = resp.Body.Close() _ = response.Body.Close()
}() }()
if resp.StatusCode != http.StatusOK { if response.StatusCode != http.StatusOK {
t.Errorf("Expected status 200, got %d", resp.StatusCode) t.Errorf("Expected status 200, got %d", response.StatusCode)
} }
req, err = http.NewRequestWithContext(context.Background(), http.MethodGet, testServer.URL+"/api", nil) request, err = http.NewRequestWithContext(context.Background(), http.MethodGet, testServer.URL+"/api", nil)
if err != nil { if err != nil {
t.Fatalf("Failed to create request: %v", err) t.Fatalf("Failed to create request: %v", err)
} }
resp, err = http.DefaultClient.Do(req) response, err = http.DefaultClient.Do(request)
if err != nil { if err != nil {
t.Fatalf("Failed to make request: %v", err) t.Fatalf("Failed to make request: %v", err)
} }
defer func() { defer func() {
_ = resp.Body.Close() _ = response.Body.Close()
}() }()
if resp.StatusCode != http.StatusOK { if response.StatusCode != http.StatusOK {
t.Errorf("Expected status 200 for API endpoint, got %d", resp.StatusCode) t.Errorf("Expected status 200 for API endpoint, got %d", response.StatusCode)
} }
} }

View File

@@ -1370,7 +1370,7 @@ const docTemplate = `{
"in": "body", "in": "body",
"required": true, "required": true,
"schema": { "schema": {
"$ref": "#/definitions/dto.VoteRequest" "$ref": "#/definitions/dto.CastVoteRequest"
} }
} }
], ],
@@ -1817,6 +1817,22 @@ const docTemplate = `{
} }
}, },
"definitions": { "definitions": {
"dto.CastVoteRequest": {
"type": "object",
"required": [
"type"
],
"properties": {
"type": {
"type": "string",
"enum": [
"up",
"down",
"none"
]
}
}
},
"dto.ConfirmAccountDeletionRequest": { "dto.ConfirmAccountDeletionRequest": {
"type": "object", "type": "object",
"required": [ "required": [
@@ -2018,22 +2034,6 @@ const docTemplate = `{
} }
} }
}, },
"dto.VoteRequest": {
"type": "object",
"required": [
"type"
],
"properties": {
"type": {
"type": "string",
"enum": [
"up",
"down",
"none"
]
}
}
},
"handlers.APIInfo": { "handlers.APIInfo": {
"type": "object", "type": "object",
"properties": { "properties": {

View File

@@ -1367,7 +1367,7 @@
"in": "body", "in": "body",
"required": true, "required": true,
"schema": { "schema": {
"$ref": "#/definitions/dto.VoteRequest" "$ref": "#/definitions/dto.CastVoteRequest"
} }
} }
], ],
@@ -1814,6 +1814,22 @@
} }
}, },
"definitions": { "definitions": {
"dto.CastVoteRequest": {
"type": "object",
"required": [
"type"
],
"properties": {
"type": {
"type": "string",
"enum": [
"up",
"down",
"none"
]
}
}
},
"dto.ConfirmAccountDeletionRequest": { "dto.ConfirmAccountDeletionRequest": {
"type": "object", "type": "object",
"required": [ "required": [
@@ -2015,22 +2031,6 @@
} }
} }
}, },
"dto.VoteRequest": {
"type": "object",
"required": [
"type"
],
"properties": {
"type": {
"type": "string",
"enum": [
"up",
"down",
"none"
]
}
}
},
"handlers.APIInfo": { "handlers.APIInfo": {
"type": "object", "type": "object",
"properties": { "properties": {

View File

@@ -1,5 +1,16 @@
basePath: /api basePath: /api
definitions: definitions:
dto.CastVoteRequest:
properties:
type:
enum:
- up
- down
- none
type: string
required:
- type
type: object
dto.ConfirmAccountDeletionRequest: dto.ConfirmAccountDeletionRequest:
properties: properties:
delete_posts: delete_posts:
@@ -140,17 +151,6 @@ definitions:
required: required:
- username - username
type: object type: object
dto.VoteRequest:
properties:
type:
enum:
- up
- down
- none
type: string
required:
- type
type: object
handlers.APIInfo: handlers.APIInfo:
properties: properties:
data: {} data: {}
@@ -1121,7 +1121,7 @@ paths:
name: request name: request
required: true required: true
schema: schema:
$ref: '#/definitions/dto.VoteRequest' $ref: '#/definitions/dto.CastVoteRequest'
produces: produces:
- application/json - application/json
responses: responses:

View File

@@ -6,7 +6,7 @@ import (
"goyco/internal/database" "goyco/internal/database"
) )
type VoteRequest struct { type CastVoteRequest struct {
Type string `json:"type" validate:"required,oneof=up down none"` Type string `json:"type" validate:"required,oneof=up down none"`
} }

View File

@@ -112,37 +112,37 @@ func newInMemoryRoundTripper(handler http.Handler) http.RoundTripper {
return &inMemoryRoundTripper{handler: handler} return &inMemoryRoundTripper{handler: handler}
} }
func (rt *inMemoryRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { func (rt *inMemoryRoundTripper) RoundTrip(request *http.Request) (*http.Response, error) {
if rt == nil || rt.handler == nil { if rt == nil || rt.handler == nil {
return nil, fmt.Errorf("in-memory round tripper not initialized") return nil, fmt.Errorf("in-memory round tripper not initialized")
} }
var bodyBytes []byte var bodyBytes []byte
if req.Body != nil && req.Body != http.NoBody { if request.Body != nil && request.Body != http.NoBody {
defer req.Body.Close() defer request.Body.Close()
var err error var err error
bodyBytes, err = io.ReadAll(req.Body) bodyBytes, err = io.ReadAll(request.Body)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to read request body: %w", err) return nil, fmt.Errorf("failed to read request body: %w", err)
} }
} }
if len(bodyBytes) > 0 { if len(bodyBytes) > 0 {
req.Body = io.NopCloser(bytes.NewReader(bodyBytes)) request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
} else { } else {
req.Body = http.NoBody request.Body = http.NoBody
} }
clonedReq := req.Clone(req.Context()) clonedRequest := request.Clone(request.Context())
if len(bodyBytes) > 0 { if len(bodyBytes) > 0 {
clonedReq.Body = io.NopCloser(bytes.NewReader(bodyBytes)) clonedRequest.Body = io.NopCloser(bytes.NewReader(bodyBytes))
} else { } else {
clonedReq.Body = http.NoBody clonedRequest.Body = http.NoBody
} }
clonedReq.RequestURI = clonedReq.URL.RequestURI() clonedRequest.RequestURI = clonedRequest.URL.RequestURI()
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
rt.handler.ServeHTTP(recorder, clonedReq) rt.handler.ServeHTTP(recorder, clonedRequest)
resp := recorder.Result() resp := recorder.Result()
return resp, nil return resp, nil
} }
@@ -250,7 +250,7 @@ func tokenHash(token string) string {
func retryOnRateLimit(t *testing.T, maxRetries int, operation func() int) int { func retryOnRateLimit(t *testing.T, maxRetries int, operation func() int) int {
t.Helper() t.Helper()
for attempt := 0; attempt < maxRetries; attempt++ { for attempt := range maxRetries {
statusCode := operation() statusCode := operation()
if statusCode != http.StatusTooManyRequests { if statusCode != http.StatusTooManyRequests {
return statusCode return statusCode

View File

@@ -15,22 +15,22 @@ func TestE2E_CompressionMiddleware(t *testing.T) {
ctx := setupTestContext(t) ctx := setupTestContext(t)
t.Run("compression_enabled_with_accept_encoding", func(t *testing.T) { t.Run("compression_enabled_with_accept_encoding", func(t *testing.T) {
req, err := http.NewRequest("GET", ctx.baseURL+"/api/posts", nil) request, err := http.NewRequest("GET", ctx.baseURL+"/api/posts", nil)
if err != nil { if err != nil {
t.Fatalf("Failed to create request: %v", err) t.Fatalf("Failed to create request: %v", err)
} }
req.Header.Set("Accept-Encoding", "gzip") request.Header.Set("Accept-Encoding", "gzip")
testutils.WithStandardHeaders(req) testutils.WithStandardHeaders(request)
resp, err := ctx.client.Do(req) response, err := ctx.client.Do(request)
if err != nil { if err != nil {
t.Fatalf("Request failed: %v", err) t.Fatalf("Request failed: %v", err)
} }
defer resp.Body.Close() defer response.Body.Close()
contentEncoding := resp.Header.Get("Content-Encoding") contentEncoding := response.Header.Get("Content-Encoding")
if contentEncoding == "gzip" { if contentEncoding == "gzip" {
body, err := io.ReadAll(resp.Body) body, err := io.ReadAll(response.Body)
if err != nil { if err != nil {
t.Fatalf("Failed to read response body: %v", err) t.Fatalf("Failed to read response body: %v", err)
} }
@@ -57,19 +57,19 @@ func TestE2E_CompressionMiddleware(t *testing.T) {
}) })
t.Run("no_compression_without_accept_encoding", func(t *testing.T) { t.Run("no_compression_without_accept_encoding", func(t *testing.T) {
req, err := http.NewRequest("GET", ctx.baseURL+"/api/posts", nil) request, err := http.NewRequest("GET", ctx.baseURL+"/api/posts", nil)
if err != nil { if err != nil {
t.Fatalf("Failed to create request: %v", err) t.Fatalf("Failed to create request: %v", err)
} }
testutils.WithStandardHeaders(req) testutils.WithStandardHeaders(request)
resp, err := ctx.client.Do(req) response, err := ctx.client.Do(request)
if err != nil { if err != nil {
t.Fatalf("Request failed: %v", err) t.Fatalf("Request failed: %v", err)
} }
defer resp.Body.Close() defer response.Body.Close()
contentEncoding := resp.Header.Get("Content-Encoding") contentEncoding := response.Header.Get("Content-Encoding")
if contentEncoding == "gzip" { if contentEncoding == "gzip" {
t.Error("Expected no compression without Accept-Encoding header") t.Error("Expected no compression without Accept-Encoding header")
} }
@@ -85,22 +85,22 @@ func TestE2E_CompressionMiddleware(t *testing.T) {
gz.Write([]byte(postData)) gz.Write([]byte(postData))
gz.Close() gz.Close()
req, err := http.NewRequest("POST", ctx.baseURL+"/api/posts", &buf) request, err := http.NewRequest("POST", ctx.baseURL+"/api/posts", &buf)
if err != nil { if err != nil {
t.Fatalf("Failed to create request: %v", err) t.Fatalf("Failed to create request: %v", err)
} }
req.Header.Set("Content-Type", "application/json") request.Header.Set("Content-Type", "application/json")
req.Header.Set("Content-Encoding", "gzip") request.Header.Set("Content-Encoding", "gzip")
testutils.WithStandardHeaders(req) testutils.WithStandardHeaders(request)
req.Header.Set("Authorization", "Bearer "+authClient.Token) request.Header.Set("Authorization", "Bearer "+authClient.Token)
resp, err := ctx.client.Do(req) response, err := ctx.client.Do(request)
if err != nil { if err != nil {
t.Fatalf("Request failed: %v", err) t.Fatalf("Request failed: %v", err)
} }
defer resp.Body.Close() defer response.Body.Close()
switch resp.StatusCode { switch response.StatusCode {
case http.StatusBadRequest: case http.StatusBadRequest:
t.Log("Decompression middleware rejected invalid gzip") t.Log("Decompression middleware rejected invalid gzip")
case http.StatusCreated, http.StatusOK: case http.StatusCreated, http.StatusOK:
@@ -113,37 +113,37 @@ func TestE2E_CacheMiddleware(t *testing.T) {
ctx := setupTestContext(t) ctx := setupTestContext(t)
t.Run("cache_miss_then_hit", func(t *testing.T) { t.Run("cache_miss_then_hit", func(t *testing.T) {
req1, err := http.NewRequest("GET", ctx.baseURL+"/api/posts", nil) firstRequest, err := http.NewRequest("GET", ctx.baseURL+"/api/posts", nil)
if err != nil { if err != nil {
t.Fatalf("Failed to create request: %v", err) t.Fatalf("Failed to create request: %v", err)
} }
testutils.WithStandardHeaders(req1) testutils.WithStandardHeaders(firstRequest)
resp1, err := ctx.client.Do(req1) firstResponse, err := ctx.client.Do(firstRequest)
if err != nil { if err != nil {
t.Fatalf("Request failed: %v", err) t.Fatalf("Request failed: %v", err)
} }
resp1.Body.Close() firstResponse.Body.Close()
cacheStatus1 := resp1.Header.Get("X-Cache") firstCacheStatus := firstResponse.Header.Get("X-Cache")
if cacheStatus1 == "HIT" { if firstCacheStatus == "HIT" {
t.Log("First request was cached (unexpected but acceptable)") t.Log("First request was cached (unexpected but acceptable)")
} }
req2, err := http.NewRequest("GET", ctx.baseURL+"/api/posts", nil) secondRequest, err := http.NewRequest("GET", ctx.baseURL+"/api/posts", nil)
if err != nil { if err != nil {
t.Fatalf("Failed to create request: %v", err) t.Fatalf("Failed to create request: %v", err)
} }
testutils.WithStandardHeaders(req2) testutils.WithStandardHeaders(secondRequest)
resp2, err := ctx.client.Do(req2) secondResponse, err := ctx.client.Do(secondRequest)
if err != nil { if err != nil {
t.Fatalf("Request failed: %v", err) t.Fatalf("Request failed: %v", err)
} }
defer resp2.Body.Close() defer secondResponse.Body.Close()
cacheStatus2 := resp2.Header.Get("X-Cache") secondCacheStatus := secondResponse.Header.Get("X-Cache")
if cacheStatus2 == "HIT" { if secondCacheStatus == "HIT" {
t.Log("Second request was served from cache") t.Log("Second request was served from cache")
} }
}) })
@@ -152,48 +152,48 @@ func TestE2E_CacheMiddleware(t *testing.T) {
testUser := ctx.createUserWithCleanup(t, "cacheuser", "StrongPass123!") testUser := ctx.createUserWithCleanup(t, "cacheuser", "StrongPass123!")
authClient := ctx.loginUser(t, testUser.Username, "StrongPass123!") authClient := ctx.loginUser(t, testUser.Username, "StrongPass123!")
req1, err := http.NewRequest("GET", ctx.baseURL+"/api/posts", nil) firstRequest, err := http.NewRequest("GET", ctx.baseURL+"/api/posts", nil)
if err != nil { if err != nil {
t.Fatalf("Failed to create request: %v", err) t.Fatalf("Failed to create request: %v", err)
} }
testutils.WithStandardHeaders(req1) testutils.WithStandardHeaders(firstRequest)
req1.Header.Set("Authorization", "Bearer "+authClient.Token) firstRequest.Header.Set("Authorization", "Bearer "+authClient.Token)
resp1, err := ctx.client.Do(req1) firstResponse, err := ctx.client.Do(firstRequest)
if err != nil { if err != nil {
t.Fatalf("Request failed: %v", err) t.Fatalf("Request failed: %v", err)
} }
resp1.Body.Close() firstResponse.Body.Close()
postData := `{"title":"Cache Invalidation Test","url":"https://example.com/cache","content":"Test"}` postData := `{"title":"Cache Invalidation Test","url":"https://example.com/cache","content":"Test"}`
req2, err := http.NewRequest("POST", ctx.baseURL+"/api/posts", strings.NewReader(postData)) secondRequest, err := http.NewRequest("POST", ctx.baseURL+"/api/posts", strings.NewReader(postData))
if err != nil { if err != nil {
t.Fatalf("Failed to create request: %v", err) t.Fatalf("Failed to create request: %v", err)
} }
req2.Header.Set("Content-Type", "application/json") secondRequest.Header.Set("Content-Type", "application/json")
testutils.WithStandardHeaders(req2) testutils.WithStandardHeaders(secondRequest)
req2.Header.Set("Authorization", "Bearer "+authClient.Token) secondRequest.Header.Set("Authorization", "Bearer "+authClient.Token)
resp2, err := ctx.client.Do(req2) secondResponse, err := ctx.client.Do(secondRequest)
if err != nil { if err != nil {
t.Fatalf("Request failed: %v", err) t.Fatalf("Request failed: %v", err)
} }
resp2.Body.Close() secondResponse.Body.Close()
req3, err := http.NewRequest("GET", ctx.baseURL+"/api/posts", nil) thirdRequest, err := http.NewRequest("GET", ctx.baseURL+"/api/posts", nil)
if err != nil { if err != nil {
t.Fatalf("Failed to create request: %v", err) t.Fatalf("Failed to create request: %v", err)
} }
testutils.WithStandardHeaders(req3) testutils.WithStandardHeaders(thirdRequest)
req3.Header.Set("Authorization", "Bearer "+authClient.Token) thirdRequest.Header.Set("Authorization", "Bearer "+authClient.Token)
resp3, err := ctx.client.Do(req3) thirdResponse, err := ctx.client.Do(thirdRequest)
if err != nil { if err != nil {
t.Fatalf("Request failed: %v", err) t.Fatalf("Request failed: %v", err)
} }
defer resp3.Body.Close() defer thirdResponse.Body.Close()
cacheStatus := resp3.Header.Get("X-Cache") cacheStatus := thirdResponse.Header.Get("X-Cache")
if cacheStatus == "HIT" { if cacheStatus == "HIT" {
t.Log("Cache was invalidated after POST") t.Log("Cache was invalidated after POST")
} }
@@ -204,23 +204,23 @@ func TestE2E_CSRFProtection(t *testing.T) {
ctx := setupTestContext(t) ctx := setupTestContext(t)
t.Run("csrf_protection_for_non_api_routes", func(t *testing.T) { t.Run("csrf_protection_for_non_api_routes", func(t *testing.T) {
req, err := http.NewRequest("POST", ctx.baseURL+"/auth/login", strings.NewReader(`{"username":"test","password":"test"}`)) request, err := http.NewRequest("POST", ctx.baseURL+"/auth/login", strings.NewReader(`{"username":"test","password":"test"}`))
if err != nil { if err != nil {
t.Fatalf("Failed to create request: %v", err) t.Fatalf("Failed to create request: %v", err)
} }
req.Header.Set("Content-Type", "application/json") request.Header.Set("Content-Type", "application/json")
testutils.WithStandardHeaders(req) testutils.WithStandardHeaders(request)
resp, err := ctx.client.Do(req) response, err := ctx.client.Do(request)
if err != nil { if err != nil {
t.Fatalf("Request failed: %v", err) t.Fatalf("Request failed: %v", err)
} }
defer resp.Body.Close() defer response.Body.Close()
if resp.StatusCode == http.StatusForbidden { if response.StatusCode == http.StatusForbidden {
t.Log("CSRF protection active for non-API routes") t.Log("CSRF protection active for non-API routes")
} else { } else {
t.Logf("CSRF check result: status %d", resp.StatusCode) t.Logf("CSRF check result: status %d", response.StatusCode)
} }
}) })
@@ -229,39 +229,39 @@ func TestE2E_CSRFProtection(t *testing.T) {
authClient := ctx.loginUser(t, testUser.Username, "StrongPass123!") authClient := ctx.loginUser(t, testUser.Username, "StrongPass123!")
postData := `{"title":"CSRF Test","url":"https://example.com/csrf","content":"Test"}` postData := `{"title":"CSRF Test","url":"https://example.com/csrf","content":"Test"}`
req, err := http.NewRequest("POST", ctx.baseURL+"/api/posts", strings.NewReader(postData)) request, err := http.NewRequest("POST", ctx.baseURL+"/api/posts", strings.NewReader(postData))
if err != nil { if err != nil {
t.Fatalf("Failed to create request: %v", err) t.Fatalf("Failed to create request: %v", err)
} }
req.Header.Set("Content-Type", "application/json") request.Header.Set("Content-Type", "application/json")
testutils.WithStandardHeaders(req) testutils.WithStandardHeaders(request)
req.Header.Set("Authorization", "Bearer "+authClient.Token) request.Header.Set("Authorization", "Bearer "+authClient.Token)
resp, err := ctx.client.Do(req) response, err := ctx.client.Do(request)
if err != nil { if err != nil {
t.Fatalf("Request failed: %v", err) t.Fatalf("Request failed: %v", err)
} }
defer resp.Body.Close() defer response.Body.Close()
if resp.StatusCode == http.StatusForbidden { if response.StatusCode == http.StatusForbidden {
t.Error("API routes should bypass CSRF protection") t.Error("API routes should bypass CSRF protection")
} }
}) })
t.Run("csrf_allows_get_requests", func(t *testing.T) { t.Run("csrf_allows_get_requests", func(t *testing.T) {
req, err := http.NewRequest("GET", ctx.baseURL+"/auth/login", nil) request, err := http.NewRequest("GET", ctx.baseURL+"/auth/login", nil)
if err != nil { if err != nil {
t.Fatalf("Failed to create request: %v", err) t.Fatalf("Failed to create request: %v", err)
} }
testutils.WithStandardHeaders(req) testutils.WithStandardHeaders(request)
resp, err := ctx.client.Do(req) response, err := ctx.client.Do(request)
if err != nil { if err != nil {
t.Fatalf("Request failed: %v", err) t.Fatalf("Request failed: %v", err)
} }
defer resp.Body.Close() defer response.Body.Close()
if resp.StatusCode == http.StatusForbidden { if response.StatusCode == http.StatusForbidden {
t.Error("GET requests should not require CSRF token") t.Error("GET requests should not require CSRF token")
} }
}) })
@@ -276,21 +276,21 @@ func TestE2E_RequestSizeLimit(t *testing.T) {
smallData := strings.Repeat("a", 100) smallData := strings.Repeat("a", 100)
postData := `{"title":"` + smallData + `","url":"https://example.com","content":"test"}` postData := `{"title":"` + smallData + `","url":"https://example.com","content":"test"}`
req, err := http.NewRequest("POST", ctx.baseURL+"/api/posts", strings.NewReader(postData)) request, err := http.NewRequest("POST", ctx.baseURL+"/api/posts", strings.NewReader(postData))
if err != nil { if err != nil {
t.Fatalf("Failed to create request: %v", err) t.Fatalf("Failed to create request: %v", err)
} }
req.Header.Set("Content-Type", "application/json") request.Header.Set("Content-Type", "application/json")
testutils.WithStandardHeaders(req) testutils.WithStandardHeaders(request)
req.Header.Set("Authorization", "Bearer "+authClient.Token) request.Header.Set("Authorization", "Bearer "+authClient.Token)
resp, err := ctx.client.Do(req) response, err := ctx.client.Do(request)
if err != nil { if err != nil {
t.Fatalf("Request failed: %v", err) t.Fatalf("Request failed: %v", err)
} }
defer resp.Body.Close() defer response.Body.Close()
if resp.StatusCode == http.StatusRequestEntityTooLarge { if response.StatusCode == http.StatusRequestEntityTooLarge {
t.Error("Small request should not exceed size limit") t.Error("Small request should not exceed size limit")
} }
}) })
@@ -301,24 +301,24 @@ func TestE2E_RequestSizeLimit(t *testing.T) {
largeData := strings.Repeat("a", 2*1024*1024) largeData := strings.Repeat("a", 2*1024*1024)
postData := `{"title":"test","url":"https://example.com","content":"` + largeData + `"}` postData := `{"title":"test","url":"https://example.com","content":"` + largeData + `"}`
req, err := http.NewRequest("POST", ctx.baseURL+"/api/posts", strings.NewReader(postData)) request, err := http.NewRequest("POST", ctx.baseURL+"/api/posts", strings.NewReader(postData))
if err != nil { if err != nil {
t.Fatalf("Failed to create request: %v", err) t.Fatalf("Failed to create request: %v", err)
} }
req.Header.Set("Content-Type", "application/json") request.Header.Set("Content-Type", "application/json")
testutils.WithStandardHeaders(req) testutils.WithStandardHeaders(request)
req.Header.Set("Authorization", "Bearer "+authClient.Token) request.Header.Set("Authorization", "Bearer "+authClient.Token)
resp, err := ctx.client.Do(req) response, err := ctx.client.Do(request)
if err != nil { if err != nil {
return return
} }
defer resp.Body.Close() defer response.Body.Close()
if resp.StatusCode == http.StatusRequestEntityTooLarge { if response.StatusCode == http.StatusRequestEntityTooLarge {
t.Log("Request size limit enforced correctly") t.Log("Request size limit enforced correctly")
} else { } else {
t.Logf("Request size limit check result: status %d", resp.StatusCode) t.Logf("Request size limit check result: status %d", response.StatusCode)
} }
}) })
} }

View File

@@ -225,6 +225,11 @@ func (h *AuthHandler) ResendVerificationEmail(w http.ResponseWriter, r *http.Req
email := strings.TrimSpace(req.Email) email := strings.TrimSpace(req.Email)
if email == "" {
SendErrorResponse(w, "Email address is required", http.StatusBadRequest)
return
}
err := h.authService.ResendVerificationEmail(email) err := h.authService.ResendVerificationEmail(email)
if err != nil { if err != nil {
switch { switch {
@@ -293,6 +298,11 @@ func (h *AuthHandler) RequestPasswordReset(w http.ResponseWriter, r *http.Reques
usernameOrEmail := strings.TrimSpace(req.UsernameOrEmail) usernameOrEmail := strings.TrimSpace(req.UsernameOrEmail)
if usernameOrEmail == "" {
SendErrorResponse(w, "Username or email is required", http.StatusBadRequest)
return
}
if err := h.authService.RequestPasswordReset(usernameOrEmail); err != nil { if err := h.authService.RequestPasswordReset(usernameOrEmail); err != nil {
} }
@@ -319,6 +329,11 @@ func (h *AuthHandler) ResetPassword(w http.ResponseWriter, r *http.Request) {
token := strings.TrimSpace(req.Token) token := strings.TrimSpace(req.Token)
newPassword := strings.TrimSpace(req.NewPassword) newPassword := strings.TrimSpace(req.NewPassword)
if token == "" {
SendErrorResponse(w, "Token is required", http.StatusBadRequest)
return
}
if err := validation.ValidatePassword(newPassword); err != nil { if err := validation.ValidatePassword(newPassword); err != nil {
SendErrorResponse(w, err.Error(), http.StatusBadRequest) SendErrorResponse(w, err.Error(), http.StatusBadRequest)
return return
@@ -467,6 +482,11 @@ func (h *AuthHandler) UpdatePassword(w http.ResponseWriter, r *http.Request) {
currentPassword := strings.TrimSpace(req.CurrentPassword) currentPassword := strings.TrimSpace(req.CurrentPassword)
newPassword := strings.TrimSpace(req.NewPassword) newPassword := strings.TrimSpace(req.NewPassword)
if currentPassword == "" {
SendErrorResponse(w, "Current password is required", http.StatusBadRequest)
return
}
if err := validation.ValidatePassword(newPassword); err != nil { if err := validation.ValidatePassword(newPassword); err != nil {
SendErrorResponse(w, err.Error(), http.StatusBadRequest) SendErrorResponse(w, err.Error(), http.StatusBadRequest)
return return
@@ -538,6 +558,11 @@ func (h *AuthHandler) ConfirmAccountDeletion(w http.ResponseWriter, r *http.Requ
token := strings.TrimSpace(req.Token) token := strings.TrimSpace(req.Token)
if token == "" {
SendErrorResponse(w, "Deletion token is required", http.StatusBadRequest)
return
}
if err := h.authService.ConfirmAccountDeletionWithPosts(token, req.DeletePosts); err != nil { if err := h.authService.ConfirmAccountDeletionWithPosts(token, req.DeletePosts); err != nil {
switch { switch {
case errors.Is(err, services.ErrInvalidDeletionToken): case errors.Is(err, services.ErrInvalidDeletionToken):
@@ -591,6 +616,11 @@ func (h *AuthHandler) RefreshToken(w http.ResponseWriter, r *http.Request) {
return return
} }
if req.RefreshToken == "" {
SendErrorResponse(w, "Refresh token is required", http.StatusBadRequest)
return
}
result, err := h.authService.RefreshAccessToken(req.RefreshToken) result, err := h.authService.RefreshAccessToken(req.RefreshToken)
if !HandleServiceError(w, err, "Token refresh failed", http.StatusInternalServerError) { if !HandleServiceError(w, err, "Token refresh failed", http.StatusInternalServerError) {
return return
@@ -618,6 +648,11 @@ func (h *AuthHandler) RevokeToken(w http.ResponseWriter, r *http.Request) {
return return
} }
if req.RefreshToken == "" {
SendErrorResponse(w, "Refresh token is required", http.StatusBadRequest)
return
}
err := h.authService.RevokeRefreshToken(req.RefreshToken) err := h.authService.RevokeRefreshToken(req.RefreshToken)
if err != nil { if err != nil {
SendErrorResponse(w, "Failed to revoke token", http.StatusInternalServerError) SendErrorResponse(w, "Failed to revoke token", http.StatusInternalServerError)

View File

@@ -252,8 +252,8 @@ func TestAuthHandlerLoginSuccess(t *testing.T) {
} }
handler := newAuthHandler(repo) handler := newAuthHandler(repo)
body := bytes.NewBufferString(`{"username":"user","password":"Password123!"}`) bodyStr := `{"username":"user","password":"Password123!"}`
request := httptest.NewRequest(http.MethodPost, "/api/auth/login", body) request := createLoginRequest(bodyStr)
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
handler.Login(recorder, request) handler.Login(recorder, request)
@@ -274,17 +274,17 @@ func TestAuthHandlerLoginErrors(t *testing.T) {
handler := newAuthHandler(&testutils.UserRepositoryStub{}) handler := newAuthHandler(&testutils.UserRepositoryStub{})
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
request := httptest.NewRequest(http.MethodPost, "/api/auth/login", bytes.NewBufferString("invalid")) request := createLoginRequest("invalid")
handler.Login(recorder, request) handler.Login(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest) testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest)
recorder = httptest.NewRecorder() recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/auth/login", bytes.NewBufferString(`{"username":" ","password":""}`)) request = createLoginRequest(`{"username":" ","password":""}`)
handler.Login(recorder, request) handler.Login(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest) testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest)
recorder = httptest.NewRecorder() recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/auth/login", bytes.NewBufferString(`{"username":"user","password":"WrongPass123!"}`)) request = createLoginRequest(`{"username":"user","password":"WrongPass123!"}`)
handler.Login(recorder, request) handler.Login(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusUnauthorized) testutils.AssertHTTPStatus(t, recorder, http.StatusUnauthorized)
@@ -294,7 +294,7 @@ func TestAuthHandlerLoginErrors(t *testing.T) {
}} }}
handler = newAuthHandler(repo) handler = newAuthHandler(repo)
recorder = httptest.NewRecorder() recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/auth/login", bytes.NewBufferString(`{"username":"user","password":"Password123!"}`)) request = createLoginRequest(`{"username":"user","password":"Password123!"}`)
handler.Login(recorder, request) handler.Login(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusForbidden) testutils.AssertHTTPStatus(t, recorder, http.StatusForbidden)
@@ -304,7 +304,7 @@ func TestAuthHandlerLoginErrors(t *testing.T) {
handler = newAuthHandler(repo) handler = newAuthHandler(repo)
recorder = httptest.NewRecorder() recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/auth/login", bytes.NewBufferString(`{"username":"user","password":"Password123!"}`)) request = createLoginRequest(`{"username":"user","password":"Password123!"}`)
handler.Login(recorder, request) handler.Login(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusInternalServerError) testutils.AssertHTTPStatus(t, recorder, http.StatusInternalServerError)
@@ -330,8 +330,7 @@ func TestAuthHandlerRegisterSuccess(t *testing.T) {
return nil return nil
}}) }})
body := bytes.NewBufferString(`{"username":"newuser","email":"new@example.com","password":"Password123!"}`) request := createRegisterRequest(`{"username":"newuser","email":"new@example.com","password":"Password123!"}`)
request := httptest.NewRequest(http.MethodPost, "/api/auth/register", body)
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
handler.Register(recorder, request) handler.Register(recorder, request)
@@ -354,12 +353,12 @@ func TestAuthHandlerRegisterErrors(t *testing.T) {
handler := newAuthHandler(repo) handler := newAuthHandler(repo)
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
request := httptest.NewRequest(http.MethodPost, "/api/auth/register", bytes.NewBufferString("invalid")) request := createRegisterRequest("invalid")
handler.Register(recorder, request) handler.Register(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest) testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest)
recorder = httptest.NewRecorder() recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/auth/register", bytes.NewBufferString(`{"username":"","email":"","password":""}`)) request = createRegisterRequest(`{"username":"","email":"","password":""}`)
handler.Register(recorder, request) handler.Register(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest) testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest)
@@ -368,7 +367,7 @@ func TestAuthHandlerRegisterErrors(t *testing.T) {
}} }}
handler = newAuthHandler(repo) handler = newAuthHandler(repo)
recorder = httptest.NewRecorder() recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/auth/register", bytes.NewBufferString(`{"username":"new","email":"taken@example.com","password":"Password123!"}`)) request = createRegisterRequest(`{"username":"new","email":"taken@example.com","password":"Password123!"}`)
handler.Register(recorder, request) handler.Register(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusConflict) testutils.AssertHTTPStatus(t, recorder, http.StatusConflict)
@@ -382,7 +381,7 @@ func TestAuthHandlerRegisterErrors(t *testing.T) {
} }
handler = newAuthHandler(repo) handler = newAuthHandler(repo)
recorder = httptest.NewRecorder() recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/auth/register", bytes.NewBufferString(`{"username":"another","email":"taken@example.com","password":"Password123!"}`)) request = createRegisterRequest(`{"username":"another","email":"taken@example.com","password":"Password123!"}`)
handler.Register(recorder, request) handler.Register(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusConflict) testutils.AssertHTTPStatus(t, recorder, http.StatusConflict)
} }
@@ -477,7 +476,7 @@ func TestAuthHandlerRequestPasswordReset(t *testing.T) {
}}) }})
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
request := httptest.NewRequest(http.MethodPost, "/api/auth/forgot-password", bytes.NewBufferString(`{"username_or_email":"user@example.com"}`)) request := createForgotPasswordRequest(`{"username_or_email":"user@example.com"}`)
handler.RequestPasswordReset(recorder, request) handler.RequestPasswordReset(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusOK) testutils.AssertHTTPStatus(t, recorder, http.StatusOK)
@@ -495,19 +494,19 @@ func TestAuthHandlerRequestPasswordReset(t *testing.T) {
}}) }})
recorder = httptest.NewRecorder() recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/auth/forgot-password", bytes.NewBufferString(`{"username_or_email":"user"}`)) request = createForgotPasswordRequest(`{"username_or_email":"user"}`)
handler.RequestPasswordReset(recorder, request) handler.RequestPasswordReset(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusOK) testutils.AssertHTTPStatus(t, recorder, http.StatusOK)
recorder = httptest.NewRecorder() recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/auth/forgot-password", bytes.NewBufferString(`{"username_or_email":""}`)) request = createForgotPasswordRequest(`{"username_or_email":""}`)
handler.RequestPasswordReset(recorder, request) handler.RequestPasswordReset(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest) testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest)
recorder = httptest.NewRecorder() recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/auth/forgot-password", bytes.NewBufferString(`invalid json`)) request = createForgotPasswordRequest(`invalid json`)
handler.RequestPasswordReset(recorder, request) handler.RequestPasswordReset(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest) testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest)
@@ -518,25 +517,25 @@ func TestAuthHandlerResetPassword(t *testing.T) {
handler := newAuthHandler(repo) handler := newAuthHandler(repo)
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
request := httptest.NewRequest(http.MethodPost, "/api/auth/reset-password", bytes.NewBufferString(`{"new_password":"NewPassword123!"}`)) request := createResetPasswordRequest(`{"new_password":"NewPassword123!"}`)
handler.ResetPassword(recorder, request) handler.ResetPassword(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest) testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest)
recorder = httptest.NewRecorder() recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/auth/reset-password", bytes.NewBufferString(`{"token":"valid_token"}`)) request = createResetPasswordRequest(`{"token":"valid_token"}`)
handler.ResetPassword(recorder, request) handler.ResetPassword(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest) testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest)
recorder = httptest.NewRecorder() recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/auth/reset-password", bytes.NewBufferString(`{"token":"valid_token","new_password":"short"}`)) request = createResetPasswordRequest(`{"token":"valid_token","new_password":"short"}`)
handler.ResetPassword(recorder, request) handler.ResetPassword(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest) testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest)
recorder = httptest.NewRecorder() recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/auth/reset-password", bytes.NewBufferString(`invalid json`)) request = createResetPasswordRequest(`invalid json`)
handler.ResetPassword(recorder, request) handler.ResetPassword(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest) testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest)
@@ -602,7 +601,7 @@ func TestAuthHandlerResetPasswordServiceOutcomes(t *testing.T) {
handler := newMockAuthHandler(repo, mockService) handler := newMockAuthHandler(repo, mockService)
request := httptest.NewRequest(http.MethodPost, "/api/auth/reset-password", bytes.NewBufferString(`{"token":"abc","new_password":"Password123!"}`)) request := createResetPasswordRequest(`{"token":"abc","new_password":"Password123!"}`)
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
handler.ResetPassword(recorder, request) handler.ResetPassword(recorder, request)
@@ -664,7 +663,7 @@ func TestAuthHandlerUpdateEmail(t *testing.T) {
userID: 1, userID: 1,
mockSetup: func(repo *testutils.UserRepositoryStub) {}, mockSetup: func(repo *testutils.UserRepositoryStub) {},
expectedStatus: http.StatusBadRequest, expectedStatus: http.StatusBadRequest,
expectedError: "Invalid request body", expectedError: "Invalid request",
}, },
{ {
name: "empty email", name: "empty email",
@@ -702,7 +701,7 @@ func TestAuthHandlerUpdateEmail(t *testing.T) {
handler := newMockAuthHandler(repo, mockService) handler := newMockAuthHandler(repo, mockService)
request := httptest.NewRequest(http.MethodPut, "/api/auth/email", bytes.NewBufferString(tt.requestBody)) request := createUpdateEmailRequest(tt.requestBody)
if tt.userID > 0 { if tt.userID > 0 {
ctx := context.WithValue(request.Context(), middleware.UserIDKey, tt.userID) ctx := context.WithValue(request.Context(), middleware.UserIDKey, tt.userID)
request = request.WithContext(ctx) request = request.WithContext(ctx)
@@ -789,7 +788,7 @@ func TestAuthHandlerUpdateUsername(t *testing.T) {
handler := newMockAuthHandler(repo, mockService) handler := newMockAuthHandler(repo, mockService)
request := httptest.NewRequest(http.MethodPut, "/api/auth/username", bytes.NewBufferString(tt.requestBody)) request := createUpdateUsernameRequest(tt.requestBody)
if tt.userID > 0 { if tt.userID > 0 {
ctx := context.WithValue(request.Context(), middleware.UserIDKey, tt.userID) ctx := context.WithValue(request.Context(), middleware.UserIDKey, tt.userID)
request = request.WithContext(ctx) request = request.WithContext(ctx)
@@ -886,7 +885,7 @@ func TestAuthHandlerUpdatePassword(t *testing.T) {
tt.mockSetup(repo) tt.mockSetup(repo)
handler := newAuthHandler(repo) handler := newAuthHandler(repo)
request := httptest.NewRequest(http.MethodPut, "/api/auth/password", bytes.NewBufferString(tt.requestBody)) request := createUpdatePasswordRequest(tt.requestBody)
if tt.userID > 0 { if tt.userID > 0 {
ctx := context.WithValue(request.Context(), middleware.UserIDKey, tt.userID) ctx := context.WithValue(request.Context(), middleware.UserIDKey, tt.userID)
request = request.WithContext(ctx) request = request.WithContext(ctx)
@@ -984,8 +983,7 @@ func TestAuthHandlerDeleteAccount(t *testing.T) {
func TestAuthHandlerResendVerificationEmail(t *testing.T) { func TestAuthHandlerResendVerificationEmail(t *testing.T) {
makeRequest := func(body string, setup func(*mockAuthService)) (*httptest.ResponseRecorder, AuthResponse) { makeRequest := func(body string, setup func(*mockAuthService)) (*httptest.ResponseRecorder, AuthResponse) {
request := httptest.NewRequest(http.MethodPost, "/api/auth/resend-verification", bytes.NewBufferString(body)) request := createResendVerificationRequest(body)
request = request.WithContext(context.Background())
repo := &testutils.UserRepositoryStub{} repo := &testutils.UserRepositoryStub{}
mockService := &mockAuthService{} mockService := &mockAuthService{}
@@ -1014,7 +1012,7 @@ func TestAuthHandlerResendVerificationEmail(t *testing.T) {
name: "invalid json", name: "invalid json",
body: "not-json", body: "not-json",
expectedStatus: http.StatusBadRequest, expectedStatus: http.StatusBadRequest,
expectedError: "Invalid request body", expectedError: "Invalid request",
}, },
{ {
name: "missing email", name: "missing email",
@@ -1139,7 +1137,7 @@ func TestAuthHandlerConfirmAccountDeletion(t *testing.T) {
name: "invalid json", name: "invalid json",
body: "not-json", body: "not-json",
expectedStatus: http.StatusBadRequest, expectedStatus: http.StatusBadRequest,
expectedError: "Invalid request body", expectedError: "Invalid request",
}, },
{ {
name: "missing token", name: "missing token",
@@ -1209,7 +1207,7 @@ func TestAuthHandlerConfirmAccountDeletion(t *testing.T) {
handler := newMockAuthHandler(repo, mockService) handler := newMockAuthHandler(repo, mockService)
request := httptest.NewRequest(http.MethodPost, "/api/auth/account/confirm", bytes.NewBufferString(tt.body)) request := createConfirmAccountDeletionRequest(tt.body)
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
handler.ConfirmAccountDeletion(recorder, request) handler.ConfirmAccountDeletion(recorder, request)
@@ -1338,9 +1336,7 @@ func TestAuthHandler_ConcurrentAccess(t *testing.T) {
for i := 0; i < concurrency; i++ { for i := 0; i < concurrency; i++ {
go func() { go func() {
body := bytes.NewBufferString(`{"username":"testuser","password":"Password123!"}`) req := createLoginRequest(`{"username":"testuser","password":"Password123!"}`)
req := httptest.NewRequest("POST", "/api/auth/login", body)
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder() w := httptest.NewRecorder()
handler.Login(w, req) handler.Login(w, req)
@@ -1370,8 +1366,7 @@ func TestAuthHandler_RefreshToken(t *testing.T) {
}, nil }, nil
} }
body := bytes.NewBufferString(`{"refresh_token":"valid_refresh_token"}`) req := createRefreshTokenRequest(`{"refresh_token":"valid_refresh_token"}`)
req := httptest.NewRequest("POST", "/api/auth/refresh", body)
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder() w := httptest.NewRecorder()
@@ -1381,8 +1376,7 @@ func TestAuthHandler_RefreshToken(t *testing.T) {
}) })
t.Run("Invalid_Request_Body", func(t *testing.T) { t.Run("Invalid_Request_Body", func(t *testing.T) {
body := bytes.NewBufferString(`invalid json`) req := createRefreshTokenRequest(`invalid json`)
req := httptest.NewRequest("POST", "/api/auth/refresh", body)
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder() w := httptest.NewRecorder()
@@ -1392,8 +1386,7 @@ func TestAuthHandler_RefreshToken(t *testing.T) {
}) })
t.Run("Missing_Refresh_Token", func(t *testing.T) { t.Run("Missing_Refresh_Token", func(t *testing.T) {
body := bytes.NewBufferString(`{"refresh_token":""}`) req := createRefreshTokenRequest(`{"refresh_token":""}`)
req := httptest.NewRequest("POST", "/api/auth/refresh", body)
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder() w := httptest.NewRecorder()
@@ -1407,8 +1400,7 @@ func TestAuthHandler_RefreshToken(t *testing.T) {
return nil, services.ErrRefreshTokenExpired return nil, services.ErrRefreshTokenExpired
} }
body := bytes.NewBufferString(`{"refresh_token":"expired_token"}`) req := createRefreshTokenRequest(`{"refresh_token":"expired_token"}`)
req := httptest.NewRequest("POST", "/api/auth/refresh", body)
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder() w := httptest.NewRecorder()
@@ -1422,8 +1414,7 @@ func TestAuthHandler_RefreshToken(t *testing.T) {
return nil, services.ErrRefreshTokenInvalid return nil, services.ErrRefreshTokenInvalid
} }
body := bytes.NewBufferString(`{"refresh_token":"invalid_token"}`) req := createRefreshTokenRequest(`{"refresh_token":"invalid_token"}`)
req := httptest.NewRequest("POST", "/api/auth/refresh", body)
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder() w := httptest.NewRecorder()
@@ -1437,8 +1428,7 @@ func TestAuthHandler_RefreshToken(t *testing.T) {
return nil, services.ErrAccountLocked return nil, services.ErrAccountLocked
} }
body := bytes.NewBufferString(`{"refresh_token":"locked_token"}`) req := createRefreshTokenRequest(`{"refresh_token":"locked_token"}`)
req := httptest.NewRequest("POST", "/api/auth/refresh", body)
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder() w := httptest.NewRecorder()
@@ -1452,8 +1442,7 @@ func TestAuthHandler_RefreshToken(t *testing.T) {
return nil, fmt.Errorf("internal error") return nil, fmt.Errorf("internal error")
} }
body := bytes.NewBufferString(`{"refresh_token":"error_token"}`) req := createRefreshTokenRequest(`{"refresh_token":"error_token"}`)
req := httptest.NewRequest("POST", "/api/auth/refresh", body)
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder() w := httptest.NewRecorder()
@@ -1473,8 +1462,7 @@ func TestAuthHandler_RevokeToken(t *testing.T) {
return nil return nil
} }
body := bytes.NewBufferString(`{"refresh_token":"token_to_revoke"}`) req := createRevokeTokenRequest(`{"refresh_token":"token_to_revoke"}`)
req := httptest.NewRequest("POST", "/api/auth/revoke", body)
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder() w := httptest.NewRecorder()
@@ -1484,8 +1472,7 @@ func TestAuthHandler_RevokeToken(t *testing.T) {
}) })
t.Run("Invalid_Request_Body", func(t *testing.T) { t.Run("Invalid_Request_Body", func(t *testing.T) {
body := bytes.NewBufferString(`invalid json`) req := createRevokeTokenRequest(`invalid json`)
req := httptest.NewRequest("POST", "/api/auth/revoke", body)
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder() w := httptest.NewRecorder()
@@ -1495,8 +1482,7 @@ func TestAuthHandler_RevokeToken(t *testing.T) {
}) })
t.Run("Missing_Refresh_Token", func(t *testing.T) { t.Run("Missing_Refresh_Token", func(t *testing.T) {
body := bytes.NewBufferString(`{"refresh_token":""}`) req := createRevokeTokenRequest(`{"refresh_token":""}`)
req := httptest.NewRequest("POST", "/api/auth/revoke", body)
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder() w := httptest.NewRecorder()
@@ -1510,8 +1496,7 @@ func TestAuthHandler_RevokeToken(t *testing.T) {
return fmt.Errorf("revoke failed") return fmt.Errorf("revoke failed")
} }
body := bytes.NewBufferString(`{"refresh_token":"token"}`) req := createRevokeTokenRequest(`{"refresh_token":"token"}`)
req := httptest.NewRequest("POST", "/api/auth/revoke", body)
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder() w := httptest.NewRecorder()

View File

@@ -1,6 +1,7 @@
package handlers package handlers
import ( import (
"bytes"
"context" "context"
"encoding/json" "encoding/json"
"errors" "errors"
@@ -11,6 +12,7 @@ import (
"testing" "testing"
"goyco/internal/database" "goyco/internal/database"
"goyco/internal/dto"
"goyco/internal/middleware" "goyco/internal/middleware"
"goyco/internal/services" "goyco/internal/services"
"goyco/internal/testutils" "goyco/internal/testutils"
@@ -721,6 +723,74 @@ func TestDecodeJSONRequest(t *testing.T) {
} }
} }
func createRequestWithDTO[T any](method, url string, body []byte) *http.Request {
r := httptest.NewRequest(method, url, bytes.NewReader(body))
var dto T
if len(body) > 0 {
if err := json.NewDecoder(bytes.NewReader(body)).Decode(&dto); err != nil {
return r
}
}
ctx := middleware.SetValidatedDTOInContext(r.Context(), &dto)
return r.WithContext(ctx)
}
func createLoginRequest(body string) *http.Request {
return createRequestWithDTO[dto.LoginRequest](http.MethodPost, "/api/auth/login", []byte(body))
}
func createRegisterRequest(body string) *http.Request {
return createRequestWithDTO[dto.RegisterRequest](http.MethodPost, "/api/auth/register", []byte(body))
}
func createResendVerificationRequest(body string) *http.Request {
return createRequestWithDTO[dto.ResendVerificationRequest](http.MethodPost, "/api/auth/resend-verification", []byte(body))
}
func createForgotPasswordRequest(body string) *http.Request {
return createRequestWithDTO[dto.ForgotPasswordRequest](http.MethodPost, "/api/auth/forgot-password", []byte(body))
}
func createResetPasswordRequest(body string) *http.Request {
return createRequestWithDTO[dto.ResetPasswordRequest](http.MethodPost, "/api/auth/reset-password", []byte(body))
}
func createUpdateEmailRequest(body string) *http.Request {
return createRequestWithDTO[dto.UpdateEmailRequest](http.MethodPut, "/api/auth/email", []byte(body))
}
func createUpdateUsernameRequest(body string) *http.Request {
return createRequestWithDTO[dto.UpdateUsernameRequest](http.MethodPut, "/api/auth/username", []byte(body))
}
func createUpdatePasswordRequest(body string) *http.Request {
return createRequestWithDTO[dto.UpdatePasswordRequest](http.MethodPut, "/api/auth/password", []byte(body))
}
func createConfirmAccountDeletionRequest(body string) *http.Request {
return createRequestWithDTO[dto.ConfirmAccountDeletionRequest](http.MethodPost, "/api/auth/account/confirm", []byte(body))
}
func createRefreshTokenRequest(body string) *http.Request {
return createRequestWithDTO[dto.RefreshTokenRequest](http.MethodPost, "/api/auth/refresh", []byte(body))
}
func createRevokeTokenRequest(body string) *http.Request {
return createRequestWithDTO[dto.RevokeTokenRequest](http.MethodPost, "/api/auth/revoke", []byte(body))
}
func createCreatePostRequest(body string) *http.Request {
return createRequestWithDTO[dto.CreatePostRequest](http.MethodPost, "/api/posts", []byte(body))
}
func createUpdatePostRequest(body string) *http.Request {
return createRequestWithDTO[dto.UpdatePostRequest](http.MethodPut, "/api/posts/1", []byte(body))
}
func createVoteRequest(body string) *http.Request {
return createRequestWithDTO[dto.CastVoteRequest](http.MethodPost, "/api/posts/1/vote", []byte(body))
}
func TestParsePagination(t *testing.T) { func TestParsePagination(t *testing.T) {
tests := []struct { tests := []struct {
name string name string

View File

@@ -6,6 +6,7 @@ import (
"fmt" "fmt"
"html/template" "html/template"
"net/http" "net/http"
"net/url"
"os" "os"
"path/filepath" "path/filepath"
"strconv" "strconv"
@@ -877,7 +878,8 @@ func (h *PageHandler) ResetPassword(w http.ResponseWriter, r *http.Request) {
func (h *PageHandler) Settings(w http.ResponseWriter, r *http.Request) { func (h *PageHandler) Settings(w http.ResponseWriter, r *http.Request) {
user := h.currentUserWithLockCheck(w, r) user := h.currentUserWithLockCheck(w, r)
if user == nil { if user == nil {
http.Redirect(w, r, "/login?flash=Sign in to manage your account", http.StatusSeeOther) redirectURL := "/login?flash=" + url.QueryEscape("Sign in to manage your account")
http.Redirect(w, r, redirectURL, http.StatusSeeOther)
return return
} }
@@ -897,7 +899,8 @@ func (h *PageHandler) Settings(w http.ResponseWriter, r *http.Request) {
func (h *PageHandler) UpdateEmail(w http.ResponseWriter, r *http.Request) { func (h *PageHandler) UpdateEmail(w http.ResponseWriter, r *http.Request) {
user := h.currentUserWithLockCheck(w, r) user := h.currentUserWithLockCheck(w, r)
if user == nil { if user == nil {
http.Redirect(w, r, "/login?flash=Sign in to manage your account", http.StatusSeeOther) redirectURL := "/login?flash=" + url.QueryEscape("Sign in to manage your account")
http.Redirect(w, r, redirectURL, http.StatusSeeOther)
return return
} }
@@ -960,13 +963,15 @@ func (h *PageHandler) UpdateEmail(w http.ResponseWriter, r *http.Request) {
} }
h.clearAuthCookie(w, r) h.clearAuthCookie(w, r)
http.Redirect(w, r, "/login?flash=Email updated. Check your inbox to confirm the new address. You will need to sign in again after verification.", http.StatusSeeOther) redirectURL := "/login?flash=" + url.QueryEscape("Email updated. Check your inbox to confirm the new address. You will need to sign in again after verification.")
http.Redirect(w, r, redirectURL, http.StatusSeeOther)
} }
func (h *PageHandler) UpdateUsername(w http.ResponseWriter, r *http.Request) { func (h *PageHandler) UpdateUsername(w http.ResponseWriter, r *http.Request) {
user := h.currentUserWithLockCheck(w, r) user := h.currentUserWithLockCheck(w, r)
if user == nil { if user == nil {
http.Redirect(w, r, "/login?flash=Sign in to manage your account", http.StatusSeeOther) redirectURL := "/login?flash=" + url.QueryEscape("Sign in to manage your account")
http.Redirect(w, r, redirectURL, http.StatusSeeOther)
return return
} }
@@ -1022,13 +1027,15 @@ func (h *PageHandler) UpdateUsername(w http.ResponseWriter, r *http.Request) {
return return
} }
http.Redirect(w, r, "/settings?flash=Username updated successfully.", http.StatusSeeOther) redirectURL := "/settings?flash=" + url.QueryEscape("Username updated successfully.")
http.Redirect(w, r, redirectURL, http.StatusSeeOther)
} }
func (h *PageHandler) UpdatePassword(w http.ResponseWriter, r *http.Request) { func (h *PageHandler) UpdatePassword(w http.ResponseWriter, r *http.Request) {
user := h.currentUserWithLockCheck(w, r) user := h.currentUserWithLockCheck(w, r)
if user == nil { if user == nil {
http.Redirect(w, r, "/login?flash=Sign in to manage your account", http.StatusSeeOther) redirectURL := "/login?flash=" + url.QueryEscape("Sign in to manage your account")
http.Redirect(w, r, redirectURL, http.StatusSeeOther)
return return
} }
@@ -1140,13 +1147,15 @@ func (h *PageHandler) UpdatePassword(w http.ResponseWriter, r *http.Request) {
return return
} }
http.Redirect(w, r, "/settings?flash=Password updated successfully.", http.StatusSeeOther) redirectURL := "/settings?flash=" + url.QueryEscape("Password updated successfully.")
http.Redirect(w, r, redirectURL, http.StatusSeeOther)
} }
func (h *PageHandler) DeleteAccount(w http.ResponseWriter, r *http.Request) { func (h *PageHandler) DeleteAccount(w http.ResponseWriter, r *http.Request) {
user := h.currentUserWithLockCheck(w, r) user := h.currentUserWithLockCheck(w, r)
if user == nil { if user == nil {
http.Redirect(w, r, "/login?flash=Sign in to manage your account", http.StatusSeeOther) redirectURL := "/login?flash=" + url.QueryEscape("Sign in to manage your account")
http.Redirect(w, r, redirectURL, http.StatusSeeOther)
return return
} }
@@ -1204,7 +1213,8 @@ func (h *PageHandler) DeleteAccount(w http.ResponseWriter, r *http.Request) {
return return
} }
http.Redirect(w, r, "/settings?flash=Check your inbox for a confirmation link to finish deleting your account.", http.StatusSeeOther) redirectURL := "/settings?flash=" + url.QueryEscape("Check your inbox for a confirmation link to finish deleting your account.")
http.Redirect(w, r, redirectURL, http.StatusSeeOther)
} }
func (h *PageHandler) ConfirmAccountDeletion(w http.ResponseWriter, r *http.Request) { func (h *PageHandler) ConfirmAccountDeletion(w http.ResponseWriter, r *http.Request) {
@@ -1328,7 +1338,8 @@ func (h *PageHandler) clearAuthCookie(w http.ResponseWriter, r *http.Request) {
func (h *PageHandler) Vote(w http.ResponseWriter, r *http.Request) { func (h *PageHandler) Vote(w http.ResponseWriter, r *http.Request) {
user := h.currentUserWithLockCheck(w, r) user := h.currentUserWithLockCheck(w, r)
if user == nil { if user == nil {
http.Redirect(w, r, "/login?flash=Please sign in to vote", http.StatusSeeOther) redirectURL := "/login?flash=" + url.QueryEscape("Please sign in to vote")
http.Redirect(w, r, redirectURL, http.StatusSeeOther)
return return
} }

View File

@@ -130,6 +130,11 @@ func (h *PostHandler) CreatePost(w http.ResponseWriter, r *http.Request) {
url := security.SanitizeURL(req.URL) url := security.SanitizeURL(req.URL)
content := security.SanitizePostContent(req.Content) content := security.SanitizePostContent(req.Content)
if url == "" {
SendErrorResponse(w, "Invalid URL", http.StatusBadRequest)
return
}
if title == "" && h.titleFetcher != nil { if title == "" && h.titleFetcher != nil {
titleCtx, cancel := context.WithTimeout(r.Context(), 7*time.Second) titleCtx, cancel := context.WithTimeout(r.Context(), 7*time.Second)
defer cancel() defer cancel()
@@ -160,6 +165,16 @@ func (h *PostHandler) CreatePost(w http.ResponseWriter, r *http.Request) {
return return
} }
if len(title) > 200 {
SendErrorResponse(w, "Title must be at most 200 characters", http.StatusBadRequest)
return
}
if len(content) > 10000 {
SendErrorResponse(w, "Content must be at most 10000 characters", http.StatusBadRequest)
return
}
post := &database.Post{ post := &database.Post{
Title: title, Title: title,
URL: url, URL: url,

View File

@@ -69,9 +69,8 @@ func TestPostHandlerCreatePostWithTitleFetcher(t *testing.T) {
handler := NewPostHandler(repo, titleFetcher, nil) handler := NewPostHandler(repo, titleFetcher, nil)
request := httptest.NewRequest(http.MethodPost, "/api/posts", bytes.NewBufferString(`{"url":"https://example.com","content":"Test content"}`)) request := createCreatePostRequest(`{"url":"https://example.com","content":"Test content"}`)
request = testutils.WithUserContext(request, middleware.UserIDKey, uint(1)) request = testutils.WithUserContext(request, middleware.UserIDKey, uint(1))
request.Header.Set("Content-Type", "application/json")
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
handler.CreatePost(recorder, request) handler.CreatePost(recorder, request)
@@ -171,7 +170,7 @@ func TestPostHandlerUpdatePostUnauthorized(t *testing.T) {
handler := NewPostHandler(repo, nil, nil) handler := NewPostHandler(repo, nil, nil)
request := httptest.NewRequest(http.MethodPut, "/api/posts/1", bytes.NewBufferString(`{"title":"Updated Title","content":"Updated content"}`)) request := createUpdatePostRequest(`{"title":"Updated Title","content":"Updated content"}`)
request = testutils.WithUserContext(request, middleware.UserIDKey, uint(1)) request = testutils.WithUserContext(request, middleware.UserIDKey, uint(1))
request = testutils.WithURLParams(request, map[string]string{"id": "1"}) request = testutils.WithURLParams(request, map[string]string{"id": "1"})
request.Header.Set("Content-Type", "application/json") request.Header.Set("Content-Type", "application/json")
@@ -278,8 +277,7 @@ func TestPostHandlerCreatePostSuccess(t *testing.T) {
handler := NewPostHandler(repo, fetcher, nil) handler := NewPostHandler(repo, fetcher, nil)
body := bytes.NewBufferString(`{"title":" ","url":"https://example.com","content":"Go"}`) request := createCreatePostRequest(`{"title":" ","url":"https://example.com","content":"Go"}`)
request := httptest.NewRequest(http.MethodPost, "/api/posts", body)
ctx := context.WithValue(request.Context(), middleware.UserIDKey, uint(42)) ctx := context.WithValue(request.Context(), middleware.UserIDKey, uint(42))
request = request.WithContext(ctx) request = request.WithContext(ctx)
@@ -297,7 +295,7 @@ func TestPostHandlerCreatePostValidation(t *testing.T) {
handler := NewPostHandler(testutils.NewPostRepositoryStub(), &testutils.TitleFetcherStub{}, nil) handler := NewPostHandler(testutils.NewPostRepositoryStub(), &testutils.TitleFetcherStub{}, nil)
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
request := httptest.NewRequest(http.MethodPost, "/api/posts", bytes.NewBufferString(`{"title":"","url":"","content":""}`)) request := createCreatePostRequest(`{"title":"","url":"","content":""}`)
request = request.WithContext(context.WithValue(request.Context(), middleware.UserIDKey, uint(1))) request = request.WithContext(context.WithValue(request.Context(), middleware.UserIDKey, uint(1)))
handler.CreatePost(recorder, request) handler.CreatePost(recorder, request)
if recorder.Result().StatusCode != http.StatusBadRequest { if recorder.Result().StatusCode != http.StatusBadRequest {
@@ -305,14 +303,14 @@ func TestPostHandlerCreatePostValidation(t *testing.T) {
} }
recorder = httptest.NewRecorder() recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts", bytes.NewBufferString(`invalid json`)) request = createCreatePostRequest(`invalid json`)
handler.CreatePost(recorder, request) handler.CreatePost(recorder, request)
if recorder.Result().StatusCode != http.StatusBadRequest { if recorder.Result().StatusCode != http.StatusBadRequest {
t.Fatalf("expected 400 for invalid JSON, got %d", recorder.Result().StatusCode) t.Fatalf("expected 400 for invalid JSON, got %d", recorder.Result().StatusCode)
} }
recorder = httptest.NewRecorder() recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts", bytes.NewBufferString(`{"title":"ok","url":"https://example.com"}`)) request = createCreatePostRequest(`{"title":"ok","url":"https://example.com"}`)
handler.CreatePost(recorder, request) handler.CreatePost(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusUnauthorized) testutils.AssertHTTPStatus(t, recorder, http.StatusUnauthorized)
} }
@@ -336,8 +334,7 @@ func TestPostHandlerCreatePostTitleFetcherErrors(t *testing.T) {
return "", tc.err return "", tc.err
}} }}
handler := NewPostHandler(repo, fetcher, nil) handler := NewPostHandler(repo, fetcher, nil)
body := bytes.NewBufferString(`{"title":" ","url":"https://example.com"}`) request := createCreatePostRequest(`{"title":" ","url":"https://example.com"}`)
request := httptest.NewRequest(http.MethodPost, "/api/posts", body)
request = request.WithContext(context.WithValue(request.Context(), middleware.UserIDKey, uint(1))) request = request.WithContext(context.WithValue(request.Context(), middleware.UserIDKey, uint(1)))
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
@@ -495,7 +492,7 @@ func TestPostHandlerUpdatePost(t *testing.T) {
} }
handler := NewPostHandler(repo, &testutils.TitleFetcherStub{}, nil) handler := NewPostHandler(repo, &testutils.TitleFetcherStub{}, nil)
request := httptest.NewRequest(http.MethodPut, "/api/posts/"+tt.postID, bytes.NewBufferString(tt.requestBody)) request := createUpdatePostRequest(tt.requestBody)
if tt.userID > 0 { if tt.userID > 0 {
ctx := context.WithValue(request.Context(), middleware.UserIDKey, tt.userID) ctx := context.WithValue(request.Context(), middleware.UserIDKey, tt.userID)
request = request.WithContext(ctx) request = request.WithContext(ctx)

View File

@@ -24,10 +24,6 @@ func TestPostHandler_XSSProtection_Comprehensive(t *testing.T) {
t.Run("XSS_"+payload[:minLen(20, len(payload))], func(t *testing.T) { t.Run("XSS_"+payload[:minLen(20, len(payload))], func(t *testing.T) {
repo := &testutils.PostRepositoryStub{ repo := &testutils.PostRepositoryStub{
CreateFn: func(post *database.Post) error { CreateFn: func(post *database.Post) error {
sanitizedTitle := security.SanitizeInput(payload)
if post.Title != sanitizedTitle {
t.Errorf("Expected sanitized title, got %q", post.Title)
}
return nil return nil
}, },
} }
@@ -41,14 +37,46 @@ func TestPostHandler_XSSProtection_Comprehensive(t *testing.T) {
} }
body, _ := json.Marshal(postData) body, _ := json.Marshal(postData)
request := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer(body)) request := createCreatePostRequest(string(body))
request.Header.Set("Content-Type", "application/json") request.Header.Set("Content-Type", "application/json")
request = request.WithContext(context.WithValue(request.Context(), middleware.UserIDKey, uint(1))) request = request.WithContext(context.WithValue(request.Context(), middleware.UserIDKey, uint(1)))
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
handler.CreatePost(recorder, request) handler.CreatePost(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusCreated) if recorder.Code != http.StatusCreated {
t.Errorf("Expected status %d, got %d. Body: %s", http.StatusCreated, recorder.Code, recorder.Body.String())
return
}
var response CommonResponse
if err := json.NewDecoder(recorder.Body).Decode(&response); err != nil {
t.Fatalf("Failed to decode response: %v", err)
}
if !response.Success {
t.Errorf("Expected successful response, got error: %s", response.Error)
return
}
dataMap, ok := response.Data.(map[string]any)
if !ok {
t.Fatalf("Expected data to be a map, got %T", response.Data)
}
title, ok := dataMap["title"].(string)
if !ok {
t.Fatalf("Expected title to be a string, got %T", dataMap["title"])
}
expectedSanitized := security.SanitizeInput(payload)
if title != expectedSanitized {
t.Errorf("Expected sanitized title %q, got %q", expectedSanitized, title)
}
if title == payload {
t.Errorf("Title was not sanitized - original payload %q matches response %q", payload, title)
}
}) })
} }
} }
@@ -123,7 +151,7 @@ func TestPostHandler_InputValidation(t *testing.T) {
} }
body, _ := json.Marshal(postData) body, _ := json.Marshal(postData)
request := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer(body)) request := createCreatePostRequest(string(body))
request.Header.Set("Content-Type", "application/json") request.Header.Set("Content-Type", "application/json")
request = request.WithContext(context.WithValue(request.Context(), middleware.UserIDKey, uint(1))) request = request.WithContext(context.WithValue(request.Context(), middleware.UserIDKey, uint(1)))
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
@@ -230,7 +258,7 @@ func TestAuthHandler_PasswordValidation(t *testing.T) {
} }
body, _ := json.Marshal(registerData) body, _ := json.Marshal(registerData)
request := httptest.NewRequest("POST", "/api/auth/register", bytes.NewBuffer(body)) request := createRegisterRequest(string(body))
request.Header.Set("Content-Type", "application/json") request.Header.Set("Content-Type", "application/json")
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
@@ -290,7 +318,7 @@ func TestAuthHandler_UsernameSanitization(t *testing.T) {
} }
body, _ := json.Marshal(registerData) body, _ := json.Marshal(registerData)
request := httptest.NewRequest("POST", "/api/auth/register", bytes.NewBuffer(body)) request := createRegisterRequest(string(body))
request.Header.Set("Content-Type", "application/json") request.Header.Set("Content-Type", "application/json")
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()

View File

@@ -1,7 +1,6 @@
package handlers package handlers
import ( import (
"bytes"
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/http" "net/http"
@@ -103,7 +102,7 @@ func TestUserHandlerCreateUser(t *testing.T) {
return nil return nil
}}) }})
request := httptest.NewRequest(http.MethodPost, "/api/users", bytes.NewBufferString(`{"username":"user","email":"user@example.com","password":"Password123!"}`)) request := createRegisterRequest(`{"username":"user","email":"user@example.com","password":"Password123!"}`)
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
handler.CreateUser(recorder, request) handler.CreateUser(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusCreated) testutils.AssertHTTPStatus(t, recorder, http.StatusCreated)
@@ -126,14 +125,14 @@ func TestUserHandlerCreateUser(t *testing.T) {
} }
recorder = httptest.NewRecorder() recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/users", bytes.NewBufferString("invalid")) request = createRegisterRequest("invalid")
handler.CreateUser(recorder, request) handler.CreateUser(recorder, request)
if recorder.Result().StatusCode != http.StatusBadRequest { if recorder.Result().StatusCode != http.StatusBadRequest {
t.Fatalf("expected 400 for invalid json, got %d", recorder.Result().StatusCode) t.Fatalf("expected 400 for invalid json, got %d", recorder.Result().StatusCode)
} }
recorder = httptest.NewRecorder() recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/users", bytes.NewBufferString(`{"username":"","email":"","password":""}`)) request = createRegisterRequest(`{"username":"","email":"","password":""}`)
handler.CreateUser(recorder, request) handler.CreateUser(recorder, request)
if recorder.Result().StatusCode != http.StatusBadRequest { if recorder.Result().StatusCode != http.StatusBadRequest {
t.Fatalf("expected 400 for missing fields, got %d", recorder.Result().StatusCode) t.Fatalf("expected 400 for missing fields, got %d", recorder.Result().StatusCode)
@@ -144,7 +143,7 @@ func TestUserHandlerCreateUser(t *testing.T) {
} }
handler = newUserHandler(repo) handler = newUserHandler(repo)
recorder = httptest.NewRecorder() recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/users", bytes.NewBufferString(`{"username":"user","email":"user@example.com","password":"Password123!"}`)) request = createRegisterRequest(`{"username":"user","email":"user@example.com","password":"Password123!"}`)
handler.CreateUser(recorder, request) handler.CreateUser(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusConflict) testutils.AssertHTTPStatus(t, recorder, http.StatusConflict)
} }
@@ -350,7 +349,7 @@ func TestUserHandler_PasswordValidation(t *testing.T) {
handler := NewUserHandler(repo, authService) handler := NewUserHandler(repo, authService)
requestBody := fmt.Sprintf(`{"username":"testuser","email":"test@example.com","password":"%s"}`, tt.password) requestBody := fmt.Sprintf(`{"username":"testuser","email":"test@example.com","password":"%s"}`, tt.password)
request := httptest.NewRequest(http.MethodPost, "/api/users", bytes.NewBufferString(requestBody)) request := createRegisterRequest(requestBody)
request.Header.Set("Content-Type", "application/json") request.Header.Set("Content-Type", "application/json")
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()

View File

@@ -58,7 +58,7 @@ type VoteResponse = CommonResponse
// @Produce json // @Produce json
// @Security BearerAuth // @Security BearerAuth
// @Param id path int true "Post ID" // @Param id path int true "Post ID"
// @Param request body dto.VoteRequest true "Vote data (type: 'up', 'down', or 'none' to remove)" // @Param request body dto.CastVoteRequest true "Vote data (type: 'up', 'down', or 'none' to remove)"
// @Success 200 {object} VoteResponse "Vote cast successfully with updated post statistics" // @Success 200 {object} VoteResponse "Vote cast successfully with updated post statistics"
// @Failure 401 {object} VoteResponse "Authentication required" // @Failure 401 {object} VoteResponse "Authentication required"
// @Failure 400 {object} VoteResponse "Invalid request data or vote type" // @Failure 400 {object} VoteResponse "Invalid request data or vote type"
@@ -78,7 +78,7 @@ func (h *VoteHandler) CastVote(w http.ResponseWriter, r *http.Request) {
return return
} }
req, ok := GetValidatedDTO[dto.VoteRequest](r) req, ok := GetValidatedDTO[dto.CastVoteRequest](r)
if !ok { if !ok {
SendErrorResponse(w, "Invalid request", http.StatusBadRequest) SendErrorResponse(w, "Invalid request", http.StatusBadRequest)
return return
@@ -283,7 +283,7 @@ func (h *VoteHandler) MountRoutes(r chi.Router, config RouteModuleConfig) {
protected = config.GeneralRateLimit(protected) protected = config.GeneralRateLimit(protected)
} }
protected.Post("/posts/{id}/vote", WithValidation[dto.VoteRequest](config.ValidationMiddleware, h.CastVote)) protected.Post("/posts/{id}/vote", WithValidation[dto.CastVoteRequest](config.ValidationMiddleware, h.CastVote))
protected.Delete("/posts/{id}/vote", h.RemoveVote) protected.Delete("/posts/{id}/vote", h.RemoveVote)
protected.Get("/posts/{id}/vote", h.GetUserVote) protected.Get("/posts/{id}/vote", h.GetUserVote)
protected.Get("/posts/{id}/votes", h.GetPostVotes) protected.Get("/posts/{id}/votes", h.GetPostVotes)

View File

@@ -1,7 +1,6 @@
package handlers package handlers
import ( import (
"bytes"
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
@@ -59,13 +58,13 @@ func TestVoteHandlerCastVote(t *testing.T) {
handler := newVoteHandlerWithRepos() handler := newVoteHandlerWithRepos()
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
request := httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"up"}`)) request := createVoteRequest(`{"type":"up"}`)
request = testutils.WithURLParams(request, map[string]string{"id": "1"}) request = testutils.WithURLParams(request, map[string]string{"id": "1"})
handler.CastVote(recorder, request) handler.CastVote(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusUnauthorized) testutils.AssertHTTPStatus(t, recorder, http.StatusUnauthorized)
recorder = httptest.NewRecorder() recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/abc/vote", bytes.NewBufferString(`{"type":"up"}`)) request = createVoteRequest(`{"type":"up"}`)
request = testutils.WithURLParams(request, map[string]string{"id": "abc"}) request = testutils.WithURLParams(request, map[string]string{"id": "abc"})
ctx := context.WithValue(request.Context(), middleware.UserIDKey, uint(1)) ctx := context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
request = request.WithContext(ctx) request = request.WithContext(ctx)
@@ -73,7 +72,7 @@ func TestVoteHandlerCastVote(t *testing.T) {
testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest) testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest)
recorder = httptest.NewRecorder() recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`invalid`)) request = createVoteRequest(`invalid`)
request = testutils.WithURLParams(request, map[string]string{"id": "1"}) request = testutils.WithURLParams(request, map[string]string{"id": "1"})
ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(1)) ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
request = request.WithContext(ctx) request = request.WithContext(ctx)
@@ -83,7 +82,7 @@ func TestVoteHandlerCastVote(t *testing.T) {
} }
recorder = httptest.NewRecorder() recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"maybe"}`)) request = createVoteRequest(`{"type":"maybe"}`)
request = testutils.WithURLParams(request, map[string]string{"id": "1"}) request = testutils.WithURLParams(request, map[string]string{"id": "1"})
ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(1)) ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
request = request.WithContext(ctx) request = request.WithContext(ctx)
@@ -93,7 +92,7 @@ func TestVoteHandlerCastVote(t *testing.T) {
} }
recorder = httptest.NewRecorder() recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"up"}`)) request = createVoteRequest(`{"type":"up"}`)
request = testutils.WithURLParams(request, map[string]string{"id": "1"}) request = testutils.WithURLParams(request, map[string]string{"id": "1"})
ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(1)) ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
request = request.WithContext(ctx) request = request.WithContext(ctx)
@@ -101,7 +100,7 @@ func TestVoteHandlerCastVote(t *testing.T) {
testutils.AssertHTTPStatus(t, recorder, http.StatusOK) testutils.AssertHTTPStatus(t, recorder, http.StatusOK)
recorder = httptest.NewRecorder() recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"down"}`)) request = createVoteRequest(`{"type":"down"}`)
request = testutils.WithURLParams(request, map[string]string{"id": "1"}) request = testutils.WithURLParams(request, map[string]string{"id": "1"})
ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(2)) ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(2))
request = request.WithContext(ctx) request = request.WithContext(ctx)
@@ -111,7 +110,7 @@ func TestVoteHandlerCastVote(t *testing.T) {
} }
recorder = httptest.NewRecorder() recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"none"}`)) request = createVoteRequest(`{"type":"none"}`)
request = testutils.WithURLParams(request, map[string]string{"id": "1"}) request = testutils.WithURLParams(request, map[string]string{"id": "1"})
ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(3)) ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(3))
request = request.WithContext(ctx) request = request.WithContext(ctx)
@@ -125,7 +124,7 @@ func TestVoteHandlerCastVotePostNotFound(t *testing.T) {
handler, _, posts := newVoteHandlerWithReposRefs() handler, _, posts := newVoteHandlerWithReposRefs()
delete(posts, 1) delete(posts, 1)
request := httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"up"}`)) request := createVoteRequest(`{"type":"up"}`)
request = testutils.WithURLParams(request, map[string]string{"id": "1"}) request = testutils.WithURLParams(request, map[string]string{"id": "1"})
ctx := context.WithValue(request.Context(), middleware.UserIDKey, uint(1)) ctx := context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
request = request.WithContext(ctx) request = request.WithContext(ctx)
@@ -164,7 +163,7 @@ func TestVoteHandlerRemoveVote(t *testing.T) {
} }
recorder = httptest.NewRecorder() recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"up"}`)) request = createVoteRequest(`{"type":"up"}`)
request = testutils.WithURLParams(request, map[string]string{"id": "1"}) request = testutils.WithURLParams(request, map[string]string{"id": "1"})
ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(1)) ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
request = request.WithContext(ctx) request = request.WithContext(ctx)
@@ -202,7 +201,7 @@ func TestVoteHandlerRemoveVotePostNotFound(t *testing.T) {
func TestVoteHandlerRemoveVoteUnexpectedError(t *testing.T) { func TestVoteHandlerRemoveVoteUnexpectedError(t *testing.T) {
handler, voteRepo, _ := newVoteHandlerWithReposRefs() handler, voteRepo, _ := newVoteHandlerWithReposRefs()
request := httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"up"}`)) request := createVoteRequest(`{"type":"up"}`)
request = testutils.WithURLParams(request, map[string]string{"id": "1"}) request = testutils.WithURLParams(request, map[string]string{"id": "1"})
ctx := context.WithValue(request.Context(), middleware.UserIDKey, uint(1)) ctx := context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
request = request.WithContext(ctx) request = request.WithContext(ctx)
@@ -257,7 +256,7 @@ func TestVoteHandlerGetUserVote(t *testing.T) {
} }
recorder = httptest.NewRecorder() recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"up"}`)) request = createVoteRequest(`{"type":"up"}`)
request = testutils.WithURLParams(request, map[string]string{"id": "1"}) request = testutils.WithURLParams(request, map[string]string{"id": "1"})
ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(1)) ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
request = request.WithContext(ctx) request = request.WithContext(ctx)
@@ -301,7 +300,7 @@ func TestVoteHandlerGetPostVotes(t *testing.T) {
} }
recorder = httptest.NewRecorder() recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"up"}`)) request = createVoteRequest(`{"type":"up"}`)
request = testutils.WithURLParams(request, map[string]string{"id": "1"}) request = testutils.WithURLParams(request, map[string]string{"id": "1"})
ctx := context.WithValue(request.Context(), middleware.UserIDKey, uint(1)) ctx := context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
request = request.WithContext(ctx) request = request.WithContext(ctx)
@@ -311,7 +310,7 @@ func TestVoteHandlerGetPostVotes(t *testing.T) {
} }
recorder = httptest.NewRecorder() recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"down"}`)) request = createVoteRequest(`{"type":"down"}`)
request = testutils.WithURLParams(request, map[string]string{"id": "1"}) request = testutils.WithURLParams(request, map[string]string{"id": "1"})
ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(2)) ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(2))
request = request.WithContext(ctx) request = request.WithContext(ctx)
@@ -345,7 +344,7 @@ func TestVoteFlowRegression(t *testing.T) {
postID := "1" postID := "1"
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
request := httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"up"}`)) request := createVoteRequest(`{"type":"up"}`)
request = testutils.WithURLParams(request, map[string]string{"id": postID}) request = testutils.WithURLParams(request, map[string]string{"id": postID})
ctx := context.WithValue(request.Context(), middleware.UserIDKey, userID) ctx := context.WithValue(request.Context(), middleware.UserIDKey, userID)
request = request.WithContext(ctx) request = request.WithContext(ctx)
@@ -363,7 +362,7 @@ func TestVoteFlowRegression(t *testing.T) {
} }
recorder = httptest.NewRecorder() recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"down"}`)) request = createVoteRequest(`{"type":"down"}`)
request = testutils.WithURLParams(request, map[string]string{"id": postID}) request = testutils.WithURLParams(request, map[string]string{"id": postID})
ctx = context.WithValue(request.Context(), middleware.UserIDKey, userID) ctx = context.WithValue(request.Context(), middleware.UserIDKey, userID)
request = request.WithContext(ctx) request = request.WithContext(ctx)
@@ -373,7 +372,7 @@ func TestVoteFlowRegression(t *testing.T) {
} }
recorder = httptest.NewRecorder() recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"none"}`)) request = createVoteRequest(`{"type":"none"}`)
request = testutils.WithURLParams(request, map[string]string{"id": postID}) request = testutils.WithURLParams(request, map[string]string{"id": postID})
ctx = context.WithValue(request.Context(), middleware.UserIDKey, userID) ctx = context.WithValue(request.Context(), middleware.UserIDKey, userID)
request = request.WithContext(ctx) request = request.WithContext(ctx)
@@ -404,7 +403,7 @@ func TestVoteFlowRegression(t *testing.T) {
postID := "1" postID := "1"
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
request := httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"up"}`)) request := createVoteRequest(`{"type":"up"}`)
request = testutils.WithURLParams(request, map[string]string{"id": postID}) request = testutils.WithURLParams(request, map[string]string{"id": postID})
ctx := context.WithValue(request.Context(), middleware.UserIDKey, uint(1)) ctx := context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
request = request.WithContext(ctx) request = request.WithContext(ctx)
@@ -414,7 +413,7 @@ func TestVoteFlowRegression(t *testing.T) {
} }
recorder = httptest.NewRecorder() recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"down"}`)) request = createVoteRequest(`{"type":"down"}`)
request = testutils.WithURLParams(request, map[string]string{"id": postID}) request = testutils.WithURLParams(request, map[string]string{"id": postID})
ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(2)) ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(2))
request = request.WithContext(ctx) request = request.WithContext(ctx)
@@ -424,7 +423,7 @@ func TestVoteFlowRegression(t *testing.T) {
} }
recorder = httptest.NewRecorder() recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"up"}`)) request = createVoteRequest(`{"type":"up"}`)
request = testutils.WithURLParams(request, map[string]string{"id": postID}) request = testutils.WithURLParams(request, map[string]string{"id": postID})
ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(3)) ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(3))
request = request.WithContext(ctx) request = request.WithContext(ctx)
@@ -452,7 +451,7 @@ func TestVoteFlowRegression(t *testing.T) {
t.Run("ErrorHandlingEdgeCases", func(t *testing.T) { t.Run("ErrorHandlingEdgeCases", func(t *testing.T) {
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
request := httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(``)) request := createVoteRequest(``)
request = testutils.WithURLParams(request, map[string]string{"id": "1"}) request = testutils.WithURLParams(request, map[string]string{"id": "1"})
ctx := context.WithValue(request.Context(), middleware.UserIDKey, uint(1)) ctx := context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
request = request.WithContext(ctx) request = request.WithContext(ctx)
@@ -460,7 +459,7 @@ func TestVoteFlowRegression(t *testing.T) {
testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest) testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest)
recorder = httptest.NewRecorder() recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{}`)) request = createVoteRequest(`{}`)
request = testutils.WithURLParams(request, map[string]string{"id": "1"}) request = testutils.WithURLParams(request, map[string]string{"id": "1"})
ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(1)) ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
request = request.WithContext(ctx) request = request.WithContext(ctx)
@@ -470,7 +469,7 @@ func TestVoteFlowRegression(t *testing.T) {
} }
recorder = httptest.NewRecorder() recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"invalid"}`)) request = createVoteRequest(`{"type":"invalid"}`)
request = testutils.WithURLParams(request, map[string]string{"id": "1"}) request = testutils.WithURLParams(request, map[string]string{"id": "1"})
ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(1)) ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
request = request.WithContext(ctx) request = request.WithContext(ctx)

View File

@@ -53,25 +53,25 @@ func TestIntegration_Caching(t *testing.T) {
router := ctx.Router router := ctx.Router
t.Run("Cache_Hit_On_Repeated_Requests", func(t *testing.T) { t.Run("Cache_Hit_On_Repeated_Requests", func(t *testing.T) {
req1 := httptest.NewRequest("GET", "/api/posts", nil) firstRequest := httptest.NewRequest("GET", "/api/posts", nil)
rec1 := httptest.NewRecorder() firstRecorder := httptest.NewRecorder()
router.ServeHTTP(rec1, req1) router.ServeHTTP(firstRecorder, firstRequest)
time.Sleep(10 * time.Millisecond) time.Sleep(10 * time.Millisecond)
req2 := httptest.NewRequest("GET", "/api/posts", nil) secondRequest := httptest.NewRequest("GET", "/api/posts", nil)
rec2 := httptest.NewRecorder() secondRecorder := httptest.NewRecorder()
router.ServeHTTP(rec2, req2) router.ServeHTTP(secondRecorder, secondRequest)
if rec1.Code != rec2.Code { if firstRecorder.Code != secondRecorder.Code {
t.Error("Cached responses should have same status code") t.Error("Cached responses should have same status code")
} }
if rec1.Body.String() != rec2.Body.String() { if firstRecorder.Body.String() != secondRecorder.Body.String() {
t.Error("Cached responses should have same body") t.Error("Cached responses should have same body")
} }
if rec2.Header().Get("X-Cache") != "HIT" { if secondRecorder.Header().Get("X-Cache") != "HIT" {
t.Log("Cache may not be enabled for this path or response may not be cacheable") t.Log("Cache may not be enabled for this path or response may not be cacheable")
} }
}) })
@@ -80,9 +80,9 @@ func TestIntegration_Caching(t *testing.T) {
ctx.Suite.EmailSender.Reset() ctx.Suite.EmailSender.Reset()
user := createUserWithCleanup(t, ctx, "cache_post_user", "cache_post@example.com") user := createUserWithCleanup(t, ctx, "cache_post_user", "cache_post@example.com")
req1 := httptest.NewRequest("GET", "/api/posts", nil) firstRequest := httptest.NewRequest("GET", "/api/posts", nil)
rec1 := httptest.NewRecorder() firstRecorder := httptest.NewRecorder()
router.ServeHTTP(rec1, req1) router.ServeHTTP(firstRecorder, firstRequest)
time.Sleep(10 * time.Millisecond) time.Sleep(10 * time.Millisecond)
@@ -92,12 +92,12 @@ func TestIntegration_Caching(t *testing.T) {
"content": "Test content", "content": "Test content",
} }
body, _ := json.Marshal(postBody) body, _ := json.Marshal(postBody)
req2 := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer(body)) secondRequest := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer(body))
req2.Header.Set("Content-Type", "application/json") secondRequest.Header.Set("Content-Type", "application/json")
req2.Header.Set("Authorization", "Bearer "+user.Token) secondRequest.Header.Set("Authorization", "Bearer "+user.Token)
req2 = testutils.WithUserContext(req2, middleware.UserIDKey, user.User.ID) secondRequest = testutils.WithUserContext(secondRequest, middleware.UserIDKey, user.User.ID)
rec2 := httptest.NewRecorder() secondRecorder := httptest.NewRecorder()
router.ServeHTTP(rec2, req2) router.ServeHTTP(secondRecorder, secondRequest)
time.Sleep(10 * time.Millisecond) time.Sleep(10 * time.Millisecond)
@@ -105,17 +105,17 @@ func TestIntegration_Caching(t *testing.T) {
rec3 := httptest.NewRecorder() rec3 := httptest.NewRecorder()
router.ServeHTTP(rec3, req3) router.ServeHTTP(rec3, req3)
if rec1.Body.String() == rec3.Body.String() && rec1.Code == http.StatusOK && rec3.Code == http.StatusOK { if firstRecorder.Body.String() == rec3.Body.String() && firstRecorder.Code == http.StatusOK && rec3.Code == http.StatusOK {
t.Log("Cache invalidation may not be working or cache may not be enabled") t.Log("Cache invalidation may not be working or cache may not be enabled")
} }
}) })
t.Run("Cache_Headers_Present", func(t *testing.T) { t.Run("Cache_Headers_Present", func(t *testing.T) {
req := httptest.NewRequest("GET", "/api/posts", nil) request := httptest.NewRequest("GET", "/api/posts", nil)
rec := httptest.NewRecorder() recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req) router.ServeHTTP(recorder, request)
if rec.Header().Get("Cache-Control") == "" && rec.Header().Get("X-Cache") == "" { if recorder.Header().Get("Cache-Control") == "" && recorder.Header().Get("X-Cache") == "" {
t.Log("Cache headers may not be present for all responses") t.Log("Cache headers may not be present for all responses")
} }
}) })
@@ -126,18 +126,18 @@ func TestIntegration_Caching(t *testing.T) {
post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Cache Delete Post", "https://example.com/cache-delete") post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Cache Delete Post", "https://example.com/cache-delete")
req1 := httptest.NewRequest("GET", "/api/posts", nil) firstRequest := httptest.NewRequest("GET", "/api/posts", nil)
rec1 := httptest.NewRecorder() firstRecorder := httptest.NewRecorder()
router.ServeHTTP(rec1, req1) router.ServeHTTP(firstRecorder, firstRequest)
time.Sleep(10 * time.Millisecond) time.Sleep(10 * time.Millisecond)
req2 := httptest.NewRequest("DELETE", "/api/posts/"+fmt.Sprintf("%d", post.ID), nil) secondRequest := httptest.NewRequest("DELETE", "/api/posts/"+fmt.Sprintf("%d", post.ID), nil)
req2.Header.Set("Authorization", "Bearer "+user.Token) secondRequest.Header.Set("Authorization", "Bearer "+user.Token)
req2 = testutils.WithUserContext(req2, middleware.UserIDKey, user.User.ID) secondRequest = testutils.WithUserContext(secondRequest, middleware.UserIDKey, user.User.ID)
req2 = testutils.WithURLParams(req2, map[string]string{"id": fmt.Sprintf("%d", post.ID)}) secondRequest = testutils.WithURLParams(secondRequest, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
rec2 := httptest.NewRecorder() secondRecorder := httptest.NewRecorder()
router.ServeHTTP(rec2, req2) router.ServeHTTP(secondRecorder, secondRequest)
time.Sleep(10 * time.Millisecond) time.Sleep(10 * time.Millisecond)
@@ -145,7 +145,7 @@ func TestIntegration_Caching(t *testing.T) {
rec3 := httptest.NewRecorder() rec3 := httptest.NewRecorder()
router.ServeHTTP(rec3, req3) router.ServeHTTP(rec3, req3)
if rec1.Body.String() == rec3.Body.String() && rec1.Code == http.StatusOK && rec3.Code == http.StatusOK { if firstRecorder.Body.String() == rec3.Body.String() && firstRecorder.Code == http.StatusOK && rec3.Code == http.StatusOK {
t.Log("Cache invalidation may not be working or cache may not be enabled") t.Log("Cache invalidation may not be working or cache may not be enabled")
} }
}) })

View File

@@ -1,15 +1,14 @@
package integration package integration
import ( import (
"bytes"
"encoding/json"
"fmt" "fmt"
"net/http" "net/http"
"net/http/httptest"
"net/url" "net/url"
"strings"
"testing" "testing"
"time"
"goyco/internal/middleware" "goyco/internal/services"
"goyco/internal/testutils" "goyco/internal/testutils"
) )
@@ -20,17 +19,8 @@ func TestIntegration_CompleteAPIEndpoints(t *testing.T) {
ctx.Suite.EmailSender.Reset() ctx.Suite.EmailSender.Reset()
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "logout_user", "logout@example.com") user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "logout_user", "logout@example.com")
reqBody := map[string]string{} request := makePostRequest(t, ctx.Router, "/api/auth/logout", map[string]any{}, user, nil)
body, _ := json.Marshal(reqBody) assertStatus(t, request, http.StatusOK)
req := httptest.NewRequest("POST", "/api/auth/logout", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
rec := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req)
assertStatus(t, rec, http.StatusOK)
}) })
t.Run("Auth_Revoke_Token_Endpoint", func(t *testing.T) { t.Run("Auth_Revoke_Token_Endpoint", func(t *testing.T) {
@@ -42,52 +32,23 @@ func TestIntegration_CompleteAPIEndpoints(t *testing.T) {
t.Fatalf("Failed to login: %v", err) t.Fatalf("Failed to login: %v", err)
} }
reqBody := map[string]string{ request := makePostRequest(t, ctx.Router, "/api/auth/revoke", map[string]any{"refresh_token": loginResult.RefreshToken}, user, nil)
"refresh_token": loginResult.RefreshToken, assertStatus(t, request, http.StatusOK)
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("POST", "/api/auth/revoke", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
rec := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req)
assertStatus(t, rec, http.StatusOK)
}) })
t.Run("Auth_Revoke_All_Tokens_Endpoint", func(t *testing.T) { t.Run("Auth_Revoke_All_Tokens_Endpoint", func(t *testing.T) {
ctx.Suite.EmailSender.Reset() ctx.Suite.EmailSender.Reset()
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "revoke_all_user", "revoke_all@example.com") user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "revoke_all_user", "revoke_all@example.com")
reqBody := map[string]string{} request := makePostRequest(t, ctx.Router, "/api/auth/revoke-all", map[string]any{}, user, nil)
body, _ := json.Marshal(reqBody) assertStatus(t, request, http.StatusOK)
req := httptest.NewRequest("POST", "/api/auth/revoke-all", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
rec := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req)
assertStatus(t, rec, http.StatusOK)
}) })
t.Run("Auth_Resend_Verification_Endpoint", func(t *testing.T) { t.Run("Auth_Resend_Verification_Endpoint", func(t *testing.T) {
ctx.Suite.EmailSender.Reset() ctx.Suite.EmailSender.Reset()
reqBody := map[string]string{ request := makePostRequestWithJSON(t, ctx.Router, "/api/auth/resend-verification", map[string]any{"email": "resend@example.com"})
"email": "resend@example.com", assertStatusRange(t, request, http.StatusOK, http.StatusNotFound)
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("POST", "/api/auth/resend-verification", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req)
assertStatusRange(t, rec, http.StatusOK, http.StatusNotFound)
}) })
t.Run("Auth_Confirm_Email_Endpoint", func(t *testing.T) { t.Run("Auth_Confirm_Email_Endpoint", func(t *testing.T) {
@@ -99,109 +60,66 @@ func TestIntegration_CompleteAPIEndpoints(t *testing.T) {
token = "test-token" token = "test-token"
} }
req := httptest.NewRequest("GET", "/api/auth/confirm?token="+url.QueryEscape(token), nil) request := makeGetRequest(t, ctx.Router, "/api/auth/confirm?token="+url.QueryEscape(token))
rec := httptest.NewRecorder() assertStatusRange(t, request, http.StatusOK, http.StatusBadRequest)
ctx.Router.ServeHTTP(rec, req)
assertStatusRange(t, rec, http.StatusOK, http.StatusBadRequest)
}) })
t.Run("Auth_Update_Email_Endpoint", func(t *testing.T) { t.Run("Auth_Update_Email_Endpoint", func(t *testing.T) {
ctx.Suite.EmailSender.Reset() ctx.Suite.EmailSender.Reset()
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "update_email_api_user", "update_email_api@example.com") user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "update_email_api_user", "update_email_api@example.com")
reqBody := map[string]string{ request := makePutRequest(t, ctx.Router, "/api/auth/email", map[string]any{"email": "newemail@example.com"}, user, nil)
"email": "newemail@example.com",
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("PUT", "/api/auth/email", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
rec := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req) response := assertJSONResponse(t, request, http.StatusOK)
if data, ok := getDataFromResponse(response); ok {
response := assertJSONResponse(t, rec, http.StatusOK)
if response != nil {
if data, ok := response["data"].(map[string]any); ok {
if email, ok := data["email"].(string); ok && email != "newemail@example.com" { if email, ok := data["email"].(string); ok && email != "newemail@example.com" {
t.Errorf("Expected email to be updated, got %s", email) t.Errorf("Expected email to be updated, got %s", email)
} }
} }
}
}) })
t.Run("Auth_Update_Username_Endpoint", func(t *testing.T) { t.Run("Auth_Update_Username_Endpoint", func(t *testing.T) {
ctx.Suite.EmailSender.Reset() ctx.Suite.EmailSender.Reset()
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "update_username_api_user", "update_username_api@example.com") user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "update_username_api_user", "update_username_api@example.com")
reqBody := map[string]string{ request := makePutRequest(t, ctx.Router, "/api/auth/username", map[string]any{"username": "new_username"}, user, nil)
"username": "new_username",
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("PUT", "/api/auth/username", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
rec := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req) response := assertJSONResponse(t, request, http.StatusOK)
if data, ok := getDataFromResponse(response); ok {
response := assertJSONResponse(t, rec, http.StatusOK)
if response != nil {
if data, ok := response["data"].(map[string]any); ok {
if username, ok := data["username"].(string); ok && username != "new_username" { if username, ok := data["username"].(string); ok && username != "new_username" {
t.Errorf("Expected username to be updated, got %s", username) t.Errorf("Expected username to be updated, got %s", username)
} }
} }
}
}) })
t.Run("Users_List_Endpoint", func(t *testing.T) { t.Run("Users_List_Endpoint", func(t *testing.T) {
ctx.Suite.EmailSender.Reset() ctx.Suite.EmailSender.Reset()
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "users_list_user", "users_list@example.com") user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "users_list_user", "users_list@example.com")
req := httptest.NewRequest("GET", "/api/users", nil) request := makeAuthenticatedGetRequest(t, ctx.Router, "/api/users", user, nil)
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
rec := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req) response := assertJSONResponse(t, request, http.StatusOK)
if data, ok := getDataFromResponse(response); ok {
response := assertJSONResponse(t, rec, http.StatusOK)
if response != nil {
if data, ok := response["data"].(map[string]any); ok {
if _, exists := data["users"]; !exists { if _, exists := data["users"]; !exists {
t.Error("Expected users in response") t.Error("Expected users in response")
} }
} }
}
}) })
t.Run("Users_Get_By_ID_Endpoint", func(t *testing.T) { t.Run("Users_Get_By_ID_Endpoint", func(t *testing.T) {
ctx.Suite.EmailSender.Reset() ctx.Suite.EmailSender.Reset()
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "users_get_user", "users_get@example.com") user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "users_get_user", "users_get@example.com")
req := httptest.NewRequest("GET", fmt.Sprintf("/api/users/%d", user.User.ID), nil) request := makeAuthenticatedGetRequest(t, ctx.Router, fmt.Sprintf("/api/users/%d", user.User.ID), user, map[string]string{"id": fmt.Sprintf("%d", user.User.ID)})
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
req = testutils.WithURLParams(req, map[string]string{"id": fmt.Sprintf("%d", user.User.ID)})
rec := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req) response := assertJSONResponse(t, request, http.StatusOK)
if data, ok := getDataFromResponse(response); ok {
response := assertJSONResponse(t, rec, http.StatusOK)
if response != nil {
if data, ok := response["data"].(map[string]any); ok {
if userData, ok := data["user"].(map[string]any); ok { if userData, ok := data["user"].(map[string]any); ok {
if id, ok := userData["id"].(float64); ok && uint(id) != user.User.ID { if id, ok := userData["id"].(float64); ok && uint(id) != user.User.ID {
t.Errorf("Expected user ID %d, got %.0f", user.User.ID, id) t.Errorf("Expected user ID %d, got %.0f", user.User.ID, id)
} }
} }
} }
}
}) })
t.Run("Users_Get_Posts_Endpoint", func(t *testing.T) { t.Run("Users_Get_Posts_Endpoint", func(t *testing.T) {
@@ -210,17 +128,10 @@ func TestIntegration_CompleteAPIEndpoints(t *testing.T) {
testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "User Posts Test", "https://example.com/user-posts") testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "User Posts Test", "https://example.com/user-posts")
req := httptest.NewRequest("GET", fmt.Sprintf("/api/users/%d/posts", user.User.ID), nil) request := makeAuthenticatedGetRequest(t, ctx.Router, fmt.Sprintf("/api/users/%d/posts", user.User.ID), user, map[string]string{"id": fmt.Sprintf("%d", user.User.ID)})
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
req = testutils.WithURLParams(req, map[string]string{"id": fmt.Sprintf("%d", user.User.ID)})
rec := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req) response := assertJSONResponse(t, request, http.StatusOK)
if data, ok := getDataFromResponse(response); ok {
response := assertJSONResponse(t, rec, http.StatusOK)
if response != nil {
if data, ok := response["data"].(map[string]any); ok {
if posts, ok := data["posts"].([]any); ok { if posts, ok := data["posts"].([]any); ok {
if len(posts) == 0 { if len(posts) == 0 {
t.Error("Expected at least one post in response") t.Error("Expected at least one post in response")
@@ -229,35 +140,24 @@ func TestIntegration_CompleteAPIEndpoints(t *testing.T) {
t.Error("Expected posts array in response") t.Error("Expected posts array in response")
} }
} }
}
}) })
t.Run("Users_Create_Endpoint", func(t *testing.T) { t.Run("Users_Create_Endpoint", func(t *testing.T) {
ctx.Suite.EmailSender.Reset() ctx.Suite.EmailSender.Reset()
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "users_create_admin", "users_create_admin@example.com") user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "users_create_admin", "users_create_admin@example.com")
reqBody := map[string]string{ request := makePostRequest(t, ctx.Router, "/api/users", map[string]any{
"username": "created_user", "username": "created_user",
"email": "created@example.com", "email": "created@example.com",
"password": "SecurePass123!", "password": "SecurePass123!",
} }, user, nil)
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("POST", "/api/users", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
rec := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req) response := assertJSONResponse(t, request, http.StatusCreated)
if data, ok := getDataFromResponse(response); ok {
response := assertJSONResponse(t, rec, http.StatusCreated)
if response != nil {
if data, ok := response["data"].(map[string]any); ok {
if _, exists := data["user"]; !exists { if _, exists := data["user"]; !exists {
t.Error("Expected user in response") t.Error("Expected user in response")
} }
} }
}
}) })
t.Run("Posts_Update_Endpoint", func(t *testing.T) { t.Run("Posts_Update_Endpoint", func(t *testing.T) {
@@ -266,30 +166,19 @@ func TestIntegration_CompleteAPIEndpoints(t *testing.T) {
post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Update Test Post", "https://example.com/update-test") post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Update Test Post", "https://example.com/update-test")
reqBody := map[string]string{ request := makePutRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d", post.ID), map[string]any{
"title": "Updated Title", "title": "Updated Title",
"content": "Updated content", "content": "Updated content",
} }, user, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("PUT", fmt.Sprintf("/api/posts/%d", post.ID), bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
req = testutils.WithURLParams(req, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
rec := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req) response := assertJSONResponse(t, request, http.StatusOK)
if data, ok := getDataFromResponse(response); ok {
response := assertJSONResponse(t, rec, http.StatusOK)
if response != nil {
if data, ok := response["data"].(map[string]any); ok {
if postData, ok := data["post"].(map[string]any); ok { if postData, ok := data["post"].(map[string]any); ok {
if title, ok := postData["title"].(string); ok && title != "Updated Title" { if title, ok := postData["title"].(string); ok && title != "Updated Title" {
t.Errorf("Expected title 'Updated Title', got '%s'", title) t.Errorf("Expected title 'Updated Title', got '%s'", title)
} }
} }
} }
}
}) })
t.Run("Posts_Delete_Endpoint", func(t *testing.T) { t.Run("Posts_Delete_Endpoint", func(t *testing.T) {
@@ -298,20 +187,11 @@ func TestIntegration_CompleteAPIEndpoints(t *testing.T) {
post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Delete Test Post", "https://example.com/delete-test") post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Delete Test Post", "https://example.com/delete-test")
req := httptest.NewRequest("DELETE", fmt.Sprintf("/api/posts/%d", post.ID), nil) request := makeDeleteRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d", post.ID), user, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
req.Header.Set("Authorization", "Bearer "+user.Token) assertStatus(t, request, http.StatusOK)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
req = testutils.WithURLParams(req, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
rec := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req) getRequest := makeGetRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d", post.ID))
assertStatus(t, getRequest, http.StatusNotFound)
assertStatus(t, rec, http.StatusOK)
getReq := httptest.NewRequest("GET", fmt.Sprintf("/api/posts/%d", post.ID), nil)
getRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(getRec, getReq)
assertStatus(t, getRec, http.StatusNotFound)
}) })
t.Run("Votes_Get_All_Endpoint", func(t *testing.T) { t.Run("Votes_Get_All_Endpoint", func(t *testing.T) {
@@ -319,28 +199,11 @@ func TestIntegration_CompleteAPIEndpoints(t *testing.T) {
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "votes_get_all_user", "votes_get_all@example.com") user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "votes_get_all_user", "votes_get_all@example.com")
post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Votes Test Post", "https://example.com/votes-test") post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Votes Test Post", "https://example.com/votes-test")
makePostRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d/vote", post.ID), map[string]any{"type": "up"}, user, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
request := makeAuthenticatedGetRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d/votes", post.ID), user, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
voteBody := map[string]string{"type": "up"} response := assertJSONResponse(t, request, http.StatusOK)
voteBodyBytes, _ := json.Marshal(voteBody) if data, ok := getDataFromResponse(response); ok {
voteReq := httptest.NewRequest("POST", fmt.Sprintf("/api/posts/%d/vote", post.ID), bytes.NewBuffer(voteBodyBytes))
voteReq.Header.Set("Content-Type", "application/json")
voteReq.Header.Set("Authorization", "Bearer "+user.Token)
voteReq = testutils.WithUserContext(voteReq, middleware.UserIDKey, user.User.ID)
voteReq = testutils.WithURLParams(voteReq, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
voteRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(voteRec, voteReq)
req := httptest.NewRequest("GET", fmt.Sprintf("/api/posts/%d/votes", post.ID), nil)
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
req = testutils.WithURLParams(req, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
rec := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req)
response := assertJSONResponse(t, rec, http.StatusOK)
if response != nil {
if data, ok := response["data"].(map[string]any); ok {
if votes, ok := data["votes"].([]any); ok { if votes, ok := data["votes"].([]any); ok {
if len(votes) == 0 { if len(votes) == 0 {
t.Error("Expected at least one vote in response") t.Error("Expected at least one vote in response")
@@ -349,7 +212,6 @@ func TestIntegration_CompleteAPIEndpoints(t *testing.T) {
t.Error("Expected votes array in response") t.Error("Expected votes array in response")
} }
} }
}
}) })
t.Run("Votes_Remove_Endpoint", func(t *testing.T) { t.Run("Votes_Remove_Endpoint", func(t *testing.T) {
@@ -358,49 +220,461 @@ func TestIntegration_CompleteAPIEndpoints(t *testing.T) {
post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Vote Remove Test", "https://example.com/vote-remove") post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Vote Remove Test", "https://example.com/vote-remove")
voteBody := map[string]string{"type": "up"} makePostRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d/vote", post.ID), map[string]any{"type": "up"}, user, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
voteBodyBytes, _ := json.Marshal(voteBody)
voteReq := httptest.NewRequest("POST", fmt.Sprintf("/api/posts/%d/vote", post.ID), bytes.NewBuffer(voteBodyBytes))
voteReq.Header.Set("Content-Type", "application/json")
voteReq.Header.Set("Authorization", "Bearer "+user.Token)
voteReq = testutils.WithUserContext(voteReq, middleware.UserIDKey, user.User.ID)
voteReq = testutils.WithURLParams(voteReq, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
voteRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(voteRec, voteReq)
req := httptest.NewRequest("DELETE", fmt.Sprintf("/api/posts/%d/vote", post.ID), nil) request := makeDeleteRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d/vote", post.ID), user, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
req.Header.Set("Authorization", "Bearer "+user.Token) assertStatus(t, request, http.StatusOK)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
req = testutils.WithURLParams(req, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
rec := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req)
assertStatus(t, rec, http.StatusOK)
}) })
t.Run("API_Info_Endpoint", func(t *testing.T) { t.Run("API_Info_Endpoint", func(t *testing.T) {
req := httptest.NewRequest("GET", "/api", nil) request := makeGetRequest(t, ctx.Router, "/api")
rec := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req) response := assertJSONResponse(t, request, http.StatusOK)
if data, ok := getDataFromResponse(response); ok {
response := assertJSONResponse(t, rec, http.StatusOK)
if response != nil {
if data, ok := response["data"].(map[string]any); ok {
if _, exists := data["endpoints"]; !exists { if _, exists := data["endpoints"]; !exists {
t.Error("Expected endpoints in API info") t.Error("Expected endpoints in API info")
} }
} }
}
}) })
t.Run("Swagger_Documentation_Endpoint", func(t *testing.T) { t.Run("Swagger_Documentation_Endpoint", func(t *testing.T) {
req := httptest.NewRequest("GET", "/swagger/index.html", nil) request := makeGetRequest(t, ctx.Router, "/swagger/index.html")
rec := httptest.NewRecorder() assertStatusRange(t, request, http.StatusOK, http.StatusNotFound)
})
ctx.Router.ServeHTTP(rec, req) t.Run("Search_Endpoint_Edge_Cases", func(t *testing.T) {
ctx.Suite.EmailSender.Reset()
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, uniqueTestUsername(t, "search_edge"), uniqueTestEmail(t, "search_edge"))
assertStatusRange(t, rec, http.StatusOK, http.StatusNotFound) testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Searchable Post One", "https://example.com/one")
testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Searchable Post Two", "https://example.com/two")
testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Different Content", "https://example.com/three")
t.Run("Empty_Search_Results", func(t *testing.T) {
request := makeGetRequest(t, ctx.Router, "/api/posts/search?q=nonexistentterm12345")
response := assertJSONResponse(t, request, http.StatusOK)
if data, ok := getDataFromResponse(response); ok {
if posts, ok := data["posts"].([]any); ok {
if len(posts) != 0 {
t.Errorf("Expected empty search results, got %d posts", len(posts))
}
}
if count, ok := data["count"].(float64); ok && count != 0 {
t.Errorf("Expected count 0, got %.0f", count)
}
}
})
t.Run("Search_With_Pagination", func(t *testing.T) {
request := makeGetRequest(t, ctx.Router, "/api/posts/search?q=Searchable&limit=1&offset=0")
response := assertJSONResponse(t, request, http.StatusOK)
var firstPostID any
if data, ok := getDataFromResponse(response); ok {
if posts, ok := data["posts"].([]any); ok {
if len(posts) > 1 {
t.Errorf("Expected at most 1 post with limit=1, got %d", len(posts))
}
if len(posts) > 0 {
if post, ok := posts[0].(map[string]any); ok {
firstPostID = post["id"]
}
}
}
if limit, ok := data["limit"].(float64); ok && limit != 1 {
t.Errorf("Expected limit 1 in response, got %.0f", limit)
}
if offset, ok := data["offset"].(float64); ok && offset != 0 {
t.Errorf("Expected offset 0 in response, got %.0f", offset)
}
}
secondRequest := makeGetRequest(t, ctx.Router, "/api/posts/search?q=Searchable&limit=1&offset=1")
secondResponse := assertJSONResponse(t, secondRequest, http.StatusOK)
if data, ok := getDataFromResponse(secondResponse); ok {
if posts, ok := data["posts"].([]any); ok {
if len(posts) > 1 {
t.Errorf("Expected at most 1 post with limit=1 and offset=1, got %d", len(posts))
}
if len(posts) > 0 && firstPostID != nil {
if post, ok := posts[0].(map[string]any); ok {
if post["id"] == firstPostID {
t.Error("Expected different post with offset=1, got same post as offset=0")
}
}
}
}
}
})
t.Run("Search_With_Special_Characters", func(t *testing.T) {
specialQueries := []string{
"Searchable%20Post",
"Searchable'Post",
"Searchable\"Post",
"Searchable;Post",
"Searchable--Post",
}
for _, query := range specialQueries {
request := makeGetRequest(t, ctx.Router, "/api/posts/search?q="+url.QueryEscape(query))
assertStatusRange(t, request, http.StatusOK, http.StatusBadRequest)
}
})
t.Run("Search_Empty_Query", func(t *testing.T) {
request := makeGetRequest(t, ctx.Router, "/api/posts/search?q=")
response := assertJSONResponse(t, request, http.StatusOK)
if data, ok := getDataFromResponse(response); ok {
if posts, ok := data["posts"].([]any); ok {
if len(posts) != 0 {
t.Errorf("Expected empty results for empty query, got %d posts", len(posts))
}
}
if count, ok := data["count"].(float64); ok && count != 0 {
t.Errorf("Expected count 0 for empty query, got %.0f", count)
}
}
})
t.Run("Search_With_Very_Long_Query", func(t *testing.T) {
longQuery := strings.Repeat("a", 1000)
request := makeGetRequest(t, ctx.Router, "/api/posts/search?q="+url.QueryEscape(longQuery))
assertStatusRange(t, request, http.StatusOK, http.StatusBadRequest)
})
t.Run("Search_Case_Insensitive", func(t *testing.T) {
request := makeGetRequest(t, ctx.Router, "/api/posts/search?q=SEARCHABLE")
response := assertJSONResponse(t, request, http.StatusOK)
if data, ok := getDataFromResponse(response); ok {
if posts, ok := data["posts"].([]any); ok {
if len(posts) == 0 {
t.Error("Expected case-insensitive search to find posts")
}
}
}
})
})
t.Run("Title_Fetch_Endpoint_Edge_Cases", func(t *testing.T) {
ctx.Suite.EmailSender.Reset()
t.Run("Missing_URL_Parameter", func(t *testing.T) {
request := makeGetRequest(t, ctx.Router, "/api/posts/title")
assertErrorResponse(t, request, http.StatusBadRequest)
})
t.Run("Empty_URL_Parameter", func(t *testing.T) {
request := makeGetRequest(t, ctx.Router, "/api/posts/title?url=")
assertErrorResponse(t, request, http.StatusBadRequest)
})
t.Run("Invalid_URL_Format", func(t *testing.T) {
invalidURLs := []string{
"not-a-url",
"://invalid",
"http://",
"https://",
}
for _, invalidURL := range invalidURLs {
ctx.Suite.TitleFetcher.SetError(services.ErrUnsupportedScheme)
request := makeGetRequest(t, ctx.Router, "/api/posts/title?url="+url.QueryEscape(invalidURL))
assertErrorResponse(t, request, http.StatusBadRequest)
}
})
t.Run("Unsupported_URL_Schemes", func(t *testing.T) {
unsupportedSchemes := []string{
"ftp://example.com",
"file:///etc/passwd",
"javascript:alert(1)",
"data:text/html,<script>alert(1)</script>",
}
for _, schemeURL := range unsupportedSchemes {
request := makeGetRequest(t, ctx.Router, "/api/posts/title?url="+url.QueryEscape(schemeURL))
assertErrorResponse(t, request, http.StatusBadRequest)
}
})
t.Run("SSRF_Protection_Localhost", func(t *testing.T) {
ssrfURLs := []string{
"http://localhost",
"http://127.0.0.1",
"http://127.0.0.1:8080",
"http://[::1]",
"http://0.0.0.0",
}
for _, ssrfURL := range ssrfURLs {
ctx.Suite.TitleFetcher.SetError(services.ErrSSRFBlocked)
request := makeGetRequest(t, ctx.Router, "/api/posts/title?url="+url.QueryEscape(ssrfURL))
assertStatusRange(t, request, http.StatusBadRequest, http.StatusBadGateway)
}
})
t.Run("SSRF_Protection_Private_IPs", func(t *testing.T) {
privateIPs := []string{
"http://192.168.1.1",
"http://10.0.0.1",
"http://172.16.0.1",
}
for _, privateIP := range privateIPs {
ctx.Suite.TitleFetcher.SetError(services.ErrSSRFBlocked)
request := makeGetRequest(t, ctx.Router, "/api/posts/title?url="+url.QueryEscape(privateIP))
assertStatusRange(t, request, http.StatusBadRequest, http.StatusBadGateway)
}
})
t.Run("Title_Fetch_Error_Handling", func(t *testing.T) {
ctx.Suite.TitleFetcher.SetError(services.ErrTitleNotFound)
request := makeGetRequest(t, ctx.Router, "/api/posts/title?url=https://example.com/notitle")
assertErrorResponse(t, request, http.StatusBadRequest)
})
t.Run("Valid_URL_Success", func(t *testing.T) {
ctx.Suite.TitleFetcher.SetTitle("Valid Title")
request := makeGetRequest(t, ctx.Router, "/api/posts/title?url=https://example.com/valid")
response := assertJSONResponse(t, request, http.StatusOK)
if data, ok := getDataFromResponse(response); ok {
if title, ok := data["title"].(string); ok {
if title != "Valid Title" {
t.Errorf("Expected title 'Valid Title', got '%s'", title)
}
} else {
t.Error("Expected title in response data")
}
}
})
})
t.Run("Get_User_Vote_Edge_Cases", func(t *testing.T) {
ctx.Suite.EmailSender.Reset()
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, uniqueTestUsername(t, "vote_edge"), uniqueTestEmail(t, "vote_edge"))
secondUser := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, uniqueTestUsername(t, "vote_edge2"), uniqueTestEmail(t, "vote_edge2"))
post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Vote Edge Test Post", "https://example.com/vote-edge")
t.Run("Get_Vote_When_User_Has_Voted", func(t *testing.T) {
voteRequest := makePostRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d/vote", post.ID), map[string]any{"type": "up"}, user, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
assertStatus(t, voteRequest, http.StatusOK)
request := makeAuthenticatedGetRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d/vote", post.ID), user, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
response := assertJSONResponse(t, request, http.StatusOK)
if data, ok := getDataFromResponse(response); ok {
if hasVote, ok := data["has_vote"].(bool); !ok || !hasVote {
t.Error("Expected has_vote to be true when user has voted")
}
if vote, ok := data["vote"]; !ok || vote == nil {
t.Error("Expected vote object when user has voted")
}
}
})
t.Run("Get_Vote_When_User_Has_Not_Voted", func(t *testing.T) {
request := makeAuthenticatedGetRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d/vote", post.ID), secondUser, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
response := assertJSONResponse(t, request, http.StatusOK)
if data, ok := getDataFromResponse(response); ok {
if hasVote, ok := data["has_vote"].(bool); ok {
if hasVote {
t.Error("Expected has_vote to be false when user has not voted")
}
} else {
t.Error("Expected has_vote field in response")
}
}
})
t.Run("Get_Vote_Invalid_Post_ID", func(t *testing.T) {
request := makeAuthenticatedGetRequest(t, ctx.Router, "/api/posts/999999/vote", user, map[string]string{"id": "999999"})
if request.Code != http.StatusOK && request.Code != http.StatusNotFound {
t.Errorf("Expected status 200 or 404 for invalid post ID, got %d", request.Code)
}
})
t.Run("Get_Vote_Unauthenticated", func(t *testing.T) {
request := makeGetRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d/vote", post.ID))
assertErrorResponse(t, request, http.StatusUnauthorized)
})
t.Run("Get_Vote_Response_Structure", func(t *testing.T) {
request := makeAuthenticatedGetRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d/vote", post.ID), user, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
response := assertJSONResponse(t, request, http.StatusOK)
if success, ok := response["success"].(bool); !ok || !success {
t.Error("Expected success field to be true")
}
if data, ok := getDataFromResponse(response); ok {
if _, exists := data["has_vote"]; !exists {
t.Error("Expected has_vote field in response data")
}
if _, exists := data["is_anonymous"]; !exists {
t.Error("Expected is_anonymous field in response data")
}
} else {
t.Error("Expected data field in response")
}
})
})
t.Run("Refresh_Token_Edge_Cases", func(t *testing.T) {
ctx.Suite.EmailSender.Reset()
refreshUser := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, uniqueTestUsername(t, "refresh_edge"), uniqueTestEmail(t, "refresh_edge"))
t.Run("Refresh_With_Expired_Token", func(t *testing.T) {
loginResult, err := ctx.AuthService.Login(refreshUser.User.Username, "SecurePass123!")
if err != nil {
t.Fatalf("Failed to login: %v", err)
}
refreshToken, err := ctx.Suite.RefreshTokenRepo.GetByTokenHash(testutils.HashVerificationToken(loginResult.RefreshToken))
if err != nil {
t.Fatalf("Failed to get refresh token: %v", err)
}
refreshToken.ExpiresAt = time.Now().Add(-1 * time.Hour)
if err := ctx.Suite.DB.Model(refreshToken).Update("expires_at", refreshToken.ExpiresAt).Error; err != nil {
t.Fatalf("Failed to expire refresh token: %v", err)
}
request := makePostRequestWithJSON(t, ctx.Router, "/api/auth/refresh", map[string]any{"refresh_token": loginResult.RefreshToken})
assertErrorResponse(t, request, http.StatusUnauthorized)
})
t.Run("Refresh_With_Revoked_Token", func(t *testing.T) {
loginResult, err := ctx.AuthService.Login(refreshUser.User.Username, "SecurePass123!")
if err != nil {
t.Fatalf("Failed to login: %v", err)
}
if err := ctx.AuthService.RevokeRefreshToken(loginResult.RefreshToken); err != nil {
t.Fatalf("Failed to revoke refresh token: %v", err)
}
request := makePostRequestWithJSON(t, ctx.Router, "/api/auth/refresh", map[string]any{"refresh_token": loginResult.RefreshToken})
assertErrorResponse(t, request, http.StatusUnauthorized)
})
t.Run("Refresh_With_Empty_Token", func(t *testing.T) {
request := makePostRequestWithJSON(t, ctx.Router, "/api/auth/refresh", map[string]any{"refresh_token": ""})
assertErrorResponse(t, request, http.StatusBadRequest)
})
t.Run("Refresh_With_Missing_Token_Field", func(t *testing.T) {
request := makePostRequestWithJSON(t, ctx.Router, "/api/auth/refresh", map[string]any{})
assertErrorResponse(t, request, http.StatusBadRequest)
})
t.Run("Refresh_Token_Rotation", func(t *testing.T) {
loginResult, err := ctx.AuthService.Login(refreshUser.User.Username, "SecurePass123!")
if err != nil {
t.Fatalf("Failed to login: %v", err)
}
originalRefreshToken := loginResult.RefreshToken
request := makePostRequestWithJSON(t, ctx.Router, "/api/auth/refresh", map[string]any{"refresh_token": originalRefreshToken})
response := assertJSONResponse(t, request, http.StatusOK)
if data, ok := getDataFromResponse(response); ok {
if newAccessToken, ok := data["access_token"].(string); ok {
if newAccessToken == "" {
t.Error("Expected new access token in refresh response")
}
if newRefreshToken, ok := data["refresh_token"].(string); ok {
if newRefreshToken != "" && newRefreshToken == originalRefreshToken {
t.Log("Refresh token rotation may not be implemented (same token returned)")
}
}
}
}
})
t.Run("Refresh_After_Account_Lock", func(t *testing.T) {
lockedUser := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, uniqueTestUsername(t, "refresh_lock"), uniqueTestEmail(t, "refresh_lock"))
loginResult, err := ctx.AuthService.Login(lockedUser.User.Username, "SecurePass123!")
if err != nil {
t.Fatalf("Failed to login: %v", err)
}
lockedUser.User.Locked = true
if err := ctx.Suite.UserRepo.Update(lockedUser.User); err != nil {
t.Fatalf("Failed to lock user: %v", err)
}
request := makePostRequestWithJSON(t, ctx.Router, "/api/auth/refresh", map[string]any{"refresh_token": loginResult.RefreshToken})
assertStatusRange(t, request, http.StatusUnauthorized, http.StatusForbidden)
})
t.Run("Refresh_With_Invalid_Token_Format", func(t *testing.T) {
request := makePostRequestWithJSON(t, ctx.Router, "/api/auth/refresh", map[string]any{"refresh_token": "invalid-token-format-12345"})
assertErrorResponse(t, request, http.StatusUnauthorized)
})
})
t.Run("Pagination_Edge_Cases", func(t *testing.T) {
ctx.Suite.EmailSender.Reset()
paginationUser := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, uniqueTestUsername(t, "pagination_edge"), uniqueTestEmail(t, "pagination_edge"))
for i := 0; i < 5; i++ {
testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, paginationUser.User.ID, fmt.Sprintf("Pagination Post %d", i), fmt.Sprintf("https://example.com/pag%d", i))
}
t.Run("Negative_Limit", func(t *testing.T) {
request := makeGetRequest(t, ctx.Router, "/api/posts?limit=-1")
assertStatusRange(t, request, http.StatusOK, http.StatusBadRequest)
})
t.Run("Negative_Offset", func(t *testing.T) {
request := makeGetRequest(t, ctx.Router, "/api/posts?offset=-1")
assertStatusRange(t, request, http.StatusOK, http.StatusBadRequest)
})
t.Run("Very_Large_Limit", func(t *testing.T) {
request := makeGetRequest(t, ctx.Router, "/api/posts?limit=10000")
assertStatusRange(t, request, http.StatusOK, http.StatusBadRequest)
})
t.Run("Very_Large_Offset", func(t *testing.T) {
request := makeGetRequest(t, ctx.Router, "/api/posts?offset=10000")
response := assertJSONResponse(t, request, http.StatusOK)
if data, ok := getDataFromResponse(response); ok {
if posts, ok := data["posts"].([]any); ok {
if len(posts) > 0 {
t.Logf("Large offset returned %d posts (may be expected)", len(posts))
}
}
}
})
t.Run("Invalid_Pagination_Parameters", func(t *testing.T) {
invalidParams := []string{
"limit=abc",
"offset=xyz",
"limit=",
"offset=",
}
for _, param := range invalidParams {
request := makeGetRequest(t, ctx.Router, "/api/posts?"+param)
assertStatus(t, request, http.StatusOK)
}
})
}) })
} }

View File

@@ -1,17 +1,12 @@
package integration package integration
import ( import (
"bytes"
"compress/gzip" "compress/gzip"
"encoding/json"
"io" "io"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"strings" "strings"
"testing" "testing"
"goyco/internal/middleware"
"goyco/internal/testutils"
) )
func TestIntegration_Compression(t *testing.T) { func TestIntegration_Compression(t *testing.T) {
@@ -19,16 +14,16 @@ func TestIntegration_Compression(t *testing.T) {
router := ctx.Router router := ctx.Router
t.Run("Response_Compression_Gzip", func(t *testing.T) { t.Run("Response_Compression_Gzip", func(t *testing.T) {
req := httptest.NewRequest("GET", "/api/posts", nil) request := httptest.NewRequest("GET", "/api/posts", nil)
req.Header.Set("Accept-Encoding", "gzip") request.Header.Set("Accept-Encoding", "gzip")
rec := httptest.NewRecorder() recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req) router.ServeHTTP(recorder, request)
contentEncoding := rec.Header().Get("Content-Encoding") contentEncoding := recorder.Header().Get("Content-Encoding")
if contentEncoding != "" && strings.Contains(contentEncoding, "gzip") { if contentEncoding != "" && strings.Contains(contentEncoding, "gzip") {
assertHeaderContains(t, rec, "Content-Encoding", "gzip") assertHeaderContains(t, recorder, "Content-Encoding", "gzip")
reader, err := gzip.NewReader(rec.Body) reader, err := gzip.NewReader(recorder.Body)
if err != nil { if err != nil {
t.Fatalf("Failed to create gzip reader: %v", err) t.Fatalf("Failed to create gzip reader: %v", err)
} }
@@ -48,14 +43,14 @@ func TestIntegration_Compression(t *testing.T) {
}) })
t.Run("Compression_Headers_Present", func(t *testing.T) { t.Run("Compression_Headers_Present", func(t *testing.T) {
req := httptest.NewRequest("GET", "/api/posts", nil) request := httptest.NewRequest("GET", "/api/posts", nil)
req.Header.Set("Accept-Encoding", "gzip, deflate") request.Header.Set("Accept-Encoding", "gzip, deflate")
rec := httptest.NewRecorder() recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req) router.ServeHTTP(recorder, request)
if rec.Header().Get("Vary") != "" { if recorder.Header().Get("Vary") != "" {
assertHeaderContains(t, rec, "Vary", "Accept-Encoding") assertHeaderContains(t, recorder, "Vary", "Accept-Encoding")
} else { } else {
t.Log("Vary header may not always be present") t.Log("Vary header may not always be present")
} }
@@ -67,25 +62,19 @@ func TestIntegration_StaticFiles(t *testing.T) {
router := ctx.Router router := ctx.Router
t.Run("Robots_Txt_Served", func(t *testing.T) { t.Run("Robots_Txt_Served", func(t *testing.T) {
req := httptest.NewRequest("GET", "/robots.txt", nil) request := makeGetRequest(t, router, "/robots.txt")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req) assertStatus(t, request, http.StatusOK)
assertStatus(t, rec, http.StatusOK) if !strings.Contains(request.Body.String(), "User-agent") {
if !strings.Contains(rec.Body.String(), "User-agent") {
t.Error("Expected robots.txt content") t.Error("Expected robots.txt content")
} }
}) })
t.Run("Static_Files_Security_Headers", func(t *testing.T) { t.Run("Static_Files_Security_Headers", func(t *testing.T) {
req := httptest.NewRequest("GET", "/robots.txt", nil) request := makeGetRequest(t, router, "/robots.txt")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req) if request.Header().Get("X-Content-Type-Options") == "" {
if rec.Header().Get("X-Content-Type-Options") == "" {
t.Log("Security headers may not be applied to all static files") t.Log("Security headers may not be applied to all static files")
} }
}) })
@@ -101,32 +90,22 @@ func TestIntegration_URLMetadata(t *testing.T) {
ctx.Suite.TitleFetcher.SetTitle("Fetched Title") ctx.Suite.TitleFetcher.SetTitle("Fetched Title")
postBody := map[string]string{ postBody := map[string]any{
"title": "Test Post", "title": "Test Post",
"url": "https://example.com/metadata-test", "url": "https://example.com/metadata-test",
"content": "Test content", "content": "Test content",
} }
body, _ := json.Marshal(postBody) request := makePostRequest(t, router, "/api/posts", postBody, user, nil)
req := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req) assertStatus(t, request, http.StatusCreated)
assertStatus(t, rec, http.StatusCreated)
}) })
t.Run("URL_Metadata_Endpoint", func(t *testing.T) { t.Run("URL_Metadata_Endpoint", func(t *testing.T) {
ctx.Suite.TitleFetcher.SetTitle("Endpoint Title") ctx.Suite.TitleFetcher.SetTitle("Endpoint Title")
req := httptest.NewRequest("GET", "/api/posts/title?url=https://example.com/test", nil) request := makeGetRequest(t, router, "/api/posts/title?url=https://example.com/test")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req) response := assertJSONResponse(t, request, http.StatusOK)
response := assertJSONResponse(t, rec, http.StatusOK)
if response != nil { if response != nil {
if data, ok := response["data"].(map[string]any); ok { if data, ok := response["data"].(map[string]any); ok {
if _, exists := data["title"]; !exists { if _, exists := data["title"]; !exists {

View File

@@ -1,14 +1,10 @@
package integration package integration
import ( import (
"bytes"
"encoding/json"
"fmt" "fmt"
"net/http" "net/http"
"net/http/httptest"
"testing" "testing"
"goyco/internal/middleware"
"goyco/internal/testutils" "goyco/internal/testutils"
) )
@@ -22,33 +18,19 @@ func TestIntegration_CrossComponentAuthorization(t *testing.T) {
post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, owner.User.ID, "Owner Post", "https://example.com/owner") post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, owner.User.ID, "Owner Post", "https://example.com/owner")
updateBody := map[string]string{ request := makePutRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d", post.ID), map[string]any{
"title": "Updated Title", "title": "Updated Title",
"content": "Updated content", "content": "Updated content",
} }, otherUser, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
body, _ := json.Marshal(updateBody)
req := httptest.NewRequest("PUT", fmt.Sprintf("/api/posts/%d", post.ID), bytes.NewBuffer(body)) assertErrorResponse(t, request, http.StatusForbidden)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+otherUser.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, otherUser.User.ID)
req = testutils.WithURLParams(req, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
rec := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req) request = makePutRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d", post.ID), map[string]any{
"title": "Updated Title",
"content": "Updated content",
}, owner, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
assertErrorResponse(t, rec, http.StatusForbidden) assertStatus(t, request, http.StatusOK)
req = httptest.NewRequest("PUT", fmt.Sprintf("/api/posts/%d", post.ID), bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+owner.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, owner.User.ID)
req = testutils.WithURLParams(req, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
rec = httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req)
assertStatus(t, rec, http.StatusOK)
}) })
t.Run("Post_Delete_Authorization", func(t *testing.T) { t.Run("Post_Delete_Authorization", func(t *testing.T) {
@@ -58,47 +40,27 @@ func TestIntegration_CrossComponentAuthorization(t *testing.T) {
post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, owner.User.ID, "Delete Post", "https://example.com/delete") post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, owner.User.ID, "Delete Post", "https://example.com/delete")
req := httptest.NewRequest("DELETE", fmt.Sprintf("/api/posts/%d", post.ID), nil) request := makeDeleteRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d", post.ID), otherUser, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
req.Header.Set("Authorization", "Bearer "+otherUser.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, otherUser.User.ID)
req = testutils.WithURLParams(req, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
rec := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req) assertErrorResponse(t, request, http.StatusForbidden)
assertErrorResponse(t, rec, http.StatusForbidden) request = makeDeleteRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d", post.ID), owner, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
req = httptest.NewRequest("DELETE", fmt.Sprintf("/api/posts/%d", post.ID), nil) assertStatus(t, request, http.StatusOK)
req.Header.Set("Authorization", "Bearer "+owner.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, owner.User.ID)
req = testutils.WithURLParams(req, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
rec = httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req)
assertStatus(t, rec, http.StatusOK)
}) })
t.Run("User_Profile_Access_Authorization", func(t *testing.T) { t.Run("User_Profile_Access_Authorization", func(t *testing.T) {
ctx.Suite.EmailSender.Reset() ctx.Suite.EmailSender.Reset()
user1 := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "profile_user1", "profile_user1@example.com") firstUser := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "profile_user1", "profile_user1@example.com")
user2 := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "profile_user2", "profile_user2@example.com") secondUser := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "profile_user2", "profile_user2@example.com")
req := httptest.NewRequest("GET", fmt.Sprintf("/api/users/%d", user1.User.ID), nil) request := makeAuthenticatedGetRequest(t, ctx.Router, fmt.Sprintf("/api/users/%d", firstUser.User.ID), secondUser, map[string]string{"id": fmt.Sprintf("%d", firstUser.User.ID)})
req.Header.Set("Authorization", "Bearer "+user2.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user2.User.ID)
req = testutils.WithURLParams(req, map[string]string{"id": fmt.Sprintf("%d", user1.User.ID)})
rec := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req) response := assertJSONResponse(t, request, http.StatusOK)
if data, ok := getDataFromResponse(response); ok {
response := assertJSONResponse(t, rec, http.StatusOK)
if response != nil {
if data, ok := response["data"].(map[string]any); ok {
if userData, ok := data["user"].(map[string]any); ok { if userData, ok := data["user"].(map[string]any); ok {
if id, ok := userData["id"].(float64); ok && uint(id) != user1.User.ID { if id, ok := userData["id"].(float64); ok && uint(id) != firstUser.User.ID {
t.Errorf("Expected user ID %d, got %.0f", user1.User.ID, id) t.Errorf("Expected user ID %d, got %.0f", firstUser.User.ID, id)
}
} }
} }
} }
@@ -109,24 +71,13 @@ func TestIntegration_CrossComponentAuthorization(t *testing.T) {
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "settings_auth_user", "settings_auth@example.com") user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "settings_auth_user", "settings_auth@example.com")
otherUser := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "settings_auth_other", "settings_auth_other@example.com") otherUser := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "settings_auth_other", "settings_auth_other@example.com")
updateBody := map[string]string{ request := makePutRequest(t, ctx.Router, "/api/auth/email", map[string]any{"email": "newemail@example.com"}, otherUser, nil)
"email": "newemail@example.com",
}
body, _ := json.Marshal(updateBody)
req := httptest.NewRequest("PUT", "/api/auth/email", bytes.NewBuffer(body)) response := assertJSONResponse(t, request, http.StatusOK)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+otherUser.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, otherUser.User.ID)
rec := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req)
response := assertJSONResponse(t, rec, http.StatusOK)
if response == nil { if response == nil {
return return
} }
if data, ok := response["data"].(map[string]any); ok { if data, ok := getDataFromResponse(response); ok {
if userData, ok := data["user"].(map[string]any); ok { if userData, ok := data["user"].(map[string]any); ok {
if email, ok := userData["email"].(string); ok && email == "newemail@example.com" { if email, ok := userData["email"].(string); ok && email == "newemail@example.com" {
if id, ok := userData["id"].(float64); ok && uint(id) != otherUser.User.ID { if id, ok := userData["id"].(float64); ok && uint(id) != otherUser.User.ID {
@@ -136,20 +87,9 @@ func TestIntegration_CrossComponentAuthorization(t *testing.T) {
} }
} }
updateBody2 := map[string]string{ request = makePutRequest(t, ctx.Router, "/api/auth/email", map[string]any{"email": "anothernewemail@example.com"}, user, nil)
"email": "anothernewemail@example.com",
}
body2, _ := json.Marshal(updateBody2)
req = httptest.NewRequest("PUT", "/api/auth/email", bytes.NewBuffer(body2)) assertStatus(t, request, http.StatusOK)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
rec = httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req)
assertStatus(t, rec, http.StatusOK)
}) })
t.Run("Vote_Authorization", func(t *testing.T) { t.Run("Vote_Authorization", func(t *testing.T) {
@@ -159,73 +99,42 @@ func TestIntegration_CrossComponentAuthorization(t *testing.T) {
post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, postOwner.User.ID, "Vote Auth Post", "https://example.com/vote-auth") post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, postOwner.User.ID, "Vote Auth Post", "https://example.com/vote-auth")
voteBody := map[string]string{"type": "up"} request := makePostRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d/vote", post.ID), map[string]any{"type": "up"}, user, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
body, _ := json.Marshal(voteBody)
req := httptest.NewRequest("POST", fmt.Sprintf("/api/posts/%d/vote", post.ID), bytes.NewBuffer(body)) assertStatus(t, request, http.StatusOK)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
req = testutils.WithURLParams(req, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
rec := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req) request = makePostRequestWithJSON(t, ctx.Router, fmt.Sprintf("/api/posts/%d/vote", post.ID), map[string]any{"type": "up"})
assertStatus(t, rec, http.StatusOK) assertErrorResponse(t, request, http.StatusUnauthorized)
req = httptest.NewRequest("POST", fmt.Sprintf("/api/posts/%d/vote", post.ID), bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
rec = httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req)
assertErrorResponse(t, rec, http.StatusUnauthorized)
}) })
t.Run("Protected_Endpoint_Without_Auth", func(t *testing.T) { t.Run("Protected_Endpoint_Without_Auth", func(t *testing.T) {
ctx.Suite.EmailSender.Reset() ctx.Suite.EmailSender.Reset()
req := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer([]byte("{}"))) request := makePostRequestWithJSON(t, ctx.Router, "/api/posts", map[string]any{})
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req) assertErrorResponse(t, request, http.StatusUnauthorized)
assertErrorResponse(t, rec, http.StatusUnauthorized)
}) })
t.Run("Protected_Endpoint_With_Invalid_Token", func(t *testing.T) { t.Run("Protected_Endpoint_With_Invalid_Token", func(t *testing.T) {
ctx.Suite.EmailSender.Reset() ctx.Suite.EmailSender.Reset()
req := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer([]byte("{}"))) request := makeRequest(t, ctx.Router, "POST", "/api/posts", []byte("{}"), map[string]string{"Content-Type": "application/json", "Authorization": "Bearer invalid-token"})
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer invalid-token")
rec := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req) assertErrorResponse(t, request, http.StatusUnauthorized)
assertErrorResponse(t, rec, http.StatusUnauthorized)
}) })
t.Run("User_List_Authorization", func(t *testing.T) { t.Run("User_List_Authorization", func(t *testing.T) {
ctx.Suite.EmailSender.Reset() ctx.Suite.EmailSender.Reset()
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "list_auth_user", "list_auth@example.com") user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "list_auth_user", "list_auth@example.com")
req := httptest.NewRequest("GET", "/api/users", nil) request := makeAuthenticatedGetRequest(t, ctx.Router, "/api/users", user, nil)
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
rec := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req) assertStatus(t, request, http.StatusOK)
assertStatus(t, rec, http.StatusOK) request = makeGetRequest(t, ctx.Router, "/api/users")
req = httptest.NewRequest("GET", "/api/users", nil) assertErrorResponse(t, request, http.StatusUnauthorized)
rec = httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req)
assertErrorResponse(t, rec, http.StatusUnauthorized)
}) })
t.Run("Refresh_Token_Authorization", func(t *testing.T) { t.Run("Refresh_Token_Authorization", func(t *testing.T) {
@@ -237,18 +146,9 @@ func TestIntegration_CrossComponentAuthorization(t *testing.T) {
t.Fatalf("Failed to login: %v", err) t.Fatalf("Failed to login: %v", err)
} }
refreshBody := map[string]string{ request := makePostRequestWithJSON(t, ctx.Router, "/api/auth/refresh", map[string]any{"refresh_token": loginResult.RefreshToken})
"refresh_token": loginResult.RefreshToken,
}
body, _ := json.Marshal(refreshBody)
req := httptest.NewRequest("POST", "/api/auth/refresh", bytes.NewBuffer(body)) response := assertJSONResponse(t, request, http.StatusOK)
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req)
response := assertJSONResponse(t, rec, http.StatusOK)
if response == nil { if response == nil {
return return
} }
@@ -260,17 +160,8 @@ func TestIntegration_CrossComponentAuthorization(t *testing.T) {
t.Error("Expected data field in refresh response") t.Error("Expected data field in refresh response")
} }
refreshBody = map[string]string{ request = makePostRequestWithJSON(t, ctx.Router, "/api/auth/refresh", map[string]any{"refresh_token": "invalid-refresh-token"})
"refresh_token": "invalid-refresh-token",
}
body, _ = json.Marshal(refreshBody)
req = httptest.NewRequest("POST", "/api/auth/refresh", bytes.NewBuffer(body)) assertErrorResponse(t, request, http.StatusUnauthorized)
req.Header.Set("Content-Type", "application/json")
rec = httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req)
assertErrorResponse(t, rec, http.StatusUnauthorized)
}) })
} }

View File

@@ -14,166 +14,137 @@ func TestIntegration_CSRF_Protection(t *testing.T) {
ctx := setupPageHandlerTestContext(t) ctx := setupPageHandlerTestContext(t)
router := ctx.Router router := ctx.Router
t.Run("CSRF_Blocks_Form_Without_Token", func(t *testing.T) { getCSRFToken := func(t *testing.T, path string, cookies ...*http.Cookie) *http.Cookie {
reqBody := url.Values{} t.Helper()
reqBody.Set("username", "testuser")
reqBody.Set("email", "test@example.com")
reqBody.Set("password", "SecurePass123!")
req := httptest.NewRequest("POST", "/register", strings.NewReader(reqBody.Encode())) request := httptest.NewRequest("GET", path, nil)
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") for _, c := range cookies {
rec := httptest.NewRecorder() request.AddCookie(c)
router.ServeHTTP(rec, req)
if rec.Code != http.StatusForbidden {
t.Errorf("Expected status 403, got %d. Body: %s", rec.Code, rec.Body.String())
} }
if !strings.Contains(rec.Body.String(), "Invalid CSRF token") { recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, request)
for _, cookie := range recorder.Result().Cookies() {
if cookie.Name == "csrf_token" {
return cookie
}
}
t.Fatalf("Expected CSRF cookie to be set for %s", path)
return nil
}
t.Run("CSRF_Blocks_Form_Without_Token", func(t *testing.T) {
requestBody := url.Values{}
requestBody.Set("username", "testuser")
requestBody.Set("email", "test@example.com")
requestBody.Set("password", "SecurePass123!")
request := httptest.NewRequest("POST", "/register", strings.NewReader(requestBody.Encode()))
request.Header.Set("Content-Type", "application/x-www-form-urlencoded")
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, request)
if recorder.Code != http.StatusForbidden {
t.Errorf("Expected status 403, got %d. Body: %s", recorder.Code, recorder.Body.String())
}
if !strings.Contains(recorder.Body.String(), "Invalid CSRF token") {
t.Error("Expected CSRF error message") t.Error("Expected CSRF error message")
} }
}) })
t.Run("CSRF_Allows_Form_With_Valid_Token", func(t *testing.T) { t.Run("CSRF_Allows_Form_With_Valid_Token", func(t *testing.T) {
getReq := httptest.NewRequest("GET", "/register", nil) csrfCookie := getCSRFToken(t, "/register")
getRec := httptest.NewRecorder()
router.ServeHTTP(getRec, getReq)
cookies := getRec.Result().Cookies() requestBody := url.Values{}
var csrfCookie *http.Cookie requestBody.Set("username", "csrf_user")
for _, cookie := range cookies { requestBody.Set("email", "csrf@example.com")
if cookie.Name == "csrf_token" { requestBody.Set("password", "SecurePass123!")
csrfCookie = cookie requestBody.Set("csrf_token", csrfCookie.Value)
break
}
}
if csrfCookie == nil { request := httptest.NewRequest("POST", "/register", strings.NewReader(requestBody.Encode()))
t.Fatal("Expected CSRF cookie to be set") request.Header.Set("Content-Type", "application/x-www-form-urlencoded")
} request.AddCookie(csrfCookie)
recorder := httptest.NewRecorder()
csrfToken := csrfCookie.Value router.ServeHTTP(recorder, request)
reqBody := url.Values{} if recorder.Code == http.StatusForbidden {
reqBody.Set("username", "csrf_user")
reqBody.Set("email", "csrf@example.com")
reqBody.Set("password", "SecurePass123!")
reqBody.Set("csrf_token", csrfToken)
req := httptest.NewRequest("POST", "/register", strings.NewReader(reqBody.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.AddCookie(csrfCookie)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
if rec.Code == http.StatusForbidden {
t.Error("Expected form submission with valid CSRF token to succeed") t.Error("Expected form submission with valid CSRF token to succeed")
} }
}) })
t.Run("CSRF_Allows_API_Requests", func(t *testing.T) { t.Run("CSRF_Allows_API_Requests", func(t *testing.T) {
reqBody := map[string]string{ requestBody := map[string]string{
"username": "api_user", "username": "api_user",
"email": "api@example.com", "email": "api@example.com",
"password": "SecurePass123!", "password": "SecurePass123!",
} }
body, _ := json.Marshal(reqBody) body, _ := json.Marshal(requestBody)
req := httptest.NewRequest("POST", "/api/auth/register", bytes.NewBuffer(body)) request := httptest.NewRequest("POST", "/api/auth/register", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json") request.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder() recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req) router.ServeHTTP(recorder, request)
if rec.Code == http.StatusForbidden { if recorder.Code == http.StatusForbidden {
t.Error("Expected API requests to bypass CSRF protection") t.Error("Expected API requests to bypass CSRF protection")
} }
}) })
t.Run("CSRF_Blocks_Mismatched_Token", func(t *testing.T) { t.Run("CSRF_Blocks_Mismatched_Token", func(t *testing.T) {
getReq := httptest.NewRequest("GET", "/register", nil) csrfCookie := getCSRFToken(t, "/register")
getRec := httptest.NewRecorder()
router.ServeHTTP(getRec, getReq)
cookies := getRec.Result().Cookies() requestBody := url.Values{}
var csrfCookie *http.Cookie requestBody.Set("username", "mismatch_user")
for _, cookie := range cookies { requestBody.Set("email", "mismatch@example.com")
if cookie.Name == "csrf_token" { requestBody.Set("password", "SecurePass123!")
csrfCookie = cookie requestBody.Set("csrf_token", "wrong-token")
break
request := httptest.NewRequest("POST", "/register", strings.NewReader(requestBody.Encode()))
request.Header.Set("Content-Type", "application/x-www-form-urlencoded")
request.AddCookie(csrfCookie)
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, request)
if recorder.Code != http.StatusForbidden {
t.Errorf("Expected status 403, got %d. Body: %s", recorder.Code, recorder.Body.String())
} }
} if !strings.Contains(recorder.Body.String(), "Invalid CSRF token") {
if csrfCookie == nil {
t.Fatal("Expected CSRF cookie to be set")
}
reqBody := url.Values{}
reqBody.Set("username", "mismatch_user")
reqBody.Set("email", "mismatch@example.com")
reqBody.Set("password", "SecurePass123!")
reqBody.Set("csrf_token", "wrong-token")
req := httptest.NewRequest("POST", "/register", strings.NewReader(reqBody.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.AddCookie(csrfCookie)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
if rec.Code != http.StatusForbidden {
t.Errorf("Expected status 403, got %d. Body: %s", rec.Code, rec.Body.String())
}
if !strings.Contains(rec.Body.String(), "Invalid CSRF token") {
t.Error("Expected CSRF error message") t.Error("Expected CSRF error message")
} }
}) })
t.Run("CSRF_Allows_GET_Requests", func(t *testing.T) { t.Run("CSRF_Allows_GET_Requests", func(t *testing.T) {
req := httptest.NewRequest("GET", "/register", nil) request := httptest.NewRequest("GET", "/register", nil)
rec := httptest.NewRecorder() recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req) router.ServeHTTP(recorder, request)
if rec.Code == http.StatusForbidden { if recorder.Code == http.StatusForbidden {
t.Error("Expected GET requests to bypass CSRF protection") t.Error("Expected GET requests to bypass CSRF protection")
} }
}) })
t.Run("CSRF_Token_In_Header", func(t *testing.T) { t.Run("CSRF_Token_In_Header", func(t *testing.T) {
getReq := httptest.NewRequest("GET", "/register", nil) csrfCookie := getCSRFToken(t, "/register")
getRec := httptest.NewRecorder()
router.ServeHTTP(getRec, getReq)
cookies := getRec.Result().Cookies() requestBody := url.Values{}
var csrfCookie *http.Cookie requestBody.Set("username", "header_user")
for _, cookie := range cookies { requestBody.Set("email", "header@example.com")
if cookie.Name == "csrf_token" { requestBody.Set("password", "SecurePass123!")
csrfCookie = cookie
break
}
}
if csrfCookie == nil { request := httptest.NewRequest("POST", "/register", strings.NewReader(requestBody.Encode()))
t.Fatal("Expected CSRF cookie to be set") request.Header.Set("Content-Type", "application/x-www-form-urlencoded")
} request.Header.Set("X-CSRF-Token", csrfCookie.Value)
request.AddCookie(csrfCookie)
recorder := httptest.NewRecorder()
csrfToken := csrfCookie.Value router.ServeHTTP(recorder, request)
reqBody := url.Values{} if recorder.Code == http.StatusForbidden {
reqBody.Set("username", "header_user")
reqBody.Set("email", "header@example.com")
reqBody.Set("password", "SecurePass123!")
req := httptest.NewRequest("POST", "/register", strings.NewReader(reqBody.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("X-CSRF-Token", csrfToken)
req.AddCookie(csrfCookie)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
if rec.Code == http.StatusForbidden {
t.Error("Expected CSRF token in header to be accepted") t.Error("Expected CSRF token in header to be accepted")
} }
}) })
@@ -182,41 +153,24 @@ func TestIntegration_CSRF_Protection(t *testing.T) {
ctx.Suite.EmailSender.Reset() ctx.Suite.EmailSender.Reset()
user := createUserWithCleanup(t, ctx, "csrf_form_user", "csrf_form@example.com") user := createUserWithCleanup(t, ctx, "csrf_form_user", "csrf_form@example.com")
getReq := httptest.NewRequest("GET", "/posts/new", nil) authCookie := &http.Cookie{Name: "auth_token", Value: user.Token}
getReq.AddCookie(&http.Cookie{Name: "auth_token", Value: user.Token}) csrfCookie := getCSRFToken(t, "/posts/new", authCookie)
getRec := httptest.NewRecorder()
router.ServeHTTP(getRec, getReq)
cookies := getRec.Result().Cookies() requestBody := url.Values{}
var csrfCookie *http.Cookie requestBody.Set("title", "CSRF Test Post")
for _, cookie := range cookies { requestBody.Set("url", "https://example.com/csrf-test")
if cookie.Name == "csrf_token" { requestBody.Set("content", "Test content")
csrfCookie = cookie requestBody.Set("csrf_token", csrfCookie.Value)
break
}
}
if csrfCookie == nil { request := httptest.NewRequest("POST", "/posts", strings.NewReader(requestBody.Encode()))
t.Fatal("Expected CSRF cookie to be set") request.Header.Set("Content-Type", "application/x-www-form-urlencoded")
} request.AddCookie(authCookie)
request.AddCookie(csrfCookie)
recorder := httptest.NewRecorder()
csrfToken := csrfCookie.Value router.ServeHTTP(recorder, request)
reqBody := url.Values{} if recorder.Code == http.StatusForbidden {
reqBody.Set("title", "CSRF Test Post")
reqBody.Set("url", "https://example.com/csrf-test")
reqBody.Set("content", "Test content")
reqBody.Set("csrf_token", csrfToken)
req := httptest.NewRequest("POST", "/posts", strings.NewReader(reqBody.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.AddCookie(&http.Cookie{Name: "auth_token", Value: user.Token})
req.AddCookie(csrfCookie)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
if rec.Code == http.StatusForbidden {
t.Error("Expected post creation with valid CSRF token to succeed") t.Error("Expected post creation with valid CSRF token to succeed")
} }
}) })

View File

@@ -19,27 +19,18 @@ func TestIntegration_DataConsistency(t *testing.T) {
ctx.Suite.EmailSender.Reset() ctx.Suite.EmailSender.Reset()
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "consistency_user", "consistency@example.com") user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "consistency_user", "consistency@example.com")
postBody := map[string]string{ request := makePostRequest(t, ctx.Router, "/api/posts", map[string]any{
"title": "Consistency Test Post", "title": "Consistency Test Post",
"url": "https://example.com/consistency", "url": "https://example.com/consistency",
"content": "Test content", "content": "Test content",
} }, user, nil)
body, _ := json.Marshal(postBody)
req := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer(body)) createResponse := assertJSONResponse(t, request, http.StatusCreated)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
rec := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req)
createResponse := assertJSONResponse(t, rec, http.StatusCreated)
if createResponse == nil { if createResponse == nil {
return return
} }
postData, ok := createResponse["data"].(map[string]any) postData, ok := getDataFromResponse(createResponse)
if !ok { if !ok {
t.Fatal("Response missing data") t.Fatal("Response missing data")
} }
@@ -53,16 +44,14 @@ func TestIntegration_DataConsistency(t *testing.T) {
createdURL := postData["url"] createdURL := postData["url"]
createdContent := postData["content"] createdContent := postData["content"]
getReq := httptest.NewRequest("GET", fmt.Sprintf("/api/posts/%.0f", postID), nil) getRequest := makeGetRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%.0f", postID))
getRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(getRec, getReq)
getResponse := assertJSONResponse(t, getRec, http.StatusOK) getResponse := assertJSONResponse(t, getRequest, http.StatusOK)
if getResponse == nil { if getResponse == nil {
return return
} }
getPostData, ok := getResponse["data"].(map[string]any) getPostData, ok := getDataFromResponse(getResponse)
if !ok { if !ok {
t.Fatal("Get response missing data") t.Fatal("Get response missing data")
} }
@@ -96,32 +85,17 @@ func TestIntegration_DataConsistency(t *testing.T) {
post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Vote Consistency Post", "https://example.com/vote-consistency") post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Vote Consistency Post", "https://example.com/vote-consistency")
voteBody := map[string]string{"type": "up"} voteRequest := makePostRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d/vote", post.ID), map[string]any{"type": "up"}, user, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
body, _ := json.Marshal(voteBody) assertStatus(t, voteRequest, http.StatusOK)
voteReq := httptest.NewRequest("POST", fmt.Sprintf("/api/posts/%d/vote", post.ID), bytes.NewBuffer(body)) getVotesRequest := makeAuthenticatedGetRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d/votes", post.ID), user, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
voteReq.Header.Set("Content-Type", "application/json")
voteReq.Header.Set("Authorization", "Bearer "+user.Token)
voteReq = testutils.WithUserContext(voteReq, middleware.UserIDKey, user.User.ID)
voteReq = testutils.WithURLParams(voteReq, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
voteRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(voteRec, voteReq)
assertStatus(t, voteRec, http.StatusOK) votesResponse := assertJSONResponse(t, getVotesRequest, http.StatusOK)
getVotesReq := httptest.NewRequest("GET", fmt.Sprintf("/api/posts/%d/votes", post.ID), nil)
getVotesReq.Header.Set("Authorization", "Bearer "+user.Token)
getVotesReq = testutils.WithUserContext(getVotesReq, middleware.UserIDKey, user.User.ID)
getVotesReq = testutils.WithURLParams(getVotesReq, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
getVotesRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(getVotesRec, getVotesReq)
votesResponse := assertJSONResponse(t, getVotesRec, http.StatusOK)
if votesResponse == nil { if votesResponse == nil {
return return
} }
votesData, ok := votesResponse["data"].(map[string]any) votesData, ok := getDataFromResponse(votesResponse)
if !ok { if !ok {
t.Fatal("Votes response missing data") t.Fatal("Votes response missing data")
} }
@@ -172,32 +146,21 @@ func TestIntegration_DataConsistency(t *testing.T) {
post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Original Title", "https://example.com/original") post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Original Title", "https://example.com/original")
updateBody := map[string]string{ updateRequest := makePutRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d", post.ID), map[string]any{
"title": "Updated Title", "title": "Updated Title",
"content": "Updated content", "content": "Updated content",
} }, user, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
body, _ := json.Marshal(updateBody)
updateReq := httptest.NewRequest("PUT", fmt.Sprintf("/api/posts/%d", post.ID), bytes.NewBuffer(body)) assertStatus(t, updateRequest, http.StatusOK)
updateReq.Header.Set("Content-Type", "application/json")
updateReq.Header.Set("Authorization", "Bearer "+user.Token)
updateReq = testutils.WithUserContext(updateReq, middleware.UserIDKey, user.User.ID)
updateReq = testutils.WithURLParams(updateReq, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
updateRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(updateRec, updateReq)
assertStatus(t, updateRec, http.StatusOK) getRequest := makeGetRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d", post.ID))
getReq := httptest.NewRequest("GET", fmt.Sprintf("/api/posts/%d", post.ID), nil) getResponse := assertJSONResponse(t, getRequest, http.StatusOK)
getRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(getRec, getReq)
getResponse := assertJSONResponse(t, getRec, http.StatusOK)
if getResponse == nil { if getResponse == nil {
return return
} }
getPostData, ok := getResponse["data"].(map[string]any) getPostData, ok := getDataFromResponse(getResponse)
if !ok { if !ok {
t.Fatal("Get response missing data") t.Fatal("Get response missing data")
} }
@@ -215,18 +178,12 @@ func TestIntegration_DataConsistency(t *testing.T) {
ctx.Suite.EmailSender.Reset() ctx.Suite.EmailSender.Reset()
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "user_posts_consistency", "user_posts_consistency@example.com") user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "user_posts_consistency", "user_posts_consistency@example.com")
post1 := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Post 1", "https://example.com/post1") firstPost := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Post 1", "https://example.com/post1")
post2 := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Post 2", "https://example.com/post2") secondPost := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Post 2", "https://example.com/post2")
req := httptest.NewRequest("GET", fmt.Sprintf("/api/users/%d/posts", user.User.ID), nil) request := makeAuthenticatedGetRequest(t, ctx.Router, fmt.Sprintf("/api/users/%d/posts", user.User.ID), user, map[string]string{"id": fmt.Sprintf("%d", user.User.ID)})
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
req = testutils.WithURLParams(req, map[string]string{"id": fmt.Sprintf("%d", user.User.ID)})
rec := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req) response := assertJSONResponse(t, request, http.StatusOK)
response := assertJSONResponse(t, rec, http.StatusOK)
if response == nil { if response == nil {
return return
} }
@@ -245,26 +202,26 @@ func TestIntegration_DataConsistency(t *testing.T) {
t.Errorf("Expected at least 2 posts, got %d", len(posts)) t.Errorf("Expected at least 2 posts, got %d", len(posts))
} }
foundPost1 := false foundFirstPost := false
foundPost2 := false foundSecondPost := false
for _, post := range posts { for _, post := range posts {
if postMap, ok := post.(map[string]any); ok { if postMap, ok := post.(map[string]any); ok {
if postID, ok := postMap["id"].(float64); ok { if postID, ok := postMap["id"].(float64); ok {
if uint(postID) == post1.ID { if uint(postID) == firstPost.ID {
foundPost1 = true foundFirstPost = true
} }
if uint(postID) == post2.ID { if uint(postID) == secondPost.ID {
foundPost2 = true foundSecondPost = true
} }
} }
} }
} }
if !foundPost1 { if !foundFirstPost {
t.Error("Post 1 not found in user posts") t.Error("Post 1 not found in user posts")
} }
if !foundPost2 { if !foundSecondPost {
t.Error("Post 2 not found in user posts") t.Error("Post 2 not found in user posts")
} }
}) })
@@ -275,20 +232,20 @@ func TestIntegration_DataConsistency(t *testing.T) {
post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Delete Consistency Post", "https://example.com/delete-consistency") post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Delete Consistency Post", "https://example.com/delete-consistency")
deleteReq := httptest.NewRequest("DELETE", fmt.Sprintf("/api/posts/%d", post.ID), nil) deleteRequest := httptest.NewRequest("DELETE", fmt.Sprintf("/api/posts/%d", post.ID), nil)
deleteReq.Header.Set("Authorization", "Bearer "+user.Token) deleteRequest.Header.Set("Authorization", "Bearer "+user.Token)
deleteReq = testutils.WithUserContext(deleteReq, middleware.UserIDKey, user.User.ID) deleteRequest = testutils.WithUserContext(deleteRequest, middleware.UserIDKey, user.User.ID)
deleteReq = testutils.WithURLParams(deleteReq, map[string]string{"id": fmt.Sprintf("%d", post.ID)}) deleteRequest = testutils.WithURLParams(deleteRequest, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
deleteRec := httptest.NewRecorder() deleteRecorder := httptest.NewRecorder()
ctx.Router.ServeHTTP(deleteRec, deleteReq) ctx.Router.ServeHTTP(deleteRecorder, deleteRequest)
assertStatus(t, deleteRec, http.StatusOK) assertStatus(t, deleteRecorder, http.StatusOK)
getReq := httptest.NewRequest("GET", fmt.Sprintf("/api/posts/%d", post.ID), nil) getRequest := httptest.NewRequest("GET", fmt.Sprintf("/api/posts/%d", post.ID), nil)
getRec := httptest.NewRecorder() getRecorder := httptest.NewRecorder()
ctx.Router.ServeHTTP(getRec, getReq) ctx.Router.ServeHTTP(getRecorder, getRequest)
assertStatus(t, getRec, http.StatusNotFound) assertStatus(t, getRecorder, http.StatusNotFound)
}) })
t.Run("Vote_Removal_Consistency", func(t *testing.T) { t.Run("Vote_Removal_Consistency", func(t *testing.T) {
@@ -300,33 +257,33 @@ func TestIntegration_DataConsistency(t *testing.T) {
voteBody := map[string]string{"type": "up"} voteBody := map[string]string{"type": "up"}
body, _ := json.Marshal(voteBody) body, _ := json.Marshal(voteBody)
voteReq := httptest.NewRequest("POST", fmt.Sprintf("/api/posts/%d/vote", post.ID), bytes.NewBuffer(body)) voteRequest := httptest.NewRequest("POST", fmt.Sprintf("/api/posts/%d/vote", post.ID), bytes.NewBuffer(body))
voteReq.Header.Set("Content-Type", "application/json") voteRequest.Header.Set("Content-Type", "application/json")
voteReq.Header.Set("Authorization", "Bearer "+user.Token) voteRequest.Header.Set("Authorization", "Bearer "+user.Token)
voteReq = testutils.WithUserContext(voteReq, middleware.UserIDKey, user.User.ID) voteRequest = testutils.WithUserContext(voteRequest, middleware.UserIDKey, user.User.ID)
voteReq = testutils.WithURLParams(voteReq, map[string]string{"id": fmt.Sprintf("%d", post.ID)}) voteRequest = testutils.WithURLParams(voteRequest, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
voteRec := httptest.NewRecorder() voteRecorder := httptest.NewRecorder()
ctx.Router.ServeHTTP(voteRec, voteReq) ctx.Router.ServeHTTP(voteRecorder, voteRequest)
assertStatus(t, voteRec, http.StatusOK) assertStatus(t, voteRecorder, http.StatusOK)
removeVoteReq := httptest.NewRequest("DELETE", fmt.Sprintf("/api/posts/%d/vote", post.ID), nil) removeVoteRequest := httptest.NewRequest("DELETE", fmt.Sprintf("/api/posts/%d/vote", post.ID), nil)
removeVoteReq.Header.Set("Authorization", "Bearer "+user.Token) removeVoteRequest.Header.Set("Authorization", "Bearer "+user.Token)
removeVoteReq = testutils.WithUserContext(removeVoteReq, middleware.UserIDKey, user.User.ID) removeVoteRequest = testutils.WithUserContext(removeVoteRequest, middleware.UserIDKey, user.User.ID)
removeVoteReq = testutils.WithURLParams(removeVoteReq, map[string]string{"id": fmt.Sprintf("%d", post.ID)}) removeVoteRequest = testutils.WithURLParams(removeVoteRequest, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
removeVoteRec := httptest.NewRecorder() removeVoteRecorder := httptest.NewRecorder()
ctx.Router.ServeHTTP(removeVoteRec, removeVoteReq) ctx.Router.ServeHTTP(removeVoteRecorder, removeVoteRequest)
assertStatus(t, removeVoteRec, http.StatusOK) assertStatus(t, removeVoteRecorder, http.StatusOK)
getVotesReq := httptest.NewRequest("GET", fmt.Sprintf("/api/posts/%d/votes", post.ID), nil) getVotesRequest := httptest.NewRequest("GET", fmt.Sprintf("/api/posts/%d/votes", post.ID), nil)
getVotesReq.Header.Set("Authorization", "Bearer "+user.Token) getVotesRequest.Header.Set("Authorization", "Bearer "+user.Token)
getVotesReq = testutils.WithUserContext(getVotesReq, middleware.UserIDKey, user.User.ID) getVotesRequest = testutils.WithUserContext(getVotesRequest, middleware.UserIDKey, user.User.ID)
getVotesReq = testutils.WithURLParams(getVotesReq, map[string]string{"id": fmt.Sprintf("%d", post.ID)}) getVotesRequest = testutils.WithURLParams(getVotesRequest, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
getVotesRec := httptest.NewRecorder() getVotesRecorder := httptest.NewRecorder()
ctx.Router.ServeHTTP(getVotesRec, getVotesReq) ctx.Router.ServeHTTP(getVotesRecorder, getVotesRequest)
votesResponse := assertJSONResponse(t, getVotesRec, http.StatusOK) votesResponse := assertJSONResponse(t, getVotesRecorder, http.StatusOK)
if votesResponse == nil { if votesResponse == nil {
return return
} }

View File

@@ -22,41 +22,39 @@ func TestIntegration_EdgeCases(t *testing.T) {
expiredToken := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjE2MDAwMDAwMDB9.expired" expiredToken := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjE2MDAwMDAwMDB9.expired"
req := httptest.NewRequest("GET", "/api/auth/me", nil) request := httptest.NewRequest("GET", "/api/auth/me", nil)
req.Header.Set("Authorization", "Bearer "+expiredToken) request.Header.Set("Authorization", "Bearer "+expiredToken)
rec := httptest.NewRecorder() recorder := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req) ctx.Router.ServeHTTP(recorder, request)
assertErrorResponse(t, rec, http.StatusUnauthorized) assertErrorResponse(t, recorder, http.StatusUnauthorized)
}) })
t.Run("Concurrent_Vote_Operations", func(t *testing.T) { t.Run("Concurrent_Vote_Operations", func(t *testing.T) {
ctx.Suite.EmailSender.Reset() ctx.Suite.EmailSender.Reset()
user1 := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "vote_user1", "vote1@example.com") firstUser := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "vote_user1", "vote1@example.com")
post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user1.User.ID, "Concurrent Vote Post", "https://example.com/concurrent") post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, firstUser.User.ID, "Concurrent Vote Post", "https://example.com/concurrent")
var wg sync.WaitGroup var wg sync.WaitGroup
errors := make(chan error, 10) errors := make(chan error, 10)
for i := 0; i < 5; i++ { for range 5 {
wg.Add(1) wg.Go(func() {
go func() {
defer wg.Done()
voteBody := map[string]string{"type": "up"} voteBody := map[string]string{"type": "up"}
body, _ := json.Marshal(voteBody) body, _ := json.Marshal(voteBody)
req := httptest.NewRequest("POST", fmt.Sprintf("/api/posts/%d/vote", post.ID), bytes.NewBuffer(body)) request := httptest.NewRequest("POST", fmt.Sprintf("/api/posts/%d/vote", post.ID), bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json") request.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+user1.Token) request.Header.Set("Authorization", "Bearer "+firstUser.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user1.User.ID) request = testutils.WithUserContext(request, middleware.UserIDKey, firstUser.User.ID)
req = testutils.WithURLParams(req, map[string]string{"id": fmt.Sprintf("%d", post.ID)}) request = testutils.WithURLParams(request, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
rec := httptest.NewRecorder() recorder := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req) ctx.Router.ServeHTTP(recorder, request)
if rec.Code != http.StatusOK { if recorder.Code != http.StatusOK {
errors <- fmt.Errorf("unexpected status: %d", rec.Code) errors <- fmt.Errorf("unexpected status: %d", recorder.Code)
} }
}() })
} }
wg.Wait() wg.Wait()
@@ -72,8 +70,8 @@ func TestIntegration_EdgeCases(t *testing.T) {
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "large_user", "large@example.com") user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "large_user", "large@example.com")
largeContent := make([]byte, 10001) largeContent := make([]byte, 10001)
for i := range largeContent { for idx := range largeContent {
largeContent[i] = 'a' largeContent[idx] = 'a'
} }
postBody := map[string]string{ postBody := map[string]string{
@@ -82,36 +80,36 @@ func TestIntegration_EdgeCases(t *testing.T) {
"content": string(largeContent), "content": string(largeContent),
} }
body, _ := json.Marshal(postBody) body, _ := json.Marshal(postBody)
req := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer(body)) request := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json") request.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+user.Token) request.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID) request = testutils.WithUserContext(request, middleware.UserIDKey, user.User.ID)
rec := httptest.NewRecorder() recorder := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req) ctx.Router.ServeHTTP(recorder, request)
assertErrorResponse(t, rec, http.StatusBadRequest) assertErrorResponse(t, recorder, http.StatusBadRequest)
smallContent := make([]byte, 1000) smallContent := make([]byte, 1000)
for i := range smallContent { for idx := range smallContent {
smallContent[i] = 'a' smallContent[idx] = 'a'
} }
postBody2 := map[string]string{ secondPostBody := map[string]string{
"title": "Small Post", "title": "Small Post",
"url": "https://example.com/small", "url": "https://example.com/small",
"content": string(smallContent), "content": string(smallContent),
} }
body2, _ := json.Marshal(postBody2) secondBody, _ := json.Marshal(secondPostBody)
req2 := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer(body2)) secondRequest := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer(secondBody))
req2.Header.Set("Content-Type", "application/json") secondRequest.Header.Set("Content-Type", "application/json")
req2.Header.Set("Authorization", "Bearer "+user.Token) secondRequest.Header.Set("Authorization", "Bearer "+user.Token)
req2 = testutils.WithUserContext(req2, middleware.UserIDKey, user.User.ID) secondRequest = testutils.WithUserContext(secondRequest, middleware.UserIDKey, user.User.ID)
rec2 := httptest.NewRecorder() secondRecorder := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec2, req2) ctx.Router.ServeHTTP(secondRecorder, secondRequest)
assertStatus(t, rec2, http.StatusCreated) assertStatus(t, secondRecorder, http.StatusCreated)
}) })
t.Run("Malformed_JSON_Payloads", func(t *testing.T) { t.Run("Malformed_JSON_Payloads", func(t *testing.T) {
@@ -127,15 +125,15 @@ func TestIntegration_EdgeCases(t *testing.T) {
} }
for _, payload := range malformedPayloads { for _, payload := range malformedPayloads {
req := httptest.NewRequest("POST", "/api/posts", bytes.NewBufferString(payload)) request := httptest.NewRequest("POST", "/api/posts", bytes.NewBufferString(payload))
req.Header.Set("Content-Type", "application/json") request.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+user.Token) request.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID) request = testutils.WithUserContext(request, middleware.UserIDKey, user.User.ID)
rec := httptest.NewRecorder() recorder := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req) ctx.Router.ServeHTTP(recorder, request)
assertErrorResponse(t, rec, http.StatusBadRequest) assertErrorResponse(t, recorder, http.StatusBadRequest)
} }
}) })
@@ -148,38 +146,36 @@ func TestIntegration_EdgeCases(t *testing.T) {
voteBody := map[string]string{"type": "up"} voteBody := map[string]string{"type": "up"}
body, _ := json.Marshal(voteBody) body, _ := json.Marshal(voteBody)
voteReq := httptest.NewRequest("POST", fmt.Sprintf("/api/posts/%d/vote", post.ID), bytes.NewBuffer(body)) voteRequest := httptest.NewRequest("POST", fmt.Sprintf("/api/posts/%d/vote", post.ID), bytes.NewBuffer(body))
voteReq.Header.Set("Content-Type", "application/json") voteRequest.Header.Set("Content-Type", "application/json")
voteReq.Header.Set("Authorization", "Bearer "+user.Token) voteRequest.Header.Set("Authorization", "Bearer "+user.Token)
voteReq = testutils.WithUserContext(voteReq, middleware.UserIDKey, user.User.ID) voteRequest = testutils.WithUserContext(voteRequest, middleware.UserIDKey, user.User.ID)
voteReq = testutils.WithURLParams(voteReq, map[string]string{"id": fmt.Sprintf("%d", post.ID)}) voteRequest = testutils.WithURLParams(voteRequest, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
voteRec := httptest.NewRecorder() voteRecorder := httptest.NewRecorder()
ctx.Router.ServeHTTP(voteRec, voteReq) ctx.Router.ServeHTTP(voteRecorder, voteRequest)
assertStatus(t, voteRec, http.StatusOK) assertStatus(t, voteRecorder, http.StatusOK)
var wg sync.WaitGroup var wg sync.WaitGroup
for i := 0; i < 3; i++ { for range 3 {
wg.Add(1) wg.Go(func() {
go func() { request := httptest.NewRequest("DELETE", fmt.Sprintf("/api/posts/%d/vote", post.ID), nil)
defer wg.Done() request.Header.Set("Authorization", "Bearer "+user.Token)
req := httptest.NewRequest("DELETE", fmt.Sprintf("/api/posts/%d/vote", post.ID), nil) request = testutils.WithUserContext(request, middleware.UserIDKey, user.User.ID)
req.Header.Set("Authorization", "Bearer "+user.Token) request = testutils.WithURLParams(request, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID) recorder := httptest.NewRecorder()
req = testutils.WithURLParams(req, map[string]string{"id": fmt.Sprintf("%d", post.ID)}) ctx.Router.ServeHTTP(recorder, request)
rec := httptest.NewRecorder() })
ctx.Router.ServeHTTP(rec, req)
}()
} }
wg.Wait() wg.Wait()
getVotesReq := httptest.NewRequest("GET", fmt.Sprintf("/api/posts/%d/votes", post.ID), nil) getVotesRequest := httptest.NewRequest("GET", fmt.Sprintf("/api/posts/%d/votes", post.ID), nil)
getVotesReq.Header.Set("Authorization", "Bearer "+user.Token) getVotesRequest.Header.Set("Authorization", "Bearer "+user.Token)
getVotesReq = testutils.WithUserContext(getVotesReq, middleware.UserIDKey, user.User.ID) getVotesRequest = testutils.WithUserContext(getVotesRequest, middleware.UserIDKey, user.User.ID)
getVotesReq = testutils.WithURLParams(getVotesReq, map[string]string{"id": fmt.Sprintf("%d", post.ID)}) getVotesRequest = testutils.WithURLParams(getVotesRequest, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
getVotesRec := httptest.NewRecorder() getVotesRecorder := httptest.NewRecorder()
ctx.Router.ServeHTTP(getVotesRec, getVotesReq) ctx.Router.ServeHTTP(getVotesRecorder, getVotesRequest)
votesResponse := assertJSONResponse(t, getVotesRec, http.StatusOK) votesResponse := assertJSONResponse(t, getVotesRecorder, http.StatusOK)
if votesResponse != nil { if votesResponse != nil {
if data, ok := votesResponse["data"].(map[string]any); ok { if data, ok := votesResponse["data"].(map[string]any); ok {
if votes, ok := data["votes"].([]any); ok { if votes, ok := data["votes"].([]any); ok {

View File

@@ -1,14 +1,10 @@
package integration package integration
import ( import (
"bytes"
"encoding/json"
"net/http" "net/http"
"net/http/httptest"
"testing" "testing"
"goyco/internal/database" "goyco/internal/database"
"goyco/internal/middleware"
"goyco/internal/testutils" "goyco/internal/testutils"
) )
@@ -19,19 +15,14 @@ func TestIntegration_EmailService(t *testing.T) {
t.Run("Registration_Email_Sent", func(t *testing.T) { t.Run("Registration_Email_Sent", func(t *testing.T) {
ctx.Suite.EmailSender.Reset() ctx.Suite.EmailSender.Reset()
reqBody := map[string]string{ reqBody := map[string]any{
"username": "email_reg_user", "username": "email_reg_user",
"email": "email_reg@example.com", "email": "email_reg@example.com",
"password": "SecurePass123!", "password": "SecurePass123!",
} }
body, _ := json.Marshal(reqBody) request := makePostRequestWithJSON(t, router, "/api/auth/register", reqBody)
req := httptest.NewRequest("POST", "/api/auth/register", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req) assertStatus(t, request, http.StatusCreated)
assertStatus(t, rec, http.StatusCreated)
token := ctx.Suite.EmailSender.VerificationToken() token := ctx.Suite.EmailSender.VerificationToken()
if token == "" { if token == "" {
@@ -52,15 +43,10 @@ func TestIntegration_EmailService(t *testing.T) {
t.Fatalf("Failed to create user: %v", err) t.Fatalf("Failed to create user: %v", err)
} }
reqBody := map[string]string{ reqBody := map[string]any{
"username_or_email": "email_reset_user", "username_or_email": "email_reset_user",
} }
body, _ := json.Marshal(reqBody) makePostRequestWithJSON(t, router, "/api/auth/forgot-password", reqBody)
req := httptest.NewRequest("POST", "/api/auth/forgot-password", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
token := ctx.Suite.EmailSender.PasswordResetToken() token := ctx.Suite.EmailSender.PasswordResetToken()
if token == "" { if token == "" {
@@ -72,17 +58,9 @@ func TestIntegration_EmailService(t *testing.T) {
ctx.Suite.EmailSender.Reset() ctx.Suite.EmailSender.Reset()
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "email_del_user", "email_del@example.com") user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "email_del_user", "email_del@example.com")
reqBody := map[string]string{} request := makeDeleteRequest(t, router, "/api/auth/account", user, nil)
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("DELETE", "/api/auth/account", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req) assertStatus(t, request, http.StatusOK)
assertStatus(t, rec, http.StatusOK)
token := ctx.Suite.EmailSender.DeletionToken() token := ctx.Suite.EmailSender.DeletionToken()
if token == "" { if token == "" {
@@ -94,17 +72,10 @@ func TestIntegration_EmailService(t *testing.T) {
ctx.Suite.EmailSender.Reset() ctx.Suite.EmailSender.Reset()
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "email_change_user", "email_change@example.com") user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "email_change_user", "email_change@example.com")
reqBody := map[string]string{ reqBody := map[string]any{
"email": "newemail@example.com", "email": "newemail@example.com",
} }
body, _ := json.Marshal(reqBody) makePutRequest(t, router, "/api/auth/email", reqBody, user, nil)
req := httptest.NewRequest("PUT", "/api/auth/email", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
token := ctx.Suite.EmailSender.VerificationToken() token := ctx.Suite.EmailSender.VerificationToken()
if token == "" { if token == "" {
@@ -115,17 +86,12 @@ func TestIntegration_EmailService(t *testing.T) {
t.Run("Email_Template_Content", func(t *testing.T) { t.Run("Email_Template_Content", func(t *testing.T) {
ctx.Suite.EmailSender.Reset() ctx.Suite.EmailSender.Reset()
reqBody := map[string]string{ reqBody := map[string]any{
"username": "template_user", "username": "template_user",
"email": "template@example.com", "email": "template@example.com",
"password": "SecurePass123!", "password": "SecurePass123!",
} }
body, _ := json.Marshal(reqBody) makePostRequestWithJSON(t, router, "/api/auth/register", reqBody)
req := httptest.NewRequest("POST", "/api/auth/register", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
token := ctx.Suite.EmailSender.VerificationToken() token := ctx.Suite.EmailSender.VerificationToken()
if token == "" { if token == "" {

View File

@@ -1,7 +1,6 @@
package integration package integration
import ( import (
"bytes"
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/http" "net/http"
@@ -10,7 +9,7 @@ import (
"strings" "strings"
"testing" "testing"
"goyco/internal/middleware" "goyco/internal/database"
"goyco/internal/testutils" "goyco/internal/testutils"
) )
@@ -20,46 +19,34 @@ func TestIntegration_EndToEndUserJourneys(t *testing.T) {
t.Run("Complete_Registration_To_Post_Creation_Journey", func(t *testing.T) { t.Run("Complete_Registration_To_Post_Creation_Journey", func(t *testing.T) {
ctx.Suite.EmailSender.Reset() ctx.Suite.EmailSender.Reset()
registerBody := map[string]string{ registerRequest := makePostRequestWithJSON(t, ctx.Router, "/api/auth/register", map[string]any{
"username": "journey_user", "username": "journey_user",
"email": "journey@example.com", "email": "journey@example.com",
"password": "SecurePass123!", "password": "SecurePass123!",
} })
body, _ := json.Marshal(registerBody)
registerReq := httptest.NewRequest("POST", "/api/auth/register", bytes.NewBuffer(body))
registerReq.Header.Set("Content-Type", "application/json")
registerRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(registerRec, registerReq)
assertStatus(t, registerRec, http.StatusCreated) assertStatus(t, registerRequest, http.StatusCreated)
verificationToken := ctx.Suite.EmailSender.VerificationToken() verificationToken := ctx.Suite.EmailSender.VerificationToken()
if verificationToken == "" { if verificationToken == "" {
t.Fatal("Verification token not sent") t.Fatal("Verification token not sent")
} }
confirmReq := httptest.NewRequest("GET", "/api/auth/confirm?token="+url.QueryEscape(verificationToken), nil) confirmRequest := makeGetRequest(t, ctx.Router, "/api/auth/confirm?token="+url.QueryEscape(verificationToken))
confirmRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(confirmRec, confirmReq)
assertStatus(t, confirmRec, http.StatusOK) assertStatus(t, confirmRequest, http.StatusOK)
loginBody := map[string]string{ loginRequest := makePostRequestWithJSON(t, ctx.Router, "/api/auth/login", map[string]any{
"username": "journey_user", "username": "journey_user",
"password": "SecurePass123!", "password": "SecurePass123!",
} })
loginBodyBytes, _ := json.Marshal(loginBody)
loginReq := httptest.NewRequest("POST", "/api/auth/login", bytes.NewBuffer(loginBodyBytes))
loginReq.Header.Set("Content-Type", "application/json")
loginRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(loginRec, loginReq)
loginResponse := assertJSONResponse(t, loginRec, http.StatusOK) loginResponse := assertJSONResponse(t, loginRequest, http.StatusOK)
if loginResponse == nil { if loginResponse == nil {
return return
} }
data, ok := loginResponse["data"].(map[string]any) data, ok := getDataFromResponse(loginResponse)
if !ok { if !ok {
t.Fatal("Login response missing data") t.Fatal("Login response missing data")
} }
@@ -67,8 +54,8 @@ func TestIntegration_EndToEndUserJourneys(t *testing.T) {
var token string var token string
if accessToken, ok := data["access_token"].(string); ok && accessToken != "" { if accessToken, ok := data["access_token"].(string); ok && accessToken != "" {
token = accessToken token = accessToken
} else if tokenVal, ok := data["token"].(string); ok && tokenVal != "" { } else if tokenValue, ok := data["token"].(string); ok && tokenValue != "" {
token = tokenVal token = tokenValue
} else { } else {
t.Fatal("Login response missing access_token or token") t.Fatal("Login response missing access_token or token")
} }
@@ -90,25 +77,19 @@ func TestIntegration_EndToEndUserJourneys(t *testing.T) {
t.Fatalf("Login response missing user.id. Data: %+v", data) t.Fatalf("Login response missing user.id. Data: %+v", data)
} }
postBody := map[string]string{ postBodyBytes, _ := json.Marshal(map[string]any{
"title": "Journey Test Post", "title": "Journey Test Post",
"url": "https://example.com/journey", "url": "https://example.com/journey",
"content": "Test content", "content": "Test content",
} })
postBodyBytes, _ := json.Marshal(postBody) postRequest := makeAuthenticatedRequest(t, ctx.Router, "POST", "/api/posts", postBodyBytes, &authenticatedUser{User: &database.User{ID: uint(userID)}, Token: token}, nil)
postReq := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer(postBodyBytes))
postReq.Header.Set("Content-Type", "application/json")
postReq.Header.Set("Authorization", "Bearer "+token)
postReq = testutils.WithUserContext(postReq, middleware.UserIDKey, uint(userID))
postRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(postRec, postReq)
postResponse := assertJSONResponse(t, postRec, http.StatusCreated) postResponse := assertJSONResponse(t, postRequest, http.StatusCreated)
if postResponse == nil { if postResponse == nil {
return return
} }
postData, ok := postResponse["data"].(map[string]any) postData, ok := getDataFromResponse(postResponse)
if !ok { if !ok {
t.Fatal("Post response missing data") t.Fatal("Post response missing data")
} }
@@ -118,56 +99,39 @@ func TestIntegration_EndToEndUserJourneys(t *testing.T) {
t.Fatal("Post response missing id") t.Fatal("Post response missing id")
} }
getPostReq := httptest.NewRequest("GET", fmt.Sprintf("/api/posts/%.0f", postID), nil) getPostRequest := makeGetRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%.0f", postID))
getPostRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(getPostRec, getPostReq)
assertStatus(t, getPostRec, http.StatusOK) assertStatus(t, getPostRequest, http.StatusOK)
}) })
t.Run("Complete_Password_Reset_Journey", func(t *testing.T) { t.Run("Complete_Password_Reset_Journey", func(t *testing.T) {
ctx.Suite.EmailSender.Reset() ctx.Suite.EmailSender.Reset()
createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "reset_journey_user", "reset_journey@example.com") createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "reset_journey_user", "reset_journey@example.com")
resetBody := map[string]string{ resetRequest := makePostRequestWithJSON(t, ctx.Router, "/api/auth/forgot-password", map[string]any{
"username_or_email": "reset_journey@example.com", "username_or_email": "reset_journey@example.com",
} })
resetBodyBytes, _ := json.Marshal(resetBody)
resetReq := httptest.NewRequest("POST", "/api/auth/forgot-password", bytes.NewBuffer(resetBodyBytes))
resetReq.Header.Set("Content-Type", "application/json")
resetRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(resetRec, resetReq)
assertStatus(t, resetRec, http.StatusOK) assertStatus(t, resetRequest, http.StatusOK)
resetToken := ctx.Suite.EmailSender.GetLastPasswordResetToken() resetToken := ctx.Suite.EmailSender.GetLastPasswordResetToken()
if resetToken == "" { if resetToken == "" {
t.Fatal("Password reset token not sent") t.Fatal("Password reset token not sent")
} }
newPasswordBody := map[string]string{ newPasswordRequest := makePostRequestWithJSON(t, ctx.Router, "/api/auth/reset-password", map[string]any{
"token": resetToken, "token": resetToken,
"new_password": "NewSecurePass123!", "new_password": "NewSecurePass123!",
} })
newPasswordBodyBytes, _ := json.Marshal(newPasswordBody)
newPasswordReq := httptest.NewRequest("POST", "/api/auth/reset-password", bytes.NewBuffer(newPasswordBodyBytes))
newPasswordReq.Header.Set("Content-Type", "application/json")
newPasswordRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(newPasswordRec, newPasswordReq)
assertStatus(t, newPasswordRec, http.StatusOK) assertStatus(t, newPasswordRequest, http.StatusOK)
loginBody := map[string]string{ loginRequest := makePostRequestWithJSON(t, ctx.Router, "/api/auth/login", map[string]any{
"username": "reset_journey_user", "username": "reset_journey_user",
"password": "NewSecurePass123!", "password": "NewSecurePass123!",
} })
loginBodyBytes, _ := json.Marshal(loginBody)
loginReq := httptest.NewRequest("POST", "/api/auth/login", bytes.NewBuffer(loginBodyBytes))
loginReq.Header.Set("Content-Type", "application/json")
loginRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(loginRec, loginReq)
assertStatus(t, loginRec, http.StatusOK) assertStatus(t, loginRequest, http.StatusOK)
}) })
t.Run("Complete_Vote_And_Unvote_Journey", func(t *testing.T) { t.Run("Complete_Vote_And_Unvote_Journey", func(t *testing.T) {
@@ -176,40 +140,21 @@ func TestIntegration_EndToEndUserJourneys(t *testing.T) {
post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Vote Journey Post", "https://example.com/vote-journey") post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Vote Journey Post", "https://example.com/vote-journey")
voteBody := map[string]string{"type": "up"} voteRequest := makePostRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d/vote", post.ID), map[string]any{"type": "up"}, user, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
voteBodyBytes, _ := json.Marshal(voteBody) assertStatus(t, voteRequest, http.StatusOK)
voteReq := httptest.NewRequest("POST", fmt.Sprintf("/api/posts/%d/vote", post.ID), bytes.NewBuffer(voteBodyBytes))
voteReq.Header.Set("Content-Type", "application/json")
voteReq.Header.Set("Authorization", "Bearer "+user.Token)
voteReq = testutils.WithUserContext(voteReq, middleware.UserIDKey, user.User.ID)
voteReq = testutils.WithURLParams(voteReq, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
voteRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(voteRec, voteReq)
assertStatus(t, voteRec, http.StatusOK) getVotesRequest := makeAuthenticatedGetRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d/votes", post.ID), user, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
getVotesReq := httptest.NewRequest("GET", fmt.Sprintf("/api/posts/%d/votes", post.ID), nil) votesResponse := assertJSONResponse(t, getVotesRequest, http.StatusOK)
getVotesReq.Header.Set("Authorization", "Bearer "+user.Token)
getVotesReq = testutils.WithUserContext(getVotesReq, middleware.UserIDKey, user.User.ID)
getVotesReq = testutils.WithURLParams(getVotesReq, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
getVotesRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(getVotesRec, getVotesReq)
votesResponse := assertJSONResponse(t, getVotesRec, http.StatusOK)
if votesResponse == nil { if votesResponse == nil {
return return
} }
if data, ok := votesResponse["data"].(map[string]any); ok { if data, ok := getDataFromResponse(votesResponse); ok {
if votes, ok := data["votes"].([]any); ok && len(votes) > 0 { if votes, ok := data["votes"].([]any); ok && len(votes) > 0 {
unvoteReq := httptest.NewRequest("DELETE", fmt.Sprintf("/api/posts/%d/vote", post.ID), nil) unvoteRequest := makeDeleteRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d/vote", post.ID), user, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
unvoteReq.Header.Set("Authorization", "Bearer "+user.Token)
unvoteReq = testutils.WithUserContext(unvoteReq, middleware.UserIDKey, user.User.ID)
unvoteReq = testutils.WithURLParams(unvoteReq, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
unvoteRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(unvoteRec, unvoteReq)
assertStatus(t, unvoteRec, http.StatusOK) assertStatus(t, unvoteRequest, http.StatusOK)
} }
} }
}) })
@@ -221,31 +166,31 @@ func TestIntegration_EndToEndUserJourneys(t *testing.T) {
csrfToken := getCSRFToken(t, pageRouter, "/register") csrfToken := getCSRFToken(t, pageRouter, "/register")
reqBody := url.Values{} requestBody := url.Values{}
reqBody.Set("username", "page_journey_user") requestBody.Set("username", "page_journey_user")
reqBody.Set("email", "page_journey@example.com") requestBody.Set("email", "page_journey@example.com")
reqBody.Set("password", "SecurePass123!") requestBody.Set("password", "SecurePass123!")
reqBody.Set("password_confirm", "SecurePass123!") requestBody.Set("password_confirm", "SecurePass123!")
reqBody.Set("csrf_token", csrfToken) requestBody.Set("csrf_token", csrfToken)
req := httptest.NewRequest("POST", "/register", strings.NewReader(reqBody.Encode())) request := httptest.NewRequest("POST", "/register", strings.NewReader(requestBody.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") request.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.AddCookie(&http.Cookie{Name: "csrf_token", Value: csrfToken}) request.AddCookie(&http.Cookie{Name: "csrf_token", Value: csrfToken})
rec := httptest.NewRecorder() recorder := httptest.NewRecorder()
pageRouter.ServeHTTP(rec, req) pageRouter.ServeHTTP(recorder, request)
assertStatusRange(t, rec, http.StatusOK, http.StatusSeeOther) assertStatusRange(t, recorder, http.StatusOK, http.StatusSeeOther)
verificationToken := pageCtx.Suite.EmailSender.VerificationToken() verificationToken := pageCtx.Suite.EmailSender.VerificationToken()
if verificationToken == "" { if verificationToken == "" {
t.Fatal("Verification token not sent") t.Fatal("Verification token not sent")
} }
confirmReq := httptest.NewRequest("GET", "/confirm?token="+url.QueryEscape(verificationToken), nil) confirmRequest := httptest.NewRequest("GET", "/confirm?token="+url.QueryEscape(verificationToken), nil)
confirmRec := httptest.NewRecorder() confirmRecorder := httptest.NewRecorder()
pageRouter.ServeHTTP(confirmRec, confirmReq) pageRouter.ServeHTTP(confirmRecorder, confirmRequest)
assertStatusRange(t, confirmRec, http.StatusOK, http.StatusSeeOther) assertStatusRange(t, confirmRecorder, http.StatusOK, http.StatusSeeOther)
loginCSRFToken := getCSRFToken(t, pageRouter, "/login") loginCSRFToken := getCSRFToken(t, pageRouter, "/login")
@@ -254,15 +199,15 @@ func TestIntegration_EndToEndUserJourneys(t *testing.T) {
loginBody.Set("password", "SecurePass123!") loginBody.Set("password", "SecurePass123!")
loginBody.Set("csrf_token", loginCSRFToken) loginBody.Set("csrf_token", loginCSRFToken)
loginReq := httptest.NewRequest("POST", "/login", strings.NewReader(loginBody.Encode())) loginRequest := httptest.NewRequest("POST", "/login", strings.NewReader(loginBody.Encode()))
loginReq.Header.Set("Content-Type", "application/x-www-form-urlencoded") loginRequest.Header.Set("Content-Type", "application/x-www-form-urlencoded")
loginReq.AddCookie(&http.Cookie{Name: "csrf_token", Value: loginCSRFToken}) loginRequest.AddCookie(&http.Cookie{Name: "csrf_token", Value: loginCSRFToken})
loginRec := httptest.NewRecorder() loginRecorder := httptest.NewRecorder()
pageRouter.ServeHTTP(loginRec, loginReq) pageRouter.ServeHTTP(loginRecorder, loginRequest)
assertStatus(t, loginRec, http.StatusSeeOther) assertStatus(t, loginRecorder, http.StatusSeeOther)
loginCookies := loginRec.Result().Cookies() loginCookies := loginRecorder.Result().Cookies()
var authToken string var authToken string
for _, cookie := range loginCookies { for _, cookie := range loginCookies {
if cookie.Name == "auth_token" { if cookie.Name == "auth_token" {
@@ -275,37 +220,31 @@ func TestIntegration_EndToEndUserJourneys(t *testing.T) {
t.Fatal("Auth token not set after login") t.Fatal("Auth token not set after login")
} }
homeReq := httptest.NewRequest("GET", "/", nil) homeRequest := httptest.NewRequest("GET", "/", nil)
homeReq.AddCookie(&http.Cookie{Name: "auth_token", Value: authToken}) homeRequest.AddCookie(&http.Cookie{Name: "auth_token", Value: authToken})
homeRec := httptest.NewRecorder() homeRecorder := httptest.NewRecorder()
pageRouter.ServeHTTP(homeRec, homeReq) pageRouter.ServeHTTP(homeRecorder, homeRequest)
assertStatus(t, homeRec, http.StatusOK) assertStatus(t, homeRecorder, http.StatusOK)
}) })
t.Run("Complete_Post_Creation_And_Update_Journey", func(t *testing.T) { t.Run("Complete_Post_Creation_And_Update_Journey", func(t *testing.T) {
ctx.Suite.EmailSender.Reset() ctx.Suite.EmailSender.Reset()
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "post_update_journey_user", "post_update_journey@example.com") user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "post_update_journey_user", "post_update_journey@example.com")
postBody := map[string]string{ postBodyBytes, _ := json.Marshal(map[string]any{
"title": "Original Title", "title": "Original Title",
"url": "https://example.com/original", "url": "https://example.com/original",
"content": "Original content", "content": "Original content",
} })
postBodyBytes, _ := json.Marshal(postBody) postRequest := makeAuthenticatedRequest(t, ctx.Router, "POST", "/api/posts", postBodyBytes, user, nil)
postReq := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer(postBodyBytes))
postReq.Header.Set("Content-Type", "application/json")
postReq.Header.Set("Authorization", "Bearer "+user.Token)
postReq = testutils.WithUserContext(postReq, middleware.UserIDKey, user.User.ID)
postRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(postRec, postReq)
postResponse := assertJSONResponse(t, postRec, http.StatusCreated) postResponse := assertJSONResponse(t, postRequest, http.StatusCreated)
if postResponse == nil { if postResponse == nil {
return return
} }
postData, ok := postResponse["data"].(map[string]any) postData, ok := getDataFromResponse(postResponse)
if !ok { if !ok {
t.Fatal("Post response missing data") t.Fatal("Post response missing data")
} }
@@ -315,34 +254,25 @@ func TestIntegration_EndToEndUserJourneys(t *testing.T) {
t.Fatal("Post response missing id") t.Fatal("Post response missing id")
} }
updateBody := map[string]string{ updateBodyBytes, _ := json.Marshal(map[string]any{
"title": "Updated Title", "title": "Updated Title",
"content": "Updated content", "content": "Updated content",
} })
updateBodyBytes, _ := json.Marshal(updateBody) updateRequest := makeAuthenticatedRequest(t, ctx.Router, "PUT", fmt.Sprintf("/api/posts/%.0f", postID), updateBodyBytes, user, map[string]string{"id": fmt.Sprintf("%.0f", postID)})
updateReq := httptest.NewRequest("PUT", fmt.Sprintf("/api/posts/%.0f", postID), bytes.NewBuffer(updateBodyBytes))
updateReq.Header.Set("Content-Type", "application/json")
updateReq.Header.Set("Authorization", "Bearer "+user.Token)
updateReq = testutils.WithUserContext(updateReq, middleware.UserIDKey, user.User.ID)
updateReq = testutils.WithURLParams(updateReq, map[string]string{"id": fmt.Sprintf("%.0f", postID)})
updateRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(updateRec, updateReq)
updateResponse := assertJSONResponse(t, updateRec, http.StatusOK) updateResponse := assertJSONResponse(t, updateRequest, http.StatusOK)
if updateResponse == nil { if updateResponse == nil {
return return
} }
getPostReq := httptest.NewRequest("GET", fmt.Sprintf("/api/posts/%.0f", postID), nil) getPostRequest := makeGetRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%.0f", postID))
getPostRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(getPostRec, getPostReq)
getPostResponse := assertJSONResponse(t, getPostRec, http.StatusOK) getPostResponse := assertJSONResponse(t, getPostRequest, http.StatusOK)
if getPostResponse == nil { if getPostResponse == nil {
return return
} }
if data, ok := getPostResponse["data"].(map[string]any); ok { if data, ok := getDataFromResponse(getPostResponse); ok {
if post, ok := data["post"].(map[string]any); ok { if post, ok := data["post"].(map[string]any); ok {
if title, ok := post["title"].(string); ok && title != "Updated Title" { if title, ok := post["title"].(string); ok && title != "Updated Title" {
t.Errorf("Post title not updated: expected 'Updated Title', got '%s'", title) t.Errorf("Post title not updated: expected 'Updated Title', got '%s'", title)

View File

@@ -2,7 +2,6 @@ package integration
import ( import (
"bytes" "bytes"
"encoding/json"
"fmt" "fmt"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
@@ -19,58 +18,46 @@ func TestIntegration_ErrorPropagation(t *testing.T) {
ctx.Suite.EmailSender.Reset() ctx.Suite.EmailSender.Reset()
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "json_error_user", "json_error@example.com") user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "json_error_user", "json_error@example.com")
req := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer([]byte("invalid json{"))) request := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer([]byte("invalid json{")))
req.Header.Set("Content-Type", "application/json") request.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+user.Token) request.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID) request = testutils.WithUserContext(request, middleware.UserIDKey, user.User.ID)
rec := httptest.NewRecorder() recorder := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req) ctx.Router.ServeHTTP(recorder, request)
assertErrorResponse(t, rec, http.StatusBadRequest) assertErrorResponse(t, recorder, http.StatusBadRequest)
}) })
t.Run("Validation_Error_Propagation", func(t *testing.T) { t.Run("Validation_Error_Propagation", func(t *testing.T) {
ctx.Suite.EmailSender.Reset() ctx.Suite.EmailSender.Reset()
reqBody := map[string]string{ reqBody := map[string]any{
"username": "", "username": "",
"email": "invalid-email", "email": "invalid-email",
"password": "weak", "password": "weak",
} }
body, _ := json.Marshal(reqBody) request := makePostRequestWithJSON(t, ctx.Router, "/api/auth/register", reqBody)
req := httptest.NewRequest("POST", "/api/auth/register", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req) assertErrorResponse(t, request, http.StatusBadRequest)
assertErrorResponse(t, rec, http.StatusBadRequest)
}) })
t.Run("Database_Error_Propagation", func(t *testing.T) { t.Run("Database_Error_Propagation", func(t *testing.T) {
ctx.Suite.EmailSender.Reset() ctx.Suite.EmailSender.Reset()
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "db_error_user", "db_error@example.com") user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "db_error_user", "db_error@example.com")
reqBody := map[string]string{ reqBody := map[string]any{
"title": "Test Post", "title": "Test Post",
"url": "https://example.com/test", "url": "https://example.com/test",
"content": "Test content", "content": "Test content",
} }
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer(body)) request := makePostRequest(t, ctx.Router, "/api/posts", reqBody, user, nil)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
rec := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req) if request.Code == http.StatusInternalServerError {
assertErrorResponse(t, request, http.StatusInternalServerError)
if rec.Code == http.StatusInternalServerError {
assertErrorResponse(t, rec, http.StatusInternalServerError)
} else { } else {
assertStatus(t, rec, http.StatusCreated) assertStatus(t, request, http.StatusCreated)
} }
}) })
@@ -78,34 +65,23 @@ func TestIntegration_ErrorPropagation(t *testing.T) {
ctx.Suite.EmailSender.Reset() ctx.Suite.EmailSender.Reset()
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "notfound_error_user", "notfound_error@example.com") user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "notfound_error_user", "notfound_error@example.com")
req := httptest.NewRequest("GET", "/api/posts/999999", nil) request := makeAuthenticatedGetRequest(t, ctx.Router, "/api/posts/999999", user, map[string]string{"id": "999999"})
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
req = testutils.WithURLParams(req, map[string]string{"id": "999999"})
rec := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req) assertErrorResponse(t, request, http.StatusNotFound)
assertErrorResponse(t, rec, http.StatusNotFound)
}) })
t.Run("Unauthorized_Error_Propagation", func(t *testing.T) { t.Run("Unauthorized_Error_Propagation", func(t *testing.T) {
ctx.Suite.EmailSender.Reset() ctx.Suite.EmailSender.Reset()
reqBody := map[string]string{ reqBody := map[string]any{
"title": "Test Post", "title": "Test Post",
"url": "https://example.com/test", "url": "https://example.com/test",
"content": "Test content", "content": "Test content",
} }
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer(body)) request := makePostRequestWithJSON(t, ctx.Router, "/api/posts", reqBody)
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req) assertErrorResponse(t, request, http.StatusUnauthorized)
assertErrorResponse(t, rec, http.StatusUnauthorized)
}) })
t.Run("Forbidden_Error_Propagation", func(t *testing.T) { t.Run("Forbidden_Error_Propagation", func(t *testing.T) {
@@ -115,79 +91,59 @@ func TestIntegration_ErrorPropagation(t *testing.T) {
post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, owner.User.ID, "Forbidden Post", "https://example.com/forbidden") post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, owner.User.ID, "Forbidden Post", "https://example.com/forbidden")
updateBody := map[string]string{ updateBody := map[string]any{
"title": "Updated Title", "title": "Updated Title",
"content": "Updated content", "content": "Updated content",
} }
body, _ := json.Marshal(updateBody)
req := httptest.NewRequest("PUT", fmt.Sprintf("/api/posts/%d", post.ID), bytes.NewBuffer(body)) request := makePutRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d", post.ID), updateBody, otherUser, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+otherUser.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, otherUser.User.ID)
req = testutils.WithURLParams(req, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
rec := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req) assertErrorResponse(t, request, http.StatusForbidden)
assertErrorResponse(t, rec, http.StatusForbidden)
}) })
t.Run("Service_Error_Propagation", func(t *testing.T) { t.Run("Service_Error_Propagation", func(t *testing.T) {
ctx.Suite.EmailSender.Reset() ctx.Suite.EmailSender.Reset()
reqBody := map[string]string{ reqBody := map[string]any{
"username": "existing_user", "username": "existing_user",
"email": "existing@example.com", "email": "existing@example.com",
"password": "SecurePass123!", "password": "SecurePass123!",
} }
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("POST", "/api/auth/register", bytes.NewBuffer(body)) request := makePostRequestWithJSON(t, ctx.Router, "/api/auth/register", reqBody)
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req)
assertStatus(t, rec, http.StatusCreated) assertStatus(t, request, http.StatusCreated)
req = httptest.NewRequest("POST", "/api/auth/register", bytes.NewBuffer(body)) request = makePostRequestWithJSON(t, ctx.Router, "/api/auth/register", reqBody)
req.Header.Set("Content-Type", "application/json")
rec = httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req)
assertStatusRange(t, rec, http.StatusBadRequest, http.StatusConflict) assertStatusRange(t, request, http.StatusBadRequest, http.StatusConflict)
assertErrorResponse(t, rec, rec.Code) assertErrorResponse(t, request, request.Code)
}) })
t.Run("Middleware_Error_Propagation", func(t *testing.T) { t.Run("Middleware_Error_Propagation", func(t *testing.T) {
ctx.Suite.EmailSender.Reset() ctx.Suite.EmailSender.Reset()
req := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer([]byte("{}"))) request := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer([]byte("{}")))
req.Header.Set("Content-Type", "application/json") request.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer expired.invalid.token") request.Header.Set("Authorization", "Bearer expired.invalid.token")
rec := httptest.NewRecorder() recorder := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req) ctx.Router.ServeHTTP(recorder, request)
assertErrorResponse(t, rec, http.StatusUnauthorized) assertErrorResponse(t, recorder, http.StatusUnauthorized)
}) })
t.Run("Handler_Error_Response_Format", func(t *testing.T) { t.Run("Handler_Error_Response_Format", func(t *testing.T) {
ctx.Suite.EmailSender.Reset() ctx.Suite.EmailSender.Reset()
req := httptest.NewRequest("GET", "/api/nonexistent", nil) request := makeGetRequest(t, ctx.Router, "/api/nonexistent")
rec := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req) if request.Code == http.StatusNotFound {
if request.Header().Get("Content-Type") == "application/json" {
if rec.Code == http.StatusNotFound { assertErrorResponse(t, request, http.StatusNotFound)
if rec.Header().Get("Content-Type") == "application/json" { } else if request.Body.Len() == 0 {
assertErrorResponse(t, rec, http.StatusNotFound)
} else {
if rec.Body.Len() == 0 {
t.Error("Expected error response body") t.Error("Expected error response body")
} }
} }
}
}) })
} }

View File

@@ -11,55 +11,35 @@ import (
"testing" "testing"
"time" "time"
"github.com/golang-jwt/jwt/v5"
"goyco/internal/database" "goyco/internal/database"
"goyco/internal/handlers" "goyco/internal/handlers"
"goyco/internal/middleware" "goyco/internal/middleware"
"goyco/internal/repositories" "goyco/internal/repositories"
"goyco/internal/services" "goyco/internal/services"
"goyco/internal/testutils" "goyco/internal/testutils"
"github.com/golang-jwt/jwt/v5"
) )
func TestIntegration_Handlers(t *testing.T) { func TestIntegration_Handlers(t *testing.T) {
suite := testutils.NewServiceSuite(t) ctx := setupTestContext(t)
authService := ctx.AuthService
authService, err := services.NewAuthFacadeForTest(testutils.AppTestConfig, suite.UserRepo, suite.PostRepo, suite.DeletionRepo, suite.RefreshTokenRepo, suite.EmailSender) emailSender := ctx.Suite.EmailSender
if err != nil { userRepo := ctx.Suite.UserRepo
t.Fatalf("Failed to create auth service: %v", err) postRepo := ctx.Suite.PostRepo
}
voteService := services.NewVoteService(suite.VoteRepo, suite.PostRepo, suite.DB)
emailSender := suite.EmailSender
userRepo := suite.UserRepo
postRepo := suite.PostRepo
titleFetcher := suite.TitleFetcher
authHandler := handlers.NewAuthHandler(authService, userRepo)
postHandler := handlers.NewPostHandler(postRepo, titleFetcher, voteService)
voteHandler := handlers.NewVoteHandler(voteService)
userHandler := handlers.NewUserHandler(userRepo, authService)
t.Run("Auth_Handler_Complete_Workflow", func(t *testing.T) { t.Run("Auth_Handler_Complete_Workflow", func(t *testing.T) {
emailSender.Reset() emailSender.Reset()
registerData := map[string]string{ registerResponse := makePostRequestWithJSON(t, ctx.Router, "/api/auth/register", map[string]any{
"username": "handler_user", "username": "handler_user",
"email": "handler@example.com", "email": "handler@example.com",
"password": "SecurePass123!", "password": "SecurePass123!",
} })
registerBody, _ := json.Marshal(registerData) if registerResponse.Code != http.StatusCreated {
registerReq := httptest.NewRequest("POST", "/api/auth/register", bytes.NewBuffer(registerBody)) t.Errorf("Expected status 201, got %d", registerResponse.Code)
registerReq.Header.Set("Content-Type", "application/json")
registerResp := httptest.NewRecorder()
authHandler.Register(registerResp, registerReq)
if registerResp.Code != http.StatusCreated {
t.Errorf("Expected status 201, got %d", registerResp.Code)
} }
var registerPayload map[string]any registerPayload := assertJSONResponse(t, registerResponse, http.StatusCreated)
if err := json.Unmarshal(registerResp.Body.Bytes(), &registerPayload); err != nil {
t.Fatalf("Failed to decode register response: %v", err)
}
if success, _ := registerPayload["success"].(bool); !success { if success, _ := registerPayload["success"].(bool); !success {
t.Fatalf("Expected register response success, got %v", registerPayload) t.Fatalf("Expected register response success, got %v", registerPayload)
} }
@@ -78,11 +58,9 @@ func TestIntegration_Handlers(t *testing.T) {
t.Fatalf("Failed to update user with mock token: %v", err) t.Fatalf("Failed to update user with mock token: %v", err)
} }
confirmReq := httptest.NewRequest(http.MethodGet, "/api/auth/confirm?token="+url.QueryEscape(mockToken), nil) confirmResponse := makeGetRequest(t, ctx.Router, "/api/auth/confirm?token="+url.QueryEscape(mockToken))
confirmResp := httptest.NewRecorder() if confirmResponse.Code != http.StatusOK {
authHandler.ConfirmEmail(confirmResp, confirmReq) t.Fatalf("Expected 200 when confirming email via handler, got %d", confirmResponse.Code)
if confirmResp.Code != http.StatusOK {
t.Fatalf("Expected 200 when confirming email via handler, got %d", confirmResp.Code)
} }
loginSeed := createAuthenticatedUser(t, authService, userRepo, "auth_handler_login", "auth_handler_login@example.com") loginSeed := createAuthenticatedUser(t, authService, userRepo, "auth_handler_login", "auth_handler_login@example.com")
@@ -92,92 +70,60 @@ func TestIntegration_Handlers(t *testing.T) {
t.Fatalf("Service login failed for seeded user: %v", err) t.Fatalf("Service login failed for seeded user: %v", err)
} }
meReq := httptest.NewRequest("GET", "/api/auth/me", nil) meResponse := makeAuthenticatedGetRequest(t, ctx.Router, "/api/auth/me", &authenticatedUser{User: loginSeed.User, Token: loginAuth.AccessToken}, nil)
meReq.Header.Set("Authorization", "Bearer "+loginAuth.AccessToken) if meResponse.Code != http.StatusOK {
meReq = testutils.WithUserContext(meReq, middleware.UserIDKey, loginSeed.User.ID) t.Errorf("Expected status 200, got %d", meResponse.Code)
meResp := httptest.NewRecorder()
authHandler.Me(meResp, meReq)
if meResp.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", meResp.Code)
} }
}) })
t.Run("Auth_Handler_Security_Validation", func(t *testing.T) { t.Run("Auth_Handler_Security_Validation", func(t *testing.T) {
emailSender.Reset() emailSender.Reset()
weakData := map[string]string{ weakResponse := makePostRequestWithJSON(t, ctx.Router, "/api/auth/register", map[string]any{
"username": "weak_user", "username": "weak_user",
"email": "weak@example.com", "email": "weak@example.com",
"password": "123", "password": "123",
} })
weakBody, _ := json.Marshal(weakData) if weakResponse.Code != http.StatusBadRequest {
weakReq := httptest.NewRequest("POST", "/api/auth/register", bytes.NewBuffer(weakBody)) t.Errorf("Expected status 400 for weak password, got %d", weakResponse.Code)
weakReq.Header.Set("Content-Type", "application/json")
weakResp := httptest.NewRecorder()
authHandler.Register(weakResp, weakReq)
if weakResp.Code != http.StatusBadRequest {
t.Errorf("Expected status 400 for weak password, got %d", weakResp.Code)
} }
var weakErrorResp map[string]any weakErrorResponse := assertJSONResponse(t, weakResponse, http.StatusBadRequest)
if err := json.Unmarshal(weakResp.Body.Bytes(), &weakErrorResp); err != nil { if success, _ := weakErrorResponse["success"].(bool); success {
t.Fatalf("Failed to decode error response: %v", err)
}
if success, _ := weakErrorResp["success"].(bool); success {
t.Error("Expected error response to have success=false") t.Error("Expected error response to have success=false")
} }
if errorMsg, ok := weakErrorResp["error"].(string); !ok || errorMsg == "" { if errorMsg, ok := weakErrorResponse["error"].(string); !ok || errorMsg == "" {
t.Error("Expected error response to contain validation error message") t.Error("Expected error response to contain validation error message")
} }
invalidData := map[string]string{ invalidResponse := makePostRequestWithJSON(t, ctx.Router, "/api/auth/register", map[string]any{
"username": "invalid_user", "username": "invalid_user",
"email": "not-an-email", "email": "not-an-email",
"password": "SecurePass123!", "password": "SecurePass123!",
} })
invalidBody, _ := json.Marshal(invalidData) if invalidResponse.Code != http.StatusBadRequest {
invalidReq := httptest.NewRequest("POST", "/api/auth/register", bytes.NewBuffer(invalidBody)) t.Errorf("Expected status 400 for invalid email, got %d", invalidResponse.Code)
invalidReq.Header.Set("Content-Type", "application/json")
invalidResp := httptest.NewRecorder()
authHandler.Register(invalidResp, invalidReq)
if invalidResp.Code != http.StatusBadRequest {
t.Errorf("Expected status 400 for invalid email, got %d", invalidResp.Code)
} }
var invalidEmailErrorResp map[string]any invalidEmailErrorResponse := assertJSONResponse(t, invalidResponse, http.StatusBadRequest)
if err := json.Unmarshal(invalidResp.Body.Bytes(), &invalidEmailErrorResp); err != nil { if success, _ := invalidEmailErrorResponse["success"].(bool); success {
t.Fatalf("Failed to decode error response: %v", err)
}
if success, _ := invalidEmailErrorResp["success"].(bool); success {
t.Error("Expected error response to have success=false") t.Error("Expected error response to have success=false")
} }
if errorMsg, ok := invalidEmailErrorResp["error"].(string); !ok || errorMsg == "" { if errorMsg, ok := invalidEmailErrorResponse["error"].(string); !ok || errorMsg == "" {
t.Error("Expected error response to contain validation error message") t.Error("Expected error response to contain validation error message")
} }
incompleteData := map[string]string{ incompleteResponse := makePostRequestWithJSON(t, ctx.Router, "/api/auth/register", map[string]any{
"username": "incomplete_user", "username": "incomplete_user",
} })
incompleteBody, _ := json.Marshal(incompleteData) if incompleteResponse.Code != http.StatusBadRequest {
incompleteReq := httptest.NewRequest("POST", "/api/auth/register", bytes.NewBuffer(incompleteBody)) t.Errorf("Expected status 400 for missing fields, got %d", incompleteResponse.Code)
incompleteReq.Header.Set("Content-Type", "application/json")
incompleteResp := httptest.NewRecorder()
authHandler.Register(incompleteResp, incompleteReq)
if incompleteResp.Code != http.StatusBadRequest {
t.Errorf("Expected status 400 for missing fields, got %d", incompleteResp.Code)
} }
var incompleteErrorResp map[string]any incompleteErrorResponse := assertJSONResponse(t, incompleteResponse, http.StatusBadRequest)
if err := json.Unmarshal(incompleteResp.Body.Bytes(), &incompleteErrorResp); err != nil { if success, _ := incompleteErrorResponse["success"].(bool); success {
t.Fatalf("Failed to decode error response: %v", err)
}
if success, _ := incompleteErrorResp["success"].(bool); success {
t.Error("Expected error response to have success=false") t.Error("Expected error response to have success=false")
} }
if errorMsg, ok := incompleteErrorResp["error"].(string); !ok || errorMsg == "" { if errorMsg, ok := incompleteErrorResponse["error"].(string); !ok || errorMsg == "" {
t.Error("Expected error response to contain validation error message") t.Error("Expected error response to contain validation error message")
} }
}) })
@@ -186,28 +132,17 @@ func TestIntegration_Handlers(t *testing.T) {
emailSender.Reset() emailSender.Reset()
user := createAuthenticatedUser(t, authService, userRepo, "post_user", "post@example.com") user := createAuthenticatedUser(t, authService, userRepo, "post_user", "post@example.com")
postData := map[string]string{ postResponse := makePostRequest(t, ctx.Router, "/api/posts", map[string]any{
"title": "Handler Test Post", "title": "Handler Test Post",
"url": "https://example.com/handler-test", "url": "https://example.com/handler-test",
"content": "This is a handler test post", "content": "This is a handler test post",
} }, user, nil)
postBody, _ := json.Marshal(postData) if postResponse.Code != http.StatusCreated {
postReq := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer(postBody)) t.Errorf("Expected status 201, got %d", postResponse.Code)
postReq.Header.Set("Content-Type", "application/json")
postReq.Header.Set("Authorization", "Bearer "+user.Token)
postReq = testutils.WithUserContext(postReq, middleware.UserIDKey, user.User.ID)
postResp := httptest.NewRecorder()
postHandler.CreatePost(postResp, postReq)
if postResp.Code != http.StatusCreated {
t.Errorf("Expected status 201, got %d", postResp.Code)
} }
var postResult map[string]any postResult := assertJSONResponse(t, postResponse, http.StatusCreated)
if err := json.Unmarshal(postResp.Body.Bytes(), &postResult); err != nil { postDetails, ok := getDataFromResponse(postResult)
t.Fatalf("Failed to decode post response: %v", err)
}
postDetails, ok := postResult["data"].(map[string]any)
if !ok { if !ok {
t.Fatalf("Expected data object in post response, got %v", postResult) t.Fatalf("Expected data object in post response, got %v", postResult)
} }
@@ -216,87 +151,49 @@ func TestIntegration_Handlers(t *testing.T) {
t.Fatal("Expected post ID in response") t.Fatal("Expected post ID in response")
} }
getReq := httptest.NewRequest("GET", fmt.Sprintf("/api/posts/%d", int(postID)), nil) getResponse := makeGetRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d", int(postID)))
getReq = testutils.WithURLParams(getReq, map[string]string{"id": fmt.Sprintf("%d", int(postID))}) if getResponse.Code != http.StatusOK {
getResp := httptest.NewRecorder() t.Errorf("Expected status 200, got %d", getResponse.Code)
postHandler.GetPost(getResp, getReq)
if getResp.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", getResp.Code)
} }
postsReq := httptest.NewRequest("GET", "/api/posts", nil) postsResponse := makeGetRequest(t, ctx.Router, "/api/posts")
postsResp := httptest.NewRecorder() if postsResponse.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", postsResponse.Code)
postHandler.GetPosts(postsResp, postsReq)
if postsResp.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", postsResp.Code)
} }
searchReq := httptest.NewRequest("GET", "/api/posts/search?q=handler", nil) searchResponse := makeGetRequest(t, ctx.Router, "/api/posts/search?q=handler")
searchResp := httptest.NewRecorder() if searchResponse.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", searchResponse.Code)
postHandler.SearchPosts(searchResp, searchReq)
if searchResp.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", searchResp.Code)
} }
}) })
t.Run("Post_Handler_Security_Validation", func(t *testing.T) { t.Run("Post_Handler_Security_Validation", func(t *testing.T) {
emailSender.Reset() emailSender.Reset()
postData := map[string]string{ postResponse := makePostRequestWithJSON(t, ctx.Router, "/api/posts", map[string]any{
"title": "Unauthorized Post", "title": "Unauthorized Post",
"url": "https://example.com/unauthorized", "url": "https://example.com/unauthorized",
"content": "This should fail", "content": "This should fail",
} })
postBody, _ := json.Marshal(postData) authErrorResponse := assertJSONResponse(t, postResponse, http.StatusUnauthorized)
postReq := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer(postBody)) if success, _ := authErrorResponse["success"].(bool); success {
postReq.Header.Set("Content-Type", "application/json")
postResp := httptest.NewRecorder()
postHandler.CreatePost(postResp, postReq)
if postResp.Code != http.StatusUnauthorized {
t.Errorf("Expected status 401 for unauthenticated post creation, got %d", postResp.Code)
}
var authErrorResp map[string]any
if err := json.Unmarshal(postResp.Body.Bytes(), &authErrorResp); err != nil {
t.Fatalf("Failed to decode error response: %v", err)
}
if success, _ := authErrorResp["success"].(bool); success {
t.Error("Expected error response to have success=false") t.Error("Expected error response to have success=false")
} }
if errorMsg, ok := authErrorResp["error"].(string); !ok || errorMsg == "" { if errorMsg, ok := authErrorResponse["error"].(string); !ok || errorMsg == "" {
t.Error("Expected error response to contain authentication error message") t.Error("Expected error response to contain authentication error message")
} }
user := createAuthenticatedUser(t, authService, userRepo, "security_user", "security@example.com") user := createAuthenticatedUser(t, authService, userRepo, "security_user", "security@example.com")
invalidData := map[string]string{ invalidResponse := makePostRequest(t, ctx.Router, "/api/posts", map[string]any{
"title": "", "title": "",
"url": "not-a-url", "url": "not-a-url",
"content": "Invalid post", "content": "Invalid post",
} }, user, nil)
invalidBody, _ := json.Marshal(invalidData) postValidationErrorResponse := assertJSONResponse(t, invalidResponse, http.StatusBadRequest)
invalidReq := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer(invalidBody)) if success, _ := postValidationErrorResponse["success"].(bool); success {
invalidReq.Header.Set("Content-Type", "application/json")
invalidReq.Header.Set("Authorization", "Bearer "+user.Token)
invalidReq = testutils.WithUserContext(invalidReq, middleware.UserIDKey, user.User.ID)
invalidResp := httptest.NewRecorder()
postHandler.CreatePost(invalidResp, invalidReq)
if invalidResp.Code != http.StatusBadRequest {
t.Errorf("Expected status 400 for invalid post data, got %d", invalidResp.Code)
}
var postValidationErrorResp map[string]any
if err := json.Unmarshal(invalidResp.Body.Bytes(), &postValidationErrorResp); err != nil {
t.Fatalf("Failed to decode error response: %v", err)
}
if success, _ := postValidationErrorResp["success"].(bool); success {
t.Error("Expected error response to have success=false") t.Error("Expected error response to have success=false")
} }
if errorMsg, ok := postValidationErrorResp["error"].(string); !ok || errorMsg == "" { if errorMsg, ok := postValidationErrorResponse["error"].(string); !ok || errorMsg == "" {
t.Error("Expected error response to contain validation error message") t.Error("Expected error response to contain validation error message")
} }
}) })
@@ -306,161 +203,100 @@ func TestIntegration_Handlers(t *testing.T) {
user := createAuthenticatedUser(t, authService, userRepo, "vote_handler_user", "vote_handler@example.com") user := createAuthenticatedUser(t, authService, userRepo, "vote_handler_user", "vote_handler@example.com")
post := testutils.CreatePostWithRepo(t, postRepo, user.User.ID, "Vote Handler Test Post", "https://example.com/vote-handler") post := testutils.CreatePostWithRepo(t, postRepo, user.User.ID, "Vote Handler Test Post", "https://example.com/vote-handler")
voteData := map[string]string{ voteResponse := makePostRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d/vote", post.ID), map[string]any{"type": "up"}, user, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
"type": "up", assertStatus(t, voteResponse, http.StatusOK)
}
voteBody, _ := json.Marshal(voteData)
voteReq := httptest.NewRequest("POST", fmt.Sprintf("/api/posts/%d/vote", post.ID), bytes.NewBuffer(voteBody))
voteReq.Header.Set("Content-Type", "application/json")
voteReq.Header.Set("Authorization", "Bearer "+user.Token)
voteReq = testutils.WithUserContext(voteReq, middleware.UserIDKey, user.User.ID)
voteReq = testutils.WithURLParams(voteReq, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
voteResp := httptest.NewRecorder()
voteHandler.CastVote(voteResp, voteReq) getVoteResponse := makeAuthenticatedGetRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d/vote", post.ID), user, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
if voteResp.Code != http.StatusOK { assertStatus(t, getVoteResponse, http.StatusOK)
t.Errorf("Expected status 200, got %d", voteResp.Code)
}
getVoteReq := httptest.NewRequest("GET", fmt.Sprintf("/api/posts/%d/vote", post.ID), nil) getPostVotesResponse := makeAuthenticatedGetRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d/votes", post.ID), user, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
getVoteReq.Header.Set("Authorization", "Bearer "+user.Token) assertStatus(t, getPostVotesResponse, http.StatusOK)
getVoteReq = testutils.WithUserContext(getVoteReq, middleware.UserIDKey, user.User.ID)
getVoteReq = testutils.WithURLParams(getVoteReq, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
getVoteResp := httptest.NewRecorder()
voteHandler.GetUserVote(getVoteResp, getVoteReq) removeVoteResponse := makeDeleteRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d/vote", post.ID), user, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
if getVoteResp.Code != http.StatusOK { assertStatus(t, removeVoteResponse, http.StatusOK)
t.Errorf("Expected status 200, got %d", getVoteResp.Code)
}
getPostVotesReq := httptest.NewRequest("GET", fmt.Sprintf("/api/posts/%d/votes", post.ID), nil)
getPostVotesReq.Header.Set("Authorization", "Bearer "+user.Token)
getPostVotesReq = testutils.WithUserContext(getPostVotesReq, middleware.UserIDKey, user.User.ID)
getPostVotesReq = testutils.WithURLParams(getPostVotesReq, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
getPostVotesResp := httptest.NewRecorder()
voteHandler.GetPostVotes(getPostVotesResp, getPostVotesReq)
if getPostVotesResp.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", getPostVotesResp.Code)
}
removeVoteReq := httptest.NewRequest("DELETE", fmt.Sprintf("/api/posts/%d/vote", post.ID), nil)
removeVoteReq.Header.Set("Authorization", "Bearer "+user.Token)
removeVoteReq = testutils.WithUserContext(removeVoteReq, middleware.UserIDKey, user.User.ID)
removeVoteReq = testutils.WithURLParams(removeVoteReq, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
removeVoteResp := httptest.NewRecorder()
voteHandler.RemoveVote(removeVoteResp, removeVoteReq)
if removeVoteResp.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", removeVoteResp.Code)
}
}) })
t.Run("User_Handler_Complete_Workflow", func(t *testing.T) { t.Run("User_Handler_Complete_Workflow", func(t *testing.T) {
emailSender.Reset() emailSender.Reset()
user := createAuthenticatedUser(t, authService, userRepo, "user_handler_user", "user_handler@example.com") user := createAuthenticatedUser(t, authService, userRepo, "user_handler_user", "user_handler@example.com")
usersReq := httptest.NewRequest("GET", "/api/users", nil) usersResponse := makeAuthenticatedGetRequest(t, ctx.Router, "/api/users", user, nil)
usersReq.Header.Set("Authorization", "Bearer "+user.Token) assertStatus(t, usersResponse, http.StatusOK)
usersReq = testutils.WithUserContext(usersReq, middleware.UserIDKey, user.User.ID)
usersResp := httptest.NewRecorder()
userHandler.GetUsers(usersResp, usersReq) getUserResponse := makeAuthenticatedGetRequest(t, ctx.Router, fmt.Sprintf("/api/users/%d", user.User.ID), user, map[string]string{"id": fmt.Sprintf("%d", user.User.ID)})
if usersResp.Code != http.StatusOK { assertStatus(t, getUserResponse, http.StatusOK)
t.Errorf("Expected status 200, got %d", usersResp.Code)
}
getUserReq := httptest.NewRequest("GET", fmt.Sprintf("/api/users/%d", user.User.ID), nil) getUserPostsResponse := makeAuthenticatedGetRequest(t, ctx.Router, fmt.Sprintf("/api/users/%d/posts", user.User.ID), user, map[string]string{"id": fmt.Sprintf("%d", user.User.ID)})
getUserReq.Header.Set("Authorization", "Bearer "+user.Token) assertStatus(t, getUserPostsResponse, http.StatusOK)
getUserReq = testutils.WithUserContext(getUserReq, middleware.UserIDKey, user.User.ID)
getUserReq = testutils.WithURLParams(getUserReq, map[string]string{"id": fmt.Sprintf("%d", user.User.ID)})
getUserResp := httptest.NewRecorder()
userHandler.GetUser(getUserResp, getUserReq)
if getUserResp.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", getUserResp.Code)
}
getUserPostsReq := httptest.NewRequest("GET", fmt.Sprintf("/api/users/%d/posts", user.User.ID), nil)
getUserPostsReq.Header.Set("Authorization", "Bearer "+user.Token)
getUserPostsReq = testutils.WithUserContext(getUserPostsReq, middleware.UserIDKey, user.User.ID)
getUserPostsReq = testutils.WithURLParams(getUserPostsReq, map[string]string{"id": fmt.Sprintf("%d", user.User.ID)})
getUserPostsResp := httptest.NewRecorder()
userHandler.GetUserPosts(getUserPostsResp, getUserPostsReq)
if getUserPostsResp.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", getUserPostsResp.Code)
}
}) })
t.Run("Error_Handling_Invalid_Requests", func(t *testing.T) { t.Run("Error_Handling_Invalid_Requests", func(t *testing.T) {
invalidJSONReq := httptest.NewRequest("POST", "/api/auth/register", bytes.NewBuffer([]byte("invalid json"))) middleware.StopAllRateLimiters()
invalidJSONReq.Header.Set("Content-Type", "application/json") ctx.Suite.EmailSender.Reset()
invalidJSONResp := httptest.NewRecorder()
authHandler.Register(invalidJSONResp, invalidJSONReq) invalidJSONResponse := makeRequest(t, ctx.Router, "POST", "/api/auth/register", []byte("invalid json"), map[string]string{"Content-Type": "application/json"})
if invalidJSONResp.Code != http.StatusBadRequest { jsonErrorResponse := assertJSONResponse(t, invalidJSONResponse, http.StatusBadRequest)
t.Errorf("Expected status 400 for invalid JSON, got %d", invalidJSONResp.Code) if success, _ := jsonErrorResponse["success"].(bool); success {
}
var jsonErrorResp map[string]any
if err := json.Unmarshal(invalidJSONResp.Body.Bytes(), &jsonErrorResp); err != nil {
t.Fatalf("Failed to decode error response: %v", err)
}
if success, _ := jsonErrorResp["success"].(bool); success {
t.Error("Expected error response to have success=false") t.Error("Expected error response to have success=false")
} }
if errorMsg, ok := jsonErrorResp["error"].(string); !ok || errorMsg == "" { if errorMsg, ok := jsonErrorResponse["error"].(string); !ok || errorMsg == "" {
t.Error("Expected error response to contain JSON parsing error message") t.Error("Expected error response to contain JSON parsing error message")
} }
missingCTData := map[string]string{ missingCTData := map[string]string{
"username": "missing_ct_user", "username": uniqueTestUsername(t, "missing_ct"),
"email": "missing_ct@example.com", "email": uniqueTestEmail(t, "missing_ct"),
"password": "SecurePass123!", "password": "SecurePass123!",
} }
missingCTBody, _ := json.Marshal(missingCTData) missingCTBody, _ := json.Marshal(missingCTData)
missingCTReq := httptest.NewRequest("POST", "/api/auth/register", bytes.NewBuffer(missingCTBody)) missingCTRequest := httptest.NewRequest("POST", "/api/auth/register", bytes.NewBuffer(missingCTBody))
missingCTResp := httptest.NewRecorder() missingCTResponse := httptest.NewRecorder()
authHandler.Register(missingCTResp, missingCTReq) ctx.Router.ServeHTTP(missingCTResponse, missingCTRequest)
if missingCTResp.Code != http.StatusCreated { if missingCTResponse.Code == http.StatusTooManyRequests {
t.Errorf("Expected status 201, got %d", missingCTResp.Code) var rateLimitResponse map[string]any
if err := json.Unmarshal(missingCTResponse.Body.Bytes(), &rateLimitResponse); err != nil {
t.Errorf("Rate limited but response is not valid JSON: %v", err)
} else {
t.Logf("Rate limit hit (expected in full test suite run), but request was processed correctly (not rejected as invalid JSON)")
}
} else if missingCTResponse.Code != http.StatusCreated {
t.Errorf("Expected status 201, got %d", missingCTResponse.Code)
} }
invalidEndpointReq := httptest.NewRequest("GET", "/api/invalid/endpoint", nil) invalidEndpointRequest := httptest.NewRequest("GET", "/api/invalid/endpoint", nil)
invalidEndpointResp := httptest.NewRecorder() invalidEndpointResponse := httptest.NewRecorder()
authHandler.Me(invalidEndpointResp, invalidEndpointReq) ctx.Router.ServeHTTP(invalidEndpointResponse, invalidEndpointRequest)
if invalidEndpointResp.Code == http.StatusOK { if invalidEndpointResponse.Code == http.StatusOK {
t.Error("Expected error for invalid endpoint") t.Error("Expected error for invalid endpoint")
} }
}) })
t.Run("Security_Authentication_Bypass", func(t *testing.T) { t.Run("Security_Authentication_Bypass", func(t *testing.T) {
meReq := httptest.NewRequest("GET", "/api/auth/me", nil) meRequest := httptest.NewRequest("GET", "/api/auth/me", nil)
meResp := httptest.NewRecorder() meResponse := httptest.NewRecorder()
authHandler.Me(meResp, meReq) ctx.Router.ServeHTTP(meResponse, meRequest)
if meResp.Code == http.StatusOK { if meResponse.Code == http.StatusOK {
t.Error("Expected error for unauthenticated request") t.Error("Expected error for unauthenticated request")
} }
invalidTokenReq := httptest.NewRequest("GET", "/api/auth/me", nil) invalidTokenRequest := httptest.NewRequest("GET", "/api/auth/me", nil)
invalidTokenReq.Header.Set("Authorization", "Bearer invalid-token") invalidTokenRequest.Header.Set("Authorization", "Bearer invalid-token")
invalidTokenResp := httptest.NewRecorder() invalidTokenResponse := httptest.NewRecorder()
authHandler.Me(invalidTokenResp, invalidTokenReq) ctx.Router.ServeHTTP(invalidTokenResponse, invalidTokenRequest)
if invalidTokenResp.Code == http.StatusOK { if invalidTokenResponse.Code == http.StatusOK {
t.Error("Expected error for invalid token") t.Error("Expected error for invalid token")
} }
malformedTokenReq := httptest.NewRequest("GET", "/api/auth/me", nil) malformedTokenRequest := httptest.NewRequest("GET", "/api/auth/me", nil)
malformedTokenReq.Header.Set("Authorization", "InvalidFormat token") malformedTokenRequest.Header.Set("Authorization", "InvalidFormat token")
malformedTokenResp := httptest.NewRecorder() malformedTokenResponse := httptest.NewRecorder()
authHandler.Me(malformedTokenResp, malformedTokenReq) ctx.Router.ServeHTTP(malformedTokenResponse, malformedTokenRequest)
if malformedTokenResp.Code == http.StatusOK { if malformedTokenResponse.Code == http.StatusOK {
t.Error("Expected error for malformed token") t.Error("Expected error for malformed token")
} }
}) })
@@ -468,32 +304,21 @@ func TestIntegration_Handlers(t *testing.T) {
t.Run("Security_Input_Sanitization", func(t *testing.T) { t.Run("Security_Input_Sanitization", func(t *testing.T) {
user := createAuthenticatedUser(t, authService, userRepo, "xss_user", "xss@example.com") user := createAuthenticatedUser(t, authService, userRepo, "xss_user", "xss@example.com")
xssData := map[string]string{ xssResponse := makePostRequest(t, ctx.Router, "/api/posts", map[string]any{
"title": "<script>alert('xss')</script>", "title": "<script>alert('xss')</script>",
"url": "https://example.com/xss", "url": "https://example.com/xss",
"content": "XSS test content", "content": "<script>alert('xss')</script>",
} }, user, nil)
xssBody, _ := json.Marshal(xssData) if xssResponse.Code != http.StatusCreated {
xssReq := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer(xssBody)) t.Errorf("Expected status 201 for XSS sanitization, got %d", xssResponse.Code)
xssReq.Header.Set("Content-Type", "application/json")
xssReq.Header.Set("Authorization", "Bearer "+user.Token)
xssReq = testutils.WithUserContext(xssReq, middleware.UserIDKey, user.User.ID)
xssResp := httptest.NewRecorder()
postHandler.CreatePost(xssResp, xssReq)
if xssResp.Code != http.StatusCreated {
t.Errorf("Expected status 201 for XSS sanitization, got %d", xssResp.Code)
} }
var xssResult map[string]any xssResult := assertJSONResponse(t, xssResponse, http.StatusCreated)
if err := json.Unmarshal(xssResp.Body.Bytes(), &xssResult); err != nil {
t.Fatalf("Failed to decode XSS response: %v", err)
}
if success, _ := xssResult["success"].(bool); !success { if success, _ := xssResult["success"].(bool); !success {
t.Error("Expected XSS response to have success=true") t.Error("Expected XSS response to have success=true")
} }
data, ok := xssResult["data"].(map[string]any) data, ok := getDataFromResponse(xssResult)
if !ok { if !ok {
t.Fatalf("Expected data object in XSS response, got %T", xssResult["data"]) t.Fatalf("Expected data object in XSS response, got %T", xssResult["data"])
} }
@@ -522,32 +347,21 @@ func TestIntegration_Handlers(t *testing.T) {
t.Errorf("Expected script tags to be HTML-escaped in content, got: %s", content) t.Errorf("Expected script tags to be HTML-escaped in content, got: %s", content)
} }
sqlData := map[string]string{ sqlResponse := makePostRequest(t, ctx.Router, "/api/posts", map[string]any{
"title": "'; DROP TABLE posts; --", "title": "'; DROP TABLE posts; --",
"url": "https://example.com/sql", "url": "https://example.com/sql",
"content": "SQL injection test", "content": "SQL injection test",
} }, user, nil)
sqlBody, _ := json.Marshal(sqlData) if sqlResponse.Code != http.StatusCreated {
sqlReq := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer(sqlBody)) t.Errorf("Expected status 201 for SQL injection sanitization, got %d", sqlResponse.Code)
sqlReq.Header.Set("Content-Type", "application/json")
sqlReq.Header.Set("Authorization", "Bearer "+user.Token)
sqlReq = testutils.WithUserContext(sqlReq, middleware.UserIDKey, user.User.ID)
sqlResp := httptest.NewRecorder()
postHandler.CreatePost(sqlResp, sqlReq)
if sqlResp.Code != http.StatusCreated {
t.Errorf("Expected status 201 for SQL injection sanitization, got %d", sqlResp.Code)
} }
var sqlResult map[string]any sqlResult := assertJSONResponse(t, sqlResponse, http.StatusCreated)
if err := json.Unmarshal(sqlResp.Body.Bytes(), &sqlResult); err != nil {
t.Fatalf("Failed to decode SQL response: %v", err)
}
if success, _ := sqlResult["success"].(bool); !success { if success, _ := sqlResult["success"].(bool); !success {
t.Error("Expected SQL response to have success=true") t.Error("Expected SQL response to have success=true")
} }
sqlResponseData, ok := sqlResult["data"].(map[string]any) sqlResponseData, ok := getDataFromResponse(sqlResult)
if !ok { if !ok {
t.Fatalf("Expected data object in SQL response, got %T", sqlResult["data"]) t.Fatalf("Expected data object in SQL response, got %T", sqlResult["data"])
} }
@@ -579,63 +393,31 @@ func TestIntegration_Handlers(t *testing.T) {
t.Run("Authorization_User_Access_Control", func(t *testing.T) { t.Run("Authorization_User_Access_Control", func(t *testing.T) {
emailSender.Reset() emailSender.Reset()
user1 := createAuthenticatedUser(t, authService, userRepo, "auth_user1", "auth1@example.com") firstUser := createAuthenticatedUser(t, authService, userRepo, "auth_user1", "auth1@example.com")
user2 := createAuthenticatedUser(t, authService, userRepo, "auth_user2", "auth2@example.com") secondUser := createAuthenticatedUser(t, authService, userRepo, "auth_user2", "auth2@example.com")
post := testutils.CreatePostWithRepo(t, postRepo, user1.User.ID, "Private Post", "https://example.com/private") post := testutils.CreatePostWithRepo(t, postRepo, firstUser.User.ID, "Private Post", "https://example.com/private")
getPostReq := httptest.NewRequest("GET", fmt.Sprintf("/api/posts/%d", post.ID), nil) getPostResponse := makeAuthenticatedGetRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d", post.ID), secondUser, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
getPostReq.Header.Set("Authorization", "Bearer "+user2.Token) testutils.AssertHTTPStatus(t, getPostResponse, http.StatusOK)
getPostReq = testutils.WithUserContext(getPostReq, middleware.UserIDKey, user2.User.ID)
getPostReq = testutils.WithURLParams(getPostReq, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
getPostResp := httptest.NewRecorder()
postHandler.GetPost(getPostResp, getPostReq) updateResponse := makePutRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d", post.ID), map[string]any{"title": "Updated Title"}, secondUser, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
testutils.AssertHTTPStatus(t, getPostResp, http.StatusOK) testutils.AssertHTTPStatus(t, updateResponse, http.StatusForbidden)
updateData := map[string]string{ deleteResponse := makeDeleteRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d", post.ID), secondUser, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
"title": "Updated Title", testutils.AssertHTTPStatus(t, deleteResponse, http.StatusForbidden)
}
updateBody, _ := json.Marshal(updateData)
updateReq := httptest.NewRequest("PUT", fmt.Sprintf("/api/posts/%d", post.ID), bytes.NewBuffer(updateBody))
updateReq.Header.Set("Content-Type", "application/json")
updateReq.Header.Set("Authorization", "Bearer "+user2.Token)
updateReq = testutils.WithUserContext(updateReq, middleware.UserIDKey, user2.User.ID)
updateReq = testutils.WithURLParams(updateReq, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
updateResp := httptest.NewRecorder()
postHandler.UpdatePost(updateResp, updateReq)
testutils.AssertHTTPStatus(t, updateResp, http.StatusForbidden)
deleteReq := httptest.NewRequest("DELETE", fmt.Sprintf("/api/posts/%d", post.ID), nil)
deleteReq.Header.Set("Authorization", "Bearer "+user2.Token)
deleteReq = testutils.WithUserContext(deleteReq, middleware.UserIDKey, user2.User.ID)
deleteReq = testutils.WithURLParams(deleteReq, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
deleteResp := httptest.NewRecorder()
postHandler.DeletePost(deleteResp, deleteReq)
testutils.AssertHTTPStatus(t, deleteResp, http.StatusForbidden)
}) })
t.Run("Authorization_Vote_Access_Control", func(t *testing.T) { t.Run("Authorization_Vote_Access_Control", func(t *testing.T) {
emailSender.Reset() emailSender.Reset()
user1 := createAuthenticatedUser(t, authService, userRepo, "vote_auth_user1", "vote_auth1@example.com") firstUser := createAuthenticatedUser(t, authService, userRepo, "vote_auth_user1", "vote_auth1@example.com")
user2 := createAuthenticatedUser(t, authService, userRepo, "vote_auth_user2", "vote_auth2@example.com") secondUser := createAuthenticatedUser(t, authService, userRepo, "vote_auth_user2", "vote_auth2@example.com")
post := testutils.CreatePostWithRepo(t, postRepo, user1.User.ID, "Vote Auth Post", "https://example.com/vote-auth") post := testutils.CreatePostWithRepo(t, postRepo, firstUser.User.ID, "Vote Auth Post", "https://example.com/vote-auth")
voteData := map[string]string{"type": "up"} voteResponse := makePostRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d/vote", post.ID), map[string]any{"type": "up"}, secondUser, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
voteBody, _ := json.Marshal(voteData) if voteResponse.Code != http.StatusOK {
voteReq := httptest.NewRequest("POST", fmt.Sprintf("/api/posts/%d/vote", post.ID), bytes.NewBuffer(voteBody)) t.Errorf("Users should be able to vote on any post, got %d", voteResponse.Code)
voteReq.Header.Set("Content-Type", "application/json")
voteReq.Header.Set("Authorization", "Bearer "+user2.Token)
voteReq = testutils.WithUserContext(voteReq, middleware.UserIDKey, user2.User.ID)
voteReq = testutils.WithURLParams(voteReq, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
voteResp := httptest.NewRecorder()
voteHandler.CastVote(voteResp, voteReq)
if voteResp.Code != http.StatusOK {
t.Errorf("Users should be able to vote on any post, got %d", voteResp.Code)
} }
}) })
@@ -664,12 +446,8 @@ func TestIntegration_Handlers(t *testing.T) {
t.Fatalf("Failed to generate expired token: %v", err) t.Fatalf("Failed to generate expired token: %v", err)
} }
meReq := httptest.NewRequest("GET", "/api/auth/me", nil) meResponse := makeRequest(t, ctx.Router, "GET", "/api/auth/me", nil, map[string]string{"Authorization": "Bearer " + expiredToken})
meReq.Header.Set("Authorization", "Bearer "+expiredToken) testutils.AssertHTTPStatus(t, meResponse, http.StatusUnauthorized)
meResp := httptest.NewRecorder()
authHandler.Me(meResp, meReq)
testutils.AssertHTTPStatus(t, meResp, http.StatusUnauthorized)
}) })
t.Run("Authorization_Token_Tampering", func(t *testing.T) { t.Run("Authorization_Token_Tampering", func(t *testing.T) {
@@ -678,12 +456,12 @@ func TestIntegration_Handlers(t *testing.T) {
tamperedToken := user.Token[:len(user.Token)-5] + "XXXXX" tamperedToken := user.Token[:len(user.Token)-5] + "XXXXX"
meReq := httptest.NewRequest("GET", "/api/auth/me", nil) meRequest := httptest.NewRequest("GET", "/api/auth/me", nil)
meReq.Header.Set("Authorization", "Bearer "+tamperedToken) meRequest.Header.Set("Authorization", "Bearer "+tamperedToken)
meResp := httptest.NewRecorder() meResponse := httptest.NewRecorder()
authHandler.Me(meResp, meReq) ctx.Router.ServeHTTP(meResponse, meRequest)
testutils.AssertHTTPStatus(t, meResp, http.StatusUnauthorized) testutils.AssertHTTPStatus(t, meResponse, http.StatusUnauthorized)
}) })
t.Run("Authorization_Session_Version_Mismatch", func(t *testing.T) { t.Run("Authorization_Session_Version_Mismatch", func(t *testing.T) {
@@ -711,12 +489,12 @@ func TestIntegration_Handlers(t *testing.T) {
t.Fatalf("Failed to generate invalid token: %v", err) t.Fatalf("Failed to generate invalid token: %v", err)
} }
meReq := httptest.NewRequest("GET", "/api/auth/me", nil) meRequest := httptest.NewRequest("GET", "/api/auth/me", nil)
meReq.Header.Set("Authorization", "Bearer "+invalidToken) meRequest.Header.Set("Authorization", "Bearer "+invalidToken)
meResp := httptest.NewRecorder() meResponse := httptest.NewRecorder()
authHandler.Me(meResp, meReq) ctx.Router.ServeHTTP(meResponse, meRequest)
testutils.AssertHTTPStatus(t, meResp, http.StatusUnauthorized) testutils.AssertHTTPStatus(t, meResponse, http.StatusUnauthorized)
}) })
} }
@@ -748,11 +526,9 @@ func TestIntegration_DatabaseMonitoring(t *testing.T) {
} }
voteService := services.NewVoteService(voteRepo, postRepo, db) voteService := services.NewVoteService(voteRepo, postRepo, db)
apiHandler := handlers.NewAPIHandlerWithMonitoring(testutils.AppTestConfig, postRepo, userRepo, voteService, db, monitor) apiHandler := handlers.NewAPIHandlerWithMonitoring(testutils.AppTestConfig, postRepo, userRepo, voteService, db, monitor)
t.Run("Health endpoint includes database monitoring", func(t *testing.T) { t.Run("Health endpoint includes database monitoring", func(t *testing.T) {
user := &database.User{ user := &database.User{
Username: "monitoring_user", Username: "monitoring_user",
Email: "monitoring@example.com", Email: "monitoring@example.com",
@@ -765,7 +541,6 @@ func TestIntegration_DatabaseMonitoring(t *testing.T) {
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
apiHandler.GetHealth(recorder, request) apiHandler.GetHealth(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusOK) testutils.AssertHTTPStatus(t, recorder, http.StatusOK)
var response map[string]any var response map[string]any
@@ -777,7 +552,7 @@ func TestIntegration_DatabaseMonitoring(t *testing.T) {
t.Error("Expected success to be true") t.Error("Expected success to be true")
} }
data, ok := response["data"].(map[string]any) data, ok := getDataFromResponse(response)
if !ok { if !ok {
t.Fatal("Expected data to be a map") t.Fatal("Expected data to be a map")
} }
@@ -813,7 +588,6 @@ func TestIntegration_DatabaseMonitoring(t *testing.T) {
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
apiHandler.GetMetrics(recorder, request) apiHandler.GetMetrics(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusOK) testutils.AssertHTTPStatus(t, recorder, http.StatusOK)
var response map[string]any var response map[string]any
@@ -825,7 +599,7 @@ func TestIntegration_DatabaseMonitoring(t *testing.T) {
t.Error("Expected success to be true") t.Error("Expected success to be true")
} }
data, ok := response["data"].(map[string]any) data, ok := getDataFromResponse(response)
if !ok { if !ok {
t.Fatal("Expected data to be a map") t.Fatal("Expected data to be a map")
} }

View File

@@ -1,6 +1,7 @@
package integration package integration
import ( import (
"bytes"
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/http" "net/http"
@@ -282,14 +283,14 @@ func setupPageHandlerTestContext(t *testing.T) *testContext {
func getCSRFToken(t *testing.T, router http.Handler, path string, cookies ...*http.Cookie) string { func getCSRFToken(t *testing.T, router http.Handler, path string, cookies ...*http.Cookie) string {
t.Helper() t.Helper()
req := httptest.NewRequest("GET", path, nil) request := httptest.NewRequest("GET", path, nil)
for _, cookie := range cookies { for _, cookie := range cookies {
req.AddCookie(cookie) request.AddCookie(cookie)
} }
rec := httptest.NewRecorder() recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req) router.ServeHTTP(recorder, request)
cookieList := rec.Result().Cookies() cookieList := recorder.Result().Cookies()
for _, cookie := range cookieList { for _, cookie := range cookieList {
if cookie.Name == "csrf_token" { if cookie.Name == "csrf_token" {
return cookie.Value return cookie.Value
@@ -299,32 +300,32 @@ func getCSRFToken(t *testing.T, router http.Handler, path string, cookies ...*ht
return "" return ""
} }
func assertJSONResponse(t *testing.T, rec *httptest.ResponseRecorder, expectedStatus int) map[string]any { func assertJSONResponse(t *testing.T, recorder *httptest.ResponseRecorder, expectedStatus int) map[string]any {
t.Helper() t.Helper()
if rec.Code != expectedStatus { if recorder.Code != expectedStatus {
t.Errorf("Expected status %d, got %d. Body: %s", expectedStatus, rec.Code, rec.Body.String()) t.Errorf("Expected status %d, got %d. Body: %s", expectedStatus, recorder.Code, recorder.Body.String())
return nil return nil
} }
var response map[string]any var response map[string]any
if err := json.NewDecoder(rec.Body).Decode(&response); err != nil { if err := json.NewDecoder(recorder.Body).Decode(&response); err != nil {
t.Fatalf("Failed to decode response: %v. Body: %s", err, rec.Body.String()) t.Fatalf("Failed to decode response: %v. Body: %s", err, recorder.Body.String())
return nil return nil
} }
return response return response
} }
func assertErrorResponse(t *testing.T, rec *httptest.ResponseRecorder, expectedStatus int) { func assertErrorResponse(t *testing.T, recorder *httptest.ResponseRecorder, expectedStatus int) {
t.Helper() t.Helper()
if rec.Code != expectedStatus { if recorder.Code != expectedStatus {
t.Errorf("Expected status %d, got %d. Body: %s", expectedStatus, rec.Code, rec.Body.String()) t.Errorf("Expected status %d, got %d. Body: %s", expectedStatus, recorder.Code, recorder.Body.String())
return return
} }
var response map[string]any var response map[string]any
if err := json.NewDecoder(rec.Body).Decode(&response); err != nil { if err := json.NewDecoder(recorder.Body).Decode(&response); err != nil {
t.Fatalf("Failed to decode error response: %v. Body: %s", err, rec.Body.String()) t.Fatalf("Failed to decode error response: %v. Body: %s", err, recorder.Body.String())
return return
} }
@@ -335,23 +336,23 @@ func assertErrorResponse(t *testing.T, rec *httptest.ResponseRecorder, expectedS
} }
} }
func assertStatus(t *testing.T, rec *httptest.ResponseRecorder, expectedStatus int) { func assertStatus(t *testing.T, recorder *httptest.ResponseRecorder, expectedStatus int) {
t.Helper() t.Helper()
if rec.Code != expectedStatus { if recorder.Code != expectedStatus {
t.Errorf("Expected status %d, got %d. Body: %s", expectedStatus, rec.Code, rec.Body.String()) t.Errorf("Expected status %d, got %d. Body: %s", expectedStatus, recorder.Code, recorder.Body.String())
} }
} }
func assertStatusRange(t *testing.T, rec *httptest.ResponseRecorder, minStatus, maxStatus int) { func assertStatusRange(t *testing.T, recorder *httptest.ResponseRecorder, minStatus, maxStatus int) {
t.Helper() t.Helper()
if rec.Code < minStatus || rec.Code > maxStatus { if recorder.Code < minStatus || recorder.Code > maxStatus {
t.Errorf("Expected status between %d and %d, got %d. Body: %s", minStatus, maxStatus, rec.Code, rec.Body.String()) t.Errorf("Expected status between %d and %d, got %d. Body: %s", minStatus, maxStatus, recorder.Code, recorder.Body.String())
} }
} }
func assertCookie(t *testing.T, rec *httptest.ResponseRecorder, name, expectedValue string) { func assertCookie(t *testing.T, recorder *httptest.ResponseRecorder, name, expectedValue string) {
t.Helper() t.Helper()
cookies := rec.Result().Cookies() cookies := recorder.Result().Cookies()
for _, cookie := range cookies { for _, cookie := range cookies {
if cookie.Name == name { if cookie.Name == name {
if expectedValue != "" && cookie.Value != expectedValue { if expectedValue != "" && cookie.Value != expectedValue {
@@ -363,9 +364,9 @@ func assertCookie(t *testing.T, rec *httptest.ResponseRecorder, name, expectedVa
t.Errorf("Expected cookie %s not found", name) t.Errorf("Expected cookie %s not found", name)
} }
func assertCookieCleared(t *testing.T, rec *httptest.ResponseRecorder, name string) { func assertCookieCleared(t *testing.T, recorder *httptest.ResponseRecorder, name string) {
t.Helper() t.Helper()
cookies := rec.Result().Cookies() cookies := recorder.Result().Cookies()
for _, cookie := range cookies { for _, cookie := range cookies {
if cookie.Name == name { if cookie.Name == name {
if cookie.Value != "" { if cookie.Value != "" {
@@ -376,21 +377,17 @@ func assertCookieCleared(t *testing.T, rec *httptest.ResponseRecorder, name stri
} }
} }
func assertHeader(t *testing.T, rec *httptest.ResponseRecorder, name, expectedValue string) { func assertHeader(t *testing.T, recorder *httptest.ResponseRecorder, name string) {
t.Helper() t.Helper()
actualValue := rec.Header().Get(name) actualValue := recorder.Header().Get(name)
if expectedValue == "" {
if actualValue == "" { if actualValue == "" {
t.Errorf("Expected header %s to be present", name) t.Errorf("Expected header %s to be present", name)
} }
} else if actualValue != expectedValue {
t.Errorf("Expected header %s=%s, got %s", name, expectedValue, actualValue)
}
} }
func assertHeaderContains(t *testing.T, rec *httptest.ResponseRecorder, name, substring string) { func assertHeaderContains(t *testing.T, recorder *httptest.ResponseRecorder, name, substring string) {
t.Helper() t.Helper()
actualValue := rec.Header().Get(name) actualValue := recorder.Header().Get(name)
if !strings.Contains(actualValue, substring) { if !strings.Contains(actualValue, substring) {
t.Errorf("Expected header %s to contain %s, got %s", name, substring, actualValue) t.Errorf("Expected header %s to contain %s, got %s", name, substring, actualValue)
} }
@@ -450,3 +447,83 @@ func createUserWithCleanup(t *testing.T, ctx *testContext, username, email strin
}) })
return user return user
} }
func makeRequest(t *testing.T, router http.Handler, method, path string, body []byte, headers map[string]string) *httptest.ResponseRecorder {
t.Helper()
var requestBody *bytes.Buffer
if body != nil {
requestBody = bytes.NewBuffer(body)
} else {
requestBody = bytes.NewBuffer(nil)
}
request := httptest.NewRequest(method, path, requestBody)
for key, value := range headers {
request.Header.Set(key, value)
}
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, request)
return recorder
}
func makeAuthenticatedRequest(t *testing.T, router http.Handler, method, path string, body []byte, user *authenticatedUser, urlParams map[string]string) *httptest.ResponseRecorder {
t.Helper()
var requestBody *bytes.Buffer
if body != nil {
requestBody = bytes.NewBuffer(body)
} else {
requestBody = bytes.NewBuffer(nil)
}
request := httptest.NewRequest(method, path, requestBody)
request.Header.Set("Authorization", "Bearer "+user.Token)
if body != nil {
request.Header.Set("Content-Type", "application/json")
}
request = testutils.WithUserContext(request, middleware.UserIDKey, user.User.ID)
if urlParams != nil {
request = testutils.WithURLParams(request, urlParams)
}
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, request)
return recorder
}
func makeGetRequest(t *testing.T, router http.Handler, path string) *httptest.ResponseRecorder {
t.Helper()
return makeRequest(t, router, "GET", path, nil, nil)
}
func makeAuthenticatedGetRequest(t *testing.T, router http.Handler, path string, user *authenticatedUser, urlParams map[string]string) *httptest.ResponseRecorder {
t.Helper()
return makeAuthenticatedRequest(t, router, "GET", path, nil, user, urlParams)
}
func makePostRequest(t *testing.T, router http.Handler, path string, body map[string]any, user *authenticatedUser, urlParams map[string]string) *httptest.ResponseRecorder {
t.Helper()
bodyBytes, _ := json.Marshal(body)
return makeAuthenticatedRequest(t, router, "POST", path, bodyBytes, user, urlParams)
}
func makePutRequest(t *testing.T, router http.Handler, path string, body map[string]any, user *authenticatedUser, urlParams map[string]string) *httptest.ResponseRecorder {
t.Helper()
bodyBytes, _ := json.Marshal(body)
return makeAuthenticatedRequest(t, router, "PUT", path, bodyBytes, user, urlParams)
}
func makeDeleteRequest(t *testing.T, router http.Handler, path string, user *authenticatedUser, urlParams map[string]string) *httptest.ResponseRecorder {
t.Helper()
return makeAuthenticatedRequest(t, router, "DELETE", path, nil, user, urlParams)
}
func makePostRequestWithJSON(t *testing.T, router http.Handler, path string, body map[string]any) *httptest.ResponseRecorder {
t.Helper()
bodyBytes, _ := json.Marshal(body)
return makeRequest(t, router, "POST", path, bodyBytes, map[string]string{"Content-Type": "application/json"})
}
func getDataFromResponse(response map[string]any) (map[string]any, bool) {
if response == nil {
return nil, false
}
data, ok := response["data"].(map[string]any)
return data, ok
}

View File

@@ -22,26 +22,26 @@ func TestIntegration_PageHandlerFormWorkflows(t *testing.T) {
ctx.Suite.EmailSender.Reset() ctx.Suite.EmailSender.Reset()
user := createAuthenticatedUser(t, authService, ctx.Suite.UserRepo, "settings_email_user", "settings_email@example.com") user := createAuthenticatedUser(t, authService, ctx.Suite.UserRepo, "settings_email_user", "settings_email@example.com")
getReq := httptest.NewRequest("GET", "/settings", nil) getRequest := httptest.NewRequest("GET", "/settings", nil)
getReq.AddCookie(&http.Cookie{Name: "auth_token", Value: user.Token}) getRequest.AddCookie(&http.Cookie{Name: "auth_token", Value: user.Token})
getRec := httptest.NewRecorder() getRecorder := httptest.NewRecorder()
router.ServeHTTP(getRec, getReq) router.ServeHTTP(getRecorder, getRequest)
csrfToken := getCSRFToken(t, router, "/settings", &http.Cookie{Name: "auth_token", Value: user.Token}) csrfToken := getCSRFToken(t, router, "/settings", &http.Cookie{Name: "auth_token", Value: user.Token})
reqBody := url.Values{} requestBody := url.Values{}
reqBody.Set("email", "newemail@example.com") requestBody.Set("email", "newemail@example.com")
reqBody.Set("csrf_token", csrfToken) requestBody.Set("csrf_token", csrfToken)
req := httptest.NewRequest("POST", "/settings/email", strings.NewReader(reqBody.Encode())) request := httptest.NewRequest("POST", "/settings/email", strings.NewReader(requestBody.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") request.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.AddCookie(&http.Cookie{Name: "auth_token", Value: user.Token}) request.AddCookie(&http.Cookie{Name: "auth_token", Value: user.Token})
req.AddCookie(&http.Cookie{Name: "csrf_token", Value: csrfToken}) request.AddCookie(&http.Cookie{Name: "csrf_token", Value: csrfToken})
rec := httptest.NewRecorder() recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req) router.ServeHTTP(recorder, request)
assertStatusRange(t, rec, http.StatusOK, http.StatusSeeOther) assertStatusRange(t, recorder, http.StatusOK, http.StatusSeeOther)
}) })
t.Run("Settings_Username_Update_Form", func(t *testing.T) { t.Run("Settings_Username_Update_Form", func(t *testing.T) {
@@ -51,19 +51,19 @@ func TestIntegration_PageHandlerFormWorkflows(t *testing.T) {
csrfToken := getCSRFToken(t, router, "/settings", &http.Cookie{Name: "auth_token", Value: user.Token}) csrfToken := getCSRFToken(t, router, "/settings", &http.Cookie{Name: "auth_token", Value: user.Token})
reqBody := url.Values{} requestBody := url.Values{}
reqBody.Set("username", "new_username") requestBody.Set("username", "new_username")
reqBody.Set("csrf_token", csrfToken) requestBody.Set("csrf_token", csrfToken)
req := httptest.NewRequest("POST", "/settings/username", strings.NewReader(reqBody.Encode())) request := httptest.NewRequest("POST", "/settings/username", strings.NewReader(requestBody.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") request.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.AddCookie(&http.Cookie{Name: "auth_token", Value: user.Token}) request.AddCookie(&http.Cookie{Name: "auth_token", Value: user.Token})
req.AddCookie(&http.Cookie{Name: "csrf_token", Value: csrfToken}) request.AddCookie(&http.Cookie{Name: "csrf_token", Value: csrfToken})
rec := httptest.NewRecorder() recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req) router.ServeHTTP(recorder, request)
assertStatusRange(t, rec, http.StatusOK, http.StatusSeeOther) assertStatusRange(t, recorder, http.StatusOK, http.StatusSeeOther)
}) })
t.Run("Settings_Password_Update_Form", func(t *testing.T) { t.Run("Settings_Password_Update_Form", func(t *testing.T) {
@@ -74,20 +74,20 @@ func TestIntegration_PageHandlerFormWorkflows(t *testing.T) {
csrfToken := getCSRFToken(t, freshCtx.Router, "/settings", &http.Cookie{Name: "auth_token", Value: user.Token}) csrfToken := getCSRFToken(t, freshCtx.Router, "/settings", &http.Cookie{Name: "auth_token", Value: user.Token})
reqBody := url.Values{} requestBody := url.Values{}
reqBody.Set("current_password", "SecurePass123!") requestBody.Set("current_password", "SecurePass123!")
reqBody.Set("new_password", "NewSecurePass123!") requestBody.Set("new_password", "NewSecurePass123!")
reqBody.Set("csrf_token", csrfToken) requestBody.Set("csrf_token", csrfToken)
req := httptest.NewRequest("POST", "/settings/password", strings.NewReader(reqBody.Encode())) request := httptest.NewRequest("POST", "/settings/password", strings.NewReader(requestBody.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") request.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.AddCookie(&http.Cookie{Name: "auth_token", Value: user.Token}) request.AddCookie(&http.Cookie{Name: "auth_token", Value: user.Token})
req.AddCookie(&http.Cookie{Name: "csrf_token", Value: csrfToken}) request.AddCookie(&http.Cookie{Name: "csrf_token", Value: csrfToken})
rec := httptest.NewRecorder() recorder := httptest.NewRecorder()
freshCtx.Router.ServeHTTP(rec, req) freshCtx.Router.ServeHTTP(recorder, request)
assertStatusRange(t, rec, http.StatusOK, http.StatusSeeOther) assertStatusRange(t, recorder, http.StatusOK, http.StatusSeeOther)
}) })
t.Run("Logout_Page_Handler", func(t *testing.T) { t.Run("Logout_Page_Handler", func(t *testing.T) {
@@ -98,19 +98,19 @@ func TestIntegration_PageHandlerFormWorkflows(t *testing.T) {
csrfToken := getCSRFToken(t, freshCtx.Router, "/settings", &http.Cookie{Name: "auth_token", Value: user.Token}) csrfToken := getCSRFToken(t, freshCtx.Router, "/settings", &http.Cookie{Name: "auth_token", Value: user.Token})
reqBody := url.Values{} requestBody := url.Values{}
reqBody.Set("csrf_token", csrfToken) requestBody.Set("csrf_token", csrfToken)
req := httptest.NewRequest("POST", "/logout", strings.NewReader(reqBody.Encode())) request := httptest.NewRequest("POST", "/logout", strings.NewReader(requestBody.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") request.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.AddCookie(&http.Cookie{Name: "auth_token", Value: user.Token}) request.AddCookie(&http.Cookie{Name: "auth_token", Value: user.Token})
req.AddCookie(&http.Cookie{Name: "csrf_token", Value: csrfToken}) request.AddCookie(&http.Cookie{Name: "csrf_token", Value: csrfToken})
rec := httptest.NewRecorder() recorder := httptest.NewRecorder()
freshCtx.Router.ServeHTTP(rec, req) freshCtx.Router.ServeHTTP(recorder, request)
assertStatus(t, rec, http.StatusSeeOther) assertStatus(t, recorder, http.StatusSeeOther)
assertCookieCleared(t, rec, "auth_token") assertCookieCleared(t, recorder, "auth_token")
}) })
t.Run("Resend_Verification_Page_Handler", func(t *testing.T) { t.Run("Resend_Verification_Page_Handler", func(t *testing.T) {
@@ -120,18 +120,18 @@ func TestIntegration_PageHandlerFormWorkflows(t *testing.T) {
csrfToken := getCSRFToken(t, freshCtx.Router, "/resend-verification") csrfToken := getCSRFToken(t, freshCtx.Router, "/resend-verification")
reqBody := url.Values{} requestBody := url.Values{}
reqBody.Set("email", "resend_page@example.com") requestBody.Set("email", "resend_page@example.com")
reqBody.Set("csrf_token", csrfToken) requestBody.Set("csrf_token", csrfToken)
req := httptest.NewRequest("POST", "/resend-verification", strings.NewReader(reqBody.Encode())) request := httptest.NewRequest("POST", "/resend-verification", strings.NewReader(requestBody.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") request.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.AddCookie(&http.Cookie{Name: "csrf_token", Value: csrfToken}) request.AddCookie(&http.Cookie{Name: "csrf_token", Value: csrfToken})
rec := httptest.NewRecorder() recorder := httptest.NewRecorder()
freshCtx.Router.ServeHTTP(rec, req) freshCtx.Router.ServeHTTP(recorder, request)
assertStatusRange(t, rec, http.StatusOK, http.StatusSeeOther) assertStatusRange(t, recorder, http.StatusOK, http.StatusSeeOther)
}) })
t.Run("Post_Vote_Page_Handler", func(t *testing.T) { t.Run("Post_Vote_Page_Handler", func(t *testing.T) {
@@ -142,26 +142,26 @@ func TestIntegration_PageHandlerFormWorkflows(t *testing.T) {
post := testutils.CreatePostWithRepo(t, freshCtx.Suite.PostRepo, user.User.ID, "Vote Page Test", "https://example.com/vote-page") post := testutils.CreatePostWithRepo(t, freshCtx.Suite.PostRepo, user.User.ID, "Vote Page Test", "https://example.com/vote-page")
getReq := httptest.NewRequest("GET", fmt.Sprintf("/posts/%d", post.ID), nil) getRequest := httptest.NewRequest("GET", fmt.Sprintf("/posts/%d", post.ID), nil)
getRec := httptest.NewRecorder() getRecorder := httptest.NewRecorder()
freshCtx.Router.ServeHTTP(getRec, getReq) freshCtx.Router.ServeHTTP(getRecorder, getRequest)
csrfToken := getCSRFToken(t, freshCtx.Router, fmt.Sprintf("/posts/%d", post.ID)) csrfToken := getCSRFToken(t, freshCtx.Router, fmt.Sprintf("/posts/%d", post.ID))
reqBody := url.Values{} requestBody := url.Values{}
reqBody.Set("action", "up") requestBody.Set("action", "up")
reqBody.Set("csrf_token", csrfToken) requestBody.Set("csrf_token", csrfToken)
req := httptest.NewRequest("POST", fmt.Sprintf("/posts/%d/vote", post.ID), strings.NewReader(reqBody.Encode())) request := httptest.NewRequest("POST", fmt.Sprintf("/posts/%d/vote", post.ID), strings.NewReader(requestBody.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") request.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.AddCookie(&http.Cookie{Name: "auth_token", Value: user.Token}) request.AddCookie(&http.Cookie{Name: "auth_token", Value: user.Token})
req.AddCookie(&http.Cookie{Name: "csrf_token", Value: csrfToken}) request.AddCookie(&http.Cookie{Name: "csrf_token", Value: csrfToken})
req = testutils.WithURLParams(req, map[string]string{"id": fmt.Sprintf("%d", post.ID)}) request = testutils.WithURLParams(request, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
rec := httptest.NewRecorder() recorder := httptest.NewRecorder()
freshCtx.Router.ServeHTTP(rec, req) freshCtx.Router.ServeHTTP(recorder, request)
assertStatusRange(t, rec, http.StatusOK, http.StatusSeeOther) assertStatusRange(t, recorder, http.StatusOK, http.StatusSeeOther)
}) })
t.Run("Login_Page_Handler_Workflow", func(t *testing.T) { t.Run("Login_Page_Handler_Workflow", func(t *testing.T) {
@@ -172,20 +172,20 @@ func TestIntegration_PageHandlerFormWorkflows(t *testing.T) {
csrfToken := getCSRFToken(t, freshCtx.Router, "/login") csrfToken := getCSRFToken(t, freshCtx.Router, "/login")
reqBody := url.Values{} requestBody := url.Values{}
reqBody.Set("username", "login_page_user") requestBody.Set("username", "login_page_user")
reqBody.Set("password", "SecurePass123!") requestBody.Set("password", "SecurePass123!")
reqBody.Set("csrf_token", csrfToken) requestBody.Set("csrf_token", csrfToken)
req := httptest.NewRequest("POST", "/login", strings.NewReader(reqBody.Encode())) request := httptest.NewRequest("POST", "/login", strings.NewReader(requestBody.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") request.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.AddCookie(&http.Cookie{Name: "csrf_token", Value: csrfToken}) request.AddCookie(&http.Cookie{Name: "csrf_token", Value: csrfToken})
rec := httptest.NewRecorder() recorder := httptest.NewRecorder()
freshCtx.Router.ServeHTTP(rec, req) freshCtx.Router.ServeHTTP(recorder, request)
assertStatus(t, rec, http.StatusSeeOther) assertStatus(t, recorder, http.StatusSeeOther)
assertCookie(t, rec, "auth_token", "") assertCookie(t, recorder, "auth_token", "")
}) })
t.Run("Email_Confirmation_Page_Handler", func(t *testing.T) { t.Run("Email_Confirmation_Page_Handler", func(t *testing.T) {
@@ -198,11 +198,11 @@ func TestIntegration_PageHandlerFormWorkflows(t *testing.T) {
token = "test-token" token = "test-token"
} }
req := httptest.NewRequest("GET", "/confirm?token="+url.QueryEscape(token), nil) request := httptest.NewRequest("GET", "/confirm?token="+url.QueryEscape(token), nil)
rec := httptest.NewRecorder() recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req) router.ServeHTTP(recorder, request)
assertStatusRange(t, rec, http.StatusOK, http.StatusSeeOther) assertStatusRange(t, recorder, http.StatusOK, http.StatusSeeOther)
}) })
} }

View File

@@ -16,63 +16,62 @@ func TestIntegration_PageHandler(t *testing.T) {
router := ctx.Router router := ctx.Router
t.Run("Home_Page_Renders", func(t *testing.T) { t.Run("Home_Page_Renders", func(t *testing.T) {
req := httptest.NewRequest("GET", "/", nil) request := httptest.NewRequest("GET", "/", nil)
rec := httptest.NewRecorder() recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req) router.ServeHTTP(recorder, request)
assertStatus(t, rec, http.StatusOK) assertStatus(t, recorder, http.StatusOK)
if !strings.Contains(rec.Body.String(), "<html") { if !strings.Contains(recorder.Body.String(), "<html") {
t.Error("Expected HTML content") t.Error("Expected HTML content")
} }
}) })
t.Run("Login_Form_Renders", func(t *testing.T) { t.Run("Login_Form_Renders", func(t *testing.T) {
req := httptest.NewRequest("GET", "/login", nil) request := httptest.NewRequest("GET", "/login", nil)
rec := httptest.NewRecorder() recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req) router.ServeHTTP(recorder, request)
assertStatus(t, rec, http.StatusOK) assertStatus(t, recorder, http.StatusOK)
body := rec.Body.String() body := recorder.Body.String()
if !strings.Contains(body, "login") && !strings.Contains(body, "Login") { if !strings.Contains(body, "login") && !strings.Contains(body, "Login") {
t.Error("Expected login form content") t.Error("Expected login form content")
} }
}) })
t.Run("Register_Form_Renders", func(t *testing.T) { t.Run("Register_Form_Renders", func(t *testing.T) {
req := httptest.NewRequest("GET", "/register", nil) request := httptest.NewRequest("GET", "/register", nil)
rec := httptest.NewRecorder() recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req) router.ServeHTTP(recorder, request)
assertStatus(t, recorder, http.StatusOK)
assertStatus(t, rec, http.StatusOK) body := recorder.Body.String()
body := rec.Body.String()
if !strings.Contains(body, "register") && !strings.Contains(body, "Register") { if !strings.Contains(body, "register") && !strings.Contains(body, "Register") {
t.Error("Expected register form content") t.Error("Expected register form content")
} }
}) })
t.Run("PageHandler_With_CSRF_Token", func(t *testing.T) { t.Run("PageHandler_With_CSRF_Token", func(t *testing.T) {
req := httptest.NewRequest("GET", "/register", nil) request := httptest.NewRequest("GET", "/register", nil)
rec := httptest.NewRecorder() recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req) router.ServeHTTP(recorder, request)
assertCookie(t, rec, "csrf_token", "") assertCookie(t, recorder, "csrf_token", "")
}) })
t.Run("PageHandler_Form_Submission", func(t *testing.T) { t.Run("PageHandler_Form_Submission", func(t *testing.T) {
ctx.Suite.EmailSender.Reset() ctx.Suite.EmailSender.Reset()
getReq := httptest.NewRequest("GET", "/register", nil) getRequest := httptest.NewRequest("GET", "/register", nil)
getRec := httptest.NewRecorder() getRecorder := httptest.NewRecorder()
router.ServeHTTP(getRec, getReq) router.ServeHTTP(getRecorder, getRequest)
cookies := getRec.Result().Cookies() cookies := getRecorder.Result().Cookies()
var csrfCookie *http.Cookie var csrfCookie *http.Cookie
for _, cookie := range cookies { for _, cookie := range cookies {
if cookie.Name == "csrf_token" { if cookie.Name == "csrf_token" {
@@ -85,33 +84,33 @@ func TestIntegration_PageHandler(t *testing.T) {
t.Fatal("Expected CSRF cookie") t.Fatal("Expected CSRF cookie")
} }
reqBody := url.Values{} requestBody := url.Values{}
reqBody.Set("username", "page_form_user") requestBody.Set("username", "page_form_user")
reqBody.Set("email", "page_form@example.com") requestBody.Set("email", "page_form@example.com")
reqBody.Set("password", "SecurePass123!") requestBody.Set("password", "SecurePass123!")
reqBody.Set("csrf_token", csrfCookie.Value) requestBody.Set("csrf_token", csrfCookie.Value)
req := httptest.NewRequest("POST", "/register", strings.NewReader(reqBody.Encode())) request := httptest.NewRequest("POST", "/register", strings.NewReader(requestBody.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") request.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.AddCookie(csrfCookie) request.AddCookie(csrfCookie)
rec := httptest.NewRecorder() recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req) router.ServeHTTP(recorder, request)
assertStatusRange(t, rec, http.StatusOK, http.StatusSeeOther) assertStatusRange(t, recorder, http.StatusOK, http.StatusSeeOther)
}) })
t.Run("PageHandler_Authenticated_Access", func(t *testing.T) { t.Run("PageHandler_Authenticated_Access", func(t *testing.T) {
ctx.Suite.EmailSender.Reset() ctx.Suite.EmailSender.Reset()
user := createUserWithCleanup(t, ctx, "page_auth_user", "page_auth@example.com") user := createUserWithCleanup(t, ctx, "page_auth_user", "page_auth@example.com")
req := httptest.NewRequest("GET", "/settings", nil) request := httptest.NewRequest("GET", "/settings", nil)
req.AddCookie(&http.Cookie{Name: "auth_token", Value: user.Token}) request.AddCookie(&http.Cookie{Name: "auth_token", Value: user.Token})
rec := httptest.NewRecorder() recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req) router.ServeHTTP(recorder, request)
assertStatus(t, rec, http.StatusOK) assertStatus(t, recorder, http.StatusOK)
}) })
t.Run("PageHandler_Post_Display", func(t *testing.T) { t.Run("PageHandler_Post_Display", func(t *testing.T) {
@@ -120,34 +119,34 @@ func TestIntegration_PageHandler(t *testing.T) {
post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Page Test Post", "https://example.com/page-test") post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Page Test Post", "https://example.com/page-test")
req := httptest.NewRequest("GET", "/posts/"+fmt.Sprintf("%d", post.ID), nil) request := httptest.NewRequest("GET", "/posts/"+fmt.Sprintf("%d", post.ID), nil)
rec := httptest.NewRecorder() recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req) router.ServeHTTP(recorder, request)
assertStatus(t, rec, http.StatusOK) assertStatus(t, recorder, http.StatusOK)
body := rec.Body.String() body := recorder.Body.String()
if !strings.Contains(body, "Page Test Post") { if !strings.Contains(body, "Page Test Post") {
t.Error("Expected post title in page") t.Error("Expected post title in page")
} }
}) })
t.Run("PageHandler_Search_Page", func(t *testing.T) { t.Run("PageHandler_Search_Page", func(t *testing.T) {
req := httptest.NewRequest("GET", "/search?q=test", nil) request := httptest.NewRequest("GET", "/search?q=test", nil)
rec := httptest.NewRecorder() recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req) router.ServeHTTP(recorder, request)
assertStatus(t, rec, http.StatusOK) assertStatus(t, recorder, http.StatusOK)
}) })
t.Run("PageHandler_Error_Handling", func(t *testing.T) { t.Run("PageHandler_Error_Handling", func(t *testing.T) {
req := httptest.NewRequest("GET", "/nonexistent", nil) request := httptest.NewRequest("GET", "/nonexistent", nil)
rec := httptest.NewRecorder() recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req) router.ServeHTTP(recorder, request)
assertStatus(t, rec, http.StatusNotFound) assertStatus(t, recorder, http.StatusNotFound)
}) })
} }

View File

@@ -1,8 +1,6 @@
package integration package integration
import ( import (
"bytes"
"encoding/json"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url" "net/url"
@@ -32,17 +30,12 @@ func TestIntegration_PasswordReset_CompleteFlow(t *testing.T) {
t.Fatalf("Failed to create user: %v", err) t.Fatalf("Failed to create user: %v", err)
} }
reqBody := map[string]string{ reqBody := map[string]any{
"username_or_email": "reset_user", "username_or_email": "reset_user",
} }
body, _ := json.Marshal(reqBody) request := makePostRequestWithJSON(t, router, "/api/auth/forgot-password", reqBody)
req := httptest.NewRequest("POST", "/api/auth/forgot-password", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req) response := assertJSONResponse(t, request, http.StatusOK)
response := assertJSONResponse(t, rec, http.StatusOK)
if response != nil { if response != nil {
if success, ok := response["success"].(bool); !ok || !success { if success, ok := response["success"].(bool); !ok || !success {
t.Error("Expected success=true") t.Error("Expected success=true")
@@ -77,18 +70,13 @@ func TestIntegration_PasswordReset_CompleteFlow(t *testing.T) {
t.Fatal("Expected password reset token") t.Fatal("Expected password reset token")
} }
reqBody := map[string]string{ reqBody := map[string]any{
"token": resetToken, "token": resetToken,
"new_password": "NewPassword123!", "new_password": "NewPassword123!",
} }
body, _ := json.Marshal(reqBody) request := makePostRequestWithJSON(t, router, "/api/auth/reset-password", reqBody)
req := httptest.NewRequest("POST", "/api/auth/reset-password", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req) assertStatus(t, request, http.StatusOK)
assertStatus(t, rec, http.StatusOK)
loginResult, err := ctx.AuthService.Login("reset_complete_user", "NewPassword123!") loginResult, err := ctx.AuthService.Login("reset_complete_user", "NewPassword123!")
if err != nil { if err != nil {
@@ -120,14 +108,14 @@ func TestIntegration_PasswordReset_CompleteFlow(t *testing.T) {
reqBody := url.Values{} reqBody := url.Values{}
reqBody.Set("username_or_email", "page_reset_user") reqBody.Set("username_or_email", "page_reset_user")
reqBody.Set("csrf_token", csrfToken) reqBody.Set("csrf_token", csrfToken)
req := httptest.NewRequest("POST", "/forgot-password", strings.NewReader(reqBody.Encode())) request := httptest.NewRequest("POST", "/forgot-password", strings.NewReader(reqBody.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") request.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.AddCookie(&http.Cookie{Name: "csrf_token", Value: csrfToken}) request.AddCookie(&http.Cookie{Name: "csrf_token", Value: csrfToken})
rec := httptest.NewRecorder() recorder := httptest.NewRecorder()
pageRouter.ServeHTTP(rec, req) pageRouter.ServeHTTP(recorder, request)
assertStatusRange(t, rec, http.StatusOK, http.StatusSeeOther) assertStatusRange(t, recorder, http.StatusOK, http.StatusSeeOther)
resetToken := pageCtx.Suite.EmailSender.PasswordResetToken() resetToken := pageCtx.Suite.EmailSender.PasswordResetToken()
if resetToken == "" { if resetToken == "" {
@@ -166,33 +154,23 @@ func TestIntegration_PasswordReset_CompleteFlow(t *testing.T) {
t.Fatalf("Failed to update user: %v", err) t.Fatalf("Failed to update user: %v", err)
} }
reqBody := map[string]string{ reqBody := map[string]any{
"token": resetToken, "token": resetToken,
"new_password": "NewPassword123!", "new_password": "NewPassword123!",
} }
body, _ := json.Marshal(reqBody) request := makePostRequestWithJSON(t, router, "/api/auth/reset-password", reqBody)
req := httptest.NewRequest("POST", "/api/auth/reset-password", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req) assertErrorResponse(t, request, http.StatusBadRequest)
assertErrorResponse(t, rec, http.StatusBadRequest)
}) })
t.Run("PasswordReset_InvalidToken", func(t *testing.T) { t.Run("PasswordReset_InvalidToken", func(t *testing.T) {
reqBody := map[string]string{ reqBody := map[string]any{
"token": "invalid-token", "token": "invalid-token",
"new_password": "NewPassword123!", "new_password": "NewPassword123!",
} }
body, _ := json.Marshal(reqBody) request := makePostRequestWithJSON(t, router, "/api/auth/reset-password", reqBody)
req := httptest.NewRequest("POST", "/api/auth/reset-password", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req) assertErrorResponse(t, request, http.StatusBadRequest)
assertErrorResponse(t, rec, http.StatusBadRequest)
}) })
t.Run("PasswordReset_WeakPassword", func(t *testing.T) { t.Run("PasswordReset_WeakPassword", func(t *testing.T) {
@@ -214,18 +192,13 @@ func TestIntegration_PasswordReset_CompleteFlow(t *testing.T) {
resetToken := ctx.Suite.EmailSender.PasswordResetToken() resetToken := ctx.Suite.EmailSender.PasswordResetToken()
reqBody := map[string]string{ reqBody := map[string]any{
"token": resetToken, "token": resetToken,
"new_password": "123", "new_password": "123",
} }
body, _ := json.Marshal(reqBody) request := makePostRequestWithJSON(t, router, "/api/auth/reset-password", reqBody)
req := httptest.NewRequest("POST", "/api/auth/reset-password", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req) assertErrorResponse(t, request, http.StatusBadRequest)
assertErrorResponse(t, rec, http.StatusBadRequest)
}) })
t.Run("PasswordReset_EmailIntegration", func(t *testing.T) { t.Run("PasswordReset_EmailIntegration", func(t *testing.T) {
@@ -243,17 +216,12 @@ func TestIntegration_PasswordReset_CompleteFlow(t *testing.T) {
t.Fatalf("Failed to create user: %v", err) t.Fatalf("Failed to create user: %v", err)
} }
reqBody := map[string]string{ reqBody := map[string]any{
"username_or_email": "email_reset@example.com", "username_or_email": "email_reset@example.com",
} }
body, _ := json.Marshal(reqBody) request := makePostRequestWithJSON(t, freshCtx.Router, "/api/auth/forgot-password", reqBody)
req := httptest.NewRequest("POST", "/api/auth/forgot-password", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
freshCtx.Router.ServeHTTP(rec, req) assertStatus(t, request, http.StatusOK)
assertStatus(t, rec, http.StatusOK)
resetToken := freshCtx.Suite.EmailSender.PasswordResetToken() resetToken := freshCtx.Suite.EmailSender.PasswordResetToken()
if resetToken == "" { if resetToken == "" {

View File

@@ -51,24 +51,24 @@ func TestIntegration_RateLimiting(t *testing.T) {
router, _ := setupRateLimitRouter(t, rateLimitConfig) router, _ := setupRateLimitRouter(t, rateLimitConfig)
for i := 0; i < 2; i++ { for i := 0; i < 2; i++ {
req := httptest.NewRequest("POST", "/api/auth/login", bytes.NewBufferString(`{"username":"test","password":"test"}`)) request := httptest.NewRequest("POST", "/api/auth/login", bytes.NewBufferString(`{"username":"test","password":"test"}`))
req.Header.Set("Content-Type", "application/json") request.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder() recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req) router.ServeHTTP(recorder, request)
} }
req := httptest.NewRequest("POST", "/api/auth/login", bytes.NewBufferString(`{"username":"test","password":"test"}`)) request := httptest.NewRequest("POST", "/api/auth/login", bytes.NewBufferString(`{"username":"test","password":"test"}`))
req.Header.Set("Content-Type", "application/json") request.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder() recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req) router.ServeHTTP(recorder, request)
assertErrorResponse(t, rec, http.StatusTooManyRequests) assertErrorResponse(t, recorder, http.StatusTooManyRequests)
assertHeader(t, rec, "Retry-After", "") assertHeader(t, recorder, "Retry-After")
var response map[string]any var response map[string]any
if err := json.NewDecoder(rec.Body).Decode(&response); err == nil { if err := json.NewDecoder(recorder.Body).Decode(&response); err == nil {
if _, exists := response["retry_after"]; !exists { if _, exists := response["retry_after"]; !exists {
t.Error("Expected retry_after in response") t.Error("Expected retry_after in response")
} }
@@ -81,17 +81,17 @@ func TestIntegration_RateLimiting(t *testing.T) {
router, _ := setupRateLimitRouter(t, rateLimitConfig) router, _ := setupRateLimitRouter(t, rateLimitConfig)
for i := 0; i < 5; i++ { for i := 0; i < 5; i++ {
req := httptest.NewRequest("GET", "/api/posts", nil) request := httptest.NewRequest("GET", "/api/posts", nil)
rec := httptest.NewRecorder() recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req) router.ServeHTTP(recorder, request)
} }
req := httptest.NewRequest("GET", "/api/posts", nil) request := httptest.NewRequest("GET", "/api/posts", nil)
rec := httptest.NewRecorder() recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req) router.ServeHTTP(recorder, request)
assertErrorResponse(t, rec, http.StatusTooManyRequests) assertErrorResponse(t, recorder, http.StatusTooManyRequests)
}) })
t.Run("Health_RateLimit_Enforced", func(t *testing.T) { t.Run("Health_RateLimit_Enforced", func(t *testing.T) {
@@ -100,17 +100,17 @@ func TestIntegration_RateLimiting(t *testing.T) {
router, _ := setupRateLimitRouter(t, rateLimitConfig) router, _ := setupRateLimitRouter(t, rateLimitConfig)
for i := 0; i < 3; i++ { for i := 0; i < 3; i++ {
req := httptest.NewRequest("GET", "/health", nil) request := httptest.NewRequest("GET", "/health", nil)
rec := httptest.NewRecorder() recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req) router.ServeHTTP(recorder, request)
} }
req := httptest.NewRequest("GET", "/health", nil) request := httptest.NewRequest("GET", "/health", nil)
rec := httptest.NewRecorder() recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req) router.ServeHTTP(recorder, request)
assertErrorResponse(t, rec, http.StatusTooManyRequests) assertErrorResponse(t, recorder, http.StatusTooManyRequests)
}) })
t.Run("Metrics_RateLimit_Enforced", func(t *testing.T) { t.Run("Metrics_RateLimit_Enforced", func(t *testing.T) {
@@ -119,17 +119,17 @@ func TestIntegration_RateLimiting(t *testing.T) {
router, _ := setupRateLimitRouter(t, rateLimitConfig) router, _ := setupRateLimitRouter(t, rateLimitConfig)
for i := 0; i < 2; i++ { for i := 0; i < 2; i++ {
req := httptest.NewRequest("GET", "/metrics", nil) request := httptest.NewRequest("GET", "/metrics", nil)
rec := httptest.NewRecorder() recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req) router.ServeHTTP(recorder, request)
} }
req := httptest.NewRequest("GET", "/metrics", nil) request := httptest.NewRequest("GET", "/metrics", nil)
rec := httptest.NewRecorder() recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req) router.ServeHTTP(recorder, request)
assertErrorResponse(t, rec, http.StatusTooManyRequests) assertErrorResponse(t, recorder, http.StatusTooManyRequests)
}) })
t.Run("RateLimit_Different_Endpoints_Independent", func(t *testing.T) { t.Run("RateLimit_Different_Endpoints_Independent", func(t *testing.T) {
@@ -139,17 +139,17 @@ func TestIntegration_RateLimiting(t *testing.T) {
router, _ := setupRateLimitRouter(t, rateLimitConfig) router, _ := setupRateLimitRouter(t, rateLimitConfig)
for i := 0; i < 2; i++ { for i := 0; i < 2; i++ {
req := httptest.NewRequest("POST", "/api/auth/login", bytes.NewBufferString(`{"username":"test","password":"test"}`)) request := httptest.NewRequest("POST", "/api/auth/login", bytes.NewBufferString(`{"username":"test","password":"test"}`))
req.Header.Set("Content-Type", "application/json") request.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder() recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req) router.ServeHTTP(recorder, request)
} }
req := httptest.NewRequest("GET", "/api/posts", nil) request := httptest.NewRequest("GET", "/api/posts", nil)
rec := httptest.NewRecorder() recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req) router.ServeHTTP(recorder, request)
assertStatus(t, rec, http.StatusOK) assertStatus(t, recorder, http.StatusOK)
}) })
t.Run("RateLimit_With_Authentication", func(t *testing.T) { t.Run("RateLimit_With_Authentication", func(t *testing.T) {
@@ -166,20 +166,20 @@ func TestIntegration_RateLimiting(t *testing.T) {
user := createAuthenticatedUser(t, authService, suite.UserRepo, uniqueTestUsername(t, "ratelimit_auth"), uniqueTestEmail(t, "ratelimit_auth")) user := createAuthenticatedUser(t, authService, suite.UserRepo, uniqueTestUsername(t, "ratelimit_auth"), uniqueTestEmail(t, "ratelimit_auth"))
for i := 0; i < 3; i++ { for i := 0; i < 3; i++ {
req := httptest.NewRequest("GET", "/api/auth/me", nil) request := httptest.NewRequest("GET", "/api/auth/me", nil)
req.Header.Set("Authorization", "Bearer "+user.Token) request.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID) request = testutils.WithUserContext(request, middleware.UserIDKey, user.User.ID)
rec := httptest.NewRecorder() recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req) router.ServeHTTP(recorder, request)
} }
req := httptest.NewRequest("GET", "/api/auth/me", nil) request := httptest.NewRequest("GET", "/api/auth/me", nil)
req.Header.Set("Authorization", "Bearer "+user.Token) request.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID) request = testutils.WithUserContext(request, middleware.UserIDKey, user.User.ID)
rec := httptest.NewRecorder() recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req) router.ServeHTTP(recorder, request)
assertErrorResponse(t, rec, http.StatusTooManyRequests) assertErrorResponse(t, recorder, http.StatusTooManyRequests)
}) })
} }

View File

@@ -17,35 +17,29 @@ func TestIntegration_Router_FullMiddlewareChain(t *testing.T) {
router := ctx.Router router := ctx.Router
t.Run("SecurityHeaders_Present", func(t *testing.T) { t.Run("SecurityHeaders_Present", func(t *testing.T) {
req := httptest.NewRequest("GET", "/health", nil) request := makeGetRequest(t, router, "/health")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req) assertStatus(t, request, http.StatusOK)
assertStatus(t, rec, http.StatusOK) assertHeader(t, request, "X-Content-Type-Options")
assertHeader(t, request, "X-Frame-Options")
assertHeader(t, rec, "X-Content-Type-Options", "") assertHeader(t, request, "X-XSS-Protection")
assertHeader(t, rec, "X-Frame-Options", "")
assertHeader(t, rec, "X-XSS-Protection", "")
}) })
t.Run("CORS_Headers_Present", func(t *testing.T) { t.Run("CORS_Headers_Present", func(t *testing.T) {
req := httptest.NewRequest("OPTIONS", "/api/posts", nil) request := httptest.NewRequest("OPTIONS", "/api/posts", nil)
req.Header.Set("Origin", "http://localhost:3000") request.Header.Set("Origin", "http://localhost:3000")
rec := httptest.NewRecorder() recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req) router.ServeHTTP(recorder, request)
assertHeader(t, rec, "Access-Control-Allow-Origin", "") assertHeader(t, recorder, "Access-Control-Allow-Origin")
}) })
t.Run("Logging_Middleware_Executes", func(t *testing.T) { t.Run("Logging_Middleware_Executes", func(t *testing.T) {
req := httptest.NewRequest("GET", "/health", nil) request := makeGetRequest(t, router, "/health")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req) if request.Code == 0 {
if rec.Code == 0 {
t.Error("Expected logging middleware to execute") t.Error("Expected logging middleware to execute")
} }
}) })
@@ -53,27 +47,24 @@ func TestIntegration_Router_FullMiddlewareChain(t *testing.T) {
t.Run("RequestSizeLimit_Enforced", func(t *testing.T) { t.Run("RequestSizeLimit_Enforced", func(t *testing.T) {
user := createUserWithCleanup(t, ctx, "size_limit_user", "size_limit@example.com") user := createUserWithCleanup(t, ctx, "size_limit_user", "size_limit@example.com")
largeBody := strings.Repeat("a", 10*1024*1024) largeBody := strings.Repeat("a", 10*1024*1024)
req := httptest.NewRequest("POST", "/api/posts", bytes.NewBufferString(largeBody)) request := httptest.NewRequest("POST", "/api/posts", bytes.NewBufferString(largeBody))
req.Header.Set("Content-Type", "application/json") request.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+user.Token) request.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID) request = testutils.WithUserContext(request, middleware.UserIDKey, user.User.ID)
rec := httptest.NewRecorder() recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req) router.ServeHTTP(recorder, request)
if rec.Code != http.StatusRequestEntityTooLarge && rec.Code != http.StatusBadRequest { if recorder.Code != http.StatusRequestEntityTooLarge && recorder.Code != http.StatusBadRequest {
t.Errorf("Expected status 413 or 400 for oversized request, got %d. Body: %s", rec.Code, rec.Body.String()) t.Errorf("Expected status 413 or 400 for oversized request, got %d. Body: %s", recorder.Code, recorder.Body.String())
} }
}) })
t.Run("DBMonitoring_Active", func(t *testing.T) { t.Run("DBMonitoring_Active", func(t *testing.T) {
req := httptest.NewRequest("GET", "/health", nil) request := makeGetRequest(t, router, "/health")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
var response map[string]any var response map[string]any
if err := json.NewDecoder(rec.Body).Decode(&response); err == nil { if err := json.NewDecoder(request.Body).Decode(&response); err == nil {
if data, ok := response["data"].(map[string]any); ok { if data, ok := response["data"].(map[string]any); ok {
if _, exists := data["database_stats"]; !exists { if _, exists := data["database_stats"]; !exists {
t.Error("Expected database_stats in health response") t.Error("Expected database_stats in health response")
@@ -83,12 +74,9 @@ func TestIntegration_Router_FullMiddlewareChain(t *testing.T) {
}) })
t.Run("Metrics_Middleware_Executes", func(t *testing.T) { t.Run("Metrics_Middleware_Executes", func(t *testing.T) {
req := httptest.NewRequest("GET", "/metrics", nil) request := makeGetRequest(t, router, "/metrics")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req) response := assertJSONResponse(t, request, http.StatusOK)
response := assertJSONResponse(t, rec, http.StatusOK)
if response != nil { if response != nil {
if data, ok := response["data"].(map[string]any); ok { if data, ok := response["data"].(map[string]any); ok {
if _, exists := data["database"]; !exists { if _, exists := data["database"]; !exists {
@@ -99,34 +87,25 @@ func TestIntegration_Router_FullMiddlewareChain(t *testing.T) {
}) })
t.Run("StaticFiles_Served", func(t *testing.T) { t.Run("StaticFiles_Served", func(t *testing.T) {
req := httptest.NewRequest("GET", "/robots.txt", nil) request := makeGetRequest(t, router, "/robots.txt")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req) assertStatus(t, request, http.StatusOK)
assertStatus(t, rec, http.StatusOK) if !strings.Contains(request.Body.String(), "User-agent") {
if !strings.Contains(rec.Body.String(), "User-agent") {
t.Error("Expected robots.txt content") t.Error("Expected robots.txt content")
} }
}) })
t.Run("API_Routes_Accessible", func(t *testing.T) { t.Run("API_Routes_Accessible", func(t *testing.T) {
req := httptest.NewRequest("GET", "/api/posts", nil) request := makeGetRequest(t, router, "/api/posts")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req) assertStatus(t, request, http.StatusOK)
assertStatus(t, rec, http.StatusOK)
}) })
t.Run("Health_Endpoint_Accessible", func(t *testing.T) { t.Run("Health_Endpoint_Accessible", func(t *testing.T) {
req := httptest.NewRequest("GET", "/health", nil) request := makeGetRequest(t, router, "/health")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req) response := assertJSONResponse(t, request, http.StatusOK)
response := assertJSONResponse(t, rec, http.StatusOK)
if response != nil { if response != nil {
if success, ok := response["success"].(bool); !ok || !success { if success, ok := response["success"].(bool); !ok || !success {
t.Error("Expected success=true in health response") t.Error("Expected success=true in health response")
@@ -135,40 +114,33 @@ func TestIntegration_Router_FullMiddlewareChain(t *testing.T) {
}) })
t.Run("Middleware_Order_Correct", func(t *testing.T) { t.Run("Middleware_Order_Correct", func(t *testing.T) {
req := httptest.NewRequest("GET", "/api/posts", nil) request := makeGetRequest(t, router, "/api/posts")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req) assertHeader(t, request, "X-Content-Type-Options")
assertHeader(t, rec, "X-Content-Type-Options", "") if request.Code == 0 {
if rec.Code == 0 {
t.Error("Response should have status code") t.Error("Response should have status code")
} }
}) })
t.Run("Compression_Middleware_Active", func(t *testing.T) { t.Run("Compression_Middleware_Active", func(t *testing.T) {
req := httptest.NewRequest("GET", "/api/posts", nil) request := httptest.NewRequest("GET", "/api/posts", nil)
req.Header.Set("Accept-Encoding", "gzip") request.Header.Set("Accept-Encoding", "gzip")
rec := httptest.NewRecorder() recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req) router.ServeHTTP(recorder, request)
if rec.Header().Get("Content-Encoding") == "" { if recorder.Header().Get("Content-Encoding") == "" {
t.Log("Compression may not be applied to small responses") t.Log("Compression may not be applied to small responses")
} }
}) })
t.Run("Cache_Middleware_Active", func(t *testing.T) { t.Run("Cache_Middleware_Active", func(t *testing.T) {
req1 := httptest.NewRequest("GET", "/api/posts", nil) firstRequest := makeGetRequest(t, router, "/api/posts")
rec1 := httptest.NewRecorder()
router.ServeHTTP(rec1, req1)
req2 := httptest.NewRequest("GET", "/api/posts", nil) secondRequest := makeGetRequest(t, router, "/api/posts")
rec2 := httptest.NewRecorder()
router.ServeHTTP(rec2, req2)
if rec1.Code != rec2.Code { if firstRequest.Code != secondRequest.Code {
t.Error("Cached responses should have same status") t.Error("Cached responses should have same status")
} }
}) })
@@ -177,35 +149,23 @@ func TestIntegration_Router_FullMiddlewareChain(t *testing.T) {
ctx.Suite.EmailSender.Reset() ctx.Suite.EmailSender.Reset()
user := createUserWithCleanup(t, ctx, "auth_middleware_user", "auth_middleware@example.com") user := createUserWithCleanup(t, ctx, "auth_middleware_user", "auth_middleware@example.com")
req := httptest.NewRequest("GET", "/api/auth/me", nil) request := makeAuthenticatedGetRequest(t, router, "/api/auth/me", user, nil)
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req) assertStatus(t, request, http.StatusOK)
assertStatus(t, rec, http.StatusOK)
}) })
t.Run("RateLimit_Middleware_Integration", func(t *testing.T) { t.Run("RateLimit_Middleware_Integration", func(t *testing.T) {
rateLimitCtx := setupTestContext(t) rateLimitCtx := setupTestContext(t)
rateLimitRouter := rateLimitCtx.Router rateLimitRouter := rateLimitCtx.Router
for i := 0; i < 3; i++ { for range 3 {
req := httptest.NewRequest("POST", "/api/auth/login", bytes.NewBufferString(`{"username":"test","password":"test"}`)) request := makePostRequestWithJSON(t, rateLimitRouter, "/api/auth/login", map[string]any{"username": "test", "password": "test"})
req.Header.Set("Content-Type", "application/json") _ = request
rec := httptest.NewRecorder()
rateLimitRouter.ServeHTTP(rec, req)
} }
req := httptest.NewRequest("POST", "/api/auth/login", bytes.NewBufferString(`{"username":"test","password":"test"}`)) request := makePostRequestWithJSON(t, rateLimitRouter, "/api/auth/login", map[string]any{"username": "test", "password": "test"})
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
rateLimitRouter.ServeHTTP(rec, req) if request.Code == http.StatusTooManyRequests {
if rec.Code == http.StatusTooManyRequests {
t.Log("Rate limiting is working") t.Log("Rate limiting is working")
} }
}) })

View File

@@ -21,60 +21,60 @@ func TestIntegration_SessionManagement(t *testing.T) {
ctx.Suite.EmailSender.Reset() ctx.Suite.EmailSender.Reset()
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "session_pass_user", "session_pass@example.com") user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "session_pass_user", "session_pass@example.com")
req1 := httptest.NewRequest("GET", "/api/auth/me", nil) firstRequest := httptest.NewRequest("GET", "/api/auth/me", nil)
req1.Header.Set("Authorization", "Bearer "+user.Token) firstRequest.Header.Set("Authorization", "Bearer "+user.Token)
req1 = testutils.WithUserContext(req1, middleware.UserIDKey, user.User.ID) firstRequest = testutils.WithUserContext(firstRequest, middleware.UserIDKey, user.User.ID)
rec1 := httptest.NewRecorder() firstRecorder := httptest.NewRecorder()
router.ServeHTTP(rec1, req1) router.ServeHTTP(firstRecorder, firstRequest)
assertStatus(t, rec1, http.StatusOK) assertStatus(t, firstRecorder, http.StatusOK)
reqBody := map[string]string{ requestBody := map[string]string{
"current_password": "SecurePass123!", "current_password": "SecurePass123!",
"new_password": "NewSecurePass123!", "new_password": "NewSecurePass123!",
} }
body, _ := json.Marshal(reqBody) body, _ := json.Marshal(requestBody)
req2 := httptest.NewRequest("PUT", "/api/auth/password", bytes.NewBuffer(body)) secondRequest := httptest.NewRequest("PUT", "/api/auth/password", bytes.NewBuffer(body))
req2.Header.Set("Content-Type", "application/json") secondRequest.Header.Set("Content-Type", "application/json")
req2.Header.Set("Authorization", "Bearer "+user.Token) secondRequest.Header.Set("Authorization", "Bearer "+user.Token)
req2 = testutils.WithUserContext(req2, middleware.UserIDKey, user.User.ID) secondRequest = testutils.WithUserContext(secondRequest, middleware.UserIDKey, user.User.ID)
rec2 := httptest.NewRecorder() secondRecorder := httptest.NewRecorder()
router.ServeHTTP(rec2, req2) router.ServeHTTP(secondRecorder, secondRequest)
assertStatus(t, rec2, http.StatusOK) assertStatus(t, secondRecorder, http.StatusOK)
req3 := httptest.NewRequest("GET", "/api/auth/me", nil) thirdRequest := httptest.NewRequest("GET", "/api/auth/me", nil)
req3.Header.Set("Authorization", "Bearer "+user.Token) thirdRequest.Header.Set("Authorization", "Bearer "+user.Token)
req3 = testutils.WithUserContext(req3, middleware.UserIDKey, user.User.ID) thirdRequest = testutils.WithUserContext(thirdRequest, middleware.UserIDKey, user.User.ID)
rec3 := httptest.NewRecorder() thirdRecorder := httptest.NewRecorder()
router.ServeHTTP(rec3, req3) router.ServeHTTP(thirdRecorder, thirdRequest)
assertErrorResponse(t, rec3, http.StatusUnauthorized) assertErrorResponse(t, thirdRecorder, http.StatusUnauthorized)
}) })
t.Run("Session_Invalidation_On_Account_Lock", func(t *testing.T) { t.Run("Session_Invalidation_On_Account_Lock", func(t *testing.T) {
ctx.Suite.EmailSender.Reset() ctx.Suite.EmailSender.Reset()
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "session_lock_user", "session_lock@example.com") user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "session_lock_user", "session_lock@example.com")
req1 := httptest.NewRequest("GET", "/api/auth/me", nil) firstRequest := httptest.NewRequest("GET", "/api/auth/me", nil)
req1.Header.Set("Authorization", "Bearer "+user.Token) firstRequest.Header.Set("Authorization", "Bearer "+user.Token)
req1 = testutils.WithUserContext(req1, middleware.UserIDKey, user.User.ID) firstRequest = testutils.WithUserContext(firstRequest, middleware.UserIDKey, user.User.ID)
rec1 := httptest.NewRecorder() firstRecorder := httptest.NewRecorder()
router.ServeHTTP(rec1, req1) router.ServeHTTP(firstRecorder, firstRequest)
assertStatus(t, rec1, http.StatusOK) assertStatus(t, firstRecorder, http.StatusOK)
if err := ctx.Suite.UserRepo.Lock(user.User.ID); err != nil { if err := ctx.Suite.UserRepo.Lock(user.User.ID); err != nil {
t.Fatalf("Failed to lock user: %v", err) t.Fatalf("Failed to lock user: %v", err)
} }
req2 := httptest.NewRequest("GET", "/api/auth/me", nil) secondRequest := httptest.NewRequest("GET", "/api/auth/me", nil)
req2.Header.Set("Authorization", "Bearer "+user.Token) secondRequest.Header.Set("Authorization", "Bearer "+user.Token)
req2 = testutils.WithUserContext(req2, middleware.UserIDKey, user.User.ID) secondRequest = testutils.WithUserContext(secondRequest, middleware.UserIDKey, user.User.ID)
rec2 := httptest.NewRecorder() secondRecorder := httptest.NewRecorder()
router.ServeHTTP(rec2, req2) router.ServeHTTP(secondRecorder, secondRequest)
assertErrorResponse(t, rec2, http.StatusUnauthorized) assertErrorResponse(t, secondRecorder, http.StatusUnauthorized)
}) })
t.Run("Refresh_Token_Revocation", func(t *testing.T) { t.Run("Refresh_Token_Revocation", func(t *testing.T) {
@@ -90,48 +90,48 @@ func TestIntegration_SessionManagement(t *testing.T) {
t.Fatal("Expected refresh token") t.Fatal("Expected refresh token")
} }
reqBody := map[string]string{ requestBody := map[string]string{
"refresh_token": loginResult.RefreshToken, "refresh_token": loginResult.RefreshToken,
} }
body, _ := json.Marshal(reqBody) body, _ := json.Marshal(requestBody)
req := httptest.NewRequest("POST", "/api/auth/refresh", bytes.NewBuffer(body)) request := httptest.NewRequest("POST", "/api/auth/refresh", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json") request.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder() recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req) router.ServeHTTP(recorder, request)
assertStatus(t, rec, http.StatusOK) assertStatus(t, recorder, http.StatusOK)
if err := ctx.AuthService.RevokeRefreshToken(loginResult.RefreshToken); err != nil { if err := ctx.AuthService.RevokeRefreshToken(loginResult.RefreshToken); err != nil {
t.Fatalf("Failed to revoke token: %v", err) t.Fatalf("Failed to revoke token: %v", err)
} }
req2 := httptest.NewRequest("POST", "/api/auth/refresh", bytes.NewBuffer(body)) secondRequest := httptest.NewRequest("POST", "/api/auth/refresh", bytes.NewBuffer(body))
req2.Header.Set("Content-Type", "application/json") secondRequest.Header.Set("Content-Type", "application/json")
rec2 := httptest.NewRecorder() secondRecorder := httptest.NewRecorder()
router.ServeHTTP(rec2, req2) router.ServeHTTP(secondRecorder, secondRequest)
assertErrorResponse(t, rec2, http.StatusUnauthorized) assertErrorResponse(t, secondRecorder, http.StatusUnauthorized)
}) })
t.Run("Multiple_Sessions_Independent", func(t *testing.T) { t.Run("Multiple_Sessions_Independent", func(t *testing.T) {
ctx.Suite.EmailSender.Reset() ctx.Suite.EmailSender.Reset()
user1 := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "multi_session_user1", "multi_session1@example.com") firstUser := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "multi_session_user1", "multi_session1@example.com")
user2 := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "multi_session_user2", "multi_session2@example.com") secondUser := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "multi_session_user2", "multi_session2@example.com")
req1 := httptest.NewRequest("GET", "/api/auth/me", nil) firstRequest := httptest.NewRequest("GET", "/api/auth/me", nil)
req1.Header.Set("Authorization", "Bearer "+user1.Token) firstRequest.Header.Set("Authorization", "Bearer "+firstUser.Token)
req1 = testutils.WithUserContext(req1, middleware.UserIDKey, user1.User.ID) firstRequest = testutils.WithUserContext(firstRequest, middleware.UserIDKey, firstUser.User.ID)
rec1 := httptest.NewRecorder() firstRecorder := httptest.NewRecorder()
router.ServeHTTP(rec1, req1) router.ServeHTTP(firstRecorder, firstRequest)
req2 := httptest.NewRequest("GET", "/api/auth/me", nil) secondRequest := httptest.NewRequest("GET", "/api/auth/me", nil)
req2.Header.Set("Authorization", "Bearer "+user2.Token) secondRequest.Header.Set("Authorization", "Bearer "+secondUser.Token)
req2 = testutils.WithUserContext(req2, middleware.UserIDKey, user2.User.ID) secondRequest = testutils.WithUserContext(secondRequest, middleware.UserIDKey, secondUser.User.ID)
rec2 := httptest.NewRecorder() secondRecorder := httptest.NewRecorder()
router.ServeHTTP(rec2, req2) router.ServeHTTP(secondRecorder, secondRequest)
assertStatus(t, rec1, http.StatusOK) assertStatus(t, firstRecorder, http.StatusOK)
assertStatus(t, rec2, http.StatusOK) assertStatus(t, secondRecorder, http.StatusOK)
}) })
} }
@@ -144,17 +144,17 @@ func TestIntegration_AccountDeletion(t *testing.T) {
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "del_flow_user", "del_flow@example.com") user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "del_flow_user", "del_flow@example.com")
post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Test Post", "https://example.com") post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Test Post", "https://example.com")
reqBody := map[string]string{} requestBody := map[string]string{}
body, _ := json.Marshal(reqBody) body, _ := json.Marshal(requestBody)
req := httptest.NewRequest("DELETE", "/api/auth/account", bytes.NewBuffer(body)) request := httptest.NewRequest("DELETE", "/api/auth/account", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json") request.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+user.Token) request.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID) request = testutils.WithUserContext(request, middleware.UserIDKey, user.User.ID)
rec := httptest.NewRecorder() recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req) router.ServeHTTP(recorder, request)
response := assertJSONResponse(t, rec, http.StatusOK) response := assertJSONResponse(t, recorder, http.StatusOK)
if response == nil { if response == nil {
return return
} }
@@ -171,13 +171,13 @@ func TestIntegration_AccountDeletion(t *testing.T) {
"token": deletionToken, "token": deletionToken,
} }
confirmBodyBytes, _ := json.Marshal(confirmBody) confirmBodyBytes, _ := json.Marshal(confirmBody)
confirmReq := httptest.NewRequest("POST", "/api/auth/account/confirm", bytes.NewBuffer(confirmBodyBytes)) confirmRequest := httptest.NewRequest("POST", "/api/auth/account/confirm", bytes.NewBuffer(confirmBodyBytes))
confirmReq.Header.Set("Content-Type", "application/json") confirmRequest.Header.Set("Content-Type", "application/json")
confirmRec := httptest.NewRecorder() confirmRecorder := httptest.NewRecorder()
router.ServeHTTP(confirmRec, confirmReq) router.ServeHTTP(confirmRecorder, confirmRequest)
confirmResponse := assertJSONResponse(t, confirmRec, http.StatusOK) confirmResponse := assertJSONResponse(t, confirmRecorder, http.StatusOK)
if confirmResponse == nil { if confirmResponse == nil {
return return
} }
@@ -209,17 +209,17 @@ func TestIntegration_AccountDeletion(t *testing.T) {
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "del_posts_user", "del_posts@example.com") user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "del_posts_user", "del_posts@example.com")
post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Deletion Post", "https://example.com/deletion") post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Deletion Post", "https://example.com/deletion")
reqBody := map[string]string{} requestBody := map[string]string{}
body, _ := json.Marshal(reqBody) body, _ := json.Marshal(requestBody)
req := httptest.NewRequest("DELETE", "/api/auth/account", bytes.NewBuffer(body)) request := httptest.NewRequest("DELETE", "/api/auth/account", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json") request.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+user.Token) request.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID) request = testutils.WithUserContext(request, middleware.UserIDKey, user.User.ID)
rec := httptest.NewRecorder() recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req) router.ServeHTTP(recorder, request)
response := assertJSONResponse(t, rec, http.StatusOK) response := assertJSONResponse(t, recorder, http.StatusOK)
if response == nil { if response == nil {
return return
} }
@@ -237,13 +237,13 @@ func TestIntegration_AccountDeletion(t *testing.T) {
"delete_posts": true, "delete_posts": true,
} }
confirmBodyBytes, _ := json.Marshal(confirmBody) confirmBodyBytes, _ := json.Marshal(confirmBody)
confirmReq := httptest.NewRequest("POST", "/api/auth/account/confirm", bytes.NewBuffer(confirmBodyBytes)) confirmRequest := httptest.NewRequest("POST", "/api/auth/account/confirm", bytes.NewBuffer(confirmBodyBytes))
confirmReq.Header.Set("Content-Type", "application/json") confirmRequest.Header.Set("Content-Type", "application/json")
confirmRec := httptest.NewRecorder() confirmRecorder := httptest.NewRecorder()
router.ServeHTTP(confirmRec, confirmReq) router.ServeHTTP(confirmRecorder, confirmRequest)
confirmResponse := assertJSONResponse(t, confirmRec, http.StatusOK) confirmResponse := assertJSONResponse(t, confirmRecorder, http.StatusOK)
if confirmResponse == nil { if confirmResponse == nil {
return return
} }
@@ -275,12 +275,12 @@ func TestIntegration_MetricsCollection(t *testing.T) {
router := ctx.Router router := ctx.Router
t.Run("Metrics_Endpoint_Returns_Data", func(t *testing.T) { t.Run("Metrics_Endpoint_Returns_Data", func(t *testing.T) {
req := httptest.NewRequest("GET", "/metrics", nil) request := httptest.NewRequest("GET", "/metrics", nil)
rec := httptest.NewRecorder() recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req) router.ServeHTTP(recorder, request)
response := assertJSONResponse(t, rec, http.StatusOK) response := assertJSONResponse(t, recorder, http.StatusOK)
if response != nil { if response != nil {
if data, ok := response["data"].(map[string]any); ok { if data, ok := response["data"].(map[string]any); ok {
if _, exists := data["database"]; !exists { if _, exists := data["database"]; !exists {
@@ -294,13 +294,13 @@ func TestIntegration_MetricsCollection(t *testing.T) {
ctx.Suite.EmailSender.Reset() ctx.Suite.EmailSender.Reset()
createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "metrics_user", "metrics@example.com") createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "metrics_user", "metrics@example.com")
req := httptest.NewRequest("GET", "/metrics", nil) request := httptest.NewRequest("GET", "/metrics", nil)
rec := httptest.NewRecorder() recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req) router.ServeHTTP(recorder, request)
var response map[string]any var response map[string]any
if err := json.NewDecoder(rec.Body).Decode(&response); err == nil { if err := json.NewDecoder(recorder.Body).Decode(&response); err == nil {
if data, ok := response["data"].(map[string]any); ok { if data, ok := response["data"].(map[string]any); ok {
if dbData, exists := data["database"].(map[string]any); exists { if dbData, exists := data["database"].(map[string]any); exists {
if _, hasQueries := dbData["total_queries"]; !hasQueries { if _, hasQueries := dbData["total_queries"]; !hasQueries {
@@ -323,7 +323,7 @@ func TestIntegration_ConcurrentRequests(t *testing.T) {
var wg sync.WaitGroup var wg sync.WaitGroup
errors := make(chan error, 10) errors := make(chan error, 10)
for i := 0; i < 10; i++ { for idx := range 10 {
wg.Add(1) wg.Add(1)
go func(index int) { go func(index int) {
defer wg.Done() defer wg.Done()
@@ -334,18 +334,18 @@ func TestIntegration_ConcurrentRequests(t *testing.T) {
"content": "Concurrent test content", "content": "Concurrent test content",
} }
body, _ := json.Marshal(postBody) body, _ := json.Marshal(postBody)
req := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer(body)) request := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json") request.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+user.Token) request.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID) request = testutils.WithUserContext(request, middleware.UserIDKey, user.User.ID)
rec := httptest.NewRecorder() recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req) router.ServeHTTP(recorder, request)
if rec.Code != http.StatusCreated { if recorder.Code != http.StatusCreated {
errors <- fmt.Errorf("Post %d failed with status %d", index, rec.Code) errors <- fmt.Errorf("Post %d failed with status %d", index, recorder.Code)
} }
}(i) }(idx)
} }
wg.Wait() wg.Wait()
@@ -370,28 +370,26 @@ func TestIntegration_ConcurrentRequests(t *testing.T) {
var wg sync.WaitGroup var wg sync.WaitGroup
errors := make(chan error, 5) errors := make(chan error, 5)
for i := 0; i < 5; i++ { for range 5 {
wg.Add(1) wg.Go(func() {
go func() {
defer wg.Done()
voteBody := map[string]string{ voteBody := map[string]string{
"type": "up", "type": "up",
} }
body, _ := json.Marshal(voteBody) body, _ := json.Marshal(voteBody)
req := httptest.NewRequest("POST", fmt.Sprintf("/api/posts/%d/vote", post.ID), bytes.NewBuffer(body)) request := httptest.NewRequest("POST", fmt.Sprintf("/api/posts/%d/vote", post.ID), bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json") request.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+user.Token) request.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID) request = testutils.WithUserContext(request, middleware.UserIDKey, user.User.ID)
req = testutils.WithURLParams(req, map[string]string{"id": fmt.Sprintf("%d", post.ID)}) request = testutils.WithURLParams(request, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
rec := httptest.NewRecorder() recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req) router.ServeHTTP(recorder, request)
if rec.Code != http.StatusOK { if recorder.Code != http.StatusOK {
errors <- fmt.Errorf("Vote failed with status %d", rec.Code) errors <- fmt.Errorf("Vote failed with status %d", recorder.Code)
} }
}() })
} }
wg.Wait() wg.Wait()
@@ -411,20 +409,18 @@ func TestIntegration_ConcurrentRequests(t *testing.T) {
var wg sync.WaitGroup var wg sync.WaitGroup
errors := make(chan error, 20) errors := make(chan error, 20)
for i := 0; i < 20; i++ { for range 20 {
wg.Add(1) wg.Go(func() {
go func() {
defer wg.Done()
req := httptest.NewRequest("GET", "/api/posts", nil) request := httptest.NewRequest("GET", "/api/posts", nil)
rec := httptest.NewRecorder() recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req) router.ServeHTTP(recorder, request)
if rec.Code != http.StatusOK { if recorder.Code != http.StatusOK {
errors <- fmt.Errorf("Read failed with status %d", rec.Code) errors <- fmt.Errorf("Read failed with status %d", recorder.Code)
} }
}() })
} }
wg.Wait() wg.Wait()

View File

@@ -38,6 +38,10 @@ func CompressionMiddlewareWithConfig(config *CompressionConfig) func(http.Handle
next.ServeHTTP(bufferedWriter, r) next.ServeHTTP(bufferedWriter, r)
if bufferedWriter.isRedirect {
return
}
if buf.Len() < config.MinSize { if buf.Len() < config.MinSize {
bufferedWriter.flush() bufferedWriter.flush()
w.Write(buf.Bytes()) w.Write(buf.Bytes())
@@ -73,9 +77,13 @@ type bufferedResponseWriter struct {
buffer *bytes.Buffer buffer *bytes.Buffer
statusCode int statusCode int
headerWritten bool headerWritten bool
isRedirect bool
} }
func (brw *bufferedResponseWriter) Write(b []byte) (int, error) { func (brw *bufferedResponseWriter) Write(b []byte) (int, error) {
if brw.isRedirect {
return brw.ResponseWriter.Write(b)
}
if !brw.headerWritten { if !brw.headerWritten {
brw.statusCode = http.StatusOK brw.statusCode = http.StatusOK
} }
@@ -87,6 +95,11 @@ func (brw *bufferedResponseWriter) WriteHeader(code int) {
return return
} }
brw.statusCode = code brw.statusCode = code
if isRedirect(code) {
brw.isRedirect = true
brw.ResponseWriter.WriteHeader(code)
brw.headerWritten = true
}
} }
func (brw *bufferedResponseWriter) Header() http.Header { func (brw *bufferedResponseWriter) Header() http.Header {
@@ -100,6 +113,10 @@ func (brw *bufferedResponseWriter) flush() {
} }
} }
func isRedirect(statusCode int) bool {
return statusCode >= 300 && statusCode < 400
}
func shouldCompress(r *http.Request, config *CompressionConfig) bool { func shouldCompress(r *http.Request, config *CompressionConfig) bool {
return r.Header.Get("Content-Encoding") == "" return r.Header.Get("Content-Encoding") == ""
} }

View File

@@ -27,7 +27,14 @@ func ValidationMiddleware() func(http.Handler) http.Handler {
dto := reflect.New(dtoType).Interface() dto := reflect.New(dtoType).Interface()
if err := json.NewDecoder(r.Body).Decode(dto); err != nil { if err := json.NewDecoder(r.Body).Decode(dto); err != nil {
http.Error(w, "Invalid JSON", http.StatusBadRequest) response := map[string]any{
"success": false,
"error": "Invalid JSON",
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadRequest)
json.NewEncoder(w).Encode(response)
return return
} }
@@ -77,3 +84,7 @@ func GetDTOTypeFromContext(ctx context.Context) reflect.Type {
func GetValidatedDTOFromContext(ctx context.Context) any { func GetValidatedDTOFromContext(ctx context.Context) any {
return ctx.Value(validatedDTOKey) return ctx.Value(validatedDTOKey)
} }
func SetValidatedDTOInContext(ctx context.Context, dto any) context.Context {
return context.WithValue(ctx, validatedDTOKey, dto)
}

View File

@@ -1,8 +1,10 @@
package server package server
import ( import (
"mime"
"net/http" "net/http"
"path/filepath" "path/filepath"
"strings"
"time" "time"
"goyco/internal/config" "goyco/internal/config"
@@ -124,7 +126,33 @@ func NewRouter(cfg RouterConfig) http.Handler {
staticDir = "./internal/static/" staticDir = "./internal/static/"
} }
router.Handle("/static/*", http.StripPrefix("/static/", http.FileServer(http.Dir(staticDir)))) staticFileServer := http.StripPrefix("/static/", http.FileServer(http.Dir(staticDir)))
router.Handle("/static/*", staticFileHandler(staticFileServer))
return router return router
} }
func staticFileHandler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
path := r.URL.Path
ext := filepath.Ext(path)
if ext == ".css" {
w.Header().Set("Content-Type", "text/css; charset=utf-8")
} else if ext == ".js" {
w.Header().Set("Content-Type", "application/javascript; charset=utf-8")
} else if ext == ".json" {
w.Header().Set("Content-Type", "application/json; charset=utf-8")
} else if ext == ".ico" {
w.Header().Set("Content-Type", "image/x-icon")
} else if strings.HasPrefix(mime.TypeByExtension(ext), "image/") {
w.Header().Set("Content-Type", mime.TypeByExtension(ext))
} else if strings.HasPrefix(mime.TypeByExtension(ext), "font/") {
w.Header().Set("Content-Type", mime.TypeByExtension(ext))
} else if mimeType := mime.TypeByExtension(ext); mimeType != "" {
w.Header().Set("Content-Type", mimeType)
}
next.ServeHTTP(w, r)
})
}

View File

@@ -3,6 +3,7 @@ package server
import ( import (
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"strings"
"testing" "testing"
"goyco/internal/config" "goyco/internal/config"
@@ -105,9 +106,9 @@ func defaultRateLimitConfig() config.RateLimitConfig {
return testutils.AppTestConfig.RateLimit return testutils.AppTestConfig.RateLimit
} }
func TestAPIRootRouting(t *testing.T) { func createDefaultRouterConfig() RouterConfig {
authHandler, postHandler, voteHandler, userHandler, apiHandler, authService := setupTestHandlers() authHandler, postHandler, voteHandler, userHandler, apiHandler, authService := setupTestHandlers()
router := NewRouter(RouterConfig{ return RouterConfig{
APIHandler: apiHandler, APIHandler: apiHandler,
AuthHandler: authHandler, AuthHandler: authHandler,
PostHandler: postHandler, PostHandler: postHandler,
@@ -115,7 +116,15 @@ func TestAPIRootRouting(t *testing.T) {
UserHandler: userHandler, UserHandler: userHandler,
AuthService: authService, AuthService: authService,
RateLimitConfig: defaultRateLimitConfig(), RateLimitConfig: defaultRateLimitConfig(),
}) }
}
func createTestRouter(cfg RouterConfig) http.Handler {
return NewRouter(cfg)
}
func TestAPIRootRouting(t *testing.T) {
router := createTestRouter(createDefaultRouterConfig())
testCases := []struct { testCases := []struct {
name string name string
@@ -141,23 +150,23 @@ func TestAPIRootRouting(t *testing.T) {
} }
func TestProtectedRoutesRequireAuth(t *testing.T) { func TestProtectedRoutesRequireAuth(t *testing.T) {
authHandler, postHandler, voteHandler, userHandler, apiHandler, authService := setupTestHandlers() router := createTestRouter(createDefaultRouterConfig())
router := NewRouter(RouterConfig{
APIHandler: apiHandler,
AuthHandler: authHandler,
PostHandler: postHandler,
VoteHandler: voteHandler,
UserHandler: userHandler,
AuthService: authService,
RateLimitConfig: defaultRateLimitConfig(),
})
protectedRoutes := []struct { protectedRoutes := []struct {
method string method string
path string path string
}{ }{
{http.MethodGet, "/api/auth/me"}, {http.MethodGet, "/api/auth/me"},
{http.MethodPost, "/api/auth/logout"},
{http.MethodPost, "/api/auth/revoke"},
{http.MethodPost, "/api/auth/revoke-all"},
{http.MethodPut, "/api/auth/email"},
{http.MethodPut, "/api/auth/username"},
{http.MethodPut, "/api/auth/password"},
{http.MethodDelete, "/api/auth/account"},
{http.MethodPost, "/api/posts"}, {http.MethodPost, "/api/posts"},
{http.MethodPut, "/api/posts/1"},
{http.MethodDelete, "/api/posts/1"},
{http.MethodPost, "/api/posts/1/vote"}, {http.MethodPost, "/api/posts/1/vote"},
{http.MethodDelete, "/api/posts/1/vote"}, {http.MethodDelete, "/api/posts/1/vote"},
{http.MethodGet, "/api/posts/1/vote"}, {http.MethodGet, "/api/posts/1/vote"},
@@ -183,17 +192,9 @@ func TestProtectedRoutesRequireAuth(t *testing.T) {
} }
func TestRouterWithDebugMode(t *testing.T) { func TestRouterWithDebugMode(t *testing.T) {
authHandler, postHandler, voteHandler, userHandler, apiHandler, authService := setupTestHandlers() cfg := createDefaultRouterConfig()
router := NewRouter(RouterConfig{ cfg.Debug = true
Debug: true, router := createTestRouter(cfg)
APIHandler: apiHandler,
AuthHandler: authHandler,
PostHandler: postHandler,
VoteHandler: voteHandler,
UserHandler: userHandler,
AuthService: authService,
RateLimitConfig: defaultRateLimitConfig(),
})
request := httptest.NewRequest(http.MethodGet, "/api", nil) request := httptest.NewRequest(http.MethodGet, "/api", nil)
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
@@ -206,16 +207,9 @@ func TestRouterWithDebugMode(t *testing.T) {
} }
func TestRouterWithCacheDisabled(t *testing.T) { func TestRouterWithCacheDisabled(t *testing.T) {
authHandler, postHandler, voteHandler, userHandler, apiHandler, authService := setupTestHandlers() cfg := createDefaultRouterConfig()
router := NewRouter(RouterConfig{ cfg.DisableCache = true
DisableCache: true, router := createTestRouter(cfg)
APIHandler: apiHandler,
AuthHandler: authHandler,
PostHandler: postHandler,
VoteHandler: voteHandler,
UserHandler: userHandler,
AuthService: authService,
})
request := httptest.NewRequest(http.MethodGet, "/api", nil) request := httptest.NewRequest(http.MethodGet, "/api", nil)
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
@@ -228,17 +222,9 @@ func TestRouterWithCacheDisabled(t *testing.T) {
} }
func TestRouterWithCompressionDisabled(t *testing.T) { func TestRouterWithCompressionDisabled(t *testing.T) {
authHandler, postHandler, voteHandler, userHandler, apiHandler, authService := setupTestHandlers() cfg := createDefaultRouterConfig()
router := NewRouter(RouterConfig{ cfg.DisableCompression = true
DisableCompression: true, router := createTestRouter(cfg)
APIHandler: apiHandler,
AuthHandler: authHandler,
PostHandler: postHandler,
VoteHandler: voteHandler,
UserHandler: userHandler,
AuthService: authService,
RateLimitConfig: defaultRateLimitConfig(),
})
request := httptest.NewRequest(http.MethodGet, "/api", nil) request := httptest.NewRequest(http.MethodGet, "/api", nil)
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
@@ -251,19 +237,9 @@ func TestRouterWithCompressionDisabled(t *testing.T) {
} }
func TestRouterWithCustomDBMonitor(t *testing.T) { func TestRouterWithCustomDBMonitor(t *testing.T) {
authHandler, postHandler, voteHandler, userHandler, apiHandler, authService := setupTestHandlers() cfg := createDefaultRouterConfig()
customDBMonitor := middleware.NewInMemoryDBMonitor() cfg.DBMonitor = middleware.NewInMemoryDBMonitor()
router := createTestRouter(cfg)
router := NewRouter(RouterConfig{
DBMonitor: customDBMonitor,
APIHandler: apiHandler,
AuthHandler: authHandler,
PostHandler: postHandler,
VoteHandler: voteHandler,
UserHandler: userHandler,
AuthService: authService,
RateLimitConfig: defaultRateLimitConfig(),
})
request := httptest.NewRequest(http.MethodGet, "/api", nil) request := httptest.NewRequest(http.MethodGet, "/api", nil)
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
@@ -296,18 +272,9 @@ func TestRouterWithPageHandler(t *testing.T) {
} }
func TestRouterWithStaticDir(t *testing.T) { func TestRouterWithStaticDir(t *testing.T) {
authHandler, postHandler, voteHandler, userHandler, apiHandler, authService := setupTestHandlers() cfg := createDefaultRouterConfig()
cfg.StaticDir = "/custom/static/path"
router := NewRouter(RouterConfig{ router := createTestRouter(cfg)
StaticDir: "/custom/static/path",
APIHandler: apiHandler,
AuthHandler: authHandler,
PostHandler: postHandler,
VoteHandler: voteHandler,
UserHandler: userHandler,
AuthService: authService,
RateLimitConfig: defaultRateLimitConfig(),
})
request := httptest.NewRequest(http.MethodGet, "/api", nil) request := httptest.NewRequest(http.MethodGet, "/api", nil)
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
@@ -320,18 +287,9 @@ func TestRouterWithStaticDir(t *testing.T) {
} }
func TestRouterWithEmptyStaticDir(t *testing.T) { func TestRouterWithEmptyStaticDir(t *testing.T) {
authHandler, postHandler, voteHandler, userHandler, apiHandler, authService := setupTestHandlers() cfg := createDefaultRouterConfig()
cfg.StaticDir = ""
router := NewRouter(RouterConfig{ router := createTestRouter(cfg)
StaticDir: "",
APIHandler: apiHandler,
AuthHandler: authHandler,
PostHandler: postHandler,
VoteHandler: voteHandler,
UserHandler: userHandler,
AuthService: authService,
RateLimitConfig: defaultRateLimitConfig(),
})
request := httptest.NewRequest(http.MethodGet, "/api", nil) request := httptest.NewRequest(http.MethodGet, "/api", nil)
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
@@ -344,20 +302,11 @@ func TestRouterWithEmptyStaticDir(t *testing.T) {
} }
func TestRouterWithAllFeaturesDisabled(t *testing.T) { func TestRouterWithAllFeaturesDisabled(t *testing.T) {
authHandler, postHandler, voteHandler, userHandler, apiHandler, authService := setupTestHandlers() cfg := createDefaultRouterConfig()
cfg.Debug = true
router := NewRouter(RouterConfig{ cfg.DisableCache = true
Debug: true, cfg.DisableCompression = true
DisableCache: true, router := createTestRouter(cfg)
DisableCompression: true,
APIHandler: apiHandler,
AuthHandler: authHandler,
PostHandler: postHandler,
VoteHandler: voteHandler,
UserHandler: userHandler,
AuthService: authService,
RateLimitConfig: defaultRateLimitConfig(),
})
request := httptest.NewRequest(http.MethodGet, "/api", nil) request := httptest.NewRequest(http.MethodGet, "/api", nil)
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
@@ -370,15 +319,9 @@ func TestRouterWithAllFeaturesDisabled(t *testing.T) {
} }
func TestRouterWithoutAPIHandler(t *testing.T) { func TestRouterWithoutAPIHandler(t *testing.T) {
authHandler, postHandler, voteHandler, userHandler, _, authService := setupTestHandlers() cfg := createDefaultRouterConfig()
router := NewRouter(RouterConfig{ cfg.APIHandler = nil
AuthHandler: authHandler, router := createTestRouter(cfg)
PostHandler: postHandler,
VoteHandler: voteHandler,
UserHandler: userHandler,
AuthService: authService,
RateLimitConfig: defaultRateLimitConfig(),
})
request := httptest.NewRequest(http.MethodGet, "/api", nil) request := httptest.NewRequest(http.MethodGet, "/api", nil)
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
@@ -391,17 +334,7 @@ func TestRouterWithoutAPIHandler(t *testing.T) {
} }
func TestRouterWithoutPageHandler(t *testing.T) { func TestRouterWithoutPageHandler(t *testing.T) {
authHandler, postHandler, voteHandler, userHandler, apiHandler, authService := setupTestHandlers() router := createTestRouter(createDefaultRouterConfig())
router := NewRouter(RouterConfig{
APIHandler: apiHandler,
AuthHandler: authHandler,
PostHandler: postHandler,
VoteHandler: voteHandler,
UserHandler: userHandler,
AuthService: authService,
RateLimitConfig: defaultRateLimitConfig(),
})
request := httptest.NewRequest(http.MethodGet, "/", nil) request := httptest.NewRequest(http.MethodGet, "/", nil)
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
@@ -414,17 +347,7 @@ func TestRouterWithoutPageHandler(t *testing.T) {
} }
func TestSwaggerRoute(t *testing.T) { func TestSwaggerRoute(t *testing.T) {
authHandler, postHandler, voteHandler, userHandler, apiHandler, authService := setupTestHandlers() router := createTestRouter(createDefaultRouterConfig())
router := NewRouter(RouterConfig{
APIHandler: apiHandler,
AuthHandler: authHandler,
PostHandler: postHandler,
VoteHandler: voteHandler,
UserHandler: userHandler,
AuthService: authService,
RateLimitConfig: defaultRateLimitConfig(),
})
request := httptest.NewRequest(http.MethodGet, "/swagger/", nil) request := httptest.NewRequest(http.MethodGet, "/swagger/", nil)
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
@@ -437,18 +360,9 @@ func TestSwaggerRoute(t *testing.T) {
} }
func TestStaticFileRoute(t *testing.T) { func TestStaticFileRoute(t *testing.T) {
authHandler, postHandler, voteHandler, userHandler, apiHandler, authService := setupTestHandlers() cfg := createDefaultRouterConfig()
cfg.StaticDir = "../../internal/static/"
router := NewRouter(RouterConfig{ router := createTestRouter(cfg)
StaticDir: "../../internal/static/",
APIHandler: apiHandler,
AuthHandler: authHandler,
PostHandler: postHandler,
VoteHandler: voteHandler,
UserHandler: userHandler,
AuthService: authService,
RateLimitConfig: defaultRateLimitConfig(),
})
request := httptest.NewRequest(http.MethodGet, "/static/css/main.css", nil) request := httptest.NewRequest(http.MethodGet, "/static/css/main.css", nil)
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
@@ -461,17 +375,7 @@ func TestStaticFileRoute(t *testing.T) {
} }
func TestRouterConfiguration(t *testing.T) { func TestRouterConfiguration(t *testing.T) {
authHandler, postHandler, voteHandler, userHandler, apiHandler, authService := setupTestHandlers() router := createTestRouter(createDefaultRouterConfig())
router := NewRouter(RouterConfig{
APIHandler: apiHandler,
AuthHandler: authHandler,
PostHandler: postHandler,
VoteHandler: voteHandler,
UserHandler: userHandler,
AuthService: authService,
RateLimitConfig: defaultRateLimitConfig(),
})
if router == nil { if router == nil {
t.Error("Router should not be nil") t.Error("Router should not be nil")
@@ -487,29 +391,484 @@ func TestRouterConfiguration(t *testing.T) {
} }
} }
func TestRouterMiddlewareIntegration(t *testing.T) { func TestAllRoutesExist(t *testing.T) {
authHandler, postHandler, voteHandler, userHandler, apiHandler, authService := setupTestHandlers() router := createTestRouter(createDefaultRouterConfig())
router := NewRouter(RouterConfig{ publicRoutes := []struct {
APIHandler: apiHandler, method string
AuthHandler: authHandler, path string
PostHandler: postHandler, description string
VoteHandler: voteHandler, }{
UserHandler: userHandler, {http.MethodGet, "/api", "API info"},
AuthService: authService, {http.MethodGet, "/health", "Health check"},
RateLimitConfig: defaultRateLimitConfig(), {http.MethodGet, "/metrics", "Metrics"},
}) {http.MethodGet, "/robots.txt", "Robots.txt"},
{http.MethodGet, "/api/posts", "Get posts"},
if router == nil { {http.MethodGet, "/api/posts/search", "Search posts"},
t.Error("Router should not be nil") {http.MethodGet, "/api/posts/title", "Fetch title from URL"},
{http.MethodGet, "/api/posts/1", "Get post by ID"},
{http.MethodPost, "/api/auth/register", "Register"},
{http.MethodPost, "/api/auth/login", "Login"},
{http.MethodPost, "/api/auth/refresh", "Refresh token"},
{http.MethodGet, "/api/auth/confirm", "Confirm email"},
{http.MethodPost, "/api/auth/resend-verification", "Resend verification"},
{http.MethodPost, "/api/auth/forgot-password", "Forgot password"},
{http.MethodPost, "/api/auth/reset-password", "Reset password"},
{http.MethodPost, "/api/auth/account/confirm", "Confirm account deletion"},
} }
request := httptest.NewRequest(http.MethodGet, "/api", nil) protectedRoutes := []struct {
method string
path string
description string
}{
{http.MethodGet, "/api/auth/me", "Get current user"},
{http.MethodPost, "/api/auth/logout", "Logout"},
{http.MethodPost, "/api/auth/revoke", "Revoke token"},
{http.MethodPost, "/api/auth/revoke-all", "Revoke all tokens"},
{http.MethodPut, "/api/auth/email", "Update email"},
{http.MethodPut, "/api/auth/username", "Update username"},
{http.MethodPut, "/api/auth/password", "Update password"},
{http.MethodDelete, "/api/auth/account", "Delete account"},
{http.MethodPost, "/api/posts", "Create post"},
{http.MethodPut, "/api/posts/1", "Update post"},
{http.MethodDelete, "/api/posts/1", "Delete post"},
{http.MethodPost, "/api/posts/1/vote", "Cast vote"},
{http.MethodDelete, "/api/posts/1/vote", "Remove vote"},
{http.MethodGet, "/api/posts/1/vote", "Get user vote"},
{http.MethodGet, "/api/posts/1/votes", "Get post votes"},
{http.MethodGet, "/api/users", "Get users"},
{http.MethodPost, "/api/users", "Create user"},
{http.MethodGet, "/api/users/1", "Get user by ID"},
{http.MethodGet, "/api/users/1/posts", "Get user posts"},
}
for _, route := range publicRoutes {
t.Run(route.description+" "+route.method+" "+route.path, func(t *testing.T) {
invalidMethod := http.MethodPatch
switch route.method {
case http.MethodGet:
invalidMethod = http.MethodDelete
case http.MethodPost:
invalidMethod = http.MethodGet
}
request := httptest.NewRequest(invalidMethod, route.path, nil)
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, request) router.ServeHTTP(recorder, request)
if recorder.Code == 0 { routeExists := recorder.Code == http.StatusMethodNotAllowed
t.Error("Router should return a status code")
if !routeExists {
request = httptest.NewRequest(route.method, route.path, nil)
recorder = httptest.NewRecorder()
router.ServeHTTP(recorder, request)
if recorder.Code == http.StatusNotFound && route.path != "/api/posts/1" && route.path != "/robots.txt" {
t.Errorf("Route %s %s should exist, got 404", route.method, route.path)
}
}
})
}
for _, route := range protectedRoutes {
t.Run(route.description+" "+route.method+" "+route.path, func(t *testing.T) {
request := httptest.NewRequest(route.method, route.path, nil)
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, request)
if recorder.Code == http.StatusNotFound {
t.Errorf("Route %s %s should exist, got 404", route.method, route.path)
}
if recorder.Code != http.StatusUnauthorized {
t.Errorf("Protected route %s %s should return 401 without auth, got %d", route.method, route.path, recorder.Code)
}
})
}
}
func TestRouteParameters(t *testing.T) {
router := createTestRouter(createDefaultRouterConfig())
testCases := []struct {
name string
method string
pathPattern string
testIDs []string
isProtected bool
}{
{
name: "Get post by ID",
method: http.MethodGet,
pathPattern: "/api/posts/{id}",
testIDs: []string{"1", "42", "999", "12345"},
isProtected: false,
},
{
name: "Update post by ID",
method: http.MethodPut,
pathPattern: "/api/posts/{id}",
testIDs: []string{"1", "42", "999"},
isProtected: true,
},
{
name: "Delete post by ID",
method: http.MethodDelete,
pathPattern: "/api/posts/{id}",
testIDs: []string{"1", "42", "999"},
isProtected: true,
},
{
name: "Get user by ID",
method: http.MethodGet,
pathPattern: "/api/users/{id}",
testIDs: []string{"1", "42", "999", "12345"},
isProtected: true,
},
{
name: "Get user posts by user ID",
method: http.MethodGet,
pathPattern: "/api/users/{id}/posts",
testIDs: []string{"1", "42", "999", "12345"},
isProtected: true,
},
{
name: "Cast vote for post ID",
method: http.MethodPost,
pathPattern: "/api/posts/{id}/vote",
testIDs: []string{"1", "42", "999"},
isProtected: true,
},
{
name: "Remove vote for post ID",
method: http.MethodDelete,
pathPattern: "/api/posts/{id}/vote",
testIDs: []string{"1", "42", "999"},
isProtected: true,
},
{
name: "Get user vote for post ID",
method: http.MethodGet,
pathPattern: "/api/posts/{id}/vote",
testIDs: []string{"1", "42", "999", "12345"},
isProtected: true,
},
{
name: "Get post votes by post ID",
method: http.MethodGet,
pathPattern: "/api/posts/{id}/votes",
testIDs: []string{"1", "42", "999", "12345"},
isProtected: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
for _, id := range tc.testIDs {
path := replaceID(tc.pathPattern, id)
t.Run("ID_"+id, func(t *testing.T) {
request := httptest.NewRequest(http.MethodPatch, path, nil)
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, request)
routeExists := recorder.Code == http.StatusMethodNotAllowed
request = httptest.NewRequest(tc.method, path, nil)
recorder = httptest.NewRecorder()
router.ServeHTTP(recorder, request)
if !routeExists {
if recorder.Code == http.StatusNotFound {
t.Errorf("Route %s %s should exist with ID %s, got 404", tc.method, path, id)
return
}
}
if tc.isProtected {
if recorder.Code != http.StatusUnauthorized {
t.Errorf("Protected route %s %s should return 401 without auth, got %d", tc.method, path, recorder.Code)
}
}
})
}
})
}
}
func replaceID(pattern, id string) string {
return strings.Replace(pattern, "{id}", id, 1)
}
func TestInvalidRouteParameters(t *testing.T) {
router := createTestRouter(createDefaultRouterConfig())
testCases := []struct {
name string
method string
path string
expectedMin int
expectedMax int
isProtected bool
allow401 bool
}{
{
name: "Non-numeric post ID",
method: http.MethodGet,
path: "/api/posts/abc",
expectedMin: http.StatusBadRequest,
expectedMax: http.StatusBadRequest,
isProtected: false,
},
{
name: "Negative post ID",
method: http.MethodGet,
path: "/api/posts/-1",
expectedMin: http.StatusBadRequest,
expectedMax: http.StatusBadRequest,
isProtected: false,
},
{
name: "Zero post ID",
method: http.MethodGet,
path: "/api/posts/0",
expectedMin: http.StatusBadRequest,
expectedMax: http.StatusNotFound,
isProtected: false,
},
{
name: "Post ID with special characters",
method: http.MethodGet,
path: "/api/posts/123@456",
expectedMin: http.StatusBadRequest,
expectedMax: http.StatusBadRequest,
isProtected: false,
},
{
name: "Post ID with encoded spaces",
method: http.MethodGet,
path: "/api/posts/12%2034",
expectedMin: http.StatusBadRequest,
expectedMax: http.StatusBadRequest,
isProtected: false,
},
{
name: "Non-numeric user ID",
method: http.MethodGet,
path: "/api/users/xyz",
expectedMin: http.StatusBadRequest,
expectedMax: http.StatusUnauthorized,
isProtected: true,
allow401: true,
},
{
name: "Negative user ID",
method: http.MethodGet,
path: "/api/users/-5",
expectedMin: http.StatusBadRequest,
expectedMax: http.StatusUnauthorized,
isProtected: true,
allow401: true,
},
{
name: "Non-numeric post ID in vote route",
method: http.MethodGet,
path: "/api/posts/invalid/vote",
expectedMin: http.StatusBadRequest,
expectedMax: http.StatusUnauthorized,
isProtected: true,
allow401: true,
},
{
name: "Very large post ID",
method: http.MethodGet,
path: "/api/posts/999999999999",
expectedMin: http.StatusBadRequest,
expectedMax: http.StatusNotFound,
isProtected: false,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
request := httptest.NewRequest(tc.method, tc.path, nil)
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, request)
if tc.isProtected && tc.allow401 {
if recorder.Code != http.StatusUnauthorized && (recorder.Code < tc.expectedMin || recorder.Code > tc.expectedMax) {
t.Errorf("Protected route %s %s with invalid parameter should return 401 or status between %d and %d, got %d", tc.method, tc.path, tc.expectedMin, tc.expectedMax, recorder.Code)
}
} else {
if recorder.Code < tc.expectedMin || recorder.Code > tc.expectedMax {
t.Errorf("Route %s %s should return status between %d and %d, got %d", tc.method, tc.path, tc.expectedMin, tc.expectedMax, recorder.Code)
}
if recorder.Code != http.StatusNotFound && recorder.Code < 400 {
t.Errorf("Route %s %s with invalid parameter should return error status (4xx), got %d", tc.method, tc.path, recorder.Code)
}
}
})
}
}
func TestQueryParameters(t *testing.T) {
router := createTestRouter(createDefaultRouterConfig())
testCases := []struct {
name string
method string
path string
queryParams string
expectRoute bool
}{
{
name: "Get posts with limit and offset",
method: http.MethodGet,
path: "/api/posts",
queryParams: "limit=10&offset=5",
expectRoute: true,
},
{
name: "Get posts with only limit",
method: http.MethodGet,
path: "/api/posts",
queryParams: "limit=20",
expectRoute: true,
},
{
name: "Get posts with only offset",
method: http.MethodGet,
path: "/api/posts",
queryParams: "offset=10",
expectRoute: true,
},
{
name: "Search posts with query parameter",
method: http.MethodGet,
path: "/api/posts/search",
queryParams: "q=test",
expectRoute: true,
},
{
name: "Search posts with query, limit, and offset",
method: http.MethodGet,
path: "/api/posts/search",
queryParams: "q=test&limit=15&offset=3",
expectRoute: true,
},
{
name: "Fetch title with URL parameter",
method: http.MethodGet,
path: "/api/posts/title",
queryParams: "url=https://example.com",
expectRoute: true,
},
{
name: "Confirm email with token parameter",
method: http.MethodGet,
path: "/api/auth/confirm",
queryParams: "token=abc123",
expectRoute: true,
},
{
name: "Get posts with invalid limit",
method: http.MethodGet,
path: "/api/posts",
queryParams: "limit=abc",
expectRoute: true,
},
{
name: "Get posts with negative limit",
method: http.MethodGet,
path: "/api/posts",
queryParams: "limit=-5",
expectRoute: true,
},
{
name: "Get posts with negative offset",
method: http.MethodGet,
path: "/api/posts",
queryParams: "offset=-10",
expectRoute: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
fullPath := tc.path
if tc.queryParams != "" {
fullPath += "?" + tc.queryParams
}
request := httptest.NewRequest(tc.method, fullPath, nil)
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, request)
if tc.expectRoute {
if recorder.Code == http.StatusNotFound {
t.Errorf("Route %s %s should exist with query parameters, got 404", tc.method, fullPath)
}
}
})
}
}
func TestRouteConflicts(t *testing.T) {
router := createTestRouter(createDefaultRouterConfig())
testCases := []struct {
name string
method string
path string
description string
}{
{
name: "posts/search should not match posts/{id}",
method: http.MethodGet,
path: "/api/posts/search",
description: "search route should be matched, not treated as ID",
},
{
name: "posts/title should not match posts/{id}",
method: http.MethodGet,
path: "/api/posts/title",
description: "title route should be matched, not treated as ID",
},
{
name: "posts/{id} should work with numeric ID",
method: http.MethodGet,
path: "/api/posts/123",
description: "numeric ID should match {id} route",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
request := httptest.NewRequest(tc.method, tc.path, nil)
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, request)
switch tc.path {
case "/api/posts/search":
if recorder.Code == http.StatusNotFound {
t.Errorf("%s: Route %s %s should exist (not 404), got %d", tc.description, tc.method, tc.path, recorder.Code)
}
case "/api/posts/title":
if recorder.Code == http.StatusNotFound {
t.Errorf("%s: Route %s %s should exist (not 404), got %d", tc.description, tc.method, tc.path, recorder.Code)
}
case "/api/posts/123":
if recorder.Code == http.StatusNotFound {
return
}
if recorder.Code < 400 {
t.Errorf("%s: Route %s %s should return 4xx or 5xx, got %d", tc.description, tc.method, tc.path, recorder.Code)
}
}
})
} }
} }

View File

@@ -1,12 +1,93 @@
package services package services
import ( import (
"crypto/rand"
"crypto/sha256"
"encoding/hex"
"errors"
"fmt" "fmt"
"net/mail"
"strings"
"goyco/internal/config" "goyco/internal/config"
"goyco/internal/database" "goyco/internal/database"
) )
const (
defaultTokenExpirationHours = 24
verificationTokenBytes = 32
deletionTokenExpirationHours = 24
)
var (
ErrInvalidCredentials = errors.New("invalid credentials")
ErrInvalidToken = errors.New("invalid or expired token")
ErrUsernameTaken = errors.New("username already exists")
ErrEmailTaken = errors.New("email already exists")
ErrInvalidEmail = errors.New("invalid email address")
ErrPasswordTooShort = errors.New("password too short")
ErrEmailNotVerified = errors.New("email not verified")
ErrAccountLocked = errors.New("account is locked")
ErrInvalidVerificationToken = errors.New("invalid verification token")
ErrEmailSenderUnavailable = errors.New("email sender not configured")
ErrDeletionEmailFailed = errors.New("account deletion email failed")
ErrInvalidDeletionToken = errors.New("invalid account deletion token")
ErrUserNotFound = errors.New("user not found")
ErrDeletionRequestNotFound = errors.New("deletion request not found")
)
type AuthResult struct {
AccessToken string `json:"access_token" example:"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..."`
RefreshToken string `json:"refresh_token" example:"a1b2c3d4e5f6g7h8i9j0k1l2m3n4o5p6q7r8s9t0"`
User *database.User `json:"user"`
}
type RegistrationResult struct {
User *database.User `json:"user"`
VerificationSent bool `json:"verification_sent"`
}
func normalizeEmail(email string) (string, error) {
trimmed := strings.TrimSpace(email)
if trimmed == "" {
return "", fmt.Errorf("email is required")
}
parsed, err := mail.ParseAddress(trimmed)
if err != nil {
return "", ErrInvalidEmail
}
return strings.ToLower(parsed.Address), nil
}
func generateVerificationToken() (string, string, error) {
buf := make([]byte, verificationTokenBytes)
if _, err := rand.Read(buf); err != nil {
return "", "", fmt.Errorf("generate verification token: %w", err)
}
token := hex.EncodeToString(buf)
hashed := HashVerificationToken(token)
return token, hashed, nil
}
func HashVerificationToken(token string) string {
sum := sha256.Sum256([]byte(token))
return hex.EncodeToString(sum[:])
}
func sanitizeUser(user *database.User) *database.User {
if user == nil {
return nil
}
copy := *user
copy.Password = ""
copy.EmailVerificationToken = ""
return &copy
}
type AuthFacade struct { type AuthFacade struct {
registrationService *RegistrationService registrationService *RegistrationService
passwordResetService *PasswordResetService passwordResetService *PasswordResetService

View File

@@ -1,35 +0,0 @@
package services
import (
"errors"
"goyco/internal/database"
)
var (
ErrInvalidCredentials = errors.New("invalid credentials")
ErrInvalidToken = errors.New("invalid or expired token")
ErrUsernameTaken = errors.New("username already exists")
ErrEmailTaken = errors.New("email already exists")
ErrInvalidEmail = errors.New("invalid email address")
ErrPasswordTooShort = errors.New("password too short")
ErrEmailNotVerified = errors.New("email not verified")
ErrAccountLocked = errors.New("account is locked")
ErrInvalidVerificationToken = errors.New("invalid verification token")
ErrEmailSenderUnavailable = errors.New("email sender not configured")
ErrDeletionEmailFailed = errors.New("account deletion email failed")
ErrInvalidDeletionToken = errors.New("invalid account deletion token")
ErrUserNotFound = errors.New("user not found")
ErrDeletionRequestNotFound = errors.New("deletion request not found")
)
type AuthResult struct {
AccessToken string `json:"access_token" example:"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..."`
RefreshToken string `json:"refresh_token" example:"a1b2c3d4e5f6g7h8i9j0k1l2m3n4o5p6q7r8s9t0"`
User *database.User `json:"user"`
}
type RegistrationResult struct {
User *database.User `json:"user"`
VerificationSent bool `json:"verification_sent"`
}

View File

@@ -1,59 +0,0 @@
package services
import (
"crypto/rand"
"crypto/sha256"
"encoding/hex"
"fmt"
"net/mail"
"strings"
"goyco/internal/database"
)
const (
defaultTokenExpirationHours = 24
verificationTokenBytes = 32
deletionTokenExpirationHours = 24
)
func normalizeEmail(email string) (string, error) {
trimmed := strings.TrimSpace(email)
if trimmed == "" {
return "", fmt.Errorf("email is required")
}
parsed, err := mail.ParseAddress(trimmed)
if err != nil {
return "", ErrInvalidEmail
}
return strings.ToLower(parsed.Address), nil
}
func generateVerificationToken() (string, string, error) {
buf := make([]byte, verificationTokenBytes)
if _, err := rand.Read(buf); err != nil {
return "", "", fmt.Errorf("generate verification token: %w", err)
}
token := hex.EncodeToString(buf)
hashed := HashVerificationToken(token)
return token, hashed, nil
}
func HashVerificationToken(token string) string {
sum := sha256.Sum256([]byte(token))
return hex.EncodeToString(sum[:])
}
func sanitizeUser(user *database.User) *database.User {
if user == nil {
return nil
}
copy := *user
copy.Password = ""
copy.EmailVerificationToken = ""
return &copy
}

View File

@@ -2,43 +2,70 @@ package templates
import ( import (
"html/template" "html/template"
"io/fs"
"path/filepath" "path/filepath"
"strings"
"testing" "testing"
"time"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func TestTemplateParsing(t *testing.T) { func templateFuncMap() template.FuncMap {
templateDir := "./" return template.FuncMap{
"formatTime": func(t time.Time) string {
if t.IsZero() {
return ""
}
return t.Format("02 Jan 2006 15:04")
},
"truncate": func(s string, length int) string {
if len(s) <= length {
return s
}
if length <= 3 {
return s[:length]
}
return s[:length-3] + "..."
},
"substr": func(s string, start, length int) string {
if start >= len(s) {
return ""
}
end := min(start+length, len(s))
return s[start:end]
},
"upper": strings.ToUpper,
}
}
var templateFiles []string func TestTemplateParsing(t *testing.T) {
err := filepath.WalkDir(templateDir, func(path string, d fs.DirEntry, err error) error { layoutPath := filepath.Join(".", "base.gohtml")
if err != nil { require.FileExists(t, layoutPath, "base layout is required for all templates")
return err
} partials, err := filepath.Glob(filepath.Join(".", "partials", "*.gohtml"))
if !d.IsDir() && filepath.Ext(path) == ".gohtml" {
templateFiles = append(templateFiles, path)
}
return nil
})
require.NoError(t, err) require.NoError(t, err)
tmpl := template.New("test") pages, err := filepath.Glob(filepath.Join(".", "*.gohtml"))
require.NoError(t, err)
require.NotEmpty(t, pages, "no page templates found")
tmpl = tmpl.Funcs(template.FuncMap{ for _, page := range pages {
"formatTime": func(any) string { return "2024-01-01" }, if filepath.Base(page) == "base.gohtml" {
"eq": func(a, b any) bool { return a == b }, continue
"ne": func(a, b any) bool { return a != b }, }
"len": func(s any) int { return 0 },
"range": func(s any) any { return s },
})
for _, file := range templateFiles { page := page
t.Run(file, func(t *testing.T) { t.Run(filepath.Base(page), func(t *testing.T) {
_, err := tmpl.ParseFiles(file) t.Parallel()
assert.NoError(t, err, "Template %s should parse without errors", file)
files := append([]string{layoutPath}, partials...)
files = append(files, page)
tmpl, err := template.New(filepath.Base(page)).Funcs(templateFuncMap()).ParseFiles(files...)
require.NoError(t, err)
require.NotNil(t, tmpl.Lookup("layout"), "layout template should be available")
require.NotNil(t, tmpl.Lookup("content"), "content block should be defined by page templates")
}) })
} }
} }

View File

@@ -44,5 +44,3 @@ GRANT ALL PRIVILEGES ON DATABASE goyco TO goyco;
EOF EOF
echo "PostgreSQL 18 installed, database 'goyco' and user 'goyco' set up." echo "PostgreSQL 18 installed, database 'goyco' and user 'goyco' set up."