Compare commits

...

109 Commits

Author SHA1 Message Date
07c6b89525 docs: remove project structure, boring and hard to maintain 2025-12-20 17:35:16 +01:00
817205d42f refactor: modernize using min() 2025-12-16 15:45:51 +01:00
199ac143a4 refactor: replace interface{} by any 2025-12-16 15:05:23 +01:00
aa7e259ed0 format: shfmt 2025-12-16 15:02:42 +01:00
4587609e17 refactor: create createTestRouter and test edge cases 2025-12-14 21:14:42 +01:00
33da6503e3 test: also test put/delete routes 2025-12-14 21:06:15 +01:00
cafc44ed77 test: add a test for route parameters 2025-12-14 21:04:36 +01:00
1480135e75 test: verified all routes to exist 2025-12-14 21:02:25 +01:00
02a764c736 clean: remove merged files 2025-12-14 20:52:14 +01:00
6834ad7764 refactor: merge facade, types and utils into one auth_service.go 2025-12-14 20:52:03 +01:00
dcf054046f test: fix parallel processor test expectations and setup 2025-12-10 07:30:27 +01:00
d2a788933d fix: track completed items in the main loop instead of using the index 2025-12-10 07:29:55 +01:00
18be3950dc clean: obsolete function 2025-12-09 22:07:30 +01:00
f9cb140e95 clean: removed the obsolete functions outputMessage and outputError 2025-12-09 22:06:12 +01:00
86d4835ccf feat: seed user is now uniq 2025-12-09 22:03:26 +01:00
feddb2ed43 test: new unit test for EnsureSeedUser 2025-12-09 22:03:16 +01:00
457b5c88e2 refactor: improve seed consistency validation 2025-12-09 21:53:03 +01:00
a8d363b2bf fix: templates now parse with the same func map as the page handler 2025-12-09 21:37:21 +01:00
0cd68e847c refactor: add a helper to centralize CSRF token retrieval 2025-12-09 15:58:28 +01:00
df6aeed713 docs: unocss it is 2025-12-04 20:43:30 +01:00
785faeb60c feat: update alpine to 3.23 2025-12-04 10:01:19 +01:00
0623c027ba docs: prepare CONTRIBUTING.md 2025-12-03 20:57:32 +01:00
d4e91b6034 refactor: complete refactor and better helpers use 2025-11-29 15:19:41 +01:00
7d46d3e81b clean: remove the unused expectedValue in assertHeader (always set to "") 2025-11-29 15:19:28 +01:00
216aaf3117 refactor: clean code and use new request helpers 2025-11-29 14:58:52 +01:00
435047ad0c refactor: clean code 2025-11-29 14:58:37 +01:00
b7ee8bd11d refactor: clean variable names and use new request helpers 2025-11-29 14:58:20 +01:00
040cd48be8 refactor: clean variables 2025-11-29 14:58:07 +01:00
2dd16e0e00 refactor: complete 2025-11-29 14:56:18 +01:00
d6db70cc79 refactor: clean code and variables, use new request helpers 2025-11-29 14:55:47 +01:00
58e10ade7d refactor: clean variable names and modernize code 2025-11-29 14:50:35 +01:00
7403a75d8e refactor: clean variable naming 2025-11-29 14:46:26 +01:00
b429bc11af refactor: clean code and use new request helpers 2025-11-29 14:41:38 +01:00
2ec5c28fb5 refactor: rename variables and clean code 2025-11-29 14:37:18 +01:00
3743a99e40 refactor: req -> request, rec -> recorder, reqBody -> requestBody... 2025-11-29 14:21:07 +01:00
5710921b87 refactor: use new request helpers 2025-11-29 14:17:25 +01:00
84d9c81484 refactor: rec -> recorder, req -> request and modernize loop 2025-11-29 14:15:07 +01:00
b0c2038927 feat: add new helpers to make requests properly in integration tests 2025-11-29 14:11:32 +01:00
fd88931146 refactor: variable names and modernize loop 2025-11-25 14:37:59 +01:00
6ce0f4dfad refactor: name variables 2025-11-25 10:13:10 +01:00
68e3dceefc refactor: name variables 2025-11-25 10:08:48 +01:00
cded14c526 fix: force correct mime types for static files after modifying compression middleware's buffering 2025-11-24 07:53:08 +01:00
cfca668ca6 docs: clean readme 2025-11-23 21:56:53 +01:00
279255b587 fix: don't let rate limiting fails the test 2025-11-23 21:48:11 +01:00
b83f8c2228 fix: update ValidationMiddleware to return a JSON error response when JSON decoding fails 2025-11-23 21:42:27 +01:00
aabc48128c fix: use router in handlers integration tests (for dto validation) 2025-11-23 15:10:55 +01:00
68716b977b fix: verify XSS sanitization in handler response instead of repository stub 2025-11-23 15:01:54 +01:00
dbe1600632 fix: indentation 2025-11-23 14:49:31 +01:00
458e25cf79 fix: modify compression middleware to pass through redirects immediately without buffering 2025-11-23 14:48:59 +01:00
d4595d8dbf fix: properly encoding the flash message in the redirect URL 2025-11-23 14:48:39 +01:00
c5418f4e4c docs: update swagger 2025-11-23 14:26:52 +01:00
db0369225e refactor: update references to VoteRequest 2025-11-23 14:26:45 +01:00
07ac965b3d refactor: use consistent naming (VoteRequest -> CastVoteRequest) 2025-11-23 14:26:19 +01:00
e2e5d42035 feat: add SetValidatedDTOInContext to support test helper functions 2025-11-23 14:22:59 +01:00
6e4b41894f fix: update test cases to use createCreatePostRequests 2025-11-23 14:22:35 +01:00
0a8ed2e27c fix: add explicite validation check for empty url, title and content length 2025-11-23 14:21:30 +01:00
216e8657f6 feat: add generic createRequestWithDTO along with helpers functions 2025-11-23 14:20:59 +01:00
fb7206c0a2 fix: test context handling 2025-11-23 14:20:09 +01:00
c25926514b fix: add explicit empty-field validation check in handlers 2025-11-23 14:19:54 +01:00
964785e494 docs: update swagger 2025-11-23 13:47:38 +01:00
9c67cd2a47 feat: update vote handler to use dto VoteRequest and update MountRoutes 2025-11-23 13:47:31 +01:00
8b5cc8e939 feat: add VoteRequest with its validation types 2025-11-23 13:46:51 +01:00
0e71b28615 feat: update CreateUser to use dto.RegisterRequest and update MountRoutes to apply validation middleware 2025-11-23 13:43:47 +01:00
cd740da57a feat: update methods to use validated DTOs and update MountRoutes 2025-11-23 13:43:14 +01:00
abe4a3dc88 feat: update handlers to use GetValidatedDTO instead of manual decoding and update MountRoutes to wrap handlers with WithValidation for all DTO-based routes 2025-11-23 13:42:52 +01:00
738243d945 feat: add ValidationMiddleware to RouteModuleConfig 2025-11-23 13:41:55 +01:00
4fbdfb6e4a feat: add two helpers function to retrieve validated DTOs from request context and to apply validation middleware 2025-11-23 13:41:07 +01:00
6bb3a78b88 feat: Add ValidationMiddleware to router configuration 2025-11-23 13:40:31 +01:00
54e37e59fc docs: update swagger 2025-11-23 13:35:00 +01:00
5d4b38ddc4 feat: add validation tags to request DTOs 2025-11-23 13:34:53 +01:00
7dc119ecde docs: update swagger 2025-11-23 13:17:14 +01:00
52c9f4a02b feat: add internal/dto to swagger directories 2025-11-23 13:16:44 +01:00
be91a135bc clean: empty line 2025-11-23 13:14:41 +01:00
2d7ff9778b feat: update swagger comments following dtos relocation 2025-11-23 13:14:07 +01:00
4ff3fd3583 refactor: remove UpdatePostRequest definition and update swagger comments 2025-11-23 13:13:53 +01:00
73121cad15 refactor: remove all request DTO, update swagger comments and update token related methods to use dto ones 2025-11-23 13:13:23 +01:00
c5bf1b2fd8 feat: locate post-related request DTOs 2025-11-23 13:12:36 +01:00
eedebe60d1 feat: locate auth-related request DTOs 2025-11-23 13:12:10 +01:00
80fb37371f update: fix go version and update alpine to 3.22 2025-11-23 10:44:28 +01:00
fea49fad8d fix: add missing method to mock 2025-11-21 17:07:26 +01:00
4b04461ebb style: minor formatting adjustments 2025-11-21 17:06:04 +01:00
533e8c3d46 feat: add GetByUsernamePrefixFn field and method to UserRepositoryStub 2025-11-21 17:05:48 +01:00
df568291f1 feat: add GetByUsernamePrefix implementation to MockUserRepository 2025-11-21 17:05:31 +01:00
81acce62b1 feat: add GetByUsernamePrefix method to interface and add implementation 2025-11-21 17:05:01 +01:00
989a61e7d5 feat: use getByUsernamePrefix to optimize findExistingSeedUser() 2025-11-21 17:04:35 +01:00
3ffd83b0fb feat: ignore docs in make format 2025-11-21 17:02:06 +01:00
62d466e4fa refactor: use go generics 2025-11-21 16:56:26 +01:00
0cd428d5d9 feat: use connection pooling instead of a single connection 2025-11-21 16:53:46 +01:00
5c239ad61d feat: add missing GetVoteCountsByPostID method to the errorVoteRepository test mock 2025-11-21 16:50:23 +01:00
01f2b1fe75 feat: remove loop and use GetVoteCountsByPostID 2025-11-21 16:48:48 +01:00
28134c101c feat: add GetVoteCountsByPostID to the mock for testing 2025-11-21 16:48:15 +01:00
2f78370d43 feat: GetVoteCountsByPostID: use a single sql query to returns up votes and down votes counts 2025-11-21 16:47:52 +01:00
39598a166d feat: remove redundat getbyemail call to reduce db query by 2 (1Q/user creation instead of 2) 2025-11-21 16:43:46 +01:00
fa9474d863 revert: db transaction use, avoiding the pgx RETURNING issue while maintaining data consistency 2025-11-21 16:31:06 +01:00
34a97994b3 feat: improve testing to use production code paths and better coverage 2025-11-21 16:26:21 +01:00
eb5f93ffd0 clean: remove duplicate sequential helpers 2025-11-21 16:25:27 +01:00
697f201d60 feat: use database transactions to ensure atomicity 2025-11-21 16:21:04 +01:00
f4ab8bda45 feat: transaction rollback test 2025-11-21 16:20:41 +01:00
65576cc623 feat: keep seeding fast and predictable even when parallelized 2025-11-21 16:16:35 +01:00
a5b4e9bf25 feat: update tests to pass precomputed hashes 2025-11-21 16:11:42 +01:00
c020517ccf feat: reduce hashing cost by removing redundant password hashing 2025-11-21 16:11:33 +01:00
4cdda3f944 feat: remove bcrypt and use a precompute hash 2025-11-21 16:11:08 +01:00
ff471cd5dd fix: loop 2025-11-21 15:39:08 +01:00
df5e67c7f3 feat: add idempotency tests 2025-11-21 15:34:08 +01:00
b2580d2380 feat: make seeding idempotente 2025-11-21 15:33:59 +01:00
4749213bf0 feat: update test to accept randomized seed user identities 2025-11-21 15:26:29 +01:00
6470425b96 feat: avoid unique constraint failures on repeat runs by randomizing seed identities 2025-11-21 15:26:05 +01:00
14ae6f815b feat: update tests to verify clamping 2025-11-21 15:21:05 +01:00
73083e4188 feat: check zero/negative value in seeding 2025-11-21 15:20:57 +01:00
65 changed files with 4426 additions and 3741 deletions

1
.prettierignore Normal file
View File

@@ -0,0 +1 @@
docs/

View File

@@ -1,4 +1,4 @@
ARG GO_VERSION=1.25.3 ARG GO_VERSION=1.25.4
# Building the binary using a golang alpine image # Building the binary using a golang alpine image
FROM golang:${GO_VERSION}-alpine AS go-builder FROM golang:${GO_VERSION}-alpine AS go-builder
@@ -11,7 +11,7 @@ ARG TARGETARCH=amd64
RUN CGO_ENABLED=0 GOOS=${TARGETOS} GOARCH=${TARGETARCH} go build -ldflags="-s -w" -o /out/goyco ./cmd/goyco RUN CGO_ENABLED=0 GOOS=${TARGETOS} GOARCH=${TARGETARCH} go build -ldflags="-s -w" -o /out/goyco ./cmd/goyco
# building the application image # building the application image
FROM alpine:3.21 FROM alpine:3.23
RUN addgroup -S goyco && adduser -S -G goyco goyco \ RUN addgroup -S goyco && adduser -S -G goyco goyco \
&& apk add --no-cache ca-certificates tzdata && apk add --no-cache ca-certificates tzdata
WORKDIR /app WORKDIR /app

View File

@@ -44,8 +44,8 @@ clean:
rm -fr dist/* rm -fr dist/*
format: format:
$(PRETTIER) -w . $(PRETTIER) -w . --ignore-path .prettierignore
$(GO) fmt ./... $(GO) fmt $(shell $(GO) list ./... | grep -v 'docs')
lint: lint:
$(GOLANGCI_LINT) run $(GOLANGCI_LINT) run

View File

@@ -354,44 +354,6 @@ It will start the application in development mode. You can also run it as a daem
Then, use `./bin/goyco` to manage the application and notably to seed the database with sample data. Then, use `./bin/goyco` to manage the application and notably to seed the database with sample data.
### Project Structure
```bash
goyco/
├── bin/ # Compiled binaries (created after build)
├── cmd/
│ └── goyco/ # Main CLI application entrypoint
├── docker/ # Docker Compose & related files
├── docs/ # Documentation and API specs
├── internal/
│ ├── config/ # Configuration management
│ ├── database/ # Database models and access
│ ├── dto/ # Data Transfer Objects (DTOs)
│ ├── e2e/ # End-to-end tests
│ ├── fuzz/ # Fuzz tests
│ ├── handlers/ # HTTP handlers
│ ├── integration/ # Integration tests
│ ├── middleware/ # HTTP middleware
│ ├── repositories/ # Data access layer
│ ├── security/ # Security and auth logic
│ ├── server/ # HTTP server implementation
│ ├── services/ # Business logic
│ ├── static/ # Static web assets
│ ├── templates/ # HTML templates
│ ├── testutils/ # Test helpers/utilities
│ ├── validation/ # Input validation
│ └── version/ # Version information
├── scripts/ # Utility/maintenance scripts
├── services/
│ └── goyco.service # Systemd service unit example
├── .env.example # Environment variable example
├── AUTHORS # Authors file
├── Dockerfile # Docker build file
├── LICENSE # License file
├── Makefile # Project build/test targets
└── README.md # This file
```
### Testing ### Testing
```bash ```bash
@@ -437,31 +399,10 @@ This will regenerate the swagger documentation and update the `docs/swagger.json
- [ ] add right management within the app - [ ] add right management within the app
- [ ] add an admin backoffice to manage rights, users, content and settings - [ ] add an admin backoffice to manage rights, users, content and settings
- [ ] add a way to run read-only communities - [ ] add a way to run read-only communities
- [ ] maybe use a css framework instead of raw css - [ ] migrate raw CSS to UnoCSS
- [ ] kubernetes deployment - [ ] kubernetes deployment
- [ ] store configuration in the database - [ ] store configuration in the database
## Contributing
Feedbacks are welcome!
But as it's a personal gitea and you cannot create accounts, feel free to contact me at <sandro@cazzaniga.fr> to get one.
Once you have it, follow the usual workflow:
1. Fork the repository
2. Create a feature branch
3. Make your changes
4. Add tests for new functionality
5. Ensure all tests pass
6. Submit a pull request
Then, I'll review your changes and merge them if they are good.
## License ## License
This project is licensed under the GNU General Public License v3.0 or later (GPLv3+). See the [LICENSE](LICENSE) file for details. This project is licensed under the GNU General Public License v3.0 or later (GPLv3+). See the [LICENSE](LICENSE) file for details.
---
**Goyco** - A modern news aggregation platform built with Go, PostgreSQL and most importantly, love.

View File

@@ -8,9 +8,10 @@ import (
"os" "os"
"sync" "sync"
"gorm.io/gorm"
"goyco/internal/config" "goyco/internal/config"
"goyco/internal/database" "goyco/internal/database"
"gorm.io/gorm"
) )
var ErrHelpRequested = errors.New("help requested") var ErrHelpRequested = errors.New("help requested")
@@ -40,11 +41,11 @@ var (
) )
func defaultDBConnector(cfg *config.Config) (*gorm.DB, func() error, error) { func defaultDBConnector(cfg *config.Config) (*gorm.DB, func() error, error) {
db, err := database.Connect(cfg) poolManager, err := database.ConnectWithPool(cfg)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
return db, func() error { return database.Close(db) }, nil return poolManager.GetDB(), func() error { return poolManager.Close() }, nil
} }
func SetDBConnector(connector DBConnector) { func SetDBConnector(connector DBConnector) {
@@ -118,26 +119,6 @@ func outputJSON(v interface{}) error {
return encoder.Encode(v) return encoder.Encode(v)
} }
func outputMessage(message string, args ...interface{}) {
if IsJSONOutput() {
outputJSON(map[string]interface{}{
"message": fmt.Sprintf(message, args...),
})
} else {
fmt.Printf(message+"\n", args...)
}
}
func outputError(message string, args ...interface{}) {
if IsJSONOutput() {
outputJSON(map[string]interface{}{
"error": fmt.Sprintf(message, args...),
})
} else {
fmt.Fprintf(os.Stderr, message+"\n", args...)
}
}
func outputWarning(message string, args ...interface{}) { func outputWarning(message string, args ...interface{}) {
if IsJSONOutput() { if IsJSONOutput() {
outputJSON(map[string]interface{}{ outputJSON(map[string]interface{}{

View File

@@ -5,9 +5,10 @@ import (
"fmt" "fmt"
"os" "os"
"gorm.io/gorm"
"goyco/internal/config" "goyco/internal/config"
"goyco/internal/database" "goyco/internal/database"
"gorm.io/gorm"
) )
func HandleMigrateCommand(cfg *config.Config, name string, args []string) error { func HandleMigrateCommand(cfg *config.Config, name string, args []string) error {
@@ -37,7 +38,7 @@ func runMigrateCommand(db *gorm.DB) error {
return fmt.Errorf("run migrations: %w", err) return fmt.Errorf("run migrations: %w", err)
} }
if IsJSONOutput() { if IsJSONOutput() {
outputJSON(map[string]interface{}{ outputJSON(map[string]any{
"action": "migrations_applied", "action": "migrations_applied",
"status": "success", "status": "success",
}) })

View File

@@ -2,14 +2,13 @@ package commands
import ( import (
"context" "context"
"crypto/rand" cryptoRand "crypto/rand"
"fmt" "fmt"
"math/big" "math/rand"
"runtime" "runtime"
"sync" "sync"
"time" "time"
"golang.org/x/crypto/bcrypt"
"goyco/internal/database" "goyco/internal/database"
"goyco/internal/repositories" "goyco/internal/repositories"
) )
@@ -17,93 +16,163 @@ import (
type ParallelProcessor struct { type ParallelProcessor struct {
maxWorkers int maxWorkers int
timeout time.Duration timeout time.Duration
passwordHash string
randSource *rand.Rand
randMu sync.Mutex
} }
func NewParallelProcessor() *ParallelProcessor { func NewParallelProcessor() *ParallelProcessor {
maxWorkers := max(min(runtime.NumCPU(), 8), 2) maxWorkers := max(min(runtime.NumCPU(), 8), 2)
seed := time.Now().UnixNano()
seedBytes := make([]byte, 8)
if _, err := cryptoRand.Read(seedBytes); err == nil {
seed = int64(seedBytes[0])<<56 | int64(seedBytes[1])<<48 | int64(seedBytes[2])<<40 | int64(seedBytes[3])<<32 |
int64(seedBytes[4])<<24 | int64(seedBytes[5])<<16 | int64(seedBytes[6])<<8 | int64(seedBytes[7])
}
return &ParallelProcessor{ return &ParallelProcessor{
maxWorkers: maxWorkers, maxWorkers: maxWorkers,
timeout: 30 * time.Second, timeout: 60 * time.Second,
randSource: rand.New(rand.NewSource(seed)),
} }
} }
func (p *ParallelProcessor) SetPasswordHash(hash string) {
p.passwordHash = hash
}
type indexedResult[T any] struct {
value T
index int
}
func processInParallel[T any](
ctx context.Context,
maxWorkers int,
count int,
processor func(index int) (T, error),
errorPrefix string,
progress *ProgressIndicator,
) ([]T, error) {
results := make(chan indexedResult[T], count)
errors := make(chan error, count)
semaphore := make(chan struct{}, maxWorkers)
var wg sync.WaitGroup
for i := range count {
wg.Add(1)
go func(index int) {
defer wg.Done()
select {
case semaphore <- struct{}{}:
case <-ctx.Done():
errors <- ctx.Err()
return
}
defer func() { <-semaphore }()
value, err := processor(index + 1)
if err != nil {
errors <- fmt.Errorf("%s %d: %w", errorPrefix, index+1, err)
return
}
results <- indexedResult[T]{value: value, index: index}
}(i)
}
go func() {
wg.Wait()
close(results)
close(errors)
}()
items := make([]T, count)
completed := 0
firstError := make(chan error, 1)
go func() {
for err := range errors {
if err != nil {
select {
case firstError <- err:
default:
}
return
}
}
}()
for completed < count {
select {
case result, ok := <-results:
if !ok {
return items, nil
}
items[result.index] = result.value
completed++
if progress != nil {
progress.Update(completed)
}
case err := <-firstError:
return nil, err
case <-ctx.Done():
return nil, fmt.Errorf("timeout: %w", ctx.Err())
}
}
return items, nil
}
func (p *ParallelProcessor) CreateUsersInParallel(userRepo repositories.UserRepository, count int, progress *ProgressIndicator) ([]database.User, error) { func (p *ParallelProcessor) CreateUsersInParallel(userRepo repositories.UserRepository, count int, progress *ProgressIndicator) ([]database.User, error) {
ctx, cancel := context.WithTimeout(context.Background(), p.timeout) ctx, cancel := context.WithTimeout(context.Background(), p.timeout)
defer cancel() defer cancel()
results := make(chan userResult, count) return processInParallel(ctx, p.maxWorkers, count,
errors := make(chan error, count) func(index int) (database.User, error) {
return p.createSingleUser(userRepo, index)
semaphore := make(chan struct{}, p.maxWorkers) },
var wg sync.WaitGroup "create user",
progress,
for i := range count { )
wg.Add(1)
go func(index int) {
defer wg.Done()
select {
case semaphore <- struct{}{}:
case <-ctx.Done():
errors <- ctx.Err()
return
}
defer func() { <-semaphore }()
user, err := p.createSingleUser(userRepo, index+1)
if err != nil {
errors <- fmt.Errorf("create user %d: %w", index+1, err)
return
}
results <- userResult{user: user, index: index}
}(i)
}
go func() {
wg.Wait()
close(results)
close(errors)
}()
users := make([]database.User, count)
completed := 0
for {
select {
case result, ok := <-results:
if !ok {
return users, nil
}
users[result.index] = result.user
completed++
if progress != nil {
progress.Update(completed)
}
case err := <-errors:
if err != nil {
return nil, err
}
case <-ctx.Done():
return nil, fmt.Errorf("timeout creating users: %w", ctx.Err())
}
}
} }
func (p *ParallelProcessor) CreatePostsInParallel(postRepo repositories.PostRepository, authorID uint, count int, progress *ProgressIndicator) ([]database.Post, error) { func (p *ParallelProcessor) CreatePostsInParallel(postRepo repositories.PostRepository, authorID uint, count int, progress *ProgressIndicator) ([]database.Post, error) {
ctx, cancel := context.WithTimeout(context.Background(), p.timeout) ctx, cancel := context.WithTimeout(context.Background(), p.timeout)
defer cancel() defer cancel()
results := make(chan postResult, count) return processInParallel(ctx, p.maxWorkers, count,
func(index int) (database.Post, error) {
return p.createSinglePost(postRepo, authorID, index)
},
"create post",
progress,
)
}
func processItemsInParallel[T any, R any](
ctx context.Context,
maxWorkers int,
items []T,
processor func(index int, item T) (R, error),
errorPrefix string,
aggregator func(accumulator R, value R) R,
initialValue R,
progress *ProgressIndicator,
) (R, error) {
count := len(items)
results := make(chan indexedResult[R], count)
errors := make(chan error, count) errors := make(chan error, count)
semaphore := make(chan struct{}, p.maxWorkers) semaphore := make(chan struct{}, maxWorkers)
var wg sync.WaitGroup var wg sync.WaitGroup
for i := range count { for i, item := range items {
wg.Add(1) wg.Add(1)
go func(index int) { go func(index int, item T) {
defer wg.Done() defer wg.Done()
select { select {
@@ -114,14 +183,14 @@ func (p *ParallelProcessor) CreatePostsInParallel(postRepo repositories.PostRepo
} }
defer func() { <-semaphore }() defer func() { <-semaphore }()
post, err := p.createSinglePost(postRepo, authorID, index+1) value, err := processor(index, item)
if err != nil { if err != nil {
errors <- fmt.Errorf("create post %d: %w", index+1, err) errors <- fmt.Errorf("%s %d: %w", errorPrefix, index+1, err)
return return
} }
results <- postResult{post: post, index: index} results <- indexedResult[R]{value: value, index: index}
}(i) }(i, item)
} }
go func() { go func() {
@@ -130,43 +199,76 @@ func (p *ParallelProcessor) CreatePostsInParallel(postRepo repositories.PostRepo
close(errors) close(errors)
}() }()
posts := make([]database.Post, count) accumulator := initialValue
completed := 0 completed := 0
firstError := make(chan error, 1)
for { go func() {
for err := range errors {
if err != nil {
select {
case firstError <- err:
default:
}
return
}
}
}()
for completed < count {
select { select {
case result, ok := <-results: case result, ok := <-results:
if !ok { if !ok {
return posts, nil return accumulator, nil
} }
posts[result.index] = result.post accumulator = aggregator(accumulator, result.value)
completed++ completed++
if progress != nil { if progress != nil {
progress.Update(completed) progress.Update(completed)
} }
case err := <-errors: case err := <-firstError:
if err != nil { return initialValue, err
return nil, err
}
case <-ctx.Done(): case <-ctx.Done():
return nil, fmt.Errorf("timeout creating posts: %w", ctx.Err()) return initialValue, fmt.Errorf("timeout: %w", ctx.Err())
} }
} }
return accumulator, nil
} }
func (p *ParallelProcessor) CreateVotesInParallel(voteRepo repositories.VoteRepository, users []database.User, posts []database.Post, avgVotesPerPost int, progress *ProgressIndicator) (int, error) { func (p *ParallelProcessor) CreateVotesInParallel(voteRepo repositories.VoteRepository, users []database.User, posts []database.Post, avgVotesPerPost int, progress *ProgressIndicator) (int, error) {
ctx, cancel := context.WithTimeout(context.Background(), p.timeout) ctx, cancel := context.WithTimeout(context.Background(), p.timeout)
defer cancel() defer cancel()
results := make(chan voteResult, len(posts)) return processItemsInParallel(ctx, p.maxWorkers, posts,
errors := make(chan error, len(posts)) func(index int, post database.Post) (int, error) {
return p.createVotesForPost(voteRepo, users, post, avgVotesPerPost)
},
"create votes for post",
func(acc, val int) int { return acc + val },
0,
progress,
)
}
semaphore := make(chan struct{}, p.maxWorkers) func processItemsInParallelNoResult[T any](
ctx context.Context,
maxWorkers int,
items []T,
processor func(index int, item T) error,
errorFormatter func(index int, item T, err error) error,
progress *ProgressIndicator,
) error {
count := len(items)
errors := make(chan error, count)
completions := make(chan struct{}, count)
semaphore := make(chan struct{}, maxWorkers)
var wg sync.WaitGroup var wg sync.WaitGroup
for i, post := range posts { for i, item := range items {
wg.Add(1) wg.Add(1)
go func(index int, post database.Post) { go func(index int, item T) {
defer wg.Done() defer wg.Done()
select { select {
@@ -177,133 +279,115 @@ func (p *ParallelProcessor) CreateVotesInParallel(voteRepo repositories.VoteRepo
} }
defer func() { <-semaphore }() defer func() { <-semaphore }()
votes, err := p.createVotesForPost(voteRepo, users, post, avgVotesPerPost) err := processor(index, item)
if err != nil { if err != nil {
errors <- fmt.Errorf("create votes for post %d: %w", post.ID, err) if errorFormatter != nil {
errors <- errorFormatter(index, item, err)
} else {
errors <- fmt.Errorf("process item %d: %w", index+1, err)
}
return return
} }
results <- voteResult{votes: votes, index: index} completions <- struct{}{}
}(i, post) }(i, item)
} }
go func() { go func() {
wg.Wait() wg.Wait()
close(results)
close(errors) close(errors)
close(completions)
}() }()
totalVotes := 0
completed := 0 completed := 0
firstError := make(chan error, 1)
for { go func() {
for err := range errors {
if err != nil {
select { select {
case result, ok := <-results: case firstError <- err:
if !ok { default:
return totalVotes, nil }
return
}
}
}()
for completed < count {
select {
case _, ok := <-completions:
if !ok {
return nil
} }
totalVotes += result.votes
completed++ completed++
if progress != nil { if progress != nil {
progress.Update(completed) progress.Update(completed)
} }
case err := <-errors: case err := <-firstError:
if err != nil {
return 0, err
}
case <-ctx.Done():
return 0, fmt.Errorf("timeout creating votes: %w", ctx.Err())
}
}
}
func (p *ParallelProcessor) UpdatePostScoresInParallel(postRepo repositories.PostRepository, voteRepo repositories.VoteRepository, posts []database.Post, progress *ProgressIndicator) error {
ctx, cancel := context.WithTimeout(context.Background(), p.timeout)
defer cancel()
errors := make(chan error, len(posts))
semaphore := make(chan struct{}, p.maxWorkers)
var wg sync.WaitGroup
for i, post := range posts {
wg.Add(1)
go func(index int, post database.Post) {
defer wg.Done()
select {
case semaphore <- struct{}{}:
case <-ctx.Done():
errors <- ctx.Err()
return
}
defer func() { <-semaphore }()
err := p.updateSinglePostScore(postRepo, voteRepo, post)
if err != nil {
errors <- fmt.Errorf("update post %d scores: %w", post.ID, err)
return
}
if progress != nil {
progress.Update(index + 1)
}
}(i, post)
}
go func() {
wg.Wait()
close(errors)
}()
for err := range errors {
if err != nil {
return err return err
case <-ctx.Done():
return fmt.Errorf("timeout: %w", ctx.Err())
} }
} }
return nil return nil
} }
type userResult struct { func (p *ParallelProcessor) UpdatePostScoresInParallel(postRepo repositories.PostRepository, voteRepo repositories.VoteRepository, posts []database.Post, progress *ProgressIndicator) error {
user database.User ctx, cancel := context.WithTimeout(context.Background(), p.timeout)
index int defer cancel()
return processItemsInParallelNoResult(ctx, p.maxWorkers, posts,
func(index int, post database.Post) error {
return p.updateSinglePostScore(postRepo, voteRepo, post)
},
func(index int, post database.Post, err error) error {
return fmt.Errorf("update post %d scores: %w", post.ID, err)
},
progress,
)
} }
type postResult struct { func (p *ParallelProcessor) generateRandomIdentifier() string {
post database.Post const length = 12
index int const chars = "abcdefghijklmnopqrstuvwxyz0123456789"
identifier := make([]byte, length)
p.randMu.Lock()
for i := range identifier {
identifier[i] = chars[p.randSource.Intn(len(chars))]
} }
p.randMu.Unlock()
type voteResult struct { return string(identifier)
votes int
index int
} }
func (p *ParallelProcessor) createSingleUser(userRepo repositories.UserRepository, index int) (database.User, error) { func (p *ParallelProcessor) createSingleUser(userRepo repositories.UserRepository, index int) (database.User, error) {
username := fmt.Sprintf("user_%d", index) randomID := p.generateRandomIdentifier()
email := fmt.Sprintf("user_%d@goyco.local", index) username := fmt.Sprintf("user_%s", randomID)
password := "password123" email := fmt.Sprintf("user_%s@goyco.local", randomID)
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return database.User{}, fmt.Errorf("hash password: %w", err)
}
const maxRetries = 10
for range maxRetries {
user := &database.User{ user := &database.User{
Username: username, Username: username,
Email: email, Email: email,
Password: string(hashedPassword), Password: p.passwordHash,
EmailVerified: true, EmailVerified: true,
} }
if err := userRepo.Create(user); err != nil { if err := userRepo.Create(user); err != nil {
return database.User{}, fmt.Errorf("create user: %w", err) randomID = p.generateRandomIdentifier()
username = fmt.Sprintf("user_%s", randomID)
email = fmt.Sprintf("user_%s@goyco.local", randomID)
continue
} }
return *user, nil return *user, nil
} }
return database.User{}, fmt.Errorf("failed to create user after %d attempts", maxRetries)
}
func (p *ParallelProcessor) createSinglePost(postRepo repositories.PostRepository, authorID uint, index int) (database.Post, error) { func (p *ParallelProcessor) createSinglePost(postRepo repositories.PostRepository, authorID uint, index int) (database.Post, error) {
sampleTitles := []string{ sampleTitles := []string{
"Amazing JavaScript Framework", "Amazing JavaScript Framework",
@@ -347,11 +431,14 @@ func (p *ParallelProcessor) createSinglePost(postRepo repositories.PostRepositor
} }
domain := sampleDomains[index%len(sampleDomains)] domain := sampleDomains[index%len(sampleDomains)]
path := generateRandomPath() randomID := p.generateRandomIdentifier()
path := fmt.Sprintf("/article/%s", randomID)
url := fmt.Sprintf("https://%s%s", domain, path) url := fmt.Sprintf("https://%s%s", domain, path)
content := fmt.Sprintf("Autogenerated seed post #%d\n\nThis is sample content for testing purposes. The post discusses %s and provides valuable insights.", index, title) content := fmt.Sprintf("Autogenerated seed post #%d\n\nThis is sample content for testing purposes. The post discusses %s and provides valuable insights.", index, title)
const maxRetries = 10
for range maxRetries {
post := &database.Post{ post := &database.Post{
Title: title, Title: title,
URL: url, URL: url,
@@ -363,38 +450,50 @@ func (p *ParallelProcessor) createSinglePost(postRepo repositories.PostRepositor
} }
if err := postRepo.Create(post); err != nil { if err := postRepo.Create(post); err != nil {
return database.Post{}, fmt.Errorf("create post: %w", err) randomID = p.generateRandomIdentifier()
path = fmt.Sprintf("/article/%s", randomID)
url = fmt.Sprintf("https://%s%s", domain, path)
continue
} }
return *post, nil return *post, nil
} }
return database.Post{}, fmt.Errorf("failed to create post after %d attempts", maxRetries)
}
func (p *ParallelProcessor) createVotesForPost(voteRepo repositories.VoteRepository, users []database.User, post database.Post, avgVotesPerPost int) (int, error) { func (p *ParallelProcessor) createVotesForPost(voteRepo repositories.VoteRepository, users []database.User, post database.Post, avgVotesPerPost int) (int, error) {
voteCount, _ := rand.Int(rand.Reader, big.NewInt(int64(avgVotesPerPost*2)+1)) p.randMu.Lock()
numVotes := int(voteCount.Int64()) numVotes := p.randSource.Intn(avgVotesPerPost*2 + 1)
p.randMu.Unlock()
if numVotes == 0 && avgVotesPerPost > 0 { if numVotes == 0 && avgVotesPerPost > 0 {
chance, _ := rand.Int(rand.Reader, big.NewInt(5)) p.randMu.Lock()
if chance.Int64() > 0 { if p.randSource.Intn(5) > 0 {
numVotes = 1 numVotes = 1
} }
p.randMu.Unlock()
} }
totalVotes := 0 totalVotes := 0
usedUsers := make(map[uint]bool) usedUsers := make(map[uint]bool)
for i := 0; i < numVotes && len(usedUsers) < len(users); i++ { for i := 0; i < numVotes && len(usedUsers) < len(users); i++ {
userIdx, _ := rand.Int(rand.Reader, big.NewInt(int64(len(users)))) p.randMu.Lock()
user := users[userIdx.Int64()] userIdx := p.randSource.Intn(len(users))
p.randMu.Unlock()
user := users[userIdx]
if usedUsers[user.ID] { if usedUsers[user.ID] {
continue continue
} }
usedUsers[user.ID] = true usedUsers[user.ID] = true
voteTypeInt, _ := rand.Int(rand.Reader, big.NewInt(10)) p.randMu.Lock()
voteTypeInt := p.randSource.Intn(10)
p.randMu.Unlock()
var voteType database.VoteType var voteType database.VoteType
if voteTypeInt.Int64() < 7 { if voteTypeInt < 7 {
voteType = database.VoteUp voteType = database.VoteUp
} else { } else {
voteType = database.VoteDown voteType = database.VoteDown
@@ -406,8 +505,8 @@ func (p *ParallelProcessor) createVotesForPost(voteRepo repositories.VoteReposit
Type: voteType, Type: voteType,
} }
if err := voteRepo.Create(vote); err != nil { if err := voteRepo.CreateOrUpdate(vote); err != nil {
return totalVotes, fmt.Errorf("create vote: %w", err) return totalVotes, fmt.Errorf("create or update vote: %w", err)
} }
totalVotes++ totalVotes++

View File

@@ -2,15 +2,15 @@ package commands_test
import ( import (
"errors" "errors"
"fmt"
"sync" "sync"
"testing" "testing"
"golang.org/x/crypto/bcrypt"
"goyco/cmd/goyco/commands" "goyco/cmd/goyco/commands"
"goyco/internal/database" "goyco/internal/database"
"goyco/internal/repositories" "goyco/internal/repositories"
"goyco/internal/testutils" "goyco/internal/testutils"
"golang.org/x/crypto/bcrypt"
) )
func TestParallelProcessor_CreateUsersInParallel(t *testing.T) { func TestParallelProcessor_CreateUsersInParallel(t *testing.T) {
@@ -25,7 +25,7 @@ func TestParallelProcessor_CreateUsersInParallel(t *testing.T) {
wantErr bool wantErr bool
}{ }{
{ {
name: "creates users with deterministic fields", name: "creates users with required fields",
count: successCount, count: successCount,
repoFactory: func() repositories.UserRepository { repoFactory: func() repositories.UserRepository {
base := testutils.NewMockUserRepository() base := testutils.NewMockUserRepository()
@@ -37,14 +37,24 @@ func TestParallelProcessor_CreateUsersInParallel(t *testing.T) {
if len(got) != successCount { if len(got) != successCount {
t.Fatalf("expected %d users, got %d", successCount, len(got)) t.Fatalf("expected %d users, got %d", successCount, len(got))
} }
usernames := make(map[string]bool)
for i, user := range got { for i, user := range got {
expectedUsername := fmt.Sprintf("user_%d", i+1) if user.Username == "" {
expectedEmail := fmt.Sprintf("user_%d@goyco.local", i+1) t.Errorf("user %d expected non-empty username", i)
if user.Username != expectedUsername {
t.Errorf("user %d username mismatch: got %q want %q", i, user.Username, expectedUsername)
} }
if user.Email != expectedEmail { if len(user.Username) < 6 || user.Username[:5] != "user_" {
t.Errorf("user %d email mismatch: got %q want %q", i, user.Email, expectedEmail) t.Errorf("user %d username should start with 'user_', got %q", i, user.Username)
}
if usernames[user.Username] {
t.Errorf("user %d duplicate username: %q", i, user.Username)
}
usernames[user.Username] = true
if user.Email == "" {
t.Errorf("user %d expected non-empty email", i)
}
if len(user.Email) < 20 || user.Email[:5] != "user_" || user.Email[len(user.Email)-12:] != "@goyco.local" {
t.Errorf("user %d email should match pattern 'user_*@goyco.local', got %q", i, user.Email)
} }
if !user.EmailVerified { if !user.EmailVerified {
t.Errorf("user %d expected EmailVerified to be true", i) t.Errorf("user %d expected EmailVerified to be true", i)
@@ -83,6 +93,11 @@ func TestParallelProcessor_CreateUsersInParallel(t *testing.T) {
t.Parallel() t.Parallel()
repo := tt.repoFactory() repo := tt.repoFactory()
p := commands.NewParallelProcessor() p := commands.NewParallelProcessor()
passwordHash, err := bcrypt.GenerateFromPassword([]byte("password123"), bcrypt.DefaultCost)
if err != nil {
t.Fatalf("failed to generate password hash: %v", err)
}
p.SetPasswordHash(string(passwordHash))
got, gotErr := p.CreateUsersInParallel(repo, tt.count, tt.progress) got, gotErr := p.CreateUsersInParallel(repo, tt.count, tt.progress)
if gotErr != nil { if gotErr != nil {
if !tt.wantErr { if !tt.wantErr {

View File

@@ -1,20 +1,40 @@
package commands package commands
import ( import (
"crypto/rand" cryptoRand "crypto/rand"
"errors" "errors"
"flag" "flag"
"fmt" "fmt"
"math/big" "math/rand"
"os" "os"
"sync"
"time"
"golang.org/x/crypto/bcrypt"
"gorm.io/gorm"
"goyco/internal/config" "goyco/internal/config"
"goyco/internal/database" "goyco/internal/database"
"goyco/internal/repositories" "goyco/internal/repositories"
"golang.org/x/crypto/bcrypt"
"gorm.io/gorm"
) )
var (
seedRandSource *rand.Rand
seedRandOnce sync.Once
)
func initSeedRand() {
seedRandOnce.Do(func() {
seed := time.Now().UnixNano()
seedBytes := make([]byte, 8)
if _, err := cryptoRand.Read(seedBytes); err == nil {
seed = int64(seedBytes[0])<<56 | int64(seedBytes[1])<<48 | int64(seedBytes[2])<<40 | int64(seedBytes[3])<<32 |
int64(seedBytes[4])<<24 | int64(seedBytes[5])<<16 | int64(seedBytes[6])<<8 | int64(seedBytes[7])
}
seedRandSource = rand.New(rand.NewSource(seed))
})
}
func HandleSeedCommand(cfg *config.Config, name string, args []string) error { func HandleSeedCommand(cfg *config.Config, name string, args []string) error {
fs := newFlagSet(name, printSeedUsage) fs := newFlagSet(name, printSeedUsage)
if err := parseCommand(fs, args, name); err != nil { if err := parseCommand(fs, args, name); err != nil {
@@ -69,26 +89,58 @@ func seedDatabase(userRepo repositories.UserRepository, postRepo repositories.Po
return err return err
} }
originalUsers := *numUsers
originalPosts := *numPosts
originalVotesPerPost := *votesPerPost
if *numUsers < 0 { if *numUsers < 0 {
return fmt.Errorf("invalid value for --users: %d (must be >= 0)", *numUsers) if !IsJSONOutput() {
fmt.Fprintf(os.Stderr, "Warning: --users value %d is negative, clamping to 0\n", *numUsers)
} }
*numUsers = 0
}
if *numPosts <= 0 { if *numPosts <= 0 {
return fmt.Errorf("invalid value for --posts: %d (must be > 0)", *numPosts) if !IsJSONOutput() {
fmt.Fprintf(os.Stderr, "Warning: --posts value %d is too low, clamping to 1\n", *numPosts)
} }
*numPosts = 1
}
if *votesPerPost < 0 { if *votesPerPost < 0 {
return fmt.Errorf("invalid value for --votes-per-post: %d (must be >= 0)", *votesPerPost) if !IsJSONOutput() {
fmt.Fprintf(os.Stderr, "Warning: --votes-per-post value %d is negative, clamping to 0\n", *votesPerPost)
}
*votesPerPost = 0
}
if !IsJSONOutput() && (originalUsers != *numUsers || originalPosts != *numPosts || originalVotesPerPost != *votesPerPost) {
fmt.Fprintf(os.Stderr, "Using clamped values: --users=%d --posts=%d --votes-per-post=%d\n", *numUsers, *numPosts, *votesPerPost)
} }
if !IsJSONOutput() { if !IsJSONOutput() {
fmt.Println("Starting database seeding...") fmt.Println("Starting database seeding...")
} }
seedPassword := "seed-password"
userPassword := "password123"
seedPasswordHash, err := bcrypt.GenerateFromPassword([]byte(seedPassword), bcrypt.DefaultCost)
if err != nil {
return fmt.Errorf("precompute seed password hash: %w", err)
}
userPasswordHash, err := bcrypt.GenerateFromPassword([]byte(userPassword), bcrypt.DefaultCost)
if err != nil {
return fmt.Errorf("precompute user password hash: %w", err)
}
spinner := NewSpinner("Creating seed user") spinner := NewSpinner("Creating seed user")
if !IsJSONOutput() { if !IsJSONOutput() {
spinner.Spin() spinner.Spin()
} }
seedUser, err := ensureSeedUser(userRepo) seedUser, err := ensureSeedUser(userRepo, string(seedPasswordHash))
if err != nil { if err != nil {
if !IsJSONOutput() { if !IsJSONOutput() {
spinner.Complete() spinner.Complete()
@@ -101,6 +153,7 @@ func seedDatabase(userRepo repositories.UserRepository, postRepo repositories.Po
} }
processor := NewParallelProcessor() processor := NewParallelProcessor()
processor.SetPasswordHash(string(userPasswordHash))
var progress *ProgressIndicator var progress *ProgressIndicator
if !IsJSONOutput() && *numUsers > 0 { if !IsJSONOutput() && *numUsers > 0 {
@@ -149,13 +202,17 @@ func seedDatabase(userRepo repositories.UserRepository, postRepo repositories.Po
progress.Complete() progress.Complete()
} }
if err := validateSeedConsistency(voteRepo, allUsers, posts); err != nil {
return fmt.Errorf("seed consistency validation failed: %w", err)
}
if IsJSONOutput() { if IsJSONOutput() {
outputJSON(map[string]interface{}{ outputJSON(map[string]any{
"action": "seed_completed", "action": "seed_completed",
"users": len(allUsers), "users": len(allUsers),
"posts": len(posts), "posts": len(posts),
"votes": votes, "votes": votes,
"seed_user": map[string]interface{}{ "seed_user": map[string]any{
"id": seedUser.ID, "id": seedUser.ID,
"username": seedUser.Username, "username": seedUser.Username,
}, },
@@ -168,233 +225,95 @@ func seedDatabase(userRepo repositories.UserRepository, postRepo repositories.Po
return nil return nil
} }
func ensureSeedUser(userRepo repositories.UserRepository) (*database.User, error) { const (
seedUsername := "seed_admin" seedUsername = "seed_admin"
seedEmail := "seed_admin@goyco.local" seedEmail = "seed_admin@goyco.local"
seedPassword := "seed-password" )
user, err := userRepo.GetByEmail(seedEmail) func ensureSeedUser(userRepo repositories.UserRepository, passwordHash string) (*database.User, error) {
if err == nil { if user, err := userRepo.GetByUsername(seedUsername); err == nil {
return user, nil return user, nil
} }
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(seedPassword), bcrypt.DefaultCost)
if err != nil {
return nil, fmt.Errorf("hash password: %w", err)
}
user = &database.User{
Username: seedUsername,
Email: seedEmail,
Password: string(hashedPassword),
EmailVerified: true,
}
if err := userRepo.Create(user); err != nil {
return nil, fmt.Errorf("create seed user: %w", err)
}
return user, nil
}
func createRandomUsers(userRepo repositories.UserRepository, count int) ([]database.User, error) {
var users []database.User
for i := range count {
username := fmt.Sprintf("user_%d", i+1)
email := fmt.Sprintf("user_%d@goyco.local", i+1)
password := "password123"
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return nil, fmt.Errorf("hash password for user %d: %w", i+1, err)
}
user := &database.User{ user := &database.User{
Username: username, Username: seedUsername,
Email: email, Email: seedEmail,
Password: string(hashedPassword), Password: passwordHash,
EmailVerified: true, EmailVerified: true,
} }
if err := userRepo.Create(user); err != nil { if err := userRepo.Create(user); err != nil {
return nil, fmt.Errorf("create user %d: %w", i+1, err) return nil, fmt.Errorf("failed to create seed user: %w", err)
} }
users = append(users, *user) return user, nil
} }
return users, nil func getVoteCounts(voteRepo repositories.VoteRepository, postID uint) (int, int, error) {
return voteRepo.GetVoteCountsByPostID(postID)
} }
func createRandomPosts(postRepo repositories.PostRepository, authorID uint, count int) ([]database.Post, error) { func validateSeedConsistency(voteRepo repositories.VoteRepository, users []database.User, posts []database.Post) error {
var posts []database.Post userIDSet := make(map[uint]struct{}, len(users))
for _, user := range users {
sampleTitles := []string{ userIDSet[user.ID] = struct{}{}
"Amazing JavaScript Framework",
"Python Best Practices",
"Go Performance Tips",
"Database Optimization",
"Web Security Guide",
"Machine Learning Basics",
"Cloud Architecture",
"DevOps Automation",
"API Design Patterns",
"Frontend Optimization",
"Backend Scaling",
"Container Orchestration",
"Microservices Architecture",
"Testing Strategies",
"Code Review Process",
"Version Control Best Practices",
"Continuous Integration",
"Monitoring and Alerting",
"Error Handling Patterns",
"Data Structures Explained",
} }
sampleDomains := []string{ postIDSet := make(map[uint]struct{}, len(posts))
"example.com", for _, post := range posts {
"techblog.org", postIDSet[post.ID] = struct{}{}
"devguide.net",
"programming.io",
"codeexamples.com",
"tutorialhub.org",
"bestpractices.dev",
"learnprogramming.net",
"codingtips.org",
"softwareengineering.com",
} }
for i := range count {
title := sampleTitles[i%len(sampleTitles)]
if i >= len(sampleTitles) {
title = fmt.Sprintf("%s - Part %d", title, (i/len(sampleTitles))+1)
}
domain := sampleDomains[i%len(sampleDomains)]
path := generateRandomPath()
url := fmt.Sprintf("https://%s%s", domain, path)
content := fmt.Sprintf("Autogenerated seed post #%d\n\nThis is sample content for testing purposes. The post discusses %s and provides valuable insights.", i+1, title)
post := &database.Post{
Title: title,
URL: url,
Content: content,
AuthorID: &authorID,
UpVotes: 0,
DownVotes: 0,
Score: 0,
}
if err := postRepo.Create(post); err != nil {
return nil, fmt.Errorf("create post %d: %w", i+1, err)
}
posts = append(posts, *post)
}
return posts, nil
}
func generateRandomPath() string {
pathLength, _ := rand.Int(rand.Reader, big.NewInt(20))
path := "/article/"
for i := int64(0); i < pathLength.Int64()+5; i++ {
randomChar, _ := rand.Int(rand.Reader, big.NewInt(26))
path += string(rune('a' + randomChar.Int64()))
}
return path
}
func createRandomVotes(voteRepo repositories.VoteRepository, users []database.User, posts []database.Post, avgVotesPerPost int) (int, error) {
totalVotes := 0
for _, post := range posts { for _, post := range posts {
voteCount, _ := rand.Int(rand.Reader, big.NewInt(int64(avgVotesPerPost*2)+1)) if err := validatePost(post, userIDSet); err != nil {
numVotes := int(voteCount.Int64()) return err
if numVotes == 0 && avgVotesPerPost > 0 {
chance, _ := rand.Int(rand.Reader, big.NewInt(5))
if chance.Int64() > 0 {
numVotes = 1
}
} }
usedUsers := make(map[uint]bool) votes, err := voteRepo.GetByPostID(post.ID)
for i := 0; i < numVotes && len(usedUsers) < len(users); i++ {
userIdx, _ := rand.Int(rand.Reader, big.NewInt(int64(len(users))))
user := users[userIdx.Int64()]
if usedUsers[user.ID] {
continue
}
usedUsers[user.ID] = true
voteTypeInt, _ := rand.Int(rand.Reader, big.NewInt(10))
var voteType database.VoteType
if voteTypeInt.Int64() < 7 {
voteType = database.VoteUp
} else {
voteType = database.VoteDown
}
vote := &database.Vote{
UserID: &user.ID,
PostID: post.ID,
Type: voteType,
}
if err := voteRepo.Create(vote); err != nil {
return totalVotes, fmt.Errorf("create vote for post %d: %w", post.ID, err)
}
totalVotes++
}
}
return totalVotes, nil
}
func updatePostScores(postRepo repositories.PostRepository, voteRepo repositories.VoteRepository, posts []database.Post) error {
for _, post := range posts {
upVotes, downVotes, err := getVoteCounts(voteRepo, post.ID)
if err != nil { if err != nil {
return fmt.Errorf("get vote counts for post %d: %w", post.ID, err) return fmt.Errorf("failed to retrieve votes for post %d: %w", post.ID, err)
} }
post.UpVotes = upVotes if err := validateVotesForPost(post.ID, votes, userIDSet, postIDSet); err != nil {
post.DownVotes = downVotes return err
post.Score = upVotes - downVotes
if err := postRepo.Update(&post); err != nil {
return fmt.Errorf("update post %d scores: %w", post.ID, err)
} }
} }
return nil return nil
} }
func getVoteCounts(voteRepo repositories.VoteRepository, postID uint) (int, int, error) { func validatePost(post database.Post, userIDSet map[uint]struct{}) error {
votes, err := voteRepo.GetByPostID(postID) if post.AuthorID == nil {
if err != nil { return fmt.Errorf("post %d has no author ID", post.ID)
return 0, 0, err
} }
upVotes := 0 if _, exists := userIDSet[*post.AuthorID]; !exists {
downVotes := 0 return fmt.Errorf("post %d references non-existent author ID %d", post.ID, *post.AuthorID)
}
return nil
}
func validateVotesForPost(postID uint, votes []database.Vote, userIDSet map[uint]struct{}, postIDSet map[uint]struct{}) error {
for _, vote := range votes { for _, vote := range votes {
switch vote.Type { if vote.PostID != postID {
case database.VoteUp: return fmt.Errorf("vote %d references post ID %d but was retrieved for post %d", vote.ID, vote.PostID, postID)
upVotes++ }
case database.VoteDown:
downVotes++ if _, exists := postIDSet[vote.PostID]; !exists {
return fmt.Errorf("vote %d references non-existent post ID %d", vote.ID, vote.PostID)
}
if vote.UserID != nil {
if _, exists := userIDSet[*vote.UserID]; !exists {
return fmt.Errorf("vote %d references non-existent user ID %d", vote.ID, *vote.UserID)
} }
} }
return upVotes, downVotes, nil if vote.Type != database.VoteUp && vote.Type != database.VoteDown {
return fmt.Errorf("vote %d has invalid type %q", vote.ID, vote.Type)
}
}
return nil
} }

View File

@@ -1,13 +1,16 @@
package commands package commands
import ( import (
"fmt"
"strings"
"testing" "testing"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"goyco/internal/database" "goyco/internal/database"
"goyco/internal/repositories" "goyco/internal/repositories"
"goyco/internal/testutils" "goyco/internal/testutils"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
) )
func TestSeedCommand(t *testing.T) { func TestSeedCommand(t *testing.T) {
@@ -21,39 +24,64 @@ func TestSeedCommand(t *testing.T) {
t.Fatalf("Failed to migrate database: %v", err) t.Fatalf("Failed to migrate database: %v", err)
} }
err = db.Transaction(func(tx *gorm.DB) error {
userRepo := repositories.NewUserRepository(db).WithTx(tx)
postRepo := repositories.NewPostRepository(db).WithTx(tx)
voteRepo := repositories.NewVoteRepository(db).WithTx(tx)
return seedDatabase(userRepo, postRepo, voteRepo, []string{"--users", "2", "--posts", "5", "--votes-per-post", "3"})
})
if err != nil {
t.Fatalf("Failed to seed database: %v", err)
}
userRepo := repositories.NewUserRepository(db) userRepo := repositories.NewUserRepository(db)
postRepo := repositories.NewPostRepository(db) postRepo := repositories.NewPostRepository(db)
voteRepo := repositories.NewVoteRepository(db) voteRepo := repositories.NewVoteRepository(db)
seedUser, err := ensureSeedUser(userRepo) users, err := userRepo.GetAll(100, 0)
if err != nil { if err != nil {
t.Fatalf("Failed to ensure seed user: %v", err) t.Fatalf("Failed to get users: %v", err)
}
seedUserCount := 0
var seedUser *database.User
regularUserCount := 0
for i := range users {
if users[i].Username == "seed_admin" {
seedUserCount++
seedUser = &users[i]
} else if strings.HasPrefix(users[i].Username, "user_") {
regularUserCount++
}
}
if seedUserCount != 1 {
t.Errorf("Expected 1 seed user, got %d", seedUserCount)
}
if seedUser == nil {
t.Fatal("Expected seed user to be created")
} }
if seedUser.Username != "seed_admin" { if seedUser.Username != "seed_admin" {
t.Errorf("Expected username 'seed_admin', got '%s'", seedUser.Username) t.Errorf("Expected username to be 'seed_admin', got '%s'", seedUser.Username)
} }
if seedUser.Email != "seed_admin@goyco.local" { if seedUser.Email != "seed_admin@goyco.local" {
t.Errorf("Expected email 'seed_admin@goyco.local', got '%s'", seedUser.Email) t.Errorf("Expected email to be 'seed_admin@goyco.local', got '%s'", seedUser.Email)
} }
if !seedUser.EmailVerified { if !seedUser.EmailVerified {
t.Error("Expected seed user to be email verified") t.Error("Expected seed user to be email verified")
} }
users, err := createRandomUsers(userRepo, 2) if regularUserCount != 2 {
if err != nil { t.Errorf("Expected 2 regular users, got %d", regularUserCount)
t.Fatalf("Failed to create random users: %v", err)
} }
if len(users) != 2 { posts, err := postRepo.GetAll(100, 0)
t.Errorf("Expected 2 users, got %d", len(users))
}
posts, err := createRandomPosts(postRepo, seedUser.ID, 5)
if err != nil { if err != nil {
t.Fatalf("Failed to create random posts: %v", err) t.Fatalf("Failed to get posts: %v", err)
} }
if len(posts) != 5 { if len(posts) != 5 {
@@ -70,39 +98,49 @@ func TestSeedCommand(t *testing.T) {
if post.AuthorID == nil || *post.AuthorID != seedUser.ID { if post.AuthorID == nil || *post.AuthorID != seedUser.ID {
t.Errorf("Post %d has wrong author ID: expected %d, got %v", i, seedUser.ID, post.AuthorID) t.Errorf("Post %d has wrong author ID: expected %d, got %v", i, seedUser.ID, post.AuthorID)
} }
expectedScore := post.UpVotes - post.DownVotes
if post.Score != expectedScore {
t.Errorf("Post %d has incorrect score: expected %d, got %d", i, expectedScore, post.Score)
}
} }
allUsers := append([]database.User{*seedUser}, users...) voteCount, err := voteRepo.Count()
votes, err := createRandomVotes(voteRepo, allUsers, posts, 3)
if err != nil { if err != nil {
t.Fatalf("Failed to create random votes: %v", err) t.Fatalf("Failed to count votes: %v", err)
} }
if votes == 0 { if voteCount == 0 {
t.Error("Expected some votes to be created") t.Error("Expected some votes to be created")
} }
err = updatePostScores(postRepo, voteRepo, posts) for _, post := range posts {
postVotes, err := voteRepo.GetByPostID(post.ID)
if err != nil { if err != nil {
t.Fatalf("Failed to update post scores: %v", err) t.Errorf("Failed to get votes for post %d: %v", post.ID, err)
}
for i, post := range posts {
updatedPost, err := postRepo.GetByID(post.ID)
if err != nil {
t.Errorf("Failed to get updated post %d: %v", i, err)
continue continue
} }
expectedScore := updatedPost.UpVotes - updatedPost.DownVotes for _, vote := range postVotes {
if updatedPost.Score != expectedScore { if vote.PostID != post.ID {
t.Errorf("Post %d has incorrect score: expected %d, got %d", i, expectedScore, updatedPost.Score) t.Errorf("Vote has wrong post ID: expected %d, got %d", post.ID, vote.PostID)
}
if vote.UserID == nil {
t.Error("Vote has nil user ID")
}
} }
} }
} }
func TestGenerateRandomPath(t *testing.T) { func TestGenerateRandomPath(t *testing.T) {
path := generateRandomPath() initSeedRand()
pathLength := seedRandSource.Intn(20)
path := "/article/"
for i := 0; i < pathLength+5; i++ {
randomChar := seedRandSource.Intn(26)
path += string(rune('a' + randomChar))
}
if path == "" { if path == "" {
t.Error("Generated path should not be empty") t.Error("Generated path should not be empty")
@@ -112,7 +150,14 @@ func TestGenerateRandomPath(t *testing.T) {
t.Errorf("Generated path too short: %s", path) t.Errorf("Generated path too short: %s", path)
} }
secondPath := generateRandomPath() initSeedRand()
secondPathLength := seedRandSource.Intn(20)
secondPath := "/article/"
for i := 0; i < secondPathLength+5; i++ {
randomChar := seedRandSource.Intn(26)
secondPath += string(rune('a' + randomChar))
}
if path == secondPath { if path == secondPath {
t.Error("Generated paths should be different") t.Error("Generated paths should be different")
} }
@@ -179,47 +224,35 @@ func TestSeedDatabaseFlagParsing(t *testing.T) {
} }
}) })
t.Run("negative users value", func(t *testing.T) { t.Run("negative users value is clamped", func(t *testing.T) {
err := seedDatabase(userRepo, postRepo, voteRepo, []string{"--users", "-1"}) err := seedDatabase(userRepo, postRepo, voteRepo, []string{"--users", "-1", "--posts", "1"})
if err == nil { if err != nil {
t.Error("expected error for negative users value") t.Errorf("negative users should be clamped, not rejected. Got error: %v", err)
}
if err != nil && err.Error() != "invalid value for --users: -1 (must be >= 0)" {
t.Errorf("expected specific error message, got: %v", err)
} }
}) })
t.Run("negative posts value", func(t *testing.T) { t.Run("negative posts value is clamped", func(t *testing.T) {
err := seedDatabase(userRepo, postRepo, voteRepo, []string{"--posts", "-5"}) err := seedDatabase(userRepo, postRepo, voteRepo, []string{"--posts", "-5"})
if err == nil { if err != nil {
t.Error("expected error for negative posts value") t.Errorf("negative posts should be clamped, not rejected. Got error: %v", err)
}
if err != nil && err.Error() != "invalid value for --posts: -5 (must be > 0)" {
t.Errorf("expected specific error message, got: %v", err)
} }
}) })
t.Run("zero posts value", func(t *testing.T) { t.Run("zero posts value is clamped", func(t *testing.T) {
err := seedDatabase(userRepo, postRepo, voteRepo, []string{"--posts", "0"}) err := seedDatabase(userRepo, postRepo, voteRepo, []string{"--posts", "0"})
if err == nil { if err != nil {
t.Error("expected error for zero posts value") t.Errorf("zero posts should be clamped, not rejected. Got error: %v", err)
}
if err != nil && err.Error() != "invalid value for --posts: 0 (must be > 0)" {
t.Errorf("expected specific error message, got: %v", err)
} }
}) })
t.Run("negative votes-per-post value", func(t *testing.T) { t.Run("negative votes-per-post value is clamped", func(t *testing.T) {
err := seedDatabase(userRepo, postRepo, voteRepo, []string{"--votes-per-post", "-10"}) err := seedDatabase(userRepo, postRepo, voteRepo, []string{"--votes-per-post", "-10", "--posts", "1"})
if err == nil { if err != nil {
t.Error("expected error for negative votes-per-post value") t.Errorf("negative votes-per-post should be clamped, not rejected. Got error: %v", err)
}
if err != nil && err.Error() != "invalid value for --votes-per-post: -10 (must be >= 0)" {
t.Errorf("expected specific error message, got: %v", err)
} }
}) })
@@ -239,3 +272,262 @@ func TestSeedDatabaseFlagParsing(t *testing.T) {
} }
}) })
} }
func TestSeedCommandIdempotency(t *testing.T) {
dbName := fmt.Sprintf("file:%s?mode=memory&cache=shared", t.Name())
db, err := gorm.Open(sqlite.Open(dbName), &gorm.Config{})
if err != nil {
t.Fatalf("Failed to connect to database: %v", err)
}
err = db.AutoMigrate(&database.User{}, &database.Post{}, &database.Vote{})
if err != nil {
t.Fatalf("Failed to migrate database: %v", err)
}
userRepo := repositories.NewUserRepository(db)
postRepo := repositories.NewPostRepository(db)
voteRepo := repositories.NewVoteRepository(db)
t.Run("first run creates seed user", func(t *testing.T) {
err := seedDatabase(userRepo, postRepo, voteRepo, []string{"--users", "1", "--posts", "2"})
if err != nil {
t.Fatalf("First seed run failed: %v", err)
}
users, err := userRepo.GetAll(100, 0)
if err != nil {
t.Fatalf("Failed to get users: %v", err)
}
seedUserCount := 0
for _, user := range users {
if user.Username == "seed_admin" {
seedUserCount++
}
}
if seedUserCount != 1 {
t.Errorf("Expected exactly 1 seed user, got %d", seedUserCount)
}
})
t.Run("second run reuses seed user", func(t *testing.T) {
usersBefore, _ := userRepo.GetAll(100, 0)
seedUserBefore := findSeedUser(usersBefore)
if seedUserBefore == nil {
t.Fatal("No seed user found before second run")
}
err := seedDatabase(userRepo, postRepo, voteRepo, []string{"--users", "1", "--posts", "2"})
if err != nil {
t.Fatalf("Second seed run failed: %v", err)
}
usersAfter, _ := userRepo.GetAll(100, 0)
seedUserAfter := findSeedUser(usersAfter)
if seedUserAfter == nil {
t.Fatal("Seed user not found after second run")
}
if seedUserBefore.ID != seedUserAfter.ID {
t.Errorf("Expected seed user to be reused (ID %d), but got different user (ID %d)", seedUserBefore.ID, seedUserAfter.ID)
}
})
t.Run("database remains consistent after multiple runs", func(t *testing.T) {
for i := range 2 {
err := seedDatabase(userRepo, postRepo, voteRepo, []string{"--users", "0", "--posts", "1"})
if err != nil {
t.Fatalf("Seed run %d failed: %v", i+1, err)
}
}
users, _ := userRepo.GetAll(100, 0)
posts, _ := postRepo.GetAll(100, 0)
for _, post := range posts {
if post.AuthorID == nil {
t.Errorf("Post %d has no author", post.ID)
continue
}
authorExists := false
for _, user := range users {
if user.ID == *post.AuthorID {
authorExists = true
break
}
}
if !authorExists {
t.Errorf("Post %d has invalid author ID %d", post.ID, *post.AuthorID)
}
votes, _ := voteRepo.GetByPostID(post.ID)
for _, vote := range votes {
if vote.UserID != nil {
userExists := false
for _, user := range users {
if user.ID == *vote.UserID {
userExists = true
break
}
}
if !userExists {
t.Errorf("Vote %d has invalid user ID %d", vote.ID, *vote.UserID)
}
}
}
}
})
}
func findSeedUser(users []database.User) *database.User {
for i := range users {
if users[i].Username == "seed_admin" {
return &users[i]
}
}
return nil
}
func TestSeedCommandTransactionRollback(t *testing.T) {
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
if err != nil {
t.Fatalf("Failed to connect to database: %v", err)
}
err = db.AutoMigrate(&database.User{}, &database.Post{}, &database.Vote{})
if err != nil {
t.Fatalf("Failed to migrate database: %v", err)
}
userRepo := repositories.NewUserRepository(db)
postRepo := repositories.NewPostRepository(db)
voteRepo := repositories.NewVoteRepository(db)
t.Run("transaction rolls back on failure", func(t *testing.T) {
initialUserCount, _ := userRepo.Count()
initialPostCount, _ := postRepo.Count()
initialVoteCount, _ := voteRepo.Count()
err := db.Transaction(func(tx *gorm.DB) error {
txUserRepo := userRepo.WithTx(tx)
txPostRepo := postRepo.WithTx(tx)
txVoteRepo := voteRepo.WithTx(tx)
err := seedDatabase(txUserRepo, txPostRepo, txVoteRepo, []string{"--users", "2", "--posts", "3"})
if err != nil {
return err
}
return fmt.Errorf("simulated failure")
})
if err == nil {
t.Fatal("Expected transaction to fail")
}
finalUserCount, _ := userRepo.Count()
finalPostCount, _ := postRepo.Count()
finalVoteCount, _ := voteRepo.Count()
if finalUserCount != initialUserCount {
t.Errorf("Expected user count to remain %d after rollback, got %d", initialUserCount, finalUserCount)
}
if finalPostCount != initialPostCount {
t.Errorf("Expected post count to remain %d after rollback, got %d", initialPostCount, finalPostCount)
}
if finalVoteCount != initialVoteCount {
t.Errorf("Expected vote count to remain %d after rollback, got %d", initialVoteCount, finalVoteCount)
}
})
t.Run("transaction commits on success", func(t *testing.T) {
initialUserCount, _ := userRepo.Count()
initialPostCount, _ := postRepo.Count()
err := db.Transaction(func(tx *gorm.DB) error {
txUserRepo := userRepo.WithTx(tx)
txPostRepo := postRepo.WithTx(tx)
txVoteRepo := voteRepo.WithTx(tx)
return seedDatabase(txUserRepo, txPostRepo, txVoteRepo, []string{"--users", "1", "--posts", "1"})
})
if err != nil {
t.Fatalf("Expected transaction to succeed, got error: %v", err)
}
finalUserCount, _ := userRepo.Count()
finalPostCount, _ := postRepo.Count()
expectedUsers := initialUserCount + 2
expectedPosts := initialPostCount + 1
if finalUserCount < expectedUsers {
t.Errorf("Expected at least %d users after commit, got %d", expectedUsers, finalUserCount)
}
if finalPostCount < expectedPosts {
t.Errorf("Expected at least %d posts after commit, got %d", expectedPosts, finalPostCount)
}
})
}
func TestEnsureSeedUser(t *testing.T) {
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
if err != nil {
t.Fatalf("Failed to connect to database: %v", err)
}
if err := db.AutoMigrate(&database.User{}); err != nil {
t.Fatalf("Failed to migrate database: %v", err)
}
userRepo := repositories.NewUserRepository(db)
passwordHash := "test_password_hash"
firstUser, err := ensureSeedUser(userRepo, passwordHash)
if err != nil {
t.Fatalf("Failed to create seed user: %v", err)
}
if firstUser.Username != "seed_admin" || firstUser.Email != "seed_admin@goyco.local" || firstUser.Password != passwordHash || !firstUser.EmailVerified {
t.Errorf("Invalid seed user: username=%s, email=%s, password matches=%v, emailVerified=%v",
firstUser.Username, firstUser.Email, firstUser.Password == passwordHash, firstUser.EmailVerified)
}
secondUser, err := ensureSeedUser(userRepo, "different_password_hash")
if err != nil {
t.Fatalf("Failed to reuse seed user: %v", err)
}
if firstUser.ID != secondUser.ID {
t.Errorf("Expected same user to be reused (ID %d), got different user (ID %d)", firstUser.ID, secondUser.ID)
}
for i := 0; i < 3; i++ {
if _, err := ensureSeedUser(userRepo, passwordHash); err != nil {
t.Fatalf("Call %d failed: %v", i+1, err)
}
}
users, err := userRepo.GetAll(100, 0)
if err != nil {
t.Fatalf("Failed to get users: %v", err)
}
seedUserCount := 0
for _, user := range users {
if user.Username == "seed_admin" {
seedUserCount++
}
}
if seedUserCount != 1 {
t.Errorf("Expected exactly 1 seed user, got %d", seedUserCount)
}
}

View File

@@ -96,21 +96,21 @@ func TestServerConfigurationFromConfig(t *testing.T) {
testServer := httptest.NewServer(srv.Handler) testServer := httptest.NewServer(srv.Handler)
defer testServer.Close() defer testServer.Close()
req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, testServer.URL+"/health", nil) request, err := http.NewRequestWithContext(context.Background(), http.MethodGet, testServer.URL+"/health", nil)
if err != nil { if err != nil {
t.Fatalf("Failed to create request: %v", err) t.Fatalf("Failed to create request: %v", err)
} }
resp, err := http.DefaultClient.Do(req) response, err := http.DefaultClient.Do(request)
if err != nil { if err != nil {
t.Fatalf("Failed to make request: %v", err) t.Fatalf("Failed to make request: %v", err)
} }
defer func() { defer func() {
_ = resp.Body.Close() _ = response.Body.Close()
}() }()
if resp.StatusCode != http.StatusOK { if response.StatusCode != http.StatusOK {
t.Errorf("Expected status 200, got %d", resp.StatusCode) t.Errorf("Expected status 200, got %d", response.StatusCode)
} }
} }
@@ -208,27 +208,27 @@ func TestTLSWiringFromConfig(t *testing.T) {
}, },
} }
req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, testServer.URL+"/health", nil) request, err := http.NewRequestWithContext(context.Background(), http.MethodGet, testServer.URL+"/health", nil)
if err != nil { if err != nil {
t.Fatalf("Failed to create request: %v", err) t.Fatalf("Failed to create request: %v", err)
} }
resp, err := client.Do(req) response, err := client.Do(request)
if err != nil { if err != nil {
t.Fatalf("Failed to make TLS request: %v", err) t.Fatalf("Failed to make TLS request: %v", err)
} }
defer func() { defer func() {
_ = resp.Body.Close() _ = response.Body.Close()
}() }()
if resp.StatusCode != http.StatusOK { if response.StatusCode != http.StatusOK {
t.Errorf("Expected status 200 over TLS, got %d", resp.StatusCode) t.Errorf("Expected status 200 over TLS, got %d", response.StatusCode)
} }
if resp.TLS == nil { if response.TLS == nil {
t.Error("Expected TLS connection info to be present in response") t.Error("Expected TLS connection info to be present in response")
} else if resp.TLS.Version < tls.VersionTLS12 { } else if response.TLS.Version < tls.VersionTLS12 {
t.Errorf("Expected TLS version 1.2 or higher, got %x", resp.TLS.Version) t.Errorf("Expected TLS version 1.2 or higher, got %x", response.TLS.Version)
} }
} }
} }
@@ -368,38 +368,38 @@ func TestServerInitializationFlow(t *testing.T) {
testServer := httptest.NewServer(srv.Handler) testServer := httptest.NewServer(srv.Handler)
defer testServer.Close() defer testServer.Close()
req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, testServer.URL+"/health", nil) request, err := http.NewRequestWithContext(context.Background(), http.MethodGet, testServer.URL+"/health", nil)
if err != nil { if err != nil {
t.Fatalf("Failed to create request: %v", err) t.Fatalf("Failed to create request: %v", err)
} }
resp, err := http.DefaultClient.Do(req) response, err := http.DefaultClient.Do(request)
if err != nil { if err != nil {
t.Fatalf("Failed to make request: %v", err) t.Fatalf("Failed to make request: %v", err)
} }
defer func() { defer func() {
_ = resp.Body.Close() _ = response.Body.Close()
}() }()
if resp.StatusCode != http.StatusOK { if response.StatusCode != http.StatusOK {
t.Errorf("Expected status 200, got %d", resp.StatusCode) t.Errorf("Expected status 200, got %d", response.StatusCode)
} }
req, err = http.NewRequestWithContext(context.Background(), http.MethodGet, testServer.URL+"/api", nil) request, err = http.NewRequestWithContext(context.Background(), http.MethodGet, testServer.URL+"/api", nil)
if err != nil { if err != nil {
t.Fatalf("Failed to create request: %v", err) t.Fatalf("Failed to create request: %v", err)
} }
resp, err = http.DefaultClient.Do(req) response, err = http.DefaultClient.Do(request)
if err != nil { if err != nil {
t.Fatalf("Failed to make request: %v", err) t.Fatalf("Failed to make request: %v", err)
} }
defer func() { defer func() {
_ = resp.Body.Close() _ = response.Body.Close()
}() }()
if resp.StatusCode != http.StatusOK { if response.StatusCode != http.StatusOK {
t.Errorf("Expected status 200 for API endpoint, got %d", resp.StatusCode) t.Errorf("Expected status 200 for API endpoint, got %d", response.StatusCode)
} }
} }

View File

@@ -111,7 +111,7 @@ const docTemplate = `{
"in": "body", "in": "body",
"required": true, "required": true,
"schema": { "schema": {
"$ref": "#/definitions/handlers.ConfirmAccountDeletionRequest" "$ref": "#/definitions/dto.ConfirmAccountDeletionRequest"
} }
} }
], ],
@@ -212,7 +212,7 @@ const docTemplate = `{
"in": "body", "in": "body",
"required": true, "required": true,
"schema": { "schema": {
"$ref": "#/definitions/handlers.UpdateEmailRequest" "$ref": "#/definitions/dto.UpdateEmailRequest"
} }
} }
], ],
@@ -276,7 +276,7 @@ const docTemplate = `{
"in": "body", "in": "body",
"required": true, "required": true,
"schema": { "schema": {
"$ref": "#/definitions/handlers.ForgotPasswordRequest" "$ref": "#/definitions/dto.ForgotPasswordRequest"
} }
} }
], ],
@@ -316,7 +316,7 @@ const docTemplate = `{
"in": "body", "in": "body",
"required": true, "required": true,
"schema": { "schema": {
"$ref": "#/definitions/handlers.LoginRequest" "$ref": "#/definitions/dto.LoginRequest"
} }
} }
], ],
@@ -453,7 +453,7 @@ const docTemplate = `{
"in": "body", "in": "body",
"required": true, "required": true,
"schema": { "schema": {
"$ref": "#/definitions/handlers.UpdatePasswordRequest" "$ref": "#/definitions/dto.UpdatePasswordRequest"
} }
} }
], ],
@@ -505,7 +505,7 @@ const docTemplate = `{
"in": "body", "in": "body",
"required": true, "required": true,
"schema": { "schema": {
"$ref": "#/definitions/handlers.RefreshTokenRequest" "$ref": "#/definitions/dto.RefreshTokenRequest"
} }
} }
], ],
@@ -563,7 +563,7 @@ const docTemplate = `{
"in": "body", "in": "body",
"required": true, "required": true,
"schema": { "schema": {
"$ref": "#/definitions/handlers.RegisterRequest" "$ref": "#/definitions/dto.RegisterRequest"
} }
} }
], ],
@@ -615,7 +615,7 @@ const docTemplate = `{
"in": "body", "in": "body",
"required": true, "required": true,
"schema": { "schema": {
"$ref": "#/definitions/handlers.ResendVerificationRequest" "$ref": "#/definitions/dto.ResendVerificationRequest"
} }
} }
], ],
@@ -685,7 +685,7 @@ const docTemplate = `{
"in": "body", "in": "body",
"required": true, "required": true,
"schema": { "schema": {
"$ref": "#/definitions/handlers.ResetPasswordRequest" "$ref": "#/definitions/dto.ResetPasswordRequest"
} }
} }
], ],
@@ -736,7 +736,7 @@ const docTemplate = `{
"in": "body", "in": "body",
"required": true, "required": true,
"schema": { "schema": {
"$ref": "#/definitions/handlers.RevokeTokenRequest" "$ref": "#/definitions/dto.RevokeTokenRequest"
} }
} }
], ],
@@ -833,7 +833,7 @@ const docTemplate = `{
"in": "body", "in": "body",
"required": true, "required": true,
"schema": { "schema": {
"$ref": "#/definitions/handlers.UpdateUsernameRequest" "$ref": "#/definitions/dto.UpdateUsernameRequest"
} }
} }
], ],
@@ -945,7 +945,7 @@ const docTemplate = `{
"in": "body", "in": "body",
"required": true, "required": true,
"schema": { "schema": {
"$ref": "#/definitions/handlers.CreatePostRequest" "$ref": "#/definitions/dto.CreatePostRequest"
} }
} }
], ],
@@ -1176,7 +1176,7 @@ const docTemplate = `{
"in": "body", "in": "body",
"required": true, "required": true,
"schema": { "schema": {
"$ref": "#/definitions/handlers.UpdatePostRequest" "$ref": "#/definitions/dto.UpdatePostRequest"
} }
} }
], ],
@@ -1370,7 +1370,7 @@ const docTemplate = `{
"in": "body", "in": "body",
"required": true, "required": true,
"schema": { "schema": {
"$ref": "#/definitions/handlers.VoteRequest" "$ref": "#/definitions/dto.CastVoteRequest"
} }
} }
], ],
@@ -1601,7 +1601,7 @@ const docTemplate = `{
"in": "body", "in": "body",
"required": true, "required": true,
"schema": { "schema": {
"$ref": "#/definitions/handlers.RegisterRequest" "$ref": "#/definitions/dto.RegisterRequest"
} }
} }
], ],
@@ -1817,6 +1817,223 @@ const docTemplate = `{
} }
}, },
"definitions": { "definitions": {
"dto.CastVoteRequest": {
"type": "object",
"required": [
"type"
],
"properties": {
"type": {
"type": "string",
"enum": [
"up",
"down",
"none"
]
}
}
},
"dto.ConfirmAccountDeletionRequest": {
"type": "object",
"required": [
"token"
],
"properties": {
"delete_posts": {
"type": "boolean"
},
"token": {
"type": "string"
}
}
},
"dto.CreatePostRequest": {
"type": "object",
"required": [
"url"
],
"properties": {
"content": {
"type": "string",
"maxLength": 10000
},
"title": {
"type": "string",
"maxLength": 200,
"minLength": 3
},
"url": {
"type": "string",
"maxLength": 2048
}
}
},
"dto.ForgotPasswordRequest": {
"type": "object",
"required": [
"username_or_email"
],
"properties": {
"username_or_email": {
"type": "string"
}
}
},
"dto.LoginRequest": {
"type": "object",
"required": [
"password",
"username"
],
"properties": {
"password": {
"type": "string",
"maxLength": 128,
"minLength": 8
},
"username": {
"type": "string",
"maxLength": 50,
"minLength": 3
}
}
},
"dto.RefreshTokenRequest": {
"type": "object",
"required": [
"refresh_token"
],
"properties": {
"refresh_token": {
"type": "string",
"example": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..."
}
}
},
"dto.RegisterRequest": {
"type": "object",
"required": [
"email",
"password",
"username"
],
"properties": {
"email": {
"type": "string",
"maxLength": 254
},
"password": {
"type": "string",
"maxLength": 128,
"minLength": 8
},
"username": {
"type": "string",
"maxLength": 50,
"minLength": 3
}
}
},
"dto.ResendVerificationRequest": {
"type": "object",
"required": [
"email"
],
"properties": {
"email": {
"type": "string",
"maxLength": 254
}
}
},
"dto.ResetPasswordRequest": {
"type": "object",
"required": [
"new_password",
"token"
],
"properties": {
"new_password": {
"type": "string",
"maxLength": 128,
"minLength": 8
},
"token": {
"type": "string"
}
}
},
"dto.RevokeTokenRequest": {
"type": "object",
"required": [
"refresh_token"
],
"properties": {
"refresh_token": {
"type": "string",
"example": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..."
}
}
},
"dto.UpdateEmailRequest": {
"type": "object",
"required": [
"email"
],
"properties": {
"email": {
"type": "string",
"maxLength": 254
}
}
},
"dto.UpdatePasswordRequest": {
"type": "object",
"required": [
"current_password",
"new_password"
],
"properties": {
"current_password": {
"type": "string"
},
"new_password": {
"type": "string",
"maxLength": 128,
"minLength": 8
}
}
},
"dto.UpdatePostRequest": {
"type": "object",
"required": [
"title"
],
"properties": {
"content": {
"type": "string",
"maxLength": 10000
},
"title": {
"type": "string",
"maxLength": 200,
"minLength": 3
}
}
},
"dto.UpdateUsernameRequest": {
"type": "object",
"required": [
"username"
],
"properties": {
"username": {
"type": "string",
"maxLength": 50,
"minLength": 3
}
}
},
"handlers.APIInfo": { "handlers.APIInfo": {
"type": "object", "type": "object",
"properties": { "properties": {
@@ -1919,50 +2136,6 @@ const docTemplate = `{
} }
} }
}, },
"handlers.ConfirmAccountDeletionRequest": {
"type": "object",
"properties": {
"delete_posts": {
"type": "boolean"
},
"token": {
"type": "string"
}
}
},
"handlers.CreatePostRequest": {
"type": "object",
"properties": {
"content": {
"type": "string"
},
"title": {
"type": "string"
},
"url": {
"type": "string"
}
}
},
"handlers.ForgotPasswordRequest": {
"type": "object",
"properties": {
"username_or_email": {
"type": "string"
}
}
},
"handlers.LoginRequest": {
"type": "object",
"properties": {
"password": {
"type": "string"
},
"username": {
"type": "string"
}
}
},
"handlers.PostResponse": { "handlers.PostResponse": {
"type": "object", "type": "object",
"properties": { "properties": {
@@ -1978,101 +2151,6 @@ const docTemplate = `{
} }
} }
}, },
"handlers.RefreshTokenRequest": {
"type": "object",
"required": [
"refresh_token"
],
"properties": {
"refresh_token": {
"type": "string",
"example": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..."
}
}
},
"handlers.RegisterRequest": {
"type": "object",
"properties": {
"email": {
"type": "string"
},
"password": {
"type": "string"
},
"username": {
"type": "string"
}
}
},
"handlers.ResendVerificationRequest": {
"type": "object",
"properties": {
"email": {
"type": "string"
}
}
},
"handlers.ResetPasswordRequest": {
"type": "object",
"properties": {
"new_password": {
"type": "string"
},
"token": {
"type": "string"
}
}
},
"handlers.RevokeTokenRequest": {
"type": "object",
"required": [
"refresh_token"
],
"properties": {
"refresh_token": {
"type": "string",
"example": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..."
}
}
},
"handlers.UpdateEmailRequest": {
"type": "object",
"properties": {
"email": {
"type": "string"
}
}
},
"handlers.UpdatePasswordRequest": {
"type": "object",
"properties": {
"current_password": {
"type": "string"
},
"new_password": {
"type": "string"
}
}
},
"handlers.UpdatePostRequest": {
"type": "object",
"properties": {
"content": {
"type": "string"
},
"title": {
"type": "string"
}
}
},
"handlers.UpdateUsernameRequest": {
"type": "object",
"properties": {
"username": {
"type": "string"
}
}
},
"handlers.UserResponse": { "handlers.UserResponse": {
"type": "object", "type": "object",
"properties": { "properties": {
@@ -2088,21 +2166,6 @@ const docTemplate = `{
} }
} }
}, },
"handlers.VoteRequest": {
"description": "Vote request with type field. All votes are handled the same way.",
"type": "object",
"properties": {
"type": {
"type": "string",
"enum": [
"up",
"down",
"none"
],
"example": "up"
}
}
},
"handlers.VoteResponse": { "handlers.VoteResponse": {
"type": "object", "type": "object",
"properties": { "properties": {

View File

@@ -108,7 +108,7 @@
"in": "body", "in": "body",
"required": true, "required": true,
"schema": { "schema": {
"$ref": "#/definitions/handlers.ConfirmAccountDeletionRequest" "$ref": "#/definitions/dto.ConfirmAccountDeletionRequest"
} }
} }
], ],
@@ -209,7 +209,7 @@
"in": "body", "in": "body",
"required": true, "required": true,
"schema": { "schema": {
"$ref": "#/definitions/handlers.UpdateEmailRequest" "$ref": "#/definitions/dto.UpdateEmailRequest"
} }
} }
], ],
@@ -273,7 +273,7 @@
"in": "body", "in": "body",
"required": true, "required": true,
"schema": { "schema": {
"$ref": "#/definitions/handlers.ForgotPasswordRequest" "$ref": "#/definitions/dto.ForgotPasswordRequest"
} }
} }
], ],
@@ -313,7 +313,7 @@
"in": "body", "in": "body",
"required": true, "required": true,
"schema": { "schema": {
"$ref": "#/definitions/handlers.LoginRequest" "$ref": "#/definitions/dto.LoginRequest"
} }
} }
], ],
@@ -450,7 +450,7 @@
"in": "body", "in": "body",
"required": true, "required": true,
"schema": { "schema": {
"$ref": "#/definitions/handlers.UpdatePasswordRequest" "$ref": "#/definitions/dto.UpdatePasswordRequest"
} }
} }
], ],
@@ -502,7 +502,7 @@
"in": "body", "in": "body",
"required": true, "required": true,
"schema": { "schema": {
"$ref": "#/definitions/handlers.RefreshTokenRequest" "$ref": "#/definitions/dto.RefreshTokenRequest"
} }
} }
], ],
@@ -560,7 +560,7 @@
"in": "body", "in": "body",
"required": true, "required": true,
"schema": { "schema": {
"$ref": "#/definitions/handlers.RegisterRequest" "$ref": "#/definitions/dto.RegisterRequest"
} }
} }
], ],
@@ -612,7 +612,7 @@
"in": "body", "in": "body",
"required": true, "required": true,
"schema": { "schema": {
"$ref": "#/definitions/handlers.ResendVerificationRequest" "$ref": "#/definitions/dto.ResendVerificationRequest"
} }
} }
], ],
@@ -682,7 +682,7 @@
"in": "body", "in": "body",
"required": true, "required": true,
"schema": { "schema": {
"$ref": "#/definitions/handlers.ResetPasswordRequest" "$ref": "#/definitions/dto.ResetPasswordRequest"
} }
} }
], ],
@@ -733,7 +733,7 @@
"in": "body", "in": "body",
"required": true, "required": true,
"schema": { "schema": {
"$ref": "#/definitions/handlers.RevokeTokenRequest" "$ref": "#/definitions/dto.RevokeTokenRequest"
} }
} }
], ],
@@ -830,7 +830,7 @@
"in": "body", "in": "body",
"required": true, "required": true,
"schema": { "schema": {
"$ref": "#/definitions/handlers.UpdateUsernameRequest" "$ref": "#/definitions/dto.UpdateUsernameRequest"
} }
} }
], ],
@@ -942,7 +942,7 @@
"in": "body", "in": "body",
"required": true, "required": true,
"schema": { "schema": {
"$ref": "#/definitions/handlers.CreatePostRequest" "$ref": "#/definitions/dto.CreatePostRequest"
} }
} }
], ],
@@ -1173,7 +1173,7 @@
"in": "body", "in": "body",
"required": true, "required": true,
"schema": { "schema": {
"$ref": "#/definitions/handlers.UpdatePostRequest" "$ref": "#/definitions/dto.UpdatePostRequest"
} }
} }
], ],
@@ -1367,7 +1367,7 @@
"in": "body", "in": "body",
"required": true, "required": true,
"schema": { "schema": {
"$ref": "#/definitions/handlers.VoteRequest" "$ref": "#/definitions/dto.CastVoteRequest"
} }
} }
], ],
@@ -1598,7 +1598,7 @@
"in": "body", "in": "body",
"required": true, "required": true,
"schema": { "schema": {
"$ref": "#/definitions/handlers.RegisterRequest" "$ref": "#/definitions/dto.RegisterRequest"
} }
} }
], ],
@@ -1814,6 +1814,223 @@
} }
}, },
"definitions": { "definitions": {
"dto.CastVoteRequest": {
"type": "object",
"required": [
"type"
],
"properties": {
"type": {
"type": "string",
"enum": [
"up",
"down",
"none"
]
}
}
},
"dto.ConfirmAccountDeletionRequest": {
"type": "object",
"required": [
"token"
],
"properties": {
"delete_posts": {
"type": "boolean"
},
"token": {
"type": "string"
}
}
},
"dto.CreatePostRequest": {
"type": "object",
"required": [
"url"
],
"properties": {
"content": {
"type": "string",
"maxLength": 10000
},
"title": {
"type": "string",
"maxLength": 200,
"minLength": 3
},
"url": {
"type": "string",
"maxLength": 2048
}
}
},
"dto.ForgotPasswordRequest": {
"type": "object",
"required": [
"username_or_email"
],
"properties": {
"username_or_email": {
"type": "string"
}
}
},
"dto.LoginRequest": {
"type": "object",
"required": [
"password",
"username"
],
"properties": {
"password": {
"type": "string",
"maxLength": 128,
"minLength": 8
},
"username": {
"type": "string",
"maxLength": 50,
"minLength": 3
}
}
},
"dto.RefreshTokenRequest": {
"type": "object",
"required": [
"refresh_token"
],
"properties": {
"refresh_token": {
"type": "string",
"example": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..."
}
}
},
"dto.RegisterRequest": {
"type": "object",
"required": [
"email",
"password",
"username"
],
"properties": {
"email": {
"type": "string",
"maxLength": 254
},
"password": {
"type": "string",
"maxLength": 128,
"minLength": 8
},
"username": {
"type": "string",
"maxLength": 50,
"minLength": 3
}
}
},
"dto.ResendVerificationRequest": {
"type": "object",
"required": [
"email"
],
"properties": {
"email": {
"type": "string",
"maxLength": 254
}
}
},
"dto.ResetPasswordRequest": {
"type": "object",
"required": [
"new_password",
"token"
],
"properties": {
"new_password": {
"type": "string",
"maxLength": 128,
"minLength": 8
},
"token": {
"type": "string"
}
}
},
"dto.RevokeTokenRequest": {
"type": "object",
"required": [
"refresh_token"
],
"properties": {
"refresh_token": {
"type": "string",
"example": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..."
}
}
},
"dto.UpdateEmailRequest": {
"type": "object",
"required": [
"email"
],
"properties": {
"email": {
"type": "string",
"maxLength": 254
}
}
},
"dto.UpdatePasswordRequest": {
"type": "object",
"required": [
"current_password",
"new_password"
],
"properties": {
"current_password": {
"type": "string"
},
"new_password": {
"type": "string",
"maxLength": 128,
"minLength": 8
}
}
},
"dto.UpdatePostRequest": {
"type": "object",
"required": [
"title"
],
"properties": {
"content": {
"type": "string",
"maxLength": 10000
},
"title": {
"type": "string",
"maxLength": 200,
"minLength": 3
}
}
},
"dto.UpdateUsernameRequest": {
"type": "object",
"required": [
"username"
],
"properties": {
"username": {
"type": "string",
"maxLength": 50,
"minLength": 3
}
}
},
"handlers.APIInfo": { "handlers.APIInfo": {
"type": "object", "type": "object",
"properties": { "properties": {
@@ -1916,50 +2133,6 @@
} }
} }
}, },
"handlers.ConfirmAccountDeletionRequest": {
"type": "object",
"properties": {
"delete_posts": {
"type": "boolean"
},
"token": {
"type": "string"
}
}
},
"handlers.CreatePostRequest": {
"type": "object",
"properties": {
"content": {
"type": "string"
},
"title": {
"type": "string"
},
"url": {
"type": "string"
}
}
},
"handlers.ForgotPasswordRequest": {
"type": "object",
"properties": {
"username_or_email": {
"type": "string"
}
}
},
"handlers.LoginRequest": {
"type": "object",
"properties": {
"password": {
"type": "string"
},
"username": {
"type": "string"
}
}
},
"handlers.PostResponse": { "handlers.PostResponse": {
"type": "object", "type": "object",
"properties": { "properties": {
@@ -1975,101 +2148,6 @@
} }
} }
}, },
"handlers.RefreshTokenRequest": {
"type": "object",
"required": [
"refresh_token"
],
"properties": {
"refresh_token": {
"type": "string",
"example": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..."
}
}
},
"handlers.RegisterRequest": {
"type": "object",
"properties": {
"email": {
"type": "string"
},
"password": {
"type": "string"
},
"username": {
"type": "string"
}
}
},
"handlers.ResendVerificationRequest": {
"type": "object",
"properties": {
"email": {
"type": "string"
}
}
},
"handlers.ResetPasswordRequest": {
"type": "object",
"properties": {
"new_password": {
"type": "string"
},
"token": {
"type": "string"
}
}
},
"handlers.RevokeTokenRequest": {
"type": "object",
"required": [
"refresh_token"
],
"properties": {
"refresh_token": {
"type": "string",
"example": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..."
}
}
},
"handlers.UpdateEmailRequest": {
"type": "object",
"properties": {
"email": {
"type": "string"
}
}
},
"handlers.UpdatePasswordRequest": {
"type": "object",
"properties": {
"current_password": {
"type": "string"
},
"new_password": {
"type": "string"
}
}
},
"handlers.UpdatePostRequest": {
"type": "object",
"properties": {
"content": {
"type": "string"
},
"title": {
"type": "string"
}
}
},
"handlers.UpdateUsernameRequest": {
"type": "object",
"properties": {
"username": {
"type": "string"
}
}
},
"handlers.UserResponse": { "handlers.UserResponse": {
"type": "object", "type": "object",
"properties": { "properties": {
@@ -2085,21 +2163,6 @@
} }
} }
}, },
"handlers.VoteRequest": {
"description": "Vote request with type field. All votes are handled the same way.",
"type": "object",
"properties": {
"type": {
"type": "string",
"enum": [
"up",
"down",
"none"
],
"example": "up"
}
}
},
"handlers.VoteResponse": { "handlers.VoteResponse": {
"type": "object", "type": "object",
"properties": { "properties": {

View File

@@ -1,5 +1,156 @@
basePath: /api basePath: /api
definitions: definitions:
dto.CastVoteRequest:
properties:
type:
enum:
- up
- down
- none
type: string
required:
- type
type: object
dto.ConfirmAccountDeletionRequest:
properties:
delete_posts:
type: boolean
token:
type: string
required:
- token
type: object
dto.CreatePostRequest:
properties:
content:
maxLength: 10000
type: string
title:
maxLength: 200
minLength: 3
type: string
url:
maxLength: 2048
type: string
required:
- url
type: object
dto.ForgotPasswordRequest:
properties:
username_or_email:
type: string
required:
- username_or_email
type: object
dto.LoginRequest:
properties:
password:
maxLength: 128
minLength: 8
type: string
username:
maxLength: 50
minLength: 3
type: string
required:
- password
- username
type: object
dto.RefreshTokenRequest:
properties:
refresh_token:
example: eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9...
type: string
required:
- refresh_token
type: object
dto.RegisterRequest:
properties:
email:
maxLength: 254
type: string
password:
maxLength: 128
minLength: 8
type: string
username:
maxLength: 50
minLength: 3
type: string
required:
- email
- password
- username
type: object
dto.ResendVerificationRequest:
properties:
email:
maxLength: 254
type: string
required:
- email
type: object
dto.ResetPasswordRequest:
properties:
new_password:
maxLength: 128
minLength: 8
type: string
token:
type: string
required:
- new_password
- token
type: object
dto.RevokeTokenRequest:
properties:
refresh_token:
example: eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9...
type: string
required:
- refresh_token
type: object
dto.UpdateEmailRequest:
properties:
email:
maxLength: 254
type: string
required:
- email
type: object
dto.UpdatePasswordRequest:
properties:
current_password:
type: string
new_password:
maxLength: 128
minLength: 8
type: string
required:
- current_password
- new_password
type: object
dto.UpdatePostRequest:
properties:
content:
maxLength: 10000
type: string
title:
maxLength: 200
minLength: 3
type: string
required:
- title
type: object
dto.UpdateUsernameRequest:
properties:
username:
maxLength: 50
minLength: 3
type: string
required:
- username
type: object
handlers.APIInfo: handlers.APIInfo:
properties: properties:
data: {} data: {}
@@ -70,34 +221,6 @@ definitions:
success: success:
type: boolean type: boolean
type: object type: object
handlers.ConfirmAccountDeletionRequest:
properties:
delete_posts:
type: boolean
token:
type: string
type: object
handlers.CreatePostRequest:
properties:
content:
type: string
title:
type: string
url:
type: string
type: object
handlers.ForgotPasswordRequest:
properties:
username_or_email:
type: string
type: object
handlers.LoginRequest:
properties:
password:
type: string
username:
type: string
type: object
handlers.PostResponse: handlers.PostResponse:
properties: properties:
data: {} data: {}
@@ -108,67 +231,6 @@ definitions:
success: success:
type: boolean type: boolean
type: object type: object
handlers.RefreshTokenRequest:
properties:
refresh_token:
example: eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9...
type: string
required:
- refresh_token
type: object
handlers.RegisterRequest:
properties:
email:
type: string
password:
type: string
username:
type: string
type: object
handlers.ResendVerificationRequest:
properties:
email:
type: string
type: object
handlers.ResetPasswordRequest:
properties:
new_password:
type: string
token:
type: string
type: object
handlers.RevokeTokenRequest:
properties:
refresh_token:
example: eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9...
type: string
required:
- refresh_token
type: object
handlers.UpdateEmailRequest:
properties:
email:
type: string
type: object
handlers.UpdatePasswordRequest:
properties:
current_password:
type: string
new_password:
type: string
type: object
handlers.UpdatePostRequest:
properties:
content:
type: string
title:
type: string
type: object
handlers.UpdateUsernameRequest:
properties:
username:
type: string
type: object
handlers.UserResponse: handlers.UserResponse:
properties: properties:
data: {} data: {}
@@ -179,17 +241,6 @@ definitions:
success: success:
type: boolean type: boolean
type: object type: object
handlers.VoteRequest:
description: Vote request with type field. All votes are handled the same way.
properties:
type:
enum:
- up
- down
- none
example: up
type: string
type: object
handlers.VoteResponse: handlers.VoteResponse:
properties: properties:
data: {} data: {}
@@ -268,7 +319,7 @@ paths:
name: request name: request
required: true required: true
schema: schema:
$ref: '#/definitions/handlers.ConfirmAccountDeletionRequest' $ref: '#/definitions/dto.ConfirmAccountDeletionRequest'
produces: produces:
- application/json - application/json
responses: responses:
@@ -331,7 +382,7 @@ paths:
name: request name: request
required: true required: true
schema: schema:
$ref: '#/definitions/handlers.UpdateEmailRequest' $ref: '#/definitions/dto.UpdateEmailRequest'
produces: produces:
- application/json - application/json
responses: responses:
@@ -375,7 +426,7 @@ paths:
name: request name: request
required: true required: true
schema: schema:
$ref: '#/definitions/handlers.ForgotPasswordRequest' $ref: '#/definitions/dto.ForgotPasswordRequest'
produces: produces:
- application/json - application/json
responses: responses:
@@ -401,7 +452,7 @@ paths:
name: request name: request
required: true required: true
schema: schema:
$ref: '#/definitions/handlers.LoginRequest' $ref: '#/definitions/dto.LoginRequest'
produces: produces:
- application/json - application/json
responses: responses:
@@ -485,7 +536,7 @@ paths:
name: request name: request
required: true required: true
schema: schema:
$ref: '#/definitions/handlers.UpdatePasswordRequest' $ref: '#/definitions/dto.UpdatePasswordRequest'
produces: produces:
- application/json - application/json
responses: responses:
@@ -523,7 +574,7 @@ paths:
name: request name: request
required: true required: true
schema: schema:
$ref: '#/definitions/handlers.RefreshTokenRequest' $ref: '#/definitions/dto.RefreshTokenRequest'
produces: produces:
- application/json - application/json
responses: responses:
@@ -561,7 +612,7 @@ paths:
name: request name: request
required: true required: true
schema: schema:
$ref: '#/definitions/handlers.RegisterRequest' $ref: '#/definitions/dto.RegisterRequest'
produces: produces:
- application/json - application/json
responses: responses:
@@ -595,7 +646,7 @@ paths:
name: request name: request
required: true required: true
schema: schema:
$ref: '#/definitions/handlers.ResendVerificationRequest' $ref: '#/definitions/dto.ResendVerificationRequest'
produces: produces:
- application/json - application/json
responses: responses:
@@ -641,7 +692,7 @@ paths:
name: request name: request
required: true required: true
schema: schema:
$ref: '#/definitions/handlers.ResetPasswordRequest' $ref: '#/definitions/dto.ResetPasswordRequest'
produces: produces:
- application/json - application/json
responses: responses:
@@ -672,7 +723,7 @@ paths:
name: request name: request
required: true required: true
schema: schema:
$ref: '#/definitions/handlers.RevokeTokenRequest' $ref: '#/definitions/dto.RevokeTokenRequest'
produces: produces:
- application/json - application/json
responses: responses:
@@ -735,7 +786,7 @@ paths:
name: request name: request
required: true required: true
schema: schema:
$ref: '#/definitions/handlers.UpdateUsernameRequest' $ref: '#/definitions/dto.UpdateUsernameRequest'
produces: produces:
- application/json - application/json
responses: responses:
@@ -809,7 +860,7 @@ paths:
name: request name: request
required: true required: true
schema: schema:
$ref: '#/definitions/handlers.CreatePostRequest' $ref: '#/definitions/dto.CreatePostRequest'
produces: produces:
- application/json - application/json
responses: responses:
@@ -933,7 +984,7 @@ paths:
name: request name: request
required: true required: true
schema: schema:
$ref: '#/definitions/handlers.UpdatePostRequest' $ref: '#/definitions/dto.UpdatePostRequest'
produces: produces:
- application/json - application/json
responses: responses:
@@ -1070,7 +1121,7 @@ paths:
name: request name: request
required: true required: true
schema: schema:
$ref: '#/definitions/handlers.VoteRequest' $ref: '#/definitions/dto.CastVoteRequest'
produces: produces:
- application/json - application/json
responses: responses:
@@ -1260,7 +1311,7 @@ paths:
name: request name: request
required: true required: true
schema: schema:
$ref: '#/definitions/handlers.RegisterRequest' $ref: '#/definitions/dto.RegisterRequest'
produces: produces:
- application/json - application/json
responses: responses:

View File

@@ -0,0 +1,51 @@
package dto
type LoginRequest struct {
Username string `json:"username" validate:"required,min=3,max=50"`
Password string `json:"password" validate:"required,min=8,max=128"`
}
type RegisterRequest struct {
Username string `json:"username" validate:"required,min=3,max=50"`
Email string `json:"email" validate:"required,email,max=254"`
Password string `json:"password" validate:"required,min=8,max=128"`
}
type ResendVerificationRequest struct {
Email string `json:"email" validate:"required,email,max=254"`
}
type ForgotPasswordRequest struct {
UsernameOrEmail string `json:"username_or_email" validate:"required"`
}
type ResetPasswordRequest struct {
Token string `json:"token" validate:"required"`
NewPassword string `json:"new_password" validate:"required,min=8,max=128"`
}
type UpdateEmailRequest struct {
Email string `json:"email" validate:"required,email,max=254"`
}
type UpdateUsernameRequest struct {
Username string `json:"username" validate:"required,min=3,max=50"`
}
type UpdatePasswordRequest struct {
CurrentPassword string `json:"current_password" validate:"required"`
NewPassword string `json:"new_password" validate:"required,min=8,max=128"`
}
type ConfirmAccountDeletionRequest struct {
Token string `json:"token" validate:"required"`
DeletePosts bool `json:"delete_posts"`
}
type RefreshTokenRequest struct {
RefreshToken string `json:"refresh_token" example:"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..." validate:"required"`
}
type RevokeTokenRequest struct {
RefreshToken string `json:"refresh_token" example:"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..." validate:"required"`
}

View File

@@ -0,0 +1,12 @@
package dto
type CreatePostRequest struct {
Title string `json:"title" validate:"omitempty,min=3,max=200"`
URL string `json:"url" validate:"required,url,max=2048"`
Content string `json:"content" validate:"omitempty,max=10000"`
}
type UpdatePostRequest struct {
Title string `json:"title" validate:"required,min=3,max=200"`
Content string `json:"content" validate:"omitempty,max=10000"`
}

View File

@@ -6,6 +6,10 @@ import (
"goyco/internal/database" "goyco/internal/database"
) )
type CastVoteRequest struct {
Type string `json:"type" validate:"required,oneof=up down none"`
}
type VoteDTO struct { type VoteDTO struct {
ID uint `json:"id"` ID uint `json:"id"`
UserID *uint `json:"user_id,omitempty"` UserID *uint `json:"user_id,omitempty"`

View File

@@ -112,37 +112,37 @@ func newInMemoryRoundTripper(handler http.Handler) http.RoundTripper {
return &inMemoryRoundTripper{handler: handler} return &inMemoryRoundTripper{handler: handler}
} }
func (rt *inMemoryRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { func (rt *inMemoryRoundTripper) RoundTrip(request *http.Request) (*http.Response, error) {
if rt == nil || rt.handler == nil { if rt == nil || rt.handler == nil {
return nil, fmt.Errorf("in-memory round tripper not initialized") return nil, fmt.Errorf("in-memory round tripper not initialized")
} }
var bodyBytes []byte var bodyBytes []byte
if req.Body != nil && req.Body != http.NoBody { if request.Body != nil && request.Body != http.NoBody {
defer req.Body.Close() defer request.Body.Close()
var err error var err error
bodyBytes, err = io.ReadAll(req.Body) bodyBytes, err = io.ReadAll(request.Body)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to read request body: %w", err) return nil, fmt.Errorf("failed to read request body: %w", err)
} }
} }
if len(bodyBytes) > 0 { if len(bodyBytes) > 0 {
req.Body = io.NopCloser(bytes.NewReader(bodyBytes)) request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
} else { } else {
req.Body = http.NoBody request.Body = http.NoBody
} }
clonedReq := req.Clone(req.Context()) clonedRequest := request.Clone(request.Context())
if len(bodyBytes) > 0 { if len(bodyBytes) > 0 {
clonedReq.Body = io.NopCloser(bytes.NewReader(bodyBytes)) clonedRequest.Body = io.NopCloser(bytes.NewReader(bodyBytes))
} else { } else {
clonedReq.Body = http.NoBody clonedRequest.Body = http.NoBody
} }
clonedReq.RequestURI = clonedReq.URL.RequestURI() clonedRequest.RequestURI = clonedRequest.URL.RequestURI()
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
rt.handler.ServeHTTP(recorder, clonedReq) rt.handler.ServeHTTP(recorder, clonedRequest)
resp := recorder.Result() resp := recorder.Result()
return resp, nil return resp, nil
} }
@@ -250,7 +250,7 @@ func tokenHash(token string) string {
func retryOnRateLimit(t *testing.T, maxRetries int, operation func() int) int { func retryOnRateLimit(t *testing.T, maxRetries int, operation func() int) int {
t.Helper() t.Helper()
for attempt := 0; attempt < maxRetries; attempt++ { for attempt := range maxRetries {
statusCode := operation() statusCode := operation()
if statusCode != http.StatusTooManyRequests { if statusCode != http.StatusTooManyRequests {
return statusCode return statusCode

View File

@@ -15,22 +15,22 @@ func TestE2E_CompressionMiddleware(t *testing.T) {
ctx := setupTestContext(t) ctx := setupTestContext(t)
t.Run("compression_enabled_with_accept_encoding", func(t *testing.T) { t.Run("compression_enabled_with_accept_encoding", func(t *testing.T) {
req, err := http.NewRequest("GET", ctx.baseURL+"/api/posts", nil) request, err := http.NewRequest("GET", ctx.baseURL+"/api/posts", nil)
if err != nil { if err != nil {
t.Fatalf("Failed to create request: %v", err) t.Fatalf("Failed to create request: %v", err)
} }
req.Header.Set("Accept-Encoding", "gzip") request.Header.Set("Accept-Encoding", "gzip")
testutils.WithStandardHeaders(req) testutils.WithStandardHeaders(request)
resp, err := ctx.client.Do(req) response, err := ctx.client.Do(request)
if err != nil { if err != nil {
t.Fatalf("Request failed: %v", err) t.Fatalf("Request failed: %v", err)
} }
defer resp.Body.Close() defer response.Body.Close()
contentEncoding := resp.Header.Get("Content-Encoding") contentEncoding := response.Header.Get("Content-Encoding")
if contentEncoding == "gzip" { if contentEncoding == "gzip" {
body, err := io.ReadAll(resp.Body) body, err := io.ReadAll(response.Body)
if err != nil { if err != nil {
t.Fatalf("Failed to read response body: %v", err) t.Fatalf("Failed to read response body: %v", err)
} }
@@ -57,19 +57,19 @@ func TestE2E_CompressionMiddleware(t *testing.T) {
}) })
t.Run("no_compression_without_accept_encoding", func(t *testing.T) { t.Run("no_compression_without_accept_encoding", func(t *testing.T) {
req, err := http.NewRequest("GET", ctx.baseURL+"/api/posts", nil) request, err := http.NewRequest("GET", ctx.baseURL+"/api/posts", nil)
if err != nil { if err != nil {
t.Fatalf("Failed to create request: %v", err) t.Fatalf("Failed to create request: %v", err)
} }
testutils.WithStandardHeaders(req) testutils.WithStandardHeaders(request)
resp, err := ctx.client.Do(req) response, err := ctx.client.Do(request)
if err != nil { if err != nil {
t.Fatalf("Request failed: %v", err) t.Fatalf("Request failed: %v", err)
} }
defer resp.Body.Close() defer response.Body.Close()
contentEncoding := resp.Header.Get("Content-Encoding") contentEncoding := response.Header.Get("Content-Encoding")
if contentEncoding == "gzip" { if contentEncoding == "gzip" {
t.Error("Expected no compression without Accept-Encoding header") t.Error("Expected no compression without Accept-Encoding header")
} }
@@ -85,22 +85,22 @@ func TestE2E_CompressionMiddleware(t *testing.T) {
gz.Write([]byte(postData)) gz.Write([]byte(postData))
gz.Close() gz.Close()
req, err := http.NewRequest("POST", ctx.baseURL+"/api/posts", &buf) request, err := http.NewRequest("POST", ctx.baseURL+"/api/posts", &buf)
if err != nil { if err != nil {
t.Fatalf("Failed to create request: %v", err) t.Fatalf("Failed to create request: %v", err)
} }
req.Header.Set("Content-Type", "application/json") request.Header.Set("Content-Type", "application/json")
req.Header.Set("Content-Encoding", "gzip") request.Header.Set("Content-Encoding", "gzip")
testutils.WithStandardHeaders(req) testutils.WithStandardHeaders(request)
req.Header.Set("Authorization", "Bearer "+authClient.Token) request.Header.Set("Authorization", "Bearer "+authClient.Token)
resp, err := ctx.client.Do(req) response, err := ctx.client.Do(request)
if err != nil { if err != nil {
t.Fatalf("Request failed: %v", err) t.Fatalf("Request failed: %v", err)
} }
defer resp.Body.Close() defer response.Body.Close()
switch resp.StatusCode { switch response.StatusCode {
case http.StatusBadRequest: case http.StatusBadRequest:
t.Log("Decompression middleware rejected invalid gzip") t.Log("Decompression middleware rejected invalid gzip")
case http.StatusCreated, http.StatusOK: case http.StatusCreated, http.StatusOK:
@@ -113,37 +113,37 @@ func TestE2E_CacheMiddleware(t *testing.T) {
ctx := setupTestContext(t) ctx := setupTestContext(t)
t.Run("cache_miss_then_hit", func(t *testing.T) { t.Run("cache_miss_then_hit", func(t *testing.T) {
req1, err := http.NewRequest("GET", ctx.baseURL+"/api/posts", nil) firstRequest, err := http.NewRequest("GET", ctx.baseURL+"/api/posts", nil)
if err != nil { if err != nil {
t.Fatalf("Failed to create request: %v", err) t.Fatalf("Failed to create request: %v", err)
} }
testutils.WithStandardHeaders(req1) testutils.WithStandardHeaders(firstRequest)
resp1, err := ctx.client.Do(req1) firstResponse, err := ctx.client.Do(firstRequest)
if err != nil { if err != nil {
t.Fatalf("Request failed: %v", err) t.Fatalf("Request failed: %v", err)
} }
resp1.Body.Close() firstResponse.Body.Close()
cacheStatus1 := resp1.Header.Get("X-Cache") firstCacheStatus := firstResponse.Header.Get("X-Cache")
if cacheStatus1 == "HIT" { if firstCacheStatus == "HIT" {
t.Log("First request was cached (unexpected but acceptable)") t.Log("First request was cached (unexpected but acceptable)")
} }
req2, err := http.NewRequest("GET", ctx.baseURL+"/api/posts", nil) secondRequest, err := http.NewRequest("GET", ctx.baseURL+"/api/posts", nil)
if err != nil { if err != nil {
t.Fatalf("Failed to create request: %v", err) t.Fatalf("Failed to create request: %v", err)
} }
testutils.WithStandardHeaders(req2) testutils.WithStandardHeaders(secondRequest)
resp2, err := ctx.client.Do(req2) secondResponse, err := ctx.client.Do(secondRequest)
if err != nil { if err != nil {
t.Fatalf("Request failed: %v", err) t.Fatalf("Request failed: %v", err)
} }
defer resp2.Body.Close() defer secondResponse.Body.Close()
cacheStatus2 := resp2.Header.Get("X-Cache") secondCacheStatus := secondResponse.Header.Get("X-Cache")
if cacheStatus2 == "HIT" { if secondCacheStatus == "HIT" {
t.Log("Second request was served from cache") t.Log("Second request was served from cache")
} }
}) })
@@ -152,48 +152,48 @@ func TestE2E_CacheMiddleware(t *testing.T) {
testUser := ctx.createUserWithCleanup(t, "cacheuser", "StrongPass123!") testUser := ctx.createUserWithCleanup(t, "cacheuser", "StrongPass123!")
authClient := ctx.loginUser(t, testUser.Username, "StrongPass123!") authClient := ctx.loginUser(t, testUser.Username, "StrongPass123!")
req1, err := http.NewRequest("GET", ctx.baseURL+"/api/posts", nil) firstRequest, err := http.NewRequest("GET", ctx.baseURL+"/api/posts", nil)
if err != nil { if err != nil {
t.Fatalf("Failed to create request: %v", err) t.Fatalf("Failed to create request: %v", err)
} }
testutils.WithStandardHeaders(req1) testutils.WithStandardHeaders(firstRequest)
req1.Header.Set("Authorization", "Bearer "+authClient.Token) firstRequest.Header.Set("Authorization", "Bearer "+authClient.Token)
resp1, err := ctx.client.Do(req1) firstResponse, err := ctx.client.Do(firstRequest)
if err != nil { if err != nil {
t.Fatalf("Request failed: %v", err) t.Fatalf("Request failed: %v", err)
} }
resp1.Body.Close() firstResponse.Body.Close()
postData := `{"title":"Cache Invalidation Test","url":"https://example.com/cache","content":"Test"}` postData := `{"title":"Cache Invalidation Test","url":"https://example.com/cache","content":"Test"}`
req2, err := http.NewRequest("POST", ctx.baseURL+"/api/posts", strings.NewReader(postData)) secondRequest, err := http.NewRequest("POST", ctx.baseURL+"/api/posts", strings.NewReader(postData))
if err != nil { if err != nil {
t.Fatalf("Failed to create request: %v", err) t.Fatalf("Failed to create request: %v", err)
} }
req2.Header.Set("Content-Type", "application/json") secondRequest.Header.Set("Content-Type", "application/json")
testutils.WithStandardHeaders(req2) testutils.WithStandardHeaders(secondRequest)
req2.Header.Set("Authorization", "Bearer "+authClient.Token) secondRequest.Header.Set("Authorization", "Bearer "+authClient.Token)
resp2, err := ctx.client.Do(req2) secondResponse, err := ctx.client.Do(secondRequest)
if err != nil { if err != nil {
t.Fatalf("Request failed: %v", err) t.Fatalf("Request failed: %v", err)
} }
resp2.Body.Close() secondResponse.Body.Close()
req3, err := http.NewRequest("GET", ctx.baseURL+"/api/posts", nil) thirdRequest, err := http.NewRequest("GET", ctx.baseURL+"/api/posts", nil)
if err != nil { if err != nil {
t.Fatalf("Failed to create request: %v", err) t.Fatalf("Failed to create request: %v", err)
} }
testutils.WithStandardHeaders(req3) testutils.WithStandardHeaders(thirdRequest)
req3.Header.Set("Authorization", "Bearer "+authClient.Token) thirdRequest.Header.Set("Authorization", "Bearer "+authClient.Token)
resp3, err := ctx.client.Do(req3) thirdResponse, err := ctx.client.Do(thirdRequest)
if err != nil { if err != nil {
t.Fatalf("Request failed: %v", err) t.Fatalf("Request failed: %v", err)
} }
defer resp3.Body.Close() defer thirdResponse.Body.Close()
cacheStatus := resp3.Header.Get("X-Cache") cacheStatus := thirdResponse.Header.Get("X-Cache")
if cacheStatus == "HIT" { if cacheStatus == "HIT" {
t.Log("Cache was invalidated after POST") t.Log("Cache was invalidated after POST")
} }
@@ -204,23 +204,23 @@ func TestE2E_CSRFProtection(t *testing.T) {
ctx := setupTestContext(t) ctx := setupTestContext(t)
t.Run("csrf_protection_for_non_api_routes", func(t *testing.T) { t.Run("csrf_protection_for_non_api_routes", func(t *testing.T) {
req, err := http.NewRequest("POST", ctx.baseURL+"/auth/login", strings.NewReader(`{"username":"test","password":"test"}`)) request, err := http.NewRequest("POST", ctx.baseURL+"/auth/login", strings.NewReader(`{"username":"test","password":"test"}`))
if err != nil { if err != nil {
t.Fatalf("Failed to create request: %v", err) t.Fatalf("Failed to create request: %v", err)
} }
req.Header.Set("Content-Type", "application/json") request.Header.Set("Content-Type", "application/json")
testutils.WithStandardHeaders(req) testutils.WithStandardHeaders(request)
resp, err := ctx.client.Do(req) response, err := ctx.client.Do(request)
if err != nil { if err != nil {
t.Fatalf("Request failed: %v", err) t.Fatalf("Request failed: %v", err)
} }
defer resp.Body.Close() defer response.Body.Close()
if resp.StatusCode == http.StatusForbidden { if response.StatusCode == http.StatusForbidden {
t.Log("CSRF protection active for non-API routes") t.Log("CSRF protection active for non-API routes")
} else { } else {
t.Logf("CSRF check result: status %d", resp.StatusCode) t.Logf("CSRF check result: status %d", response.StatusCode)
} }
}) })
@@ -229,39 +229,39 @@ func TestE2E_CSRFProtection(t *testing.T) {
authClient := ctx.loginUser(t, testUser.Username, "StrongPass123!") authClient := ctx.loginUser(t, testUser.Username, "StrongPass123!")
postData := `{"title":"CSRF Test","url":"https://example.com/csrf","content":"Test"}` postData := `{"title":"CSRF Test","url":"https://example.com/csrf","content":"Test"}`
req, err := http.NewRequest("POST", ctx.baseURL+"/api/posts", strings.NewReader(postData)) request, err := http.NewRequest("POST", ctx.baseURL+"/api/posts", strings.NewReader(postData))
if err != nil { if err != nil {
t.Fatalf("Failed to create request: %v", err) t.Fatalf("Failed to create request: %v", err)
} }
req.Header.Set("Content-Type", "application/json") request.Header.Set("Content-Type", "application/json")
testutils.WithStandardHeaders(req) testutils.WithStandardHeaders(request)
req.Header.Set("Authorization", "Bearer "+authClient.Token) request.Header.Set("Authorization", "Bearer "+authClient.Token)
resp, err := ctx.client.Do(req) response, err := ctx.client.Do(request)
if err != nil { if err != nil {
t.Fatalf("Request failed: %v", err) t.Fatalf("Request failed: %v", err)
} }
defer resp.Body.Close() defer response.Body.Close()
if resp.StatusCode == http.StatusForbidden { if response.StatusCode == http.StatusForbidden {
t.Error("API routes should bypass CSRF protection") t.Error("API routes should bypass CSRF protection")
} }
}) })
t.Run("csrf_allows_get_requests", func(t *testing.T) { t.Run("csrf_allows_get_requests", func(t *testing.T) {
req, err := http.NewRequest("GET", ctx.baseURL+"/auth/login", nil) request, err := http.NewRequest("GET", ctx.baseURL+"/auth/login", nil)
if err != nil { if err != nil {
t.Fatalf("Failed to create request: %v", err) t.Fatalf("Failed to create request: %v", err)
} }
testutils.WithStandardHeaders(req) testutils.WithStandardHeaders(request)
resp, err := ctx.client.Do(req) response, err := ctx.client.Do(request)
if err != nil { if err != nil {
t.Fatalf("Request failed: %v", err) t.Fatalf("Request failed: %v", err)
} }
defer resp.Body.Close() defer response.Body.Close()
if resp.StatusCode == http.StatusForbidden { if response.StatusCode == http.StatusForbidden {
t.Error("GET requests should not require CSRF token") t.Error("GET requests should not require CSRF token")
} }
}) })
@@ -276,21 +276,21 @@ func TestE2E_RequestSizeLimit(t *testing.T) {
smallData := strings.Repeat("a", 100) smallData := strings.Repeat("a", 100)
postData := `{"title":"` + smallData + `","url":"https://example.com","content":"test"}` postData := `{"title":"` + smallData + `","url":"https://example.com","content":"test"}`
req, err := http.NewRequest("POST", ctx.baseURL+"/api/posts", strings.NewReader(postData)) request, err := http.NewRequest("POST", ctx.baseURL+"/api/posts", strings.NewReader(postData))
if err != nil { if err != nil {
t.Fatalf("Failed to create request: %v", err) t.Fatalf("Failed to create request: %v", err)
} }
req.Header.Set("Content-Type", "application/json") request.Header.Set("Content-Type", "application/json")
testutils.WithStandardHeaders(req) testutils.WithStandardHeaders(request)
req.Header.Set("Authorization", "Bearer "+authClient.Token) request.Header.Set("Authorization", "Bearer "+authClient.Token)
resp, err := ctx.client.Do(req) response, err := ctx.client.Do(request)
if err != nil { if err != nil {
t.Fatalf("Request failed: %v", err) t.Fatalf("Request failed: %v", err)
} }
defer resp.Body.Close() defer response.Body.Close()
if resp.StatusCode == http.StatusRequestEntityTooLarge { if response.StatusCode == http.StatusRequestEntityTooLarge {
t.Error("Small request should not exceed size limit") t.Error("Small request should not exceed size limit")
} }
}) })
@@ -301,24 +301,24 @@ func TestE2E_RequestSizeLimit(t *testing.T) {
largeData := strings.Repeat("a", 2*1024*1024) largeData := strings.Repeat("a", 2*1024*1024)
postData := `{"title":"test","url":"https://example.com","content":"` + largeData + `"}` postData := `{"title":"test","url":"https://example.com","content":"` + largeData + `"}`
req, err := http.NewRequest("POST", ctx.baseURL+"/api/posts", strings.NewReader(postData)) request, err := http.NewRequest("POST", ctx.baseURL+"/api/posts", strings.NewReader(postData))
if err != nil { if err != nil {
t.Fatalf("Failed to create request: %v", err) t.Fatalf("Failed to create request: %v", err)
} }
req.Header.Set("Content-Type", "application/json") request.Header.Set("Content-Type", "application/json")
testutils.WithStandardHeaders(req) testutils.WithStandardHeaders(request)
req.Header.Set("Authorization", "Bearer "+authClient.Token) request.Header.Set("Authorization", "Bearer "+authClient.Token)
resp, err := ctx.client.Do(req) response, err := ctx.client.Do(request)
if err != nil { if err != nil {
return return
} }
defer resp.Body.Close() defer response.Body.Close()
if resp.StatusCode == http.StatusRequestEntityTooLarge { if response.StatusCode == http.StatusRequestEntityTooLarge {
t.Log("Request size limit enforced correctly") t.Log("Request size limit enforced correctly")
} else { } else {
t.Logf("Request size limit check result: status %d", resp.StatusCode) t.Logf("Request size limit check result: status %d", response.StatusCode)
} }
}) })
} }

View File

@@ -64,62 +64,6 @@ type AuthUserSummary struct {
Locked bool `json:"locked" example:"false"` Locked bool `json:"locked" example:"false"`
} }
type LoginRequest struct {
Username string `json:"username"`
Password string `json:"password"`
}
type RegisterRequest struct {
Username string `json:"username"`
Email string `json:"email"`
Password string `json:"password"`
}
type CreatePostRequest struct {
Title string `json:"title"`
URL string `json:"url"`
Content string `json:"content"`
}
type ResendVerificationRequest struct {
Email string `json:"email"`
}
type ForgotPasswordRequest struct {
UsernameOrEmail string `json:"username_or_email"`
}
type ResetPasswordRequest struct {
Token string `json:"token"`
NewPassword string `json:"new_password"`
}
type UpdateEmailRequest struct {
Email string `json:"email"`
}
type UpdateUsernameRequest struct {
Username string `json:"username"`
}
type UpdatePasswordRequest struct {
CurrentPassword string `json:"current_password"`
NewPassword string `json:"new_password"`
}
type ConfirmAccountDeletionRequest struct {
Token string `json:"token"`
DeletePosts bool `json:"delete_posts"`
}
type RefreshTokenRequest struct {
RefreshToken string `json:"refresh_token" example:"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..." binding:"required"`
}
type RevokeTokenRequest struct {
RefreshToken string `json:"refresh_token" example:"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..." binding:"required"`
}
func NewAuthHandler(authService AuthServiceInterface, userRepo repositories.UserRepository) *AuthHandler { func NewAuthHandler(authService AuthServiceInterface, userRepo repositories.UserRepository) *AuthHandler {
return &AuthHandler{ return &AuthHandler{
authService: authService, authService: authService,
@@ -132,7 +76,7 @@ func NewAuthHandler(authService AuthServiceInterface, userRepo repositories.User
// @Tags auth // @Tags auth
// @Accept json // @Accept json
// @Produce json // @Produce json
// @Param request body LoginRequest true "Login credentials" // @Param request body dto.LoginRequest true "Login credentials"
// @Success 200 {object} AuthTokensResponse "Authentication successful" // @Success 200 {object} AuthTokensResponse "Authentication successful"
// @Failure 400 {object} AuthResponse "Invalid request data or validation failed" // @Failure 400 {object} AuthResponse "Invalid request data or validation failed"
// @Failure 401 {object} AuthResponse "Invalid credentials" // @Failure 401 {object} AuthResponse "Invalid credentials"
@@ -140,23 +84,15 @@ func NewAuthHandler(authService AuthServiceInterface, userRepo repositories.User
// @Failure 500 {object} AuthResponse "Internal server error" // @Failure 500 {object} AuthResponse "Internal server error"
// @Router /api/auth/login [post] // @Router /api/auth/login [post]
func (h *AuthHandler) Login(w http.ResponseWriter, r *http.Request) { func (h *AuthHandler) Login(w http.ResponseWriter, r *http.Request) {
var req struct { req, ok := GetValidatedDTO[dto.LoginRequest](r)
Username string `json:"username"` if !ok {
Password string `json:"password"` SendErrorResponse(w, "Invalid request", http.StatusBadRequest)
}
if !DecodeJSONRequest(w, r, &req) {
return return
} }
username := security.SanitizeUsername(req.Username) username := security.SanitizeUsername(req.Username)
password := strings.TrimSpace(req.Password) password := strings.TrimSpace(req.Password)
if username == "" || password == "" {
SendErrorResponse(w, "Username and password are required", http.StatusBadRequest)
return
}
if err := validation.ValidatePassword(password); err != nil { if err := validation.ValidatePassword(password); err != nil {
SendErrorResponse(w, err.Error(), http.StatusBadRequest) SendErrorResponse(w, err.Error(), http.StatusBadRequest)
return return
@@ -175,20 +111,16 @@ func (h *AuthHandler) Login(w http.ResponseWriter, r *http.Request) {
// @Tags auth // @Tags auth
// @Accept json // @Accept json
// @Produce json // @Produce json
// @Param request body RegisterRequest true "Registration data" // @Param request body dto.RegisterRequest true "Registration data"
// @Success 201 {object} AuthResponse "Registration successful" // @Success 201 {object} AuthResponse "Registration successful"
// @Failure 400 {object} AuthResponse "Invalid request data or validation failed" // @Failure 400 {object} AuthResponse "Invalid request data or validation failed"
// @Failure 409 {object} AuthResponse "Username or email already exists" // @Failure 409 {object} AuthResponse "Username or email already exists"
// @Failure 500 {object} AuthResponse "Internal server error" // @Failure 500 {object} AuthResponse "Internal server error"
// @Router /api/auth/register [post] // @Router /api/auth/register [post]
func (h *AuthHandler) Register(w http.ResponseWriter, r *http.Request) { func (h *AuthHandler) Register(w http.ResponseWriter, r *http.Request) {
var req struct { req, ok := GetValidatedDTO[dto.RegisterRequest](r)
Username string `json:"username"` if !ok {
Email string `json:"email"` SendErrorResponse(w, "Invalid request", http.StatusBadRequest)
Password string `json:"password"`
}
if !DecodeJSONRequest(w, r, &req) {
return return
} }
@@ -196,11 +128,6 @@ func (h *AuthHandler) Register(w http.ResponseWriter, r *http.Request) {
email := strings.TrimSpace(req.Email) email := strings.TrimSpace(req.Email)
password := strings.TrimSpace(req.Password) password := strings.TrimSpace(req.Password)
if username == "" || email == "" || password == "" {
SendErrorResponse(w, "Username, email, and password are required", http.StatusBadRequest)
return
}
username = security.SanitizeUsername(username) username = security.SanitizeUsername(username)
if err := validation.ValidateUsername(username); err != nil { if err := validation.ValidateUsername(username); err != nil {
SendErrorResponse(w, err.Error(), http.StatusBadRequest) SendErrorResponse(w, err.Error(), http.StatusBadRequest)
@@ -280,7 +207,7 @@ func (h *AuthHandler) ConfirmEmail(w http.ResponseWriter, r *http.Request) {
// @Tags auth // @Tags auth
// @Accept json // @Accept json
// @Produce json // @Produce json
// @Param request body ResendVerificationRequest true "Email address" // @Param request body dto.ResendVerificationRequest true "Email address"
// @Success 200 {object} AuthResponse // @Success 200 {object} AuthResponse
// @Failure 400 {object} AuthResponse // @Failure 400 {object} AuthResponse
// @Failure 404 {object} AuthResponse // @Failure 404 {object} AuthResponse
@@ -290,15 +217,14 @@ func (h *AuthHandler) ConfirmEmail(w http.ResponseWriter, r *http.Request) {
// @Failure 500 {object} AuthResponse // @Failure 500 {object} AuthResponse
// @Router /api/auth/resend-verification [post] // @Router /api/auth/resend-verification [post]
func (h *AuthHandler) ResendVerificationEmail(w http.ResponseWriter, r *http.Request) { func (h *AuthHandler) ResendVerificationEmail(w http.ResponseWriter, r *http.Request) {
var req struct { req, ok := GetValidatedDTO[dto.ResendVerificationRequest](r)
Email string `json:"email"` if !ok {
} SendErrorResponse(w, "Invalid request", http.StatusBadRequest)
if !DecodeJSONRequest(w, r, &req) {
return return
} }
email := strings.TrimSpace(req.Email) email := strings.TrimSpace(req.Email)
if email == "" { if email == "" {
SendErrorResponse(w, "Email address is required", http.StatusBadRequest) SendErrorResponse(w, "Email address is required", http.StatusBadRequest)
return return
@@ -359,20 +285,19 @@ func (h *AuthHandler) Me(w http.ResponseWriter, r *http.Request) {
// @Tags auth // @Tags auth
// @Accept json // @Accept json
// @Produce json // @Produce json
// @Param request body ForgotPasswordRequest true "Username or email" // @Param request body dto.ForgotPasswordRequest true "Username or email"
// @Success 200 {object} AuthResponse "Password reset email sent if account exists" // @Success 200 {object} AuthResponse "Password reset email sent if account exists"
// @Failure 400 {object} AuthResponse "Invalid request data" // @Failure 400 {object} AuthResponse "Invalid request data"
// @Router /api/auth/forgot-password [post] // @Router /api/auth/forgot-password [post]
func (h *AuthHandler) RequestPasswordReset(w http.ResponseWriter, r *http.Request) { func (h *AuthHandler) RequestPasswordReset(w http.ResponseWriter, r *http.Request) {
var req struct { req, ok := GetValidatedDTO[dto.ForgotPasswordRequest](r)
UsernameOrEmail string `json:"username_or_email"` if !ok {
} SendErrorResponse(w, "Invalid request", http.StatusBadRequest)
if !DecodeJSONRequest(w, r, &req) {
return return
} }
usernameOrEmail := strings.TrimSpace(req.UsernameOrEmail) usernameOrEmail := strings.TrimSpace(req.UsernameOrEmail)
if usernameOrEmail == "" { if usernameOrEmail == "" {
SendErrorResponse(w, "Username or email is required", http.StatusBadRequest) SendErrorResponse(w, "Username or email is required", http.StatusBadRequest)
return return
@@ -389,18 +314,15 @@ func (h *AuthHandler) RequestPasswordReset(w http.ResponseWriter, r *http.Reques
// @Tags auth // @Tags auth
// @Accept json // @Accept json
// @Produce json // @Produce json
// @Param request body ResetPasswordRequest true "Password reset data" // @Param request body dto.ResetPasswordRequest true "Password reset data"
// @Success 200 {object} AuthResponse "Password reset successfully" // @Success 200 {object} AuthResponse "Password reset successfully"
// @Failure 400 {object} AuthResponse "Invalid or expired token, or validation failed" // @Failure 400 {object} AuthResponse "Invalid or expired token, or validation failed"
// @Failure 500 {object} AuthResponse "Internal server error" // @Failure 500 {object} AuthResponse "Internal server error"
// @Router /api/auth/reset-password [post] // @Router /api/auth/reset-password [post]
func (h *AuthHandler) ResetPassword(w http.ResponseWriter, r *http.Request) { func (h *AuthHandler) ResetPassword(w http.ResponseWriter, r *http.Request) {
var req struct { req, ok := GetValidatedDTO[dto.ResetPasswordRequest](r)
Token string `json:"token"` if !ok {
NewPassword string `json:"new_password"` SendErrorResponse(w, "Invalid request", http.StatusBadRequest)
}
if !DecodeJSONRequest(w, r, &req) {
return return
} }
@@ -408,17 +330,12 @@ func (h *AuthHandler) ResetPassword(w http.ResponseWriter, r *http.Request) {
newPassword := strings.TrimSpace(req.NewPassword) newPassword := strings.TrimSpace(req.NewPassword)
if token == "" { if token == "" {
SendErrorResponse(w, "Reset token is required", http.StatusBadRequest) SendErrorResponse(w, "Token is required", http.StatusBadRequest)
return return
} }
if newPassword == "" { if err := validation.ValidatePassword(newPassword); err != nil {
SendErrorResponse(w, "New password is required", http.StatusBadRequest) SendErrorResponse(w, err.Error(), http.StatusBadRequest)
return
}
if len(newPassword) < 8 {
SendErrorResponse(w, "Password must be at least 8 characters long", http.StatusBadRequest)
return return
} }
@@ -443,7 +360,7 @@ func (h *AuthHandler) ResetPassword(w http.ResponseWriter, r *http.Request) {
// @Accept json // @Accept json
// @Produce json // @Produce json
// @Security BearerAuth // @Security BearerAuth
// @Param request body UpdateEmailRequest true "New email address" // @Param request body dto.UpdateEmailRequest true "New email address"
// @Success 200 {object} AuthResponse // @Success 200 {object} AuthResponse
// @Failure 400 {object} AuthResponse // @Failure 400 {object} AuthResponse
// @Failure 401 {object} AuthResponse // @Failure 401 {object} AuthResponse
@@ -457,11 +374,9 @@ func (h *AuthHandler) UpdateEmail(w http.ResponseWriter, r *http.Request) {
return return
} }
var req struct { req, ok := GetValidatedDTO[dto.UpdateEmailRequest](r)
Email string `json:"email"` if !ok {
} SendErrorResponse(w, "Invalid request", http.StatusBadRequest)
if !DecodeJSONRequest(w, r, &req) {
return return
} }
@@ -498,7 +413,7 @@ func (h *AuthHandler) UpdateEmail(w http.ResponseWriter, r *http.Request) {
// @Accept json // @Accept json
// @Produce json // @Produce json
// @Security BearerAuth // @Security BearerAuth
// @Param request body UpdateUsernameRequest true "New username" // @Param request body dto.UpdateUsernameRequest true "New username"
// @Success 200 {object} AuthResponse // @Success 200 {object} AuthResponse
// @Failure 400 {object} AuthResponse // @Failure 400 {object} AuthResponse
// @Failure 401 {object} AuthResponse // @Failure 401 {object} AuthResponse
@@ -511,11 +426,9 @@ func (h *AuthHandler) UpdateUsername(w http.ResponseWriter, r *http.Request) {
return return
} }
var req struct { req, ok := GetValidatedDTO[dto.UpdateUsernameRequest](r)
Username string `json:"username"` if !ok {
} SendErrorResponse(w, "Invalid request", http.StatusBadRequest)
if !DecodeJSONRequest(w, r, &req) {
return return
} }
@@ -548,7 +461,7 @@ func (h *AuthHandler) UpdateUsername(w http.ResponseWriter, r *http.Request) {
// @Accept json // @Accept json
// @Produce json // @Produce json
// @Security BearerAuth // @Security BearerAuth
// @Param request body UpdatePasswordRequest true "Password update data" // @Param request body dto.UpdatePasswordRequest true "Password update data"
// @Success 200 {object} AuthResponse // @Success 200 {object} AuthResponse
// @Failure 400 {object} AuthResponse // @Failure 400 {object} AuthResponse
// @Failure 401 {object} AuthResponse // @Failure 401 {object} AuthResponse
@@ -560,12 +473,9 @@ func (h *AuthHandler) UpdatePassword(w http.ResponseWriter, r *http.Request) {
return return
} }
var req struct { req, ok := GetValidatedDTO[dto.UpdatePasswordRequest](r)
CurrentPassword string `json:"current_password"` if !ok {
NewPassword string `json:"new_password"` SendErrorResponse(w, "Invalid request", http.StatusBadRequest)
}
if !DecodeJSONRequest(w, r, &req) {
return return
} }
@@ -633,23 +543,21 @@ func (h *AuthHandler) DeleteAccount(w http.ResponseWriter, r *http.Request) {
// @Tags auth // @Tags auth
// @Accept json // @Accept json
// @Produce json // @Produce json
// @Param request body ConfirmAccountDeletionRequest true "Account deletion data" // @Param request body dto.ConfirmAccountDeletionRequest true "Account deletion data"
// @Success 200 {object} AuthResponse "Account deleted successfully" // @Success 200 {object} AuthResponse "Account deleted successfully"
// @Failure 400 {object} AuthResponse "Invalid or expired token" // @Failure 400 {object} AuthResponse "Invalid or expired token"
// @Failure 503 {object} AuthResponse "Email delivery unavailable" // @Failure 503 {object} AuthResponse "Email delivery unavailable"
// @Failure 500 {object} AuthResponse "Internal server error" // @Failure 500 {object} AuthResponse "Internal server error"
// @Router /api/auth/account/confirm [post] // @Router /api/auth/account/confirm [post]
func (h *AuthHandler) ConfirmAccountDeletion(w http.ResponseWriter, r *http.Request) { func (h *AuthHandler) ConfirmAccountDeletion(w http.ResponseWriter, r *http.Request) {
var req struct { req, ok := GetValidatedDTO[dto.ConfirmAccountDeletionRequest](r)
Token string `json:"token"` if !ok {
DeletePosts bool `json:"delete_posts"` SendErrorResponse(w, "Invalid request", http.StatusBadRequest)
}
if !DecodeJSONRequest(w, r, &req) {
return return
} }
token := strings.TrimSpace(req.Token) token := strings.TrimSpace(req.Token)
if token == "" { if token == "" {
SendErrorResponse(w, "Deletion token is required", http.StatusBadRequest) SendErrorResponse(w, "Deletion token is required", http.StatusBadRequest)
return return
@@ -694,7 +602,7 @@ func (h *AuthHandler) Logout(w http.ResponseWriter, r *http.Request) {
// @Tags auth // @Tags auth
// @Accept json // @Accept json
// @Produce json // @Produce json
// @Param request body RefreshTokenRequest true "Refresh token data" // @Param request body dto.RefreshTokenRequest true "Refresh token data"
// @Success 200 {object} AuthTokensResponse "Token refreshed successfully" // @Success 200 {object} AuthTokensResponse "Token refreshed successfully"
// @Failure 400 {object} AuthResponse "Invalid request body or missing refresh token" // @Failure 400 {object} AuthResponse "Invalid request body or missing refresh token"
// @Failure 401 {object} AuthResponse "Invalid or expired refresh token" // @Failure 401 {object} AuthResponse "Invalid or expired refresh token"
@@ -702,13 +610,13 @@ func (h *AuthHandler) Logout(w http.ResponseWriter, r *http.Request) {
// @Failure 500 {object} AuthResponse "Internal server error" // @Failure 500 {object} AuthResponse "Internal server error"
// @Router /api/auth/refresh [post] // @Router /api/auth/refresh [post]
func (h *AuthHandler) RefreshToken(w http.ResponseWriter, r *http.Request) { func (h *AuthHandler) RefreshToken(w http.ResponseWriter, r *http.Request) {
var req RefreshTokenRequest req, ok := GetValidatedDTO[dto.RefreshTokenRequest](r)
if !ok {
if !DecodeJSONRequest(w, r, &req) { SendErrorResponse(w, "Invalid request", http.StatusBadRequest)
return return
} }
if strings.TrimSpace(req.RefreshToken) == "" { if req.RefreshToken == "" {
SendErrorResponse(w, "Refresh token is required", http.StatusBadRequest) SendErrorResponse(w, "Refresh token is required", http.StatusBadRequest)
return return
} }
@@ -727,20 +635,20 @@ func (h *AuthHandler) RefreshToken(w http.ResponseWriter, r *http.Request) {
// @Accept json // @Accept json
// @Produce json // @Produce json
// @Security BearerAuth // @Security BearerAuth
// @Param request body RevokeTokenRequest true "Token revocation data" // @Param request body dto.RevokeTokenRequest true "Token revocation data"
// @Success 200 {object} AuthResponse "Token revoked successfully" // @Success 200 {object} AuthResponse "Token revoked successfully"
// @Failure 400 {object} AuthResponse "Invalid request body or missing refresh token" // @Failure 400 {object} AuthResponse "Invalid request body or missing refresh token"
// @Failure 401 {object} AuthResponse "Invalid or expired access token" // @Failure 401 {object} AuthResponse "Invalid or expired access token"
// @Failure 500 {object} AuthResponse "Internal server error" // @Failure 500 {object} AuthResponse "Internal server error"
// @Router /api/auth/revoke [post] // @Router /api/auth/revoke [post]
func (h *AuthHandler) RevokeToken(w http.ResponseWriter, r *http.Request) { func (h *AuthHandler) RevokeToken(w http.ResponseWriter, r *http.Request) {
var req RevokeTokenRequest req, ok := GetValidatedDTO[dto.RevokeTokenRequest](r)
if !ok {
if !DecodeJSONRequest(w, r, &req) { SendErrorResponse(w, "Invalid request", http.StatusBadRequest)
return return
} }
if strings.TrimSpace(req.RefreshToken) == "" { if req.RefreshToken == "" {
SendErrorResponse(w, "Refresh token is required", http.StatusBadRequest) SendErrorResponse(w, "Refresh token is required", http.StatusBadRequest)
return return
} }
@@ -782,28 +690,28 @@ func (h *AuthHandler) RevokeAllTokens(w http.ResponseWriter, r *http.Request) {
func (h *AuthHandler) MountRoutes(r chi.Router, config RouteModuleConfig) { func (h *AuthHandler) MountRoutes(r chi.Router, config RouteModuleConfig) {
if config.GeneralRateLimit != nil { if config.GeneralRateLimit != nil {
rateLimited := config.GeneralRateLimit(r) rateLimited := config.GeneralRateLimit(r)
rateLimited.Post("/auth/refresh", h.RefreshToken) rateLimited.Post("/auth/refresh", WithValidation[dto.RefreshTokenRequest](config.ValidationMiddleware, h.RefreshToken))
rateLimited.Get("/auth/confirm", h.ConfirmEmail) rateLimited.Get("/auth/confirm", h.ConfirmEmail)
rateLimited.Post("/auth/resend-verification", h.ResendVerificationEmail) rateLimited.Post("/auth/resend-verification", WithValidation[dto.ResendVerificationRequest](config.ValidationMiddleware, h.ResendVerificationEmail))
} else { } else {
r.Post("/auth/refresh", h.RefreshToken) r.Post("/auth/refresh", WithValidation[dto.RefreshTokenRequest](config.ValidationMiddleware, h.RefreshToken))
r.Get("/auth/confirm", h.ConfirmEmail) r.Get("/auth/confirm", h.ConfirmEmail)
r.Post("/auth/resend-verification", h.ResendVerificationEmail) r.Post("/auth/resend-verification", WithValidation[dto.ResendVerificationRequest](config.ValidationMiddleware, h.ResendVerificationEmail))
} }
if config.AuthRateLimit != nil { if config.AuthRateLimit != nil {
rateLimited := config.AuthRateLimit(r) rateLimited := config.AuthRateLimit(r)
rateLimited.Post("/auth/register", h.Register) rateLimited.Post("/auth/register", WithValidation[dto.RegisterRequest](config.ValidationMiddleware, h.Register))
rateLimited.Post("/auth/login", h.Login) rateLimited.Post("/auth/login", WithValidation[dto.LoginRequest](config.ValidationMiddleware, h.Login))
rateLimited.Post("/auth/forgot-password", h.RequestPasswordReset) rateLimited.Post("/auth/forgot-password", WithValidation[dto.ForgotPasswordRequest](config.ValidationMiddleware, h.RequestPasswordReset))
rateLimited.Post("/auth/reset-password", h.ResetPassword) rateLimited.Post("/auth/reset-password", WithValidation[dto.ResetPasswordRequest](config.ValidationMiddleware, h.ResetPassword))
rateLimited.Post("/auth/account/confirm", h.ConfirmAccountDeletion) rateLimited.Post("/auth/account/confirm", WithValidation[dto.ConfirmAccountDeletionRequest](config.ValidationMiddleware, h.ConfirmAccountDeletion))
} else { } else {
r.Post("/auth/register", h.Register) r.Post("/auth/register", WithValidation[dto.RegisterRequest](config.ValidationMiddleware, h.Register))
r.Post("/auth/login", h.Login) r.Post("/auth/login", WithValidation[dto.LoginRequest](config.ValidationMiddleware, h.Login))
r.Post("/auth/forgot-password", h.RequestPasswordReset) r.Post("/auth/forgot-password", WithValidation[dto.ForgotPasswordRequest](config.ValidationMiddleware, h.RequestPasswordReset))
r.Post("/auth/reset-password", h.ResetPassword) r.Post("/auth/reset-password", WithValidation[dto.ResetPasswordRequest](config.ValidationMiddleware, h.ResetPassword))
r.Post("/auth/account/confirm", h.ConfirmAccountDeletion) r.Post("/auth/account/confirm", WithValidation[dto.ConfirmAccountDeletionRequest](config.ValidationMiddleware, h.ConfirmAccountDeletion))
} }
protected := r protected := r
@@ -816,10 +724,10 @@ func (h *AuthHandler) MountRoutes(r chi.Router, config RouteModuleConfig) {
protected.Get("/auth/me", h.Me) protected.Get("/auth/me", h.Me)
protected.Post("/auth/logout", h.Logout) protected.Post("/auth/logout", h.Logout)
protected.Post("/auth/revoke", h.RevokeToken) protected.Post("/auth/revoke", WithValidation[dto.RevokeTokenRequest](config.ValidationMiddleware, h.RevokeToken))
protected.Post("/auth/revoke-all", h.RevokeAllTokens) protected.Post("/auth/revoke-all", h.RevokeAllTokens)
protected.Put("/auth/email", h.UpdateEmail) protected.Put("/auth/email", WithValidation[dto.UpdateEmailRequest](config.ValidationMiddleware, h.UpdateEmail))
protected.Put("/auth/username", h.UpdateUsername) protected.Put("/auth/username", WithValidation[dto.UpdateUsernameRequest](config.ValidationMiddleware, h.UpdateUsername))
protected.Put("/auth/password", h.UpdatePassword) protected.Put("/auth/password", WithValidation[dto.UpdatePasswordRequest](config.ValidationMiddleware, h.UpdatePassword))
protected.Delete("/auth/account", h.DeleteAccount) protected.Delete("/auth/account", h.DeleteAccount)
} }

View File

@@ -252,8 +252,8 @@ func TestAuthHandlerLoginSuccess(t *testing.T) {
} }
handler := newAuthHandler(repo) handler := newAuthHandler(repo)
body := bytes.NewBufferString(`{"username":"user","password":"Password123!"}`) bodyStr := `{"username":"user","password":"Password123!"}`
request := httptest.NewRequest(http.MethodPost, "/api/auth/login", body) request := createLoginRequest(bodyStr)
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
handler.Login(recorder, request) handler.Login(recorder, request)
@@ -274,17 +274,17 @@ func TestAuthHandlerLoginErrors(t *testing.T) {
handler := newAuthHandler(&testutils.UserRepositoryStub{}) handler := newAuthHandler(&testutils.UserRepositoryStub{})
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
request := httptest.NewRequest(http.MethodPost, "/api/auth/login", bytes.NewBufferString("invalid")) request := createLoginRequest("invalid")
handler.Login(recorder, request) handler.Login(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest) testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest)
recorder = httptest.NewRecorder() recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/auth/login", bytes.NewBufferString(`{"username":" ","password":""}`)) request = createLoginRequest(`{"username":" ","password":""}`)
handler.Login(recorder, request) handler.Login(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest) testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest)
recorder = httptest.NewRecorder() recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/auth/login", bytes.NewBufferString(`{"username":"user","password":"WrongPass123!"}`)) request = createLoginRequest(`{"username":"user","password":"WrongPass123!"}`)
handler.Login(recorder, request) handler.Login(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusUnauthorized) testutils.AssertHTTPStatus(t, recorder, http.StatusUnauthorized)
@@ -294,7 +294,7 @@ func TestAuthHandlerLoginErrors(t *testing.T) {
}} }}
handler = newAuthHandler(repo) handler = newAuthHandler(repo)
recorder = httptest.NewRecorder() recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/auth/login", bytes.NewBufferString(`{"username":"user","password":"Password123!"}`)) request = createLoginRequest(`{"username":"user","password":"Password123!"}`)
handler.Login(recorder, request) handler.Login(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusForbidden) testutils.AssertHTTPStatus(t, recorder, http.StatusForbidden)
@@ -304,7 +304,7 @@ func TestAuthHandlerLoginErrors(t *testing.T) {
handler = newAuthHandler(repo) handler = newAuthHandler(repo)
recorder = httptest.NewRecorder() recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/auth/login", bytes.NewBufferString(`{"username":"user","password":"Password123!"}`)) request = createLoginRequest(`{"username":"user","password":"Password123!"}`)
handler.Login(recorder, request) handler.Login(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusInternalServerError) testutils.AssertHTTPStatus(t, recorder, http.StatusInternalServerError)
@@ -330,8 +330,7 @@ func TestAuthHandlerRegisterSuccess(t *testing.T) {
return nil return nil
}}) }})
body := bytes.NewBufferString(`{"username":"newuser","email":"new@example.com","password":"Password123!"}`) request := createRegisterRequest(`{"username":"newuser","email":"new@example.com","password":"Password123!"}`)
request := httptest.NewRequest(http.MethodPost, "/api/auth/register", body)
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
handler.Register(recorder, request) handler.Register(recorder, request)
@@ -354,12 +353,12 @@ func TestAuthHandlerRegisterErrors(t *testing.T) {
handler := newAuthHandler(repo) handler := newAuthHandler(repo)
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
request := httptest.NewRequest(http.MethodPost, "/api/auth/register", bytes.NewBufferString("invalid")) request := createRegisterRequest("invalid")
handler.Register(recorder, request) handler.Register(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest) testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest)
recorder = httptest.NewRecorder() recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/auth/register", bytes.NewBufferString(`{"username":"","email":"","password":""}`)) request = createRegisterRequest(`{"username":"","email":"","password":""}`)
handler.Register(recorder, request) handler.Register(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest) testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest)
@@ -368,7 +367,7 @@ func TestAuthHandlerRegisterErrors(t *testing.T) {
}} }}
handler = newAuthHandler(repo) handler = newAuthHandler(repo)
recorder = httptest.NewRecorder() recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/auth/register", bytes.NewBufferString(`{"username":"new","email":"taken@example.com","password":"Password123!"}`)) request = createRegisterRequest(`{"username":"new","email":"taken@example.com","password":"Password123!"}`)
handler.Register(recorder, request) handler.Register(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusConflict) testutils.AssertHTTPStatus(t, recorder, http.StatusConflict)
@@ -382,7 +381,7 @@ func TestAuthHandlerRegisterErrors(t *testing.T) {
} }
handler = newAuthHandler(repo) handler = newAuthHandler(repo)
recorder = httptest.NewRecorder() recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/auth/register", bytes.NewBufferString(`{"username":"another","email":"taken@example.com","password":"Password123!"}`)) request = createRegisterRequest(`{"username":"another","email":"taken@example.com","password":"Password123!"}`)
handler.Register(recorder, request) handler.Register(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusConflict) testutils.AssertHTTPStatus(t, recorder, http.StatusConflict)
} }
@@ -477,7 +476,7 @@ func TestAuthHandlerRequestPasswordReset(t *testing.T) {
}}) }})
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
request := httptest.NewRequest(http.MethodPost, "/api/auth/forgot-password", bytes.NewBufferString(`{"username_or_email":"user@example.com"}`)) request := createForgotPasswordRequest(`{"username_or_email":"user@example.com"}`)
handler.RequestPasswordReset(recorder, request) handler.RequestPasswordReset(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusOK) testutils.AssertHTTPStatus(t, recorder, http.StatusOK)
@@ -495,19 +494,19 @@ func TestAuthHandlerRequestPasswordReset(t *testing.T) {
}}) }})
recorder = httptest.NewRecorder() recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/auth/forgot-password", bytes.NewBufferString(`{"username_or_email":"user"}`)) request = createForgotPasswordRequest(`{"username_or_email":"user"}`)
handler.RequestPasswordReset(recorder, request) handler.RequestPasswordReset(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusOK) testutils.AssertHTTPStatus(t, recorder, http.StatusOK)
recorder = httptest.NewRecorder() recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/auth/forgot-password", bytes.NewBufferString(`{"username_or_email":""}`)) request = createForgotPasswordRequest(`{"username_or_email":""}`)
handler.RequestPasswordReset(recorder, request) handler.RequestPasswordReset(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest) testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest)
recorder = httptest.NewRecorder() recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/auth/forgot-password", bytes.NewBufferString(`invalid json`)) request = createForgotPasswordRequest(`invalid json`)
handler.RequestPasswordReset(recorder, request) handler.RequestPasswordReset(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest) testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest)
@@ -518,25 +517,25 @@ func TestAuthHandlerResetPassword(t *testing.T) {
handler := newAuthHandler(repo) handler := newAuthHandler(repo)
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
request := httptest.NewRequest(http.MethodPost, "/api/auth/reset-password", bytes.NewBufferString(`{"new_password":"NewPassword123!"}`)) request := createResetPasswordRequest(`{"new_password":"NewPassword123!"}`)
handler.ResetPassword(recorder, request) handler.ResetPassword(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest) testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest)
recorder = httptest.NewRecorder() recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/auth/reset-password", bytes.NewBufferString(`{"token":"valid_token"}`)) request = createResetPasswordRequest(`{"token":"valid_token"}`)
handler.ResetPassword(recorder, request) handler.ResetPassword(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest) testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest)
recorder = httptest.NewRecorder() recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/auth/reset-password", bytes.NewBufferString(`{"token":"valid_token","new_password":"short"}`)) request = createResetPasswordRequest(`{"token":"valid_token","new_password":"short"}`)
handler.ResetPassword(recorder, request) handler.ResetPassword(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest) testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest)
recorder = httptest.NewRecorder() recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/auth/reset-password", bytes.NewBufferString(`invalid json`)) request = createResetPasswordRequest(`invalid json`)
handler.ResetPassword(recorder, request) handler.ResetPassword(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest) testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest)
@@ -602,7 +601,7 @@ func TestAuthHandlerResetPasswordServiceOutcomes(t *testing.T) {
handler := newMockAuthHandler(repo, mockService) handler := newMockAuthHandler(repo, mockService)
request := httptest.NewRequest(http.MethodPost, "/api/auth/reset-password", bytes.NewBufferString(`{"token":"abc","new_password":"Password123!"}`)) request := createResetPasswordRequest(`{"token":"abc","new_password":"Password123!"}`)
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
handler.ResetPassword(recorder, request) handler.ResetPassword(recorder, request)
@@ -664,7 +663,7 @@ func TestAuthHandlerUpdateEmail(t *testing.T) {
userID: 1, userID: 1,
mockSetup: func(repo *testutils.UserRepositoryStub) {}, mockSetup: func(repo *testutils.UserRepositoryStub) {},
expectedStatus: http.StatusBadRequest, expectedStatus: http.StatusBadRequest,
expectedError: "Invalid request body", expectedError: "Invalid request",
}, },
{ {
name: "empty email", name: "empty email",
@@ -702,7 +701,7 @@ func TestAuthHandlerUpdateEmail(t *testing.T) {
handler := newMockAuthHandler(repo, mockService) handler := newMockAuthHandler(repo, mockService)
request := httptest.NewRequest(http.MethodPut, "/api/auth/email", bytes.NewBufferString(tt.requestBody)) request := createUpdateEmailRequest(tt.requestBody)
if tt.userID > 0 { if tt.userID > 0 {
ctx := context.WithValue(request.Context(), middleware.UserIDKey, tt.userID) ctx := context.WithValue(request.Context(), middleware.UserIDKey, tt.userID)
request = request.WithContext(ctx) request = request.WithContext(ctx)
@@ -789,7 +788,7 @@ func TestAuthHandlerUpdateUsername(t *testing.T) {
handler := newMockAuthHandler(repo, mockService) handler := newMockAuthHandler(repo, mockService)
request := httptest.NewRequest(http.MethodPut, "/api/auth/username", bytes.NewBufferString(tt.requestBody)) request := createUpdateUsernameRequest(tt.requestBody)
if tt.userID > 0 { if tt.userID > 0 {
ctx := context.WithValue(request.Context(), middleware.UserIDKey, tt.userID) ctx := context.WithValue(request.Context(), middleware.UserIDKey, tt.userID)
request = request.WithContext(ctx) request = request.WithContext(ctx)
@@ -886,7 +885,7 @@ func TestAuthHandlerUpdatePassword(t *testing.T) {
tt.mockSetup(repo) tt.mockSetup(repo)
handler := newAuthHandler(repo) handler := newAuthHandler(repo)
request := httptest.NewRequest(http.MethodPut, "/api/auth/password", bytes.NewBufferString(tt.requestBody)) request := createUpdatePasswordRequest(tt.requestBody)
if tt.userID > 0 { if tt.userID > 0 {
ctx := context.WithValue(request.Context(), middleware.UserIDKey, tt.userID) ctx := context.WithValue(request.Context(), middleware.UserIDKey, tt.userID)
request = request.WithContext(ctx) request = request.WithContext(ctx)
@@ -984,8 +983,7 @@ func TestAuthHandlerDeleteAccount(t *testing.T) {
func TestAuthHandlerResendVerificationEmail(t *testing.T) { func TestAuthHandlerResendVerificationEmail(t *testing.T) {
makeRequest := func(body string, setup func(*mockAuthService)) (*httptest.ResponseRecorder, AuthResponse) { makeRequest := func(body string, setup func(*mockAuthService)) (*httptest.ResponseRecorder, AuthResponse) {
request := httptest.NewRequest(http.MethodPost, "/api/auth/resend-verification", bytes.NewBufferString(body)) request := createResendVerificationRequest(body)
request = request.WithContext(context.Background())
repo := &testutils.UserRepositoryStub{} repo := &testutils.UserRepositoryStub{}
mockService := &mockAuthService{} mockService := &mockAuthService{}
@@ -1014,7 +1012,7 @@ func TestAuthHandlerResendVerificationEmail(t *testing.T) {
name: "invalid json", name: "invalid json",
body: "not-json", body: "not-json",
expectedStatus: http.StatusBadRequest, expectedStatus: http.StatusBadRequest,
expectedError: "Invalid request body", expectedError: "Invalid request",
}, },
{ {
name: "missing email", name: "missing email",
@@ -1139,7 +1137,7 @@ func TestAuthHandlerConfirmAccountDeletion(t *testing.T) {
name: "invalid json", name: "invalid json",
body: "not-json", body: "not-json",
expectedStatus: http.StatusBadRequest, expectedStatus: http.StatusBadRequest,
expectedError: "Invalid request body", expectedError: "Invalid request",
}, },
{ {
name: "missing token", name: "missing token",
@@ -1209,7 +1207,7 @@ func TestAuthHandlerConfirmAccountDeletion(t *testing.T) {
handler := newMockAuthHandler(repo, mockService) handler := newMockAuthHandler(repo, mockService)
request := httptest.NewRequest(http.MethodPost, "/api/auth/account/confirm", bytes.NewBufferString(tt.body)) request := createConfirmAccountDeletionRequest(tt.body)
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
handler.ConfirmAccountDeletion(recorder, request) handler.ConfirmAccountDeletion(recorder, request)
@@ -1338,9 +1336,7 @@ func TestAuthHandler_ConcurrentAccess(t *testing.T) {
for i := 0; i < concurrency; i++ { for i := 0; i < concurrency; i++ {
go func() { go func() {
body := bytes.NewBufferString(`{"username":"testuser","password":"Password123!"}`) req := createLoginRequest(`{"username":"testuser","password":"Password123!"}`)
req := httptest.NewRequest("POST", "/api/auth/login", body)
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder() w := httptest.NewRecorder()
handler.Login(w, req) handler.Login(w, req)
@@ -1370,8 +1366,7 @@ func TestAuthHandler_RefreshToken(t *testing.T) {
}, nil }, nil
} }
body := bytes.NewBufferString(`{"refresh_token":"valid_refresh_token"}`) req := createRefreshTokenRequest(`{"refresh_token":"valid_refresh_token"}`)
req := httptest.NewRequest("POST", "/api/auth/refresh", body)
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder() w := httptest.NewRecorder()
@@ -1381,8 +1376,7 @@ func TestAuthHandler_RefreshToken(t *testing.T) {
}) })
t.Run("Invalid_Request_Body", func(t *testing.T) { t.Run("Invalid_Request_Body", func(t *testing.T) {
body := bytes.NewBufferString(`invalid json`) req := createRefreshTokenRequest(`invalid json`)
req := httptest.NewRequest("POST", "/api/auth/refresh", body)
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder() w := httptest.NewRecorder()
@@ -1392,8 +1386,7 @@ func TestAuthHandler_RefreshToken(t *testing.T) {
}) })
t.Run("Missing_Refresh_Token", func(t *testing.T) { t.Run("Missing_Refresh_Token", func(t *testing.T) {
body := bytes.NewBufferString(`{"refresh_token":""}`) req := createRefreshTokenRequest(`{"refresh_token":""}`)
req := httptest.NewRequest("POST", "/api/auth/refresh", body)
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder() w := httptest.NewRecorder()
@@ -1407,8 +1400,7 @@ func TestAuthHandler_RefreshToken(t *testing.T) {
return nil, services.ErrRefreshTokenExpired return nil, services.ErrRefreshTokenExpired
} }
body := bytes.NewBufferString(`{"refresh_token":"expired_token"}`) req := createRefreshTokenRequest(`{"refresh_token":"expired_token"}`)
req := httptest.NewRequest("POST", "/api/auth/refresh", body)
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder() w := httptest.NewRecorder()
@@ -1422,8 +1414,7 @@ func TestAuthHandler_RefreshToken(t *testing.T) {
return nil, services.ErrRefreshTokenInvalid return nil, services.ErrRefreshTokenInvalid
} }
body := bytes.NewBufferString(`{"refresh_token":"invalid_token"}`) req := createRefreshTokenRequest(`{"refresh_token":"invalid_token"}`)
req := httptest.NewRequest("POST", "/api/auth/refresh", body)
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder() w := httptest.NewRecorder()
@@ -1437,8 +1428,7 @@ func TestAuthHandler_RefreshToken(t *testing.T) {
return nil, services.ErrAccountLocked return nil, services.ErrAccountLocked
} }
body := bytes.NewBufferString(`{"refresh_token":"locked_token"}`) req := createRefreshTokenRequest(`{"refresh_token":"locked_token"}`)
req := httptest.NewRequest("POST", "/api/auth/refresh", body)
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder() w := httptest.NewRecorder()
@@ -1452,8 +1442,7 @@ func TestAuthHandler_RefreshToken(t *testing.T) {
return nil, fmt.Errorf("internal error") return nil, fmt.Errorf("internal error")
} }
body := bytes.NewBufferString(`{"refresh_token":"error_token"}`) req := createRefreshTokenRequest(`{"refresh_token":"error_token"}`)
req := httptest.NewRequest("POST", "/api/auth/refresh", body)
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder() w := httptest.NewRecorder()
@@ -1473,8 +1462,7 @@ func TestAuthHandler_RevokeToken(t *testing.T) {
return nil return nil
} }
body := bytes.NewBufferString(`{"refresh_token":"token_to_revoke"}`) req := createRevokeTokenRequest(`{"refresh_token":"token_to_revoke"}`)
req := httptest.NewRequest("POST", "/api/auth/revoke", body)
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder() w := httptest.NewRecorder()
@@ -1484,8 +1472,7 @@ func TestAuthHandler_RevokeToken(t *testing.T) {
}) })
t.Run("Invalid_Request_Body", func(t *testing.T) { t.Run("Invalid_Request_Body", func(t *testing.T) {
body := bytes.NewBufferString(`invalid json`) req := createRevokeTokenRequest(`invalid json`)
req := httptest.NewRequest("POST", "/api/auth/revoke", body)
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder() w := httptest.NewRecorder()
@@ -1495,8 +1482,7 @@ func TestAuthHandler_RevokeToken(t *testing.T) {
}) })
t.Run("Missing_Refresh_Token", func(t *testing.T) { t.Run("Missing_Refresh_Token", func(t *testing.T) {
body := bytes.NewBufferString(`{"refresh_token":""}`) req := createRevokeTokenRequest(`{"refresh_token":""}`)
req := httptest.NewRequest("POST", "/api/auth/revoke", body)
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder() w := httptest.NewRecorder()
@@ -1510,8 +1496,7 @@ func TestAuthHandler_RevokeToken(t *testing.T) {
return fmt.Errorf("revoke failed") return fmt.Errorf("revoke failed")
} }
body := bytes.NewBufferString(`{"refresh_token":"token"}`) req := createRevokeTokenRequest(`{"refresh_token":"token"}`)
req := httptest.NewRequest("POST", "/api/auth/revoke", body)
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder() w := httptest.NewRecorder()

View File

@@ -6,6 +6,7 @@ import (
"fmt" "fmt"
"net/http" "net/http"
"net/url" "net/url"
"reflect"
"strconv" "strconv"
"strings" "strings"
"time" "time"
@@ -290,3 +291,24 @@ func HandleServiceError(w http.ResponseWriter, err error, defaultMsg string, def
SendErrorResponse(w, defaultMsg, defaultCode) SendErrorResponse(w, defaultMsg, defaultCode)
return false return false
} }
func GetValidatedDTO[T any](r *http.Request) (*T, bool) {
dtoVal := middleware.GetValidatedDTOFromContext(r.Context())
if dtoVal == nil {
return nil, false
}
dto, ok := dtoVal.(*T)
return dto, ok
}
func WithValidation[T any](validationMiddleware func(http.Handler) http.Handler, handler http.HandlerFunc) http.HandlerFunc {
if validationMiddleware == nil {
return handler
}
var zero T
dtoType := reflect.TypeOf(zero)
return func(w http.ResponseWriter, r *http.Request) {
ctx := middleware.SetDTOTypeInContext(r.Context(), dtoType)
validationMiddleware(handler).ServeHTTP(w, r.WithContext(ctx))
}
}

View File

@@ -1,6 +1,7 @@
package handlers package handlers
import ( import (
"bytes"
"context" "context"
"encoding/json" "encoding/json"
"errors" "errors"
@@ -11,6 +12,7 @@ import (
"testing" "testing"
"goyco/internal/database" "goyco/internal/database"
"goyco/internal/dto"
"goyco/internal/middleware" "goyco/internal/middleware"
"goyco/internal/services" "goyco/internal/services"
"goyco/internal/testutils" "goyco/internal/testutils"
@@ -721,6 +723,74 @@ func TestDecodeJSONRequest(t *testing.T) {
} }
} }
func createRequestWithDTO[T any](method, url string, body []byte) *http.Request {
r := httptest.NewRequest(method, url, bytes.NewReader(body))
var dto T
if len(body) > 0 {
if err := json.NewDecoder(bytes.NewReader(body)).Decode(&dto); err != nil {
return r
}
}
ctx := middleware.SetValidatedDTOInContext(r.Context(), &dto)
return r.WithContext(ctx)
}
func createLoginRequest(body string) *http.Request {
return createRequestWithDTO[dto.LoginRequest](http.MethodPost, "/api/auth/login", []byte(body))
}
func createRegisterRequest(body string) *http.Request {
return createRequestWithDTO[dto.RegisterRequest](http.MethodPost, "/api/auth/register", []byte(body))
}
func createResendVerificationRequest(body string) *http.Request {
return createRequestWithDTO[dto.ResendVerificationRequest](http.MethodPost, "/api/auth/resend-verification", []byte(body))
}
func createForgotPasswordRequest(body string) *http.Request {
return createRequestWithDTO[dto.ForgotPasswordRequest](http.MethodPost, "/api/auth/forgot-password", []byte(body))
}
func createResetPasswordRequest(body string) *http.Request {
return createRequestWithDTO[dto.ResetPasswordRequest](http.MethodPost, "/api/auth/reset-password", []byte(body))
}
func createUpdateEmailRequest(body string) *http.Request {
return createRequestWithDTO[dto.UpdateEmailRequest](http.MethodPut, "/api/auth/email", []byte(body))
}
func createUpdateUsernameRequest(body string) *http.Request {
return createRequestWithDTO[dto.UpdateUsernameRequest](http.MethodPut, "/api/auth/username", []byte(body))
}
func createUpdatePasswordRequest(body string) *http.Request {
return createRequestWithDTO[dto.UpdatePasswordRequest](http.MethodPut, "/api/auth/password", []byte(body))
}
func createConfirmAccountDeletionRequest(body string) *http.Request {
return createRequestWithDTO[dto.ConfirmAccountDeletionRequest](http.MethodPost, "/api/auth/account/confirm", []byte(body))
}
func createRefreshTokenRequest(body string) *http.Request {
return createRequestWithDTO[dto.RefreshTokenRequest](http.MethodPost, "/api/auth/refresh", []byte(body))
}
func createRevokeTokenRequest(body string) *http.Request {
return createRequestWithDTO[dto.RevokeTokenRequest](http.MethodPost, "/api/auth/revoke", []byte(body))
}
func createCreatePostRequest(body string) *http.Request {
return createRequestWithDTO[dto.CreatePostRequest](http.MethodPost, "/api/posts", []byte(body))
}
func createUpdatePostRequest(body string) *http.Request {
return createRequestWithDTO[dto.UpdatePostRequest](http.MethodPut, "/api/posts/1", []byte(body))
}
func createVoteRequest(body string) *http.Request {
return createRequestWithDTO[dto.CastVoteRequest](http.MethodPost, "/api/posts/1/vote", []byte(body))
}
func TestParsePagination(t *testing.T) { func TestParsePagination(t *testing.T) {
tests := []struct { tests := []struct {
name string name string

View File

@@ -6,6 +6,7 @@ import (
"fmt" "fmt"
"html/template" "html/template"
"net/http" "net/http"
"net/url"
"os" "os"
"path/filepath" "path/filepath"
"strconv" "strconv"
@@ -877,7 +878,8 @@ func (h *PageHandler) ResetPassword(w http.ResponseWriter, r *http.Request) {
func (h *PageHandler) Settings(w http.ResponseWriter, r *http.Request) { func (h *PageHandler) Settings(w http.ResponseWriter, r *http.Request) {
user := h.currentUserWithLockCheck(w, r) user := h.currentUserWithLockCheck(w, r)
if user == nil { if user == nil {
http.Redirect(w, r, "/login?flash=Sign in to manage your account", http.StatusSeeOther) redirectURL := "/login?flash=" + url.QueryEscape("Sign in to manage your account")
http.Redirect(w, r, redirectURL, http.StatusSeeOther)
return return
} }
@@ -897,7 +899,8 @@ func (h *PageHandler) Settings(w http.ResponseWriter, r *http.Request) {
func (h *PageHandler) UpdateEmail(w http.ResponseWriter, r *http.Request) { func (h *PageHandler) UpdateEmail(w http.ResponseWriter, r *http.Request) {
user := h.currentUserWithLockCheck(w, r) user := h.currentUserWithLockCheck(w, r)
if user == nil { if user == nil {
http.Redirect(w, r, "/login?flash=Sign in to manage your account", http.StatusSeeOther) redirectURL := "/login?flash=" + url.QueryEscape("Sign in to manage your account")
http.Redirect(w, r, redirectURL, http.StatusSeeOther)
return return
} }
@@ -960,13 +963,15 @@ func (h *PageHandler) UpdateEmail(w http.ResponseWriter, r *http.Request) {
} }
h.clearAuthCookie(w, r) h.clearAuthCookie(w, r)
http.Redirect(w, r, "/login?flash=Email updated. Check your inbox to confirm the new address. You will need to sign in again after verification.", http.StatusSeeOther) redirectURL := "/login?flash=" + url.QueryEscape("Email updated. Check your inbox to confirm the new address. You will need to sign in again after verification.")
http.Redirect(w, r, redirectURL, http.StatusSeeOther)
} }
func (h *PageHandler) UpdateUsername(w http.ResponseWriter, r *http.Request) { func (h *PageHandler) UpdateUsername(w http.ResponseWriter, r *http.Request) {
user := h.currentUserWithLockCheck(w, r) user := h.currentUserWithLockCheck(w, r)
if user == nil { if user == nil {
http.Redirect(w, r, "/login?flash=Sign in to manage your account", http.StatusSeeOther) redirectURL := "/login?flash=" + url.QueryEscape("Sign in to manage your account")
http.Redirect(w, r, redirectURL, http.StatusSeeOther)
return return
} }
@@ -1022,13 +1027,15 @@ func (h *PageHandler) UpdateUsername(w http.ResponseWriter, r *http.Request) {
return return
} }
http.Redirect(w, r, "/settings?flash=Username updated successfully.", http.StatusSeeOther) redirectURL := "/settings?flash=" + url.QueryEscape("Username updated successfully.")
http.Redirect(w, r, redirectURL, http.StatusSeeOther)
} }
func (h *PageHandler) UpdatePassword(w http.ResponseWriter, r *http.Request) { func (h *PageHandler) UpdatePassword(w http.ResponseWriter, r *http.Request) {
user := h.currentUserWithLockCheck(w, r) user := h.currentUserWithLockCheck(w, r)
if user == nil { if user == nil {
http.Redirect(w, r, "/login?flash=Sign in to manage your account", http.StatusSeeOther) redirectURL := "/login?flash=" + url.QueryEscape("Sign in to manage your account")
http.Redirect(w, r, redirectURL, http.StatusSeeOther)
return return
} }
@@ -1140,13 +1147,15 @@ func (h *PageHandler) UpdatePassword(w http.ResponseWriter, r *http.Request) {
return return
} }
http.Redirect(w, r, "/settings?flash=Password updated successfully.", http.StatusSeeOther) redirectURL := "/settings?flash=" + url.QueryEscape("Password updated successfully.")
http.Redirect(w, r, redirectURL, http.StatusSeeOther)
} }
func (h *PageHandler) DeleteAccount(w http.ResponseWriter, r *http.Request) { func (h *PageHandler) DeleteAccount(w http.ResponseWriter, r *http.Request) {
user := h.currentUserWithLockCheck(w, r) user := h.currentUserWithLockCheck(w, r)
if user == nil { if user == nil {
http.Redirect(w, r, "/login?flash=Sign in to manage your account", http.StatusSeeOther) redirectURL := "/login?flash=" + url.QueryEscape("Sign in to manage your account")
http.Redirect(w, r, redirectURL, http.StatusSeeOther)
return return
} }
@@ -1204,7 +1213,8 @@ func (h *PageHandler) DeleteAccount(w http.ResponseWriter, r *http.Request) {
return return
} }
http.Redirect(w, r, "/settings?flash=Check your inbox for a confirmation link to finish deleting your account.", http.StatusSeeOther) redirectURL := "/settings?flash=" + url.QueryEscape("Check your inbox for a confirmation link to finish deleting your account.")
http.Redirect(w, r, redirectURL, http.StatusSeeOther)
} }
func (h *PageHandler) ConfirmAccountDeletion(w http.ResponseWriter, r *http.Request) { func (h *PageHandler) ConfirmAccountDeletion(w http.ResponseWriter, r *http.Request) {
@@ -1328,7 +1338,8 @@ func (h *PageHandler) clearAuthCookie(w http.ResponseWriter, r *http.Request) {
func (h *PageHandler) Vote(w http.ResponseWriter, r *http.Request) { func (h *PageHandler) Vote(w http.ResponseWriter, r *http.Request) {
user := h.currentUserWithLockCheck(w, r) user := h.currentUserWithLockCheck(w, r)
if user == nil { if user == nil {
http.Redirect(w, r, "/login?flash=Please sign in to vote", http.StatusSeeOther) redirectURL := "/login?flash=" + url.QueryEscape("Please sign in to vote")
http.Redirect(w, r, redirectURL, http.StatusSeeOther)
return return
} }

View File

@@ -36,11 +36,6 @@ func NewPostHandler(postRepo repositories.PostRepository, titleFetcher services.
type PostResponse = CommonResponse type PostResponse = CommonResponse
type UpdatePostRequest struct {
Title string `json:"title"`
Content string `json:"content"`
}
// @Summary Get posts // @Summary Get posts
// @Description Get a list of posts with pagination. Posts include vote statistics (up_votes, down_votes, score) and current user's vote status. // @Description Get a list of posts with pagination. Posts include vote statistics (up_votes, down_votes, score) and current user's vote status.
// @Tags posts // @Tags posts
@@ -111,7 +106,7 @@ func (h *PostHandler) GetPost(w http.ResponseWriter, r *http.Request) {
// @Accept json // @Accept json
// @Produce json // @Produce json
// @Security BearerAuth // @Security BearerAuth
// @Param request body CreatePostRequest true "Post data" // @Param request body dto.CreatePostRequest true "Post data"
// @Success 201 {object} PostResponse // @Success 201 {object} PostResponse
// @Failure 400 {object} PostResponse "Invalid request data or validation failed" // @Failure 400 {object} PostResponse "Invalid request data or validation failed"
// @Failure 401 {object} PostResponse "Authentication required" // @Failure 401 {object} PostResponse "Authentication required"
@@ -120,32 +115,9 @@ func (h *PostHandler) GetPost(w http.ResponseWriter, r *http.Request) {
// @Failure 500 {object} PostResponse "Internal server error" // @Failure 500 {object} PostResponse "Internal server error"
// @Router /api/posts [post] // @Router /api/posts [post]
func (h *PostHandler) CreatePost(w http.ResponseWriter, r *http.Request) { func (h *PostHandler) CreatePost(w http.ResponseWriter, r *http.Request) {
var req struct { req, ok := GetValidatedDTO[dto.CreatePostRequest](r)
Title string `json:"title"` if !ok {
URL string `json:"url"` SendErrorResponse(w, "Invalid request", http.StatusBadRequest)
Content string `json:"content"`
}
if !DecodeJSONRequest(w, r, &req) {
return
}
req.Title = security.SanitizeInput(req.Title)
req.URL = security.SanitizeURL(req.URL)
req.Content = security.SanitizePostContent(req.Content)
if req.URL == "" {
SendErrorResponse(w, "URL is required", http.StatusBadRequest)
return
}
if len(req.Title) > 200 {
SendErrorResponse(w, "Title must be no more than 200 characters", http.StatusBadRequest)
return
}
if len(req.Content) > 10000 {
SendErrorResponse(w, "Content must be no more than 10,000 characters", http.StatusBadRequest)
return return
} }
@@ -154,13 +126,20 @@ func (h *PostHandler) CreatePost(w http.ResponseWriter, r *http.Request) {
return return
} }
title := req.Title title := security.SanitizeInput(req.Title)
url := security.SanitizeURL(req.URL)
content := security.SanitizePostContent(req.Content)
if url == "" {
SendErrorResponse(w, "Invalid URL", http.StatusBadRequest)
return
}
if title == "" && h.titleFetcher != nil { if title == "" && h.titleFetcher != nil {
titleCtx, cancel := context.WithTimeout(r.Context(), 7*time.Second) titleCtx, cancel := context.WithTimeout(r.Context(), 7*time.Second)
defer cancel() defer cancel()
fetchedTitle, err := h.titleFetcher.FetchTitle(titleCtx, req.URL) fetchedTitle, err := h.titleFetcher.FetchTitle(titleCtx, url)
if err != nil { if err != nil {
switch { switch {
case errors.Is(err, services.ErrUnsupportedScheme): case errors.Is(err, services.ErrUnsupportedScheme):
@@ -186,10 +165,20 @@ func (h *PostHandler) CreatePost(w http.ResponseWriter, r *http.Request) {
return return
} }
if len(title) > 200 {
SendErrorResponse(w, "Title must be at most 200 characters", http.StatusBadRequest)
return
}
if len(content) > 10000 {
SendErrorResponse(w, "Content must be at most 10000 characters", http.StatusBadRequest)
return
}
post := &database.Post{ post := &database.Post{
Title: title, Title: title,
URL: req.URL, URL: url,
Content: req.Content, Content: content,
AuthorID: &userID, AuthorID: &userID,
} }
@@ -257,7 +246,7 @@ func (h *PostHandler) SearchPosts(w http.ResponseWriter, r *http.Request) {
// @Produce json // @Produce json
// @Security BearerAuth // @Security BearerAuth
// @Param id path int true "Post ID" // @Param id path int true "Post ID"
// @Param request body UpdatePostRequest true "Post update data" // @Param request body dto.UpdatePostRequest true "Post update data"
// @Success 200 {object} PostResponse "Post updated successfully" // @Success 200 {object} PostResponse "Post updated successfully"
// @Failure 400 {object} PostResponse "Invalid request data or validation failed" // @Failure 400 {object} PostResponse "Invalid request data or validation failed"
// @Failure 401 {object} PostResponse "Authentication required" // @Failure 401 {object} PostResponse "Authentication required"
@@ -286,40 +275,27 @@ func (h *PostHandler) UpdatePost(w http.ResponseWriter, r *http.Request) {
return return
} }
var req struct { req, ok := GetValidatedDTO[dto.UpdatePostRequest](r)
Title string `json:"title"` if !ok {
Content string `json:"content"` SendErrorResponse(w, "Invalid request", http.StatusBadRequest)
}
if !DecodeJSONRequest(w, r, &req) {
return return
} }
req.Title = security.SanitizeInput(req.Title) title := security.SanitizeInput(req.Title)
req.Content = security.SanitizePostContent(req.Content) content := security.SanitizePostContent(req.Content)
if len(req.Title) > 200 { if err := validation.ValidateTitle(title); err != nil {
SendErrorResponse(w, "Title must be no more than 200 characters", http.StatusBadRequest)
return
}
if len(req.Content) > 10000 {
SendErrorResponse(w, "Content must be no more than 10,000 characters", http.StatusBadRequest)
return
}
if err := validation.ValidateTitle(req.Title); err != nil {
SendErrorResponse(w, err.Error(), http.StatusBadRequest) SendErrorResponse(w, err.Error(), http.StatusBadRequest)
return return
} }
if err := validation.ValidateContent(req.Content); err != nil { if err := validation.ValidateContent(content); err != nil {
SendErrorResponse(w, err.Error(), http.StatusBadRequest) SendErrorResponse(w, err.Error(), http.StatusBadRequest)
return return
} }
post.Title = req.Title post.Title = title
post.Content = req.Content post.Content = content
if err := h.postRepo.Update(post); err != nil { if err := h.postRepo.Update(post); err != nil {
SendErrorResponse(w, "Failed to update post", http.StatusInternalServerError) SendErrorResponse(w, "Failed to update post", http.StatusInternalServerError)
@@ -458,7 +434,7 @@ func (h *PostHandler) MountRoutes(r chi.Router, config RouteModuleConfig) {
if config.GeneralRateLimit != nil { if config.GeneralRateLimit != nil {
protected = config.GeneralRateLimit(protected) protected = config.GeneralRateLimit(protected)
} }
protected.Post("/posts", h.CreatePost) protected.Post("/posts", WithValidation[dto.CreatePostRequest](config.ValidationMiddleware, h.CreatePost))
protected.Put("/posts/{id}", h.UpdatePost) protected.Put("/posts/{id}", WithValidation[dto.UpdatePostRequest](config.ValidationMiddleware, h.UpdatePost))
protected.Delete("/posts/{id}", h.DeletePost) protected.Delete("/posts/{id}", h.DeletePost)
} }

View File

@@ -69,9 +69,8 @@ func TestPostHandlerCreatePostWithTitleFetcher(t *testing.T) {
handler := NewPostHandler(repo, titleFetcher, nil) handler := NewPostHandler(repo, titleFetcher, nil)
request := httptest.NewRequest(http.MethodPost, "/api/posts", bytes.NewBufferString(`{"url":"https://example.com","content":"Test content"}`)) request := createCreatePostRequest(`{"url":"https://example.com","content":"Test content"}`)
request = testutils.WithUserContext(request, middleware.UserIDKey, uint(1)) request = testutils.WithUserContext(request, middleware.UserIDKey, uint(1))
request.Header.Set("Content-Type", "application/json")
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
handler.CreatePost(recorder, request) handler.CreatePost(recorder, request)
@@ -171,7 +170,7 @@ func TestPostHandlerUpdatePostUnauthorized(t *testing.T) {
handler := NewPostHandler(repo, nil, nil) handler := NewPostHandler(repo, nil, nil)
request := httptest.NewRequest(http.MethodPut, "/api/posts/1", bytes.NewBufferString(`{"title":"Updated Title","content":"Updated content"}`)) request := createUpdatePostRequest(`{"title":"Updated Title","content":"Updated content"}`)
request = testutils.WithUserContext(request, middleware.UserIDKey, uint(1)) request = testutils.WithUserContext(request, middleware.UserIDKey, uint(1))
request = testutils.WithURLParams(request, map[string]string{"id": "1"}) request = testutils.WithURLParams(request, map[string]string{"id": "1"})
request.Header.Set("Content-Type", "application/json") request.Header.Set("Content-Type", "application/json")
@@ -278,8 +277,7 @@ func TestPostHandlerCreatePostSuccess(t *testing.T) {
handler := NewPostHandler(repo, fetcher, nil) handler := NewPostHandler(repo, fetcher, nil)
body := bytes.NewBufferString(`{"title":" ","url":"https://example.com","content":"Go"}`) request := createCreatePostRequest(`{"title":" ","url":"https://example.com","content":"Go"}`)
request := httptest.NewRequest(http.MethodPost, "/api/posts", body)
ctx := context.WithValue(request.Context(), middleware.UserIDKey, uint(42)) ctx := context.WithValue(request.Context(), middleware.UserIDKey, uint(42))
request = request.WithContext(ctx) request = request.WithContext(ctx)
@@ -297,7 +295,7 @@ func TestPostHandlerCreatePostValidation(t *testing.T) {
handler := NewPostHandler(testutils.NewPostRepositoryStub(), &testutils.TitleFetcherStub{}, nil) handler := NewPostHandler(testutils.NewPostRepositoryStub(), &testutils.TitleFetcherStub{}, nil)
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
request := httptest.NewRequest(http.MethodPost, "/api/posts", bytes.NewBufferString(`{"title":"","url":"","content":""}`)) request := createCreatePostRequest(`{"title":"","url":"","content":""}`)
request = request.WithContext(context.WithValue(request.Context(), middleware.UserIDKey, uint(1))) request = request.WithContext(context.WithValue(request.Context(), middleware.UserIDKey, uint(1)))
handler.CreatePost(recorder, request) handler.CreatePost(recorder, request)
if recorder.Result().StatusCode != http.StatusBadRequest { if recorder.Result().StatusCode != http.StatusBadRequest {
@@ -305,14 +303,14 @@ func TestPostHandlerCreatePostValidation(t *testing.T) {
} }
recorder = httptest.NewRecorder() recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts", bytes.NewBufferString(`invalid json`)) request = createCreatePostRequest(`invalid json`)
handler.CreatePost(recorder, request) handler.CreatePost(recorder, request)
if recorder.Result().StatusCode != http.StatusBadRequest { if recorder.Result().StatusCode != http.StatusBadRequest {
t.Fatalf("expected 400 for invalid JSON, got %d", recorder.Result().StatusCode) t.Fatalf("expected 400 for invalid JSON, got %d", recorder.Result().StatusCode)
} }
recorder = httptest.NewRecorder() recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts", bytes.NewBufferString(`{"title":"ok","url":"https://example.com"}`)) request = createCreatePostRequest(`{"title":"ok","url":"https://example.com"}`)
handler.CreatePost(recorder, request) handler.CreatePost(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusUnauthorized) testutils.AssertHTTPStatus(t, recorder, http.StatusUnauthorized)
} }
@@ -336,8 +334,7 @@ func TestPostHandlerCreatePostTitleFetcherErrors(t *testing.T) {
return "", tc.err return "", tc.err
}} }}
handler := NewPostHandler(repo, fetcher, nil) handler := NewPostHandler(repo, fetcher, nil)
body := bytes.NewBufferString(`{"title":" ","url":"https://example.com"}`) request := createCreatePostRequest(`{"title":" ","url":"https://example.com"}`)
request := httptest.NewRequest(http.MethodPost, "/api/posts", body)
request = request.WithContext(context.WithValue(request.Context(), middleware.UserIDKey, uint(1))) request = request.WithContext(context.WithValue(request.Context(), middleware.UserIDKey, uint(1)))
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
@@ -495,7 +492,7 @@ func TestPostHandlerUpdatePost(t *testing.T) {
} }
handler := NewPostHandler(repo, &testutils.TitleFetcherStub{}, nil) handler := NewPostHandler(repo, &testutils.TitleFetcherStub{}, nil)
request := httptest.NewRequest(http.MethodPut, "/api/posts/"+tt.postID, bytes.NewBufferString(tt.requestBody)) request := createUpdatePostRequest(tt.requestBody)
if tt.userID > 0 { if tt.userID > 0 {
ctx := context.WithValue(request.Context(), middleware.UserIDKey, tt.userID) ctx := context.WithValue(request.Context(), middleware.UserIDKey, tt.userID)
request = request.WithContext(ctx) request = request.WithContext(ctx)
@@ -669,6 +666,9 @@ func (e *errorVoteRepository) Delete(uint) error { ret
func (e *errorVoteRepository) Count() (int64, error) { return 0, nil } func (e *errorVoteRepository) Count() (int64, error) { return 0, nil }
func (e *errorVoteRepository) CountByPostID(uint) (int64, error) { return 0, nil } func (e *errorVoteRepository) CountByPostID(uint) (int64, error) { return 0, nil }
func (e *errorVoteRepository) CountByUserID(uint) (int64, error) { return 0, nil } func (e *errorVoteRepository) CountByUserID(uint) (int64, error) { return 0, nil }
func (e *errorVoteRepository) GetVoteCountsByPostID(uint) (int, int, error) {
return 0, 0, errors.New("database error")
}
func (e *errorVoteRepository) WithTx(*gorm.DB) repositories.VoteRepository { return e } func (e *errorVoteRepository) WithTx(*gorm.DB) repositories.VoteRepository { return e }
func TestPostHandler_EdgeCases(t *testing.T) { func TestPostHandler_EdgeCases(t *testing.T) {

View File

@@ -18,4 +18,5 @@ type RouteModuleConfig struct {
AuthRateLimit func(chi.Router) chi.Router AuthRateLimit func(chi.Router) chi.Router
CSRFMiddleware func(http.Handler) http.Handler CSRFMiddleware func(http.Handler) http.Handler
AuthMiddleware func(http.Handler) http.Handler AuthMiddleware func(http.Handler) http.Handler
ValidationMiddleware func(http.Handler) http.Handler
} }

View File

@@ -24,10 +24,6 @@ func TestPostHandler_XSSProtection_Comprehensive(t *testing.T) {
t.Run("XSS_"+payload[:minLen(20, len(payload))], func(t *testing.T) { t.Run("XSS_"+payload[:minLen(20, len(payload))], func(t *testing.T) {
repo := &testutils.PostRepositoryStub{ repo := &testutils.PostRepositoryStub{
CreateFn: func(post *database.Post) error { CreateFn: func(post *database.Post) error {
sanitizedTitle := security.SanitizeInput(payload)
if post.Title != sanitizedTitle {
t.Errorf("Expected sanitized title, got %q", post.Title)
}
return nil return nil
}, },
} }
@@ -41,14 +37,46 @@ func TestPostHandler_XSSProtection_Comprehensive(t *testing.T) {
} }
body, _ := json.Marshal(postData) body, _ := json.Marshal(postData)
request := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer(body)) request := createCreatePostRequest(string(body))
request.Header.Set("Content-Type", "application/json") request.Header.Set("Content-Type", "application/json")
request = request.WithContext(context.WithValue(request.Context(), middleware.UserIDKey, uint(1))) request = request.WithContext(context.WithValue(request.Context(), middleware.UserIDKey, uint(1)))
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
handler.CreatePost(recorder, request) handler.CreatePost(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusCreated) if recorder.Code != http.StatusCreated {
t.Errorf("Expected status %d, got %d. Body: %s", http.StatusCreated, recorder.Code, recorder.Body.String())
return
}
var response CommonResponse
if err := json.NewDecoder(recorder.Body).Decode(&response); err != nil {
t.Fatalf("Failed to decode response: %v", err)
}
if !response.Success {
t.Errorf("Expected successful response, got error: %s", response.Error)
return
}
dataMap, ok := response.Data.(map[string]any)
if !ok {
t.Fatalf("Expected data to be a map, got %T", response.Data)
}
title, ok := dataMap["title"].(string)
if !ok {
t.Fatalf("Expected title to be a string, got %T", dataMap["title"])
}
expectedSanitized := security.SanitizeInput(payload)
if title != expectedSanitized {
t.Errorf("Expected sanitized title %q, got %q", expectedSanitized, title)
}
if title == payload {
t.Errorf("Title was not sanitized - original payload %q matches response %q", payload, title)
}
}) })
} }
} }
@@ -123,7 +151,7 @@ func TestPostHandler_InputValidation(t *testing.T) {
} }
body, _ := json.Marshal(postData) body, _ := json.Marshal(postData)
request := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer(body)) request := createCreatePostRequest(string(body))
request.Header.Set("Content-Type", "application/json") request.Header.Set("Content-Type", "application/json")
request = request.WithContext(context.WithValue(request.Context(), middleware.UserIDKey, uint(1))) request = request.WithContext(context.WithValue(request.Context(), middleware.UserIDKey, uint(1)))
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
@@ -230,7 +258,7 @@ func TestAuthHandler_PasswordValidation(t *testing.T) {
} }
body, _ := json.Marshal(registerData) body, _ := json.Marshal(registerData)
request := httptest.NewRequest("POST", "/api/auth/register", bytes.NewBuffer(body)) request := createRegisterRequest(string(body))
request.Header.Set("Content-Type", "application/json") request.Header.Set("Content-Type", "application/json")
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
@@ -290,7 +318,7 @@ func TestAuthHandler_UsernameSanitization(t *testing.T) {
} }
body, _ := json.Marshal(registerData) body, _ := json.Marshal(registerData)
request := httptest.NewRequest("POST", "/api/auth/register", bytes.NewBuffer(body)) request := createRegisterRequest(string(body))
request.Header.Set("Content-Type", "application/json") request.Header.Set("Content-Type", "application/json")
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()

View File

@@ -91,7 +91,7 @@ func (h *UserHandler) GetUser(w http.ResponseWriter, r *http.Request) {
// @Accept json // @Accept json
// @Produce json // @Produce json
// @Security BearerAuth // @Security BearerAuth
// @Param request body RegisterRequest true "User data" // @Param request body dto.RegisterRequest true "User data"
// @Success 201 {object} UserResponse "User created successfully" // @Success 201 {object} UserResponse "User created successfully"
// @Failure 400 {object} UserResponse "Invalid request data or validation failed" // @Failure 400 {object} UserResponse "Invalid request data or validation failed"
// @Failure 401 {object} UserResponse "Authentication required" // @Failure 401 {object} UserResponse "Authentication required"
@@ -99,13 +99,9 @@ func (h *UserHandler) GetUser(w http.ResponseWriter, r *http.Request) {
// @Failure 500 {object} UserResponse "Internal server error" // @Failure 500 {object} UserResponse "Internal server error"
// @Router /api/users [post] // @Router /api/users [post]
func (h *UserHandler) CreateUser(w http.ResponseWriter, r *http.Request) { func (h *UserHandler) CreateUser(w http.ResponseWriter, r *http.Request) {
var req struct { req, ok := GetValidatedDTO[dto.RegisterRequest](r)
Username string `json:"username"` if !ok {
Email string `json:"email"` SendErrorResponse(w, "Invalid request", http.StatusBadRequest)
Password string `json:"password"`
}
if !DecodeJSONRequest(w, r, &req) {
return return
} }
@@ -189,7 +185,7 @@ func (h *UserHandler) MountRoutes(r chi.Router, config RouteModuleConfig) {
} }
protected.Get("/users", h.GetUsers) protected.Get("/users", h.GetUsers)
protected.Post("/users", h.CreateUser) protected.Post("/users", WithValidation[dto.RegisterRequest](config.ValidationMiddleware, h.CreateUser))
protected.Get("/users/{id}", h.GetUser) protected.Get("/users/{id}", h.GetUser)
protected.Get("/users/{id}/posts", h.GetUserPosts) protected.Get("/users/{id}/posts", h.GetUserPosts)
} }

View File

@@ -1,7 +1,6 @@
package handlers package handlers
import ( import (
"bytes"
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/http" "net/http"
@@ -103,7 +102,7 @@ func TestUserHandlerCreateUser(t *testing.T) {
return nil return nil
}}) }})
request := httptest.NewRequest(http.MethodPost, "/api/users", bytes.NewBufferString(`{"username":"user","email":"user@example.com","password":"Password123!"}`)) request := createRegisterRequest(`{"username":"user","email":"user@example.com","password":"Password123!"}`)
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
handler.CreateUser(recorder, request) handler.CreateUser(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusCreated) testutils.AssertHTTPStatus(t, recorder, http.StatusCreated)
@@ -126,14 +125,14 @@ func TestUserHandlerCreateUser(t *testing.T) {
} }
recorder = httptest.NewRecorder() recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/users", bytes.NewBufferString("invalid")) request = createRegisterRequest("invalid")
handler.CreateUser(recorder, request) handler.CreateUser(recorder, request)
if recorder.Result().StatusCode != http.StatusBadRequest { if recorder.Result().StatusCode != http.StatusBadRequest {
t.Fatalf("expected 400 for invalid json, got %d", recorder.Result().StatusCode) t.Fatalf("expected 400 for invalid json, got %d", recorder.Result().StatusCode)
} }
recorder = httptest.NewRecorder() recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/users", bytes.NewBufferString(`{"username":"","email":"","password":""}`)) request = createRegisterRequest(`{"username":"","email":"","password":""}`)
handler.CreateUser(recorder, request) handler.CreateUser(recorder, request)
if recorder.Result().StatusCode != http.StatusBadRequest { if recorder.Result().StatusCode != http.StatusBadRequest {
t.Fatalf("expected 400 for missing fields, got %d", recorder.Result().StatusCode) t.Fatalf("expected 400 for missing fields, got %d", recorder.Result().StatusCode)
@@ -144,7 +143,7 @@ func TestUserHandlerCreateUser(t *testing.T) {
} }
handler = newUserHandler(repo) handler = newUserHandler(repo)
recorder = httptest.NewRecorder() recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/users", bytes.NewBufferString(`{"username":"user","email":"user@example.com","password":"Password123!"}`)) request = createRegisterRequest(`{"username":"user","email":"user@example.com","password":"Password123!"}`)
handler.CreateUser(recorder, request) handler.CreateUser(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusConflict) testutils.AssertHTTPStatus(t, recorder, http.StatusConflict)
} }
@@ -350,7 +349,7 @@ func TestUserHandler_PasswordValidation(t *testing.T) {
handler := NewUserHandler(repo, authService) handler := NewUserHandler(repo, authService)
requestBody := fmt.Sprintf(`{"username":"testuser","email":"test@example.com","password":"%s"}`, tt.password) requestBody := fmt.Sprintf(`{"username":"testuser","email":"test@example.com","password":"%s"}`, tt.password)
request := httptest.NewRequest(http.MethodPost, "/api/users", bytes.NewBufferString(requestBody)) request := createRegisterRequest(requestBody)
request.Header.Set("Content-Type", "application/json") request.Header.Set("Content-Type", "application/json")
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()

View File

@@ -4,6 +4,7 @@ import (
"net/http" "net/http"
"goyco/internal/database" "goyco/internal/database"
"goyco/internal/dto"
"goyco/internal/services" "goyco/internal/services"
"github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5"
@@ -39,11 +40,6 @@ func NewVoteHandler(voteService *services.VoteService) *VoteHandler {
} }
} }
// @Description Vote request with type field. All votes are handled the same way.
type VoteRequest struct {
Type string `json:"type" example:"up" enums:"up,down,none" description:"Vote type: 'up' for upvote, 'down' for downvote, 'none' to remove vote"`
}
type VoteResponse = CommonResponse type VoteResponse = CommonResponse
// @Summary Cast a vote on a post // @Summary Cast a vote on a post
@@ -62,7 +58,7 @@ type VoteResponse = CommonResponse
// @Produce json // @Produce json
// @Security BearerAuth // @Security BearerAuth
// @Param id path int true "Post ID" // @Param id path int true "Post ID"
// @Param request body VoteRequest true "Vote data (type: 'up', 'down', or 'none' to remove)" // @Param request body dto.CastVoteRequest true "Vote data (type: 'up', 'down', or 'none' to remove)"
// @Success 200 {object} VoteResponse "Vote cast successfully with updated post statistics" // @Success 200 {object} VoteResponse "Vote cast successfully with updated post statistics"
// @Failure 401 {object} VoteResponse "Authentication required" // @Failure 401 {object} VoteResponse "Authentication required"
// @Failure 400 {object} VoteResponse "Invalid request data or vote type" // @Failure 400 {object} VoteResponse "Invalid request data or vote type"
@@ -82,8 +78,9 @@ func (h *VoteHandler) CastVote(w http.ResponseWriter, r *http.Request) {
return return
} }
var req VoteRequest req, ok := GetValidatedDTO[dto.CastVoteRequest](r)
if !DecodeJSONRequest(w, r, &req) { if !ok {
SendErrorResponse(w, "Invalid request", http.StatusBadRequest)
return return
} }
@@ -286,7 +283,7 @@ func (h *VoteHandler) MountRoutes(r chi.Router, config RouteModuleConfig) {
protected = config.GeneralRateLimit(protected) protected = config.GeneralRateLimit(protected)
} }
protected.Post("/posts/{id}/vote", h.CastVote) protected.Post("/posts/{id}/vote", WithValidation[dto.CastVoteRequest](config.ValidationMiddleware, h.CastVote))
protected.Delete("/posts/{id}/vote", h.RemoveVote) protected.Delete("/posts/{id}/vote", h.RemoveVote)
protected.Get("/posts/{id}/vote", h.GetUserVote) protected.Get("/posts/{id}/vote", h.GetUserVote)
protected.Get("/posts/{id}/votes", h.GetPostVotes) protected.Get("/posts/{id}/votes", h.GetPostVotes)

View File

@@ -1,7 +1,6 @@
package handlers package handlers
import ( import (
"bytes"
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
@@ -59,13 +58,13 @@ func TestVoteHandlerCastVote(t *testing.T) {
handler := newVoteHandlerWithRepos() handler := newVoteHandlerWithRepos()
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
request := httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"up"}`)) request := createVoteRequest(`{"type":"up"}`)
request = testutils.WithURLParams(request, map[string]string{"id": "1"}) request = testutils.WithURLParams(request, map[string]string{"id": "1"})
handler.CastVote(recorder, request) handler.CastVote(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusUnauthorized) testutils.AssertHTTPStatus(t, recorder, http.StatusUnauthorized)
recorder = httptest.NewRecorder() recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/abc/vote", bytes.NewBufferString(`{"type":"up"}`)) request = createVoteRequest(`{"type":"up"}`)
request = testutils.WithURLParams(request, map[string]string{"id": "abc"}) request = testutils.WithURLParams(request, map[string]string{"id": "abc"})
ctx := context.WithValue(request.Context(), middleware.UserIDKey, uint(1)) ctx := context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
request = request.WithContext(ctx) request = request.WithContext(ctx)
@@ -73,7 +72,7 @@ func TestVoteHandlerCastVote(t *testing.T) {
testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest) testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest)
recorder = httptest.NewRecorder() recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`invalid`)) request = createVoteRequest(`invalid`)
request = testutils.WithURLParams(request, map[string]string{"id": "1"}) request = testutils.WithURLParams(request, map[string]string{"id": "1"})
ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(1)) ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
request = request.WithContext(ctx) request = request.WithContext(ctx)
@@ -83,7 +82,7 @@ func TestVoteHandlerCastVote(t *testing.T) {
} }
recorder = httptest.NewRecorder() recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"maybe"}`)) request = createVoteRequest(`{"type":"maybe"}`)
request = testutils.WithURLParams(request, map[string]string{"id": "1"}) request = testutils.WithURLParams(request, map[string]string{"id": "1"})
ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(1)) ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
request = request.WithContext(ctx) request = request.WithContext(ctx)
@@ -93,7 +92,7 @@ func TestVoteHandlerCastVote(t *testing.T) {
} }
recorder = httptest.NewRecorder() recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"up"}`)) request = createVoteRequest(`{"type":"up"}`)
request = testutils.WithURLParams(request, map[string]string{"id": "1"}) request = testutils.WithURLParams(request, map[string]string{"id": "1"})
ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(1)) ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
request = request.WithContext(ctx) request = request.WithContext(ctx)
@@ -101,7 +100,7 @@ func TestVoteHandlerCastVote(t *testing.T) {
testutils.AssertHTTPStatus(t, recorder, http.StatusOK) testutils.AssertHTTPStatus(t, recorder, http.StatusOK)
recorder = httptest.NewRecorder() recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"down"}`)) request = createVoteRequest(`{"type":"down"}`)
request = testutils.WithURLParams(request, map[string]string{"id": "1"}) request = testutils.WithURLParams(request, map[string]string{"id": "1"})
ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(2)) ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(2))
request = request.WithContext(ctx) request = request.WithContext(ctx)
@@ -111,7 +110,7 @@ func TestVoteHandlerCastVote(t *testing.T) {
} }
recorder = httptest.NewRecorder() recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"none"}`)) request = createVoteRequest(`{"type":"none"}`)
request = testutils.WithURLParams(request, map[string]string{"id": "1"}) request = testutils.WithURLParams(request, map[string]string{"id": "1"})
ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(3)) ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(3))
request = request.WithContext(ctx) request = request.WithContext(ctx)
@@ -125,7 +124,7 @@ func TestVoteHandlerCastVotePostNotFound(t *testing.T) {
handler, _, posts := newVoteHandlerWithReposRefs() handler, _, posts := newVoteHandlerWithReposRefs()
delete(posts, 1) delete(posts, 1)
request := httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"up"}`)) request := createVoteRequest(`{"type":"up"}`)
request = testutils.WithURLParams(request, map[string]string{"id": "1"}) request = testutils.WithURLParams(request, map[string]string{"id": "1"})
ctx := context.WithValue(request.Context(), middleware.UserIDKey, uint(1)) ctx := context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
request = request.WithContext(ctx) request = request.WithContext(ctx)
@@ -164,7 +163,7 @@ func TestVoteHandlerRemoveVote(t *testing.T) {
} }
recorder = httptest.NewRecorder() recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"up"}`)) request = createVoteRequest(`{"type":"up"}`)
request = testutils.WithURLParams(request, map[string]string{"id": "1"}) request = testutils.WithURLParams(request, map[string]string{"id": "1"})
ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(1)) ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
request = request.WithContext(ctx) request = request.WithContext(ctx)
@@ -202,7 +201,7 @@ func TestVoteHandlerRemoveVotePostNotFound(t *testing.T) {
func TestVoteHandlerRemoveVoteUnexpectedError(t *testing.T) { func TestVoteHandlerRemoveVoteUnexpectedError(t *testing.T) {
handler, voteRepo, _ := newVoteHandlerWithReposRefs() handler, voteRepo, _ := newVoteHandlerWithReposRefs()
request := httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"up"}`)) request := createVoteRequest(`{"type":"up"}`)
request = testutils.WithURLParams(request, map[string]string{"id": "1"}) request = testutils.WithURLParams(request, map[string]string{"id": "1"})
ctx := context.WithValue(request.Context(), middleware.UserIDKey, uint(1)) ctx := context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
request = request.WithContext(ctx) request = request.WithContext(ctx)
@@ -257,7 +256,7 @@ func TestVoteHandlerGetUserVote(t *testing.T) {
} }
recorder = httptest.NewRecorder() recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"up"}`)) request = createVoteRequest(`{"type":"up"}`)
request = testutils.WithURLParams(request, map[string]string{"id": "1"}) request = testutils.WithURLParams(request, map[string]string{"id": "1"})
ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(1)) ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
request = request.WithContext(ctx) request = request.WithContext(ctx)
@@ -301,7 +300,7 @@ func TestVoteHandlerGetPostVotes(t *testing.T) {
} }
recorder = httptest.NewRecorder() recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"up"}`)) request = createVoteRequest(`{"type":"up"}`)
request = testutils.WithURLParams(request, map[string]string{"id": "1"}) request = testutils.WithURLParams(request, map[string]string{"id": "1"})
ctx := context.WithValue(request.Context(), middleware.UserIDKey, uint(1)) ctx := context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
request = request.WithContext(ctx) request = request.WithContext(ctx)
@@ -311,7 +310,7 @@ func TestVoteHandlerGetPostVotes(t *testing.T) {
} }
recorder = httptest.NewRecorder() recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"down"}`)) request = createVoteRequest(`{"type":"down"}`)
request = testutils.WithURLParams(request, map[string]string{"id": "1"}) request = testutils.WithURLParams(request, map[string]string{"id": "1"})
ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(2)) ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(2))
request = request.WithContext(ctx) request = request.WithContext(ctx)
@@ -345,7 +344,7 @@ func TestVoteFlowRegression(t *testing.T) {
postID := "1" postID := "1"
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
request := httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"up"}`)) request := createVoteRequest(`{"type":"up"}`)
request = testutils.WithURLParams(request, map[string]string{"id": postID}) request = testutils.WithURLParams(request, map[string]string{"id": postID})
ctx := context.WithValue(request.Context(), middleware.UserIDKey, userID) ctx := context.WithValue(request.Context(), middleware.UserIDKey, userID)
request = request.WithContext(ctx) request = request.WithContext(ctx)
@@ -363,7 +362,7 @@ func TestVoteFlowRegression(t *testing.T) {
} }
recorder = httptest.NewRecorder() recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"down"}`)) request = createVoteRequest(`{"type":"down"}`)
request = testutils.WithURLParams(request, map[string]string{"id": postID}) request = testutils.WithURLParams(request, map[string]string{"id": postID})
ctx = context.WithValue(request.Context(), middleware.UserIDKey, userID) ctx = context.WithValue(request.Context(), middleware.UserIDKey, userID)
request = request.WithContext(ctx) request = request.WithContext(ctx)
@@ -373,7 +372,7 @@ func TestVoteFlowRegression(t *testing.T) {
} }
recorder = httptest.NewRecorder() recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"none"}`)) request = createVoteRequest(`{"type":"none"}`)
request = testutils.WithURLParams(request, map[string]string{"id": postID}) request = testutils.WithURLParams(request, map[string]string{"id": postID})
ctx = context.WithValue(request.Context(), middleware.UserIDKey, userID) ctx = context.WithValue(request.Context(), middleware.UserIDKey, userID)
request = request.WithContext(ctx) request = request.WithContext(ctx)
@@ -404,7 +403,7 @@ func TestVoteFlowRegression(t *testing.T) {
postID := "1" postID := "1"
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
request := httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"up"}`)) request := createVoteRequest(`{"type":"up"}`)
request = testutils.WithURLParams(request, map[string]string{"id": postID}) request = testutils.WithURLParams(request, map[string]string{"id": postID})
ctx := context.WithValue(request.Context(), middleware.UserIDKey, uint(1)) ctx := context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
request = request.WithContext(ctx) request = request.WithContext(ctx)
@@ -414,7 +413,7 @@ func TestVoteFlowRegression(t *testing.T) {
} }
recorder = httptest.NewRecorder() recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"down"}`)) request = createVoteRequest(`{"type":"down"}`)
request = testutils.WithURLParams(request, map[string]string{"id": postID}) request = testutils.WithURLParams(request, map[string]string{"id": postID})
ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(2)) ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(2))
request = request.WithContext(ctx) request = request.WithContext(ctx)
@@ -424,7 +423,7 @@ func TestVoteFlowRegression(t *testing.T) {
} }
recorder = httptest.NewRecorder() recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"up"}`)) request = createVoteRequest(`{"type":"up"}`)
request = testutils.WithURLParams(request, map[string]string{"id": postID}) request = testutils.WithURLParams(request, map[string]string{"id": postID})
ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(3)) ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(3))
request = request.WithContext(ctx) request = request.WithContext(ctx)
@@ -452,7 +451,7 @@ func TestVoteFlowRegression(t *testing.T) {
t.Run("ErrorHandlingEdgeCases", func(t *testing.T) { t.Run("ErrorHandlingEdgeCases", func(t *testing.T) {
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
request := httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(``)) request := createVoteRequest(``)
request = testutils.WithURLParams(request, map[string]string{"id": "1"}) request = testutils.WithURLParams(request, map[string]string{"id": "1"})
ctx := context.WithValue(request.Context(), middleware.UserIDKey, uint(1)) ctx := context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
request = request.WithContext(ctx) request = request.WithContext(ctx)
@@ -460,7 +459,7 @@ func TestVoteFlowRegression(t *testing.T) {
testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest) testutils.AssertHTTPStatus(t, recorder, http.StatusBadRequest)
recorder = httptest.NewRecorder() recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{}`)) request = createVoteRequest(`{}`)
request = testutils.WithURLParams(request, map[string]string{"id": "1"}) request = testutils.WithURLParams(request, map[string]string{"id": "1"})
ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(1)) ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
request = request.WithContext(ctx) request = request.WithContext(ctx)
@@ -470,7 +469,7 @@ func TestVoteFlowRegression(t *testing.T) {
} }
recorder = httptest.NewRecorder() recorder = httptest.NewRecorder()
request = httptest.NewRequest(http.MethodPost, "/api/posts/1/vote", bytes.NewBufferString(`{"type":"invalid"}`)) request = createVoteRequest(`{"type":"invalid"}`)
request = testutils.WithURLParams(request, map[string]string{"id": "1"}) request = testutils.WithURLParams(request, map[string]string{"id": "1"})
ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(1)) ctx = context.WithValue(request.Context(), middleware.UserIDKey, uint(1))
request = request.WithContext(ctx) request = request.WithContext(ctx)

View File

@@ -53,25 +53,25 @@ func TestIntegration_Caching(t *testing.T) {
router := ctx.Router router := ctx.Router
t.Run("Cache_Hit_On_Repeated_Requests", func(t *testing.T) { t.Run("Cache_Hit_On_Repeated_Requests", func(t *testing.T) {
req1 := httptest.NewRequest("GET", "/api/posts", nil) firstRequest := httptest.NewRequest("GET", "/api/posts", nil)
rec1 := httptest.NewRecorder() firstRecorder := httptest.NewRecorder()
router.ServeHTTP(rec1, req1) router.ServeHTTP(firstRecorder, firstRequest)
time.Sleep(10 * time.Millisecond) time.Sleep(10 * time.Millisecond)
req2 := httptest.NewRequest("GET", "/api/posts", nil) secondRequest := httptest.NewRequest("GET", "/api/posts", nil)
rec2 := httptest.NewRecorder() secondRecorder := httptest.NewRecorder()
router.ServeHTTP(rec2, req2) router.ServeHTTP(secondRecorder, secondRequest)
if rec1.Code != rec2.Code { if firstRecorder.Code != secondRecorder.Code {
t.Error("Cached responses should have same status code") t.Error("Cached responses should have same status code")
} }
if rec1.Body.String() != rec2.Body.String() { if firstRecorder.Body.String() != secondRecorder.Body.String() {
t.Error("Cached responses should have same body") t.Error("Cached responses should have same body")
} }
if rec2.Header().Get("X-Cache") != "HIT" { if secondRecorder.Header().Get("X-Cache") != "HIT" {
t.Log("Cache may not be enabled for this path or response may not be cacheable") t.Log("Cache may not be enabled for this path or response may not be cacheable")
} }
}) })
@@ -80,9 +80,9 @@ func TestIntegration_Caching(t *testing.T) {
ctx.Suite.EmailSender.Reset() ctx.Suite.EmailSender.Reset()
user := createUserWithCleanup(t, ctx, "cache_post_user", "cache_post@example.com") user := createUserWithCleanup(t, ctx, "cache_post_user", "cache_post@example.com")
req1 := httptest.NewRequest("GET", "/api/posts", nil) firstRequest := httptest.NewRequest("GET", "/api/posts", nil)
rec1 := httptest.NewRecorder() firstRecorder := httptest.NewRecorder()
router.ServeHTTP(rec1, req1) router.ServeHTTP(firstRecorder, firstRequest)
time.Sleep(10 * time.Millisecond) time.Sleep(10 * time.Millisecond)
@@ -92,12 +92,12 @@ func TestIntegration_Caching(t *testing.T) {
"content": "Test content", "content": "Test content",
} }
body, _ := json.Marshal(postBody) body, _ := json.Marshal(postBody)
req2 := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer(body)) secondRequest := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer(body))
req2.Header.Set("Content-Type", "application/json") secondRequest.Header.Set("Content-Type", "application/json")
req2.Header.Set("Authorization", "Bearer "+user.Token) secondRequest.Header.Set("Authorization", "Bearer "+user.Token)
req2 = testutils.WithUserContext(req2, middleware.UserIDKey, user.User.ID) secondRequest = testutils.WithUserContext(secondRequest, middleware.UserIDKey, user.User.ID)
rec2 := httptest.NewRecorder() secondRecorder := httptest.NewRecorder()
router.ServeHTTP(rec2, req2) router.ServeHTTP(secondRecorder, secondRequest)
time.Sleep(10 * time.Millisecond) time.Sleep(10 * time.Millisecond)
@@ -105,17 +105,17 @@ func TestIntegration_Caching(t *testing.T) {
rec3 := httptest.NewRecorder() rec3 := httptest.NewRecorder()
router.ServeHTTP(rec3, req3) router.ServeHTTP(rec3, req3)
if rec1.Body.String() == rec3.Body.String() && rec1.Code == http.StatusOK && rec3.Code == http.StatusOK { if firstRecorder.Body.String() == rec3.Body.String() && firstRecorder.Code == http.StatusOK && rec3.Code == http.StatusOK {
t.Log("Cache invalidation may not be working or cache may not be enabled") t.Log("Cache invalidation may not be working or cache may not be enabled")
} }
}) })
t.Run("Cache_Headers_Present", func(t *testing.T) { t.Run("Cache_Headers_Present", func(t *testing.T) {
req := httptest.NewRequest("GET", "/api/posts", nil) request := httptest.NewRequest("GET", "/api/posts", nil)
rec := httptest.NewRecorder() recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req) router.ServeHTTP(recorder, request)
if rec.Header().Get("Cache-Control") == "" && rec.Header().Get("X-Cache") == "" { if recorder.Header().Get("Cache-Control") == "" && recorder.Header().Get("X-Cache") == "" {
t.Log("Cache headers may not be present for all responses") t.Log("Cache headers may not be present for all responses")
} }
}) })
@@ -126,18 +126,18 @@ func TestIntegration_Caching(t *testing.T) {
post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Cache Delete Post", "https://example.com/cache-delete") post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Cache Delete Post", "https://example.com/cache-delete")
req1 := httptest.NewRequest("GET", "/api/posts", nil) firstRequest := httptest.NewRequest("GET", "/api/posts", nil)
rec1 := httptest.NewRecorder() firstRecorder := httptest.NewRecorder()
router.ServeHTTP(rec1, req1) router.ServeHTTP(firstRecorder, firstRequest)
time.Sleep(10 * time.Millisecond) time.Sleep(10 * time.Millisecond)
req2 := httptest.NewRequest("DELETE", "/api/posts/"+fmt.Sprintf("%d", post.ID), nil) secondRequest := httptest.NewRequest("DELETE", "/api/posts/"+fmt.Sprintf("%d", post.ID), nil)
req2.Header.Set("Authorization", "Bearer "+user.Token) secondRequest.Header.Set("Authorization", "Bearer "+user.Token)
req2 = testutils.WithUserContext(req2, middleware.UserIDKey, user.User.ID) secondRequest = testutils.WithUserContext(secondRequest, middleware.UserIDKey, user.User.ID)
req2 = testutils.WithURLParams(req2, map[string]string{"id": fmt.Sprintf("%d", post.ID)}) secondRequest = testutils.WithURLParams(secondRequest, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
rec2 := httptest.NewRecorder() secondRecorder := httptest.NewRecorder()
router.ServeHTTP(rec2, req2) router.ServeHTTP(secondRecorder, secondRequest)
time.Sleep(10 * time.Millisecond) time.Sleep(10 * time.Millisecond)
@@ -145,7 +145,7 @@ func TestIntegration_Caching(t *testing.T) {
rec3 := httptest.NewRecorder() rec3 := httptest.NewRecorder()
router.ServeHTTP(rec3, req3) router.ServeHTTP(rec3, req3)
if rec1.Body.String() == rec3.Body.String() && rec1.Code == http.StatusOK && rec3.Code == http.StatusOK { if firstRecorder.Body.String() == rec3.Body.String() && firstRecorder.Code == http.StatusOK && rec3.Code == http.StatusOK {
t.Log("Cache invalidation may not be working or cache may not be enabled") t.Log("Cache invalidation may not be working or cache may not be enabled")
} }
}) })

View File

@@ -1,15 +1,14 @@
package integration package integration
import ( import (
"bytes"
"encoding/json"
"fmt" "fmt"
"net/http" "net/http"
"net/http/httptest"
"net/url" "net/url"
"strings"
"testing" "testing"
"time"
"goyco/internal/middleware" "goyco/internal/services"
"goyco/internal/testutils" "goyco/internal/testutils"
) )
@@ -20,17 +19,8 @@ func TestIntegration_CompleteAPIEndpoints(t *testing.T) {
ctx.Suite.EmailSender.Reset() ctx.Suite.EmailSender.Reset()
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "logout_user", "logout@example.com") user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "logout_user", "logout@example.com")
reqBody := map[string]string{} request := makePostRequest(t, ctx.Router, "/api/auth/logout", map[string]any{}, user, nil)
body, _ := json.Marshal(reqBody) assertStatus(t, request, http.StatusOK)
req := httptest.NewRequest("POST", "/api/auth/logout", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
rec := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req)
assertStatus(t, rec, http.StatusOK)
}) })
t.Run("Auth_Revoke_Token_Endpoint", func(t *testing.T) { t.Run("Auth_Revoke_Token_Endpoint", func(t *testing.T) {
@@ -42,52 +32,23 @@ func TestIntegration_CompleteAPIEndpoints(t *testing.T) {
t.Fatalf("Failed to login: %v", err) t.Fatalf("Failed to login: %v", err)
} }
reqBody := map[string]string{ request := makePostRequest(t, ctx.Router, "/api/auth/revoke", map[string]any{"refresh_token": loginResult.RefreshToken}, user, nil)
"refresh_token": loginResult.RefreshToken, assertStatus(t, request, http.StatusOK)
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("POST", "/api/auth/revoke", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
rec := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req)
assertStatus(t, rec, http.StatusOK)
}) })
t.Run("Auth_Revoke_All_Tokens_Endpoint", func(t *testing.T) { t.Run("Auth_Revoke_All_Tokens_Endpoint", func(t *testing.T) {
ctx.Suite.EmailSender.Reset() ctx.Suite.EmailSender.Reset()
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "revoke_all_user", "revoke_all@example.com") user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "revoke_all_user", "revoke_all@example.com")
reqBody := map[string]string{} request := makePostRequest(t, ctx.Router, "/api/auth/revoke-all", map[string]any{}, user, nil)
body, _ := json.Marshal(reqBody) assertStatus(t, request, http.StatusOK)
req := httptest.NewRequest("POST", "/api/auth/revoke-all", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
rec := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req)
assertStatus(t, rec, http.StatusOK)
}) })
t.Run("Auth_Resend_Verification_Endpoint", func(t *testing.T) { t.Run("Auth_Resend_Verification_Endpoint", func(t *testing.T) {
ctx.Suite.EmailSender.Reset() ctx.Suite.EmailSender.Reset()
reqBody := map[string]string{ request := makePostRequestWithJSON(t, ctx.Router, "/api/auth/resend-verification", map[string]any{"email": "resend@example.com"})
"email": "resend@example.com", assertStatusRange(t, request, http.StatusOK, http.StatusNotFound)
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("POST", "/api/auth/resend-verification", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req)
assertStatusRange(t, rec, http.StatusOK, http.StatusNotFound)
}) })
t.Run("Auth_Confirm_Email_Endpoint", func(t *testing.T) { t.Run("Auth_Confirm_Email_Endpoint", func(t *testing.T) {
@@ -99,109 +60,66 @@ func TestIntegration_CompleteAPIEndpoints(t *testing.T) {
token = "test-token" token = "test-token"
} }
req := httptest.NewRequest("GET", "/api/auth/confirm?token="+url.QueryEscape(token), nil) request := makeGetRequest(t, ctx.Router, "/api/auth/confirm?token="+url.QueryEscape(token))
rec := httptest.NewRecorder() assertStatusRange(t, request, http.StatusOK, http.StatusBadRequest)
ctx.Router.ServeHTTP(rec, req)
assertStatusRange(t, rec, http.StatusOK, http.StatusBadRequest)
}) })
t.Run("Auth_Update_Email_Endpoint", func(t *testing.T) { t.Run("Auth_Update_Email_Endpoint", func(t *testing.T) {
ctx.Suite.EmailSender.Reset() ctx.Suite.EmailSender.Reset()
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "update_email_api_user", "update_email_api@example.com") user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "update_email_api_user", "update_email_api@example.com")
reqBody := map[string]string{ request := makePutRequest(t, ctx.Router, "/api/auth/email", map[string]any{"email": "newemail@example.com"}, user, nil)
"email": "newemail@example.com",
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("PUT", "/api/auth/email", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
rec := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req) response := assertJSONResponse(t, request, http.StatusOK)
if data, ok := getDataFromResponse(response); ok {
response := assertJSONResponse(t, rec, http.StatusOK)
if response != nil {
if data, ok := response["data"].(map[string]any); ok {
if email, ok := data["email"].(string); ok && email != "newemail@example.com" { if email, ok := data["email"].(string); ok && email != "newemail@example.com" {
t.Errorf("Expected email to be updated, got %s", email) t.Errorf("Expected email to be updated, got %s", email)
} }
} }
}
}) })
t.Run("Auth_Update_Username_Endpoint", func(t *testing.T) { t.Run("Auth_Update_Username_Endpoint", func(t *testing.T) {
ctx.Suite.EmailSender.Reset() ctx.Suite.EmailSender.Reset()
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "update_username_api_user", "update_username_api@example.com") user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "update_username_api_user", "update_username_api@example.com")
reqBody := map[string]string{ request := makePutRequest(t, ctx.Router, "/api/auth/username", map[string]any{"username": "new_username"}, user, nil)
"username": "new_username",
}
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("PUT", "/api/auth/username", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
rec := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req) response := assertJSONResponse(t, request, http.StatusOK)
if data, ok := getDataFromResponse(response); ok {
response := assertJSONResponse(t, rec, http.StatusOK)
if response != nil {
if data, ok := response["data"].(map[string]any); ok {
if username, ok := data["username"].(string); ok && username != "new_username" { if username, ok := data["username"].(string); ok && username != "new_username" {
t.Errorf("Expected username to be updated, got %s", username) t.Errorf("Expected username to be updated, got %s", username)
} }
} }
}
}) })
t.Run("Users_List_Endpoint", func(t *testing.T) { t.Run("Users_List_Endpoint", func(t *testing.T) {
ctx.Suite.EmailSender.Reset() ctx.Suite.EmailSender.Reset()
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "users_list_user", "users_list@example.com") user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "users_list_user", "users_list@example.com")
req := httptest.NewRequest("GET", "/api/users", nil) request := makeAuthenticatedGetRequest(t, ctx.Router, "/api/users", user, nil)
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
rec := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req) response := assertJSONResponse(t, request, http.StatusOK)
if data, ok := getDataFromResponse(response); ok {
response := assertJSONResponse(t, rec, http.StatusOK)
if response != nil {
if data, ok := response["data"].(map[string]any); ok {
if _, exists := data["users"]; !exists { if _, exists := data["users"]; !exists {
t.Error("Expected users in response") t.Error("Expected users in response")
} }
} }
}
}) })
t.Run("Users_Get_By_ID_Endpoint", func(t *testing.T) { t.Run("Users_Get_By_ID_Endpoint", func(t *testing.T) {
ctx.Suite.EmailSender.Reset() ctx.Suite.EmailSender.Reset()
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "users_get_user", "users_get@example.com") user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "users_get_user", "users_get@example.com")
req := httptest.NewRequest("GET", fmt.Sprintf("/api/users/%d", user.User.ID), nil) request := makeAuthenticatedGetRequest(t, ctx.Router, fmt.Sprintf("/api/users/%d", user.User.ID), user, map[string]string{"id": fmt.Sprintf("%d", user.User.ID)})
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
req = testutils.WithURLParams(req, map[string]string{"id": fmt.Sprintf("%d", user.User.ID)})
rec := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req) response := assertJSONResponse(t, request, http.StatusOK)
if data, ok := getDataFromResponse(response); ok {
response := assertJSONResponse(t, rec, http.StatusOK)
if response != nil {
if data, ok := response["data"].(map[string]any); ok {
if userData, ok := data["user"].(map[string]any); ok { if userData, ok := data["user"].(map[string]any); ok {
if id, ok := userData["id"].(float64); ok && uint(id) != user.User.ID { if id, ok := userData["id"].(float64); ok && uint(id) != user.User.ID {
t.Errorf("Expected user ID %d, got %.0f", user.User.ID, id) t.Errorf("Expected user ID %d, got %.0f", user.User.ID, id)
} }
} }
} }
}
}) })
t.Run("Users_Get_Posts_Endpoint", func(t *testing.T) { t.Run("Users_Get_Posts_Endpoint", func(t *testing.T) {
@@ -210,17 +128,10 @@ func TestIntegration_CompleteAPIEndpoints(t *testing.T) {
testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "User Posts Test", "https://example.com/user-posts") testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "User Posts Test", "https://example.com/user-posts")
req := httptest.NewRequest("GET", fmt.Sprintf("/api/users/%d/posts", user.User.ID), nil) request := makeAuthenticatedGetRequest(t, ctx.Router, fmt.Sprintf("/api/users/%d/posts", user.User.ID), user, map[string]string{"id": fmt.Sprintf("%d", user.User.ID)})
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
req = testutils.WithURLParams(req, map[string]string{"id": fmt.Sprintf("%d", user.User.ID)})
rec := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req) response := assertJSONResponse(t, request, http.StatusOK)
if data, ok := getDataFromResponse(response); ok {
response := assertJSONResponse(t, rec, http.StatusOK)
if response != nil {
if data, ok := response["data"].(map[string]any); ok {
if posts, ok := data["posts"].([]any); ok { if posts, ok := data["posts"].([]any); ok {
if len(posts) == 0 { if len(posts) == 0 {
t.Error("Expected at least one post in response") t.Error("Expected at least one post in response")
@@ -229,35 +140,24 @@ func TestIntegration_CompleteAPIEndpoints(t *testing.T) {
t.Error("Expected posts array in response") t.Error("Expected posts array in response")
} }
} }
}
}) })
t.Run("Users_Create_Endpoint", func(t *testing.T) { t.Run("Users_Create_Endpoint", func(t *testing.T) {
ctx.Suite.EmailSender.Reset() ctx.Suite.EmailSender.Reset()
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "users_create_admin", "users_create_admin@example.com") user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "users_create_admin", "users_create_admin@example.com")
reqBody := map[string]string{ request := makePostRequest(t, ctx.Router, "/api/users", map[string]any{
"username": "created_user", "username": "created_user",
"email": "created@example.com", "email": "created@example.com",
"password": "SecurePass123!", "password": "SecurePass123!",
} }, user, nil)
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("POST", "/api/users", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
rec := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req) response := assertJSONResponse(t, request, http.StatusCreated)
if data, ok := getDataFromResponse(response); ok {
response := assertJSONResponse(t, rec, http.StatusCreated)
if response != nil {
if data, ok := response["data"].(map[string]any); ok {
if _, exists := data["user"]; !exists { if _, exists := data["user"]; !exists {
t.Error("Expected user in response") t.Error("Expected user in response")
} }
} }
}
}) })
t.Run("Posts_Update_Endpoint", func(t *testing.T) { t.Run("Posts_Update_Endpoint", func(t *testing.T) {
@@ -266,30 +166,19 @@ func TestIntegration_CompleteAPIEndpoints(t *testing.T) {
post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Update Test Post", "https://example.com/update-test") post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Update Test Post", "https://example.com/update-test")
reqBody := map[string]string{ request := makePutRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d", post.ID), map[string]any{
"title": "Updated Title", "title": "Updated Title",
"content": "Updated content", "content": "Updated content",
} }, user, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("PUT", fmt.Sprintf("/api/posts/%d", post.ID), bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
req = testutils.WithURLParams(req, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
rec := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req) response := assertJSONResponse(t, request, http.StatusOK)
if data, ok := getDataFromResponse(response); ok {
response := assertJSONResponse(t, rec, http.StatusOK)
if response != nil {
if data, ok := response["data"].(map[string]any); ok {
if postData, ok := data["post"].(map[string]any); ok { if postData, ok := data["post"].(map[string]any); ok {
if title, ok := postData["title"].(string); ok && title != "Updated Title" { if title, ok := postData["title"].(string); ok && title != "Updated Title" {
t.Errorf("Expected title 'Updated Title', got '%s'", title) t.Errorf("Expected title 'Updated Title', got '%s'", title)
} }
} }
} }
}
}) })
t.Run("Posts_Delete_Endpoint", func(t *testing.T) { t.Run("Posts_Delete_Endpoint", func(t *testing.T) {
@@ -298,20 +187,11 @@ func TestIntegration_CompleteAPIEndpoints(t *testing.T) {
post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Delete Test Post", "https://example.com/delete-test") post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Delete Test Post", "https://example.com/delete-test")
req := httptest.NewRequest("DELETE", fmt.Sprintf("/api/posts/%d", post.ID), nil) request := makeDeleteRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d", post.ID), user, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
req.Header.Set("Authorization", "Bearer "+user.Token) assertStatus(t, request, http.StatusOK)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
req = testutils.WithURLParams(req, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
rec := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req) getRequest := makeGetRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d", post.ID))
assertStatus(t, getRequest, http.StatusNotFound)
assertStatus(t, rec, http.StatusOK)
getReq := httptest.NewRequest("GET", fmt.Sprintf("/api/posts/%d", post.ID), nil)
getRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(getRec, getReq)
assertStatus(t, getRec, http.StatusNotFound)
}) })
t.Run("Votes_Get_All_Endpoint", func(t *testing.T) { t.Run("Votes_Get_All_Endpoint", func(t *testing.T) {
@@ -319,28 +199,11 @@ func TestIntegration_CompleteAPIEndpoints(t *testing.T) {
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "votes_get_all_user", "votes_get_all@example.com") user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "votes_get_all_user", "votes_get_all@example.com")
post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Votes Test Post", "https://example.com/votes-test") post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Votes Test Post", "https://example.com/votes-test")
makePostRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d/vote", post.ID), map[string]any{"type": "up"}, user, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
request := makeAuthenticatedGetRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d/votes", post.ID), user, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
voteBody := map[string]string{"type": "up"} response := assertJSONResponse(t, request, http.StatusOK)
voteBodyBytes, _ := json.Marshal(voteBody) if data, ok := getDataFromResponse(response); ok {
voteReq := httptest.NewRequest("POST", fmt.Sprintf("/api/posts/%d/vote", post.ID), bytes.NewBuffer(voteBodyBytes))
voteReq.Header.Set("Content-Type", "application/json")
voteReq.Header.Set("Authorization", "Bearer "+user.Token)
voteReq = testutils.WithUserContext(voteReq, middleware.UserIDKey, user.User.ID)
voteReq = testutils.WithURLParams(voteReq, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
voteRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(voteRec, voteReq)
req := httptest.NewRequest("GET", fmt.Sprintf("/api/posts/%d/votes", post.ID), nil)
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
req = testutils.WithURLParams(req, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
rec := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req)
response := assertJSONResponse(t, rec, http.StatusOK)
if response != nil {
if data, ok := response["data"].(map[string]any); ok {
if votes, ok := data["votes"].([]any); ok { if votes, ok := data["votes"].([]any); ok {
if len(votes) == 0 { if len(votes) == 0 {
t.Error("Expected at least one vote in response") t.Error("Expected at least one vote in response")
@@ -349,7 +212,6 @@ func TestIntegration_CompleteAPIEndpoints(t *testing.T) {
t.Error("Expected votes array in response") t.Error("Expected votes array in response")
} }
} }
}
}) })
t.Run("Votes_Remove_Endpoint", func(t *testing.T) { t.Run("Votes_Remove_Endpoint", func(t *testing.T) {
@@ -358,49 +220,461 @@ func TestIntegration_CompleteAPIEndpoints(t *testing.T) {
post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Vote Remove Test", "https://example.com/vote-remove") post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Vote Remove Test", "https://example.com/vote-remove")
voteBody := map[string]string{"type": "up"} makePostRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d/vote", post.ID), map[string]any{"type": "up"}, user, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
voteBodyBytes, _ := json.Marshal(voteBody)
voteReq := httptest.NewRequest("POST", fmt.Sprintf("/api/posts/%d/vote", post.ID), bytes.NewBuffer(voteBodyBytes))
voteReq.Header.Set("Content-Type", "application/json")
voteReq.Header.Set("Authorization", "Bearer "+user.Token)
voteReq = testutils.WithUserContext(voteReq, middleware.UserIDKey, user.User.ID)
voteReq = testutils.WithURLParams(voteReq, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
voteRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(voteRec, voteReq)
req := httptest.NewRequest("DELETE", fmt.Sprintf("/api/posts/%d/vote", post.ID), nil) request := makeDeleteRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d/vote", post.ID), user, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
req.Header.Set("Authorization", "Bearer "+user.Token) assertStatus(t, request, http.StatusOK)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
req = testutils.WithURLParams(req, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
rec := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req)
assertStatus(t, rec, http.StatusOK)
}) })
t.Run("API_Info_Endpoint", func(t *testing.T) { t.Run("API_Info_Endpoint", func(t *testing.T) {
req := httptest.NewRequest("GET", "/api", nil) request := makeGetRequest(t, ctx.Router, "/api")
rec := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req) response := assertJSONResponse(t, request, http.StatusOK)
if data, ok := getDataFromResponse(response); ok {
response := assertJSONResponse(t, rec, http.StatusOK)
if response != nil {
if data, ok := response["data"].(map[string]any); ok {
if _, exists := data["endpoints"]; !exists { if _, exists := data["endpoints"]; !exists {
t.Error("Expected endpoints in API info") t.Error("Expected endpoints in API info")
} }
} }
}
}) })
t.Run("Swagger_Documentation_Endpoint", func(t *testing.T) { t.Run("Swagger_Documentation_Endpoint", func(t *testing.T) {
req := httptest.NewRequest("GET", "/swagger/index.html", nil) request := makeGetRequest(t, ctx.Router, "/swagger/index.html")
rec := httptest.NewRecorder() assertStatusRange(t, request, http.StatusOK, http.StatusNotFound)
})
ctx.Router.ServeHTTP(rec, req) t.Run("Search_Endpoint_Edge_Cases", func(t *testing.T) {
ctx.Suite.EmailSender.Reset()
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, uniqueTestUsername(t, "search_edge"), uniqueTestEmail(t, "search_edge"))
assertStatusRange(t, rec, http.StatusOK, http.StatusNotFound) testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Searchable Post One", "https://example.com/one")
testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Searchable Post Two", "https://example.com/two")
testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Different Content", "https://example.com/three")
t.Run("Empty_Search_Results", func(t *testing.T) {
request := makeGetRequest(t, ctx.Router, "/api/posts/search?q=nonexistentterm12345")
response := assertJSONResponse(t, request, http.StatusOK)
if data, ok := getDataFromResponse(response); ok {
if posts, ok := data["posts"].([]any); ok {
if len(posts) != 0 {
t.Errorf("Expected empty search results, got %d posts", len(posts))
}
}
if count, ok := data["count"].(float64); ok && count != 0 {
t.Errorf("Expected count 0, got %.0f", count)
}
}
})
t.Run("Search_With_Pagination", func(t *testing.T) {
request := makeGetRequest(t, ctx.Router, "/api/posts/search?q=Searchable&limit=1&offset=0")
response := assertJSONResponse(t, request, http.StatusOK)
var firstPostID any
if data, ok := getDataFromResponse(response); ok {
if posts, ok := data["posts"].([]any); ok {
if len(posts) > 1 {
t.Errorf("Expected at most 1 post with limit=1, got %d", len(posts))
}
if len(posts) > 0 {
if post, ok := posts[0].(map[string]any); ok {
firstPostID = post["id"]
}
}
}
if limit, ok := data["limit"].(float64); ok && limit != 1 {
t.Errorf("Expected limit 1 in response, got %.0f", limit)
}
if offset, ok := data["offset"].(float64); ok && offset != 0 {
t.Errorf("Expected offset 0 in response, got %.0f", offset)
}
}
secondRequest := makeGetRequest(t, ctx.Router, "/api/posts/search?q=Searchable&limit=1&offset=1")
secondResponse := assertJSONResponse(t, secondRequest, http.StatusOK)
if data, ok := getDataFromResponse(secondResponse); ok {
if posts, ok := data["posts"].([]any); ok {
if len(posts) > 1 {
t.Errorf("Expected at most 1 post with limit=1 and offset=1, got %d", len(posts))
}
if len(posts) > 0 && firstPostID != nil {
if post, ok := posts[0].(map[string]any); ok {
if post["id"] == firstPostID {
t.Error("Expected different post with offset=1, got same post as offset=0")
}
}
}
}
}
})
t.Run("Search_With_Special_Characters", func(t *testing.T) {
specialQueries := []string{
"Searchable%20Post",
"Searchable'Post",
"Searchable\"Post",
"Searchable;Post",
"Searchable--Post",
}
for _, query := range specialQueries {
request := makeGetRequest(t, ctx.Router, "/api/posts/search?q="+url.QueryEscape(query))
assertStatusRange(t, request, http.StatusOK, http.StatusBadRequest)
}
})
t.Run("Search_Empty_Query", func(t *testing.T) {
request := makeGetRequest(t, ctx.Router, "/api/posts/search?q=")
response := assertJSONResponse(t, request, http.StatusOK)
if data, ok := getDataFromResponse(response); ok {
if posts, ok := data["posts"].([]any); ok {
if len(posts) != 0 {
t.Errorf("Expected empty results for empty query, got %d posts", len(posts))
}
}
if count, ok := data["count"].(float64); ok && count != 0 {
t.Errorf("Expected count 0 for empty query, got %.0f", count)
}
}
})
t.Run("Search_With_Very_Long_Query", func(t *testing.T) {
longQuery := strings.Repeat("a", 1000)
request := makeGetRequest(t, ctx.Router, "/api/posts/search?q="+url.QueryEscape(longQuery))
assertStatusRange(t, request, http.StatusOK, http.StatusBadRequest)
})
t.Run("Search_Case_Insensitive", func(t *testing.T) {
request := makeGetRequest(t, ctx.Router, "/api/posts/search?q=SEARCHABLE")
response := assertJSONResponse(t, request, http.StatusOK)
if data, ok := getDataFromResponse(response); ok {
if posts, ok := data["posts"].([]any); ok {
if len(posts) == 0 {
t.Error("Expected case-insensitive search to find posts")
}
}
}
})
})
t.Run("Title_Fetch_Endpoint_Edge_Cases", func(t *testing.T) {
ctx.Suite.EmailSender.Reset()
t.Run("Missing_URL_Parameter", func(t *testing.T) {
request := makeGetRequest(t, ctx.Router, "/api/posts/title")
assertErrorResponse(t, request, http.StatusBadRequest)
})
t.Run("Empty_URL_Parameter", func(t *testing.T) {
request := makeGetRequest(t, ctx.Router, "/api/posts/title?url=")
assertErrorResponse(t, request, http.StatusBadRequest)
})
t.Run("Invalid_URL_Format", func(t *testing.T) {
invalidURLs := []string{
"not-a-url",
"://invalid",
"http://",
"https://",
}
for _, invalidURL := range invalidURLs {
ctx.Suite.TitleFetcher.SetError(services.ErrUnsupportedScheme)
request := makeGetRequest(t, ctx.Router, "/api/posts/title?url="+url.QueryEscape(invalidURL))
assertErrorResponse(t, request, http.StatusBadRequest)
}
})
t.Run("Unsupported_URL_Schemes", func(t *testing.T) {
unsupportedSchemes := []string{
"ftp://example.com",
"file:///etc/passwd",
"javascript:alert(1)",
"data:text/html,<script>alert(1)</script>",
}
for _, schemeURL := range unsupportedSchemes {
request := makeGetRequest(t, ctx.Router, "/api/posts/title?url="+url.QueryEscape(schemeURL))
assertErrorResponse(t, request, http.StatusBadRequest)
}
})
t.Run("SSRF_Protection_Localhost", func(t *testing.T) {
ssrfURLs := []string{
"http://localhost",
"http://127.0.0.1",
"http://127.0.0.1:8080",
"http://[::1]",
"http://0.0.0.0",
}
for _, ssrfURL := range ssrfURLs {
ctx.Suite.TitleFetcher.SetError(services.ErrSSRFBlocked)
request := makeGetRequest(t, ctx.Router, "/api/posts/title?url="+url.QueryEscape(ssrfURL))
assertStatusRange(t, request, http.StatusBadRequest, http.StatusBadGateway)
}
})
t.Run("SSRF_Protection_Private_IPs", func(t *testing.T) {
privateIPs := []string{
"http://192.168.1.1",
"http://10.0.0.1",
"http://172.16.0.1",
}
for _, privateIP := range privateIPs {
ctx.Suite.TitleFetcher.SetError(services.ErrSSRFBlocked)
request := makeGetRequest(t, ctx.Router, "/api/posts/title?url="+url.QueryEscape(privateIP))
assertStatusRange(t, request, http.StatusBadRequest, http.StatusBadGateway)
}
})
t.Run("Title_Fetch_Error_Handling", func(t *testing.T) {
ctx.Suite.TitleFetcher.SetError(services.ErrTitleNotFound)
request := makeGetRequest(t, ctx.Router, "/api/posts/title?url=https://example.com/notitle")
assertErrorResponse(t, request, http.StatusBadRequest)
})
t.Run("Valid_URL_Success", func(t *testing.T) {
ctx.Suite.TitleFetcher.SetTitle("Valid Title")
request := makeGetRequest(t, ctx.Router, "/api/posts/title?url=https://example.com/valid")
response := assertJSONResponse(t, request, http.StatusOK)
if data, ok := getDataFromResponse(response); ok {
if title, ok := data["title"].(string); ok {
if title != "Valid Title" {
t.Errorf("Expected title 'Valid Title', got '%s'", title)
}
} else {
t.Error("Expected title in response data")
}
}
})
})
t.Run("Get_User_Vote_Edge_Cases", func(t *testing.T) {
ctx.Suite.EmailSender.Reset()
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, uniqueTestUsername(t, "vote_edge"), uniqueTestEmail(t, "vote_edge"))
secondUser := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, uniqueTestUsername(t, "vote_edge2"), uniqueTestEmail(t, "vote_edge2"))
post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Vote Edge Test Post", "https://example.com/vote-edge")
t.Run("Get_Vote_When_User_Has_Voted", func(t *testing.T) {
voteRequest := makePostRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d/vote", post.ID), map[string]any{"type": "up"}, user, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
assertStatus(t, voteRequest, http.StatusOK)
request := makeAuthenticatedGetRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d/vote", post.ID), user, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
response := assertJSONResponse(t, request, http.StatusOK)
if data, ok := getDataFromResponse(response); ok {
if hasVote, ok := data["has_vote"].(bool); !ok || !hasVote {
t.Error("Expected has_vote to be true when user has voted")
}
if vote, ok := data["vote"]; !ok || vote == nil {
t.Error("Expected vote object when user has voted")
}
}
})
t.Run("Get_Vote_When_User_Has_Not_Voted", func(t *testing.T) {
request := makeAuthenticatedGetRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d/vote", post.ID), secondUser, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
response := assertJSONResponse(t, request, http.StatusOK)
if data, ok := getDataFromResponse(response); ok {
if hasVote, ok := data["has_vote"].(bool); ok {
if hasVote {
t.Error("Expected has_vote to be false when user has not voted")
}
} else {
t.Error("Expected has_vote field in response")
}
}
})
t.Run("Get_Vote_Invalid_Post_ID", func(t *testing.T) {
request := makeAuthenticatedGetRequest(t, ctx.Router, "/api/posts/999999/vote", user, map[string]string{"id": "999999"})
if request.Code != http.StatusOK && request.Code != http.StatusNotFound {
t.Errorf("Expected status 200 or 404 for invalid post ID, got %d", request.Code)
}
})
t.Run("Get_Vote_Unauthenticated", func(t *testing.T) {
request := makeGetRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d/vote", post.ID))
assertErrorResponse(t, request, http.StatusUnauthorized)
})
t.Run("Get_Vote_Response_Structure", func(t *testing.T) {
request := makeAuthenticatedGetRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d/vote", post.ID), user, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
response := assertJSONResponse(t, request, http.StatusOK)
if success, ok := response["success"].(bool); !ok || !success {
t.Error("Expected success field to be true")
}
if data, ok := getDataFromResponse(response); ok {
if _, exists := data["has_vote"]; !exists {
t.Error("Expected has_vote field in response data")
}
if _, exists := data["is_anonymous"]; !exists {
t.Error("Expected is_anonymous field in response data")
}
} else {
t.Error("Expected data field in response")
}
})
})
t.Run("Refresh_Token_Edge_Cases", func(t *testing.T) {
ctx.Suite.EmailSender.Reset()
refreshUser := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, uniqueTestUsername(t, "refresh_edge"), uniqueTestEmail(t, "refresh_edge"))
t.Run("Refresh_With_Expired_Token", func(t *testing.T) {
loginResult, err := ctx.AuthService.Login(refreshUser.User.Username, "SecurePass123!")
if err != nil {
t.Fatalf("Failed to login: %v", err)
}
refreshToken, err := ctx.Suite.RefreshTokenRepo.GetByTokenHash(testutils.HashVerificationToken(loginResult.RefreshToken))
if err != nil {
t.Fatalf("Failed to get refresh token: %v", err)
}
refreshToken.ExpiresAt = time.Now().Add(-1 * time.Hour)
if err := ctx.Suite.DB.Model(refreshToken).Update("expires_at", refreshToken.ExpiresAt).Error; err != nil {
t.Fatalf("Failed to expire refresh token: %v", err)
}
request := makePostRequestWithJSON(t, ctx.Router, "/api/auth/refresh", map[string]any{"refresh_token": loginResult.RefreshToken})
assertErrorResponse(t, request, http.StatusUnauthorized)
})
t.Run("Refresh_With_Revoked_Token", func(t *testing.T) {
loginResult, err := ctx.AuthService.Login(refreshUser.User.Username, "SecurePass123!")
if err != nil {
t.Fatalf("Failed to login: %v", err)
}
if err := ctx.AuthService.RevokeRefreshToken(loginResult.RefreshToken); err != nil {
t.Fatalf("Failed to revoke refresh token: %v", err)
}
request := makePostRequestWithJSON(t, ctx.Router, "/api/auth/refresh", map[string]any{"refresh_token": loginResult.RefreshToken})
assertErrorResponse(t, request, http.StatusUnauthorized)
})
t.Run("Refresh_With_Empty_Token", func(t *testing.T) {
request := makePostRequestWithJSON(t, ctx.Router, "/api/auth/refresh", map[string]any{"refresh_token": ""})
assertErrorResponse(t, request, http.StatusBadRequest)
})
t.Run("Refresh_With_Missing_Token_Field", func(t *testing.T) {
request := makePostRequestWithJSON(t, ctx.Router, "/api/auth/refresh", map[string]any{})
assertErrorResponse(t, request, http.StatusBadRequest)
})
t.Run("Refresh_Token_Rotation", func(t *testing.T) {
loginResult, err := ctx.AuthService.Login(refreshUser.User.Username, "SecurePass123!")
if err != nil {
t.Fatalf("Failed to login: %v", err)
}
originalRefreshToken := loginResult.RefreshToken
request := makePostRequestWithJSON(t, ctx.Router, "/api/auth/refresh", map[string]any{"refresh_token": originalRefreshToken})
response := assertJSONResponse(t, request, http.StatusOK)
if data, ok := getDataFromResponse(response); ok {
if newAccessToken, ok := data["access_token"].(string); ok {
if newAccessToken == "" {
t.Error("Expected new access token in refresh response")
}
if newRefreshToken, ok := data["refresh_token"].(string); ok {
if newRefreshToken != "" && newRefreshToken == originalRefreshToken {
t.Log("Refresh token rotation may not be implemented (same token returned)")
}
}
}
}
})
t.Run("Refresh_After_Account_Lock", func(t *testing.T) {
lockedUser := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, uniqueTestUsername(t, "refresh_lock"), uniqueTestEmail(t, "refresh_lock"))
loginResult, err := ctx.AuthService.Login(lockedUser.User.Username, "SecurePass123!")
if err != nil {
t.Fatalf("Failed to login: %v", err)
}
lockedUser.User.Locked = true
if err := ctx.Suite.UserRepo.Update(lockedUser.User); err != nil {
t.Fatalf("Failed to lock user: %v", err)
}
request := makePostRequestWithJSON(t, ctx.Router, "/api/auth/refresh", map[string]any{"refresh_token": loginResult.RefreshToken})
assertStatusRange(t, request, http.StatusUnauthorized, http.StatusForbidden)
})
t.Run("Refresh_With_Invalid_Token_Format", func(t *testing.T) {
request := makePostRequestWithJSON(t, ctx.Router, "/api/auth/refresh", map[string]any{"refresh_token": "invalid-token-format-12345"})
assertErrorResponse(t, request, http.StatusUnauthorized)
})
})
t.Run("Pagination_Edge_Cases", func(t *testing.T) {
ctx.Suite.EmailSender.Reset()
paginationUser := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, uniqueTestUsername(t, "pagination_edge"), uniqueTestEmail(t, "pagination_edge"))
for i := 0; i < 5; i++ {
testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, paginationUser.User.ID, fmt.Sprintf("Pagination Post %d", i), fmt.Sprintf("https://example.com/pag%d", i))
}
t.Run("Negative_Limit", func(t *testing.T) {
request := makeGetRequest(t, ctx.Router, "/api/posts?limit=-1")
assertStatusRange(t, request, http.StatusOK, http.StatusBadRequest)
})
t.Run("Negative_Offset", func(t *testing.T) {
request := makeGetRequest(t, ctx.Router, "/api/posts?offset=-1")
assertStatusRange(t, request, http.StatusOK, http.StatusBadRequest)
})
t.Run("Very_Large_Limit", func(t *testing.T) {
request := makeGetRequest(t, ctx.Router, "/api/posts?limit=10000")
assertStatusRange(t, request, http.StatusOK, http.StatusBadRequest)
})
t.Run("Very_Large_Offset", func(t *testing.T) {
request := makeGetRequest(t, ctx.Router, "/api/posts?offset=10000")
response := assertJSONResponse(t, request, http.StatusOK)
if data, ok := getDataFromResponse(response); ok {
if posts, ok := data["posts"].([]any); ok {
if len(posts) > 0 {
t.Logf("Large offset returned %d posts (may be expected)", len(posts))
}
}
}
})
t.Run("Invalid_Pagination_Parameters", func(t *testing.T) {
invalidParams := []string{
"limit=abc",
"offset=xyz",
"limit=",
"offset=",
}
for _, param := range invalidParams {
request := makeGetRequest(t, ctx.Router, "/api/posts?"+param)
assertStatus(t, request, http.StatusOK)
}
})
}) })
} }

View File

@@ -1,17 +1,12 @@
package integration package integration
import ( import (
"bytes"
"compress/gzip" "compress/gzip"
"encoding/json"
"io" "io"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"strings" "strings"
"testing" "testing"
"goyco/internal/middleware"
"goyco/internal/testutils"
) )
func TestIntegration_Compression(t *testing.T) { func TestIntegration_Compression(t *testing.T) {
@@ -19,16 +14,16 @@ func TestIntegration_Compression(t *testing.T) {
router := ctx.Router router := ctx.Router
t.Run("Response_Compression_Gzip", func(t *testing.T) { t.Run("Response_Compression_Gzip", func(t *testing.T) {
req := httptest.NewRequest("GET", "/api/posts", nil) request := httptest.NewRequest("GET", "/api/posts", nil)
req.Header.Set("Accept-Encoding", "gzip") request.Header.Set("Accept-Encoding", "gzip")
rec := httptest.NewRecorder() recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req) router.ServeHTTP(recorder, request)
contentEncoding := rec.Header().Get("Content-Encoding") contentEncoding := recorder.Header().Get("Content-Encoding")
if contentEncoding != "" && strings.Contains(contentEncoding, "gzip") { if contentEncoding != "" && strings.Contains(contentEncoding, "gzip") {
assertHeaderContains(t, rec, "Content-Encoding", "gzip") assertHeaderContains(t, recorder, "Content-Encoding", "gzip")
reader, err := gzip.NewReader(rec.Body) reader, err := gzip.NewReader(recorder.Body)
if err != nil { if err != nil {
t.Fatalf("Failed to create gzip reader: %v", err) t.Fatalf("Failed to create gzip reader: %v", err)
} }
@@ -48,14 +43,14 @@ func TestIntegration_Compression(t *testing.T) {
}) })
t.Run("Compression_Headers_Present", func(t *testing.T) { t.Run("Compression_Headers_Present", func(t *testing.T) {
req := httptest.NewRequest("GET", "/api/posts", nil) request := httptest.NewRequest("GET", "/api/posts", nil)
req.Header.Set("Accept-Encoding", "gzip, deflate") request.Header.Set("Accept-Encoding", "gzip, deflate")
rec := httptest.NewRecorder() recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req) router.ServeHTTP(recorder, request)
if rec.Header().Get("Vary") != "" { if recorder.Header().Get("Vary") != "" {
assertHeaderContains(t, rec, "Vary", "Accept-Encoding") assertHeaderContains(t, recorder, "Vary", "Accept-Encoding")
} else { } else {
t.Log("Vary header may not always be present") t.Log("Vary header may not always be present")
} }
@@ -67,25 +62,19 @@ func TestIntegration_StaticFiles(t *testing.T) {
router := ctx.Router router := ctx.Router
t.Run("Robots_Txt_Served", func(t *testing.T) { t.Run("Robots_Txt_Served", func(t *testing.T) {
req := httptest.NewRequest("GET", "/robots.txt", nil) request := makeGetRequest(t, router, "/robots.txt")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req) assertStatus(t, request, http.StatusOK)
assertStatus(t, rec, http.StatusOK) if !strings.Contains(request.Body.String(), "User-agent") {
if !strings.Contains(rec.Body.String(), "User-agent") {
t.Error("Expected robots.txt content") t.Error("Expected robots.txt content")
} }
}) })
t.Run("Static_Files_Security_Headers", func(t *testing.T) { t.Run("Static_Files_Security_Headers", func(t *testing.T) {
req := httptest.NewRequest("GET", "/robots.txt", nil) request := makeGetRequest(t, router, "/robots.txt")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req) if request.Header().Get("X-Content-Type-Options") == "" {
if rec.Header().Get("X-Content-Type-Options") == "" {
t.Log("Security headers may not be applied to all static files") t.Log("Security headers may not be applied to all static files")
} }
}) })
@@ -101,32 +90,22 @@ func TestIntegration_URLMetadata(t *testing.T) {
ctx.Suite.TitleFetcher.SetTitle("Fetched Title") ctx.Suite.TitleFetcher.SetTitle("Fetched Title")
postBody := map[string]string{ postBody := map[string]any{
"title": "Test Post", "title": "Test Post",
"url": "https://example.com/metadata-test", "url": "https://example.com/metadata-test",
"content": "Test content", "content": "Test content",
} }
body, _ := json.Marshal(postBody) request := makePostRequest(t, router, "/api/posts", postBody, user, nil)
req := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req) assertStatus(t, request, http.StatusCreated)
assertStatus(t, rec, http.StatusCreated)
}) })
t.Run("URL_Metadata_Endpoint", func(t *testing.T) { t.Run("URL_Metadata_Endpoint", func(t *testing.T) {
ctx.Suite.TitleFetcher.SetTitle("Endpoint Title") ctx.Suite.TitleFetcher.SetTitle("Endpoint Title")
req := httptest.NewRequest("GET", "/api/posts/title?url=https://example.com/test", nil) request := makeGetRequest(t, router, "/api/posts/title?url=https://example.com/test")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req) response := assertJSONResponse(t, request, http.StatusOK)
response := assertJSONResponse(t, rec, http.StatusOK)
if response != nil { if response != nil {
if data, ok := response["data"].(map[string]any); ok { if data, ok := response["data"].(map[string]any); ok {
if _, exists := data["title"]; !exists { if _, exists := data["title"]; !exists {

View File

@@ -1,14 +1,10 @@
package integration package integration
import ( import (
"bytes"
"encoding/json"
"fmt" "fmt"
"net/http" "net/http"
"net/http/httptest"
"testing" "testing"
"goyco/internal/middleware"
"goyco/internal/testutils" "goyco/internal/testutils"
) )
@@ -22,33 +18,19 @@ func TestIntegration_CrossComponentAuthorization(t *testing.T) {
post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, owner.User.ID, "Owner Post", "https://example.com/owner") post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, owner.User.ID, "Owner Post", "https://example.com/owner")
updateBody := map[string]string{ request := makePutRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d", post.ID), map[string]any{
"title": "Updated Title", "title": "Updated Title",
"content": "Updated content", "content": "Updated content",
} }, otherUser, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
body, _ := json.Marshal(updateBody)
req := httptest.NewRequest("PUT", fmt.Sprintf("/api/posts/%d", post.ID), bytes.NewBuffer(body)) assertErrorResponse(t, request, http.StatusForbidden)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+otherUser.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, otherUser.User.ID)
req = testutils.WithURLParams(req, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
rec := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req) request = makePutRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d", post.ID), map[string]any{
"title": "Updated Title",
"content": "Updated content",
}, owner, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
assertErrorResponse(t, rec, http.StatusForbidden) assertStatus(t, request, http.StatusOK)
req = httptest.NewRequest("PUT", fmt.Sprintf("/api/posts/%d", post.ID), bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+owner.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, owner.User.ID)
req = testutils.WithURLParams(req, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
rec = httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req)
assertStatus(t, rec, http.StatusOK)
}) })
t.Run("Post_Delete_Authorization", func(t *testing.T) { t.Run("Post_Delete_Authorization", func(t *testing.T) {
@@ -58,47 +40,27 @@ func TestIntegration_CrossComponentAuthorization(t *testing.T) {
post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, owner.User.ID, "Delete Post", "https://example.com/delete") post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, owner.User.ID, "Delete Post", "https://example.com/delete")
req := httptest.NewRequest("DELETE", fmt.Sprintf("/api/posts/%d", post.ID), nil) request := makeDeleteRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d", post.ID), otherUser, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
req.Header.Set("Authorization", "Bearer "+otherUser.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, otherUser.User.ID)
req = testutils.WithURLParams(req, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
rec := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req) assertErrorResponse(t, request, http.StatusForbidden)
assertErrorResponse(t, rec, http.StatusForbidden) request = makeDeleteRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d", post.ID), owner, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
req = httptest.NewRequest("DELETE", fmt.Sprintf("/api/posts/%d", post.ID), nil) assertStatus(t, request, http.StatusOK)
req.Header.Set("Authorization", "Bearer "+owner.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, owner.User.ID)
req = testutils.WithURLParams(req, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
rec = httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req)
assertStatus(t, rec, http.StatusOK)
}) })
t.Run("User_Profile_Access_Authorization", func(t *testing.T) { t.Run("User_Profile_Access_Authorization", func(t *testing.T) {
ctx.Suite.EmailSender.Reset() ctx.Suite.EmailSender.Reset()
user1 := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "profile_user1", "profile_user1@example.com") firstUser := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "profile_user1", "profile_user1@example.com")
user2 := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "profile_user2", "profile_user2@example.com") secondUser := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "profile_user2", "profile_user2@example.com")
req := httptest.NewRequest("GET", fmt.Sprintf("/api/users/%d", user1.User.ID), nil) request := makeAuthenticatedGetRequest(t, ctx.Router, fmt.Sprintf("/api/users/%d", firstUser.User.ID), secondUser, map[string]string{"id": fmt.Sprintf("%d", firstUser.User.ID)})
req.Header.Set("Authorization", "Bearer "+user2.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user2.User.ID)
req = testutils.WithURLParams(req, map[string]string{"id": fmt.Sprintf("%d", user1.User.ID)})
rec := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req) response := assertJSONResponse(t, request, http.StatusOK)
if data, ok := getDataFromResponse(response); ok {
response := assertJSONResponse(t, rec, http.StatusOK)
if response != nil {
if data, ok := response["data"].(map[string]any); ok {
if userData, ok := data["user"].(map[string]any); ok { if userData, ok := data["user"].(map[string]any); ok {
if id, ok := userData["id"].(float64); ok && uint(id) != user1.User.ID { if id, ok := userData["id"].(float64); ok && uint(id) != firstUser.User.ID {
t.Errorf("Expected user ID %d, got %.0f", user1.User.ID, id) t.Errorf("Expected user ID %d, got %.0f", firstUser.User.ID, id)
}
} }
} }
} }
@@ -109,24 +71,13 @@ func TestIntegration_CrossComponentAuthorization(t *testing.T) {
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "settings_auth_user", "settings_auth@example.com") user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "settings_auth_user", "settings_auth@example.com")
otherUser := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "settings_auth_other", "settings_auth_other@example.com") otherUser := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "settings_auth_other", "settings_auth_other@example.com")
updateBody := map[string]string{ request := makePutRequest(t, ctx.Router, "/api/auth/email", map[string]any{"email": "newemail@example.com"}, otherUser, nil)
"email": "newemail@example.com",
}
body, _ := json.Marshal(updateBody)
req := httptest.NewRequest("PUT", "/api/auth/email", bytes.NewBuffer(body)) response := assertJSONResponse(t, request, http.StatusOK)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+otherUser.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, otherUser.User.ID)
rec := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req)
response := assertJSONResponse(t, rec, http.StatusOK)
if response == nil { if response == nil {
return return
} }
if data, ok := response["data"].(map[string]any); ok { if data, ok := getDataFromResponse(response); ok {
if userData, ok := data["user"].(map[string]any); ok { if userData, ok := data["user"].(map[string]any); ok {
if email, ok := userData["email"].(string); ok && email == "newemail@example.com" { if email, ok := userData["email"].(string); ok && email == "newemail@example.com" {
if id, ok := userData["id"].(float64); ok && uint(id) != otherUser.User.ID { if id, ok := userData["id"].(float64); ok && uint(id) != otherUser.User.ID {
@@ -136,20 +87,9 @@ func TestIntegration_CrossComponentAuthorization(t *testing.T) {
} }
} }
updateBody2 := map[string]string{ request = makePutRequest(t, ctx.Router, "/api/auth/email", map[string]any{"email": "anothernewemail@example.com"}, user, nil)
"email": "anothernewemail@example.com",
}
body2, _ := json.Marshal(updateBody2)
req = httptest.NewRequest("PUT", "/api/auth/email", bytes.NewBuffer(body2)) assertStatus(t, request, http.StatusOK)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
rec = httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req)
assertStatus(t, rec, http.StatusOK)
}) })
t.Run("Vote_Authorization", func(t *testing.T) { t.Run("Vote_Authorization", func(t *testing.T) {
@@ -159,73 +99,42 @@ func TestIntegration_CrossComponentAuthorization(t *testing.T) {
post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, postOwner.User.ID, "Vote Auth Post", "https://example.com/vote-auth") post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, postOwner.User.ID, "Vote Auth Post", "https://example.com/vote-auth")
voteBody := map[string]string{"type": "up"} request := makePostRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d/vote", post.ID), map[string]any{"type": "up"}, user, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
body, _ := json.Marshal(voteBody)
req := httptest.NewRequest("POST", fmt.Sprintf("/api/posts/%d/vote", post.ID), bytes.NewBuffer(body)) assertStatus(t, request, http.StatusOK)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
req = testutils.WithURLParams(req, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
rec := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req) request = makePostRequestWithJSON(t, ctx.Router, fmt.Sprintf("/api/posts/%d/vote", post.ID), map[string]any{"type": "up"})
assertStatus(t, rec, http.StatusOK) assertErrorResponse(t, request, http.StatusUnauthorized)
req = httptest.NewRequest("POST", fmt.Sprintf("/api/posts/%d/vote", post.ID), bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
rec = httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req)
assertErrorResponse(t, rec, http.StatusUnauthorized)
}) })
t.Run("Protected_Endpoint_Without_Auth", func(t *testing.T) { t.Run("Protected_Endpoint_Without_Auth", func(t *testing.T) {
ctx.Suite.EmailSender.Reset() ctx.Suite.EmailSender.Reset()
req := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer([]byte("{}"))) request := makePostRequestWithJSON(t, ctx.Router, "/api/posts", map[string]any{})
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req) assertErrorResponse(t, request, http.StatusUnauthorized)
assertErrorResponse(t, rec, http.StatusUnauthorized)
}) })
t.Run("Protected_Endpoint_With_Invalid_Token", func(t *testing.T) { t.Run("Protected_Endpoint_With_Invalid_Token", func(t *testing.T) {
ctx.Suite.EmailSender.Reset() ctx.Suite.EmailSender.Reset()
req := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer([]byte("{}"))) request := makeRequest(t, ctx.Router, "POST", "/api/posts", []byte("{}"), map[string]string{"Content-Type": "application/json", "Authorization": "Bearer invalid-token"})
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer invalid-token")
rec := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req) assertErrorResponse(t, request, http.StatusUnauthorized)
assertErrorResponse(t, rec, http.StatusUnauthorized)
}) })
t.Run("User_List_Authorization", func(t *testing.T) { t.Run("User_List_Authorization", func(t *testing.T) {
ctx.Suite.EmailSender.Reset() ctx.Suite.EmailSender.Reset()
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "list_auth_user", "list_auth@example.com") user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "list_auth_user", "list_auth@example.com")
req := httptest.NewRequest("GET", "/api/users", nil) request := makeAuthenticatedGetRequest(t, ctx.Router, "/api/users", user, nil)
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
rec := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req) assertStatus(t, request, http.StatusOK)
assertStatus(t, rec, http.StatusOK) request = makeGetRequest(t, ctx.Router, "/api/users")
req = httptest.NewRequest("GET", "/api/users", nil) assertErrorResponse(t, request, http.StatusUnauthorized)
rec = httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req)
assertErrorResponse(t, rec, http.StatusUnauthorized)
}) })
t.Run("Refresh_Token_Authorization", func(t *testing.T) { t.Run("Refresh_Token_Authorization", func(t *testing.T) {
@@ -237,18 +146,9 @@ func TestIntegration_CrossComponentAuthorization(t *testing.T) {
t.Fatalf("Failed to login: %v", err) t.Fatalf("Failed to login: %v", err)
} }
refreshBody := map[string]string{ request := makePostRequestWithJSON(t, ctx.Router, "/api/auth/refresh", map[string]any{"refresh_token": loginResult.RefreshToken})
"refresh_token": loginResult.RefreshToken,
}
body, _ := json.Marshal(refreshBody)
req := httptest.NewRequest("POST", "/api/auth/refresh", bytes.NewBuffer(body)) response := assertJSONResponse(t, request, http.StatusOK)
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req)
response := assertJSONResponse(t, rec, http.StatusOK)
if response == nil { if response == nil {
return return
} }
@@ -260,17 +160,8 @@ func TestIntegration_CrossComponentAuthorization(t *testing.T) {
t.Error("Expected data field in refresh response") t.Error("Expected data field in refresh response")
} }
refreshBody = map[string]string{ request = makePostRequestWithJSON(t, ctx.Router, "/api/auth/refresh", map[string]any{"refresh_token": "invalid-refresh-token"})
"refresh_token": "invalid-refresh-token",
}
body, _ = json.Marshal(refreshBody)
req = httptest.NewRequest("POST", "/api/auth/refresh", bytes.NewBuffer(body)) assertErrorResponse(t, request, http.StatusUnauthorized)
req.Header.Set("Content-Type", "application/json")
rec = httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req)
assertErrorResponse(t, rec, http.StatusUnauthorized)
}) })
} }

View File

@@ -14,166 +14,137 @@ func TestIntegration_CSRF_Protection(t *testing.T) {
ctx := setupPageHandlerTestContext(t) ctx := setupPageHandlerTestContext(t)
router := ctx.Router router := ctx.Router
t.Run("CSRF_Blocks_Form_Without_Token", func(t *testing.T) { getCSRFToken := func(t *testing.T, path string, cookies ...*http.Cookie) *http.Cookie {
reqBody := url.Values{} t.Helper()
reqBody.Set("username", "testuser")
reqBody.Set("email", "test@example.com")
reqBody.Set("password", "SecurePass123!")
req := httptest.NewRequest("POST", "/register", strings.NewReader(reqBody.Encode())) request := httptest.NewRequest("GET", path, nil)
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") for _, c := range cookies {
rec := httptest.NewRecorder() request.AddCookie(c)
router.ServeHTTP(rec, req)
if rec.Code != http.StatusForbidden {
t.Errorf("Expected status 403, got %d. Body: %s", rec.Code, rec.Body.String())
} }
if !strings.Contains(rec.Body.String(), "Invalid CSRF token") { recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, request)
for _, cookie := range recorder.Result().Cookies() {
if cookie.Name == "csrf_token" {
return cookie
}
}
t.Fatalf("Expected CSRF cookie to be set for %s", path)
return nil
}
t.Run("CSRF_Blocks_Form_Without_Token", func(t *testing.T) {
requestBody := url.Values{}
requestBody.Set("username", "testuser")
requestBody.Set("email", "test@example.com")
requestBody.Set("password", "SecurePass123!")
request := httptest.NewRequest("POST", "/register", strings.NewReader(requestBody.Encode()))
request.Header.Set("Content-Type", "application/x-www-form-urlencoded")
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, request)
if recorder.Code != http.StatusForbidden {
t.Errorf("Expected status 403, got %d. Body: %s", recorder.Code, recorder.Body.String())
}
if !strings.Contains(recorder.Body.String(), "Invalid CSRF token") {
t.Error("Expected CSRF error message") t.Error("Expected CSRF error message")
} }
}) })
t.Run("CSRF_Allows_Form_With_Valid_Token", func(t *testing.T) { t.Run("CSRF_Allows_Form_With_Valid_Token", func(t *testing.T) {
getReq := httptest.NewRequest("GET", "/register", nil) csrfCookie := getCSRFToken(t, "/register")
getRec := httptest.NewRecorder()
router.ServeHTTP(getRec, getReq)
cookies := getRec.Result().Cookies() requestBody := url.Values{}
var csrfCookie *http.Cookie requestBody.Set("username", "csrf_user")
for _, cookie := range cookies { requestBody.Set("email", "csrf@example.com")
if cookie.Name == "csrf_token" { requestBody.Set("password", "SecurePass123!")
csrfCookie = cookie requestBody.Set("csrf_token", csrfCookie.Value)
break
}
}
if csrfCookie == nil { request := httptest.NewRequest("POST", "/register", strings.NewReader(requestBody.Encode()))
t.Fatal("Expected CSRF cookie to be set") request.Header.Set("Content-Type", "application/x-www-form-urlencoded")
} request.AddCookie(csrfCookie)
recorder := httptest.NewRecorder()
csrfToken := csrfCookie.Value router.ServeHTTP(recorder, request)
reqBody := url.Values{} if recorder.Code == http.StatusForbidden {
reqBody.Set("username", "csrf_user")
reqBody.Set("email", "csrf@example.com")
reqBody.Set("password", "SecurePass123!")
reqBody.Set("csrf_token", csrfToken)
req := httptest.NewRequest("POST", "/register", strings.NewReader(reqBody.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.AddCookie(csrfCookie)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
if rec.Code == http.StatusForbidden {
t.Error("Expected form submission with valid CSRF token to succeed") t.Error("Expected form submission with valid CSRF token to succeed")
} }
}) })
t.Run("CSRF_Allows_API_Requests", func(t *testing.T) { t.Run("CSRF_Allows_API_Requests", func(t *testing.T) {
reqBody := map[string]string{ requestBody := map[string]string{
"username": "api_user", "username": "api_user",
"email": "api@example.com", "email": "api@example.com",
"password": "SecurePass123!", "password": "SecurePass123!",
} }
body, _ := json.Marshal(reqBody) body, _ := json.Marshal(requestBody)
req := httptest.NewRequest("POST", "/api/auth/register", bytes.NewBuffer(body)) request := httptest.NewRequest("POST", "/api/auth/register", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json") request.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder() recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req) router.ServeHTTP(recorder, request)
if rec.Code == http.StatusForbidden { if recorder.Code == http.StatusForbidden {
t.Error("Expected API requests to bypass CSRF protection") t.Error("Expected API requests to bypass CSRF protection")
} }
}) })
t.Run("CSRF_Blocks_Mismatched_Token", func(t *testing.T) { t.Run("CSRF_Blocks_Mismatched_Token", func(t *testing.T) {
getReq := httptest.NewRequest("GET", "/register", nil) csrfCookie := getCSRFToken(t, "/register")
getRec := httptest.NewRecorder()
router.ServeHTTP(getRec, getReq)
cookies := getRec.Result().Cookies() requestBody := url.Values{}
var csrfCookie *http.Cookie requestBody.Set("username", "mismatch_user")
for _, cookie := range cookies { requestBody.Set("email", "mismatch@example.com")
if cookie.Name == "csrf_token" { requestBody.Set("password", "SecurePass123!")
csrfCookie = cookie requestBody.Set("csrf_token", "wrong-token")
break
request := httptest.NewRequest("POST", "/register", strings.NewReader(requestBody.Encode()))
request.Header.Set("Content-Type", "application/x-www-form-urlencoded")
request.AddCookie(csrfCookie)
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, request)
if recorder.Code != http.StatusForbidden {
t.Errorf("Expected status 403, got %d. Body: %s", recorder.Code, recorder.Body.String())
} }
} if !strings.Contains(recorder.Body.String(), "Invalid CSRF token") {
if csrfCookie == nil {
t.Fatal("Expected CSRF cookie to be set")
}
reqBody := url.Values{}
reqBody.Set("username", "mismatch_user")
reqBody.Set("email", "mismatch@example.com")
reqBody.Set("password", "SecurePass123!")
reqBody.Set("csrf_token", "wrong-token")
req := httptest.NewRequest("POST", "/register", strings.NewReader(reqBody.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.AddCookie(csrfCookie)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
if rec.Code != http.StatusForbidden {
t.Errorf("Expected status 403, got %d. Body: %s", rec.Code, rec.Body.String())
}
if !strings.Contains(rec.Body.String(), "Invalid CSRF token") {
t.Error("Expected CSRF error message") t.Error("Expected CSRF error message")
} }
}) })
t.Run("CSRF_Allows_GET_Requests", func(t *testing.T) { t.Run("CSRF_Allows_GET_Requests", func(t *testing.T) {
req := httptest.NewRequest("GET", "/register", nil) request := httptest.NewRequest("GET", "/register", nil)
rec := httptest.NewRecorder() recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req) router.ServeHTTP(recorder, request)
if rec.Code == http.StatusForbidden { if recorder.Code == http.StatusForbidden {
t.Error("Expected GET requests to bypass CSRF protection") t.Error("Expected GET requests to bypass CSRF protection")
} }
}) })
t.Run("CSRF_Token_In_Header", func(t *testing.T) { t.Run("CSRF_Token_In_Header", func(t *testing.T) {
getReq := httptest.NewRequest("GET", "/register", nil) csrfCookie := getCSRFToken(t, "/register")
getRec := httptest.NewRecorder()
router.ServeHTTP(getRec, getReq)
cookies := getRec.Result().Cookies() requestBody := url.Values{}
var csrfCookie *http.Cookie requestBody.Set("username", "header_user")
for _, cookie := range cookies { requestBody.Set("email", "header@example.com")
if cookie.Name == "csrf_token" { requestBody.Set("password", "SecurePass123!")
csrfCookie = cookie
break
}
}
if csrfCookie == nil { request := httptest.NewRequest("POST", "/register", strings.NewReader(requestBody.Encode()))
t.Fatal("Expected CSRF cookie to be set") request.Header.Set("Content-Type", "application/x-www-form-urlencoded")
} request.Header.Set("X-CSRF-Token", csrfCookie.Value)
request.AddCookie(csrfCookie)
recorder := httptest.NewRecorder()
csrfToken := csrfCookie.Value router.ServeHTTP(recorder, request)
reqBody := url.Values{} if recorder.Code == http.StatusForbidden {
reqBody.Set("username", "header_user")
reqBody.Set("email", "header@example.com")
reqBody.Set("password", "SecurePass123!")
req := httptest.NewRequest("POST", "/register", strings.NewReader(reqBody.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("X-CSRF-Token", csrfToken)
req.AddCookie(csrfCookie)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
if rec.Code == http.StatusForbidden {
t.Error("Expected CSRF token in header to be accepted") t.Error("Expected CSRF token in header to be accepted")
} }
}) })
@@ -182,41 +153,24 @@ func TestIntegration_CSRF_Protection(t *testing.T) {
ctx.Suite.EmailSender.Reset() ctx.Suite.EmailSender.Reset()
user := createUserWithCleanup(t, ctx, "csrf_form_user", "csrf_form@example.com") user := createUserWithCleanup(t, ctx, "csrf_form_user", "csrf_form@example.com")
getReq := httptest.NewRequest("GET", "/posts/new", nil) authCookie := &http.Cookie{Name: "auth_token", Value: user.Token}
getReq.AddCookie(&http.Cookie{Name: "auth_token", Value: user.Token}) csrfCookie := getCSRFToken(t, "/posts/new", authCookie)
getRec := httptest.NewRecorder()
router.ServeHTTP(getRec, getReq)
cookies := getRec.Result().Cookies() requestBody := url.Values{}
var csrfCookie *http.Cookie requestBody.Set("title", "CSRF Test Post")
for _, cookie := range cookies { requestBody.Set("url", "https://example.com/csrf-test")
if cookie.Name == "csrf_token" { requestBody.Set("content", "Test content")
csrfCookie = cookie requestBody.Set("csrf_token", csrfCookie.Value)
break
}
}
if csrfCookie == nil { request := httptest.NewRequest("POST", "/posts", strings.NewReader(requestBody.Encode()))
t.Fatal("Expected CSRF cookie to be set") request.Header.Set("Content-Type", "application/x-www-form-urlencoded")
} request.AddCookie(authCookie)
request.AddCookie(csrfCookie)
recorder := httptest.NewRecorder()
csrfToken := csrfCookie.Value router.ServeHTTP(recorder, request)
reqBody := url.Values{} if recorder.Code == http.StatusForbidden {
reqBody.Set("title", "CSRF Test Post")
reqBody.Set("url", "https://example.com/csrf-test")
reqBody.Set("content", "Test content")
reqBody.Set("csrf_token", csrfToken)
req := httptest.NewRequest("POST", "/posts", strings.NewReader(reqBody.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.AddCookie(&http.Cookie{Name: "auth_token", Value: user.Token})
req.AddCookie(csrfCookie)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
if rec.Code == http.StatusForbidden {
t.Error("Expected post creation with valid CSRF token to succeed") t.Error("Expected post creation with valid CSRF token to succeed")
} }
}) })

View File

@@ -19,27 +19,18 @@ func TestIntegration_DataConsistency(t *testing.T) {
ctx.Suite.EmailSender.Reset() ctx.Suite.EmailSender.Reset()
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "consistency_user", "consistency@example.com") user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "consistency_user", "consistency@example.com")
postBody := map[string]string{ request := makePostRequest(t, ctx.Router, "/api/posts", map[string]any{
"title": "Consistency Test Post", "title": "Consistency Test Post",
"url": "https://example.com/consistency", "url": "https://example.com/consistency",
"content": "Test content", "content": "Test content",
} }, user, nil)
body, _ := json.Marshal(postBody)
req := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer(body)) createResponse := assertJSONResponse(t, request, http.StatusCreated)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
rec := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req)
createResponse := assertJSONResponse(t, rec, http.StatusCreated)
if createResponse == nil { if createResponse == nil {
return return
} }
postData, ok := createResponse["data"].(map[string]any) postData, ok := getDataFromResponse(createResponse)
if !ok { if !ok {
t.Fatal("Response missing data") t.Fatal("Response missing data")
} }
@@ -53,16 +44,14 @@ func TestIntegration_DataConsistency(t *testing.T) {
createdURL := postData["url"] createdURL := postData["url"]
createdContent := postData["content"] createdContent := postData["content"]
getReq := httptest.NewRequest("GET", fmt.Sprintf("/api/posts/%.0f", postID), nil) getRequest := makeGetRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%.0f", postID))
getRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(getRec, getReq)
getResponse := assertJSONResponse(t, getRec, http.StatusOK) getResponse := assertJSONResponse(t, getRequest, http.StatusOK)
if getResponse == nil { if getResponse == nil {
return return
} }
getPostData, ok := getResponse["data"].(map[string]any) getPostData, ok := getDataFromResponse(getResponse)
if !ok { if !ok {
t.Fatal("Get response missing data") t.Fatal("Get response missing data")
} }
@@ -96,32 +85,17 @@ func TestIntegration_DataConsistency(t *testing.T) {
post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Vote Consistency Post", "https://example.com/vote-consistency") post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Vote Consistency Post", "https://example.com/vote-consistency")
voteBody := map[string]string{"type": "up"} voteRequest := makePostRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d/vote", post.ID), map[string]any{"type": "up"}, user, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
body, _ := json.Marshal(voteBody) assertStatus(t, voteRequest, http.StatusOK)
voteReq := httptest.NewRequest("POST", fmt.Sprintf("/api/posts/%d/vote", post.ID), bytes.NewBuffer(body)) getVotesRequest := makeAuthenticatedGetRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d/votes", post.ID), user, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
voteReq.Header.Set("Content-Type", "application/json")
voteReq.Header.Set("Authorization", "Bearer "+user.Token)
voteReq = testutils.WithUserContext(voteReq, middleware.UserIDKey, user.User.ID)
voteReq = testutils.WithURLParams(voteReq, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
voteRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(voteRec, voteReq)
assertStatus(t, voteRec, http.StatusOK) votesResponse := assertJSONResponse(t, getVotesRequest, http.StatusOK)
getVotesReq := httptest.NewRequest("GET", fmt.Sprintf("/api/posts/%d/votes", post.ID), nil)
getVotesReq.Header.Set("Authorization", "Bearer "+user.Token)
getVotesReq = testutils.WithUserContext(getVotesReq, middleware.UserIDKey, user.User.ID)
getVotesReq = testutils.WithURLParams(getVotesReq, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
getVotesRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(getVotesRec, getVotesReq)
votesResponse := assertJSONResponse(t, getVotesRec, http.StatusOK)
if votesResponse == nil { if votesResponse == nil {
return return
} }
votesData, ok := votesResponse["data"].(map[string]any) votesData, ok := getDataFromResponse(votesResponse)
if !ok { if !ok {
t.Fatal("Votes response missing data") t.Fatal("Votes response missing data")
} }
@@ -172,32 +146,21 @@ func TestIntegration_DataConsistency(t *testing.T) {
post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Original Title", "https://example.com/original") post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Original Title", "https://example.com/original")
updateBody := map[string]string{ updateRequest := makePutRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d", post.ID), map[string]any{
"title": "Updated Title", "title": "Updated Title",
"content": "Updated content", "content": "Updated content",
} }, user, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
body, _ := json.Marshal(updateBody)
updateReq := httptest.NewRequest("PUT", fmt.Sprintf("/api/posts/%d", post.ID), bytes.NewBuffer(body)) assertStatus(t, updateRequest, http.StatusOK)
updateReq.Header.Set("Content-Type", "application/json")
updateReq.Header.Set("Authorization", "Bearer "+user.Token)
updateReq = testutils.WithUserContext(updateReq, middleware.UserIDKey, user.User.ID)
updateReq = testutils.WithURLParams(updateReq, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
updateRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(updateRec, updateReq)
assertStatus(t, updateRec, http.StatusOK) getRequest := makeGetRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d", post.ID))
getReq := httptest.NewRequest("GET", fmt.Sprintf("/api/posts/%d", post.ID), nil) getResponse := assertJSONResponse(t, getRequest, http.StatusOK)
getRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(getRec, getReq)
getResponse := assertJSONResponse(t, getRec, http.StatusOK)
if getResponse == nil { if getResponse == nil {
return return
} }
getPostData, ok := getResponse["data"].(map[string]any) getPostData, ok := getDataFromResponse(getResponse)
if !ok { if !ok {
t.Fatal("Get response missing data") t.Fatal("Get response missing data")
} }
@@ -215,18 +178,12 @@ func TestIntegration_DataConsistency(t *testing.T) {
ctx.Suite.EmailSender.Reset() ctx.Suite.EmailSender.Reset()
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "user_posts_consistency", "user_posts_consistency@example.com") user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "user_posts_consistency", "user_posts_consistency@example.com")
post1 := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Post 1", "https://example.com/post1") firstPost := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Post 1", "https://example.com/post1")
post2 := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Post 2", "https://example.com/post2") secondPost := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Post 2", "https://example.com/post2")
req := httptest.NewRequest("GET", fmt.Sprintf("/api/users/%d/posts", user.User.ID), nil) request := makeAuthenticatedGetRequest(t, ctx.Router, fmt.Sprintf("/api/users/%d/posts", user.User.ID), user, map[string]string{"id": fmt.Sprintf("%d", user.User.ID)})
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
req = testutils.WithURLParams(req, map[string]string{"id": fmt.Sprintf("%d", user.User.ID)})
rec := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req) response := assertJSONResponse(t, request, http.StatusOK)
response := assertJSONResponse(t, rec, http.StatusOK)
if response == nil { if response == nil {
return return
} }
@@ -245,26 +202,26 @@ func TestIntegration_DataConsistency(t *testing.T) {
t.Errorf("Expected at least 2 posts, got %d", len(posts)) t.Errorf("Expected at least 2 posts, got %d", len(posts))
} }
foundPost1 := false foundFirstPost := false
foundPost2 := false foundSecondPost := false
for _, post := range posts { for _, post := range posts {
if postMap, ok := post.(map[string]any); ok { if postMap, ok := post.(map[string]any); ok {
if postID, ok := postMap["id"].(float64); ok { if postID, ok := postMap["id"].(float64); ok {
if uint(postID) == post1.ID { if uint(postID) == firstPost.ID {
foundPost1 = true foundFirstPost = true
} }
if uint(postID) == post2.ID { if uint(postID) == secondPost.ID {
foundPost2 = true foundSecondPost = true
} }
} }
} }
} }
if !foundPost1 { if !foundFirstPost {
t.Error("Post 1 not found in user posts") t.Error("Post 1 not found in user posts")
} }
if !foundPost2 { if !foundSecondPost {
t.Error("Post 2 not found in user posts") t.Error("Post 2 not found in user posts")
} }
}) })
@@ -275,20 +232,20 @@ func TestIntegration_DataConsistency(t *testing.T) {
post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Delete Consistency Post", "https://example.com/delete-consistency") post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Delete Consistency Post", "https://example.com/delete-consistency")
deleteReq := httptest.NewRequest("DELETE", fmt.Sprintf("/api/posts/%d", post.ID), nil) deleteRequest := httptest.NewRequest("DELETE", fmt.Sprintf("/api/posts/%d", post.ID), nil)
deleteReq.Header.Set("Authorization", "Bearer "+user.Token) deleteRequest.Header.Set("Authorization", "Bearer "+user.Token)
deleteReq = testutils.WithUserContext(deleteReq, middleware.UserIDKey, user.User.ID) deleteRequest = testutils.WithUserContext(deleteRequest, middleware.UserIDKey, user.User.ID)
deleteReq = testutils.WithURLParams(deleteReq, map[string]string{"id": fmt.Sprintf("%d", post.ID)}) deleteRequest = testutils.WithURLParams(deleteRequest, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
deleteRec := httptest.NewRecorder() deleteRecorder := httptest.NewRecorder()
ctx.Router.ServeHTTP(deleteRec, deleteReq) ctx.Router.ServeHTTP(deleteRecorder, deleteRequest)
assertStatus(t, deleteRec, http.StatusOK) assertStatus(t, deleteRecorder, http.StatusOK)
getReq := httptest.NewRequest("GET", fmt.Sprintf("/api/posts/%d", post.ID), nil) getRequest := httptest.NewRequest("GET", fmt.Sprintf("/api/posts/%d", post.ID), nil)
getRec := httptest.NewRecorder() getRecorder := httptest.NewRecorder()
ctx.Router.ServeHTTP(getRec, getReq) ctx.Router.ServeHTTP(getRecorder, getRequest)
assertStatus(t, getRec, http.StatusNotFound) assertStatus(t, getRecorder, http.StatusNotFound)
}) })
t.Run("Vote_Removal_Consistency", func(t *testing.T) { t.Run("Vote_Removal_Consistency", func(t *testing.T) {
@@ -300,33 +257,33 @@ func TestIntegration_DataConsistency(t *testing.T) {
voteBody := map[string]string{"type": "up"} voteBody := map[string]string{"type": "up"}
body, _ := json.Marshal(voteBody) body, _ := json.Marshal(voteBody)
voteReq := httptest.NewRequest("POST", fmt.Sprintf("/api/posts/%d/vote", post.ID), bytes.NewBuffer(body)) voteRequest := httptest.NewRequest("POST", fmt.Sprintf("/api/posts/%d/vote", post.ID), bytes.NewBuffer(body))
voteReq.Header.Set("Content-Type", "application/json") voteRequest.Header.Set("Content-Type", "application/json")
voteReq.Header.Set("Authorization", "Bearer "+user.Token) voteRequest.Header.Set("Authorization", "Bearer "+user.Token)
voteReq = testutils.WithUserContext(voteReq, middleware.UserIDKey, user.User.ID) voteRequest = testutils.WithUserContext(voteRequest, middleware.UserIDKey, user.User.ID)
voteReq = testutils.WithURLParams(voteReq, map[string]string{"id": fmt.Sprintf("%d", post.ID)}) voteRequest = testutils.WithURLParams(voteRequest, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
voteRec := httptest.NewRecorder() voteRecorder := httptest.NewRecorder()
ctx.Router.ServeHTTP(voteRec, voteReq) ctx.Router.ServeHTTP(voteRecorder, voteRequest)
assertStatus(t, voteRec, http.StatusOK) assertStatus(t, voteRecorder, http.StatusOK)
removeVoteReq := httptest.NewRequest("DELETE", fmt.Sprintf("/api/posts/%d/vote", post.ID), nil) removeVoteRequest := httptest.NewRequest("DELETE", fmt.Sprintf("/api/posts/%d/vote", post.ID), nil)
removeVoteReq.Header.Set("Authorization", "Bearer "+user.Token) removeVoteRequest.Header.Set("Authorization", "Bearer "+user.Token)
removeVoteReq = testutils.WithUserContext(removeVoteReq, middleware.UserIDKey, user.User.ID) removeVoteRequest = testutils.WithUserContext(removeVoteRequest, middleware.UserIDKey, user.User.ID)
removeVoteReq = testutils.WithURLParams(removeVoteReq, map[string]string{"id": fmt.Sprintf("%d", post.ID)}) removeVoteRequest = testutils.WithURLParams(removeVoteRequest, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
removeVoteRec := httptest.NewRecorder() removeVoteRecorder := httptest.NewRecorder()
ctx.Router.ServeHTTP(removeVoteRec, removeVoteReq) ctx.Router.ServeHTTP(removeVoteRecorder, removeVoteRequest)
assertStatus(t, removeVoteRec, http.StatusOK) assertStatus(t, removeVoteRecorder, http.StatusOK)
getVotesReq := httptest.NewRequest("GET", fmt.Sprintf("/api/posts/%d/votes", post.ID), nil) getVotesRequest := httptest.NewRequest("GET", fmt.Sprintf("/api/posts/%d/votes", post.ID), nil)
getVotesReq.Header.Set("Authorization", "Bearer "+user.Token) getVotesRequest.Header.Set("Authorization", "Bearer "+user.Token)
getVotesReq = testutils.WithUserContext(getVotesReq, middleware.UserIDKey, user.User.ID) getVotesRequest = testutils.WithUserContext(getVotesRequest, middleware.UserIDKey, user.User.ID)
getVotesReq = testutils.WithURLParams(getVotesReq, map[string]string{"id": fmt.Sprintf("%d", post.ID)}) getVotesRequest = testutils.WithURLParams(getVotesRequest, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
getVotesRec := httptest.NewRecorder() getVotesRecorder := httptest.NewRecorder()
ctx.Router.ServeHTTP(getVotesRec, getVotesReq) ctx.Router.ServeHTTP(getVotesRecorder, getVotesRequest)
votesResponse := assertJSONResponse(t, getVotesRec, http.StatusOK) votesResponse := assertJSONResponse(t, getVotesRecorder, http.StatusOK)
if votesResponse == nil { if votesResponse == nil {
return return
} }

View File

@@ -22,41 +22,39 @@ func TestIntegration_EdgeCases(t *testing.T) {
expiredToken := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjE2MDAwMDAwMDB9.expired" expiredToken := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjE2MDAwMDAwMDB9.expired"
req := httptest.NewRequest("GET", "/api/auth/me", nil) request := httptest.NewRequest("GET", "/api/auth/me", nil)
req.Header.Set("Authorization", "Bearer "+expiredToken) request.Header.Set("Authorization", "Bearer "+expiredToken)
rec := httptest.NewRecorder() recorder := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req) ctx.Router.ServeHTTP(recorder, request)
assertErrorResponse(t, rec, http.StatusUnauthorized) assertErrorResponse(t, recorder, http.StatusUnauthorized)
}) })
t.Run("Concurrent_Vote_Operations", func(t *testing.T) { t.Run("Concurrent_Vote_Operations", func(t *testing.T) {
ctx.Suite.EmailSender.Reset() ctx.Suite.EmailSender.Reset()
user1 := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "vote_user1", "vote1@example.com") firstUser := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "vote_user1", "vote1@example.com")
post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user1.User.ID, "Concurrent Vote Post", "https://example.com/concurrent") post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, firstUser.User.ID, "Concurrent Vote Post", "https://example.com/concurrent")
var wg sync.WaitGroup var wg sync.WaitGroup
errors := make(chan error, 10) errors := make(chan error, 10)
for i := 0; i < 5; i++ { for range 5 {
wg.Add(1) wg.Go(func() {
go func() {
defer wg.Done()
voteBody := map[string]string{"type": "up"} voteBody := map[string]string{"type": "up"}
body, _ := json.Marshal(voteBody) body, _ := json.Marshal(voteBody)
req := httptest.NewRequest("POST", fmt.Sprintf("/api/posts/%d/vote", post.ID), bytes.NewBuffer(body)) request := httptest.NewRequest("POST", fmt.Sprintf("/api/posts/%d/vote", post.ID), bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json") request.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+user1.Token) request.Header.Set("Authorization", "Bearer "+firstUser.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user1.User.ID) request = testutils.WithUserContext(request, middleware.UserIDKey, firstUser.User.ID)
req = testutils.WithURLParams(req, map[string]string{"id": fmt.Sprintf("%d", post.ID)}) request = testutils.WithURLParams(request, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
rec := httptest.NewRecorder() recorder := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req) ctx.Router.ServeHTTP(recorder, request)
if rec.Code != http.StatusOK { if recorder.Code != http.StatusOK {
errors <- fmt.Errorf("unexpected status: %d", rec.Code) errors <- fmt.Errorf("unexpected status: %d", recorder.Code)
} }
}() })
} }
wg.Wait() wg.Wait()
@@ -72,8 +70,8 @@ func TestIntegration_EdgeCases(t *testing.T) {
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "large_user", "large@example.com") user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "large_user", "large@example.com")
largeContent := make([]byte, 10001) largeContent := make([]byte, 10001)
for i := range largeContent { for idx := range largeContent {
largeContent[i] = 'a' largeContent[idx] = 'a'
} }
postBody := map[string]string{ postBody := map[string]string{
@@ -82,36 +80,36 @@ func TestIntegration_EdgeCases(t *testing.T) {
"content": string(largeContent), "content": string(largeContent),
} }
body, _ := json.Marshal(postBody) body, _ := json.Marshal(postBody)
req := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer(body)) request := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json") request.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+user.Token) request.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID) request = testutils.WithUserContext(request, middleware.UserIDKey, user.User.ID)
rec := httptest.NewRecorder() recorder := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req) ctx.Router.ServeHTTP(recorder, request)
assertErrorResponse(t, rec, http.StatusBadRequest) assertErrorResponse(t, recorder, http.StatusBadRequest)
smallContent := make([]byte, 1000) smallContent := make([]byte, 1000)
for i := range smallContent { for idx := range smallContent {
smallContent[i] = 'a' smallContent[idx] = 'a'
} }
postBody2 := map[string]string{ secondPostBody := map[string]string{
"title": "Small Post", "title": "Small Post",
"url": "https://example.com/small", "url": "https://example.com/small",
"content": string(smallContent), "content": string(smallContent),
} }
body2, _ := json.Marshal(postBody2) secondBody, _ := json.Marshal(secondPostBody)
req2 := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer(body2)) secondRequest := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer(secondBody))
req2.Header.Set("Content-Type", "application/json") secondRequest.Header.Set("Content-Type", "application/json")
req2.Header.Set("Authorization", "Bearer "+user.Token) secondRequest.Header.Set("Authorization", "Bearer "+user.Token)
req2 = testutils.WithUserContext(req2, middleware.UserIDKey, user.User.ID) secondRequest = testutils.WithUserContext(secondRequest, middleware.UserIDKey, user.User.ID)
rec2 := httptest.NewRecorder() secondRecorder := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec2, req2) ctx.Router.ServeHTTP(secondRecorder, secondRequest)
assertStatus(t, rec2, http.StatusCreated) assertStatus(t, secondRecorder, http.StatusCreated)
}) })
t.Run("Malformed_JSON_Payloads", func(t *testing.T) { t.Run("Malformed_JSON_Payloads", func(t *testing.T) {
@@ -127,15 +125,15 @@ func TestIntegration_EdgeCases(t *testing.T) {
} }
for _, payload := range malformedPayloads { for _, payload := range malformedPayloads {
req := httptest.NewRequest("POST", "/api/posts", bytes.NewBufferString(payload)) request := httptest.NewRequest("POST", "/api/posts", bytes.NewBufferString(payload))
req.Header.Set("Content-Type", "application/json") request.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+user.Token) request.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID) request = testutils.WithUserContext(request, middleware.UserIDKey, user.User.ID)
rec := httptest.NewRecorder() recorder := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req) ctx.Router.ServeHTTP(recorder, request)
assertErrorResponse(t, rec, http.StatusBadRequest) assertErrorResponse(t, recorder, http.StatusBadRequest)
} }
}) })
@@ -148,38 +146,36 @@ func TestIntegration_EdgeCases(t *testing.T) {
voteBody := map[string]string{"type": "up"} voteBody := map[string]string{"type": "up"}
body, _ := json.Marshal(voteBody) body, _ := json.Marshal(voteBody)
voteReq := httptest.NewRequest("POST", fmt.Sprintf("/api/posts/%d/vote", post.ID), bytes.NewBuffer(body)) voteRequest := httptest.NewRequest("POST", fmt.Sprintf("/api/posts/%d/vote", post.ID), bytes.NewBuffer(body))
voteReq.Header.Set("Content-Type", "application/json") voteRequest.Header.Set("Content-Type", "application/json")
voteReq.Header.Set("Authorization", "Bearer "+user.Token) voteRequest.Header.Set("Authorization", "Bearer "+user.Token)
voteReq = testutils.WithUserContext(voteReq, middleware.UserIDKey, user.User.ID) voteRequest = testutils.WithUserContext(voteRequest, middleware.UserIDKey, user.User.ID)
voteReq = testutils.WithURLParams(voteReq, map[string]string{"id": fmt.Sprintf("%d", post.ID)}) voteRequest = testutils.WithURLParams(voteRequest, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
voteRec := httptest.NewRecorder() voteRecorder := httptest.NewRecorder()
ctx.Router.ServeHTTP(voteRec, voteReq) ctx.Router.ServeHTTP(voteRecorder, voteRequest)
assertStatus(t, voteRec, http.StatusOK) assertStatus(t, voteRecorder, http.StatusOK)
var wg sync.WaitGroup var wg sync.WaitGroup
for i := 0; i < 3; i++ { for range 3 {
wg.Add(1) wg.Go(func() {
go func() { request := httptest.NewRequest("DELETE", fmt.Sprintf("/api/posts/%d/vote", post.ID), nil)
defer wg.Done() request.Header.Set("Authorization", "Bearer "+user.Token)
req := httptest.NewRequest("DELETE", fmt.Sprintf("/api/posts/%d/vote", post.ID), nil) request = testutils.WithUserContext(request, middleware.UserIDKey, user.User.ID)
req.Header.Set("Authorization", "Bearer "+user.Token) request = testutils.WithURLParams(request, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID) recorder := httptest.NewRecorder()
req = testutils.WithURLParams(req, map[string]string{"id": fmt.Sprintf("%d", post.ID)}) ctx.Router.ServeHTTP(recorder, request)
rec := httptest.NewRecorder() })
ctx.Router.ServeHTTP(rec, req)
}()
} }
wg.Wait() wg.Wait()
getVotesReq := httptest.NewRequest("GET", fmt.Sprintf("/api/posts/%d/votes", post.ID), nil) getVotesRequest := httptest.NewRequest("GET", fmt.Sprintf("/api/posts/%d/votes", post.ID), nil)
getVotesReq.Header.Set("Authorization", "Bearer "+user.Token) getVotesRequest.Header.Set("Authorization", "Bearer "+user.Token)
getVotesReq = testutils.WithUserContext(getVotesReq, middleware.UserIDKey, user.User.ID) getVotesRequest = testutils.WithUserContext(getVotesRequest, middleware.UserIDKey, user.User.ID)
getVotesReq = testutils.WithURLParams(getVotesReq, map[string]string{"id": fmt.Sprintf("%d", post.ID)}) getVotesRequest = testutils.WithURLParams(getVotesRequest, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
getVotesRec := httptest.NewRecorder() getVotesRecorder := httptest.NewRecorder()
ctx.Router.ServeHTTP(getVotesRec, getVotesReq) ctx.Router.ServeHTTP(getVotesRecorder, getVotesRequest)
votesResponse := assertJSONResponse(t, getVotesRec, http.StatusOK) votesResponse := assertJSONResponse(t, getVotesRecorder, http.StatusOK)
if votesResponse != nil { if votesResponse != nil {
if data, ok := votesResponse["data"].(map[string]any); ok { if data, ok := votesResponse["data"].(map[string]any); ok {
if votes, ok := data["votes"].([]any); ok { if votes, ok := data["votes"].([]any); ok {

View File

@@ -1,14 +1,10 @@
package integration package integration
import ( import (
"bytes"
"encoding/json"
"net/http" "net/http"
"net/http/httptest"
"testing" "testing"
"goyco/internal/database" "goyco/internal/database"
"goyco/internal/middleware"
"goyco/internal/testutils" "goyco/internal/testutils"
) )
@@ -19,19 +15,14 @@ func TestIntegration_EmailService(t *testing.T) {
t.Run("Registration_Email_Sent", func(t *testing.T) { t.Run("Registration_Email_Sent", func(t *testing.T) {
ctx.Suite.EmailSender.Reset() ctx.Suite.EmailSender.Reset()
reqBody := map[string]string{ reqBody := map[string]any{
"username": "email_reg_user", "username": "email_reg_user",
"email": "email_reg@example.com", "email": "email_reg@example.com",
"password": "SecurePass123!", "password": "SecurePass123!",
} }
body, _ := json.Marshal(reqBody) request := makePostRequestWithJSON(t, router, "/api/auth/register", reqBody)
req := httptest.NewRequest("POST", "/api/auth/register", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req) assertStatus(t, request, http.StatusCreated)
assertStatus(t, rec, http.StatusCreated)
token := ctx.Suite.EmailSender.VerificationToken() token := ctx.Suite.EmailSender.VerificationToken()
if token == "" { if token == "" {
@@ -52,15 +43,10 @@ func TestIntegration_EmailService(t *testing.T) {
t.Fatalf("Failed to create user: %v", err) t.Fatalf("Failed to create user: %v", err)
} }
reqBody := map[string]string{ reqBody := map[string]any{
"username_or_email": "email_reset_user", "username_or_email": "email_reset_user",
} }
body, _ := json.Marshal(reqBody) makePostRequestWithJSON(t, router, "/api/auth/forgot-password", reqBody)
req := httptest.NewRequest("POST", "/api/auth/forgot-password", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
token := ctx.Suite.EmailSender.PasswordResetToken() token := ctx.Suite.EmailSender.PasswordResetToken()
if token == "" { if token == "" {
@@ -72,17 +58,9 @@ func TestIntegration_EmailService(t *testing.T) {
ctx.Suite.EmailSender.Reset() ctx.Suite.EmailSender.Reset()
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "email_del_user", "email_del@example.com") user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "email_del_user", "email_del@example.com")
reqBody := map[string]string{} request := makeDeleteRequest(t, router, "/api/auth/account", user, nil)
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("DELETE", "/api/auth/account", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req) assertStatus(t, request, http.StatusOK)
assertStatus(t, rec, http.StatusOK)
token := ctx.Suite.EmailSender.DeletionToken() token := ctx.Suite.EmailSender.DeletionToken()
if token == "" { if token == "" {
@@ -94,17 +72,10 @@ func TestIntegration_EmailService(t *testing.T) {
ctx.Suite.EmailSender.Reset() ctx.Suite.EmailSender.Reset()
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "email_change_user", "email_change@example.com") user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "email_change_user", "email_change@example.com")
reqBody := map[string]string{ reqBody := map[string]any{
"email": "newemail@example.com", "email": "newemail@example.com",
} }
body, _ := json.Marshal(reqBody) makePutRequest(t, router, "/api/auth/email", reqBody, user, nil)
req := httptest.NewRequest("PUT", "/api/auth/email", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
token := ctx.Suite.EmailSender.VerificationToken() token := ctx.Suite.EmailSender.VerificationToken()
if token == "" { if token == "" {
@@ -115,17 +86,12 @@ func TestIntegration_EmailService(t *testing.T) {
t.Run("Email_Template_Content", func(t *testing.T) { t.Run("Email_Template_Content", func(t *testing.T) {
ctx.Suite.EmailSender.Reset() ctx.Suite.EmailSender.Reset()
reqBody := map[string]string{ reqBody := map[string]any{
"username": "template_user", "username": "template_user",
"email": "template@example.com", "email": "template@example.com",
"password": "SecurePass123!", "password": "SecurePass123!",
} }
body, _ := json.Marshal(reqBody) makePostRequestWithJSON(t, router, "/api/auth/register", reqBody)
req := httptest.NewRequest("POST", "/api/auth/register", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
token := ctx.Suite.EmailSender.VerificationToken() token := ctx.Suite.EmailSender.VerificationToken()
if token == "" { if token == "" {

View File

@@ -1,7 +1,6 @@
package integration package integration
import ( import (
"bytes"
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/http" "net/http"
@@ -10,7 +9,7 @@ import (
"strings" "strings"
"testing" "testing"
"goyco/internal/middleware" "goyco/internal/database"
"goyco/internal/testutils" "goyco/internal/testutils"
) )
@@ -20,46 +19,34 @@ func TestIntegration_EndToEndUserJourneys(t *testing.T) {
t.Run("Complete_Registration_To_Post_Creation_Journey", func(t *testing.T) { t.Run("Complete_Registration_To_Post_Creation_Journey", func(t *testing.T) {
ctx.Suite.EmailSender.Reset() ctx.Suite.EmailSender.Reset()
registerBody := map[string]string{ registerRequest := makePostRequestWithJSON(t, ctx.Router, "/api/auth/register", map[string]any{
"username": "journey_user", "username": "journey_user",
"email": "journey@example.com", "email": "journey@example.com",
"password": "SecurePass123!", "password": "SecurePass123!",
} })
body, _ := json.Marshal(registerBody)
registerReq := httptest.NewRequest("POST", "/api/auth/register", bytes.NewBuffer(body))
registerReq.Header.Set("Content-Type", "application/json")
registerRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(registerRec, registerReq)
assertStatus(t, registerRec, http.StatusCreated) assertStatus(t, registerRequest, http.StatusCreated)
verificationToken := ctx.Suite.EmailSender.VerificationToken() verificationToken := ctx.Suite.EmailSender.VerificationToken()
if verificationToken == "" { if verificationToken == "" {
t.Fatal("Verification token not sent") t.Fatal("Verification token not sent")
} }
confirmReq := httptest.NewRequest("GET", "/api/auth/confirm?token="+url.QueryEscape(verificationToken), nil) confirmRequest := makeGetRequest(t, ctx.Router, "/api/auth/confirm?token="+url.QueryEscape(verificationToken))
confirmRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(confirmRec, confirmReq)
assertStatus(t, confirmRec, http.StatusOK) assertStatus(t, confirmRequest, http.StatusOK)
loginBody := map[string]string{ loginRequest := makePostRequestWithJSON(t, ctx.Router, "/api/auth/login", map[string]any{
"username": "journey_user", "username": "journey_user",
"password": "SecurePass123!", "password": "SecurePass123!",
} })
loginBodyBytes, _ := json.Marshal(loginBody)
loginReq := httptest.NewRequest("POST", "/api/auth/login", bytes.NewBuffer(loginBodyBytes))
loginReq.Header.Set("Content-Type", "application/json")
loginRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(loginRec, loginReq)
loginResponse := assertJSONResponse(t, loginRec, http.StatusOK) loginResponse := assertJSONResponse(t, loginRequest, http.StatusOK)
if loginResponse == nil { if loginResponse == nil {
return return
} }
data, ok := loginResponse["data"].(map[string]any) data, ok := getDataFromResponse(loginResponse)
if !ok { if !ok {
t.Fatal("Login response missing data") t.Fatal("Login response missing data")
} }
@@ -67,8 +54,8 @@ func TestIntegration_EndToEndUserJourneys(t *testing.T) {
var token string var token string
if accessToken, ok := data["access_token"].(string); ok && accessToken != "" { if accessToken, ok := data["access_token"].(string); ok && accessToken != "" {
token = accessToken token = accessToken
} else if tokenVal, ok := data["token"].(string); ok && tokenVal != "" { } else if tokenValue, ok := data["token"].(string); ok && tokenValue != "" {
token = tokenVal token = tokenValue
} else { } else {
t.Fatal("Login response missing access_token or token") t.Fatal("Login response missing access_token or token")
} }
@@ -90,25 +77,19 @@ func TestIntegration_EndToEndUserJourneys(t *testing.T) {
t.Fatalf("Login response missing user.id. Data: %+v", data) t.Fatalf("Login response missing user.id. Data: %+v", data)
} }
postBody := map[string]string{ postBodyBytes, _ := json.Marshal(map[string]any{
"title": "Journey Test Post", "title": "Journey Test Post",
"url": "https://example.com/journey", "url": "https://example.com/journey",
"content": "Test content", "content": "Test content",
} })
postBodyBytes, _ := json.Marshal(postBody) postRequest := makeAuthenticatedRequest(t, ctx.Router, "POST", "/api/posts", postBodyBytes, &authenticatedUser{User: &database.User{ID: uint(userID)}, Token: token}, nil)
postReq := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer(postBodyBytes))
postReq.Header.Set("Content-Type", "application/json")
postReq.Header.Set("Authorization", "Bearer "+token)
postReq = testutils.WithUserContext(postReq, middleware.UserIDKey, uint(userID))
postRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(postRec, postReq)
postResponse := assertJSONResponse(t, postRec, http.StatusCreated) postResponse := assertJSONResponse(t, postRequest, http.StatusCreated)
if postResponse == nil { if postResponse == nil {
return return
} }
postData, ok := postResponse["data"].(map[string]any) postData, ok := getDataFromResponse(postResponse)
if !ok { if !ok {
t.Fatal("Post response missing data") t.Fatal("Post response missing data")
} }
@@ -118,56 +99,39 @@ func TestIntegration_EndToEndUserJourneys(t *testing.T) {
t.Fatal("Post response missing id") t.Fatal("Post response missing id")
} }
getPostReq := httptest.NewRequest("GET", fmt.Sprintf("/api/posts/%.0f", postID), nil) getPostRequest := makeGetRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%.0f", postID))
getPostRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(getPostRec, getPostReq)
assertStatus(t, getPostRec, http.StatusOK) assertStatus(t, getPostRequest, http.StatusOK)
}) })
t.Run("Complete_Password_Reset_Journey", func(t *testing.T) { t.Run("Complete_Password_Reset_Journey", func(t *testing.T) {
ctx.Suite.EmailSender.Reset() ctx.Suite.EmailSender.Reset()
createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "reset_journey_user", "reset_journey@example.com") createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "reset_journey_user", "reset_journey@example.com")
resetBody := map[string]string{ resetRequest := makePostRequestWithJSON(t, ctx.Router, "/api/auth/forgot-password", map[string]any{
"username_or_email": "reset_journey@example.com", "username_or_email": "reset_journey@example.com",
} })
resetBodyBytes, _ := json.Marshal(resetBody)
resetReq := httptest.NewRequest("POST", "/api/auth/forgot-password", bytes.NewBuffer(resetBodyBytes))
resetReq.Header.Set("Content-Type", "application/json")
resetRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(resetRec, resetReq)
assertStatus(t, resetRec, http.StatusOK) assertStatus(t, resetRequest, http.StatusOK)
resetToken := ctx.Suite.EmailSender.GetLastPasswordResetToken() resetToken := ctx.Suite.EmailSender.GetLastPasswordResetToken()
if resetToken == "" { if resetToken == "" {
t.Fatal("Password reset token not sent") t.Fatal("Password reset token not sent")
} }
newPasswordBody := map[string]string{ newPasswordRequest := makePostRequestWithJSON(t, ctx.Router, "/api/auth/reset-password", map[string]any{
"token": resetToken, "token": resetToken,
"new_password": "NewSecurePass123!", "new_password": "NewSecurePass123!",
} })
newPasswordBodyBytes, _ := json.Marshal(newPasswordBody)
newPasswordReq := httptest.NewRequest("POST", "/api/auth/reset-password", bytes.NewBuffer(newPasswordBodyBytes))
newPasswordReq.Header.Set("Content-Type", "application/json")
newPasswordRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(newPasswordRec, newPasswordReq)
assertStatus(t, newPasswordRec, http.StatusOK) assertStatus(t, newPasswordRequest, http.StatusOK)
loginBody := map[string]string{ loginRequest := makePostRequestWithJSON(t, ctx.Router, "/api/auth/login", map[string]any{
"username": "reset_journey_user", "username": "reset_journey_user",
"password": "NewSecurePass123!", "password": "NewSecurePass123!",
} })
loginBodyBytes, _ := json.Marshal(loginBody)
loginReq := httptest.NewRequest("POST", "/api/auth/login", bytes.NewBuffer(loginBodyBytes))
loginReq.Header.Set("Content-Type", "application/json")
loginRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(loginRec, loginReq)
assertStatus(t, loginRec, http.StatusOK) assertStatus(t, loginRequest, http.StatusOK)
}) })
t.Run("Complete_Vote_And_Unvote_Journey", func(t *testing.T) { t.Run("Complete_Vote_And_Unvote_Journey", func(t *testing.T) {
@@ -176,40 +140,21 @@ func TestIntegration_EndToEndUserJourneys(t *testing.T) {
post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Vote Journey Post", "https://example.com/vote-journey") post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Vote Journey Post", "https://example.com/vote-journey")
voteBody := map[string]string{"type": "up"} voteRequest := makePostRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d/vote", post.ID), map[string]any{"type": "up"}, user, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
voteBodyBytes, _ := json.Marshal(voteBody) assertStatus(t, voteRequest, http.StatusOK)
voteReq := httptest.NewRequest("POST", fmt.Sprintf("/api/posts/%d/vote", post.ID), bytes.NewBuffer(voteBodyBytes))
voteReq.Header.Set("Content-Type", "application/json")
voteReq.Header.Set("Authorization", "Bearer "+user.Token)
voteReq = testutils.WithUserContext(voteReq, middleware.UserIDKey, user.User.ID)
voteReq = testutils.WithURLParams(voteReq, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
voteRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(voteRec, voteReq)
assertStatus(t, voteRec, http.StatusOK) getVotesRequest := makeAuthenticatedGetRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d/votes", post.ID), user, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
getVotesReq := httptest.NewRequest("GET", fmt.Sprintf("/api/posts/%d/votes", post.ID), nil) votesResponse := assertJSONResponse(t, getVotesRequest, http.StatusOK)
getVotesReq.Header.Set("Authorization", "Bearer "+user.Token)
getVotesReq = testutils.WithUserContext(getVotesReq, middleware.UserIDKey, user.User.ID)
getVotesReq = testutils.WithURLParams(getVotesReq, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
getVotesRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(getVotesRec, getVotesReq)
votesResponse := assertJSONResponse(t, getVotesRec, http.StatusOK)
if votesResponse == nil { if votesResponse == nil {
return return
} }
if data, ok := votesResponse["data"].(map[string]any); ok { if data, ok := getDataFromResponse(votesResponse); ok {
if votes, ok := data["votes"].([]any); ok && len(votes) > 0 { if votes, ok := data["votes"].([]any); ok && len(votes) > 0 {
unvoteReq := httptest.NewRequest("DELETE", fmt.Sprintf("/api/posts/%d/vote", post.ID), nil) unvoteRequest := makeDeleteRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d/vote", post.ID), user, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
unvoteReq.Header.Set("Authorization", "Bearer "+user.Token)
unvoteReq = testutils.WithUserContext(unvoteReq, middleware.UserIDKey, user.User.ID)
unvoteReq = testutils.WithURLParams(unvoteReq, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
unvoteRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(unvoteRec, unvoteReq)
assertStatus(t, unvoteRec, http.StatusOK) assertStatus(t, unvoteRequest, http.StatusOK)
} }
} }
}) })
@@ -221,31 +166,31 @@ func TestIntegration_EndToEndUserJourneys(t *testing.T) {
csrfToken := getCSRFToken(t, pageRouter, "/register") csrfToken := getCSRFToken(t, pageRouter, "/register")
reqBody := url.Values{} requestBody := url.Values{}
reqBody.Set("username", "page_journey_user") requestBody.Set("username", "page_journey_user")
reqBody.Set("email", "page_journey@example.com") requestBody.Set("email", "page_journey@example.com")
reqBody.Set("password", "SecurePass123!") requestBody.Set("password", "SecurePass123!")
reqBody.Set("password_confirm", "SecurePass123!") requestBody.Set("password_confirm", "SecurePass123!")
reqBody.Set("csrf_token", csrfToken) requestBody.Set("csrf_token", csrfToken)
req := httptest.NewRequest("POST", "/register", strings.NewReader(reqBody.Encode())) request := httptest.NewRequest("POST", "/register", strings.NewReader(requestBody.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") request.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.AddCookie(&http.Cookie{Name: "csrf_token", Value: csrfToken}) request.AddCookie(&http.Cookie{Name: "csrf_token", Value: csrfToken})
rec := httptest.NewRecorder() recorder := httptest.NewRecorder()
pageRouter.ServeHTTP(rec, req) pageRouter.ServeHTTP(recorder, request)
assertStatusRange(t, rec, http.StatusOK, http.StatusSeeOther) assertStatusRange(t, recorder, http.StatusOK, http.StatusSeeOther)
verificationToken := pageCtx.Suite.EmailSender.VerificationToken() verificationToken := pageCtx.Suite.EmailSender.VerificationToken()
if verificationToken == "" { if verificationToken == "" {
t.Fatal("Verification token not sent") t.Fatal("Verification token not sent")
} }
confirmReq := httptest.NewRequest("GET", "/confirm?token="+url.QueryEscape(verificationToken), nil) confirmRequest := httptest.NewRequest("GET", "/confirm?token="+url.QueryEscape(verificationToken), nil)
confirmRec := httptest.NewRecorder() confirmRecorder := httptest.NewRecorder()
pageRouter.ServeHTTP(confirmRec, confirmReq) pageRouter.ServeHTTP(confirmRecorder, confirmRequest)
assertStatusRange(t, confirmRec, http.StatusOK, http.StatusSeeOther) assertStatusRange(t, confirmRecorder, http.StatusOK, http.StatusSeeOther)
loginCSRFToken := getCSRFToken(t, pageRouter, "/login") loginCSRFToken := getCSRFToken(t, pageRouter, "/login")
@@ -254,15 +199,15 @@ func TestIntegration_EndToEndUserJourneys(t *testing.T) {
loginBody.Set("password", "SecurePass123!") loginBody.Set("password", "SecurePass123!")
loginBody.Set("csrf_token", loginCSRFToken) loginBody.Set("csrf_token", loginCSRFToken)
loginReq := httptest.NewRequest("POST", "/login", strings.NewReader(loginBody.Encode())) loginRequest := httptest.NewRequest("POST", "/login", strings.NewReader(loginBody.Encode()))
loginReq.Header.Set("Content-Type", "application/x-www-form-urlencoded") loginRequest.Header.Set("Content-Type", "application/x-www-form-urlencoded")
loginReq.AddCookie(&http.Cookie{Name: "csrf_token", Value: loginCSRFToken}) loginRequest.AddCookie(&http.Cookie{Name: "csrf_token", Value: loginCSRFToken})
loginRec := httptest.NewRecorder() loginRecorder := httptest.NewRecorder()
pageRouter.ServeHTTP(loginRec, loginReq) pageRouter.ServeHTTP(loginRecorder, loginRequest)
assertStatus(t, loginRec, http.StatusSeeOther) assertStatus(t, loginRecorder, http.StatusSeeOther)
loginCookies := loginRec.Result().Cookies() loginCookies := loginRecorder.Result().Cookies()
var authToken string var authToken string
for _, cookie := range loginCookies { for _, cookie := range loginCookies {
if cookie.Name == "auth_token" { if cookie.Name == "auth_token" {
@@ -275,37 +220,31 @@ func TestIntegration_EndToEndUserJourneys(t *testing.T) {
t.Fatal("Auth token not set after login") t.Fatal("Auth token not set after login")
} }
homeReq := httptest.NewRequest("GET", "/", nil) homeRequest := httptest.NewRequest("GET", "/", nil)
homeReq.AddCookie(&http.Cookie{Name: "auth_token", Value: authToken}) homeRequest.AddCookie(&http.Cookie{Name: "auth_token", Value: authToken})
homeRec := httptest.NewRecorder() homeRecorder := httptest.NewRecorder()
pageRouter.ServeHTTP(homeRec, homeReq) pageRouter.ServeHTTP(homeRecorder, homeRequest)
assertStatus(t, homeRec, http.StatusOK) assertStatus(t, homeRecorder, http.StatusOK)
}) })
t.Run("Complete_Post_Creation_And_Update_Journey", func(t *testing.T) { t.Run("Complete_Post_Creation_And_Update_Journey", func(t *testing.T) {
ctx.Suite.EmailSender.Reset() ctx.Suite.EmailSender.Reset()
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "post_update_journey_user", "post_update_journey@example.com") user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "post_update_journey_user", "post_update_journey@example.com")
postBody := map[string]string{ postBodyBytes, _ := json.Marshal(map[string]any{
"title": "Original Title", "title": "Original Title",
"url": "https://example.com/original", "url": "https://example.com/original",
"content": "Original content", "content": "Original content",
} })
postBodyBytes, _ := json.Marshal(postBody) postRequest := makeAuthenticatedRequest(t, ctx.Router, "POST", "/api/posts", postBodyBytes, user, nil)
postReq := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer(postBodyBytes))
postReq.Header.Set("Content-Type", "application/json")
postReq.Header.Set("Authorization", "Bearer "+user.Token)
postReq = testutils.WithUserContext(postReq, middleware.UserIDKey, user.User.ID)
postRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(postRec, postReq)
postResponse := assertJSONResponse(t, postRec, http.StatusCreated) postResponse := assertJSONResponse(t, postRequest, http.StatusCreated)
if postResponse == nil { if postResponse == nil {
return return
} }
postData, ok := postResponse["data"].(map[string]any) postData, ok := getDataFromResponse(postResponse)
if !ok { if !ok {
t.Fatal("Post response missing data") t.Fatal("Post response missing data")
} }
@@ -315,34 +254,25 @@ func TestIntegration_EndToEndUserJourneys(t *testing.T) {
t.Fatal("Post response missing id") t.Fatal("Post response missing id")
} }
updateBody := map[string]string{ updateBodyBytes, _ := json.Marshal(map[string]any{
"title": "Updated Title", "title": "Updated Title",
"content": "Updated content", "content": "Updated content",
} })
updateBodyBytes, _ := json.Marshal(updateBody) updateRequest := makeAuthenticatedRequest(t, ctx.Router, "PUT", fmt.Sprintf("/api/posts/%.0f", postID), updateBodyBytes, user, map[string]string{"id": fmt.Sprintf("%.0f", postID)})
updateReq := httptest.NewRequest("PUT", fmt.Sprintf("/api/posts/%.0f", postID), bytes.NewBuffer(updateBodyBytes))
updateReq.Header.Set("Content-Type", "application/json")
updateReq.Header.Set("Authorization", "Bearer "+user.Token)
updateReq = testutils.WithUserContext(updateReq, middleware.UserIDKey, user.User.ID)
updateReq = testutils.WithURLParams(updateReq, map[string]string{"id": fmt.Sprintf("%.0f", postID)})
updateRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(updateRec, updateReq)
updateResponse := assertJSONResponse(t, updateRec, http.StatusOK) updateResponse := assertJSONResponse(t, updateRequest, http.StatusOK)
if updateResponse == nil { if updateResponse == nil {
return return
} }
getPostReq := httptest.NewRequest("GET", fmt.Sprintf("/api/posts/%.0f", postID), nil) getPostRequest := makeGetRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%.0f", postID))
getPostRec := httptest.NewRecorder()
ctx.Router.ServeHTTP(getPostRec, getPostReq)
getPostResponse := assertJSONResponse(t, getPostRec, http.StatusOK) getPostResponse := assertJSONResponse(t, getPostRequest, http.StatusOK)
if getPostResponse == nil { if getPostResponse == nil {
return return
} }
if data, ok := getPostResponse["data"].(map[string]any); ok { if data, ok := getDataFromResponse(getPostResponse); ok {
if post, ok := data["post"].(map[string]any); ok { if post, ok := data["post"].(map[string]any); ok {
if title, ok := post["title"].(string); ok && title != "Updated Title" { if title, ok := post["title"].(string); ok && title != "Updated Title" {
t.Errorf("Post title not updated: expected 'Updated Title', got '%s'", title) t.Errorf("Post title not updated: expected 'Updated Title', got '%s'", title)

View File

@@ -2,7 +2,6 @@ package integration
import ( import (
"bytes" "bytes"
"encoding/json"
"fmt" "fmt"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
@@ -19,58 +18,46 @@ func TestIntegration_ErrorPropagation(t *testing.T) {
ctx.Suite.EmailSender.Reset() ctx.Suite.EmailSender.Reset()
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "json_error_user", "json_error@example.com") user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "json_error_user", "json_error@example.com")
req := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer([]byte("invalid json{"))) request := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer([]byte("invalid json{")))
req.Header.Set("Content-Type", "application/json") request.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+user.Token) request.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID) request = testutils.WithUserContext(request, middleware.UserIDKey, user.User.ID)
rec := httptest.NewRecorder() recorder := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req) ctx.Router.ServeHTTP(recorder, request)
assertErrorResponse(t, rec, http.StatusBadRequest) assertErrorResponse(t, recorder, http.StatusBadRequest)
}) })
t.Run("Validation_Error_Propagation", func(t *testing.T) { t.Run("Validation_Error_Propagation", func(t *testing.T) {
ctx.Suite.EmailSender.Reset() ctx.Suite.EmailSender.Reset()
reqBody := map[string]string{ reqBody := map[string]any{
"username": "", "username": "",
"email": "invalid-email", "email": "invalid-email",
"password": "weak", "password": "weak",
} }
body, _ := json.Marshal(reqBody) request := makePostRequestWithJSON(t, ctx.Router, "/api/auth/register", reqBody)
req := httptest.NewRequest("POST", "/api/auth/register", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req) assertErrorResponse(t, request, http.StatusBadRequest)
assertErrorResponse(t, rec, http.StatusBadRequest)
}) })
t.Run("Database_Error_Propagation", func(t *testing.T) { t.Run("Database_Error_Propagation", func(t *testing.T) {
ctx.Suite.EmailSender.Reset() ctx.Suite.EmailSender.Reset()
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "db_error_user", "db_error@example.com") user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "db_error_user", "db_error@example.com")
reqBody := map[string]string{ reqBody := map[string]any{
"title": "Test Post", "title": "Test Post",
"url": "https://example.com/test", "url": "https://example.com/test",
"content": "Test content", "content": "Test content",
} }
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer(body)) request := makePostRequest(t, ctx.Router, "/api/posts", reqBody, user, nil)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
rec := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req) if request.Code == http.StatusInternalServerError {
assertErrorResponse(t, request, http.StatusInternalServerError)
if rec.Code == http.StatusInternalServerError {
assertErrorResponse(t, rec, http.StatusInternalServerError)
} else { } else {
assertStatus(t, rec, http.StatusCreated) assertStatus(t, request, http.StatusCreated)
} }
}) })
@@ -78,34 +65,23 @@ func TestIntegration_ErrorPropagation(t *testing.T) {
ctx.Suite.EmailSender.Reset() ctx.Suite.EmailSender.Reset()
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "notfound_error_user", "notfound_error@example.com") user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "notfound_error_user", "notfound_error@example.com")
req := httptest.NewRequest("GET", "/api/posts/999999", nil) request := makeAuthenticatedGetRequest(t, ctx.Router, "/api/posts/999999", user, map[string]string{"id": "999999"})
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
req = testutils.WithURLParams(req, map[string]string{"id": "999999"})
rec := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req) assertErrorResponse(t, request, http.StatusNotFound)
assertErrorResponse(t, rec, http.StatusNotFound)
}) })
t.Run("Unauthorized_Error_Propagation", func(t *testing.T) { t.Run("Unauthorized_Error_Propagation", func(t *testing.T) {
ctx.Suite.EmailSender.Reset() ctx.Suite.EmailSender.Reset()
reqBody := map[string]string{ reqBody := map[string]any{
"title": "Test Post", "title": "Test Post",
"url": "https://example.com/test", "url": "https://example.com/test",
"content": "Test content", "content": "Test content",
} }
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer(body)) request := makePostRequestWithJSON(t, ctx.Router, "/api/posts", reqBody)
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req) assertErrorResponse(t, request, http.StatusUnauthorized)
assertErrorResponse(t, rec, http.StatusUnauthorized)
}) })
t.Run("Forbidden_Error_Propagation", func(t *testing.T) { t.Run("Forbidden_Error_Propagation", func(t *testing.T) {
@@ -115,79 +91,59 @@ func TestIntegration_ErrorPropagation(t *testing.T) {
post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, owner.User.ID, "Forbidden Post", "https://example.com/forbidden") post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, owner.User.ID, "Forbidden Post", "https://example.com/forbidden")
updateBody := map[string]string{ updateBody := map[string]any{
"title": "Updated Title", "title": "Updated Title",
"content": "Updated content", "content": "Updated content",
} }
body, _ := json.Marshal(updateBody)
req := httptest.NewRequest("PUT", fmt.Sprintf("/api/posts/%d", post.ID), bytes.NewBuffer(body)) request := makePutRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d", post.ID), updateBody, otherUser, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+otherUser.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, otherUser.User.ID)
req = testutils.WithURLParams(req, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
rec := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req) assertErrorResponse(t, request, http.StatusForbidden)
assertErrorResponse(t, rec, http.StatusForbidden)
}) })
t.Run("Service_Error_Propagation", func(t *testing.T) { t.Run("Service_Error_Propagation", func(t *testing.T) {
ctx.Suite.EmailSender.Reset() ctx.Suite.EmailSender.Reset()
reqBody := map[string]string{ reqBody := map[string]any{
"username": "existing_user", "username": "existing_user",
"email": "existing@example.com", "email": "existing@example.com",
"password": "SecurePass123!", "password": "SecurePass123!",
} }
body, _ := json.Marshal(reqBody)
req := httptest.NewRequest("POST", "/api/auth/register", bytes.NewBuffer(body)) request := makePostRequestWithJSON(t, ctx.Router, "/api/auth/register", reqBody)
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req)
assertStatus(t, rec, http.StatusCreated) assertStatus(t, request, http.StatusCreated)
req = httptest.NewRequest("POST", "/api/auth/register", bytes.NewBuffer(body)) request = makePostRequestWithJSON(t, ctx.Router, "/api/auth/register", reqBody)
req.Header.Set("Content-Type", "application/json")
rec = httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req)
assertStatusRange(t, rec, http.StatusBadRequest, http.StatusConflict) assertStatusRange(t, request, http.StatusBadRequest, http.StatusConflict)
assertErrorResponse(t, rec, rec.Code) assertErrorResponse(t, request, request.Code)
}) })
t.Run("Middleware_Error_Propagation", func(t *testing.T) { t.Run("Middleware_Error_Propagation", func(t *testing.T) {
ctx.Suite.EmailSender.Reset() ctx.Suite.EmailSender.Reset()
req := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer([]byte("{}"))) request := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer([]byte("{}")))
req.Header.Set("Content-Type", "application/json") request.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer expired.invalid.token") request.Header.Set("Authorization", "Bearer expired.invalid.token")
rec := httptest.NewRecorder() recorder := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req) ctx.Router.ServeHTTP(recorder, request)
assertErrorResponse(t, rec, http.StatusUnauthorized) assertErrorResponse(t, recorder, http.StatusUnauthorized)
}) })
t.Run("Handler_Error_Response_Format", func(t *testing.T) { t.Run("Handler_Error_Response_Format", func(t *testing.T) {
ctx.Suite.EmailSender.Reset() ctx.Suite.EmailSender.Reset()
req := httptest.NewRequest("GET", "/api/nonexistent", nil) request := makeGetRequest(t, ctx.Router, "/api/nonexistent")
rec := httptest.NewRecorder()
ctx.Router.ServeHTTP(rec, req) if request.Code == http.StatusNotFound {
if request.Header().Get("Content-Type") == "application/json" {
if rec.Code == http.StatusNotFound { assertErrorResponse(t, request, http.StatusNotFound)
if rec.Header().Get("Content-Type") == "application/json" { } else if request.Body.Len() == 0 {
assertErrorResponse(t, rec, http.StatusNotFound)
} else {
if rec.Body.Len() == 0 {
t.Error("Expected error response body") t.Error("Expected error response body")
} }
} }
}
}) })
} }

View File

@@ -11,55 +11,35 @@ import (
"testing" "testing"
"time" "time"
"github.com/golang-jwt/jwt/v5"
"goyco/internal/database" "goyco/internal/database"
"goyco/internal/handlers" "goyco/internal/handlers"
"goyco/internal/middleware" "goyco/internal/middleware"
"goyco/internal/repositories" "goyco/internal/repositories"
"goyco/internal/services" "goyco/internal/services"
"goyco/internal/testutils" "goyco/internal/testutils"
"github.com/golang-jwt/jwt/v5"
) )
func TestIntegration_Handlers(t *testing.T) { func TestIntegration_Handlers(t *testing.T) {
suite := testutils.NewServiceSuite(t) ctx := setupTestContext(t)
authService := ctx.AuthService
authService, err := services.NewAuthFacadeForTest(testutils.AppTestConfig, suite.UserRepo, suite.PostRepo, suite.DeletionRepo, suite.RefreshTokenRepo, suite.EmailSender) emailSender := ctx.Suite.EmailSender
if err != nil { userRepo := ctx.Suite.UserRepo
t.Fatalf("Failed to create auth service: %v", err) postRepo := ctx.Suite.PostRepo
}
voteService := services.NewVoteService(suite.VoteRepo, suite.PostRepo, suite.DB)
emailSender := suite.EmailSender
userRepo := suite.UserRepo
postRepo := suite.PostRepo
titleFetcher := suite.TitleFetcher
authHandler := handlers.NewAuthHandler(authService, userRepo)
postHandler := handlers.NewPostHandler(postRepo, titleFetcher, voteService)
voteHandler := handlers.NewVoteHandler(voteService)
userHandler := handlers.NewUserHandler(userRepo, authService)
t.Run("Auth_Handler_Complete_Workflow", func(t *testing.T) { t.Run("Auth_Handler_Complete_Workflow", func(t *testing.T) {
emailSender.Reset() emailSender.Reset()
registerData := map[string]string{ registerResponse := makePostRequestWithJSON(t, ctx.Router, "/api/auth/register", map[string]any{
"username": "handler_user", "username": "handler_user",
"email": "handler@example.com", "email": "handler@example.com",
"password": "SecurePass123!", "password": "SecurePass123!",
} })
registerBody, _ := json.Marshal(registerData) if registerResponse.Code != http.StatusCreated {
registerReq := httptest.NewRequest("POST", "/api/auth/register", bytes.NewBuffer(registerBody)) t.Errorf("Expected status 201, got %d", registerResponse.Code)
registerReq.Header.Set("Content-Type", "application/json")
registerResp := httptest.NewRecorder()
authHandler.Register(registerResp, registerReq)
if registerResp.Code != http.StatusCreated {
t.Errorf("Expected status 201, got %d", registerResp.Code)
} }
var registerPayload map[string]any registerPayload := assertJSONResponse(t, registerResponse, http.StatusCreated)
if err := json.Unmarshal(registerResp.Body.Bytes(), &registerPayload); err != nil {
t.Fatalf("Failed to decode register response: %v", err)
}
if success, _ := registerPayload["success"].(bool); !success { if success, _ := registerPayload["success"].(bool); !success {
t.Fatalf("Expected register response success, got %v", registerPayload) t.Fatalf("Expected register response success, got %v", registerPayload)
} }
@@ -78,11 +58,9 @@ func TestIntegration_Handlers(t *testing.T) {
t.Fatalf("Failed to update user with mock token: %v", err) t.Fatalf("Failed to update user with mock token: %v", err)
} }
confirmReq := httptest.NewRequest(http.MethodGet, "/api/auth/confirm?token="+url.QueryEscape(mockToken), nil) confirmResponse := makeGetRequest(t, ctx.Router, "/api/auth/confirm?token="+url.QueryEscape(mockToken))
confirmResp := httptest.NewRecorder() if confirmResponse.Code != http.StatusOK {
authHandler.ConfirmEmail(confirmResp, confirmReq) t.Fatalf("Expected 200 when confirming email via handler, got %d", confirmResponse.Code)
if confirmResp.Code != http.StatusOK {
t.Fatalf("Expected 200 when confirming email via handler, got %d", confirmResp.Code)
} }
loginSeed := createAuthenticatedUser(t, authService, userRepo, "auth_handler_login", "auth_handler_login@example.com") loginSeed := createAuthenticatedUser(t, authService, userRepo, "auth_handler_login", "auth_handler_login@example.com")
@@ -92,92 +70,60 @@ func TestIntegration_Handlers(t *testing.T) {
t.Fatalf("Service login failed for seeded user: %v", err) t.Fatalf("Service login failed for seeded user: %v", err)
} }
meReq := httptest.NewRequest("GET", "/api/auth/me", nil) meResponse := makeAuthenticatedGetRequest(t, ctx.Router, "/api/auth/me", &authenticatedUser{User: loginSeed.User, Token: loginAuth.AccessToken}, nil)
meReq.Header.Set("Authorization", "Bearer "+loginAuth.AccessToken) if meResponse.Code != http.StatusOK {
meReq = testutils.WithUserContext(meReq, middleware.UserIDKey, loginSeed.User.ID) t.Errorf("Expected status 200, got %d", meResponse.Code)
meResp := httptest.NewRecorder()
authHandler.Me(meResp, meReq)
if meResp.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", meResp.Code)
} }
}) })
t.Run("Auth_Handler_Security_Validation", func(t *testing.T) { t.Run("Auth_Handler_Security_Validation", func(t *testing.T) {
emailSender.Reset() emailSender.Reset()
weakData := map[string]string{ weakResponse := makePostRequestWithJSON(t, ctx.Router, "/api/auth/register", map[string]any{
"username": "weak_user", "username": "weak_user",
"email": "weak@example.com", "email": "weak@example.com",
"password": "123", "password": "123",
} })
weakBody, _ := json.Marshal(weakData) if weakResponse.Code != http.StatusBadRequest {
weakReq := httptest.NewRequest("POST", "/api/auth/register", bytes.NewBuffer(weakBody)) t.Errorf("Expected status 400 for weak password, got %d", weakResponse.Code)
weakReq.Header.Set("Content-Type", "application/json")
weakResp := httptest.NewRecorder()
authHandler.Register(weakResp, weakReq)
if weakResp.Code != http.StatusBadRequest {
t.Errorf("Expected status 400 for weak password, got %d", weakResp.Code)
} }
var weakErrorResp map[string]any weakErrorResponse := assertJSONResponse(t, weakResponse, http.StatusBadRequest)
if err := json.Unmarshal(weakResp.Body.Bytes(), &weakErrorResp); err != nil { if success, _ := weakErrorResponse["success"].(bool); success {
t.Fatalf("Failed to decode error response: %v", err)
}
if success, _ := weakErrorResp["success"].(bool); success {
t.Error("Expected error response to have success=false") t.Error("Expected error response to have success=false")
} }
if errorMsg, ok := weakErrorResp["error"].(string); !ok || errorMsg == "" { if errorMsg, ok := weakErrorResponse["error"].(string); !ok || errorMsg == "" {
t.Error("Expected error response to contain validation error message") t.Error("Expected error response to contain validation error message")
} }
invalidData := map[string]string{ invalidResponse := makePostRequestWithJSON(t, ctx.Router, "/api/auth/register", map[string]any{
"username": "invalid_user", "username": "invalid_user",
"email": "not-an-email", "email": "not-an-email",
"password": "SecurePass123!", "password": "SecurePass123!",
} })
invalidBody, _ := json.Marshal(invalidData) if invalidResponse.Code != http.StatusBadRequest {
invalidReq := httptest.NewRequest("POST", "/api/auth/register", bytes.NewBuffer(invalidBody)) t.Errorf("Expected status 400 for invalid email, got %d", invalidResponse.Code)
invalidReq.Header.Set("Content-Type", "application/json")
invalidResp := httptest.NewRecorder()
authHandler.Register(invalidResp, invalidReq)
if invalidResp.Code != http.StatusBadRequest {
t.Errorf("Expected status 400 for invalid email, got %d", invalidResp.Code)
} }
var invalidEmailErrorResp map[string]any invalidEmailErrorResponse := assertJSONResponse(t, invalidResponse, http.StatusBadRequest)
if err := json.Unmarshal(invalidResp.Body.Bytes(), &invalidEmailErrorResp); err != nil { if success, _ := invalidEmailErrorResponse["success"].(bool); success {
t.Fatalf("Failed to decode error response: %v", err)
}
if success, _ := invalidEmailErrorResp["success"].(bool); success {
t.Error("Expected error response to have success=false") t.Error("Expected error response to have success=false")
} }
if errorMsg, ok := invalidEmailErrorResp["error"].(string); !ok || errorMsg == "" { if errorMsg, ok := invalidEmailErrorResponse["error"].(string); !ok || errorMsg == "" {
t.Error("Expected error response to contain validation error message") t.Error("Expected error response to contain validation error message")
} }
incompleteData := map[string]string{ incompleteResponse := makePostRequestWithJSON(t, ctx.Router, "/api/auth/register", map[string]any{
"username": "incomplete_user", "username": "incomplete_user",
} })
incompleteBody, _ := json.Marshal(incompleteData) if incompleteResponse.Code != http.StatusBadRequest {
incompleteReq := httptest.NewRequest("POST", "/api/auth/register", bytes.NewBuffer(incompleteBody)) t.Errorf("Expected status 400 for missing fields, got %d", incompleteResponse.Code)
incompleteReq.Header.Set("Content-Type", "application/json")
incompleteResp := httptest.NewRecorder()
authHandler.Register(incompleteResp, incompleteReq)
if incompleteResp.Code != http.StatusBadRequest {
t.Errorf("Expected status 400 for missing fields, got %d", incompleteResp.Code)
} }
var incompleteErrorResp map[string]any incompleteErrorResponse := assertJSONResponse(t, incompleteResponse, http.StatusBadRequest)
if err := json.Unmarshal(incompleteResp.Body.Bytes(), &incompleteErrorResp); err != nil { if success, _ := incompleteErrorResponse["success"].(bool); success {
t.Fatalf("Failed to decode error response: %v", err)
}
if success, _ := incompleteErrorResp["success"].(bool); success {
t.Error("Expected error response to have success=false") t.Error("Expected error response to have success=false")
} }
if errorMsg, ok := incompleteErrorResp["error"].(string); !ok || errorMsg == "" { if errorMsg, ok := incompleteErrorResponse["error"].(string); !ok || errorMsg == "" {
t.Error("Expected error response to contain validation error message") t.Error("Expected error response to contain validation error message")
} }
}) })
@@ -186,28 +132,17 @@ func TestIntegration_Handlers(t *testing.T) {
emailSender.Reset() emailSender.Reset()
user := createAuthenticatedUser(t, authService, userRepo, "post_user", "post@example.com") user := createAuthenticatedUser(t, authService, userRepo, "post_user", "post@example.com")
postData := map[string]string{ postResponse := makePostRequest(t, ctx.Router, "/api/posts", map[string]any{
"title": "Handler Test Post", "title": "Handler Test Post",
"url": "https://example.com/handler-test", "url": "https://example.com/handler-test",
"content": "This is a handler test post", "content": "This is a handler test post",
} }, user, nil)
postBody, _ := json.Marshal(postData) if postResponse.Code != http.StatusCreated {
postReq := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer(postBody)) t.Errorf("Expected status 201, got %d", postResponse.Code)
postReq.Header.Set("Content-Type", "application/json")
postReq.Header.Set("Authorization", "Bearer "+user.Token)
postReq = testutils.WithUserContext(postReq, middleware.UserIDKey, user.User.ID)
postResp := httptest.NewRecorder()
postHandler.CreatePost(postResp, postReq)
if postResp.Code != http.StatusCreated {
t.Errorf("Expected status 201, got %d", postResp.Code)
} }
var postResult map[string]any postResult := assertJSONResponse(t, postResponse, http.StatusCreated)
if err := json.Unmarshal(postResp.Body.Bytes(), &postResult); err != nil { postDetails, ok := getDataFromResponse(postResult)
t.Fatalf("Failed to decode post response: %v", err)
}
postDetails, ok := postResult["data"].(map[string]any)
if !ok { if !ok {
t.Fatalf("Expected data object in post response, got %v", postResult) t.Fatalf("Expected data object in post response, got %v", postResult)
} }
@@ -216,87 +151,49 @@ func TestIntegration_Handlers(t *testing.T) {
t.Fatal("Expected post ID in response") t.Fatal("Expected post ID in response")
} }
getReq := httptest.NewRequest("GET", fmt.Sprintf("/api/posts/%d", int(postID)), nil) getResponse := makeGetRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d", int(postID)))
getReq = testutils.WithURLParams(getReq, map[string]string{"id": fmt.Sprintf("%d", int(postID))}) if getResponse.Code != http.StatusOK {
getResp := httptest.NewRecorder() t.Errorf("Expected status 200, got %d", getResponse.Code)
postHandler.GetPost(getResp, getReq)
if getResp.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", getResp.Code)
} }
postsReq := httptest.NewRequest("GET", "/api/posts", nil) postsResponse := makeGetRequest(t, ctx.Router, "/api/posts")
postsResp := httptest.NewRecorder() if postsResponse.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", postsResponse.Code)
postHandler.GetPosts(postsResp, postsReq)
if postsResp.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", postsResp.Code)
} }
searchReq := httptest.NewRequest("GET", "/api/posts/search?q=handler", nil) searchResponse := makeGetRequest(t, ctx.Router, "/api/posts/search?q=handler")
searchResp := httptest.NewRecorder() if searchResponse.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", searchResponse.Code)
postHandler.SearchPosts(searchResp, searchReq)
if searchResp.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", searchResp.Code)
} }
}) })
t.Run("Post_Handler_Security_Validation", func(t *testing.T) { t.Run("Post_Handler_Security_Validation", func(t *testing.T) {
emailSender.Reset() emailSender.Reset()
postData := map[string]string{ postResponse := makePostRequestWithJSON(t, ctx.Router, "/api/posts", map[string]any{
"title": "Unauthorized Post", "title": "Unauthorized Post",
"url": "https://example.com/unauthorized", "url": "https://example.com/unauthorized",
"content": "This should fail", "content": "This should fail",
} })
postBody, _ := json.Marshal(postData) authErrorResponse := assertJSONResponse(t, postResponse, http.StatusUnauthorized)
postReq := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer(postBody)) if success, _ := authErrorResponse["success"].(bool); success {
postReq.Header.Set("Content-Type", "application/json")
postResp := httptest.NewRecorder()
postHandler.CreatePost(postResp, postReq)
if postResp.Code != http.StatusUnauthorized {
t.Errorf("Expected status 401 for unauthenticated post creation, got %d", postResp.Code)
}
var authErrorResp map[string]any
if err := json.Unmarshal(postResp.Body.Bytes(), &authErrorResp); err != nil {
t.Fatalf("Failed to decode error response: %v", err)
}
if success, _ := authErrorResp["success"].(bool); success {
t.Error("Expected error response to have success=false") t.Error("Expected error response to have success=false")
} }
if errorMsg, ok := authErrorResp["error"].(string); !ok || errorMsg == "" { if errorMsg, ok := authErrorResponse["error"].(string); !ok || errorMsg == "" {
t.Error("Expected error response to contain authentication error message") t.Error("Expected error response to contain authentication error message")
} }
user := createAuthenticatedUser(t, authService, userRepo, "security_user", "security@example.com") user := createAuthenticatedUser(t, authService, userRepo, "security_user", "security@example.com")
invalidData := map[string]string{ invalidResponse := makePostRequest(t, ctx.Router, "/api/posts", map[string]any{
"title": "", "title": "",
"url": "not-a-url", "url": "not-a-url",
"content": "Invalid post", "content": "Invalid post",
} }, user, nil)
invalidBody, _ := json.Marshal(invalidData) postValidationErrorResponse := assertJSONResponse(t, invalidResponse, http.StatusBadRequest)
invalidReq := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer(invalidBody)) if success, _ := postValidationErrorResponse["success"].(bool); success {
invalidReq.Header.Set("Content-Type", "application/json")
invalidReq.Header.Set("Authorization", "Bearer "+user.Token)
invalidReq = testutils.WithUserContext(invalidReq, middleware.UserIDKey, user.User.ID)
invalidResp := httptest.NewRecorder()
postHandler.CreatePost(invalidResp, invalidReq)
if invalidResp.Code != http.StatusBadRequest {
t.Errorf("Expected status 400 for invalid post data, got %d", invalidResp.Code)
}
var postValidationErrorResp map[string]any
if err := json.Unmarshal(invalidResp.Body.Bytes(), &postValidationErrorResp); err != nil {
t.Fatalf("Failed to decode error response: %v", err)
}
if success, _ := postValidationErrorResp["success"].(bool); success {
t.Error("Expected error response to have success=false") t.Error("Expected error response to have success=false")
} }
if errorMsg, ok := postValidationErrorResp["error"].(string); !ok || errorMsg == "" { if errorMsg, ok := postValidationErrorResponse["error"].(string); !ok || errorMsg == "" {
t.Error("Expected error response to contain validation error message") t.Error("Expected error response to contain validation error message")
} }
}) })
@@ -306,161 +203,100 @@ func TestIntegration_Handlers(t *testing.T) {
user := createAuthenticatedUser(t, authService, userRepo, "vote_handler_user", "vote_handler@example.com") user := createAuthenticatedUser(t, authService, userRepo, "vote_handler_user", "vote_handler@example.com")
post := testutils.CreatePostWithRepo(t, postRepo, user.User.ID, "Vote Handler Test Post", "https://example.com/vote-handler") post := testutils.CreatePostWithRepo(t, postRepo, user.User.ID, "Vote Handler Test Post", "https://example.com/vote-handler")
voteData := map[string]string{ voteResponse := makePostRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d/vote", post.ID), map[string]any{"type": "up"}, user, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
"type": "up", assertStatus(t, voteResponse, http.StatusOK)
}
voteBody, _ := json.Marshal(voteData)
voteReq := httptest.NewRequest("POST", fmt.Sprintf("/api/posts/%d/vote", post.ID), bytes.NewBuffer(voteBody))
voteReq.Header.Set("Content-Type", "application/json")
voteReq.Header.Set("Authorization", "Bearer "+user.Token)
voteReq = testutils.WithUserContext(voteReq, middleware.UserIDKey, user.User.ID)
voteReq = testutils.WithURLParams(voteReq, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
voteResp := httptest.NewRecorder()
voteHandler.CastVote(voteResp, voteReq) getVoteResponse := makeAuthenticatedGetRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d/vote", post.ID), user, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
if voteResp.Code != http.StatusOK { assertStatus(t, getVoteResponse, http.StatusOK)
t.Errorf("Expected status 200, got %d", voteResp.Code)
}
getVoteReq := httptest.NewRequest("GET", fmt.Sprintf("/api/posts/%d/vote", post.ID), nil) getPostVotesResponse := makeAuthenticatedGetRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d/votes", post.ID), user, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
getVoteReq.Header.Set("Authorization", "Bearer "+user.Token) assertStatus(t, getPostVotesResponse, http.StatusOK)
getVoteReq = testutils.WithUserContext(getVoteReq, middleware.UserIDKey, user.User.ID)
getVoteReq = testutils.WithURLParams(getVoteReq, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
getVoteResp := httptest.NewRecorder()
voteHandler.GetUserVote(getVoteResp, getVoteReq) removeVoteResponse := makeDeleteRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d/vote", post.ID), user, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
if getVoteResp.Code != http.StatusOK { assertStatus(t, removeVoteResponse, http.StatusOK)
t.Errorf("Expected status 200, got %d", getVoteResp.Code)
}
getPostVotesReq := httptest.NewRequest("GET", fmt.Sprintf("/api/posts/%d/votes", post.ID), nil)
getPostVotesReq.Header.Set("Authorization", "Bearer "+user.Token)
getPostVotesReq = testutils.WithUserContext(getPostVotesReq, middleware.UserIDKey, user.User.ID)
getPostVotesReq = testutils.WithURLParams(getPostVotesReq, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
getPostVotesResp := httptest.NewRecorder()
voteHandler.GetPostVotes(getPostVotesResp, getPostVotesReq)
if getPostVotesResp.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", getPostVotesResp.Code)
}
removeVoteReq := httptest.NewRequest("DELETE", fmt.Sprintf("/api/posts/%d/vote", post.ID), nil)
removeVoteReq.Header.Set("Authorization", "Bearer "+user.Token)
removeVoteReq = testutils.WithUserContext(removeVoteReq, middleware.UserIDKey, user.User.ID)
removeVoteReq = testutils.WithURLParams(removeVoteReq, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
removeVoteResp := httptest.NewRecorder()
voteHandler.RemoveVote(removeVoteResp, removeVoteReq)
if removeVoteResp.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", removeVoteResp.Code)
}
}) })
t.Run("User_Handler_Complete_Workflow", func(t *testing.T) { t.Run("User_Handler_Complete_Workflow", func(t *testing.T) {
emailSender.Reset() emailSender.Reset()
user := createAuthenticatedUser(t, authService, userRepo, "user_handler_user", "user_handler@example.com") user := createAuthenticatedUser(t, authService, userRepo, "user_handler_user", "user_handler@example.com")
usersReq := httptest.NewRequest("GET", "/api/users", nil) usersResponse := makeAuthenticatedGetRequest(t, ctx.Router, "/api/users", user, nil)
usersReq.Header.Set("Authorization", "Bearer "+user.Token) assertStatus(t, usersResponse, http.StatusOK)
usersReq = testutils.WithUserContext(usersReq, middleware.UserIDKey, user.User.ID)
usersResp := httptest.NewRecorder()
userHandler.GetUsers(usersResp, usersReq) getUserResponse := makeAuthenticatedGetRequest(t, ctx.Router, fmt.Sprintf("/api/users/%d", user.User.ID), user, map[string]string{"id": fmt.Sprintf("%d", user.User.ID)})
if usersResp.Code != http.StatusOK { assertStatus(t, getUserResponse, http.StatusOK)
t.Errorf("Expected status 200, got %d", usersResp.Code)
}
getUserReq := httptest.NewRequest("GET", fmt.Sprintf("/api/users/%d", user.User.ID), nil) getUserPostsResponse := makeAuthenticatedGetRequest(t, ctx.Router, fmt.Sprintf("/api/users/%d/posts", user.User.ID), user, map[string]string{"id": fmt.Sprintf("%d", user.User.ID)})
getUserReq.Header.Set("Authorization", "Bearer "+user.Token) assertStatus(t, getUserPostsResponse, http.StatusOK)
getUserReq = testutils.WithUserContext(getUserReq, middleware.UserIDKey, user.User.ID)
getUserReq = testutils.WithURLParams(getUserReq, map[string]string{"id": fmt.Sprintf("%d", user.User.ID)})
getUserResp := httptest.NewRecorder()
userHandler.GetUser(getUserResp, getUserReq)
if getUserResp.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", getUserResp.Code)
}
getUserPostsReq := httptest.NewRequest("GET", fmt.Sprintf("/api/users/%d/posts", user.User.ID), nil)
getUserPostsReq.Header.Set("Authorization", "Bearer "+user.Token)
getUserPostsReq = testutils.WithUserContext(getUserPostsReq, middleware.UserIDKey, user.User.ID)
getUserPostsReq = testutils.WithURLParams(getUserPostsReq, map[string]string{"id": fmt.Sprintf("%d", user.User.ID)})
getUserPostsResp := httptest.NewRecorder()
userHandler.GetUserPosts(getUserPostsResp, getUserPostsReq)
if getUserPostsResp.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", getUserPostsResp.Code)
}
}) })
t.Run("Error_Handling_Invalid_Requests", func(t *testing.T) { t.Run("Error_Handling_Invalid_Requests", func(t *testing.T) {
invalidJSONReq := httptest.NewRequest("POST", "/api/auth/register", bytes.NewBuffer([]byte("invalid json"))) middleware.StopAllRateLimiters()
invalidJSONReq.Header.Set("Content-Type", "application/json") ctx.Suite.EmailSender.Reset()
invalidJSONResp := httptest.NewRecorder()
authHandler.Register(invalidJSONResp, invalidJSONReq) invalidJSONResponse := makeRequest(t, ctx.Router, "POST", "/api/auth/register", []byte("invalid json"), map[string]string{"Content-Type": "application/json"})
if invalidJSONResp.Code != http.StatusBadRequest { jsonErrorResponse := assertJSONResponse(t, invalidJSONResponse, http.StatusBadRequest)
t.Errorf("Expected status 400 for invalid JSON, got %d", invalidJSONResp.Code) if success, _ := jsonErrorResponse["success"].(bool); success {
}
var jsonErrorResp map[string]any
if err := json.Unmarshal(invalidJSONResp.Body.Bytes(), &jsonErrorResp); err != nil {
t.Fatalf("Failed to decode error response: %v", err)
}
if success, _ := jsonErrorResp["success"].(bool); success {
t.Error("Expected error response to have success=false") t.Error("Expected error response to have success=false")
} }
if errorMsg, ok := jsonErrorResp["error"].(string); !ok || errorMsg == "" { if errorMsg, ok := jsonErrorResponse["error"].(string); !ok || errorMsg == "" {
t.Error("Expected error response to contain JSON parsing error message") t.Error("Expected error response to contain JSON parsing error message")
} }
missingCTData := map[string]string{ missingCTData := map[string]string{
"username": "missing_ct_user", "username": uniqueTestUsername(t, "missing_ct"),
"email": "missing_ct@example.com", "email": uniqueTestEmail(t, "missing_ct"),
"password": "SecurePass123!", "password": "SecurePass123!",
} }
missingCTBody, _ := json.Marshal(missingCTData) missingCTBody, _ := json.Marshal(missingCTData)
missingCTReq := httptest.NewRequest("POST", "/api/auth/register", bytes.NewBuffer(missingCTBody)) missingCTRequest := httptest.NewRequest("POST", "/api/auth/register", bytes.NewBuffer(missingCTBody))
missingCTResp := httptest.NewRecorder() missingCTResponse := httptest.NewRecorder()
authHandler.Register(missingCTResp, missingCTReq) ctx.Router.ServeHTTP(missingCTResponse, missingCTRequest)
if missingCTResp.Code != http.StatusCreated { if missingCTResponse.Code == http.StatusTooManyRequests {
t.Errorf("Expected status 201, got %d", missingCTResp.Code) var rateLimitResponse map[string]any
if err := json.Unmarshal(missingCTResponse.Body.Bytes(), &rateLimitResponse); err != nil {
t.Errorf("Rate limited but response is not valid JSON: %v", err)
} else {
t.Logf("Rate limit hit (expected in full test suite run), but request was processed correctly (not rejected as invalid JSON)")
}
} else if missingCTResponse.Code != http.StatusCreated {
t.Errorf("Expected status 201, got %d", missingCTResponse.Code)
} }
invalidEndpointReq := httptest.NewRequest("GET", "/api/invalid/endpoint", nil) invalidEndpointRequest := httptest.NewRequest("GET", "/api/invalid/endpoint", nil)
invalidEndpointResp := httptest.NewRecorder() invalidEndpointResponse := httptest.NewRecorder()
authHandler.Me(invalidEndpointResp, invalidEndpointReq) ctx.Router.ServeHTTP(invalidEndpointResponse, invalidEndpointRequest)
if invalidEndpointResp.Code == http.StatusOK { if invalidEndpointResponse.Code == http.StatusOK {
t.Error("Expected error for invalid endpoint") t.Error("Expected error for invalid endpoint")
} }
}) })
t.Run("Security_Authentication_Bypass", func(t *testing.T) { t.Run("Security_Authentication_Bypass", func(t *testing.T) {
meReq := httptest.NewRequest("GET", "/api/auth/me", nil) meRequest := httptest.NewRequest("GET", "/api/auth/me", nil)
meResp := httptest.NewRecorder() meResponse := httptest.NewRecorder()
authHandler.Me(meResp, meReq) ctx.Router.ServeHTTP(meResponse, meRequest)
if meResp.Code == http.StatusOK { if meResponse.Code == http.StatusOK {
t.Error("Expected error for unauthenticated request") t.Error("Expected error for unauthenticated request")
} }
invalidTokenReq := httptest.NewRequest("GET", "/api/auth/me", nil) invalidTokenRequest := httptest.NewRequest("GET", "/api/auth/me", nil)
invalidTokenReq.Header.Set("Authorization", "Bearer invalid-token") invalidTokenRequest.Header.Set("Authorization", "Bearer invalid-token")
invalidTokenResp := httptest.NewRecorder() invalidTokenResponse := httptest.NewRecorder()
authHandler.Me(invalidTokenResp, invalidTokenReq) ctx.Router.ServeHTTP(invalidTokenResponse, invalidTokenRequest)
if invalidTokenResp.Code == http.StatusOK { if invalidTokenResponse.Code == http.StatusOK {
t.Error("Expected error for invalid token") t.Error("Expected error for invalid token")
} }
malformedTokenReq := httptest.NewRequest("GET", "/api/auth/me", nil) malformedTokenRequest := httptest.NewRequest("GET", "/api/auth/me", nil)
malformedTokenReq.Header.Set("Authorization", "InvalidFormat token") malformedTokenRequest.Header.Set("Authorization", "InvalidFormat token")
malformedTokenResp := httptest.NewRecorder() malformedTokenResponse := httptest.NewRecorder()
authHandler.Me(malformedTokenResp, malformedTokenReq) ctx.Router.ServeHTTP(malformedTokenResponse, malformedTokenRequest)
if malformedTokenResp.Code == http.StatusOK { if malformedTokenResponse.Code == http.StatusOK {
t.Error("Expected error for malformed token") t.Error("Expected error for malformed token")
} }
}) })
@@ -468,32 +304,21 @@ func TestIntegration_Handlers(t *testing.T) {
t.Run("Security_Input_Sanitization", func(t *testing.T) { t.Run("Security_Input_Sanitization", func(t *testing.T) {
user := createAuthenticatedUser(t, authService, userRepo, "xss_user", "xss@example.com") user := createAuthenticatedUser(t, authService, userRepo, "xss_user", "xss@example.com")
xssData := map[string]string{ xssResponse := makePostRequest(t, ctx.Router, "/api/posts", map[string]any{
"title": "<script>alert('xss')</script>", "title": "<script>alert('xss')</script>",
"url": "https://example.com/xss", "url": "https://example.com/xss",
"content": "XSS test content", "content": "<script>alert('xss')</script>",
} }, user, nil)
xssBody, _ := json.Marshal(xssData) if xssResponse.Code != http.StatusCreated {
xssReq := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer(xssBody)) t.Errorf("Expected status 201 for XSS sanitization, got %d", xssResponse.Code)
xssReq.Header.Set("Content-Type", "application/json")
xssReq.Header.Set("Authorization", "Bearer "+user.Token)
xssReq = testutils.WithUserContext(xssReq, middleware.UserIDKey, user.User.ID)
xssResp := httptest.NewRecorder()
postHandler.CreatePost(xssResp, xssReq)
if xssResp.Code != http.StatusCreated {
t.Errorf("Expected status 201 for XSS sanitization, got %d", xssResp.Code)
} }
var xssResult map[string]any xssResult := assertJSONResponse(t, xssResponse, http.StatusCreated)
if err := json.Unmarshal(xssResp.Body.Bytes(), &xssResult); err != nil {
t.Fatalf("Failed to decode XSS response: %v", err)
}
if success, _ := xssResult["success"].(bool); !success { if success, _ := xssResult["success"].(bool); !success {
t.Error("Expected XSS response to have success=true") t.Error("Expected XSS response to have success=true")
} }
data, ok := xssResult["data"].(map[string]any) data, ok := getDataFromResponse(xssResult)
if !ok { if !ok {
t.Fatalf("Expected data object in XSS response, got %T", xssResult["data"]) t.Fatalf("Expected data object in XSS response, got %T", xssResult["data"])
} }
@@ -522,32 +347,21 @@ func TestIntegration_Handlers(t *testing.T) {
t.Errorf("Expected script tags to be HTML-escaped in content, got: %s", content) t.Errorf("Expected script tags to be HTML-escaped in content, got: %s", content)
} }
sqlData := map[string]string{ sqlResponse := makePostRequest(t, ctx.Router, "/api/posts", map[string]any{
"title": "'; DROP TABLE posts; --", "title": "'; DROP TABLE posts; --",
"url": "https://example.com/sql", "url": "https://example.com/sql",
"content": "SQL injection test", "content": "SQL injection test",
} }, user, nil)
sqlBody, _ := json.Marshal(sqlData) if sqlResponse.Code != http.StatusCreated {
sqlReq := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer(sqlBody)) t.Errorf("Expected status 201 for SQL injection sanitization, got %d", sqlResponse.Code)
sqlReq.Header.Set("Content-Type", "application/json")
sqlReq.Header.Set("Authorization", "Bearer "+user.Token)
sqlReq = testutils.WithUserContext(sqlReq, middleware.UserIDKey, user.User.ID)
sqlResp := httptest.NewRecorder()
postHandler.CreatePost(sqlResp, sqlReq)
if sqlResp.Code != http.StatusCreated {
t.Errorf("Expected status 201 for SQL injection sanitization, got %d", sqlResp.Code)
} }
var sqlResult map[string]any sqlResult := assertJSONResponse(t, sqlResponse, http.StatusCreated)
if err := json.Unmarshal(sqlResp.Body.Bytes(), &sqlResult); err != nil {
t.Fatalf("Failed to decode SQL response: %v", err)
}
if success, _ := sqlResult["success"].(bool); !success { if success, _ := sqlResult["success"].(bool); !success {
t.Error("Expected SQL response to have success=true") t.Error("Expected SQL response to have success=true")
} }
sqlResponseData, ok := sqlResult["data"].(map[string]any) sqlResponseData, ok := getDataFromResponse(sqlResult)
if !ok { if !ok {
t.Fatalf("Expected data object in SQL response, got %T", sqlResult["data"]) t.Fatalf("Expected data object in SQL response, got %T", sqlResult["data"])
} }
@@ -579,63 +393,31 @@ func TestIntegration_Handlers(t *testing.T) {
t.Run("Authorization_User_Access_Control", func(t *testing.T) { t.Run("Authorization_User_Access_Control", func(t *testing.T) {
emailSender.Reset() emailSender.Reset()
user1 := createAuthenticatedUser(t, authService, userRepo, "auth_user1", "auth1@example.com") firstUser := createAuthenticatedUser(t, authService, userRepo, "auth_user1", "auth1@example.com")
user2 := createAuthenticatedUser(t, authService, userRepo, "auth_user2", "auth2@example.com") secondUser := createAuthenticatedUser(t, authService, userRepo, "auth_user2", "auth2@example.com")
post := testutils.CreatePostWithRepo(t, postRepo, user1.User.ID, "Private Post", "https://example.com/private") post := testutils.CreatePostWithRepo(t, postRepo, firstUser.User.ID, "Private Post", "https://example.com/private")
getPostReq := httptest.NewRequest("GET", fmt.Sprintf("/api/posts/%d", post.ID), nil) getPostResponse := makeAuthenticatedGetRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d", post.ID), secondUser, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
getPostReq.Header.Set("Authorization", "Bearer "+user2.Token) testutils.AssertHTTPStatus(t, getPostResponse, http.StatusOK)
getPostReq = testutils.WithUserContext(getPostReq, middleware.UserIDKey, user2.User.ID)
getPostReq = testutils.WithURLParams(getPostReq, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
getPostResp := httptest.NewRecorder()
postHandler.GetPost(getPostResp, getPostReq) updateResponse := makePutRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d", post.ID), map[string]any{"title": "Updated Title"}, secondUser, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
testutils.AssertHTTPStatus(t, getPostResp, http.StatusOK) testutils.AssertHTTPStatus(t, updateResponse, http.StatusForbidden)
updateData := map[string]string{ deleteResponse := makeDeleteRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d", post.ID), secondUser, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
"title": "Updated Title", testutils.AssertHTTPStatus(t, deleteResponse, http.StatusForbidden)
}
updateBody, _ := json.Marshal(updateData)
updateReq := httptest.NewRequest("PUT", fmt.Sprintf("/api/posts/%d", post.ID), bytes.NewBuffer(updateBody))
updateReq.Header.Set("Content-Type", "application/json")
updateReq.Header.Set("Authorization", "Bearer "+user2.Token)
updateReq = testutils.WithUserContext(updateReq, middleware.UserIDKey, user2.User.ID)
updateReq = testutils.WithURLParams(updateReq, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
updateResp := httptest.NewRecorder()
postHandler.UpdatePost(updateResp, updateReq)
testutils.AssertHTTPStatus(t, updateResp, http.StatusForbidden)
deleteReq := httptest.NewRequest("DELETE", fmt.Sprintf("/api/posts/%d", post.ID), nil)
deleteReq.Header.Set("Authorization", "Bearer "+user2.Token)
deleteReq = testutils.WithUserContext(deleteReq, middleware.UserIDKey, user2.User.ID)
deleteReq = testutils.WithURLParams(deleteReq, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
deleteResp := httptest.NewRecorder()
postHandler.DeletePost(deleteResp, deleteReq)
testutils.AssertHTTPStatus(t, deleteResp, http.StatusForbidden)
}) })
t.Run("Authorization_Vote_Access_Control", func(t *testing.T) { t.Run("Authorization_Vote_Access_Control", func(t *testing.T) {
emailSender.Reset() emailSender.Reset()
user1 := createAuthenticatedUser(t, authService, userRepo, "vote_auth_user1", "vote_auth1@example.com") firstUser := createAuthenticatedUser(t, authService, userRepo, "vote_auth_user1", "vote_auth1@example.com")
user2 := createAuthenticatedUser(t, authService, userRepo, "vote_auth_user2", "vote_auth2@example.com") secondUser := createAuthenticatedUser(t, authService, userRepo, "vote_auth_user2", "vote_auth2@example.com")
post := testutils.CreatePostWithRepo(t, postRepo, user1.User.ID, "Vote Auth Post", "https://example.com/vote-auth") post := testutils.CreatePostWithRepo(t, postRepo, firstUser.User.ID, "Vote Auth Post", "https://example.com/vote-auth")
voteData := map[string]string{"type": "up"} voteResponse := makePostRequest(t, ctx.Router, fmt.Sprintf("/api/posts/%d/vote", post.ID), map[string]any{"type": "up"}, secondUser, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
voteBody, _ := json.Marshal(voteData) if voteResponse.Code != http.StatusOK {
voteReq := httptest.NewRequest("POST", fmt.Sprintf("/api/posts/%d/vote", post.ID), bytes.NewBuffer(voteBody)) t.Errorf("Users should be able to vote on any post, got %d", voteResponse.Code)
voteReq.Header.Set("Content-Type", "application/json")
voteReq.Header.Set("Authorization", "Bearer "+user2.Token)
voteReq = testutils.WithUserContext(voteReq, middleware.UserIDKey, user2.User.ID)
voteReq = testutils.WithURLParams(voteReq, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
voteResp := httptest.NewRecorder()
voteHandler.CastVote(voteResp, voteReq)
if voteResp.Code != http.StatusOK {
t.Errorf("Users should be able to vote on any post, got %d", voteResp.Code)
} }
}) })
@@ -664,12 +446,8 @@ func TestIntegration_Handlers(t *testing.T) {
t.Fatalf("Failed to generate expired token: %v", err) t.Fatalf("Failed to generate expired token: %v", err)
} }
meReq := httptest.NewRequest("GET", "/api/auth/me", nil) meResponse := makeRequest(t, ctx.Router, "GET", "/api/auth/me", nil, map[string]string{"Authorization": "Bearer " + expiredToken})
meReq.Header.Set("Authorization", "Bearer "+expiredToken) testutils.AssertHTTPStatus(t, meResponse, http.StatusUnauthorized)
meResp := httptest.NewRecorder()
authHandler.Me(meResp, meReq)
testutils.AssertHTTPStatus(t, meResp, http.StatusUnauthorized)
}) })
t.Run("Authorization_Token_Tampering", func(t *testing.T) { t.Run("Authorization_Token_Tampering", func(t *testing.T) {
@@ -678,12 +456,12 @@ func TestIntegration_Handlers(t *testing.T) {
tamperedToken := user.Token[:len(user.Token)-5] + "XXXXX" tamperedToken := user.Token[:len(user.Token)-5] + "XXXXX"
meReq := httptest.NewRequest("GET", "/api/auth/me", nil) meRequest := httptest.NewRequest("GET", "/api/auth/me", nil)
meReq.Header.Set("Authorization", "Bearer "+tamperedToken) meRequest.Header.Set("Authorization", "Bearer "+tamperedToken)
meResp := httptest.NewRecorder() meResponse := httptest.NewRecorder()
authHandler.Me(meResp, meReq) ctx.Router.ServeHTTP(meResponse, meRequest)
testutils.AssertHTTPStatus(t, meResp, http.StatusUnauthorized) testutils.AssertHTTPStatus(t, meResponse, http.StatusUnauthorized)
}) })
t.Run("Authorization_Session_Version_Mismatch", func(t *testing.T) { t.Run("Authorization_Session_Version_Mismatch", func(t *testing.T) {
@@ -711,12 +489,12 @@ func TestIntegration_Handlers(t *testing.T) {
t.Fatalf("Failed to generate invalid token: %v", err) t.Fatalf("Failed to generate invalid token: %v", err)
} }
meReq := httptest.NewRequest("GET", "/api/auth/me", nil) meRequest := httptest.NewRequest("GET", "/api/auth/me", nil)
meReq.Header.Set("Authorization", "Bearer "+invalidToken) meRequest.Header.Set("Authorization", "Bearer "+invalidToken)
meResp := httptest.NewRecorder() meResponse := httptest.NewRecorder()
authHandler.Me(meResp, meReq) ctx.Router.ServeHTTP(meResponse, meRequest)
testutils.AssertHTTPStatus(t, meResp, http.StatusUnauthorized) testutils.AssertHTTPStatus(t, meResponse, http.StatusUnauthorized)
}) })
} }
@@ -748,11 +526,9 @@ func TestIntegration_DatabaseMonitoring(t *testing.T) {
} }
voteService := services.NewVoteService(voteRepo, postRepo, db) voteService := services.NewVoteService(voteRepo, postRepo, db)
apiHandler := handlers.NewAPIHandlerWithMonitoring(testutils.AppTestConfig, postRepo, userRepo, voteService, db, monitor) apiHandler := handlers.NewAPIHandlerWithMonitoring(testutils.AppTestConfig, postRepo, userRepo, voteService, db, monitor)
t.Run("Health endpoint includes database monitoring", func(t *testing.T) { t.Run("Health endpoint includes database monitoring", func(t *testing.T) {
user := &database.User{ user := &database.User{
Username: "monitoring_user", Username: "monitoring_user",
Email: "monitoring@example.com", Email: "monitoring@example.com",
@@ -765,7 +541,6 @@ func TestIntegration_DatabaseMonitoring(t *testing.T) {
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
apiHandler.GetHealth(recorder, request) apiHandler.GetHealth(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusOK) testutils.AssertHTTPStatus(t, recorder, http.StatusOK)
var response map[string]any var response map[string]any
@@ -777,7 +552,7 @@ func TestIntegration_DatabaseMonitoring(t *testing.T) {
t.Error("Expected success to be true") t.Error("Expected success to be true")
} }
data, ok := response["data"].(map[string]any) data, ok := getDataFromResponse(response)
if !ok { if !ok {
t.Fatal("Expected data to be a map") t.Fatal("Expected data to be a map")
} }
@@ -813,7 +588,6 @@ func TestIntegration_DatabaseMonitoring(t *testing.T) {
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
apiHandler.GetMetrics(recorder, request) apiHandler.GetMetrics(recorder, request)
testutils.AssertHTTPStatus(t, recorder, http.StatusOK) testutils.AssertHTTPStatus(t, recorder, http.StatusOK)
var response map[string]any var response map[string]any
@@ -825,7 +599,7 @@ func TestIntegration_DatabaseMonitoring(t *testing.T) {
t.Error("Expected success to be true") t.Error("Expected success to be true")
} }
data, ok := response["data"].(map[string]any) data, ok := getDataFromResponse(response)
if !ok { if !ok {
t.Fatal("Expected data to be a map") t.Fatal("Expected data to be a map")
} }

View File

@@ -1,6 +1,7 @@
package integration package integration
import ( import (
"bytes"
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/http" "net/http"
@@ -282,14 +283,14 @@ func setupPageHandlerTestContext(t *testing.T) *testContext {
func getCSRFToken(t *testing.T, router http.Handler, path string, cookies ...*http.Cookie) string { func getCSRFToken(t *testing.T, router http.Handler, path string, cookies ...*http.Cookie) string {
t.Helper() t.Helper()
req := httptest.NewRequest("GET", path, nil) request := httptest.NewRequest("GET", path, nil)
for _, cookie := range cookies { for _, cookie := range cookies {
req.AddCookie(cookie) request.AddCookie(cookie)
} }
rec := httptest.NewRecorder() recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req) router.ServeHTTP(recorder, request)
cookieList := rec.Result().Cookies() cookieList := recorder.Result().Cookies()
for _, cookie := range cookieList { for _, cookie := range cookieList {
if cookie.Name == "csrf_token" { if cookie.Name == "csrf_token" {
return cookie.Value return cookie.Value
@@ -299,32 +300,32 @@ func getCSRFToken(t *testing.T, router http.Handler, path string, cookies ...*ht
return "" return ""
} }
func assertJSONResponse(t *testing.T, rec *httptest.ResponseRecorder, expectedStatus int) map[string]any { func assertJSONResponse(t *testing.T, recorder *httptest.ResponseRecorder, expectedStatus int) map[string]any {
t.Helper() t.Helper()
if rec.Code != expectedStatus { if recorder.Code != expectedStatus {
t.Errorf("Expected status %d, got %d. Body: %s", expectedStatus, rec.Code, rec.Body.String()) t.Errorf("Expected status %d, got %d. Body: %s", expectedStatus, recorder.Code, recorder.Body.String())
return nil return nil
} }
var response map[string]any var response map[string]any
if err := json.NewDecoder(rec.Body).Decode(&response); err != nil { if err := json.NewDecoder(recorder.Body).Decode(&response); err != nil {
t.Fatalf("Failed to decode response: %v. Body: %s", err, rec.Body.String()) t.Fatalf("Failed to decode response: %v. Body: %s", err, recorder.Body.String())
return nil return nil
} }
return response return response
} }
func assertErrorResponse(t *testing.T, rec *httptest.ResponseRecorder, expectedStatus int) { func assertErrorResponse(t *testing.T, recorder *httptest.ResponseRecorder, expectedStatus int) {
t.Helper() t.Helper()
if rec.Code != expectedStatus { if recorder.Code != expectedStatus {
t.Errorf("Expected status %d, got %d. Body: %s", expectedStatus, rec.Code, rec.Body.String()) t.Errorf("Expected status %d, got %d. Body: %s", expectedStatus, recorder.Code, recorder.Body.String())
return return
} }
var response map[string]any var response map[string]any
if err := json.NewDecoder(rec.Body).Decode(&response); err != nil { if err := json.NewDecoder(recorder.Body).Decode(&response); err != nil {
t.Fatalf("Failed to decode error response: %v. Body: %s", err, rec.Body.String()) t.Fatalf("Failed to decode error response: %v. Body: %s", err, recorder.Body.String())
return return
} }
@@ -335,23 +336,23 @@ func assertErrorResponse(t *testing.T, rec *httptest.ResponseRecorder, expectedS
} }
} }
func assertStatus(t *testing.T, rec *httptest.ResponseRecorder, expectedStatus int) { func assertStatus(t *testing.T, recorder *httptest.ResponseRecorder, expectedStatus int) {
t.Helper() t.Helper()
if rec.Code != expectedStatus { if recorder.Code != expectedStatus {
t.Errorf("Expected status %d, got %d. Body: %s", expectedStatus, rec.Code, rec.Body.String()) t.Errorf("Expected status %d, got %d. Body: %s", expectedStatus, recorder.Code, recorder.Body.String())
} }
} }
func assertStatusRange(t *testing.T, rec *httptest.ResponseRecorder, minStatus, maxStatus int) { func assertStatusRange(t *testing.T, recorder *httptest.ResponseRecorder, minStatus, maxStatus int) {
t.Helper() t.Helper()
if rec.Code < minStatus || rec.Code > maxStatus { if recorder.Code < minStatus || recorder.Code > maxStatus {
t.Errorf("Expected status between %d and %d, got %d. Body: %s", minStatus, maxStatus, rec.Code, rec.Body.String()) t.Errorf("Expected status between %d and %d, got %d. Body: %s", minStatus, maxStatus, recorder.Code, recorder.Body.String())
} }
} }
func assertCookie(t *testing.T, rec *httptest.ResponseRecorder, name, expectedValue string) { func assertCookie(t *testing.T, recorder *httptest.ResponseRecorder, name, expectedValue string) {
t.Helper() t.Helper()
cookies := rec.Result().Cookies() cookies := recorder.Result().Cookies()
for _, cookie := range cookies { for _, cookie := range cookies {
if cookie.Name == name { if cookie.Name == name {
if expectedValue != "" && cookie.Value != expectedValue { if expectedValue != "" && cookie.Value != expectedValue {
@@ -363,9 +364,9 @@ func assertCookie(t *testing.T, rec *httptest.ResponseRecorder, name, expectedVa
t.Errorf("Expected cookie %s not found", name) t.Errorf("Expected cookie %s not found", name)
} }
func assertCookieCleared(t *testing.T, rec *httptest.ResponseRecorder, name string) { func assertCookieCleared(t *testing.T, recorder *httptest.ResponseRecorder, name string) {
t.Helper() t.Helper()
cookies := rec.Result().Cookies() cookies := recorder.Result().Cookies()
for _, cookie := range cookies { for _, cookie := range cookies {
if cookie.Name == name { if cookie.Name == name {
if cookie.Value != "" { if cookie.Value != "" {
@@ -376,21 +377,17 @@ func assertCookieCleared(t *testing.T, rec *httptest.ResponseRecorder, name stri
} }
} }
func assertHeader(t *testing.T, rec *httptest.ResponseRecorder, name, expectedValue string) { func assertHeader(t *testing.T, recorder *httptest.ResponseRecorder, name string) {
t.Helper() t.Helper()
actualValue := rec.Header().Get(name) actualValue := recorder.Header().Get(name)
if expectedValue == "" {
if actualValue == "" { if actualValue == "" {
t.Errorf("Expected header %s to be present", name) t.Errorf("Expected header %s to be present", name)
} }
} else if actualValue != expectedValue {
t.Errorf("Expected header %s=%s, got %s", name, expectedValue, actualValue)
}
} }
func assertHeaderContains(t *testing.T, rec *httptest.ResponseRecorder, name, substring string) { func assertHeaderContains(t *testing.T, recorder *httptest.ResponseRecorder, name, substring string) {
t.Helper() t.Helper()
actualValue := rec.Header().Get(name) actualValue := recorder.Header().Get(name)
if !strings.Contains(actualValue, substring) { if !strings.Contains(actualValue, substring) {
t.Errorf("Expected header %s to contain %s, got %s", name, substring, actualValue) t.Errorf("Expected header %s to contain %s, got %s", name, substring, actualValue)
} }
@@ -450,3 +447,83 @@ func createUserWithCleanup(t *testing.T, ctx *testContext, username, email strin
}) })
return user return user
} }
func makeRequest(t *testing.T, router http.Handler, method, path string, body []byte, headers map[string]string) *httptest.ResponseRecorder {
t.Helper()
var requestBody *bytes.Buffer
if body != nil {
requestBody = bytes.NewBuffer(body)
} else {
requestBody = bytes.NewBuffer(nil)
}
request := httptest.NewRequest(method, path, requestBody)
for key, value := range headers {
request.Header.Set(key, value)
}
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, request)
return recorder
}
func makeAuthenticatedRequest(t *testing.T, router http.Handler, method, path string, body []byte, user *authenticatedUser, urlParams map[string]string) *httptest.ResponseRecorder {
t.Helper()
var requestBody *bytes.Buffer
if body != nil {
requestBody = bytes.NewBuffer(body)
} else {
requestBody = bytes.NewBuffer(nil)
}
request := httptest.NewRequest(method, path, requestBody)
request.Header.Set("Authorization", "Bearer "+user.Token)
if body != nil {
request.Header.Set("Content-Type", "application/json")
}
request = testutils.WithUserContext(request, middleware.UserIDKey, user.User.ID)
if urlParams != nil {
request = testutils.WithURLParams(request, urlParams)
}
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, request)
return recorder
}
func makeGetRequest(t *testing.T, router http.Handler, path string) *httptest.ResponseRecorder {
t.Helper()
return makeRequest(t, router, "GET", path, nil, nil)
}
func makeAuthenticatedGetRequest(t *testing.T, router http.Handler, path string, user *authenticatedUser, urlParams map[string]string) *httptest.ResponseRecorder {
t.Helper()
return makeAuthenticatedRequest(t, router, "GET", path, nil, user, urlParams)
}
func makePostRequest(t *testing.T, router http.Handler, path string, body map[string]any, user *authenticatedUser, urlParams map[string]string) *httptest.ResponseRecorder {
t.Helper()
bodyBytes, _ := json.Marshal(body)
return makeAuthenticatedRequest(t, router, "POST", path, bodyBytes, user, urlParams)
}
func makePutRequest(t *testing.T, router http.Handler, path string, body map[string]any, user *authenticatedUser, urlParams map[string]string) *httptest.ResponseRecorder {
t.Helper()
bodyBytes, _ := json.Marshal(body)
return makeAuthenticatedRequest(t, router, "PUT", path, bodyBytes, user, urlParams)
}
func makeDeleteRequest(t *testing.T, router http.Handler, path string, user *authenticatedUser, urlParams map[string]string) *httptest.ResponseRecorder {
t.Helper()
return makeAuthenticatedRequest(t, router, "DELETE", path, nil, user, urlParams)
}
func makePostRequestWithJSON(t *testing.T, router http.Handler, path string, body map[string]any) *httptest.ResponseRecorder {
t.Helper()
bodyBytes, _ := json.Marshal(body)
return makeRequest(t, router, "POST", path, bodyBytes, map[string]string{"Content-Type": "application/json"})
}
func getDataFromResponse(response map[string]any) (map[string]any, bool) {
if response == nil {
return nil, false
}
data, ok := response["data"].(map[string]any)
return data, ok
}

View File

@@ -22,26 +22,26 @@ func TestIntegration_PageHandlerFormWorkflows(t *testing.T) {
ctx.Suite.EmailSender.Reset() ctx.Suite.EmailSender.Reset()
user := createAuthenticatedUser(t, authService, ctx.Suite.UserRepo, "settings_email_user", "settings_email@example.com") user := createAuthenticatedUser(t, authService, ctx.Suite.UserRepo, "settings_email_user", "settings_email@example.com")
getReq := httptest.NewRequest("GET", "/settings", nil) getRequest := httptest.NewRequest("GET", "/settings", nil)
getReq.AddCookie(&http.Cookie{Name: "auth_token", Value: user.Token}) getRequest.AddCookie(&http.Cookie{Name: "auth_token", Value: user.Token})
getRec := httptest.NewRecorder() getRecorder := httptest.NewRecorder()
router.ServeHTTP(getRec, getReq) router.ServeHTTP(getRecorder, getRequest)
csrfToken := getCSRFToken(t, router, "/settings", &http.Cookie{Name: "auth_token", Value: user.Token}) csrfToken := getCSRFToken(t, router, "/settings", &http.Cookie{Name: "auth_token", Value: user.Token})
reqBody := url.Values{} requestBody := url.Values{}
reqBody.Set("email", "newemail@example.com") requestBody.Set("email", "newemail@example.com")
reqBody.Set("csrf_token", csrfToken) requestBody.Set("csrf_token", csrfToken)
req := httptest.NewRequest("POST", "/settings/email", strings.NewReader(reqBody.Encode())) request := httptest.NewRequest("POST", "/settings/email", strings.NewReader(requestBody.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") request.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.AddCookie(&http.Cookie{Name: "auth_token", Value: user.Token}) request.AddCookie(&http.Cookie{Name: "auth_token", Value: user.Token})
req.AddCookie(&http.Cookie{Name: "csrf_token", Value: csrfToken}) request.AddCookie(&http.Cookie{Name: "csrf_token", Value: csrfToken})
rec := httptest.NewRecorder() recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req) router.ServeHTTP(recorder, request)
assertStatusRange(t, rec, http.StatusOK, http.StatusSeeOther) assertStatusRange(t, recorder, http.StatusOK, http.StatusSeeOther)
}) })
t.Run("Settings_Username_Update_Form", func(t *testing.T) { t.Run("Settings_Username_Update_Form", func(t *testing.T) {
@@ -51,19 +51,19 @@ func TestIntegration_PageHandlerFormWorkflows(t *testing.T) {
csrfToken := getCSRFToken(t, router, "/settings", &http.Cookie{Name: "auth_token", Value: user.Token}) csrfToken := getCSRFToken(t, router, "/settings", &http.Cookie{Name: "auth_token", Value: user.Token})
reqBody := url.Values{} requestBody := url.Values{}
reqBody.Set("username", "new_username") requestBody.Set("username", "new_username")
reqBody.Set("csrf_token", csrfToken) requestBody.Set("csrf_token", csrfToken)
req := httptest.NewRequest("POST", "/settings/username", strings.NewReader(reqBody.Encode())) request := httptest.NewRequest("POST", "/settings/username", strings.NewReader(requestBody.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") request.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.AddCookie(&http.Cookie{Name: "auth_token", Value: user.Token}) request.AddCookie(&http.Cookie{Name: "auth_token", Value: user.Token})
req.AddCookie(&http.Cookie{Name: "csrf_token", Value: csrfToken}) request.AddCookie(&http.Cookie{Name: "csrf_token", Value: csrfToken})
rec := httptest.NewRecorder() recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req) router.ServeHTTP(recorder, request)
assertStatusRange(t, rec, http.StatusOK, http.StatusSeeOther) assertStatusRange(t, recorder, http.StatusOK, http.StatusSeeOther)
}) })
t.Run("Settings_Password_Update_Form", func(t *testing.T) { t.Run("Settings_Password_Update_Form", func(t *testing.T) {
@@ -74,20 +74,20 @@ func TestIntegration_PageHandlerFormWorkflows(t *testing.T) {
csrfToken := getCSRFToken(t, freshCtx.Router, "/settings", &http.Cookie{Name: "auth_token", Value: user.Token}) csrfToken := getCSRFToken(t, freshCtx.Router, "/settings", &http.Cookie{Name: "auth_token", Value: user.Token})
reqBody := url.Values{} requestBody := url.Values{}
reqBody.Set("current_password", "SecurePass123!") requestBody.Set("current_password", "SecurePass123!")
reqBody.Set("new_password", "NewSecurePass123!") requestBody.Set("new_password", "NewSecurePass123!")
reqBody.Set("csrf_token", csrfToken) requestBody.Set("csrf_token", csrfToken)
req := httptest.NewRequest("POST", "/settings/password", strings.NewReader(reqBody.Encode())) request := httptest.NewRequest("POST", "/settings/password", strings.NewReader(requestBody.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") request.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.AddCookie(&http.Cookie{Name: "auth_token", Value: user.Token}) request.AddCookie(&http.Cookie{Name: "auth_token", Value: user.Token})
req.AddCookie(&http.Cookie{Name: "csrf_token", Value: csrfToken}) request.AddCookie(&http.Cookie{Name: "csrf_token", Value: csrfToken})
rec := httptest.NewRecorder() recorder := httptest.NewRecorder()
freshCtx.Router.ServeHTTP(rec, req) freshCtx.Router.ServeHTTP(recorder, request)
assertStatusRange(t, rec, http.StatusOK, http.StatusSeeOther) assertStatusRange(t, recorder, http.StatusOK, http.StatusSeeOther)
}) })
t.Run("Logout_Page_Handler", func(t *testing.T) { t.Run("Logout_Page_Handler", func(t *testing.T) {
@@ -98,19 +98,19 @@ func TestIntegration_PageHandlerFormWorkflows(t *testing.T) {
csrfToken := getCSRFToken(t, freshCtx.Router, "/settings", &http.Cookie{Name: "auth_token", Value: user.Token}) csrfToken := getCSRFToken(t, freshCtx.Router, "/settings", &http.Cookie{Name: "auth_token", Value: user.Token})
reqBody := url.Values{} requestBody := url.Values{}
reqBody.Set("csrf_token", csrfToken) requestBody.Set("csrf_token", csrfToken)
req := httptest.NewRequest("POST", "/logout", strings.NewReader(reqBody.Encode())) request := httptest.NewRequest("POST", "/logout", strings.NewReader(requestBody.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") request.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.AddCookie(&http.Cookie{Name: "auth_token", Value: user.Token}) request.AddCookie(&http.Cookie{Name: "auth_token", Value: user.Token})
req.AddCookie(&http.Cookie{Name: "csrf_token", Value: csrfToken}) request.AddCookie(&http.Cookie{Name: "csrf_token", Value: csrfToken})
rec := httptest.NewRecorder() recorder := httptest.NewRecorder()
freshCtx.Router.ServeHTTP(rec, req) freshCtx.Router.ServeHTTP(recorder, request)
assertStatus(t, rec, http.StatusSeeOther) assertStatus(t, recorder, http.StatusSeeOther)
assertCookieCleared(t, rec, "auth_token") assertCookieCleared(t, recorder, "auth_token")
}) })
t.Run("Resend_Verification_Page_Handler", func(t *testing.T) { t.Run("Resend_Verification_Page_Handler", func(t *testing.T) {
@@ -120,18 +120,18 @@ func TestIntegration_PageHandlerFormWorkflows(t *testing.T) {
csrfToken := getCSRFToken(t, freshCtx.Router, "/resend-verification") csrfToken := getCSRFToken(t, freshCtx.Router, "/resend-verification")
reqBody := url.Values{} requestBody := url.Values{}
reqBody.Set("email", "resend_page@example.com") requestBody.Set("email", "resend_page@example.com")
reqBody.Set("csrf_token", csrfToken) requestBody.Set("csrf_token", csrfToken)
req := httptest.NewRequest("POST", "/resend-verification", strings.NewReader(reqBody.Encode())) request := httptest.NewRequest("POST", "/resend-verification", strings.NewReader(requestBody.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") request.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.AddCookie(&http.Cookie{Name: "csrf_token", Value: csrfToken}) request.AddCookie(&http.Cookie{Name: "csrf_token", Value: csrfToken})
rec := httptest.NewRecorder() recorder := httptest.NewRecorder()
freshCtx.Router.ServeHTTP(rec, req) freshCtx.Router.ServeHTTP(recorder, request)
assertStatusRange(t, rec, http.StatusOK, http.StatusSeeOther) assertStatusRange(t, recorder, http.StatusOK, http.StatusSeeOther)
}) })
t.Run("Post_Vote_Page_Handler", func(t *testing.T) { t.Run("Post_Vote_Page_Handler", func(t *testing.T) {
@@ -142,26 +142,26 @@ func TestIntegration_PageHandlerFormWorkflows(t *testing.T) {
post := testutils.CreatePostWithRepo(t, freshCtx.Suite.PostRepo, user.User.ID, "Vote Page Test", "https://example.com/vote-page") post := testutils.CreatePostWithRepo(t, freshCtx.Suite.PostRepo, user.User.ID, "Vote Page Test", "https://example.com/vote-page")
getReq := httptest.NewRequest("GET", fmt.Sprintf("/posts/%d", post.ID), nil) getRequest := httptest.NewRequest("GET", fmt.Sprintf("/posts/%d", post.ID), nil)
getRec := httptest.NewRecorder() getRecorder := httptest.NewRecorder()
freshCtx.Router.ServeHTTP(getRec, getReq) freshCtx.Router.ServeHTTP(getRecorder, getRequest)
csrfToken := getCSRFToken(t, freshCtx.Router, fmt.Sprintf("/posts/%d", post.ID)) csrfToken := getCSRFToken(t, freshCtx.Router, fmt.Sprintf("/posts/%d", post.ID))
reqBody := url.Values{} requestBody := url.Values{}
reqBody.Set("action", "up") requestBody.Set("action", "up")
reqBody.Set("csrf_token", csrfToken) requestBody.Set("csrf_token", csrfToken)
req := httptest.NewRequest("POST", fmt.Sprintf("/posts/%d/vote", post.ID), strings.NewReader(reqBody.Encode())) request := httptest.NewRequest("POST", fmt.Sprintf("/posts/%d/vote", post.ID), strings.NewReader(requestBody.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") request.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.AddCookie(&http.Cookie{Name: "auth_token", Value: user.Token}) request.AddCookie(&http.Cookie{Name: "auth_token", Value: user.Token})
req.AddCookie(&http.Cookie{Name: "csrf_token", Value: csrfToken}) request.AddCookie(&http.Cookie{Name: "csrf_token", Value: csrfToken})
req = testutils.WithURLParams(req, map[string]string{"id": fmt.Sprintf("%d", post.ID)}) request = testutils.WithURLParams(request, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
rec := httptest.NewRecorder() recorder := httptest.NewRecorder()
freshCtx.Router.ServeHTTP(rec, req) freshCtx.Router.ServeHTTP(recorder, request)
assertStatusRange(t, rec, http.StatusOK, http.StatusSeeOther) assertStatusRange(t, recorder, http.StatusOK, http.StatusSeeOther)
}) })
t.Run("Login_Page_Handler_Workflow", func(t *testing.T) { t.Run("Login_Page_Handler_Workflow", func(t *testing.T) {
@@ -172,20 +172,20 @@ func TestIntegration_PageHandlerFormWorkflows(t *testing.T) {
csrfToken := getCSRFToken(t, freshCtx.Router, "/login") csrfToken := getCSRFToken(t, freshCtx.Router, "/login")
reqBody := url.Values{} requestBody := url.Values{}
reqBody.Set("username", "login_page_user") requestBody.Set("username", "login_page_user")
reqBody.Set("password", "SecurePass123!") requestBody.Set("password", "SecurePass123!")
reqBody.Set("csrf_token", csrfToken) requestBody.Set("csrf_token", csrfToken)
req := httptest.NewRequest("POST", "/login", strings.NewReader(reqBody.Encode())) request := httptest.NewRequest("POST", "/login", strings.NewReader(requestBody.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") request.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.AddCookie(&http.Cookie{Name: "csrf_token", Value: csrfToken}) request.AddCookie(&http.Cookie{Name: "csrf_token", Value: csrfToken})
rec := httptest.NewRecorder() recorder := httptest.NewRecorder()
freshCtx.Router.ServeHTTP(rec, req) freshCtx.Router.ServeHTTP(recorder, request)
assertStatus(t, rec, http.StatusSeeOther) assertStatus(t, recorder, http.StatusSeeOther)
assertCookie(t, rec, "auth_token", "") assertCookie(t, recorder, "auth_token", "")
}) })
t.Run("Email_Confirmation_Page_Handler", func(t *testing.T) { t.Run("Email_Confirmation_Page_Handler", func(t *testing.T) {
@@ -198,11 +198,11 @@ func TestIntegration_PageHandlerFormWorkflows(t *testing.T) {
token = "test-token" token = "test-token"
} }
req := httptest.NewRequest("GET", "/confirm?token="+url.QueryEscape(token), nil) request := httptest.NewRequest("GET", "/confirm?token="+url.QueryEscape(token), nil)
rec := httptest.NewRecorder() recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req) router.ServeHTTP(recorder, request)
assertStatusRange(t, rec, http.StatusOK, http.StatusSeeOther) assertStatusRange(t, recorder, http.StatusOK, http.StatusSeeOther)
}) })
} }

View File

@@ -16,63 +16,62 @@ func TestIntegration_PageHandler(t *testing.T) {
router := ctx.Router router := ctx.Router
t.Run("Home_Page_Renders", func(t *testing.T) { t.Run("Home_Page_Renders", func(t *testing.T) {
req := httptest.NewRequest("GET", "/", nil) request := httptest.NewRequest("GET", "/", nil)
rec := httptest.NewRecorder() recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req) router.ServeHTTP(recorder, request)
assertStatus(t, rec, http.StatusOK) assertStatus(t, recorder, http.StatusOK)
if !strings.Contains(rec.Body.String(), "<html") { if !strings.Contains(recorder.Body.String(), "<html") {
t.Error("Expected HTML content") t.Error("Expected HTML content")
} }
}) })
t.Run("Login_Form_Renders", func(t *testing.T) { t.Run("Login_Form_Renders", func(t *testing.T) {
req := httptest.NewRequest("GET", "/login", nil) request := httptest.NewRequest("GET", "/login", nil)
rec := httptest.NewRecorder() recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req) router.ServeHTTP(recorder, request)
assertStatus(t, rec, http.StatusOK) assertStatus(t, recorder, http.StatusOK)
body := rec.Body.String() body := recorder.Body.String()
if !strings.Contains(body, "login") && !strings.Contains(body, "Login") { if !strings.Contains(body, "login") && !strings.Contains(body, "Login") {
t.Error("Expected login form content") t.Error("Expected login form content")
} }
}) })
t.Run("Register_Form_Renders", func(t *testing.T) { t.Run("Register_Form_Renders", func(t *testing.T) {
req := httptest.NewRequest("GET", "/register", nil) request := httptest.NewRequest("GET", "/register", nil)
rec := httptest.NewRecorder() recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req) router.ServeHTTP(recorder, request)
assertStatus(t, recorder, http.StatusOK)
assertStatus(t, rec, http.StatusOK) body := recorder.Body.String()
body := rec.Body.String()
if !strings.Contains(body, "register") && !strings.Contains(body, "Register") { if !strings.Contains(body, "register") && !strings.Contains(body, "Register") {
t.Error("Expected register form content") t.Error("Expected register form content")
} }
}) })
t.Run("PageHandler_With_CSRF_Token", func(t *testing.T) { t.Run("PageHandler_With_CSRF_Token", func(t *testing.T) {
req := httptest.NewRequest("GET", "/register", nil) request := httptest.NewRequest("GET", "/register", nil)
rec := httptest.NewRecorder() recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req) router.ServeHTTP(recorder, request)
assertCookie(t, rec, "csrf_token", "") assertCookie(t, recorder, "csrf_token", "")
}) })
t.Run("PageHandler_Form_Submission", func(t *testing.T) { t.Run("PageHandler_Form_Submission", func(t *testing.T) {
ctx.Suite.EmailSender.Reset() ctx.Suite.EmailSender.Reset()
getReq := httptest.NewRequest("GET", "/register", nil) getRequest := httptest.NewRequest("GET", "/register", nil)
getRec := httptest.NewRecorder() getRecorder := httptest.NewRecorder()
router.ServeHTTP(getRec, getReq) router.ServeHTTP(getRecorder, getRequest)
cookies := getRec.Result().Cookies() cookies := getRecorder.Result().Cookies()
var csrfCookie *http.Cookie var csrfCookie *http.Cookie
for _, cookie := range cookies { for _, cookie := range cookies {
if cookie.Name == "csrf_token" { if cookie.Name == "csrf_token" {
@@ -85,33 +84,33 @@ func TestIntegration_PageHandler(t *testing.T) {
t.Fatal("Expected CSRF cookie") t.Fatal("Expected CSRF cookie")
} }
reqBody := url.Values{} requestBody := url.Values{}
reqBody.Set("username", "page_form_user") requestBody.Set("username", "page_form_user")
reqBody.Set("email", "page_form@example.com") requestBody.Set("email", "page_form@example.com")
reqBody.Set("password", "SecurePass123!") requestBody.Set("password", "SecurePass123!")
reqBody.Set("csrf_token", csrfCookie.Value) requestBody.Set("csrf_token", csrfCookie.Value)
req := httptest.NewRequest("POST", "/register", strings.NewReader(reqBody.Encode())) request := httptest.NewRequest("POST", "/register", strings.NewReader(requestBody.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") request.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.AddCookie(csrfCookie) request.AddCookie(csrfCookie)
rec := httptest.NewRecorder() recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req) router.ServeHTTP(recorder, request)
assertStatusRange(t, rec, http.StatusOK, http.StatusSeeOther) assertStatusRange(t, recorder, http.StatusOK, http.StatusSeeOther)
}) })
t.Run("PageHandler_Authenticated_Access", func(t *testing.T) { t.Run("PageHandler_Authenticated_Access", func(t *testing.T) {
ctx.Suite.EmailSender.Reset() ctx.Suite.EmailSender.Reset()
user := createUserWithCleanup(t, ctx, "page_auth_user", "page_auth@example.com") user := createUserWithCleanup(t, ctx, "page_auth_user", "page_auth@example.com")
req := httptest.NewRequest("GET", "/settings", nil) request := httptest.NewRequest("GET", "/settings", nil)
req.AddCookie(&http.Cookie{Name: "auth_token", Value: user.Token}) request.AddCookie(&http.Cookie{Name: "auth_token", Value: user.Token})
rec := httptest.NewRecorder() recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req) router.ServeHTTP(recorder, request)
assertStatus(t, rec, http.StatusOK) assertStatus(t, recorder, http.StatusOK)
}) })
t.Run("PageHandler_Post_Display", func(t *testing.T) { t.Run("PageHandler_Post_Display", func(t *testing.T) {
@@ -120,34 +119,34 @@ func TestIntegration_PageHandler(t *testing.T) {
post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Page Test Post", "https://example.com/page-test") post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Page Test Post", "https://example.com/page-test")
req := httptest.NewRequest("GET", "/posts/"+fmt.Sprintf("%d", post.ID), nil) request := httptest.NewRequest("GET", "/posts/"+fmt.Sprintf("%d", post.ID), nil)
rec := httptest.NewRecorder() recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req) router.ServeHTTP(recorder, request)
assertStatus(t, rec, http.StatusOK) assertStatus(t, recorder, http.StatusOK)
body := rec.Body.String() body := recorder.Body.String()
if !strings.Contains(body, "Page Test Post") { if !strings.Contains(body, "Page Test Post") {
t.Error("Expected post title in page") t.Error("Expected post title in page")
} }
}) })
t.Run("PageHandler_Search_Page", func(t *testing.T) { t.Run("PageHandler_Search_Page", func(t *testing.T) {
req := httptest.NewRequest("GET", "/search?q=test", nil) request := httptest.NewRequest("GET", "/search?q=test", nil)
rec := httptest.NewRecorder() recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req) router.ServeHTTP(recorder, request)
assertStatus(t, rec, http.StatusOK) assertStatus(t, recorder, http.StatusOK)
}) })
t.Run("PageHandler_Error_Handling", func(t *testing.T) { t.Run("PageHandler_Error_Handling", func(t *testing.T) {
req := httptest.NewRequest("GET", "/nonexistent", nil) request := httptest.NewRequest("GET", "/nonexistent", nil)
rec := httptest.NewRecorder() recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req) router.ServeHTTP(recorder, request)
assertStatus(t, rec, http.StatusNotFound) assertStatus(t, recorder, http.StatusNotFound)
}) })
} }

View File

@@ -1,8 +1,6 @@
package integration package integration
import ( import (
"bytes"
"encoding/json"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url" "net/url"
@@ -32,17 +30,12 @@ func TestIntegration_PasswordReset_CompleteFlow(t *testing.T) {
t.Fatalf("Failed to create user: %v", err) t.Fatalf("Failed to create user: %v", err)
} }
reqBody := map[string]string{ reqBody := map[string]any{
"username_or_email": "reset_user", "username_or_email": "reset_user",
} }
body, _ := json.Marshal(reqBody) request := makePostRequestWithJSON(t, router, "/api/auth/forgot-password", reqBody)
req := httptest.NewRequest("POST", "/api/auth/forgot-password", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req) response := assertJSONResponse(t, request, http.StatusOK)
response := assertJSONResponse(t, rec, http.StatusOK)
if response != nil { if response != nil {
if success, ok := response["success"].(bool); !ok || !success { if success, ok := response["success"].(bool); !ok || !success {
t.Error("Expected success=true") t.Error("Expected success=true")
@@ -77,18 +70,13 @@ func TestIntegration_PasswordReset_CompleteFlow(t *testing.T) {
t.Fatal("Expected password reset token") t.Fatal("Expected password reset token")
} }
reqBody := map[string]string{ reqBody := map[string]any{
"token": resetToken, "token": resetToken,
"new_password": "NewPassword123!", "new_password": "NewPassword123!",
} }
body, _ := json.Marshal(reqBody) request := makePostRequestWithJSON(t, router, "/api/auth/reset-password", reqBody)
req := httptest.NewRequest("POST", "/api/auth/reset-password", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req) assertStatus(t, request, http.StatusOK)
assertStatus(t, rec, http.StatusOK)
loginResult, err := ctx.AuthService.Login("reset_complete_user", "NewPassword123!") loginResult, err := ctx.AuthService.Login("reset_complete_user", "NewPassword123!")
if err != nil { if err != nil {
@@ -120,14 +108,14 @@ func TestIntegration_PasswordReset_CompleteFlow(t *testing.T) {
reqBody := url.Values{} reqBody := url.Values{}
reqBody.Set("username_or_email", "page_reset_user") reqBody.Set("username_or_email", "page_reset_user")
reqBody.Set("csrf_token", csrfToken) reqBody.Set("csrf_token", csrfToken)
req := httptest.NewRequest("POST", "/forgot-password", strings.NewReader(reqBody.Encode())) request := httptest.NewRequest("POST", "/forgot-password", strings.NewReader(reqBody.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") request.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.AddCookie(&http.Cookie{Name: "csrf_token", Value: csrfToken}) request.AddCookie(&http.Cookie{Name: "csrf_token", Value: csrfToken})
rec := httptest.NewRecorder() recorder := httptest.NewRecorder()
pageRouter.ServeHTTP(rec, req) pageRouter.ServeHTTP(recorder, request)
assertStatusRange(t, rec, http.StatusOK, http.StatusSeeOther) assertStatusRange(t, recorder, http.StatusOK, http.StatusSeeOther)
resetToken := pageCtx.Suite.EmailSender.PasswordResetToken() resetToken := pageCtx.Suite.EmailSender.PasswordResetToken()
if resetToken == "" { if resetToken == "" {
@@ -166,33 +154,23 @@ func TestIntegration_PasswordReset_CompleteFlow(t *testing.T) {
t.Fatalf("Failed to update user: %v", err) t.Fatalf("Failed to update user: %v", err)
} }
reqBody := map[string]string{ reqBody := map[string]any{
"token": resetToken, "token": resetToken,
"new_password": "NewPassword123!", "new_password": "NewPassword123!",
} }
body, _ := json.Marshal(reqBody) request := makePostRequestWithJSON(t, router, "/api/auth/reset-password", reqBody)
req := httptest.NewRequest("POST", "/api/auth/reset-password", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req) assertErrorResponse(t, request, http.StatusBadRequest)
assertErrorResponse(t, rec, http.StatusBadRequest)
}) })
t.Run("PasswordReset_InvalidToken", func(t *testing.T) { t.Run("PasswordReset_InvalidToken", func(t *testing.T) {
reqBody := map[string]string{ reqBody := map[string]any{
"token": "invalid-token", "token": "invalid-token",
"new_password": "NewPassword123!", "new_password": "NewPassword123!",
} }
body, _ := json.Marshal(reqBody) request := makePostRequestWithJSON(t, router, "/api/auth/reset-password", reqBody)
req := httptest.NewRequest("POST", "/api/auth/reset-password", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req) assertErrorResponse(t, request, http.StatusBadRequest)
assertErrorResponse(t, rec, http.StatusBadRequest)
}) })
t.Run("PasswordReset_WeakPassword", func(t *testing.T) { t.Run("PasswordReset_WeakPassword", func(t *testing.T) {
@@ -214,18 +192,13 @@ func TestIntegration_PasswordReset_CompleteFlow(t *testing.T) {
resetToken := ctx.Suite.EmailSender.PasswordResetToken() resetToken := ctx.Suite.EmailSender.PasswordResetToken()
reqBody := map[string]string{ reqBody := map[string]any{
"token": resetToken, "token": resetToken,
"new_password": "123", "new_password": "123",
} }
body, _ := json.Marshal(reqBody) request := makePostRequestWithJSON(t, router, "/api/auth/reset-password", reqBody)
req := httptest.NewRequest("POST", "/api/auth/reset-password", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req) assertErrorResponse(t, request, http.StatusBadRequest)
assertErrorResponse(t, rec, http.StatusBadRequest)
}) })
t.Run("PasswordReset_EmailIntegration", func(t *testing.T) { t.Run("PasswordReset_EmailIntegration", func(t *testing.T) {
@@ -243,17 +216,12 @@ func TestIntegration_PasswordReset_CompleteFlow(t *testing.T) {
t.Fatalf("Failed to create user: %v", err) t.Fatalf("Failed to create user: %v", err)
} }
reqBody := map[string]string{ reqBody := map[string]any{
"username_or_email": "email_reset@example.com", "username_or_email": "email_reset@example.com",
} }
body, _ := json.Marshal(reqBody) request := makePostRequestWithJSON(t, freshCtx.Router, "/api/auth/forgot-password", reqBody)
req := httptest.NewRequest("POST", "/api/auth/forgot-password", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
freshCtx.Router.ServeHTTP(rec, req) assertStatus(t, request, http.StatusOK)
assertStatus(t, rec, http.StatusOK)
resetToken := freshCtx.Suite.EmailSender.PasswordResetToken() resetToken := freshCtx.Suite.EmailSender.PasswordResetToken()
if resetToken == "" { if resetToken == "" {

View File

@@ -51,24 +51,24 @@ func TestIntegration_RateLimiting(t *testing.T) {
router, _ := setupRateLimitRouter(t, rateLimitConfig) router, _ := setupRateLimitRouter(t, rateLimitConfig)
for i := 0; i < 2; i++ { for i := 0; i < 2; i++ {
req := httptest.NewRequest("POST", "/api/auth/login", bytes.NewBufferString(`{"username":"test","password":"test"}`)) request := httptest.NewRequest("POST", "/api/auth/login", bytes.NewBufferString(`{"username":"test","password":"test"}`))
req.Header.Set("Content-Type", "application/json") request.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder() recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req) router.ServeHTTP(recorder, request)
} }
req := httptest.NewRequest("POST", "/api/auth/login", bytes.NewBufferString(`{"username":"test","password":"test"}`)) request := httptest.NewRequest("POST", "/api/auth/login", bytes.NewBufferString(`{"username":"test","password":"test"}`))
req.Header.Set("Content-Type", "application/json") request.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder() recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req) router.ServeHTTP(recorder, request)
assertErrorResponse(t, rec, http.StatusTooManyRequests) assertErrorResponse(t, recorder, http.StatusTooManyRequests)
assertHeader(t, rec, "Retry-After", "") assertHeader(t, recorder, "Retry-After")
var response map[string]any var response map[string]any
if err := json.NewDecoder(rec.Body).Decode(&response); err == nil { if err := json.NewDecoder(recorder.Body).Decode(&response); err == nil {
if _, exists := response["retry_after"]; !exists { if _, exists := response["retry_after"]; !exists {
t.Error("Expected retry_after in response") t.Error("Expected retry_after in response")
} }
@@ -81,17 +81,17 @@ func TestIntegration_RateLimiting(t *testing.T) {
router, _ := setupRateLimitRouter(t, rateLimitConfig) router, _ := setupRateLimitRouter(t, rateLimitConfig)
for i := 0; i < 5; i++ { for i := 0; i < 5; i++ {
req := httptest.NewRequest("GET", "/api/posts", nil) request := httptest.NewRequest("GET", "/api/posts", nil)
rec := httptest.NewRecorder() recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req) router.ServeHTTP(recorder, request)
} }
req := httptest.NewRequest("GET", "/api/posts", nil) request := httptest.NewRequest("GET", "/api/posts", nil)
rec := httptest.NewRecorder() recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req) router.ServeHTTP(recorder, request)
assertErrorResponse(t, rec, http.StatusTooManyRequests) assertErrorResponse(t, recorder, http.StatusTooManyRequests)
}) })
t.Run("Health_RateLimit_Enforced", func(t *testing.T) { t.Run("Health_RateLimit_Enforced", func(t *testing.T) {
@@ -100,17 +100,17 @@ func TestIntegration_RateLimiting(t *testing.T) {
router, _ := setupRateLimitRouter(t, rateLimitConfig) router, _ := setupRateLimitRouter(t, rateLimitConfig)
for i := 0; i < 3; i++ { for i := 0; i < 3; i++ {
req := httptest.NewRequest("GET", "/health", nil) request := httptest.NewRequest("GET", "/health", nil)
rec := httptest.NewRecorder() recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req) router.ServeHTTP(recorder, request)
} }
req := httptest.NewRequest("GET", "/health", nil) request := httptest.NewRequest("GET", "/health", nil)
rec := httptest.NewRecorder() recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req) router.ServeHTTP(recorder, request)
assertErrorResponse(t, rec, http.StatusTooManyRequests) assertErrorResponse(t, recorder, http.StatusTooManyRequests)
}) })
t.Run("Metrics_RateLimit_Enforced", func(t *testing.T) { t.Run("Metrics_RateLimit_Enforced", func(t *testing.T) {
@@ -119,17 +119,17 @@ func TestIntegration_RateLimiting(t *testing.T) {
router, _ := setupRateLimitRouter(t, rateLimitConfig) router, _ := setupRateLimitRouter(t, rateLimitConfig)
for i := 0; i < 2; i++ { for i := 0; i < 2; i++ {
req := httptest.NewRequest("GET", "/metrics", nil) request := httptest.NewRequest("GET", "/metrics", nil)
rec := httptest.NewRecorder() recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req) router.ServeHTTP(recorder, request)
} }
req := httptest.NewRequest("GET", "/metrics", nil) request := httptest.NewRequest("GET", "/metrics", nil)
rec := httptest.NewRecorder() recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req) router.ServeHTTP(recorder, request)
assertErrorResponse(t, rec, http.StatusTooManyRequests) assertErrorResponse(t, recorder, http.StatusTooManyRequests)
}) })
t.Run("RateLimit_Different_Endpoints_Independent", func(t *testing.T) { t.Run("RateLimit_Different_Endpoints_Independent", func(t *testing.T) {
@@ -139,17 +139,17 @@ func TestIntegration_RateLimiting(t *testing.T) {
router, _ := setupRateLimitRouter(t, rateLimitConfig) router, _ := setupRateLimitRouter(t, rateLimitConfig)
for i := 0; i < 2; i++ { for i := 0; i < 2; i++ {
req := httptest.NewRequest("POST", "/api/auth/login", bytes.NewBufferString(`{"username":"test","password":"test"}`)) request := httptest.NewRequest("POST", "/api/auth/login", bytes.NewBufferString(`{"username":"test","password":"test"}`))
req.Header.Set("Content-Type", "application/json") request.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder() recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req) router.ServeHTTP(recorder, request)
} }
req := httptest.NewRequest("GET", "/api/posts", nil) request := httptest.NewRequest("GET", "/api/posts", nil)
rec := httptest.NewRecorder() recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req) router.ServeHTTP(recorder, request)
assertStatus(t, rec, http.StatusOK) assertStatus(t, recorder, http.StatusOK)
}) })
t.Run("RateLimit_With_Authentication", func(t *testing.T) { t.Run("RateLimit_With_Authentication", func(t *testing.T) {
@@ -166,20 +166,20 @@ func TestIntegration_RateLimiting(t *testing.T) {
user := createAuthenticatedUser(t, authService, suite.UserRepo, uniqueTestUsername(t, "ratelimit_auth"), uniqueTestEmail(t, "ratelimit_auth")) user := createAuthenticatedUser(t, authService, suite.UserRepo, uniqueTestUsername(t, "ratelimit_auth"), uniqueTestEmail(t, "ratelimit_auth"))
for i := 0; i < 3; i++ { for i := 0; i < 3; i++ {
req := httptest.NewRequest("GET", "/api/auth/me", nil) request := httptest.NewRequest("GET", "/api/auth/me", nil)
req.Header.Set("Authorization", "Bearer "+user.Token) request.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID) request = testutils.WithUserContext(request, middleware.UserIDKey, user.User.ID)
rec := httptest.NewRecorder() recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req) router.ServeHTTP(recorder, request)
} }
req := httptest.NewRequest("GET", "/api/auth/me", nil) request := httptest.NewRequest("GET", "/api/auth/me", nil)
req.Header.Set("Authorization", "Bearer "+user.Token) request.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID) request = testutils.WithUserContext(request, middleware.UserIDKey, user.User.ID)
rec := httptest.NewRecorder() recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req) router.ServeHTTP(recorder, request)
assertErrorResponse(t, rec, http.StatusTooManyRequests) assertErrorResponse(t, recorder, http.StatusTooManyRequests)
}) })
} }

View File

@@ -17,35 +17,29 @@ func TestIntegration_Router_FullMiddlewareChain(t *testing.T) {
router := ctx.Router router := ctx.Router
t.Run("SecurityHeaders_Present", func(t *testing.T) { t.Run("SecurityHeaders_Present", func(t *testing.T) {
req := httptest.NewRequest("GET", "/health", nil) request := makeGetRequest(t, router, "/health")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req) assertStatus(t, request, http.StatusOK)
assertStatus(t, rec, http.StatusOK) assertHeader(t, request, "X-Content-Type-Options")
assertHeader(t, request, "X-Frame-Options")
assertHeader(t, rec, "X-Content-Type-Options", "") assertHeader(t, request, "X-XSS-Protection")
assertHeader(t, rec, "X-Frame-Options", "")
assertHeader(t, rec, "X-XSS-Protection", "")
}) })
t.Run("CORS_Headers_Present", func(t *testing.T) { t.Run("CORS_Headers_Present", func(t *testing.T) {
req := httptest.NewRequest("OPTIONS", "/api/posts", nil) request := httptest.NewRequest("OPTIONS", "/api/posts", nil)
req.Header.Set("Origin", "http://localhost:3000") request.Header.Set("Origin", "http://localhost:3000")
rec := httptest.NewRecorder() recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req) router.ServeHTTP(recorder, request)
assertHeader(t, rec, "Access-Control-Allow-Origin", "") assertHeader(t, recorder, "Access-Control-Allow-Origin")
}) })
t.Run("Logging_Middleware_Executes", func(t *testing.T) { t.Run("Logging_Middleware_Executes", func(t *testing.T) {
req := httptest.NewRequest("GET", "/health", nil) request := makeGetRequest(t, router, "/health")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req) if request.Code == 0 {
if rec.Code == 0 {
t.Error("Expected logging middleware to execute") t.Error("Expected logging middleware to execute")
} }
}) })
@@ -53,27 +47,24 @@ func TestIntegration_Router_FullMiddlewareChain(t *testing.T) {
t.Run("RequestSizeLimit_Enforced", func(t *testing.T) { t.Run("RequestSizeLimit_Enforced", func(t *testing.T) {
user := createUserWithCleanup(t, ctx, "size_limit_user", "size_limit@example.com") user := createUserWithCleanup(t, ctx, "size_limit_user", "size_limit@example.com")
largeBody := strings.Repeat("a", 10*1024*1024) largeBody := strings.Repeat("a", 10*1024*1024)
req := httptest.NewRequest("POST", "/api/posts", bytes.NewBufferString(largeBody)) request := httptest.NewRequest("POST", "/api/posts", bytes.NewBufferString(largeBody))
req.Header.Set("Content-Type", "application/json") request.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+user.Token) request.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID) request = testutils.WithUserContext(request, middleware.UserIDKey, user.User.ID)
rec := httptest.NewRecorder() recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req) router.ServeHTTP(recorder, request)
if rec.Code != http.StatusRequestEntityTooLarge && rec.Code != http.StatusBadRequest { if recorder.Code != http.StatusRequestEntityTooLarge && recorder.Code != http.StatusBadRequest {
t.Errorf("Expected status 413 or 400 for oversized request, got %d. Body: %s", rec.Code, rec.Body.String()) t.Errorf("Expected status 413 or 400 for oversized request, got %d. Body: %s", recorder.Code, recorder.Body.String())
} }
}) })
t.Run("DBMonitoring_Active", func(t *testing.T) { t.Run("DBMonitoring_Active", func(t *testing.T) {
req := httptest.NewRequest("GET", "/health", nil) request := makeGetRequest(t, router, "/health")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
var response map[string]any var response map[string]any
if err := json.NewDecoder(rec.Body).Decode(&response); err == nil { if err := json.NewDecoder(request.Body).Decode(&response); err == nil {
if data, ok := response["data"].(map[string]any); ok { if data, ok := response["data"].(map[string]any); ok {
if _, exists := data["database_stats"]; !exists { if _, exists := data["database_stats"]; !exists {
t.Error("Expected database_stats in health response") t.Error("Expected database_stats in health response")
@@ -83,12 +74,9 @@ func TestIntegration_Router_FullMiddlewareChain(t *testing.T) {
}) })
t.Run("Metrics_Middleware_Executes", func(t *testing.T) { t.Run("Metrics_Middleware_Executes", func(t *testing.T) {
req := httptest.NewRequest("GET", "/metrics", nil) request := makeGetRequest(t, router, "/metrics")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req) response := assertJSONResponse(t, request, http.StatusOK)
response := assertJSONResponse(t, rec, http.StatusOK)
if response != nil { if response != nil {
if data, ok := response["data"].(map[string]any); ok { if data, ok := response["data"].(map[string]any); ok {
if _, exists := data["database"]; !exists { if _, exists := data["database"]; !exists {
@@ -99,34 +87,25 @@ func TestIntegration_Router_FullMiddlewareChain(t *testing.T) {
}) })
t.Run("StaticFiles_Served", func(t *testing.T) { t.Run("StaticFiles_Served", func(t *testing.T) {
req := httptest.NewRequest("GET", "/robots.txt", nil) request := makeGetRequest(t, router, "/robots.txt")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req) assertStatus(t, request, http.StatusOK)
assertStatus(t, rec, http.StatusOK) if !strings.Contains(request.Body.String(), "User-agent") {
if !strings.Contains(rec.Body.String(), "User-agent") {
t.Error("Expected robots.txt content") t.Error("Expected robots.txt content")
} }
}) })
t.Run("API_Routes_Accessible", func(t *testing.T) { t.Run("API_Routes_Accessible", func(t *testing.T) {
req := httptest.NewRequest("GET", "/api/posts", nil) request := makeGetRequest(t, router, "/api/posts")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req) assertStatus(t, request, http.StatusOK)
assertStatus(t, rec, http.StatusOK)
}) })
t.Run("Health_Endpoint_Accessible", func(t *testing.T) { t.Run("Health_Endpoint_Accessible", func(t *testing.T) {
req := httptest.NewRequest("GET", "/health", nil) request := makeGetRequest(t, router, "/health")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req) response := assertJSONResponse(t, request, http.StatusOK)
response := assertJSONResponse(t, rec, http.StatusOK)
if response != nil { if response != nil {
if success, ok := response["success"].(bool); !ok || !success { if success, ok := response["success"].(bool); !ok || !success {
t.Error("Expected success=true in health response") t.Error("Expected success=true in health response")
@@ -135,40 +114,33 @@ func TestIntegration_Router_FullMiddlewareChain(t *testing.T) {
}) })
t.Run("Middleware_Order_Correct", func(t *testing.T) { t.Run("Middleware_Order_Correct", func(t *testing.T) {
req := httptest.NewRequest("GET", "/api/posts", nil) request := makeGetRequest(t, router, "/api/posts")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req) assertHeader(t, request, "X-Content-Type-Options")
assertHeader(t, rec, "X-Content-Type-Options", "") if request.Code == 0 {
if rec.Code == 0 {
t.Error("Response should have status code") t.Error("Response should have status code")
} }
}) })
t.Run("Compression_Middleware_Active", func(t *testing.T) { t.Run("Compression_Middleware_Active", func(t *testing.T) {
req := httptest.NewRequest("GET", "/api/posts", nil) request := httptest.NewRequest("GET", "/api/posts", nil)
req.Header.Set("Accept-Encoding", "gzip") request.Header.Set("Accept-Encoding", "gzip")
rec := httptest.NewRecorder() recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req) router.ServeHTTP(recorder, request)
if rec.Header().Get("Content-Encoding") == "" { if recorder.Header().Get("Content-Encoding") == "" {
t.Log("Compression may not be applied to small responses") t.Log("Compression may not be applied to small responses")
} }
}) })
t.Run("Cache_Middleware_Active", func(t *testing.T) { t.Run("Cache_Middleware_Active", func(t *testing.T) {
req1 := httptest.NewRequest("GET", "/api/posts", nil) firstRequest := makeGetRequest(t, router, "/api/posts")
rec1 := httptest.NewRecorder()
router.ServeHTTP(rec1, req1)
req2 := httptest.NewRequest("GET", "/api/posts", nil) secondRequest := makeGetRequest(t, router, "/api/posts")
rec2 := httptest.NewRecorder()
router.ServeHTTP(rec2, req2)
if rec1.Code != rec2.Code { if firstRequest.Code != secondRequest.Code {
t.Error("Cached responses should have same status") t.Error("Cached responses should have same status")
} }
}) })
@@ -177,35 +149,23 @@ func TestIntegration_Router_FullMiddlewareChain(t *testing.T) {
ctx.Suite.EmailSender.Reset() ctx.Suite.EmailSender.Reset()
user := createUserWithCleanup(t, ctx, "auth_middleware_user", "auth_middleware@example.com") user := createUserWithCleanup(t, ctx, "auth_middleware_user", "auth_middleware@example.com")
req := httptest.NewRequest("GET", "/api/auth/me", nil) request := makeAuthenticatedGetRequest(t, router, "/api/auth/me", user, nil)
req.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req) assertStatus(t, request, http.StatusOK)
assertStatus(t, rec, http.StatusOK)
}) })
t.Run("RateLimit_Middleware_Integration", func(t *testing.T) { t.Run("RateLimit_Middleware_Integration", func(t *testing.T) {
rateLimitCtx := setupTestContext(t) rateLimitCtx := setupTestContext(t)
rateLimitRouter := rateLimitCtx.Router rateLimitRouter := rateLimitCtx.Router
for i := 0; i < 3; i++ { for range 3 {
req := httptest.NewRequest("POST", "/api/auth/login", bytes.NewBufferString(`{"username":"test","password":"test"}`)) request := makePostRequestWithJSON(t, rateLimitRouter, "/api/auth/login", map[string]any{"username": "test", "password": "test"})
req.Header.Set("Content-Type", "application/json") _ = request
rec := httptest.NewRecorder()
rateLimitRouter.ServeHTTP(rec, req)
} }
req := httptest.NewRequest("POST", "/api/auth/login", bytes.NewBufferString(`{"username":"test","password":"test"}`)) request := makePostRequestWithJSON(t, rateLimitRouter, "/api/auth/login", map[string]any{"username": "test", "password": "test"})
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
rateLimitRouter.ServeHTTP(rec, req) if request.Code == http.StatusTooManyRequests {
if rec.Code == http.StatusTooManyRequests {
t.Log("Rate limiting is working") t.Log("Rate limiting is working")
} }
}) })

View File

@@ -21,60 +21,60 @@ func TestIntegration_SessionManagement(t *testing.T) {
ctx.Suite.EmailSender.Reset() ctx.Suite.EmailSender.Reset()
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "session_pass_user", "session_pass@example.com") user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "session_pass_user", "session_pass@example.com")
req1 := httptest.NewRequest("GET", "/api/auth/me", nil) firstRequest := httptest.NewRequest("GET", "/api/auth/me", nil)
req1.Header.Set("Authorization", "Bearer "+user.Token) firstRequest.Header.Set("Authorization", "Bearer "+user.Token)
req1 = testutils.WithUserContext(req1, middleware.UserIDKey, user.User.ID) firstRequest = testutils.WithUserContext(firstRequest, middleware.UserIDKey, user.User.ID)
rec1 := httptest.NewRecorder() firstRecorder := httptest.NewRecorder()
router.ServeHTTP(rec1, req1) router.ServeHTTP(firstRecorder, firstRequest)
assertStatus(t, rec1, http.StatusOK) assertStatus(t, firstRecorder, http.StatusOK)
reqBody := map[string]string{ requestBody := map[string]string{
"current_password": "SecurePass123!", "current_password": "SecurePass123!",
"new_password": "NewSecurePass123!", "new_password": "NewSecurePass123!",
} }
body, _ := json.Marshal(reqBody) body, _ := json.Marshal(requestBody)
req2 := httptest.NewRequest("PUT", "/api/auth/password", bytes.NewBuffer(body)) secondRequest := httptest.NewRequest("PUT", "/api/auth/password", bytes.NewBuffer(body))
req2.Header.Set("Content-Type", "application/json") secondRequest.Header.Set("Content-Type", "application/json")
req2.Header.Set("Authorization", "Bearer "+user.Token) secondRequest.Header.Set("Authorization", "Bearer "+user.Token)
req2 = testutils.WithUserContext(req2, middleware.UserIDKey, user.User.ID) secondRequest = testutils.WithUserContext(secondRequest, middleware.UserIDKey, user.User.ID)
rec2 := httptest.NewRecorder() secondRecorder := httptest.NewRecorder()
router.ServeHTTP(rec2, req2) router.ServeHTTP(secondRecorder, secondRequest)
assertStatus(t, rec2, http.StatusOK) assertStatus(t, secondRecorder, http.StatusOK)
req3 := httptest.NewRequest("GET", "/api/auth/me", nil) thirdRequest := httptest.NewRequest("GET", "/api/auth/me", nil)
req3.Header.Set("Authorization", "Bearer "+user.Token) thirdRequest.Header.Set("Authorization", "Bearer "+user.Token)
req3 = testutils.WithUserContext(req3, middleware.UserIDKey, user.User.ID) thirdRequest = testutils.WithUserContext(thirdRequest, middleware.UserIDKey, user.User.ID)
rec3 := httptest.NewRecorder() thirdRecorder := httptest.NewRecorder()
router.ServeHTTP(rec3, req3) router.ServeHTTP(thirdRecorder, thirdRequest)
assertErrorResponse(t, rec3, http.StatusUnauthorized) assertErrorResponse(t, thirdRecorder, http.StatusUnauthorized)
}) })
t.Run("Session_Invalidation_On_Account_Lock", func(t *testing.T) { t.Run("Session_Invalidation_On_Account_Lock", func(t *testing.T) {
ctx.Suite.EmailSender.Reset() ctx.Suite.EmailSender.Reset()
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "session_lock_user", "session_lock@example.com") user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "session_lock_user", "session_lock@example.com")
req1 := httptest.NewRequest("GET", "/api/auth/me", nil) firstRequest := httptest.NewRequest("GET", "/api/auth/me", nil)
req1.Header.Set("Authorization", "Bearer "+user.Token) firstRequest.Header.Set("Authorization", "Bearer "+user.Token)
req1 = testutils.WithUserContext(req1, middleware.UserIDKey, user.User.ID) firstRequest = testutils.WithUserContext(firstRequest, middleware.UserIDKey, user.User.ID)
rec1 := httptest.NewRecorder() firstRecorder := httptest.NewRecorder()
router.ServeHTTP(rec1, req1) router.ServeHTTP(firstRecorder, firstRequest)
assertStatus(t, rec1, http.StatusOK) assertStatus(t, firstRecorder, http.StatusOK)
if err := ctx.Suite.UserRepo.Lock(user.User.ID); err != nil { if err := ctx.Suite.UserRepo.Lock(user.User.ID); err != nil {
t.Fatalf("Failed to lock user: %v", err) t.Fatalf("Failed to lock user: %v", err)
} }
req2 := httptest.NewRequest("GET", "/api/auth/me", nil) secondRequest := httptest.NewRequest("GET", "/api/auth/me", nil)
req2.Header.Set("Authorization", "Bearer "+user.Token) secondRequest.Header.Set("Authorization", "Bearer "+user.Token)
req2 = testutils.WithUserContext(req2, middleware.UserIDKey, user.User.ID) secondRequest = testutils.WithUserContext(secondRequest, middleware.UserIDKey, user.User.ID)
rec2 := httptest.NewRecorder() secondRecorder := httptest.NewRecorder()
router.ServeHTTP(rec2, req2) router.ServeHTTP(secondRecorder, secondRequest)
assertErrorResponse(t, rec2, http.StatusUnauthorized) assertErrorResponse(t, secondRecorder, http.StatusUnauthorized)
}) })
t.Run("Refresh_Token_Revocation", func(t *testing.T) { t.Run("Refresh_Token_Revocation", func(t *testing.T) {
@@ -90,48 +90,48 @@ func TestIntegration_SessionManagement(t *testing.T) {
t.Fatal("Expected refresh token") t.Fatal("Expected refresh token")
} }
reqBody := map[string]string{ requestBody := map[string]string{
"refresh_token": loginResult.RefreshToken, "refresh_token": loginResult.RefreshToken,
} }
body, _ := json.Marshal(reqBody) body, _ := json.Marshal(requestBody)
req := httptest.NewRequest("POST", "/api/auth/refresh", bytes.NewBuffer(body)) request := httptest.NewRequest("POST", "/api/auth/refresh", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json") request.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder() recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req) router.ServeHTTP(recorder, request)
assertStatus(t, rec, http.StatusOK) assertStatus(t, recorder, http.StatusOK)
if err := ctx.AuthService.RevokeRefreshToken(loginResult.RefreshToken); err != nil { if err := ctx.AuthService.RevokeRefreshToken(loginResult.RefreshToken); err != nil {
t.Fatalf("Failed to revoke token: %v", err) t.Fatalf("Failed to revoke token: %v", err)
} }
req2 := httptest.NewRequest("POST", "/api/auth/refresh", bytes.NewBuffer(body)) secondRequest := httptest.NewRequest("POST", "/api/auth/refresh", bytes.NewBuffer(body))
req2.Header.Set("Content-Type", "application/json") secondRequest.Header.Set("Content-Type", "application/json")
rec2 := httptest.NewRecorder() secondRecorder := httptest.NewRecorder()
router.ServeHTTP(rec2, req2) router.ServeHTTP(secondRecorder, secondRequest)
assertErrorResponse(t, rec2, http.StatusUnauthorized) assertErrorResponse(t, secondRecorder, http.StatusUnauthorized)
}) })
t.Run("Multiple_Sessions_Independent", func(t *testing.T) { t.Run("Multiple_Sessions_Independent", func(t *testing.T) {
ctx.Suite.EmailSender.Reset() ctx.Suite.EmailSender.Reset()
user1 := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "multi_session_user1", "multi_session1@example.com") firstUser := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "multi_session_user1", "multi_session1@example.com")
user2 := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "multi_session_user2", "multi_session2@example.com") secondUser := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "multi_session_user2", "multi_session2@example.com")
req1 := httptest.NewRequest("GET", "/api/auth/me", nil) firstRequest := httptest.NewRequest("GET", "/api/auth/me", nil)
req1.Header.Set("Authorization", "Bearer "+user1.Token) firstRequest.Header.Set("Authorization", "Bearer "+firstUser.Token)
req1 = testutils.WithUserContext(req1, middleware.UserIDKey, user1.User.ID) firstRequest = testutils.WithUserContext(firstRequest, middleware.UserIDKey, firstUser.User.ID)
rec1 := httptest.NewRecorder() firstRecorder := httptest.NewRecorder()
router.ServeHTTP(rec1, req1) router.ServeHTTP(firstRecorder, firstRequest)
req2 := httptest.NewRequest("GET", "/api/auth/me", nil) secondRequest := httptest.NewRequest("GET", "/api/auth/me", nil)
req2.Header.Set("Authorization", "Bearer "+user2.Token) secondRequest.Header.Set("Authorization", "Bearer "+secondUser.Token)
req2 = testutils.WithUserContext(req2, middleware.UserIDKey, user2.User.ID) secondRequest = testutils.WithUserContext(secondRequest, middleware.UserIDKey, secondUser.User.ID)
rec2 := httptest.NewRecorder() secondRecorder := httptest.NewRecorder()
router.ServeHTTP(rec2, req2) router.ServeHTTP(secondRecorder, secondRequest)
assertStatus(t, rec1, http.StatusOK) assertStatus(t, firstRecorder, http.StatusOK)
assertStatus(t, rec2, http.StatusOK) assertStatus(t, secondRecorder, http.StatusOK)
}) })
} }
@@ -144,17 +144,17 @@ func TestIntegration_AccountDeletion(t *testing.T) {
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "del_flow_user", "del_flow@example.com") user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "del_flow_user", "del_flow@example.com")
post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Test Post", "https://example.com") post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Test Post", "https://example.com")
reqBody := map[string]string{} requestBody := map[string]string{}
body, _ := json.Marshal(reqBody) body, _ := json.Marshal(requestBody)
req := httptest.NewRequest("DELETE", "/api/auth/account", bytes.NewBuffer(body)) request := httptest.NewRequest("DELETE", "/api/auth/account", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json") request.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+user.Token) request.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID) request = testutils.WithUserContext(request, middleware.UserIDKey, user.User.ID)
rec := httptest.NewRecorder() recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req) router.ServeHTTP(recorder, request)
response := assertJSONResponse(t, rec, http.StatusOK) response := assertJSONResponse(t, recorder, http.StatusOK)
if response == nil { if response == nil {
return return
} }
@@ -171,13 +171,13 @@ func TestIntegration_AccountDeletion(t *testing.T) {
"token": deletionToken, "token": deletionToken,
} }
confirmBodyBytes, _ := json.Marshal(confirmBody) confirmBodyBytes, _ := json.Marshal(confirmBody)
confirmReq := httptest.NewRequest("POST", "/api/auth/account/confirm", bytes.NewBuffer(confirmBodyBytes)) confirmRequest := httptest.NewRequest("POST", "/api/auth/account/confirm", bytes.NewBuffer(confirmBodyBytes))
confirmReq.Header.Set("Content-Type", "application/json") confirmRequest.Header.Set("Content-Type", "application/json")
confirmRec := httptest.NewRecorder() confirmRecorder := httptest.NewRecorder()
router.ServeHTTP(confirmRec, confirmReq) router.ServeHTTP(confirmRecorder, confirmRequest)
confirmResponse := assertJSONResponse(t, confirmRec, http.StatusOK) confirmResponse := assertJSONResponse(t, confirmRecorder, http.StatusOK)
if confirmResponse == nil { if confirmResponse == nil {
return return
} }
@@ -209,17 +209,17 @@ func TestIntegration_AccountDeletion(t *testing.T) {
user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "del_posts_user", "del_posts@example.com") user := createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "del_posts_user", "del_posts@example.com")
post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Deletion Post", "https://example.com/deletion") post := testutils.CreatePostWithRepo(t, ctx.Suite.PostRepo, user.User.ID, "Deletion Post", "https://example.com/deletion")
reqBody := map[string]string{} requestBody := map[string]string{}
body, _ := json.Marshal(reqBody) body, _ := json.Marshal(requestBody)
req := httptest.NewRequest("DELETE", "/api/auth/account", bytes.NewBuffer(body)) request := httptest.NewRequest("DELETE", "/api/auth/account", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json") request.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+user.Token) request.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID) request = testutils.WithUserContext(request, middleware.UserIDKey, user.User.ID)
rec := httptest.NewRecorder() recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req) router.ServeHTTP(recorder, request)
response := assertJSONResponse(t, rec, http.StatusOK) response := assertJSONResponse(t, recorder, http.StatusOK)
if response == nil { if response == nil {
return return
} }
@@ -237,13 +237,13 @@ func TestIntegration_AccountDeletion(t *testing.T) {
"delete_posts": true, "delete_posts": true,
} }
confirmBodyBytes, _ := json.Marshal(confirmBody) confirmBodyBytes, _ := json.Marshal(confirmBody)
confirmReq := httptest.NewRequest("POST", "/api/auth/account/confirm", bytes.NewBuffer(confirmBodyBytes)) confirmRequest := httptest.NewRequest("POST", "/api/auth/account/confirm", bytes.NewBuffer(confirmBodyBytes))
confirmReq.Header.Set("Content-Type", "application/json") confirmRequest.Header.Set("Content-Type", "application/json")
confirmRec := httptest.NewRecorder() confirmRecorder := httptest.NewRecorder()
router.ServeHTTP(confirmRec, confirmReq) router.ServeHTTP(confirmRecorder, confirmRequest)
confirmResponse := assertJSONResponse(t, confirmRec, http.StatusOK) confirmResponse := assertJSONResponse(t, confirmRecorder, http.StatusOK)
if confirmResponse == nil { if confirmResponse == nil {
return return
} }
@@ -275,12 +275,12 @@ func TestIntegration_MetricsCollection(t *testing.T) {
router := ctx.Router router := ctx.Router
t.Run("Metrics_Endpoint_Returns_Data", func(t *testing.T) { t.Run("Metrics_Endpoint_Returns_Data", func(t *testing.T) {
req := httptest.NewRequest("GET", "/metrics", nil) request := httptest.NewRequest("GET", "/metrics", nil)
rec := httptest.NewRecorder() recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req) router.ServeHTTP(recorder, request)
response := assertJSONResponse(t, rec, http.StatusOK) response := assertJSONResponse(t, recorder, http.StatusOK)
if response != nil { if response != nil {
if data, ok := response["data"].(map[string]any); ok { if data, ok := response["data"].(map[string]any); ok {
if _, exists := data["database"]; !exists { if _, exists := data["database"]; !exists {
@@ -294,13 +294,13 @@ func TestIntegration_MetricsCollection(t *testing.T) {
ctx.Suite.EmailSender.Reset() ctx.Suite.EmailSender.Reset()
createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "metrics_user", "metrics@example.com") createAuthenticatedUser(t, ctx.AuthService, ctx.Suite.UserRepo, "metrics_user", "metrics@example.com")
req := httptest.NewRequest("GET", "/metrics", nil) request := httptest.NewRequest("GET", "/metrics", nil)
rec := httptest.NewRecorder() recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req) router.ServeHTTP(recorder, request)
var response map[string]any var response map[string]any
if err := json.NewDecoder(rec.Body).Decode(&response); err == nil { if err := json.NewDecoder(recorder.Body).Decode(&response); err == nil {
if data, ok := response["data"].(map[string]any); ok { if data, ok := response["data"].(map[string]any); ok {
if dbData, exists := data["database"].(map[string]any); exists { if dbData, exists := data["database"].(map[string]any); exists {
if _, hasQueries := dbData["total_queries"]; !hasQueries { if _, hasQueries := dbData["total_queries"]; !hasQueries {
@@ -323,7 +323,7 @@ func TestIntegration_ConcurrentRequests(t *testing.T) {
var wg sync.WaitGroup var wg sync.WaitGroup
errors := make(chan error, 10) errors := make(chan error, 10)
for i := 0; i < 10; i++ { for idx := range 10 {
wg.Add(1) wg.Add(1)
go func(index int) { go func(index int) {
defer wg.Done() defer wg.Done()
@@ -334,18 +334,18 @@ func TestIntegration_ConcurrentRequests(t *testing.T) {
"content": "Concurrent test content", "content": "Concurrent test content",
} }
body, _ := json.Marshal(postBody) body, _ := json.Marshal(postBody)
req := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer(body)) request := httptest.NewRequest("POST", "/api/posts", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json") request.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+user.Token) request.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID) request = testutils.WithUserContext(request, middleware.UserIDKey, user.User.ID)
rec := httptest.NewRecorder() recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req) router.ServeHTTP(recorder, request)
if rec.Code != http.StatusCreated { if recorder.Code != http.StatusCreated {
errors <- fmt.Errorf("Post %d failed with status %d", index, rec.Code) errors <- fmt.Errorf("Post %d failed with status %d", index, recorder.Code)
} }
}(i) }(idx)
} }
wg.Wait() wg.Wait()
@@ -370,28 +370,26 @@ func TestIntegration_ConcurrentRequests(t *testing.T) {
var wg sync.WaitGroup var wg sync.WaitGroup
errors := make(chan error, 5) errors := make(chan error, 5)
for i := 0; i < 5; i++ { for range 5 {
wg.Add(1) wg.Go(func() {
go func() {
defer wg.Done()
voteBody := map[string]string{ voteBody := map[string]string{
"type": "up", "type": "up",
} }
body, _ := json.Marshal(voteBody) body, _ := json.Marshal(voteBody)
req := httptest.NewRequest("POST", fmt.Sprintf("/api/posts/%d/vote", post.ID), bytes.NewBuffer(body)) request := httptest.NewRequest("POST", fmt.Sprintf("/api/posts/%d/vote", post.ID), bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json") request.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+user.Token) request.Header.Set("Authorization", "Bearer "+user.Token)
req = testutils.WithUserContext(req, middleware.UserIDKey, user.User.ID) request = testutils.WithUserContext(request, middleware.UserIDKey, user.User.ID)
req = testutils.WithURLParams(req, map[string]string{"id": fmt.Sprintf("%d", post.ID)}) request = testutils.WithURLParams(request, map[string]string{"id": fmt.Sprintf("%d", post.ID)})
rec := httptest.NewRecorder() recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req) router.ServeHTTP(recorder, request)
if rec.Code != http.StatusOK { if recorder.Code != http.StatusOK {
errors <- fmt.Errorf("Vote failed with status %d", rec.Code) errors <- fmt.Errorf("Vote failed with status %d", recorder.Code)
} }
}() })
} }
wg.Wait() wg.Wait()
@@ -411,20 +409,18 @@ func TestIntegration_ConcurrentRequests(t *testing.T) {
var wg sync.WaitGroup var wg sync.WaitGroup
errors := make(chan error, 20) errors := make(chan error, 20)
for i := 0; i < 20; i++ { for range 20 {
wg.Add(1) wg.Go(func() {
go func() {
defer wg.Done()
req := httptest.NewRequest("GET", "/api/posts", nil) request := httptest.NewRequest("GET", "/api/posts", nil)
rec := httptest.NewRecorder() recorder := httptest.NewRecorder()
router.ServeHTTP(rec, req) router.ServeHTTP(recorder, request)
if rec.Code != http.StatusOK { if recorder.Code != http.StatusOK {
errors <- fmt.Errorf("Read failed with status %d", rec.Code) errors <- fmt.Errorf("Read failed with status %d", recorder.Code)
} }
}() })
} }
wg.Wait() wg.Wait()

View File

@@ -38,6 +38,10 @@ func CompressionMiddlewareWithConfig(config *CompressionConfig) func(http.Handle
next.ServeHTTP(bufferedWriter, r) next.ServeHTTP(bufferedWriter, r)
if bufferedWriter.isRedirect {
return
}
if buf.Len() < config.MinSize { if buf.Len() < config.MinSize {
bufferedWriter.flush() bufferedWriter.flush()
w.Write(buf.Bytes()) w.Write(buf.Bytes())
@@ -73,9 +77,13 @@ type bufferedResponseWriter struct {
buffer *bytes.Buffer buffer *bytes.Buffer
statusCode int statusCode int
headerWritten bool headerWritten bool
isRedirect bool
} }
func (brw *bufferedResponseWriter) Write(b []byte) (int, error) { func (brw *bufferedResponseWriter) Write(b []byte) (int, error) {
if brw.isRedirect {
return brw.ResponseWriter.Write(b)
}
if !brw.headerWritten { if !brw.headerWritten {
brw.statusCode = http.StatusOK brw.statusCode = http.StatusOK
} }
@@ -87,6 +95,11 @@ func (brw *bufferedResponseWriter) WriteHeader(code int) {
return return
} }
brw.statusCode = code brw.statusCode = code
if isRedirect(code) {
brw.isRedirect = true
brw.ResponseWriter.WriteHeader(code)
brw.headerWritten = true
}
} }
func (brw *bufferedResponseWriter) Header() http.Header { func (brw *bufferedResponseWriter) Header() http.Header {
@@ -100,6 +113,10 @@ func (brw *bufferedResponseWriter) flush() {
} }
} }
func isRedirect(statusCode int) bool {
return statusCode >= 300 && statusCode < 400
}
func shouldCompress(r *http.Request, config *CompressionConfig) bool { func shouldCompress(r *http.Request, config *CompressionConfig) bool {
return r.Header.Get("Content-Encoding") == "" return r.Header.Get("Content-Encoding") == ""
} }

View File

@@ -27,7 +27,14 @@ func ValidationMiddleware() func(http.Handler) http.Handler {
dto := reflect.New(dtoType).Interface() dto := reflect.New(dtoType).Interface()
if err := json.NewDecoder(r.Body).Decode(dto); err != nil { if err := json.NewDecoder(r.Body).Decode(dto); err != nil {
http.Error(w, "Invalid JSON", http.StatusBadRequest) response := map[string]any{
"success": false,
"error": "Invalid JSON",
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadRequest)
json.NewEncoder(w).Encode(response)
return return
} }
@@ -77,3 +84,7 @@ func GetDTOTypeFromContext(ctx context.Context) reflect.Type {
func GetValidatedDTOFromContext(ctx context.Context) any { func GetValidatedDTOFromContext(ctx context.Context) any {
return ctx.Value(validatedDTOKey) return ctx.Value(validatedDTOKey)
} }
func SetValidatedDTOInContext(ctx context.Context, dto any) context.Context {
return context.WithValue(ctx, validatedDTOKey, dto)
}

View File

@@ -28,6 +28,7 @@ type UserRepository interface {
Unlock(id uint) error Unlock(id uint) error
GetPosts(userID uint, limit, offset int) ([]database.Post, error) GetPosts(userID uint, limit, offset int) ([]database.Post, error)
GetDeletedUsers() ([]database.User, error) GetDeletedUsers() ([]database.User, error)
GetByUsernamePrefix(prefix string) (*database.User, error)
HardDeleteAll() (int64, error) HardDeleteAll() (int64, error)
Count() (int64, error) Count() (int64, error)
WithTx(tx *gorm.DB) UserRepository WithTx(tx *gorm.DB) UserRepository
@@ -240,6 +241,17 @@ func (r *userRepository) GetDeletedUsers() ([]database.User, error) {
return users, err return users, err
} }
func (r *userRepository) GetByUsernamePrefix(prefix string) (*database.User, error) {
var user database.User
err := r.db.
Where("username LIKE ? AND email LIKE ?", prefix+"%", prefix+"%@goyco.local").
First(&user).Error
if err != nil {
return nil, err
}
return &user, nil
}
func (r *userRepository) HardDeleteAll() (int64, error) { func (r *userRepository) HardDeleteAll() (int64, error) {
var totalDeleted int64 var totalDeleted int64
err := r.db.Transaction(func(tx *gorm.DB) error { err := r.db.Transaction(func(tx *gorm.DB) error {

View File

@@ -20,6 +20,7 @@ type VoteRepository interface {
Count() (int64, error) Count() (int64, error)
CountByPostID(postID uint) (int64, error) CountByPostID(postID uint) (int64, error)
CountByUserID(userID uint) (int64, error) CountByUserID(userID uint) (int64, error)
GetVoteCountsByPostID(postID uint) (upVotes int, downVotes int, err error)
WithTx(tx *gorm.DB) VoteRepository WithTx(tx *gorm.DB) VoteRepository
} }
@@ -144,3 +145,20 @@ func (r *voteRepository) Count() (int64, error) {
err := r.db.Model(&database.Vote{}).Count(&count).Error err := r.db.Model(&database.Vote{}).Count(&count).Error
return count, err return count, err
} }
func (r *voteRepository) GetVoteCountsByPostID(postID uint) (int, int, error) {
var result struct {
UpVotes int64
DownVotes int64
}
err := r.db.Model(&database.Vote{}).
Select("COUNT(CASE WHEN type = ? THEN 1 END) as up_votes, COUNT(CASE WHEN type = ? THEN 1 END) as down_votes", database.VoteUp, database.VoteDown).
Where("post_id = ?", postID).
Scan(&result).Error
if err != nil {
return 0, 0, err
}
return int(result.UpVotes), int(result.DownVotes), nil
}

View File

@@ -1,8 +1,10 @@
package server package server
import ( import (
"mime"
"net/http" "net/http"
"path/filepath" "path/filepath"
"strings"
"time" "time"
"goyco/internal/config" "goyco/internal/config"
@@ -73,6 +75,7 @@ func NewRouter(cfg RouterConfig) http.Handler {
}, },
CSRFMiddleware: middleware.CSRFMiddleware(), CSRFMiddleware: middleware.CSRFMiddleware(),
AuthMiddleware: middleware.NewAuth(cfg.AuthService), AuthMiddleware: middleware.NewAuth(cfg.AuthService),
ValidationMiddleware: middleware.ValidationMiddleware(),
} }
if cfg.PageHandler != nil { if cfg.PageHandler != nil {
@@ -123,7 +126,33 @@ func NewRouter(cfg RouterConfig) http.Handler {
staticDir = "./internal/static/" staticDir = "./internal/static/"
} }
router.Handle("/static/*", http.StripPrefix("/static/", http.FileServer(http.Dir(staticDir)))) staticFileServer := http.StripPrefix("/static/", http.FileServer(http.Dir(staticDir)))
router.Handle("/static/*", staticFileHandler(staticFileServer))
return router return router
} }
func staticFileHandler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
path := r.URL.Path
ext := filepath.Ext(path)
if ext == ".css" {
w.Header().Set("Content-Type", "text/css; charset=utf-8")
} else if ext == ".js" {
w.Header().Set("Content-Type", "application/javascript; charset=utf-8")
} else if ext == ".json" {
w.Header().Set("Content-Type", "application/json; charset=utf-8")
} else if ext == ".ico" {
w.Header().Set("Content-Type", "image/x-icon")
} else if strings.HasPrefix(mime.TypeByExtension(ext), "image/") {
w.Header().Set("Content-Type", mime.TypeByExtension(ext))
} else if strings.HasPrefix(mime.TypeByExtension(ext), "font/") {
w.Header().Set("Content-Type", mime.TypeByExtension(ext))
} else if mimeType := mime.TypeByExtension(ext); mimeType != "" {
w.Header().Set("Content-Type", mimeType)
}
next.ServeHTTP(w, r)
})
}

View File

@@ -3,6 +3,7 @@ package server
import ( import (
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"strings"
"testing" "testing"
"goyco/internal/config" "goyco/internal/config"
@@ -105,9 +106,9 @@ func defaultRateLimitConfig() config.RateLimitConfig {
return testutils.AppTestConfig.RateLimit return testutils.AppTestConfig.RateLimit
} }
func TestAPIRootRouting(t *testing.T) { func createDefaultRouterConfig() RouterConfig {
authHandler, postHandler, voteHandler, userHandler, apiHandler, authService := setupTestHandlers() authHandler, postHandler, voteHandler, userHandler, apiHandler, authService := setupTestHandlers()
router := NewRouter(RouterConfig{ return RouterConfig{
APIHandler: apiHandler, APIHandler: apiHandler,
AuthHandler: authHandler, AuthHandler: authHandler,
PostHandler: postHandler, PostHandler: postHandler,
@@ -115,7 +116,15 @@ func TestAPIRootRouting(t *testing.T) {
UserHandler: userHandler, UserHandler: userHandler,
AuthService: authService, AuthService: authService,
RateLimitConfig: defaultRateLimitConfig(), RateLimitConfig: defaultRateLimitConfig(),
}) }
}
func createTestRouter(cfg RouterConfig) http.Handler {
return NewRouter(cfg)
}
func TestAPIRootRouting(t *testing.T) {
router := createTestRouter(createDefaultRouterConfig())
testCases := []struct { testCases := []struct {
name string name string
@@ -141,23 +150,23 @@ func TestAPIRootRouting(t *testing.T) {
} }
func TestProtectedRoutesRequireAuth(t *testing.T) { func TestProtectedRoutesRequireAuth(t *testing.T) {
authHandler, postHandler, voteHandler, userHandler, apiHandler, authService := setupTestHandlers() router := createTestRouter(createDefaultRouterConfig())
router := NewRouter(RouterConfig{
APIHandler: apiHandler,
AuthHandler: authHandler,
PostHandler: postHandler,
VoteHandler: voteHandler,
UserHandler: userHandler,
AuthService: authService,
RateLimitConfig: defaultRateLimitConfig(),
})
protectedRoutes := []struct { protectedRoutes := []struct {
method string method string
path string path string
}{ }{
{http.MethodGet, "/api/auth/me"}, {http.MethodGet, "/api/auth/me"},
{http.MethodPost, "/api/auth/logout"},
{http.MethodPost, "/api/auth/revoke"},
{http.MethodPost, "/api/auth/revoke-all"},
{http.MethodPut, "/api/auth/email"},
{http.MethodPut, "/api/auth/username"},
{http.MethodPut, "/api/auth/password"},
{http.MethodDelete, "/api/auth/account"},
{http.MethodPost, "/api/posts"}, {http.MethodPost, "/api/posts"},
{http.MethodPut, "/api/posts/1"},
{http.MethodDelete, "/api/posts/1"},
{http.MethodPost, "/api/posts/1/vote"}, {http.MethodPost, "/api/posts/1/vote"},
{http.MethodDelete, "/api/posts/1/vote"}, {http.MethodDelete, "/api/posts/1/vote"},
{http.MethodGet, "/api/posts/1/vote"}, {http.MethodGet, "/api/posts/1/vote"},
@@ -183,17 +192,9 @@ func TestProtectedRoutesRequireAuth(t *testing.T) {
} }
func TestRouterWithDebugMode(t *testing.T) { func TestRouterWithDebugMode(t *testing.T) {
authHandler, postHandler, voteHandler, userHandler, apiHandler, authService := setupTestHandlers() cfg := createDefaultRouterConfig()
router := NewRouter(RouterConfig{ cfg.Debug = true
Debug: true, router := createTestRouter(cfg)
APIHandler: apiHandler,
AuthHandler: authHandler,
PostHandler: postHandler,
VoteHandler: voteHandler,
UserHandler: userHandler,
AuthService: authService,
RateLimitConfig: defaultRateLimitConfig(),
})
request := httptest.NewRequest(http.MethodGet, "/api", nil) request := httptest.NewRequest(http.MethodGet, "/api", nil)
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
@@ -206,16 +207,9 @@ func TestRouterWithDebugMode(t *testing.T) {
} }
func TestRouterWithCacheDisabled(t *testing.T) { func TestRouterWithCacheDisabled(t *testing.T) {
authHandler, postHandler, voteHandler, userHandler, apiHandler, authService := setupTestHandlers() cfg := createDefaultRouterConfig()
router := NewRouter(RouterConfig{ cfg.DisableCache = true
DisableCache: true, router := createTestRouter(cfg)
APIHandler: apiHandler,
AuthHandler: authHandler,
PostHandler: postHandler,
VoteHandler: voteHandler,
UserHandler: userHandler,
AuthService: authService,
})
request := httptest.NewRequest(http.MethodGet, "/api", nil) request := httptest.NewRequest(http.MethodGet, "/api", nil)
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
@@ -228,17 +222,9 @@ func TestRouterWithCacheDisabled(t *testing.T) {
} }
func TestRouterWithCompressionDisabled(t *testing.T) { func TestRouterWithCompressionDisabled(t *testing.T) {
authHandler, postHandler, voteHandler, userHandler, apiHandler, authService := setupTestHandlers() cfg := createDefaultRouterConfig()
router := NewRouter(RouterConfig{ cfg.DisableCompression = true
DisableCompression: true, router := createTestRouter(cfg)
APIHandler: apiHandler,
AuthHandler: authHandler,
PostHandler: postHandler,
VoteHandler: voteHandler,
UserHandler: userHandler,
AuthService: authService,
RateLimitConfig: defaultRateLimitConfig(),
})
request := httptest.NewRequest(http.MethodGet, "/api", nil) request := httptest.NewRequest(http.MethodGet, "/api", nil)
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
@@ -251,19 +237,9 @@ func TestRouterWithCompressionDisabled(t *testing.T) {
} }
func TestRouterWithCustomDBMonitor(t *testing.T) { func TestRouterWithCustomDBMonitor(t *testing.T) {
authHandler, postHandler, voteHandler, userHandler, apiHandler, authService := setupTestHandlers() cfg := createDefaultRouterConfig()
customDBMonitor := middleware.NewInMemoryDBMonitor() cfg.DBMonitor = middleware.NewInMemoryDBMonitor()
router := createTestRouter(cfg)
router := NewRouter(RouterConfig{
DBMonitor: customDBMonitor,
APIHandler: apiHandler,
AuthHandler: authHandler,
PostHandler: postHandler,
VoteHandler: voteHandler,
UserHandler: userHandler,
AuthService: authService,
RateLimitConfig: defaultRateLimitConfig(),
})
request := httptest.NewRequest(http.MethodGet, "/api", nil) request := httptest.NewRequest(http.MethodGet, "/api", nil)
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
@@ -296,18 +272,9 @@ func TestRouterWithPageHandler(t *testing.T) {
} }
func TestRouterWithStaticDir(t *testing.T) { func TestRouterWithStaticDir(t *testing.T) {
authHandler, postHandler, voteHandler, userHandler, apiHandler, authService := setupTestHandlers() cfg := createDefaultRouterConfig()
cfg.StaticDir = "/custom/static/path"
router := NewRouter(RouterConfig{ router := createTestRouter(cfg)
StaticDir: "/custom/static/path",
APIHandler: apiHandler,
AuthHandler: authHandler,
PostHandler: postHandler,
VoteHandler: voteHandler,
UserHandler: userHandler,
AuthService: authService,
RateLimitConfig: defaultRateLimitConfig(),
})
request := httptest.NewRequest(http.MethodGet, "/api", nil) request := httptest.NewRequest(http.MethodGet, "/api", nil)
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
@@ -320,18 +287,9 @@ func TestRouterWithStaticDir(t *testing.T) {
} }
func TestRouterWithEmptyStaticDir(t *testing.T) { func TestRouterWithEmptyStaticDir(t *testing.T) {
authHandler, postHandler, voteHandler, userHandler, apiHandler, authService := setupTestHandlers() cfg := createDefaultRouterConfig()
cfg.StaticDir = ""
router := NewRouter(RouterConfig{ router := createTestRouter(cfg)
StaticDir: "",
APIHandler: apiHandler,
AuthHandler: authHandler,
PostHandler: postHandler,
VoteHandler: voteHandler,
UserHandler: userHandler,
AuthService: authService,
RateLimitConfig: defaultRateLimitConfig(),
})
request := httptest.NewRequest(http.MethodGet, "/api", nil) request := httptest.NewRequest(http.MethodGet, "/api", nil)
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
@@ -344,20 +302,11 @@ func TestRouterWithEmptyStaticDir(t *testing.T) {
} }
func TestRouterWithAllFeaturesDisabled(t *testing.T) { func TestRouterWithAllFeaturesDisabled(t *testing.T) {
authHandler, postHandler, voteHandler, userHandler, apiHandler, authService := setupTestHandlers() cfg := createDefaultRouterConfig()
cfg.Debug = true
router := NewRouter(RouterConfig{ cfg.DisableCache = true
Debug: true, cfg.DisableCompression = true
DisableCache: true, router := createTestRouter(cfg)
DisableCompression: true,
APIHandler: apiHandler,
AuthHandler: authHandler,
PostHandler: postHandler,
VoteHandler: voteHandler,
UserHandler: userHandler,
AuthService: authService,
RateLimitConfig: defaultRateLimitConfig(),
})
request := httptest.NewRequest(http.MethodGet, "/api", nil) request := httptest.NewRequest(http.MethodGet, "/api", nil)
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
@@ -370,15 +319,9 @@ func TestRouterWithAllFeaturesDisabled(t *testing.T) {
} }
func TestRouterWithoutAPIHandler(t *testing.T) { func TestRouterWithoutAPIHandler(t *testing.T) {
authHandler, postHandler, voteHandler, userHandler, _, authService := setupTestHandlers() cfg := createDefaultRouterConfig()
router := NewRouter(RouterConfig{ cfg.APIHandler = nil
AuthHandler: authHandler, router := createTestRouter(cfg)
PostHandler: postHandler,
VoteHandler: voteHandler,
UserHandler: userHandler,
AuthService: authService,
RateLimitConfig: defaultRateLimitConfig(),
})
request := httptest.NewRequest(http.MethodGet, "/api", nil) request := httptest.NewRequest(http.MethodGet, "/api", nil)
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
@@ -391,17 +334,7 @@ func TestRouterWithoutAPIHandler(t *testing.T) {
} }
func TestRouterWithoutPageHandler(t *testing.T) { func TestRouterWithoutPageHandler(t *testing.T) {
authHandler, postHandler, voteHandler, userHandler, apiHandler, authService := setupTestHandlers() router := createTestRouter(createDefaultRouterConfig())
router := NewRouter(RouterConfig{
APIHandler: apiHandler,
AuthHandler: authHandler,
PostHandler: postHandler,
VoteHandler: voteHandler,
UserHandler: userHandler,
AuthService: authService,
RateLimitConfig: defaultRateLimitConfig(),
})
request := httptest.NewRequest(http.MethodGet, "/", nil) request := httptest.NewRequest(http.MethodGet, "/", nil)
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
@@ -414,17 +347,7 @@ func TestRouterWithoutPageHandler(t *testing.T) {
} }
func TestSwaggerRoute(t *testing.T) { func TestSwaggerRoute(t *testing.T) {
authHandler, postHandler, voteHandler, userHandler, apiHandler, authService := setupTestHandlers() router := createTestRouter(createDefaultRouterConfig())
router := NewRouter(RouterConfig{
APIHandler: apiHandler,
AuthHandler: authHandler,
PostHandler: postHandler,
VoteHandler: voteHandler,
UserHandler: userHandler,
AuthService: authService,
RateLimitConfig: defaultRateLimitConfig(),
})
request := httptest.NewRequest(http.MethodGet, "/swagger/", nil) request := httptest.NewRequest(http.MethodGet, "/swagger/", nil)
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
@@ -437,18 +360,9 @@ func TestSwaggerRoute(t *testing.T) {
} }
func TestStaticFileRoute(t *testing.T) { func TestStaticFileRoute(t *testing.T) {
authHandler, postHandler, voteHandler, userHandler, apiHandler, authService := setupTestHandlers() cfg := createDefaultRouterConfig()
cfg.StaticDir = "../../internal/static/"
router := NewRouter(RouterConfig{ router := createTestRouter(cfg)
StaticDir: "../../internal/static/",
APIHandler: apiHandler,
AuthHandler: authHandler,
PostHandler: postHandler,
VoteHandler: voteHandler,
UserHandler: userHandler,
AuthService: authService,
RateLimitConfig: defaultRateLimitConfig(),
})
request := httptest.NewRequest(http.MethodGet, "/static/css/main.css", nil) request := httptest.NewRequest(http.MethodGet, "/static/css/main.css", nil)
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
@@ -461,17 +375,7 @@ func TestStaticFileRoute(t *testing.T) {
} }
func TestRouterConfiguration(t *testing.T) { func TestRouterConfiguration(t *testing.T) {
authHandler, postHandler, voteHandler, userHandler, apiHandler, authService := setupTestHandlers() router := createTestRouter(createDefaultRouterConfig())
router := NewRouter(RouterConfig{
APIHandler: apiHandler,
AuthHandler: authHandler,
PostHandler: postHandler,
VoteHandler: voteHandler,
UserHandler: userHandler,
AuthService: authService,
RateLimitConfig: defaultRateLimitConfig(),
})
if router == nil { if router == nil {
t.Error("Router should not be nil") t.Error("Router should not be nil")
@@ -487,29 +391,484 @@ func TestRouterConfiguration(t *testing.T) {
} }
} }
func TestRouterMiddlewareIntegration(t *testing.T) { func TestAllRoutesExist(t *testing.T) {
authHandler, postHandler, voteHandler, userHandler, apiHandler, authService := setupTestHandlers() router := createTestRouter(createDefaultRouterConfig())
router := NewRouter(RouterConfig{ publicRoutes := []struct {
APIHandler: apiHandler, method string
AuthHandler: authHandler, path string
PostHandler: postHandler, description string
VoteHandler: voteHandler, }{
UserHandler: userHandler, {http.MethodGet, "/api", "API info"},
AuthService: authService, {http.MethodGet, "/health", "Health check"},
RateLimitConfig: defaultRateLimitConfig(), {http.MethodGet, "/metrics", "Metrics"},
}) {http.MethodGet, "/robots.txt", "Robots.txt"},
{http.MethodGet, "/api/posts", "Get posts"},
if router == nil { {http.MethodGet, "/api/posts/search", "Search posts"},
t.Error("Router should not be nil") {http.MethodGet, "/api/posts/title", "Fetch title from URL"},
{http.MethodGet, "/api/posts/1", "Get post by ID"},
{http.MethodPost, "/api/auth/register", "Register"},
{http.MethodPost, "/api/auth/login", "Login"},
{http.MethodPost, "/api/auth/refresh", "Refresh token"},
{http.MethodGet, "/api/auth/confirm", "Confirm email"},
{http.MethodPost, "/api/auth/resend-verification", "Resend verification"},
{http.MethodPost, "/api/auth/forgot-password", "Forgot password"},
{http.MethodPost, "/api/auth/reset-password", "Reset password"},
{http.MethodPost, "/api/auth/account/confirm", "Confirm account deletion"},
} }
request := httptest.NewRequest(http.MethodGet, "/api", nil) protectedRoutes := []struct {
method string
path string
description string
}{
{http.MethodGet, "/api/auth/me", "Get current user"},
{http.MethodPost, "/api/auth/logout", "Logout"},
{http.MethodPost, "/api/auth/revoke", "Revoke token"},
{http.MethodPost, "/api/auth/revoke-all", "Revoke all tokens"},
{http.MethodPut, "/api/auth/email", "Update email"},
{http.MethodPut, "/api/auth/username", "Update username"},
{http.MethodPut, "/api/auth/password", "Update password"},
{http.MethodDelete, "/api/auth/account", "Delete account"},
{http.MethodPost, "/api/posts", "Create post"},
{http.MethodPut, "/api/posts/1", "Update post"},
{http.MethodDelete, "/api/posts/1", "Delete post"},
{http.MethodPost, "/api/posts/1/vote", "Cast vote"},
{http.MethodDelete, "/api/posts/1/vote", "Remove vote"},
{http.MethodGet, "/api/posts/1/vote", "Get user vote"},
{http.MethodGet, "/api/posts/1/votes", "Get post votes"},
{http.MethodGet, "/api/users", "Get users"},
{http.MethodPost, "/api/users", "Create user"},
{http.MethodGet, "/api/users/1", "Get user by ID"},
{http.MethodGet, "/api/users/1/posts", "Get user posts"},
}
for _, route := range publicRoutes {
t.Run(route.description+" "+route.method+" "+route.path, func(t *testing.T) {
invalidMethod := http.MethodPatch
switch route.method {
case http.MethodGet:
invalidMethod = http.MethodDelete
case http.MethodPost:
invalidMethod = http.MethodGet
}
request := httptest.NewRequest(invalidMethod, route.path, nil)
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, request) router.ServeHTTP(recorder, request)
if recorder.Code == 0 { routeExists := recorder.Code == http.StatusMethodNotAllowed
t.Error("Router should return a status code")
if !routeExists {
request = httptest.NewRequest(route.method, route.path, nil)
recorder = httptest.NewRecorder()
router.ServeHTTP(recorder, request)
if recorder.Code == http.StatusNotFound && route.path != "/api/posts/1" && route.path != "/robots.txt" {
t.Errorf("Route %s %s should exist, got 404", route.method, route.path)
}
}
})
}
for _, route := range protectedRoutes {
t.Run(route.description+" "+route.method+" "+route.path, func(t *testing.T) {
request := httptest.NewRequest(route.method, route.path, nil)
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, request)
if recorder.Code == http.StatusNotFound {
t.Errorf("Route %s %s should exist, got 404", route.method, route.path)
}
if recorder.Code != http.StatusUnauthorized {
t.Errorf("Protected route %s %s should return 401 without auth, got %d", route.method, route.path, recorder.Code)
}
})
}
}
func TestRouteParameters(t *testing.T) {
router := createTestRouter(createDefaultRouterConfig())
testCases := []struct {
name string
method string
pathPattern string
testIDs []string
isProtected bool
}{
{
name: "Get post by ID",
method: http.MethodGet,
pathPattern: "/api/posts/{id}",
testIDs: []string{"1", "42", "999", "12345"},
isProtected: false,
},
{
name: "Update post by ID",
method: http.MethodPut,
pathPattern: "/api/posts/{id}",
testIDs: []string{"1", "42", "999"},
isProtected: true,
},
{
name: "Delete post by ID",
method: http.MethodDelete,
pathPattern: "/api/posts/{id}",
testIDs: []string{"1", "42", "999"},
isProtected: true,
},
{
name: "Get user by ID",
method: http.MethodGet,
pathPattern: "/api/users/{id}",
testIDs: []string{"1", "42", "999", "12345"},
isProtected: true,
},
{
name: "Get user posts by user ID",
method: http.MethodGet,
pathPattern: "/api/users/{id}/posts",
testIDs: []string{"1", "42", "999", "12345"},
isProtected: true,
},
{
name: "Cast vote for post ID",
method: http.MethodPost,
pathPattern: "/api/posts/{id}/vote",
testIDs: []string{"1", "42", "999"},
isProtected: true,
},
{
name: "Remove vote for post ID",
method: http.MethodDelete,
pathPattern: "/api/posts/{id}/vote",
testIDs: []string{"1", "42", "999"},
isProtected: true,
},
{
name: "Get user vote for post ID",
method: http.MethodGet,
pathPattern: "/api/posts/{id}/vote",
testIDs: []string{"1", "42", "999", "12345"},
isProtected: true,
},
{
name: "Get post votes by post ID",
method: http.MethodGet,
pathPattern: "/api/posts/{id}/votes",
testIDs: []string{"1", "42", "999", "12345"},
isProtected: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
for _, id := range tc.testIDs {
path := replaceID(tc.pathPattern, id)
t.Run("ID_"+id, func(t *testing.T) {
request := httptest.NewRequest(http.MethodPatch, path, nil)
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, request)
routeExists := recorder.Code == http.StatusMethodNotAllowed
request = httptest.NewRequest(tc.method, path, nil)
recorder = httptest.NewRecorder()
router.ServeHTTP(recorder, request)
if !routeExists {
if recorder.Code == http.StatusNotFound {
t.Errorf("Route %s %s should exist with ID %s, got 404", tc.method, path, id)
return
}
}
if tc.isProtected {
if recorder.Code != http.StatusUnauthorized {
t.Errorf("Protected route %s %s should return 401 without auth, got %d", tc.method, path, recorder.Code)
}
}
})
}
})
}
}
func replaceID(pattern, id string) string {
return strings.Replace(pattern, "{id}", id, 1)
}
func TestInvalidRouteParameters(t *testing.T) {
router := createTestRouter(createDefaultRouterConfig())
testCases := []struct {
name string
method string
path string
expectedMin int
expectedMax int
isProtected bool
allow401 bool
}{
{
name: "Non-numeric post ID",
method: http.MethodGet,
path: "/api/posts/abc",
expectedMin: http.StatusBadRequest,
expectedMax: http.StatusBadRequest,
isProtected: false,
},
{
name: "Negative post ID",
method: http.MethodGet,
path: "/api/posts/-1",
expectedMin: http.StatusBadRequest,
expectedMax: http.StatusBadRequest,
isProtected: false,
},
{
name: "Zero post ID",
method: http.MethodGet,
path: "/api/posts/0",
expectedMin: http.StatusBadRequest,
expectedMax: http.StatusNotFound,
isProtected: false,
},
{
name: "Post ID with special characters",
method: http.MethodGet,
path: "/api/posts/123@456",
expectedMin: http.StatusBadRequest,
expectedMax: http.StatusBadRequest,
isProtected: false,
},
{
name: "Post ID with encoded spaces",
method: http.MethodGet,
path: "/api/posts/12%2034",
expectedMin: http.StatusBadRequest,
expectedMax: http.StatusBadRequest,
isProtected: false,
},
{
name: "Non-numeric user ID",
method: http.MethodGet,
path: "/api/users/xyz",
expectedMin: http.StatusBadRequest,
expectedMax: http.StatusUnauthorized,
isProtected: true,
allow401: true,
},
{
name: "Negative user ID",
method: http.MethodGet,
path: "/api/users/-5",
expectedMin: http.StatusBadRequest,
expectedMax: http.StatusUnauthorized,
isProtected: true,
allow401: true,
},
{
name: "Non-numeric post ID in vote route",
method: http.MethodGet,
path: "/api/posts/invalid/vote",
expectedMin: http.StatusBadRequest,
expectedMax: http.StatusUnauthorized,
isProtected: true,
allow401: true,
},
{
name: "Very large post ID",
method: http.MethodGet,
path: "/api/posts/999999999999",
expectedMin: http.StatusBadRequest,
expectedMax: http.StatusNotFound,
isProtected: false,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
request := httptest.NewRequest(tc.method, tc.path, nil)
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, request)
if tc.isProtected && tc.allow401 {
if recorder.Code != http.StatusUnauthorized && (recorder.Code < tc.expectedMin || recorder.Code > tc.expectedMax) {
t.Errorf("Protected route %s %s with invalid parameter should return 401 or status between %d and %d, got %d", tc.method, tc.path, tc.expectedMin, tc.expectedMax, recorder.Code)
}
} else {
if recorder.Code < tc.expectedMin || recorder.Code > tc.expectedMax {
t.Errorf("Route %s %s should return status between %d and %d, got %d", tc.method, tc.path, tc.expectedMin, tc.expectedMax, recorder.Code)
}
if recorder.Code != http.StatusNotFound && recorder.Code < 400 {
t.Errorf("Route %s %s with invalid parameter should return error status (4xx), got %d", tc.method, tc.path, recorder.Code)
}
}
})
}
}
func TestQueryParameters(t *testing.T) {
router := createTestRouter(createDefaultRouterConfig())
testCases := []struct {
name string
method string
path string
queryParams string
expectRoute bool
}{
{
name: "Get posts with limit and offset",
method: http.MethodGet,
path: "/api/posts",
queryParams: "limit=10&offset=5",
expectRoute: true,
},
{
name: "Get posts with only limit",
method: http.MethodGet,
path: "/api/posts",
queryParams: "limit=20",
expectRoute: true,
},
{
name: "Get posts with only offset",
method: http.MethodGet,
path: "/api/posts",
queryParams: "offset=10",
expectRoute: true,
},
{
name: "Search posts with query parameter",
method: http.MethodGet,
path: "/api/posts/search",
queryParams: "q=test",
expectRoute: true,
},
{
name: "Search posts with query, limit, and offset",
method: http.MethodGet,
path: "/api/posts/search",
queryParams: "q=test&limit=15&offset=3",
expectRoute: true,
},
{
name: "Fetch title with URL parameter",
method: http.MethodGet,
path: "/api/posts/title",
queryParams: "url=https://example.com",
expectRoute: true,
},
{
name: "Confirm email with token parameter",
method: http.MethodGet,
path: "/api/auth/confirm",
queryParams: "token=abc123",
expectRoute: true,
},
{
name: "Get posts with invalid limit",
method: http.MethodGet,
path: "/api/posts",
queryParams: "limit=abc",
expectRoute: true,
},
{
name: "Get posts with negative limit",
method: http.MethodGet,
path: "/api/posts",
queryParams: "limit=-5",
expectRoute: true,
},
{
name: "Get posts with negative offset",
method: http.MethodGet,
path: "/api/posts",
queryParams: "offset=-10",
expectRoute: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
fullPath := tc.path
if tc.queryParams != "" {
fullPath += "?" + tc.queryParams
}
request := httptest.NewRequest(tc.method, fullPath, nil)
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, request)
if tc.expectRoute {
if recorder.Code == http.StatusNotFound {
t.Errorf("Route %s %s should exist with query parameters, got 404", tc.method, fullPath)
}
}
})
}
}
func TestRouteConflicts(t *testing.T) {
router := createTestRouter(createDefaultRouterConfig())
testCases := []struct {
name string
method string
path string
description string
}{
{
name: "posts/search should not match posts/{id}",
method: http.MethodGet,
path: "/api/posts/search",
description: "search route should be matched, not treated as ID",
},
{
name: "posts/title should not match posts/{id}",
method: http.MethodGet,
path: "/api/posts/title",
description: "title route should be matched, not treated as ID",
},
{
name: "posts/{id} should work with numeric ID",
method: http.MethodGet,
path: "/api/posts/123",
description: "numeric ID should match {id} route",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
request := httptest.NewRequest(tc.method, tc.path, nil)
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, request)
switch tc.path {
case "/api/posts/search":
if recorder.Code == http.StatusNotFound {
t.Errorf("%s: Route %s %s should exist (not 404), got %d", tc.description, tc.method, tc.path, recorder.Code)
}
case "/api/posts/title":
if recorder.Code == http.StatusNotFound {
t.Errorf("%s: Route %s %s should exist (not 404), got %d", tc.description, tc.method, tc.path, recorder.Code)
}
case "/api/posts/123":
if recorder.Code == http.StatusNotFound {
return
}
if recorder.Code < 400 {
t.Errorf("%s: Route %s %s should return 4xx or 5xx, got %d", tc.description, tc.method, tc.path, recorder.Code)
}
}
})
} }
} }

View File

@@ -1,12 +1,93 @@
package services package services
import ( import (
"crypto/rand"
"crypto/sha256"
"encoding/hex"
"errors"
"fmt" "fmt"
"net/mail"
"strings"
"goyco/internal/config" "goyco/internal/config"
"goyco/internal/database" "goyco/internal/database"
) )
const (
defaultTokenExpirationHours = 24
verificationTokenBytes = 32
deletionTokenExpirationHours = 24
)
var (
ErrInvalidCredentials = errors.New("invalid credentials")
ErrInvalidToken = errors.New("invalid or expired token")
ErrUsernameTaken = errors.New("username already exists")
ErrEmailTaken = errors.New("email already exists")
ErrInvalidEmail = errors.New("invalid email address")
ErrPasswordTooShort = errors.New("password too short")
ErrEmailNotVerified = errors.New("email not verified")
ErrAccountLocked = errors.New("account is locked")
ErrInvalidVerificationToken = errors.New("invalid verification token")
ErrEmailSenderUnavailable = errors.New("email sender not configured")
ErrDeletionEmailFailed = errors.New("account deletion email failed")
ErrInvalidDeletionToken = errors.New("invalid account deletion token")
ErrUserNotFound = errors.New("user not found")
ErrDeletionRequestNotFound = errors.New("deletion request not found")
)
type AuthResult struct {
AccessToken string `json:"access_token" example:"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..."`
RefreshToken string `json:"refresh_token" example:"a1b2c3d4e5f6g7h8i9j0k1l2m3n4o5p6q7r8s9t0"`
User *database.User `json:"user"`
}
type RegistrationResult struct {
User *database.User `json:"user"`
VerificationSent bool `json:"verification_sent"`
}
func normalizeEmail(email string) (string, error) {
trimmed := strings.TrimSpace(email)
if trimmed == "" {
return "", fmt.Errorf("email is required")
}
parsed, err := mail.ParseAddress(trimmed)
if err != nil {
return "", ErrInvalidEmail
}
return strings.ToLower(parsed.Address), nil
}
func generateVerificationToken() (string, string, error) {
buf := make([]byte, verificationTokenBytes)
if _, err := rand.Read(buf); err != nil {
return "", "", fmt.Errorf("generate verification token: %w", err)
}
token := hex.EncodeToString(buf)
hashed := HashVerificationToken(token)
return token, hashed, nil
}
func HashVerificationToken(token string) string {
sum := sha256.Sum256([]byte(token))
return hex.EncodeToString(sum[:])
}
func sanitizeUser(user *database.User) *database.User {
if user == nil {
return nil
}
copy := *user
copy.Password = ""
copy.EmailVerificationToken = ""
return &copy
}
type AuthFacade struct { type AuthFacade struct {
registrationService *RegistrationService registrationService *RegistrationService
passwordResetService *PasswordResetService passwordResetService *PasswordResetService

View File

@@ -1,35 +0,0 @@
package services
import (
"errors"
"goyco/internal/database"
)
var (
ErrInvalidCredentials = errors.New("invalid credentials")
ErrInvalidToken = errors.New("invalid or expired token")
ErrUsernameTaken = errors.New("username already exists")
ErrEmailTaken = errors.New("email already exists")
ErrInvalidEmail = errors.New("invalid email address")
ErrPasswordTooShort = errors.New("password too short")
ErrEmailNotVerified = errors.New("email not verified")
ErrAccountLocked = errors.New("account is locked")
ErrInvalidVerificationToken = errors.New("invalid verification token")
ErrEmailSenderUnavailable = errors.New("email sender not configured")
ErrDeletionEmailFailed = errors.New("account deletion email failed")
ErrInvalidDeletionToken = errors.New("invalid account deletion token")
ErrUserNotFound = errors.New("user not found")
ErrDeletionRequestNotFound = errors.New("deletion request not found")
)
type AuthResult struct {
AccessToken string `json:"access_token" example:"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..."`
RefreshToken string `json:"refresh_token" example:"a1b2c3d4e5f6g7h8i9j0k1l2m3n4o5p6q7r8s9t0"`
User *database.User `json:"user"`
}
type RegistrationResult struct {
User *database.User `json:"user"`
VerificationSent bool `json:"verification_sent"`
}

View File

@@ -1,59 +0,0 @@
package services
import (
"crypto/rand"
"crypto/sha256"
"encoding/hex"
"fmt"
"net/mail"
"strings"
"goyco/internal/database"
)
const (
defaultTokenExpirationHours = 24
verificationTokenBytes = 32
deletionTokenExpirationHours = 24
)
func normalizeEmail(email string) (string, error) {
trimmed := strings.TrimSpace(email)
if trimmed == "" {
return "", fmt.Errorf("email is required")
}
parsed, err := mail.ParseAddress(trimmed)
if err != nil {
return "", ErrInvalidEmail
}
return strings.ToLower(parsed.Address), nil
}
func generateVerificationToken() (string, string, error) {
buf := make([]byte, verificationTokenBytes)
if _, err := rand.Read(buf); err != nil {
return "", "", fmt.Errorf("generate verification token: %w", err)
}
token := hex.EncodeToString(buf)
hashed := HashVerificationToken(token)
return token, hashed, nil
}
func HashVerificationToken(token string) string {
sum := sha256.Sum256([]byte(token))
return hex.EncodeToString(sum[:])
}
func sanitizeUser(user *database.User) *database.User {
if user == nil {
return nil
}
copy := *user
copy.Password = ""
copy.EmailVerificationToken = ""
return &copy
}

View File

@@ -7,9 +7,10 @@ import (
"sync" "sync"
"testing" "testing"
"gorm.io/gorm"
"goyco/internal/database" "goyco/internal/database"
"goyco/internal/repositories" "goyco/internal/repositories"
"gorm.io/gorm"
) )
type mockVoteRepo struct { type mockVoteRepo struct {
@@ -240,6 +241,25 @@ func (m *mockVoteRepo) CountByUserID(userID uint) (int64, error) {
return count, nil return count, nil
} }
func (m *mockVoteRepo) GetVoteCountsByPostID(postID uint) (int, int, error) {
m.mu.RLock()
defer m.mu.RUnlock()
upVotes := 0
downVotes := 0
for _, vote := range m.votes {
if vote.PostID == postID {
switch vote.Type {
case database.VoteUp:
upVotes++
case database.VoteDown:
downVotes++
}
}
}
return upVotes, downVotes, nil
}
func (m *mockVoteRepo) WithTx(tx *gorm.DB) repositories.VoteRepository { func (m *mockVoteRepo) WithTx(tx *gorm.DB) repositories.VoteRepository {
return m return m
} }

View File

@@ -2,43 +2,70 @@ package templates
import ( import (
"html/template" "html/template"
"io/fs"
"path/filepath" "path/filepath"
"strings"
"testing" "testing"
"time"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func TestTemplateParsing(t *testing.T) { func templateFuncMap() template.FuncMap {
templateDir := "./" return template.FuncMap{
"formatTime": func(t time.Time) string {
if t.IsZero() {
return ""
}
return t.Format("02 Jan 2006 15:04")
},
"truncate": func(s string, length int) string {
if len(s) <= length {
return s
}
if length <= 3 {
return s[:length]
}
return s[:length-3] + "..."
},
"substr": func(s string, start, length int) string {
if start >= len(s) {
return ""
}
end := min(start+length, len(s))
return s[start:end]
},
"upper": strings.ToUpper,
}
}
var templateFiles []string func TestTemplateParsing(t *testing.T) {
err := filepath.WalkDir(templateDir, func(path string, d fs.DirEntry, err error) error { layoutPath := filepath.Join(".", "base.gohtml")
if err != nil { require.FileExists(t, layoutPath, "base layout is required for all templates")
return err
} partials, err := filepath.Glob(filepath.Join(".", "partials", "*.gohtml"))
if !d.IsDir() && filepath.Ext(path) == ".gohtml" {
templateFiles = append(templateFiles, path)
}
return nil
})
require.NoError(t, err) require.NoError(t, err)
tmpl := template.New("test") pages, err := filepath.Glob(filepath.Join(".", "*.gohtml"))
require.NoError(t, err)
require.NotEmpty(t, pages, "no page templates found")
tmpl = tmpl.Funcs(template.FuncMap{ for _, page := range pages {
"formatTime": func(any) string { return "2024-01-01" }, if filepath.Base(page) == "base.gohtml" {
"eq": func(a, b any) bool { return a == b }, continue
"ne": func(a, b any) bool { return a != b }, }
"len": func(s any) int { return 0 },
"range": func(s any) any { return s },
})
for _, file := range templateFiles { page := page
t.Run(file, func(t *testing.T) { t.Run(filepath.Base(page), func(t *testing.T) {
_, err := tmpl.ParseFiles(file) t.Parallel()
assert.NoError(t, err, "Template %s should parse without errors", file)
files := append([]string{layoutPath}, partials...)
files = append(files, page)
tmpl, err := template.New(filepath.Base(page)).Funcs(templateFuncMap()).ParseFiles(files...)
require.NoError(t, err)
require.NotNil(t, tmpl.Lookup("layout"), "layout template should be available")
require.NotNil(t, tmpl.Lookup("content"), "content block should be defined by page templates")
}) })
} }
} }

View File

@@ -422,6 +422,24 @@ func (m *MockUserRepository) GetDeletedUsers() ([]database.User, error) {
return []database.User{}, nil return []database.User{}, nil
} }
func (m *MockUserRepository) GetByUsernamePrefix(prefix string) (*database.User, error) {
m.mu.RLock()
defer m.mu.RUnlock()
for _, user := range m.users {
if len(user.Username) >= len(prefix) && user.Username[:len(prefix)] == prefix {
if len(user.Email) >= 13 && strings.HasSuffix(user.Email, "@goyco.local") {
emailPrefix := user.Email[:len(user.Email)-13]
if len(emailPrefix) >= len(prefix) && emailPrefix[:len(prefix)] == prefix {
return user, nil
}
}
}
}
return nil, gorm.ErrRecordNotFound
}
func (m *MockUserRepository) HardDeleteAll() (int64, error) { func (m *MockUserRepository) HardDeleteAll() (int64, error) {
if m.HardDeleteAllFunc != nil { if m.HardDeleteAllFunc != nil {
return m.HardDeleteAllFunc() return m.HardDeleteAllFunc()
@@ -965,6 +983,25 @@ func (m *MockVoteRepository) CountByUserID(userID uint) (int64, error) {
return count, nil return count, nil
} }
func (m *MockVoteRepository) GetVoteCountsByPostID(postID uint) (int, int, error) {
m.mu.RLock()
defer m.mu.RUnlock()
upVotes := 0
downVotes := 0
for _, vote := range m.votes {
if vote.PostID == postID {
switch vote.Type {
case database.VoteUp:
upVotes++
case database.VoteDown:
downVotes++
}
}
}
return upVotes, downVotes, nil
}
func (m *MockVoteRepository) Count() (int64, error) { func (m *MockVoteRepository) Count() (int64, error) {
m.mu.RLock() m.mu.RLock()
defer m.mu.RUnlock() defer m.mu.RUnlock()

View File

@@ -153,6 +153,7 @@ type UserRepositoryStub struct {
UnlockFn func(uint) error UnlockFn func(uint) error
GetPostsFn func(uint, int, int) ([]database.Post, error) GetPostsFn func(uint, int, int) ([]database.Post, error)
GetDeletedUsersFn func() ([]database.User, error) GetDeletedUsersFn func() ([]database.User, error)
GetByUsernamePrefixFn func(string) (*database.User, error)
HardDeleteAllFn func() (int64, error) HardDeleteAllFn func() (int64, error)
CountFn func() (int64, error) CountFn func() (int64, error)
WithTxFn func(*gorm.DB) repositories.UserRepository WithTxFn func(*gorm.DB) repositories.UserRepository
@@ -281,6 +282,13 @@ func (s *UserRepositoryStub) GetDeletedUsers() ([]database.User, error) {
return nil, nil return nil, nil
} }
func (s *UserRepositoryStub) GetByUsernamePrefix(prefix string) (*database.User, error) {
if s != nil && s.GetByUsernamePrefixFn != nil {
return s.GetByUsernamePrefixFn(prefix)
}
return nil, gorm.ErrRecordNotFound
}
func (s *UserRepositoryStub) HardDeleteAll() (int64, error) { func (s *UserRepositoryStub) HardDeleteAll() (int64, error) {
if s != nil && s.HardDeleteAllFn != nil { if s != nil && s.HardDeleteAllFn != nil {
return s.HardDeleteAllFn() return s.HardDeleteAllFn()

View File

@@ -30,7 +30,7 @@ if [ ! -d "docs" ]; then
mkdir -p docs mkdir -p docs
fi fi
SWAGGER_DIRECTORIES="cmd/goyco,internal/handlers" SWAGGER_DIRECTORIES="cmd/goyco,internal/handlers,internal/dto"
SWAGGER_MAIN_FILE="main.go" SWAGGER_MAIN_FILE="main.go"
SWAGGER_OUTPUT_DIR="docs" SWAGGER_OUTPUT_DIR="docs"

View File

@@ -44,5 +44,3 @@ GRANT ALL PRIVILEGES ON DATABASE goyco TO goyco;
EOF EOF
echo "PostgreSQL 18 installed, database 'goyco' and user 'goyco' set up." echo "PostgreSQL 18 installed, database 'goyco' and user 'goyco' set up."